Skip to content

[hipblaslt][tensilelite] Validate MX scale-format combinations for gfx1250#7768

Closed
jaopaulolc wants to merge 5 commits into
developfrom
users/jolabega/tensilelite-guard-invalid-mx-scale-fmt-combos
Closed

[hipblaslt][tensilelite] Validate MX scale-format combinations for gfx1250#7768
jaopaulolc wants to merge 5 commits into
developfrom
users/jolabega/tensilelite-guard-invalid-mx-scale-fmt-combos

Conversation

@jaopaulolc
Copy link
Copy Markdown
Contributor

@jaopaulolc jaopaulolc commented May 26, 2026

Summary

The AMDGPU assembler currently accepts invalid (matrix_fmt, matrix_scale_fmt) tuples on gfx1250's v_wmma_scale_f32_16x16x128_f8f6f4, silently emitting illegal encodings (ROCm/llvm-project#2634). Out of 225 possible tuples only 43 are legal; the assembler accepts all 225 today.

This PR guards every host path that can build a kernel for one of those illegal tuples, so the failure surfaces as a clean validation error long before any incorrect codegen happens.

Rules enforced (per the ISA spec):

  • FP8 / BF8 / FP6 / BF6 matrix class must pair with E8 (UE8M0) scale.
  • FP4 matrix class accepts E8, E5M3, or E4M3 scale.
  • When both A and B are FP4, the two scale formats must match.

Changes

Foundation (new)

  • tensilelite/include/Tensile/MXScaleFormatValidation.hppTensileLite::isValidMXScaleFormatCombination, mxScaleFormatCombinationError, etc. Single source of truth for the rule set.
  • tensilelite/src/MXScaleFormatValidation.cpp — error-string formatting (matching wording in C++ and Python).

Tensilelite C++

  • include/Tensile/ContractionProblem.hpp, src/ContractionProblem.cppContractionProblemGemm::validateMXScaleFormats() (throws std::runtime_error) and isValidMXScaleFormats() (bool). Invoked from setMXScaleA, setMXScaleB, and consistencyCheck so invalid combos throw at the point they are introduced.

hipBLASLt host

  • rocblaslt_mat_utils.hpprocblasltScalingFormatToMXScaleDataType, rocblasltHipDataTypeToMXMatrixDataType, and validateMXScaleFormatCombination helpers. Wired into rocblaslt_matmul_valid_args so the API returns rocblaslt_status_invalid_value (instead of building a broken kernel).
  • tensile_host.cpp — defensive try/catch around tensileProblem.validateMXScaleFormats() after each ScalingFormat -> setMXScale* switch (both ConstructTensileProblem and updateTensileProblem).

Tensilelite Python

  • Tensile/Common/MXScaleFormatValidation.py — Python mirror of the validator. Error wording is byte-for-byte identical to the C++ implementation.
  • Tensile/SolutionStructs/Problem.py — new ProblemType._checkMXScaleFormatCombination(), called between MX-scale defaulting and _checkIfSupportedGEMMType(). Validates against MacDataTypeA/B (the MAC compute-input type, matching how the ISA constraint applies).

Tests (new)

  • tests/MXScaleFormatValidation_test.cpp — GoogleTest covering all 43 valid combinations from table-valid-combinations.txt plus curated invalid samples, classification helpers, error-string contract, and end-to-end integration on real ContractionProblemGemm instances. Registered in tests/CMakeLists.txt. Uses targeted using-declarations (not a broad using namespace TensileLite;) to avoid colliding with the TensileLite::E8 / TensileLite::E5M3 struct types.
  • Tensile/Tests/unit/test_MXScaleFormatValidation.py — 193 pytest cases mirroring the C++ matrix, including ProblemType construction (accept / reject / no-MX-skip) and a YAML-spelling round-trip guard.

Audit of existing tests/configs

Audited every gfx1250-targeted YAML under Tensile/Tests/common/{gemm,streamk}/gfx12* (15 files, 37 problem-type blocks), every C++ MX test fixture, and every hipBLASLt sample/client helper. Zero invalid combinations were found — all existing configs already obey the rules. No test data fixes were required.

The audit confirmed three correctness notes that this PR depends on:

  1. f8 as a YAML scale spelling means E4M3 (the C++ enum reuses rocisa::DataType::Float8 for the E4M3 byte scale; matches tensile_host.cpp's Block_32_UE4M3 -> Float8 mapping).
  2. _FNUZ matrix dtypes are correctly treated as non-MX and bypass the gfx1250 rules.
  3. The FP4×FP4 fixtures in mxf4_gfx1250.yaml already use matching scales.

Test plan

All items have been executed in this dev environment (ROCm 7.14 at /opt/rocm, FFM/mi450 GPU simulator).

  • Built the C++ gtest binary with the superbuild flags mirroring ~/bin/rocbuild (with GPU_TARGETS=gfx1250 and TENSILELITE_BUILD_TESTING=ON), then ran the validator suites:

    ./tensilelite-tests --gtest_filter='*MXCombinationTest*:*MXPerSideValidTest*:MXClassification.*:MXMixedClass.*:MXErrorString.*:MXProblemIntegration.*:IsMXProblem.*'
    

    Result: 124 passed, 0 failed across 13 test suites — every valid (43) and invalid (curated) combination, the FP4×FP4 mismatch rule, the error-string format contract, the classification helpers, and end-to-end ContractionProblemGemm integration.

  • pytest projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py -v193 passed, 0 failed.

  • Existing Tensile/Tests/unit suite (Python portion of the regression check) — 841 passed, 6 skipped, 0 failed.

  • Loaded a YAML config with an invalid combo (MacDataTypeA: F4, MacDataTypeB: F4, DataTypeMXSA: E5M3, DataTypeMXSB: f8) into ProblemType(...) and confirmed it raises with: Invalid MX scale-format combination (A=FP4, AScale=E5M3, B=FP4, BScale=E4M3): FP4 x FP4 requires AScale (E5M3) == BScale (E4M3); see table-valid-combinations.txt / ROCm/llvm-project#2634.

  • Built libhipblaslt.so (against FFM/mi450 because no real gfx1250 GPU is reachable here; the validator is architecture-independent host C++ so the test path is identical) and ran a smoke test that calls hipblasLtMatmul against four invalid (A type, A scale mode, B type, B scale mode) tuples. All four returned HIPBLAS_STATUS_INVALID_VALUE (3):

    [case 1] FP4 x FP4, A=VEC32_UE5M3_EXT, B=VEC32_UE4M3_EXT (mismatched FP4 scales)   -> 3 OK
    [case 2] FP8 x FP8, A=VEC32_UE5M3_EXT, B=VEC32_UE8M0       (FP8 must use E8)       -> 3 OK
    [case 3] BF8 x BF8, A=VEC32_UE4M3_EXT, B=VEC32_UE8M0       (BF8 must use E8)       -> 3 OK
    [case 4] FP8 x FP4, A=VEC32_UE5M3_EXT, B=VEC32_UE8M0       (FP8 must use E8)       -> 3 OK
    

Made with Cursor

…x1250

The AMDGPU assembler currently accepts invalid (matrix_fmt, matrix_scale_fmt)
tuples on gfx1250's v_wmma_scale_f32_16x16x128_f8f6f4 instruction, silently
emitting illegal encodings (ROCm/llvm-project#2634). Guard against this in
the host stack so kernels for invalid combinations never get built.

Rules enforced (per the ISA / table-valid-combinations.txt):
  * FP8/BF8/FP6/BF6 matrix class must pair with E8 (UE8M0) scale.
  * FP4 matrix class accepts E8, E5M3, or E4M3 scale.
  * When both A and B are FP4, the two scale formats must match.

Implementation:
  * New TensileLite::isValidMXScaleFormatCombination /
    mxScaleFormatCombinationError utilities (header-only validator + .cpp
    for error string formatting), reusable from any host TU.
  * ContractionProblemGemm gains validateMXScaleFormats() (throws) and
    isValidMXScaleFormats() (bool). Called from setMXScaleA/B and
    consistencyCheck so invalid combos throw at the point they are
    introduced.
  * hipBLASLt rejects invalid combinations at the public API surface via a
    new validateMXScaleFormatCombination guard in rocblaslt_matmul_valid_args
    that returns rocblaslt_status_invalid_value. tensile_host.cpp also
    validates defensively after the ScalingFormat -> setMXScale switch.
  * Tensilelite Python side mirrors the validator (with matching error
    wording) and hooks it into ProblemType._checkMXScaleFormatCombination
    so invalid YAML configs raise at problem-type construction.

Test coverage:
  * tests/MXScaleFormatValidation_test.cpp - GoogleTest exercising all 43
    valid (A, scaleA, B, scaleB) combinations from
    table-valid-combinations.txt plus curated invalid samples; integration
    tests that exercise the throw on real ContractionProblemGemm instances.
  * Tensile/Tests/unit/test_MXScaleFormatValidation.py - 193 pytest cases
    covering the same matrix, including end-to-end ProblemType
    construction and YAML-spelling round-trip.

Audit of every gfx1250-targeted YAML in Tensile/Tests/ and every hipBLASLt
sample / client testing helper confirmed zero pre-existing invalid
combinations; no test data fixes were required.

Co-authored-by: Cursor <cursoragent@cursor.com>
`using namespace TensileLite;` pulls in `TensileLite::E8` and
`TensileLite::E5M3` (FP8 byte struct types), which shadow the
`rocisa::DataType::E8` / `E5M3` scale-format enumerators the test table
references via local `constexpr auto E8 = rocisa::DataType::E8;` etc.
Result: every `INSTANTIATE_TEST_SUITE_P` row raised an "ambiguous
reference" error and the `tensilelite-tests` build failed.

Replace the broad `using namespace` with explicit using-declarations
for only the validator symbols we exercise, so the spec table compiles
verbatim. Verified with:

  cmake --build .../tensilelite/tests --target tensilelite-tests -j32
  ./tensilelite-tests --gtest_filter='*MXCombinationTest*:*MXPerSideValidTest*:MXClassification.*:MXMixedClass.*:MXErrorString.*:MXProblemIntegration.*:IsMXProblem.*'
  -> 124 / 124 passed.

Co-authored-by: Cursor <cursoragent@cursor.com>
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 26, 2026

Codecov Report

❌ Patch coverage is 77.77778% with 38 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...il/rocblaslt/include/rocblaslt_mx_scale_format.hpp 62.30% 9 Missing and 14 partials ⚠️
...tail/rocblaslt/src/include/rocblaslt_mat_utils.hpp 50.00% 5 Missing and 2 partials ⚠️
...rary/src/amd_detail/rocblaslt/src/tensile_host.cpp 66.67% 6 Missing ⚠️
...silelite/Tensile/Common/MXScaleFormatValidation.py 96.92% 1 Missing and 1 partial ⚠️

❌ Your project status has failed because the head coverage (77.83%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #7768      +/-   ##
===========================================
- Coverage    61.87%   61.75%   -0.12%     
===========================================
  Files         2086     2085       -1     
  Lines       357038   356846     -192     
  Branches     53806    53845      +39     
===========================================
- Hits        220892   220348     -544     
- Misses      117348   117739     +391     
+ Partials     18798    18759      -39     
Flag Coverage Δ *Carryforward flag
TensileLite 26.09% <97.44%> (+0.16%) ⬆️
hipBLAS 90.65% <ø> (ø) Carriedforward from d708e7a
hipBLASLt 41.41% <61.29%> (+0.13%) ⬆️
hipCUB 82.21% <ø> (ø) Carriedforward from d708e7a
hipDNN 85.61% <ø> (-0.25%) ⬇️ Carriedforward from d708e7a
hipFFT 51.12% <ø> (+1.12%) ⬆️ Carriedforward from d708e7a
hipRAND 76.12% <ø> (ø) Carriedforward from d708e7a
hipSOLVER 69.24% <ø> (ø) Carriedforward from d708e7a
hipSPARSE 85.09% <ø> (ø) Carriedforward from d708e7a
rocBLAS 48.11% <ø> (+0.01%) ⬆️ Carriedforward from d708e7a
rocFFT 51.98% <ø> (-0.09%) ⬇️ Carriedforward from d708e7a
rocRAND 57.04% <ø> (ø) Carriedforward from d708e7a
rocSOLVER 77.83% <ø> (+<0.01%) ⬆️ Carriedforward from d708e7a
rocSPARSE 72.68% <ø> (ø) Carriedforward from d708e7a

*This pull request uses carry forward flags. Click here to find out more.

Files with missing lines Coverage Δ
...slt/tensilelite/Tensile/SolutionStructs/Problem.py 48.07% <100.00%> (+4.25%) ⬆️
...silelite/Tensile/Common/MXScaleFormatValidation.py 96.92% <96.92%> (ø)
...rary/src/amd_detail/rocblaslt/src/tensile_host.cpp 41.67% <66.67%> (+0.11%) ⬆️
...tail/rocblaslt/src/include/rocblaslt_mat_utils.hpp 38.86% <50.00%> (+0.49%) ⬆️
...il/rocblaslt/include/rocblaslt_mx_scale_format.hpp 62.30% <62.30%> (ø)

... and 43 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

jaopaulolc and others added 3 commits May 26, 2026 14:43
Codecov flagged the MX scale-format inline helpers added in 56dcba9 as
the lowest-coverage region of that PR - they were only exercised
indirectly via the full hipblasLtMatmul API path, so individual branches
(the unknown-enum fall-through, several hipDataType cases, and the
one-side-MX vs both-sides-MX branches in the joint validator) were not
hit by any test.

Split the three helpers out of rocblaslt_mat_utils.hpp into a new lean,
log-free header so they can be unit-tested without dragging the
rocblaslt logging machinery (utility.hpp / log_base / get_logger_os) into
a client gtest:

  * rocblasltScalingFormatToMXScaleDataType    (ScalingFormat -> dtype)
  * rocblasltHipDataTypeToMXMatrixDataType     (hipDataType   -> dtype)
  * checkMXScaleFormatCombination              (pure joint validator
                                                returning optional<string>)

rocblaslt_mat_utils.hpp::validateMXScaleFormatCombination is now a thin
wrapper that calls checkMXScaleFormatCombination and emits log_error
with the returned diagnostic, preserving the existing rocblaslt_status
contract for in-tree callers (validateMatmulArgs, tensile_host.cpp).
The lean header lives next to rocblaslt-types.h so the hipBLASLt library
build keeps finding it via its existing include path, and so the
hipblaslt-test target can include it without also picking up the
sibling private "utility.hpp" (which would shadow the client common
utility.hpp and break the rest of hipblaslt-test).

Tests
-----
New clients/tests/src/mat_utils_mx_scale_format_gtest.cpp (13 cases)
covers every documented branch of all three helpers:

  * ScalingFormat: every Block_*_UE8M0 / Block_*_UE4M3 / Block_*_UE5M3
    variant, every non-block format (None/Scalar/Vector), and an
    unknown enumerator hitting the default branch.
  * hipDataType: FP8/BF8 (incl. fnuz), FP6/BF6, FP4, plus a sample of
    non-MX types collapsing to rocisa::DataType::None.
  * checkMXScaleFormatCombination: both-sides-non-MX early exit;
    single-side-MX with the other side reported as (None, None) so the
    FP4xFP4 joint rule does not spuriously fire; FP8/BF8/FP6/BF6 -> E8
    enforcement; FP4 accepting all three legal scales (E8 / E5M3 /
    E4M3); FP4xFP4 same-scale enforcement (including the 16-vs-32 same
    family acceptance); mixed FP8 x FP4 legal/illegal combinations.

CMake wiring:

  * library/src/CMakeLists.txt registers the new header alongside the
    other public rocblaslt-include headers.
  * clients/CMakeLists.txt links tensilelite::tensilelite-host into
    hipblaslt-test so TensileLite::isValidMXScaleFormatCombination /
    mxScaleFormatCombinationError resolve at link time.
  * clients/tests/src/CMakeLists.txt adds the new gtest source and the
    rocblaslt public-include directory (deliberately not the private
    src/include path, to avoid the utility.hpp shadow).

Verified:

  cmake --build build --target hipblaslt-test -j8
  ./hipblaslt-test --gtest_filter='MXScaleFormatHelpersTest.*'
  -> 13 / 13 passed.

  cmake --build build --target tensilelite-tests -j8
  ./tensilelite-tests --gtest_filter='IsMXProblem*:MXClassification*:MXMixedClass*:MXErrorString*:MXProblemIntegration*'
  -> 32 / 32 passed (no regression from the refactor).

Co-authored-by: Cursor <cursoragent@cursor.com>
@jaopaulolc
Copy link
Copy Markdown
Contributor Author

Closing in favor of #7814, which carries the reduced-scope (tensilelite-C++-only) version of this work. The Python/rocisa validator, the hipBLASLt host-API guard, and the hipBLASLt client-side unit test have been dropped per updated priorities; only the in-tensilelite check that rejects solutions with invalid scale-format combinations is retained there.

@jaopaulolc jaopaulolc closed this May 27, 2026
jaopaulolc added a commit that referenced this pull request May 30, 2026
….py (#7814)

This PR supersedes #7768.

## Motivation

gfx1250's `v_wmma_scale_f32_16x16x128_f8f6f4` only accepts a fixed set
of joint MX scale-format tuples `(A matrix class, A scale, B matrix
class, B scale)`. The AMDGPU assembler does not enforce that joint
constraint today (see
[ROCm/llvm-project#2634](ROCm/llvm-project#2634)),
so a kernel candidate whose tuple is illegal would otherwise codegen
into an encoding the hardware does not implement.

Rejecting illegal candidates inside the kernel generator turns this into
a clean "no solution selected" outcome at solution-selection time
instead of a silent miscompile or a hard hardware fault at run time.

## Technical Details

Add a validator helper `_validateMXScaleFormatCombination(state, ...)`
in `projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py`,
and call it from `Solution.assignDerivedParameters` immediately after
the existing `_deriveAndValidateMXScaleLayoutAndTransport` helper, so
the rule runs at the same point as every other MX-related
derived-parameter check.

**Rules enforced** (per the ISA spec, mirrored in
`table-valid-combinations.txt`):

* **FP8 / BF8 / FP6 / BF6** (incl. `_fnuz` variants of FP8/BF8) must
pair with **E8** (UE8M0) scale on their own side.
* **FP4** accepts **E8**, **E5M3**, or **E4M3** scale.
* When both A and B are **FP4** the two scales must match.

**Inputs read.** Only `state["ProblemType"]` fields: `DataTypeA`,
`DataTypeB` (matrix dtype), `DataTypeMXSA`, `DataTypeMXSB` (scale
dtype), and `MXBlockA`, `MXBlockB` (per-side MX-block width).

**Field-shape compatibility.** `ProblemType` fields can arrive as
`DataType` wrappers (during `assignDerivedParameters`) or as raw
`rocisa.enum.DataTypeEnum` values (after
`cleanupProblemTypeForLogging`). The helper normalizes both shapes via a
small `_mxEnumValue` resolver.

**Short-circuit.** Sides whose `MXBlock` is `0` carry no MX scale and
are skipped. The helper returns `True` early when neither side has MX
scaling, so non-MX problems are entirely unaffected.

**Rejection path.** A candidate whose joint tuple is illegal is dropped
via `reject(state, printRejectionReason, msg)`, which sets
`state["Valid"] = False` and (with
`--global-parameters=PrintSolutionRejectionReason=True` or the
equivalent YAML setting) prints a diagnostic of the form:

```
reject: Invalid MX scale-format combination (A=FP4, AScale=E5M3, B=FP4, BScale=E4M3): FP4 x FP4 requires AScale (E5M3) == BScale (E4M3); see table-valid-combinations.txt / ROCm/llvm-project#2634.
```

The diagnostic always names the offending tuple in ISA-spec spelling
(`FP8`, `BF8`, `FP6`, `BF6`, `FP4`, `E8`, `E5M3`, `E4M3`), the rule(s)
that failed (per-side dtype/scale mismatch and/or the FP4 x FP4 joint
rule), and the cross-reference to the source-of-truth
(`ROCm/llvm-project#2634`).

**Scope.** Pure-Python change in one helper plus its caller, gated by
`MXBlockA != 0 or MXBlockB != 0`. No new dependencies — the helper uses
the already-imported `rocisa.enum.DataTypeEnum` and the existing
`reject(...)` mechanism.

## Test Plan

New
`projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py`
exercises the helper directly with minimal `Solution`-state fixtures (no
client build, no GPU, no rocisa device code). Coverage:

* Short-circuit when both `MXBlock` fields are `0`, including an FP4 x
FP4 dtype pair that must **not** trigger the joint rule when neither
side has MX scaling.
* Per-side rule: every FP8 / BF8 / FP6 / BF6 (incl. `_fnuz`) variant
accepts E8 and rejects E5M3 / E4M3 on each side independently.
* Per-side rule: FP4 accepts all three legal scales (E8, E5M3, E4M3)
when paired with itself.
* Joint rule: every FP4 x FP4 scale-mismatch is rejected; matching
scales are accepted.
* Mixed-class combos (FP8 x FP4, BF8 x FP4, ...): each side must satisfy
its own per-side rule; the joint FP4 x FP4 rule does not fire across
mixed classes.
* Asymmetric MX (only one side has `mxBlock != 0`): the non-MX side is
fully skipped, including the FP4 x FP4 joint trigger.
* Diagnostic-message contract: `capsys` captures the rejection string
and asserts it names the offending tuple, the rule that failed, and the
`ROCm/llvm-project#2634` reference.
* Field-shape compatibility: both `DataType`-wrapped and raw
`DataTypeEnum` field shapes are exercised.
* **Authoritative spec table:** every legal `(A, scaleA, B, scaleB)`
tuple permitted by the ISA is exercised end-to-end (positive set), so
the test file doubles as a single source of truth for what gfx1250
allows.

Run locally with:

```
cd projects/hipblaslt/tensilelite
PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py -v
PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit
```

## Test Result

```
$ PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py
138 passed in 0.31s

$ PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit
859 passed, 199 skipped in 38.08s
```

The 199 skips are arch-marker-gated tests that always skip on a non-GPU
host; no regressions versus `develop`.

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants