ck_tile: add FillUniformScaleDistribution and fix MX GEMM scale init#7724
Open
AviralGoelAMD wants to merge 11 commits into
Open
ck_tile: add FillUniformScaleDistribution and fix MX GEMM scale init#7724AviralGoelAMD wants to merge 11 commits into
AviralGoelAMD wants to merge 11 commits into
Conversation
Scale byte values were drawn from dist(40, 60), which for e8m0 gives 2^(byte-127) ≈ 10^-27 — far below FP16 min denorm. Both GPU and CPU produced all-zero C, so check_err passed vacuously without testing the GEMM computation. Use each scale type's bias to center the range near unity: e8m0 (bias=127): bytes 124-128 → scales 0.125..2.0 e4m3 (bias=7): bytes 4-8 → scales near 1.0 e5m3 (bias=15): bytes 12-16 → scales near 1.0 Also use a fixed seed for reproducibility and apply clang-format.
…y types The previous approach used numeric_traits::bias directly as the center for random raw byte generation. This only works for E8M0 where the entire byte is the exponent. For E5M3/E4M3, bias is the exponent field bias but the raw byte encodes both exponent and mantissa bits, so the exponent must be shifted left by the mantissa width. Also reverts unrelated formatting changes.
…ine test
Add FillUniformScaleDistribution<ScaleType> to fill.hpp. Maps human-readable
float bounds [min, max] to the raw byte range of ExMy scale types (e8m0, e4m3,
e5m3) by re-centering the IEEE 754 exponent into the type's own bias space,
then sampling uniformly over raw bytes. Upper bound is strict: no generated
value exceeds max_scale_.
Replace the open-coded raw-byte distribution in test_mx_gemm_pipeline_util.hpp
with FillUniformScaleDistribution{0.125f, 2.0f}. Scale-A and scale-B use
different seeds (11941 / 11943) so their values are uncorrelated.
Add 41 unit tests in test_fill.cpp covering e8m0_t, e4m3_t, and e5m3_t:
range clamping, power-of-two upper bound guarantee, seed reproducibility,
non-power-of-two bound snapping, empty range, stress coverage, and death on
inverted bounds.
- Replace left-shift of signed int with multiplication to avoid UB in C++17 when the pre-shift value is negative (e.g. small scale inputs with low-bias types) - Add assert(min_r <= max_r) to catch inverted-range UB from out-of-bounds inputs - Add default member values (0.125f, 2.0f) matching the convention of sibling fillers - Convert doc comment from /// to /** */ style consistent with the rest of fill.hpp, and add @note documenting the min-snapping asymmetry and seed default - Fix misleading raw=0 comment: describe e8m0 and e4m3/e5m3 cases separately - Add clamping to expected_raw_range() test helper to match implementation - Remove EXPECT_DEATH test (assert fires only in debug builds; prefer not testing UB) - Re-number tests after removal to keep sequence contiguous - Add tests for nullopt seed path and range overload - Delete debug-only dump_scale_fill.cpp (no CMake target, not built by CI) - Update run_mx_gemm.inc and run_mx_flatmm.inc to use FillUniformScaleDistribution instead of FillUniformDistribution for ExMy scale tensors
Remove run_mx_gemm.inc, run_mx_flatmm.inc, and test_mx_grouped_gemm_ut_cases.inc changes — these belong in separate PRs.
Switch the RangeOverloadFills* typed test from std::vector to ck_tile::HostTensor, which is the library's preferred host storage type.
Replace std::vector with ck_tile::HostTensor in all FillUniformScaleDistribution typed tests, which is the preferred CK host storage type. Index-based loops converted to range-based for loops since HostTensor does not provide operator[].
Add a comment above raw_min=1 in expected_raw_range to document that raw byte 0 is excluded on the same grounds as FillUniformScaleDistribution. Expand the comment before StrictFloatUpperBound (test 12) to explain how it differs from StrictFloatBounds (test 11): test 12 covers all ExMy types but only enforces the upper bound, because non-zero mantissa bits in e4m3/e5m3 allow values below min_scale (covered by test 13).
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Problem
MX GEMM pipeline tests were passing vacuously: scale bytes were drawn from a fixed range (40–60) which, for e8m0, maps to scales ≈ 10⁻²⁷ — far below FP16 min denorm. Both GPU and CPU produced all-zero outputs, so numerical checks passed without exercising the GEMM.
Changes
include/ck_tile/host/fill.hpp— newFillUniformScaleDistribution<ScaleType>functor<< mant_bitsto avoid shifting negative signed integers (C++17 UB)assert(min_r <= max_r)to catch inverted-range UB when both bounds exceed the type's representable rangestd::optionalseed consistent with sibling fillers/** */Doxygen style with@noteon snapping asymmetrytest/ck_tile/gemm_mx/test_mx_gemm_pipeline_util.hpp— fix scale initializationFillUniformScaleDistribution<>{0.125f, 2.0f}test/ck_tile/utility/test_fill.cpp— new unit tests forFillUniformScaleDistributionexpected_raw_rangemirrors implementation clamping exactly