diff --git a/dask_ml/model_selection/_split.py b/dask_ml/model_selection/_split.py index ed1ac30ca..a8e5e598b 100644 --- a/dask_ml/model_selection/_split.py +++ b/dask_ml/model_selection/_split.py @@ -9,6 +9,7 @@ import dask.array as da import dask.dataframe as dd import numpy as np +import pandas as pd import sklearn.model_selection as ms from sklearn.model_selection._split import BaseCrossValidator, _validate_shuffle_split from sklearn.utils import check_random_state @@ -358,12 +359,360 @@ def _blockwise_slice(arr, idx): return sliced +def _largest_remainder_split(bin_sizes, total, rng): + """Split ``total`` units across bins, proportional to ``bin_sizes``. + + Returns integers summing to ``total``. Each bin gets its floor share + ``bin_size * total / sum(bin_sizes)``; the leftover slots go to bins + with the largest fractional remainder. ``rng`` breaks ties. + + See https://en.wikipedia.org/wiki/Largest_remainders_method. + """ + pool = int(bin_sizes.sum()) + if pool == 0 or total == 0: + return np.zeros_like(bin_sizes) + products = bin_sizes * int(total) + quotas = products // pool + remainders = products % pool + leftover = int(total) - int(quotas.sum()) + tiebreaker = rng.random(len(bin_sizes)) + priority = np.lexsort((tiebreaker, -remainders)) + quotas[priority[:leftover]] += 1 + return quotas + + +def _get_test_count_per_class(n_per_class, n_test, rng): + """How many test rows each class contributes, summing to ``n_test``. + + Counts are proportional to class size via largest-remainder rounding. + sklearn rule: every class keeps at least one row in both train and + test splits. Over/underflow from that clamp is redistributed to + classes that still have headroom. + """ + if (n_per_class < 2).any(): + raise ValueError( + "The least populated class in y has only 1 member, which is too " + "few. The minimum number of groups for any class cannot be less " + "than 2." + ) + quotas = _largest_remainder_split(n_per_class, n_test, rng) + floor = np.ones_like(quotas) + ceil = n_per_class - 1 + for _ in range(len(quotas) + 1): + clamped = np.clip(quotas, floor, ceil) + delta = n_test - int(clamped.sum()) + if delta == 0: + return clamped + has_room = clamped < ceil if delta > 0 else clamped > floor + eligible = np.where(has_room)[0] + if len(eligible) == 0: + return clamped + step = 1 if delta > 0 else -1 + for class_idx in rng.permutation(eligible)[: abs(delta)]: + clamped[class_idx] += step + quotas = clamped + return quotas + + +def _check_no_missing_labels(arr): + """Raise ValueError if ``arr`` contains NaN, None, or pd.NA. + + ``np.unique`` treats every NaN as distinct (NaN != NaN), so silently + splitting on missing labels would create phantom classes. sklearn + refuses such inputs; we do the same. ``pd.isna`` handles float NaN, + object-dtype None, and pandas NA in one call. + + Void-dtype arrays are packed compound labels from ``_as_1d_block`` and + can't be NaN by construction; skip them (and ``pd.isna`` rejects void). + """ + if arr.dtype.kind == "V": + return + if pd.isna(arr).any(): + raise ValueError( + "Input contains NaN/None/NA. `stratify` must not contain missing values." + ) + + +def _get_class_freq_in_block(block): + """Return the unique class ids and their counts in each stratify block.""" + block = np.asarray(block).ravel() + _check_no_missing_labels(block) + return np.unique(block, return_counts=True) + + +def _get_test_count_per_class_block(class_freq_per_block, test_size, seed): + """How many test rows each block contributes for each class. + + Runs on the driver as a single ``dask.delayed`` task once all per-block + summaries arrive. Two steps: + + 1. Split ``n_test`` across classes proportional to class size, with + sklearn's ">=1 row per class in both splits" rule + (``_get_test_count_per_class``). + 2. For each class, split its test-row budget across blocks proportional + to how many rows of that class each block holds + (``_largest_remainder_split``). + + Parameters + ---------- + class_freq_per_block : list of (numpy.ndarray, numpy.ndarray) + One ``(classes_in_block, counts_in_block)`` tuple per block, from + ``_get_class_freq_in_block``. + test_size : float or int + Fraction (0 < x < 1) or absolute count of test rows. + seed : int + Seed for the driver's ``np.random.Generator`` (tie-breaking only). + + Returns + ------- + (numpy.ndarray, numpy.ndarray) + ``(class_ids, test_counts)`` where ``test_counts[block, class]`` + is the number of rows of ``class_ids[class]`` that ``block`` puts + into the test set. Shape: ``(n_blocks, n_classes)``. + """ + class_ids = np.unique(np.concatenate([c for c, _ in class_freq_per_block])) + n_blocks = len(class_freq_per_block) + n_classes = len(class_ids) + n_per_class_per_block = np.zeros((n_blocks, n_classes), dtype=np.int64) + for bi, (ci, fi) in enumerate(class_freq_per_block): + n_per_class_per_block[bi, np.searchsorted(class_ids, ci)] = fi + + n_per_class = n_per_class_per_block.sum(axis=0) + n_total = n_per_class.sum() + if isinstance(test_size, (int, np.integer)) and not isinstance(test_size, bool): + n_test = int(test_size) + else: + n_test = max(1, min(n_total - 1, int(round(n_total * float(test_size))))) + + rng = np.random.default_rng(seed) + test_count_per_class = _get_test_count_per_class(n_per_class, n_test, rng) + test_count_per_class_block = np.column_stack( + [ + _largest_remainder_split( + n_per_class_per_block[:, ci], test_count_per_class[ci], rng + ) + for ci in range(n_classes) + ] + ) + return class_ids, test_count_per_class_block + + +def _get_train_test_indices_per_block( + block, test_count_per_class_block, block_idx, seed, shuffle +): + """Pick this block's train/test row indices. + + For each class, picks ``test_counts[block_idx, class]`` rows of that + class uniformly without replacement. The per-block RNG is seeded with + ``(seed, block_idx)`` so blocks sample independently and + reproducibly. When ``shuffle`` is True, the returned index arrays are + permuted so output row order is randomized within each block. + """ + class_ids, test_counts = test_count_per_class_block + rng = np.random.default_rng((int(seed), int(block_idx))) + strat_here = np.asarray(block).ravel() + is_test = np.zeros(len(strat_here), dtype=bool) + for class_id, rows_to_pick in zip(class_ids, test_counts[block_idx]): + rows_to_pick = int(rows_to_pick) + if rows_to_pick == 0: + continue + positions_of_class = np.where(strat_here == class_id)[0] + is_test[rng.choice(positions_of_class, size=rows_to_pick, replace=False)] = True + test_idx = np.flatnonzero(is_test) + train_idx = np.flatnonzero(~is_test) + if shuffle: + test_idx = rng.permutation(test_idx) + train_idx = rng.permutation(train_idx) + return train_idx, test_idx + + +def _slice_block(block, indices_pair, take_test): + """Return rows of ``block`` at the train or test indices. + + Module-level so dask can pickle it across workers. + """ + rows = np.asarray(block) + return rows[indices_pair[1] if take_test else indices_pair[0]] + + +def _as_1d_block(block): + """Collapse each row of a 2-D stratify block into a single hashable scalar. + + For compound stratification (e.g. multilabel ``y``), sklearn treats each + unique tuple of row values as one class. Numpy's ``np.void`` view + interprets each row's bytes as a single scalar - zero-copy and orderable, + so ``np.unique`` / ``np.searchsorted`` group identical tuples together. + """ + block = np.ascontiguousarray(block) + void_dtype = np.dtype((np.void, block.dtype.itemsize * block.shape[1])) + return block.view(void_dtype).ravel() + + +def _axis0_layout(obj): + """Axis-0 ``(block_count, chunks_or_None)`` for a dask collection. + + ``chunks`` is ``None`` when block sizes are unknown (e.g. ``from_delayed`` + or a dask DataFrame whose divisions don't carry block lengths). Only + block count can be compared in that case. + """ + if isinstance(obj, da.Array): + chunks = obj.chunks[0] + if any(np.isnan(c) for c in chunks): + return obj.numblocks[0], None + return obj.numblocks[0], tuple(chunks) + return obj.npartitions, None + + +def _as_input_aligned_1d_array(stratify, arrays): + """Wrap ``stratify`` as a 1-D dask Array, axis-0 aligned to ``arrays[0]``. + + Non-dask inputs are split to match the reference axis-0 chunks. Dask + inputs are validated against that layout - silently re-chunking a + user's collection would change their execution plan, and a mismatched + layout would misalign per-block label indices with feature blocks. + Compound (2-D) labels are packed row-wise into ``np.void`` scalars so + the rest of the pipeline can treat them as ordinary 1-D class ids. + """ + ref_blocks, ref_chunks = _axis0_layout(arrays[0]) + + if isinstance(stratify, (da.Array, dd.Series, dd.DataFrame)): + blocks, chunks = _axis0_layout(stratify) + if blocks != ref_blocks or ( + ref_chunks is not None and chunks is not None and chunks != ref_chunks + ): + ref_desc = ref_chunks if ref_chunks is not None else f"{ref_blocks} blocks" + obj_desc = chunks if chunks is not None else f"{blocks} blocks" + raise ValueError( + f"Axis-0 partitioning of stratify ({obj_desc}) does not match " + f"arrays[0] ({ref_desc}). Rechunk to align, e.g. " + f"`stratify.rechunk({{0: arrays[0].chunks[0]}})` for dask " + f"Arrays or `stratify.repartition(npartitions=" + f"arrays[0].npartitions)` for dask DataFrames." + ) + if isinstance(stratify, (dd.Series, dd.DataFrame)): + stratify = stratify.to_dask_array() + else: + as_numpy = np.asarray(stratify) + # Eager check: in-memory inputs cost nothing to validate now and + # give users an immediate error instead of one at .compute() time. + _check_no_missing_labels(as_numpy) + # Match arr0's axis-0 chunks exactly when known; otherwise fall back + # to an even split into ref_blocks pieces. + if ref_chunks is not None and sum(ref_chunks) == as_numpy.shape[0]: + axis0_chunks = ref_chunks + elif ref_blocks > 1 and as_numpy.shape[0] >= ref_blocks: + splits = np.array_split(as_numpy, ref_blocks, axis=0) + axis0_chunks = tuple(s.shape[0] for s in splits) + else: + axis0_chunks = (as_numpy.shape[0],) + chunks = (axis0_chunks,) + tuple((s,) for s in as_numpy.shape[1:]) + stratify = da.from_array(as_numpy, chunks=chunks) + + if stratify.ndim > 1: + stratify = stratify.rechunk({axis: -1 for axis in range(1, stratify.ndim)}) + stratify = stratify.map_blocks( + _as_1d_block, drop_axis=list(range(1, stratify.ndim)) + ) + return stratify + + +def _stratified_split(arrays, stratify, test_size, random_state, shuffle=True): + """Lazy stratified train/test split over dask collections. + + Builds a delayed graph; nothing computes until the user calls + ``.compute()`` on an output. Feature data is never gathered to the + driver, only a KB-scale ``(n_blocks, n_classes)`` counts matrix is. + + High level approach: + + 1. each stratify block emits ``(classes_in_block, counts_in_block)`` + via ``_get_class_freq_in_block`` (one delayed task per block). + 2. the driver merges those into ``(class_ids, test_counts[p, c])`` + via ``_get_test_count_per_class_block`` (one delayed task). + 3. each stratify block picks its own test rows + via ``_get_train_test_indices_per_block`` (one delayed task per block). + 4. each input array's blocks are sliced by the matching mask + via ``_slice_block``; train and test blocks are concatenated. + + Output arrays have ``nan`` for their first dimension until ``.compute()`` + is called (same as ``da.where(cond)[0]``). + + Parameters + ---------- + arrays : sequence of dask Arrays / Series / DataFrames + Inputs to split. Each must share the same axis-0 block count as + ``stratify``. + stratify : dask Array/Series/DataFrame, numpy array, pandas Series, or list + Class labels. May be 1-D or 2-D (rows treated as compound labels). + test_size : float or int + Fraction (0 < x < 1) or absolute count of test rows. + random_state : int, RandomState, or None + Seed for the split. + shuffle : bool, default True + If True, permute row order within each output block (matches + sklearn and the non-stratified dask paths). If False, preserve the + original row order within each block (rows that survive + selection appear in the order they had in the input). + + Returns + ------- + list of dask.array.Array + ``[X1_train, X1_test, X2_train, X2_test, ...]``. + """ + stratify = _as_input_aligned_1d_array(stratify, arrays) + strat_blocks = stratify.to_delayed().ravel().tolist() + + rng = check_random_state(random_state) + seed = int(draw_seed(rng, 0, _I4MAX, dtype="uint")) + + class_freq_per_block = [ + dask.delayed(_get_class_freq_in_block)(block) for block in strat_blocks + ] + test_count_per_class_block = dask.delayed(_get_test_count_per_class_block)( + class_freq_per_block, test_size, seed + ) + train_test_idx_per_block = [ + dask.delayed(_get_train_test_indices_per_block)( + block, test_count_per_class_block, block_idx, seed, shuffle + ) + for block_idx, block in enumerate(strat_blocks) + ] + + outputs = [] + for arr in arrays: + if isinstance(arr, (dd.Series, dd.DataFrame)): + arr = arr.to_dask_array() + if arr.ndim > 1: + arr = arr.rechunk({axis: -1 for axis in range(1, arr.ndim)}) + feature_blocks = arr.to_delayed().ravel().tolist() + trailing_shape = arr.shape[1:] + train_blocks = [ + da.from_delayed( + dask.delayed(_slice_block)(block, idx, False), + shape=(np.nan, *trailing_shape), + dtype=arr.dtype, + ) + for block, idx in zip(feature_blocks, train_test_idx_per_block) + ] + test_blocks = [ + da.from_delayed( + dask.delayed(_slice_block)(block, idx, True), + shape=(np.nan, *trailing_shape), + dtype=arr.dtype, + ) + for block, idx in zip(feature_blocks, train_test_idx_per_block) + ] + outputs += [da.concatenate(train_blocks), da.concatenate(test_blocks)] + return outputs + + def train_test_split( *arrays, test_size=None, train_size=None, random_state=None, shuffle=None, + stratify=None, blockwise=True, convert_mixed_types=False, **options, @@ -372,18 +721,33 @@ def train_test_split( Parameters ---------- - *arrays : Sequence of Dask Arrays, DataFrames, or Series + ``*arrays`` : Sequence of Dask Arrays, DataFrames, or Series Non-dask objects will be passed through to :func:`sklearn.model_selection.train_test_split`. + test_size : float or int, default 0.1 + train_size : float or int, optional + random_state : int, RandomState instance or None, optional (default=None) If int, random_state is the seed used by the random number generator; If RandomState instance, random_state is the random number generator; If None, the random number generator is the RandomState instance used by `np.random`. + shuffle : bool, default None Whether to shuffle the data before splitting. + + stratify : array-like, optional (default=None) + If not None, data is split in a stratified fashion using this as + the class labels (sklearn semantics). Stays lazy: index computation + runs inside a delayed task, triggered only when ``.compute()`` is + called on a split output. The ``*arrays`` themselves are never + gathered to the driver, but ``stratify`` and the resulting index + arrays do materialize on a single worker, so memory scales with the + label vector (not the feature matrix). ``stratify`` may be a dask + Array/Series/DataFrame or any array-like accepted by sklearn. + blockwise : bool, default True. Whether to shuffle data only within blocks (True), or allow data to be shuffled between blocks (False). Shuffling between blocks can @@ -399,6 +763,7 @@ def train_test_split( arrays contains a mixture of types. This results in some computation to determine the length of each block. + Returns ------- splitting : list, length=2 * len(arrays) @@ -431,6 +796,46 @@ def train_test_split( if options: raise TypeError("Unexpected options {}".format(options)) + if stratify is not None: + if ( + isinstance(stratify, bool) + or np.isscalar(stratify) + or getattr(stratify, "ndim", 1) == 0 + ): + raise TypeError( + "`stratify` must be an array of class labels (one per sample), " + "not a scalar/boolean. To stratify on the target, pass the " + "target itself: train_test_split(X, y, stratify=y). Got " + f"{type(stratify).__name__}={stratify!r}." + ) + if not hasattr(stratify, "__array__") and not isinstance( + stratify, (da.Array, dd.Series, dd.DataFrame, list, tuple) + ): + raise TypeError( + "`stratify` must be array-like (dask Array/Series/DataFrame, " + f"numpy array, pandas Series, or list). Got {type(stratify).__name__}." + ) + + def _known_len(obj): + shape = getattr(obj, "shape", None) + if shape is not None and len(shape) > 0: + n = shape[0] + if isinstance(n, int): + return n + return None + try: + return len(obj) + except TypeError: + return None + + len_a = _known_len(arrays[0]) + len_s = _known_len(stratify) + if len_a is not None and len_s is not None and len_a != len_s: + raise ValueError( + f"Length of stratify ({len_s}) does not match length of " + f"input arrays ({len_a})." + ) + types = set(type(arr) for arr in arrays) if da.Array in types and types & {dd.Series, dd.DataFrame}: @@ -480,6 +885,12 @@ def train_test_split( f" dask versions<2.13.0. Current version is {DASK_VERSION}." ) kwargs = {} + + if stratify is not None: + return _stratified_split( + arrays, stratify, test_size, random_state, shuffle=shuffle + ) + return list( itertools.chain.from_iterable( arr.random_split([train_size, test_size], random_state=rng, **kwargs) @@ -495,6 +906,15 @@ def train_test_split( "'shuffle=False' is not currently supported for dask Arrays." ) + if stratify is not None: + if blockwise is False: + raise NotImplementedError( + "'blockwise=False' is not supported for stratified splits." + ) + return _stratified_split( + arrays, stratify, test_size, random_state, shuffle=shuffle + ) + splitter = ShuffleSplit( n_splits=1, test_size=test_size, @@ -511,10 +931,12 @@ def train_test_split( return list(itertools.chain.from_iterable(train_test_pairs)) else: - return ms.train_test_split( - *arrays, + kwargs = dict( test_size=test_size, train_size=train_size, random_state=random_state, - shuffle=shuffle, + stratify=stratify, ) + if shuffle is not None: + kwargs["shuffle"] = shuffle + return ms.train_test_split(*arrays, **kwargs) diff --git a/tests/model_selection/test_split.py b/tests/model_selection/test_split.py index b5df46fa1..6c2c74d81 100644 --- a/tests/model_selection/test_split.py +++ b/tests/model_selection/test_split.py @@ -7,6 +7,7 @@ import dask_ml.model_selection from dask_ml._compat import DASK_2130 +from dask_ml.model_selection import train_test_split X, y = make_regression(n_samples=110, n_features=5) dX = da.from_array(X, 50) @@ -259,3 +260,149 @@ def test_split_3d_data(): assert X_train.ndim == X_3d.ndim assert X_train.shape[1:] == X_3d.shape[1:] + + +def _counts(y): + arr = y.compute() if isinstance(y, (da.Array, dd.Series, dd.DataFrame)) else y + return dict(zip(*np.unique(np.asarray(arr).ravel(), return_counts=True))) + + +def test_stratify_dask_array_correctness(): + """Ratios + X shape + sum-of-splits + disjoint rows on dask.Array.""" + y_np = np.repeat([0, 1, 2], [600, 300, 100]) + X_np = np.arange(1000 * 4).reshape(1000, 4) + X = da.from_array(X_np, chunks=200) + y = da.from_array(y_np, chunks=200) + + Xtr, Xte, ytr, yte = train_test_split( + X, y, stratify=y, random_state=0, test_size=0.2 + ) + assert _counts(ytr) == {0: 480, 1: 240, 2: 80} + assert _counts(yte) == {0: 120, 1: 60, 2: 20} + Xtr_c, Xte_c = Xtr.compute(), Xte.compute() + assert Xtr_c.shape == (800, 4) and Xte_c.shape == (200, 4) + assert set(map(tuple, Xtr_c)).isdisjoint(set(map(tuple, Xte_c))) + + +def test_stratify_dask_dataframe_correctness(): + df = pd.DataFrame( + { + "a": np.arange(1000), + "b": np.arange(1000, 2000), + "label": np.repeat([0, 1, 2], [500, 300, 200]), + } + ) + ddf = dd.from_pandas(df, npartitions=5) + Xtr, Xte, ytr, yte = train_test_split( + ddf[["a", "b"]], + ddf["label"], + stratify=ddf["label"], + random_state=0, + test_size=0.2, + shuffle=True, + ) + assert _counts(ytr) == {0: 400, 1: 240, 2: 160} + assert _counts(yte) == {0: 100, 1: 60, 2: 40} + + +def test_stratify_reproducibility_and_output_order(): + """Same seed gives identical outputs. Output order interleaved, not class-grouped.""" + y_np = np.tile([0, 1, 2], 333)[:999] + X_np = np.arange(999 * 2).reshape(999, 2) + X = da.from_array(X_np, chunks=200) + y = da.from_array(y_np, chunks=200) + + a = train_test_split(X, y, stratify=y, random_state=42, test_size=0.2) + b = train_test_split(X, y, stratify=y, random_state=42, test_size=0.2) + for x, y_ in zip(a, b): + np.testing.assert_array_equal(x.compute(), y_.compute()) + + ytr = a[2].compute() + # naive per-class concat would yield [0,0,...,1,1,...,2,2,...]: 2 transitions + transitions = int(np.sum(ytr[1:] != ytr[:-1])) + assert transitions > len(ytr) // 4 + + +@pytest.mark.parametrize("as_dask", [False, True], ids=["numpy", "dask"]) +def test_stratify_2d_compound(as_dask): + """2D stratify = compound (multilabel) class labels, sklearn semantics.""" + n = 80 + col_a = np.repeat([0, 1], n // 2) + col_b = np.tile([0, 1], n // 2) + strat_np = np.column_stack([col_a, col_b]) # 4 compound classes, 20 each + strat = da.from_array(strat_np, chunks=(20, 2)) if as_dask else strat_np + X = da.from_array(np.arange(n).reshape(-1, 1), chunks=20) + + Xtr, Xte = train_test_split(X, stratify=strat, random_state=0, test_size=0.25) + tr_idx = set(Xtr.compute().ravel()) + te_idx = set(Xte.compute().ravel()) + assert len(tr_idx) == 60 and len(te_idx) == 20 + for a in (0, 1): + for b in (0, 1): + cls_rows = set(np.where((strat_np[:, 0] == a) & (strat_np[:, 1] == b))[0]) + assert len(cls_rows & tr_idx) == 15 + assert len(cls_rows & te_idx) == 5 + + +@pytest.mark.parametrize( + "stratify,extra,err,match", + [ + (True, {}, TypeError, "must be an array of class labels"), + (1, {}, TypeError, "must be an array of class labels"), + ("y", {"shuffle": False}, NotImplementedError, "shuffle=False"), + ], + ids=["bool", "scalar", "shuffle_false"], +) +def test_stratify_invalid_args(stratify, extra, err, match): + X = da.from_array(np.random.RandomState(0).random((100, 2)), chunks=50) + y = da.from_array(np.repeat([0, 1], 50), chunks=50) + kwargs = {"stratify": y if stratify == "y" else stratify} + with pytest.raises(err, match=match): + train_test_split(X, y, **kwargs, **extra) + + +def test_stratify_length_mismatch_raises(): + X = da.from_array(np.random.RandomState(0).random((100, 2)), chunks=50) + bad = da.from_array(np.repeat([0, 1], 40), chunks=40) + with pytest.raises(ValueError, match="[Ll]ength"): + train_test_split(X, stratify=bad, random_state=0) + + +@pytest.mark.parametrize( + "bad_labels,kind", + [ + (np.array([0.0, np.nan, 1.0, 1.0, 0.0]), "float-nan"), + (np.array([0, None, 1, 1, 0], dtype=object), "object-none"), + (pd.array([0, pd.NA, 1, 1, 0], dtype="Int64"), "pandas-NA"), + ], +) +def test_stratify_nan_labels_rejected(bad_labels, kind): + """In-memory stratify with missing values is rejected eagerly.""" + X = np.arange(5 * 2).reshape(5, 2) + Xd = da.from_array(X, chunks=3) + with pytest.raises(ValueError, match="NaN|NA|missing"): + train_test_split(Xd, stratify=bad_labels, random_state=0, test_size=0.4) + + +def test_stratify_nan_labels_rejected_dask_lazy(): + """Dask stratify with NaN: validation is lazy (per-block, on .compute()). + + Eager validation would require loading the label array on the driver, + which we explicitly avoid for dask inputs. Graph build succeeds; the + error surfaces when the user calls .compute(). + """ + X = da.from_array(np.arange(10 * 2).reshape(10, 2), chunks=5) + y = da.from_array(np.array([0.0, np.nan, 1.0, 1.0, 0.0] * 2), chunks=5) + Xtr, Xte = train_test_split(X, stratify=y, random_state=0, test_size=0.4) + assert isinstance(Xtr, da.Array) # graph built lazily, no error yet + with pytest.raises(ValueError, match="NaN|NA|missing"): + Xtr.compute() + + +def test_stratify_misaligned_partitions_raise(): + """Dask stratify with different partition count than input arrays.""" + rng = np.random.RandomState(0) + X = da.from_array(rng.random((1000, 2)), chunks=200) # 5 blocks + y = da.from_array(np.repeat([0, 1], 500), chunks=334) # 3 blocks + with pytest.raises(ValueError, match="[Aa]xis-0 partitioning"): + train_test_split(X, stratify=y, random_state=0, test_size=0.2)