Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 40 additions & 6 deletions examples/speechlm2/to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import json
import os
from copy import deepcopy
from dataclasses import dataclass
from pathlib import Path
from typing import Any
Expand All @@ -22,7 +23,9 @@
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import save_file

from nemo.collections.speechlm2.parts.hf_hub import LLM_BACKBONE_DIR
from nemo.core.config import hydra_runner
from nemo.utils.dtype import str_to_dtype
from nemo.utils.model_utils import import_class_by_path


Expand Down Expand Up @@ -92,19 +95,45 @@ def consolidate_state_dict(model: torch.nn.Module) -> dict[str, torch.Tensor]:
return consolidated


def _canonical_torch_dtype_name(dtype: str | torch.dtype) -> str:
"""Return the PyTorch dtype name accepted by Transformers configs."""
return str(str_to_dtype(dtype)).replace("torch.", "")


def _hf_export_config(model: torch.nn.Module, dtype: str | torch.dtype) -> dict[str, Any]:
"""Build the exported root config without mutating the training config."""
config = OmegaConf.to_container(model.cfg) if isinstance(model.cfg, DictConfig) else deepcopy(model.cfg)
dtype_name = _canonical_torch_dtype_name(dtype)
config["dtype"] = dtype_name
config["torch_dtype"] = dtype_name
return config


def save_hf_checkpoint(model: torch.nn.Module, state_dict: dict, cfg: HfExportConfig) -> None:
"""Save a consolidated state dict and model config in HuggingFace Hub format."""
output_dir = Path(cfg.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

target_dtype = getattr(torch, cfg.dtype)
target_dtype = str_to_dtype(cfg.dtype)
state_dict = {k: v.to(target_dtype) for k, v in state_dict.items()}

save_file(state_dict, output_dir / "model.safetensors")

config = OmegaConf.to_container(model.cfg) if isinstance(model.cfg, DictConfig) else model.cfg
config = _hf_export_config(model, cfg.dtype)
with open(output_dir / "config.json", "w") as f:
json.dump(config, f, indent=2)
save_llm_backbone_config(model, output_dir)


def save_llm_backbone_config(model: torch.nn.Module, output_dir: str | Path) -> None:
"""Save the original LLM config separately from the NeMo wrapper config."""
llm_config = getattr(getattr(model, "llm", None), "config", None)
if llm_config is None:
return

llm_backbone_dir = Path(output_dir) / LLM_BACKBONE_DIR
llm_backbone_dir.mkdir(parents=True, exist_ok=True)
llm_config.save_pretrained(str(llm_backbone_dir))


def _detect_vllm_architecture(model_cfg: dict) -> str:
Expand Down Expand Up @@ -165,7 +194,11 @@ def prepare_for_vllm(output_dir: str, model_cfg: dict) -> None:
raise ValueError("model config has no 'audio_locator_tag' (set it in the training YAML).")

# 1. Patch config.json (arch, model_type, audio_locator_tag for vLLM plugin).
arch = _detect_vllm_architecture(model_cfg)
arch_model_cfg = dict(model_cfg)
llm_backbone_dir = output_dir / LLM_BACKBONE_DIR
if (llm_backbone_dir / "config.json").exists():
arch_model_cfg["pretrained_llm"] = str(llm_backbone_dir)
arch = _detect_vllm_architecture(arch_model_cfg)
config_path = output_dir / "config.json"
config = json.loads(config_path.read_text())
config["model_type"] = "nemo_speechlm"
Expand Down Expand Up @@ -274,7 +307,7 @@ def main(cfg: HfExportConfig) -> None:

full_cfg = OmegaConf.to_container(OmegaConf.load(cfg.ckpt_config), resolve=True)
model_cfg = full_cfg["model"]
model_cfg["torch_dtype"] = cfg.dtype
model_cfg["torch_dtype"] = _canonical_torch_dtype_name(cfg.dtype)
cls = import_class_by_path(cfg.class_path)

strategy_cfg = full_cfg.get("trainer", {}).get("strategy", {})
Expand Down Expand Up @@ -317,9 +350,10 @@ def main(cfg: HfExportConfig) -> None:
model_cfg["init_configure_model"] = True
model = cls(model_cfg)
load_checkpoint(model, cfg.ckpt_path)
model = model.to(getattr(torch, cfg.dtype))
model = model.to(str_to_dtype(cfg.dtype))
model_cfg["pretrained_weights"] = False
model.save_pretrained(cfg.output_dir)
model.save_pretrained(cfg.output_dir, config=_hf_export_config(model, cfg.dtype))
save_llm_backbone_config(model, cfg.output_dir)
_try_prepare_for_vllm(cfg.output_dir, model_cfg)


Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/speechlm2/models/duplex_ear_tts.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ def __init__(self, cfg: dict) -> None:
self.audio_codec_run_dtype = getattr(torch, self.cfg.get("audio_codec_run_dtype", "float32"), torch.float32)

# Load tokenizer
tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_lm_name
self.tokenizer = AutoTokenizer(
self.cfg.pretrained_lm_name,
tokenizer_src,
use_fast=True,
trust_remote_code=True,
bos_token=self.cfg.get("bos_token", None),
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/speechlm2/models/duplex_s2s_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __init__(self, cfg: dict) -> None:
# pretrained LM head weights.
# However, for S2S we need to access the activations before LM head directly
# to feed them to the audio codec head.
self.tokenizer = AutoTokenizer(self.cfg.pretrained_llm, use_fast=True)
tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm
self.tokenizer = AutoTokenizer(tokenizer_src, use_fast=True)
llm = load_pretrained_hf(self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights).train()
self.llm = llm.model # fetch PretrainedBaseModel from model "ForCausalLM"
self.lm_head = llm.lm_head
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def __init__(self, cfg: dict) -> None:
# pretrained LM head weights.
# However, for S2S we need to access the activations before LM head directly
# to feed them to the audio codec head.
self.tokenizer = AutoTokenizer(self.cfg.pretrained_llm, use_fast=True)
tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm
self.tokenizer = AutoTokenizer(tokenizer_src, use_fast=True)
llm = load_pretrained_hf(self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights).train()
self.llm = llm.model # fetch PretrainedBaseModel from model "ForCausalLM"
self.lm_head = llm.lm_head
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/speechlm2/models/duplex_stt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def __init__(self, cfg: dict) -> None:
).train()

# Initialize tokenizer with optional special tokens from config
tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm
self.tokenizer = AutoTokenizer(
self.cfg.pretrained_llm,
tokenizer_src,
use_fast=True,
bos_token=self.cfg.get("bos_token", None),
eos_token=self.cfg.get("eos_token", None),
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/speechlm2/models/salm.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ def __init__(self, cfg) -> None:
self.cfg = DictConfig(cfg)
self.audio_locator_tag = self.cfg.audio_locator_tag

tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm
self.tokenizer = AutoTokenizer(
self.cfg.pretrained_llm, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False)
tokenizer_src, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False)
)
self.tokenizer.add_special_tokens({"additional_special_tokens": [self.audio_locator_tag]})
self.llm = load_pretrained_hf(
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/speechlm2/models/salm_asr_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def __init__(self, cfg) -> None:
self.cfg = DictConfig(cfg)
self.audio_locator_tag = self.cfg.audio_locator_tag

self.tokenizer = AutoTokenizer(self.cfg.pretrained_llm, use_fast=True)
tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm
self.tokenizer = AutoTokenizer(tokenizer_src, use_fast=True)
self.tokenizer.add_special_tokens({"additional_special_tokens": [self.audio_locator_tag]})
self.llm = load_pretrained_hf(self.cfg.pretrained_llm, pretrained_weights=self.cfg.pretrained_weights)
if not hasattr(self.llm, "model") and hasattr(self.llm, "backbone"):
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/speechlm2/models/salm_automodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ def __init__(self, cfg) -> None:
self.cfg = DictConfig(cfg)
self.audio_locator_tag = self.cfg.audio_locator_tag

tokenizer_src = self.cfg.get("tokenizer_path", None) or self.cfg.pretrained_llm
self.tokenizer = AutoTokenizer(
self.cfg.pretrained_llm, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False)
tokenizer_src, use_fast=True, trust_remote_code=self.cfg.get("trust_remote_code", False)
)
self.tokenizer.add_special_tokens({"additional_special_tokens": [self.audio_locator_tag]})
self.llm = None # populated by configure_model
Expand Down
27 changes: 27 additions & 0 deletions nemo/collections/speechlm2/parts/hf_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from transformers.utils import cached_file

SAFETENSORS_SINGLE_FILE = "model.safetensors"
LLM_BACKBONE_DIR = "llm_backbone"


class HFHubMixin(
Expand Down Expand Up @@ -80,6 +81,7 @@ def _from_pretrained(
if resolved_config_file is None:
raise RuntimeError(f"Missing {CONFIG_NAME} file for {model_id=}")
model_kwargs['cfg'] = OmegaConf.to_container(OmegaConf.load(resolved_config_file))
_inject_local_artifact_paths(model_kwargs['cfg'], model_id, _cached_file_kwargs)
# The setting below tells the model's __init__ not to load the original pretrained weights
# for individual children modules.
# To illustrate: if you trained a new model M using a pretrained ASR and a pretrained LLM,
Expand Down Expand Up @@ -252,3 +254,28 @@ def _load_state_dict_with_dtensors(model, weight_dir):
# the planner narrows each tensor to the local DTensor shard,
# and copies directly into model parameter storage.
dcp.load(state_dict, storage_reader=reader)


def _inject_local_artifact_paths(cfg: dict, model_id: str, cached_file_kwargs: dict) -> None:
"""
Redirect a loaded SpeechLM2 checkpoint config to artifacts saved beside it.

The root checkpoint directory keeps NeMo's wrapper ``config.json``. When it
also contains a root tokenizer and ``llm_backbone/config.json``, point
tokenizer construction to the root directory and LLM config construction to
``llm_backbone`` by mutating ``tokenizer_path`` plus ``pretrained_llm`` or
``pretrained_lm_name`` in-place.
"""
resolved_tokenizer_file = cached_file(model_id, "tokenizer_config.json", **cached_file_kwargs)
if resolved_tokenizer_file is not None and ("pretrained_llm" in cfg or "pretrained_lm_name" in cfg):
cfg["tokenizer_path"] = str(Path(resolved_tokenizer_file).parent)

resolved_llm_config_file = cached_file(model_id, f"{LLM_BACKBONE_DIR}/{CONFIG_NAME}", **cached_file_kwargs)
if resolved_llm_config_file is None:
return

llm_backbone_path = str(Path(resolved_llm_config_file).parent)
if "pretrained_llm" in cfg:
cfg["pretrained_llm"] = llm_backbone_path
if "pretrained_lm_name" in cfg:
cfg["pretrained_lm_name"] = llm_backbone_path
75 changes: 75 additions & 0 deletions tests/collections/speechlm2/test_hf_hub.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nemo.collections.speechlm2.parts.hf_hub import _inject_local_artifact_paths


def _cached_file_kwargs():
return {
"cache_dir": None,
"force_download": False,
"local_files_only": True,
"token": None,
"revision": None,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_raise_exceptions_for_connection_errors": False,
}


def _write_local_export_artifacts(tmp_path):
(tmp_path / "tokenizer_config.json").write_text("{}")
(tmp_path / "llm_backbone").mkdir()
(tmp_path / "llm_backbone" / "config.json").write_text("{}")


def test_inject_local_artifact_paths_salm_config(tmp_path):
_write_local_export_artifacts(tmp_path)
cfg = {
"pretrained_llm": "remote-llm",
"pretrained_asr": "remote-asr",
}

_inject_local_artifact_paths(cfg, str(tmp_path), _cached_file_kwargs())

assert cfg["pretrained_llm"] == str(tmp_path / "llm_backbone")
assert cfg["pretrained_asr"] == "remote-asr"
assert cfg["tokenizer_path"] == str(tmp_path)


def test_inject_local_artifact_paths_duplex_eartts_config(tmp_path):
_write_local_export_artifacts(tmp_path)
cfg = {
"pretrained_lm_name": "remote-llm",
"tts_config": {},
}

_inject_local_artifact_paths(cfg, str(tmp_path), _cached_file_kwargs())

assert cfg["pretrained_lm_name"] == str(tmp_path / "llm_backbone")
assert cfg["tokenizer_path"] == str(tmp_path)


def test_inject_local_artifact_paths_no_artifacts_keeps_old_config(tmp_path):
cfg = {
"pretrained_llm": "remote-llm",
"pretrained_weights": True,
}

_inject_local_artifact_paths(cfg, str(tmp_path), _cached_file_kwargs())

assert cfg == {
"pretrained_llm": "remote-llm",
"pretrained_weights": True,
}
69 changes: 69 additions & 0 deletions tests/collections/speechlm2/test_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from unittest.mock import patch

import pytest
import torch
from safetensors.torch import load_file

_TO_HF_PATH = Path(__file__).parents[3] / "examples" / "speechlm2" / "to_hf.py"
_spec = importlib.util.spec_from_file_location("to_hf_for_test", _TO_HF_PATH)
Expand Down Expand Up @@ -101,6 +103,73 @@ def _seed_output_dir(tmp_path, llm_arch="Qwen2ForCausalLM"):
return tmp_path


class _FakeLLMConfig:
def save_pretrained(self, output_dir):
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
(output_dir / "config.json").write_text(
json.dumps(
{
"model_type": "qwen2",
"architectures": ["Qwen2ForCausalLM"],
"hidden_size": 2048,
}
)
)


class _FakeExportModel:
cfg = {
"pretrained_llm": "fake-model",
"pretrained_asr": "fake-asr",
"pretrained_weights": False,
"dtype": "bf16",
"torch_dtype": "bf16",
"audio_locator_tag": AUDIO_TOKEN,
}
llm = type("_FakeLLM", (), {"config": _FakeLLMConfig()})()


def test_save_hf_checkpoint_writes_llm_backbone_config(tmp_path):
cfg = to_hf.HfExportConfig(
class_path="fake.Class",
ckpt_path="fake.ckpt",
ckpt_config="fake.yaml",
output_dir=str(tmp_path),
dtype="bfloat16",
)
to_hf.save_hf_checkpoint(_FakeExportModel(), {"weight": torch.zeros(1)}, cfg)

root_cfg = json.loads((tmp_path / "config.json").read_text())
llm_cfg = json.loads((tmp_path / "llm_backbone" / "config.json").read_text())

assert "llm_config" not in root_cfg
assert root_cfg["pretrained_llm"] == "fake-model"
assert root_cfg["dtype"] == "bfloat16"
assert root_cfg["torch_dtype"] == "bfloat16"
assert llm_cfg["model_type"] == "qwen2"
assert llm_cfg["architectures"] == ["Qwen2ForCausalLM"]
assert _FakeExportModel.cfg["dtype"] == "bf16"


def test_save_hf_checkpoint_accepts_bf16_export_dtype(tmp_path):
cfg = to_hf.HfExportConfig(
class_path="fake.Class",
ckpt_path="fake.ckpt",
ckpt_config="fake.yaml",
output_dir=str(tmp_path),
dtype="bf16",
)
to_hf.save_hf_checkpoint(_FakeExportModel(), {"weight": torch.zeros(1)}, cfg)

root_cfg = json.loads((tmp_path / "config.json").read_text())
state_dict = load_file(tmp_path / "model.safetensors")

assert root_cfg["dtype"] == "bfloat16"
assert root_cfg["torch_dtype"] == "bfloat16"
assert state_dict["weight"].dtype == torch.bfloat16


# ──────────────────────────────────────────────────────────────────────
# Error paths (no mocking required — checks run before any HF calls)
# ──────────────────────────────────────────────────────────────────────
Expand Down
Loading