Skip to content

Add register-only cross-lane reduction for attention#2359

Open
stefankoncarevic wants to merge 7 commits into
ROCm:developfrom
stefankoncarevic:blockwise-reduce-crosslane
Open

Add register-only cross-lane reduction for attention#2359
stefankoncarevic wants to merge 7 commits into
ROCm:developfrom
stefankoncarevic:blockwise-reduce-crosslane

Conversation

@stefankoncarevic
Copy link
Copy Markdown
Contributor

@stefankoncarevic stefankoncarevic commented Apr 27, 2026

Resolves: https://amd-hub.atlassian.net/browse/AIROCMLIR-771

Motivation

In the current implementation, blockwise broadcast-reduce operations with small partial reduction factors (partialR ∈ {2, 4}) perform cross-lane reduction through LDS (Local Data Share), even though only a simple 2-way or 4-way reduction is needed. This is suboptimal because:

  • A 2-way or 4-way reduction can be done with a single register-to-register cross-lane exchange — no shared memory required.
  • The LDS store → barrier → tree-reduce → barrier → readback sequence adds latency that is disproportionate to the actual work (one or two reduction steps).
  • Modern AMD GPUs provide efficient cross-lane instructions (v_permlanex16_var_b32, v_permlane{16,32}_swap_b32, ds_swizzle_b32, ds_bpermute_b32) that can replace this pattern entirely, keeping the reduction in registers.
    This PR replaces LDS-based reduction with register-only cross-lane intrinsics across all supported AMD GPU architectures (CDNA and RDNA), and adds transparent widening so that sub-32-bit types like f16 (common in attention kernels) also benefit from the fast path instead of falling back to the slower LDS-tree reduction.

Technical Details

1. Architecture-specific cross-lane reduction helpers

Each helper implements a complete reduce step: load → widen → bitcast → intrinsic → bitcast → narrow → reduce → store.

Helper Intrinsic(s) Architecture Wave size
permlaneX16VarReduce v_permlanex16_var_b32 gfx950, gfx12 (RDNA4) 32
dsSwizzleReduceWave32 ds_swizzle_b32 XOR=16 gfx11 (RDNA3) 32
permlaneSwapReduceStep v_permlane{16,32}_swap_b32 gfx950 (CDNA4) 64
dsSwizzleReduceStep ds_swizzle_b32 XOR within 32-lane half gfx908, gfx90a, gfx94x (CDNA 1/2/3) 64
dsBpermuteReduceStep ds_bpermute_b32 XOR 32 cross-half gfx908, gfx90a, gfx94x (CDNA 1/2/3) 64

2. Sub-32-bit type widening (f16 → f32)

All cross-lane intrinsics operate on 32-bit registers. Previously, sub-32-bit types (f16) could not use the fast path and fell back to the LDS-tree reduction. Three helper functions transparently widen/narrow values around the intrinsic call:

  • widenTo32Bit()arith.extf f16→f32 (or arith.extsi for integer types)
  • narrowFrom32Bit()arith.truncf f32→f16 (or arith.trunci for integers)
  • get32BitType() — returns the corresponding 32-bit type (f32 for floats, i32 for integers)
    The flow for f16: f16 → extf → f32 → bitcast → i32 → intrinsic → i32 → bitcast → f32 → truncf → f16 → reduce → store. For f32 (already 32-bit), widen/narrow are no-ops — existing behavior is unchanged.
    This is critical for attention kernels which use f16 throughout — without widening they fall back to LDS-tree and lose the performance benefit.

3. Two reduction regimes: NR-Large and NR-Small

NR-Large (blockSize ≤ nrDimProd): The reduction dimension maps entirely to intra-wave lane distances. Cross-lane reduction replaces the LDS-tree. Each thread reduces its own private buffer elements.

  • Wave32: partialR == 2 && partialR == mTidPerWave — one cross-half-wave step
  • Wave64: partialR ∈ {2, 4} && partialR == mTidPerWave — one or two swap steps
    NR-Small (blockSize > nrDimProd): Threadwise partial reduction produces a single accumulator per thread. The cross-lane step reduces this scalar across partner lanes.
  • Wave32: partialR == 2 && nTidPerWave == 16 — one cross-half-wave step
  • Wave64: partialR ∈ {2, 4} && nrDimProd * partialR == waveSize — full register-only path

4. LDS-skip optimization

When all safety conditions are met (single-wave or partialR == mTidPerWave, K == 1, no extraOut), both the upfront LDS write+barrier and the final LDS broadcast are skipped — the entire reduction stays in registers. The result is broadcast directly from partialReductionBuffer via readReducedResultsFromPrivateBuffer.

5. Eligibility gating

Each fast-path has explicit eligibility flags that check:

  • Architecture support (hasPermlaneVar, hasPermlaneSwap, hasDsSwizzleBpermute, hasDsSwizzleWave32)
  • 2D thread layout (has2DThreadLayout via m_tid/n_tid naming)
  • Reduction geometry (partialR, mTidPerWave, nTidPerWave, blockSize, nonReductionDimSizeProduct)
  • Safety constraints (K == 1, !extraOut for LDS-skip variants)
    Configurations that don't match any fast-path fall back to the existing LDS round-trip path unchanged.

Test Plan

New: lowering_bcast_reduce_cross_lane.mlir (15 tests)

Test Arch Path Verifies
test_permlane_nrsmall_ldsskip_gfx1201 RDNA4 w32 NR-Small PermlaneX16Var LDS-skip rocdl.permlanex16.var, no lds_barrier
test_permlane_nrlarge_gfx1201 RDNA4 w32 NR-Large PermlaneX16Var rocdl.permlanex16.var, serial XOR butterfly
test_dsswizzle_nrsmall_ldsskip_gfx1100 RDNA3 w32 NR-Small DsSwizzle LDS-skip rocdl.ds_swizzle, no lds_barrier
test_dsswizzle_nrlarge_gfx1100 RDNA3 w32 NR-Large DsSwizzle rocdl.ds_swizzle, serial XOR butterfly
test_dsbpermute_nrsmall_ldsskip_sum_gfx942 CDNA3 w64 NR-Small DsSwizzle+Bpermute LDS-skip rocdl.ds_bpermute, no lds_barrier
test_permlaneswap_nrsmall_ldsskip_sum_gfx950 CDNA4 w64 NR-Small PermlaneSwap LDS-skip rocdl.permlane32.swap, no lds_barrier
test_dsbpermute_nrsmall_ldsskip_r4_sum_gfx942 CDNA3 w64 NR-Small partialR=4 ds_swizzle + ds_bpermute, no lds_barrier
test_permlaneswap_nrsmall_ldsskip_r4_sum_gfx950 CDNA4 w64 NR-Small partialR=4 permlane16.swap + permlane32.swap, no lds_barrier
test_dsbpermute_nrlarge_r2_gfx942 CDNA3 w64 NR-Large partialR=2 rocdl.ds_bpermute, LDS broadcast
test_permlaneswap_nrlarge_r2_gfx950 CDNA4 w64 NR-Large partialR=2 rocdl.permlane32.swap, LDS broadcast
test_dsbpermute_nrlarge_r4_gfx942 CDNA3 w64 NR-Large partialR=4 ds_swizzle + ds_bpermute, LDS broadcast
test_permlaneswap_nrlarge_r4_gfx950 CDNA4 w64 NR-Large partialR=4 permlane16.swap + permlane32.swap, LDS broadcast
test_dsbpermute_nrsmall_noldsskip_gfx942 CDNA3 w64 NR-Small no LDS-skip (K>1) rocdl.ds_bpermute, HAS lds_barrier
test_permlaneswap_nrsmall_noldsskip_gfx950 CDNA4 w64 NR-Small no LDS-skip (K>1) rocdl.permlane32.swap, HAS lds_barrier
test_ldstree_fallback_multiwave_gfx942 CDNA3 w64 LDS-tree fallback No intrinsics, HAS lds_barrier

Test Result

Both CI Pass

Submission Checklist

@stefankoncarevic stefankoncarevic changed the base branch from dpp-refactor-blockwise-reduce to develop April 27, 2026 11:44
@stefankoncarevic stefankoncarevic force-pushed the blockwise-reduce-crosslane branch 2 times, most recently from 71cac93 to ffd8579 Compare May 8, 2026 21:37
@stefankoncarevic stefankoncarevic force-pushed the blockwise-reduce-crosslane branch from cb4dba3 to e4c2001 Compare May 21, 2026 15:45
Introduce a register-only reduction path using v_permlanex16_var_b32
(GFX12+) for blockwise broadcast-reduce when partialR=2 on wave32
architectures. This avoids the initial LDS store + barrier by performing
the 2-way reduction directly in registers before writing to LDS.
…or Navi4x

Restructure the permlanex16_var reduction logic into two distinct paths
gated by 2D thread layout awareness (mTidPerWave/nTidPerWave):
- SerialPermlane (blockSize <= nrDimProd): XOR butterfly reduction in
  registers for power-of-2 rDimSize matching mTidPerWave. Uses LDS only
  for final broadcast.
- PR2-Permlane (blockSize > nrDimProd): register-only cross-half-wave
  reduction for partialR=2 when nTidPerWave=16 (lanes 0-15 <-> 16-31),
  avoiding the initial LDS store + barrier.
Both paths now require has2DThreadLayout and wave32. The PR2-Permlane
path is moved into the blockSize > nrDimProd branch alongside DPP and
LDS-Tree fallbacks. Clean up comments for brevity.
Introduce register-only cross-lane reduction paths for blockwise reductions
on gfx950 (wave64) and gfx12 / Navi4x (wave32), and skip the LDS round-trip
when single-wave coverage is sufficient.
NR-Large (blockSize <= nrDimProd):
- New permlaneSwapReduce helper using v_permlane{16,32}_swap_b32 for
  partner-group sizes 2 or 4 (gfx950 wave64).
- canUsePermlaneSwapReduce path skips LDS entirely when blockSize ==
  waveSize, broadcasting via the new readReducedResultsFromPrivateBuffer
  helper instead of going through LDS.
- canUseSerialPermlane (wave32) gains the same single-wave END LDS skip,
  mirroring the wave64 behaviour.
NR-Small (blockSize > nrDimProd):
- canUsePermlaneInDPP replaces the gpu.subgroup_reduce/DPP cross-lane step
  with permlaneSwapReduce on a scalar accumulator.
  END LDS round-trips are skipped: the input is read from
  partialReductionBuffer[0], and the broadcast goes through
  readReducedResultsFromPrivateBuffer.
- canUsePermlaneReduce (wave32) gains an analogous single-wave END LDS
  skip (canUsePermlaneReduceLdsSkip).
Architecture gating:
- The wave32 permlanex16_var paths now require the new hasPermlaneVar gate
  (gfx950 || gfx12). This prevents emitting v_permlanex16_var_b32 on gfx11
  / Navi3x, which only exposes the immediate-selector form. A future
  ds_swizzle path will cover gfx11.
Multi-wave configs, configs with extraOut, and unsupported architectures
fall back to the existing LDS round-trip path unchanged, so the new
fast-paths are purely additive.
…paths

Remove the blockSize == waveSize constraint from LDS-skip conditions
for NR-Large (SerialPermlane on gfx1201, PermlaneSwapReduce on gfx950)
and NR-Small (PR2-Permlane on gfx1201) reduction paths. partialR ==
mTidPerWave guarantees all reduction partners reside within the same
wave, making cross-wave LDS unnecessary even for multi-wave blocks.
Register-only blockwise reduction for gfx908 (MI100), gfx90a (MI250),
and gfx94x (MI300) using ds_swizzle_b32 (XOR within 32-lane halves)
and ds_bpermute_b32 (XOR 32, crossing the half-wave boundary).
NR-Large path: replaces LDS-Tree for partialR in {2,4} when
blockSize <= nrDimProd — fully LDS-free.
NR-Small path: replaces LDS+DPP for single-wave configs with
maxActiveReductionThreads in {2,4}. LDS-skip variant (K==1, no
extraOut) is fully register-only; otherwise only the cross-lane
step is optimized while initial LDS store remains.
ds_swizzle wave32 for gfx11, remove DPP LDS-skip, clean up tests

- Add widenTo32Bit/narrowFrom32Bit helpers so sub-32-bit types (f16)
  use the fast register-only path instead of falling back to LDS-tree.
  f16 values are widened to f32 before the i32 bitcast/intrinsic and
  narrowed back afterward.
- Remove is32BitElem gate from eligibility predicates — no longer
  needed with widening in place.
- Add ds_swizzle XOR=16 wave32 path for gfx11 (RDNA3/Navi3x) where
  v_permlanex16_var is not available.
- Remove faulty DPP LDS-skip optimization that caused incorrect reads
  on certain thread layouts.
- Add lowering_bcast_reduce_cross_lane.mlir test covering all
  register-only reduction paths across architectures.
- Add RDNA3 gfx1100 pipeline ordering test in pipelines.mlir.
@stefankoncarevic stefankoncarevic force-pushed the blockwise-reduce-crosslane branch from e4c2001 to 4eabdfd Compare May 21, 2026 15:45
@stefankoncarevic stefankoncarevic changed the title [DRAFT] Add register-only cross-lane reduction for attention Add register-only cross-lane reduction for attention May 21, 2026
@stefankoncarevic stefankoncarevic requested a review from Copilot May 21, 2026 15:48
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds register-only cross-lane reduction fast paths to rock.blockwise_broadcast_reduce lowering (replacing some LDS store/barrier/read sequences) for specific single-wave layouts relevant to attention-style reductions, leveraging ROCDL cross-lane primitives on multiple AMD GPU families.

Changes:

  • Implement multiple cross-lane reduction strategies (permlane swap, ds_swizzle+ds_bpermute, and wave32 ds_swizzle / permlanex16.var) with optional LDS end-to-end skip when safe.
  • Add extensive MLIR tests covering the new fast paths across architectures (gfx11/gfx94x/gfx950/gfx12xx) and add a non-power-of-2 NR-dim tid factoring regression test.
  • Extend ROCDL dialect definitions and update driver pipeline-dump tests for gfx1100.

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
mlir/test/rocmlir-driver/pipelines.mlir Adds pipeline-dump coverage for -arch=gfx1100 (RDNA3).
mlir/test/Dialect/Rock/lowering_blockwise_broadcast_reduce.mlir Adds coverage for NR-Small tid factoring when NR product is non-power-of-2.
mlir/test/Dialect/Rock/lowering_bcast_reduce_cross_lane.mlir New test file validating cross-lane register-only reduction paths and LDS-skip behavior.
mlir/lib/Dialect/Rock/Transforms/BlockwiseGemmToThreadwise.cpp Implements cross-lane reduction + register-only broadcast/skip logic in the lowering.
mlir/include/mlir/Dialect/Rock/Passes.td Updates pass dependent dialect list for ROCDL usage.
external/llvm-project/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td Adds a new ROCDL op definition for permlanex16.var.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread mlir/include/mlir/Dialect/Rock/Passes.td Outdated
…e reductions

Use the upstream amdgpu.swizzle_bitmode op instead of manually calling
rocdl.ds_swizzle with hand-rolled widening/bitcast/narrowing for
sub-32-bit types. The AMDGPU op handles automatic type decomposition
and recomposition via decomposeValue/composeValue, simplifying the code
and removing 8 lines of boilerplate per swizzle call site.
Changes:
- BlockwiseGemmToThreadwise.cpp: Replace ROCDL::DsSwizzleOp with
  amdgpu::SwizzleBitModeOp in dsSwizzleReduceStep, add AMDGPU dialect
  to legal dialects in conversion target
- CMakeLists.txt: Add MLIRAMDGPUDialect dependency
- Passes.td: Add amdgpu::AMDGPUDialect to dependentDialects
- Tests: Update CHECK patterns from rocdl.ds_swizzle to
  amdgpu.swizzle_bitmode
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants