[CK] Fix grouped conv bwd data stride>1 silent miscompute (ALMIOPEN-1959)#7732
Merged
JH-Leon-KIM-AMD merged 4 commits intoMay 27, 2026
Merged
Conversation
…959) The flat-descriptor fast path added in PR #6208 used only the first sub-GEMM (flat_*_container_[gemm_set_id]) and silently dropped the remaining sub-GEMMs when stride > dilation, producing zeroed dx slices on the (G=1, stride>1, 2D, NumDTensor=0) intersection. Gate the flat path to gemms_count_ == 1 so it only fires when there is exactly one sub-GEMM per dispatch (its original purpose). Stride>1 now routes through the existing grouped CShuffle path which packs all sub-GEMMs into one descriptor array and walks them on-device in a single launch -- correct, and 6-32% faster than per-sub-GEMM relaunches across the production shapes measured via ckProfiler (Retinanet, i2v, ALMIOPEN-1959). Add regression coverage for the previously-uncovered intersection with gemms_count in {4, 9, 36} so a future regression in the gate is caught by CI.
bartekxk
approved these changes
May 26, 2026
shumway
pushed a commit
to ROCm/composable_kernel
that referenced
this pull request
May 27, 2026
[CK] Fix grouped conv bwd data stride>1 silent miscompute (ALMIOPEN-1959) (#7732)
## Motivation
Fix silent miscompute in the grouped convolution backward-data kernel
(`DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1`) when stride >
dilation (ALMIOPEN-1959). PR #6208 introduced a flat-descriptor fast
path that dropped all but the first sub-GEMM, producing zeroed slices of
`dx` on
the (G=1, stride>1, 2D, NumDTensor=0) intersection. Restore correctness
without giving up the perf gains PR #6208 delivered on stride=1 shapes.
## Technical Details
- Tighten the flat-descriptor fast-path gate to require
`arg.gemms_count_ == 1` (i.e. a single sub-GEMM per dispatch — its
original purpose). For stride > 1, the implicit GEMM is split into
`gemms_count_` sub-GEMMs whose output cells tile `dx` disjointly;
routing them through the flat path required dropping all but the first,
which was the source of the bug.
- Stride > 1 now falls through to the existing grouped CShuffle path,
which packs all sub-GEMMs into one descriptor array and walks them
on-device in a single kernel launch. This is the pre-PR-6208 production
path; correctness is established and per-dispatch launch count is
minimised.
- Add regression coverage for the (G=1, stride>1, 2D, NumDTensor=0)
intersection in
`test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp` with
`gemms_count` ∈ {4, 9, 36}. Pre-existing cases did not hit this
intersection (all stride>1 cases used G=2; all G=1 cases used stride=1),
which is why PR #6208's regression slipped past CI.
## Test Plan
- `ctest -L SMOKE_TEST -R 'grouped_convnd_bwd_data'` on gfx942 (smoke
tier — runs on every PR via `smart_build_and_test.sh`).
- End-to-end verify (`verify=1`) via
`example_grouped_conv_bwd_data_xdl_fp16` on stride 1/2/3/6 shapes
including the original ALMIOPEN-1959 case and a cross-bucket
(`gemms_count=36`) case spanning two `MaxGroupedGemmGroupsNum=32`
buckets.
- ckProfiler A/B sweep on MI300X (gfx942) toggling the flat-path gate
via an environment variable: full kernel-family enumeration, winning
kernel + its avg_time reported under each gate. 33/41 shapes completed
before the sweep was stopped; the remaining 8 were the largest
i2v/synthetic shapes where ckProfiler exceeded its 300s per-shape
enumeration budget (not relevant to the verdict).
## Test Result
### Correctness
| Test | Result |
|---|:---:|
| `test_grouped_convnd_bwd_data` (12 type parameterizations × Test2D,
includes 3 new regression shapes) | **12/12 PASSED** in 14.18 s |
| `test_grouped_convnd_bwd_data_interface` (API checks) | **PASSED** in
0.28 s |
| ALMIOPEN-1959 stride=2 (`verify=1`) | **PASSED** |
| stride=1 K3 (`verify=1`) | **PASSED** |
| stride=3 K3 `gemms_count=9` (`verify=1`) | **PASSED** |
| stride=6 K6 `gemms_count=36` cross-bucket (`verify=1`) | **PASSED** |
### Performance (ckProfiler A/B on gfx942 / MI300X)
Comparing the **post-fix gate** (flat path only when `gemms_count_==1`,
column "B") vs the **inner-loop variant** that keeps the flat path on
stride>1 (column "A") across 25 stride>1 shapes where production picks
a `_v1` instance (so the gate actually fires):
| Stride | Shapes | A wins | Tie | B wins | Notes |
|:------:|:------:|:------:|:---:|:------:|---|
| 1 (sanity, gate moot) | 3 | 0 | 3 | 0 | gate doesn't differentiate — A
== B as expected |
| > 1 (gate fires) | 25 | **0** | 11 | **14** | B wins +6% to +32%; A
never wins |
Highlights from the firing-gate cases:
| Shape (G=1, stride=2 unless noted) | A ms | B ms | B vs A |
|---|---:|---:|---:|
| ALMIOPEN-1959 (N=16, K=256, C=128, 5×5, 40×175) | 0.183 | 0.171 | **B
+6%** |
| Retinanet-L61 (N=32, K=C=256, 3×3, 25×25) | 0.054 | 0.045 | **B +17%**
|
| i2v-010 (N=1, K=C=384, 3×3, 277×209) | 0.174 | 0.125 | **B +28%** |
| Synthetic 50×50 K3 N=32 K=C=256 | 0.131 | 0.088 | **B +32%** |
Why B wins everywhere the gate fires: for `gemms_count = N`, the flat
path needs N kernel launches (one per sub-GEMM), while the grouped path
loops over the same N sub-GEMMs on-device in 1 launch. The (N−1) ×
launch-tax is a structural disadvantage A can't recover from.
### Diff
| File | Lines |
|---|---:|
|
`include/.../device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp`
| +14 / −8 (one extra condition + expanded dispatch comment) |
| `test/.../test_grouped_convnd_bwd_data.cpp` | +9 / −0 (3 new shapes) |
## Submission Checklist
- [x] Look over the contributing guidelines at
https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
yenong-amd
pushed a commit
that referenced
this pull request
May 28, 2026
…959) (#7732) ## Motivation Fix silent miscompute in the grouped convolution backward-data kernel (`DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1`) when stride > dilation (ALMIOPEN-1959). PR #6208 introduced a flat-descriptor fast path that dropped all but the first sub-GEMM, producing zeroed slices of `dx` on the (G=1, stride>1, 2D, NumDTensor=0) intersection. Restore correctness without giving up the perf gains PR #6208 delivered on stride=1 shapes. ## Technical Details - Tighten the flat-descriptor fast-path gate to require `arg.gemms_count_ == 1` (i.e. a single sub-GEMM per dispatch — its original purpose). For stride > 1, the implicit GEMM is split into `gemms_count_` sub-GEMMs whose output cells tile `dx` disjointly; routing them through the flat path required dropping all but the first, which was the source of the bug. - Stride > 1 now falls through to the existing grouped CShuffle path, which packs all sub-GEMMs into one descriptor array and walks them on-device in a single kernel launch. This is the pre-PR-6208 production path; correctness is established and per-dispatch launch count is minimised. - Add regression coverage for the (G=1, stride>1, 2D, NumDTensor=0) intersection in `test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cpp` with `gemms_count` ∈ {4, 9, 36}. Pre-existing cases did not hit this intersection (all stride>1 cases used G=2; all G=1 cases used stride=1), which is why PR #6208's regression slipped past CI. ## Test Plan - `ctest -L SMOKE_TEST -R 'grouped_convnd_bwd_data'` on gfx942 (smoke tier — runs on every PR via `smart_build_and_test.sh`). - End-to-end verify (`verify=1`) via `example_grouped_conv_bwd_data_xdl_fp16` on stride 1/2/3/6 shapes including the original ALMIOPEN-1959 case and a cross-bucket (`gemms_count=36`) case spanning two `MaxGroupedGemmGroupsNum=32` buckets. - ckProfiler A/B sweep on MI300X (gfx942) toggling the flat-path gate via an environment variable: full kernel-family enumeration, winning kernel + its avg_time reported under each gate. 33/41 shapes completed before the sweep was stopped; the remaining 8 were the largest i2v/synthetic shapes where ckProfiler exceeded its 300s per-shape enumeration budget (not relevant to the verdict). ## Test Result ### Correctness | Test | Result | |---|:---:| | `test_grouped_convnd_bwd_data` (12 type parameterizations × Test2D, includes 3 new regression shapes) | **12/12 PASSED** in 14.18 s | | `test_grouped_convnd_bwd_data_interface` (API checks) | **PASSED** in 0.28 s | | ALMIOPEN-1959 stride=2 (`verify=1`) | **PASSED** | | stride=1 K3 (`verify=1`) | **PASSED** | | stride=3 K3 `gemms_count=9` (`verify=1`) | **PASSED** | | stride=6 K6 `gemms_count=36` cross-bucket (`verify=1`) | **PASSED** | ### Performance (ckProfiler A/B on gfx942 / MI300X) Comparing the **post-fix gate** (flat path only when `gemms_count_==1`, column "B") vs the **inner-loop variant** that keeps the flat path on stride>1 (column "A") across 25 stride>1 shapes where production picks a `_v1` instance (so the gate actually fires): | Stride | Shapes | A wins | Tie | B wins | Notes | |:------:|:------:|:------:|:---:|:------:|---| | 1 (sanity, gate moot) | 3 | 0 | 3 | 0 | gate doesn't differentiate — A == B as expected | | > 1 (gate fires) | 25 | **0** | 11 | **14** | B wins +6% to +32%; A never wins | Highlights from the firing-gate cases: | Shape (G=1, stride=2 unless noted) | A ms | B ms | B vs A | |---|---:|---:|---:| | ALMIOPEN-1959 (N=16, K=256, C=128, 5×5, 40×175) | 0.183 | 0.171 | **B +6%** | | Retinanet-L61 (N=32, K=C=256, 3×3, 25×25) | 0.054 | 0.045 | **B +17%** | | i2v-010 (N=1, K=C=384, 3×3, 277×209) | 0.174 | 0.125 | **B +28%** | | Synthetic 50×50 K3 N=32 K=C=256 | 0.131 | 0.088 | **B +32%** | Why B wins everywhere the gate fires: for `gemms_count = N`, the flat path needs N kernel launches (one per sub-GEMM), while the grouped path loops over the same N sub-GEMMs on-device in 1 launch. The (N−1) × launch-tax is a structural disadvantage A can't recover from. ### Diff | File | Lines | |---|---:| | `include/.../device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpp` | +14 / −8 (one extra condition + expanded dispatch comment) | | `test/.../test_grouped_convnd_bwd_data.cpp` | +9 / −0 (3 new shapes) | ## Submission Checklist - [x] Look over the contributing guidelines at https://github.com/ROCm/ROCm/blob/develop/CONTRIBUTING.md#pull-requests.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Motivation
Fix silent miscompute in the grouped convolution backward-data kernel (
DeviceGroupedConvBwdDataMultipleD_Xdl_CShuffle_v1) when stride > dilation (ALMIOPEN-1959). PR #6208 introduced a flat-descriptor fast path that dropped all but the first sub-GEMM, producing zeroed slices ofdxonthe (G=1, stride>1, 2D, NumDTensor=0) intersection. Restore correctness without giving up the perf gains PR #6208 delivered on stride=1 shapes.
Technical Details
arg.gemms_count_ == 1(i.e. a single sub-GEMM per dispatch — its original purpose). For stride > 1, the implicit GEMM is split intogemms_count_sub-GEMMs whose output cells tiledxdisjointly; routing them through the flat path required dropping all but the first, which was the source of the bug.test/grouped_convnd_bwd_data/test_grouped_convnd_bwd_data.cppwithgemms_count∈ {4, 9, 36}. Pre-existing cases did not hit this intersection (all stride>1 cases used G=2; all G=1 cases used stride=1), which is why PR [CK] Enable grouped conv bwd data to match non-grouped perf via NoShuffle + packed descriptors #6208's regression slipped past CI.Test Plan
ctest -L SMOKE_TEST -R 'grouped_convnd_bwd_data'on gfx942 (smoke tier — runs on every PR viasmart_build_and_test.sh).verify=1) viaexample_grouped_conv_bwd_data_xdl_fp16on stride 1/2/3/6 shapes including the original ALMIOPEN-1959 case and a cross-bucket (gemms_count=36) case spanning twoMaxGroupedGemmGroupsNum=32buckets.Test Result
Correctness
test_grouped_convnd_bwd_data(12 type parameterizations × Test2D, includes 3 new regression shapes)test_grouped_convnd_bwd_data_interface(API checks)verify=1)verify=1)gemms_count=9(verify=1)gemms_count=36cross-bucket (verify=1)Performance (ckProfiler A/B on gfx942 / MI300X)
Comparing the post-fix gate (flat path only when
gemms_count_==1,column "B") vs the inner-loop variant that keeps the flat path on
stride>1 (column "A") across 25 stride>1 shapes where production picks
a
_v1instance (so the gate actually fires):Highlights from the firing-gate cases:
Why B wins everywhere the gate fires: for
gemms_count = N, the flat path needs N kernel launches (one per sub-GEMM), while the grouped path loops over the same N sub-GEMMs on-device in 1 launch. The (N−1) × launch-tax is a structural disadvantage A can't recover from.Diff
include/.../device_grouped_conv_bwd_data_multiple_d_xdl_cshuffle_v1.hpptest/.../test_grouped_convnd_bwd_data.cppSubmission Checklist