Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion Sources/MLXInferenceCore/GenerationConfig.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ public struct GenerationConfig: Sendable, Codable {
/// force-disable streaming even on MoE models.
public var streamExperts: Bool

/// Enable MTP (Multi-Token Prediction) speculative decoding.
/// When true, the inference engine will use the model's internal MTP heads
/// to draft `numMTPTokens` candidate tokens per step, then verify them in
/// a single batched forward pass — targeting 2x+ throughput improvement.
/// Requires a checkpoint that retains `mtp.*` weights (set SWIFTLM_MTP_ENABLE=1
/// at model-load time). No-ops gracefully if the model does not conform to
/// `MTPLanguageModel`.
/// ⚠️ LOAD-TIME flag: changes take effect on the next model load.
Comment on lines +56 to +60
public var enableMTP: Bool
Comment on lines +53 to +61

/// Number of tokens the MTP heads draft per speculation round (default 1).
/// Higher values increase potential speedup but also increase rejection rate.
public var numMTPTokens: Int

public init(
maxTokens: Int = 2048,
temperature: Float = 0.6,
Expand All @@ -63,7 +77,9 @@ public struct GenerationConfig: Sendable, Codable {
kvBits: Int? = nil,
kvGroupSize: Int = 64,
turboKV: Bool = false,
streamExperts: Bool = false
streamExperts: Bool = false,
enableMTP: Bool = false,
numMTPTokens: Int = 1
) {
self.maxTokens = maxTokens
self.temperature = temperature
Expand All @@ -78,6 +94,8 @@ public struct GenerationConfig: Sendable, Codable {
self.kvGroupSize = kvGroupSize
self.turboKV = turboKV
self.streamExperts = streamExperts
self.enableMTP = enableMTP
self.numMTPTokens = numMTPTokens
}

public static let `default` = GenerationConfig()
Expand Down
70 changes: 63 additions & 7 deletions Sources/MLXInferenceCore/InferenceEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,20 @@ public struct GenerationToken: Sendable {
}
}

// MARK: — Inference Metrics

/// Live performance counters updated at the end of each generation pass.
public struct InferenceMetrics: Sendable {
/// Time from first-token request to first decoded token (seconds).
public var ttft: Double
/// Prompt / prefill throughput (tokens per second).
public var prefillToksPerSec: Double
/// Decode throughput — tokens generated per second after the first token.
public var decodeToksPerSec: Double

public static let zero = InferenceMetrics(ttft: 0, prefillToksPerSec: 0, decodeToksPerSec: 0)
}

// MARK: — InferenceEngine

@MainActor
Expand All @@ -113,6 +127,8 @@ public final class InferenceEngine: ObservableObject {
@Published public private(set) var thermalLevel: ThermalLevel = .nominal
@Published public private(set) var activeContextTokens: Int = 0
@Published public private(set) var maxContextWindow: Int = 0
/// Performance counters from the most recent completed generation.
@Published public private(set) var lastMetrics: InferenceMetrics = .zero

/// Set when a corrupted/truncated model is detected during inference.
/// The UI should observe this and offer to delete & re-download.
Expand Down Expand Up @@ -587,6 +603,10 @@ extension InferenceEngine {
var outputText = ""
var tokenCount = 0

// ── Metrics timing ──────────────────────────────────────
let generationStart = Date()
var firstTokenDate: Date? = nil

// Set RNG seed for reproducible output when requested.
if let seed = config.seed {
MLX.seed(seed)
Expand Down Expand Up @@ -627,21 +647,39 @@ extension InferenceEngine {
}

let stream: AsyncStream<Generation> = try await container.perform { ctx in
try MLXLMCommon.generate(
input: lmInput,
cache: cache,
parameters: params,
context: ctx
)
// MTP speculative decoding path: use MTPTokenIterator when
// 1. The config requests MTP (enableMTP=true)
// 2. The loaded model conforms to MTPLanguageModel
if config.enableMTP, ctx.model is (any MTPLanguageModel) {
return try MLXLMCommon.generateMTP(
input: lmInput,
cache: cache,
parameters: params,
context: ctx,
numMTPTokens: config.numMTPTokens
)
} else {
return try MLXLMCommon.generate(
input: lmInput,
cache: cache,
parameters: params,
context: ctx
)
}
}

for await generation in stream {
guard !Task.isCancelled else { break }

if case .chunk(let text, tokenId: _) = generation {
// Record time-to-first-token on the very first chunk
if firstTokenDate == nil {
firstTokenDate = Date()
}

outputText += text
tokenCount += 1

// Update the UI token counter periodically to save CPU
if tokenCount % 10 == 0 {
self.activeContextTokens = baseTokens + tokenCount
Expand Down Expand Up @@ -669,6 +707,24 @@ extension InferenceEngine {
continuation.yield(GenerationToken(text: text, isThinking: thinkingActive))
}
}

// ── Publish metrics for the completed turn ───────────────
let totalElapsed = Date().timeIntervalSince(generationStart)
let ttft = firstTokenDate.map { $0.timeIntervalSince(generationStart) } ?? 0
// Prefill throughput: prompt tokens / time-to-first-token
let prefillTps = (ttft > 0 && baseTokens > 0)
? Double(baseTokens) / ttft
: 0
// Decode throughput: generated tokens / time spent decoding
let decodeElapsed = totalElapsed - ttft
let decodeTps = (decodeElapsed > 0 && tokenCount > 1)
? Double(tokenCount - 1) / decodeElapsed
: 0
self.lastMetrics = InferenceMetrics(
ttft: ttft,
prefillToksPerSec: prefillTps,
decodeToksPerSec: decodeTps
)
} catch let ssdError as SSDStreamingError {
// Corrupted/truncated safetensors — surface a clear, actionable error
let msg = "Model weights are corrupted or incomplete. Please re-download the model."
Expand Down
45 changes: 36 additions & 9 deletions Sources/SwiftLM/Server.swift
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,12 @@ struct MLXServer: AsyncParsableCommand {
@Option(name: .long, help: "DFlash block size (number of tokens per draft block). Default: use draft model's configured block_size.")
var dflashBlockSize: Int?

@Flag(name: .long, help: "Enable Multi-Token Prediction (MTP) Speculative Decoding.")
var mtp: Bool = false

@Option(name: .long, help: "Number of MTP tokens to generate per speculation round (default: 3)")
var numMtpTokens: Int = 3

mutating func run() async throws {
// Raise the open-file limit: large sharded models (e.g. Kimi K2.5, 182 safetensor
// shards) + draft model + metallib + dylibs can exhaust the default macOS FD limit of 256.
Expand All @@ -295,10 +301,14 @@ struct MLXServer: AsyncParsableCommand {
// This env var must be set before MLX's Metal backend initializes.
// Value 50 splits large computation graphs into ~1-layer chunks so macOS
// can page in weights incrementally without exceeding the watchdog timeout.
if self.draftModel != nil || self.streamExperts {
if self.draftModel != nil || self.streamExperts || self.mtp {
setenv("MLX_MAX_OPS_PER_BUFFER", "50", 1)
}

if self.mtp {
setenv("SWIFTLM_MTP_ENABLE", "1", 1)
}

// Register SwiftLM-owned DFlash model types before any model loading.
await registerDFlashModelTypes()

Expand Down Expand Up @@ -766,7 +776,9 @@ struct MLXServer: AsyncParsableCommand {
thinking: self.thinking,
isVision: isVision,
prefillSize: self.prefillSize,
turboKV: self.turboKV
turboKV: self.turboKV,
mtp: self.mtp,
numMtpTokens: self.numMtpTokens
)

let parallelSlots = self.parallel
Expand Down Expand Up @@ -797,7 +809,8 @@ struct MLXServer: AsyncParsableCommand {
let thinkingStr = config.thinking ? "enabled" : "disabled"
let ssdStr = self.streamExperts ? "enabled" : "disabled"
let turboKVStr = config.turboKV ? "enabled" : "disabled"
print("[SwiftLM] Config: ctx_size=\(ctxSizeStr), temp=\(config.temp), top_p=\(config.topP), top_k=\(topKStr), min_p=\(minPStr), repeat_penalty=\(penaltyStr), parallel=\(parallelSlots), cors=\(corsStr), mem_limit=\(memLimitStr), auth=\(authStr), thinking=\(thinkingStr), ssd_stream=\(ssdStr), turbo_kv=\(turboKVStr)")
let mtpStr = config.mtp ? "enabled (\(config.numMtpTokens) tokens/round)" : "disabled"
print("[SwiftLM] Config: ctx_size=\(ctxSizeStr), temp=\(config.temp), top_p=\(config.topP), top_k=\(topKStr), min_p=\(minPStr), repeat_penalty=\(penaltyStr), parallel=\(parallelSlots), cors=\(corsStr), mem_limit=\(memLimitStr), auth=\(authStr), thinking=\(thinkingStr), ssd_stream=\(ssdStr), turbo_kv=\(turboKVStr), mtp=\(mtpStr)")

// ── Build Hummingbird router ──
let router = Router()
Expand Down Expand Up @@ -1044,6 +1057,8 @@ struct ServerConfig: Sendable {
let prefillSize: Int
/// When true, each KVCacheSimple layer compresses history > 8192 tokens to 3-bit PolarQuant.
let turboKV: Bool
let mtp: Bool
let numMtpTokens: Int
}

// ── SSD Memory Budget ────────────────────────────────────────────────────────
Expand Down Expand Up @@ -1584,14 +1599,26 @@ func handleChatCompletion(
}
let remainingTokens = lmInput.text.tokens[startIndex...]
let trimmedInput = LMInput(tokens: remainingTokens)
stream = try MLXLMCommon.generate(
input: trimmedInput, cache: cache, parameters: params, context: context
)
if config.mtp, context.model is any MTPLanguageModel {
stream = try MLXLMCommon.generateMTP(
input: trimmedInput, cache: cache, parameters: params, context: context, numMTPTokens: config.numMtpTokens
)
} else {
stream = try MLXLMCommon.generate(
input: trimmedInput, cache: cache, parameters: params, context: context
)
}
} else {
// Cache miss: process the full prompt.
stream = try MLXLMCommon.generate(
input: lmInput, cache: cache, parameters: params, context: context
)
if config.mtp, context.model is any MTPLanguageModel {
stream = try MLXLMCommon.generateMTP(
input: lmInput, cache: cache, parameters: params, context: context, numMTPTokens: config.numMtpTokens
)
} else {
stream = try MLXLMCommon.generate(
input: lmInput, cache: cache, parameters: params, context: context
)
}
}

// Return a closure that will save the cache state synchronously AFTER
Expand Down
Loading
Loading