2
2
import os
3
3
from abc import ABC
4
4
from concurrent .futures import ThreadPoolExecutor , as_completed
5
+ from pathlib import Path
5
6
from typing import Iterator , List , Optional
7
+ from urllib .parse import urlparse , urlunparse
6
8
7
9
import polars as pl
10
+ import requests
8
11
from tqdm import tqdm
9
12
10
13
from ..data import Patient
15
18
logger = logging .getLogger (__name__ )
16
19
17
20
21
+ def is_url (path : str ) -> bool :
22
+ """URL detection."""
23
+ result = urlparse (path )
24
+ # Both scheme and netloc must be present for a valid URL
25
+ return all ([result .scheme , result .netloc ])
26
+
27
+
28
+ def clean_path (path : str ) -> str :
29
+ """Clean a path string."""
30
+ if is_url (path ):
31
+ parsed = urlparse (path )
32
+ cleaned_path = os .path .normpath (parsed .path )
33
+ # Rebuild the full URL
34
+ return urlunparse (parsed ._replace (path = cleaned_path ))
35
+ else :
36
+ # It's a local path — resolve and normalize
37
+ return str (Path (path ).expanduser ().resolve ())
38
+
39
+
40
+ def path_exists (path : str ) -> bool :
41
+ """
42
+ Check if a path exists.
43
+ If the path is a URL, it will send a HEAD request.
44
+ If the path is a local file, it will use the Path.exists().
45
+ """
46
+ if is_url (path ):
47
+ try :
48
+ response = requests .head (path , timeout = 5 )
49
+ return response .status_code == 200
50
+ except requests .RequestException :
51
+ return False
52
+ else :
53
+ return Path (path ).exists ()
54
+
55
+
56
+ def scan_csv_gz_or_csv (path : str ) -> pl .LazyFrame :
57
+ """
58
+ Scan a CSV.gz or CSV file and returns a LazyFrame.
59
+ It will fall back to the other extension if not found.
60
+
61
+ Args:
62
+ path (str): URL or local path to a .csv or .csv.gz file
63
+
64
+ Returns:
65
+ pl.LazyFrame: The LazyFrame for the CSV.gz or CSV file.
66
+ """
67
+ if path_exists (path ):
68
+ return pl .scan_csv (path , infer_schema = False )
69
+ # Try the alternative extension
70
+ if path .endswith (".csv.gz" ):
71
+ alt_path = path [:- 3 ] # Remove .gz
72
+ elif path .endswith (".csv" ):
73
+ alt_path = f"{ path } .gz" # Add .gz
74
+ else :
75
+ raise FileNotFoundError (f"Path does not have expected extension: { path } " )
76
+ if path_exists (alt_path ):
77
+ logger .info (f"Original path does not exist. Using alternative: { alt_path } " )
78
+ return pl .scan_csv (alt_path , infer_schema = False )
79
+ raise FileNotFoundError (f"Neither path exists: { path } or { alt_path } " )
80
+
81
+
18
82
class BaseDataset (ABC ):
19
83
"""Abstract base class for all PyHealth datasets.
20
84
@@ -79,15 +143,13 @@ def collected_global_event_df(self) -> pl.DataFrame:
79
143
if self .dev :
80
144
# Limit the number of patients in dev mode
81
145
logger .info ("Dev mode enabled: limiting to 1000 patients" )
82
- limited_patients = (
83
- df .select (pl .col ("patient_id" ))
84
- .unique ()
85
- .limit (1000 )
86
- )
146
+ limited_patients = df .select (pl .col ("patient_id" )).unique ().limit (1000 )
87
147
df = df .join (limited_patients , on = "patient_id" , how = "inner" )
88
148
89
149
self ._collected_global_event_df = df .collect ()
90
- logger .info (f"Collected dataframe with shape: { self ._collected_global_event_df .shape } " )
150
+ logger .info (
151
+ f"Collected dataframe with shape: { self ._collected_global_event_df .shape } "
152
+ )
91
153
92
154
return self ._collected_global_event_df
93
155
@@ -118,36 +180,42 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
118
180
119
181
table_cfg = self .config .tables [table_name ]
120
182
csv_path = f"{ self .root } /{ table_cfg .file_path } "
121
- # TODO: check if it's zipped or not.
122
-
123
- # TODO: make this work for remote files
124
- # if not Path(csv_path).exists():
125
- # raise FileNotFoundError(f"CSV not found: {csv_path}")
183
+ csv_path = clean_path (csv_path )
126
184
127
185
logger .info (f"Scanning table: { table_name } from { csv_path } " )
128
-
129
- df = pl .scan_csv (csv_path , infer_schema = False )
130
-
131
- # TODO: this is an ad hoc fix for the MIMIC-III dataset
132
- df = df .with_columns ([pl .col (col ).alias (col .lower ()) for col in df .collect_schema ().names ()])
186
+ df = scan_csv_gz_or_csv (csv_path )
187
+
188
+ # Convert column names to lowercase before calling preprocess_func
189
+ col_names = df .collect_schema ().names ()
190
+ if any (col != col .lower () for col in col_names ):
191
+ logger .warning ("Some column names were converted to lowercase" )
192
+ df = df .with_columns ([pl .col (col ).alias (col .lower ()) for col in col_names ])
193
+
194
+ # Check if there is a preprocessing function for this table
195
+ preprocess_func = getattr (self , f"preprocess_{ table_name } " , None )
196
+ if preprocess_func is not None :
197
+ logger .info (
198
+ f"Preprocessing table: { table_name } with { preprocess_func .__name__ } "
199
+ )
200
+ df = preprocess_func (df )
133
201
134
202
# Handle joins
135
203
for join_cfg in table_cfg .join :
136
204
other_csv_path = f"{ self .root } /{ join_cfg .file_path } "
137
- # if not Path(other_csv_path).exists():
138
- # raise FileNotFoundError(
139
- # f"Join CSV not found: {other_csv_path}"
140
- # )
141
-
142
- join_df = pl .scan_csv (other_csv_path , infer_schema = False )
143
- join_df = join_df .with_columns ([pl .col (col ).alias (col .lower ()) for col in join_df .collect_schema ().names ()])
205
+ other_csv_path = clean_path (other_csv_path )
206
+ logger .info (f"Joining with table: { other_csv_path } " )
207
+ join_df = scan_csv_gz_or_csv (other_csv_path )
208
+ join_df = join_df .with_columns (
209
+ [
210
+ pl .col (col ).alias (col .lower ())
211
+ for col in join_df .collect_schema ().names ()
212
+ ]
213
+ )
144
214
join_key = join_cfg .on
145
215
columns = join_cfg .columns
146
216
how = join_cfg .how
147
217
148
- df = df .join (
149
- join_df .select ([join_key ] + columns ), on = join_key , how = how
150
- )
218
+ df = df .join (join_df .select ([join_key ] + columns ), on = join_key , how = how )
151
219
152
220
patient_id_col = table_cfg .patient_id
153
221
timestamp_col = table_cfg .timestamp
@@ -158,10 +226,9 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
158
226
if timestamp_col :
159
227
if isinstance (timestamp_col , list ):
160
228
# Concatenate all timestamp parts in order with no separator
161
- combined_timestamp = (
162
- pl .concat_str ([pl .col (col ) for col in timestamp_col ])
163
- .str .strptime (pl .Datetime , format = timestamp_format , strict = True )
164
- )
229
+ combined_timestamp = pl .concat_str (
230
+ [pl .col (col ) for col in timestamp_col ]
231
+ ).str .strptime (pl .Datetime , format = timestamp_format , strict = True )
165
232
timestamp_expr = combined_timestamp
166
233
else :
167
234
# Single timestamp column
@@ -185,8 +252,7 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
185
252
186
253
# Flatten attribute columns with event_type prefix
187
254
attribute_columns = [
188
- pl .col (attr ).alias (f"{ table_name } /{ attr } " )
189
- for attr in attribute_cols
255
+ pl .col (attr ).alias (f"{ table_name } /{ attr } " ) for attr in attribute_cols
190
256
]
191
257
192
258
event_frame = df .select (base_columns + attribute_columns )
@@ -225,9 +291,7 @@ def get_patient(self, patient_id: str) -> Patient:
225
291
assert (
226
292
patient_id in self .unique_patient_ids
227
293
), f"Patient { patient_id } not found in dataset"
228
- df = self .collected_global_event_df .filter (
229
- pl .col ("patient_id" ) == patient_id
230
- )
294
+ df = self .collected_global_event_df .filter (pl .col ("patient_id" ) == patient_id )
231
295
return Patient (patient_id = patient_id , data_source = df )
232
296
233
297
def iter_patients (self , df : Optional [pl .LazyFrame ] = None ) -> Iterator [Patient ]:
@@ -260,11 +324,9 @@ def default_task(self) -> Optional[BaseTask]:
260
324
Optional[BaseTask]: The default task, if any.
261
325
"""
262
326
return None
263
-
327
+
264
328
def set_task (
265
- self ,
266
- task : Optional [BaseTask ] = None ,
267
- num_workers : Optional [int ] = None
329
+ self , task : Optional [BaseTask ] = None , num_workers : Optional [int ] = None
268
330
) -> SampleDataset :
269
331
"""Processes the base dataset to generate the task-specific sample dataset.
270
332
@@ -283,7 +345,9 @@ def set_task(
283
345
assert self .default_task is not None , "No default tasks found"
284
346
task = self .default_task
285
347
286
- logger .info (f"Setting task { task .task_name } for { self .dataset_name } base dataset..." )
348
+ logger .info (
349
+ f"Setting task { task .task_name } for { self .dataset_name } base dataset..."
350
+ )
287
351
288
352
filtered_global_event_df = task .pre_filter (self .collected_global_event_df )
289
353
@@ -298,7 +362,7 @@ def set_task(
298
362
if num_workers == 1 :
299
363
for patient in tqdm (
300
364
self .iter_patients (filtered_global_event_df ),
301
- desc = f"Generating samples for { task .task_name } "
365
+ desc = f"Generating samples for { task .task_name } " ,
302
366
):
303
367
samples .extend (task (patient ))
304
368
else :
0 commit comments