perf(ppo): reduce log-prob + entropy cross-entropy peak memory#2011
Open
Mantissagithub wants to merge 3 commits into
Open
perf(ppo): reduce log-prob + entropy cross-entropy peak memory#2011Mantissagithub wants to merge 3 commits into
Mantissagithub wants to merge 3 commits into
Conversation
The fused backward materialized three new full-vocab [T,1,V] fp32 tensors
(log_softmax, entropy_grad, grad_input) on top of the saved softmax, so the
unchunked path peaked ~4x [T,V] — worse than the pre-fix two-pass code
(H200 B=16384 V=151936: 60.3 vs 51.1 GiB).
Reuse the saved softmax buffer in place as the gradient (mirroring Megatron's
VocabParallelCrossEntropy.calculate_gradients): at most one extra full-vocab
temp, and only when entropy gradient flows. Return a bf16 grad like Megatron's
CE (drops the bf16+fp32 mix).
H200 peak (with_entropy, backward), now <= main at every chunk size:
chunk -1 : 51.1 -> 32.5 GiB (was 60.3, a regression)
chunk 1024: 38.0 -> 27.8 GiB
chunk 512 : 37.6 -> 27.8 GiB
Backward verified vs naive autograd within bf16 tol (TP=1 and TP=2). Also drop
a stale INVESTIGATION.md reference in the clone-guard comment.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What this is
A memory optimization for the log-prob/entropy cross-entropy in
slime/utils/ppo_utils.py. It:softmaxbuffer in place (no extra full-vocab temps), mirroringVocabParallelCrossEntropy.calculate_gradients;_clone_if_grad_tracked), keeping the no-grad ref/old-logprob path lean and fixing aview … modified inplacecrash on the differentiable path;--log-probs-chunk-sizeintoexamples/retool/retool_qwen3_4b_rl.sh.Scope — this is a mitigation, not the whole fix for #1951
I want to be upfront: this PR does not, on its own, resolve the #1951 OOM. Digging into that issue (details there), the failing
58.15 GiBallocation is a single[T, 76032]fp32 tensor withT ≈ 205,280— i.e. anactor_trainmicro-batch ~20× larger than--max-tokens-per-gpu 9216. That happens because the reporter is onslimerl/slime:v0.2.4, which predates the token-balanced training scheduler (slime/utils/dp_schedule.py). The real fix for that OOM is boundingT, which currentmainalready does.So this PR is best viewed as complementary headroom for the cross-entropy once
Tis bounded — not the #1951 fix. I've usedRelates to #1951rather thanFixesfor that reason. Happy to drop the retool example change if you'd prefer this stay a pureppo_utilsoptimization.Changes
slime/utils/ppo_utils.py— fused_VocabParallelLogProbsAndEntropy+ in-place backward +_clone_if_grad_tracked.examples/retool/retool_qwen3_4b_rl.sh— set--log-probs-chunk-size 1024.tests/test_logprob_entropy_fused.py(+ CI registration) — forward parity, chunked==unchunked, bf16-tolerance backward, and a TP=2 gloo gradcheck vs a full-vocab reference.tools/repro_1951.py— synthetic single-rank peak-memory reproducer.slime/backends/megatron_utils/model.py— opt-inSLIME_MEM_PROBEpeak logger (env-gated; happy to drop if it's noise).Measurements (1×H200, real
megatron-core, fp32,--with-entropy --backward)Where it helps — fixed
T, per-card vocabV=76032, sweeping--log-probs-chunk-size:-1(off)1024512And where it doesn't — sweeping the micro-batch token count
T(chunk1024):T[T,76032]So it roughly halves the cross-entropy peak in the regime that fits, and turns the ~29 GiB-class step from OOM into a survivable one — but at
T≈205kit still OOMs, because the logits tensor itself is 58 GiB. No cross-entropy change can help that; boundingTis what does.Correctness
Backward matches naïve autograd within bf16 tolerance (rel ≲ 0.002) at TP=1 (entropy-active and entropy-zero) and via the TP=2 gloo gradcheck.
pytest tests/test_logprob_entropy_fused.py tests/test_chunked_gae.pypasses; pre-commit clean. The bf16 grad return matches Megatron's fused CE.Notes / limitations
loss_maskbefore CE to shrinkTitself.Relates to #1951.