Skip to content

[tensilelite] Reject invalid MX scale-format combinations in Solution.py#7814

Merged
jaopaulolc merged 7 commits into
developfrom
users/jolabega/tensilelite-only-mx-scale-format-validation
May 30, 2026
Merged

[tensilelite] Reject invalid MX scale-format combinations in Solution.py#7814
jaopaulolc merged 7 commits into
developfrom
users/jolabega/tensilelite-only-mx-scale-format-validation

Conversation

@jaopaulolc
Copy link
Copy Markdown
Contributor

@jaopaulolc jaopaulolc commented May 27, 2026

This PR supersedes #7768.

Motivation

gfx1250's v_wmma_scale_f32_16x16x128_f8f6f4 only accepts a fixed set of joint MX scale-format tuples (A matrix class, A scale, B matrix class, B scale). The AMDGPU assembler does not enforce that joint constraint today (see ROCm/llvm-project#2634), so a kernel candidate whose tuple is illegal would otherwise codegen into an encoding the hardware does not implement.

Rejecting illegal candidates inside the kernel generator turns this into a clean "no solution selected" outcome at solution-selection time instead of a silent miscompile or a hard hardware fault at run time.

Technical Details

Add a validator helper _validateMXScaleFormatCombination(state, ...) in projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py, and call it from Solution.assignDerivedParameters immediately after the existing _deriveAndValidateMXScaleLayoutAndTransport helper, so the rule runs at the same point as every other MX-related derived-parameter check.

Rules enforced (per the ISA spec, mirrored in table-valid-combinations.txt):

  • FP8 / BF8 / FP6 / BF6 (incl. _fnuz variants of FP8/BF8) must pair with E8 (UE8M0) scale on their own side.
  • FP4 accepts E8, E5M3, or E4M3 scale.
  • When both A and B are FP4 the two scales must match.

Inputs read. Only state["ProblemType"] fields: DataTypeA, DataTypeB (matrix dtype), DataTypeMXSA, DataTypeMXSB (scale dtype), and MXBlockA, MXBlockB (per-side MX-block width).

Field-shape compatibility. ProblemType fields can arrive as DataType wrappers (during assignDerivedParameters) or as raw rocisa.enum.DataTypeEnum values (after cleanupProblemTypeForLogging). The helper normalizes both shapes via a small _mxEnumValue resolver.

Short-circuit. Sides whose MXBlock is 0 carry no MX scale and are skipped. The helper returns True early when neither side has MX scaling, so non-MX problems are entirely unaffected.

Rejection path. A candidate whose joint tuple is illegal is dropped via reject(state, printRejectionReason, msg), which sets state["Valid"] = False and (with --global-parameters=PrintSolutionRejectionReason=True or the equivalent YAML setting) prints a diagnostic of the form:

reject: Invalid MX scale-format combination (A=FP4, AScale=E5M3, B=FP4, BScale=E4M3): FP4 x FP4 requires AScale (E5M3) == BScale (E4M3); see table-valid-combinations.txt / ROCm/llvm-project#2634.

The diagnostic always names the offending tuple in ISA-spec spelling (FP8, BF8, FP6, BF6, FP4, E8, E5M3, E4M3), the rule(s) that failed (per-side dtype/scale mismatch and/or the FP4 x FP4 joint rule), and the cross-reference to the source-of-truth (ROCm/llvm-project#2634).

Scope. Pure-Python change in one helper plus its caller, gated by MXBlockA != 0 or MXBlockB != 0. No new dependencies — the helper uses the already-imported rocisa.enum.DataTypeEnum and the existing reject(...) mechanism.

Test Plan

New projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py exercises the helper directly with minimal Solution-state fixtures (no client build, no GPU, no rocisa device code). Coverage:

  • Short-circuit when both MXBlock fields are 0, including an FP4 x FP4 dtype pair that must not trigger the joint rule when neither side has MX scaling.
  • Per-side rule: every FP8 / BF8 / FP6 / BF6 (incl. _fnuz) variant accepts E8 and rejects E5M3 / E4M3 on each side independently.
  • Per-side rule: FP4 accepts all three legal scales (E8, E5M3, E4M3) when paired with itself.
  • Joint rule: every FP4 x FP4 scale-mismatch is rejected; matching scales are accepted.
  • Mixed-class combos (FP8 x FP4, BF8 x FP4, ...): each side must satisfy its own per-side rule; the joint FP4 x FP4 rule does not fire across mixed classes.
  • Asymmetric MX (only one side has mxBlock != 0): the non-MX side is fully skipped, including the FP4 x FP4 joint trigger.
  • Diagnostic-message contract: capsys captures the rejection string and asserts it names the offending tuple, the rule that failed, and the ROCm/llvm-project#2634 reference.
  • Field-shape compatibility: both DataType-wrapped and raw DataTypeEnum field shapes are exercised.
  • Authoritative spec table: every legal (A, scaleA, B, scaleB) tuple permitted by the ISA is exercised end-to-end (positive set), so the test file doubles as a single source of truth for what gfx1250 allows.

Run locally with:

cd projects/hipblaslt/tensilelite
PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py -v
PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit

Test Result

$ PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py
138 passed in 0.31s

$ PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit
859 passed, 199 skipped in 38.08s

The 199 skips are arch-marker-gated tests that always skip on a non-GPU host; no regressions versus develop.

gfx1250's v_wmma_scale_f32_16x16x128_f8f6f4 only accepts a fixed set of
(A matrix class, A scale, B matrix class, B scale) tuples. The AMDGPU
assembler does not enforce these joint constraints
(ROCm/llvm-project#2634), so kernel candidates with illegal tuples must
be rejected by the kernel generator before any codegen happens.

Add the rejection in Solution.assignDerivedParameters, alongside the
existing _deriveAndValidateMXScaleLayoutAndTransport helper, so the
rule fires at the same point in the pipeline as every other MX-related
derived-parameter check.

Rules enforced (per the ISA / table-valid-combinations.txt):
  * FP8 / BF8 / FP6 / BF6 (incl. _fnuz variants) must pair with E8
    (UE8M0) scale.
  * FP4 accepts E8, E5M3, or E4M3 scale.
  * When both A and B are FP4 the two scales must match.

A candidate whose joint tuple is illegal is dropped via
``reject(state, ..., "Invalid MX scale-format combination (...): ...; see
table-valid-combinations.txt / ROCm/llvm-project#2634.")``, which sets
``state["Valid"] = False`` and (when --global-parameters=
PrintSolutionRejectionReason=True) prints the diagnostic with the
offending tuple in ISA-spec spelling.

Sides whose ``MXBlock`` is 0 carry no MX scale and are skipped, so the
guard only fires for real MX problems.

Tests
-----
New ``Tensile/Tests/unit/test_MXScaleFormatValidation.py`` (138 pytest
cases) covers the helper directly with minimal Solution-state fixtures:

  * Short-circuit when both MXBlock fields are 0 (including the
    FP4 x FP4 dtype pair, which must NOT trigger the joint rule when
    neither side has MX scaling).
  * Per-side rule: every FP8 / BF8 / FP6 / BF6 (incl. _fnuz) variant
    accepts E8 and rejects E5M3 / E4M3 on each side independently.
  * Per-side rule: FP4 accepts all three legal scales (E8, E5M3, E4M3)
    when paired with itself.
  * Joint rule: every FP4 x FP4 scale-mismatch is rejected; matching
    scales (including the 32-vs-16 block-width same-family case
    expressed at the dtype layer) are accepted.
  * Mixed-class combos (FP8 x FP4, BF8 x FP4, ...) - each side must
    satisfy its own per-side rule; joint FP4 x FP4 does not fire.
  * Asymmetric MX (only one side has mxBlock != 0) - the non-MX side
    is fully skipped, including the FP4 x FP4 joint trigger.
  * Diagnostic message contract: ``capsys`` captures the rejection
    string and asserts it names the offending tuple in ISA-spec
    spellings, the rule that failed, and the
    ``table-valid-combinations.txt / ROCm/llvm-project#2634`` reference.
  * Field-shape compatibility: ``ProblemType`` fields can arrive as
    ``DataType`` wrappers (during ``assignDerivedParameters``) or as
    raw ``DataTypeEnum`` values (after ``cleanupProblemTypeForLogging``).
    Both shapes are covered.
  * Authoritative spec table: every legal (A, scaleA, B, scaleB) tuple
    permitted by the ISA is exercised end-to-end (positive set) so the
    test file is a single source of truth for what gfx1250 allows.

Verified locally with::

  PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py
  -> 138 passed.

  PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit
  -> 859 passed, 199 skipped (the skips are arch-marker-gated tests
     that always skip on a non-GPU host; no regressions).

Co-authored-by: Cursor <cursoragent@cursor.com>
@jaopaulolc jaopaulolc force-pushed the users/jolabega/tensilelite-only-mx-scale-format-validation branch from c3991a0 to 8b5fc68 Compare May 27, 2026 16:57
@jaopaulolc jaopaulolc changed the title [tensilelite] Validate MX scale-format combinations for gfx1250 [tensilelite] Reject invalid MX scale-format combinations in Solution.py May 27, 2026
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented May 27, 2026

Codecov Report

❌ Patch coverage is 85.91549% with 10 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
...ensile/SolutionStructs/Validators/MXScaleFormat.py 88.24% 4 Missing and 4 partials ⚠️
...lt/tensilelite/Tensile/SolutionStructs/Solution.py 33.33% 2 Missing ⚠️

❌ Your project status has failed because the head coverage (77.83%) is below the target coverage (80.00%). You can increase the head coverage or adjust the target coverage.

Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #7814      +/-   ##
===========================================
- Coverage    62.06%   61.97%   -0.09%     
===========================================
  Files         2088     2089       +1     
  Lines       358303   358385      +82     
  Branches     54196    54119      -77     
===========================================
- Hits        222365   222091     -274     
- Misses      117115   117479     +364     
+ Partials     18823    18815       -8     
Flag Coverage Δ *Carryforward flag
TensileLite 27.52% <85.92%> (+0.09%) ⬆️
hipBLAS 90.65% <ø> (ø) Carriedforward from 42c31c8
hipBLASLt 41.27% <ø> (ø)
hipCUB 82.21% <ø> (ø) Carriedforward from 42c31c8
hipDNN 85.87% <ø> (-0.72%) ⬇️ Carriedforward from 42c31c8
hipFFT 50.00% <ø> (ø) Carriedforward from 42c31c8
hipRAND 76.12% <ø> (ø) Carriedforward from 42c31c8
hipSOLVER 69.24% <ø> (ø) Carriedforward from 42c31c8
hipSPARSE 85.09% <ø> (-0.32%) ⬇️ Carriedforward from 42c31c8
rocBLAS 48.09% <ø> (ø) Carriedforward from 42c31c8
rocFFT 52.07% <ø> (ø) Carriedforward from 42c31c8
rocRAND 57.04% <ø> (+0.01%) ⬆️ Carriedforward from 42c31c8
rocSOLVER 77.83% <ø> (ø) Carriedforward from 42c31c8
rocSPARSE 72.68% <ø> (ø) Carriedforward from 42c31c8

*This pull request uses carry forward flags. Click here to find out more.

Files with missing lines Coverage Δ
...lt/tensilelite/Tensile/SolutionStructs/Solution.py 7.02% <33.33%> (+0.02%) ⬆️
...ensile/SolutionStructs/Validators/MXScaleFormat.py 88.24% <88.24%> (ø)

... and 35 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

@jichangjichang
Copy link
Copy Markdown
Contributor

the test and reject condition should be only for gfx1250 (or you can use HasWMMA_V3)

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds a TensileLite-side validator to reject illegal joint MX matrix/scale-format tuples for gfx1250 v_wmma_scale_f32_16x16x128_f8f6f4, preventing invalid kernels from being generated when the assembler doesn’t enforce the hardware constraint.

Changes:

  • Added _validateMXScaleFormatCombination(state, ...) to enforce per-side and FP4×FP4 joint MX scale-format legality and reject invalid candidates via reject(...).
  • Wired the new validator into Solution.assignDerivedParameters immediately after existing MX layout/transport derivation.
  • Added a new pytest unit suite exercising valid/invalid combinations, one-sided MX behavior, and rejection message content.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 5 comments.

File Description
projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py Implements MX scale-format combination validation and invokes it during derived-parameter assignment.
projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py Adds unit tests covering the new validator’s rule set and diagnostics.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py Outdated
Comment thread projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py Outdated
Comment thread projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py Outdated
Comment thread projects/hipblaslt/tensilelite/Tensile/Tests/unit/test_MXScaleFormatValidation.py Outdated
Comment thread projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py Outdated
Address PR #7814 review feedback.

Gate the validator on gfx1250 (jichangjichang)
----------------------------------------------
The joint MX scale-format rules in table-valid-combinations.txt are
specific to gfx1250's v_wmma_scale_f32_16x16x128_f8f6f4. Older
architectures use different MX instructions whose constraints differ,
so applying the gfx1250 rule set to them would over-reject legitimate
kernels.

* ``_validateMXScaleFormatCombination`` now takes an ``asmCaps``
  argument and short-circuits to True (no-op) when
  ``asmCaps["HasWMMA_V3"]`` is false.
* The call site in ``Solution.assignDerivedParameters`` passes
  ``isaInfoMap[isa].asmCaps`` (same shape used by the existing
  ``_deriveAndValidateMXScaleLayoutAndTransport`` call right above).
* New ``TestArchitectureGate`` test class verifies:
    - illegal-on-gfx1250 combos pass through untouched when the cap
      is missing OR explicitly false,
    - the FP4 x FP4 joint rule no longer fires off-arch,
    - the same combo is still rejected on gfx1250 (positive control).

Fix Copilot review nits
-----------------------
* ``ids=lambda e: e.name`` was applied to multi-argument parametrize
  decorations (line 183, line 339) where pytest passes the entire
  tuple to the lambda. The lambda silently fell back to pytest's
  default repr-based ids; replaced with an explicit
  ``[f"{a.name}-{sa.name}_x_{b.name}-{sb.name}" ...]`` list for the
  authoritative spec-table parametrize, and dropped the unused
  ``ids=`` from the FP4 x FP4 mismatch table.
* Removed two references to ``Problem.cleanupProblemTypeForLogging``
  (it does not exist in the repo). Reworded both spots to describe
  the actual shapes the helper normalizes.
* Fixed the inline comment in ``_mxEnumValue`` that incorrectly
  claimed ``DataType.value`` was a ``DataTypeEnum``. Per
  ``Tensile/Common/DataType.py`` it is the underlying *int*; the
  helper still works correctly with that knowledge, the comment
  just had to match reality.

Verified locally
----------------
$ PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py
142 passed in 0.26s   (138 + 4 new gating tests)

$ PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit
1057 passed, 5 skipped in 42.62s

Co-authored-by: Cursor <cursoragent@cursor.com>
@jaopaulolc
Copy link
Copy Markdown
Contributor Author

Good point. The MX rules in this PR (per table-valid-combinations.txt / ROCm/llvm-project#2634) are specific to gfx1250's v_wmma_scale_f32_16x16x128_f8f6f4, so the validator and tests are now gated on asmCaps["HasWMMA_V3"] (matching how the surrounding gfx1250-specific rejects in assignDerivedParameters are spelled).

Behavior changes in 26a13d7:

  • _validateMXScaleFormatCombination takes an asmCaps argument and short-circuits to True (no-op) when HasWMMA_V3 is false, so non-gfx1250 candidates are passed through untouched.
  • Call site in Solution.assignDerivedParameters now passes isaInfoMap[isa].asmCaps (same shape used by the existing _deriveAndValidateMXScaleLayoutAndTransport call right above).
  • New TestArchitectureGate test class confirms the gate from both sides: tuples that would be illegal on gfx1250 pass through untouched when the cap is missing or explicitly false, and the same combos are still rejected on gfx1250 as a positive control.

Comment thread projects/hipblaslt/tensilelite/Tensile/SolutionStructs/Solution.py Outdated
jaopaulolc and others added 3 commits May 29, 2026 09:21
Address review feedback from Nathan Henderson on PR #7814: pull the MX
scale-format validation out of Solution.py (already over 5,000 lines) and
into its own module under the existing Validators/ package, so other
Solution-level validation code can join it later.

* New module: ``Tensile/SolutionStructs/Validators/MXScaleFormat.py``
  - Houses the gfx1250 WMMA_V3 MX rule set, all dtype/scale constants
    (``_MX_FP8_LIKE``, ``_MX_BF8_LIKE``, ``_MX_F6_LIKE``, ``_MX_FP4``,
    ``_E8_ONLY``, ``_FP4_SCALES``), the four helpers
    (``_mxMatrixLabel``, ``_mxScaleLabel``, ``_mxEnumValue``,
    ``_isLegalMXScaleForMatrix``), and the validator itself.
  - Exposes a single public entry point,
    ``validateMXScaleFormatCombination(state, asmCaps, printRejectionReason)``,
    matching the ``validateMIParameters`` / ``validateWorkGroup`` naming
    convention used by the sibling ``Validators/MatrixInstruction.py``
    and ``Validators/WorkGroup.py``.
* ``Solution.py``: drops the moved 200-line MX block (header comment,
  constants, helpers, validator), drops the now-unused
  ``from rocisa.enum import DataTypeEnum`` import, and calls the
  validator through ``from .Validators.MXScaleFormat import
  validateMXScaleFormatCombination``. Net: -196 lines.
* Unit test: ``Tests/unit/test_MXScaleFormatValidation.py`` updated to
  import the public symbol from the new module path; the 142 cases run
  unchanged.

Pure code-motion + public rename. No behavioural change: the helper
signature, the rule set, and the exact reject-message text (which the
tests assert on) are all preserved.

Verified::

  PYTHONPATH=. python -m pytest Tensile/Tests/unit/test_MXScaleFormatValidation.py
  -> 142 passed.

  PYTHONPATH=. python -m pytest Tensile/Tests/unit -m unit
  -> 1057 passed, 5 skipped.

Co-authored-by: Cursor <cursoragent@cursor.com>
@jaopaulolc jaopaulolc enabled auto-merge (squash) May 29, 2026 17:29
@jaopaulolc jaopaulolc merged commit d8f3143 into develop May 30, 2026
48 checks passed
@jaopaulolc jaopaulolc deleted the users/jolabega/tensilelite-only-mx-scale-format-validation branch May 30, 2026 08:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants