diff --git a/docs/api/datasets/pyhealth.datasets.DREAMTDataset.rst b/docs/api/datasets/pyhealth.datasets.DREAMTDataset.rst
index 61cc16c2e..3d45696be 100644
--- a/docs/api/datasets/pyhealth.datasets.DREAMTDataset.rst
+++ b/docs/api/datasets/pyhealth.datasets.DREAMTDataset.rst
@@ -1,21 +1,24 @@
pyhealth.datasets.DREAMTDataset
===================================
-The Dataset for Real-time sleep stage EstimAtion using Multisensor wearable Technology (DREAMT) includes wrist-based wearable and polysomnography (PSG) sleep data from 100 participants recruited from the Duke University Health System (DUHS) Sleep Disorder Lab.
+The Dataset for Real-time sleep stage EstimAtion using Multisensor wearable
+Technology (DREAMT) includes wrist-based wearable and polysomnography (PSG)
+sleep data from 100 participants recruited from the Duke University Health
+System (DUHS) Sleep Disorder Lab.
-This includes wearable signals, PSG signals, sleep labels, and clinical data related to sleep health and disorders.
+This includes wearable signals, PSG signals, sleep labels, and clinical data
+related to sleep health and disorders.
-The DREAMTDataset class provides an interface for loading and working with the DREAMT dataset. It can process DREAMT data across versions into a well-structured dataset object providing support for modeling and analysis.
+``DREAMTDataset`` supports both official DREAMT release layouts and partial
+local subsets. It builds metadata linking each patient to locally available
+signal files and exposes data for both the simplified
+``SleepStagingDREAMT`` window-classification task and the more sequence-style
+``SleepStagingDREAMTSeq`` task.
-Refer to the `doc `_ for more information about the dataset.
+Refer to the `doc `_ for more
+information about the dataset.
.. autoclass:: pyhealth.datasets.DREAMTDataset
:members:
:undoc-members:
:show-inheritance:
-
-
-
-
-
-
\ No newline at end of file
diff --git a/docs/api/models.rst b/docs/api/models.rst
index 7368dec94..789f1cf21 100644
--- a/docs/api/models.rst
+++ b/docs/api/models.rst
@@ -190,6 +190,7 @@ API Reference
models/pyhealth.models.SparcNet
models/pyhealth.models.StageNet
models/pyhealth.models.StageAttentionNet
+ models/pyhealth.models.WatchSleepNet
models/pyhealth.models.AdaCare
models/pyhealth.models.ConCare
models/pyhealth.models.Agent
diff --git a/docs/api/models/pyhealth.models.WatchSleepNet.rst b/docs/api/models/pyhealth.models.WatchSleepNet.rst
new file mode 100644
index 000000000..4b3cd3627
--- /dev/null
+++ b/docs/api/models/pyhealth.models.WatchSleepNet.rst
@@ -0,0 +1,14 @@
+pyhealth.models.WatchSleepNet
+===================================
+
+Simplified WatchSleepNet-style architecture for wearable sleep staging.
+
+The implementation supports both:
+
+- pooled classification over fixed windows
+- sequence-output classification for epoch-level sleep staging
+
+.. autoclass:: pyhealth.models.WatchSleepNet
+ :members:
+ :undoc-members:
+ :show-inheritance:
diff --git a/docs/api/tasks/pyhealth.tasks.sleep_staging.rst b/docs/api/tasks/pyhealth.tasks.sleep_staging.rst
index eab36ee9b..a9a5a3e0d 100644
--- a/docs/api/tasks/pyhealth.tasks.sleep_staging.rst
+++ b/docs/api/tasks/pyhealth.tasks.sleep_staging.rst
@@ -1,6 +1,19 @@
pyhealth.tasks.sleep_staging
=======================================
+``SleepStagingDREAMT`` provides a simplified window-classification task, while
+``SleepStagingDREAMTSeq`` provides a more paper-aligned sequence-style task.
+
+.. autoclass:: pyhealth.tasks.SleepStagingDREAMT
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
+.. autoclass:: pyhealth.tasks.SleepStagingDREAMTSeq
+ :members:
+ :undoc-members:
+ :show-inheritance:
+
.. autofunction:: pyhealth.tasks.sleep_staging.sleep_staging_isruc_fn
.. autofunction:: pyhealth.tasks.sleep_staging.sleep_staging_sleepedf_fn
-.. autofunction:: pyhealth.tasks.sleep_staging.sleep_staging_shhs_fn
\ No newline at end of file
+.. autofunction:: pyhealth.tasks.sleep_staging.sleep_staging_shhs_fn
diff --git a/examples/dreamt_sleep_staging_watchsleepnet.py b/examples/dreamt_sleep_staging_watchsleepnet.py
new file mode 100644
index 000000000..fca9bb04e
--- /dev/null
+++ b/examples/dreamt_sleep_staging_watchsleepnet.py
@@ -0,0 +1,377 @@
+"""Synthetic DREAMT sleep-staging example with WatchSleepNet ablations.
+
+This script demonstrates two usage patterns:
+
+1. A simplified window-classification pipeline.
+2. A more paper-aligned sequence-style pipeline using IBI-only epoch features.
+
+The sequence example reports metrics emphasized in the WatchSleepNet paper:
+accuracy, macro F1, REM F1, Cohen's kappa, and AUROC. The training loop is
+intentionally lightweight and uses synthetic DREAMT-compatible data only. It
+does not implement the paper's SHHS/MESA pretraining stage.
+"""
+
+from __future__ import annotations
+
+import shutil
+import tempfile
+from pathlib import Path
+from typing import Any
+
+import numpy as np
+import pandas as pd
+import torch
+
+from pyhealth.datasets import DREAMTDataset, get_dataloader
+from pyhealth.models import WatchSleepNet
+from pyhealth.tasks import SleepStagingDREAMT, SleepStagingDREAMTSeq
+
+NUM_CLASSES = 3
+REM_CLASS_INDEX = 2
+IGNORE_INDEX = -100
+
+
+def build_synthetic_dreamt_root(root: Path, num_subjects: int = 4) -> Path:
+ """Create a tiny DREAMT-style directory with wearable CSV files."""
+ dreamt_root = root / "dreamt"
+ (dreamt_root / "data_64Hz").mkdir(parents=True)
+
+ participant_rows = []
+ for subject_index in range(num_subjects):
+ patient_id = f"S{subject_index + 1:03d}"
+ participant_rows.append(
+ {
+ "SID": patient_id,
+ "AGE": 25 + subject_index,
+ "GENDER": "F" if subject_index % 2 == 0 else "M",
+ "BMI": 22.0 + subject_index,
+ "OAHI": 1.0,
+ "AHI": 2.0,
+ "Mean_SaO2": "97%",
+ "Arousal Index": 10.0,
+ "MEDICAL_HISTORY": "None",
+ "Sleep_Disorders": "None",
+ }
+ )
+
+ # Five epochs: one preparation epoch followed by Wake, NREM, REM, NREM.
+ labels = (
+ ["P"] * 30
+ + ["W"] * 30
+ + ["N2"] * 30
+ + ["R"] * 30
+ + ["N3"] * 30
+ )
+ timestamps = np.arange(len(labels), dtype=np.float32)
+ frame = pd.DataFrame(
+ {
+ "TIMESTAMP": timestamps,
+ "IBI": np.sin(timestamps * 0.05 * (subject_index + 1)) + 1.0,
+ "HR": 60.0 + 3.0 * np.cos(timestamps * 0.05),
+ "BVP": np.sin(timestamps * 0.03),
+ "EDA": np.linspace(0.01, 0.04, len(labels)),
+ "TEMP": np.full(len(labels), 33.0 + 0.1 * subject_index),
+ "ACC_X": np.zeros(len(labels)),
+ "ACC_Y": np.ones(len(labels)),
+ "ACC_Z": np.full(len(labels), 2.0),
+ "Sleep_Stage": labels,
+ }
+ )
+ frame.to_csv(
+ dreamt_root / "data_64Hz" / f"{patient_id}_whole_df.csv",
+ index=False,
+ )
+
+ pd.DataFrame(participant_rows).to_csv(
+ dreamt_root / "participant_info.csv",
+ index=False,
+ )
+ return dreamt_root
+
+
+def _to_numpy(value: Any) -> np.ndarray:
+ if isinstance(value, torch.Tensor):
+ return value.detach().cpu().numpy()
+ return np.asarray(value)
+
+
+def _macro_f1(y_true: np.ndarray, y_pred: np.ndarray, num_classes: int) -> float:
+ f1_scores = []
+ for class_index in range(num_classes):
+ true_pos = np.sum((y_true == class_index) & (y_pred == class_index))
+ false_pos = np.sum((y_true != class_index) & (y_pred == class_index))
+ false_neg = np.sum((y_true == class_index) & (y_pred != class_index))
+ precision = true_pos / max(true_pos + false_pos, 1)
+ recall = true_pos / max(true_pos + false_neg, 1)
+ if precision + recall == 0:
+ f1_scores.append(0.0)
+ else:
+ f1_scores.append(2 * precision * recall / (precision + recall))
+ return float(np.mean(f1_scores))
+
+
+def _class_f1(y_true: np.ndarray, y_pred: np.ndarray, positive_class: int) -> float:
+ true_pos = np.sum((y_true == positive_class) & (y_pred == positive_class))
+ false_pos = np.sum((y_true != positive_class) & (y_pred == positive_class))
+ false_neg = np.sum((y_true == positive_class) & (y_pred != positive_class))
+ precision = true_pos / max(true_pos + false_pos, 1)
+ recall = true_pos / max(true_pos + false_neg, 1)
+ if precision + recall == 0:
+ return 0.0
+ return float(2 * precision * recall / (precision + recall))
+
+
+def _cohen_kappa(y_true: np.ndarray, y_pred: np.ndarray, num_classes: int) -> float:
+ confusion = np.zeros((num_classes, num_classes), dtype=np.float64)
+ for true_label, pred_label in zip(y_true, y_pred):
+ confusion[int(true_label), int(pred_label)] += 1
+
+ total = confusion.sum()
+ if total == 0:
+ return 0.0
+ observed = np.trace(confusion) / total
+ expected = np.sum(confusion.sum(axis=0) * confusion.sum(axis=1)) / (total * total)
+ if expected >= 1.0:
+ return 0.0
+ return float((observed - expected) / (1.0 - expected))
+
+
+def _binary_auc(y_true: np.ndarray, y_score: np.ndarray) -> float:
+ positive_mask = y_true == 1
+ negative_mask = y_true == 0
+ num_pos = int(positive_mask.sum())
+ num_neg = int(negative_mask.sum())
+ if num_pos == 0 or num_neg == 0:
+ return float("nan")
+
+ order = np.argsort(y_score)
+ ranks = np.empty_like(order, dtype=np.float64)
+ ranks[order] = np.arange(1, len(y_score) + 1, dtype=np.float64)
+ positive_ranks = ranks[positive_mask]
+ auc = (positive_ranks.sum() - num_pos * (num_pos + 1) / 2.0) / (num_pos * num_neg)
+ return float(auc)
+
+
+def _multiclass_auroc(
+ y_true: np.ndarray,
+ y_prob: np.ndarray,
+ num_classes: int,
+) -> float:
+ aucs = []
+ for class_index in range(num_classes):
+ one_vs_rest = (y_true == class_index).astype(np.int64)
+ auc = _binary_auc(one_vs_rest, y_prob[:, class_index])
+ if not np.isnan(auc):
+ aucs.append(auc)
+ if not aucs:
+ return 0.0
+ return float(np.mean(aucs))
+
+
+def compute_metrics(
+ y_true: np.ndarray,
+ y_prob: np.ndarray,
+ num_classes: int = NUM_CLASSES,
+) -> dict[str, float]:
+ y_pred = y_prob.argmax(axis=-1)
+ accuracy = float((y_true == y_pred).mean()) if y_true.size else 0.0
+ macro_f1 = _macro_f1(y_true, y_pred, num_classes)
+ rem_f1 = _class_f1(y_true, y_pred, REM_CLASS_INDEX)
+ kappa = _cohen_kappa(y_true, y_pred, num_classes)
+ auroc = _multiclass_auroc(y_true, y_prob, num_classes)
+ return {
+ "accuracy": accuracy,
+ "macro_f1": macro_f1,
+ "rem_f1": rem_f1,
+ "cohen_kappa": kappa,
+ "auroc": auroc,
+ }
+
+
+def evaluate_model(model, loader, sequence_output: bool = False) -> dict[str, float]:
+ y_true_parts = []
+ y_prob_parts = []
+ model.eval()
+ with torch.no_grad():
+ for batch in loader:
+ output = model(**batch)
+ y_prob = _to_numpy(output["y_prob"])
+ y_true = _to_numpy(output["y_true"])
+ if sequence_output:
+ valid_mask = y_true != IGNORE_INDEX
+ if not np.any(valid_mask):
+ continue
+ y_true_parts.append(y_true[valid_mask])
+ y_prob_parts.append(y_prob[valid_mask])
+ else:
+ y_true_parts.append(y_true.reshape(-1))
+ y_prob_parts.append(y_prob.reshape(-1, y_prob.shape[-1]))
+
+ if not y_true_parts:
+ return {
+ "accuracy": 0.0,
+ "macro_f1": 0.0,
+ "rem_f1": 0.0,
+ "cohen_kappa": 0.0,
+ "auroc": 0.0,
+ }
+
+ y_true_all = np.concatenate(y_true_parts, axis=0)
+ y_prob_all = np.concatenate(y_prob_parts, axis=0)
+ return compute_metrics(y_true_all, y_prob_all)
+
+
+def train_one_epoch(model, loader, optimizer) -> float:
+ model.train()
+ running_loss = 0.0
+ num_batches = 0
+ for batch in loader:
+ optimizer.zero_grad()
+ output = model(**batch)
+ output["loss"].backward()
+ optimizer.step()
+ running_loss += float(output["loss"].item())
+ num_batches += 1
+ return running_loss / max(num_batches, 1)
+
+
+def run_window_ablation(sample_dataset) -> None:
+ """Run a simple window-classification ablation."""
+ configs = [
+ {"name": "baseline", "hidden_dim": 32, "use_tcn": True, "use_attention": True},
+ {
+ "name": "no_attention",
+ "hidden_dim": 32,
+ "use_tcn": True,
+ "use_attention": False,
+ },
+ {"name": "no_tcn", "hidden_dim": 32, "use_tcn": False, "use_attention": True},
+ {
+ "name": "small_hidden",
+ "hidden_dim": 16,
+ "use_tcn": True,
+ "use_attention": True,
+ },
+ ]
+ loader = get_dataloader(sample_dataset, batch_size=4, shuffle=True)
+
+ print("Window classification ablation")
+ for config in configs:
+ model = WatchSleepNet(
+ dataset=sample_dataset,
+ hidden_dim=config["hidden_dim"],
+ conv_channels=config["hidden_dim"],
+ num_attention_heads=4,
+ use_tcn=config["use_tcn"],
+ use_attention=config["use_attention"],
+ )
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+ mean_loss = train_one_epoch(model, loader, optimizer)
+ metrics = evaluate_model(model, loader, sequence_output=False)
+ print(
+ f"{config['name']:>12} | loss={mean_loss:.4f} "
+ f"| acc={metrics['accuracy']:.3f} "
+ f"| macro_f1={metrics['macro_f1']:.3f}"
+ )
+
+
+def run_sequence_ablation(sample_dataset, feature_variant: str) -> None:
+ """Run a sequence-style ablation closer to the paper."""
+ configs = [
+ {
+ "name": "baseline_seq",
+ "hidden_dim": 32,
+ "use_tcn": True,
+ "use_attention": True,
+ },
+ {
+ "name": "no_attn_seq",
+ "hidden_dim": 32,
+ "use_tcn": True,
+ "use_attention": False,
+ },
+ {
+ "name": "no_tcn_seq",
+ "hidden_dim": 32,
+ "use_tcn": False,
+ "use_attention": True,
+ },
+ {
+ "name": "small_seq",
+ "hidden_dim": 16,
+ "use_tcn": True,
+ "use_attention": True,
+ },
+ ]
+ loader = get_dataloader(sample_dataset, batch_size=2, shuffle=True)
+
+ print(f"\nSequence-style ablation ({feature_variant})")
+ for config in configs:
+ model = WatchSleepNet(
+ dataset=sample_dataset,
+ hidden_dim=config["hidden_dim"],
+ conv_channels=config["hidden_dim"],
+ num_attention_heads=4,
+ use_tcn=config["use_tcn"],
+ use_attention=config["use_attention"],
+ sequence_output=True,
+ num_classes=NUM_CLASSES,
+ ignore_index=IGNORE_INDEX,
+ )
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+ mean_loss = train_one_epoch(model, loader, optimizer)
+ metrics = evaluate_model(model, loader, sequence_output=True)
+ print(
+ f"{config['name']:>12} | loss={mean_loss:.4f} "
+ f"| acc={metrics['accuracy']:.3f} "
+ f"| macro_f1={metrics['macro_f1']:.3f} "
+ f"| rem_f1={metrics['rem_f1']:.3f} "
+ f"| kappa={metrics['cohen_kappa']:.3f} "
+ f"| auroc={metrics['auroc']:.3f}"
+ )
+
+
+def main() -> None:
+ temp_dir = Path(tempfile.mkdtemp(prefix="pyhealth_dreamt_example_"))
+ try:
+ dreamt_root = build_synthetic_dreamt_root(temp_dir)
+ dataset = DREAMTDataset(root=str(dreamt_root), cache_dir=temp_dir / "cache")
+
+ window_task = SleepStagingDREAMT(
+ window_size=30,
+ stride=30,
+ source_preference="wearable",
+ )
+ window_dataset = dataset.set_task(task=window_task, num_workers=1)
+ print(f"Generated {len(window_dataset)} synthetic window samples")
+ run_window_ablation(window_dataset)
+
+ seq_task = SleepStagingDREAMTSeq(
+ feature_columns=("IBI",),
+ epoch_seconds=30.0,
+ sequence_length=4,
+ source_preference="wearable",
+ ignore_index=IGNORE_INDEX,
+ )
+ seq_dataset = dataset.set_task(task=seq_task, num_workers=1)
+ print(f"\nGenerated {len(seq_dataset)} IBI-only sequence samples")
+ run_sequence_ablation(seq_dataset, feature_variant="IBI only")
+
+ rich_seq_task = SleepStagingDREAMTSeq(
+ feature_columns=SleepStagingDREAMT.DEFAULT_FEATURE_COLUMNS,
+ epoch_seconds=30.0,
+ sequence_length=4,
+ source_preference="wearable",
+ ignore_index=IGNORE_INDEX,
+ )
+ rich_seq_dataset = dataset.set_task(task=rich_seq_task, num_workers=1)
+ print(
+ f"Generated {len(rich_seq_dataset)} multi-signal sequence samples"
+ )
+ run_sequence_ablation(rich_seq_dataset, feature_variant="IBI + context")
+ finally:
+ shutil.rmtree(temp_dir, ignore_errors=True)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/pyhealth/datasets/configs/dreamt.yaml b/pyhealth/datasets/configs/dreamt.yaml
index ad6430e45..472c55a58 100644
--- a/pyhealth/datasets/configs/dreamt.yaml
+++ b/pyhealth/datasets/configs/dreamt.yaml
@@ -5,6 +5,7 @@ tables:
patient_id: "patient_id"
timestamp: null
attributes:
+ - "record_id"
- "age"
- "gender"
- "bmi"
@@ -13,5 +14,10 @@ tables:
- "mean_sao2"
- "arousal_index"
- "medical_history"
+ - "sleep_disorders"
+ - "signal_file"
+ - "signal_source"
+ - "sampling_rate_hz"
+ - "signal_format"
- "file_64hz"
- - "file_100hz"
\ No newline at end of file
+ - "file_100hz"
diff --git a/pyhealth/datasets/dreamt.py b/pyhealth/datasets/dreamt.py
index a7c43d23c..c39398597 100644
--- a/pyhealth/datasets/dreamt.py
+++ b/pyhealth/datasets/dreamt.py
@@ -1,179 +1,400 @@
import logging
-import os
+import re
from pathlib import Path
-from typing import Optional, Union
+from typing import Optional
import pandas as pd
-from pyhealth.datasets import BaseDataset
-logger = logging.getLogger(__name__)
-
-class DREAMTDataset(BaseDataset):
- """
- Base Dataset for Real-time sleep stage EstimAtion using Multisensor wearable Technology (DREAMT)
+from .base_dataset import BaseDataset
+from ..tasks.sleep_staging import SleepStagingDREAMT
- Dataset accepts current versions of DREAMT (1.0.0, 1.0.1, 2.0.0, 2.1.0), available at:
- https://physionet.org/content/dreamt/
-
- DREAMT includes wrist-based wearable and polysomnography (PSG) sleep data from 100 participants
- recruited from the Duke University Health System (DUHS) Sleep Disorder Lab. This includes
- wearable signals, PSG signals, sleep labels, and clinical data related to sleep health and disorders.
+logger = logging.getLogger(__name__)
- Citations:
- ---------
- When using this dataset, please cite:
- Wang, K., Yang, J., Shetty, A., & Dunn, J. (2025). DREAMT: Dataset for Real-time sleep stage EstimAtion
- using Multisensor wearable Technology (version 2.1.0). PhysioNet. RRID:SCR_007345.
- https://doi.org/10.13026/7r9r-7r24
+class DREAMTDataset(BaseDataset):
+ """Base dataset wrapper for DREAMT sleep-staging data.
- Will Ke Wang, Jiamu Yang, Leeor Hershkovich, Hayoung Jeong, Bill Chen, Karnika Singh, Ali R Roghanizad,
- Md Mobashir Hasan Shandhi, Andrew R Spector, Jessilyn Dunn. (2024). Proceedings of the fifth
- Conference on Health, Inference, and Learning, PMLR 248:380-396.
+ This dataset is designed for reproducibility work around DREAMT-style sleep
+ staging. It supports two common local layouts:
- Goldberger, A., Amaral, L., Glass, L., Hausdorff, J., Ivanov, P. C., Mark, R., ... & Stanley, H. E. (2000).
- PhysioBank, PhysioToolkit, and PhysioNet: Components of a new research resource for complex
- physiologic signals. Circulation [Online]. 101 (23), pp. e215–e220. RRID:SCR_007345.
+ 1. The official PhysioNet DREAMT release rooted at a version directory
+ containing ``participant_info.csv`` and one or both of ``data_64Hz`` and
+ ``data_100Hz``.
+ 2. A processed per-subject directory containing files named with a subject
+ identifier such as ``S002_record.csv`` or ``S002_features.npz``.
- Note:
- ---------
- Dataset follows file and folder structure of dataset version, looks for participant_info.csv and data folders,
- so root path should be version downloaded, example: root = ".../dreamt/1.0.0/" or ".../dreamt/2.0.0/"
+ For the official release, the dataset builds a metadata table with one
+ event per subject. The downstream task then reads the referenced signal file
+ and produces fixed-size windows for sleep-stage classification.
Args:
- root: root directory containing the dataset files
- dataset_name: optional name of dataset, defaults to "dreamt_sleep"
- config_path: optional configuration file, defaults to "dreamt.yaml"
-
- Attributes:
- root: root directory containing the dataset files
- dataset_name: name of dataset
- config_path: path to configuration file
+ root: Root directory of DREAMT data. This can be the DREAMT version
+ directory itself or a parent directory that contains exactly one
+ DREAMT release.
+ dataset_name: Optional dataset name. Defaults to ``"dreamt_sleep"``.
+ config_path: Optional config path. Defaults to the bundled
+ ``dreamt.yaml`` config.
+ preferred_source: Preferred signal source for the generic
+ ``signal_file`` column. One of ``"auto"``, ``"wearable"``, or
+ ``"psg"``.
+ cache_dir: Optional PyHealth cache directory.
+ num_workers: Number of workers for PyHealth dataset operations.
+ dev: Whether to enable PyHealth dev mode.
Examples:
>>> from pyhealth.datasets import DREAMTDataset
- >>> dataset = DREAMTDataset(root = "/path/to/dreamt/data/version")
- >>> dataset.stats()
- >>>
- >>> # Get all patient ids
- >>> unique_patients = dataset.unique_patient_ids
- >>> print(f"There are {len(unique_patients)} patients")
- >>>
- >>> # Get single patient data
+ >>> dataset = DREAMTDataset(
+ ... root="/path/to/dreamt/2.1.0",
+ ... preferred_source="wearable",
+ ... )
>>> patient = dataset.get_patient("S002")
- >>> print(f"Patient has {len(patient.data_source)} event")
- >>>
- >>> # Get event
- >>> event = patient.get_events(event_type="dreamt_sleep")
- >>>
- >>> # Get Apnea-Hypopnea Index (AHI)
- >>> ahi = event[0].ahi
- >>> print(f"AHI is {ahi}")
- >>>
- >>> # Get 64Hz sleep file path
- >>> file_path = event[0].file_64hz
- >>> print(f"64Hz sleep file path: {file_path}")
+ >>> event = patient.get_events("dreamt_sleep")[0]
+ >>> event.signal_file
+ '/path/to/dreamt/2.1.0/data_64Hz/S002_whole_df.csv'
"""
+ SUPPORTED_SUFFIXES = {".csv", ".parquet", ".pkl", ".pickle", ".npz", ".npy"}
+ _SUBJECT_PATTERN = re.compile(r"(S\d{3})", re.IGNORECASE)
+
def __init__(
- self,
- root: str,
- dataset_name: Optional[str] = None,
- config_path: Optional[str] = None,
+ self,
+ root: str,
+ dataset_name: Optional[str] = None,
+ config_path: Optional[str] = None,
+ preferred_source: str = "auto",
+ cache_dir: str | Path | None = None,
+ num_workers: int = 1,
+ dev: bool = False,
) -> None:
+ preferred_source = preferred_source.lower()
+ if preferred_source not in {"auto", "wearable", "psg"}:
+ raise ValueError(
+ "preferred_source must be one of 'auto', 'wearable', or 'psg'."
+ )
+
if config_path is None:
- logger.info("No config provided, using default config")
config_path = Path(__file__).parent / "configs" / "dreamt.yaml"
-
- metadata_file = Path(root) / "dreamt-metadata.csv"
- if not os.path.exists(metadata_file):
- logger.info(f"{metadata_file} does not exist")
- self.prepare_metadata(root)
-
- default_tables = ["dreamt_sleep"]
+ resolved_root = self._resolve_root(Path(root).expanduser().resolve())
+ self.prepare_metadata(resolved_root, preferred_source=preferred_source)
+
+ self.preferred_source = preferred_source
super().__init__(
- root=root,
- tables=default_tables,
+ root=str(resolved_root),
+ tables=["dreamt_sleep"],
dataset_name=dataset_name or "dreamt_sleep",
- config_path=config_path
+ config_path=str(config_path),
+ cache_dir=cache_dir,
+ num_workers=num_workers,
+ dev=dev,
)
-
- def get_patient_file(self, patient_id: str, root: str, file_path: str) -> Union[str | None]:
- """
- Returns file path of 64Hz and 100Hz data for a patient, or None if no file found
-
- Args:
- patient_id: patient identifier
- root: root directory containing the dataset files
- file_path: path to location of 64Hz or 100Hz file
-
- Returns:
- file: path to file location or None if no file found
- """
-
- if file_path == "data_64Hz" or file_path == "data":
- file = Path(root) / f"{file_path}" / f"{patient_id}_whole_df.csv"
-
- if file_path == "data_100Hz":
- file = Path(root) / f"{file_path}" / f"{patient_id}_PSG_df.csv"
- if not os.path.exists(str(file)):
- logger.info(f"{file} not found")
- file = None
-
- return file
+ @staticmethod
+ def _extract_subject_id(path: Path) -> Optional[str]:
+ match = DREAMTDataset._SUBJECT_PATTERN.search(path.name)
+ if match is None:
+ return None
+ return match.group(1).upper()
+ @classmethod
+ def _is_signal_file(cls, path: Path) -> bool:
+ if path.suffix.lower() not in cls.SUPPORTED_SUFFIXES:
+ return False
+ return cls._extract_subject_id(path) is not None
- def prepare_metadata(self, root: str) -> None:
- """
- Prepares metadata csv file for the DREAMT dataset by performing the following:
- 1. Obtain clinical data from participant_info.csv file
- 2. Process file paths based on patients found in clinical data
- 3. Organize all data into a single DataFrame
- 4. Save the processed DataFrame to a CSV file
-
- Args:
- root: root directory containing the dataset files
- """
+ @classmethod
+ def _resolve_root(cls, root: Path) -> Path:
+ """Resolve a user-provided root to a concrete DREAMT data directory."""
+ if not root.exists():
+ raise FileNotFoundError(f"DREAMT root does not exist: {root}")
+
+ if (root / "participant_info.csv").exists():
+ return root
+
+ if any(
+ cls._is_signal_file(path)
+ for path in root.iterdir()
+ if path.is_file()
+ ):
+ return root
+
+ candidates = sorted(
+ {
+ path.parent.resolve()
+ for path in root.rglob("participant_info.csv")
+ }
+ )
+ if len(candidates) == 1:
+ return candidates[0]
+ if len(candidates) > 1:
+ candidate_text = ", ".join(str(path) for path in candidates[:5])
+ raise ValueError(
+ "Found multiple DREAMT roots under the provided path. Use an "
+ f"explicit version directory instead. Candidates: {candidate_text}"
+ )
+
+ signal_candidates = sorted(
+ {
+ path.parent.resolve()
+ for path in root.rglob("*")
+ if path.is_file() and cls._is_signal_file(path)
+ }
+ )
+ if len(signal_candidates) == 1:
+ return signal_candidates[0]
+ if len(signal_candidates) > 1:
+ counts = {
+ candidate: sum(
+ 1
+ for child in candidate.iterdir()
+ if child.is_file() and cls._is_signal_file(child)
+ )
+ for candidate in signal_candidates
+ }
+ best = max(counts.items(), key=lambda item: item[1])[0]
+ logger.info(
+ "Resolved DREAMT root to %s based on detected subject files.",
+ best,
+ )
+ return best
+
+ raise FileNotFoundError(
+ "Could not find a DREAMT data directory containing "
+ "participant_info.csv or processed subject files."
+ )
- output_path = Path(root) / "dreamt-metadata.csv"
+ @staticmethod
+ def _coerce_float(value: object) -> object:
+ if pd.isna(value):
+ return None
+ text = str(value).strip()
+ if not text:
+ return None
+ if text.endswith("%"):
+ text = text[:-1]
+ try:
+ return float(text)
+ except ValueError:
+ return value
+
+ @classmethod
+ def _locate_subject_file(cls, directory: Path, patient_id: str) -> Optional[Path]:
+ if not directory.exists():
+ return None
+
+ candidates = sorted(
+ path
+ for path in directory.iterdir()
+ if path.is_file()
+ and cls._extract_subject_id(path) == patient_id
+ and path.suffix.lower() in cls.SUPPORTED_SUFFIXES
+ )
+ if not candidates:
+ return None
- # Obtain patient clinical data
- participant_info_path = Path(root) / "participant_info.csv"
+ def score(path: Path) -> tuple[int, int, int]:
+ name = path.name.lower()
+ return (
+ int("updated" in name),
+ int("whole" in name or "psg" in name or "record" in name),
+ -len(name),
+ )
+
+ return max(candidates, key=score)
+
+ @staticmethod
+ def _select_signal_file(
+ file_64hz: Optional[Path],
+ file_100hz: Optional[Path],
+ preferred_source: str,
+ ) -> tuple[Optional[Path], Optional[str], Optional[float]]:
+ if preferred_source == "wearable":
+ if file_64hz is not None:
+ return file_64hz, "wearable", 64.0
+ if file_100hz is not None:
+ return file_100hz, "psg", 100.0
+ elif preferred_source == "psg":
+ if file_100hz is not None:
+ return file_100hz, "psg", 100.0
+ if file_64hz is not None:
+ return file_64hz, "wearable", 64.0
+ else:
+ if file_64hz is not None:
+ return file_64hz, "wearable", 64.0
+ if file_100hz is not None:
+ return file_100hz, "psg", 100.0
+ return None, None, None
+
+ @classmethod
+ def _build_metadata_from_participant_info(
+ cls,
+ root: Path,
+ preferred_source: str,
+ ) -> pd.DataFrame:
+ participant_info_path = root / "participant_info.csv"
participant_info = pd.read_csv(participant_info_path)
+ participant_info.columns = [
+ str(col).strip() for col in participant_info.columns
+ ]
+
+ rename_map = {
+ "SID": "patient_id",
+ "AGE": "age",
+ "Age": "age",
+ "GENDER": "gender",
+ "Gender": "gender",
+ "BMI": "bmi",
+ "OAHI": "oahi",
+ "AHI": "ahi",
+ "Mean_SaO2": "mean_sao2",
+ "Arousal Index": "arousal_index",
+ "MEDICAL_HISTORY": "medical_history",
+ "Sleep_Disorders": "sleep_disorders",
+ }
+ participant_info = participant_info.rename(columns=rename_map)
+ if "patient_id" not in participant_info.columns:
+ raise ValueError(
+ "participant_info.csv must contain a SID/patient_id column."
+ )
+
+ file_64_dir = root / "data_64Hz"
+ if not file_64_dir.exists():
+ legacy_dir = root / "data"
+ file_64_dir = legacy_dir if legacy_dir.exists() else file_64_dir
+ file_100_dir = root / "data_100Hz"
+
+ participant_info["patient_id"] = participant_info["patient_id"].astype(str)
+ participant_info["file_64hz"] = participant_info["patient_id"].apply(
+ lambda pid: cls._locate_subject_file(file_64_dir, pid)
+ )
+ participant_info["file_100hz"] = participant_info["patient_id"].apply(
+ lambda pid: cls._locate_subject_file(file_100_dir, pid)
+ )
+
+ resolved = participant_info.apply(
+ lambda row: cls._select_signal_file(
+ row["file_64hz"],
+ row["file_100hz"],
+ preferred_source=preferred_source,
+ ),
+ axis=1,
+ result_type="expand",
+ )
+ resolved.columns = ["signal_file", "signal_source", "sampling_rate_hz"]
+ participant_info = pd.concat([participant_info, resolved], axis=1)
+ for column in ["age", "bmi", "oahi", "ahi", "mean_sao2", "arousal_index"]:
+ if column in participant_info.columns:
+ participant_info[column] = participant_info[column].apply(
+ cls._coerce_float
+ )
- # Determine folder structure, assign associated file paths based on folder structure
- all_folders = [item.name for item in Path(root).iterdir() if item.is_dir()]
- file_path_64hz = "data_64Hz" if "data_64Hz" in all_folders else "data"
- file_path_100hz = "data_100Hz"
+ for column in ["file_64hz", "file_100hz", "signal_file"]:
+ participant_info[column] = participant_info[column].apply(
+ lambda value: str(value) if isinstance(value, Path) else value
+ )
- # Determine paths for 64Hz and 100Hz files for each patient
- participant_info['file_64hz'] = participant_info['SID'].apply(
- lambda sid: self.get_patient_file(sid, root, file_path_64hz)
+ participant_info["signal_format"] = participant_info["signal_file"].apply(
+ lambda value: Path(value).suffix.lower() if isinstance(value, str) else None
)
- participant_info['file_100hz'] = participant_info['SID'].apply(
- lambda sid: self.get_patient_file(sid, root, file_path_100hz)
+ participant_info["record_id"] = participant_info["patient_id"]
+
+ available_mask = participant_info["signal_file"].notna()
+ missing_count = int((~available_mask).sum())
+ if missing_count:
+ logger.info(
+ "Dropping %s DREAMT participants without local signal files.",
+ missing_count,
+ )
+ participant_info = participant_info.loc[available_mask].reset_index(drop=True)
+
+ return participant_info
+
+ @classmethod
+ def _build_metadata_from_processed_files(
+ cls,
+ root: Path,
+ preferred_source: str,
+ ) -> pd.DataFrame:
+ del preferred_source
+ files = sorted(
+ path
+ for path in root.rglob("*")
+ if path.is_file() and cls._is_signal_file(path)
)
+ if not files:
+ raise FileNotFoundError(
+ "No DREAMT subject files were found under the provided root."
+ )
+
+ rows = []
+ for path in files:
+ patient_id = cls._extract_subject_id(path)
+ if patient_id is None:
+ continue
+ rows.append(
+ {
+ "patient_id": patient_id,
+ "record_id": patient_id,
+ "signal_file": str(path.resolve()),
+ "signal_source": "processed",
+ "sampling_rate_hz": None,
+ "signal_format": path.suffix.lower(),
+ "file_64hz": None,
+ "file_100hz": None,
+ }
+ )
+
+ metadata = pd.DataFrame(rows)
+ metadata = metadata.drop_duplicates(subset=["patient_id"], keep="first")
+ return metadata.sort_values("patient_id").reset_index(drop=True)
+
+ @classmethod
+ def prepare_metadata(cls, root: str | Path, preferred_source: str = "auto") -> None:
+ """Create ``dreamt-metadata.csv`` for a local DREAMT directory.
+
+ The generated metadata always includes:
+
+ - ``patient_id``
+ - ``record_id``
+ - ``signal_file``
+ - ``signal_source``
+ - ``sampling_rate_hz``
+ - ``signal_format``
+ - ``file_64hz``
+ - ``file_100hz``
+
+ For the official DREAMT release, participant metadata columns are also
+ preserved when available.
+ """
+ root = Path(root).expanduser().resolve()
+ output_path = root / "dreamt-metadata.csv"
+
+ if (root / "participant_info.csv").exists():
+ metadata = cls._build_metadata_from_participant_info(
+ root,
+ preferred_source=preferred_source,
+ )
+ else:
+ metadata = cls._build_metadata_from_processed_files(
+ root,
+ preferred_source=preferred_source,
+ )
+
+ if metadata.empty:
+ raise ValueError("No DREAMT subjects were found while preparing metadata.")
+
+ if metadata["signal_file"].notna().sum() == 0:
+ raise ValueError(
+ "Metadata was created, but no usable signal files were detected. "
+ "Expected DREAMT subject files such as S002_whole_df.csv, "
+ "S002_PSG_df.csv, or processed subject files like S002_record.npz."
+ )
+
+ metadata.to_csv(output_path, index=False)
- # Remove "%" from mean SaO2 recording
- participant_info['Mean_SaO2'] = participant_info['Mean_SaO2'].str[:-1]
-
- # Format columns to align with BaseDataset
- participant_info = participant_info.rename(columns = {
- 'SID': 'patient_id',
- 'AGE': 'age',
- 'GENDER': 'gender',
- 'BMI': 'bmi',
- 'OAHI': 'oahi',
- 'AHI': 'ahi',
- 'Mean_SaO2': 'mean_sao2',
- 'Arousal Index': 'arousal_index',
- "MEDICAL_HISTORY": 'medical_history',
- "Sleep_Disorders": 'sleep_disorders'
- })
-
- # Create csv
- participant_info.to_csv(output_path, index=False)
+ @property
+ def default_task(self) -> SleepStagingDREAMT:
+ """Return the default DREAMT sleep-staging task."""
+ return SleepStagingDREAMT()
diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py
index 945822910..e7745b489 100644
--- a/pyhealth/models/__init__.py
+++ b/pyhealth/models/__init__.py
@@ -40,6 +40,7 @@
from .ehrmamba import EHRMamba, MambaBlock
from .vae import VAE
from .vision_embedding import VisionEmbeddingModel
+from .watchsleepnet import WatchSleepNet
from .text_embedding import TextEmbedding
from .sdoh import SdohClassifier
from .medlink import MedLink
diff --git a/pyhealth/models/watchsleepnet.py b/pyhealth/models/watchsleepnet.py
new file mode 100644
index 000000000..43f4d3694
--- /dev/null
+++ b/pyhealth/models/watchsleepnet.py
@@ -0,0 +1,426 @@
+import math
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..datasets.sample_dataset import SampleDataset
+from .base_model import BaseModel
+
+
+class ResidualConvBlock(nn.Module):
+ """Residual 1D convolution block for local feature extraction."""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int = 3,
+ dropout: float = 0.1,
+ ) -> None:
+ super().__init__()
+ padding = kernel_size // 2
+ self.conv1 = nn.Conv1d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ )
+ self.bn1 = nn.BatchNorm1d(out_channels)
+ self.relu1 = nn.ReLU()
+ self.dropout1 = nn.Dropout(dropout)
+
+ self.conv2 = nn.Conv1d(
+ out_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ )
+ self.bn2 = nn.BatchNorm1d(out_channels)
+ self.dropout2 = nn.Dropout(dropout)
+ self.relu2 = nn.ReLU()
+
+ self.shortcut = None
+ if in_channels != out_channels:
+ self.shortcut = nn.Conv1d(in_channels, out_channels, kernel_size=1)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ residual = x if self.shortcut is None else self.shortcut(x)
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu1(x)
+ x = self.dropout1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.dropout2(x)
+ x = x + residual
+ x = self.relu2(x)
+ return x
+
+
+class DilatedTCNBlock(nn.Module):
+ """Residual dilated temporal block."""
+
+ def __init__(
+ self,
+ channels: int,
+ dilation: int,
+ kernel_size: int = 3,
+ dropout: float = 0.1,
+ ) -> None:
+ super().__init__()
+ padding = ((kernel_size - 1) * dilation) // 2
+ self.conv1 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ dilation=dilation,
+ )
+ self.bn1 = nn.BatchNorm1d(channels)
+ self.relu1 = nn.ReLU()
+ self.dropout1 = nn.Dropout(dropout)
+
+ self.conv2 = nn.Conv1d(
+ channels,
+ channels,
+ kernel_size=kernel_size,
+ padding=padding,
+ dilation=dilation,
+ )
+ self.bn2 = nn.BatchNorm1d(channels)
+ self.dropout2 = nn.Dropout(dropout)
+ self.relu2 = nn.ReLU()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ residual = x
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.relu1(x)
+ x = self.dropout1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ x = self.dropout2(x)
+ x = x + residual
+ x = self.relu2(x)
+ return x
+
+
+class TemporalAttentionPool(nn.Module):
+ """Attention pooling over temporal features."""
+
+ def __init__(self, input_dim: int) -> None:
+ super().__init__()
+ self.score = nn.Linear(input_dim, 1)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ ) -> tuple[torch.Tensor, torch.Tensor]:
+ scores = self.score(x).squeeze(-1)
+ if mask is not None:
+ scores = scores.masked_fill(~mask, float("-inf"))
+ weights = torch.softmax(scores, dim=1)
+ pooled = torch.sum(weights.unsqueeze(-1) * x, dim=1)
+ return pooled, weights
+
+
+class WatchSleepNet(BaseModel):
+ """Simplified WatchSleepNet-style classifier for sleep staging.
+
+ The architecture is intentionally compact but preserves the main modeling
+ ideas from WatchSleepNet:
+
+ - residual 1D convolution for local morphology
+ - optional dilated temporal blocks
+ - bidirectional LSTM for sequence modeling
+ - optional multi-head self-attention
+ - temporal pooling and a classification head
+
+ Args:
+ dataset: A task-specific ``SampleDataset`` with one tensor feature and
+ one multiclass label.
+ feature_key: Feature field to read from the dataset. Defaults to the
+ first feature key.
+ input_dim: Optional feature dimension. If omitted, inferred from the
+ first sample.
+ hidden_dim: Hidden size of the bidirectional LSTM.
+ conv_channels: Channel width of the residual convolution stack.
+ conv_blocks: Number of residual convolution blocks.
+ tcn_blocks: Number of dilated TCN blocks.
+ lstm_layers: Number of bidirectional LSTM layers.
+ num_attention_heads: Number of attention heads when attention is used.
+ dropout: Dropout rate applied throughout the network.
+ num_classes: Optional override for output classes. Defaults to the task
+ label processor size.
+ use_tcn: Whether to use the dilated TCN stack.
+ use_attention: Whether to use multi-head self-attention and attention
+ pooling. When disabled, the model uses masked mean pooling instead.
+ sequence_output: Whether to emit logits for each timestep instead of a
+ single pooled prediction.
+ ignore_index: Ignore label used for padded sequence targets.
+ """
+
+ def __init__(
+ self,
+ dataset: SampleDataset,
+ feature_key: Optional[str] = None,
+ input_dim: Optional[int] = None,
+ hidden_dim: int = 64,
+ conv_channels: int = 64,
+ conv_blocks: int = 2,
+ tcn_blocks: int = 2,
+ lstm_layers: int = 1,
+ num_attention_heads: int = 4,
+ dropout: float = 0.1,
+ num_classes: Optional[int] = None,
+ use_tcn: bool = True,
+ use_attention: bool = True,
+ sequence_output: bool = False,
+ ignore_index: int = -100,
+ ) -> None:
+ super().__init__(dataset=dataset)
+ if len(self.label_keys) != 1:
+ raise ValueError("WatchSleepNet supports a single label field.")
+
+ self.label_key = self.label_keys[0]
+ self.feature_key = feature_key or self.feature_keys[0]
+ self.hidden_dim = hidden_dim
+ self.conv_channels = conv_channels
+ self.conv_blocks = conv_blocks
+ self.tcn_blocks = tcn_blocks
+ self.lstm_layers = lstm_layers
+ self.num_attention_heads = num_attention_heads
+ self.dropout = dropout
+ self.use_tcn = use_tcn
+ self.use_attention = use_attention
+ self.sequence_output = sequence_output
+ self.ignore_index = ignore_index
+
+ self.input_dim = input_dim or self._infer_input_dim(self.feature_key)
+ if self.input_dim <= 0:
+ raise ValueError("input_dim must be positive.")
+
+ self.input_projection = nn.Conv1d(self.input_dim, conv_channels, kernel_size=1)
+ self.residual_conv = nn.ModuleList(
+ [
+ ResidualConvBlock(
+ conv_channels,
+ conv_channels,
+ kernel_size=3,
+ dropout=dropout,
+ )
+ for _ in range(conv_blocks)
+ ]
+ )
+
+ self.tcn = nn.ModuleList(
+ [
+ DilatedTCNBlock(
+ conv_channels,
+ dilation=2**block_index,
+ kernel_size=3,
+ dropout=dropout,
+ )
+ for block_index in range(tcn_blocks)
+ ]
+ )
+
+ self.bilstm = nn.LSTM(
+ input_size=conv_channels,
+ hidden_size=hidden_dim,
+ num_layers=lstm_layers,
+ batch_first=True,
+ bidirectional=True,
+ dropout=dropout if lstm_layers > 1 else 0.0,
+ )
+
+ context_dim = hidden_dim * 2
+ self.attention_projection = None
+ if use_attention:
+ projected_dim = context_dim
+ if projected_dim % num_attention_heads != 0:
+ projected_dim = (
+ math.ceil(projected_dim / num_attention_heads)
+ * num_attention_heads
+ )
+ self.attention_projection = nn.Linear(context_dim, projected_dim)
+ self.self_attention = nn.MultiheadAttention(
+ embed_dim=projected_dim,
+ num_heads=num_attention_heads,
+ dropout=dropout,
+ batch_first=True,
+ )
+ self.attention_norm = nn.LayerNorm(projected_dim)
+ self.attention_pool = TemporalAttentionPool(projected_dim)
+ classifier_input_dim = projected_dim
+ else:
+ self.self_attention = None
+ self.attention_norm = None
+ self.attention_pool = None
+ classifier_input_dim = context_dim
+
+ self.final_dropout = nn.Dropout(dropout)
+ if self.sequence_output and num_classes is None:
+ raise ValueError(
+ "num_classes must be provided when sequence_output=True."
+ )
+ self.classifier = nn.Linear(
+ classifier_input_dim,
+ num_classes if num_classes is not None else self.get_output_size(),
+ )
+
+ def _infer_input_dim(self, feature_key: str) -> int:
+ for sample in self.dataset:
+ if feature_key not in sample:
+ continue
+ value = sample[feature_key]
+ tensor = (
+ value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
+ )
+ if tensor.dim() == 1:
+ return 1
+ return int(tensor.shape[-1])
+ raise ValueError(
+ f"Unable to infer input_dim for feature '{feature_key}' from the dataset."
+ )
+
+ def _coerce_input(self, value: Any) -> torch.Tensor:
+ tensor = value if isinstance(value, torch.Tensor) else torch.as_tensor(value)
+ tensor = tensor.to(self.device, dtype=torch.float32)
+
+ if tensor.dim() == 2:
+ tensor = tensor.unsqueeze(-1)
+ if tensor.dim() != 3:
+ raise ValueError(
+ "Expected input tensor with 2 or 3 dims, received shape "
+ f"{tuple(tensor.shape)}."
+ )
+
+ if tensor.shape[-1] == self.input_dim:
+ return tensor
+ if tensor.shape[1] == self.input_dim:
+ return tensor.transpose(1, 2)
+ raise ValueError(
+ "Unable to infer whether input is channel-last or channel-first. "
+ f"Expected one axis to match input_dim={self.input_dim}, got "
+ f"shape {tuple(tensor.shape)}."
+ )
+
+ @staticmethod
+ def _default_mask(x: torch.Tensor) -> torch.Tensor:
+ return torch.ones(x.shape[0], x.shape[1], dtype=torch.bool, device=x.device)
+
+ def _build_mask(
+ self,
+ x: torch.Tensor,
+ explicit_mask: Optional[Any] = None,
+ ) -> torch.Tensor:
+ if explicit_mask is None:
+ return self._default_mask(x)
+ mask = (
+ explicit_mask
+ if isinstance(explicit_mask, torch.Tensor)
+ else torch.as_tensor(explicit_mask)
+ )
+ mask = mask.to(self.device)
+ if mask.dim() != 2:
+ raise ValueError(
+ f"Expected mask with shape [batch, seq_len], got {tuple(mask.shape)}."
+ )
+ return mask > 0
+
+ @staticmethod
+ def _masked_mean(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ weights = mask.unsqueeze(-1).to(dtype=x.dtype)
+ denom = torch.clamp(weights.sum(dim=1), min=1.0)
+ return (x * weights).sum(dim=1) / denom
+
+ def _encode_sequence(
+ self,
+ x: torch.Tensor,
+ mask: torch.Tensor,
+ ) -> torch.Tensor:
+ # x: [batch, seq_len, input_dim] -> [batch, input_dim, seq_len]
+ x = x.transpose(1, 2)
+ x = self.input_projection(x)
+ for block in self.residual_conv:
+ x = block(x)
+ if self.use_tcn:
+ for block in self.tcn:
+ x = block(x)
+
+ # x: [batch, conv_channels, seq_len] -> [batch, seq_len, conv_channels]
+ x = x.transpose(1, 2)
+ x, _ = self.bilstm(x)
+
+ if self.use_attention and self.self_attention is not None:
+ residual = x
+ if self.attention_projection is not None:
+ residual = self.attention_projection(residual)
+ attn_out, _ = self.self_attention(
+ residual,
+ residual,
+ residual,
+ key_padding_mask=~mask,
+ need_weights=False,
+ )
+ x = self.attention_norm(attn_out + residual)
+ return x
+
+ def _compute_sequence_loss(
+ self,
+ logits: torch.Tensor,
+ y_true: torch.Tensor,
+ ) -> torch.Tensor:
+ return F.cross_entropy(
+ logits.reshape(-1, logits.shape[-1]),
+ y_true.reshape(-1).long(),
+ ignore_index=self.ignore_index,
+ )
+
+ def forward(self, **kwargs) -> dict[str, torch.Tensor]:
+ x = self._coerce_input(kwargs[self.feature_key])
+ mask = self._build_mask(x, kwargs.get("mask"))
+ x = self._encode_sequence(x, mask)
+
+ if self.sequence_output:
+ embed = self.final_dropout(x)
+ logits = self.classifier(embed)
+ results = {
+ "logit": logits,
+ "y_prob": torch.softmax(logits, dim=-1),
+ }
+ if self.label_key in kwargs:
+ y_true = kwargs[self.label_key].to(self.device).long()
+ results["loss"] = self._compute_sequence_loss(logits, y_true)
+ results["y_true"] = y_true
+ if kwargs.get("embed", False):
+ results["embed"] = embed
+ return results
+
+ if self.use_attention and self.attention_pool is not None:
+ pooled, _ = self.attention_pool(x, mask)
+ else:
+ pooled = self._masked_mean(x, mask)
+
+ embed = self.final_dropout(pooled)
+ logits = self.classifier(embed)
+ results = {
+ "logit": logits,
+ "y_prob": self.prepare_y_prob(logits),
+ }
+
+ if self.label_key in kwargs:
+ y_true = kwargs[self.label_key].to(self.device)
+ results["loss"] = self.get_loss_function()(logits, y_true)
+ results["y_true"] = y_true
+
+ if kwargs.get("embed", False):
+ results["embed"] = embed
+
+ return results
diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py
index 2f4294a19..0d3c2207f 100644
--- a/pyhealth/tasks/__init__.py
+++ b/pyhealth/tasks/__init__.py
@@ -54,6 +54,8 @@
ReadmissionPredictionOMOP,
)
from .sleep_staging import (
+ SleepStagingDREAMT,
+ SleepStagingDREAMTSeq,
sleep_staging_isruc_fn,
sleep_staging_shhs_fn,
sleep_staging_sleepedf_fn,
diff --git a/pyhealth/tasks/sleep_staging.py b/pyhealth/tasks/sleep_staging.py
index 911b672c8..b975563c4 100644
--- a/pyhealth/tasks/sleep_staging.py
+++ b/pyhealth/tasks/sleep_staging.py
@@ -1,8 +1,14 @@
import os
import pickle
+from pathlib import Path
+from typing import Any, Dict, Optional, Sequence
+
import mne
-import pandas as pd
import numpy as np
+import pandas as pd
+import polars as pl
+
+from .base_task import BaseTask
def sleep_staging_isruc_fn(record, epoch_seconds=10, label_id=1):
@@ -329,6 +335,578 @@ def sleep_staging_shhs_fn(record, epoch_seconds=30):
return samples
+class SleepStagingDREAMT(BaseTask):
+ """Three-class sleep staging task for DREAMT-style wearable sequences.
+
+ This task converts one DREAMT subject recording into fixed-length windows of
+ wearable features and maps detailed sleep stages to three classes:
+
+ - ``0``: Wake
+ - ``1``: NREM
+ - ``2``: REM
+
+ Supported input formats:
+
+ - Raw per-subject CSV/Parquet/Pickle files containing wearable columns and a
+ ``Sleep_Stage`` column.
+ - Processed ``.npz``/``.npy`` dictionaries with keys such as
+ ``features``, ``labels``, and optional ``feature_names``.
+
+ The default feature set mirrors the smartwatch-oriented signals available in
+ DREAMT and includes ``IBI`` plus common wearable context channels.
+ """
+
+ task_name: str = "SleepStagingDREAMT"
+ input_schema: Dict[str, str] = {"signal": "tensor"}
+ output_schema: Dict[str, str] = {"label": "multiclass"}
+
+ DEFAULT_FEATURE_COLUMNS = (
+ "IBI",
+ "HR",
+ "BVP",
+ "EDA",
+ "TEMP",
+ "ACC_X",
+ "ACC_Y",
+ "ACC_Z",
+ )
+ _IGNORE_LABEL = -1
+ _LABEL_MAP = {
+ "W": 0,
+ "WAKE": 0,
+ "WAKEFUL": 0,
+ "0": 0,
+ "N1": 1,
+ "N2": 1,
+ "N3": 1,
+ "N4": 1,
+ "NREM": 1,
+ "1": 1,
+ "2": 1,
+ "3": 1,
+ "4": 1,
+ "R": 2,
+ "REM": 2,
+ "5": 2,
+ }
+ _INVALID_LABELS = {"", "P", "PREPARATION", "UNKNOWN", "?", "NAN", "NONE"}
+
+ def __init__(
+ self,
+ feature_columns: Optional[Sequence[str]] = None,
+ label_column: str = "Sleep_Stage",
+ source_preference: str = "wearable",
+ window_seconds: float = 30.0,
+ stride_seconds: Optional[float] = None,
+ window_size: Optional[int] = None,
+ stride: Optional[int] = None,
+ default_sampling_rate_hz: Optional[float] = None,
+ min_labeled_fraction: float = 0.5,
+ pad_short_windows: bool = False,
+ include_partial_last_window: bool = False,
+ ) -> None:
+ if source_preference not in {"auto", "wearable", "psg"}:
+ raise ValueError(
+ "source_preference must be one of 'auto', 'wearable', or 'psg'."
+ )
+ if window_seconds <= 0 and window_size is None:
+ raise ValueError(
+ "window_seconds must be positive when window_size is None."
+ )
+ if min_labeled_fraction <= 0 or min_labeled_fraction > 1:
+ raise ValueError("min_labeled_fraction must be in (0, 1].")
+
+ self.feature_columns = tuple(feature_columns or self.DEFAULT_FEATURE_COLUMNS)
+ self.label_column = label_column
+ self.source_preference = source_preference
+ self.window_seconds = float(window_seconds)
+ self.stride_seconds = (
+ float(stride_seconds) if stride_seconds is not None else None
+ )
+ self.window_size = window_size
+ self.stride = stride
+ self.default_sampling_rate_hz = default_sampling_rate_hz
+ self.min_labeled_fraction = min_labeled_fraction
+ self.pad_short_windows = pad_short_windows
+ self.include_partial_last_window = include_partial_last_window
+
+ def pre_filter(self, df: pl.LazyFrame) -> pl.LazyFrame:
+ return df.filter(
+ (pl.col("event_type") == "dreamt_sleep")
+ & (
+ pl.col("dreamt_sleep/signal_file").is_not_null()
+ | pl.col("dreamt_sleep/file_64hz").is_not_null()
+ | pl.col("dreamt_sleep/file_100hz").is_not_null()
+ )
+ )
+
+ @classmethod
+ def _normalize_stage_label(cls, value: Any) -> int:
+ if pd.isna(value):
+ return cls._IGNORE_LABEL
+ normalized = str(value).strip().upper()
+ if normalized in cls._INVALID_LABELS:
+ return cls._IGNORE_LABEL
+ return cls._LABEL_MAP.get(normalized, cls._IGNORE_LABEL)
+
+ @staticmethod
+ def _load_processed_payload(path: Path) -> Dict[str, Any]:
+ if path.suffix.lower() == ".npz":
+ with np.load(path, allow_pickle=True) as data:
+ return {key: data[key] for key in data.files}
+
+ payload = np.load(path, allow_pickle=True)
+ if isinstance(payload, np.ndarray) and payload.shape == ():
+ payload = payload.item()
+ if not isinstance(payload, dict):
+ raise ValueError(
+ f"Unsupported DREAMT numpy payload in {path}. "
+ "Expected a dict-like object with features and labels."
+ )
+ return payload
+
+ def _load_signal_source(self, event) -> tuple[Path, Optional[float]]:
+ event_sampling_rate = getattr(event, "sampling_rate_hz", None)
+ if event_sampling_rate is not None:
+ try:
+ event_sampling_rate = float(event_sampling_rate)
+ except (TypeError, ValueError):
+ event_sampling_rate = None
+
+ candidates: list[tuple[Optional[str], Optional[float]]] = []
+ if self.source_preference == "wearable":
+ candidates.extend(
+ [
+ (getattr(event, "file_64hz", None), 64.0),
+ (getattr(event, "signal_file", None), event_sampling_rate),
+ (getattr(event, "file_100hz", None), 100.0),
+ ]
+ )
+ elif self.source_preference == "psg":
+ candidates.extend(
+ [
+ (getattr(event, "file_100hz", None), 100.0),
+ (getattr(event, "signal_file", None), event_sampling_rate),
+ (getattr(event, "file_64hz", None), 64.0),
+ ]
+ )
+ else:
+ candidates.extend(
+ [
+ (getattr(event, "signal_file", None), event_sampling_rate),
+ (getattr(event, "file_64hz", None), 64.0),
+ (getattr(event, "file_100hz", None), 100.0),
+ ]
+ )
+
+ for file_path, sample_rate in candidates:
+ if file_path is None or (
+ isinstance(file_path, float) and np.isnan(file_path)
+ ):
+ continue
+ path = Path(str(file_path)).expanduser().resolve()
+ if path.exists():
+ return path, sample_rate
+
+ raise FileNotFoundError(
+ "No DREAMT signal file was found for the requested patient event."
+ )
+
+ def _dataframe_from_payload(
+ self,
+ payload: Dict[str, Any],
+ path: Path,
+ ) -> pd.DataFrame:
+ if "frame" in payload:
+ frame = payload["frame"]
+ if isinstance(frame, pd.DataFrame):
+ return frame.copy()
+ return pd.DataFrame(frame)
+
+ if "features" not in payload or "labels" not in payload:
+ raise ValueError(
+ f"Processed DREAMT file {path} must contain 'features' and 'labels'."
+ )
+
+ features = np.asarray(payload["features"], dtype=np.float32)
+ labels = np.asarray(payload["labels"])
+ if features.ndim != 2:
+ raise ValueError(
+ f"Processed DREAMT file {path} must provide features with shape [T, F]."
+ )
+ if labels.ndim != 1 or labels.shape[0] != features.shape[0]:
+ raise ValueError(
+ f"Processed DREAMT file {path} must provide labels with shape [T]."
+ )
+
+ feature_names = payload.get("feature_names", self.feature_columns)
+ feature_names = [str(name) for name in feature_names]
+ if len(feature_names) != features.shape[1]:
+ raise ValueError(
+ f"feature_names length does not match feature width in {path}."
+ )
+
+ frame = pd.DataFrame(features, columns=feature_names)
+ frame[self.label_column] = labels
+ if "timestamps" in payload:
+ frame["TIMESTAMP"] = np.asarray(payload["timestamps"])
+ return frame
+
+ def _load_frame(self, path: Path) -> pd.DataFrame:
+ suffix = path.suffix.lower()
+ if suffix == ".csv":
+ return pd.read_csv(path)
+ if suffix == ".parquet":
+ return pd.read_parquet(path)
+ if suffix in {".pkl", ".pickle"}:
+ payload = pd.read_pickle(path)
+ if isinstance(payload, pd.DataFrame):
+ return payload
+ if isinstance(payload, dict):
+ return self._dataframe_from_payload(payload, path)
+ raise ValueError(f"Unsupported pickle payload in {path}.")
+ if suffix in {".npz", ".npy"}:
+ payload = self._load_processed_payload(path)
+ return self._dataframe_from_payload(payload, path)
+ raise ValueError(f"Unsupported DREAMT file format: {path.suffix}")
+
+ @staticmethod
+ def _infer_sampling_rate_hz(
+ frame: pd.DataFrame,
+ fallback: Optional[float],
+ ) -> float:
+ if "TIMESTAMP" in frame.columns:
+ timestamps = pd.to_numeric(frame["TIMESTAMP"], errors="coerce").to_numpy()
+ diffs = np.diff(timestamps)
+ diffs = diffs[np.isfinite(diffs) & (diffs > 0)]
+ if diffs.size > 0:
+ median_step = float(np.median(diffs))
+ if median_step > 0:
+ return 1.0 / median_step
+ if fallback is not None and fallback > 0:
+ return float(fallback)
+ raise ValueError(
+ "Unable to infer DREAMT sampling rate. Provide TIMESTAMP values or "
+ "set default_sampling_rate_hz."
+ )
+
+ def _resolve_window_params(self, sample_rate_hz: float) -> tuple[int, int]:
+ window_size = self.window_size
+ if window_size is None:
+ window_size = max(1, int(round(self.window_seconds * sample_rate_hz)))
+
+ stride = self.stride
+ if stride is None:
+ if self.stride_seconds is not None:
+ stride = max(1, int(round(self.stride_seconds * sample_rate_hz)))
+ else:
+ stride = window_size
+ return window_size, stride
+
+ def _extract_feature_frame(self, frame: pd.DataFrame, path: Path) -> pd.DataFrame:
+ available_columns = {str(column): column for column in frame.columns}
+ feature_data = {}
+ for column in self.feature_columns:
+ source_column = available_columns.get(column)
+ if source_column is None:
+ feature_data[column] = np.zeros(len(frame), dtype=np.float32)
+ continue
+ values = pd.to_numeric(frame[source_column], errors="coerce")
+ feature_data[column] = values.to_numpy(dtype=np.float32)
+
+ if not feature_data:
+ raise ValueError(f"No wearable feature columns were found in {path}.")
+
+ feature_frame = pd.DataFrame(feature_data)
+ feature_frame = feature_frame.ffill().bfill().fillna(0.0)
+ return feature_frame
+
+ def _extract_labels(self, frame: pd.DataFrame, path: Path) -> np.ndarray:
+ if self.label_column not in frame.columns:
+ raise ValueError(
+ f"DREAMT file {path} is missing the label column '{self.label_column}'."
+ )
+ labels = frame[self.label_column].apply(self._normalize_stage_label).to_numpy()
+ return labels.astype(np.int64, copy=False)
+
+ @staticmethod
+ def _majority_label(labels: np.ndarray) -> int:
+ valid = labels[labels >= 0]
+ if valid.size == 0:
+ return SleepStagingDREAMT._IGNORE_LABEL
+ return int(np.bincount(valid).argmax())
+
+ def __call__(self, patient) -> list[dict[str, Any]]:
+ events = patient.get_events("dreamt_sleep")
+ if not events:
+ return []
+
+ samples = []
+ for event in events:
+ signal_path, sampling_rate_hz = self._load_signal_source(event)
+ frame = self._load_frame(signal_path)
+ labels = self._extract_labels(frame, signal_path)
+ valid_mask = labels >= 0
+
+ if valid_mask.sum() == 0:
+ continue
+
+ frame = frame.loc[valid_mask].reset_index(drop=True)
+ labels = labels[valid_mask]
+ feature_frame = self._extract_feature_frame(frame, signal_path)
+ features = feature_frame.to_numpy(dtype=np.float32, copy=False)
+
+ inferred_rate = self._infer_sampling_rate_hz(
+ frame,
+ sampling_rate_hz or self.default_sampling_rate_hz,
+ )
+ window_size, stride = self._resolve_window_params(inferred_rate)
+
+ total_length = features.shape[0]
+ if total_length < window_size and not self.pad_short_windows:
+ continue
+
+ starts = list(range(0, max(total_length - window_size + 1, 1), stride))
+ if (
+ self.include_partial_last_window
+ and total_length > window_size
+ and starts[-1] + window_size < total_length
+ ):
+ starts.append(starts[-1] + stride)
+
+ for window_index, start in enumerate(starts):
+ end = min(start + window_size, total_length)
+ window_features = features[start:end]
+ window_labels = labels[start:end]
+
+ labeled_fraction = float((window_labels >= 0).mean())
+ if labeled_fraction < self.min_labeled_fraction:
+ continue
+
+ label = self._majority_label(window_labels)
+ if label == self._IGNORE_LABEL:
+ continue
+
+ if end - start < window_size:
+ if not self.pad_short_windows:
+ continue
+ pad_rows = window_size - (end - start)
+ window_features = np.pad(
+ window_features,
+ pad_width=((0, pad_rows), (0, 0)),
+ mode="constant",
+ )
+
+ record_id = f"{patient.patient_id}-{signal_path.stem}-{window_index}"
+ samples.append(
+ {
+ "patient_id": patient.patient_id,
+ "record_id": record_id,
+ "signal": window_features.astype(np.float32, copy=False),
+ "label": int(label),
+ "signal_source": getattr(event, "signal_source", None),
+ "signal_file": str(signal_path),
+ }
+ )
+
+ return samples
+
+
+class SleepStagingDREAMTSeq(SleepStagingDREAMT):
+ """Sequence-style DREAMT sleep staging task closer to WatchSleepNet.
+
+ This task converts a DREAMT recording into a sequence of epoch-level
+ feature vectors and labels. Each sample contains:
+
+ - ``signal``: ``[sequence_length, input_dim]``
+ - ``mask``: ``[sequence_length]`` with 1 for valid epochs and 0 for padding
+ - ``label``: ``[sequence_length]`` with padded labels set to
+ ``ignore_index``
+
+ By default, this class uses ``IBI`` only to better align with the paper's
+ shared-modality representation.
+ """
+
+ task_name: str = "SleepStagingDREAMTSeq"
+ input_schema: Dict[str, str] = {
+ "signal": "tensor",
+ "mask": "tensor",
+ }
+ output_schema: Dict[str, str] = {"label": "tensor"}
+
+ def __init__(
+ self,
+ feature_columns: Optional[Sequence[str]] = ("IBI",),
+ label_column: str = "Sleep_Stage",
+ source_preference: str = "wearable",
+ epoch_seconds: float = 30.0,
+ sequence_length: int = 1100,
+ stride_epochs: Optional[int] = None,
+ default_sampling_rate_hz: Optional[float] = None,
+ pad_value: float = 0.0,
+ truncate: bool = True,
+ ignore_index: int = -100,
+ ) -> None:
+ super().__init__(
+ feature_columns=feature_columns,
+ label_column=label_column,
+ source_preference=source_preference,
+ default_sampling_rate_hz=default_sampling_rate_hz,
+ )
+ if epoch_seconds <= 0:
+ raise ValueError("epoch_seconds must be positive.")
+ if sequence_length <= 0:
+ raise ValueError("sequence_length must be positive.")
+
+ self.epoch_seconds = float(epoch_seconds)
+ self.sequence_length = int(sequence_length)
+ self.stride_epochs = stride_epochs
+ self.pad_value = float(pad_value)
+ self.truncate = truncate
+ self.ignore_index = int(ignore_index)
+
+ def _epochize_features_and_labels(
+ self,
+ features: np.ndarray,
+ labels: np.ndarray,
+ sample_rate_hz: float,
+ ) -> tuple[np.ndarray, np.ndarray]:
+ epoch_size = max(1, int(round(self.epoch_seconds * sample_rate_hz)))
+ num_epochs = features.shape[0] // epoch_size
+ if num_epochs == 0:
+ return (
+ np.zeros((0, features.shape[1]), dtype=np.float32),
+ np.zeros((0,), dtype=np.int64),
+ )
+
+ epoch_features: list[np.ndarray] = []
+ epoch_labels: list[int] = []
+ for epoch_index in range(num_epochs):
+ start = epoch_index * epoch_size
+ end = start + epoch_size
+ epoch_feature_values = features[start:end]
+ epoch_label_values = labels[start:end]
+ epoch_label = self._majority_label(epoch_label_values)
+ if epoch_label == self._IGNORE_LABEL:
+ continue
+ epoch_features.append(
+ epoch_feature_values.mean(axis=0).astype(np.float32, copy=False)
+ )
+ epoch_labels.append(epoch_label)
+
+ if not epoch_features:
+ return (
+ np.zeros((0, features.shape[1]), dtype=np.float32),
+ np.zeros((0,), dtype=np.int64),
+ )
+
+ return (
+ np.stack(epoch_features).astype(np.float32, copy=False),
+ np.asarray(epoch_labels, dtype=np.int64),
+ )
+
+ def _build_sequence_sample(
+ self,
+ patient_id: str,
+ signal_path: Path,
+ epoch_features: np.ndarray,
+ epoch_labels: np.ndarray,
+ chunk_index: int,
+ start_epoch: int,
+ ) -> dict[str, Any]:
+ valid_length = min(epoch_features.shape[0], self.sequence_length)
+ feature_dim = epoch_features.shape[1]
+
+ signal = np.full(
+ (self.sequence_length, feature_dim),
+ fill_value=self.pad_value,
+ dtype=np.float32,
+ )
+ mask = np.zeros((self.sequence_length,), dtype=np.float32)
+ labels = np.full(
+ (self.sequence_length,),
+ fill_value=self.ignore_index,
+ dtype=np.int64,
+ )
+
+ signal[:valid_length] = epoch_features[:valid_length]
+ mask[:valid_length] = 1.0
+ labels[:valid_length] = epoch_labels[:valid_length]
+
+ return {
+ "patient_id": patient_id,
+ "record_id": f"{patient_id}-{signal_path.stem}-seq-{chunk_index}",
+ "signal": signal,
+ "mask": mask,
+ "label": labels,
+ "signal_source": None,
+ "signal_file": str(signal_path),
+ "start_epoch": int(start_epoch),
+ }
+
+ def __call__(self, patient) -> list[dict[str, Any]]:
+ events = patient.get_events("dreamt_sleep")
+ if not events:
+ return []
+
+ samples = []
+ for event in events:
+ signal_path, sampling_rate_hz = self._load_signal_source(event)
+ frame = self._load_frame(signal_path)
+ labels = self._extract_labels(frame, signal_path)
+ feature_frame = self._extract_feature_frame(frame, signal_path)
+ features = feature_frame.to_numpy(dtype=np.float32, copy=False)
+
+ inferred_rate = self._infer_sampling_rate_hz(
+ frame,
+ sampling_rate_hz or self.default_sampling_rate_hz,
+ )
+ epoch_features, epoch_labels = self._epochize_features_and_labels(
+ features,
+ labels,
+ inferred_rate,
+ )
+ if epoch_features.shape[0] == 0:
+ continue
+
+ stride_epochs = self.stride_epochs or self.sequence_length
+ if epoch_features.shape[0] <= self.sequence_length:
+ samples.append(
+ self._build_sequence_sample(
+ patient.patient_id,
+ signal_path,
+ epoch_features,
+ epoch_labels,
+ chunk_index=0,
+ start_epoch=0,
+ )
+ )
+ continue
+
+ max_start = epoch_features.shape[0] - self.sequence_length
+ starts = list(range(0, max_start + 1, stride_epochs))
+ if (
+ not self.truncate
+ and starts[-1] != max_start
+ ):
+ starts.append(max_start)
+
+ for chunk_index, start_epoch in enumerate(starts):
+ end_epoch = start_epoch + self.sequence_length
+ samples.append(
+ self._build_sequence_sample(
+ patient.patient_id,
+ signal_path,
+ epoch_features[start_epoch:end_epoch],
+ epoch_labels[start_epoch:end_epoch],
+ chunk_index=chunk_index,
+ start_epoch=start_epoch,
+ )
+ )
+
+ return samples
+
+
if __name__ == "__main__":
from pyhealth.datasets import SleepEDFDataset, SHHSDataset, ISRUCDataset
diff --git a/tests/core/test_dreamt_dataset.py b/tests/core/test_dreamt_dataset.py
new file mode 100644
index 000000000..c0febcfc9
--- /dev/null
+++ b/tests/core/test_dreamt_dataset.py
@@ -0,0 +1,126 @@
+import shutil
+import tempfile
+import unittest
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+
+from pyhealth.datasets import DREAMTDataset
+
+
+def _create_synthetic_dreamt_root(root: Path) -> Path:
+ dreamt_root = root / "physionet.org" / "files" / "dreamt" / "2.1.0"
+ (dreamt_root / "data_64Hz").mkdir(parents=True)
+ (dreamt_root / "data_100Hz").mkdir(parents=True)
+
+ participant_info = pd.DataFrame(
+ {
+ "SID": ["S001", "S002"],
+ "AGE": [28.0, 35.0],
+ "GENDER": ["F", "M"],
+ "BMI": [21.0, 24.5],
+ "OAHI": [1.0, 2.0],
+ "AHI": [3.0, 4.0],
+ "Mean_SaO2": ["97%", "96%"],
+ "Arousal Index": [10.0, 12.0],
+ "MEDICAL_HISTORY": ["None", "None"],
+ "Sleep_Disorders": ["OSA", "None"],
+ }
+ )
+ participant_info.to_csv(dreamt_root / "participant_info.csv", index=False)
+
+ stage_pattern = ["P"] * 10 + ["W"] * 10 + ["N2"] * 10 + ["R"] * 10
+ timestamps = np.arange(len(stage_pattern), dtype=np.float32) / 64.0
+
+ for patient_id in ["S001", "S002"]:
+ frame = pd.DataFrame(
+ {
+ "TIMESTAMP": timestamps,
+ "IBI": np.linspace(0.8, 1.2, len(stage_pattern)),
+ "HR": np.linspace(60.0, 68.0, len(stage_pattern)),
+ "BVP": np.linspace(0.1, 0.5, len(stage_pattern)),
+ "EDA": np.linspace(0.01, 0.04, len(stage_pattern)),
+ "TEMP": np.linspace(33.0, 33.5, len(stage_pattern)),
+ "ACC_X": np.zeros(len(stage_pattern)),
+ "ACC_Y": np.ones(len(stage_pattern)),
+ "ACC_Z": np.full(len(stage_pattern), 2.0),
+ "Sleep_Stage": stage_pattern,
+ }
+ )
+ frame.to_csv(
+ dreamt_root / "data_64Hz" / f"{patient_id}_whole_df.csv",
+ index=False,
+ )
+ frame.to_csv(
+ dreamt_root / "data_100Hz" / f"{patient_id}_PSG_df_updated.csv",
+ index=False,
+ )
+
+ return dreamt_root
+
+
+class TestDREAMTDataset(unittest.TestCase):
+ def setUp(self):
+ self.temp_dir = Path(tempfile.mkdtemp())
+ self.synthetic_root = _create_synthetic_dreamt_root(self.temp_dir)
+
+ def tearDown(self):
+ shutil.rmtree(self.temp_dir)
+
+ def test_resolves_nested_root_and_creates_metadata(self):
+ dataset = DREAMTDataset(
+ root=str(self.temp_dir),
+ cache_dir=self.temp_dir / "cache",
+ )
+
+ self.assertEqual(
+ Path(dataset.root).resolve(),
+ self.synthetic_root.resolve(),
+ )
+ self.assertEqual(dataset.dataset_name, "dreamt_sleep")
+ self.assertTrue((self.synthetic_root / "dreamt-metadata.csv").exists())
+ self.assertEqual(len(dataset.unique_patient_ids), 2)
+
+ def test_patient_event_contains_signal_references(self):
+ dataset = DREAMTDataset(
+ root=str(self.synthetic_root),
+ preferred_source="psg",
+ cache_dir=self.temp_dir / "cache_psg",
+ )
+
+ patient = dataset.get_patient("S001")
+ event = patient.get_events("dreamt_sleep")[0]
+
+ self.assertTrue(event.signal_file.endswith("S001_PSG_df_updated.csv"))
+ self.assertTrue(event.file_64hz.endswith("S001_whole_df.csv"))
+ self.assertTrue(event.file_100hz.endswith("S001_PSG_df_updated.csv"))
+ self.assertEqual(event.signal_source, "psg")
+ self.assertEqual(float(event.sampling_rate_hz), 100.0)
+
+ def test_missing_root_raises_helpful_error(self):
+ with self.assertRaises(FileNotFoundError):
+ DREAMTDataset(root=str(self.temp_dir / "missing"))
+
+ def test_partial_download_drops_missing_subjects(self):
+ missing_wearable_file = self.synthetic_root / "data_64Hz" / "S002_whole_df.csv"
+ missing_psg_file = (
+ self.synthetic_root / "data_100Hz" / "S002_PSG_df_updated.csv"
+ )
+ missing_wearable_file.unlink()
+ missing_psg_file.unlink()
+ metadata_file = self.synthetic_root / "dreamt-metadata.csv"
+ if metadata_file.exists():
+ metadata_file.unlink()
+
+ dataset = DREAMTDataset(
+ root=str(self.synthetic_root),
+ cache_dir=self.temp_dir / "cache_partial",
+ )
+
+ self.assertEqual(len(dataset.unique_patient_ids), 1)
+ self.assertEqual(dataset.unique_patient_ids[0], "S001")
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_sleep_staging_task.py b/tests/core/test_sleep_staging_task.py
new file mode 100644
index 000000000..84f9a5a04
--- /dev/null
+++ b/tests/core/test_sleep_staging_task.py
@@ -0,0 +1,209 @@
+import shutil
+import tempfile
+import unittest
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+
+from pyhealth.datasets import DREAMTDataset
+from pyhealth.tasks import SleepStagingDREAMT, SleepStagingDREAMTSeq
+
+
+def _create_task_test_root(root: Path) -> Path:
+ dreamt_root = root / "dreamt"
+ (dreamt_root / "data_64Hz").mkdir(parents=True)
+
+ participant_info = pd.DataFrame(
+ {
+ "SID": ["S001"],
+ "AGE": [29.0],
+ "GENDER": ["F"],
+ "BMI": [22.0],
+ "OAHI": [1.0],
+ "AHI": [1.0],
+ "Mean_SaO2": ["98%"],
+ "Arousal Index": [9.0],
+ "MEDICAL_HISTORY": ["None"],
+ "Sleep_Disorders": ["None"],
+ }
+ )
+ participant_info.to_csv(dreamt_root / "participant_info.csv", index=False)
+
+ stages = ["P"] * 10 + ["W"] * 10 + ["N1"] * 10 + ["R"] * 10
+ timestamps = [index / 64.0 for index in range(len(stages))]
+ frame = pd.DataFrame(
+ {
+ "TIMESTAMP": timestamps,
+ "IBI": [0.8 + 0.01 * index for index in range(len(stages))],
+ "HR": [60.0 + index for index in range(len(stages))],
+ "BVP": [0.1] * len(stages),
+ "EDA": [0.01] * len(stages),
+ "TEMP": [33.2] * len(stages),
+ "ACC_X": [0.0] * len(stages),
+ "ACC_Y": [1.0] * len(stages),
+ "ACC_Z": [2.0] * len(stages),
+ "Sleep_Stage": stages,
+ }
+ )
+ frame.to_csv(dreamt_root / "data_64Hz" / "S001_whole_df.csv", index=False)
+ return dreamt_root
+
+
+def _create_seq_task_test_root(root: Path) -> Path:
+ dreamt_root = root / "dreamt_seq"
+ (dreamt_root / "data_64Hz").mkdir(parents=True)
+
+ participant_info = pd.DataFrame(
+ {
+ "SID": ["S001"],
+ "AGE": [29.0],
+ "GENDER": ["F"],
+ "BMI": [22.0],
+ "OAHI": [1.0],
+ "AHI": [1.0],
+ "Mean_SaO2": ["98%"],
+ "Arousal Index": [9.0],
+ "MEDICAL_HISTORY": ["None"],
+ "Sleep_Disorders": ["None"],
+ }
+ )
+ participant_info.to_csv(dreamt_root / "participant_info.csv", index=False)
+
+ stage_blocks = [
+ ("P", 30),
+ ("W", 30),
+ ("N2", 30),
+ ("R", 30),
+ ]
+ labels = [label for label, count in stage_blocks for _ in range(count)]
+ timestamps = np.arange(len(labels), dtype=np.float32)
+ frame = pd.DataFrame(
+ {
+ "TIMESTAMP": timestamps,
+ "IBI": np.linspace(0.8, 1.3, len(labels)),
+ "HR": np.linspace(60.0, 72.0, len(labels)),
+ "BVP": np.linspace(0.1, 0.4, len(labels)),
+ "EDA": np.linspace(0.01, 0.03, len(labels)),
+ "TEMP": np.full(len(labels), 33.2),
+ "ACC_X": np.zeros(len(labels)),
+ "ACC_Y": np.ones(len(labels)),
+ "ACC_Z": np.full(len(labels), 2.0),
+ "Sleep_Stage": labels,
+ }
+ )
+ frame.to_csv(dreamt_root / "data_64Hz" / "S001_whole_df.csv", index=False)
+ return dreamt_root
+
+
+class TestSleepStagingDREAMT(unittest.TestCase):
+ def setUp(self):
+ self.temp_dir = Path(tempfile.mkdtemp())
+ self.root = _create_task_test_root(self.temp_dir)
+ self.dataset = DREAMTDataset(
+ root=str(self.root),
+ cache_dir=self.temp_dir / "cache",
+ )
+
+ def tearDown(self):
+ shutil.rmtree(self.temp_dir)
+
+ def test_task_maps_labels_and_extracts_fixed_windows(self):
+ task = SleepStagingDREAMT(
+ window_size=10,
+ stride=10,
+ source_preference="wearable",
+ )
+ patient = self.dataset.get_patient("S001")
+
+ samples = task(patient)
+
+ self.assertEqual(len(samples), 3)
+ self.assertEqual([sample["label"] for sample in samples], [0, 1, 2])
+ self.assertEqual(samples[0]["signal"].shape, (10, 8))
+ self.assertEqual(samples[0]["patient_id"], "S001")
+ self.assertIn("record_id", samples[0])
+
+ def test_set_task_builds_sample_dataset(self):
+ task = SleepStagingDREAMT(
+ window_size=10,
+ stride=10,
+ source_preference="wearable",
+ )
+ sample_dataset = self.dataset.set_task(task=task, num_workers=1)
+
+ self.assertEqual(len(sample_dataset), 3)
+ sample = sample_dataset[0]
+ self.assertIn("signal", sample)
+ self.assertIn("label", sample)
+ self.assertEqual(tuple(sample["signal"].shape), (10, 8))
+
+class TestSleepStagingDREAMTSeq(unittest.TestCase):
+ def setUp(self):
+ self.temp_dir = Path(tempfile.mkdtemp())
+ self.root = _create_seq_task_test_root(self.temp_dir)
+ self.dataset = DREAMTDataset(
+ root=str(self.root),
+ cache_dir=self.temp_dir / "cache",
+ )
+
+ def tearDown(self):
+ shutil.rmtree(self.temp_dir)
+
+ def test_seq_task_builds_epoch_sequence(self):
+ task = SleepStagingDREAMTSeq(
+ epoch_seconds=30,
+ sequence_length=5,
+ source_preference="wearable",
+ )
+ patient = self.dataset.get_patient("S001")
+
+ samples = task(patient)
+
+ self.assertEqual(len(samples), 1)
+ self.assertEqual(samples[0]["signal"].shape, (5, 1))
+ self.assertEqual(samples[0]["mask"].tolist(), [1.0, 1.0, 1.0, 0.0, 0.0])
+ self.assertEqual(samples[0]["label"][:3].tolist(), [0, 1, 2])
+
+ def test_seq_task_outputs_mask_and_padded_labels(self):
+ task = SleepStagingDREAMTSeq(
+ epoch_seconds=30,
+ sequence_length=5,
+ source_preference="wearable",
+ ignore_index=-100,
+ )
+ patient = self.dataset.get_patient("S001")
+ sample = task(patient)[0]
+
+ self.assertTrue(np.allclose(sample["signal"][3:], 0.0))
+ self.assertEqual(sample["label"][3:].tolist(), [-100, -100])
+
+ def test_seq_task_ibi_only_mode(self):
+ task = SleepStagingDREAMTSeq(
+ feature_columns=("IBI",),
+ epoch_seconds=30,
+ sequence_length=5,
+ source_preference="wearable",
+ )
+ patient = self.dataset.get_patient("S001")
+ sample = task(patient)[0]
+
+ self.assertEqual(sample["signal"].shape[1], 1)
+
+ def test_seq_set_task_builds_sample_dataset(self):
+ task = SleepStagingDREAMTSeq(
+ epoch_seconds=30,
+ sequence_length=5,
+ source_preference="wearable",
+ )
+ sample_dataset = self.dataset.set_task(task=task, num_workers=1)
+ sample = sample_dataset[0]
+
+ self.assertIn("signal", sample)
+ self.assertIn("mask", sample)
+ self.assertIn("label", sample)
+ self.assertEqual(tuple(sample["signal"].shape), (5, 1))
+
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tests/core/test_watchsleepnet.py b/tests/core/test_watchsleepnet.py
new file mode 100644
index 000000000..e7d2417e2
--- /dev/null
+++ b/tests/core/test_watchsleepnet.py
@@ -0,0 +1,197 @@
+import unittest
+
+import numpy as np
+import torch
+
+from pyhealth.datasets import create_sample_dataset, get_dataloader
+from pyhealth.models import WatchSleepNet
+
+
+class TestWatchSleepNet(unittest.TestCase):
+ def setUp(self):
+ rng = np.random.default_rng(42)
+ self.samples = [
+ {
+ "patient_id": f"patient-{index}",
+ "record_id": f"record-{index}",
+ "signal": rng.normal(size=(24, 8)).astype(np.float32),
+ "label": index % 3,
+ }
+ for index in range(6)
+ ]
+ self.dataset = create_sample_dataset(
+ samples=self.samples,
+ input_schema={"signal": "tensor"},
+ output_schema={"label": "multiclass"},
+ dataset_name="watchsleepnet_test",
+ )
+ self.model = WatchSleepNet(
+ dataset=self.dataset,
+ hidden_dim=16,
+ conv_channels=16,
+ conv_blocks=2,
+ tcn_blocks=2,
+ num_attention_heads=4,
+ dropout=0.1,
+ )
+ self.sequence_samples = [
+ {
+ "patient_id": f"seq-patient-{index}",
+ "record_id": f"seq-record-{index}",
+ "signal": rng.normal(size=(5, 4)).astype(np.float32),
+ "mask": np.array([1, 1, 1, 0, 0], dtype=np.float32),
+ "label": np.array([0, 1, 2, -100, -100], dtype=np.int64),
+ }
+ for index in range(4)
+ ]
+ self.sequence_dataset = create_sample_dataset(
+ samples=self.sequence_samples,
+ input_schema={"signal": "tensor", "mask": "tensor"},
+ output_schema={"label": "tensor"},
+ dataset_name="watchsleepnet_seq_test",
+ )
+
+ def test_model_initialization(self):
+ self.assertIsInstance(self.model, WatchSleepNet)
+ self.assertEqual(self.model.feature_key, "signal")
+ self.assertEqual(self.model.input_dim, 8)
+ self.assertEqual(self.model.hidden_dim, 16)
+ self.assertEqual(self.model.conv_channels, 16)
+
+ def test_forward_shapes(self):
+ loader = get_dataloader(self.dataset, batch_size=3, shuffle=False)
+ batch = next(iter(loader))
+
+ with torch.no_grad():
+ output = self.model(**batch)
+
+ self.assertIn("loss", output)
+ self.assertIn("y_prob", output)
+ self.assertIn("y_true", output)
+ self.assertIn("logit", output)
+ self.assertEqual(tuple(output["logit"].shape), (3, 3))
+ self.assertEqual(tuple(output["y_prob"].shape), (3, 3))
+ self.assertEqual(tuple(output["y_true"].shape), (3,))
+
+ def test_backward_pass(self):
+ loader = get_dataloader(self.dataset, batch_size=3, shuffle=False)
+ batch = next(iter(loader))
+
+ output = self.model(**batch)
+ output["loss"].backward()
+
+ has_grad = any(
+ parameter.requires_grad and parameter.grad is not None
+ for parameter in self.model.parameters()
+ )
+ self.assertTrue(has_grad)
+
+ def test_without_attention_and_tcn(self):
+ model = WatchSleepNet(
+ dataset=self.dataset,
+ hidden_dim=12,
+ conv_channels=12,
+ num_attention_heads=3,
+ use_attention=False,
+ use_tcn=False,
+ )
+ loader = get_dataloader(self.dataset, batch_size=2, shuffle=False)
+ batch = next(iter(loader))
+
+ with torch.no_grad():
+ output = model(**batch)
+
+ self.assertEqual(tuple(output["logit"].shape), (2, 3))
+
+ def test_sequence_output_shapes(self):
+ model = WatchSleepNet(
+ dataset=self.sequence_dataset,
+ input_dim=4,
+ hidden_dim=12,
+ conv_channels=12,
+ num_attention_heads=3,
+ num_classes=3,
+ sequence_output=True,
+ )
+ loader = get_dataloader(self.sequence_dataset, batch_size=2, shuffle=False)
+ batch = next(iter(loader))
+
+ with torch.no_grad():
+ output = model(**batch)
+
+ self.assertEqual(tuple(output["logit"].shape), (2, 5, 3))
+ self.assertEqual(tuple(output["y_prob"].shape), (2, 5, 3))
+ self.assertEqual(tuple(output["y_true"].shape), (2, 5))
+
+ def test_sequence_mode_requires_num_classes(self):
+ with self.assertRaises(ValueError):
+ WatchSleepNet(
+ dataset=self.sequence_dataset,
+ input_dim=4,
+ hidden_dim=12,
+ conv_channels=12,
+ num_attention_heads=3,
+ sequence_output=True,
+ )
+
+ def test_sequence_loss_with_padding(self):
+ model = WatchSleepNet(
+ dataset=self.sequence_dataset,
+ input_dim=4,
+ hidden_dim=12,
+ conv_channels=12,
+ num_attention_heads=3,
+ num_classes=3,
+ sequence_output=True,
+ )
+ loader = get_dataloader(self.sequence_dataset, batch_size=2, shuffle=False)
+ batch = next(iter(loader))
+ output = model(**batch)
+
+ self.assertEqual(output["loss"].dim(), 0)
+ self.assertTrue(torch.isfinite(output["loss"]))
+
+ def test_sequence_backward_pass(self):
+ model = WatchSleepNet(
+ dataset=self.sequence_dataset,
+ input_dim=4,
+ hidden_dim=12,
+ conv_channels=12,
+ num_attention_heads=3,
+ num_classes=3,
+ sequence_output=True,
+ )
+ loader = get_dataloader(self.sequence_dataset, batch_size=2, shuffle=False)
+ batch = next(iter(loader))
+ output = model(**batch)
+ output["loss"].backward()
+
+ has_grad = any(
+ parameter.requires_grad and parameter.grad is not None
+ for parameter in model.parameters()
+ )
+ self.assertTrue(has_grad)
+
+ def test_sequence_mode_without_attention_and_tcn(self):
+ model = WatchSleepNet(
+ dataset=self.sequence_dataset,
+ input_dim=4,
+ hidden_dim=12,
+ conv_channels=12,
+ num_attention_heads=3,
+ num_classes=3,
+ sequence_output=True,
+ use_attention=False,
+ use_tcn=False,
+ )
+ loader = get_dataloader(self.sequence_dataset, batch_size=2, shuffle=False)
+ batch = next(iter(loader))
+
+ with torch.no_grad():
+ output = model(**batch)
+
+ self.assertEqual(tuple(output["logit"].shape), (2, 5, 3))
+
+
+if __name__ == "__main__":
+ unittest.main()