From f4791dbc6453ebc18624a5db434d82945813aac5 Mon Sep 17 00:00:00 2001 From: Shuwen-Fang Date: Mon, 15 Jun 2026 21:15:55 +0000 Subject: [PATCH] refactor --- src/maxtext/layers/moe.py | 210 ++++++++++++++++++++------------------ 1 file changed, 112 insertions(+), 98 deletions(-) diff --git a/src/maxtext/layers/moe.py b/src/maxtext/layers/moe.py index 020956098c..86229ad2c1 100644 --- a/src/maxtext/layers/moe.py +++ b/src/maxtext/layers/moe.py @@ -1640,6 +1640,85 @@ def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather): layer_w1 = adc.checkpoint_name(layer_w1, "moe_mlpwi_1") return self.apply_ffn_activation(layer_w0, layer_w1) + def get_gmm_for_local_experts(x, routing, route_metadata): + """Return a partial GMM function with preconfigured routing params.""" + num_ep = self.get_expert_parallelism_size() + num_experts_per_shard = self.config.num_experts // num_ep + use_truncated_buffer = self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0] + if use_truncated_buffer: + local_group_sizes = routing.local_group_sizes + return functools.partial( + gmm, + group_sizes=local_group_sizes, + expert_assignments=routing.selected_experts, + group_offset=0, + ) + if self.config.use_ragged_sort and self.config.use_ring_of_experts: + experts_start = route_metadata.expert_shard_id * num_experts_per_shard + else: + experts_start = 0 + return functools.partial( + gmm, + group_sizes=routing.group_sizes, + expert_assignments=routing.selected_experts, + group_offset=experts_start, + ) + + def unsort_output_with_ra2a(intermediate_output, routing, route_metadata, output_shape, is_batch_sharded_by_expert): + """Unsort tokens and return them to original shards using ragged all-to-all.""" + if is_batch_sharded_by_expert: + # locally unpermute back to the original order + if self.config.use_ragged_sort: + # Mirror the ragged-prefix gather used in `local_permute`. The + # un-permute can use the same valid-prefix length because the + # routed token count is identical for forward and backward. + valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32) + local_output = a2a_ragged_unsort( + intermediate_output, + jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable + valid_end, + ) + else: + local_output = _sort_activations( + intermediate_output, + jnp.argsort(route_metadata.local_sorted_indices), + self.config.use_custom_sort_vjp, + ) + + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + jnp.transpose(route_metadata.all_shards_group_sizes), + route_metadata.expert_shard_id, + self.get_expert_parallelism_size(), + ) + return jax.lax.ragged_all_to_all( + local_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) + + # If batch is replicated across EP shards then each shard should send + # 0..local_shard_size data to the other shards and receive the + # local_shard data from all of the other shards using ragged_all_to_all. + input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( + route_metadata.reshaped_group_sizes, + route_metadata.expert_shard_id, + self.get_expert_parallelism_size(), + is_batch_sharded=False, + ) + return jax.lax.ragged_all_to_all( + intermediate_output, + output_shape, + input_offsets, + send_sizes, + output_offsets, + recv_sizes, + axis_name=self._expert_parallelism_name, + ) + @functools.partial( jax.shard_map, mesh=self.mesh, @@ -1663,36 +1742,16 @@ def gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather): ), check_vma=self.config.check_vma, ) - def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, sharded_input_ids, rngs): + def sparse_matmul_route_and_compute( + x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, sharded_input_ids, rngs + ): batch_size, sequence_length, _ = x.shape x, routing, route_metadata = route(x, logits, pre_bias_logits, rngs, input_ids=sharded_input_ids) if self.config.mlp_bias: w0_bias, w1_bias, wo_bias = self.transform_bias(routing.selected_experts, w0_bias, w1_bias, wo_bias) - num_ep = self.get_expert_parallelism_size() - num_experts_per_shard = self.config.num_experts // num_ep - - use_truncated_buffer = self.config.use_ring_of_experts and x.shape[0] < routing.sorted_selected_experts.shape[0] - if use_truncated_buffer: - local_group_sizes = routing.local_group_sizes - gmm_fn = functools.partial( - gmm, - group_sizes=local_group_sizes, - expert_assignments=routing.selected_experts, - group_offset=0, - ) - else: - if self.config.use_ragged_sort and self.config.use_ring_of_experts: - experts_start = route_metadata.expert_shard_id * num_experts_per_shard - else: - experts_start = 0 - gmm_fn = functools.partial( - gmm, - group_sizes=routing.group_sizes, - expert_assignments=routing.selected_experts, - group_offset=experts_start, - ) + gmm_fn = get_gmm_for_local_experts(x, routing, route_metadata) intermediate_layer = gmm_up(x, w0, w1, w0_bias, w1_bias, gmm_fn, weight_gather) wo_gather_axes, wo_tile_size = get_wo_gmm_params() @@ -1727,83 +1786,38 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, s output, (-1, sequence_length, self.moe_expert_input_dim // self.get_tensor_parallelism_size()) ) output = jax.lax.psum_scatter(output, self._expert_parallelism_name, scatter_dimension=0, tiled=True) + return output, routing.lb_loss, routing.bias_updates + + if self.get_expert_parallelism_size() > 1: + original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok + if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim: + raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!") + output_shape = jax.lax.empty( + ( + original_inputs_first_dim, + self.moe_expert_input_dim // self.get_tensor_parallelism_size(), + ), + dtype=intermediate_output.dtype, + ) - else: - if self.get_expert_parallelism_size() > 1: - original_inputs_first_dim = batch_size * sequence_length * self.config.num_experts_per_tok - if routing.sorted_selected_experts.shape[0] != original_inputs_first_dim: - raise ValueError("original_inputs_first_dim does not match the original tensor" " shape!") - output_shape = jax.lax.empty( - ( - original_inputs_first_dim, - self.moe_expert_input_dim // self.get_tensor_parallelism_size(), - ), - dtype=intermediate_output.dtype, - ) - - if is_batch_sharded_by_expert: - # locally unpermute back to the original order - if self.config.use_ragged_sort: - # Mirror the ragged-prefix gather used in `local_permute`. The - # un-permute can use the same valid-prefix length because the - # routed token count is identical for forward and backward. - valid_end = jnp.sum(routing.group_sizes).astype(jnp.int32) - local_output = a2a_ragged_unsort( - intermediate_output, - jnp.argsort(route_metadata.local_sorted_indices), # pylint: disable=undefined-variable - valid_end, - ) - else: - local_output = _sort_activations( - intermediate_output, - jnp.argsort(route_metadata.local_sorted_indices), - self.config.use_custom_sort_vjp, - ) - - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - jnp.transpose(route_metadata.all_shards_group_sizes), - route_metadata.expert_shard_id, - self.get_expert_parallelism_size(), - ) - intermediate_output = jax.lax.ragged_all_to_all( - local_output, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) - else: - # If batch is replicated across EP shards then each shard should send - # 0..local_shard_size data to the other shards and receive the - # local_shard data from all of the other shards using ragged_all_to_all. - input_offsets, send_sizes, output_offsets, recv_sizes = RoutedMoE.get_all_to_all_params( - route_metadata.reshaped_group_sizes, - route_metadata.expert_shard_id, - self.get_expert_parallelism_size(), - is_batch_sharded=False, - ) - intermediate_output = jax.lax.ragged_all_to_all( - intermediate_output, - output_shape, - input_offsets, - send_sizes, - output_offsets, - recv_sizes, - axis_name=self._expert_parallelism_name, - ) - - output = self.unpermute( + intermediate_output = unsort_output_with_ra2a( intermediate_output, - routing.sorted_selected_experts, - routing.weights, - batch_size=batch_size, - sequence_length=sequence_length, - use_custom_sort_vjp=self.config.use_custom_sort_vjp, - group_sizes=routing.group_sizes, + routing, + route_metadata, + output_shape, + is_batch_sharded_by_expert, ) + output = self.unpermute( + intermediate_output, + routing.sorted_selected_experts, + routing.weights, + batch_size=batch_size, + sequence_length=sequence_length, + use_custom_sort_vjp=self.config.use_custom_sort_vjp, + group_sizes=routing.group_sizes, + ) + return output, routing.lb_loss, routing.bias_updates if self.config.moe_fsdp_use_two_stage_all_gather: @@ -1851,7 +1865,7 @@ def wrapper(x, logits, pre_bias_logits, w0, w1, wo, w0_bias, w1_bias, wo_bias, s if wo_bias is not None: wo_bias = self._maybe_shard_with_pspec(wo_bias, wo_bias_pspec) - return wrapper( + return sparse_matmul_route_and_compute( inputs, gate_logits, pre_bias_logits,