Skip to content

perf(ppo): reduce log-prob + entropy cross-entropy peak memory#2011

Open
Mantissagithub wants to merge 3 commits into
THUDM:mainfrom
Mantissagithub:fix/logprob-vocab-ce-oom
Open

perf(ppo): reduce log-prob + entropy cross-entropy peak memory#2011
Mantissagithub wants to merge 3 commits into
THUDM:mainfrom
Mantissagithub:fix/logprob-vocab-ce-oom

Conversation

@Mantissagithub
Copy link
Copy Markdown

@Mantissagithub Mantissagithub commented Jun 2, 2026

What this is

A memory optimization for the log-prob/entropy cross-entropy in slime/utils/ppo_utils.py. It:

  • fuses the log-prob and entropy computation into a single autograd Function, so it keeps one working copy of the logits instead of two clones;
  • makes the backward reuse Megatron's saved softmax buffer in place (no extra full-vocab temps), mirroring VocabParallelCrossEntropy.calculate_gradients;
  • clones only when autograd would observe Megatron's in-place mutation (_clone_if_grad_tracked), keeping the no-grad ref/old-logprob path lean and fixing a view … modified inplace crash on the differentiable path;
  • wires the existing --log-probs-chunk-size into examples/retool/retool_qwen3_4b_rl.sh.

Scope — this is a mitigation, not the whole fix for #1951

I want to be upfront: this PR does not, on its own, resolve the #1951 OOM. Digging into that issue (details there), the failing 58.15 GiB allocation is a single [T, 76032] fp32 tensor with T ≈ 205,280 — i.e. an actor_train micro-batch ~20× larger than --max-tokens-per-gpu 9216. That happens because the reporter is on slimerl/slime:v0.2.4, which predates the token-balanced training scheduler (slime/utils/dp_schedule.py). The real fix for that OOM is bounding T, which current main already does.

So this PR is best viewed as complementary headroom for the cross-entropy once T is bounded — not the #1951 fix. I've used Relates to #1951 rather than Fixes for that reason. Happy to drop the retool example change if you'd prefer this stay a pure ppo_utils optimization.

Changes

  • slime/utils/ppo_utils.py — fused _VocabParallelLogProbsAndEntropy + in-place backward + _clone_if_grad_tracked.
  • examples/retool/retool_qwen3_4b_rl.sh — set --log-probs-chunk-size 1024.
  • tests/test_logprob_entropy_fused.py (+ CI registration) — forward parity, chunked==unchunked, bf16-tolerance backward, and a TP=2 gloo gradcheck vs a full-vocab reference.
  • tools/repro_1951.py — synthetic single-rank peak-memory reproducer.
  • slime/backends/megatron_utils/model.py — opt-in SLIME_MEM_PROBE peak logger (env-gated; happy to drop if it's noise).

Measurements (1×H200, real megatron-core, fp32, --with-entropy --backward)

Where it helps — fixed T, per-card vocab V=76032, sweeping --log-probs-chunk-size:

chunk_size before this PR
-1 (off) 51.06 GiB 32.46 GiB
1024 38.05 GiB 27.83 GiB
512 37.59 GiB 27.82 GiB

And where it doesn't — sweeping the micro-batch token count T (chunk 1024):

T input [T,76032] before this PR
9,216 2.61 GiB 14.42 GiB 7.84 GiB
51,200 14.50 GiB 79.82 GiB 43.56 GiB
103,756 29.39 GiB OOM 88.24 GiB
205,280 58.14 GiB OOM OOM

So it roughly halves the cross-entropy peak in the regime that fits, and turns the ~29 GiB-class step from OOM into a survivable one — but at T≈205k it still OOMs, because the logits tensor itself is 58 GiB. No cross-entropy change can help that; bounding T is what does.

Correctness

Backward matches naïve autograd within bf16 tolerance (rel ≲ 0.002) at TP=1 (entropy-active and entropy-zero) and via the TP=2 gloo gradcheck. pytest tests/test_logprob_entropy_fused.py tests/test_chunked_gae.py passes; pre-commit clean. The bf16 grad return matches Megatron's fused CE.

Notes / limitations

  • The backward is hand-written; verified at TP=1 and TP=2 only — larger-TP behavior leans on the existing GPU CI suite.
  • This reduces the cross-entropy working set, not the input logits tensor.
  • Follow-up idea (separate PR): filter logits by loss_mask before CE to shrink T itself.

Relates to #1951.

  The fused backward materialized three new full-vocab [T,1,V] fp32 tensors
  (log_softmax, entropy_grad, grad_input) on top of the saved softmax, so the
  unchunked path peaked ~4x [T,V] — worse than the pre-fix two-pass code
  (H200 B=16384 V=151936: 60.3 vs 51.1 GiB).

  Reuse the saved softmax buffer in place as the gradient (mirroring Megatron's
  VocabParallelCrossEntropy.calculate_gradients): at most one extra full-vocab
  temp, and only when entropy gradient flows. Return a bf16 grad like Megatron's
  CE (drops the bf16+fp32 mix).

  H200 peak (with_entropy, backward), now <= main at every chunk size:
    chunk -1  : 51.1 -> 32.5 GiB  (was 60.3, a regression)
    chunk 1024: 38.0 -> 27.8 GiB
    chunk 512 : 37.6 -> 27.8 GiB

  Backward verified vs naive autograd within bf16 tol (TP=1 and TP=2). Also drop
  a stale INVESTIGATION.md reference in the clone-guard comment.
@Mantissagithub Mantissagithub changed the title fix(ppo): reduce fused logprob memory peak perf(ppo): reduce log-prob + entropy cross-entropy peak memory Jun 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant