Skip to content

Commit cd455fd

Browse files
committed
Add cfvae.py and unit test
1 parent aca976d commit cd455fd

File tree

3 files changed

+366
-0
lines changed

3 files changed

+366
-0
lines changed

pyhealth/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,4 @@
2525
from .transformer import Transformer, TransformerLayer
2626
from .transformers_model import TransformersModel
2727
from .vae import VAE
28+
from .cfvae import CFVAE

pyhealth/models/cfvae.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
# ==============================================================================
2+
# Author(s): Sharim Khan, Gabriel Lee
3+
# NetID(s): sharimk2, gjlee4
4+
# Paper title:
5+
# Explaining A Machine Learning Decision to Physicians via Counterfactuals
6+
# Paper link: https://arxiv.org/abs/2306.06325
7+
# Description: This file defines the Counterfactual Variational Autoencoder (CFVAE)
8+
# model, which reconstructs input data while generating counterfactual
9+
# examples that flip the prediction of a frozen classifier.
10+
# ==============================================================================
11+
12+
from typing import List, Dict
13+
14+
import torch
15+
import torch.nn as nn
16+
import torch.nn.functional as F
17+
18+
from pyhealth.models import BaseModel
19+
20+
21+
class CFVAE(BaseModel):
22+
"""Counterfactual Variational Autoencoder (CFVAE) for binary prediction tasks.
23+
24+
This is a parametrized version of the CFVAE model described by Nagesh et al.
25+
26+
The CFVAE learns to reconstruct inputs while generating counterfactual samples
27+
that flip the output of a fixed, externally trained binary classifier. It combines
28+
VAE reconstruction and KL divergence losses with a classifier-based loss.
29+
30+
NOTE: A binary classifier MUST be passed as an argument.
31+
NOTE: The sparsity constraint should be implemented in the training loop.
32+
33+
Attributes:
34+
feature_keys: Feature keys used as inputs.
35+
label_keys: A list containing the label key.
36+
mode: Task mode (must be 'binary').
37+
latent_dim: Latent dimensionality of the VAE.
38+
external_classifier: Frozen external classifier for guiding counterfactuals.
39+
enc1: First encoder layer.
40+
enc2: Layer projecting to latent mean and log-variance.
41+
dec1: First decoder layer.
42+
dec2: Layer projecting to reconstructed input space.
43+
44+
Example:
45+
cfvae = CFVAE(
46+
dataset=samples,
47+
feature_keys=["labs"],
48+
label_key="mortality",
49+
mode="binary",
50+
feat_dim=27,
51+
latent_dim=32,
52+
hidden_dim=64,
53+
external_classifier=frozen_classifier
54+
)
55+
"""
56+
57+
def __init__(
58+
self,
59+
dataset,
60+
feature_keys: List[str],
61+
label_key: str,
62+
mode: str,
63+
feat_dim: int,
64+
latent_dim: int = 32,
65+
hidden_dim: int = 64,
66+
external_classifier: nn.Module = None,
67+
):
68+
"""
69+
Initializes the CFVAE model and freezes the external classifier.
70+
71+
Args:
72+
dataset: PyHealth-compatible dataset object.
73+
feature_keys: List of input feature keys.
74+
label_key: Output label key (must be binary).
75+
mode: Task mode ('binary' only supported).
76+
feat_dim: Input feature dimensionality.
77+
latent_dim: Latent space dimensionality.
78+
hidden_dim: Hidden layer size in encoder/decoder.
79+
external_classifier: Frozen binary classifier to guide counterfactuals.
80+
"""
81+
super().__init__(dataset)
82+
self.feature_keys = feature_keys
83+
self.label_keys = [label_key]
84+
self.mode = mode
85+
86+
assert mode == "binary", "Only binary classification is supported."
87+
assert external_classifier is not None, "external_classifier must be provided."
88+
89+
self.latent_dim = latent_dim
90+
self.external_classifier = external_classifier.eval()
91+
for param in self.external_classifier.parameters():
92+
param.requires_grad = False
93+
94+
self.enc1 = nn.Sequential(
95+
nn.Linear(feat_dim, hidden_dim),
96+
nn.LayerNorm(hidden_dim),
97+
nn.ReLU()
98+
)
99+
self.enc2 = nn.Linear(hidden_dim, 2 * latent_dim)
100+
101+
self.dec1 = nn.Sequential(
102+
nn.Linear(latent_dim + 2, hidden_dim),
103+
nn.LayerNorm(hidden_dim),
104+
nn.ReLU()
105+
)
106+
self.dec2 = nn.Linear(hidden_dim, feat_dim)
107+
108+
def reparameterize(
109+
self, mu: torch.Tensor, log_var: torch.Tensor
110+
) -> torch.Tensor:
111+
"""
112+
Applies the reparameterization trick to sample z from Gaussian N.
113+
114+
Args:
115+
mu: Mean of the latent distribution, shape (B, latent_dim).
116+
log_var: Log variance of the latent distribution, shape (B, latent_dim).
117+
118+
Returns:
119+
z: Sampled latent variable, shape (B, latent_dim).
120+
"""
121+
std = torch.exp(0.5 * log_var)
122+
eps = torch.randn_like(std)
123+
return mu + eps * std
124+
125+
def forward(self, **kwargs) -> Dict[str, torch.Tensor]:
126+
"""
127+
Forward pass for CFVAE: encodes input, reparameterizes, decodes with flipped
128+
labels, and computes reconstruction, KL, and classifier-based losses.
129+
130+
Args:
131+
kwargs: Dict of inputs including:
132+
- feature_keys[0]: Input tensor (B, feat_dim)
133+
- label_keys[0]: Ground truth label tensor (B,)
134+
135+
Returns:
136+
Dictionary containing:
137+
- loss: Total training loss (recon + KL + classifier disagreement).
138+
- y_prob: Classifier output probabilities for reconstructed inputs.
139+
- y_true: Ground truth labels.
140+
"""
141+
x = kwargs[self.feature_keys[0]].to(self.device)
142+
y = kwargs[self.label_keys[0]].to(self.device)
143+
144+
# Encode inputs
145+
h = self.enc1(x)
146+
h = self.enc2(h).view(-1, 2, self.latent_dim)
147+
mu, log_var = h[:, 0, :], h[:, 1, :]
148+
z = self.reparameterize(mu, log_var)
149+
150+
# Flip labels to condition decoder on opposite class (counterfactual)
151+
y_cf = 1 - y
152+
y_cf_onehot = F.one_hot(y_cf.view(-1).long(), num_classes=2).float()
153+
z_cond = torch.cat([z, y_cf_onehot], dim=1)
154+
155+
h_dec = self.dec1(z_cond)
156+
x_recon = torch.sigmoid(self.dec2(h_dec))
157+
158+
# Evaluate external classifier on counterfactual
159+
with torch.no_grad():
160+
logits = self.external_classifier(x_recon)
161+
162+
# Compute losses
163+
clf_loss = self.get_loss_function()(logits, y)
164+
recon_loss = F.mse_loss(x_recon, x, reduction="mean")
165+
kld_loss = -0.5 * torch.mean(
166+
1 + log_var - mu.pow(2) - log_var.exp()
167+
)
168+
total_loss = recon_loss + kld_loss + clf_loss
169+
170+
return {
171+
"loss": total_loss,
172+
"y_prob": self.prepare_y_prob(logits),
173+
"y_true": y,
174+
}
175+
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# ==============================================================================
2+
# Author(s): Sharim Khan, Gabriel Lee
3+
# NetID(s): sharimk2, gjlee4
4+
# Paper title:
5+
# Explaining A Machine Learning Decision to Physicians via Counterfactuals
6+
# Paper link: https://arxiv.org/abs/2306.06325
7+
# Description: Test script to train and evaluate a Counterfactual VAE (CFVAE) on
8+
# MIMIC-IV for mortality prediction using PyHealth, including training
9+
# a frozen dummy classifier and then CFVAE with that classifier.
10+
# ==============================================================================
11+
12+
import logging
13+
import os
14+
import sys
15+
from typing import Any
16+
17+
import torch
18+
import torch.nn as nn
19+
20+
# Configure logging
21+
logging.basicConfig(
22+
level=logging.INFO,
23+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
24+
)
25+
logger = logging.getLogger(__name__)
26+
27+
# Add parent directory to sys.path for relative imports
28+
current_dir = os.path.dirname(os.path.abspath(__file__))
29+
parent_dir = os.path.dirname(os.path.dirname(current_dir))
30+
if parent_dir not in sys.path:
31+
sys.path.insert(0, parent_dir)
32+
33+
34+
def test_cfvae_mortality_prediction_mimic4() -> None:
35+
"""Trains a CFVAE model on MIMIC-IV demo data with a frozen dummy classifier.
36+
37+
Steps:
38+
- Load and preprocess MIMIC-IV lab data.
39+
- Train a binary classifier on in-hospital mortality.
40+
- Freeze the classifier.
41+
- Train a CFVAE model to produce counterfactuals.
42+
- Evaluate CFVAE on test data.
43+
"""
44+
logger.info("===== Starting CFVAE Unit Test =====")
45+
from pyhealth.datasets import MIMIC4Dataset
46+
from pyhealth.tasks import InHospitalMortalityMIMIC4
47+
from pyhealth.datasets import split_by_sample, get_dataloader
48+
from pyhealth.trainer import Trainer
49+
from pyhealth.models import BaseModel, CFVAE
50+
51+
# Load MIMIC-IV demo dataset
52+
dataset = MIMIC4Dataset(
53+
ehr_root="https://physionet.org/files/mimic-iv-demo/2.2/",
54+
ehr_tables=["diagnoses_icd", "procedures_icd", "prescriptions", "labevents"],
55+
)
56+
57+
task = InHospitalMortalityMIMIC4()
58+
samples = dataset.set_task(task)
59+
logger.info(f"===== Loaded {len(samples)} samples. ===== ")
60+
61+
# Preprocessing: mean over time, normalize across samples
62+
logger.info("===== Preprocessing samples (mean over time) =====")
63+
for sample in samples:
64+
sample["labs"] = torch.mean(sample["labs"], dim=0)
65+
66+
labs_tensor = torch.stack([s["labs"] for s in samples])
67+
feature_mean = labs_tensor.mean(dim=0)
68+
feature_std = labs_tensor.std(dim=0) + 1e-6
69+
70+
for sample in samples:
71+
sample["labs"] = (sample["labs"] - feature_mean) / feature_std
72+
73+
# Split data
74+
train_dataset, val_dataset, test_dataset = split_by_sample(
75+
dataset=samples,
76+
ratios=[0.7, 0.1, 0.2]
77+
)
78+
79+
train_dataloader = get_dataloader(train_dataset, batch_size=32, shuffle=True)
80+
val_dataloader = get_dataloader(val_dataset, batch_size=32, shuffle=False)
81+
test_dataloader = get_dataloader(test_dataset, batch_size=32, shuffle=False)
82+
83+
logger.info("===== Stage 1: Train the dummy classifier =====")
84+
85+
class DummyClassifier(nn.Module):
86+
"""Simple feedforward binary classifier."""
87+
88+
def __init__(self, input_dim: int = 27, hidden_dim: int = 64):
89+
"""
90+
Args:
91+
input_dim: Dimension of input feature vector.
92+
hidden_dim: Size of hidden layer.
93+
"""
94+
super().__init__()
95+
self.model = nn.Sequential(
96+
nn.Linear(input_dim, hidden_dim),
97+
nn.ReLU(),
98+
nn.Linear(hidden_dim, 1)
99+
)
100+
101+
def forward(self, x: torch.Tensor) -> torch.Tensor:
102+
"""Forward pass.
103+
104+
Args:
105+
x: Tensor of shape [batch_size, input_dim].
106+
107+
Returns:
108+
Output logits as a tensor of shape [batch_size, 1].
109+
"""
110+
return self.model(x)
111+
112+
class WrappedClassifier(BaseModel):
113+
"""Wraps a PyTorch classifier into the PyHealth BaseModel interface."""
114+
115+
def __init__(self, dataset: Any, model: nn.Module):
116+
"""
117+
Args:
118+
dataset: PyHealth dataset object.
119+
model: PyTorch model to be wrapped.
120+
"""
121+
super().__init__(dataset)
122+
self.model = model
123+
self.mode = self.dataset.output_schema[self.label_keys[0]]
124+
125+
def forward(self, **kwargs) -> dict:
126+
"""Forward pass and loss computation.
127+
128+
Args:
129+
kwargs: Dict containing "labs" and "mortality".
130+
131+
Returns:
132+
Dictionary with keys "loss", "y_prob", and "y_true".
133+
"""
134+
x = kwargs[self.feature_keys[0]].to(self.device)
135+
y = kwargs[self.label_keys[0]].to(self.device)
136+
logits = self.model(x)
137+
loss = self.get_loss_function()(logits, y)
138+
y_prob = self.prepare_y_prob(logits)
139+
return {
140+
"loss": loss,
141+
"y_prob": y_prob,
142+
"y_true": y
143+
}
144+
145+
clf = DummyClassifier(input_dim=27)
146+
wrapped_model = WrappedClassifier(dataset=samples, model=clf)
147+
148+
trainer = Trainer(model=wrapped_model, metrics=["roc_auc", "accuracy"])
149+
trainer.train(
150+
train_dataloader=train_dataloader,
151+
val_dataloader=val_dataloader,
152+
epochs=5,
153+
monitor="roc_auc"
154+
)
155+
156+
logger.info("===== Freezing the classifier... =====")
157+
clf.eval()
158+
for param in clf.parameters():
159+
param.requires_grad = False
160+
161+
logger.info("===== Stage 2: Train CFVAE with frozen classifier =====")
162+
163+
cfvae_model = CFVAE(
164+
dataset=samples,
165+
feature_keys=["labs"],
166+
label_key="mortality",
167+
mode="binary",
168+
feat_dim=27,
169+
latent_dim=32,
170+
hidden_dim=64,
171+
external_classifier=clf
172+
)
173+
174+
cfvae_trainer = Trainer(model=cfvae_model, metrics=["roc_auc", "accuracy"])
175+
cfvae_trainer.train(
176+
train_dataloader=train_dataloader,
177+
val_dataloader=val_dataloader,
178+
epochs=10,
179+
monitor="roc_auc",
180+
optimizer_params={"lr": 1e-3}
181+
)
182+
183+
logger.info("===== Test set evaluation =====")
184+
print(cfvae_trainer.evaluate(test_dataloader))
185+
logger.info("===== Successfully completed CFVAE unit test! =====")
186+
187+
188+
if __name__ == "__main__":
189+
test_cfvae_mortality_prediction_mimic4()
190+

0 commit comments

Comments
 (0)