Skip to content

[CK] Fix grouped conv bwd data stride>1 silent miscompute (ALMIOPEN-1959)#7732

Merged
JH-Leon-KIM-AMD merged 4 commits into
developfrom
users/jeongkim/ck/fix-grouped-conv-bwd-data-noshuffle
May 27, 2026
Merged

[CK] Fix grouped conv bwd data stride>1 silent miscompute (ALMIOPEN-1959)#7732
JH-Leon-KIM-AMD merged 4 commits into
developfrom
users/jeongkim/ck/fix-grouped-conv-bwd-data-noshuffle

Conversation

@JH-Leon-KIM-AMD
Copy link
Copy Markdown
Contributor

@JH-Leon-KIM-AMD JH-Leon-KIM-AMD commented May 25, 2026

Motivation

Fix silent miscompute in the grouped convolution backward-data kernel (DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1) when stride > dilation (ALMIOPEN-1959). PR #6208 introduced a flat-descriptor fast path that dropped all but the first sub-GEMM, producing zeroed slices of dx on
the (G=1, stride>1, 2D, NumDTensor=0) intersection. Restore correctness without giving up the perf gains PR #6208 delivered on stride=1 shapes.

Technical Details

  • Tighten the flat-descriptor fast-path gate to require arg.gemms_count_ == 1 (i.e. a single sub-GEMM per dispatch — its original purpose). For stride > 1, the implicit GEMM is split into gemms_count_ sub-GEMMs whose output cells tile dx disjointly; routing them through the flat path required dropping all but the first, which was the source of the bug.
  • Stride > 1 now falls through to the existing grouped CShuffle path, which packs all sub-GEMMs into one descriptor array and walks them on-device in a single kernel launch. This is the pre-PR-6208 production path; correctness is established and per-dispatch launch count is minimised.
  • Add regression coverage for the (G=1, stride>1, 2D, NumDTensor=0) intersection in test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp with gemms_count ∈ {4, 9, 36}. Pre-existing cases did not hit this intersection (all stride>1 cases used G=2; all G=1 cases used stride=1), which is why PR [CK] Enable grouped conv bwd data to match non-grouped perf via NoShuffle + packed descriptors #6208's regression slipped past CI.

Test Plan

  • ctest -L SMOKE_TEST -R 'grouped_convnd_bwd_data' on gfx942 (smoke tier — runs on every PR via smart_build_and_test.sh).
  • End-to-end verify (verify=1) via example_grouped_conv_bwd_data_xdl_fp16 on stride 1/2/3/6 shapes including the original ALMIOPEN-1959 case and a cross-bucket (gemms_count=36) case spanning two MaxGroupedGemmGroupsNum=32 buckets.
  • ckProfiler A/B sweep on MI300X (gfx942) toggling the flat-path gate via an environment variable: full kernel-family enumeration, winning kernel + its avg_time reported under each gate. 33/41 shapes completed before the sweep was stopped; the remaining 8 were the largest i2v/synthetic shapes where ckProfiler exceeded its 300s per-shape enumeration budget (not relevant to the verdict).

Test Result

Correctness

Test Result
test_grouped_convnd_bwd_data (12 type parameterizations × Test2D, includes 3 new regression shapes) 12/12 PASSED in 14.18 s
test_grouped_convnd_bwd_data_interface (API checks) PASSED in 0.28 s
ALMIOPEN-1959 stride=2 (verify=1) PASSED
stride=1 K3 (verify=1) PASSED
stride=3 K3 gemms_count=9 (verify=1) PASSED
stride=6 K6 gemms_count=36 cross-bucket (verify=1) PASSED

Performance (ckProfiler A/B on gfx942 / MI300X)

Comparing the post-fix gate (flat path only when gemms_count_==1,
column "B") vs the inner-loop variant that keeps the flat path on
stride>1 (column "A") across 25 stride>1 shapes where production picks
a _v1 instance (so the gate actually fires):

Stride Shapes A wins Tie B wins Notes
1 (sanity, gate moot) 3 0 3 0 gate doesn't differentiate — A == B as expected
> 1 (gate fires) 25 0 11 14 B wins +6% to +32%; A never wins

Highlights from the firing-gate cases:

Shape (G=1, stride=2 unless noted) A ms B ms B vs A
ALMIOPEN-1959 (N=16, K=256, C=128, 5×5, 40×175) 0.183 0.171 B +6%
Retinanet-L61 (N=32, K=C=256, 3×3, 25×25) 0.054 0.045 B +17%
i2v-010 (N=1, K=C=384, 3×3, 277×209) 0.174 0.125 B +28%
Synthetic 50×50 K3 N=32 K=C=256 0.131 0.088 B +32%

Why B wins everywhere the gate fires: for gemms_count = N, the flat path needs N kernel launches (one per sub-GEMM), while the grouped path loops over the same N sub-GEMMs on-device in 1 launch. The (N−1) × launch-tax is a structural disadvantage A can't recover from.

Diff

File Lines
include/.../device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp +14 / −8 (one extra condition + expanded dispatch comment)
test/.../test_grouped_convnd_bwd_data.cpp +9 / −0 (3 new shapes)

Submission Checklist

…959)

The flat-descriptor fast path added in PR #6208 used only the first
sub-GEMM (flat_*_container_[gemm_set_id]) and silently dropped the
remaining sub-GEMMs when stride > dilation, producing zeroed dx slices
on the (G=1, stride>1, 2D, NumDTensor=0) intersection.

Gate the flat path to gemms_count_ == 1 so it only fires when there is
exactly one sub-GEMM per dispatch (its original purpose). Stride>1 now
routes through the existing grouped CShuffle path which packs all
sub-GEMMs into one descriptor array and walks them on-device in a
single launch -- correct, and 6-32% faster than per-sub-GEMM relaunches
across the production shapes measured via ckProfiler (Retinanet, i2v,
ALMIOPEN-1959).

Add regression coverage for the previously-uncovered intersection with
gemms_count in {4, 9, 36} so a future regression in the gate is caught
by CI.
@JH-Leon-KIM-AMD JH-Leon-KIM-AMD marked this pull request as ready for review May 26, 2026 08:59
@JH-Leon-KIM-AMD JH-Leon-KIM-AMD requested a review from a team as a code owner May 26, 2026 08:59
@JH-Leon-KIM-AMD JH-Leon-KIM-AMD merged commit b0e29d9 into develop May 27, 2026
39 of 43 checks passed
@JH-Leon-KIM-AMD JH-Leon-KIM-AMD deleted the users/jeongkim/ck/fix-grouped-conv-bwd-data-noshuffle branch May 27, 2026 06:59
shumway pushed a commit to ROCm/composable_kernel that referenced this pull request May 27, 2026
[CK] Fix grouped conv bwd data stride>1 silent miscompute (ALMIOPEN-1959) (#7732)

## Motivation

Fix silent miscompute in the grouped convolution backward-data kernel
(`DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1`) when stride >
dilation (ALMIOPEN-1959). PR #6208 introduced a flat-descriptor fast
path that dropped all but the first sub-GEMM, producing zeroed slices of
`dx` on
the (G=1, stride>1, 2D, NumDTensor=0) intersection. Restore correctness
without giving up the perf gains PR #6208 delivered on stride=1 shapes.

## Technical Details

- Tighten the flat-descriptor fast-path gate to require
`arg.gemms_count_ == 1` (i.e. a single sub-GEMM per dispatch — its
original purpose). For stride > 1, the implicit GEMM is split into
`gemms_count_` sub-GEMMs whose output cells tile `dx` disjointly;
routing them through the flat path required dropping all but the first,
which was the source of the bug.
- Stride > 1 now falls through to the existing grouped CShuffle path,
which packs all sub-GEMMs into one descriptor array and walks them
on-device in a single kernel launch. This is the pre-PR-6208 production
path; correctness is established and per-dispatch launch count is
minimised.
- Add regression coverage for the (G=1, stride>1, 2D, NumDTensor=0)
intersection in
`test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp` with
`gemms_count` ∈ {4, 9, 36}. Pre-existing cases did not hit this
intersection (all stride>1 cases used G=2; all G=1 cases used stride=1),
which is why PR #6208's regression slipped past CI.

## Test Plan

- `ctest -L SMOKE_TEST -R 'grouped_convnd_bwd_data'` on gfx942 (smoke
tier — runs on every PR via `smart_build_and_test.sh`).
- End-to-end verify (`verify=1`) via
`example_grouped_conv_bwd_data_xdl_fp16` on stride 1/2/3/6 shapes
including the original ALMIOPEN-1959 case and a cross-bucket
(`gemms_count=36`) case spanning two `MaxGroupedGemmGroupsNum=32`
buckets.
- ckProfiler A/B sweep on MI300X (gfx942) toggling the flat-path gate
via an environment variable: full kernel-family enumeration, winning
kernel + its avg_time reported under each gate. 33/41 shapes completed
before the sweep was stopped; the remaining 8 were the largest
i2v/synthetic shapes where ckProfiler exceeded its 300s per-shape
enumeration budget (not relevant to the verdict).

## Test Result

### Correctness

| Test | Result |
|---|:---:|
| `test_grouped_convnd_bwd_data` (12 type parameterizations × Test2D,
includes 3 new regression shapes) | **12/12 PASSED** in 14.18 s |
| `test_grouped_convnd_bwd_data_interface` (API checks) | **PASSED** in
0.28 s |
| ALMIOPEN-1959 stride=2 (`verify=1`) | **PASSED** |
| stride=1 K3 (`verify=1`) | **PASSED** |
| stride=3 K3 `gemms_count=9` (`verify=1`) | **PASSED** |
| stride=6 K6 `gemms_count=36` cross-bucket (`verify=1`) | **PASSED** |

### Performance (ckProfiler A/B on gfx942 / MI300X)

Comparing the **post-fix gate** (flat path only when `gemms_count_==1`,
column "B") vs the **inner-loop variant** that keeps the flat path on
stride>1 (column "A") across 25 stride>1 shapes where production picks
a `_v1` instance (so the gate actually fires):

| Stride | Shapes | A wins | Tie | B wins | Notes |
|:------:|:------:|:------:|:---:|:------:|---|
| 1 (sanity, gate moot) | 3 | 0 | 3 | 0 | gate doesn't differentiate — A
== B as expected |
| > 1 (gate fires) | 25 | **0** | 11 | **14** | B wins +6% to +32%; A
never wins |

Highlights from the firing-gate cases:

| Shape (G=1, stride=2 unless noted) | A ms | B ms | B vs A |
|---|---:|---:|---:|
| ALMIOPEN-1959 (N=16, K=256, C=128, 5×5, 40×175) | 0.183 | 0.171 | **B
+6%** |
| Retinanet-L61 (N=32, K=C=256, 3×3, 25×25) | 0.054 | 0.045 | **B +17%**
|
| i2v-010 (N=1, K=C=384, 3×3, 277×209) | 0.174 | 0.125 | **B +28%** |
| Synthetic 50×50 K3 N=32 K=C=256 | 0.131 | 0.088 | **B +32%** |

Why B wins everywhere the gate fires: for `gemms_count = N`, the flat
path needs N kernel launches (one per sub-GEMM), while the grouped path
loops over the same N sub-GEMMs on-device in 1 launch. The (N−1) ×
launch-tax is a structural disadvantage A can't recover from.

### Diff

| File | Lines |
|---|---:|
|
`include/.../device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp`
| +14 / −8 (one extra condition + expanded dispatch comment) |
| `test/.../test_grouped_convnd_bwd_data.cpp` | +9 / −0 (3 new shapes) |

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
yenong-amd pushed a commit that referenced this pull request May 28, 2026
…959) (#7732)

## Motivation

Fix silent miscompute in the grouped convolution backward-data kernel
(`DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1`) when stride >
dilation (ALMIOPEN-1959). PR #6208 introduced a flat-descriptor fast
path that dropped all but the first sub-GEMM, producing zeroed slices of
`dx` on
the (G=1, stride>1, 2D, NumDTensor=0) intersection. Restore correctness
without giving up the perf gains PR #6208 delivered on stride=1 shapes.

## Technical Details

- Tighten the flat-descriptor fast-path gate to require
`arg.gemms_count_ == 1` (i.e. a single sub-GEMM per dispatch — its
original purpose). For stride > 1, the implicit GEMM is split into
`gemms_count_` sub-GEMMs whose output cells tile `dx` disjointly;
routing them through the flat path required dropping all but the first,
which was the source of the bug.
- Stride > 1 now falls through to the existing grouped CShuffle path,
which packs all sub-GEMMs into one descriptor array and walks them
on-device in a single kernel launch. This is the pre-PR-6208 production
path; correctness is established and per-dispatch launch count is
minimised.
- Add regression coverage for the (G=1, stride>1, 2D, NumDTensor=0)
intersection in
`test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp` with
`gemms_count` ∈ {4, 9, 36}. Pre-existing cases did not hit this
intersection (all stride>1 cases used G=2; all G=1 cases used stride=1),
which is why PR #6208's regression slipped past CI.

## Test Plan

- `ctest -L SMOKE_TEST -R 'grouped_convnd_bwd_data'` on gfx942 (smoke
tier — runs on every PR via `smart_build_and_test.sh`).
- End-to-end verify (`verify=1`) via
`example_grouped_conv_bwd_data_xdl_fp16` on stride 1/2/3/6 shapes
including the original ALMIOPEN-1959 case and a cross-bucket
(`gemms_count=36`) case spanning two `MaxGroupedGemmGroupsNum=32`
buckets.
- ckProfiler A/B sweep on MI300X (gfx942) toggling the flat-path gate
via an environment variable: full kernel-family enumeration, winning
kernel + its avg_time reported under each gate. 33/41 shapes completed
before the sweep was stopped; the remaining 8 were the largest
i2v/synthetic shapes where ckProfiler exceeded its 300s per-shape
enumeration budget (not relevant to the verdict).

## Test Result

### Correctness

| Test | Result |
|---|:---:|
| `test_grouped_convnd_bwd_data` (12 type parameterizations × Test2D,
includes 3 new regression shapes) | **12/12 PASSED** in 14.18 s |
| `test_grouped_convnd_bwd_data_interface` (API checks) | **PASSED** in
0.28 s |
| ALMIOPEN-1959 stride=2 (`verify=1`) | **PASSED** |
| stride=1 K3 (`verify=1`) | **PASSED** |
| stride=3 K3 `gemms_count=9` (`verify=1`) | **PASSED** |
| stride=6 K6 `gemms_count=36` cross-bucket (`verify=1`) | **PASSED** |

### Performance (ckProfiler A/B on gfx942 / MI300X)

Comparing the **post-fix gate** (flat path only when `gemms_count_==1`,
column "B") vs the **inner-loop variant** that keeps the flat path on
stride>1 (column "A") across 25 stride>1 shapes where production picks
a `_v1` instance (so the gate actually fires):

| Stride | Shapes | A wins | Tie | B wins | Notes |
|:------:|:------:|:------:|:---:|:------:|---|
| 1 (sanity, gate moot) | 3 | 0 | 3 | 0 | gate doesn't differentiate — A
== B as expected |
| > 1 (gate fires) | 25 | **0** | 11 | **14** | B wins +6% to +32%; A
never wins |

Highlights from the firing-gate cases:

| Shape (G=1, stride=2 unless noted) | A ms | B ms | B vs A |
|---|---:|---:|---:|
| ALMIOPEN-1959 (N=16, K=256, C=128, 5×5, 40×175) | 0.183 | 0.171 | **B
+6%** |
| Retinanet-L61 (N=32, K=C=256, 3×3, 25×25) | 0.054 | 0.045 | **B +17%**
|
| i2v-010 (N=1, K=C=384, 3×3, 277×209) | 0.174 | 0.125 | **B +28%** |
| Synthetic 50×50 K3 N=32 K=C=256 | 0.131 | 0.088 | **B +32%** |

Why B wins everywhere the gate fires: for `gemms_count = N`, the flat
path needs N kernel launches (one per sub-GEMM), while the grouped path
loops over the same N sub-GEMMs on-device in 1 launch. The (N−1) ×
launch-tax is a structural disadvantage A can't recover from.

### Diff

| File | Lines |
|---|---:|
|
`include/.../device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp`
| +14 / −8 (one extra condition + expanded dispatch comment) |
| `test/.../test_grouped_convnd_bwd_data.cpp` | +9 / −0 (3 new shapes) |

## Submission Checklist

- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants