Skip to content

ck_tile: add FillUniformScaleDistribution and fix MX GEMM scale init#7724

Open
AviralGoelAMD wants to merge 11 commits into
developfrom
users/avirgoel/ck/fix-mxgemm-scale-init
Open

ck_tile: add FillUniformScaleDistribution and fix MX GEMM scale init#7724
AviralGoelAMD wants to merge 11 commits into
developfrom
users/avirgoel/ck/fix-mxgemm-scale-init

Conversation

@AviralGoelAMD
Copy link
Copy Markdown
Contributor

@AviralGoelAMD AviralGoelAMD commented May 24, 2026

Summary

Problem

MX GEMM pipeline tests were passing vacuously: scale bytes were drawn from a fixed range (40–60) which, for e8m0, maps to scales ≈ 10⁻²⁷ — far below FP16 min denorm. Both GPU and CPU produced all-zero outputs, so numerical checks passed without exercising the GEMM.

Changes

include/ck_tile/host/fill.hpp — new FillUniformScaleDistribution<ScaleType> functor

  • Accepts human-readable float bounds and maps them to the raw byte range of any ExMy scale type (e8m0, e4m3, e5m3) by re-centering the IEEE 754 exponent into the type's bias space
  • Sampling is uniform over raw bytes → uniform over representable values
  • Fixes left-shift UB: uses multiplication instead of << mant_bits to avoid shifting negative signed integers (C++17 UB)
  • Adds assert(min_r <= max_r) to catch inverted-range UB when both bounds exceed the type's representable range
  • Provides default member values (0.125f, 2.0f) and std::optional seed consistent with sibling fillers
  • /** */ Doxygen style with @note on snapping asymmetry

test/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp — fix scale initialization

  • Replace manual byte-range distribution with FillUniformScaleDistribution<>{0.125f, 2.0f}
  • Use distinct seeds for scale_a (11941) and scale_b (11943) to avoid correlated scale tensors that were causing 60 test failures for fp4+e5m3/e4m3 combinations

test/ck_tile/utility/test_fill.cpp — new unit tests for FillUniformScaleDistribution

  • 16 typed tests across e8m0, e4m3, e5m3: validity, range, reproducibility, coverage, snapping, stress, nullopt seed, and range overload
  • Test helper expected_raw_range mirrors implementation clamping exactly

Scale byte values were drawn from dist(40, 60), which for e8m0
gives 2^(byte-127) ≈ 10^-27 — far below FP16 min denorm. Both
GPU and CPU produced all-zero C, so check_err passed vacuously
without testing the GEMM computation.

Use each scale type's bias to center the range near unity:
  e8m0 (bias=127): bytes 124-128 → scales 0.125..2.0
  e4m3 (bias=7):   bytes 4-8     → scales near 1.0
  e5m3 (bias=15):  bytes 12-16   → scales near 1.0

Also use a fixed seed for reproducibility and apply clang-format.
…y types

The previous approach used numeric_traits::bias directly as the center
for random raw byte generation. This only works for E8M0 where the
entire byte is the exponent. For E5M3/E4M3, bias is the exponent field
bias but the raw byte encodes both exponent and mantissa bits, so the
exponent must be shifted left by the mantissa width.

Also reverts unrelated formatting changes.
…ine test

Add FillUniformScaleDistribution<ScaleType> to fill.hpp. Maps human-readable
float bounds [min, max] to the raw byte range of ExMy scale types (e8m0, e4m3,
e5m3) by re-centering the IEEE 754 exponent into the type's own bias space,
then sampling uniformly over raw bytes. Upper bound is strict: no generated
value exceeds max_scale_.

Replace the open-coded raw-byte distribution in test_mx_gemm_pipeline_util.hpp
with FillUniformScaleDistribution{0.125f, 2.0f}. Scale-A and scale-B use
different seeds (11941 / 11943) so their values are uncorrelated.

Add 41 unit tests in test_fill.cpp covering e8m0_t, e4m3_t, and e5m3_t:
range clamping, power-of-two upper bound guarantee, seed reproducibility,
non-power-of-two bound snapping, empty range, stress coverage, and death on
inverted bounds.
- Replace left-shift of signed int with multiplication to avoid UB in C++17
  when the pre-shift value is negative (e.g. small scale inputs with low-bias types)
- Add assert(min_r <= max_r) to catch inverted-range UB from out-of-bounds inputs
- Add default member values (0.125f, 2.0f) matching the convention of sibling fillers
- Convert doc comment from /// to /** */ style consistent with the rest of fill.hpp,
  and add @note documenting the min-snapping asymmetry and seed default
- Fix misleading raw=0 comment: describe e8m0 and e4m3/e5m3 cases separately
- Add clamping to expected_raw_range() test helper to match implementation
- Remove EXPECT_DEATH test (assert fires only in debug builds; prefer not testing UB)
- Re-number tests after removal to keep sequence contiguous
- Add tests for nullopt seed path and range overload
- Delete debug-only dump_scale_fill.cpp (no CMake target, not built by CI)
- Update run_mx_gemm.inc and run_mx_flatmm.inc to use FillUniformScaleDistribution
  instead of FillUniformDistribution for ExMy scale tensors
Remove run_mx_gemm.inc, run_mx_flatmm.inc, and test_mx_grouped_gemm_ut_cases.inc
changes — these belong in separate PRs.
Switch the RangeOverloadFills* typed test from std::vector to
ck_tile::HostTensor, which is the library's preferred host storage type.
Replace std::vector with ck_tile::HostTensor in all FillUniformScaleDistribution
typed tests, which is the preferred CK host storage type. Index-based loops
converted to range-based for loops since HostTensor does not provide operator[].
Add a comment above raw_min=1 in expected_raw_range to document that
raw byte 0 is excluded on the same grounds as FillUniformScaleDistribution.
Expand the comment before StrictFloatUpperBound (test 12) to explain
how it differs from StrictFloatBounds (test 11): test 12 covers all
ExMy types but only enforces the upper bound, because non-zero mantissa
bits in e4m3/e5m3 allow values below min_scale (covered by test 13).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant