Skip to content

feat: add --balance-by-flops for FLOPs-balanced micro-batching#2017

Open
HaoDong0027 wants to merge 4 commits into
THUDM:mainfrom
HaoDong0027:balance
Open

feat: add --balance-by-flops for FLOPs-balanced micro-batching#2017
HaoDong0027 wants to merge 4 commits into
THUDM:mainfrom
HaoDong0027:balance

Conversation

@HaoDong0027
Copy link
Copy Markdown

@HaoDong0027 HaoDong0027 commented Jun 4, 2026

Summary

This PR adds --balance-by-flops for FLOPs-aware micro-batch partitioning in dynamic batching. Instead of balancing by token count alone, it uses coeff * L + L² per sample to account for the quadratic cost of attention. The linear coefficient is auto-computed from model config (hidden_size, FFN, SwiGLU, MoE experts/topk/shared).

Changes:

  • Add --balance-by-flops flag and auto-compute workload_coeff in arguments.py.
  • Replace first-fit token packing with Karmarkar-Karp FLOPs-balanced partitioning for micro-batch grouping and DP rank assignment in dp_schedule.py.
  • Add calculate_workload() in seqlen_balancing.py.

Motivation

Dynamic batching packs micro-batches by token count (Σ L), but attention FLOPs scale as O(L²). When sequence lengths vary widely, a rank with a few long sequences has much higher compute cost than one with many short sequences. Since all ranks synchronize at all-reduce, the slowest rank determines step time.

--balance-by-flops uses coeff * L + L² as the balancing metric. This operates purely at the micro-batch partitioning level and does not affect downstream context-parallel splitting, so it composes directly with CP.

Experiments

All experiments compare --balance-by-flops (with --balance-data) against a baseline using only --balance-data (token-sum KK balancing). Both use --use-dynamic-batch-size. Statistics exclude warmup steps.

Setup

SFT Dense SFT MoE RL Dense RL MoE
GPU 4×L20Z 8×L20Z 8×L20Z 8×L20Z
Model Qwen3-4B Qwen3.5-35B-A3B Qwen3-4B Qwen3-30B-A3B
Type Dense MoE Dense MoE
Task SFT SFT RL (GRPO) RL (GRPO)
TP / PP / EP / DP 1 / 1 / — / 4 2 / 1 / 8 / 4 2 / 1 / — / 4 4 / 1 / 8 / 2
max_tokens_per_gpu 9,216 8,192 9,216 20,480
global_batch_size 128 128 256 256
max_response_len 8,192 8,192

Results

SFT Dense SFT MoE RL Dense RL MoE
actor_train_time Δ −12.7% −30.1% −14.0% −20.6%
actor_train_tflops Δ +13.7% +23.4% +4.5% +23.7%
actor_train_tok/s Δ +13.7% +23.4% +6.0% +24.0%

Notes

  • --balance-by-flops requires --use-dynamic-batch-size.
  • The flag is orthogonal to --balance-data; when both are set, FLOPs-based weights are used for both micro-batch packing and DP rank assignment.

@HaoDong0027
Copy link
Copy Markdown
Author

@huang3eng Please help to review.

@huang3eng
Copy link
Copy Markdown
Contributor

LGTM~ @zhuzilin

Comment thread slime/utils/arguments.py Outdated
else:
d_ff = ffn
ffn_mul = 3 if getattr(args, "swiglu", False) else 2
args.workload_coeff = 2 * h + ffn_mul * d_ff // 2
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we reuse the calculate_fwd_flops function here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe feasible,I'll test it on Qwen3-30B-A3B and report back.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We ran an ablation on Qwen3-30B-A3B (RL GRPO, TP=4, EP=8, DP=2) comparing baseline (token-sum KK), lightweight approximation (coeff*L + L²), and exact FLOPs (calculate_fwd_flops):

Metric Baseline (token-sum) Approx (coeff*L+L²) Exact (calculate_fwd_flops)
actor_train_time 119.48s 94.80s (-20.7%) 94.11s (-21.2%)
actor_train_tflops 49.58 61.26 (+23.5%) 59.79 (+20.6%)
actor_train_tok/s 14,535 18,006 (+23.9%) 18,633 (+28.2%)

Since the approximate and exact methods provide nearly identical speedups over the baseline, we refactored the implementation to reuse calculate_fwd_flops directly rather than introducing an additional workload estimation function. Please see the latest commit for details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants