diff --git a/pytest.ini b/pytest.ini index 8800326967..eaa026a3c2 100644 --- a/pytest.ini +++ b/pytest.ini @@ -10,6 +10,7 @@ addopts = --ignore=tests/integration/smoke/train_int8_smoke_test.py --ignore=tests/integration/smoke/train_smoke_test.py --ignore=tests/integration/smoke/train_using_ragged_dot_smoke_test.py + --ignore=tests/unit/convert_deepseek_unscanned_low_memory_test.py --ignore=tests/unit/dequantize_mxfp4_test.py --ignore=tests/unit/dequantize_pack_quantized_int4_test.py --ignore=tests/unit/gemma3_layers_test.py diff --git a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py index ff1330660e..75d381380a 100644 --- a/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py +++ b/src/maxtext/checkpoint_conversion/standalone_scripts/convert_deepseek_family_unscanned_ckpt.py @@ -19,6 +19,12 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_unscanned_ckpt \ --base_model_path \ --maxtext_model_path --model_size deepseek2-16b + +Pass --low_memory true to keep peak host RAM at O(one tensor) instead of +O(2x model size): tensors are read from the safetensors shards one at a time +and converted weights are staged in disk-backed memmaps (under TMPDIR) until +the checkpoint is written. Recommended for trillion-parameter checkpoints such +as kimi-k2.6-text, which otherwise needs ~2.5 TB of host RAM. """ # pylint: disable=line-too-long @@ -30,6 +36,8 @@ import os import gc import logging +import shutil +import tempfile import absl import numpy as np @@ -46,34 +54,38 @@ absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log -def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info) -> dict: - """Convert Huggingface Checkpoint to Jax.""" - base_num_decoder_layers = model_params["num_layers"] - first_num_dense_layers = model_params["first_num_dense_layers"] - base_num_query_heads = model_params["base_num_query_heads"] - base_emb_dim = model_params["base_emb_dim"] - num_experts = model_params["num_experts"] - q_lora_rank = model_params["q_lora_rank"] - kv_lora_rank = model_params["kv_lora_rank"] - qk_nope_head_dim = model_params["qk_nope_head_dim"] - qk_rope_head_dim = model_params["qk_rope_head_dim"] - v_head_dim = model_params["v_head_dim"] +class _LazyShardLoader: + """Dict-like loader that reads one tensor at a time from safetensors shards. - ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors")) - chkpt_vars = {} - is_compressed = bool(model_params.get("compressed_int4", False)) - hf_key_prefix = model_params.get("hf_key_prefix", "") - - def _normalize(raw_key): - if not hf_key_prefix: - return raw_key - if raw_key.startswith(hf_key_prefix): - return raw_key[len(hf_key_prefix) :] - return None - - for i, ckpt_path in enumerate(ckpt_paths): - max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...") - with safe_open(ckpt_path, framework="pt", device="cpu") as f: + Building the index only reads the shard headers (`safe_open(...).keys()`), so + no tensor data is resident up front. Each `__getitem__` loads (and, for + compressed int4 experts, dequantizes) exactly one tensor from its source shard + and keeps no reference to it, so peak memory during conversion stays at + O(one tensor) instead of O(whole model). See + https://github.com/AI-Hypercomputer/maxtext/issues/4071. + """ + + def __init__(self, ckpt_paths, model_params): + base_num_decoder_layers = model_params["num_layers"] + first_num_dense_layers = model_params["first_num_dense_layers"] + num_experts = model_params["num_experts"] + is_compressed = bool(model_params.get("compressed_int4", False)) + hf_key_prefix = model_params.get("hf_key_prefix", "") + + def _normalize(raw_key): + if not hf_key_prefix: + return raw_key + if raw_key.startswith(hf_key_prefix): + return raw_key[len(hf_key_prefix) :] + return None + + self._handles = {} + # mapped MaxText key -> (shard path, raw key within the shard, whether the + # raw key is the base name of a packed-int4 triple to dequantize on access). + self._index = {} + for i, ckpt_path in enumerate(ckpt_paths): + max_logging.log(f"Indexing checkpoint {i+1} of {len(ckpt_paths)} ...") + f = self._open(ckpt_path) for raw_key in f.keys(): key = _normalize(raw_key) if key is None: @@ -91,13 +103,7 @@ def _normalize(raw_key): ).get(hf_key) if not mapped_key: continue - raw_base = raw_key[: -len(".weight_packed")] - shape_t = f.get_tensor(raw_base + ".weight_shape") - chkpt_vars[mapped_key] = ds_ckpt.dequantize_pack_quantized_int4( - f.get_tensor(raw_key), - f.get_tensor(raw_base + ".weight_scale"), - shape_t.tolist(), - ) + self._index[mapped_key] = (ckpt_path, raw_key[: -len(".weight_packed")], True) continue if is_compressed and key.endswith((".weight_scale", ".weight_shape")): continue @@ -111,11 +117,90 @@ def _normalize(raw_key): layer, num_experts, first_num_dense_layers, base_num_decoder_layers ).get(key) if mapped_key: - chkpt_vars[mapped_key] = f.get_tensor(raw_key) + self._index[mapped_key] = (ckpt_path, raw_key, False) else: # This catches keys that are allowed but missing from the mapping dictionary max_logging.log(f"Debug: Allowed key '{key}' (layer {layer}) has no mapping in hf_to_maxtext_mapping.") + def _open(self, ckpt_path): + if ckpt_path not in self._handles: + self._handles[ckpt_path] = safe_open(ckpt_path, framework="pt", device="cpu") + return self._handles[ckpt_path] + + def __getitem__(self, mapped_key) -> torch.Tensor: + ckpt_path, raw_key, is_packed_int4 = self._index[mapped_key] + f = self._open(ckpt_path) + if is_packed_int4: + return ds_ckpt.dequantize_pack_quantized_int4( + f.get_tensor(raw_key + ".weight_packed"), + f.get_tensor(raw_key + ".weight_scale"), + f.get_tensor(raw_key + ".weight_shape").tolist(), + ) + return f.get_tensor(raw_key) + + +class _LeafSpiller: + """Spills converted weight leaves to .npy memmaps in a scratch directory. + + `spill` writes an array to disk and returns it reopened as a read-only memmap, + so its pages are clean and the OS can evict them under memory pressure. + `empty` pre-allocates a zero-initialized leaf on disk for incremental filling + (e.g. per-expert stacking); `seal` flushes such a leaf and reopens it + read-only. + """ + + def __init__(self, spill_dir): + self._spill_dir = spill_dir + self._num_leaves = 0 + + def empty(self, shape, dtype) -> np.memmap: + """Allocates a zero-initialized writable memmap leaf on disk.""" + leaf_path = os.path.join(self._spill_dir, f"leaf_{self._num_leaves:05d}.npy") + self._num_leaves += 1 + return np.lib.format.open_memmap(leaf_path, mode="w+", dtype=dtype, shape=shape) + + def seal(self, leaf) -> np.memmap: + """Flushes a writable memmap leaf and reopens it read-only.""" + leaf.flush() + return np.lib.format.open_memmap(leaf.filename, mode="r") + + def spill(self, leaf) -> np.memmap: + """Writes an in-memory array to disk and returns a read-only memmap view.""" + out = self.empty(leaf.shape, leaf.dtype) + out[...] = leaf + return self.seal(out) + + +def _keep(leaf): + """Identity leaf placement used when low-memory spilling is disabled.""" + return leaf + + +def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info, spill_dir=None) -> dict: + """Convert Huggingface Checkpoint to Jax. + + If spill_dir is set, converted leaves are staged in disk-backed memmaps under + that directory instead of host RAM (low-memory mode). + """ + base_num_decoder_layers = model_params["num_layers"] + first_num_dense_layers = model_params["first_num_dense_layers"] + base_num_query_heads = model_params["base_num_query_heads"] + base_emb_dim = model_params["base_emb_dim"] + num_experts = model_params["num_experts"] + q_lora_rank = model_params["q_lora_rank"] + kv_lora_rank = model_params["kv_lora_rank"] + qk_nope_head_dim = model_params["qk_nope_head_dim"] + qk_rope_head_dim = model_params["qk_rope_head_dim"] + v_head_dim = model_params["v_head_dim"] + + ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors")) + chkpt_vars = _LazyShardLoader(ckpt_paths, model_params) + if spill_dir is None: + to_leaf, alloc_leaf, seal_leaf = _keep, np.zeros, _keep + else: + spiller = _LeafSpiller(spill_dir) + to_leaf, alloc_leaf, seal_leaf = spiller.spill, spiller.empty, spiller.seal + logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # initialize the data structure for storing jax_weights @@ -129,19 +214,19 @@ def _normalize(raw_key): # decoder norm scale ########################################### max_logging.log("Processing decoder norm scale") - jax_weights["decoder"]["decoder_norm"]["scale"] = chkpt_vars["decoder_norm.scale"].to(torch.float16).numpy() + jax_weights["decoder"]["decoder_norm"]["scale"] = to_leaf(chkpt_vars["decoder_norm.scale"].to(torch.float16).numpy()) logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # logits dense ################################################# max_logging.log("Processing logits dense") - jax_weights["decoder"]["logits_dense"]["kernel"] = ( + jax_weights["decoder"]["logits_dense"]["kernel"] = to_leaf( chkpt_vars["logits_dense.kernel"].to(torch.float16).numpy().transpose() ) logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) # token embedding ############################################## max_logging.log("Processing token embeddings") - jax_weights["token_embedder"]["embedding"] = chkpt_vars["token_embedder.embedding"].to(torch.float16).numpy() + jax_weights["token_embedder"]["embedding"] = to_leaf(chkpt_vars["token_embedder.embedding"].to(torch.float16).numpy()) logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) layers = { @@ -241,18 +326,18 @@ def _normalize(raw_key): else: self_attention.update({"query": {"kernel": None}}) - self_attention["kv_norm"]["scale"] = kv_norm - self_attention["wkv_a"]["kernel"] = wkv_a - self_attention["wkv_b"]["kernel"] = wkv_b - self_attention["out"]["kernel"] = out - pre_self_attention_layer_norm["scale"] = pre_self_attention - post_self_attention_layer_norm["scale"] = post_self_attention + self_attention["kv_norm"]["scale"] = to_leaf(kv_norm) + self_attention["wkv_a"]["kernel"] = to_leaf(wkv_a) + self_attention["wkv_b"]["kernel"] = to_leaf(wkv_b) + self_attention["out"]["kernel"] = to_leaf(out) + pre_self_attention_layer_norm["scale"] = to_leaf(pre_self_attention) + post_self_attention_layer_norm["scale"] = to_leaf(post_self_attention) if q_lora_rank != 0: - self_attention["q_norm"]["scale"] = q_norm - self_attention["wq_a"]["kernel"] = wq_a - self_attention["wq_b"]["kernel"] = wq_b + self_attention["q_norm"]["scale"] = to_leaf(q_norm) + self_attention["wq_a"]["kernel"] = to_leaf(wq_a) + self_attention["wq_b"]["kernel"] = to_leaf(wq_b) else: - self_attention["query"]["kernel"] = query + self_attention["query"]["kernel"] = to_leaf(query) jax_weights["decoder"][layer_name]["self_attention"] = self_attention jax_weights["decoder"][layer_name]["pre_self_attention_layer_norm"] = pre_self_attention_layer_norm @@ -268,9 +353,9 @@ def _normalize(raw_key): wi_0 = chkpt_vars[f"{layer_key}.{layer_idx}.mlp.wi_0.kernel"].to(torch.float16).numpy().transpose() wi_1 = chkpt_vars[f"{layer_key}.{layer_idx}.mlp.wi_1.kernel"].to(torch.float16).numpy().transpose() wo = chkpt_vars[f"{layer_key}.{layer_idx}.mlp.wo.kernel"].to(torch.float16).numpy().transpose() - mlp["wi_0"]["kernel"] = wi_0 - mlp["wi_1"]["kernel"] = wi_1 - mlp["wo"]["kernel"] = wo + mlp["wi_0"]["kernel"] = to_leaf(wi_0) + mlp["wi_1"]["kernel"] = to_leaf(wi_1) + mlp["wo"]["kernel"] = to_leaf(wo) jax_weights["decoder"][layer_name]["mlp"] = mlp else: layer_name = f"{layer_key}_{layer_idx}" @@ -308,11 +393,11 @@ def _normalize(raw_key): ) if q_lora_rank != 0: - moe["MoeBlock_0"]["gate"]["bias"] = gate_bias - moe["MoeBlock_0"]["gate"]["kernel"] = gate - moe["shared_experts"]["wi_0"]["kernel"] = shared_wi_0 - moe["shared_experts"]["wi_1"]["kernel"] = shared_wi_1 - moe["shared_experts"]["wo"]["kernel"] = shared_wo + moe["MoeBlock_0"]["gate"]["bias"] = to_leaf(gate_bias) + moe["MoeBlock_0"]["gate"]["kernel"] = to_leaf(gate) + moe["shared_experts"]["wi_0"]["kernel"] = to_leaf(shared_wi_0) + moe["shared_experts"]["wi_1"]["kernel"] = to_leaf(shared_wi_1) + moe["shared_experts"]["wo"]["kernel"] = to_leaf(shared_wo) for k in tqdm(range(num_experts), desc="experts", leave=False): wi_0 = ( @@ -336,13 +421,16 @@ def _normalize(raw_key): if moe["MoeBlock_0"]["wi_0"] is None: stack_shape = (num_experts,) - moe["MoeBlock_0"]["wi_0"] = np.zeros(stack_shape + wi_0.shape, dtype=np.float16) - moe["MoeBlock_0"]["wi_1"] = np.zeros(stack_shape + wi_1.shape, dtype=np.float16) - moe["MoeBlock_0"]["wo"] = np.zeros(stack_shape + wo.shape, dtype=np.float16) + moe["MoeBlock_0"]["wi_0"] = alloc_leaf(stack_shape + wi_0.shape, dtype=np.float16) + moe["MoeBlock_0"]["wi_1"] = alloc_leaf(stack_shape + wi_1.shape, dtype=np.float16) + moe["MoeBlock_0"]["wo"] = alloc_leaf(stack_shape + wo.shape, dtype=np.float16) moe["MoeBlock_0"]["wi_0"][k, ...] = wi_0 moe["MoeBlock_0"]["wi_1"][k, ...] = wi_1 moe["MoeBlock_0"]["wo"][k, ...] = wo + moe["MoeBlock_0"]["wi_0"] = seal_leaf(moe["MoeBlock_0"]["wi_0"]) + moe["MoeBlock_0"]["wi_1"] = seal_leaf(moe["MoeBlock_0"]["wi_1"]) + moe["MoeBlock_0"]["wo"] = seal_leaf(moe["MoeBlock_0"]["wo"]) jax_weights["decoder"][layer_name]["DeepSeekMoeBlock_0"] = moe del chkpt_vars @@ -351,7 +439,7 @@ def _normalize(raw_key): return jax_weights -def _convert_to_jax_weights(base_model_path, model_size, mem_info) -> dict: +def _convert_to_jax_weights(base_model_path, model_size, mem_info, spill_dir=None) -> dict: """ Function to convert the checkpoint at base_model_path into Orbax checkpoint for MaxText and output jax_weights ready for MaxText. @@ -360,6 +448,8 @@ def _convert_to_jax_weights(base_model_path, model_size, mem_info) -> dict: base_model_path: Path to the Hugging Face model checkpoint. model_size: Model size key in MODEL_PARAMS_DICT. mem_info: A process instance used for memory tracking. + spill_dir: Optional scratch directory; if set, converted leaves are staged + in disk-backed memmaps under it instead of host RAM (low-memory mode). Returns: The converted JAX weights. @@ -367,7 +457,7 @@ def _convert_to_jax_weights(base_model_path, model_size, mem_info) -> dict: model_params = ds_ckpt.MODEL_PARAMS_DICT[model_size] logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3)) max_logging.log(f"Loading the base model from {base_model_path}") - return _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info) + return _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info, spill_dir) def main() -> None: @@ -378,6 +468,15 @@ def main() -> None: parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16) parser.add_argument("--use-ocdbt", type=str2bool, required=False, default=True) parser.add_argument("--use-zarr3", type=str2bool, required=False, default=True) + parser.add_argument( + "--low_memory", + type=str2bool, + required=False, + default=False, + help="Stage converted weights in disk-backed memmaps under TMPDIR instead of host RAM, keeping peak RSS at " + "O(one tensor) instead of O(2x model size). Needs free disk for one fp16 copy of the model; the checkpoint is " + "saved without simulated-device sharding (the saved checkpoint is identical and topology-independent).", + ) args = parser.parse_args() if args.model_size not in ds_ckpt.MODEL_PARAMS_DICT: @@ -386,13 +485,26 @@ def main() -> None: os.environ["JAX_PLATFORMS"] = "cpu" os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}" mem_info = psutil.Process() - save_weights_to_checkpoint( - args.maxtext_model_path, - _convert_to_jax_weights(args.base_model_path, args.model_size, mem_info), - args.simulated_cpu_devices_count, - args.use_ocdbt, - args.use_zarr3, - ) + spill_dir = None + device_count = args.simulated_cpu_devices_count + if args.low_memory: + spill_dir = tempfile.mkdtemp(prefix="convert_unscanned_spill_") + max_logging.log(f"low_memory: staging converted weights in {spill_dir} (set TMPDIR to control its location)") + if device_count > 1: + # Simulated-device sharding would re-materialize every leaf in host RAM. + max_logging.log("low_memory: saving without simulated-device sharding so weights stay on disk") + device_count = 1 + try: + save_weights_to_checkpoint( + args.maxtext_model_path, + _convert_to_jax_weights(args.base_model_path, args.model_size, mem_info, spill_dir), + device_count, + args.use_ocdbt, + args.use_zarr3, + ) + finally: + if spill_dir is not None: + shutil.rmtree(spill_dir, ignore_errors=True) if __name__ == "__main__": diff --git a/tests/end_to_end/tpu/kimi/Run_Kimi.md b/tests/end_to_end/tpu/kimi/Run_Kimi.md index b024822399..801274ebd0 100644 --- a/tests/end_to_end/tpu/kimi/Run_Kimi.md +++ b/tests/end_to_end/tpu/kimi/Run_Kimi.md @@ -65,7 +65,7 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_fam --base_model_path $LOCAL_HF_PATH \ --maxtext_model_path $GCS_PATH_TO_SAVE ``` -Use `convert_deepseek_family_unscanned_ckpt.py` with the same `--model_size` for the unscanned (decoding) layout. +Use `convert_deepseek_family_unscanned_ckpt.py` with the same `--model_size` for the unscanned (decoding) layout. Converting a trillion-parameter variant this way holds the whole dequantized model plus the assembled output in host RAM (~2.5 TB for K2.6); on smaller hosts add `--low_memory true`, which streams tensors from the shards one at a time and stages converted weights in disk-backed memmaps under `TMPDIR` (needs free disk for one fp16 copy of the model, ~4x the int4 checkpoint size). > **Note:** The Pre-training / fine-tuning / decoding flows below use `model_name=kimi-k2-1t`. K2-Thinking / K2.5 / K2.6 text branches share K2's architecture, so the same config works — just point `tokenizer_path` at the variant-specific HF tokenizer (e.g. `moonshotai/Kimi-K2.5`) and `load_parameters_path` at the converted checkpoint. diff --git a/tests/unit/convert_deepseek_unscanned_low_memory_test.py b/tests/unit/convert_deepseek_unscanned_low_memory_test.py new file mode 100644 index 0000000000..39b1767f2b --- /dev/null +++ b/tests/unit/convert_deepseek_unscanned_low_memory_test.py @@ -0,0 +1,314 @@ +# Copyright 2023–2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for the streaming/low-memory unscanned DeepSeek-family converter (#4071). + +Uses a tiny synthetic kimi-k2.6-style checkpoint (compressed int4 routed experts, +`language_model.` key prefix, a vision key that must be dropped) to check that: + - no tensor data is read while indexing the shards (the converter used to buffer + every dequantized tensor for all shards up front, needing ~2.5 TB RAM for K2.6) + - each source tensor is read at most once + - low-memory (disk-spilled) conversion is bit-identical to the in-memory one + - a checkpoint saved from disk-spilled leaves restores to the expected values + +Not run in GitHub runners (depends on torch). +""" + +import collections +import os +import pathlib +import shutil +import tempfile +import unittest +import zlib + +os.environ.setdefault("JAX_PLATFORMS", "cpu") + +import numpy as np +import psutil +import pytest +import torch +from safetensors.torch import save_file + +from maxtext.checkpoint_conversion.standalone_scripts import convert_deepseek_family_ckpt as ds_ckpt +from maxtext.checkpoint_conversion.standalone_scripts import convert_deepseek_family_unscanned_ckpt as unscanned_ckpt +from maxtext.checkpoint_conversion.utils.utils import save_weights_to_checkpoint +from maxtext.utils import max_logging + +_TINY_MODEL_SIZE = "tiny-kimi-test" +_TINY_PARAMS = { + "num_layers": 3, # 1 dense + 2 MoE + "first_num_dense_layers": 1, + "base_num_query_heads": 2, + "base_emb_dim": 64, + "num_experts": 4, + "q_lora_rank": 16, + "kv_lora_rank": 16, + "qk_nope_head_dim": 8, + "qk_rope_head_dim": 4, + "v_head_dim": 8, + "has_mtp": False, + "compressed_int4": True, + "hf_key_prefix": "language_model.", +} +_VOCAB = 128 +_DENSE_INTER = 96 +_MOE_INTER = 32 + + +def _rng_for(key): + """Deterministic per-key RNG so expected values can be recomputed independently.""" + return np.random.default_rng(zlib.crc32(key.encode())) + + +def _bf16(key, shape): + return torch.from_numpy(_rng_for(key).standard_normal(shape).astype(np.float32)).to(torch.bfloat16) + + +def _packed_int4(key, out_features, in_features): + """compressed-tensors pack-quantized layout: int32 [out, in/8], bf16 scale [out, in/32], shape [2].""" + rng = _rng_for(key) + packed = torch.from_numpy( + rng.integers(-(2**31), 2**31 - 1, size=(out_features, in_features // 8), dtype=np.int64).astype(np.int32) + ) + scale = torch.from_numpy((rng.standard_normal((out_features, in_features // 32)).astype(np.float32) * 0.01)).to( + torch.bfloat16 + ) + shape = torch.tensor([out_features, in_features], dtype=torch.int64) + return packed, scale, shape + + +def _attention_tensors(params, layer_idx): + """All bf16 tensors for one layer's attention + norms (HF [out, in] layout).""" + emb = params["base_emb_dim"] + num_heads = params["base_num_query_heads"] + prefix = f"model.layers.{layer_idx}" + return { + f"{prefix}.input_layernorm.weight": _bf16(f"{prefix}.iln", (emb,)), + f"{prefix}.post_attention_layernorm.weight": _bf16(f"{prefix}.pln", (emb,)), + f"{prefix}.self_attn.q_a_proj.weight": _bf16(f"{prefix}.qa", (params["q_lora_rank"], emb)), + f"{prefix}.self_attn.q_a_layernorm.weight": _bf16(f"{prefix}.qan", (params["q_lora_rank"],)), + f"{prefix}.self_attn.q_b_proj.weight": _bf16( + f"{prefix}.qb", (num_heads * (params["qk_nope_head_dim"] + params["qk_rope_head_dim"]), params["q_lora_rank"]) + ), + f"{prefix}.self_attn.kv_a_proj_with_mqa.weight": _bf16( + f"{prefix}.kva", (params["kv_lora_rank"] + params["qk_rope_head_dim"], emb) + ), + f"{prefix}.self_attn.kv_a_layernorm.weight": _bf16(f"{prefix}.kvn", (params["kv_lora_rank"],)), + f"{prefix}.self_attn.kv_b_proj.weight": _bf16( + f"{prefix}.kvb", (num_heads * (params["qk_nope_head_dim"] + params["v_head_dim"]), params["kv_lora_rank"]) + ), + f"{prefix}.self_attn.o_proj.weight": _bf16(f"{prefix}.o", (emb, num_heads * params["v_head_dim"])), + } + + +def _write_tiny_hf_checkpoint(params, out_dir): + """Synthesizes a tiny multi-shard kimi-k2.6-style HF checkpoint.""" + emb = params["base_emb_dim"] + pfx = params["hf_key_prefix"] + shards = [] + + # shard 0: embeddings + the dense layer + a vision key that must be dropped + shard0 = {pfx + "model.embed_tokens.weight": _bf16("embed", (_VOCAB, emb))} + shard0.update({pfx + k: v for k, v in _attention_tensors(params, 0).items()}) + shard0[pfx + "model.layers.0.mlp.gate_proj.weight"] = _bf16("d0.gate", (_DENSE_INTER, emb)) + shard0[pfx + "model.layers.0.mlp.up_proj.weight"] = _bf16("d0.up", (_DENSE_INTER, emb)) + shard0[pfx + "model.layers.0.mlp.down_proj.weight"] = _bf16("d0.down", (emb, _DENSE_INTER)) + shard0["vision_tower.patch_embed.proj.weight"] = _bf16("vision", (8, 8)) + shards.append(shard0) + + # one shard per MoE layer + for layer_idx in range(params["first_num_dense_layers"], params["num_layers"]): + shard = {pfx + k: v for k, v in _attention_tensors(params, layer_idx).items()} + prefix = f"model.layers.{layer_idx}" + shard[pfx + f"{prefix}.mlp.gate.weight"] = _bf16(f"{prefix}.rgate", (params["num_experts"], emb)) + shard[pfx + f"{prefix}.mlp.gate.e_score_correction_bias"] = _bf16(f"{prefix}.rbias", (params["num_experts"],)) + shard[pfx + f"{prefix}.mlp.shared_experts.gate_proj.weight"] = _bf16(f"{prefix}.sg", (_MOE_INTER, emb)) + shard[pfx + f"{prefix}.mlp.shared_experts.up_proj.weight"] = _bf16(f"{prefix}.su", (_MOE_INTER, emb)) + shard[pfx + f"{prefix}.mlp.shared_experts.down_proj.weight"] = _bf16(f"{prefix}.sd", (emb, _MOE_INTER)) + for expert_idx in range(params["num_experts"]): + for proj, (out_features, in_features) in ( + ("gate_proj", (_MOE_INTER, emb)), + ("up_proj", (_MOE_INTER, emb)), + ("down_proj", (emb, _MOE_INTER)), + ): + base = f"{prefix}.mlp.experts.{expert_idx}.{proj}" + packed, scale, shape = _packed_int4(base, out_features, in_features) + shard[pfx + base + ".weight_packed"] = packed + shard[pfx + base + ".weight_scale"] = scale + shard[pfx + base + ".weight_shape"] = shape + shards.append(shard) + + # last shard: final norm + lm_head + shards.append( + { + pfx + "model.norm.weight": _bf16("norm", (emb,)), + pfx + "lm_head.weight": _bf16("lmhead", (_VOCAB, emb)), + } + ) + + for i, shard in enumerate(shards): + save_file(shard, pathlib.Path(out_dir) / f"model-{i+1:05d}-of-{len(shards):05d}.safetensors") + + +def _flatten(tree, prefix=""): + out = {} + if isinstance(tree, dict): + for key, value in tree.items(): + out.update(_flatten(value, f"{prefix}/{key}")) + elif tree is not None: + out[prefix] = np.asarray(tree) + return out + + +class _CountingHandle: + """Wraps a safetensors handle, recording every get_tensor call.""" + + def __init__(self, inner, path, read_log, assembly_started): + self._inner = inner + self._path = path + self._read_log = read_log + self._assembly_started = assembly_started + + def keys(self): + return self._inner.keys() + + def get_tensor(self, name): + self._read_log.append((os.path.basename(str(self._path)), name, bool(self._assembly_started))) + return self._inner.get_tensor(name) + + def __enter__(self): + return self + + def __exit__(self, *exc): + return self._inner.__exit__(*exc) + + +@pytest.mark.cpu_only +class ConvertDeepseekUnscannedLowMemoryTest(unittest.TestCase): + """Streaming + disk-spill behavior of convert_deepseek_family_unscanned_ckpt.""" + + @classmethod + def setUpClass(cls): + cls.tmp_dir = tempfile.mkdtemp(prefix="unscanned_low_memory_test_") + cls.hf_dir = os.path.join(cls.tmp_dir, "hf") + os.makedirs(cls.hf_dir) + _write_tiny_hf_checkpoint(_TINY_PARAMS, cls.hf_dir) + ds_ckpt.MODEL_PARAMS_DICT[_TINY_MODEL_SIZE] = _TINY_PARAMS + + @classmethod + def tearDownClass(cls): + ds_ckpt.MODEL_PARAMS_DICT.pop(_TINY_MODEL_SIZE, None) + shutil.rmtree(cls.tmp_dir, ignore_errors=True) + + def _convert(self, spill_dir=None): + # pylint: disable=protected-access + if spill_dir is None: + return unscanned_ckpt._convert_to_jax_weights(self.hf_dir, _TINY_MODEL_SIZE, psutil.Process()) + return unscanned_ckpt._convert_to_jax_weights(self.hf_dir, _TINY_MODEL_SIZE, psutil.Process(), spill_dir) + + def test_no_tensor_reads_before_assembly_and_no_rereads(self): + """Tensor data must be streamed during assembly, not buffered while scanning shards. + + The converter used to load (and dequantize) every tensor of every shard into one + dict before assembling the output pytree, so peak RSS was ~the whole dequantized + model (~2.3 TB for kimi-k2.6); see #4071. + """ + read_log = [] + assembly_started = [] + original_safe_open = unscanned_ckpt.safe_open + original_log = max_logging.log + + def counting_safe_open(path, *args, **kwargs): + return _CountingHandle(original_safe_open(path, *args, **kwargs), path, read_log, assembly_started) + + def phase_marking_log(message, *args, **kwargs): + if isinstance(message, str) and message.startswith("Processing decoder norm scale"): + assembly_started.append(True) + return original_log(message, *args, **kwargs) + + unscanned_ckpt.safe_open = counting_safe_open + max_logging.log = phase_marking_log + try: + self._convert() + finally: + unscanned_ckpt.safe_open = original_safe_open + max_logging.log = original_log + + self.assertTrue(read_log, "expected the converter to read tensors") + reads_before_assembly = [(path, name) for path, name, in_assembly in read_log if not in_assembly] + self.assertEqual( + reads_before_assembly, [], f"{len(reads_before_assembly)} tensors were buffered before assembly began" + ) + read_counts = collections.Counter((path, name) for path, name, _ in read_log) + rereads = {key: count for key, count in read_counts.items() if count > 1} + self.assertEqual(rereads, {}, "each source tensor should be read at most once") + + def test_low_memory_pytree_is_bit_identical(self): + """Disk-spilled conversion must produce the same pytree as the in-memory one.""" + reference = _flatten(self._convert()) + spill_dir = os.path.join(self.tmp_dir, "spill") + os.makedirs(spill_dir, exist_ok=True) + spilled = self._convert(spill_dir=spill_dir) + for leaf in ( + spilled["token_embedder"]["embedding"], + spilled["decoder"]["moe_layers_0"]["DeepSeekMoeBlock_0"]["MoeBlock_0"]["wi_0"], + ): + self.assertIsInstance(leaf, np.memmap) + self.assertFalse(leaf.flags.writeable) + spilled = _flatten(spilled) + + self.assertEqual(set(reference), set(spilled)) + for key, want in reference.items(): + self.assertEqual(want.dtype, spilled[key].dtype, key) + np.testing.assert_array_equal(want, spilled[key], err_msg=key) + + def test_low_memory_checkpoint_restores_expected_values(self): + """An Orbax checkpoint saved from disk-spilled leaves restores the expected weights.""" + import orbax.checkpoint as ocp # pylint: disable=import-outside-toplevel + + spill_dir = os.path.join(self.tmp_dir, "spill_save") + os.makedirs(spill_dir, exist_ok=True) + ckpt_dir = os.path.join(self.tmp_dir, "ckpt") + save_weights_to_checkpoint(ckpt_dir, self._convert(spill_dir=spill_dir), 1, True, True) + restored = _flatten(ocp.PyTreeCheckpointer().restore(os.path.join(ckpt_dir, "0", "items"))) + + emb = _TINY_PARAMS["base_emb_dim"] + + def as_np_f16(tensor): + return tensor.to(torch.float16).numpy() + + np.testing.assert_array_equal( + restored["/params/params/token_embedder/embedding"], as_np_f16(_bf16("embed", (_VOCAB, emb))) + ) + np.testing.assert_array_equal( + restored["/params/params/decoder/logits_dense/kernel"], as_np_f16(_bf16("lmhead", (_VOCAB, emb))).transpose() + ) + np.testing.assert_array_equal(restored["/params/params/decoder/decoder_norm/scale"], as_np_f16(_bf16("norm", (emb,)))) + np.testing.assert_array_equal( + restored["/params/params/decoder/dense_layers_0/mlp/wi_0/kernel"], + as_np_f16(_bf16("d0.gate", (_DENSE_INTER, emb))).transpose(), + ) + # routed-expert int4 path: HF layer 1 -> moe_layers_0, expert 2, up_proj -> wi_1 + packed, scale, shape = _packed_int4("model.layers.1.mlp.experts.2.up_proj", _MOE_INTER, emb) + expected_expert = as_np_f16(ds_ckpt.dequantize_pack_quantized_int4(packed, scale, shape.tolist())).transpose() + np.testing.assert_array_equal( + restored["/params/params/decoder/moe_layers_0/DeepSeekMoeBlock_0/MoeBlock_0/wi_1"][2], expected_expert + ) + # the vision tower key must not leak into the text-only checkpoint + self.assertFalse([k for k in restored if "vision" in k.lower()]) + + +if __name__ == "__main__": + unittest.main()