In SkyRL v0.2.0 on the FSDP1 path:
Together these leave the rotary embedding's non-persistent inv_freq buffer in uninitialized memory on every rank except rank 0. When sequence parallelism is engaged (sequence_parallel_size > 1) the corrupted buffer produces NaN at every position of every non-rank-0 rank's forward, which surfaces downstream as policy_kl: nan and final_loss: nan at step 1.
This is closely related to huggingface/transformers#45902, just on the FSDP1 + dense Qwen3 path. Basically transformers wants callers to use model.initialize_weights() (which SkyRL is not doing).
Notably, the FSDP2 init path in SkyRL uses _sync_non_persistent_buffers, which itself has the same root-cause bug: on rank 0 it just does buf.detach().cuda(), which on a meta-resident buffer reads uninitialized GPU memory and then broadcasts that garbage to every rank.
Mechanism
See HFModelWrapper.__init__ under meta_init=True.
model.to(bf16) casts every parameter and non-persistent buffer to bf16 on the meta device/
- Including
Qwen3RotaryEmbedding.inv_freq.
- Rank 0 takes the
from_pretrained(torch_dtype=...) branch, which keeps inv_freq at fp32 regardless of the dtype kwarg.
FSDPStrategy.prepare(wrapped) wraps with param_init_fn=init_fn plus sync_module_states=True. init_fn on non-rank-0 ranks calls x.to_empty(device=cuda, recurse=False), which allocates real GPU memory, but does not initialize the values.
- Non-rank-0 ranks now have a bf16 slot for
inv_freq filled with whatever was previously in that memory region.
sync_module_states then tries to broadcast rank-0's fp32 inv_freq to those bf16 slots; the dtype mismatch makes the broadcast silently fail, and rank-0's values never get written into non-rank-0's bf16 slot.
- The slot retains the uninitialized
to_empty() memory that was there before the broadcast call.
At forward time, transformers Qwen3RotaryEmbedding multiplies the garbage inv_freq with position ids and ultimately and final logits are NaN on ranks 1..N-1.
Reproducer
import argparse
import os
import sys
import torch
import torch.distributed as dist
from skyrl.backends.skyrl_train.distributed.fsdp_strategy import FSDPStrategy
from skyrl.backends.skyrl_train.distributed.fsdp_utils import should_use_meta_init
from skyrl.backends.skyrl_train.distributed.ulysses import (
apply_monkey_patch,
set_ulysses_sequence_parallel_group,
)
from skyrl.backends.skyrl_train.workers.model_wrapper import HFModelWrapper
from skyrl.train.config import FSDPConfig
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__)
p.add_argument("--model", default="Qwen/Qwen3-1.7B")
p.add_argument(
"--sp-size", type=int, default=2,
help="Ulysses sequence-parallel size. >=2 to fire the bug. "
"world_size must be divisible by sp_size.",
)
p.add_argument(
"--seq-len", type=int, default=128,
help="Tokens per sequence. Any value >=128 fires the bug.",
)
p.add_argument(
"--dtypes", default="bf16,fp32",
help="Comma-separated dtypes to test. bf16 (default) reproduces; "
"fp32 is the workaround.",
)
return p.parse_args()
def is_rank0() -> bool:
return not dist.is_initialized() or dist.get_rank() == 0
def rprint(msg: str) -> None:
if is_rank0():
print(msg, flush=True)
def probe_inv_freq(wrapped: HFModelWrapper, label: str) -> None:
"""Per-rank dump of the rotary inv_freq buffer (the smoking gun)."""
rank = dist.get_rank() if dist.is_initialized() else 0
for name, buf in wrapped.model.named_buffers():
if not name.endswith("inv_freq") or "original" in name:
continue
n_nan = int(torch.isnan(buf).sum().item())
first5 = buf.detach().float().cpu().tolist()[:5]
print(
f"[probe_inv_freq {label} rank={rank}] name={name!r} "
f"dtype={buf.dtype} shape={tuple(buf.shape)} n_nan={n_nan} "
f"first5={first5}",
flush=True,
)
def load_and_forward(
*, model_path: str, bf16: bool, sequences: torch.Tensor,
attention_mask: torch.Tensor, sp_size: int,
) -> torch.Tensor:
"""Mirror SkyRL's ref-worker init + forward, returning the full log_probs."""
label = "bf16" if bf16 else "fp32"
strategy = FSDPStrategy(
fsdp_config=FSDPConfig(),
fsdp_strategy="fsdp",
seed=42,
micro_train_batch_size_per_gpu=1,
)
strategy.setup_distributed()
use_meta = should_use_meta_init(
use_meta_tensor=True, mesh=strategy.device_mesh,
)
wrapped = HFModelWrapper(
model_path,
use_flash_attention_2=True,
bf16=bf16,
sequence_parallel_size=sp_size,
use_sample_packing=True,
meta_init=use_meta,
)
# Ulysses setup (mirrors WorkerBase._seq_parallel_monkey_patch).
if sp_size > 1:
ws = dist.get_world_size()
dp_size = ws // sp_size
worker_mesh = torch.distributed.device_mesh.init_device_mesh(
"cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp"),
)
set_ulysses_sequence_parallel_group(worker_mesh["sp"].get_group())
apply_monkey_patch(model=wrapped.model, ulysses_sp_size=sp_size)
fsdp_model = strategy.prepare(wrapped)
fsdp_model.eval()
probe_inv_freq(wrapped, label)
with torch.no_grad(), torch.autocast(dtype=torch.bfloat16, device_type="cuda"):
# num_actions = seq_len-1 so we read the whole log_probs range.
log_probs = fsdp_model(
sequences, sequences.shape[1] - 1, attention_mask,
temperature=1.0, return_output=False,
)
return log_probs.detach().float().cpu()
def main() -> int:
args = parse_args()
dist.init_process_group(backend="nccl")
torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))
rprint(
f"[repro] model={args.model} sp_size={args.sp_size} "
f"seq_len={args.seq_len} world_size={dist.get_world_size()}"
)
# Build a tiny synthetic batch on rank 0, broadcast to others. Content
# doesn't matter — the bug is in init, not in numerics.
device = torch.cuda.current_device()
sequences = torch.randint(
10, 10000, (1, args.seq_len), dtype=torch.long, device=device,
)
dist.broadcast(sequences, src=0)
attention_mask = torch.ones_like(sequences)
results: dict[str, torch.Tensor] = {}
for dtype_name in [d.strip() for d in args.dtypes.split(",")]:
if dtype_name not in {"bf16", "fp32"}:
raise SystemExit(f"unknown dtype: {dtype_name}")
results[dtype_name] = load_and_forward(
model_path=args.model,
bf16=(dtype_name == "bf16"),
sequences=sequences,
attention_mask=attention_mask,
sp_size=args.sp_size,
)
log_probs = results[dtype_name]
n_nan = int(torch.isnan(log_probs).sum())
rprint(
f"[{dtype_name:>4}] log_probs shape={tuple(log_probs.shape)} "
f"n_nan={n_nan}/{log_probs.numel()}"
)
if "bf16" in results and "fp32" in results:
n_nan_bf16 = int(torch.isnan(results["bf16"]).sum())
n_nan_fp32 = int(torch.isnan(results["fp32"]).sum())
bug = n_nan_bf16 > 0 and n_nan_fp32 == 0
rprint(
f"\n[SUMMARY_1694] n_nan_bf16={n_nan_bf16} n_nan_fp32={n_nan_fp32} "
f"bug_reproduces={bug}"
)
dist.barrier()
dist.destroy_process_group()
return 0
if __name__ == "__main__":
sys.exit(main())
Expected output:
[probe_inv_freq bf16 rank=0] dtype=torch.float32 first5=[1.0, 0.806, ...] ← correct
[probe_inv_freq bf16 rank=1..N] dtype=torch.bfloat16 first5=[...garbage...] ← uninitialized
[bf16] log_probs n_nan=N/N ← every position NaN
[probe_inv_freq fp32 rank=0..N] dtype=torch.float32 first5=[1.0, 0.806, ...] ← all ranks correct
[fp32] log_probs n_nan=0/N ← clean
[SUMMARY_1694] bug_reproduces=True
Actual output with torch==2.11.0, SkyRL v0.2.0, transformers==5.8.0, and Python 3.12:
[probe_inv_freq bf16 rank=0] name='model.rotary_emb.inv_freq' dtype=torch.float32 shape=(64,) n_nan=0 first5=[1.0, 0.8058422207832336, 0.6493816375732422, 0.5232991576194763, 0.4216965138912201]
[probe_inv_freq bf16 rank=1] name='model.rotary_emb.inv_freq' dtype=torch.bfloat16 shape=(64,) n_nan=1 first5=[0.0, 1.0, 22675456.0, 0.8046875, 0.109375]
[probe_inv_freq bf16 rank=2] name='model.rotary_emb.inv_freq' dtype=torch.bfloat16 shape=(64,) n_nan=1 first5=[0.0, 1.0, 22675456.0, 0.8046875, 0.109375]
... (ranks 3-15 identical to rank 1: bf16 with 1 NaN and the same garbage values)
[bf16] log_probs shape=(1, 127) n_nan=127/127
[probe_inv_freq fp32 rank=0] name='model.rotary_emb.inv_freq' dtype=torch.float32 shape=(64,) n_nan=0 first5=[1.0, 0.8058422207832336, 0.6493816375732422, 0.5232991576194763, 0.4216965138912201]
[probe_inv_freq fp32 rank=1] name='model.rotary_emb.inv_freq' dtype=torch.float32 shape=(64,) n_nan=0 first5=[1.0, 0.8058422207832336, 0.6493816375732422, 0.5232991576194763, 0.4216965138912201]
... (ranks 2-15 identical to rank 0: fp32 with the correct rope frequencies)
[fp32] log_probs shape=(1, 127) n_nan=0/127
[SUMMARY_1694] n_nan_bf16=127 n_nan_fp32=0 bug_reproduces=True
Notice the bf16-path rank-1+ values are byte-identical across all non-rank-0 ranks (22675456.0 recurs verbatim), that's the same uninitialized GPU memory layout being read from to_empty() on every worker process at the same point in the allocator's free list.
Suggested fix
- Fix
_sync_non_persistent_buffers so it calls model.initialize_weights() on rank 0 before broadcasting.
- Invoke the fixed
_sync_non_persistent_buffers from the FSDP1 init path too, not just the FSDP2 path.
In SkyRL v0.2.0 on the FSDP1 path:
HFModelWrapper.__init__andTrainerConfig: defaultbf16toTrueshould_use_meta_init: defaults non-rank-0 ranks tometa_init=TrueTogether these leave the rotary embedding's non-persistent
inv_freqbuffer in uninitialized memory on every rank except rank 0. When sequence parallelism is engaged (sequence_parallel_size > 1) the corrupted buffer produces NaN at every position of every non-rank-0 rank's forward, which surfaces downstream aspolicy_kl: nanandfinal_loss: nanat step 1.This is closely related to huggingface/transformers#45902, just on the FSDP1 + dense Qwen3 path. Basically
transformerswants callers to usemodel.initialize_weights()(which SkyRL is not doing).Notably, the FSDP2 init path in SkyRL uses
_sync_non_persistent_buffers, which itself has the same root-cause bug: on rank 0 it just doesbuf.detach().cuda(), which on a meta-resident buffer reads uninitialized GPU memory and then broadcasts that garbage to every rank.Mechanism
See
HFModelWrapper.__init__undermeta_init=True.model.to(bf16)casts every parameter and non-persistent buffer to bf16 on the meta device/Qwen3RotaryEmbedding.inv_freq.from_pretrained(torch_dtype=...)branch, which keepsinv_freqat fp32 regardless of thedtypekwarg.FSDPStrategy.prepare(wrapped)wraps withparam_init_fn=init_fnplussync_module_states=True.init_fnon non-rank-0 ranks callsx.to_empty(device=cuda, recurse=False), which allocates real GPU memory, but does not initialize the values.inv_freqfilled with whatever was previously in that memory region.sync_module_statesthen tries to broadcast rank-0's fp32inv_freqto those bf16 slots; the dtype mismatch makes the broadcast silently fail, and rank-0's values never get written into non-rank-0's bf16 slot.to_empty()memory that was there before the broadcast call.At forward time, transformers
Qwen3RotaryEmbeddingmultiplies the garbageinv_freqwith position ids and ultimately and final logits are NaN on ranks 1..N-1.Reproducer
Expected output:
Actual output with
torch==2.11.0, SkyRL v0.2.0,transformers==5.8.0, and Python 3.12:Notice the bf16-path rank-1+ values are byte-identical across all non-rank-0 ranks (
22675456.0recurs verbatim), that's the same uninitialized GPU memory layout being read fromto_empty()on every worker process at the same point in the allocator's free list.Suggested fix
_sync_non_persistent_buffersso it callsmodel.initialize_weights()on rank 0 before broadcasting._sync_non_persistent_buffersfrom the FSDP1 init path too, not just the FSDP2 path.