Skip to content

Commit 9305a6f

Browse files
authored
2-3× Speedup for BaseDataset.set_task() (#378)
* Enhance Patient class with optimized filtering methods - Add fast time range filtering via binary search on sorted timestamps - Add efficient event type filtering using pre-built index lookups - Reduce timestamp precision from microseconds to milliseconds - Set default num_workers=1 in set_task for better memory control - Remove unused dev flag from child MIMIC4 dataset classes * Refactor InHospitalMortalityMIMIC4 and Readmission30DaysMIMIC4 for improved performance and clarity * Fix bug in get_events to ensure event_type is only asserted when filters are provided
1 parent 73a2172 commit 9305a6f

File tree

5 files changed

+151
-96
lines changed

5 files changed

+151
-96
lines changed

pyhealth/data/data.py

Lines changed: 61 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import operator
12
from dataclasses import dataclass, field
23
from datetime import datetime
4+
from functools import reduce
35
from typing import Dict, List, Mapping, Optional, Union
46

7+
import numpy as np
58
import polars as pl
69

710

@@ -91,7 +94,8 @@ class Patient:
9194
9295
Attributes:
9396
patient_id (str): Unique patient identifier.
94-
data_source (pl.DataFrame): DataFrame containing all events.
97+
data_source (pl.DataFrame): DataFrame containing all events, sorted by timestamp.
98+
event_type_partitions (Dict[str, pl.DataFrame]): Dictionary mapping event types to their respective DataFrame partitions.
9599
"""
96100

97101
def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None:
@@ -104,6 +108,42 @@ def __init__(self, patient_id: str, data_source: pl.DataFrame) -> None:
104108
"""
105109
self.patient_id = patient_id
106110
self.data_source = data_source.sort("timestamp")
111+
self.event_type_partitions = self.data_source.partition_by("event_type", maintain_order=True, as_dict=True)
112+
113+
def _filter_by_time_range_regular(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame:
114+
"""Regular filtering by time. Time complexity: O(n)."""
115+
if start is not None:
116+
df = df.filter(pl.col("timestamp") >= start)
117+
if end is not None:
118+
df = df.filter(pl.col("timestamp") <= end)
119+
return df
120+
121+
def _filter_by_time_range_fast(self, df: pl.DataFrame, start: Optional[datetime], end: Optional[datetime]) -> pl.DataFrame:
122+
"""Fast filtering by time using binary search on sorted timestamps. Time complexity: O(log n)."""
123+
if start is None and end is None:
124+
return df
125+
df = df.filter(pl.col("timestamp").is_not_null())
126+
ts_col = df["timestamp"].to_numpy()
127+
start_idx = 0
128+
end_idx = len(ts_col)
129+
if start is not None:
130+
start_idx = np.searchsorted(ts_col, start, side="left")
131+
if end is not None:
132+
end_idx = np.searchsorted(ts_col, end, side="right")
133+
return df.slice(start_idx, end_idx - start_idx)
134+
135+
def _filter_by_event_type_regular(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame:
136+
"""Regular filtering by event type. Time complexity: O(n)."""
137+
if event_type:
138+
df = df.filter(pl.col("event_type") == event_type)
139+
return df
140+
141+
def _filter_by_event_type_fast(self, df: pl.DataFrame, event_type: Optional[str]) -> pl.DataFrame:
142+
"""Fast filtering by event type using pre-built event type index. Time complexity: O(1)."""
143+
if event_type:
144+
return self.event_type_partitions.get((event_type,), df[:0])
145+
else:
146+
return df
107147

108148
def get_events(
109149
self,
@@ -129,37 +169,41 @@ def get_events(
129169
Union[pl.DataFrame, List[Event]]: Filtered events as a DataFrame
130170
or a list of Event objects.
131171
"""
132-
df = self.data_source
133-
if event_type:
134-
df = df.filter(pl.col("event_type") == event_type)
135-
if start:
136-
df = df.filter(pl.col("timestamp") >= start)
137-
if end:
138-
df = df.filter(pl.col("timestamp") <= end)
172+
# faster filtering (by default)
173+
df = self._filter_by_event_type_fast(self.data_source, event_type)
174+
df = self._filter_by_time_range_fast(df, start, end)
139175

140-
filters = filters or []
141-
for filt in filters:
176+
# regular filtering (commented out by default)
177+
# df = self._filter_by_event_type_regular(self.data_source, event_type)
178+
# df = self._filter_by_time_range_regular(df, start, end)
179+
180+
if filters:
142181
assert event_type is not None, "event_type must be provided if filters are provided"
182+
else:
183+
filters = []
184+
exprs = []
185+
for filt in filters:
143186
if not (isinstance(filt, tuple) and len(filt) == 3):
144187
raise ValueError(f"Invalid filter format: {filt} (must be tuple of (attr, op, value))")
145188
attr, op, val = filt
146189
col_expr = pl.col(f"{event_type}/{attr}")
147190
# Build operator expression
148191
if op == "==":
149-
expr = col_expr == val
192+
exprs.append(col_expr == val)
150193
elif op == "!=":
151-
expr = col_expr != val
194+
exprs.append(col_expr != val)
152195
elif op == "<":
153-
expr = col_expr < val
196+
exprs.append(col_expr < val)
154197
elif op == "<=":
155-
expr = col_expr <= val
198+
exprs.append(col_expr <= val)
156199
elif op == ">":
157-
expr = col_expr > val
200+
exprs.append(col_expr > val)
158201
elif op == ">=":
159-
expr = col_expr >= val
202+
exprs.append(col_expr >= val)
160203
else:
161204
raise ValueError(f"Unsupported operator: {op} in filter {filt}")
162-
df = df.filter(expr)
205+
if exprs:
206+
df = df.filter(reduce(operator.and_, exprs))
163207
if return_df:
164208
return df
165209
return [Event.from_dict(d) for d in df.to_dicts()]

pyhealth/datasets/base_dataset.py

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def __init__(
9797
tables: List[str],
9898
dataset_name: Optional[str] = None,
9999
config_path: Optional[str] = None,
100-
dev: bool = False, # Added dev parameter
100+
dev: bool = False,
101101
):
102102
"""Initializes the BaseDataset.
103103
@@ -115,7 +115,7 @@ def __init__(
115115
self.tables = tables
116116
self.dataset_name = dataset_name or self.__class__.__name__
117117
self.config = load_yaml_config(config_path)
118-
self.dev = dev # Store dev mode flag
118+
self.dev = dev
119119

120120
logger.info(
121121
f"Initializing {self.dataset_name} dataset from {self.root} (dev mode: {self.dev})"
@@ -147,6 +147,21 @@ def collected_global_event_df(self) -> pl.DataFrame:
147147
df = df.join(limited_patients, on="patient_id", how="inner")
148148

149149
self._collected_global_event_df = df.collect()
150+
151+
# Profile the Polars collect() operation (commented out by default)
152+
# self._collected_global_event_df, profile = df.profile()
153+
# profile = profile.with_columns([
154+
# (pl.col("end") - pl.col("start")).alias("duration"),
155+
# ])
156+
# profile = profile.with_columns([
157+
# (pl.col("duration") / profile["duration"].sum() * 100).alias("percentage")
158+
# ])
159+
# profile = profile.sort("duration", descending=True)
160+
# with pl.Config() as cfg:
161+
# cfg.set_tbl_rows(-1)
162+
# cfg.set_fmt_str_lengths(200)
163+
# print(profile)
164+
150165
logger.info(
151166
f"Collected dataframe with shape: {self._collected_global_event_df.shape}"
152167
)
@@ -247,7 +262,8 @@ def load_table(self, table_name: str) -> pl.LazyFrame:
247262
base_columns = [
248263
patient_id_expr.alias("patient_id"),
249264
pl.lit(table_name).cast(pl.Utf8).alias("event_type"),
250-
timestamp_expr.cast(pl.Datetime).alias("timestamp"),
265+
# ms should be sufficient for most cases
266+
timestamp_expr.cast(pl.Datetime(time_unit="ms")).alias("timestamp"),
251267
]
252268

253269
# Flatten attribute columns with event_type prefix
@@ -326,14 +342,15 @@ def default_task(self) -> Optional[BaseTask]:
326342
return None
327343

328344
def set_task(
329-
self, task: Optional[BaseTask] = None, num_workers: Optional[int] = None
345+
self, task: Optional[BaseTask] = None, num_workers: int = 1
330346
) -> SampleDataset:
331347
"""Processes the base dataset to generate the task-specific sample dataset.
332348
333349
Args:
334350
task (Optional[BaseTask]): The task to set. Uses default task if None.
335-
num_workers (Optional[int]): Number of workers for parallel processing.
336-
Use None to use all available cores (max 32). Use 1 for single-threaded.
351+
num_workers (int): Number of workers for multi-threading. Default is 1.
352+
This is because the task function is usually CPU-bound. And using
353+
multi-threading may not speed up the task function.
337354
338355
Returns:
339356
SampleDataset: The generated sample dataset.
@@ -351,26 +368,26 @@ def set_task(
351368

352369
filtered_global_event_df = task.pre_filter(self.collected_global_event_df)
353370

354-
# Determine number of workers
355-
if num_workers is None:
356-
num_workers = min(8, os.cpu_count())
357-
358371
logger.info(f"Generating samples with {num_workers} worker(s)...")
359372

360373
samples = []
361374

362375
if num_workers == 1:
376+
# single-threading (by default)
363377
for patient in tqdm(
364378
self.iter_patients(filtered_global_event_df),
365-
desc=f"Generating samples for {task.task_name}",
379+
total=filtered_global_event_df["patient_id"].n_unique(),
380+
desc=f"Generating samples for {task.task_name} with 1 worker",
381+
smoothing=0,
366382
):
367383
samples.extend(task(patient))
368384
else:
369-
logger.info(f"Generating samples for {task.task_name}")
385+
# multi-threading (not recommended)
386+
logger.info(f"Generating samples for {task.task_name} with {num_workers} workers")
370387
patients = list(self.iter_patients(filtered_global_event_df))
371388
with ThreadPoolExecutor(max_workers=num_workers) as executor:
372389
futures = [executor.submit(task, patient) for patient in patients]
373-
for future in as_completed(futures):
390+
for future in tqdm(as_completed(futures), total=len(futures), desc=f"Collecting samples for {task.task_name} from {num_workers} workers"):
374391
samples.extend(future.result())
375392

376393
sample_dataset = SampleDataset(

pyhealth/datasets/mimic4.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ def __init__(
210210
note_config_path: Optional[str] = None,
211211
cxr_config_path: Optional[str] = None,
212212
dataset_name: str = "mimic4",
213-
dev: bool = False, # Added dev parameter
213+
dev: bool = False,
214214
):
215215
log_memory_usage("Starting MIMIC4Dataset init")
216216

@@ -220,8 +220,10 @@ def __init__(
220220
self.root = None
221221
self.tables = None
222222
self.config = None
223-
self.dev = dev # Store dev mode flag
224-
223+
# Dev flag is only used in the MIMIC4Dataset class
224+
# to ensure the same set of patients are used for all sub-datasets.
225+
self.dev = dev
226+
225227
# We need at least one root directory
226228
if not any([ehr_root, note_root, cxr_root]):
227229
raise ValueError("At least one root directory must be provided")
@@ -238,7 +240,6 @@ def __init__(
238240
root=ehr_root,
239241
tables=ehr_tables,
240242
config_path=ehr_config_path,
241-
dev=dev # Pass dev mode flag
242243
)
243244
log_memory_usage("After EHR dataset initialization")
244245

@@ -249,7 +250,6 @@ def __init__(
249250
root=note_root,
250251
tables=note_tables,
251252
config_path=note_config_path,
252-
dev=dev # Pass dev mode flag
253253
)
254254
log_memory_usage("After Note dataset initialization")
255255

@@ -260,7 +260,6 @@ def __init__(
260260
root=cxr_root,
261261
tables=cxr_tables,
262262
config_path=cxr_config_path,
263-
dev=dev # Pass dev mode flag
264263
)
265264
log_memory_usage("After CXR dataset initialization")
266265

pyhealth/tasks/in_hospital_mortality_mimic4.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from datetime import datetime, timedelta
2-
from typing import Any, Dict, List, ClassVar
2+
from typing import Any, ClassVar, Dict, List
33

44
import polars as pl
55

@@ -8,11 +8,16 @@
88

99
class InHospitalMortalityMIMIC4(BaseTask):
1010
"""Task for predicting in-hospital mortality using MIMIC-IV dataset.
11-
11+
12+
This task leverages lab results to predict the likelihood of in-hospital
13+
mortality.
14+
1215
Attributes:
1316
task_name (str): The name of the task.
14-
input_schema (Dict[str, str]): The input schema for the task.
15-
output_schema (Dict[str, str]): The output schema for the task.
17+
input_schema (Dict[str, str]): The schema for input data, which includes:
18+
- labs: A timeseries of lab results.
19+
output_schema (Dict[str, str]): The schema for output data, which includes:
20+
- mortality: A binary indicator of mortality.
1621
"""
1722
task_name: str = "InHospitalMortalityMIMIC4"
1823
input_schema: Dict[str, str] = {"labs": "timeseries"}
@@ -33,7 +38,7 @@ class InHospitalMortalityMIMIC4(BaseTask):
3338
"Phosphate": ["50970"],
3439
},
3540
}
36-
41+
3742
# Create flat list of all lab items for use in the function
3843
LABITEMS: ClassVar[List[str]] = [
3944
item for category in LAB_CATEGORIES.values()
@@ -42,25 +47,16 @@ class InHospitalMortalityMIMIC4(BaseTask):
4247
]
4348

4449
def __call__(self, patient: Any) -> List[Dict[str, Any]]:
45-
"""Processes a single patient for the in-hospital mortality prediction task.
46-
47-
Args:
48-
patient (Any): A Patient object containing patient data.
49-
50-
Returns:
51-
List[Dict[str, Any]]: A list of samples, each sample is a dict with patient_id,
52-
admission_id, labs, and mortality as keys.
53-
"""
5450
input_window_hours = 48
5551
samples = []
56-
52+
5753
demographics = patient.get_events(event_type="patients")
5854
assert len(demographics) == 1
5955
demographics = demographics[0]
6056
anchor_age = int(demographics.anchor_age)
6157
if anchor_age < 18:
6258
return []
63-
59+
6460
admissions = patient.get_events(event_type="admissions")
6561
for admission in admissions:
6662
admission_dischtime = datetime.strptime(admission.dischtime, "%Y-%m-%d %H:%M:%S")
@@ -95,7 +91,9 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
9591
labevents_df = labevents_df.pivot(
9692
index="timestamp",
9793
columns="labevents/itemid",
98-
values="labevents/valuenum"
94+
values="labevents/valuenum",
95+
# in case of multiple values for the same timestamp
96+
aggregate_function="first",
9997
)
10098
labevents_df = labevents_df.sort("timestamp")
10199

@@ -104,13 +102,13 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
104102
missing_cols = [item for item in self.LABITEMS if item not in existing_cols]
105103
for col in missing_cols:
106104
labevents_df = labevents_df.with_columns(pl.lit(None).alias(col))
107-
105+
108106
# Reorder columns by LABITEMS
109107
labevents_df = labevents_df.select(
110108
"timestamp",
111109
*self.LABITEMS
112110
)
113-
111+
114112
timestamps = labevents_df["timestamp"].to_list()
115113
lab_values = labevents_df.drop("timestamp").to_numpy()
116114

@@ -124,4 +122,4 @@ def __call__(self, patient: Any) -> List[Dict[str, Any]]:
124122
}
125123
)
126124

127-
return samples
125+
return samples

0 commit comments

Comments
 (0)