From e92d672be2ed0024d76a2f34805ca2746b70e3df Mon Sep 17 00:00:00 2001 From: Liudeng Zhang Date: Fri, 15 May 2026 13:38:36 -0500 Subject: [PATCH 1/2] guard torch.cuda.synchronize for CPU/MPS eval paths --- src/state/_cli/_emb/_eval.py | 5 +++-- src/state/emb/nn/eval_utils.py | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/state/_cli/_emb/_eval.py b/src/state/_cli/_emb/_eval.py index 74ddd8fb..3700e886 100644 --- a/src/state/_cli/_emb/_eval.py +++ b/src/state/_cli/_emb/_eval.py @@ -118,8 +118,9 @@ def load_config_override(config_path: str | None = None): with torch.no_grad(): with torch.autocast(device_type=device_type, dtype=precision): for batch in tqdm(dataloader, desc="Processing batches"): - torch.cuda.synchronize() - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.synchronize() + torch.cuda.empty_cache() # Compute embeddings _, _, _, emb, ds_emb = model._compute_embedding_for_batch(batch) diff --git a/src/state/emb/nn/eval_utils.py b/src/state/emb/nn/eval_utils.py index 283d2977..1df78a94 100644 --- a/src/state/emb/nn/eval_utils.py +++ b/src/state/emb/nn/eval_utils.py @@ -76,7 +76,8 @@ def evaluate_de(model, cfg, device=None, logger=print): pred_exp = model._predict_exp_for_adata( tmp_adata, cfg["validations"]["diff_exp"]["dataset_name"], cfg["validations"]["diff_exp"]["obs_pert_col"] ) - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() de_metrics = compute_gene_overlap_cross_pert( pred_exp, true_top_genes, k=cfg["validations"]["diff_exp"]["top_k_rank"] ) From a22a2bfb3ce5a2609563a0902102946abf68196d Mon Sep 17 00:00:00 2001 From: Liudeng Zhang Date: Fri, 15 May 2026 15:03:59 -0500 Subject: [PATCH 2/2] Use device_type variable for CUDA guard consistency Per review feedback: reuse the device_type variable defined above instead of calling torch.cuda.is_available() a second time. --- src/state/_cli/_emb/_eval.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/state/_cli/_emb/_eval.py b/src/state/_cli/_emb/_eval.py index 3700e886..d2313f4c 100644 --- a/src/state/_cli/_emb/_eval.py +++ b/src/state/_cli/_emb/_eval.py @@ -118,7 +118,7 @@ def load_config_override(config_path: str | None = None): with torch.no_grad(): with torch.autocast(device_type=device_type, dtype=precision): for batch in tqdm(dataloader, desc="Processing batches"): - if torch.cuda.is_available(): + if device_type == "cuda": torch.cuda.synchronize() torch.cuda.empty_cache()