[hipblaslt][tensilelite] Validate MX scale-format combinations for gfx1250#7768
Closed
jaopaulolc wants to merge 5 commits into
Closed
[hipblaslt][tensilelite] Validate MX scale-format combinations for gfx1250#7768jaopaulolc wants to merge 5 commits into
jaopaulolc wants to merge 5 commits into
Conversation
…x1250 The AMDGPU assembler currently accepts invalid (matrix_fmt, matrix_scale_fmt) tuples on gfx1250's v_wmma_scale_f32_16x16x128_f8f6f4 instruction, silently emitting illegal encodings (ROCm/llvm-project#2634). Guard against this in the host stack so kernels for invalid combinations never get built. Rules enforced (per the ISA / table-valid-combinations.txt): * FP8/BF8/FP6/BF6 matrix class must pair with E8 (UE8M0) scale. * FP4 matrix class accepts E8, E5M3, or E4M3 scale. * When both A and B are FP4, the two scale formats must match. Implementation: * New TensileLite::isValidMXScaleFormatCombination / mxScaleFormatCombinationError utilities (header-only validator + .cpp for error string formatting), reusable from any host TU. * ContractionProblemGemm gains validateMXScaleFormats() (throws) and isValidMXScaleFormats() (bool). Called from setMXScaleA/B and consistencyCheck so invalid combos throw at the point they are introduced. * hipBLASLt rejects invalid combinations at the public API surface via a new validateMXScaleFormatCombination guard in rocblaslt_matmul_valid_args that returns rocblaslt_status_invalid_value. tensile_host.cpp also validates defensively after the ScalingFormat -> setMXScale switch. * Tensilelite Python side mirrors the validator (with matching error wording) and hooks it into ProblemType._checkMXScaleFormatCombination so invalid YAML configs raise at problem-type construction. Test coverage: * tests/MXScaleFormatValidation_test.cpp - GoogleTest exercising all 43 valid (A, scaleA, B, scaleB) combinations from table-valid-combinations.txt plus curated invalid samples; integration tests that exercise the throw on real ContractionProblemGemm instances. * Tensile/Tests/unit/test_MXScaleFormatValidation.py - 193 pytest cases covering the same matrix, including end-to-end ProblemType construction and YAML-spelling round-trip. Audit of every gfx1250-targeted YAML in Tensile/Tests/ and every hipBLASLt sample / client testing helper confirmed zero pre-existing invalid combinations; no test data fixes were required. Co-authored-by: Cursor <cursoragent@cursor.com>
`using namespace TensileLite;` pulls in `TensileLite::E8` and `TensileLite::E5M3` (FP8 byte struct types), which shadow the `rocisa::DataType::E8` / `E5M3` scale-format enumerators the test table references via local `constexpr auto E8 = rocisa::DataType::E8;` etc. Result: every `INSTANTIATE_TEST_SUITE_P` row raised an "ambiguous reference" error and the `tensilelite-tests` build failed. Replace the broad `using namespace` with explicit using-declarations for only the validator symbols we exercise, so the spec table compiles verbatim. Verified with: cmake --build .../tensilelite/tests --target tensilelite-tests -j32 ./tensilelite-tests --gtest_filter='*MXCombinationTest*:*MXPerSideValidTest*:MXClassification.*:MXMixedClass.*:MXErrorString.*:MXProblemIntegration.*:IsMXProblem.*' -> 124 / 124 passed. Co-authored-by: Cursor <cursoragent@cursor.com>
Codecov Report❌ Patch coverage is ❌ Your project status has failed because the head coverage (77.83%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #7768 +/- ##
===========================================
- Coverage 61.87% 61.75% -0.12%
===========================================
Files 2086 2085 -1
Lines 357038 356846 -192
Branches 53806 53845 +39
===========================================
- Hits 220892 220348 -544
- Misses 117348 117739 +391
+ Partials 18798 18759 -39
*This pull request uses carry forward flags. Click here to find out more.
🚀 New features to boost your workflow:
|
…mx-scale-fmt-combos
Codecov flagged the MX scale-format inline helpers added in 56dcba9 as the lowest-coverage region of that PR - they were only exercised indirectly via the full hipblasLtMatmul API path, so individual branches (the unknown-enum fall-through, several hipDataType cases, and the one-side-MX vs both-sides-MX branches in the joint validator) were not hit by any test. Split the three helpers out of rocblaslt_mat_utils.hpp into a new lean, log-free header so they can be unit-tested without dragging the rocblaslt logging machinery (utility.hpp / log_base / get_logger_os) into a client gtest: * rocblasltScalingFormatToMXScaleDataType (ScalingFormat -> dtype) * rocblasltHipDataTypeToMXMatrixDataType (hipDataType -> dtype) * checkMXScaleFormatCombination (pure joint validator returning optional<string>) rocblaslt_mat_utils.hpp::validateMXScaleFormatCombination is now a thin wrapper that calls checkMXScaleFormatCombination and emits log_error with the returned diagnostic, preserving the existing rocblaslt_status contract for in-tree callers (validateMatmulArgs, tensile_host.cpp). The lean header lives next to rocblaslt-types.h so the hipBLASLt library build keeps finding it via its existing include path, and so the hipblaslt-test target can include it without also picking up the sibling private "utility.hpp" (which would shadow the client common utility.hpp and break the rest of hipblaslt-test). Tests ----- New clients/tests/src/mat_utils_mx_scale_format_gtest.cpp (13 cases) covers every documented branch of all three helpers: * ScalingFormat: every Block_*_UE8M0 / Block_*_UE4M3 / Block_*_UE5M3 variant, every non-block format (None/Scalar/Vector), and an unknown enumerator hitting the default branch. * hipDataType: FP8/BF8 (incl. fnuz), FP6/BF6, FP4, plus a sample of non-MX types collapsing to rocisa::DataType::None. * checkMXScaleFormatCombination: both-sides-non-MX early exit; single-side-MX with the other side reported as (None, None) so the FP4xFP4 joint rule does not spuriously fire; FP8/BF8/FP6/BF6 -> E8 enforcement; FP4 accepting all three legal scales (E8 / E5M3 / E4M3); FP4xFP4 same-scale enforcement (including the 16-vs-32 same family acceptance); mixed FP8 x FP4 legal/illegal combinations. CMake wiring: * library/src/CMakeLists.txt registers the new header alongside the other public rocblaslt-include headers. * clients/CMakeLists.txt links tensilelite::tensilelite-host into hipblaslt-test so TensileLite::isValidMXScaleFormatCombination / mxScaleFormatCombinationError resolve at link time. * clients/tests/src/CMakeLists.txt adds the new gtest source and the rocblaslt public-include directory (deliberately not the private src/include path, to avoid the utility.hpp shadow). Verified: cmake --build build --target hipblaslt-test -j8 ./hipblaslt-test --gtest_filter='MXScaleFormatHelpersTest.*' -> 13 / 13 passed. cmake --build build --target tensilelite-tests -j8 ./tensilelite-tests --gtest_filter='IsMXProblem*:MXClassification*:MXMixedClass*:MXErrorString*:MXProblemIntegration*' -> 32 / 32 passed (no regression from the refactor). Co-authored-by: Cursor <cursoragent@cursor.com>
…mx-scale-fmt-combos
Contributor
Author
|
Closing in favor of #7814, which carries the reduced-scope (tensilelite-C++-only) version of this work. The Python/rocisa validator, the hipBLASLt host-API guard, and the hipBLASLt client-side unit test have been dropped per updated priorities; only the in-tensilelite check that rejects solutions with invalid scale-format combinations is retained there. |
jaopaulolc
added a commit
that referenced
this pull request
May 30, 2026
….py (#7814) This PR supersedes #7768. ## Motivation gfx1250's `v_wmma_scale_f32_16x16x128_f8f6f4` only accepts a fixed set of joint MX scale-format tuples `(A matrix class, A scale, B matrix class, B scale)`. The AMDGPU assembler does not enforce that joint constraint today (see [ROCm/llvm-project#2634](ROCm/llvm-project#2634)), so a kernel candidate whose tuple is illegal would otherwise codegen into an encoding the hardware does not implement. Rejecting illegal candidates inside the kernel generator turns this into a clean "no solution selected" outcome at solution-selection time instead of a silent miscompile or a hard hardware fault at run time. ## Technical Details Add a validator helper `_validateMXScaleFormatCombination(state, ...)` in `projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py`, and call it from `Solution.assignDerivedParameters` immediately after the existing `_deriveAndValidateMXScaleLayoutAndTransport` helper, so the rule runs at the same point as every other MX-related derived-parameter check. **Rules enforced** (per the ISA spec, mirrored in `table-valid-combinations.txt`): * **FP8 / BF8 / FP6 / BF6** (incl. `_fnuz` variants of FP8/BF8) must pair with **E8** (UE8M0) scale on their own side. * **FP4** accepts **E8**, **E5M3**, or **E4M3** scale. * When both A and B are **FP4** the two scales must match. **Inputs read.** Only `state["ProblemType"]` fields: `DataTypeA`, `DataTypeB` (matrix dtype), `DataTypeMXSA`, `DataTypeMXSB` (scale dtype), and `MXBlockA`, `MXBlockB` (per-side MX-block width). **Field-shape compatibility.** `ProblemType` fields can arrive as `DataType` wrappers (during `assignDerivedParameters`) or as raw `rocisa.enum.DataTypeEnum` values (after `cleanupProblemTypeForLogging`). The helper normalizes both shapes via a small `_mxEnumValue` resolver. **Short-circuit.** Sides whose `MXBlock` is `0` carry no MX scale and are skipped. The helper returns `True` early when neither side has MX scaling, so non-MX problems are entirely unaffected. **Rejection path.** A candidate whose joint tuple is illegal is dropped via `reject(state, printRejectionReason, msg)`, which sets `state["Valid"] = False` and (with `--global-parameters=PrintSolutionRejectionReason=True` or the equivalent YAML setting) prints a diagnostic of the form: ``` reject: Invalid MX scale-format combination (A=FP4, AScale=E5M3, B=FP4, BScale=E4M3): FP4 x FP4 requires AScale (E5M3) == BScale (E4M3); see table-valid-combinations.txt / ROCm/llvm-project#2634. ``` The diagnostic always names the offending tuple in ISA-spec spelling (`FP8`, `BF8`, `FP6`, `BF6`, `FP4`, `E8`, `E5M3`, `E4M3`), the rule(s) that failed (per-side dtype/scale mismatch and/or the FP4 x FP4 joint rule), and the cross-reference to the source-of-truth (`ROCm/llvm-project#2634`). **Scope.** Pure-Python change in one helper plus its caller, gated by `MXBlockA != 0 or MXBlockB != 0`. No new dependencies — the helper uses the already-imported `rocisa.enum.DataTypeEnum` and the existing `reject(...)` mechanism. ## Test Plan New `projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py` exercises the helper directly with minimal `Solution`-state fixtures (no client build, no GPU, no rocisa device code). Coverage: * Short-circuit when both `MXBlock` fields are `0`, including an FP4 x FP4 dtype pair that must **not** trigger the joint rule when neither side has MX scaling. * Per-side rule: every FP8 / BF8 / FP6 / BF6 (incl. `_fnuz`) variant accepts E8 and rejects E5M3 / E4M3 on each side independently. * Per-side rule: FP4 accepts all three legal scales (E8, E5M3, E4M3) when paired with itself. * Joint rule: every FP4 x FP4 scale-mismatch is rejected; matching scales are accepted. * Mixed-class combos (FP8 x FP4, BF8 x FP4, ...): each side must satisfy its own per-side rule; the joint FP4 x FP4 rule does not fire across mixed classes. * Asymmetric MX (only one side has `mxBlock != 0`): the non-MX side is fully skipped, including the FP4 x FP4 joint trigger. * Diagnostic-message contract: `capsys` captures the rejection string and asserts it names the offending tuple, the rule that failed, and the `ROCm/llvm-project#2634` reference. * Field-shape compatibility: both `DataType`-wrapped and raw `DataTypeEnum` field shapes are exercised. * **Authoritative spec table:** every legal `(A, scaleA, B, scaleB)` tuple permitted by the ISA is exercised end-to-end (positive set), so the test file doubles as a single source of truth for what gfx1250 allows. Run locally with: ``` cd projects/hipblaslt/tensilelite PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py -v PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit ``` ## Test Result ``` $ PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py 138 passed in 0.31s $ PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit 859 passed, 199 skipped in 38.08s ``` The 199 skips are arch-marker-gated tests that always skip on a non-GPU host; no regressions versus `develop`. --------- Co-authored-by: Cursor <cursoragent@cursor.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The AMDGPU assembler currently accepts invalid
(matrix_fmt, matrix_scale_fmt)tuples on gfx1250'sv_wmma_scale_f32_16x16x128_f8f6f4, silently emitting illegal encodings (ROCm/llvm-project#2634). Out of 225 possible tuples only 43 are legal; the assembler accepts all 225 today.This PR guards every host path that can build a kernel for one of those illegal tuples, so the failure surfaces as a clean validation error long before any incorrect codegen happens.
Rules enforced (per the ISA spec):
Changes
Foundation (new)
tensilelite/include/Tensile/MXScaleFormatValidation.hpp—TensileLite::isValidMXScaleFormatCombination,mxScaleFormatCombinationError, etc. Single source of truth for the rule set.tensilelite/src/MXScaleFormatValidation.cpp— error-string formatting (matching wording in C++ and Python).Tensilelite C++
include/Tensile/ContractionProblem.hpp,src/ContractionProblem.cpp—ContractionProblemGemm::validateMXScaleFormats()(throwsstd::runtime_error) andisValidMXScaleFormats()(bool). Invoked fromsetMXScaleA,setMXScaleB, andconsistencyCheckso invalid combos throw at the point they are introduced.hipBLASLt host
rocblaslt_mat_utils.hpp—rocblasltScalingFormatToMXScaleDataType,rocblasltHipDataTypeToMXMatrixDataType, andvalidateMXScaleFormatCombinationhelpers. Wired intorocblaslt_matmul_valid_argsso the API returnsrocblaslt_status_invalid_value(instead of building a broken kernel).tensile_host.cpp— defensive try/catch aroundtensileProblem.validateMXScaleFormats()after eachScalingFormat->setMXScale*switch (bothConstructTensileProblemandupdateTensileProblem).Tensilelite Python
Tensile/Common/MXScaleFormatValidation.py— Python mirror of the validator. Error wording is byte-for-byte identical to the C++ implementation.Tensile/SolutionStructs/Problem.py— newProblemType._checkMXScaleFormatCombination(), called between MX-scale defaulting and_checkIfSupportedGEMMType(). Validates againstMacDataTypeA/B(the MAC compute-input type, matching how the ISA constraint applies).Tests (new)
tests/MXScaleFormatValidation_test.cpp— GoogleTest covering all 43 valid combinations fromtable-valid-combinations.txtplus curated invalid samples, classification helpers, error-string contract, and end-to-end integration on realContractionProblemGemminstances. Registered intests/CMakeLists.txt. Uses targeted using-declarations (not a broadusing namespace TensileLite;) to avoid colliding with theTensileLite::E8/TensileLite::E5M3struct types.Tensile/Tests/unit/test_MXScaleFormatValidation.py— 193 pytest cases mirroring the C++ matrix, including ProblemType construction (accept / reject / no-MX-skip) and a YAML-spelling round-trip guard.Audit of existing tests/configs
Audited every gfx1250-targeted YAML under
Tensile/Tests/common/{gemm,streamk}/gfx12*(15 files, 37 problem-type blocks), every C++ MX test fixture, and every hipBLASLt sample/client helper. Zero invalid combinations were found — all existing configs already obey the rules. No test data fixes were required.The audit confirmed three correctness notes that this PR depends on:
f8as a YAML scale spelling means E4M3 (the C++ enum reusesrocisa::DataType::Float8for the E4M3 byte scale; matchestensile_host.cpp'sBlock_32_UE4M3 -> Float8mapping)._FNUZmatrix dtypes are correctly treated as non-MX and bypass the gfx1250 rules.mxf4_gfx1250.yamlalready use matching scales.Test plan
All items have been executed in this dev environment (ROCm 7.14 at
/opt/rocm, FFM/mi450 GPU simulator).Built the C++ gtest binary with the superbuild flags mirroring
~/bin/rocbuild(withGPU_TARGETS=gfx1250andTENSILELITE_BUILD_TESTING=ON), then ran the validator suites:Result: 124 passed, 0 failed across 13 test suites — every valid (43) and invalid (curated) combination, the FP4×FP4 mismatch rule, the error-string format contract, the classification helpers, and end-to-end
ContractionProblemGemmintegration.pytest projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py -v— 193 passed, 0 failed.Existing
Tensile/Tests/unitsuite (Python portion of the regression check) — 841 passed, 6 skipped, 0 failed.Loaded a YAML config with an invalid combo (
MacDataTypeA: F4,MacDataTypeB: F4,DataTypeMXSA: E5M3,DataTypeMXSB: f8) intoProblemType(...)and confirmed it raises with:Invalid MX scale-format combination (A=FP4, AScale=E5M3, B=FP4, BScale=E4M3): FP4 x FP4 requires AScale (E5M3) == BScale (E4M3); see table-valid-combinations.txt / ROCm/llvm-project#2634.Built
libhipblaslt.so(against FFM/mi450 because no real gfx1250 GPU is reachable here; the validator is architecture-independent host C++ so the test path is identical) and ran a smoke test that callshipblasLtMatmulagainst four invalid(A type, A scale mode, B type, B scale mode)tuples. All four returnedHIPBLAS_STATUS_INVALID_VALUE(3):Made with Cursor