Skip to content

WIP: Add stratified split feature to model_selection.train_test_split#635

Open
chauhankaranraj wants to merge 1 commit into
dask:mainfrom
chauhankaranraj:master
Open

WIP: Add stratified split feature to model_selection.train_test_split#635
chauhankaranraj wants to merge 1 commit into
dask:mainfrom
chauhankaranraj:master

Conversation

@chauhankaranraj
Copy link
Copy Markdown
Contributor

@chauhankaranraj chauhankaranraj commented Apr 3, 2020

I took a stab at implementing a solution for issue #535

Adding a WIP label because currently the stratified split is not completely lazily for dask arrays (compute_chunk_sizes being called here). Nonetheless, I think it works fine for dask series and dataframes.

Any feedback would be appreciated :)

Copy link
Copy Markdown
Member

@TomAugspurger TomAugspurger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you say a bit about the high-level strategy here, and the challenges?

Stepping back, the idea behind stratify is to get approximately the same frequency of each class in the output splits as in the input? So do we absolutely require computing? At the very least, I think we'll need a full pass over the data to compute the frequencies in stratify, since the data may not be shuffled ahead of time. But can that pass be delayed until .compute time rather than when we construct the graph?

Comment thread dask_ml/model_selection/_split.py Outdated
Comment thread dask_ml/model_selection/_split.py Outdated
Comment thread dask_ml/model_selection/_split.py Outdated
Comment thread dask_ml/model_selection/_split.py Outdated
@austinzh
Copy link
Copy Markdown

My two cents.

  1. classes can be optional because computing classes from an out-of-core dataset, outside train test split will cost the same.
  2. If we split classes by classes, does it mean the return train, test datasets are ordered by classes?

@TomAugspurger
Copy link
Copy Markdown
Member

TomAugspurger commented Jun 18, 2020 via email

@austinzh
Copy link
Copy Markdown

austinzh commented Jun 18, 2020

Yes. But if we check for ci in classes: loop, we will found that we split class by class then concatenate them back.
That implies the return array, for example, the train set looks like
[randomlized_classA, randomlized_classB, randomlized_ClassC] meaning in this PR's implementation, same class stick together.

But If we use the same parameter for scikit-learn's train_test_split, the output will be shuffled.

For example, I run this on un-shuffled, iris.csv.
output1 is the output of sklean's train, test = ms.train_test_split(df, test_size=0.2, random_state=0, shuffle=True, stratify=df['species'])
output2 is the output of this PR. And I only print the species column

output1:

setosa
setosa
setosa
setosa
versicolor
setosa
virginica
virginica
versicolor
virginica
virginica
versicolor
setosa
versicolor
virginica
virginica
setosa
versicolor
versicolor
setosa
virginica
setosa
setosa
virginica
virginica
versicolor
versicolor
setosa
virginica
virginica
versicolor
versicolor
setosa
virginica
virginica
versicolor
virginica
versicolor
virginica
versicolor
versicolor
versicolor
setosa
setosa
versicolor
versicolor
virginica
virginica
versicolor
setosa
virginica
virginica
setosa
setosa
versicolor
versicolor
setosa
setosa
versicolor
virginica
setosa
setosa
versicolor
versicolor
virginica
versicolor
virginica
setosa
setosa
virginica
versicolor
versicolor
setosa
setosa
virginica
versicolor
virginica
setosa
versicolor
virginica
virginica
versicolor
virginica
setosa
versicolor
setosa
setosa
virginica
virginica
versicolor
virginica
setosa
setosa
setosa
setosa
setosa
versicolor
versicolor
versicolor
virginica
setosa
virginica
setosa
virginica
setosa
versicolor
versicolor
versicolor
versicolor
setosa
virginica
virginica
setosa
versicolor
versicolor
virginica
setosa
virginica
virginica
virginica

output2

setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
setosa
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
versicolor
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica
virginica

I don't think we would want to order by class. Does scikit-learn do that?

On Thu, Jun 18, 2020 at 9:58 AM austinzh @.***> wrote: My two cents. 1. classes can be optional because computing classes from an out-of-core dataset, outside train test split will cost the same. 2. If we split classes by classes, does it mean the return train, test datasets are ordered by classes? — You are receiving this because you commented. Reply to this email directly, view it on GitHub <#635 (comment)>, or unsubscribe https://github.com/notifications/unsubscribe-auth/AAKAOIQKBAVFY55235RQT4DRXITQXANCNFSM4L4RTUDQ .

Comment thread dask_ml/model_selection/_split.py Outdated
Comment on lines +535 to +537
train_test_pairs.append(
[dd.concat(arr_train_slices), dd.concat(arr_test_slices)]
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the output is not ordered by classes. I think we needs to add some kind of shuffle here.

Suggested change
train_test_pairs.append(
[dd.concat(arr_train_slices), dd.concat(arr_test_slices)]
)
train = dd.concat(arr_train_slices)
test = dd.concat(arr_test_slices)
train = train.shuffle(train.index)
test = test.shuffle(test.index)
# concat all train subdfs as 1 train df, same for test
train_test_pairs.append([train, test])

Base automatically changed from master to main February 2, 2021 03:43
Copy link
Copy Markdown

@hsteinshiromoto hsteinshiromoto left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I am a new guy and I had a question / suggestion to improve the code.

Best,


types = set(type(arr) for arr in arrays)

if stratify is not None:
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question: why you are not using if stratify: ?

@bpsut
Copy link
Copy Markdown

bpsut commented May 12, 2026

What is the current status of this? It's 2026 and this appears to still be open, but it would be REALLY nice to have.

Squashed:
- Fix linting errors.
@chauhankaranraj
Copy link
Copy Markdown
Contributor Author

Hi dask community,

Apologies for being MIA, had a busy stretch at my day job (we were launching a new AWS product, S3 Files). Finally got some time to work on this PR this weekend with some LLM help. Would really appreciate a re-review whenever you have cycles! 🙏

Can you say a bit about the high-level strategy here

The high level approach is as follows

  1. Count how many rows of each class live in each block
  2. Sum these counts to get the global class distribution. Use test_size to decide how many test rows each class should contribute overall (keeping in mind sklearn's rule of at least one row of every class in both train and test).
  3. For each class, split its "test rows budget" across blocks in proportion to how many rows of that class each block holds.
  4. Each block picks that many rows at random per class. Slice every input array by those indices and concatenate the pieces.

Note that everything stays lazy until .compute(). Only the small (n_blocks, n_classes) shaped row count matrix is computed and brought into memory, which to me seems like a fair trade-off for the split accuracy.

Please lmk what y'all think!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants