diff --git a/src/maxtext/checkpoint_conversion/to_maxtext.py b/src/maxtext/checkpoint_conversion/to_maxtext.py index 4245201b4e..47e6b5868a 100644 --- a/src/maxtext/checkpoint_conversion/to_maxtext.py +++ b/src/maxtext/checkpoint_conversion/to_maxtext.py @@ -67,7 +67,8 @@ from maxtext.common.common_types import MODEL_MODE_TRAIN from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS from maxtext.checkpoint_conversion.utils.param_mapping import HOOK_FNS, PARAM_MAPPING -from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, apply_hook_fns, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, param_key_parts_from_path, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys +from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns, _get_hf_loading_function +from maxtext.checkpoint_conversion.utils.utils import MemoryMonitorTqdm, load_hf_dict_from_transformers, load_hf_dict_from_safetensors, param_key_parts_from_path, print_peak_memory, print_ram_usage, save_weights_to_checkpoint, validate_and_filter_param_map_keys from maxtext.inference.inference_utils import str2bool from maxtext.layers import quantizations from maxtext.models import models @@ -340,151 +341,7 @@ def get_maxtext_model_info(config): return maxtext_abstract_dict, abstract_params_treedef -def _build_multi_axis_stacked_tensor( - hf_source_keys: List[List[str]], - tensor_getter_fn: Callable[[str], np.ndarray], - hook_fns: Any, - target_shape: tuple, - config, -) -> np.ndarray: - """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers). - - This function handles the complex case for scanned MoE layers, producing a tensor - with the shape (num_experts, num_layers, ...). - - Args: - hf_source_keys: A nested (2D) list of Hugging Face parameter names. - Outer list iterates experts, inner list iterates layers. - tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). - hook_fns: The hook function(s) to apply to each individual weight. - target_shape: The final shape of the target MaxText tensor. - config: The MaxText pyconfig object. - - Returns: - The final, assembled NumPy array for the MaxText parameter. - """ - all_expert_tensors = [] - # The hook function needs the shape of an individual slice, not the full stacked tensor. - # For multi-axis stacking (experts, layers, ...), the slice shape is target_shape[2:] - mt_slice_shape = target_shape[2:] - - # Outer loop iterates through experts - for layer_keys_for_expert in hf_source_keys: - layer_tensors_for_expert = [] - # Inner loop iterates through layers for the current expert - for hf_key_single in layer_keys_for_expert: - if isinstance(hf_key_single, (list, tuple)): - hf_tensor_numpy = tuple(tensor_getter_fn(k) for k in hf_key_single) - else: - hf_tensor_numpy = tensor_getter_fn(hf_key_single) - processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) - layer_tensors_for_expert.append(processed_hf_tensor) - all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0)) - return np.stack(all_expert_tensors, axis=0) - - -def _build_single_axis_stacked_tensor( - hf_source_keys: List[str], - tensor_getter_fn: Callable[[str], np.ndarray], - hook_fns: Any, - target_shape: tuple, - config, -) -> np.ndarray: - """Builds a MaxText tensor by stacking HF weights along a single axis. - - This function handles both standard scanned layers (e.g., attention) and - unscanned MoE layers (which are stacked along the expert axis). - - Args: - hf_source_keys: A 1D list of Hugging Face parameter names. - tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). - hook_fns: The hook function(s) to apply to each individual weight. - target_shape: The final shape of the target MaxText tensor. - config: The MaxText pyconfig object. - - Returns: - The final, assembled NumPy array for the MaxText parameter. - """ - tensors_to_stack = [] - - if config.scan_layers: - # If it's a standard scanned layer, we use the configured param_scan_axis. - axis_to_stack = config.param_scan_axis - else: - # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). - axis_to_stack = 0 - - # The hook function needs the shape of an individual slice, not the full stacked tensor. - # We calculate it by removing the stacking dimension from the final target shape. - mt_slice_shape_list = list(target_shape) - del mt_slice_shape_list[axis_to_stack] - mt_slice_shape = tuple(mt_slice_shape_list) - - for hf_key_single in hf_source_keys: - if isinstance(hf_key_single, (list, tuple)): - hf_tensor_numpy = tuple(tensor_getter_fn(k) for k in hf_key_single) - else: - hf_tensor_numpy = tensor_getter_fn(hf_key_single) - processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) - tensors_to_stack.append(processed_hf_tensor) - - # Stack all processed tensors along the determined axis. - return np.stack(tensors_to_stack, axis=axis_to_stack) - - -def _get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config): - """Determine the loading function for HF keys. - - This function natively supports `composite_hf_key` mapping (where multiple HF keys - combine into a single MaxText parameter, like Qwen3.5's qkv and z -> in_proj_qkvz). - If the input is a tuple of strings, they are fetched as a tuple of arrays and passed - together into the model hook. - HF keys can take four forms: - Case 1: Unscanned (single string) - Case 2: Scanned (list of strings) - Case 3: Unscanned with expert stacking (list of strings) - Case 4: Scanned with expert stacking (nested list of strings) - """ - load_fn = None - if not isinstance(hf_source_keys_or_key, list): - # Case 1: Single hf key (str) - def _loader(getter, key, shape, hook): - if isinstance(key, (list, tuple)): - tensors = tuple(getter(k) for k in key) - return apply_hook_fns(tensors, shape, hook) - return apply_hook_fns(getter(key), shape, hook) - - load_fn = partial( - _loader, - tensor_getter, - hf_source_keys_or_key, - mt_target_shape_or_shapes, - hook_fn, - ) - # Stacked mapping - elif not isinstance(hf_source_keys_or_key[0], list): - # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) - load_fn = partial( - _build_single_axis_stacked_tensor, - hf_source_keys_or_key, - tensor_getter, - hook_fn, - mt_target_shape_or_shapes, - config, - ) - else: - # isinstance(hf_source_keys_or_key[0], list) - # Case 4: Multi-Axis Stacked hf keys (nested list) - load_fn = partial( - _build_multi_axis_stacked_tensor, - hf_source_keys_or_key, - tensor_getter, - hook_fn, - mt_target_shape_or_shapes, - config, - ) - return load_fn def _get_maxtext_indices_and_shapes(mt_param_key_or_keys, maxtext_abstract_dict): diff --git a/src/maxtext/checkpoint_conversion/utils/load_dynamic.py b/src/maxtext/checkpoint_conversion/utils/load_dynamic.py new file mode 100644 index 0000000000..9e57da68de --- /dev/null +++ b/src/maxtext/checkpoint_conversion/utils/load_dynamic.py @@ -0,0 +1,315 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Dynamic loading of HuggingFace checkpoints during training/eval workloads directly in the target format.""" + +import concurrent.futures +import gc +import multiprocessing +import os +import random +import time +import numpy as np + +from flax import nnx +from flax import traverse_util +from google.cloud import storage +from huggingface_hub import HfFileSystem +import jax +from orbax.checkpoint import v1 as ocp_v1 +from orbax.checkpoint._src.arrays import sharding as sharding_utils + +from maxtext.utils import max_logging +from maxtext.utils.globals import HF_IDS +from maxtext.checkpoint_conversion.utils.tensor_handling import get_hf_loading_function +from maxtext.checkpoint_conversion.utils import param_mapping +from maxtext.checkpoint_conversion.utils.hf_model_configs import HF_MODEL_CONFIGS + + +def build_gcs_cache_worker(fpath, gcs_cache_dir, hf_access_token): + fs = HfFileSystem(token=hf_access_token) + time.sleep(random.uniform(0.0, 5.0)) + + bucket_name = gcs_cache_dir.replace("gs://", "").split("/")[0] + blob_prefix = gcs_cache_dir.replace("gs://", "").split("/", 1)[1] if "/" in gcs_cache_dir.replace("gs://", "") else "" + blob_name = os.path.join(blob_prefix, os.path.basename(fpath)) + + storage_client = storage.Client() + bucket = storage_client.bucket(bucket_name) + blob = bucket.blob(blob_name) + + if blob.exists(): + max_logging.log(f"[Worker] Cache hit for {os.path.basename(fpath)}.") + return + + t0 = time.time() + max_retries = 5 + for attempt in range(max_retries): + try: + with fs.open(fpath, "rb") as remote_f: + blob.chunk_size = 1024 * 1024 * 32 # 32MB chunks + blob.upload_from_file(remote_f, client=storage_client) + print(f"[Worker] Cached {os.path.basename(fpath)} in {time.time() - t0:.1f}s", flush=True) + break + except Exception as e: + if attempt < max_retries - 1: + max_logging.log(f"Error fetching {fpath} to GCS: {e}. Retrying in 15 seconds... (Attempt {attempt+1}/{max_retries})") + time.sleep(15) + else: + max_logging.log(f"Failed to fetch {fpath} to GCS after {max_retries} attempts.") + raise + + +def get_hf_config_and_mappings(config): + """Gets HF config and parameter mapping based on the MaxText config.""" + model_key = config.model_name + if "-Instruct" in model_key: + model_key = model_key.replace("-Instruct", "") + hf_config_obj = HF_MODEL_CONFIGS[model_key] + hf_config_dict = hf_config_obj.to_dict() + + param_map_mt_to_hf = param_mapping.PARAM_MAPPING[model_key]( + hf_config_dict, config, scan_layers=config.scan_layers + ) + hook_fn_map_mt = param_mapping.HOOK_FNS[model_key]( + hf_config_dict, config, scan_layers=config.scan_layers, saving_to_hf=False + ) + return param_map_mt_to_hf, hook_fn_map_mt + + +def load_sharded_hf_state(path, devices=None): + """Loads HF state with maximal sharding across TPU mesh to avoid host OOM.""" + t0 = time.time() + context = ocp_v1.Context( + checkpoint_layout=ocp_v1.options.CheckpointLayout.SAFETENSORS, + safetensors_options=ocp_v1.options.SafetensorsOptions(ignore_load_sharding=False), + ) + with context: + metadata = ocp_v1.metadata(path) + simple_abstract_state = metadata.metadata + + # Distributed Sharded Download: Tell JAX to shard the HF Safetensors download + # across the entire TPU mesh to avoid Host OOM. + current_global_devices = devices if devices is not None else jax.devices() + shardings = sharding_utils.construct_maximal_shardings(simple_abstract_state, devices=current_global_devices) + + def combine_sharding(sds, single_sharding): + return jax.ShapeDtypeStruct(shape=sds.shape, dtype=sds.dtype, sharding=single_sharding) + + sharded_abstract_state = jax.tree.map(combine_sharding, simple_abstract_state, shardings) + + max_logging.log("Reading raw Safetensors into memory (Distributed Sharded GCS Download)...") + hf_state = ocp_v1.load(path, sharded_abstract_state) + max_logging.log(f"load_sharded_hf_state took {time.time() - t0:.2f}s") + return hf_state + + +def transform_hf_state_to_mt_state( + hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, config +): + """Transforms HF state into MaxText state by applying param mappings and mathematical hooks.""" + t0 = time.time() + def tensor_getter(key): + return hf_state.pop(key) + + flat_target = traverse_util.flatten_dict(target_tree, sep=".") + flat_restored = flat_target.copy() + + mapped_count = 0 + keys_missed = [] + max_logging.log("Starting fast in-memory Distributed Transformations...") + + for mt_key, hf_source in param_map_mt_to_hf.items(): + mt_name = mt_key.replace("params-", "").replace("-", ".") + + # Determine the correct key in flat_target + check_name = mt_name + if check_name not in flat_target: + if ("params." + mt_name) in flat_target: + check_name = "params." + mt_name + elif mt_key.replace("-", ".") in flat_target: + check_name = mt_key.replace("-", ".") + + if check_name not in flat_target: + keys_missed.append(mt_name) + continue + + target_shape = flat_target[check_name].shape + hook_fn = hook_fn_map_mt.get(mt_key) + + load_fn = get_hf_loading_function( + hf_source, + tensor_getter, + hook_fn, + target_shape, + config, + ) + + # Execute transformation and assign to flat_restored + t_layer = time.time() + unsharded_array = load_fn() + + # Ensure it's Sharded explicitly matching the JAX model expectations + target_sharding = flat_target[check_name].sharding + + if isinstance(unsharded_array, jax.Array): + if target_sharding.device_set == unsharded_array.sharding.device_set: + max_logging.log(f"Loaded {check_name} via TPU-to-TPU direct resharding.") + flat_restored[check_name] = jax.device_put(unsharded_array, device=target_sharding) + else: + max_logging.log(f"Loaded {check_name} via JAX JIT TPU-to-TPU resharding.") + flat_restored[check_name] = jax.jit( + lambda x: x, out_shardings=target_sharding + )(unsharded_array) + else: + if jax.process_count() > 1 and not target_sharding.is_fully_addressable: + max_logging.log(f"Loaded {check_name} via Host CPU callback fallback (NumPy array).") + flat_restored[check_name] = jax.make_array_from_callback( + unsharded_array.shape, target_sharding, lambda index, source=unsharded_array: source[index] + ) + else: + max_logging.log(f"Loaded {check_name} via Host CPU device_put.") + flat_restored[check_name] = jax.device_put(unsharded_array, device=target_sharding) + del unsharded_array + + max_logging.log(f"Transformed {check_name} from {hf_source} in {time.time() - t_layer:.4f}s") + mapped_count += 1 + + if mapped_count % 10 == 0: + gc.collect() + + if mapped_count == 0: + max_logging.log(f"All transformations missed! Sample missed mt_names: {keys_missed[:5]}") + max_logging.log(f"Sample flat_target keys: {list(flat_target.keys())[:5]}") + + max_logging.log(f"Successfully mapped {mapped_count} parameters.") + restored_params = traverse_util.unflatten_dict(flat_restored, sep=".") + + if "params" in restored_params: + restored_params = restored_params["params"] + + max_logging.log(f"transform_hf_state_to_mt_state took {time.time() - t0:.2f}s") + + return {"params": restored_params} + + +def _get_global_mesh(target_tree): + flat_target = traverse_util.flatten_dict(target_tree, sep=".") + for val in flat_target.values(): + if hasattr(val, "sharding") and val.sharding is not None: + return val.sharding.mesh + return None + + +def load_safetensors_dynamic_state(path, abstract_unboxed_pre_state, config): + """Main entry point to dynamically build and load safetensors into MaxText format. + + Splits execution into: + 1. Deriving Mappings + 2. Loading Sharded arrays directly to TPUs + 3. Processing the transformations natively on TPUs + """ + if config is None: + raise ValueError("config must be provided for safetensors_dynamic loading.") + + model_name = config.model_name + if "-Instruct" in model_name: + model_name = model_name.replace("-Instruct", "") + + if not path: + if model_name not in HF_IDS: + raise ValueError(f"Unsupported model name for automatic HF repo resolution: {model_name}.") + path = HF_IDS[model_name] + + if path.startswith("hf://"): + path = path[5:] + + if not path.startswith("gs://") and not os.path.isdir(path): + fs = HfFileSystem(token=config.hf_access_token) + repo_id = path + + files = fs.glob(f"{repo_id}/*.safetensors") + + host_id = jax.process_index() + + if hasattr(config, "base_output_directory") and config.base_output_directory.startswith("gs://"): + gcs_cache_dir = f"{config.base_output_directory}/hf_cache/{repo_id.replace('/', '_')}" + path = gcs_cache_dir + + # Only Host 0 downloads to the shared GCS cache + if host_id == 0: + max_logging.log(f"Dynamic HF Hub Fast DL: Host 0 is downloading to shared GCS Cache: {gcs_cache_dir}") + t_gcs_start = time.time() + + # List existing blobs to avoid spawning processes for already cached files + storage_client = storage.Client() + bucket_name = gcs_cache_dir.replace("gs://", "").split("/")[0] + blob_prefix = gcs_cache_dir.replace("gs://", "").split("/", 1)[1] if "/" in gcs_cache_dir.replace("gs://", "") else "" + + existing_blobs = {blob.name for blob in storage_client.list_blobs(bucket_name, prefix=blob_prefix)} + + files_to_download = [] + for fpath in files: + expected_blob_name = os.path.join(blob_prefix, os.path.basename(fpath)) + if expected_blob_name not in existing_blobs: + files_to_download.append(fpath) + + if files_to_download: + with concurrent.futures.ProcessPoolExecutor(max_workers=32, mp_context=multiprocessing.get_context("spawn")) as executor: + futures = [ + executor.submit(build_gcs_cache_worker, fpath, gcs_cache_dir, config.hf_access_token) + for fpath in files_to_download + ] + + while futures: + done, futures = concurrent.futures.wait(futures, timeout=10) + + # Raise any exceptions if a worker failed + for f in done: + f.result() + + t_gcs_end = time.time() + max_logging.log(f"GCS caching complete in {t_gcs_end - t_gcs_start:.2f}s. Downloaded {len(files_to_download)} missing files.") + + # Global barrier: all hosts wait for Host 0 to finish downloading to the shared GCS bucket + max_logging.log(f"Host {host_id} waiting for GCS cache at {gcs_cache_dir} to be populated by Host 0...") + jax.experimental.multihost_utils.sync_global_devices("dynamic_hf_download_complete") + max_logging.log(f"Host {host_id} detected GCS cache is ready!") + + else: + raise ValueError("base_output_directory with gs:// prefix is required for huggingface downloads.") + + t_total = time.time() + param_map_mt_to_hf, hook_fn_map_mt = get_hf_config_and_mappings(config) + max_logging.log(f"[1/3] Mappings derived in {time.time() - t_total:.2f}s") + + target_tree = ( + abstract_unboxed_pre_state.to_pure_dict() + if isinstance(abstract_unboxed_pre_state, nnx.State) + else abstract_unboxed_pre_state.params + ) + + t1 = time.time() + hf_state = load_sharded_hf_state(path, devices=None) + max_logging.log(f"[2/3] Distributed Sharded GCS load completed in {time.time() - t1:.2f}s") + + t2 = time.time() + restored_params = transform_hf_state_to_mt_state( + hf_state, target_tree, param_map_mt_to_hf, hook_fn_map_mt, config + ) + max_logging.log(f"[3/3] CPU Transformations completed in {time.time() - t2:.2f}s") + max_logging.log(f"Total safetensors_dynamic duration: {time.time() - t_total:.2f}s") + + return None, restored_params + diff --git a/src/maxtext/checkpoint_conversion/utils/param_mapping.py b/src/maxtext/checkpoint_conversion/utils/param_mapping.py index 2767afefbc..ee8dab1d8a 100644 --- a/src/maxtext/checkpoint_conversion/utils/param_mapping.py +++ b/src/maxtext/checkpoint_conversion/utils/param_mapping.py @@ -57,6 +57,21 @@ import jax import jax.numpy as jnp +def concat(args, axis=0): + if any(isinstance(x, jax.Array) for x in args): + return jnp.concatenate(args, axis=axis) + return np.concatenate(args, axis=axis) + +def stack(args, axis=0): + if any(isinstance(x, jax.Array) for x in args): + return jnp.stack(args, axis=axis) + return np.stack(args, axis=axis) + +def split(arr, num_or_size_splits, axis=0): + if isinstance(arr, jax.Array): + return jnp.split(arr, num_or_size_splits, axis=axis) + return np.split(arr, num_or_size_splits, axis=axis) + def GEMMA3_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Generates a parameter mapping from MaxText to Hugging Face for Gemma3. @@ -1155,21 +1170,21 @@ def process_wi_0_wi_1(input_tensor, target_shape=None): # input_tensor is a tuple of the two extracted MaxText arrays: (wi_0, wi_1) wi_0, wi_1 = input_tensor # Concatenate them along the final feature dimension - gate_up = np.concatenate([wi_0, wi_1], axis=-1) + gate_up = concat([wi_0, wi_1], axis=-1) # Transpose to match Hugging Face's expected layout: (experts, 2 * out_features, in_features) return gate_up.swapaxes(-1, -2) else: # 2. HF -> MaxText (Splitting) # input_tensor is the massive HF gate_up_proj. Shape: (..., out, in) # Split into gate and up along the output dimension (axis=-2 for transposed shape logic) - gate, up = np.split(input_tensor, 2, axis=-2) + gate, up = split(input_tensor, 2, axis=-2) # Swap the last two dimensions gate = gate.swapaxes(-1, -2) up = up.swapaxes(-1, -2) # Stack them along a new final dimension so the base conversion script can iterate and split them - return np.stack([gate, up], axis=-1) + return stack([gate, up], axis=-1) text_cfg = config.get("text_config", config) H_k = text_cfg["linear_num_key_heads"] @@ -1194,7 +1209,7 @@ def concat_qkvz_and_transpose(input_tensor, target_shape=None): v = v_r.reshape(H_v * D_v, -1) z = z_r.reshape(H_v * D_v, -1) - qkv = np.concatenate([q, k, v], axis=0) + qkv = concat([q, k, v], axis=0) return qkv, z else: qkv_m, z_m = input_tensor @@ -1213,7 +1228,7 @@ def concat_qkvz_and_transpose(input_tensor, target_shape=None): z_r = z_m.reshape(H_k, V_per_K * D_v, -1) # Concat along the feature dim (axis 1) so they are interleaved per Key-head - interleaved = np.concatenate([q_r, k_r, v_r, z_r], axis=1) + interleaved = concat([q_r, k_r, v_r, z_r], axis=1) return interleaved.reshape(-1, qkv_m.shape[-1]).T def concat_ba_and_transpose(input_tensor, target_shape=None): @@ -1232,7 +1247,7 @@ def concat_ba_and_transpose(input_tensor, target_shape=None): b_m, a_m = input_tensor b_r = b_m.reshape(H_k, V_per_K, -1) a_r = a_m.reshape(H_k, V_per_K, -1) - interleaved = np.concatenate([b_r, a_r], axis=1) + interleaved = concat([b_r, a_r], axis=1) return interleaved.reshape(-1, b_m.shape[-1]).T # Initialize Hooks @@ -1924,7 +1939,7 @@ def interleave(input_tensor, target_shape=None): wi_0_1 = input_tensor wi_0 = wi_0_1[..., ::2] wi_1 = wi_0_1[..., 1::2] - return np.stack([wi_0, wi_1], axis=-1) + return stack([wi_0, wi_1], axis=-1) n_layers = config["num_hidden_layers"] # hf config layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval @@ -2489,13 +2504,13 @@ def adjust_rope(input_tensor, target_shape): # Convert from MaxText's interleaved layout to HF's concatenated layout evens = arr[..., ::2] odds = arr[..., 1::2] - return jax.numpy.concatenate((evens, odds), axis=arr.ndim - 1) + return concat((evens, odds), axis=arr.ndim - 1) else: # Convert from HF's concatenated layout to MaxText's interleaved layout half_dim = arr.shape[-1] // 2 first_half = arr[..., :half_dim] second_half = arr[..., half_dim:] - return jax.numpy.stack([first_half, second_half], axis=-1).reshape(arr.shape) + return stack([first_half, second_half], axis=-1).reshape(arr.shape) def reshape_kernel(input_tensor, target_shape): if saving_to_hf: diff --git a/src/maxtext/checkpoint_conversion/utils/tensor_handling.py b/src/maxtext/checkpoint_conversion/utils/tensor_handling.py new file mode 100644 index 0000000000..9d0db7f8d7 --- /dev/null +++ b/src/maxtext/checkpoint_conversion/utils/tensor_handling.py @@ -0,0 +1,249 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tensor handling utility functions for checkpoint conversion.""" + +from functools import partial +from typing import Any, Callable, List +import jax +import jax.numpy as jnp +import numpy as np + +def _is_jax(x): + if isinstance(x, (list, tuple)): + return any(_is_jax(t) for t in x) + return isinstance(x, jax.Array) + + +@partial(jax.jit, static_argnums=(1, 2)) +def _apply_hook_fns_compiled(weight, target_shape, hook_fns_tuple): + for hook_fn in hook_fns_tuple: + weight = hook_fn(weight, target_shape) + return weight + + +@partial(jax.jit, static_argnums=(1, 2, 3, 4)) +def _build_single_axis_compiled(tensors_tuple, target_shape, hook_fns_tuple, axis_to_stack, mt_slice_shape): + tensors_to_stack = [] + for hf_tensor in tensors_tuple: + processed_hf_tensor = _apply_hook_fns_compiled(hf_tensor, mt_slice_shape, hook_fns_tuple) + tensors_to_stack.append(processed_hf_tensor) + return jnp.stack(tensors_to_stack, axis=axis_to_stack) + + +@partial(jax.jit, static_argnums=(1, 2, 3)) +def _build_multi_axis_compiled(raw_tensors_nested_tuple, target_shape, hook_fns_tuple, mt_slice_shape): + all_expert_tensors = [] + for layer_tensors_for_expert in raw_tensors_nested_tuple: + layer_tensors_processed = [] + for hf_tensor in layer_tensors_for_expert: + processed_hf_tensor = _apply_hook_fns_compiled(hf_tensor, mt_slice_shape, hook_fns_tuple) + layer_tensors_processed.append(processed_hf_tensor) + all_expert_tensors.append(jnp.stack(layer_tensors_processed, axis=0)) + return jnp.stack(all_expert_tensors, axis=0) + + +def apply_hook_fns(weight, target_shape, hook_fns): + """Apply hook functions, essential for to_maxtext and to_huggingface""" + # If hook is unspecified, use identity + if hook_fns is None: + return weight + if not isinstance(hook_fns, list): + hook_fns = [hook_fns] + + if _is_jax(weight): + try: + hook_fns_tuple = tuple(hook_fns) + return _apply_hook_fns_compiled(weight, target_shape, hook_fns_tuple) + except Exception: # pylint: disable=broad-except + # Fallback to host CPU NumPy execution in case of tracer / concrete type errors + pass + + # Standard host CPU NumPy hook execution + for hook_fn in hook_fns: + weight = hook_fn(weight, target_shape) + return weight + + +def build_multi_axis_stacked_tensor( + hf_source_keys: List[List[str]], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> Any: + """Builds a MaxText tensor by stacking HF weights along two axes (experts and layers). + + This function handles the complex case for scanned MoE layers, producing a tensor + with the shape (num_experts, num_layers, ...). + + Args: + hf_source_keys: A nested (2D) list of Hugging Face parameter names. + Outer list iterates experts, inner list iterates layers. + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + mt_slice_shape = target_shape[2:] + # Load all raw tensors first on python side + raw_tensors_nested = [] + for layer_keys_for_expert in hf_source_keys: + expert_tensors = [] + for hf_key_single in layer_keys_for_expert: + if isinstance(hf_key_single, (list, tuple)): + hf_tensor = tuple(tensor_getter_fn(k) for k in hf_key_single) + else: + hf_tensor = tensor_getter_fn(hf_key_single) + expert_tensors.append(hf_tensor) + raw_tensors_nested.append(tuple(expert_tensors)) + + if _is_jax(raw_tensors_nested): + hook_fns_tuple = tuple(hook_fns) if isinstance(hook_fns, list) else ((hook_fns,) if hook_fns is not None else ()) + return _build_multi_axis_compiled( + tuple(raw_tensors_nested), + target_shape, + hook_fns_tuple, + mt_slice_shape + ) + + all_expert_tensors = [] + for expert_tensors in raw_tensors_nested: + layer_tensors_for_expert = [] + for hf_tensor_numpy in expert_tensors: + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + layer_tensors_for_expert.append(processed_hf_tensor) + all_expert_tensors.append(np.stack(layer_tensors_for_expert, axis=0)) + return np.stack(all_expert_tensors, axis=0) + + +def build_single_axis_stacked_tensor( + hf_source_keys: List[str], + tensor_getter_fn: Callable[[str], np.ndarray], + hook_fns: Any, + target_shape: tuple, + config, +) -> Any: + """Builds a MaxText tensor by stacking HF weights along a single axis. + + This function handles both standard scanned layers (e.g., attention) and + unscanned MoE layers (which are stacked along the expert axis). + + Args: + hf_source_keys: A 1D list of Hugging Face parameter names. + tensor_getter_fn: A callable that takes a HF key and returns the tensor (as numpy array). + hook_fns: The hook function(s) to apply to each individual weight. + target_shape: The final shape of the target MaxText tensor. + config: The MaxText pyconfig object. + + Returns: + The final, assembled NumPy array for the MaxText parameter. + """ + tensors_to_stack = [] + + if config.scan_layers: + # If it's a standard scanned layer, we use the configured param_scan_axis. + axis_to_stack = config.param_scan_axis + else: + # Otherwise, if an unscanned MoE layer, and we stack along the expert axis (0). + axis_to_stack = 0 + + # The hook function needs the shape of an individual slice, not the full stacked tensor. + # We calculate it by removing the stacking dimension from the final target shape. + mt_slice_shape_list = list(target_shape) + del mt_slice_shape_list[axis_to_stack] + mt_slice_shape = tuple(mt_slice_shape_list) + + # Load all raw tensors first on python side + raw_tensors = [] + for hf_key_single in hf_source_keys: + if isinstance(hf_key_single, (list, tuple)): + hf_tensor = tuple(tensor_getter_fn(k) for k in hf_key_single) + else: + hf_tensor = tensor_getter_fn(hf_key_single) + raw_tensors.append(hf_tensor) + + if _is_jax(raw_tensors): + hook_fns_tuple = tuple(hook_fns) if isinstance(hook_fns, list) else ((hook_fns,) if hook_fns is not None else ()) + return _build_single_axis_compiled( + tuple(raw_tensors), + target_shape, + hook_fns_tuple, + axis_to_stack, + mt_slice_shape + ) + + tensors_to_stack = [] + for hf_tensor_numpy in raw_tensors: + processed_hf_tensor = apply_hook_fns(hf_tensor_numpy, mt_slice_shape, hook_fns) + tensors_to_stack.append(processed_hf_tensor) + + return np.stack(tensors_to_stack, axis=axis_to_stack) + + +def get_hf_loading_function(hf_source_keys_or_key, tensor_getter, hook_fn, mt_target_shape_or_shapes, config): + """Determine the loading function for HF keys. + + This function natively supports `composite_hf_key` mapping. + If the input is a tuple of strings, they are fetched as a tuple of arrays and passed + together into the model hook. + + HF keys can take four forms: + Case 1: Unscanned (single string) + Case 2: Scanned (list of strings) + Case 3: Unscanned with expert stacking (list of strings) + Case 4: Scanned with expert stacking (nested list of strings) + """ + load_fn = None + if not isinstance(hf_source_keys_or_key, list): + # Case 1: Single hf key (str) + def _loader(getter, key, shape, hook): + if isinstance(key, (list, tuple)): + tensors = tuple(getter(k) for k in key) + return apply_hook_fns(tensors, shape, hook) + return apply_hook_fns(getter(key), shape, hook) + + load_fn = partial( + _loader, + tensor_getter, + hf_source_keys_or_key, + mt_target_shape_or_shapes, + hook_fn, + ) + # Stacked mapping + elif not isinstance(hf_source_keys_or_key[0], list): + # Case 2 or 3: Single-Axis Stacked hf keys (un-nested list) + load_fn = partial( + build_single_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + else: + # isinstance(hf_source_keys_or_key[0], list) + # Case 4: Multi-Axis Stacked hf keys (nested list) + load_fn = partial( + build_multi_axis_stacked_tensor, + hf_source_keys_or_key, + tensor_getter, + hook_fn, + mt_target_shape_or_shapes, + config, + ) + return load_fn diff --git a/src/maxtext/checkpoint_conversion/utils/utils.py b/src/maxtext/checkpoint_conversion/utils/utils.py index cf43763f06..6f345f2998 100644 --- a/src/maxtext/checkpoint_conversion/utils/utils.py +++ b/src/maxtext/checkpoint_conversion/utils/utils.py @@ -47,6 +47,7 @@ from transformers import AutoModelForCausalLM from flax.training import train_state +from maxtext.checkpoint_conversion.utils.tensor_handling import apply_hook_fns from maxtext.common import checkpointing from maxtext.common.gcloud_stub import gcs_storage from maxtext.utils import max_logging @@ -137,19 +138,6 @@ def validate_and_filter_param_map_keys(param_map_keys, maxtext_state_keys): return filtered_map_keys -def apply_hook_fns(weight, target_shape, hook_fns): - """Apply hook functions, essential for to_maxtext and to_huggingface""" - # If hook is unsepecified, use identity - if hook_fns is None: - return weight - if not isinstance(hook_fns, list): - hook_fns = [hook_fns] - # Apply a list of hooks, be careful of order - for hook_fn in hook_fns: - weight = hook_fn(weight, target_shape) - return weight - - def convert_jax_weight_to_numpy(weight: "jax.Array", dtype_str: None | str = None) -> np.ndarray: """Converts a JAX array to a NumPy array with the specified dtype, used in to_huggingface. diff --git a/src/maxtext/common/checkpointing.py b/src/maxtext/common/checkpointing.py index 73f475bb39..24ed91228b 100644 --- a/src/maxtext/common/checkpointing.py +++ b/src/maxtext/common/checkpointing.py @@ -25,6 +25,7 @@ from flax.training import train_state import jax import jax.numpy as jnp +from maxtext.checkpoint_conversion.utils import load_dynamic from maxtext.utils.globals import DEFAULT_OCDBT_TARGET_DATA_FILE_SIZE from maxtext.input_pipeline.multihost_dataloading import MultiHostDataLoadIterator from maxtext.input_pipeline.multihost_dataloading import RemoteIteratorWrapper @@ -773,6 +774,7 @@ def load_state_if_possible( checkpoint_conversion_fn=None, source_checkpoint_layout="orbax", expansion_factor_real_data: int = -1, + config: Any | None = None, ): """Loads TrainState as possible from the inputs. @@ -912,7 +914,14 @@ def map_to_pspec(data): _assert_no_shaped_dtype_struct(restored) return (restored, None) - if load_parameters_from_path != "": + if source_checkpoint_layout == "safetensors_dynamic": + path = load_parameters_from_path or load_full_state_from_path + max_logging.log(f"Dynamic On-the-Fly Formatting: Loading SafeTensors from {path}") + + return load_dynamic.load_safetensors_dynamic_state( + path, abstract_unboxed_pre_state, config + ) + elif load_parameters_from_path != "": if isinstance(abstract_unboxed_pre_state, nnx.State): _, params, _ = nnx.split(abstract_unboxed_pre_state.model, nnx.Param, ...) else: @@ -927,6 +936,9 @@ def map_to_pspec(data): checkpoint_storage_concurrent_gb, use_ocdbt=use_ocdbt, use_zarr3=use_zarr3, + enable_orbax_v1=enable_orbax_v1, + source_checkpoint_layout=source_checkpoint_layout, + checkpoint_conversion_fn=checkpoint_conversion_fn, ) _assert_no_shaped_dtype_struct(restored_params) return None, restored_params @@ -972,7 +984,14 @@ def setup_checkpoint_logger(config) -> Any | None: # pytype: disable=attribute- def load_params_from_path( - load_parameters_from_path, abstract_unboxed_params, checkpoint_storage_concurrent_gb, use_ocdbt=True, use_zarr3=True + load_parameters_from_path, + abstract_unboxed_params, + checkpoint_storage_concurrent_gb, + use_ocdbt=True, + use_zarr3=True, + enable_orbax_v1=False, + source_checkpoint_layout="orbax", + checkpoint_conversion_fn=None, ): """Load decode params from checkpoint at specified path.""" assert load_parameters_from_path, "load_parameters_from_path is not defined." diff --git a/src/maxtext/configs/types.py b/src/maxtext/configs/types.py index cb1987eb77..051af1a046 100644 --- a/src/maxtext/configs/types.py +++ b/src/maxtext/configs/types.py @@ -345,7 +345,7 @@ class Checkpointing(BaseModel): save_quantized_params_path: PathStr = Field("", description="Path to save params quantized on the fly.") enable_orbax_v1: bool = Field(False, description="Bool flag for enabling Orbax v1.") checkpoint_conversion_fn: None | str = Field(None, description="Function for processing loaded checkpoint dict.") - source_checkpoint_layout: Literal["orbax", "safetensors"] = Field( + source_checkpoint_layout: Literal["orbax", "safetensors", "safetensors_dynamic"] = Field( "orbax", description="The layout of the source checkpoint to load." ) save_checkpoint_on_completion: bool = Field( diff --git a/src/maxtext/utils/maxtext_utils.py b/src/maxtext/utils/maxtext_utils.py index 16b022c3a4..28a8878391 100644 --- a/src/maxtext/utils/maxtext_utils.py +++ b/src/maxtext/utils/maxtext_utils.py @@ -1496,6 +1496,7 @@ def setup_initial_state( checkpoint_conversion_fn=config.checkpoint_conversion_fn, source_checkpoint_layout=config.source_checkpoint_layout, expansion_factor_real_data=config.expansion_factor_real_data, + config=config, ) if restored: