diff --git a/pyhealth/datasets/sample_dataset.py b/pyhealth/datasets/sample_dataset.py index 134bede35..57635ca4b 100644 --- a/pyhealth/datasets/sample_dataset.py +++ b/pyhealth/datasets/sample_dataset.py @@ -338,6 +338,23 @@ def __str__(self) -> str: """ return f"Sample dataset {self.dataset_name} {self.task_name}" + def _rebuild_index_mappings(self) -> None: + """Rebuild patient and record lookup maps for the current dataset view.""" + patient_to_index: Dict[str, List[int]] = {} + record_to_index: Dict[str, List[int]] = {} + + for i in range(len(self)): + sample = self[i] + patient_id = sample.get("patient_id") + if patient_id is not None: + patient_to_index.setdefault(patient_id, []).append(i) + record_id = sample.get("record_id", sample.get("visit_id")) + if record_id is not None: + record_to_index.setdefault(record_id, []).append(i) + + self.patient_to_index = patient_to_index + self.record_to_index = record_to_index + def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset": """Create a StreamingDataset restricted to the provided indices.""" @@ -404,6 +421,7 @@ def subset(self, indices: Union[Sequence[int], slice]) -> "SampleDataset": new_dataset.subsampled_files = new_subsampled_files new_dataset.region_of_interest = new_roi new_dataset.reset() + new_dataset._rebuild_index_mappings() return new_dataset @@ -528,6 +546,7 @@ def subset(self, indices: Union[Sequence[int], slice]) -> SampleDataset: new_dataset = copy.deepcopy(self) new_dataset._data = samples + new_dataset._rebuild_index_mappings() return new_dataset def close(self) -> None: diff --git a/tests/core/test_sample_dataset.py b/tests/core/test_sample_dataset.py index 8dd4629c0..f9f84075d 100644 --- a/tests/core/test_sample_dataset.py +++ b/tests/core/test_sample_dataset.py @@ -28,6 +28,28 @@ def _get_datasets(self): ) return ds_disk, ds_mem + def _get_metadata_datasets(self): + samples = [ + {"patient_id": "p1", "record_id": "r1", "feature": 0, "label": 0}, + {"patient_id": "p1", "record_id": "r2", "feature": 1, "label": 1}, + {"patient_id": "p2", "record_id": "r3", "feature": 2, "label": 0}, + {"patient_id": "p1", "record_id": "r4", "feature": 3, "label": 1}, + {"patient_id": "p3", "record_id": "r5", "feature": 4, "label": 0}, + ] + ds_disk = create_sample_dataset( + samples=samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + in_memory=False + ) + ds_mem = create_sample_dataset( + samples=samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + in_memory=True + ) + return ds_disk, ds_mem + def test_len(self): ds_disk, ds_mem = self._get_datasets() self.assertEqual(len(ds_disk), 20) @@ -83,6 +105,38 @@ def test_subset_slice(self): for d, m in zip(list_disk, list_mem): self.assertEqual(d["feature"], m["feature"]) + def test_subset_rebuilds_index_metadata(self): + ds_disk, ds_mem = self._get_metadata_datasets() + indices = [1, 3, 4] + + sub_disk = ds_disk.subset(indices) + sub_mem = ds_mem.subset(indices) + + expected_patient = {"p1": [0, 1], "p3": [2]} + expected_record = {"r2": [0], "r4": [1], "r5": [2]} + + self.assertEqual(ds_disk.patient_to_index["p1"], [0, 1, 3]) + self.assertEqual(ds_mem.patient_to_index["p1"], [0, 1, 3]) + self.assertEqual(sub_disk.patient_to_index, expected_patient) + self.assertEqual(sub_mem.patient_to_index, expected_patient) + self.assertEqual(sub_disk.record_to_index, expected_record) + self.assertEqual(sub_mem.record_to_index, expected_record) + + def test_subset_slice_rebuilds_index_metadata(self): + ds_disk, ds_mem = self._get_metadata_datasets() + s = slice(1, 5, 2) + + sub_disk = ds_disk.subset(s) + sub_mem = ds_mem.subset(s) + + expected_patient = {"p1": [0, 1]} + expected_record = {"r2": [0], "r4": [1]} + + self.assertEqual(sub_disk.patient_to_index, expected_patient) + self.assertEqual(sub_mem.patient_to_index, expected_patient) + self.assertEqual(sub_disk.record_to_index, expected_record) + self.assertEqual(sub_mem.record_to_index, expected_record) + def test_set_shuffle(self): ds_disk, ds_mem = self._get_datasets()