This repository provides code and data needed to reproduce the experiments presented in our paper: Counterfactual Explanations for Continuous Action Reinforcement Learning (IJCAI 2025).
The project involves two case studies:
- A diabetes management scenario based on a Type 1 Diabetes (T1D) patient simulator
- A physics-based Lunar Lander environment
Each case study includes steps to:
- Train baseline policies using PPO
- Generate counterfactual explanations using TD3
- Post-process results for quantitative analysis and visualization
This case study corresponds to the experiments reported in Section 5.1 of the paper. It includes:
- Training a PPO agent to generate original traces
- Running TD3-based counterfactual experiments
- Post-processing experimental results
We use Stable-Baselines3 to train a PPO agent for generating original traces, serving as the baseline model.
conda create -n CF_diabetic_train_ppo python=3.8
conda activate CF_diabetic_train_ppo
pip install stable-baselines3==1.7.0
pip install gym==0.21.0
pip install torch==2.0.0
pip install pandas==1.5.3
pip install tensorboard
Replace the default packages with the provided ones:
cd yourpath/.conda/envs/CF_diabetic_train_ppo/lib/python3.8/site-packages
Upload and replace files in train_ppo/package
.
- Single Environment (Adult Patient #7):
python diabetic_case_train_ppo.py \
-arg_patient_type adult \
-arg_patient_id 7 \
-arg_cuda 0 \
-arg_train_step 100000 \
-arg_callback_step 100000
- Multi Environment (Train: Patients 4–6, Test: Patients 7–9):
python diabetic_case_train_ppo_generalize.py \
-arg_exp_id 1 \
-arg_cuda 0 \
-arg_train_step_each_env 3000 \
-arg_callback_step 3000 \
-arg_train_round 179 \
-arg_lr 0.0001 \
-arg_test_epochs_each_env 1 \
-arg_max_test_time_each_env 3000
The trained models are saved under
train_ppo/trained_model
.
conda create -n CF_diabetic python=3.8
conda activate CF_diabetic
pip install stable-baselines3==1.7.0
pip install gym==0.21.0
pip install torch==2.0.0
pip install pandas==1.5.3
Replace packages as before:
cd yourpath/.conda/envs/CF_diabetic/lib/python3.8/site-packages
Upload files from train_td3/package
.
The code automatically filters trace segments with:
- Maximum reward (20)
- Initial CGM < 65
- All-zero actions
Remaining segments are grouped into:
- (65, 100]
- (100, 150]
- (150, 260]
Then, 6 segments per group are selected for training and 6 for testing.
- P1 (Baseline Reward):
# Single-environment
sbatch run_diabetic_exp1.sh
# Multi-environment
sbatch run_diabetic_exp2.sh
- P2-base (State Threshold):
# Single-environment
sbatch run_diabetic_exp3_State_RP2.sh with diabetic_experiments_exp1_state_RP2_UET11.txt
# Multi-environment
sbatch run_diabetic_exp3_State_RP2.sh with diabetic_experiments_exp3_state_RP2_UET6.txt
- P2-fixed (State Threshold + Fixed Action):
# Single-environment
sbatch run_diabetic_exp3_State_RP3.sh with diabetic_experiments_exp1_state_RP3_UET11.txt
# Multi-environment
sbatch run_diabetic_exp3_State_RP3.sh with diabetic_experiments_exp3_state_RP3_UET6.txt
To run on newly generated original traces:
- Use
script_diabetic_exp_multiple_trace.py
to generate parameter.txt
files - Apply them in the above
.sh
SLURM scripts
Example scripts are available in:
train_td3/code/P1/generate_new_trace/
- Rename
trial_1
totrial_X
(X = 1 to 7) - Move
trial_2
~trial_7
intotd3_cf_results
insidetrial_1
Edit data_postprocess.py
:
case_name = 'diabetic' # Line 3059
Then run the block under:
# basic postprocess to get all metric data
Ensure experiment parameters and file paths are consistent with
.txt
configs.
-
all_metric_aveg_best_baseline_test_0_RP1.csv
— P1, no constraints -
all_metric_aveg_best_baseline_test_100_RP2.csv
— P2-base, threshold 100 -
all_metric_aveg_best_baseline_test_100-0.03_RP3.csv
— P2-fixed, threshold 100 & fixed action 0.03 -
P_hr
column corresponds to rho_plus -
compare_count_perc
column corresponds to rho_adv
- Create a folder
across_rp
- Move all
*_RP*.csv
results into it - Run code under:
# draw final learning curve
- Set
metric_name='rho_plus'
for Figure 2 - Set
metric_name='rho_adv'
for Figure 3
Final data and figures are located in:
Supplementary/code_and_data/data_postprocess/Diabetes_case_study/across_rp/final_data
This case study corresponds to the experiments reported in Section 5.2 of the paper.
conda create -n CF_LunarLander_train_ppo_generalize python=3.8
conda activate CF_LunarLander_train_ppo_generalize
pip install stable-baselines3==1.7.0
pip install gym==0.21.0
pip install torch==2.0.0
pip install pandas==1.5.3
pip install tensorboard
Replace packages:
cd yourpath/.conda/envs/CF_LunarLander_train_ppo_generalize/lib/python3.8/site-packages
Upload files in train_ppo/package
.
- Single Environment (gravity = -9):
python openai_case_train_ppo_generalize.py \
-arg_exp_id 1 \
-arg_cuda 0 \
-arg_train_step_each_env 500 \
-arg_callback_step 500 \
-arg_train_round 5 \
-arg_lr 0.0001 \
-arg_test_epochs_each_env 1 \
-arg_max_test_time_each_env 1800 \
-arg_if_train_personalize 1 \
-arg_assigned_gravity -9.0
- Multi Environment (Train: g = -11, -9, -5; Test: g = -10, -8, -6):
python openai_case_train_ppo_generalize.py \
-arg_exp_id 1 \
-arg_cuda 0 \
-arg_train_step_each_env 500 \
-arg_callback_step 500 \
-arg_train_round 3 \
-arg_lr 0.0001 \
-arg_test_epochs_each_env 1 \
-arg_max_test_time_each_env 1800
The trained models are saved under
train_ppo/trained_model
.
conda create -n CF_lunarlander python=3.8
conda activate CF_lunarlander
pip install stable-baselines3==1.7.0
conda install swig
pip install box2d-py
pip install gym==0.21.0
pip install Box2D
pip install torch==2.0.0
pip install pandas==1.5.3
pip install tensorflow==2.13.0
Replace packages:
cd yourpath/.conda/envs/CF_lunarlander/lib/python3.8/site-packages
Upload files from train_td3/package
.
- P1:
# Single-environment
sbatch run_LL_exp1_double_test.sh
# Multi-environment
sbatch run_LL_exp3_double_test.sh
- P2-base:
# Single-environment
sbatch run_LL_RP2_exp1_S_double_test.sh
# Multi-environment
sbatch run_LL_RP2_exp3_S_double_test.sh
- P2-fixed:
# Single-environment
sbatch run_LL_RP3_exp1_S_double_test.sh
# Multi-environment
sbatch run_LL_RP3_exp3_S_double_test.sh
Use script_LL_exp.py
to generate new .txt
configs, and apply them in SLURM scripts above.
Example scripts:
train_td3/code/P1/generate_new_trace/
Follow the same steps as in the Diabetes Case Study.
Set case_name = 'lunar_lander'
in data_postprocess.py
, and execute the block under:
# basic postprocess to get all metric data
Example result files:
all_metric_aveg_best_baseline_test_0.0_RP1.csv
all_metric_aveg_best_baseline_test_0.18_RP2.csv
all_metric_aveg_best_baseline_test_0.18-0_0_RP3.csv
Move all results into across_rp/
and run:
# draw final learning curve
metric_name='rho_plus'
for Figure 4metric_name='rho_adv'
for Figure 5
Final data and figures:
data_postprocess/lunar_lander/across_rp/final_data
This repository was originally developed as part of the research project:
Counterfactual Explanations for Continuous Action Reinforcement Learning
Presented at the 34th International Joint Conference on Artificial Intelligence (IJCAI), 2025
If you use this code or build upon it, please cite the original work:
@inproceedings{dong2025counterfactual,
title={Counterfactual Explanations for Continuous Action Reinforcement Learning},
author={Shuyang Dong and Shangtong Zhang and Lu Feng},
booktitle={Proceedings of the 34th International Joint Conference on Artificial Intelligence (IJCAI)},
year={2025}
}