Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""Session-wide test configuration.

Scanpy's bundled dataset loaders (``pbmc3k``, ``pbmc3k_processed``,
``pbmc68k_reduced``, ``paul15``) re-parse their h5ad/h5 from disk on *every*
call. Across this suite they are loaded ~100 times (``pbmc68k_reduced`` ~43x,
``pbmc3k`` ~50x, ``paul15`` ~4x), which is pure host-CPU/IO overhead — and
disproportionately expensive on the slow CI host.

We memoize each loader once per session and hand every caller an independent
``.copy()`` so existing tests can keep mutating their AnnData in place without
leaking state to other tests. Callers don't need to change: the loaders are
patched in place on ``scanpy.datasets`` before any test module is imported, so
both ``sc.datasets.pbmc3k()`` and ``from scanpy.datasets import pbmc3k`` pick up
the cached version.

Only the deterministic disk-backed loaders are cached. ``sc.datasets.blobs`` is
intentionally left alone (it synthesizes data per-call with varying parameters).
"""

from __future__ import annotations

import functools

import scanpy as sc

_CACHED_LOADERS = ("pbmc3k", "pbmc3k_processed", "pbmc68k_reduced", "paul15")


def _memoize_loader(loader):
cache = {}

@functools.wraps(loader)
def wrapper(*args, **kwargs):
try:
key = (args, tuple(sorted(kwargs.items())))
hash(key)
except TypeError:
# Unhashable arguments (not used in this suite) -> don't cache.
return loader(*args, **kwargs)
if key not in cache:
cache[key] = loader(*args, **kwargs)
return cache[key].copy()

return wrapper


for _name in _CACHED_LOADERS:
setattr(sc.datasets, _name, _memoize_loader(getattr(sc.datasets, _name)))
25 changes: 23 additions & 2 deletions tests/dask/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny


@pytest.fixture(scope="module")
@pytest.fixture(scope="session")
def cluster():
cluster = LocalCUDACluster(
CUDA_VISIBLE_DEVICES="0",
Expand All @@ -19,7 +19,28 @@ def cluster():


@pytest.fixture(scope="function")
def client(cluster):
def dist_client(cluster):
"""Real distributed client backed by a (session-scoped) LocalCUDACluster.

Only needed by tests that exercise the multi-GPU ``cuml.dask`` /
``cugraph.dask`` code paths (dask clustering, dask logreg, dense ``full``
dask PCA), which raise ``ValueError: No clients found`` without a live
distributed client. The client itself stays function-scoped so each test
gets an isolated client (connecting to the shared cluster is cheap).
"""
client = Client(cluster)
yield client
client.close()


@pytest.fixture(scope="function")
def client():
"""Lightweight no-op stand-in for scheduler-agnostic dask tests.

The vast majority of dask tests only build dask arrays and call
``.compute()`` / ``.persist()``, which run on dask's default scheduler and
never touch the client object. Handing them ``None`` avoids spinning up a
LocalCUDACluster and skips the distributed serialization round-trips of
cupy chunks, which are pure overhead on the tiny test arrays.
"""
yield None
4 changes: 2 additions & 2 deletions tests/dask/test_dask_clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


@pytest.mark.parametrize("clustering_function", [rsc.tl.leiden, rsc.tl.louvain])
def test_dask_clustering(client, clustering_function):
def test_dask_clustering(dist_client, clustering_function):
adata = pbmc3k_processed()
clustering_function(adata, use_dask=True, key_added="test_dask")
clustering_function(adata, key_added="test_no_dask")
Expand All @@ -22,7 +22,7 @@ def test_dask_clustering(client, clustering_function):

@pytest.mark.parametrize("clustering_function", [rsc.tl.leiden, rsc.tl.louvain])
@pytest.mark.parametrize("resolution", [0.1, [0.5, 1.0]])
def test_dask_clustering_resolution(client, clustering_function, resolution):
def test_dask_clustering_resolution(dist_client, clustering_function, resolution):
adata = pbmc3k_processed()
clustering_function(
adata, use_dask=True, key_added="test_dask", resolution=resolution
Expand Down
2 changes: 1 addition & 1 deletion tests/dask/test_dask_pca.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
@pytest.mark.parametrize("zero_center", [True, False])
@pytest.mark.flaky(reruns=2, reruns_delay=5)
def test_pca_dask(client, data_kind, zero_center):
def test_pca_dask(dist_client, data_kind, zero_center):
adata_1 = pbmc3k_processed()
adata_2 = pbmc3k_processed()

Expand Down
2 changes: 1 addition & 1 deletion tests/dask/test_dask_rank_logreg.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _compare_top_genes(result1, result2, top_n=10, min_overlap=9):

@pytest.mark.parametrize("data_kind", ["sparse", "dense"])
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
def test_rank_genes_groups_logreg(client, data_kind, dtype):
def test_rank_genes_groups_logreg(dist_client, data_kind, dtype):
if data_kind == "dense":
adata = pbmc68k_reduced()
adata.X = adata.X.astype(dtype)
Expand Down
29 changes: 22 additions & 7 deletions tests/test_harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,22 +110,27 @@ def test_harmony_integrate_bad_prune_threshold(bad_threshold):
)


@pytest.mark.filterwarnings("ignore:Harmony did not converge")
@pytest.mark.parametrize("correction_method", ["fast", "original", "batched"])
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
def test_harmony_integrate(correction_method, dtype):
def test_harmony_integrate(correction_method):
"""
Test that Harmony integrate works.

This is a very simple test that just checks to see if the Harmony
integrate wrapper successfully added a new field to ``adata.obsm``
and makes sure it has the same dimensions as the original PCA table.

This is a pure shape/contract check: the output shape is independent of
dtype and iteration count, so we run float32 with a single harmony
iteration to exercise all three correction-method paths cheaply.
"""
adata = sc.datasets.pbmc68k_reduced()
rsc.pp.harmony_integrate(
adata,
"bulk_labels",
correction_method=correction_method,
dtype=dtype,
dtype=cp.float32,
max_iter_harmony=1,
)
assert adata.obsm["X_pca_harmony"].shape == adata.obsm["X_pca"].shape

Expand Down Expand Up @@ -228,6 +233,7 @@ def test_harmony_integrate_reference(
)


@pytest.mark.filterwarnings("ignore:Harmony did not converge")
@pytest.mark.parametrize("correction_method", ["original", "batched"])
@pytest.mark.parametrize("dtype", [cp.float64, cp.float32])
def test_harmony2_correction_methods_agree(
Expand All @@ -240,7 +246,7 @@ def test_harmony2_correction_methods_agree(
"donor",
correction_method=correction_method,
dtype=dtype,
max_iter_harmony=20,
max_iter_harmony=5,
)
h2 = adata.obsm["X_pca_harmony"]

Expand All @@ -251,7 +257,7 @@ def test_harmony2_correction_methods_agree(
"donor",
correction_method="fast",
dtype=dtype,
max_iter_harmony=20,
max_iter_harmony=5,
)
h2_ref = adata_ref.obsm["X_pca_harmony"]

Expand Down Expand Up @@ -450,8 +456,17 @@ def test_compute_lambda_kb_zero_denom(dtype):
cp.testing.assert_allclose(result[0, 1], dtype(1.0))


@pytest.mark.parametrize("correction_method", ["fast", "original", "batched"])
@pytest.mark.parametrize("dtype", [cp.float32, cp.float64])
@pytest.mark.parametrize(
("dtype", "correction_method"),
[
(cp.float32, "fast"),
(cp.float32, "original"),
(cp.float32, "batched"),
# float64 numeric reference for `fast` only: float64 original/batched
# agreement with `fast` is covered by test_harmony2_correction_methods_agree
(cp.float64, "fast"),
],
)
def test_harmony2_ircolitis_reference(
adata_ircolitis_harmony2, correction_method, dtype
):
Expand Down
Loading