Skip to content

Commit 225396a

Browse files
committed
Merge branch 'main' of github.com:evilsocket/cake
2 parents 7f2d640 + 316a0de commit 225396a

7 files changed

Lines changed: 64 additions & 48 deletions

File tree

CLAUDE.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,22 @@ cargo build --release --features metal
2424
cargo build --release --features vulkan
2525
```
2626

27+
## Acceleration Features
28+
29+
| Feature | Platform | Backend | Best For | Notes |
30+
|---------|----------|---------|----------|-------|
31+
| `metal` | macOS (Apple Silicon) | GPU via MPS + custom MSL kernels | Primary inference on Mac | Fastest option on Apple Silicon (~42 tok/s on M3 Pro) |
32+
| `cuda` | Linux (NVIDIA GPU) | GPU via cuBLAS/cuDNN | Primary inference on Linux | Requires CUDA toolkit matching driver version |
33+
| `accelerate` | macOS | CPU via Apple Accelerate (AMX) | CPU-only F32 inference on Mac | 2.7x faster than pure-Rust for F32 matmul; no F16 support |
34+
| `vulkan` | Any (Vulkan 1.3+) | GPU via Vulkan compute shaders | Steam Deck, AMD GPUs | Portable but less optimized than Metal/CUDA |
35+
| (none) | Any | CPU via pure-Rust `gemm` | Portable CPU fallback | F16 weights stay F16, avoids bandwidth doubling |
36+
37+
**When to use which:**
38+
- **Apple Silicon (stevie.local):** Use `--features metal`. Metal is 1.6x faster than CPU F16 (42 vs 26 tok/s). The `accelerate` feature doesn't help with Metal and doesn't support F16 matmul, so CPU F16 (default, no features) is actually faster than `accelerate` with F32 (26 vs 23 tok/s).
39+
- **NVIDIA GPU (blade/bahamut):** Use `--features cuda`. Add `flash-attn` for flash attention support.
40+
- **CPU-only with F32 models:** Use `--features accelerate` on macOS for 2.7x faster F32 matmul. On Linux, consider linking against MKL or OpenBLAS.
41+
- **CPU-only with F16 models:** Use no features — pure-Rust `gemm` with F16 avoids the 2x memory bandwidth penalty of converting to F32.
42+
2743
## Interactive Chat
2844

2945
```bash

README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@ Cake is a **multimodal AI inference server** written in Rust that can run models
3232
### Build
3333

3434
```sh
35-
cargo build --release --features cuda # Linux (NVIDIA)
36-
cargo build --release --features metal # macOS (Apple Silicon)
37-
cargo build --release --features vulkan # Linux (AMD/Intel/Steam Deck)
38-
cargo build --release # CPU only
35+
cargo build --release --features cuda # Linux (NVIDIA)
36+
cargo build --release --features metal # macOS (Apple Silicon GPU)
37+
cargo build --release --features accelerate # macOS (Apple Silicon CPU, F32 models)
38+
cargo build --release --features vulkan # Linux (AMD/Intel/Steam Deck)
39+
cargo build --release # CPU only (portable)
3940
```
4041

4142
### Models

cake-core/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ base64 = "0.22.1"
7070
default = ["master", "llama", "qwen2", "qwen3_5", "qwen3", "qwen3_moe", "qwen3_5_moe", "phi4", "mistral", "gemma3", "falcon3", "olmo2", "exaone4", "flux", "vibevoice", "luxtts"]
7171

7272
metal = ["candle-core/metal", "candle-nn/metal", "candle-transformers/metal", "dep:candle-metal-kernels", "dep:objc2-metal"]
73+
accelerate = ["candle-core/accelerate", "candle-nn/accelerate", "candle-transformers/accelerate"]
7374
cuda = ["candle-core/cuda", "candle-nn/cuda", "candle-transformers/cuda", "dep:bindgen_cuda"]
7475
flash-attn = ["cuda", "dep:candle-flash-attn"]
7576
vulkan = ["dep:ash", "dep:gpu-allocator", "dep:bytemuck"]

cake-core/src/backends/metal/ops.msl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,3 +838,4 @@ kernel void fused_vector_attention_f32(
838838

839839
output[bh * head_dim + d] = acc * (1.0f / sum_exp);
840840
}
841+

cake-core/src/models/common/text_model.rs

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -333,17 +333,10 @@ impl TextModelBase {
333333

334334
let head_start = std::time::Instant::now();
335335

336-
// Final norm + lm_head in F32 for logit precision — F16 through 24 layers
337-
// accumulates small errors that get amplified across 248k vocab entries,
338-
// shifting the sampling distribution enough to cause wrong-language output.
339-
let x_f32 = x.to_dtype(candle_core::DType::F32)
340-
.map_err(|e| anyhow!("error in ln_f x to_f32: {e}"))?;
341-
let w_f32 = self.ln_f_weight.to_dtype(candle_core::DType::F32)
342-
.map_err(|e| anyhow!("error in ln_f w to_f32: {e}"))?;
343336
let x = self
344337
.ctx
345338
.backend
346-
.rms_norm(&x_f32, &w_f32, self.ln_f_eps)
339+
.rms_norm(&x, &self.ln_f_weight, self.ln_f_eps)
347340
.map_err(|e| anyhow!("error in ln_f.forward: {e}"))?;
348341

349342
let x = x
@@ -352,12 +345,10 @@ impl TextModelBase {
352345
.contiguous()
353346
.map_err(|e| anyhow!("error in x.i.contiguous: {e}"))?;
354347

355-
let lm_w_f32 = self.lm_head_weight.to_dtype(candle_core::DType::F32)
356-
.map_err(|e| anyhow!("error in lm_head w to_f32: {e}"))?;
357348
let logits = self
358349
.ctx
359350
.backend
360-
.linear_forward(&x, &lm_w_f32, None)
351+
.linear_forward(&x, &self.lm_head_weight, None)
361352
.map_err(|e| anyhow!("error in lm_head.forward: {e}"))?;
362353
// Note: no explicit sync needed here — the CPU-side logits sampling
363354
// (to_vec1 in LogitsProcessor) implicitly synchronizes the Metal command buffer.

cake-core/src/models/qwen3_5/full_attention.rs

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,11 @@ impl Qwen3_5FullAttention {
138138
let qkv = self.backend.linear_forward(x, &self.qkv_proj_weight, None)
139139
.map_err(|e| anyhow!("qkv_proj: {e}"))?;
140140

141-
// Flush GPU commands after QKV matmul (always needed — full attention
142-
// accumulates ~24 commands between syncs, can't afford more)
143-
let _ = self.backend.synchronize();
141+
// Flush GPU commands after QKV matmul — needed for prefill where many
142+
// operations follow. Generation (seq_len=1) uses fused SDPA with few commands.
143+
if seq_len > 1 {
144+
let _ = self.backend.synchronize();
145+
}
144146

145147
// Split: Q (doubled for gating), K, V
146148
let q_out = qkv.narrow(D::Minus1, 0, self.q_size)
@@ -206,45 +208,46 @@ impl Qwen3_5FullAttention {
206208
).map_err(|e| anyhow!("flash_attn: {e}"))?;
207209
}
208210

209-
// Metal: mixed-precision attention — F16 matmuls + F32 softmax.
210-
// F16 SDPA causes garbage, F32 SDPA exceeds threadgroup memory.
211+
// Metal path: fused SDPA for generation, mixed-precision for prefill.
211212
#[cfg(feature = "metal")]
212213
if matches!(q.device(), candle_core::Device::Metal(_)) {
214+
// Generation (seq_len=1): fused kernel — single dispatch with native
215+
// GQA (no repeat_kv), online softmax, no attention matrix materialization.
216+
// Replaces 4+ separate dispatches (repeat_kv + 2 matmuls + softmax + dtype casts).
217+
if seq_len == 1 {
218+
let scale = 1.0 / (self.head_dim as f32).sqrt();
219+
break 'attn self.backend.sdpa(&q, &k, &v, None, false, scale)
220+
.map_err(|e| anyhow!("sdpa: {e}"))?;
221+
}
222+
223+
// Prefill (seq_len > 1): F16 matmuls + F32 softmax (F16 SDPA causes
224+
// garbage, F32 SDPA exceeds threadgroup memory).
213225
let k = self.repeat_kv(k).map_err(|e| anyhow!("repeat_kv k: {e}"))?;
214226
let v = self.repeat_kv(v).map_err(|e| anyhow!("repeat_kv v: {e}"))?;
215227
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
216228
let att = att.to_dtype(candle_core::DType::F32)?;
217-
let att = if seq_len == 1 {
218-
att
219-
} else {
220-
let tril = Tensor::tril2(seq_len, candle_core::DType::F32, att.device())
221-
.map_err(|e| anyhow!("tril: {e}"))?;
222-
let mask = ((tril - 1.0)? * 1e9)?;
223-
let mask = mask.broadcast_as(att.shape())
224-
.map_err(|e| anyhow!("mask broadcast: {e}"))?;
225-
(att + mask).map_err(|e| anyhow!("mask add: {e}"))?
226-
};
229+
let tril = Tensor::tril2(seq_len, candle_core::DType::F32, att.device())
230+
.map_err(|e| anyhow!("tril: {e}"))?;
231+
let mask = ((tril - 1.0)? * 1e9)?;
232+
let mask = mask.broadcast_as(att.shape())
233+
.map_err(|e| anyhow!("mask broadcast: {e}"))?;
234+
let att = (att + mask).map_err(|e| anyhow!("mask add: {e}"))?;
227235
let att = self.backend.softmax(&att, att.rank() - 1)?;
228236
let att = att.to_dtype(v.dtype())?;
229237
break 'attn att.matmul(&v.contiguous()?)
230238
.map_err(|e| anyhow!("att matmul v: {e}"))?;
231239
}
232240

233-
// Manual attention with GQA head expansion (CPU fallback)
241+
// CPU: manual attention with GQA head expansion
234242
let k = self.repeat_kv(k).map_err(|e| anyhow!("repeat_kv k: {e}"))?;
235243
let v = self.repeat_kv(v).map_err(|e| anyhow!("repeat_kv v: {e}"))?;
236-
237244
let att = (q.matmul(&k.t()?)? / (self.head_dim as f64).sqrt())?;
238-
let att = if seq_len == 1 {
239-
att
240-
} else {
241-
let mask = cache.mask(seq_len, att.device())
242-
.map_err(|e| anyhow!("mask: {e}"))?
243-
.broadcast_as(att.shape())
244-
.map_err(|e| anyhow!("mask broadcast: {e}"))?;
245-
masked_fill(&att, &mask, f32::NEG_INFINITY)
246-
.map_err(|e| anyhow!("masked_fill: {e}"))?
247-
};
245+
let mask = cache.mask(seq_len, att.device())
246+
.map_err(|e| anyhow!("mask: {e}"))?
247+
.broadcast_as(att.shape())
248+
.map_err(|e| anyhow!("mask broadcast: {e}"))?;
249+
let att = masked_fill(&att, &mask, f32::NEG_INFINITY)
250+
.map_err(|e| anyhow!("masked_fill: {e}"))?;
248251
let att = self.backend.softmax(&att, att.rank() - 1)?;
249252
att.matmul(&v.contiguous()?)?
250253
};

docs/install.md

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -98,15 +98,18 @@ make mobile_ios
9898

9999
By default, inference runs on CPU. Enable GPU acceleration with:
100100

101-
| Feature | Backend | Platforms |
102-
|---------|---------|-----------|
103-
| `cuda` | NVIDIA CUDA (PTX kernels + flash-attn) | Linux, Windows |
104-
| `metal` | Apple Metal (MSL shaders + fused SDPA) | macOS, iOS |
105-
| `vulkan` | Vulkan via wgpu | Linux, Windows, Steam Deck |
106-
| `flash-attn` | Flash Attention 2 (implies `cuda`) | Linux, Windows |
101+
| Feature | Backend | Platforms | Notes |
102+
|---------|---------|-----------|-------|
103+
| `cuda` | NVIDIA CUDA (PTX kernels + flash-attn) | Linux, Windows | Best for NVIDIA GPUs |
104+
| `metal` | Apple Metal (MSL shaders + fused SDPA) | macOS, iOS | Best for Apple Silicon (~42 tok/s on M3 Pro with 0.8B model) |
105+
| `accelerate` | Apple Accelerate (AMX hardware) | macOS | CPU-only; 2.7x faster F32 matmul via Apple BLAS. No F16 support — use `metal` for F16 models |
106+
| `vulkan` | Vulkan via wgpu | Linux, Windows, Steam Deck | Portable GPU backend |
107+
| `flash-attn` | Flash Attention 2 (implies `cuda`) | Linux, Windows | Fused attention kernel for long sequences |
107108

108109
Multiple backends can be compiled together — the runtime auto-selects based on available hardware.
109110

111+
**Apple Silicon guidance:** Use `metal` for best performance. The `accelerate` feature only helps CPU inference with F32 models — for F16 models (default), CPU without `accelerate` is actually faster (26 vs 23 tok/s) because F16 halves memory bandwidth vs the F32 conversion Accelerate requires.
112+
110113
### Model Features
111114

112115
By default, all text model architectures are compiled in. To build only for specific models:

0 commit comments

Comments
 (0)