fix: Gemma 4 MoE Metal fusions and Qwen 3.6 GDN prefill pipelining#27
fix: Gemma 4 MoE Metal fusions and Qwen 3.6 GDN prefill pipelining#27solderzzc wants to merge 8 commits into
Conversation
…lining (ml-explore#224, ml-explore#225) - Aligns internal state recurrent precision with Python upstream - Enables async chunk pipelining for 10x M-series prefill throughput on GDN architectures
…spatch via compiledSwiGLU closure
There was a problem hiding this comment.
Pull request overview
This PR targets performance and correctness issues in MLX Swift LLM/VLM inference: reducing MoE dispatch overhead, fixing GatedDelta state-precision drift across multi-step prefill, and improving long-prompt prefill throughput by pipelining eval work.
Changes:
- Replace/route GELU + softcap operations through safer/fused implementations in Gemma 4 text/vision paths.
- Keep GatedDelta recurrent state in FP32 and add tests to catch multi-step prefill instability.
- Pipeline prompt prefill chunking using
asyncEvalto reduce GPU starvation on long prompts; add minimal SDPA repro scripts.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
test_gemma4_crash.swift |
Adds a Swift reproducer for SDPA crash/regression investigation. |
run_test.py |
Adds a Python reproducer for SDPA crash/regression investigation. |
Tests/MLXLMTests/GatedDeltaTests.swift |
Adds regression tests for multi-step GatedDelta prefill finiteness and determinism. |
Libraries/MLXVLM/Models/Gemma4.swift |
Switches Gemma4 VLM MLP activations to safeGeluApproximate and uses compiled softcap with dtype alignment. |
Libraries/MLXLMCommon/SwitchLayers.swift |
Adjusts MoE expert intermediate computation dtype handling (bf16 casts) in quantized/SSD and fallback paths. |
Libraries/MLXLMCommon/Optimizations.swift |
Introduces safeGeluApproximate and a compiled softcap helper. |
Libraries/MLXLLM/Models/Qwen3MoE.swift |
Updates Qwen3 MoE MLP elementwise dtype handling to avoid unwanted promotions. |
Libraries/MLXLLM/Models/Gemma4Text.swift |
Switches Gemma4 text MLP activations to safeGeluApproximate, aligns scalar dtypes, and uses compiled softcap. |
Libraries/MLXLLM/Models/GatedDelta.swift |
Keeps state dtype separate from input dtype in the Metal kernel and forces FP32 state in updates. |
Libraries/MLXLLM/LLMModel.swift |
Uses asyncEval during chunked prefill and syncs once after the loop. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| let xGate = qGate.computeExperts(x, buffers: usedGate, ranges: ranges) | ||
| let xUp = qUp.computeExperts(x, buffers: usedUp, ranges: ranges) | ||
| let intermediate = activation(xGate) * xUp | ||
| let intermediate = activation(xGate.asType(.bfloat16)) * xUp.asType(.bfloat16) |
There was a problem hiding this comment.
Same dtype concern here: unconditionally downcasting xGate/xUp to .bfloat16 can change results for .float32 models and may not be appropriate on CPU. Consider gating the cast on the actual dtype/device (e.g. only when inputs are .float16 on GPU) so higher-precision paths keep their precision.
| let intermediate = activation(xGate.asType(.bfloat16)) * xUp.asType(.bfloat16) | |
| let useBFloat16Intermediate = xGate.dtype == .float16 && xUp.dtype == .float16 | |
| let gateInput = useBFloat16Intermediate ? xGate.asType(.bfloat16) : xGate | |
| let upInput = useBFloat16Intermediate ? xUp.asType(.bfloat16) : xUp | |
| let intermediate = activation(gateInput) * upInput |
|
|
||
| var cacheArrays: [MLXArray] = [] | ||
| for c in cache { cacheArrays.append(contentsOf: c.innerState()) } | ||
| asyncEval(cacheArrays) |
There was a problem hiding this comment.
cache can be empty (you already pass nil into the model when cache.isEmpty). In that case cacheArrays stays empty and this calls asyncEval([]). If asyncEval treats an empty list specially (or if it’s just wasted work), it’s safer to guard and only call asyncEval when there’s at least one array to evaluate.
| asyncEval(cacheArrays) | |
| if !cacheArrays.isEmpty { | |
| asyncEval(cacheArrays) | |
| } |
| do { | ||
| let out = MLXFast.scaledDotProductAttention(queries: q, keys: k, values: v, scale: 1.0, mask: .array(mask)) | ||
| MLX.eval(out) | ||
| print("SDPA success!") | ||
| } catch { | ||
| print("Caught error: \(error)") | ||
| } |
There was a problem hiding this comment.
do/catch is invalid here: the do block doesn’t contain any try, so the catch will be unreachable (or, if scaledDotProductAttention is actually throws, the call is missing try). This makes the reproducer not compile/run as intended. Consider removing the do/catch entirely, or add the appropriate try/try? on the throwing call(s) and keep the error handling consistent with the MLX API’s throwing behavior.
| do { | |
| let out = MLXFast.scaledDotProductAttention(queries: q, keys: k, values: v, scale: 1.0, mask: .array(mask)) | |
| MLX.eval(out) | |
| print("SDPA success!") | |
| } catch { | |
| print("Caught error: \(error)") | |
| } | |
| let out = MLXFast.scaledDotProductAttention(queries: q, keys: k, values: v, scale: 1.0, mask: .array(mask)) | |
| MLX.eval(out) | |
| print("SDPA success!") |
| eval(diff) | ||
|
|
||
| let maxDiff = diff.item(Float.self) | ||
| XCTAssertEqual(maxDiff, 0.0, "GDN kernel not deterministic across runs") |
There was a problem hiding this comment.
This test requires bitwise-identical outputs across two full forward passes (XCTAssertEqual(maxDiff, 0.0)). On GPU (and sometimes across different Metal/driver versions), tiny nondeterminism or math reordering can produce extremely small non-zero diffs, making the test flaky. Consider asserting maxDiff is below a small tolerance (e.g. XCTAssertLessThan(maxDiff, eps) or XCTAssertEqual(maxDiff, 0, accuracy: eps)), while keeping the existing finite/non-zero checks.
| XCTAssertEqual(maxDiff, 0.0, "GDN kernel not deterministic across runs") | |
| let eps: Float = 1e-6 | |
| XCTAssertFalse(maxDiff.isNaN || maxDiff.isInfinite, "GDN kernel produced a non-finite diff across runs") | |
| XCTAssertLessThanOrEqual(maxDiff, eps, "GDN kernel not deterministic across runs") |
| let xGate = qGate.computeExperts(x, buffers: usedGate, ranges: ranges) | ||
| let xUp = qUp.computeExperts(x, buffers: usedUp, ranges: ranges) | ||
| let intermediate = activation(xGate) * xUp | ||
| let intermediate = activation(xGate.asType(.bfloat16)) * xUp.asType(.bfloat16) |
There was a problem hiding this comment.
Casting xGate and xUp to .bfloat16 unconditionally before applying the activation/multiply changes numerics for non-.float16 execution (e.g. .float32 runs) and may be an unintended accuracy regression outside the specific Metal auto-promotion case. Consider making the cast conditional (only when the arrays are .float16), or otherwise preserving the original dtype when higher precision is already in use.
| let intermediate = activation(xGate.asType(.bfloat16)) * xUp.asType(.bfloat16) | |
| let activatedGate = activation(xGate.dtype == .float16 ? xGate.asType(.bfloat16) : xGate) | |
| let multipliedUp = xUp.dtype == .float16 ? xUp.asType(.bfloat16) : xUp | |
| let intermediate = activatedGate * multipliedUp |
…trained .float32 promotion drift
… SwitchLayers fusions to strict contiguous array signatures
…prevent gatherQuantizedMM NaN corruption on pre-M4 GPUs
…explicitly to preserve numeric approximations across Apple GPU boundaries
SWIGLUandGEGLUclosures usingcompile(shapeless: true)to halve MoE dispatch overhead.asyncEvalpipelining to safely bypass GPU chunking starvation, resolving the ~10x M-series prefill limits on 2K+ prompts.