Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,6 @@ Benchmark results for `gemma-4-26b-a4b-it-4bit` (26B MoE, 4-bit) on M5 Pro 64 GB

> Run `./run_benchmark.sh` to generate these metrics on your own device. (See **Benchmarks & Testing** below).

### Qwen3.6-35B-A3B-UD-MLX-4bit (Full-RAM) — M1 Ultra 64 GB

Benchmark results for full-RAM (no SSD streaming) MoE inference on M1 Ultra. The 3.4× vanilla improvement vs. earlier builds comes from the `needsMoeFlush` gate in `mlx-swift-lm` (see [SwiftLM #84](https://github.com/SharpAI/SwiftLM/issues/84)) — the per-layer GPU sync barrier required for SSD streaming was firing unconditionally on the full-RAM path and flushing MLX's kernel-batching pipeline.

| Configuration | Short (~126 tok) | Medium (~400 tok) | Long (~800 tok) |
|---|---|---|---|
| **Vanilla full-GPU** | **61.7 tok/s** | **62.3 tok/s** | **62.1 tok/s** |

> *Hardware:* Apple M1 Ultra, 64 GB unified memory, macOS 26.x. Model ~20 GB on disk, ~21.6 GB resident weight + ~2.1 GB KV at runtime.
> *Flags:* `--repeat-penalty 1.1 --max-tokens 2000`, `temperature: 0.6`, single-stream `/v1/chat/completions`.
> *Vanilla baseline before* `needsMoeFlush` *gate (for reference):* 19.2 / 18.1 / 18.3 tok/s — see #84.

> ⚠️ **DFlash on this model is currently unsuitable for production.** DFlash uses pure greedy (`argMax`) decoding regardless of `temperature`, which on Qwen3.6-35B-A3B + the [`z-lab/Qwen3.6-35B-A3B-DFlash`](https://huggingface.co/z-lab/Qwen3.6-35B-A3B-DFlash) draft locks into low-entropy attractors (`"and and and..."`, `"**UMA** **UMA**..."`). Earlier 70 tok/s DFlash numbers were degenerate output that scored high acceptance because draft and target both committed to the same locked-in token. Repetition-penalty mitigation works on some prompts but tanks acceptance on others — the proper fix is stochastic posterior sampling with rejection-based accept ([Leviathan/Chen](https://arxiv.org/abs/2211.17192) formulation), which is a DFlash architecture change tracked at [z-lab/dflash#91](https://github.com/z-lab/dflash/issues/91).

### DeepSeek-V4-Flash (126 GB, Q3-mixed-gs128-affine) — M5 Pro 64 GB

Model: [`Thump604/DeepSeek-V4-Flash-MLX-Q3-mixed-gs128-affine`](https://huggingface.co/Thump604/DeepSeek-V4-Flash-MLX-Q3-mixed-gs128-affine)
Expand Down
1 change: 1 addition & 0 deletions Sources/DFlash/DFlashIntermediateDumper.swift
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ public enum DFlashDumper {
eval(floatArr)

let shape = (0..<floatArr.ndim).map { floatArr.dim($0) }
let totalElements = shape.reduce(1, *)

// Build spec-compliant .npy header: shape must be a Python tuple,
// spaces pad before the final newline byte.
Expand Down
2 changes: 1 addition & 1 deletion Sources/DFlash/DFlashRuntime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ public enum DFlashRuntime {

let draftBackend = DFlashDraftBackend()

let targetCache = makeTargetCache(targetModel: targetModel)
var targetCache = makeTargetCache(targetModel: targetModel)

let draftCache = draftBackend.makeCache(
draftModel: draftModel,
Expand Down
30 changes: 1 addition & 29 deletions Sources/MLXInferenceCore/GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,6 @@ public struct GenerationConfig: Sendable {
public var seed: UInt64?
public var enableThinking: Bool

// ── SwiftLM Engine Parameters ──────────────────────────────────────
/// Enable TurboQuant KV-cache compression (3-bit PolarQuant+QJL).
/// Compresses KV history > 8192 tokens to ~3.5 bits/token.
public var turboKV: Bool

/// Enable SSD expert streaming for MoE models.
public var streamExperts: Bool

/// Chunk size for prefill evaluation.
/// Lower values prevent GPU timeout on large models.
public var prefillSize: Int

/// KV-cache quantization bits (nil = no quantization, 4 or 8 typical).
public var kvBits: Int?

/// KV-cache quantization group size (default 64).
public var kvGroupSize: Int

public init(
maxTokens: Int = 2048,
temperature: Float = 0.6,
Expand All @@ -38,12 +20,7 @@ public struct GenerationConfig: Sendable {
minP: Float = 0.0,
repetitionPenalty: Float = 1.05,
seed: UInt64? = nil,
enableThinking: Bool = false,
turboKV: Bool = false,
streamExperts: Bool = false,
prefillSize: Int = 512,
kvBits: Int? = nil,
kvGroupSize: Int = 64
enableThinking: Bool = false
) {
self.maxTokens = maxTokens
self.temperature = temperature
Expand All @@ -53,11 +30,6 @@ public struct GenerationConfig: Sendable {
self.repetitionPenalty = repetitionPenalty
self.seed = seed
self.enableThinking = enableThinking
self.turboKV = turboKV
self.streamExperts = streamExperts
self.prefillSize = prefillSize
self.kvBits = kvBits
self.kvGroupSize = kvGroupSize
}

public static let `default` = GenerationConfig()
Expand Down
186 changes: 16 additions & 170 deletions Sources/MLXInferenceCore/InferenceEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,6 @@ public final class InferenceEngine: ObservableObject {
@Published public private(set) var activeContextTokens: Int = 0
@Published public private(set) var maxContextWindow: Int = 0

/// Set when a corrupted/truncated model is detected during inference.
/// The UI should observe this and offer to delete & re-download.
@Published public var corruptedModelId: String? = nil

/// Whether to automatically unload the model when the app backgrounds
/// and reload it when returning to foreground.
/// Defaults to true on iOS (prevents jetsam), false on macOS.
Expand Down Expand Up @@ -281,44 +277,7 @@ public final class InferenceEngine: ObservableObject {
state = .error("Device is too hot. Let it cool before loading a model.")
return
}
corruptedModelId = nil

guard ModelStorage.verifyModelIntegrity(for: modelId) else {
await downloadThenLoad(modelId: modelId)
return
}

await loadVerifiedModel(modelId: modelId)
}

private func downloadThenLoad(modelId: String) async {
print("[InferenceEngine] Model \(modelId) is missing or incomplete. Starting download before load.")
releaseLoadedModelResources()
state = .downloading(progress: 0.0, speed: "Preparing...")

let task = downloadManager.startDownload(modelId: modelId)

do {
try await task.value
state = .downloading(progress: 1.0, speed: "Verifying...")

guard ModelStorage.verifyModelIntegrity(for: modelId) else {
markModelCorrupted(
modelId: modelId,
message: "Model files are incomplete after download. Choose a recovery option."
)
return
}

await loadVerifiedModel(modelId: modelId)
} catch is CancellationError {
state = .idle
} catch {
state = .error("Failed to download \(modelId): \(error.localizedDescription)")
}
}

private func loadVerifiedModel(modelId: String) async {
state = .loading
currentModelId = modelId

Expand Down Expand Up @@ -353,29 +312,25 @@ public final class InferenceEngine: ObservableObject {
downloader: downloader
)

let speedTracker = DownloadSpeedTracker()

if architecture.supportsVision {
container = try await VLMModelFactory.shared.loadContainer(
from: downloader,
using: TransformersTokenizerLoader(),
configuration: config
) { [weak self] progress in
speedTracker.record(totalBytes: progress.completedUnitCount)
let smoothedSpeed = speedTracker.speedBytesPerSec

Task { @MainActor in
guard let self else { return }
let pct = progress.fractionCompleted
let speedStr = smoothedSpeed
let speedBytesPerSec = progress.userInfo[ProgressUserInfoKey("throughputKey")] as? Double
let speedStr = speedBytesPerSec
.map { String(format: "%.1f MB/s", $0 / 1_000_000) } ?? ""
self.state = .downloading(progress: pct, speed: speedStr)

self.downloadManager.updateProgress(ModelDownloadProgress(
modelId: modelId,
fractionCompleted: pct,
currentFile: "",
speedMBps: smoothedSpeed.map { $0 / 1_000_000 }
speedMBps: speedBytesPerSec.map { $0 / 1_000_000 }
))
}
}
Expand All @@ -385,21 +340,19 @@ public final class InferenceEngine: ObservableObject {
using: TransformersTokenizerLoader(),
configuration: config
) { [weak self] progress in
speedTracker.record(totalBytes: progress.completedUnitCount)
let smoothedSpeed = speedTracker.speedBytesPerSec

Task { @MainActor in
guard let self else { return }
let pct = progress.fractionCompleted
let speedStr = smoothedSpeed
let speedBytesPerSec = progress.userInfo[ProgressUserInfoKey("throughputKey")] as? Double
let speedStr = speedBytesPerSec
.map { String(format: "%.1f MB/s", $0 / 1_000_000) } ?? ""
self.state = .downloading(progress: pct, speed: speedStr)

self.downloadManager.updateProgress(ModelDownloadProgress(
modelId: modelId,
fractionCompleted: pct,
currentFile: "",
speedMBps: smoothedSpeed.map { $0 / 1_000_000 }
speedMBps: speedBytesPerSec.map { $0 / 1_000_000 }
))
}
}
Expand All @@ -408,85 +361,26 @@ public final class InferenceEngine: ObservableObject {
downloadManager.clearProgress(modelId: modelId)
downloadManager.lastLoadedModelId = modelId
downloadManager.refresh()

// Verify integrity to catch incomplete downloads before marking as ready
guard ModelStorage.verifyModelIntegrity(for: modelId) else {
throw NSError(domain: "InferenceEngine", code: 1, userInfo: [NSLocalizedDescriptionKey: "Model safetensors files are incomplete. Please delete and re-download."])
}

// Read the model's actual max context length from config.json
if let ctxLen = ModelStorage.readMaxContextLength(for: modelId) {
self.maxContextWindow = ctxLen
print("[InferenceEngine] Model context window: \(ctxLen) tokens")
} else {
self.maxContextWindow = 8192 // conservative fallback for models without explicit limits
print("[InferenceEngine] No explicit context limit found in config.json, defaulting to 8192")
}

state = .ready(modelId: modelId)

} catch {
ExpertStreamingConfig.shared.deactivate()
downloadManager.clearProgress(modelId: modelId)
state = .error("Failed to load \(modelId): \(error.localizedDescription)")

// If the model is incomplete/corrupted, flag it so the UI shows the "Delete & Re-download" button
let nsError = error as NSError
if nsError.domain == "InferenceEngine" && nsError.code == 1 || Self.isModelCorruptionError(error) {
markModelCorrupted(
modelId: modelId,
message: "Model weights are corrupted or incomplete. Choose a recovery option."
)
return
}

container = nil
self.maxContextWindow = 0
self.activeContextTokens = 0
}
}

/// Unload the current model and free all GPU memory.
public func unload() {
releaseLoadedModelResources()
corruptedModelId = nil
state = .idle
}

private func releaseLoadedModelResources() {
generationTask?.cancel()
generationTask = nil
container = nil
currentModelId = nil
maxContextWindow = 0
activeContextTokens = 0
state = .idle
ExpertStreamingConfig.shared.deactivate()
MLX.Memory.cacheLimit = 0
}

private func markModelCorrupted(modelId: String?, message: String) {
let failedModelId = modelId ?? currentModelId
releaseLoadedModelResources()
state = .error(message)
corruptedModelId = failedModelId
}

private static func isModelCorruptionError(_ error: Error) -> Bool {
let description = error.localizedDescription.lowercased()
return description.contains("ssd streaming")
|| description.contains("pread")
|| description.contains("safetensors")
|| description.contains("corrupt")
|| description.contains("incomplete")
}

public func clearCorruptionRecovery() {
corruptedModelId = nil
if case .error = state {
state = .idle
}
}

// MARK: — Generation

public nonisolated func generate(
Expand Down Expand Up @@ -528,17 +422,11 @@ public final class InferenceEngine: ObservableObject {
}

let mlxMessages = finalMessages
var params = GenerateParameters(
maxTokens: config.maxTokens,
kvBits: config.kvBits,
kvGroupSize: config.kvGroupSize,
temperature: config.temperature,
topP: config.topP,
topK: config.topK,
minP: config.minP,
repetitionPenalty: config.repetitionPenalty,
prefillStepSize: config.prefillSize
)
var params = GenerateParameters(temperature: config.temperature)
params.topP = config.topP
params.topK = config.topK
params.minP = config.minP
params.repetitionPenalty = config.repetitionPenalty
params.repetitionContextSize = 20

var thinkingActive = false
Expand All @@ -554,7 +442,9 @@ public final class InferenceEngine: ObservableObject {
let baseTokens = Int(Double(stringLength) / 3.5)
self.activeContextTokens = baseTokens

// maxContextWindow is already set during loadModel() from config.json
// If we have a max length config, expose it
// TODO: Safely extract from ModelConfiguration when MLX exposes it dynamically
self.maxContextWindow = 8192

let stream: AsyncStream<Generation> = try await container.generate(
input: lmInput,
Expand Down Expand Up @@ -595,30 +485,11 @@ public final class InferenceEngine: ObservableObject {
continuation.yield(GenerationToken(text: text, isThinking: thinkingActive))
}
}
} catch let ssdError as SSDStreamingError {
// Corrupted/truncated safetensors — surface a clear, actionable error
let msg = "Model weights are corrupted or incomplete. Please re-download the model."
print("[InferenceEngine] SSD Streaming Error: \(ssdError.localizedDescription)")
continuation.yield(GenerationToken(text: "\n\n[Error: \(msg)]"))
self.markModelCorrupted(modelId: self.currentModelId, message: msg)
} catch {
// Check if the generic error is also an SSD streaming issue
if Self.isModelCorruptionError(error) {
let msg = "Model weights are corrupted or incomplete. Please re-download the model."
self.markModelCorrupted(modelId: self.currentModelId, message: msg)
}
continuation.yield(GenerationToken(text: "\n\n[Error: \(error.localizedDescription)]"))
}

if let latchedError = SSDStreamingErrorLatch.shared.consume() {
let msg = "Model weights are corrupted or incomplete. Please re-download the model."
print("[InferenceEngine] Latched SSD error after generation: \(latchedError.localizedDescription)")
self.markModelCorrupted(modelId: self.currentModelId, message: msg)
} else if case .error = self.state {
// Already in error state from catch block above
} else {
self.state = self.currentModelId.map { .ready(modelId: $0) } ?? .idle
}
self.state = self.currentModelId.map { .ready(modelId: $0) } ?? .idle
continuation.finish()
}
}
Expand All @@ -629,29 +500,4 @@ public final class InferenceEngine: ObservableObject {
generationTask = nil
if let id = currentModelId { state = .ready(modelId: id) }
}

/// Delete corrupted model files and start a fresh download.
/// Called from the UI when the user confirms re-download after corruption is detected.
public func deleteCorruptedAndRedownload() {
guard let modelId = corruptedModelId else { return }

releaseLoadedModelResources()
state = .downloading(progress: 0.0, speed: "Deleting corrupted files...")

do {
try ModelStorage.delete(modelId)
print("[InferenceEngine] Successfully deleted corrupted cache directory for \(modelId).")
} catch {
print("[InferenceEngine] FAILED to delete corrupted cache: \(error.localizedDescription)")
state = .error("Failed to delete corrupted model: \(error.localizedDescription)")
return
}
downloadManager.refresh()
corruptedModelId = nil

print("[InferenceEngine] Deleted corrupted files for \(modelId), starting fresh download")
Task { @MainActor in
await downloadThenLoad(modelId: modelId)
}
}
}
Loading
Loading