Skip to content

Commit abc5723

Browse files
zzachwJohn Wu
andauthored
Fix issues with BaseDataset.load_table() (#368)
* Fix two issues with `BaseDataset.load_table()` 1. Introduced `scan_csv_gz_or_csv` function to handle both CSV and CSV.gz files. 2. Added a preprocessing function call to `BaseDataset.preprocess_{table_name}` that allows table specific customized processing * Fix file extension handling in `scan_csv_gz_or_csv` function. Updated the logic to check for `.gz` suffix and modified the logging message for clarity when falling back to CSV files. * Edited it so it would be able to load Synthetic API datasets through url checking and API calls, this way our tutorials should still work * Cleans up `scan_csv_gz_or_csv` function * Fix datetime parsing in MortalityPredictionMIMIC3 --------- Co-authored-by: John Wu <[email protected]>
1 parent b907a3c commit abc5723

File tree

4 files changed

+146
-56
lines changed

4 files changed

+146
-56
lines changed

pyhealth/datasets/base_dataset.py

Lines changed: 105 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@
22
import os
33
from abc import ABC
44
from concurrent.futures import ThreadPoolExecutor, as_completed
5+
from pathlib import Path
56
from typing import Iterator, List, Optional
7+
from urllib.parse import urlparse, urlunparse
68

79
import polars as pl
10+
import requests
811
from tqdm import tqdm
912

1013
from ..data import Patient
@@ -15,6 +18,67 @@
1518
logger = logging.getLogger(__name__)
1619

1720

21+
def is_url(path: str) -> bool:
22+
"""URL detection."""
23+
result = urlparse(path)
24+
# Both scheme and netloc must be present for a valid URL
25+
return all([result.scheme, result.netloc])
26+
27+
28+
def clean_path(path: str) -> str:
29+
"""Clean a path string."""
30+
if is_url(path):
31+
parsed = urlparse(path)
32+
cleaned_path = os.path.normpath(parsed.path)
33+
# Rebuild the full URL
34+
return urlunparse(parsed._replace(path=cleaned_path))
35+
else:
36+
# It's a local path — resolve and normalize
37+
return str(Path(path).expanduser().resolve())
38+
39+
40+
def path_exists(path: str) -> bool:
41+
"""
42+
Check if a path exists.
43+
If the path is a URL, it will send a HEAD request.
44+
If the path is a local file, it will use the Path.exists().
45+
"""
46+
if is_url(path):
47+
try:
48+
response = requests.head(path, timeout=5)
49+
return response.status_code == 200
50+
except requests.RequestException:
51+
return False
52+
else:
53+
return Path(path).exists()
54+
55+
56+
def scan_csv_gz_or_csv(path: str) -> pl.LazyFrame:
57+
"""
58+
Scan a CSV.gz or CSV file and returns a LazyFrame.
59+
It will fall back to the other extension if not found.
60+
61+
Args:
62+
path (str): URL or local path to a .csv or .csv.gz file
63+
64+
Returns:
65+
pl.LazyFrame: The LazyFrame for the CSV.gz or CSV file.
66+
"""
67+
if path_exists(path):
68+
return pl.scan_csv(path, infer_schema=False)
69+
# Try the alternative extension
70+
if path.endswith(".csv.gz"):
71+
alt_path = path[:-3] # Remove .gz
72+
elif path.endswith(".csv"):
73+
alt_path = f"{path}.gz" # Add .gz
74+
else:
75+
raise FileNotFoundError(f"Path does not have expected extension: {path}")
76+
if path_exists(alt_path):
77+
logger.info(f"Original path does not exist. Using alternative: {alt_path}")
78+
return pl.scan_csv(alt_path, infer_schema=False)
79+
raise FileNotFoundError(f"Neither path exists: {path} or {alt_path}")
80+
81+
1882
class BaseDataset(ABC):
1983
"""Abstract base class for all PyHealth datasets.
2084
@@ -79,15 +143,13 @@ def collected_global_event_df(self) -> pl.DataFrame:
79143
if self.dev:
80144
# Limit the number of patients in dev mode
81145
logger.info("Dev mode enabled: limiting to 1000 patients")
82-
limited_patients = (
83-
df.select(pl.col("patient_id"))
84-
.unique()
85-
.limit(1000)
86-
)
146+
limited_patients = df.select(pl.col("patient_id")).unique().limit(1000)
87147
df = df.join(limited_patients, on="patient_id", how="inner")
88148

89149
self._collected_global_event_df = df.collect()
90-
logger.info(f"Collected dataframe with shape: {self._collected_global_event_df.shape}")
150+
logger.info(
151+
f"Collected dataframe with shape: {self._collected_global_event_df.shape}"
152+
)
91153

92154
return self._collected_global_event_df
93155

@@ -118,36 +180,42 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
118180

119181
table_cfg = self.config.tables[table_name]
120182
csv_path = f"{self.root}/{table_cfg.file_path}"
121-
# TODO: check if it's zipped or not.
122-
123-
# TODO: make this work for remote files
124-
# if not Path(csv_path).exists():
125-
# raise FileNotFoundError(f"CSV not found: {csv_path}")
183+
csv_path = clean_path(csv_path)
126184

127185
logger.info(f"Scanning table: {table_name} from {csv_path}")
128-
129-
df = pl.scan_csv(csv_path, infer_schema=False)
130-
131-
# TODO: this is an ad hoc fix for the MIMIC-III dataset
132-
df = df.with_columns([pl.col(col).alias(col.lower()) for col in df.collect_schema().names()])
186+
df = scan_csv_gz_or_csv(csv_path)
187+
188+
# Convert column names to lowercase before calling preprocess_func
189+
col_names = df.collect_schema().names()
190+
if any(col != col.lower() for col in col_names):
191+
logger.warning("Some column names were converted to lowercase")
192+
df = df.with_columns([pl.col(col).alias(col.lower()) for col in col_names])
193+
194+
# Check if there is a preprocessing function for this table
195+
preprocess_func = getattr(self, f"preprocess_{table_name}", None)
196+
if preprocess_func is not None:
197+
logger.info(
198+
f"Preprocessing table: {table_name} with {preprocess_func.__name__}"
199+
)
200+
df = preprocess_func(df)
133201

134202
# Handle joins
135203
for join_cfg in table_cfg.join:
136204
other_csv_path = f"{self.root}/{join_cfg.file_path}"
137-
# if not Path(other_csv_path).exists():
138-
# raise FileNotFoundError(
139-
# f"Join CSV not found: {other_csv_path}"
140-
# )
141-
142-
join_df = pl.scan_csv(other_csv_path, infer_schema=False)
143-
join_df = join_df.with_columns([pl.col(col).alias(col.lower()) for col in join_df.collect_schema().names()])
205+
other_csv_path = clean_path(other_csv_path)
206+
logger.info(f"Joining with table: {other_csv_path}")
207+
join_df = scan_csv_gz_or_csv(other_csv_path)
208+
join_df = join_df.with_columns(
209+
[
210+
pl.col(col).alias(col.lower())
211+
for col in join_df.collect_schema().names()
212+
]
213+
)
144214
join_key = join_cfg.on
145215
columns = join_cfg.columns
146216
how = join_cfg.how
147217

148-
df = df.join(
149-
join_df.select([join_key] + columns), on=join_key, how=how
150-
)
218+
df = df.join(join_df.select([join_key] + columns), on=join_key, how=how)
151219

152220
patient_id_col = table_cfg.patient_id
153221
timestamp_col = table_cfg.timestamp
@@ -158,10 +226,9 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
158226
if timestamp_col:
159227
if isinstance(timestamp_col, list):
160228
# Concatenate all timestamp parts in order with no separator
161-
combined_timestamp = (
162-
pl.concat_str([pl.col(col) for col in timestamp_col])
163-
.str.strptime(pl.Datetime, format=timestamp_format, strict=True)
164-
)
229+
combined_timestamp = pl.concat_str(
230+
[pl.col(col) for col in timestamp_col]
231+
).str.strptime(pl.Datetime, format=timestamp_format, strict=True)
165232
timestamp_expr = combined_timestamp
166233
else:
167234
# Single timestamp column
@@ -185,8 +252,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
185252

186253
# Flatten attribute columns with event_type prefix
187254
attribute_columns = [
188-
pl.col(attr).alias(f"{table_name}/{attr}")
189-
for attr in attribute_cols
255+
pl.col(attr).alias(f"{table_name}/{attr}") for attr in attribute_cols
190256
]
191257

192258
event_frame = df.select(base_columns + attribute_columns)
@@ -225,9 +291,7 @@ def get_patient(self, patient_id: str) -> Patient:
225291
assert (
226292
patient_id in self.unique_patient_ids
227293
), f"Patient {patient_id} not found in dataset"
228-
df = self.collected_global_event_df.filter(
229-
pl.col("patient_id") == patient_id
230-
)
294+
df = self.collected_global_event_df.filter(pl.col("patient_id") == patient_id)
231295
return Patient(patient_id=patient_id, data_source=df)
232296

233297
def iter_patients(self, df: Optional[pl.LazyFrame] = None) -> Iterator[Patient]:
@@ -260,11 +324,9 @@ def default_task(self) -> Optional[BaseTask]:
260324
Optional[BaseTask]: The default task, if any.
261325
"""
262326
return None
263-
327+
264328
def set_task(
265-
self,
266-
task: Optional[BaseTask] = None,
267-
num_workers: Optional[int] = None
329+
self, task: Optional[BaseTask] = None, num_workers: Optional[int] = None
268330
) -> SampleDataset:
269331
"""Processes the base dataset to generate the task-specific sample dataset.
270332
@@ -283,7 +345,9 @@ def set_task(
283345
assert self.default_task is not None, "No default tasks found"
284346
task = self.default_task
285347

286-
logger.info(f"Setting task {task.task_name} for {self.dataset_name} base dataset...")
348+
logger.info(
349+
f"Setting task {task.task_name} for {self.dataset_name} base dataset..."
350+
)
287351

288352
filtered_global_event_df = task.pre_filter(self.collected_global_event_df)
289353

@@ -298,7 +362,7 @@ def set_task(
298362
if num_workers == 1:
299363
for patient in tqdm(
300364
self.iter_patients(filtered_global_event_df),
301-
desc=f"Generating samples for {task.task_name}"
365+
desc=f"Generating samples for {task.task_name}",
302366
):
303367
samples.extend(task(patient))
304368
else:

pyhealth/datasets/configs/mimic3.yaml

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
version: "1.4"
22
tables:
33
patients:
4-
file_path: "PATIENTS.csv"
4+
file_path: "PATIENTS.csv.gz"
55
patient_id: "subject_id"
66
timestamp: null
77
attributes:
@@ -13,7 +13,7 @@ tables:
1313
- "expire_flag"
1414

1515
admissions:
16-
file_path: "ADMISSIONS.csv"
16+
file_path: "ADMISSIONS.csv.gz"
1717
patient_id: "subject_id"
1818
timestamp: "admittime"
1919
attributes:
@@ -33,7 +33,7 @@ tables:
3333
- "hospital_expire_flag"
3434

3535
icustays:
36-
file_path: "ICUSTAYS.csv"
36+
file_path: "ICUSTAYS.csv.gz"
3737
patient_id: "subject_id"
3838
timestamp: "intime"
3939
attributes:
@@ -44,10 +44,10 @@ tables:
4444
- "outtime"
4545

4646
diagnoses_icd:
47-
file_path: "DIAGNOSES_ICD.csv"
47+
file_path: "DIAGNOSES_ICD.csv.gz"
4848
patient_id: "subject_id"
4949
join:
50-
- file_path: "ADMISSIONS.csv"
50+
- file_path: "ADMISSIONS.csv.gz"
5151
"on": "hadm_id"
5252
how: "inner"
5353
columns:
@@ -58,7 +58,7 @@ tables:
5858
- "seq_num"
5959

6060
prescriptions:
61-
file_path: "PRESCRIPTIONS.csv"
61+
file_path: "PRESCRIPTIONS.csv.gz"
6262
patient_id: "subject_id"
6363
timestamp: "startdate"
6464
attributes:
@@ -76,12 +76,12 @@ tables:
7676
- "form_unit_disp"
7777
- "route"
7878
- "enddate"
79-
79+
8080
procedures_icd:
81-
file_path: "PROCEDURES_ICD.csv"
81+
file_path: "PROCEDURES_ICD.csv.gz"
8282
patient_id: "subject_id"
8383
join:
84-
- file_path: "ADMISSIONS.csv"
84+
- file_path: "ADMISSIONS.csv.gz"
8585
"on": "hadm_id"
8686
how: "inner"
8787
columns:
@@ -92,10 +92,10 @@ tables:
9292
- "seq_num"
9393

9494
labevents:
95-
file_path: "LABEVENTS.csv"
95+
file_path: "LABEVENTS.csv.gz"
9696
patient_id: "subject_id"
9797
join:
98-
- file_path: "D_LABITEMS.csv"
98+
- file_path: "D_LABITEMS.csv.gz"
9999
"on": "itemid"
100100
how: "inner"
101101
columns:
@@ -114,12 +114,10 @@ tables:
114114
- "flag"
115115

116116
noteevents:
117-
file_path: "NOTEEVENTS.csv"
117+
file_path: "NOTEEVENTS.csv.gz"
118118
patient_id: "subject_id"
119119
timestamp:
120-
- "chartdate"
121120
- "charttime"
122-
timestamp_format: "%Y%m%d%H%M%S"
123121
attributes:
124122
- "text"
125123
- "category"

pyhealth/datasets/mimic3.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
from pathlib import Path
44
from typing import List, Optional
55

6+
import polars as pl
7+
68
from .base_dataset import BaseDataset
79

810
logger = logging.getLogger(__name__)
@@ -58,3 +60,28 @@ def __init__(
5860
**kwargs
5961
)
6062
return
63+
64+
def preprocess_noteevents(self, df: pl.LazyFrame) -> pl.LazyFrame:
65+
"""
66+
Table-specific preprocess function which will be called by BaseDataset.load_table().
67+
68+
Preprocesses the noteevents table by ensuring that the charttime column
69+
is populated. If charttime is null, it uses chartdate with a default
70+
time of 00:00:00.
71+
72+
See: https://mimic.mit.edu/docs/iii/tables/noteevents/#chartdate-charttime-storetime.
73+
74+
Args:
75+
df (pl.LazyFrame): The input dataframe containing noteevents data.
76+
77+
Returns:
78+
pl.LazyFrame: The processed dataframe with updated charttime
79+
values.
80+
"""
81+
df = df.with_columns(
82+
pl.when(pl.col("charttime").is_null())
83+
.then(pl.col("chartdate") + pl.lit(" 00:00:00"))
84+
.otherwise(pl.col("charttime"))
85+
.alias("charttime")
86+
)
87+
return df

pyhealth/tasks/mortality_prediction.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from datetime import datetime
22
from typing import Any, Dict, List, Optional
3+
34
from .base_task import BaseTask
45

56

@@ -41,7 +42,7 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
4142
try:
4243
# Check the type and convert if necessary
4344
if isinstance(visit.dischtime, str):
44-
discharge_time = datetime.strptime(visit.dischtime, "%Y-%m-%d")
45+
discharge_time = datetime.strptime(visit.dischtime, "%Y-%m-%d %H:%M:%S")
4546
else:
4647
discharge_time = visit.dischtime
4748
except (ValueError, AttributeError):

0 commit comments

Comments
 (0)