Skip to content

[CS598 DLH] Counterfactual VAE (CF-VAE) for binary healthcare prediction tasks #404

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

gjlee4
Copy link

@gjlee4 gjlee4 commented May 4, 2025

Contributor Info

Name: Sharim Khan, Gabriel Lee
NetID: sharimk2, gjlee4
Paper: Explaining a Machine Learning Decision to Physicians via Counterfactuals
Paper Link: https://arxiv.org/abs/2306.06325

Type of Contribution

Model: New model addition
Task: Counterfactual VAE (CF-VAE) for binary healthcare prediction tasks

High-Level Description

This PR contributes a PyHealth-compatible implementation of the Counterfactual Variational Autoencoder (CFVAE) model, which is specifically designed to generate counterfactuals for binary prediction tasks. Specifically, the model is composed of:

  • A variational autoencoder that encodes input into latent space and reconstructs the input via a decoder
  • A classifier branch to guide the learning of counterfactual task (e.g. flipping a binary decision) from the reconstructed input

The model output should generate a counterfactual based on the input and provided classifier.

This implementation is inspired by the model proposed in:

Explaining a Machine Learning Decision to Physicians via Counterfactuals
Nagesh et al., 2023 (arXiv link)

Testing

To aid in testing the model, we provide the below Google Colab Notebook.
https://colab.research.google.com/drive/1gRaE6QDYfgjhopzEaAgCy45WaEL9ON1D

The unit test is recommended to be run in the notebook. The notebook will handle reproducing the environment, such that the test case can be run as:

!python3 pyhealth/unittests/test_cfvae_mortality_prediction.py
Unit Test Output
2025-05-06 07:32:07,322 - __main__ - INFO - ===== Starting CFVAE Unit Test =====
2025-05-06 07:32:08,709 - numexpr.utils - INFO - NumExpr defaulting to 2 threads.
Memory usage Starting MIMIC4Dataset init: 823.0 MB
2025-05-06 07:32:20,081 - pyhealth.datasets.mimic4 - INFO - Memory usage Starting MIMIC4Dataset init: 823.0 MB
Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: False)
2025-05-06 07:32:20,081 - pyhealth.datasets.mimic4 - INFO - Initializing MIMIC4EHRDataset with tables: ['diagnoses_icd', 'procedures_icd', 'prescriptions', 'labevents'] (dev mode: False)
Using default EHR config: /content/PyHealth/pyhealth/datasets/configs/mimic4_ehr.yaml
2025-05-06 07:32:20,081 - pyhealth.datasets.mimic4 - INFO - Using default EHR config: /content/PyHealth/pyhealth/datasets/configs/mimic4_ehr.yaml
Memory usage Before initializing mimic4_ehr: 823.0 MB
2025-05-06 07:32:20,081 - pyhealth.datasets.mimic4 - INFO - Memory usage Before initializing mimic4_ehr: 823.0 MB
Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False)
2025-05-06 07:32:20,090 - pyhealth.datasets.base_dataset - INFO - Initializing mimic4_ehr dataset from https://physionet.org/files/mimic-iv-demo/2.2/ (dev mode: False)
Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz
2025-05-06 07:32:20,090 - pyhealth.datasets.base_dataset - INFO - Scanning table: diagnoses_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/diagnoses_icd.csv.gz
Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz
2025-05-06 07:32:21,556 - pyhealth.datasets.base_dataset - INFO - Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz
Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz
2025-05-06 07:32:22,449 - pyhealth.datasets.base_dataset - INFO - Scanning table: procedures_icd from https://physionet.org/files/mimic-iv-demo/2.2/hosp/procedures_icd.csv.gz
Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz
2025-05-06 07:32:23,331 - pyhealth.datasets.base_dataset - INFO - Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz
Scanning table: prescriptions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/prescriptions.csv.gz
2025-05-06 07:32:23,766 - pyhealth.datasets.base_dataset - INFO - Scanning table: prescriptions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/prescriptions.csv.gz
Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz
2025-05-06 07:32:26,186 - pyhealth.datasets.base_dataset - INFO - Scanning table: labevents from https://physionet.org/files/mimic-iv-demo/2.2/hosp/labevents.csv.gz
Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz
2025-05-06 07:32:31,908 - pyhealth.datasets.base_dataset - INFO - Joining with table: https://physionet.org/files/mimic-iv-demo/2.2/hosp/d_labitems.csv.gz
Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz
2025-05-06 07:32:32,846 - pyhealth.datasets.base_dataset - INFO - Scanning table: patients from https://physionet.org/files/mimic-iv-demo/2.2/hosp/patients.csv.gz
Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz
2025-05-06 07:32:33,719 - pyhealth.datasets.base_dataset - INFO - Scanning table: admissions from https://physionet.org/files/mimic-iv-demo/2.2/hosp/admissions.csv.gz
Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz
2025-05-06 07:32:34,201 - pyhealth.datasets.base_dataset - INFO - Scanning table: icustays from https://physionet.org/files/mimic-iv-demo/2.2/icu/icustays.csv.gz
Memory usage After initializing mimic4_ehr: 843.0 MB
2025-05-06 07:32:35,070 - pyhealth.datasets.mimic4 - INFO - Memory usage After initializing mimic4_ehr: 843.0 MB
Memory usage After EHR dataset initialization: 843.0 MB
2025-05-06 07:32:35,070 - pyhealth.datasets.mimic4 - INFO - Memory usage After EHR dataset initialization: 843.0 MB
Memory usage Before combining data: 843.0 MB
2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Memory usage Before combining data: 843.0 MB
Combining data from ehr dataset
2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Combining data from ehr dataset
Creating combined dataframe
2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Creating combined dataframe
Memory usage After combining data: 843.0 MB
2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Memory usage After combining data: 843.0 MB
Memory usage Completed MIMIC4Dataset init: 843.0 MB
2025-05-06 07:32:35,071 - pyhealth.datasets.mimic4 - INFO - Memory usage Completed MIMIC4Dataset init: 843.0 MB
Setting task InHospitalMortalityMIMIC4 for mimic4 base dataset...
2025-05-06 07:32:35,071 - pyhealth.datasets.base_dataset - INFO - Setting task InHospitalMortalityMIMIC4 for mimic4 base dataset...
Collecting global event dataframe...
2025-05-06 07:32:35,071 - pyhealth.datasets.base_dataset - INFO - Collecting global event dataframe...
Collected dataframe with shape: (131557, 44)
2025-05-06 07:32:35,400 - pyhealth.datasets.base_dataset - INFO - Collected dataframe with shape: (131557, 44)
Generating samples with 2 worker(s)...
2025-05-06 07:32:35,401 - pyhealth.datasets.base_dataset - INFO - Generating samples with 2 worker(s)...
Generating samples for InHospitalMortalityMIMIC4
2025-05-06 07:32:35,401 - pyhealth.datasets.base_dataset - INFO - Generating samples for InHospitalMortalityMIMIC4
Label mortality vocab: {0: 0, 1: 1}
2025-05-06 07:32:36,705 - pyhealth.processors.label_processor - INFO - Label mortality vocab: {0: 0, 1: 1}
Processing samples: 100% 216/216 [00:00<00:00, 610.73it/s]
Generated 216 samples for task InHospitalMortalityMIMIC4
2025-05-06 07:32:37,059 - pyhealth.datasets.base_dataset - INFO - Generated 216 samples for task InHospitalMortalityMIMIC4
2025-05-06 07:32:37,064 - __main__ - INFO - ===== Loaded 216 samples. ===== 
2025-05-06 07:32:37,065 - __main__ - INFO - ===== Preprocessing samples (mean over time) =====
2025-05-06 07:32:37,090 - __main__ - INFO - ===== Stage 1: Train the dummy classifier =====
WrappedClassifier(
  (model): DummyClassifier(
    (model): Sequential(
      (0): Linear(in_features=27, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
  )
)
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - WrappedClassifier(
  (model): DummyClassifier(
    (model): Sequential(
      (0): Linear(in_features=27, out_features=64, bias=True)
      (1): ReLU()
      (2): Linear(in_features=64, out_features=1, bias=True)
    )
  )
)
Metrics: ['roc_auc', 'accuracy']
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Metrics: ['roc_auc', 'accuracy']
Device: cuda
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Device: cuda

2025-05-06 07:32:37,351 - pyhealth.trainer - INFO -
Training:
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Training:
Batch size: 32
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
2025-05-06 07:32:37,351 - pyhealth.trainer - INFO - Optimizer params: {'lr': 0.001}
Weight decay: 0.0
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Weight decay: 0.0
Max grad norm: None
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b42355ee310>
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b42355ee310>
Monitor: roc_auc
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Monitor: roc_auc
Monitor criterion: max
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Monitor criterion: max
Epochs: 5
2025-05-06 07:32:37,352 - pyhealth.trainer - INFO - Epochs: 5

2025-05-06 07:32:37,354 - pyhealth.trainer - INFO -
Epoch 0 / 5: 100% 5/5 [00:00<00:00, 7.42it/s]
--- Train epoch-0, step-5 ---
2025-05-06 07:32:38,028 - pyhealth.trainer - INFO - --- Train epoch-0, step-5 ---
loss: 0.6972
2025-05-06 07:32:38,028 - pyhealth.trainer - INFO - loss: 0.6972
Evaluation: 100% 1/1 [00:00<00:00, 792.13it/s]
--- Eval epoch-0, step-5 ---
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - --- Eval epoch-0, step-5 ---
roc_auc: 0.1500
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - roc_auc: 0.1500
accuracy: 0.6190
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - accuracy: 0.6190
loss: 0.6875
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - loss: 0.6875
New best roc_auc score (0.1500) at epoch-0, step-5
2025-05-06 07:32:38,053 - pyhealth.trainer - INFO - New best roc_auc score (0.1500) at epoch-0, step-5

2025-05-06 07:32:38,054 - pyhealth.trainer - INFO -
Epoch 1 / 5: 100% 5/5 [00:00<00:00, 450.82it/s]
--- Train epoch-1, step-10 ---
2025-05-06 07:32:38,066 - pyhealth.trainer - INFO - --- Train epoch-1, step-10 ---
loss: 0.6614
2025-05-06 07:32:38,066 - pyhealth.trainer - INFO - loss: 0.6614
Evaluation: 100% 1/1 [00:00<00:00, 973.38it/s]
--- Eval epoch-1, step-10 ---
2025-05-06 07:32:38,073 - pyhealth.trainer - INFO - --- Eval epoch-1, step-10 ---
roc_auc: 0.3500
2025-05-06 07:32:38,073 - pyhealth.trainer - INFO - roc_auc: 0.3500
accuracy: 0.9048
2025-05-06 07:32:38,074 - pyhealth.trainer - INFO - accuracy: 0.9048
loss: 0.6495
2025-05-06 07:32:38,074 - pyhealth.trainer - INFO - loss: 0.6495
New best roc_auc score (0.3500) at epoch-1, step-10
2025-05-06 07:32:38,074 - pyhealth.trainer - INFO - New best roc_auc score (0.3500) at epoch-1, step-10

2025-05-06 07:32:38,075 - pyhealth.trainer - INFO -
Epoch 2 / 5: 100% 5/5 [00:00<00:00, 404.94it/s]
--- Train epoch-2, step-15 ---
2025-05-06 07:32:38,088 - pyhealth.trainer - INFO - --- Train epoch-2, step-15 ---
loss: 0.6310
2025-05-06 07:32:38,088 - pyhealth.trainer - INFO - loss: 0.6310
Evaluation: 100% 1/1 [00:00<00:00, 872.54it/s]
--- Eval epoch-2, step-15 ---
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - --- Eval epoch-2, step-15 ---
roc_auc: 0.5500
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - roc_auc: 0.5500
accuracy: 0.9524
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 0.6134
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - loss: 0.6134
New best roc_auc score (0.5500) at epoch-2, step-15
2025-05-06 07:32:38,094 - pyhealth.trainer - INFO - New best roc_auc score (0.5500) at epoch-2, step-15

2025-05-06 07:32:38,095 - pyhealth.trainer - INFO -
Epoch 3 / 5: 100% 5/5 [00:00<00:00, 401.70it/s]
--- Train epoch-3, step-20 ---
2025-05-06 07:32:38,108 - pyhealth.trainer - INFO - --- Train epoch-3, step-20 ---
loss: 0.6013
2025-05-06 07:32:38,108 - pyhealth.trainer - INFO - loss: 0.6013
Evaluation: 100% 1/1 [00:00<00:00, 957.60it/s]
--- Eval epoch-3, step-20 ---
2025-05-06 07:32:38,113 - pyhealth.trainer - INFO - --- Eval epoch-3, step-20 ---
roc_auc: 0.6500
2025-05-06 07:32:38,114 - pyhealth.trainer - INFO - roc_auc: 0.6500
accuracy: 0.9524
2025-05-06 07:32:38,114 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 0.5788
2025-05-06 07:32:38,114 - pyhealth.trainer - INFO - loss: 0.5788
New best roc_auc score (0.6500) at epoch-3, step-20
2025-05-06 07:32:38,114 - pyhealth.trainer - INFO - New best roc_auc score (0.6500) at epoch-3, step-20

2025-05-06 07:32:38,115 - pyhealth.trainer - INFO -
Epoch 4 / 5: 100% 5/5 [00:00<00:00, 420.70it/s]
--- Train epoch-4, step-25 ---
2025-05-06 07:32:38,127 - pyhealth.trainer - INFO - --- Train epoch-4, step-25 ---
loss: 0.5689
2025-05-06 07:32:38,127 - pyhealth.trainer - INFO - loss: 0.5689
Evaluation: 100% 1/1 [00:00<00:00, 951.95it/s]
--- Eval epoch-4, step-25 ---
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - --- Eval epoch-4, step-25 ---
roc_auc: 0.7500
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - roc_auc: 0.7500
accuracy: 0.9524
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 0.5457
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - loss: 0.5457
New best roc_auc score (0.7500) at epoch-4, step-25
2025-05-06 07:32:38,133 - pyhealth.trainer - INFO - New best roc_auc score (0.7500) at epoch-4, step-25
Loaded best model
2025-05-06 07:32:38,134 - pyhealth.trainer - INFO - Loaded best model
2025-05-06 07:32:38,138 - main - INFO - ===== Freezing the classifier... =====
2025-05-06 07:32:38,138 - main - INFO - ===== Stage 2: Train CFVAE with frozen classifier =====
CFVAE(
(external_classifier): DummyClassifier(
(model): Sequential(
(0): Linear(in_features=27, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=1, bias=True)
)
)
(enc1): Sequential(
(0): Linear(in_features=27, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
)
(enc2): Linear(in_features=64, out_features=64, bias=True)
(dec1): Sequential(
(0): Linear(in_features=34, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
)
(dec2): Linear(in_features=64, out_features=27, bias=True)
)
2025-05-06 07:32:38,140 - pyhealth.trainer - INFO - CFVAE(
(external_classifier): DummyClassifier(
(model): Sequential(
(0): Linear(in_features=27, out_features=64, bias=True)
(1): ReLU()
(2): Linear(in_features=64, out_features=1, bias=True)
)
)
(enc1): Sequential(
(0): Linear(in_features=27, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
)
(enc2): Linear(in_features=64, out_features=64, bias=True)
(dec1): Sequential(
(0): Linear(in_features=34, out_features=64, bias=True)
(1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
(2): ReLU()
)
(dec2): Linear(in_features=64, out_features=27, bias=True)
)
Metrics: ['roc_auc', 'accuracy']
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Metrics: ['roc_auc', 'accuracy']
Device: cuda
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Device: cuda

2025-05-06 07:32:38,141 - pyhealth.trainer - INFO -
Training:
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Training:
Batch size: 32
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Batch size: 32
Optimizer: <class 'torch.optim.adam.Adam'>
2025-05-06 07:32:38,141 - pyhealth.trainer - INFO - Optimizer: <class 'torch.optim.adam.Adam'>
Optimizer params: {'lr': 0.001}
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Optimizer params: {'lr': 0.001}
Weight decay: 0.0
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Weight decay: 0.0
Max grad norm: None
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Max grad norm: None
Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b42355ee310>
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Val dataloader: <torch.utils.data.dataloader.DataLoader object at 0x7b42355ee310>
Monitor: roc_auc
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Monitor: roc_auc
Monitor criterion: max
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Monitor criterion: max
Epochs: 10
2025-05-06 07:32:38,142 - pyhealth.trainer - INFO - Epochs: 10

2025-05-06 07:32:38,142 - pyhealth.trainer - INFO -
Epoch 0 / 10: 100% 5/5 [00:00<00:00, 15.67it/s]
--- Train epoch-0, step-5 ---
2025-05-06 07:32:38,462 - pyhealth.trainer - INFO - --- Train epoch-0, step-5 ---
loss: 1.4165
2025-05-06 07:32:38,462 - pyhealth.trainer - INFO - loss: 1.4165
Evaluation: 100% 1/1 [00:00<00:00, 538.98it/s]
--- Eval epoch-0, step-5 ---
2025-05-06 07:32:38,468 - pyhealth.trainer - INFO - --- Eval epoch-0, step-5 ---
roc_auc: 0.0000
2025-05-06 07:32:38,468 - pyhealth.trainer - INFO - roc_auc: 0.0000
accuracy: 0.9524
2025-05-06 07:32:38,469 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.3753
2025-05-06 07:32:38,469 - pyhealth.trainer - INFO - loss: 1.3753
New best roc_auc score (0.0000) at epoch-0, step-5
2025-05-06 07:32:38,469 - pyhealth.trainer - INFO - New best roc_auc score (0.0000) at epoch-0, step-5

2025-05-06 07:32:38,471 - pyhealth.trainer - INFO -
Epoch 1 / 10: 100% 5/5 [00:00<00:00, 287.57it/s]
--- Train epoch-1, step-10 ---
2025-05-06 07:32:38,488 - pyhealth.trainer - INFO - --- Train epoch-1, step-10 ---
loss: 1.3389
2025-05-06 07:32:38,488 - pyhealth.trainer - INFO - loss: 1.3389
Evaluation: 100% 1/1 [00:00<00:00, 667.14it/s]
--- Eval epoch-1, step-10 ---
2025-05-06 07:32:38,493 - pyhealth.trainer - INFO - --- Eval epoch-1, step-10 ---
roc_auc: 0.5500
2025-05-06 07:32:38,494 - pyhealth.trainer - INFO - roc_auc: 0.5500
accuracy: 0.9524
2025-05-06 07:32:38,494 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.3141
2025-05-06 07:32:38,494 - pyhealth.trainer - INFO - loss: 1.3141
New best roc_auc score (0.5500) at epoch-1, step-10
2025-05-06 07:32:38,494 - pyhealth.trainer - INFO - New best roc_auc score (0.5500) at epoch-1, step-10

2025-05-06 07:32:38,496 - pyhealth.trainer - INFO -
Epoch 2 / 10: 100% 5/5 [00:00<00:00, 296.64it/s]
--- Train epoch-2, step-15 ---
2025-05-06 07:32:38,513 - pyhealth.trainer - INFO - --- Train epoch-2, step-15 ---
loss: 1.2860
2025-05-06 07:32:38,513 - pyhealth.trainer - INFO - loss: 1.2860
Evaluation: 100% 1/1 [00:00<00:00, 669.05it/s]
--- Eval epoch-2, step-15 ---
2025-05-06 07:32:38,518 - pyhealth.trainer - INFO - --- Eval epoch-2, step-15 ---
roc_auc: 0.5500
2025-05-06 07:32:38,518 - pyhealth.trainer - INFO - roc_auc: 0.5500
accuracy: 0.9524
2025-05-06 07:32:38,519 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2922
2025-05-06 07:32:38,519 - pyhealth.trainer - INFO - loss: 1.2922

2025-05-06 07:32:38,519 - pyhealth.trainer - INFO -
Epoch 3 / 10: 100% 5/5 [00:00<00:00, 304.42it/s]
--- Train epoch-3, step-20 ---
2025-05-06 07:32:38,536 - pyhealth.trainer - INFO - --- Train epoch-3, step-20 ---
loss: 1.2599
2025-05-06 07:32:38,536 - pyhealth.trainer - INFO - loss: 1.2599
Evaluation: 100% 1/1 [00:00<00:00, 619.27it/s]
--- Eval epoch-3, step-20 ---
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - --- Eval epoch-3, step-20 ---
roc_auc: 0.8500
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - roc_auc: 0.8500
accuracy: 0.9524
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2581
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - loss: 1.2581
New best roc_auc score (0.8500) at epoch-3, step-20
2025-05-06 07:32:38,541 - pyhealth.trainer - INFO - New best roc_auc score (0.8500) at epoch-3, step-20

2025-05-06 07:32:38,544 - pyhealth.trainer - INFO -
Epoch 4 / 10: 100% 5/5 [00:00<00:00, 295.69it/s]
--- Train epoch-4, step-25 ---
2025-05-06 07:32:38,561 - pyhealth.trainer - INFO - --- Train epoch-4, step-25 ---
loss: 1.2457
2025-05-06 07:32:38,561 - pyhealth.trainer - INFO - loss: 1.2457
Evaluation: 100% 1/1 [00:00<00:00, 648.07it/s]
--- Eval epoch-4, step-25 ---
2025-05-06 07:32:38,566 - pyhealth.trainer - INFO - --- Eval epoch-4, step-25 ---
roc_auc: 0.9500
2025-05-06 07:32:38,567 - pyhealth.trainer - INFO - roc_auc: 0.9500
accuracy: 0.9524
2025-05-06 07:32:38,567 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2414
2025-05-06 07:32:38,567 - pyhealth.trainer - INFO - loss: 1.2414
New best roc_auc score (0.9500) at epoch-4, step-25
2025-05-06 07:32:38,567 - pyhealth.trainer - INFO - New best roc_auc score (0.9500) at epoch-4, step-25

2025-05-06 07:32:38,569 - pyhealth.trainer - INFO -
Epoch 5 / 10: 100% 5/5 [00:00<00:00, 252.61it/s]
--- Train epoch-5, step-30 ---
2025-05-06 07:32:38,589 - pyhealth.trainer - INFO - --- Train epoch-5, step-30 ---
loss: 1.2208
2025-05-06 07:32:38,589 - pyhealth.trainer - INFO - loss: 1.2208
Evaluation: 100% 1/1 [00:00<00:00, 685.90it/s]
--- Eval epoch-5, step-30 ---
2025-05-06 07:32:38,594 - pyhealth.trainer - INFO - --- Eval epoch-5, step-30 ---
roc_auc: 0.3500
2025-05-06 07:32:38,594 - pyhealth.trainer - INFO - roc_auc: 0.3500
accuracy: 0.9524
2025-05-06 07:32:38,594 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2270
2025-05-06 07:32:38,594 - pyhealth.trainer - INFO - loss: 1.2270

2025-05-06 07:32:38,595 - pyhealth.trainer - INFO -
Epoch 6 / 10: 100% 5/5 [00:00<00:00, 294.96it/s]
--- Train epoch-6, step-35 ---
2025-05-06 07:32:38,612 - pyhealth.trainer - INFO - --- Train epoch-6, step-35 ---
loss: 1.1976
2025-05-06 07:32:38,612 - pyhealth.trainer - INFO - loss: 1.1976
Evaluation: 100% 1/1 [00:00<00:00, 660.42it/s]
--- Eval epoch-6, step-35 ---
2025-05-06 07:32:38,617 - pyhealth.trainer - INFO - --- Eval epoch-6, step-35 ---
roc_auc: 0.6000
2025-05-06 07:32:38,617 - pyhealth.trainer - INFO - roc_auc: 0.6000
accuracy: 0.9524
2025-05-06 07:32:38,617 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.2117
2025-05-06 07:32:38,617 - pyhealth.trainer - INFO - loss: 1.2117

2025-05-06 07:32:38,618 - pyhealth.trainer - INFO -
Epoch 7 / 10: 100% 5/5 [00:00<00:00, 300.62it/s]
--- Train epoch-7, step-40 ---
2025-05-06 07:32:38,635 - pyhealth.trainer - INFO - --- Train epoch-7, step-40 ---
loss: 1.1940
2025-05-06 07:32:38,635 - pyhealth.trainer - INFO - loss: 1.1940
Evaluation: 100% 1/1 [00:00<00:00, 666.93it/s]
--- Eval epoch-7, step-40 ---
2025-05-06 07:32:38,640 - pyhealth.trainer - INFO - --- Eval epoch-7, step-40 ---
roc_auc: 0.1000
2025-05-06 07:32:38,640 - pyhealth.trainer - INFO - roc_auc: 0.1000
accuracy: 0.9524
2025-05-06 07:32:38,640 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.1942
2025-05-06 07:32:38,640 - pyhealth.trainer - INFO - loss: 1.1942

2025-05-06 07:32:38,640 - pyhealth.trainer - INFO -
Epoch 8 / 10: 100% 5/5 [00:00<00:00, 296.75it/s]
--- Train epoch-8, step-45 ---
2025-05-06 07:32:38,657 - pyhealth.trainer - INFO - --- Train epoch-8, step-45 ---
loss: 1.1885
2025-05-06 07:32:38,658 - pyhealth.trainer - INFO - loss: 1.1885
Evaluation: 100% 1/1 [00:00<00:00, 641.43it/s]
--- Eval epoch-8, step-45 ---
2025-05-06 07:32:38,663 - pyhealth.trainer - INFO - --- Eval epoch-8, step-45 ---
roc_auc: 0.5500
2025-05-06 07:32:38,663 - pyhealth.trainer - INFO - roc_auc: 0.5500
accuracy: 0.9524
2025-05-06 07:32:38,663 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.1871
2025-05-06 07:32:38,663 - pyhealth.trainer - INFO - loss: 1.1871

2025-05-06 07:32:38,663 - pyhealth.trainer - INFO -
Epoch 9 / 10: 100% 5/5 [00:00<00:00, 289.75it/s]
--- Train epoch-9, step-50 ---
2025-05-06 07:32:38,681 - pyhealth.trainer - INFO - --- Train epoch-9, step-50 ---
loss: 1.1748
2025-05-06 07:32:38,681 - pyhealth.trainer - INFO - loss: 1.1748
Evaluation: 100% 1/1 [00:00<00:00, 621.56it/s]
--- Eval epoch-9, step-50 ---
2025-05-06 07:32:38,686 - pyhealth.trainer - INFO - --- Eval epoch-9, step-50 ---
roc_auc: 0.6500
2025-05-06 07:32:38,686 - pyhealth.trainer - INFO - roc_auc: 0.6500
accuracy: 0.9524
2025-05-06 07:32:38,687 - pyhealth.trainer - INFO - accuracy: 0.9524
loss: 1.1849
2025-05-06 07:32:38,687 - pyhealth.trainer - INFO - loss: 1.1849
Loaded best model
2025-05-06 07:32:38,687 - pyhealth.trainer - INFO - Loaded best model
2025-05-06 07:32:38,691 - main - INFO - ===== Test set evaluation =====
Evaluation: 100% 2/2 [00:00<00:00, 72.16it/s]
{'roc_auc': np.float64(0.23577235772357724), 'accuracy': 0.9318181818181818, 'loss': 1.5149085521697998}
2025-05-06 07:32:38,722 - main - INFO - ===== Successfully completed CFVAE unit test! =====

Files to Review

  • pyhealth/models/cfvae.py
    • The described CFVAE model by Nagesh et al., 2023, with parametrized binary classifier
  • pyhealth/unittests/test_cfvae_mortality_prediction.py
    • Unit test involving 2 stages: training a dummy classifier, and training the CFVAE with the frozen classifier to guide in counterfactual generation

Extension

An extension is made to parametrize the external classifier to de-couple it as much as possible from the CFVAE itself. This deviates from the authors' original code to tightly couple their original MLP model with the CFVAE.

While the binary classifier needs to be part of the internal CFVAE layers, there are many types of binary prediction tasks in health care (mortality prediction, readmission prediction, etc.). The extra burden should not be placed on researchers to hand-craft each use-case, so we opt for more re-usability.

To do this, we enhance their described model to accept a frozen classifier to be passed as an argument to the CFVAE model, which also more closely aligns with the original paper's description of an external, black-box Binary Prediction (BP) model as seen in Figure 2 of the paper.

The PyHealth model in this PR outputs loss, y_true, and y_prob in alignment with PyHealth’s training pipeline expectations.

Important Design Notes

  • Does NOT inherit from the existing vae.py model, which is designed for image signals and uses 2D convolutional blocks.
  • Intentional design to pass a binary classifier to promote re-usability of the model. This is to extend the authors' model in a way that promotes re-usability.
  • The sparsity constraint described by the authors is a part of the training loop and should be handled separately, not in the model itself.

@gjlee4 gjlee4 marked this pull request as ready for review May 4, 2025 21:15
@leegabriel leegabriel force-pushed the master branch 8 times, most recently from 60cb3d4 to 0ff0c4b Compare May 6, 2025 05:40
@sblittlefield sblittlefield added the Highlight for TAs to highlight label May 13, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Highlight for TAs to highlight
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants