Skip to content

Fix meta-init bf16 cast corrupting rotary inv_freq under sequence parallelism#1725

Open
jamesbraza wants to merge 3 commits into
NovaSky-AI:mainfrom
EdisonScientific:fix/meta-init-inv-freq-bf16-cast
Open

Fix meta-init bf16 cast corrupting rotary inv_freq under sequence parallelism#1725
jamesbraza wants to merge 3 commits into
NovaSky-AI:mainfrom
EdisonScientific:fix/meta-init-inv-freq-bf16-cast

Conversation

@jamesbraza
Copy link
Copy Markdown
Contributor

Meta-init cast non-persistent buffers (rotary inv_freq) to bf16 while rank-0's from_pretrained kept them fp32, so the init-time buffer broadcast reinterpreted rank-0's fp32 bytes as bf16 garbage and produced NaN attention under SP>1. Now we cast only params and persistent buffers, matching transformers from_pretrained.

Closes #1709

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request modifies the meta-initialization logic in model_wrapper.py to cast only parameters and persistent buffers to the target dtype, preventing corruption of non-persistent buffers (like inv_freq in rotary embeddings) during FSDP initialization. It also adds corresponding unit and integration tests. The reviewer identified critical issues, including the use of a non-existent PyTorch method named_non_persistent_buffers() which will cause runtime AttributeErrors, the use of a non-existent model (Qwen/Qwen3-0.6B) in tests, and a similar blanket .to cast issue in fsdp_worker.py that needs to be addressed.

Comment thread skyrl/backends/skyrl_train/workers/model_wrapper.py
Comment thread tests/backends/skyrl_train/models/test_models.py Outdated
Comment thread skyrl/backends/skyrl_train/workers/model_wrapper.py
Comment thread tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py Outdated
@SumanthRH SumanthRH self-requested a review May 28, 2026 23:45
@SumanthRH
Copy link
Copy Markdown
Member

@jamesbraza thanks for the PR!

As a sanity check I ran the test script test_meta_init.py:

uv run --isolated --extra dev --extra fsdp pytest -s -vvv tests/backends/skyrl_train/gpu/gpu_ci/test_meta_init.py

against origin/main

I expected NaN results without the fix in your PR, but tests pass. Can you confirm that the unit test is correctly simulating the issue you saw?

@jamesbraza
Copy link
Copy Markdown
Contributor Author

I expected NaN results without the fix in your PR, but tests pass. Can you confirm that the unit test is correctly simulating the issue you saw?

Yep sorry about that, you're right the GPU tests were passing on main. Fixed it just now, turns out there was a small detail missing. Nice catch

@SumanthRH
Copy link
Copy Markdown
Member

@jamesbraza thanks for the updated test, but it would be great to get a regression test for the NaN failures with SP > 1.

I re-ran the updated test and

  1. assertions fail for dtype - inv_freq gets cast to bfloat16 instead of float32
  2. assertions pass for Nan - no NaNs observed with the model used.

The reproduction script you provided before in #1709 clearly produces NaNs with BF16, but then the script forces meta init for Qwen3-1.7B, which clearly uses tie_word_embeddings=True, which means that SkyRL code skips meta init: https://huggingface.co/Qwen/Qwen3-1.7B/blob/main/config.json

So the reproducer is not simulating an actual FSDP init for Qwen3 1.7B with SkyRL.

jamesbraza and others added 3 commits May 29, 2026 14:34
…aSky-AI#1709)

HFModelWrapper's meta-init cast every parameter and buffer to the target dtype, including non-persistent buffers like Qwen3RotaryEmbedding.inv_freq that from_pretrained (rank 0) leaves at fp32. The dtype divergence made the init-time rank-0->all non-persistent-buffer broadcast reinterpret rank-0's fp32 bytes into the half-width bf16 buffers, NaN-ing rotary attention on every non-rank-0 rank under sequence parallelism.

Now cast only parameters and persistent buffers, matching from_pretrained's default-dtype context. Adds a CPU unit test for the dtype invariant and a 2-rank SP=2 GPU test (inv_freq finite + forward NaN-free); verified on 2-node 8xH100 that the reproducer flips bug_reproduces True->False.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
The GPU test used `Qwen/Qwen3-0.6B`, which ties word embeddings, so `FSDPRefWorkerBase.init_model` gated meta-init off (`use_meta_tensor=not tie_word_embeddings`) and the test passed on `main` even without the fix — it never reached the buggy path.

Switch both meta-init tests to a non-tied model (`llamafactory/tiny-random-Llama-3`) so the meta path is taken, and assert each rank's `inv_freq` stays fp32: the corrupted bf16 values can be finite garbage (e.g. ~2e7), so the finiteness/forward checks alone could pass. Confirmed on 2xH100 that both tests now fail on `main` and pass with the fix.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
…divergence

The prior test used a tiny non-tied model (tiny-random-Llama-3, head_dim 4), which surfaced the `inv_freq` bf16 dtype divergence but not the forward NaN: too few rotary frequencies for the corrupted bf16 buffer to land on a NaN. Switch to Qwen/Qwen3-8B (non-tied, head_dim 128) so the corruption reproduces the actual NaN logits under SP>1, and assert forward-NaN-free first (the headline symptom) with the dtype check as a deterministic backstop.

Verified on 2xH100: the forward NaN fires deterministically on main (3/3 runs, bf16=True) and the test passes with the fix.

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
@jamesbraza jamesbraza force-pushed the fix/meta-init-inv-freq-bf16-cast branch from be23934 to ae0e8bf Compare May 30, 2026 00:57
@jamesbraza
Copy link
Copy Markdown
Contributor Author

@jamesbraza thanks for the updated test, but it would be great to get a regression test for the NaN failures with SP > 1.
...

Yeah thanks for pointing this out, agreed let's get this right, I just pushed a commit upgrading the test.

So inv_freq holds head_dim/2 rotary frequencies. The prior model tiny-random-Llama-3 has head_dim of 4, so it only had 2 rotary frequencies. Now with Qwen3-8B (head_dim of 128) there's 64 rotary frequencies, so the NaN will be reliably reproduced (it worked on 3/3 attempts) on a non-rank-0 inv_freq.

Also Qwen3-8B is a non-tied model, so init_model now takes the meta path on its own, covering that code.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

FSDP1 ref worker leaves non-rank-0 inv_freq uninitialized under meta_init=True + bf16=True, producing NaN forward under SP>1

2 participants