Skip to content

Commit 24f1f1a

Browse files
evilsocketclaude
andcommitted
perf: cache F32 lm_head weight — eliminates 1 GB allocation per token
The F32 lm_head weight conversion (to_dtype(F32)) was creating a fresh 1 GB copy of the 248320×1024 weight matrix on every single generated token. Pre-caching it at model load time eliminates this allocation. Benchmark (M3 Pro, Qwen3.5-0.8B, 50 tokens): - Before: 16.1 tok/s - After: 36.7 tok/s (+128%, 2.28x speedup) Memory cost: +1 GB for the cached F32 weight (acceptable on 36 GB M3 Pro). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent b7c2e2a commit 24f1f1a

1 file changed

Lines changed: 8 additions & 3 deletions

File tree

cake-core/src/models/common/text_model.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ pub struct TextModelBase {
137137
pub ln_f_weight: Tensor,
138138
pub ln_f_eps: f32,
139139
pub lm_head_weight: Tensor,
140+
/// Cached F32 lm_head weight for logit precision (avoids 1 GB allocation per token).
141+
pub lm_head_weight_f32: Tensor,
140142

141143
pub logits_processor: LogitsProcessor,
142144

@@ -245,6 +247,9 @@ impl TextModelBase {
245247

246248
let generated = 0;
247249

250+
// Pre-cache F32 lm_head weight to avoid 1 GB allocation+copy per token.
251+
let lm_head_weight_f32 = lm_head_weight.to_dtype(candle_core::DType::F32)?;
252+
248253
Ok(Self {
249254
tokenizer,
250255
tokens,
@@ -258,6 +263,7 @@ impl TextModelBase {
258263
ln_f_weight,
259264
ln_f_eps,
260265
lm_head_weight,
266+
lm_head_weight_f32,
261267
logits_processor,
262268
})
263269
}
@@ -347,14 +353,13 @@ impl TextModelBase {
347353

348354
// lm_head in F32 for logit precision — F16 matmul over 248k vocab
349355
// amplifies accumulated errors, shifting the sampling distribution.
356+
// Uses pre-cached F32 weight to avoid 1 GB allocation per token.
350357
let x_f32 = x.to_dtype(candle_core::DType::F32)
351358
.map_err(|e| anyhow!("error in lm_head x to_f32: {e}"))?;
352-
let lm_w_f32 = self.lm_head_weight.to_dtype(candle_core::DType::F32)
353-
.map_err(|e| anyhow!("error in lm_head w to_f32: {e}"))?;
354359
let logits = self
355360
.ctx
356361
.backend
357-
.linear_forward(&x_f32, &lm_w_f32, None)
362+
.linear_forward(&x_f32, &self.lm_head_weight_f32, None)
358363
.map_err(|e| anyhow!("error in lm_head.forward: {e}"))?;
359364
// Note: no explicit sync needed here — the CPU-side logits sampling
360365
// (to_vec1 in LogitsProcessor) implicitly synchronizes the Metal command buffer.

0 commit comments

Comments
 (0)