[CS598 DLH] Counterfactual VAE (CF-VAE) for binary healthcare prediction tasks #404
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Contributor Info
Name: Sharim Khan, Gabriel Lee
NetID: sharimk2, gjlee4
Paper: Explaining a Machine Learning Decision to Physicians via Counterfactuals
Paper Link: https://arxiv.org/abs/2306.06325
Type of Contribution
Model: New model addition
Task: Counterfactual VAE (CF-VAE) for binary healthcare prediction tasks
High-Level Description
This PR contributes a PyHealth-compatible implementation of the Counterfactual Variational Autoencoder (CFVAE) model, which is specifically designed to generate counterfactuals for binary prediction tasks. Specifically, the model is composed of:
The model output should generate a counterfactual based on the input and provided classifier.
This implementation is inspired by the model proposed in:
Explaining a Machine Learning Decision to Physicians via Counterfactuals
Nagesh et al., 2023 (arXiv link)
Testing
To aid in testing the model, we provide the below Google Colab Notebook.
https://colab.research.google.com/drive/1gRaE6QDYfgjhopzEaAgCy45WaEL9ON1D
The unit test is recommended to be run in the notebook. The notebook will handle reproducing the environment, such that the test case can be run as:
Unit Test Output
Files to Review
pyhealth/models/cfvae.py
pyhealth/unittests/test_cfvae_mortality_prediction.py
Extension
An extension is made to parametrize the external classifier to de-couple it as much as possible from the CFVAE itself. This deviates from the authors' original code to tightly couple their original MLP model with the CFVAE.
While the binary classifier needs to be part of the internal CFVAE layers, there are many types of binary prediction tasks in health care (mortality prediction, readmission prediction, etc.). The extra burden should not be placed on researchers to hand-craft each use-case, so we opt for more re-usability.
To do this, we enhance their described model to accept a frozen classifier to be passed as an argument to the CFVAE model, which also more closely aligns with the original paper's description of an external, black-box Binary Prediction (BP) model as seen in Figure 2 of the paper.
The PyHealth model in this PR outputs
loss
,y_true
, andy_prob
in alignment with PyHealth’s training pipeline expectations.Important Design Notes
vae.py
model, which is designed for image signals and uses 2D convolutional blocks.