diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 4f92d1c783..9d481d4332 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -25,6 +25,8 @@ export XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_triton_gemm=false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_dense.xml $TE_PATH/tests/jax/test_distributed_dense.py || test_fail "test_distributed_dense.py" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_grouped_gemm.xml $TE_PATH/tests/jax/test_distributed_grouped_gemm.py || test_fail "test_distributed_grouped_gemm.py" + python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_helper.xml $TE_PATH/tests/jax/test_distributed_helper.py || test_fail "test_distributed_helper.py" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_layernorm.xml $TE_PATH/tests/jax/test_distributed_layernorm.py || test_fail "test_distributed_layernorm.py" @@ -37,9 +39,6 @@ XLA_FLAGS="$XLA_FLAGS --xla_gpu_enable_nccl_comm_splitting=false" python3 -m pyt python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_dist_fused_attn.xml $TE_PATH/tests/jax/test_distributed_fused_attn.py || test_fail "test_distributed_fused_attn.py" -# TODO(Phuong): add this test back after it is verified -# SCRIPT_NAME=$TE_PATH/tests/jax/test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh || test_fail "test_multi_process_distributed_grouped_gemm.py" - if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" exit 1 diff --git a/tests/jax/test_distributed_grouped_gemm.py b/tests/jax/test_distributed_grouped_gemm.py new file mode 100644 index 0000000000..be2487a8df --- /dev/null +++ b/tests/jax/test_distributed_grouped_gemm.py @@ -0,0 +1,503 @@ +# Copyright (c) 2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Partitioning tests for grouped quantize and grouped GEMM.""" + +from types import SimpleNamespace + +import jax +import jax.numpy as jnp +import numpy as np +import pytest +from jax.sharding import Mesh, NamedSharding, PartitionSpec + +from transformer_engine.jax.cpp_extensions.gemm import GroupedGemmPrimitive +from transformer_engine.jax.cpp_extensions.quantization import GroupedQuantizePrimitive +from transformer_engine.jax.dense import grouped_dense +from transformer_engine.jax.quantize import QuantizeLayout, QuantizerFactory, ScalingMode +from transformer_engine.jax.sharding import MeshResource, global_shard_guard + + +def _mesh(): + devices = jax.devices() + if len(devices) < 4: + pytest.skip("Grouped GEMM partitioning tests require at least 4 visible GPUs.") + return Mesh(np.asarray(devices[:4]).reshape(2, 2), ("expert", "fsdp")) + + +def _mesh_with_dp_tp(): + devices = jax.devices() + if len(devices) < 4: + pytest.skip("Grouped GEMM partitioning tests require at least 4 visible GPUs.") + return Mesh(np.asarray(devices[:4]).reshape(2, 1, 2, 1), ("expert", "dp", "fsdp", "tp")) + + +def _mesh_with_arbitrary_axis(): + devices = jax.devices() + if len(devices) < 4: + pytest.skip("Grouped GEMM partitioning tests require at least 4 visible GPUs.") + return Mesh( + np.asarray(devices[:4]).reshape(2, 1, 2, 1), + ("expert", "dp", "fsdp", "myaxis123"), + ) + + +def _arg_info(mesh, shape, spec): + return SimpleNamespace( + shape=shape, + ndim=len(shape), + size=int(np.prod(shape)), + sharding=NamedSharding(mesh, PartitionSpec(*spec)), + ) + + +def _normalize_spec(spec): + if isinstance(spec, PartitionSpec): + return tuple(spec) + return spec + + +def _spec_contains_axis(spec, axis): + for axis_spec in spec: + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + if axis in axis_tuple: + return True + return False + + +def _mxfp8_grouped_quantizer_set(n_groups): + return QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=True, + n_groups=n_groups, + ) + + +def test_grouped_quantize_gathers_hidden_axis_for_block_scales(): + mesh = _mesh() + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 64), ("expert", None, "fsdp")), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (8,), ("expert",)), + ), + (), + ) + + assert tuple(arg_shardings[0].spec) == ("expert", None, None) + specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(specs[0]) == ("expert",) + assert _normalize_spec(specs[2]) == ("expert",) + assert _normalize_spec(specs[4]) == ("expert",) + + +def test_grouped_quantize_mxfp8_colwise_specs_gather_hidden_axis(): + mesh = _mesh() + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE_COLWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", None, "fsdp")), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (8,), ("expert",)), + ), + (), + ) + + assert tuple(arg_shardings[0].spec) == ("expert", None, None) + specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(specs[0]) == ("expert",) + assert _normalize_spec(specs[1]) == ("expert",) + assert _normalize_spec(specs[2]) == ("expert",) + assert _normalize_spec(specs[3]) == ("expert",) + assert _normalize_spec(specs[4]) == ("expert",) + + +def test_grouped_quantize_preserves_row_side_fsdp_for_kernel(): + mesh = _mesh() + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 64), ("expert", "fsdp", None)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (8,), ("expert",)), + ), + (), + ) + + assert tuple(arg_shardings[0].spec) == ("expert", "fsdp", None) + specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(specs[0]) == (("expert", "fsdp"),) + assert _normalize_spec(specs[2]) == (("expert", "fsdp"),) + + +def test_grouped_quantize_strips_unsupported_axes_and_gathers_hidden_axes(): + mesh = _mesh_with_dp_tp() + with jax.set_mesh(mesh), global_shard_guard( + MeshResource(dp_resource="dp", tp_resource="tp", fsdp_resource="fsdp", ep_resource="expert") + ): + with pytest.warns(RuntimeWarning, match="Grouped quantize.*tp"): + _, _, out_shardings, arg_shardings = GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", "dp", ("fsdp", "tp"))), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + ), + (), + ) + + assert tuple(arg_shardings[0].spec) == ("expert", "dp", None) + assert tuple(arg_shardings[1].spec) == ("expert",) + assert tuple(arg_shardings[2].spec) == ("expert",) + + out_specs = tuple(tuple(sharding.spec) for sharding in out_shardings) + assert _normalize_spec(out_specs[0]) == (("expert", "dp"),) + assert _normalize_spec(out_specs[2]) == (("expert", "dp"),) + assert _normalize_spec(out_specs[4]) == ("expert",) + for spec in (*out_specs, *(tuple(sharding.spec) for sharding in arg_shardings)): + assert not _spec_contains_axis(spec, "tp") + + +def test_grouped_gemm_rhs_weight_specs_gather_fsdp_but_preserve_ep(): + mesh = _mesh() + arg_infos = ( + _arg_info(mesh, (8192,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (65536,), (("expert", "fsdp"),)), + _arg_info(mesh, (2048,), (("expert", "fsdp"),)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (1,), (None,)), + _arg_info(mesh, (0,), (None,)), + ) + with global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + _, _, out_sharding, arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + arg_infos, + (), + ) + + assert tuple(arg_shardings[2].spec) == ("expert",) + assert tuple(arg_shardings[3].spec) == ("expert",) + assert tuple(out_sharding[0].spec) == (None, None, None) + + +def test_grouped_gemm_strips_unsupported_axes_preserves_dp_and_gathers_rhs_fsdp(): + mesh = _mesh_with_dp_tp() + arg_infos = ( + _arg_info(mesh, (8192,), (("dp", "tp"),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (65536,), (("expert", "fsdp", "tp"),)), + _arg_info(mesh, (2048,), (("expert", "fsdp", "tp"),)), + _arg_info(mesh, (0,), (("fsdp", "tp"),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (8,), (("expert", "tp"),)), + _arg_info(mesh, (0,), (("tp",),)), + _arg_info(mesh, (1,), (("tp",),)), + _arg_info(mesh, (0,), (("tp",),)), + ) + result_infos = (_arg_info(mesh, (1, 128, 64), ("expert", "tp", None)),) + with jax.set_mesh(mesh), global_shard_guard( + MeshResource(dp_resource="dp", tp_resource="tp", fsdp_resource="fsdp", ep_resource="expert") + ): + with pytest.warns(RuntimeWarning, match="Grouped GEMM.*tp"): + _, _, out_sharding, arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + arg_infos, + result_infos, + ) + + assert tuple(arg_shardings[0].spec) == ("dp",) + assert tuple(arg_shardings[2].spec) == ("expert",) + assert tuple(arg_shardings[3].spec) == ("expert",) + assert tuple(arg_shardings[5].spec) == ("expert",) + assert tuple(out_sharding[0].spec) == ("expert", None, None) + for spec in ( + *(tuple(sharding.spec) for sharding in arg_shardings), + tuple(out_sharding[0].spec), + ): + assert not _spec_contains_axis(spec, "tp") + + +def test_grouped_gemm_reduce_axis_skips_ep_and_uses_dp(): + mesh = _mesh_with_dp_tp() + arg_infos = ( + _arg_info(mesh, (8192,), (("expert", "dp"),)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8192,), (("expert", "dp"),)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (8,), ("expert",)), + _arg_info(mesh, (0,), (None,)), + _arg_info(mesh, (1,), (None,)), + _arg_info(mesh, (0,), (None,)), + ) + + with jax.set_mesh(mesh), global_shard_guard( + MeshResource(dp_resource="dp", fsdp_resource="fsdp", ep_resource="expert") + ): + _, _, reduce_axis = GroupedGemmPrimitive._parse_partition_specs( + mesh, + arg_infos, + (), + out_shape=(1, 128, 64), + lhs_is_trans=False, + lhs_axis_boundary=1, + ) + + assert reduce_axis == "dp" + + +def test_grouped_partitioning_strips_arbitrary_unsupported_axis(): + mesh = _mesh_with_arbitrary_axis() + mesh_resource = MeshResource(dp_resource="dp", fsdp_resource="fsdp", ep_resource="expert") + + with jax.set_mesh(mesh), global_shard_guard(mesh_resource): + with pytest.warns(RuntimeWarning, match="Grouped quantize.*myaxis123"): + _, _, quantize_out_shardings, quantize_arg_shardings = ( + GroupedQuantizePrimitive.partition( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + _arg_info(mesh, (8, 128, 128), ("expert", "myaxis123", ("dp", "fsdp"))), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + ), + (), + ) + ) + + gemm_arg_infos = ( + _arg_info(mesh, (8192,), (("dp", "myaxis123"),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (65536,), (("expert", "fsdp", "myaxis123"),)), + _arg_info(mesh, (2048,), (("expert", "fsdp", "myaxis123"),)), + _arg_info(mesh, (0,), (("fsdp", "myaxis123"),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (8,), (("expert", "myaxis123"),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + _arg_info(mesh, (1,), (("myaxis123",),)), + _arg_info(mesh, (0,), (("myaxis123",),)), + ) + gemm_result_infos = (_arg_info(mesh, (1, 128, 64), ("expert", "myaxis123", None)),) + with pytest.warns(RuntimeWarning, match="Grouped GEMM.*myaxis123"): + _, _, gemm_out_sharding, gemm_arg_shardings = GroupedGemmPrimitive.partition( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 1, + (1, 128, 64), + 128, + 64, + 128, + 64, + mesh, + gemm_arg_infos, + gemm_result_infos, + ) + + assert tuple(quantize_arg_shardings[0].spec) == ("expert", None, None) + assert tuple(quantize_arg_shardings[1].spec) == ("expert",) + quantize_out_specs = tuple(tuple(sharding.spec) for sharding in quantize_out_shardings) + assert _normalize_spec(quantize_out_specs[0]) == ("expert",) + assert _normalize_spec(quantize_out_specs[2]) == ("expert",) + + assert tuple(gemm_arg_shardings[0].spec) == ("dp",) + assert tuple(gemm_arg_shardings[2].spec) == ("expert",) + assert tuple(gemm_arg_shardings[3].spec) == ("expert",) + assert tuple(gemm_out_sharding[0].spec) == ("expert", None, None) + + all_specs = ( + *quantize_out_specs, + *(tuple(sharding.spec) for sharding in quantize_arg_shardings), + *(tuple(sharding.spec) for sharding in gemm_arg_shardings), + tuple(gemm_out_sharding[0].spec), + ) + for spec in all_specs: + assert not _spec_contains_axis(spec, "myaxis123") + + +def test_grouped_partitioning_shardy_rules_smoke(): + mesh = _mesh() + quantize_rule = GroupedQuantizePrimitive.shardy_sharding_rule( + jnp.float8_e4m3fn, + ScalingMode.MXFP8_1D_SCALING.value, + QuantizeLayout.ROWWISE, + -1, + jnp.float8_e8m0fnu, + mesh, + ( + SimpleNamespace(shape=(8, 128, 64)), + SimpleNamespace(shape=(8,)), + SimpleNamespace(shape=(8,)), + ), + ( + SimpleNamespace(shape=(8 * 128 * 64,)), + SimpleNamespace(shape=(1,)), + SimpleNamespace(shape=(8 * 128 * 64,)), + SimpleNamespace(shape=(1,)), + SimpleNamespace(shape=(8,)), + ), + ) + gemm_rule = GroupedGemmPrimitive.shardy_sharding_rule( + False, + False, + ScalingMode.NO_SCALING.value, + jnp.bfloat16, + False, + False, + False, + 1, + 2, + (128, 64), + 128, + 64, + 128, + 64, + mesh, + tuple(SimpleNamespace(shape=(1,)) for _ in range(13)), + (SimpleNamespace(shape=(128, 64)),), + ) + + assert quantize_rule is not None + assert gemm_rule is not None + + +def test_grouped_dense_mxfp8_ep_fsdp_outside_shard_map_single_process(): + mesh = _mesh() + n_groups = 4 + group_tokens = 128 + hidden = 256 + out_hidden = 128 + x_shape = (n_groups * group_tokens, hidden) + w_shape = (n_groups, hidden, out_hidden) + + x_sharding = NamedSharding(mesh, PartitionSpec("expert", None)) + w_sharding = NamedSharding(mesh, PartitionSpec("expert", "fsdp", None)) + group_sharding = NamedSharding(mesh, PartitionSpec("expert")) + out_sharding = NamedSharding(mesh, PartitionSpec("expert", None)) + + quantizer_set = _mxfp8_grouped_quantizer_set(n_groups) + + with mesh, global_shard_guard(MeshResource(fsdp_resource="fsdp", ep_resource="expert")): + x = jax.device_put( + jax.random.normal(jax.random.PRNGKey(20), x_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + x_sharding, + ) + w = jax.device_put( + jax.random.normal(jax.random.PRNGKey(21), w_shape, dtype=jnp.bfloat16) + * jnp.asarray(0.01, dtype=jnp.bfloat16), + w_sharding, + ) + group_sizes = jax.device_put( + jnp.full((n_groups,), group_tokens, dtype=jnp.int32), + group_sharding, + ) + + def apply_with_vjp(x, w, group_sizes): + def apply(x, w): + return grouped_dense( + x, + w, + group_sizes, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + + out, vjp_fn = jax.vjp(apply, x, w) + dx, dw = vjp_fn(out) + return out, dx, dw + + out, dx, dw = jax.jit( + apply_with_vjp, + in_shardings=(x_sharding, w_sharding, group_sharding), + out_shardings=(out_sharding, x_sharding, w_sharding), + )(x, w, group_sizes) + out, dx, dw = jax.block_until_ready((out, dx, dw)) + + assert tuple(out.sharding.spec) == ("expert", None) + assert tuple(dx.sharding.spec) == ("expert", None) + assert tuple(dw.sharding.spec) == ("expert", "fsdp", None) + for value in (out, dx, dw): + local_value = np.asarray(jax.device_get(value.addressable_data(0))) + assert np.all(np.isfinite(local_value)) + assert np.any(local_value != 0.0) diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py deleted file mode 100644 index 94fed0859f..0000000000 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -from functools import partial - -import jax -import jax.numpy as jnp -import jax.experimental.multihost_utils as jem - -from transformer_engine.jax.dense import grouped_dense as te_grouped_dense -from transformer_engine.jax.quantize import ( - QuantizerFactory, - ScalingMode, -) - -from utils import assert_allclose, dtype_tols - - -N_GROUP = 8 -MESH_AXIS_NAME = "fsdp" - - -def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): - assert kernel_fsdp_axis in [1, 2] - x_shape, w_shape = data_shapes - - x_sharding = NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None, None, None)) - w_sharding = ( - NamedSharding(mesh, PartitionSpec(None, None, MESH_AXIS_NAME)) - if kernel_fsdp_axis == 2 - else NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None)) - ) - w_no_sharding = NamedSharding(mesh, PartitionSpec(None, None, None)) - - def init_data(): - x_key = jax.random.PRNGKey(0) - w_key = jax.random.PRNGKey(1) - x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) - w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) - w_amax = jnp.max(jnp.abs(w), axis=range(1, w.ndim)) - return x, w, w, w_amax - - def test_func(outter_x, outter_w, outter_w_amax): - in_specs = (x_sharding.spec, w_sharding.spec, None) - out_specs = x_sharding.spec - - @partial( - shard_map.shard_map, - mesh=mesh, - in_specs=in_specs, - out_specs=out_specs, - check_rep=False, - ) - def sharded_group_gemm(x, w, w_amax): - group_size = x.shape[0] - x_reshaped = x.reshape(-1, x.shape[-1]) - n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) - - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, - is_2x2x=True, - n_groups=group_size, - ) - - output = te_grouped_dense( - x_reshaped, - w, - n_groups, - kernel_amax=w_amax, - quantizer_set=quantizer_set, - kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis), - ) - output = output.reshape(*x.shape[:-1], -1) - return output - - def run(x, w, w_amax): - output = sharded_group_gemm(x, w, w_amax) - return output - - output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_w_amax) - dx, dw, _ = vjp_fn(output) - return output, dx, dw - - def ref_func(outter_x, outter_w): - - in_specs = (x_sharding.spec, w_no_sharding.spec) - out_specs = x_sharding.spec - - @partial( - shard_map.shard_map, - mesh=mesh, - in_specs=in_specs, - out_specs=out_specs, - check_rep=False, - ) - def sharded_group_gemm(x, w): - group_size = x.shape[0] - x_reshaped = x.reshape(-1, x.shape[-1]) - n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) - - quantizer_set = QuantizerFactory.create_set( - scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, - fwd_dtype=jnp.float8_e4m3fn, - bwd_dtype=jnp.float8_e5m2, - is_2x2x=True, - n_groups=group_size, - ) - output = te_grouped_dense(x_reshaped, w, n_groups, quantizer_set=quantizer_set) - output = output.reshape(*x.shape[:-1], -1) - return output - - def run(x, w): - output = sharded_group_gemm(x, w) - return output - - output, vjp_fn = jax.vjp(run, outter_x, outter_w) - dx, dw = vjp_fn(output) - return output, dx, dw - - init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding, None)) - x, w, w_global, w_amax = init_func() - - o_sharding = x_sharding - test_func_jitted = jax.jit( - test_func, - in_shardings=(x_sharding, w_sharding, None), - out_shardings=(o_sharding, x_sharding, w_sharding), - ) - ref_func_jitted = jax.jit( - ref_func, - in_shardings=(x_sharding, w_no_sharding), - out_shardings=(o_sharding, x_sharding, w_no_sharding), - ) - - out, dx, dw = test_func_jitted(x, w, w_amax) - ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) - - e4m3_tols = dtype_tols(jnp.float8_e4m3fn) - e5m2_tols = dtype_tols(jnp.float8_e5m2) - - out, ref_out = jem.process_allgather((out, ref_out)) - dx, ref_dx = jem.process_allgather((dx, ref_dx)) - dw, ref_dw = jem.process_allgather((dw, ref_dw)) - - jnp.allclose(out, ref_out, **e4m3_tols) - jnp.allclose(dx, ref_dx, **e5m2_tols) - jnp.allclose(dw, ref_dw, **e5m2_tols) - - -if __name__ == "__main__": - from jax.sharding import NamedSharding, PartitionSpec - from jax.experimental import shard_map - import sys - - coord_addr = sys.argv[1] - proc_id = int(sys.argv[2]) - num_procs = int(sys.argv[3]) - - jax.distributed.initialize( - coordinator_address=coord_addr, num_processes=num_procs, process_id=proc_id - ) - - mesh = jax.make_mesh((num_procs,), (MESH_AXIS_NAME,)) - - with mesh: - data_shapes = [((4, 16, 128, 7168), (7168, 2048))] - for data_shape in data_shapes: - for kernel_fsdp_axis in [1, 2]: - test_grouped_gemm_fp8_allgather(data_shape, kernel_fsdp_axis) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 4ff6d07986..655aa6e3f9 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -49,7 +49,16 @@ ) from .misc import get_padded_spec, is_all_reduce_in_float32, get_min_device_compute_capability from ..sharding import ( + common_spec_axis, + filter_spec_axes, global_mesh_resource, + local_2d_sizes_from_spec, + local_shape_from_spec, + merge_axis_specs, + spec_axes, + spec_contains_axis, + strip_axis_from_spec, + supported_grouped_partition_axes, tpsp_axis_size, dp_or_fsdp_axis_size, ) @@ -211,6 +220,20 @@ def _get_nvfp4_tensor_scale_inv(amax): return amax / (DATA_DTYPE_MAX * SCALE_DTYPE_MAX) +def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): + ignored_axes = tuple( + axis for axis in spec_axes(original_spec) if axis not in spec_axes(partition_spec) + ) + if ignored_axes: + warnings.warn( + "Grouped GEMM custom partitioning will ignore/replicate sharding " + f"axes {ignored_axes} from {arg_name}; only DP/FSDP/EP grouped " + "partitioning axes are preserved.", + RuntimeWarning, + stacklevel=3, + ) + + def collective_gemm_bootstrap( num_total_devices, num_devices_per_process, @@ -1738,6 +1761,320 @@ def impl( ) return (out,) + @staticmethod + def _parse_partition_specs( + mesh, + arg_infos, + result_infos, + out_shape=None, + lhs_is_trans=None, + lhs_axis_boundary=None, + ): + gsr = global_mesh_resource(validate=False) + fsdp_axis = gsr.fsdp_resource + allowed_axes = supported_grouped_partition_axes(mesh) + + original_arg_specs = tuple(get_padded_spec(arg_info) for arg_info in arg_infos) + lhs_data_spec = filter_spec_axes(original_arg_specs[0], allowed_axes) + lhs_scale_spec = filter_spec_axes(original_arg_specs[1], allowed_axes) + rhs_data_spec = filter_spec_axes(original_arg_specs[2], allowed_axes) + rhs_scale_spec = filter_spec_axes(original_arg_specs[3], allowed_axes) + bias_spec = filter_spec_axes(original_arg_specs[4], allowed_axes) + + lhs_first_dims_spec = filter_spec_axes(original_arg_specs[5], allowed_axes) + lhs_last_dims_spec = filter_spec_axes(original_arg_specs[6], allowed_axes) + rhs_first_dims_spec = filter_spec_axes(original_arg_specs[7], allowed_axes) + rhs_last_dims_spec = filter_spec_axes(original_arg_specs[8], allowed_axes) + out_first_dims_spec = filter_spec_axes(original_arg_specs[9], allowed_axes) + out_last_dims_spec = filter_spec_axes(original_arg_specs[10], allowed_axes) + additional_arg_0_spec = filter_spec_axes(original_arg_specs[11], allowed_axes) + additional_arg_1_spec = filter_spec_axes(original_arg_specs[12], allowed_axes) + + grouped_dim_specs = ( + lhs_first_dims_spec, + lhs_last_dims_spec, + rhs_first_dims_spec, + rhs_last_dims_spec, + out_first_dims_spec, + out_last_dims_spec, + ) + grouped_dim_infos = arg_infos[5:11] + active_group_spec = next( + (spec for spec, info in zip(grouped_dim_specs, grouped_dim_infos) if info.size > 0), + (None,), + ) + if arg_infos[11].size > 1: + additional_arg_0_spec = active_group_spec + if arg_infos[12].size > 1: + additional_arg_1_spec = active_group_spec + + rhs_is_ragged = arg_infos[7].size > 0 or arg_infos[8].size > 0 + ep_axis = gsr.ep_resource + if ( + ep_axis is not None + and not rhs_is_ragged + and spec_contains_axis(active_group_spec, ep_axis) + ): + if len(rhs_data_spec) > 0 and not spec_contains_axis(rhs_data_spec, ep_axis): + rhs_data_spec = ( + merge_axis_specs(rhs_data_spec[0], ep_axis), + *rhs_data_spec[1:], + ) + if len(rhs_scale_spec) > 0 and not spec_contains_axis(rhs_scale_spec, ep_axis): + rhs_scale_spec = ( + merge_axis_specs(rhs_scale_spec[0], ep_axis), + *rhs_scale_spec[1:], + ) + if len(bias_spec) > 0 and not spec_contains_axis(bias_spec, ep_axis): + bias_spec = (merge_axis_specs(bias_spec[0], ep_axis), *bias_spec[1:]) + + gather_rhs_fsdp = ( + fsdp_axis is not None + and not rhs_is_ragged + and ( + spec_contains_axis(rhs_data_spec, fsdp_axis) + or spec_contains_axis(rhs_scale_spec, fsdp_axis) + or spec_contains_axis(bias_spec, fsdp_axis) + ) + ) + + if gather_rhs_fsdp: + rhs_data_spec = strip_axis_from_spec(rhs_data_spec, fsdp_axis) + rhs_scale_spec = strip_axis_from_spec(rhs_scale_spec, fsdp_axis) + bias_spec = strip_axis_from_spec(bias_spec, fsdp_axis) + + reducible_axes = tuple( + axis for axis in (gsr.dp_resource, gsr.fsdp_resource) if axis is not None + ) + reduce_axis = common_spec_axis(lhs_data_spec, rhs_data_spec, reducible_axes) + if reduce_axis is not None and gather_rhs_fsdp: + reduce_axis = None + + if result_infos: + original_out_spec = get_padded_spec(result_infos[0]) + out_spec = filter_spec_axes(original_out_spec, allowed_axes) + else: + original_out_spec = None + out_spec = (None,) * (len(out_shape) if out_shape is not None else 1) + + if rhs_is_ragged and lhs_is_trans is not None and lhs_axis_boundary is not None: + lhs_non_contracting_dims = ( + range(lhs_axis_boundary, len(lhs_data_spec)) + if lhs_is_trans + else range(0, lhs_axis_boundary) + ) + lhs_data_spec = list(lhs_data_spec) + for out_idx, lhs_dim in enumerate(lhs_non_contracting_dims, start=1): + if out_idx < len(out_spec): + lhs_data_spec[lhs_dim] = merge_axis_specs( + lhs_data_spec[lhs_dim], out_spec[out_idx] + ) + lhs_data_spec = tuple(lhs_data_spec) + + final_arg_specs = ( + lhs_data_spec, + lhs_scale_spec, + rhs_data_spec, + rhs_scale_spec, + bias_spec, + lhs_first_dims_spec, + lhs_last_dims_spec, + rhs_first_dims_spec, + rhs_last_dims_spec, + out_first_dims_spec, + out_last_dims_spec, + additional_arg_0_spec, + additional_arg_1_spec, + ) + for arg_name, original_spec, partition_spec in zip( + ( + "lhs_data", + "lhs_scale_inv", + "rhs_data", + "rhs_scale_inv", + "bias", + "lhs_first_dims", + "lhs_last_dims", + "rhs_first_dims", + "rhs_last_dims", + "out_first_dims", + "out_last_dims", + "additional_arg_0", + "additional_arg_1", + ), + original_arg_specs, + final_arg_specs, + ): + _warn_if_axes_ignored(arg_name, original_spec, partition_spec) + if original_out_spec is not None: + _warn_if_axes_ignored("output", original_out_spec, out_spec) + + return ( + final_arg_specs, + out_spec, + reduce_axis, + ) + + @staticmethod + def partition( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + arg_infos, + result_infos, + ): + arg_specs, out_spec, reduce_axis = GroupedGemmPrimitive._parse_partition_specs( + mesh, + arg_infos, + result_infos, + out_shape, + lhs_is_trans=lhs_is_trans, + lhs_axis_boundary=lhs_axis_boundary, + ) + arg_shardings = tuple(NamedSharding(mesh, PartitionSpec(*spec)) for spec in arg_specs) + out_sharding = (NamedSharding(mesh, PartitionSpec(*out_spec)),) + local_out_shape = local_shape_from_spec(out_shape, out_spec, mesh) + local_lhs_left_size, local_lhs_right_size = local_2d_sizes_from_spec( + arg_infos[0].shape, + arg_specs[0], + lhs_axis_boundary, + lhs_left_size, + lhs_right_size, + mesh, + ) + local_rhs_left_size, local_rhs_right_size = local_2d_sizes_from_spec( + arg_infos[2].shape, + arg_specs[2], + rhs_axis_boundary, + rhs_left_size, + rhs_right_size, + mesh, + ) + + def sharded_impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, + additional_arg_0, + additional_arg_1, + ): + (out,) = GroupedGemmPrimitive.impl( + lhs_data, + lhs_scale_inv, + rhs_data, + rhs_scale_inv, + bias, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, + additional_arg_0, + additional_arg_1, + lhs_is_trans=lhs_is_trans, + rhs_is_trans=rhs_is_trans, + scaling_mode=scaling_mode, + out_dtype=out_dtype, + has_bias=has_bias, + use_async_d2h_group_sizes=use_async_d2h_group_sizes, + use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + out_shape=local_out_shape, + lhs_left_size=local_lhs_left_size, + lhs_right_size=local_lhs_right_size, + rhs_left_size=local_rhs_left_size, + rhs_right_size=local_rhs_right_size, + ) + + if reduce_axis is not None: + if is_all_reduce_in_float32(): + out = jax.lax.psum(out.astype(jnp.float32), reduce_axis).astype(out_dtype) + else: + out = jax.lax.psum(out, reduce_axis) + return (out,) + + return mesh, sharded_impl, out_sharding, arg_shardings + + @staticmethod + def shardy_sharding_rule( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + operand_types, + result_types, + ): + del ( + lhs_is_trans, + rhs_is_trans, + scaling_mode, + out_dtype, + has_bias, + use_async_d2h_group_sizes, + use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + out_shape, + lhs_left_size, + lhs_right_size, + rhs_left_size, + rhs_right_size, + mesh, + ) + + prefix = "GroupedGemm" + + def spec_for(name, rank): + if rank == 0: + return () + return tuple(f"{prefix}_{name}_{i}" for i in range(rank)) + + operand_mappings = tuple( + spec_for(f"arg{i}", len(operand_type.shape)) + for i, operand_type in enumerate(operand_types) + ) + result_mappings = tuple( + spec_for(f"out{i}", len(result_type.shape)) + for i, result_type in enumerate(result_types) + ) + return SdyShardingRule( + operand_mappings=operand_mappings, + result_mappings=result_mappings, + ) + register_primitive(GroupedGemmPrimitive) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 7138cfcf40..9266ab08f0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -6,6 +6,7 @@ from functools import reduce from typing import Tuple, Optional, Union import math +import warnings import jax @@ -31,7 +32,14 @@ from ..sharding import ( all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp, + filter_spec_axes, get_num_devices_in_mesh, + global_mesh_resource, + lax_paral_op, + local_shape_from_spec, + merge_axis_specs, + spec_axes, + supported_grouped_partition_axes, ) from ..quantize import ( ScaledTensor2x, @@ -52,6 +60,54 @@ __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] +def _flat_data_spec(input_spec): + return (merge_axis_specs(*input_spec),) + + +def _normalize_flatten_axis(flatten_axis, ndim): + return flatten_axis + ndim if flatten_axis < 0 else flatten_axis + + +def _contiguous_flat_input_spec(input_spec, flatten_axis): + flatten_axis = _normalize_flatten_axis(flatten_axis, len(input_spec)) + if flatten_axis <= 0 or len(input_spec) == 0: + return (None,) * len(input_spec) + return (*input_spec[:flatten_axis], *((None,) * (len(input_spec) - flatten_axis))) + + +def _warn_if_axes_ignored(arg_name, original_spec, partition_spec): + ignored_axes = tuple( + axis for axis in spec_axes(original_spec) if axis not in spec_axes(partition_spec) + ) + if ignored_axes: + warnings.warn( + "Grouped quantize custom partitioning will ignore/replicate sharding " + f"axes {ignored_axes} from {arg_name}; only supported packed grouped " + "data axes are preserved.", + RuntimeWarning, + stacklevel=3, + ) + + +def _pad_or_slice_to_shape(x, target_shape): + if target_shape is None or x.shape == target_shape: + return x + target_size = math.prod(target_shape) + current_size = math.prod(x.shape) + x = x.reshape(-1) + if current_size > target_size: + return x[:target_size].reshape(target_shape) + return jnp.pad(x, (0, target_size - current_size)).reshape(target_shape) + + +def _all_reduce_grouped_amax_along_dp_fsdp(amax, mesh): + gsr = global_mesh_resource() + for axis in (gsr.dp_resource, gsr.fsdp_resource): + if axis is not None and axis in mesh.axis_names: + amax = lax_paral_op(amax, jax.lax.pmax, axis, mesh) + return amax + + class BaseDBiasQuantizePrimitive(BasePrimitive): """ Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias @@ -1236,6 +1292,152 @@ def impl( ) return rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax + @staticmethod + def _parse_partition_specs(scaling_mode, q_layout, flatten_axis, mesh, arg_infos): + allowed_axes = supported_grouped_partition_axes(mesh) + original_x_spec = get_padded_spec(arg_infos[0]) + x_spec = filter_spec_axes(original_x_spec, allowed_axes) + x_spec = _contiguous_flat_input_spec(x_spec, flatten_axis) + _warn_if_axes_ignored("x", original_x_spec, x_spec) + + original_group_spec = get_padded_spec(arg_infos[2]) + group_spec = filter_spec_axes(original_group_spec, allowed_axes) + if group_spec == (None,) and len(x_spec) > 0: + group_spec = (x_spec[0],) + _warn_if_axes_ignored("group_sizes", original_group_spec, group_spec) + flat_spec = _flat_data_spec(x_spec) + replicated_spec = (None,) + + rowwise_out_spec = flat_spec if q_layout.has_rowwise else replicated_spec + colwise_out_spec = flat_spec if q_layout.has_colwise else replicated_spec + + rowwise_scale_inv_spec = replicated_spec + colwise_scale_inv_spec = replicated_spec + if ScalingMode(scaling_mode).is_block_scaling: + rowwise_scale_inv_spec = flat_spec if q_layout.has_rowwise else replicated_spec + colwise_scale_inv_spec = flat_spec if q_layout.has_colwise else replicated_spec + elif ScalingMode(scaling_mode).is_tensor_scaling(): + rowwise_scale_inv_spec = group_spec if q_layout.has_rowwise else replicated_spec + colwise_scale_inv_spec = group_spec if q_layout.has_colwise else replicated_spec + + updated_amax_spec = group_spec + return ( + x_spec, + group_spec, + ( + rowwise_out_spec, + colwise_out_spec, + rowwise_scale_inv_spec, + colwise_scale_inv_spec, + updated_amax_spec, + ), + ) + + @staticmethod + def partition( + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + scale_dtype, + mesh, + arg_infos, + result_infos, + ): + x_spec, group_spec, out_specs = GroupedQuantizePrimitive._parse_partition_specs( + scaling_mode, q_layout, flatten_axis, mesh, arg_infos + ) + local_out_shapes = ( + tuple( + local_shape_from_spec(info.shape, spec, mesh) + for info, spec in zip(result_infos, out_specs) + ) + if result_infos + else (None,) * len(out_specs) + ) + + arg_shardings = ( + NamedSharding(mesh, PartitionSpec(*x_spec)), + NamedSharding(mesh, PartitionSpec(*group_spec)), + NamedSharding(mesh, PartitionSpec(*group_spec)), + ) + out_shardings = tuple(NamedSharding(mesh, PartitionSpec(*spec)) for spec in out_specs) + + def sharded_impl(x, scale, group_sizes): + ( + rowwise_out, + colwise_out, + rowwise_scale_inv, + colwise_scale_inv, + updated_amax, + ) = GroupedQuantizePrimitive.impl( + x, + scale, + group_sizes, + out_dtype=out_dtype, + scaling_mode=scaling_mode, + q_layout=q_layout, + flatten_axis=flatten_axis, + scale_dtype=scale_dtype, + ) + if ScalingMode(scaling_mode).is_block_scaling: + rowwise_scale_inv = _pad_or_slice_to_shape(rowwise_scale_inv, local_out_shapes[2]) + colwise_scale_inv = _pad_or_slice_to_shape(colwise_scale_inv, local_out_shapes[3]) + if ScalingMode(scaling_mode).is_tensor_scaling(): + updated_amax = _all_reduce_grouped_amax_along_dp_fsdp(updated_amax, mesh) + return ( + rowwise_out, + colwise_out, + rowwise_scale_inv, + colwise_scale_inv, + updated_amax, + ) + + return mesh, sharded_impl, out_shardings, arg_shardings + + @staticmethod + def shardy_sharding_rule( + out_dtype, + scaling_mode, + q_layout, + flatten_axis, + scale_dtype, + mesh, + value_types, + result_types, + ): + del out_dtype, scale_dtype, mesh, result_types, flatten_axis + + prefix = "GroupedQuantize" + input_spec = tuple(f"{prefix}_x_{i}" for i in range(len(value_types[0].shape))) + flat_spec = (f"{prefix}_flat",) + group_spec = (BATCHING + f"{prefix}_group",) + scalar_spec = (BATCHING + f"{prefix}_scalar",) + + rowwise_out_spec = flat_spec if q_layout.has_rowwise else scalar_spec + colwise_out_spec = flat_spec if q_layout.has_colwise else scalar_spec + + if ScalingMode(scaling_mode).is_block_scaling: + rowwise_scale_spec = flat_spec if q_layout.has_rowwise else scalar_spec + colwise_scale_spec = flat_spec if q_layout.has_colwise else scalar_spec + elif ScalingMode(scaling_mode).is_tensor_scaling(): + rowwise_scale_spec = group_spec if q_layout.has_rowwise else scalar_spec + colwise_scale_spec = group_spec if q_layout.has_colwise else scalar_spec + else: + rowwise_scale_spec = scalar_spec + colwise_scale_spec = scalar_spec + + return SdyShardingRule( + operand_mappings=(input_spec, group_spec, group_spec), + result_mappings=( + rowwise_out_spec, + colwise_out_spec, + rowwise_scale_spec, + colwise_scale_spec, + group_spec, + ), + ) + register_primitive(GroupedQuantizePrimitive) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index f8c30ffccb..b60810f5eb 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -26,34 +26,6 @@ ) -def _all_gather_kernel(kernel, mesh_axis, axis_idx): - assert mesh_axis is not None - assert 0 < axis_idx < len(kernel.shape) - - # TODO(Ming Hunag): Add a condition branch for with/without shmap. - kernel_shape = kernel.shape - kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :]) - global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx) - global_kernel = global_kernel.reshape(*kernel_whole_shape) - return global_kernel - - -def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx): - assert mesh_axis is not None - assert 0 < axis_idx < len(scattered_kernel_shape) - - # TODO(Ming Hunag): Add a condition branch for with/without shmap. - kernel = kernel.reshape( - *scattered_kernel_shape[:axis_idx], - -1, - scattered_kernel_shape[axis_idx], - *scattered_kernel_shape[axis_idx + 1 :], - ) - kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx) - kernel = kernel.reshape(scattered_kernel_shape) - return kernel - - def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -325,7 +297,6 @@ def grouped_dense( preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, quantizer_set: QuantizerSet = noop_quantizer_set, - kernel_fsdp_info: Tuple[str, int] = (None, -1), ): """ Perform grouped dense (linear) layer transformation with optional quantization. @@ -341,14 +312,15 @@ def grouped_dense( preferred_element_type: Preferred data type for the output tensor group_offset: 1D array containing offsets for each group (not yet implemented) quantizer_set: Set of quantizers for FP8 quantization of the input and output - kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix - represented in the format (str, int). The first element is the - FSDP mesh axis, and the second element is the dimension along - which the weight is sharded. Returns: A jnp.ndarray containing the result of the grouped linear operation """ + x_contracting_dims, kernel_contracting_dims = contracting_dims + x_contracting_dims = tex.sanitize_dims(x.ndim, x_contracting_dims) + kernel_contracting_dims = tex.sanitize_dims(kernel.ndim, kernel_contracting_dims) + contracting_dims = (x_contracting_dims, kernel_contracting_dims) + output = _grouped_dense( x, kernel, @@ -359,12 +331,11 @@ def grouped_dense( preferred_element_type, group_offset, quantizer_set, - kernel_fsdp_info, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7, 9)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) def _grouped_dense( x, kernel, @@ -375,7 +346,6 @@ def _grouped_dense( preferred_element_type, group_offset, quantizer_set, - kernel_fsdp_info, ): output, _ = _grouped_dense_fwd_rule( x, @@ -387,7 +357,6 @@ def _grouped_dense( preferred_element_type, group_offset, quantizer_set, - kernel_fsdp_info, ) return output @@ -402,16 +371,11 @@ def _grouped_dense_fwd_rule( preferred_element_type, group_offset, quantizer_set, - kernel_fsdp_info, ): use_bias = bias is not None - kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info - kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None - assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." - del kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx, kernel_fsdp_info, kernel_fsdp_enabled - x_contracting_dims, k_contracting_dims = contracting_dims + flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis @@ -476,12 +440,8 @@ def _grouped_dense_fwd_rule( def _grouped_dense_bwd_rule( - contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad + contracting_dims, precision, preferred_element_type, group_offset, ctx, grad ): - kernel_fsdp_mesh_axis, _ = kernel_fsdp_info - kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None - assert not kernel_fsdp_enabled, "FSDP sharding for grouped_dense is not supported yet." - fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims ( diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 17c9a242f0..14783ecbe2 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1471,7 +1471,7 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa x, kernel, group_sizes=group_sizes, - contracting_dims=((1,), (1,)), + contracting_dims=((-1,), (1,)), quantizer_set=quantizer_set, ) return out diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 9b13412c14..86ddf72ff7 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -12,6 +12,7 @@ from contextlib import contextmanager from dataclasses import dataclass from typing import Callable, Optional +import math import warnings import jax @@ -133,8 +134,8 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): """ A wrapper function to jax.lax.with_sharding_constraint 1. Does nothing if mesh is empty. - 2. If all mesh axes are manual axes, replaces pspec with all Nones. - 3. Otherwise, strips only the manual axes. + 2. Keeps only auto axes in pspec. + 3. Returns x unchanged if no auto axes remain. """ if pspec is None: return x @@ -143,22 +144,21 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): if mesh.empty: return x - # We want to exclude the axes that already used by shard_map and shard_map - # only sets those in the abstract_mesh, not the physical one - manual_axis_names = get_abstract_mesh().manual_axes + # with_sharding_constraint can only refer to auto axes. Explicit axes are + # already fixed by the active mesh, and manual axes are managed by shard_map. + abstract_mesh = get_abstract_mesh() + auto_axis_names = set(abstract_mesh.auto_axes) # Multiple mesh axes can be mapped to a single shape axis, so we need to unpack and process tuples here too - def filter_manual_axes(name_or_tuple): + def filter_non_auto_axes(name_or_tuple): if isinstance(name_or_tuple, tuple): - out = tuple(n for n in name_or_tuple if n not in manual_axis_names) + out = tuple(n for n in name_or_tuple if n in auto_axis_names) if len(out) == 0: return None return out - if name_or_tuple in manual_axis_names: - return None - return name_or_tuple + return name_or_tuple if name_or_tuple in auto_axis_names else None - cleaned_axis_names = tuple(filter_manual_axes(name_or_tuple) for name_or_tuple in pspec) + cleaned_axis_names = tuple(filter_non_auto_axes(name_or_tuple) for name_or_tuple in pspec) if cleaned_axis_names == (None,) * len(cleaned_axis_names): return x @@ -234,6 +234,155 @@ def get_padded_spec(spec, ndim): return spec + (None,) * (ndim - len(spec)) +def spec_axes(spec): + """Return unique non-None mesh axes used by a PartitionSpec-like tuple.""" + axes = [] + for axis_spec in spec: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + return axes + + +def axis_spec_contains(axis_spec, axis): + """Return whether one dimension's axis spec contains a mesh axis.""" + if axis is None or axis_spec is None: + return False + if isinstance(axis_spec, tuple): + return axis in axis_spec + return axis_spec == axis + + +def spec_contains_axis(spec, axis): + """Return whether a PartitionSpec-like tuple contains a mesh axis.""" + return any(axis_spec_contains(axis_spec, axis) for axis_spec in spec) + + +def filter_axis_spec(axis_spec, allowed_axes): + """Keep only allowed axes in one dimension's axis spec.""" + if axis_spec is None: + return None + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axes = tuple(axis for axis in axis_tuple if axis in allowed_axes) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else axes + + +def filter_spec_axes(spec, allowed_axes): + """Keep only allowed axes in a PartitionSpec-like tuple.""" + return tuple(filter_axis_spec(axis_spec, allowed_axes) for axis_spec in spec) + + +def supported_grouped_partition_axes(mesh): + """Return mesh axes supported by grouped quantize/GEMM custom partitioning.""" + gsr = global_mesh_resource(validate=False) + return { + axis + for axis in (gsr.ep_resource, gsr.dp_resource, gsr.fsdp_resource) + if axis is not None and axis in mesh.axis_names + } + + +def strip_axis_from_axis_spec(axis_spec, axis): + """Remove one mesh axis from one dimension's axis spec.""" + if axis is None or axis_spec is None: + return axis_spec + if isinstance(axis_spec, tuple): + stripped = tuple(a for a in axis_spec if a != axis) + if len(stripped) == 0: + return None + return stripped[0] if len(stripped) == 1 else stripped + return None if axis_spec == axis else axis_spec + + +def strip_axis_from_spec(spec, axis): + """Remove one mesh axis from a PartitionSpec-like tuple.""" + return tuple(strip_axis_from_axis_spec(axis_spec, axis) for axis_spec in spec) + + +def merge_axis_specs(*axis_specs): + """Merge dimension axis specs while preserving first-seen axis order.""" + axes = [] + for axis_spec in axis_specs: + if axis_spec is None: + continue + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + for axis in axis_tuple: + if axis is not None and axis not in axes: + axes.append(axis) + if len(axes) == 0: + return None + return axes[0] if len(axes) == 1 else tuple(axes) + + +def common_spec_axis(spec_a, spec_b, allowed_axes=None): + """Return the first mesh axis that appears in both specs.""" + axes = [] + for spec in (spec_a, spec_b): + for axis in spec_axes(spec): + if axis not in axes: + axes.append(axis) + for axis in axes: + if allowed_axes is not None and axis not in allowed_axes: + continue + if spec_contains_axis(spec_a, axis) and spec_contains_axis(spec_b, axis): + return axis + return None + + +def axis_spec_size(axis_spec, mesh): + """Return the device count represented by one dimension's axis spec.""" + axis_tuple = axis_spec if isinstance(axis_spec, tuple) else (axis_spec,) + axis_size = 1 + for axis in axis_tuple: + if axis is not None: + axis_size *= mesh.shape[axis] + return axis_size + + +def spec_size(spec, mesh): + """Return the total device count represented by a PartitionSpec-like tuple.""" + axis_size = 1 + for axis_spec in spec: + axis_size *= axis_spec_size(axis_spec, mesh) + return axis_size + + +def local_shape_from_spec(global_shape, spec, mesh): + """Derive a local shape from a global shape and PartitionSpec-like tuple.""" + local_shape = [] + for dim, axis_spec in zip(global_shape, spec): + local_shape.append(dim // axis_spec_size(axis_spec, mesh)) + return tuple(local_shape) + + +def local_2d_sizes_from_spec(shape, spec, axis_boundary, left_size, right_size, mesh): + """Derive local collapsed 2D dimensions from a global shape and sharding spec.""" + if len(shape) == len(spec) and len(shape) > 1: + local_shape = local_shape_from_spec(shape, spec, mesh) + return ( + math.prod(local_shape[:axis_boundary]), + math.prod(local_shape[axis_boundary:]), + ) + + size = spec_size(spec, mesh) + if size == 1: + return left_size, right_size + if left_size % size == 0: + return left_size // size, right_size + if right_size % size == 0: + return left_size, right_size // size + raise ValueError( + "Cannot derive local 2D sizes from sharding spec. " + f"shape={shape}, spec={spec}, axis_boundary={axis_boundary}, " + f"left_size={left_size}, right_size={right_size}, spec_size={size}" + ) + + def lax_paral_op( x: jnp.array, ops: Callable, mesh_resource: str, mesh: jax.sharding.Mesh, **kwargs ): @@ -330,6 +479,7 @@ class MeshResource: tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None tpsp_resource: Axis name for tensor sequence parallelism (hidden and sequence sharding), default is None fsdp_resource: Axis name for full-sharded data parallelism, default is None + ep_resource: Axis name for expert parallelism (expert sharding), default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None """ @@ -338,6 +488,7 @@ class MeshResource: tp_resource: str = None tpsp_resource: str = None fsdp_resource: str = None + ep_resource: str = None pp_resource: str = None cp_resource: str = None @@ -364,7 +515,7 @@ def global_shard_guard(resource: MeshResource): _GLOBAL_MESH_RESOURCE = old_resources -def global_mesh_resource() -> MeshResource: +def global_mesh_resource(validate: bool = True) -> MeshResource: """Get the current global mesh resource configuration. Returns: @@ -375,7 +526,8 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) - _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) + if validate: + _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) return _GLOBAL_MESH_RESOURCE