Skip to content

Commit 64b66e5

Browse files
evilsocketclaude
andcommitted
perf: use fused SDPA kernel for Qwen3.5 full attention generation
Route seq_len=1 (token generation) in Qwen3.5 full attention layers through the fused_vector_attention Metal kernel via backend.sdpa(). This replaces 4+ separate GPU dispatches (repeat_kv head expansion, Q@K^T matmul, dtype cast, softmax, att@V matmul) with a single fused kernel that handles GQA natively and uses online softmax. Benchmark (M3 Pro, Qwen3.5-0.8B, 50 tokens): - Before: 14.7 tok/s - After: 15.2 tok/s (+3.4%) The prefill path (seq_len > 1) still uses the manual mixed-precision attention, which is only called once at the start of a conversation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 4c33647 commit 64b66e5

1 file changed

Lines changed: 20 additions & 17 deletions

File tree

cake-core/src/models/qwen3_5/full_attention.rs

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -206,38 +206,41 @@ impl Qwen3_5FullAttention {
206206
).map_err(|e| anyhow!("flash_attn: {e}"))?;
207207
}
208208

209-
// Metal: mixed-precision attention — F16 matmuls + F32 softmax.
210-
// F16 SDPA causes garbage, F32 SDPA exceeds threadgroup memory.
209+
// Metal path: fused SDPA for generation, mixed-precision for prefill.
211210
#[cfg(feature = "metal")]
212211
if matches!(q.device(), candle_core::Device::Metal(_)) {
212+
// Generation (seq_len=1): fused kernel — single dispatch with native
213+
// GQA (no repeat_kv), online softmax, no attention matrix materialization.
214+
// Replaces 4+ separate dispatches (repeat_kv + 2 matmuls + softmax + dtype casts).
215+
if seq_len == 1 {
216+
let scale = 1.0 / (self.head_dim as f32).sqrt();
217+
break 'attn self.backend.sdpa(&q, &k, &v, None, false, scale)
218+
.map_err(|e| anyhow!("sdpa: {e}"))?;
219+
}
220+
221+
// Prefill (seq_len > 1): F16 matmuls + F32 softmax (F16 SDPA causes
222+
// garbage, F32 SDPA exceeds threadgroup memory).
213223
let k = self.repeat_kv(k).map_err(|e| anyhow!("repeat_kv k: {e}"))?;
214224
let v = self.repeat_kv(v).map_err(|e| anyhow!("repeat_kv v: {e}"))?;
215225
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
216226
let att = att.to_dtype(candle_core::DType::F32)?;
217-
let att = if seq_len == 1 {
218-
att
219-
} else {
220-
let tril = Tensor::tril2(seq_len, candle_core::DType::F32, att.device())
221-
.map_err(|e| anyhow!("tril: {e}"))?;
222-
let mask = ((tril - 1.0)? * 1e9)?;
223-
let mask = mask.broadcast_as(att.shape())
224-
.map_err(|e| anyhow!("mask broadcast: {e}"))?;
225-
(att + mask).map_err(|e| anyhow!("mask add: {e}"))?
226-
};
227+
let tril = Tensor::tril2(seq_len, candle_core::DType::F32, att.device())
228+
.map_err(|e| anyhow!("tril: {e}"))?;
229+
let mask = ((tril - 1.0)? * 1e9)?;
230+
let mask = mask.broadcast_as(att.shape())
231+
.map_err(|e| anyhow!("mask broadcast: {e}"))?;
232+
let att = (att + mask).map_err(|e| anyhow!("mask add: {e}"))?;
227233
let att = self.backend.softmax(&att, att.rank() - 1)?;
228234
let att = att.to_dtype(v.dtype())?;
229235
break 'attn att.matmul(&v.contiguous()?)
230236
.map_err(|e| anyhow!("att matmul v: {e}"))?;
231237
}
232238

233-
// Manual attention with GQA head expansion (CPU fallback)
239+
// CPU fallback: manual attention with GQA head expansion
234240
let k = self.repeat_kv(k).map_err(|e| anyhow!("repeat_kv k: {e}"))?;
235241
let v = self.repeat_kv(v).map_err(|e| anyhow!("repeat_kv v: {e}"))?;
236-
237242
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
238-
let att = if seq_len == 1 {
239-
att
240-
} else {
243+
let att = if seq_len == 1 { att } else {
241244
let mask = cache.mask(seq_len, att.device())
242245
.map_err(|e| anyhow!("mask: {e}"))?
243246
.broadcast_as(att.shape())

0 commit comments

Comments
 (0)