feat: add --balance-by-flops for FLOPs-balanced micro-batching#2017
Open
HaoDong0027 wants to merge 4 commits into
Open
feat: add --balance-by-flops for FLOPs-balanced micro-batching#2017HaoDong0027 wants to merge 4 commits into
HaoDong0027 wants to merge 4 commits into
Conversation
Author
|
@huang3eng Please help to review. |
Contributor
|
LGTM~ @zhuzilin |
zhuzilin
reviewed
Jun 4, 2026
| else: | ||
| d_ff = ffn | ||
| ffn_mul = 3 if getattr(args, "swiglu", False) else 2 | ||
| args.workload_coeff = 2 * h + ffn_mul * d_ff // 2 |
Contributor
There was a problem hiding this comment.
can we reuse the calculate_fwd_flops function here?
Author
There was a problem hiding this comment.
Maybe feasible,I'll test it on Qwen3-30B-A3B and report back.
Author
There was a problem hiding this comment.
We ran an ablation on Qwen3-30B-A3B (RL GRPO, TP=4, EP=8, DP=2) comparing baseline (token-sum KK), lightweight approximation (coeff*L + L²), and exact FLOPs (calculate_fwd_flops):
| Metric | Baseline (token-sum) | Approx (coeff*L+L²) |
Exact (calculate_fwd_flops) |
|---|---|---|---|
| actor_train_time | 119.48s | 94.80s (-20.7%) | 94.11s (-21.2%) |
| actor_train_tflops | 49.58 | 61.26 (+23.5%) | 59.79 (+20.6%) |
| actor_train_tok/s | 14,535 | 18,006 (+23.9%) | 18,633 (+28.2%) |
Since the approximate and exact methods provide nearly identical speedups over the baseline, we refactored the implementation to reuse calculate_fwd_flops directly rather than introducing an additional workload estimation function. Please see the latest commit for details.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
This PR adds
--balance-by-flopsfor FLOPs-aware micro-batch partitioning in dynamic batching. Instead of balancing by token count alone, it usescoeff * L + L²per sample to account for the quadratic cost of attention. The linear coefficient is auto-computed from model config (hidden_size, FFN, SwiGLU, MoE experts/topk/shared).Changes:
--balance-by-flopsflag and auto-computeworkload_coeffinarguments.py.dp_schedule.py.calculate_workload()inseqlen_balancing.py.Motivation
Dynamic batching packs micro-batches by token count (
Σ L), but attention FLOPs scale asO(L²). When sequence lengths vary widely, a rank with a few long sequences has much higher compute cost than one with many short sequences. Since all ranks synchronize at all-reduce, the slowest rank determines step time.--balance-by-flopsusescoeff * L + L²as the balancing metric. This operates purely at the micro-batch partitioning level and does not affect downstream context-parallel splitting, so it composes directly with CP.Experiments
All experiments compare
--balance-by-flops(with--balance-data) against a baseline using only--balance-data(token-sum KK balancing). Both use--use-dynamic-batch-size. Statistics exclude warmup steps.Setup
Results
Notes
--balance-by-flopsrequires--use-dynamic-batch-size.--balance-data; when both are set, FLOPs-based weights are used for both micro-batch packing and DP rank assignment.