[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057
[Common/PyTorch] bugfix: Token-linear fused RoPE impl. for THD tensors.#3057plugyawn wants to merge 6 commits into
Conversation
Greptile SummaryThis PR fixes a block-launch scaling bug in the THD fused RoPE path where the old kernel launched
Confidence Score: 4/5Safe to merge for well-formed inputs; the new kernels are mathematically equivalent to the original, but lack the old kernel's explicit out-of-range token guard — a concern noted in prior review threads that is worth addressing before wide deployment. The core correctness path — binary search → s_id computation → fused_rope_block_forward/backward call — is equivalent to the original kernel for all valid inputs, and the bitwise parity test confirms this. The outstanding concern (no t_id >= cu_seqlens[-1]/cp_size guard in the new kernels) was flagged in a prior review thread and remains unresolved. Test coverage also leaves float16 and cp_rank > 0 with cp_size > 1 untested in the new path. transformer_engine/common/fused_rope/fused_rope.cu — the new token-linear kernel bodies and the fused_rope_thd_use_token_linear dispatcher warrant a second look, specifically the missing bounds guard discussed in the prior review thread. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["fused_rope_forward / fused_rope_backward"] --> B["Compute total_tokens from input.data.shape[0]\n(THD only)"]
B --> C["fused_rope_thd_use_token_linear(qkv_format, b, s, total_tokens)"]
C --> D{qkv_format == NVTE_THD?}
D -- No --> E["return false"]
D -- Yes --> F{total_tokens <= 0?}
F -- Yes --> E
F -- No --> G{"NVTE_FUSED_ROPE_THD_TOKEN_LINEAR env?"}
G -- "== 0" --> E
G -- "== 1" --> H["return true (forced new)"]
G -- "unset / other" --> I{b >= 64 AND s*b >= 8*total_tokens?}
I -- No --> E
I -- Yes --> H
H --> J["New token-linear kernel\ndim3(total_tokens) blocks\none block per packed token"]
J --> K["fused_rope_thd_find_seq_id\nbinary search over cu_seqlens"]
K --> L["fused_rope_block_forward /\nfused_rope_block_backward"]
E --> M["Original fused_rope_forward_kernel\ndim3(s, b) blocks — many dead blocks\nwhen s >> avg seqlen"]
Reviews (5): Last reviewed commit: "Merge branch 'main' into rope-thd-token-..." | Re-trigger Greptile |
| int t_id = blockIdx.x; | ||
| int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); |
There was a problem hiding this comment.
Redundant binary search across all threads in the block
Every thread in the block calls fused_rope_thd_find_seq_id with the same arguments (t_id = blockIdx.x, nseq, cp_size) and produces an identical result. With warps_per_block = 8, that's 256 threads each doing O(log nseq) global-memory reads of cu_seqlens that could be performed once. For nseq=2401 (~12 iterations x 256 threads), each block reads ~3,072 redundant entries from cu_seqlens. Performing the search once in thread 0 and broadcasting the result via shared memory would eliminate that overhead.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| int t_id = blockIdx.x; | ||
| int b_id = fused_rope_thd_find_seq_id(cu_seqlens, nseq, t_id, cp_size); | ||
| int start = cu_seqlens[b_id] / cp_size; | ||
| int end = cu_seqlens[b_id + 1] / cp_size; | ||
| int s_id = t_id - start; | ||
| int cur_seqlens = end - start; |
There was a problem hiding this comment.
No guard for
t_id exceeding valid cu_seqlens range
The old kernel explicitly filters dead blocks with if (t_id >= end) return; before any computation. The new kernel does not: it trusts that blockIdx.x < cu_seqlens[nseq]/cp_size because total_tokens is read from input.data.shape[0]. If a caller passes a tensor with shape[0] larger than cu_seqlens[-1]/cp_size, the binary search lands on b_id = nseq-1, computes s_id = t_id - start >= cur_seqlens, and fused_rope_block_forward indexes freqs at an out-of-range s_id_for_freqs. Adding if (t_id >= (int)(cu_seqlens[nseq] / cp_size)) return; after the binary search would restore the safety property the old kernel had.
|
@plugyawn Hi, could you sign your commits? See https://github.com/NVIDIA/TransformerEngine/blob/main/CONTRIBUTING.rst#sign-your-work @sudhakarsingh27 Could you take a look? |
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
for more information, see https://pre-commit.ci Signed-off-by: plugyawn <progyan.das@iitgn.ac.in>
331a3a0 to
6c46696
Compare
|
Thanks! Signed! fwiw I think the binary search overhead on normal cases can be reduced also, I'll probably add some improvements. |
Description
Adds a token-linear implementation of the existing THD fused RoPE path to remove a launch-scaling bug.
Addresses #2866, which finds an interesting case with RoPE scales by freqs_len × n_spans, which is pathological; it should scale by total tokens. I reproduced the issue and found that it's causing a noticeable drops on even plausibly routine shapes. For eg: the [128/512] and [512/128] cases here.
The new kernel reuses the existing
fused_rope_block_forwardandfused_rope_block_backwarddevice helpers, so the math doesn't change. All we need to do is add a THD-only path that launches one bloc/packed token.This is mostly pathological, however, so I've added a condition on the dispatch to avoid the unnecessary binary search overhead, although the overhead appears to be not-that-relevant. The condition is: token-linear only when
b >= 64and the old launch would issue ≥ 8× as many blocks as there are tokens. I'm not sure if this the usual shape of TE updates, so I could remove it!Some more relevant tests:
Microbenchmark on H100 (bf16,
h=32,d=d2=128,freqs_len=T_local=65536, single GPU):Fixes: #2866.
Type of change
Changes
Please list the changes introduced in this PR:
NVTE_FUSED_ROPE_THD_TOKEN_LINEAR=0|1.fused_rope_block_forwardandfused_rope_block_backwarddevice helpers.Checklist: