Skip to content

fix: Gemma 4 MoE Metal fusions and Qwen 3.6 GDN prefill pipelining#27

Open
solderzzc wants to merge 8 commits into
mainfrom
feature/gemma4-moe-4x-alignment
Open

fix: Gemma 4 MoE Metal fusions and Qwen 3.6 GDN prefill pipelining#27
solderzzc wants to merge 8 commits into
mainfrom
feature/gemma4-moe-4x-alignment

Conversation

@solderzzc
Copy link
Copy Markdown
Member

  • Fuses Metal SWIGLU and GEGLU closures using compile(shapeless: true) to halve MoE dispatch overhead.
  • Synchronizes GDN recurrent context precision buffers to FP32, preventing floating point boundary decay at T>1.
  • Introduces asyncEval pipelining to safely bypass GPU chunking starvation, resolving the ~10x M-series prefill limits on 2K+ prompts.

…lining (ml-explore#224, ml-explore#225)

- Aligns internal state recurrent precision with Python upstream
- Enables async chunk pipelining for 10x M-series prefill throughput on GDN architectures
Copilot AI review requested due to automatic review settings April 20, 2026 23:17
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR targets performance and correctness issues in MLX Swift LLM/VLM inference: reducing MoE dispatch overhead, fixing GatedDelta state-precision drift across multi-step prefill, and improving long-prompt prefill throughput by pipelining eval work.

Changes:

  • Replace/route GELU + softcap operations through safer/fused implementations in Gemma 4 text/vision paths.
  • Keep GatedDelta recurrent state in FP32 and add tests to catch multi-step prefill instability.
  • Pipeline prompt prefill chunking using asyncEval to reduce GPU starvation on long prompts; add minimal SDPA repro scripts.

Reviewed changes

Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.

Show a summary per file
File Description
test_gemma4_crash.swift Adds a Swift reproducer for SDPA crash/regression investigation.
run_test.py Adds a Python reproducer for SDPA crash/regression investigation.
Tests/MLXLMTests/GatedDeltaTests.swift Adds regression tests for multi-step GatedDelta prefill finiteness and determinism.
Libraries/MLXVLM/Models/Gemma4.swift Switches Gemma4 VLM MLP activations to safeGeluApproximate and uses compiled softcap with dtype alignment.
Libraries/MLXLMCommon/SwitchLayers.swift Adjusts MoE expert intermediate computation dtype handling (bf16 casts) in quantized/SSD and fallback paths.
Libraries/MLXLMCommon/Optimizations.swift Introduces safeGeluApproximate and a compiled softcap helper.
Libraries/MLXLLM/Models/Qwen3MoE.swift Updates Qwen3 MoE MLP elementwise dtype handling to avoid unwanted promotions.
Libraries/MLXLLM/Models/Gemma4Text.swift Switches Gemma4 text MLP activations to safeGeluApproximate, aligns scalar dtypes, and uses compiled softcap.
Libraries/MLXLLM/Models/GatedDelta.swift Keeps state dtype separate from input dtype in the Metal kernel and forces FP32 state in updates.
Libraries/MLXLLM/LLMModel.swift Uses asyncEval during chunked prefill and syncs once after the loop.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

let xGate = qGate.computeExperts(x, buffers: usedGate, ranges: ranges)
let xUp = qUp.computeExperts(x, buffers: usedUp, ranges: ranges)
let intermediate = activation(xGate) * xUp
let intermediate = activation(xGate.asType(.bfloat16)) * xUp.asType(.bfloat16)
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same dtype concern here: unconditionally downcasting xGate/xUp to .bfloat16 can change results for .float32 models and may not be appropriate on CPU. Consider gating the cast on the actual dtype/device (e.g. only when inputs are .float16 on GPU) so higher-precision paths keep their precision.

Suggested change
let intermediate = activation(xGate.asType(.bfloat16)) * xUp.asType(.bfloat16)
let useBFloat16Intermediate = xGate.dtype == .float16 && xUp.dtype == .float16
let gateInput = useBFloat16Intermediate ? xGate.asType(.bfloat16) : xGate
let upInput = useBFloat16Intermediate ? xUp.asType(.bfloat16) : xUp
let intermediate = activation(gateInput) * upInput

Copilot uses AI. Check for mistakes.
Comment thread Libraries/MLXLLM/LLMModel.swift Outdated

var cacheArrays: [MLXArray] = []
for c in cache { cacheArrays.append(contentsOf: c.innerState()) }
asyncEval(cacheArrays)
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cache can be empty (you already pass nil into the model when cache.isEmpty). In that case cacheArrays stays empty and this calls asyncEval([]). If asyncEval treats an empty list specially (or if it’s just wasted work), it’s safer to guard and only call asyncEval when there’s at least one array to evaluate.

Suggested change
asyncEval(cacheArrays)
if !cacheArrays.isEmpty {
asyncEval(cacheArrays)
}

Copilot uses AI. Check for mistakes.
Comment thread test_gemma4_crash.swift Outdated
Comment on lines +16 to +22
do {
let out = MLXFast.scaledDotProductAttention(queries: q, keys: k, values: v, scale: 1.0, mask: .array(mask))
MLX.eval(out)
print("SDPA success!")
} catch {
print("Caught error: \(error)")
}
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do/catch is invalid here: the do block doesn’t contain any try, so the catch will be unreachable (or, if scaledDotProductAttention is actually throws, the call is missing try). This makes the reproducer not compile/run as intended. Consider removing the do/catch entirely, or add the appropriate try/try? on the throwing call(s) and keep the error handling consistent with the MLX API’s throwing behavior.

Suggested change
do {
let out = MLXFast.scaledDotProductAttention(queries: q, keys: k, values: v, scale: 1.0, mask: .array(mask))
MLX.eval(out)
print("SDPA success!")
} catch {
print("Caught error: \(error)")
}
let out = MLXFast.scaledDotProductAttention(queries: q, keys: k, values: v, scale: 1.0, mask: .array(mask))
MLX.eval(out)
print("SDPA success!")

Copilot uses AI. Check for mistakes.
Comment thread Tests/MLXLMTests/GatedDeltaTests.swift Outdated
eval(diff)

let maxDiff = diff.item(Float.self)
XCTAssertEqual(maxDiff, 0.0, "GDN kernel not deterministic across runs")
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test requires bitwise-identical outputs across two full forward passes (XCTAssertEqual(maxDiff, 0.0)). On GPU (and sometimes across different Metal/driver versions), tiny nondeterminism or math reordering can produce extremely small non-zero diffs, making the test flaky. Consider asserting maxDiff is below a small tolerance (e.g. XCTAssertLessThan(maxDiff, eps) or XCTAssertEqual(maxDiff, 0, accuracy: eps)), while keeping the existing finite/non-zero checks.

Suggested change
XCTAssertEqual(maxDiff, 0.0, "GDN kernel not deterministic across runs")
let eps: Float = 1e-6
XCTAssertFalse(maxDiff.isNaN || maxDiff.isInfinite, "GDN kernel produced a non-finite diff across runs")
XCTAssertLessThanOrEqual(maxDiff, eps, "GDN kernel not deterministic across runs")

Copilot uses AI. Check for mistakes.
let xGate = qGate.computeExperts(x, buffers: usedGate, ranges: ranges)
let xUp = qUp.computeExperts(x, buffers: usedUp, ranges: ranges)
let intermediate = activation(xGate) * xUp
let intermediate = activation(xGate.asType(.bfloat16)) * xUp.asType(.bfloat16)
Copy link

Copilot AI Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Casting xGate and xUp to .bfloat16 unconditionally before applying the activation/multiply changes numerics for non-.float16 execution (e.g. .float32 runs) and may be an unintended accuracy regression outside the specific Metal auto-promotion case. Consider making the cast conditional (only when the arrays are .float16), or otherwise preserving the original dtype when higher precision is already in use.

Suggested change
let intermediate = activation(xGate.asType(.bfloat16)) * xUp.asType(.bfloat16)
let activatedGate = activation(xGate.dtype == .float16 ? xGate.asType(.bfloat16) : xGate)
let multipliedUp = xUp.dtype == .float16 ? xUp.asType(.bfloat16) : xUp
let intermediate = activatedGate * multipliedUp

Copilot uses AI. Check for mistakes.
Aegis-AI added 5 commits April 20, 2026 16:33
… SwitchLayers fusions to strict contiguous array signatures
…prevent gatherQuantizedMM NaN corruption on pre-M4 GPUs
…explicitly to preserve numeric approximations across Apple GPU boundaries
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants