Skip to content

[JAX] Grouped quant+GEMM custom partitioning rules#3058

Open
jberchtold-nvidia wants to merge 12 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-custom-partition-rules
Open

[JAX] Grouped quant+GEMM custom partitioning rules#3058
jberchtold-nvidia wants to merge 12 commits into
NVIDIA:mainfrom
jberchtold-nvidia:jberchtold/gmm-custom-partition-rules

Conversation

@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator

@jberchtold-nvidia jberchtold-nvidia commented May 28, 2026

Description

Adds custom partitioning rules to the grouped quantization and grouped GEMM primitives to support DP/FSDP and EP shardings

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add custom partitioning to grouped quantize and grouped GEMM primitives
  • Add tests to validate custom partitioning functions directly, then a real test on a JIT'd program to verify the shardings are correct

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@jberchtold-nvidia jberchtold-nvidia marked this pull request as draft May 28, 2026 20:59
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 28, 2026

Greptile Summary

This PR adds custom partitioning rules (partition and shardy_sharding_rule) to GroupedQuantizePrimitive and GroupedGemmPrimitive, enabling these primitives to participate in JAX's SPMD partitioning for DP, FSDP, and EP mesh axes. It also adds supporting utilities to sharding.py, removes the now-superseded kernel_fsdp_info plumbing from grouped_dense, and fixes with_sharding_constraint to keep only auto mesh axes (instead of filtering manual-only axes).

  • Custom partitioning for grouped ops: _parse_partition_specs in both primitives filters input specs to EP/DP/FSDP axes, gathers the RHS (weight) along the FSDP axis when needed, and emits warnings when unsupported axes (e.g. TP) are silently dropped. GroupedGemmPrimitive additionally inserts a psum all-reduce when both operands share a DP/FSDP contracting axis.
  • New sharding.py utilities: spec_axes, filter_spec_axes, merge_axis_specs, common_spec_axis, local_shape_from_spec, local_2d_sizes_from_spec, and supported_grouped_partition_axes are extracted as shared helpers consumed by both primitives.
  • Cleanup in dense.py: _all_gather_kernel / _psum_scatter_kernel (previously dead due to an assert not kernel_fsdp_enabled) and the kernel_fsdp_info parameter are removed; sanitization of negative contracting dims is now done upfront in grouped_dense.

Confidence Score: 5/5

Safe to merge; the custom partitioning logic is well-tested with direct unit tests and an end-to-end JIT+VJP test, the removal of dead kernel_fsdp_info plumbing is clean, and the new sharding.py helpers are self-contained.

All changes are additive (new partition/shardy_sharding_rule statics) or clean removals of previously-guarded-dead code. The two findings are speculative edge cases — the scale-input warning gap requires a caller to pass an unusual scale sharding, and the reduce_axis/out_spec inconsistency requires JAX's result_infos to supply a dp-sharded output spec in a configuration where both operands are dp-contracted, which does not arise in the documented EP/FSDP usage patterns. The test suite directly exercises the partition spec logic for all supported axis combinations including unsupported-axis stripping, and the single-process end-to-end test verifies correct sharding on forward and backward passes.

transformer_engine/jax/cpp_extensions/gemm.py and transformer_engine/jax/cpp_extensions/quantization.py warrant a second look for the two minor gaps noted above.

Important Files Changed

Filename Overview
transformer_engine/jax/cpp_extensions/gemm.py Adds _parse_partition_specs, partition, and shardy_sharding_rule to GroupedGemmPrimitive; the psum reduce_axis is not stripped from out_spec when result_infos provides a dp/fsdp-sharded output, which could declare an inconsistent output sharding.
transformer_engine/jax/cpp_extensions/quantization.py Adds _parse_partition_specs, partition, and shardy_sharding_rule to GroupedQuantizePrimitive; scale input (arg_infos[1]) silently drops unsupported axes without a warning, unlike x and group_sizes.
transformer_engine/jax/sharding.py Adds 15 new shared partitioning helpers (spec_axes, filter_spec_axes, merge_axis_specs, local_shape_from_spec, etc.), adds ep_resource to MeshResource, adds validate flag to global_mesh_resource, and fixes with_sharding_constraint to filter to auto-axes only.
transformer_engine/jax/dense.py Removes dead kernel_fsdp_info plumbing and _all_gather_kernel/_psum_scatter_kernel helpers; adds negative contracting-dim sanitization upfront in grouped_dense.
transformer_engine/jax/flax/module.py Changes x contracting dim from (1,) to (-1,) to align with the new sanitize_dims normalization in grouped_dense; semantically identical for 2D inputs.
tests/jax/test_distributed_grouped_gemm.py New test file covering partition spec logic for both GroupedQuantizePrimitive and GroupedGemmPrimitive, plus a single-process end-to-end test with JIT, VJP, and sharding assertions.
tests/jax/test_multi_process_distributed_grouped_gemm.py Deleted; replaced by the new single-process test and the multi-process test now runs via the updated test.sh.
qa/L1_jax_distributed_unittest/test.sh Adds test_distributed_grouped_gemm.py to the CI suite and removes the commented-out TODO for the now-deleted multi-process test.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["JAX jit / partition call"] --> B["GroupedGemmPrimitive._parse_partition_specs"]
    A --> C["GroupedQuantizePrimitive._parse_partition_specs"]

    B --> B1["filter_spec_axes → allowed EP/DP/FSDP only"]
    B1 --> B2{"rhs_is_ragged?"}
    B2 -- No --> B3{"ep_axis in group_spec?"}
    B3 -- Yes --> B4["Add ep_axis to rhs_data_spec[0]"]
    B2 -- Yes --> B5["skip ep injection"]
    B4 --> B6{"rhs has fsdp?"}
    B5 --> B6
    B6 -- Yes --> B7["strip fsdp from rhs → gather_rhs_fsdp=True"]
    B6 -- No --> B8["gather_rhs_fsdp=False"]
    B7 --> B9["reduce_axis = common_spec_axis(lhs, rhs, dp|fsdp)"]
    B8 --> B9
    B9 --> B10{"gather_rhs_fsdp AND reduce_axis set?"}
    B10 -- Yes --> B11["clear reduce_axis"]
    B10 -- No --> B12["keep reduce_axis"]
    B11 --> B13["return arg_specs, out_spec, reduce_axis"]
    B12 --> B13

    B13 --> B14["sharded_impl: call impl with local sizes"]
    B14 --> B15{"reduce_axis != None?"}
    B15 -- Yes --> B16["psum(out, reduce_axis)"]
    B15 -- No --> B17["return out as-is"]
    B16 --> B17

    C --> C1["filter_spec_axes → allowed EP/DP/FSDP only"]
    C1 --> C2["_contiguous_flat_input_spec: strip dims after flatten_axis"]
    C2 --> C3["_flat_data_spec: merge all dims → 1D spec"]
    C3 --> C4["assign rowwise/colwise/scale_inv/amax specs"]
    C4 --> C5["sharded_impl: call impl, pad/slice scale_inv if block_scaling"]
Loading

Reviews (2): Last reviewed commit: "Rename distributed grouped GEMM tests" | Re-trigger Greptile

Comment thread transformer_engine/jax/dense.py Outdated
Comment thread transformer_engine/jax/cpp_extensions/gemm.py Outdated
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-custom-partition-rules branch from 027b3e6 to ff0407d Compare June 1, 2026 15:46
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-custom-partition-rules branch from 70893b7 to 1bd6b54 Compare June 2, 2026 20:11
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia jberchtold-nvidia force-pushed the jberchtold/gmm-custom-partition-rules branch from 59ff8e0 to 3c30c9b Compare June 2, 2026 20:25
pre-commit-ci Bot and others added 5 commits June 2, 2026 20:26
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
@jberchtold-nvidia
Copy link
Copy Markdown
Collaborator Author

/te-ci L1 jax

@jberchtold-nvidia jberchtold-nvidia marked this pull request as ready for review June 3, 2026 23:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant