Skip to content
Merged
27 changes: 15 additions & 12 deletions Sources/MLXInferenceCore/InferenceEngine.swift
Original file line number Diff line number Diff line change
Expand Up @@ -613,24 +613,27 @@ extension InferenceEngine {

// maxContextWindow is already set during loadModel() from config.json

// TurboKV: enable 3-bit PolarQuant+QJL on every KVCacheSimple layer
// before generation. Must be set on the model (not the cache) so the
// cache inherits the flag when newCache() is called inside generate().
// TurboKV: enable 3-bit PolarQuant+QJL on every KVCacheSimple cache layer.
// KVCacheSimple is a cache object (not a neural-network Module), so we
// iterate the cache array — mirroring the pattern in Server.swift.
let cache = await container.perform { ctx in ctx.model.newCache(parameters: params) }
if config.turboKV {
await container.perform { ctx in
for module in ctx.model.modules() {
if let simple = module as? KVCacheSimple {
simple.turboQuantEnabled = true
}
for layer in cache {
if let simple = layer as? KVCacheSimple {
simple.turboQuantEnabled = true
}
}
print("[InferenceEngine] TurboKV enabled for this request")
}

let stream: AsyncStream<Generation> = try await container.generate(
input: lmInput,
parameters: params
)
let stream: AsyncStream<Generation> = try await container.perform { ctx in
try MLXLMCommon.generate(
input: lmInput,
cache: cache,
parameters: params,
context: ctx
)
}

for await generation in stream {
guard !Task.isCancelled else { break }
Expand Down
21 changes: 14 additions & 7 deletions SwiftBuddy/SwiftBuddy/ViewModels/ServerManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ final class ServerManager: ObservableObject {
guard !isOnline else { return }
let configuration = startupConfiguration.normalized

task = Task {
task = Task.detached { [weak self] in
guard let self = self else { return }
do {
let router = Router()

Expand Down Expand Up @@ -259,18 +260,24 @@ final class ServerManager: ObservableObject {
configuration: .init(address: .hostname(configuration.host, port: configuration.port))
)

self.isOnline = true
self.host = configuration.host
self.port = configuration.port
self.runningConfiguration = configuration
self.restartRequired = false
await MainActor.run {
self.isOnline = true
self.host = configuration.host
self.port = configuration.port
self.runningConfiguration = configuration
self.restartRequired = false
}
ConsoleLog.shared.info("Server online at http://\(configuration.host):\(configuration.port)")

try await app.runService()
} catch {
print("Server failed: \(error)")
ConsoleLog.shared.error("Server failed: \(error.localizedDescription)")
self.isOnline = false
await MainActor.run {
self.isOnline = false
self.runningConfiguration = nil
self.restartRequired = false
}
}
}
}
Expand Down
154 changes: 88 additions & 66 deletions SwiftBuddy/SwiftBuddy/Views/SettingsView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,24 @@ struct SettingsView: View {
return ModelCatalog.all.first(where: { $0.id == modelId })?.isMoE ?? false
}

private var currentModelId: String? {
guard case .ready(let modelId) = engine.state else { return nil }
return modelId
}

private var effectiveStreamExpertsSetting: Bool {
viewModel.config.effectiveStreamExperts(defaultingTo: currentModelIsMoE)
}

// Tracks the stream-experts value that was in effect when the current model was loaded.
// A mismatch with `effectiveStreamExpertsSetting` means a reload is required.
@State private var appliedStreamExperts: Bool? = nil

private var needsModelReloadForStreamingChange: Bool {
guard let applied = appliedStreamExperts else { return false }
return effectiveStreamExpertsSetting != applied
}
Comment on lines +53 to +56

private var ssdStreamingBinding: Binding<Bool> {
Binding(
get: { effectiveStreamExpertsSetting },
Expand Down Expand Up @@ -117,6 +131,16 @@ struct SettingsView: View {
}
.onAppear {
draftServerConfiguration = server.startupConfiguration
// Seed the applied value from the current engine state so the reload
// prompt doesn't fire spuriously on first open.
if case .ready = engine.state {
appliedStreamExperts = effectiveStreamExpertsSetting
}
}
.onChange(of: engine.state) { _, newState in
if case .ready = newState {
appliedStreamExperts = effectiveStreamExpertsSetting
}
}
#if os(macOS)
.frame(minWidth: 520, minHeight: 580)
Expand Down Expand Up @@ -295,6 +319,9 @@ struct SettingsView: View {
tint: SwiftBuddyTheme.warning,
hint: "Stream MoE expert weights from NVMe (requires model reload)"
)
if needsModelReloadForStreamingChange {
modelReloadPrompt
}
toggleRow(
label: "TurboQuant KV", icon: "bolt.badge.clock",
isOn: $viewModel.config.turboKV,
Expand Down Expand Up @@ -555,70 +582,8 @@ struct SettingsView: View {
tint: SwiftBuddyTheme.accentSecondary,
hint: "mmap expert weights from NVMe — only active expert pages stay in RAM. Auto-enabled for MoE catalog models."
)
if effectiveStreamExpertsSetting != currentModelIsMoE {
VStack(alignment: .leading, spacing: 8) {
HStack(spacing: 6) {
Image(systemName: "arrow.clockwise.circle.fill")
.foregroundStyle(SwiftBuddyTheme.warning)
.font(.caption)
Text("Reload model to apply this change")
.font(.caption2.weight(.medium))
.foregroundStyle(SwiftBuddyTheme.warning)
Spacer()
Button("Reload") {
let currentId: String? = {
if case .ready(let id) = engine.state { return id }
return nil
}()
if let id = currentId {
Task {
engine.unload()
await engine.load(modelId: id)
}
}
}
.font(.caption2.weight(.semibold))
.foregroundStyle(SwiftBuddyTheme.accent)
.buttonStyle(.plain)
}

switch engine.state {
case .loading(let progress, let stage):
VStack(alignment: .leading, spacing: 4) {
HStack {
Text(stage)
.font(.caption2.weight(.medium))
.foregroundStyle(SwiftBuddyTheme.textSecondary)
Spacer()
Text("\(Int(progress * 100))%")
.font(.caption2.monospacedDigit())
.foregroundStyle(SwiftBuddyTheme.textTertiary)
}
ProgressView(value: progress)
.tint(SwiftBuddyTheme.accent)
}
case .downloading(let progress, let speed):
VStack(alignment: .leading, spacing: 4) {
HStack {
Text("Downloading model files")
.font(.caption2.weight(.medium))
.foregroundStyle(SwiftBuddyTheme.textSecondary)
Spacer()
Text("\(Int(progress * 100))% · \(speed)")
.font(.caption2.monospacedDigit())
.foregroundStyle(SwiftBuddyTheme.textTertiary)
}
ProgressView(value: progress)
.tint(SwiftBuddyTheme.accent)
}
default:
EmptyView()
}
}
.padding(.horizontal, 4)
.padding(.vertical, 6)
.background(SwiftBuddyTheme.warning.opacity(0.08))
.clipShape(RoundedRectangle(cornerRadius: 8))
if needsModelReloadForStreamingChange {
modelReloadPrompt
}
}
}
Expand Down Expand Up @@ -702,7 +667,7 @@ struct SettingsView: View {
}
.pickerStyle(.segmented)
.tint(SwiftBuddyTheme.accent)
.onChange(of: localColorScheme) { newValue in
.onChange(of: localColorScheme) { _, newValue in
// Defer the @Published write to avoid the view update crash
Task { @MainActor in
appearance.preference = newValue
Expand Down Expand Up @@ -917,7 +882,7 @@ struct SettingsView: View {
port: server.port,
parallel: server.startupConfiguration.parallelSlots,
apiKeySet: !server.startupConfiguration.apiKey.isEmpty,
modelId: {
modelId: { () -> String? in
if case .ready(let id) = engine.state { return id }
return nil
}()
Expand Down Expand Up @@ -981,6 +946,63 @@ struct SettingsView: View {
}
}

@ViewBuilder
private var modelReloadPrompt: some View {
VStack(alignment: .leading, spacing: 8) {
HStack(spacing: 6) {
Image(systemName: "arrow.clockwise.circle.fill")
.foregroundStyle(SwiftBuddyTheme.warning)
.font(.caption)
Text("Reload model to apply this change")
.font(.caption2.weight(.medium))
.foregroundStyle(SwiftBuddyTheme.warning)
Spacer()
Button("Reload") {
reloadCurrentModel()
}
.font(.caption2.weight(.semibold))
.foregroundStyle(SwiftBuddyTheme.accent)
.buttonStyle(.plain)
.disabled(currentModelId == nil)
}

switch engine.state {
case .loading(let progress, let stage):
progressRow(label: stage, progress: progress)
case .downloading(let progress, let speed):
progressRow(label: "Downloading · \(speed)", progress: progress)
default:
EmptyView()
}
}
}

@ViewBuilder
private func progressRow(label: String, progress: Double) -> some View {
VStack(alignment: .leading, spacing: 4) {
HStack {
Text(label)
.font(.caption2.weight(.medium))
.foregroundStyle(SwiftBuddyTheme.textSecondary)
Spacer()
Text("\(Int(progress * 100))%")
.font(.caption2.monospacedDigit())
.foregroundStyle(SwiftBuddyTheme.textTertiary)
}
ProgressView(value: progress)
.tint(SwiftBuddyTheme.accent)
.controlSize(.small)
}
}

private func reloadCurrentModel() {
guard let currentModelId else { return }
Task {
engine.unload()
await engine.load(modelId: currentModelId)
}
}

@ViewBuilder
private func parameterCard<Content: View>(_ title: String, @ViewBuilder content: () -> Content) -> some View {
VStack(alignment: .leading, spacing: 10) {
Expand Down
3 changes: 2 additions & 1 deletion SwiftBuddy/generate_xcodeproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def uid():
# ── MLXInferenceCore sources (path relative to SwiftBuddy/)
core_sources = [
("../Sources/MLXInferenceCore/ChatMessage.swift", uid(), uid()),
("../Sources/MLXInferenceCore/CLICommandBuilder.swift", uid(), uid()),
("../Sources/MLXInferenceCore/GenerationConfig.swift", uid(), uid()),
("../Sources/MLXInferenceCore/ModelCatalog.swift", uid(), uid()),
("../Sources/MLXInferenceCore/ModelStorage.swift", uid(), uid()),
Expand Down Expand Up @@ -512,7 +513,7 @@ def main():
print(" • ../mlx-swift-lm → MLXLLM, MLXLMCommon")
print()
print("📂 MLXInferenceCore sources included directly:")
for p, _, _ in [("ChatMessage", None, None), ("GenerationConfig", None, None),
for p, _, _ in [("ChatMessage", None, None), ("CLICommandBuilder", None, None), ("GenerationConfig", None, None),
("ModelCatalog", None, None), ("ModelDownloadManager", None, None),
("ModelArchitectureProbe", None, None), ("InferenceEngine", None, None)]:
print(f" • {p}.swift")
Expand Down
30 changes: 30 additions & 0 deletions tests/SwiftBuddyTests/ContextWindowCalculationTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import XCTest
import MLX
import MLXLMCommon
@testable import MLXInferenceCore

final class ContextWindowCalculationTests: XCTestCase {

@MainActor
func testContextTokensCalculation() async throws {
// Feature: Verify that tokens calculation accurately reflects the prompt cache window
// by evaluating the full size of the prepared tokens array, not just the batch shape.

let engine = InferenceEngine()

// Mock a scenario where userInput prepares a chat template with large history.
// We will directly instantiate LMInput and assert on its size.

let mockTokens = MLXArray(Array(0..<512))
// If tokenizer batches it, shape could be [1, 512].
let reshapedTokens = mockTokens.reshaped([1, 512])

// MLXLMCommon's LMInput struct
let lmInput = LMInput(tokens: reshapedTokens)

// Validate that using .size accurately captures the token count (512)
// rather than falling victim to the batch dimension .shape[0] which would be 1.
XCTAssertEqual(lmInput.text.tokens.shape[0], 1, "shape[0] captures the batch dimension, returning 1")
XCTAssertEqual(lmInput.text.tokens.size, 512, "size captures the total token count, resolving the context window bug")
}
}
Loading
Loading