Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 2 additions & 145 deletions src/maxtext/checkpoint_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@
from maxtext.common.common_types import MODEL_MODE_TRAIN
from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS
from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, param_key_parts_from_path, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns, _get_hf_loading_function
from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, param_key_parts_from_path, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys
from maxtext.inference.inference_utils import str2bool
from maxtext.layers import quantizations
from maxtext.models import models
Expand Down Expand Up @@ -340,151 +341,7 @@ def get_maxtext_model_info(config):
return maxtext_abstract_dict, abstract_params_treedef


def _build_multi_axis_stacked_tensor(
hf_source_keys: List[List[str]],
tensor_getter_fn: Callable[[str], np.ndarray],
hook_fns: Any,
target_shape: tuple,
config,
) -> np.ndarray:
"""Builds a MaxText tensor by stacking HF weights along two axes (experts and layers).

This function handles the complex case for scanned MoE layers, producing a tensor
with the shape (num_experts, num_layers, ...).

Args:
hf_source_keys: A nested (2D) list of Hugging Face parameter names.
Outer list iterates experts, inner list iterates layers.
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
hook_fns: The hook function(s) to apply to each individual weight.
target_shape: The final shape of the target MaxText tensor.
config: The MaxText pyconfig object.

Returns:
The final, assembled NumPy array for the MaxText parameter.
"""
all_expert_tensors = []
# The hook function needs the shape of an individual slice, not the full stacked tensor.
# For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:]
mt_slice_shape = target_shape[2:]

# Outer loop iterates through experts
for layer_keys_for_expert in hf_source_keys:
layer_tensors_for_expert = []
# Inner loop iterates through layers for the current expert
for hf_key_single in layer_keys_for_expert:
if isinstance(hf_key_single, (list, tuple)):
hf_tensor_numpy = tuple(tensor_getter_fn(k) for k in hf_key_single)
else:
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
layer_tensors_for_expert.append(processed_hf_tensor)
all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0))
return np.stack(all_expert_tensors, axis=0)


def _build_single_axis_stacked_tensor(
hf_source_keys: List[str],
tensor_getter_fn: Callable[[str], np.ndarray],
hook_fns: Any,
target_shape: tuple,
config,
) -> np.ndarray:
"""Builds a MaxText tensor by stacking HF weights along a single axis.

This function handles both standard scanned layers (e.g., attention) and
unscanned MoE layers (which are stacked along the expert axis).

Args:
hf_source_keys: A 1D list of Hugging Face parameter names.
tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array).
hook_fns: The hook function(s) to apply to each individual weight.
target_shape: The final shape of the target MaxText tensor.
config: The MaxText pyconfig object.

Returns:
The final, assembled NumPy array for the MaxText parameter.
"""
tensors_to_stack = []

if config.scan_layers:
# If it's a standard scanned layer, we use the configured param_scan_axis.
axis_to_stack = config.param_scan_axis
else:
# Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0).
axis_to_stack = 0

# The hook function needs the shape of an individual slice, not the full stacked tensor.
# We calculate it by removing the stacking dimension from the final target shape.
mt_slice_shape_list = list(target_shape)
del mt_slice_shape_list[axis_to_stack]
mt_slice_shape = tuple(mt_slice_shape_list)

for hf_key_single in hf_source_keys:
if isinstance(hf_key_single, (list, tuple)):
hf_tensor_numpy = tuple(tensor_getter_fn(k) for k in hf_key_single)
else:
hf_tensor_numpy = tensor_getter_fn(hf_key_single)
processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns)
tensors_to_stack.append(processed_hf_tensor)

# Stack all processed tensors along the determined axis.
return np.stack(tensors_to_stack, axis=axis_to_stack)


def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config):
"""Determine the loading function for HF keys.

This function natively supports `composite_hf_key` mapping (where multiple HF keys
combine into a single MaxText parameter, like Qwen3.5's qkv and z -> in_proj_qkvz).
If the input is a tuple of strings, they are fetched as a tuple of arrays and passed
together into the model hook.

HF keys can take four forms:
Case 1: Unscanned (single string)
Case 2: Scanned (list of strings)
Case 3: Unscanned with expert stacking (list of strings)
Case 4: Scanned with expert stacking (nested list of strings)
"""
load_fn = None
if not isinstance(hf_source_keys_or_key, list):
# Case 1: Single hf key (str)
def _loader(getter, key, shape, hook):
if isinstance(key, (list, tuple)):
tensors = tuple(getter(k) for k in key)
return apply_hook_fns(tensors, shape, hook)
return apply_hook_fns(getter(key), shape, hook)

load_fn = partial(
_loader,
tensor_getter,
hf_source_keys_or_key,
mt_target_shape_or_shapes,
hook_fn,
)
# Stacked mapping
elif not isinstance(hf_source_keys_or_key[0], list):
# Case 2 or 3: Single-Axis Stacked hf keys (un-nested list)
load_fn = partial(
_build_single_axis_stacked_tensor,
hf_source_keys_or_key,
tensor_getter,
hook_fn,
mt_target_shape_or_shapes,
config,
)
else:
# isinstance(hf_source_keys_or_key[0], list)
# Case 4: Multi-Axis Stacked hf keys (nested list)
load_fn = partial(
_build_multi_axis_stacked_tensor,
hf_source_keys_or_key,
tensor_getter,
hook_fn,
mt_target_shape_or_shapes,
config,
)
return load_fn


def _get_maxtext_indices_and_shapes(mt_param_key_or_keys, maxtext_abstract_dict):
Expand Down
Loading
Loading