diff --git a/docs/api/datasets/pyhealth.datasets.DREAMTDataset.rst b/docs/api/datasets/pyhealth.datasets.DREAMTDataset.rst index 61cc16c2e..3d45696be 100644 --- a/docs/api/datasets/pyhealth.datasets.DREAMTDataset.rst +++ b/docs/api/datasets/pyhealth.datasets.DREAMTDataset.rst @@ -1,21 +1,24 @@ pyhealth.datasets.DREAMTDataset =================================== -The Dataset for Real-time sleep stage EstimAtion using Multisensor wearable Technology (DREAMT) includes wrist-based wearable and polysomnography (PSG) sleep data from 100 participants recruited from the Duke University Health System (DUHS) Sleep Disorder Lab. +The Dataset for Real-time sleep stage EstimAtion using Multisensor wearable +Technology (DREAMT) includes wrist-based wearable and polysomnography (PSG) +sleep data from 100 participants recruited from the Duke University Health +System (DUHS) Sleep Disorder Lab. -This includes wearable signals, PSG signals, sleep labels, and clinical data related to sleep health and disorders. +This includes wearable signals, PSG signals, sleep labels, and clinical data +related to sleep health and disorders. -The DREAMTDataset class provides an interface for loading and working with the DREAMT dataset. It can process DREAMT data across versions into a well-structured dataset object providing support for modeling and analysis. +``DREAMTDataset`` supports both official DREAMT release layouts and partial +local subsets. It builds metadata linking each patient to locally available +signal files and exposes data for both the simplified +``SleepStagingDREAMT`` window-classification task and the more sequence-style +``SleepStagingDREAMTSeq`` task. -Refer to the `doc `_ for more information about the dataset. +Refer to the `doc `_ for more +information about the dataset. .. autoclass:: pyhealth.datasets.DREAMTDataset :members: :undoc-members: :show-inheritance: - - - - - - \ No newline at end of file diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..789f1cf21 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -190,6 +190,7 @@ API Reference models/pyhealth.models.SparcNet models/pyhealth.models.StageNet models/pyhealth.models.StageAttentionNet + models/pyhealth.models.WatchSleepNet models/pyhealth.models.AdaCare models/pyhealth.models.ConCare models/pyhealth.models.Agent diff --git a/docs/api/models/pyhealth.models.WatchSleepNet.rst b/docs/api/models/pyhealth.models.WatchSleepNet.rst new file mode 100644 index 000000000..4b3cd3627 --- /dev/null +++ b/docs/api/models/pyhealth.models.WatchSleepNet.rst @@ -0,0 +1,14 @@ +pyhealth.models.WatchSleepNet +=================================== + +Simplified WatchSleepNet-style architecture for wearable sleep staging. + +The implementation supports both: + +- pooled classification over fixed windows +- sequence-output classification for epoch-level sleep staging + +.. autoclass:: pyhealth.models.WatchSleepNet + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks/pyhealth.tasks.sleep_staging.rst b/docs/api/tasks/pyhealth.tasks.sleep_staging.rst index eab36ee9b..a9a5a3e0d 100644 --- a/docs/api/tasks/pyhealth.tasks.sleep_staging.rst +++ b/docs/api/tasks/pyhealth.tasks.sleep_staging.rst @@ -1,6 +1,19 @@ pyhealth.tasks.sleep_staging ======================================= +``SleepStagingDREAMT`` provides a simplified window-classification task, while +``SleepStagingDREAMTSeq`` provides a more paper-aligned sequence-style task. + +.. autoclass:: pyhealth.tasks.SleepStagingDREAMT + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.tasks.SleepStagingDREAMTSeq + :members: + :undoc-members: + :show-inheritance: + .. autofunction:: pyhealth.tasks.sleep_staging.sleep_staging_isruc_fn .. autofunction:: pyhealth.tasks.sleep_staging.sleep_staging_sleepedf_fn -.. autofunction:: pyhealth.tasks.sleep_staging.sleep_staging_shhs_fn \ No newline at end of file +.. autofunction:: pyhealth.tasks.sleep_staging.sleep_staging_shhs_fn diff --git a/examples/dreamt_sleep_staging_watchsleepnet.py b/examples/dreamt_sleep_staging_watchsleepnet.py new file mode 100644 index 000000000..fca9bb04e --- /dev/null +++ b/examples/dreamt_sleep_staging_watchsleepnet.py @@ -0,0 +1,377 @@ +"""Synthetic DREAMT sleep-staging example with WatchSleepNet ablations. + +This script demonstrates two usage patterns: + +1. A simplified window-classification pipeline. +2. A more paper-aligned sequence-style pipeline using IBI-only epoch features. + +The sequence example reports metrics emphasized in the WatchSleepNet paper: +accuracy, macro F1, REM F1, Cohen's kappa, and AUROC. The training loop is +intentionally lightweight and uses synthetic DREAMT-compatible data only. It +does not implement the paper's SHHS/MESA pretraining stage. +""" + +from __future__ import annotations + +import shutil +import tempfile +from pathlib import Path +from typing import Any + +import numpy as np +import pandas as pd +import torch + +from pyhealth.datasets import DREAMTDataset, get_dataloader +from pyhealth.models import WatchSleepNet +from pyhealth.tasks import SleepStagingDREAMT, SleepStagingDREAMTSeq + +NUM_CLASSES = 3 +REM_CLASS_INDEX = 2 +IGNORE_INDEX = -100 + + +def build_synthetic_dreamt_root(root: Path, num_subjects: int = 4) -> Path: + """Create a tiny DREAMT-style directory with wearable CSV files.""" + dreamt_root = root / "dreamt" + (dreamt_root / "data_64Hz").mkdir(parents=True) + + participant_rows = [] + for subject_index in range(num_subjects): + patient_id = f"S{subject_index + 1:03d}" + participant_rows.append( + { + "SID": patient_id, + "AGE": 25 + subject_index, + "GENDER": "F" if subject_index % 2 == 0 else "M", + "BMI": 22.0 + subject_index, + "OAHI": 1.0, + "AHI": 2.0, + "Mean_SaO2": "97%", + "Arousal Index": 10.0, + "MEDICAL_HISTORY": "None", + "Sleep_Disorders": "None", + } + ) + + # Five epochs: one preparation epoch followed by Wake, NREM, REM, NREM. + labels = ( + ["P"] * 30 + + ["W"] * 30 + + ["N2"] * 30 + + ["R"] * 30 + + ["N3"] * 30 + ) + timestamps = np.arange(len(labels), dtype=np.float32) + frame = pd.DataFrame( + { + "TIMESTAMP": timestamps, + "IBI": np.sin(timestamps * 0.05 * (subject_index + 1)) + 1.0, + "HR": 60.0 + 3.0 * np.cos(timestamps * 0.05), + "BVP": np.sin(timestamps * 0.03), + "EDA": np.linspace(0.01, 0.04, len(labels)), + "TEMP": np.full(len(labels), 33.0 + 0.1 * subject_index), + "ACC_X": np.zeros(len(labels)), + "ACC_Y": np.ones(len(labels)), + "ACC_Z": np.full(len(labels), 2.0), + "Sleep_Stage": labels, + } + ) + frame.to_csv( + dreamt_root / "data_64Hz" / f"{patient_id}_whole_df.csv", + index=False, + ) + + pd.DataFrame(participant_rows).to_csv( + dreamt_root / "participant_info.csv", + index=False, + ) + return dreamt_root + + +def _to_numpy(value: Any) -> np.ndarray: + if isinstance(value, torch.Tensor): + return value.detach().cpu().numpy() + return np.asarray(value) + + +def _macro_f1(y_true: np.ndarray, y_pred: np.ndarray, num_classes: int) -> float: + f1_scores = [] + for class_index in range(num_classes): + true_pos = np.sum((y_true == class_index) & (y_pred == class_index)) + false_pos = np.sum((y_true != class_index) & (y_pred == class_index)) + false_neg = np.sum((y_true == class_index) & (y_pred != class_index)) + precision = true_pos / max(true_pos + false_pos, 1) + recall = true_pos / max(true_pos + false_neg, 1) + if precision + recall == 0: + f1_scores.append(0.0) + else: + f1_scores.append(2 * precision * recall / (precision + recall)) + return float(np.mean(f1_scores)) + + +def _class_f1(y_true: np.ndarray, y_pred: np.ndarray, positive_class: int) -> float: + true_pos = np.sum((y_true == positive_class) & (y_pred == positive_class)) + false_pos = np.sum((y_true != positive_class) & (y_pred == positive_class)) + false_neg = np.sum((y_true == positive_class) & (y_pred != positive_class)) + precision = true_pos / max(true_pos + false_pos, 1) + recall = true_pos / max(true_pos + false_neg, 1) + if precision + recall == 0: + return 0.0 + return float(2 * precision * recall / (precision + recall)) + + +def _cohen_kappa(y_true: np.ndarray, y_pred: np.ndarray, num_classes: int) -> float: + confusion = np.zeros((num_classes, num_classes), dtype=np.float64) + for true_label, pred_label in zip(y_true, y_pred): + confusion[int(true_label), int(pred_label)] += 1 + + total = confusion.sum() + if total == 0: + return 0.0 + observed = np.trace(confusion) / total + expected = np.sum(confusion.sum(axis=0) * confusion.sum(axis=1)) / (total * total) + if expected >= 1.0: + return 0.0 + return float((observed - expected) / (1.0 - expected)) + + +def _binary_auc(y_true: np.ndarray, y_score: np.ndarray) -> float: + positive_mask = y_true == 1 + negative_mask = y_true == 0 + num_pos = int(positive_mask.sum()) + num_neg = int(negative_mask.sum()) + if num_pos == 0 or num_neg == 0: + return float("nan") + + order = np.argsort(y_score) + ranks = np.empty_like(order, dtype=np.float64) + ranks[order] = np.arange(1, len(y_score) + 1, dtype=np.float64) + positive_ranks = ranks[positive_mask] + auc = (positive_ranks.sum() - num_pos * (num_pos + 1) / 2.0) / (num_pos * num_neg) + return float(auc) + + +def _multiclass_auroc( + y_true: np.ndarray, + y_prob: np.ndarray, + num_classes: int, +) -> float: + aucs = [] + for class_index in range(num_classes): + one_vs_rest = (y_true == class_index).astype(np.int64) + auc = _binary_auc(one_vs_rest, y_prob[:, class_index]) + if not np.isnan(auc): + aucs.append(auc) + if not aucs: + return 0.0 + return float(np.mean(aucs)) + + +def compute_metrics( + y_true: np.ndarray, + y_prob: np.ndarray, + num_classes: int = NUM_CLASSES, +) -> dict[str, float]: + y_pred = y_prob.argmax(axis=-1) + accuracy = float((y_true == y_pred).mean()) if y_true.size else 0.0 + macro_f1 = _macro_f1(y_true, y_pred, num_classes) + rem_f1 = _class_f1(y_true, y_pred, REM_CLASS_INDEX) + kappa = _cohen_kappa(y_true, y_pred, num_classes) + auroc = _multiclass_auroc(y_true, y_prob, num_classes) + return { + "accuracy": accuracy, + "macro_f1": macro_f1, + "rem_f1": rem_f1, + "cohen_kappa": kappa, + "auroc": auroc, + } + + +def evaluate_model(model, loader, sequence_output: bool = False) -> dict[str, float]: + y_true_parts = [] + y_prob_parts = [] + model.eval() + with torch.no_grad(): + for batch in loader: + output = model(**batch) + y_prob = _to_numpy(output["y_prob"]) + y_true = _to_numpy(output["y_true"]) + if sequence_output: + valid_mask = y_true != IGNORE_INDEX + if not np.any(valid_mask): + continue + y_true_parts.append(y_true[valid_mask]) + y_prob_parts.append(y_prob[valid_mask]) + else: + y_true_parts.append(y_true.reshape(-1)) + y_prob_parts.append(y_prob.reshape(-1, y_prob.shape[-1])) + + if not y_true_parts: + return { + "accuracy": 0.0, + "macro_f1": 0.0, + "rem_f1": 0.0, + "cohen_kappa": 0.0, + "auroc": 0.0, + } + + y_true_all = np.concatenate(y_true_parts, axis=0) + y_prob_all = np.concatenate(y_prob_parts, axis=0) + return compute_metrics(y_true_all, y_prob_all) + + +def train_one_epoch(model, loader, optimizer) -> float: + model.train() + running_loss = 0.0 + num_batches = 0 + for batch in loader: + optimizer.zero_grad() + output = model(**batch) + output["loss"].backward() + optimizer.step() + running_loss += float(output["loss"].item()) + num_batches += 1 + return running_loss / max(num_batches, 1) + + +def run_window_ablation(sample_dataset) -> None: + """Run a simple window-classification ablation.""" + configs = [ + {"name": "baseline", "hidden_dim": 32, "use_tcn": True, "use_attention": True}, + { + "name": "no_attention", + "hidden_dim": 32, + "use_tcn": True, + "use_attention": False, + }, + {"name": "no_tcn", "hidden_dim": 32, "use_tcn": False, "use_attention": True}, + { + "name": "small_hidden", + "hidden_dim": 16, + "use_tcn": True, + "use_attention": True, + }, + ] + loader = get_dataloader(sample_dataset, batch_size=4, shuffle=True) + + print("Window classification ablation") + for config in configs: + model = WatchSleepNet( + dataset=sample_dataset, + hidden_dim=config["hidden_dim"], + conv_channels=config["hidden_dim"], + num_attention_heads=4, + use_tcn=config["use_tcn"], + use_attention=config["use_attention"], + ) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + mean_loss = train_one_epoch(model, loader, optimizer) + metrics = evaluate_model(model, loader, sequence_output=False) + print( + f"{config['name']:>12} | loss={mean_loss:.4f} " + f"| acc={metrics['accuracy']:.3f} " + f"| macro_f1={metrics['macro_f1']:.3f}" + ) + + +def run_sequence_ablation(sample_dataset, feature_variant: str) -> None: + """Run a sequence-style ablation closer to the paper.""" + configs = [ + { + "name": "baseline_seq", + "hidden_dim": 32, + "use_tcn": True, + "use_attention": True, + }, + { + "name": "no_attn_seq", + "hidden_dim": 32, + "use_tcn": True, + "use_attention": False, + }, + { + "name": "no_tcn_seq", + "hidden_dim": 32, + "use_tcn": False, + "use_attention": True, + }, + { + "name": "small_seq", + "hidden_dim": 16, + "use_tcn": True, + "use_attention": True, + }, + ] + loader = get_dataloader(sample_dataset, batch_size=2, shuffle=True) + + print(f"\nSequence-style ablation ({feature_variant})") + for config in configs: + model = WatchSleepNet( + dataset=sample_dataset, + hidden_dim=config["hidden_dim"], + conv_channels=config["hidden_dim"], + num_attention_heads=4, + use_tcn=config["use_tcn"], + use_attention=config["use_attention"], + sequence_output=True, + num_classes=NUM_CLASSES, + ignore_index=IGNORE_INDEX, + ) + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + mean_loss = train_one_epoch(model, loader, optimizer) + metrics = evaluate_model(model, loader, sequence_output=True) + print( + f"{config['name']:>12} | loss={mean_loss:.4f} " + f"| acc={metrics['accuracy']:.3f} " + f"| macro_f1={metrics['macro_f1']:.3f} " + f"| rem_f1={metrics['rem_f1']:.3f} " + f"| kappa={metrics['cohen_kappa']:.3f} " + f"| auroc={metrics['auroc']:.3f}" + ) + + +def main() -> None: + temp_dir = Path(tempfile.mkdtemp(prefix="pyhealth_dreamt_example_")) + try: + dreamt_root = build_synthetic_dreamt_root(temp_dir) + dataset = DREAMTDataset(root=str(dreamt_root), cache_dir=temp_dir / "cache") + + window_task = SleepStagingDREAMT( + window_size=30, + stride=30, + source_preference="wearable", + ) + window_dataset = dataset.set_task(task=window_task, num_workers=1) + print(f"Generated {len(window_dataset)} synthetic window samples") + run_window_ablation(window_dataset) + + seq_task = SleepStagingDREAMTSeq( + feature_columns=("IBI",), + epoch_seconds=30.0, + sequence_length=4, + source_preference="wearable", + ignore_index=IGNORE_INDEX, + ) + seq_dataset = dataset.set_task(task=seq_task, num_workers=1) + print(f"\nGenerated {len(seq_dataset)} IBI-only sequence samples") + run_sequence_ablation(seq_dataset, feature_variant="IBI only") + + rich_seq_task = SleepStagingDREAMTSeq( + feature_columns=SleepStagingDREAMT.DEFAULT_FEATURE_COLUMNS, + epoch_seconds=30.0, + sequence_length=4, + source_preference="wearable", + ignore_index=IGNORE_INDEX, + ) + rich_seq_dataset = dataset.set_task(task=rich_seq_task, num_workers=1) + print( + f"Generated {len(rich_seq_dataset)} multi-signal sequence samples" + ) + run_sequence_ablation(rich_seq_dataset, feature_variant="IBI + context") + finally: + shutil.rmtree(temp_dir, ignore_errors=True) + + +if __name__ == "__main__": + main() diff --git a/pyhealth/datasets/configs/dreamt.yaml b/pyhealth/datasets/configs/dreamt.yaml index ad6430e45..472c55a58 100644 --- a/pyhealth/datasets/configs/dreamt.yaml +++ b/pyhealth/datasets/configs/dreamt.yaml @@ -5,6 +5,7 @@ tables: patient_id: "patient_id" timestamp: null attributes: + - "record_id" - "age" - "gender" - "bmi" @@ -13,5 +14,10 @@ tables: - "mean_sao2" - "arousal_index" - "medical_history" + - "sleep_disorders" + - "signal_file" + - "signal_source" + - "sampling_rate_hz" + - "signal_format" - "file_64hz" - - "file_100hz" \ No newline at end of file + - "file_100hz" diff --git a/pyhealth/datasets/dreamt.py b/pyhealth/datasets/dreamt.py index a7c43d23c..c39398597 100644 --- a/pyhealth/datasets/dreamt.py +++ b/pyhealth/datasets/dreamt.py @@ -1,179 +1,400 @@ import logging -import os +import re from pathlib import Path -from typing import Optional, Union +from typing import Optional import pandas as pd -from pyhealth.datasets import BaseDataset -logger = logging.getLogger(__name__) - -class DREAMTDataset(BaseDataset): - """ - Base Dataset for Real-time sleep stage EstimAtion using Multisensor wearable Technology (DREAMT) +from .base_dataset import BaseDataset +from ..tasks.sleep_staging import SleepStagingDREAMT - Dataset accepts current versions of DREAMT (1.0.0, 1.0.1, 2.0.0, 2.1.0), available at: - https://physionet.org/content/dreamt/ - - DREAMT includes wrist-based wearable and polysomnography (PSG) sleep data from 100 participants - recruited from the Duke University Health System (DUHS) Sleep Disorder Lab. This includes - wearable signals, PSG signals, sleep labels, and clinical data related to sleep health and disorders. +logger = logging.getLogger(__name__) - Citations: - --------- - When using this dataset, please cite: - Wang, K., Yang, J., Shetty, A., & Dunn, J. (2025). DREAMT: Dataset for Real-time sleep stage EstimAtion - using Multisensor wearable Technology (version 2.1.0). PhysioNet. RRID:SCR_007345. - https://doi.org/10.13026/7r9r-7r24 +class DREAMTDataset(BaseDataset): + """Base dataset wrapper for DREAMT sleep-staging data. - Will Ke Wang, Jiamu Yang, Leeor Hershkovich, Hayoung Jeong, Bill Chen, Karnika Singh, Ali R Roghanizad, - Md Mobashir Hasan Shandhi, Andrew R Spector, Jessilyn Dunn. (2024). Proceedings of the fifth - Conference on Health, Inference, and Learning, PMLR 248:380-396. + This dataset is designed for reproducibility work around DREAMT-style sleep + staging. It supports two common local layouts: - Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R., ... & Stanley, H. E. (2000). - PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex - physiologic signals. Circulation [Online]. 101 (23), pp. e215–e220. RRID:SCR_007345. + 1. The official PhysioNet DREAMT release rooted at a version directory + containing ``participant_info.csv`` and one or both of ``data_64Hz`` and + ``data_100Hz``. + 2. A processed per-subject directory containing files named with a subject + identifier such as ``S002_record.csv`` or ``S002_features.npz``. - Note: - --------- - Dataset follows file and folder structure of dataset version, looks for participant_info.csv and data folders, - so root path should be version downloaded, example: root = ".../dreamt/1.0.0/" or ".../dreamt/2.0.0/" + For the official release, the dataset builds a metadata table with one + event per subject. The downstream task then reads the referenced signal file + and produces fixed-size windows for sleep-stage classification. Args: - root: root directory containing the dataset files - dataset_name: optional name of dataset, defaults to "dreamt_sleep" - config_path: optional configuration file, defaults to "dreamt.yaml" - - Attributes: - root: root directory containing the dataset files - dataset_name: name of dataset - config_path: path to configuration file + root: Root directory of DREAMT data. This can be the DREAMT version + directory itself or a parent directory that contains exactly one + DREAMT release. + dataset_name: Optional dataset name. Defaults to ``"dreamt_sleep"``. + config_path: Optional config path. Defaults to the bundled + ``dreamt.yaml`` config. + preferred_source: Preferred signal source for the generic + ``signal_file`` column. One of ``"auto"``, ``"wearable"``, or + ``"psg"``. + cache_dir: Optional PyHealth cache directory. + num_workers: Number of workers for PyHealth dataset operations. + dev: Whether to enable PyHealth dev mode. Examples: >>> from pyhealth.datasets import DREAMTDataset - >>> dataset = DREAMTDataset(root = "/path/to/dreamt/data/version") - >>> dataset.stats() - >>> - >>> # Get all patient ids - >>> unique_patients = dataset.unique_patient_ids - >>> print(f"There are {len(unique_patients)} patients") - >>> - >>> # Get single patient data + >>> dataset = DREAMTDataset( + ... root="/path/to/dreamt/2.1.0", + ... preferred_source="wearable", + ... ) >>> patient = dataset.get_patient("S002") - >>> print(f"Patient has {len(patient.data_source)} event") - >>> - >>> # Get event - >>> event = patient.get_events(event_type="dreamt_sleep") - >>> - >>> # Get Apnea-Hypopnea Index (AHI) - >>> ahi = event[0].ahi - >>> print(f"AHI is {ahi}") - >>> - >>> # Get 64Hz sleep file path - >>> file_path = event[0].file_64hz - >>> print(f"64Hz sleep file path: {file_path}") + >>> event = patient.get_events("dreamt_sleep")[0] + >>> event.signal_file + '/path/to/dreamt/2.1.0/data_64Hz/S002_whole_df.csv' """ + SUPPORTED_SUFFIXES = {".csv", ".parquet", ".pkl", ".pickle", ".npz", ".npy"} + _SUBJECT_PATTERN = re.compile(r"(S\d{3})", re.IGNORECASE) + def __init__( - self, - root: str, - dataset_name: Optional[str] = None, - config_path: Optional[str] = None, + self, + root: str, + dataset_name: Optional[str] = None, + config_path: Optional[str] = None, + preferred_source: str = "auto", + cache_dir: str | Path | None = None, + num_workers: int = 1, + dev: bool = False, ) -> None: + preferred_source = preferred_source.lower() + if preferred_source not in {"auto", "wearable", "psg"}: + raise ValueError( + "preferred_source must be one of 'auto', 'wearable', or 'psg'." + ) + if config_path is None: - logger.info("No config provided, using default config") config_path = Path(__file__).parent / "configs" / "dreamt.yaml" - - metadata_file = Path(root) / "dreamt-metadata.csv" - if not os.path.exists(metadata_file): - logger.info(f"{metadata_file} does not exist") - self.prepare_metadata(root) - - default_tables = ["dreamt_sleep"] + resolved_root = self._resolve_root(Path(root).expanduser().resolve()) + self.prepare_metadata(resolved_root, preferred_source=preferred_source) + + self.preferred_source = preferred_source super().__init__( - root=root, - tables=default_tables, + root=str(resolved_root), + tables=["dreamt_sleep"], dataset_name=dataset_name or "dreamt_sleep", - config_path=config_path + config_path=str(config_path), + cache_dir=cache_dir, + num_workers=num_workers, + dev=dev, ) - - def get_patient_file(self, patient_id: str, root: str, file_path: str) -> Union[str | None]: - """ - Returns file path of 64Hz and 100Hz data for a patient, or None if no file found - - Args: - patient_id: patient identifier - root: root directory containing the dataset files - file_path: path to location of 64Hz or 100Hz file - - Returns: - file: path to file location or None if no file found - """ - - if file_path == "data_64Hz" or file_path == "data": - file = Path(root) / f"{file_path}" / f"{patient_id}_whole_df.csv" - - if file_path == "data_100Hz": - file = Path(root) / f"{file_path}" / f"{patient_id}_PSG_df.csv" - if not os.path.exists(str(file)): - logger.info(f"{file} not found") - file = None - - return file + @staticmethod + def _extract_subject_id(path: Path) -> Optional[str]: + match = DREAMTDataset._SUBJECT_PATTERN.search(path.name) + if match is None: + return None + return match.group(1).upper() + @classmethod + def _is_signal_file(cls, path: Path) -> bool: + if path.suffix.lower() not in cls.SUPPORTED_SUFFIXES: + return False + return cls._extract_subject_id(path) is not None - def prepare_metadata(self, root: str) -> None: - """ - Prepares metadata csv file for the DREAMT dataset by performing the following: - 1. Obtain clinical data from participant_info.csv file - 2. Process file paths based on patients found in clinical data - 3. Organize all data into a single DataFrame - 4. Save the processed DataFrame to a CSV file - - Args: - root: root directory containing the dataset files - """ + @classmethod + def _resolve_root(cls, root: Path) -> Path: + """Resolve a user-provided root to a concrete DREAMT data directory.""" + if not root.exists(): + raise FileNotFoundError(f"DREAMT root does not exist: {root}") + + if (root / "participant_info.csv").exists(): + return root + + if any( + cls._is_signal_file(path) + for path in root.iterdir() + if path.is_file() + ): + return root + + candidates = sorted( + { + path.parent.resolve() + for path in root.rglob("participant_info.csv") + } + ) + if len(candidates) == 1: + return candidates[0] + if len(candidates) > 1: + candidate_text = ", ".join(str(path) for path in candidates[:5]) + raise ValueError( + "Found multiple DREAMT roots under the provided path. Use an " + f"explicit version directory instead. Candidates: {candidate_text}" + ) + + signal_candidates = sorted( + { + path.parent.resolve() + for path in root.rglob("*") + if path.is_file() and cls._is_signal_file(path) + } + ) + if len(signal_candidates) == 1: + return signal_candidates[0] + if len(signal_candidates) > 1: + counts = { + candidate: sum( + 1 + for child in candidate.iterdir() + if child.is_file() and cls._is_signal_file(child) + ) + for candidate in signal_candidates + } + best = max(counts.items(), key=lambda item: item[1])[0] + logger.info( + "Resolved DREAMT root to %s based on detected subject files.", + best, + ) + return best + + raise FileNotFoundError( + "Could not find a DREAMT data directory containing " + "participant_info.csv or processed subject files." + ) - output_path = Path(root) / "dreamt-metadata.csv" + @staticmethod + def _coerce_float(value: object) -> object: + if pd.isna(value): + return None + text = str(value).strip() + if not text: + return None + if text.endswith("%"): + text = text[:-1] + try: + return float(text) + except ValueError: + return value + + @classmethod + def _locate_subject_file(cls, directory: Path, patient_id: str) -> Optional[Path]: + if not directory.exists(): + return None + + candidates = sorted( + path + for path in directory.iterdir() + if path.is_file() + and cls._extract_subject_id(path) == patient_id + and path.suffix.lower() in cls.SUPPORTED_SUFFIXES + ) + if not candidates: + return None - # Obtain patient clinical data - participant_info_path = Path(root) / "participant_info.csv" + def score(path: Path) -> tuple[int, int, int]: + name = path.name.lower() + return ( + int("updated" in name), + int("whole" in name or "psg" in name or "record" in name), + -len(name), + ) + + return max(candidates, key=score) + + @staticmethod + def _select_signal_file( + file_64hz: Optional[Path], + file_100hz: Optional[Path], + preferred_source: str, + ) -> tuple[Optional[Path], Optional[str], Optional[float]]: + if preferred_source == "wearable": + if file_64hz is not None: + return file_64hz, "wearable", 64.0 + if file_100hz is not None: + return file_100hz, "psg", 100.0 + elif preferred_source == "psg": + if file_100hz is not None: + return file_100hz, "psg", 100.0 + if file_64hz is not None: + return file_64hz, "wearable", 64.0 + else: + if file_64hz is not None: + return file_64hz, "wearable", 64.0 + if file_100hz is not None: + return file_100hz, "psg", 100.0 + return None, None, None + + @classmethod + def _build_metadata_from_participant_info( + cls, + root: Path, + preferred_source: str, + ) -> pd.DataFrame: + participant_info_path = root / "participant_info.csv" participant_info = pd.read_csv(participant_info_path) + participant_info.columns = [ + str(col).strip() for col in participant_info.columns + ] + + rename_map = { + "SID": "patient_id", + "AGE": "age", + "Age": "age", + "GENDER": "gender", + "Gender": "gender", + "BMI": "bmi", + "OAHI": "oahi", + "AHI": "ahi", + "Mean_SaO2": "mean_sao2", + "Arousal Index": "arousal_index", + "MEDICAL_HISTORY": "medical_history", + "Sleep_Disorders": "sleep_disorders", + } + participant_info = participant_info.rename(columns=rename_map) + if "patient_id" not in participant_info.columns: + raise ValueError( + "participant_info.csv must contain a SID/patient_id column." + ) + + file_64_dir = root / "data_64Hz" + if not file_64_dir.exists(): + legacy_dir = root / "data" + file_64_dir = legacy_dir if legacy_dir.exists() else file_64_dir + file_100_dir = root / "data_100Hz" + + participant_info["patient_id"] = participant_info["patient_id"].astype(str) + participant_info["file_64hz"] = participant_info["patient_id"].apply( + lambda pid: cls._locate_subject_file(file_64_dir, pid) + ) + participant_info["file_100hz"] = participant_info["patient_id"].apply( + lambda pid: cls._locate_subject_file(file_100_dir, pid) + ) + + resolved = participant_info.apply( + lambda row: cls._select_signal_file( + row["file_64hz"], + row["file_100hz"], + preferred_source=preferred_source, + ), + axis=1, + result_type="expand", + ) + resolved.columns = ["signal_file", "signal_source", "sampling_rate_hz"] + participant_info = pd.concat([participant_info, resolved], axis=1) + for column in ["age", "bmi", "oahi", "ahi", "mean_sao2", "arousal_index"]: + if column in participant_info.columns: + participant_info[column] = participant_info[column].apply( + cls._coerce_float + ) - # Determine folder structure, assign associated file paths based on folder structure - all_folders = [item.name for item in Path(root).iterdir() if item.is_dir()] - file_path_64hz = "data_64Hz" if "data_64Hz" in all_folders else "data" - file_path_100hz = "data_100Hz" + for column in ["file_64hz", "file_100hz", "signal_file"]: + participant_info[column] = participant_info[column].apply( + lambda value: str(value) if isinstance(value, Path) else value + ) - # Determine paths for 64Hz and 100Hz files for each patient - participant_info['file_64hz'] = participant_info['SID'].apply( - lambda sid: self.get_patient_file(sid, root, file_path_64hz) + participant_info["signal_format"] = participant_info["signal_file"].apply( + lambda value: Path(value).suffix.lower() if isinstance(value, str) else None ) - participant_info['file_100hz'] = participant_info['SID'].apply( - lambda sid: self.get_patient_file(sid, root, file_path_100hz) + participant_info["record_id"] = participant_info["patient_id"] + + available_mask = participant_info["signal_file"].notna() + missing_count = int((~available_mask).sum()) + if missing_count: + logger.info( + "Dropping %s DREAMT participants without local signal files.", + missing_count, + ) + participant_info = participant_info.loc[available_mask].reset_index(drop=True) + + return participant_info + + @classmethod + def _build_metadata_from_processed_files( + cls, + root: Path, + preferred_source: str, + ) -> pd.DataFrame: + del preferred_source + files = sorted( + path + for path in root.rglob("*") + if path.is_file() and cls._is_signal_file(path) ) + if not files: + raise FileNotFoundError( + "No DREAMT subject files were found under the provided root." + ) + + rows = [] + for path in files: + patient_id = cls._extract_subject_id(path) + if patient_id is None: + continue + rows.append( + { + "patient_id": patient_id, + "record_id": patient_id, + "signal_file": str(path.resolve()), + "signal_source": "processed", + "sampling_rate_hz": None, + "signal_format": path.suffix.lower(), + "file_64hz": None, + "file_100hz": None, + } + ) + + metadata = pd.DataFrame(rows) + metadata = metadata.drop_duplicates(subset=["patient_id"], keep="first") + return metadata.sort_values("patient_id").reset_index(drop=True) + + @classmethod + def prepare_metadata(cls, root: str | Path, preferred_source: str = "auto") -> None: + """Create ``dreamt-metadata.csv`` for a local DREAMT directory. + + The generated metadata always includes: + + - ``patient_id`` + - ``record_id`` + - ``signal_file`` + - ``signal_source`` + - ``sampling_rate_hz`` + - ``signal_format`` + - ``file_64hz`` + - ``file_100hz`` + + For the official DREAMT release, participant metadata columns are also + preserved when available. + """ + root = Path(root).expanduser().resolve() + output_path = root / "dreamt-metadata.csv" + + if (root / "participant_info.csv").exists(): + metadata = cls._build_metadata_from_participant_info( + root, + preferred_source=preferred_source, + ) + else: + metadata = cls._build_metadata_from_processed_files( + root, + preferred_source=preferred_source, + ) + + if metadata.empty: + raise ValueError("No DREAMT subjects were found while preparing metadata.") + + if metadata["signal_file"].notna().sum() == 0: + raise ValueError( + "Metadata was created, but no usable signal files were detected. " + "Expected DREAMT subject files such as S002_whole_df.csv, " + "S002_PSG_df.csv, or processed subject files like S002_record.npz." + ) + + metadata.to_csv(output_path, index=False) - # Remove "%" from mean SaO2 recording - participant_info['Mean_SaO2'] = participant_info['Mean_SaO2'].str[:-1] - - # Format columns to align with BaseDataset - participant_info = participant_info.rename(columns = { - 'SID': 'patient_id', - 'AGE': 'age', - 'GENDER': 'gender', - 'BMI': 'bmi', - 'OAHI': 'oahi', - 'AHI': 'ahi', - 'Mean_SaO2': 'mean_sao2', - 'Arousal Index': 'arousal_index', - "MEDICAL_HISTORY": 'medical_history', - "Sleep_Disorders": 'sleep_disorders' - }) - - # Create csv - participant_info.to_csv(output_path, index=False) + @property + def default_task(self) -> SleepStagingDREAMT: + """Return the default DREAMT sleep-staging task.""" + return SleepStagingDREAMT() diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 945822910..e7745b489 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -40,6 +40,7 @@ from .ehrmamba import EHRMamba, MambaBlock from .vae import VAE from .vision_embedding import VisionEmbeddingModel +from .watchsleepnet import WatchSleepNet from .text_embedding import TextEmbedding from .sdoh import SdohClassifier from .medlink import MedLink diff --git a/pyhealth/models/watchsleepnet.py b/pyhealth/models/watchsleepnet.py new file mode 100644 index 000000000..43f4d3694 --- /dev/null +++ b/pyhealth/models/watchsleepnet.py @@ -0,0 +1,426 @@ +import math +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from ..datasets.sample_dataset import SampleDataset +from .base_model import BaseModel + + +class ResidualConvBlock(nn.Module): + """Residual 1D convolution block for local feature extraction.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + dropout: float = 0.1, + ) -> None: + super().__init__() + padding = kernel_size // 2 + self.conv1 = nn.Conv1d( + in_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + ) + self.bn1 = nn.BatchNorm1d(out_channels) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = nn.Conv1d( + out_channels, + out_channels, + kernel_size=kernel_size, + padding=padding, + ) + self.bn2 = nn.BatchNorm1d(out_channels) + self.dropout2 = nn.Dropout(dropout) + self.relu2 = nn.ReLU() + + self.shortcut = None + if in_channels != out_channels: + self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x if self.shortcut is None else self.shortcut(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.dropout1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.dropout2(x) + x = x + residual + x = self.relu2(x) + return x + + +class DilatedTCNBlock(nn.Module): + """Residual dilated temporal block.""" + + def __init__( + self, + channels: int, + dilation: int, + kernel_size: int = 3, + dropout: float = 0.1, + ) -> None: + super().__init__() + padding = ((kernel_size - 1) * dilation) // 2 + self.conv1 = nn.Conv1d( + channels, + channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + ) + self.bn1 = nn.BatchNorm1d(channels) + self.relu1 = nn.ReLU() + self.dropout1 = nn.Dropout(dropout) + + self.conv2 = nn.Conv1d( + channels, + channels, + kernel_size=kernel_size, + padding=padding, + dilation=dilation, + ) + self.bn2 = nn.BatchNorm1d(channels) + self.dropout2 = nn.Dropout(dropout) + self.relu2 = nn.ReLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + residual = x + x = self.conv1(x) + x = self.bn1(x) + x = self.relu1(x) + x = self.dropout1(x) + + x = self.conv2(x) + x = self.bn2(x) + x = self.dropout2(x) + x = x + residual + x = self.relu2(x) + return x + + +class TemporalAttentionPool(nn.Module): + """Attention pooling over temporal features.""" + + def __init__(self, input_dim: int) -> None: + super().__init__() + self.score = nn.Linear(input_dim, 1) + + def forward( + self, + x: torch.Tensor, + mask: Optional[torch.Tensor] = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + scores = self.score(x).squeeze(-1) + if mask is not None: + scores = scores.masked_fill(~mask, float("-inf")) + weights = torch.softmax(scores, dim=1) + pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1) + return pooled, weights + + +class WatchSleepNet(BaseModel): + """Simplified WatchSleepNet-style classifier for sleep staging. + + The architecture is intentionally compact but preserves the main modeling + ideas from WatchSleepNet: + + - residual 1D convolution for local morphology + - optional dilated temporal blocks + - bidirectional LSTM for sequence modeling + - optional multi-head self-attention + - temporal pooling and a classification head + + Args: + dataset: A task-specific ``SampleDataset`` with one tensor feature and + one multiclass label. + feature_key: Feature field to read from the dataset. Defaults to the + first feature key. + input_dim: Optional feature dimension. If omitted, inferred from the + first sample. + hidden_dim: Hidden size of the bidirectional LSTM. + conv_channels: Channel width of the residual convolution stack. + conv_blocks: Number of residual convolution blocks. + tcn_blocks: Number of dilated TCN blocks. + lstm_layers: Number of bidirectional LSTM layers. + num_attention_heads: Number of attention heads when attention is used. + dropout: Dropout rate applied throughout the network. + num_classes: Optional override for output classes. Defaults to the task + label processor size. + use_tcn: Whether to use the dilated TCN stack. + use_attention: Whether to use multi-head self-attention and attention + pooling. When disabled, the model uses masked mean pooling instead. + sequence_output: Whether to emit logits for each timestep instead of a + single pooled prediction. + ignore_index: Ignore label used for padded sequence targets. + """ + + def __init__( + self, + dataset: SampleDataset, + feature_key: Optional[str] = None, + input_dim: Optional[int] = None, + hidden_dim: int = 64, + conv_channels: int = 64, + conv_blocks: int = 2, + tcn_blocks: int = 2, + lstm_layers: int = 1, + num_attention_heads: int = 4, + dropout: float = 0.1, + num_classes: Optional[int] = None, + use_tcn: bool = True, + use_attention: bool = True, + sequence_output: bool = False, + ignore_index: int = -100, + ) -> None: + super().__init__(dataset=dataset) + if len(self.label_keys) != 1: + raise ValueError("WatchSleepNet supports a single label field.") + + self.label_key = self.label_keys[0] + self.feature_key = feature_key or self.feature_keys[0] + self.hidden_dim = hidden_dim + self.conv_channels = conv_channels + self.conv_blocks = conv_blocks + self.tcn_blocks = tcn_blocks + self.lstm_layers = lstm_layers + self.num_attention_heads = num_attention_heads + self.dropout = dropout + self.use_tcn = use_tcn + self.use_attention = use_attention + self.sequence_output = sequence_output + self.ignore_index = ignore_index + + self.input_dim = input_dim or self._infer_input_dim(self.feature_key) + if self.input_dim <= 0: + raise ValueError("input_dim must be positive.") + + self.input_projection = nn.Conv1d(self.input_dim, conv_channels, kernel_size=1) + self.residual_conv = nn.ModuleList( + [ + ResidualConvBlock( + conv_channels, + conv_channels, + kernel_size=3, + dropout=dropout, + ) + for _ in range(conv_blocks) + ] + ) + + self.tcn = nn.ModuleList( + [ + DilatedTCNBlock( + conv_channels, + dilation=2**block_index, + kernel_size=3, + dropout=dropout, + ) + for block_index in range(tcn_blocks) + ] + ) + + self.bilstm = nn.LSTM( + input_size=conv_channels, + hidden_size=hidden_dim, + num_layers=lstm_layers, + batch_first=True, + bidirectional=True, + dropout=dropout if lstm_layers > 1 else 0.0, + ) + + context_dim = hidden_dim * 2 + self.attention_projection = None + if use_attention: + projected_dim = context_dim + if projected_dim % num_attention_heads != 0: + projected_dim = ( + math.ceil(projected_dim / num_attention_heads) + * num_attention_heads + ) + self.attention_projection = nn.Linear(context_dim, projected_dim) + self.self_attention = nn.MultiheadAttention( + embed_dim=projected_dim, + num_heads=num_attention_heads, + dropout=dropout, + batch_first=True, + ) + self.attention_norm = nn.LayerNorm(projected_dim) + self.attention_pool = TemporalAttentionPool(projected_dim) + classifier_input_dim = projected_dim + else: + self.self_attention = None + self.attention_norm = None + self.attention_pool = None + classifier_input_dim = context_dim + + self.final_dropout = nn.Dropout(dropout) + if self.sequence_output and num_classes is None: + raise ValueError( + "num_classes must be provided when sequence_output=True." + ) + self.classifier = nn.Linear( + classifier_input_dim, + num_classes if num_classes is not None else self.get_output_size(), + ) + + def _infer_input_dim(self, feature_key: str) -> int: + for sample in self.dataset: + if feature_key not in sample: + continue + value = sample[feature_key] + tensor = ( + value if isinstance(value, torch.Tensor) else torch.as_tensor(value) + ) + if tensor.dim() == 1: + return 1 + return int(tensor.shape[-1]) + raise ValueError( + f"Unable to infer input_dim for feature '{feature_key}' from the dataset." + ) + + def _coerce_input(self, value: Any) -> torch.Tensor: + tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value) + tensor = tensor.to(self.device, dtype=torch.float32) + + if tensor.dim() == 2: + tensor = tensor.unsqueeze(-1) + if tensor.dim() != 3: + raise ValueError( + "Expected input tensor with 2 or 3 dims, received shape " + f"{tuple(tensor.shape)}." + ) + + if tensor.shape[-1] == self.input_dim: + return tensor + if tensor.shape[1] == self.input_dim: + return tensor.transpose(1, 2) + raise ValueError( + "Unable to infer whether input is channel-last or channel-first. " + f"Expected one axis to match input_dim={self.input_dim}, got " + f"shape {tuple(tensor.shape)}." + ) + + @staticmethod + def _default_mask(x: torch.Tensor) -> torch.Tensor: + return torch.ones(x.shape[0], x.shape[1], dtype=torch.bool, device=x.device) + + def _build_mask( + self, + x: torch.Tensor, + explicit_mask: Optional[Any] = None, + ) -> torch.Tensor: + if explicit_mask is None: + return self._default_mask(x) + mask = ( + explicit_mask + if isinstance(explicit_mask, torch.Tensor) + else torch.as_tensor(explicit_mask) + ) + mask = mask.to(self.device) + if mask.dim() != 2: + raise ValueError( + f"Expected mask with shape [batch, seq_len], got {tuple(mask.shape)}." + ) + return mask > 0 + + @staticmethod + def _masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + weights = mask.unsqueeze(-1).to(dtype=x.dtype) + denom = torch.clamp(weights.sum(dim=1), min=1.0) + return (x * weights).sum(dim=1) / denom + + def _encode_sequence( + self, + x: torch.Tensor, + mask: torch.Tensor, + ) -> torch.Tensor: + # x: [batch, seq_len, input_dim] -> [batch, input_dim, seq_len] + x = x.transpose(1, 2) + x = self.input_projection(x) + for block in self.residual_conv: + x = block(x) + if self.use_tcn: + for block in self.tcn: + x = block(x) + + # x: [batch, conv_channels, seq_len] -> [batch, seq_len, conv_channels] + x = x.transpose(1, 2) + x, _ = self.bilstm(x) + + if self.use_attention and self.self_attention is not None: + residual = x + if self.attention_projection is not None: + residual = self.attention_projection(residual) + attn_out, _ = self.self_attention( + residual, + residual, + residual, + key_padding_mask=~mask, + need_weights=False, + ) + x = self.attention_norm(attn_out + residual) + return x + + def _compute_sequence_loss( + self, + logits: torch.Tensor, + y_true: torch.Tensor, + ) -> torch.Tensor: + return F.cross_entropy( + logits.reshape(-1, logits.shape[-1]), + y_true.reshape(-1).long(), + ignore_index=self.ignore_index, + ) + + def forward(self, **kwargs) -> dict[str, torch.Tensor]: + x = self._coerce_input(kwargs[self.feature_key]) + mask = self._build_mask(x, kwargs.get("mask")) + x = self._encode_sequence(x, mask) + + if self.sequence_output: + embed = self.final_dropout(x) + logits = self.classifier(embed) + results = { + "logit": logits, + "y_prob": torch.softmax(logits, dim=-1), + } + if self.label_key in kwargs: + y_true = kwargs[self.label_key].to(self.device).long() + results["loss"] = self._compute_sequence_loss(logits, y_true) + results["y_true"] = y_true + if kwargs.get("embed", False): + results["embed"] = embed + return results + + if self.use_attention and self.attention_pool is not None: + pooled, _ = self.attention_pool(x, mask) + else: + pooled = self._masked_mean(x, mask) + + embed = self.final_dropout(pooled) + logits = self.classifier(embed) + results = { + "logit": logits, + "y_prob": self.prepare_y_prob(logits), + } + + if self.label_key in kwargs: + y_true = kwargs[self.label_key].to(self.device) + results["loss"] = self.get_loss_function()(logits, y_true) + results["y_true"] = y_true + + if kwargs.get("embed", False): + results["embed"] = embed + + return results diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 2f4294a19..0d3c2207f 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -54,6 +54,8 @@ ReadmissionPredictionOMOP, ) from .sleep_staging import ( + SleepStagingDREAMT, + SleepStagingDREAMTSeq, sleep_staging_isruc_fn, sleep_staging_shhs_fn, sleep_staging_sleepedf_fn, diff --git a/pyhealth/tasks/sleep_staging.py b/pyhealth/tasks/sleep_staging.py index 911b672c8..b975563c4 100644 --- a/pyhealth/tasks/sleep_staging.py +++ b/pyhealth/tasks/sleep_staging.py @@ -1,8 +1,14 @@ import os import pickle +from pathlib import Path +from typing import Any, Dict, Optional, Sequence + import mne -import pandas as pd import numpy as np +import pandas as pd +import polars as pl + +from .base_task import BaseTask def sleep_staging_isruc_fn(record, epoch_seconds=10, label_id=1): @@ -329,6 +335,578 @@ def sleep_staging_shhs_fn(record, epoch_seconds=30): return samples +class SleepStagingDREAMT(BaseTask): + """Three-class sleep staging task for DREAMT-style wearable sequences. + + This task converts one DREAMT subject recording into fixed-length windows of + wearable features and maps detailed sleep stages to three classes: + + - ``0``: Wake + - ``1``: NREM + - ``2``: REM + + Supported input formats: + + - Raw per-subject CSV/Parquet/Pickle files containing wearable columns and a + ``Sleep_Stage`` column. + - Processed ``.npz``/``.npy`` dictionaries with keys such as + ``features``, ``labels``, and optional ``feature_names``. + + The default feature set mirrors the smartwatch-oriented signals available in + DREAMT and includes ``IBI`` plus common wearable context channels. + """ + + task_name: str = "SleepStagingDREAMT" + input_schema: Dict[str, str] = {"signal": "tensor"} + output_schema: Dict[str, str] = {"label": "multiclass"} + + DEFAULT_FEATURE_COLUMNS = ( + "IBI", + "HR", + "BVP", + "EDA", + "TEMP", + "ACC_X", + "ACC_Y", + "ACC_Z", + ) + _IGNORE_LABEL = -1 + _LABEL_MAP = { + "W": 0, + "WAKE": 0, + "WAKEFUL": 0, + "0": 0, + "N1": 1, + "N2": 1, + "N3": 1, + "N4": 1, + "NREM": 1, + "1": 1, + "2": 1, + "3": 1, + "4": 1, + "R": 2, + "REM": 2, + "5": 2, + } + _INVALID_LABELS = {"", "P", "PREPARATION", "UNKNOWN", "?", "NAN", "NONE"} + + def __init__( + self, + feature_columns: Optional[Sequence[str]] = None, + label_column: str = "Sleep_Stage", + source_preference: str = "wearable", + window_seconds: float = 30.0, + stride_seconds: Optional[float] = None, + window_size: Optional[int] = None, + stride: Optional[int] = None, + default_sampling_rate_hz: Optional[float] = None, + min_labeled_fraction: float = 0.5, + pad_short_windows: bool = False, + include_partial_last_window: bool = False, + ) -> None: + if source_preference not in {"auto", "wearable", "psg"}: + raise ValueError( + "source_preference must be one of 'auto', 'wearable', or 'psg'." + ) + if window_seconds <= 0 and window_size is None: + raise ValueError( + "window_seconds must be positive when window_size is None." + ) + if min_labeled_fraction <= 0 or min_labeled_fraction > 1: + raise ValueError("min_labeled_fraction must be in (0, 1].") + + self.feature_columns = tuple(feature_columns or self.DEFAULT_FEATURE_COLUMNS) + self.label_column = label_column + self.source_preference = source_preference + self.window_seconds = float(window_seconds) + self.stride_seconds = ( + float(stride_seconds) if stride_seconds is not None else None + ) + self.window_size = window_size + self.stride = stride + self.default_sampling_rate_hz = default_sampling_rate_hz + self.min_labeled_fraction = min_labeled_fraction + self.pad_short_windows = pad_short_windows + self.include_partial_last_window = include_partial_last_window + + def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame: + return df.filter( + (pl.col("event_type") == "dreamt_sleep") + & ( + pl.col("dreamt_sleep/signal_file").is_not_null() + | pl.col("dreamt_sleep/file_64hz").is_not_null() + | pl.col("dreamt_sleep/file_100hz").is_not_null() + ) + ) + + @classmethod + def _normalize_stage_label(cls, value: Any) -> int: + if pd.isna(value): + return cls._IGNORE_LABEL + normalized = str(value).strip().upper() + if normalized in cls._INVALID_LABELS: + return cls._IGNORE_LABEL + return cls._LABEL_MAP.get(normalized, cls._IGNORE_LABEL) + + @staticmethod + def _load_processed_payload(path: Path) -> Dict[str, Any]: + if path.suffix.lower() == ".npz": + with np.load(path, allow_pickle=True) as data: + return {key: data[key] for key in data.files} + + payload = np.load(path, allow_pickle=True) + if isinstance(payload, np.ndarray) and payload.shape == (): + payload = payload.item() + if not isinstance(payload, dict): + raise ValueError( + f"Unsupported DREAMT numpy payload in {path}. " + "Expected a dict-like object with features and labels." + ) + return payload + + def _load_signal_source(self, event) -> tuple[Path, Optional[float]]: + event_sampling_rate = getattr(event, "sampling_rate_hz", None) + if event_sampling_rate is not None: + try: + event_sampling_rate = float(event_sampling_rate) + except (TypeError, ValueError): + event_sampling_rate = None + + candidates: list[tuple[Optional[str], Optional[float]]] = [] + if self.source_preference == "wearable": + candidates.extend( + [ + (getattr(event, "file_64hz", None), 64.0), + (getattr(event, "signal_file", None), event_sampling_rate), + (getattr(event, "file_100hz", None), 100.0), + ] + ) + elif self.source_preference == "psg": + candidates.extend( + [ + (getattr(event, "file_100hz", None), 100.0), + (getattr(event, "signal_file", None), event_sampling_rate), + (getattr(event, "file_64hz", None), 64.0), + ] + ) + else: + candidates.extend( + [ + (getattr(event, "signal_file", None), event_sampling_rate), + (getattr(event, "file_64hz", None), 64.0), + (getattr(event, "file_100hz", None), 100.0), + ] + ) + + for file_path, sample_rate in candidates: + if file_path is None or ( + isinstance(file_path, float) and np.isnan(file_path) + ): + continue + path = Path(str(file_path)).expanduser().resolve() + if path.exists(): + return path, sample_rate + + raise FileNotFoundError( + "No DREAMT signal file was found for the requested patient event." + ) + + def _dataframe_from_payload( + self, + payload: Dict[str, Any], + path: Path, + ) -> pd.DataFrame: + if "frame" in payload: + frame = payload["frame"] + if isinstance(frame, pd.DataFrame): + return frame.copy() + return pd.DataFrame(frame) + + if "features" not in payload or "labels" not in payload: + raise ValueError( + f"Processed DREAMT file {path} must contain 'features' and 'labels'." + ) + + features = np.asarray(payload["features"], dtype=np.float32) + labels = np.asarray(payload["labels"]) + if features.ndim != 2: + raise ValueError( + f"Processed DREAMT file {path} must provide features with shape [T, F]." + ) + if labels.ndim != 1 or labels.shape[0] != features.shape[0]: + raise ValueError( + f"Processed DREAMT file {path} must provide labels with shape [T]." + ) + + feature_names = payload.get("feature_names", self.feature_columns) + feature_names = [str(name) for name in feature_names] + if len(feature_names) != features.shape[1]: + raise ValueError( + f"feature_names length does not match feature width in {path}." + ) + + frame = pd.DataFrame(features, columns=feature_names) + frame[self.label_column] = labels + if "timestamps" in payload: + frame["TIMESTAMP"] = np.asarray(payload["timestamps"]) + return frame + + def _load_frame(self, path: Path) -> pd.DataFrame: + suffix = path.suffix.lower() + if suffix == ".csv": + return pd.read_csv(path) + if suffix == ".parquet": + return pd.read_parquet(path) + if suffix in {".pkl", ".pickle"}: + payload = pd.read_pickle(path) + if isinstance(payload, pd.DataFrame): + return payload + if isinstance(payload, dict): + return self._dataframe_from_payload(payload, path) + raise ValueError(f"Unsupported pickle payload in {path}.") + if suffix in {".npz", ".npy"}: + payload = self._load_processed_payload(path) + return self._dataframe_from_payload(payload, path) + raise ValueError(f"Unsupported DREAMT file format: {path.suffix}") + + @staticmethod + def _infer_sampling_rate_hz( + frame: pd.DataFrame, + fallback: Optional[float], + ) -> float: + if "TIMESTAMP" in frame.columns: + timestamps = pd.to_numeric(frame["TIMESTAMP"], errors="coerce").to_numpy() + diffs = np.diff(timestamps) + diffs = diffs[np.isfinite(diffs) & (diffs > 0)] + if diffs.size > 0: + median_step = float(np.median(diffs)) + if median_step > 0: + return 1.0 / median_step + if fallback is not None and fallback > 0: + return float(fallback) + raise ValueError( + "Unable to infer DREAMT sampling rate. Provide TIMESTAMP values or " + "set default_sampling_rate_hz." + ) + + def _resolve_window_params(self, sample_rate_hz: float) -> tuple[int, int]: + window_size = self.window_size + if window_size is None: + window_size = max(1, int(round(self.window_seconds * sample_rate_hz))) + + stride = self.stride + if stride is None: + if self.stride_seconds is not None: + stride = max(1, int(round(self.stride_seconds * sample_rate_hz))) + else: + stride = window_size + return window_size, stride + + def _extract_feature_frame(self, frame: pd.DataFrame, path: Path) -> pd.DataFrame: + available_columns = {str(column): column for column in frame.columns} + feature_data = {} + for column in self.feature_columns: + source_column = available_columns.get(column) + if source_column is None: + feature_data[column] = np.zeros(len(frame), dtype=np.float32) + continue + values = pd.to_numeric(frame[source_column], errors="coerce") + feature_data[column] = values.to_numpy(dtype=np.float32) + + if not feature_data: + raise ValueError(f"No wearable feature columns were found in {path}.") + + feature_frame = pd.DataFrame(feature_data) + feature_frame = feature_frame.ffill().bfill().fillna(0.0) + return feature_frame + + def _extract_labels(self, frame: pd.DataFrame, path: Path) -> np.ndarray: + if self.label_column not in frame.columns: + raise ValueError( + f"DREAMT file {path} is missing the label column '{self.label_column}'." + ) + labels = frame[self.label_column].apply(self._normalize_stage_label).to_numpy() + return labels.astype(np.int64, copy=False) + + @staticmethod + def _majority_label(labels: np.ndarray) -> int: + valid = labels[labels >= 0] + if valid.size == 0: + return SleepStagingDREAMT._IGNORE_LABEL + return int(np.bincount(valid).argmax()) + + def __call__(self, patient) -> list[dict[str, Any]]: + events = patient.get_events("dreamt_sleep") + if not events: + return [] + + samples = [] + for event in events: + signal_path, sampling_rate_hz = self._load_signal_source(event) + frame = self._load_frame(signal_path) + labels = self._extract_labels(frame, signal_path) + valid_mask = labels >= 0 + + if valid_mask.sum() == 0: + continue + + frame = frame.loc[valid_mask].reset_index(drop=True) + labels = labels[valid_mask] + feature_frame = self._extract_feature_frame(frame, signal_path) + features = feature_frame.to_numpy(dtype=np.float32, copy=False) + + inferred_rate = self._infer_sampling_rate_hz( + frame, + sampling_rate_hz or self.default_sampling_rate_hz, + ) + window_size, stride = self._resolve_window_params(inferred_rate) + + total_length = features.shape[0] + if total_length < window_size and not self.pad_short_windows: + continue + + starts = list(range(0, max(total_length - window_size + 1, 1), stride)) + if ( + self.include_partial_last_window + and total_length > window_size + and starts[-1] + window_size < total_length + ): + starts.append(starts[-1] + stride) + + for window_index, start in enumerate(starts): + end = min(start + window_size, total_length) + window_features = features[start:end] + window_labels = labels[start:end] + + labeled_fraction = float((window_labels >= 0).mean()) + if labeled_fraction < self.min_labeled_fraction: + continue + + label = self._majority_label(window_labels) + if label == self._IGNORE_LABEL: + continue + + if end - start < window_size: + if not self.pad_short_windows: + continue + pad_rows = window_size - (end - start) + window_features = np.pad( + window_features, + pad_width=((0, pad_rows), (0, 0)), + mode="constant", + ) + + record_id = f"{patient.patient_id}-{signal_path.stem}-{window_index}" + samples.append( + { + "patient_id": patient.patient_id, + "record_id": record_id, + "signal": window_features.astype(np.float32, copy=False), + "label": int(label), + "signal_source": getattr(event, "signal_source", None), + "signal_file": str(signal_path), + } + ) + + return samples + + +class SleepStagingDREAMTSeq(SleepStagingDREAMT): + """Sequence-style DREAMT sleep staging task closer to WatchSleepNet. + + This task converts a DREAMT recording into a sequence of epoch-level + feature vectors and labels. Each sample contains: + + - ``signal``: ``[sequence_length, input_dim]`` + - ``mask``: ``[sequence_length]`` with 1 for valid epochs and 0 for padding + - ``label``: ``[sequence_length]`` with padded labels set to + ``ignore_index`` + + By default, this class uses ``IBI`` only to better align with the paper's + shared-modality representation. + """ + + task_name: str = "SleepStagingDREAMTSeq" + input_schema: Dict[str, str] = { + "signal": "tensor", + "mask": "tensor", + } + output_schema: Dict[str, str] = {"label": "tensor"} + + def __init__( + self, + feature_columns: Optional[Sequence[str]] = ("IBI",), + label_column: str = "Sleep_Stage", + source_preference: str = "wearable", + epoch_seconds: float = 30.0, + sequence_length: int = 1100, + stride_epochs: Optional[int] = None, + default_sampling_rate_hz: Optional[float] = None, + pad_value: float = 0.0, + truncate: bool = True, + ignore_index: int = -100, + ) -> None: + super().__init__( + feature_columns=feature_columns, + label_column=label_column, + source_preference=source_preference, + default_sampling_rate_hz=default_sampling_rate_hz, + ) + if epoch_seconds <= 0: + raise ValueError("epoch_seconds must be positive.") + if sequence_length <= 0: + raise ValueError("sequence_length must be positive.") + + self.epoch_seconds = float(epoch_seconds) + self.sequence_length = int(sequence_length) + self.stride_epochs = stride_epochs + self.pad_value = float(pad_value) + self.truncate = truncate + self.ignore_index = int(ignore_index) + + def _epochize_features_and_labels( + self, + features: np.ndarray, + labels: np.ndarray, + sample_rate_hz: float, + ) -> tuple[np.ndarray, np.ndarray]: + epoch_size = max(1, int(round(self.epoch_seconds * sample_rate_hz))) + num_epochs = features.shape[0] // epoch_size + if num_epochs == 0: + return ( + np.zeros((0, features.shape[1]), dtype=np.float32), + np.zeros((0,), dtype=np.int64), + ) + + epoch_features: list[np.ndarray] = [] + epoch_labels: list[int] = [] + for epoch_index in range(num_epochs): + start = epoch_index * epoch_size + end = start + epoch_size + epoch_feature_values = features[start:end] + epoch_label_values = labels[start:end] + epoch_label = self._majority_label(epoch_label_values) + if epoch_label == self._IGNORE_LABEL: + continue + epoch_features.append( + epoch_feature_values.mean(axis=0).astype(np.float32, copy=False) + ) + epoch_labels.append(epoch_label) + + if not epoch_features: + return ( + np.zeros((0, features.shape[1]), dtype=np.float32), + np.zeros((0,), dtype=np.int64), + ) + + return ( + np.stack(epoch_features).astype(np.float32, copy=False), + np.asarray(epoch_labels, dtype=np.int64), + ) + + def _build_sequence_sample( + self, + patient_id: str, + signal_path: Path, + epoch_features: np.ndarray, + epoch_labels: np.ndarray, + chunk_index: int, + start_epoch: int, + ) -> dict[str, Any]: + valid_length = min(epoch_features.shape[0], self.sequence_length) + feature_dim = epoch_features.shape[1] + + signal = np.full( + (self.sequence_length, feature_dim), + fill_value=self.pad_value, + dtype=np.float32, + ) + mask = np.zeros((self.sequence_length,), dtype=np.float32) + labels = np.full( + (self.sequence_length,), + fill_value=self.ignore_index, + dtype=np.int64, + ) + + signal[:valid_length] = epoch_features[:valid_length] + mask[:valid_length] = 1.0 + labels[:valid_length] = epoch_labels[:valid_length] + + return { + "patient_id": patient_id, + "record_id": f"{patient_id}-{signal_path.stem}-seq-{chunk_index}", + "signal": signal, + "mask": mask, + "label": labels, + "signal_source": None, + "signal_file": str(signal_path), + "start_epoch": int(start_epoch), + } + + def __call__(self, patient) -> list[dict[str, Any]]: + events = patient.get_events("dreamt_sleep") + if not events: + return [] + + samples = [] + for event in events: + signal_path, sampling_rate_hz = self._load_signal_source(event) + frame = self._load_frame(signal_path) + labels = self._extract_labels(frame, signal_path) + feature_frame = self._extract_feature_frame(frame, signal_path) + features = feature_frame.to_numpy(dtype=np.float32, copy=False) + + inferred_rate = self._infer_sampling_rate_hz( + frame, + sampling_rate_hz or self.default_sampling_rate_hz, + ) + epoch_features, epoch_labels = self._epochize_features_and_labels( + features, + labels, + inferred_rate, + ) + if epoch_features.shape[0] == 0: + continue + + stride_epochs = self.stride_epochs or self.sequence_length + if epoch_features.shape[0] <= self.sequence_length: + samples.append( + self._build_sequence_sample( + patient.patient_id, + signal_path, + epoch_features, + epoch_labels, + chunk_index=0, + start_epoch=0, + ) + ) + continue + + max_start = epoch_features.shape[0] - self.sequence_length + starts = list(range(0, max_start + 1, stride_epochs)) + if ( + not self.truncate + and starts[-1] != max_start + ): + starts.append(max_start) + + for chunk_index, start_epoch in enumerate(starts): + end_epoch = start_epoch + self.sequence_length + samples.append( + self._build_sequence_sample( + patient.patient_id, + signal_path, + epoch_features[start_epoch:end_epoch], + epoch_labels[start_epoch:end_epoch], + chunk_index=chunk_index, + start_epoch=start_epoch, + ) + ) + + return samples + + if __name__ == "__main__": from pyhealth.datasets import SleepEDFDataset, SHHSDataset, ISRUCDataset diff --git a/tests/core/test_dreamt_dataset.py b/tests/core/test_dreamt_dataset.py new file mode 100644 index 000000000..c0febcfc9 --- /dev/null +++ b/tests/core/test_dreamt_dataset.py @@ -0,0 +1,126 @@ +import shutil +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import pandas as pd + +from pyhealth.datasets import DREAMTDataset + + +def _create_synthetic_dreamt_root(root: Path) -> Path: + dreamt_root = root / "physionet.org" / "files" / "dreamt" / "2.1.0" + (dreamt_root / "data_64Hz").mkdir(parents=True) + (dreamt_root / "data_100Hz").mkdir(parents=True) + + participant_info = pd.DataFrame( + { + "SID": ["S001", "S002"], + "AGE": [28.0, 35.0], + "GENDER": ["F", "M"], + "BMI": [21.0, 24.5], + "OAHI": [1.0, 2.0], + "AHI": [3.0, 4.0], + "Mean_SaO2": ["97%", "96%"], + "Arousal Index": [10.0, 12.0], + "MEDICAL_HISTORY": ["None", "None"], + "Sleep_Disorders": ["OSA", "None"], + } + ) + participant_info.to_csv(dreamt_root / "participant_info.csv", index=False) + + stage_pattern = ["P"] * 10 + ["W"] * 10 + ["N2"] * 10 + ["R"] * 10 + timestamps = np.arange(len(stage_pattern), dtype=np.float32) / 64.0 + + for patient_id in ["S001", "S002"]: + frame = pd.DataFrame( + { + "TIMESTAMP": timestamps, + "IBI": np.linspace(0.8, 1.2, len(stage_pattern)), + "HR": np.linspace(60.0, 68.0, len(stage_pattern)), + "BVP": np.linspace(0.1, 0.5, len(stage_pattern)), + "EDA": np.linspace(0.01, 0.04, len(stage_pattern)), + "TEMP": np.linspace(33.0, 33.5, len(stage_pattern)), + "ACC_X": np.zeros(len(stage_pattern)), + "ACC_Y": np.ones(len(stage_pattern)), + "ACC_Z": np.full(len(stage_pattern), 2.0), + "Sleep_Stage": stage_pattern, + } + ) + frame.to_csv( + dreamt_root / "data_64Hz" / f"{patient_id}_whole_df.csv", + index=False, + ) + frame.to_csv( + dreamt_root / "data_100Hz" / f"{patient_id}_PSG_df_updated.csv", + index=False, + ) + + return dreamt_root + + +class TestDREAMTDataset(unittest.TestCase): + def setUp(self): + self.temp_dir = Path(tempfile.mkdtemp()) + self.synthetic_root = _create_synthetic_dreamt_root(self.temp_dir) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_resolves_nested_root_and_creates_metadata(self): + dataset = DREAMTDataset( + root=str(self.temp_dir), + cache_dir=self.temp_dir / "cache", + ) + + self.assertEqual( + Path(dataset.root).resolve(), + self.synthetic_root.resolve(), + ) + self.assertEqual(dataset.dataset_name, "dreamt_sleep") + self.assertTrue((self.synthetic_root / "dreamt-metadata.csv").exists()) + self.assertEqual(len(dataset.unique_patient_ids), 2) + + def test_patient_event_contains_signal_references(self): + dataset = DREAMTDataset( + root=str(self.synthetic_root), + preferred_source="psg", + cache_dir=self.temp_dir / "cache_psg", + ) + + patient = dataset.get_patient("S001") + event = patient.get_events("dreamt_sleep")[0] + + self.assertTrue(event.signal_file.endswith("S001_PSG_df_updated.csv")) + self.assertTrue(event.file_64hz.endswith("S001_whole_df.csv")) + self.assertTrue(event.file_100hz.endswith("S001_PSG_df_updated.csv")) + self.assertEqual(event.signal_source, "psg") + self.assertEqual(float(event.sampling_rate_hz), 100.0) + + def test_missing_root_raises_helpful_error(self): + with self.assertRaises(FileNotFoundError): + DREAMTDataset(root=str(self.temp_dir / "missing")) + + def test_partial_download_drops_missing_subjects(self): + missing_wearable_file = self.synthetic_root / "data_64Hz" / "S002_whole_df.csv" + missing_psg_file = ( + self.synthetic_root / "data_100Hz" / "S002_PSG_df_updated.csv" + ) + missing_wearable_file.unlink() + missing_psg_file.unlink() + metadata_file = self.synthetic_root / "dreamt-metadata.csv" + if metadata_file.exists(): + metadata_file.unlink() + + dataset = DREAMTDataset( + root=str(self.synthetic_root), + cache_dir=self.temp_dir / "cache_partial", + ) + + self.assertEqual(len(dataset.unique_patient_ids), 1) + self.assertEqual(dataset.unique_patient_ids[0], "S001") + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_sleep_staging_task.py b/tests/core/test_sleep_staging_task.py new file mode 100644 index 000000000..84f9a5a04 --- /dev/null +++ b/tests/core/test_sleep_staging_task.py @@ -0,0 +1,209 @@ +import shutil +import tempfile +import unittest +from pathlib import Path + +import numpy as np +import pandas as pd + +from pyhealth.datasets import DREAMTDataset +from pyhealth.tasks import SleepStagingDREAMT, SleepStagingDREAMTSeq + + +def _create_task_test_root(root: Path) -> Path: + dreamt_root = root / "dreamt" + (dreamt_root / "data_64Hz").mkdir(parents=True) + + participant_info = pd.DataFrame( + { + "SID": ["S001"], + "AGE": [29.0], + "GENDER": ["F"], + "BMI": [22.0], + "OAHI": [1.0], + "AHI": [1.0], + "Mean_SaO2": ["98%"], + "Arousal Index": [9.0], + "MEDICAL_HISTORY": ["None"], + "Sleep_Disorders": ["None"], + } + ) + participant_info.to_csv(dreamt_root / "participant_info.csv", index=False) + + stages = ["P"] * 10 + ["W"] * 10 + ["N1"] * 10 + ["R"] * 10 + timestamps = [index / 64.0 for index in range(len(stages))] + frame = pd.DataFrame( + { + "TIMESTAMP": timestamps, + "IBI": [0.8 + 0.01 * index for index in range(len(stages))], + "HR": [60.0 + index for index in range(len(stages))], + "BVP": [0.1] * len(stages), + "EDA": [0.01] * len(stages), + "TEMP": [33.2] * len(stages), + "ACC_X": [0.0] * len(stages), + "ACC_Y": [1.0] * len(stages), + "ACC_Z": [2.0] * len(stages), + "Sleep_Stage": stages, + } + ) + frame.to_csv(dreamt_root / "data_64Hz" / "S001_whole_df.csv", index=False) + return dreamt_root + + +def _create_seq_task_test_root(root: Path) -> Path: + dreamt_root = root / "dreamt_seq" + (dreamt_root / "data_64Hz").mkdir(parents=True) + + participant_info = pd.DataFrame( + { + "SID": ["S001"], + "AGE": [29.0], + "GENDER": ["F"], + "BMI": [22.0], + "OAHI": [1.0], + "AHI": [1.0], + "Mean_SaO2": ["98%"], + "Arousal Index": [9.0], + "MEDICAL_HISTORY": ["None"], + "Sleep_Disorders": ["None"], + } + ) + participant_info.to_csv(dreamt_root / "participant_info.csv", index=False) + + stage_blocks = [ + ("P", 30), + ("W", 30), + ("N2", 30), + ("R", 30), + ] + labels = [label for label, count in stage_blocks for _ in range(count)] + timestamps = np.arange(len(labels), dtype=np.float32) + frame = pd.DataFrame( + { + "TIMESTAMP": timestamps, + "IBI": np.linspace(0.8, 1.3, len(labels)), + "HR": np.linspace(60.0, 72.0, len(labels)), + "BVP": np.linspace(0.1, 0.4, len(labels)), + "EDA": np.linspace(0.01, 0.03, len(labels)), + "TEMP": np.full(len(labels), 33.2), + "ACC_X": np.zeros(len(labels)), + "ACC_Y": np.ones(len(labels)), + "ACC_Z": np.full(len(labels), 2.0), + "Sleep_Stage": labels, + } + ) + frame.to_csv(dreamt_root / "data_64Hz" / "S001_whole_df.csv", index=False) + return dreamt_root + + +class TestSleepStagingDREAMT(unittest.TestCase): + def setUp(self): + self.temp_dir = Path(tempfile.mkdtemp()) + self.root = _create_task_test_root(self.temp_dir) + self.dataset = DREAMTDataset( + root=str(self.root), + cache_dir=self.temp_dir / "cache", + ) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_task_maps_labels_and_extracts_fixed_windows(self): + task = SleepStagingDREAMT( + window_size=10, + stride=10, + source_preference="wearable", + ) + patient = self.dataset.get_patient("S001") + + samples = task(patient) + + self.assertEqual(len(samples), 3) + self.assertEqual([sample["label"] for sample in samples], [0, 1, 2]) + self.assertEqual(samples[0]["signal"].shape, (10, 8)) + self.assertEqual(samples[0]["patient_id"], "S001") + self.assertIn("record_id", samples[0]) + + def test_set_task_builds_sample_dataset(self): + task = SleepStagingDREAMT( + window_size=10, + stride=10, + source_preference="wearable", + ) + sample_dataset = self.dataset.set_task(task=task, num_workers=1) + + self.assertEqual(len(sample_dataset), 3) + sample = sample_dataset[0] + self.assertIn("signal", sample) + self.assertIn("label", sample) + self.assertEqual(tuple(sample["signal"].shape), (10, 8)) + +class TestSleepStagingDREAMTSeq(unittest.TestCase): + def setUp(self): + self.temp_dir = Path(tempfile.mkdtemp()) + self.root = _create_seq_task_test_root(self.temp_dir) + self.dataset = DREAMTDataset( + root=str(self.root), + cache_dir=self.temp_dir / "cache", + ) + + def tearDown(self): + shutil.rmtree(self.temp_dir) + + def test_seq_task_builds_epoch_sequence(self): + task = SleepStagingDREAMTSeq( + epoch_seconds=30, + sequence_length=5, + source_preference="wearable", + ) + patient = self.dataset.get_patient("S001") + + samples = task(patient) + + self.assertEqual(len(samples), 1) + self.assertEqual(samples[0]["signal"].shape, (5, 1)) + self.assertEqual(samples[0]["mask"].tolist(), [1.0, 1.0, 1.0, 0.0, 0.0]) + self.assertEqual(samples[0]["label"][:3].tolist(), [0, 1, 2]) + + def test_seq_task_outputs_mask_and_padded_labels(self): + task = SleepStagingDREAMTSeq( + epoch_seconds=30, + sequence_length=5, + source_preference="wearable", + ignore_index=-100, + ) + patient = self.dataset.get_patient("S001") + sample = task(patient)[0] + + self.assertTrue(np.allclose(sample["signal"][3:], 0.0)) + self.assertEqual(sample["label"][3:].tolist(), [-100, -100]) + + def test_seq_task_ibi_only_mode(self): + task = SleepStagingDREAMTSeq( + feature_columns=("IBI",), + epoch_seconds=30, + sequence_length=5, + source_preference="wearable", + ) + patient = self.dataset.get_patient("S001") + sample = task(patient)[0] + + self.assertEqual(sample["signal"].shape[1], 1) + + def test_seq_set_task_builds_sample_dataset(self): + task = SleepStagingDREAMTSeq( + epoch_seconds=30, + sequence_length=5, + source_preference="wearable", + ) + sample_dataset = self.dataset.set_task(task=task, num_workers=1) + sample = sample_dataset[0] + + self.assertIn("signal", sample) + self.assertIn("mask", sample) + self.assertIn("label", sample) + self.assertEqual(tuple(sample["signal"].shape), (5, 1)) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/core/test_watchsleepnet.py b/tests/core/test_watchsleepnet.py new file mode 100644 index 000000000..e7d2417e2 --- /dev/null +++ b/tests/core/test_watchsleepnet.py @@ -0,0 +1,197 @@ +import unittest + +import numpy as np +import torch + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import WatchSleepNet + + +class TestWatchSleepNet(unittest.TestCase): + def setUp(self): + rng = np.random.default_rng(42) + self.samples = [ + { + "patient_id": f"patient-{index}", + "record_id": f"record-{index}", + "signal": rng.normal(size=(24, 8)).astype(np.float32), + "label": index % 3, + } + for index in range(6) + ] + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema={"signal": "tensor"}, + output_schema={"label": "multiclass"}, + dataset_name="watchsleepnet_test", + ) + self.model = WatchSleepNet( + dataset=self.dataset, + hidden_dim=16, + conv_channels=16, + conv_blocks=2, + tcn_blocks=2, + num_attention_heads=4, + dropout=0.1, + ) + self.sequence_samples = [ + { + "patient_id": f"seq-patient-{index}", + "record_id": f"seq-record-{index}", + "signal": rng.normal(size=(5, 4)).astype(np.float32), + "mask": np.array([1, 1, 1, 0, 0], dtype=np.float32), + "label": np.array([0, 1, 2, -100, -100], dtype=np.int64), + } + for index in range(4) + ] + self.sequence_dataset = create_sample_dataset( + samples=self.sequence_samples, + input_schema={"signal": "tensor", "mask": "tensor"}, + output_schema={"label": "tensor"}, + dataset_name="watchsleepnet_seq_test", + ) + + def test_model_initialization(self): + self.assertIsInstance(self.model, WatchSleepNet) + self.assertEqual(self.model.feature_key, "signal") + self.assertEqual(self.model.input_dim, 8) + self.assertEqual(self.model.hidden_dim, 16) + self.assertEqual(self.model.conv_channels, 16) + + def test_forward_shapes(self): + loader = get_dataloader(self.dataset, batch_size=3, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + output = self.model(**batch) + + self.assertIn("loss", output) + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + self.assertIn("logit", output) + self.assertEqual(tuple(output["logit"].shape), (3, 3)) + self.assertEqual(tuple(output["y_prob"].shape), (3, 3)) + self.assertEqual(tuple(output["y_true"].shape), (3,)) + + def test_backward_pass(self): + loader = get_dataloader(self.dataset, batch_size=3, shuffle=False) + batch = next(iter(loader)) + + output = self.model(**batch) + output["loss"].backward() + + has_grad = any( + parameter.requires_grad and parameter.grad is not None + for parameter in self.model.parameters() + ) + self.assertTrue(has_grad) + + def test_without_attention_and_tcn(self): + model = WatchSleepNet( + dataset=self.dataset, + hidden_dim=12, + conv_channels=12, + num_attention_heads=3, + use_attention=False, + use_tcn=False, + ) + loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + output = model(**batch) + + self.assertEqual(tuple(output["logit"].shape), (2, 3)) + + def test_sequence_output_shapes(self): + model = WatchSleepNet( + dataset=self.sequence_dataset, + input_dim=4, + hidden_dim=12, + conv_channels=12, + num_attention_heads=3, + num_classes=3, + sequence_output=True, + ) + loader = get_dataloader(self.sequence_dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + output = model(**batch) + + self.assertEqual(tuple(output["logit"].shape), (2, 5, 3)) + self.assertEqual(tuple(output["y_prob"].shape), (2, 5, 3)) + self.assertEqual(tuple(output["y_true"].shape), (2, 5)) + + def test_sequence_mode_requires_num_classes(self): + with self.assertRaises(ValueError): + WatchSleepNet( + dataset=self.sequence_dataset, + input_dim=4, + hidden_dim=12, + conv_channels=12, + num_attention_heads=3, + sequence_output=True, + ) + + def test_sequence_loss_with_padding(self): + model = WatchSleepNet( + dataset=self.sequence_dataset, + input_dim=4, + hidden_dim=12, + conv_channels=12, + num_attention_heads=3, + num_classes=3, + sequence_output=True, + ) + loader = get_dataloader(self.sequence_dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + output = model(**batch) + + self.assertEqual(output["loss"].dim(), 0) + self.assertTrue(torch.isfinite(output["loss"])) + + def test_sequence_backward_pass(self): + model = WatchSleepNet( + dataset=self.sequence_dataset, + input_dim=4, + hidden_dim=12, + conv_channels=12, + num_attention_heads=3, + num_classes=3, + sequence_output=True, + ) + loader = get_dataloader(self.sequence_dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + output = model(**batch) + output["loss"].backward() + + has_grad = any( + parameter.requires_grad and parameter.grad is not None + for parameter in model.parameters() + ) + self.assertTrue(has_grad) + + def test_sequence_mode_without_attention_and_tcn(self): + model = WatchSleepNet( + dataset=self.sequence_dataset, + input_dim=4, + hidden_dim=12, + conv_channels=12, + num_attention_heads=3, + num_classes=3, + sequence_output=True, + use_attention=False, + use_tcn=False, + ) + loader = get_dataloader(self.sequence_dataset, batch_size=2, shuffle=False) + batch = next(iter(loader)) + + with torch.no_grad(): + output = model(**batch) + + self.assertEqual(tuple(output["logit"].shape), (2, 5, 3)) + + +if __name__ == "__main__": + unittest.main()