Skip to content

FSDP1 ref worker leaves non-rank-0 inv_freq uninitialized under meta_init=True + bf16=True, producing NaN forward under SP>1 #1709

@jamesbraza

Description

@jamesbraza

In SkyRL v0.2.0 on the FSDP1 path:

Together these leave the rotary embedding's non-persistent inv_freq buffer in uninitialized memory on every rank except rank 0. When sequence parallelism is engaged (sequence_parallel_size > 1) the corrupted buffer produces NaN at every position of every non-rank-0 rank's forward, which surfaces downstream as policy_kl: nan and final_loss: nan at step 1.

This is closely related to huggingface/transformers#45902, just on the FSDP1 + dense Qwen3 path. Basically transformers wants callers to use model.initialize_weights() (which SkyRL is not doing).

Notably, the FSDP2 init path in SkyRL uses _sync_non_persistent_buffers, which itself has the same root-cause bug: on rank 0 it just does buf.detach().cuda(), which on a meta-resident buffer reads uninitialized GPU memory and then broadcasts that garbage to every rank.

Mechanism

See HFModelWrapper.__init__ under meta_init=True.

  1. model.to(bf16) casts every parameter and non-persistent buffer to bf16 on the meta device/
    • Including Qwen3RotaryEmbedding.inv_freq.
  2. Rank 0 takes the from_pretrained(torch_dtype=...) branch, which keeps inv_freq at fp32 regardless of the dtype kwarg.
  3. FSDPStrategy.prepare(wrapped) wraps with param_init_fn=init_fn plus sync_module_states=True. init_fn on non-rank-0 ranks calls x.to_empty(device=cuda, recurse=False), which allocates real GPU memory, but does not initialize the values.
    • Non-rank-0 ranks now have a bf16 slot for inv_freq filled with whatever was previously in that memory region.
  4. sync_module_states then tries to broadcast rank-0's fp32 inv_freq to those bf16 slots; the dtype mismatch makes the broadcast silently fail, and rank-0's values never get written into non-rank-0's bf16 slot.
    • The slot retains the uninitialized to_empty() memory that was there before the broadcast call.

At forward time, transformers Qwen3RotaryEmbedding multiplies the garbage inv_freq with position ids and ultimately and final logits are NaN on ranks 1..N-1.

Reproducer

import argparse
import os
import sys

import torch
import torch.distributed as dist
from skyrl.backends.skyrl_train.distributed.fsdp_strategy import FSDPStrategy
from skyrl.backends.skyrl_train.distributed.fsdp_utils import should_use_meta_init
from skyrl.backends.skyrl_train.distributed.ulysses import (
    apply_monkey_patch,
    set_ulysses_sequence_parallel_group,
)
from skyrl.backends.skyrl_train.workers.model_wrapper import HFModelWrapper
from skyrl.train.config import FSDPConfig


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description=__doc__)
    p.add_argument("--model", default="Qwen/Qwen3-1.7B")
    p.add_argument(
        "--sp-size", type=int, default=2,
        help="Ulysses sequence-parallel size. >=2 to fire the bug. "
             "world_size must be divisible by sp_size.",
    )
    p.add_argument(
        "--seq-len", type=int, default=128,
        help="Tokens per sequence. Any value >=128 fires the bug.",
    )
    p.add_argument(
        "--dtypes", default="bf16,fp32",
        help="Comma-separated dtypes to test. bf16 (default) reproduces; "
             "fp32 is the workaround.",
    )
    return p.parse_args()


def is_rank0() -> bool:
    return not dist.is_initialized() or dist.get_rank() == 0


def rprint(msg: str) -> None:
    if is_rank0():
        print(msg, flush=True)


def probe_inv_freq(wrapped: HFModelWrapper, label: str) -> None:
    """Per-rank dump of the rotary inv_freq buffer (the smoking gun)."""
    rank = dist.get_rank() if dist.is_initialized() else 0
    for name, buf in wrapped.model.named_buffers():
        if not name.endswith("inv_freq") or "original" in name:
            continue
        n_nan = int(torch.isnan(buf).sum().item())
        first5 = buf.detach().float().cpu().tolist()[:5]
        print(
            f"[probe_inv_freq {label} rank={rank}] name={name!r} "
            f"dtype={buf.dtype} shape={tuple(buf.shape)} n_nan={n_nan} "
            f"first5={first5}",
            flush=True,
        )


def load_and_forward(
    *, model_path: str, bf16: bool, sequences: torch.Tensor,
    attention_mask: torch.Tensor, sp_size: int,
) -> torch.Tensor:
    """Mirror SkyRL's ref-worker init + forward, returning the full log_probs."""
    label = "bf16" if bf16 else "fp32"
    strategy = FSDPStrategy(
        fsdp_config=FSDPConfig(),
        fsdp_strategy="fsdp",
        seed=42,
        micro_train_batch_size_per_gpu=1,
    )
    strategy.setup_distributed()

    use_meta = should_use_meta_init(
        use_meta_tensor=True, mesh=strategy.device_mesh,
    )
    wrapped = HFModelWrapper(
        model_path,
        use_flash_attention_2=True,
        bf16=bf16,
        sequence_parallel_size=sp_size,
        use_sample_packing=True,
        meta_init=use_meta,
    )

    # Ulysses setup (mirrors WorkerBase._seq_parallel_monkey_patch).
    if sp_size > 1:
        ws = dist.get_world_size()
        dp_size = ws // sp_size
        worker_mesh = torch.distributed.device_mesh.init_device_mesh(
            "cuda", mesh_shape=(dp_size, sp_size), mesh_dim_names=("dp", "sp"),
        )
        set_ulysses_sequence_parallel_group(worker_mesh["sp"].get_group())
        apply_monkey_patch(model=wrapped.model, ulysses_sp_size=sp_size)

    fsdp_model = strategy.prepare(wrapped)
    fsdp_model.eval()
    probe_inv_freq(wrapped, label)

    with torch.no_grad(), torch.autocast(dtype=torch.bfloat16, device_type="cuda"):
        # num_actions = seq_len-1 so we read the whole log_probs range.
        log_probs = fsdp_model(
            sequences, sequences.shape[1] - 1, attention_mask,
            temperature=1.0, return_output=False,
        )
    return log_probs.detach().float().cpu()


def main() -> int:
    args = parse_args()
    dist.init_process_group(backend="nccl")
    torch.cuda.set_device(int(os.environ["LOCAL_RANK"]))

    rprint(
        f"[repro] model={args.model} sp_size={args.sp_size} "
        f"seq_len={args.seq_len} world_size={dist.get_world_size()}"
    )

    # Build a tiny synthetic batch on rank 0, broadcast to others. Content
    # doesn't matter — the bug is in init, not in numerics.
    device = torch.cuda.current_device()
    sequences = torch.randint(
        10, 10000, (1, args.seq_len), dtype=torch.long, device=device,
    )
    dist.broadcast(sequences, src=0)
    attention_mask = torch.ones_like(sequences)

    results: dict[str, torch.Tensor] = {}
    for dtype_name in [d.strip() for d in args.dtypes.split(",")]:
        if dtype_name not in {"bf16", "fp32"}:
            raise SystemExit(f"unknown dtype: {dtype_name}")
        results[dtype_name] = load_and_forward(
            model_path=args.model,
            bf16=(dtype_name == "bf16"),
            sequences=sequences,
            attention_mask=attention_mask,
            sp_size=args.sp_size,
        )
        log_probs = results[dtype_name]
        n_nan = int(torch.isnan(log_probs).sum())
        rprint(
            f"[{dtype_name:>4}] log_probs shape={tuple(log_probs.shape)} "
            f"n_nan={n_nan}/{log_probs.numel()}"
        )

    if "bf16" in results and "fp32" in results:
        n_nan_bf16 = int(torch.isnan(results["bf16"]).sum())
        n_nan_fp32 = int(torch.isnan(results["fp32"]).sum())
        bug = n_nan_bf16 > 0 and n_nan_fp32 == 0
        rprint(
            f"\n[SUMMARY_1694] n_nan_bf16={n_nan_bf16} n_nan_fp32={n_nan_fp32} "
            f"bug_reproduces={bug}"
        )

    dist.barrier()
    dist.destroy_process_group()
    return 0


if __name__ == "__main__":
    sys.exit(main())

Expected output:

[probe_inv_freq bf16 rank=0]    dtype=torch.float32  first5=[1.0, 0.806, ...]   ← correct
[probe_inv_freq bf16 rank=1..N] dtype=torch.bfloat16 first5=[...garbage...]      ← uninitialized
[bf16]  log_probs  n_nan=N/N    ← every position NaN
[probe_inv_freq fp32 rank=0..N] dtype=torch.float32  first5=[1.0, 0.806, ...]   ← all ranks correct
[fp32]  log_probs  n_nan=0/N    ← clean
[SUMMARY_1694] bug_reproduces=True

Actual output with torch==2.11.0, SkyRL v0.2.0, transformers==5.8.0, and Python 3.12:

[probe_inv_freq bf16 rank=0]  name='model.rotary_emb.inv_freq' dtype=torch.float32  shape=(64,) n_nan=0 first5=[1.0, 0.8058422207832336, 0.6493816375732422, 0.5232991576194763, 0.4216965138912201]
[probe_inv_freq bf16 rank=1]  name='model.rotary_emb.inv_freq' dtype=torch.bfloat16 shape=(64,) n_nan=1 first5=[0.0, 1.0, 22675456.0, 0.8046875, 0.109375]
[probe_inv_freq bf16 rank=2]  name='model.rotary_emb.inv_freq' dtype=torch.bfloat16 shape=(64,) n_nan=1 first5=[0.0, 1.0, 22675456.0, 0.8046875, 0.109375]
... (ranks 3-15 identical to rank 1: bf16 with 1 NaN and the same garbage values)
[bf16]  log_probs shape=(1, 127) n_nan=127/127

[probe_inv_freq fp32 rank=0]  name='model.rotary_emb.inv_freq' dtype=torch.float32 shape=(64,) n_nan=0 first5=[1.0, 0.8058422207832336, 0.6493816375732422, 0.5232991576194763, 0.4216965138912201]
[probe_inv_freq fp32 rank=1]  name='model.rotary_emb.inv_freq' dtype=torch.float32 shape=(64,) n_nan=0 first5=[1.0, 0.8058422207832336, 0.6493816375732422, 0.5232991576194763, 0.4216965138912201]
... (ranks 2-15 identical to rank 0: fp32 with the correct rope frequencies)
[fp32]  log_probs shape=(1, 127) n_nan=0/127

[SUMMARY_1694] n_nan_bf16=127 n_nan_fp32=0 bug_reproduces=True

Notice the bf16-path rank-1+ values are byte-identical across all non-rank-0 ranks (22675456.0 recurs verbatim), that's the same uninitialized GPU memory layout being read from to_empty() on every worker process at the same point in the allocator's free list.

Suggested fix

  1. Fix _sync_non_persistent_buffers so it calls model.initialize_weights() on rank 0 before broadcasting.
  2. Invoke the fixed _sync_non_persistent_buffers from the FSDP1 init path too, not just the FSDP2 path.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions