[tensilelite] Reject invalid MX scale-format combinations in Solution.py#7814
Conversation
gfx1250's v_wmma_scale_f32_16x16x128_f8f6f4 only accepts a fixed set of (A matrix class, A scale, B matrix class, B scale) tuples. The AMDGPU assembler does not enforce these joint constraints (ROCm/llvm-project#2634), so kernel candidates with illegal tuples must be rejected by the kernel generator before any codegen happens. Add the rejection in Solution.assignDerivedParameters, alongside the existing _deriveAndValidateMXScaleLayoutAndTransport helper, so the rule fires at the same point in the pipeline as every other MX-related derived-parameter check. Rules enforced (per the ISA / table-valid-combinations.txt): * FP8 / BF8 / FP6 / BF6 (incl. _fnuz variants) must pair with E8 (UE8M0) scale. * FP4 accepts E8, E5M3, or E4M3 scale. * When both A and B are FP4 the two scales must match. A candidate whose joint tuple is illegal is dropped via ``reject(state, ..., "Invalid MX scale-format combination (...): ...; see table-valid-combinations.txt / ROCm/llvm-project#2634.")``, which sets ``state["Valid"] = False`` and (when --global-parameters= PrintSolutionRejectionReason=True) prints the diagnostic with the offending tuple in ISA-spec spelling. Sides whose ``MXBlock`` is 0 carry no MX scale and are skipped, so the guard only fires for real MX problems. Tests ----- New ``Tensile/Tests/unit/test_MXScaleFormatValidation.py`` (138 pytest cases) covers the helper directly with minimal Solution-state fixtures: * Short-circuit when both MXBlock fields are 0 (including the FP4 x FP4 dtype pair, which must NOT trigger the joint rule when neither side has MX scaling). * Per-side rule: every FP8 / BF8 / FP6 / BF6 (incl. _fnuz) variant accepts E8 and rejects E5M3 / E4M3 on each side independently. * Per-side rule: FP4 accepts all three legal scales (E8, E5M3, E4M3) when paired with itself. * Joint rule: every FP4 x FP4 scale-mismatch is rejected; matching scales (including the 32-vs-16 block-width same-family case expressed at the dtype layer) are accepted. * Mixed-class combos (FP8 x FP4, BF8 x FP4, ...) - each side must satisfy its own per-side rule; joint FP4 x FP4 does not fire. * Asymmetric MX (only one side has mxBlock != 0) - the non-MX side is fully skipped, including the FP4 x FP4 joint trigger. * Diagnostic message contract: ``capsys`` captures the rejection string and asserts it names the offending tuple in ISA-spec spellings, the rule that failed, and the ``table-valid-combinations.txt / ROCm/llvm-project#2634`` reference. * Field-shape compatibility: ``ProblemType`` fields can arrive as ``DataType`` wrappers (during ``assignDerivedParameters``) or as raw ``DataTypeEnum`` values (after ``cleanupProblemTypeForLogging``). Both shapes are covered. * Authoritative spec table: every legal (A, scaleA, B, scaleB) tuple permitted by the ISA is exercised end-to-end (positive set) so the test file is a single source of truth for what gfx1250 allows. Verified locally with:: PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py -> 138 passed. PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit -> 859 passed, 199 skipped (the skips are arch-marker-gated tests that always skip on a non-GPU host; no regressions). Co-authored-by: Cursor <cursoragent@cursor.com>
c3991a0 to
8b5fc68
Compare
Codecov Report❌ Patch coverage is ❌ Your project status has failed because the head coverage (77.83%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #7814 +/- ##
===========================================
- Coverage 62.06% 61.97% -0.09%
===========================================
Files 2088 2089 +1
Lines 358303 358385 +82
Branches 54196 54119 -77
===========================================
- Hits 222365 222091 -274
- Misses 117115 117479 +364
+ Partials 18823 18815 -8
*This pull request uses carry forward flags. Click here to find out more.
🚀 New features to boost your workflow:
|
…format-validation
|
the test and reject condition should be only for gfx1250 (or you can use HasWMMA_V3) |
There was a problem hiding this comment.
Pull request overview
This PR adds a TensileLite-side validator to reject illegal joint MX matrix/scale-format tuples for gfx1250 v_wmma_scale_f32_16x16x128_f8f6f4, preventing invalid kernels from being generated when the assembler doesn’t enforce the hardware constraint.
Changes:
- Added
_validateMXScaleFormatCombination(state, ...)to enforce per-side and FP4×FP4 joint MX scale-format legality and reject invalid candidates viareject(...). - Wired the new validator into
Solution.assignDerivedParametersimmediately after existing MX layout/transport derivation. - Added a new pytest unit suite exercising valid/invalid combinations, one-sided MX behavior, and rejection message content.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.
| File | Description |
|---|---|
projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py |
Implements MX scale-format combination validation and invokes it during derived-parameter assignment. |
projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py |
Adds unit tests covering the new validator’s rule set and diagnostics. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Address PR #7814 review feedback. Gate the validator on gfx1250 (jichangjichang) ---------------------------------------------- The joint MX scale-format rules in table-valid-combinations.txt are specific to gfx1250's v_wmma_scale_f32_16x16x128_f8f6f4. Older architectures use different MX instructions whose constraints differ, so applying the gfx1250 rule set to them would over-reject legitimate kernels. * ``_validateMXScaleFormatCombination`` now takes an ``asmCaps`` argument and short-circuits to True (no-op) when ``asmCaps["HasWMMA_V3"]`` is false. * The call site in ``Solution.assignDerivedParameters`` passes ``isaInfoMap[isa].asmCaps`` (same shape used by the existing ``_deriveAndValidateMXScaleLayoutAndTransport`` call right above). * New ``TestArchitectureGate`` test class verifies: - illegal-on-gfx1250 combos pass through untouched when the cap is missing OR explicitly false, - the FP4 x FP4 joint rule no longer fires off-arch, - the same combo is still rejected on gfx1250 (positive control). Fix Copilot review nits ----------------------- * ``ids=lambda e: e.name`` was applied to multi-argument parametrize decorations (line 183, line 339) where pytest passes the entire tuple to the lambda. The lambda silently fell back to pytest's default repr-based ids; replaced with an explicit ``[f"{a.name}-{sa.name}_x_{b.name}-{sb.name}" ...]`` list for the authoritative spec-table parametrize, and dropped the unused ``ids=`` from the FP4 x FP4 mismatch table. * Removed two references to ``Problem.cleanupProblemTypeForLogging`` (it does not exist in the repo). Reworded both spots to describe the actual shapes the helper normalizes. * Fixed the inline comment in ``_mxEnumValue`` that incorrectly claimed ``DataType.value`` was a ``DataTypeEnum``. Per ``Tensile/Common/DataType.py`` it is the underlying *int*; the helper still works correctly with that knowledge, the comment just had to match reality. Verified locally ---------------- $ PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py 142 passed in 0.26s (138 + 4 new gating tests) $ PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit 1057 passed, 5 skipped in 42.62s Co-authored-by: Cursor <cursoragent@cursor.com>
|
Good point. The MX rules in this PR (per Behavior changes in 26a13d7:
|
…format-validation
…format-validation
Address review feedback from Nathan Henderson on PR #7814: pull the MX scale-format validation out of Solution.py (already over 5,000 lines) and into its own module under the existing Validators/ package, so other Solution-level validation code can join it later. * New module: ``Tensile/SolutionStructs/Validators/MXScaleFormat.py`` - Houses the gfx1250 WMMA_V3 MX rule set, all dtype/scale constants (``_MX_FP8_LIKE``, ``_MX_BF8_LIKE``, ``_MX_F6_LIKE``, ``_MX_FP4``, ``_E8_ONLY``, ``_FP4_SCALES``), the four helpers (``_mxMatrixLabel``, ``_mxScaleLabel``, ``_mxEnumValue``, ``_isLegalMXScaleForMatrix``), and the validator itself. - Exposes a single public entry point, ``validateMXScaleFormatCombination(state, asmCaps, printRejectionReason)``, matching the ``validateMIParameters`` / ``validateWorkGroup`` naming convention used by the sibling ``Validators/MatrixInstruction.py`` and ``Validators/WorkGroup.py``. * ``Solution.py``: drops the moved 200-line MX block (header comment, constants, helpers, validator), drops the now-unused ``from rocisa.enum import DataTypeEnum`` import, and calls the validator through ``from .Validators.MXScaleFormat import validateMXScaleFormatCombination``. Net: -196 lines. * Unit test: ``Tests/unit/test_MXScaleFormatValidation.py`` updated to import the public symbol from the new module path; the 142 cases run unchanged. Pure code-motion + public rename. No behavioural change: the helper signature, the rule set, and the exact reject-message text (which the tests assert on) are all preserved. Verified:: PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py -> 142 passed. PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit -> 1057 passed, 5 skipped. Co-authored-by: Cursor <cursoragent@cursor.com>
…format-validation
This PR supersedes #7768.
Motivation
gfx1250's
v_wmma_scale_f32_16x16x128_f8f6f4only accepts a fixed set of joint MX scale-format tuples(A matrix class, A scale, B matrix class, B scale). The AMDGPU assembler does not enforce that joint constraint today (see ROCm/llvm-project#2634), so a kernel candidate whose tuple is illegal would otherwise codegen into an encoding the hardware does not implement.Rejecting illegal candidates inside the kernel generator turns this into a clean "no solution selected" outcome at solution-selection time instead of a silent miscompile or a hard hardware fault at run time.
Technical Details
Add a validator helper
_validateMXScaleFormatCombination(state, ...)inprojects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py, and call it fromSolution.assignDerivedParametersimmediately after the existing_deriveAndValidateMXScaleLayoutAndTransporthelper, so the rule runs at the same point as every other MX-related derived-parameter check.Rules enforced (per the ISA spec, mirrored in
table-valid-combinations.txt):_fnuzvariants of FP8/BF8) must pair with E8 (UE8M0) scale on their own side.Inputs read. Only
state["ProblemType"]fields:DataTypeA,DataTypeB(matrix dtype),DataTypeMXSA,DataTypeMXSB(scale dtype), andMXBlockA,MXBlockB(per-side MX-block width).Field-shape compatibility.
ProblemTypefields can arrive asDataTypewrappers (duringassignDerivedParameters) or as rawrocisa.enum.DataTypeEnumvalues (aftercleanupProblemTypeForLogging). The helper normalizes both shapes via a small_mxEnumValueresolver.Short-circuit. Sides whose
MXBlockis0carry no MX scale and are skipped. The helper returnsTrueearly when neither side has MX scaling, so non-MX problems are entirely unaffected.Rejection path. A candidate whose joint tuple is illegal is dropped via
reject(state, printRejectionReason, msg), which setsstate["Valid"] = Falseand (with--global-parameters=PrintSolutionRejectionReason=Trueor the equivalent YAML setting) prints a diagnostic of the form:The diagnostic always names the offending tuple in ISA-spec spelling (
FP8,BF8,FP6,BF6,FP4,E8,E5M3,E4M3), the rule(s) that failed (per-side dtype/scale mismatch and/or the FP4 x FP4 joint rule), and the cross-reference to the source-of-truth (ROCm/llvm-project#2634).Scope. Pure-Python change in one helper plus its caller, gated by
MXBlockA != 0 or MXBlockB != 0. No new dependencies — the helper uses the already-importedrocisa.enum.DataTypeEnumand the existingreject(...)mechanism.Test Plan
New
projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.pyexercises the helper directly with minimalSolution-state fixtures (no client build, no GPU, no rocisa device code). Coverage:MXBlockfields are0, including an FP4 x FP4 dtype pair that must not trigger the joint rule when neither side has MX scaling._fnuz) variant accepts E8 and rejects E5M3 / E4M3 on each side independently.mxBlock != 0): the non-MX side is fully skipped, including the FP4 x FP4 joint trigger.capsyscaptures the rejection string and asserts it names the offending tuple, the rule that failed, and theROCm/llvm-project#2634reference.DataType-wrapped and rawDataTypeEnumfield shapes are exercised.(A, scaleA, B, scaleB)tuple permitted by the ISA is exercised end-to-end (positive set), so the test file doubles as a single source of truth for what gfx1250 allows.Run locally with:
Test Result
The 199 skips are arch-marker-gated tests that always skip on a non-GPU host; no regressions versus
develop.