Skip to content

Commit f1787d9

Browse files
evilsocketclaude
andcommitted
perf: revert F32 lm_head — F16 is correct after norm fixes
The F32 lm_head was added to compensate for logit distribution errors caused by the GDN gated norm bug (+1.0 on non-residual weights). Now that the norm is fixed, F16 lm_head produces correct output at all temperatures. Removing F32 saves: - 1 GB memory (no cached F32 weight) - ~6ms/token (reads 508 MB instead of 1 GB from memory bandwidth) Benchmark (M3 Pro, Qwen3.5-0.8B, 50 tokens): - F32 lm_head (cached): 36.7 tok/s, +1 GB memory - F16 lm_head: 42.4 tok/s, no extra memory (+15.5%) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 24f1f1a commit f1787d9

1 file changed

Lines changed: 1 addition & 12 deletions

File tree

cake-core/src/models/common/text_model.rs

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ pub struct TextModelBase {
137137
pub ln_f_weight: Tensor,
138138
pub ln_f_eps: f32,
139139
pub lm_head_weight: Tensor,
140-
/// Cached F32 lm_head weight for logit precision (avoids 1 GB allocation per token).
141-
pub lm_head_weight_f32: Tensor,
142140

143141
pub logits_processor: LogitsProcessor,
144142

@@ -247,9 +245,6 @@ impl TextModelBase {
247245

248246
let generated = 0;
249247

250-
// Pre-cache F32 lm_head weight to avoid 1 GB allocation+copy per token.
251-
let lm_head_weight_f32 = lm_head_weight.to_dtype(candle_core::DType::F32)?;
252-
253248
Ok(Self {
254249
tokenizer,
255250
tokens,
@@ -263,7 +258,6 @@ impl TextModelBase {
263258
ln_f_weight,
264259
ln_f_eps,
265260
lm_head_weight,
266-
lm_head_weight_f32,
267261
logits_processor,
268262
})
269263
}
@@ -351,15 +345,10 @@ impl TextModelBase {
351345
.contiguous()
352346
.map_err(|e| anyhow!("error in x.i.contiguous: {e}"))?;
353347

354-
// lm_head in F32 for logit precision — F16 matmul over 248k vocab
355-
// amplifies accumulated errors, shifting the sampling distribution.
356-
// Uses pre-cached F32 weight to avoid 1 GB allocation per token.
357-
let x_f32 = x.to_dtype(candle_core::DType::F32)
358-
.map_err(|e| anyhow!("error in lm_head x to_f32: {e}"))?;
359348
let logits = self
360349
.ctx
361350
.backend
362-
.linear_forward(&x_f32, &self.lm_head_weight_f32, None)
351+
.linear_forward(&x, &self.lm_head_weight, None)
363352
.map_err(|e| anyhow!("error in lm_head.forward: {e}"))?;
364353
// Note: no explicit sync needed here — the CPU-side logits sampling
365354
// (to_vec1 in LogitsProcessor) implicitly synchronizes the Metal command buffer.

0 commit comments

Comments
 (0)