Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions pyhealth/datasets/sample_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,23 @@ def __str__(self) -> str:
"""
return f"Sample dataset {self.dataset_name} {self.task_name}"

def _rebuild_index_mappings(self) -> None:
"""Rebuild patient and record lookup maps for the current dataset view."""
patient_to_index: Dict[str, List[int]] = {}
record_to_index: Dict[str, List[int]] = {}

for i in range(len(self)):
sample = self[i]
patient_id = sample.get("patient_id")
if patient_id is not None:
patient_to_index.setdefault(patient_id, []).append(i)
record_id = sample.get("record_id", sample.get("visit_id"))
if record_id is not None:
record_to_index.setdefault(record_id, []).append(i)

self.patient_to_index = patient_to_index
self.record_to_index = record_to_index

def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset":
"""Create a StreamingDataset restricted to the provided indices."""

Expand Down Expand Up @@ -404,6 +421,7 @@ def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset":
new_dataset.subsampled_files = new_subsampled_files
new_dataset.region_of_interest = new_roi
new_dataset.reset()
new_dataset._rebuild_index_mappings()

return new_dataset

Expand Down Expand Up @@ -528,6 +546,7 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset:

new_dataset = copy.deepcopy(self)
new_dataset._data = samples
new_dataset._rebuild_index_mappings()
return new_dataset

def close(self) -> None:
Expand Down
54 changes: 54 additions & 0 deletions tests/core/test_sample_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,28 @@ def _get_datasets(self):
)
return ds_disk, ds_mem

def _get_metadata_datasets(self):
samples = [
{"patient_id": "p1", "record_id": "r1", "feature": 0, "label": 0},
{"patient_id": "p1", "record_id": "r2", "feature": 1, "label": 1},
{"patient_id": "p2", "record_id": "r3", "feature": 2, "label": 0},
{"patient_id": "p1", "record_id": "r4", "feature": 3, "label": 1},
{"patient_id": "p3", "record_id": "r5", "feature": 4, "label": 0},
]
ds_disk = create_sample_dataset(
samples=samples,
input_schema=self.input_schema,
output_schema=self.output_schema,
in_memory=False
)
ds_mem = create_sample_dataset(
samples=samples,
input_schema=self.input_schema,
output_schema=self.output_schema,
in_memory=True
)
return ds_disk, ds_mem

def test_len(self):
ds_disk, ds_mem = self._get_datasets()
self.assertEqual(len(ds_disk), 20)
Expand Down Expand Up @@ -83,6 +105,38 @@ def test_subset_slice(self):
for d, m in zip(list_disk, list_mem):
self.assertEqual(d["feature"], m["feature"])

def test_subset_rebuilds_index_metadata(self):
ds_disk, ds_mem = self._get_metadata_datasets()
indices = [1, 3, 4]

sub_disk = ds_disk.subset(indices)
sub_mem = ds_mem.subset(indices)

expected_patient = {"p1": [0, 1], "p3": [2]}
expected_record = {"r2": [0], "r4": [1], "r5": [2]}

self.assertEqual(ds_disk.patient_to_index["p1"], [0, 1, 3])
self.assertEqual(ds_mem.patient_to_index["p1"], [0, 1, 3])
self.assertEqual(sub_disk.patient_to_index, expected_patient)
self.assertEqual(sub_mem.patient_to_index, expected_patient)
self.assertEqual(sub_disk.record_to_index, expected_record)
self.assertEqual(sub_mem.record_to_index, expected_record)

def test_subset_slice_rebuilds_index_metadata(self):
ds_disk, ds_mem = self._get_metadata_datasets()
s = slice(1, 5, 2)

sub_disk = ds_disk.subset(s)
sub_mem = ds_mem.subset(s)

expected_patient = {"p1": [0, 1]}
expected_record = {"r2": [0], "r4": [1]}

self.assertEqual(sub_disk.patient_to_index, expected_patient)
self.assertEqual(sub_mem.patient_to_index, expected_patient)
self.assertEqual(sub_disk.record_to_index, expected_record)
self.assertEqual(sub_mem.record_to_index, expected_record)

def test_set_shuffle(self):
ds_disk, ds_mem = self._get_datasets()

Expand Down
Loading