Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ addopts =
--ignore=tests/integration/smoke/train_int8_smoke_test.py
--ignore=tests/integration/smoke/train_smoke_test.py
--ignore=tests/integration/smoke/train_using_ragged_dot_smoke_test.py
--ignore=tests/unit/convert_deepseek_unscanned_low_memory_test.py
--ignore=tests/unit/dequantize_mxfp4_test.py
--ignore=tests/unit/dequantize_pack_quantized_int4_test.py
--ignore=tests/unit/gemma3_layers_test.py
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@
python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_family_unscanned_ckpt \
--base_model_path <path/to/meta/ckpt> \
--maxtext_model_path <GCS/path/to/save/new/maxtext/ckpt> --model_size deepseek2-16b

Pass --low_memory true to keep peak host RAM at O(one tensor) instead of
O(2x model size): tensors are read from the safetensors shards one at a time
and converted weights are staged in disk-backed memmaps (under TMPDIR) until
the checkpoint is written. Recommended for trillion-parameter checkpoints such
as kimi-k2.6-text, which otherwise needs ~2.5 TB of host RAM.
"""

# pylint: disable=line-too-long
Expand All @@ -30,6 +36,8 @@
import os
import gc
import logging
import shutil
import tempfile
import absl

import numpy as np
Expand All @@ -46,34 +54,38 @@
absl.logging.set_verbosity(absl.logging.INFO) # for max_logging.log


def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info) -> dict:
"""Convert Huggingface Checkpoint to Jax."""
base_num_decoder_layers = model_params["num_layers"]
first_num_dense_layers = model_params["first_num_dense_layers"]
base_num_query_heads = model_params["base_num_query_heads"]
base_emb_dim = model_params["base_emb_dim"]
num_experts = model_params["num_experts"]
q_lora_rank = model_params["q_lora_rank"]
kv_lora_rank = model_params["kv_lora_rank"]
qk_nope_head_dim = model_params["qk_nope_head_dim"]
qk_rope_head_dim = model_params["qk_rope_head_dim"]
v_head_dim = model_params["v_head_dim"]
class _LazyShardLoader:
"""Dict-like loader that reads one tensor at a time from safetensors shards.

ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
chkpt_vars = {}
is_compressed = bool(model_params.get("compressed_int4", False))
hf_key_prefix = model_params.get("hf_key_prefix", "")

def _normalize(raw_key):
if not hf_key_prefix:
return raw_key
if raw_key.startswith(hf_key_prefix):
return raw_key[len(hf_key_prefix) :]
return None

for i, ckpt_path in enumerate(ckpt_paths):
max_logging.log(f"Loading checkpoint {i+1} of {len(ckpt_paths)} ...")
with safe_open(ckpt_path, framework="pt", device="cpu") as f:
Building the index only reads the shard headers (`safe_open(...).keys()`), so
no tensor data is resident up front. Each `__getitem__` loads (and, for
compressed int4 experts, dequantizes) exactly one tensor from its source shard
and keeps no reference to it, so peak memory during conversion stays at
O(one tensor) instead of O(whole model). See
https://github.com/AI-Hypercomputer/maxtext/issues/4071.
"""

def __init__(self, ckpt_paths, model_params):
base_num_decoder_layers = model_params["num_layers"]
first_num_dense_layers = model_params["first_num_dense_layers"]
num_experts = model_params["num_experts"]
is_compressed = bool(model_params.get("compressed_int4", False))
hf_key_prefix = model_params.get("hf_key_prefix", "")

def _normalize(raw_key):
if not hf_key_prefix:
return raw_key
if raw_key.startswith(hf_key_prefix):
return raw_key[len(hf_key_prefix) :]
return None

self._handles = {}
# mapped MaxText key -> (shard path, raw key within the shard, whether the
# raw key is the base name of a packed-int4 triple to dequantize on access).
self._index = {}
for i, ckpt_path in enumerate(ckpt_paths):
max_logging.log(f"Indexing checkpoint {i+1} of {len(ckpt_paths)} ...")
f = self._open(ckpt_path)
for raw_key in f.keys():
key = _normalize(raw_key)
if key is None:
Expand All @@ -91,13 +103,7 @@ def _normalize(raw_key):
).get(hf_key)
if not mapped_key:
continue
raw_base = raw_key[: -len(".weight_packed")]
shape_t = f.get_tensor(raw_base + ".weight_shape")
chkpt_vars[mapped_key] = ds_ckpt.dequantize_pack_quantized_int4(
f.get_tensor(raw_key),
f.get_tensor(raw_base + ".weight_scale"),
shape_t.tolist(),
)
self._index[mapped_key] = (ckpt_path, raw_key[: -len(".weight_packed")], True)
continue
if is_compressed and key.endswith((".weight_scale", ".weight_shape")):
continue
Expand All @@ -111,11 +117,90 @@ def _normalize(raw_key):
layer, num_experts, first_num_dense_layers, base_num_decoder_layers
).get(key)
if mapped_key:
chkpt_vars[mapped_key] = f.get_tensor(raw_key)
self._index[mapped_key] = (ckpt_path, raw_key, False)
else:
# This catches keys that are allowed but missing from the mapping dictionary
max_logging.log(f"Debug: Allowed key '{key}' (layer {layer}) has no mapping in hf_to_maxtext_mapping.")

def _open(self, ckpt_path):
if ckpt_path not in self._handles:
self._handles[ckpt_path] = safe_open(ckpt_path, framework="pt", device="cpu")
return self._handles[ckpt_path]

def __getitem__(self, mapped_key) -> torch.Tensor:
ckpt_path, raw_key, is_packed_int4 = self._index[mapped_key]
f = self._open(ckpt_path)
if is_packed_int4:
return ds_ckpt.dequantize_pack_quantized_int4(
f.get_tensor(raw_key + ".weight_packed"),
f.get_tensor(raw_key + ".weight_scale"),
f.get_tensor(raw_key + ".weight_shape").tolist(),
)
return f.get_tensor(raw_key)


class _LeafSpiller:
"""Spills converted weight leaves to .npy memmaps in a scratch directory.

`spill` writes an array to disk and returns it reopened as a read-only memmap,
so its pages are clean and the OS can evict them under memory pressure.
`empty` pre-allocates a zero-initialized leaf on disk for incremental filling
(e.g. per-expert stacking); `seal` flushes such a leaf and reopens it
read-only.
"""

def __init__(self, spill_dir):
self._spill_dir = spill_dir
self._num_leaves = 0

def empty(self, shape, dtype) -> np.memmap:
"""Allocates a zero-initialized writable memmap leaf on disk."""
leaf_path = os.path.join(self._spill_dir, f"leaf_{self._num_leaves:05d}.npy")
self._num_leaves += 1
return np.lib.format.open_memmap(leaf_path, mode="w+", dtype=dtype, shape=shape)

def seal(self, leaf) -> np.memmap:
"""Flushes a writable memmap leaf and reopens it read-only."""
leaf.flush()
return np.lib.format.open_memmap(leaf.filename, mode="r")

def spill(self, leaf) -> np.memmap:
"""Writes an in-memory array to disk and returns a read-only memmap view."""
out = self.empty(leaf.shape, leaf.dtype)
out[...] = leaf
return self.seal(out)


def _keep(leaf):
"""Identity leaf placement used when low-memory spilling is disabled."""
return leaf


def _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info, spill_dir=None) -> dict:
"""Convert Huggingface Checkpoint to Jax.

If spill_dir is set, converted leaves are staged in disk-backed memmaps under
that directory instead of host RAM (low-memory mode).
"""
base_num_decoder_layers = model_params["num_layers"]
first_num_dense_layers = model_params["first_num_dense_layers"]
base_num_query_heads = model_params["base_num_query_heads"]
base_emb_dim = model_params["base_emb_dim"]
num_experts = model_params["num_experts"]
q_lora_rank = model_params["q_lora_rank"]
kv_lora_rank = model_params["kv_lora_rank"]
qk_nope_head_dim = model_params["qk_nope_head_dim"]
qk_rope_head_dim = model_params["qk_rope_head_dim"]
v_head_dim = model_params["v_head_dim"]

ckpt_paths = sorted(pathlib.Path(base_model_path).glob("[!.]*.safetensors"))
chkpt_vars = _LazyShardLoader(ckpt_paths, model_params)
if spill_dir is None:
to_leaf, alloc_leaf, seal_leaf = _keep, np.zeros, _keep
else:
spiller = _LeafSpiller(spill_dir)
to_leaf, alloc_leaf, seal_leaf = spiller.spill, spiller.empty, spiller.seal

logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

# initialize the data structure for storing jax_weights
Expand All @@ -129,19 +214,19 @@ def _normalize(raw_key):

# decoder norm scale ###########################################
max_logging.log("Processing decoder norm scale")
jax_weights["decoder"]["decoder_norm"]["scale"] = chkpt_vars["decoder_norm.scale"].to(torch.float16).numpy()
jax_weights["decoder"]["decoder_norm"]["scale"] = to_leaf(chkpt_vars["decoder_norm.scale"].to(torch.float16).numpy())
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

# logits dense #################################################
max_logging.log("Processing logits dense")
jax_weights["decoder"]["logits_dense"]["kernel"] = (
jax_weights["decoder"]["logits_dense"]["kernel"] = to_leaf(
chkpt_vars["logits_dense.kernel"].to(torch.float16).numpy().transpose()
)
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

# token embedding ##############################################
max_logging.log("Processing token embeddings")
jax_weights["token_embedder"]["embedding"] = chkpt_vars["token_embedder.embedding"].to(torch.float16).numpy()
jax_weights["token_embedder"]["embedding"] = to_leaf(chkpt_vars["token_embedder.embedding"].to(torch.float16).numpy())
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))

layers = {
Expand Down Expand Up @@ -241,18 +326,18 @@ def _normalize(raw_key):
else:
self_attention.update({"query": {"kernel": None}})

self_attention["kv_norm"]["scale"] = kv_norm
self_attention["wkv_a"]["kernel"] = wkv_a
self_attention["wkv_b"]["kernel"] = wkv_b
self_attention["out"]["kernel"] = out
pre_self_attention_layer_norm["scale"] = pre_self_attention
post_self_attention_layer_norm["scale"] = post_self_attention
self_attention["kv_norm"]["scale"] = to_leaf(kv_norm)
self_attention["wkv_a"]["kernel"] = to_leaf(wkv_a)
self_attention["wkv_b"]["kernel"] = to_leaf(wkv_b)
self_attention["out"]["kernel"] = to_leaf(out)
pre_self_attention_layer_norm["scale"] = to_leaf(pre_self_attention)
post_self_attention_layer_norm["scale"] = to_leaf(post_self_attention)
if q_lora_rank != 0:
self_attention["q_norm"]["scale"] = q_norm
self_attention["wq_a"]["kernel"] = wq_a
self_attention["wq_b"]["kernel"] = wq_b
self_attention["q_norm"]["scale"] = to_leaf(q_norm)
self_attention["wq_a"]["kernel"] = to_leaf(wq_a)
self_attention["wq_b"]["kernel"] = to_leaf(wq_b)
else:
self_attention["query"]["kernel"] = query
self_attention["query"]["kernel"] = to_leaf(query)

jax_weights["decoder"][layer_name]["self_attention"] = self_attention
jax_weights["decoder"][layer_name]["pre_self_attention_layer_norm"] = pre_self_attention_layer_norm
Expand All @@ -268,9 +353,9 @@ def _normalize(raw_key):
wi_0 = chkpt_vars[f"{layer_key}.{layer_idx}.mlp.wi_0.kernel"].to(torch.float16).numpy().transpose()
wi_1 = chkpt_vars[f"{layer_key}.{layer_idx}.mlp.wi_1.kernel"].to(torch.float16).numpy().transpose()
wo = chkpt_vars[f"{layer_key}.{layer_idx}.mlp.wo.kernel"].to(torch.float16).numpy().transpose()
mlp["wi_0"]["kernel"] = wi_0
mlp["wi_1"]["kernel"] = wi_1
mlp["wo"]["kernel"] = wo
mlp["wi_0"]["kernel"] = to_leaf(wi_0)
mlp["wi_1"]["kernel"] = to_leaf(wi_1)
mlp["wo"]["kernel"] = to_leaf(wo)
jax_weights["decoder"][layer_name]["mlp"] = mlp
else:
layer_name = f"{layer_key}_{layer_idx}"
Expand Down Expand Up @@ -308,11 +393,11 @@ def _normalize(raw_key):
)

if q_lora_rank != 0:
moe["MoeBlock_0"]["gate"]["bias"] = gate_bias
moe["MoeBlock_0"]["gate"]["kernel"] = gate
moe["shared_experts"]["wi_0"]["kernel"] = shared_wi_0
moe["shared_experts"]["wi_1"]["kernel"] = shared_wi_1
moe["shared_experts"]["wo"]["kernel"] = shared_wo
moe["MoeBlock_0"]["gate"]["bias"] = to_leaf(gate_bias)
moe["MoeBlock_0"]["gate"]["kernel"] = to_leaf(gate)
moe["shared_experts"]["wi_0"]["kernel"] = to_leaf(shared_wi_0)
moe["shared_experts"]["wi_1"]["kernel"] = to_leaf(shared_wi_1)
moe["shared_experts"]["wo"]["kernel"] = to_leaf(shared_wo)

for k in tqdm(range(num_experts), desc="experts", leave=False):
wi_0 = (
Expand All @@ -336,13 +421,16 @@ def _normalize(raw_key):

if moe["MoeBlock_0"]["wi_0"] is None:
stack_shape = (num_experts,)
moe["MoeBlock_0"]["wi_0"] = np.zeros(stack_shape + wi_0.shape, dtype=np.float16)
moe["MoeBlock_0"]["wi_1"] = np.zeros(stack_shape + wi_1.shape, dtype=np.float16)
moe["MoeBlock_0"]["wo"] = np.zeros(stack_shape + wo.shape, dtype=np.float16)
moe["MoeBlock_0"]["wi_0"] = alloc_leaf(stack_shape + wi_0.shape, dtype=np.float16)
moe["MoeBlock_0"]["wi_1"] = alloc_leaf(stack_shape + wi_1.shape, dtype=np.float16)
moe["MoeBlock_0"]["wo"] = alloc_leaf(stack_shape + wo.shape, dtype=np.float16)
moe["MoeBlock_0"]["wi_0"][k, ...] = wi_0
moe["MoeBlock_0"]["wi_1"][k, ...] = wi_1
moe["MoeBlock_0"]["wo"][k, ...] = wo

moe["MoeBlock_0"]["wi_0"] = seal_leaf(moe["MoeBlock_0"]["wi_0"])
moe["MoeBlock_0"]["wi_1"] = seal_leaf(moe["MoeBlock_0"]["wi_1"])
moe["MoeBlock_0"]["wo"] = seal_leaf(moe["MoeBlock_0"]["wo"])
jax_weights["decoder"][layer_name]["DeepSeekMoeBlock_0"] = moe

del chkpt_vars
Expand All @@ -351,7 +439,7 @@ def _normalize(raw_key):
return jax_weights


def _convert_to_jax_weights(base_model_path, model_size, mem_info) -> dict:
def _convert_to_jax_weights(base_model_path, model_size, mem_info, spill_dir=None) -> dict:
"""
Function to convert the checkpoint at base_model_path into Orbax checkpoint
for MaxText and output jax_weights ready for MaxText.
Expand All @@ -360,14 +448,16 @@ def _convert_to_jax_weights(base_model_path, model_size, mem_info) -> dict:
base_model_path: Path to the Hugging Face model checkpoint.
model_size: Model size key in MODEL_PARAMS_DICT.
mem_info: A process instance used for memory tracking.
spill_dir: Optional scratch directory; if set, converted leaves are staged
in disk-backed memmaps under it instead of host RAM (low-memory mode).

Returns:
The converted JAX weights.
"""
model_params = ds_ckpt.MODEL_PARAMS_DICT[model_size]
logging.debug("Memory usage: %f GB", mem_info.memory_info().rss / (1024**3))
max_logging.log(f"Loading the base model from {base_model_path}")
return _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info)
return _convert_huggingface_to_jax_weights(base_model_path, model_params, mem_info, spill_dir)


def main() -> None:
Expand All @@ -378,6 +468,15 @@ def main() -> None:
parser.add_argument("--simulated_cpu_devices_count", type=int, required=False, default=16)
parser.add_argument("--use-ocdbt", type=str2bool, required=False, default=True)
parser.add_argument("--use-zarr3", type=str2bool, required=False, default=True)
parser.add_argument(
"--low_memory",
type=str2bool,
required=False,
default=False,
help="Stage converted weights in disk-backed memmaps under TMPDIR instead of host RAM, keeping peak RSS at "
"O(one tensor) instead of O(2x model size). Needs free disk for one fp16 copy of the model; the checkpoint is "
"saved without simulated-device sharding (the saved checkpoint is identical and topology-independent).",
)
args = parser.parse_args()

if args.model_size not in ds_ckpt.MODEL_PARAMS_DICT:
Expand All @@ -386,13 +485,26 @@ def main() -> None:
os.environ["JAX_PLATFORMS"] = "cpu"
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={args.simulated_cpu_devices_count}"
mem_info = psutil.Process()
save_weights_to_checkpoint(
args.maxtext_model_path,
_convert_to_jax_weights(args.base_model_path, args.model_size, mem_info),
args.simulated_cpu_devices_count,
args.use_ocdbt,
args.use_zarr3,
)
spill_dir = None
device_count = args.simulated_cpu_devices_count
if args.low_memory:
spill_dir = tempfile.mkdtemp(prefix="convert_unscanned_spill_")
max_logging.log(f"low_memory: staging converted weights in {spill_dir} (set TMPDIR to control its location)")
if device_count > 1:
# Simulated-device sharding would re-materialize every leaf in host RAM.
max_logging.log("low_memory: saving without simulated-device sharding so weights stay on disk")
device_count = 1
try:
save_weights_to_checkpoint(
args.maxtext_model_path,
_convert_to_jax_weights(args.base_model_path, args.model_size, mem_info, spill_dir),
device_count,
args.use_ocdbt,
args.use_zarr3,
)
finally:
if spill_dir is not None:
shutil.rmtree(spill_dir, ignore_errors=True)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion tests/end_to_end/tpu/kimi/Run_Kimi.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ python3 -m maxtext.checkpoint_conversion.standalone_scripts.convert_deepseek_fam
--base_model_path $LOCAL_HF_PATH \
--maxtext_model_path $GCS_PATH_TO_SAVE
```
Use `convert_deepseek_family_unscanned_ckpt.py` with the same `--model_size` for the unscanned (decoding) layout.
Use `convert_deepseek_family_unscanned_ckpt.py` with the same `--model_size` for the unscanned (decoding) layout. Converting a trillion-parameter variant this way holds the whole dequantized model plus the assembled output in host RAM (~2.5 TB for K2.6); on smaller hosts add `--low_memory true`, which streams tensors from the shards one at a time and stages converted weights in disk-backed memmaps under `TMPDIR` (needs free disk for one fp16 copy of the model, ~4x the int4 checkpoint size).

> **Note:** The Pre-training / fine-tuning / decoding flows below use `model_name=kimi-k2-1t`. K2-Thinking / K2.5 / K2.6 text branches share K2's architecture, so the same config works — just point `tokenizer_path` at the variant-specific HF tokenizer (e.g. `moonshotai/Kimi-K2.5`) and `load_parameters_path` at the converted checkpoint.

Expand Down
Loading
Loading