WIP: Add stratified split feature to model_selection.train_test_split#635
WIP: Add stratified split feature to model_selection.train_test_split#635chauhankaranraj wants to merge 1 commit into
Conversation
TomAugspurger
left a comment
There was a problem hiding this comment.
Can you say a bit about the high-level strategy here, and the challenges?
Stepping back, the idea behind stratify is to get approximately the same frequency of each class in the output splits as in the input? So do we absolutely require computing? At the very least, I think we'll need a full pass over the data to compute the frequencies in stratify, since the data may not be shuffled ahead of time. But can that pass be delayed until .compute time rather than when we construct the graph?
|
My two cents.
|
|
I don't think we would want to order by class. Does scikit-learn do that?
…On Thu, Jun 18, 2020 at 9:58 AM austinzh ***@***.***> wrote:
My two cents.
1. classes can be optional because computing classes from an
out-of-core dataset, outside train test split will cost the same.
2. If we split classes by classes, does it mean the return train, test
datasets are ordered by classes?
—
You are receiving this because you commented.
Reply to this email directly, view it on GitHub
<#635 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAKAOIQKBAVFY55235RQT4DRXITQXANCNFSM4L4RTUDQ>
.
|
|
Yes. But if we check But If we use the same parameter for scikit-learn's train_test_split, the output will be shuffled. For example, I run this on un-shuffled, iris.csv. output1: output2
|
| train_test_pairs.append( | ||
| [dd.concat(arr_train_slices), dd.concat(arr_test_slices)] | ||
| ) |
There was a problem hiding this comment.
For the output is not ordered by classes. I think we needs to add some kind of shuffle here.
| train_test_pairs.append( | |
| [dd.concat(arr_train_slices), dd.concat(arr_test_slices)] | |
| ) | |
| train = dd.concat(arr_train_slices) | |
| test = dd.concat(arr_test_slices) | |
| train = train.shuffle(train.index) | |
| test = test.shuffle(test.index) | |
| # concat all train subdfs as 1 train df, same for test | |
| train_test_pairs.append([train, test]) |
hsteinshiromoto
left a comment
There was a problem hiding this comment.
Hi,
I am a new guy and I had a question / suggestion to improve the code.
Best,
|
|
||
| types = set(type(arr) for arr in arrays) | ||
|
|
||
| if stratify is not None: |
There was a problem hiding this comment.
Quick question: why you are not using if stratify: ?
|
What is the current status of this? It's 2026 and this appears to still be open, but it would be REALLY nice to have. |
Squashed: - Fix linting errors.
|
Hi dask community, Apologies for being MIA, had a busy stretch at my day job (we were launching a new AWS product, S3 Files). Finally got some time to work on this PR this weekend with some LLM help. Would really appreciate a re-review whenever you have cycles! 🙏
The high level approach is as follows
Note that everything stays lazy until Please lmk what y'all think! |
I took a stab at implementing a solution for issue #535
Adding a WIP label because currently the stratified split is not completely lazily for dask arrays (
compute_chunk_sizesbeing called here). Nonetheless, I think it works fine for dask series and dataframes.Any feedback would be appreciated :)