[speechlm2] SALMAutomodel: THD (packed sequence) and context parallel support#15679
Conversation
Adds an opt-in packed-sequence (THD) training/validation path so that ``SALMAutomodel`` can feed a Nemotron-V3 LLM via ``cu_seqlens``-aware varlen attention instead of the right-padded BSHD layout. Padding overhead drops from O(B*max_T) to O(rounding) per minibatch, which is substantial for variable-length speech inputs. Activated by ``model.packed_sequences: True`` in the YAML; the BSHD path is unchanged when the flag is unset. Generate / inference still use BSHD (it doesn't go through ``prepare_inputs``). Pieces: * ``parts/packed_sequences.py`` — concatenates per-utterance text + audio embeddings into a single flat ``[T_total, H]`` sequence with a ``cu_seqlens`` index, applies the per-utt next-token shift, and rounds each utterance's flat length up to a multiple of ``2*cp_size`` so the same packing also satisfies TE's CP DualChunkSwap contract. Output shape mirrors Automodel's canonical THD layout (``components/distributed/thd_utils.process_input_for_thd`` / ``cp_utils._shard_thd_chunk_for_te`` — 2D, no leading batch dim) so no extra squeeze/unsqueeze hops are needed. * ``parts/cp_helpers.py`` — ``get_cp_mesh`` reads the CP submesh out of the device mesh and returns ``(None, 1, 0)`` when CP is inactive. Used by ``prepare_packed_llm_inputs`` to short-circuit the CP-shard path. Further CP plumbing for the BSHD path lands in a follow-up. * ``models/salm_automodel.py`` — ``forward`` accepts ``**llm_kwargs`` (THD metadata: ``qkv_format``, ``cu_seqlens``, ``position_ids``, ``max_seqlen``); ``prepare_inputs`` returns the THD dict early when ``packed_sequences`` is set; ``training_step``/``validation_step`` splat ``llm_kwargs`` into the forward call and use a shape-generic ``logits.reshape(-1, V)`` so the same code handles both BSHD ``(B, T, V)`` and THD ``(T, V)`` outputs. Tests: * ``test_salm_packed_sequences.py`` — covers shape contracts, cu_seqlens invariants, per-utt next-token shift, audio-frame label masking, the ``cp_size``/``tp_size`` rounding, and the TE preprocessor regression test that pins down the ``cu_seqlens`` + ``max_seqlen`` (singular) contract. Includes BSHD-vs-THD pair-equivalence checks: the set of supervised ``(input_embedding, target_token_id)`` pairs reaching the cross-entropy must be identical between the two layouts on the same batch. * ``test_salm_cp_helpers.py`` — three CPU tests for ``get_cp_mesh`` covering the no-mesh, ``cp_size==1``, and missing-``cp``-axis paths.
Adds context-parallelism support to ``SALMAutomodel`` so that large
Nemotron-V3 LLMs with hybrid Mamba/attention layers can train on long
audio sequences across multiple GPUs. Builds on the THD packed-sequence
path from the previous commit; the BSHD path is also supported but the
THD path is the recommended configuration under CP.
Activated by ``cp_size > 1`` in the strategy config (e.g.
``AutomodelParallelStrategy(cp_size=2, ...)``); the existing TP
truncation path is folded into the CP padding so single-axis runs are
unchanged.
Pieces:
* ``parts/cp_helpers.py`` — extends the module with two CP-aware
helpers used by ``SALMAutomodel.prepare_inputs``:
- ``shard_bshd_for_cp`` pads the BSHD seq dim to a multiple of
``2*cp_size*tp_size`` and partitions along the seq dim using TE's
``thd_get_partitioned_indices`` (the same DualChunkSwap pattern
Automodel's ``Config 1`` reference test uses).
- ``encode_audio_with_cp_distribution`` distributes the audio
encoder forward across CP ranks instead of recomputing it
``cp_size`` times. Right-pads the audio batch with zero-audio
dummies so every rank participates in FSDP all-gather (and AC
fires uniformly), then all-gathers the variable-length embedding
tensors back so each rank reconstructs the full ordered list.
* ``models/salm_automodel.py`` — ``prepare_inputs`` derives the CP
mesh once via ``get_cp_mesh``, swaps the audio encoder call to the
CP-distributed version, and (for the BSHD branch) inserts a
``shard_bshd_for_cp`` step before the TP-truncation fallback. Under
CP the BSHD path also drops the padding mask before passing the
batch to the LLM (TE's fused-attention CP path supports ``causal``
but not ``padding_causal``); this is documented as a known
limitation, the durable fix is the THD packed-sequence path.
Tests:
* ``test_salm_cp_helpers.py`` — adds a ``_PerceptionStub`` and CPU
fallback tests for ``encode_audio_with_cp_distribution``
(``cp_mesh is None`` and ``B_aud == 0`` paths). The
``cp_size > 1`` paths in ``shard_bshd_for_cp`` and
``encode_audio_with_cp_distribution`` require ``transformer_engine_torch``
and a real ``torch.distributed`` process group respectively;
exercised by 2-GPU smoke tests.
Adds two new subsections under "AutomodelParallelStrategy (SALMAutomodel)" in the training-and-scaling guide: * "Packed Sequences (THD)" — explains the layout, when it helps (variable-length speech batches), and the YAML knob (``model.packed_sequences: true`` plus ``attn: te``). * "Context Parallelism (CP)" — explains the strategy knob (``cp_size > 1``), the BSHD vs THD pairing, and the recommended configuration. Documents the BSHD-under-CP padding-mask drop as a known limitation, with THD as the durable fix. Calls out the TransformerEngine 2.14 cuDNN-backend bug on certain GPU architectures (notably Blackwell sm_120) that returns correct THD forward activations but gradients amplified 8x-960x per layer, and the ``NVTE_FUSED_ATTN=0`` workaround that forces FlashAttention dispatch (which is gradient-correct on the same shapes). Adds a matching "Packed sequences (THD)" entry to the SALMAutomodel config reference, with a cross-reference to the training-and-scaling guide for the CP pairing.
Catches three configuration combinations that produce silent NaN
gradients or hangs at training time, and raises an informative error
(or warns where the bug is architecture-specific) before the user
spends ~7 minutes on model load only to watch their loss go NaN.
The check is a pure function in ``parts/parallel.py`` and runs from
``SALMAutomodel.on_fit_start`` once the device mesh is wired up.
Cases:
1. ``model.packed_sequences=false`` (BSHD) under ``cp_size > 1`` —
hard error. TE's fused-attention CP path rejects ``padding_causal``
so the right-pad mask is dropped, which lets pad K/V leak into
real-token attention through the causal mask and produces NaN
gradients after step 1. There is no supported workaround; the
error message points users to ``packed_sequences: true``.
2. ``model.packed_sequences=true`` (THD) with
``automodel_backend.attn != "te"`` — hard error. THD packing emits
a 2D ``[T_total, H]`` layout for TE varlen FlashAttention; the
SDPA THD code path in the Automodel branch transposes assuming 4D
BSHD inputs and breaks.
3. ``model.packed_sequences=true`` + ``attn="te"`` +
``NVTE_FUSED_ATTN != "0"`` — TE 2.14's cuDNN fused-attention
backward kernel amplifies THD/padding_causal gradients 8x-960x per
layer on Blackwell sm_120; the resulting ``inf`` gradients drive
the optimizer to NaN. We have no way to be certain the bug only
affects sm_120, so this is a ``warnings.warn`` on other arches and
a hard ``ValueError`` on sm_120 (where the failure is reproduced).
Tests:
* ``test_salm_parallelism_validation.py`` — 19 unit tests covering
every (BSHD, THD) x (cp=1, cp>1) x (attn ∈ {te, sdpa, flex}) x
(NVTE_FUSED_ATTN ∈ {None, "", "0", "1", "true"}) x
(device_capability ∈ {(9,0), (12,0), None}) combination that
matters. Pure-function tests — no Lightning, no model, no device
mesh required.
Under context-parallel (CP) and tensor-parallel (TP) training, all ranks in the same (cp, tp) sub-mesh of a DP slot must process the *same* global batch each step. Independent per-rank Lhotse loaders (with default concurrent_bucketing=True or shard_seed="randomized") can produce divergent cu_seqlens at a fraction of steps, deadlocking NCCL collectives with mismatched per-rank shapes. Add BroadcastingDataLoader: a thin wrapper around (real DataLoader | None) that broadcasts each batch from the DP source rank (cp_rank == 0 and tp_rank == 0) to non-source ranks in the (cp, tp) sub-mesh, plus a continue/stop sentinel so iteration ends in lockstep regardless of whether the source loader exposes __len__. state_dict / load_state_dict are delegated to the source so checkpoint/resume keeps working with DataLoader, torchdata.StatefulDataLoader, etc. The wrapper is a no-op when device_mesh is None or every named axis present has size 1, so the same call site works for single-GPU, DDP-only, and CP/TP runs. Wire the speechlm2 datamodule's train and validation dataloaders through the wrapper. Existing per-rank Lhotse construction stays unchanged on the source rank. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add a section to the Lhotse dataloading docs covering why CP/TP needs identical batches per (cp, tp) sub-mesh and how to use BroadcastingDataLoader. Add a short cross-reference note in the speechlm2 CP section so users know the datamodule applies it automatically. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The fit-start validator already rejects BSHD + CP > 1 with a hard error pointing users to model.packed_sequences=true (see validate_parallelism_compatibility in parts/parallel.py), so any code that exists only to support BSHD under CP is unreachable. In SALMAutomodel.prepare_inputs the BSHD branch's ``if cp_size > 1: shard_bshd_for_cp(...)`` and the ``llm_attention_mask = None if cp_size > 1 else attention_mask`` ternary both presupposed BSHD + CP > 1; remove them and inline the TP-truncation into the BSHD path. Drop the unused shard_bshd_for_cp helper from cp_helpers.py and update its module docstring + the cp_helpers test docstring accordingly. No behavior change for any reachable configuration. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/ok to test ad8520c |
Fixes the check_isort_and_black CI step on PR #15679. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
|
/ok to test 9ba5eec |
|
[🤖]: Hi @pzelasko 👋, We wanted to let you know that a CICD pipeline for this PR just finished successfully. So it might be time to merge this PR or get some approvals. |
|
/ok to test 73862e5 |
|
/ok to test d15c316 |
|
[🤖]: Hi @pzelasko 👋, We wanted to let you know that a CICD pipeline for this PR just finished successfully. So it might be time to merge this PR or get some approvals. |
|
|
||
| gathered_stack = [torch.zeros_like(local_stack) for _ in range(cp_size)] | ||
| gathered_lens = [torch.zeros_like(local_lens) for _ in range(cp_size)] | ||
| dist.all_gather(gathered_stack, local_stack, group=cp_mesh.get_group()) |
There was a problem hiding this comment.
Blocker: looks like dist.all_gather isn't autograd-aware, gathered_stack comes back with no grad_fn. Since full_embs is built entirely from gathered_stack, the perception module receives no gradient at all when CP is active. But in In the cp_size==1 path, encode_audio_with_optional_chunking returns grad-connected embeddings, so this is also a silent behavioral divergence between CP and non-CP. Is this intended? This would be safe only when self.perception is fully frozen, which might not always be the case.
There was a problem hiding this comment.
great catch, thanks
| num_frames = (inputs["target_ids"] != -100).long().sum() | ||
| with loss_parallel(): | ||
| logits = forward_outputs["logits"] | ||
| loss = ( |
There was a problem hiding this comment.
Here each rank would hold only a shard of a global sequence's tokens under CP, so local_CE_sum / local_num_frames would be a per-shard ratio. on_validation_epoch_end would then take an unweighted mean() (with sync_dist=True) over ranks with different denominators, which isn't the true Σsum / Σframes. Accuracy looks to have the same micro-vs-macro issue, since preds.eq(refs).float().mean() is averaged across unequal shards. This wouldn't affect training but biases the val_acc used for checkpoint selection. Should we mirror the non-CP path - aggregate loss_sum and num_frames separately across the dp_cp group and divide once?
| padded_lens = [((L + cp_mult - 1) // cp_mult) * cp_mult for L in real_lens] | ||
| else: | ||
| padded_lens = list(real_lens) | ||
| if tp_size > 1: |
There was a problem hiding this comment.
minor nit: Each per-utt length is rounded to a multiple of 2*cp_size, but the final tp_size bump adds tp_size - rem only to the last segment. When tp_size isn't a multiple of 2*cp_size, that bump breaks the last segment's 2*cp_size alignment, which violates the thd_get_partitioned_indices contract. The current test (test_tp_and_cp_combined) has cp_size=2, tp_size=8, so this issue is not hit. Maybe add an assert tp_size % (2*cp_size) == 0 here?
|
/ok to test 34fca17 |
|
/ok to test 095ce23 |
|
[🤖]: Hi @pzelasko 👋, We wanted to let you know that a CICD pipeline for this PR just finished successfully. So it might be time to merge this PR or get some approvals. |
KunalDhawan
left a comment
There was a problem hiding this comment.
Great work, thanks @pzelasko!
Important
The
Update branchbutton must only be pressed in very rare occassions.An outdated branch is never blocking the merge of a PR.
Please reach out to the automation team before pressing that button.
What does this PR do ?
Adds opt-in THD packed-sequence training and full context-parallel support to
SALMAutomodel, with a fit-start validator that catches known-bad configs.Collection: speechlm2
Changelog
parts/packed_sequences.py+parts/cp_helpers.get_cp_mesh.cp_size > 1) — sequence-shard hybrid Mamba/attention LLMs across GPUs. Requires TransformerEngine. Newparts/cp_helpers.shard_bshd_for_cp+encode_audio_with_cp_distribution. Recommended pairing is THD + CP.prepare_inputsderives the CP mesh once, distributes audio encoding, and (for the BSHD branch) inserts the TEDualChunkSwapshard before the TP-truncation fallback.parts/parallel.validate_parallelism_compatibilityruns aton_fit_startand raises errors for three known-bad combos: BSHD+CP>1 (NaN at step 2 from pad-K/V leak), THD+non-TE attention (Automodel's SDPA THD path is not ready), and THD+TE withoutNVTE_FUSED_ATTN=0(cuDNN backward gradient amp on Blackwell sm_120 — hard error on sm_120, warn elsewhere).training_and_scaling.rstandconfigs.rstdocument the new packed_sequences flag, the CP knob, the BSHD+CP-not-supported warning, and theNVTE_FUSED_ATTN=0workaround.Usage
# Add a code snippet demonstrating how to use thisGitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information