Skip to content

Commit c04420f

Browse files
evilsocketclaude
andcommitted
fix: correct residual RMS norm for Qwen3.5 — fixes garbage output
The root cause of garbage output (on both CPU and Metal) was incorrect handling of residual RMS norm weights. Qwen3.5 uses (1 + weight) * rms_norm(x) for all norm layers, meaning stored weights are deltas from 0 that need +1.0 added at load time. Two bugs: 1. load_rms_norm_weight had an auto-detection heuristic (threshold 0.5) that incorrectly skipped the +1.0 for weights that had drifted above 0.5 during training (e.g., linear_attn.norm, later layer input norms). Removed the heuristic — always add 1.0 when residual_rms_norm is set. 2. RmsNormGated in linear_attention.rs loaded its weight directly without adding +1.0, unlike all other norms that go through load_rms_norm_weight. Also reverted unnecessary .contiguous() calls added during debugging. Verified against HuggingFace reference implementation (transformers 5.3.0) which confirms: forward = output * (1.0 + self.weight.float()). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent cc511ce commit c04420f

2 files changed

Lines changed: 6 additions & 16 deletions

File tree

cake-core/src/models/common/config.rs

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,20 +159,9 @@ pub fn load_rms_norm_weight(
159159
) -> candle_core::Result<candle_core::Tensor> {
160160
let weight = vb.get(size, "weight")?;
161161
if residual {
162-
// Auto-detect: some quantized model variants (e.g., MLX-quantized) already apply
163-
// the (1+w) transformation during quantization, storing the final weight directly.
164-
// Detect by checking if the mean is closer to 0 (residual) or 1 (already transformed).
165-
let mean: f32 = weight.to_dtype(candle_core::DType::F32)?
166-
.mean_all()?
167-
.to_scalar()?;
168-
if mean.abs() < 0.5 {
169-
// Weights near 0: residual pattern, add 1.0
170-
Ok((weight + 1.0)?)
171-
} else {
172-
// Weights near 1: already transformed (e.g., MLX quantized), use as-is
173-
log::debug!("rms_norm weight mean={mean:.3}: skipping residual +1.0");
174-
Ok(weight)
175-
}
162+
// Residual RMS norm: forward = (1 + weight) * rms_norm(x).
163+
// Weights are stored as deltas from 0, so add 1.0 at load time.
164+
Ok((weight + 1.0)?)
176165
} else {
177166
Ok(weight)
178167
}

cake-core/src/models/qwen3_5/linear_attention.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,9 @@ struct RmsNormGated {
2626

2727
impl RmsNormGated {
2828
fn load(size: usize, eps: f64, vb: VarBuilder, backend: Arc<dyn ComputeBackend>) -> Result<Self> {
29-
// Store weight as F32 to match the recurrent step's F32 output.
30-
let weight = vb.get(size, "weight")?.to_dtype(DType::F32)?;
29+
// Residual RMS norm: forward = (1 + weight) * rms_norm(x) * silu(z).
30+
// Store as F32 to match the recurrent step's F32 output.
31+
let weight = (vb.get(size, "weight")?.to_dtype(DType::F32)? + 1.0)?;
3132
Ok(Self { weight, eps: eps as f32, backend })
3233
}
3334

0 commit comments

Comments
 (0)