Skip to content

[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057

Open
plugyawn wants to merge 6 commits into
NVIDIA:mainfrom
plugyawn:rope-thd-token-linear
Open

[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057
plugyawn wants to merge 6 commits into
NVIDIA:mainfrom
plugyawn:rope-thd-token-linear

Conversation

@plugyawn
Copy link
Copy Markdown

@plugyawn plugyawn commented May 28, 2026

Description

Adds a token-linear implementation of the existing THD fused RoPE path to remove a launch-scaling bug.

Addresses #2866, which finds an interesting case with RoPE scales by freqs_len × n_spans, which is pathological; it should scale by total tokens. I reproduced the issue and found that it's causing a noticeable drops on even plausibly routine shapes. For eg: the [128/512] and [512/128] cases here.

The new kernel reuses the existing fused_rope_block_forward and fused_rope_block_backward device helpers, so the math doesn't change. All we need to do is add a THD-only path that launches one bloc/packed token.

n_seqs max span old layer fwd+bwd (ms) new layer fwd+bwd (ms) layer speedup old paired-RoPE share new paired-RoPE share
128 512 41.8151 23.0284 1.816x 49.12% 6.14%
512 128 102.1047 23.0167 4.436x 79.38% 6.59%
1024 64 182.9933 23.3783 7.827x 88.36% 6.77%
2401 28 401.0516 24.5668 16.325x 94.40% 6.41%

This is mostly pathological, however, so I've added a condition on the dispatch to avoid the unnecessary binary search overhead, although the overhead appears to be not-that-relevant. The condition is: token-linear only when b >= 64 and the old launch would issue ≥ 8× as many blocks as there are tokens. I'm not sure if this the usual shape of TE updates, so I could remove it!

Some more relevant tests:
Microbenchmark on H100 (bf16, h=32, d=d2=128, freqs_len=T_local=65536, single GPU):

n_seqs old fwd+bwd (ms) new fwd+bwd (ms) speedup
1 1.2746 1.2734 1.001x
8 1.8860 1.3827 1.364x
32 3.9359 1.4462 2.722x
128 12.1849 1.5024 8.110x
512 44.9411 1.5600 28.808x
1024 89.1110 1.5919 55.977x
2401 208.4182 1.6373 127.296x

Fixes: #2866.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add token-linear THD fused RoPE forward/backward kernels that launch one CUDA block per packed local token row.
  • Add NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.
  • Reuses existing fused_rope_block_forward and fused_rope_block_backward device helpers.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation <<(none?)>>
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@github-actions github-actions Bot added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label May 28, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR fixes a block-launch scaling bug in the THD fused RoPE path where the old kernel launched freqs_len × n_seqs blocks but only total_tokens of them did useful work. The new path launches exactly total_tokens blocks (one per packed token), performs a per-block binary search over cu_seqlens to identify the owning sequence, and calls the unchanged fused_rope_block_forward/backward device helpers — so the per-token math is bit-for-bit identical to the original.

  • New CUDA kernels (fused_rope_thd_token_forward_kernel, fused_rope_thd_token_backward_kernel) and a host-side heuristic dispatcher that selects the token-linear path when b ≥ 64 and the old block count is ≥ 8× the token count; overridable via NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.
  • New parity test (test_fused_rope_thd_token_linear_parity) that forces each kernel path and asserts bitwise equality on forward outputs and input gradients across a broad set of shapes including zero-length spans and context-parallel configurations.
  • Two new benchmarks providing microbenchmark and full TransformerLayer measurements to quantify the speedup (up to ~127× for n_seqs=2401 on H100).

Confidence Score: 4/5

Safe to merge for well-formed inputs; the new kernels are mathematically equivalent to the original, but lack the old kernel's explicit out-of-range token guard — a concern noted in prior review threads that is worth addressing before wide deployment.

The core correctness path — binary search → s_id computation → fused_rope_block_forward/backward call — is equivalent to the original kernel for all valid inputs, and the bitwise parity test confirms this. The outstanding concern (no t_id >= cu_seqlens[-1]/cp_size guard in the new kernels) was flagged in a prior review thread and remains unresolved. Test coverage also leaves float16 and cp_rank > 0 with cp_size > 1 untested in the new path.

transformer_engine/common/fused_rope/fused_rope.cu — the new token-linear kernel bodies and the fused_rope_thd_use_token_linear dispatcher warrant a second look, specifically the missing bounds guard discussed in the prior review thread.

Important Files Changed

Filename Overview
transformer_engine/common/fused_rope/fused_rope.cu Adds token-linear THD forward/backward kernels and a heuristic dispatcher; math is equivalent to the original kernel for valid inputs, but the new kernels lack the old kernel's t_id >= end guard that protected against mismatched tensor shapes and out-of-range cu_seqlens access (already flagged in prior review thread).
tests/pytorch/test_fused_rope.py Adds a parity test comparing old and new THD kernels with bitwise equality; covers most shape combinations including zero-length spans, but omits float16 from the dtype parametrize and always uses cp_rank=0 with cp_size=2.
benchmarks/attention/benchmark_rope_thd_token_linear.py New microbenchmark sweeping n_seqs while holding total tokens fixed; straightforward and correct, captures old/new/heuristic regimes with env-var toggling.
benchmarks/attention/benchmark_rope_thd_full_layer.py New end-to-end TransformerLayer benchmark measuring RoPE share of total layer time; well-structured with csv/png output and correct env-var scoping.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["fused_rope_forward / fused_rope_backward"] --> B["Compute total_tokens from input.data.shape[0]\n(THD only)"]
    B --> C["fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)"]
    C --> D{qkv_format == NVTE_THD?}
    D -- No --> E["return false"]
    D -- Yes --> F{total_tokens <= 0?}
    F -- Yes --> E
    F -- No --> G{"NVTE_FUSED_ROPE_THD_TOKEN_LINEAR env?"}
    G -- "== 0" --> E
    G -- "== 1" --> H["return true (forced new)"]
    G -- "unset / other" --> I{b >= 64 AND s*b >= 8*total_tokens?}
    I -- No --> E
    I -- Yes --> H
    H --> J["New token-linear kernel\ndim3(total_tokens) blocks\none block per packed token"]
    J --> K["fused_rope_thd_find_seq_id\nbinary search over cu_seqlens"]
    K --> L["fused_rope_block_forward /\nfused_rope_block_backward"]
    E --> M["Original fused_rope_forward_kernel\ndim3(s, b) blocks — many dead blocks\nwhen s >> avg seqlen"]
Loading

Reviews (5): Last reviewed commit: "Merge branch 'main' into rope-thd-token-..." | Re-trigger Greptile

Comment on lines +250 to +251
int t_id = blockIdx.x;
int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Redundant binary search across all threads in the block

Every thread in the block calls fused_rope_thd_find_seq_id with the same arguments (t_id = blockIdx.x, nseq, cp_size) and produces an identical result. With warps_per_block = 8, that's 256 threads each doing O(log nseq) global-memory reads of cu_seqlens that could be performed once. For nseq=2401 (~12 iterations x 256 threads), each block reads ~3,072 redundant entries from cu_seqlens. Performing the search once in thread 0 and broadcasting the result via shared memory would eliminate that overhead.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Smart bot!

Comment on lines +250 to +255
int t_id = blockIdx.x;
int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size);
int start = cu_seqlens[b_id] / cp_size;
int end = cu_seqlens[b_id + 1] / cp_size;
int s_id = t_id - start;
int cur_seqlens = end - start;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 No guard for t_id exceeding valid cu_seqlens range

The old kernel explicitly filters dead blocks with if (t_id >= end) return; before any computation. The new kernel does not: it trusts that blockIdx.x < cu_seqlens[nseq]/cp_size because total_tokens is read from input.data.shape[0]. If a caller passes a tensor with shape[0] larger than cu_seqlens[-1]/cp_size, the binary search lands on b_id = nseq-1, computes s_id = t_id - start >= cur_seqlens, and fused_rope_block_forward indexes freqs at an out-of-range s_id_for_freqs. Adding if (t_id >= (int)(cu_seqlens[nseq] / cp_size)) return; after the binary search would restore the safety property the old kernel had.

@ptrendx
Copy link
Copy Markdown
Member

ptrendx commented May 28, 2026

@plugyawn Hi, could you sign your commits? See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work
Nice improvement :-).

@sudhakarsingh27 Could you take a look?

plugyawn and others added 3 commits May 29, 2026 03:23
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci

Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
@plugyawn plugyawn force-pushed the rope-thd-token-linear branch from 331a3a0 to 6c46696 Compare May 28, 2026 21:55
@plugyawn
Copy link
Copy Markdown
Author

plugyawn commented May 28, 2026

Thanks! Signed!

fwiw I think the binary search overhead on normal cases can be reduced also, I'll probably add some improvements.

@sudhakarsingh27 sudhakarsingh27 self-requested a review June 3, 2026 22:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Performance] Fused RoPE THD kernel becomes dominant bottleneck in long-context training with many packed sequences

3 participants