[JAX] Grouped quant+GEMM custom partitioning rules#3058
[JAX] Grouped quant+GEMM custom partitioning rules#3058jberchtold-nvidia wants to merge 12 commits into
Conversation
…mm-custom-partition-rules
Greptile SummaryThis PR adds custom partitioning rules (
Confidence Score: 5/5Safe to merge; the custom partitioning logic is well-tested with direct unit tests and an end-to-end JIT+VJP test, the removal of dead kernel_fsdp_info plumbing is clean, and the new sharding.py helpers are self-contained. All changes are additive (new partition/shardy_sharding_rule statics) or clean removals of previously-guarded-dead code. The two findings are speculative edge cases — the scale-input warning gap requires a caller to pass an unusual scale sharding, and the reduce_axis/out_spec inconsistency requires JAX's result_infos to supply a dp-sharded output spec in a configuration where both operands are dp-contracted, which does not arise in the documented EP/FSDP usage patterns. The test suite directly exercises the partition spec logic for all supported axis combinations including unsupported-axis stripping, and the single-process end-to-end test verifies correct sharding on forward and backward passes. transformer_engine/jax/cpp_extensions/gemm.py and transformer_engine/jax/cpp_extensions/quantization.py warrant a second look for the two minor gaps noted above. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["JAX jit / partition call"] --> B["GroupedGemmPrimitive._parse_partition_specs"]
A --> C["GroupedQuantizePrimitive._parse_partition_specs"]
B --> B1["filter_spec_axes → allowed EP/DP/FSDP only"]
B1 --> B2{"rhs_is_ragged?"}
B2 -- No --> B3{"ep_axis in group_spec?"}
B3 -- Yes --> B4["Add ep_axis to rhs_data_spec[0]"]
B2 -- Yes --> B5["skip ep injection"]
B4 --> B6{"rhs has fsdp?"}
B5 --> B6
B6 -- Yes --> B7["strip fsdp from rhs → gather_rhs_fsdp=True"]
B6 -- No --> B8["gather_rhs_fsdp=False"]
B7 --> B9["reduce_axis = common_spec_axis(lhs, rhs, dp|fsdp)"]
B8 --> B9
B9 --> B10{"gather_rhs_fsdp AND reduce_axis set?"}
B10 -- Yes --> B11["clear reduce_axis"]
B10 -- No --> B12["keep reduce_axis"]
B11 --> B13["return arg_specs, out_spec, reduce_axis"]
B12 --> B13
B13 --> B14["sharded_impl: call impl with local sizes"]
B14 --> B15{"reduce_axis != None?"}
B15 -- Yes --> B16["psum(out, reduce_axis)"]
B15 -- No --> B17["return out as-is"]
B16 --> B17
C --> C1["filter_spec_axes → allowed EP/DP/FSDP only"]
C1 --> C2["_contiguous_flat_input_spec: strip dims after flatten_axis"]
C2 --> C3["_flat_data_spec: merge all dims → 1D spec"]
C3 --> C4["assign rowwise/colwise/scale_inv/amax specs"]
C4 --> C5["sharded_impl: call impl, pad/slice scale_inv if block_scaling"]
Reviews (2): Last reviewed commit: "Rename distributed grouped GEMM tests" | Re-trigger Greptile |
027b3e6 to
ff0407d
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
70893b7 to
1bd6b54
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
59ff8e0 to
3c30c9b
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci L1 jax |
Description
Adds custom partitioning rules to the grouped quantization and grouped GEMM primitives to support DP/FSDP and EP shardings
Type of change
Changes
Checklist: