Add register-only cross-lane reduction for attention#2359
Open
stefankoncarevic wants to merge 7 commits into
Open
Add register-only cross-lane reduction for attention#2359stefankoncarevic wants to merge 7 commits into
stefankoncarevic wants to merge 7 commits into
Conversation
71cac93 to
ffd8579
Compare
cb4dba3 to
e4c2001
Compare
Introduce a register-only reduction path using v_permlanex16_var_b32 (GFX12+) for blockwise broadcast-reduce when partialR=2 on wave32 architectures. This avoids the initial LDS store + barrier by performing the 2-way reduction directly in registers before writing to LDS.
…or Navi4x Restructure the permlanex16_var reduction logic into two distinct paths gated by 2D thread layout awareness (mTidPerWave/nTidPerWave): - SerialPermlane (blockSize <= nrDimProd): XOR butterfly reduction in registers for power-of-2 rDimSize matching mTidPerWave. Uses LDS only for final broadcast. - PR2-Permlane (blockSize > nrDimProd): register-only cross-half-wave reduction for partialR=2 when nTidPerWave=16 (lanes 0-15 <-> 16-31), avoiding the initial LDS store + barrier. Both paths now require has2DThreadLayout and wave32. The PR2-Permlane path is moved into the blockSize > nrDimProd branch alongside DPP and LDS-Tree fallbacks. Clean up comments for brevity.
Introduce register-only cross-lane reduction paths for blockwise reductions
on gfx950 (wave64) and gfx12 / Navi4x (wave32), and skip the LDS round-trip
when single-wave coverage is sufficient.
NR-Large (blockSize <= nrDimProd):
- New permlaneSwapReduce helper using v_permlane{16,32}_swap_b32 for
partner-group sizes 2 or 4 (gfx950 wave64).
- canUsePermlaneSwapReduce path skips LDS entirely when blockSize ==
waveSize, broadcasting via the new readReducedResultsFromPrivateBuffer
helper instead of going through LDS.
- canUseSerialPermlane (wave32) gains the same single-wave END LDS skip,
mirroring the wave64 behaviour.
NR-Small (blockSize > nrDimProd):
- canUsePermlaneInDPP replaces the gpu.subgroup_reduce/DPP cross-lane step
with permlaneSwapReduce on a scalar accumulator.
END LDS round-trips are skipped: the input is read from
partialReductionBuffer[0], and the broadcast goes through
readReducedResultsFromPrivateBuffer.
- canUsePermlaneReduce (wave32) gains an analogous single-wave END LDS
skip (canUsePermlaneReduceLdsSkip).
Architecture gating:
- The wave32 permlanex16_var paths now require the new hasPermlaneVar gate
(gfx950 || gfx12). This prevents emitting v_permlanex16_var_b32 on gfx11
/ Navi3x, which only exposes the immediate-selector form. A future
ds_swizzle path will cover gfx11.
Multi-wave configs, configs with extraOut, and unsupported architectures
fall back to the existing LDS round-trip path unchanged, so the new
fast-paths are purely additive.
…paths Remove the blockSize == waveSize constraint from LDS-skip conditions for NR-Large (SerialPermlane on gfx1201, PermlaneSwapReduce on gfx950) and NR-Small (PR2-Permlane on gfx1201) reduction paths. partialR == mTidPerWave guarantees all reduction partners reside within the same wave, making cross-wave LDS unnecessary even for multi-wave blocks.
Register-only blockwise reduction for gfx908 (MI100), gfx90a (MI250),
and gfx94x (MI300) using ds_swizzle_b32 (XOR within 32-lane halves)
and ds_bpermute_b32 (XOR 32, crossing the half-wave boundary).
NR-Large path: replaces LDS-Tree for partialR in {2,4} when
blockSize <= nrDimProd — fully LDS-free.
NR-Small path: replaces LDS+DPP for single-wave configs with
maxActiveReductionThreads in {2,4}. LDS-skip variant (K==1, no
extraOut) is fully register-only; otherwise only the cross-lane
step is optimized while initial LDS store remains.
ds_swizzle wave32 for gfx11, remove DPP LDS-skip, clean up tests - Add widenTo32Bit/narrowFrom32Bit helpers so sub-32-bit types (f16) use the fast register-only path instead of falling back to LDS-tree. f16 values are widened to f32 before the i32 bitcast/intrinsic and narrowed back afterward. - Remove is32BitElem gate from eligibility predicates — no longer needed with widening in place. - Add ds_swizzle XOR=16 wave32 path for gfx11 (RDNA3/Navi3x) where v_permlanex16_var is not available. - Remove faulty DPP LDS-skip optimization that caused incorrect reads on certain thread layouts. - Add lowering_bcast_reduce_cross_lane.mlir test covering all register-only reduction paths across architectures. - Add RDNA3 gfx1100 pipeline ordering test in pipelines.mlir.
e4c2001 to
4eabdfd
Compare
Contributor
There was a problem hiding this comment.
Pull request overview
This PR adds register-only cross-lane reduction fast paths to rock.blockwise_broadcast_reduce lowering (replacing some LDS store/barrier/read sequences) for specific single-wave layouts relevant to attention-style reductions, leveraging ROCDL cross-lane primitives on multiple AMD GPU families.
Changes:
- Implement multiple cross-lane reduction strategies (permlane swap, ds_swizzle+ds_bpermute, and wave32 ds_swizzle / permlanex16.var) with optional LDS end-to-end skip when safe.
- Add extensive MLIR tests covering the new fast paths across architectures (gfx11/gfx94x/gfx950/gfx12xx) and add a non-power-of-2 NR-dim tid factoring regression test.
- Extend ROCDL dialect definitions and update driver pipeline-dump tests for gfx1100.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| mlir/test/rocmlir-driver/pipelines.mlir | Adds pipeline-dump coverage for -arch=gfx1100 (RDNA3). |
| mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir | Adds coverage for NR-Small tid factoring when NR product is non-power-of-2. |
| mlir/test/Dialect/Rock/lowering_bcast_reduce_cross_lane.mlir | New test file validating cross-lane register-only reduction paths and LDS-skip behavior. |
| mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp | Implements cross-lane reduction + register-only broadcast/skip logic in the lowering. |
| mlir/include/mlir/Dialect/Rock/Passes.td | Updates pass dependent dialect list for ROCDL usage. |
| external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | Adds a new ROCDL op definition for permlanex16.var. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
…e reductions Use the upstream amdgpu.swizzle_bitmode op instead of manually calling rocdl.ds_swizzle with hand-rolled widening/bitcast/narrowing for sub-32-bit types. The AMDGPU op handles automatic type decomposition and recomposition via decomposeValue/composeValue, simplifying the code and removing 8 lines of boilerplate per swizzle call site. Changes: - BlockwiseGemmToThreadwise.cpp: Replace ROCDL::DsSwizzleOp with amdgpu::SwizzleBitModeOp in dsSwizzleReduceStep, add AMDGPU dialect to legal dialects in conversion target - CMakeLists.txt: Add MLIRAMDGPUDialect dependency - Passes.td: Add amdgpu::AMDGPUDialect to dependentDialects - Tests: Update CHECK patterns from rocdl.ds_swizzle to amdgpu.swizzle_bitmode
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Resolves: https://amd-hub.atlassian.net/browse/AIROCMLIR-771
Motivation
In the current implementation, blockwise broadcast-reduce operations with small partial reduction factors (
partialR∈ {2, 4}) perform cross-lane reduction through LDS (Local Data Share), even though only a simple 2-way or 4-way reduction is needed. This is suboptimal because:v_permlanex16_var_b32,v_permlane{16,32}_swap_b32,ds_swizzle_b32,ds_bpermute_b32) that can replace this pattern entirely, keeping the reduction in registers.This PR replaces LDS-based reduction with register-only cross-lane intrinsics across all supported AMD GPU architectures (CDNA and RDNA), and adds transparent widening so that sub-32-bit types like f16 (common in attention kernels) also benefit from the fast path instead of falling back to the slower LDS-tree reduction.
Technical Details
1. Architecture-specific cross-lane reduction helpers
Each helper implements a complete reduce step: load → widen → bitcast → intrinsic → bitcast → narrow → reduce → store.
permlaneX16VarReducev_permlanex16_var_b32dsSwizzleReduceWave32ds_swizzle_b32XOR=16permlaneSwapReduceStepv_permlane{16,32}_swap_b32dsSwizzleReduceStepds_swizzle_b32XOR within 32-lane halfdsBpermuteReduceStepds_bpermute_b32XOR 32 cross-half2. Sub-32-bit type widening (f16 → f32)
All cross-lane intrinsics operate on 32-bit registers. Previously, sub-32-bit types (f16) could not use the fast path and fell back to the LDS-tree reduction. Three helper functions transparently widen/narrow values around the intrinsic call:
widenTo32Bit()—arith.extf f16→f32(orarith.extsifor integer types)narrowFrom32Bit()—arith.truncf f32→f16(orarith.truncifor integers)get32BitType()— returns the corresponding 32-bit type (f32 for floats, i32 for integers)The flow for f16:
f16 → extf → f32 → bitcast → i32 → intrinsic → i32 → bitcast → f32 → truncf → f16 → reduce → store. For f32 (already 32-bit), widen/narrow are no-ops — existing behavior is unchanged.This is critical for attention kernels which use f16 throughout — without widening they fall back to LDS-tree and lose the performance benefit.
3. Two reduction regimes: NR-Large and NR-Small
NR-Large (
blockSize ≤ nrDimProd): The reduction dimension maps entirely to intra-wave lane distances. Cross-lane reduction replaces the LDS-tree. Each thread reduces its own private buffer elements.partialR == 2 && partialR == mTidPerWave— one cross-half-wave steppartialR ∈ {2, 4} && partialR == mTidPerWave— one or two swap stepsNR-Small (
blockSize > nrDimProd): Threadwise partial reduction produces a single accumulator per thread. The cross-lane step reduces this scalar across partner lanes.partialR == 2 && nTidPerWave == 16— one cross-half-wave steppartialR ∈ {2, 4} && nrDimProd * partialR == waveSize— full register-only path4. LDS-skip optimization
When all safety conditions are met (single-wave or
partialR == mTidPerWave,K == 1, noextraOut), both the upfront LDS write+barrier and the final LDS broadcast are skipped — the entire reduction stays in registers. The result is broadcast directly frompartialReductionBufferviareadReducedResultsFromPrivateBuffer.5. Eligibility gating
Each fast-path has explicit eligibility flags that check:
hasPermlaneVar,hasPermlaneSwap,hasDsSwizzleBpermute,hasDsSwizzleWave32)has2DThreadLayoutviam_tid/n_tidnaming)partialR,mTidPerWave,nTidPerWave,blockSize,nonReductionDimSizeProduct)K == 1,!extraOutfor LDS-skip variants)Configurations that don't match any fast-path fall back to the existing LDS round-trip path unchanged.
Test Plan
New:
lowering_bcast_reduce_cross_lane.mlir(15 tests)test_permlane_nrsmall_ldsskip_gfx1201rocdl.permlanex16.var, nolds_barriertest_permlane_nrlarge_gfx1201rocdl.permlanex16.var, serial XOR butterflytest_dsswizzle_nrsmall_ldsskip_gfx1100rocdl.ds_swizzle, nolds_barriertest_dsswizzle_nrlarge_gfx1100rocdl.ds_swizzle, serial XOR butterflytest_dsbpermute_nrsmall_ldsskip_sum_gfx942rocdl.ds_bpermute, nolds_barriertest_permlaneswap_nrsmall_ldsskip_sum_gfx950rocdl.permlane32.swap, nolds_barriertest_dsbpermute_nrsmall_ldsskip_r4_sum_gfx942ds_swizzle+ds_bpermute, nolds_barriertest_permlaneswap_nrsmall_ldsskip_r4_sum_gfx950permlane16.swap+permlane32.swap, nolds_barriertest_dsbpermute_nrlarge_r2_gfx942rocdl.ds_bpermute, LDS broadcasttest_permlaneswap_nrlarge_r2_gfx950rocdl.permlane32.swap, LDS broadcasttest_dsbpermute_nrlarge_r4_gfx942ds_swizzle+ds_bpermute, LDS broadcasttest_permlaneswap_nrlarge_r4_gfx950permlane16.swap+permlane32.swap, LDS broadcasttest_dsbpermute_nrsmall_noldsskip_gfx942rocdl.ds_bpermute, HASlds_barriertest_permlaneswap_nrsmall_noldsskip_gfx950rocdl.permlane32.swap, HASlds_barriertest_ldstree_fallback_multiwave_gfx942lds_barrierTest Result
Both CI Pass
Submission Checklist