diff --git a/src/state/_cli/_emb/_eval.py b/src/state/_cli/_emb/_eval.py index 74ddd8fb..d2313f4c 100644 --- a/src/state/_cli/_emb/_eval.py +++ b/src/state/_cli/_emb/_eval.py @@ -118,8 +118,9 @@ def load_config_override(config_path: str | None = None): with torch.no_grad(): with torch.autocast(device_type=device_type, dtype=precision): for batch in tqdm(dataloader, desc="Processing batches"): - torch.cuda.synchronize() - torch.cuda.empty_cache() + if device_type == "cuda": + torch.cuda.synchronize() + torch.cuda.empty_cache() # Compute embeddings _, _, _, emb, ds_emb = model._compute_embedding_for_batch(batch) diff --git a/src/state/emb/nn/eval_utils.py b/src/state/emb/nn/eval_utils.py index 283d2977..1df78a94 100644 --- a/src/state/emb/nn/eval_utils.py +++ b/src/state/emb/nn/eval_utils.py @@ -76,7 +76,8 @@ def evaluate_de(model, cfg, device=None, logger=print): pred_exp = model._predict_exp_for_adata( tmp_adata, cfg["validations"]["diff_exp"]["dataset_name"], cfg["validations"]["diff_exp"]["obs_pert_col"] ) - torch.cuda.synchronize() + if torch.cuda.is_available(): + torch.cuda.synchronize() de_metrics = compute_gene_overlap_cross_pert( pred_exp, true_top_genes, k=cfg["validations"]["diff_exp"]["top_k_rank"] )