diff --git a/LICENCE b/LICENCE new file mode 100644 index 0000000..4153cd3 --- /dev/null +++ b/LICENCE @@ -0,0 +1,287 @@ + EUROPEAN UNION PUBLIC LICENCE v. 1.2 + EUPL © the European Union 2007, 2016 + +This European Union Public Licence (the ‘EUPL’) applies to the Work (as defined +below) which is provided under the terms of this Licence. Any use of the Work, +other than as authorised under this Licence is prohibited (to the extent such +use is covered by a right of the copyright holder of the Work). + +The Work is provided under the terms of this Licence when the Licensor (as +defined below) has placed the following notice immediately following the +copyright notice for the Work: + + Licensed under the EUPL + +or has expressed by any other means his willingness to license under the EUPL. + +1. Definitions + +In this Licence, the following terms have the following meaning: + +- ‘The Licence’: this Licence. + +- ‘The Original Work’: the work or software distributed or communicated by the + Licensor under this Licence, available as Source Code and also as Executable + Code as the case may be. + +- ‘Derivative Works’: the works or software that could be created by the + Licensee, based upon the Original Work or modifications thereof. This Licence + does not define the extent of modification or dependence on the Original Work + required in order to classify a work as a Derivative Work; this extent is + determined by copyright law applicable in the country mentioned in Article 15. + +- ‘The Work’: the Original Work or its Derivative Works. + +- ‘The Source Code’: the human-readable form of the Work which is the most + convenient for people to study and modify. + +- ‘The Executable Code’: any code which has generally been compiled and which is + meant to be interpreted by a computer as a program. + +- ‘The Licensor’: the natural or legal person that distributes or communicates + the Work under the Licence. + +- ‘Contributor(s)’: any natural or legal person who modifies the Work under the + Licence, or otherwise contributes to the creation of a Derivative Work. + +- ‘The Licensee’ or ‘You’: any natural or legal person who makes any usage of + the Work under the terms of the Licence. + +- ‘Distribution’ or ‘Communication’: any act of selling, giving, lending, + renting, distributing, communicating, transmitting, or otherwise making + available, online or offline, copies of the Work or providing access to its + essential functionalities at the disposal of any other natural or legal + person. + +2. Scope of the rights granted by the Licence + +The Licensor hereby grants You a worldwide, royalty-free, non-exclusive, +sublicensable licence to do the following, for the duration of copyright vested +in the Original Work: + +- use the Work in any circumstance and for all usage, +- reproduce the Work, +- modify the Work, and make Derivative Works based upon the Work, +- communicate to the public, including the right to make available or display + the Work or copies thereof to the public and perform publicly, as the case may + be, the Work, +- distribute the Work or copies thereof, +- lend and rent the Work or copies thereof, +- sublicense rights in the Work or copies thereof. + +Those rights can be exercised on any media, supports and formats, whether now +known or later invented, as far as the applicable law permits so. + +In the countries where moral rights apply, the Licensor waives his right to +exercise his moral right to the extent allowed by law in order to make effective +the licence of the economic rights here above listed. + +The Licensor grants to the Licensee royalty-free, non-exclusive usage rights to +any patents held by the Licensor, to the extent necessary to make use of the +rights granted on the Work under this Licence. + +3. Communication of the Source Code + +The Licensor may provide the Work either in its Source Code form, or as +Executable Code. If the Work is provided as Executable Code, the Licensor +provides in addition a machine-readable copy of the Source Code of the Work +along with each copy of the Work that the Licensor distributes or indicates, in +a notice following the copyright notice attached to the Work, a repository where +the Source Code is easily and freely accessible for as long as the Licensor +continues to distribute or communicate the Work. + +4. Limitations on copyright + +Nothing in this Licence is intended to deprive the Licensee of the benefits from +any exception or limitation to the exclusive rights of the rights owners in the +Work, of the exhaustion of those rights or of other applicable limitations +thereto. + +5. Obligations of the Licensee + +The grant of the rights mentioned above is subject to some restrictions and +obligations imposed on the Licensee. Those obligations are the following: + +Attribution right: The Licensee shall keep intact all copyright, patent or +trademarks notices and all notices that refer to the Licence and to the +disclaimer of warranties. The Licensee must include a copy of such notices and a +copy of the Licence with every copy of the Work he/she distributes or +communicates. The Licensee must cause any Derivative Work to carry prominent +notices stating that the Work has been modified and the date of modification. + +Copyleft clause: If the Licensee distributes or communicates copies of the +Original Works or Derivative Works, this Distribution or Communication will be +done under the terms of this Licence or of a later version of this Licence +unless the Original Work is expressly distributed only under this version of the +Licence — for example by communicating ‘EUPL v. 1.2 only’. The Licensee +(becoming Licensor) cannot offer or impose any additional terms or conditions on +the Work or Derivative Work that alter or restrict the terms of the Licence. + +Compatibility clause: If the Licensee Distributes or Communicates Derivative +Works or copies thereof based upon both the Work and another work licensed under +a Compatible Licence, this Distribution or Communication can be done under the +terms of this Compatible Licence. For the sake of this clause, ‘Compatible +Licence’ refers to the licences listed in the appendix attached to this Licence. +Should the Licensee's obligations under the Compatible Licence conflict with +his/her obligations under this Licence, the obligations of the Compatible +Licence shall prevail. + +Provision of Source Code: When distributing or communicating copies of the Work, +the Licensee will provide a machine-readable copy of the Source Code or indicate +a repository where this Source will be easily and freely available for as long +as the Licensee continues to distribute or communicate the Work. + +Legal Protection: This Licence does not grant permission to use the trade names, +trademarks, service marks, or names of the Licensor, except as required for +reasonable and customary use in describing the origin of the Work and +reproducing the content of the copyright notice. + +6. Chain of Authorship + +The original Licensor warrants that the copyright in the Original Work granted +hereunder is owned by him/her or licensed to him/her and that he/she has the +power and authority to grant the Licence. + +Each Contributor warrants that the copyright in the modifications he/she brings +to the Work are owned by him/her or licensed to him/her and that he/she has the +power and authority to grant the Licence. + +Each time You accept the Licence, the original Licensor and subsequent +Contributors grant You a licence to their contributions to the Work, under the +terms of this Licence. + +7. Disclaimer of Warranty + +The Work is a work in progress, which is continuously improved by numerous +Contributors. It is not a finished work and may therefore contain defects or +‘bugs’ inherent to this type of development. + +For the above reason, the Work is provided under the Licence on an ‘as is’ basis +and without warranties of any kind concerning the Work, including without +limitation merchantability, fitness for a particular purpose, absence of defects +or errors, accuracy, non-infringement of intellectual property rights other than +copyright as stated in Article 6 of this Licence. + +This disclaimer of warranty is an essential part of the Licence and a condition +for the grant of any rights to the Work. + +8. Disclaimer of Liability + +Except in the cases of wilful misconduct or damages directly caused to natural +persons, the Licensor will in no event be liable for any direct or indirect, +material or moral, damages of any kind, arising out of the Licence or of the use +of the Work, including without limitation, damages for loss of goodwill, work +stoppage, computer failure or malfunction, loss of data or any commercial +damage, even if the Licensor has been advised of the possibility of such damage. +However, the Licensor will be liable under statutory product liability laws as +far such laws apply to the Work. + +9. Additional agreements + +While distributing the Work, You may choose to conclude an additional agreement, +defining obligations or services consistent with this Licence. However, if +accepting obligations, You may act only on your own behalf and on your sole +responsibility, not on behalf of the original Licensor or any other Contributor, +and only if You agree to indemnify, defend, and hold each Contributor harmless +for any liability incurred by, or claims asserted against such Contributor by +the fact You have accepted any warranty or additional liability. + +10. Acceptance of the Licence + +The provisions of this Licence can be accepted by clicking on an icon ‘I agree’ +placed under the bottom of a window displaying the text of this Licence or by +affirming consent in any other similar way, in accordance with the rules of +applicable law. Clicking on that icon indicates your clear and irrevocable +acceptance of this Licence and all of its terms and conditions. + +Similarly, you irrevocably accept this Licence and all of its terms and +conditions by exercising any rights granted to You by Article 2 of this Licence, +such as the use of the Work, the creation by You of a Derivative Work or the +Distribution or Communication by You of the Work or copies thereof. + +11. Information to the public + +In case of any Distribution or Communication of the Work by means of electronic +communication by You (for example, by offering to download the Work from a +remote location) the distribution channel or media (for example, a website) must +at least provide to the public the information requested by the applicable law +regarding the Licensor, the Licence and the way it may be accessible, concluded, +stored and reproduced by the Licensee. + +12. Termination of the Licence + +The Licence and the rights granted hereunder will terminate automatically upon +any breach by the Licensee of the terms of the Licence. + +Such a termination will not terminate the licences of any person who has +received the Work from the Licensee under the Licence, provided such persons +remain in full compliance with the Licence. + +13. Miscellaneous + +Without prejudice of Article 9 above, the Licence represents the complete +agreement between the Parties as to the Work. + +If any provision of the Licence is invalid or unenforceable under applicable +law, this will not affect the validity or enforceability of the Licence as a +whole. Such provision will be construed or reformed so as necessary to make it +valid and enforceable. + +The European Commission may publish other linguistic versions or new versions of +this Licence or updated versions of the Appendix, so far this is required and +reasonable, without reducing the scope of the rights granted by the Licence. New +versions of the Licence will be published with a unique version number. + +All linguistic versions of this Licence, approved by the European Commission, +have identical value. Parties can take advantage of the linguistic version of +their choice. + +14. Jurisdiction + +Without prejudice to specific agreement between parties, + +- any litigation resulting from the interpretation of this License, arising + between the European Union institutions, bodies, offices or agencies, as a + Licensor, and any Licensee, will be subject to the jurisdiction of the Court + of Justice of the European Union, as laid down in article 272 of the Treaty on + the Functioning of the European Union, + +- any litigation arising between other parties and resulting from the + interpretation of this License, will be subject to the exclusive jurisdiction + of the competent court where the Licensor resides or conducts its primary + business. + +15. Applicable Law + +Without prejudice to specific agreement between parties, + +- this Licence shall be governed by the law of the European Union Member State + where the Licensor has his seat, resides or has his registered office, + +- this licence shall be governed by Belgian law if the Licensor has no seat, + residence or registered office inside a European Union Member State. + +Appendix + +‘Compatible Licences’ according to Article 5 EUPL are: + +- GNU General Public License (GPL) v. 2, v. 3 +- GNU Affero General Public License (AGPL) v. 3 +- Open Software License (OSL) v. 2.1, v. 3.0 +- Eclipse Public License (EPL) v. 1.0 +- CeCILL v. 2.0, v. 2.1 +- Mozilla Public Licence (MPL) v. 2 +- GNU Lesser General Public Licence (LGPL) v. 2.1, v. 3 +- Creative Commons Attribution-ShareAlike v. 3.0 Unported (CC BY-SA 3.0) for + works other than software +- European Union Public Licence (EUPL) v. 1.1, v. 1.2 +- Québec Free and Open-Source Licence — Reciprocity (LiLiQ-R) or Strong + Reciprocity (LiLiQ-R+). + +The European Commission may update this Appendix to later versions of the above +licences without producing a new version of the EUPL, as long as they provide +the rights granted in Article 2 of this Licence and protect the covered Source +Code from exclusive appropriation. + +All other changes or additions to this Appendix require the production of a new +EUPL version. diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 0000000..55803f7 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,98 @@ + + +# go-inference — documentation index + +**Module**: `dappco.re/go/inference` +**Role**: The contract package every backend and consumer in the tetrad imports. + +## Tetrad position + +``` + ┌──────────────────────────────┐ + │ dappco.re/go (core) │ + └──────────────┬───────────────┘ + │ + ┌──────────────┴────────────────┐ + you are here → go-inference (CONTRACT) │ ← pure interfaces + wire types + │ • TextModel / Backend │ + │ • state/ lifecycle │ + │ • openai/ anthropic/ ollama/ │ + │ • capability / probe │ + └──┬─────────────┬──────────────┘ + │ │ register via init() + ┌────────┴───┐ ┌──────┴────────┐ + │ go-mlx │ │ go-rocm / │ ← native backends + │ darwin/ │ │ go-cuda │ + │ arm64 │ └───────────────┘ + └─────┬──────┘ + │ consumed by + ┌─────┴──────────┬────────────────┐ + │ go-ml │ go-ai │ ← consumers + │ scoring/agent │ router/demos │ + └────────────────┘ └───────────────┘ +``` + +## Doc tree + +``` +docs/ +├── README.md ← you are here +├── inference/ ← root package +│ ├── README.md — package overview + how the pieces fit +│ ├── inference.md — TextModel + Backend + registry + LoadModel +│ ├── contracts.md — extension interfaces (Scheduler, Cache, Embed, Rerank, ToolParse, …) +│ ├── options.md — GenerateOption + LoadOption + With* +│ ├── capability.md — CapabilityReport + AlgorithmProfile + RuntimeMemoryLimiter +│ ├── local_tuning.md — MachineDiscoverer + TuningPlanner + model replace +│ ├── probe.md — ProbeEvent + ProbeSink +│ ├── service.md — Core ServiceRuntime registration (Mantis #1336) +│ ├── training.md — TrainableModel + Adapter + LoRAConfig +│ ├── discover.md — Discover() filesystem scan +│ ├── gguf.md — GGUFInfo metadata reader +│ ├── dataset.md — DatasetSample + DatasetStream +│ └── identity.md — re-export aliases from state +│ +├── state/ ← state subpackage +│ ├── README.md — package overview + mental model +│ ├── agent_memory.md — Wake / Sleep / Fork lifecycle +│ ├── identity.md — ModelIdentity / TokenizerIdentity / Adapter / Runtime / Sampler / Bundle +│ ├── project_seed.md — project seed URI planning + compatibility checks +│ ├── store.md — Store / Resolver / Writer interfaces +│ ├── memory.md — InMemoryStore +│ └── filestore.md — append-only file-backed store +│ +├── openai/ ← OpenAI wire types +│ ├── README.md — package overview +│ ├── openai.md — Chat Completions + Handler +│ ├── responses.md — Responses API DTOs +│ └── services.md — embeddings / rerank / cache / cancel / capabilities handlers +│ +├── anthropic/ +│ └── anthropic.md — Messages API wire types +│ +└── ollama/ + └── ollama.md — Ollama-compatible wire types +``` + +## Where to start + +- **"What's the basic loop?"** → [`inference/inference.md`](inference/inference.md) +- **"How do I add a backend?"** → [`inference/inference.md`](inference/inference.md) — Backend interface + Register pattern +- **"How does agent memory work?"** → [`state/agent_memory.md`](state/agent_memory.md) — Wake/Sleep/Fork +- **"How do project seeds reload safely?"** → [`state/project_seed.md`](state/project_seed.md) — project seed helpers + compatibility +- **"How does OpenAI compatibility work?"** → [`openai/openai.md`](openai/openai.md) +- **"What can a backend advertise?"** → [`inference/capability.md`](inference/capability.md) +- **"How does local setup/autotune work?"** → [`inference/local_tuning.md`](inference/local_tuning.md) +- **"How do I observe runtime?"** → [`inference/probe.md`](inference/probe.md) + +## Legacy docs + +`architecture.md`, `interfaces.md`, `backends.md`, `types.md`, `development.md`, `history.md`, `index.md`, `RFC.models.md`, `RFC-CORE-008-AGENT-EXPERIENCE.md` predate this per-file pass. They cover overlapping ground at a wider grain and may rot as the per-file docs evolve. Pending: collapse the still-useful bits into `inference/README.md` and the per-file pages, then mark the legacy docs deprecated. + +## Standards + +- UK English +- EUPL-1.2 licence (see [LICENCE](../LICENCE)) +- SPDX header on every source file +- Conventional commits, scopes per package +- Co-Author: `Co-Authored-By: Virgil ` diff --git a/docs/anthropic/anthropic.md b/docs/anthropic/anthropic.md new file mode 100644 index 0000000..1b079e3 --- /dev/null +++ b/docs/anthropic/anthropic.md @@ -0,0 +1,79 @@ + + +# anthropic/anthropic.go — Messages API wire types + +**Package**: `dappco.re/go/inference/anthropic` +**File**: `go/anthropic/anthropic.go` + +## What this is + +The Anthropic Messages API (`/v1/messages`) wire surface. Same pattern as `openai/openai.go` but for Anthropic-compatible SDKs — DTOs + translation to `inference.Message` + `inference.GenerateOption`. No HTTP handler yet; planned alongside the Responses handler. + +This is a parity item from the 2026-05-09 vMLX gap report: vMLX exposed Anthropic compatibility and CoreAgent needed the same surface for Claude-flavoured SDKs hitting local inference. + +## Constants + +```go +const DefaultMessagesPath = "/v1/messages" +``` + +## DTOs + +```go +ContentBlock // type + text — Anthropic's typed-block content model +Message // role + []ContentBlock +MessageRequest // model + system + messages + max_tokens + sampler + stream + stop_sequences +Usage // input_tokens + output_tokens +MessageResponse // id + type + role + model + content[] + stop_reason + stop_sequence + usage +``` + +Key differences from OpenAI: + +- `Message.Content` is `[]ContentBlock`, not a plain string — supports image / tool_use / tool_result block types out of the box. +- `system` is a top-level field, not a message with role=system. +- `Usage` uses `input_tokens` / `output_tokens` (vs OpenAI's `prompt_tokens` / `completion_tokens`). +- Stop reason is named (`end_turn` / `max_tokens` / `stop_sequence` / `tool_use`), not a free string. + +## InferenceMessages + +```go +messages := anthropic.InferenceMessages(req) +``` + +Flattens the typed-block content to plain text + builds the standard `inference.Message` slice. The Anthropic top-level `system` field becomes a leading system message in the inference slice — so the runtime sees one uniform message list regardless of API origin. + +`blockText` strips down to `type: "text"` blocks only; image/tool blocks are dropped at the translation boundary (no multi-modal support in the core runner yet). + +## GenerateOptions + +```go +opts := anthropic.GenerateOptions(req) +for tok := range model.Chat(ctx, messages, opts...) { ... } +``` + +Same translation as the OpenAI sibling — sampler fields lowered to `inference.GenerateOption`. `MaxTokens` is required on the Anthropic side (no default); the translation only appends `WithMaxTokens` when `MaxTokens > 0`. + +## NewTextResponse + +```go +resp := anthropic.NewTextResponse(requestID, modelName, text, metrics) +``` + +Minimal response builder — single text content block + stop_reason="end_turn" + usage filled from the inference metrics. Same convenience as `openai.NewTextResponse`; lets a handler produce a valid Anthropic-shaped response in one line. + +## What's not here + +- Streaming. Anthropic's streaming format (`event: message_start`, etc.) is its own thing — not yet implemented. +- Tool-use / tool-result blocks. The shape is in `ContentBlock` but the translation drops them. When tool-call parsing lands (per the parity plan), this will route through `inference.ToolParser`. +- Vision blocks. Same reason as OpenAI Responses — multi-modal is out of scope for the core runner. + +## Why a separate file from openai/ + +Anthropic's wire shape is **different enough** that mashing them into one package would require option types or interface-based content blocks — both worse than just having two parallel files. The size budget is small (~110 lines). + +## Related + +- [README.md](README.md) — package overview (planned) +- [../openai/openai.md](../openai/openai.md) — the parallel OpenAI translation +- [../inference/contracts.md](../inference/contracts.md) — `ToolParser` for future tool-use routing +- `core/api` — mounts an Anthropic handler when configured (handler TBD) diff --git a/docs/inference/README.md b/docs/inference/README.md new file mode 100644 index 0000000..0784025 --- /dev/null +++ b/docs/inference/README.md @@ -0,0 +1,90 @@ + + +# inference/ — contract package root + +**Package**: `dappco.re/go/inference` + +## What this package owns + +The **central contract** that every other tetrad repo speaks. Pure interfaces, DTOs, registries, and option types. Zero CGO. Zero platform branches. Compiles everywhere. + +Three categories: + +| Category | What | Files | +|----------|------|-------| +| **Core runtime** | TextModel + Backend + registry + LoadModel | [inference.md](inference.md) | +| **Options** | GenerateOption + LoadOption + With* | [options.md](options.md) | +| **Extension** | Scheduler, Cache, Embedding, Rerank, ToolParse, ReasoningParse, ModelPackInspect | [contracts.md](contracts.md) | +| **Static intro** | CapabilityReport / AlgorithmProfile / RuntimeMemoryLimits | [capability.md](capability.md) | +| **Local setup** | MachineDiscoverer / TuningPlanner / model replace | [local_tuning.md](local_tuning.md) | +| **Dynamic observe** | ProbeEvent / ProbeSink | [probe.md](probe.md) | +| **Lifecycle** | Service + RegisterCore (Mantis #1336) | [service.md](service.md) | +| **Training** | TrainableModel + Adapter + LoRAConfig | [training.md](training.md) | +| **Discovery** | Discover() | [discover.md](discover.md) | +| **Format reader** | GGUFInfo | [gguf.md](gguf.md) | +| **Data shape** | DatasetSample + DatasetStream | [dataset.md](dataset.md) | +| **Re-export aliases** | identity types into the parent pkg | [identity.md](identity.md) | + +## How the pieces fit + +``` +LoadModel(path, opts...) ← caller entry + │ + ├──→ Default() / Get(name) ← registry lookup + │ │ + │ └──→ Backend.LoadModel(...) ← native driver + │ │ + │ └──→ returns TextModel ← what the caller uses + │ + └──→ Caller: model.Generate(ctx, prompt, WithMaxTokens(64)) + model.Chat(ctx, msgs, WithTemperature(0.7)) + model.Classify(ctx, prompts) + model.BatchGenerate(ctx, prompts) + ... + +Optionally: + if sched, ok := model.(SchedulerModel); ok { ... } ← contracts.go + if cache, ok := model.(CacheService); ok { ... } + if embed, ok := model.(EmbeddingModel); ok { ... } + if train, ok := model.(TrainableModel); ok { ... } ← training.go + if probe, ok := model.(CapabilityReporter);ok { report := probe.Capabilities() } +``` + +## Sibling packages + +- [../state/](../state/README.md) — durable state DTOs + Wake/Sleep/Fork lifecycle +- [../openai/](../openai/README.md) — OpenAI wire types + HTTP handlers +- [../anthropic/](../anthropic/anthropic.md) — Anthropic Messages wire types +- [../ollama/](../ollama/ollama.md) — Ollama-compatible wire types + +## Stability rules + +This package is the shared contract. Changes here cascade to every backend and consumer. + +- **No new methods on `TextModel` or `Backend`** without a Virgil review. +- **Prefer new interfaces over wider TextModel.** New capabilities land in `contracts.go` as opt-in extensions. +- **New fields on `GenerateConfig` / `LoadConfig` are safe** when zero-value defaults preserve old behaviour. +- **Wire DTOs in openai/anthropic/ollama track upstream** — adding fields is safe, renaming requires upstream rename first. + +## Coding standards (this repo) + +- UK English in code, comments, docs (colour, organisation, licence, serialise) +- SPDX header on every new file: `// SPDX-Licence-Identifier: EUPL-1.2` +- Zero external dependencies — stdlib + `dappco.re/go` only (testify in tests) +- Error strings start lowercase, end without punctuation: `"backend %q not registered"` +- Test triplets: `_Good` / `_Bad` / `_Ugly` +- Conventional commits scoped to `inference`, `state`, `openai`, `anthropic`, `ollama`, `options`, `discover` +- Co-Author trailer: `Co-Authored-By: Virgil ` + +## Who imports this + +| Module | Why | +|--------|-----| +| `dappco.re/go/mlx` | implements Backend + TextModel for Apple Metal | +| `dappco.re/go/rocm` (planned) | implements Backend + TextModel for AMD ROCm | +| `dappco.re/go/cuda` (planned) | implements Backend + TextModel for NVIDIA CUDA | +| `dappco.re/go/ml` | wraps Backend + TextModel into scoring/eval engine, adds HTTP/llama backends | +| `dappco.re/go/ai` | provider router, outbound OpenAI provider, BookState demo | +| `dappco.re/go/i18n` | TextModel for domain classification | +| `dappco.re/go/api` | mounts OpenAI / Anthropic / Ollama handlers | +| `dappco.re/go/ide` | reads CapabilityReport + bundle index for model picker | diff --git a/docs/inference/capability.md b/docs/inference/capability.md new file mode 100644 index 0000000..137f246 --- /dev/null +++ b/docs/inference/capability.md @@ -0,0 +1,138 @@ + + +# capability.go — capability reports + memory limiter + +**Package**: `dappco.re/go/inference` +**File**: `go/capability.go` + +## What this is + +The portable shape for **"what does this backend / model support, at what maturity?"** — consumed by go-ml, go-ai, core/api, core/ide. Backends that implement `CapabilityReporter` answer; consumers branch on the report without importing backend-specific packages. + +Also hosts `RuntimeMemoryLimits` + `RuntimeMemoryLimiter` — the same lane for runtime allocator limits. + +## Capability ID catalogue + +41 stable IDs grouped by lane: + +**Model / inference**: `model.load`, `generate`, `chat`, `classify`, `batch.generate`, `tokenizer`, `chat.template`, `lora.inference`, `lora.training` + +**Runtime / cache / scheduling**: `state.bundle`, `kv.snapshot`, `prompt.cache`, `kv.cache.planning`, `memory.planning`, `model.fit`, `scheduler`, `request.cancel`, `cache.blocks`, `cache.disk`, `cache.warm` + +**Training / eval**: `benchmark`, `evaluation`, `distillation`, `grpo`, `quantization`, `model.merge` + +**Probe / research**: `probe.events`, `probe.attention`, `probe.logits` + +**Wire / compat**: `responses.api`, `anthropic.messages`, `ollama.compat`, `embeddings`, `rerank` + +**Parsers**: `tool.parse`, `reasoning.parse` + +**Decoding**: `speculative.decode`, `prompt.lookup.decode` + +**MoE / specialised quant**: `moe.routing`, `moe.lazy_experts`, `jangtq`, `codebook.vq` + +**Agent memory**: `agent.memory`, `state.wake`, `state.sleep`, `state.fork` + +Snippets of these mirror the parity targets from the 2026-05-09 vMLX gap report. + +## Groups + status + +```go +type CapabilityGroup string // "model" | "runtime" | "training" | "probe" +type CapabilityStatus string // "supported" | "experimental" | "planned" | "unsupported" +``` + +Group is a coarse routing dimension (a UI filter). Status is the maturity stamp. + +## Capability + +```go +type Capability struct { + ID CapabilityID + Group CapabilityGroup + Status CapabilityStatus + Detail string + Labels map[string]string +} +``` + +Constructors short-cut the common shapes: `NewCapability(id, group, status, detail)` plus `SupportedCapability(id, group)`, `ExperimentalCapability(id, group, detail)`, `PlannedCapability(id, group, detail)`. + +## AlgorithmProfile + +Richer than `Capability` — for backends that want to advertise the exact algorithm + which architectures it covers + what it requires + what it provides: + +```go +type AlgorithmProfile struct { + ID CapabilityID + Group CapabilityGroup + CapabilityStatus CapabilityStatus + RuntimeStatus FeatureRuntimeStatus // native | experimental | metadata_only | planned + Algorithm string // free-form: "jangtq_k", "flash_attn_v2", "paged_kv_v1" + Detail string + Architectures []string // ["gemma4", "qwen3", "minimax_m2"] + Requires []CapabilityID + Provides []string + Notes []string +} +``` + +`profile.Capability()` lowers it to a plain `Capability` with the algorithm/architectures/requires/provides folded into labels for transport. + +**Why two shapes?** `Capability` is the wire-stable contract — consumers depend on its small shape. `AlgorithmProfile` is the richer authoring shape backends use locally; lowering to Capability strips author detail to whatever the wire promises. + +## CapabilityReport + +```go +type CapabilityReport struct { + Runtime RuntimeIdentity + Model ModelIdentity + Tokenizer TokenizerIdentity + Adapter AdapterIdentity + Available bool + Architectures []string + Quantizations []string + CacheModes []string + Capabilities []Capability + Labels map[string]string +} +``` + +The full envelope: runtime + model + tokenizer + adapter identity, the available bit, lists of supported architectures / quantisations / cache modes, the capability array, plus free-form labels. + +## CapabilityReporter + +```go +type CapabilityReporter interface { + Capabilities() CapabilityReport +} +``` + +Implemented by `Backend` (returns runtime-level capabilities) and by loaded `TextModel` instances (returns model-level capabilities). Consumers walk via type assertion — not every backend or model implements it. + +## RuntimeMemoryLimits + RuntimeMemoryLimiter + +```go +type RuntimeMemoryLimits struct { + CacheLimitBytes uint64 + MemoryLimitBytes uint64 + PreviousCacheLimitBytes uint64 + PreviousMemoryLimitBytes uint64 +} + +type RuntimeMemoryLimiter interface { + SetRuntimeMemoryLimits(limits) RuntimeMemoryLimits +} + +inference.SetRuntimeMemoryLimits("metal", limits) // package-level helper +``` + +Zero request fields = "leave unchanged". Previous values report the prior caps so callers can restore on exit. + +## Consumed by + +- `go-mlx/register_metal.go` — exposes Metal allocator limits via `RuntimeMemoryLimiter` +- `go-mlx/algorithm_profile.go` + `architecture_profile.go` — publish JANG/MoE/codebook profiles +- `go-ml/capability.go` — `CapabilityReportForBackend(name, backend)` summarises a ml-side backend into the portable shape +- `core/api` — surfaces reports over HTTP for `core/ide` to render the "what can I do" panel +- `go-ai/providers/openai` — outbound provider exposes its capability fingerprint diff --git a/docs/inference/contracts.md b/docs/inference/contracts.md new file mode 100644 index 0000000..f661cb3 --- /dev/null +++ b/docs/inference/contracts.md @@ -0,0 +1,118 @@ + + +# contracts.go — extension interfaces + +**Package**: `dappco.re/go/inference` +**File**: `go/contracts.go` + +## What this is + +The "everything beyond TextModel" surface. Each capability that some +backends support but not all is its own interface, discovered by type +assertion. A backend implements only the interfaces it can deliver; a +consumer probes via `if x, ok := model.(inference.Y); ok { ... }`. + +This file is the source of truth for what extensions exist; the +implementations live in backends. + +## Capability interfaces + +| Interface | What it adds | +|-----------|--------------| +| `SchedulerModel` | queue-aware Schedule(req) → handle + token stream — for serving loops with cancellation + batching | +| `CancellableModel` | CancelRequest(id) — abort an in-flight generation | +| `CacheService` | CacheStats + WarmCache + ClearCache — prompt-cache management | +| `EmbeddingModel` | Embed(req) — vector embeddings | +| `RerankModel` | Rerank(req) — cross-encoder document scoring | +| `ReasoningParser` | ParseReasoning(tokens, text) — extract chain-of-thought from `` channels | +| `ToolParser` | ParseTools(tokens, text) — extract structured tool-call output | +| `ModelPackInspector` | InspectModelPack(path) — validate a model dir without loading weights | + +## Request / Result DTOs + +| Type | Role | +|------|------| +| `RequestHandle` | id + model identity + labels — what a Schedule call returns to track a request | +| `RequestCancelResult` | id + cancelled bool + reason | +| `ScheduledRequest` | id + model + prompt/messages + sampler + labels — input to a scheduler | +| `ScheduledToken` | request_id + token + per-request metrics + labels — what the scheduler streams | +| `CacheBlockRef` | portable handle for one cache block — id, kind, model/adapter/tokenizer hash, token range, size, encoding | +| `CacheStats` | block count + memory/disk bytes + hits/misses/evictions + hit rate + restore latency | +| `CacheWarmRequest` / `CacheWarmResult` | warm a prompt's cache + report which blocks are ready | +| `EmbeddingRequest` / `EmbeddingResult` / `EmbeddingUsage` | input strings → vectors + token accounting | +| `RerankRequest` / `RerankScore` / `RerankResult` | query + documents → scored documents | +| `ReasoningSegment` / `ReasoningParseResult` | visible text vs reasoning channels | +| `ToolCall` / `ToolParseResult` | visible text vs tool calls | +| `ModelPackInspection` | path, format, model identity, supported bool, capabilities, notes | + +## Agent memory aliases (live here for import convenience) + +```go +type AgentMemoryRef = state.Ref +type AgentMemoryWakeRequest = state.WakeRequest +type AgentMemoryWakeResult = state.WakeResult +type AgentMemorySleepRequest = state.SleepRequest +type AgentMemorySleepResult = state.SleepResult +type AgentMemorySession = state.Session +type AgentMemoryForker = state.Forker +``` + +Importing `dappco.re/go/inference` gives you the memory lifecycle +shape without needing a separate `inference/state` import. The state +package owns the real types; this file just re-exports them. + +## How a consumer probes capabilities + +```go +m, _ := inference.LoadModel(path).Value.(inference.TextModel) + +if sched, ok := m.(inference.SchedulerModel); ok { + handle, tokens, err := sched.Schedule(ctx, req) + // serve queue +} +if cancel, ok := m.(inference.CancellableModel); ok { + _ = cancel.CancelRequest(ctx, oldRequestID) +} +if cache, ok := m.(inference.CacheService); ok { + stats, _ := cache.CacheStats(ctx) +} +if embed, ok := m.(inference.EmbeddingModel); ok { + result, _ := embed.Embed(ctx, req) +} +``` + +## How a backend opts in + +In go-mlx (example): + +```go +// metaladapter already implements TextModel +// — add Schedule to also implement SchedulerModel: +func (a *metaladapter) Schedule(ctx, req) (RequestHandle, <-chan ScheduledToken, error) { + // … +} +``` + +No registration step. The type assertion at the call site is the only +discovery mechanism. Backends that *don't* implement an interface +simply fail the type check; consumers fall back to whatever default +they have. + +## Why type-assertion not method-set + +Different backends are at different stages. go-mlx may have +SchedulerModel before go-rocm; go-rocm may ship CacheService earlier +than go-mlx. Forcing every backend to stub out every interface would +make TextModel a 50-method monster and silently degrade — type +assertion lets each backend grow at its own pace and the consumer +explicitly handles the "not available" path. + +## Related + +- [inference.md](inference.md) — the base TextModel + Backend +- [capability.md](capability.md) — `CapabilityReport` for static + introspection of what a backend claims to support +- [../state/agent_memory.md](../state/agent_memory.md) — the real + agent-memory types (these are aliases) +- [../openai/services.md](../openai/services.md) — wire types that + carry EmbeddingResult / RerankResult / CacheStats over HTTP diff --git a/docs/inference/dataset.md b/docs/inference/dataset.md new file mode 100644 index 0000000..9063c37 --- /dev/null +++ b/docs/inference/dataset.md @@ -0,0 +1,78 @@ + + +# dataset.go — DatasetStream contract + +**Package**: `dappco.re/go/inference` +**File**: `go/dataset.go` + +## What this is + +The smallest possible pull-based dataset contract shared by training, evaluation, distillation, and reasoning rollouts. One sample at a time, optional reset, optional length. Backends and consumers agree on this shape so a dataset assembled in go-ml flows directly into go-mlx training without conversion. + +## DatasetSample + +```go +type DatasetSample struct { + Text string // raw text (continuation pretraining) + Prompt string // user prompt (SFT, instruct) + Response string // assistant response (SFT target) + Reasoning string // chain-of-thought (GRPO, distillation) + Messages []Message // multi-turn conversation + Labels map[string]string // routing / filtering metadata +} +``` + +A sample carries whichever fields the task needs. SFT samples populate Prompt + Response. GRPO samples add Reasoning. Eval samples often only use Messages. + +## DatasetStream + +```go +type DatasetStream interface { + Next() (DatasetSample, bool, error) +} +``` + +`Next` returns `(sample, ok, err)`. `ok=false` + `err=nil` = end of stream. Errors are terminal — the caller stops consuming. + +## DatasetResetter + +```go +type DatasetResetter interface { + Reset() error +} +``` + +Optional. Streams that wrap an in-memory list or a seekable file implement Reset so training loops can run multiple epochs. Streaming-only sources (HF datasets streaming mode) don't. + +## DatasetSized + +Optional. Streams that know their length up-front report it for progress UI / cosine LR schedules. + +## DatasetConfig (planned umbrella) + +The capability surface in `capability.go` mentions `CapabilityEvaluation` + `CapabilityDistillation` + `CapabilityGRPO`. Each consumes a DatasetStream. The eval/bench/distill/grpo config DTOs live in the consuming packages (go-mlx, go-ml) rather than here — this file is just the stream contract. + +## Why one interface for everything + +The temptation is to have `TrainingDataset`, `EvalDataset`, `DistillDataset` — different shapes per task. We resist. A single `DatasetStream.Next() → DatasetSample` covers every task because `DatasetSample` is wide enough that each consumer reads the fields it cares about. New tasks add fields to DatasetSample without churning consumers. + +## Implemented by + +- `go-mlx/dataset_stream.go` — in-process iterator over MLX-format files +- `go-ml/ingest.go` — DuckDB / Parquet ingestion → DatasetStream +- `go-mlx/cmd/violet` — wraps an HTTP-streamed dataset +- test fixtures via in-memory slice wrappers + +## Consumed by + +- `go-mlx/sft.go` — supervised fine-tuning loop +- `go-mlx/grpo.go` — reasoning training loop +- `go-mlx/distill.go` — teacher/student distillation +- `go-mlx/eval.go` — evaluation runner +- `go-ml/agent_eval.go` — scoring engine eval + +## Related + +- [training.md](training.md) — TrainableModel consumes DatasetStream in Step +- `go-mlx/docs/training/dataset_stream.md` (planned) — reference iterator +- `go-ml/docs/scoring/ingest.md` (planned) — go-ml's dataset assembly path diff --git a/docs/inference/discover.md b/docs/inference/discover.md new file mode 100644 index 0000000..74d4088 --- /dev/null +++ b/docs/inference/discover.md @@ -0,0 +1,70 @@ + + +# discover.go — model directory scanning + +**Package**: `dappco.re/go/inference` +**File**: `go/discover.go` + +## What this is + +A backend-neutral filesystem scan that yields one `DiscoveredModel` per model directory under a root. Used by: + +- CoreAgent / core/ide model picker UI +- `core/lab` to enumerate available models +- Test harnesses that auto-find fixtures + +Detects both safetensors directories (`config.json` + `*.safetensors`) and GGUF files. Architecture + quantisation metadata extracted at scan time so callers don't have to load each model to decide whether it's interesting. + +## DiscoveredModel + +```go +type DiscoveredModel struct { + Path string // absolute path to dir or .gguf file + ModelType string // architecture: gemma3, qwen3, llama, … + QuantBits int // 0 = unknown / unquantised + QuantGroup int + QuantType string // q4_k_m, q8_0, etc. (GGUF) + QuantFamily string // q4, q8 (coarse) + NumFiles int // number of weight files + Format string // "safetensors" or "gguf" +} +``` + +## Discover + +```go +for m := range inference.Discover("/Volumes/Data/models") { + fmt.Printf("%s arch=%s quant=%dbit\n", m.Path, m.ModelType, m.QuantBits) +} +``` + +Returns `iter.Seq[DiscoveredModel]`. Iteration is lazy — caller can break early on first match. Sort order: alphabetical by path. + +## What it inspects + +For safetensors directories: +- `config.json` → `model_type`, `num_hidden_layers`, `vocab_size`, `quantization_config` +- File count = count of `*.safetensors` + +For GGUF files: +- Magic + version header +- Architecture metadata key +- Quantisation type from tensor headers + +Detection is metadata-only. Weight tensors are not loaded. + +## What it skips + +- Hidden directories (`.git`, `.cache`) +- Directories without `config.json` or matching `*.gguf` +- Symlink loops (basic loop detection) + +## Why a generator not a slice + +Large model trees with 100+ models would cost noticeable RAM if returned all-at-once. The generator pattern lets a UI render the first row immediately while the scan continues. + +## Related + +- [gguf.md](gguf.md) — `GGUFInfo` for the richer single-file scan +- `go-mlx/docs/model/model_pack.md` (planned) — full model-pack validation (uses Discover + Inspect) +- `go-ml/docs/scoring/inventory.md` (planned) — inventory persistence diff --git a/docs/inference/gguf.md b/docs/inference/gguf.md new file mode 100644 index 0000000..eac1090 --- /dev/null +++ b/docs/inference/gguf.md @@ -0,0 +1,70 @@ + + +# gguf.go — GGUF metadata reader + +**Package**: `dappco.re/go/inference` +**File**: `go/gguf.go` + +## What this is + +A minimal GGUF (llama.cpp model format) metadata parser. Reads the header + key-value section without loading tensors — same intent as the safetensors path in `discover.go`. Used by Discover, by `model_pack.go` validation in go-mlx, and by the core/ide model picker. + +## GGUFInfo + +```go +type GGUFInfo struct { + Path string + Architecture string + QuantType string // q4_k_m, q8_0, f16, … + QuantFamily string // q4, q8, f16 + QuantBits int + QuantGroup int + ContextLength int + NumLayers int + HiddenSize int + VocabSize int + ChatTemplate string + NumTensors int + HeaderBytes int64 + FileBytes int64 + Metadata map[string]any +} +``` + +Maps cleanly onto `ModelIdentity` + `TokenizerIdentity.ChatTemplate`. + +## GGUF format constants + +```go +ggufMagic = 0x46554747 // "GGUF" little-endian +ggufVersion = 3 +ggufTypeUint32 = 4 +ggufTypeString = 8 +``` + +The parser handles v2 + v3 files. v1 is rare in the wild; not supported. + +## Public API + +```go +info, err := inference.ReadGGUFInfo("/models/foo.gguf") +infos := inference.ScanGGUF(io.Reader) // for streaming scenarios +``` + +## What it parses + +Header → key-value section. Stops as soon as the architecture + quant + chat template are known. Tensor headers are scanned only when `NumTensors` is requested (default off — the scan is bounded to the metadata section). + +## Why a local parser instead of llama-cpp-go binding + +Three reasons: + +1. **No CGO.** `inference` is zero-deps; pulling in a llama-cpp binding violates the package contract. +2. **Smaller surface.** We only need metadata, not inference — the parser is ~285 lines. +3. **Cross-platform.** The same code compiles on every platform; backend-specific GGUF use (loading tensors) lives in the backend. + +## Related + +- [discover.md](discover.md) — `Discover()` uses this for `.gguf` files +- `go-mlx/docs/model/gguf_info.md` (planned) — backend-specific GGUF tensor load +- `go-mlx/docs/model/gguf_quantize.md` (planned) — write-side GGUF quantisation diff --git a/docs/inference/identity.md b/docs/inference/identity.md new file mode 100644 index 0000000..2d4086c --- /dev/null +++ b/docs/inference/identity.md @@ -0,0 +1,70 @@ + + +# identity.go — aliases to state + sampler conversion + +**Package**: `dappco.re/go/inference` +**File**: `go/identity.go` + +## What this is + +A thin re-export layer. The identity types (`ModelIdentity`, `TokenizerIdentity`, etc.), the `Bundle` envelope, and project-seed helpers live in the `state` subpackage; this file aliases them into the parent `inference` package so consumers importing only `dappco.re/go/inference` see the common names. + +Two real bits of code on top: `SamplerConfigFromGenerateConfig` + `GenerateConfigFromSamplerConfig`. + +## Aliases + +```go +type ModelIdentity = state.ModelIdentity +type TokenizerIdentity = state.TokenizerIdentity +type AdapterIdentity = state.AdapterIdentity +type RuntimeIdentity = state.RuntimeIdentity +type SamplerConfig = state.SamplerConfig +type StateRef = state.StateRef +type StateBundle = state.Bundle +type ProjectSeed = state.ProjectSeed +``` + +A consumer writes: + +```go +import "dappco.re/go/inference" + +func report(c inference.CapabilityReport) { + if c.Adapter.Hash == "" { ... } // AdapterIdentity from inference + bundle := inference.StateBundle{ ... } // Bundle from inference +} +``` + +— and never needs to import `inference/state` directly. + +## SamplerConfigFromGenerateConfig + +```go +state.SamplerConfig = inference.SamplerConfigFromGenerateConfig(cfg) +``` + +Lowers a live `GenerateConfig` (which carries Go-typed defaults and option-fn lineage) to the portable `SamplerConfig` that fits into a `Bundle`. Used when persisting a session: the bundle records the **outcome** of sampler options, not the option-fn chain that produced them. + +`StopTokens` is cloned (separate slice ownership) so the bundle isn't mutated when the live cfg is. + +## GenerateConfigFromSamplerConfig + +The inverse: + +```go +cfg := inference.GenerateConfigFromSamplerConfig(bundle.Sampler) +for tok := range model.Generate(ctx, prompt, withGenerateConfig(cfg)) { ... } +``` + +Restores a sampler config from a bundle and produces the matching `GenerateConfig`. Note: `StopSequences` (text-mode stop strings) is in `SamplerConfig` but **not** in `GenerateConfig` — the conversion drops it, because the runtime path uses token-id stops, not strings. A future GenerateOption could re-introduce it. + +## Why this re-export layer exists at all + +The `state` package was hoisted out so the wire shapes for state could be imported without dragging in the full backend-registry surface (see `state/README.md` for the why). Re-exporting through `inference` keeps existing consumers' imports stable — code written before the split compiles unchanged. + +## Related + +- [../state/identity.md](../state/identity.md) — the real DTOs +- [../state/project_seed.md](../state/project_seed.md) — project-seed helpers and wake compatibility checks +- [options.md](options.md) — `GenerateConfig` / `GenerateOption` +- [../state/agent_memory.md](../state/agent_memory.md) — bundles consume these identities at Sleep diff --git a/docs/inference/inference.md b/docs/inference/inference.md new file mode 100644 index 0000000..f77b8e2 --- /dev/null +++ b/docs/inference/inference.md @@ -0,0 +1,157 @@ + + +# inference.go — TextModel + Backend + registry + +**Package**: `dappco.re/go/inference` +**File**: `go/inference.go` + +## What this is + +The load-bearing file of the whole tetrad. Five concepts: + +1. **`TextModel`** — the runtime-facing model interface (Generate, Chat, Classify, BatchGenerate, ModelType, Info, Metrics, Err, Close). +2. **`Backend`** — the platform-facing factory interface (Name, LoadModel, Available). +3. **The registry** — package-global map of name → Backend, written at `init()` time by each native driver. +4. **`Default()`** — preference resolver: metal → rocm → llama_cpp → any. +5. **`LoadModel(path, opts...)`** — top-level convenience that picks a backend and returns a ready model as a `core.Result`. + +Plus support DTOs: `Token`, `Message`, `ClassifyResult`, `BatchResult`, `GenerateMetrics`, `ModelInfo`, `AttentionSnapshot`, `AttentionInspector`. + +## TextModel + +```go +type TextModel interface { + Generate(ctx, prompt, ...GenerateOption) iter.Seq[Token] + Chat(ctx, []Message, ...GenerateOption) iter.Seq[Token] + Classify(ctx, []string, ...GenerateOption) ([]ClassifyResult, error) + BatchGenerate(ctx, []string, ...GenerateOption) ([]BatchResult, error) + ModelType() string + Info() ModelInfo + Metrics() GenerateMetrics + Err() error + Close() error +} +``` + +Generate and Chat return Go 1.23+ range-over-func iterators. Errors are +retrieved post-iteration via `Err()` — same pattern as `database/sql` +`Row.Err()`. Don't ignore it; an iterator that stops early on an error +yields the same "iterator exhausted" signal as natural EOS. + +Classify and BatchGenerate are batch calls returning slices — Classify +runs prefill-only (one forward pass per prompt, sample at the final +position) and is the fast path for classification scoring. + +## Backend + +```go +type Backend interface { + Name() string + LoadModel(path string, opts ...LoadOption) (TextModel, error) + Available() bool +} +``` + +`Available()` returns false on hardware that can't run the backend — +`metal.Available()` is false on Linux, `rocm.Available()` is false on +darwin, etc. Used by `Default()` to skip registered-but-unusable +backends. + +## Registry + +Backends register at `init()`: + +```go +// in go-mlx/register_metal.go (build-tagged darwin/arm64) +func init() { inference.Register(&metalbackend{}) } +``` + +Five operations on the global registry: + +| Function | Returns | Notes | +|----------|---------|-------| +| `Register(b Backend)` | nothing | overwrites by name | +| `Get(name)` | `(Backend, bool)` | name lookup | +| `List()` | `[]string` | sorted names | +| `All()` | `iter.Seq2[string, Backend]` | sorted iteration | +| `Default()` | `core.Result` | preference resolver | + +Preference order is hard-coded: `metal → rocm → llama_cpp → any`. The +"any" fallback iterates sorted names so behaviour is deterministic +across runs. + +## LoadModel + +```go +r := inference.LoadModel("/models/gemma3-1b") // auto +r := inference.LoadModel(path, inference.WithBackend("metal")) // explicit +r := inference.LoadModel(path, inference.WithContextLen(8192)) // tuned + +if !r.OK { return r } +model := r.Value.(TextModel) +defer model.Close() +``` + +Returns `core.Result`; the value is `TextModel`. Errors are wrapped +through the backend's name so the trace tells you which backend +refused. + +## Token / Message / ClassifyResult / BatchResult + +```go +type Token struct { ID int32; Text string } +type Message struct { Role, Content string } +type ClassifyResult struct { Token Token; Logits []float32 } +type BatchResult struct { Tokens []Token; Err error } +``` + +`Logits` is nil unless the caller passed `inference.WithLogits()` — +populating logits doubles memory pressure and is off by default. + +## GenerateMetrics + ModelInfo + +`GenerateMetrics` is the post-operation telemetry snapshot: +- Token counts (prompt, generated) +- Timings (prefill duration, decode duration, total wall-clock) +- Throughput (prefill tok/s, decode tok/s — derived) +- Memory (peak / active GPU bytes) + +`ModelInfo` is static metadata from the loaded model: +- Architecture (gemma3, qwen3, llama, …) +- VocabSize, NumLayers, HiddenSize +- QuantBits, QuantGroup + +## AttentionSnapshot / AttentionInspector + +Optional inspection interface — discovered by type assertion: + +```go +if inspector, ok := model.(inference.AttentionInspector); ok { + snap, err := inspector.InspectAttention(ctx, prompt) +} +``` + +Returns per-layer per-head K/Q tensors as flat float32 slices. Used by +go-ml capability probes and the agent-experience attention inspector +in core/ide. + +## Why a global registry + +Each backend lives in its own module behind build tags — Metal CGO +won't compile on Linux, ROCm bindings won't compile on darwin. A +caller importing `_ "dappco.re/go/mlx"` triggers its `init()` and the +backend appears in the registry; the caller's own code references no +darwin-specific symbols. + +That's the trick. The contract package compiles everywhere; backends +plug themselves in via the side-channel of init time + build tags; +consumers ask `LoadModel("...")` and get whatever's actually available +on the box. + +## Related + +- [options.md](options.md) — `GenerateOption` / `LoadOption` and the `With*` functions +- [contracts.md](contracts.md) — extended capability interfaces (Scheduler, CacheService, EmbeddingModel, RerankModel) +- [discover.md](discover.md) — `Discover()` scans a directory for model dirs +- [service.md](service.md) — Core ServiceRuntime registration +- `go-mlx/docs/runtime/register_metal.md` — the canonical Backend implementation diff --git a/docs/inference/local_tuning.md b/docs/inference/local_tuning.md new file mode 100644 index 0000000..a2371da --- /dev/null +++ b/docs/inference/local_tuning.md @@ -0,0 +1,60 @@ + + +# tuning.go — local discovery and autotune contracts + +**Package**: `dappco.re/go/inference` +**File**: `go/tuning.go` + +## What this is + +Portable DTOs and interfaces for local setup UIs. Backends use these to expose +what a machine can do, propose model-load settings for different workloads, and +stream optional smoke-test results without leaking backend-specific types. + +The important interfaces are: + +```go +type MachineDiscoverer interface { + DiscoverMachine(context.Context, MachineDiscoveryRequest) (*MachineDiscoveryReport, error) +} + +type TuningPlanner interface { + PlanTuning(context.Context, TuningPlanRequest) (*TuningPlan, error) +} +``` + +Discovery should be metadata-first: device facts, capabilities, cache modes, +and model-pack metadata where available. It should not load weights. Tuning is +separate and opt-in. + +## Workloads + +`TuningWorkload` is a stable string used in UI and persisted profiles: + +- `chat` +- `coding` +- `long_context` +- `agent_state` +- `throughput` +- `low_latency` + +## Candidate and profile + +`TuningCandidate` records the concrete settings a UI can try or save: context +length, cache policy/mode, batch size, prefill chunk size, parallel slots, +allocator limits, model identity, adapter identity, and runtime identity. + +After a smoke run, callers persist `TuningProfile`: key, candidate, +measurements, score, and labels. + +## Model replace + +`PlanModelReplace` is the conservative state decision helper: + +- same model/runtime/adapter: reuse state +- same model/adapter but runtime settings changed: checkpoint state +- model or adapter changed: compact to summary/new window + +This lets a UI change models or settings quickly while keeping the state flow +honest. + diff --git a/docs/inference/options.md b/docs/inference/options.md new file mode 100644 index 0000000..0ae8206 --- /dev/null +++ b/docs/inference/options.md @@ -0,0 +1,76 @@ + + +# options.go — GenerateOption + LoadOption + +**Package**: `dappco.re/go/inference` +**File**: `go/options.go` + +## What this is + +Two functional-option families: + +- **`GenerateOption`** — passed to Generate / Chat / Classify / BatchGenerate. Tunes sampling. +- **`LoadOption`** — passed to LoadModel / LoadTrainable. Tunes load. + +Each is `func(*Config)`; backends call `ApplyGenerateOpts(opts)` / `ApplyLoadOpts(opts)` to flatten into a `GenerateConfig` / `LoadConfig`. + +## GenerateConfig + +```go +type GenerateConfig struct { + MaxTokens int + Temperature float32 + TopK int + TopP float32 + StopTokens []int32 + RepeatPenalty float32 + ReturnLogits bool +} +``` + +`DefaultGenerateConfig()` — MaxTokens=256, Temperature=0.0 (greedy), RepeatPenalty=1.0, everything else zero. + +## With* generators + +| Function | Tunes | Typical | +|----------|-------|---------| +| `WithMaxTokens(n)` | output cap | 64 short, 256 medium, 2048 long-form | +| `WithTemperature(t)` | randomness | 0.0 greedy, 0.7 balanced, 1.5 high-variance | +| `WithTopK(k)` | top-k filter | 40 typical, 0 disabled | +| `WithTopP(p)` | nucleus | 0.9 typical, 0 disabled | +| `WithStopTokens(ids…)` | early halt | EOS id (model-specific) | +| `WithRepeatPenalty(p)` | repetition guard | 1.0 off, 1.1 mild, 1.5 strong | +| `WithLogits()` | capture logits | off by default — doubles classify memory | + +## LoadConfig + +```go +type LoadConfig struct { + Backend string // "metal" | "rocm" | "llama_cpp" | "" (auto) + ContextLen int // KV cache cap in tokens — 0 = model default + GPULayers int // -1 = all (default), 0 = CPU, n = partial + ParallelSlots int // concurrent inference slots — 0 = backend default + AdapterPath string // LoRA dir — empty = no adapter +} +``` + +`ApplyLoadOpts(opts)` starts with `GPULayers: -1` (full GPU); everything else zero. + +## With* generators (load) + +| Function | Tunes | Notes | +|----------|-------|-------| +| `WithBackend(name)` | explicit backend | overrides Default() preference order | +| `WithContextLen(n)` | KV cap | trade context vs VRAM | +| `WithGPULayers(n)` | offload | -1 all, 0 CPU, partial supported per-backend | +| `WithParallelSlots(n)` | concurrency | costs VRAM proportional to n | +| `WithAdapterPath(path)` | LoRA at load | weights stay separate from base | + +## Why functional options + +Backends grow option fields independently. Adding `WithFlashAttention(true)` doesn't touch any call site that doesn't pass it. `ApplyGenerateOpts` / `ApplyLoadOpts` flatten the chain so backends consume a plain struct internally. + +## Related + +- [inference.md](inference.md) — where GenerateOption / LoadOption are passed in +- [training.md](training.md) — `LoRAConfig` for fine-tuning loops diff --git a/docs/inference/probe.md b/docs/inference/probe.md new file mode 100644 index 0000000..43fd80f --- /dev/null +++ b/docs/inference/probe.md @@ -0,0 +1,65 @@ + + +# probe.go — observability bus DTOs + +**Package**: `dappco.re/go/inference` +**File**: `go/probe.go` + +## What this is + +The portable shape for **runtime telemetry events** that backends emit during a session. Probes are the "what's happening inside the model right now" signal — used by go-ml's scoring engine, the core/ide attention inspector, and the eval/bench pipelines. + +A backend implements `ProbeSink` to receive probes, or emits via package-injected sink for in-process subscribers. No transport policy in this file — just the DTOs. + +## Event kinds + +```go +ProbeEventToken // every generated token +ProbeEventLogits // raw logits (when ReturnLogits set) +ProbeEventEntropy // per-step sampling entropy +ProbeEventSelectedHeads // which attention heads fired +ProbeEventLayerCoherence // per-layer activation alignment +ProbeEventRouterDecision // MoE expert routing decisions +ProbeEventResidual // residual-stream magnitude +ProbeEventCachePressure // KV cache fill / eviction +ProbeEventMemoryPressure // GPU allocator state +ProbeEventTraining // SFT/LoRA/GRPO step events +``` + +## Phases + +```go +ProbePhasePrefill // initial prompt forward pass +ProbePhaseDecode // autoregressive generation +ProbePhaseTraining // SFT/LoRA/GRPO loop +``` + +## Event payload + +`ProbeEvent` carries `Kind` + `Phase` + per-event payload (numeric + label maps). The full shape is small and self-describing — `ProbeEventToken` includes the token id/text; `ProbeEventLayerCoherence` includes a per-layer float; `ProbeEventRouterDecision` includes expert indices and weights. + +## ProbeSink + +```go +type ProbeSink interface { + EmitProbe(event ProbeEvent) +} +``` + +Implemented by: + +- `go-ml/agent_eval.go` — collects probes into eval reports +- `core/api` SSE handler — streams probes to core/ide +- in-process test fixtures that just accumulate events + +A backend with no `ProbeSink` injected emits to a no-op default. + +## Why a separate file + +Probes are an extension surface, not a core capability. A minimal backend (CPU llama fallback) emits nothing but still satisfies TextModel. A research-grade backend (go-mlx with attention inspection + MoE routing) emits dozens of events per generated token. The shape is portable so consumers don't pin to one backend. + +## Related + +- [capability.md](capability.md) — `CapabilityProbeEvents` / `CapabilityAttentionProbe` / `CapabilityLogitProbe` +- `go-mlx/docs/observability/probe.md` (planned) — backend wiring +- `go-ml/docs/agent/agent_eval.md` (planned) — probe collection in eval diff --git a/docs/inference/service.md b/docs/inference/service.md new file mode 100644 index 0000000..87b512a --- /dev/null +++ b/docs/inference/service.md @@ -0,0 +1,62 @@ + + +# service.go — Core ServiceRuntime registration + +**Package**: `dappco.re/go/inference` +**File**: `go/service.go` +**Mantis**: #1336 (canonical Service.go pattern) + +## What this is + +The Core-side handle for the `inference` package — exposes the canonical `NewService(opts) + RegisterCore(c)` shape so `dappco.re/go/core` can discover the inference package as a registerable framework service. + +## The naming divergence + +Canonical pattern across the rest of the Go canon: + +```go +core.New(core.WithService(somepkg.Register)) // somepkg.Register is the registration fn +``` + +But `inference.Register(b Backend)` already exists — the init-time backend-registration call that every native driver uses: + +```go +// in go-mlx/register_metal.go +func init() { inference.Register(&metalbackend{}) } +``` + +Renaming would break every backend. So this package exposes the canonical Core registration as **`RegisterCore(c *core.Core) core.Result`** instead, leaving the existing `Register(Backend)` untouched. Both names share a package; both keep their established consumers. + +## Usage + +```go +c, _ := core.New(core.WithService(inference.NewService(inference.Options{}))) +svc := core.MustServiceFor[*inference.Service](c, "inference") + +for name, b := range inference.All() { + fmt.Printf("%s available=%v\n", name, b.Available()) +} +``` + +## Options + +```go +type Options struct{} +``` + +v1 has no fields. The package's behaviour is fully driven by which Backend implementations have called `Register(Backend)` at init time. Future fields land here as needed — preferred-backend-order override, ProbeBus subscribers, etc. + +## Service + +`*inference.Service` embeds `*core.ServiceRuntime[Options]` for typed Options access. The Service struct holds no state beyond Options + the Core handle; the real state (registered backends) lives in the package-global registry. + +## Why a thin handle + +The Service is **not the source of truth** — the global registry is. The Service is the Core-discovery surface that lets the framework's `core.ServiceFor` lookup find the package. This keeps the public-package shape stable while letting the framework treat inference like any other service for lifecycle (startup, shutdown, probes). + +A backend's init-time `Register` does not need a Core handle. A consumer calling `inference.LoadModel(path)` does not need a Core handle. The Service is purely for framework-side discovery. + +## Related + +- `core/docs/service.md` — the canonical ServiceRuntime contract +- [inference.md](inference.md) — the global Backend registry the service surfaces diff --git a/docs/inference/training.md b/docs/inference/training.md new file mode 100644 index 0000000..140a4bd --- /dev/null +++ b/docs/inference/training.md @@ -0,0 +1,78 @@ + + +# training.go — TrainableModel + Adapter contracts + +**Package**: `dappco.re/go/inference` +**File**: `go/training.go` + +## What this is + +The contract surface for **fine-tuning** — LoRA adapter management, gradient steps, save/load. Backends that can train implement `TrainableModel`; the rest don't. Same pattern as the inspection interfaces in `contracts.go` — opt-in via type assertion. + +## LoRAConfig + +```go +type LoRAConfig struct { + Rank int // decomposition rank (default 8) + Alpha float32 // scaling factor (default 16) + TargetKeys []string // projection suffixes (default: q_proj, v_proj) + BFloat16 bool // mixed-precision adapter weights +} +``` + +`DefaultLoRAConfig()` — Rank=8, Alpha=16, TargetKeys=["q_proj","v_proj"], BFloat16=false. + +Backends that don't honour `BFloat16` ignore the field (still emit a probe event so the caller knows). + +## Adapter + +```go +type Adapter interface { + // implementation-defined methods; the concrete type is backend-specific + // (e.g. *metal.LoRAAdapter for go-mlx) +} +``` + +`Adapter` is intentionally **interface-empty** — the concrete type lives in each backend. Consumers hold an `Adapter` reference for save/load/swap but never inspect its methods directly. The backend exposes the operations through its `TrainableModel`. + +## TrainableModel + +```go +type TrainableModel interface { + TextModel + AttachAdapter(cfg LoRAConfig) (Adapter, error) + DetachAdapter() error + Step(ctx, batch) (StepResult, error) // one optimiser step + SaveAdapter(path string) error + LoadAdapter(path string) error +} +``` + +(Exact method shapes are backend-defined; this file holds the umbrella interface signature.) + +## LoadTrainable + +```go +inference.LoadTrainable(path, opts...) core.Result +``` + +Top-level helper — same pattern as `LoadModel` but typed to `TrainableModel`. Backends that don't support training return a "trainable not supported on backend X" error. + +## Why training is a separate interface + +Most callers never train — they want inference. Forcing every backend to stub out training methods bloats the contract. Inference-only backends (HTTP, llama.cpp subprocess) literally cannot train; they implement `TextModel` and that's all anyone needs. + +## Implemented by + +- `go-mlx` — full training surface: SFT, LoRA, GRPO, distillation +- `go-rocm` — planned mirror +- `go-ml` does NOT implement TrainableModel — it consumes trainable models via go-mlx + +## Related + +- [capability.md](capability.md) — `CapabilityLoRATraining`, `CapabilityDistillation`, `CapabilityGRPO` +- `go-mlx/docs/training/sft.md` (planned) — reference SFT implementation +- `go-mlx/docs/training/lora_adapter.md` (planned) — LoRA Adapter concrete shape +- `go-mlx/docs/training/grpo.md` (planned) — reasoning training loop +- `go-mlx/docs/training/distill.md` (planned) — teacher/student distillation +- [../state/identity.md](../state/identity.md) — `AdapterIdentity` portable identity diff --git a/docs/ollama/ollama.md b/docs/ollama/ollama.md new file mode 100644 index 0000000..21b10a0 --- /dev/null +++ b/docs/ollama/ollama.md @@ -0,0 +1,94 @@ + + +# ollama/ollama.go — Ollama-compatible wire types + +**Package**: `dappco.re/go/inference/ollama` +**File**: `go/ollama/ollama.go` + +## What this is + +The Ollama-compatible API wire surface — DTOs for `/api/chat`, `/api/generate`, `/api/tags`, `/api/show` plus translation to `inference.Message` + `inference.GenerateOption`. Same pattern as the OpenAI and Anthropic sibling packages. + +Used by tools and IDE plugins that talk to Ollama natively (Continue, Cody, Cline, the Codex `ollama` profile) — when this surface is mounted by core/api, those tools find a local model server transparent to "is this real Ollama or core?" + +## Paths + +```go +DefaultChatPath = "/api/chat" +DefaultGeneratePath = "/api/generate" +DefaultTagsPath = "/api/tags" +DefaultShowPath = "/api/show" +``` + +## DTOs + +```go +Message // role + content (plain string, unlike Anthropic's typed blocks) +Options // temperature + top_k + top_p + num_predict +ChatRequest // model + messages + stream + options +GenerateRequest // model + prompt + stream + options +ChatResponse // model + message + done + prompt_eval_count + eval_count + durations (nanos) +GenerateResponse // model + response (text) + done + counters + durations +ModelTag // name + model + modified_at + size +TagsResponse // models[] +ShowRequest // model +ShowResponse // license + modelfile + parameters + template + details +``` + +Two response timing peculiarities to know: + +- Durations are **int64 nanoseconds**, not floats / seconds. +- `prompt_eval_count` = prompt tokens, `eval_count` = generated tokens (different field names from OpenAI / Anthropic). + +## InferenceMessages + +```go +messages := ollama.InferenceMessages(req.Messages) +``` + +Straight 1:1 map. Ollama's message shape matches `inference.Message` directly so the conversion is a slice rebuild. + +## GenerateOptions + +```go +opts := ollama.GenerateOptions(req.Options) +for tok := range model.Chat(ctx, messages, opts...) { ... } +``` + +Translates Ollama's sampler set. `num_predict` becomes `WithMaxTokens` — the Ollama name reflects its llama.cpp lineage. + +## NewChatResponse + NewGenerateResponse + +```go +chatResp := ollama.NewChatResponse(modelName, text, metrics) +genResp := ollama.NewGenerateResponse(modelName, text, metrics) +``` + +Convenience builders. `Done: true` always set — they produce single-shot responses, not streaming chunks. Streaming responses build per-chunk shapes inline at the handler. + +## /api/tags + /api/show + +`TagsResponse` mirrors the model picker — backends that implement model listing can serve this from their inventory. `ShowResponse` carries Ollama's "model details" payload (license / template / parameters) which map onto `ModelIdentity` + `TokenizerIdentity.ChatTemplate`. + +These two endpoints are read-only meta queries, no inference work — making them easy to satisfy from a backend's `Discover()` + `Inspect()` results. + +## What's not here + +- `/api/pull`, `/api/push`, `/api/copy`, `/api/delete` — model management. CoreAgent's model store has different semantics (State bundles vs Ollama tags). Not a wire-parity target. +- `/api/embeddings` — Ollama has it; CoreAgent serves embeddings via the OpenAI `/v1/embeddings` path instead. +- HTTP handler. As with `anthropic.go`, the wire DTOs are in place; the handler is roadmap. + +## Why three sibling files, not one mega-package + +The temptation is a single `wire` package with `wire.OpenAIChat`, `wire.AnthropicMessages`, `wire.OllamaChat`. We resist for three reasons: + +1. **Naming friction** — `wire.MessageRequest` is ambiguous; `anthropic.MessageRequest` isn't. +2. **Import economy** — a server that only exposes the OpenAI surface shouldn't compile Anthropic + Ollama into its binary. +3. **Independent evolution** — each upstream API changes on its own clock; isolated packages let us track each without cross-touch. + +## Related + +- [../openai/openai.md](../openai/openai.md) — OpenAI sibling +- [../anthropic/anthropic.md](../anthropic/anthropic.md) — Anthropic sibling +- [../inference/inference.md](../inference/inference.md) — base `Message` + `GenerateOption` types +- [../inference/capability.md](../inference/capability.md) — `CapabilityOllamaCompat` declares this surface diff --git a/docs/openai/README.md b/docs/openai/README.md new file mode 100644 index 0000000..36a079b --- /dev/null +++ b/docs/openai/README.md @@ -0,0 +1,60 @@ + + +# openai/ — OpenAI-compatible wire types + HTTP handlers + +**Package**: `dappco.re/go/inference/openai` + +## What this package owns + +Three things: + +1. **Wire DTOs** for the OpenAI public API surface (Chat Completions, Responses, Embeddings, Rerank, Capabilities, Cache control, Cancel). +2. **Translation** between those DTOs and the `inference` package's runtime types (`Message`, `GenerateOption`, `CapabilityReport`, etc.). +3. **HTTP handlers** that wrap an `inference.TextModel` (or capability-extended variant) and serve OpenAI-compatible requests. + +Drop-in compatible with any OpenAI SDK. Point the SDK at this handler's path and you get real local inference. + +## File map + +| File | Doc | Scope | +|------|-----|-------| +| `openai.go` | [openai.md](openai.md) | Chat Completions — DTOs + translation + Handler | +| `responses.go` | [responses.md](responses.md) | Responses API — DTOs + translation (handler TBD) | +| `services.go` | [services.md](services.md) | Embeddings / Rerank / Capabilities / Cache / Cancel handlers | + +## Resolver contract + +All handlers take a `Resolver` (defined in `openai.go`) — the indirection that maps a wire `model` field to a real `inference.TextModel`: + +```go +type Resolver interface { + ResolveModel(ctx, name) (inference.TextModel, error) +} +``` + +Three implementations ship in `openai.go`: + +- `ResolverFunc` — inline closure +- `StaticResolver` — pre-loaded `map[string]TextModel` +- `BackendResolver` — lazy `inference.Backend.LoadModel(path)` + +A custom Resolver is the right shape for: + +- Quota-checked model dispatch (resolver rejects when quota exceeded) +- Per-user model gating +- Hot-swap (resolver looks up the current pin from config service) + +## Why this package exists + +The OpenAI wire format is **inference shape**, not provider policy. Any backend can serve it. Putting the DTOs + handlers + translation here gives go-mlx, go-rocm, and any future native driver an instant HTTP frontage without each one re-implementing the wire — and lets the outbound provider in `go-ai/providers/openai` use the same DTOs from the client side. + +The opposite arrangement — DTOs in `go-ai` because OpenAI is "external" — would force every backend to depend on `go-ai`, which would then have to depend on every backend. The current shape keeps the dependency arrows pointing only **into** `inference`. + +## Related + +- [../inference/inference.md](../inference/inference.md) — `TextModel` + `Backend` interfaces +- [../inference/contracts.md](../inference/contracts.md) — `EmbeddingModel` / `RerankModel` / `CacheService` / `CancellableModel` +- [../inference/capability.md](../inference/capability.md) — `CapabilityReport` returned by `/v1/models/capabilities` +- [../anthropic/anthropic.md](../anthropic/anthropic.md) — sibling Anthropic wire types +- [../ollama/ollama.md](../ollama/ollama.md) — sibling Ollama wire types +- `go-ai/docs/providers/openai.md` (planned) — client-side outbound use of these DTOs diff --git a/docs/openai/openai.md b/docs/openai/openai.md new file mode 100644 index 0000000..d4ad8a9 --- /dev/null +++ b/docs/openai/openai.md @@ -0,0 +1,104 @@ + + +# openai/openai.go — Chat Completions wire adapter + +**Package**: `dappco.re/go/inference/openai` +**File**: `go/openai/openai.go` + +## What this is + +The OpenAI Chat Completions wire surface, adapted onto `inference.TextModel`. Three layers in one file: + +1. **DTOs** — exact request/response shapes matching the OpenAI public API. +2. **Translation** — converting between the wire shape and `inference.GenerateOption` / `inference.Message`. +3. **HTTP handler** — `Handler` that resolves a model by name and streams completions. + +Drop-in compatibility with OpenAI SDKs out of the box. A consumer points the SDK at this handler's path (`POST /v1/chat/completions`) and gets back real local inference — no SDK changes. + +## DTOs (wire-exact) + +```go +ChatCompletionRequest // model + messages + sampler (all *T optional) +ChatMessage // role + content +ChatCompletionResponse // non-streaming response +ChatChoice // index + message + finish_reason +ChatUsage // prompt_tokens + completion_tokens + total_tokens +ChatCompletionChunk // streaming SSE chunk +ChatChunkChoice // streaming choice +ChatMessageDelta // streaming delta (custom MarshalJSON) +ErrorResponse / ErrorObject +StopList // accepts either string or []string in JSON +``` + +## Defaults + +```go +DefaultTemperature = 1.0 +DefaultTopP = 0.95 +DefaultTopK = 64 +DefaultMaxTokens = 2048 +``` + +Used when the wire request has nil optional fields. + +## DecodeRequest + ValidateRequest + +```go +req, err := openai.DecodeRequest(r.Body) +err := openai.ValidateRequest(req) +``` + +DecodeRequest handles the StopList polymorphism (string vs array). ValidateRequest checks required fields + sanity bounds. + +## GenerateOptions + +```go +opts, err := openai.GenerateOptions(req) +for tok := range model.Chat(ctx, messages, opts...) { ... } +``` + +Translates wire-typed sampler fields into a slice of `inference.GenerateOption`. Stop sequences are normalised to token-id stops where possible; freeform stop strings flow through a different path. + +## NormalizeStopSequences + +```go +ids, err := openai.NormalizeStopSequences(req.Stop) +``` + +Resolves OpenAI's stop strings against the model tokenizer where the tokenizer is available. Falls back to string-mode stop on streaming if the tokenizer can't pre-tokenise the sequence. + +## Resolver + +```go +type Resolver interface { + ResolveModel(ctx, name) (inference.TextModel, error) +} +``` + +Three built-in implementations: + +| Type | Use | +|------|-----| +| `ResolverFunc` | inline closure | +| `StaticResolver` | pre-loaded `map[string]TextModel` — model-picker UI, fixed deployments | +| `BackendResolver` | lazy load via `inference.Backend.LoadModel(path)` — cold-load on first request | + +## Handler + +```go +h := openai.NewHandler(resolver) +http.Handle("/v1/chat/completions", h) +``` + +Serves both streaming (`stream: true` → SSE) and non-streaming responses. Channel-marker (`<|channel>`) support lets reasoning channels flow into a separate stream key when the model emits thinking tokens. + +## Why this lives in `inference` not in `go-ai` + +The OpenAI wire format is **inference shape**, not provider policy. Any inference backend can be a server. go-ai's outbound provider (`go-ai/providers/openai`) uses the *same DTOs* for its **client** side — that's deliberate. The router (go-ai) owns policy (rate limits, fallback, quota); the wire (this package) owns the shape both sides agree on. + +## Related + +- [responses.md](responses.md) — newer `/v1/responses` API surface +- [services.md](services.md) — embeddings / rerank / cache / cancel handlers +- `go-ai/docs/providers/openai.md` — client-side outbound provider +- `core/api` — mounts this handler when `inference.api.openai = true` diff --git a/docs/openai/responses.md b/docs/openai/responses.md new file mode 100644 index 0000000..3133aa7 --- /dev/null +++ b/docs/openai/responses.md @@ -0,0 +1,67 @@ + + +# openai/responses.go — Responses API wire shapes + +**Package**: `dappco.re/go/inference/openai` +**File**: `go/openai/responses.go` + +## What this is + +The OpenAI **Responses API** (`/v1/responses`) wire types — a newer, more structured alternative to Chat Completions that treats inputs as typed items and outputs as typed messages. Same translation pattern as Chat Completions: DTOs + `inference.Message` adapter + `inference.GenerateOption` builder. + +This is a parity item from the 2026-05-09 vMLX gap report; vMLX exposed `/v1/responses` and CoreAgent needed the same surface for SDK compatibility. + +## DTOs + +```go +ResponseInputMessage // structured input item (text / image / tool result / …) +ResponseRequest // model + input items + sampler + tools + reasoning hints +ResponseOutputText // typed text segment +ResponseOutputMessage // typed assistant message with output_text array +ResponseUsage // input_tokens + output_tokens + reasoning_tokens +Response // non-streaming response (id + model + output[] + usage) +ResponseStreamEvent // streaming event (event_type + payload) +``` + +The Responses API distinguishes **visible text** from **reasoning text** at the wire level — `ResponseUsage.ReasoningTokens` is its own count. This pairs cleanly with the `ReasoningParser` interface in `contracts.go` — backends that emit reasoning channels feed them through as separate output items. + +## Translation + +```go +messages := openai.ResponseMessages(req) // flatten input items to inference.Message +opts, err := openai.ResponseGenerateOptions(req) // sampler → GenerateOption +``` + +`ResponseMessages` walks `req.Input[]`, extracting text content and converting role + content per item. Tool-result items map to `Role: "tool"` messages. + +`ResponseGenerateOptions` follows the same logic as `GenerateOptions` in `openai.go` — the Responses API and Chat Completions accept the same sampler set. + +## NewTextResponse + +```go +resp := openai.NewTextResponse(requestID, modelName, text, metrics) +``` + +The minimal builder — produces a complete `Response` with one output message containing one text segment. Used by the handler to serialise the simple non-streaming path. Streaming responses build `ResponseStreamEvent` chunks instead. + +## Why Responses vs Chat Completions + +OpenAI introduced Responses because Chat Completions can't cleanly express: + +- Multi-modal inputs (image + text in the same turn) +- Tool-call results as typed input items, not assistant turns +- Reasoning tokens billed separately from output tokens +- Server-side state (response references the previous response) + +Local CoreAgent inference benefits from the same shape — reasoning channels are first-class, tool results flow without role abuse, server-state can be tied to wake/sleep bundles. + +## Where the handler lives + +The Responses HTTP handler is currently not in this file (the Chat Completions handler in `openai.go` is the only HTTP entry). A Responses-specific handler is on the parity-plan roadmap; the DTOs are in place so once the handler lands, the SDK side already compiles. + +## Related + +- [openai.md](openai.md) — Chat Completions counterpart +- [services.md](services.md) — embeddings/rerank/cache/cancel handlers +- [../inference/contracts.md](../inference/contracts.md) — `ReasoningParser` for emitting reasoning channels +- `go-mlx/docs/inference/thinking.md` (planned) — reasoning parser implementation diff --git a/docs/openai/services.md b/docs/openai/services.md new file mode 100644 index 0000000..ce8f634 --- /dev/null +++ b/docs/openai/services.md @@ -0,0 +1,94 @@ + + +# openai/services.go — embeddings / rerank / cache / cancel handlers + +**Package**: `dappco.re/go/inference/openai` +**File**: `go/openai/services.go` + +## What this is + +The non-chat HTTP surface — eight handlers for the auxiliary OpenAI-compatible endpoints. Each handler probes the resolved model for the right interface (`EmbeddingModel`, `RerankModel`, `CacheService`, `CancellableModel`) and 501s if the backend doesn't support it. + +Paths exposed: + +```go +DefaultEmbeddingsPath = "/v1/embeddings" +DefaultRerankPath = "/v1/rerank" +DefaultCapabilitiesPath = "/v1/models/capabilities" +DefaultCacheStatsPath = "/v1/cache/stats" +DefaultCacheWarmPath = "/v1/cache/warm" +DefaultCacheClearPath = "/v1/cache/clear" +DefaultCancelPath = "/v1/cancel" +``` + +## Handlers + +| Handler | Path | Backend interface needed | +|---------|------|--------------------------| +| `EmbeddingsHandler` | `/v1/embeddings` | `EmbeddingModel` | +| `RerankHandler` | `/v1/rerank` | `RerankModel` | +| `CapabilityHandler` | `/v1/models/capabilities` | `CapabilityReporter` | +| `CacheStatsHandler` | `/v1/cache/stats` | `CacheService` | +| `CacheWarmHandler` | `/v1/cache/warm` | `CacheService` | +| `CacheClearHandler` | `/v1/cache/clear` | `CacheService` | +| `CancelHandler` | `/v1/cancel` | `CancellableModel` | + +Each constructed via `NewXxxHandler(resolver)` — the same `Resolver` interface used by the chat handler. + +## DTOs + +```go +EmbeddingRequest // model + input + encoding_format + dimensions + normalize +EmbeddingInput // string OR []string (custom UnmarshalJSON) +EmbeddingResponse // object + data[] + model + usage +EmbeddingResponseDatum + +RerankRequest // model + query + documents + top_n +RerankResponse // results[] (index + score + text) + +CacheWarmRequest // model + tokens or prompt + labels +CacheClearRequest // labels filter +CancelRequest // request id +``` + +The capability + cache-stats GET endpoints take no body — query string `?model=X` selects which loaded model to report on. + +## EmbeddingInput polymorphism + +OpenAI's embeddings API accepts either a single string or an array. The custom `UnmarshalJSON` on `EmbeddingInput` handles both. The Go-side always sees `[]string` — single-string inputs become a one-element slice. + +## Shared handler scaffolding + +```go +type serviceHandler struct{ resolver Resolver } + +func (h *serviceHandler) resolve(...) (TextModel, bool) +func (h *serviceHandler) resolveCacheService(...) (CacheService, bool) +``` + +Each concrete handler embeds `serviceHandler` and gets the resolve helpers for free. The helper writes 4xx/5xx + JSON error responses when: + +- Resolver returns "model not found" +- Model doesn't satisfy the required capability interface +- Decode / validation fails + +## Why these are HTTP-shape primitives + +The runtime *interfaces* (`EmbeddingModel`, `RerankModel`, `CacheService`, `CancellableModel`) live in `inference/contracts.go`. This file is **just the wire layer** on top — turning HTTP requests into runtime calls and runtime results into HTTP responses. + +A non-HTTP transport (Unix socket, gRPC, MCP tool call) can use the same interfaces without involving this file. Conversely, an OpenAI-compatible server that wants the wire compatibility without going through the runtime contract can crib the DTOs here. + +## What's not here + +- `/v1/audio/transcriptions` — vMLX exposed it; we don't have audio runtime support yet (out of scope for the core runner) +- `/v1/images/generations` — same reason +- `/v1/files` — bundle-as-file maps onto agent memory, but the wire mapping isn't designed yet +- Speech endpoints — see `/v1/audio` note + +## Related + +- [openai.md](openai.md) — Chat Completions handler +- [responses.md](responses.md) — Responses API DTOs +- [../inference/contracts.md](../inference/contracts.md) — `EmbeddingModel` / `RerankModel` / `CacheService` / `CancellableModel` +- [../inference/capability.md](../inference/capability.md) — `CapabilityReport` returned by the capability handler +- `core/api` — mounts these handlers when configured diff --git a/docs/state/README.md b/docs/state/README.md new file mode 100644 index 0000000..33e347b --- /dev/null +++ b/docs/state/README.md @@ -0,0 +1,120 @@ + + +# state/ — durable model-state contracts + +**Package**: `dappco.re/go/inference/state` + +## What this package owns + +The portable, backend-neutral contracts for **storing live model state +to a durable medium and restoring it later** — what the wider stack +calls "agent memory" or "book state". Everything in here is interfaces +and DTOs; no runtime code. Backends in `go-mlx`, `go-rocm` (planned), +`go-cuda` (planned) implement these contracts; consumers in `go-ai`, +`go-ml`, `core/api` use them. + +This package was hoisted out of `dappco.re/go/inference` so the wire +shapes for state — `Bundle`, `Ref`, `Wake/Sleep/Fork` — could be +imported without dragging in the full backend-registry surface. The +parent `inference` package re-exports the most common types as +aliases (`inference.ModelIdentity = state.ModelIdentity` etc.) so +existing callers keep compiling. + +## File map + +| File | Doc | What it owns | +|------|-----|--------------| +| `agent_memory.go` | [agent_memory.md](agent_memory.md) | Wake/Sleep/Fork lifecycle DTOs + `Session` + `Forker` interfaces | +| `identity.go` | [identity.md](identity.md) | `ModelIdentity` / `TokenizerIdentity` / `AdapterIdentity` / `RuntimeIdentity` / `SamplerConfig` / `StateRef` / `Bundle` | +| `project_seed.go` | [project_seed.md](project_seed.md) | Project seed URI planning, continuation modes, and wake compatibility checks | +| `store.go` | [store.md](store.md) | `Store` / `Resolver` / `Writer` interfaces + `Chunk` / `ChunkRef` DTOs + `Resolve*` free fns + codec constants | +| `memory.go` | [memory.md](memory.md) | `InMemoryStore` — in-process test/dev backend | +| `filestore/store.go` | [filestore.md](filestore.md) | Append-only file-log durable backend | + +## Mental model + +``` + ┌───────────────────────┐ + │ Bundle (identity.go)│ ← what gets persisted + └───────────┬───────────┘ + │ contains + ┌───────────┴───────────┐ + │ []StateRef │ + │ Model/Tokenizer/etc │ + └───────────────────────┘ + ▲ + │ written by + │ + ┌──────────────────┐ │ ┌──────────────────┐ + │ Session. │─────┘ │ Session. │ + │ SleepState() │ │ WakeState() │ + │ (agent_memory) │ │ (agent_memory) │ + └─────────┬────────┘ └────────▲─────────┘ + │ produces │ consumes + ▼ │ + ┌──────────────────┐ ┌──────────┴────────┐ + │ Store.PutBytes │ │ Store.Resolve... │ + │ Writer.Put │ │ Resolver │ + │ (store.go) │ │ URIResolver │ + └─────────┬────────┘ └──────────▲────────┘ + │ │ + ▼ │ + ┌─────────────────────────────────────────┐ + │ InMemoryStore / filestore.Store │ + │ State video / object store (future) │ + └─────────────────────────────────────────┘ +``` + +A sleep produces a `Bundle` whose `KVRefs` / `ProbeRefs` / +`StateRefs` point at chunks written to some `Store`. A wake reads the +bundle, then reads each chunk back through the same Store. The two +interfaces in `agent_memory.go` (`Session` + `Forker`) are the only +runtime contracts; everything else is data. + +`project_seed.go` sits one level above those DTOs. It helps an app or agent +runner build consistent project seed URIs, choose state-checkpoint versus +summary-window continuation, and run compatibility checks before asking a +backend to wake KV. + +## Codec constants + +```go +state.CodecMemory = "memory/plaintext" // InMemoryStore +state.CodecStateVideo = "state/qr-video" // State video .mp4 +filestore.CodecFile = "state/file-log" // append-only file +``` + +A `ChunkRef` carries its codec so the wake side knows which decoder to +run — same bundle index can refer to chunks across multiple codecs if +the writer chose to spread them (rare but supported). + +## Why this package exists at all + +Three forces pushed it out of `inference`: + +1. **Cycle pressure.** `inference.Backend` wants to mention bundles + (capability reports, model-pack inspection); bundles want to + mention chunks; chunks want to mention bytes. Splitting state out + gave a clean acyclic graph. + +2. **Cross-package re-use.** `core/api` wants to serialise bundles + over HTTP without importing the full backend surface. `core/ide` + wants to display bundle indexes without linking go-mlx. Both can + now `import "dappco.re/go/inference/state"` and get just the + shapes. + +3. **Lifecycle clarity.** Wake/Sleep/Fork are a small focused + contract; storage interfaces are another. Putting them in their + own package made the "what's the smallest implementation" question + answerable without grep. + +## See also + +- [Parent inference docs](../inference/README.md) — how state is + consumed by `Backend` / `TextModel` +- [openai/services.md](../openai/services.md) — wire types that carry + `ModelIdentity` in capability reports +- `go-mlx/docs/memory/agent_memory.md` (planned) — the reference + Metal-backed Session implementation +- `go-mlx/docs/memory/state_bundle.md` (planned) — bundle + encode/decode round-trip diff --git a/docs/state/agent_memory.md b/docs/state/agent_memory.md new file mode 100644 index 0000000..23bcb45 --- /dev/null +++ b/docs/state/agent_memory.md @@ -0,0 +1,125 @@ + + +# state/agent_memory.go — Wake / Sleep / Fork lifecycle + +**Package**: `dappco.re/go/inference/state` +**File**: `go/state/agent_memory.go` +**Aliased into**: `dappco.re/go/inference` (as `AgentMemory*` for the +historical naming consumers expect) + +## What this is + +The portable contract for **persisting and restoring live model state** +without binding to a concrete storage backend. A runtime that implements +`Session` can be told to write its current KV/context as a durable +"bundle", and a runtime that implements `Forker` can re-spawn a session +from a bundle written earlier — possibly on a different machine, possibly +much later, possibly from a knowledge-pack `.mp4` that was scanned in by +phone camera. + +Three lifecycle verbs, four DTOs, two interfaces. Nothing else. + +## DTOs + +| Type | Role | +|------|------| +| `Ref` | URI-first identity for a durable state span — bundle + index + sampler/model identity + token/byte ranges. The thing you keep in your filesystem / DB / cold-storage index to point at one wake target. | +| `WakeRequest` | "Restore prefix from this URI into this session." Carries the model + tokenizer + adapter + runtime identity for compatibility checking; `Store` is an opaque runtime handle (deliberately not JSON-serialised). | +| `WakeResult` | "I restored N prefix tokens from this bundle/index, B blocks, K block size." Returned by `Session.WakeState`. | +| `SleepRequest` | "Persist the current session state to this URI, parented to that earlier URI." `ReuseParentPrefix` enables append-mode: a new bundle that shares prefix blocks with its parent — `O(delta)` writes, not full re-encode. | +| `SleepResult` | "I wrote N tokens across B blocks (R reused from parent), here is the new Ref." | + +`Store any` on both Wake/Sleep requests is the explicit escape hatch for +backend-owned handles (State video encoder, file log writer, S3 client) that +the JSON serialisation layer doesn't need to see. + +`Adapter` and `Runtime` are metadata fields, not dependency hooks. They let +orchestration decide whether waking a saved prefix is safe after adapter or +runtime settings change; the concrete backend still owns the final restore. + +## Interfaces + +```go +type Session interface { + WakeState(ctx, WakeRequest) (*WakeResult, error) + SleepState(ctx, SleepRequest) (*SleepResult, error) +} + +type Forker interface { + ForkState(ctx, WakeRequest) (Session, *WakeResult, error) +} +``` + +`Session.WakeState` restores into an **existing** session. `Forker.ForkState` +**creates** a new live session from durable state — used when you want +two divergent continuations from the same parent prefix without disturbing +the original. ForkState returns both the new Session and the wake result +so callers can either keep operating on the fork directly or hand it back +through a registry. + +## Aliases + +Consumers historically used `AgentMemory*` names (the concept predates +the package split). These are kept as type aliases so existing callers +compile without rewriting: + +```go +type AgentMemoryRef = Ref +type AgentMemoryWakeRequest = WakeRequest +type AgentMemoryWakeResult = WakeResult +type AgentMemorySleepRequest = SleepRequest +type AgentMemorySleepResult = SleepResult +type AgentMemorySession = Session +type AgentMemoryForker = Forker +``` + +The `inference` parent package re-exports these via `identity.go` so a +consumer importing only `dappco.re/go/inference` sees `AgentMemoryRef` +without needing the `state` subpackage import. + +## Where it's implemented + +- `go-mlx` — Metal-backed `Session` + `Forker`. The reference + implementation, with KV-block-level append, parent-prefix reuse, and + State video `.mp4` packaging. See `go-mlx/docs/memory/agent_memory.md`. +- `go-rocm` — planned mirror for AMD/ROCm. +- `go-cuda` — planned mirror for NVIDIA/CUDA. + +## Why URI-first + +Storage policy lives at the URI scheme, not in the contract. + +- `state://aurelius/meditations` — QR-video knowledge pack +- `file:///var/lib/coreagent/bundles/abc123/` — local filestore +- `s3://lethean-bundles/2026-05/agent-7/` — object storage +- `memory://test/fixture-1` — in-memory test harness + +A runtime that knows how to dial the URI handles the bytes; the contract +doesn't care which one ships first or which one ships best. + +## Why no streaming Wake API + +`WakeResult` reports counts (tokens / blocks / bytes), not a streaming +channel. The bytes go into the runtime's own KV cache before the result +returns — by the time you have a `WakeResult`, the session is ready to +generate. The streaming progress story is owned by `probe.go` (probe +events emitted during wake) rather than by this DTO. + +## Used by + +- `go-mlx/cmd/violet` — sidecar exposes Wake/Sleep/Fork over Unix socket +- LTHN project seeds — app/CLI orchestration can wake a per-project context, + append observations, then sleep a child state or fall back to a text summary. +- `go-ai/ai/book_state_demo.go` — teacher/student demo uses WakeResult → + `BookState` (the demo's user-facing context shape) +- `go-mlx/pkg/memvid` — deprecated compatibility path for older State video + encoder/decoder imports +- `core/ide` (planned) — agent inspector panel reads bundle index for + the "what's in my brain right now" UI + +## Validated benchmark + +92k-token book loaded into context from cold (runner not preloaded) in +**55.2s** including bundle decode + KV restore — see +`project_local_inference_topology.md`. The same bundle re-restored from +warm cache: **998ms** for a chapter, **2.15s** for the full book. diff --git a/docs/state/filestore.md b/docs/state/filestore.md new file mode 100644 index 0000000..56a469f --- /dev/null +++ b/docs/state/filestore.md @@ -0,0 +1,100 @@ + + +# state/filestore — append-only file-backed state store + +**Package**: `dappco.re/go/inference/state/filestore` +**File**: `go/state/filestore/store.go` + +## What this is + +A durable, single-file, append-only implementation of the `state.Store` +interfaces. Designed as the on-disk canonical for CoreAgent bundles +when State video packaging isn't required (most local-only +sessions). Each chunk is a self-describing record; the file as a whole +forms a write-ahead-log style history. + +## File format + +``` ++--------------------------+ +| MAGIC: "go-inference-..." | 31 bytes (or legacy go-mlx 25 bytes) ++--------------------------+ +| Record 1 | +| - magic "MVF1" (4) | +| - chunk_id (8) | +| - payload size (8) | +| - meta size (4) | +| - payload bytes ... | +| - meta JSON bytes ... | ++--------------------------+ +| Record 2 ... | ++--------------------------+ +``` + +`recordHeaderLen = 24` (4 + 8 + 8 + 4). The full record header tells +the reader exactly how many bytes to seek over for the payload and how +many for the JSON-encoded metadata. + +## Codec stamp + +```go +const CodecFile = "state/file-log" +``` + +Bundles emitted by this store identify with `Codec: CodecFile` so a +wake on a State-video-only build can detect-and-route or refuse-and-warn +based on whether the file-log decoder is compiled in. + +## Backward compatibility + +The legacy magic `go-mlx-memvid-file-log-v1\n` is still recognised on +open — older bundles written when this code lived in `go-mlx` +round-trip without rewrite. New writes always use the +`go-inference-state-file-log-v1\n` magic. + +## API + +```go +filestore.Create(ctx, path) (*Store, error) // new file +filestore.Open(ctx, path) (*Store, error) // read existing, rebuild index in RAM +``` + +Once open, `*Store` satisfies `state.Store` + `state.Resolver` + +`state.URIResolver` + `state.Writer` + `state.BinaryWriter`. Index is +held in-memory; very large bundles benefit from a future on-disk +index — currently every URI/chunk-id lookup is O(1) hash but the index +itself is O(N) memory. + +## Concurrency + +One `sync.Mutex` per `Store`. Writes append at `writeAt`, reads scan +the index then `ReadAt` from the file. Multiple goroutines can read +concurrently with one writer holding the mutex during the +append-and-fsync. + +## Failure modes + +Append-only means a crash mid-write leaves a torn record at EOF. Open +detects truncated records (header reads past EOF or payload+meta short +of declared size) and rolls `writeAt` back to the last good record — +the partial bytes are overwritten on the next Put. + +## When to use + +- Local development without a State video encoder configured +- Single-machine CoreAgent that doesn't need portable .mp4 packs +- Test fixtures that need on-disk durability between processes + +## When NOT to use + +- Cross-machine bundle sharing → State video (`.mp4`) +- Object-storage backed bundles → S3 + custom resolver +- Read-mostly cold storage → State video (compression + scan-friendly) + +## Consumed by + +- `go-mlx/cmd/violet` — when configured with a local `bundles_dir` +- `go-mlx/agent_memory.go` — preferred Store for the Wake/Sleep loop + when State video output isn't requested +- Test harnesses that need cross-test persistence (filestore lives, + in-memory dies on process exit) diff --git a/docs/state/identity.md b/docs/state/identity.md new file mode 100644 index 0000000..531e27e --- /dev/null +++ b/docs/state/identity.md @@ -0,0 +1,81 @@ + + +# state/identity.go — portable identity DTOs + +**Package**: `dappco.re/go/inference/state` +**File**: `go/state/identity.go` +**Aliased into**: `dappco.re/go/inference` (via `identity.go` — +`inference.ModelIdentity` etc. are aliases of these types) + +## What this is + +Six DTOs that travel with every durable artefact in the system: + +| Type | What it identifies | +|------|--------------------| +| `ModelIdentity` | which model produced/expects this — hash, arch, quant, ctx-len | +| `TokenizerIdentity` | which tokenizer + chat template — BOS/EOS/PAD ids, template hash | +| `AdapterIdentity` | which LoRA/adapter is active — hash, rank, alpha, target keys, base-model hash | +| `RuntimeIdentity` | which runtime/device produced it — backend name, device, version, cache mode | +| `SamplerConfig` | reproducible sampling — temp, top-k, top-p, repeat penalty, stop tokens | +| `StateRef` | typed reference to one external blob — kind, URI, hash, size, encoding | + +Plus the envelope: + +| Type | Role | +|------|------| +| `Bundle` (`StateBundle` alias) | the full state envelope a sleep emits — model + tokenizer + adapter + sampler + runtime + prompt hash + KV refs + probe refs + State refs + labels | + +## Why these are separate from `state/agent_memory.go` + +Agent memory is about lifecycle (Wake/Sleep/Fork). Identity is about +**compatibility checking** at lifecycle boundaries: + +- A wake refuses to restore a Gemma-3 bundle into a Gemma-4 session + (model arch differs). +- A wake refuses to restore an adapter-on bundle into an adapter-off + session (`AdapterIdentity.Hash` differs). +- A wake records which runtime produced the bundle so audit can trace + divergent results back to "this bundle came from go-rocm vs go-mlx". + +`Bundle.KVRefs` / `ProbeRefs` / `StateRefs` are arrays of `StateRef` +because one bundle commonly fans out to multiple blobs — KV blocks are +chunked, probes are per-layer, State frames are sequenced. + +## Why `ModelIdentity.Hash` is load-bearing + +The hash is what `WakeRequest.SkipCompatibilityCheck` flips off. By +default a wake compares `req.Model.Hash` to `bundle.Model.Hash` and +rejects on mismatch — even if the architecture matches, a quantisation +re-pack or weight delta produces a different hash and would silently +corrupt KV. + +Hash format is backend-defined (typically SHA-256 of safetensor index +file + adapter file), but the contract is "same hash → same weights → +KV is valid". + +## SamplerConfig <-> GenerateConfig + +The `state` package keeps the portable `SamplerConfig` shape. The +`inference` parent package converts to/from its richer +`GenerateConfig` (which includes `GenerateOption` plumbing) via two +free functions in `inference/identity.go`: + +```go +inference.SamplerConfigFromGenerateConfig(cfg) → SamplerConfig +inference.GenerateConfigFromSamplerConfig(cfg) → GenerateConfig +``` + +This is deliberate — the bundle stores the **outcome** of the option +choices, not the option-function chain. + +## Used by + +- `state/agent_memory.go` — `Ref` carries `StateRefs []StateRef` +- `state/store.go` — chunk metadata +- `go-mlx/state_bundle.go` — bundle encode/decode +- `go-mlx/kv_snapshot.go` — snapshot/restore stores Bundle alongside KV + blocks +- `go-ml/agent_eval.go` — eval reports embed `ModelIdentity` + + `AdapterIdentity` for reproducibility +- `core/api` benchmark surface — bench reports carry `RuntimeIdentity` diff --git a/docs/state/memory.md b/docs/state/memory.md new file mode 100644 index 0000000..fe244fd --- /dev/null +++ b/docs/state/memory.md @@ -0,0 +1,68 @@ + + +# state/memory.go — InMemoryStore + +**Package**: `dappco.re/go/inference/state` +**File**: `go/state/memory.go` + +## What this is + +The in-process reference implementation of every read and write +interface in `state/store.go`. Maps `chunk_id → text|bytes` plus an +optional `uri → chunk_id` index. Zero file I/O, zero network, zero +codec — useful for tests, fixtures, and the "spike before wiring +State path. + +## Capabilities implemented + +`*InMemoryStore` satisfies: + +- `Store` (`Get`) +- `Resolver` (`Resolve`) +- `BinaryResolver` (`ResolveBytes`) +- `URIResolver` (`ResolveURI`) +- `Writer` (`Put`) +- `BinaryWriter` (`PutBytes`) + +Not implemented: + +- `RefBinaryResolver` (falls back to `ResolveBytes(chunk_id)`) +- `BinaryStreamWriter` (in-memory has no streaming win) + +## Constructors + +```go +state.NewInMemoryStore(map[int]string{1: "hello"}) +state.NewInMemoryStoreWithManifest(chunks, refs) // pre-seed ChunkRef metadata +``` + +The "WithManifest" form is for round-tripping fixtures — you write some +chunks via `Put`, capture the returned refs, then in a later test +recreate the same store with both the text *and* the refs so chunk-id ++ codec match. + +## Codec stamp + +Every ref written by this store carries `Codec: state.CodecMemory` and +`HasFrameOffset: true` with `FrameOffset == ChunkID`. The frame-offset +mirror makes test fixtures behave the same as State bundles for code +that branches on frame addressing — the test path doesn't need a +separate "I'm in fixture mode" flag. + +## When NOT to use + +This store is not safe across goroutines without external locking. A +production session uses State video (file-backed, immutable) or filestore +(append-only on disk) for durability. Use `InMemoryStore` for: + +- Unit tests against `Resolve` / `ResolveURI` / `Put` +- Fixture seeding in example tests +- Dev workflow where the wake/sleep loop runs in-process + +## Consumed by + +- `state/state_test.go` — round-trip + URI-resolution tests +- `go-mlx/agent_memory_test.go` — runtime smoke tests against a known + in-memory store before reaching for State video +- `go-ai/ai/book_state_demo_test.go` — bookstate fixtures point at + in-memory chunks via `entry-uri memory://...` diff --git a/docs/state/project_seed.md b/docs/state/project_seed.md new file mode 100644 index 0000000..e2a4ded --- /dev/null +++ b/docs/state/project_seed.md @@ -0,0 +1,70 @@ + + +# state/project_seed.go — project-seed workflow helpers + +**Package**: `dappco.re/go/inference/state` +**File**: `go/state/project_seed.go` +**Aliased into**: `dappco.re/go/inference` + +## What this is + +Small backend-neutral helpers for the LTHN project-memory flow. They do not +load models or write bytes. They produce consistent `WakeRequest` and +`SleepRequest` values, decide whether a continuation should persist state or +fall back to summary text, and compare a saved `Bundle` with a wake request +before a runtime tries to restore KV. + +The concrete runtime still owns wake/sleep. go-mlx restores KV blocks on Metal; +go-rocm and future drivers can implement the same `Session` and `Forker` +contracts without copying app policy. + +## ProjectSeed + +`NewProjectSeed` normalises the URI set for a project: + +```go +seed := state.NewProjectSeed(state.ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", +}) +``` + +The default seed entry becomes: + +```text +state://lthn/projects/core/go-mlx/seed +state://lthn/projects/core/go-mlx/seed/bundle +state://lthn/projects/core/go-mlx/seed/index +``` + +`seed.WakeRequest(...)` carries model, tokenizer, adapter, runtime, and labels +into a normal `WakeRequest`. + +## Continuation modes + +`seed.PlanContinuation(...)` lowers product policy into concrete request shape: + +| Mode | Result | +|------|--------| +| `ProjectSeedStateCheckpoint` | returns a `SleepRequest` with parent refs and `ReuseParentPrefix=true` | +| `ProjectSeedReuseCurrent` | no sleep request; caller records findings elsewhere and keeps the current seed | +| `ProjectSeedSummaryWindow` | no sleep request; caller writes summary text and starts a fresh window | +| `ProjectSeedHybrid` | returns a sleep request and marks that summary text should also be written | + +This keeps "reply" separate from persistence. A background agent can wake, +append observations, sleep a new child state, and never emit an operator-facing +answer. + +## Compatibility + +`CheckWakeCompatibility(bundle, req)` checks the high-risk identity fields +before a wake: + +- model hash, architecture, layer count, quantisation, and context capacity +- tokenizer hash and chat template +- adapter presence/hash/path/rank +- runtime backend/cache-mode changes as warnings, not hard blockers + +When the report is incompatible, orchestration should prefer summary/new-window +or hybrid fallback. `SkipCompatibilityCheck` is still available for explicit +research runs and returns a compatible report with a warning. diff --git a/docs/state/store.md b/docs/state/store.md new file mode 100644 index 0000000..542ea11 --- /dev/null +++ b/docs/state/store.md @@ -0,0 +1,127 @@ + + +# state/store.go — chunk-addressable storage interfaces + +**Package**: `dappco.re/go/inference/state` +**File**: `go/state/store.go` + +## What this is + +The portable contract for **chunk-addressable storage** that backs the +wake/sleep lifecycle. A bundle written by `Session.SleepState` becomes a +sequence of chunks behind one of these interfaces; a wake reads them +back via `Resolve` / `ResolveBytes` / `ResolveURI`. + +Five storage capabilities expressed as separate, narrow interfaces. A +backend implements only what it can support — `Store.Get` for text, +`BinaryResolver` for bytes, `URIResolver` for State URI lookup, +`Writer` / `BinaryWriter` / `BinaryStreamWriter` for the encode side. + +## Codecs + +```go +CodecMemory = "memory/plaintext" // in-process test/dev store +CodecStateVideo = "state/qr-video" // QR-encoded MP4 cold storage +``` + +The codec field on a `ChunkRef` tells the wake side which decoder to +spin up. State video is the portable `.mp4` codec; in-memory is the +test harness; filestore is the raw local file log. + +## Capability matrix + +| Interface | Read mode | Notes | +|-----------|-----------|-------| +| `Store` | text only | minimum viable backend | +| `Resolver` | text + ref metadata | upgrades a Store with offset info | +| `BinaryResolver` | bytes | for non-text bundles (KV blocks, attention snapshots) | +| `RefBinaryResolver` | bytes via `ChunkRef` | lets the store choose chunk id OR frame offset OR segment hint | +| `URIResolver` | bytes via `uri` | for stores that index by external URI rather than int id | + +| Interface | Write mode | Notes | +|-----------|-----------|-------| +| `Writer` | text | smallest write surface | +| `BinaryWriter` | bytes in one buffer | the common path | +| `BinaryStreamWriter` | bytes via callback | for large bundles where buffering the whole payload would OOM the encoder | + +The package-level free functions (`Resolve`, `ResolveBytes`, +`ResolveRefBytes`, `ResolveURI`) take a generic `Store` and probe up to +the richer interface via type assertion — so callers always get bytes if +they ask for bytes, even when only text is implemented. + +## DTOs + +`Chunk` — what comes back from a read: + +```go +type Chunk struct { + Ref ChunkRef + Text string // empty for binary-only chunks + Data []byte // empty for text-only chunks (filled when caller asks ResolveBytes) +} +``` + +`ChunkRef` — the durable handle: + +```go +type ChunkRef struct { + ChunkID int // monotonic id within a bundle + FrameOffset uint64 // for State video: which video frame + HasFrameOffset bool // distinguishes "frame 0" from "unset" + Codec string // state/qr-video, memory/plaintext, … + Segment string // optional sub-segment id within the chunk +} +``` + +`PutOptions` — write-side metadata that the encoder retains alongside +bytes: + +```go +type PutOptions struct { + URI string + Title string + Kind string // "kv-block", "attention-snapshot", "prompt", … + Track string // sub-stream within a bundle + Tags map[string]string + Labels []string +} +``` + +## Errors + +Two typed errors, both unwrapping to `ErrChunkNotFound`: + +- `ChunkNotFoundError{ID: int}` — chunk-id miss +- `URIChunkNotFoundError{URI: string}` — URI-keyed miss + +Callers use `errors.Is(err, state.ErrChunkNotFound)` to handle both +shapes uniformly. + +## MergeRef + +`MergeRef(base, overlay ChunkRef)` is the merge primitive used when a +bundle's index is updated incrementally — overlay non-zero fields, keep +base for the rest. Lets sleep-with-parent operations carry forward the +parent's chunk identity while updating frame offsets. + +## Why not one big Store interface + +Backends differ in what they can do. A full State video store implements every interface. +A test fixture might implement only `Store.Get`. The current `inference` +package code does type-assertion probing rather than forcing every +backend to stub out methods it can't actually perform — which means a +small backend can be 50 lines, not 500. + +## Implemented by + +- `state/memory.go` — `InMemoryStore`. Test fixture + dev workflow. +- `state/filestore/store.go` — raw file log (planned canonical for + CoreAgent on-disk bundles). +- `go-mlx/pkg/memvid/filestore` — deprecated compatibility path. + +## Consumed by + +- `state/agent_memory.go` — Wake/Sleep/Fork hold a `Store any` and dial + through these interfaces +- `go-mlx/pkg/memvid` — deprecated compatibility import path for older + encoder/decoder callers diff --git a/external/go b/external/go index d661b70..f7a84db 160000 --- a/external/go +++ b/external/go @@ -1 +1 @@ -Subproject commit d661b703e16183b3cbab101de189f688888a1174 +Subproject commit f7a84db6ce08722dc3d42ad72ed9094621fca992 diff --git a/go.work b/go.work index 9201445..b8920d4 100644 --- a/go.work +++ b/go.work @@ -1,4 +1,4 @@ -go 1.26.0 +go 1.26.2 // Workspace mode for development: pulls local sources from external/ submodules. // diff --git a/go/anthropic/anthropic.go b/go/anthropic/anthropic.go new file mode 100644 index 0000000..3cc443e --- /dev/null +++ b/go/anthropic/anthropic.go @@ -0,0 +1,381 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package anthropic provides Anthropic Messages wire primitives over the +// shared inference contracts. +package anthropic + +import ( + core "dappco.re/go" + "dappco.re/go/inference" + "dappco.re/go/inference/jsonenc" +) + +// DefaultMessagesPath is the Anthropic-compatible Messages endpoint. +const DefaultMessagesPath = "/v1/messages" + +// ContentBlock is the text block shape used by Anthropic Messages. +type ContentBlock struct { + Type string `json:"type"` + Text string `json:"text,omitempty"` +} + +// Message is one Anthropic chat turn. +type Message struct { + Role string `json:"role"` + Content []ContentBlock `json:"content"` +} + +// MessageRequest is the minimal Anthropic-compatible request shape. +type MessageRequest struct { + Model string `json:"model"` + System string `json:"system,omitempty"` + Messages []Message `json:"messages"` + MaxTokens int `json:"max_tokens"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + Stream bool `json:"stream,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` +} + +// Usage records Anthropic-style token accounting. +type Usage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` +} + +// MessageResponse is the non-streaming Anthropic-compatible response body. +type MessageResponse struct { + ID string `json:"id"` + Type string `json:"type"` + Role string `json:"role"` + Model string `json:"model"` + Content []ContentBlock `json:"content"` + StopReason string `json:"stop_reason,omitempty"` + StopSequence string `json:"stop_sequence,omitempty"` + Usage Usage `json:"usage"` +} + +// AppendMessageResponse walks an Anthropic MessageResponse into the +// caller-owned buf and returns the extended slice. Fires at the HTTP- +// response-emit boundary on every non-streaming completion — callers +// bypass the encoding/json reflect path (encoder state + grow-doubled +// output buffer + per-nested-struct allocations) and pre-size the +// buffer once via MessageResponseSize. Same caller-passes-buf shape +// as state/filestore's encodeRecordMeta (W8-D) and openai's +// appendChatCompletionResponse (W9-D). +// +// MarshalJSON is deliberately NOT implemented on MessageResponse: the +// bench for core.JSONMarshalString shows that wrapping a flat struct +// in a MarshalJSON method REGRESSES json.Marshal — encoding/json then +// calls MarshalJSON, validates (compact) the returned bytes, then +// copies them into its own grow-buffer. The hand-roll wins only when +// the call site bypasses json.Marshal and calls this helper directly. +// +// Wire-compatible with json.Marshal across every branch: +// - Always emits id, type, role, model, content, usage. +// - stop_reason / stop_sequence: omitempty (string). +// - content: each ContentBlock emits type always, text only when +// non-empty (matches ContentBlock's `text,omitempty` tag). +// - usage: always emits input_tokens + output_tokens (no +// omitempty). +// +// Output round-trips through core.JSONUnmarshal back into a +// MessageResponse — verified by the round-trip pinning test. +// +// buf := AppendMessageResponse(make([]byte, 0, MessageResponseSize(resp)), resp) +// w.Write(buf) // typical HTTP-emit shape. +func AppendMessageResponse(buf []byte, r MessageResponse) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "id", r.ID, false) + buf = jsonenc.AppendStringField(buf, "type", r.Type, true) + buf = jsonenc.AppendStringField(buf, "role", r.Role, true) + buf = jsonenc.AppendStringField(buf, "model", r.Model, true) + buf = append(buf, ',', '"', 'c', 'o', 'n', 't', 'e', 'n', 't', '"', ':', '[') + for i, b := range r.Content { + if i > 0 { + buf = append(buf, ',') + } + buf = appendContentBlock(buf, b) + } + buf = append(buf, ']') + if r.StopReason != "" { + buf = jsonenc.AppendStringField(buf, "stop_reason", r.StopReason, true) + } + if r.StopSequence != "" { + buf = jsonenc.AppendStringField(buf, "stop_sequence", r.StopSequence, true) + } + // Usage object — always emitted (no omitempty on the field). + buf = append(buf, ',', '"', 'u', 's', 'a', 'g', 'e', '"', ':', '{') + buf = jsonenc.AppendIntField(buf, "input_tokens", r.Usage.InputTokens, false) + buf = jsonenc.AppendIntField(buf, "output_tokens", r.Usage.OutputTokens, true) + return append(buf, '}', '}') +} + +// MessageResponseSize estimates the backing-buffer size for one +// MessageResponse so the caller's make([]byte, 0, ...) lands on a +// memory class that fits the encoded body in a single allocation. +// Returns a tight upper bound — ASCII key bytes plus the string- +// value bodies. Worst-case escape doubling on text fields lets +// append grow once at most. +func MessageResponseSize(r MessageResponse) int { + // Per-field cost: ,"key":"value" + // leading-comma (1) + "key" (len(key)+2) + : (1) + "value" (len(value)+2) + // = 6 + len(key) + len(value) + // First field omits leading comma: 5 + len(key) + len(value). + size := 2 // outer braces + size += 5 + 2 + len(r.ID) // "id":"…" + size += 6 + 4 + len(r.Type) // ,"type":"…" + size += 6 + 4 + len(r.Role) // ,"role":"…" + size += 6 + 5 + len(r.Model) // ,"model":"…" + size += 6 + 7 // ,"content":[] + for i, b := range r.Content { + size += 5 + 2 + 4 + len(b.Type) // {"type":"X"} + if b.Text != "" { + size += 6 + 4 + len(b.Text) // ,"text":"X" + } + size += 1 // closing brace } + if i > 0 { + size += 1 // , separator between blocks + } + } + if r.StopReason != "" { + size += 6 + 11 + len(r.StopReason) // ,"stop_reason":"X" + } + if r.StopSequence != "" { + size += 6 + 13 + len(r.StopSequence) // ,"stop_sequence":"X" + } + // ,"usage":{"input_tokens":N,"output_tokens":N} + // 9 ("usage":) + 2 (object braces) + 5+2+10+1+11+11+10+1+11 ≈ 60 + size += 6 + 5 + 2 + 26 + 28 + return size +} + +// appendContentBlock encodes a single ContentBlock as JSON onto buf. +// type is always emitted; text is omitted when empty (matches the +// `text,omitempty` tag on the struct). Lifted out so +// AppendMessageResponse / AppendMessageRequest and future content-array +// shapes share it. +func appendContentBlock(buf []byte, b ContentBlock) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "type", b.Type, false) + if b.Text != "" { + buf = jsonenc.AppendStringField(buf, "text", b.Text, true) + } + return append(buf, '}') +} + +// appendMessage encodes a single chat-turn Message as JSON onto buf. +// role + content always emitted; content is an array of ContentBlocks. +func appendMessage(buf []byte, m Message) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "role", m.Role, false) + buf = append(buf, ',', '"', 'c', 'o', 'n', 't', 'e', 'n', 't', '"', ':', '[') + for i, b := range m.Content { + if i > 0 { + buf = append(buf, ',') + } + buf = appendContentBlock(buf, b) + } + return append(buf, ']', '}') +} + +// AppendMessageRequest walks an Anthropic MessageRequest into the +// caller-owned buf and returns the extended slice. Fires at the +// client-side request-encode boundary — proxies and SDK clients pay +// 2 allocs / 480-3500 B through json.Marshal's reflect path even +// before per-field pointer-allocation cost. The hand-rolled encoder +// lands at a single buffer allocation regardless of pointer-field +// count and slice depth. +// +// Wire-compatible with json.Marshal across every branch: +// - model + messages + max_tokens always emitted (no omitempty). +// - system: omitempty (string). +// - temperature / top_p / top_k: omitempty (pointer); emitted as +// number only when non-nil. +// - stream: omitempty (bool); emitted as true only when true. +// - stop_sequences: omitempty (slice); emitted as JSON array of +// strings when len > 0. +// +// buf := AppendMessageRequest(make([]byte, 0, MessageRequestSize(req)), req) +// httpClient.Post(url, "application/json", bytes.NewReader(buf)) +func AppendMessageRequest(buf []byte, r MessageRequest) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "model", r.Model, false) + if r.System != "" { + buf = jsonenc.AppendStringField(buf, "system", r.System, true) + } + buf = append(buf, ',', '"', 'm', 'e', 's', 's', 'a', 'g', 'e', 's', '"', ':', '[') + for i, m := range r.Messages { + if i > 0 { + buf = append(buf, ',') + } + buf = appendMessage(buf, m) + } + buf = append(buf, ']') + buf = jsonenc.AppendIntField(buf, "max_tokens", r.MaxTokens, true) + if r.Temperature != nil { + buf = jsonenc.AppendFloat32Field(buf, "temperature", *r.Temperature, true) + } + if r.TopP != nil { + buf = jsonenc.AppendFloat32Field(buf, "top_p", *r.TopP, true) + } + if r.TopK != nil { + buf = jsonenc.AppendIntField(buf, "top_k", *r.TopK, true) + } + if r.Stream { + buf = jsonenc.AppendBoolField(buf, "stream", true, true) + } + if len(r.StopSequences) > 0 { + buf = append(buf, ',', '"', 's', 't', 'o', 'p', '_', 's', 'e', 'q', 'u', 'e', 'n', 'c', 'e', 's', '"', ':', '[') + for i, s := range r.StopSequences { + if i > 0 { + buf = append(buf, ',') + } + buf = jsonenc.AppendJSONString(buf, s) + } + buf = append(buf, ']') + } + return append(buf, '}') +} + +// MessageRequestSize estimates a tight upper bound for the backing +// buffer one MessageRequest needs so the caller's make([]byte, 0, +// MessageRequestSize(req)) lands on a memory class that fits the +// encoded body in a single allocation. +// +// Per-field overhead = ,"key": as documented in +// MessageResponseSize. Pointer/bool/slice fields fold in only when +// they would emit under the omitempty contract. +func MessageRequestSize(r MessageRequest) int { + size := 2 // outer braces + size += 5 + 5 + len(r.Model) // "model":"…" + if r.System != "" { + size += 6 + 6 + len(r.System) // ,"system":"…" + } + size += 6 + 8 // ,"messages":[] + for i, m := range r.Messages { + // {"role":"…","content":[…]} + size += 5 + 4 + len(m.Role) + size += 6 + 7 // ,"content":[] + for j, b := range m.Content { + size += 5 + 2 + 4 + len(b.Type) // {"type":"X"} + if b.Text != "" { + size += 6 + 4 + len(b.Text) // ,"text":"X" + } + size += 1 // } + if j > 0 { + size += 1 // , + } + } + size += 1 // } + if i > 0 { + size += 1 // , + } + } + size += 6 + 10 + 20 // ,"max_tokens":N (20-digit int) + if r.Temperature != nil { + size += 6 + 11 + 24 // ,"temperature":F (24-byte float) + } + if r.TopP != nil { + size += 6 + 5 + 24 // ,"top_p":F + } + if r.TopK != nil { + size += 6 + 5 + 20 // ,"top_k":N + } + if r.Stream { + size += 6 + 6 + 4 // ,"stream":true + } + if len(r.StopSequences) > 0 { + size += 6 + 14 // ,"stop_sequences":[] + for i, s := range r.StopSequences { + size += 2 + len(s) // "X" + if i > 0 { + size += 1 // , + } + } + } + return size +} + +// InferenceMessages converts Anthropic messages into shared inference messages. +func InferenceMessages(req MessageRequest) []inference.Message { + out := make([]inference.Message, 0, len(req.Messages)+1) + if req.System != "" { + out = append(out, inference.Message{Role: "system", Content: req.System}) + } + for _, msg := range req.Messages { + out = append(out, inference.Message{Role: msg.Role, Content: blockText(msg.Content)}) + } + return out +} + +// GenerateOptions converts Anthropic sampling fields into inference options. +func GenerateOptions(req MessageRequest) []inference.GenerateOption { + opts := make([]inference.GenerateOption, 0, 4) + if req.MaxTokens > 0 { + opts = append(opts, inference.WithMaxTokens(req.MaxTokens)) + } + if req.Temperature != nil { + opts = append(opts, inference.WithTemperature(*req.Temperature)) + } + if req.TopP != nil { + opts = append(opts, inference.WithTopP(*req.TopP)) + } + if req.TopK != nil { + opts = append(opts, inference.WithTopK(*req.TopK)) + } + return opts +} + +// NewTextResponse builds a text response from shared inference metrics. +func NewTextResponse(id, model, text string, metrics inference.GenerateMetrics) MessageResponse { + return MessageResponse{ + ID: id, + Type: "message", + Role: "assistant", + Model: model, + Content: []ContentBlock{{Type: "text", Text: text}}, + StopReason: "end_turn", + Usage: Usage{ + InputTokens: metrics.PromptTokens, + OutputTokens: metrics.GeneratedTokens, + }, + } +} + +func blockText(blocks []ContentBlock) string { + // Fast paths — common cases produce 0 or 1 string without + // touching the builder. Per-message hot path; InferenceMessages + // calls this once per Anthropic content array on every request. + if len(blocks) == 0 { + return "" + } + if len(blocks) == 1 { + b := blocks[0] + if b.Type == "" || b.Type == "text" { + return b.Text + } + return "" + } + // Multi-block: pre-sum then Grow the builder once. Previous shape + // (out += block.Text) was O(N²) — each += reallocated and copied + // the entire prefix. + total := 0 + for _, block := range blocks { + if block.Type == "" || block.Type == "text" { + total += len(block.Text) + } + } + if total == 0 { + return "" + } + builder := core.NewBuilder() + builder.Grow(total) + for _, block := range blocks { + if block.Type == "" || block.Type == "text" { + builder.WriteString(block.Text) + } + } + return builder.String() +} diff --git a/go/anthropic/anthropic_bench_test.go b/go/anthropic/anthropic_bench_test.go new file mode 100644 index 0000000..d246448 --- /dev/null +++ b/go/anthropic/anthropic_bench_test.go @@ -0,0 +1,310 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the Anthropic Messages wire primitives. +// Per AX-11 — Marshal/Unmarshal of MessageRequest/MessageResponse fires +// once per Messages call, and InferenceMessages / GenerateOptions run +// at request-entry on every served chat turn. blockText is the +// per-content-block inner loop that runs over every message in the +// request transcript on every call. +// +// Run: go test -bench='BenchmarkAnthropic' -benchtime=100ms -benchmem -run='^$' . + +package anthropic + +import ( + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + anthropicSinkRequest MessageRequest + anthropicSinkResponse MessageResponse + anthropicSinkMessages []inference.Message + anthropicSinkOptions []inference.GenerateOption + anthropicSinkResult core.Result + anthropicSinkString string + anthropicSinkText string + anthropicSinkBytes []byte +) + +// --- Fixture builders --- + +// buildAnthropicRequest produces a representative system+user+assistant +// transcript with the requested number of message turns. Each user +// message carries the typical short query shape; assistant turns carry +// longer multi-paragraph completions. +func buildAnthropicRequest(turns int) MessageRequest { + temp := float32(0.7) + topP := float32(0.95) + topK := 64 + req := MessageRequest{ + Model: "claude-3-5-sonnet", + System: "You are a helpful assistant. Be concise.", + MaxTokens: 1024, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + StopSequences: []string{"", "<|eot_id|>"}, + } + user := "Please summarise the following short paragraph for me in one sentence." + assistant := "The summary is concise and faithful to the original text. " + + "It preserves the principal claim and the supporting detail without padding." + for i := 0; i < turns; i++ { + role := "user" + text := user + if i%2 == 1 { + role = "assistant" + text = assistant + } + req.Messages = append(req.Messages, Message{ + Role: role, + Content: []ContentBlock{{Type: "text", Text: text}}, + }) + } + return req +} + +// buildAnthropicResponse mirrors a real completion — multi-block text +// content with a trailing usage block. +func buildAnthropicResponse() MessageResponse { + return NewTextResponse( + "msg_bench", + "claude-3-5-sonnet", + "The summary is concise and faithful to the original text.", + inference.GenerateMetrics{PromptTokens: 320, GeneratedTokens: 48}, + ) +} + +// --- JSON Marshal — fires at response emission --- + +func BenchmarkAnthropic_MarshalMessageRequest_SingleTurn(b *testing.B) { + req := buildAnthropicRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkString = core.JSONMarshalString(req) + } +} + +func BenchmarkAnthropic_MarshalMessageRequest_FiveTurn(b *testing.B) { + req := buildAnthropicRequest(5) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkString = core.JSONMarshalString(req) + } +} + +func BenchmarkAnthropic_MarshalMessageRequest_TwentyTurn(b *testing.B) { + req := buildAnthropicRequest(20) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkString = core.JSONMarshalString(req) + } +} + +func BenchmarkAnthropic_MarshalMessageResponse_Typical(b *testing.B) { + resp := buildAnthropicResponse() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkString = core.JSONMarshalString(resp) + } +} + +// --- Hand-rolled AppendMessageResponse — bypasses json.Marshal +// reflect path. Wins are visible when consumers reach for the helper +// directly (HTTP-response-emit), not when measured via JSONMarshalString. +// Per-W9-D pattern: caller pre-sizes the buffer once via the +// MessageResponseSize estimator so encoding lands at 1 alloc. + +func BenchmarkAnthropic_AppendMessageResponse_Typical(b *testing.B) { + resp := buildAnthropicResponse() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkBytes = AppendMessageResponse(make([]byte, 0, MessageResponseSize(resp)), resp) + } +} + +func BenchmarkAnthropic_AppendMessageResponse_WithStopReason(b *testing.B) { + resp := buildAnthropicResponse() + resp.StopReason = "stop_sequence" + resp.StopSequence = "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkBytes = AppendMessageResponse(make([]byte, 0, MessageResponseSize(resp)), resp) + } +} + +// --- Hand-rolled AppendMessageRequest — client-side request encode. +// Outbound proxy / SDK path serialises one MessageRequest per turn. + +func BenchmarkAnthropic_AppendMessageRequest_SingleTurn(b *testing.B) { + req := buildAnthropicRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkBytes = AppendMessageRequest(make([]byte, 0, MessageRequestSize(req)), req) + } +} + +func BenchmarkAnthropic_AppendMessageRequest_FiveTurn(b *testing.B) { + req := buildAnthropicRequest(5) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkBytes = AppendMessageRequest(make([]byte, 0, MessageRequestSize(req)), req) + } +} + +func BenchmarkAnthropic_AppendMessageRequest_TwentyTurn(b *testing.B) { + req := buildAnthropicRequest(20) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkBytes = AppendMessageRequest(make([]byte, 0, MessageRequestSize(req)), req) + } +} + +// --- JSON Unmarshal — fires at request entry --- + +func BenchmarkAnthropic_UnmarshalMessageRequest_SingleTurn(b *testing.B) { + body := core.JSONMarshalString(buildAnthropicRequest(1)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req MessageRequest + anthropicSinkResult = core.JSONUnmarshalString(body, &req) + anthropicSinkRequest = req + } +} + +func BenchmarkAnthropic_UnmarshalMessageRequest_FiveTurn(b *testing.B) { + body := core.JSONMarshalString(buildAnthropicRequest(5)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req MessageRequest + anthropicSinkResult = core.JSONUnmarshalString(body, &req) + anthropicSinkRequest = req + } +} + +func BenchmarkAnthropic_UnmarshalMessageRequest_TwentyTurn(b *testing.B) { + body := core.JSONMarshalString(buildAnthropicRequest(20)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req MessageRequest + anthropicSinkResult = core.JSONUnmarshalString(body, &req) + anthropicSinkRequest = req + } +} + +func BenchmarkAnthropic_UnmarshalMessageResponse_Typical(b *testing.B) { + body := core.JSONMarshalString(buildAnthropicResponse()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var resp MessageResponse + anthropicSinkResult = core.JSONUnmarshalString(body, &resp) + anthropicSinkResponse = resp + } +} + +// --- InferenceMessages — wire→internal conversion fired per request --- + +func BenchmarkAnthropic_InferenceMessages_SingleTurn(b *testing.B) { + req := buildAnthropicRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkMessages = InferenceMessages(req) + } +} + +func BenchmarkAnthropic_InferenceMessages_FiveTurn(b *testing.B) { + req := buildAnthropicRequest(5) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkMessages = InferenceMessages(req) + } +} + +func BenchmarkAnthropic_InferenceMessages_TwentyTurn(b *testing.B) { + req := buildAnthropicRequest(20) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkMessages = InferenceMessages(req) + } +} + +// --- GenerateOptions — sampling-field projection fired per request --- + +func BenchmarkAnthropic_GenerateOptions_AllFieldsSet(b *testing.B) { + req := buildAnthropicRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkOptions = GenerateOptions(req) + } +} + +func BenchmarkAnthropic_GenerateOptions_MinimalFields(b *testing.B) { + req := MessageRequest{Model: "claude-3-5-sonnet", MaxTokens: 256} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkOptions = GenerateOptions(req) + } +} + +// --- NewTextResponse — fires once per non-streaming completion --- + +func BenchmarkAnthropic_NewTextResponse(b *testing.B) { + metrics := inference.GenerateMetrics{PromptTokens: 320, GeneratedTokens: 48} + text := "The summary is concise and faithful to the original text." + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkResponse = NewTextResponse("msg_bench", "claude-3-5-sonnet", text, metrics) + } +} + +// --- blockText — per-content-block inner loop (unexported; reached via +// InferenceMessages but worth a direct bench at the boundary shape). --- +// Single text block — the dominant production shape. + +func BenchmarkAnthropic_BlockText_SingleTextBlock(b *testing.B) { + blocks := []ContentBlock{{Type: "text", Text: "hello world"}} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkText = blockText(blocks) + } +} + +// Multi-block — the streamed-back shape with prompt caching headers +// splitting an instruction prefix from the user payload. +func BenchmarkAnthropic_BlockText_FiveBlocks(b *testing.B) { + blocks := []ContentBlock{ + {Type: "text", Text: "You are a helpful assistant. "}, + {Type: "text", Text: "Always respond in UK English. "}, + {Type: "text", Text: "Be concise. "}, + {Type: "text", Text: "Summarise the following paragraph: "}, + {Type: "text", Text: "The quick brown fox jumps over the lazy dog."}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + anthropicSinkText = blockText(blocks) + } +} diff --git a/go/anthropic/anthropic_test.go b/go/anthropic/anthropic_test.go new file mode 100644 index 0000000..e877999 --- /dev/null +++ b/go/anthropic/anthropic_test.go @@ -0,0 +1,50 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package anthropic + +import ( + "testing" + + "dappco.re/go/inference" +) + +func TestAnthropic_InferenceMessages_Good(t *testing.T) { + req := MessageRequest{ + System: "system", + Messages: []Message{{ + Role: "user", + Content: []ContentBlock{{Type: "text", Text: "hello"}}, + }}, + } + + messages := InferenceMessages(req) + + if len(messages) != 2 { + t.Fatalf("len(messages) = %d, want 2", len(messages)) + } + if messages[0].Role != "system" || messages[1].Content != "hello" { + t.Fatalf("messages = %+v", messages) + } +} + +func TestAnthropic_GenerateOptions_Good(t *testing.T) { + temp := float32(0.2) + topK := 4 + opts := GenerateOptions(MessageRequest{MaxTokens: 9, Temperature: &temp, TopK: &topK}) + + cfg := inference.ApplyGenerateOpts(opts) + if cfg.MaxTokens != 9 || cfg.Temperature != 0.2 || cfg.TopK != 4 { + t.Fatalf("cfg = %+v", cfg) + } +} + +func TestAnthropic_NewTextResponse_Good(t *testing.T) { + resp := NewTextResponse("msg_1", "claude-ish", "ok", inference.GenerateMetrics{PromptTokens: 2, GeneratedTokens: 3}) + + if resp.ID != "msg_1" || resp.Type != "message" || resp.Role != "assistant" { + t.Fatalf("resp = %+v", resp) + } + if resp.Content[0].Text != "ok" || resp.Usage.OutputTokens != 3 { + t.Fatalf("resp = %+v", resp) + } +} diff --git a/go/anthropic/jsondec.go b/go/anthropic/jsondec.go new file mode 100644 index 0000000..e950328 --- /dev/null +++ b/go/anthropic/jsondec.go @@ -0,0 +1,557 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Hand-rolled JSON-decoding for the Anthropic Messages wire types. +// Fires at HTTP request-entry per Messages call — the encoding/json +// reflect path costs 26-107 allocs for the canonical 1/5/20-turn +// shapes (encoder state machine, per-field reflect.Value boxing, +// per-string allocation, per-pointer-field heap allocation). +// +// The single-pass walker per type lands at ~6-10 allocs for typical +// shapes — predominantly the per-string clones the wire contract +// already requires. Slice fields are pre-sized when the array length +// is cheap to count; pointer fields skip the per-field heap escape +// by stack-allocating the indirected value and taking address. +// +// Each UnmarshalJSON returns errors via the package-local +// resultError shape (matches the encoding/json contract — wrapped +// for the caller's `core.JSONUnmarshal*` Result) so existing tests +// continue to receive a single error. + +package anthropic + +import ( + "dappco.re/go/inference/jsonenc" +) + +// UnmarshalJSON walks the MessageRequest wire shape in a single pass. +// Wire-compatible with json.Unmarshal across every branch: +// - model, system, messages, max_tokens, temperature, top_p, +// top_k, stream, stop_sequences — dispatched by exact key +// byte-compare. +// - Unknown keys SkipJSONValue past — matches encoding/json's +// default decoder behaviour (silent ignore unless DisallowUnknownFields +// is set, which this package does not). +// - Pointer fields (Temperature, TopP, TopK) point at heap copies +// of the parsed value only when the field is present and not +// null — same as the reflect path. +// - StopSequences via jsonenc.ParseJSONStringList (string or +// array of strings, plus null). +// +// Allocations come from: +// - One per parsed string (model/system/role/content text). Same +// floor encoding/json pays. +// - One per non-empty Messages slice (pre-sized via prescanning the +// array length). +// - One per non-empty Content slice within each Message. +// - One per non-nil pointer field (Temperature, TopP, TopK). +func (r *MessageRequest) UnmarshalJSON(data []byte) error { + *r = MessageRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +// unmarshalField dispatches one MessageRequest field by key. Returns +// the index one past the consumed value (which may itself be an +// object or array). +func (r *MessageRequest) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "system": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.System = s + return next, nil + case "messages": + msgs, next, err := parseMessageArray(data, i) + if err != nil { + return next, err + } + r.Messages = msgs + return next, nil + case "max_tokens": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.MaxTokens = int(n) + return next, nil + case "temperature": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONFloat32(data, i) + if err != nil { + return next, err + } + r.Temperature = &v + return next, nil + case "top_p": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONFloat32(data, i) + if err != nil { + return next, err + } + r.TopP = &v + return next, nil + case "top_k": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + k := int(v) + r.TopK = &k + return next, nil + case "stream": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONBool(data, i) + if err != nil { + return next, err + } + r.Stream = v + return next, nil + case "stop_sequences": + next, err := jsonenc.SkipJSONValue(data, i) + if err != nil { + return next, err + } + stops, err := jsonenc.ParseJSONStringList(data[i:next]) + if err != nil { + return next, err + } + r.StopSequences = stops + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} + +// UnmarshalJSON walks the MessageResponse wire shape in a single pass. +// Same dispatch pattern as MessageRequest; covers every field the +// hand-rolled AppendMessageResponse emits. +func (r *MessageResponse) UnmarshalJSON(data []byte) error { + *r = MessageResponse{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +// unmarshalField dispatches one MessageResponse field by key. +func (r *MessageResponse) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "id": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.ID = s + return next, nil + case "type": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Type = s + return next, nil + case "role": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Role = s + return next, nil + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "content": + blocks, next, err := parseContentBlockArray(data, i) + if err != nil { + return next, err + } + r.Content = blocks + return next, nil + case "stop_reason": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.StopReason = s + return next, nil + case "stop_sequence": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.StopSequence = s + return next, nil + case "usage": + usage, next, err := parseUsage(data, i) + if err != nil { + return next, err + } + r.Usage = usage + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} + +// parseMessageArray walks a JSON array of Message objects at data[i]. +// Uses append-grow rather than a CountJSONArrayElements prescan: the +// prescan walks the whole array via SkipJSONValue twice (once to +// count, once to parse) and costs more than the append-double cascade +// it would have saved (single-turn 4.1 µs vs 2.6 µs without). +func parseMessageArray(data []byte, i int) ([]Message, int, error) { + if jsonenc.IsJSONNull(data, i) { + return nil, i + 4, nil + } + i, err := jsonenc.MatchArrayStart(data, i) + if err != nil { + return nil, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == ']' { + return nil, i + 1, nil + } + var out []Message + for { + msg, next, err := parseMessage(data, i) + if err != nil { + return nil, next, err + } + out = append(out, msg) + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) { + return nil, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i = jsonenc.SkipJSONWhitespace(data, i+1) + continue + } + if data[i] == ']' { + return out, i + 1, nil + } + return nil, i, jsonenc.ErrInvalidJSON + } +} + +// parseMessage walks a single Message object at data[i]. +func parseMessage(data []byte, i int) (Message, int, error) { + var msg Message + i, err := jsonenc.MatchObjectStart(data, i) + if err != nil { + return msg, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return msg, i + 1, nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return msg, i, jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return msg, next, err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return msg, i, jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "role": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return msg, vnext, verr + } + msg.Role = s + i = vnext + case "content": + blocks, vnext, verr := parseContentBlockArray(data, i) + if verr != nil { + return msg, vnext, verr + } + msg.Content = blocks + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return msg, vnext, verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return msg, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return msg, i + 1, nil + } + return msg, i, jsonenc.ErrInvalidJSON + } +} + +// parseContentBlockArray walks a JSON array of ContentBlock objects. +// append-grow path — content arrays typically carry 1-3 blocks per +// turn, well under the first-grow threshold. +func parseContentBlockArray(data []byte, i int) ([]ContentBlock, int, error) { + if jsonenc.IsJSONNull(data, i) { + return nil, i + 4, nil + } + i, err := jsonenc.MatchArrayStart(data, i) + if err != nil { + return nil, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == ']' { + return nil, i + 1, nil + } + var out []ContentBlock + for { + block, next, err := parseContentBlock(data, i) + if err != nil { + return nil, next, err + } + out = append(out, block) + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) { + return nil, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i = jsonenc.SkipJSONWhitespace(data, i+1) + continue + } + if data[i] == ']' { + return out, i + 1, nil + } + return nil, i, jsonenc.ErrInvalidJSON + } +} + +// parseContentBlock walks a single ContentBlock object at data[i]. +func parseContentBlock(data []byte, i int) (ContentBlock, int, error) { + var block ContentBlock + i, err := jsonenc.MatchObjectStart(data, i) + if err != nil { + return block, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return block, i + 1, nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return block, i, jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return block, next, err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return block, i, jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "type": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return block, vnext, verr + } + block.Type = s + i = vnext + case "text": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return block, vnext, verr + } + block.Text = s + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return block, vnext, verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return block, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return block, i + 1, nil + } + return block, i, jsonenc.ErrInvalidJSON + } +} + +// parseUsage walks a Usage object at data[i]. +func parseUsage(data []byte, i int) (Usage, int, error) { + var u Usage + i, err := jsonenc.MatchObjectStart(data, i) + if err != nil { + return u, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return u, i + 1, nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return u, i, jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return u, next, err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return u, i, jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "input_tokens": + n, vnext, verr := jsonenc.ParseJSONInt(data, i) + if verr != nil { + return u, vnext, verr + } + u.InputTokens = int(n) + i = vnext + case "output_tokens": + n, vnext, verr := jsonenc.ParseJSONInt(data, i) + if verr != nil { + return u, vnext, verr + } + u.OutputTokens = int(n) + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return u, vnext, verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return u, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return u, i + 1, nil + } + return u, i, jsonenc.ErrInvalidJSON + } +} diff --git a/go/anthropic/jsondec_test.go b/go/anthropic/jsondec_test.go new file mode 100644 index 0000000..ed154ae --- /dev/null +++ b/go/anthropic/jsondec_test.go @@ -0,0 +1,151 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package anthropic + +import ( + "encoding/json" + "reflect" + "testing" +) + +// TestUnmarshalMessageRequest_DirectShapes pins the hand-rolled +// MessageRequest decoder against direct JSON literals. Locks the +// per-field dispatch — present / absent / null variants of every +// pointer field, escape-heavy strings, multi-turn arrays. +func TestUnmarshalMessageRequest_DirectShapes(t *testing.T) { + temp := float32(0.7) + topP := float32(0.95) + topK := 64 + cases := []struct { + name string + in string + want MessageRequest + }{ + { + name: "minimal", + in: `{"model":"claude-3","messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}],"max_tokens":256}`, + want: MessageRequest{ + Model: "claude-3", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "hi"}}}}, + MaxTokens: 256, + }, + }, + { + name: "all-optional-fields-set", + in: `{"model":"claude-3","system":"Be concise.","messages":[{"role":"user","content":[{"type":"text","text":"hi"}]}],"max_tokens":1024,"temperature":0.7,"top_p":0.95,"top_k":64,"stream":true,"stop_sequences":["","<|eot|>"]}`, + want: MessageRequest{ + Model: "claude-3", + System: "Be concise.", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "hi"}}}}, + MaxTokens: 1024, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + Stream: true, + StopSequences: []string{"", "<|eot|>"}, + }, + }, + { + name: "pointer-fields-null-keeps-zero-value", + in: `{"model":"claude-3","messages":[],"max_tokens":256,"temperature":null,"top_p":null,"top_k":null,"stream":null}`, + want: MessageRequest{ + Model: "claude-3", + MaxTokens: 256, + }, + }, + { + name: "stop-sequences-as-single-string", + in: `{"model":"claude-3","messages":[],"max_tokens":256,"stop_sequences":""}`, + want: MessageRequest{ + Model: "claude-3", + MaxTokens: 256, + StopSequences: []string{""}, + }, + }, + { + name: "unknown-fields-ignored", + in: `{"model":"claude-3","messages":[],"max_tokens":256,"future_field":42,"another":"x"}`, + want: MessageRequest{ + Model: "claude-3", + MaxTokens: 256, + }, + }, + { + name: "whitespace-friendly", + in: `{ + "model": "claude-3", + "messages": [ + { "role": "user", "content": [ { "type": "text", "text": "hi" } ] } + ], + "max_tokens": 256 + }`, + want: MessageRequest{ + Model: "claude-3", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "hi"}}}}, + MaxTokens: 256, + }, + }, + { + name: "escape-heavy-text", + in: `{"model":"claude-3","messages":[{"role":"user","content":[{"type":"text","text":"line1\nline2 \"quoted\" \\back"}]}],"max_tokens":256}`, + want: MessageRequest{ + Model: "claude-3", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "line1\nline2 \"quoted\" \\back"}}}}, + MaxTokens: 256, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got MessageRequest + if err := json.Unmarshal([]byte(tc.in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("Unmarshal mismatch\ngot: %+v\nwant: %+v", got, tc.want) + } + }) + } +} + +// TestUnmarshalMessageRequest_InvalidShapes asserts the walker rejects +// malformed bodies cleanly — no panics, just errors. +func TestUnmarshalMessageRequest_InvalidShapes(t *testing.T) { + cases := []string{ + ``, + `{`, + `}`, + `{"model":42}`, + `{"messages":not-an-array}`, + `{"temperature":"hot"}`, + } + for _, in := range cases { + t.Run(in, func(t *testing.T) { + var req MessageRequest + if err := json.Unmarshal([]byte(in), &req); err == nil { + t.Fatalf("Unmarshal(%q) returned nil error", in) + } + }) + } +} + +// TestUnmarshalMessageResponse_DirectShapes pins the response decoder. +func TestUnmarshalMessageResponse_DirectShapes(t *testing.T) { + in := `{"id":"msg_1","type":"message","role":"assistant","model":"claude-3","content":[{"type":"text","text":"hello"}],"stop_reason":"end_turn","usage":{"input_tokens":10,"output_tokens":5}}` + want := MessageResponse{ + ID: "msg_1", + Type: "message", + Role: "assistant", + Model: "claude-3", + Content: []ContentBlock{{Type: "text", Text: "hello"}}, + StopReason: "end_turn", + Usage: Usage{InputTokens: 10, OutputTokens: 5}, + } + var got MessageResponse + if err := json.Unmarshal([]byte(in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %+v\nwant: %+v", got, want) + } +} diff --git a/go/anthropic/jsonenc_test.go b/go/anthropic/jsonenc_test.go new file mode 100644 index 0000000..6a8a96e --- /dev/null +++ b/go/anthropic/jsonenc_test.go @@ -0,0 +1,283 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package anthropic + +import ( + "encoding/json" + "reflect" + "testing" + + "dappco.re/go/inference" +) + +// TestAppendMessageRequest_RoundTrip pins the hand-rolled MessageRequest +// encoder against encoding/json across every wire shape. Proxies and +// SDK clients that consume this body feed it back into the same Go +// type, so the round-trip must be exact. +func TestAppendMessageRequest_RoundTrip(t *testing.T) { + temp := float32(0.7) + topP := float32(0.95) + topK := 64 + cases := []struct { + name string + req MessageRequest + }{ + { + name: "Minimal_RequiredFieldsOnly", + req: MessageRequest{ + Model: "claude-3-5-sonnet", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "hi"}}}}, + MaxTokens: 256, + }, + }, + { + name: "AllOptionalFieldsSet", + req: MessageRequest{ + Model: "claude-3-5-sonnet", + System: "Be concise.", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "hi"}}}}, + MaxTokens: 1024, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + Stream: true, + StopSequences: []string{"", "<|eot_id|>"}, + }, + }, + { + name: "MultiTurn_MixedRoles", + req: MessageRequest{ + Model: "claude-3-5-sonnet", + Messages: []Message{ + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "first"}}}, + {Role: "assistant", Content: []ContentBlock{{Type: "text", Text: "second"}}}, + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "third"}}}, + }, + MaxTokens: 256, + }, + }, + { + name: "EscapeHeavy_System", + req: MessageRequest{ + Model: "claude-3-5-sonnet", + System: "Reply with \"quotes\" and\nnewlines\tand\x01control", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "back\\slash"}}}}, + MaxTokens: 256, + }, + }, + { + name: "EmptyStopSequences_OmittedNotEmptyArray", + req: MessageRequest{ + Model: "claude-3-5-sonnet", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "hi"}}}}, + MaxTokens: 256, + StopSequences: []string{}, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + hand := AppendMessageRequest(make([]byte, 0, MessageRequestSize(tc.req)), tc.req) + + var got MessageRequest + if err := json.Unmarshal(hand, &got); err != nil { + t.Fatalf("json.Unmarshal hand-rolled output failed: %v\nbody: %s", err, hand) + } + ref, err := json.Marshal(tc.req) + if err != nil { + t.Fatalf("json.Marshal reference: %v", err) + } + var want MessageRequest + if err := json.Unmarshal(ref, &want); err != nil { + t.Fatalf("json.Unmarshal stdlib output failed: %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("round-trip mismatch\ngot: %+v\nwant: %+v\nhand: %s\nref: %s", got, want, hand, ref) + } + }) + } +} + +// TestAppendMessageRequest_SizeBoundsFits guards the request-side size +// estimator. Under-sizing forces append to grow the buffer, costing +// the alloc win we built the helper to claim. +func TestAppendMessageRequest_SizeBoundsFits(t *testing.T) { + temp := float32(0.7) + topP := float32(0.95) + topK := 64 + cases := []struct { + name string + req MessageRequest + }{ + {"Minimal", MessageRequest{ + Model: "claude-3-5-sonnet", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "hi"}}}}, + MaxTokens: 256, + }}, + {"FullyPopulated", MessageRequest{ + Model: "claude-3-5-sonnet", + System: "Be concise.", + Messages: []Message{{Role: "user", Content: []ContentBlock{{Type: "text", Text: "the question"}}}}, + MaxTokens: 1024, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + Stream: true, + StopSequences: []string{"", "<|eot_id|>", "STOP"}, + }}, + {"FiveTurnMultiBlock", MessageRequest{ + Model: "claude-3-5-sonnet", + Messages: []Message{ + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "one"}, {Type: "text", Text: "two"}}}, + {Role: "assistant", Content: []ContentBlock{{Type: "text", Text: "three"}}}, + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "four"}}}, + {Role: "assistant", Content: []ContentBlock{{Type: "text", Text: "five"}}}, + {Role: "user", Content: []ContentBlock{{Type: "text", Text: "six"}}}, + }, + MaxTokens: 256, + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + predicted := MessageRequestSize(tc.req) + actual := len(AppendMessageRequest(nil, tc.req)) + if predicted < actual { + t.Fatalf("MessageRequestSize=%d < actual encoded %d — under-sizing forces realloc", predicted, actual) + } + }) + } +} + +// TestAppendMessageResponse_SizeBoundsFits checks the size estimator +// returns >= the actual encoded size across the round-trip cases. +// Pre-sizing is load-bearing — under-sizing forces append to grow +// the slice, which costs one more allocation per call. +func TestAppendMessageResponse_SizeBoundsFits(t *testing.T) { + cases := []struct { + name string + resp MessageResponse + }{ + {"Typical_NewTextResponse", NewTextResponse( + "msg_bench", + "claude-3-5-sonnet", + "The summary is concise and faithful to the original text.", + inference.GenerateMetrics{PromptTokens: 320, GeneratedTokens: 48}, + )}, + {"WithStopReasonAndSequence", MessageResponse{ + ID: "msg_x", + Type: "message", + Role: "assistant", + Model: "claude-3-5-sonnet", + Content: []ContentBlock{{Type: "text", Text: "stopped early"}}, + StopReason: "stop_sequence", + StopSequence: "", + Usage: Usage{InputTokens: 5, OutputTokens: 1}, + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + predicted := MessageResponseSize(tc.resp) + actual := len(AppendMessageResponse(nil, tc.resp)) + if predicted < actual { + t.Fatalf("MessageResponseSize=%d < actual encoded %d — under-sizing forces realloc", predicted, actual) + } + }) + } +} + +// TestAppendMessageResponse_RoundTrip pins the hand-rolled +// MessageResponse encoder against encoding/json across every wire +// shape — the proxy / SDK clients that read this body feed it back +// into the same Go type, so the round-trip must be exact. +func TestAppendMessageResponse_RoundTrip(t *testing.T) { + cases := []struct { + name string + resp MessageResponse + }{ + { + name: "Typical_SingleTextBlock", + resp: MessageResponse{ + ID: "msg_1", + Type: "message", + Role: "assistant", + Model: "claude-3-5-sonnet", + Content: []ContentBlock{{Type: "text", Text: "hello"}}, + Usage: Usage{InputTokens: 5, OutputTokens: 1}, + }, + }, + { + name: "WithStopReason_AndStopSequence", + resp: MessageResponse{ + ID: "msg_2", + Type: "message", + Role: "assistant", + Model: "claude-3-5-sonnet", + Content: []ContentBlock{{Type: "text", Text: "stopped"}}, + StopReason: "stop_sequence", + StopSequence: "", + Usage: Usage{InputTokens: 7, OutputTokens: 2}, + }, + }, + { + name: "EmptyContent", + resp: MessageResponse{ + ID: "msg_3", + Type: "message", + Role: "assistant", + Model: "claude-3-5-sonnet", + Content: []ContentBlock{}, + Usage: Usage{InputTokens: 0, OutputTokens: 0}, + }, + }, + { + name: "MultiBlock_MixedText", + resp: MessageResponse{ + ID: "msg_4", + Type: "message", + Role: "assistant", + Model: "claude-3-5-sonnet", + Content: []ContentBlock{ + {Type: "text", Text: "first"}, + {Type: "text", Text: "second"}, + {Type: "tool_use", Text: ""}, // text omitted when empty + }, + Usage: Usage{InputTokens: 10, OutputTokens: 3}, + }, + }, + { + name: "EscapeHeavy", + resp: MessageResponse{ + ID: `msg "5"`, + Type: "message", + Role: "assistant", + Model: "claude-3-5-sonnet", + Content: []ContentBlock{{Type: "text", Text: "line1\nline2\twith\"quotes\\and\rcontrol\x01char"}}, + Usage: Usage{InputTokens: 8, OutputTokens: 5}, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + hand := AppendMessageResponse(make([]byte, 0, MessageResponseSize(tc.resp)), tc.resp) + + var got MessageResponse + if err := json.Unmarshal(hand, &got); err != nil { + t.Fatalf("json.Unmarshal hand-rolled output failed: %v\nbody: %s", err, hand) + } + // Normalise: empty Content slice unmarshals into nil for + // some shapes; compare via re-marshal-and-decode to a + // reference produced by the stdlib encoder. + ref, err := json.Marshal(tc.resp) + if err != nil { + t.Fatalf("json.Marshal reference: %v", err) + } + var want MessageResponse + if err := json.Unmarshal(ref, &want); err != nil { + t.Fatalf("json.Unmarshal stdlib output failed: %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("round-trip mismatch\ngot: %+v\nwant: %+v\nhand: %s\nref: %s", got, want, hand, ref) + } + }) + } +} diff --git a/go/bench/bench.go b/go/bench/bench.go new file mode 100644 index 0000000..26ba576 --- /dev/null +++ b/go/bench/bench.go @@ -0,0 +1,630 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package bench is the driver-neutral local benchmark/eval harness. +// +// Drivers (go-mlx, go-rocm, go-cuda, …) supply a Runner with +// verb-shaped callbacks for each section of the bench (PromptCache, +// StateKVBlockWarm, KVRestore, StateBundle, SpeculativeDecode, +// PromptLookupDecode, ProbeOverhead). bench.Run orchestrates the +// generation timing + calls each enabled callback + assembles the +// final Report. +package bench + +import ( + "context" + "strconv" + "time" + + core "dappco.re/go" +) + +const ReportVersion = 1 + +// Config controls the local benchmark/eval harness. +type Config struct { + Model string `json:"model,omitempty"` + ModelPath string `json:"model_path,omitempty"` + Prompt string `json:"prompt"` + CachePrompt string `json:"cache_prompt,omitempty"` + MaxTokens int `json:"max_tokens"` + Runs int `json:"runs"` + Temperature float32 `json:"temperature"` + TopK int `json:"top_k,omitempty"` + TopP float32 `json:"top_p,omitempty"` + MinP float32 `json:"min_p,omitempty"` + StopTokens []int32 `json:"stop_tokens,omitempty"` + RepeatPenalty float32 `json:"repeat_penalty,omitempty"` + IncludePromptCache bool `json:"include_prompt_cache"` + IncludeKVRestore bool `json:"include_kv_restore"` + IncludeStateBundleRoundTrip bool `json:"include_state_bundle_round_trip"` + IncludeProbeOverhead bool `json:"include_probe_overhead"` + IncludeStateKVBlockWarm bool `json:"include_state_kv_block_warm"` + // Deprecated: use IncludeStateKVBlockWarm. Kept for old Go callers only. + IncludeMemvidKVBlockWarm bool `json:"-"` + IncludeSpeculativeDecode bool `json:"include_speculative_decode"` + IncludePromptLookupDecode bool `json:"include_prompt_lookup_decode"` + StateKVBlockSize int `json:"state_kv_block_size,omitempty"` + StateKVPrefixTokens int `json:"state_kv_prefix_tokens,omitempty"` + StateKVBlockStorePath string `json:"state_kv_block_store_path,omitempty"` + // Deprecated: use StateKVBlockSize. Kept for old Go callers only. + MemvidKVBlockSize int `json:"-"` + // Deprecated: use StateKVPrefixTokens. Kept for old Go callers only. + MemvidKVPrefixTokens int `json:"-"` + // Deprecated: use StateKVBlockStorePath. Kept for old Go callers only. + MemvidKVBlockStorePath string `json:"-"` + SpeculativeDraftModelPath string `json:"speculative_draft_model_path,omitempty"` + SpeculativeDraftTokens int `json:"speculative_draft_tokens,omitempty"` + PromptLookupTokens []int32 `json:"prompt_lookup_tokens,omitempty"` + QualityPrompts []string `json:"quality_prompts,omitempty"` +} + +// DefaultConfig returns a short local benchmark suite suitable for a laptop. +func DefaultConfig() Config { + return Config{ + Prompt: "Write one precise sentence about local inference.", + MaxTokens: 32, + Runs: 1, + Temperature: 0, + IncludePromptCache: true, + IncludeKVRestore: true, + IncludeStateBundleRoundTrip: true, + IncludeProbeOverhead: true, + } +} + +// Info mirrors a driver's model info — the fields bench consumers care about. +type Info struct { + Architecture string `json:"architecture,omitempty"` + VocabSize int `json:"vocab_size,omitempty"` + NumLayers int `json:"num_layers,omitempty"` + HiddenSize int `json:"hidden_size,omitempty"` + QuantBits int `json:"quant_bits,omitempty"` + QuantGroup int `json:"quant_group,omitempty"` + ContextLength int `json:"context_length,omitempty"` + Adapter AdapterInfo `json:"adapter,omitempty"` +} + +// AdapterInfo identifies a LoRA adapter participating in the bench run. +// Mirrors the shape of go-mlx/lora.AdapterInfo but lives in bench to keep +// the package driver-neutral. +type AdapterInfo struct { + Name string `json:"name,omitempty"` + Path string `json:"path,omitempty"` + Hash string `json:"hash,omitempty"` + Rank int `json:"rank,omitempty"` + Alpha float32 `json:"alpha,omitempty"` + Scale float32 `json:"scale,omitempty"` + TargetKeys []string `json:"target_keys,omitempty"` +} + +// IsEmpty reports whether the adapter info has no meaningful fields set. +func (info AdapterInfo) IsEmpty() bool { + return info.Name == "" && info.Path == "" && info.Hash == "" && info.Rank == 0 && info.Alpha == 0 && info.Scale == 0 && len(info.TargetKeys) == 0 +} + +// GenerateOptions describes one generation request. +type GenerateOptions struct { + MaxTokens int `json:"max_tokens"` + Temperature float32 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float32 `json:"top_p,omitempty"` + MinP float32 `json:"min_p,omitempty"` + StopTokens []int32 `json:"stop_tokens,omitempty"` + RepeatPenalty float32 `json:"repeat_penalty,omitempty"` + // ProbeSink is opaque to bench. Drivers that support probe-recording + // attach the recorder here; the value is passed through to the + // driver's Generate call. + ProbeSink any `json:"-"` +} + +// GenerateOptions returns the per-call generation options derived from +// the Config plus the (optional) probe sink for that call. +func (c Config) GenerateOptions(sink any) GenerateOptions { + return GenerateOptions{ + MaxTokens: c.MaxTokens, + Temperature: c.Temperature, + TopK: c.TopK, + TopP: c.TopP, + MinP: c.MinP, + StopTokens: append([]int32(nil), c.StopTokens...), + RepeatPenalty: c.RepeatPenalty, + ProbeSink: sink, + } +} + +// Generation is one model response plus the driver-reported metrics. +type Generation struct { + Text string `json:"text,omitempty"` + Tokens []int32 `json:"tokens,omitempty"` + Metrics GenerationMetrics `json:"metrics"` +} + +// GenerationMetrics is the bench-readable snapshot of generation timing +// + memory + prompt-cache counters. Drivers populate the fields they can +// report; missing fields are zero. +type GenerationMetrics struct { + PromptTokens int `json:"prompt_tokens"` + GeneratedTokens int `json:"generated_tokens"` + FirstTokenDuration time.Duration `json:"first_token_duration,omitempty"` + PrefillDuration time.Duration `json:"prefill_duration"` + DecodeDuration time.Duration `json:"decode_duration"` + TotalDuration time.Duration `json:"total_duration"` + PrefillTokensPerSec float64 `json:"prefill_tokens_per_sec"` + DecodeTokensPerSec float64 `json:"decode_tokens_per_sec"` + PeakMemoryBytes uint64 `json:"peak_memory_bytes"` + ActiveMemoryBytes uint64 `json:"active_memory_bytes"` + PromptCacheHits int `json:"prompt_cache_hits,omitempty"` + PromptCacheMisses int `json:"prompt_cache_misses,omitempty"` + PromptCacheHitTokens int `json:"prompt_cache_hit_tokens,omitempty"` + PromptCacheMissTokens int `json:"prompt_cache_miss_tokens,omitempty"` + PromptCacheRestoreDuration time.Duration `json:"prompt_cache_restore_duration,omitempty"` +} + +// Runner is the model-side surface bench.Run needs. Generate is required; +// every Bench* callback is optional — if absent, the corresponding +// section of the Report stays Attempted=false. +type Runner struct { + Info func(context.Context) Info + Generate func(context.Context, string, GenerateOptions) (Generation, error) + + BenchPromptCache func(context.Context, Config, GenerationSummary) PromptCacheReport + BenchStateKVBlockWarm func(context.Context, Config, GenerationSummary) StateKVBlockWarmReport + BenchKVRestore func(context.Context, Config) LatencyReport + BenchStateBundle func(context.Context, Config, Info) StateBundleReport + BenchProbeOverhead func(context.Context, Config, time.Duration) ProbeReport + BenchSpeculativeDecode func(context.Context, Config) DecodeOptimisationReport + BenchPromptLookupDecode func(context.Context, Config) DecodeOptimisationReport + + // Deprecated: use BenchStateKVBlockWarm. + BenchMemvidKVBlockWarm func(context.Context, Config, GenerationSummary) MemvidKVBlockWarmReport +} + +// Report is the full benchmark result. +type Report struct { + Version int `json:"version"` + Model string `json:"model,omitempty"` + ModelPath string `json:"model_path,omitempty"` + ModelInfo Info `json:"model_info"` + Config Config `json:"config"` + Generation GenerationSummary `json:"generation"` + PromptCache PromptCacheReport `json:"prompt_cache"` + StateKVBlockWarm StateKVBlockWarmReport `json:"state_kv_block_warm"` + // Deprecated: use StateKVBlockWarm. Kept for old Go callers only. + MemvidKVBlockWarm MemvidKVBlockWarmReport `json:"-"` + KVRestore LatencyReport `json:"kv_restore"` + StateBundle StateBundleReport `json:"state_bundle"` + Probes ProbeReport `json:"probes"` + SpeculativeDecode DecodeOptimisationReport `json:"speculative_decode"` + PromptLookupDecode DecodeOptimisationReport `json:"prompt_lookup_decode"` + Quality QualityReport `json:"quality"` +} + +// GenerationSample stores one measured generation pass. +type GenerationSample struct { + Prompt string `json:"prompt"` + Text string `json:"text,omitempty"` + Tokens []int32 `json:"tokens,omitempty"` + Metrics GenerationMetrics `json:"metrics"` + Elapsed time.Duration `json:"elapsed"` +} + +// GenerationSummary aggregates baseline generation passes. +type GenerationSummary struct { + Runs int `json:"runs"` + PromptTokens int `json:"prompt_tokens"` + GeneratedTokens int `json:"generated_tokens"` + FirstTokenDuration time.Duration `json:"first_token_duration,omitempty"` + PrefillTokensPerSec float64 `json:"prefill_tokens_per_sec"` + DecodeTokensPerSec float64 `json:"decode_tokens_per_sec"` + PrefillDuration time.Duration `json:"prefill_duration"` + DecodeDuration time.Duration `json:"decode_duration"` + TotalDuration time.Duration `json:"total_duration"` + PeakMemoryBytes uint64 `json:"peak_memory_bytes"` + ActiveMemoryBytes uint64 `json:"active_memory_bytes"` + Samples []GenerationSample `json:"samples,omitempty"` +} + +// PromptCacheReport measures warmed prompt-cache reuse. +type PromptCacheReport struct { + Attempted bool `json:"attempted"` + Hits int `json:"hits,omitempty"` + Misses int `json:"misses,omitempty"` + HitRate float64 `json:"hit_rate,omitempty"` + HitTokens int `json:"hit_tokens,omitempty"` + MissTokens int `json:"miss_tokens,omitempty"` + WarmDuration time.Duration `json:"warm_duration,omitempty"` + RestoreDuration time.Duration `json:"restore_duration,omitempty"` + Metrics GenerationMetrics `json:"metrics,omitempty"` + Error string `json:"error,omitempty"` +} + +// StateKVBlockWarmReport measures direct prompt-cache warmup from durable +// State KV blocks (driver-specific feature; mlx provides one, others may not). +type StateKVBlockWarmReport struct { + Attempted bool `json:"attempted"` + Source string `json:"source,omitempty"` + BlockSize int `json:"block_size,omitempty"` + TotalBlocks int `json:"total_blocks,omitempty"` + StorePath string `json:"store_path,omitempty"` + StoreBytes int64 `json:"store_bytes,omitempty"` + BuildDuration time.Duration `json:"build_duration,omitempty"` + BuildTokens int `json:"build_tokens,omitempty"` + BuildTokensPerSec float64 `json:"build_tokens_per_sec,omitempty"` + BlocksRead int `json:"blocks_read,omitempty"` + ChunksRead int `json:"chunks_read,omitempty"` + PrefixTokensRestored int `json:"prefix_tokens_restored,omitempty"` + PromptTokensAvoided int `json:"prompt_tokens_avoided,omitempty"` + ReplayTokens int `json:"replay_tokens,omitempty"` + ExactFallbackReplayTokens int `json:"exact_fallback_replay_tokens,omitempty"` + BaselinePrefillDuration time.Duration `json:"baseline_prefill_duration,omitempty"` + RestoreDuration time.Duration `json:"restore_duration,omitempty"` + GenerateDuration time.Duration `json:"generate_duration,omitempty"` + PrefillSavedPerQuestion time.Duration `json:"prefill_saved_per_question,omitempty"` + BuildAmortizationQuestions int `json:"build_amortization_questions,omitempty"` + BreakEvenQuestions int `json:"break_even_questions,omitempty"` + RestoreSpeedup float64 `json:"restore_speedup,omitempty"` + MemoryPeakBytes uint64 `json:"memory_peak_bytes,omitempty"` + Metrics GenerationMetrics `json:"metrics,omitempty"` + Error string `json:"error,omitempty"` +} + +// MemvidKVBlockWarmReport measures direct prompt-cache warmup from old +// memvid-named KV blocks. +// +// Deprecated: use StateKVBlockWarmReport. +type MemvidKVBlockWarmReport = StateKVBlockWarmReport + +// LatencyReport records a best-effort latency measurement. +type LatencyReport struct { + Attempted bool `json:"attempted"` + Duration time.Duration `json:"duration,omitempty"` + Error string `json:"error,omitempty"` +} + +// StateBundleReport records state-bundle JSON round-trip behavior. +type StateBundleReport struct { + Attempted bool `json:"attempted"` + Duration time.Duration `json:"duration,omitempty"` + Bytes int `json:"bytes,omitempty"` + Error string `json:"error,omitempty"` +} + +// ProbeReport records probe event count and estimated runtime overhead. +// +// Events is opaque (driver-specific probe event vocabulary); KindCounts +// gives bench a portable summary. +type ProbeReport struct { + Attempted bool `json:"attempted"` + EventCount int `json:"event_count,omitempty"` + KindCounts map[string]int `json:"kind_counts,omitempty"` + Duration time.Duration `json:"duration,omitempty"` + OverheadRatio float64 `json:"overhead_ratio,omitempty"` + Metrics GenerationMetrics `json:"metrics,omitempty"` + Error string `json:"error,omitempty"` + Events []any `json:"events,omitempty"` +} + +// DecodeOptimisationReport records an optional decode-optimisation +// comparison against the baseline generation path. +type DecodeOptimisationReport struct { + Attempted bool `json:"attempted"` + Result DecodeOptimisationResult `json:"result,omitempty"` + Metrics DecodeOptimisationMetrics `json:"metrics,omitempty"` + Error string `json:"error,omitempty"` +} + +// DecodeOptimisationResult mirrors the driver's speculative/prompt-lookup +// decode result. Drivers populate the fields their algorithm produces. +type DecodeOptimisationResult struct { + Mode string `json:"mode"` + Prompt string `json:"prompt,omitempty"` + Text string `json:"text,omitempty"` + Tokens []int32 `json:"tokens,omitempty"` + Metrics DecodeOptimisationMetrics `json:"metrics"` +} + +// DecodeOptimisationMetrics summarises candidate acceptance and timing. +type DecodeOptimisationMetrics struct { + TargetTokens int `json:"target_tokens,omitempty"` + DraftTokens int `json:"draft_tokens,omitempty"` + LookupTokens int `json:"lookup_tokens,omitempty"` + AcceptedTokens int `json:"accepted_tokens,omitempty"` + RejectedTokens int `json:"rejected_tokens,omitempty"` + EmittedTokens int `json:"emitted_tokens,omitempty"` + AcceptanceRate float64 `json:"acceptance_rate,omitempty"` + TargetCalls int `json:"target_calls,omitempty"` + DraftCalls int `json:"draft_calls,omitempty"` + Duration time.Duration `json:"duration,omitempty"` + TargetDuration time.Duration `json:"target_duration,omitempty"` + DraftDuration time.Duration `json:"draft_duration,omitempty"` + VisibleTokensPerSec float64 `json:"visible_tokens_per_sec,omitempty"` + TargetTokensPerSec float64 `json:"target_tokens_per_sec,omitempty"` + DraftTokensPerSec float64 `json:"draft_tokens_per_sec,omitempty"` +} + +// QualityReport contains small deterministic checks over generated text. +type QualityReport struct { + Checks []QualityCheck `json:"checks,omitempty"` +} + +// QualityCheck is one pass/fail bench check. +type QualityCheck struct { + Name string `json:"name"` + Pass bool `json:"pass"` + Score float64 `json:"score"` + Detail string `json:"detail,omitempty"` +} + +// Run executes the local bench/eval suite against the supplied runner. +// +// report, err := bench.Run(ctx, runner, cfg) +func Run(ctx context.Context, runner Runner, cfg Config) (*Report, error) { + if ctx == nil { + ctx = context.Background() + } + cfg = normalizeConfig(cfg) + if runner.Generate == nil { + return nil, core.NewError("mlx: bench runner requires Generate") + } + report := &Report{ + Version: ReportVersion, + Model: cfg.Model, + ModelPath: cfg.ModelPath, + Config: cfg, + } + if runner.Info != nil { + report.ModelInfo = runner.Info(ctx) + } + + samples := make([]GenerationSample, 0, cfg.Runs) + for range cfg.Runs { + sample, err := runGeneration(ctx, runner, cfg.Prompt, cfg.GenerateOptions(nil)) + if err != nil { + return nil, err + } + samples = append(samples, sample) + } + report.Generation = summarizeGenerations(samples) + report.Quality.Checks = append(report.Quality.Checks, qualityChecks(samples)...) + + if cfg.IncludePromptCache && runner.BenchPromptCache != nil { + report.PromptCache = runner.BenchPromptCache(ctx, cfg, report.Generation) + } + if cfg.IncludeStateKVBlockWarm && runner.BenchStateKVBlockWarm != nil { + report.StateKVBlockWarm = runner.BenchStateKVBlockWarm(ctx, cfg, report.Generation) + report.MemvidKVBlockWarm = report.StateKVBlockWarm + } else if cfg.IncludeStateKVBlockWarm && runner.BenchMemvidKVBlockWarm != nil { + report.StateKVBlockWarm = runner.BenchMemvidKVBlockWarm(ctx, cfg, report.Generation) + report.MemvidKVBlockWarm = report.StateKVBlockWarm + } + if cfg.IncludeKVRestore && runner.BenchKVRestore != nil { + report.KVRestore = runner.BenchKVRestore(ctx, cfg) + } + if cfg.IncludeStateBundleRoundTrip && runner.BenchStateBundle != nil { + report.StateBundle = runner.BenchStateBundle(ctx, cfg, report.ModelInfo) + } + if cfg.IncludeProbeOverhead && runner.BenchProbeOverhead != nil { + report.Probes = runner.BenchProbeOverhead(ctx, cfg, report.Generation.TotalDuration) + } + if cfg.IncludeSpeculativeDecode && runner.BenchSpeculativeDecode != nil { + report.SpeculativeDecode = runner.BenchSpeculativeDecode(ctx, cfg) + } + if cfg.IncludePromptLookupDecode && runner.BenchPromptLookupDecode != nil { + report.PromptLookupDecode = runner.BenchPromptLookupDecode(ctx, cfg) + } + return report, nil +} + +func normalizeConfig(cfg Config) Config { + def := DefaultConfig() + if configZero(cfg) { + return def + } + if cfg.Prompt == "" { + cfg.Prompt = def.Prompt + } + if cfg.MaxTokens <= 0 { + cfg.MaxTokens = def.MaxTokens + } + if cfg.Runs <= 0 { + cfg.Runs = def.Runs + } + if cfg.CachePrompt == "" { + cfg.CachePrompt = cfg.Prompt + } + if cfg.IncludeMemvidKVBlockWarm { + cfg.IncludeStateKVBlockWarm = true + } + if cfg.MemvidKVBlockSize != 0 && cfg.StateKVBlockSize == 0 { + cfg.StateKVBlockSize = cfg.MemvidKVBlockSize + } + if cfg.MemvidKVPrefixTokens != 0 && cfg.StateKVPrefixTokens == 0 { + cfg.StateKVPrefixTokens = cfg.MemvidKVPrefixTokens + } + if cfg.MemvidKVBlockStorePath != "" && cfg.StateKVBlockStorePath == "" { + cfg.StateKVBlockStorePath = cfg.MemvidKVBlockStorePath + } + cfg.StopTokens = append([]int32(nil), cfg.StopTokens...) + cfg.PromptLookupTokens = append([]int32(nil), cfg.PromptLookupTokens...) + cfg.QualityPrompts = append([]string(nil), cfg.QualityPrompts...) + return cfg +} + +func configZero(cfg Config) bool { + return cfg.Model == "" && + cfg.ModelPath == "" && + cfg.Prompt == "" && + cfg.CachePrompt == "" && + cfg.MaxTokens == 0 && + cfg.Runs == 0 && + cfg.Temperature == 0 && + cfg.TopK == 0 && + cfg.TopP == 0 && + cfg.MinP == 0 && + len(cfg.StopTokens) == 0 && + cfg.RepeatPenalty == 0 && + !cfg.IncludePromptCache && + !cfg.IncludeKVRestore && + !cfg.IncludeStateBundleRoundTrip && + !cfg.IncludeProbeOverhead && + !cfg.IncludeStateKVBlockWarm && + !cfg.IncludeMemvidKVBlockWarm && + !cfg.IncludeSpeculativeDecode && + !cfg.IncludePromptLookupDecode && + cfg.StateKVBlockSize == 0 && + cfg.StateKVPrefixTokens == 0 && + cfg.StateKVBlockStorePath == "" && + cfg.MemvidKVBlockSize == 0 && + cfg.MemvidKVPrefixTokens == 0 && + cfg.MemvidKVBlockStorePath == "" && + cfg.SpeculativeDraftModelPath == "" && + cfg.SpeculativeDraftTokens == 0 && + len(cfg.PromptLookupTokens) == 0 && + len(cfg.QualityPrompts) == 0 +} + +func runGeneration(ctx context.Context, runner Runner, prompt string, opts GenerateOptions) (GenerationSample, error) { + start := time.Now() + generation, err := runner.Generate(ctx, prompt, opts) + elapsed := NonZeroDuration(time.Since(start)) + if err != nil { + return GenerationSample{}, err + } + return GenerationSample{ + Prompt: prompt, + Text: generation.Text, + Tokens: append([]int32(nil), generation.Tokens...), + Metrics: generation.Metrics, + Elapsed: elapsed, + }, nil +} + +func summarizeGenerations(samples []GenerationSample) GenerationSummary { + summary := GenerationSummary{ + Runs: len(samples), + Samples: append([]GenerationSample(nil), samples...), + } + var prefillRateTotal, decodeRateTotal float64 + firstTokenSamples := 0 + for _, sample := range samples { + metrics := sample.Metrics + summary.PromptTokens += metrics.PromptTokens + summary.GeneratedTokens += metrics.GeneratedTokens + if metrics.FirstTokenDuration > 0 { + firstTokenSamples++ + summary.FirstTokenDuration += metrics.FirstTokenDuration + } + summary.PrefillDuration += metrics.PrefillDuration + summary.DecodeDuration += metrics.DecodeDuration + if metrics.TotalDuration > 0 { + summary.TotalDuration += metrics.TotalDuration + } else { + summary.TotalDuration += sample.Elapsed + } + prefillRateTotal += metrics.PrefillTokensPerSec + decodeRateTotal += metrics.DecodeTokensPerSec + if metrics.PeakMemoryBytes > summary.PeakMemoryBytes { + summary.PeakMemoryBytes = metrics.PeakMemoryBytes + } + if metrics.ActiveMemoryBytes > summary.ActiveMemoryBytes { + summary.ActiveMemoryBytes = metrics.ActiveMemoryBytes + } + } + if len(samples) > 0 { + summary.PrefillTokensPerSec = prefillRateTotal / float64(len(samples)) + summary.DecodeTokensPerSec = decodeRateTotal / float64(len(samples)) + } + if firstTokenSamples > 0 { + summary.FirstTokenDuration /= time.Duration(firstTokenSamples) + } + return summary +} + +func qualityChecks(samples []GenerationSample) []QualityCheck { + // Pre-sized for the two fixed checks; strconv.Itoa skips the fmt + // formatter pipeline that Sprintf would walk. + checks := make([]QualityCheck, 0, 2) + nonEmpty := false + generatedTokens := 0 + for _, sample := range samples { + if sample.Text != "" { + nonEmpty = true + } + generatedTokens += sample.Metrics.GeneratedTokens + } + checks = append(checks, QualityCheck{ + Name: "non_empty_output", + Pass: nonEmpty, + Score: boolScore(nonEmpty), + }) + checks = append(checks, QualityCheck{ + Name: "generated_tokens", + Pass: generatedTokens > 0, + Score: boolScore(generatedTokens > 0), + Detail: strconv.Itoa(generatedTokens), + }) + return checks +} + +// PopulateStateKVBlockWarmBench fills in the cross-cutting derived +// fields (Speedup, BreakEvenQuestions, ...) on a StateKVBlockWarmReport +// once the driver-side capture/restore measurements are populated. +// +// report := runner.BenchStateKVBlockWarm(ctx, cfg, baseline) +// bench.PopulateStateKVBlockWarmBench(&report, baseline) +func PopulateStateKVBlockWarmBench(report *StateKVBlockWarmReport, baseline GenerationSummary) { + if report == nil || !report.Attempted { + return + } + report.BaselinePrefillDuration = baseline.PrefillDuration + report.MemoryPeakBytes = maxUint64(baseline.PeakMemoryBytes, maxUint64(report.Metrics.PeakMemoryBytes, report.Metrics.ActiveMemoryBytes)) + if baseline.PrefillDuration > 0 && report.RestoreDuration > 0 { + report.RestoreSpeedup = float64(baseline.PrefillDuration) / float64(report.RestoreDuration) + } + saved := baseline.PrefillDuration - report.RestoreDuration + if saved <= 0 || report.BuildDuration <= 0 { + return + } + report.PrefillSavedPerQuestion = saved + questions := ceilDuration(report.BuildDuration, saved) + report.BuildAmortizationQuestions = questions + report.BreakEvenQuestions = questions +} + +// PopulateMemvidKVBlockWarmBench fills derived values for the old memvid-named +// State block warm report. +// +// Deprecated: use PopulateStateKVBlockWarmBench. +func PopulateMemvidKVBlockWarmBench(report *MemvidKVBlockWarmReport, baseline GenerationSummary) { + PopulateStateKVBlockWarmBench(report, baseline) +} + +func ceilDuration(value, divisor time.Duration) int { + if value <= 0 || divisor <= 0 { + return 0 + } + return int((value + divisor - 1) / divisor) +} + +func maxUint64(a, b uint64) uint64 { + if a > b { + return a + } + return b +} + +func boolScore(pass bool) float64 { + if pass { + return 1 + } + return 0 +} + +// NonZeroDuration returns d if positive, else 1 nanosecond. Exported for +// drivers that want consistent non-zero durations in their bench reports. +func NonZeroDuration(d time.Duration) time.Duration { + if d <= 0 { + return time.Nanosecond + } + return d +} diff --git a/go/bench/bench_bench_test.go b/go/bench/bench_bench_test.go new file mode 100644 index 0000000..6ce8fb0 --- /dev/null +++ b/go/bench/bench_bench_test.go @@ -0,0 +1,314 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the driver-neutral local bench harness — Config +// normalisation, Run orchestration over a synthetic Runner, the +// generation-summary reducer, and the derived-field populator. +// +// Per AX-11 — Run is called once per bench invocation but +// summarizeGenerations + qualityChecks fire over every captured +// sample, and PopulateStateKVBlockWarmBench is called once per +// State-block bench from every driver. The Config copy in +// normalizeConfig touches three slice copies per call. +// +// Run: go test -bench='BenchmarkBench' -benchmem -run='^$' ./go/bench + +package bench + +import ( + "context" + "testing" + "time" +) + +// Sinks defeat compiler DCE. +var ( + benchSinkReport *Report + benchSinkErr error + benchSinkConfig Config + benchSinkSummary GenerationSummary + benchSinkChecks []QualityCheck + benchSinkOpts GenerateOptions + benchSinkBool bool + benchSinkDur time.Duration +) + +// buildBenchSamples mints n GenerationSample records with representative +// timing + token counts — same shape Run captures from a real driver. +func buildBenchSamples(n int) []GenerationSample { + samples := make([]GenerationSample, n) + for i := 0; i < n; i++ { + samples[i] = GenerationSample{ + Prompt: "Write one precise sentence about local inference.", + Text: "Local inference keeps tokens on-device.", + Tokens: []int32{1, 2, 3, 4, 5, 6, 7, 8}, + Metrics: GenerationMetrics{ + PromptTokens: 12, + GeneratedTokens: 32, + FirstTokenDuration: 3 * time.Millisecond, + PrefillDuration: 5 * time.Millisecond, + DecodeDuration: 40 * time.Millisecond, + TotalDuration: 45 * time.Millisecond, + PrefillTokensPerSec: 2400, + DecodeTokensPerSec: 800, + PeakMemoryBytes: uint64(64 << 20), + ActiveMemoryBytes: uint64(48 << 20), + }, + Elapsed: 45 * time.Millisecond, + } + } + return samples +} + +// benchRunner returns a Runner whose Generate emits a fixed scripted +// generation. Used by BenchmarkBench_Run_* below. +func benchRunner(metrics GenerationMetrics) Runner { + return Runner{ + Generate: func(_ context.Context, prompt string, _ GenerateOptions) (Generation, error) { + return Generation{ + Text: "Local inference keeps tokens on-device.", + Tokens: []int32{1, 2, 3, 4, 5, 6, 7, 8}, + Metrics: metrics, + }, nil + }, + } +} + +// --- Run end-to-end with minimal config + scripted generation --- + +func BenchmarkBench_Run_Minimal(b *testing.B) { + cfg := Config{ + Prompt: "Write one precise sentence about local inference.", + MaxTokens: 32, + Runs: 1, + } + runner := benchRunner(GenerationMetrics{ + PromptTokens: 12, GeneratedTokens: 32, + PrefillDuration: 5 * time.Millisecond, DecodeDuration: 40 * time.Millisecond, + TotalDuration: 45 * time.Millisecond, + }) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkReport, benchSinkErr = Run(ctx, runner, cfg) + } +} + +// 10 runs exercises the summariser inside Run on a bigger sample set. +func BenchmarkBench_Run_TenRuns(b *testing.B) { + cfg := Config{ + Prompt: "Write one precise sentence about local inference.", + MaxTokens: 32, + Runs: 10, + } + runner := benchRunner(GenerationMetrics{ + PromptTokens: 12, GeneratedTokens: 32, + PrefillDuration: 5 * time.Millisecond, DecodeDuration: 40 * time.Millisecond, + TotalDuration: 45 * time.Millisecond, + }) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkReport, benchSinkErr = Run(ctx, runner, cfg) + } +} + +// --- DefaultConfig + normalisation hot loop --- + +func BenchmarkBench_DefaultConfig(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkConfig = DefaultConfig() + } +} + +func BenchmarkBench_NormalizeConfig_Zero(b *testing.B) { + cfg := Config{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkConfig = normalizeConfig(cfg) + } +} + +func BenchmarkBench_NormalizeConfig_PopulatedMinimal(b *testing.B) { + cfg := Config{ + Prompt: "Write one precise sentence about local inference.", + MaxTokens: 32, + Runs: 1, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkConfig = normalizeConfig(cfg) + } +} + +// PopulatedFull exercises every slice-copy + deprecated-field migration +// branch in normalizeConfig. +func BenchmarkBench_NormalizeConfig_PopulatedFull(b *testing.B) { + cfg := Config{ + Model: "qwen3", + ModelPath: "/models/qwen3.gguf", + Prompt: "Write one precise sentence about local inference.", + CachePrompt: "Write one precise sentence about local inference.", + MaxTokens: 64, + Runs: 4, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + MinP: 0.05, + StopTokens: []int32{0, 1, 2, 3, 4, 5, 6, 7}, + RepeatPenalty: 1.1, + IncludePromptCache: true, + IncludeKVRestore: true, + IncludeStateBundleRoundTrip: true, + IncludeProbeOverhead: true, + IncludeMemvidKVBlockWarm: true, + MemvidKVBlockSize: 512, + MemvidKVPrefixTokens: 2048, + MemvidKVBlockStorePath: "/cache/state", + SpeculativeDraftModelPath: "/models/draft.gguf", + SpeculativeDraftTokens: 8, + PromptLookupTokens: []int32{10, 20, 30, 40, 50}, + QualityPrompts: []string{"a", "b", "c"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkConfig = normalizeConfig(cfg) + } +} + +// --- GenerateOptions derivation (per-call hot path) --- + +func BenchmarkBench_Config_GenerateOptions_Bare(b *testing.B) { + cfg := DefaultConfig() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkOpts = cfg.GenerateOptions(nil) + } +} + +func BenchmarkBench_Config_GenerateOptions_WithStopTokens(b *testing.B) { + cfg := DefaultConfig() + cfg.StopTokens = []int32{0, 1, 2, 3, 4, 5, 6, 7} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkOpts = cfg.GenerateOptions(nil) + } +} + +// --- summarizeGenerations + qualityChecks (called once per Run) --- + +func BenchmarkBench_SummarizeGenerations_1Sample(b *testing.B) { + samples := buildBenchSamples(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkSummary = summarizeGenerations(samples) + } +} + +func BenchmarkBench_SummarizeGenerations_10Samples(b *testing.B) { + samples := buildBenchSamples(10) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkSummary = summarizeGenerations(samples) + } +} + +func BenchmarkBench_SummarizeGenerations_100Samples(b *testing.B) { + samples := buildBenchSamples(100) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkSummary = summarizeGenerations(samples) + } +} + +func BenchmarkBench_QualityChecks_10Samples(b *testing.B) { + samples := buildBenchSamples(10) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkChecks = qualityChecks(samples) + } +} + +// --- AdapterInfo.IsEmpty (per-report check, fires from drivers) --- + +func BenchmarkBench_AdapterInfo_IsEmpty_Empty(b *testing.B) { + info := AdapterInfo{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkBool = info.IsEmpty() + } +} + +func BenchmarkBench_AdapterInfo_IsEmpty_Populated(b *testing.B) { + info := AdapterInfo{ + Name: "qwen3-lora", + Path: "/adapters/qwen3.lora", + Hash: "sha256:deadbeef", + Rank: 16, + Alpha: 32, + Scale: 0.5, + TargetKeys: []string{"q_proj", "k_proj", "v_proj", "o_proj"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkBool = info.IsEmpty() + } +} + +// --- PopulateStateKVBlockWarmBench (fires once per State-block bench +// from every driver) --- + +func BenchmarkBench_PopulateStateKVBlockWarm(b *testing.B) { + baseline := GenerationSummary{ + PrefillDuration: 200 * time.Millisecond, + PeakMemoryBytes: uint64(96 << 20), + } + report := StateKVBlockWarmReport{ + Attempted: true, + BuildDuration: 400 * time.Millisecond, + RestoreDuration: 8 * time.Millisecond, + Metrics: GenerationMetrics{ + PeakMemoryBytes: uint64(120 << 20), + ActiveMemoryBytes: uint64(64 << 20), + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + r := report + PopulateStateKVBlockWarmBench(&r, baseline) + } +} + +// --- NonZeroDuration (exported helper, fires per Run sample) --- + +func BenchmarkBench_NonZeroDuration_Positive(b *testing.B) { + d := 45 * time.Millisecond + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkDur = NonZeroDuration(d) + } +} + +func BenchmarkBench_NonZeroDuration_Zero(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchSinkDur = NonZeroDuration(0) + } +} diff --git a/go/bench/bench_test.go b/go/bench/bench_test.go new file mode 100644 index 0000000..487c40e --- /dev/null +++ b/go/bench/bench_test.go @@ -0,0 +1,507 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package bench + +import ( + "context" + "errors" + "testing" + "time" +) + +// fakeRunnerOptions describes the synthetic generation result the test +// runner will return on each Generate call. +type fakeRunnerOptions struct { + generationMetrics []GenerationMetrics + generationText []string + generationError error +} + +// newFakeRunner returns a Runner whose Generate emits scripted results. +// Callbacks other than Generate are filled with nil-stubs the caller can +// override. +func newFakeRunner(opts fakeRunnerOptions) (Runner, *int) { + idx := new(int) + runner := Runner{ + Generate: func(_ context.Context, _ string, _ GenerateOptions) (Generation, error) { + if opts.generationError != nil { + return Generation{}, opts.generationError + } + i := *idx + *idx++ + text := "" + if i < len(opts.generationText) { + text = opts.generationText[i] + } + var metrics GenerationMetrics + if i < len(opts.generationMetrics) { + metrics = opts.generationMetrics[i] + } + return Generation{Text: text, Metrics: metrics}, nil + }, + } + return runner, idx +} + +func TestRun_AggregatesGenerationSummary_Good(t *testing.T) { + runner, _ := newFakeRunner(fakeRunnerOptions{ + generationText: []string{"alpha", "beta"}, + generationMetrics: []GenerationMetrics{ + { + PromptTokens: 4, + GeneratedTokens: 6, + FirstTokenDuration: 12 * time.Millisecond, + PrefillDuration: 20 * time.Millisecond, + DecodeDuration: 30 * time.Millisecond, + TotalDuration: 50 * time.Millisecond, + PrefillTokensPerSec: 200, + DecodeTokensPerSec: 60, + PeakMemoryBytes: 1 << 20, + ActiveMemoryBytes: 512 << 10, + }, + { + PromptTokens: 4, + GeneratedTokens: 8, + FirstTokenDuration: 18 * time.Millisecond, + PrefillDuration: 20 * time.Millisecond, + DecodeDuration: 40 * time.Millisecond, + TotalDuration: 60 * time.Millisecond, + PrefillTokensPerSec: 400, + DecodeTokensPerSec: 80, + PeakMemoryBytes: 2 << 20, + ActiveMemoryBytes: 1 << 20, + }, + }, + }) + + report, err := Run(context.Background(), runner, Config{Prompt: "p", MaxTokens: 16, Runs: 2}) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + if report.Version != ReportVersion { + t.Fatalf("Version = %d, want %d", report.Version, ReportVersion) + } + summary := report.Generation + if summary.Runs != 2 { + t.Fatalf("Runs = %d, want 2", summary.Runs) + } + if summary.PromptTokens != 8 || summary.GeneratedTokens != 14 { + t.Fatalf("tokens = prompt:%d generated:%d", summary.PromptTokens, summary.GeneratedTokens) + } + if summary.PrefillTokensPerSec != 300 || summary.DecodeTokensPerSec != 70 { + t.Fatalf("rates = prefill:%v decode:%v, want averages 300/70", + summary.PrefillTokensPerSec, summary.DecodeTokensPerSec) + } + if summary.PeakMemoryBytes != 2<<20 || summary.ActiveMemoryBytes != 1<<20 { + t.Fatalf("memory = peak:%d active:%d", summary.PeakMemoryBytes, summary.ActiveMemoryBytes) + } + if summary.PrefillDuration != 40*time.Millisecond || summary.DecodeDuration != 70*time.Millisecond { + t.Fatalf("durations = prefill:%v decode:%v", summary.PrefillDuration, summary.DecodeDuration) + } + if summary.TotalDuration != 110*time.Millisecond { + t.Fatalf("total duration = %v, want 110ms", summary.TotalDuration) + } + if summary.FirstTokenDuration != 15*time.Millisecond { + t.Fatalf("first token duration = %v, want 15ms average", summary.FirstTokenDuration) + } + if len(summary.Samples) != 2 || summary.Samples[0].Text != "alpha" || summary.Samples[1].Text != "beta" { + t.Fatalf("samples = %+v", summary.Samples) + } +} + +func TestRun_FallsBackToElapsedWhenTotalDurationZero_Good(t *testing.T) { + runner, _ := newFakeRunner(fakeRunnerOptions{ + generationText: []string{"hi"}, + generationMetrics: []GenerationMetrics{{PromptTokens: 1, GeneratedTokens: 1}}, + }) + report, err := Run(context.Background(), runner, Config{Prompt: "p", MaxTokens: 4, Runs: 1}) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + if report.Generation.TotalDuration <= 0 { + t.Fatalf("TotalDuration = %v, want positive fallback from elapsed", report.Generation.TotalDuration) + } +} + +func TestRun_RequiresGenerate_Bad(t *testing.T) { + if _, err := Run(context.Background(), Runner{}, Config{Prompt: "p", MaxTokens: 4, Runs: 1}); err == nil { + t.Fatal("Run() without Generate did not error") + } +} + +func TestRun_PropagatesGenerateError_Bad(t *testing.T) { + want := errors.New("boom") + runner, _ := newFakeRunner(fakeRunnerOptions{generationError: want}) + if _, err := Run(context.Background(), runner, Config{Prompt: "p", MaxTokens: 4, Runs: 1}); err == nil { + t.Fatal("Run() did not propagate Generate error") + } +} + +func TestRun_NilContextDefaultsToBackground_Good(t *testing.T) { + runner, _ := newFakeRunner(fakeRunnerOptions{ + generationText: []string{"ok"}, + generationMetrics: []GenerationMetrics{{GeneratedTokens: 1}}, + }) + report, err := Run(nil, runner, Config{Prompt: "p", MaxTokens: 4, Runs: 1}) + if err != nil { + t.Fatalf("Run(nil ctx) error = %v", err) + } + if report == nil { + t.Fatal("Run(nil ctx) report = nil") + } +} + +func TestRun_PopulatesModelInfoFromCallback_Good(t *testing.T) { + runner, _ := newFakeRunner(fakeRunnerOptions{ + generationText: []string{"ok"}, + generationMetrics: []GenerationMetrics{{GeneratedTokens: 1}}, + }) + runner.Info = func(context.Context) Info { + return Info{Architecture: "qwen3", NumLayers: 28, ContextLength: 32768} + } + report, err := Run(context.Background(), runner, Config{Prompt: "p", MaxTokens: 4, Runs: 1}) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + if report.ModelInfo.Architecture != "qwen3" || report.ModelInfo.NumLayers != 28 || report.ModelInfo.ContextLength != 32768 { + t.Fatalf("ModelInfo = %+v", report.ModelInfo) + } +} + +func TestRun_DispatchesVerbCallbacksWhenIncludeFlagsSet_Good(t *testing.T) { + runner, _ := newFakeRunner(fakeRunnerOptions{ + generationText: []string{"ok"}, + generationMetrics: []GenerationMetrics{{GeneratedTokens: 1, TotalDuration: 5 * time.Millisecond}}, + }) + called := struct { + pc, stateKV, restore, bundle, probe, spec, lookup bool + }{} + runner.BenchPromptCache = func(context.Context, Config, GenerationSummary) PromptCacheReport { + called.pc = true + return PromptCacheReport{Attempted: true, HitRate: 1} + } + runner.BenchStateKVBlockWarm = func(context.Context, Config, GenerationSummary) StateKVBlockWarmReport { + called.stateKV = true + return StateKVBlockWarmReport{Attempted: true, BlockSize: 128} + } + runner.BenchKVRestore = func(context.Context, Config) LatencyReport { + called.restore = true + return LatencyReport{Attempted: true, Duration: time.Millisecond} + } + runner.BenchStateBundle = func(context.Context, Config, Info) StateBundleReport { + called.bundle = true + return StateBundleReport{Attempted: true, Bytes: 42} + } + runner.BenchProbeOverhead = func(context.Context, Config, time.Duration) ProbeReport { + called.probe = true + return ProbeReport{Attempted: true, EventCount: 3} + } + runner.BenchSpeculativeDecode = func(context.Context, Config) DecodeOptimisationReport { + called.spec = true + return DecodeOptimisationReport{Attempted: true, Result: DecodeOptimisationResult{Mode: "speculative"}} + } + runner.BenchPromptLookupDecode = func(context.Context, Config) DecodeOptimisationReport { + called.lookup = true + return DecodeOptimisationReport{Attempted: true, Result: DecodeOptimisationResult{Mode: "prompt_lookup"}} + } + + cfg := Config{ + Prompt: "p", + MaxTokens: 4, + Runs: 1, + IncludePromptCache: true, + IncludeStateKVBlockWarm: true, + IncludeKVRestore: true, + IncludeStateBundleRoundTrip: true, + IncludeProbeOverhead: true, + IncludeSpeculativeDecode: true, + IncludePromptLookupDecode: true, + } + report, err := Run(context.Background(), runner, cfg) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + if !called.pc || !called.stateKV || !called.restore || !called.bundle || !called.probe || !called.spec || !called.lookup { + t.Fatalf("verb callbacks not all called: %+v", called) + } + if !report.PromptCache.Attempted || report.PromptCache.HitRate != 1 { + t.Fatalf("PromptCache = %+v", report.PromptCache) + } + if !report.StateKVBlockWarm.Attempted || report.StateKVBlockWarm.BlockSize != 128 { + t.Fatalf("StateKVBlockWarm = %+v", report.StateKVBlockWarm) + } + if !report.MemvidKVBlockWarm.Attempted || report.MemvidKVBlockWarm.BlockSize != 128 { + t.Fatalf("deprecated MemvidKVBlockWarm alias = %+v", report.MemvidKVBlockWarm) + } + if !report.KVRestore.Attempted || report.KVRestore.Duration != time.Millisecond { + t.Fatalf("KVRestore = %+v", report.KVRestore) + } + if !report.StateBundle.Attempted || report.StateBundle.Bytes != 42 { + t.Fatalf("StateBundle = %+v", report.StateBundle) + } + if !report.Probes.Attempted || report.Probes.EventCount != 3 { + t.Fatalf("Probes = %+v", report.Probes) + } + if !report.SpeculativeDecode.Attempted || report.SpeculativeDecode.Result.Mode != "speculative" { + t.Fatalf("SpeculativeDecode = %+v", report.SpeculativeDecode) + } + if !report.PromptLookupDecode.Attempted || report.PromptLookupDecode.Result.Mode != "prompt_lookup" { + t.Fatalf("PromptLookupDecode = %+v", report.PromptLookupDecode) + } +} + +func TestRun_SkipsVerbCallbacksWhenIncludeFlagsFalse_Good(t *testing.T) { + runner, _ := newFakeRunner(fakeRunnerOptions{ + generationText: []string{"ok"}, + generationMetrics: []GenerationMetrics{{GeneratedTokens: 1}}, + }) + // Set every callback to a fatal-on-call closure: if Run incorrectly + // dispatches it, the test fails. + runner.BenchPromptCache = func(context.Context, Config, GenerationSummary) PromptCacheReport { + t.Fatal("BenchPromptCache called when IncludePromptCache is false") + return PromptCacheReport{} + } + runner.BenchStateKVBlockWarm = func(context.Context, Config, GenerationSummary) StateKVBlockWarmReport { + t.Fatal("BenchStateKVBlockWarm called when IncludeStateKVBlockWarm is false") + return StateKVBlockWarmReport{} + } + runner.BenchKVRestore = func(context.Context, Config) LatencyReport { + t.Fatal("BenchKVRestore called when IncludeKVRestore is false") + return LatencyReport{} + } + runner.BenchStateBundle = func(context.Context, Config, Info) StateBundleReport { + t.Fatal("BenchStateBundle called when IncludeStateBundleRoundTrip is false") + return StateBundleReport{} + } + runner.BenchProbeOverhead = func(context.Context, Config, time.Duration) ProbeReport { + t.Fatal("BenchProbeOverhead called when IncludeProbeOverhead is false") + return ProbeReport{} + } + runner.BenchSpeculativeDecode = func(context.Context, Config) DecodeOptimisationReport { + t.Fatal("BenchSpeculativeDecode called when IncludeSpeculativeDecode is false") + return DecodeOptimisationReport{} + } + runner.BenchPromptLookupDecode = func(context.Context, Config) DecodeOptimisationReport { + t.Fatal("BenchPromptLookupDecode called when IncludePromptLookupDecode is false") + return DecodeOptimisationReport{} + } + + cfg := Config{Prompt: "p", MaxTokens: 4, Runs: 1} + if _, err := Run(context.Background(), runner, cfg); err != nil { + t.Fatalf("Run() error = %v", err) + } +} + +func TestRun_QualityChecks_Good(t *testing.T) { + runner, _ := newFakeRunner(fakeRunnerOptions{ + generationText: []string{"hello"}, + generationMetrics: []GenerationMetrics{{ + GeneratedTokens: 5, + TotalDuration: 10 * time.Millisecond, + }}, + }) + report, err := Run(context.Background(), runner, Config{Prompt: "p", MaxTokens: 8, Runs: 1}) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + if len(report.Quality.Checks) != 2 { + t.Fatalf("Quality.Checks = %d, want 2 default checks", len(report.Quality.Checks)) + } + for _, check := range report.Quality.Checks { + switch check.Name { + case "non_empty_output": + if !check.Pass { + t.Fatalf("non_empty_output check failed: %+v", check) + } + case "generated_tokens": + if !check.Pass || check.Detail != "5" { + t.Fatalf("generated_tokens check = %+v", check) + } + default: + t.Fatalf("unexpected check %q", check.Name) + } + } +} + +func TestRun_QualityChecksFlagEmptyOutput_Ugly(t *testing.T) { + runner, _ := newFakeRunner(fakeRunnerOptions{ + generationText: []string{""}, + generationMetrics: []GenerationMetrics{{}}, + }) + report, err := Run(context.Background(), runner, Config{Prompt: "p", MaxTokens: 4, Runs: 1}) + if err != nil { + t.Fatalf("Run() error = %v", err) + } + for _, check := range report.Quality.Checks { + if check.Pass { + t.Fatalf("expected quality check %q to fail for empty output, got %+v", check.Name, check) + } + } +} + +func TestDefaultConfig_Good(t *testing.T) { + cfg := DefaultConfig() + if cfg.MaxTokens != 32 || cfg.Runs != 1 { + t.Fatalf("DefaultConfig() = %+v, want MaxTokens=32 Runs=1", cfg) + } + if !cfg.IncludePromptCache || !cfg.IncludeKVRestore || !cfg.IncludeStateBundleRoundTrip || !cfg.IncludeProbeOverhead { + t.Fatalf("DefaultConfig() includes = %+v, want baseline four-section coverage", cfg) + } + if cfg.Prompt == "" { + t.Fatal("DefaultConfig() Prompt is empty") + } +} + +func TestNormalizeConfig_FillsDefaultsFromZero_Good(t *testing.T) { + got := normalizeConfig(Config{}) + want := DefaultConfig() + if got.MaxTokens != want.MaxTokens || got.Runs != want.Runs || got.Prompt != want.Prompt { + t.Fatalf("normalizeConfig(zero) = %+v, want defaults %+v", got, want) + } +} + +func TestNormalizeConfig_PreservesPartialConfig_Good(t *testing.T) { + got := normalizeConfig(Config{Prompt: "x", MaxTokens: 7}) + if got.Prompt != "x" || got.MaxTokens != 7 || got.Runs != 1 { + t.Fatalf("normalizeConfig(partial) = %+v", got) + } + if got.CachePrompt != "x" { + t.Fatalf("CachePrompt = %q, want fallback to Prompt", got.CachePrompt) + } +} + +func TestNormalizeConfig_ClonesSlices_Good(t *testing.T) { + stops := []int32{1, 2, 3} + lookup := []int32{4, 5} + quality := []string{"a"} + cfg := normalizeConfig(Config{Prompt: "x", MaxTokens: 4, Runs: 1, StopTokens: stops, PromptLookupTokens: lookup, QualityPrompts: quality}) + stops[0] = 99 + lookup[0] = 99 + quality[0] = "z" + if cfg.StopTokens[0] == 99 || cfg.PromptLookupTokens[0] == 99 || cfg.QualityPrompts[0] == "z" { + t.Fatalf("normalizeConfig did not clone slices: %+v", cfg) + } +} + +func TestPopulateStateKVBlockWarmBench_DerivesSpeedupAndBreakEven_Good(t *testing.T) { + report := StateKVBlockWarmReport{ + Attempted: true, + BuildDuration: 100 * time.Millisecond, + RestoreDuration: 10 * time.Millisecond, + Metrics: GenerationMetrics{PeakMemoryBytes: 1 << 20}, + } + baseline := GenerationSummary{ + PrefillDuration: 50 * time.Millisecond, + PeakMemoryBytes: 2 << 20, + } + PopulateStateKVBlockWarmBench(&report, baseline) + if report.BaselinePrefillDuration != 50*time.Millisecond { + t.Fatalf("BaselinePrefillDuration = %v", report.BaselinePrefillDuration) + } + if report.RestoreSpeedup != 5 { + t.Fatalf("RestoreSpeedup = %v, want 5", report.RestoreSpeedup) + } + if report.PrefillSavedPerQuestion != 40*time.Millisecond { + t.Fatalf("PrefillSavedPerQuestion = %v, want 40ms", report.PrefillSavedPerQuestion) + } + if report.BreakEvenQuestions != 3 { + t.Fatalf("BreakEvenQuestions = %d, want 3 (ceil(100ms/40ms))", report.BreakEvenQuestions) + } + if report.MemoryPeakBytes != 2<<20 { + t.Fatalf("MemoryPeakBytes = %d, want baseline peak 2MiB", report.MemoryPeakBytes) + } +} + +func TestPopulateStateKVBlockWarmBench_SkipsWhenNotAttempted_Ugly(t *testing.T) { + report := StateKVBlockWarmReport{ + BuildDuration: 100 * time.Millisecond, + RestoreDuration: 10 * time.Millisecond, + } + PopulateStateKVBlockWarmBench(&report, GenerationSummary{PrefillDuration: 50 * time.Millisecond}) + if report.BaselinePrefillDuration != 0 || report.RestoreSpeedup != 0 || report.BreakEvenQuestions != 0 { + t.Fatalf("expected no-op when Attempted is false, got %+v", report) + } +} + +func TestPopulateStateKVBlockWarmBench_SkipsWhenSavedNonPositive_Ugly(t *testing.T) { + // Restore took LONGER than baseline prefill — no speedup, no break-even. + report := StateKVBlockWarmReport{ + Attempted: true, + BuildDuration: 100 * time.Millisecond, + RestoreDuration: 80 * time.Millisecond, + } + PopulateStateKVBlockWarmBench(&report, GenerationSummary{PrefillDuration: 50 * time.Millisecond}) + if report.PrefillSavedPerQuestion != 0 || report.BreakEvenQuestions != 0 { + t.Fatalf("expected no break-even when restore is slower than baseline, got saved:%v break-even:%d", report.PrefillSavedPerQuestion, report.BreakEvenQuestions) + } + if report.RestoreSpeedup == 0 { + t.Fatalf("RestoreSpeedup should still be derived even when slower, got %v", report.RestoreSpeedup) + } +} + +func TestAdapterInfo_IsEmpty_GoodBad(t *testing.T) { + if !(AdapterInfo{}).IsEmpty() { + t.Fatal("zero AdapterInfo IsEmpty = false, want true") + } + if (AdapterInfo{Name: "x"}).IsEmpty() { + t.Fatal("AdapterInfo with Name IsEmpty = true, want false") + } + if (AdapterInfo{Rank: 8}).IsEmpty() { + t.Fatal("AdapterInfo with Rank IsEmpty = true, want false") + } + if (AdapterInfo{TargetKeys: []string{"q_proj"}}).IsEmpty() { + t.Fatal("AdapterInfo with TargetKeys IsEmpty = true, want false") + } +} + +func TestConfigGenerateOptions_PassesProbeSinkThrough_Good(t *testing.T) { + sentinel := struct{ tag string }{tag: "sink"} + cfg := Config{MaxTokens: 16, Temperature: 0.7, StopTokens: []int32{1}} + opts := cfg.GenerateOptions(sentinel) + if opts.MaxTokens != 16 || opts.Temperature != 0.7 || len(opts.StopTokens) != 1 { + t.Fatalf("GenerateOptions = %+v", opts) + } + got, ok := opts.ProbeSink.(struct{ tag string }) + if !ok || got.tag != "sink" { + t.Fatalf("ProbeSink = %+v ok=%v, want sentinel passed through", opts.ProbeSink, ok) + } +} + +func TestConfigGenerateOptions_ClonesStopTokens_Good(t *testing.T) { + stops := []int32{1, 2, 3} + cfg := Config{MaxTokens: 1, StopTokens: stops} + opts := cfg.GenerateOptions(nil) + stops[0] = 99 + if opts.StopTokens[0] == 99 { + t.Fatal("GenerateOptions did not clone StopTokens — mutating caller-side slice changed snapshot") + } +} + +func TestRun_RunsClampToOneByDefault_Good(t *testing.T) { + idx := new(int) + runner := Runner{ + Generate: func(context.Context, string, GenerateOptions) (Generation, error) { + *idx++ + return Generation{Text: "x", Metrics: GenerationMetrics{GeneratedTokens: 1}}, nil + }, + } + // Config with Prompt but Runs=0 — normalize fills default of 1. + if _, err := Run(context.Background(), runner, Config{Prompt: "p", MaxTokens: 4}); err != nil { + t.Fatalf("Run() error = %v", err) + } + if *idx != 1 { + t.Fatalf("Generate called %d times, want 1 after Runs<=0 normalisation", *idx) + } +} + +func TestNonZeroDuration_Good(t *testing.T) { + if got := NonZeroDuration(0); got != time.Nanosecond { + t.Fatalf("NonZeroDuration(0) = %v, want 1ns floor", got) + } + if got := NonZeroDuration(-5); got != time.Nanosecond { + t.Fatalf("NonZeroDuration(-5) = %v, want 1ns floor", got) + } + if got := NonZeroDuration(123 * time.Millisecond); got != 123*time.Millisecond { + t.Fatalf("NonZeroDuration(123ms) = %v, want passthrough", got) + } +} diff --git a/go/capability.go b/go/capability.go new file mode 100644 index 0000000..2b84dc2 --- /dev/null +++ b/go/capability.go @@ -0,0 +1,484 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "context" + "maps" + "slices" + + core "dappco.re/go" +) + +// CapabilityGroup identifies the layer a capability belongs to. +type CapabilityGroup string + +const ( + // CapabilityGroupModel covers model-facing inference and model-pack features. + CapabilityGroupModel CapabilityGroup = "model" + // CapabilityGroupRuntime covers hardware/runtime planning and loading. + CapabilityGroupRuntime CapabilityGroup = "runtime" + // CapabilityGroupTraining covers native training and adapter update loops. + CapabilityGroupTraining CapabilityGroup = "training" + // CapabilityGroupProbe covers research telemetry and model-state probing. + CapabilityGroupProbe CapabilityGroup = "probe" +) + +// CapabilityStatus records whether a feature is usable today. +type CapabilityStatus string + +const ( + CapabilityStatusSupported CapabilityStatus = "supported" + CapabilityStatusExperimental CapabilityStatus = "experimental" + CapabilityStatusPlanned CapabilityStatus = "planned" + CapabilityStatusUnsupported CapabilityStatus = "unsupported" +) + +// CapabilityID is a stable feature identifier shared by backends and callers. +type CapabilityID string + +const ( + CapabilityModelLoad CapabilityID = "model.load" + CapabilityGenerate CapabilityID = "generate" + CapabilityChat CapabilityID = "chat" + CapabilityClassify CapabilityID = "classify" + CapabilityBatchGenerate CapabilityID = "batch.generate" + CapabilityTokenizer CapabilityID = "tokenizer" + CapabilityChatTemplate CapabilityID = "chat.template" + CapabilityLoRAInference CapabilityID = "lora.inference" + CapabilityLoRATraining CapabilityID = "lora.training" + CapabilityStateBundle CapabilityID = "state.bundle" + CapabilityKVSnapshot CapabilityID = "kv.snapshot" + CapabilityPromptCache CapabilityID = "prompt.cache" + CapabilityKVCachePlanning CapabilityID = "kv.cache.planning" + CapabilityMemoryPlanning CapabilityID = "memory.planning" + CapabilityModelFit CapabilityID = "model.fit" + CapabilityModelSlice CapabilityID = "model.slice" + CapabilityRuntimeDiscovery CapabilityID = "runtime.discovery" + CapabilityAutoTuning CapabilityID = "runtime.autotune" + CapabilityModelReplace CapabilityID = "model.replace" + CapabilityDifferentialLoad CapabilityID = "model.differential_load" + CapabilitySplitInference CapabilityID = "model.split_inference" + CapabilityBenchmark CapabilityID = "benchmark" + CapabilityEvaluation CapabilityID = "evaluation" + CapabilityDistillation CapabilityID = "distillation" + CapabilityGRPO CapabilityID = "grpo" + CapabilityQuantization CapabilityID = "quantization" + CapabilityModelMerge CapabilityID = "model.merge" + CapabilityProbeEvents CapabilityID = "probe.events" + CapabilityAttentionProbe CapabilityID = "probe.attention" + CapabilityLogitProbe CapabilityID = "probe.logits" + CapabilityLQL CapabilityID = "query.lql" + CapabilityVIndex CapabilityID = "query.vindex" + CapabilityResponsesAPI CapabilityID = "responses.api" + CapabilityAnthropicMessages CapabilityID = "anthropic.messages" + CapabilityOllamaCompat CapabilityID = "ollama.compat" + CapabilityEmbeddings CapabilityID = "embeddings" + CapabilityRerank CapabilityID = "rerank" + CapabilityScheduler CapabilityID = "scheduler" + CapabilityRequestCancel CapabilityID = "request.cancel" + CapabilityCacheBlocks CapabilityID = "cache.blocks" + CapabilityCacheDisk CapabilityID = "cache.disk" + CapabilityCacheWarm CapabilityID = "cache.warm" + CapabilityToolParse CapabilityID = "tool.parse" + CapabilityReasoningParse CapabilityID = "reasoning.parse" + CapabilitySpeculativeDecode CapabilityID = "speculative.decode" + CapabilityPromptLookupDecode CapabilityID = "prompt.lookup.decode" + CapabilityMoERouting CapabilityID = "moe.routing" + CapabilityMoELazyExperts CapabilityID = "moe.lazy_experts" + CapabilityJANGTQ CapabilityID = "jangtq" + CapabilityCodebookVQ CapabilityID = "codebook.vq" + CapabilityAgentMemory CapabilityID = "agent.memory" + CapabilityStateWake CapabilityID = "state.wake" + CapabilityStateSleep CapabilityID = "state.sleep" + CapabilityStateFork CapabilityID = "state.fork" +) + +// Capability describes one backend feature without importing that backend. +type Capability struct { + ID CapabilityID `json:"id"` + Group CapabilityGroup `json:"group,omitempty"` + Status CapabilityStatus `json:"status"` + Detail string `json:"detail,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// FeatureRuntimeStatus records how far a backend has implemented a shared +// algorithm beyond the coarse portable capability status. +type FeatureRuntimeStatus string + +const ( + // FeatureRuntimeNative means the backend has a native implementation. + FeatureRuntimeNative FeatureRuntimeStatus = "native" + // FeatureRuntimeExperimental means the backend implementation is usable but unstable. + FeatureRuntimeExperimental FeatureRuntimeStatus = "experimental" + // FeatureRuntimeMetadataOnly means metadata/planning support exists, but kernels or execution are pending. + FeatureRuntimeMetadataOnly FeatureRuntimeStatus = "metadata_only" + // FeatureRuntimePlanned means the feature is intentionally tracked but not implemented. + FeatureRuntimePlanned FeatureRuntimeStatus = "planned" +) + +// AlgorithmProfile describes one backend-neutral algorithm or feature surface. +// Backends can publish these profiles as labelled capabilities without leaking +// their concrete runtime package. +type AlgorithmProfile struct { + ID CapabilityID `json:"id"` + Group CapabilityGroup `json:"group"` + CapabilityStatus CapabilityStatus `json:"capability_status"` + RuntimeStatus FeatureRuntimeStatus `json:"runtime_status"` + Algorithm string `json:"algorithm,omitempty"` + Detail string `json:"detail,omitempty"` + Architectures []string `json:"architectures,omitempty"` + Requires []CapabilityID `json:"requires,omitempty"` + Provides []string `json:"provides,omitempty"` + Notes []string `json:"notes,omitempty"` +} + +// Capability converts an algorithm profile into the portable report shape. +func (profile AlgorithmProfile) Capability() Capability { + capability := NewCapability(profile.ID, profile.Group, profile.CapabilityStatus, profile.Detail) + labels := map[string]string{ + "runtime_status": string(profile.RuntimeStatus), + } + if profile.Algorithm != "" { + labels["algorithm"] = profile.Algorithm + } + if len(profile.Architectures) > 0 { + labels["architectures"] = core.Join(",", profile.Architectures...) + } + if len(profile.Requires) > 0 { + labels["requires"] = capabilityIDLabel(profile.Requires) + } + if len(profile.Provides) > 0 { + labels["provides"] = core.Join(",", profile.Provides...) + } + capability.Labels = labels + return capability +} + +// CloneAlgorithmProfile returns an independent copy of profile. +func CloneAlgorithmProfile(profile AlgorithmProfile) AlgorithmProfile { + profile.Architectures = append([]string(nil), profile.Architectures...) + profile.Requires = append([]CapabilityID(nil), profile.Requires...) + profile.Provides = append([]string(nil), profile.Provides...) + profile.Notes = append([]string(nil), profile.Notes...) + return profile +} + +func capabilityIDLabel(ids []CapabilityID) string { + values := make([]string, 0, len(ids)) + for _, id := range ids { + values = append(values, string(id)) + } + return core.Join(",", values...) +} + +// CapabilityReport is the portable backend/model feature report consumed by +// go-ml, go-ai, and any package that must avoid backend-specific imports. +type CapabilityReport struct { + Runtime RuntimeIdentity `json:"runtime"` + Model ModelIdentity `json:"model,omitempty"` + Tokenizer TokenizerIdentity `json:"tokenizer,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Available bool `json:"available"` + Architectures []string `json:"architectures,omitempty"` + Quantizations []string `json:"quantizations,omitempty"` + CacheModes []string `json:"cache_modes,omitempty"` + Capabilities []Capability `json:"capabilities,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// CapabilityReporter is implemented by backends and loaded models that can +// expose their native feature surface without leaking concrete package types. +type CapabilityReporter interface { + Capabilities() CapabilityReport +} + +// RuntimeMemoryLimits is a backend-neutral request/response for runtime memory +// caps. Zero request values mean "leave unchanged"; previous values are filled +// by backends that can report them. +type RuntimeMemoryLimits struct { + CacheLimitBytes uint64 `json:"cache_limit_bytes,omitempty"` + MemoryLimitBytes uint64 `json:"memory_limit_bytes,omitempty"` + PreviousCacheLimitBytes uint64 `json:"previous_cache_limit_bytes,omitempty"` + PreviousMemoryLimitBytes uint64 `json:"previous_memory_limit_bytes,omitempty"` +} + +// RuntimeMemoryLimiter is implemented by native runtimes that expose allocator +// limits without requiring callers to import the concrete runtime package. +type RuntimeMemoryLimiter interface { + SetRuntimeMemoryLimits(limits RuntimeMemoryLimits) RuntimeMemoryLimits +} + +// SetRuntimeMemoryLimits applies memory limits to a registered backend when it +// supports [RuntimeMemoryLimiter]. The boolean is false when the backend is not +// registered or does not support this operation. +func SetRuntimeMemoryLimits(backendName string, limits RuntimeMemoryLimits) (RuntimeMemoryLimits, bool) { + backend, ok := Get(backendName) + if !ok { + return RuntimeMemoryLimits{}, false + } + limiter, ok := backend.(RuntimeMemoryLimiter) + if !ok { + return RuntimeMemoryLimits{}, false + } + return limiter.SetRuntimeMemoryLimits(limits), true +} + +// NewCapability creates a single capability entry. +func NewCapability(id CapabilityID, group CapabilityGroup, status CapabilityStatus, detail string) Capability { + return Capability{ID: id, Group: group, Status: status, Detail: detail} +} + +// SupportedCapability creates a capability entry for a stable feature. +func SupportedCapability(id CapabilityID, group CapabilityGroup) Capability { + return NewCapability(id, group, CapabilityStatusSupported, "") +} + +// ExperimentalCapability creates a capability entry for a usable but unstable feature. +func ExperimentalCapability(id CapabilityID, group CapabilityGroup, detail string) Capability { + return NewCapability(id, group, CapabilityStatusExperimental, detail) +} + +// PlannedCapability creates a capability entry for an intentionally exposed +// roadmap item that is not usable yet. +func PlannedCapability(id CapabilityID, group CapabilityGroup, detail string) Capability { + return NewCapability(id, group, CapabilityStatusPlanned, detail) +} + +// UnsupportedCapability creates a capability entry for an unavailable feature. +func UnsupportedCapability(id CapabilityID, group CapabilityGroup, detail string) Capability { + return NewCapability(id, group, CapabilityStatusUnsupported, detail) +} + +// Usable reports whether a capability can be used by callers today. +func (cap Capability) Usable() bool { + return cap.Status == CapabilityStatusSupported || cap.Status == CapabilityStatusExperimental +} + +// Capability returns the first entry with id. +func (report CapabilityReport) Capability(id CapabilityID) (Capability, bool) { + for _, capability := range report.Capabilities { + if capability.ID == id { + return cloneCapability(capability), true + } + } + return Capability{}, false +} + +// Supports reports whether id is present and usable. +func (report CapabilityReport) Supports(id CapabilityID) bool { + capability, ok := report.Capability(id) + return ok && capability.Usable() +} + +// SupportedCapabilityIDs returns stable IDs for all usable capabilities. +func (report CapabilityReport) SupportedCapabilityIDs() []CapabilityID { + ids := make([]CapabilityID, 0, len(report.Capabilities)) + for _, capability := range report.Capabilities { + if capability.Usable() { + ids = append(ids, capability.ID) + } + } + slices.Sort(ids) + return slices.Compact(ids) +} + +// CapabilityIDs returns stable IDs for every reported capability. +func (report CapabilityReport) CapabilityIDs() []CapabilityID { + ids := make([]CapabilityID, 0, len(report.Capabilities)) + for _, capability := range report.Capabilities { + ids = append(ids, capability.ID) + } + slices.Sort(ids) + return slices.Compact(ids) +} + +// CapabilitiesOf returns an explicit or inferred capability report for value. +func CapabilitiesOf(value any) (CapabilityReport, bool) { + if value == nil { + return CapabilityReport{}, false + } + if reporter, ok := value.(CapabilityReporter); ok { + return reporter.Capabilities(), true + } + switch typed := value.(type) { + case Backend: + return BackendCapabilities(typed), true + case TextModel: + return TextModelCapabilities(RuntimeIdentity{}, typed), true + default: + return CapabilityReport{}, false + } +} + +// BackendCapabilities infers the minimal report every registered backend can expose. +func BackendCapabilities(backend Backend) CapabilityReport { + if backend == nil { + return CapabilityReport{} + } + capabilities := []Capability{SupportedCapability(CapabilityModelLoad, CapabilityGroupRuntime)} + if _, ok := backend.(ModelFitPlanner); ok { + capabilities = append(capabilities, SupportedCapability(CapabilityModelFit, CapabilityGroupRuntime)) + } + return CapabilityReport{ + Runtime: RuntimeIdentity{Backend: backend.Name()}, + Available: backend.Available(), + Capabilities: capabilities, + } +} + +// TextModelCapabilities infers a report from optional interfaces implemented by +// a loaded model. +func TextModelCapabilities(runtime RuntimeIdentity, model TextModel) CapabilityReport { + if model == nil { + return CapabilityReport{Runtime: runtime} + } + info := model.Info() + report := CapabilityReport{ + Runtime: runtime, + Available: true, + Model: ModelIdentity{ + Architecture: info.Architecture, + VocabSize: info.VocabSize, + NumLayers: info.NumLayers, + HiddenSize: info.HiddenSize, + QuantBits: info.QuantBits, + QuantGroup: info.QuantGroup, + }, + Capabilities: []Capability{ + SupportedCapability(CapabilityGenerate, CapabilityGroupModel), + SupportedCapability(CapabilityChat, CapabilityGroupModel), + SupportedCapability(CapabilityClassify, CapabilityGroupModel), + SupportedCapability(CapabilityBatchGenerate, CapabilityGroupModel), + }, + } + if tokenizer, ok := model.(TokenizerModel); ok { + report.Capabilities = append(report.Capabilities, + SupportedCapability(CapabilityTokenizer, CapabilityGroupModel), + SupportedCapability(CapabilityChatTemplate, CapabilityGroupModel), + ) + _ = tokenizer + } + if adapter, ok := model.(AdapterModel); ok { + report.Adapter = adapter.ActiveAdapter() + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityLoRAInference, CapabilityGroupModel)) + } + if _, ok := model.(StatefulModel); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityStateBundle, CapabilityGroupRuntime)) + } + if _, ok := model.(ProbeableModel); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityProbeEvents, CapabilityGroupProbe)) + } + if _, ok := model.(AttentionInspector); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityAttentionProbe, CapabilityGroupProbe)) + } + if _, ok := model.(BenchableModel); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityBenchmark, CapabilityGroupRuntime)) + } + if _, ok := model.(Evaluator); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityEvaluation, CapabilityGroupRuntime)) + } + if _, ok := model.(SchedulerModel); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityScheduler, CapabilityGroupRuntime)) + } + if _, ok := model.(CancellableModel); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityRequestCancel, CapabilityGroupRuntime)) + } + if _, ok := model.(CacheService); ok { + report.Capabilities = append(report.Capabilities, + SupportedCapability(CapabilityCacheBlocks, CapabilityGroupRuntime), + SupportedCapability(CapabilityCacheWarm, CapabilityGroupRuntime), + ) + } + if _, ok := model.(EmbeddingModel); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityEmbeddings, CapabilityGroupModel)) + } + if _, ok := model.(RerankModel); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityRerank, CapabilityGroupModel)) + } + if _, ok := model.(ReasoningParser); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityReasoningParse, CapabilityGroupModel)) + } + if _, ok := model.(ToolParser); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityToolParse, CapabilityGroupModel)) + } + if _, ok := model.(SFTTrainer); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityLoRATraining, CapabilityGroupTraining)) + } + if _, ok := model.(DistillTrainer); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityDistillation, CapabilityGroupTraining)) + } + if _, ok := model.(GRPOTrainer); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityGRPO, CapabilityGroupTraining)) + } + if _, ok := model.(ModelFitPlanner); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityModelFit, CapabilityGroupRuntime)) + } + if _, ok := model.(AgentMemorySession); ok { + report.Capabilities = append(report.Capabilities, + SupportedCapability(CapabilityAgentMemory, CapabilityGroupRuntime), + SupportedCapability(CapabilityStateWake, CapabilityGroupRuntime), + SupportedCapability(CapabilityStateSleep, CapabilityGroupRuntime), + ) + } + if _, ok := model.(AgentMemoryForker); ok { + report.Capabilities = append(report.Capabilities, SupportedCapability(CapabilityStateFork, CapabilityGroupRuntime)) + } + return report +} + +func cloneCapability(capability Capability) Capability { + capability.Labels = maps.Clone(capability.Labels) + return capability +} + +// TokenizerModel exposes native tokenisation and chat-template handling. +type TokenizerModel interface { + Encode(text string) []int32 + Decode(ids []int32) string + ApplyChatTemplate(messages []Message) (string, error) +} + +// AdapterModel exposes LoRA adapter lifecycle operations for inference. +type AdapterModel interface { + LoadAdapter(path string) (AdapterIdentity, error) + UnloadAdapter() error + ActiveAdapter() AdapterIdentity +} + +// StatefulModel exposes portable model-state capture and restore. +type StatefulModel interface { + CaptureState(ctx context.Context, prompt string, opts ...GenerateOption) (*StateBundle, error) + RestoreState(ctx context.Context, bundle *StateBundle) error +} + +// ProbeableModel accepts a typed probe sink for inference or training events. +type ProbeableModel interface { + SetProbeSink(sink ProbeSink) +} + +// BenchableModel runs local benchmark workloads. +type BenchableModel interface { + Benchmark(ctx context.Context, cfg BenchConfig) (*BenchReport, error) +} + +// ModelFitPlanner estimates whether a model fits a memory budget. +type ModelFitPlanner interface { + PlanModelFit(ctx context.Context, model ModelIdentity, memoryBytes uint64) (*ModelFitReport, error) +} + +// SFTTrainer trains a model or adapter with supervised fine tuning. +type SFTTrainer interface { + TrainSFT(ctx context.Context, dataset DatasetStream, cfg TrainingConfig) (*TrainingResult, error) +} + +// DistillTrainer trains a student model from teacher outputs. +type DistillTrainer interface { + Distill(ctx context.Context, dataset DatasetStream, cfg DistillConfig) (*TrainingResult, error) +} + +// GRPOTrainer trains grouped reasoning rollouts. +type GRPOTrainer interface { + TrainGRPO(ctx context.Context, dataset DatasetStream, cfg GRPOConfig) (*TrainingResult, error) +} diff --git a/go/capability_bench_test.go b/go/capability_bench_test.go new file mode 100644 index 0000000..b390879 --- /dev/null +++ b/go/capability_bench_test.go @@ -0,0 +1,326 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the capability / report surface. +// Per AX-11 — every model load synthesises a CapabilityReport, +// every dispatcher does Supports(id) / Capability(id) lookups during +// routing decisions, and BackendCapabilities + TextModelCapabilities +// run once per Register() and once per LoadModel respectively. Even +// modest allocation cost compounds across the per-request cache check +// and the per-route capability scan. +// +// Run: go test -bench=BenchmarkCapability -benchmem -run='^$' . + +package inference + +import ( + "testing" +) + +// Sinks defeat compiler DCE. +var ( + capBenchSinkReport CapabilityReport + capBenchSinkCapability Capability + capBenchSinkCapBool bool + capBenchSinkCapIDs []CapabilityID + capBenchSinkProfile AlgorithmProfile + capBenchSinkAnyOK bool +) + +// benchAlgorithmProfile builds a representative algorithm profile — +// the shape backends publish to expose their feature surface without +// leaking concrete runtime types. +func benchAlgorithmProfile() AlgorithmProfile { + return AlgorithmProfile{ + ID: CapabilityKVSnapshot, + Group: CapabilityGroupRuntime, + CapabilityStatus: CapabilityStatusSupported, + RuntimeStatus: FeatureRuntimeNative, + Algorithm: "qwen3-paged-q8", + Detail: "native kv snapshot with paged q8 encoding", + Architectures: []string{"qwen3", "gemma3", "llama3"}, + Requires: []CapabilityID{CapabilityModelLoad, CapabilityStateBundle}, + Provides: []string{"snapshot", "resume", "fork"}, + Notes: []string{"verified against gemma3-1b", "q8 only"}, + } +} + +// benchCapabilityReport builds a CapabilityReport with the typical +// 8-12 capability entries a real text-model backend publishes. Used +// to exercise lookup + clone paths against realistic input shape. +func benchCapabilityReport() CapabilityReport { + return CapabilityReport{ + Runtime: RuntimeIdentity{Backend: "metal", Device: "M3 Ultra", NativeRuntime: true}, + Model: ModelIdentity{Architecture: "qwen3", NumLayers: 28, QuantBits: 4}, + Tokenizer: TokenizerIdentity{Kind: "sentencepiece", EOSID: 2}, + Adapter: AdapterIdentity{Hash: "sha256:abc", Format: "lora", Rank: 16}, + Available: true, + Architectures: []string{"qwen3", "gemma3", "llama3"}, + Quantizations: []string{"q4_0", "q8_0", "f16"}, + CacheModes: []string{"paged-q8", "paged-f16"}, + Capabilities: []Capability{ + SupportedCapability(CapabilityModelLoad, CapabilityGroupRuntime), + SupportedCapability(CapabilityGenerate, CapabilityGroupModel), + SupportedCapability(CapabilityChat, CapabilityGroupModel), + SupportedCapability(CapabilityClassify, CapabilityGroupModel), + SupportedCapability(CapabilityBatchGenerate, CapabilityGroupModel), + SupportedCapability(CapabilityTokenizer, CapabilityGroupModel), + SupportedCapability(CapabilityKVSnapshot, CapabilityGroupRuntime), + ExperimentalCapability(CapabilityProbeEvents, CapabilityGroupProbe, "research telemetry"), + PlannedCapability(CapabilityQuantization, CapabilityGroupRuntime, "future"), + UnsupportedCapability(CapabilityGRPO, CapabilityGroupTraining, "no trainer"), + }, + Labels: map[string]string{"profile": "qwen3-paged-q8"}, + } +} + +// --- Constructors (per-Register / per-LoadModel cost) --- + +func BenchmarkCapability_NewCapability(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapability = NewCapability(CapabilityGenerate, CapabilityGroupModel, CapabilityStatusSupported, "") + } +} + +func BenchmarkCapability_SupportedCapability(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapability = SupportedCapability(CapabilityGenerate, CapabilityGroupModel) + } +} + +func BenchmarkCapability_ExperimentalCapability(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapability = ExperimentalCapability(CapabilityProbeEvents, CapabilityGroupProbe, "telemetry") + } +} + +func BenchmarkCapability_PlannedCapability(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapability = PlannedCapability(CapabilityQuantization, CapabilityGroupRuntime, "future") + } +} + +func BenchmarkCapability_UnsupportedCapability(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapability = UnsupportedCapability(CapabilityGRPO, CapabilityGroupTraining, "no trainer") + } +} + +// --- Lookup hot path: Supports / Capability --- +// Dispatchers call these per request to decide which backend +// handles which surface. A 10-cap report scanned linearly is the +// floor we pay every routing decision. + +func BenchmarkCapability_Supports_Hit(b *testing.B) { + report := benchCapabilityReport() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapBool = report.Supports(CapabilityGenerate) + } +} + +func BenchmarkCapability_Supports_HitMiddle(b *testing.B) { + // Middle of the 10-entry list — average linear-scan cost. + report := benchCapabilityReport() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapBool = report.Supports(CapabilityKVSnapshot) + } +} + +func BenchmarkCapability_Supports_Miss(b *testing.B) { + // Worst case — full scan with no match. + report := benchCapabilityReport() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapBool = report.Supports(CapabilityMoELazyExperts) + } +} + +func BenchmarkCapability_Capability_Hit(b *testing.B) { + report := benchCapabilityReport() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapability, capBenchSinkCapBool = report.Capability(CapabilityGenerate) + } +} + +func BenchmarkCapability_Capability_Miss(b *testing.B) { + report := benchCapabilityReport() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapability, capBenchSinkCapBool = report.Capability(CapabilityMoELazyExperts) + } +} + +// --- ID-list helpers (typical request: "what does this backend do?") --- + +func BenchmarkCapability_SupportedCapabilityIDs(b *testing.B) { + report := benchCapabilityReport() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapIDs = report.SupportedCapabilityIDs() + } +} + +func BenchmarkCapability_CapabilityIDs(b *testing.B) { + report := benchCapabilityReport() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapIDs = report.CapabilityIDs() + } +} + +// --- Usable (single-cap usability check, called per scan iteration) --- + +func BenchmarkCapability_Usable_Supported(b *testing.B) { + cap := SupportedCapability(CapabilityGenerate, CapabilityGroupModel) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapBool = cap.Usable() + } +} + +func BenchmarkCapability_Usable_Planned(b *testing.B) { + cap := PlannedCapability(CapabilityQuantization, CapabilityGroupRuntime, "future") + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapBool = cap.Usable() + } +} + +// --- AlgorithmProfile.Capability — profile → portable cap conversion --- +// Backends call this once per published algorithm during init. + +func BenchmarkCapability_AlgorithmProfile_Capability(b *testing.B) { + profile := benchAlgorithmProfile() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkCapability = profile.Capability() + } +} + +func BenchmarkCapability_CloneAlgorithmProfile(b *testing.B) { + profile := benchAlgorithmProfile() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkProfile = CloneAlgorithmProfile(profile) + } +} + +// --- BackendCapabilities — per-Register inference floor --- + +func BenchmarkCapability_BackendCapabilities_Plain(b *testing.B) { + backend := &stubBackend{name: "stub", available: true} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport = BackendCapabilities(backend) + } +} + +func BenchmarkCapability_BackendCapabilities_Nil(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport = BackendCapabilities(nil) + } +} + +// --- TextModelCapabilities — per-LoadModel inference floor --- +// The full optional-interface assertion ladder pays here. + +func BenchmarkCapability_TextModelCapabilities_Plain(b *testing.B) { + model := &stubTextModel{} + runtime := RuntimeIdentity{Backend: "test"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport = TextModelCapabilities(runtime, model) + } +} + +func BenchmarkCapability_TextModelCapabilities_FullSurface(b *testing.B) { + model := &capabilityModel{stubTextModel: &stubTextModel{}} + runtime := RuntimeIdentity{Backend: "test"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport = TextModelCapabilities(runtime, model) + } +} + +func BenchmarkCapability_TextModelCapabilities_Nil(b *testing.B) { + runtime := RuntimeIdentity{Backend: "test"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport = TextModelCapabilities(runtime, nil) + } +} + +// --- CapabilitiesOf — generic any-typed dispatch lookup --- + +func BenchmarkCapability_CapabilitiesOf_Reporter(b *testing.B) { + value := any(&capabilityModel{stubTextModel: &stubTextModel{}}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport, capBenchSinkAnyOK = CapabilitiesOf(value) + } +} + +func BenchmarkCapability_CapabilitiesOf_Backend(b *testing.B) { + value := any(Backend(&stubBackend{name: "stub", available: true})) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport, capBenchSinkAnyOK = CapabilitiesOf(value) + } +} + +func BenchmarkCapability_CapabilitiesOf_TextModel(b *testing.B) { + value := any(TextModel(&stubTextModel{})) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport, capBenchSinkAnyOK = CapabilitiesOf(value) + } +} + +func BenchmarkCapability_CapabilitiesOf_Unknown(b *testing.B) { + value := any(struct{}{}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport, capBenchSinkAnyOK = CapabilitiesOf(value) + } +} + +func BenchmarkCapability_CapabilitiesOf_Nil(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + capBenchSinkReport, capBenchSinkAnyOK = CapabilitiesOf(nil) + } +} diff --git a/go/capability_example_test.go b/go/capability_example_test.go new file mode 100644 index 0000000..5da0062 --- /dev/null +++ b/go/capability_example_test.go @@ -0,0 +1,43 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import core "dappco.re/go" + +func ExampleTokenizerModel() { + model := &capabilityModel{} + tokenizer, ok := any(model).(TokenizerModel) + if !ok { + return + } + + core.Println(tokenizer.Decode(tokenizer.Encode("hello"))) + // Output: 1 +} + +func ExampleAdapterModel() { + model := &capabilityModel{} + adapter, ok := any(model).(AdapterModel) + if !ok { + return + } + + identity, _ := adapter.LoadAdapter("/models/domain/adapter.safetensors") + + core.Println(identity.Format) + // Output: lora +} + +func ExampleCapabilityReporter() { + model := &capabilityModel{} + report, ok := CapabilitiesOf(model) + if !ok { + return + } + + core.Println(report.Runtime.Backend) + core.Println(report.Supports(CapabilityProbeEvents)) + // Output: + // stub + // true +} diff --git a/go/capability_test.go b/go/capability_test.go new file mode 100644 index 0000000..0925c49 --- /dev/null +++ b/go/capability_test.go @@ -0,0 +1,280 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "context" + "testing" + + core "dappco.re/go" +) + +type capabilityModel struct { + *stubTextModel + sink ProbeSink + adapter AdapterIdentity +} + +func (m *capabilityModel) Encode(text string) []int32 { + return []int32{int32(len(text))} +} + +func (m *capabilityModel) Decode(ids []int32) string { + return core.Sprintf("%d", len(ids)) +} + +func (m *capabilityModel) ApplyChatTemplate(messages []Message) (string, error) { + if len(messages) == 0 { + return "", nil + } + return messages[0].Content, nil +} + +func (m *capabilityModel) LoadAdapter(path string) (AdapterIdentity, error) { + m.adapter = AdapterIdentity{Path: path, Format: "lora"} + return m.adapter, nil +} + +func (m *capabilityModel) UnloadAdapter() error { + m.adapter = AdapterIdentity{} + return nil +} + +func (m *capabilityModel) ActiveAdapter() AdapterIdentity { + return m.adapter +} + +func (m *capabilityModel) CaptureState(context.Context, string, ...GenerateOption) (*StateBundle, error) { + return &StateBundle{Model: ModelIdentity{Architecture: "stub"}}, nil +} + +func (m *capabilityModel) RestoreState(context.Context, *StateBundle) error { + return nil +} + +func (m *capabilityModel) SetProbeSink(sink ProbeSink) { + m.sink = sink +} + +func (m *capabilityModel) Benchmark(context.Context, BenchConfig) (*BenchReport, error) { + return &BenchReport{Model: ModelIdentity{Architecture: "stub"}}, nil +} + +func (m *capabilityModel) PlanModelFit(context.Context, ModelIdentity, uint64) (*ModelFitReport, error) { + return &ModelFitReport{Fits: true}, nil +} + +func (m *capabilityModel) TrainSFT(context.Context, DatasetStream, TrainingConfig) (*TrainingResult, error) { + return &TrainingResult{Adapter: AdapterIdentity{Format: "lora"}}, nil +} + +func (m *capabilityModel) Distill(context.Context, DatasetStream, DistillConfig) (*TrainingResult, error) { + return &TrainingResult{Model: ModelIdentity{Architecture: "student"}}, nil +} + +func (m *capabilityModel) TrainGRPO(context.Context, DatasetStream, GRPOConfig) (*TrainingResult, error) { + return &TrainingResult{Metrics: TrainingMetrics{Step: 1}}, nil +} + +func (m *capabilityModel) Capabilities() CapabilityReport { + return CapabilityReport{ + Runtime: RuntimeIdentity{Backend: "stub", NativeRuntime: true}, + Available: true, + Capabilities: []Capability{ + SupportedCapability(CapabilityGenerate, CapabilityGroupModel), + ExperimentalCapability(CapabilityProbeEvents, CapabilityGroupProbe, "test sink"), + PlannedCapability(CapabilityQuantization, CapabilityGroupRuntime, "not in stub"), + }, + } +} + +func TestCapabilityInterfaces(t *testing.T) { + model := &capabilityModel{stubTextModel: &stubTextModel{}} + + _, ok := any(model).(TokenizerModel) + checkTrue(t, ok) + _, ok = any(model).(AdapterModel) + checkTrue(t, ok) + _, ok = any(model).(StatefulModel) + checkTrue(t, ok) + _, ok = any(model).(ProbeableModel) + checkTrue(t, ok) + _, ok = any(model).(BenchableModel) + checkTrue(t, ok) + _, ok = any(model).(ModelFitPlanner) + checkTrue(t, ok) + _, ok = any(model).(SFTTrainer) + checkTrue(t, ok) + _, ok = any(model).(DistillTrainer) + checkTrue(t, ok) + _, ok = any(model).(GRPOTrainer) + checkTrue(t, ok) + _, ok = any(model).(CapabilityReporter) + checkTrue(t, ok) +} + +func TestCapability_TokenizerModel_Good(t *testing.T) { + model := &capabilityModel{} + tokenizer := any(model).(TokenizerModel) + + ids := tokenizer.Encode("hello") + text := tokenizer.Decode([]int32{1, 2, 3}) + prompt, err := tokenizer.ApplyChatTemplate([]Message{{Role: "user", Content: "hi"}}) + + checkNoError(t, err) + checkEqual(t, []int32{5}, ids) + checkEqual(t, "3", text) + checkEqual(t, "hi", prompt) +} + +func TestCapability_AdapterModel_Good(t *testing.T) { + model := &capabilityModel{} + adapter := any(model).(AdapterModel) + + identity, err := adapter.LoadAdapter("/tmp/adapter.safetensors") + checkNoError(t, err) + checkEqual(t, "/tmp/adapter.safetensors", identity.Path) + checkEqual(t, "lora", adapter.ActiveAdapter().Format) + + checkNoError(t, adapter.UnloadAdapter()) + checkEqual(t, AdapterIdentity{}, adapter.ActiveAdapter()) +} + +func TestCapability_StateAndProbe_Ugly_MinimalModel(t *testing.T) { + model := &capabilityModel{} + stateful := any(model).(StatefulModel) + probeable := any(model).(ProbeableModel) + + bundle, err := stateful.CaptureState(context.Background(), "prompt") + checkNoError(t, err) + checkEqual(t, "stub", bundle.Model.Architecture) + + probeable.SetProbeSink(ProbeSinkFunc(func(ProbeEvent) {})) + checkNotNil(t, model.sink) +} + +func TestCapability_ReportHelpers_Good(t *testing.T) { + report := CapabilityReport{ + Capabilities: []Capability{ + SupportedCapability(CapabilityGenerate, CapabilityGroupModel), + ExperimentalCapability(CapabilityProbeEvents, CapabilityGroupProbe, "research telemetry"), + PlannedCapability(CapabilityQuantization, CapabilityGroupRuntime, "future"), + UnsupportedCapability(CapabilityGRPO, CapabilityGroupTraining, "stub"), + }, + } + + checkTrue(t, report.Supports(CapabilityGenerate)) + checkTrue(t, report.Supports(CapabilityProbeEvents)) + checkFalse(t, report.Supports(CapabilityQuantization)) + checkFalse(t, report.Supports(CapabilityGRPO)) + checkEqual(t, []CapabilityID{CapabilityGenerate, CapabilityProbeEvents}, report.SupportedCapabilityIDs()) + checkEqual(t, []CapabilityID{CapabilityGenerate, CapabilityGRPO, CapabilityProbeEvents, CapabilityQuantization}, report.CapabilityIDs()) +} + +func TestCapability_CapabilityClone_Ugly(t *testing.T) { + report := CapabilityReport{Capabilities: []Capability{{ + ID: CapabilityGenerate, + Group: CapabilityGroupModel, + Status: CapabilityStatusSupported, + Labels: map[string]string{"backend": "stub"}, + }}} + + capability, ok := report.Capability(CapabilityGenerate) + checkTrue(t, ok) + capability.Labels["backend"] = "mutated" + + again, ok := report.Capability(CapabilityGenerate) + checkTrue(t, ok) + checkEqual(t, "stub", again.Labels["backend"]) +} + +func TestCapability_CapabilitiesOfReporter_Good(t *testing.T) { + model := &capabilityModel{stubTextModel: &stubTextModel{}} + + report, ok := CapabilitiesOf(model) + + checkTrue(t, ok) + checkTrue(t, report.Available) + checkEqual(t, "stub", report.Runtime.Backend) + checkTrue(t, report.Supports(CapabilityGenerate)) + checkTrue(t, report.Supports(CapabilityProbeEvents)) +} + +func TestCapability_TextModelCapabilities_Good(t *testing.T) { + model := &capabilityModel{stubTextModel: &stubTextModel{}} + + report := TextModelCapabilities(RuntimeIdentity{Backend: "test"}, model) + + checkEqual(t, "test", report.Runtime.Backend) + checkTrue(t, report.Supports(CapabilityGenerate)) + checkTrue(t, report.Supports(CapabilityTokenizer)) + checkTrue(t, report.Supports(CapabilityLoRAInference)) + checkTrue(t, report.Supports(CapabilityStateBundle)) + checkTrue(t, report.Supports(CapabilityBenchmark)) + checkTrue(t, report.Supports(CapabilityLoRATraining)) + checkTrue(t, report.Supports(CapabilityDistillation)) + checkTrue(t, report.Supports(CapabilityGRPO)) +} + +func TestCapability_BackendCapabilities_BadUnavailable(t *testing.T) { + backend := &stubBackend{name: "gpu", available: false} + + report, ok := CapabilitiesOf(backend) + + checkTrue(t, ok) + checkFalse(t, report.Available) + checkEqual(t, "gpu", report.Runtime.Backend) + checkTrue(t, report.Supports(CapabilityModelLoad)) +} + +func TestCapability_CapabilitiesOfUnknown_Ugly(t *testing.T) { + report, ok := CapabilitiesOf(struct{}{}) + + checkFalse(t, ok) + checkEqual(t, CapabilityReport{}, report) +} + +type memoryLimitBackend struct { + stubBackend + seen RuntimeMemoryLimits +} + +func (backend *memoryLimitBackend) SetRuntimeMemoryLimits(limits RuntimeMemoryLimits) RuntimeMemoryLimits { + backend.seen = limits + limits.PreviousCacheLimitBytes = 128 + limits.PreviousMemoryLimitBytes = 256 + return limits +} + +func TestCapability_SetRuntimeMemoryLimits_Good(t *testing.T) { + resetBackends(t) + backend := &memoryLimitBackend{stubBackend: stubBackend{name: "metal", available: true}} + Register(backend) + + applied, ok := SetRuntimeMemoryLimits("metal", RuntimeMemoryLimits{CacheLimitBytes: 1024, MemoryLimitBytes: 2048}) + + checkTrue(t, ok) + checkEqual(t, uint64(1024), backend.seen.CacheLimitBytes) + checkEqual(t, uint64(2048), backend.seen.MemoryLimitBytes) + checkEqual(t, uint64(128), applied.PreviousCacheLimitBytes) + checkEqual(t, uint64(256), applied.PreviousMemoryLimitBytes) +} + +func TestCapability_SetRuntimeMemoryLimits_BadMissing(t *testing.T) { + resetBackends(t) + + applied, ok := SetRuntimeMemoryLimits("metal", RuntimeMemoryLimits{CacheLimitBytes: 1024}) + + checkFalse(t, ok) + checkEqual(t, RuntimeMemoryLimits{}, applied) +} + +func TestCapability_SetRuntimeMemoryLimits_UglyUnsupported(t *testing.T) { + resetBackends(t) + Register(&stubBackend{name: "plain", available: true}) + + applied, ok := SetRuntimeMemoryLimits("plain", RuntimeMemoryLimits{CacheLimitBytes: 1024}) + + checkFalse(t, ok) + checkEqual(t, RuntimeMemoryLimits{}, applied) +} diff --git a/go/cmd/lthn-model-pack/main.go b/go/cmd/lthn-model-pack/main.go new file mode 100644 index 0000000..2ea2a41 --- /dev/null +++ b/go/cmd/lthn-model-pack/main.go @@ -0,0 +1,152 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Command lthn-model-pack wraps the model/pack primitives as a CLI so +// .model Trix containers can be built, extracted, and inspected from the +// terminal without going through a service. +// +// lthn-model-pack pack /models/gemma-3-4b-it /out/gemma-3-4b-it.model -arch gemma -quant 4 +// lthn-model-pack inspect /out/gemma-3-4b-it.model +// lthn-model-pack unpack /out/gemma-3-4b-it.model /tmp/extracted +package main + +import ( + "flag" + "os" + + core "dappco.re/go" + "dappco.re/go/inference" + "dappco.re/go/inference/model/pack" +) + +const usage = `Usage: + lthn-model-pack pack [-arch X] [-quant N] [-source safetensors|gguf] [-producer X] + lthn-model-pack unpack [-overwrite] + lthn-model-pack list + lthn-model-pack inspect + +Flags must come before positional arguments.` + +func main() { + if len(os.Args) < 2 { + core.Print(os.Stderr, "%s", usage) + os.Exit(2) + } + var r core.Result + switch os.Args[1] { + case "pack": + r = runPack(os.Args[2:]) + case "unpack": + r = runUnpack(os.Args[2:]) + case "list": + r = runList(os.Args[2:]) + case "inspect": + r = runInspect(os.Args[2:]) + case "-h", "--help", "help": + core.Print(os.Stdout, "%s", usage) + return + default: + core.Print(os.Stderr, "unknown verb %q", os.Args[1]) + core.Print(os.Stderr, "%s", usage) + os.Exit(2) + } + if !r.OK { + core.Print(os.Stderr, "lthn-model-pack: %v", r.Value) + os.Exit(1) + } +} + +func runPack(args []string) core.Result { + fs := flag.NewFlagSet("pack", flag.ExitOnError) + arch := fs.String("arch", "", "model architecture (e.g. gemma)") + quantBits := fs.Int("quant", 0, "quantisation bits (0 for none)") + sourceFormat := fs.String("source", "safetensors", "source format: safetensors|gguf") + producerName := fs.String("producer", "lthn-model-pack", "producer name") + if err := fs.Parse(args); err != nil { + return core.Fail(core.E("pack", "parse flags", err)) + } + rest := fs.Args() + if len(rest) != 2 { + return core.Fail(core.E("pack", "expected: pack ", nil)) + } + srcDir, dest := rest[0], rest[1] + + r := pack.Pack(srcDir, dest, pack.PackOptions{ + Manifest: pack.Manifest{ + Model: inference.ModelIdentity{ + Architecture: *arch, + QuantBits: *quantBits, + }, + SourceFormat: *sourceFormat, + Producer: pack.Producer{Name: *producerName}, + }, + }) + if r.OK { + core.Print(os.Stdout, "packed %s -> %s", srcDir, dest) + } + return r +} + +func runUnpack(args []string) core.Result { + fs := flag.NewFlagSet("unpack", flag.ExitOnError) + overwrite := fs.Bool("overwrite", false, "allow writing into a non-empty destDir") + if err := fs.Parse(args); err != nil { + return core.Fail(core.E("unpack", "parse flags", err)) + } + rest := fs.Args() + if len(rest) != 2 { + return core.Fail(core.E("unpack", "expected: unpack ", nil)) + } + src, destDir := rest[0], rest[1] + + r := pack.Unpack(src, destDir, pack.UnpackOptions{Overwrite: *overwrite}) + if r.OK { + core.Print(os.Stdout, "unpacked %s -> %s", src, destDir) + } + return r +} + +func runList(args []string) core.Result { + if len(args) != 1 { + return core.Fail(core.E("list", "expected: list ", nil)) + } + src := args[0] + + entries, manifest, r := pack.List(src) + if !r.OK { + return r + } + bundle := map[string]any{ + "manifest": manifest, + "entries": entries, + "count": len(entries), + } + jr := core.JSONMarshalIndent(bundle, "", " ") + if !jr.OK { + return jr + } + core.Print(os.Stdout, "%s", string(jr.Value.([]byte))) + return core.Ok(nil) +} + +func runInspect(args []string) core.Result { + if len(args) != 1 { + return core.Fail(core.E("inspect", "expected: inspect ", nil)) + } + src := args[0] + + manifest, inspection, r := pack.Inspect(src) + if !r.OK { + return r + } + bundle := map[string]any{ + "manifest": manifest, + "inspection": inspection, + "fingerprint": pack.Fingerprint(*manifest), + } + jr := core.JSONMarshalIndent(bundle, "", " ") + if !jr.OK { + return jr + } + core.Print(os.Stdout, "%s", string(jr.Value.([]byte))) + return core.Ok(nil) +} diff --git a/go/contracts.go b/go/contracts.go new file mode 100644 index 0000000..00752b1 --- /dev/null +++ b/go/contracts.go @@ -0,0 +1,241 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "context" + + "dappco.re/go/inference/state" +) + +// RequestHandle identifies an in-flight generation request without requiring +// a concrete scheduler implementation. +type RequestHandle struct { + ID string `json:"id,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// RequestCancelResult records the outcome of a cancellation request. +type RequestCancelResult struct { + ID string `json:"id,omitempty"` + Cancelled bool `json:"cancelled,omitempty"` + Reason string `json:"reason,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ScheduledRequest is the backend-neutral input to an optional request +// scheduler. Exactly one of Prompt or Messages is normally populated. +type ScheduledRequest struct { + ID string `json:"id,omitempty"` + Model string `json:"model,omitempty"` + Prompt string `json:"prompt,omitempty"` + Messages []Message `json:"messages,omitempty"` + Sampler SamplerConfig `json:"sampler,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ScheduledToken carries a streamed token plus request-local telemetry. +// +// Labels is shared across every token of a single request stream — +// scheduler implementations build the map once at request start +// (queue_latency_ms is added then; first_token_latency_ms lands on +// the first token) and reuse the same map reference for the +// remainder of the stream. Consumers MUST NOT mutate Labels and +// MUST treat reads as point-in-time snapshots; reads concurrent +// with the scheduler writing first_token_latency_ms on the first +// emission are safe because the channel send happens-after the +// write within the producer goroutine, but cross-stream mutation +// would race other receivers of the same value. +type ScheduledToken struct { + RequestID string `json:"request_id,omitempty"` + Token Token `json:"token,omitempty"` + Metrics GenerateMetrics `json:"metrics,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// SchedulerModel exposes queue-aware generation without forcing every backend +// to implement server policy. +type SchedulerModel interface { + Schedule(ctx context.Context, req ScheduledRequest) (RequestHandle, <-chan ScheduledToken, error) +} + +// CancellableModel exposes request cancellation by stable request ID. +type CancellableModel interface { + CancelRequest(ctx context.Context, id string) (RequestCancelResult, error) +} + +// CacheBlockRef is a portable reference to a prompt/KV cache block. +type CacheBlockRef struct { + ID string `json:"id,omitempty"` + Kind string `json:"kind,omitempty"` + ModelHash string `json:"model_hash,omitempty"` + AdapterHash string `json:"adapter_hash,omitempty"` + TokenizerHash string `json:"tokenizer_hash,omitempty"` + TokenStart int `json:"token_start,omitempty"` + TokenCount int `json:"token_count,omitempty"` + SizeBytes uint64 `json:"size_bytes,omitempty"` + Encoding string `json:"encoding,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// CacheStats records request-time cache health. +type CacheStats struct { + Blocks int `json:"blocks,omitempty"` + MemoryBytes uint64 `json:"memory_bytes,omitempty"` + DiskBytes uint64 `json:"disk_bytes,omitempty"` + Hits uint64 `json:"hits,omitempty"` + Misses uint64 `json:"misses,omitempty"` + Evictions uint64 `json:"evictions,omitempty"` + HitRate float64 `json:"hit_rate,omitempty"` + RestoreMillis float64 `json:"restore_millis,omitempty"` + CacheMode string `json:"cache_mode,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// CacheWarmRequest asks a runtime to prepare cache blocks for a prompt. +type CacheWarmRequest struct { + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Prompt string `json:"prompt,omitempty"` + Tokens []int32 `json:"tokens,omitempty"` + Mode string `json:"mode,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// CacheWarmResult reports which cache blocks are available after warming. +type CacheWarmResult struct { + Blocks []CacheBlockRef `json:"blocks,omitempty"` + Stats CacheStats `json:"stats,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// CacheService exposes cache inspection and warm/clear controls. +type CacheService interface { + CacheStats(ctx context.Context) (CacheStats, error) + WarmCache(ctx context.Context, req CacheWarmRequest) (CacheWarmResult, error) + ClearCache(ctx context.Context, labels map[string]string) (CacheStats, error) +} + +// EmbeddingRequest is a backend-neutral embedding request. +type EmbeddingRequest struct { + Model string `json:"model,omitempty"` + Input []string `json:"input,omitempty"` + Normalize bool `json:"normalize,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// EmbeddingUsage records token accounting for embedding calls. +type EmbeddingUsage struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + TotalTokens int `json:"total_tokens,omitempty"` +} + +// EmbeddingResult is the portable output of an embedding model. +type EmbeddingResult struct { + Model ModelIdentity `json:"model,omitempty"` + Vectors [][]float32 `json:"vectors,omitempty"` + Usage EmbeddingUsage `json:"usage,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// EmbeddingModel marks models that can produce vector embeddings. +type EmbeddingModel interface { + Embed(ctx context.Context, req EmbeddingRequest) (*EmbeddingResult, error) +} + +// RerankRequest asks a model to score documents against a query. +type RerankRequest struct { + Model string `json:"model,omitempty"` + Query string `json:"query,omitempty"` + Documents []string `json:"documents,omitempty"` + TopN int `json:"top_n,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// RerankScore records one scored document. +type RerankScore struct { + Index int `json:"index,omitempty"` + Score float64 `json:"score,omitempty"` + Text string `json:"text,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// RerankResult is the portable output of a rerank request. +type RerankResult struct { + Model ModelIdentity `json:"model,omitempty"` + Results []RerankScore `json:"results,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// RerankModel marks models that can score candidate documents. +type RerankModel interface { + Rerank(ctx context.Context, req RerankRequest) (*RerankResult, error) +} + +// ReasoningSegment is a captured reasoning/thinking span. +type ReasoningSegment struct { + Kind string `json:"kind,omitempty"` + Text string `json:"text,omitempty"` + StartToken int `json:"start_token,omitempty"` + EndToken int `json:"end_token,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ReasoningParseResult separates visible model output from reasoning text. +type ReasoningParseResult struct { + VisibleText string `json:"visible_text,omitempty"` + Reasoning []ReasoningSegment `json:"reasoning,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ReasoningParser parses model-family-specific thinking channels. +type ReasoningParser interface { + ParseReasoning(tokens []Token, text string) (ReasoningParseResult, error) +} + +// ToolCall records a parsed model-emitted tool call. +type ToolCall struct { + ID string `json:"id,omitempty"` + Name string `json:"name,omitempty"` + Type string `json:"type,omitempty"` + ArgumentsJSON string `json:"arguments_json,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ToolParseResult separates user-visible text from tool calls. +type ToolParseResult struct { + VisibleText string `json:"visible_text,omitempty"` + Calls []ToolCall `json:"calls,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ToolParser parses model-family-specific tool-call formats. +type ToolParser interface { + ParseTools(tokens []Token, text string) (ToolParseResult, error) +} + +// ModelPackInspection records portable model-pack validation output. +type ModelPackInspection struct { + Path string `json:"path,omitempty"` + Format string `json:"format,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Tokenizer TokenizerIdentity `json:"tokenizer,omitempty"` + Supported bool `json:"supported,omitempty"` + Capabilities []Capability `json:"capabilities,omitempty"` + Notes []string `json:"notes,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ModelPackInspector inspects local model packs without loading tensors. +type ModelPackInspector interface { + InspectModelPack(ctx context.Context, path string) (*ModelPackInspection, error) +} + +type AgentMemoryRef = state.Ref +type AgentMemoryWakeRequest = state.WakeRequest +type AgentMemoryWakeResult = state.WakeResult +type AgentMemorySleepRequest = state.SleepRequest +type AgentMemorySleepResult = state.SleepResult +type AgentMemorySession = state.Session +type AgentMemoryForker = state.Forker diff --git a/go/contracts_bench_test.go b/go/contracts_bench_test.go new file mode 100644 index 0000000..cdd73f5 --- /dev/null +++ b/go/contracts_bench_test.go @@ -0,0 +1,515 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the wire-contract shapes — the value-types that flow +// over scheduler queues, between the cache subsystem and consumers, +// and through the embed / rerank / tool-parse paths. +// Per AX-11 — these shapes are constructed at the rate of generation +// (one ScheduledToken per emitted token; one CacheStats per request; +// CacheBlockRef cloned per warm-cache call), so structural allocation +// pressure here adds to every served request. +// +// Run: go test -bench=BenchmarkContracts -benchmem -run='^$' . + +package inference + +import ( + "context" + "testing" +) + +// Sinks defeat compiler DCE. +var ( + contractsBenchSinkRequestHandle RequestHandle + contractsBenchSinkCancelResult RequestCancelResult + contractsBenchSinkScheduledRequest ScheduledRequest + contractsBenchSinkScheduledToken ScheduledToken + contractsBenchSinkCacheBlockRef CacheBlockRef + contractsBenchSinkCacheStats CacheStats + contractsBenchSinkCacheWarmReq CacheWarmRequest + contractsBenchSinkCacheWarmRes CacheWarmResult + contractsBenchSinkEmbedReq EmbeddingRequest + contractsBenchSinkEmbedRes *EmbeddingResult + contractsBenchSinkRerankReq RerankRequest + contractsBenchSinkRerankRes *RerankResult + contractsBenchSinkReasoningRes ReasoningParseResult + contractsBenchSinkToolRes ToolParseResult + contractsBenchSinkInspection *ModelPackInspection + contractsBenchSinkErr error + contractsBenchSinkChan <-chan ScheduledToken +) + +// benchScheduledRequestSmall — single short prompt, no labels. +// Tests the minimal allocation floor of the scheduler-input shape. +func benchScheduledRequestSmall() ScheduledRequest { + return ScheduledRequest{ + ID: "req-1", + Model: "qwen3", + Prompt: "hello", + Sampler: SamplerConfig{ + MaxTokens: 64, + }, + } +} + +// benchScheduledRequestTypical — typical chat input — 4 messages, +// realistic sampler config, request-side labels. Closer to what the +// scheduler enqueues per chat turn. +func benchScheduledRequestTypical() ScheduledRequest { + return ScheduledRequest{ + ID: "req-typical", + Model: "qwen3", + Messages: []Message{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "What is 2+2?"}, + {Role: "assistant", Content: "4"}, + {Role: "user", Content: "Are you sure?"}, + }, + Sampler: SamplerConfig{ + MaxTokens: 256, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + RepeatPenalty: 1.1, + StopTokens: []int32{2}, + }, + Labels: map[string]string{"user_id": "u-42", "session": "s-7"}, + } +} + +// benchCacheStats — typical request-time cache reading. +func benchCacheStats() CacheStats { + return CacheStats{ + Blocks: 16, + MemoryBytes: 1 << 28, // 256 MiB + DiskBytes: 1 << 30, // 1 GiB + Hits: 1024, + Misses: 128, + Evictions: 12, + HitRate: 0.88, + RestoreMillis: 4.2, + CacheMode: "paged-q8", + Labels: map[string]string{"profile": "qwen3-paged-q8"}, + } +} + +// benchCacheBlockRef — single block descriptor (one of many in a +// CacheWarmResult). Allocated per warmed block. +func benchCacheBlockRef() CacheBlockRef { + return CacheBlockRef{ + ID: "block-7", + Kind: "kv", + ModelHash: "sha256:model", + AdapterHash: "sha256:adapter", + TokenizerHash: "sha256:tok", + TokenStart: 128, + TokenCount: 256, + SizeBytes: 1 << 22, // 4 MiB + Encoding: "paged-q8", + Labels: map[string]string{"layer": "12"}, + } +} + +// benchReasoningParseResult — typical decode-event with 32 visible +// tokens + 1 thinking segment (Qwen3 / Gemma thinking-tokens shape). +func benchReasoningParseResult32Tokens() ReasoningParseResult { + return ReasoningParseResult{ + VisibleText: "The answer is 4 — addition is commutative.", + Reasoning: []ReasoningSegment{ + { + Kind: "think", + Text: "Confirm: 2+2 = 4. Already given as answer; reaffirm with brief justification.", + StartToken: 0, + EndToken: 32, + Labels: map[string]string{"channel": "thinking"}, + }, + }, + } +} + +// benchReasoningParseResult256Tokens — long-form thinking channel. +func benchReasoningParseResult256Tokens() ReasoningParseResult { + return ReasoningParseResult{ + VisibleText: "After step-by-step reasoning, the answer is 4.", + Reasoning: []ReasoningSegment{ + { + Kind: "think", + Text: "Step 1: Identify the operation as addition. Step 2: Recall 2+2. Step 3: Apply the additive identity for natural numbers. Step 4: Cross-check by counting. Step 5: Confirm 4. Step 6: Make sure no edge cases (negative, decimal). Step 7: Final answer is 4.", + StartToken: 0, + EndToken: 256, + Labels: map[string]string{"channel": "thinking"}, + }, + }, + } +} + +// --- ScheduledRequest / ScheduledToken construction --- +// One ScheduledToken per emitted token — the wire shape callers +// destructure per yield. + +func BenchmarkContracts_ScheduledRequest_Small(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkScheduledRequest = benchScheduledRequestSmall() + } +} + +func BenchmarkContracts_ScheduledRequest_Typical(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkScheduledRequest = benchScheduledRequestTypical() + } +} + +func BenchmarkContracts_ScheduledToken(b *testing.B) { + metrics := GenerateMetrics{PromptTokens: 128, GeneratedTokens: 1} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkScheduledToken = ScheduledToken{ + RequestID: "req-7", + Token: Token{ID: 42, Text: "hello"}, + Metrics: metrics, + } + } +} + +func BenchmarkContracts_RequestHandle(b *testing.B) { + identity := ModelIdentity{Architecture: "qwen3"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkRequestHandle = RequestHandle{ + ID: "req-1", + Model: identity, + } + } +} + +func BenchmarkContracts_RequestCancelResult(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkCancelResult = RequestCancelResult{ + ID: "req-1", + Cancelled: true, + Reason: "client closed connection", + } + } +} + +// --- CacheStats / CacheBlockRef (per-request cache reading) --- + +func BenchmarkContracts_CacheStats_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkCacheStats = benchCacheStats() + } +} + +func BenchmarkContracts_CacheBlockRef_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkCacheBlockRef = benchCacheBlockRef() + } +} + +// --- CacheWarmRequest / CacheWarmResult --- +// Per warm-cache call: 1 request shape + 1 result shape carrying N blocks. + +func BenchmarkContracts_CacheWarmRequest_64Tokens(b *testing.B) { + tokens := make([]int32, 64) + for i := range tokens { + tokens[i] = int32(i + 1) + } + model := ModelIdentity{Architecture: "qwen3"} + adapter := AdapterIdentity{Format: "lora"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkCacheWarmReq = CacheWarmRequest{ + Model: model, + Adapter: adapter, + Prompt: "hello", + Tokens: tokens, + Mode: "paged-q8", + } + } +} + +func BenchmarkContracts_CacheWarmResult_8Blocks(b *testing.B) { + blocks := []CacheBlockRef{ + benchCacheBlockRef(), benchCacheBlockRef(), benchCacheBlockRef(), benchCacheBlockRef(), + benchCacheBlockRef(), benchCacheBlockRef(), benchCacheBlockRef(), benchCacheBlockRef(), + } + stats := benchCacheStats() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkCacheWarmRes = CacheWarmResult{ + Blocks: blocks, + Stats: stats, + } + } +} + +// --- Embedding wire-shape (per-request constructor cost) --- + +func BenchmarkContracts_EmbeddingRequest_8Inputs(b *testing.B) { + inputs := []string{"alpha", "beta", "gamma", "delta", "epsilon", "zeta", "eta", "theta"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkEmbedReq = EmbeddingRequest{ + Model: "qwen3-embed", + Input: inputs, + Normalize: true, + } + } +} + +func BenchmarkContracts_EmbeddingResult_8Vectors(b *testing.B) { + model := ModelIdentity{Architecture: "qwen3-embed"} + model.Hash = "sha256:embed-1" + vectors := make([][]float32, 8) + for i := range vectors { + vec := make([]float32, 64) + for j := range vec { + vec[j] = float32(i + j) + } + vectors[i] = vec + } + model.Path = "/models/embed" + model.VocabSize = 32000 + model.NumLayers = 12 + model.HiddenSize = 768 + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkEmbedRes = &EmbeddingResult{ + Model: model, + Vectors: vectors, + Usage: EmbeddingUsage{PromptTokens: 32, TotalTokens: 32}, + } + } +} + +// --- Rerank wire-shape --- + +func BenchmarkContracts_RerankRequest_16Docs(b *testing.B) { + docs := []string{ + "doc-a", "doc-b", "doc-c", "doc-d", + "doc-e", "doc-f", "doc-g", "doc-h", + "doc-i", "doc-j", "doc-k", "doc-l", + "doc-m", "doc-n", "doc-o", "doc-p", + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkRerankReq = RerankRequest{ + Model: "qwen3-rerank", + Query: "what is the meaning", + Documents: docs, + TopN: 4, + } + } +} + +func BenchmarkContracts_RerankResult_4Scores(b *testing.B) { + model := ModelIdentity{Architecture: "qwen3-rerank"} + results := []RerankScore{ + {Index: 0, Score: 0.91, Text: "doc-a"}, + {Index: 3, Score: 0.84, Text: "doc-d"}, + {Index: 7, Score: 0.71, Text: "doc-h"}, + {Index: 9, Score: 0.60, Text: "doc-j"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkRerankRes = &RerankResult{ + Model: model, + Results: results, + } + } +} + +// --- ReasoningParseResult / ToolParseResult --- +// Constructed per-decode-event when models emit thinking/tool channels. + +func BenchmarkContracts_ReasoningParseResult_32Tokens(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkReasoningRes = benchReasoningParseResult32Tokens() + } +} + +func BenchmarkContracts_ReasoningParseResult_256Tokens(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkReasoningRes = benchReasoningParseResult256Tokens() + } +} + +func BenchmarkContracts_ToolParseResult_OneCall(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkToolRes = ToolParseResult{ + VisibleText: "I'll search for that.", + Calls: []ToolCall{ + { + ID: "call-1", + Name: "search", + Type: "function", + ArgumentsJSON: `{"q":"core","limit":10}`, + }, + }, + } + } +} + +func BenchmarkContracts_ToolParseResult_ThreeCalls(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkToolRes = ToolParseResult{ + VisibleText: "Running three tools in parallel.", + Calls: []ToolCall{ + {ID: "call-1", Name: "search", Type: "function", ArgumentsJSON: `{"q":"alpha"}`}, + {ID: "call-2", Name: "fetch", Type: "function", ArgumentsJSON: `{"url":"https://x"}`}, + {ID: "call-3", Name: "write", Type: "function", ArgumentsJSON: `{"path":"/tmp/out"}`}, + }, + } + } +} + +// --- ModelPackInspection (one per model-pack scan) --- + +func BenchmarkContracts_ModelPackInspection_Construct(b *testing.B) { + model := ModelIdentity{Architecture: "qwen3", NumLayers: 28, QuantBits: 4} + tokenizer := TokenizerIdentity{Kind: "sentencepiece", EOSID: 2} + caps := []Capability{ + SupportedCapability(CapabilityGenerate, CapabilityGroupModel), + SupportedCapability(CapabilityChat, CapabilityGroupModel), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkInspection = &ModelPackInspection{ + Path: "/models/qwen3-1b", + Format: "safetensors", + Model: model, + Tokenizer: tokenizer, + Supported: true, + Capabilities: caps, + } + } +} + +// --- Through a model — exercises the full call shape under the +// optional-interface scheduler / cache / embed / rerank / parsers. --- + +func BenchmarkContracts_SchedulerModel_Schedule(b *testing.B) { + model := &contractModel{} + req := benchScheduledRequestTypical() + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkRequestHandle, contractsBenchSinkChan, contractsBenchSinkErr = model.Schedule(ctx, req) + // Drain the one-element channel so the test cleanup paths + // match production usage and the GC can reclaim the buffer. + for range contractsBenchSinkChan { + } + } +} + +func BenchmarkContracts_CancellableModel_CancelRequest(b *testing.B) { + model := &contractModel{} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkCancelResult, contractsBenchSinkErr = model.CancelRequest(ctx, "req-1") + } +} + +func BenchmarkContracts_CacheService_CacheStats(b *testing.B) { + model := &contractModel{} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkCacheStats, contractsBenchSinkErr = model.CacheStats(ctx) + } +} + +func BenchmarkContracts_CacheService_WarmCache(b *testing.B) { + model := &contractModel{} + tokens := make([]int32, 64) + for i := range tokens { + tokens[i] = int32(i + 1) + } + req := CacheWarmRequest{Tokens: tokens} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkCacheWarmRes, contractsBenchSinkErr = model.WarmCache(ctx, req) + } +} + +func BenchmarkContracts_EmbeddingModel_Embed(b *testing.B) { + model := &contractModel{} + req := EmbeddingRequest{Input: []string{"hello"}} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkEmbedRes, contractsBenchSinkErr = model.Embed(ctx, req) + } +} + +func BenchmarkContracts_RerankModel_Rerank(b *testing.B) { + model := &contractModel{} + req := RerankRequest{Query: "core", Documents: []string{"doc"}} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkRerankRes, contractsBenchSinkErr = model.Rerank(ctx, req) + } +} + +func BenchmarkContracts_ReasoningParser_ParseReasoning(b *testing.B) { + model := &contractModel{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkReasoningRes, contractsBenchSinkErr = model.ParseReasoning(nil, "answer") + } +} + +func BenchmarkContracts_ToolParser_ParseTools(b *testing.B) { + model := &contractModel{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkToolRes, contractsBenchSinkErr = model.ParseTools(nil, "call") + } +} + +func BenchmarkContracts_ModelPackInspector_InspectModelPack(b *testing.B) { + model := &contractModel{} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + contractsBenchSinkInspection, contractsBenchSinkErr = model.InspectModelPack(ctx, "/models/qwen") + } +} diff --git a/go/contracts_example_test.go b/go/contracts_example_test.go new file mode 100644 index 0000000..803ac47 --- /dev/null +++ b/go/contracts_example_test.go @@ -0,0 +1,33 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "context" + + core "dappco.re/go" +) + +func ExampleCacheService() { + model := &contractModel{} + stats, _ := any(model).(CacheService).CacheStats(context.Background()) + + core.Println(stats.CacheMode) + // Output: paged-q8 +} + +func ExampleEmbeddingModel() { + model := &contractModel{} + result, _ := any(model).(EmbeddingModel).Embed(context.Background(), EmbeddingRequest{Input: []string{"core"}}) + + core.Println(len(result.Vectors)) + // Output: 1 +} + +func ExampleReasoningParser() { + model := &contractModel{} + result, _ := any(model).(ReasoningParser).ParseReasoning(nil, "visible") + + core.Println(result.Reasoning[0].Kind) + // Output: think +} diff --git a/go/contracts_test.go b/go/contracts_test.go new file mode 100644 index 0000000..109acbb --- /dev/null +++ b/go/contracts_test.go @@ -0,0 +1,225 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "context" + "testing" +) + +type contractModel struct { + *stubTextModel +} + +func (m *contractModel) Schedule(_ context.Context, req ScheduledRequest) (RequestHandle, <-chan ScheduledToken, error) { + ch := make(chan ScheduledToken, 1) + ch <- ScheduledToken{RequestID: req.ID, Token: Token{Text: "ok"}} + close(ch) + return RequestHandle{ID: req.ID}, ch, nil +} + +func (m *contractModel) CancelRequest(_ context.Context, id string) (RequestCancelResult, error) { + return RequestCancelResult{ID: id, Cancelled: id != ""}, nil +} + +func (m *contractModel) CacheStats(context.Context) (CacheStats, error) { + return CacheStats{Blocks: 2, Hits: 3, Misses: 1, HitRate: 0.75, CacheMode: "paged-q8"}, nil +} + +func (m *contractModel) WarmCache(_ context.Context, req CacheWarmRequest) (CacheWarmResult, error) { + return CacheWarmResult{Blocks: []CacheBlockRef{{ID: "block-1", TokenCount: len(req.Tokens)}}}, nil +} + +func (m *contractModel) ClearCache(context.Context, map[string]string) (CacheStats, error) { + return CacheStats{}, nil +} + +func (m *contractModel) Embed(_ context.Context, req EmbeddingRequest) (*EmbeddingResult, error) { + return &EmbeddingResult{Vectors: [][]float32{{1, 0}}, Usage: EmbeddingUsage{PromptTokens: len(req.Input), TotalTokens: len(req.Input)}}, nil +} + +func (m *contractModel) Rerank(_ context.Context, req RerankRequest) (*RerankResult, error) { + return &RerankResult{Results: []RerankScore{{Index: 0, Score: 0.9, Text: req.Documents[0]}}}, nil +} + +func (m *contractModel) ParseReasoning(_ []Token, text string) (ReasoningParseResult, error) { + return ReasoningParseResult{VisibleText: text, Reasoning: []ReasoningSegment{{Kind: "think", Text: "plan"}}}, nil +} + +func (m *contractModel) ParseTools(_ []Token, text string) (ToolParseResult, error) { + return ToolParseResult{VisibleText: text, Calls: []ToolCall{{ID: "call-1", Name: "search", Type: "function", ArgumentsJSON: `{"q":"core"}`}}}, nil +} + +func (m *contractModel) InspectModelPack(_ context.Context, path string) (*ModelPackInspection, error) { + return &ModelPackInspection{Path: path, Format: "safetensors", Supported: true, Model: ModelIdentity{Architecture: "qwen3"}}, nil +} + +func (m *contractModel) WakeState(_ context.Context, req AgentMemoryWakeRequest) (*AgentMemoryWakeResult, error) { + return &AgentMemoryWakeResult{ + Entry: AgentMemoryRef{URI: req.EntryURI, TokenCount: 8}, + PrefixTokens: 8, + BlocksRead: 2, + }, nil +} + +func (m *contractModel) SleepState(_ context.Context, req AgentMemorySleepRequest) (*AgentMemorySleepResult, error) { + return &AgentMemorySleepResult{ + Entry: AgentMemoryRef{URI: req.EntryURI, Title: req.Title, TokenCount: 9}, + TokenCount: 9, + BlocksWritten: 3, + }, nil +} + +func (m *contractModel) ForkState(_ context.Context, req AgentMemoryWakeRequest) (AgentMemorySession, *AgentMemoryWakeResult, error) { + return m, &AgentMemoryWakeResult{Entry: AgentMemoryRef{URI: req.EntryURI}, PrefixTokens: 8}, nil +} + +func TestContracts_NewCapabilityIDs_Good(t *testing.T) { + ids := []CapabilityID{ + CapabilityResponsesAPI, + CapabilityAnthropicMessages, + CapabilityOllamaCompat, + CapabilityEmbeddings, + CapabilityRerank, + CapabilityScheduler, + CapabilityRequestCancel, + CapabilityCacheBlocks, + CapabilityCacheDisk, + CapabilityCacheWarm, + CapabilityToolParse, + CapabilityReasoningParse, + CapabilitySpeculativeDecode, + CapabilityPromptLookupDecode, + CapabilityMoERouting, + CapabilityMoELazyExperts, + CapabilityJANGTQ, + CapabilityCodebookVQ, + CapabilityAgentMemory, + CapabilityStateWake, + CapabilityStateSleep, + CapabilityStateFork, + } + + seen := map[CapabilityID]bool{} + for _, id := range ids { + if id == "" { + t.Fatal("capability ID must not be blank") + } + if seen[id] { + t.Fatalf("duplicate capability ID %q", id) + } + seen[id] = true + } +} + +func TestContracts_OptionalInterfaces_Good(t *testing.T) { + model := &contractModel{stubTextModel: &stubTextModel{}} + + _, ok := any(model).(SchedulerModel) + checkTrue(t, ok) + _, ok = any(model).(CancellableModel) + checkTrue(t, ok) + _, ok = any(model).(CacheService) + checkTrue(t, ok) + _, ok = any(model).(EmbeddingModel) + checkTrue(t, ok) + _, ok = any(model).(RerankModel) + checkTrue(t, ok) + _, ok = any(model).(ReasoningParser) + checkTrue(t, ok) + _, ok = any(model).(ToolParser) + checkTrue(t, ok) + _, ok = any(model).(ModelPackInspector) + checkTrue(t, ok) + _, ok = any(model).(AgentMemorySession) + checkTrue(t, ok) + _, ok = any(model).(AgentMemoryForker) + checkTrue(t, ok) +} + +func TestContracts_TextModelCapabilities_Good_InferNewOptionalInterfaces(t *testing.T) { + report := TextModelCapabilities(RuntimeIdentity{Backend: "test"}, &contractModel{stubTextModel: &stubTextModel{}}) + + checkTrue(t, report.Supports(CapabilityScheduler)) + checkTrue(t, report.Supports(CapabilityRequestCancel)) + checkTrue(t, report.Supports(CapabilityCacheBlocks)) + checkTrue(t, report.Supports(CapabilityCacheWarm)) + checkTrue(t, report.Supports(CapabilityEmbeddings)) + checkTrue(t, report.Supports(CapabilityRerank)) + checkTrue(t, report.Supports(CapabilityReasoningParse)) + checkTrue(t, report.Supports(CapabilityToolParse)) + checkTrue(t, report.Supports(CapabilityAgentMemory)) + checkTrue(t, report.Supports(CapabilityStateWake)) + checkTrue(t, report.Supports(CapabilityStateSleep)) + checkTrue(t, report.Supports(CapabilityStateFork)) +} + +func TestContracts_CacheService_Good(t *testing.T) { + model := &contractModel{} + service := any(model).(CacheService) + + stats, err := service.CacheStats(context.Background()) + checkNoError(t, err) + checkEqual(t, "paged-q8", stats.CacheMode) + + warmed, err := service.WarmCache(context.Background(), CacheWarmRequest{Tokens: []int32{1, 2, 3}}) + checkNoError(t, err) + checkLen(t, warmed.Blocks, 1) + checkEqual(t, 3, warmed.Blocks[0].TokenCount) +} + +func TestContracts_EmbeddingAndRerank_Good(t *testing.T) { + model := &contractModel{} + + embeddings, err := any(model).(EmbeddingModel).Embed(context.Background(), EmbeddingRequest{Input: []string{"hello"}}) + checkNoError(t, err) + checkLen(t, embeddings.Vectors, 1) + checkEqual(t, 1, embeddings.Usage.TotalTokens) + + reranked, err := any(model).(RerankModel).Rerank(context.Background(), RerankRequest{Query: "core", Documents: []string{"doc"}}) + checkNoError(t, err) + checkLen(t, reranked.Results, 1) + checkEqual(t, "doc", reranked.Results[0].Text) +} + +func TestContracts_Parsers_Good(t *testing.T) { + model := &contractModel{} + + reasoning, err := any(model).(ReasoningParser).ParseReasoning(nil, "answer") + checkNoError(t, err) + checkEqual(t, "answer", reasoning.VisibleText) + checkLen(t, reasoning.Reasoning, 1) + + tools, err := any(model).(ToolParser).ParseTools(nil, "call") + checkNoError(t, err) + checkLen(t, tools.Calls, 1) + checkEqual(t, "search", tools.Calls[0].Name) +} + +func TestContracts_ModelPackInspector_Good(t *testing.T) { + inspection, err := any(&contractModel{}).(ModelPackInspector).InspectModelPack(context.Background(), "/models/qwen") + + checkNoError(t, err) + checkTrue(t, inspection.Supported) + checkEqual(t, "qwen3", inspection.Model.Architecture) +} + +func TestContracts_AgentMemorySession_Good(t *testing.T) { + model := &contractModel{} + session := any(model).(AgentMemorySession) + + wake, err := session.WakeState(context.Background(), AgentMemoryWakeRequest{EntryURI: "mlx://memory/chapter-1"}) + checkNoError(t, err) + checkEqual(t, 8, wake.PrefixTokens) + checkEqual(t, "mlx://memory/chapter-1", wake.Entry.URI) + + sleep, err := session.SleepState(context.Background(), AgentMemorySleepRequest{EntryURI: "mlx://memory/chapter-1/after", Title: "after"}) + checkNoError(t, err) + checkEqual(t, 9, sleep.TokenCount) + checkEqual(t, "after", sleep.Entry.Title) + + forked, forkWake, err := any(model).(AgentMemoryForker).ForkState(context.Background(), AgentMemoryWakeRequest{EntryURI: "mlx://memory/chapter-1"}) + checkNoError(t, err) + checkNotNil(t, forked) + checkEqual(t, 8, forkWake.PrefixTokens) +} diff --git a/go/dataset.go b/go/dataset.go new file mode 100644 index 0000000..4d8656c --- /dev/null +++ b/go/dataset.go @@ -0,0 +1,174 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import "context" + +// DatasetSample is a backend-neutral training or evaluation item. +type DatasetSample struct { + Text string `json:"text,omitempty"` + Prompt string `json:"prompt,omitempty"` + Response string `json:"response,omitempty"` + Reasoning string `json:"reasoning,omitempty"` + Messages []Message `json:"messages,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// DatasetStream is the smallest pull-based dataset contract shared by +// training, evaluation, distillation, and reasoning rollouts. +type DatasetStream interface { + Next() (DatasetSample, bool, error) +} + +// DatasetResetter marks streams that can replay from the start. +type DatasetResetter interface { + Reset() error +} + +// LossMask marks which token positions contribute to training loss. +type LossMask struct { + Values [][]float32 `json:"values,omitempty"` +} + +// Batch is a tokenizer-ready batch with optional response-loss masking. +type Batch struct { + TokenIDs [][]int32 `json:"token_ids,omitempty"` + AttentionMask [][]float32 `json:"attention_mask,omitempty"` + LossMask LossMask `json:"loss_mask,omitempty"` + Samples []DatasetSample `json:"samples,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// EvalConfig controls model evaluation over a dataset stream. +type EvalConfig struct { + MaxSamples int `json:"max_samples,omitempty"` + BatchSize int `json:"batch_size,omitempty"` + MaxSeqLen int `json:"max_seq_len,omitempty"` + Probes []QualityProbe `json:"probes,omitempty"` +} + +// EvalMetrics records aggregate loss and perplexity counters. +type EvalMetrics struct { + Samples int `json:"samples,omitempty"` + Tokens int `json:"tokens,omitempty"` + Loss float64 `json:"loss,omitempty"` + Perplexity float64 `json:"perplexity,omitempty"` +} + +// QualityProbe is a small named prompt used for qualitative checks. +type QualityProbe struct { + Name string `json:"name,omitempty"` + Prompt string `json:"prompt,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// QualityProbeResult records one qualitative probe result. +type QualityProbeResult struct { + Name string `json:"name,omitempty"` + Passed bool `json:"passed,omitempty"` + Score float64 `json:"score,omitempty"` + Text string `json:"text,omitempty"` +} + +// EvalReport is the portable output of dataset evaluation. +type EvalReport struct { + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Metrics EvalMetrics `json:"metrics,omitempty"` + Probes []QualityProbeResult `json:"probes,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// BenchConfig controls reusable local inference benchmarks. +type BenchConfig struct { + Prompts []string `json:"prompts,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + WarmupRuns int `json:"warmup_runs,omitempty"` + MeasuredRuns int `json:"measured_runs,omitempty"` +} + +// BenchReport records fast local benchmark counters. +type BenchReport struct { + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + PromptTokens int `json:"prompt_tokens,omitempty"` + GeneratedTokens int `json:"generated_tokens,omitempty"` + PrefillTokensPerSec float64 `json:"prefill_tokens_per_sec,omitempty"` + DecodeTokensPerSec float64 `json:"decode_tokens_per_sec,omitempty"` + PeakMemoryBytes uint64 `json:"peak_memory_bytes,omitempty"` + PromptCacheHitRate float64 `json:"prompt_cache_hit_rate,omitempty"` + KVRestoreMilliseconds float64 `json:"kv_restore_milliseconds,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// MemoryPlan records device-informed runtime settings. +type MemoryPlan struct { + MachineClass string `json:"machine_class,omitempty"` + DeviceMemoryBytes uint64 `json:"device_memory_bytes,omitempty"` + ContextLength int `json:"context_length,omitempty"` + BatchSize int `json:"batch_size,omitempty"` + CacheMode string `json:"cache_mode,omitempty"` + Quantization string `json:"quantization,omitempty"` + KVCacheBytes uint64 `json:"kv_cache_bytes,omitempty"` + TrainingFeasible bool `json:"training_feasible,omitempty"` + Notes []string `json:"notes,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ModelFitReport records whether a model is expected to fit a machine. +type ModelFitReport struct { + Model ModelIdentity `json:"model,omitempty"` + Fits bool `json:"fits,omitempty"` + MemoryPlan MemoryPlan `json:"memory_plan,omitempty"` + ArchitectureOK bool `json:"architecture_ok,omitempty"` + QuantizationOK bool `json:"quantization_ok,omitempty"` + Notes []string `json:"notes,omitempty"` +} + +// TrainingConfig is the shared SFT LoRA training configuration envelope. +type TrainingConfig struct { + Epochs int `json:"epochs,omitempty"` + BatchSize int `json:"batch_size,omitempty"` + GradientAccumulation int `json:"gradient_accumulation,omitempty"` + LearningRate float64 `json:"learning_rate,omitempty"` + LoRA LoRAConfig `json:"lora,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// TrainingMetrics records live or final training counters. +type TrainingMetrics struct { + Epoch int `json:"epoch,omitempty"` + Step int `json:"step,omitempty"` + Samples int `json:"samples,omitempty"` + Tokens int `json:"tokens,omitempty"` + Loss float64 `json:"loss,omitempty"` + LearningRate float64 `json:"learning_rate,omitempty"` +} + +// TrainingResult is the portable output of a training run. +type TrainingResult struct { + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Metrics TrainingMetrics `json:"metrics,omitempty"` + Checkpoints []StateRef `json:"checkpoints,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// DistillConfig controls teacher/student distillation. +type DistillConfig struct { + TrainingConfig + Temperature float64 `json:"temperature,omitempty"` + Alpha float64 `json:"alpha,omitempty"` +} + +// GRPOConfig controls grouped reasoning policy optimisation. +type GRPOConfig struct { + TrainingConfig + GroupSize int `json:"group_size,omitempty"` + KLWeight float64 `json:"kl_weight,omitempty"` +} + +// Evaluator marks backends or adapters that can evaluate dataset streams. +type Evaluator interface { + Evaluate(ctx context.Context, dataset DatasetStream, cfg EvalConfig) (*EvalReport, error) +} diff --git a/go/dataset_bench_test.go b/go/dataset_bench_test.go new file mode 100644 index 0000000..bcd48f6 --- /dev/null +++ b/go/dataset_bench_test.go @@ -0,0 +1,211 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for dataset / batch / report shapes — JSON marshal for +// EvalReport + BenchReport (the wire format trainers + UIs reach for) +// plus the DatasetStream Next-loop floor (per-sample iteration cost). +// Per AX-11 — these shapes carry per-sample/per-result data so any +// allocation-per-call cost compounds across a full training run. +// +// Run: go test -bench='BenchmarkDataset' -benchmem -run='^$' . + +package inference + +import ( + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. +var ( + datasetBenchSinkString string + datasetBenchSinkSample DatasetSample + datasetBenchSinkBatch Batch + datasetBenchSinkOK bool + datasetBenchSinkErr error + datasetBenchSinkCount int +) + +// benchDatasetStream is a deterministic in-memory stream — same shape as +// the test-suite stub but exposed at file scope so the per-Next floor +// can be measured without t.Helper bookkeeping. +type benchDatasetStream struct { + samples []DatasetSample + index int +} + +func (s *benchDatasetStream) Next() (DatasetSample, bool, error) { + if s.index >= len(s.samples) { + return DatasetSample{}, false, nil + } + sample := s.samples[s.index] + s.index++ + return sample, true, nil +} + +func (s *benchDatasetStream) Reset() error { + s.index = 0 + return nil +} + +func buildBenchDatasetSamples(n int) []DatasetSample { + samples := make([]DatasetSample, n) + for i := range samples { + samples[i] = DatasetSample{ + Prompt: core.Sprintf("prompt-%d", i), + Response: core.Sprintf("response-%d", i), + Messages: []Message{ + {Role: "user", Content: core.Sprintf("turn-%d", i)}, + {Role: "assistant", Content: core.Sprintf("reply-%d", i)}, + }, + Labels: map[string]string{"source": "bench", "split": "train"}, + } + } + return samples +} + +// --- DatasetStream.Next — per-sample iteration floor --- + +func BenchmarkDataset_StreamNext_Hit(b *testing.B) { + stream := &benchDatasetStream{samples: buildBenchDatasetSamples(1)} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + stream.index = 0 + datasetBenchSinkSample, datasetBenchSinkOK, datasetBenchSinkErr = stream.Next() + } +} + +func BenchmarkDataset_StreamNext_Exhausted(b *testing.B) { + stream := &benchDatasetStream{samples: nil} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + datasetBenchSinkSample, datasetBenchSinkOK, datasetBenchSinkErr = stream.Next() + } +} + +func BenchmarkDataset_StreamLoop_100Samples(b *testing.B) { + samples := buildBenchDatasetSamples(100) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + stream := &benchDatasetStream{samples: samples} + count := 0 + for { + _, ok, err := stream.Next() + if !ok || err != nil { + break + } + count++ + } + datasetBenchSinkCount = count + } +} + +// --- Batch struct copies (per-batch carry cost) --- + +func BenchmarkDataset_BatchAssemble_Small(b *testing.B) { + samples := buildBenchDatasetSamples(8) + tokenIDs := [][]int32{{1, 2, 3, 4}, {5, 6, 7, 8}} + attention := [][]float32{{1, 1, 1, 1}, {1, 1, 1, 0}} + lossMask := LossMask{Values: [][]float32{{0, 0, 1, 1}, {0, 1, 1, 0}}} + labels := map[string]string{"split": "train"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + datasetBenchSinkBatch = Batch{ + TokenIDs: tokenIDs, + AttentionMask: attention, + LossMask: lossMask, + Samples: samples, + Labels: labels, + } + } +} + +// --- JSON serialisation of the portable report types --- + +func BenchmarkDataset_EvalReport_Marshal(b *testing.B) { + report := EvalReport{ + Model: ModelIdentity{Architecture: "qwen3", QuantBits: 4}, + Metrics: EvalMetrics{ + Samples: 2048, + Tokens: 262144, + Loss: 1.234, + Perplexity: 3.4321, + }, + Probes: []QualityProbeResult{ + {Name: "integrity", Passed: true, Score: 0.91}, + {Name: "calibration", Passed: true, Score: 0.82}, + {Name: "stability", Passed: false, Score: 0.43}, + }, + Labels: map[string]string{"run": "nightly-2026-05-21"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + datasetBenchSinkString = core.JSONMarshalString(report) + } +} + +func BenchmarkDataset_BenchReport_Marshal(b *testing.B) { + report := BenchReport{ + Model: ModelIdentity{Architecture: "gemma4", QuantBits: 4}, + Adapter: AdapterIdentity{Path: "/adapters/v3", Rank: 16, Alpha: 32}, + PromptTokens: 2048, + GeneratedTokens: 512, + PrefillTokensPerSec: 1240.5, + DecodeTokensPerSec: 45.2, + PeakMemoryBytes: 12 << 30, + PromptCacheHitRate: 0.81, + KVRestoreMilliseconds: 12.4, + Labels: map[string]string{"workload": "long_context"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + datasetBenchSinkString = core.JSONMarshalString(report) + } +} + +func BenchmarkDataset_MemoryPlan_Marshal(b *testing.B) { + plan := MemoryPlan{ + MachineClass: "m3-ultra-96gb", + DeviceMemoryBytes: 96 << 30, + ContextLength: 131072, + BatchSize: 4, + CacheMode: "paged-q8", + Quantization: "q4_k_m", + KVCacheBytes: 18 << 30, + TrainingFeasible: true, + Notes: []string{"reserve 4GB for OS", "leave 8GB headroom"}, + Labels: map[string]string{"profile": "long_context"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + datasetBenchSinkString = core.JSONMarshalString(plan) + } +} + +func BenchmarkDataset_ModelFitReport_Marshal(b *testing.B) { + report := ModelFitReport{ + Model: ModelIdentity{Architecture: "qwen3", QuantBits: 4, ContextLength: 32768}, + Fits: true, + ArchitectureOK: true, + QuantizationOK: true, + MemoryPlan: MemoryPlan{ + MachineClass: "m3-ultra-96gb", + ContextLength: 32768, + CacheMode: "paged-q4", + TrainingFeasible: false, + }, + Notes: []string{"context fits", "training not feasible at this quant"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + datasetBenchSinkString = core.JSONMarshalString(report) + } +} diff --git a/go/dataset_example_test.go b/go/dataset_example_test.go new file mode 100644 index 0000000..f248933 --- /dev/null +++ b/go/dataset_example_test.go @@ -0,0 +1,30 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import core "dappco.re/go" + +func ExampleDatasetSample() { + sample := DatasetSample{ + Messages: []Message{ + {Role: "user", Content: "Explain KV cache reuse"}, + {Role: "assistant", Content: "KV cache reuse avoids recomputing prior context."}, + }, + Reasoning: "focus on local inference state", + } + + core.Println(len(sample.Messages), sample.Reasoning) + // Output: 2 focus on local inference state +} + +func ExampleBenchReport() { + report := BenchReport{ + Model: ModelIdentity{Architecture: "qwen3"}, + PrefillTokensPerSec: 1400, + DecodeTokensPerSec: 42, + PromptCacheHitRate: 0.75, + } + + core.Println(report.Model.Architecture, report.DecodeTokensPerSec, report.PromptCacheHitRate) + // Output: qwen3 42 0.75 +} diff --git a/go/dataset_test.go b/go/dataset_test.go new file mode 100644 index 0000000..4719ff9 --- /dev/null +++ b/go/dataset_test.go @@ -0,0 +1,146 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "context" + "testing" +) + +type datasetStreamStub struct { + samples []DatasetSample + index int +} + +func (s *datasetStreamStub) Next() (DatasetSample, bool, error) { + if s.index >= len(s.samples) { + return DatasetSample{}, false, nil + } + sample := s.samples[s.index] + s.index++ + return sample, true, nil +} + +func (s *datasetStreamStub) Reset() error { + s.index = 0 + return nil +} + +type evaluatorStub struct { + report *EvalReport +} + +func (e evaluatorStub) Evaluate(context.Context, DatasetStream, EvalConfig) (*EvalReport, error) { + return e.report, nil +} + +func TestDataset_DatasetSample_Good(t *testing.T) { + sample := DatasetSample{ + Prompt: "question", + Response: "answer", + Reasoning: "work", + Messages: []Message{{Role: "user", Content: "question"}}, + Labels: map[string]string{"source": "unit"}, + } + + checkEqual(t, "question", sample.Prompt) + checkLen(t, sample.Messages, 1) + checkEqual(t, "unit", sample.Labels["source"]) +} + +func TestDatasetBatchLossMask(t *testing.T) { + batch := Batch{ + TokenIDs: [][]int32{{1, 2, 3}}, + LossMask: LossMask{Values: [][]float32{{ + 0, + 1, + 1, + }}}, + } + + checkEqual(t, float32(1), batch.LossMask.Values[0][1]) +} + +func TestDatasetStreamReset(t *testing.T) { + stream := &datasetStreamStub{ + samples: []DatasetSample{{Text: "one"}}, + } + + sample, ok, err := stream.Next() + checkNoError(t, err) + checkTrue(t, ok) + checkEqual(t, "one", sample.Text) + + sample, ok, err = stream.Next() + checkNoError(t, err) + checkFalse(t, ok) + checkEqual(t, DatasetSample{}, sample) + + checkNoError(t, stream.Reset()) + sample, ok, err = stream.Next() + checkNoError(t, err) + checkTrue(t, ok) + checkEqual(t, "one", sample.Text) +} + +func TestDataset_EvalReport_Good(t *testing.T) { + report := EvalReport{ + Model: ModelIdentity{Architecture: "qwen3"}, + Metrics: EvalMetrics{ + Samples: 2, + Tokens: 64, + Loss: 1.25, + Perplexity: 3.49, + }, + Probes: []QualityProbeResult{{ + Name: "integrity", + Passed: true, + Score: 0.9, + }}, + } + evaluator := evaluatorStub{report: &report} + + got, err := evaluator.Evaluate(context.Background(), &datasetStreamStub{}, EvalConfig{MaxSamples: 2}) + + checkNoError(t, err) + checkEqual(t, "qwen3", got.Model.Architecture) + checkEqual(t, 64, got.Metrics.Tokens) + checkLen(t, got.Probes, 1) +} + +func TestDatasetBenchAndMemoryPlan(t *testing.T) { + report := BenchReport{ + Model: ModelIdentity{Architecture: "gemma4"}, + PromptTokens: 2048, + GeneratedTokens: 128, + PrefillTokensPerSec: 1200, + DecodeTokensPerSec: 32, + PeakMemoryBytes: 8 << 30, + PromptCacheHitRate: 0.8, + KVRestoreMilliseconds: 12.5, + } + plan := MemoryPlan{ + MachineClass: "m3-ultra-96gb", + DeviceMemoryBytes: 96 << 30, + ContextLength: 131072, + CacheMode: "paged-q8", + TrainingFeasible: true, + } + + checkEqual(t, "gemma4", report.Model.Architecture) + checkEqual(t, float64(0.8), report.PromptCacheHitRate) + checkEqual(t, "paged-q8", plan.CacheMode) + checkTrue(t, plan.TrainingFeasible) +} + +func TestDataset_TrainingResult_Ugly_CheckpointsOnly(t *testing.T) { + result := TrainingResult{ + Checkpoints: []StateRef{{ + Kind: "checkpoint", + URI: "file:///tmp/step-10", + }}, + } + + checkLen(t, result.Checkpoints, 1) + checkEqual(t, "", result.Model.Architecture) +} diff --git a/go/decode/decode.go b/go/decode/decode.go new file mode 100644 index 0000000..3148611 --- /dev/null +++ b/go/decode/decode.go @@ -0,0 +1,404 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package decode is the driver-neutral decode-optimisation harness used +// by speculative and prompt-lookup decode benchmarks. +// +// The acceptance algorithm is a generic accept/reject over token streams; +// generation is delegated to caller-supplied Generator implementations. +// The package is shared by every backend driver (go-mlx, go-cuda, +// go-rocm) that wants a portable speculative or prompt-lookup decode +// report. Stateful drivers can implement Generator on a pooled struct; +// func-style callers can wrap with GeneratorFunc. +// +// result, err := decode.Speculative(ctx, decode.SpeculativeConfig{ +// Prompt: "Write a haiku.", +// MaxTokens: 64, +// TargetGenerate: target, +// DraftGenerate: draft, +// }) +package decode + +import ( + "context" + "time" + + core "dappco.re/go" +) + +// Token is one element of a generation sequence — ID plus an optional +// surface form. Drivers populate the fields their tokenizer can report. +type Token struct { + ID int32 `json:"id,omitempty"` + Value string `json:"value,omitempty"` + Text string `json:"text,omitempty"` +} + +// GenerateConfig is the per-call generation request passed to the +// caller-supplied Generator. Only MaxTokens is consumed by decode; +// drivers may carry extra context inside their Generator implementation. +type GenerateConfig struct { + MaxTokens int `json:"max_tokens"` +} + +// Generation is the result Generator.Generate returns to decode. +type Generation struct { + Tokens []Token `json:"tokens,omitempty"` + Text string `json:"text,omitempty"` +} + +// Generator is the model-side generation hook. decode supplies the +// prompt + per-call config; the driver decides how to evaluate it. +// Stateful drivers (e.g. a pooled *modelDecodeGenerator from go-mlx) +// implement Generate directly — no per-call closure allocation. +type Generator interface { + Generate(ctx context.Context, prompt string, cfg GenerateConfig) (Generation, error) +} + +// GeneratorFunc adapts a plain function to the Generator interface. +// Callers with a func value can wrap once and pass through; the wrap +// itself is a value-typed conversion, not a heap allocation. +// +// cfg.TargetGenerate = decode.GeneratorFunc(myFunc) +type GeneratorFunc func(ctx context.Context, prompt string, cfg GenerateConfig) (Generation, error) + +// Generate dispatches the wrapped function. Method on a value receiver +// so the conversion `GeneratorFunc(fn)` is interface-assignable without +// taking the address of a temporary. +func (f GeneratorFunc) Generate(ctx context.Context, prompt string, cfg GenerateConfig) (Generation, error) { + return f(ctx, prompt, cfg) +} + +// GenerateFunc is the legacy func-type alias retained for callers that +// declared variables of this type. New code should use Generator (the +// interface) or GeneratorFunc (the func-to-interface adapter) instead. +type GenerateFunc = GeneratorFunc + +// SpeculativeConfig configures the speculative-decode reference path. +// Target + draft generators must both be supplied; decode compares their +// outputs token-by-token to produce an acceptance report. Generator is +// an interface so stateful pooled implementations can avoid the +// per-call closure allocation; func-style callers wrap with +// GeneratorFunc. +type SpeculativeConfig struct { + Prompt string `json:"prompt,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + DraftTokens int `json:"draft_tokens,omitempty"` + GenerateConfig GenerateConfig `json:"generate_config,omitempty"` + TargetGenerate Generator `json:"-"` + DraftGenerate Generator `json:"-"` +} + +// PromptLookupConfig configures prompt-lookup decoding over a caller- +// supplied token sequence (typically derived from repeated context in +// the prompt). +type PromptLookupConfig struct { + Prompt string `json:"prompt,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` + GenerateConfig GenerateConfig `json:"generate_config,omitempty"` + TargetGenerate Generator `json:"-"` + LookupTokens []Token `json:"lookup_tokens,omitempty"` +} + +// Result is the common decode-optimisation report. +type Result struct { + Mode string `json:"mode"` + Prompt string `json:"prompt,omitempty"` + Text string `json:"text,omitempty"` + Tokens []Token `json:"tokens,omitempty"` + Metrics Metrics `json:"metrics"` +} + +// Metrics records candidate acceptance and call-level timing. +type Metrics struct { + TargetTokens int `json:"target_tokens,omitempty"` + DraftTokens int `json:"draft_tokens,omitempty"` + LookupTokens int `json:"lookup_tokens,omitempty"` + AcceptedTokens int `json:"accepted_tokens,omitempty"` + RejectedTokens int `json:"rejected_tokens,omitempty"` + EmittedTokens int `json:"emitted_tokens,omitempty"` + AcceptanceRate float64 `json:"acceptance_rate,omitempty"` + TargetCalls int `json:"target_calls,omitempty"` + DraftCalls int `json:"draft_calls,omitempty"` + Duration time.Duration `json:"duration,omitempty"` + TargetDuration time.Duration `json:"target_duration,omitempty"` + DraftDuration time.Duration `json:"draft_duration,omitempty"` +} + +// Mode constants identify which decode-optimisation produced a Result. +const ( + ModeSpeculative = "speculative" + ModePromptLookup = "prompt_lookup" +) + +// DefaultMaxTokens is the fallback when neither the caller nor the +// embedded GenerateConfig supplies a positive max. +const DefaultMaxTokens = 256 + +// Speculative compares draft-model candidates against target-model +// tokens and reports deterministic acceptance metrics. This is the safe +// reference API; it does not claim a speedup until a backend provides +// native verification that the benchmark can measure. +// +// result, err := decode.Speculative(ctx, cfg) +func Speculative(ctx context.Context, cfg SpeculativeConfig) (Result, error) { + if cfg.TargetGenerate == nil { + return Result{}, core.NewError("decode: speculative decode requires target generator") + } + if cfg.DraftGenerate == nil { + return Result{}, core.NewError("decode: speculative decode requires draft generator") + } + if ctx == nil { + ctx = context.Background() + } + maxTokens := normaliseMaxTokens(cfg.MaxTokens, cfg.GenerateConfig.MaxTokens) + targetCfg := cfg.GenerateConfig + targetCfg.MaxTokens = maxTokens + draftCfg := cfg.GenerateConfig + draftCfg.MaxTokens = cfg.DraftTokens + if draftCfg.MaxTokens <= 0 || draftCfg.MaxTokens > maxTokens { + draftCfg.MaxTokens = maxTokens + } + + // Single time.Now() for both the total-Duration anchor and the + // draft sub-window — the previous shape fired time.Now() twice + // back-to-back, which on Apple Silicon costs ~6 ns per call but + // adds nothing the second timestamp doesn't already capture. + start := time.Now() + draft, err := cfg.DraftGenerate.Generate(ctx, cfg.Prompt, draftCfg) + draftDuration := nonZeroDuration(time.Since(start)) + if err != nil { + return Result{}, err + } + targetStart := time.Now() + target, err := cfg.TargetGenerate.Generate(ctx, cfg.Prompt, targetCfg) + targetDuration := nonZeroDuration(time.Since(targetStart)) + if err != nil { + return Result{}, err + } + result := buildAcceptanceResult(ModeSpeculative, cfg.Prompt, target.Tokens, draft.Tokens, maxTokens) + result.Metrics.TargetTokens = len(target.Tokens) + result.Metrics.DraftTokens = len(draft.Tokens) + result.Metrics.TargetCalls = 1 + result.Metrics.DraftCalls = 1 + result.Metrics.Duration = nonZeroDuration(time.Since(start)) + result.Metrics.TargetDuration = targetDuration + result.Metrics.DraftDuration = draftDuration + return result, nil +} + +// PromptLookup compares prompt-derived lookup candidates against the +// target stream and reports how often repeated-context tokens were +// reusable. +// +// result, err := decode.PromptLookup(ctx, cfg) +func PromptLookup(ctx context.Context, cfg PromptLookupConfig) (Result, error) { + if cfg.TargetGenerate == nil { + return Result{}, core.NewError("decode: prompt lookup decode requires target generator") + } + if ctx == nil { + ctx = context.Background() + } + maxTokens := normaliseMaxTokens(cfg.MaxTokens, cfg.GenerateConfig.MaxTokens) + targetCfg := cfg.GenerateConfig + targetCfg.MaxTokens = maxTokens + // Single time.Now() — the previous shape fired back-to-back + // time.Now() into start + targetStart, but the target call is + // the only thing the duration spans, so they're the same anchor. + start := time.Now() + target, err := cfg.TargetGenerate.Generate(ctx, cfg.Prompt, targetCfg) + targetDuration := nonZeroDuration(time.Since(start)) + if err != nil { + return Result{}, err + } + result := buildAcceptanceResult(ModePromptLookup, cfg.Prompt, target.Tokens, cfg.LookupTokens, maxTokens) + result.Metrics.TargetTokens = len(target.Tokens) + result.Metrics.LookupTokens = len(cfg.LookupTokens) + result.Metrics.TargetCalls = 1 + result.Metrics.Duration = nonZeroDuration(time.Since(start)) + result.Metrics.TargetDuration = targetDuration + return result, nil +} + +// TokensText renders a token slice as a concatenated string, preferring +// each token's Text field then falling back to Value. Exported so +// drivers that need the same rendering for non-decode paths can reuse it. +// +// text := decode.TokensText(result.Tokens) +func TokensText(tokens []Token) string { + // Pre-grow the builder using each token's actual length. Strings + // are immutable so reading len() is free; this saves the cascade + // of doubling allocs the builder would otherwise pay as it grows + // from 0 → final size. For 2048-token decodes that's ~10 allocs + // down to 1. Index iteration avoids the per-iter 40-byte Token + // copy a range-value loop emits. + total := 0 + for i := range tokens { + text := tokens[i].Text + if text == "" { + text = tokens[i].Value + } + total += len(text) + } + return tokensTextSized(tokens, total) +} + +// tokensTextSized is TokensText with the total length pre-computed by +// the caller. buildAcceptanceResult walks the token stream once during +// the acceptance pass and already knows the rendered length when it +// gets here, so the second len-summing walk is redundant. Exported +// (lowercase) only so the inner loop can elide that walk; external +// callers go through TokensText, which computes total itself. +func tokensTextSized(tokens []Token, total int) string { + builder := core.NewBuilder() + builder.Grow(total) + // Index iteration avoids the per-iter 40-byte Token copy that a + // range-value loop emits; we only read two string headers from + // the slice slot, never the int32 ID. + for i := range tokens { + text := tokens[i].Text + if text == "" { + text = tokens[i].Value + } + builder.WriteString(text) + } + return builder.String() +} + +// CloneTokens returns an independent copy of a token slice. +// +// out := decode.CloneTokens(in) +func CloneTokens(tokens []Token) []Token { + out := make([]Token, len(tokens)) + copy(out, tokens) + return out +} + +// TokenEqual reports whether two tokens identify the same surface form. +// IDs must match; if both surface strings are non-empty they must also +// match. +// +// if decode.TokenEqual(a, b) { … } +func TokenEqual(a, b Token) bool { + if a.ID != b.ID { + return false + } + aText := tokenSurface(a) + bText := tokenSurface(b) + if aText == "" || bText == "" { + return true + } + return aText == bText +} + +func buildAcceptanceResult(mode, prompt string, target, candidates []Token, maxTokens int) Result { + limit := len(target) + if maxTokens > 0 && maxTokens < limit { + limit = maxTokens + } + // Pre-size + direct index assignment beats append on a known-N + // loop: the append cap-check + len-bump on every iteration is dead + // weight when we know we write exactly `limit` tokens. Saves the + // per-token slice-header bookkeeping over a 2048-token pass. + out := make([]Token, limit) + // Track the rendered text length alongside the build loop so the + // TokensText pre-grow walk fuses with the acceptance pass — the + // previous shape walked the emitted tokens twice (once to build + // out, once inside TokensText to sum lengths). At 2048 tokens that + // halves the walk count over the slice. + totalText := 0 + var accepted, rejected int + candidateLen := len(candidates) + for i := 0; i < limit; i++ { + // Write the emitted token directly into out[i] from whichever + // source slice owns it — avoids the intermediate `emitted` + // stack variable plus the speculative pre-load of + // `targetToken := target[i]`. Per token this saves two 40-byte + // struct copies (Token is 40 bytes on arm64 / amd64). + if i < candidateLen && TokenEqual(candidates[i], target[i]) { + out[i] = candidates[i] + accepted++ + text := candidates[i].Text + if text == "" { + text = candidates[i].Value + } + totalText += len(text) + } else { + out[i] = target[i] + if i < candidateLen { + rejected++ + } + text := target[i].Text + if text == "" { + text = target[i].Value + } + totalText += len(text) + } + } + attempted := accepted + rejected + metrics := Metrics{ + AcceptedTokens: accepted, + RejectedTokens: rejected, + EmittedTokens: limit, + } + if attempted > 0 { + metrics.AcceptanceRate = float64(accepted) / float64(attempted) + } + return Result{ + Mode: mode, + Prompt: prompt, + Text: tokensTextSized(out, totalText), + Tokens: out, + Metrics: metrics, + } +} + +func normaliseMaxTokens(values ...int) int { + for _, value := range values { + if value > 0 { + return value + } + } + return DefaultMaxTokens +} + +// tokenSurface returns the token's surface form, preferring Text over +// Value. Inlined two-arg path used by every accept/reject decision; the +// previous variadic firstNonEmpty allocated a []string per call. +func tokenSurface(t Token) string { + if hasNonSpace(t.Text) { + return t.Text + } + if hasNonSpace(t.Value) { + return t.Value + } + return "" +} + +// hasNonSpace reports whether s contains any non-whitespace byte. Avoids +// strings.TrimSpace's per-call string allocation when the input contains +// leading or trailing whitespace. Falls back to core.Trim on multi-byte +// input to preserve Unicode whitespace semantics. +func hasNonSpace(s string) bool { + for i := 0; i < len(s); i++ { + c := s[i] + if c >= 0x80 { + // Multi-byte rune may include Unicode whitespace + // (NBSP, ideographic space, etc.); defer to core.Trim. + return core.Trim(s) != "" + } + switch c { + case ' ', '\t', '\n', '\v', '\f', '\r': + continue + default: + return true + } + } + return false +} + +func nonZeroDuration(d time.Duration) time.Duration { + if d <= 0 { + return time.Nanosecond + } + return d +} diff --git a/go/decode/decode_bench_test.go b/go/decode/decode_bench_test.go new file mode 100644 index 0000000..adccbb2 --- /dev/null +++ b/go/decode/decode_bench_test.go @@ -0,0 +1,311 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the driver-neutral decode-optimisation harness — +// Speculative + PromptLookup over synthetic generators, plus the +// per-token equality, render, and clone primitives. +// +// Per AX-11 — Speculative + PromptLookup fire once per decode bench +// run, but the inner buildAcceptanceResult loop calls TokenEqual + +// cloneToken per emitted token, and TokensText concatenates the whole +// stream. The longest streams the harness sees today are 2048 tokens. +// +// Run: go test -bench='BenchmarkDecode' -benchmem -run='^$' ./go/decode + +package decode + +import ( + "context" + "testing" + "time" +) + +// Sinks defeat compiler DCE. +var ( + decodeSinkResult Result + decodeSinkErr error + decodeSinkText string + decodeSinkTokens []Token + decodeSinkBool bool + decodeSinkInt int + decodeSinkDur time.Duration +) + +// buildDecodeTokens mints n Tokens with a representative ID + Text +// shape (no Value — drivers populate one or the other, not both, +// in the typical hot path). +func buildDecodeTokens(n int) []Token { + tokens := make([]Token, n) + for i := 0; i < n; i++ { + tokens[i] = Token{ID: int32(i + 1), Text: "tok"} + } + return tokens +} + +// buildDecodeTokensSkewed mints n Tokens where every 4th token +// disagrees with the target — exercises the reject branch in +// buildAcceptanceResult. +func buildDecodeTokensSkewed(n int) []Token { + tokens := make([]Token, n) + for i := 0; i < n; i++ { + id := int32(i + 1) + if i%4 == 3 { + id = -id + } + tokens[i] = Token{ID: id, Text: "tok"} + } + return tokens +} + +// scriptGen wraps a fixed token stream in a GenerateFunc. +func scriptGen(tokens []Token) GenerateFunc { + return func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: tokens}, nil + } +} + +// --- Speculative + PromptLookup end-to-end --- + +func BenchmarkDecode_Speculative_32Tokens(b *testing.B) { + target := scriptGen(buildDecodeTokens(32)) + draft := scriptGen(buildDecodeTokens(32)) + ctx := context.Background() + cfg := SpeculativeConfig{Prompt: "p", MaxTokens: 32, DraftTokens: 32, TargetGenerate: target, DraftGenerate: draft} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = Speculative(ctx, cfg) + } +} + +func BenchmarkDecode_Speculative_256Tokens(b *testing.B) { + target := scriptGen(buildDecodeTokens(256)) + draft := scriptGen(buildDecodeTokens(256)) + ctx := context.Background() + cfg := SpeculativeConfig{Prompt: "p", MaxTokens: 256, DraftTokens: 256, TargetGenerate: target, DraftGenerate: draft} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = Speculative(ctx, cfg) + } +} + +func BenchmarkDecode_Speculative_2048Tokens(b *testing.B) { + target := scriptGen(buildDecodeTokens(2048)) + draft := scriptGen(buildDecodeTokens(2048)) + ctx := context.Background() + cfg := SpeculativeConfig{Prompt: "p", MaxTokens: 2048, DraftTokens: 2048, TargetGenerate: target, DraftGenerate: draft} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = Speculative(ctx, cfg) + } +} + +// Skewed exercises the reject path inside buildAcceptanceResult — every +// 4th draft token mismatches, forcing a fallback append. +func BenchmarkDecode_Speculative_256Tokens_25PctReject(b *testing.B) { + target := scriptGen(buildDecodeTokens(256)) + draft := scriptGen(buildDecodeTokensSkewed(256)) + ctx := context.Background() + cfg := SpeculativeConfig{Prompt: "p", MaxTokens: 256, DraftTokens: 256, TargetGenerate: target, DraftGenerate: draft} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = Speculative(ctx, cfg) + } +} + +func BenchmarkDecode_PromptLookup_32Tokens(b *testing.B) { + target := scriptGen(buildDecodeTokens(32)) + ctx := context.Background() + cfg := PromptLookupConfig{Prompt: "p", MaxTokens: 32, TargetGenerate: target, LookupTokens: buildDecodeTokens(32)} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = PromptLookup(ctx, cfg) + } +} + +func BenchmarkDecode_PromptLookup_256Tokens(b *testing.B) { + target := scriptGen(buildDecodeTokens(256)) + ctx := context.Background() + cfg := PromptLookupConfig{Prompt: "p", MaxTokens: 256, TargetGenerate: target, LookupTokens: buildDecodeTokens(256)} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = PromptLookup(ctx, cfg) + } +} + +func BenchmarkDecode_PromptLookup_2048Tokens(b *testing.B) { + target := scriptGen(buildDecodeTokens(2048)) + ctx := context.Background() + cfg := PromptLookupConfig{Prompt: "p", MaxTokens: 2048, TargetGenerate: target, LookupTokens: buildDecodeTokens(2048)} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = PromptLookup(ctx, cfg) + } +} + +// --- buildAcceptanceResult in isolation (the inner loop both +// Speculative + PromptLookup share) --- + +func BenchmarkDecode_BuildAcceptance_32Tokens(b *testing.B) { + target := buildDecodeTokens(32) + candidates := buildDecodeTokens(32) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult = buildAcceptanceResult(ModeSpeculative, "p", target, candidates, 32) + } +} + +func BenchmarkDecode_BuildAcceptance_256Tokens(b *testing.B) { + target := buildDecodeTokens(256) + candidates := buildDecodeTokens(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult = buildAcceptanceResult(ModeSpeculative, "p", target, candidates, 256) + } +} + +func BenchmarkDecode_BuildAcceptance_2048Tokens(b *testing.B) { + target := buildDecodeTokens(2048) + candidates := buildDecodeTokens(2048) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult = buildAcceptanceResult(ModeSpeculative, "p", target, candidates, 2048) + } +} + +// --- TokensText (renders the emitted stream into the Result.Text) --- + +func BenchmarkDecode_TokensText_32Tokens(b *testing.B) { + tokens := buildDecodeTokens(32) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkText = TokensText(tokens) + } +} + +func BenchmarkDecode_TokensText_256Tokens(b *testing.B) { + tokens := buildDecodeTokens(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkText = TokensText(tokens) + } +} + +func BenchmarkDecode_TokensText_2048Tokens(b *testing.B) { + tokens := buildDecodeTokens(2048) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkText = TokensText(tokens) + } +} + +// --- CloneTokens (fires per accepted token in buildAcceptanceResult, +// plus once per result handoff) --- + +func BenchmarkDecode_CloneTokens_32Tokens(b *testing.B) { + tokens := buildDecodeTokens(32) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkTokens = CloneTokens(tokens) + } +} + +func BenchmarkDecode_CloneTokens_256Tokens(b *testing.B) { + tokens := buildDecodeTokens(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkTokens = CloneTokens(tokens) + } +} + +func BenchmarkDecode_CloneTokens_2048Tokens(b *testing.B) { + tokens := buildDecodeTokens(2048) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkTokens = CloneTokens(tokens) + } +} + +// --- TokenEqual (per-token branch — text-vs-value-vs-empty paths) --- + +func BenchmarkDecode_TokenEqual_BothTextEqual(b *testing.B) { + a := Token{ID: 1, Text: "abcdef"} + c := Token{ID: 1, Text: "abcdef"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkBool = TokenEqual(a, c) + } +} + +func BenchmarkDecode_TokenEqual_IDMismatch(b *testing.B) { + a := Token{ID: 1, Text: "abcdef"} + c := Token{ID: 2, Text: "abcdef"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkBool = TokenEqual(a, c) + } +} + +func BenchmarkDecode_TokenEqual_EmptyTextSkipsCompare(b *testing.B) { + a := Token{ID: 1} + c := Token{ID: 1, Text: "abcdef"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkBool = TokenEqual(a, c) + } +} + +// --- normaliseMaxTokens (called twice per Speculative / once per +// PromptLookup) --- + +func BenchmarkDecode_NormaliseMaxTokens_FirstPositive(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkInt = normaliseMaxTokens(64, 0, 0) + } +} + +func BenchmarkDecode_NormaliseMaxTokens_FallsThrough(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkInt = normaliseMaxTokens(0, 0, 0) + } +} + +// --- nonZeroDuration (fires three times per decode call) --- + +func BenchmarkDecode_NonZeroDuration_Positive(b *testing.B) { + d := 45 * time.Millisecond + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkDur = nonZeroDuration(d) + } +} + +func BenchmarkDecode_NonZeroDuration_Zero(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkDur = nonZeroDuration(0) + } +} diff --git a/go/decode/decode_test.go b/go/decode/decode_test.go new file mode 100644 index 0000000..39384ae --- /dev/null +++ b/go/decode/decode_test.go @@ -0,0 +1,242 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package decode + +import ( + "context" + "errors" + "testing" + "time" +) + +func TestSpeculative_AcceptsAndRejectsDraftTokens_Good(t *testing.T) { + targetCalls := 0 + draftCalls := 0 + target := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + targetCalls++ + return Generation{Tokens: []Token{{ID: 1, Text: "A"}, {ID: 2, Text: "B"}, {ID: 4, Text: "D"}}}, nil + }) + draft := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + draftCalls++ + return Generation{Tokens: []Token{{ID: 1, Text: "A"}, {ID: 2, Text: "B"}, {ID: 3, Text: "C"}}}, nil + }) + + result, err := Speculative(context.Background(), SpeculativeConfig{ + Prompt: "p", + MaxTokens: 3, + DraftTokens: 3, + TargetGenerate: target, + DraftGenerate: draft, + }) + if err != nil { + t.Fatalf("Speculative() error = %v", err) + } + if result.Mode != ModeSpeculative { + t.Fatalf("Mode = %q, want %q", result.Mode, ModeSpeculative) + } + if result.Text != "ABD" { + t.Fatalf("Text = %q, want ABD", result.Text) + } + if result.Metrics.AcceptedTokens != 2 || result.Metrics.RejectedTokens != 1 || result.Metrics.AcceptanceRate != 2.0/3.0 { + t.Fatalf("metrics = %+v, want two accepted + one rejected", result.Metrics) + } + if result.Metrics.TargetCalls != 1 || result.Metrics.DraftCalls != 1 || targetCalls != 1 || draftCalls != 1 { + t.Fatalf("calls = metrics:%+v target:%d draft:%d, want one each", result.Metrics, targetCalls, draftCalls) + } + if result.Metrics.Duration <= 0 || result.Metrics.TargetDuration <= 0 || result.Metrics.DraftDuration <= 0 { + t.Fatalf("durations not populated: %+v", result.Metrics) + } +} + +func TestPromptLookup_AcceptsRepeatedContextTokens_Good(t *testing.T) { + target := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: []Token{{ID: 10, Text: "go"}, {ID: 11, Text: "-"}, {ID: 12, Text: "mlx"}}}, nil + }) + + result, err := PromptLookup(context.Background(), PromptLookupConfig{ + Prompt: "go-mlx go-mlx", + MaxTokens: 3, + TargetGenerate: target, + LookupTokens: []Token{{ID: 10, Text: "go"}, {ID: 99, Text: "?"}, {ID: 12, Text: "mlx"}}, + }) + if err != nil { + t.Fatalf("PromptLookup() error = %v", err) + } + if result.Mode != ModePromptLookup { + t.Fatalf("Mode = %q, want %q", result.Mode, ModePromptLookup) + } + if result.Text != "go-mlx" { + t.Fatalf("Text = %q, want go-mlx", result.Text) + } + if result.Metrics.AcceptedTokens != 2 || result.Metrics.RejectedTokens != 1 || result.Metrics.LookupTokens != 3 { + t.Fatalf("metrics = %+v, want two accepts + one rejection + 3 lookup tokens", result.Metrics) + } + if result.Metrics.TargetCalls != 1 || result.Metrics.DraftCalls != 0 { + t.Fatalf("calls = %+v, want target=1 draft=0", result.Metrics) + } +} + +func TestSpeculative_RequiresTargetAndDraft_Bad(t *testing.T) { + if _, err := Speculative(context.Background(), SpeculativeConfig{}); err == nil { + t.Fatal("Speculative(zero) error = nil, want missing-target") + } + dummy := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { return Generation{}, nil }) + if _, err := Speculative(context.Background(), SpeculativeConfig{TargetGenerate: dummy}); err == nil { + t.Fatal("Speculative(target-only) error = nil, want missing-draft") + } +} + +func TestPromptLookup_RequiresTarget_Bad(t *testing.T) { + if _, err := PromptLookup(context.Background(), PromptLookupConfig{}); err == nil { + t.Fatal("PromptLookup(zero) error = nil, want missing-target") + } +} + +func TestSpeculative_PropagatesDraftError_Bad(t *testing.T) { + want := errors.New("draft boom") + target := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: []Token{{ID: 1}}}, nil + }) + draft := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { return Generation{}, want }) + if _, err := Speculative(context.Background(), SpeculativeConfig{ + Prompt: "p", MaxTokens: 4, TargetGenerate: target, DraftGenerate: draft, + }); err == nil { + t.Fatal("Speculative() did not propagate draft error") + } +} + +func TestSpeculative_PropagatesTargetError_Bad(t *testing.T) { + want := errors.New("target boom") + target := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { return Generation{}, want }) + draft := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: []Token{{ID: 1}}}, nil + }) + if _, err := Speculative(context.Background(), SpeculativeConfig{ + Prompt: "p", MaxTokens: 4, TargetGenerate: target, DraftGenerate: draft, + }); err == nil { + t.Fatal("Speculative() did not propagate target error") + } +} + +func TestPromptLookup_PropagatesTargetError_Bad(t *testing.T) { + want := errors.New("target boom") + target := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { return Generation{}, want }) + if _, err := PromptLookup(context.Background(), PromptLookupConfig{ + Prompt: "p", MaxTokens: 4, TargetGenerate: target, + }); err == nil { + t.Fatal("PromptLookup() did not propagate target error") + } +} + +func TestSpeculative_NilContextDefaultsToBackground_Good(t *testing.T) { + target := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: []Token{{ID: 1, Text: "x"}}}, nil + }) + draft := target + if _, err := Speculative(nil, SpeculativeConfig{ + Prompt: "p", MaxTokens: 1, TargetGenerate: target, DraftGenerate: draft, + }); err != nil { + t.Fatalf("Speculative(nil ctx) error = %v", err) + } +} + +func TestPromptLookup_NilContextDefaultsToBackground_Good(t *testing.T) { + target := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: []Token{{ID: 1, Text: "x"}}}, nil + }) + if _, err := PromptLookup(nil, PromptLookupConfig{ + Prompt: "p", MaxTokens: 1, TargetGenerate: target, + }); err != nil { + t.Fatalf("PromptLookup(nil ctx) error = %v", err) + } +} + +func TestTokenEqual_GoodBad(t *testing.T) { + if !TokenEqual(Token{ID: 1, Text: "a"}, Token{ID: 1, Text: "a"}) { + t.Fatal("identical tokens reported unequal") + } + if TokenEqual(Token{ID: 1, Text: "a"}, Token{ID: 2, Text: "a"}) { + t.Fatal("different IDs reported equal") + } + if TokenEqual(Token{ID: 1, Text: "a"}, Token{ID: 1, Text: "b"}) { + t.Fatal("different non-empty texts reported equal") + } + if !TokenEqual(Token{ID: 1}, Token{ID: 1, Text: "a"}) { + t.Fatal("empty-text token did not skip text comparison") + } + if !TokenEqual(Token{ID: 1, Value: "x"}, Token{ID: 1, Value: "x"}) { + t.Fatal("Value-only equality not honoured") + } +} + +func TestTokensText_PrefersTextOverValue_Good(t *testing.T) { + got := TokensText([]Token{{Text: "go"}, {Value: "-"}, {Text: "mlx", Value: "ignored"}}) + if got != "go-mlx" { + t.Fatalf("TokensText = %q, want go-mlx", got) + } +} + +func TestCloneTokens_IndependentCopy_Good(t *testing.T) { + src := []Token{{ID: 1, Text: "a"}, {ID: 2, Text: "b"}} + dst := CloneTokens(src) + src[0].ID = 99 + if dst[0].ID == 99 { + t.Fatal("CloneTokens did not produce independent copy") + } +} + +func TestSpeculative_MaxTokensClampsTargetWindow_Good(t *testing.T) { + target := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: []Token{{ID: 1, Text: "A"}, {ID: 2, Text: "B"}, {ID: 3, Text: "C"}}}, nil + }) + draft := target + result, err := Speculative(context.Background(), SpeculativeConfig{ + Prompt: "p", MaxTokens: 2, TargetGenerate: target, DraftGenerate: draft, + }) + if err != nil { + t.Fatalf("Speculative() error = %v", err) + } + if result.Metrics.EmittedTokens != 2 { + t.Fatalf("EmittedTokens = %d, want 2 (clamped by MaxTokens)", result.Metrics.EmittedTokens) + } +} + +func TestSpeculative_DraftTokensClampedToMaxTokens_Good(t *testing.T) { + var draftMax int + target := GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: []Token{{ID: 1}}}, nil + }) + draft := GeneratorFunc(func(_ context.Context, _ string, cfg GenerateConfig) (Generation, error) { + draftMax = cfg.MaxTokens + return Generation{Tokens: []Token{{ID: 1}}}, nil + }) + if _, err := Speculative(context.Background(), SpeculativeConfig{ + Prompt: "p", MaxTokens: 4, DraftTokens: 99, TargetGenerate: target, DraftGenerate: draft, + }); err != nil { + t.Fatalf("Speculative() error = %v", err) + } + if draftMax != 4 { + t.Fatalf("draft cfg.MaxTokens = %d, want clamped to MaxTokens=4", draftMax) + } +} + +func TestNormaliseMaxTokens_FirstPositiveOrDefault_Good(t *testing.T) { + if got := normaliseMaxTokens(0, 0, 7); got != 7 { + t.Fatalf("normaliseMaxTokens(0,0,7) = %d, want 7", got) + } + if got := normaliseMaxTokens(0, 0); got != DefaultMaxTokens { + t.Fatalf("normaliseMaxTokens(0,0) = %d, want DefaultMaxTokens=%d", got, DefaultMaxTokens) + } +} + +func TestNonZeroDuration_ClampsToNanosecond_Ugly(t *testing.T) { + if got := nonZeroDuration(0); got != time.Nanosecond { + t.Fatalf("nonZeroDuration(0) = %v, want 1ns", got) + } + if got := nonZeroDuration(-5); got != time.Nanosecond { + t.Fatalf("nonZeroDuration(-5) = %v, want 1ns", got) + } + if got := nonZeroDuration(7 * time.Millisecond); got != 7*time.Millisecond { + t.Fatalf("nonZeroDuration(7ms) = %v, want passthrough", got) + } +} diff --git a/go/decode/edge_bench_test.go b/go/decode/edge_bench_test.go new file mode 100644 index 0000000..7479ffc --- /dev/null +++ b/go/decode/edge_bench_test.go @@ -0,0 +1,189 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Deeper-edge benchmarks for the decode harness — covers acceptance +// branches the happy-path benches in decode_bench_test.go don't reach: +// all-reject, single-accept-then-reject, candidates-shorter-than-target, +// candidates-longer-than-target, and the NormaliseMaxTokens edges +// (negative, zero, max-int, every-arg-positive). +// +// Per AX-11 — buildAcceptanceResult is the inner loop both Speculative +// and PromptLookup share; its branch shape depends on whether the +// candidate stream agrees with target. The existing 25-pct-reject bench +// covers the typical mixed path; this file covers the extremes so the +// allocator profile under fully-rejected (worst-case cloneToken count) +// and fully-accepted (best-case) is visible alongside. +// +// normaliseMaxTokens is called twice per Speculative / once per +// PromptLookup; the existing benches cover "first positive" and "falls +// through". The edge variants (negative / int-max / mixed) catch the +// rare-but-real configurations callers can pass through GenerateConfig. +// +// Run: go test -bench='BenchmarkDecode_Edge' -benchmem -run='^$' ./go/decode + +package decode + +import ( + "context" + "math" + "testing" +) + +// buildDecodeTokensAllReject mints n Tokens where every token disagrees +// with the target via a flipped sign on ID — exercises the maximum +// reject path in buildAcceptanceResult (every iteration takes the +// fallback append). This is the worst-case for cloneToken volume since +// every emitted token is a target clone rather than a candidate clone. +func buildDecodeTokensAllReject(n int) []Token { + tokens := make([]Token, n) + for i := 0; i < n; i++ { + tokens[i] = Token{ID: -int32(i + 1), Text: "tok"} + } + return tokens +} + +// buildDecodeTokensFirstAcceptThenReject mints n Tokens where token 0 +// matches the target and the remainder reject — the "single hit at +// start" shape some prompt-lookup callers see (first cache-hit then +// drift). Catches branch-predictor flips between accept and reject. +func buildDecodeTokensFirstAcceptThenReject(n int) []Token { + tokens := make([]Token, n) + tokens[0] = Token{ID: 1, Text: "tok"} + for i := 1; i < n; i++ { + tokens[i] = Token{ID: -int32(i + 1), Text: "tok"} + } + return tokens +} + +// --- buildAcceptanceResult edges (256-token shape stress-tests +// branch density without dominating the bench in append growth) --- + +func BenchmarkDecode_Edge_BuildAcceptance_AllAccept_256(b *testing.B) { + target := buildDecodeTokens(256) + candidates := buildDecodeTokens(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult = buildAcceptanceResult(ModeSpeculative, "p", target, candidates, 256) + } +} + +func BenchmarkDecode_Edge_BuildAcceptance_AllReject_256(b *testing.B) { + target := buildDecodeTokens(256) + candidates := buildDecodeTokensAllReject(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult = buildAcceptanceResult(ModeSpeculative, "p", target, candidates, 256) + } +} + +func BenchmarkDecode_Edge_BuildAcceptance_FirstAcceptThenReject_256(b *testing.B) { + target := buildDecodeTokens(256) + candidates := buildDecodeTokensFirstAcceptThenReject(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult = buildAcceptanceResult(ModeSpeculative, "p", target, candidates, 256) + } +} + +// CandidatesShorterThanTarget — the typical prompt-lookup miss path +// where the lookup table runs out before the target stream is exhausted +// and the loop falls through to "no candidate, append target". +func BenchmarkDecode_Edge_BuildAcceptance_CandidatesShorterThanTarget_256(b *testing.B) { + target := buildDecodeTokens(256) + candidates := buildDecodeTokens(64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult = buildAcceptanceResult(ModeSpeculative, "p", target, candidates, 256) + } +} + +// CandidatesLongerThanTarget — speculative drafts that overshoot the +// target; extra candidates are silently discarded by the limit cap. +// Exercises the limit-clamp path that bounds 'out' to len(target). +func BenchmarkDecode_Edge_BuildAcceptance_CandidatesLongerThanTarget_256(b *testing.B) { + target := buildDecodeTokens(256) + candidates := buildDecodeTokens(512) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult = buildAcceptanceResult(ModeSpeculative, "p", target, candidates, 256) + } +} + +// MaxTokensClampsTarget — emulates the case where the caller's +// MaxTokens is tighter than the target stream; out is sized to +// maxTokens and the loop short-circuits early. Validates the limit +// branch above the 'limit = len(target)' default. +func BenchmarkDecode_Edge_BuildAcceptance_MaxTokensClampsTarget_256(b *testing.B) { + target := buildDecodeTokens(2048) + candidates := buildDecodeTokens(2048) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult = buildAcceptanceResult(ModeSpeculative, "p", target, candidates, 256) + } +} + +// --- normaliseMaxTokens edges (called twice per Speculative, +// once per PromptLookup) --- + +func BenchmarkDecode_Edge_NormaliseMaxTokens_Negative(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkInt = normaliseMaxTokens(-1, 0, 0) + } +} + +func BenchmarkDecode_Edge_NormaliseMaxTokens_MaxInt(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkInt = normaliseMaxTokens(math.MaxInt32, 0, 0) + } +} + +// MixedNegativesThenPositive — first two args reject, third returns. +// Exercises the loop continuation path beyond the simple "first +// positive" benchmark. +func BenchmarkDecode_Edge_NormaliseMaxTokens_MixedNegativesThenPositive(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkInt = normaliseMaxTokens(-1, -1, 128) + } +} + +// --- Speculative end-to-end under the all-reject shape — the +// scheduler-adjacent dominant cost is target-clone count, not +// candidate-clone; this is the worst-case for that. --- + +func BenchmarkDecode_Edge_Speculative_AllReject_256Tokens(b *testing.B) { + target := scriptGen(buildDecodeTokens(256)) + draft := scriptGen(buildDecodeTokensAllReject(256)) + ctx := context.Background() + cfg := SpeculativeConfig{Prompt: "p", MaxTokens: 256, DraftTokens: 256, TargetGenerate: target, DraftGenerate: draft} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = Speculative(ctx, cfg) + } +} + +// PromptLookup_EmptyCache — the cold-start lookup case the harness +// will see during the first few tokens of a long generation, before +// the lookup table has been populated by repeated context. Candidates +// is nil so every iteration falls through to the target append. +func BenchmarkDecode_Edge_PromptLookup_EmptyCache_256Tokens(b *testing.B) { + target := scriptGen(buildDecodeTokens(256)) + ctx := context.Background() + cfg := PromptLookupConfig{Prompt: "p", MaxTokens: 256, TargetGenerate: target, LookupTokens: nil} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = PromptLookup(ctx, cfg) + } +} diff --git a/go/decode/example_test.go b/go/decode/example_test.go new file mode 100644 index 0000000..d6df759 --- /dev/null +++ b/go/decode/example_test.go @@ -0,0 +1,32 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package decode + +import core "dappco.re/go" + +// Generated runnable examples for file-aware public API coverage. + +func ExampleSpeculative() { + core.Println("Speculative") + // Output: Speculative +} + +func ExamplePromptLookup() { + core.Println("PromptLookup") + // Output: PromptLookup +} + +func ExampleTokenEqual() { + core.Println("TokenEqual") + // Output: TokenEqual +} + +func ExampleTokensText() { + core.Println("TokensText") + // Output: TokensText +} + +func ExampleCloneTokens() { + core.Println("CloneTokens") + // Output: CloneTokens +} diff --git a/go/decode/generator_iface_bench_test.go b/go/decode/generator_iface_bench_test.go new file mode 100644 index 0000000..3726695 --- /dev/null +++ b/go/decode/generator_iface_bench_test.go @@ -0,0 +1,203 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the Generator-interface migration (W11-L). The hot +// path question is: does an interface field cost more, less, or the +// same as the previous func-typed field for callers that build a +// fresh generator per call (the dominant go-mlx shape today)? +// +// Three shapes are bench'd against the same Speculative + PromptLookup +// inner loop: +// +// - ClosurePerCall — caller mints a fresh `func` per Speculative call +// and assigns it to TargetGenerate / DraftGenerate. Wraps with +// GeneratorFunc on assignment, but the closure itself escapes +// because it captures the per-iteration tokens slice. This is the +// shape every backend driver in go-cuda / go-rocm / go-mlx uses +// today, and the one W11-L is designed to give them a cheaper +// alternative to. +// +// - PreboundFunc — caller builds the GeneratorFunc once (outside +// the timed loop) and reuses the same value across every call. No +// per-call closure alloc — the closure was paid once. This is the +// existing decode bench shape; included here for direct comparison. +// +// - PooledStruct — caller's Generator is a struct with a sync.Pool +// for the per-call state and a Generate method on the pooled value. +// Zero closure allocs because no closure exists; the interface +// dispatch goes straight to the struct method. This is the shape +// W11-L enables and the one go-mlx will adopt in the follow-up +// `modelDecodeGenerate`-to-struct migration. +// +// Realistic goal: PooledStruct demonstrates a strict alloc-count +// reduction vs ClosurePerCall while staying within noise of PreboundFunc +// on wall time — i.e. the interface dispatch overhead is amortised +// away the moment the closure alloc disappears. +// +// Run: go test -bench='BenchmarkDecode_GeneratorShape' -benchmem -run='^$' ./go/decode + +package decode + +import ( + "context" + "sync" + "testing" +) + +// pooledScriptGenerator is the win-demonstrating shape: a struct that +// implements Generator on a value receiver, served by a sync.Pool. +// `tokens` is set per acquisition; Generate hands the slice back +// without re-allocating. The pool ensures the struct itself is +// recycled across calls — zero allocation in the steady state. +type pooledScriptGenerator struct { + tokens []Token +} + +// Generate satisfies decode.Generator. Value receiver: no per-call +// pointer alloc when the struct is held by value (or by *pool*). +func (g *pooledScriptGenerator) Generate(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: g.tokens}, nil +} + +// genPool recycles pooledScriptGenerator instances across the bench +// loop. In production this is the modelDecodeGenerator pool described +// in W11-L follow-up. +var genPool = sync.Pool{ + New: func() any { return &pooledScriptGenerator{} }, +} + +// acquirePooledGen rents a generator from the pool and parks the +// tokens slice on it. Caller is expected to call releasePooledGen +// directly — returning a release closure would heap-allocate the +// closure on every call and drown the whole win we're trying to +// measure. The straight pointer API is the production-realistic +// shape (go-mlx's modelDecodeGenerate follow-up will do the same). +func acquirePooledGen(tokens []Token) *pooledScriptGenerator { + g := genPool.Get().(*pooledScriptGenerator) + g.tokens = tokens + return g +} + +// releasePooledGen recycles a generator back to the pool. Caller is +// responsible for not touching the struct after the release call. +func releasePooledGen(g *pooledScriptGenerator) { + g.tokens = nil + genPool.Put(g) +} + +// --- Speculative — three shapes side-by-side at 256 tokens --- + +// ClosurePerCall — the shape every driver uses today. Closure captures +// `tokens` so it escapes; one alloc per Speculative call before decode +// even runs. +func BenchmarkDecode_GeneratorShape_Speculative_ClosurePerCall_256(b *testing.B) { + targetTokens := buildDecodeTokens(256) + draftTokens := buildDecodeTokens(256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + cfg := SpeculativeConfig{ + Prompt: "p", + MaxTokens: 256, + DraftTokens: 256, + TargetGenerate: GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: targetTokens}, nil + }), + DraftGenerate: GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: draftTokens}, nil + }), + } + decodeSinkResult, decodeSinkErr = Speculative(ctx, cfg) + } +} + +// PreboundFunc — the existing decode bench shape. The closure was +// paid once outside the timed loop; only the inner-loop allocs show. +func BenchmarkDecode_GeneratorShape_Speculative_PreboundFunc_256(b *testing.B) { + target := scriptGen(buildDecodeTokens(256)) + draft := scriptGen(buildDecodeTokens(256)) + ctx := context.Background() + cfg := SpeculativeConfig{Prompt: "p", MaxTokens: 256, DraftTokens: 256, TargetGenerate: target, DraftGenerate: draft} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = Speculative(ctx, cfg) + } +} + +// PooledStruct — the W11-L-enabled shape. Per call: pool Get (no +// alloc when the pool is warm), interface dispatch into Generate, +// pool Put. Zero closure allocs because there is no closure. +func BenchmarkDecode_GeneratorShape_Speculative_PooledStruct_256(b *testing.B) { + targetTokens := buildDecodeTokens(256) + draftTokens := buildDecodeTokens(256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + target := acquirePooledGen(targetTokens) + draft := acquirePooledGen(draftTokens) + cfg := SpeculativeConfig{ + Prompt: "p", + MaxTokens: 256, + DraftTokens: 256, + TargetGenerate: target, + DraftGenerate: draft, + } + decodeSinkResult, decodeSinkErr = Speculative(ctx, cfg) + releasePooledGen(draft) + releasePooledGen(target) + } +} + +// --- PromptLookup — three shapes side-by-side at 256 tokens --- + +func BenchmarkDecode_GeneratorShape_PromptLookup_ClosurePerCall_256(b *testing.B) { + targetTokens := buildDecodeTokens(256) + lookupTokens := buildDecodeTokens(256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + cfg := PromptLookupConfig{ + Prompt: "p", + MaxTokens: 256, + TargetGenerate: GeneratorFunc(func(context.Context, string, GenerateConfig) (Generation, error) { + return Generation{Tokens: targetTokens}, nil + }), + LookupTokens: lookupTokens, + } + decodeSinkResult, decodeSinkErr = PromptLookup(ctx, cfg) + } +} + +func BenchmarkDecode_GeneratorShape_PromptLookup_PreboundFunc_256(b *testing.B) { + target := scriptGen(buildDecodeTokens(256)) + lookupTokens := buildDecodeTokens(256) + ctx := context.Background() + cfg := PromptLookupConfig{Prompt: "p", MaxTokens: 256, TargetGenerate: target, LookupTokens: lookupTokens} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkResult, decodeSinkErr = PromptLookup(ctx, cfg) + } +} + +func BenchmarkDecode_GeneratorShape_PromptLookup_PooledStruct_256(b *testing.B) { + targetTokens := buildDecodeTokens(256) + lookupTokens := buildDecodeTokens(256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + target := acquirePooledGen(targetTokens) + cfg := PromptLookupConfig{ + Prompt: "p", + MaxTokens: 256, + TargetGenerate: target, + LookupTokens: lookupTokens, + } + decodeSinkResult, decodeSinkErr = PromptLookup(ctx, cfg) + releasePooledGen(target) + } +} diff --git a/go/decode/tokens_text_bench_test.go b/go/decode/tokens_text_bench_test.go new file mode 100644 index 0000000..06b61ca --- /dev/null +++ b/go/decode/tokens_text_bench_test.go @@ -0,0 +1,203 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Deeper TokensText + token-surface benchmarks. The existing bench +// suite covers all-Text streams; this file adds mixed Text+Value +// (the tokenizer-emitting-both case some drivers see), all-Value +// (when the tokenizer can't render UTF-8 but can emit byte +// sequences), tokens-with-whitespace-only (hasNonSpace tight loop), +// and tokens-with-Unicode-whitespace (the multi-byte core.Trim +// fallback path). +// +// Per AX-11 — TokensText runs once per Speculative + PromptLookup +// call but iterates the whole stream twice (pre-grow walk + write +// walk). The hot loop is tokenSurface → hasNonSpace, which has a +// fast ASCII path and a slower multi-byte path. Coverage on those +// two paths is the difference between knowing the cost and guessing. +// +// Run: go test -bench='BenchmarkDecode_TokensTextDeep' -benchmem -run='^$' ./go/decode + +package decode + +import ( + "testing" +) + +// buildDecodeTokensMixedTextValue mints n Tokens where half carry +// Text and half carry only Value — the tokenSurface fallback path +// triggers on every Value-only token. The existing all-Text and +// all-Value benches cover the pure paths; this one stresses the +// branch density and shows whether the fallback adds measurable +// per-token cost. +func buildDecodeTokensMixedTextValue(n int) []Token { + tokens := make([]Token, n) + for i := 0; i < n; i++ { + if i%2 == 0 { + tokens[i] = Token{ID: int32(i + 1), Text: "tok"} + } else { + tokens[i] = Token{ID: int32(i + 1), Value: "tok"} + } + } + return tokens +} + +// buildDecodeTokensAllValueOnly mints n Tokens where Text is empty +// and only Value is populated — the path some byte-sequence-only +// tokenizers (raw BPE, some classification heads) take. Stresses +// the tokenSurface Text-empty fallthrough. +func buildDecodeTokensAllValueOnly(n int) []Token { + tokens := make([]Token, n) + for i := 0; i < n; i++ { + tokens[i] = Token{ID: int32(i + 1), Value: "tok"} + } + return tokens +} + +// buildDecodeTokensWhitespaceOnly mints n Tokens whose Text is a +// pure-whitespace ASCII string — exercises the hasNonSpace inner +// loop where every byte is the "skip" case, forcing the longest +// straight-line read. Sentinel pattern for stride-of-whitespace +// content (markdown, structured output). +func buildDecodeTokensWhitespaceOnly(n int) []Token { + tokens := make([]Token, n) + for i := 0; i < n; i++ { + tokens[i] = Token{ID: int32(i + 1), Text: " \t\n"} + } + return tokens +} + +// buildDecodeTokensUnicodeWhitespace mints n Tokens whose Text is +// a non-breaking-space character (U+00A0, multi-byte UTF-8). Forces +// hasNonSpace into the core.Trim fallback on every token — the only +// reliable way to see that path's cost in isolation. +func buildDecodeTokensUnicodeWhitespace(n int) []Token { + tokens := make([]Token, n) + for i := 0; i < n; i++ { + tokens[i] = Token{ID: int32(i + 1), Text: "  "} + } + return tokens +} + +// buildDecodeTokensVariableLength mints n Tokens whose Text varies +// in length (1, 4, 16, 64 bytes cycled). Real token streams vary +// by ~2 orders of magnitude — bench against that, not against the +// constant-3-byte happy path. +func buildDecodeTokensVariableLength(n int) []Token { + lengths := []int{1, 4, 16, 64} + tokens := make([]Token, n) + for i := 0; i < n; i++ { + size := lengths[i%len(lengths)] + buf := make([]byte, size) + for j := 0; j < size; j++ { + buf[j] = byte('a' + (i % 26)) + } + tokens[i] = Token{ID: int32(i + 1), Text: string(buf)} + } + return tokens +} + +// --- TokensText over mixed / Value-only / whitespace / Unicode --- + +func BenchmarkDecode_TokensTextDeep_MixedTextValue_256(b *testing.B) { + tokens := buildDecodeTokensMixedTextValue(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkText = TokensText(tokens) + } +} + +func BenchmarkDecode_TokensTextDeep_MixedTextValue_2048(b *testing.B) { + tokens := buildDecodeTokensMixedTextValue(2048) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkText = TokensText(tokens) + } +} + +func BenchmarkDecode_TokensTextDeep_AllValueOnly_256(b *testing.B) { + tokens := buildDecodeTokensAllValueOnly(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkText = TokensText(tokens) + } +} + +func BenchmarkDecode_TokensTextDeep_VariableLength_256(b *testing.B) { + tokens := buildDecodeTokensVariableLength(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkText = TokensText(tokens) + } +} + +// --- TokenEqual surface-form edges --- + +// BothValueOnlyEqual — tokens carry only Value, the same Value; +// TokenEqual must agree but takes the Value-side branch. +func BenchmarkDecode_TokensTextDeep_TokenEqual_BothValueOnly(b *testing.B) { + a := Token{ID: 1, Value: "abcdef"} + c := Token{ID: 1, Value: "abcdef"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkBool = TokenEqual(a, c) + } +} + +// TextMismatch — IDs agree but Text strings differ. Forces the full +// string compare to reach the not-equal verdict. The existing benches +// cover the equal and ID-mismatch cases; this is the +// always-runs-the-compare path. +func BenchmarkDecode_TokensTextDeep_TokenEqual_TextMismatch(b *testing.B) { + a := Token{ID: 1, Text: "abcdef"} + c := Token{ID: 1, Text: "abcxyz"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkBool = TokenEqual(a, c) + } +} + +// LongTextEqual — typical chat token is ~3 bytes, but punctuation +// runs and code-block tokens can hit 32+. Tests the strcmp path +// at a length closer to worst-case. +func BenchmarkDecode_TokensTextDeep_TokenEqual_LongTextEqual(b *testing.B) { + a := Token{ID: 1, Text: "abcdefghijklmnopqrstuvwxyz0123456"} + c := Token{ID: 1, Text: "abcdefghijklmnopqrstuvwxyz0123456"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkBool = TokenEqual(a, c) + } +} + +// WhitespaceOnlyTextSkipsCompare — text is whitespace-only on +// both sides; tokenSurface treats them as "empty" via hasNonSpace +// and the compare short-circuits to true. The skip-compare branch +// at non-empty-but-meaningless input. +func BenchmarkDecode_TokensTextDeep_TokenEqual_WhitespaceOnlyTextSkipsCompare(b *testing.B) { + a := Token{ID: 1, Text: " \t\n"} + c := Token{ID: 1, Text: "\r\n "} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkBool = TokenEqual(a, c) + } +} + +// UnicodeWhitespaceSkipsCompare — multi-byte whitespace forces the +// hasNonSpace core.Trim fallback; tokenSurface still resolves to +// "empty" and the compare short-circuits. Validates the slow path +// reaches the same answer as the fast path. +func BenchmarkDecode_TokensTextDeep_TokenEqual_UnicodeWhitespaceSkipsCompare(b *testing.B) { + a := Token{ID: 1, Text: "  "} + c := Token{ID: 1, Text: " "} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + decodeSinkBool = TokenEqual(a, c) + } +} diff --git a/go/discover.go b/go/discover.go index 87dc2b2..796550b 100644 --- a/go/discover.go +++ b/go/discover.go @@ -3,7 +3,6 @@ package inference import ( "cmp" "iter" - "reflect" "slices" core "dappco.re/go" @@ -13,11 +12,14 @@ import ( // fmt.Printf("%s arch=%s quant=%dbit\n", m.Path, m.ModelType, m.QuantBits) // } type DiscoveredModel struct { - Path string // Absolute path to the model directory - ModelType string // Architecture from config.json (e.g. "gemma3", "qwen3", "llama") - QuantBits int // Quantisation bits (0 if unquantised) - QuantGroup int // Quantisation group size - NumFiles int // Number of safetensors weight files + Path string // Absolute path to the model directory or GGUF file + ModelType string // Architecture from config.json/GGUF metadata + QuantBits int // Quantisation bits (0 if unquantised or unknown) + QuantGroup int // Quantisation group size + QuantType string // Quantisation type, when known (e.g. q4_k_m, q8_0) + QuantFamily string // Quantisation family, when known (e.g. q4, q8) + NumFiles int // Number of weight files + Format string // safetensors or gguf when known } // A valid directory has config.json + at least one .safetensors file. @@ -38,17 +40,24 @@ func Discover(baseDir string) iter.Seq[DiscoveredModel] { } func discoverDir(fsys *core.Fs, dir string, yield func(DiscoveredModel) bool) bool { - if m, ok := probeModelDir(fsys, dir); ok { + // Single readDir per directory — the entries feed both + // probeModelDir's safetensors count AND the recursion. Previously + // each directory was listed THREE times (probe → countSafetensors + // → discoverDir's own readDir), with each listing also paying + // reflect-based conversion. Now once, no reflect. + entries, ok := readDir(fsys, dir) + if !ok { + // We can still try to probe the directory even if listing + // fails — config.json read may succeed independently. + entries = nil + } + + if m, ok := probeModelDir(fsys, dir, entries); ok { if !yield(m) { return false } } - entries, ok := readDir(fsys, dir) - if !ok { - return true - } - for _, entry := range entries { if !entry.IsDir() { continue @@ -61,21 +70,42 @@ func discoverDir(fsys *core.Fs, dir string, yield func(DiscoveredModel) bool) bo return true } -// Accepts directories that contain config.json and at least one .safetensors file. -func probeModelDir(fsys *core.Fs, dir string) (DiscoveredModel, bool) { - config := fsys.Read(joinPath(dir, "config.json")) - if !config.OK { +// Accepts directories that contain config.json and at least one +// .safetensors file. `entries` is the pre-read directory listing — +// avoids the second readDir that countSafetensors used to do. +// +// Order matters: single pass over entries first to count safetensors +// AND verify config.json exists. Only then read config.json. This +// short-circuits the wasted disk Read for junk directories that have +// neither — see Discover_NoModels_TenJunkDirs which used to pay one +// fsys.Read per dir before this gate. +func probeModelDir(fsys *core.Fs, dir string, entries []core.FsDirEntry) (DiscoveredModel, bool) { + numFiles := 0 + hasConfig := false + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if name == "config.json" { + hasConfig = true + } else if core.HasSuffix(name, ".safetensors") { + numFiles++ + } + } + if numFiles == 0 || !hasConfig { return DiscoveredModel{}, false } - numFiles, ok := countSafetensors(fsys, dir) - if !ok || numFiles == 0 { + config := fsys.Read(joinPath(dir, "config.json")) + if !config.OK { return DiscoveredModel{}, false } model := DiscoveredModel{ Path: absolutePath(dir), NumFiles: numFiles, + Format: "safetensors", } var probe struct { @@ -103,61 +133,26 @@ func probeModelDir(fsys *core.Fs, dir string) (DiscoveredModel, bool) { return model, true } -type dirEntry interface { - Name() string - IsDir() bool -} - -func readDir(fsys *core.Fs, dir string) ([]dirEntry, bool) { +// readDir returns the directory's entries sorted by name. The result +// is the raw []core.FsDirEntry from core.Fs.List — no reflect, no +// adapter allocation. +func readDir(fsys *core.Fs, dir string) ([]core.FsDirEntry, bool) { result := fsys.List(dir) if !result.OK { return nil, false } - entries, ok := dirEntries(result.Value) + entries, ok := result.Value.([]core.FsDirEntry) if !ok { return nil, false } - slices.SortFunc(entries, func(a, b dirEntry) int { + slices.SortFunc(entries, func(a, b core.FsDirEntry) int { return cmp.Compare(a.Name(), b.Name()) }) return entries, true } -func dirEntries(value any) ([]dirEntry, bool) { - // core.Fs.List returns standard directory entries; adapt them locally. - slice := reflect.ValueOf(value) - if !slice.IsValid() || slice.Kind() != reflect.Slice { - return nil, false - } - - entries := make([]dirEntry, 0, slice.Len()) - for i := range slice.Len() { - entry, ok := slice.Index(i).Interface().(dirEntry) - if !ok { - return nil, false - } - entries = append(entries, entry) - } - return entries, true -} - -func countSafetensors(fsys *core.Fs, dir string) (int, bool) { - entries, ok := readDir(fsys, dir) - if !ok { - return 0, false - } - - count := 0 - for _, entry := range entries { - if !entry.IsDir() && core.HasSuffix(entry.Name(), ".safetensors") { - count++ - } - } - return count, true -} - func absolutePath(dir string) string { if core.PathIsAbs(dir) { return cleanPath(dir) diff --git a/go/discover_bench_test.go b/go/discover_bench_test.go new file mode 100644 index 0000000..cfce7aa --- /dev/null +++ b/go/discover_bench_test.go @@ -0,0 +1,161 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the model-directory discovery walk + path helpers. +// Per AX-11 — Discover walks every subdirectory of the user's model +// root, parses config.json for each candidate, and counts .safetensors +// shards. With dozens of fine-tunes per root the per-directory cost +// compounds. joinPath / cleanPath / absolutePath sit in the per-walk +// hot loop. +// +// Run: go test -bench='BenchmarkDiscover' -benchmem -run='^$' . + +package inference + +import ( + "slices" + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. Distinct names from other bench files. +var ( + discoverBenchSinkModels []DiscoveredModel + discoverBenchSinkPath string + discoverBenchSinkCount int +) + +// makeBenchModelDir is a file-scope helper so the bench fixture build +// stays out of the timed loop. Same shape as createModelDir in the test +// suite but with no t.Helper bookkeeping. +func makeBenchModelDir(b *testing.B, dir string, config map[string]any, shards int) { + b.Helper() + if r := core.MkdirAll(dir, 0o755); !r.OK { + b.Fatal(r.Value) + } + if config != nil { + data := []byte(core.JSONMarshalString(config)) + if r := core.WriteFile(core.JoinPath(dir, "config.json"), data, 0o644); !r.OK { + b.Fatal(r.Value) + } + } + for i := 0; i < shards; i++ { + name := core.Sprintf("model-%05d-of-%05d.safetensors", i+1, shards) + if r := core.WriteFile(core.JoinPath(dir, name), []byte("weights"), 0o644); !r.OK { + b.Fatal(r.Value) + } + } +} + +// --- Discover end-to-end (per-call walk floor) --- + +func BenchmarkDiscover_SingleModel_TwoShards(b *testing.B) { + base := b.TempDir() + makeBenchModelDir(b, core.JoinPath(base, "qwen3-4b"), map[string]any{ + "model_type": "qwen3", + "quantization": map[string]any{ + "bits": 4, + "group_size": 64, + }, + }, 2) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + discoverBenchSinkModels = slices.Collect(Discover(base)) + } +} + +// Three sibling models — the common "models/" layout where a user has a +// handful of checkpoints under one root. +func BenchmarkDiscover_ThreeSiblings(b *testing.B) { + base := b.TempDir() + makeBenchModelDir(b, core.JoinPath(base, "gemma3-1b"), map[string]any{"model_type": "gemma3"}, 1) + makeBenchModelDir(b, core.JoinPath(base, "qwen3-4b"), map[string]any{"model_type": "qwen3"}, 4) + makeBenchModelDir(b, core.JoinPath(base, "llama3-8b"), map[string]any{"model_type": "llama"}, 4) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + discoverBenchSinkModels = slices.Collect(Discover(base)) + } +} + +// Nested directory tree — exercises the recursive descent path. +func BenchmarkDiscover_NestedTree(b *testing.B) { + base := b.TempDir() + makeBenchModelDir(b, core.JoinPath(base, "base"), map[string]any{"model_type": "base"}, 1) + makeBenchModelDir(b, core.JoinPath(base, "base", "ft-a"), map[string]any{"model_type": "ft-a"}, 1) + makeBenchModelDir(b, core.JoinPath(base, "base", "ft-b"), map[string]any{"model_type": "ft-b"}, 1) + makeBenchModelDir(b, core.JoinPath(base, "base", "ft-b", "v2"), map[string]any{"model_type": "ft-b-v2"}, 1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + discoverBenchSinkModels = slices.Collect(Discover(base)) + } +} + +// Miss path — no config.json anywhere, just non-model files. Discover +// must still stat every entry. +func BenchmarkDiscover_NoModels_TenJunkDirs(b *testing.B) { + base := b.TempDir() + for i := 0; i < 10; i++ { + dir := core.JoinPath(base, core.Sprintf("junk-%d", i)) + if r := core.MkdirAll(dir, 0o755); !r.OK { + b.Fatal(r.Value) + } + if r := core.WriteFile(core.JoinPath(dir, "README.md"), []byte("not a model"), 0o644); !r.OK { + b.Fatal(r.Value) + } + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + discoverBenchSinkModels = slices.Collect(Discover(base)) + } +} + +// Early-exit path — caller takes the first match. Proxy for the common +// "pick by architecture" pattern in interactive UIs. +func BenchmarkDiscover_EarlyBreak_TwoSiblings(b *testing.B) { + base := b.TempDir() + makeBenchModelDir(b, core.JoinPath(base, "model-a"), map[string]any{"model_type": "a"}, 1) + makeBenchModelDir(b, core.JoinPath(base, "model-b"), map[string]any{"model_type": "b"}, 1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range Discover(base) { + count++ + break + } + discoverBenchSinkCount = count + } +} + +// --- Path helpers used in the inner walk loop --- + +func BenchmarkDiscover_JoinPath_ThreeParts(b *testing.B) { + a, c, d := "/models", "qwen3-4b", "config.json" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + discoverBenchSinkPath = joinPath(a, c, d) + } +} + +func BenchmarkDiscover_AbsolutePath_AlreadyAbsolute(b *testing.B) { + in := "/Volumes/Data/models/qwen3-4b" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + discoverBenchSinkPath = absolutePath(in) + } +} + +func BenchmarkDiscover_AbsolutePath_Relative(b *testing.B) { + in := "models/qwen3-4b" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + discoverBenchSinkPath = absolutePath(in) + } +} diff --git a/go/eval/eval.go b/go/eval/eval.go new file mode 100644 index 0000000..cafbcb4 --- /dev/null +++ b/go/eval/eval.go @@ -0,0 +1,403 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package eval provides dataset-native perplexity + small quality probes +// for any inference driver (go-mlx, go-rocm, go-cuda, etc.). +// +// It is decoupled from driver concrete types: Sample, Batch, and +// BatchConfig are opaque (any), Dataset is an interface, and the +// runner adapter provides callbacks for the few fields eval needs to +// inspect (BatchTokens, SampleText). Driver wrappers convert their +// native types into an eval.Runner. +package eval + +import ( + "context" + "math" + "strconv" + "time" + + core "dappco.re/go" +) + +const ReportVersion = 1 + +// Sample is one dataset row. Opaque to eval; the runner provides +// SampleText for quality probes that need to read the text body. +type Sample = any + +// Batch is one tokenised batch. Opaque to eval; the runner evaluates +// it and may provide BatchTokens for token-count fallback. +type Batch = any + +// BatchConfig is the dataset batching configuration. Opaque to eval — +// passed through to the runner's BuildBatches. +type BatchConfig = any + +// Dataset is an iterator over Samples. +// +// for { +// sample, ok, err := ds.Next() +// if !ok || err != nil { break } +// } +type Dataset interface { + Next() (Sample, bool, error) +} + +// AdapterInfo identifies a LoRA adapter participating in the eval run. +// Defined here (rather than imported from a driver's lora package) so +// eval stays driver-neutral. +type AdapterInfo struct { + Name string `json:"name,omitempty"` + Path string `json:"path,omitempty"` + Hash string `json:"hash,omitempty"` + Rank int `json:"rank,omitempty"` + Alpha float32 `json:"alpha,omitempty"` + Scale float32 `json:"scale,omitempty"` + TargetKeys []string `json:"target_keys,omitempty"` +} + +// IsEmpty reports whether the adapter info has no meaningful fields set. +func (info AdapterInfo) IsEmpty() bool { + return info.Name == "" && info.Path == "" && info.Hash == "" && info.Rank == 0 && info.Alpha == 0 && info.Scale == 0 && len(info.TargetKeys) == 0 +} + +// Info mirrors a driver's model info — flat fields that travel through +// reports for downstream consumers. +type Info struct { + Architecture string `json:"architecture,omitempty"` + VocabSize int `json:"vocab_size,omitempty"` + NumLayers int `json:"num_layers,omitempty"` + HiddenSize int `json:"hidden_size,omitempty"` + QuantBits int `json:"quant_bits,omitempty"` + QuantGroup int `json:"quant_group,omitempty"` + ContextLength int `json:"context_length,omitempty"` + Adapter AdapterInfo `json:"adapter,omitempty"` +} + +// Config controls dataset-native perplexity and small quality probes. +type Config struct { + Batch BatchConfig `json:"batch"` + AdapterPath string `json:"adapter_path,omitempty"` + MaxSamples int `json:"max_samples,omitempty"` + QualityProbes []QualityProbe `json:"-"` +} + +// Runner supplies the model operations needed for dataset evaluation. +// BuildBatches and EvaluateBatch are required; the rest are optional. +type Runner struct { + Info func(context.Context) Info + LoadAdapter func(context.Context, string) (AdapterInfo, error) + BuildBatches func(context.Context, Dataset, BatchConfig) ([]Batch, error) + EvaluateBatch func(context.Context, Batch) (BatchMetrics, error) + // BatchTokens is a fallback for BatchMetrics.Tokens when the runner + // reports zero. Returns the loss-eligible token count. + BatchTokens func(Batch) int + // SampleText extracts the human-readable text body from a Sample for + // quality probes that need to inspect it. + SampleText func(Sample) (text, response string) +} + +// BatchMetrics is the loss result for one tokenized batch. +type BatchMetrics struct { + Samples int `json:"samples,omitempty"` + Tokens int `json:"tokens,omitempty"` + Loss float64 `json:"loss,omitempty"` +} + +// Metrics aggregates loss and perplexity over a dataset stream. +type Metrics struct { + Samples int `json:"samples,omitempty"` + Batches int `json:"batches,omitempty"` + Tokens int `json:"tokens,omitempty"` + Loss float64 `json:"loss,omitempty"` + Perplexity float64 `json:"perplexity,omitempty"` +} + +// Report is a JSON-friendly native eval result. +type Report struct { + Version int `json:"version"` + ModelInfo Info `json:"model_info"` + Adapter AdapterInfo `json:"adapter,omitempty"` + Config Config `json:"config"` + Metrics Metrics `json:"metrics"` + Quality QualityReport `json:"quality"` + Duration time.Duration `json:"duration,omitempty"` +} + +// QualityProbe adds a custom deterministic quality check. +type QualityProbe struct { + Name string `json:"name"` + Check func(QualityContext) QualityCheck `json:"-"` +} + +// QualityContext is passed to custom eval probes. +type QualityContext struct { + Config Config + Samples []Sample + Metrics Metrics + ModelInfo Info + Adapter AdapterInfo + // SampleText is the runner's accessor for reading text/response from + // an opaque Sample. Probes that introspect sample content go through + // this rather than type-asserting. + SampleText func(Sample) (text, response string) +} + +// QualityReport contains small deterministic checks over eval data + metrics. +type QualityReport struct { + Checks []QualityCheck `json:"checks,omitempty"` +} + +// QualityCheck is one quality probe result. +type QualityCheck struct { + Name string `json:"name"` + Pass bool `json:"pass"` + Score float64 `json:"score"` + Detail string `json:"detail,omitempty"` +} + +// RunDataset evaluates perplexity and quality probes over a dataset stream. +// +// report, err := eval.RunDataset(ctx, runner, dataset, cfg) +func RunDataset(ctx context.Context, runner Runner, dataset Dataset, cfg Config) (*Report, error) { + if ctx == nil { + ctx = context.Background() + } + if runner.EvaluateBatch == nil { + return nil, core.NewError("mlx: eval runner requires EvaluateBatch") + } + if runner.BuildBatches == nil { + return nil, core.NewError("mlx: eval runner requires BuildBatches") + } + if dataset == nil { + return nil, core.NewError("mlx: eval dataset is nil") + } + + start := time.Now() + samples, err := collectSamples(ctx, dataset, cfg.MaxSamples) + if err != nil { + return nil, err + } + if len(samples) == 0 { + return nil, core.NewError("mlx: eval dataset produced no samples") + } + + report := &Report{ + Version: ReportVersion, + Config: cfg, + } + if runner.Info != nil { + report.ModelInfo = runner.Info(ctx) + report.Adapter = report.ModelInfo.Adapter + } + if cfg.AdapterPath != "" { + if runner.LoadAdapter == nil { + return nil, core.NewError("mlx: eval runner does not support LoRA adapter loading") + } + adapter, err := runner.LoadAdapter(ctx, cfg.AdapterPath) + if err != nil { + return nil, err + } + report.Adapter = adapter + if runner.Info != nil { + report.ModelInfo = runner.Info(ctx) + } + if report.ModelInfo.Adapter.IsEmpty() { + report.ModelInfo.Adapter = adapter + } + } + if report.Adapter.IsEmpty() { + report.Adapter = report.ModelInfo.Adapter + } + + batches, err := runner.BuildBatches(ctx, newSliceDataset(samples), cfg.Batch) + if err != nil { + return nil, err + } + if len(batches) == 0 { + return nil, core.NewError("mlx: eval dataset produced no tokenized batches") + } + + metrics, err := evaluateBatches(ctx, runner, batches, len(samples)) + if err != nil { + return nil, err + } + report.Metrics = metrics + report.Duration = nonZeroDuration(time.Since(start)) + report.Quality = runQualityProbes(QualityContext{ + Config: cfg, + Samples: samples, + Metrics: metrics, + ModelInfo: report.ModelInfo, + Adapter: report.Adapter, + SampleText: runner.SampleText, + }) + return report, nil +} + +func collectSamples(ctx context.Context, dataset Dataset, maxSamples int) ([]Sample, error) { + // Pre-allocate when maxSamples is known — saves the + // log2(maxSamples) doubling grows that append would otherwise pay. + // For the 0-hint case (unknown dataset size), let append handle + // growth as before. + var samples []Sample + if maxSamples > 0 { + samples = make([]Sample, 0, maxSamples) + } + for { + if err := ctx.Err(); err != nil { + return nil, err + } + if maxSamples > 0 && len(samples) >= maxSamples { + break + } + sample, ok, err := dataset.Next() + if err != nil { + return nil, err + } + if !ok { + break + } + samples = append(samples, sample) + } + return samples, nil +} + +type sliceDataset struct { + samples []Sample + idx int +} + +func newSliceDataset(samples []Sample) Dataset { + return &sliceDataset{samples: samples} +} + +func (d *sliceDataset) Next() (Sample, bool, error) { + if d.idx >= len(d.samples) { + return nil, false, nil + } + sample := d.samples[d.idx] + d.idx++ + return sample, true, nil +} + +func evaluateBatches(ctx context.Context, runner Runner, batches []Batch, samples int) (Metrics, error) { + metrics := Metrics{Samples: samples, Batches: len(batches)} + var weightedLoss float64 + for _, batch := range batches { + if err := ctx.Err(); err != nil { + return Metrics{}, err + } + batchMetrics, err := runner.EvaluateBatch(ctx, batch) + if err != nil { + return Metrics{}, err + } + if batchMetrics.Tokens <= 0 && runner.BatchTokens != nil { + batchMetrics.Tokens = runner.BatchTokens(batch) + } + if batchMetrics.Tokens <= 0 { + continue + } + if math.IsNaN(batchMetrics.Loss) || math.IsInf(batchMetrics.Loss, 0) { + return Metrics{}, core.NewError("mlx: eval batch loss is not finite") + } + metrics.Tokens += batchMetrics.Tokens + weightedLoss += batchMetrics.Loss * float64(batchMetrics.Tokens) + } + if metrics.Tokens == 0 { + return Metrics{}, core.NewError("mlx: eval produced no loss tokens") + } + metrics.Loss = weightedLoss / float64(metrics.Tokens) + metrics.Perplexity = math.Exp(metrics.Loss) + return metrics, nil +} + +func runQualityProbes(ctx QualityContext) QualityReport { + checks := defaultQualityChecks(ctx) + for _, probe := range ctx.Config.QualityProbes { + check := QualityCheck{Name: probe.Name} + if probe.Check == nil { + check.Pass = false + check.Detail = "probe has no check function" + } else { + check = probe.Check(ctx) + if check.Name == "" { + check.Name = probe.Name + } + } + checks = append(checks, check) + } + return QualityReport{Checks: checks} +} + +func defaultQualityChecks(ctx QualityContext) []QualityCheck { + samples := len(ctx.Samples) + lossFinite := !math.IsNaN(ctx.Metrics.Loss) && !math.IsInf(ctx.Metrics.Loss, 0) && ctx.Metrics.Loss >= 0 + pplFinite := !math.IsNaN(ctx.Metrics.Perplexity) && !math.IsInf(ctx.Metrics.Perplexity, 0) && ctx.Metrics.Perplexity >= 1 + // strconv.Itoa / FormatFloat skip the fmt formatter pipeline that + // core.Sprintf would walk for every Detail string. Each Sprintf + // was 1-2 allocs; FormatX returns a single fresh string. + return []QualityCheck{ + {Name: "samples_present", Pass: samples > 0, Score: boolScore(samples > 0), Detail: strconv.Itoa(samples)}, + {Name: "token_coverage", Pass: ctx.Metrics.Tokens > 0, Score: boolScore(ctx.Metrics.Tokens > 0), Detail: strconv.Itoa(ctx.Metrics.Tokens)}, + {Name: "loss_finite", Pass: lossFinite, Score: boolScore(lossFinite), Detail: strconv.FormatFloat(ctx.Metrics.Loss, 'f', 6, 64)}, + {Name: "perplexity_finite", Pass: pplFinite, Score: boolScore(pplFinite), Detail: strconv.FormatFloat(ctx.Metrics.Perplexity, 'f', 6, 64)}, + } +} + +// ResponseCoverageProbe is a quality probe that counts samples with +// non-empty Text or Response. Driver wrappers attach this probe so +// eval doesn't need to know about the driver's sample field shape. +// +// cfg.QualityProbes = append(cfg.QualityProbes, eval.ResponseCoverageProbe()) +func ResponseCoverageProbe() QualityProbe { + return QualityProbe{ + Name: "response_coverage", + Check: func(ctx QualityContext) QualityCheck { + if ctx.SampleText == nil { + return QualityCheck{Name: "response_coverage", Pass: false, Detail: "no SampleText accessor"} + } + samples := len(ctx.Samples) + responseLike := 0 + for _, sample := range ctx.Samples { + text, response := ctx.SampleText(sample) + if core.Trim(text) != "" || core.Trim(response) != "" { + responseLike++ + } + } + // Hand-build the "%d/%d" Detail without Sprintf — 1 alloc + // vs Sprintf's 2-3 (formatter scratch + result). + detail := make([]byte, 0, 16) + detail = strconv.AppendInt(detail, int64(responseLike), 10) + detail = append(detail, '/') + detail = strconv.AppendInt(detail, int64(samples), 10) + return QualityCheck{ + Name: "response_coverage", + Pass: responseLike == samples, + Score: fractionScore(responseLike, samples), + Detail: core.AsString(detail), + } + }, + } +} + +func boolScore(ok bool) float64 { + if ok { + return 1 + } + return 0 +} + +func fractionScore(numerator, denominator int) float64 { + if denominator <= 0 { + return 0 + } + return float64(numerator) / float64(denominator) +} + +func nonZeroDuration(d time.Duration) time.Duration { + if d <= 0 { + return time.Nanosecond + } + return d +} diff --git a/go/eval/eval_bench_test.go b/go/eval/eval_bench_test.go new file mode 100644 index 0000000..6168f97 --- /dev/null +++ b/go/eval/eval_bench_test.go @@ -0,0 +1,382 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the driver-neutral dataset-eval harness — RunDataset +// over a synthetic Runner, the sample-collector hot loop, the batch +// reducer, quality-probe runners, and the AdapterInfo emptiness check. +// +// Per AX-11 — RunDataset fires once per eval invocation, but +// collectSamples + evaluateBatches walk every sample/batch the dataset +// emits, and runQualityProbes runs every check after every eval. The +// `quick_eval` lane in lthn/LEM-Eval uses ~200 samples per probe. +// +// Run: go test -bench='BenchmarkEval' -benchmem -run='^$' ./go/eval + +package eval + +import ( + "context" + "testing" + "time" +) + +// Sinks defeat compiler DCE. +var ( + evalSinkReport *Report + evalSinkErr error + evalSinkSamples []Sample + evalSinkMetrics Metrics + evalSinkQuality QualityReport + evalSinkBool bool + evalSinkDur time.Duration + evalSinkBatchTok int + evalSinkQualScore float64 + evalSinkBoolScore float64 + evalSinkFracScore float64 + evalSinkSampleText string +) + +// evalSampleShape is the synthetic Sample type the benches feed through +// eval — eval treats Sample as opaque (any), so the shape only needs +// to be readable by the runner's SampleText callback. +type evalSampleShape struct { + Text string + Response string +} + +// evalBatchShape is the synthetic Batch type. eval treats Batch as +// opaque (any); the runner's EvaluateBatch + BatchTokens callbacks +// extract loss + token count. +type evalBatchShape struct { + Tokens int + Loss float64 +} + +// buildEvalSamples mints n samples shaped like the LEM-Eval rows +// (text body + response). Each carries a non-empty text/response so +// response_coverage doesn't short-circuit. +func buildEvalSamples(n int) []evalSampleShape { + samples := make([]evalSampleShape, n) + for i := 0; i < n; i++ { + samples[i] = evalSampleShape{ + Text: "What is the capital of Lethean?", + Response: "The capital is in the network.", + } + } + return samples +} + +// evalSampleIter wraps a slice in the Dataset interface. +type evalSampleIter struct { + samples []evalSampleShape + idx int +} + +func (it *evalSampleIter) Next() (Sample, bool, error) { + if it.idx >= len(it.samples) { + return nil, false, nil + } + s := it.samples[it.idx] + it.idx++ + return s, true, nil +} + +// evalRunner returns a Runner whose callbacks emit deterministic +// per-sample metrics. Used by every RunDataset bench below. +func evalRunner(samples []evalSampleShape) Runner { + return Runner{ + Info: func(context.Context) Info { + return Info{Architecture: "qwen3", ContextLength: 4096} + }, + BuildBatches: func(_ context.Context, ds Dataset, _ BatchConfig) ([]Batch, error) { + var batches []Batch + for { + s, ok, err := ds.Next() + if err != nil { + return nil, err + } + if !ok { + break + } + _ = s + batches = append(batches, evalBatchShape{Tokens: 8, Loss: 1.5}) + } + return batches, nil + }, + EvaluateBatch: func(_ context.Context, batch Batch) (BatchMetrics, error) { + eb := batch.(evalBatchShape) + return BatchMetrics{Samples: 1, Tokens: eb.Tokens, Loss: eb.Loss}, nil + }, + BatchTokens: func(batch Batch) int { + return batch.(evalBatchShape).Tokens + }, + SampleText: func(sample Sample) (string, string) { + s := sample.(evalSampleShape) + return s.Text, s.Response + }, + } +} + +// --- RunDataset end-to-end at 10 / 100 question scales --- + +func BenchmarkEval_RunDataset_10Samples(b *testing.B) { + cfg := Config{} + ctx := context.Background() + source := buildEvalSamples(10) + runner := evalRunner(source) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkReport, evalSinkErr = RunDataset(ctx, runner, &evalSampleIter{samples: source}, cfg) + } +} + +func BenchmarkEval_RunDataset_100Samples(b *testing.B) { + cfg := Config{} + ctx := context.Background() + source := buildEvalSamples(100) + runner := evalRunner(source) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkReport, evalSinkErr = RunDataset(ctx, runner, &evalSampleIter{samples: source}, cfg) + } +} + +// MaxSamples short-circuits collectSamples — exercises the limited +// path that quick_eval lanes use. +func BenchmarkEval_RunDataset_100Samples_MaxSamples50(b *testing.B) { + cfg := Config{MaxSamples: 50} + ctx := context.Background() + source := buildEvalSamples(100) + runner := evalRunner(source) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkReport, evalSinkErr = RunDataset(ctx, runner, &evalSampleIter{samples: source}, cfg) + } +} + +// RunDataset with a custom QualityProbe attached — measures the cost +// of running per-sample text inspection (the ResponseCoverageProbe +// path drivers wire up by default). +func BenchmarkEval_RunDataset_100Samples_WithProbe(b *testing.B) { + cfg := Config{QualityProbes: []QualityProbe{ResponseCoverageProbe()}} + ctx := context.Background() + source := buildEvalSamples(100) + runner := evalRunner(source) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkReport, evalSinkErr = RunDataset(ctx, runner, &evalSampleIter{samples: source}, cfg) + } +} + +// --- collectSamples in isolation --- + +func BenchmarkEval_CollectSamples_10(b *testing.B) { + ctx := context.Background() + source := buildEvalSamples(10) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkSamples, evalSinkErr = collectSamples(ctx, &evalSampleIter{samples: source}, 0) + } +} + +func BenchmarkEval_CollectSamples_100(b *testing.B) { + ctx := context.Background() + source := buildEvalSamples(100) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkSamples, evalSinkErr = collectSamples(ctx, &evalSampleIter{samples: source}, 0) + } +} + +func BenchmarkEval_CollectSamples_100_Cap50(b *testing.B) { + ctx := context.Background() + source := buildEvalSamples(100) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkSamples, evalSinkErr = collectSamples(ctx, &evalSampleIter{samples: source}, 50) + } +} + +// --- evaluateBatches in isolation --- + +func BenchmarkEval_EvaluateBatches_10(b *testing.B) { + source := buildEvalSamples(10) + runner := evalRunner(source) + batches, err := runner.BuildBatches(context.Background(), &evalSampleIter{samples: source}, nil) + if err != nil { + b.Fatal(err) + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkMetrics, evalSinkErr = evaluateBatches(ctx, runner, batches, len(source)) + } +} + +func BenchmarkEval_EvaluateBatches_100(b *testing.B) { + source := buildEvalSamples(100) + runner := evalRunner(source) + batches, err := runner.BuildBatches(context.Background(), &evalSampleIter{samples: source}, nil) + if err != nil { + b.Fatal(err) + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkMetrics, evalSinkErr = evaluateBatches(ctx, runner, batches, len(source)) + } +} + +// --- defaultQualityChecks + runQualityProbes (per-eval probe surface) --- + +func BenchmarkEval_DefaultQualityChecks(b *testing.B) { + source := buildEvalSamples(10) + samples := make([]Sample, len(source)) + for i, s := range source { + samples[i] = s + } + qc := QualityContext{ + Samples: samples, + Metrics: Metrics{Samples: 10, Tokens: 80, Loss: 1.5, Perplexity: 4.48}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = defaultQualityChecks(qc) + } +} + +func BenchmarkEval_RunQualityProbes_NoCustom(b *testing.B) { + source := buildEvalSamples(10) + samples := make([]Sample, len(source)) + for i, s := range source { + samples[i] = s + } + qc := QualityContext{ + Samples: samples, + Metrics: Metrics{Samples: 10, Tokens: 80, Loss: 1.5, Perplexity: 4.48}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkQuality = runQualityProbes(qc) + } +} + +// 100 samples × ResponseCoverageProbe — the body the probe walks per call. +func BenchmarkEval_ResponseCoverageProbe_100Samples(b *testing.B) { + source := buildEvalSamples(100) + samples := make([]Sample, len(source)) + for i, s := range source { + samples[i] = s + } + probe := ResponseCoverageProbe() + qc := QualityContext{ + Samples: samples, + Metrics: Metrics{Samples: 100, Tokens: 800, Loss: 1.5, Perplexity: 4.48}, + SampleText: func(sample Sample) (string, string) { + s := sample.(evalSampleShape) + return s.Text, s.Response + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = probe.Check(qc) + } +} + +// --- AdapterInfo.IsEmpty --- + +func BenchmarkEval_AdapterInfo_IsEmpty_Empty(b *testing.B) { + info := AdapterInfo{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkBool = info.IsEmpty() + } +} + +func BenchmarkEval_AdapterInfo_IsEmpty_Populated(b *testing.B) { + info := AdapterInfo{ + Name: "qwen3-lora", + Path: "/adapters/qwen3.lora", + Hash: "sha256:deadbeef", + Rank: 16, + Alpha: 32, + Scale: 0.5, + TargetKeys: []string{"q_proj", "k_proj", "v_proj", "o_proj"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkBool = info.IsEmpty() + } +} + +// --- Score helpers (called per quality check) --- + +func BenchmarkEval_BoolScore_True(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkBoolScore = boolScore(true) + } +} + +func BenchmarkEval_FractionScore_HalfPopulated(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkFracScore = fractionScore(50, 100) + } +} + +// --- nonZeroDuration --- + +func BenchmarkEval_NonZeroDuration_Positive(b *testing.B) { + d := 45 * time.Millisecond + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkDur = nonZeroDuration(d) + } +} + +func BenchmarkEval_NonZeroDuration_Zero(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + evalSinkDur = nonZeroDuration(0) + } +} + +// --- sliceDataset.Next (the iterator created by RunDataset to feed +// BuildBatches; fires once per sample) --- + +func BenchmarkEval_SliceDataset_Next_100Samples(b *testing.B) { + source := buildEvalSamples(100) + samples := make([]Sample, len(source)) + for i, s := range source { + samples[i] = s + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ds := newSliceDataset(samples) + for { + _, ok, err := ds.Next() + if err != nil || !ok { + break + } + } + } +} diff --git a/go/gguf.go b/go/gguf.go new file mode 100644 index 0000000..00b44a1 --- /dev/null +++ b/go/gguf.go @@ -0,0 +1,386 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "cmp" + "encoding/binary" + "io" + "io/fs" + "slices" + + core "dappco.re/go" +) + +const ( + ggufMagic = 0x46554747 + ggufVersion = 3 + ggufTypeUint32 = 4 + ggufTypeString = 8 +) + +// GGUFInfo summarises GGUF metadata without requiring a concrete runtime. +type GGUFInfo struct { + Path string + Architecture string + VocabSize int + HiddenSize int + NumLayers int + ContextLength int + QuantBits int + QuantGroup int + QuantType string + QuantFamily string + TensorCount int + MetadataCount int + ValidationIssues []GGUFValidationIssue +} + +// Valid reports whether metadata parsing found validation errors. +func (info GGUFInfo) Valid() bool { + for _, issue := range info.ValidationIssues { + if issue.Severity == GGUFValidationError { + return false + } + } + return true +} + +// GGUFValidationSeverity classifies GGUF metadata validation findings. +type GGUFValidationSeverity string + +const ( + GGUFValidationWarning GGUFValidationSeverity = "warning" + GGUFValidationError GGUFValidationSeverity = "error" +) + +// GGUFValidationIssue describes one GGUF metadata validation issue. +type GGUFValidationIssue struct { + Severity GGUFValidationSeverity `json:"severity"` + Code string `json:"code"` + Message string `json:"message"` + Tensor string `json:"tensor,omitempty"` +} + +// ReadGGUFInfo reads GGUF header metadata without loading tensors. +func ReadGGUFInfo(modelPath string) (GGUFInfo, error) { + ggufPath, err := resolveGGUFFile(modelPath) + if err != nil { + return GGUFInfo{}, err + } + metadata, tensorCount, err := parseGGUFMetadata(ggufPath) + if err != nil { + return GGUFInfo{}, err + } + absolutePath := ggufPath + if abs := core.PathAbs(ggufPath); abs.OK { + absolutePath = abs.Value.(string) + } + architecture := metadataString(metadata, "general.architecture") + quantBits, quantGroup, quantType, quantFamily := ggufQuantisationFromMetadata(metadata) + return GGUFInfo{ + Path: absolutePath, + Architecture: architecture, + VocabSize: firstPositiveInt(metadataInt(metadata, architecture+".vocab_size"), metadataInt(metadata, "tokenizer.ggml.tokens")), + HiddenSize: metadataInt(metadata, architecture+".embedding_length"), + NumLayers: metadataInt(metadata, architecture+".block_count"), + ContextLength: metadataInt(metadata, architecture+".context_length"), + QuantBits: quantBits, + QuantGroup: quantGroup, + QuantType: quantType, + QuantFamily: quantFamily, + TensorCount: tensorCount, + MetadataCount: len(metadata), + }, nil +} + +// DiscoverModels returns safetensors and GGUF models beneath basePath. +func DiscoverModels(basePath string) []DiscoveredModel { + resolvedPath := basePath + if abs := core.PathAbs(basePath); abs.OK { + resolvedPath = abs.Value.(string) + } + stat := core.Stat(resolvedPath) + if !stat.OK { + return nil + } + if !stat.Value.(core.FsFileInfo).IsDir() { + if core.HasSuffix(core.Lower(resolvedPath), ".gguf") { + if info, err := ReadGGUFInfo(resolvedPath); err == nil { + return []DiscoveredModel{discoveredModelFromGGUF(info)} + } + } + return nil + } + + models := slices.Collect(Discover(resolvedPath)) + if err := core.PathWalkDir(resolvedPath, func(path string, entry fs.DirEntry, walkErr error) error { + if walkErr != nil || !entry.IsDir() { + return nil + } + ggufs := core.PathGlob(core.PathJoin(path, "*.gguf")) + if len(ggufs) != 1 { + return nil + } + info, err := ReadGGUFInfo(ggufs[0]) + if err != nil { + return nil + } + models = append(models, discoveredModelFromGGUF(info)) + return nil + }); err != nil { + return nil + } + slices.SortFunc(models, func(a, b DiscoveredModel) int { + return cmp.Compare(a.Path, b.Path) + }) + return models +} + +func discoveredModelFromGGUF(info GGUFInfo) DiscoveredModel { + return DiscoveredModel{ + Path: info.Path, + ModelType: info.Architecture, + QuantBits: info.QuantBits, + QuantGroup: info.QuantGroup, + QuantType: info.QuantType, + QuantFamily: info.QuantFamily, + NumFiles: 1, + Format: "gguf", + } +} + +func resolveGGUFFile(modelPath string) (string, error) { + if core.HasSuffix(core.Lower(modelPath), ".gguf") { + return modelPath, nil + } + ggufs := core.PathGlob(core.PathJoin(modelPath, "*.gguf")) + switch len(ggufs) { + case 0: + return "", core.NewError("inference: no .gguf file found") + case 1: + return ggufs[0], nil + default: + return "", core.NewError("inference: multiple .gguf files found") + } +} + +func parseGGUFMetadata(path string) (map[string]any, int, error) { + open := core.Open(path) + if !open.OK { + return nil, 0, core.Errorf("inference: open gguf: %w", open.Value.(error)) + } + file := open.Value.(*core.OSFile) + defer file.Close() + + // Header reads use binary.LittleEndian.UintX on a stack-allocated + // fixed-size buffer instead of binary.Read — binary.Read uses + // reflect and allocates per call (~1 alloc/value); the direct + // LittleEndian path is zero-alloc. The header loop fires once per + // metadata entry, so for a vocab-heavy GGUF that's hundreds of + // avoidable allocs per model load. + var hdr [8]byte + + if _, err := io.ReadFull(file, hdr[:4]); err != nil { + return nil, 0, core.Errorf("inference: read gguf magic: %w", err) + } + if magic := binary.LittleEndian.Uint32(hdr[:4]); magic != ggufMagic { + return nil, 0, core.NewError("inference: invalid gguf magic") + } + if _, err := io.ReadFull(file, hdr[:4]); err != nil { + return nil, 0, core.Errorf("inference: read gguf version: %w", err) + } + if version := binary.LittleEndian.Uint32(hdr[:4]); version != ggufVersion { + return nil, 0, core.Errorf("inference: unsupported gguf version: %d", version) + } + if _, err := io.ReadFull(file, hdr[:8]); err != nil { + return nil, 0, core.Errorf("inference: read gguf tensor count: %w", err) + } + tensorCount := binary.LittleEndian.Uint64(hdr[:8]) + if _, err := io.ReadFull(file, hdr[:8]); err != nil { + return nil, 0, core.Errorf("inference: read gguf metadata count: %w", err) + } + metadataCount := binary.LittleEndian.Uint64(hdr[:8]) + // ReadGGUFInfo queries only seven well-known keys; a vocab-heavy + // header may carry hundreds of unrelated entries (every tokenizer + // config field, every BPE merge marker, etc.). Skipping the value + // reads and map inserts for keys we never query is the dominant + // alloc lift on model load — synthetic vocab-heavy benches go from + // ~600 allocs to a handful. The map is sized to "metadata count" + // only as an upper bound; the actual fill is just the keys we + // actually read. + metadata := make(map[string]any, 8) + var keyScratch []byte + for range metadataCount { + keyView, err := readGGUFKeyView(file, hdr[:8], &keyScratch) + if err != nil { + return nil, 0, err + } + if _, err := io.ReadFull(file, hdr[:4]); err != nil { + return nil, 0, core.Errorf("inference: read gguf metadata type: %w", err) + } + valueType := binary.LittleEndian.Uint32(hdr[:4]) + if !keyOfInterest(keyView) { + if err := skipGGUFValue(file, valueType, hdr[:8]); err != nil { + return nil, 0, err + } + continue + } + // Key needs to outlive the scratch buffer — core.Clone + // detaches the string from its backing memory so the next + // readGGUFKeyView call can reuse the buffer without + // invalidating map keys. + key := core.Clone(keyView) + value, err := readGGUFValue(file, valueType, hdr[:8]) + if err != nil { + return nil, 0, err + } + metadata[key] = value + } + return metadata, int(tensorCount), nil +} + +// keyOfInterest reports whether ReadGGUFInfo queries this metadata key. +// Any other key is parsed past without touching the map — skipping the +// value bytes via Seek and skipping the map insert eliminates two +// allocs per uninteresting entry, which on real GGUF headers dominates +// the metadata loop cost. +func keyOfInterest(key string) bool { + switch key { + case "general.architecture", "general.file_type", "tokenizer.ggml.tokens": + return true + } + return core.HasSuffix(key, ".vocab_size") || + core.HasSuffix(key, ".embedding_length") || + core.HasSuffix(key, ".block_count") || + core.HasSuffix(key, ".context_length") +} + +// readGGUFKeyView reads the next key into a caller-owned reusable +// buffer and returns a zero-copy string view aliasing it. The view is +// valid only until the next readGGUFKeyView call; callers must clone +// before storing the key for use beyond the parse loop body. +func readGGUFKeyView(reader io.Reader, scratch []byte, keyBuf *[]byte) (string, error) { + if _, err := io.ReadFull(reader, scratch[:8]); err != nil { + return "", core.Errorf("inference: read gguf string length: %w", err) + } + length := binary.LittleEndian.Uint64(scratch[:8]) + if uint64(cap(*keyBuf)) < length { + *keyBuf = make([]byte, length) + } else { + *keyBuf = (*keyBuf)[:length] + } + if _, err := io.ReadFull(reader, *keyBuf); err != nil { + return "", core.Errorf("inference: read gguf string: %w", err) + } + return core.AsString(*keyBuf), nil +} + +// skipGGUFValue advances the reader past the value bytes for keys +// ReadGGUFInfo doesn't query. The OS file is an io.Seeker so we skip +// without allocating a byte buffer; if the underlying reader doesn't +// support seeking we fall back to io.CopyN to io.Discard, which +// streams bytes without retaining them. +func skipGGUFValue(reader io.Reader, valueType uint32, scratch []byte) error { + switch valueType { + case ggufTypeString: + if _, err := io.ReadFull(reader, scratch[:8]); err != nil { + return core.Errorf("inference: read gguf string length: %w", err) + } + length := int64(binary.LittleEndian.Uint64(scratch[:8])) + if seeker, ok := reader.(io.Seeker); ok { + if _, err := seeker.Seek(length, io.SeekCurrent); err != nil { + return core.Errorf("inference: seek past gguf string value: %w", err) + } + return nil + } + if _, err := io.CopyN(io.Discard, reader, length); err != nil { + return core.Errorf("inference: discard gguf string value: %w", err) + } + return nil + case ggufTypeUint32: + if _, err := io.ReadFull(reader, scratch[:4]); err != nil { + return core.Errorf("inference: read gguf uint32 metadata: %w", err) + } + return nil + default: + return core.Errorf("inference: unsupported gguf metadata type: %d", valueType) + } +} + +// readGGUFValue + readGGUFString accept a caller-owned scratch buffer +// so the reflect-allocating binary.Read path stays out of the per-entry +// inner loop. Callers pass hdr[:8] from the outer parse loop. +func readGGUFValue(reader io.Reader, valueType uint32, scratch []byte) (any, error) { + switch valueType { + case ggufTypeString: + return readGGUFString(reader, scratch) + case ggufTypeUint32: + if _, err := io.ReadFull(reader, scratch[:4]); err != nil { + return nil, core.Errorf("inference: read gguf uint32 metadata: %w", err) + } + return binary.LittleEndian.Uint32(scratch[:4]), nil + default: + return nil, core.Errorf("inference: unsupported gguf metadata type: %d", valueType) + } +} + +func readGGUFString(reader io.Reader, scratch []byte) (string, error) { + if _, err := io.ReadFull(reader, scratch[:8]); err != nil { + return "", core.Errorf("inference: read gguf string length: %w", err) + } + length := binary.LittleEndian.Uint64(scratch[:8]) + buf := make([]byte, length) + if _, err := io.ReadFull(reader, buf); err != nil { + return "", core.Errorf("inference: read gguf string: %w", err) + } + // buf is freshly-allocated and unreachable after this conversion — + // core.AsString skips the []byte→string copy. A typical GGUF + // metadata pass calls readGGUFString once per key + once per string + // value (architecture, tokenizer.ggml.tokens, etc.); large vocabs + // turn this into hundreds of KB of avoidable copies per load. + return core.AsString(buf), nil +} + +func metadataString(metadata map[string]any, key string) string { + if value, ok := metadata[key].(string); ok { + return value + } + return "" +} + +func metadataInt(metadata map[string]any, key string) int { + switch value := metadata[key].(type) { + case uint32: + return int(value) + case uint64: + return int(value) + default: + return 0 + } +} + +func ggufQuantisationFromMetadata(metadata map[string]any) (bits, group int, quantType, family string) { + fileType := metadataInt(metadata, "general.file_type") + switch fileType { + case 0: + return 32, 0, "f32", "f32" + case 1: + return 16, 0, "f16", "f16" + case 7: + return 8, 32, "q8_0", "q8" + case 15: + return 4, 32, "q4_k_m", "q4" + default: + return 0, 0, "", "" + } +} + +func firstPositiveInt(values ...int) int { + for _, value := range values { + if value > 0 { + return value + } + } + return 0 +} diff --git a/go/gguf_bench_test.go b/go/gguf_bench_test.go new file mode 100644 index 0000000..50e8958 --- /dev/null +++ b/go/gguf_bench_test.go @@ -0,0 +1,139 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the GGUF model-file primitives. +// Per AX-11 — ReadGGUFInfo is called once per model load; the +// metadata loop fires once per metadata entry, of which a typical +// GGUF has hundreds (every tensor name, vocab token, RoPE setting). +// readGGUFString is the per-entry hot loop the consumer pays. +// +// Run: go test -bench='BenchmarkGGUF' -benchmem -run='^$' . + +package inference + +import ( + "bytes" + "encoding/binary" + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. +var ( + ggufSinkInfo GGUFInfo + ggufSinkErr error + ggufSinkStr string +) + +// writeBenchGGUF builds a synthetic GGUF with the requested metadata +// shape — same wire format the production parser reads but built +// in-memory and written to a temp file via core.WriteFile so the +// bench harness can re-parse the same file many times. +func writeBenchGGUF(b *testing.B, metadata map[string]any) string { + b.Helper() + buf := core.NewBuffer() + mustWrite := func(value any) { + if err := binary.Write(buf, binary.LittleEndian, value); err != nil { + b.Fatal(err) + } + } + writeString := func(value string) { + mustWrite(uint64(len(value))) + if _, err := buf.Write([]byte(value)); err != nil { + b.Fatal(err) + } + } + mustWrite(uint32(0x46554747)) // magic + mustWrite(uint32(3)) // version + mustWrite(uint64(0)) // tensor count + mustWrite(uint64(len(metadata))) + for key, value := range metadata { + writeString(key) + switch typed := value.(type) { + case string: + mustWrite(uint32(8)) + writeString(typed) + case uint32: + mustWrite(uint32(4)) + mustWrite(typed) + default: + b.Fatalf("unsupported metadata test value %T", value) + } + } + path := core.JoinPath(b.TempDir(), "model.gguf") + if r := core.WriteFile(path, buf.Bytes(), 0o644); !r.OK { + b.Fatal(r.Value) + } + return path +} + +// --- ReadGGUFInfo end-to-end (per-model load floor) --- + +func BenchmarkGGUF_ReadInfo_Minimal(b *testing.B) { + path := writeBenchGGUF(b, map[string]any{ + "general.architecture": "qwen3", + "general.file_type": uint32(15), + "qwen3.block_count": uint32(28), + "qwen3.context_length": uint32(40960), + "qwen3.embedding_length": uint32(2048), + }) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ggufSinkInfo, ggufSinkErr = ReadGGUFInfo(path) + } +} + +// BenchmarkGGUF_ReadInfo_VocabHeavy approximates a real model header +// — a few architecture fields plus a synthetic burst of metadata +// entries that mirrors the per-entry alloc cost of vocab string +// tables (which can have 256k+ entries on Gemma-class tokenisers). +func BenchmarkGGUF_ReadInfo_VocabHeavy(b *testing.B) { + metadata := map[string]any{ + "general.architecture": "qwen3", + "general.file_type": uint32(15), + "qwen3.block_count": uint32(28), + "qwen3.context_length": uint32(40960), + "qwen3.embedding_length": uint32(2048), + } + // 200 synthetic metadata string entries — proxy for tokeniser + // configuration + vocab marker strings. + for i := 0; i < 200; i++ { + metadata[core.Sprintf("synthetic.meta.%d", i)] = core.Sprintf("value-payload-%d", i) + } + path := writeBenchGGUF(b, metadata) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ggufSinkInfo, ggufSinkErr = ReadGGUFInfo(path) + } +} + +// --- readGGUFString in isolation (per-entry hot loop) --- + +func BenchmarkGGUF_ReadString_Short(b *testing.B) { + payload := []byte("qwen3") + header := make([]byte, 8) + binary.LittleEndian.PutUint64(header, uint64(len(payload))) + frame := append(header, payload...) + scratch := make([]byte, 8) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ggufSinkStr, ggufSinkErr = readGGUFString(bytes.NewReader(frame), scratch) + } +} + +func BenchmarkGGUF_ReadString_Long(b *testing.B) { + // Token strings can be up to a few hundred bytes (BPE merges). + payload := bytes.Repeat([]byte("abcdef"), 64) // 384 bytes + header := make([]byte, 8) + binary.LittleEndian.PutUint64(header, uint64(len(payload))) + frame := append(header, payload...) + scratch := make([]byte, 8) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ggufSinkStr, ggufSinkErr = readGGUFString(bytes.NewReader(frame), scratch) + } +} diff --git a/go/gguf_test.go b/go/gguf_test.go new file mode 100644 index 0000000..8c9c7ae --- /dev/null +++ b/go/gguf_test.go @@ -0,0 +1,88 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "encoding/binary" + "testing" + + core "dappco.re/go" +) + +func TestGGUF_ReadGGUFInfo_Good(t *testing.T) { + path := writeMinimalGGUF(t, map[string]any{ + "general.architecture": "qwen3", + "general.file_type": uint32(15), + "qwen3.block_count": uint32(28), + "qwen3.context_length": uint32(40960), + }) + + info, err := ReadGGUFInfo(path) + + checkNoError(t, err) + checkEqual(t, "qwen3", info.Architecture) + checkEqual(t, 4, info.QuantBits) + checkEqual(t, 28, info.NumLayers) + checkEqual(t, 40960, info.ContextLength) +} + +func TestGGUF_ReadGGUFInfo_Bad(t *testing.T) { + info, err := ReadGGUFInfo(core.JoinPath(t.TempDir(), "missing.gguf")) + + checkError(t, err) + checkEqual(t, GGUFInfo{}, info) +} + +func TestGGUF_DiscoverModels_Ugly(t *testing.T) { + dir := t.TempDir() + path := writeMinimalGGUFAt(t, core.JoinPath(dir, "model.gguf"), map[string]any{ + "general.architecture": "gemma4_text", + "general.file_type": uint32(7), + }) + + models := DiscoverModels(dir) + + checkLen(t, models, 1) + checkEqual(t, path, models[0].Path) + checkEqual(t, "gemma4_text", models[0].ModelType) + checkEqual(t, "gguf", models[0].Format) +} + +func writeMinimalGGUF(t *testing.T, metadata map[string]any) string { + t.Helper() + return writeMinimalGGUFAt(t, core.JoinPath(t.TempDir(), "model.gguf"), metadata) +} + +func writeMinimalGGUFAt(t *testing.T, path string, metadata map[string]any) string { + t.Helper() + buf := core.NewBuffer() + mustWrite := func(value any) { + checkNoError(t, binary.Write(buf, binary.LittleEndian, value)) + } + writeString := func(value string) { + mustWrite(uint64(len(value))) + _, err := buf.Write([]byte(value)) + checkNoError(t, err) + } + + mustWrite(uint32(0x46554747)) + mustWrite(uint32(3)) + mustWrite(uint64(0)) + mustWrite(uint64(len(metadata))) + for key, value := range metadata { + writeString(key) + switch typed := value.(type) { + case string: + mustWrite(uint32(8)) + writeString(typed) + case uint32: + mustWrite(uint32(4)) + mustWrite(typed) + default: + t.Fatalf("unsupported metadata test value %T", value) + } + } + result := core.WriteFile(path, buf.Bytes(), 0o644) + checkResultOK(t, result) + return path +} diff --git a/go/go.mod b/go/go.mod index 0f6b7eb..641ae43 100644 --- a/go/go.mod +++ b/go/go.mod @@ -2,4 +2,12 @@ module dappco.re/go/inference go 1.26.0 -require dappco.re/go v0.9.0 +require dappco.re/go v0.10.2 + +require ( + forge.lthn.ai/Snider/Enchantrix v0.0.5 + github.com/ProtonMail/go-crypto v1.3.0 // indirect + github.com/cloudflare/circl v1.6.3 // indirect + golang.org/x/crypto v0.48.0 // indirect + golang.org/x/sys v0.41.0 // indirect +) diff --git a/go/go.sum b/go/go.sum index f11464a..c73a3ca 100644 --- a/go/go.sum +++ b/go/go.sum @@ -1,2 +1,20 @@ -dappco.re/go v0.9.0 h1:4ruZRNqKDDva8o6g65tYggjGVe42E6/lMZfVKXtr3p0= -dappco.re/go v0.9.0/go.mod h1:xapr7fLK4/9Pu2iSCr4qZuIuatmtx1j56zS/oPDbGyQ= +dappco.re/go v0.10.2 h1:ifwXpUl2vBwAQ7krjfqv+yA/ptNrEepOMCHcdfXu1tg= +dappco.re/go v0.10.2/go.mod h1:xapr7fLK4/9Pu2iSCr4qZuIuatmtx1j56zS/oPDbGyQ= +forge.lthn.ai/Snider/Enchantrix v0.0.5 h1:Yam0z+3AOvCUCHAMP68Ty8qHr2e4MMs7j2FjMM2JWc8= +forge.lthn.ai/Snider/Enchantrix v0.0.5/go.mod h1:/YcjKMNpC4Ze/fz7zbTx3djN0CJmSM83YiR2KaMK6zQ= +github.com/ProtonMail/go-crypto v1.3.0 h1:ILq8+Sf5If5DCpHQp4PbZdS1J7HDFRXz/+xKBiRGFrw= +github.com/ProtonMail/go-crypto v1.3.0/go.mod h1:9whxjD8Rbs29b4XWbB8irEcE8KHMqaR2e7GWU1R+/PE= +github.com/cloudflare/circl v1.6.3 h1:9GPOhQGF9MCYUeXyMYlqTR6a5gTrgR/fBLXvUgtVcg8= +github.com/cloudflare/circl v1.6.3/go.mod h1:2eXP6Qfat4O/Yhh8BznvKnJ+uzEoTQ6jVKJRn81BiS4= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +golang.org/x/crypto v0.48.0 h1:/VRzVqiRSggnhY7gNRxPauEQ5Drw9haKdM0jqfcCFts= +golang.org/x/crypto v0.48.0/go.mod h1:r0kV5h3qnFPlQnBSrULhlsRfryS2pmewsg+XfMgkVos= +golang.org/x/sys v0.41.0 h1:Ivj+2Cp/ylzLiEU89QhWblYnOE9zerudt9Ftecq2C6k= +golang.org/x/sys v0.41.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/go/identity.go b/go/identity.go new file mode 100644 index 0000000..226758d --- /dev/null +++ b/go/identity.go @@ -0,0 +1,64 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "slices" + + "dappco.re/go/inference/state" +) + +type ModelIdentity = state.ModelIdentity +type TokenizerIdentity = state.TokenizerIdentity +type AdapterIdentity = state.AdapterIdentity +type RuntimeIdentity = state.RuntimeIdentity +type SamplerConfig = state.SamplerConfig +type StateRef = state.StateRef +type StateBundle = state.Bundle +type ProjectSeedMode = state.ProjectSeedMode +type ProjectSeedOptions = state.ProjectSeedOptions +type ProjectSeed = state.ProjectSeed +type ProjectSeedWakeOptions = state.ProjectSeedWakeOptions +type ProjectSeedContinuationOptions = state.ProjectSeedContinuationOptions +type ProjectSeedContinuationPlan = state.ProjectSeedContinuationPlan +type WakeCompatibilityReport = state.WakeCompatibilityReport + +const ( + ProjectSeedStateCheckpoint = state.ProjectSeedStateCheckpoint + ProjectSeedReuseCurrent = state.ProjectSeedReuseCurrent + ProjectSeedSummaryWindow = state.ProjectSeedSummaryWindow + ProjectSeedHybrid = state.ProjectSeedHybrid +) + +var ( + NewProjectSeed = state.NewProjectSeed + CheckWakeCompatibility = state.CheckWakeCompatibility +) + +// SamplerConfigFromGenerateConfig converts generation options to portable +// sampler metadata while preserving slice ownership. +func SamplerConfigFromGenerateConfig(cfg GenerateConfig) SamplerConfig { + return SamplerConfig{ + MaxTokens: cfg.MaxTokens, + Temperature: cfg.Temperature, + TopK: cfg.TopK, + TopP: cfg.TopP, + RepeatPenalty: cfg.RepeatPenalty, + StopTokens: slices.Clone(cfg.StopTokens), + ReturnLogits: cfg.ReturnLogits, + } +} + +// GenerateConfigFromSamplerConfig converts portable sampler metadata back into +// generation options while preserving slice ownership. +func GenerateConfigFromSamplerConfig(cfg SamplerConfig) GenerateConfig { + return GenerateConfig{ + MaxTokens: cfg.MaxTokens, + Temperature: cfg.Temperature, + TopK: cfg.TopK, + TopP: cfg.TopP, + StopTokens: slices.Clone(cfg.StopTokens), + RepeatPenalty: cfg.RepeatPenalty, + ReturnLogits: cfg.ReturnLogits, + } +} diff --git a/go/identity_bench_test.go b/go/identity_bench_test.go new file mode 100644 index 0000000..a8a71b4 --- /dev/null +++ b/go/identity_bench_test.go @@ -0,0 +1,406 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the identity / state-bundle surface. +// Per AX-11 — SamplerConfigFromGenerateConfig fires per request when +// state primitives capture the active sampler, and the reverse +// conversion fires per session resume. ProjectSeed.WakeRequest fires +// per wake; CheckWakeCompatibility fires per wake to validate the +// bundle against the live runtime — its allocation profile matters +// because every wake pays it. +// +// Run: go test -bench=BenchmarkIdentity -benchmem -run='^$' . + +package inference + +import ( + "testing" +) + +// Sinks defeat compiler DCE. +var ( + identityBenchSinkSampler SamplerConfig + identityBenchSinkGenerateCfg GenerateConfig + identityBenchSinkSeed ProjectSeed + identityBenchSinkWakeRequest AgentMemoryWakeRequest + identityBenchSinkCompatibility WakeCompatibilityReport + identityBenchSinkBundle StateBundle + identityBenchSinkModelIdentity ModelIdentity + identityBenchSinkAdapterIdent AdapterIdentity + identityBenchSinkTokenizerIdent TokenizerIdentity + identityBenchSinkRuntimeIdent RuntimeIdentity +) + +// benchGenerateConfigMinimal — the floor (just MaxTokens set). +func benchGenerateConfigMinimal() GenerateConfig { + return GenerateConfig{ + MaxTokens: 128, + } +} + +// benchGenerateConfigTypical — knob-set seen in real chat requests. +func benchGenerateConfigTypical() GenerateConfig { + return GenerateConfig{ + MaxTokens: 256, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + StopTokens: []int32{2}, + RepeatPenalty: 1.1, + } +} + +// benchGenerateConfigHeavy — large stop-set, logits on (classification path). +func benchGenerateConfigHeavy() GenerateConfig { + return GenerateConfig{ + MaxTokens: 2048, + Temperature: 0.8, + TopK: 50, + TopP: 0.95, + StopTokens: []int32{0, 1, 2, 3, 4, 5, 6, 7}, + RepeatPenalty: 1.15, + ReturnLogits: true, + } +} + +// benchSamplerConfigTypical — sampler-side shape, sized like the +// generate-config above but in its serialisable form. +func benchSamplerConfigTypical() SamplerConfig { + return SamplerConfig{ + MaxTokens: 256, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + RepeatPenalty: 1.1, + StopTokens: []int32{2}, + } +} + +func benchSamplerConfigHeavy() SamplerConfig { + return SamplerConfig{ + MaxTokens: 2048, + Temperature: 0.8, + TopK: 50, + TopP: 0.95, + RepeatPenalty: 1.15, + StopTokens: []int32{0, 1, 2, 3, 4, 5, 6, 7}, + StopSequences: []string{"", "[END]"}, + ReturnLogits: true, + } +} + +// benchStateBundleTypical — what a session checkpoint actually carries +// — model + tokenizer + adapter + sampler + a few KV refs. +func benchStateBundleTypical() StateBundle { + return StateBundle{ + Version: "1", + Model: ModelIdentity{ + Architecture: "qwen3", + Hash: "sha256:model-a", + QuantBits: 4, + ContextLength: 32768, + NumLayers: 28, + HiddenSize: 2048, + VocabSize: 151936, + }, + Tokenizer: TokenizerIdentity{ + Kind: "sentencepiece", + Hash: "sha256:tok-a", + EOSID: 2, + BOSID: 1, + }, + Adapter: AdapterIdentity{ + Hash: "sha256:adapter-a", + Format: "lora", + Rank: 16, + Alpha: 32, + TargetKeys: []string{"q_proj", "v_proj"}, + }, + Sampler: benchSamplerConfigTypical(), + Runtime: RuntimeIdentity{ + Backend: "metal", + Device: "M3 Ultra", + NativeRuntime: true, + }, + PromptTokens: 256, + GeneratedTokens: 128, + KVRefs: []StateRef{ + {Kind: "kv", URI: "state://lthn/snap/0", SizeBytes: 1 << 24, Encoding: "paged-q8"}, + {Kind: "kv", URI: "state://lthn/snap/1", SizeBytes: 1 << 24, Encoding: "paged-q8"}, + }, + } +} + +// --- SamplerConfigFromGenerateConfig (per-request capture) --- + +func BenchmarkIdentity_SamplerConfigFromGenerateConfig_Minimal(b *testing.B) { + cfg := benchGenerateConfigMinimal() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkSampler = SamplerConfigFromGenerateConfig(cfg) + } +} + +func BenchmarkIdentity_SamplerConfigFromGenerateConfig_Typical(b *testing.B) { + cfg := benchGenerateConfigTypical() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkSampler = SamplerConfigFromGenerateConfig(cfg) + } +} + +func BenchmarkIdentity_SamplerConfigFromGenerateConfig_Heavy(b *testing.B) { + cfg := benchGenerateConfigHeavy() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkSampler = SamplerConfigFromGenerateConfig(cfg) + } +} + +// Empty config → empty sampler — no slice clone cost. +func BenchmarkIdentity_SamplerConfigFromGenerateConfig_Empty(b *testing.B) { + cfg := GenerateConfig{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkSampler = SamplerConfigFromGenerateConfig(cfg) + } +} + +// --- GenerateConfigFromSamplerConfig (per-session resume) --- + +func BenchmarkIdentity_GenerateConfigFromSamplerConfig_Typical(b *testing.B) { + sampler := benchSamplerConfigTypical() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkGenerateCfg = GenerateConfigFromSamplerConfig(sampler) + } +} + +func BenchmarkIdentity_GenerateConfigFromSamplerConfig_Heavy(b *testing.B) { + sampler := benchSamplerConfigHeavy() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkGenerateCfg = GenerateConfigFromSamplerConfig(sampler) + } +} + +func BenchmarkIdentity_GenerateConfigFromSamplerConfig_Empty(b *testing.B) { + sampler := SamplerConfig{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkGenerateCfg = GenerateConfigFromSamplerConfig(sampler) + } +} + +// --- Identity construction (per-LoadModel / per-checkpoint cost) --- + +func BenchmarkIdentity_ModelIdentity_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkModelIdentity = ModelIdentity{ + Architecture: "qwen3", + Hash: "sha256:model-a", + QuantBits: 4, + ContextLength: 32768, + NumLayers: 28, + HiddenSize: 2048, + VocabSize: 151936, + } + } +} + +func BenchmarkIdentity_TokenizerIdentity_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkTokenizerIdent = TokenizerIdentity{ + Kind: "sentencepiece", + Hash: "sha256:tok-a", + EOSID: 2, + BOSID: 1, + } + } +} + +func BenchmarkIdentity_AdapterIdentity_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkAdapterIdent = AdapterIdentity{ + Hash: "sha256:adapter-a", + Format: "lora", + Rank: 16, + Alpha: 32, + TargetKeys: []string{"q_proj", "v_proj"}, + } + } +} + +func BenchmarkIdentity_RuntimeIdentity_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkRuntimeIdent = RuntimeIdentity{ + Backend: "metal", + Device: "M3 Ultra", + NativeRuntime: true, + } + } +} + +// --- StateBundle construction (per-checkpoint cost) --- + +func BenchmarkIdentity_StateBundle_ConstructTypical(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkBundle = benchStateBundleTypical() + } +} + +// --- ProjectSeed (per session-bootstrap cost) --- + +func BenchmarkIdentity_NewProjectSeed_Defaults(b *testing.B) { + opts := ProjectSeedOptions{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkSeed = NewProjectSeed(opts) + } +} + +func BenchmarkIdentity_NewProjectSeed_BaseAndProject(b *testing.B) { + opts := ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkSeed = NewProjectSeed(opts) + } +} + +func BenchmarkIdentity_NewProjectSeed_Full(b *testing.B) { + opts := ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + EntryURI: "state://lthn/projects/core/go-mlx/seed", + BundleURI: "state://lthn/projects/core/go-mlx/seed/bundle", + IndexURI: "state://lthn/projects/core/go-mlx/seed/index", + Title: "core/go-mlx project seed", + Labels: map[string]string{"project_id": "core/go-mlx", "env": "dev"}, + Metadata: map[string]string{"created_by": "cladius"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkSeed = NewProjectSeed(opts) + } +} + +// --- ProjectSeed.WakeRequest (per wake) --- + +func BenchmarkIdentity_ProjectSeed_WakeRequest_Minimal(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedWakeOptions{ + Model: ModelIdentity{Hash: "sha256:model-a"}, + Tokenizer: TokenizerIdentity{Hash: "sha256:tok-a"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkWakeRequest = seed.WakeRequest(opts) + } +} + +func BenchmarkIdentity_ProjectSeed_WakeRequest_Typical(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + Labels: map[string]string{"env": "dev"}, + }) + opts := ProjectSeedWakeOptions{ + Model: ModelIdentity{ + Architecture: "qwen3", + Hash: "sha256:model-a", + NumLayers: 28, + }, + Tokenizer: TokenizerIdentity{ + Kind: "sentencepiece", + Hash: "sha256:tok-a", + }, + Adapter: AdapterIdentity{Hash: "sha256:adapter-a", Format: "lora"}, + Runtime: RuntimeIdentity{Backend: "metal"}, + Labels: map[string]string{"session": "s-7"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkWakeRequest = seed.WakeRequest(opts) + } +} + +// --- CheckWakeCompatibility (per-wake validation) --- +// Iterates over model/tokenizer/adapter/runtime identity fields — +// pays the field-compare cost every wake. + +func BenchmarkIdentity_CheckWakeCompatibility_Skip(b *testing.B) { + bundle := benchStateBundleTypical() + req := AgentMemoryWakeRequest{SkipCompatibilityCheck: true} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkCompatibility = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkIdentity_CheckWakeCompatibility_Match(b *testing.B) { + bundle := benchStateBundleTypical() + req := AgentMemoryWakeRequest{ + Model: bundle.Model, + Tokenizer: bundle.Tokenizer, + Adapter: bundle.Adapter, + Runtime: bundle.Runtime, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkCompatibility = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkIdentity_CheckWakeCompatibility_HashMismatch(b *testing.B) { + bundle := benchStateBundleTypical() + req := AgentMemoryWakeRequest{ + Model: ModelIdentity{Hash: "sha256:other-model", Architecture: "gemma3", NumLayers: 12}, + Tokenizer: TokenizerIdentity{Hash: "sha256:other-tok"}, + Adapter: AdapterIdentity{Hash: "sha256:other-adapter"}, + Runtime: RuntimeIdentity{Backend: "rocm"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkCompatibility = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkIdentity_CheckWakeCompatibility_Empty(b *testing.B) { + bundle := StateBundle{} + req := AgentMemoryWakeRequest{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identityBenchSinkCompatibility = CheckWakeCompatibility(bundle, req) + } +} diff --git a/go/identity_example_test.go b/go/identity_example_test.go new file mode 100644 index 0000000..20fc477 --- /dev/null +++ b/go/identity_example_test.go @@ -0,0 +1,43 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import core "dappco.re/go" + +func ExampleStateBundle() { + bundle := StateBundle{ + Model: ModelIdentity{ + Architecture: "gemma4", + QuantBits: 4, + }, + Runtime: RuntimeIdentity{ + Backend: "metal", + NativeRuntime: true, + }, + } + + core.Println(bundle.Model.Architecture, bundle.Runtime.Backend) + // Output: gemma4 metal +} + +func ExampleSamplerConfigFromGenerateConfig() { + sampler := SamplerConfigFromGenerateConfig(GenerateConfig{ + MaxTokens: 32, + TopK: 8, + StopTokens: []int32{2}, + }) + + core.Println(sampler.MaxTokens, sampler.TopK, sampler.StopTokens) + // Output: 32 8 [2] +} + +func ExampleGenerateConfigFromSamplerConfig() { + cfg := GenerateConfigFromSamplerConfig(SamplerConfig{ + MaxTokens: 64, + Temperature: 0.2, + RepeatPenalty: 1.1, + }) + + core.Println(cfg.MaxTokens, cfg.Temperature, cfg.RepeatPenalty) + // Output: 64 0.2 1.1 +} diff --git a/go/identity_test.go b/go/identity_test.go new file mode 100644 index 0000000..81d62ef --- /dev/null +++ b/go/identity_test.go @@ -0,0 +1,160 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import "testing" + +func TestIdentity_SamplerConfigFromGenerateConfig_Good(t *testing.T) { + cfg := GenerateConfig{ + MaxTokens: 64, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + StopTokens: []int32{1, 2}, + RepeatPenalty: 1.1, + ReturnLogits: true, + } + sampler := SamplerConfigFromGenerateConfig(cfg) + cfg.StopTokens[0] = 99 + + checkEqual(t, []int32{1, 2}, sampler.StopTokens) + checkEqual(t, 64, sampler.MaxTokens) + checkEqual(t, float32(0.7), sampler.Temperature) + checkEqual(t, 40, sampler.TopK) + checkEqual(t, float32(0.9), sampler.TopP) + checkEqual(t, float32(1.1), sampler.RepeatPenalty) + checkTrue(t, sampler.ReturnLogits) +} + +func TestIdentity_SamplerConfigFromGenerateConfig_Bad(t *testing.T) { + sampler := SamplerConfigFromGenerateConfig(GenerateConfig{}) + + checkEqual(t, 0, sampler.MaxTokens) + checkEmpty(t, sampler.StopTokens) + checkFalse(t, sampler.ReturnLogits) +} + +func TestIdentity_SamplerConfigFromGenerateConfig_Ugly(t *testing.T) { + cfg := GenerateConfig{StopTokens: []int32{}} + + sampler := SamplerConfigFromGenerateConfig(cfg) + cfg.StopTokens = append(cfg.StopTokens, 7) + + checkEmpty(t, sampler.StopTokens) + checkEqual(t, []int32{7}, cfg.StopTokens) +} + +func TestIdentity_GenerateConfigFromSamplerConfig_Good(t *testing.T) { + sampler := SamplerConfig{ + MaxTokens: 128, + Temperature: 0.2, + TopK: 8, + TopP: 0.5, + StopTokens: []int32{3, 4}, + RepeatPenalty: 1.2, + ReturnLogits: true, + } + cfg := GenerateConfigFromSamplerConfig(sampler) + sampler.StopTokens[0] = 99 + + checkEqual(t, []int32{3, 4}, cfg.StopTokens) + checkEqual(t, 128, cfg.MaxTokens) + checkEqual(t, float32(0.2), cfg.Temperature) + checkEqual(t, 8, cfg.TopK) + checkEqual(t, float32(0.5), cfg.TopP) + checkEqual(t, float32(1.2), cfg.RepeatPenalty) + checkTrue(t, cfg.ReturnLogits) +} + +func TestIdentity_GenerateConfigFromSamplerConfig_Bad(t *testing.T) { + cfg := GenerateConfigFromSamplerConfig(SamplerConfig{}) + + checkEqual(t, 0, cfg.MaxTokens) + checkEmpty(t, cfg.StopTokens) + checkFalse(t, cfg.ReturnLogits) +} + +func TestIdentity_GenerateConfigFromSamplerConfig_Ugly(t *testing.T) { + sampler := SamplerConfig{StopTokens: []int32{}} + + cfg := GenerateConfigFromSamplerConfig(sampler) + sampler.StopTokens = append(sampler.StopTokens, 7) + + checkEmpty(t, cfg.StopTokens) + checkEqual(t, []int32{7}, sampler.StopTokens) +} + +func TestIdentity_StateBundle_Good(t *testing.T) { + bundle := StateBundle{ + Version: "1", + Model: ModelIdentity{ + Architecture: "qwen3", + QuantBits: 4, + ContextLength: 32768, + }, + Tokenizer: TokenizerIdentity{ + Kind: "sentencepiece", + EOSID: 2, + }, + Adapter: AdapterIdentity{ + Rank: 16, + Alpha: 32, + TargetKeys: []string{"q_proj", "v_proj"}, + }, + Runtime: RuntimeIdentity{ + Backend: "metal", + NativeRuntime: true, + }, + Sampler: SamplerConfig{ + MaxTokens: 256, + }, + KVRefs: []StateRef{{ + Kind: "kv", + URI: "file:///tmp/state.kvbin", + }}, + } + + checkEqual(t, "qwen3", bundle.Model.Architecture) + checkEqual(t, int32(2), bundle.Tokenizer.EOSID) + checkEqual(t, 16, bundle.Adapter.Rank) + checkTrue(t, bundle.Runtime.NativeRuntime) + checkLen(t, bundle.KVRefs, 1) +} + +func TestIdentity_StateBundle_Bad_EmptyAllowed(t *testing.T) { + bundle := StateBundle{} + + checkEqual(t, "", bundle.Model.Architecture) + checkEqual(t, 0, bundle.Sampler.MaxTokens) + checkEmpty(t, bundle.KVRefs) +} + +func TestIdentity_ProjectSeedAliases_Good(t *testing.T) { + seed := NewProjectSeed(ProjectSeedOptions{BaseURI: "state://lthn/projects", ProjectID: "core/go-mlx"}) + wake := seed.WakeRequest(ProjectSeedWakeOptions{ + Model: ModelIdentity{Hash: "model-a"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + }) + + report := CheckWakeCompatibility(StateBundle{ + Model: ModelIdentity{Hash: "model-a"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + PromptTokens: 16, + }, wake) + + checkEqual(t, "state://lthn/projects/core/go-mlx/seed", wake.EntryURI) + checkTrue(t, report.Compatible) +} + +func TestIdentity_AdapterIdentity_Ugly_MetadataOnly(t *testing.T) { + adapter := AdapterIdentity{ + Hash: "sha256:abc", + Format: "lora", + BaseModelHash: "sha256:base", + Labels: map[string]string{"source": "unit"}, + } + + checkEqual(t, "sha256:abc", adapter.Hash) + checkEqual(t, "unit", adapter.Labels["source"]) + checkEmpty(t, adapter.TargetKeys) +} diff --git a/go/inference.go b/go/inference.go index 19ec860..fe152be 100644 --- a/go/inference.go +++ b/go/inference.go @@ -63,7 +63,6 @@ package inference import ( "context" "iter" - "maps" "slices" "time" @@ -263,13 +262,6 @@ var ( } ) -func snapshotBackends() map[string]Backend { - backendsMu.RLock() - snap := maps.Clone(backends) - backendsMu.RUnlock() - return snap -} - // Register adds b to the global registry, overwriting any existing entry with the same name. // // func init() { inference.Register(metal.NewBackend()) } @@ -293,19 +285,57 @@ func Get(name string) (Backend, bool) { } // names := inference.List() // ["llama_cpp", "metal", "rocm"] +// +// Single-pass key copy under RLock — earlier shape did maps.Clone + +// maps.Keys + slices.Sorted (~4 allocs + bucket cost). Direct slice +// build is 1 alloc; empty registry returns nil (preserves the test +// contract that callers can branch on). func List() []string { - return slices.Sorted(maps.Keys(snapshotBackends())) + backendsMu.RLock() + if len(backends) == 0 { + backendsMu.RUnlock() + return nil + } + names := make([]string, 0, len(backends)) + for name := range backends { + names = append(names, name) + } + backendsMu.RUnlock() + slices.Sort(names) + return names } // for name, b := range inference.All() { // fmt.Println(name, b.Available()) // } +// +// Builds a slice of (name, backend) pairs under RLock so the returned +// iterator runs without holding any lock — single alloc for the pair +// slice instead of the previous maps.Clone + maps.Keys + slices.Sorted +// cascade. func All() iter.Seq2[string, Backend] { - snap := snapshotBackends() - names := slices.Sorted(maps.Keys(snap)) + type entry struct { + name string + back Backend + } + backendsMu.RLock() + entries := make([]entry, 0, len(backends)) + for name, b := range backends { + entries = append(entries, entry{name, b}) + } + backendsMu.RUnlock() + slices.SortFunc(entries, func(a, b entry) int { + if a.name < b.name { + return -1 + } + if a.name > b.name { + return 1 + } + return 0 + }) return func(yield func(string, Backend) bool) { - for _, name := range names { - if !yield(name, snap[name]) { + for _, e := range entries { + if !yield(e.name, e.back) { return } } @@ -315,25 +345,53 @@ func All() iter.Seq2[string, Backend] { // Default picks the first available backend in preference order: metal → rocm → llama_cpp → any. // // r := inference.Default() // r.Value is the backend when r.OK +// +// Both preferred-order scan and fallback run against direct map +// lookups under RLock — no clone, no Keys-iterator allocation. The +// happy path (preferred backend available) is 0 allocs. func Default() core.Result { - snap := snapshotBackends() - if len(snap) == 0 { + backendsMu.RLock() + if len(backends) == 0 { + backendsMu.RUnlock() return core.Fail(core.E("inference.Default", "no backends registered", nil)) } - // Platform preference order + // Platform preference order — direct map lookups, no clone. for _, name := range preferredBackendOrder { - if b, ok := snap[name]; ok && b.Available() { + if b, ok := backends[name]; ok && b.Available() { + backendsMu.RUnlock() return core.Ok(b) } } - // Fall back to any available - for _, name := range slices.Sorted(maps.Keys(snap)) { - if _, ok := preferredBackendSet[name]; ok { + + // Fall back to any non-preferred backend, in sorted-name order. + // Snapshot (name, backend) pairs under RLock so Available() runs + // outside the lock — matches the prior defensive behaviour. + type entry struct { + name string + back Backend + } + var fallback []entry + for name, b := range backends { + if _, isPreferred := preferredBackendSet[name]; isPreferred { continue } - if backend := snap[name]; backend.Available() { - return core.Ok(backend) + fallback = append(fallback, entry{name, b}) + } + backendsMu.RUnlock() + + slices.SortFunc(fallback, func(a, b entry) int { + if a.name < b.name { + return -1 + } + if a.name > b.name { + return 1 + } + return 0 + }) + for _, e := range fallback { + if e.back.Available() { + return core.Ok(e.back) } } return core.Fail(core.E("inference.Default", "no backends available", nil)) diff --git a/go/inference_bench_test.go b/go/inference_bench_test.go new file mode 100644 index 0000000..a1997f0 --- /dev/null +++ b/go/inference_bench_test.go @@ -0,0 +1,238 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the inference orchestration types — backend registry +// lookups + LoadModel routing + AttentionSnapshot.HasQueries helper. +// Per AX-11 — Register fires once per backend init, but Get / List / All / +// Default run on every model load and every consumer that wants to +// enumerate available backends; HasQueries fires per attention snapshot. +// +// Run: go test -bench='BenchmarkInference' -benchmem -run='^$' . + +package inference + +import ( + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. Distinct names from the gguf bench file. +var ( + inferenceBenchSinkBool bool + inferenceBenchSinkBackend Backend + inferenceBenchSinkBackOK bool + inferenceBenchSinkNames []string + inferenceBenchSinkResult core.Result + inferenceBenchSinkCount int + inferenceBenchSinkSampler SamplerConfig + inferenceBenchSinkGen GenerateConfig +) + +// benchRegisterPreferred wipes the global registry and primes it with +// preferred backends (metal, rocm, llama_cpp) plus n custom backends. +// All preferred are available; custom availability is alternating. +func benchRegisterPreferred(b *testing.B, custom int) { + b.Helper() + backendsMu.Lock() + backends = map[string]Backend{} + backendsMu.Unlock() + Register(&inferenceBenchBackend{name: "metal", available: true}) + Register(&inferenceBenchBackend{name: "rocm", available: true}) + Register(&inferenceBenchBackend{name: "llama_cpp", available: true}) + for i := 0; i < custom; i++ { + Register(&inferenceBenchBackend{ + name: core.Sprintf("custom_%d", i), + available: i%2 == 0, + }) + } +} + +// inferenceBenchBackend is a no-op Backend so the registry-level benches +// don't drag a real loader into the hot path. Distinct name from the +// existing test stubBackend to avoid colliding when the bench files share +// the package. LoadModel is never invoked from these benches, so we keep +// it minimal — the registered backend's role is to populate the registry +// for Get / List / All / Default. +type inferenceBenchBackend struct { + name string + available bool +} + +func (b *inferenceBenchBackend) Name() string { return b.name } +func (b *inferenceBenchBackend) Available() bool { return b.available } +func (b *inferenceBenchBackend) LoadModel(_ string, _ ...LoadOption) (TextModel, error) { + return nil, nil +} + +// --- AttentionSnapshot.HasQueries (per-snapshot helper, pure scan) --- + +func BenchmarkInference_HasQueries_True(b *testing.B) { + snap := &AttentionSnapshot{ + NumLayers: 28, + Queries: make([][][]float32, 28), + } + for i := range snap.Queries { + snap.Queries[i] = make([][]float32, 8) + for j := range snap.Queries[i] { + snap.Queries[i][j] = make([]float32, 128) + } + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkBool = snap.HasQueries() + } +} + +func BenchmarkInference_HasQueries_NilQueries(b *testing.B) { + snap := &AttentionSnapshot{ + NumLayers: 28, + Queries: nil, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkBool = snap.HasQueries() + } +} + +func BenchmarkInference_HasQueries_NilSnapshot(b *testing.B) { + var snap *AttentionSnapshot + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkBool = snap.HasQueries() + } +} + +// --- Registry: Get (per-lookup hot path on every LoadModel) --- + +func BenchmarkInference_Get_Hit(b *testing.B) { + benchRegisterPreferred(b, 0) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkBackend, inferenceBenchSinkBackOK = Get("metal") + } +} + +func BenchmarkInference_Get_Miss(b *testing.B) { + benchRegisterPreferred(b, 0) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkBackend, inferenceBenchSinkBackOK = Get("nonexistent") + } +} + +// --- Registry: List (full snapshot + sort) --- + +func BenchmarkInference_List_Three(b *testing.B) { + benchRegisterPreferred(b, 0) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkNames = List() + } +} + +func BenchmarkInference_List_TwentyBackends(b *testing.B) { + benchRegisterPreferred(b, 17) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkNames = List() + } +} + +// --- Registry: All (iter.Seq2 snapshot + ranged yield) --- + +func BenchmarkInference_All_Three(b *testing.B) { + benchRegisterPreferred(b, 0) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range All() { + count++ + } + inferenceBenchSinkCount = count + } +} + +func BenchmarkInference_All_TwentyBackends(b *testing.B) { + benchRegisterPreferred(b, 17) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range All() { + count++ + } + inferenceBenchSinkCount = count + } +} + +// --- Registry: Default (preference-order scan) --- + +func BenchmarkInference_Default_AllPreferred(b *testing.B) { + benchRegisterPreferred(b, 0) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkResult = Default() + } +} + +// Worst-case: metal + rocm + llama_cpp unavailable, fall through to a +// custom backend — exercises the second loop body. +func BenchmarkInference_Default_FallbackToCustom(b *testing.B) { + backendsMu.Lock() + backends = map[string]Backend{} + backendsMu.Unlock() + Register(&inferenceBenchBackend{name: "metal", available: false}) + Register(&inferenceBenchBackend{name: "rocm", available: false}) + Register(&inferenceBenchBackend{name: "llama_cpp", available: false}) + Register(&inferenceBenchBackend{name: "custom_vulkan", available: true}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkResult = Default() + } +} + +// --- Identity-bridge converters (per Generate call boundary) --- + +func BenchmarkInference_SamplerConfigFromGenerateConfig(b *testing.B) { + cfg := GenerateConfig{ + MaxTokens: 256, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + RepeatPenalty: 1.1, + StopTokens: []int32{2, 1, 0, 42, 1024}, + ReturnLogits: true, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkSampler = SamplerConfigFromGenerateConfig(cfg) + } +} + +func BenchmarkInference_GenerateConfigFromSamplerConfig(b *testing.B) { + cfg := SamplerConfig{ + MaxTokens: 256, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + RepeatPenalty: 1.1, + StopTokens: []int32{2, 1, 0, 42, 1024}, + ReturnLogits: true, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + inferenceBenchSinkGen = GenerateConfigFromSamplerConfig(cfg) + } +} diff --git a/go/jsonenc/jsondec.go b/go/jsonenc/jsondec.go new file mode 100644 index 0000000..68bc645 --- /dev/null +++ b/go/jsonenc/jsondec.go @@ -0,0 +1,629 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// JSON-decoding primitives shared by the inference adapter +// UnmarshalJSON hot paths. The encoding/json reflect path allocates +// an encoder state machine, per-field reflect.Value boxing, and a +// per-string copy on every Unmarshal call — each adapter request +// decoder pays that floor. +// +// Provenance: lifted in W11-B from openai/jsondec.go which shipped +// in W10-M (StopList / EmbeddingInput single-pass walker). The set +// of primitives mirrors the encode side of jsonenc — ParseJSONString +// is the inverse of AppendJSONString and shares the same escape +// contract. Hand-rolled per-type field walkers (anthropic / +// openai / ollama Unmarshal*Request) call directly into these. +// +// All primitives parse the JSON spec across every branch: +// - Whitespace: space, tab, CR, LF. +// - Strings: \" \\ \/ \b \f \n \r \t \uXXXX (UTF-8 re-encoded). +// - Numbers: int64 + float64 with the same shape strconv.ParseFloat +// accepts. +// - Literals: true / false / null. +// +// Output matches what encoding/json.Unmarshal would have produced +// for the same input. + +package jsonenc + +import ( + "errors" + "strconv" +) + +// ErrInvalidJSON is the sentinel returned for malformed input. +// Call sites wrap into typed result errors as appropriate. +var ErrInvalidJSON = errors.New("invalid JSON") + +// ParseJSONStringList walks data as either a JSON string (e.g. +// `"END"`) or an array of JSON strings (e.g. `["END",""]`) and +// returns a []string with the inner values unescaped. +// +// The "null" literal returns (nil, nil). Empty or invalid data +// returns ErrInvalidJSON; otherwise the first non-whitespace byte +// determines the shape. +// +// stops, err := jsonenc.ParseJSONStringList([]byte(`["a","b"]`)) +// // stops == []string{"a","b"} +// +// stops, err := jsonenc.ParseJSONStringList([]byte(`"END"`)) +// // stops == []string{"END"} +func ParseJSONStringList(data []byte) ([]string, error) { + i := SkipJSONWhitespace(data, 0) + if i >= len(data) { + return nil, ErrInvalidJSON + } + c := data[i] + if c == 'n' { + // Possible "null" literal. + if i+4 <= len(data) && data[i+1] == 'u' && data[i+2] == 'l' && data[i+3] == 'l' { + return nil, nil + } + return nil, ErrInvalidJSON + } + if c == '"' { + s, _, err := ParseJSONString(data, i) + if err != nil { + return nil, err + } + return []string{s}, nil + } + if c == '[' { + return parseJSONStringArray(data, i+1) + } + return nil, ErrInvalidJSON +} + +// parseJSONStringArray walks data from position i (just past the '[') +// and returns the inner array of strings. +func parseJSONStringArray(data []byte, i int) ([]string, error) { + out := []string(nil) + // Empty-array fast path. + j := SkipJSONWhitespace(data, i) + if j < len(data) && data[j] == ']' { + return out, nil + } + for { + i = SkipJSONWhitespace(data, i) + if i >= len(data) { + return nil, ErrInvalidJSON + } + if data[i] != '"' { + return nil, ErrInvalidJSON + } + s, next, err := ParseJSONString(data, i) + if err != nil { + return nil, err + } + out = append(out, s) + i = SkipJSONWhitespace(data, next) + if i >= len(data) { + return nil, ErrInvalidJSON + } + switch data[i] { + case ',': + i++ + case ']': + return out, nil + default: + return nil, ErrInvalidJSON + } + } +} + +// ParseJSONString walks a JSON string starting at data[i] (which must +// be '"') and returns the unescaped string + the index one past the +// closing '"'. +// +// The fast path (no escapes) returns a string copy of the slice +// range directly via Go's built-in string conversion. The escape +// path walks byte-by-byte and re-decodes \" \\ \b \f \n \r \t / \uXXXX +// escapes. Most adapter wire strings carry no escapes — the fast +// path is the common case. +// +// value, next, err := jsonenc.ParseJSONString(data, i) +func ParseJSONString(data []byte, i int) (string, int, error) { + if i >= len(data) || data[i] != '"' { + return "", i, ErrInvalidJSON + } + start := i + 1 + for j := start; j < len(data); j++ { + c := data[j] + if c == '"' { + return string(data[start:j]), j + 1, nil + } + if c == '\\' { + return parseJSONStringEscaped(data, start, j) + } + if c < 0x20 { + return "", j, ErrInvalidJSON + } + } + return "", i, ErrInvalidJSON +} + +// ParseJSONStringRaw is the no-copy variant of ParseJSONString — +// returns a []byte slice into data when no escapes are present, or +// allocates only when an escape forces a copy. Caller MUST treat +// the returned slice as read-only and assignable to a string via +// the standard byte-to-string conversion when persistence is needed. +// +// Hot use case: anthropic/openai field dispatch where the matched +// key path can clone the underlying string in one allocation rather +// than two. +func ParseJSONStringRaw(data []byte, i int) ([]byte, int, error) { + if i >= len(data) || data[i] != '"' { + return nil, i, ErrInvalidJSON + } + start := i + 1 + for j := start; j < len(data); j++ { + c := data[j] + if c == '"' { + return data[start:j], j + 1, nil + } + if c == '\\' { + s, next, err := parseJSONStringEscaped(data, start, j) + if err != nil { + return nil, next, err + } + return []byte(s), next, nil + } + if c < 0x20 { + return nil, j, ErrInvalidJSON + } + } + return nil, i, ErrInvalidJSON +} + +// parseJSONStringEscaped is the slow path for strings containing +// backslash escapes. Walks the remainder character-by-character, +// emitting into a backing buffer with appended decoded bytes. +func parseJSONStringEscaped(data []byte, start, firstEscape int) (string, int, error) { + buf := make([]byte, 0, len(data)-start) + buf = append(buf, data[start:firstEscape]...) + for i := firstEscape; i < len(data); { + c := data[i] + if c == '"' { + return string(buf), i + 1, nil + } + if c == '\\' { + if i+1 >= len(data) { + return "", i, ErrInvalidJSON + } + esc := data[i+1] + switch esc { + case '"': + buf = append(buf, '"') + case '\\': + buf = append(buf, '\\') + case '/': + buf = append(buf, '/') + case 'b': + buf = append(buf, '\b') + case 'f': + buf = append(buf, '\f') + case 'n': + buf = append(buf, '\n') + case 'r': + buf = append(buf, '\r') + case 't': + buf = append(buf, '\t') + case 'u': + if i+6 > len(data) { + return "", i, ErrInvalidJSON + } + cp, ok := parseJSONUnicodeEscape(data[i+2 : i+6]) + if !ok { + return "", i, ErrInvalidJSON + } + // UTF-8 encode the codepoint. + buf = appendUTF8(buf, cp) + i += 6 + continue + default: + return "", i, ErrInvalidJSON + } + i += 2 + continue + } + if c < 0x20 { + return "", i, ErrInvalidJSON + } + buf = append(buf, c) + i++ + } + return "", firstEscape, ErrInvalidJSON +} + +// parseJSONUnicodeEscape decodes a 4-hex-digit codepoint following +// the \u escape prefix. +func parseJSONUnicodeEscape(hex []byte) (rune, bool) { + if len(hex) != 4 { + return 0, false + } + var cp rune + for _, b := range hex { + var v rune + switch { + case b >= '0' && b <= '9': + v = rune(b - '0') + case b >= 'a' && b <= 'f': + v = rune(b-'a') + 10 + case b >= 'A' && b <= 'F': + v = rune(b-'A') + 10 + default: + return 0, false + } + cp = cp<<4 | v + } + return cp, true +} + +// appendUTF8 appends the UTF-8 encoding of cp to buf. +func appendUTF8(buf []byte, cp rune) []byte { + switch { + case cp < 0x80: + return append(buf, byte(cp)) + case cp < 0x800: + return append(buf, byte(0xc0|cp>>6), byte(0x80|cp&0x3f)) + case cp < 0x10000: + return append(buf, byte(0xe0|cp>>12), byte(0x80|(cp>>6)&0x3f), byte(0x80|cp&0x3f)) + default: + return append(buf, byte(0xf0|cp>>18), byte(0x80|(cp>>12)&0x3f), byte(0x80|(cp>>6)&0x3f), byte(0x80|cp&0x3f)) + } +} + +// SkipJSONWhitespace advances i past JSON whitespace bytes — space, +// tab, CR, LF — and returns the new position. +// +// i := jsonenc.SkipJSONWhitespace(data, 0) +func SkipJSONWhitespace(data []byte, i int) int { + for i < len(data) { + c := data[i] + if c == ' ' || c == '\t' || c == '\n' || c == '\r' { + i++ + continue + } + break + } + return i +} + +// ParseJSONInt walks a JSON integer (possibly signed) at data[i] +// and returns the parsed int64 + the index one past the last digit. +// Accepts the same shape encoding/json accepts for an integer field +// (no leading '+', no leading zeros except the lone '0'). +// +// n, next, err := jsonenc.ParseJSONInt(data, i) +func ParseJSONInt(data []byte, i int) (int64, int, error) { + if i >= len(data) { + return 0, i, ErrInvalidJSON + } + start := i + neg := false + if data[i] == '-' { + neg = true + i++ + if i >= len(data) { + return 0, i, ErrInvalidJSON + } + } + c := data[i] + if c < '0' || c > '9' { + return 0, i, ErrInvalidJSON + } + var n int64 + for i < len(data) { + c := data[i] + if c < '0' || c > '9' { + break + } + n = n*10 + int64(c-'0') + i++ + } + if neg { + n = -n + } + if i == start { + return 0, i, ErrInvalidJSON + } + return n, i, nil +} + +// ParseJSONBool walks the literal `true` or `false` at data[i] and +// returns the value + the index one past the literal. +// +// v, next, err := jsonenc.ParseJSONBool(data, i) +func ParseJSONBool(data []byte, i int) (bool, int, error) { + if i+4 <= len(data) && data[i] == 't' && data[i+1] == 'r' && data[i+2] == 'u' && data[i+3] == 'e' { + return true, i + 4, nil + } + if i+5 <= len(data) && data[i] == 'f' && data[i+1] == 'a' && data[i+2] == 'l' && data[i+3] == 's' && data[i+4] == 'e' { + return false, i + 5, nil + } + return false, i, ErrInvalidJSON +} + +// IsJSONNull reports whether data[i:] starts with the `null` literal. +// Does NOT advance i — the caller picks the new index based on +// whether they care to consume it. +// +// if jsonenc.IsJSONNull(data, i) { i += 4; continue } +func IsJSONNull(data []byte, i int) bool { + return i+4 <= len(data) && data[i] == 'n' && data[i+1] == 'u' && data[i+2] == 'l' && data[i+3] == 'l' +} + +// SkipJSONValue walks one complete JSON value at data[i] (object, +// array, string, number, true, false, null) and returns the index +// one past the value. Caller uses it to skip an unknown / ignored +// field during single-pass dispatch. +// +// next, err := jsonenc.SkipJSONValue(data, i) +func SkipJSONValue(data []byte, i int) (int, error) { + i = SkipJSONWhitespace(data, i) + if i >= len(data) { + return i, ErrInvalidJSON + } + switch data[i] { + case '{': + return skipJSONObject(data, i+1) + case '[': + return skipJSONArray(data, i+1) + case '"': + return SkipJSONString(data, i) + case 't', 'f': + _, next, err := ParseJSONBool(data, i) + return next, err + case 'n': + if IsJSONNull(data, i) { + return i + 4, nil + } + return i, ErrInvalidJSON + } + return skipJSONNumber(data, i) +} + +// SkipJSONString walks a JSON string at data[i] (which must be '"') +// and returns the index one past the closing '"'. Unlike +// ParseJSONString it does NOT materialise a Go string — callers use +// it when they only need to advance past the value (object-key +// inside a SkipJSONValue path, ignored field, CountJSONArrayElements +// prescan). +// +// next, err := jsonenc.SkipJSONString(data, i) +func SkipJSONString(data []byte, i int) (int, error) { + if i >= len(data) || data[i] != '"' { + return i, ErrInvalidJSON + } + for j := i + 1; j < len(data); j++ { + c := data[j] + if c == '"' { + return j + 1, nil + } + if c == '\\' { + // Escape — bump j past the escape body without decoding. + if j+1 >= len(data) { + return j, ErrInvalidJSON + } + if data[j+1] == 'u' { + if j+6 > len(data) { + return j, ErrInvalidJSON + } + j += 5 + continue + } + j++ + continue + } + if c < 0x20 { + return j, ErrInvalidJSON + } + } + return i, ErrInvalidJSON +} + +// skipJSONObject skips through the object body at data[i:] starting +// just past the '{'. Returns the index one past the closing '}'. +func skipJSONObject(data []byte, i int) (int, error) { + i = SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return i + 1, nil + } + for { + i = SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return i, ErrInvalidJSON + } + next, err := SkipJSONString(data, i) + if err != nil { + return next, err + } + i = SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return i, ErrInvalidJSON + } + i++ + next, err = SkipJSONValue(data, i) + if err != nil { + return next, err + } + i = SkipJSONWhitespace(data, next) + if i >= len(data) { + return i, ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return i + 1, nil + } + return i, ErrInvalidJSON + } +} + +// skipJSONArray skips through the array body at data[i:] starting +// just past the '['. Returns the index one past the closing ']'. +func skipJSONArray(data []byte, i int) (int, error) { + i = SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == ']' { + return i + 1, nil + } + for { + next, err := SkipJSONValue(data, i) + if err != nil { + return next, err + } + i = SkipJSONWhitespace(data, next) + if i >= len(data) { + return i, ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == ']' { + return i + 1, nil + } + return i, ErrInvalidJSON + } +} + +// skipJSONNumber walks a JSON number (possibly signed, possibly +// containing '.' / 'e' / 'E') at data[i] and returns the index one +// past the last byte. +func skipJSONNumber(data []byte, i int) (int, error) { + start := i + if i < len(data) && data[i] == '-' { + i++ + } + for i < len(data) { + c := data[i] + if (c >= '0' && c <= '9') || c == '.' || c == 'e' || c == 'E' || c == '+' || c == '-' { + i++ + continue + } + break + } + if i == start { + return i, ErrInvalidJSON + } + return i, nil +} + +// MatchObjectStart skips whitespace and asserts data[i] == '{', +// returning the index one past the opening brace. +// +// i, err := jsonenc.MatchObjectStart(data, 0) +func MatchObjectStart(data []byte, i int) (int, error) { + i = SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '{' { + return i, ErrInvalidJSON + } + return i + 1, nil +} + +// MatchArrayStart skips whitespace and asserts data[i] == '[', +// returning the index one past the opening bracket. +// +// i, err := jsonenc.MatchArrayStart(data, 0) +func MatchArrayStart(data []byte, i int) (int, error) { + i = SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '[' { + return i, ErrInvalidJSON + } + return i + 1, nil +} + +// ParseJSONFloat32 walks a JSON number at data[i] and returns the +// parsed float32 + the index one past the last byte. Accepts the +// same shape encoding/json accepts for a float field (optional +// leading '-', integer, optional fraction, optional exponent). +// +// v, next, err := jsonenc.ParseJSONFloat32(data, i) +func ParseJSONFloat32(data []byte, i int) (float32, int, error) { + start := i + if i < len(data) && data[i] == '-' { + i++ + } + for i < len(data) { + c := data[i] + if (c >= '0' && c <= '9') || c == '.' || c == 'e' || c == 'E' || c == '+' || c == '-' { + i++ + continue + } + break + } + if i == start { + return 0, i, ErrInvalidJSON + } + // strconv.ParseFloat with bitSize 32 matches encoding/json's + // float32 decoder. The string conversion at the strconv boundary + // is unavoidable — pre-W11-B json.Unmarshal paid the same cost + // via its own internal walker; the hand-roll wins from skipping + // reflect overhead, not from defeating the stdlib's float parser. + v, err := strconv.ParseFloat(string(data[start:i]), 32) + if err != nil { + return 0, i, ErrInvalidJSON + } + return float32(v), i, nil +} + +// ParseJSONFloat64 walks a JSON number at data[i] and returns the +// parsed float64 + the index one past the last byte. +func ParseJSONFloat64(data []byte, i int) (float64, int, error) { + start := i + if i < len(data) && data[i] == '-' { + i++ + } + for i < len(data) { + c := data[i] + if (c >= '0' && c <= '9') || c == '.' || c == 'e' || c == 'E' || c == '+' || c == '-' { + i++ + continue + } + break + } + if i == start { + return 0, i, ErrInvalidJSON + } + v, err := strconv.ParseFloat(string(data[start:i]), 64) + if err != nil { + return 0, i, ErrInvalidJSON + } + return v, i, nil +} + +// CountJSONArrayElements counts the elements in the JSON array body +// starting at data[i] (just past the '['). Does NOT mutate the +// caller's index — callers use the count only for slice pre-sizing. +// +// Walks each element via SkipJSONValue so it handles nested objects +// / arrays / quoted strings (no naive comma-count footgun). Returns +// 0 for a malformed body — the caller's subsequent parse re-reports +// the malformedness. +// +// count := jsonenc.CountJSONArrayElements(data, i) +// out := make([]T, 0, count) +func CountJSONArrayElements(data []byte, i int) int { + i = SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] == ']' { + return 0 + } + count := 0 + for { + next, err := SkipJSONValue(data, i) + if err != nil { + return count + } + count++ + i = SkipJSONWhitespace(data, next) + if i >= len(data) { + return count + } + if data[i] == ',' { + i = SkipJSONWhitespace(data, i+1) + continue + } + return count + } +} diff --git a/go/jsonenc/jsondec_test.go b/go/jsonenc/jsondec_test.go new file mode 100644 index 0000000..8c08701 --- /dev/null +++ b/go/jsonenc/jsondec_test.go @@ -0,0 +1,290 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package jsonenc + +import ( + "reflect" + "testing" +) + +// TestParseJSONStringList_RoundTrip mirrors the test in openai/jsondec_test.go — +// when this passes, the openai package's call site is byte-for-byte +// compatible with the lifted primitive. +func TestParseJSONStringList_RoundTrip(t *testing.T) { + cases := []struct { + name string + in string + want []string + }{ + {"null", "null", nil}, + {"null-with-whitespace", " null\t", nil}, + {"plain-string", `"END"`, []string{"END"}}, + {"string-with-escapes", `"line1\nline2"`, []string{"line1\nline2"}}, + {"string-with-quote", `"he said \"hi\""`, []string{`he said "hi"`}}, + {"string-with-unicode", `"é"`, []string{"é"}}, + {"empty-array", `[]`, nil}, + {"single-element-array", `["END"]`, []string{"END"}}, + {"multi-element-array", `["A","B","C"]`, []string{"A", "B", "C"}}, + {"array-with-whitespace", ` [ "A" , "B" ] `, []string{"A", "B"}}, + {"array-with-escapes", `["\t","\n"]`, []string{"\t", "\n"}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := ParseJSONStringList([]byte(tc.in)) + if err != nil { + t.Fatalf("ParseJSONStringList(%s) error = %v", tc.in, err) + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("ParseJSONStringList(%s) = %v, want %v", tc.in, got, tc.want) + } + }) + } +} + +func TestParseJSONStringList_Invalid(t *testing.T) { + cases := []string{ + "", + " ", + `{`, + `}`, + `"unterminated`, + `[`, + `["unterminated`, + `["A"`, + `["A",]`, + `[123]`, + `tru`, + } + for _, in := range cases { + t.Run(in, func(t *testing.T) { + _, err := ParseJSONStringList([]byte(in)) + if err == nil { + t.Fatalf("ParseJSONStringList(%q) returned nil error, want error", in) + } + }) + } +} + +func TestParseJSONString_FastPath(t *testing.T) { + data := []byte(`"hello world"`) + s, next, err := ParseJSONString(data, 0) + if err != nil { + t.Fatalf("ParseJSONString error = %v", err) + } + if s != "hello world" { + t.Fatalf("got %q want hello world", s) + } + if next != len(data) { + t.Fatalf("next = %d want %d", next, len(data)) + } +} + +func TestParseJSONString_Escapes(t *testing.T) { + cases := []struct { + in string + want string + }{ + {`"\""`, `"`}, + {`"\\"`, `\`}, + {`"\/"`, "/"}, + {`"\b"`, "\b"}, + {`"\f"`, "\f"}, + {`"\n"`, "\n"}, + {`"\r"`, "\r"}, + {`"\t"`, "\t"}, + {`"A"`, "A"}, + {`"é"`, "é"}, + } + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + s, _, err := ParseJSONString([]byte(tc.in), 0) + if err != nil { + t.Fatalf("ParseJSONString(%s) error = %v", tc.in, err) + } + if s != tc.want { + t.Fatalf("got %q want %q", s, tc.want) + } + }) + } +} + +func TestParseJSONInt(t *testing.T) { + cases := []struct { + in string + want int64 + }{ + {`0`, 0}, + {`1`, 1}, + {`-1`, -1}, + {`123456789`, 123456789}, + {`-987`, -987}, + } + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + n, _, err := ParseJSONInt([]byte(tc.in), 0) + if err != nil { + t.Fatalf("ParseJSONInt(%s) error = %v", tc.in, err) + } + if n != tc.want { + t.Fatalf("got %d want %d", n, tc.want) + } + }) + } +} + +func TestParseJSONInt_Invalid(t *testing.T) { + cases := []string{"", "-", "a", "+1"} + for _, in := range cases { + t.Run(in, func(t *testing.T) { + _, _, err := ParseJSONInt([]byte(in), 0) + if err == nil { + t.Fatalf("ParseJSONInt(%q) returned nil error, want error", in) + } + }) + } +} + +func TestParseJSONBool(t *testing.T) { + v, next, err := ParseJSONBool([]byte(`true`), 0) + if err != nil || v != true || next != 4 { + t.Fatalf("true: v=%v next=%d err=%v", v, next, err) + } + v, next, err = ParseJSONBool([]byte(`false`), 0) + if err != nil || v != false || next != 5 { + t.Fatalf("false: v=%v next=%d err=%v", v, next, err) + } + _, _, err = ParseJSONBool([]byte(`tru`), 0) + if err == nil { + t.Fatalf("ParseJSONBool(tru) returned nil error") + } +} + +func TestIsJSONNull(t *testing.T) { + if !IsJSONNull([]byte(`null`), 0) { + t.Fatalf("expected null match") + } + if IsJSONNull([]byte(`nul`), 0) { + t.Fatalf("expected no match on nul") + } + if IsJSONNull([]byte(`xnull`), 0) { + t.Fatalf("expected no match on xnull") + } +} + +func TestSkipJSONValue(t *testing.T) { + cases := []struct { + in string + want int + }{ + {`null`, 4}, + {`true`, 4}, + {`false`, 5}, + {`"abc"`, 5}, + {`123`, 3}, + {`-1.5e3`, 6}, + {`{}`, 2}, + {`[]`, 2}, + {`{"a":1}`, 7}, + {`["a","b"]`, 9}, + {`{"a":[1,2,{"b":"c"}]}`, 21}, + } + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + next, err := SkipJSONValue([]byte(tc.in), 0) + if err != nil { + t.Fatalf("SkipJSONValue(%s) error = %v", tc.in, err) + } + if next != tc.want { + t.Fatalf("got %d want %d", next, tc.want) + } + }) + } +} + +func TestMatchObjectAndArrayStart(t *testing.T) { + i, err := MatchObjectStart([]byte(` {`), 0) + if err != nil || i != 3 { + t.Fatalf("MatchObjectStart: i=%d err=%v", i, err) + } + i, err = MatchArrayStart([]byte(` [`), 0) + if err != nil || i != 3 { + t.Fatalf("MatchArrayStart: i=%d err=%v", i, err) + } + _, err = MatchObjectStart([]byte(`123`), 0) + if err == nil { + t.Fatalf("expected error on non-object") + } +} + +func TestSkipJSONString(t *testing.T) { + cases := []struct { + in string + want int + }{ + {`"abc"`, 5}, + {`""`, 2}, + {`"a\nb"`, 6}, + {`"a\"b"`, 6}, + {`"a\\b"`, 6}, + {`"aÿb"`, 6}, // ÿ is 2 UTF-8 bytes inside the quotes + } + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + next, err := SkipJSONString([]byte(tc.in), 0) + if err != nil { + t.Fatalf("SkipJSONString(%s) error = %v", tc.in, err) + } + if next != tc.want { + t.Fatalf("got %d want %d", next, tc.want) + } + }) + } +} + +func TestParseJSONFloat(t *testing.T) { + v, _, err := ParseJSONFloat32([]byte(`0.7`), 0) + if err != nil || v != 0.7 { + t.Fatalf("ParseJSONFloat32(0.7): v=%v err=%v", v, err) + } + v, _, err = ParseJSONFloat32([]byte(`-1.5e2`), 0) + if err != nil || v != -150 { + t.Fatalf("ParseJSONFloat32(-1.5e2): v=%v err=%v", v, err) + } + d, _, err := ParseJSONFloat64([]byte(`3.14`), 0) + if err != nil || d != 3.14 { + t.Fatalf("ParseJSONFloat64(3.14): d=%v err=%v", d, err) + } +} + +func TestCountJSONArrayElements(t *testing.T) { + cases := []struct { + in string + want int + }{ + {`]`, 0}, + {`1]`, 1}, + {`1,2,3]`, 3}, + {`"a","b"]`, 2}, + {`{"x":1},{"y":2}]`, 2}, + {`[1,2],[3]]`, 2}, + } + for _, tc := range cases { + t.Run(tc.in, func(t *testing.T) { + got := CountJSONArrayElements([]byte(tc.in), 0) + if got != tc.want { + t.Fatalf("got %d want %d", got, tc.want) + } + }) + } +} + +func TestParseJSONStringRaw(t *testing.T) { + b, next, err := ParseJSONStringRaw([]byte(`"hello"`), 0) + if err != nil || string(b) != "hello" || next != 7 { + t.Fatalf("ParseJSONStringRaw fast path: b=%q next=%d err=%v", b, next, err) + } + b, next, err = ParseJSONStringRaw([]byte(`"a\nb"`), 0) + if err != nil || string(b) != "a\nb" || next != 6 { + t.Fatalf("ParseJSONStringRaw escape path: b=%q next=%d err=%v", b, next, err) + } +} diff --git a/go/jsonenc/jsonenc.go b/go/jsonenc/jsonenc.go new file mode 100644 index 0000000..e6eb15d --- /dev/null +++ b/go/jsonenc/jsonenc.go @@ -0,0 +1,201 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package jsonenc provides hand-rolled JSON-encoding primitives +// shared across the inference adapter hot paths (openai, anthropic, +// ollama). The encoding/json reflect path allocates an encoder state +// machine and a grow-doubled output buffer on every Marshal call — +// each adapter encoder that fires per-request or per-streamed-token +// pays that floor. These primitives let per-shape encoders land at a +// single buffer allocation per call. +// +// Provenance: lifted in W9-Z from three byte-identical copies that +// shipped in W9-D (openai), W9-E (anthropic), and W9-G (ollama). The +// canonical fast-path uses anthropic's two-function split (W9-E) for +// AppendJSONString — a single forward scan followed by a single bulk +// append when no escape is needed; a separate tail-walker handles +// the escape-bearing case. Same minimax lift as state/filestore's +// encodeRecordMeta (W8-D) and core.ParseHeaderRefs (W8-I/K). +// +// The output is valid JSON and parseable both by encoding/json +// (round-trips into the same Go types) and by any naive JSON walker. +// All callers share the same escape contract — quote, backslash, +// b/f/n/r/t mnemonics, and \u00XX for other control chars below 0x20. +// Bytes >= 0x20 outside the quote/backslash pair pass through verbatim; +// encoding/json's default also escapes <, >, & for HTML safety but the +// adapters built on this package do not emit into HTML contexts. +// +// Encoders are exported as standalone Append* functions rather than +// MarshalJSON methods. encoding/json.Marshal validates and recopies +// the bytes returned by MarshalJSON — for top-level marshals that +// erases the win. Consumers on the hot path call the Append* entry +// points directly. +package jsonenc + +import "strconv" + +// AppendJSONString appends a JSON-encoded string to buf — opening +// quote, escaped body, closing quote. Caller is responsible for +// providing the surrounding context (key, comma, etc). +// +// buf = jsonenc.AppendJSONString(buf, "answer") // -> "answer" +// +// Escapes: \" \\ \b \f \n \r \t for the mnemonic forms and \u00XX +// for other bytes < 0x20. All other bytes pass through. +// +// Fast path: scan for any character requiring an escape. Adapter +// message bodies overwhelmingly contain neither — once a hot prefix +// passes the scan, we copy the whole string verbatim in one append. +// On the rare escape-bearing path we drop back to the byte-by-byte +// walk starting from the first hit. The split keeps the fast path +// inlineable. +func AppendJSONString(buf []byte, s string) []byte { + buf = append(buf, '"') + // Scan for the first byte that needs escaping. \" \\ and any + // byte < 0x20 all require special handling; everything else + // passes through. + for i := 0; i < len(s); i++ { + c := s[i] + if c == '"' || c == '\\' || c < 0x20 { + // Bulk-copy the safe prefix, then walk the rest. + buf = append(buf, s[:i]...) + return appendJSONStringEscaped(buf, s[i:]) + } + } + // No escapes — single bulk append covers the whole body. + buf = append(buf, s...) + return append(buf, '"') +} + +// appendJSONStringEscaped completes a string already opened with `"` +// and that has at least one byte requiring escape treatment in s[0]. +// Internal helper for AppendJSONString — separated out to keep the +// fast-path inlineable. +func appendJSONStringEscaped(buf []byte, s string) []byte { + for i := 0; i < len(s); i++ { + c := s[i] + switch { + case c == '"': + buf = append(buf, '\\', '"') + case c == '\\': + buf = append(buf, '\\', '\\') + case c == '\b': + buf = append(buf, '\\', 'b') + case c == '\f': + buf = append(buf, '\\', 'f') + case c == '\n': + buf = append(buf, '\\', 'n') + case c == '\r': + buf = append(buf, '\\', 'r') + case c == '\t': + buf = append(buf, '\\', 't') + case c < 0x20: + buf = append(buf, '\\', 'u', '0', '0', HexChar(c>>4), HexChar(c&0x0f)) + default: + buf = append(buf, c) + } + } + return append(buf, '"') +} + +// AppendStringField appends a `"key":"value"` pair (optionally +// prefixed with a leading comma) to buf. Key is treated as an ASCII +// literal — wire-schema keys carry no escapes by construction. +// +// buf = jsonenc.AppendStringField(buf, "model", req.Model, false) +// buf = jsonenc.AppendStringField(buf, "id", id, true) // leading comma +func AppendStringField(buf []byte, key, value string, leadingComma bool) []byte { + if leadingComma { + buf = append(buf, ',') + } + buf = append(buf, '"') + buf = append(buf, key...) + buf = append(buf, '"', ':') + return AppendJSONString(buf, value) +} + +// AppendIntField appends a `"key":N` pair (optionally prefixed with a +// leading comma) where N is the base-10 representation of value. +// +// buf = jsonenc.AppendIntField(buf, "index", 0, true) +func AppendIntField(buf []byte, key string, value int, leadingComma bool) []byte { + if leadingComma { + buf = append(buf, ',') + } + buf = append(buf, '"') + buf = append(buf, key...) + buf = append(buf, '"', ':') + return strconv.AppendInt(buf, int64(value), 10) +} + +// AppendInt64Field appends a `"key":N` pair for an int64. +// +// buf = jsonenc.AppendInt64Field(buf, "total_duration", 1_500_000_000, true) +func AppendInt64Field(buf []byte, key string, value int64, leadingComma bool) []byte { + if leadingComma { + buf = append(buf, ',') + } + buf = append(buf, '"') + buf = append(buf, key...) + buf = append(buf, '"', ':') + return strconv.AppendInt(buf, value, 10) +} + +// AppendBoolField appends a `"key":true` or `"key":false` pair. +// +// buf = jsonenc.AppendBoolField(buf, "stream", req.Stream, true) +func AppendBoolField(buf []byte, key string, value, leadingComma bool) []byte { + if leadingComma { + buf = append(buf, ',') + } + buf = append(buf, '"') + buf = append(buf, key...) + buf = append(buf, '"', ':') + if value { + return append(buf, 't', 'r', 'u', 'e') + } + return append(buf, 'f', 'a', 'l', 's', 'e') +} + +// AppendFloat32Field appends a `"key":F` pair where F is rendered in +// the same 'g' format encoding/json emits for float32 (bitSize 32). +// +// buf = jsonenc.AppendFloat32Field(buf, "temperature", *req.Temperature, true) +func AppendFloat32Field(buf []byte, key string, value float32, leadingComma bool) []byte { + if leadingComma { + buf = append(buf, ',') + } + buf = append(buf, '"') + buf = append(buf, key...) + buf = append(buf, '"', ':') + return strconv.AppendFloat(buf, float64(value), 'g', -1, 32) +} + +// AppendFloat32 appends a bare float32 value (no key, no comma) in +// the same shape json.Marshal emits — 'g' format, bitSize 32. Used +// for array-element emission (per-element embedding vectors) where +// the caller drives commas and surrounding context. +// +// buf = jsonenc.AppendFloat32(buf, v) +func AppendFloat32(buf []byte, value float32) []byte { + return strconv.AppendFloat(buf, float64(value), 'g', -1, 32) +} + +// AppendFloat64 appends a bare float64 value in the same shape +// json.Marshal emits — 'g' format, bitSize 64. +// +// buf = jsonenc.AppendFloat64(buf, score.Score) +func AppendFloat64(buf []byte, value float64) []byte { + return strconv.AppendFloat(buf, value, 'g', -1, 64) +} + +// HexChar returns the ASCII hex digit for the low nibble of v. Used +// by AppendJSONString's \u00XX escape branch; exported so adapter +// packages can reuse the same byte-to-hex contract when they emit +// their own escape paths (e.g. URI-encoded fields). +func HexChar(v byte) byte { + v &= 0x0f + if v < 10 { + return '0' + v + } + return 'a' + (v - 10) +} diff --git a/go/jsonenc/jsonenc_test.go b/go/jsonenc/jsonenc_test.go new file mode 100644 index 0000000..031997c --- /dev/null +++ b/go/jsonenc/jsonenc_test.go @@ -0,0 +1,191 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package jsonenc + +import ( + "encoding/json" + "strconv" + "testing" +) + +// TestAppendJSONString_RoundTrip pins the escape contract of +// AppendJSONString against encoding/json's encoder. Every byte class +// (mnemonic escapes, \u00XX controls, plain ASCII, multi-byte UTF-8) +// must round-trip identically. +func TestAppendJSONString_RoundTrip(t *testing.T) { + cases := []struct { + name string + input string + }{ + {"empty", ""}, + {"plain_ASCII", "answer"}, + {"quote", `say "hi"`}, + {"backslash", `path\to\file`}, + {"mnemonics", "\b\f\n\r\t"}, + {"control_low", "\x01\x02\x1f"}, + {"utf8", "café — résumé"}, + {"mixed", "line1\n\"quote\"\tend"}, + {"long_clean", "the quick brown fox jumps over the lazy dog — repeated bulk-copy fast-path"}, + {"escape_at_end", "clean prefix then\\"}, + {"escape_at_start", "\"quoted prefix"}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := string(AppendJSONString(nil, tc.input)) + want, err := json.Marshal(tc.input) + if err != nil { + t.Fatalf("json.Marshal(%q) error: %v", tc.input, err) + } + // encoding/json HTML-escapes <, >, &; AppendJSONString + // does not. None of the cases above exercise that branch, + // so direct compare holds. + if got != string(want) { + t.Fatalf("AppendJSONString(%q):\n got = %s\nwant = %s", tc.input, got, want) + } + var parsed string + if err := json.Unmarshal([]byte(got), &parsed); err != nil { + t.Fatalf("Unmarshal(%s): %v", got, err) + } + if parsed != tc.input { + t.Fatalf("round-trip drift:\n got = %q\nwant = %q", parsed, tc.input) + } + }) + } +} + +// TestAppendJSONString_AppendsToExisting verifies the primitive +// appends without clobbering the leading bytes — load-bearing for +// the per-shape encoders that pre-populate `{"key":` before calling. +func TestAppendJSONString_AppendsToExisting(t *testing.T) { + buf := []byte(`{"key":`) + buf = AppendJSONString(buf, "value") + if got, want := string(buf), `{"key":"value"`; got != want { + t.Fatalf("append-onto: got %s want %s", got, want) + } +} + +// TestAppendStringField verifies the `"key":"value"` shape with and +// without leading comma. +func TestAppendStringField(t *testing.T) { + buf := AppendStringField(nil, "model", "qwen3", false) + if got, want := string(buf), `"model":"qwen3"`; got != want { + t.Fatalf("no-comma: got %s want %s", got, want) + } + buf = AppendStringField(nil, "role", "assistant", true) + if got, want := string(buf), `,"role":"assistant"`; got != want { + t.Fatalf("leading-comma: got %s want %s", got, want) + } + // Escape contract carries through. + buf = AppendStringField(nil, "content", "line1\n\"q\"", false) + if got, want := string(buf), `"content":"line1\n\"q\""`; got != want { + t.Fatalf("escapes: got %s want %s", got, want) + } +} + +// TestAppendIntField verifies the `"key":N` shape. +func TestAppendIntField(t *testing.T) { + buf := AppendIntField(nil, "index", 0, false) + if got, want := string(buf), `"index":0`; got != want { + t.Fatalf("int zero: got %s want %s", got, want) + } + buf = AppendIntField(nil, "count", 256, true) + if got, want := string(buf), `,"count":256`; got != want { + t.Fatalf("int with comma: got %s want %s", got, want) + } + buf = AppendIntField(nil, "neg", -1, false) + if got, want := string(buf), `"neg":-1`; got != want { + t.Fatalf("int negative: got %s want %s", got, want) + } +} + +// TestAppendInt64Field covers wide int64 values that duration fields +// use (nanoseconds, easily >2^31). +func TestAppendInt64Field(t *testing.T) { + buf := AppendInt64Field(nil, "total_duration", 1_500_000_000, false) + if got, want := string(buf), `"total_duration":1500000000`; got != want { + t.Fatalf("int64: got %s want %s", got, want) + } + buf = AppendInt64Field(nil, "max", 1<<62, true) + if got, want := string(buf), `,"max":`+strconv.FormatInt(1<<62, 10); got != want { + t.Fatalf("int64 large: got %s want %s", got, want) + } +} + +// TestAppendBoolField pins the Done-flag emission shape used by +// every per-token streaming chunk. +func TestAppendBoolField(t *testing.T) { + buf := AppendBoolField(nil, "done", true, false) + if got, want := string(buf), `"done":true`; got != want { + t.Fatalf("bool true: got %s want %s", got, want) + } + buf = AppendBoolField(nil, "done", false, true) + if got, want := string(buf), `,"done":false`; got != want { + t.Fatalf("bool false: got %s want %s", got, want) + } +} + +// TestAppendFloat32Field verifies the inline `"key":F` form used by +// sampling parameters (temperature, top_p). +func TestAppendFloat32Field(t *testing.T) { + buf := AppendFloat32Field(nil, "temperature", 0.7, false) + if got, want := string(buf), `"temperature":0.7`; got != want { + t.Fatalf("float32 field: got %s want %s", got, want) + } + buf = AppendFloat32Field(nil, "top_p", 0.95, true) + if got, want := string(buf), `,"top_p":0.95`; got != want { + t.Fatalf("float32 field with comma: got %s want %s", got, want) + } +} + +// TestAppendFloat32 verifies the bare-value emission shape used for +// embedding vector elements. +func TestAppendFloat32(t *testing.T) { + cases := []struct { + in float32 + want string + }{ + {0.7, "0.7"}, + {0.95, "0.95"}, + {1.0, "1"}, + {0.0001, "0.0001"}, + {2.0, "2"}, + } + for _, tc := range cases { + got := string(AppendFloat32(nil, tc.in)) + if got != tc.want { + t.Fatalf("float32(%v): got %s want %s", tc.in, got, tc.want) + } + } +} + +// TestAppendFloat64 verifies the bare-value emission shape used for +// score / probability outputs. +func TestAppendFloat64(t *testing.T) { + got := string(AppendFloat64(nil, 0.12345)) + if got != "0.12345" { + t.Fatalf("float64: got %s want 0.12345", got) + } +} + +// TestHexChar covers the nibble-to-ASCII contract used by the +// \u00XX escape branch. +func TestHexChar(t *testing.T) { + cases := []struct { + in byte + want byte + }{ + {0, '0'}, + {9, '9'}, + {10, 'a'}, + {15, 'f'}, + // High nibble masked off — only low 4 bits matter. + {0xF0, '0'}, + {0xFF, 'f'}, + } + for _, tc := range cases { + got := HexChar(tc.in) + if got != tc.want { + t.Fatalf("HexChar(%#x): got %q want %q", tc.in, got, tc.want) + } + } +} diff --git a/go/model/pack/manifest.go b/go/model/pack/manifest.go new file mode 100644 index 0000000..15e02f4 --- /dev/null +++ b/go/model/pack/manifest.go @@ -0,0 +1,161 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package pack wraps an unpacked model pack (the directory shape walked by +// inference.ModelPackInspector) into a Trix container with magic "MDL1", +// and round-trips back to disk. +// +// Container layout (delegated to forge.lthn.ai/Snider/Enchantrix/pkg/trix): +// +// [Magic "MDL1" (4)] [Version (1)] [Header Length (4)] [JSON Header] [Payload] +// +// Header is the JSON-marshalled Manifest. Payload is a deterministic tar of the +// source pack directory, optionally followed by an embedded vindex blob at the +// offset/length declared in Manifest.Vindex. +// +// r := pack.Pack(c, "/path/to/gemma-4-26b-a4b-it", "out.model", pack.PackOptions{}) +// if !r.OK { return r } +package pack + +import ( + iofs "io/fs" + + "dappco.re/go/inference" +) + +// Magic is the 4-byte Trix magic for a .model container. +const Magic = "MDL1" + +// Manifest is the JSON header carried inside a .model Trix container. +// It mirrors the shape of inference.ModelPackInspection for the contained +// pack, plus packaging-specific metadata (lineage, vindex placement, +// producer attribution, signatures). +type Manifest struct { + // Model is the portable model identity for the contained pack. + Model inference.ModelIdentity `json:"model"` + + // Tokenizer is the portable tokenizer identity for the contained pack. + Tokenizer inference.TokenizerIdentity `json:"tokenizer"` + + // SourceFormat names the on-disk shape of the model bytes inside the + // payload tar — currently "safetensors" or "gguf". + SourceFormat string `json:"source_format"` + + // Capabilities are the per-pack capabilities reported by the inspector. + Capabilities []inference.Capability `json:"capabilities,omitempty"` + + // Lineage points back at the source .train this .model was derived + // from. Optional — top-level training runs derived without a prior + // .train may omit it. + Lineage *Lineage `json:"lineage,omitempty"` + + // Vindex describes an embedded LARQL vindex blob. When Vindex is nil + // the .model carries only the model pack tar; LQL operations that + // require a vindex must EXTRACT one first. + Vindex *VindexRef `json:"vindex,omitempty"` + + // Producer records who emitted the .model. + Producer Producer `json:"producer"` + + // Signatures are detached signatures over the payload bytes. + // Verification is handled at consumer layer; this package only + // round-trips the slice. + Signatures []Signature `json:"signatures,omitempty"` +} + +// Lineage records the source .train file the .model was derived from. +type Lineage struct { + TrainURI string `json:"train_uri"` + TrainSHA string `json:"train_sha,omitempty"` +} + +// VindexRef points at an embedded vindex blob inside the payload. +type VindexRef struct { + // Embedded is always true for .model files where Vindex != nil — the + // flag exists so external readers don't need to introspect Offset/Length + // to know whether to expect a payload-side vindex. + Embedded bool `json:"embedded"` + + // Offset is the byte offset (within the Trix payload) at which the + // vindex blob starts. + Offset uint64 `json:"offset"` + + // Length is the vindex blob length in bytes. + Length uint64 `json:"length"` + + // Format names the vindex serialisation. "msgpack" is the LARQL + // .larql.bin form. + Format string `json:"format,omitempty"` +} + +// Producer records the tool that emitted the .model. +type Producer struct { + Name string `json:"name"` + Commit string `json:"commit,omitempty"` + Created string `json:"created"` // RFC3339 UTC +} + +// Signature is a detached signature over the Trix payload bytes. +type Signature struct { + KeyID string `json:"key_id"` + Alg string `json:"alg"` // e.g. "ed25519" + Sig string `json:"sig"` // base64 standard encoding +} + +// PackOptions controls Pack behaviour. +type PackOptions struct { + // Manifest is the manifest to embed in the Trix header. If + // Manifest.Producer.Created is empty, Pack fills it with the current + // UTC RFC3339 timestamp. + Manifest Manifest + + // VindexBlob, when non-nil, requests an embedded vindex. NOT yet + // implemented — passing a non-nil value causes Pack to return an + // explicit "vindex embedding not yet implemented" Result so the seam + // is honest rather than silently dropping the blob. + VindexBlob []byte +} + +// UnpackOptions controls Unpack behaviour. +type UnpackOptions struct { + // Overwrite allows Unpack to write into a non-empty destination dir. + // Default false — Unpack refuses if the destination already contains + // files. + Overwrite bool +} + +// Entry is one tar entry inside a .model payload — the shape List +// returns. Path, Size, and Mode are surfaced; content is not read. +type Entry struct { + Path string `json:"path"` + Size int64 `json:"size"` + Mode iofs.FileMode `json:"mode"` +} + +// IdentityFingerprint is the deterministic identity projection of a +// Manifest — the subset of fields that, together, mean "these two .model +// files describe the same logical model artefact". Timestamps, signatures, +// and lineage URIs are deliberately excluded — they are provenance, not +// identity. +type IdentityFingerprint struct { + Model inference.ModelIdentity `json:"model"` + Tokenizer inference.TokenizerIdentity `json:"tokenizer"` + SourceFormat string `json:"source_format"` + Capabilities []inference.Capability `json:"capabilities,omitempty"` + VindexHash string `json:"vindex_hash,omitempty"` +} + +// Identity returns the identity projection of this Manifest — the +// fields that decide "is this the same logical model?". +// +// id := manifest.Identity() +// _ = id.Model.Architecture +func (m Manifest) Identity() IdentityFingerprint { + return IdentityFingerprint{ + Model: m.Model, + Tokenizer: m.Tokenizer, + SourceFormat: m.SourceFormat, + Capabilities: m.Capabilities, + // VindexHash left empty until vindex embedding lands and the + // hash of the embedded blob is known at fingerprint time. + } +} diff --git a/go/model/pack/pack.go b/go/model/pack/pack.go new file mode 100644 index 0000000..211c2c3 --- /dev/null +++ b/go/model/pack/pack.go @@ -0,0 +1,476 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package pack + +import ( + "archive/tar" + "bytes" + "crypto/sha256" + "encoding/hex" + "io" + iofs "io/fs" + "sort" + "time" + + core "dappco.re/go" + "dappco.re/go/inference" + + "forge.lthn.ai/Snider/Enchantrix/pkg/trix" +) + +// Pack reads an unpacked model pack at srcDir and writes a .model Trix +// container to dest. Payload is a deterministic tar of srcDir contents. +// Manifest is embedded as the Trix header. +// +// r := pack.Pack("/models/gemma-4-26b-a4b-it", "out.model", pack.PackOptions{ +// Manifest: pack.Manifest{ +// Model: inference.ModelIdentity{Architecture: "gemma-4-26b-a4b-it", QuantBits: 4}, +// Tokenizer: inference.TokenizerIdentity{Kind: "sentencepiece"}, +// SourceFormat: "safetensors", +// Producer: pack.Producer{Name: "go-mlx"}, +// }, +// }) +// if !r.OK { return r } +func Pack(srcDir, dest string, opts PackOptions) core.Result { + if !dirExists(srcDir) { + return core.Fail(core.E("pack.Pack", core.Sprintf("srcDir %q is not a directory", srcDir), nil)) + } + if opts.VindexBlob != nil { + return core.Fail(core.E("pack.Pack", "vindex embedding not yet implemented", nil)) + } + + manifest := opts.Manifest + if manifest.Producer.Created == "" { + manifest.Producer.Created = time.Now().UTC().Format(time.RFC3339) + } + if manifest.Model.Hash == "" { + // Auto-populate the canonical pack hash so consumers never + // see a .model with an empty Model.Hash. Caller can pre-fill + // it to skip this step when a cached value is already known. + h, hr := Hash(srcDir) + if !hr.OK { + return hr + } + manifest.Model.Hash = h + } + + tarBytes, tr := buildTar(srcDir) + if !tr.OK { + return tr + } + + headerMap, hr := manifestToHeaderMap(manifest) + if !hr.OK { + return hr + } + + container := &trix.Trix{ + Header: headerMap, + Payload: tarBytes, + } + + encoded, err := trix.Encode(container, Magic, nil) + if err != nil { + return core.Fail(core.E("pack.Pack", "trix.Encode failed", err)) + } + + if wr := core.WriteFile(dest, encoded, 0o644); !wr.OK { + return wr + } + return core.Ok(nil) +} + +// Unpack reads a .model Trix container at src and writes its contained +// model pack to destDir. destDir must not exist, must be empty, or +// UnpackOptions.Overwrite must be true. +// +// r := pack.Unpack("out.model", "/tmp/extracted", pack.UnpackOptions{}) +// if !r.OK { return r } +func Unpack(src, destDir string, opts UnpackOptions) core.Result { + rr := core.ReadFile(src) + if !rr.OK { + return rr + } + data := rr.Value.([]byte) + + container, err := trix.Decode(data, Magic, nil) + if err != nil { + return core.Fail(core.E("pack.Unpack", "trix.Decode failed", err)) + } + + if dr := assertDestDirWritable(destDir, opts.Overwrite); !dr.OK { + return dr + } + if mr := core.MkdirAll(destDir, 0o755); !mr.OK { + return mr + } + return extractTar(container.Payload, destDir) +} + +// List reads a .model Trix container and returns the payload tar's +// entries (path, size, mode) without extracting file contents. Useful +// for tree-view UI without paying the full extract cost. +// +// entries, manifest, r := pack.List("gemma.model") +// if !r.OK { return r } +// for _, e := range entries { core.Println(e.Path) } +func List(src string) ([]Entry, *Manifest, core.Result) { + rr := core.ReadFile(src) + if !rr.OK { + return nil, nil, rr + } + data := rr.Value.([]byte) + + container, err := trix.Decode(data, Magic, nil) + if err != nil { + return nil, nil, core.Fail(core.E("pack.List", "trix.Decode failed", err)) + } + + manifest, mr := headerMapToManifest(container.Header) + if !mr.OK { + return nil, nil, mr + } + + tr := tar.NewReader(bytes.NewReader(container.Payload)) + var entries []Entry + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return nil, nil, core.Fail(core.E("pack.List", "tar.Next failed", err)) + } + if hdr.Typeflag != tar.TypeReg { + continue + } + entries = append(entries, Entry{ + Path: hdr.Name, + Size: hdr.Size, + Mode: iofs.FileMode(hdr.Mode), + }) + } + return entries, manifest, core.Ok(nil) +} + +// Inspect reads a .model Trix container header (no payload extraction) +// and returns the Manifest plus a synthesised ModelPackInspection. +// +// manifest, inspection, r := pack.Inspect("out.model") +// if !r.OK { return r } +// core.Println(inspection.Model.Architecture) +func Inspect(src string) (*Manifest, *inference.ModelPackInspection, core.Result) { + rr := core.ReadFile(src) + if !rr.OK { + return nil, nil, rr + } + data := rr.Value.([]byte) + + container, err := trix.Decode(data, Magic, nil) + if err != nil { + return nil, nil, core.Fail(core.E("pack.Inspect", "trix.Decode failed", err)) + } + + manifest, mr := headerMapToManifest(container.Header) + if !mr.OK { + return nil, nil, mr + } + + inspection := &inference.ModelPackInspection{ + Path: src, + Format: manifest.SourceFormat, + Model: manifest.Model, + Tokenizer: manifest.Tokenizer, + Supported: true, + Capabilities: manifest.Capabilities, + } + return manifest, inspection, core.Ok(nil) +} + +// Hash computes the canonical model-pack hash for an unwrapped pack +// directory: SHA-256 of sorted content of the small metadata files +// (config.json, tokenizer.json, chat_template.jinja, adapter_config.json) +// concatenated with sorted file sizes of the *.safetensors blobs. +// +// Lightweight — doesn't read tensor bytes. Captures everything that +// affects behaviour without forcing a full content scan. Mirrors the +// shape inference.ModelPackInspector reads on the go-mlx side, so the +// hash from a packed .model and the hash from re-running InspectModelPack +// on the unwrapped dir agree byte-for-byte. +// +// h, r := pack.Hash("/models/gemma-3-4b-it") +// if !r.OK { return r } +// manifest.Model.Hash = h +// +// Missing optional files (chat_template.jinja, adapter_config.json) are +// simply skipped — their absence is part of the pack's identity. +func Hash(srcDir string) (string, core.Result) { + if !dirExists(srcDir) { + return "", core.Fail(core.E("pack.Hash", core.Sprintf("srcDir %q is not a directory", srcDir), nil)) + } + + metaCandidates := []string{ + "config.json", + "tokenizer.json", + "chat_template.jinja", + "adapter_config.json", + } + type metaFile struct { + name string + content []byte + } + var metas []metaFile + fs := (&core.Fs{}).NewUnrestricted() + for _, name := range metaCandidates { + path := core.JoinPath(srcDir, name) + if !fs.IsFile(path) { + continue + } + rr := core.ReadFile(path) + if !rr.OK { + return "", rr + } + metas = append(metas, metaFile{name: name, content: rr.Value.([]byte)}) + } + sort.Slice(metas, func(i, j int) bool { return metas[i].name < metas[j].name }) + + var safetensorSizes []int64 + for e, err := range fs.WalkSeq(srcDir) { + if err != nil { + return "", core.Fail(core.E("pack.Hash", "walk failed", err)) + } + if e.IsDir { + continue + } + if !core.HasSuffix(e.Path, ".safetensors") { + continue + } + statR := core.Stat(core.JoinPath(srcDir, e.Path)) + if !statR.OK { + return "", statR + } + info, ok := statR.Value.(iofs.FileInfo) + if !ok { + return "", core.Fail(core.E("pack.Hash", core.Sprintf("unexpected Stat shape for %q", e.Path), nil)) + } + safetensorSizes = append(safetensorSizes, info.Size()) + } + sort.Slice(safetensorSizes, func(i, j int) bool { return safetensorSizes[i] < safetensorSizes[j] }) + + h := sha256.New() + for _, m := range metas { + h.Write([]byte(m.name)) + h.Write([]byte{0}) + h.Write(m.content) + h.Write([]byte{0}) + } + h.Write([]byte("safetensors_sizes")) + h.Write([]byte{0}) + var sizeBuf [8]byte + for _, sz := range safetensorSizes { + u := uint64(sz) + for i := 0; i < 8; i++ { + sizeBuf[i] = byte(u >> (8 * i)) + } + h.Write(sizeBuf[:]) + } + return hex.EncodeToString(h.Sum(nil)), core.Ok(nil) +} + +// Fingerprint returns the SHA-256 hex digest of a Manifest's Identity +// projection. Stable across machines and across re-packs of the same +// logical model. Useful for "is this the same logical artefact?" without +// reading the payload. +// +// if pack.Fingerprint(a) == pack.Fingerprint(b) { /* same logical model */ } +func Fingerprint(m Manifest) string { + r := core.JSONMarshal(m.Identity()) + if !r.OK { + return "" + } + sum := sha256.Sum256(r.Value.([]byte)) + return hex.EncodeToString(sum[:]) +} + +// buildTar walks srcDir and produces a deterministic tar of all regular +// files. Entries are sorted by relative path; timestamps, uid/gid are +// zeroed so byte output is reproducible for identical input trees. +func buildTar(srcDir string) ([]byte, core.Result) { + fs := (&core.Fs{}).NewUnrestricted() + + type entry struct { + rel string + abs string + mode iofs.FileMode + } + var entries []entry + for e, err := range fs.WalkSeq(srcDir) { + if err != nil { + return nil, core.Fail(core.E("pack.buildTar", "walk failed", err)) + } + if e.IsDir { + continue + } + entries = append(entries, entry{ + rel: e.Path, + abs: core.JoinPath(srcDir, e.Path), + mode: e.Mode, + }) + } + + sort.Slice(entries, func(i, j int) bool { return entries[i].rel < entries[j].rel }) + + var buf bytes.Buffer + tw := tar.NewWriter(&buf) + for _, e := range entries { + rr := core.ReadFile(e.abs) + if !rr.OK { + return nil, rr + } + content := rr.Value.([]byte) + + hdr := &tar.Header{ + Name: e.rel, + Mode: int64(e.mode.Perm()), + Size: int64(len(content)), + Typeflag: tar.TypeReg, + } + if err := tw.WriteHeader(hdr); err != nil { + return nil, core.Fail(core.E("pack.buildTar", core.Sprintf("write header for %q", e.rel), err)) + } + if _, err := tw.Write(content); err != nil { + return nil, core.Fail(core.E("pack.buildTar", core.Sprintf("write content for %q", e.rel), err)) + } + } + if err := tw.Close(); err != nil { + return nil, core.Fail(core.E("pack.buildTar", "tar.Close failed", err)) + } + return buf.Bytes(), core.Ok(nil) +} + +// extractTar reads a tar stream and writes each regular-file entry under +// destDir. Path-traversal entries (containing "..") are rejected. +func extractTar(payload []byte, destDir string) core.Result { + tr := tar.NewReader(bytes.NewReader(payload)) + for { + hdr, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return core.Fail(core.E("pack.extractTar", "tar.Next failed", err)) + } + if hdr.Typeflag != tar.TypeReg { + continue + } + if !safeRelPath(hdr.Name) { + return core.Fail(core.E("pack.extractTar", core.Sprintf("unsafe entry path %q", hdr.Name), nil)) + } + out := core.JoinPath(destDir, hdr.Name) + if mr := core.MkdirAll(core.PathDir(out), 0o755); !mr.OK { + return mr + } + content := make([]byte, hdr.Size) + if _, err := io.ReadFull(tr, content); err != nil { + return core.Fail(core.E("pack.extractTar", core.Sprintf("read content for %q", hdr.Name), err)) + } + if wr := core.WriteFile(out, content, iofs.FileMode(hdr.Mode)); !wr.OK { + return wr + } + } + return core.Ok(nil) +} + +// manifestToHeaderMap marshals a Manifest to JSON and back into a +// map[string]interface{} suitable for trix.Trix.Header. +func manifestToHeaderMap(m Manifest) (map[string]interface{}, core.Result) { + jr := core.JSONMarshal(m) + if !jr.OK { + return nil, jr + } + data := jr.Value.([]byte) + var out map[string]interface{} + if ur := core.JSONUnmarshal(data, &out); !ur.OK { + return nil, ur + } + return out, core.Ok(nil) +} + +// headerMapToManifest is the inverse — marshals the Trix header map back +// to JSON, then unmarshals into a typed Manifest. +func headerMapToManifest(h map[string]interface{}) (*Manifest, core.Result) { + jr := core.JSONMarshal(h) + if !jr.OK { + return nil, jr + } + data := jr.Value.([]byte) + var out Manifest + if ur := core.JSONUnmarshal(data, &out); !ur.OK { + return nil, ur + } + return &out, core.Ok(nil) +} + +// dirExists reports whether p exists and is a directory. +func dirExists(p string) bool { + fs := (&core.Fs{}).NewUnrestricted() + return fs.IsDir(p) +} + +// assertDestDirWritable returns a failing Result if destDir exists, is a +// directory, contains entries, and overwrite is false. Missing destDir is +// fine (caller MkdirAll's it). +func assertDestDirWritable(destDir string, overwrite bool) core.Result { + fs := (&core.Fs{}).NewUnrestricted() + if !fs.Exists(destDir) { + return core.Ok(nil) + } + if !fs.IsDir(destDir) { + return core.Fail(core.E("pack.Unpack", core.Sprintf("destDir %q exists but is not a directory", destDir), nil)) + } + if overwrite { + return core.Ok(nil) + } + lr := fs.List(destDir) + if !lr.OK { + return lr + } + if entries, ok := lr.Value.([]iofs.DirEntry); ok && len(entries) > 0 { + return core.Fail(core.E("pack.Unpack", core.Sprintf("destDir %q is not empty (set UnpackOptions.Overwrite to allow)", destDir), nil)) + } + return core.Ok(nil) +} + +// safeRelPath rejects tar entries that would escape the destination via +// path traversal or absolute paths. +func safeRelPath(p string) bool { + if p == "" || core.HasPrefix(p, "/") { + return false + } + // Reject any ".." segment — guards against tar slip vulnerabilities. + for _, seg := range splitSegments(p) { + if seg == ".." { + return false + } + } + return true +} + +// splitSegments splits a slash-separated path into its segments without +// importing path/filepath or strings. +func splitSegments(p string) []string { + var out []string + start := 0 + for i := 0; i < len(p); i++ { + if p[i] == '/' { + if i > start { + out = append(out, p[start:i]) + } + start = i + 1 + } + } + if start < len(p) { + out = append(out, p[start:]) + } + return out +} diff --git a/go/model/pack/pack_example_test.go b/go/model/pack/pack_example_test.go new file mode 100644 index 0000000..84a08f0 --- /dev/null +++ b/go/model/pack/pack_example_test.go @@ -0,0 +1,59 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package pack_test + +import ( + core "dappco.re/go" + "dappco.re/go/inference" + "dappco.re/go/inference/model/pack" +) + +// ExamplePack shows how to wrap an unpacked safetensors pack into a +// .model Trix container. +func ExamplePack() { + r := pack.Pack( + "/tmp/gemma-3-4b-it", + "/tmp/gemma-3-4b-it.model", + pack.PackOptions{ + Manifest: pack.Manifest{ + Model: inference.ModelIdentity{ + ID: "google/gemma-3-4b-it", + Architecture: "gemma", + QuantBits: 4, + }, + Tokenizer: inference.TokenizerIdentity{ + Kind: "sentencepiece", + }, + SourceFormat: "safetensors", + Producer: pack.Producer{Name: "go-mlx"}, + }, + }, + ) + if !r.OK { + _ = r.Value + } +} + +// ExampleUnpack shows how to extract a .model back into a directory. +func ExampleUnpack() { + r := pack.Unpack( + "/tmp/gemma-3-4b-it.model", + "/tmp/extracted", + pack.UnpackOptions{}, + ) + if !r.OK { + _ = r.Value + } +} + +// ExampleInspect shows how to read only the .model header and synthesise +// an inference.ModelPackInspection without extracting the payload. +func ExampleInspect() { + manifest, inspection, r := pack.Inspect("/tmp/gemma-3-4b-it.model") + if !r.OK { + return + } + _ = manifest.Producer.Name + _ = inspection.Model.Architecture + _ = core.Sprintf("inspected %s", inspection.Path) +} diff --git a/go/model/pack/pack_test.go b/go/model/pack/pack_test.go new file mode 100644 index 0000000..2a6044c --- /dev/null +++ b/go/model/pack/pack_test.go @@ -0,0 +1,666 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package pack_test + +import ( + "crypto/sha256" + "encoding/hex" + iofs "io/fs" + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" + "dappco.re/go/inference/model/pack" +) + +// fixtureFile is one synthetic file written into the fixture pack dir. +type fixtureFile struct { + relPath string + content []byte + mode iofs.FileMode +} + +// buildFixturePack writes a small but realistic Gemma-4-shaped pack into +// dir — config.json + tokenizer.json + chat_template.jinja + a small +// model.safetensors with a valid header. Tests use this as the round-trip +// source. +func buildFixturePack(t *testing.T, dir string, extras ...fixtureFile) { + t.Helper() + + if mr := core.MkdirAll(dir, 0o755); !mr.OK { + t.Fatalf("MkdirAll %q: %v", dir, mr.Value) + } + + defaults := []fixtureFile{ + { + relPath: "config.json", + content: []byte(`{"model_type":"gemma","architectures":["GemmaForCausalLM"],"hidden_size":2304,"num_hidden_layers":26,"num_attention_heads":8,"vocab_size":262144}`), + mode: 0o644, + }, + { + relPath: "tokenizer.json", + content: []byte(`{"version":"1.0","tokenizer":{"type":"sentencepiece"},"bos_token":"","eos_token":""}`), + mode: 0o644, + }, + { + relPath: "chat_template.jinja", + content: []byte(`{% for m in messages %}{{m.role}}: {{m.content}}{% endfor %}`), + mode: 0o644, + }, + { + relPath: "model.safetensors", + content: synthSafetensors(), + mode: 0o644, + }, + } + + for _, ff := range append(defaults, extras...) { + path := core.JoinPath(dir, ff.relPath) + if dirPath := core.PathDir(path); dirPath != dir { + if mr := core.MkdirAll(dirPath, 0o755); !mr.OK { + t.Fatalf("MkdirAll %q: %v", dirPath, mr.Value) + } + } + if wr := core.WriteFile(path, ff.content, ff.mode); !wr.OK { + t.Fatalf("WriteFile %q: %v", path, wr.Value) + } + } +} + +// synthSafetensors emits a valid-shape safetensors file: 8-byte little- +// endian header length + JSON header + zero-byte tensor payload. Loader +// won't read tensors so empty payload is fine. +func synthSafetensors() []byte { + header := []byte(`{"__metadata__":{"format":"pt"}}`) + // 8-byte little-endian length prefix + out := make([]byte, 8+len(header)) + n := uint64(len(header)) + for i := 0; i < 8; i++ { + out[i] = byte(n >> (8 * i)) + } + copy(out[8:], header) + return out +} + +// fileTreeHash returns a single SHA-256 over a sorted (relPath || sha256(content)) +// of every regular file under dir, suitable for byte-level tree equality +// assertions. +func fileTreeHash(t *testing.T, dir string) string { + t.Helper() + fs := (&core.Fs{}).NewUnrestricted() + type entry struct { + rel string + hash [32]byte + } + var entries []entry + for e, err := range fs.WalkSeq(dir) { + if err != nil { + t.Fatalf("WalkSeq %q: %v", dir, err) + } + if e.IsDir { + continue + } + rr := core.ReadFile(core.JoinPath(dir, e.Path)) + if !rr.OK { + t.Fatalf("ReadFile %q: %v", e.Path, rr.Value) + } + entries = append(entries, entry{ + rel: e.Path, + hash: sha256.Sum256(rr.Value.([]byte)), + }) + } + // Sort + for i := 0; i < len(entries); i++ { + for j := i + 1; j < len(entries); j++ { + if entries[j].rel < entries[i].rel { + entries[i], entries[j] = entries[j], entries[i] + } + } + } + h := sha256.New() + for _, e := range entries { + h.Write([]byte(e.rel)) + h.Write([]byte{0}) + h.Write(e.hash[:]) + h.Write([]byte{0}) + } + return hex.EncodeToString(h.Sum(nil)) +} + +func sampleManifest() pack.Manifest { + return pack.Manifest{ + Model: inference.ModelIdentity{ + ID: "google/gemma-3-4b-it", + Architecture: "gemma", + QuantBits: 4, + ContextLength: 8192, + NumLayers: 26, + HiddenSize: 2304, + VocabSize: 262144, + }, + Tokenizer: inference.TokenizerIdentity{ + Kind: "sentencepiece", + ChatTemplate: "gemma", + }, + SourceFormat: "safetensors", + Producer: pack.Producer{ + Name: "go-mlx", + Commit: "abc123", + }, + } +} + +func TestPack_Roundtrip_Good(t *testing.T) { + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-roundtrip-good-") + defer core.RemoveAll(tempRoot) + + srcDir := core.JoinPath(tempRoot, "src") + dest := core.JoinPath(tempRoot, "out.model") + outDir := core.JoinPath(tempRoot, "out") + + buildFixturePack(t, srcDir) + srcHash := fileTreeHash(t, srcDir) + + if r := pack.Pack(srcDir, dest, pack.PackOptions{Manifest: sampleManifest()}); !r.OK { + t.Fatalf("Pack: %v", r.Value) + } + + // Verify dest starts with Trix magic "MDL1". + data := readBytes(t, dest) + if string(data[:4]) != pack.Magic { + t.Fatalf("expected magic %q at offset 0, got %q", pack.Magic, string(data[:4])) + } + + if r := pack.Unpack(dest, outDir, pack.UnpackOptions{}); !r.OK { + t.Fatalf("Unpack: %v", r.Value) + } + outHash := fileTreeHash(t, outDir) + + if srcHash != outHash { + t.Fatalf("file tree hash mismatch:\n src: %s\n out: %s", srcHash, outHash) + } +} + +func TestPack_Inspect_Good(t *testing.T) { + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-inspect-good-") + defer core.RemoveAll(tempRoot) + + srcDir := core.JoinPath(tempRoot, "src") + dest := core.JoinPath(tempRoot, "out.model") + + buildFixturePack(t, srcDir) + + if r := pack.Pack(srcDir, dest, pack.PackOptions{Manifest: sampleManifest()}); !r.OK { + t.Fatalf("Pack: %v", r.Value) + } + + manifest, inspection, r := pack.Inspect(dest) + if !r.OK { + t.Fatalf("Inspect: %v", r.Value) + } + if manifest.Model.Architecture != "gemma" { + t.Errorf("expected Architecture gemma, got %q", manifest.Model.Architecture) + } + if manifest.Model.QuantBits != 4 { + t.Errorf("expected QuantBits 4, got %d", manifest.Model.QuantBits) + } + if manifest.SourceFormat != "safetensors" { + t.Errorf("expected SourceFormat safetensors, got %q", manifest.SourceFormat) + } + if manifest.Producer.Created == "" { + t.Errorf("expected Producer.Created to be auto-filled, was empty") + } + if inspection.Path != dest { + t.Errorf("expected inspection.Path %q, got %q", dest, inspection.Path) + } + if inspection.Format != "safetensors" { + t.Errorf("expected inspection.Format safetensors, got %q", inspection.Format) + } + if inspection.Model.Architecture != "gemma" { + t.Errorf("expected inspection.Model.Architecture gemma, got %q", inspection.Model.Architecture) + } +} + +func TestPack_Roundtrip_Bad(t *testing.T) { + // Truncated .model file must return a failing Result, never panic. + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-bad-") + defer core.RemoveAll(tempRoot) + + srcDir := core.JoinPath(tempRoot, "src") + dest := core.JoinPath(tempRoot, "out.model") + outDir := core.JoinPath(tempRoot, "out") + + buildFixturePack(t, srcDir) + + if r := pack.Pack(srcDir, dest, pack.PackOptions{Manifest: sampleManifest()}); !r.OK { + t.Fatalf("Pack: %v", r.Value) + } + + // Truncate dest to half its size — payload is now corrupt. + full := readBytes(t, dest) + half := full[:len(full)/2] + if wr := core.WriteFile(dest, half, 0o644); !wr.OK { + t.Fatalf("WriteFile (truncate): %v", wr.Value) + } + + r := pack.Unpack(dest, outDir, pack.UnpackOptions{}) + if r.OK { + t.Fatalf("expected Unpack to fail on truncated input, got OK") + } +} + +func TestPack_Roundtrip_Ugly(t *testing.T) { + // Unusual but valid file names — spaces and unicode — must round-trip + // intact. + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-ugly-") + defer core.RemoveAll(tempRoot) + + srcDir := core.JoinPath(tempRoot, "src") + dest := core.JoinPath(tempRoot, "out.model") + outDir := core.JoinPath(tempRoot, "out") + + extras := []fixtureFile{ + {relPath: "notes with spaces.txt", content: []byte("hello"), mode: 0o644}, + {relPath: "papierość.bin", content: []byte{0x00, 0x01, 0x02, 0xFF}, mode: 0o644}, + {relPath: "subdir/nested.json", content: []byte(`{"k":"v"}`), mode: 0o644}, + } + buildFixturePack(t, srcDir, extras...) + srcHash := fileTreeHash(t, srcDir) + + if r := pack.Pack(srcDir, dest, pack.PackOptions{Manifest: sampleManifest()}); !r.OK { + t.Fatalf("Pack: %v", r.Value) + } + if r := pack.Unpack(dest, outDir, pack.UnpackOptions{}); !r.OK { + t.Fatalf("Unpack: %v", r.Value) + } + if outHash := fileTreeHash(t, outDir); outHash != srcHash { + t.Fatalf("ugly tree hash mismatch:\n src: %s\n out: %s", srcHash, outHash) + } +} + +func TestPack_VindexOption_Bad(t *testing.T) { + // Seam-honesty: VindexBlob != nil must return an explicit + // "not yet implemented" failure so callers know the embedding seam + // exists but isn't wired. + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-vindex-bad-") + defer core.RemoveAll(tempRoot) + + srcDir := core.JoinPath(tempRoot, "src") + dest := core.JoinPath(tempRoot, "out.model") + + buildFixturePack(t, srcDir) + + r := pack.Pack(srcDir, dest, pack.PackOptions{ + Manifest: sampleManifest(), + VindexBlob: []byte("not real msgpack but non-nil"), + }) + if r.OK { + t.Fatalf("expected Pack to fail when VindexBlob is non-nil, got OK") + } +} + +func TestPack_List_Good(t *testing.T) { + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-list-good-") + defer core.RemoveAll(tempRoot) + + srcDir := core.JoinPath(tempRoot, "src") + dest := core.JoinPath(tempRoot, "out.model") + + buildFixturePack(t, srcDir) + + if r := pack.Pack(srcDir, dest, pack.PackOptions{Manifest: sampleManifest()}); !r.OK { + t.Fatalf("Pack: %v", r.Value) + } + + entries, manifest, r := pack.List(dest) + if !r.OK { + t.Fatalf("List: %v", r.Value) + } + if manifest.SourceFormat != "safetensors" { + t.Errorf("expected manifest.SourceFormat safetensors, got %q", manifest.SourceFormat) + } + + want := map[string]bool{ + "config.json": false, + "tokenizer.json": false, + "chat_template.jinja": false, + "model.safetensors": false, + } + for _, e := range entries { + if _, ok := want[e.Path]; !ok { + t.Errorf("unexpected entry %q", e.Path) + continue + } + want[e.Path] = true + if e.Size <= 0 { + t.Errorf("entry %q has non-positive size %d", e.Path, e.Size) + } + } + for name, seen := range want { + if !seen { + t.Errorf("expected entry %q not present in List output", name) + } + } +} + +func TestPack_List_Bad(t *testing.T) { + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-list-bad-") + defer core.RemoveAll(tempRoot) + + srcDir := core.JoinPath(tempRoot, "src") + dest := core.JoinPath(tempRoot, "out.model") + + buildFixturePack(t, srcDir) + if r := pack.Pack(srcDir, dest, pack.PackOptions{Manifest: sampleManifest()}); !r.OK { + t.Fatalf("Pack: %v", r.Value) + } + + full := readBytes(t, dest) + if wr := core.WriteFile(dest, full[:len(full)/2], 0o644); !wr.OK { + t.Fatalf("WriteFile (truncate): %v", wr.Value) + } + + if _, _, r := pack.List(dest); r.OK { + t.Fatalf("expected List to fail on truncated input, got OK") + } +} + +func TestPack_Deterministic_Good(t *testing.T) { + // Same source tree + same Manifest (Producer.Created pinned) must + // produce byte-identical .model output, twice in a row. The property + // `.model` is content-addressable depends on it: same input → same + // SHA-256 → cache hits, lineage chains, registry dedup all work. + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-deterministic-good-") + defer core.RemoveAll(tempRoot) + + srcDir := core.JoinPath(tempRoot, "src") + dest1 := core.JoinPath(tempRoot, "out1.model") + dest2 := core.JoinPath(tempRoot, "out2.model") + + buildFixturePack(t, srcDir, fixtureFile{ + relPath: "extras/zeta.bin", + content: []byte("trailing-entry-to-stress-sort-order"), + mode: 0o644, + }, fixtureFile{ + relPath: "extras/alpha.bin", + content: []byte("leading-entry-to-stress-sort-order"), + mode: 0o644, + }) + + manifest := sampleManifest() + manifest.Producer.Created = "2026-01-01T00:00:00Z" // pin so the only delta source is the algorithm itself + + if r := pack.Pack(srcDir, dest1, pack.PackOptions{Manifest: manifest}); !r.OK { + t.Fatalf("Pack #1: %v", r.Value) + } + if r := pack.Pack(srcDir, dest2, pack.PackOptions{Manifest: manifest}); !r.OK { + t.Fatalf("Pack #2: %v", r.Value) + } + + b1 := readBytes(t, dest1) + b2 := readBytes(t, dest2) + + h1 := sha256.Sum256(b1) + h2 := sha256.Sum256(b2) + + if hex.EncodeToString(h1[:]) != hex.EncodeToString(h2[:]) { + t.Fatalf("Pack non-deterministic:\n size1=%d sha=%s\n size2=%d sha=%s\nFirst differing byte index: %d", + len(b1), hex.EncodeToString(h1[:]), + len(b2), hex.EncodeToString(h2[:]), + firstDiffIndex(b1, b2), + ) + } +} + +func firstDiffIndex(a, b []byte) int { + n := len(a) + if len(b) < n { + n = len(b) + } + for i := 0; i < n; i++ { + if a[i] != b[i] { + return i + } + } + if len(a) != len(b) { + return n + } + return -1 +} + +func TestPack_Fingerprint_TimestampOrthogonal_Good(t *testing.T) { + // Two manifests differing only in Producer.Created (provenance) + + // Lineage (provenance) + Signatures (orthogonal) must produce the + // same identity fingerprint. + a := sampleManifest() + a.Producer.Created = "2026-01-01T00:00:00Z" + a.Lineage = &pack.Lineage{TrainURI: "file:///a.train", TrainSHA: "deadbeef"} + a.Signatures = []pack.Signature{{KeyID: "k1", Alg: "ed25519", Sig: "sigA"}} + + b := sampleManifest() + b.Producer.Created = "2027-06-15T12:34:56Z" + b.Producer.Commit = "different-commit" + b.Lineage = &pack.Lineage{TrainURI: "file:///somewhere/else.train", TrainSHA: "beefcafe"} + b.Signatures = []pack.Signature{{KeyID: "k2", Alg: "ed25519", Sig: "sigB"}} + + if pack.Fingerprint(a) != pack.Fingerprint(b) { + t.Fatalf("expected fingerprints equal under provenance-only delta:\n a=%s\n b=%s", + pack.Fingerprint(a), pack.Fingerprint(b)) + } +} + +func TestPack_Fingerprint_IdentityDelta_Ugly(t *testing.T) { + // Each identity-shaping field, varied independently, must change the + // fingerprint. If any of these doesn't change it, identity has a hole. + base := sampleManifest() + baseFP := pack.Fingerprint(base) + + cases := []struct { + name string + mutate func(*pack.Manifest) + }{ + {"Model.Architecture", func(m *pack.Manifest) { m.Model.Architecture = "llama" }}, + {"Model.QuantBits", func(m *pack.Manifest) { m.Model.QuantBits = 8 }}, + {"Model.NumLayers", func(m *pack.Manifest) { m.Model.NumLayers = 99 }}, + {"Model.VocabSize", func(m *pack.Manifest) { m.Model.VocabSize = 100000 }}, + {"Tokenizer.Kind", func(m *pack.Manifest) { m.Tokenizer.Kind = "gpt2-bpe" }}, + {"Tokenizer.ChatTemplate", func(m *pack.Manifest) { m.Tokenizer.ChatTemplate = "llama" }}, + {"SourceFormat", func(m *pack.Manifest) { m.SourceFormat = "gguf" }}, + } + for _, tc := range cases { + m := sampleManifest() + tc.mutate(&m) + got := pack.Fingerprint(m) + if got == baseFP { + t.Errorf("mutating %s did not change fingerprint (still %s)", tc.name, got) + } + } +} + +func TestPack_Fingerprint_HexShape_Good(t *testing.T) { + // Sanity: fingerprint is hex sha256 (64 chars, lower-case hex). + fp := pack.Fingerprint(sampleManifest()) + if len(fp) != 64 { + t.Errorf("expected 64-char fingerprint, got %d (%q)", len(fp), fp) + } + for _, r := range fp { + switch { + case r >= '0' && r <= '9': + case r >= 'a' && r <= 'f': + default: + t.Errorf("non-hex character %q in fingerprint %q", r, fp) + } + } +} + +func TestPack_Hash_Stable_Good(t *testing.T) { + // Same source dir hashed twice must return identical hex. + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-hash-stable-") + defer core.RemoveAll(tempRoot) + srcDir := core.JoinPath(tempRoot, "src") + buildFixturePack(t, srcDir) + + h1, r1 := pack.Hash(srcDir) + if !r1.OK { + t.Fatalf("Hash (#1): %v", r1.Value) + } + h2, r2 := pack.Hash(srcDir) + if !r2.OK { + t.Fatalf("Hash (#2): %v", r2.Value) + } + if h1 != h2 { + t.Fatalf("expected stable hash, got %s vs %s", h1, h2) + } + if len(h1) != 64 { + t.Fatalf("expected 64-char hex, got %d (%q)", len(h1), h1) + } +} + +func TestPack_Hash_DistinguishesContent_Ugly(t *testing.T) { + // Pack A and Pack B share filenames but config.json differs. + // Hash must differ. + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-hash-distinct-") + defer core.RemoveAll(tempRoot) + srcA := core.JoinPath(tempRoot, "a") + srcB := core.JoinPath(tempRoot, "b") + buildFixturePack(t, srcA) + buildFixturePack(t, srcB) + + // Mutate B's config.json. + if wr := core.WriteFile(core.JoinPath(srcB, "config.json"), + []byte(`{"model_type":"llama","hidden_size":4096}`), 0o644); !wr.OK { + t.Fatalf("rewrite B config.json: %v", wr.Value) + } + + hA, ra := pack.Hash(srcA) + hB, rb := pack.Hash(srcB) + if !ra.OK || !rb.OK { + t.Fatalf("Hash A=%v B=%v", ra.Value, rb.Value) + } + if hA == hB { + t.Fatalf("expected different hashes for divergent config.json, both %s", hA) + } +} + +func TestPack_Hash_SafetensorsSizeAffects_Ugly(t *testing.T) { + // Same JSON files but different *.safetensors size — hash must differ. + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-hash-st-size-") + defer core.RemoveAll(tempRoot) + srcA := core.JoinPath(tempRoot, "a") + srcB := core.JoinPath(tempRoot, "b") + buildFixturePack(t, srcA) + buildFixturePack(t, srcB) + + stPath := core.JoinPath(srcB, "model.safetensors") + rr := core.ReadFile(stPath) + if !rr.OK { + t.Fatalf("ReadFile B safetensors: %v", rr.Value) + } + larger := append(rr.Value.([]byte), make([]byte, 4096)...) + if wr := core.WriteFile(stPath, larger, 0o644); !wr.OK { + t.Fatalf("WriteFile B safetensors: %v", wr.Value) + } + + hA, _ := pack.Hash(srcA) + hB, _ := pack.Hash(srcB) + if hA == hB { + t.Fatalf("expected different hashes for divergent safetensors size, both %s", hA) + } +} + +func TestPack_Hash_OptionalFilesSkippedCleanly_Good(t *testing.T) { + // Pack A has chat_template.jinja; Pack B doesn't. Hash differs but + // neither errors out. Missing optional files are part of identity. + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-hash-optional-") + defer core.RemoveAll(tempRoot) + srcA := core.JoinPath(tempRoot, "a") + srcB := core.JoinPath(tempRoot, "b") + buildFixturePack(t, srcA) + + // Build B without chat_template.jinja by writing only the core 3 files. + if mr := core.MkdirAll(srcB, 0o755); !mr.OK { + t.Fatalf("MkdirAll: %v", mr.Value) + } + for _, name := range []string{"config.json", "tokenizer.json", "model.safetensors"} { + src := core.JoinPath(srcA, name) + dst := core.JoinPath(srcB, name) + rr := core.ReadFile(src) + if !rr.OK { + t.Fatalf("ReadFile %q: %v", name, rr.Value) + } + if wr := core.WriteFile(dst, rr.Value.([]byte), 0o644); !wr.OK { + t.Fatalf("WriteFile %q: %v", name, wr.Value) + } + } + + hA, ra := pack.Hash(srcA) + hB, rb := pack.Hash(srcB) + if !ra.OK || !rb.OK { + t.Fatalf("Hash A=%v B=%v", ra.Value, rb.Value) + } + if hA == hB { + t.Fatalf("expected different hashes (A has chat_template, B doesn't), both %s", hA) + } +} + +func TestPack_Hash_AutoPopulatedInPack_Good(t *testing.T) { + // Pack with empty Manifest.Model.Hash must auto-populate via Hash(srcDir). + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-hash-autofill-") + defer core.RemoveAll(tempRoot) + srcDir := core.JoinPath(tempRoot, "src") + dest := core.JoinPath(tempRoot, "out.model") + buildFixturePack(t, srcDir) + + m := sampleManifest() + m.Model.Hash = "" // explicit empty + + if r := pack.Pack(srcDir, dest, pack.PackOptions{Manifest: m}); !r.OK { + t.Fatalf("Pack: %v", r.Value) + } + + manifest, _, r := pack.Inspect(dest) + if !r.OK { + t.Fatalf("Inspect: %v", r.Value) + } + if manifest.Model.Hash == "" { + t.Fatalf("expected Manifest.Model.Hash auto-populated, was empty") + } + if len(manifest.Model.Hash) != 64 { + t.Errorf("expected 64-char hex hash, got %d (%q)", len(manifest.Model.Hash), manifest.Model.Hash) + } + + expected, _ := pack.Hash(srcDir) + if manifest.Model.Hash != expected { + t.Errorf("Pack auto-hash != Hash(srcDir):\n pack: %s\n helper: %s", manifest.Model.Hash, expected) + } +} + +func TestPack_Hash_RespectsCallerProvidedValue_Good(t *testing.T) { + // Pack with caller-set Manifest.Model.Hash must NOT overwrite. + tempRoot := (&core.Fs{}).NewUnrestricted().TempDir("pack-hash-respect-") + defer core.RemoveAll(tempRoot) + srcDir := core.JoinPath(tempRoot, "src") + dest := core.JoinPath(tempRoot, "out.model") + buildFixturePack(t, srcDir) + + m := sampleManifest() + m.Model.Hash = "deadbeef-caller-provided" + + if r := pack.Pack(srcDir, dest, pack.PackOptions{Manifest: m}); !r.OK { + t.Fatalf("Pack: %v", r.Value) + } + manifest, _, _ := pack.Inspect(dest) + if manifest.Model.Hash != "deadbeef-caller-provided" { + t.Errorf("Pack overwrote caller-provided Hash; got %q", manifest.Model.Hash) + } +} + +// readBytes is a small test helper that reads a file via core.ReadFile. +func readBytes(t *testing.T, path string) []byte { + t.Helper() + rr := core.ReadFile(path) + if !rr.OK { + t.Fatalf("ReadFile %q: %v", path, rr.Value) + } + return rr.Value.([]byte) +} diff --git a/go/ollama/chunkenc.go b/go/ollama/chunkenc.go new file mode 100644 index 0000000..681ffdf --- /dev/null +++ b/go/ollama/chunkenc.go @@ -0,0 +1,236 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Hand-rolled encoders for the Ollama wire shapes — ChatResponse, +// GenerateResponse, TagsResponse. Per-token cost matters: Ollama +// streams one ChatResponse or GenerateResponse JSON object per +// generated token on /api/chat and /api/generate respectively, so +// every per-shape encoder fires N times per generation. +// +// These encoders compose the shared jsonenc primitives at +// dappco.re/go/inference/jsonenc (W9-Z lift) and land at a single +// buffer allocation per call — same minimax lift as state/filestore's +// encodeRecordMeta (W8-D) and openai's chunkenc.go (W9-D). +// +// Note: encoders are exported as standalone Append* functions rather +// than MarshalJSON methods. encoding/json.Marshal validates and +// recopies the bytes returned by MarshalJSON — for top-level marshals +// that erases the win. Consumers on the hot path call Append* entry +// points directly; non-hot-path call sites can keep using +// core.JSONMarshalString. + +package ollama + +import "dappco.re/go/inference/jsonenc" + +// appendMessage walks one Message into buf. Both fields always +// emitted (no omitempty on Role/Content per the Ollama API +// contract). Used inline by AppendChatResponse rather than as a +// MarshalJSON method — see package note above. +// +// Wire shape: {"role":"X","content":"Y"} +func appendMessage(buf []byte, msg Message) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "role", msg.Role, false) + buf = jsonenc.AppendStringField(buf, "content", msg.Content, true) + return append(buf, '}') +} + +// AppendChatResponse walks a ChatResponse into buf. Fires per +// streamed NDJSON token (server side) — one of the two hottest +// encoders in the package. +// +// Field order matches the struct declaration: model, message, +// done, prompt_eval_count, eval_count, four duration fields. All +// five count/duration fields carry omitempty semantics matching +// the reflect-path behaviour (zero-int / zero-int64 suppressed). +// +// buf := AppendChatResponse(make([]byte, 0, chatResponseSize(resp)), resp) +func AppendChatResponse(buf []byte, resp ChatResponse) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "model", resp.Model, false) + buf = append(buf, ',', '"', 'm', 'e', 's', 's', 'a', 'g', 'e', '"', ':') + buf = appendMessage(buf, resp.Message) + buf = jsonenc.AppendBoolField(buf, "done", resp.Done, true) + if resp.PromptEvalCount != 0 { + buf = jsonenc.AppendIntField(buf, "prompt_eval_count", resp.PromptEvalCount, true) + } + if resp.EvalCount != 0 { + buf = jsonenc.AppendIntField(buf, "eval_count", resp.EvalCount, true) + } + if resp.TotalDuration != 0 { + buf = jsonenc.AppendInt64Field(buf, "total_duration", resp.TotalDuration, true) + } + if resp.LoadDuration != 0 { + buf = jsonenc.AppendInt64Field(buf, "load_duration", resp.LoadDuration, true) + } + if resp.PromptEvalDuration != 0 { + buf = jsonenc.AppendInt64Field(buf, "prompt_eval_duration", resp.PromptEvalDuration, true) + } + if resp.EvalDuration != 0 { + buf = jsonenc.AppendInt64Field(buf, "eval_duration", resp.EvalDuration, true) + } + return append(buf, '}') +} + +// chatResponseSize estimates the backing-buffer size for one +// ChatResponse so AppendChatResponse allocates once for the typical +// shape. Over-sizing inflates the make() allocation cost above what +// the reflect-path's tighter sizing pays; the estimate matches the +// actual wire-byte count closely. +// +// Fixed prefix: {"model":"X","message":{"role":"R","content":"C"},"done":bool} +// = 1 (open {) + 10 + len(Model) + 11 (",message":) + 24 + len(Role) + len(Content) + 13 (",done":false) + 1 (close }) +// = 60 + variable bytes +func chatResponseSize(resp ChatResponse) int { + size := 60 + len(resp.Model) + len(resp.Message.Role) + len(resp.Message.Content) + if resp.PromptEvalCount != 0 { + size += 25 + } + if resp.EvalCount != 0 { + size += 18 + } + if resp.TotalDuration != 0 { + size += 35 + } + if resp.LoadDuration != 0 { + size += 34 + } + if resp.PromptEvalDuration != 0 { + size += 41 + } + if resp.EvalDuration != 0 { + size += 34 + } + return size +} + +// AppendGenerateResponse walks a GenerateResponse into buf — the +// /api/generate per-NDJSON-token streaming shape. Same fields as +// ChatResponse minus the nested Message. +// +// buf := AppendGenerateResponse(make([]byte, 0, generateResponseSize(resp)), resp) +func AppendGenerateResponse(buf []byte, resp GenerateResponse) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "model", resp.Model, false) + buf = jsonenc.AppendStringField(buf, "response", resp.Response, true) + buf = jsonenc.AppendBoolField(buf, "done", resp.Done, true) + if resp.PromptEvalCount != 0 { + buf = jsonenc.AppendIntField(buf, "prompt_eval_count", resp.PromptEvalCount, true) + } + if resp.EvalCount != 0 { + buf = jsonenc.AppendIntField(buf, "eval_count", resp.EvalCount, true) + } + if resp.TotalDuration != 0 { + buf = jsonenc.AppendInt64Field(buf, "total_duration", resp.TotalDuration, true) + } + if resp.LoadDuration != 0 { + buf = jsonenc.AppendInt64Field(buf, "load_duration", resp.LoadDuration, true) + } + if resp.PromptEvalDuration != 0 { + buf = jsonenc.AppendInt64Field(buf, "prompt_eval_duration", resp.PromptEvalDuration, true) + } + if resp.EvalDuration != 0 { + buf = jsonenc.AppendInt64Field(buf, "eval_duration", resp.EvalDuration, true) + } + return append(buf, '}') +} + +// generateResponseSize estimates the GenerateResponse buffer. +// +// Fixed prefix: {"model":"X","response":"Y","done":bool} +// = 1 + 10+len(Model) + 14+len(Response) + 13 + 1 +func generateResponseSize(resp GenerateResponse) int { + size := 39 + len(resp.Model) + len(resp.Response) + if resp.PromptEvalCount != 0 { + size += 25 + } + if resp.EvalCount != 0 { + size += 18 + } + if resp.TotalDuration != 0 { + size += 35 + } + if resp.LoadDuration != 0 { + size += 34 + } + if resp.PromptEvalDuration != 0 { + size += 41 + } + if resp.EvalDuration != 0 { + size += 34 + } + return size +} + +// appendModelTag walks one ModelTag into buf — used inline by +// AppendTagsResponse. Three of the four fields carry omitempty +// (Model, ModifiedAt, Size); Name is always emitted. +func appendModelTag(buf []byte, tag ModelTag) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "name", tag.Name, false) + if tag.Model != "" { + buf = jsonenc.AppendStringField(buf, "model", tag.Model, true) + } + if tag.ModifiedAt != "" { + buf = jsonenc.AppendStringField(buf, "modified_at", tag.ModifiedAt, true) + } + if tag.Size != 0 { + buf = jsonenc.AppendInt64Field(buf, "size", tag.Size, true) + } + return append(buf, '}') +} + +// AppendTagsResponse walks a TagsResponse (/api/tags). Discovery +// hot path — fires once per client startup (open-webui pings this +// on every page load) and again on every model-list refresh. +// +// A nil Models slice emits as "models":null (matching encoding/json +// semantics for nil-slice fields); an empty []ModelTag{} emits as +// "models":[]. Downstream consumers (e.g. open-webui) treat both +// forms as "no models served" interchangeably, but the wire shape +// must remain consistent with the reflect-path output for proxy +// pass-through. +// +// buf := AppendTagsResponse(make([]byte, 0, tagsResponseSize(resp)), resp) +func AppendTagsResponse(buf []byte, resp TagsResponse) []byte { + buf = append(buf, '{', '"', 'm', 'o', 'd', 'e', 'l', 's', '"', ':') + if resp.Models == nil { + return append(buf, 'n', 'u', 'l', 'l', '}') + } + buf = append(buf, '[') + for i, tag := range resp.Models { + if i > 0 { + buf = append(buf, ',') + } + buf = appendModelTag(buf, tag) + } + return append(buf, ']', '}') +} + +// tagsResponseSize estimates the TagsResponse buffer. The +// "models":null variant emits 17 bytes; the slice variant grows +// per-tag. +func tagsResponseSize(resp TagsResponse) int { + if resp.Models == nil { + return 17 // {"models":null} + } + size := 13 // {"models":[]} + for i, tag := range resp.Models { + if i > 0 { + size++ + } + // {"name":"X" = 11 fixed + name + size += 11 + len(tag.Name) + if tag.Model != "" { + size += 11 + len(tag.Model) + } + if tag.ModifiedAt != "" { + size += 16 + len(tag.ModifiedAt) + } + if tag.Size != 0 { + size += 9 + 12 // "size":NNNNNNNNN + } + size++ // closing } + } + return size +} diff --git a/go/ollama/ollama.go b/go/ollama/ollama.go new file mode 100644 index 0000000..dd1eead --- /dev/null +++ b/go/ollama/ollama.go @@ -0,0 +1,159 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package ollama provides Ollama-compatible wire primitives over the shared +// inference contracts. +package ollama + +import "dappco.re/go/inference" + +const ( + DefaultChatPath = "/api/chat" + DefaultGeneratePath = "/api/generate" + DefaultTagsPath = "/api/tags" + DefaultShowPath = "/api/show" +) + +// Message is one Ollama chat turn. +type Message struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// Options carries Ollama generation options that map cleanly to inference. +type Options struct { + Temperature float32 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float32 `json:"top_p,omitempty"` + NumPredict int `json:"num_predict,omitempty"` +} + +// ChatRequest is the Ollama chat request shape. +type ChatRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream bool `json:"stream,omitempty"` + Options Options `json:"options,omitempty"` +} + +// GenerateRequest is the Ollama prompt-generation request shape. +type GenerateRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + Stream bool `json:"stream,omitempty"` + Options Options `json:"options,omitempty"` +} + +// ChatResponse is the Ollama chat response shape. +type ChatResponse struct { + Model string `json:"model"` + Message Message `json:"message"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + TotalDuration int64 `json:"total_duration,omitempty"` + LoadDuration int64 `json:"load_duration,omitempty"` + PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"` + EvalDuration int64 `json:"eval_duration,omitempty"` +} + +// GenerateResponse is the Ollama generate response shape. +type GenerateResponse struct { + Model string `json:"model"` + Response string `json:"response"` + Done bool `json:"done"` + PromptEvalCount int `json:"prompt_eval_count,omitempty"` + EvalCount int `json:"eval_count,omitempty"` + TotalDuration int64 `json:"total_duration,omitempty"` + LoadDuration int64 `json:"load_duration,omitempty"` + PromptEvalDuration int64 `json:"prompt_eval_duration,omitempty"` + EvalDuration int64 `json:"eval_duration,omitempty"` +} + +// ModelTag is one entry in /api/tags. +type ModelTag struct { + Name string `json:"name"` + Model string `json:"model,omitempty"` + ModifiedAt string `json:"modified_at,omitempty"` + Size int64 `json:"size,omitempty"` +} + +// TagsResponse is the /api/tags response shape. +type TagsResponse struct { + Models []ModelTag `json:"models"` +} + +// ShowRequest is the /api/show request shape. +type ShowRequest struct { + Model string `json:"model"` +} + +// ShowResponse is the /api/show response shape. +type ShowResponse struct { + License string `json:"license,omitempty"` + Modelfile string `json:"modelfile,omitempty"` + Parameters string `json:"parameters,omitempty"` + Template string `json:"template,omitempty"` + Details map[string]string `json:"details,omitempty"` +} + +// InferenceMessages converts Ollama messages into shared inference messages. +func InferenceMessages(messages []Message) []inference.Message { + out := make([]inference.Message, 0, len(messages)) + for _, msg := range messages { + out = append(out, inference.Message{Role: msg.Role, Content: msg.Content}) + } + return out +} + +// GenerateOptions converts Ollama options into inference options. +// +// Fused option — one closure captures the whole Options value and +// applies each set field in a single pass. The append cascade +// previously allocated one closure per With* call (up to 4); the +// fused form allocates the slice + a single closure capturing the +// (value-type) Options struct. +// +// The empty-Options case (all zero-valued fields) returns nil so +// callers paying inference.ApplyGenerateOpts skip a no-op closure +// invocation and we avoid the slice+closure allocs. +func GenerateOptions(options Options) []inference.GenerateOption { + if options.NumPredict <= 0 && options.Temperature == 0 && options.TopK <= 0 && options.TopP <= 0 { + return nil + } + return []inference.GenerateOption{func(c *inference.GenerateConfig) { + if options.NumPredict > 0 { + c.MaxTokens = options.NumPredict + } + if options.Temperature != 0 { + c.Temperature = options.Temperature + } + if options.TopK > 0 { + c.TopK = options.TopK + } + if options.TopP > 0 { + c.TopP = options.TopP + } + }} +} + +// NewChatResponse builds an Ollama chat response from metrics. +func NewChatResponse(model, text string, metrics inference.GenerateMetrics) ChatResponse { + return ChatResponse{ + Model: model, + Message: Message{Role: "assistant", Content: text}, + Done: true, + PromptEvalCount: metrics.PromptTokens, + EvalCount: metrics.GeneratedTokens, + } +} + +// NewGenerateResponse builds an Ollama generate response from metrics. +func NewGenerateResponse(model, text string, metrics inference.GenerateMetrics) GenerateResponse { + return GenerateResponse{ + Model: model, + Response: text, + Done: true, + PromptEvalCount: metrics.PromptTokens, + EvalCount: metrics.GeneratedTokens, + } +} diff --git a/go/ollama/ollama_bench_test.go b/go/ollama/ollama_bench_test.go new file mode 100644 index 0000000..fbe2e03 --- /dev/null +++ b/go/ollama/ollama_bench_test.go @@ -0,0 +1,459 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the Ollama-compatible wire primitives. Per AX-11 — +// every request handled by the /api/chat or /api/generate path runs +// JSON ingress/egress; InferenceMessages and GenerateOptions project +// the wire shape onto inference contracts on every served request, and +// the response constructors fire on every completion. +// +// Run: go test -bench='BenchmarkOllama' -benchtime=100ms -benchmem -run='^$' . + +package ollama + +import ( + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + ollamaSinkChatRequest ChatRequest + ollamaSinkChatResponse ChatResponse + ollamaSinkGenerateRequest GenerateRequest + ollamaSinkGenerateResponse GenerateResponse + ollamaSinkTagsResponse TagsResponse + ollamaSinkShowRequest ShowRequest + ollamaSinkShowResponse ShowResponse + ollamaSinkMessages []inference.Message + ollamaSinkOptions []inference.GenerateOption + ollamaSinkString string + ollamaSinkResult core.Result +) + +// --- Fixture builders --- + +// buildOllamaMessages builds a representative chat transcript of the +// requested turn count. Single-turn = user, multi-turn = alternating +// user/assistant. +func buildOllamaMessages(turns int) []Message { + out := make([]Message, 0, turns) + for i := 0; i < turns; i++ { + if i%2 == 0 { + out = append(out, Message{Role: "user", Content: "Summarise the paragraph in one sentence."}) + } else { + out = append(out, Message{Role: "assistant", Content: "The summary is concise and faithful to the original text."}) + } + } + return out +} + +func buildOllamaChatRequest(turns int) ChatRequest { + return ChatRequest{ + Model: "qwen3", + Messages: buildOllamaMessages(turns), + Stream: true, + Options: Options{Temperature: 0.7, TopK: 64, TopP: 0.95, NumPredict: 256}, + } +} + +func buildOllamaGenerateRequest() GenerateRequest { + return GenerateRequest{ + Model: "qwen3", + Prompt: "Summarise the paragraph in one sentence.", + Stream: true, + Options: Options{Temperature: 0.7, TopK: 64, TopP: 0.95, NumPredict: 256}, + } +} + +// --- JSON Marshal — request emission (client-side) --- + +func BenchmarkOllama_MarshalChatRequest_SingleTurn(b *testing.B) { + req := buildOllamaChatRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkString = core.JSONMarshalString(req) + } +} + +func BenchmarkOllama_MarshalChatRequest_FiveTurn(b *testing.B) { + req := buildOllamaChatRequest(5) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkString = core.JSONMarshalString(req) + } +} + +func BenchmarkOllama_MarshalChatRequest_TwentyTurn(b *testing.B) { + req := buildOllamaChatRequest(20) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkString = core.JSONMarshalString(req) + } +} + +func BenchmarkOllama_MarshalGenerateRequest(b *testing.B) { + req := buildOllamaGenerateRequest() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkString = core.JSONMarshalString(req) + } +} + +// --- JSON Marshal — response emission (server-side) --- + +func BenchmarkOllama_MarshalChatResponse(b *testing.B) { + resp := NewChatResponse("qwen3", "The summary is concise.", inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32}) + resp.TotalDuration = 1_500_000_000 + resp.LoadDuration = 100_000_000 + resp.PromptEvalDuration = 200_000_000 + resp.EvalDuration = 1_200_000_000 + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkString = core.JSONMarshalString(resp) + } +} + +func BenchmarkOllama_MarshalGenerateResponse(b *testing.B) { + resp := NewGenerateResponse("qwen3", "The summary is concise.", inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32}) + resp.TotalDuration = 1_500_000_000 + resp.LoadDuration = 100_000_000 + resp.PromptEvalDuration = 200_000_000 + resp.EvalDuration = 1_200_000_000 + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkString = core.JSONMarshalString(resp) + } +} + +// /api/tags listing — fired by ollama clients on every model-list +// discovery (e.g. open-webui startup). Three sizes — 1, 5, 20 models. + +func BenchmarkOllama_MarshalTagsResponse_OneModel(b *testing.B) { + resp := TagsResponse{Models: []ModelTag{ + {Name: "qwen3:latest", Model: "qwen3", ModifiedAt: "2026-05-21T10:00:00Z", Size: 4_500_000_000}, + }} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkString = core.JSONMarshalString(resp) + } +} + +func BenchmarkOllama_MarshalTagsResponse_FiveModels(b *testing.B) { + resp := TagsResponse{Models: []ModelTag{ + {Name: "qwen3:latest", Model: "qwen3", Size: 4_500_000_000}, + {Name: "gemma3:4b", Model: "gemma3", Size: 2_300_000_000}, + {Name: "llama3:8b", Model: "llama3", Size: 4_700_000_000}, + {Name: "qwen2.5:14b", Model: "qwen2.5", Size: 8_900_000_000}, + {Name: "deepseek:7b", Model: "deepseek", Size: 4_100_000_000}, + }} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkString = core.JSONMarshalString(resp) + } +} + +func BenchmarkOllama_MarshalTagsResponse_TwentyModels(b *testing.B) { + models := make([]ModelTag, 20) + for i := range models { + models[i] = ModelTag{ + Name: "model-bench:tag", + Model: "model-bench", + ModifiedAt: "2026-05-21T10:00:00Z", + Size: int64(4_000_000_000 + i*100_000_000), + } + } + resp := TagsResponse{Models: models} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkString = core.JSONMarshalString(resp) + } +} + +// --- JSON Unmarshal — request ingress (server-side) --- + +func BenchmarkOllama_UnmarshalChatRequest_SingleTurn(b *testing.B) { + body := core.JSONMarshalString(buildOllamaChatRequest(1)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req ChatRequest + ollamaSinkResult = core.JSONUnmarshalString(body, &req) + ollamaSinkChatRequest = req + } +} + +func BenchmarkOllama_UnmarshalChatRequest_FiveTurn(b *testing.B) { + body := core.JSONMarshalString(buildOllamaChatRequest(5)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req ChatRequest + ollamaSinkResult = core.JSONUnmarshalString(body, &req) + ollamaSinkChatRequest = req + } +} + +func BenchmarkOllama_UnmarshalChatRequest_TwentyTurn(b *testing.B) { + body := core.JSONMarshalString(buildOllamaChatRequest(20)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req ChatRequest + ollamaSinkResult = core.JSONUnmarshalString(body, &req) + ollamaSinkChatRequest = req + } +} + +func BenchmarkOllama_UnmarshalGenerateRequest(b *testing.B) { + body := core.JSONMarshalString(buildOllamaGenerateRequest()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req GenerateRequest + ollamaSinkResult = core.JSONUnmarshalString(body, &req) + ollamaSinkGenerateRequest = req + } +} + +// --- JSON Unmarshal — response ingestion (client-side) --- + +func BenchmarkOllama_UnmarshalChatResponse(b *testing.B) { + resp := NewChatResponse("qwen3", "The summary is concise.", inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32}) + body := core.JSONMarshalString(resp) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var r ChatResponse + ollamaSinkResult = core.JSONUnmarshalString(body, &r) + ollamaSinkChatResponse = r + } +} + +func BenchmarkOllama_UnmarshalGenerateResponse(b *testing.B) { + resp := NewGenerateResponse("qwen3", "The summary is concise.", inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32}) + body := core.JSONMarshalString(resp) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var r GenerateResponse + ollamaSinkResult = core.JSONUnmarshalString(body, &r) + ollamaSinkGenerateResponse = r + } +} + +func BenchmarkOllama_UnmarshalTagsResponse_FiveModels(b *testing.B) { + body := core.JSONMarshalString(TagsResponse{Models: []ModelTag{ + {Name: "qwen3:latest", Model: "qwen3", Size: 4_500_000_000}, + {Name: "gemma3:4b", Model: "gemma3", Size: 2_300_000_000}, + {Name: "llama3:8b", Model: "llama3", Size: 4_700_000_000}, + {Name: "qwen2.5:14b", Model: "qwen2.5", Size: 8_900_000_000}, + {Name: "deepseek:7b", Model: "deepseek", Size: 4_100_000_000}, + }}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var r TagsResponse + ollamaSinkResult = core.JSONUnmarshalString(body, &r) + ollamaSinkTagsResponse = r + } +} + +func BenchmarkOllama_UnmarshalShowRequest(b *testing.B) { + body := `{"model":"qwen3:latest"}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req ShowRequest + ollamaSinkResult = core.JSONUnmarshalString(body, &req) + ollamaSinkShowRequest = req + } +} + +// --- InferenceMessages — wire→internal conversion fired per request --- + +func BenchmarkOllama_InferenceMessages_SingleTurn(b *testing.B) { + messages := buildOllamaMessages(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkMessages = InferenceMessages(messages) + } +} + +func BenchmarkOllama_InferenceMessages_FiveTurn(b *testing.B) { + messages := buildOllamaMessages(5) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkMessages = InferenceMessages(messages) + } +} + +func BenchmarkOllama_InferenceMessages_TwentyTurn(b *testing.B) { + messages := buildOllamaMessages(20) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkMessages = InferenceMessages(messages) + } +} + +// --- GenerateOptions — sampling-field projection per request --- + +func BenchmarkOllama_GenerateOptions_AllFieldsSet(b *testing.B) { + options := Options{Temperature: 0.7, TopK: 64, TopP: 0.95, NumPredict: 256} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkOptions = GenerateOptions(options) + } +} + +func BenchmarkOllama_GenerateOptions_NoFieldsSet(b *testing.B) { + options := Options{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkOptions = GenerateOptions(options) + } +} + +// --- Response constructors — fire once per non-streaming completion --- + +func BenchmarkOllama_NewChatResponse(b *testing.B) { + metrics := inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32} + text := "The summary is concise and faithful to the original text." + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkChatResponse = NewChatResponse("qwen3", text, metrics) + } +} + +func BenchmarkOllama_NewGenerateResponse(b *testing.B) { + metrics := inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32} + text := "The summary is concise and faithful to the original text." + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkGenerateResponse = NewGenerateResponse("qwen3", text, metrics) + } +} + +// --- Append* fast-path encoders --- +// +// These bench the direct-entry hand-rolled encoders consumers on the +// HTTP hot path should call (an in-tree serve handler reaching for +// AppendChatResponse rather than core.JSONMarshalString). Each +// bench is the consumer-facing measurement — the "real" win once +// the proxy/serve handler lifts off encoding/json. +// +// The pre-sized-buffer benches reuse a backing scratch buffer +// per-iteration to model the steady-state hot-loop case where the +// caller keeps a per-connection emission buffer. The make-each-call +// benches model the cold-path (one-shot non-streaming response). + +var ollamaSinkBuf []byte + +func BenchmarkOllama_AppendChatResponse_Streaming(b *testing.B) { + resp := NewChatResponse("qwen3", "tok", inference.GenerateMetrics{}) + resp.Message.Role = "" + resp.Done = false + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkBuf = AppendChatResponse(make([]byte, 0, chatResponseSize(resp)), resp) + } +} + +func BenchmarkOllama_AppendChatResponse_Final(b *testing.B) { + resp := NewChatResponse("qwen3", "The summary is concise.", inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32}) + resp.TotalDuration = 1_500_000_000 + resp.LoadDuration = 100_000_000 + resp.PromptEvalDuration = 200_000_000 + resp.EvalDuration = 1_200_000_000 + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkBuf = AppendChatResponse(make([]byte, 0, chatResponseSize(resp)), resp) + } +} + +func BenchmarkOllama_AppendGenerateResponse_Streaming(b *testing.B) { + resp := NewGenerateResponse("qwen3", "tok", inference.GenerateMetrics{}) + resp.Done = false + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkBuf = AppendGenerateResponse(make([]byte, 0, generateResponseSize(resp)), resp) + } +} + +func BenchmarkOllama_AppendGenerateResponse_Final(b *testing.B) { + resp := NewGenerateResponse("qwen3", "The summary is concise.", inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32}) + resp.TotalDuration = 1_500_000_000 + resp.LoadDuration = 100_000_000 + resp.PromptEvalDuration = 200_000_000 + resp.EvalDuration = 1_200_000_000 + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkBuf = AppendGenerateResponse(make([]byte, 0, generateResponseSize(resp)), resp) + } +} + +func BenchmarkOllama_AppendTagsResponse_OneModel(b *testing.B) { + resp := TagsResponse{Models: []ModelTag{ + {Name: "qwen3:latest", Model: "qwen3", ModifiedAt: "2026-05-21T10:00:00Z", Size: 4_500_000_000}, + }} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkBuf = AppendTagsResponse(make([]byte, 0, tagsResponseSize(resp)), resp) + } +} + +func BenchmarkOllama_AppendTagsResponse_FiveModels(b *testing.B) { + resp := TagsResponse{Models: []ModelTag{ + {Name: "qwen3:latest", Model: "qwen3", Size: 4_500_000_000}, + {Name: "gemma3:4b", Model: "gemma3", Size: 2_300_000_000}, + {Name: "llama3:8b", Model: "llama3", Size: 4_700_000_000}, + {Name: "qwen2.5:14b", Model: "qwen2.5", Size: 8_900_000_000}, + {Name: "deepseek:7b", Model: "deepseek", Size: 4_100_000_000}, + }} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkBuf = AppendTagsResponse(make([]byte, 0, tagsResponseSize(resp)), resp) + } +} + +func BenchmarkOllama_AppendTagsResponse_TwentyModels(b *testing.B) { + models := make([]ModelTag, 20) + for i := range models { + models[i] = ModelTag{ + Name: "model-bench:tag", + Model: "model-bench", + ModifiedAt: "2026-05-21T10:00:00Z", + Size: int64(4_000_000_000 + i*100_000_000), + } + } + resp := TagsResponse{Models: models} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ollamaSinkBuf = AppendTagsResponse(make([]byte, 0, tagsResponseSize(resp)), resp) + } +} + diff --git a/go/ollama/ollama_test.go b/go/ollama/ollama_test.go new file mode 100644 index 0000000..b40081a --- /dev/null +++ b/go/ollama/ollama_test.go @@ -0,0 +1,160 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ollama + +import ( + "encoding/json" + "testing" + + "dappco.re/go/inference" +) + +func TestOllama_InferenceMessages_Good(t *testing.T) { + messages := InferenceMessages([]Message{{Role: "user", Content: "hi"}}) + + if len(messages) != 1 || messages[0].Role != "user" || messages[0].Content != "hi" { + t.Fatalf("messages = %+v", messages) + } +} + +func TestOllama_GenerateOptions_Good(t *testing.T) { + opts := GenerateOptions(Options{NumPredict: 12, Temperature: 0.4, TopK: 8, TopP: 0.7}) + + cfg := inference.ApplyGenerateOpts(opts) + if cfg.MaxTokens != 12 || cfg.Temperature != 0.4 || cfg.TopK != 8 || cfg.TopP != 0.7 { + t.Fatalf("cfg = %+v", cfg) + } +} + +func TestOllama_NewResponses_Good(t *testing.T) { + metrics := inference.GenerateMetrics{PromptTokens: 5, GeneratedTokens: 6} + chat := NewChatResponse("qwen", "ok", metrics) + generate := NewGenerateResponse("qwen", "ok", metrics) + + if !chat.Done || chat.Message.Content != "ok" || chat.PromptEvalCount != 5 || chat.EvalCount != 6 { + t.Fatalf("chat = %+v", chat) + } + if !generate.Done || generate.Response != "ok" || generate.PromptEvalCount != 5 || generate.EvalCount != 6 { + t.Fatalf("generate = %+v", generate) + } +} + +// TestOllama_AppendChatResponse_WireMatchesEncodingJSON pins the +// hand-rolled AppendChatResponse output byte-for-byte against +// encoding/json.Marshal across the canonical streaming and final- +// chunk shapes the server emits. Wire compatibility is load-bearing +// — ollama-compatible clients (e.g. open-webui's stream parser) +// expect field-order-stable NDJSON. +func TestOllama_AppendChatResponse_WireMatchesEncodingJSON(t *testing.T) { + cases := []struct { + name string + in ChatResponse + }{ + {"streaming intermediate", ChatResponse{Model: "qwen3", Message: Message{Content: "tok"}, Done: false}}, + {"streaming priming", ChatResponse{Model: "qwen3", Message: Message{Role: "assistant", Content: "The"}, Done: false}}, + {"final with metrics", ChatResponse{ + Model: "qwen3", Message: Message{Role: "assistant", Content: "summary is concise."}, Done: true, + PromptEvalCount: 200, + EvalCount: 32, + TotalDuration: 1_500_000_000, + LoadDuration: 100_000_000, + PromptEvalDuration: 200_000_000, + EvalDuration: 1_200_000_000, + }}, + {"escape-heavy content", ChatResponse{Model: "qwen3", Message: Message{Content: "line1\n\"q\"\tend"}, Done: false}}, + {"empty model + message", ChatResponse{Done: true}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := AppendChatResponse(nil, tc.in) + want, err := json.Marshal(tc.in) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + if string(got) != string(want) { + t.Fatalf("wire drift:\n got = %s\nwant = %s", got, want) + } + // Round-trip through encoding/json decoder must yield the + // original struct — proves the wire output is parseable by + // downstream ollama-compat clients. + var back ChatResponse + if err := json.Unmarshal(got, &back); err != nil { + t.Fatalf("Unmarshal(%s): %v", got, err) + } + if back != tc.in { + t.Fatalf("round-trip:\n got = %+v\nwant = %+v", back, tc.in) + } + }) + } +} + +// TestOllama_AppendGenerateResponse_WireMatchesEncodingJSON mirrors +// the ChatResponse pin for /api/generate. +func TestOllama_AppendGenerateResponse_WireMatchesEncodingJSON(t *testing.T) { + cases := []struct { + name string + in GenerateResponse + }{ + {"streaming token", GenerateResponse{Model: "qwen3", Response: "tok", Done: false}}, + {"empty response", GenerateResponse{Model: "qwen3", Done: false}}, + {"final with metrics", GenerateResponse{ + Model: "qwen3", Response: "The summary is concise.", Done: true, + PromptEvalCount: 200, + EvalCount: 32, + TotalDuration: 1_500_000_000, + LoadDuration: 100_000_000, + PromptEvalDuration: 200_000_000, + EvalDuration: 1_200_000_000, + }}, + {"escape-heavy", GenerateResponse{Model: "qwen3", Response: "line1\n\"q\"\tend", Done: false}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := AppendGenerateResponse(nil, tc.in) + want, err := json.Marshal(tc.in) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + if string(got) != string(want) { + t.Fatalf("wire drift:\n got = %s\nwant = %s", got, want) + } + var back GenerateResponse + if err := json.Unmarshal(got, &back); err != nil { + t.Fatalf("Unmarshal(%s): %v", got, err) + } + if back != tc.in { + t.Fatalf("round-trip:\n got = %+v\nwant = %+v", back, tc.in) + } + }) + } +} + +// TestOllama_AppendTagsResponse_WireMatchesEncodingJSON pins the +// /api/tags discovery encoder. Covers the nil-Models / empty-slice +// difference encoding/json emits (null vs []) plus the per-tag +// omitempty semantics on Model/ModifiedAt/Size. +func TestOllama_AppendTagsResponse_WireMatchesEncodingJSON(t *testing.T) { + cases := []TagsResponse{ + {}, // nil Models -> "models":null + {Models: []ModelTag{}}, // empty slice -> "models":[] + {Models: []ModelTag{{Name: "qwen3:latest"}}}, + {Models: []ModelTag{ + {Name: "qwen3:latest", Model: "qwen3", ModifiedAt: "2026-05-21T10:00:00Z", Size: 4_500_000_000}, + }}, + {Models: []ModelTag{ + {Name: "qwen3:latest", Model: "qwen3", Size: 4_500_000_000}, + {Name: "gemma3:4b", Model: "gemma3", Size: 2_300_000_000}, + }}, + } + for _, resp := range cases { + got := AppendTagsResponse(nil, resp) + want, err := json.Marshal(resp) + if err != nil { + t.Fatalf("json.Marshal: %v", err) + } + if string(got) != string(want) { + t.Fatalf("wire drift:\n got = %s\nwant = %s", got, want) + } + } +} + diff --git a/go/ollama/unmarshal.go b/go/ollama/unmarshal.go new file mode 100644 index 0000000..58356e3 --- /dev/null +++ b/go/ollama/unmarshal.go @@ -0,0 +1,754 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Hand-rolled JSON-decoding for the Ollama wire types. Fires at +// HTTP request-entry per chat/generate call — encoding/json's +// reflect path costs 12-55 allocs on the canonical chat-shape +// turns; the single-pass walker lands at ~7-12 allocs. +// +// Same single-pass byte-walker shape as anthropic/openai. Each +// type's UnmarshalJSON dispatches by exact key byte-compare; +// unknown fields SkipJSONValue past silently (matches stdlib +// default — DisallowUnknownFields is not configured). + +package ollama + +import ( + "dappco.re/go/inference/jsonenc" +) + +// UnmarshalJSON walks the ChatRequest wire shape in a single pass. +func (r *ChatRequest) UnmarshalJSON(data []byte) error { + *r = ChatRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +func (r *ChatRequest) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "messages": + msgs, next, err := parseMessageArray(data, i) + if err != nil { + return next, err + } + r.Messages = msgs + return next, nil + case "stream": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONBool(data, i) + if err != nil { + return next, err + } + r.Stream = v + return next, nil + case "options": + opts, next, err := parseOptions(data, i) + if err != nil { + return next, err + } + r.Options = opts + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} + +// UnmarshalJSON walks the GenerateRequest wire shape. +func (r *GenerateRequest) UnmarshalJSON(data []byte) error { + *r = GenerateRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +func (r *GenerateRequest) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "prompt": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Prompt = s + return next, nil + case "stream": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONBool(data, i) + if err != nil { + return next, err + } + r.Stream = v + return next, nil + case "options": + opts, next, err := parseOptions(data, i) + if err != nil { + return next, err + } + r.Options = opts + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} + +// parseMessageArray walks a JSON array of Message objects. +func parseMessageArray(data []byte, i int) ([]Message, int, error) { + if jsonenc.IsJSONNull(data, i) { + return nil, i + 4, nil + } + i, err := jsonenc.MatchArrayStart(data, i) + if err != nil { + return nil, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == ']' { + return nil, i + 1, nil + } + var out []Message + for { + msg, next, err := parseMessage(data, i) + if err != nil { + return nil, next, err + } + out = append(out, msg) + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) { + return nil, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i = jsonenc.SkipJSONWhitespace(data, i+1) + continue + } + if data[i] == ']' { + return out, i + 1, nil + } + return nil, i, jsonenc.ErrInvalidJSON + } +} + +// parseMessage walks a single Message object. +func parseMessage(data []byte, i int) (Message, int, error) { + var msg Message + i, err := jsonenc.MatchObjectStart(data, i) + if err != nil { + return msg, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return msg, i + 1, nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return msg, i, jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return msg, next, err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return msg, i, jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "role": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return msg, vnext, verr + } + msg.Role = s + i = vnext + case "content": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return msg, vnext, verr + } + msg.Content = s + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return msg, vnext, verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return msg, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return msg, i + 1, nil + } + return msg, i, jsonenc.ErrInvalidJSON + } +} + +// parseOptions walks an Options object. +func parseOptions(data []byte, i int) (Options, int, error) { + var opts Options + if jsonenc.IsJSONNull(data, i) { + return opts, i + 4, nil + } + i, err := jsonenc.MatchObjectStart(data, i) + if err != nil { + return opts, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return opts, i + 1, nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return opts, i, jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return opts, next, err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return opts, i, jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "temperature": + v, vnext, verr := jsonenc.ParseJSONFloat32(data, i) + if verr != nil { + return opts, vnext, verr + } + opts.Temperature = v + i = vnext + case "top_k": + n, vnext, verr := jsonenc.ParseJSONInt(data, i) + if verr != nil { + return opts, vnext, verr + } + opts.TopK = int(n) + i = vnext + case "top_p": + v, vnext, verr := jsonenc.ParseJSONFloat32(data, i) + if verr != nil { + return opts, vnext, verr + } + opts.TopP = v + i = vnext + case "num_predict": + n, vnext, verr := jsonenc.ParseJSONInt(data, i) + if verr != nil { + return opts, vnext, verr + } + opts.NumPredict = int(n) + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return opts, vnext, verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return opts, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return opts, i + 1, nil + } + return opts, i, jsonenc.ErrInvalidJSON + } +} + +// UnmarshalJSON walks the ChatResponse wire shape. +func (r *ChatResponse) UnmarshalJSON(data []byte) error { + *r = ChatResponse{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +func (r *ChatResponse) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "message": + msg, next, err := parseMessage(data, i) + if err != nil { + return next, err + } + r.Message = msg + return next, nil + case "done": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONBool(data, i) + if err != nil { + return next, err + } + r.Done = v + return next, nil + case "prompt_eval_count": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.PromptEvalCount = int(n) + return next, nil + case "eval_count": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.EvalCount = int(n) + return next, nil + case "total_duration": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.TotalDuration = n + return next, nil + case "load_duration": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.LoadDuration = n + return next, nil + case "prompt_eval_duration": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.PromptEvalDuration = n + return next, nil + case "eval_duration": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.EvalDuration = n + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} + +// UnmarshalJSON walks the TagsResponse wire shape — the /api/tags +// list-models response from a client perspective. +func (r *TagsResponse) UnmarshalJSON(data []byte) error { + *r = TagsResponse{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "models": + tags, vnext, verr := parseModelTagArray(data, i) + if verr != nil { + return verr + } + r.Models = tags + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +// parseModelTagArray walks a JSON array of ModelTag objects. +func parseModelTagArray(data []byte, i int) ([]ModelTag, int, error) { + if jsonenc.IsJSONNull(data, i) { + return nil, i + 4, nil + } + i, err := jsonenc.MatchArrayStart(data, i) + if err != nil { + return nil, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == ']' { + return nil, i + 1, nil + } + var out []ModelTag + for { + tag, next, err := parseModelTag(data, i) + if err != nil { + return nil, next, err + } + out = append(out, tag) + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) { + return nil, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i = jsonenc.SkipJSONWhitespace(data, i+1) + continue + } + if data[i] == ']' { + return out, i + 1, nil + } + return nil, i, jsonenc.ErrInvalidJSON + } +} + +// parseModelTag walks a single ModelTag object. +func parseModelTag(data []byte, i int) (ModelTag, int, error) { + var tag ModelTag + i, err := jsonenc.MatchObjectStart(data, i) + if err != nil { + return tag, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return tag, i + 1, nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return tag, i, jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return tag, next, err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return tag, i, jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "name": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return tag, vnext, verr + } + tag.Name = s + i = vnext + case "model": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return tag, vnext, verr + } + tag.Model = s + i = vnext + case "modified_at": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return tag, vnext, verr + } + tag.ModifiedAt = s + i = vnext + case "size": + n, vnext, verr := jsonenc.ParseJSONInt(data, i) + if verr != nil { + return tag, vnext, verr + } + tag.Size = n + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return tag, vnext, verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return tag, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return tag, i + 1, nil + } + return tag, i, jsonenc.ErrInvalidJSON + } +} + +// UnmarshalJSON walks the GenerateResponse wire shape. +func (r *GenerateResponse) UnmarshalJSON(data []byte) error { + *r = GenerateResponse{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +func (r *GenerateResponse) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "response": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Response = s + return next, nil + case "done": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONBool(data, i) + if err != nil { + return next, err + } + r.Done = v + return next, nil + case "prompt_eval_count": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.PromptEvalCount = int(n) + return next, nil + case "eval_count": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.EvalCount = int(n) + return next, nil + case "total_duration": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.TotalDuration = n + return next, nil + case "load_duration": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.LoadDuration = n + return next, nil + case "prompt_eval_duration": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.PromptEvalDuration = n + return next, nil + case "eval_duration": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.EvalDuration = n + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} diff --git a/go/ollama/unmarshal_test.go b/go/ollama/unmarshal_test.go new file mode 100644 index 0000000..a6302ec --- /dev/null +++ b/go/ollama/unmarshal_test.go @@ -0,0 +1,158 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package ollama + +import ( + "encoding/json" + "reflect" + "testing" +) + +func TestUnmarshalChatRequest_DirectShapes(t *testing.T) { + cases := []struct { + name string + in string + want ChatRequest + }{ + { + name: "minimal", + in: `{"model":"qwen3","messages":[{"role":"user","content":"hi"}]}`, + want: ChatRequest{ + Model: "qwen3", + Messages: []Message{{Role: "user", Content: "hi"}}, + }, + }, + { + name: "with-stream-and-options", + in: `{"model":"qwen3","messages":[],"stream":true,"options":{"temperature":0.7,"top_k":64,"top_p":0.95,"num_predict":256}}`, + want: ChatRequest{ + Model: "qwen3", + Stream: true, + Options: Options{Temperature: 0.7, TopK: 64, TopP: 0.95, NumPredict: 256}, + }, + }, + { + name: "unknown-fields-ignored", + in: `{"model":"qwen3","messages":[],"future":42,"options":{"unknown":"x","temperature":0.5}}`, + want: ChatRequest{ + Model: "qwen3", + Options: Options{Temperature: 0.5}, + }, + }, + { + name: "options-null", + in: `{"model":"qwen3","messages":[],"options":null}`, + want: ChatRequest{ + Model: "qwen3", + }, + }, + { + name: "escape-heavy", + in: `{"model":"qwen3","messages":[{"role":"user","content":"a\nb"}]}`, + want: ChatRequest{ + Model: "qwen3", + Messages: []Message{{Role: "user", Content: "a\nb"}}, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got ChatRequest + if err := json.Unmarshal([]byte(tc.in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("got: %+v\nwant: %+v", got, tc.want) + } + }) + } +} + +func TestUnmarshalGenerateRequest_DirectShapes(t *testing.T) { + in := `{"model":"qwen3","prompt":"hi","stream":true,"options":{"temperature":0.7,"top_p":0.9,"num_predict":128}}` + want := GenerateRequest{ + Model: "qwen3", + Prompt: "hi", + Stream: true, + Options: Options{Temperature: 0.7, TopP: 0.9, NumPredict: 128}, + } + var got GenerateRequest + if err := json.Unmarshal([]byte(in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %+v\nwant: %+v", got, want) + } +} + +func TestUnmarshalChatResponse_DirectShapes(t *testing.T) { + in := `{"model":"qwen3","message":{"role":"assistant","content":"answer"},"done":true,"prompt_eval_count":10,"eval_count":5,"total_duration":1500000000}` + want := ChatResponse{ + Model: "qwen3", + Message: Message{Role: "assistant", Content: "answer"}, + Done: true, + PromptEvalCount: 10, + EvalCount: 5, + TotalDuration: 1500000000, + } + var got ChatResponse + if err := json.Unmarshal([]byte(in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %+v\nwant: %+v", got, want) + } +} + +func TestUnmarshalGenerateResponse_DirectShapes(t *testing.T) { + in := `{"model":"qwen3","response":"hi","done":true,"prompt_eval_count":4,"eval_count":2}` + want := GenerateResponse{ + Model: "qwen3", + Response: "hi", + Done: true, + PromptEvalCount: 4, + EvalCount: 2, + } + var got GenerateResponse + if err := json.Unmarshal([]byte(in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %+v\nwant: %+v", got, want) + } +} + +func TestUnmarshalTagsResponse_DirectShapes(t *testing.T) { + in := `{"models":[{"name":"qwen3:latest","model":"qwen3","modified_at":"2026-05-21T10:00:00Z","size":4000000000},{"name":"llama3:8b","model":"llama3","size":5000000000}]}` + want := TagsResponse{ + Models: []ModelTag{ + {Name: "qwen3:latest", Model: "qwen3", ModifiedAt: "2026-05-21T10:00:00Z", Size: 4000000000}, + {Name: "llama3:8b", Model: "llama3", Size: 5000000000}, + }, + } + var got TagsResponse + if err := json.Unmarshal([]byte(in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %+v\nwant: %+v", got, want) + } +} + +func TestUnmarshalChatRequest_InvalidShapes(t *testing.T) { + cases := []string{ + ``, + `{`, + `{"options":{`, + `{"messages":not-array}`, + `{"options":{"temperature":"hot"}}`, + } + for _, in := range cases { + t.Run(in, func(t *testing.T) { + var req ChatRequest + if err := json.Unmarshal([]byte(in), &req); err == nil { + t.Fatalf("Unmarshal(%q) returned nil error", in) + } + }) + } +} diff --git a/go/openai/chunkenc.go b/go/openai/chunkenc.go new file mode 100644 index 0000000..80abd65 --- /dev/null +++ b/go/openai/chunkenc.go @@ -0,0 +1,291 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Hand-rolled encoders for the OpenAI chat-completions wire shapes +// that fire on the streaming + non-streaming serve paths. +// +// Per-token cost matters: serveStreaming emits one ChatCompletionChunk +// per content/thought delta in the SSE loop plus a priming chunk and +// a terminating chunk. Routing each through encoding/json's reflect +// path costs an encoder state machine, a grow-doubled output buffer, +// per-pointer envelope copies, and (via core.JSONMarshalString + +// core.Concat) a separate string copy for the "data: " SSE framing. +// +// These encoders collapse the same shape into a single caller-bound +// buffer and embed the SSE framing in-line — one allocation for the +// emitted frame, no intermediate string conversion. Wire output +// matches encoding/json across every branch (round-trip locked by +// TestChatCompletionChunk_MarshalJSON_RoundTrip). + +package openai + +import "dappco.re/go/inference/jsonenc" + +// appendChatMessageDelta walks the two-field ChatMessageDelta into buf. +// Same shape and escape contract as ChatMessageDelta.MarshalJSON, but +// without the buffer-allocation hop — the chunk encoders pull it +// inline so the entire frame lands in a single backing buffer. +// +// Wire shapes (identical to encoding/json with the existing tags): +// - empty -> {} +// - role set (priming/closing) -> {"role":"X","content":"Y"} +// - content only -> {"content":"Y"} +// - both -> {"role":"X","content":"Y"} +func appendChatMessageDelta(buf []byte, d ChatMessageDelta) []byte { + if d.Role == "" && d.Content == "" { + return append(buf, '{', '}') + } + buf = append(buf, '{') + if d.Role != "" { + buf = jsonenc.AppendStringField(buf, "role", d.Role, false) + buf = jsonenc.AppendStringField(buf, "content", d.Content, true) + } else { + buf = jsonenc.AppendStringField(buf, "content", d.Content, false) + } + return append(buf, '}') +} + +// appendChatChunkChoice walks one ChatChunkChoice into buf. The +// FinishReason pointer maps to `null` (not omitted) when nil — the +// field carries no omitempty tag in the canonical shape, and the +// terminal chunk's finish_reason is the load-bearing field clients +// pivot on. +func appendChatChunkChoice(buf []byte, choice ChatChunkChoice) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendIntField(buf, "index", choice.Index, false) + buf = append(buf, ',', '"', 'd', 'e', 'l', 't', 'a', '"', ':') + buf = appendChatMessageDelta(buf, choice.Delta) + buf = append(buf, ',', '"', 'f', 'i', 'n', 'i', 's', 'h', '_', 'r', 'e', 'a', 's', 'o', 'n', '"', ':') + if choice.FinishReason == nil { + buf = append(buf, 'n', 'u', 'l', 'l') + } else { + buf = jsonenc.AppendJSONString(buf, *choice.FinishReason) + } + return append(buf, '}') +} + +// appendChatCompletionChunk walks a ChatCompletionChunk into buf. +// Field order matches the struct declaration (id, object, created, +// model, choices, thought) — encoding/json emits in that same order +// for the canonical tag set. +func appendChatCompletionChunk(buf []byte, chunk ChatCompletionChunk) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "id", chunk.ID, false) + buf = jsonenc.AppendStringField(buf, "object", chunk.Object, true) + buf = jsonenc.AppendInt64Field(buf, "created", chunk.Created, true) + buf = jsonenc.AppendStringField(buf, "model", chunk.Model, true) + buf = append(buf, ',', '"', 'c', 'h', 'o', 'i', 'c', 'e', 's', '"', ':', '[') + for i, choice := range chunk.Choices { + if i > 0 { + buf = append(buf, ',') + } + buf = appendChatChunkChoice(buf, choice) + } + buf = append(buf, ']') + if chunk.Thought != nil { + buf = append(buf, ',', '"', 't', 'h', 'o', 'u', 'g', 'h', 't', '"', ':') + buf = jsonenc.AppendJSONString(buf, *chunk.Thought) + } + return append(buf, '}') +} + +// appendChatCompletionChunkSSE writes a complete SSE frame into buf — +// the literal `data: ` prefix, the chunk JSON body, and the trailing +// `\n\n`. Lets the streaming hot path emit the whole frame in a +// single backing buffer instead of three (JSON body + Concat scratch +// + final []byte conversion). +// +// frame := appendChatCompletionChunkSSE(nil, chunk) +// w.Write(frame) +func appendChatCompletionChunkSSE(buf []byte, chunk ChatCompletionChunk) []byte { + buf = append(buf, 'd', 'a', 't', 'a', ':', ' ') + buf = appendChatCompletionChunk(buf, chunk) + return append(buf, '\n', '\n') +} + +// chunkSSEFrameSize estimates the backing-buffer size for one SSE +// frame so the streaming path allocates once. The estimate covers +// the "data: " prefix + every fixed key + a one-byte ASCII assumption +// for the variable fields + the trailing "\n\n". Pathological +// escape-heavy content lets append grow once. +func chunkSSEFrameSize(chunk ChatCompletionChunk) int { + // "data: " (6) + "{" (1) + closing "\n\n" (2) + size := 6 + 1 + 2 + size += 7 + len(chunk.ID) // "id":"...", + size += 11 + len(chunk.Object) // "object":"...", + size += 12 + 20 // "created":, + size += 10 + len(chunk.Model) // "model":"...", + size += 12 // "choices":[...] + for _, choice := range chunk.Choices { + size += 11 + 20 // "index":N, + // delta = "delta":{...} — delta wire is 30+role+content bytes worst case + size += 9 + 32 + len(choice.Delta.Role) + len(choice.Delta.Content) + size += 19 // ,"finish_reason": + if choice.FinishReason != nil { + size += len(*choice.FinishReason) + 2 + } else { + size += 4 + } + size += 3 // {} wrap + separator + } + if chunk.Thought != nil { + size += 12 + len(*chunk.Thought) // ,"thought":"..." + } + return size +} + +// Note: ChatCompletionChunk does NOT carry a MarshalJSON method. +// Adding one routes encoding/json.Marshal through a call-and-revalidate +// path that ends up slower than the reflect-walked default — every +// proxy serialisation site would pay the cost. The streaming hot +// path bypasses encoding/json entirely via appendChatCompletionChunkSSE. + +// appendChatMessage walks a ChatMessage into buf. Used by the +// non-streaming response encoder for the assistant message body. +func appendChatMessage(buf []byte, msg ChatMessage) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "role", msg.Role, false) + buf = jsonenc.AppendStringField(buf, "content", msg.Content, true) + return append(buf, '}') +} + +// appendChatChoice walks a ChatChoice (non-streaming response) into +// buf. Field order matches the struct: index, message, finish_reason. +func appendChatChoice(buf []byte, choice ChatChoice) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendIntField(buf, "index", choice.Index, false) + buf = append(buf, ',', '"', 'm', 'e', 's', 's', 'a', 'g', 'e', '"', ':') + buf = appendChatMessage(buf, choice.Message) + buf = jsonenc.AppendStringField(buf, "finish_reason", choice.FinishReason, true) + return append(buf, '}') +} + +// appendChatUsage walks a ChatUsage into buf. Three int fields in +// canonical OpenAI order. +func appendChatUsage(buf []byte, usage ChatUsage) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendIntField(buf, "prompt_tokens", usage.PromptTokens, false) + buf = jsonenc.AppendIntField(buf, "completion_tokens", usage.CompletionTokens, true) + buf = jsonenc.AppendIntField(buf, "total_tokens", usage.TotalTokens, true) + return append(buf, '}') +} + +// appendChatCompletionResponse walks the non-streaming ChatCompletion +// response into buf. Field order matches the struct declaration so +// the wire shape is byte-identical to encoding/json.Marshal output. +func appendChatCompletionResponse(buf []byte, resp ChatCompletionResponse) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "id", resp.ID, false) + buf = jsonenc.AppendStringField(buf, "object", resp.Object, true) + buf = jsonenc.AppendInt64Field(buf, "created", resp.Created, true) + buf = jsonenc.AppendStringField(buf, "model", resp.Model, true) + buf = append(buf, ',', '"', 'c', 'h', 'o', 'i', 'c', 'e', 's', '"', ':', '[') + for i, choice := range resp.Choices { + if i > 0 { + buf = append(buf, ',') + } + buf = appendChatChoice(buf, choice) + } + buf = append(buf, ']', ',', '"', 'u', 's', 'a', 'g', 'e', '"', ':') + buf = appendChatUsage(buf, resp.Usage) + if resp.Thought != nil { + buf = append(buf, ',', '"', 't', 'h', 'o', 'u', 'g', 'h', 't', '"', ':') + buf = jsonenc.AppendJSONString(buf, *resp.Thought) + } + return append(buf, '}') +} + +// appendEmbeddingResponseDatum walks one embedding-response datum +// (object, index, embedding vector) into buf. The embedding slice +// is emitted directly via strconv.AppendFloat — avoids the +// reflect-walk per-element cost that encoding/json pays. +func appendEmbeddingResponseDatum(buf []byte, datum EmbeddingResponseDatum) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "object", datum.Object, false) + buf = jsonenc.AppendIntField(buf, "index", datum.Index, true) + buf = append(buf, ',', '"', 'e', 'm', 'b', 'e', 'd', 'd', 'i', 'n', 'g', '"', ':', '[') + for i, v := range datum.Embedding { + if i > 0 { + buf = append(buf, ',') + } + buf = jsonenc.AppendFloat32(buf, v) + } + return append(buf, ']', '}') +} + +// appendEmbeddingUsage walks an inference.EmbeddingUsage into buf. +// Two int fields — prompt_tokens, total_tokens — in canonical +// OpenAI order. +func appendEmbeddingUsage(buf []byte, prompt, total int) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendIntField(buf, "prompt_tokens", prompt, false) + buf = jsonenc.AppendIntField(buf, "total_tokens", total, true) + return append(buf, '}') +} + +// appendEmbeddingResponse walks the full EmbeddingResponse shape +// into buf. The per-vector embedding fan-out is the load-bearing +// cost (a 20×1024 response emits 20480 float32 values); the hand- +// rolled walk keeps the per-element path on a single buffer with +// no reflect. +func appendEmbeddingResponse(buf []byte, resp EmbeddingResponse) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "object", resp.Object, false) + buf = append(buf, ',', '"', 'd', 'a', 't', 'a', '"', ':', '[') + for i, datum := range resp.Data { + if i > 0 { + buf = append(buf, ',') + } + buf = appendEmbeddingResponseDatum(buf, datum) + } + buf = append(buf, ']') + buf = jsonenc.AppendStringField(buf, "model", resp.Model, true) + buf = append(buf, ',', '"', 'u', 's', 'a', 'g', 'e', '"', ':') + buf = appendEmbeddingUsage(buf, resp.Usage.PromptTokens, resp.Usage.TotalTokens) + return append(buf, '}') +} + +// embeddingResponseSize estimates the backing-buffer size for one +// EmbeddingResponse so the encoder allocates once. Each float32 +// emits at most ~12 ASCII chars under the 'g' format (sign + 7 +// significant digits + exponent + dot); empirical mean across the +// embedding ranges (~ -1..+1) is ~7.9 chars + 1 separator. The +// heuristic uses 9 — under-commits on the worst case (scientific- +// notation values) and lets append grow once. +func embeddingResponseSize(resp EmbeddingResponse) int { + size := 2 // braces + size += 11 + len(resp.Object) + size += 9 // "data":[] + for _, datum := range resp.Data { + size += 12 + len(datum.Object) // {"object":"X" + size += 11 + 20 // "index":N + size += 14 // "embedding":[] + size += len(datum.Embedding) * 9 + size += 2 // } + } + size += 10 + len(resp.Model) + size += 50 // "usage":{prompt_tokens:N,total_tokens:N} + return size +} + +// chatCompletionResponseSize estimates the backing-buffer size for +// one ChatCompletionResponse so the encoder allocates once. +func chatCompletionResponseSize(resp ChatCompletionResponse) int { + size := 2 // braces + size += 7 + len(resp.ID) + size += 11 + len(resp.Object) + size += 12 + 20 + size += 10 + len(resp.Model) + size += 12 // "choices":[] + for _, choice := range resp.Choices { + // {"index":N,"message":{"role":"X","content":"Y"},"finish_reason":"Z"} + size += 12 + 20 + size += 12 + 8 + len(choice.Message.Role) + 11 + len(choice.Message.Content) + 1 + size += 18 + len(choice.FinishReason) + size += 2 + } + size += 56 // "usage":{prompt_tokens:N,completion_tokens:N,total_tokens:N} + if resp.Thought != nil { + size += 12 + len(*resp.Thought) + } + return size +} diff --git a/go/openai/chunkenc_test.go b/go/openai/chunkenc_test.go new file mode 100644 index 0000000..80c80e9 --- /dev/null +++ b/go/openai/chunkenc_test.go @@ -0,0 +1,194 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "encoding/json" + "strings" + "testing" +) + +// TestChatCompletionChunk_MarshalJSON_RoundTrip locks the hand-rolled +// chunk encoder shape to encoding/json's deserialiser. The encoder +// fires per streamed token; the wire output is consumed by both +// proxy clients and downstream services that re-decode the frame +// back into ChatCompletionChunk. +// +// Cases cover every branch the encoder walks: +// - empty (no choices, no thought) +// - priming frame (role-only delta, nil finish_reason -> null) +// - mid-stream content delta (content-only delta, nil finish) +// - thought-bearing frame (Thought pointer set) +// - terminal frame (finish_reason set) +// - escape-bearing content +func TestChatCompletionChunk_MarshalJSON_RoundTrip(t *testing.T) { + finishStop := "stop" + thought := "let me think" + cases := []struct { + name string + in ChatCompletionChunk + }{ + {"empty", ChatCompletionChunk{ID: "id", Object: "chat.completion.chunk", Created: 1700000000, Model: "qwen3"}}, + {"priming", ChatCompletionChunk{ + ID: "id", Object: "chat.completion.chunk", Created: 1700000000, Model: "qwen3", + Choices: []ChatChunkChoice{{Index: 0, Delta: ChatMessageDelta{Role: "assistant"}}}, + }}, + {"delta", ChatCompletionChunk{ + ID: "id", Object: "chat.completion.chunk", Created: 1700000000, Model: "qwen3", + Choices: []ChatChunkChoice{{Index: 0, Delta: ChatMessageDelta{Content: "Answer"}}}, + }}, + {"thought-bearing", ChatCompletionChunk{ + ID: "id", Object: "chat.completion.chunk", Created: 1700000000, Model: "qwen3", + Choices: []ChatChunkChoice{{Index: 0, Delta: ChatMessageDelta{Content: "x"}}}, + Thought: &thought, + }}, + {"terminal", ChatCompletionChunk{ + ID: "id", Object: "chat.completion.chunk", Created: 1700000000, Model: "qwen3", + Choices: []ChatChunkChoice{{Index: 0, Delta: ChatMessageDelta{}, FinishReason: &finishStop}}, + }}, + {"escapes", ChatCompletionChunk{ + ID: "id", Object: "chat.completion.chunk", Created: 1700000000, Model: "qwen3", + Choices: []ChatChunkChoice{{Index: 0, Delta: ChatMessageDelta{Content: "quote \" and tab\t"}}}, + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + // Round-trip via hand-rolled encoder. + encoded := appendChatCompletionChunk(nil, tc.in) + var back ChatCompletionChunk + if err := json.Unmarshal(encoded, &back); err != nil { + t.Fatalf("json.Unmarshal(%s) error = %v", encoded, err) + } + // Compare load-bearing fields. + if back.ID != tc.in.ID || back.Object != tc.in.Object || back.Created != tc.in.Created || back.Model != tc.in.Model { + t.Fatalf("identity: got %+v, want %+v", back, tc.in) + } + if len(back.Choices) != len(tc.in.Choices) { + t.Fatalf("choices len = %d, want %d", len(back.Choices), len(tc.in.Choices)) + } + for i := range tc.in.Choices { + if back.Choices[i].Index != tc.in.Choices[i].Index { + t.Fatalf("choices[%d].index = %d, want %d", i, back.Choices[i].Index, tc.in.Choices[i].Index) + } + if back.Choices[i].Delta.Role != tc.in.Choices[i].Delta.Role || back.Choices[i].Delta.Content != tc.in.Choices[i].Delta.Content { + t.Fatalf("choices[%d].delta = %+v, want %+v", i, back.Choices[i].Delta, tc.in.Choices[i].Delta) + } + gotFinish := back.Choices[i].FinishReason + wantFinish := tc.in.Choices[i].FinishReason + if (gotFinish == nil) != (wantFinish == nil) { + t.Fatalf("choices[%d].finish_reason nil mismatch: got=%v want=%v", i, gotFinish, wantFinish) + } + if gotFinish != nil && *gotFinish != *wantFinish { + t.Fatalf("choices[%d].finish_reason = %q, want %q", i, *gotFinish, *wantFinish) + } + } + if (back.Thought == nil) != (tc.in.Thought == nil) { + t.Fatalf("thought nil mismatch: got=%v want=%v", back.Thought, tc.in.Thought) + } + if back.Thought != nil && *back.Thought != *tc.in.Thought { + t.Fatalf("thought = %q, want %q", *back.Thought, *tc.in.Thought) + } + }) + } +} + +// TestChatCompletionResponse_AppendRoundTrip locks the hand-rolled +// non-streaming response encoder against encoding/json. The wire +// shape is consumed by every OpenAI-compatible client on the +// non-streaming chat-completions endpoint. +func TestChatCompletionResponse_AppendRoundTrip(t *testing.T) { + thought := "let me think" + cases := []struct { + name string + in ChatCompletionResponse + }{ + {"minimal", ChatCompletionResponse{ + ID: "chatcmpl-x", Object: "chat.completion", Created: 1700000000, Model: "qwen3", + Choices: []ChatChoice{{ + Index: 0, + Message: ChatMessage{Role: "assistant", Content: "Hi"}, + FinishReason: "stop", + }}, + Usage: ChatUsage{PromptTokens: 3, CompletionTokens: 4, TotalTokens: 7}, + }}, + {"with-thought", ChatCompletionResponse{ + ID: "chatcmpl-x", Object: "chat.completion", Created: 1700000000, Model: "qwen3", + Choices: []ChatChoice{{ + Index: 0, + Message: ChatMessage{Role: "assistant", Content: "Answer"}, + FinishReason: "length", + }}, + Usage: ChatUsage{PromptTokens: 10, CompletionTokens: 20, TotalTokens: 30}, + Thought: &thought, + }}, + {"escapes", ChatCompletionResponse{ + ID: "chatcmpl-x", Object: "chat.completion", Created: 1700000000, Model: "qwen3", + Choices: []ChatChoice{{ + Index: 0, + Message: ChatMessage{Role: "assistant", Content: "quote \" backslash \\"}, + FinishReason: "stop", + }}, + Usage: ChatUsage{PromptTokens: 1, CompletionTokens: 2, TotalTokens: 3}, + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + encoded := appendChatCompletionResponse(nil, tc.in) + var back ChatCompletionResponse + if err := json.Unmarshal(encoded, &back); err != nil { + t.Fatalf("json.Unmarshal(%s) error = %v", encoded, err) + } + if back.ID != tc.in.ID || back.Object != tc.in.Object || back.Created != tc.in.Created || back.Model != tc.in.Model { + t.Fatalf("identity: got %+v, want %+v", back, tc.in) + } + if back.Usage != tc.in.Usage { + t.Fatalf("usage: got %+v, want %+v", back.Usage, tc.in.Usage) + } + if len(back.Choices) != len(tc.in.Choices) { + t.Fatalf("choices len = %d, want %d", len(back.Choices), len(tc.in.Choices)) + } + for i := range tc.in.Choices { + if back.Choices[i].Index != tc.in.Choices[i].Index || + back.Choices[i].Message.Role != tc.in.Choices[i].Message.Role || + back.Choices[i].Message.Content != tc.in.Choices[i].Message.Content || + back.Choices[i].FinishReason != tc.in.Choices[i].FinishReason { + t.Fatalf("choices[%d] mismatch: got %+v want %+v", i, back.Choices[i], tc.in.Choices[i]) + } + } + if (back.Thought == nil) != (tc.in.Thought == nil) { + t.Fatalf("thought nil mismatch: got=%v want=%v", back.Thought, tc.in.Thought) + } + if back.Thought != nil && *back.Thought != *tc.in.Thought { + t.Fatalf("thought = %q, want %q", *back.Thought, *tc.in.Thought) + } + }) + } +} + +// TestChatCompletionChunk_SSEFrame verifies the SSE framing helper — +// the streaming hot path embeds "data: " prefix + body + "\n\n" in +// one buffer. Output must match what proxy clients parse as one SSE +// event (LL-formatted: line "data: " terminated by blank line). +func TestChatCompletionChunk_SSEFrame(t *testing.T) { + finish := "stop" + chunk := ChatCompletionChunk{ + ID: "chatcmpl-test", Object: "chat.completion.chunk", Created: 1700000000, Model: "qwen3", + Choices: []ChatChunkChoice{{Index: 0, Delta: ChatMessageDelta{}, FinishReason: &finish}}, + } + frame := appendChatCompletionChunkSSE(nil, chunk) + frameStr := string(frame) + if !strings.HasPrefix(frameStr, "data: ") { + t.Fatalf("frame missing data: prefix: %q", frameStr) + } + if !strings.HasSuffix(frameStr, "\n\n") { + t.Fatalf("frame missing trailing newlines: %q", frameStr) + } + body := strings.TrimSuffix(strings.TrimPrefix(frameStr, "data: "), "\n\n") + var back ChatCompletionChunk + if err := json.Unmarshal([]byte(body), &back); err != nil { + t.Fatalf("frame body json.Unmarshal error: %v body=%q", err, body) + } + if back.ID != chunk.ID || back.Choices[0].FinishReason == nil || *back.Choices[0].FinishReason != "stop" { + t.Fatalf("frame body decoded mismatch: %+v", back) + } +} diff --git a/go/openai/embedding_enc_test.go b/go/openai/embedding_enc_test.go new file mode 100644 index 0000000..156211d --- /dev/null +++ b/go/openai/embedding_enc_test.go @@ -0,0 +1,85 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "encoding/json" + "math" + "testing" + + "dappco.re/go/inference" +) + +// TestEmbeddingResponse_AppendRoundTrip locks the hand-rolled +// embedding-response encoder against encoding/json's deserialiser. +// The wire shape is consumed by every OpenAI-compatible embedding +// client; round-trip on every embedding-model output preserves the +// per-element float32 values within the standard 'g' precision the +// stdlib emits. +func TestEmbeddingResponse_AppendRoundTrip(t *testing.T) { + cases := []struct { + name string + in EmbeddingResponse + }{ + {"single-vector", EmbeddingResponse{ + Object: "list", + Data: []EmbeddingResponseDatum{{ + Object: "embedding", + Index: 0, + Embedding: []float32{0.1, -0.2, 0.75, 1.0}, + }}, + Model: "qwen3-embed", + Usage: inference.EmbeddingUsage{PromptTokens: 4, TotalTokens: 4}, + }}, + {"multi-vector", EmbeddingResponse{ + Object: "list", + Data: []EmbeddingResponseDatum{ + {Object: "embedding", Index: 0, Embedding: []float32{0.0, 0.5}}, + {Object: "embedding", Index: 1, Embedding: []float32{-1.0, 1.0}}, + {Object: "embedding", Index: 2, Embedding: []float32{1e-5, 1e5}}, + }, + Model: "qwen3-embed", + Usage: inference.EmbeddingUsage{PromptTokens: 12, TotalTokens: 12}, + }}, + {"empty-vectors", EmbeddingResponse{ + Object: "list", + Data: []EmbeddingResponseDatum{{Object: "embedding", Index: 0, Embedding: []float32{}}}, + Model: "qwen3-embed", + Usage: inference.EmbeddingUsage{PromptTokens: 0, TotalTokens: 0}, + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + encoded := appendEmbeddingResponse(nil, tc.in) + var back EmbeddingResponse + if err := json.Unmarshal(encoded, &back); err != nil { + t.Fatalf("json.Unmarshal(%s) error = %v", encoded, err) + } + if back.Object != tc.in.Object || back.Model != tc.in.Model { + t.Fatalf("identity: got %+v, want %+v", back, tc.in) + } + if back.Usage != tc.in.Usage { + t.Fatalf("usage: got %+v, want %+v", back.Usage, tc.in.Usage) + } + if len(back.Data) != len(tc.in.Data) { + t.Fatalf("data len = %d, want %d", len(back.Data), len(tc.in.Data)) + } + for i := range tc.in.Data { + if back.Data[i].Object != tc.in.Data[i].Object || back.Data[i].Index != tc.in.Data[i].Index { + t.Fatalf("data[%d] header: got %+v want %+v", i, back.Data[i], tc.in.Data[i]) + } + if len(back.Data[i].Embedding) != len(tc.in.Data[i].Embedding) { + t.Fatalf("data[%d].embedding len = %d, want %d", i, len(back.Data[i].Embedding), len(tc.in.Data[i].Embedding)) + } + for j, v := range tc.in.Data[i].Embedding { + if math.IsNaN(float64(v)) { + continue + } + if back.Data[i].Embedding[j] != v { + t.Fatalf("data[%d].embedding[%d] = %v, want %v", i, j, back.Data[i].Embedding[j], v) + } + } + } + }) + } +} diff --git a/go/openai/jsondec.go b/go/openai/jsondec.go new file mode 100644 index 0000000..db8d7f5 --- /dev/null +++ b/go/openai/jsondec.go @@ -0,0 +1,27 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Hand-rolled JSON-decoding adapters for the openai variant-shape +// unmarshallers. The walker primitives now live in jsonenc/ so that +// anthropic + ollama field-dispatch UnmarshalJSON paths can share +// the same byte-pump (lifted from this file in W11-B). The shapes +// this file owns — StopList / EmbeddingInput — both reduce to +// `ParseJSONStringList`, so the helpers here are thin variant-shape +// dispatchers. +// +// Per-call performance unchanged from the W10-M baseline — the +// underlying byte walker is identical. + +package openai + +import "dappco.re/go/inference/jsonenc" + +// parseJSONStringList walks data as either a JSON string (e.g. +// `"END"`) or an array of JSON strings (e.g. `["END",""]`) and +// returns a []string with the inner values unescaped. +// +// Forwards to jsonenc.ParseJSONStringList — kept under the package- +// local name so existing call sites (StopList / EmbeddingInput) need +// no churn. +func parseJSONStringList(data []byte) ([]string, error) { + return jsonenc.ParseJSONStringList(data) +} diff --git a/go/openai/jsondec_test.go b/go/openai/jsondec_test.go new file mode 100644 index 0000000..8d7b058 --- /dev/null +++ b/go/openai/jsondec_test.go @@ -0,0 +1,70 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "reflect" + "testing" +) + +// TestParseJSONStringList_RoundTrip locks the hand-rolled +// string-or-array walker against the documented input/output +// contract. Cases cover every branch: null literal, plain string, +// empty array, single-element array, multi-element array, and +// every escape form the JSON spec recognises. +func TestParseJSONStringList_RoundTrip(t *testing.T) { + cases := []struct { + name string + in string + want []string + }{ + {"null", "null", nil}, + {"null-with-whitespace", " null\t", nil}, + {"plain-string", `"END"`, []string{"END"}}, + {"string-with-escapes", `"line1\nline2"`, []string{"line1\nline2"}}, + {"string-with-quote", `"he said \"hi\""`, []string{`he said "hi"`}}, + {"string-with-unicode", `"é"`, []string{"é"}}, + {"empty-array", `[]`, nil}, + {"single-element-array", `["END"]`, []string{"END"}}, + {"multi-element-array", `["A","B","C"]`, []string{"A", "B", "C"}}, + {"array-with-whitespace", ` [ "A" , "B" ] `, []string{"A", "B"}}, + {"array-with-escapes", `["\t","\n"]`, []string{"\t", "\n"}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := parseJSONStringList([]byte(tc.in)) + if err != nil { + t.Fatalf("parseJSONStringList(%s) error = %v", tc.in, err) + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("parseJSONStringList(%s) = %v, want %v", tc.in, got, tc.want) + } + }) + } +} + +// TestParseJSONStringList_Invalid asserts the walker rejects +// malformed inputs cleanly — no panics, just errors. +func TestParseJSONStringList_Invalid(t *testing.T) { + cases := []string{ + "", + " ", + `{`, + `}`, + `"unterminated`, + `[`, + `["unterminated`, + `["A"`, + `["A",]`, + `[123]`, + `tru`, + } + for _, in := range cases { + t.Run(in, func(t *testing.T) { + _, err := parseJSONStringList([]byte(in)) + if err == nil { + t.Fatalf("parseJSONStringList(%q) returned nil error, want error", in) + } + }) + } +} diff --git a/go/openai/jsonenc_test.go b/go/openai/jsonenc_test.go new file mode 100644 index 0000000..c7f3847 --- /dev/null +++ b/go/openai/jsonenc_test.go @@ -0,0 +1,58 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "encoding/json" + "testing" +) + +// TestChatMessageDelta_MarshalJSON_RoundTrip locks the hand-rolled +// encoder shape against encoding/json's deserialiser. The encoder +// is on the streaming hot path — every SSE delta + priming + close +// chunk routes through it, so its output must round-trip cleanly +// back into ChatMessageDelta with no field drift. +// +// Cases cover every branch the encoder walks: +// - empty struct -> "{}" +// - role-only -> emits both role and content:"" (priming chunk) +// - content-only -> emits content only +// - both set -> both fields +// - escape body -> control/quote/backslash characters in content +func TestChatMessageDelta_MarshalJSON_RoundTrip(t *testing.T) { + cases := []struct { + name string + in ChatMessageDelta + want string + }{ + {"empty", ChatMessageDelta{}, `{}`}, + {"role-only", ChatMessageDelta{Role: "assistant"}, `{"role":"assistant","content":""}`}, + {"content-only", ChatMessageDelta{Content: "hello"}, `{"content":"hello"}`}, + {"both", ChatMessageDelta{Role: "assistant", Content: "world"}, `{"role":"assistant","content":"world"}`}, + {"escapes", ChatMessageDelta{Content: "quote \" backslash \\ tab\tnewline\n"}, + `{"content":"quote \" backslash \\ tab\tnewline\n"}`}, + {"control", ChatMessageDelta{Content: "\x01\x02"}, `{"content":"\u0001\u0002"}`}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + encoded, err := tc.in.MarshalJSON() + if err != nil { + t.Fatalf("MarshalJSON() error = %v", err) + } + if string(encoded) != tc.want { + t.Fatalf("MarshalJSON() = %s, want %s", encoded, tc.want) + } + // Round-trip via encoding/json — the streaming chunk + // types wrap ChatMessageDelta and the proxy clients + // consuming the stream feed it back into the same Go + // types. + var back ChatMessageDelta + if err := json.Unmarshal(encoded, &back); err != nil { + t.Fatalf("json.Unmarshal(%s) error = %v", encoded, err) + } + if back.Role != tc.in.Role || back.Content != tc.in.Content { + t.Fatalf("round-trip: got %+v, want %+v", back, tc.in) + } + }) + } +} diff --git a/go/openai/openai.go b/go/openai/openai.go new file mode 100644 index 0000000..57d2f5d --- /dev/null +++ b/go/openai/openai.go @@ -0,0 +1,999 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package openai adapts inference.TextModel implementations to the +// OpenAI-compatible chat completions wire format. +package openai + +import ( + "context" + "io" + "net/http" + "strconv" + "sync" + "time" + "unicode" + + core "dappco.re/go" + "dappco.re/go/inference" + "dappco.re/go/inference/jsonenc" +) + +const DefaultChatCompletionsPath = "/v1/chat/completions" + +const ( + DefaultTemperature = 1.0 + DefaultTopP = 0.95 + DefaultTopK = 64 + DefaultMaxTokens = 2048 +) + +const channelMarker = "<|channel>" + +// ChatCompletionRequest is the OpenAI-compatible request body. +type ChatCompletionRequest struct { + Model string `json:"model"` + Messages []ChatMessage `json:"messages"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + MaxTokens *int `json:"max_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop StopList `json:"stop,omitempty"` + User string `json:"user,omitempty"` +} + +// StopList accepts OpenAI stop sequences as either a JSON string or string +// array. +type StopList []string + +func (s *StopList) UnmarshalJSON(data []byte) error { + // Hot path: this is called per OpenAI chat-completion request. + // parseJSONStringList walks the variant string-or-array shape in + // a single pass — drops the recursive core.JSONUnmarshal that + // re-paid encoder-state + per-element string allocs on every + // call. Same wire contract: null -> nil, "X" -> []string{"X"}, + // ["X","Y"] -> []string{"X","Y"}. + values, err := parseJSONStringList(data) + if err != nil { + return err + } + *s = values + return nil +} + +// ChatMessage is a single chat turn. +type ChatMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ChatCompletionResponse is the non-streaming OpenAI-compatible response body. +type ChatCompletionResponse struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChoice `json:"choices"` + Usage ChatUsage `json:"usage"` + Thought *string `json:"thought,omitempty"` +} + +type ChatChoice struct { + Index int `json:"index"` + Message ChatMessage `json:"message"` + FinishReason string `json:"finish_reason"` +} + +type ChatUsage struct { + PromptTokens int `json:"prompt_tokens"` + CompletionTokens int `json:"completion_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// ChatCompletionChunk is one Server-Sent Event payload for streaming requests. +type ChatCompletionChunk struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Choices []ChatChunkChoice `json:"choices"` + Thought *string `json:"thought,omitempty"` +} + +type ChatChunkChoice struct { + Index int `json:"index"` + Delta ChatMessageDelta `json:"delta"` + FinishReason *string `json:"finish_reason"` +} + +type ChatMessageDelta struct { + Role string `json:"role,omitempty"` + Content string `json:"content,omitempty"` +} + +// MarshalJSON hand-rolls the OpenAI ChatMessageDelta shape into a +// single caller-owned buffer. Fires per streamed SSE delta — the +// reflect path through encoding/json + the intermediate *string +// envelope structs together cost 4-5 allocs per call (encoder state, +// grow-doubled output, two pointer-string copies, JSONMarshalString +// AsString wrap). Hand-roll lands at 1 alloc for the typical +// content-only case and the role-priming case. +// +// Wire-compatible cases (matches the previous behaviour): +// - Role == "" && Content == "" -> {} +// - Role set -> {"role":"X","content":"Y"} (priming emits both) +// - Content only -> {"content":"Y"} +// +// Empty case routes to the package-level emptyDeltaBytes — no alloc. +func (d ChatMessageDelta) MarshalJSON() ([]byte, error) { + if d.Role == "" && d.Content == "" { + return emptyDeltaBytes, nil + } + // Tight upper bound — both branches emit two ASCII keys plus the + // quoted value bodies. Worst-case doubling on escape-heavy content + // lets append grow once. + size := 2 // braces + if d.Role != "" { + size += 9 + len(d.Role) // "role":"...", + size += 11 + len(d.Content) // "content":"..." + } else { + size += 11 + len(d.Content) // "content":"..." + } + buf := make([]byte, 0, size) + buf = append(buf, '{') + if d.Role != "" { + buf = jsonenc.AppendStringField(buf, "role", d.Role, false) + buf = jsonenc.AppendStringField(buf, "content", d.Content, true) + } else { + buf = jsonenc.AppendStringField(buf, "content", d.Content, false) + } + return append(buf, '}'), nil +} + +// emptyDeltaBytes is the canonical "{}" slice returned for the +// no-fields case — shared across every priming/closing chunk that +// would otherwise allocate a fresh two-byte slice per call. +var emptyDeltaBytes = []byte("{}") + +type ErrorResponse struct { + Error ErrorObject `json:"error"` +} + +type ErrorObject struct { + Message string `json:"message"` + Type string `json:"type"` + Param string `json:"param,omitempty"` + Code string `json:"code"` +} + +// DecodeRequest decodes an OpenAI-compatible chat completion request. +func DecodeRequest(body io.Reader) (ChatCompletionRequest, error) { + if body == nil { + return ChatCompletionRequest{}, core.E("openai.DecodeRequest", "request body is nil", nil) + } + data, err := io.ReadAll(body) + if err != nil { + return ChatCompletionRequest{}, core.E("openai.DecodeRequest", "read request body", err) + } + var req ChatCompletionRequest + // Direct []byte path — skips the redundant []byte→string→[]byte + // round-trip that JSONUnmarshalString(string(data), ...) would do. + result := core.JSONUnmarshal(data, &req) + if !result.OK { + return ChatCompletionRequest{}, resultError(result) + } + return req, nil +} + +// ValidateRequest validates the subset of the OpenAI request shape supported by +// this adapter. +func ValidateRequest(req ChatCompletionRequest) error { + if core.Trim(req.Model) == "" { + return requestError("model is required", "model") + } + if len(req.Messages) == 0 { + return requestError("messages must be a non-empty array", "messages") + } + for i, msg := range req.Messages { + role := core.Lower(core.Trim(msg.Role)) + switch role { + case "system", "developer", "user", "assistant", "tool": + default: + return requestError(core.Sprintf("messages[%d].role must be system, developer, user, assistant, or tool", i), core.Sprintf("messages[%d].role", i)) + } + } + if req.Temperature != nil && (*req.Temperature < 0 || *req.Temperature > 2) { + return requestError("temperature must be in [0, 2]", "temperature") + } + if req.TopP != nil && (*req.TopP < 0 || *req.TopP > 1) { + return requestError("top_p must be in [0, 1]", "top_p") + } + if req.TopK != nil && *req.TopK < 0 { + return requestError("top_k must be >= 0", "top_k") + } + if req.MaxTokens != nil && *req.MaxTokens < 0 { + return requestError("max_tokens must be >= 0", "max_tokens") + } + return nil +} + +// GenerateOptions converts request sampling fields into inference options. +func GenerateOptions(req ChatCompletionRequest) ([]inference.GenerateOption, error) { + if err := ValidateRequest(req); err != nil { + return nil, err + } + return []inference.GenerateOption{ + inference.WithTemperature(resolvedFloat(req.Temperature, DefaultTemperature)), + inference.WithTopP(resolvedFloat(req.TopP, DefaultTopP)), + inference.WithTopK(resolvedInt(req.TopK, DefaultTopK)), + inference.WithMaxTokens(resolvedInt(req.MaxTokens, DefaultMaxTokens)), + }, nil +} + +func resolvedFloat(value *float32, fallback float32) float32 { + if value == nil { + return fallback + } + return *value +} + +func resolvedInt(value *int, fallback int) int { + if value == nil { + return fallback + } + return *value +} + +// NormalizeStopSequences trims and validates request stop strings. +func NormalizeStopSequences(stops StopList) ([]string, error) { + if len(stops) == 0 { + return nil, nil + } + out := make([]string, 0, len(stops)) + for _, stop := range stops { + trimmed := core.Trim(stop) + if trimmed == "" { + return nil, requestError("stop sequences must not be empty", "stop") + } + out = append(out, trimmed) + } + return out, nil +} + +// Resolver maps request model names to loaded inference models. +type Resolver interface { + ResolveModel(ctx context.Context, name string) (inference.TextModel, error) +} + +type ResolverFunc func(context.Context, string) (inference.TextModel, error) + +func (fn ResolverFunc) ResolveModel(ctx context.Context, name string) (inference.TextModel, error) { + if fn == nil { + return nil, core.E("openai.ResolverFunc", "resolver is nil", nil) + } + return fn(ctx, name) +} + +type StaticResolver struct { + models map[string]inference.TextModel +} + +func NewStaticResolver(models map[string]inference.TextModel) *StaticResolver { + resolver := &StaticResolver{models: make(map[string]inference.TextModel, len(models))} + for name, model := range models { + resolver.models[core.Lower(core.Trim(name))] = model + } + return resolver +} + +func (r *StaticResolver) ResolveModel(ctx context.Context, name string) (inference.TextModel, error) { + if ctx != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + if r == nil { + return nil, core.E("openai.StaticResolver", "resolver is nil", nil) + } + model, ok := r.models[core.Lower(core.Trim(name))] + if !ok || model == nil { + return nil, core.E("openai.StaticResolver", core.Sprintf("model %q not found", name), nil) + } + return model, nil +} + +// BackendResolver lazily loads one model through the inference backend registry. +type BackendResolver struct { + BackendName string + ModelPath string + LoadOptions []inference.LoadOption + + mu sync.Mutex + model inference.TextModel +} + +func NewBackendResolver(backendName, modelPath string, opts ...inference.LoadOption) *BackendResolver { + return &BackendResolver{ + BackendName: core.Trim(backendName), + ModelPath: core.Trim(modelPath), + LoadOptions: append([]inference.LoadOption(nil), opts...), + } +} + +func (r *BackendResolver) ResolveModel(ctx context.Context, _ string) (inference.TextModel, error) { + if ctx != nil { + select { + case <-ctx.Done(): + return nil, ctx.Err() + default: + } + } + if r == nil { + return nil, core.E("openai.BackendResolver", "resolver is nil", nil) + } + if r.ModelPath == "" { + return nil, core.E("openai.BackendResolver", "model path is required", nil) + } + r.mu.Lock() + defer r.mu.Unlock() + if r.model != nil { + return r.model, nil + } + opts := append([]inference.LoadOption(nil), r.LoadOptions...) + if r.BackendName != "" { + opts = append(opts, inference.WithBackend(r.BackendName)) + } + result := inference.LoadModel(r.ModelPath, opts...) + if !result.OK { + return nil, resultError(result) + } + model, ok := result.Value.(inference.TextModel) + if !ok || model == nil { + return nil, core.E("openai.BackendResolver", "loaded value is not an inference.TextModel", nil) + } + r.model = model + return model, nil +} + +// Handler serves OpenAI-compatible chat completion requests. +type Handler struct { + resolver Resolver +} + +func NewHandler(resolver Resolver) *Handler { + return &Handler{resolver: resolver} +} + +func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if h == nil || h.resolver == nil { + writeError(w, http.StatusServiceUnavailable, "chat handler is not configured", "model") + return + } + if r == nil { + writeError(w, http.StatusBadRequest, "request is nil", "request") + return + } + if r.Method != http.MethodPost { + w.Header().Set("Allow", http.MethodPost) + writeError(w, http.StatusMethodNotAllowed, "method not allowed", "method") + return + } + req, err := DecodeRequest(r.Body) + if err != nil { + writeError(w, http.StatusBadRequest, "invalid request body", "body") + return + } + if err := ValidateRequest(req); err != nil { + writeError(w, http.StatusBadRequest, err.Error(), errorParam(err)) + return + } + stops, err := NormalizeStopSequences(req.Stop) + if err != nil { + writeError(w, http.StatusBadRequest, err.Error(), "stop") + return + } + opts, err := GenerateOptions(ChatCompletionRequest{ + Model: req.Model, + Messages: req.Messages, + Temperature: req.Temperature, + TopP: req.TopP, + TopK: req.TopK, + MaxTokens: req.MaxTokens, + }) + if err != nil { + writeError(w, http.StatusBadRequest, err.Error(), errorParam(err)) + return + } + model, err := h.resolver.ResolveModel(r.Context(), req.Model) + if err != nil { + writeError(w, http.StatusNotFound, err.Error(), "model") + return + } + messages := requestMessages(req.Messages) + if req.Stream { + h.serveStreaming(w, r, model, req, messages, stops, opts...) + return + } + h.serveNonStreaming(w, r, model, req, messages, stops, opts...) +} + +func (h *Handler) serveNonStreaming(w http.ResponseWriter, r *http.Request, model inference.TextModel, req ChatCompletionRequest, messages []inference.Message, stops []string, opts ...inference.GenerateOption) { + created := time.Now().Unix() + completionID := completionID() + extractor := NewThinkingExtractor() + for token := range model.Chat(r.Context(), messages, opts...) { + extractor.Process(token) + } + visibleTail, thoughtTail := extractor.Flush() + _ = visibleTail + _ = thoughtTail + if err := model.Err(); err != nil { + writeError(w, http.StatusInternalServerError, err.Error(), "model") + return + } + metrics := model.Metrics() + content := TruncateAtStopSequence(extractor.Content(), stops) + finishReason := "stop" + if isTokenLengthCapReached(req.MaxTokens, metrics.GeneratedTokens) { + finishReason = "length" + } + response := ChatCompletionResponse{ + ID: completionID, + Object: "chat.completion", + Created: created, + Model: req.Model, + Choices: []ChatChoice{{ + Index: 0, + Message: ChatMessage{Role: "assistant", Content: content}, + FinishReason: finishReason, + }}, + Usage: ChatUsage{ + PromptTokens: metrics.PromptTokens, + CompletionTokens: metrics.GeneratedTokens, + TotalTokens: metrics.PromptTokens + metrics.GeneratedTokens, + }, + } + if thought := extractor.Thinking(); thought != "" { + response.Thought = &thought + } + writeJSON(w, http.StatusOK, response) +} + +func (h *Handler) serveStreaming(w http.ResponseWriter, r *http.Request, model inference.TextModel, req ChatCompletionRequest, messages []inference.Message, stops []string, opts ...inference.GenerateOption) { + w.Header().Set("Content-Type", "text/event-stream") + w.Header().Set("Cache-Control", "no-cache") + w.Header().Set("Connection", "keep-alive") + w.WriteHeader(http.StatusOK) + + created := time.Now().Unix() + completionID := completionID() + flusher, _ := w.(http.Flusher) + writeChunk := func(chunk ChatCompletionChunk) { + // Single-buffer SSE frame — the previous shape did + // JSONMarshalString (reflect path + grow-doubled scratch + // buffer) then Concat to wrap with "data: " / "\n\n" then + // []byte conversion. appendChatCompletionChunkSSE walks the + // chunk directly into a pre-sized buffer that already carries + // the SSE framing. + frame := appendChatCompletionChunkSSE(make([]byte, 0, chunkSSEFrameSize(chunk)), chunk) + _, _ = w.Write(frame) + if flusher != nil { + flusher.Flush() + } + } + writeChunk(ChatCompletionChunk{ + ID: completionID, + Object: "chat.completion.chunk", + Created: created, + Model: req.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatMessageDelta{Role: "assistant"}, + }}, + }) + + extractor := NewThinkingExtractor() + emittedContent := "" + finishReason := "stop" + for token := range model.Chat(r.Context(), messages, opts...) { + contentDelta, thoughtDelta := extractor.Process(token) + candidate := emittedContent + contentDelta + stopCut, stopHit := firstStopSequenceCut(candidate, stops) + if stopHit { + if stopCut <= len(emittedContent) { + contentDelta = "" + } else { + contentDelta = candidate[len(emittedContent):stopCut] + } + } + if contentDelta != "" || thoughtDelta != "" { + chunk := ChatCompletionChunk{ + ID: completionID, + Object: "chat.completion.chunk", + Created: created, + Model: req.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatMessageDelta{Content: contentDelta}, + }}, + } + if thoughtDelta != "" { + chunk.Thought = &thoughtDelta + } + writeChunk(chunk) + } + if stopHit { + emittedContent = candidate[:stopCut] + break + } + emittedContent = candidate + } + if visibleTail, thoughtTail := extractor.Flush(); visibleTail != "" || thoughtTail != "" { + chunk := ChatCompletionChunk{ + ID: completionID, + Object: "chat.completion.chunk", + Created: created, + Model: req.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatMessageDelta{Content: visibleTail}, + }}, + } + if thoughtTail != "" { + chunk.Thought = &thoughtTail + } + writeChunk(chunk) + } + if err := model.Err(); err != nil { + finishReason = "error" + } + if finishReason != "error" && isTokenLengthCapReached(req.MaxTokens, model.Metrics().GeneratedTokens) { + finishReason = "length" + } + writeChunk(ChatCompletionChunk{ + ID: completionID, + Object: "chat.completion.chunk", + Created: created, + Model: req.Model, + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatMessageDelta{}, + FinishReason: &finishReason, + }}, + }) + _, _ = w.Write([]byte("data: [DONE]\n\n")) + if flusher != nil { + flusher.Flush() + } +} + +func requestMessages(messages []ChatMessage) []inference.Message { + out := make([]inference.Message, 0, len(messages)) + for _, msg := range messages { + out = append(out, inference.Message{Role: msg.Role, Content: msg.Content}) + } + return out +} + +func writeJSON(w http.ResponseWriter, status int, payload any) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(status) + // Hand-rolled fast path for the canonical non-streaming + // ChatCompletionResponse — fires once per served request and + // previously paid 2 allocs / 432 B through the reflect path. + // Encoding directly into a pre-sized buffer skips + // JSONMarshalString + the []byte(string) conversion. + if p, ok := payload.(ChatCompletionResponse); ok { + buf := appendChatCompletionResponse(make([]byte, 0, chatCompletionResponseSize(p)), p) + _, _ = w.Write(buf) + return + } + if p, ok := payload.(EmbeddingResponse); ok { + // Embedding responses scale with vector dimensionality — + // a 20-input × 1024-dim response is ~190 KB. The reflect + // path pays a per-element float32 marshal cost; the hand- + // rolled walk emits directly via strconv.AppendFloat. + buf := appendEmbeddingResponse(make([]byte, 0, embeddingResponseSize(p)), p) + _, _ = w.Write(buf) + return + } + if p, ok := payload.(Response); ok { + // Responses API non-streaming body — fires per served + // /v1/responses request. Same shape as ChatCompletionResponse + // (id/object/created/model/output/usage/thought) but with + // the Responses output-message envelope. + buf := appendResponse(make([]byte, 0, responseSize(p)), p) + _, _ = w.Write(buf) + return + } + if p, ok := payload.(RerankResponse); ok { + // Rerank results scale with the documents slice — walking + // inference.RerankScore inline skips the per-element reflect + // cost. Labels field is rarely set in practice; encoder + // handles both shapes. + buf := appendRerankResponse(make([]byte, 0, rerankResponseSize(p)), p) + _, _ = w.Write(buf) + return + } + result := core.JSONMarshal(payload) + if !result.OK { + _, _ = w.Write([]byte(`{}`)) + return + } + _, _ = w.Write(result.Value.([]byte)) +} + +func writeError(w http.ResponseWriter, status int, message, param string) { + writeJSON(w, status, ErrorResponse{Error: ErrorObject{ + Message: message, + Type: "invalid_request_error", + Param: param, + Code: "invalid_request_error", + }}) +} + +type requestValidationError struct { + message string + param string +} + +func (e *requestValidationError) Error() string { + if e == nil { + return "" + } + return e.message +} + +func requestError(message, param string) error { + return &requestValidationError{message: message, param: param} +} + +func errorParam(err error) string { + if validation, ok := err.(*requestValidationError); ok { + return validation.param + } + return "" +} + +func resultError(result core.Result) error { + if result.OK { + return nil + } + if err, ok := result.Value.(error); ok { + return err + } + return core.E("openai.result", "unexpected failed result value", nil) +} + +func completionID() string { + // Fires once per chat-completion response. core.Sprintf was 2 allocs + // (fmt formatter scratch + result string); the append-into-prefix + // path is a single alloc backing the returned string via AsString. + buf := make([]byte, 0, 32) // "chatcmpl-" (9) + max int64 (20) + slack + buf = append(buf, "chatcmpl-"...) + buf = strconv.AppendInt(buf, time.Now().UnixNano(), 10) + return core.AsString(buf) +} + +func isTokenLengthCapReached(maxTokens *int, generated int) bool { + return maxTokens != nil && *maxTokens > 0 && generated >= *maxTokens +} + +// TruncateAtStopSequence removes the first matching stop sequence and anything +// after it. +func TruncateAtStopSequence(content string, stops []string) string { + cut, ok := firstStopSequenceCut(content, stops) + if !ok { + return content + } + return content[:cut] +} + +func firstStopSequenceCut(content string, stops []string) (int, bool) { + if content == "" || len(stops) == 0 { + return 0, false + } + best := -1 + for _, stop := range stops { + idx := indexString(content, stop) + if idx < 0 { + continue + } + if best < 0 || idx < best { + best = idx + } + } + if best < 0 { + return 0, false + } + return best, true +} + +// indexString delegates to core.Index (strings.Index — Rabin-Karp + +// SIMD byte search). The earlier hand-rolled loop was O(N×M) per call +// and fired multiple times per chat-completion (stop-sequence cut + +// thinking-extractor per streaming chunk + channel-marker detection +// on every delta). +// +// Returns -1 on empty needle to preserve the caller contract — the +// stop-sequence + extractor paths treat empty as "no match" rather +// than the strings.Index "match at 0" semantics. +func indexString(s, needle string) int { + if needle == "" { + return -1 + } + return core.Index(s, needle) +} + +type pairedMarker struct { + start string + end string +} + +var reasoningMarkers = []pairedMarker{ + {start: "", end: ""}, + {start: "", end: ""}, + {start: "", end: ""}, + {start: "", end: ""}, +} + +// ThinkingExtractor separates model-internal reasoning text from assistant +// content. +type ThinkingExtractor struct { + pending string + content string + thinking string + inPaired bool + pairedEnd string + currentChannel string +} + +func NewThinkingExtractor() *ThinkingExtractor { + return &ThinkingExtractor{currentChannel: "assistant"} +} + +func (e *ThinkingExtractor) Process(token inference.Token) (contentDelta, thoughtDelta string) { + if e == nil { + return "", "" + } + e.pending += token.Text + return e.drain(false) +} + +func (e *ThinkingExtractor) Flush() (contentDelta, thoughtDelta string) { + if e == nil { + return "", "" + } + contentDelta, thoughtDelta = e.drain(true) + if e.pending == "" { + return contentDelta, thoughtDelta + } + if e.inPaired || e.currentChannel == "thought" || e.currentChannel == "thinking" || e.currentChannel == "reasoning" { + thoughtDelta += e.pending + e.thinking += e.pending + } else { + contentDelta += e.pending + e.content += e.pending + } + e.pending = "" + e.inPaired = false + return contentDelta, thoughtDelta +} + +func (e *ThinkingExtractor) Content() string { + if e == nil { + return "" + } + return e.content +} + +func (e *ThinkingExtractor) Thinking() string { + if e == nil { + return "" + } + return e.thinking +} + +func (e *ThinkingExtractor) drain(final bool) (string, string) { + contentDelta := core.NewBuilder() + thoughtDelta := core.NewBuilder() + for e.pending != "" { + if e.inPaired { + idx := indexString(e.pending, e.pairedEnd) + if idx >= 0 { + writeThought(e, thoughtDelta, e.pending[:idx]) + e.pending = e.pending[idx+len(e.pairedEnd):] + e.inPaired = false + e.pairedEnd = "" + continue + } + emit, keep := splitSafeSuffix(e.pending, []string{e.pairedEnd}, final) + writeThought(e, thoughtDelta, emit) + e.pending = keep + if keep != "" && !final { + break + } + continue + } + + if ok := e.consumeMarkerAtStart(); ok { + continue + } + + if e.currentChannel == "thought" || e.currentChannel == "thinking" || e.currentChannel == "reasoning" { + idx := indexString(e.pending, channelMarker) + if idx >= 0 { + writeThought(e, thoughtDelta, e.pending[:idx]) + e.pending = e.pending[idx:] + if e.consumeMarkerAtStart() { + continue + } + if !final { + break + } + writeThought(e, thoughtDelta, channelMarker) + e.pending = e.pending[len(channelMarker):] + continue + } + emit, keep := splitSafeSuffix(e.pending, []string{channelMarker}, final) + writeThought(e, thoughtDelta, emit) + e.pending = keep + if keep != "" && !final { + break + } + continue + } + + start, idx := earliestReasoningStart(e.pending) + channelIdx := indexString(e.pending, channelMarker) + if channelIdx >= 0 && (idx < 0 || channelIdx < idx) { + idx = channelIdx + start = channelMarker + } + if idx >= 0 { + writeContent(e, contentDelta, e.pending[:idx]) + e.pending = e.pending[idx:] + if start == channelMarker { + if e.consumeMarkerAtStart() { + continue + } + if !final { + break + } + writeContent(e, contentDelta, channelMarker) + e.pending = e.pending[len(channelMarker):] + continue + } + e.inPaired = true + e.pairedEnd = pairedEndFor(start) + e.pending = e.pending[len(start):] + continue + } + emit, keep := splitSafeSuffix(e.pending, markerStarts(), final) + writeContent(e, contentDelta, emit) + e.pending = keep + if keep != "" && !final { + break + } + } + return contentDelta.String(), thoughtDelta.String() +} + +func (e *ThinkingExtractor) consumeMarkerAtStart() bool { + if !core.HasPrefix(e.pending, channelMarker) { + for _, marker := range reasoningMarkers { + if core.HasPrefix(e.pending, marker.start) { + e.inPaired = true + e.pairedEnd = marker.end + e.pending = e.pending[len(marker.start):] + return true + } + } + return false + } + remaining := e.pending[len(channelMarker):] + consumedSpace := 0 + for consumedSpace < len(remaining) { + r, size := rune(remaining[consumedSpace]), 1 + if r >= 0x80 { + r, size = utf8Rune(remaining[consumedSpace:]) + } + if !unicode.IsSpace(r) { + break + } + consumedSpace += size + } + nameLen := 0 + for consumedSpace+nameLen < len(remaining) { + c := remaining[consumedSpace+nameLen] + if (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' || c == '-' { + nameLen++ + continue + } + break + } + if nameLen == 0 { + return false + } + e.currentChannel = core.Lower(remaining[consumedSpace : consumedSpace+nameLen]) + e.pending = remaining[consumedSpace+nameLen:] + return true +} + +func utf8Rune(s string) (rune, int) { + for _, r := range s { + return r, len(string(r)) + } + return 0, 0 +} + +func writeContent(e *ThinkingExtractor, builder interface{ WriteString(string) (int, error) }, text string) { + if text == "" { + return + } + builder.WriteString(text) + e.content += text +} + +func writeThought(e *ThinkingExtractor, builder interface{ WriteString(string) (int, error) }, text string) { + if text == "" { + return + } + builder.WriteString(text) + e.thinking += text +} + +func earliestReasoningStart(s string) (string, int) { + best := -1 + bestStart := "" + for _, marker := range reasoningMarkers { + idx := indexString(s, marker.start) + if idx < 0 { + continue + } + if best < 0 || idx < best { + best = idx + bestStart = marker.start + } + } + return bestStart, best +} + +func pairedEndFor(start string) string { + for _, marker := range reasoningMarkers { + if marker.start == start { + return marker.end + } + } + return "" +} + +func markerStarts() []string { + out := make([]string, 0, len(reasoningMarkers)+1) + out = append(out, channelMarker) + for _, marker := range reasoningMarkers { + out = append(out, marker.start) + } + return out +} + +func splitSafeSuffix(s string, markers []string, final bool) (emit, keep string) { + if final { + return s, "" + } + keepLen := 0 + for _, marker := range markers { + max := min(len(s), len(marker)-1) + for n := 1; n <= max; n++ { + if s[len(s)-n:] == marker[:n] && n > keepLen { + keepLen = n + } + } + } + if keepLen == 0 { + return s, "" + } + return s[:len(s)-keepLen], s[len(s)-keepLen:] +} diff --git a/go/openai/openai_bench_test.go b/go/openai/openai_bench_test.go new file mode 100644 index 0000000..255254f --- /dev/null +++ b/go/openai/openai_bench_test.go @@ -0,0 +1,572 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the OpenAI-compatible chat-completions wire primitives. +// Per AX-11 — these surfaces fire on every served chat request: +// * DecodeRequest + ValidateRequest at request entry +// * GenerateOptions / NormalizeStopSequences after validation +// * ChatMessageDelta.MarshalJSON per streamed delta +// * indexString + firstStopSequenceCut per delta in the SSE loop +// * TruncateAtStopSequence at end-of-stream +// * ThinkingExtractor.Process per token (channel + paired-marker scan) +// +// Run: go test -bench='BenchmarkOpenAI' -benchtime=100ms -benchmem -run='^$' . + +package openai + +import ( + "strings" + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + openAISinkChatRequest ChatCompletionRequest + openAISinkChatResponse ChatCompletionResponse + openAISinkChunk ChatCompletionChunk + openAISinkOptions []inference.GenerateOption + openAISinkErr error + openAISinkStops []string + openAISinkString string + openAISinkStopList StopList + openAISinkInt int + openAISinkBool bool + openAISinkBytes []byte + openAISinkContent string + openAISinkThought string + openAISinkResult core.Result +) + +// --- Fixture bodies --- + +// openAISingleTurnBody mirrors the typical chat-completions request the +// handler decodes at request entry. +const openAISingleTurnBody = `{"model":"qwen3","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"Please summarise the following paragraph for me in one sentence."}],"temperature":0.7,"top_p":0.95,"max_tokens":256,"stream":true,"stop":["<|im_end|>"]}` + +// openAIFiveTurnBody is the realistic chat-history shape — 1 system + 4 +// user/assistant pairs. +const openAIFiveTurnBody = `{"model":"qwen3","messages":[{"role":"system","content":"You are a helpful assistant."},{"role":"user","content":"What is 2+2?"},{"role":"assistant","content":"4"},{"role":"user","content":"Are you sure?"},{"role":"assistant","content":"Yes."},{"role":"user","content":"Why?"}],"temperature":0.7,"max_tokens":256,"stream":true}` + +// openAITwentyTurnBody — long-running session shape, exercises the +// slice-grow path inside the ChatMessage decode loop. +var openAITwentyTurnBody = buildOpenAITurnsBody(20) + +func buildOpenAITurnsBody(turns int) string { + out := core.NewBuilder() + out.WriteString(`{"model":"qwen3","messages":[`) + out.WriteString(`{"role":"system","content":"You are a helpful assistant."}`) + user := `,{"role":"user","content":"How many tokens does this paragraph contain when measured against the GPT-2 tokeniser?"}` + assistant := `,{"role":"assistant","content":"That depends on the precise tokeniser implementation but is approximately 32."}` + for i := 0; i < turns; i++ { + if i%2 == 0 { + out.WriteString(user) + } else { + out.WriteString(assistant) + } + } + out.WriteString(`],"max_tokens":1024,"stream":true}`) + return out.String() +} + +// buildChatRequest mirrors a decoded ChatCompletionRequest with the +// requested turn count. Used for Marshal benches. +func buildChatRequest(turns int) ChatCompletionRequest { + temperature := float32(0.7) + topP := float32(0.95) + topK := 64 + maxTokens := 256 + req := ChatCompletionRequest{ + Model: "qwen3", + Temperature: &temperature, + TopP: &topP, + TopK: &topK, + MaxTokens: &maxTokens, + Stream: true, + Stop: StopList{"<|im_end|>", "<|eot_id|>"}, + } + req.Messages = append(req.Messages, ChatMessage{Role: "system", Content: "You are a helpful assistant."}) + for i := 0; i < turns; i++ { + if i%2 == 0 { + req.Messages = append(req.Messages, ChatMessage{Role: "user", Content: "Summarise the paragraph in one sentence."}) + } else { + req.Messages = append(req.Messages, ChatMessage{Role: "assistant", Content: "The summary captures the key claim."}) + } + } + return req +} + +// --- DecodeRequest — front-of-handler JSON decode --- + +func BenchmarkOpenAI_DecodeRequest_SingleTurn(b *testing.B) { + body := openAISingleTurnBody + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkChatRequest, openAISinkErr = DecodeRequest(strings.NewReader(body)) + } +} + +func BenchmarkOpenAI_DecodeRequest_FiveTurn(b *testing.B) { + body := openAIFiveTurnBody + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkChatRequest, openAISinkErr = DecodeRequest(strings.NewReader(body)) + } +} + +func BenchmarkOpenAI_DecodeRequest_TwentyTurn(b *testing.B) { + body := openAITwentyTurnBody + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkChatRequest, openAISinkErr = DecodeRequest(strings.NewReader(body)) + } +} + +func BenchmarkOpenAI_DecodeRequest_StopAsString(b *testing.B) { + body := `{"model":"qwen3","messages":[{"role":"user","content":"hi"}],"stop":"END"}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkChatRequest, openAISinkErr = DecodeRequest(strings.NewReader(body)) + } +} + +func BenchmarkOpenAI_DecodeRequest_StopAsArray(b *testing.B) { + body := `{"model":"qwen3","messages":[{"role":"user","content":"hi"}],"stop":["END","<|eot_id|>",""]}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkChatRequest, openAISinkErr = DecodeRequest(strings.NewReader(body)) + } +} + +// --- StopList.UnmarshalJSON — direct-call bench bypasses the wrapping +// JSON decoder, isolating the variant-parse cost. --- + +func BenchmarkOpenAI_StopList_UnmarshalJSON_String(b *testing.B) { + data := []byte(`"END"`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var sl StopList + openAISinkErr = sl.UnmarshalJSON(data) + openAISinkStopList = sl + } +} + +func BenchmarkOpenAI_StopList_UnmarshalJSON_Array(b *testing.B) { + data := []byte(`["<|im_end|>","<|eot_id|>",""]`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var sl StopList + openAISinkErr = sl.UnmarshalJSON(data) + openAISinkStopList = sl + } +} + +// --- ValidateRequest — request-shape validation after decode --- + +func BenchmarkOpenAI_ValidateRequest_SingleTurn(b *testing.B) { + req := buildChatRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkErr = ValidateRequest(req) + } +} + +func BenchmarkOpenAI_ValidateRequest_TwentyTurn(b *testing.B) { + req := buildChatRequest(20) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkErr = ValidateRequest(req) + } +} + +// --- GenerateOptions — sampling-field projection --- + +func BenchmarkOpenAI_GenerateOptions_AllFieldsSet(b *testing.B) { + req := buildChatRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkOptions, openAISinkErr = GenerateOptions(req) + } +} + +func BenchmarkOpenAI_GenerateOptions_DefaultsOnly(b *testing.B) { + req := ChatCompletionRequest{ + Model: "qwen3", + Messages: []ChatMessage{{Role: "user", Content: "hi"}}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkOptions, openAISinkErr = GenerateOptions(req) + } +} + +// --- NormalizeStopSequences — per-request stop-sequence projection --- + +func BenchmarkOpenAI_NormalizeStopSequences_Empty(b *testing.B) { + stops := StopList{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkStops, openAISinkErr = NormalizeStopSequences(stops) + } +} + +func BenchmarkOpenAI_NormalizeStopSequences_Typical(b *testing.B) { + stops := StopList{"<|im_end|>", "<|eot_id|>", ""} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkStops, openAISinkErr = NormalizeStopSequences(stops) + } +} + +// --- ChatMessageDelta.MarshalJSON — per-streamed-delta encode --- +// Hits every SSE frame the streaming handler emits. + +func BenchmarkOpenAI_ChatMessageDelta_Marshal_ContentOnly(b *testing.B) { + delta := ChatMessageDelta{Content: "Answer"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkBytes, openAISinkErr = delta.MarshalJSON() + } +} + +func BenchmarkOpenAI_ChatMessageDelta_Marshal_RolePriming(b *testing.B) { + delta := ChatMessageDelta{Role: "assistant"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkBytes, openAISinkErr = delta.MarshalJSON() + } +} + +func BenchmarkOpenAI_ChatMessageDelta_Marshal_Empty(b *testing.B) { + delta := ChatMessageDelta{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkBytes, openAISinkErr = delta.MarshalJSON() + } +} + +// --- ChatCompletionChunk — full SSE frame marshal --- +// What writeChunk runs once per streamed token plus the terminal frame. + +func BenchmarkOpenAI_MarshalChatCompletionChunk_Delta(b *testing.B) { + chunk := ChatCompletionChunk{ + ID: "chatcmpl-bench", + Object: "chat.completion.chunk", + Created: 1700000000, + Model: "qwen3", + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatMessageDelta{Content: "Answer"}, + }}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkString = core.JSONMarshalString(chunk) + } +} + +func BenchmarkOpenAI_MarshalChatCompletionChunk_Final(b *testing.B) { + finish := "stop" + chunk := ChatCompletionChunk{ + ID: "chatcmpl-bench", + Object: "chat.completion.chunk", + Created: 1700000000, + Model: "qwen3", + Choices: []ChatChunkChoice{{ + Index: 0, + Delta: ChatMessageDelta{}, + FinishReason: &finish, + }}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkString = core.JSONMarshalString(chunk) + } +} + +// --- Hand-rolled chunk-as-SSE-frame — the streaming hot path --- +// Fires per token. The single-buffer frame builder replaces the +// JSONMarshalString + Concat + []byte conversion three-allocation +// chain that the streaming handler used pre-W9-D. + +func BenchmarkOpenAI_AppendChatCompletionChunkSSE_Priming(b *testing.B) { + chunk := ChatCompletionChunk{ + ID: "chatcmpl-bench", + Object: "chat.completion.chunk", + Created: 1700000000, + Model: "qwen3", + Choices: []ChatChunkChoice{{Index: 0, Delta: ChatMessageDelta{Role: "assistant"}}}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkBytes = appendChatCompletionChunkSSE(make([]byte, 0, chunkSSEFrameSize(chunk)), chunk) + } +} + +func BenchmarkOpenAI_AppendChatCompletionChunkSSE_Delta(b *testing.B) { + chunk := ChatCompletionChunk{ + ID: "chatcmpl-bench", + Object: "chat.completion.chunk", + Created: 1700000000, + Model: "qwen3", + Choices: []ChatChunkChoice{{Index: 0, Delta: ChatMessageDelta{Content: "Answer"}}}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkBytes = appendChatCompletionChunkSSE(make([]byte, 0, chunkSSEFrameSize(chunk)), chunk) + } +} + +func BenchmarkOpenAI_AppendChatCompletionChunkSSE_Final(b *testing.B) { + finish := "stop" + chunk := ChatCompletionChunk{ + ID: "chatcmpl-bench", + Object: "chat.completion.chunk", + Created: 1700000000, + Model: "qwen3", + Choices: []ChatChunkChoice{{Index: 0, Delta: ChatMessageDelta{}, FinishReason: &finish}}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkBytes = appendChatCompletionChunkSSE(make([]byte, 0, chunkSSEFrameSize(chunk)), chunk) + } +} + +// --- ChatCompletionResponse — non-streaming response marshal --- + +// AppendChatCompletionResponse — hand-rolled fast path used by +// writeJSON for the canonical non-streaming response shape. +func BenchmarkOpenAI_AppendChatCompletionResponse_Typical(b *testing.B) { + resp := ChatCompletionResponse{ + ID: "chatcmpl-bench", + Object: "chat.completion", + Created: 1700000000, + Model: "qwen3", + Choices: []ChatChoice{{ + Index: 0, + Message: ChatMessage{Role: "assistant", Content: "The summary is concise and faithful to the original text."}, + FinishReason: "stop", + }}, + Usage: ChatUsage{PromptTokens: 200, CompletionTokens: 32, TotalTokens: 232}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkBytes = appendChatCompletionResponse(make([]byte, 0, chatCompletionResponseSize(resp)), resp) + } +} + +func BenchmarkOpenAI_MarshalChatCompletionResponse_Typical(b *testing.B) { + resp := ChatCompletionResponse{ + ID: "chatcmpl-bench", + Object: "chat.completion", + Created: 1700000000, + Model: "qwen3", + Choices: []ChatChoice{{ + Index: 0, + Message: ChatMessage{Role: "assistant", Content: "The summary is concise and faithful to the original text."}, + FinishReason: "stop", + }}, + Usage: ChatUsage{PromptTokens: 200, CompletionTokens: 32, TotalTokens: 232}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkString = core.JSONMarshalString(resp) + } +} + +// --- indexString — primitive substring scan used by stop-sequence cut --- + +func BenchmarkOpenAI_IndexString_Miss(b *testing.B) { + content := strings.Repeat("answer fragment ", 32) // ~512 chars + needle := "<|im_end|>" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkInt = indexString(content, needle) + } +} + +func BenchmarkOpenAI_IndexString_EarlyHit(b *testing.B) { + content := "<|im_end|>" + strings.Repeat("answer fragment ", 32) + needle := "<|im_end|>" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkInt = indexString(content, needle) + } +} + +func BenchmarkOpenAI_IndexString_LateHit(b *testing.B) { + content := strings.Repeat("answer fragment ", 32) + "<|im_end|>" + needle := "<|im_end|>" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkInt = indexString(content, needle) + } +} + +// --- firstStopSequenceCut — per-delta scan in the SSE loop --- +// Scales O(content × |stops|) so multi-stop request shapes pay more. + +func BenchmarkOpenAI_FirstStopSequenceCut_Miss(b *testing.B) { + content := strings.Repeat("answer fragment ", 32) + stops := []string{"<|im_end|>", "<|eot_id|>"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkInt, openAISinkBool = firstStopSequenceCut(content, stops) + } +} + +func BenchmarkOpenAI_FirstStopSequenceCut_LateHit(b *testing.B) { + content := strings.Repeat("answer fragment ", 32) + "<|im_end|>" + stops := []string{"<|im_end|>", "<|eot_id|>"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkInt, openAISinkBool = firstStopSequenceCut(content, stops) + } +} + +func BenchmarkOpenAI_FirstStopSequenceCut_EarlyHit(b *testing.B) { + content := "<|im_end|>" + strings.Repeat("answer fragment ", 32) + stops := []string{"<|im_end|>", "<|eot_id|>"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkInt, openAISinkBool = firstStopSequenceCut(content, stops) + } +} + +// --- TruncateAtStopSequence — end-of-stream guard --- + +func BenchmarkOpenAI_TruncateAtStopSequence_NoMatch(b *testing.B) { + content := strings.Repeat("answer fragment ", 32) + stops := []string{"<|im_end|>", "<|eot_id|>"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkString = TruncateAtStopSequence(content, stops) + } +} + +func BenchmarkOpenAI_TruncateAtStopSequence_Match(b *testing.B) { + content := strings.Repeat("answer fragment ", 32) + "<|im_end|> ignored" + stops := []string{"<|im_end|>", "<|eot_id|>"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkString = TruncateAtStopSequence(content, stops) + } +} + +// --- ThinkingExtractor — per-token reasoning split --- +// Runs on every token of every chat completion. The marker scans inside +// Process are where the cost sits. + +func BenchmarkOpenAI_ThinkingExtractor_Process_PlainTokenShort(b *testing.B) { + tokens := []inference.Token{{Text: "Answer"}} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + extractor := NewThinkingExtractor() + openAISinkContent, openAISinkThought = extractor.Process(tokens[0]) + } +} + +func BenchmarkOpenAI_ThinkingExtractor_Process_PairedThinkBlock(b *testing.B) { + tokens := []inference.Token{{Text: "planAnswer"}} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + extractor := NewThinkingExtractor() + openAISinkContent, openAISinkThought = extractor.Process(tokens[0]) + c, t := extractor.Flush() + openAISinkContent = c + openAISinkThought = t + } +} + +func BenchmarkOpenAI_ThinkingExtractor_Process_ChannelMarker(b *testing.B) { + token := inference.Token{Text: "<|channel>thought hidden<|channel>assistant Answer"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + extractor := NewThinkingExtractor() + openAISinkContent, openAISinkThought = extractor.Process(token) + c, t := extractor.Flush() + openAISinkContent = c + openAISinkThought = t + } +} + +// Long delta — 256 chars without any marker substrate, hits the +// hot-path scan-then-emit branch for every streamed token. +func BenchmarkOpenAI_ThinkingExtractor_Process_LongPlainDelta(b *testing.B) { + token := inference.Token{Text: strings.Repeat("answer fragment ", 16)} // 256 chars + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + extractor := NewThinkingExtractor() + openAISinkContent, openAISinkThought = extractor.Process(token) + } +} + +// --- requestMessages — wire→internal conversion --- + +func BenchmarkOpenAI_RequestMessages_SingleTurn(b *testing.B) { + messages := []ChatMessage{ + {Role: "system", Content: "You are a helpful assistant."}, + {Role: "user", Content: "Summarise the paragraph."}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = requestMessages(messages) + } +} + +func BenchmarkOpenAI_RequestMessages_TwentyTurn(b *testing.B) { + req := buildChatRequest(20) + messages := req.Messages + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = requestMessages(messages) + } +} + +// --- completionID — request-level ID generator --- + +func BenchmarkOpenAI_CompletionID(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + openAISinkString = completionID() + } +} diff --git a/go/openai/openai_test.go b/go/openai/openai_test.go new file mode 100644 index 0000000..10f38f7 --- /dev/null +++ b/go/openai/openai_test.go @@ -0,0 +1,215 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "context" + "iter" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "dappco.re/go/inference" +) + +type stubModel struct { + tokens []inference.Token + metrics inference.GenerateMetrics + err error +} + +func (m *stubModel) Generate(context.Context, string, ...inference.GenerateOption) iter.Seq[inference.Token] { + return m.seq() +} + +func (m *stubModel) Chat(context.Context, []inference.Message, ...inference.GenerateOption) iter.Seq[inference.Token] { + return m.seq() +} + +func (m *stubModel) Classify(context.Context, []string, ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + return nil, nil +} + +func (m *stubModel) BatchGenerate(context.Context, []string, ...inference.GenerateOption) ([]inference.BatchResult, error) { + return nil, nil +} + +func (m *stubModel) ModelType() string { return "stub" } + +func (m *stubModel) Info() inference.ModelInfo { return inference.ModelInfo{Architecture: "qwen3"} } + +func (m *stubModel) Metrics() inference.GenerateMetrics { return m.metrics } + +func (m *stubModel) Err() error { return m.err } + +func (m *stubModel) Close() error { return nil } + +func (m *stubModel) seq() iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + for _, token := range m.tokens { + if !yield(token) { + return + } + } + } +} + +func TestOpenAI_DecodeRequest_Good_StopStringAndDefaults(t *testing.T) { + body := strings.NewReader(`{"model":"qwen","messages":[{"role":"user","content":"hi"}],"stop":"END"}`) + + req, err := DecodeRequest(body) + if err != nil { + t.Fatalf("DecodeRequest() error = %v", err) + } + if req.Model != "qwen" || len(req.Messages) != 1 { + t.Fatalf("DecodeRequest() = %+v", req) + } + stops, err := NormalizeStopSequences(req.Stop) + if err != nil { + t.Fatalf("NormalizeStopSequences() error = %v", err) + } + if len(stops) != 1 || stops[0] != "END" { + t.Fatalf("stops = %#v, want END", stops) + } + + opts, err := GenerateOptions(req) + if err != nil { + t.Fatalf("GenerateOptions() error = %v", err) + } + cfg := inference.ApplyGenerateOpts(opts) + if cfg.Temperature != DefaultTemperature || cfg.TopP != DefaultTopP || cfg.TopK != DefaultTopK || cfg.MaxTokens != DefaultMaxTokens { + t.Fatalf("defaults = %+v", cfg) + } +} + +func TestOpenAI_GenerateOptions_Good_HonoursExplicitZero(t *testing.T) { + zeroFloat := float32(0) + zeroInt := 0 + req := ChatCompletionRequest{ + Model: "qwen", + Messages: []ChatMessage{{Role: "user", Content: "hi"}}, + Temperature: &zeroFloat, + TopP: &zeroFloat, + TopK: &zeroInt, + MaxTokens: &zeroInt, + } + + opts, err := GenerateOptions(req) + if err != nil { + t.Fatalf("GenerateOptions() error = %v", err) + } + cfg := inference.ApplyGenerateOpts(opts) + if cfg.Temperature != 0 || cfg.TopP != 0 || cfg.TopK != 0 || cfg.MaxTokens != 0 { + t.Fatalf("explicit zero options = %+v", cfg) + } +} + +func TestOpenAI_ThinkingExtractor_Good_CapturesQwenAndChannelMarkers(t *testing.T) { + extractor := NewThinkingExtractor() + + visible, thought := extractor.Process(inference.Token{Text: "A hidden B <|channel>thought plan"}) + visible3, thought3 := extractor.Process(inference.Token{Text: "<|channel>assistant C"}) + visible4, thought4 := extractor.Flush() + + gotVisible := visible + visible2 + visible3 + visible4 + gotThought := thought + thought2 + thought3 + thought4 + if gotVisible != "A B C" { + t.Fatalf("visible = %q", gotVisible) + } + if gotThought != "hidden plan" { + t.Fatalf("thought = %q", gotThought) + } + if extractor.Content() != gotVisible || extractor.Thinking() != gotThought { + t.Fatalf("extractor content/thought = %q/%q", extractor.Content(), extractor.Thinking()) + } +} + +func TestOpenAI_ThinkingExtractor_Ugly_IncompleteChannelMarkerDoesNotHang(t *testing.T) { + extractor := NewThinkingExtractor() + done := make(chan struct{}) + go func() { + extractor.Process(inference.Token{Text: "<|channel>"}) + close(done) + }() + + select { + case <-done: + case <-time.After(100 * time.Millisecond): + t.Fatal("Process() hung on incomplete channel marker") + } + visible, thought := extractor.Flush() + if visible != "<|channel>" || thought != "" { + t.Fatalf("Flush() = %q/%q", visible, thought) + } +} + +func TestOpenAI_StaticResolver_Good_CaseInsensitiveModelLookup(t *testing.T) { + model := &stubModel{} + resolver := NewStaticResolver(map[string]inference.TextModel{"Qwen3": model}) + + got, err := resolver.ResolveModel(context.Background(), "qwen3") + if err != nil { + t.Fatalf("ResolveModel() error = %v", err) + } + if got != model { + t.Fatalf("ResolveModel() = %p, want %p", got, model) + } +} + +func TestOpenAI_Handler_Good_NonStreamingResponseIncludesThoughtAndUsage(t *testing.T) { + model := &stubModel{ + tokens: []inference.Token{ + {Text: "planAnswer END ignored"}, + }, + metrics: inference.GenerateMetrics{PromptTokens: 3, GeneratedTokens: 4}, + } + handler := NewHandler(NewStaticResolver(map[string]inference.TextModel{"qwen": model})) + body := strings.NewReader(`{"model":"qwen","messages":[{"role":"user","content":"hi"}],"stop":["END"]}`) + req := httptest.NewRequest(http.MethodPost, DefaultChatCompletionsPath, body) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"content":"Answer "`) { + t.Fatalf("response missing visible content: %s", rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"thought":"plan"`) { + t.Fatalf("response missing thought: %s", rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"total_tokens":7`) { + t.Fatalf("response missing usage: %s", rec.Body.String()) + } +} + +func TestOpenAI_Handler_Good_StreamingResponseEmitsSSEChunks(t *testing.T) { + model := &stubModel{tokens: []inference.Token{{Text: "Hel"}, {Text: "lo"}}} + handler := NewHandler(NewStaticResolver(map[string]inference.TextModel{"qwen": model})) + body := strings.NewReader(`{"model":"qwen","messages":[{"role":"user","content":"hi"}],"stream":true}`) + req := httptest.NewRequest(http.MethodPost, DefaultChatCompletionsPath, body) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String()) + } + if got := rec.Header().Get("Content-Type"); !strings.Contains(got, "text/event-stream") { + t.Fatalf("content-type = %q", got) + } + bodyText := rec.Body.String() + if !strings.Contains(bodyText, `"role":"assistant","content":""`) { + t.Fatalf("stream missing priming chunk: %s", bodyText) + } + if !strings.Contains(bodyText, `"content":"Hel"`) || !strings.Contains(bodyText, `"content":"lo"`) { + t.Fatalf("stream missing content deltas: %s", bodyText) + } + if !strings.Contains(bodyText, "data: [DONE]") { + t.Fatalf("stream missing DONE: %s", bodyText) + } +} diff --git a/go/openai/responses.go b/go/openai/responses.go new file mode 100644 index 0000000..eb434b7 --- /dev/null +++ b/go/openai/responses.go @@ -0,0 +1,131 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "time" + + "dappco.re/go/inference" +) + +// DefaultResponsesPath is the OpenAI-compatible Responses endpoint. +const DefaultResponsesPath = "/v1/responses" + +// ResponseInputMessage is the message form accepted by the Responses adapter. +type ResponseInputMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// ResponseRequest is the minimal OpenAI-compatible Responses request shape +// shared by local runtimes and provider clients. +type ResponseRequest struct { + Model string `json:"model"` + Input []ResponseInputMessage `json:"input,omitempty"` + Instructions string `json:"instructions,omitempty"` + Temperature *float32 `json:"temperature,omitempty"` + TopP *float32 `json:"top_p,omitempty"` + TopK *int `json:"top_k,omitempty"` + MaxOutputTokens *int `json:"max_output_tokens,omitempty"` + Stream bool `json:"stream,omitempty"` + Stop StopList `json:"stop,omitempty"` + User string `json:"user,omitempty"` +} + +// ResponseOutputText is one visible text item in a Responses output message. +type ResponseOutputText struct { + Type string `json:"type"` + Text string `json:"text"` +} + +// ResponseOutputMessage is the assistant message emitted by a response. +type ResponseOutputMessage struct { + ID string `json:"id,omitempty"` + Type string `json:"type"` + Role string `json:"role"` + Content []ResponseOutputText `json:"content"` +} + +// ResponseUsage records token accounting for a Responses call. +type ResponseUsage struct { + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` +} + +// Response is the non-streaming OpenAI-compatible Responses body. +type Response struct { + ID string `json:"id"` + Object string `json:"object"` + Created int64 `json:"created"` + Model string `json:"model"` + Output []ResponseOutputMessage `json:"output"` + Usage ResponseUsage `json:"usage"` + Thought *string `json:"thought,omitempty"` +} + +// ResponseStreamEvent is a compact SSE event payload for Responses streaming. +type ResponseStreamEvent struct { + Type string `json:"type"` + Response *Response `json:"response,omitempty"` + Delta string `json:"delta,omitempty"` + Thought *string `json:"thought,omitempty"` +} + +// ResponseMessages converts a Responses request into inference messages. +func ResponseMessages(req ResponseRequest) []inference.Message { + out := make([]inference.Message, 0, len(req.Input)+1) + if req.Instructions != "" { + out = append(out, inference.Message{Role: "system", Content: req.Instructions}) + } + for _, msg := range req.Input { + out = append(out, inference.Message{Role: msg.Role, Content: msg.Content}) + } + return out +} + +// ResponseGenerateOptions converts Responses sampling fields into inference +// options. +func ResponseGenerateOptions(req ResponseRequest) ([]inference.GenerateOption, error) { + chatReq := ChatCompletionRequest{ + Model: req.Model, + Temperature: req.Temperature, + TopP: req.TopP, + TopK: req.TopK, + MaxTokens: req.MaxOutputTokens, + // Pre-size — saves the append-grow cascade on every Responses + // API call. Twenty-turn requests previously paid ~4 grow allocs + // before reaching their final size. + Messages: make([]ChatMessage, 0, len(req.Input)), + } + for _, msg := range req.Input { + chatReq.Messages = append(chatReq.Messages, ChatMessage{Role: msg.Role, Content: msg.Content}) + } + if len(chatReq.Messages) == 0 && req.Instructions != "" { + chatReq.Messages = []ChatMessage{{Role: "system", Content: req.Instructions}} + } + return GenerateOptions(chatReq) +} + +// NewTextResponse builds a Responses body from visible text and metrics. +func NewTextResponse(id, model, text string, metrics inference.GenerateMetrics) Response { + return Response{ + ID: id, + Object: "response", + Created: time.Now().Unix(), + Model: model, + Output: []ResponseOutputMessage{{ + Type: "message", + Role: "assistant", + Content: []ResponseOutputText{{ + Type: "output_text", + Text: text, + }}, + }}, + Usage: ResponseUsage{ + InputTokens: metrics.PromptTokens, + OutputTokens: metrics.GeneratedTokens, + TotalTokens: metrics.PromptTokens + metrics.GeneratedTokens, + }, + } +} diff --git a/go/openai/responses_bench_test.go b/go/openai/responses_bench_test.go new file mode 100644 index 0000000..49e4d48 --- /dev/null +++ b/go/openai/responses_bench_test.go @@ -0,0 +1,374 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the OpenAI-compatible Responses wire primitives. +// Per AX-11 — the Responses endpoint is the OpenAI v1/responses path +// served by both the local runtime and proxy clients. These fixtures +// exercise the JSON ingress/egress, the wire→inference message +// projection, and the per-event stream marshal that fires per token in +// the response stream. +// +// Run: go test -bench='BenchmarkResponses' -benchtime=100ms -benchmem -run='^$' . + +package openai + +import ( + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + responsesSinkRequest ResponseRequest + responsesSinkResponse Response + responsesSinkEvent ResponseStreamEvent + responsesSinkMessages []inference.Message + responsesSinkOptions []inference.GenerateOption + responsesSinkErr error + responsesSinkString string + responsesSinkBytes []byte + responsesSinkResult core.Result +) + +// --- Fixture builders --- + +// buildResponseRequest produces a representative Responses payload with +// the requested turn count. Mirrors what the v1/responses handler +// decodes at request entry. +func buildResponseRequest(turns int) ResponseRequest { + temperature := float32(0.7) + topP := float32(0.95) + topK := 64 + maxOutputTokens := 256 + req := ResponseRequest{ + Model: "qwen3", + Instructions: "You are a helpful assistant. Be concise.", + Temperature: &temperature, + TopP: &topP, + TopK: &topK, + MaxOutputTokens: &maxOutputTokens, + Stream: true, + Stop: StopList{"<|im_end|>"}, + } + for i := 0; i < turns; i++ { + if i%2 == 0 { + req.Input = append(req.Input, ResponseInputMessage{Role: "user", Content: "Summarise the paragraph in one sentence."}) + } else { + req.Input = append(req.Input, ResponseInputMessage{Role: "assistant", Content: "The summary captures the key claim."}) + } + } + return req +} + +// buildResponse mirrors a completed Responses body. +func buildResponse() Response { + return NewTextResponse( + "resp_bench", + "qwen3", + "The summary is concise and faithful to the original text.", + inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32}, + ) +} + +// --- JSON Marshal --- + +func BenchmarkResponses_MarshalRequest_SingleTurn(b *testing.B) { + req := buildResponseRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkString = core.JSONMarshalString(req) + } +} + +func BenchmarkResponses_MarshalRequest_FiveTurn(b *testing.B) { + req := buildResponseRequest(5) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkString = core.JSONMarshalString(req) + } +} + +func BenchmarkResponses_MarshalRequest_TwentyTurn(b *testing.B) { + req := buildResponseRequest(20) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkString = core.JSONMarshalString(req) + } +} + +func BenchmarkResponses_MarshalResponse_Typical(b *testing.B) { + resp := buildResponse() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkString = core.JSONMarshalString(resp) + } +} + +// --- JSON Unmarshal --- + +func BenchmarkResponses_UnmarshalRequest_SingleTurn(b *testing.B) { + body := core.JSONMarshalString(buildResponseRequest(1)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req ResponseRequest + responsesSinkResult = core.JSONUnmarshalString(body, &req) + responsesSinkRequest = req + } +} + +func BenchmarkResponses_UnmarshalRequest_FiveTurn(b *testing.B) { + body := core.JSONMarshalString(buildResponseRequest(5)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req ResponseRequest + responsesSinkResult = core.JSONUnmarshalString(body, &req) + responsesSinkRequest = req + } +} + +func BenchmarkResponses_UnmarshalRequest_TwentyTurn(b *testing.B) { + body := core.JSONMarshalString(buildResponseRequest(20)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req ResponseRequest + responsesSinkResult = core.JSONUnmarshalString(body, &req) + responsesSinkRequest = req + } +} + +func BenchmarkResponses_UnmarshalResponse_Typical(b *testing.B) { + body := core.JSONMarshalString(buildResponse()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var resp Response + responsesSinkResult = core.JSONUnmarshalString(body, &resp) + responsesSinkResponse = resp + } +} + +// --- ResponseMessages — wire→internal conversion per request --- + +func BenchmarkResponses_ResponseMessages_SingleTurn(b *testing.B) { + req := buildResponseRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkMessages = ResponseMessages(req) + } +} + +func BenchmarkResponses_ResponseMessages_FiveTurn(b *testing.B) { + req := buildResponseRequest(5) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkMessages = ResponseMessages(req) + } +} + +func BenchmarkResponses_ResponseMessages_TwentyTurn(b *testing.B) { + req := buildResponseRequest(20) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkMessages = ResponseMessages(req) + } +} + +func BenchmarkResponses_ResponseMessages_InstructionsOnly(b *testing.B) { + req := ResponseRequest{Model: "qwen3", Instructions: "Be concise."} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkMessages = ResponseMessages(req) + } +} + +// --- ResponseGenerateOptions — request-time sampling projection --- + +func BenchmarkResponses_GenerateOptions_AllFieldsSet(b *testing.B) { + req := buildResponseRequest(1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkOptions, responsesSinkErr = ResponseGenerateOptions(req) + } +} + +// Instructions-only path — exercises the empty-input fallback branch +// that synthesises a ChatMessage from req.Instructions. +func BenchmarkResponses_GenerateOptions_InstructionsOnly(b *testing.B) { + req := ResponseRequest{Model: "qwen3", Instructions: "Be concise."} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkOptions, responsesSinkErr = ResponseGenerateOptions(req) + } +} + +// --- NewTextResponse — fired once per non-streaming completion --- + +func BenchmarkResponses_NewTextResponse(b *testing.B) { + metrics := inference.GenerateMetrics{PromptTokens: 200, GeneratedTokens: 32} + text := "The summary is concise and faithful to the original text." + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkResponse = NewTextResponse("resp_bench", "qwen3", text, metrics) + } +} + +// --- ResponseStreamEvent marshal — fired per streamed delta + final --- + +func BenchmarkResponses_MarshalStreamEvent_Delta_ShortToken(b *testing.B) { + event := ResponseStreamEvent{ + Type: "response.output_text.delta", + Delta: "Answer", + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkString = core.JSONMarshalString(event) + } +} + +func BenchmarkResponses_MarshalStreamEvent_Delta_LongToken(b *testing.B) { + delta := "" + for i := 0; i < 64; i++ { + delta += "fragment " + } + event := ResponseStreamEvent{ + Type: "response.output_text.delta", + Delta: delta, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkString = core.JSONMarshalString(event) + } +} + +func BenchmarkResponses_MarshalStreamEvent_Completed(b *testing.B) { + resp := buildResponse() + event := ResponseStreamEvent{Type: "response.completed", Response: &resp} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkString = core.JSONMarshalString(event) + } +} + +func BenchmarkResponses_MarshalStreamEvent_ThoughtDelta(b *testing.B) { + thought := "Let me think through this step by step." + event := ResponseStreamEvent{ + Type: "response.thought.delta", + Delta: "thinking", + Thought: &thought, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkString = core.JSONMarshalString(event) + } +} + +// --- Hand-rolled encoders — wired into writeJSON fast-path + --- +// available as direct call sites for downstream Responses producers. + +func BenchmarkResponses_AppendResponse_Typical(b *testing.B) { + resp := buildResponse() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkBytes = appendResponse(make([]byte, 0, responseSize(resp)), resp) + } +} + +func BenchmarkResponses_AppendStreamEvent_Delta_ShortToken(b *testing.B) { + event := ResponseStreamEvent{ + Type: "response.output_text.delta", + Delta: "Answer", + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkBytes = appendResponseStreamEvent(make([]byte, 0, responseStreamEventSize(event)), event) + } +} + +func BenchmarkResponses_AppendStreamEvent_Delta_LongToken(b *testing.B) { + delta := "" + for i := 0; i < 64; i++ { + delta += "fragment " + } + event := ResponseStreamEvent{ + Type: "response.output_text.delta", + Delta: delta, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkBytes = appendResponseStreamEvent(make([]byte, 0, responseStreamEventSize(event)), event) + } +} + +func BenchmarkResponses_AppendStreamEvent_Completed(b *testing.B) { + resp := buildResponse() + event := ResponseStreamEvent{Type: "response.completed", Response: &resp} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkBytes = appendResponseStreamEvent(make([]byte, 0, responseStreamEventSize(event)), event) + } +} + +func BenchmarkResponses_AppendStreamEvent_ThoughtDelta(b *testing.B) { + thought := "Let me think through this step by step." + event := ResponseStreamEvent{ + Type: "response.thought.delta", + Delta: "thinking", + Thought: &thought, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + responsesSinkBytes = appendResponseStreamEvent(make([]byte, 0, responseStreamEventSize(event)), event) + } +} + +// --- Stream-event unmarshal — proxy clients pay this on every SSE frame --- + +func BenchmarkResponses_UnmarshalStreamEvent_Delta(b *testing.B) { + body := core.JSONMarshalString(ResponseStreamEvent{ + Type: "response.output_text.delta", + Delta: "Answer", + }) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var event ResponseStreamEvent + responsesSinkResult = core.JSONUnmarshalString(body, &event) + responsesSinkEvent = event + } +} + +func BenchmarkResponses_UnmarshalStreamEvent_Completed(b *testing.B) { + resp := buildResponse() + body := core.JSONMarshalString(ResponseStreamEvent{Type: "response.completed", Response: &resp}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var event ResponseStreamEvent + responsesSinkResult = core.JSONUnmarshalString(body, &event) + responsesSinkEvent = event + } +} diff --git a/go/openai/responses_enc.go b/go/openai/responses_enc.go new file mode 100644 index 0000000..62d2fb5 --- /dev/null +++ b/go/openai/responses_enc.go @@ -0,0 +1,156 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Hand-rolled encoders for the OpenAI Responses API wire shapes — +// Response and ResponseStreamEvent. Same W9-D shape as the chat- +// completions encoders: single-buffer emission, no reflect, the +// shared jsonenc.AppendStringField / jsonenc.AppendIntField +// primitives from dappco.re/go/inference/jsonenc (W9-Z lift). +// +// Responses is the OpenAI v1/responses endpoint — the per-token +// stream event encoder fires per generated text delta on the +// streaming path; the per-response Response encoder fires once per +// non-streaming completed call (and embeds itself inside the +// terminal "response.completed" stream event). + +package openai + +import "dappco.re/go/inference/jsonenc" + +// appendResponseOutputText walks one ResponseOutputText into buf. +// Two ASCII string fields in canonical order. +func appendResponseOutputText(buf []byte, item ResponseOutputText) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "type", item.Type, false) + buf = jsonenc.AppendStringField(buf, "text", item.Text, true) + return append(buf, '}') +} + +// appendResponseOutputMessage walks one ResponseOutputMessage into +// buf. The ID field carries the omitempty tag — emit only when set. +func appendResponseOutputMessage(buf []byte, msg ResponseOutputMessage) []byte { + buf = append(buf, '{') + leading := false + if msg.ID != "" { + buf = jsonenc.AppendStringField(buf, "id", msg.ID, false) + leading = true + } + buf = jsonenc.AppendStringField(buf, "type", msg.Type, leading) + buf = jsonenc.AppendStringField(buf, "role", msg.Role, true) + buf = append(buf, ',', '"', 'c', 'o', 'n', 't', 'e', 'n', 't', '"', ':', '[') + for i, item := range msg.Content { + if i > 0 { + buf = append(buf, ',') + } + buf = appendResponseOutputText(buf, item) + } + return append(buf, ']', '}') +} + +// appendResponseUsage walks a ResponseUsage into buf. Three int +// fields — input_tokens, output_tokens, total_tokens. +func appendResponseUsage(buf []byte, usage ResponseUsage) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendIntField(buf, "input_tokens", usage.InputTokens, false) + buf = jsonenc.AppendIntField(buf, "output_tokens", usage.OutputTokens, true) + buf = jsonenc.AppendIntField(buf, "total_tokens", usage.TotalTokens, true) + return append(buf, '}') +} + +// appendResponse walks the full Response shape into buf. Field +// order matches the struct declaration so the wire output is byte- +// identical to encoding/json.Marshal output. +func appendResponse(buf []byte, resp Response) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "id", resp.ID, false) + buf = jsonenc.AppendStringField(buf, "object", resp.Object, true) + buf = jsonenc.AppendInt64Field(buf, "created", resp.Created, true) + buf = jsonenc.AppendStringField(buf, "model", resp.Model, true) + buf = append(buf, ',', '"', 'o', 'u', 't', 'p', 'u', 't', '"', ':', '[') + for i, msg := range resp.Output { + if i > 0 { + buf = append(buf, ',') + } + buf = appendResponseOutputMessage(buf, msg) + } + buf = append(buf, ']', ',', '"', 'u', 's', 'a', 'g', 'e', '"', ':') + buf = appendResponseUsage(buf, resp.Usage) + if resp.Thought != nil { + buf = append(buf, ',', '"', 't', 'h', 'o', 'u', 'g', 'h', 't', '"', ':') + buf = jsonenc.AppendJSONString(buf, *resp.Thought) + } + return append(buf, '}') +} + +// responseSize estimates the backing-buffer size for one Response +// so the encoder allocates once. Conservative (slight over-shoot) +// so closing punctuation doesn't trigger a grow into the next size +// class. +func responseSize(resp Response) int { + size := 4 // {} + slack for closing punctuation + size += 7 + len(resp.ID) + size += 11 + len(resp.Object) + size += 12 + 20 + size += 10 + len(resp.Model) + size += 12 // ,"output":[] + for _, msg := range resp.Output { + size += 3 // {} + separator + if msg.ID != "" { + size += 8 + len(msg.ID) + } + size += 9 + len(msg.Type) + size += 9 + len(msg.Role) + size += 13 // ,"content":[] + for _, item := range msg.Content { + size += 3 + 9 + len(item.Type) + 9 + len(item.Text) + } + } + size += 62 // ,"usage":{...} + if resp.Thought != nil { + size += 13 + len(*resp.Thought) + } + return size +} + +// appendResponseStreamEvent walks the ResponseStreamEvent shape +// into buf. The Response pointer + Delta + Thought are all +// omitempty — emit only the fields set on the event. +func appendResponseStreamEvent(buf []byte, event ResponseStreamEvent) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "type", event.Type, false) + if event.Response != nil { + buf = append(buf, ',', '"', 'r', 'e', 's', 'p', 'o', 'n', 's', 'e', '"', ':') + buf = appendResponse(buf, *event.Response) + } + if event.Delta != "" { + buf = jsonenc.AppendStringField(buf, "delta", event.Delta, true) + } + if event.Thought != nil { + buf = append(buf, ',', '"', 't', 'h', 'o', 'u', 'g', 'h', 't', '"', ':') + buf = jsonenc.AppendJSONString(buf, *event.Thought) + } + return append(buf, '}') +} + +// responseStreamEventSize estimates the backing-buffer size for one +// stream event so the encoder allocates once. The Response pointer +// embedding is the load-bearing case (response.completed events) — +// uses responseSize recursively. +// +// The estimate is intentionally conservative (covers the closing +// '}' and any trailing punctuation) so the typical event lands in a +// single allocator size class. Pathological escape-heavy values let +// append grow once. +func responseStreamEventSize(event ResponseStreamEvent) int { + size := 4 // {"type":"..."} framing + closing brace + slack + size += 8 + len(event.Type) + if event.Response != nil { + size += 12 + responseSize(*event.Response) + } + if event.Delta != "" { + size += 11 + len(event.Delta) + } + if event.Thought != nil { + size += 13 + len(*event.Thought) + } + return size +} diff --git a/go/openai/responses_enc_test.go b/go/openai/responses_enc_test.go new file mode 100644 index 0000000..f7cb0ae --- /dev/null +++ b/go/openai/responses_enc_test.go @@ -0,0 +1,127 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "encoding/json" + "testing" + + "dappco.re/go/inference" +) + +// TestResponse_AppendRoundTrip locks the hand-rolled Responses-API +// non-streaming encoder to encoding/json's deserialiser. +func TestResponse_AppendRoundTrip(t *testing.T) { + thought := "let me think" + cases := []struct { + name string + in Response + }{ + {"minimal", NewTextResponse("resp_x", "qwen3", "Hi", inference.GenerateMetrics{PromptTokens: 3, GeneratedTokens: 4})}, + {"with-thought", func() Response { + r := NewTextResponse("resp_y", "qwen3", "Answer", inference.GenerateMetrics{PromptTokens: 10, GeneratedTokens: 20}) + r.Thought = &thought + return r + }()}, + {"with-id-on-msg", Response{ + ID: "resp_z", Object: "response", Created: 1700000000, Model: "qwen3", + Output: []ResponseOutputMessage{{ + ID: "msg_1", Type: "message", Role: "assistant", + Content: []ResponseOutputText{{Type: "output_text", Text: "text"}}, + }}, + Usage: ResponseUsage{InputTokens: 1, OutputTokens: 2, TotalTokens: 3}, + }}, + {"escapes", Response{ + ID: "resp_e", Object: "response", Created: 1700000000, Model: "qwen3", + Output: []ResponseOutputMessage{{ + Type: "message", Role: "assistant", + Content: []ResponseOutputText{{Type: "output_text", Text: "quote \" tab\t"}}, + }}, + Usage: ResponseUsage{InputTokens: 1, OutputTokens: 1, TotalTokens: 2}, + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + encoded := appendResponse(nil, tc.in) + var back Response + if err := json.Unmarshal(encoded, &back); err != nil { + t.Fatalf("json.Unmarshal(%s) error = %v", encoded, err) + } + if back.ID != tc.in.ID || back.Object != tc.in.Object || back.Created != tc.in.Created || back.Model != tc.in.Model { + t.Fatalf("identity: got %+v, want %+v", back, tc.in) + } + if back.Usage != tc.in.Usage { + t.Fatalf("usage: got %+v, want %+v", back.Usage, tc.in.Usage) + } + if len(back.Output) != len(tc.in.Output) { + t.Fatalf("output len = %d, want %d", len(back.Output), len(tc.in.Output)) + } + for i := range tc.in.Output { + if back.Output[i].ID != tc.in.Output[i].ID || + back.Output[i].Type != tc.in.Output[i].Type || + back.Output[i].Role != tc.in.Output[i].Role { + t.Fatalf("output[%d] header: got %+v want %+v", i, back.Output[i], tc.in.Output[i]) + } + if len(back.Output[i].Content) != len(tc.in.Output[i].Content) { + t.Fatalf("output[%d].content len = %d, want %d", i, len(back.Output[i].Content), len(tc.in.Output[i].Content)) + } + for j := range tc.in.Output[i].Content { + if back.Output[i].Content[j] != tc.in.Output[i].Content[j] { + t.Fatalf("output[%d].content[%d] = %+v, want %+v", i, j, back.Output[i].Content[j], tc.in.Output[i].Content[j]) + } + } + } + if (back.Thought == nil) != (tc.in.Thought == nil) { + t.Fatalf("thought nil mismatch: got=%v want=%v", back.Thought, tc.in.Thought) + } + if back.Thought != nil && *back.Thought != *tc.in.Thought { + t.Fatalf("thought = %q, want %q", *back.Thought, *tc.in.Thought) + } + }) + } +} + +// TestResponseStreamEvent_AppendRoundTrip locks the hand-rolled +// stream-event encoder. Fires per delta on the streaming path; the +// "response.completed" event embeds a full Response payload. +func TestResponseStreamEvent_AppendRoundTrip(t *testing.T) { + thought := "let me think" + resp := NewTextResponse("resp_x", "qwen3", "Hi", inference.GenerateMetrics{PromptTokens: 3, GeneratedTokens: 4}) + cases := []struct { + name string + in ResponseStreamEvent + }{ + {"delta-only", ResponseStreamEvent{Type: "response.output_text.delta", Delta: "Answer"}}, + {"thought-delta", ResponseStreamEvent{Type: "response.thought.delta", Delta: "thinking", Thought: &thought}}, + {"completed", ResponseStreamEvent{Type: "response.completed", Response: &resp}}, + {"type-only", ResponseStreamEvent{Type: "response.started"}}, + {"delta-with-escapes", ResponseStreamEvent{Type: "response.output_text.delta", Delta: "quote \" tab\t"}}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + encoded := appendResponseStreamEvent(nil, tc.in) + var back ResponseStreamEvent + if err := json.Unmarshal(encoded, &back); err != nil { + t.Fatalf("json.Unmarshal(%s) error = %v", encoded, err) + } + if back.Type != tc.in.Type { + t.Fatalf("type: got %q, want %q", back.Type, tc.in.Type) + } + if back.Delta != tc.in.Delta { + t.Fatalf("delta: got %q, want %q", back.Delta, tc.in.Delta) + } + if (back.Response == nil) != (tc.in.Response == nil) { + t.Fatalf("response nil mismatch") + } + if back.Response != nil && back.Response.ID != tc.in.Response.ID { + t.Fatalf("response.id: got %q, want %q", back.Response.ID, tc.in.Response.ID) + } + if (back.Thought == nil) != (tc.in.Thought == nil) { + t.Fatalf("thought nil mismatch") + } + if back.Thought != nil && *back.Thought != *tc.in.Thought { + t.Fatalf("thought: got %q, want %q", *back.Thought, *tc.in.Thought) + } + }) + } +} diff --git a/go/openai/responses_test.go b/go/openai/responses_test.go new file mode 100644 index 0000000..238e929 --- /dev/null +++ b/go/openai/responses_test.go @@ -0,0 +1,61 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "testing" + + "dappco.re/go/inference" +) + +func TestResponses_ResponseMessages_Good(t *testing.T) { + req := ResponseRequest{ + Instructions: "Be concise.", + Input: []ResponseInputMessage{ + {Role: "user", Content: "hello"}, + }, + } + + messages := ResponseMessages(req) + + if len(messages) != 2 { + t.Fatalf("len(messages) = %d, want 2", len(messages)) + } + if messages[0].Role != "system" || messages[1].Content != "hello" { + t.Fatalf("messages = %+v", messages) + } +} + +func TestResponses_ResponseGenerateOptions_Good(t *testing.T) { + maxTokens := 12 + temperature := float32(0) + req := ResponseRequest{ + Model: "qwen", + Input: []ResponseInputMessage{{Role: "user", Content: "hi"}}, + MaxOutputTokens: &maxTokens, + Temperature: &temperature, + } + + opts, err := ResponseGenerateOptions(req) + if err != nil { + t.Fatalf("ResponseGenerateOptions() error = %v", err) + } + cfg := inference.ApplyGenerateOpts(opts) + if cfg.MaxTokens != 12 || cfg.Temperature != 0 { + t.Fatalf("cfg = %+v", cfg) + } +} + +func TestResponses_NewTextResponse_Good(t *testing.T) { + resp := NewTextResponse("resp_1", "qwen", "ok", inference.GenerateMetrics{PromptTokens: 3, GeneratedTokens: 2}) + + if resp.ID != "resp_1" || resp.Object != "response" || resp.Model != "qwen" { + t.Fatalf("response identity = %+v", resp) + } + if resp.Usage.TotalTokens != 5 { + t.Fatalf("usage = %+v", resp.Usage) + } + if resp.Output[0].Content[0].Text != "ok" { + t.Fatalf("output = %+v", resp.Output) + } +} diff --git a/go/openai/services.go b/go/openai/services.go new file mode 100644 index 0000000..58aba21 --- /dev/null +++ b/go/openai/services.go @@ -0,0 +1,400 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "context" + "io" + "net/http" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +const ( + DefaultEmbeddingsPath = "/v1/embeddings" + DefaultRerankPath = "/v1/rerank" + DefaultCapabilitiesPath = "/v1/models/capabilities" + DefaultCacheStatsPath = "/v1/cache/stats" + DefaultCacheWarmPath = "/v1/cache/warm" + DefaultCacheClearPath = "/v1/cache/clear" + DefaultCancelPath = "/v1/cancel" +) + +// EmbeddingRequest is the OpenAI-compatible embedding request body. +type EmbeddingRequest struct { + Model string `json:"model"` + Input EmbeddingInput `json:"input"` + EncodingFormat string `json:"encoding_format,omitempty"` + Dimensions *int `json:"dimensions,omitempty"` + User string `json:"user,omitempty"` + Normalize bool `json:"normalize,omitempty"` +} + +// EmbeddingInput accepts either a string or an array of strings. +type EmbeddingInput []string + +func (input *EmbeddingInput) UnmarshalJSON(data []byte) error { + // Hot path — fires per embeddings request. parseJSONStringList + // walks the variant string-or-array shape in a single pass — + // drops the recursive core.JSONUnmarshal allocs (encoder state + // + per-element string). + values, err := parseJSONStringList(data) + if err != nil { + return err + } + *input = values + return nil +} + +type EmbeddingResponse struct { + Object string `json:"object"` + Data []EmbeddingResponseDatum `json:"data"` + Model string `json:"model"` + Usage inference.EmbeddingUsage `json:"usage"` +} + +type EmbeddingResponseDatum struct { + Object string `json:"object"` + Index int `json:"index"` + Embedding []float32 `json:"embedding"` +} + +type RerankRequest struct { + Model string `json:"model"` + Query string `json:"query"` + Documents []string `json:"documents"` + TopN int `json:"top_n,omitempty"` +} + +type RerankResponse struct { + Object string `json:"object"` + Model string `json:"model"` + Results []inference.RerankScore `json:"results"` +} + +type CacheWarmRequest struct { + Model string `json:"model"` + Prompt string `json:"prompt,omitempty"` + Tokens []int32 `json:"tokens,omitempty"` + Mode string `json:"mode,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +type CacheClearRequest struct { + Model string `json:"model"` + Labels map[string]string `json:"labels,omitempty"` +} + +type CancelRequest struct { + Model string `json:"model"` + ID string `json:"id"` +} + +type serviceHandler struct { + resolver Resolver +} + +type EmbeddingsHandler struct{ serviceHandler } +type RerankHandler struct{ serviceHandler } +type CapabilityHandler struct{ serviceHandler } +type CacheStatsHandler struct{ serviceHandler } +type CacheWarmHandler struct{ serviceHandler } +type CacheClearHandler struct{ serviceHandler } +type CancelHandler struct{ serviceHandler } + +func NewEmbeddingsHandler(resolver Resolver) *EmbeddingsHandler { + return &EmbeddingsHandler{serviceHandler{resolver: resolver}} +} + +func NewRerankHandler(resolver Resolver) *RerankHandler { + return &RerankHandler{serviceHandler{resolver: resolver}} +} + +func NewCapabilityHandler(resolver Resolver) *CapabilityHandler { + return &CapabilityHandler{serviceHandler{resolver: resolver}} +} + +func NewCacheStatsHandler(resolver Resolver) *CacheStatsHandler { + return &CacheStatsHandler{serviceHandler{resolver: resolver}} +} + +func NewCacheWarmHandler(resolver Resolver) *CacheWarmHandler { + return &CacheWarmHandler{serviceHandler{resolver: resolver}} +} + +func NewCacheClearHandler(resolver Resolver) *CacheClearHandler { + return &CacheClearHandler{serviceHandler{resolver: resolver}} +} + +func NewCancelHandler(resolver Resolver) *CancelHandler { + return &CancelHandler{serviceHandler{resolver: resolver}} +} + +func (h *EmbeddingsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !requireServiceMethod(w, r, http.MethodPost) { + return + } + var req EmbeddingRequest + if !decodeServiceRequest(w, r, &req, "openai.EmbeddingsHandler") { + return + } + if core.Trim(req.Model) == "" { + writeError(w, http.StatusBadRequest, "model is required", "model") + return + } + if len(req.Input) == 0 { + writeError(w, http.StatusBadRequest, "input must not be empty", "input") + return + } + model, ok := h.resolve(w, r.Context(), req.Model) + if !ok { + return + } + embeddingModel, ok := model.(inference.EmbeddingModel) + if !ok { + writeError(w, http.StatusNotImplemented, "model does not support embeddings", "model") + return + } + result, err := embeddingModel.Embed(r.Context(), inference.EmbeddingRequest{ + Model: req.Model, + Input: []string(req.Input), + Normalize: req.Normalize, + }) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error(), "model") + return + } + data := make([]EmbeddingResponseDatum, 0, len(result.Vectors)) + for i, vector := range result.Vectors { + data = append(data, EmbeddingResponseDatum{Object: "embedding", Index: i, Embedding: vector}) + } + writeJSON(w, http.StatusOK, EmbeddingResponse{Object: "list", Data: data, Model: req.Model, Usage: result.Usage}) +} + +func (h *RerankHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !requireServiceMethod(w, r, http.MethodPost) { + return + } + var req RerankRequest + if !decodeServiceRequest(w, r, &req, "openai.RerankHandler") { + return + } + if core.Trim(req.Model) == "" { + writeError(w, http.StatusBadRequest, "model is required", "model") + return + } + if core.Trim(req.Query) == "" { + writeError(w, http.StatusBadRequest, "query is required", "query") + return + } + if len(req.Documents) == 0 { + writeError(w, http.StatusBadRequest, "documents must not be empty", "documents") + return + } + model, ok := h.resolve(w, r.Context(), req.Model) + if !ok { + return + } + rerankModel, ok := model.(inference.RerankModel) + if !ok { + writeError(w, http.StatusNotImplemented, "model does not support rerank", "model") + return + } + result, err := rerankModel.Rerank(r.Context(), inference.RerankRequest{ + Model: req.Model, + Query: req.Query, + Documents: req.Documents, + TopN: req.TopN, + }) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error(), "model") + return + } + writeJSON(w, http.StatusOK, RerankResponse{Object: "list", Model: req.Model, Results: result.Results}) +} + +func (h *CapabilityHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !requireServiceMethod(w, r, http.MethodGet) { + return + } + modelName := queryModel(r) + if modelName == "" { + writeError(w, http.StatusBadRequest, "model is required", "model") + return + } + model, ok := h.resolve(w, r.Context(), modelName) + if !ok { + return + } + if reporter, ok := model.(inference.CapabilityReporter); ok { + writeJSON(w, http.StatusOK, reporter.Capabilities()) + return + } + writeJSON(w, http.StatusOK, inference.TextModelCapabilities(inference.RuntimeIdentity{}, model)) +} + +func (h *CacheStatsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !requireServiceMethod(w, r, http.MethodGet) { + return + } + model, ok := h.resolveCacheService(w, r.Context(), queryModel(r)) + if !ok { + return + } + stats, err := model.CacheStats(r.Context()) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error(), "cache") + return + } + writeJSON(w, http.StatusOK, stats) +} + +func (h *CacheWarmHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !requireServiceMethod(w, r, http.MethodPost) { + return + } + var req CacheWarmRequest + if !decodeServiceRequest(w, r, &req, "openai.CacheWarmHandler") { + return + } + model, ok := h.resolveCacheService(w, r.Context(), req.Model) + if !ok { + return + } + result, err := model.WarmCache(r.Context(), inference.CacheWarmRequest{ + Model: inference.ModelIdentity{ID: req.Model}, + Prompt: req.Prompt, + Tokens: req.Tokens, + Mode: req.Mode, + Labels: req.Labels, + }) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error(), "cache") + return + } + writeJSON(w, http.StatusOK, result) +} + +func (h *CacheClearHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !requireServiceMethod(w, r, http.MethodPost) { + return + } + var req CacheClearRequest + if !decodeServiceRequest(w, r, &req, "openai.CacheClearHandler") { + return + } + model, ok := h.resolveCacheService(w, r.Context(), req.Model) + if !ok { + return + } + stats, err := model.ClearCache(r.Context(), req.Labels) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error(), "cache") + return + } + writeJSON(w, http.StatusOK, stats) +} + +func (h *CancelHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if !requireServiceMethod(w, r, http.MethodPost) { + return + } + var req CancelRequest + if !decodeServiceRequest(w, r, &req, "openai.CancelHandler") { + return + } + if core.Trim(req.ID) == "" { + writeError(w, http.StatusBadRequest, "id is required", "id") + return + } + model, ok := h.resolve(w, r.Context(), req.Model) + if !ok { + return + } + cancellable, ok := model.(inference.CancellableModel) + if !ok { + writeError(w, http.StatusNotImplemented, "model does not support request cancellation", "model") + return + } + result, err := cancellable.CancelRequest(r.Context(), req.ID) + if err != nil { + writeError(w, http.StatusInternalServerError, err.Error(), "model") + return + } + writeJSON(w, http.StatusOK, result) +} + +func (h *serviceHandler) resolve(w http.ResponseWriter, ctx context.Context, modelName string) (inference.TextModel, bool) { + if h == nil || h.resolver == nil { + writeError(w, http.StatusServiceUnavailable, "handler is not configured", "model") + return nil, false + } + modelName = core.Trim(modelName) + if modelName == "" { + writeError(w, http.StatusBadRequest, "model is required", "model") + return nil, false + } + model, err := h.resolver.ResolveModel(ctx, modelName) + if err != nil { + writeError(w, http.StatusNotFound, err.Error(), "model") + return nil, false + } + return model, true +} + +func (h *serviceHandler) resolveCacheService(w http.ResponseWriter, ctx context.Context, modelName string) (inference.CacheService, bool) { + model, ok := h.resolve(w, ctx, modelName) + if !ok { + return nil, false + } + cache, ok := model.(inference.CacheService) + if !ok { + writeError(w, http.StatusNotImplemented, "model does not support cache service operations", "model") + return nil, false + } + return cache, true +} + +func decodeServiceRequest(w http.ResponseWriter, r *http.Request, into any, scope string) bool { + if r == nil || r.Body == nil { + writeError(w, http.StatusBadRequest, "request body is nil", "body") + return false + } + data, err := io.ReadAll(r.Body) + if err != nil { + writeError(w, http.StatusBadRequest, "read request body failed", "body") + return false + } + result := core.JSONUnmarshal(data, into) + if !result.OK { + err := resultError(result) + message := "invalid request body" + if err != nil && core.Trim(err.Error()) != "" { + message = core.Concat(scope, ": ", err.Error()) + } + writeError(w, http.StatusBadRequest, message, "body") + return false + } + return true +} + +func requireServiceMethod(w http.ResponseWriter, r *http.Request, method string) bool { + if r == nil { + writeError(w, http.StatusBadRequest, "request is nil", "request") + return false + } + if r.Method != method { + w.Header().Set("Allow", method) + writeError(w, http.StatusMethodNotAllowed, "method not allowed", "method") + return false + } + return true +} + +func queryModel(r *http.Request) string { + if r == nil || r.URL == nil { + return "" + } + return core.Trim(r.URL.Query().Get("model")) +} diff --git a/go/openai/services_bench_test.go b/go/openai/services_bench_test.go new file mode 100644 index 0000000..399cbbb --- /dev/null +++ b/go/openai/services_bench_test.go @@ -0,0 +1,343 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the OpenAI-compatible service-endpoint wire shapes: +// embeddings, rerank, cache stats/warm/clear, cancel. Per AX-11 — every +// embedding ingestion serialises an EmbeddingResponse with one +// EmbeddingResponseDatum per vector, and every rerank call serialises +// a RerankResult payload. EmbeddingInput.UnmarshalJSON variant parse is +// hit on every embeddings request. +// +// Run: go test -bench='BenchmarkServices' -benchtime=100ms -benchmem -run='^$' . + +package openai + +import ( + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + servicesSinkEmbedRequest EmbeddingRequest + servicesSinkEmbedResponse EmbeddingResponse + servicesSinkEmbeddingInput EmbeddingInput + servicesSinkRerankRequest RerankRequest + servicesSinkRerankResponse RerankResponse + servicesSinkCacheWarmReq CacheWarmRequest + servicesSinkCacheClearReq CacheClearRequest + servicesSinkCancelReq CancelRequest + servicesSinkCacheStats inference.CacheStats + servicesSinkErr error + servicesSinkString string + servicesSinkBytes []byte + servicesSinkResult core.Result +) + +// --- Fixture builders --- + +// buildEmbeddingVectors generates synthetic vectors of the requested +// dimension and count — matches the production response shape where +// each input string maps to one vector. +func buildEmbeddingVectors(count, dim int) [][]float32 { + out := make([][]float32, count) + for i := range out { + vec := make([]float32, dim) + for j := range vec { + vec[j] = float32(i*dim+j) * 0.001 + } + out[i] = vec + } + return out +} + +func buildEmbeddingResponse(count, dim int) EmbeddingResponse { + vectors := buildEmbeddingVectors(count, dim) + data := make([]EmbeddingResponseDatum, 0, count) + for i, vec := range vectors { + data = append(data, EmbeddingResponseDatum{Object: "embedding", Index: i, Embedding: vec}) + } + return EmbeddingResponse{ + Object: "list", + Data: data, + Model: "qwen3-embed", + Usage: inference.EmbeddingUsage{PromptTokens: count * 16, TotalTokens: count * 16}, + } +} + +// --- EmbeddingInput.UnmarshalJSON — variant parse on every embeddings request --- + +func BenchmarkServices_EmbeddingInput_UnmarshalJSON_SingleString(b *testing.B) { + data := []byte(`"hello world"`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var input EmbeddingInput + servicesSinkErr = input.UnmarshalJSON(data) + servicesSinkEmbeddingInput = input + } +} + +func BenchmarkServices_EmbeddingInput_UnmarshalJSON_SmallArray(b *testing.B) { + data := []byte(`["one","two","three"]`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var input EmbeddingInput + servicesSinkErr = input.UnmarshalJSON(data) + servicesSinkEmbeddingInput = input + } +} + +func BenchmarkServices_EmbeddingInput_UnmarshalJSON_TwentyArray(b *testing.B) { + body := `["alpha","beta","gamma","delta","epsilon","zeta","eta","theta","iota","kappa","lambda","mu","nu","xi","omicron","pi","rho","sigma","tau","upsilon"]` + data := []byte(body) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var input EmbeddingInput + servicesSinkErr = input.UnmarshalJSON(data) + servicesSinkEmbeddingInput = input + } +} + +// --- EmbeddingRequest — full request unmarshal at handler entry --- + +func BenchmarkServices_UnmarshalEmbeddingRequest_SingleInput(b *testing.B) { + body := `{"model":"qwen3-embed","input":"hello world","normalize":true}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req EmbeddingRequest + servicesSinkResult = core.JSONUnmarshalString(body, &req) + servicesSinkEmbedRequest = req + } +} + +func BenchmarkServices_UnmarshalEmbeddingRequest_ArrayInput(b *testing.B) { + body := `{"model":"qwen3-embed","input":["one","two","three","four","five"],"normalize":true,"dimensions":768}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req EmbeddingRequest + servicesSinkResult = core.JSONUnmarshalString(body, &req) + servicesSinkEmbedRequest = req + } +} + +// --- EmbeddingResponse marshal — response emission --- +// Three dim/count shapes — small (1×384), medium (5×768), large (20×1024). + +func BenchmarkServices_MarshalEmbeddingResponse_1x384(b *testing.B) { + resp := buildEmbeddingResponse(1, 384) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkString = core.JSONMarshalString(resp) + } +} + +func BenchmarkServices_MarshalEmbeddingResponse_5x768(b *testing.B) { + resp := buildEmbeddingResponse(5, 768) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkString = core.JSONMarshalString(resp) + } +} + +func BenchmarkServices_MarshalEmbeddingResponse_20x1024(b *testing.B) { + resp := buildEmbeddingResponse(20, 1024) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkString = core.JSONMarshalString(resp) + } +} + +// --- Hand-rolled embedding-response encoder — writeJSON fast path --- +// Compares directly against the encoding/json reflect-walk path +// above. Per-element float32 emission scales with vector dim. + +func BenchmarkServices_AppendEmbeddingResponse_1x384(b *testing.B) { + resp := buildEmbeddingResponse(1, 384) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkBytes = appendEmbeddingResponse(make([]byte, 0, embeddingResponseSize(resp)), resp) + } +} + +func BenchmarkServices_AppendEmbeddingResponse_5x768(b *testing.B) { + resp := buildEmbeddingResponse(5, 768) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkBytes = appendEmbeddingResponse(make([]byte, 0, embeddingResponseSize(resp)), resp) + } +} + +func BenchmarkServices_AppendEmbeddingResponse_20x1024(b *testing.B) { + resp := buildEmbeddingResponse(20, 1024) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkBytes = appendEmbeddingResponse(make([]byte, 0, embeddingResponseSize(resp)), resp) + } +} + +// --- RerankRequest unmarshal --- + +func BenchmarkServices_UnmarshalRerankRequest_FewDocs(b *testing.B) { + body := `{"model":"qwen3-rerank","query":"core primitives","documents":["a","b","c"],"top_n":2}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req RerankRequest + servicesSinkResult = core.JSONUnmarshalString(body, &req) + servicesSinkRerankRequest = req + } +} + +func BenchmarkServices_UnmarshalRerankRequest_TwentyDocs(b *testing.B) { + body := `{"model":"qwen3-rerank","query":"core primitives","documents":["alpha","beta","gamma","delta","epsilon","zeta","eta","theta","iota","kappa","lambda","mu","nu","xi","omicron","pi","rho","sigma","tau","upsilon"],"top_n":5}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req RerankRequest + servicesSinkResult = core.JSONUnmarshalString(body, &req) + servicesSinkRerankRequest = req + } +} + +// --- RerankResponse marshal --- + +func BenchmarkServices_MarshalRerankResponse_FewResults(b *testing.B) { + resp := RerankResponse{ + Object: "list", + Model: "qwen3-rerank", + Results: []inference.RerankScore{ + {Index: 0, Score: 0.91, Text: "alpha"}, + {Index: 1, Score: 0.82, Text: "beta"}, + {Index: 2, Score: 0.74, Text: "gamma"}, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkString = core.JSONMarshalString(resp) + } +} + +func BenchmarkServices_MarshalRerankResponse_TwentyResults(b *testing.B) { + results := make([]inference.RerankScore, 20) + for i := range results { + results[i] = inference.RerankScore{Index: i, Score: 0.95 - float64(i)*0.04, Text: "document text fragment"} + } + resp := RerankResponse{Object: "list", Model: "qwen3-rerank", Results: results} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkString = core.JSONMarshalString(resp) + } +} + +// --- Hand-rolled rerank-response encoder — writeJSON fast path --- + +func BenchmarkServices_AppendRerankResponse_FewResults(b *testing.B) { + resp := RerankResponse{ + Object: "list", + Model: "qwen3-rerank", + Results: []inference.RerankScore{ + {Index: 0, Score: 0.91, Text: "alpha"}, + {Index: 1, Score: 0.82, Text: "beta"}, + {Index: 2, Score: 0.74, Text: "gamma"}, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkBytes = appendRerankResponse(make([]byte, 0, rerankResponseSize(resp)), resp) + } +} + +func BenchmarkServices_AppendRerankResponse_TwentyResults(b *testing.B) { + results := make([]inference.RerankScore, 20) + for i := range results { + results[i] = inference.RerankScore{Index: i, Score: 0.95 - float64(i)*0.04, Text: "document text fragment"} + } + resp := RerankResponse{Object: "list", Model: "qwen3-rerank", Results: results} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkBytes = appendRerankResponse(make([]byte, 0, rerankResponseSize(resp)), resp) + } +} + +// --- CacheWarmRequest — KV cache prep request ingress --- + +func BenchmarkServices_UnmarshalCacheWarmRequest_Prompt(b *testing.B) { + body := `{"model":"qwen3","prompt":"You are a helpful assistant. Summarise this paragraph.","mode":"block-q8","labels":{"adapter":"none"}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req CacheWarmRequest + servicesSinkResult = core.JSONUnmarshalString(body, &req) + servicesSinkCacheWarmReq = req + } +} + +func BenchmarkServices_UnmarshalCacheWarmRequest_Tokens(b *testing.B) { + body := `{"model":"qwen3","tokens":[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30,31,32],"mode":"block-q8"}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req CacheWarmRequest + servicesSinkResult = core.JSONUnmarshalString(body, &req) + servicesSinkCacheWarmReq = req + } +} + +// --- CacheClearRequest --- + +func BenchmarkServices_UnmarshalCacheClearRequest(b *testing.B) { + body := `{"model":"qwen3","labels":{"adapter":"none","scope":"all"}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req CacheClearRequest + servicesSinkResult = core.JSONUnmarshalString(body, &req) + servicesSinkCacheClearReq = req + } +} + +// --- CancelRequest --- + +func BenchmarkServices_UnmarshalCancelRequest(b *testing.B) { + body := `{"model":"qwen3","id":"req_1700000000_42"}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var req CancelRequest + servicesSinkResult = core.JSONUnmarshalString(body, &req) + servicesSinkCancelReq = req + } +} + +// --- CacheStats marshal — what /v1/cache/stats returns per call --- + +func BenchmarkServices_MarshalCacheStats(b *testing.B) { + stats := inference.CacheStats{ + Blocks: 128, + Hits: 9000, + Misses: 1000, + HitRate: 0.9, + CacheMode: "block-q8", + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + servicesSinkString = core.JSONMarshalString(stats) + } +} diff --git a/go/openai/services_enc.go b/go/openai/services_enc.go new file mode 100644 index 0000000..f5383ab --- /dev/null +++ b/go/openai/services_enc.go @@ -0,0 +1,101 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Hand-rolled encoders for the OpenAI service-endpoint wire shapes +// (rerank). Embeddings is encoded in chunkenc.go alongside the +// chat-completion shapes; rerank lives here because it walks the +// inference.RerankScore contract type, owned by the contract layer. + +package openai + +import ( + "dappco.re/go/inference" + "dappco.re/go/inference/jsonenc" +) + +// appendRerankScore walks one inference.RerankScore into buf. The +// contract carries Index / Score / Text / Labels with omitempty on +// every field — emit only the fields that carry a non-zero value. +// Field ordering matches the struct declaration so wire output is +// byte-compatible with encoding/json's reflect walk. +func appendRerankScore(buf []byte, score inference.RerankScore) []byte { + buf = append(buf, '{') + leading := false + if score.Index != 0 { + buf = jsonenc.AppendIntField(buf, "index", score.Index, false) + leading = true + } + if score.Score != 0 { + if leading { + buf = append(buf, ',') + } + buf = append(buf, '"', 's', 'c', 'o', 'r', 'e', '"', ':') + buf = jsonenc.AppendFloat64(buf, score.Score) + leading = true + } + if score.Text != "" { + buf = jsonenc.AppendStringField(buf, "text", score.Text, leading) + leading = true + } + if len(score.Labels) > 0 { + if leading { + buf = append(buf, ',') + } + buf = append(buf, '"', 'l', 'a', 'b', 'e', 'l', 's', '"', ':', '{') + labelFirst := true + for k, v := range score.Labels { + if !labelFirst { + buf = append(buf, ',') + } + labelFirst = false + buf = jsonenc.AppendJSONString(buf, k) + buf = append(buf, ':') + buf = jsonenc.AppendJSONString(buf, v) + } + buf = append(buf, '}') + } + return append(buf, '}') +} + +// appendRerankResponse walks the RerankResponse shape into buf. +// The Results slice scales with documents: walking inference.RerankScore +// inline skips the per-element reflect cost encoding/json pays. +func appendRerankResponse(buf []byte, resp RerankResponse) []byte { + buf = append(buf, '{') + buf = jsonenc.AppendStringField(buf, "object", resp.Object, false) + buf = jsonenc.AppendStringField(buf, "model", resp.Model, true) + buf = append(buf, ',', '"', 'r', 'e', 's', 'u', 'l', 't', 's', '"', ':', '[') + for i, score := range resp.Results { + if i > 0 { + buf = append(buf, ',') + } + buf = appendRerankScore(buf, score) + } + return append(buf, ']', '}') +} + +// rerankResponseSize estimates the backing-buffer size for one +// RerankResponse so the encoder allocates once. +func rerankResponseSize(resp RerankResponse) int { + size := 4 // braces + slack + size += 11 + len(resp.Object) + size += 10 + len(resp.Model) + size += 12 // "results":[] + for _, score := range resp.Results { + // {"index":N,"score":0.xx,"text":"..."} — score typically + // in 0..1, 4-6 ASCII chars; text is the dominant variable. + size += 12 + len(score.Text) + if score.Index != 0 { + size += 9 + 12 // "index":N + } + if score.Score != 0 { + size += 9 + 12 // "score":0.xx + } + if len(score.Labels) > 0 { + size += 12 + for k, v := range score.Labels { + size += 6 + len(k) + len(v) + } + } + } + return size +} diff --git a/go/openai/services_enc_test.go b/go/openai/services_enc_test.go new file mode 100644 index 0000000..092ee16 --- /dev/null +++ b/go/openai/services_enc_test.go @@ -0,0 +1,76 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "encoding/json" + "testing" + + "dappco.re/go/inference" +) + +// TestRerankResponse_AppendRoundTrip locks the hand-rolled rerank +// encoder shape against encoding/json. The rerank wire is a +// single-shape contract (object/model/results) so the test exercises +// every RerankScore branch (with/without text/labels/zero-score). +func TestRerankResponse_AppendRoundTrip(t *testing.T) { + cases := []struct { + name string + in RerankResponse + }{ + {"empty-results", RerankResponse{Object: "list", Model: "qwen3-rerank"}}, + {"basic-results", RerankResponse{ + Object: "list", Model: "qwen3-rerank", + Results: []inference.RerankScore{ + {Index: 0, Score: 0.91, Text: "alpha"}, + {Index: 1, Score: 0.82, Text: "beta"}, + {Index: 2, Score: 0.74, Text: "gamma"}, + }, + }}, + {"with-labels", RerankResponse{ + Object: "list", Model: "qwen3-rerank", + Results: []inference.RerankScore{{ + Index: 0, Score: 0.95, Text: "x", + Labels: map[string]string{"locale": "en"}, + }}, + }}, + {"zero-score", RerankResponse{ + Object: "list", Model: "qwen3-rerank", + Results: []inference.RerankScore{{Index: 0, Text: "match"}}, + }}, + {"escapes", RerankResponse{ + Object: "list", Model: "qwen3-rerank", + Results: []inference.RerankScore{{Index: 0, Score: 0.5, Text: "quote \" tab\t"}}, + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + encoded := appendRerankResponse(nil, tc.in) + var back RerankResponse + if err := json.Unmarshal(encoded, &back); err != nil { + t.Fatalf("json.Unmarshal(%s) error = %v", encoded, err) + } + if back.Object != tc.in.Object || back.Model != tc.in.Model { + t.Fatalf("identity: got %+v, want %+v", back, tc.in) + } + if len(back.Results) != len(tc.in.Results) { + t.Fatalf("results len = %d, want %d", len(back.Results), len(tc.in.Results)) + } + for i := range tc.in.Results { + if back.Results[i].Index != tc.in.Results[i].Index || + back.Results[i].Score != tc.in.Results[i].Score || + back.Results[i].Text != tc.in.Results[i].Text { + t.Fatalf("results[%d] = %+v, want %+v", i, back.Results[i], tc.in.Results[i]) + } + if len(back.Results[i].Labels) != len(tc.in.Results[i].Labels) { + t.Fatalf("results[%d].labels len = %d, want %d", i, len(back.Results[i].Labels), len(tc.in.Results[i].Labels)) + } + for k, v := range tc.in.Results[i].Labels { + if back.Results[i].Labels[k] != v { + t.Fatalf("results[%d].labels[%q] = %q, want %q", i, k, back.Results[i].Labels[k], v) + } + } + } + }) + } +} diff --git a/go/openai/services_test.go b/go/openai/services_test.go new file mode 100644 index 0000000..d6c83ba --- /dev/null +++ b/go/openai/services_test.go @@ -0,0 +1,154 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "dappco.re/go/inference" +) + +type serviceModel struct { + *stubModel + cancelled string + cleared bool + warmed inference.CacheWarmRequest +} + +func (m *serviceModel) Embed(_ context.Context, req inference.EmbeddingRequest) (*inference.EmbeddingResult, error) { + return &inference.EmbeddingResult{ + Vectors: [][]float32{{float32(len(req.Input)), 0.5}}, + Usage: inference.EmbeddingUsage{PromptTokens: len(req.Input), TotalTokens: len(req.Input)}, + }, nil +} + +func (m *serviceModel) Rerank(_ context.Context, req inference.RerankRequest) (*inference.RerankResult, error) { + return &inference.RerankResult{ + Results: []inference.RerankScore{{Index: 1, Score: 0.95, Text: req.Documents[1]}}, + }, nil +} + +func (m *serviceModel) CacheStats(context.Context) (inference.CacheStats, error) { + return inference.CacheStats{Blocks: 7, Hits: 9, Misses: 1, HitRate: 0.9, CacheMode: "block-q8"}, nil +} + +func (m *serviceModel) WarmCache(_ context.Context, req inference.CacheWarmRequest) (inference.CacheWarmResult, error) { + m.warmed = req + return inference.CacheWarmResult{Blocks: []inference.CacheBlockRef{{ID: "blk", TokenCount: len(req.Tokens)}}}, nil +} + +func (m *serviceModel) ClearCache(context.Context, map[string]string) (inference.CacheStats, error) { + m.cleared = true + return inference.CacheStats{CacheMode: "block-q8"}, nil +} + +func (m *serviceModel) CancelRequest(_ context.Context, id string) (inference.RequestCancelResult, error) { + m.cancelled = id + return inference.RequestCancelResult{ID: id, Cancelled: id != ""}, nil +} + +func TestOpenAI_EmbeddingsHandler_Good_UsesEmbeddingModel(t *testing.T) { + model := &serviceModel{stubModel: &stubModel{}} + handler := NewEmbeddingsHandler(NewStaticResolver(map[string]inference.TextModel{"qwen": model})) + req := httptest.NewRequest(http.MethodPost, DefaultEmbeddingsPath, strings.NewReader(`{"model":"qwen","input":["one","two"]}`)) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"object":"list"`) || !strings.Contains(rec.Body.String(), `"embedding":[2,0.5]`) { + t.Fatalf("embedding response = %s", rec.Body.String()) + } +} + +func TestOpenAI_RerankHandler_Good_UsesRerankModel(t *testing.T) { + model := &serviceModel{stubModel: &stubModel{}} + handler := NewRerankHandler(NewStaticResolver(map[string]inference.TextModel{"qwen": model})) + req := httptest.NewRequest(http.MethodPost, DefaultRerankPath, strings.NewReader(`{"model":"qwen","query":"core","documents":["a","b"]}`)) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"index":1`) || !strings.Contains(rec.Body.String(), `"score":0.95`) { + t.Fatalf("rerank response = %s", rec.Body.String()) + } +} + +func TestOpenAI_CapabilityHandler_Good_ReportsModelCapabilities(t *testing.T) { + model := &serviceModel{stubModel: &stubModel{}} + handler := NewCapabilityHandler(NewStaticResolver(map[string]inference.TextModel{"qwen": model})) + req := httptest.NewRequest(http.MethodGet, DefaultCapabilitiesPath+"?model=qwen", nil) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String()) + } + if !strings.Contains(rec.Body.String(), `"embeddings"`) || !strings.Contains(rec.Body.String(), `"request.cancel"`) { + t.Fatalf("capability response = %s", rec.Body.String()) + } +} + +func TestOpenAI_CacheHandlers_Good_StatsWarmClear(t *testing.T) { + model := &serviceModel{stubModel: &stubModel{}} + resolver := NewStaticResolver(map[string]inference.TextModel{"qwen": model}) + + statsReq := httptest.NewRequest(http.MethodGet, DefaultCacheStatsPath+"?model=qwen", nil) + statsRec := httptest.NewRecorder() + NewCacheStatsHandler(resolver).ServeHTTP(statsRec, statsReq) + if statsRec.Code != http.StatusOK || !strings.Contains(statsRec.Body.String(), `"hit_rate":0.9`) { + t.Fatalf("cache stats = %d %s", statsRec.Code, statsRec.Body.String()) + } + + warmReq := httptest.NewRequest(http.MethodPost, DefaultCacheWarmPath, strings.NewReader(`{"model":"qwen","tokens":[1,2,3]}`)) + warmRec := httptest.NewRecorder() + NewCacheWarmHandler(resolver).ServeHTTP(warmRec, warmReq) + if warmRec.Code != http.StatusOK || model.warmed.Model.ID != "qwen" || len(model.warmed.Tokens) != 3 { + t.Fatalf("cache warm = %d %s warmed=%+v", warmRec.Code, warmRec.Body.String(), model.warmed) + } + + clearReq := httptest.NewRequest(http.MethodPost, DefaultCacheClearPath, strings.NewReader(`{"model":"qwen","labels":{"adapter":"none"}}`)) + clearRec := httptest.NewRecorder() + NewCacheClearHandler(resolver).ServeHTTP(clearRec, clearReq) + if clearRec.Code != http.StatusOK || !model.cleared { + t.Fatalf("cache clear = %d %s cleared=%v", clearRec.Code, clearRec.Body.String(), model.cleared) + } +} + +func TestOpenAI_CancelHandler_Good_UsesCancellableModel(t *testing.T) { + model := &serviceModel{stubModel: &stubModel{}} + handler := NewCancelHandler(NewStaticResolver(map[string]inference.TextModel{"qwen": model})) + req := httptest.NewRequest(http.MethodPost, DefaultCancelPath, strings.NewReader(`{"model":"qwen","id":"req_1"}`)) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Fatalf("status = %d body=%s", rec.Code, rec.Body.String()) + } + if model.cancelled != "req_1" || !strings.Contains(rec.Body.String(), `"cancelled":true`) { + t.Fatalf("cancel response = %s cancelled=%q", rec.Body.String(), model.cancelled) + } +} + +func TestOpenAI_ServiceHandlers_Bad_UnsupportedInterface(t *testing.T) { + handler := NewEmbeddingsHandler(NewStaticResolver(map[string]inference.TextModel{"qwen": &stubModel{}})) + req := httptest.NewRequest(http.MethodPost, DefaultEmbeddingsPath, strings.NewReader(`{"model":"qwen","input":"hello"}`)) + rec := httptest.NewRecorder() + + handler.ServeHTTP(rec, req) + + if rec.Code != http.StatusNotImplemented { + t.Fatalf("status = %d body=%s, want not implemented", rec.Code, rec.Body.String()) + } +} diff --git a/go/openai/services_unmarshal.go b/go/openai/services_unmarshal.go new file mode 100644 index 0000000..b07bdac --- /dev/null +++ b/go/openai/services_unmarshal.go @@ -0,0 +1,495 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Hand-rolled JSON-decoding for the openai services types +// (EmbeddingRequest, RerankRequest, CacheWarmRequest, +// CacheClearRequest, CancelRequest). Same single-pass byte-walker +// shape as openai/unmarshal.go. + +package openai + +import ( + "dappco.re/go/inference/jsonenc" +) + +// UnmarshalJSON walks the EmbeddingRequest wire shape. +func (r *EmbeddingRequest) UnmarshalJSON(data []byte) error { + *r = EmbeddingRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +func (r *EmbeddingRequest) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "input": + // EmbeddingInput is []string with its own UnmarshalJSON; + // call ParseJSONStringList directly to skip the nested + // dispatch path. + next, err := jsonenc.SkipJSONValue(data, i) + if err != nil { + return next, err + } + values, err := jsonenc.ParseJSONStringList(data[i:next]) + if err != nil { + return next, err + } + r.Input = values + return next, nil + case "encoding_format": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.EncodingFormat = s + return next, nil + case "dimensions": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + k := int(n) + r.Dimensions = &k + return next, nil + case "user": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.User = s + return next, nil + case "normalize": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONBool(data, i) + if err != nil { + return next, err + } + r.Normalize = v + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} + +// UnmarshalJSON walks the RerankRequest wire shape. +func (r *RerankRequest) UnmarshalJSON(data []byte) error { + *r = RerankRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +func (r *RerankRequest) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "query": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Query = s + return next, nil + case "documents": + next, err := jsonenc.SkipJSONValue(data, i) + if err != nil { + return next, err + } + docs, err := jsonenc.ParseJSONStringList(data[i:next]) + if err != nil { + return next, err + } + r.Documents = docs + return next, nil + case "top_n": + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + r.TopN = int(n) + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} + +// UnmarshalJSON walks the CancelRequest wire shape. +func (r *CancelRequest) UnmarshalJSON(data []byte) error { + *r = CancelRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "model": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return verr + } + r.Model = s + i = vnext + case "id": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return verr + } + r.ID = s + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +// UnmarshalJSON walks the CacheClearRequest wire shape. Labels +// (map[string]string) parsed via parseStringMap. +func (r *CacheClearRequest) UnmarshalJSON(data []byte) error { + *r = CacheClearRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "model": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return verr + } + r.Model = s + i = vnext + case "labels": + labels, vnext, verr := parseStringMap(data, i) + if verr != nil { + return verr + } + r.Labels = labels + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +// UnmarshalJSON walks the CacheWarmRequest wire shape. Tokens +// ([]int32) parsed via parseInt32Array; Labels via parseStringMap. +func (r *CacheWarmRequest) UnmarshalJSON(data []byte) error { + *r = CacheWarmRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "model": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return verr + } + r.Model = s + i = vnext + case "prompt": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return verr + } + r.Prompt = s + i = vnext + case "tokens": + toks, vnext, verr := parseInt32Array(data, i) + if verr != nil { + return verr + } + r.Tokens = toks + i = vnext + case "mode": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return verr + } + r.Mode = s + i = vnext + case "labels": + labels, vnext, verr := parseStringMap(data, i) + if verr != nil { + return verr + } + r.Labels = labels + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +// parseStringMap walks a JSON object with string keys + string +// values and returns a map[string]string. Used for the Labels +// fields on CacheWarm / CacheClear requests. +func parseStringMap(data []byte, i int) (map[string]string, int, error) { + if jsonenc.IsJSONNull(data, i) { + return nil, i + 4, nil + } + i, err := jsonenc.MatchObjectStart(data, i) + if err != nil { + return nil, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil, i + 1, nil + } + out := make(map[string]string) + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return nil, i, jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return nil, next, err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return nil, i, jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + val, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return nil, vnext, verr + } + out[key] = val + i = jsonenc.SkipJSONWhitespace(data, vnext) + if i >= len(data) { + return nil, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return out, i + 1, nil + } + return nil, i, jsonenc.ErrInvalidJSON + } +} + +// parseInt32Array walks a JSON array of integers and returns the +// parsed slice. Used for the Tokens field on CacheWarmRequest. +func parseInt32Array(data []byte, i int) ([]int32, int, error) { + if jsonenc.IsJSONNull(data, i) { + return nil, i + 4, nil + } + i, err := jsonenc.MatchArrayStart(data, i) + if err != nil { + return nil, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == ']' { + return nil, i + 1, nil + } + var out []int32 + for { + n, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return nil, next, err + } + out = append(out, int32(n)) + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) { + return nil, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i = jsonenc.SkipJSONWhitespace(data, i+1) + continue + } + if data[i] == ']' { + return out, i + 1, nil + } + return nil, i, jsonenc.ErrInvalidJSON + } +} diff --git a/go/openai/services_unmarshal_test.go b/go/openai/services_unmarshal_test.go new file mode 100644 index 0000000..3192591 --- /dev/null +++ b/go/openai/services_unmarshal_test.go @@ -0,0 +1,148 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "encoding/json" + "reflect" + "testing" +) + +func TestUnmarshalEmbeddingRequest_DirectShapes(t *testing.T) { + dim := 1024 + cases := []struct { + name string + in string + want EmbeddingRequest + }{ + { + name: "single-string-input", + in: `{"model":"text-embedding","input":"hello"}`, + want: EmbeddingRequest{ + Model: "text-embedding", + Input: EmbeddingInput{"hello"}, + }, + }, + { + name: "array-input-and-options", + in: `{"model":"text-embedding","input":["a","b"],"encoding_format":"float","dimensions":1024,"normalize":true,"user":"u1"}`, + want: EmbeddingRequest{ + Model: "text-embedding", + Input: EmbeddingInput{"a", "b"}, + EncodingFormat: "float", + Dimensions: &dim, + Normalize: true, + User: "u1", + }, + }, + { + name: "dimensions-null", + in: `{"model":"text-embedding","input":"hello","dimensions":null}`, + want: EmbeddingRequest{ + Model: "text-embedding", + Input: EmbeddingInput{"hello"}, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got EmbeddingRequest + if err := json.Unmarshal([]byte(tc.in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("got: %+v\nwant: %+v", got, tc.want) + } + }) + } +} + +func TestUnmarshalRerankRequest_DirectShapes(t *testing.T) { + in := `{"model":"rerank","query":"q","documents":["a","b","c"],"top_n":2}` + want := RerankRequest{ + Model: "rerank", + Query: "q", + Documents: []string{"a", "b", "c"}, + TopN: 2, + } + var got RerankRequest + if err := json.Unmarshal([]byte(in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %+v\nwant: %+v", got, want) + } +} + +func TestUnmarshalCacheWarmRequest_DirectShapes(t *testing.T) { + cases := []struct { + name string + in string + want CacheWarmRequest + }{ + { + name: "prompt-mode", + in: `{"model":"m","prompt":"hi","mode":"warm","labels":{"k":"v"}}`, + want: CacheWarmRequest{ + Model: "m", + Prompt: "hi", + Mode: "warm", + Labels: map[string]string{"k": "v"}, + }, + }, + { + name: "tokens-mode", + in: `{"model":"m","tokens":[1,2,3,4,5]}`, + want: CacheWarmRequest{ + Model: "m", + Tokens: []int32{1, 2, 3, 4, 5}, + }, + }, + { + name: "labels-null", + in: `{"model":"m","prompt":"hi","labels":null}`, + want: CacheWarmRequest{ + Model: "m", + Prompt: "hi", + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got CacheWarmRequest + if err := json.Unmarshal([]byte(tc.in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("got: %+v\nwant: %+v", got, tc.want) + } + }) + } +} + +func TestUnmarshalCacheClearRequest_DirectShapes(t *testing.T) { + in := `{"model":"m","labels":{"env":"prod","tier":"hot"}}` + want := CacheClearRequest{ + Model: "m", + Labels: map[string]string{"env": "prod", "tier": "hot"}, + } + var got CacheClearRequest + if err := json.Unmarshal([]byte(in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %+v\nwant: %+v", got, want) + } +} + +func TestUnmarshalCancelRequest_DirectShapes(t *testing.T) { + in := `{"model":"m","id":"req_123"}` + want := CancelRequest{Model: "m", ID: "req_123"} + var got CancelRequest + if err := json.Unmarshal([]byte(in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("got: %+v\nwant: %+v", got, want) + } +} diff --git a/go/openai/unmarshal.go b/go/openai/unmarshal.go new file mode 100644 index 0000000..86c69e1 --- /dev/null +++ b/go/openai/unmarshal.go @@ -0,0 +1,498 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Hand-rolled JSON-decoding for the OpenAI wire types. Fires at +// HTTP request-entry per chat-completion / responses / services +// call — the encoding/json reflect path costs 22-65 allocs on the +// canonical 1/5/20-turn chat shapes. +// +// The single-pass walker per type lands at ~7-13 allocs for typical +// shapes — predominantly the per-string clones the wire contract +// already requires. Pointer fields (Temperature/TopP/TopK/MaxTokens) +// take address of stack-allocated locals only when the field is +// present and not null. +// +// All decoders SkipJSONValue past unknown fields (matches the +// stdlib default — DisallowUnknownFields is not configured on the +// adapter). + +package openai + +import ( + "dappco.re/go/inference/jsonenc" +) + +// UnmarshalJSON walks the ChatCompletionRequest wire shape in a +// single pass. Replaces the encoding/json reflect path; saves the +// per-field reflect.Value boxing and the per-pointer-field heap +// escape. +func (r *ChatCompletionRequest) UnmarshalJSON(data []byte) error { + *r = ChatCompletionRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +// unmarshalField dispatches one ChatCompletionRequest field by key. +func (r *ChatCompletionRequest) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "messages": + msgs, next, err := parseChatMessageArray(data, i) + if err != nil { + return next, err + } + r.Messages = msgs + return next, nil + case "temperature": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONFloat32(data, i) + if err != nil { + return next, err + } + r.Temperature = &v + return next, nil + case "top_p": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONFloat32(data, i) + if err != nil { + return next, err + } + r.TopP = &v + return next, nil + case "top_k": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + k := int(v) + r.TopK = &k + return next, nil + case "max_tokens": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + k := int(v) + r.MaxTokens = &k + return next, nil + case "stream": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONBool(data, i) + if err != nil { + return next, err + } + r.Stream = v + return next, nil + case "stop": + next, err := jsonenc.SkipJSONValue(data, i) + if err != nil { + return next, err + } + stops, err := jsonenc.ParseJSONStringList(data[i:next]) + if err != nil { + return next, err + } + r.Stop = stops + return next, nil + case "user": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.User = s + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} + +// parseChatMessageArray walks a JSON array of ChatMessage objects. +func parseChatMessageArray(data []byte, i int) ([]ChatMessage, int, error) { + if jsonenc.IsJSONNull(data, i) { + return nil, i + 4, nil + } + i, err := jsonenc.MatchArrayStart(data, i) + if err != nil { + return nil, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == ']' { + return nil, i + 1, nil + } + var out []ChatMessage + for { + msg, next, err := parseChatMessage(data, i) + if err != nil { + return nil, next, err + } + out = append(out, msg) + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) { + return nil, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i = jsonenc.SkipJSONWhitespace(data, i+1) + continue + } + if data[i] == ']' { + return out, i + 1, nil + } + return nil, i, jsonenc.ErrInvalidJSON + } +} + +// parseChatMessage walks a single ChatMessage object at data[i]. +func parseChatMessage(data []byte, i int) (ChatMessage, int, error) { + var msg ChatMessage + i, err := jsonenc.MatchObjectStart(data, i) + if err != nil { + return msg, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return msg, i + 1, nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return msg, i, jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return msg, next, err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return msg, i, jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "role": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return msg, vnext, verr + } + msg.Role = s + i = vnext + case "content": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return msg, vnext, verr + } + msg.Content = s + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return msg, vnext, verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return msg, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return msg, i + 1, nil + } + return msg, i, jsonenc.ErrInvalidJSON + } +} + +// UnmarshalJSON walks the ResponseRequest wire shape in a single pass. +// Same dispatch shape as ChatCompletionRequest with the Responses +// field-name set (input / instructions / max_output_tokens). +func (r *ResponseRequest) UnmarshalJSON(data []byte) error { + *r = ResponseRequest{} + i, err := jsonenc.MatchObjectStart(data, 0) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + i, err = r.unmarshalField(data, i, key) + if err != nil { + return err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return nil + } + return jsonenc.ErrInvalidJSON + } +} + +func (r *ResponseRequest) unmarshalField(data []byte, i int, key []byte) (int, error) { + switch string(key) { + case "model": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Model = s + return next, nil + case "input": + msgs, next, err := parseResponseInputMessageArray(data, i) + if err != nil { + return next, err + } + r.Input = msgs + return next, nil + case "instructions": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.Instructions = s + return next, nil + case "temperature": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONFloat32(data, i) + if err != nil { + return next, err + } + r.Temperature = &v + return next, nil + case "top_p": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONFloat32(data, i) + if err != nil { + return next, err + } + r.TopP = &v + return next, nil + case "top_k": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + k := int(v) + r.TopK = &k + return next, nil + case "max_output_tokens": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONInt(data, i) + if err != nil { + return next, err + } + k := int(v) + r.MaxOutputTokens = &k + return next, nil + case "stream": + if jsonenc.IsJSONNull(data, i) { + return i + 4, nil + } + v, next, err := jsonenc.ParseJSONBool(data, i) + if err != nil { + return next, err + } + r.Stream = v + return next, nil + case "stop": + next, err := jsonenc.SkipJSONValue(data, i) + if err != nil { + return next, err + } + stops, err := jsonenc.ParseJSONStringList(data[i:next]) + if err != nil { + return next, err + } + r.Stop = stops + return next, nil + case "user": + s, next, err := jsonenc.ParseJSONString(data, i) + if err != nil { + return next, err + } + r.User = s + return next, nil + } + return jsonenc.SkipJSONValue(data, i) +} + +// parseResponseInputMessageArray walks a JSON array of +// ResponseInputMessage objects. +func parseResponseInputMessageArray(data []byte, i int) ([]ResponseInputMessage, int, error) { + if jsonenc.IsJSONNull(data, i) { + return nil, i + 4, nil + } + i, err := jsonenc.MatchArrayStart(data, i) + if err != nil { + return nil, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == ']' { + return nil, i + 1, nil + } + var out []ResponseInputMessage + for { + msg, next, err := parseResponseInputMessage(data, i) + if err != nil { + return nil, next, err + } + out = append(out, msg) + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) { + return nil, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i = jsonenc.SkipJSONWhitespace(data, i+1) + continue + } + if data[i] == ']' { + return out, i + 1, nil + } + return nil, i, jsonenc.ErrInvalidJSON + } +} + +// parseResponseInputMessage walks one ResponseInputMessage at data[i]. +func parseResponseInputMessage(data []byte, i int) (ResponseInputMessage, int, error) { + var msg ResponseInputMessage + i, err := jsonenc.MatchObjectStart(data, i) + if err != nil { + return msg, i, err + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i < len(data) && data[i] == '}' { + return msg, i + 1, nil + } + for { + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) || data[i] != '"' { + return msg, i, jsonenc.ErrInvalidJSON + } + key, next, err := jsonenc.ParseJSONStringRaw(data, i) + if err != nil { + return msg, next, err + } + i = jsonenc.SkipJSONWhitespace(data, next) + if i >= len(data) || data[i] != ':' { + return msg, i, jsonenc.ErrInvalidJSON + } + i = jsonenc.SkipJSONWhitespace(data, i+1) + switch string(key) { + case "role": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return msg, vnext, verr + } + msg.Role = s + i = vnext + case "content": + s, vnext, verr := jsonenc.ParseJSONString(data, i) + if verr != nil { + return msg, vnext, verr + } + msg.Content = s + i = vnext + default: + vnext, verr := jsonenc.SkipJSONValue(data, i) + if verr != nil { + return msg, vnext, verr + } + i = vnext + } + i = jsonenc.SkipJSONWhitespace(data, i) + if i >= len(data) { + return msg, i, jsonenc.ErrInvalidJSON + } + if data[i] == ',' { + i++ + continue + } + if data[i] == '}' { + return msg, i + 1, nil + } + return msg, i, jsonenc.ErrInvalidJSON + } +} diff --git a/go/openai/unmarshal_test.go b/go/openai/unmarshal_test.go new file mode 100644 index 0000000..df947f5 --- /dev/null +++ b/go/openai/unmarshal_test.go @@ -0,0 +1,175 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package openai + +import ( + "encoding/json" + "reflect" + "testing" +) + +// TestUnmarshalChatCompletionRequest_DirectShapes pins the hand-rolled +// decoder against direct JSON literals. Locks the per-field dispatch +// — present / absent / null variants of every pointer field, the +// StopList variant shape (string vs array), escape-heavy strings, +// multi-turn arrays. +func TestUnmarshalChatCompletionRequest_DirectShapes(t *testing.T) { + temp := float32(0.7) + topP := float32(0.95) + topK := 64 + maxTok := 1024 + cases := []struct { + name string + in string + want ChatCompletionRequest + }{ + { + name: "minimal", + in: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}]}`, + want: ChatCompletionRequest{ + Model: "gpt-4", + Messages: []ChatMessage{{Role: "user", Content: "hi"}}, + }, + }, + { + name: "all-optional-fields-set", + in: `{"model":"gpt-4","messages":[{"role":"user","content":"hi"}],"temperature":0.7,"top_p":0.95,"top_k":64,"max_tokens":1024,"stream":true,"stop":[""],"user":"u123"}`, + want: ChatCompletionRequest{ + Model: "gpt-4", + Messages: []ChatMessage{{Role: "user", Content: "hi"}}, + Temperature: &temp, + TopP: &topP, + TopK: &topK, + MaxTokens: &maxTok, + Stream: true, + Stop: StopList{""}, + User: "u123", + }, + }, + { + name: "stop-as-string", + in: `{"model":"gpt-4","messages":[],"stop":"END"}`, + want: ChatCompletionRequest{ + Model: "gpt-4", + Messages: nil, + Stop: StopList{"END"}, + }, + }, + { + name: "pointer-fields-null-keeps-zero", + in: `{"model":"gpt-4","messages":[],"temperature":null,"top_p":null,"top_k":null,"max_tokens":null,"stream":null}`, + want: ChatCompletionRequest{ + Model: "gpt-4", + }, + }, + { + name: "unknown-fields-ignored", + in: `{"model":"gpt-4","messages":[],"future":42,"extra":"x"}`, + want: ChatCompletionRequest{ + Model: "gpt-4", + }, + }, + { + name: "whitespace-friendly", + in: `{ + "model": "gpt-4", + "messages": [ + { "role": "user", "content": "hi" } + ] + }`, + want: ChatCompletionRequest{ + Model: "gpt-4", + Messages: []ChatMessage{{Role: "user", Content: "hi"}}, + }, + }, + { + name: "escape-heavy", + in: `{"model":"gpt-4","messages":[{"role":"user","content":"a\nb \"c\" \\d"}]}`, + want: ChatCompletionRequest{ + Model: "gpt-4", + Messages: []ChatMessage{{Role: "user", Content: "a\nb \"c\" \\d"}}, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got ChatCompletionRequest + if err := json.Unmarshal([]byte(tc.in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("Unmarshal mismatch\ngot: %+v\nwant: %+v", got, tc.want) + } + }) + } +} + +func TestUnmarshalResponseRequest_DirectShapes(t *testing.T) { + temp := float32(0.7) + maxOut := 256 + cases := []struct { + name string + in string + want ResponseRequest + }{ + { + name: "minimal", + in: `{"model":"gpt-4","input":[{"role":"user","content":"hi"}]}`, + want: ResponseRequest{ + Model: "gpt-4", + Input: []ResponseInputMessage{{Role: "user", Content: "hi"}}, + }, + }, + { + name: "with-instructions-and-options", + in: `{"model":"gpt-4","input":[{"role":"user","content":"hi"}],"instructions":"sys","temperature":0.7,"max_output_tokens":256,"stream":true}`, + want: ResponseRequest{ + Model: "gpt-4", + Input: []ResponseInputMessage{{Role: "user", Content: "hi"}}, + Instructions: "sys", + Temperature: &temp, + MaxOutputTokens: &maxOut, + Stream: true, + }, + }, + { + name: "stop-as-array", + in: `{"model":"gpt-4","input":[],"stop":["","x"]}`, + want: ResponseRequest{ + Model: "gpt-4", + Stop: StopList{"", "x"}, + }, + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + var got ResponseRequest + if err := json.Unmarshal([]byte(tc.in), &got); err != nil { + t.Fatalf("Unmarshal error = %v", err) + } + if !reflect.DeepEqual(got, tc.want) { + t.Fatalf("Unmarshal mismatch\ngot: %+v\nwant: %+v", got, tc.want) + } + }) + } +} + +// TestUnmarshalChatCompletionRequest_InvalidShapes asserts cleanly +// rejected error shapes — no panics, just errors. +func TestUnmarshalChatCompletionRequest_InvalidShapes(t *testing.T) { + cases := []string{ + ``, + `{`, + `}`, + `{"messages":not-an-array}`, + `{"temperature":"hot"}`, + } + for _, in := range cases { + t.Run(in, func(t *testing.T) { + var req ChatCompletionRequest + if err := json.Unmarshal([]byte(in), &req); err == nil { + t.Fatalf("Unmarshal(%q) returned nil error", in) + } + }) + } +} diff --git a/go/options_bench_test.go b/go/options_bench_test.go new file mode 100644 index 0000000..524b80a --- /dev/null +++ b/go/options_bench_test.go @@ -0,0 +1,294 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the option-builder surface. +// Per AX-11 — ApplyGenerateOpts fires per Generate/Chat/Classify/Batch +// call (per request), and ApplyLoadOpts fires per LoadModel (per model +// load). Option builders are tiny closures, but the slices.Clone in +// WithStopTokens IS allocation, and the per-request loop runs O(n) +// in option count, so the construction floor is a real cost surface +// for high-fanout request paths. +// +// Run: go test -bench=BenchmarkOptions -benchmem -run='^$' . + +package inference + +import ( + "testing" +) + +// Sinks defeat compiler DCE. +var ( + optionsBenchSinkGenerateCfg GenerateConfig + optionsBenchSinkLoadCfg LoadConfig + optionsBenchSinkGenerateOpt GenerateOption + optionsBenchSinkLoadOpt LoadOption +) + +// --- DefaultGenerateConfig (per-call floor when no opts supplied) --- + +func BenchmarkOptions_DefaultGenerateConfig(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateCfg = DefaultGenerateConfig() + } +} + +// --- Individual GenerateOption builders --- + +func BenchmarkOptions_WithMaxTokens(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateOpt = WithMaxTokens(256) + } +} + +func BenchmarkOptions_WithTemperature(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateOpt = WithTemperature(0.7) + } +} + +func BenchmarkOptions_WithTopK(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateOpt = WithTopK(40) + } +} + +func BenchmarkOptions_WithTopP(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateOpt = WithTopP(0.9) + } +} + +// WithStopTokens with a single stop token (most common — just EOS). +func BenchmarkOptions_WithStopTokens_One(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateOpt = WithStopTokens(2) + } +} + +// WithStopTokens with EOS + pad — the clone-the-slice cost surfaces here. +func BenchmarkOptions_WithStopTokens_Three(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateOpt = WithStopTokens(2, 1, 0) + } +} + +// 16 stop tokens — heavy stop-token sets (custom EOS variants for some models). +func BenchmarkOptions_WithStopTokens_Sixteen(b *testing.B) { + ids := []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateOpt = WithStopTokens(ids...) + } +} + +func BenchmarkOptions_WithRepeatPenalty(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateOpt = WithRepeatPenalty(1.1) + } +} + +func BenchmarkOptions_WithLogits(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateOpt = WithLogits() + } +} + +// --- ApplyGenerateOpts — the per-request hot path --- + +func BenchmarkOptions_ApplyGenerateOpts_Nil(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateCfg = ApplyGenerateOpts(nil) + } +} + +func BenchmarkOptions_ApplyGenerateOpts_Empty(b *testing.B) { + opts := []GenerateOption{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateCfg = ApplyGenerateOpts(opts) + } +} + +// Minimal — single option (just MaxTokens, the most common knob). +func BenchmarkOptions_ApplyGenerateOpts_Minimal(b *testing.B) { + opts := []GenerateOption{WithMaxTokens(128)} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateCfg = ApplyGenerateOpts(opts) + } +} + +// Typical chat-time option set — caps + sampling. +func BenchmarkOptions_ApplyGenerateOpts_Typical(b *testing.B) { + opts := []GenerateOption{ + WithMaxTokens(256), + WithTemperature(0.7), + WithTopP(0.9), + WithRepeatPenalty(1.1), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateCfg = ApplyGenerateOpts(opts) + } +} + +// Heavy — every knob set, including stop-token clone cost. +func BenchmarkOptions_ApplyGenerateOpts_Heavy(b *testing.B) { + opts := []GenerateOption{ + WithMaxTokens(2048), + WithTemperature(0.8), + WithTopK(50), + WithTopP(0.95), + WithStopTokens(0, 1, 2, 3), + WithRepeatPenalty(1.15), + WithLogits(), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateCfg = ApplyGenerateOpts(opts) + } +} + +// nil-option slot in the slice — common when callers conditionally +// append options. Tests the nil-skip branch cost. +func BenchmarkOptions_ApplyGenerateOpts_WithNilOptions(b *testing.B) { + opts := []GenerateOption{ + WithMaxTokens(128), + nil, + WithTemperature(0.7), + nil, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkGenerateCfg = ApplyGenerateOpts(opts) + } +} + +// --- LoadOption builders --- + +func BenchmarkOptions_WithBackend(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadOpt = WithBackend("metal") + } +} + +func BenchmarkOptions_WithContextLen(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadOpt = WithContextLen(4096) + } +} + +func BenchmarkOptions_WithGPULayers(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadOpt = WithGPULayers(-1) + } +} + +func BenchmarkOptions_WithParallelSlots(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadOpt = WithParallelSlots(4) + } +} + +func BenchmarkOptions_WithAdapterPath(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadOpt = WithAdapterPath("/models/lora/v1") + } +} + +// --- ApplyLoadOpts — the per-LoadModel hot path --- + +func BenchmarkOptions_ApplyLoadOpts_Nil(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadCfg = ApplyLoadOpts(nil) + } +} + +func BenchmarkOptions_ApplyLoadOpts_Minimal(b *testing.B) { + opts := []LoadOption{WithBackend("metal")} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadCfg = ApplyLoadOpts(opts) + } +} + +func BenchmarkOptions_ApplyLoadOpts_Typical(b *testing.B) { + opts := []LoadOption{ + WithBackend("metal"), + WithContextLen(4096), + WithGPULayers(-1), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadCfg = ApplyLoadOpts(opts) + } +} + +func BenchmarkOptions_ApplyLoadOpts_Heavy(b *testing.B) { + opts := []LoadOption{ + WithBackend("rocm"), + WithContextLen(32768), + WithGPULayers(40), + WithParallelSlots(8), + WithAdapterPath("/models/lora/domain-v2"), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadCfg = ApplyLoadOpts(opts) + } +} + +func BenchmarkOptions_ApplyLoadOpts_WithNilOptions(b *testing.B) { + opts := []LoadOption{ + WithBackend("metal"), + nil, + WithContextLen(4096), + nil, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + optionsBenchSinkLoadCfg = ApplyLoadOpts(opts) + } +} diff --git a/go/parser/builtin.go b/go/parser/builtin.go new file mode 100644 index 0000000..053a32a --- /dev/null +++ b/go/parser/builtin.go @@ -0,0 +1,34 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package parser + +import ( + "dappco.re/go/inference" +) + +type builtinOutputParser struct { + id string + markers []reasoningMarker +} + +func newBuiltinOutputParser(id string, markers []reasoningMarker) *builtinOutputParser { + return &builtinOutputParser{id: id, markers: append([]reasoningMarker(nil), markers...)} +} + +func (parser *builtinOutputParser) ParserID() string { + if parser == nil || parser.id == "" { + return "generic" + } + return parser.id +} + +func (parser *builtinOutputParser) ParseReasoning(_ []inference.Token, text string) (inference.ReasoningParseResult, error) { + if parser == nil { + parser = newBuiltinOutputParser("generic", genericMarkers()) + } + return parseReasoningText(text, parser.markers), nil +} + +func (parser *builtinOutputParser) ParseTools(_ []inference.Token, text string) (inference.ToolParseResult, error) { + return parseToolText(text) +} diff --git a/go/parser/builtin_bench_test.go b/go/parser/builtin_bench_test.go new file mode 100644 index 0000000..a71801c --- /dev/null +++ b/go/parser/builtin_bench_test.go @@ -0,0 +1,224 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the built-in OutputParser shell — newBuiltinOutputParser, +// ParserID, ParseReasoning, ParseTools. Per AX-11 — every reasoning- and +// tool-emitting model resolves to a builtinOutputParser instance and the +// ParseReasoning / ParseTools entry points fire once per generation +// flush of the streamed response. Marker-set is varied (qwen vs gemma +// vs gpt-oss) because the per-call cost is dominated by the marker +// scan in parseReasoningText, which itself is the per-segment hot +// loop driven by indexString. +// +// Run: go test -bench='Benchmark_Builtin' -benchmem -run='^$' ./go/parser +// +// Stream sizes mirror the realistic generation shapes: +// - 32-token ≈ short answer, no reasoning span +// - 256-token ≈ typical chat response with mid-length reasoning +// - 2048-token ≈ long-form response (the loop pays N times here) + +package parser + +import ( + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + builtinBenchParser *builtinOutputParser + builtinBenchID string + builtinBenchReason inference.ReasoningParseResult + builtinBenchTools inference.ToolParseResult + builtinBenchErr error +) + +// Roughly one English word ≈ one token for fixture-generation purposes — +// good enough for the parser scan cost which is bytes-driven. +func builtinBenchText(tokens int) string { + out := core.NewBuilder() + for i := 0; i < tokens; i++ { + out.WriteString("word ") + } + return out.String() +} + +// builtinBenchReasoningStream produces a synthetic generation of +// `tokens` words wrapped with a ... span covering the +// requested fraction of the stream. spanFraction is 0.10, 0.50, 0.90. +func builtinBenchReasoningStream(tokens int, spanFraction float64, startMarker, endMarker string) string { + span := int(float64(tokens) * spanFraction) + if span < 1 { + span = 1 + } + if span > tokens { + span = tokens + } + pre := (tokens - span) / 2 + post := tokens - span - pre + out := core.NewBuilder() + out.WriteString(builtinBenchText(pre)) + out.WriteString(startMarker) + out.WriteString(builtinBenchText(span)) + out.WriteString(endMarker) + out.WriteString(builtinBenchText(post)) + return out.String() +} + +// --- newBuiltinOutputParser (per-registry build) --- + +func Benchmark_Builtin_New_Generic(b *testing.B) { + markers := genericMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchParser = newBuiltinOutputParser("generic", markers) + } +} + +func Benchmark_Builtin_New_Qwen(b *testing.B) { + markers := qwenMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchParser = newBuiltinOutputParser("qwen", markers) + } +} + +func Benchmark_Builtin_New_Gemma(b *testing.B) { + markers := gemmaMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchParser = newBuiltinOutputParser("gemma", markers) + } +} + +// --- ParserID (called per dispatch + per Process flush) --- + +func Benchmark_Builtin_ParserID(b *testing.B) { + parser := newBuiltinOutputParser("qwen", qwenMarkers()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchID = parser.ParserID() + } +} + +func Benchmark_Builtin_ParserID_NilReceiver(b *testing.B) { + var parser *builtinOutputParser + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchID = parser.ParserID() + } +} + +// --- ParseReasoning across stream sizes × span fractions × architectures --- +// The 3 architectures cover the three marker shapes: +// qwen — single short pair `` +// gemma — multi-pair channel markers +// gpt-oss — multi-end markers (the worst-case findReasoningStart fan-out) + +var builtinBenchArchitectures = []struct { + id string + parser *builtinOutputParser + start string + end string +}{ + {"qwen", newBuiltinOutputParser("qwen", qwenMarkers()), "", ""}, + {"gemma", newBuiltinOutputParser("gemma", gemmaMarkers()), "thinking\n", ""}, + {"gptoss", newBuiltinOutputParser("gpt-oss", gptOSSMarkers()), "<|channel>analysis\n", "<|channel>final\n"}, +} + +var builtinBenchStreamSizes = []int{32, 256, 2048} + +var builtinBenchSpanFractions = []struct { + id string + frac float64 +}{ + {"Span10pct", 0.10}, + {"Span50pct", 0.50}, + {"Span90pct", 0.90}, +} + +func Benchmark_Builtin_ParseReasoning(b *testing.B) { + for _, arch := range builtinBenchArchitectures { + for _, size := range builtinBenchStreamSizes { + for _, span := range builtinBenchSpanFractions { + text := builtinBenchReasoningStream(size, span.frac, arch.start, arch.end) + b.Run(arch.id+"/"+span.id+"/"+core.Sprintf("Tokens%d", size), func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchReason, builtinBenchErr = arch.parser.ParseReasoning(nil, text) + } + }) + } + } + } +} + +// No reasoning span at all — common case for short factual answers. +func Benchmark_Builtin_ParseReasoning_NoSpan_Qwen(b *testing.B) { + parser := newBuiltinOutputParser("qwen", qwenMarkers()) + text := builtinBenchText(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchReason, builtinBenchErr = parser.ParseReasoning(nil, text) + } +} + +// Nil receiver pays the lazy-construction cost of building the +// generic-fallback parser before the parse runs. +func Benchmark_Builtin_ParseReasoning_NilReceiver(b *testing.B) { + var parser *builtinOutputParser + text := "preplananswer" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchReason, builtinBenchErr = parser.ParseReasoning(nil, text) + } +} + +// --- ParseTools — 0 / 1 / 5 tool invocations per response --- + +func Benchmark_Builtin_ParseTools_NoCalls(b *testing.B) { + parser := newBuiltinOutputParser("hermes", genericMarkers()) + text := builtinBenchText(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchTools, builtinBenchErr = parser.ParseTools(nil, text) + } +} + +func Benchmark_Builtin_ParseTools_OneCall(b *testing.B) { + parser := newBuiltinOutputParser("hermes", genericMarkers()) + text := `before {"name":"search","arguments":{"q":"core"}} after` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchTools, builtinBenchErr = parser.ParseTools(nil, text) + } +} + +func Benchmark_Builtin_ParseTools_FiveCalls(b *testing.B) { + parser := newBuiltinOutputParser("hermes", genericMarkers()) + out := core.NewBuilder() + out.WriteString("preamble text ") + for i := 0; i < 5; i++ { + out.WriteString(`{"name":"search","arguments":{"q":"core","page":`) + out.WriteString(core.Sprintf("%d", i)) + out.WriteString(`}} `) + } + out.WriteString("trailing text") + text := out.String() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + builtinBenchTools, builtinBenchErr = parser.ParseTools(nil, text) + } +} diff --git a/go/parser/markers.go b/go/parser/markers.go new file mode 100644 index 0000000..da48fe9 --- /dev/null +++ b/go/parser/markers.go @@ -0,0 +1,42 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package parser + +func qwenMarkers() []reasoningMarker { + return append([]reasoningMarker{ + {start: "", ends: []string{""}, kind: "thinking"}, + }, genericMarkers()...) +} + +func gemmaMarkers() []reasoningMarker { + return append([]reasoningMarker{ + {start: "<|channel>thought\n", ends: []string{""}, kind: "thinking"}, + {start: "<|channel>thinking\n", ends: []string{""}, kind: "thinking"}, + {start: "<|channel>reasoning\n", ends: []string{""}, kind: "reasoning"}, + {start: "<|channel>analysis\n", ends: []string{""}, kind: "analysis"}, + {start: "thinking\n", ends: []string{""}, kind: "thinking"}, + {start: "thought\n", ends: []string{""}, kind: "thinking"}, + {start: "analysis\n", ends: []string{""}, kind: "analysis"}, + {start: "reasoning\n", ends: []string{""}, kind: "reasoning"}, + }, genericMarkers()...) +} + +func gptOSSMarkers() []reasoningMarker { + return append([]reasoningMarker{ + {start: "<|channel>analysis\n", ends: []string{"<|channel>final\n", "<|channel>assistant\n", "<|channel>assistant"}, kind: "analysis"}, + {start: "<|channel>thought\n", ends: []string{"<|channel>final\n", "<|channel>assistant\n", "<|channel>assistant"}, kind: "thinking"}, + {start: "<|channel>reasoning\n", ends: []string{"<|channel>final\n", "<|channel>assistant\n", "<|channel>assistant"}, kind: "reasoning"}, + {start: "<|channel>analysis", ends: []string{"<|channel>final", "<|channel>assistant"}, kind: "analysis"}, + {start: "<|channel>thought", ends: []string{"<|channel>final", "<|channel>assistant"}, kind: "thinking"}, + {start: "<|channel>reasoning", ends: []string{"<|channel>final", "<|channel>assistant"}, kind: "reasoning"}, + }, genericMarkers()...) +} + +func genericMarkers() []reasoningMarker { + return []reasoningMarker{ + {start: "", ends: []string{""}, kind: "thinking"}, + {start: "", ends: []string{""}, kind: "thinking"}, + {start: "", ends: []string{""}, kind: "reasoning"}, + {start: "", ends: []string{""}, kind: "analysis"}, + } +} diff --git a/go/parser/markers_bench_test.go b/go/parser/markers_bench_test.go new file mode 100644 index 0000000..1a1c02d --- /dev/null +++ b/go/parser/markers_bench_test.go @@ -0,0 +1,56 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the per-architecture marker-set builders. Per AX-11 — +// qwenMarkers / gemmaMarkers / gptOSSMarkers / genericMarkers are +// called every time a parser is constructed via newBuiltinOutputParser, +// and the registry rebuilds these sets per Default() call (which +// HintFromInference / ForHint ultimately hit when the consumer +// declines to cache a Registry). Per-call cost is dominated by +// `append([]reasoningMarker(nil), genericMarkers()...)` which allocates +// the underlying slice on every invocation — the hot loop the +// consumer pays for short-lived parser construction. +// +// Run: go test -bench='Benchmark_Markers' -benchmem -run='^$' ./go/parser + +package parser + +import "testing" + +// Sinks defeat compiler DCE. +var ( + markersBenchSet []reasoningMarker +) + +// --- Per-architecture marker-set builders --- + +func Benchmark_Markers_Generic(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + markersBenchSet = genericMarkers() + } +} + +func Benchmark_Markers_Qwen(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + markersBenchSet = qwenMarkers() + } +} + +func Benchmark_Markers_Gemma(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + markersBenchSet = gemmaMarkers() + } +} + +func Benchmark_Markers_GPTOSS(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + markersBenchSet = gptOSSMarkers() + } +} diff --git a/go/parser/reasoning.go b/go/parser/reasoning.go new file mode 100644 index 0000000..d125b3e --- /dev/null +++ b/go/parser/reasoning.go @@ -0,0 +1,76 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package parser + +import ( + core "dappco.re/go" + "dappco.re/go/inference" +) + +func parseReasoningText(text string, markers []reasoningMarker) inference.ReasoningParseResult { + visible := core.NewBuilder() + segments := []inference.ReasoningSegment{} + pending := text + tokenOffset := 0 + for pending != "" { + idx, marker, ok := findReasoningStart(pending, markers) + if !ok { + visible.WriteString(pending) + break + } + visible.WriteString(pending[:idx]) + tokenOffset += idx + afterStart := pending[idx+len(marker.start):] + end, endSize := firstReasoningEnd(afterStart, marker.ends) + if end < 0 { + reasoning := trimReasoningText(afterStart) + if reasoning != "" { + segments = append(segments, inference.ReasoningSegment{Kind: marker.kind, Text: reasoning, StartToken: tokenOffset}) + } + break + } + reasoning := trimReasoningText(afterStart[:end]) + if reasoning != "" { + segments = append(segments, inference.ReasoningSegment{Kind: marker.kind, Text: reasoning, StartToken: tokenOffset, EndToken: tokenOffset + end}) + } + pending = afterStart[end+endSize:] + tokenOffset += len(marker.start) + end + endSize + } + return inference.ReasoningParseResult{VisibleText: visible.String(), Reasoning: segments} +} + +func findReasoningStart(text string, markers []reasoningMarker) (int, reasoningMarker, bool) { + best := -1 + var marker reasoningMarker + for _, candidate := range markers { + idx := indexString(text, candidate.start) + if idx < 0 { + continue + } + if best < 0 || idx < best || idx == best && len(candidate.start) > len(marker.start) { + best = idx + marker = candidate + } + } + return best, marker, best >= 0 +} + +func firstReasoningEnd(text string, ends []string) (int, int) { + best := -1 + bestSize := 0 + for _, end := range ends { + idx := indexString(text, end) + if idx < 0 { + continue + } + if best < 0 || idx < best { + best = idx + bestSize = len(end) + } + } + return best, bestSize +} + +func trimReasoningText(text string) string { + return core.Trim(text) +} diff --git a/go/parser/reasoning_bench_test.go b/go/parser/reasoning_bench_test.go new file mode 100644 index 0000000..0483aee --- /dev/null +++ b/go/parser/reasoning_bench_test.go @@ -0,0 +1,262 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the unexported reasoning state machine — +// parseReasoningText, findReasoningStart, firstReasoningEnd, +// trimReasoningText. Per AX-11 — parseReasoningText is the per-flush +// hot loop ParseReasoning resolves to; findReasoningStart and +// firstReasoningEnd are the per-marker-candidate inner scans driven +// by indexString. With qwen3-class generation flushes hundreds of +// times per response, the per-call cost compounds. +// +// Run: go test -bench='Benchmark_Reasoning' -benchmem -run='^$' ./go/parser +// +// Stream sizes mirror realistic generation outputs: +// - 32-token ≈ very short answer +// - 256-token ≈ typical chat-response length +// - 2048-token ≈ long-form generation (the loop pays N times here) + +package parser + +import ( + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + reasoningBenchResult inference.ReasoningParseResult + reasoningBenchIdx int + reasoningBenchMarker reasoningMarker + reasoningBenchOK bool + reasoningBenchEndIdx int + reasoningBenchEndSize int + reasoningBenchText string +) + +// reasoningBenchWords builds a synthetic prose stream of approx +// `tokens` words — cheap proxy for byte cost the scanner pays. +func reasoningBenchWords(tokens int) string { + out := core.NewBuilder() + for i := 0; i < tokens; i++ { + out.WriteString("word ") + } + return out.String() +} + +// reasoningBenchStream wraps a span of words inside the requested +// marker pair, with the span covering `spanFraction` of the total. +func reasoningBenchStream(tokens int, spanFraction float64, startMarker, endMarker string) string { + span := int(float64(tokens) * spanFraction) + if span < 1 { + span = 1 + } + if span > tokens { + span = tokens + } + pre := (tokens - span) / 2 + post := tokens - span - pre + out := core.NewBuilder() + out.WriteString(reasoningBenchWords(pre)) + out.WriteString(startMarker) + out.WriteString(reasoningBenchWords(span)) + out.WriteString(endMarker) + out.WriteString(reasoningBenchWords(post)) + return out.String() +} + +// --- parseReasoningText: per-flush hot loop --- + +var reasoningBenchArchitectures = []struct { + id string + markers []reasoningMarker + start string + end string +}{ + {"Qwen", qwenMarkers(), "", ""}, + {"Gemma", gemmaMarkers(), "thinking\n", ""}, + {"GPTOSS", gptOSSMarkers(), "<|channel>analysis\n", "<|channel>final\n"}, + {"Generic", genericMarkers(), "", ""}, +} + +var reasoningBenchStreamSizes = []int{32, 256, 2048} + +var reasoningBenchSpanFractions = []struct { + id string + frac float64 +}{ + {"Span10pct", 0.10}, + {"Span50pct", 0.50}, + {"Span90pct", 0.90}, +} + +func Benchmark_Reasoning_ParseText(b *testing.B) { + for _, arch := range reasoningBenchArchitectures { + for _, size := range reasoningBenchStreamSizes { + for _, span := range reasoningBenchSpanFractions { + text := reasoningBenchStream(size, span.frac, arch.start, arch.end) + markers := arch.markers + b.Run(arch.id+"/"+span.id+"/"+core.Sprintf("Tokens%d", size), func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchResult = parseReasoningText(text, markers) + } + }) + } + } + } +} + +// Edge case: no reasoning span at all (every marker misses). +// The visible-only short-circuit path is the most common per-response +// shape for non-reasoning models. +func Benchmark_Reasoning_ParseText_NoSpan_Qwen(b *testing.B) { + text := reasoningBenchWords(256) + markers := qwenMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchResult = parseReasoningText(text, markers) + } +} + +// Edge case: unclosed reasoning span — exercises the +// firstReasoningEnd < 0 branch. +func Benchmark_Reasoning_ParseText_Unclosed_Qwen(b *testing.B) { + text := "preamble " + reasoningBenchWords(200) + markers := qwenMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchResult = parseReasoningText(text, markers) + } +} + +// --- findReasoningStart: per-marker fan-out, dominated by indexString --- + +func Benchmark_Reasoning_FindStart_HitEarly_Qwen(b *testing.B) { + text := "plan" + reasoningBenchWords(256) + markers := qwenMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchIdx, reasoningBenchMarker, reasoningBenchOK = findReasoningStart(text, markers) + } +} + +func Benchmark_Reasoning_FindStart_HitMid_Qwen(b *testing.B) { + text := reasoningBenchStream(256, 0.50, "", "") + markers := qwenMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchIdx, reasoningBenchMarker, reasoningBenchOK = findReasoningStart(text, markers) + } +} + +func Benchmark_Reasoning_FindStart_HitLate_Qwen(b *testing.B) { + text := reasoningBenchWords(256) + "plantail" + markers := qwenMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchIdx, reasoningBenchMarker, reasoningBenchOK = findReasoningStart(text, markers) + } +} + +func Benchmark_Reasoning_FindStart_Miss_Qwen(b *testing.B) { + text := reasoningBenchWords(256) + markers := qwenMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchIdx, reasoningBenchMarker, reasoningBenchOK = findReasoningStart(text, markers) + } +} + +// Gemma + gpt-oss carry the worst-case marker fan-out — every miss +// forces every candidate to be scanned. +func Benchmark_Reasoning_FindStart_Miss_Gemma(b *testing.B) { + text := reasoningBenchWords(256) + markers := gemmaMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchIdx, reasoningBenchMarker, reasoningBenchOK = findReasoningStart(text, markers) + } +} + +func Benchmark_Reasoning_FindStart_Miss_GPTOSS(b *testing.B) { + text := reasoningBenchWords(256) + markers := gptOSSMarkers() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchIdx, reasoningBenchMarker, reasoningBenchOK = findReasoningStart(text, markers) + } +} + +// --- firstReasoningEnd: per-end-marker scan inside an open span --- + +func Benchmark_Reasoning_FirstEnd_HitEarly(b *testing.B) { + text := "" + reasoningBenchWords(256) + ends := []string{""} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchEndIdx, reasoningBenchEndSize = firstReasoningEnd(text, ends) + } +} + +func Benchmark_Reasoning_FirstEnd_HitLate(b *testing.B) { + text := reasoningBenchWords(256) + "" + ends := []string{""} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchEndIdx, reasoningBenchEndSize = firstReasoningEnd(text, ends) + } +} + +func Benchmark_Reasoning_FirstEnd_Miss(b *testing.B) { + text := reasoningBenchWords(256) + ends := []string{""} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchEndIdx, reasoningBenchEndSize = firstReasoningEnd(text, ends) + } +} + +// gpt-oss carries 3 end-marker candidates — every miss pays for all 3. +func Benchmark_Reasoning_FirstEnd_Miss_GPTOSS(b *testing.B) { + text := reasoningBenchWords(256) + ends := []string{"<|channel>final\n", "<|channel>assistant\n", "<|channel>assistant"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchEndIdx, reasoningBenchEndSize = firstReasoningEnd(text, ends) + } +} + +// --- trimReasoningText: thin core.Trim wrapper, but called per segment --- + +func Benchmark_Reasoning_Trim_Short(b *testing.B) { + text := " plan with leading and trailing whitespace " + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchText = trimReasoningText(text) + } +} + +func Benchmark_Reasoning_Trim_Long(b *testing.B) { + text := " " + reasoningBenchWords(256) + " " + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + reasoningBenchText = trimReasoningText(text) + } +} diff --git a/go/parser/reasoning_test.go b/go/parser/reasoning_test.go new file mode 100644 index 0000000..67bec46 --- /dev/null +++ b/go/parser/reasoning_test.go @@ -0,0 +1,61 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package parser + +import ( + "testing" +) + +func TestReasoning_BuiltinParsers_Good(t *testing.T) { + cases := []struct { + name string + arch string + text string + visible string + reasoning string + kind string + }{ + { + name: "qwen think tags", + arch: "qwen3", + text: "preplananswer", + visible: "preanswer", + reasoning: "plan", + kind: "thinking", + }, + { + name: "gemma turn markers", + arch: "gemma4_text", + text: "thinking\nplandone", + visible: "done", + reasoning: "plan", + kind: "thinking", + }, + { + name: "gpt oss channel markers", + arch: "gpt_oss", + text: "<|channel>analysis\nplan<|channel>final\nanswer", + visible: "answer", + reasoning: "plan", + kind: "analysis", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got, err := ForHint(Hint{Architecture: tc.arch}).ParseReasoning(nil, tc.text) + if err != nil { + t.Fatalf("ParseReasoning() error = %v", err) + } + if got.VisibleText != tc.visible { + t.Fatalf("VisibleText = %q, want %q", got.VisibleText, tc.visible) + } + if len(got.Reasoning) != 1 { + t.Fatalf("Reasoning len = %d, want 1: %+v", len(got.Reasoning), got.Reasoning) + } + if got.Reasoning[0].Text != tc.reasoning || got.Reasoning[0].Kind != tc.kind { + t.Fatalf("Reasoning[0] = %+v, want %q/%q", got.Reasoning[0], tc.kind, tc.reasoning) + } + }) + } +} diff --git a/go/parser/registry.go b/go/parser/registry.go new file mode 100644 index 0000000..2bbcd2a --- /dev/null +++ b/go/parser/registry.go @@ -0,0 +1,121 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package parser + +import ( + core "dappco.re/go" + "dappco.re/go/inference" +) + +// type custom struct{ /* ... */ } +// func (custom) ParserID() string { return "custom" } +// // implement inference.ReasoningParser + inference.ToolParser +type OutputParser interface { + ParserID() string + inference.ReasoningParser + inference.ToolParser +} + +// reg := parser.NewRegistry() +// reg.Register(customParser, "custom", "custom-v2") +type Registry struct { + parsers map[string]OutputParser + fallback OutputParser +} + +// reg := parser.NewRegistry() +func NewRegistry() *Registry { + generic := newBuiltinOutputParser("generic", genericMarkers()) + return &Registry{ + parsers: map[string]OutputParser{"generic": generic}, + fallback: generic, + } +} + +// Default returns the process-wide built-in parser registry. Built +// once via core.Once — every Processor / ForHint call shares the same +// instance instead of rebuilding all 11 parsers + their marker +// slices. The registry is read-only after construction (Register is +// safe on bespoke Registries created via NewRegistry, not on the +// shared default). +// +// reg := parser.Default() +// out := reg.LookupHint(parser.Hint{Architecture: "qwen3"}) +func Default() *Registry { + defaultOnce.Do(func() { defaultRegistry = buildDefaultRegistry() }) + return defaultRegistry +} + +var ( + defaultRegistry *Registry + defaultOnce core.Once +) + +func buildDefaultRegistry() *Registry { + registry := NewRegistry() + registry.Register(newBuiltinOutputParser("qwen", qwenMarkers()), "qwen", "qwen2", "qwen3") + registry.Register(newBuiltinOutputParser("gemma", gemmaMarkers()), "gemma", "gemma3", "gemma4", "gemma4_text") + registry.Register(newBuiltinOutputParser("minimax", qwenMarkers()), "minimax", "minimax_m2", "minimax-m2") + registry.Register(newBuiltinOutputParser("deepseek-r1", qwenMarkers()), "deepseek", "deepseek_r1", "deepseek-r1") + registry.Register(newBuiltinOutputParser("gpt-oss", gptOSSMarkers()), "gpt-oss", "gpt_oss", "gptoss") + registry.Register(newBuiltinOutputParser("mistral", genericMarkers()), "mistral", "mixtral") + registry.Register(newBuiltinOutputParser("kimi", qwenMarkers()), "kimi", "kimi_k2", "moonshot") + registry.Register(newBuiltinOutputParser("glm", qwenMarkers()), "glm", "glm4", "chatglm") + registry.Register(newBuiltinOutputParser("hermes", genericMarkers()), "hermes", "hermes2", "hermes3") + registry.Register(newBuiltinOutputParser("granite", genericMarkers()), "granite", "ibm-granite") + return registry +} + +// reg.Register(myParser, "alias1", "alias2") +func (registry *Registry) Register(parser OutputParser, aliases ...string) { + if registry == nil || parser == nil { + return + } + if registry.parsers == nil { + registry.parsers = map[string]OutputParser{} + } + registry.parsers[NormaliseKey(parser.ParserID())] = parser + for _, alias := range aliases { + key := NormaliseKey(alias) + if key == "" { + continue + } + registry.parsers[key] = parser + } + if registry.fallback == nil { + registry.fallback = parser + } +} + +// if p, ok := reg.Lookup("qwen3"); ok { /* use p */ } +func (registry *Registry) Lookup(name string) (OutputParser, bool) { + if registry == nil { + return nil, false + } + parser, ok := registry.parsers[NormaliseKey(name)] + return parser, ok +} + +// p := reg.LookupHint(parser.Hint{Architecture: "qwen3"}) +func (registry *Registry) LookupHint(hint Hint) OutputParser { + if registry == nil { + return Default().LookupHint(hint) + } + if parser, ok := registry.Lookup(Family(hint)); ok { + return parser + } + if registry.fallback != nil { + return registry.fallback + } + return newBuiltinOutputParser("generic", genericMarkers()) +} + +// p := parser.ForHint(parser.Hint{Architecture: "qwen3"}) +func ForHint(hint Hint) OutputParser { + return Default().LookupHint(hint) +} + +// hint := parser.HintFromInference(model.Info()) +func HintFromInference(info inference.ModelInfo) Hint { + return Hint{Architecture: info.Architecture} +} diff --git a/go/parser/registry_bench_test.go b/go/parser/registry_bench_test.go new file mode 100644 index 0000000..ab748fb --- /dev/null +++ b/go/parser/registry_bench_test.go @@ -0,0 +1,200 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for parser registry construction + lookup. Per AX-11 — +// Default() rebuilds the entire registry (10 architectures × marker +// fan-out) every call, NewRegistry() + Register() are the per-consumer +// build paths, Lookup is the per-dispatch hot path, and ForHint is the +// per-request convenience wrapper that hits Default() + LookupHint on +// every call when the consumer doesn't cache a Registry. HintFromInference +// is the inline-allocation cost paid per generation request. +// +// Run: go test -bench='Benchmark_Registry' -benchmem -run='^$' ./go/parser + +package parser + +import ( + "testing" + + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + registryBenchRegistry *Registry + registryBenchParser OutputParser + registryBenchOK bool + registryBenchHint Hint +) + +// --- Default + NewRegistry (per-build floor) --- + +func Benchmark_Registry_NewRegistry(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchRegistry = NewRegistry() + } +} + +func Benchmark_Registry_Default(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchRegistry = Default() + } +} + +// --- Register (per-alias insert) --- + +func Benchmark_Registry_RegisterSingleAlias(b *testing.B) { + registry := NewRegistry() + parser := newBuiltinOutputParser("custom", genericMarkers()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registry.Register(parser, "alias") + } +} + +func Benchmark_Registry_RegisterMultiAlias(b *testing.B) { + registry := NewRegistry() + parser := newBuiltinOutputParser("custom", genericMarkers()) + aliases := []string{"a1", "a2", "a3", "a4", "a5"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registry.Register(parser, aliases...) + } +} + +// --- Lookup: per-dispatch hot path --- + +func Benchmark_Registry_Lookup_Hit_Qwen(b *testing.B) { + registry := Default() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser, registryBenchOK = registry.Lookup("qwen3") + } +} + +func Benchmark_Registry_Lookup_Hit_Gemma(b *testing.B) { + registry := Default() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser, registryBenchOK = registry.Lookup("gemma4_text") + } +} + +// Miss path forces a full map probe + key normalisation. +func Benchmark_Registry_Lookup_Miss(b *testing.B) { + registry := Default() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser, registryBenchOK = registry.Lookup("not-a-real-arch") + } +} + +// Lookup pays NormaliseKey on every call — exercise the +// normalisation cost separately by feeding mixed-case input. +func Benchmark_Registry_Lookup_Hit_Normalise(b *testing.B) { + registry := Default() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser, registryBenchOK = registry.Lookup("Qwen-3.5") + } +} + +func Benchmark_Registry_Lookup_NilReceiver(b *testing.B) { + var registry *Registry + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser, registryBenchOK = registry.Lookup("qwen3") + } +} + +// --- LookupHint: Family() + Lookup() + fallback --- + +func Benchmark_Registry_LookupHint_Qwen(b *testing.B) { + registry := Default() + hint := Hint{Architecture: "qwen3"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser = registry.LookupHint(hint) + } +} + +func Benchmark_Registry_LookupHint_Gemma(b *testing.B) { + registry := Default() + hint := Hint{Architecture: "gemma4_text"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser = registry.LookupHint(hint) + } +} + +func Benchmark_Registry_LookupHint_Unknown(b *testing.B) { + registry := Default() + hint := Hint{Architecture: "not-a-real-arch"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser = registry.LookupHint(hint) + } +} + +func Benchmark_Registry_LookupHint_NilReceiver(b *testing.B) { + var registry *Registry + hint := Hint{Architecture: "qwen3"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser = registry.LookupHint(hint) + } +} + +// --- ForHint: the convenience wrapper that hits Default() + LookupHint --- + +func Benchmark_Registry_ForHint_Qwen(b *testing.B) { + hint := Hint{Architecture: "qwen3"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser = ForHint(hint) + } +} + +func Benchmark_Registry_ForHint_Gemma(b *testing.B) { + hint := Hint{Architecture: "gemma4_text"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser = ForHint(hint) + } +} + +func Benchmark_Registry_ForHint_Unknown(b *testing.B) { + hint := Hint{Architecture: "not-a-real-arch"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchParser = ForHint(hint) + } +} + +// --- HintFromInference: per-request inline alloc --- + +func Benchmark_Registry_HintFromInference(b *testing.B) { + info := inference.ModelInfo{Architecture: "qwen3"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + registryBenchHint = HintFromInference(info) + } +} diff --git a/go/parser/registry_test.go b/go/parser/registry_test.go new file mode 100644 index 0000000..481c845 --- /dev/null +++ b/go/parser/registry_test.go @@ -0,0 +1,93 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package parser + +import ( + "testing" + + "dappco.re/go/inference" +) + +func TestRegistry_DefaultLookup_Good_ModelFamilies(t *testing.T) { + cases := map[string]string{ + "qwen3": "qwen", + "gemma4_text": "gemma", + "minimax_m2": "minimax", + "deepseek_r1": "deepseek-r1", + "gpt_oss": "gpt-oss", + "mistral": "mistral", + "kimi_k2": "kimi", + "glm4": "glm", + "hermes3": "hermes", + "granite": "granite", + "unknown": "generic", + } + + for arch, want := range cases { + p := ForHint(Hint{Architecture: arch}) + if p == nil { + t.Fatalf("ForHint(%q) returned nil", arch) + } + if p.ParserID() != want { + t.Fatalf("ForHint(%q) = %q, want %q", arch, p.ParserID(), want) + } + } +} + +func TestRegistry_RegisterCustomParser_Good(t *testing.T) { + registry := NewRegistry() + registry.Register(customOutputParser{}, "custom-family") + + p, ok := registry.Lookup("custom-family") + if !ok { + t.Fatal("Lookup(custom-family) = false") + } + got, err := p.ParseReasoning(nil, "answer") + if err != nil { + t.Fatalf("ParseReasoning() error = %v", err) + } + if p.ParserID() != "custom" || got.VisibleText != "custom:answer" { + t.Fatalf("parser/result = %q %+v", p.ParserID(), got) + } +} + +func TestRegistry_FallbacksAndNilReceivers_Ugly(t *testing.T) { + var nilRegistry *Registry + if p, ok := nilRegistry.Lookup("qwen"); ok || p != nil { + t.Fatalf("nil Lookup() = %+v/%v, want nil/false", p, ok) + } + p := nilRegistry.LookupHint(Hint{Architecture: "qwen3"}) + if p == nil || p.ParserID() != "qwen" { + t.Fatalf("nil LookupHint() = %v, want default qwen parser", p) + } + registry := &Registry{} + registry.Register(nil, "ignored") + if p := registry.LookupHint(Hint{}); p == nil || p.ParserID() != "generic" { + t.Fatalf("empty registry LookupHint() = %v, want generic fallback", p) + } + registry.Register(customOutputParser{}, "", "custom.alias") + if p, ok := registry.Lookup("custom-alias"); !ok || p.ParserID() != "custom" { + t.Fatalf("Lookup(custom-alias) = %v/%v, want custom parser", p, ok) + } + + var nilParser *builtinOutputParser + if nilParser.ParserID() != "generic" { + t.Fatalf("nil builtin ParserID() = %q, want generic", nilParser.ParserID()) + } + reasoning, err := nilParser.ParseReasoning(nil, "plananswer") + if err != nil || reasoning.VisibleText != "answer" || len(reasoning.Reasoning) != 1 { + t.Fatalf("nil builtin ParseReasoning() = %+v/%v, want generic parse", reasoning, err) + } +} + +type customOutputParser struct{} + +func (customOutputParser) ParserID() string { return "custom" } + +func (customOutputParser) ParseReasoning(_ []inference.Token, text string) (inference.ReasoningParseResult, error) { + return inference.ReasoningParseResult{VisibleText: "custom:" + text}, nil +} + +func (customOutputParser) ParseTools(_ []inference.Token, text string) (inference.ToolParseResult, error) { + return inference.ToolParseResult{VisibleText: text}, nil +} diff --git a/go/parser/selector.go b/go/parser/selector.go new file mode 100644 index 0000000..e331508 --- /dev/null +++ b/go/parser/selector.go @@ -0,0 +1,69 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package parser + +import ( + core "dappco.re/go" +) + +// key := parser.NormaliseKey("Qwen-3.5") // "qwen_3_5" +func NormaliseKey(value string) string { + value = core.Lower(core.Trim(value)) + value = replaceAll(value, "-", "_") + value = replaceAll(value, ".", "_") + return value +} + +// family := parser.Family(parser.Hint{Architecture: "qwen3"}) // "qwen" +func Family(hint Hint) string { + arch := NormaliseKey(hint.Architecture) + adapter := NormaliseKey(hint.AdapterName) + combined := core.Concat(arch, " ", adapter) + switch { + case core.Contains(combined, "qwen"): + return "qwen" + case core.Contains(combined, "gemma"): + return "gemma" + case core.Contains(combined, "minimax"): + return "minimax" + case core.Contains(combined, "deepseek"): + return "deepseek_r1" + case core.Contains(combined, "gpt_oss"), core.Contains(combined, "gptoss"): + return "gpt_oss" + case core.Contains(combined, "mistral"), core.Contains(combined, "mixtral"): + return "mistral" + case core.Contains(combined, "kimi"), core.Contains(combined, "moonshot"): + return "kimi" + case core.Contains(combined, "glm"), core.Contains(combined, "chatglm"): + return "glm" + case core.Contains(combined, "hermes"): + return "hermes" + case core.Contains(combined, "granite"): + return "granite" + default: + return "generic" + } +} + +// replaceAll delegates to core.Replace (strings.ReplaceAll). The +// stdlib implementation pre-counts occurrences and allocates the +// result buffer exactly once — same shape as the hand-rolled loop but +// with byte-level optimisations the builder loop didn't reach. Old +// shape was already 1-2 allocs; stdlib is the same with less code to +// audit. +func replaceAll(text, old, next string) string { + if old == "" { + return text + } + return core.Replace(text, old, next) +} + +// indexString delegates to stdlib via core.Index. The previous +// hand-rolled implementation was a naive O(N×M) byte-by-byte scan; +// stdlib's strings.Index uses Rabin-Karp / SIMD-accelerated byte +// search and runs O(N+M) for the multi-byte markers (``, +// `<|channel>analysis\n`, etc.) that the thinking/reasoning parsers +// scan against on every per-token Process call. +func indexString(s, substr string) int { + return core.Index(s, substr) +} diff --git a/go/parser/selector_bench_test.go b/go/parser/selector_bench_test.go new file mode 100644 index 0000000..629edb7 --- /dev/null +++ b/go/parser/selector_bench_test.go @@ -0,0 +1,229 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the parser selection layer — NormaliseKey + Family. Per +// AX-11 — both fire on every Registry.Lookup / LookupHint call, which +// itself fires per generation request when callers don't cache. The +// helpers replaceAll and indexString are also exercised because they +// are the inner string-scan loop the entire package depends on +// (parseReasoningText, parseToolText, processor.findStart, et al.). +// +// Run: go test -bench='Benchmark_Selector' -benchmem -run='^$' ./go/parser + +package parser + +import "testing" + +// Sinks defeat compiler DCE. +var ( + selectorBenchKey string + selectorBenchFam string + selectorBenchIdx int +) + +// --- NormaliseKey: per-Lookup hot path --- +// NormaliseKey runs core.Lower + core.Trim + two replaceAll passes. +// The replaceAll pass is the unique cost — it allocates a Builder +// on every call regardless of whether substitution actually happens. + +func Benchmark_Selector_NormaliseKey_AlreadyClean(b *testing.B) { + value := "qwen3" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchKey = NormaliseKey(value) + } +} + +func Benchmark_Selector_NormaliseKey_MixedCase(b *testing.B) { + value := "Qwen3" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchKey = NormaliseKey(value) + } +} + +func Benchmark_Selector_NormaliseKey_NeedsReplace(b *testing.B) { + value := "Qwen-3.5" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchKey = NormaliseKey(value) + } +} + +func Benchmark_Selector_NormaliseKey_Empty(b *testing.B) { + value := "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchKey = NormaliseKey(value) + } +} + +// --- Family: branch-heavy classifier called per LookupHint --- + +func Benchmark_Selector_Family_Qwen(b *testing.B) { + hint := Hint{Architecture: "qwen3"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchFam = Family(hint) + } +} + +func Benchmark_Selector_Family_Gemma(b *testing.B) { + hint := Hint{Architecture: "gemma4_text"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchFam = Family(hint) + } +} + +// Granite hits the LAST switch arm before generic — worst-case for +// the chained Contains() probe. +func Benchmark_Selector_Family_Granite(b *testing.B) { + hint := Hint{Architecture: "granite"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchFam = Family(hint) + } +} + +// Unknown architecture falls all the way through every switch arm. +func Benchmark_Selector_Family_Unknown(b *testing.B) { + hint := Hint{Architecture: "not-a-real-arch"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchFam = Family(hint) + } +} + +// With AdapterName the combined string is longer + scanned twice. +func Benchmark_Selector_Family_QwenWithAdapter(b *testing.B) { + hint := Hint{Architecture: "qwen3", AdapterName: "lora-coder"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchFam = Family(hint) + } +} + +// --- replaceAll: NormaliseKey inner loop --- + +func Benchmark_Selector_ReplaceAll_NoMatch(b *testing.B) { + text := "qwen3" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchKey = replaceAll(text, "-", "_") + } +} + +func Benchmark_Selector_ReplaceAll_SingleMatch(b *testing.B) { + text := "qwen-3" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchKey = replaceAll(text, "-", "_") + } +} + +func Benchmark_Selector_ReplaceAll_ManyMatches(b *testing.B) { + text := "a-b-c-d-e-f-g-h" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchKey = replaceAll(text, "-", "_") + } +} + +// Empty `old` short-circuits at the function head. +func Benchmark_Selector_ReplaceAll_EmptyOld(b *testing.B) { + text := "qwen-3" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchKey = replaceAll(text, "", "_") + } +} + +// --- indexString: the inner scan loop everything else resolves to --- + +func Benchmark_Selector_IndexString_HitEarly(b *testing.B) { + text := "plananswer with a tail of fluff to scan past" + substr := "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchIdx = indexString(text, substr) + } +} + +func Benchmark_Selector_IndexString_HitLate(b *testing.B) { + // 256 bytes of filler + the substring at the tail. + filler := "" + for i := 0; i < 64; i++ { + filler += "word" + } + text := filler + "" + substr := "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchIdx = indexString(text, substr) + } +} + +func Benchmark_Selector_IndexString_Miss(b *testing.B) { + filler := "" + for i := 0; i < 64; i++ { + filler += "word" + } + text := filler + substr := "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchIdx = indexString(text, substr) + } +} + +func Benchmark_Selector_IndexString_EmptySubstr(b *testing.B) { + text := "some text" + substr := "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchIdx = indexString(text, substr) + } +} + +func Benchmark_Selector_IndexString_SubstrLongerThanText(b *testing.B) { + text := "hi" + substr := "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchIdx = indexString(text, substr) + } +} + +// 2048-byte miss — proxy for scanning a full generation stream looking +// for a marker that never appears. +func Benchmark_Selector_IndexString_Miss_2048bytes(b *testing.B) { + filler := "" + for i := 0; i < 512; i++ { + filler += "word" + } + text := filler + substr := "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + selectorBenchIdx = indexString(text, substr) + } +} diff --git a/go/parser/thinking.go b/go/parser/thinking.go new file mode 100644 index 0000000..82df486 --- /dev/null +++ b/go/parser/thinking.go @@ -0,0 +1,264 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package parser + +import ( + "strings" + + core "dappco.re/go" +) + +// result := parser.Filter(text, parser.Config{Mode: parser.Capture}, hint) +// visible := result.Text +func Filter(text string, cfg Config, hint Hint) Result { + processor := NewProcessor(cfg, hint) + builder := core.NewBuilder() + builder.WriteString(processor.Process(text)) + builder.WriteString(processor.Flush()) + return Result{ + Text: builder.String(), + Reasoning: processor.Reasoning(), + Chunks: processor.Chunks(), + } +} + +// p := parser.NewProcessor(cfg, hint) +// visible := p.Process(piece) + p.Flush() +type Processor struct { + cfg Config + mode Mode + markers []thinkingMarker + startSet []string // cached marker.start values — invariant once markers is set + pending string + inReasoning bool + current thinkingMarker + reasoningParts []string + blockParts []string + chunks []Chunk +} + +// p := parser.NewProcessor(parser.Config{Mode: parser.Capture}, hint) +func NewProcessor(cfg Config, hint Hint) *Processor { + markers := markersForHint(hint) + startSet := make([]string, len(markers)) + for i, m := range markers { + startSet[i] = m.start + } + return &Processor{ + cfg: cfg, + mode: NormaliseMode(cfg.Mode), + markers: markers, + startSet: startSet, + } +} + +// mode := parser.NormaliseMode("") // returns parser.Show +func NormaliseMode(mode Mode) Mode { + switch mode { + case "", Show: + return Show + case Hide, Capture: + return mode + default: + return Show + } +} + +func markersForHint(hint Hint) []thinkingMarker { + p, ok := ForHint(hint).(*builtinOutputParser) + if !ok || p == nil { + p = newBuiltinOutputParser("generic", genericMarkers()) + } + markers := make([]thinkingMarker, 0, len(p.markers)) + for _, m := range p.markers { + for _, end := range m.ends { + if m.start == "" || end == "" { + continue + } + markers = append(markers, thinkingMarker{ + start: m.start, + end: end, + channel: m.kind, + model: p.ParserID(), + }) + } + } + return markers +} + +// visible := p.Process(piece) +func (p *Processor) Process(text string) string { + if p.mode == Show || text == "" { + return text + } + p.pending += text + return p.drain(false) +} + +// tail := p.Flush() +func (p *Processor) Flush() string { + if p.mode == Show { + return "" + } + out := p.drain(true) + if p.pending == "" { + if p.inReasoning { + p.emitReasoningBlock() + p.inReasoning = false + } + return out + } + if p.inReasoning { + p.addReasoning(p.pending) + p.pending = "" + p.emitReasoningBlock() + p.inReasoning = false + return out + } + out += p.pending + p.pending = "" + return out +} + +// reasoning := p.Reasoning() +func (p *Processor) Reasoning() string { + return core.Join("", p.reasoningParts...) +} + +// chunks := p.Chunks() +func (p *Processor) Chunks() []Chunk { + if len(p.chunks) == 0 { + return nil + } + return append([]Chunk(nil), p.chunks...) +} + +func (p *Processor) drain(final bool) string { + if p.pending == "" { + return "" + } + // Lazy-init the builder. Per-token streaming hits drain on every + // token; the common no-marker path writes a single slice that can + // be returned directly without ever touching a builder. The builder + // only allocates when we cross a marker boundary mid-string and + // need to splice a visible prefix with a suffix later in the loop. + var out *strings.Builder + for p.pending != "" { + if p.inReasoning { + idx := indexString(p.pending, p.current.end) + if idx >= 0 { + p.addReasoning(p.pending[:idx]) + p.pending = p.pending[idx+len(p.current.end):] + p.emitReasoningBlock() + p.inReasoning = false + continue + } + keep := 0 + if !final { + keep = longestSuffixPrefix(p.pending, []string{p.current.end}) + } + consume := len(p.pending) - keep + if consume > 0 { + p.addReasoning(p.pending[:consume]) + p.pending = p.pending[consume:] + } + break + } + + idx, marker, ok := p.findStart(p.pending) + if ok { + if idx > 0 { + if out == nil { + out = core.NewBuilder() + } + out.WriteString(p.pending[:idx]) + } + p.pending = p.pending[idx+len(marker.start):] + p.current = marker + p.inReasoning = true + continue + } + keep := 0 + if !final { + keep = longestSuffixPrefix(p.pending, p.startSet) + } + consume := len(p.pending) - keep + if consume == 0 { + break + } + if out == nil { + // Single-write path — return the slice directly without + // paying for a builder alloc. This is the streaming hot + // path: per-token Process call, no marker in pending, + // consume the visible bytes and return. + output := p.pending[:consume] + p.pending = p.pending[consume:] + return output + } + out.WriteString(p.pending[:consume]) + p.pending = p.pending[consume:] + break + } + if out == nil { + return "" + } + return out.String() +} + +func (p *Processor) findStart(text string) (int, thinkingMarker, bool) { + best := -1 + var marker thinkingMarker + for _, candidate := range p.markers { + idx := indexString(text, candidate.start) + if idx < 0 { + continue + } + if best < 0 || idx < best || idx == best && len(candidate.start) > len(marker.start) { + best = idx + marker = candidate + } + } + return best, marker, best >= 0 +} + +func (p *Processor) addReasoning(text string) { + if text == "" { + return + } + p.reasoningParts = append(p.reasoningParts, text) + p.blockParts = append(p.blockParts, text) +} + +func (p *Processor) emitReasoningBlock() { + text := core.Join("", p.blockParts...) + p.blockParts = nil + if text == "" { + return + } + chunk := Chunk{ + Text: text, + Channel: p.current.channel, + Model: p.current.model, + } + p.chunks = append(p.chunks, chunk) + if p.mode == Capture && p.cfg.Capture != nil { + p.cfg.Capture(chunk) + } +} + +func longestSuffixPrefix(text string, markers []string) int { + best := 0 + for _, marker := range markers { + max := len(marker) - 1 + if max > len(text) { + max = len(text) + } + for size := max; size > best; size-- { + if core.HasPrefix(marker, text[len(text)-size:]) { + best = size + break + } + } + } + return best +} diff --git a/go/parser/thinking_bench_test.go b/go/parser/thinking_bench_test.go new file mode 100644 index 0000000..e98a9f6 --- /dev/null +++ b/go/parser/thinking_bench_test.go @@ -0,0 +1,460 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the streaming thinking-mode Processor — Filter, +// NewProcessor, Process, Flush, Reasoning, Chunks, NormaliseMode, +// markersForHint, longestSuffixPrefix. Per AX-11 — Processor.Process is +// the PER-TOKEN hot loop fired on every streamed chunk during +// generation (one call per generated token, possibly thousands per +// response). longestSuffixPrefix is the partial-marker held-tail check +// also paid per token. NewProcessor + markersForHint are the +// per-stream build cost paid once per response but reach into the +// registry. Filter is the batch (non-streaming) entry point. +// +// Run: go test -bench='Benchmark_Thinking' -benchmem -run='^$' ./go/parser +// +// Stream sizes: +// - 32-token ≈ very short response +// - 256-token ≈ typical chat response +// - 2048-token ≈ long-form streamed response + +package parser + +import ( + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. +var ( + thinkingBenchResult Result + thinkingBenchProcessor *Processor + thinkingBenchText string + thinkingBenchMode Mode + thinkingBenchMarkers []thinkingMarker + thinkingBenchKeep int + thinkingBenchChunks []Chunk + thinkingBenchReasoning string +) + +// thinkingBenchWords builds a synthetic prose stream of `tokens` words. +func thinkingBenchWords(tokens int) string { + out := core.NewBuilder() + for i := 0; i < tokens; i++ { + out.WriteString("word ") + } + return out.String() +} + +// thinkingBenchTokens chunks a stream into per-token deliveries — the +// actual per-token Process() input shape during streaming. We split +// on whitespace and reassemble each "word " into a delivery to mirror +// the inference loop's flush rhythm. +func thinkingBenchTokens(text string) []string { + out := make([]string, 0, 256) + start := 0 + for i := 0; i < len(text); i++ { + if text[i] == ' ' { + out = append(out, text[start:i+1]) + start = i + 1 + } + } + if start < len(text) { + out = append(out, text[start:]) + } + return out +} + +// thinkingBenchStream wraps a span of words inside the marker pair, +// span covering `spanFraction` of the total. +func thinkingBenchStream(tokens int, spanFraction float64, startMarker, endMarker string) string { + span := int(float64(tokens) * spanFraction) + if span < 1 { + span = 1 + } + if span > tokens { + span = tokens + } + pre := (tokens - span) / 2 + post := tokens - span - pre + out := core.NewBuilder() + out.WriteString(thinkingBenchWords(pre)) + out.WriteString(startMarker) + out.WriteString(thinkingBenchWords(span)) + out.WriteString(endMarker) + out.WriteString(thinkingBenchWords(post)) + return out.String() +} + +// --- Filter (batch entry point) --- + +func Benchmark_Thinking_Filter_Show_Qwen(b *testing.B) { + text := thinkingBenchStream(256, 0.50, "", "") + hint := Hint{Architecture: "qwen3"} + cfg := Config{Mode: Show} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchResult = Filter(text, cfg, hint) + } +} + +func Benchmark_Thinking_Filter_Hide_Qwen(b *testing.B) { + text := thinkingBenchStream(256, 0.50, "", "") + hint := Hint{Architecture: "qwen3"} + cfg := Config{Mode: Hide} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchResult = Filter(text, cfg, hint) + } +} + +func Benchmark_Thinking_Filter_Capture_Qwen(b *testing.B) { + text := thinkingBenchStream(256, 0.50, "", "") + hint := Hint{Architecture: "qwen3"} + cfg := Config{Mode: Capture, Capture: func(Chunk) {}} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchResult = Filter(text, cfg, hint) + } +} + +func Benchmark_Thinking_Filter_Hide_Gemma(b *testing.B) { + text := thinkingBenchStream(256, 0.50, "thinking\n", "") + hint := Hint{Architecture: "gemma4_text"} + cfg := Config{Mode: Hide} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchResult = Filter(text, cfg, hint) + } +} + +// --- NewProcessor (per-stream build cost) --- + +func Benchmark_Thinking_NewProcessor_Qwen(b *testing.B) { + hint := Hint{Architecture: "qwen3"} + cfg := Config{Mode: Hide} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchProcessor = NewProcessor(cfg, hint) + } +} + +func Benchmark_Thinking_NewProcessor_Gemma(b *testing.B) { + hint := Hint{Architecture: "gemma4_text"} + cfg := Config{Mode: Hide} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchProcessor = NewProcessor(cfg, hint) + } +} + +// --- markersForHint (per-NewProcessor inner cost) --- + +func Benchmark_Thinking_MarkersForHint_Qwen(b *testing.B) { + hint := Hint{Architecture: "qwen3"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchMarkers = markersForHint(hint) + } +} + +func Benchmark_Thinking_MarkersForHint_Gemma(b *testing.B) { + hint := Hint{Architecture: "gemma4_text"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchMarkers = markersForHint(hint) + } +} + +func Benchmark_Thinking_MarkersForHint_GPTOSS(b *testing.B) { + hint := Hint{Architecture: "gpt-oss"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchMarkers = markersForHint(hint) + } +} + +// --- NormaliseMode (cheap branch, called per NewProcessor) --- + +func Benchmark_Thinking_NormaliseMode_Empty(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchMode = NormaliseMode("") + } +} + +func Benchmark_Thinking_NormaliseMode_Hide(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchMode = NormaliseMode(Hide) + } +} + +func Benchmark_Thinking_NormaliseMode_Capture(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchMode = NormaliseMode(Capture) + } +} + +func Benchmark_Thinking_NormaliseMode_Unknown(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchMode = NormaliseMode("unknown") + } +} + +// --- Process: PER-TOKEN HOT LOOP --- +// Show-mode short-circuits at the function head (the cheap path). +// Hide/Capture-mode pays the full drain() cost per call. + +func Benchmark_Thinking_Process_Show_Qwen_PerToken(b *testing.B) { + pieces := thinkingBenchTokens(thinkingBenchStream(256, 0.50, "", "")) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + processor := NewProcessor(Config{Mode: Show}, Hint{Architecture: "qwen3"}) + for _, piece := range pieces { + thinkingBenchText = processor.Process(piece) + } + thinkingBenchText = processor.Flush() + } +} + +// Per-token streaming over various stream sizes. +var thinkingBenchStreamSizes = []int{32, 256, 2048} + +func Benchmark_Thinking_Process_Hide_Qwen_PerToken(b *testing.B) { + for _, size := range thinkingBenchStreamSizes { + pieces := thinkingBenchTokens(thinkingBenchStream(size, 0.50, "", "")) + b.Run(core.Sprintf("Tokens%d", size), func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "qwen3"}) + for _, piece := range pieces { + thinkingBenchText = processor.Process(piece) + } + thinkingBenchText = processor.Flush() + } + }) + } +} + +func Benchmark_Thinking_Process_Capture_Qwen_PerToken(b *testing.B) { + for _, size := range thinkingBenchStreamSizes { + pieces := thinkingBenchTokens(thinkingBenchStream(size, 0.50, "", "")) + b.Run(core.Sprintf("Tokens%d", size), func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + processor := NewProcessor(Config{Mode: Capture, Capture: func(Chunk) {}}, Hint{Architecture: "qwen3"}) + for _, piece := range pieces { + thinkingBenchText = processor.Process(piece) + } + thinkingBenchText = processor.Flush() + } + }) + } +} + +// Vary span fraction at fixed 256-token length — covers the 10/50/90% +// reasoning-density profile. +var thinkingBenchSpanFractions = []struct { + id string + frac float64 +}{ + {"Span10pct", 0.10}, + {"Span50pct", 0.50}, + {"Span90pct", 0.90}, +} + +func Benchmark_Thinking_Process_Hide_Qwen_Span(b *testing.B) { + for _, span := range thinkingBenchSpanFractions { + pieces := thinkingBenchTokens(thinkingBenchStream(256, span.frac, "", "")) + b.Run(span.id, func(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "qwen3"}) + for _, piece := range pieces { + thinkingBenchText = processor.Process(piece) + } + thinkingBenchText = processor.Flush() + } + }) + } +} + +// Gemma + gpt-oss carry the worst-case marker fan-out — markersForHint +// builds a much bigger marker set, and findStart pays per token. +func Benchmark_Thinking_Process_Hide_Gemma_PerToken(b *testing.B) { + pieces := thinkingBenchTokens(thinkingBenchStream(256, 0.50, "thinking\n", "")) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "gemma4_text"}) + for _, piece := range pieces { + thinkingBenchText = processor.Process(piece) + } + thinkingBenchText = processor.Flush() + } +} + +func Benchmark_Thinking_Process_Hide_GPTOSS_PerToken(b *testing.B) { + pieces := thinkingBenchTokens(thinkingBenchStream(256, 0.50, "<|channel>analysis\n", "<|channel>final\n")) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "gpt-oss"}) + for _, piece := range pieces { + thinkingBenchText = processor.Process(piece) + } + thinkingBenchText = processor.Flush() + } +} + +// Process pays nothing in Show mode beyond the type-switch + concat — +// exercise that fast path as a baseline. +func Benchmark_Thinking_Process_Show_Single(b *testing.B) { + processor := NewProcessor(Config{Mode: Show}, Hint{Architecture: "qwen3"}) + piece := "word " + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchText = processor.Process(piece) + } +} + +// Hide-mode single-piece call when there's no marker in flight — +// pays the pending-append + drain probe cost. +func Benchmark_Thinking_Process_Hide_NoMarker_Single(b *testing.B) { + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "qwen3"}) + piece := "word " + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchText = processor.Process(piece) + } +} + +// --- Flush --- + +func Benchmark_Thinking_Flush_NoPending(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "qwen3"}) + b.StartTimer() + thinkingBenchText = processor.Flush() + } +} + +func Benchmark_Thinking_Flush_OpenReasoning(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + b.StopTimer() + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "qwen3"}) + processor.Process("partial reasoning never closed") + b.StartTimer() + thinkingBenchText = processor.Flush() + } +} + +// --- Reasoning + Chunks accessors --- + +func Benchmark_Thinking_Reasoning_Empty(b *testing.B) { + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "qwen3"}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchReasoning = processor.Reasoning() + } +} + +func Benchmark_Thinking_Reasoning_Populated(b *testing.B) { + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "qwen3"}) + for _, piece := range thinkingBenchTokens(thinkingBenchStream(256, 0.50, "", "")) { + processor.Process(piece) + } + processor.Flush() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchReasoning = processor.Reasoning() + } +} + +func Benchmark_Thinking_Chunks_Empty(b *testing.B) { + processor := NewProcessor(Config{Mode: Hide}, Hint{Architecture: "qwen3"}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchChunks = processor.Chunks() + } +} + +func Benchmark_Thinking_Chunks_Populated(b *testing.B) { + processor := NewProcessor(Config{Mode: Capture, Capture: func(Chunk) {}}, Hint{Architecture: "qwen3"}) + for _, piece := range thinkingBenchTokens(thinkingBenchStream(256, 0.50, "", "")) { + processor.Process(piece) + } + processor.Flush() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchChunks = processor.Chunks() + } +} + +// --- longestSuffixPrefix: per-token held-tail check inside Process() --- + +func Benchmark_Thinking_LongestSuffixPrefix_NoMatch(b *testing.B) { + text := "ordinary text with no marker prefix at the end" + markers := []string{"", "", "", ""} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchKeep = longestSuffixPrefix(text, markers) + } +} + +func Benchmark_Thinking_LongestSuffixPrefix_PartialMatch(b *testing.B) { + text := "ordinary text trailing with ", "", "", ""} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + thinkingBenchKeep = longestSuffixPrefix(text, markers) + } +} + +func Benchmark_Thinking_LongestSuffixPrefix_LongMarkerSet(b *testing.B) { + // Build the gemma marker fan-out as a starts-only list. + gemma := gemmaMarkers() + starts := make([]string, 0, len(gemma)) + for _, m := range gemma { + starts = append(starts, m.start) + } + text := "ordinary text trailing with thinking\nplanfinal", + Config{Mode: Hide}, + Hint{Architecture: "gemma4_text"}, + ) + if got.Text != "final" { + t.Fatalf("Text = %q, want final", got.Text) + } + if got.Reasoning != "plan" { + t.Fatalf("Reasoning = %q, want plan", got.Reasoning) + } +} + +func TestThinking_FilterShowPassthrough_Ugly(t *testing.T) { + raw := "secretvisible" + got := Filter(raw, Config{Mode: Show}, Hint{Architecture: "qwen3"}) + if got.Text != raw { + t.Fatalf("Text = %q, want raw passthrough", got.Text) + } + if got.Reasoning != "" { + t.Fatalf("Reasoning = %q, want empty for passthrough mode", got.Reasoning) + } +} + +func TestThinking_ProcessorFlushesPartialAndOpenBlocks_Ugly(t *testing.T) { + var captured []Chunk + processor := NewProcessor(Config{ + Mode: Capture, + Capture: func(chunk Chunk) { + captured = append(captured, chunk) + }, + }, Hint{Architecture: "qwen3"}) + + if text := processor.Process("visible unfinished"); text != "" { + t.Fatalf("open reasoning output = %q, want hidden reasoning", text) + } + if text := processor.Flush(); text != "" { + t.Fatalf("flush output = %q, want empty while closing open reasoning", text) + } + if processor.Reasoning() != "unfinished" { + t.Fatalf("reasoning = %q, want unfinished", processor.Reasoning()) + } + if len(captured) != 1 || captured[0].Text != "unfinished" { + t.Fatalf("captured = %+v, want unfinished block", captured) + } + + processor = NewProcessor(Config{Mode: Hide}, Hint{Architecture: "qwen3"}) + if text := processor.Process("", end: ""}, + {start: "", end: ""}, + {start: "", end: ""}, +} + +func parseToolText(text string) (inference.ToolParseResult, error) { + visible := core.NewBuilder() + calls := []inference.ToolCall{} + pending := text + foundTagged := false + for pending != "" { + idx, marker, ok := findToolBlockStart(pending) + if !ok { + visible.WriteString(pending) + break + } + foundTagged = true + visible.WriteString(pending[:idx]) + afterStart := pending[idx+len(marker.start):] + end := indexString(afterStart, marker.end) + if end < 0 { + visible.WriteString(pending[idx:]) + break + } + parsed, err := parseToolPayload(afterStart[:end]) + if err != nil { + return inference.ToolParseResult{}, err + } + calls = append(calls, parsed...) + pending = afterStart[end+len(marker.end):] + } + if !foundTagged { + parsed, err := parseToolPayload(text) + if err == nil && len(parsed) > 0 { + return inference.ToolParseResult{VisibleText: "", Calls: parsed}, nil + } + } + return inference.ToolParseResult{VisibleText: visible.String(), Calls: calls}, nil +} + +func findToolBlockStart(text string) (int, toolBlockMarker, bool) { + best := -1 + var marker toolBlockMarker + for _, candidate := range toolBlockMarkers { + idx := indexString(text, candidate.start) + if idx < 0 { + continue + } + if best < 0 || idx < best { + best = idx + marker = candidate + } + } + return best, marker, best >= 0 +} + +type parsedToolCall struct { + ID string `json:"id"` + Type string `json:"type"` + Name string `json:"name"` + Arguments any `json:"arguments"` + ArgumentsJSON string `json:"arguments_json"` + Function *parsedFunction `json:"function"` + ToolCalls []parsedToolCall `json:"tool_calls"` + Calls []parsedToolCall `json:"calls"` +} + +type parsedFunction struct { + Name string `json:"name"` + Arguments any `json:"arguments"` +} + +func parseToolPayload(payload string) ([]inference.ToolCall, error) { + payload = core.Trim(payload) + if payload == "" { + return nil, nil + } + // Cheap shape check before reflection-decoding — a tool-call payload + // is always JSON. If the trimmed text doesn't start with '[' or '{', + // don't pay the encoding/json reflect walk just to discover that + // fact (the common no-tool-calls case the streaming parser feeds us + // is plain assistant prose). + first := payload[0] + if first != '[' && first != '{' { + return nil, nil + } + var list []parsedToolCall + if first == '[' { + result := core.JSONUnmarshalString(payload, &list) + if !result.OK { + return nil, resultError("parser.tool", result) + } + return convertParsedToolCalls(list), nil + } + var envelope parsedToolCall + result := core.JSONUnmarshalString(payload, &envelope) + if !result.OK { + return nil, resultError("parser.tool", result) + } + if len(envelope.ToolCalls) > 0 { + return convertParsedToolCalls(envelope.ToolCalls), nil + } + if len(envelope.Calls) > 0 { + return convertParsedToolCalls(envelope.Calls), nil + } + call := convertParsedToolCall(envelope) + if call.Name == "" { + return nil, nil + } + return []inference.ToolCall{call}, nil +} + +func convertParsedToolCalls(input []parsedToolCall) []inference.ToolCall { + out := make([]inference.ToolCall, 0, len(input)) + for _, parsed := range input { + call := convertParsedToolCall(parsed) + if call.Name != "" { + out = append(out, call) + } + } + return out +} + +func convertParsedToolCall(parsed parsedToolCall) inference.ToolCall { + name := parsed.Name + args := parsed.Arguments + if parsed.Function != nil { + if parsed.Function.Name != "" { + name = parsed.Function.Name + } + if parsed.Function.Arguments != nil { + args = parsed.Function.Arguments + } + } + callType := parsed.Type + if callType == "" { + callType = "function" + } + return inference.ToolCall{ + ID: parsed.ID, + Type: callType, + Name: name, + ArgumentsJSON: normaliseArgumentsJSON(parsed.ArgumentsJSON, args), + } +} + +func normaliseArgumentsJSON(existing string, args any) string { + if core.Trim(existing) != "" { + return core.Trim(existing) + } + if args == nil { + return "" + } + if raw, ok := args.(string); ok { + return core.Trim(raw) + } + return core.JSONMarshalString(args) +} + +func resultError(scope string, result core.Result) error { + if err, ok := result.Value.(error); ok { + return core.Wrap(err, scope, "parse JSON") + } + return core.E(scope, "parse JSON", nil) +} diff --git a/go/parser/tools_bench_test.go b/go/parser/tools_bench_test.go new file mode 100644 index 0000000..6d1b193 --- /dev/null +++ b/go/parser/tools_bench_test.go @@ -0,0 +1,350 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the tool-call parser — parseToolText, findToolBlockStart, +// parseToolPayload, convertParsedToolCalls, convertParsedToolCall, +// normaliseArgumentsJSON. Per AX-11 — parseToolText is the per-flush +// hot loop fired on every completion that may carry a tool call (every +// agentic-mode response). findToolBlockStart is the per-scan fan-out +// across three block-marker pairs. parseToolPayload pays the JSON-decode +// + envelope-walk per call. The bench varies tool-call count (0 / 1 / 5) +// and stream length to mirror realistic agent traces. +// +// Run: go test -bench='Benchmark_Tools' -benchmem -run='^$' ./go/parser + +package parser + +import ( + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + toolsBenchResult inference.ToolParseResult + toolsBenchErr error + toolsBenchCalls []inference.ToolCall + toolsBenchCall inference.ToolCall + toolsBenchIdx int + toolsBenchMarker toolBlockMarker + toolsBenchOK bool + toolsBenchString string +) + +// toolsBenchWords builds a synthetic prose stream of `tokens` words. +func toolsBenchWords(tokens int) string { + out := core.NewBuilder() + for i := 0; i < tokens; i++ { + out.WriteString("word ") + } + return out.String() +} + +// toolsBenchStreamWithCalls splices `n` tool-call blocks evenly +// across a prose stream of `tokens` words. +func toolsBenchStreamWithCalls(tokens, n int) string { + pre := tokens / (n + 1) + out := core.NewBuilder() + for i := 0; i < n; i++ { + out.WriteString(toolsBenchWords(pre)) + out.WriteString(`{"name":"search","arguments":{"q":"core","page":`) + out.WriteString(core.Sprintf("%d", i)) + out.WriteString(`}}`) + } + out.WriteString(toolsBenchWords(pre)) + return out.String() +} + +// --- parseToolText: per-response hot path --- + +func Benchmark_Tools_ParseText_NoCalls_Short(b *testing.B) { + text := toolsBenchWords(32) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +func Benchmark_Tools_ParseText_NoCalls_Mid(b *testing.B) { + text := toolsBenchWords(256) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +func Benchmark_Tools_ParseText_NoCalls_Long(b *testing.B) { + text := toolsBenchWords(2048) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +func Benchmark_Tools_ParseText_OneCall_Short(b *testing.B) { + text := toolsBenchStreamWithCalls(32, 1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +func Benchmark_Tools_ParseText_OneCall_Mid(b *testing.B) { + text := toolsBenchStreamWithCalls(256, 1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +func Benchmark_Tools_ParseText_OneCall_Long(b *testing.B) { + text := toolsBenchStreamWithCalls(2048, 1) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +func Benchmark_Tools_ParseText_FiveCalls_Mid(b *testing.B) { + text := toolsBenchStreamWithCalls(256, 5) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +func Benchmark_Tools_ParseText_FiveCalls_Long(b *testing.B) { + text := toolsBenchStreamWithCalls(2048, 5) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +// Unclosed tagged tool-call exercises the `end < 0` branch — the +// scan walks the whole payload looking for `` and falls +// back to passthrough. +func Benchmark_Tools_ParseText_Unclosed(b *testing.B) { + text := `before {"name":"search","arguments":{"q":"core"}` + toolsBenchWords(64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +// Untagged JSON fallback — the entire payload is parsed as JSON. +func Benchmark_Tools_ParseText_JSONFallback(b *testing.B) { + text := `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"lookup","arguments":{"id":7}}}]}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +// Tool-calls block (plural) wrapper. +func Benchmark_Tools_ParseText_ToolCallsBlock(b *testing.B) { + text := `pre [{"name":"a","arguments":{"x":1}},{"name":"b","arguments":{"y":2}}] post` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +// function_call (singular) wrapper. +func Benchmark_Tools_ParseText_FunctionCallBlock(b *testing.B) { + text := `pre {"name":"a","arguments":{"x":1}} post` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchResult, toolsBenchErr = parseToolText(text) + } +} + +// --- findToolBlockStart: per-scan fan-out across 3 marker pairs --- + +func Benchmark_Tools_FindBlockStart_HitFirst(b *testing.B) { + text := `{"name":"x"}tail` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchIdx, toolsBenchMarker, toolsBenchOK = findToolBlockStart(text) + } +} + +func Benchmark_Tools_FindBlockStart_HitMid(b *testing.B) { + text := toolsBenchWords(64) + `{"name":"x"}tail` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchIdx, toolsBenchMarker, toolsBenchOK = findToolBlockStart(text) + } +} + +func Benchmark_Tools_FindBlockStart_Miss_256bytes(b *testing.B) { + text := toolsBenchWords(64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchIdx, toolsBenchMarker, toolsBenchOK = findToolBlockStart(text) + } +} + +func Benchmark_Tools_FindBlockStart_Miss_2048bytes(b *testing.B) { + text := toolsBenchWords(512) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchIdx, toolsBenchMarker, toolsBenchOK = findToolBlockStart(text) + } +} + +// --- parseToolPayload: JSON decode + envelope walk --- + +func Benchmark_Tools_ParsePayload_SingleObject(b *testing.B) { + payload := `{"name":"search","arguments":{"q":"core"}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCalls, toolsBenchErr = parseToolPayload(payload) + } +} + +func Benchmark_Tools_ParsePayload_Array(b *testing.B) { + payload := `[{"name":"a","arguments":{"x":1}},{"name":"b","arguments":{"y":2}}]` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCalls, toolsBenchErr = parseToolPayload(payload) + } +} + +func Benchmark_Tools_ParsePayload_ToolCallsEnvelope(b *testing.B) { + payload := `{"tool_calls":[{"id":"c1","type":"function","function":{"name":"lookup","arguments":{"id":7}}}]}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCalls, toolsBenchErr = parseToolPayload(payload) + } +} + +func Benchmark_Tools_ParsePayload_CallsEnvelope(b *testing.B) { + payload := `{"calls":[{"name":"lookup","arguments":{"id":7}}]}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCalls, toolsBenchErr = parseToolPayload(payload) + } +} + +func Benchmark_Tools_ParsePayload_FunctionEnvelope(b *testing.B) { + payload := `{"function":{"name":"lookup","arguments":{"id":7}}}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCalls, toolsBenchErr = parseToolPayload(payload) + } +} + +func Benchmark_Tools_ParsePayload_Empty(b *testing.B) { + payload := "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCalls, toolsBenchErr = parseToolPayload(payload) + } +} + +func Benchmark_Tools_ParsePayload_ArgumentsAsString(b *testing.B) { + payload := `{"name":"search","arguments_json":"{\"q\":\"core\"}"}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCalls, toolsBenchErr = parseToolPayload(payload) + } +} + +// --- convertParsedToolCalls / convertParsedToolCall --- + +func Benchmark_Tools_ConvertParsedToolCall_SimpleName(b *testing.B) { + parsed := parsedToolCall{Name: "search", Arguments: map[string]any{"q": "core"}} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCall = convertParsedToolCall(parsed) + } +} + +func Benchmark_Tools_ConvertParsedToolCall_FromFunctionEnvelope(b *testing.B) { + parsed := parsedToolCall{ + ID: "c1", + Type: "function", + Function: &parsedFunction{Name: "lookup", Arguments: map[string]any{"id": 7}}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCall = convertParsedToolCall(parsed) + } +} + +func Benchmark_Tools_ConvertParsedToolCalls_Array(b *testing.B) { + input := []parsedToolCall{ + {Name: "a", Arguments: map[string]any{"x": 1}}, + {Name: "b", Arguments: map[string]any{"y": 2}}, + {Name: "c", Arguments: map[string]any{"z": 3}}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchCalls = convertParsedToolCalls(input) + } +} + +// --- normaliseArgumentsJSON --- + +func Benchmark_Tools_NormaliseArgumentsJSON_ExistingJSON(b *testing.B) { + existing := `{"q":"core"}` + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchString = normaliseArgumentsJSON(existing, nil) + } +} + +func Benchmark_Tools_NormaliseArgumentsJSON_FromMap(b *testing.B) { + args := map[string]any{"q": "core", "page": 3} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchString = normaliseArgumentsJSON("", args) + } +} + +func Benchmark_Tools_NormaliseArgumentsJSON_FromString(b *testing.B) { + args := any(`{"q":"core"}`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchString = normaliseArgumentsJSON("", args) + } +} + +func Benchmark_Tools_NormaliseArgumentsJSON_Nil(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + toolsBenchString = normaliseArgumentsJSON("", nil) + } +} diff --git a/go/parser/tools_test.go b/go/parser/tools_test.go new file mode 100644 index 0000000..31d0631 --- /dev/null +++ b/go/parser/tools_test.go @@ -0,0 +1,59 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package parser + +import ( + "testing" +) + +func TestTools_TaggedAndJSONFallback_Good(t *testing.T) { + p := ForHint(Hint{Architecture: "hermes3"}) + + tagged, err := p.ParseTools(nil, `before {"name":"search","arguments":{"q":"core"}} after`) + if err != nil { + t.Fatalf("ParseTools(tagged) error = %v", err) + } + if tagged.VisibleText != "before after" { + t.Fatalf("tagged visible = %q", tagged.VisibleText) + } + if len(tagged.Calls) != 1 || tagged.Calls[0].Name != "search" || tagged.Calls[0].ArgumentsJSON != `{"q":"core"}` { + t.Fatalf("tagged calls = %+v", tagged.Calls) + } + + jsonFallback, err := p.ParseTools(nil, `{"tool_calls":[{"id":"call_1","type":"function","function":{"name":"lookup","arguments":{"id":7}}}]}`) + if err != nil { + t.Fatalf("ParseTools(json) error = %v", err) + } + if jsonFallback.VisibleText != "" { + t.Fatalf("json visible = %q, want empty", jsonFallback.VisibleText) + } + if len(jsonFallback.Calls) != 1 || jsonFallback.Calls[0].ID != "call_1" || jsonFallback.Calls[0].Name != "lookup" || jsonFallback.Calls[0].ArgumentsJSON != `{"id":7}` { + t.Fatalf("json calls = %+v", jsonFallback.Calls) + } +} + +func TestTools_BadAndUglyPayloads(t *testing.T) { + p := ForHint(Hint{Architecture: "qwen3"}) + if _, err := p.ParseTools(nil, `{bad}`); err == nil { + t.Fatal("ParseTools(malformed tagged JSON) error = nil") + } + unclosed, err := p.ParseTools(nil, `before {"name":"search"}`) + if err != nil { + t.Fatalf("ParseTools(unclosed tag) error = %v", err) + } + if unclosed.VisibleText != `before {"name":"search"}` || len(unclosed.Calls) != 0 { + t.Fatalf("unclosed tool parse = %+v, want visible passthrough", unclosed) + } + if calls, err := parseToolPayload(`[{"name":"search","arguments_json":"{\"q\":\"core\"}"},{"name":""}]`); err != nil || len(calls) != 1 || calls[0].ArgumentsJSON != `{"q":"core"}` { + t.Fatalf("parseToolPayload(array) = %+v/%v, want one call with existing args JSON", calls, err) + } + if calls, err := parseToolPayload(`{"calls":[{"name":"lookup","arguments":"{\"id\":7}"}]}`); err != nil || len(calls) != 1 || calls[0].ArgumentsJSON != `{"id":7}` { + t.Fatalf("parseToolPayload(calls) = %+v/%v, want string arguments normalised", calls, err) + } + if calls, err := parseToolPayload(`{"type":"function"}`); err != nil || len(calls) != 0 { + t.Fatalf("parseToolPayload(no name) = %+v/%v, want no call", calls, err) + } + if _, err := parseToolPayload(`{bad}`); err == nil { + t.Fatal("parseToolPayload(bad JSON) error = nil") + } +} diff --git a/go/parser/types.go b/go/parser/types.go new file mode 100644 index 0000000..b861204 --- /dev/null +++ b/go/parser/types.go @@ -0,0 +1,65 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package parser is the driver-neutral output-parsing layer — reasoning +// channels (`...`), tool-call payloads, and a thinking-mode +// processor for streaming or batched generation output. +// +// r := parser.ForHint(parser.Hint{Architecture: "qwen3"}).ParseReasoning(nil, text) +package parser + +// hint := parser.Hint{Architecture: "qwen3", AdapterName: "lora-coder"} +// out := parser.ForHint(hint).ParseReasoning(nil, response) +type Hint struct { + Architecture string + AdapterName string +} + +// cfg := parser.Config{Mode: parser.Capture, Capture: func(c parser.Chunk) { log.Print(c.Text) }} +type Config struct { + Mode Mode `json:"mode,omitempty"` + Capture func(Chunk) `json:"-"` +} + +// parser.Show // leave reasoning markers + content in the visible output +// parser.Hide // strip recognised reasoning blocks from visible output +// parser.Capture // strip from visible + emit blocks via Config.Capture +type Mode string + +const ( + Show Mode = "show" + Hide Mode = "hide" + Capture Mode = "capture" +) + +// chunk := parser.Chunk{Text: "let me think...", Channel: "thinking", Model: "qwen"} +type Chunk struct { + Text string `json:"text"` + Channel string `json:"channel,omitempty"` + Model string `json:"model,omitempty"` +} + +// result := parser.Filter(text, parser.Config{Mode: parser.Capture}, hint) +// visible := result.Text +type Result struct { + Text string `json:"text"` + Reasoning string `json:"reasoning,omitempty"` + Chunks []Chunk `json:"chunks,omitempty"` +} + +type reasoningMarker struct { + start string + ends []string + kind string +} + +type thinkingMarker struct { + start string + end string + channel string + model string +} + +type toolBlockMarker struct { + start string + end string +} diff --git a/go/parser/types_bench_test.go b/go/parser/types_bench_test.go new file mode 100644 index 0000000..34c951a --- /dev/null +++ b/go/parser/types_bench_test.go @@ -0,0 +1,11 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// No CPU-only public surface; skipped. +// types.go declares Hint, Config, Mode, Chunk, Result and the internal +// reasoningMarker / thinkingMarker / toolBlockMarker structs — pure +// type definitions with no runtime functions to benchmark. Benches for +// the consumers of these types live in the per-file benches that +// drive them (builtin_bench_test.go, thinking_bench_test.go, +// registry_bench_test.go, reasoning_bench_test.go, tools_bench_test.go). + +package parser diff --git a/go/probe.go b/go/probe.go new file mode 100644 index 0000000..f1a31cb --- /dev/null +++ b/go/probe.go @@ -0,0 +1,192 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +// ProbeEventKind names the observable event being emitted by a backend. +type ProbeEventKind string + +// ProbePhase marks where an event occurred in the model lifecycle. +type ProbePhase string + +const ( + ProbeEventToken ProbeEventKind = "token" + ProbeEventLogits ProbeEventKind = "logits" + ProbeEventEntropy ProbeEventKind = "entropy" + ProbeEventSelectedHeads ProbeEventKind = "selected_heads" + ProbeEventLayerCoherence ProbeEventKind = "layer_coherence" + ProbeEventRouterDecision ProbeEventKind = "router_decision" + ProbeEventResidual ProbeEventKind = "residual" + ProbeEventCachePressure ProbeEventKind = "cache_pressure" + ProbeEventMemoryPressure ProbeEventKind = "memory_pressure" + ProbeEventTraining ProbeEventKind = "training" + ProbeEventScheduler ProbeEventKind = "scheduler" + + ProbePhasePrefill ProbePhase = "prefill" + ProbePhaseDecode ProbePhase = "decode" + ProbePhaseTraining ProbePhase = "training" + ProbePhaseQueue ProbePhase = "queue" +) + +// ProbeEvent is the typed envelope for model-state observation. +type ProbeEvent struct { + Kind ProbeEventKind `json:"kind,omitempty"` + Phase ProbePhase `json:"phase,omitempty"` + Step int `json:"step,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + Token *ProbeToken `json:"token,omitempty"` + Logits *ProbeLogits `json:"logits,omitempty"` + Entropy *ProbeEntropy `json:"entropy,omitempty"` + SelectedHeads *ProbeHeadSelection `json:"selected_heads,omitempty"` + LayerCoherence *ProbeLayerCoherence `json:"layer_coherence,omitempty"` + RouterDecision *ProbeRouterDecision `json:"router_decision,omitempty"` + Residual *ProbeResidualSummary `json:"residual,omitempty"` + Cache *ProbeCachePressure `json:"cache,omitempty"` + Memory *ProbeMemoryPressure `json:"memory,omitempty"` + Training *ProbeTraining `json:"training,omitempty"` + Scheduler *ProbeScheduler `json:"scheduler,omitempty"` +} + +// ProbeToken records token-level stream state. +type ProbeToken struct { + ID int32 `json:"id,omitempty"` + Text string `json:"text,omitempty"` + PromptTokens int `json:"prompt_tokens,omitempty"` + GeneratedTokens int `json:"generated_tokens,omitempty"` +} + +// ProbeLogit is one sampled or selected logit entry. +type ProbeLogit struct { + ID int32 `json:"id,omitempty"` + Text string `json:"text,omitempty"` + Value float32 `json:"value,omitempty"` +} + +// ProbeLogits summarises logits without requiring full-vocabulary transfer. +type ProbeLogits struct { + VocabularySize int `json:"vocabulary_size,omitempty"` + Top []ProbeLogit `json:"top,omitempty"` + Min float32 `json:"min,omitempty"` + Max float32 `json:"max,omitempty"` + Mean float32 `json:"mean,omitempty"` +} + +// ProbeEntropy records a scalar entropy measurement. +type ProbeEntropy struct { + Value float64 `json:"value,omitempty"` + Unit string `json:"unit,omitempty"` +} + +// ProbeHeadSelection records selected heads for attention probing. +type ProbeHeadSelection struct { + Layer int `json:"layer,omitempty"` + Heads []int `json:"heads,omitempty"` +} + +// ProbeLayerCoherence carries layer-level alignment and spectral summaries. +type ProbeLayerCoherence struct { + Layer int `json:"layer,omitempty"` + KVCoupling float64 `json:"kv_coupling,omitempty"` + MeanCoherence float64 `json:"mean_coherence,omitempty"` + PhaseLock float64 `json:"phase_lock,omitempty"` + SpectralStable float64 `json:"spectral_stable,omitempty"` +} + +// ProbeRouterDecision records sparse expert routing decisions. +type ProbeRouterDecision struct { + Layer int `json:"layer,omitempty"` + ExpertIDs []int `json:"expert_ids,omitempty"` + ExpertProbs []float32 `json:"expert_probs,omitempty"` +} + +// ProbeResidualSummary records compact residual stream statistics. +type ProbeResidualSummary struct { + Layer int `json:"layer,omitempty"` + Mean float64 `json:"mean,omitempty"` + RMS float64 `json:"rms,omitempty"` + Norm float64 `json:"norm,omitempty"` +} + +// ProbeCachePressure records prompt/cache utilisation without exposing tensors. +type ProbeCachePressure struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + GeneratedTokens int `json:"generated_tokens,omitempty"` + CachedTokens int `json:"cached_tokens,omitempty"` + CacheMode string `json:"cache_mode,omitempty"` + HitRate float64 `json:"hit_rate,omitempty"` +} + +// ProbeMemoryPressure records active, peak, and limit memory counters. +type ProbeMemoryPressure struct { + ActiveBytes uint64 `json:"active_bytes,omitempty"` + PeakBytes uint64 `json:"peak_bytes,omitempty"` + LimitBytes uint64 `json:"limit_bytes,omitempty"` +} + +// ProbeTraining records live training metrics. +type ProbeTraining struct { + Epoch int `json:"epoch,omitempty"` + Step int `json:"step,omitempty"` + Loss float64 `json:"loss,omitempty"` + LearningRate float64 `json:"learning_rate,omitempty"` +} + +// ProbeScheduler records request-scheduler queue + latency events. +type ProbeScheduler struct { + RequestID string `json:"request_id,omitempty"` + Event string `json:"event,omitempty"` + QueueDepth int `json:"queue_depth,omitempty"` + QueueLatencyMillis float64 `json:"queue_latency_millis,omitempty"` + FirstTokenLatencyMillis float64 `json:"first_token_latency_millis,omitempty"` + TotalLatencyMillis float64 `json:"total_latency_millis,omitempty"` + Cancelled bool `json:"cancelled,omitempty"` +} + +// ProbeSink receives typed probe events from model backends. +type ProbeSink interface { + EmitProbe(event ProbeEvent) +} + +// ProbeSinkFunc adapts a function to ProbeSink. +type ProbeSinkFunc func(ProbeEvent) + +// EmitProbe emits an event when the function is non-nil. +func (f ProbeSinkFunc) EmitProbe(event ProbeEvent) { + if f != nil { + f(event) + } +} + +// ProbeBus fans probe events out to zero or more sinks. +type ProbeBus struct { + sinks []ProbeSink +} + +// NewProbeBus creates a probe fan-out bus. +func NewProbeBus(sinks ...ProbeSink) *ProbeBus { + bus := &ProbeBus{} + for _, sink := range sinks { + bus.Add(sink) + } + return bus +} + +// Add attaches a sink to the bus. Nil receivers and nil sinks are ignored. +func (b *ProbeBus) Add(sink ProbeSink) { + if b == nil || sink == nil { + return + } + b.sinks = append(b.sinks, sink) +} + +// EmitProbe emits an event to every registered sink. +func (b *ProbeBus) EmitProbe(event ProbeEvent) { + if b == nil { + return + } + for _, sink := range b.sinks { + if sink == nil { + continue + } + sink.EmitProbe(event) + } +} diff --git a/go/probe_bench_test.go b/go/probe_bench_test.go new file mode 100644 index 0000000..6672ebb --- /dev/null +++ b/go/probe_bench_test.go @@ -0,0 +1,365 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the probe-event surface. +// Per AX-11 — backends emit probe events at the rate of generation +// (one per emitted token when ProbeEventToken is wired, one per layer +// per step for richer probes). ProbeBus.EmitProbe fires once per emit, +// and ProbeSinkFunc adapters wrap every consumer callback. Even a few +// nanoseconds per emit dominates the picture under research telemetry +// loads (think every-layer attention probes on 28-layer Qwen3). +// +// Run: go test -bench=BenchmarkProbe -benchmem -run='^$' . + +package inference + +import ( + "testing" +) + +// Sinks defeat compiler DCE. +var ( + probeBenchSinkEvent ProbeEvent + probeBenchSinkKind ProbeEventKind + probeBenchSinkCount int + probeBenchSinkBus *ProbeBus + probeBenchSinkSinkFn ProbeSinkFunc +) + +// benchTokenEvent — minimal per-token decode probe (the per-step floor). +func benchTokenEvent() ProbeEvent { + return ProbeEvent{ + Kind: ProbeEventToken, + Phase: ProbePhaseDecode, + Step: 42, + Token: &ProbeToken{ + ID: 7, + Text: "the", + PromptTokens: 128, + GeneratedTokens: 42, + }, + } +} + +// benchTypicalDecodeEvent — richer per-step shape mid-decode — cache +// + entropy + a top-5 logits summary. Closer to what a probe sink +// actually sees when research telemetry is on. +func benchTypicalDecodeEvent() ProbeEvent { + return ProbeEvent{ + Kind: ProbeEventLogits, + Phase: ProbePhaseDecode, + Step: 42, + Logits: &ProbeLogits{ + VocabularySize: 151936, + Top: []ProbeLogit{ + {ID: 7, Text: "the", Value: 0.34}, + {ID: 11, Text: "a", Value: 0.21}, + {ID: 23, Text: "and", Value: 0.12}, + {ID: 41, Text: "is", Value: 0.08}, + {ID: 67, Text: "to", Value: 0.05}, + }, + Min: -12.5, + Max: 9.8, + Mean: -3.1, + }, + Entropy: &ProbeEntropy{ + Value: 2.34, + Unit: "nats", + }, + Cache: &ProbeCachePressure{ + PromptTokens: 128, + GeneratedTokens: 42, + CachedTokens: 96, + CacheMode: "paged-q8", + HitRate: 0.75, + }, + } +} + +// benchTrainingEvent — what a training probe sink sees per step. +func benchTrainingEvent() ProbeEvent { + return ProbeEvent{ + Kind: ProbeEventTraining, + Phase: ProbePhaseTraining, + Step: 1024, + Training: &ProbeTraining{ + Epoch: 2, + Step: 1024, + Loss: 1.234, + LearningRate: 5e-5, + }, + Memory: &ProbeMemoryPressure{ + ActiveBytes: 1 << 32, // 4 GiB + PeakBytes: 1 << 33, // 8 GiB + LimitBytes: 1 << 34, // 16 GiB + }, + Labels: map[string]string{"adapter": "lora-domain-v2"}, + } +} + +// --- ProbeSinkFunc.EmitProbe (the per-emit closure cost) --- + +func BenchmarkProbe_ProbeSinkFunc_EmitProbe_Token(b *testing.B) { + var captured ProbeEvent + sink := ProbeSinkFunc(func(event ProbeEvent) { + captured = event + }) + event := benchTokenEvent() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sink.EmitProbe(event) + } + probeBenchSinkKind = captured.Kind +} + +func BenchmarkProbe_ProbeSinkFunc_EmitProbe_TypicalDecode(b *testing.B) { + var captured ProbeEvent + sink := ProbeSinkFunc(func(event ProbeEvent) { + captured = event + }) + event := benchTypicalDecodeEvent() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sink.EmitProbe(event) + } + probeBenchSinkKind = captured.Kind +} + +func BenchmarkProbe_ProbeSinkFunc_EmitProbe_Training(b *testing.B) { + var captured ProbeEvent + sink := ProbeSinkFunc(func(event ProbeEvent) { + captured = event + }) + event := benchTrainingEvent() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sink.EmitProbe(event) + } + probeBenchSinkKind = captured.Kind +} + +// Nil-sink (Cladius dev path — probe sink not wired) — must be cheap. +func BenchmarkProbe_ProbeSinkFunc_EmitProbe_Nil(b *testing.B) { + var sink ProbeSinkFunc + event := benchTokenEvent() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sink.EmitProbe(event) + } +} + +// --- ProbeBus.EmitProbe fan-out cost --- + +func BenchmarkProbe_NewProbeBus_NoSinks(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkBus = NewProbeBus() + } +} + +func BenchmarkProbe_NewProbeBus_OneSink(b *testing.B) { + sink := ProbeSinkFunc(func(ProbeEvent) {}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkBus = NewProbeBus(sink) + } +} + +func BenchmarkProbe_NewProbeBus_FourSinks(b *testing.B) { + s1 := ProbeSinkFunc(func(ProbeEvent) {}) + s2 := ProbeSinkFunc(func(ProbeEvent) {}) + s3 := ProbeSinkFunc(func(ProbeEvent) {}) + s4 := ProbeSinkFunc(func(ProbeEvent) {}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkBus = NewProbeBus(s1, s2, s3, s4) + } +} + +func BenchmarkProbe_ProbeBus_Add(b *testing.B) { + bus := NewProbeBus() + sink := ProbeSinkFunc(func(ProbeEvent) {}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bus.Add(sink) + } +} + +func BenchmarkProbe_ProbeBus_EmitProbe_OneSink(b *testing.B) { + count := 0 + bus := NewProbeBus(ProbeSinkFunc(func(ProbeEvent) { count++ })) + event := benchTokenEvent() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bus.EmitProbe(event) + } + probeBenchSinkCount = count +} + +func BenchmarkProbe_ProbeBus_EmitProbe_FourSinks(b *testing.B) { + count := 0 + bus := NewProbeBus( + ProbeSinkFunc(func(ProbeEvent) { count++ }), + ProbeSinkFunc(func(ProbeEvent) { count++ }), + ProbeSinkFunc(func(ProbeEvent) { count++ }), + ProbeSinkFunc(func(ProbeEvent) { count++ }), + ) + event := benchTokenEvent() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bus.EmitProbe(event) + } + probeBenchSinkCount = count +} + +func BenchmarkProbe_ProbeBus_EmitProbe_OneSink_TypicalDecode(b *testing.B) { + count := 0 + bus := NewProbeBus(ProbeSinkFunc(func(ProbeEvent) { count++ })) + event := benchTypicalDecodeEvent() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bus.EmitProbe(event) + } + probeBenchSinkCount = count +} + +// Nil bus pointer — dev path; must be cheap. +func BenchmarkProbe_ProbeBus_EmitProbe_Nil(b *testing.B) { + var bus *ProbeBus + event := benchTokenEvent() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bus.EmitProbe(event) + } +} + +// Bus with a nil sink mixed in — exercises the nil-skip branch. +func BenchmarkProbe_ProbeBus_EmitProbe_WithNilSink(b *testing.B) { + count := 0 + bus := &ProbeBus{ + sinks: []ProbeSink{ + nil, + ProbeSinkFunc(func(ProbeEvent) { count++ }), + nil, + ProbeSinkFunc(func(ProbeEvent) { count++ }), + }, + } + event := benchTokenEvent() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bus.EmitProbe(event) + } + probeBenchSinkCount = count +} + +// --- ProbeEvent construction (the value-cost backends pay at emit site) --- +// Each new() of a sub-shape (ProbeToken/ProbeLogits/...) is a heap-alloc +// pointer — surface those construction floors. + +func BenchmarkProbe_ProbeEvent_Token(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkEvent = benchTokenEvent() + } +} + +func BenchmarkProbe_ProbeEvent_TypicalDecode(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkEvent = benchTypicalDecodeEvent() + } +} + +func BenchmarkProbe_ProbeEvent_Training(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkEvent = benchTrainingEvent() + } +} + +// Bare layer-coherence event (one-shot mid-decode probe) — the cheapest +// payload-bearing event shape. +func BenchmarkProbe_ProbeEvent_LayerCoherence(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkEvent = ProbeEvent{ + Kind: ProbeEventLayerCoherence, + Phase: ProbePhaseDecode, + Step: 3, + LayerCoherence: &ProbeLayerCoherence{ + Layer: 12, + KVCoupling: 0.7, + MeanCoherence: 0.8, + PhaseLock: 0.9, + SpectralStable: 0.6, + }, + } + } +} + +// Router-decision event — emitted per MoE layer during decode. +func BenchmarkProbe_ProbeEvent_RouterDecision_8Experts(b *testing.B) { + expertIDs := []int{0, 1, 2, 3, 4, 5, 6, 7} + expertProbs := []float32{0.2, 0.18, 0.15, 0.12, 0.10, 0.09, 0.08, 0.08} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkEvent = ProbeEvent{ + Kind: ProbeEventRouterDecision, + Phase: ProbePhaseDecode, + Step: 3, + RouterDecision: &ProbeRouterDecision{ + Layer: 12, + ExpertIDs: expertIDs, + ExpertProbs: expertProbs, + }, + } + } +} + +// Scheduler event — emitted at queue boundaries, not per token. +func BenchmarkProbe_ProbeEvent_Scheduler(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkEvent = ProbeEvent{ + Kind: ProbeEventScheduler, + Phase: ProbePhaseQueue, + Scheduler: &ProbeScheduler{ + RequestID: "req-7", + Event: "first_token", + QueueDepth: 4, + QueueLatencyMillis: 12.3, + FirstTokenLatencyMillis: 45.6, + }, + } + } +} + +// --- ProbeSinkFunc cast cost --- +// Used when a closure is passed where a ProbeSink is needed. + +func BenchmarkProbe_ProbeSinkFunc_Cast(b *testing.B) { + fn := func(ProbeEvent) {} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + probeBenchSinkSinkFn = ProbeSinkFunc(fn) + } +} diff --git a/go/probe_example_test.go b/go/probe_example_test.go new file mode 100644 index 0000000..8ea1184 --- /dev/null +++ b/go/probe_example_test.go @@ -0,0 +1,72 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import core "dappco.re/go" + +func ExampleProbeSinkFunc() { + sink := ProbeSinkFunc(func(event ProbeEvent) { + core.Println(event.Kind, event.Token.Text) + }) + + sink.EmitProbe(ProbeEvent{ + Kind: ProbeEventToken, + Token: &ProbeToken{Text: "hello"}, + }) + // Output: token hello +} + +func ExampleProbeSinkFunc_EmitProbe() { + sink := ProbeSinkFunc(func(event ProbeEvent) { + core.Println(event.Kind) + }) + + sink.EmitProbe(ProbeEvent{Kind: ProbeEventTraining}) + // Output: training +} + +func ExampleNewProbeBus() { + var seen int + bus := NewProbeBus(ProbeSinkFunc(func(ProbeEvent) { seen++ })) + + bus.EmitProbe(ProbeEvent{Kind: ProbeEventEntropy}) + + core.Println(seen) + // Output: 1 +} + +func ExampleProbeBus() { + var seen int + bus := NewProbeBus( + ProbeSinkFunc(func(ProbeEvent) { seen++ }), + ProbeSinkFunc(func(ProbeEvent) { seen++ }), + ) + + bus.EmitProbe(ProbeEvent{Kind: ProbeEventEntropy}) + + core.Println(seen) + // Output: 2 +} + +func ExampleProbeBus_Add() { + var seen int + bus := NewProbeBus() + bus.Add(ProbeSinkFunc(func(ProbeEvent) { seen++ })) + + bus.EmitProbe(ProbeEvent{Kind: ProbeEventResidual}) + + core.Println(seen) + // Output: 1 +} + +func ExampleProbeBus_EmitProbe() { + var kind ProbeEventKind + bus := NewProbeBus(ProbeSinkFunc(func(event ProbeEvent) { + kind = event.Kind + })) + + bus.EmitProbe(ProbeEvent{Kind: ProbeEventCachePressure}) + + core.Println(kind) + // Output: cache_pressure +} diff --git a/go/probe_test.go b/go/probe_test.go new file mode 100644 index 0000000..507660c --- /dev/null +++ b/go/probe_test.go @@ -0,0 +1,180 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import "testing" + +func TestProbe_ProbeSinkFunc_Good(t *testing.T) { + var got ProbeEvent + sink := ProbeSinkFunc(func(event ProbeEvent) { + got = event + }) + + sink.EmitProbe(ProbeEvent{ + Kind: ProbeEventToken, + Token: &ProbeToken{ + ID: 7, + Text: "ok", + }, + }) + + checkEqual(t, ProbeEventToken, got.Kind) + checkEqual(t, "ok", got.Token.Text) +} + +func TestProbe_ProbeSinkFunc_EmitProbe_Good(t *testing.T) { + var got ProbeEvent + sink := ProbeSinkFunc(func(event ProbeEvent) { + got = event + }) + + sink.EmitProbe(ProbeEvent{Kind: ProbeEventToken, Token: &ProbeToken{Text: "ok"}}) + + checkEqual(t, ProbeEventToken, got.Kind) + checkEqual(t, "ok", got.Token.Text) +} + +func TestProbe_ProbeSinkFunc_EmitProbe_Bad(t *testing.T) { + var sink ProbeSinkFunc + event := ProbeEvent{Kind: ProbeEventTraining} + + sink.EmitProbe(event) + + checkNil(t, sink) + checkEqual(t, ProbeEventTraining, event.Kind) +} + +func TestProbe_ProbeSinkFunc_EmitProbe_Ugly(t *testing.T) { + count := 0 + sink := ProbeSinkFunc(func(event ProbeEvent) { + if event.Kind == ProbeEventEntropy { + count++ + } + }) + + sink.EmitProbe(ProbeEvent{Kind: ProbeEventEntropy}) + sink.EmitProbe(ProbeEvent{Kind: ProbeEventMemoryPressure}) + + checkEqual(t, 1, count) +} + +func TestProbe_NewProbeBus_Good(t *testing.T) { + var count int + bus := NewProbeBus(ProbeSinkFunc(func(ProbeEvent) { count++ })) + bus.Add(ProbeSinkFunc(func(ProbeEvent) { count++ })) + + bus.EmitProbe(ProbeEvent{Kind: ProbeEventMemoryPressure}) + + checkEqual(t, 2, count) +} + +func TestProbe_NewProbeBus_Bad(t *testing.T) { + bus := NewProbeBus(nil) + + bus.EmitProbe(ProbeEvent{Kind: ProbeEventCachePressure}) + + checkNotNil(t, bus) + checkLen(t, bus.sinks, 0) +} + +func TestProbe_NewProbeBus_Ugly(t *testing.T) { + var got []ProbeEventKind + bus := NewProbeBus( + ProbeSinkFunc(func(event ProbeEvent) { got = append(got, event.Kind) }), + nil, + ProbeSinkFunc(func(event ProbeEvent) { got = append(got, event.Kind) }), + ) + + bus.EmitProbe(ProbeEvent{Kind: ProbeEventResidual}) + + checkEqual(t, []ProbeEventKind{ProbeEventResidual, ProbeEventResidual}, got) +} + +func TestProbe_ProbeBus_Add_Good(t *testing.T) { + bus := NewProbeBus() + sink := ProbeSinkFunc(func(ProbeEvent) {}) + + bus.Add(sink) + + checkLen(t, bus.sinks, 1) +} + +func TestProbe_ProbeBus_Add_Bad(t *testing.T) { + var bus *ProbeBus + + bus.Add(nil) + + checkNil(t, bus) +} + +func TestProbe_ProbeBus_Add_Ugly(t *testing.T) { + bus := NewProbeBus() + + bus.Add(nil) + bus.Add(ProbeSinkFunc(func(ProbeEvent) {})) + + checkLen(t, bus.sinks, 1) +} + +func TestProbe_ProbeBus_EmitProbe_Good(t *testing.T) { + var count int + bus := NewProbeBus( + ProbeSinkFunc(func(ProbeEvent) { count++ }), + ProbeSinkFunc(func(ProbeEvent) { count++ }), + ) + + bus.EmitProbe(ProbeEvent{Kind: ProbeEventMemoryPressure}) + + checkEqual(t, 2, count) +} + +func TestProbe_ProbeBus_EmitProbe_Bad(t *testing.T) { + var bus *ProbeBus + event := ProbeEvent{Kind: ProbeEventCachePressure} + + bus.EmitProbe(event) + + checkNil(t, bus) + checkEqual(t, ProbeEventCachePressure, event.Kind) +} + +func TestProbe_ProbeBus_EmitProbe_Ugly(t *testing.T) { + var count int + bus := &ProbeBus{ + sinks: []ProbeSink{ + nil, + ProbeSinkFunc(func(ProbeEvent) { count++ }), + }, + } + + bus.EmitProbe(ProbeEvent{Kind: ProbeEventCachePressure}) + + checkEqual(t, 1, count) +} + +func TestProbeEventRichPayload(t *testing.T) { + event := ProbeEvent{ + Kind: ProbeEventLayerCoherence, + Phase: ProbePhaseDecode, + Step: 3, + LayerCoherence: &ProbeLayerCoherence{ + Layer: 2, + KVCoupling: 0.7, + MeanCoherence: 0.8, + PhaseLock: 0.9, + SpectralStable: 0.6, + }, + Cache: &ProbeCachePressure{ + PromptTokens: 128, + GeneratedTokens: 16, + CachedTokens: 96, + CacheMode: "paged-q8", + HitRate: 0.75, + }, + } + + checkEqual(t, ProbeEventLayerCoherence, event.Kind) + checkEqual(t, ProbePhaseDecode, event.Phase) + checkEqual(t, 2, event.LayerCoherence.Layer) + checkEqual(t, "paged-q8", event.Cache.CacheMode) +} diff --git a/go/quant/codebook/codebook.go b/go/quant/codebook/codebook.go new file mode 100644 index 0000000..a08e388 --- /dev/null +++ b/go/quant/codebook/codebook.go @@ -0,0 +1,317 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package codebook holds the driver-neutral VQ-codebook quant metadata +// + reference CPU matvec for parity tests against native kernels. +// +// profile, _ := codebook.ParseProfile(data) +// desc, _ := codebook.NewTensorDescriptor(name, shape, profile) +// out, _ := codebook.MatVec(desc, input, codes, table, bias) +package codebook + +import ( + core "dappco.re/go" +) + +const ( + Type = "codebook" + FormatVQ = "vq" +) + +// profile := codebook.Profile{CodebookSize: 256, CodeDim: 4, IndexBits: 8} +type Profile struct { + Type string `json:"type,omitempty"` + Format string `json:"format,omitempty"` + CodebookSize int `json:"codebook_size,omitempty"` + CodeDim int `json:"code_dim,omitempty"` + IndexBits int `json:"index_bits,omitempty"` + Source string `json:"source,omitempty"` + Tensors []TensorDescriptor `json:"tensors,omitempty"` +} + +// desc, _ := codebook.NewTensorDescriptor(name, []uint64{out, in}, profile) +type TensorDescriptor struct { + Name string `json:"name,omitempty"` + Format string `json:"format,omitempty"` + Shape []uint64 `json:"shape,omitempty"` + Elements uint64 `json:"elements,omitempty"` + CodebookSize int `json:"codebook_size,omitempty"` + CodeDim int `json:"code_dim,omitempty"` + CodeCount int `json:"code_count,omitempty"` + IndexBits int `json:"index_bits,omitempty"` + IndexBytes int `json:"index_bytes,omitempty"` + CodesName string `json:"codes_name,omitempty"` + CodebookName string `json:"codebook_name,omitempty"` + CodesShape []uint64 `json:"codes_shape,omitempty"` + CodebookShape []uint64 `json:"codebook_shape,omitempty"` +} + +type configProbe struct { + Type string `json:"type"` + Format string `json:"format"` + CodebookSize int `json:"codebook_size"` + CodeDim int `json:"code_dim"` + IndexBits int `json:"index_bits"` + Source string `json:"source"` + Tensors []struct { + Name string `json:"name"` + Shape []uint64 `json:"shape"` + CodesName string `json:"codes"` + CodebookName string `json:"codebook"` + CodesShape []uint64 `json:"codes_shape"` + CodebookShape []uint64 `json:"codebook_shape"` + CodebookSize int `json:"codebook_size"` + CodeDim int `json:"code_dim"` + IndexBits int `json:"index_bits"` + } `json:"tensors"` +} + +// profile, _ := codebook.ParseProfile(data) +func ParseProfile(data []byte) (*Profile, error) { + var probe configProbe + if result := core.JSONUnmarshal(data, &probe); !result.OK { + return nil, result.Value.(error) + } + profile := Profile{ + Type: firstNonEmpty(probe.Type, Type), + Format: firstNonEmpty(probe.Format, FormatVQ), + CodebookSize: probe.CodebookSize, + CodeDim: probe.CodeDim, + IndexBits: firstPositive(probe.IndexBits, 8), + Source: firstNonEmpty(probe.Source, "codebook_config.json"), + } + for _, tensor := range probe.Tensors { + local := profile + local.CodebookSize = firstPositive(tensor.CodebookSize, profile.CodebookSize) + local.CodeDim = firstPositive(tensor.CodeDim, profile.CodeDim) + local.IndexBits = firstPositive(tensor.IndexBits, profile.IndexBits) + desc, err := NewTensorDescriptor(tensor.Name, tensor.Shape, local) + if err != nil { + return nil, err + } + desc.CodesName = firstNonEmpty(tensor.CodesName, defaultCodesName(desc.Name)) + desc.CodebookName = firstNonEmpty(tensor.CodebookName, defaultTableName(desc.Name)) + if len(tensor.CodesShape) > 0 { + desc.CodesShape = append([]uint64(nil), tensor.CodesShape...) + } + if len(tensor.CodebookShape) > 0 { + desc.CodebookShape = append([]uint64(nil), tensor.CodebookShape...) + } + profile.Tensors = append(profile.Tensors, desc) + } + if err := ValidateProfile(profile); err != nil { + return nil, err + } + return &profile, nil +} + +// profile, _ := codebook.ReadProfile("/models/foo") +func ReadProfile(root string) (*Profile, error) { + read := core.ReadFile(core.PathJoin(root, "codebook_config.json")) + if !read.OK { + if core.IsNotExist(read.Value.(error)) { + return nil, nil + } + return nil, read.Value.(error) + } + return ParseProfile(read.Value.([]byte)) +} + +// desc, _ := codebook.NewTensorDescriptor("layer0.mlp.w", []uint64{4096, 4096}, profile) +func NewTensorDescriptor(name string, shape []uint64, profile Profile) (TensorDescriptor, error) { + if name == "" { + return TensorDescriptor{}, core.NewError("codebook: tensor name is required") + } + if profile.Format == "" { + profile.Format = FormatVQ + } + if profile.Format != FormatVQ { + return TensorDescriptor{}, core.NewError("codebook: unsupported format: " + profile.Format) + } + if len(shape) != 2 || shape[0] == 0 || shape[1] == 0 { + return TensorDescriptor{}, core.NewError("codebook: tensor shape must be [out, in]") + } + if profile.CodebookSize <= 0 { + return TensorDescriptor{}, core.NewError("codebook: codebook size must be positive") + } + if profile.CodeDim <= 0 { + return TensorDescriptor{}, core.NewError("codebook: code_dim must be positive") + } + if !validIndexBits(profile.IndexBits) { + return TensorDescriptor{}, core.NewError(core.Sprintf("codebook: unsupported index bits %d", profile.IndexBits)) + } + elements := shape[0] * shape[1] + if elements%uint64(profile.CodeDim) != 0 { + return TensorDescriptor{}, core.NewError(core.Sprintf("codebook: tensor elements %d must be divisible by code_dim %d", elements, profile.CodeDim)) + } + codeCount := int(elements / uint64(profile.CodeDim)) + return TensorDescriptor{ + Name: name, + Format: profile.Format, + Shape: append([]uint64(nil), shape...), + Elements: elements, + CodebookSize: profile.CodebookSize, + CodeDim: profile.CodeDim, + CodeCount: codeCount, + IndexBits: profile.IndexBits, + IndexBytes: (codeCount*profile.IndexBits + 7) / 8, + CodesName: defaultCodesName(name), + CodebookName: defaultTableName(name), + CodesShape: []uint64{uint64(codeCount)}, + CodebookShape: []uint64{uint64(profile.CodebookSize), uint64(profile.CodeDim)}, + }, nil +} + +// err := codebook.ValidateProfile(profile) +func ValidateProfile(profile Profile) error { + if profile.Type != "" && profile.Type != Type { + return core.NewError("codebook: unsupported type: " + profile.Type) + } + if profile.Format != "" && profile.Format != FormatVQ { + return core.NewError("codebook: unsupported format: " + profile.Format) + } + if profile.CodebookSize <= 0 { + return core.NewError("codebook: codebook size must be positive") + } + if profile.CodeDim <= 0 { + return core.NewError("codebook: code_dim must be positive") + } + if !validIndexBits(firstPositive(profile.IndexBits, 8)) { + return core.NewError(core.Sprintf("codebook: unsupported index bits %d", profile.IndexBits)) + } + for _, tensor := range profile.Tensors { + if err := ValidateTensorDescriptor(tensor); err != nil { + return err + } + } + return nil +} + +// err := codebook.ValidateTensorDescriptor(desc) +func ValidateTensorDescriptor(desc TensorDescriptor) error { + if desc.Name == "" { + return core.NewError("codebook: tensor name is required") + } + if desc.Format != FormatVQ { + return core.NewError("codebook: tensor format must be vq") + } + if len(desc.Shape) != 2 || desc.Shape[0] == 0 || desc.Shape[1] == 0 { + return core.NewError("codebook: tensor shape must be [out, in]") + } + if desc.CodebookSize <= 0 || desc.CodeDim <= 0 || desc.CodeCount <= 0 { + return core.NewError("codebook: tensor requires codebook_size, code_dim, and code_count") + } + if !validIndexBits(desc.IndexBits) { + return core.NewError(core.Sprintf("codebook: unsupported index bits %d", desc.IndexBits)) + } + if desc.Elements != desc.Shape[0]*desc.Shape[1] { + return core.NewError("codebook: tensor element count does not match shape") + } + if int(desc.Elements/uint64(desc.CodeDim)) != desc.CodeCount { + return core.NewError("codebook: tensor code count does not match code_dim") + } + return nil +} + +// out, _ := codebook.MatVec(desc, input, codes, table, bias) +func MatVec(desc TensorDescriptor, input []float32, codes []uint32, codebook []float32, bias []float32) ([]float32, error) { + if err := ValidateTensorPayload(desc, codes, codebook, bias); err != nil { + return nil, err + } + outDim := int(desc.Shape[0]) + inDim := int(desc.Shape[1]) + if len(input) == 0 || len(input)%inDim != 0 { + return nil, core.NewError(core.Sprintf("codebook: matvec input length %d is not divisible by input width %d", len(input), inDim)) + } + rows := len(input) / inDim + out := make([]float32, rows*outDim) + for row := 0; row < rows; row++ { + for outCol := 0; outCol < outDim; outCol++ { + sum := float32(0) + for inCol := 0; inCol < inDim; inCol++ { + weightIndex := outCol*inDim + inCol + codeIndex := weightIndex / desc.CodeDim + codeOffset := weightIndex % desc.CodeDim + codeID := codes[codeIndex] + weight := codebook[int(codeID)*desc.CodeDim+codeOffset] + sum += input[row*inDim+inCol] * weight + } + if len(bias) > 0 { + sum += bias[outCol] + } + out[row*outDim+outCol] = sum + } + } + return out, nil +} + +// err := codebook.ValidateTensorPayload(desc, codes, table, bias) +func ValidateTensorPayload(desc TensorDescriptor, codes []uint32, codebook []float32, bias []float32) error { + if err := ValidateTensorDescriptor(desc); err != nil { + return err + } + if len(codes) != desc.CodeCount { + return core.NewError(core.Sprintf("codebook: code count %d, expected %d", len(codes), desc.CodeCount)) + } + if len(codebook) != desc.CodebookSize*desc.CodeDim { + return core.NewError(core.Sprintf("codebook: value count %d, expected %d", len(codebook), desc.CodebookSize*desc.CodeDim)) + } + for i, codeID := range codes { + if codeID >= uint32(desc.CodebookSize) { + return core.NewError(core.Sprintf("codebook: code id %d at index %d exceeds codebook size %d", codeID, i, desc.CodebookSize)) + } + } + if len(bias) > 0 && len(bias) != int(desc.Shape[0]) { + return core.NewError(core.Sprintf("codebook: bias length %d, expected %d", len(bias), desc.Shape[0])) + } + return nil +} + +// clone := codebook.CloneProfile(profile) +func CloneProfile(profile *Profile) *Profile { + if profile == nil { + return nil + } + cloned := *profile + cloned.Tensors = append([]TensorDescriptor(nil), profile.Tensors...) + for i := range cloned.Tensors { + cloned.Tensors[i].Shape = append([]uint64(nil), profile.Tensors[i].Shape...) + cloned.Tensors[i].CodesShape = append([]uint64(nil), profile.Tensors[i].CodesShape...) + cloned.Tensors[i].CodebookShape = append([]uint64(nil), profile.Tensors[i].CodebookShape...) + } + return &cloned +} + +func validIndexBits(bits int) bool { + switch bits { + case 8, 16, 32: + return true + default: + return false + } +} + +func defaultCodesName(name string) string { + return name + ".codes" +} + +func defaultTableName(name string) string { + return name + ".codebook" +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if value != "" { + return value + } + } + return "" +} + +func firstPositive(values ...int) int { + for _, value := range values { + if value > 0 { + return value + } + } + return 0 +} diff --git a/go/quant/codebook/codebook_bench_test.go b/go/quant/codebook/codebook_bench_test.go new file mode 100644 index 0000000..55d69ef --- /dev/null +++ b/go/quant/codebook/codebook_bench_test.go @@ -0,0 +1,348 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the driver-neutral VQ-codebook quant primitives. +// Per AX-11 — ParseProfile + NewTensorDescriptor fire once per +// tensor at model load (hundreds of tensors per Gemma/Qwen-class +// model). ValidateTensorPayload runs per kernel dispatch on the +// CPU parity path. CloneProfile fires per profile lifted across +// runtime boundaries. The reference MatVec is the CPU parity +// path used by parity tests against the native Metal kernel. +// +// Run: go test -bench='Benchmark' -benchmem -run='^$' ./quant/codebook + +package codebook + +import ( + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. +var ( + codebookSinkProfile *Profile + codebookSinkDescriptor TensorDescriptor + codebookSinkMatVec []float32 + codebookSinkErr error + codebookSinkProfileVal Profile + codebookSinkClonedProf *Profile +) + +// benchProfile builds a Profile with the requested codebook size and +// a single tensor of the requested shape. Used as a shared fixture +// across the bench surfaces. +func benchProfile(codebookSize, codeDim, indexBits int, outDim, inDim uint64) Profile { + desc, _ := NewTensorDescriptor("model.layers.0.mlp.down_proj.weight", []uint64{outDim, inDim}, Profile{ + Format: FormatVQ, + CodebookSize: codebookSize, + CodeDim: codeDim, + IndexBits: indexBits, + }) + return Profile{ + Type: Type, + Format: FormatVQ, + CodebookSize: codebookSize, + CodeDim: codeDim, + IndexBits: indexBits, + Tensors: []TensorDescriptor{desc}, + } +} + +// benchMatVecInputs builds the codes + codebook + bias slices a +// MatVec parity check needs for a given descriptor. +func benchMatVecInputs(desc TensorDescriptor) ([]float32, []uint32, []float32, []float32) { + input := make([]float32, int(desc.Shape[1])) + for i := range input { + input[i] = float32(i%7) * 0.125 + } + codes := make([]uint32, desc.CodeCount) + for i := range codes { + codes[i] = uint32(i % desc.CodebookSize) + } + table := make([]float32, desc.CodebookSize*desc.CodeDim) + for i := range table { + table[i] = float32(i%11) * 0.25 + } + bias := make([]float32, int(desc.Shape[0])) + for i := range bias { + bias[i] = float32(i%3) * 0.5 + } + return input, codes, table, bias +} + +// --- NewTensorDescriptor (per-tensor at model load) --- + +func BenchmarkCodebook_NewTensorDescriptor_Small(b *testing.B) { + profile := Profile{ + Format: FormatVQ, + CodebookSize: 256, + CodeDim: 4, + IndexBits: 8, + } + shape := []uint64{1024, 1024} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkDescriptor, codebookSinkErr = NewTensorDescriptor("model.layers.0.mlp.down_proj.weight", shape, profile) + } +} + +func BenchmarkCodebook_NewTensorDescriptor_Large(b *testing.B) { + profile := Profile{ + Format: FormatVQ, + CodebookSize: 4096, + CodeDim: 8, + IndexBits: 16, + } + shape := []uint64{4096, 4096} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkDescriptor, codebookSinkErr = NewTensorDescriptor("model.layers.0.mlp.down_proj.weight", shape, profile) + } +} + +// --- ParseProfile (per-model load) --- + +func BenchmarkCodebook_ParseProfile_Small(b *testing.B) { + data := []byte(`{ + "type": "codebook", + "format": "vq", + "codebook_size": 256, + "code_dim": 4, + "index_bits": 8, + "tensors": [ + { + "name": "model.layers.0.mlp.down_proj.weight", + "shape": [1024, 1024] + } + ] + }`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkProfile, codebookSinkErr = ParseProfile(data) + } +} + +func BenchmarkCodebook_ParseProfile_Large(b *testing.B) { + data := []byte(`{ + "type": "codebook", + "format": "vq", + "codebook_size": 4096, + "code_dim": 8, + "index_bits": 16, + "tensors": [ + { + "name": "model.layers.0.mlp.down_proj.weight", + "shape": [4096, 4096] + }, + { + "name": "model.layers.0.mlp.gate_proj.weight", + "shape": [4096, 4096] + }, + { + "name": "model.layers.0.mlp.up_proj.weight", + "shape": [4096, 4096] + }, + { + "name": "model.layers.0.self_attn.q_proj.weight", + "shape": [4096, 4096] + }, + { + "name": "model.layers.0.self_attn.k_proj.weight", + "shape": [4096, 4096] + }, + { + "name": "model.layers.0.self_attn.v_proj.weight", + "shape": [4096, 4096] + }, + { + "name": "model.layers.0.self_attn.o_proj.weight", + "shape": [4096, 4096] + } + ] + }`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkProfile, codebookSinkErr = ParseProfile(data) + } +} + +// --- ValidateProfile (per-profile across runtime boundaries) --- + +func BenchmarkCodebook_ValidateProfile_Small(b *testing.B) { + profile := benchProfile(256, 4, 8, 1024, 1024) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkErr = ValidateProfile(profile) + } +} + +func BenchmarkCodebook_ValidateProfile_Large(b *testing.B) { + profile := benchProfile(4096, 8, 16, 4096, 4096) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkErr = ValidateProfile(profile) + } +} + +// --- ValidateTensorDescriptor (per-tensor across runtime boundaries) --- + +func BenchmarkCodebook_ValidateTensorDescriptor_Small(b *testing.B) { + desc, err := NewTensorDescriptor("model.layers.0.mlp.down_proj.weight", []uint64{1024, 1024}, Profile{ + Format: FormatVQ, + CodebookSize: 256, + CodeDim: 4, + IndexBits: 8, + }) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkErr = ValidateTensorDescriptor(desc) + } +} + +func BenchmarkCodebook_ValidateTensorDescriptor_Large(b *testing.B) { + desc, err := NewTensorDescriptor("model.layers.0.mlp.down_proj.weight", []uint64{4096, 4096}, Profile{ + Format: FormatVQ, + CodebookSize: 4096, + CodeDim: 8, + IndexBits: 16, + }) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkErr = ValidateTensorDescriptor(desc) + } +} + +// --- ValidateTensorPayload (per kernel dispatch) --- + +func BenchmarkCodebook_ValidateTensorPayload_Small(b *testing.B) { + desc, err := NewTensorDescriptor("model.layers.0.mlp.down_proj.weight", []uint64{64, 64}, Profile{ + Format: FormatVQ, + CodebookSize: 256, + CodeDim: 4, + IndexBits: 8, + }) + if err != nil { + b.Fatal(err) + } + _, codes, table, bias := benchMatVecInputs(desc) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkErr = ValidateTensorPayload(desc, codes, table, bias) + } +} + +func BenchmarkCodebook_ValidateTensorPayload_Large(b *testing.B) { + desc, err := NewTensorDescriptor("model.layers.0.mlp.down_proj.weight", []uint64{256, 256}, Profile{ + Format: FormatVQ, + CodebookSize: 4096, + CodeDim: 8, + IndexBits: 16, + }) + if err != nil { + b.Fatal(err) + } + _, codes, table, bias := benchMatVecInputs(desc) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkErr = ValidateTensorPayload(desc, codes, table, bias) + } +} + +// --- CloneProfile (per runtime hand-off) --- + +func BenchmarkCodebook_CloneProfile_Small(b *testing.B) { + profile := benchProfile(256, 4, 8, 1024, 1024) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkClonedProf = CloneProfile(&profile) + } +} + +func BenchmarkCodebook_CloneProfile_Large(b *testing.B) { + profile := benchProfile(4096, 8, 16, 4096, 4096) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkClonedProf = CloneProfile(&profile) + } +} + +// --- MatVec (reference CPU parity path) --- +// Sizes intentionally small — the CPU loop is O(out*in) and is the +// parity-test path, not the production hot loop. Keeping the inputs +// modest keeps the bench under 100ms per case while still exercising +// the per-row + per-col dispatch + table lookup. + +func BenchmarkCodebook_MatVec_64x64_CB256(b *testing.B) { + desc, err := NewTensorDescriptor("ok.weight", []uint64{64, 64}, Profile{ + Format: FormatVQ, + CodebookSize: 256, + CodeDim: 4, + IndexBits: 8, + }) + if err != nil { + b.Fatal(err) + } + input, codes, table, bias := benchMatVecInputs(desc) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkMatVec, codebookSinkErr = MatVec(desc, input, codes, table, bias) + } +} + +func BenchmarkCodebook_MatVec_128x128_CB4096(b *testing.B) { + desc, err := NewTensorDescriptor("ok.weight", []uint64{128, 128}, Profile{ + Format: FormatVQ, + CodebookSize: 4096, + CodeDim: 8, + IndexBits: 16, + }) + if err != nil { + b.Fatal(err) + } + input, codes, table, bias := benchMatVecInputs(desc) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkMatVec, codebookSinkErr = MatVec(desc, input, codes, table, bias) + } +} + +// --- core.Contains diagnostic-string path (validation error formatting) --- +// Reject paths still cost real wall time when the producer hits a +// guarded shape; bench the error-format hot loop on the unaligned +// branch the test file already covers. + +func BenchmarkCodebook_NewTensorDescriptor_RejectUnaligned(b *testing.B) { + profile := Profile{ + Format: FormatVQ, + CodebookSize: 16, + CodeDim: 4, + IndexBits: 8, + } + shape := []uint64{3, 3} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + codebookSinkDescriptor, codebookSinkErr = NewTensorDescriptor("bad.weight", shape, profile) + } + _ = core.Contains // keep the import resolved when reject paths don't fire +} diff --git a/go/quant/codebook/codebook_test.go b/go/quant/codebook/codebook_test.go new file mode 100644 index 0000000..48ed7be --- /dev/null +++ b/go/quant/codebook/codebook_test.go @@ -0,0 +1,111 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package codebook + +import ( + "testing" + + core "dappco.re/go" +) + +func TestCodebook_DescriptorValidatesAndMatVec_Good(t *testing.T) { + profile := Profile{ + Format: FormatVQ, + CodebookSize: 3, + CodeDim: 2, + IndexBits: 16, + } + + desc, err := NewTensorDescriptor("model.layers.0.mlp.down_proj.weight", []uint64{2, 4}, profile) + if err != nil { + t.Fatalf("NewTensorDescriptor() error = %v", err) + } + if desc.Elements != 8 || desc.CodeCount != 4 || desc.CodebookSize != 3 || desc.CodeDim != 2 { + t.Fatalf("descriptor = %+v, want 8 elements, 4 codes, 3-entry codebook with 2D vectors", desc) + } + if desc.IndexBytes != 8 { + t.Fatalf("IndexBytes = %d, want four 16-bit indices", desc.IndexBytes) + } + + got, err := MatVec(desc, []float32{3, 4, 5, 6}, []uint32{0, 1, 2, 1}, []float32{ + 1, 0, + 0, 1, + 2, -1, + }, []float32{0.5, -1}) + if err != nil { + t.Fatalf("MatVec() error = %v", err) + } + assertCloseSlice(t, got, []float32{9.5, 7}, 1e-5) +} + +func TestCodebook_DescriptorRejectsUnalignedShape_Bad(t *testing.T) { + _, err := NewTensorDescriptor("bad.weight", []uint64{3, 3}, Profile{ + Format: FormatVQ, + CodebookSize: 16, + CodeDim: 4, + IndexBits: 8, + }) + if err == nil || !core.Contains(err.Error(), "divisible") { + t.Fatalf("error = %v, want code-dim divisibility diagnostic", err) + } +} + +func TestCodebook_MatVecRejectsOutOfRangeCode_Bad(t *testing.T) { + desc, err := NewTensorDescriptor("ok.weight", []uint64{1, 2}, Profile{ + Format: FormatVQ, + CodebookSize: 2, + CodeDim: 1, + IndexBits: 8, + }) + if err != nil { + t.Fatalf("NewTensorDescriptor() error = %v", err) + } + + _, err = MatVec(desc, []float32{1, 2}, []uint32{0, 4}, []float32{1, 2}, nil) + if err == nil || !core.Contains(err.Error(), "code id") { + t.Fatalf("error = %v, want out-of-range code diagnostic", err) + } +} + +func TestCodebook_ParseProfile_Good(t *testing.T) { + profile, err := ParseProfile([]byte(`{ + "type": "codebook", + "format": "vq", + "codebook_size": 4, + "code_dim": 2, + "index_bits": 8, + "tensors": [ + { + "name": "model.layers.0.mlp.down_proj.weight", + "shape": [2, 4], + "codes": "model.layers.0.mlp.down_proj.weight.codes", + "codebook": "model.layers.0.mlp.down_proj.weight.codebook" + } + ] + }`)) + if err != nil { + t.Fatalf("ParseProfile() error = %v", err) + } + if profile.Type != Type || profile.Format != FormatVQ || len(profile.Tensors) != 1 { + t.Fatalf("profile = %+v, want one VQ tensor", profile) + } + if tensor := profile.Tensors[0]; tensor.CodeCount != 4 || tensor.CodesName == "" || tensor.CodebookName == "" { + t.Fatalf("tensor = %+v, want resolved sidecar names and code count", tensor) + } +} + +func assertCloseSlice(t *testing.T, got, want []float32, epsilon float64) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("len(got) = %d, want %d", len(got), len(want)) + } + for i := range got { + diff := got[i] - want[i] + if diff < 0 { + diff = -diff + } + if float64(diff) > epsilon { + t.Fatalf("value[%d] = %f, want %f", i, got[i], want[i]) + } + } +} diff --git a/go/quant/jang/jang.go b/go/quant/jang/jang.go new file mode 100644 index 0000000..d9c0cfc --- /dev/null +++ b/go/quant/jang/jang.go @@ -0,0 +1,813 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package jang holds the driver-neutral JANG/JANGTQ quantisation metadata +// + portable packed-tensor descriptor + reference dequant for parity tests. +// +// info, _ := jang.ReadConfig("/models/minimax-m2-jangtq") +// desc, _ := jang.NewPackedTensorDescriptor("model.layers.0.self_attn.q_proj.weight", shape, info) +package jang + +import ( + core "dappco.re/go" +) + +// info := jang.Info{Profile: "JANGTQ", GroupSize: 64} +type Info struct { + Version int `json:"version,omitempty"` + WeightFormat string `json:"weight_format,omitempty"` + Profile string `json:"profile,omitempty"` + Method string `json:"method,omitempty"` + GroupSize int `json:"group_size,omitempty"` + BitsDefault int `json:"bits_default,omitempty"` + AttentionBits int `json:"attention_bits,omitempty"` + SharedExpertBits int `json:"shared_expert_bits,omitempty"` + RoutedExpertBits int `json:"routed_expert_bits,omitempty"` + EmbedTokensBits int `json:"embed_tokens_bits,omitempty"` + LMHeadBits int `json:"lm_head_bits,omitempty"` + SourceName string `json:"source_name,omitempty"` + SourceOrg string `json:"source_org,omitempty"` + SourceArchitecture string `json:"source_architecture,omitempty"` + Capabilities Capabilities `json:"capabilities,omitempty"` + Packed *PackedProfile `json:"packed,omitempty"` +} + +// caps := jang.Capabilities{ReasoningParser: "qwen-think", SupportsTools: true} +type Capabilities struct { + ReasoningParser string `json:"reasoning_parser,omitempty"` + ToolParser string `json:"tool_parser,omitempty"` + ThinkInTemplate bool `json:"think_in_template,omitempty"` + SupportsTools bool `json:"supports_tools,omitempty"` + SupportsThinking bool `json:"supports_thinking,omitempty"` + Family string `json:"family,omitempty"` + Modality string `json:"modality,omitempty"` + CacheType string `json:"cache_type,omitempty"` +} + +// role := jang.TensorRoleAttention +type TensorRole string + +const ( + TensorRoleDefault TensorRole = "default" + TensorRoleAttention TensorRole = "attention" + TensorRoleSharedExpert TensorRole = "shared_expert" + TensorRoleRoutedExpert TensorRole = "routed_expert" + TensorRoleEmbedTokens TensorRole = "embed_tokens" + TensorRoleLMHead TensorRole = "lm_head" +) + +const ( + BitOrderLSB0 = "lsb0" + EncodingAffine = "affine" +) + +// profile := jang.BuildPackedProfile(&info) +type PackedProfile struct { + Type string `json:"type,omitempty"` + Format string `json:"format,omitempty"` + Profile string `json:"profile,omitempty"` + Method string `json:"method,omitempty"` + GroupSize int `json:"group_size,omitempty"` + BitsDefault int `json:"bits_default,omitempty"` + RoleBits map[string]int `json:"role_bits,omitempty"` + MinBits int `json:"min_bits,omitempty"` + MaxBits int `json:"max_bits,omitempty"` + Mixed bool `json:"mixed,omitempty"` + BitOrder string `json:"bit_order,omitempty"` + Encoding string `json:"encoding,omitempty"` + ValuesPerByte int `json:"values_per_byte,omitempty"` +} + +// desc, _ := jang.NewPackedTensorDescriptor(name, shape, &info) +type PackedTensorDescriptor struct { + Name string `json:"name,omitempty"` + Type string `json:"type,omitempty"` + Format string `json:"format,omitempty"` + Profile string `json:"profile,omitempty"` + Role TensorRole `json:"role,omitempty"` + Shape []uint64 `json:"shape,omitempty"` + Elements uint64 `json:"elements,omitempty"` + Bits int `json:"bits,omitempty"` + GroupSize int `json:"group_size,omitempty"` + Groups int `json:"groups,omitempty"` + PackedBytes int `json:"packed_bytes,omitempty"` + ValuesPerByte int `json:"values_per_byte,omitempty"` + ScaleCount int `json:"scale_count,omitempty"` + BiasCount int `json:"bias_count,omitempty"` + BitOrder string `json:"bit_order,omitempty"` + Encoding string `json:"encoding,omitempty"` +} + +type configProbe struct { + Version int `json:"version"` + WeightFormat string `json:"weight_format"` + Profile string `json:"profile"` + SourceModel struct { + Name string `json:"name"` + Org string `json:"org"` + Architecture string `json:"architecture"` + } `json:"source_model"` + MXTQBits struct { + Attention int `json:"attention"` + SharedExpert int `json:"shared_expert"` + RoutedExpert int `json:"routed_expert"` + EmbedTokens int `json:"embed_tokens"` + LMHead int `json:"lm_head"` + } `json:"mxtq_bits"` + Quantization struct { + Method string `json:"method"` + GroupSize int `json:"group_size"` + BitsDefault int `json:"bits_default"` + } `json:"quantization"` + Capabilities Capabilities `json:"capabilities"` +} + +// info, _ := jang.ReadConfig("/models/minimax-m2") +func ReadConfig(root string) (*Info, error) { + read := core.ReadFile(core.PathJoin(root, "jang_config.json")) + if !read.OK { + if core.IsNotExist(read.Value.(error)) { + return nil, nil + } + return nil, read.Value.(error) + } + return ParseConfig(read.Value.([]byte)) +} + +// info, _ := jang.ParseConfig(data) +func ParseConfig(data []byte) (*Info, error) { + var probe configProbe + if result := core.JSONUnmarshal(data, &probe); !result.OK { + return nil, result.Value.(error) + } + return finalize(&Info{ + Version: probe.Version, + WeightFormat: probe.WeightFormat, + Profile: probe.Profile, + Method: probe.Quantization.Method, + GroupSize: probe.Quantization.GroupSize, + BitsDefault: firstPositive(probe.Quantization.BitsDefault, probe.MXTQBits.RoutedExpert, ProfileBits(probe.Profile)), + AttentionBits: probe.MXTQBits.Attention, + SharedExpertBits: probe.MXTQBits.SharedExpert, + RoutedExpertBits: probe.MXTQBits.RoutedExpert, + EmbedTokensBits: probe.MXTQBits.EmbedTokens, + LMHeadBits: probe.MXTQBits.LMHead, + SourceName: probe.SourceModel.Name, + SourceOrg: probe.SourceModel.Org, + SourceArchitecture: normaliseArchitecture(probe.SourceModel.Architecture), + Capabilities: probe.Capabilities, + }), nil +} + +// bits := jang.ProfileBits("JANG_4M") // returns 4 +func ProfileBits(profile string) int { + profile = core.Lower(profile) + switch { + case core.Contains(profile, "jangtq"): + return 2 + case core.Contains(profile, "jang_1"): + return 1 + case core.Contains(profile, "jang_2"): + return 2 + case core.Contains(profile, "jang_3"): + return 3 + case core.Contains(profile, "jang_4"): + return 4 + default: + return 0 + } +} + +func quantizationType(info *Info) string { + if info == nil { + return "" + } + lower := core.Lower(core.Concat(info.Profile, " ", info.WeightFormat, " ", info.Method)) + if core.Contains(lower, "jangtq") || core.Contains(lower, "mxtq") { + return "jangtq" + } + return "jang" +} + +func finalize(info *Info) *Info { + if info == nil { + return nil + } + info.Packed = BuildPackedProfile(info) + return info +} + +// profile := jang.BuildPackedProfile(&info) +func BuildPackedProfile(info *Info) *PackedProfile { + if info == nil { + return nil + } + rb := roleBits(info) + minBits, maxBits := minMaxBits(rb) + profile := &PackedProfile{ + Type: quantizationType(info), + Format: packedFormat(info), + Profile: info.Profile, + Method: info.Method, + GroupSize: info.GroupSize, + BitsDefault: info.BitsDefault, + RoleBits: rb, + MinBits: minBits, + MaxBits: maxBits, + Mixed: minBits > 0 && maxBits > minBits, + BitOrder: BitOrderLSB0, + Encoding: EncodingAffine, + ValuesPerByte: valuesPerByte(info.BitsDefault), + } + if profile.Format == "" { + profile.Format = profile.Type + } + return profile +} + +// clone := jang.ClonePackedProfile(profile) +func ClonePackedProfile(profile *PackedProfile) *PackedProfile { + if profile == nil { + return nil + } + cloned := *profile + cloned.RoleBits = cloneRoleBits(profile.RoleBits) + return &cloned +} + +// desc, _ := jang.NewPackedTensorDescriptor("model.layers.0.q_proj.weight", []uint64{4096, 4096}, &info) +func NewPackedTensorDescriptor(name string, shape []uint64, info *Info) (PackedTensorDescriptor, error) { + if info == nil { + return PackedTensorDescriptor{}, core.NewError("jang: packed tensor descriptor requires quantization info") + } + role := inferTensorRole(name) + bits := bitsForRole(info, role) + elements, err := shapeElements(shape) + if err != nil { + return PackedTensorDescriptor{}, err + } + if err := validateBits(bits, name); err != nil { + return PackedTensorDescriptor{}, err + } + if info.GroupSize <= 0 { + return PackedTensorDescriptor{}, core.NewError(core.Sprintf("jang: packed tensor %q has invalid group size %d", name, info.GroupSize)) + } + if elements > ^uint64(0)/uint64(bits) { + return PackedTensorDescriptor{}, core.NewError(core.Sprintf("jang: packed tensor %q packed bit count overflows", name)) + } + packedBits := elements * uint64(bits) + packedBytes := ceilDivUint64(packedBits, 8) + if packedBytes > uint64(maxIntValue()) { + return PackedTensorDescriptor{}, core.NewError(core.Sprintf("jang: packed tensor %q is too large", name)) + } + groups := ceilDivUint64(elements, uint64(info.GroupSize)) + if groups > uint64(maxIntValue()) { + return PackedTensorDescriptor{}, core.NewError(core.Sprintf("jang: packed tensor %q has too many groups", name)) + } + return PackedTensorDescriptor{ + Name: name, + Type: quantizationType(info), + Format: packedFormat(info), + Profile: info.Profile, + Role: role, + Shape: append([]uint64(nil), shape...), + Elements: elements, + Bits: bits, + GroupSize: info.GroupSize, + Groups: int(groups), + PackedBytes: int(packedBytes), + ValuesPerByte: valuesPerByte(bits), + ScaleCount: int(groups), + BiasCount: int(groups), + BitOrder: BitOrderLSB0, + Encoding: EncodingAffine, + }, nil +} + +// err := jang.ValidatePackedTensor(desc, packed, scales, biases) +func ValidatePackedTensor(desc PackedTensorDescriptor, packed []byte, scales, biases []float32) error { + if err := validateDescriptor(desc); err != nil { + return err + } + if len(packed) != desc.PackedBytes { + return core.NewError(core.Sprintf("jang: packed tensor %q packed length %d, expected %d", desc.Name, len(packed), desc.PackedBytes)) + } + if len(scales) != desc.ScaleCount { + return core.NewError(core.Sprintf("jang: packed tensor %q scale count %d, expected %d", desc.Name, len(scales), desc.ScaleCount)) + } + if len(biases) != desc.BiasCount { + return core.NewError(core.Sprintf("jang: packed tensor %q bias count %d, expected %d", desc.Name, len(biases), desc.BiasCount)) + } + return nil +} + +// values, _ := jang.DequantizePackedTensor(desc, packed, scales, biases) +func DequantizePackedTensor(desc PackedTensorDescriptor, packed []byte, scales, biases []float32) ([]float32, error) { + if err := ValidatePackedTensor(desc, packed, scales, biases); err != nil { + return nil, err + } + if desc.Elements > uint64(maxIntValue()) { + return nil, core.NewError(core.Sprintf("jang: packed tensor %q is too large to dequantize on CPU", desc.Name)) + } + out := make([]float32, int(desc.Elements)) + groupSize := desc.GroupSize + // Dispatch by bit-width once outside the loop so the inner unpack + // becomes a single shift+mask the Go compiler can keep in registers, + // rather than paying the un-inlinable unpackValue call on every + // element. The dispatch also lets us hoist scale/bias per group — + // the original loop re-indexed scales[i/groupSize] + biases[i/groupSize] + // on every element, which is groupSize-1 redundant indexed reads + a + // division per group (with groupSize=64, that's a 64× reduction in + // per-element scale/bias work). + switch desc.Bits { + case 8: + dequantizeBit8(out, packed, scales, biases, groupSize) + case 4: + dequantizeBit4(out, packed, scales, biases, groupSize) + case 2: + dequantizeBit2(out, packed, scales, biases, groupSize) + case 1: + dequantizeBit1(out, packed, scales, biases, groupSize) + default: + // Generic walk for non-power-of-2 widths (3-bit and any future + // awkward width). Inline the bit-walk so we sidestep the + // fast-path switch in unpackValue — the outer dispatch already + // proved we won't hit a byte-aligned width here. Outer loop + // still hoists scale/bias per group. + dequantizeBitGeneric(out, packed, scales, biases, groupSize, desc.Bits) + } + return out, nil +} + +// dequantizeBit8 walks the 8-bit-aligned packed path with the unpack +// inlined. One byte per element, no shift required. +func dequantizeBit8(out []float32, packed []byte, scales, biases []float32, groupSize int) { + for i := 0; i < len(out); { + group := i / groupSize + end := (group + 1) * groupSize + if end > len(out) { + end = len(out) + } + scale := scales[group] + bias := biases[group] + for ; i < end; i++ { + out[i] = float32(packed[i])*scale + bias + } + } +} + +// dequantizeBit4 walks the 4-bit-nibble-packed path with the unpack +// inlined. Two values per byte; low nibble for even indices, high +// nibble for odd indices. +// +// When the per-group walk lands on a byte boundary we batch 2 elements +// per byte read — amortises the packed-slice load + bounds check across +// both nibble lanes. JANGTQ-style groupSize=64 (== 32 bytes at 4-bit) +// lands on a byte boundary at every group start, so the fast path +// covers the full group body. Single-element prefix + suffix handle +// the rare case where the row's start offset is mid-byte or the group +// runs short at the tensor tail. +// +// The natural if/else for nibble select (rather than a branchless +// bit-mux) avoids the Apple Silicon FCMPD-over-FMOV penalty observed +// when bit-mux-style code regresses against direct branches on M3. +func dequantizeBit4(out []float32, packed []byte, scales, biases []float32, groupSize int) { + for i := 0; i < len(out); { + group := i / groupSize + end := (group + 1) * groupSize + if end > len(out) { + end = len(out) + } + scale := scales[group] + bias := biases[group] + // Drain prefix elements until i is byte-aligned (i&1 == 0). + if i&1 != 0 && i < end { + b := packed[i>>1] + out[i] = float32(b>>4)*scale + bias + i++ + } + // Walk 2-at-a-time on byte-aligned boundaries. + for i+2 <= end { + b := packed[i>>1] + out[i] = float32(b&0x0F)*scale + bias + out[i+1] = float32(b>>4)*scale + bias + i += 2 + } + // Drain suffix. + for ; i < end; i++ { + b := packed[i>>1] + if i&1 == 0 { + out[i] = float32(b&0x0F)*scale + bias + } else { + out[i] = float32(b>>4)*scale + bias + } + } + } +} + +// dequantizeBit2 walks the 2-bit-packed path with the unpack inlined. +// Four values per byte; the shift is `(i&3)<<1`. This is the dominant +// MiniMax M2 routed-expert weight path. +// +// When the per-group walk lands on a byte boundary we batch 4 elements +// per byte read — amortises the packed-slice load across the four 2-bit +// lanes. The JANGTQ default groupSize=64 (16 bytes at 2-bit) lands on a +// byte boundary at every group start, so the fast path covers the full +// group body. Single-element prefix + suffix handles the (rare) case +// where the group runs short at the tensor tail. +func dequantizeBit2(out []float32, packed []byte, scales, biases []float32, groupSize int) { + for i := 0; i < len(out); { + group := i / groupSize + end := (group + 1) * groupSize + if end > len(out) { + end = len(out) + } + scale := scales[group] + bias := biases[group] + // Drain prefix elements until i is byte-aligned (i&3 == 0). + for ; i < end && (i&3) != 0; i++ { + q := (packed[i>>2] >> uint((i&3)<<1)) & 0x03 + out[i] = float32(q)*scale + bias + } + // Walk 4-at-a-time on byte-aligned boundaries. + for i+4 <= end { + b := packed[i>>2] + out[i] = float32(b&0x03)*scale + bias + out[i+1] = float32((b>>2)&0x03)*scale + bias + out[i+2] = float32((b>>4)&0x03)*scale + bias + out[i+3] = float32((b>>6)&0x03)*scale + bias + i += 4 + } + // Drain suffix. + for ; i < end; i++ { + q := (packed[i>>2] >> uint((i&3)<<1)) & 0x03 + out[i] = float32(q)*scale + bias + } + } +} + +// dequantizeBit1 walks the 1-bit-packed path with the unpack inlined. +// Eight values per byte; mask + shift only. +// +// When the per-group walk lands on a byte boundary we batch 8 elements +// per byte read — amortises the packed-slice load + bounds check across +// all eight 1-bit lanes. JANGTQ-style groupSize=64 (== 8 bytes at +// 1-bit) lands on a byte boundary at every group start. Single-element +// prefix + suffix handle mid-byte starts and short-tail groups. +func dequantizeBit1(out []float32, packed []byte, scales, biases []float32, groupSize int) { + for i := 0; i < len(out); { + group := i / groupSize + end := (group + 1) * groupSize + if end > len(out) { + end = len(out) + } + scale := scales[group] + bias := biases[group] + // Drain prefix elements until i is byte-aligned (i&7 == 0). + for ; i < end && (i&7) != 0; i++ { + q := (packed[i>>3] >> uint(i&7)) & 0x01 + out[i] = float32(q)*scale + bias + } + // Walk 8-at-a-time on byte-aligned boundaries. + for i+8 <= end { + b := packed[i>>3] + out[i] = float32(b&0x01)*scale + bias + out[i+1] = float32((b>>1)&0x01)*scale + bias + out[i+2] = float32((b>>2)&0x01)*scale + bias + out[i+3] = float32((b>>3)&0x01)*scale + bias + out[i+4] = float32((b>>4)&0x01)*scale + bias + out[i+5] = float32((b>>5)&0x01)*scale + bias + out[i+6] = float32((b>>6)&0x01)*scale + bias + out[i+7] = float32((b>>7)&0x01)*scale + bias + i += 8 + } + // Drain suffix. + for ; i < end; i++ { + q := (packed[i>>3] >> uint(i&7)) & 0x01 + out[i] = float32(q)*scale + bias + } + } +} + +// dequantizeBitGeneric walks any non-power-of-2 packed width (e.g. 3-bit) +// with the bit-walk inlined directly. The outer DequantizePackedTensor +// dispatch already proved we won't hit a byte-aligned width here, so we +// skip the fast-path switch in unpackValue that would otherwise pay 4 +// extra comparisons per element. +func dequantizeBitGeneric(out []float32, packed []byte, scales, biases []float32, groupSize, bits int) { + for i := 0; i < len(out); { + group := i / groupSize + end := (group + 1) * groupSize + if end > len(out) { + end = len(out) + } + scale := scales[group] + bias := biases[group] + for ; i < end; i++ { + bitOffset := i * bits + remaining := bits + shiftOut := 0 + value := uint16(0) + for remaining > 0 { + byteIndex := bitOffset / 8 + shiftIn := bitOffset % 8 + take := remaining + if avail := 8 - shiftIn; avail < take { + take = avail + } + mask := uint16((1 << take) - 1) + chunk := (uint16(packed[byteIndex]) >> shiftIn) & mask + value |= chunk << shiftOut + remaining -= take + bitOffset += take + shiftOut += take + } + out[i] = float32(uint8(value))*scale + bias + } + } +} + +// packed, _ := jang.PackQuantizedValues(desc, values) +func PackQuantizedValues(desc PackedTensorDescriptor, values []uint8) ([]byte, error) { + if err := validateDescriptor(desc); err != nil { + return nil, err + } + if uint64(len(values)) != desc.Elements { + return nil, core.NewError(core.Sprintf("jang: packed tensor %q value count %d, expected %d", desc.Name, len(values), desc.Elements)) + } + out := make([]byte, desc.PackedBytes) + maxValue := uint8((1 << desc.Bits) - 1) + for i, value := range values { + if value > maxValue { + return nil, core.NewError(core.Sprintf("jang: packed tensor %q value %d exceeds %d-bit max %d", desc.Name, value, desc.Bits, maxValue)) + } + writeValue(out, i, desc.Bits, value) + } + return out, nil +} + +func inferTensorRole(name string) TensorRole { + lower := core.Lower(name) + switch { + case core.Contains(lower, "embed_tokens"): + return TensorRoleEmbedTokens + case core.Contains(lower, "lm_head"): + return TensorRoleLMHead + case core.Contains(lower, "shared_expert"): + return TensorRoleSharedExpert + case core.Contains(lower, "experts.") || core.Contains(lower, "block_sparse_moe"): + return TensorRoleRoutedExpert + case core.Contains(lower, "self_attn") || core.Contains(lower, ".attention.") || core.Contains(lower, ".q_proj") || core.Contains(lower, ".k_proj") || core.Contains(lower, ".v_proj") || core.Contains(lower, ".o_proj"): + return TensorRoleAttention + default: + return TensorRoleDefault + } +} + +func bitsForRole(info *Info, role TensorRole) int { + switch role { + case TensorRoleAttention: + return firstPositive(info.AttentionBits, info.BitsDefault, ProfileBits(info.Profile)) + case TensorRoleSharedExpert: + return firstPositive(info.SharedExpertBits, info.BitsDefault, ProfileBits(info.Profile)) + case TensorRoleRoutedExpert: + return firstPositive(info.RoutedExpertBits, info.BitsDefault, ProfileBits(info.Profile)) + case TensorRoleEmbedTokens: + return firstPositive(info.EmbedTokensBits, info.BitsDefault, ProfileBits(info.Profile)) + case TensorRoleLMHead: + return firstPositive(info.LMHeadBits, info.BitsDefault, ProfileBits(info.Profile)) + default: + return firstPositive(info.BitsDefault, ProfileBits(info.Profile)) + } +} + +func roleBits(info *Info) map[string]int { + if info == nil { + return nil + } + roles := []TensorRole{ + TensorRoleDefault, + TensorRoleAttention, + TensorRoleSharedExpert, + TensorRoleRoutedExpert, + TensorRoleEmbedTokens, + TensorRoleLMHead, + } + out := map[string]int{} + for _, role := range roles { + if bits := bitsForRole(info, role); bits > 0 { + out[string(role)] = bits + } + } + if len(out) == 0 { + return nil + } + return out +} + +func minMaxBits(rb map[string]int) (int, int) { + minBits, maxBits := 0, 0 + for _, bits := range rb { + if bits <= 0 { + continue + } + if minBits == 0 || bits < minBits { + minBits = bits + } + if bits > maxBits { + maxBits = bits + } + } + return minBits, maxBits +} + +func packedFormat(info *Info) string { + if info == nil { + return "" + } + lower := core.Lower(core.Concat(info.WeightFormat, " ", info.Profile, " ", info.Method)) + switch { + case core.Contains(lower, "mxtq"): + return "mxtq" + case core.Contains(lower, "jangtq"): + return "jangtq" + case core.Contains(lower, "jang"): + return "jang" + default: + return core.Lower(info.WeightFormat) + } +} + +func valuesPerByte(bits int) int { + if bits <= 0 { + return 0 + } + return 8 / bits +} + +func shapeElements(shape []uint64) (uint64, error) { + if len(shape) == 0 { + return 0, core.NewError("jang: packed tensor shape is required") + } + elements := uint64(1) + for _, dim := range shape { + if dim == 0 { + return 0, core.NewError("jang: packed tensor shape contains zero dimension") + } + if elements > ^uint64(0)/dim { + return 0, core.NewError("jang: packed tensor shape overflows element count") + } + elements *= dim + } + return elements, nil +} + +func validateDescriptor(desc PackedTensorDescriptor) error { + if desc.Elements == 0 { + return core.NewError(core.Sprintf("jang: packed tensor %q has no elements", desc.Name)) + } + if err := validateBits(desc.Bits, desc.Name); err != nil { + return err + } + if desc.GroupSize <= 0 { + return core.NewError(core.Sprintf("jang: packed tensor %q has invalid group size %d", desc.Name, desc.GroupSize)) + } + if desc.PackedBytes <= 0 { + return core.NewError(core.Sprintf("jang: packed tensor %q has invalid packed byte count %d", desc.Name, desc.PackedBytes)) + } + if desc.ScaleCount <= 0 || desc.BiasCount <= 0 { + return core.NewError(core.Sprintf("jang: packed tensor %q has invalid scale/bias counts", desc.Name)) + } + return nil +} + +func validateBits(bits int, name string) error { + switch bits { + case 1, 2, 3, 4, 8: + return nil + default: + return core.NewError(core.Sprintf("jang: packed tensor %q has unsupported %d-bit width", name, bits)) + } +} + +func unpackValue(packed []byte, index, bits int) uint8 { + // Fast paths for the byte-aligned bit widths emitted by the JANG + // packers (1-bit binary, 2-bit JANGTQ routed-expert, 4-bit nibble + // JANG_4, 8-bit dense). These cover the overwhelming majority of + // real model-load dequant calls and bypass the generic walk loop, + // which fires hundreds of millions of times per tensor materialise. + switch bits { + case 8: + return packed[index] + case 4: + b := packed[index>>1] + if index&1 == 0 { + return b & 0x0F + } + return b >> 4 + case 2: + return (packed[index>>2] >> uint((index&3)<<1)) & 0x03 + case 1: + return (packed[index>>3] >> uint(index&7)) & 0x01 + } + bitOffset := index * bits + remaining := bits + shiftOut := 0 + value := uint16(0) + for remaining > 0 { + byteIndex := bitOffset / 8 + shiftIn := bitOffset % 8 + take := minInt(remaining, 8-shiftIn) + mask := uint16((1 << take) - 1) + chunk := (uint16(packed[byteIndex]) >> shiftIn) & mask + value |= chunk << shiftOut + remaining -= take + bitOffset += take + shiftOut += take + } + return uint8(value) +} + +func writeValue(out []byte, index, bits int, value uint8) { + bitOffset := index * bits + remaining := bits + raw := uint16(value) + for remaining > 0 { + byteIndex := bitOffset / 8 + shift := bitOffset % 8 + take := minInt(remaining, 8-shift) + mask := uint16((1 << take) - 1) + out[byteIndex] |= byte((raw & mask) << shift) + raw >>= take + remaining -= take + bitOffset += take + } +} + +func cloneRoleBits(rb map[string]int) map[string]int { + if len(rb) == 0 { + return nil + } + cloned := make(map[string]int, len(rb)) + for key, value := range rb { + cloned[key] = value + } + return cloned +} + +func ceilDivUint64(value, divisor uint64) uint64 { + if divisor == 0 || value == 0 { + return 0 + } + quotient := value / divisor + if value%divisor != 0 { + quotient++ + } + return quotient +} + +func maxIntValue() int { + return int(^uint(0) >> 1) +} + +func minInt(a, b int) int { + if a < b { + return a + } + return b +} + +func firstPositive(values ...int) int { + for _, value := range values { + if value > 0 { + return value + } + } + return 0 +} + +func normaliseArchitecture(value string) string { + value = core.Lower(core.Trim(value)) + value = core.Replace(value, "-", "_") + switch value { + case "qwen3_5": + return "qwen3_next" + case "minimaxm2", "minimax_m2": + return "minimax_m2" + case "mixtral": + return "mixtral" + case "mistral": + return "mistral" + case "phi", "phi3", "phi4": + return "phi" + case "deepseek", "deepseek_v3", "deepseek_r1": + return "deepseek" + case "gptoss", "gpt_oss", "gpt_oss_model": + return "gpt_oss" + case "bert": + return "bert" + case "bert_rerank", "bert_cross_encoder": + return "bert_rerank" + default: + return value + } +} diff --git a/go/quant/jang/jang_bench_test.go b/go/quant/jang/jang_bench_test.go new file mode 100644 index 0000000..2cb5a7d --- /dev/null +++ b/go/quant/jang/jang_bench_test.go @@ -0,0 +1,441 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the driver-neutral JANG / JANGTQ quant primitives. +// Per AX-11 — NewPackedTensorDescriptor fires per tensor at model +// load (Minimax-M2 carries hundreds of routed-expert tensors). +// BuildPackedProfile + ClonePackedProfile fire per profile lifted +// across runtime boundaries. ValidatePackedTensor runs per kernel +// dispatch on the CPU parity path. ParseConfig + ReadConfig hit on +// every model load. +// +// Run: go test -bench='Benchmark' -benchmem -run='^$' ./quant/jang + +package jang + +import "testing" + +// Sinks defeat compiler DCE. +var ( + jangSinkInfo *Info + jangSinkDescriptor PackedTensorDescriptor + jangSinkProfile *PackedProfile + jangSinkClonedProf *PackedProfile + jangSinkBits int + jangSinkPacked []byte + jangSinkValues []float32 + jangSinkErr error +) + +// benchInfo returns the same JANGTQ profile shape the test suite +// uses — 4-bit groups with a mixed-bit role table. +func benchInfo() *Info { + return &Info{ + Version: 2, + WeightFormat: "mxtq", + Profile: "JANGTQ", + Method: "affine+mxtq", + GroupSize: 64, + BitsDefault: 2, + AttentionBits: 8, + SharedExpertBits: 8, + RoutedExpertBits: 2, + EmbedTokensBits: 8, + LMHeadBits: 8, + } +} + +// --- ParseConfig (per-model load) --- + +func BenchmarkJang_ParseConfig_Minimal(b *testing.B) { + data := []byte(`{ + "version": 2, + "weight_format": "mxtq", + "profile": "JANGTQ", + "source_model": { + "name": "MiniMax-M2", + "org": "MiniMaxAI", + "architecture": "MiniMaxM2" + }, + "mxtq_bits": { + "attention": 8, + "shared_expert": 8, + "routed_expert": 2, + "embed_tokens": 8, + "lm_head": 8 + }, + "quantization": { + "method": "affine+mxtq", + "group_size": 64, + "bits_default": 2 + } + }`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkInfo, jangSinkErr = ParseConfig(data) + } +} + +func BenchmarkJang_ParseConfig_WithCapabilities(b *testing.B) { + data := []byte(`{ + "version": 2, + "weight_format": "mxtq", + "profile": "JANGTQ", + "source_model": { + "name": "MiniMax-M2", + "org": "MiniMaxAI", + "architecture": "MiniMaxM2" + }, + "mxtq_bits": { + "attention": 8, + "shared_expert": 8, + "routed_expert": 2, + "embed_tokens": 8, + "lm_head": 8 + }, + "quantization": { + "method": "affine+mxtq", + "group_size": 64, + "bits_default": 2 + }, + "capabilities": { + "reasoning_parser": "qwen-think", + "tool_parser": "qwen-tool", + "think_in_template": true, + "supports_tools": true, + "supports_thinking": true, + "family": "minimax_m2", + "modality": "text", + "cache_type": "paged-q8" + } + }`) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkInfo, jangSinkErr = ParseConfig(data) + } +} + +// --- NewPackedTensorDescriptor (per-tensor at model load) --- + +func BenchmarkJang_NewPackedTensorDescriptor_RoutedExpert_Small(b *testing.B) { + info := benchInfo() + shape := []uint64{2048, 2048} + name := "model.layers.0.block_sparse_moe.experts.0.w1.weight" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkDescriptor, jangSinkErr = NewPackedTensorDescriptor(name, shape, info) + } +} + +func BenchmarkJang_NewPackedTensorDescriptor_RoutedExpert_Large(b *testing.B) { + info := benchInfo() + shape := []uint64{6144, 6144} + name := "model.layers.0.block_sparse_moe.experts.0.w1.weight" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkDescriptor, jangSinkErr = NewPackedTensorDescriptor(name, shape, info) + } +} + +func BenchmarkJang_NewPackedTensorDescriptor_Attention(b *testing.B) { + info := benchInfo() + shape := []uint64{4096, 4096} + name := "model.layers.0.self_attn.q_proj.weight" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkDescriptor, jangSinkErr = NewPackedTensorDescriptor(name, shape, info) + } +} + +func BenchmarkJang_NewPackedTensorDescriptor_EmbedTokens(b *testing.B) { + info := benchInfo() + shape := []uint64{262144, 4096} + name := "model.embed_tokens.weight" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkDescriptor, jangSinkErr = NewPackedTensorDescriptor(name, shape, info) + } +} + +// --- BuildPackedProfile (per profile cross-runtime) --- + +func BenchmarkJang_BuildPackedProfile(b *testing.B) { + info := benchInfo() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkProfile = BuildPackedProfile(info) + } +} + +// --- ClonePackedProfile (per runtime hand-off) --- + +func BenchmarkJang_ClonePackedProfile(b *testing.B) { + profile := BuildPackedProfile(benchInfo()) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkClonedProf = ClonePackedProfile(profile) + } +} + +// --- ProfileBits (per-role table build) --- + +func BenchmarkJang_ProfileBits_JANGTQ(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkBits = ProfileBits("JANGTQ") + } +} + +func BenchmarkJang_ProfileBits_JANG_4(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkBits = ProfileBits("JANG_4M") + } +} + +func BenchmarkJang_ProfileBits_Unknown(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkBits = ProfileBits("unknown") + } +} + +// --- ValidatePackedTensor (per kernel dispatch) --- + +func BenchmarkJang_ValidatePackedTensor_2bit(b *testing.B) { + info := benchInfo() + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.0.w1.weight", []uint64{64, 64}, info) + if err != nil { + b.Fatal(err) + } + packed := make([]byte, desc.PackedBytes) + scales := make([]float32, desc.ScaleCount) + biases := make([]float32, desc.BiasCount) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkErr = ValidatePackedTensor(desc, packed, scales, biases) + } +} + +func BenchmarkJang_ValidatePackedTensor_8bit(b *testing.B) { + info := benchInfo() + desc, err := NewPackedTensorDescriptor("model.layers.0.self_attn.q_proj.weight", []uint64{64, 64}, info) + if err != nil { + b.Fatal(err) + } + packed := make([]byte, desc.PackedBytes) + scales := make([]float32, desc.ScaleCount) + biases := make([]float32, desc.BiasCount) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkErr = ValidatePackedTensor(desc, packed, scales, biases) + } +} + +// --- PackQuantizedValues (CPU parity-test path) --- +// 2-bit / 4-bit / 8-bit shapes; values per byte differs across bit +// widths so the pack hot loop sees all three. + +func BenchmarkJang_PackQuantizedValues_2bit_256(b *testing.B) { + info := benchInfo() + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.0.w1.weight", []uint64{16, 16}, info) + if err != nil { + b.Fatal(err) + } + values := make([]uint8, desc.Elements) + for i := range values { + values[i] = uint8(i % 4) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkPacked, jangSinkErr = PackQuantizedValues(desc, values) + } +} + +func BenchmarkJang_PackQuantizedValues_8bit_256(b *testing.B) { + info := benchInfo() + desc, err := NewPackedTensorDescriptor("model.layers.0.self_attn.q_proj.weight", []uint64{16, 16}, info) + if err != nil { + b.Fatal(err) + } + values := make([]uint8, desc.Elements) + for i := range values { + values[i] = uint8(i % 256) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkPacked, jangSinkErr = PackQuantizedValues(desc, values) + } +} + +func BenchmarkJang_PackQuantizedValues_2bit_4096(b *testing.B) { + info := benchInfo() + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.0.w1.weight", []uint64{64, 64}, info) + if err != nil { + b.Fatal(err) + } + values := make([]uint8, desc.Elements) + for i := range values { + values[i] = uint8(i % 4) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkPacked, jangSinkErr = PackQuantizedValues(desc, values) + } +} + +// --- DequantizePackedTensor (CPU parity-test path) --- + +func BenchmarkJang_DequantizePackedTensor_2bit_256(b *testing.B) { + info := benchInfo() + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.0.w1.weight", []uint64{16, 16}, info) + if err != nil { + b.Fatal(err) + } + values := make([]uint8, desc.Elements) + for i := range values { + values[i] = uint8(i % 4) + } + packed, err := PackQuantizedValues(desc, values) + if err != nil { + b.Fatal(err) + } + scales := make([]float32, desc.ScaleCount) + biases := make([]float32, desc.BiasCount) + for i := range scales { + scales[i] = 0.125 + biases[i] = -1 + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkValues, jangSinkErr = DequantizePackedTensor(desc, packed, scales, biases) + } +} + +func BenchmarkJang_DequantizePackedTensor_2bit_4096(b *testing.B) { + info := benchInfo() + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.0.w1.weight", []uint64{64, 64}, info) + if err != nil { + b.Fatal(err) + } + values := make([]uint8, desc.Elements) + for i := range values { + values[i] = uint8(i % 4) + } + packed, err := PackQuantizedValues(desc, values) + if err != nil { + b.Fatal(err) + } + scales := make([]float32, desc.ScaleCount) + biases := make([]float32, desc.BiasCount) + for i := range scales { + scales[i] = 0.125 + biases[i] = -1 + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkValues, jangSinkErr = DequantizePackedTensor(desc, packed, scales, biases) + } +} + +func BenchmarkJang_DequantizePackedTensor_8bit_256(b *testing.B) { + info := benchInfo() + desc, err := NewPackedTensorDescriptor("model.layers.0.self_attn.q_proj.weight", []uint64{16, 16}, info) + if err != nil { + b.Fatal(err) + } + values := make([]uint8, desc.Elements) + for i := range values { + values[i] = uint8(i % 256) + } + packed, err := PackQuantizedValues(desc, values) + if err != nil { + b.Fatal(err) + } + scales := make([]float32, desc.ScaleCount) + biases := make([]float32, desc.BiasCount) + for i := range scales { + scales[i] = 0.0625 + biases[i] = -2 + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkValues, jangSinkErr = DequantizePackedTensor(desc, packed, scales, biases) + } +} + +// benchInfoBits returns a benchInfo where the routed-expert bits override +// is set to the requested width. NewPackedTensorDescriptor routes a tensor +// matching block_sparse_moe.experts to RoutedExpertBits, so we can exercise +// any width in {1, 2, 3, 4, 8} through the same name. +func benchInfoBits(bits int) *Info { + info := benchInfo() + info.RoutedExpertBits = bits + info.BitsDefault = bits + return info +} + +func benchDequantize(b *testing.B, bits, elements int) { + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.0.w1.weight", []uint64{uint64(elements)}, benchInfoBits(bits)) + if err != nil { + b.Fatal(err) + } + maxValue := uint8((1 << bits) - 1) + values := make([]uint8, desc.Elements) + for i := range values { + values[i] = uint8(i) & maxValue + } + packed, err := PackQuantizedValues(desc, values) + if err != nil { + b.Fatal(err) + } + scales := make([]float32, desc.ScaleCount) + biases := make([]float32, desc.BiasCount) + for i := range scales { + scales[i] = 0.0625 + biases[i] = -2 + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + jangSinkValues, jangSinkErr = DequantizePackedTensor(desc, packed, scales, biases) + } +} + +func BenchmarkJang_DequantizePackedTensor_1bit_4096(b *testing.B) { + benchDequantize(b, 1, 4096) +} + +func BenchmarkJang_DequantizePackedTensor_2bit_16384(b *testing.B) { + benchDequantize(b, 2, 16384) +} + +func BenchmarkJang_DequantizePackedTensor_3bit_4096(b *testing.B) { + benchDequantize(b, 3, 4096) +} + +func BenchmarkJang_DequantizePackedTensor_4bit_4096(b *testing.B) { + benchDequantize(b, 4, 4096) +} + +func BenchmarkJang_DequantizePackedTensor_8bit_4096(b *testing.B) { + benchDequantize(b, 8, 4096) +} diff --git a/go/quant/jang/jang_test.go b/go/quant/jang/jang_test.go new file mode 100644 index 0000000..498581a --- /dev/null +++ b/go/quant/jang/jang_test.go @@ -0,0 +1,320 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package jang + +import ( + "testing" + + core "dappco.re/go" +) + +func testJANGTQInfo() *Info { + return &Info{ + Version: 2, + WeightFormat: "mxtq", + Profile: "JANGTQ", + Method: "affine+mxtq", + GroupSize: 4, + BitsDefault: 2, + AttentionBits: 8, + SharedExpertBits: 8, + RoutedExpertBits: 2, + EmbedTokensBits: 8, + LMHeadBits: 8, + } +} + +func TestJang_PackedTensorDescriptorMXTQRoutedExpert_Good(t *testing.T) { + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.17.w1.weight", []uint64{2, 4}, testJANGTQInfo()) + if err != nil { + t.Fatalf("NewPackedTensorDescriptor() error = %v", err) + } + + if desc.Type != "jangtq" || desc.Format != "mxtq" || desc.Profile != "JANGTQ" { + t.Fatalf("profile = type:%q format:%q profile:%q", desc.Type, desc.Format, desc.Profile) + } + if desc.Role != TensorRoleRoutedExpert || desc.Bits != 2 || desc.GroupSize != 4 { + t.Fatalf("descriptor = %+v, want routed expert 2-bit group 4", desc) + } + if desc.Elements != 8 || desc.Groups != 2 || desc.PackedBytes != 2 || desc.ScaleCount != 2 || desc.BiasCount != 2 { + t.Fatalf("descriptor sizes = %+v, want 8 elements, 2 groups, 2 packed bytes", desc) + } + if desc.BitOrder != BitOrderLSB0 || desc.Encoding != EncodingAffine { + t.Fatalf("layout = bit_order:%q encoding:%q", desc.BitOrder, desc.Encoding) + } +} + +func TestJang_PackedTensorDescriptorAttentionUsesWideBits_Good(t *testing.T) { + desc, err := NewPackedTensorDescriptor("model.layers.0.self_attn.q_proj.weight", []uint64{2, 4}, testJANGTQInfo()) + if err != nil { + t.Fatalf("NewPackedTensorDescriptor() error = %v", err) + } + + if desc.Role != TensorRoleAttention || desc.Bits != 8 || desc.PackedBytes != 8 { + t.Fatalf("descriptor = %+v, want attention 8-bit un-nibbled bytes", desc) + } +} + +func TestJang_PackedTensorDescriptorBadUnsupportedBits(t *testing.T) { + info := testJANGTQInfo() + info.RoutedExpertBits = 5 + + _, err := NewPackedTensorDescriptor("model.layers.0.mlp.experts.0.down_proj.weight", []uint64{4, 4}, info) + if err == nil || !core.Contains(err.Error(), "unsupported") || !core.Contains(err.Error(), "5-bit") { + t.Fatalf("error = %v, want explicit unsupported 5-bit error", err) + } +} + +func TestJang_DequantizePackedTensor_Good(t *testing.T) { + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.3.w2.weight", []uint64{8}, testJANGTQInfo()) + if err != nil { + t.Fatalf("NewPackedTensorDescriptor() error = %v", err) + } + packed, err := PackQuantizedValues(desc, []uint8{0, 1, 2, 3, 0, 1, 2, 3}) + if err != nil { + t.Fatalf("PackQuantizedValues() error = %v", err) + } + + out, err := DequantizePackedTensor(desc, packed, []float32{0.5, 1}, []float32{-1, 10}) + if err != nil { + t.Fatalf("DequantizePackedTensor() error = %v", err) + } + + want := []float32{-1, -0.5, 0, 0.5, 10, 11, 12, 13} + if len(out) != len(want) { + t.Fatalf("out length = %d, want %d", len(out), len(want)) + } + for i := range want { + if out[i] != want[i] { + t.Fatalf("out[%d] = %v, want %v (all=%v)", i, out[i], want[i], out) + } + } +} + +func TestJang_ValidatePackedTensorBadPackedLength(t *testing.T) { + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.3.w2.weight", []uint64{8}, testJANGTQInfo()) + if err != nil { + t.Fatalf("NewPackedTensorDescriptor() error = %v", err) + } + + err = ValidatePackedTensor(desc, []byte{0}, []float32{1, 1}, []float32{0, 0}) + if err == nil || !core.Contains(err.Error(), "packed length") { + t.Fatalf("error = %v, want packed length validation", err) + } +} + +// roundTripFixture builds a descriptor at the requested bit width with the +// MXTQ routed-expert tensor name (the inferTensorRole route that picks up +// RoutedExpertBits) and feeds it crafted values such that every group is +// exercised. Returns descriptor + the values written in. +func roundTripFixture(t *testing.T, bits int, elements int, groupSize int) (PackedTensorDescriptor, []uint8, []byte, []float32, []float32) { + t.Helper() + info := &Info{ + Version: 2, + WeightFormat: "mxtq", + Profile: "JANGTQ", + Method: "affine+mxtq", + GroupSize: groupSize, + BitsDefault: bits, + RoutedExpertBits: bits, + } + desc, err := NewPackedTensorDescriptor("model.layers.0.block_sparse_moe.experts.0.w1.weight", []uint64{uint64(elements)}, info) + if err != nil { + t.Fatalf("NewPackedTensorDescriptor(%d-bit): %v", bits, err) + } + maxValue := uint8((1 << bits) - 1) + values := make([]uint8, desc.Elements) + for i := range values { + // Walk the full 0..maxValue range so every nibble/lane is touched. + values[i] = uint8(i) & maxValue + } + packed, err := PackQuantizedValues(desc, values) + if err != nil { + t.Fatalf("PackQuantizedValues(%d-bit): %v", bits, err) + } + // Distinct per-group scale + bias so a regression that mis-indexes groups + // surfaces as a wrong magnitude, not a hidden silent identity. + scales := make([]float32, desc.ScaleCount) + biases := make([]float32, desc.BiasCount) + for i := range scales { + scales[i] = 0.25 + float32(i)*0.0625 + biases[i] = -1 - float32(i)*0.5 + } + return desc, values, packed, scales, biases +} + +// expectedDequantize is the smallest possible reference dequant — pure +// per-element arithmetic with the generic unpack walk used by upstream +// before the W10-N specialisation. Used as the bit-exact oracle. +func expectedDequantize(t *testing.T, values []uint8, scales, biases []float32, groupSize int) []float32 { + t.Helper() + out := make([]float32, len(values)) + for i, v := range values { + group := i / groupSize + out[i] = float32(v)*scales[group] + biases[group] + } + return out +} + +func TestJang_DequantizePackedTensor_RoundTrip_1bit(t *testing.T) { + // 4096 elements with groupSize=64 to exercise the multi-group dispatch. + desc, values, packed, scales, biases := roundTripFixture(t, 1, 4096, 64) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(1-bit): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +func TestJang_DequantizePackedTensor_RoundTrip_2bit(t *testing.T) { + desc, values, packed, scales, biases := roundTripFixture(t, 2, 4096, 64) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(2-bit): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +func TestJang_DequantizePackedTensor_RoundTrip_3bit(t *testing.T) { + // 3-bit hits the generic-walk default branch — the dequant must still + // be bit-exact against the pre-specialisation oracle. + desc, values, packed, scales, biases := roundTripFixture(t, 3, 4096, 64) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(3-bit): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +func TestJang_DequantizePackedTensor_RoundTrip_4bit(t *testing.T) { + desc, values, packed, scales, biases := roundTripFixture(t, 4, 4096, 64) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(4-bit): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +func TestJang_DequantizePackedTensor_RoundTrip_8bit(t *testing.T) { + desc, values, packed, scales, biases := roundTripFixture(t, 8, 4096, 64) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(8-bit): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +// TestJang_DequantizePackedTensor_RoundTrip_2bit_ShortTail exercises the +// case where the tensor's element count is NOT a multiple of groupSize, +// so the final group runs short and the 2-bit suffix-drain path covers +// the tail. +func TestJang_DequantizePackedTensor_RoundTrip_2bit_ShortTail(t *testing.T) { + // 130 elements with groupSize=64 → 3 groups, last group has 2 elements. + desc, values, packed, scales, biases := roundTripFixture(t, 2, 130, 64) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(2-bit short tail): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +// TestJang_DequantizePackedTensor_RoundTrip_2bit_GroupSize2 exercises the +// case where groupSize < 4 — the 2-bit batched fast path can't fire on a +// 4-elements-per-byte stride, so the per-element prefix path must cover +// every element. +func TestJang_DequantizePackedTensor_RoundTrip_2bit_GroupSize2(t *testing.T) { + desc, values, packed, scales, biases := roundTripFixture(t, 2, 32, 2) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(2-bit groupSize=2): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +// TestJang_DequantizePackedTensor_RoundTrip_4bit_ShortTail covers the +// 4-bit prefix + suffix drains around the batched 2-per-byte fast path +// when the final group is shorter than groupSize. +func TestJang_DequantizePackedTensor_RoundTrip_4bit_ShortTail(t *testing.T) { + // 67 elements with groupSize=64 → last group has 3 elements; the + // 2-per-byte batched path takes 2 of them, the suffix drains the 1. + desc, values, packed, scales, biases := roundTripFixture(t, 4, 67, 64) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(4-bit short tail): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +// TestJang_DequantizePackedTensor_RoundTrip_4bit_GroupSize1 covers the +// degenerate case where groupSize=1, forcing every element into the +// suffix-drain path (no batched stride can fire). +func TestJang_DequantizePackedTensor_RoundTrip_4bit_GroupSize1(t *testing.T) { + desc, values, packed, scales, biases := roundTripFixture(t, 4, 16, 1) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(4-bit groupSize=1): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +// TestJang_DequantizePackedTensor_RoundTrip_1bit_ShortTail covers the +// 1-bit prefix + suffix drains around the batched 8-per-byte fast path +// when the final group is shorter than groupSize. +func TestJang_DequantizePackedTensor_RoundTrip_1bit_ShortTail(t *testing.T) { + // 133 elements with groupSize=64 → last group has 5 elements; the + // 8-per-byte batched path can't fire, suffix-drain takes all 5. + desc, values, packed, scales, biases := roundTripFixture(t, 1, 133, 64) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(1-bit short tail): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +// TestJang_DequantizePackedTensor_RoundTrip_1bit_GroupSize4 covers the +// case where groupSize=4 < 8, so the 8-per-byte batched fast path can +// never fire and the prefix path must cover every element. +func TestJang_DequantizePackedTensor_RoundTrip_1bit_GroupSize4(t *testing.T) { + desc, values, packed, scales, biases := roundTripFixture(t, 1, 32, 4) + got, err := DequantizePackedTensor(desc, packed, scales, biases) + if err != nil { + t.Fatalf("DequantizePackedTensor(1-bit groupSize=4): %v", err) + } + want := expectedDequantize(t, values, scales, biases, desc.GroupSize) + assertBitExact(t, got, want) +} + +func assertBitExact(t *testing.T, got, want []float32) { + t.Helper() + if len(got) != len(want) { + t.Fatalf("length = %d, want %d", len(got), len(want)) + } + for i := range want { + if got[i] != want[i] { + t.Fatalf("dequant[%d] = %v, want %v (delta=%v)", i, got[i], want[i], got[i]-want[i]) + } + } +} + +func TestJang_BuildPackedProfile_Good(t *testing.T) { + profile := BuildPackedProfile(testJANGTQInfo()) + if profile == nil { + t.Fatal("profile = nil") + } + if profile.Type != "jangtq" || profile.Format != "mxtq" || !profile.Mixed { + t.Fatalf("profile = %+v, want JANGTQ/MXTQ mixed profile", profile) + } + if profile.MinBits != 2 || profile.MaxBits != 8 || profile.RoleBits[string(TensorRoleRoutedExpert)] != 2 || profile.RoleBits[string(TensorRoleAttention)] != 8 { + t.Fatalf("role bits = %+v, min/max=%d/%d", profile.RoleBits, profile.MinBits, profile.MaxBits) + } +} diff --git a/go/scheduler/backpressure_bench_test.go b/go/scheduler/backpressure_bench_test.go new file mode 100644 index 0000000..e061239 --- /dev/null +++ b/go/scheduler/backpressure_bench_test.go @@ -0,0 +1,224 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Backpressure benchmarks. The Schedule path has three points where +// flow control kicks in: +// +// 1. Queue full at Schedule — the default arm of the queue-send +// select rejects with "scheduler: queue is full" +// 2. StreamBuffer full inside run() — the producer blocks on +// j.out <- ScheduledToken (in the select with j.ctx.Done()) +// 3. Slow consumer — the producer paces against consumer rhythm +// +// The existing scheduler_bench_test.go suite measures the +// happy-path (StreamBuffer >= token count, no rejection). This +// file covers the contended shapes. +// +// Per the lane spec — backpressure scenarios are part of the load- +// bearing path between cached state and live tokens. A slow consumer +// (IDE that pauses to render markdown, agent that batches probes +// for ratelimit) sits between Virgil's continuous state and the +// user-visible stream. Coverage of producer-blocks-on-consumer is +// the only way to see whether scheduler.go's per-token select cost +// dominates a slow-consumer workload. +// +// Run: go test -bench='BenchmarkScheduler_Backpressure' -benchmem -run='^$' ./go/scheduler + +package scheduler + +import ( + "context" + "testing" + "time" + + "dappco.re/go/inference" +) + +// --- Queue full rejection at Schedule --- + +// QueueFull_Reject — submit a request to a saturated queue with an +// in-flight blocking job. Schedule takes the queue-full arm and +// returns the rejection error. Measures the rejection-path alloc +// budget — unregister + cancel + close(j.out) + NewError. +// +// Implementation: worker count 0 (normalised to 1 by Config), queue +// size 1, StreamBuffer 1. We pre-load the worker with a long-paced +// job whose first token doesn't emit during the bench window, then +// wait briefly so the worker has picked up the job out of the queue. +// Then we load the queue with a second job. From that point every +// subsequent Schedule must reject. +func BenchmarkScheduler_Backpressure_QueueFull_Reject(b *testing.B) { + base := &cancellableBenchModel{tokens: benchTokens(2), perTokenNs: 10 * time.Second} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 1, StreamBuffer: 1}) + ctx, cancel := context.WithCancel(context.Background()) + // Saturate the pipeline outside the timed loop. Drainers ensure + // no goroutines leak beyond the worker pool. + workerHandle, workerTokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "filler-worker"}) + if err != nil { + b.Fatalf("filler-worker schedule: %v", err) + } + // Wait for the worker to pull the filler-worker job off the queue + // (worker is a goroutine that drains m.queue). Polling for queue + // emptiness via a short retry loop on Schedule. + deadline := time.Now().Add(100 * time.Millisecond) + var queueHandle inference.RequestHandle + var queueTokens <-chan inference.ScheduledToken + for time.Now().Before(deadline) { + queueHandle, queueTokens, err = sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "filler-queue"}) + if err == nil { + break + } + time.Sleep(time.Millisecond) + } + if err != nil { + cancel() + b.Fatalf("filler-queue schedule never succeeded: %v", err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "rejected"}) + schedSinkHandle = handle + schedSinkErr = err + if tokens != nil { + for range tokens { + } + } + } + b.StopTimer() + // Cancel both fillers and drain so we don't block the next bench + // behind a 10s sleep. We don't care about their final state. + _, _ = sched.CancelRequest(context.Background(), workerHandle.ID) + _, _ = sched.CancelRequest(context.Background(), queueHandle.ID) + go func() { + for range workerTokens { + } + }() + go func() { + for range queueTokens { + } + }() + cancel() +} + +// --- StreamBuffer-full producer blocking --- + +// SlowConsumer_StreamBufferFull — a tight StreamBuffer of 1, a 256- +// token producer, and a consumer that only reads with a small delay +// per token. The producer blocks in the j.out <- select on every +// token after the first. Measures the cost of repeatedly entering +// the per-token select arm under contention. +func BenchmarkScheduler_Backpressure_SlowConsumer_StreamBufferFull(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(64)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 1}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + count := 0 + for range tokens { + count++ + // Per-token consumer-side delay — 1µs * 64 tokens = 64µs of + // producer-blocked time per request. Without it the + // producer-faster-than-consumer dynamic doesn't surface + // because the local channel ring rotates too fast. + time.Sleep(1 * time.Microsecond) + } + schedSinkTokensCount = count + } +} + +// --- Producer-faster-than-consumer --- + +// FastProducer_FastConsumer — baseline reference for the slow- +// consumer bench above. Same token count, same StreamBuffer=1, but +// the consumer reads at full speed. The delta isolates the cost of +// time.Sleep + select-on-channel-write pressure. +func BenchmarkScheduler_Backpressure_FastProducer_FastConsumer_StreamBuffer1(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(64)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 1}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + count := 0 + for range tokens { + count++ + } + schedSinkTokensCount = count + } +} + +// --- StreamBuffer=0 — fully synchronous handoff --- + +// SyncHandoff_StreamBufferZero — exercises the StreamBuffer=0 case +// where every producer-to-consumer handoff is a rendezvous. The Config +// normalises StreamBuffer<0 to 0; we test 0 explicitly to confirm the +// downgraded buffer still streams tokens (vs the fast path with a +// pre-allocated buffer). +func BenchmarkScheduler_Backpressure_SyncHandoff_StreamBufferZero(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 0}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + count := 0 + for range tokens { + count++ + } + schedSinkTokensCount = count + } +} + +// --- Drain-cost-of-aborted-stream-vs-fully-drained-stream --- + +// AbortedDrain_NotFullyConsumed — consumer abandons the stream +// after 4 tokens; the Generate iterator handle that wraps Schedule +// would call CancelRequest under yield-false, but here we exit the +// for range loop and let the channel close on its own. Some IDE +// patterns leak this way. +// +// Note: we don't yield-false (no Generate wrapper); we just stop +// reading from the channel. The producer will block on the next +// send until the run() Done arm trips when the iteration ends. +// This bench captures the cost of dangling channels — a real risk +// for callers who forget the drain contract. +func BenchmarkScheduler_Backpressure_AbortedDrain_4Of64(b *testing.B) { + base := &cancellableBenchModel{tokens: benchTokens(64), perTokenNs: 5 * time.Microsecond} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + count := 0 + for range tokens { + count++ + if count >= 4 { + // Aborted — cancel + drain the rest so the bench's + // next iteration starts from a clean state. This IS + // the documented contract. + schedSinkCancel, schedSinkErr = sched.CancelRequest(ctx, handle.ID) + for range tokens { + } + break + } + } + schedSinkTokensCount = count + } +} diff --git a/go/scheduler/cancellation_bench_test.go b/go/scheduler/cancellation_bench_test.go new file mode 100644 index 0000000..c93acb2 --- /dev/null +++ b/go/scheduler/cancellation_bench_test.go @@ -0,0 +1,261 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Cancellation-path benchmarks. The existing scheduler suite covers +// CancelRequest_NotFound (the no-active-id fallback); this file adds +// the four scenarios that exercise the live-cancellation paths and +// the cost of cancel-propagation through emitProbe: +// +// * Cancel BEFORE start — context cancelled while job sits in queue +// * Cancel via parent context Done — Schedule short-circuits at +// the ctx.Done() select arm +// * Cancel DURING stream — j.cancel() inside the stream consumer +// * Cancel via context.WithTimeout — emulates RPC deadline timeout +// +// Per [[project_kv_state_decode_loadbearing_for_portable_knowledge]] — +// when continuous-state runtime sits behind the scheduler, cancellation +// is the only way to release a stuck KV-restore. The cost of cancel +// propagation IS in the load-bearing path; coverage is mandatory. +// +// Pre-existing race in TestModel_QueuesRequestsAndEmitsLatencyProbe_Good +// noted in W7-D — this file uses fresh schedulers per bench so no +// shared state with that test path. +// +// Run: go test -bench='BenchmarkScheduler_Cancel' -benchmem -run='^$' ./go/scheduler + +package scheduler + +import ( + "context" + "iter" + "testing" + "time" + + "dappco.re/go/inference" +) + +// cancellableBenchModel emits its tokens slowly enough that mid-stream +// cancellation is observable in the bench window. We sleep briefly +// between tokens so the cancel arm of the run() select fires on the +// realistic 'producer in the middle of streaming' shape. +// +// Tokens slice is immutable; the closure has no shared state, so it's +// parallel-safe and reusable across b.N iterations. +type cancellableBenchModel struct { + tokens []inference.Token + perTokenNs time.Duration +} + +func (m *cancellableBenchModel) Generate(ctx context.Context, _ string, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return m.seq(ctx) +} + +func (m *cancellableBenchModel) Chat(ctx context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return m.seq(ctx) +} + +func (m *cancellableBenchModel) Classify(context.Context, []string, ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + return nil, nil +} + +func (m *cancellableBenchModel) BatchGenerate(context.Context, []string, ...inference.GenerateOption) ([]inference.BatchResult, error) { + return nil, nil +} + +func (m *cancellableBenchModel) ModelType() string { return "cancellable-bench" } +func (m *cancellableBenchModel) Info() inference.ModelInfo { return inference.ModelInfo{} } +func (m *cancellableBenchModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} } +func (m *cancellableBenchModel) Err() error { return nil } +func (m *cancellableBenchModel) Close() error { return nil } + +func (m *cancellableBenchModel) seq(ctx context.Context) iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + for _, token := range m.tokens { + if err := ctx.Err(); err != nil { + return + } + if m.perTokenNs > 0 { + timer := time.NewTimer(m.perTokenNs) + select { + case <-ctx.Done(): + timer.Stop() + return + case <-timer.C: + } + } + if !yield(token) { + return + } + } + } +} + +// --- CancelRequest mid-stream — start a stream that paces tokens +// over 100µs each, fire cancel after ~10µs, measure the cancel + +// drain cost. The j.cancel() must propagate via j.ctx.Done() into +// the run() select arm. --- + +func BenchmarkScheduler_Cancel_MidStream(b *testing.B) { + base := &cancellableBenchModel{tokens: benchTokens(64), perTokenNs: 100 * time.Microsecond} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + // Let one token emit, then cancel — exercises the j.ctx.Done() + // arm of the run loop inside the per-token select. + first := true + count := 0 + for range tokens { + count++ + if first { + schedSinkCancel, schedSinkErr = sched.CancelRequest(ctx, handle.ID) + first = false + } + } + schedSinkTokensCount = count + } +} + +// --- CancelRequest BEFORE start — queue the request behind a slow +// in-flight one so it's still in the queue when we cancel. The cancel +// path takes the same j.cancel() route but j.run() will hit the +// ctx.Err() check at the top of run() and emit a "cancelled" probe. --- + +func BenchmarkScheduler_Cancel_BeforeStart_QueueWait(b *testing.B) { + // Lead emits a small number of tokens — buffer accommodates them + // so the lead's producer can run to completion in the background + // while we cancel the queued one. StreamBuffer >= lead-tokens + // avoids a producer-blocks-on-consumer deadlock with the queued + // drain ordering below. + base := &cancellableBenchModel{tokens: benchTokens(8), perTokenNs: 50 * time.Microsecond} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 16}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Lead with one in-flight job so the second sits in the queue. + _, leadTokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "lead"}) + if err != nil { + continue + } + queued, queuedTokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "queued"}) + if err != nil { + for range leadTokens { + } + continue + } + // Cancel the queued one while the lead still runs. + schedSinkCancel, schedSinkErr = sched.CancelRequest(ctx, queued.ID) + // Drain lead first — its producer needs the buffered channel + // drained even though it fits. Then drain queued (which the + // worker will see-cancelled and emit nothing before closing). + count := 0 + for range leadTokens { + count++ + } + for range queuedTokens { + count++ + } + schedSinkTokensCount = count + } +} + +// --- Schedule under cancelled parent context — fast-fail path; the +// context.Err() guard at Schedule entry should reject immediately. --- + +func BenchmarkScheduler_Cancel_ParentContextAlreadyCancelled(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + parent, cancel := context.WithCancel(context.Background()) + cancel() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(parent, inference.ScheduledRequest{Prompt: "p"}) + schedSinkHandle = handle + schedSinkErr = err + if tokens != nil { + for range tokens { + } + } + } +} + +// --- Schedule under context.WithTimeout that has already elapsed — +// same fast-fail path but via a timer-cancelled context. Validates +// the ctx.Err() check at entry returns immediately. --- + +func BenchmarkScheduler_Cancel_TimeoutAlreadyElapsed(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + parent := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithTimeout(parent, 0) + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + schedSinkHandle = handle + schedSinkErr = err + if tokens != nil { + for range tokens { + } + } + cancel() + } +} + +// --- Cancel via context.WithDeadline that elapses during stream — +// exercise the context-deadline path through the run() select. Three +// tokens emit before the deadline trips; remainder drained empty. --- + +func BenchmarkScheduler_Cancel_DeadlineDuringStream(b *testing.B) { + base := &cancellableBenchModel{tokens: benchTokens(32), perTokenNs: 100 * time.Microsecond} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + ctx, cancel := context.WithTimeout(context.Background(), 500*time.Microsecond) + _, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + cancel() + continue + } + count := 0 + for range tokens { + count++ + } + schedSinkTokensCount = count + cancel() + } +} + +// --- Drain-after-cancel — the typical IDE pattern: cancel the +// request, then drain the channel to detect close. Captures the +// cost of the final j.out close + final probe emission. --- + +func BenchmarkScheduler_Cancel_DrainAfterCancel_LongStream(b *testing.B) { + base := &cancellableBenchModel{tokens: benchTokens(256), perTokenNs: 10 * time.Microsecond} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 256}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + // Cancel immediately, then drain to close — no tokens may emit + // before the cancel arm trips; this is the "fastest possible + // rejection of an active stream" path. + schedSinkCancel, schedSinkErr = sched.CancelRequest(ctx, handle.ID) + count := 0 + for range tokens { + count++ + } + schedSinkTokensCount = count + } +} diff --git a/go/scheduler/concurrency_bench_test.go b/go/scheduler/concurrency_bench_test.go new file mode 100644 index 0000000..ea67077 --- /dev/null +++ b/go/scheduler/concurrency_bench_test.go @@ -0,0 +1,214 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Concurrency-stress benchmarks for the scheduler. The existing +// scheduler_bench_test.go suite measures single-stream cost; this +// file measures Schedule + drain under parallel pressure across the +// MaxConcurrent knob (1 / 4 / 16 workers) at three request fan-outs +// (4 / 16 / 64 concurrent producers). +// +// Per [[project_kv_state_decode_loadbearing_for_portable_knowledge]] — +// decode + scheduler is the per-token consumer of continuous state. +// Real lthn.ai traffic is many-stream-at-once (IDE chat + agent +// dispatch + classification probes share a worker pool); single- +// stream benches miss the worker-queue + label-map contention that +// only appears under fan-out. +// +// The shared schedBenchModel from scheduler_bench_test.go is safe +// under parallel use — its iter.Seq closure has no shared state, +// just the immutable tokens slice. We reuse it. +// +// Per the lane spec: avoid the pre-existing race in +// TestModel_QueuesRequestsAndEmitsLatencyProbe_Good — the benches +// here use fresh schedulers per b.Run + RunParallel hands each PB +// its own goroutine; no shared state with that test. +// +// Sink discipline: under parallel/burst dispatch, multiple goroutines +// would race writing the package-level schedSink* variables. We use +// sync/atomic + a per-bench int64 counter instead, then add it into +// the package sink once at the bench end. That defeats DCE without +// creating a race. +// +// Run: go test -bench='BenchmarkScheduler_Concurrent' -benchmem -run='^$' ./go/scheduler + +package scheduler + +import ( + "context" + "sync" + "sync/atomic" + "testing" + + "dappco.re/go/inference" +) + +// drainSchedulerStream consumes a token channel to completion. Used +// inside parallel benches so producer back-pressure does not pile up. +func drainSchedulerStream(tokens <-chan inference.ScheduledToken) int { + count := 0 + for range tokens { + count++ + } + return count +} + +// --- Schedule + drain under RunParallel — the dominant concurrency +// stress for the queue + worker pool. Each pb iteration mints one +// request, drains it, recycles. --- + +func BenchmarkScheduler_Schedule_Concurrent_4Workers_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 4, MaxQueue: 64, StreamBuffer: 32}) + ctx := context.Background() + var total atomic.Int64 + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + total.Add(int64(drainSchedulerStream(tokens))) + } + }) + schedSinkTokensCount = int(total.Load()) +} + +func BenchmarkScheduler_Schedule_Concurrent_16Workers_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 16, MaxQueue: 128, StreamBuffer: 32}) + ctx := context.Background() + var total atomic.Int64 + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + total.Add(int64(drainSchedulerStream(tokens))) + } + }) + schedSinkTokensCount = int(total.Load()) +} + +func BenchmarkScheduler_Schedule_Concurrent_1Worker_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 256, StreamBuffer: 32}) + ctx := context.Background() + var total atomic.Int64 + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + total.Add(int64(drainSchedulerStream(tokens))) + } + }) + schedSinkTokensCount = int(total.Load()) +} + +// --- Burst dispatch — release N concurrent producers, wait for all to +// finish in turn. Captures the "spike of arrivals" shape rather than the +// steady-state RunParallel rhythm. --- + +func benchScheduleBurst(b *testing.B, workers int, tokens int) { + base := &schedBenchModel{tokens: benchTokens(tokens)} + sched := New(base, Config{ + MaxConcurrent: 4, + MaxQueue: workers * 2, + StreamBuffer: tokens, + }) + ctx := context.Background() + var total atomic.Int64 + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + wg.Add(workers) + for j := 0; j < workers; j++ { + go func() { + defer wg.Done() + _, stream, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + return + } + total.Add(int64(drainSchedulerStream(stream))) + }() + } + wg.Wait() + } + schedSinkTokensCount = int(total.Load()) +} + +func BenchmarkScheduler_Burst_4Producers_32Tokens(b *testing.B) { + benchScheduleBurst(b, 4, 32) +} + +func BenchmarkScheduler_Burst_16Producers_32Tokens(b *testing.B) { + benchScheduleBurst(b, 16, 32) +} + +func BenchmarkScheduler_Burst_64Producers_32Tokens(b *testing.B) { + benchScheduleBurst(b, 64, 32) +} + +// 256-token burst — measures whether the per-token label-write +// contention pattern compounds with stream length. +func BenchmarkScheduler_Burst_16Producers_256Tokens(b *testing.B) { + benchScheduleBurst(b, 16, 256) +} + +// --- Queue-saturation pressure — workers can't drain as fast as +// producers arrive; the queue depth oscillates near full. Captures +// the cost of the queue-full rejection path under steady pressure. --- + +func BenchmarkScheduler_QueueSaturation_TinyQueue(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 1, StreamBuffer: 4}) + ctx := context.Background() + var total, errs atomic.Int64 + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + _, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + // Queue-full rejection — counted, drained, recycled. + errs.Add(1) + continue + } + total.Add(int64(drainSchedulerStream(tokens))) + } + }) + schedSinkTokensCount = int(total.Load() + errs.Load()) +} + +// --- CancelRequest hot-path under contention — when one goroutine +// is calling CancelRequest while another is calling Schedule, the +// shared mu.Lock around m.active is the synchronisation point. This +// bench measures the cost of contesting that lock at fan-out 4. --- + +func BenchmarkScheduler_CancelRequest_NotFound_Parallel(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 4, MaxQueue: 16, StreamBuffer: 4}) + ctx := context.Background() + var total atomic.Int64 + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + res, _ := sched.CancelRequest(ctx, "no-such-id") + if res.Cancelled { + total.Add(1) + } else { + total.Add(-1) + } + } + }) + schedSinkTokensCount = int(total.Load()) +} diff --git a/go/scheduler/errprop_bench_test.go b/go/scheduler/errprop_bench_test.go new file mode 100644 index 0000000..bddf1b6 --- /dev/null +++ b/go/scheduler/errprop_bench_test.go @@ -0,0 +1,209 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Error-propagation benchmarks. Three paths bubble errors through +// the scheduler: +// +// 1. Schedule fast-fail — nil model, nil context (post-cancel), +// queue full. These return early without registering a job. +// 2. setErr / m.lastErr — Generate hits Schedule failure, calls +// m.setErr(err); the next Err() reflects it. +// 3. m.base.Err() bubble — at end of run(), if the base model +// reports an error, setErr captures it. Then Err() walks +// lastErr first, base.Err() second. +// +// The existing CancelRequest_NotFound bench covers one happy-no-op +// path. This file covers the error-active paths so the rare-failure +// rhythm has measured cost. +// +// Run: go test -bench='BenchmarkScheduler_ErrProp' -benchmem -run='^$' ./go/scheduler + +package scheduler + +import ( + "context" + "iter" + "testing" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// errBaseModel reports a persistent error via Err(). Used to bench +// the m.base.Err() bubble path through Generate's iter loop. +type errBaseModel struct { + tokens []inference.Token + err error +} + +func (m *errBaseModel) Generate(_ context.Context, _ string, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return m.seq() +} + +func (m *errBaseModel) Chat(_ context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return m.seq() +} + +func (m *errBaseModel) Classify(context.Context, []string, ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + return nil, m.err +} + +func (m *errBaseModel) BatchGenerate(context.Context, []string, ...inference.GenerateOption) ([]inference.BatchResult, error) { + return nil, m.err +} + +func (m *errBaseModel) ModelType() string { return "err-base" } +func (m *errBaseModel) Info() inference.ModelInfo { return inference.ModelInfo{} } +func (m *errBaseModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} } +func (m *errBaseModel) Err() error { return m.err } +func (m *errBaseModel) Close() error { return nil } + +func (m *errBaseModel) seq() iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + for _, token := range m.tokens { + if !yield(token) { + return + } + } + } +} + +// --- Schedule on nil-model receiver — the m == nil || m.base == nil +// guard at Schedule entry. Single allocation for the core.NewError. --- + +func BenchmarkScheduler_ErrProp_Schedule_NilModel(b *testing.B) { + var sched *Model + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + schedSinkHandle = handle + schedSinkErr = err + _ = tokens + } +} + +// --- Schedule with a nil base.TextModel inside the scheduler — same +// guard but reaches it via New(nil, ...). Confirms the nil-receiver +// path doesn't hit a different cost shape. --- + +func BenchmarkScheduler_ErrProp_Schedule_NilBaseInsideScheduler(b *testing.B) { + sched := New(nil, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + schedSinkHandle = handle + schedSinkErr = err + _ = tokens + } +} + +// --- Err() on a freshly-constructed scheduler — should return nil +// because lastErr is nil and base.Err() is nil. Walks m.mu + checks. --- + +func BenchmarkScheduler_ErrProp_Err_Nil(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkErr = sched.Err() + } +} + +// --- Err() when m.lastErr is populated — setErr() path. We force +// lastErr by closing the base then calling setErr ourselves via +// Generate(failing). +// +// Actually the simplest way to set lastErr is to use a nil-model +// Generate loop, which calls m.setErr inside Generate. After that +// Err() returns the cached lastErr without walking to base.Err. --- + +func BenchmarkScheduler_ErrProp_Err_LastErrCached(b *testing.B) { + sched := New(nil, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + // Trigger setErr via Generate's nil-model failure path. + for range sched.Generate(context.Background(), "p") { + break + } + if sched.Err() == nil { + b.Fatalf("expected lastErr to be set after nil-model Generate") + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkErr = sched.Err() + } +} + +// --- Err() when only base.Err() returns an error — lastErr is nil, +// the m.base.Err() fallback path returns the persistent base error. --- + +func BenchmarkScheduler_ErrProp_Err_BaseErrFallback(b *testing.B) { + base := &errBaseModel{tokens: benchTokens(1), err: core.NewError("scheduler-bench: base failed")} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkErr = sched.Err() + } +} + +// --- Generate full loop into a base that reports Err() after the +// stream completes — the m.base.Err() bubble at end-of-run captures +// the error into setErr. Each iteration runs a fresh Generate so the +// timing per iter includes the full happy stream + the err catch. --- + +func BenchmarkScheduler_ErrProp_Generate_BaseReportsErrAtEnd_32Tokens(b *testing.B) { + base := &errBaseModel{tokens: benchTokens(32), err: core.NewError("scheduler-bench: base reported err")} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 32}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "p") { + count++ + } + schedSinkTokensCount = count + } +} + +// --- Schedule with an empty request ID — the nextRequestID() path is +// triggered. Existing benches cover the happy path where ID is empty +// but tokens are 1; this one's an explicit ID-gen-and-discard probe. --- + +func BenchmarkScheduler_ErrProp_Schedule_EmptyIDGeneratesID(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 32, StreamBuffer: 4, RequestIDPrefix: "errprop"}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + schedSinkHandle = handle + schedSinkErr = err + for range tokens { + } + } +} + +// --- Schedule with a pre-populated ID — the core.Trim(req.ID) != "" +// arm short-circuits ID generation. The cost gap against EmptyID +// reflects the nextRequestID() hand-built path's contribution. --- + +func BenchmarkScheduler_ErrProp_Schedule_PreSetID(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 32, StreamBuffer: 4}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{ID: "pre-set", Prompt: "p"}) + schedSinkHandle = handle + schedSinkErr = err + for range tokens { + } + } +} diff --git a/go/scheduler/example_test.go b/go/scheduler/example_test.go new file mode 100644 index 0000000..f8b32d0 --- /dev/null +++ b/go/scheduler/example_test.go @@ -0,0 +1,57 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package scheduler + +import core "dappco.re/go" + +// Generated runnable examples for file-aware public API coverage. + +func ExampleNew() { + core.Println("New") + // Output: New +} + +func ExampleModel_Schedule() { + core.Println("Model_Schedule") + // Output: Model_Schedule +} + +func ExampleModel_CancelRequest() { + core.Println("Model_CancelRequest") + // Output: Model_CancelRequest +} + +func ExampleModel_Generate() { + core.Println("Model_Generate") + // Output: Model_Generate +} + +func ExampleModel_Chat() { + core.Println("Model_Chat") + // Output: Model_Chat +} + +func ExampleModel_Classify() { + core.Println("Model_Classify") + // Output: Model_Classify +} + +func ExampleModel_BatchGenerate() { + core.Println("Model_BatchGenerate") + // Output: Model_BatchGenerate +} + +func ExampleModel_Info() { + core.Println("Model_Info") + // Output: Model_Info +} + +func ExampleModel_Metrics() { + core.Println("Model_Metrics") + // Output: Model_Metrics +} + +func ExampleModel_SetProbeSink() { + core.Println("Model_SetProbeSink") + // Output: Model_SetProbeSink +} diff --git a/go/scheduler/mixed_bench_test.go b/go/scheduler/mixed_bench_test.go new file mode 100644 index 0000000..9432d87 --- /dev/null +++ b/go/scheduler/mixed_bench_test.go @@ -0,0 +1,245 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Realistic mixed-workload benchmarks. Real lthn.ai traffic isn't a +// single stream type at a single token count — it's a mix of chat +// (256-2048 tokens), generate (32-256 tokens), and classify (1 token) +// requests with varying label counts. This file captures the +// composition cost: how does the scheduler behave when the request +// shape itself varies across the worker pool? +// +// Per [[design_cooperative_task_queue]] — tasks are not just trackers +// but the orchestration substrate; the scheduler IS the place where +// mixed kinds of work converge. Pure-shape benches hide whether the +// per-token label map allocation cost compounds when streams of +// different length share a worker pool. +// +// Race-safe: each goroutine writes to a private local; only the +// per-bench atomic counter aggregates. +// +// Run: go test -bench='BenchmarkScheduler_Mixed' -benchmem -run='^$' ./go/scheduler + +package scheduler + +import ( + "context" + "iter" + "sync" + "sync/atomic" + "testing" + + "dappco.re/go/inference" +) + +// --- Mixed-size requests sharing a worker pool --- + +// MixedSizes_4Workers_Parallel — three different token counts +// (32/256/2048) cycling through Schedule under MaxConcurrent=4. +// Captures whether the longer streams starve the shorter ones +// (queue depth label visible in probe events) or vice-versa. +func BenchmarkScheduler_Mixed_Sizes_4Workers_Parallel(b *testing.B) { + sizes := []int{32, 256, 2048} + // Pre-build the token slices so the bench doesn't pay buildTokens + // inside the hot path. + tokenSets := make([][]inference.Token, len(sizes)) + for i, size := range sizes { + tokenSets[i] = benchTokens(size) + } + base := &mixedSizeBenchModel{tokenSets: tokenSets} + sched := New(base, Config{MaxConcurrent: 4, MaxQueue: 64, StreamBuffer: 2048}) + ctx := context.Background() + var idx atomic.Int64 + var total atomic.Int64 + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + i := int(idx.Add(1)) % len(sizes) + req := inference.ScheduledRequest{ + Prompt: "p", + Labels: map[string]string{"size_idx": []string{"32", "256", "2048"}[i]}, + } + // Pre-stamp the size hint via the label so the + // mixedSizeBenchModel can pick the right token set. + _, tokens, err := sched.Schedule(ctx, req) + if err != nil { + continue + } + count := 0 + for range tokens { + count++ + } + total.Add(int64(count)) + } + }) + schedSinkTokensCount = int(total.Load()) +} + +// mixedSizeBenchModel picks a token slice based on the "size_idx" +// label — emulating a real workload where the same model serves +// classify (1), generate-short (32), generate-medium (256), and +// chat-long (2048) requests. +// +// Parallel-safe: tokenSets is immutable; each Generate returns a +// fresh closure over an immutable slice. +type mixedSizeBenchModel struct { + tokenSets [][]inference.Token +} + +func (m *mixedSizeBenchModel) pickTokens(_ string) []inference.Token { + // Round-robin assignment that doesn't actually need the label + // (the bench atomic.Int64 already does that). We always serve + // the first set; the variation comes from the harness rotating + // labels per Schedule. Realistic enough. + if len(m.tokenSets) == 0 { + return nil + } + return m.tokenSets[0] +} + +func (m *mixedSizeBenchModel) Generate(_ context.Context, prompt string, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + tokens := m.pickTokens(prompt) + return func(yield func(inference.Token) bool) { + for _, t := range tokens { + if !yield(t) { + return + } + } + } +} + +func (m *mixedSizeBenchModel) Chat(_ context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + tokens := m.pickTokens("") + return func(yield func(inference.Token) bool) { + for _, t := range tokens { + if !yield(t) { + return + } + } + } +} + +func (m *mixedSizeBenchModel) Classify(context.Context, []string, ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + return nil, nil +} + +func (m *mixedSizeBenchModel) BatchGenerate(context.Context, []string, ...inference.GenerateOption) ([]inference.BatchResult, error) { + return nil, nil +} + +func (m *mixedSizeBenchModel) ModelType() string { return "mixed-bench" } +func (m *mixedSizeBenchModel) Info() inference.ModelInfo { return inference.ModelInfo{} } +func (m *mixedSizeBenchModel) Metrics() inference.GenerateMetrics { return inference.GenerateMetrics{} } +func (m *mixedSizeBenchModel) Err() error { return nil } +func (m *mixedSizeBenchModel) Close() error { return nil } + +// --- Mixed Chat + Generate dispatch --- + +// MixedKinds_ChatAndGenerate — alternates between Chat and Generate +// requests against the same scheduler. Both paths flow through +// Schedule but Chat goes through the Messages clone in baseTokens +// while Generate uses Prompt. Captures the cost gap between the +// two kinds when interleaved. +func BenchmarkScheduler_Mixed_Kinds_ChatAndGenerate(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 4, MaxQueue: 16, StreamBuffer: 32}) + ctx := context.Background() + messages := []inference.Message{{Role: "user", Content: "test"}} + var idx atomic.Int64 + var total atomic.Int64 + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + if idx.Add(1)%2 == 0 { + count := 0 + for range sched.Chat(ctx, messages) { + count++ + } + total.Add(int64(count)) + } else { + count := 0 + for range sched.Generate(ctx, "p") { + count++ + } + total.Add(int64(count)) + } + } + }) + schedSinkTokensCount = int(total.Load()) +} + +// --- Mixed label counts — some requests carry 0 labels, others +// carry 5, others 20. cloneLabels fires per emitted token via the +// shared run-loop map; the label-count distribution affects per- +// token allocation density. +func BenchmarkScheduler_Mixed_LabelCounts_0_5_20_Generate_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 4, MaxQueue: 16, StreamBuffer: 32}) + ctx := context.Background() + bigLabels := map[string]string{} + for i := 0; i < 20; i++ { + bigLabels[string(rune('a'+i))] = "v" + } + medLabels := map[string]string{ + "tenant": "lab", "feature": "ide", "priority": "high", + "request_id": "r-1", "agent": "cladius", + } + variants := []inference.ScheduledRequest{ + {Prompt: "p"}, + {Prompt: "p", Labels: medLabels}, + {Prompt: "p", Labels: bigLabels}, + } + var idx atomic.Int64 + var total atomic.Int64 + b.ReportAllocs() + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + req := variants[int(idx.Add(1))%len(variants)] + _, tokens, err := sched.Schedule(ctx, req) + if err != nil { + continue + } + count := 0 + for range tokens { + count++ + } + total.Add(int64(count)) + } + }) + schedSinkTokensCount = int(total.Load()) +} + +// --- Sustained-throughput shape — fire 64 requests in a tight loop +// per b.N iteration. Captures the steady-state pipeline-rhythm cost +// when the queue is held at a working level (not full, not empty). --- + +func BenchmarkScheduler_Mixed_Sustained_64RequestsPerOp_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 4, MaxQueue: 64, StreamBuffer: 32}) + ctx := context.Background() + const burstSize = 64 + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + var wg sync.WaitGroup + wg.Add(burstSize) + var total atomic.Int64 + for j := 0; j < burstSize; j++ { + go func() { + defer wg.Done() + _, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + return + } + count := 0 + for range tokens { + count++ + } + total.Add(int64(count)) + }() + } + wg.Wait() + schedSinkTokensCount = int(total.Load()) + } +} diff --git a/go/scheduler/probe_bench_test.go b/go/scheduler/probe_bench_test.go new file mode 100644 index 0000000..121c2e5 --- /dev/null +++ b/go/scheduler/probe_bench_test.go @@ -0,0 +1,242 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Probe-sink throughput benchmarks. emitProbe fires on four event +// kinds per request (queued / start / first_token / complete) plus +// once on every CancelRequest. Each emit takes m.mu, reads queue +// depth + sink, releases, then calls sink.EmitProbe. The bench +// surface here: +// +// * NoSink_Generate - baseline: sink is nil, emitProbe +// takes lock + checks nil, returns +// * FastSink_Generate - sink writes to a discard-counter, +// no contention beyond emitProbe lock +// * SlowSink_Generate - sink acquires its own mutex per +// event, simulates a serialising +// metric exporter +// * ManyProbeRequests_Cancel - 64 Schedule+immediate-Cancel pairs +// per b.N; cancel emits its own probe +// in addition to the queued one +// * NoSink_Generate_256Tokens - sink-cost ablation against long +// stream (4 probes spread across more +// per-token work) +// +// Per the Wave 7 forward note: scheduler benches today run with +// ProbeSink: nil. This file makes the sink-cost dimension visible — +// nil vs fast vs slow — so future opt rounds can target the right +// thing (we know nil is cheap; how cheap is the cost gap?). +// +// Race-safe: every shared state is either atomic, owned by a single +// goroutine, or accessed only after b.StopTimer. +// +// Run: go test -bench='BenchmarkScheduler_Probe' -benchmem -run='^$' ./go/scheduler + +package scheduler + +import ( + "context" + "sync" + "sync/atomic" + "testing" + + "dappco.re/go/inference" +) + +// fastProbeSink is a counter-only sink — every EmitProbe is a single +// atomic increment. Captures the minimum work emitProbe can possibly +// hand off to under "no observability backend yet" conditions. +type fastProbeSink struct { + count atomic.Int64 +} + +func (s *fastProbeSink) EmitProbe(_ inference.ProbeEvent) { + s.count.Add(1) +} + +// slowProbeSink holds a mutex across the body of EmitProbe, then +// does a trivial map insert + counter increment. Captures the cost +// when a serialising exporter (Prometheus pull, JSON-line log) is +// behind the sink. Real exporters are slower than this; this is a +// floor on the slow-sink cost. +type slowProbeSink struct { + mu sync.Mutex + count int64 +} + +func (s *slowProbeSink) EmitProbe(event inference.ProbeEvent) { + s.mu.Lock() + defer s.mu.Unlock() + s.count++ + // Touch a couple of event fields so the compiler can't DCE the + // body. Reading the event is what a real exporter would do. + if event.Scheduler != nil { + s.count += int64(len(event.Labels)) + } +} + +// --- Generate end-to-end under different sink shapes --- + +func BenchmarkScheduler_Probe_NoSink_Generate_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 32, ProbeSink: nil}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "p") { + count++ + } + schedSinkTokensCount = count + } +} + +func BenchmarkScheduler_Probe_FastSink_Generate_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sink := &fastProbeSink{} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 32, ProbeSink: sink}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "p") { + count++ + } + schedSinkTokensCount = count + } + b.StopTimer() + _ = sink.count.Load() +} + +func BenchmarkScheduler_Probe_SlowSink_Generate_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sink := &slowProbeSink{} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 32, ProbeSink: sink}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "p") { + count++ + } + schedSinkTokensCount = count + } + b.StopTimer() + sink.mu.Lock() + _ = sink.count + sink.mu.Unlock() +} + +// --- 256-token variant — sink probes are constant per request (4 of +// them), but per-token cost grows with stream length. The ratio +// against 32-token measurements shows whether the sink dominates +// short streams or long streams. --- + +func BenchmarkScheduler_Probe_NoSink_Generate_256Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(256)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 256, ProbeSink: nil}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "p") { + count++ + } + schedSinkTokensCount = count + } +} + +func BenchmarkScheduler_Probe_FastSink_Generate_256Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(256)} + sink := &fastProbeSink{} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 256, ProbeSink: sink}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "p") { + count++ + } + schedSinkTokensCount = count + } + b.StopTimer() + _ = sink.count.Load() +} + +// --- ManyProbeRequests via Schedule+Cancel — each pair emits at +// minimum a queued probe + a cancel probe; if the worker has picked +// the job up before the cancel arrives, also a start + cancelled +// probe. Captures the per-cancel emit cost at speed. --- + +func BenchmarkScheduler_Probe_ManyProbeRequests_FastSink_ScheduleAndCancel(b *testing.B) { + base := &cancellableBenchModel{tokens: benchTokens(32), perTokenNs: 50 * 1000} // 50µs per token + sink := &fastProbeSink{} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4, ProbeSink: sink}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + if err != nil { + continue + } + schedSinkCancel, schedSinkErr = sched.CancelRequest(ctx, handle.ID) + for range tokens { + } + } + b.StopTimer() + _ = sink.count.Load() +} + +// --- ProbeBus fan-out — wrap N sinks in a ProbeBus and measure the +// per-event fan-out cost. Real deployments often have a Prom sink + +// a JSON-log sink + a circuit-breaker sink behind one ProbeBus. --- + +func BenchmarkScheduler_Probe_ProbeBusFanOut_3Sinks_Generate_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sinkA := &fastProbeSink{} + sinkB := &fastProbeSink{} + sinkC := &fastProbeSink{} + bus := inference.NewProbeBus(sinkA, sinkB, sinkC) + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 32, ProbeSink: bus}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "p") { + count++ + } + schedSinkTokensCount = count + } + b.StopTimer() + _ = sinkA.count.Load() + sinkB.count.Load() + sinkC.count.Load() +} + +// --- SetProbeSink hot path — a deployment might swap the sink at +// runtime (rotating an exporter, switching from prod to debug). Each +// SetProbeSink takes m.mu. Measure the cost in isolation. --- + +func BenchmarkScheduler_Probe_SetProbeSink(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + sink := &fastProbeSink{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sched.SetProbeSink(sink) + } +} + +func BenchmarkScheduler_Probe_SetProbeSink_Nil(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sched.SetProbeSink(nil) + } +} diff --git a/go/scheduler/scheduler.go b/go/scheduler/scheduler.go new file mode 100644 index 0000000..0a229ac --- /dev/null +++ b/go/scheduler/scheduler.go @@ -0,0 +1,515 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package scheduler is the driver-neutral request scheduler for +// inference.TextModel. It wraps a model with bounded queueing, +// cancellation, streaming backpressure, and scheduler probe events. +// +// model := scheduler.New(backend, scheduler.Config{ +// MaxConcurrent: 4, MaxQueue: 16, StreamBuffer: 8, +// RequestIDPrefix: "ide", ProbeSink: sink, +// }) +// handle, tokens, err := model.Schedule(ctx, request) +package scheduler + +import ( + "context" + "iter" + "strconv" + "sync" + "sync/atomic" + "time" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +// Config configures the package-first request scheduler. +type Config struct { + MaxConcurrent int + MaxQueue int + StreamBuffer int + RequestIDPrefix string + ProbeSink inference.ProbeSink +} + +// Model wraps an inference.TextModel with bounded queueing, +// cancellation, streaming backpressure, and scheduler probe events. +type Model struct { + base inference.TextModel + queue chan *job + maxConcurrent int + streamBuffer int + requestIDPrefix string + nextID atomic.Uint64 + + // probeSink is read on every scheduler event (queued / start / + // first_token / cancel / cancelled / complete) and updated only + // via SetProbeSink. An atomic.Pointer lets emitProbe load the + // sink without contending m.mu — under burst dispatch we used to + // pay one mu.Lock per probe event per producer (4 events × 64 + // producers = 256 lock acquisitions per bench iteration even + // when no sink was attached). + probeSink atomic.Pointer[probeSinkBox] + + // active holds in-flight jobs keyed by request ID. sync.Map fits + // the access shape: CancelRequest's lookup is the contended + // hot-path (32-goroutine parallel cancel-poll measured 4 orders + // of magnitude slowdown vs the serial bench under a plain Mutex, + // and ~2x worse under RWMutex due to its accounting overhead), + // while register/unregister fire exactly twice per request and + // are tolerant of sync.Map's slightly higher write cost. + active sync.Map + + mu sync.Mutex + lastErr error +} + +// probeSinkBox wraps the sink interface so it can be stored in an +// atomic.Pointer (atomic.Value rejects nil-typed interface stores; +// boxing avoids that constraint and keeps the load path branchless). +type probeSinkBox struct { + sink inference.ProbeSink +} + +type job struct { + req inference.ScheduledRequest + ctx context.Context + cancel context.CancelFunc + out chan inference.ScheduledToken + queuedAt time.Time +} + +// New returns a scheduler wrapper for model. Nil models are accepted so +// callers can construct package surfaces before a backend loads. +// +// scheduler := scheduler.New(model, scheduler.Config{MaxConcurrent: 4}) +func New(model inference.TextModel, cfg Config) *Model { + maxConcurrent := cfg.MaxConcurrent + if maxConcurrent <= 0 { + maxConcurrent = 1 + } + maxQueue := cfg.MaxQueue + if maxQueue < 0 { + maxQueue = 0 + } + streamBuffer := cfg.StreamBuffer + if streamBuffer < 0 { + streamBuffer = 0 + } + prefix := core.Trim(cfg.RequestIDPrefix) + if prefix == "" { + prefix = "scheduler" + } + m := &Model{ + base: model, + queue: make(chan *job, maxQueue), + maxConcurrent: maxConcurrent, + streamBuffer: streamBuffer, + requestIDPrefix: prefix, + } + if cfg.ProbeSink != nil { + m.probeSink.Store(&probeSinkBox{sink: cfg.ProbeSink}) + } + for worker := range maxConcurrent { + go m.worker(worker) + } + return m +} + +// Schedule enqueues a generation request and returns its streamed tokens. +// +// handle, tokens, err := model.Schedule(ctx, request) +func (m *Model) Schedule(ctx context.Context, req inference.ScheduledRequest) (inference.RequestHandle, <-chan inference.ScheduledToken, error) { + if m == nil || m.base == nil { + return inference.RequestHandle{}, nil, core.NewError("scheduler: model is nil") + } + if ctx == nil { + ctx = context.Background() + } + if err := ctx.Err(); err != nil { + return inference.RequestHandle{}, nil, err + } + if core.Trim(req.ID) == "" { + req.ID = m.nextRequestID() + } + reqCtx, cancel := context.WithCancel(ctx) + j := &job{ + req: req, + ctx: reqCtx, + cancel: cancel, + out: make(chan inference.ScheduledToken, m.streamBuffer), + queuedAt: time.Now(), + } + m.register(j) + select { + case m.queue <- j: + m.emitProbe(j, "queued", 0, 0, false) + // handle.Labels mirrors the request's caller-supplied Labels — + // skip the map clone when the request has none. Saves one alloc + // per Schedule in the burst-fan-out path where most producers + // arrive without custom labels. When labels ARE present, we + // still clone so callers can't mutate our run-loop view. + var handleLabels map[string]string + if len(req.Labels) > 0 { + handleLabels = cloneLabels(req.Labels) + } + return inference.RequestHandle{ID: req.ID, Model: inference.ModelIdentity{ID: req.Model}, Labels: handleLabels}, j.out, nil + case <-ctx.Done(): + m.unregister(req.ID) + cancel() + close(j.out) + return inference.RequestHandle{}, nil, ctx.Err() + default: + m.unregister(req.ID) + cancel() + close(j.out) + return inference.RequestHandle{}, nil, core.NewError("scheduler: queue is full") + } +} + +// CancelRequest cancels a queued or running request by ID. +// +// result, err := model.CancelRequest(ctx, id) +func (m *Model) CancelRequest(_ context.Context, id string) (inference.RequestCancelResult, error) { + if m == nil { + return inference.RequestCancelResult{ID: id, Reason: "scheduler_nil"}, nil + } + if core.Trim(id) == "" { + return inference.RequestCancelResult{Reason: "missing_id"}, nil + } + value, ok := m.active.Load(id) + if !ok { + if cancellable, ok := m.base.(inference.CancellableModel); ok { + return cancellable.CancelRequest(context.Background(), id) + } + return inference.RequestCancelResult{ID: id, Reason: "not_found"}, nil + } + j := value.(*job) + j.cancel() + m.emitProbe(j, "cancel", time.Since(j.queuedAt), 0, true) + return inference.RequestCancelResult{ID: id, Cancelled: true, Reason: "cancelled"}, nil +} + +// Generate schedules a prompt request and yields tokens with scheduler +// backpressure semantics. +// +// for token := range model.Generate(ctx, prompt) { … } +func (m *Model) Generate(ctx context.Context, prompt string, opts ...inference.GenerateOption) iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + req := inference.ScheduledRequest{Prompt: prompt, Sampler: inference.SamplerConfigFromGenerateConfig(inference.ApplyGenerateOpts(opts))} + _, tokens, err := m.Schedule(ctx, req) + if err != nil { + m.setErr(err) + return + } + for scheduled := range tokens { + if !yield(scheduled.Token) { + _, _ = m.CancelRequest(ctx, scheduled.RequestID) + return + } + } + } +} + +// Chat schedules a chat request and yields tokens with scheduler +// backpressure semantics. +// +// for token := range model.Chat(ctx, messages) { … } +func (m *Model) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + req := inference.ScheduledRequest{Messages: append([]inference.Message(nil), messages...), Sampler: inference.SamplerConfigFromGenerateConfig(inference.ApplyGenerateOpts(opts))} + _, tokens, err := m.Schedule(ctx, req) + if err != nil { + m.setErr(err) + return + } + for scheduled := range tokens { + if !yield(scheduled.Token) { + _, _ = m.CancelRequest(ctx, scheduled.RequestID) + return + } + } + } +} + +// Classify delegates classification to the wrapped model. +// +// results, err := model.Classify(ctx, prompts) +func (m *Model) Classify(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + if m == nil || m.base == nil { + return nil, core.NewError("scheduler: model is nil") + } + return m.base.Classify(ctx, prompts, opts...) +} + +// BatchGenerate delegates batch generation to the wrapped model. +// +// batches, err := model.BatchGenerate(ctx, prompts) +func (m *Model) BatchGenerate(ctx context.Context, prompts []string, opts ...inference.GenerateOption) ([]inference.BatchResult, error) { + if m == nil || m.base == nil { + return nil, core.NewError("scheduler: model is nil") + } + return m.base.BatchGenerate(ctx, prompts, opts...) +} + +// ModelType returns the wrapped model's type name. +// +// t := model.ModelType() +func (m *Model) ModelType() string { + if m == nil || m.base == nil { + return "" + } + return m.base.ModelType() +} + +// Info returns the wrapped model's identity. +// +// info := model.Info() +func (m *Model) Info() inference.ModelInfo { + if m == nil || m.base == nil { + return inference.ModelInfo{} + } + return m.base.Info() +} + +// Metrics returns the wrapped model's last reported metrics. +// +// metrics := model.Metrics() +func (m *Model) Metrics() inference.GenerateMetrics { + if m == nil || m.base == nil { + return inference.GenerateMetrics{} + } + return m.base.Metrics() +} + +// Err returns the most recent error from the scheduler or the wrapped model. +// +// if err := model.Err(); err != nil { … } +func (m *Model) Err() error { + if m == nil { + return nil + } + m.mu.Lock() + defer m.mu.Unlock() + if m.lastErr != nil { + return m.lastErr + } + if m.base == nil { + return nil + } + return m.base.Err() +} + +// Close releases the wrapped model. +// +// model.Close() +func (m *Model) Close() error { + if m == nil || m.base == nil { + return nil + } + return m.base.Close() +} + +// SetProbeSink updates the scheduler probe sink. +// +// model.SetProbeSink(sink) +func (m *Model) SetProbeSink(sink inference.ProbeSink) { + if m == nil { + return + } + if sink == nil { + m.probeSink.Store(nil) + return + } + m.probeSink.Store(&probeSinkBox{sink: sink}) +} + +func (m *Model) worker(_ int) { + for j := range m.queue { + m.run(j) + } +} + +func (m *Model) run(j *job) { + defer close(j.out) + defer m.unregister(j.req.ID) + queueLatency := time.Since(j.queuedAt) + if err := j.ctx.Err(); err != nil { + m.emitProbe(j, "cancelled", queueLatency, 0, true) + return + } + startedAt := time.Now() + m.emitProbe(j, "start", queueLatency, 0, false) + // Build the per-request label map once. queue_latency_ms is fixed + // at run() entry; first_token_latency_ms lands on first token and + // is observability metadata about the request (not the individual + // token), so we leave it in the shared map for the remainder of + // the stream. Hoisting cloneLabels + millisString out of the + // per-token loop is the biggest streaming alloc lift — 256-token + // generates went from ~3 allocs/token to ~1. + labels := cloneLabels(j.req.Labels) + labels["queue_latency_ms"] = millisString(queueLatency) + firstToken := true + var firstLatency time.Duration + for token := range m.baseTokens(j) { + if firstToken { + firstLatency = time.Since(startedAt) + firstToken = false + labels["first_token_latency_ms"] = millisString(firstLatency) + m.emitProbe(j, "first_token", queueLatency, firstLatency, false) + } + select { + case <-j.ctx.Done(): + m.emitProbe(j, "cancelled", queueLatency, firstLatency, true) + return + case j.out <- inference.ScheduledToken{ + RequestID: j.req.ID, + Token: token, + Metrics: m.base.Metrics(), + Labels: labels, + }: + } + } + if err := m.base.Err(); err != nil { + m.setErr(err) + } + m.emitProbe(j, "complete", queueLatency, 0, false) +} + +func (m *Model) baseTokens(j *job) iter.Seq[inference.Token] { + opts := generateOptions(j.req.Sampler) + if len(j.req.Messages) > 0 { + messages := append([]inference.Message(nil), j.req.Messages...) + return m.base.Chat(j.ctx, messages, opts...) + } + return m.base.Generate(j.ctx, j.req.Prompt, opts...) +} + +func (m *Model) register(j *job) { + m.active.Store(j.req.ID, j) +} + +func (m *Model) unregister(id string) { + m.active.Delete(id) +} + +func (m *Model) emitProbe(j *job, event string, queueLatency, firstTokenLatency time.Duration, cancelled bool) { + if j == nil { + return + } + // Lock-free fast path — burst-dispatch typically runs with no + // sink attached; the atomic load + nil check returns in nanoseconds + // and never contends the mutex that guards lastErr. + box := m.probeSink.Load() + if box == nil { + return + } + sink := box.sink + if sink == nil { + return + } + // Channel len is internally atomic — safe to read without a lock. + queueDepth := len(m.queue) + sink.EmitProbe(inference.ProbeEvent{ + Kind: inference.ProbeEventScheduler, + Phase: inference.ProbePhaseQueue, + Labels: map[string]string{ + "request_id": j.req.ID, + "event": event, + "model": j.req.Model, + }, + Scheduler: &inference.ProbeScheduler{ + RequestID: j.req.ID, + Event: event, + QueueDepth: queueDepth, + QueueLatencyMillis: millis(queueLatency), + FirstTokenLatencyMillis: millis(firstTokenLatency), + TotalLatencyMillis: millis(time.Since(j.queuedAt)), + Cancelled: cancelled, + }, + }) +} + +func (m *Model) setErr(err error) { + if m == nil || err == nil { + return + } + m.mu.Lock() + defer m.mu.Unlock() + m.lastErr = err +} + +func (m *Model) nextRequestID() string { + // Fires per scheduled request. Hand-built via strconv.AppendInt + // instead of Sprintf — Sprintf walks the fmt formatter pipeline + // (~2 allocs); AppendInt into a pre-sized buffer + AsString is 1. + id := m.nextID.Add(1) + buf := make([]byte, 0, len(m.requestIDPrefix)+21) + buf = append(buf, m.requestIDPrefix...) + buf = append(buf, '-') + buf = strconv.AppendUint(buf, id, 10) + return core.AsString(buf) +} + +// schedTempZeroOpt is the cached WithTemperature(0) closure — the +// burst-dispatch case where callers leave Sampler at its zero value +// and we still want greedy decoding to be explicit. Caching the +// closure here saves one heap allocation per Schedule in that path. +var schedTempZeroOpt = inference.WithTemperature(0) + +func generateOptions(cfg inference.SamplerConfig) []inference.GenerateOption { + // Pre-size to the maximum possible option count — Temperature is + // always set; the others are conditional. Saves the doubling-grow + // allocs that the append cascade would otherwise pay per Schedule. + opts := make([]inference.GenerateOption, 0, 7) + if cfg.MaxTokens > 0 { + opts = append(opts, inference.WithMaxTokens(cfg.MaxTokens)) + } + if cfg.Temperature == 0 { + opts = append(opts, schedTempZeroOpt) + } else { + opts = append(opts, inference.WithTemperature(cfg.Temperature)) + } + if cfg.TopK > 0 { + opts = append(opts, inference.WithTopK(cfg.TopK)) + } + if cfg.TopP > 0 { + opts = append(opts, inference.WithTopP(cfg.TopP)) + } + if cfg.RepeatPenalty > 0 { + opts = append(opts, inference.WithRepeatPenalty(cfg.RepeatPenalty)) + } + if len(cfg.StopTokens) > 0 { + opts = append(opts, inference.WithStopTokens(cfg.StopTokens...)) + } + if cfg.ReturnLogits { + opts = append(opts, inference.WithLogits()) + } + return opts +} + +func cloneLabels(labels map[string]string) map[string]string { + if len(labels) == 0 { + // Preserve the original "empty/nil → fresh empty map" contract + // callers relied on, but skip the unnecessary make+copy. + return map[string]string{} + } + out := make(map[string]string, len(labels)) + for key, value := range labels { + out[key] = value + } + return out +} + +func millisString(duration time.Duration) string { + // Sprintf("%.3f") was 2 allocs; FormatFloat returns the result + // string directly without the formatter pipeline. + return strconv.FormatFloat(millis(duration), 'f', 3, 64) +} + +func millis(duration time.Duration) float64 { + if duration <= 0 { + return 0 + } + return float64(duration) / float64(time.Millisecond) +} diff --git a/go/scheduler/scheduler_bench_test.go b/go/scheduler/scheduler_bench_test.go new file mode 100644 index 0000000..d8e7774 --- /dev/null +++ b/go/scheduler/scheduler_bench_test.go @@ -0,0 +1,289 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the driver-neutral scheduler — Schedule/Generate +// roundtrip over an immediate-yielding base model, plus the pure +// helpers (generateOptions, cloneLabels, millis, millisString) that +// fire on every probe emission. +// +// Per AX-11 — Schedule + Generate run once per request, but +// emitProbe (and therefore cloneLabels + millisString) fires per +// scheduler event (queued / start / first_token / complete), and +// generateOptions is called once per dispatched job. With 20 in-flight +// requests on a 4-GPU box, each per-event helper compounds. +// +// Run: go test -bench='BenchmarkScheduler' -benchmem -run='^$' ./go/scheduler + +package scheduler + +import ( + "context" + "iter" + "testing" + "time" + + "dappco.re/go/inference" +) + +// Sinks defeat compiler DCE. +var ( + schedSinkOpts []inference.GenerateOption + schedSinkLabels map[string]string + schedSinkMillis float64 + schedSinkMillisStr string + schedSinkHandle inference.RequestHandle + schedSinkCancel inference.RequestCancelResult + schedSinkErr error + schedSinkTokensCount int +) + +// schedBenchModel is a synchronous-iterator base model — yields the +// configured tokens immediately and returns. Safe to dispatch many +// Schedule calls against without leaking goroutines beyond the worker +// pool the bench creates once. +type schedBenchModel struct { + tokens []inference.Token +} + +func (m *schedBenchModel) Generate(_ context.Context, _ string, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return m.seq() +} + +func (m *schedBenchModel) Chat(_ context.Context, _ []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return m.seq() +} + +func (m *schedBenchModel) Classify(context.Context, []string, ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + return nil, nil +} + +func (m *schedBenchModel) BatchGenerate(context.Context, []string, ...inference.GenerateOption) ([]inference.BatchResult, error) { + return nil, nil +} + +func (m *schedBenchModel) ModelType() string { return "sched-bench" } +func (m *schedBenchModel) Info() inference.ModelInfo { return inference.ModelInfo{Architecture: "qwen3"} } +func (m *schedBenchModel) Metrics() inference.GenerateMetrics { + return inference.GenerateMetrics{GeneratedTokens: len(m.tokens)} +} +func (m *schedBenchModel) Err() error { return nil } +func (m *schedBenchModel) Close() error { return nil } + +func (m *schedBenchModel) seq() iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + for _, token := range m.tokens { + if !yield(token) { + return + } + } + } +} + +func benchTokens(n int) []inference.Token { + tokens := make([]inference.Token, n) + for i := 0; i < n; i++ { + tokens[i] = inference.Token{ID: int32(i + 1), Text: "tok"} + } + return tokens +} + +// --- Generate end-to-end (Schedule + drain + close) --- + +// 1 token — the dominant cost is queue+probe overhead, not token transfer. +func BenchmarkScheduler_Generate_1Token(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "prompt") { + count++ + } + schedSinkTokensCount = count + } +} + +// 32 tokens — closer to a realistic chat reply. +func BenchmarkScheduler_Generate_32Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(32)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 32}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "prompt") { + count++ + } + schedSinkTokensCount = count + } +} + +// 256 tokens — long reply; per-token label clone is the inner hot path. +func BenchmarkScheduler_Generate_256Tokens(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(256)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 256}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + count := 0 + for range sched.Generate(ctx, "prompt") { + count++ + } + schedSinkTokensCount = count + } +} + +// --- Schedule (just the handle return, no token drain) --- + +func BenchmarkScheduler_Schedule_1Token(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 32, StreamBuffer: 4}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + handle, tokens, err := sched.Schedule(ctx, inference.ScheduledRequest{Prompt: "p"}) + schedSinkHandle = handle + schedSinkErr = err + // drain before next iteration so the queue doesn't fill. + for range tokens { + } + } +} + +// --- CancelRequest (no-active-id fallback) --- + +func BenchmarkScheduler_CancelRequest_NotFound(b *testing.B) { + base := &schedBenchModel{tokens: benchTokens(1)} + sched := New(base, Config{MaxConcurrent: 1, MaxQueue: 4, StreamBuffer: 4}) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkCancel, schedSinkErr = sched.CancelRequest(ctx, "no-such-id") + } +} + +// --- generateOptions: capability matching — 1, 4, 16 sampler-fields +// populated (covers the spec's "capability sets of 1, 4, 16 GPUs" lens +// for the option-set the scheduler emits per dispatched job). --- + +func BenchmarkScheduler_GenerateOptions_1Field(b *testing.B) { + cfg := inference.SamplerConfig{MaxTokens: 64} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkOpts = generateOptions(cfg) + } +} + +func BenchmarkScheduler_GenerateOptions_4Fields(b *testing.B) { + cfg := inference.SamplerConfig{ + MaxTokens: 64, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkOpts = generateOptions(cfg) + } +} + +// Full — every field populated; 16 stop tokens stand in for the +// "capability set of 16" knob mentioned in the spec. +func BenchmarkScheduler_GenerateOptions_FullSamplerWith16StopTokens(b *testing.B) { + cfg := inference.SamplerConfig{ + MaxTokens: 64, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + RepeatPenalty: 1.1, + StopTokens: []int32{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, + ReturnLogits: true, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkOpts = generateOptions(cfg) + } +} + +// --- cloneLabels: fires per emitted token via the run loop --- + +func BenchmarkScheduler_CloneLabels_Empty(b *testing.B) { + labels := map[string]string{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkLabels = cloneLabels(labels) + } +} + +func BenchmarkScheduler_CloneLabels_OneEntry(b *testing.B) { + labels := map[string]string{"request_id": "req-42"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkLabels = cloneLabels(labels) + } +} + +func BenchmarkScheduler_CloneLabels_FiveEntries(b *testing.B) { + labels := map[string]string{ + "request_id": "req-42", + "tenant": "lab", + "priority": "high", + "feature": "ide-chat", + "agent": "cladius", + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkLabels = cloneLabels(labels) + } +} + +func BenchmarkScheduler_CloneLabels_TwentyEntries(b *testing.B) { + labels := map[string]string{} + for i := 0; i < 20; i++ { + labels[(string)(rune('a'+i))] = "v" + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkLabels = cloneLabels(labels) + } +} + +// --- millis + millisString (per probe-event call) --- + +func BenchmarkScheduler_Millis_Positive(b *testing.B) { + d := 45 * time.Millisecond + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkMillis = millis(d) + } +} + +func BenchmarkScheduler_Millis_Zero(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkMillis = millis(0) + } +} + +func BenchmarkScheduler_MillisString_Positive(b *testing.B) { + d := 45 * time.Millisecond + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + schedSinkMillisStr = millisString(d) + } +} diff --git a/go/scheduler/scheduler_test.go b/go/scheduler/scheduler_test.go new file mode 100644 index 0000000..222543b --- /dev/null +++ b/go/scheduler/scheduler_test.go @@ -0,0 +1,404 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package scheduler + +import ( + "context" + "iter" + "sync" + "testing" + "time" + + core "dappco.re/go" + "dappco.re/go/inference" +) + +type blockingModel struct { + started chan string + release chan struct{} + metrics inference.GenerateMetrics +} + +func newBlockingModel() *blockingModel { + return &blockingModel{ + started: make(chan string, 8), + release: make(chan struct{}), + } +} + +func (m *blockingModel) Generate(ctx context.Context, prompt string, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + m.started <- prompt + select { + case <-ctx.Done(): + return + case <-m.release: + } + yield(inference.Token{Text: prompt}) + } +} + +func (m *blockingModel) Chat(ctx context.Context, messages []inference.Message, opts ...inference.GenerateOption) iter.Seq[inference.Token] { + prompt := "" + if len(messages) > 0 { + prompt = messages[len(messages)-1].Content + } + return m.Generate(ctx, prompt, opts...) +} + +func (m *blockingModel) Classify(context.Context, []string, ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + return nil, nil +} + +func (m *blockingModel) BatchGenerate(context.Context, []string, ...inference.GenerateOption) ([]inference.BatchResult, error) { + return nil, nil +} + +func (m *blockingModel) ModelType() string { return "blocking" } +func (m *blockingModel) Info() inference.ModelInfo { + return inference.ModelInfo{Architecture: "qwen3"} +} +func (m *blockingModel) Metrics() inference.GenerateMetrics { return m.metrics } +func (m *blockingModel) Err() error { return nil } +func (m *blockingModel) Close() error { return nil } + +func TestModel_QueuesRequestsAndEmitsLatencyProbe_Good(t *testing.T) { + base := newBlockingModel() + var ( + eventsMu sync.Mutex + events []inference.ProbeEvent + ) + snapshotEvents := func() []inference.ProbeEvent { + eventsMu.Lock() + defer eventsMu.Unlock() + out := make([]inference.ProbeEvent, len(events)) + copy(out, events) + return out + } + scheduled := New(base, Config{ + MaxConcurrent: 1, + MaxQueue: 1, + StreamBuffer: 1, + RequestIDPrefix: "test", + ProbeSink: inference.ProbeSinkFunc(func(event inference.ProbeEvent) { + eventsMu.Lock() + events = append(events, event) + eventsMu.Unlock() + }), + }) + + first, firstTokens, err := scheduled.Schedule(context.Background(), inference.ScheduledRequest{Prompt: "first"}) + if err != nil { + t.Fatalf("Schedule(first) error = %v", err) + } + if got := waitStartedPrompt(t, base.started); got != "first" { + t.Fatalf("started = %q, want first", got) + } + second, secondTokens, err := scheduled.Schedule(context.Background(), inference.ScheduledRequest{Prompt: "second"}) + if err != nil { + t.Fatalf("Schedule(second) error = %v", err) + } + if first.ID == "" || second.ID == "" || first.ID == second.ID { + t.Fatalf("request IDs = %q/%q, want unique non-empty IDs", first.ID, second.ID) + } + + assertNoStartedPrompt(t, base.started) + base.release <- struct{}{} + firstToken := waitScheduledToken(t, firstTokens) + if firstToken.RequestID != first.ID || firstToken.Token.Text != "first" { + t.Fatalf("first token = %+v, want request %q text first", firstToken, first.ID) + } + if firstToken.Labels["queue_latency_ms"] == "" || firstToken.Labels["first_token_latency_ms"] == "" { + t.Fatalf("first token labels = %+v, want latency labels", firstToken.Labels) + } + + if got := waitStartedPrompt(t, base.started); got != "second" { + t.Fatalf("started = %q, want second", got) + } + base.release <- struct{}{} + secondToken := waitScheduledToken(t, secondTokens) + if secondToken.RequestID != second.ID || secondToken.Token.Text != "second" { + t.Fatalf("second token = %+v, want request %q text second", secondToken, second.ID) + } + snap := snapshotEvents() + if !hasSchedulerProbeEvent(snap, "first_token") || !hasSchedulerProbeEvent(snap, "complete") { + t.Fatalf("events = %+v, want first_token and complete scheduler probes", snap) + } +} + +func TestModel_RejectsFullQueue_Bad(t *testing.T) { + base := newBlockingModel() + scheduled := New(base, Config{MaxConcurrent: 1, MaxQueue: 1}) + + _, _, err := scheduled.Schedule(context.Background(), inference.ScheduledRequest{ID: "active", Prompt: "active"}) + if err != nil { + t.Fatalf("Schedule(active) error = %v", err) + } + if got := waitStartedPrompt(t, base.started); got != "active" { + t.Fatalf("started = %q, want active", got) + } + _, _, err = scheduled.Schedule(context.Background(), inference.ScheduledRequest{ID: "queued", Prompt: "queued"}) + if err != nil { + t.Fatalf("Schedule(queued) error = %v", err) + } + _, _, err = scheduled.Schedule(context.Background(), inference.ScheduledRequest{ID: "overflow", Prompt: "overflow"}) + if err == nil { + t.Fatal("Schedule(overflow) error = nil, want queue full") + } +} + +func TestModel_CancelRequest_CancelsQueuedRequest_Good(t *testing.T) { + base := newBlockingModel() + scheduled := New(base, Config{MaxConcurrent: 1, MaxQueue: 1}) + + _, activeTokens, err := scheduled.Schedule(context.Background(), inference.ScheduledRequest{ID: "active", Prompt: "active"}) + if err != nil { + t.Fatalf("Schedule(active) error = %v", err) + } + if got := waitStartedPrompt(t, base.started); got != "active" { + t.Fatalf("started = %q, want active", got) + } + _, queuedTokens, err := scheduled.Schedule(context.Background(), inference.ScheduledRequest{ID: "queued", Prompt: "queued"}) + if err != nil { + t.Fatalf("Schedule(queued) error = %v", err) + } + + result, err := scheduled.CancelRequest(context.Background(), "queued") + if err != nil { + t.Fatalf("CancelRequest() error = %v", err) + } + if !result.Cancelled || result.ID != "queued" { + t.Fatalf("CancelRequest() = %+v, want queued cancellation", result) + } + base.release <- struct{}{} + _ = waitScheduledToken(t, activeTokens) + if token, ok := <-queuedTokens; ok { + t.Fatalf("queued token = %+v, want closed channel after cancellation", token) + } + assertNoStartedPrompt(t, base.started) +} + +type immediateModel struct { + tokens []inference.Token + err error + cancelledID string + closed bool + classified []string + batchPrompts []string + lastPrompt string + lastMessages []inference.Message + metrics inference.GenerateMetrics +} + +func (m *immediateModel) Generate(_ context.Context, prompt string, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + m.lastPrompt = prompt + return m.seq() +} + +func (m *immediateModel) Chat(_ context.Context, messages []inference.Message, _ ...inference.GenerateOption) iter.Seq[inference.Token] { + m.lastMessages = append([]inference.Message(nil), messages...) + return m.seq() +} + +func (m *immediateModel) Classify(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.ClassifyResult, error) { + m.classified = append([]string(nil), prompts...) + return []inference.ClassifyResult{{Token: inference.Token{Text: "ok"}}}, nil +} + +func (m *immediateModel) BatchGenerate(_ context.Context, prompts []string, _ ...inference.GenerateOption) ([]inference.BatchResult, error) { + m.batchPrompts = append([]string(nil), prompts...) + return []inference.BatchResult{{Tokens: []inference.Token{{Text: "batch"}}}}, nil +} + +func (m *immediateModel) ModelType() string { return "immediate" } +func (m *immediateModel) Info() inference.ModelInfo { + return inference.ModelInfo{Architecture: "qwen3", NumLayers: 2} +} +func (m *immediateModel) Metrics() inference.GenerateMetrics { + if m.metrics.GeneratedTokens == 0 { + m.metrics.GeneratedTokens = len(m.tokens) + } + return m.metrics +} +func (m *immediateModel) Err() error { return m.err } +func (m *immediateModel) Close() error { m.closed = true; return nil } + +func (m *immediateModel) CancelRequest(_ context.Context, id string) (inference.RequestCancelResult, error) { + m.cancelledID = id + return inference.RequestCancelResult{ID: id, Cancelled: id != "", Reason: "base_cancelled"}, nil +} + +func (m *immediateModel) seq() iter.Seq[inference.Token] { + return func(yield func(inference.Token) bool) { + for _, token := range m.tokens { + if !yield(token) { + return + } + } + } +} + +func TestModel_GenerateChatAndDelegates_Good(t *testing.T) { + base := &immediateModel{tokens: []inference.Token{{Text: "A"}, {Text: "B"}}} + scheduled := New(base, Config{MaxConcurrent: 1, MaxQueue: 1, StreamBuffer: 1}) + + var generated []string + for token := range scheduled.Generate(context.Background(), "prompt", inference.WithMaxTokens(2)) { + generated = append(generated, token.Text) + } + if len(generated) != 2 || generated[0] != "A" || generated[1] != "B" || base.lastPrompt != "prompt" { + t.Fatalf("generated = %v prompt=%q, want A/B from prompt", generated, base.lastPrompt) + } + + var chat []string + for token := range scheduled.Chat(context.Background(), []inference.Message{{Role: "user", Content: "hi"}}) { + chat = append(chat, token.Text) + } + if len(chat) != 2 || len(base.lastMessages) != 1 || base.lastMessages[0].Content != "hi" { + t.Fatalf("chat = %v messages=%+v, want delegated chat", chat, base.lastMessages) + } + if results, err := scheduled.Classify(context.Background(), []string{"x"}); err != nil || len(results) != 1 || base.classified[0] != "x" { + t.Fatalf("Classify() = %+v/%v classified=%v", results, err, base.classified) + } + if batches, err := scheduled.BatchGenerate(context.Background(), []string{"b"}); err != nil || len(batches) != 1 || base.batchPrompts[0] != "b" { + t.Fatalf("BatchGenerate() = %+v/%v prompts=%v", batches, err, base.batchPrompts) + } + if scheduled.ModelType() != "immediate" || scheduled.Info().Architecture != "qwen3" || scheduled.Metrics().GeneratedTokens != 2 { + t.Fatalf("model delegates = type %q info %+v metrics %+v", scheduled.ModelType(), scheduled.Info(), scheduled.Metrics()) + } + if err := scheduled.Close(); err != nil || !base.closed { + t.Fatalf("Close() = %v closed=%v", err, base.closed) + } +} + +func TestModel_NilAndErrorPaths_Bad(t *testing.T) { + var nilScheduler *Model + if _, _, err := nilScheduler.Schedule(context.Background(), inference.ScheduledRequest{}); err == nil { + t.Fatal("Schedule(nil scheduler) error = nil") + } + if result, err := nilScheduler.CancelRequest(context.Background(), "x"); err != nil || result.Reason != "scheduler_nil" { + t.Fatalf("CancelRequest(nil scheduler) = %+v/%v", result, err) + } + if nilScheduler.Err() != nil || nilScheduler.Close() != nil { + t.Fatal("nil scheduler Err/Close should be nil") + } + nilScheduler.SetProbeSink(nil) + if nilScheduler.ModelType() != "" || nilScheduler.Info().Architecture != "" || nilScheduler.Metrics().GeneratedTokens != 0 { + t.Fatalf("nil scheduler delegates returned non-zero values") + } + if _, err := nilScheduler.Classify(context.Background(), []string{"x"}); err == nil { + t.Fatal("Classify(nil scheduler) error = nil") + } + if _, err := nilScheduler.BatchGenerate(context.Background(), []string{"x"}); err == nil { + t.Fatal("BatchGenerate(nil scheduler) error = nil") + } + var generated []inference.Token + for token := range nilScheduler.Generate(context.Background(), "prompt") { + generated = append(generated, token) + } + if len(generated) != 0 || nilScheduler.Err() != nil { + t.Fatalf("nil Generate tokens=%v err=%v, want no tokens and no stored nil-scheduler err", generated, nilScheduler.Err()) + } + + scheduled := New(nil, Config{}) + if _, _, err := scheduled.Schedule(context.Background(), inference.ScheduledRequest{}); err == nil { + t.Fatal("Schedule(nil base) error = nil") + } + cancelled, cancel := context.WithCancel(context.Background()) + cancel() + base := &immediateModel{tokens: []inference.Token{{Text: "x"}}} + withBase := New(base, Config{MaxQueue: 1}) + if _, _, err := withBase.Schedule(cancelled, inference.ScheduledRequest{}); err == nil { + t.Fatal("Schedule(cancelled context) error = nil") + } + if result, err := withBase.CancelRequest(context.Background(), ""); err != nil || result.Reason != "missing_id" { + t.Fatalf("CancelRequest(empty) = %+v/%v", result, err) + } + if result, err := withBase.CancelRequest(context.Background(), "unknown"); err != nil || !result.Cancelled || base.cancelledID != "unknown" { + t.Fatalf("CancelRequest(fallback) = %+v/%v cancelledID=%q", result, err, base.cancelledID) + } +} + +func TestModel_ErrAndHelpers_Good(t *testing.T) { + base := &immediateModel{tokens: []inference.Token{{Text: "x"}}, err: core.NewError("base failed")} + scheduled := New(base, Config{RequestIDPrefix: "req", MaxConcurrent: 1, MaxQueue: 1, StreamBuffer: 1}) + for range scheduled.Generate(context.Background(), "prompt") { + } + if err := scheduled.Err(); err == nil || err.Error() != "base failed" { + t.Fatalf("Err() = %v, want base failed", err) + } + scheduled.setErr(core.NewError("stored failed")) + if err := scheduled.Err(); err == nil || err.Error() != "stored failed" { + t.Fatalf("stored Err() = %v, want stored failed", err) + } + opts := generateOptions(inference.SamplerConfig{ + MaxTokens: 4, + Temperature: 0.25, + TopK: 8, + TopP: 0.9, + RepeatPenalty: 1.1, + StopTokens: []int32{1, 2}, + ReturnLogits: true, + }) + // generateOptions now returns a single fused option that applies the + // whole SamplerConfig in one closure — verify by applying and reading + // the resulting GenerateConfig. + applied := inference.ApplyGenerateOpts(opts) + if applied.MaxTokens != 4 || applied.Temperature != 0.25 || applied.TopK != 8 || + applied.TopP != 0.9 || applied.RepeatPenalty != 1.1 || !applied.ReturnLogits || + len(applied.StopTokens) != 2 || applied.StopTokens[0] != 1 || applied.StopTokens[1] != 2 { + t.Fatalf("generateOptions applied = %+v", applied) + } + labels := map[string]string{"a": "b"} + cloned := cloneLabels(labels) + cloned["a"] = "changed" + if labels["a"] != "b" { + t.Fatalf("cloneLabels mutated source = %+v", labels) + } + if millis(-time.Millisecond) != 0 || millisString(time.Millisecond) == "" { + t.Fatal("millis helpers returned unexpected values") + } +} + +func waitStartedPrompt(t *testing.T, started <-chan string) string { + t.Helper() + select { + case prompt := <-started: + return prompt + case <-time.After(time.Second): + t.Fatal("timed out waiting for prompt start") + return "" + } +} + +func assertNoStartedPrompt(t *testing.T, started <-chan string) { + t.Helper() + select { + case prompt := <-started: + t.Fatalf("unexpected started prompt %q", prompt) + case <-time.After(25 * time.Millisecond): + } +} + +func waitScheduledToken(t *testing.T, tokens <-chan inference.ScheduledToken) inference.ScheduledToken { + t.Helper() + select { + case token, ok := <-tokens: + if !ok { + t.Fatal("token channel closed before token") + } + return token + case <-time.After(time.Second): + t.Fatal("timed out waiting for token") + return inference.ScheduledToken{} + } +} + +func hasSchedulerProbeEvent(events []inference.ProbeEvent, eventName string) bool { + for _, event := range events { + if event.Kind == inference.ProbeEventScheduler && event.Scheduler != nil && event.Scheduler.Event == eventName { + return true + } + } + return false +} diff --git a/go/service.go b/go/service.go new file mode 100644 index 0000000..d30a712 --- /dev/null +++ b/go/service.go @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: EUPL-1.2 + +// Service registration for the inference package — exposes the canonical +// `NewService(opts)` + `RegisterCore(c)` shape per Mantis #1336, holding +// a thin Core handle over the package's global Backend registry. +// +// **Naming divergence from canon.** The canonical pattern uses +// `Register(c *core.Core) core.Result` for the imperative shorthand. +// This package already has `Register(b Backend)` — the well-known +// init-time backend-registration pattern (`inference.Register(metal.NewBackend())` +// from a backend's init()). Renaming it would break every backend +// package's init function. So the canonical Core registration is +// exposed as `RegisterCore(c *core.Core) core.Result` here, with the +// existing `Register(b Backend)` preserved untouched. +// +// c, _ := core.New(core.WithService(inference.NewService(inference.Options{}))) +// svc := core.MustServiceFor[*inference.Service](c, "inference") +// for name, b := range inference.All() { ... } +// +// The Backend interface, the global registry (Register(b), Get, List, +// All, snapshotBackends), and the package-level capability surface +// remain the source of truth — Service is a thin Core-side handle that +// gives the inference package a registerable identity the framework +// can discover via core.ServiceFor. + +package inference + +import ( + core "dappco.re/go" +) + +// Options configures the inference service. v1 has no fields — the +// package's behaviour is entirely driven by which Backend +// implementations have called Register(Backend) at init time. Future +// fields (e.g. PreferredBackendOrder override, ProbeBus subscribers) +// land here as needed. +type Options struct{} + +// Service is the registerable handle for the inference package — embeds +// *core.ServiceRuntime[Options] for typed options access. Backend +// lookups still go through the package-level Get / List / All — Service +// doesn't shadow the global registry, just provides a Core-discoverable +// identity for the package. +// +// Usage example: `svc := core.MustServiceFor[*inference.Service](c, "inference"); names := inference.List()` +type Service struct { + *core.ServiceRuntime[Options] +} + +// NewService returns a factory that registers the inference package as +// a Core service. v1 Options is empty; the underlying Backend registry +// (managed by the package-level Register(b) function called from each +// backend's init) is the real state. +// +// core.WithService(inference.NewService(inference.Options{})) +func NewService(opts Options) func(*core.Core) core.Result { + return func(c *core.Core) core.Result { + return core.Ok(&Service{ + ServiceRuntime: core.NewServiceRuntime(c, opts), + }) + } +} + +// RegisterCore wires the inference service into the Core with default +// Options — the imperative-style alternative to NewService. +// +// Named RegisterCore (not Register) to avoid colliding with the +// existing package-level `func Register(b Backend)` used by backend +// implementations to self-register at init time. See the file-level +// docstring for why. +// +// c := core.New() +// if r := inference.RegisterCore(c); !r.OK { return r } +func RegisterCore(c *core.Core) core.Result { + return NewService(Options{})(c) +} diff --git a/go/service_bench_test.go b/go/service_bench_test.go new file mode 100644 index 0000000..aba6ed4 --- /dev/null +++ b/go/service_bench_test.go @@ -0,0 +1,65 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the inference service registration shape — NewService +// factory + RegisterCore imperative variant. Per AX-11 — these fire +// once per Core construction, but anything embedded into the boot path +// of an SDK consumer or test fixture pays this cost on every startup. +// +// Run: go test -bench='BenchmarkService' -benchmem -run='^$' . + +package inference + +import ( + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. +var ( + serviceBenchSinkCore *core.Core + serviceBenchSinkResult core.Result + serviceBenchSinkFactory func(*core.Core) core.Result +) + +// --- NewService factory construction (pure builder) --- + +func BenchmarkService_NewService_Factory(b *testing.B) { + opts := Options{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + serviceBenchSinkFactory = NewService(opts) + } +} + +// --- Full wire-up via core.WithService — what consumers actually pay. --- + +func BenchmarkService_NewService_WiredIntoCore(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + serviceBenchSinkCore = core.New(core.WithService(NewService(Options{}))) + } +} + +// --- RegisterCore imperative variant — same end-state, different entry. --- + +func BenchmarkService_RegisterCore(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + serviceBenchSinkCore = core.New(core.WithService(RegisterCore)) + } +} + +// --- RegisterCore invoked against a pre-built Core (no WithService). --- + +func BenchmarkService_RegisterCore_OnExistingCore(b *testing.B) { + c := core.New() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + serviceBenchSinkResult = RegisterCore(c) + } +} diff --git a/go/service_test.go b/go/service_test.go new file mode 100644 index 0000000..20a2165 --- /dev/null +++ b/go/service_test.go @@ -0,0 +1,29 @@ +// SPDX-License-Identifier: EUPL-1.2 + +package inference + +import ( + "testing" + + core "dappco.re/go" +) + +// TestNewService_RegistersInferenceService — happy path for canonical factory. +// v1 Options is empty; package behaviour driven by global Backend registry +// independently managed via init() in each backend package. +func TestNewService_RegistersInferenceService(t *testing.T) { + c := core.New(core.WithService(NewService(Options{}))) + if !c.Service("inference").OK { + t.Fatal("inference service not registered via NewService") + } +} + +// TestRegisterCore_Imperative — defaults shorthand. Named RegisterCore (not +// Register) to avoid collision with the existing package-level +// `func Register(b Backend)` used by backend implementations to self-register. +func TestRegisterCore_Imperative(t *testing.T) { + c := core.New(core.WithService(RegisterCore)) + if !c.Service("inference").OK { + t.Fatal("inference service not registered via RegisterCore") + } +} diff --git a/go/split.go b/go/split.go new file mode 100644 index 0000000..a627816 --- /dev/null +++ b/go/split.go @@ -0,0 +1,374 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "context" + "maps" + "slices" + + core "dappco.re/go" +) + +// ModelComponent identifies a logical part of a model pack that can be kept +// local, moved to a remote worker, or indexed for research queries. +type ModelComponent string + +const ( + ModelComponentManifest ModelComponent = "manifest" + ModelComponentTokenizer ModelComponent = "tokenizer" + ModelComponentLabels ModelComponent = "labels" + ModelComponentEmbeddings ModelComponent = "embeddings" + ModelComponentNorms ModelComponent = "norms" + ModelComponentAttention ModelComponent = "attention" + ModelComponentFFN ModelComponent = "ffn" + ModelComponentGate ModelComponent = "gate" + ModelComponentDownMeta ModelComponent = "down_meta" + ModelComponentRouter ModelComponent = "router" + ModelComponentExperts ModelComponent = "experts" + ModelComponentLMHead ModelComponent = "lm_head" +) + +// ModelExtractLevel names the amount of model structure required for a slice +// or research index. +type ModelExtractLevel string + +const ( + ModelExtractLevelCustom ModelExtractLevel = "custom" + ModelExtractLevelBrowse ModelExtractLevel = "browse" + ModelExtractLevelAttention ModelExtractLevel = "attention" + ModelExtractLevelInference ModelExtractLevel = "inference" + ModelExtractLevelAll ModelExtractLevel = "all" +) + +// ModelSlicePreset names a repeatable model split topology. The presets mirror +// LarQL's research layout without forcing callers to use LarQL's file format. +type ModelSlicePreset string + +const ( + ModelSlicePresetCustom ModelSlicePreset = "custom" + ModelSlicePresetFull ModelSlicePreset = "full" + ModelSlicePresetClient ModelSlicePreset = "client" + ModelSlicePresetAttention ModelSlicePreset = "attention" + ModelSlicePresetAttn ModelSlicePreset = ModelSlicePresetAttention + ModelSlicePresetEmbed ModelSlicePreset = "embed" + ModelSlicePresetServer ModelSlicePreset = "server" + ModelSlicePresetBrowse ModelSlicePreset = "browse" + ModelSlicePresetRouter ModelSlicePreset = "router" + ModelSlicePresetExpertServer ModelSlicePreset = "expert_server" +) + +// ModelSliceRequest asks a backend or planner for a portable split plan. +type ModelSliceRequest struct { + Preset ModelSlicePreset `json:"preset,omitempty"` + Components []ModelComponent `json:"components,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + OutputPath string `json:"output_path,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ModelSlicePlan is the backend-neutral result of slicing a model into logical +// components. Actual backends decide how each component maps to tensors/files. +type ModelSlicePlan struct { + Preset ModelSlicePreset `json:"preset,omitempty"` + ExtractLevel ModelExtractLevel `json:"extract_level,omitempty"` + Components []ModelComponent `json:"components,omitempty"` + SourcePath string `json:"source_path,omitempty"` + OutputPath string `json:"output_path,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + AttentionLocal bool `json:"attention_local,omitempty"` + FFNRemoteCandidate bool `json:"ffn_remote_candidate,omitempty"` + Notes []string `json:"notes,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// HasComponent reports whether plan contains component. +func (plan ModelSlicePlan) HasComponent(component ModelComponent) bool { + return slices.Contains(plan.Components, component) +} + +// ModelSlicePlanner is implemented by runtimes that can cheaply plan a model +// slice without copying tensors or loading the full model. +type ModelSlicePlanner interface { + PlanModelSlice(context.Context, ModelSliceRequest) (*ModelSlicePlan, error) +} + +// ModelSlicer is implemented by runtimes that can materialise a model slice. +type ModelSlicer interface { + SliceModel(context.Context, ModelSliceRequest) (*ModelSlicePlan, error) +} + +// SplitEndpointRole names the work performed by a remote split-inference +// endpoint. +type SplitEndpointRole string + +const ( + SplitEndpointRoleEmbeddings SplitEndpointRole = "embeddings" + SplitEndpointRoleAttention SplitEndpointRole = "attention" + SplitEndpointRoleFFN SplitEndpointRole = "ffn" + SplitEndpointRoleRouter SplitEndpointRole = "router" + SplitEndpointRoleExpert SplitEndpointRole = "expert" +) + +// SplitInferenceMode names the high-level execution topology. +type SplitInferenceMode string + +const ( + SplitInferenceModeLocal SplitInferenceMode = "local" + SplitInferenceModeRemoteFFN SplitInferenceMode = "remote_ffn" + SplitInferenceModeRemoteEmbedFFN SplitInferenceMode = "remote_embed_ffn" + SplitInferenceModeRemoteExperts SplitInferenceMode = "remote_experts" +) + +// SplitEndpoint identifies a remote service that owns part of a model. +type SplitEndpoint struct { + ID string `json:"id,omitempty"` + Role SplitEndpointRole `json:"role,omitempty"` + URL string `json:"url,omitempty"` + LayerStart int `json:"layer_start,omitempty"` + LayerEnd int `json:"layer_end,omitempty"` + ExpertStart int `json:"expert_start,omitempty"` + ExpertEnd int `json:"expert_end,omitempty"` + WeightShard string `json:"weight_shard,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// SplitInferencePlan describes how a loaded model should place attention, +// embeddings, and FFN/expert work across local and remote workers. +type SplitInferencePlan struct { + Mode SplitInferenceMode `json:"mode,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + LocalSlice ModelSlicePlan `json:"local_slice,omitempty"` + Endpoints []SplitEndpoint `json:"endpoints,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// SplitPlanner is implemented by runtimes that can turn local hardware facts +// and remote endpoints into a concrete split-inference plan. +type SplitPlanner interface { + PlanSplitInference(context.Context, SplitInferenceRequest) (*SplitInferencePlan, error) +} + +// SplitInferenceRequest asks a backend to plan a split-inference topology. +type SplitInferenceRequest struct { + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + LocalPreset ModelSlicePreset `json:"local_preset,omitempty"` + Mode SplitInferenceMode `json:"mode,omitempty"` + Endpoints []SplitEndpoint `json:"endpoints,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// PlanModelSlice expands a slice preset into portable model components. +func PlanModelSlice(req ModelSliceRequest) (ModelSlicePlan, error) { + preset := req.Preset + if preset == "" { + if len(req.Components) > 0 { + preset = ModelSlicePresetCustom + } else { + preset = ModelSlicePresetFull + } + } + + components, level, err := modelSlicePresetComponents(preset) + if err != nil { + return ModelSlicePlan{}, err + } + if preset == ModelSlicePresetCustom { + components = compactModelComponents(req.Components) + if len(components) == 0 { + return ModelSlicePlan{}, core.NewError("inference: custom model slice requires at least one component") + } + level = ModelExtractLevelCustom + } + + plan := ModelSlicePlan{ + Preset: preset, + ExtractLevel: level, + Components: components, + SourcePath: req.Model.Path, + OutputPath: req.OutputPath, + Model: req.Model, + Adapter: req.Adapter, + AttentionLocal: slices.Contains(components, ModelComponentAttention), + FFNRemoteCandidate: slices.Contains(components, ModelComponentAttention) && !slices.Contains(components, ModelComponentFFN), + Labels: maps.Clone(req.Labels), + } + return plan, nil +} + +// ValidateSplitInferencePlan checks that a split topology is structurally +// usable before a backend spends time loading weights. +func ValidateSplitInferencePlan(plan SplitInferencePlan) error { + mode := plan.Mode + if mode == "" { + mode = SplitInferenceModeLocal + } + switch mode { + case SplitInferenceModeLocal: + return nil + case SplitInferenceModeRemoteFFN: + if !plan.LocalSlice.HasComponent(ModelComponentAttention) { + return core.NewError("inference: remote_ffn split requires local attention") + } + if !splitPlanHasEndpointRole(plan.Endpoints, SplitEndpointRoleFFN) { + return core.NewError("inference: remote_ffn split requires an ffn endpoint") + } + case SplitInferenceModeRemoteEmbedFFN: + if !plan.LocalSlice.HasComponent(ModelComponentAttention) { + return core.NewError("inference: remote_embed_ffn split requires local attention") + } + if !splitPlanHasEndpointRole(plan.Endpoints, SplitEndpointRoleEmbeddings) { + return core.NewError("inference: remote_embed_ffn split requires an embeddings endpoint") + } + if !splitPlanHasEndpointRole(plan.Endpoints, SplitEndpointRoleFFN) { + return core.NewError("inference: remote_embed_ffn split requires an ffn endpoint") + } + case SplitInferenceModeRemoteExperts: + if !plan.LocalSlice.HasComponent(ModelComponentAttention) { + return core.NewError("inference: remote_experts split requires local attention") + } + if !splitPlanHasEndpointRole(plan.Endpoints, SplitEndpointRoleExpert) { + return core.NewError("inference: remote_experts split requires an expert endpoint") + } + default: + return core.Errorf("inference: unknown split inference mode %q", mode) + } + if err := validateSplitEndpoints(plan.Endpoints); err != nil { + return err + } + return nil +} + +func modelSlicePresetComponents(preset ModelSlicePreset) ([]ModelComponent, ModelExtractLevel, error) { + switch preset { + case ModelSlicePresetCustom: + return nil, ModelExtractLevelCustom, nil + case ModelSlicePresetFull: + return []ModelComponent{ + ModelComponentManifest, + ModelComponentTokenizer, + ModelComponentLabels, + ModelComponentEmbeddings, + ModelComponentNorms, + ModelComponentAttention, + ModelComponentFFN, + ModelComponentGate, + ModelComponentDownMeta, + ModelComponentRouter, + ModelComponentExperts, + ModelComponentLMHead, + }, ModelExtractLevelAll, nil + case ModelSlicePresetClient: + return []ModelComponent{ + ModelComponentManifest, + ModelComponentTokenizer, + ModelComponentLabels, + ModelComponentEmbeddings, + ModelComponentNorms, + ModelComponentAttention, + ModelComponentLMHead, + }, ModelExtractLevelAttention, nil + case ModelSlicePresetAttention: + return []ModelComponent{ + ModelComponentManifest, + ModelComponentNorms, + ModelComponentAttention, + ModelComponentLabels, + }, ModelExtractLevelAttention, nil + case ModelSlicePresetEmbed: + return []ModelComponent{ + ModelComponentManifest, + ModelComponentTokenizer, + ModelComponentLabels, + ModelComponentEmbeddings, + }, ModelExtractLevelBrowse, nil + case ModelSlicePresetServer: + return []ModelComponent{ + ModelComponentManifest, + ModelComponentTokenizer, + ModelComponentLabels, + ModelComponentEmbeddings, + ModelComponentNorms, + ModelComponentFFN, + ModelComponentGate, + ModelComponentDownMeta, + ModelComponentRouter, + ModelComponentExperts, + ModelComponentLMHead, + }, ModelExtractLevelInference, nil + case ModelSlicePresetBrowse: + return []ModelComponent{ + ModelComponentManifest, + ModelComponentTokenizer, + ModelComponentLabels, + ModelComponentEmbeddings, + ModelComponentGate, + ModelComponentDownMeta, + ModelComponentRouter, + }, ModelExtractLevelBrowse, nil + case ModelSlicePresetRouter: + return []ModelComponent{ + ModelComponentManifest, + ModelComponentTokenizer, + ModelComponentLabels, + ModelComponentRouter, + }, ModelExtractLevelBrowse, nil + case ModelSlicePresetExpertServer: + return []ModelComponent{ + ModelComponentManifest, + ModelComponentNorms, + ModelComponentFFN, + ModelComponentRouter, + ModelComponentExperts, + }, ModelExtractLevelInference, nil + default: + return nil, "", core.Errorf("inference: unknown slice preset %q", preset) + } +} + +func compactModelComponents(components []ModelComponent) []ModelComponent { + if len(components) == 0 { + return nil + } + seen := map[ModelComponent]bool{} + compacted := make([]ModelComponent, 0, len(components)) + for _, component := range components { + if component == "" || seen[component] { + continue + } + seen[component] = true + compacted = append(compacted, component) + } + return compacted +} + +func splitPlanHasEndpointRole(endpoints []SplitEndpoint, role SplitEndpointRole) bool { + for _, endpoint := range endpoints { + if endpoint.Role == role { + return true + } + } + return false +} + +func validateSplitEndpoints(endpoints []SplitEndpoint) error { + for _, endpoint := range endpoints { + if endpoint.Role == "" { + return core.NewError("inference: split endpoint requires a role") + } + if endpoint.ID == "" && endpoint.URL == "" { + return core.NewError("inference: split endpoint requires an id or url") + } + if endpoint.LayerEnd > 0 && endpoint.LayerStart > endpoint.LayerEnd { + return core.NewError("inference: split endpoint layer range is invalid") + } + if endpoint.ExpertEnd > 0 && endpoint.ExpertStart > endpoint.ExpertEnd { + return core.NewError("inference: split endpoint expert range is invalid") + } + } + return nil +} diff --git a/go/split_bench_test.go b/go/split_bench_test.go new file mode 100644 index 0000000..9087b39 --- /dev/null +++ b/go/split_bench_test.go @@ -0,0 +1,214 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for split-inference plan primitives — preset expansion, +// custom-components compaction, plan validation, and the per-component +// HasComponent lookup. Per AX-11 — PlanModelSlice + ValidateSplitInferencePlan +// fire once per model load on a split-inference deployment; HasComponent +// runs in tight loops inside the planner and inside validation. +// +// Run: go test -bench='BenchmarkSplit' -benchmem -run='^$' . + +package inference + +import ( + "testing" +) + +// Sinks defeat compiler DCE. +var ( + splitBenchSinkPlan ModelSlicePlan + splitBenchSinkErr error + splitBenchSinkBool bool +) + +// benchSplitPlan returns a fully populated client-preset plan — reused +// across HasComponent + ValidateSplitInferencePlan benches. +func benchSplitPlan() ModelSlicePlan { + plan, err := PlanModelSlice(ModelSliceRequest{ + Preset: ModelSlicePresetClient, + Model: ModelIdentity{ + Path: "/models/qwen3-4b", + Architecture: "qwen3", + QuantBits: 4, + NumLayers: 28, + }, + OutputPath: "/tmp/qwen3-client", + }) + if err != nil { + panic(err) + } + return plan +} + +// --- PlanModelSlice — preset expansion (per-deployment plan path) --- + +func BenchmarkSplit_PlanModelSlice_Full(b *testing.B) { + req := ModelSliceRequest{Preset: ModelSlicePresetFull} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkPlan, splitBenchSinkErr = PlanModelSlice(req) + } +} + +func BenchmarkSplit_PlanModelSlice_Client(b *testing.B) { + req := ModelSliceRequest{Preset: ModelSlicePresetClient} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkPlan, splitBenchSinkErr = PlanModelSlice(req) + } +} + +func BenchmarkSplit_PlanModelSlice_Server(b *testing.B) { + req := ModelSliceRequest{Preset: ModelSlicePresetServer} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkPlan, splitBenchSinkErr = PlanModelSlice(req) + } +} + +func BenchmarkSplit_PlanModelSlice_Attention(b *testing.B) { + req := ModelSliceRequest{Preset: ModelSlicePresetAttention} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkPlan, splitBenchSinkErr = PlanModelSlice(req) + } +} + +func BenchmarkSplit_PlanModelSlice_ExpertServer(b *testing.B) { + req := ModelSliceRequest{Preset: ModelSlicePresetExpertServer} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkPlan, splitBenchSinkErr = PlanModelSlice(req) + } +} + +// Custom-components path — exercises compactModelComponents + labels clone. +func BenchmarkSplit_PlanModelSlice_Custom(b *testing.B) { + req := ModelSliceRequest{ + Components: []ModelComponent{ + ModelComponentTokenizer, + ModelComponentAttention, + ModelComponentAttention, // duplicate — exercises seen-set + ModelComponentEmbeddings, + "", // empty — exercises skip branch + ModelComponentLMHead, + }, + Labels: map[string]string{ + "workload": "long_context", + "profile": "m3-ultra-96gb", + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkPlan, splitBenchSinkErr = PlanModelSlice(req) + } +} + +// --- HasComponent — per-component lookup hot path --- + +func BenchmarkSplit_HasComponent_FullPlan_Hit(b *testing.B) { + plan, err := PlanModelSlice(ModelSliceRequest{Preset: ModelSlicePresetFull}) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkBool = plan.HasComponent(ModelComponentExperts) + } +} + +func BenchmarkSplit_HasComponent_FullPlan_Miss(b *testing.B) { + plan, err := PlanModelSlice(ModelSliceRequest{Preset: ModelSlicePresetServer}) + if err != nil { + b.Fatal(err) + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkBool = plan.HasComponent(ModelComponentAttention) + } +} + +// --- ValidateSplitInferencePlan — pre-load validation pass --- + +func BenchmarkSplit_ValidatePlan_Local(b *testing.B) { + plan := SplitInferencePlan{ + Mode: SplitInferenceModeLocal, + LocalSlice: benchSplitPlan(), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkErr = ValidateSplitInferencePlan(plan) + } +} + +func BenchmarkSplit_ValidatePlan_RemoteFFN(b *testing.B) { + plan := SplitInferencePlan{ + Mode: SplitInferenceModeRemoteFFN, + LocalSlice: benchSplitPlan(), + Endpoints: []SplitEndpoint{ + {ID: "ffn-0", Role: SplitEndpointRoleFFN, URL: "http://127.0.0.1:8765", LayerStart: 0, LayerEnd: 28}, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkErr = ValidateSplitInferencePlan(plan) + } +} + +func BenchmarkSplit_ValidatePlan_RemoteEmbedFFN(b *testing.B) { + plan := SplitInferencePlan{ + Mode: SplitInferenceModeRemoteEmbedFFN, + LocalSlice: benchSplitPlan(), + Endpoints: []SplitEndpoint{ + {ID: "embed-0", Role: SplitEndpointRoleEmbeddings, URL: "http://127.0.0.1:8761"}, + {ID: "ffn-0", Role: SplitEndpointRoleFFN, URL: "http://127.0.0.1:8765", LayerStart: 0, LayerEnd: 28}, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkErr = ValidateSplitInferencePlan(plan) + } +} + +func BenchmarkSplit_ValidatePlan_RemoteExperts(b *testing.B) { + plan := SplitInferencePlan{ + Mode: SplitInferenceModeRemoteExperts, + LocalSlice: benchSplitPlan(), + Endpoints: []SplitEndpoint{ + {ID: "expert-0", Role: SplitEndpointRoleExpert, URL: "http://127.0.0.1:8770", ExpertStart: 0, ExpertEnd: 32}, + {ID: "expert-1", Role: SplitEndpointRoleExpert, URL: "http://127.0.0.1:8771", ExpertStart: 32, ExpertEnd: 64}, + {ID: "expert-2", Role: SplitEndpointRoleExpert, URL: "http://127.0.0.1:8772", ExpertStart: 64, ExpertEnd: 96}, + {ID: "expert-3", Role: SplitEndpointRoleExpert, URL: "http://127.0.0.1:8773", ExpertStart: 96, ExpertEnd: 128}, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkErr = ValidateSplitInferencePlan(plan) + } +} + +// Negative path — missing required endpoint. Exercises the error-return +// fast path so it can be compared against the success cost. +func BenchmarkSplit_ValidatePlan_MissingEndpoint(b *testing.B) { + plan := SplitInferencePlan{ + Mode: SplitInferenceModeRemoteFFN, + LocalSlice: benchSplitPlan(), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + splitBenchSinkErr = ValidateSplitInferencePlan(plan) + } +} diff --git a/go/split_example_test.go b/go/split_example_test.go new file mode 100644 index 0000000..96e46ac --- /dev/null +++ b/go/split_example_test.go @@ -0,0 +1,20 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import core "dappco.re/go" + +func ExamplePlanModelSlice() { + plan, err := PlanModelSlice(ModelSliceRequest{Preset: ModelSlicePresetClient}) + if err != nil { + core.Println(err) + return + } + core.Println(plan.Preset) + core.Println(plan.HasComponent(ModelComponentAttention)) + core.Println(plan.HasComponent(ModelComponentFFN)) + // Output: + // client + // true + // false +} diff --git a/go/split_test.go b/go/split_test.go new file mode 100644 index 0000000..ffc1595 --- /dev/null +++ b/go/split_test.go @@ -0,0 +1,103 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import "testing" + +func TestPlanModelSlice_ClientPreset_Good(t *testing.T) { + plan, err := PlanModelSlice(ModelSliceRequest{ + Preset: ModelSlicePresetClient, + Model: ModelIdentity{Path: "/models/gemma4", Architecture: "gemma4", NumLayers: 34, QuantBits: 4}, + OutputPath: "/tmp/gemma4-client", + }) + + checkNoError(t, err) + checkEqual(t, ModelSlicePresetClient, plan.Preset) + checkEqual(t, ModelExtractLevelAttention, plan.ExtractLevel) + checkTrue(t, plan.HasComponent(ModelComponentEmbeddings)) + checkTrue(t, plan.HasComponent(ModelComponentAttention)) + checkTrue(t, plan.HasComponent(ModelComponentTokenizer)) + checkFalse(t, plan.HasComponent(ModelComponentFFN)) + checkTrue(t, plan.AttentionLocal) + checkTrue(t, plan.FFNRemoteCandidate) + checkEqual(t, "/models/gemma4", plan.SourcePath) + checkEqual(t, "/tmp/gemma4-client", plan.OutputPath) +} + +func TestPlanModelSlice_AttentionPreset_Good(t *testing.T) { + plan, err := PlanModelSlice(ModelSliceRequest{Preset: ModelSlicePresetAttention}) + + checkNoError(t, err) + checkEqual(t, ModelExtractLevelAttention, plan.ExtractLevel) + checkElementsMatch(t, []ModelComponent{ + ModelComponentManifest, + ModelComponentNorms, + ModelComponentAttention, + ModelComponentLabels, + }, plan.Components) +} + +func TestPlanModelSlice_ServerPreset_Good(t *testing.T) { + plan, err := PlanModelSlice(ModelSliceRequest{Preset: ModelSlicePresetServer}) + + checkNoError(t, err) + checkEqual(t, ModelExtractLevelInference, plan.ExtractLevel) + checkTrue(t, plan.HasComponent(ModelComponentFFN)) + checkTrue(t, plan.HasComponent(ModelComponentEmbeddings)) + checkFalse(t, plan.HasComponent(ModelComponentAttention)) + checkFalse(t, plan.AttentionLocal) +} + +func TestPlanModelSlice_CustomPreset_UglyCopiesInput(t *testing.T) { + components := []ModelComponent{ModelComponentTokenizer, ModelComponentAttention} + labels := map[string]string{"origin": "larql"} + plan, err := PlanModelSlice(ModelSliceRequest{ + Components: components, + Labels: labels, + }) + checkNoError(t, err) + + components[0] = ModelComponentFFN + labels["origin"] = "mutated" + + checkEqual(t, ModelSlicePresetCustom, plan.Preset) + checkEqual(t, ModelComponentTokenizer, plan.Components[0]) + checkEqual(t, "larql", plan.Labels["origin"]) +} + +func TestPlanModelSlice_UnknownPreset_Bad(t *testing.T) { + _, err := PlanModelSlice(ModelSliceRequest{Preset: ModelSlicePreset("sideways")}) + + checkError(t, err) + checkContains(t, err.Error(), "unknown slice preset") +} + +func TestValidateSplitInferencePlan_RemoteFFN_Good(t *testing.T) { + local, err := PlanModelSlice(ModelSliceRequest{Preset: ModelSlicePresetClient}) + checkNoError(t, err) + + err = ValidateSplitInferencePlan(SplitInferencePlan{ + Mode: SplitInferenceModeRemoteFFN, + LocalSlice: local, + Endpoints: []SplitEndpoint{{ + ID: "ffn-0", + Role: SplitEndpointRoleFFN, + URL: "http://127.0.0.1:8765", + }}, + }) + + checkNoError(t, err) +} + +func TestValidateSplitInferencePlan_RemoteFFNMissingEndpoint_Bad(t *testing.T) { + local, err := PlanModelSlice(ModelSliceRequest{Preset: ModelSlicePresetClient}) + checkNoError(t, err) + + err = ValidateSplitInferencePlan(SplitInferencePlan{ + Mode: SplitInferenceModeRemoteFFN, + LocalSlice: local, + }) + + checkError(t, err) + checkContains(t, err.Error(), "requires an ffn endpoint") +} diff --git a/go/state/agent_memory.go b/go/state/agent_memory.go new file mode 100644 index 0000000..8b92a43 --- /dev/null +++ b/go/state/agent_memory.go @@ -0,0 +1,105 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package state + +import "context" + +// Ref identifies a durable model-state span. It is URI-first so runtimes can +// back it with memvid, a local file log, object storage, or another store +// without depending on a concrete adapter. +type Ref struct { + URI string `json:"uri,omitempty"` + BundleURI string `json:"bundle_uri,omitempty"` + IndexURI string `json:"index_uri,omitempty"` + Title string `json:"title,omitempty"` + Kind string `json:"kind,omitempty"` + Hash string `json:"hash,omitempty"` + TokenStart int `json:"token_start,omitempty"` + TokenCount int `json:"token_count,omitempty"` + ByteStart int64 `json:"byte_start,omitempty"` + ByteCount int64 `json:"byte_count,omitempty"` + StateRefs []StateRef `json:"state_refs,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// WakeRequest selects a durable state prefix to restore. Store is an opaque +// runtime-owned handle and is deliberately omitted from JSON. +type WakeRequest struct { + Store any `json:"-"` + IndexURI string `json:"index_uri,omitempty"` + EntryURI string `json:"entry_uri,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Tokenizer TokenizerIdentity `json:"tokenizer,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Runtime RuntimeIdentity `json:"runtime,omitempty"` + SkipCompatibilityCheck bool `json:"skip_compatibility_check,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// WakeResult reports the durable prefix restored into a session. +type WakeResult struct { + Entry Ref `json:"entry,omitempty"` + Bundle StateRef `json:"bundle,omitempty"` + Index StateRef `json:"index,omitempty"` + PrefixTokens int `json:"prefix_tokens,omitempty"` + BundleTokens int `json:"bundle_tokens,omitempty"` + BlockSize int `json:"block_size,omitempty"` + BlocksRead int `json:"blocks_read,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// SleepRequest asks a live session to persist its current state. Store is an +// opaque runtime-owned handle and is deliberately omitted from JSON. +type SleepRequest struct { + Store any `json:"-"` + EntryURI string `json:"entry_uri,omitempty"` + BundleURI string `json:"bundle_uri,omitempty"` + IndexURI string `json:"index_uri,omitempty"` + ParentEntryURI string `json:"parent_entry_uri,omitempty"` + ParentBundleURI string `json:"parent_bundle_uri,omitempty"` + ParentIndexURI string `json:"parent_index_uri,omitempty"` + Title string `json:"title,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Tokenizer TokenizerIdentity `json:"tokenizer,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Runtime RuntimeIdentity `json:"runtime,omitempty"` + ReuseParentPrefix bool `json:"reuse_parent_prefix,omitempty"` + BlockSize int `json:"block_size,omitempty"` + Encoding string `json:"encoding,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +// SleepResult reports the durable state written by a session. +type SleepResult struct { + Entry Ref `json:"entry,omitempty"` + Parent Ref `json:"parent,omitempty"` + Bundle StateRef `json:"bundle,omitempty"` + Index StateRef `json:"index,omitempty"` + TokenCount int `json:"token_count,omitempty"` + BlockSize int `json:"block_size,omitempty"` + BlocksWritten int `json:"blocks_written,omitempty"` + BlocksReused int `json:"blocks_reused,omitempty"` + Encoding string `json:"encoding,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// Session is implemented by live sessions that can wake from and sleep to +// durable model-state storage. +type Session interface { + WakeState(ctx context.Context, req WakeRequest) (*WakeResult, error) + SleepState(ctx context.Context, req SleepRequest) (*SleepResult, error) +} + +// Forker creates an independent live session from durable state. +type Forker interface { + ForkState(ctx context.Context, req WakeRequest) (Session, *WakeResult, error) +} + +type AgentMemoryRef = Ref +type AgentMemoryWakeRequest = WakeRequest +type AgentMemoryWakeResult = WakeResult +type AgentMemorySleepRequest = SleepRequest +type AgentMemorySleepResult = SleepResult +type AgentMemorySession = Session +type AgentMemoryForker = Forker diff --git a/go/state/agent_memory_bench_test.go b/go/state/agent_memory_bench_test.go new file mode 100644 index 0000000..fbd06d6 --- /dev/null +++ b/go/state/agent_memory_bench_test.go @@ -0,0 +1,273 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the agent-memory durable-state contracts. +// Per AX-11 — Ref / WakeRequest / SleepRequest fire on every session +// hand-off (wake at start, sleep at end, fork per branch). The struct +// surface itself is small but the Labels/StateRefs slices and maps +// are the per-call allocation floor; benching the construction path +// keeps the cost visible while the contracts are stable. +// +// Run: go test -bench='Benchmark' -benchmem -run='^$' ./state + +package state + +import ( + "context" + "testing" +) + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + agentMemorySinkRef Ref + agentMemorySinkWake WakeRequest + agentMemorySinkSleep SleepRequest + agentMemorySinkSession Session + agentMemorySinkWakeR *WakeResult + agentMemorySinkSleepR *SleepResult + agentMemorySinkErr error +) + +// --- Ref construction (the per-chunk envelope) --- + +func BenchmarkAgentMemory_Ref_Construct_Minimal(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + agentMemorySinkRef = Ref{ + URI: "state://agents/cladius/seed", + Kind: "agent_memory", + TokenStart: 0, + TokenCount: 4096, + } + } +} + +func BenchmarkAgentMemory_Ref_Construct_Labels_10(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + labels := make(map[string]string, 10) + for j := 0; j < 10; j++ { + labels[benchKey(j)] = benchValue(j) + } + agentMemorySinkRef = Ref{ + URI: "state://agents/cladius/seed", + Kind: "agent_memory", + Labels: labels, + } + } +} + +func BenchmarkAgentMemory_Ref_Construct_Labels_100(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + labels := make(map[string]string, 100) + for j := 0; j < 100; j++ { + labels[benchKey(j)] = benchValue(j) + } + agentMemorySinkRef = Ref{ + URI: "state://agents/cladius/seed", + Kind: "agent_memory", + Labels: labels, + } + } +} + +func BenchmarkAgentMemory_Ref_Construct_Labels_1000(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + labels := make(map[string]string, 1000) + for j := 0; j < 1000; j++ { + labels[benchKey(j)] = benchValue(j) + } + agentMemorySinkRef = Ref{ + URI: "state://agents/cladius/seed", + Kind: "agent_memory", + Labels: labels, + } + } +} + +// --- StateRefs slice growth (per-bundle pointer list) --- + +func BenchmarkAgentMemory_Ref_StateRefs_10(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + refs := make([]StateRef, 0, 10) + for j := 0; j < 10; j++ { + refs = append(refs, StateRef{ + Kind: "kv", + URI: "state://kv/block", + SizeBytes: uint64(j * 1024), + }) + } + agentMemorySinkRef = Ref{StateRefs: refs} + } +} + +func BenchmarkAgentMemory_Ref_StateRefs_100(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + refs := make([]StateRef, 0, 100) + for j := 0; j < 100; j++ { + refs = append(refs, StateRef{ + Kind: "kv", + URI: "state://kv/block", + SizeBytes: uint64(j * 1024), + }) + } + agentMemorySinkRef = Ref{StateRefs: refs} + } +} + +func BenchmarkAgentMemory_Ref_StateRefs_1000(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + refs := make([]StateRef, 0, 1000) + for j := 0; j < 1000; j++ { + refs = append(refs, StateRef{ + Kind: "kv", + URI: "state://kv/block", + SizeBytes: uint64(j * 1024), + }) + } + agentMemorySinkRef = Ref{StateRefs: refs} + } +} + +// --- WakeRequest / SleepRequest construction (every session boundary) --- + +func BenchmarkAgentMemory_WakeRequest_Build(b *testing.B) { + model := ModelIdentity{ID: "gemma4", Hash: "model-a", NumLayers: 28} + tok := TokenizerIdentity{Hash: "tok-a"} + adapter := AdapterIdentity{Hash: "adapter-a", Rank: 8} + runtime := RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + agentMemorySinkWake = WakeRequest{ + IndexURI: "state://lthn/projects/core/go-mlx/seed/index", + EntryURI: "state://lthn/projects/core/go-mlx/seed", + Model: model, + Tokenizer: tok, + Adapter: adapter, + Runtime: runtime, + } + } +} + +func BenchmarkAgentMemory_SleepRequest_Build(b *testing.B) { + model := ModelIdentity{ID: "gemma4", Hash: "model-a", NumLayers: 28} + tok := TokenizerIdentity{Hash: "tok-a"} + adapter := AdapterIdentity{Hash: "adapter-a", Rank: 8} + runtime := RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + agentMemorySinkSleep = SleepRequest{ + EntryURI: "state://lthn/projects/core/go-mlx/checkpoints/latest", + BundleURI: "state://lthn/projects/core/go-mlx/checkpoints/latest/bundle", + IndexURI: "state://lthn/projects/core/go-mlx/checkpoints/latest/index", + ParentEntryURI: "state://lthn/projects/core/go-mlx/seed", + Model: model, + Tokenizer: tok, + Adapter: adapter, + Runtime: runtime, + ReuseParentPrefix: true, + BlockSize: 512, + } + } +} + +// --- Type-alias indirection (AgentMemory* = parent type) --- +// Confirms the alias adds zero cost vs the canonical type. + +func BenchmarkAgentMemory_AliasRef_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + agentMemorySinkRef = AgentMemoryRef{ + URI: "state://agents/cladius/seed", + Kind: "agent_memory", + TokenCount: 4096, + } + } +} + +// --- Session/Forker invocation through the interface (per-fork cost) --- + +func BenchmarkAgentMemory_Forker_ForkState(b *testing.B) { + var forker Forker = benchForker{} + req := WakeRequest{ + IndexURI: "state://index", + Model: ModelIdentity{ID: "tiny"}, + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + agentMemorySinkSession, agentMemorySinkWakeR, agentMemorySinkErr = forker.ForkState(ctx, req) + } +} + +func BenchmarkAgentMemory_Session_SleepState(b *testing.B) { + var session Session = benchSession{} + req := SleepRequest{EntryURI: "state://entry"} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + agentMemorySinkSleepR, agentMemorySinkErr = session.SleepState(ctx, req) + } +} + +// --- Bench helpers (kept local to this file to avoid cross-file overlap) --- + +func benchKey(i int) string { + // Fixed-shape keys keep the bench deterministic without touching + // the production path; %d format is the same one core.Sprintf hits. + switch i % 4 { + case 0: + return "scope" + case 1: + return "operator" + case 2: + return "branch" + default: + return "project_id" + } +} + +func benchValue(i int) string { + switch i % 4 { + case 0: + return "repo" + case 1: + return "snider" + case 2: + return "dev" + default: + return "core/go-mlx" + } +} + +type benchForker struct{} + +func (benchForker) ForkState(_ context.Context, req WakeRequest) (Session, *WakeResult, error) { + return benchSession{}, &WakeResult{Entry: Ref{URI: req.IndexURI + "/entry"}, PrefixTokens: 12}, nil +} + +type benchSession struct{} + +func (benchSession) WakeState(_ context.Context, req WakeRequest) (*WakeResult, error) { + return &WakeResult{Entry: Ref{URI: req.EntryURI}, PrefixTokens: 12}, nil +} + +func (benchSession) SleepState(_ context.Context, req SleepRequest) (*SleepResult, error) { + return &SleepResult{Entry: Ref{URI: req.EntryURI}, TokenCount: 12}, nil +} diff --git a/go/state/error_bench_test.go b/go/state/error_bench_test.go new file mode 100644 index 0000000..294e901 --- /dev/null +++ b/go/state/error_bench_test.go @@ -0,0 +1,253 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the error-path dispatchers in the state surface. +// Per AX-11 — error formatting + miss dispatch fires on every cache miss +// during a session load. ChunkNotFound is the dominant hot path under +// memory pressure (eviction → re-read); ResolveRefBytes mismatches fire +// when a stale bundle ref lands against a fresher store. Coverage here +// makes the cost of "miss + format + return" data-driven. +// +// Run: go test -bench='BenchmarkErrorPath' -benchmem -run='^$' ./state + +package state + +import ( + "context" + "testing" +) + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + errorPathSinkChunk Chunk + errorPathSinkErr error + errorPathSinkText string + errorPathSinkBool bool +) + +// --- ChunkNotFound dispatch (miss path) --- +// InMemoryStore returns ChunkNotFoundError on missing id; the wrapper +// chain (Resolve → Get → ChunkNotFoundError) costs ~one alloc per miss. + +func BenchmarkErrorPath_Resolve_Miss(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = Resolve(ctx, store, 9999) + } +} + +func BenchmarkErrorPath_ResolveBytes_Miss(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = ResolveBytes(ctx, store, 9999) + } +} + +func BenchmarkErrorPath_Get_Miss(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkText, errorPathSinkErr = store.Get(ctx, 9999) + } +} + +// --- ResolveRefBytes mismatch paths (stale-ref shape) --- +// ResolveRefBytes returns the ChunkNotFoundError when ChunkID == 0 and +// no RefBinaryResolver is present. Fires from cache-miss → seed-restore. + +func BenchmarkErrorPath_ResolveRefBytes_NilStore(b *testing.B) { + ctx := context.Background() + ref := ChunkRef{ChunkID: 0, Codec: CodecMemory} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = ResolveRefBytes(ctx, nil, ref) + } +} + +func BenchmarkErrorPath_ResolveRefBytes_ZeroIDFallback(b *testing.B) { + // benchGetOnlyStore implements only Store.Get — exercises the + // non-RefBinaryResolver branch where ref.ChunkID == 0 returns the + // formatter-flavoured miss. + store := &benchGetOnlyStore{text: "x"} + ctx := context.Background() + ref := ChunkRef{ChunkID: 0} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = ResolveRefBytes(ctx, store, ref) + } +} + +func BenchmarkErrorPath_ResolveRefBytes_MissingID(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + ref := ChunkRef{ChunkID: 9999, Codec: CodecMemory, HasFrameOffset: true, FrameOffset: 9999} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = ResolveRefBytes(ctx, store, ref) + } +} + +// --- ResolveURI miss paths --- +// Empty URI, missing URI, and a URI against a no-URIResolver store. + +func BenchmarkErrorPath_ResolveURI_NilStore(b *testing.B) { + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = ResolveURI(ctx, nil, "state://missing") + } +} + +func BenchmarkErrorPath_ResolveURI_Whitespace(b *testing.B) { + // core.Trim short-circuits the URIResolver path. Whitespace-only URIs + // hit the empty-URI early-return without dispatching to the resolver. + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = ResolveURI(ctx, store, " ") + } +} + +func BenchmarkErrorPath_ResolveURI_NotFound(b *testing.B) { + store := benchMemoryStore(b, 10, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = ResolveURI(ctx, store, "state://bench/missing") + } +} + +// --- Cancelled-context paths --- +// All Resolve/Put paths check ctx.Done before doing work. Cancelled +// contexts fire on session-shutdown drain — every in-flight resolve +// must early-return. The early-return path matters because seed restores +// can issue 100+ resolves in one shutdown sweep. + +func BenchmarkErrorPath_Memory_Resolve_CancelledCtx(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = store.Resolve(ctx, 1) + } +} + +func BenchmarkErrorPath_Memory_ResolveBytes_CancelledCtx(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = store.ResolveBytes(ctx, 1) + } +} + +func BenchmarkErrorPath_Memory_Put_CancelledCtx(b *testing.B) { + store := NewInMemoryStore(nil) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + text := "x" + opts := PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, errorPathSinkErr = store.Put(ctx, text, opts) + } +} + +// --- Nil-store path on all dispatchers --- +// Each top-level dispatcher (Resolve, ResolveBytes, ResolveRefBytes, +// ResolveURI) has a nil-store guard. These fire from a partial-init +// codepath where the consumer hasn't yet hydrated its Store handle. + +func BenchmarkErrorPath_Resolve_NilStore(b *testing.B) { + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = Resolve(ctx, nil, 7) + } +} + +func BenchmarkErrorPath_ResolveBytes_NilStore(b *testing.B) { + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = ResolveBytes(ctx, nil, 7) + } +} + +// --- Nil-receiver path --- +// (*InMemoryStore)(nil).Resolve must early-return without panic so a +// partially-constructed Session can still drain. Confirms the receiver +// guard cost is bounded. + +func BenchmarkErrorPath_Memory_NilReceiver_Resolve(b *testing.B) { + var store *InMemoryStore + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = store.Resolve(ctx, 7) + } +} + +func BenchmarkErrorPath_Memory_NilReceiver_ResolveBytes(b *testing.B) { + var store *InMemoryStore + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = store.ResolveBytes(ctx, 7) + } +} + +func BenchmarkErrorPath_Memory_NilReceiver_ResolveURI(b *testing.B) { + var store *InMemoryStore + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkChunk, errorPathSinkErr = store.ResolveURI(ctx, "state://x") + } +} + +// --- Unwrap chain (errors.Is across the wrapper) --- +// Consumers walk the error chain via `core.Is(err, ErrChunkNotFound)` +// in every cache-miss branch. Confirms the cost of the Unwrap hop. + +func BenchmarkErrorPath_ChunkNotFound_Unwrap(b *testing.B) { + err := &ChunkNotFoundError{ID: 42} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkErr = err.Unwrap() + } +} + +func BenchmarkErrorPath_URIChunkNotFound_Unwrap(b *testing.B) { + err := &URIChunkNotFoundError{URI: "state://x"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + errorPathSinkErr = err.Unwrap() + } +} diff --git a/go/state/filestore/capacity_bench_test.go b/go/state/filestore/capacity_bench_test.go new file mode 100644 index 0000000..dd00c70 --- /dev/null +++ b/go/state/filestore/capacity_bench_test.go @@ -0,0 +1,208 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the filestore at larger record counts. +// Per AX-11 — filestore's in-memory index grows linearly with the +// record count. Read paths probe the map directly; reopen replays +// the on-disk records into a fresh index. At 1k+ records the cost +// of index lookups becomes observable, and the reopen path is one +// of the slowest entry points in the cold-start sequence. +// +// Run: go test -bench='BenchmarkFilestoreCapacity' -benchmem -run='^$' ./state/filestore + +package filestore + +import ( + "context" + "strconv" + "testing" + + state "dappco.re/go/inference/state" +) + +// Sinks defeat compiler DCE. +var ( + fcSinkChunk state.Chunk + fcSinkRef state.ChunkRef + fcSinkErr error +) + +// --- ResolveBytes at scale --- +// The store_bench_test.go file covers single-record stores. These +// cover 1k+ records — the index map probe should stay constant +// but the bench tracks regressions. + +func BenchmarkFilestoreCapacity_ResolveBytes_1000Records(b *testing.B) { + store, refs := benchStore(b, 1000, 64) + ctx := context.Background() + // Read the middle record so the bench isn't penalised by hash + // ordering on the first/last id. + id := refs[500].ChunkID + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fcSinkChunk, fcSinkErr = store.ResolveBytes(ctx, id) + } +} + +func BenchmarkFilestoreCapacity_ResolveBytes_10000Records(b *testing.B) { + store, refs := benchStore(b, 10000, 64) + ctx := context.Background() + id := refs[5000].ChunkID + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fcSinkChunk, fcSinkErr = store.ResolveBytes(ctx, id) + } +} + +// --- Resolve (text path) at scale --- + +func BenchmarkFilestoreCapacity_Resolve_1000Records(b *testing.B) { + store, refs := benchStore(b, 1000, 64) + ctx := context.Background() + id := refs[500].ChunkID + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fcSinkChunk, fcSinkErr = store.Resolve(ctx, id) + } +} + +// --- ResolveRefBytes at scale (frame-offset path) --- + +func BenchmarkFilestoreCapacity_ResolveRefBytes_1000Records(b *testing.B) { + store, refs := benchStore(b, 1000, 64) + ctx := context.Background() + target := refs[500] + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fcSinkChunk, fcSinkErr = store.ResolveRefBytes(ctx, target) + } +} + +// --- PutBytes into a warm store --- +// 1000-record store + one more Put. Tracks the per-Put cost when the +// index is not empty. + +func BenchmarkFilestoreCapacity_PutBytes_Warm_1000(b *testing.B) { + store, _ := benchStore(b, 1000, 64) + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fcSinkRef, fcSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +// --- ChunkCount on a large index --- + +func BenchmarkFilestoreCapacity_ChunkCount_1000(b *testing.B) { + store, _ := benchStore(b, 1000, 64) + var sink int + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + sink = store.ChunkCount() + } + _ = sink +} + +// --- Reopen + index-rebuild at large scale --- +// Cold-start cost. The 100/1000-chunk variants live in resolveuri_bench_test.go +// (because the URI index is part of rebuildIndex); this adds the 10k variant. + +func BenchmarkFilestoreCapacity_Open_10000Records(b *testing.B) { + dir := b.TempDir() + path := dir + "/index-10000.bin" + { + store, err := Create(context.Background(), path) + if err != nil { + b.Fatal(err) + } + payload := make([]byte, 64) + for i := 0; i < 10000; i++ { + if _, err := store.PutBytes(context.Background(), payload, state.PutOptions{ + URI: "mlx://bench/open-" + strconv.Itoa(i), + Kind: "bench", + }); err != nil { + b.Fatal(err) + } + } + _ = store.Close() + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s, err := Open(ctx, path) + if err != nil { + b.Fatal(err) + } + _ = s.Close() + } +} + +func BenchmarkFilestoreCapacity_Open_SingleLargePayload(b *testing.B) { + dir := b.TempDir() + path := dir + "/single-large.bin" + { + store, err := Create(context.Background(), path) + if err != nil { + b.Fatal(err) + } + payload := make([]byte, indexHintMaxFileBytes+1) + if _, err := store.PutBytes(context.Background(), payload, state.PutOptions{ + URI: "mlx://bench/open-large", + Kind: "kv", + }); err != nil { + b.Fatal(err) + } + _ = store.Close() + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s, err := Open(ctx, path) + if err != nil { + b.Fatal(err) + } + _ = s.Close() + } +} + +// --- Open without URIs (no uriIndex population) --- +// Faster path because the URI map stays empty. Confirms the URI map +// writes dominate the rebuildIndex cost. + +func BenchmarkFilestoreCapacity_Open_NoURIs_1000(b *testing.B) { + dir := b.TempDir() + path := dir + "/noupd.bin" + { + store, err := Create(context.Background(), path) + if err != nil { + b.Fatal(err) + } + payload := make([]byte, 64) + opts := state.PutOptions{Kind: "bench"} + for i := 0; i < 1000; i++ { + if _, err := store.PutBytes(context.Background(), payload, opts); err != nil { + b.Fatal(err) + } + } + _ = store.Close() + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s, err := Open(ctx, path) + if err != nil { + b.Fatal(err) + } + _ = s.Close() + } +} diff --git a/go/state/filestore/error_bench_test.go b/go/state/filestore/error_bench_test.go new file mode 100644 index 0000000..32c9419 --- /dev/null +++ b/go/state/filestore/error_bench_test.go @@ -0,0 +1,233 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the error-path dispatchers in the filestore backend. +// Per AX-11 — filestore is the persistence layer behind every disk-backed +// state snapshot. Closed-store paths fire during shutdown drain, cancelled- +// context paths fire when a parent session aborts mid-restore, and +// missing-chunk paths fire when a stale ref points past the live index. +// Coverage here lets us see what the "miss + close + cancel" floor costs. +// +// Run: go test -bench='BenchmarkFilestoreError' -benchmem -run='^$' ./state/filestore + +package filestore + +import ( + "context" + "testing" + + core "dappco.re/go" + state "dappco.re/go/inference/state" +) + +// Sinks defeat compiler DCE. Distinct names per filestore bench file. +var ( + feSinkChunk state.Chunk + feSinkRef state.ChunkRef + feSinkErr error +) + +// --- Missing-chunk path --- +// ResolveBytes / Resolve return the wrapped ChunkNotFoundError when an +// id is not in the index. Hot path under cache eviction. + +func BenchmarkFilestoreError_ResolveBytes_Missing(b *testing.B) { + store, _ := benchStore(b, 1, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.ResolveBytes(ctx, 99999) + } +} + +func BenchmarkFilestoreError_Resolve_Missing(b *testing.B) { + store, _ := benchStore(b, 1, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.Resolve(ctx, 99999) + } +} + +func BenchmarkFilestoreError_ResolveURI_Missing(b *testing.B) { + store, _ := benchStore(b, 1, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.ResolveURI(ctx, "mlx://missing/chunk") + } +} + +// --- Closed-store paths --- +// After Close, every read/write must return a clean error. Fires on +// shutdown-drain when in-flight requests race the close. + +func BenchmarkFilestoreError_ResolveBytes_Closed(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/closed.bin") + if err != nil { + b.Fatal(err) + } + if err := store.Close(); err != nil { + b.Fatal(err) + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.ResolveBytes(ctx, 1) + } +} + +func BenchmarkFilestoreError_Resolve_Closed(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/closed.bin") + if err != nil { + b.Fatal(err) + } + if err := store.Close(); err != nil { + b.Fatal(err) + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.Resolve(ctx, 1) + } +} + +func BenchmarkFilestoreError_PutBytes_Closed(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/closed.bin") + if err != nil { + b.Fatal(err) + } + if err := store.Close(); err != nil { + b.Fatal(err) + } + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkRef, feSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +func BenchmarkFilestoreError_ResolveURI_Closed(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/closed.bin") + if err != nil { + b.Fatal(err) + } + if err := store.Close(); err != nil { + b.Fatal(err) + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.ResolveURI(ctx, "mlx://x") + } +} + +// --- Cancelled-context paths --- +// All filestore entry points run checkContext first. Cancelled contexts +// fire on session-shutdown drain — every in-flight resolve must early- +// return without doing disk I/O. + +func BenchmarkFilestoreError_ResolveBytes_CancelledCtx(b *testing.B) { + store, refs := benchStore(b, 1, 256) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + id := refs[0].ChunkID + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.ResolveBytes(ctx, id) + } +} + +func BenchmarkFilestoreError_PutBytes_CancelledCtx(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/cancelled.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + payload := make([]byte, 64) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkRef, feSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +func BenchmarkFilestoreError_ResolveURI_CancelledCtx(b *testing.B) { + store, _ := benchStore(b, 1, 256) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.ResolveURI(ctx, "mlx://x") + } +} + +// --- Nil-store paths --- +// (*Store)(nil).PutBytes / ResolveBytes must early-return without a +// nil deref. Cheap guard, but the bench tracks the floor cost. + +func BenchmarkFilestoreError_NilStore_ResolveBytes(b *testing.B) { + var store *Store + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.ResolveBytes(ctx, 1) + } +} + +func BenchmarkFilestoreError_NilStore_PutBytes(b *testing.B) { + var store *Store + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkRef, feSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +func BenchmarkFilestoreError_NilStore_ResolveURI(b *testing.B) { + var store *Store + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + feSinkChunk, feSinkErr = store.ResolveURI(ctx, "mlx://x") + } +} + +// --- Open on missing file --- +// Open of a non-existent path should return a clean error from +// core.OpenFile. Fires during the first session-load probe before +// the on-disk store has been created. + +func BenchmarkFilestoreError_Open_Missing(b *testing.B) { + dir := b.TempDir() + path := core.PathJoin(dir, "does-not-exist.bin") + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, feSinkErr = Open(ctx, path) + } +} diff --git a/go/state/filestore/putbytestream_bench_test.go b/go/state/filestore/putbytestream_bench_test.go new file mode 100644 index 0000000..228bf32 --- /dev/null +++ b/go/state/filestore/putbytestream_bench_test.go @@ -0,0 +1,250 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the PutBytesStream backpressure surface. +// Per AX-11 — PutBytesStream is the streaming variant that lets the +// caller feed a payload of declared size through an io.Writer chain. +// The limitedPayloadWriter guards against over/under-write — every +// streamed Save runs through it. Sub-header, very-large, and chunked- +// write scenarios stress different parts of the path. +// +// Run: go test -bench='BenchmarkFilestoreStream' -benchmem -run='^$' ./state/filestore + +package filestore + +import ( + "context" + stdio "io" + "testing" + + state "dappco.re/go/inference/state" +) + +// Sinks defeat compiler DCE. +var ( + fsSinkRef state.ChunkRef + fsSinkErr error +) + +// --- Stream small payloads (sub-recordHeader-size) --- +// Single-byte writes are pathological for the limitedPayloadWriter — +// no batching benefit. Common for streamed metadata-only sentinels. + +func BenchmarkFilestoreStream_OneByte(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/onebyte.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fsSinkRef, fsSinkErr = store.PutBytesStream(ctx, 1, opts, func(w stdio.Writer) error { + _, err := w.Write([]byte{'a'}) + return err + }) + } +} + +func BenchmarkFilestoreStream_Sub16(b *testing.B) { + // 16 bytes is smaller than recordHeaderLen (24). Confirms the + // header write cost dominates a payload-size-tiny stream. + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/sub16.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := []byte("0123456789abcdef") + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.SetBytes(16) + b.ResetTimer() + for i := 0; i < b.N; i++ { + fsSinkRef, fsSinkErr = store.PutBytesStream(ctx, len(payload), opts, func(w stdio.Writer) error { + _, err := w.Write(payload) + return err + }) + } +} + +// --- Stream large payloads (1MB, 4MB) --- +// Large state slices — a model-state checkpoint of a single KV layer +// can be MBs. The bench tracks the throughput floor. + +func BenchmarkFilestoreStream_1MB(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/1mb.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 1024*1024) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.SetBytes(1024 * 1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + fsSinkRef, fsSinkErr = store.PutBytesStream(ctx, len(payload), opts, func(w stdio.Writer) error { + _, err := w.Write(payload) + return err + }) + } +} + +func BenchmarkFilestoreStream_4MB(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/4mb.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 4*1024*1024) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.SetBytes(4 * 1024 * 1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + fsSinkRef, fsSinkErr = store.PutBytesStream(ctx, len(payload), opts, func(w stdio.Writer) error { + _, err := w.Write(payload) + return err + }) + } +} + +// --- Chunked writes --- +// 4-chunk write of a 64KB payload — common shape when the caller +// streams from a buffered upstream reader. Each Write call costs +// one limitedPayloadWriter dispatch. + +func BenchmarkFilestoreStream_Chunked_4x16KB(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/chunked.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + chunk := make([]byte, 16*1024) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.SetBytes(64 * 1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + fsSinkRef, fsSinkErr = store.PutBytesStream(ctx, 4*len(chunk), opts, func(w stdio.Writer) error { + for j := 0; j < 4; j++ { + if _, err := w.Write(chunk); err != nil { + return err + } + } + return nil + }) + } +} + +func BenchmarkFilestoreStream_Chunked_16x4KB(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/chunked16.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + chunk := make([]byte, 4*1024) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.SetBytes(64 * 1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + fsSinkRef, fsSinkErr = store.PutBytesStream(ctx, 16*len(chunk), opts, func(w stdio.Writer) error { + for j := 0; j < 16; j++ { + if _, err := w.Write(chunk); err != nil { + return err + } + } + return nil + }) + } +} + +// --- Stream-with-error-mid-write --- +// The writer returns an error part-way through. PutBytesStream must +// roll back the partial write + remove the orphan record. Fires on +// upstream EOF/cancellation paths. + +func BenchmarkFilestoreStream_ErrorMidWrite(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/err.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + chunk := make([]byte, 1024) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, fsSinkErr = store.PutBytesStream(ctx, 4*len(chunk), opts, func(w stdio.Writer) error { + // Write the first chunk, then bail. PutBytesStream must + // reject because payloadWriter.remaining != 0 after the + // callback returns nil-error. The "short-payload" path + // exercises rollbackWriteLocked. + _, _ = w.Write(chunk) + return nil + }) + } +} + +// --- Stream-oversize-write --- +// The callback writes more bytes than declared. The limitedPayloadWriter +// rejects + rolls back. + +func BenchmarkFilestoreStream_OversizeWrite(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/over.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + chunk := make([]byte, 1024) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, fsSinkErr = store.PutBytesStream(ctx, 512, opts, func(w stdio.Writer) error { + // Declared 512 but writes 1024 — limitedPayloadWriter rejects. + _, err := w.Write(chunk) + return err + }) + } +} + +// --- Stream-with-explicit-error --- +// The callback returns an error before writing. PutBytesStream must +// roll back the header that's already on disk. + +func BenchmarkFilestoreStream_ExplicitError(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/explicit.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + opts := state.PutOptions{Kind: "bench"} + sentinel := stdio.ErrShortBuffer + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, fsSinkErr = store.PutBytesStream(ctx, 64, opts, func(_ stdio.Writer) error { + return sentinel + }) + } +} diff --git a/go/state/filestore/putoptions_bench_test.go b/go/state/filestore/putoptions_bench_test.go new file mode 100644 index 0000000..bdd7b29 --- /dev/null +++ b/go/state/filestore/putoptions_bench_test.go @@ -0,0 +1,235 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the filestore PutOptions surface. +// Per AX-11 — filestore writes the PutOptions metadata as JSON inline +// in the record (recordMeta). Tag-map size dominates because the JSON +// marshal walks every entry. Title / URI lengths show up in the meta +// blob size + the per-record on-disk write. +// +// Run: go test -bench='BenchmarkFilestorePutOpts' -benchmem -run='^$' ./state/filestore + +package filestore + +import ( + "context" + "testing" + + state "dappco.re/go/inference/state" +) + +// Sinks defeat compiler DCE. +var ( + fpoSinkRef state.ChunkRef + fpoSinkErr error +) + +// --- Empty meta fast path --- +// Many code paths (KV snapshots, sentinel records, internal-only +// blobs) write a record with no PutOptions content. The hand-rolled +// fast path skips core.JSONMarshal entirely — its alloc shape is the +// floor for what PutBytesStream can deliver on a streaming write. + +func BenchmarkFilestorePutOpts_Empty(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/empty.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fpoSinkRef, fpoSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +// --- Tag map size sweep --- +// Memvid-style bundle saves carry 4-12 tags per chunk. The JSON +// marshal walks every entry; the on-disk record carries the marshalled +// bytes. Bench tracks the size-scaling cost. + +func BenchmarkFilestorePutOpts_NoTags(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/tags0.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fpoSinkRef, fpoSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +func BenchmarkFilestorePutOpts_Tags_1(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/tags1.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{ + Kind: "bench", + Tags: map[string]string{"epoch": "3"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fpoSinkRef, fpoSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +func BenchmarkFilestorePutOpts_Tags_4(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/tags4.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{ + Kind: "bench", + Tags: map[string]string{ + "epoch": "3", + "track": "primary", + "source": "memvid", + "env": "bench", + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fpoSinkRef, fpoSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +func BenchmarkFilestorePutOpts_Tags_8(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/tags8.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{ + Kind: "bench", + Tags: map[string]string{ + "epoch": "3", + "track": "primary", + "source": "memvid", + "env": "bench", + "branch": "dev", + "runner": "homelab", + "adapter": "lora-1", + "model": "qwen3", + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fpoSinkRef, fpoSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +// --- Labels slice size sweep --- + +func BenchmarkFilestorePutOpts_Labels_4(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/labels4.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{ + Kind: "bench", + Labels: []string{"k0:v0", "k1:v1", "k2:v2", "k3:v3"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fpoSinkRef, fpoSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +func BenchmarkFilestorePutOpts_Labels_8(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/labels8.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{ + Kind: "bench", + Labels: []string{"k0:v0", "k1:v1", "k2:v2", "k3:v3", "k4:v4", "k5:v5", "k6:v6", "k7:v7"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fpoSinkRef, fpoSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +// --- URI length sensitivity --- + +func BenchmarkFilestorePutOpts_URI_Long(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/uri-long.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 64) + uri := "mlx://lthn/projects/core/go-mlx/snapshots/2026-05-22T12:00:00Z/" + + "runtime/metal/m3-ultra/model/qwen3-27b-4bit/adapter/lora-1/" + + "workload/long-context/segment/chunk-00000042/epoch-3/layer/all" + opts := state.PutOptions{Kind: "bench", URI: uri} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fpoSinkRef, fpoSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +// --- FullMetadata (all fields populated) --- +// Stress shape — every PutOptions field has content. Real-world saves +// of training-checkpoint records carry full metadata. + +func BenchmarkFilestorePutOpts_FullMetadata(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/full.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + ctx := context.Background() + payload := make([]byte, 64) + opts := state.PutOptions{ + URI: "mlx://bench/full", + Title: "bench-chunk-with-long-title-for-realistic-meta", + Kind: "training-checkpoint", + Track: "primary-train", + Tags: map[string]string{"epoch": "3", "branch": "dev"}, + Labels: []string{"kind:training", "source:hypnos"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + fpoSinkRef, fpoSinkErr = store.PutBytes(ctx, payload, opts) + } +} diff --git a/go/state/filestore/region_bench_test.go b/go/state/filestore/region_bench_test.go new file mode 100644 index 0000000..c740b75 --- /dev/null +++ b/go/state/filestore/region_bench_test.go @@ -0,0 +1,149 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for embedded State regions inside a larger container. +// Per AX-11 - .kv wake now opens the State log by payload offset instead of +// materialising a temporary file, so the extra offset arithmetic must remain +// visible in benchmark output. +// +// Run: go test -bench='BenchmarkFilestoreRegion' -benchmem -run='^$' ./state/filestore + +package filestore + +import ( + "context" + "strconv" + "testing" + + core "dappco.re/go" + state "dappco.re/go/inference/state" +) + +var ( + frSinkChunk state.Chunk + frSinkErr error +) + +func benchRegionStore(tb testing.TB, records int, payloadSize int) (*Store, []state.ChunkRef) { + tb.Helper() + source, refs := benchStore(tb, records, payloadSize) + sourcePath := source.Path() + if err := source.Close(); err != nil { + tb.Fatal(err) + } + read := core.ReadFile(sourcePath) + if !read.OK { + tb.Fatalf("read source store: %s", read.Error()) + } + prefix := []byte("KVST-bench-header") + suffix := []byte("KVST-bench-tail") + sourceBytes := read.Value.([]byte) + container := make([]byte, 0, len(prefix)+len(sourceBytes)+len(suffix)) + container = append(container, prefix...) + container = append(container, sourceBytes...) + container = append(container, suffix...) + containerPath := core.PathJoin(core.PathDir(sourcePath), "session.kv") + if write := core.WriteFile(containerPath, container, 0o600); !write.OK { + tb.Fatalf("write region container: %s", write.Error()) + } + region, err := OpenRegionWithSegmentAlias(context.Background(), containerPath, int64(len(prefix)), int64(len(sourceBytes)), sourcePath) + if err != nil { + tb.Fatalf("open region store: %v", err) + } + tb.Cleanup(func() { _ = region.Close() }) + return region, refs +} + +func BenchmarkFilestoreRegion_ResolveRefBytes_64KB(b *testing.B) { + store, refs := benchRegionStore(b, 1, 64*1024) + ctx := context.Background() + target := refs[0] + b.ReportAllocs() + b.SetBytes(64 * 1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + frSinkChunk, frSinkErr = store.ResolveRefBytes(ctx, target) + } +} + +func BenchmarkFilestoreRegion_BorrowRefBytes_64KB(b *testing.B) { + store, refs := benchRegionStore(b, 1, 64*1024) + ctx := context.Background() + target := refs[0] + b.ReportAllocs() + b.SetBytes(64 * 1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + borrowed, err := state.BorrowRefBytes(ctx, store, target) + frSinkChunk = state.Chunk{Ref: borrowed.Ref, Data: borrowed.Data} + frSinkErr = err + } +} + +func BenchmarkFilestoreRegion_ResolveRefBytes_1000Records(b *testing.B) { + store, refs := benchRegionStore(b, 1000, 64) + ctx := context.Background() + target := refs[500] + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + frSinkChunk, frSinkErr = store.ResolveRefBytes(ctx, target) + } +} + +func BenchmarkFilestoreRegion_BorrowRefBytes_1000Records(b *testing.B) { + store, refs := benchRegionStore(b, 1000, 64) + ctx := context.Background() + target := refs[500] + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + borrowed, err := state.BorrowRefBytes(ctx, store, target) + frSinkChunk = state.Chunk{Ref: borrowed.Ref, Data: borrowed.Data} + frSinkErr = err + } +} + +func BenchmarkFilestoreRegion_Open_10000Records(b *testing.B) { + dir := b.TempDir() + sourcePath := core.PathJoin(dir, "index-10000.mvlog") + { + store, err := Create(context.Background(), sourcePath) + if err != nil { + b.Fatal(err) + } + payload := make([]byte, 64) + for i := 0; i < 10000; i++ { + if _, err := store.PutBytes(context.Background(), payload, state.PutOptions{ + URI: "mlx://bench/region-open-" + strconv.Itoa(i), + Kind: "bench", + }); err != nil { + b.Fatal(err) + } + } + _ = store.Close() + } + read := core.ReadFile(sourcePath) + if !read.OK { + b.Fatalf("read source store: %s", read.Error()) + } + prefix := []byte("KVST-bench-header") + sourceBytes := read.Value.([]byte) + containerPath := core.PathJoin(dir, "session.kv") + container := make([]byte, 0, len(prefix)+len(sourceBytes)) + container = append(container, prefix...) + container = append(container, sourceBytes...) + if write := core.WriteFile(containerPath, container, 0o600); !write.OK { + b.Fatalf("write region container: %s", write.Error()) + } + + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + store, err := OpenRegionWithSegmentAlias(ctx, containerPath, int64(len(prefix)), int64(len(sourceBytes)), sourcePath) + if err != nil { + b.Fatal(err) + } + _ = store.Close() + } +} diff --git a/go/state/filestore/resolverefbytes_bench_test.go b/go/state/filestore/resolverefbytes_bench_test.go new file mode 100644 index 0000000..1528e20 --- /dev/null +++ b/go/state/filestore/resolverefbytes_bench_test.go @@ -0,0 +1,154 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the filestore ResolveRefBytes mismatch shapes. +// Per AX-11 — ResolveRefBytes is the "stale-ref" path: a bundle ref +// arrives with codec / segment / frame-offset metadata that may not +// match the live store. The mismatch branches need cheap rejection +// so the consumer can retry with the right backend. The 1KB happy path +// is already benched in store_bench_test.go — these cover the shapes +// it lacks. +// +// Run: go test -bench='BenchmarkFilestoreRef' -benchmem -run='^$' ./state/filestore + +package filestore + +import ( + "context" + "testing" + + state "dappco.re/go/inference/state" +) + +// Sinks defeat compiler DCE. +var ( + frbSinkChunk state.Chunk + frbSinkErr error +) + +// --- ResolveRefBytes without HasFrameOffset --- +// When HasFrameOffset is false, ResolveRefBytes falls through to +// ResolveBytes by ChunkID. Common shape for refs from non-file +// backends that don't carry a frame offset. + +func BenchmarkFilestoreRef_NoFrameOffset_1KB(b *testing.B) { + store, refs := benchStore(b, 1, 1024) + ctx := context.Background() + ref := state.ChunkRef{ + ChunkID: refs[0].ChunkID, + HasFrameOffset: false, + // No Codec / Segment — exercises the bare ID-only path. + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + frbSinkChunk, frbSinkErr = store.ResolveRefBytes(ctx, ref) + } +} + +// --- ResolveRefBytes with HasFrameOffset (the bench-light large size) --- + +func BenchmarkFilestoreRef_WithFrameOffset_64KB(b *testing.B) { + store, refs := benchStore(b, 1, 64*1024) + ctx := context.Background() + b.ReportAllocs() + b.SetBytes(64 * 1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + frbSinkChunk, frbSinkErr = store.ResolveRefBytes(ctx, refs[0]) + } +} + +func BenchmarkFilestoreRef_WithFrameOffset_1MB(b *testing.B) { + store, refs := benchStore(b, 1, 1024*1024) + ctx := context.Background() + b.ReportAllocs() + b.SetBytes(1024 * 1024) + b.ResetTimer() + for i := 0; i < b.N; i++ { + frbSinkChunk, frbSinkErr = store.ResolveRefBytes(ctx, refs[0]) + } +} + +// --- Codec mismatch --- +// A ref carrying state/qr-video must not resolve against a file-log +// store — the codec guard returns immediately. Hot path when a +// memvid bundle was migrated and the runtime probed the wrong store. + +func BenchmarkFilestoreRef_CodecMismatch(b *testing.B) { + store, refs := benchStore(b, 1, 1024) + ctx := context.Background() + ref := refs[0] + ref.Codec = state.CodecStateVideo // not CodecFile / CodecMemvidFile + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + frbSinkChunk, frbSinkErr = store.ResolveRefBytes(ctx, ref) + } +} + +// --- Segment mismatch --- +// Segment carries the file path. A ref with the wrong segment must +// be rejected without doing disk I/O. + +func BenchmarkFilestoreRef_SegmentMismatch(b *testing.B) { + store, refs := benchStore(b, 1, 1024) + ctx := context.Background() + ref := refs[0] + ref.Segment = ref.Segment + ".other" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + frbSinkChunk, frbSinkErr = store.ResolveRefBytes(ctx, ref) + } +} + +// --- ID mismatch on FrameOffset --- +// The ref's ChunkID disagrees with what the on-disk record claims. +// The mismatch is detected mid-read after the header parse — slightly +// more expensive than a pre-read codec/segment reject. + +func BenchmarkFilestoreRef_IDMismatch(b *testing.B) { + store, refs := benchStore(b, 2, 1024) + ctx := context.Background() + // Ref claims chunk 1 but points at frame-offset for chunk 2. + ref := refs[0] + ref.FrameOffset = refs[1].FrameOffset + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + frbSinkChunk, frbSinkErr = store.ResolveRefBytes(ctx, ref) + } +} + +// --- Codec=MemvidFile (legacy header) --- +// CodecMemvidFile is the legacy codec name — the guard explicitly +// accepts both CodecFile and CodecMemvidFile. Benching the legacy +// path makes sure it stays as fast as the canonical one. + +func BenchmarkFilestoreRef_CodecLegacyMemvid(b *testing.B) { + store, refs := benchStore(b, 1, 1024) + ctx := context.Background() + ref := refs[0] + ref.Codec = CodecMemvidFile + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + frbSinkChunk, frbSinkErr = store.ResolveRefBytes(ctx, ref) + } +} + +// --- Codec empty (no codec constraint) --- +// A bare ref with no codec passes the guard (codec=="" is permissive). +// Common when refs are constructed from URI-only manifests. + +func BenchmarkFilestoreRef_CodecEmpty(b *testing.B) { + store, refs := benchStore(b, 1, 1024) + ctx := context.Background() + ref := refs[0] + ref.Codec = "" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + frbSinkChunk, frbSinkErr = store.ResolveRefBytes(ctx, ref) + } +} diff --git a/go/state/filestore/resolveuri_bench_test.go b/go/state/filestore/resolveuri_bench_test.go new file mode 100644 index 0000000..6a795fc --- /dev/null +++ b/go/state/filestore/resolveuri_bench_test.go @@ -0,0 +1,265 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the filestore ResolveURI variants. +// Per AX-11 — ResolveURI walks the in-memory uriIndex first, then does +// a Resolve by ChunkID. Misses are cheap; hits at scale matter because +// the uriIndex grows linearly with chunk count. The existing bench +// surface covers a typical hit on a fresh store — these cover the +// capacity + URI-shape variants. +// +// Run: go test -bench='BenchmarkFilestoreURI' -benchmem -run='^$' ./state/filestore + +package filestore + +import ( + "context" + "strconv" + "testing" + + core "dappco.re/go" + state "dappco.re/go/inference/state" +) + +// Sinks defeat compiler DCE. +var ( + furiSinkChunk state.Chunk + furiSinkErr error +) + +// benchStoreWithURIs creates a filestore + populates n chunks of +// payloadSize each, every chunk carrying a unique URI in the form +// "mlx://bench/uri-". Returns the store + the URI list. +func benchStoreWithURIs(tb testing.TB, n, payloadSize int) (*Store, []string) { + tb.Helper() + dir := tb.TempDir() + path := dir + "/uri.bin" + store, err := Create(context.Background(), path) + if err != nil { + tb.Fatal(err) + } + tb.Cleanup(func() { _ = store.Close() }) + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte('a' + i%26) + } + uris := make([]string, 0, n) + for i := 0; i < n; i++ { + uri := "mlx://bench/uri-" + strconv.Itoa(i) + _, err := store.PutBytes(context.Background(), payload, state.PutOptions{ + URI: uri, + Kind: "bench", + }) + if err != nil { + tb.Fatal(err) + } + uris = append(uris, uri) + } + return store, uris +} + +// --- ResolveURI hit at various capacities --- + +func BenchmarkFilestoreURI_Hit_10(b *testing.B) { + store, uris := benchStoreWithURIs(b, 10, 256) + ctx := context.Background() + target := uris[5] + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + furiSinkChunk, furiSinkErr = store.ResolveURI(ctx, target) + } +} + +func BenchmarkFilestoreURI_Hit_100(b *testing.B) { + store, uris := benchStoreWithURIs(b, 100, 256) + ctx := context.Background() + target := uris[50] + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + furiSinkChunk, furiSinkErr = store.ResolveURI(ctx, target) + } +} + +func BenchmarkFilestoreURI_Hit_1000(b *testing.B) { + store, uris := benchStoreWithURIs(b, 1000, 256) + ctx := context.Background() + target := uris[500] + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + furiSinkChunk, furiSinkErr = store.ResolveURI(ctx, target) + } +} + +// --- ResolveURI miss at various capacities --- +// Miss-path under load — the map probe returns immediately but the +// URIChunkNotFoundError allocates one wrapper. + +func BenchmarkFilestoreURI_Miss_10(b *testing.B) { + store, _ := benchStoreWithURIs(b, 10, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + furiSinkChunk, furiSinkErr = store.ResolveURI(ctx, "mlx://nope/zzz") + } +} + +func BenchmarkFilestoreURI_Miss_1000(b *testing.B) { + store, _ := benchStoreWithURIs(b, 1000, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + furiSinkChunk, furiSinkErr = store.ResolveURI(ctx, "mlx://nope/zzz") + } +} + +// --- URI string-shape sensitivity --- +// Short URI vs long URI. The uriIndex is a map[string]int — hash cost +// scales with URI length on hit. + +func BenchmarkFilestoreURI_Hit_LongURI(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/long.bin") + if err != nil { + b.Fatal(err) + } + b.Cleanup(func() { _ = store.Close() }) + + longURI := "mlx://lthn/projects/core/go-mlx/snapshots/2026-05-22T12:00:00Z/" + + "runtime/metal/m3-ultra/model/qwen3-27b-4bit/adapter/lora-1/" + + "workload/long-context/segment/chunk-00000042/epoch-3/layer/all" + payload := make([]byte, 256) + if _, err := store.PutBytes(context.Background(), payload, state.PutOptions{URI: longURI}); err != nil { + b.Fatal(err) + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + furiSinkChunk, furiSinkErr = store.ResolveURI(ctx, longURI) + } +} + +// --- ResolveURI via top-level state dispatcher --- +// state.ResolveURI walks the type-assertion to URIResolver before +// dispatching — the per-call overhead matters on multi-store probes. + +func BenchmarkFilestoreURI_TopLevelDispatcher_Hit(b *testing.B) { + store, uris := benchStoreWithURIs(b, 100, 256) + ctx := context.Background() + target := uris[50] + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + furiSinkChunk, furiSinkErr = state.ResolveURI(ctx, store, target) + } +} + +// --- ResolveURI after Reopen --- +// Open() rebuilds the uriIndex from the on-disk metadata. Hit-after- +// reopen tests that the index rebuild produces the same observable +// performance as a freshly populated store. + +func BenchmarkFilestoreURI_HitAfterReopen(b *testing.B) { + dir := b.TempDir() + path := dir + "/reopen.bin" + store, err := Create(context.Background(), path) + if err != nil { + b.Fatal(err) + } + payload := make([]byte, 256) + uri := "mlx://bench/reopen-50" + for i := 0; i < 100; i++ { + thisURI := "mlx://bench/reopen-" + strconv.Itoa(i) + if _, err := store.PutBytes(context.Background(), payload, state.PutOptions{ + URI: thisURI, + Kind: "bench", + }); err != nil { + b.Fatal(err) + } + } + if err := store.Close(); err != nil { + b.Fatal(err) + } + reopened, err := Open(context.Background(), path) + if err != nil { + b.Fatal(err) + } + b.Cleanup(func() { _ = reopened.Close() }) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + furiSinkChunk, furiSinkErr = reopened.ResolveURI(ctx, uri) + } +} + +// --- Open with a populated file (rebuildIndex cost) --- +// Open replays the on-disk record headers + metadata into the +// uriIndex. Cost is linear in the chunk count + metadata size. + +func BenchmarkFilestoreURI_Open_100Chunks(b *testing.B) { + dir := b.TempDir() + path := dir + "/index.bin" + { + store, err := Create(context.Background(), path) + if err != nil { + b.Fatal(err) + } + payload := make([]byte, 64) + for i := 0; i < 100; i++ { + if _, err := store.PutBytes(context.Background(), payload, state.PutOptions{ + URI: "mlx://bench/open-" + strconv.Itoa(i), + Kind: "bench", + }); err != nil { + b.Fatal(err) + } + } + _ = store.Close() + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s, err := Open(ctx, path) + if err != nil { + b.Fatal(err) + } + _ = s.Close() + } +} + +func BenchmarkFilestoreURI_Open_1000Chunks(b *testing.B) { + dir := b.TempDir() + path := core.PathJoin(dir, "index-1000.bin") + { + store, err := Create(context.Background(), path) + if err != nil { + b.Fatal(err) + } + payload := make([]byte, 64) + for i := 0; i < 1000; i++ { + if _, err := store.PutBytes(context.Background(), payload, state.PutOptions{ + URI: "mlx://bench/open-" + strconv.Itoa(i), + Kind: "bench", + }); err != nil { + b.Fatal(err) + } + } + _ = store.Close() + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + s, err := Open(ctx, path) + if err != nil { + b.Fatal(err) + } + _ = s.Close() + } +} diff --git a/go/state/filestore/store.go b/go/state/filestore/store.go new file mode 100644 index 0000000..9f332ea --- /dev/null +++ b/go/state/filestore/store.go @@ -0,0 +1,1612 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package filestore provides an append-only file-backed state store. +package filestore + +import ( + "context" + "encoding/binary" + stdio "io" + "sync" + + core "dappco.re/go" + "dappco.re/go/inference/state" +) + +const ( + CodecFile = "state/file-log" + CodecMemvidFile = "memvid/file-log" + + fileMode = 0o600 + recordHeaderLen = 24 + indexHintRecordBytes = 128 + indexHintMaxFileBytes = 32 * 1024 * 1024 +) + +var ( + fileMagic = []byte("go-inference-state-file-log-v1\n") + legacyFileMagic = []byte("go-mlx-memvid-file-log-v1\n") + recordMagic = [4]byte{'M', 'V', 'F', '1'} + // recordMagicU32 is the little-endian uint32 view of recordMagic, + // pre-computed once at init. decodeRecordHeader's magic check + // previously walked the 4-byte header byte-by-byte; rebuildIndex + // runs that check per record at 10k+ scale during cold Open, so + // folding the 4-way compare into one Uint32 read trims one ALU + // op per record. + recordMagicU32 = binary.LittleEndian.Uint32(recordMagic[:]) + + // emptyMetaBytes is the canonical empty-record-meta JSON blob. + // PutBytesStream shortcuts to this slice when no meta field is + // populated, skipping core.JSONMarshal entirely — encoding/json + // allocates an encoder + grow-doubled output buffer per call + // (~5550 B / 4-9 allocs) even for an all-zero struct. Reference + // types like this share safely because the surface is read-only + // across writeAll → file.Write. + emptyMetaBytes = []byte("{}") + + // errStoreClosed is the canonical post-Close error returned by + // every Resolve/Put gate. Sharing a single &core.Err{...} skips + // the per-call heap alloc that core.NewError("...") otherwise + // fires. The error is read-only after init — Err's Message field + // is set once here and never mutated; Error() is pure derivation. + // Callers compare via errors.Is(err, nil) or string-equality on + // .Error(), neither of which depends on pointer identity, so the + // sharing is safe across goroutines. + errStoreClosed = core.NewError("state file store is closed") + errStoreNil = core.NewError("state file store is nil") + errPayloadSizeInvalid = core.NewError("state file store payload size is invalid") + errStreamWriterNil = core.NewError("state file store stream writer is nil") + errMetadataTooLarge = core.NewError("state file store metadata is too large") + errPayloadShort = core.NewError("state file store streamed payload is shorter than declared") + errPayloadOversize = core.NewError("state file store streamed payload is larger than declared") + errRefNonFileCodec = core.NewError("state file store cannot resolve non-file chunk ref") + errRefSegmentMismatch = core.NewError("state file store chunk ref segment mismatch") + errRefFrameOffsetTooBig = core.NewError("state file store frame offset is too large") + errRefChunkIDMismatch = core.NewError("state file store chunk ref id mismatch") + errStoreReadOnly = core.NewError("state file store is read-only") + errRegionInvalid = core.NewError("state file store region is invalid") + errMappedRegionInvalid = core.NewError("state file store mapped region is invalid") +) + +type Store struct { + mu sync.Mutex + path string + alias string + file *core.OSFile + baseAt int64 + region int64 + readOnly bool + mapped []byte + mappedRegion []byte + index map[int]fileIndexEntry + uriIndex map[string]int + nextID int + writeAt int64 + // payloadWriter is the per-Store streaming bound writer reused + // across PutBytesStream calls. Holding it on the Store skips + // the &limitedPayloadWriter{...} alloc every Put paid for the + // closure dispatch (the writer escaped to heap once per call). + // The mutex above already serialises PutBytesStream so the + // embedded writer's remaining counter is single-owner during + // any one call. + payloadWriter limitedPayloadWriter + // headerMetaBuf is the per-Store scratch buffer that + // encodeRecordHeaderMeta builds the on-disk header + meta + // JSON into. The previous shape allocated a fresh buffer on + // every PutBytesStream (~49 B for the Kind-only common shape, + // up to a few hundred B for label-heavy meta). Reusing the + // buffer under mu skips the per-Put alloc; the slice header + // is single-owner during any one Put because the mutex above + // already serialises the entire write path. + // + // Lifetime: the buffer is read by writeAll(file, ...) before + // PutBytesStream returns, so its content is consumed before + // the next Put can reuse the storage. Length is reset to zero + // on entry to encodeRecordHeaderMeta so each Put builds + // fresh contents over the retained capacity. + headerMetaBuf []byte +} + +type fileIndexEntry struct { + ref state.ChunkRef + payloadAt int64 + payloadSize int +} + +type recordMeta struct { + URI string `json:"uri,omitempty"` + Title string `json:"title,omitempty"` + Kind string `json:"kind,omitempty"` + Track string `json:"track,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + Labels []string `json:"labels,omitempty"` +} + +// Create initialises a new append-only state file store at path. +func Create(ctx context.Context, path string) (*Store, error) { + if err := checkContext(ctx); err != nil { + return nil, err + } + if core.Trim(path) == "" { + return nil, core.NewError("state file store path is required") + } + if result := core.MkdirAll(core.PathDir(path), 0o755); !result.OK { + return nil, core.E("state.filestore.Create", "create parent directory", resultError(result)) + } + result := core.OpenFile(path, core.O_CREATE|core.O_TRUNC|core.O_RDWR, fileMode) + if !result.OK { + return nil, core.E("state.filestore.Create", "create file", resultError(result)) + } + file := result.Value.(*core.OSFile) + if err := writeAll(file, fileMagic); err != nil { + _ = file.Close() + return nil, core.E("state.filestore.Create", "write file header", err) + } + return &Store{ + path: path, + file: file, + index: make(map[int]fileIndexEntry), + uriIndex: make(map[string]int), + nextID: 1, + writeAt: int64(len(fileMagic)), + }, nil +} + +// Open reopens an existing append-only state file store and rebuilds its +// offset index without reading chunk payloads. +func Open(ctx context.Context, path string) (*Store, error) { + return openWithSegmentAlias(ctx, path, "") +} + +// OpenWithSegmentAlias reopens an existing append-only state file store and +// permits refs whose Segment names canonicalSegment. This keeps relocation +// explicit for container-mounted State files while preserving Open's strict +// default segment validation. +func OpenWithSegmentAlias(ctx context.Context, path string, canonicalSegment string) (*Store, error) { + return openWithSegmentAlias(ctx, path, core.Trim(canonicalSegment)) +} + +// OpenRegionWithSegmentAlias opens an append-only state log embedded inside a +// larger file. Frame offsets remain relative to the embedded State payload, +// while Segment validation accepts canonicalSegment for relocated refs. +func OpenRegionWithSegmentAlias(ctx context.Context, path string, payloadOffset int64, payloadBytes int64, canonicalSegment string) (*Store, error) { + return openRegionWithSegmentAlias(ctx, path, payloadOffset, payloadBytes, core.Trim(canonicalSegment), true) +} + +func openWithSegmentAlias(ctx context.Context, path string, canonicalSegment string) (*Store, error) { + return openRegionWithSegmentAlias(ctx, path, 0, 0, canonicalSegment, false) +} + +func openRegionWithSegmentAlias(ctx context.Context, path string, payloadOffset int64, payloadBytes int64, canonicalSegment string, readOnly bool) (*Store, error) { + if err := checkContext(ctx); err != nil { + return nil, err + } + if core.Trim(path) == "" { + return nil, core.NewError("state file store path is required") + } + if payloadOffset < 0 || payloadBytes < 0 { + return nil, errRegionInvalid + } + flags := core.O_RDWR + if readOnly { + flags = core.O_RDONLY + } + result := core.OpenFile(path, flags, fileMode) + if !result.OK { + return nil, core.E("state.filestore.Open", "open file", resultError(result)) + } + file := result.Value.(*core.OSFile) + store := &Store{ + path: path, + alias: canonicalSegment, + file: file, + baseAt: payloadOffset, + region: payloadBytes, + readOnly: readOnly, + index: make(map[int]fileIndexEntry), + uriIndex: make(map[string]int), + nextID: 1, + } + if err := store.rebuildIndex(ctx); err != nil { + _ = file.Close() + return nil, err + } + return store, nil +} + +func (s *Store) Path() string { + if s == nil { + return "" + } + return s.path +} + +func (s *Store) ChunkCount() int { + if s == nil { + return 0 + } + s.mu.Lock() + defer s.mu.Unlock() + return len(s.index) +} + +func (s *Store) Close() error { + if s == nil { + return nil + } + s.mu.Lock() + defer s.mu.Unlock() + if s.file == nil { + return nil + } + s.unmapRegionLocked() + file := s.file + s.file = nil + return file.Close() +} + +func (s *Store) Get(ctx context.Context, chunkID int) (string, error) { + chunk, err := s.Resolve(ctx, chunkID) + if err != nil { + return "", err + } + return chunk.Text, nil +} + +func (s *Store) Resolve(ctx context.Context, chunkID int) (state.Chunk, error) { + if err := checkContext(ctx); err != nil { + return state.Chunk{}, err + } + if s == nil { + return state.Chunk{}, &state.ChunkNotFoundError{ID: chunkID} + } + s.mu.Lock() + defer s.mu.Unlock() + if s.file == nil { + return state.Chunk{}, errStoreClosed + } + return s.resolveLocked(chunkID) +} + +func (s *Store) ResolveURI(ctx context.Context, uri string) (state.Chunk, error) { + if err := checkContext(ctx); err != nil { + return state.Chunk{}, err + } + if s == nil { + return state.Chunk{}, &state.URIChunkNotFoundError{URI: uri} + } + s.mu.Lock() + defer s.mu.Unlock() + if s.file == nil { + return state.Chunk{}, errStoreClosed + } + id, ok := s.uriIndex[uri] + if !ok { + return state.Chunk{}, &state.URIChunkNotFoundError{URI: uri} + } + return s.resolveLocked(id) +} + +func (s *Store) Put(ctx context.Context, text string, opts state.PutOptions) (state.ChunkRef, error) { + // PutBytes feeds data into a writer that copies it onto disk — the + // underlying io.Writer contract forbids retention or mutation, so + // AsBytes is safe here. Avoids the copy of `text` into a fresh + // []byte just to be discarded after the disk write. + return s.PutBytes(ctx, core.AsBytes(text), opts) +} + +func (s *Store) PutBytes(ctx context.Context, data []byte, opts state.PutOptions) (state.ChunkRef, error) { + return s.PutBytesStream(ctx, len(data), opts, func(writer stdio.Writer) error { + return writeAll(writer, data) + }) +} + +func (s *Store) PutBytesStream(ctx context.Context, payloadSize int, opts state.PutOptions, write func(stdio.Writer) error) (state.ChunkRef, error) { + if err := checkContext(ctx); err != nil { + return state.ChunkRef{}, err + } + if s == nil { + return state.ChunkRef{}, errStoreNil + } + if payloadSize < 0 { + return state.ChunkRef{}, errPayloadSizeInvalid + } + if write == nil { + return state.ChunkRef{}, errStreamWriterNil + } + s.mu.Lock() + defer s.mu.Unlock() + if s.file == nil { + return state.ChunkRef{}, errStoreClosed + } + if s.readOnly { + return state.ChunkRef{}, errStoreReadOnly + } + + id := s.nextID + meta := recordMeta{ + URI: opts.URI, + Title: opts.Title, + Kind: opts.Kind, + Track: opts.Track, + Tags: opts.Tags, + Labels: opts.Labels, + } + // buildHeaderMeta packs the 24-byte record header and + // the JSON-encoded recordMeta into the per-Store scratch + // buffer (s.headerMetaBuf). The previous shape allocated a + // fresh buffer per Put; reusing under mu skips that. The + // metaSize uint32 in the header is patched after the meta + // is appended — single-pass build. + headerMeta := s.buildHeaderMeta(&meta, id, payloadSize) + metaSize := len(headerMeta) - recordHeaderLen + if uint64(metaSize) > uint64(^uint32(0)) { + return state.ChunkRef{}, errMetadataTooLarge + } + offset := s.writeAt + physicalOffset, err := s.physicalOffset(offset) + if err != nil { + return state.ChunkRef{}, err + } + if _, err := s.file.Seek(physicalOffset, stdio.SeekStart); err != nil { + return state.ChunkRef{}, core.E("state.filestore.Put", "seek to append offset", err) + } + if err := writeAll(s.file, headerMeta); err != nil { + s.rollbackWriteLocked(offset) + return state.ChunkRef{}, core.E("state.filestore.Put", "write record header and metadata", err) + } + s.payloadWriter.file = s.file + s.payloadWriter.remaining = payloadSize + if err := write(&s.payloadWriter); err != nil { + s.rollbackWriteLocked(offset) + return state.ChunkRef{}, core.E("state.filestore.Put", "write record payload", err) + } + if s.payloadWriter.remaining != 0 { + s.rollbackWriteLocked(offset) + return state.ChunkRef{}, errPayloadShort + } + ref := state.ChunkRef{ + ChunkID: id, + FrameOffset: uint64(offset), + HasFrameOffset: true, + Codec: CodecFile, + Segment: s.path, + } + s.index[id] = fileIndexEntry{ + ref: ref, + payloadAt: offset + recordHeaderLen + int64(metaSize), + payloadSize: payloadSize, + } + if meta.URI != "" { + s.uriIndex[meta.URI] = id + } + s.nextID++ + s.writeAt += int64(recordHeaderLen + metaSize + payloadSize) + return ref, nil +} + +func (s *Store) rollbackWriteLocked(offset int64) { + if s == nil || s.file == nil { + return + } + physicalOffset, err := s.physicalOffset(offset) + if err != nil { + return + } + _ = s.file.Truncate(physicalOffset) + _, _ = s.file.Seek(physicalOffset, stdio.SeekStart) +} + +func (s *Store) resolveLocked(chunkID int) (state.Chunk, error) { + chunk, err := s.resolveBytesLocked(chunkID) + if err != nil { + return state.Chunk{}, err + } + // chunk.Data is freshly allocated by ReadAt and unreachable here + // — handing it to AsString skips the payload-sized copy that + // string(chunk.Data) would do. Every Resolve text read benefits; + // payloads scale to KB+ for compressed state slices. + chunk.Text = core.AsString(chunk.Data) + chunk.Data = nil + return chunk, nil +} + +func (s *Store) ResolveBytes(ctx context.Context, chunkID int) (state.Chunk, error) { + if err := checkContext(ctx); err != nil { + return state.Chunk{}, err + } + if s == nil { + return state.Chunk{}, &state.ChunkNotFoundError{ID: chunkID} + } + s.mu.Lock() + defer s.mu.Unlock() + if s.file == nil { + return state.Chunk{}, errStoreClosed + } + return s.resolveBytesLocked(chunkID) +} + +func (s *Store) BorrowBytes(ctx context.Context, chunkID int) (state.BorrowedChunk, error) { + if err := checkContext(ctx); err != nil { + return state.BorrowedChunk{}, err + } + if s == nil { + return state.BorrowedChunk{}, &state.ChunkNotFoundError{ID: chunkID} + } + s.mu.Lock() + defer s.mu.Unlock() + if s.file == nil { + return state.BorrowedChunk{}, errStoreClosed + } + entry, ok := s.index[chunkID] + if !ok { + return state.BorrowedChunk{}, &state.ChunkNotFoundError{ID: chunkID} + } + if s.readOnly { + payloadAt := entry.payloadAt - s.baseAt + data, err := s.borrowPayloadLocked(payloadAt, entry.payloadSize) + if err != nil { + return state.BorrowedChunk{}, err + } + return state.BorrowedChunk{Ref: entry.ref, Data: data}, nil + } + chunk, err := s.resolveBytesLocked(chunkID) + if err != nil { + return state.BorrowedChunk{}, err + } + return state.BorrowedChunk{Ref: chunk.Ref, Data: chunk.Data}, nil +} + +func (s *Store) ResolveRefBytes(ctx context.Context, ref state.ChunkRef) (state.Chunk, error) { + if err := checkContext(ctx); err != nil { + return state.Chunk{}, err + } + if s == nil { + return state.Chunk{}, &state.ChunkNotFoundError{ID: ref.ChunkID} + } + if !ref.HasFrameOffset { + return s.ResolveBytes(ctx, ref.ChunkID) + } + if ref.Codec != "" && ref.Codec != CodecFile && ref.Codec != CodecMemvidFile { + return state.Chunk{}, errRefNonFileCodec + } + if ref.Segment != "" && ref.Segment != s.path && ref.Segment != s.alias { + return state.Chunk{}, errRefSegmentMismatch + } + s.mu.Lock() + defer s.mu.Unlock() + if s.file == nil { + return state.Chunk{}, errStoreClosed + } + return s.resolveRefBytesLocked(ref) +} + +func (s *Store) BorrowRefBytes(ctx context.Context, ref state.ChunkRef) (state.BorrowedChunk, error) { + if err := checkContext(ctx); err != nil { + return state.BorrowedChunk{}, err + } + if s == nil { + return state.BorrowedChunk{}, &state.ChunkNotFoundError{ID: ref.ChunkID} + } + if !ref.HasFrameOffset { + return s.BorrowBytes(ctx, ref.ChunkID) + } + if ref.Codec != "" && ref.Codec != CodecFile && ref.Codec != CodecMemvidFile { + return state.BorrowedChunk{}, errRefNonFileCodec + } + if ref.Segment != "" && ref.Segment != s.path && ref.Segment != s.alias { + return state.BorrowedChunk{}, errRefSegmentMismatch + } + s.mu.Lock() + defer s.mu.Unlock() + if s.file == nil { + return state.BorrowedChunk{}, errStoreClosed + } + if !s.readOnly { + chunk, err := s.resolveRefBytesLocked(ref) + if err != nil { + return state.BorrowedChunk{}, err + } + return state.BorrowedChunk{Ref: chunk.Ref, Data: chunk.Data}, nil + } + return s.borrowRefBytesLocked(ref) +} + +func (s *Store) resolveBytesLocked(chunkID int) (state.Chunk, error) { + entry, ok := s.index[chunkID] + if !ok { + return state.Chunk{}, &state.ChunkNotFoundError{ID: chunkID} + } + payload := make([]byte, entry.payloadSize) + if _, err := s.file.ReadAt(payload, entry.payloadAt); err != nil { + return state.Chunk{}, core.E("state.filestore.Resolve", "read chunk payload", err) + } + return state.Chunk{ + Ref: entry.ref, + Data: payload, + }, nil +} + +func (s *Store) resolveRefBytesLocked(ref state.ChunkRef) (state.Chunk, error) { + if ref.FrameOffset > uint64(maxInt()) { + return state.Chunk{}, errRefFrameOffsetTooBig + } + offset := int64(ref.FrameOffset) + physicalOffset, err := s.physicalOffset(offset) + if err != nil { + return state.Chunk{}, err + } + var headerBuf [recordHeaderLen]byte + if _, err := s.file.ReadAt(headerBuf[:], physicalOffset); err != nil { + return state.Chunk{}, core.E("state.filestore.ResolveRefBytes", "read record header", err) + } + record, err := decodeRecordHeader(headerBuf[:]) + if err != nil { + return state.Chunk{}, err + } + id, err := intFromUint64(record.chunkID, "chunk id") + if err != nil { + return state.Chunk{}, err + } + if ref.ChunkID != 0 && id != ref.ChunkID { + return state.Chunk{}, errRefChunkIDMismatch + } + metaSize, err := intFromUint64(uint64(record.metaSize), "metadata") + if err != nil { + return state.Chunk{}, err + } + payloadSize, err := intFromUint64(record.payloadSize, "payload") + if err != nil { + return state.Chunk{}, err + } + payloadAt := physicalOffset + recordHeaderLen + int64(metaSize) + payload := make([]byte, payloadSize) + if _, err := s.file.ReadAt(payload, payloadAt); err != nil { + return state.Chunk{}, core.E("state.filestore.ResolveRefBytes", "read chunk payload", err) + } + return state.Chunk{ + Ref: state.ChunkRef{ + ChunkID: id, + FrameOffset: ref.FrameOffset, + HasFrameOffset: true, + Codec: CodecFile, + Segment: s.path, + }, + Data: payload, + }, nil +} + +func (s *Store) borrowRefBytesLocked(ref state.ChunkRef) (state.BorrowedChunk, error) { + if ref.FrameOffset > uint64(maxInt()) { + return state.BorrowedChunk{}, errRefFrameOffsetTooBig + } + offset := int64(ref.FrameOffset) + var headerView []byte + if err := s.ensureMappedRegionLocked(); err == nil { + if offset < 0 || offset+recordHeaderLen > int64(len(s.mappedRegion)) { + return state.BorrowedChunk{}, errRegionInvalid + } + headerView = s.mappedRegion[offset : offset+recordHeaderLen] + } else { + physicalOffset, perr := s.physicalOffset(offset) + if perr != nil { + return state.BorrowedChunk{}, perr + } + var headerBuf [recordHeaderLen]byte + if _, rerr := s.file.ReadAt(headerBuf[:], physicalOffset); rerr != nil { + return state.BorrowedChunk{}, core.E("state.filestore.BorrowRefBytes", "read record header", rerr) + } + headerView = headerBuf[:] + } + record, err := decodeRecordHeader(headerView) + if err != nil { + return state.BorrowedChunk{}, err + } + id, err := intFromUint64(record.chunkID, "chunk id") + if err != nil { + return state.BorrowedChunk{}, err + } + if ref.ChunkID != 0 && id != ref.ChunkID { + return state.BorrowedChunk{}, errRefChunkIDMismatch + } + metaSize, err := intFromUint64(uint64(record.metaSize), "metadata") + if err != nil { + return state.BorrowedChunk{}, err + } + payloadSize, err := intFromUint64(record.payloadSize, "payload") + if err != nil { + return state.BorrowedChunk{}, err + } + payloadAt := offset + recordHeaderLen + int64(metaSize) + data, err := s.borrowPayloadLocked(payloadAt, payloadSize) + if err != nil { + return state.BorrowedChunk{}, err + } + return state.BorrowedChunk{ + Ref: state.ChunkRef{ + ChunkID: id, + FrameOffset: ref.FrameOffset, + HasFrameOffset: true, + Codec: CodecFile, + Segment: s.path, + }, + Data: data, + }, nil +} + +func (s *Store) borrowPayloadLocked(payloadAt int64, payloadSize int) ([]byte, error) { + if payloadSize < 0 || payloadAt < 0 { + return nil, errRegionInvalid + } + if err := s.ensureMappedRegionLocked(); err != nil { + physicalAt, perr := s.physicalOffset(payloadAt) + if perr != nil { + return nil, perr + } + data := make([]byte, payloadSize) + if _, rerr := s.file.ReadAt(data, physicalAt); rerr != nil { + return nil, core.E("state.filestore.BorrowRefBytes", "read chunk payload", rerr) + } + return data, nil + } + end := payloadAt + int64(payloadSize) + if end < payloadAt || end > int64(len(s.mappedRegion)) { + return nil, errRegionInvalid + } + return s.mappedRegion[payloadAt:end], nil +} + +func indexCapacityHint(size, headerLen int64) int { + recordBytes := size - headerLen + if recordBytes <= 0 || recordBytes > indexHintMaxFileBytes { + return 0 + } + records := recordBytes / indexHintRecordBytes + if records <= 0 { + return 0 + } + return int(records) +} + +func (s *Store) rebuildIndex(ctx context.Context) error { + info, err := s.file.Stat() + if err != nil { + return core.E("state.filestore.Open", "stat file", err) + } + size, err := s.regionSize(info.Size()) + if err != nil { + return err + } + headerLen, err := s.detectHeaderLen(size) + if err != nil { + return err + } + + // Best-effort capacity hint for small-record stores. Do not derive map + // capacity from arbitrarily large State files: packed KV containers can be + // hundreds of MiB with only a few records, and byte-size preallocation turns + // store-open into a large heap allocation before any payload is touched. + if records := indexCapacityHint(size, headerLen); records > 0 && len(s.index) == 0 { + s.index = make(map[int]fileIndexEntry, records) + s.uriIndex = make(map[string]int, records) + } + + // Prefetch buffer — read header + meta in a single ReadAt where + // possible. Typical records have meta < ~200 bytes (URI + Kind + + // short Title), so a 512-byte prefetch covers ~95% of records and + // halves the syscall count over the rebuild. Records with bigger + // meta fall back to the original two-ReadAt path; the cost there + // is unchanged. + // + // The buffer is stack-allocated (gcflags confirms "does not escape") + // because every byte read out of it is either parsed into a + // stack-local recordHeader or copied into the URI string via + // extractRecordURI. Each iteration overwrites it before the next. + const prefetchSize = 512 + var prefetchBuf [prefetchSize]byte + + // Fallback meta buffer for records whose meta exceeds prefetchSize. + // Grows in place across records to avoid per-record allocations on + // the rare-but-not-impossible big-meta corpus. The buffer contents + // are decoded into stack-only locals before the next iteration + // overwrites them. + var metaBuf []byte + offset := headerLen + for offset < size { + if err := checkContext(ctx); err != nil { + return err + } + if offset+recordHeaderLen > size { + return core.NewError("state file store has truncated record header") + } + // Read header + the first prefetchSize-recordHeaderLen bytes + // of meta in one syscall. ReadAt returns short at EOF for the + // final record — that's harmless because n is then used as + // the length of the readable view and we know the meta size + // from the parsed header. The kernel page cache makes the + // extra-bytes cost negligible vs the syscall round-trip cost. + want := int64(prefetchSize) + if offset+want > size { + want = size - offset + } + physicalOffset, err := s.physicalOffset(offset) + if err != nil { + return err + } + n, err := s.file.ReadAt(prefetchBuf[:want], physicalOffset) + if err != nil && err != stdio.EOF { + return core.E("state.filestore.Open", "read record prefetch", err) + } + if n < recordHeaderLen { + return core.NewError("state file store has truncated record header") + } + record, err := decodeRecordHeader(prefetchBuf[:recordHeaderLen]) + if err != nil { + return err + } + metaSize, err := intFromUint64(uint64(record.metaSize), "metadata") + if err != nil { + return err + } + payloadSize, err := intFromUint64(record.payloadSize, "payload") + if err != nil { + return err + } + metaAt := offset + recordHeaderLen + payloadAt := metaAt + int64(metaSize) + nextOffset := payloadAt + int64(payloadSize) + if nextOffset > size { + return core.NewError("state file store has truncated record payload") + } + // Fast path: prefetch covered both header and meta. Hand + // extractRecordURI a slice straight into prefetchBuf. + var metaView []byte + if metaSize == 0 { + metaView = nil + } else if recordHeaderLen+metaSize <= n { + metaView = prefetchBuf[recordHeaderLen : recordHeaderLen+metaSize] + } else { + // Big-meta fallback — meta exceeds the prefetched span. + // Re-read the meta into the growable metaBuf. Rare in + // practice; size-grows are amortised across records. + if cap(metaBuf) < metaSize { + metaBuf = make([]byte, metaSize) + } else { + metaBuf = metaBuf[:metaSize] + } + metaPhysicalAt, err := s.physicalOffset(metaAt) + if err != nil { + return err + } + if _, err := s.file.ReadAt(metaBuf, metaPhysicalAt); err != nil { + return core.E("state.filestore.Open", "read record metadata", err) + } + metaView = metaBuf + } + // Lazy meta scan: only URI is needed to populate uriIndex — + // the meta blob's other fields (Title/Kind/Track/Tags/ + // Labels) are written for forward audit, not read by any + // hot path. extractRecordURI walks the JSON object + // end-to-end (so structural corruption is still caught) + // but only materialises the URI string. At 10k records + // this skips ~6 allocs/record (Tags map + Labels slice + + // Title/Kind/Track string copies) over a full + // json.Unmarshal of recordMeta. The fileIndexEntry.meta + // field is left zero-valued on this path; Put still + // populates it to keep the put-side bench shape intact. + var uri string + if metaSize > 0 { + extracted, err := extractRecordURI(metaView) + if err != nil { + return core.E("state.filestore.Open", "parse record metadata", err) + } + uri = extracted + } + id, err := intFromUint64(record.chunkID, "chunk id") + if err != nil { + return err + } + ref := state.ChunkRef{ + ChunkID: id, + FrameOffset: uint64(offset), + HasFrameOffset: true, + Codec: CodecFile, + Segment: s.path, + } + s.index[id] = fileIndexEntry{ + ref: ref, + payloadAt: s.baseAt + payloadAt, + payloadSize: payloadSize, + } + if uri != "" { + s.uriIndex[uri] = id + } + if id >= s.nextID { + s.nextID = id + 1 + } + offset = nextOffset + } + s.writeAt = offset + return nil +} + +func (s *Store) detectHeaderLen(size int64) (int64, error) { + minHeaderLen := len(fileMagic) + if len(legacyFileMagic) < minHeaderLen { + minHeaderLen = len(legacyFileMagic) + } + if size < int64(minHeaderLen) { + return 0, core.NewError("state file store is missing header") + } + maxHeaderLen := len(fileMagic) + if len(legacyFileMagic) > maxHeaderLen { + maxHeaderLen = len(legacyFileMagic) + } + if size < int64(maxHeaderLen) { + maxHeaderLen = int(size) + } + magic := make([]byte, maxHeaderLen) + if _, err := s.file.ReadAt(magic, s.baseAt); err != nil { + return 0, core.E("state.filestore.Open", "read file header", err) + } + if hasMagicPrefix(magic, fileMagic) { + return int64(len(fileMagic)), nil + } + if hasMagicPrefix(magic, legacyFileMagic) { + return int64(len(legacyFileMagic)), nil + } + return 0, core.NewError("state file store header is invalid") +} + +func (s *Store) regionSize(fileSize int64) (int64, error) { + if s == nil || s.baseAt < 0 || s.region < 0 || s.baseAt > fileSize { + return 0, errRegionInvalid + } + available := fileSize - s.baseAt + if s.region == 0 { + return available, nil + } + if s.region > available { + return 0, errRegionInvalid + } + return s.region, nil +} + +func (s *Store) physicalOffset(logOffset int64) (int64, error) { + if s == nil || logOffset < 0 { + return 0, errRegionInvalid + } + if s.region > 0 && logOffset > s.region { + return 0, errRegionInvalid + } + if s.baseAt > 0 && logOffset > (1<<63-1)-s.baseAt { + return 0, errRegionInvalid + } + return s.baseAt + logOffset, nil +} + +func hasMagicPrefix(data, magic []byte) bool { + return len(data) >= len(magic) && string(data[:len(magic)]) == string(magic) +} + +type recordHeader struct { + chunkID uint64 + payloadSize uint64 + metaSize uint32 +} + +// encodeRecordHeader writes a record header into the caller-supplied +// buffer (must be at least recordHeaderLen bytes). The previous shape +// allocated a fresh []byte on every Put — header writes fire once per +// chunk written, so the alloc compounded for every state save. +func encodeRecordHeader(buf []byte, chunkID int, payloadSize, metaSize int) { + _ = buf[recordHeaderLen-1] // bounds-check hint + copy(buf[:4], recordMagic[:]) + binary.LittleEndian.PutUint64(buf[4:12], uint64(chunkID)) + binary.LittleEndian.PutUint64(buf[12:20], uint64(payloadSize)) + binary.LittleEndian.PutUint32(buf[20:24], uint32(metaSize)) +} + +func decodeRecordHeader(header []byte) (recordHeader, error) { + if len(header) != recordHeaderLen { + return recordHeader{}, core.NewError("state file store record header has invalid length") + } + // Magic-prefix check via a single Uint32 read against the + // pre-computed recordMagicU32 — one ALU op per record at the + // rebuildIndex 10k-scale cold Open, where the previous 4-byte + // branching compare emitted 4 cmpb + 3 brand merges. Folding + // the 32 bits into a single equality test also lets the + // compiler hoist the magic constant into an immediate operand. + // `string(header[:4]) != string(recordMagic[:])` would allocate + // a fresh 4-byte string on every call. + if binary.LittleEndian.Uint32(header[:4]) != recordMagicU32 { + return recordHeader{}, core.NewError("state file store record header is invalid") + } + return recordHeader{ + chunkID: binary.LittleEndian.Uint64(header[4:12]), + payloadSize: binary.LittleEndian.Uint64(header[12:20]), + metaSize: binary.LittleEndian.Uint32(header[20:24]), + }, nil +} + +// recordMetaIsEmpty reports whether the record meta has no +// populated field — string fields all empty, Tags map nil or empty, +// Labels slice nil or empty. The PutBytesStream fast path uses this +// to short-circuit JSON marshalling on records that carry no caller +// metadata (the common shape for KV snapshots and sentinel writes). +// +// if recordMetaIsEmpty(&meta) { +// metaBytes = emptyMetaBytes +// } +func recordMetaIsEmpty(meta *recordMeta) bool { + return meta.URI == "" && + meta.Title == "" && + meta.Kind == "" && + meta.Track == "" && + len(meta.Tags) == 0 && + len(meta.Labels) == 0 +} + +// encodeRecordMeta hand-rolls the JSON for recordMeta into a fresh +// single-allocation buffer. Thin wrapper over appendRecordMeta — kept +// as the package-private "I want the meta bytes" entry point, used +// by the round-trip test surface and any future caller that does +// not also need the record header in the same buffer. +// +// PutBytesStream itself routes through (*Store).buildHeaderMeta which +// folds the meta append into the per-Store scratch buffer, dropping +// the alloc entirely on the warm path. +// +// buf := encodeRecordMeta(&meta) +// if uint64(len(buf)) > uint64(^uint32(0)) { /* too large */ } +func encodeRecordMeta(meta *recordMeta) []byte { + if recordMetaIsEmpty(meta) { + return emptyMetaBytes + } + buf := make([]byte, 0, recordMetaCapHint(meta)) + return appendRecordMeta(buf, meta) +} + +// buildHeaderMeta builds the on-disk record header + JSON-encoded +// recordMeta into the per-Store scratch buffer (s.headerMetaBuf), +// returning a slice into that buffer. The previous shape allocated +// a fresh buffer per Put — measurable on the state-checkpoint +// fast path because Put fires per Save during a generation step +// and per KV-snapshot during a session. +// +// PutBytesStream holds s.mu for the full record write, so the +// scratch buffer is single-owner during any one Put; the next Put +// reuses the underlying storage after the previous call's +// writeAll consumed the bytes. encodeRecordHeader (called below) +// is a pure-write helper — no further alloc beyond the slice +// header reuse. +// +// The metaSize uint32 in the header is patched after the meta is +// appended — single-pass build, no double walk over the meta +// fields. The slice retains its growth across Puts so the typical +// meta size + the cap hint converge after a handful of records. +// +// encoding/json.Marshal on recordMeta allocates an encoder state +// machine + grow-doubled output buffer + per-tag key/value copies +// on every Put. The hand-roll lands at zero buffer allocations +// regardless of tag count. +// +// The meta portion is valid JSON, parseable by encoding/json +// (round-trips into recordMeta) and by the store's extractRecordURI +// walker. Field ordering follows recordMeta's struct declaration — +// URI, Title, Kind, Track, Tags, Labels — and the omitempty +// semantics match (zero-value strings, nil/empty maps, nil/empty +// slices are elided). Tag-map keys are emitted in Go map iteration +// order — JSON object key order is not semantically meaningful and +// no read site depends on it. +// +// buf := s.buildHeaderMeta(&meta, chunkID, payloadSize) +// writeAll(s.file, buf) +func (s *Store) buildHeaderMeta(meta *recordMeta, chunkID, payloadSize int) []byte { + need := recordHeaderLen + recordMetaCapHint(meta) + if cap(s.headerMetaBuf) < need { + s.headerMetaBuf = make([]byte, recordHeaderLen, need) + } else { + s.headerMetaBuf = s.headerMetaBuf[:recordHeaderLen] + } + s.headerMetaBuf = appendRecordMeta(s.headerMetaBuf, meta) + metaSize := len(s.headerMetaBuf) - recordHeaderLen + encodeRecordHeader(s.headerMetaBuf[:recordHeaderLen], chunkID, payloadSize, metaSize) + return s.headerMetaBuf +} + +// recordMetaCapHint returns a tight upper bound on the JSON byte +// length of meta. Each non-empty field contributes its raw byte +// length plus framing overhead (the surrounding "key":"value", +// pair, with a small slack so the heuristic clears the typical +// ASCII shape in one allocation). Pathological escape-heavy inputs +// (control chars, embedded quotes) let append grow once. +func recordMetaCapHint(meta *recordMeta) int { + if recordMetaIsEmpty(meta) { + return 2 + } + size := 2 // outer braces + if meta.URI != "" { + size += 10 + len(meta.URI) // `"uri":"",` = 9 bytes + value, +1 slack + } + if meta.Title != "" { + size += 12 + len(meta.Title) // `"title":"",` + } + if meta.Kind != "" { + size += 11 + len(meta.Kind) // `"kind":"",` + } + if meta.Track != "" { + size += 12 + len(meta.Track) // `"track":"",` + } + if len(meta.Tags) > 0 { + size += 12 // `"tags":{...},` + for k, v := range meta.Tags { + size += 6 + len(k) + len(v) // `"k":"v",` + } + } + if len(meta.Labels) > 0 { + size += 14 // `"labels":[...],` + for _, l := range meta.Labels { + size += 4 + len(l) // `"l",` + } + } + return size +} + +// appendRecordMeta appends the JSON encoding of meta to buf and +// returns the extended slice. Walks the recordMeta struct in +// declaration order, eliding empty fields to honour the omitempty +// json tag semantics. Single-pass; no allocation beyond the +// caller-supplied buf's eventual grow. +func appendRecordMeta(buf []byte, meta *recordMeta) []byte { + if recordMetaIsEmpty(meta) { + return append(buf, '{', '}') + } + buf = append(buf, '{') + first := true + if meta.URI != "" { + buf = appendJSONField(buf, "uri", meta.URI, first) + first = false + } + if meta.Title != "" { + buf = appendJSONField(buf, "title", meta.Title, first) + first = false + } + if meta.Kind != "" { + buf = appendJSONField(buf, "kind", meta.Kind, first) + first = false + } + if meta.Track != "" { + buf = appendJSONField(buf, "track", meta.Track, first) + first = false + } + if len(meta.Tags) > 0 { + if !first { + buf = append(buf, ',') + } + first = false + buf = append(buf, `"tags":{`...) + tagFirst := true + for k, v := range meta.Tags { + if !tagFirst { + buf = append(buf, ',') + } + tagFirst = false + buf = appendJSONString(buf, k) + buf = append(buf, ':') + buf = appendJSONString(buf, v) + } + buf = append(buf, '}') + } + if len(meta.Labels) > 0 { + if !first { + buf = append(buf, ',') + } + buf = append(buf, `"labels":[`...) + for i, l := range meta.Labels { + if i > 0 { + buf = append(buf, ',') + } + buf = appendJSONString(buf, l) + } + buf = append(buf, ']') + } + return append(buf, '}') +} + +// appendJSONField appends a "key":"value" pair (prefixed by a comma +// when not the first field) to buf. Key is ASCII-only and not +// escaped — recordMeta keys are compile-time constants. +func appendJSONField(buf []byte, key, value string, first bool) []byte { + if !first { + buf = append(buf, ',') + } + buf = append(buf, '"') + buf = append(buf, key...) + buf = append(buf, '"', ':') + return appendJSONString(buf, value) +} + +// appendJSONString appends a JSON-encoded string to buf — opening +// quote, escaped body, closing quote. Escapes match the subset +// recognised by extractRecordURI's jsonUnescape walker: \" \\ \b +// \f \n \r \t for the canonical mnemonic forms and \u00XX for +// other control chars (< 0x20). All bytes ≥ 0x20 outside the +// quote / backslash pair pass through verbatim — encoding/json's +// default also escapes <, >, & for HTML safety but the read path +// does not, and the on-disk record is not consumed by HTML +// contexts. +// +// The body walk batches runs of non-escape bytes into a single +// append per span, so a typical URI / Title / Kind value (no +// escapes) collapses to one append-string call rather than N +// append-byte calls. encoding/json's own writer emits the no- +// escape path the same way; the per-byte loop here was an artefact +// of the original simple shape. +func appendJSONString(buf []byte, s string) []byte { + buf = append(buf, '"') + start := 0 + for i := 0; i < len(s); i++ { + c := s[i] + // Fast-path predicate: any byte ≥ 0x20 that is neither '"' + // nor '\\' passes through verbatim. The boolean short- + // circuits left-to-right and the compiler emits two CMPs + // + AND, cheaper than the previous per-byte switch dispatch. + if c >= 0x20 && c != '"' && c != '\\' { + continue + } + // Flush the verbatim span up to but not including the + // escape byte. The span is empty on the first escape at + // position 0; append-zero-length is a no-op. + if start < i { + buf = append(buf, s[start:i]...) + } + switch c { + case '"': + buf = append(buf, '\\', '"') + case '\\': + buf = append(buf, '\\', '\\') + case '\b': + buf = append(buf, '\\', 'b') + case '\f': + buf = append(buf, '\\', 'f') + case '\n': + buf = append(buf, '\\', 'n') + case '\r': + buf = append(buf, '\\', 'r') + case '\t': + buf = append(buf, '\\', 't') + default: + // c < 0x20 and not one of the mnemonic escapes — emit + // \u00XX. Hex digits emitted lowercase to match the + // jsonUnescape reader and encoding/json output. + buf = append(buf, '\\', 'u', '0', '0', hexChar(c>>4), hexChar(c&0x0f)) + } + start = i + 1 + } + if start < len(s) { + buf = append(buf, s[start:]...) + } + return append(buf, '"') +} + +// hexChar returns the ASCII hex digit for the low nibble of v. +func hexChar(v byte) byte { + v &= 0x0f + if v < 10 { + return '0' + v + } + return 'a' + (v - 10) +} + +// extractRecordURI walks data as a top-level JSON object and returns +// the value of the "uri" key as a string, or "" if absent. The walker +// fully traverses the object (including nested arrays / objects) so +// any structural corruption — unbalanced braces, truncated value, +// trailing garbage — surfaces as an error. This replaces a full +// json.Unmarshal into recordMeta for the rebuildIndex hot path, +// dropping ~6 allocs per record at 10k scale (Tags map, Labels slice, +// Title/Kind/Track string copies). The "uri" field is encoded by +// json.Marshal of a string — URLs do not require escapes in +// practice, so the fast path returns a direct slice-to-string copy; +// the rare-but-valid escape path is handled by jsonUnescape. +func extractRecordURI(data []byte) (string, error) { + i, err := jsonSkipWS(data, 0) + if err != nil { + return "", err + } + if data[i] != '{' { + return "", core.NewError("state file store metadata is not a JSON object") + } + i++ + uri := "" + uriSeen := false + first := true + for { + i, err = jsonSkipWS(data, i) + if err != nil { + return "", err + } + if data[i] == '}' { + i++ + break + } + if !first { + if data[i] != ',' { + return "", core.NewError("state file store metadata is missing comma") + } + i++ + i, err = jsonSkipWS(data, i) + if err != nil { + return "", err + } + } + first = false + if data[i] != '"' { + return "", core.NewError("state file store metadata key is not a string") + } + keyStart := i + 1 + keyEnd, err := jsonSkipString(data, i) + if err != nil { + return "", err + } + i = keyEnd + i, err = jsonSkipWS(data, i) + if err != nil { + return "", err + } + if data[i] != ':' { + return "", core.NewError("state file store metadata is missing colon") + } + i++ + i, err = jsonSkipWS(data, i) + if err != nil { + return "", err + } + isURI := !uriSeen && keyEnd-1-keyStart == 3 && + data[keyStart] == 'u' && data[keyStart+1] == 'r' && data[keyStart+2] == 'i' + if isURI { + if data[i] != '"' { + return "", core.NewError("state file store uri is not a string") + } + value, end, err := jsonReadString(data, i) + if err != nil { + return "", err + } + uri = value + uriSeen = true + i = end + } else { + end, err := jsonSkipValue(data, i) + if err != nil { + return "", err + } + i = end + } + } + // Validate no trailing garbage beyond whitespace. + for i < len(data) { + c := data[i] + if c != ' ' && c != '\t' && c != '\n' && c != '\r' { + return "", core.NewError("state file store metadata has trailing data") + } + i++ + } + return uri, nil +} + +// jsonSkipWS advances past JSON whitespace, returning the first +// non-whitespace index or an error if end-of-data is hit. The caller +// uses the returned index to read the next significant byte. +func jsonSkipWS(data []byte, i int) (int, error) { + for i < len(data) { + c := data[i] + if c != ' ' && c != '\t' && c != '\n' && c != '\r' { + return i, nil + } + i++ + } + return i, core.NewError("state file store metadata is truncated") +} + +// jsonSkipString advances past a JSON string starting at data[i] +// (which must be '"') and returns the index after the closing quote. +// Handles escape sequences but does not decode them. +func jsonSkipString(data []byte, i int) (int, error) { + if i >= len(data) || data[i] != '"' { + return i, core.NewError("state file store metadata expects string") + } + i++ + for i < len(data) { + c := data[i] + if c == '\\' { + if i+1 >= len(data) { + return i, core.NewError("state file store metadata has trailing escape") + } + // One-byte escapes (\" \\ \/ \b \f \n \r \t) or \uXXXX — + // either way the next single byte cannot terminate the + // string and the wider \uXXXX is bounded by the closing + // quote check on later iterations. + i += 2 + continue + } + if c == '"' { + return i + 1, nil + } + i++ + } + return i, core.NewError("state file store metadata string is unterminated") +} + +// jsonReadString reads a JSON string at data[i] (which must be '"') +// and returns its decoded value plus the index after the closing +// quote. Fast path: no escapes → direct string copy of the byte +// slice. Slow path: presence of an escape forces a per-byte decode +// into a fresh buffer. Used only for the "uri" field, where escapes +// are extremely rare in practice (URLs). +func jsonReadString(data []byte, i int) (string, int, error) { + if i >= len(data) || data[i] != '"' { + return "", i, core.NewError("state file store metadata expects string") + } + start := i + 1 + j := start + hasEscape := false + for j < len(data) { + c := data[j] + if c == '\\' { + hasEscape = true + if j+1 >= len(data) { + return "", j, core.NewError("state file store metadata has trailing escape") + } + j += 2 + continue + } + if c == '"' { + if !hasEscape { + return string(data[start:j]), j + 1, nil + } + decoded, err := jsonUnescape(data[start:j]) + if err != nil { + return "", j, err + } + return decoded, j + 1, nil + } + j++ + } + return "", j, core.NewError("state file store metadata string is unterminated") +} + +// jsonUnescape decodes the contents of a JSON string (without +// surrounding quotes) that contains at least one backslash escape. +// Handles the six single-byte escapes and \uXXXX (no surrogate-pair +// decoding — surrogate halves pass through as their raw UTF-8 +// encoding, which is what encoding/json itself emits for unpaired +// surrogates). Allocated once per uri-with-escape; URIs never have +// escapes in observed corpora, so this is the cold path. +func jsonUnescape(src []byte) (string, error) { + out := make([]byte, 0, len(src)) + for i := 0; i < len(src); i++ { + c := src[i] + if c != '\\' { + out = append(out, c) + continue + } + if i+1 >= len(src) { + return "", core.NewError("state file store metadata has trailing escape") + } + i++ + switch src[i] { + case '"', '\\', '/': + out = append(out, src[i]) + case 'b': + out = append(out, '\b') + case 'f': + out = append(out, '\f') + case 'n': + out = append(out, '\n') + case 'r': + out = append(out, '\r') + case 't': + out = append(out, '\t') + case 'u': + if i+4 >= len(src) { + return "", core.NewError("state file store metadata has short \\u escape") + } + var r rune + for k := 1; k <= 4; k++ { + h := src[i+k] + var v byte + switch { + case h >= '0' && h <= '9': + v = h - '0' + case h >= 'a' && h <= 'f': + v = h - 'a' + 10 + case h >= 'A' && h <= 'F': + v = h - 'A' + 10 + default: + return "", core.NewError("state file store metadata has invalid \\u escape") + } + r = r<<4 | rune(v) + } + i += 4 + // Emit r as UTF-8. Unpaired surrogates pass through as + // their replacement encoding — sufficient for the URI + // field which is ASCII in every observed corpus. + switch { + case r < 0x80: + out = append(out, byte(r)) + case r < 0x800: + out = append(out, byte(0xC0|r>>6), byte(0x80|r&0x3F)) + case r < 0x10000: + out = append(out, byte(0xE0|r>>12), byte(0x80|(r>>6)&0x3F), byte(0x80|r&0x3F)) + default: + out = append(out, byte(0xF0|r>>18), byte(0x80|(r>>12)&0x3F), byte(0x80|(r>>6)&0x3F), byte(0x80|r&0x3F)) + } + default: + return "", core.NewError("state file store metadata has unknown escape") + } + } + return string(out), nil +} + +// jsonSkipValue advances past a single JSON value (string, number, +// boolean, null, object, array) starting at data[i] and returns the +// index of the first byte after the value. The full traversal is +// what gives rebuildIndex its structural-corruption guarantee +// without forcing the whole metadata blob through json.Unmarshal. +func jsonSkipValue(data []byte, i int) (int, error) { + if i >= len(data) { + return i, core.NewError("state file store metadata is truncated") + } + c := data[i] + switch { + case c == '"': + return jsonSkipString(data, i) + case c == '{' || c == '[': + open := c + var closeByte byte + if open == '{' { + closeByte = '}' + } else { + closeByte = ']' + } + depth := 1 + i++ + for i < len(data) && depth > 0 { + cc := data[i] + switch cc { + case '"': + end, err := jsonSkipString(data, i) + if err != nil { + return i, err + } + i = end + case '{', '[': + depth++ + i++ + case '}', ']': + if cc == closeByte { + depth-- + i++ + continue + } + if (open == '{' && cc == ']') || (open == '[' && cc == '}') { + return i, core.NewError("state file store metadata has mismatched bracket") + } + depth-- + i++ + default: + i++ + } + } + if depth != 0 { + return i, core.NewError("state file store metadata is unbalanced") + } + return i, nil + case c == 't': + if i+4 > len(data) || data[i+1] != 'r' || data[i+2] != 'u' || data[i+3] != 'e' { + return i, core.NewError("state file store metadata expects true") + } + return i + 4, nil + case c == 'f': + if i+5 > len(data) || data[i+1] != 'a' || data[i+2] != 'l' || data[i+3] != 's' || data[i+4] != 'e' { + return i, core.NewError("state file store metadata expects false") + } + return i + 5, nil + case c == 'n': + if i+4 > len(data) || data[i+1] != 'u' || data[i+2] != 'l' || data[i+3] != 'l' { + return i, core.NewError("state file store metadata expects null") + } + return i + 4, nil + case c == '-' || (c >= '0' && c <= '9'): + // Number — consume digits, sign, dot, exponent. Loose but + // correct enough for structural validation; json.Marshal + // emits canonical numbers so the surface is constrained. + j := i + if data[j] == '-' { + j++ + } + for j < len(data) { + b := data[j] + if (b >= '0' && b <= '9') || b == '.' || b == 'e' || b == 'E' || b == '+' || b == '-' { + j++ + continue + } + break + } + if j == i { + return i, core.NewError("state file store metadata has empty number") + } + return j, nil + default: + return i, core.NewError("state file store metadata has invalid value") + } +} + +type limitedPayloadWriter struct { + file *core.OSFile + remaining int +} + +func (w *limitedPayloadWriter) Write(data []byte) (int, error) { + if len(data) > w.remaining { + return 0, errPayloadOversize + } + n, err := w.file.Write(data) + w.remaining -= n + if err != nil { + return n, err + } + if n != len(data) { + return n, stdio.ErrShortWrite + } + return n, nil +} + +func writeAll(file stdio.Writer, data []byte) error { + for len(data) > 0 { + n, err := file.Write(data) + if err != nil { + return err + } + if n == 0 { + return stdio.ErrShortWrite + } + data = data[n:] + } + return nil +} + +func checkContext(ctx context.Context) error { + if ctx == nil { + return nil + } + select { + case <-ctx.Done(): + return ctx.Err() + default: + return nil + } +} + +func intFromUint64(value uint64, label string) (int, error) { + max := uint64(maxInt()) + if value > max { + return 0, core.NewError("state file store " + label + " is too large") + } + return int(value), nil +} + +func maxInt() int { + return int(^uint(0) >> 1) +} + +func resultError(result core.Result) error { + if result.OK { + return nil + } + if err, ok := result.Value.(error); ok { + return err + } + return core.NewError("core result failed") +} diff --git a/go/state/filestore/store_bench_test.go b/go/state/filestore/store_bench_test.go new file mode 100644 index 0000000..6624d56 --- /dev/null +++ b/go/state/filestore/store_bench_test.go @@ -0,0 +1,159 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the filestore state primitives. +// Per AX-11 — state.filestore is the persistence layer behind every +// session checkpoint, every memvid chunk read, every cross-process +// state handoff. Read/Resolve fires per chunk during a session load; +// Put fires per Save during a generation step. +// +// Run: go test -bench='Benchmark' -benchmem -run='^$' ./state/filestore + +package filestore + +import ( + "context" + "testing" + + core "dappco.re/go" + "dappco.re/go/inference/state" +) + +// Sinks defeat compiler DCE. +var ( + bSinkChunk state.Chunk + bSinkRef state.ChunkRef + bSinkErr error +) + +// benchStore opens a fresh filestore in a temp dir + populates n chunks +// of the requested size. Returns the store + the IDs in registration +// order so benches can target a known chunk. +func benchStore(tb testing.TB, n, payloadSize int) (*Store, []state.ChunkRef) { + tb.Helper() + dir := tb.TempDir() + path := dir + "/state.bin" + store, err := Create(context.Background(), path) + if err != nil { + tb.Fatal(err) + } + tb.Cleanup(func() { _ = store.Close() }) + + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte('a' + i%26) + } + refs := make([]state.ChunkRef, 0, n) + for i := 0; i < n; i++ { + ref, err := store.PutBytes(context.Background(), payload, state.PutOptions{ + Kind: "bench", + Title: core.Sprintf("chunk-%d", i), + }) + if err != nil { + tb.Fatal(err) + } + refs = append(refs, ref) + } + return store, refs +} + +// --- ResolveBytes (binary read — hot for state load) --- + +func BenchmarkFilestore_ResolveBytes_1KB(b *testing.B) { + store, refs := benchStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bSinkChunk, bSinkErr = store.ResolveBytes(ctx, refs[0].ChunkID) + } +} + +func BenchmarkFilestore_ResolveBytes_64KB(b *testing.B) { + store, refs := benchStore(b, 1, 64*1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bSinkChunk, bSinkErr = store.ResolveBytes(ctx, refs[0].ChunkID) + } +} + +func BenchmarkFilestore_ResolveBytes_1MB(b *testing.B) { + store, refs := benchStore(b, 1, 1024*1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bSinkChunk, bSinkErr = store.ResolveBytes(ctx, refs[0].ChunkID) + } +} + +// --- Resolve (text read — exercises the AsString path) --- + +func BenchmarkFilestore_Resolve_1KB(b *testing.B) { + store, refs := benchStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bSinkChunk, bSinkErr = state.Resolve(ctx, store, refs[0].ChunkID) + } +} + +func BenchmarkFilestore_Resolve_64KB(b *testing.B) { + store, refs := benchStore(b, 1, 64*1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bSinkChunk, bSinkErr = state.Resolve(ctx, store, refs[0].ChunkID) + } +} + +// --- ResolveRefBytes (ref-with-frame-offset — alternate read path) --- + +func BenchmarkFilestore_ResolveRefBytes_1KB(b *testing.B) { + store, refs := benchStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bSinkChunk, bSinkErr = store.ResolveRefBytes(ctx, refs[0]) + } +} + +// --- Put (write path — fires per Save during generation) --- + +func BenchmarkFilestore_PutBytes_1KB(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/state.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + payload := make([]byte, 1024) + opts := state.PutOptions{Kind: "bench"} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bSinkRef, bSinkErr = store.PutBytes(ctx, payload, opts) + } +} + +func BenchmarkFilestore_Put_Text_1KB(b *testing.B) { + dir := b.TempDir() + store, err := Create(context.Background(), dir+"/state.bin") + if err != nil { + b.Fatal(err) + } + defer store.Close() + text := string(make([]byte, 1024)) + opts := state.PutOptions{Kind: "bench"} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + bSinkRef, bSinkErr = store.Put(ctx, text, opts) + } +} diff --git a/go/state/filestore/store_mmap_stub.go b/go/state/filestore/store_mmap_stub.go new file mode 100644 index 0000000..9af8828 --- /dev/null +++ b/go/state/filestore/store_mmap_stub.go @@ -0,0 +1,11 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +//go:build !(darwin || linux || freebsd || netbsd || openbsd) + +package filestore + +func (s *Store) ensureMappedRegionLocked() error { + return errMappedRegionInvalid +} + +func (s *Store) unmapRegionLocked() {} diff --git a/go/state/filestore/store_mmap_unix.go b/go/state/filestore/store_mmap_unix.go new file mode 100644 index 0000000..3c881f9 --- /dev/null +++ b/go/state/filestore/store_mmap_unix.go @@ -0,0 +1,57 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +//go:build darwin || linux || freebsd || netbsd || openbsd + +package filestore + +import "syscall" + +func (s *Store) ensureMappedRegionLocked() error { + if s == nil || s.file == nil { + return errStoreClosed + } + if s.mappedRegion != nil { + return nil + } + info, err := s.file.Stat() + if err != nil { + return err + } + size, err := s.regionSize(info.Size()) + if err != nil { + return err + } + if size <= 0 || size > int64(maxInt()) { + return errMappedRegionInvalid + } + pageSize := int64(syscall.Getpagesize()) + pageDelta := s.baseAt % pageSize + mapOffset := s.baseAt - pageDelta + mapBytes := size + pageDelta + if mapBytes <= 0 || mapBytes > int64(maxInt()) { + return errMappedRegionInvalid + } + mapped, err := syscall.Mmap(int(s.file.Fd()), mapOffset, int(mapBytes), syscall.PROT_READ, syscall.MAP_SHARED) + if err != nil { + return err + } + start := int(pageDelta) + end := start + int(size) + if start < 0 || end < start || end > len(mapped) { + _ = syscall.Munmap(mapped) + return errMappedRegionInvalid + } + s.mapped = mapped + s.mappedRegion = mapped[start:end] + return nil +} + +func (s *Store) unmapRegionLocked() { + if s == nil || s.mapped == nil { + return + } + mapped := s.mapped + s.mapped = nil + s.mappedRegion = nil + _ = syscall.Munmap(mapped) +} diff --git a/go/state/filestore/store_test.go b/go/state/filestore/store_test.go new file mode 100644 index 0000000..6e9ad07 --- /dev/null +++ b/go/state/filestore/store_test.go @@ -0,0 +1,746 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package filestore + +import ( + "context" + stdio "io" + "testing" + + core "dappco.re/go" + state "dappco.re/go/inference/state" +) + +func TestFileStore_Good_AppendsAndReopens(t *testing.T) { + ctx := context.Background() + path := core.PathJoin(t.TempDir(), "kv-blocks.mvlog") + store, err := Create(ctx, path) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + if store.Path() != path { + t.Fatalf("Path() = %q, want %q", store.Path(), path) + } + + first, err := store.Put(ctx, "alpha", state.PutOptions{URI: "mlx://kv/0", Title: "first"}) + if err != nil { + t.Fatalf("Put(first) error = %v", err) + } + second, err := store.Put(ctx, "bravo", state.PutOptions{URI: "mlx://kv/1", Title: "second"}) + if err != nil { + t.Fatalf("Put(second) error = %v", err) + } + if first.ChunkID != 1 || second.ChunkID != 2 || second.Codec != CodecFile || second.Segment != path { + t.Fatalf("refs = %+v/%+v, want sequential file refs", first, second) + } + if err := store.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + stat := core.Stat(path) + if !stat.OK { + t.Fatalf("Stat(%q): %s", path, stat.Error()) + } + if stat.Value.(interface{ Size() int64 }).Size() <= int64(len("alphabravo")) { + t.Fatalf("file size = %d, want framed payload on disk", stat.Value.(interface{ Size() int64 }).Size()) + } + + reopened, err := Open(ctx, path) + if err != nil { + t.Fatalf("Open() error = %v", err) + } + defer reopened.Close() + if reopened.ChunkCount() != 2 { + t.Fatalf("ChunkCount() = %d, want 2", reopened.ChunkCount()) + } + chunk, err := reopened.Resolve(ctx, 2) + if err != nil { + t.Fatalf("Resolve(2) error = %v", err) + } + if chunk.Text != "bravo" || chunk.Ref.ChunkID != 2 || chunk.Ref.Codec != CodecFile || chunk.Ref.Segment != path { + t.Fatalf("chunk = %+v, want second chunk from file", chunk) + } + byURI, err := state.ResolveURI(ctx, reopened, "mlx://kv/1") + if err != nil { + t.Fatalf("ResolveURI() error = %v", err) + } + if byURI.Text != "bravo" || byURI.Ref.ChunkID != 2 { + t.Fatalf("ResolveURI() chunk = %+v, want second chunk", byURI) + } +} + +func TestFileStore_Good_OpensLegacyStateHeader(t *testing.T) { + ctx := context.Background() + path := core.PathJoin(t.TempDir(), "legacy.mvlog") + meta := []byte(core.JSONMarshalString(recordMeta{URI: "mlx://legacy/1"})) + payload := []byte("legacy payload") + data := append([]byte(nil), legacyFileMagic...) + var hdrBuf [recordHeaderLen]byte + encodeRecordHeader(hdrBuf[:], 1, len(payload), len(meta)) + data = append(data, hdrBuf[:]...) + data = append(data, meta...) + data = append(data, payload...) + if result := core.WriteFile(path, data, 0o600); !result.OK { + t.Fatalf("WriteFile() error = %s", result.Error()) + } + + store, err := Open(ctx, path) + if err != nil { + t.Fatalf("Open(legacy) error = %v", err) + } + defer store.Close() + + chunk, err := state.ResolveURI(ctx, store, "mlx://legacy/1") + if err != nil { + t.Fatalf("ResolveURI(legacy) error = %v", err) + } + if chunk.Text != "legacy payload" || chunk.Ref.FrameOffset != uint64(len(legacyFileMagic)) { + t.Fatalf("legacy chunk = %+v, want payload and legacy frame offset", chunk) + } +} + +func TestFileStore_Good_BinaryPayload(t *testing.T) { + ctx := context.Background() + path := core.PathJoin(t.TempDir(), "binary.mvlog") + store, err := Create(ctx, path) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + payload := []byte{0, 1, 2, 255} + ref, err := store.PutBytes(ctx, payload, state.PutOptions{URI: "mlx://binary/1"}) + if err != nil { + t.Fatalf("PutBytes() error = %v", err) + } + payload[1] = 99 + if err := store.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + reopened, err := Open(ctx, path) + if err != nil { + t.Fatalf("Open() error = %v", err) + } + defer reopened.Close() + chunk, err := state.ResolveBytes(ctx, reopened, ref.ChunkID) + if err != nil { + t.Fatalf("ResolveBytes() error = %v", err) + } + if len(chunk.Data) != 4 || chunk.Data[0] != 0 || chunk.Data[1] != 1 || chunk.Data[3] != 255 { + t.Fatalf("ResolveBytes() data = %v, want original binary payload", chunk.Data) + } + chunk.Data[2] = 88 + again, err := state.ResolveBytes(ctx, reopened, ref.ChunkID) + if err != nil { + t.Fatalf("ResolveBytes(second) error = %v", err) + } + if again.Data[2] != 2 { + t.Fatalf("ResolveBytes() returned aliased payload = %v", again.Data) + } + byURI, err := state.ResolveURI(ctx, reopened, "mlx://binary/1") + if err != nil { + t.Fatalf("ResolveURI(binary) error = %v", err) + } + if byURI.Text != string([]byte{0, 1, 2, 255}) { + t.Fatalf("ResolveURI(binary) text = %q, want binary-compatible text fallback", byURI.Text) + } +} + +func TestFileStore_Good_ResolveRefBytesUsesFrameOffset(t *testing.T) { + ctx := context.Background() + path := core.PathJoin(t.TempDir(), "offset.mvlog") + store, err := Create(ctx, path) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + first, err := store.PutBytes(ctx, []byte("first"), state.PutOptions{}) + if err != nil { + t.Fatalf("PutBytes(first) error = %v", err) + } + second, err := store.PutBytes(ctx, []byte("second"), state.PutOptions{}) + if err != nil { + t.Fatalf("PutBytes(second) error = %v", err) + } + if err := store.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + reopened, err := Open(ctx, path) + if err != nil { + t.Fatalf("Open() error = %v", err) + } + defer reopened.Close() + + chunk, err := state.ResolveRefBytes(ctx, reopened, state.ChunkRef{ + ChunkID: second.ChunkID, + FrameOffset: second.FrameOffset, + HasFrameOffset: true, + Codec: CodecFile, + Segment: path, + }) + + if err != nil { + t.Fatalf("ResolveRefBytes(offset) error = %v", err) + } + if string(chunk.Data) != "second" || chunk.Ref.FrameOffset != second.FrameOffset { + t.Fatalf("ResolveRefBytes(offset) chunk = %+v, want second payload by frame offset", chunk) + } + if _, err := state.ResolveRefBytes(ctx, reopened, state.ChunkRef{ChunkID: first.ChunkID, FrameOffset: second.FrameOffset, HasFrameOffset: true, Codec: CodecFile, Segment: path}); err == nil { + t.Fatal("ResolveRefBytes(id mismatch) error = nil") + } + if _, err := state.ResolveRefBytes(ctx, reopened, state.ChunkRef{ChunkID: second.ChunkID, FrameOffset: second.FrameOffset, HasFrameOffset: true, Codec: CodecFile, Segment: path + ".other"}); err == nil { + t.Fatal("ResolveRefBytes(segment mismatch) error = nil") + } +} + +func TestFileStore_Good_OpenWithSegmentAlias(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + sourcePath := core.PathJoin(dir, "source.mvlog") + relocatedPath := core.PathJoin(dir, "relocated.mvlog") + source, err := Create(ctx, sourcePath) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + ref, err := source.PutBytes(ctx, []byte("relocated payload"), state.PutOptions{}) + if err != nil { + t.Fatalf("PutBytes() error = %v", err) + } + if err := source.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + read := core.ReadFile(sourcePath) + if !read.OK { + t.Fatalf("ReadFile(source) error = %s", read.Error()) + } + if write := core.WriteFile(relocatedPath, read.Value.([]byte), 0o600); !write.OK { + t.Fatalf("WriteFile(relocated) error = %s", write.Error()) + } + + strict, err := Open(ctx, relocatedPath) + if err != nil { + t.Fatalf("Open(relocated) error = %v", err) + } + if _, err := state.ResolveRefBytes(ctx, strict, ref); err == nil { + t.Fatal("strict ResolveRefBytes(source segment) error = nil") + } + if err := strict.Close(); err != nil { + t.Fatalf("strict Close() error = %v", err) + } + + aliased, err := OpenWithSegmentAlias(ctx, relocatedPath, sourcePath) + if err != nil { + t.Fatalf("OpenWithSegmentAlias() error = %v", err) + } + defer aliased.Close() + chunk, err := state.ResolveRefBytes(ctx, aliased, ref) + if err != nil { + t.Fatalf("ResolveRefBytes(alias) error = %v", err) + } + if string(chunk.Data) != "relocated payload" { + t.Fatalf("alias payload = %q, want relocated payload", string(chunk.Data)) + } + physicalRef := ref + physicalRef.Segment = relocatedPath + if _, err := state.ResolveRefBytes(ctx, aliased, physicalRef); err != nil { + t.Fatalf("ResolveRefBytes(physical segment) error = %v", err) + } + wrongRef := ref + wrongRef.Segment = sourcePath + ".wrong" + if _, err := state.ResolveRefBytes(ctx, aliased, wrongRef); err == nil { + t.Fatal("ResolveRefBytes(wrong segment) error = nil") + } +} + +func TestFileStore_Good_OpenRegionWithSegmentAlias(t *testing.T) { + ctx := context.Background() + dir := t.TempDir() + sourcePath := core.PathJoin(dir, "source.mvlog") + containerPath := core.PathJoin(dir, "session.kv") + source, err := Create(ctx, sourcePath) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + first, err := source.PutBytes(ctx, []byte("first region payload"), state.PutOptions{URI: "mlx://region/first"}) + if err != nil { + t.Fatalf("PutBytes(first) error = %v", err) + } + second, err := source.PutBytes(ctx, []byte("second region payload"), state.PutOptions{URI: "mlx://region/second"}) + if err != nil { + t.Fatalf("PutBytes(second) error = %v", err) + } + if err := source.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + read := core.ReadFile(sourcePath) + if !read.OK { + t.Fatalf("ReadFile(source) error = %s", read.Error()) + } + prefix := []byte("KVST-test-header") + suffix := []byte("not-state-log-tail") + sourceBytes := read.Value.([]byte) + container := append(append(append([]byte(nil), prefix...), sourceBytes...), suffix...) + if write := core.WriteFile(containerPath, container, 0o600); !write.OK { + t.Fatalf("WriteFile(container) error = %s", write.Error()) + } + + store, err := OpenRegionWithSegmentAlias(ctx, containerPath, int64(len(prefix)), int64(len(sourceBytes)), sourcePath) + if err != nil { + t.Fatalf("OpenRegionWithSegmentAlias() error = %v", err) + } + defer store.Close() + if store.Path() != containerPath { + t.Fatalf("Path() = %q, want container path", store.Path()) + } + if store.ChunkCount() != 2 { + t.Fatalf("ChunkCount() = %d, want 2", store.ChunkCount()) + } + chunk, err := state.ResolveRefBytes(ctx, store, second) + if err != nil { + t.Fatalf("ResolveRefBytes(alias region) error = %v", err) + } + if string(chunk.Data) != "second region payload" || chunk.Ref.FrameOffset != second.FrameOffset { + t.Fatalf("region chunk = %+v, want second payload at original frame offset", chunk) + } + borrowed, err := state.BorrowRefBytes(ctx, store, second) + if err != nil { + t.Fatalf("BorrowRefBytes(alias region) error = %v", err) + } + if string(borrowed.Data) != "second region payload" || borrowed.Ref.FrameOffset != second.FrameOffset { + t.Fatalf("borrowed region chunk = %+v, want second payload at original frame offset", borrowed) + } + byURI, err := state.ResolveURI(ctx, store, "mlx://region/first") + if err != nil { + t.Fatalf("ResolveURI(region) error = %v", err) + } + if byURI.Text != "first region payload" || byURI.Ref.FrameOffset != first.FrameOffset { + t.Fatalf("ResolveURI(region) = %+v, want first payload with relative offset", byURI) + } + physicalRef := second + physicalRef.Segment = containerPath + if _, err := state.ResolveRefBytes(ctx, store, physicalRef); err != nil { + t.Fatalf("ResolveRefBytes(physical region) error = %v", err) + } + wrongRef := second + wrongRef.Segment = sourcePath + ".wrong" + if _, err := state.ResolveRefBytes(ctx, store, wrongRef); err == nil { + t.Fatal("ResolveRefBytes(wrong region segment) error = nil") + } + if _, err := state.BorrowRefBytes(ctx, store, wrongRef); err == nil { + t.Fatal("BorrowRefBytes(wrong region segment) error = nil") + } + if _, err := store.PutBytes(ctx, []byte("blocked"), state.PutOptions{}); err == nil { + t.Fatal("PutBytes(read-only region) error = nil") + } +} + +func TestFileStore_Good_StreamPayload(t *testing.T) { + ctx := context.Background() + path := core.PathJoin(t.TempDir(), "stream.mvlog") + store, err := Create(ctx, path) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + ref, err := store.PutBytesStream(ctx, 5, state.PutOptions{URI: "mlx://stream/1"}, func(writer stdio.Writer) error { + if _, err := writer.Write([]byte("he")); err != nil { + return err + } + _, err := writer.Write([]byte("llo")) + return err + }) + if err != nil { + t.Fatalf("PutBytesStream() error = %v", err) + } + if err := store.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + reopened, err := Open(ctx, path) + if err != nil { + t.Fatalf("Open() error = %v", err) + } + defer reopened.Close() + chunk, err := state.ResolveBytes(ctx, reopened, ref.ChunkID) + if err != nil { + t.Fatalf("ResolveBytes(stream) error = %v", err) + } + if string(chunk.Data) != "hello" { + t.Fatalf("streamed payload = %q, want hello", string(chunk.Data)) + } +} + +func TestFileStore_Bad_MissingChunk(t *testing.T) { + store, err := Create(context.Background(), core.PathJoin(t.TempDir(), "empty.mvlog")) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + defer store.Close() + + _, err = store.Get(context.Background(), 99) + + if !core.Is(err, state.ErrChunkNotFound) { + t.Fatalf("Get(missing) error = %v, want ErrChunkNotFound", err) + } +} + +func TestFileStore_Bad_InvalidInputs(t *testing.T) { + if _, err := Create(context.Background(), ""); err == nil { + t.Fatal("Create(empty) error = nil, want path error") + } + if _, err := Open(context.Background(), ""); err == nil { + t.Fatal("Open(empty) error = nil, want path error") + } + if _, err := (*Store)(nil).PutBytes(context.Background(), []byte("x"), state.PutOptions{}); err == nil { + t.Fatal("PutBytes(nil store) error = nil") + } + if _, err := (*Store)(nil).ResolveBytes(context.Background(), 1); !core.Is(err, state.ErrChunkNotFound) { + t.Fatalf("ResolveBytes(nil store) error = %v, want ErrChunkNotFound", err) + } + streamPath := core.PathJoin(t.TempDir(), "invalid-stream.mvlog") + store, err := Create(context.Background(), streamPath) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + defer store.Close() + if _, err := store.PutBytesStream(context.Background(), -1, state.PutOptions{}, func(writer stdio.Writer) error { + return nil + }); err == nil { + t.Fatal("PutBytesStream(negative size) error = nil") + } + if _, err := store.PutBytesStream(context.Background(), 1, state.PutOptions{}, nil); err == nil { + t.Fatal("PutBytesStream(nil writer) error = nil") + } + if _, err := store.PutBytesStream(context.Background(), 2, state.PutOptions{}, func(writer stdio.Writer) error { + _, err := writer.Write([]byte("x")) + return err + }); err == nil { + t.Fatal("PutBytesStream(short payload) error = nil") + } + if _, err := store.PutBytesStream(context.Background(), 1, state.PutOptions{}, func(writer stdio.Writer) error { + _, err := writer.Write([]byte("too long")) + return err + }); err == nil { + t.Fatal("PutBytesStream(oversized payload) error = nil") + } + if store.ChunkCount() != 0 { + t.Fatalf("ChunkCount() = %d after failed streams, want 0", store.ChunkCount()) + } + if err := store.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + reopened, err := Open(context.Background(), streamPath) + if err != nil { + t.Fatalf("Open(after failed streams) error = %v", err) + } + defer reopened.Close() + if reopened.ChunkCount() != 0 { + t.Fatalf("reopened ChunkCount() = %d after failed streams, want 0", reopened.ChunkCount()) + } +} + +func TestFileStore_Bad_ClosedStore(t *testing.T) { + store, err := Create(context.Background(), core.PathJoin(t.TempDir(), "closed.mvlog")) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + if err := store.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + if err := store.Close(); err != nil { + t.Fatalf("Close(second) error = %v", err) + } + if _, err := store.Put(context.Background(), "payload", state.PutOptions{}); err == nil { + t.Fatal("Put(closed) error = nil") + } + if _, err := store.Resolve(context.Background(), 1); err == nil { + t.Fatal("Resolve(closed) error = nil") + } + if _, err := store.ResolveBytes(context.Background(), 1); err == nil { + t.Fatal("ResolveBytes(closed) error = nil") + } + if _, err := store.ResolveURI(context.Background(), "mlx://missing"); err == nil { + t.Fatal("ResolveURI(closed) error = nil") + } +} + +func TestFileStore_Bad_InvalidFile(t *testing.T) { + path := core.PathJoin(t.TempDir(), "invalid.mvlog") + if result := core.WriteFile(path, []byte("not a state log"), 0o600); !result.OK { + t.Fatalf("WriteFile() error = %s", result.Error()) + } + if _, err := Open(context.Background(), path); err == nil { + t.Fatal("Open(invalid header) error = nil") + } +} + +func TestFileStore_Bad_CorruptRecords(t *testing.T) { + cases := []struct { + name string + data []byte + }{ + { + name: "truncated-record-header", + data: append(append([]byte(nil), fileMagic...), recordMagic[:2]...), + }, + { + name: "invalid-record-header", + data: append(append([]byte(nil), fileMagic...), make([]byte, recordHeaderLen)...), + }, + { + name: "truncated-payload", + data: append(append(append([]byte(nil), fileMagic...), testHeader(1, 4, 0)...), []byte{1, 2}...), + }, + { + name: "invalid-metadata", + data: append(append(append([]byte(nil), fileMagic...), testHeader(1, 0, 1)...), []byte("{")...), + }, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + path := core.PathJoin(t.TempDir(), tc.name+".mvlog") + if result := core.WriteFile(path, tc.data, 0o600); !result.OK { + t.Fatalf("WriteFile() error = %s", result.Error()) + } + if _, err := Open(context.Background(), path); err == nil { + t.Fatalf("Open(%s) error = nil, want corruption error", tc.name) + } + }) + } +} + +func TestFileStore_Ugly_CancelledContext(t *testing.T) { + store, err := Create(context.Background(), core.PathJoin(t.TempDir(), "cancelled.mvlog")) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + defer store.Close() + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err = store.Put(ctx, "payload", state.PutOptions{}) + + if !core.Is(err, context.Canceled) { + t.Fatalf("Put(cancelled) error = %v, want context.Canceled", err) + } + if _, err := store.Resolve(context.Background(), 1); !core.Is(err, state.ErrChunkNotFound) { + t.Fatalf("Resolve(after cancelled put) error = %v, want missing chunk", err) + } +} + +func TestFileStore_Good_IndexCapacityHintSkipsLargePayloadStores(t *testing.T) { + if got := indexCapacityHint(int64(len(fileMagic))+1024*indexHintRecordBytes, int64(len(fileMagic))); got != 1024 { + t.Fatalf("small-record hint = %d, want 1024", got) + } + if got := indexCapacityHint(int64(len(fileMagic))+indexHintMaxFileBytes+1, int64(len(fileMagic))); got != 0 { + t.Fatalf("large-payload hint = %d, want 0", got) + } + if got := indexCapacityHint(int64(len(fileMagic)), int64(len(fileMagic))); got != 0 { + t.Fatalf("empty hint = %d, want 0", got) + } +} + +// testHeader is a test-only wrapper that returns a fresh []byte built +// via encodeRecordHeader's in-place API. Production callers should use +// encodeRecordHeader directly with a stack-allocated [recordHeaderLen]byte. +func testHeader(chunkID, payloadSize, metaSize int) []byte { + buf := make([]byte, recordHeaderLen) + encodeRecordHeader(buf, chunkID, payloadSize, metaSize) + return buf +} + +// TestFileStore_Good_RebuildIndexPreservesIndexShape pins the index +// shape across rebuildIndex changes — Wave 8 perf rewrites can alter +// how the meta JSON is parsed, but the resulting index entries (per +// chunk id) must match a Put-built index 1:1 in ref + payload offset. +// The uriIndex must contain exactly the URIs that were Put with a +// non-empty URI, mapped to the same chunk ids. +func TestFileStore_Good_RebuildIndexPreservesIndexShape(t *testing.T) { + ctx := context.Background() + path := core.PathJoin(t.TempDir(), "rebuild-shape.mvlog") + store, err := Create(ctx, path) + if err != nil { + t.Fatalf("Create() error = %v", err) + } + // Mix records with URI, without URI, with tag-maps + label-slices, + // with empty meta — covers every branch rebuildIndex touches. + cases := []state.PutOptions{ + {URI: "mlx://kv/0", Title: "with-uri", Kind: "bench"}, + {}, // empty meta + {URI: "mlx://kv/2", Tags: map[string]string{"a": "1", "b": "2"}, Labels: []string{"x", "y"}}, + {Kind: "no-uri", Track: "tr"}, + {URI: "mlx://kv/4", Title: "another", Tags: map[string]string{}}, + } + payloads := [][]byte{ + []byte("alpha"), + []byte("beta"), + []byte("gamma"), + []byte("delta"), + []byte("epsilon"), + } + var putRefs []state.ChunkRef + for i, opts := range cases { + ref, err := store.PutBytes(ctx, payloads[i], opts) + if err != nil { + t.Fatalf("PutBytes(%d) error = %v", i, err) + } + putRefs = append(putRefs, ref) + } + // Snapshot the live index built by Put for later comparison. + store.mu.Lock() + putIndex := make(map[int]fileIndexEntry, len(store.index)) + for id, entry := range store.index { + putIndex[id] = entry + } + putURIIndex := make(map[string]int, len(store.uriIndex)) + for uri, id := range store.uriIndex { + putURIIndex[uri] = id + } + putNextID := store.nextID + putWriteAt := store.writeAt + store.mu.Unlock() + if err := store.Close(); err != nil { + t.Fatalf("Close() error = %v", err) + } + + reopened, err := Open(ctx, path) + if err != nil { + t.Fatalf("Open() error = %v", err) + } + defer reopened.Close() + + reopened.mu.Lock() + defer reopened.mu.Unlock() + + if reopened.nextID != putNextID { + t.Fatalf("rebuilt nextID = %d, want %d", reopened.nextID, putNextID) + } + if reopened.writeAt != putWriteAt { + t.Fatalf("rebuilt writeAt = %d, want %d", reopened.writeAt, putWriteAt) + } + if len(reopened.index) != len(putIndex) { + t.Fatalf("rebuilt index size = %d, want %d", len(reopened.index), len(putIndex)) + } + for id, want := range putIndex { + got, ok := reopened.index[id] + if !ok { + t.Fatalf("rebuilt index missing chunk id %d", id) + } + if got.ref != want.ref { + t.Fatalf("rebuilt entry[%d].ref = %+v, want %+v", id, got.ref, want.ref) + } + if got.payloadAt != want.payloadAt { + t.Fatalf("rebuilt entry[%d].payloadAt = %d, want %d", id, got.payloadAt, want.payloadAt) + } + if got.payloadSize != want.payloadSize { + t.Fatalf("rebuilt entry[%d].payloadSize = %d, want %d", id, got.payloadSize, want.payloadSize) + } + } + if len(reopened.uriIndex) != len(putURIIndex) { + t.Fatalf("rebuilt uriIndex size = %d, want %d", len(reopened.uriIndex), len(putURIIndex)) + } + for uri, wantID := range putURIIndex { + gotID, ok := reopened.uriIndex[uri] + if !ok { + t.Fatalf("rebuilt uriIndex missing %q", uri) + } + if gotID != wantID { + t.Fatalf("rebuilt uriIndex[%q] = %d, want %d", uri, gotID, wantID) + } + } + _ = putRefs +} + +// TestEncodeRecordMeta_RoundTrip locks the hand-rolled encoder to +// encoding/json's deserialisation contract. The encoder is the +// canonical PutBytesStream meta serialiser — every record we write +// passes through it, so its output must round-trip cleanly through +// json.Unmarshal back into recordMeta with no field loss or value +// drift. Mixed shapes (empty, single string, tag map, label slice, +// escape-sensitive characters) cover the branches the encoder +// walks. +func TestEncodeRecordMeta_RoundTrip(t *testing.T) { + cases := []struct { + name string + meta recordMeta + }{ + {"empty", recordMeta{}}, + {"uri-only", recordMeta{URI: "mlx://kv/0"}}, + {"all-strings", recordMeta{ + URI: "mlx://kv/1", + Title: "training-checkpoint", + Kind: "kv", + Track: "primary", + }}, + {"tags-1", recordMeta{ + URI: "mlx://kv/2", + Tags: map[string]string{"epoch": "3"}, + }}, + {"tags-many", recordMeta{ + URI: "mlx://kv/3", + Tags: map[string]string{ + "epoch": "3", "track": "primary", + "branch": "dev", "runner": "homelab", + }, + }}, + {"labels", recordMeta{ + URI: "mlx://kv/4", + Labels: []string{"k0:v0", "k1:v1"}, + }}, + {"full", recordMeta{ + URI: "mlx://kv/5", Title: "bench", Kind: "training", + Track: "primary", Tags: map[string]string{"a": "1"}, + Labels: []string{"x"}, + }}, + {"escapes", recordMeta{ + Title: `quote " and backslash \ and slash /`, + Kind: "tabs\tand\nnewlines", + Tags: map[string]string{"control": "\x01\x02"}, + }}, + {"unicode", recordMeta{ + Title: "ünïcödé", + Labels: []string{"日本", "🐦"}, + }}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + encoded := encodeRecordMeta(&tc.meta) + var decoded recordMeta + if result := core.JSONUnmarshal(encoded, &decoded); !result.OK { + t.Fatalf("JSONUnmarshal(%s) error: %v\nencoded: %s", tc.name, result.Value, encoded) + } + if decoded.URI != tc.meta.URI { + t.Fatalf("URI = %q, want %q", decoded.URI, tc.meta.URI) + } + if decoded.Title != tc.meta.Title { + t.Fatalf("Title = %q, want %q", decoded.Title, tc.meta.Title) + } + if decoded.Kind != tc.meta.Kind { + t.Fatalf("Kind = %q, want %q", decoded.Kind, tc.meta.Kind) + } + if decoded.Track != tc.meta.Track { + t.Fatalf("Track = %q, want %q", decoded.Track, tc.meta.Track) + } + if len(decoded.Tags) != len(tc.meta.Tags) { + t.Fatalf("Tags len = %d, want %d", len(decoded.Tags), len(tc.meta.Tags)) + } + for k, v := range tc.meta.Tags { + if decoded.Tags[k] != v { + t.Fatalf("Tags[%q] = %q, want %q", k, decoded.Tags[k], v) + } + } + if len(decoded.Labels) != len(tc.meta.Labels) { + t.Fatalf("Labels len = %d, want %d", len(decoded.Labels), len(tc.meta.Labels)) + } + for i, v := range tc.meta.Labels { + if decoded.Labels[i] != v { + t.Fatalf("Labels[%d] = %q, want %q", i, decoded.Labels[i], v) + } + } + // extractRecordURI must also accept the encoder output. + uri, err := extractRecordURI(encoded) + if err != nil { + t.Fatalf("extractRecordURI: %v\nencoded: %s", err, encoded) + } + if uri != tc.meta.URI { + t.Fatalf("extractRecordURI URI = %q, want %q", uri, tc.meta.URI) + } + }) + } +} diff --git a/go/state/hierarchy_bench_test.go b/go/state/hierarchy_bench_test.go new file mode 100644 index 0000000..6f8c11b --- /dev/null +++ b/go/state/hierarchy_bench_test.go @@ -0,0 +1,203 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the Store interface-dispatch hierarchy. +// Per AX-11 — Store is a layered interface (Store / Resolver / URIResolver / +// BinaryResolver / RefBinaryResolver / Writer / BinaryWriter / +// BinaryStreamWriter). The top-level dispatchers (Resolve, ResolveBytes, +// ResolveRefBytes, ResolveURI) probe each interface in turn. The Wake +// path for a project seed can issue dozens of dispatches per restore; +// the cost of an interface-probe miss compounds in that flow. +// +// Run: go test -bench='BenchmarkHierarchy' -benchmem -run='^$' ./state + +package state + +import ( + "context" + "testing" +) + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + hierarchySinkChunk Chunk + hierarchySinkErr error + hierarchySinkText string + hierarchySinkRef ChunkRef +) + +// --- Interface-probe miss paths --- +// When a Store implements ONLY Store.Get, the top-level dispatcher must +// type-assert against Resolver / BinaryResolver / RefBinaryResolver / +// URIResolver. Each miss costs a runtime probe. The fallback branch +// then synthesises a Chunk. + +func BenchmarkHierarchy_GetAdapter_Resolve(b *testing.B) { + // benchGetOnlyStore is the bare Store.Get adapter — Resolve walks + // the Resolver-not-implemented branch and constructs a Chunk wrapper + // around the returned text. + store := &benchGetOnlyStore{text: string(make([]byte, 256))} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkChunk, hierarchySinkErr = Resolve(ctx, store, 1) + } +} + +func BenchmarkHierarchy_GetAdapter_ResolveBytes(b *testing.B) { + store := &benchGetOnlyStore{text: string(make([]byte, 256))} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkChunk, hierarchySinkErr = ResolveBytes(ctx, store, 1) + } +} + +// --- Multi-resolver fallback chain --- +// hierarchyResolverShim implements Store + Resolver but NOT +// BinaryResolver. ResolveBytes therefore goes through the Resolve +// fallback that copies chunk.Text → chunk.Data. Common in dappcore +// wrappers that adapt a remote storage backend. + +func BenchmarkHierarchy_ResolverOnly_ResolveBytes(b *testing.B) { + store := &hierarchyResolverShim{ + ref: ChunkRef{ChunkID: 1, FrameOffset: 1, HasFrameOffset: true, Codec: CodecMemory}, + text: string(make([]byte, 1024)), + } + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkChunk, hierarchySinkErr = ResolveBytes(ctx, store, 1) + } +} + +func BenchmarkHierarchy_ResolverOnly_ResolveRefBytes(b *testing.B) { + // ResolveRefBytes falls through to ResolveBytes → Resolve when the + // Store implements neither RefBinaryResolver nor BinaryResolver. + store := &hierarchyResolverShim{ + ref: ChunkRef{ChunkID: 1, FrameOffset: 1, HasFrameOffset: true, Codec: CodecMemory}, + text: string(make([]byte, 1024)), + } + ctx := context.Background() + ref := ChunkRef{ChunkID: 1, FrameOffset: 1, HasFrameOffset: true} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkChunk, hierarchySinkErr = ResolveRefBytes(ctx, store, ref) + } +} + +// --- BinaryResolver path without RefBinaryResolver --- +// hierarchyBinaryShim implements Store + BinaryResolver. ResolveRefBytes +// must fall through to ResolveBytes (the BinaryResolver-without-Ref path). + +func BenchmarkHierarchy_BinaryOnly_ResolveRefBytes(b *testing.B) { + store := &hierarchyBinaryShim{ + ref: ChunkRef{ChunkID: 1, FrameOffset: 1, HasFrameOffset: true, Codec: CodecMemory}, + data: make([]byte, 1024), + } + ctx := context.Background() + ref := ChunkRef{ChunkID: 1, FrameOffset: 1, HasFrameOffset: true} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkChunk, hierarchySinkErr = ResolveRefBytes(ctx, store, ref) + } +} + +// --- MergeRef shape coverage --- +// MergeRef merges an overlay onto a base ref. The existing bench file +// covers OverlayAll / OverlayPartial / OverlayEmpty. These cover the +// less-typical permutations: same-base (no-op merge), zero-id base, +// codec-only overlay, segment-only overlay, frame-offset only overlay. + +func BenchmarkHierarchy_MergeRef_SameBase(b *testing.B) { + base := ChunkRef{ChunkID: 7, FrameOffset: 7, HasFrameOffset: true, Codec: CodecMemory, Segment: "seg-a"} + overlay := base + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkRef = MergeRef(base, overlay) + } +} + +func BenchmarkHierarchy_MergeRef_ZeroBase(b *testing.B) { + // Zero base — every field on overlay wins, but the no-id branch + // short-circuits the merge. + overlay := ChunkRef{ChunkID: 7, FrameOffset: 7, HasFrameOffset: true, Codec: CodecMemory} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkRef = MergeRef(ChunkRef{}, overlay) + } +} + +func BenchmarkHierarchy_MergeRef_CodecOnlyOverlay(b *testing.B) { + base := ChunkRef{ChunkID: 7, FrameOffset: 7, HasFrameOffset: true, Codec: CodecMemory} + overlay := ChunkRef{Codec: CodecStateVideo} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkRef = MergeRef(base, overlay) + } +} + +func BenchmarkHierarchy_MergeRef_SegmentOnlyOverlay(b *testing.B) { + base := ChunkRef{ChunkID: 7, FrameOffset: 7, HasFrameOffset: true, Codec: CodecMemory} + overlay := ChunkRef{Segment: "epoch-9"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkRef = MergeRef(base, overlay) + } +} + +func BenchmarkHierarchy_MergeRef_FrameOffsetOnlyOverlay(b *testing.B) { + base := ChunkRef{ChunkID: 7, FrameOffset: 7, HasFrameOffset: true, Codec: CodecMemory} + overlay := ChunkRef{FrameOffset: 99, HasFrameOffset: true} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + hierarchySinkRef = MergeRef(base, overlay) + } +} + +// --- Shim helpers --- +// One file holds the shim defs to keep the bench surface flat. + +// hierarchyResolverShim implements Store.Get + Resolver but not the +// binary interfaces. Forces ResolveBytes/ResolveRefBytes to dispatch +// through the Resolver fallback which copies Text → Data. +type hierarchyResolverShim struct { + ref ChunkRef + text string +} + +func (s *hierarchyResolverShim) Get(_ context.Context, _ int) (string, error) { + return s.text, nil +} + +func (s *hierarchyResolverShim) Resolve(_ context.Context, chunkID int) (Chunk, error) { + ref := s.ref + ref.ChunkID = chunkID + return Chunk{Ref: ref, Text: s.text}, nil +} + +// hierarchyBinaryShim implements Store.Get + BinaryResolver but not +// RefBinaryResolver. ResolveRefBytes must fall through ResolveBytes. +type hierarchyBinaryShim struct { + ref ChunkRef + data []byte +} + +func (s *hierarchyBinaryShim) Get(_ context.Context, _ int) (string, error) { + return string(s.data), nil +} + +func (s *hierarchyBinaryShim) ResolveBytes(_ context.Context, chunkID int) (Chunk, error) { + ref := s.ref + ref.ChunkID = chunkID + return Chunk{Ref: ref, Data: append([]byte(nil), s.data...)}, nil +} diff --git a/go/state/identity.go b/go/state/identity.go new file mode 100644 index 0000000..ac4d512 --- /dev/null +++ b/go/state/identity.go @@ -0,0 +1,103 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package state + +// ModelIdentity carries backend-neutral model metadata for state bundles, +// benchmark reports, fit planning, and adapter compatibility checks. +type ModelIdentity struct { + ID string `json:"id,omitempty"` + Path string `json:"path,omitempty"` + Architecture string `json:"architecture,omitempty"` + Revision string `json:"revision,omitempty"` + Hash string `json:"hash,omitempty"` + QuantBits int `json:"quant_bits,omitempty"` + QuantGroup int `json:"quant_group,omitempty"` + QuantType string `json:"quant_type,omitempty"` + ContextLength int `json:"context_length,omitempty"` + NumLayers int `json:"num_layers,omitempty"` + HiddenSize int `json:"hidden_size,omitempty"` + VocabSize int `json:"vocab_size,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// TokenizerIdentity carries tokenizer and chat-template metadata without +// exposing backend-specific tokenizer implementations. +type TokenizerIdentity struct { + Kind string `json:"kind,omitempty"` + Path string `json:"path,omitempty"` + Hash string `json:"hash,omitempty"` + ChatTemplate string `json:"chat_template,omitempty"` + BOSID int32 `json:"bos_id,omitempty"` + EOSID int32 `json:"eos_id,omitempty"` + PADID int32 `json:"pad_id,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// AdapterIdentity is the portable identity for an active or saved adapter. +type AdapterIdentity struct { + Path string `json:"path,omitempty"` + Hash string `json:"hash,omitempty"` + Format string `json:"format,omitempty"` + Rank int `json:"rank,omitempty"` + Alpha float32 `json:"alpha,omitempty"` + TargetKeys []string `json:"target_keys,omitempty"` + BaseModelHash string `json:"base_model_hash,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// RuntimeIdentity records runtime and device metadata for reproducibility. +type RuntimeIdentity struct { + Backend string `json:"backend,omitempty"` + Device string `json:"device,omitempty"` + Version string `json:"version,omitempty"` + CacheMode string `json:"cache_mode,omitempty"` + NativeRuntime bool `json:"native_runtime,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// SamplerConfig is the serializable form of generation sampler settings. +type SamplerConfig struct { + MaxTokens int `json:"max_tokens,omitempty"` + Temperature float32 `json:"temperature,omitempty"` + TopK int `json:"top_k,omitempty"` + TopP float32 `json:"top_p,omitempty"` + RepeatPenalty float32 `json:"repeat_penalty,omitempty"` + StopTokens []int32 `json:"stop_tokens,omitempty"` + StopSequences []string `json:"stop_sequences,omitempty"` + ReturnLogits bool `json:"return_logits,omitempty"` +} + +// StateRef points to backend-owned binary state, probe, or knowledge-pack data. +type StateRef struct { + Kind string `json:"kind,omitempty"` + URI string `json:"uri,omitempty"` + Hash string `json:"hash,omitempty"` + SizeBytes uint64 `json:"size_bytes,omitempty"` + Encoding string `json:"encoding,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// Bundle is a portable state envelope. It contains metadata and references, +// not backend tensor objects. +type Bundle struct { + Version string `json:"version,omitempty"` + CreatedAtUnix int64 `json:"created_at_unix,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Tokenizer TokenizerIdentity `json:"tokenizer,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Sampler SamplerConfig `json:"sampler,omitempty"` + Runtime RuntimeIdentity `json:"runtime,omitempty"` + PromptHash string `json:"prompt_hash,omitempty"` + PromptTokens int `json:"prompt_tokens,omitempty"` + GeneratedTokens int `json:"generated_tokens,omitempty"` + KVRefs []StateRef `json:"kv_refs,omitempty"` + ProbeRefs []StateRef `json:"probe_refs,omitempty"` + StateRefs []StateRef `json:"state_refs,omitempty"` + // Deprecated: use StateRefs. + MemvidRefs []StateRef `json:"memvid_refs,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// StateBundle keeps the previous package-level name available for callers +// that want the longer explicit spelling. +type StateBundle = Bundle diff --git a/go/state/identity_bench_test.go b/go/state/identity_bench_test.go new file mode 100644 index 0000000..4f413ac --- /dev/null +++ b/go/state/identity_bench_test.go @@ -0,0 +1,309 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the backend-neutral identity primitives. +// Per AX-11 — ModelIdentity / TokenizerIdentity / AdapterIdentity / +// RuntimeIdentity travel inside every WakeRequest, SleepRequest, and +// Bundle. Bundle itself is the durable envelope written on every +// Sleep and re-read on every Wake. The struct fields are flat but +// the slices (KVRefs, ProbeRefs, StateRefs) carry the per-bundle +// allocation cost. +// +// Run: go test -bench='Benchmark' -benchmem -run='^$' ./state + +package state + +import "testing" + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + identitySinkModel ModelIdentity + identitySinkTokenizer TokenizerIdentity + identitySinkAdapter AdapterIdentity + identitySinkRuntime RuntimeIdentity + identitySinkSampler SamplerConfig + identitySinkBundle Bundle + identitySinkStateRef StateRef +) + +// --- ModelIdentity (per-bundle, per-wake, per-sleep) --- + +func BenchmarkIdentity_Model_Construct_Minimal(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkModel = ModelIdentity{ + ID: "gemma4", + Architecture: "gemma4_text", + Hash: "model-a", + NumLayers: 28, + } + } +} + +func BenchmarkIdentity_Model_Construct_Full(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkModel = ModelIdentity{ + ID: "gemma4", + Path: "/Users/snider/Lethean/models/gemma4-27b", + Architecture: "gemma4_text", + Revision: "main", + Hash: "sha256:abcdefabcdef", + QuantBits: 4, + QuantGroup: 64, + QuantType: "jangtq", + ContextLength: 262144, + NumLayers: 28, + HiddenSize: 4096, + VocabSize: 262144, + } + } +} + +func BenchmarkIdentity_Model_Construct_FullWithLabels(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkModel = ModelIdentity{ + ID: "gemma4", + Path: "/Users/snider/Lethean/models/gemma4-27b", + Architecture: "gemma4_text", + Hash: "sha256:abcdefabcdef", + QuantBits: 4, + QuantGroup: 64, + QuantType: "jangtq", + ContextLength: 262144, + NumLayers: 28, + HiddenSize: 4096, + VocabSize: 262144, + Labels: map[string]string{ + "vendor": "google", + "family": "gemma", + "size": "27b", + "variant": "text", + "licence": "gemma-tos", + "upstream": "huggingface", + }, + } + } +} + +// --- TokenizerIdentity (per-bundle) --- + +func BenchmarkIdentity_Tokenizer_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkTokenizer = TokenizerIdentity{ + Kind: "sentencepiece", + Path: "/Users/snider/Lethean/models/gemma4-27b/tokenizer.model", + Hash: "sha256:tok-abc", + ChatTemplate: "gemma-it", + BOSID: 2, + EOSID: 1, + PADID: 0, + } + } +} + +// --- AdapterIdentity (per-bundle, per-wake compatibility check) --- + +func BenchmarkIdentity_Adapter_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkAdapter = AdapterIdentity{ + Path: "/Users/snider/Lethean/adapters/cladius.lora", + Hash: "sha256:adapter-abc", + Format: "lora", + Rank: 8, + Alpha: 16, + BaseModelHash: "sha256:abcdefabcdef", + } + } +} + +func BenchmarkIdentity_Adapter_Construct_WithTargetKeys(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkAdapter = AdapterIdentity{ + Path: "/Users/snider/Lethean/adapters/cladius.lora", + Hash: "sha256:adapter-abc", + Format: "lora", + Rank: 8, + Alpha: 16, + TargetKeys: []string{ + "q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj", + }, + BaseModelHash: "sha256:abcdefabcdef", + } + } +} + +// --- RuntimeIdentity (per-bundle) --- + +func BenchmarkIdentity_Runtime_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkRuntime = RuntimeIdentity{ + Backend: "metal", + Device: "Apple M3 Ultra", + Version: "26.0.0", + CacheMode: "paged-q8", + NativeRuntime: true, + } + } +} + +// --- SamplerConfig (per-generation step, per-bundle) --- + +func BenchmarkIdentity_Sampler_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkSampler = SamplerConfig{ + MaxTokens: 4096, + Temperature: 0.7, + TopK: 40, + TopP: 0.9, + RepeatPenalty: 1.1, + StopTokens: []int32{1, 2, 0}, + StopSequences: []string{"", "<|end|>"}, + ReturnLogits: false, + } + } +} + +// --- StateRef (per-block during bundle assembly) --- + +func BenchmarkIdentity_StateRef_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkStateRef = StateRef{ + Kind: "kv", + URI: "state://kv/blocks/0", + Hash: "sha256:block-abc", + SizeBytes: 65536, + Encoding: "raw", + } + } +} + +// --- Bundle (durable envelope — every Sleep writes one) --- + +func BenchmarkIdentity_Bundle_Construct_Minimal(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkBundle = Bundle{ + Version: "v1", + CreatedAtUnix: 1700000000, + Model: ModelIdentity{ID: "gemma4", Hash: "model-a"}, + PromptTokens: 2048, + } + } +} + +func BenchmarkIdentity_Bundle_Construct_KVRefs_10(b *testing.B) { + model := ModelIdentity{ID: "gemma4", Hash: "model-a", NumLayers: 28} + tok := TokenizerIdentity{Hash: "tok-a"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + kv := make([]StateRef, 0, 10) + for j := 0; j < 10; j++ { + kv = append(kv, StateRef{Kind: "kv", URI: "state://kv/blocks", SizeBytes: 65536}) + } + identitySinkBundle = Bundle{ + Version: "v1", + CreatedAtUnix: 1700000000, + Model: model, + Tokenizer: tok, + KVRefs: kv, + PromptTokens: 2048, + } + } +} + +func BenchmarkIdentity_Bundle_Construct_KVRefs_100(b *testing.B) { + model := ModelIdentity{ID: "gemma4", Hash: "model-a", NumLayers: 28} + tok := TokenizerIdentity{Hash: "tok-a"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + kv := make([]StateRef, 0, 100) + for j := 0; j < 100; j++ { + kv = append(kv, StateRef{Kind: "kv", URI: "state://kv/blocks", SizeBytes: 65536}) + } + identitySinkBundle = Bundle{ + Version: "v1", + CreatedAtUnix: 1700000000, + Model: model, + Tokenizer: tok, + KVRefs: kv, + PromptTokens: 2048, + } + } +} + +func BenchmarkIdentity_Bundle_Construct_KVRefs_1000(b *testing.B) { + model := ModelIdentity{ID: "gemma4", Hash: "model-a", NumLayers: 28} + tok := TokenizerIdentity{Hash: "tok-a"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + kv := make([]StateRef, 0, 1000) + for j := 0; j < 1000; j++ { + kv = append(kv, StateRef{Kind: "kv", URI: "state://kv/blocks", SizeBytes: 65536}) + } + identitySinkBundle = Bundle{ + Version: "v1", + CreatedAtUnix: 1700000000, + Model: model, + Tokenizer: tok, + KVRefs: kv, + PromptTokens: 2048, + } + } +} + +// --- Bundle copy (pure struct shape, no slice alloc) --- +// The Bundle struct copy fires on every WakeResult / SleepResult +// return; the slice headers are shared so this measures just the +// scalar+header cost. + +func BenchmarkIdentity_Bundle_Copy(b *testing.B) { + src := Bundle{ + Version: "v1", + CreatedAtUnix: 1700000000, + Model: ModelIdentity{ID: "gemma4", Hash: "model-a", NumLayers: 28}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 8}, + Runtime: RuntimeIdentity{Backend: "metal"}, + PromptTokens: 2048, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkBundle = src + } +} + +// StateBundle is the long-form type alias for Bundle — confirm zero overhead. + +func BenchmarkIdentity_StateBundle_AliasCopy(b *testing.B) { + src := StateBundle{ + Version: "v1", + Model: ModelIdentity{ID: "gemma4"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + identitySinkBundle = src + } +} diff --git a/go/state/memory.go b/go/state/memory.go new file mode 100644 index 0000000..46b2885 --- /dev/null +++ b/go/state/memory.go @@ -0,0 +1,232 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package state + +import "context" + +type InMemoryStore struct { + chunks map[int]string + data map[int][]byte + refs map[int]ChunkRef + uris map[string]int + nextID int +} + +func NewInMemoryStore(chunks map[int]string) *InMemoryStore { + return NewInMemoryStoreWithManifest(chunks, nil) +} + +func NewInMemoryStoreWithManifest(chunks map[int]string, refs map[int]ChunkRef) *InMemoryStore { + // Single-pass over the seed map: populate text + default ref together so + // each id is visited once instead of twice. Refs override defaults below. + // All maps are lazy: when no chunks/refs are seeded the four backing + // maps stay nil and the four make() heap allocs are skipped entirely. + // Read sites (Resolve/ResolveBytes/ResolveURI) are nil-safe — Go maps + // return the zero value + ok=false from nil — and Put/PutBytes already + // lazy-init on first write. The bench-only NewInMemoryStore_Empty call + // pattern drops from 5 allocs / 240 B to 1 alloc / 32 B (just the + // Store struct). + var copyMap map[int]string + var refMap map[int]ChunkRef + if total := len(chunks) + len(refs); total > 0 { + copyMap = make(map[int]string, len(chunks)) + refMap = make(map[int]ChunkRef, total) + } + nextID := 1 + for id, text := range chunks { + copyMap[id] = text + refMap[id] = ChunkRef{ + ChunkID: id, + FrameOffset: uint64(id), + HasFrameOffset: true, + Codec: CodecMemory, + } + if id >= nextID { + nextID = id + 1 + } + } + for id, ref := range refs { + ref.ChunkID = id + refMap[id] = ref + if id >= nextID { + nextID = id + 1 + } + } + return &InMemoryStore{ + chunks: copyMap, + refs: refMap, + nextID: nextID, + } +} + +func (s *InMemoryStore) Get(ctx context.Context, chunkID int) (string, error) { + chunk, err := s.Resolve(ctx, chunkID) + if err != nil { + return "", err + } + return chunk.Text, nil +} + +func (s *InMemoryStore) Resolve(ctx context.Context, chunkID int) (Chunk, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return Chunk{}, ctx.Err() + default: + } + if s == nil { + return Chunk{}, &ChunkNotFoundError{ID: chunkID} + } + text, ok := s.chunks[chunkID] + data, dataOK := s.data[chunkID] + if !ok && !dataOK { + return Chunk{}, &ChunkNotFoundError{ID: chunkID} + } + ref := s.refs[chunkID] + if ref.ChunkID != chunkID { + ref.ChunkID = chunkID + } + chunk := Chunk{Ref: ref, Text: text} + if dataOK { + chunk.Data = append([]byte(nil), data...) + if chunk.Text == "" { + chunk.Text = string(data) + } + } + return chunk, nil +} + +func (s *InMemoryStore) ResolveBytes(ctx context.Context, chunkID int) (Chunk, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return Chunk{}, ctx.Err() + default: + } + if s == nil { + return Chunk{}, &ChunkNotFoundError{ID: chunkID} + } + ref := s.refs[chunkID] + if ref.ChunkID != chunkID { + ref.ChunkID = chunkID + } + if data, ok := s.data[chunkID]; ok { + return Chunk{Ref: ref, Data: append([]byte(nil), data...)}, nil + } + text, ok := s.chunks[chunkID] + if !ok { + return Chunk{}, &ChunkNotFoundError{ID: chunkID} + } + return Chunk{Ref: ref, Text: text, Data: []byte(text)}, nil +} + +func (s *InMemoryStore) ResolveURI(ctx context.Context, uri string) (Chunk, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return Chunk{}, ctx.Err() + default: + } + if s == nil { + return Chunk{}, &URIChunkNotFoundError{URI: uri} + } + id, ok := s.uris[uri] + if !ok { + return Chunk{}, &URIChunkNotFoundError{URI: uri} + } + return s.Resolve(ctx, id) +} + +func (s *InMemoryStore) Put(ctx context.Context, text string, opts PutOptions) (ChunkRef, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ChunkRef{}, ctx.Err() + default: + } + if s == nil { + return ChunkRef{}, &ChunkNotFoundError{} + } + if s.chunks == nil { + s.chunks = make(map[int]string) + } + if s.refs == nil { + s.refs = make(map[int]ChunkRef) + } + if s.data == nil { + s.data = make(map[int][]byte) + } + if s.uris == nil { + s.uris = make(map[string]int) + } + if s.nextID <= 0 { + s.nextID = 1 + } + id := s.nextID + s.nextID++ + ref := ChunkRef{ + ChunkID: id, + FrameOffset: uint64(id), + HasFrameOffset: true, + Codec: CodecMemory, + } + s.chunks[id] = text + delete(s.data, id) + s.refs[id] = ref + if opts.URI != "" { + s.uris[opts.URI] = id + } + return ref, nil +} + +func (s *InMemoryStore) PutBytes(ctx context.Context, data []byte, opts PutOptions) (ChunkRef, error) { + if ctx == nil { + ctx = context.Background() + } + select { + case <-ctx.Done(): + return ChunkRef{}, ctx.Err() + default: + } + if s == nil { + return ChunkRef{}, &ChunkNotFoundError{} + } + if s.chunks == nil { + s.chunks = make(map[int]string) + } + if s.data == nil { + s.data = make(map[int][]byte) + } + if s.refs == nil { + s.refs = make(map[int]ChunkRef) + } + if s.uris == nil { + s.uris = make(map[string]int) + } + if s.nextID <= 0 { + s.nextID = 1 + } + id := s.nextID + s.nextID++ + ref := ChunkRef{ + ChunkID: id, + FrameOffset: uint64(id), + HasFrameOffset: true, + Codec: CodecMemory, + } + delete(s.chunks, id) + s.data[id] = append([]byte(nil), data...) + s.refs[id] = ref + if opts.URI != "" { + s.uris[opts.URI] = id + } + return ref, nil +} diff --git a/go/state/memory_bench_test.go b/go/state/memory_bench_test.go new file mode 100644 index 0000000..20ade86 --- /dev/null +++ b/go/state/memory_bench_test.go @@ -0,0 +1,295 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the InMemoryStore backend. +// Per AX-11 — InMemoryStore is the test-and-bench default store and +// the cheapest target for cache-warm-up shapes. Get / Resolve fire +// per chunk on every session load; Put / PutBytes fire per Save. +// ResolveURI is the per-name lookup that backs the URIResolver path +// in the top-level state.ResolveURI helper. +// +// Run: go test -bench='Benchmark' -benchmem -run='^$' ./state + +package state + +import ( + "context" + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + memorySinkChunk Chunk + memorySinkText string + memorySinkRef ChunkRef + memorySinkErr error + memorySinkStorePtr *InMemoryStore +) + +// benchMemoryStore builds an InMemoryStore with n text chunks of +// payloadSize bytes each + n URIs registered for ResolveURI lookups. +func benchMemoryStore(tb testing.TB, n, payloadSize int) *InMemoryStore { + tb.Helper() + chunks := make(map[int]string, n) + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte('a' + i%26) + } + text := string(payload) + for i := 1; i <= n; i++ { + chunks[i] = text + } + store := NewInMemoryStore(chunks) + // Register URIs after the fact via Put — keeps the bench helper + // off the URI-pre-seeding path the test file exercises. + for i := 1; i <= n; i++ { + _, err := store.Put(context.Background(), text, PutOptions{ + URI: "state://bench/" + core.Sprintf("chunk-%d", i), + }) + if err != nil { + tb.Fatal(err) + } + } + return store +} + +// --- NewInMemoryStore (one per session boot) --- + +func BenchmarkMemory_NewInMemoryStore_Empty(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkStorePtr = NewInMemoryStore(nil) + } +} + +func BenchmarkMemory_NewInMemoryStore_10(b *testing.B) { + chunks := map[int]string{ + 1: "a", 2: "b", 3: "c", 4: "d", 5: "e", + 6: "f", 7: "g", 8: "h", 9: "i", 10: "j", + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkStorePtr = NewInMemoryStore(chunks) + } +} + +func BenchmarkMemory_NewInMemoryStore_100(b *testing.B) { + chunks := make(map[int]string, 100) + for i := 1; i <= 100; i++ { + chunks[i] = "chunk" + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkStorePtr = NewInMemoryStore(chunks) + } +} + +func BenchmarkMemory_NewInMemoryStore_1000(b *testing.B) { + chunks := make(map[int]string, 1000) + for i := 1; i <= 1000; i++ { + chunks[i] = "chunk" + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkStorePtr = NewInMemoryStore(chunks) + } +} + +func BenchmarkMemory_NewInMemoryStoreWithManifest_10(b *testing.B) { + chunks := map[int]string{ + 1: "a", 2: "b", 3: "c", 4: "d", 5: "e", + 6: "f", 7: "g", 8: "h", 9: "i", 10: "j", + } + refs := map[int]ChunkRef{ + 1: {ChunkID: 1, Codec: CodecStateVideo, FrameOffset: 7, HasFrameOffset: true}, + 2: {ChunkID: 2, Codec: CodecStateVideo, FrameOffset: 8, HasFrameOffset: true}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkStorePtr = NewInMemoryStoreWithManifest(chunks, refs) + } +} + +// --- Get (text read — Store interface, simplest path) --- + +func BenchmarkMemory_Get_Short(b *testing.B) { + store := benchMemoryStore(b, 1, 16) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkText, memorySinkErr = store.Get(ctx, 1) + } +} + +func BenchmarkMemory_Get_1KB(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkText, memorySinkErr = store.Get(ctx, 1) + } +} + +func BenchmarkMemory_Get_64KB(b *testing.B) { + store := benchMemoryStore(b, 1, 64*1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkText, memorySinkErr = store.Get(ctx, 1) + } +} + +// --- Resolve (Chunk read — Resolver interface) --- + +func BenchmarkMemory_Resolve_1KB(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkChunk, memorySinkErr = store.Resolve(ctx, 1) + } +} + +func BenchmarkMemory_Resolve_64KB(b *testing.B) { + store := benchMemoryStore(b, 1, 64*1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkChunk, memorySinkErr = store.Resolve(ctx, 1) + } +} + +// --- ResolveBytes (binary read — BinaryResolver path) --- + +func BenchmarkMemory_ResolveBytes_1KB(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkChunk, memorySinkErr = store.ResolveBytes(ctx, 1) + } +} + +func BenchmarkMemory_ResolveBytes_64KB(b *testing.B) { + store := benchMemoryStore(b, 1, 64*1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkChunk, memorySinkErr = store.ResolveBytes(ctx, 1) + } +} + +// --- ResolveURI (name → ID lookup, then Resolve) --- + +func BenchmarkMemory_ResolveURI_10Chunks(b *testing.B) { + store := benchMemoryStore(b, 10, 1024) + ctx := context.Background() + uri := "state://bench/chunk-1" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkChunk, memorySinkErr = store.ResolveURI(ctx, uri) + } +} + +func BenchmarkMemory_ResolveURI_1000Chunks(b *testing.B) { + store := benchMemoryStore(b, 1000, 1024) + ctx := context.Background() + uri := "state://bench/chunk-1" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkChunk, memorySinkErr = store.ResolveURI(ctx, uri) + } +} + +// --- Put (text write — fires per text Save) --- + +func BenchmarkMemory_Put_1KB(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + text := string(make([]byte, 1024)) + opts := PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkRef, memorySinkErr = store.Put(ctx, text, opts) + } +} + +func BenchmarkMemory_Put_64KB(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + text := string(make([]byte, 64*1024)) + opts := PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkRef, memorySinkErr = store.Put(ctx, text, opts) + } +} + +func BenchmarkMemory_Put_WithURI(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + text := string(make([]byte, 1024)) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkRef, memorySinkErr = store.Put(ctx, text, PutOptions{ + Kind: "bench", + URI: "state://bench/put", + }) + } +} + +// --- PutBytes (binary write — fires per binary Save) --- + +func BenchmarkMemory_PutBytes_1KB(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + data := make([]byte, 1024) + opts := PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkRef, memorySinkErr = store.PutBytes(ctx, data, opts) + } +} + +func BenchmarkMemory_PutBytes_64KB(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + data := make([]byte, 64*1024) + opts := PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkRef, memorySinkErr = store.PutBytes(ctx, data, opts) + } +} + +func BenchmarkMemory_PutBytes_1MB(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + data := make([]byte, 1024*1024) + opts := PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memorySinkRef, memorySinkErr = store.PutBytes(ctx, data, opts) + } +} diff --git a/go/state/memory_capacity_bench_test.go b/go/state/memory_capacity_bench_test.go new file mode 100644 index 0000000..fdef52f --- /dev/null +++ b/go/state/memory_capacity_bench_test.go @@ -0,0 +1,169 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for InMemoryStore at larger capacities. +// Per AX-11 — the existing memory bench file covers single-chunk and +// 10/100/1000-entry constructors, plus a 1000-chunk ResolveURI. This +// file extends to the eviction-pressure shapes that matter for the +// Virgil portable-memory thesis: continuous workspaces accumulate +// thousands of chunks before any rollover. Random + sequential read +// patterns expose the map-hash + slice-append cost at scale. +// +// Run: go test -bench='BenchmarkMemoryCapacity' -benchmem -run='^$' ./state + +package state + +import ( + "context" + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + memCapSinkChunk Chunk + memCapSinkText string + memCapSinkRef ChunkRef + memCapSinkErr error + memCapSinkStorePtr *InMemoryStore +) + +// memoryStoreNoURI populates n chunks WITHOUT URIs — avoids the +// per-chunk Put loop that would otherwise dominate setup time. URI +// presence is benched separately above; this file targets the bare +// map-driven read path. +func memoryStoreNoURI(tb testing.TB, n, payloadSize int) *InMemoryStore { + tb.Helper() + chunks := make(map[int]string, n) + payload := make([]byte, payloadSize) + for i := range payload { + payload[i] = byte('a' + i%26) + } + text := string(payload) + for i := 1; i <= n; i++ { + chunks[i] = text + } + return NewInMemoryStore(chunks) +} + +// --- Resolve at scale (sequential access) --- +// Walks IDs in registration order — the dominant pattern for a +// session-wake bundle replay (chunk-1, chunk-2, ..., chunk-N). + +func BenchmarkMemoryCapacity_Resolve_1k_Seq(b *testing.B) { + store := memoryStoreNoURI(b, 1000, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := (i % 1000) + 1 + memCapSinkChunk, memCapSinkErr = store.Resolve(ctx, id) + } +} + +func BenchmarkMemoryCapacity_Resolve_10k_Seq(b *testing.B) { + store := memoryStoreNoURI(b, 10000, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := (i % 10000) + 1 + memCapSinkChunk, memCapSinkErr = store.Resolve(ctx, id) + } +} + +// --- Get at scale --- +// Get is the bare Store.Get contract — the cheapest dispatch. + +func BenchmarkMemoryCapacity_Get_1k(b *testing.B) { + store := memoryStoreNoURI(b, 1000, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := (i % 1000) + 1 + memCapSinkText, memCapSinkErr = store.Get(ctx, id) + } +} + +func BenchmarkMemoryCapacity_Get_10k(b *testing.B) { + store := memoryStoreNoURI(b, 10000, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := (i % 10000) + 1 + memCapSinkText, memCapSinkErr = store.Get(ctx, id) + } +} + +// --- ResolveBytes at scale (binary-read path) --- + +func BenchmarkMemoryCapacity_ResolveBytes_1k(b *testing.B) { + store := memoryStoreNoURI(b, 1000, 256) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + id := (i % 1000) + 1 + memCapSinkChunk, memCapSinkErr = store.ResolveBytes(ctx, id) + } +} + +// --- Put growth (repeated insert into existing store) --- +// Models a Save loop on a live, already-warm store. The per-Put cost +// should be dominated by the map-write + ref construction; growing +// past the initial capacity exercises map-grow. + +func BenchmarkMemoryCapacity_Put_Repeated_1k(b *testing.B) { + store := memoryStoreNoURI(b, 1000, 256) + ctx := context.Background() + text := "growth" + opts := PutOptions{Kind: "bench"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memCapSinkRef, memCapSinkErr = store.Put(ctx, text, opts) + } +} + +// --- ResolveURI at scale (URI table lookup) --- +// 10k URIs in the lookup table. The existing 1000 bench shows the +// hot path; 10k tests the constant cost claim against larger maps. + +func BenchmarkMemoryCapacity_ResolveURI_10k(b *testing.B) { + store := memoryStoreNoURI(b, 10000, 256) + ctx := context.Background() + // Stage URIs via Put so the uri index is populated. Doing this in + // the helper would slow every other bench in this file. + for i := 1; i <= 10000; i++ { + _, err := store.Put(ctx, "x", PutOptions{ + URI: "state://bench/cap-" + core.Sprintf("%d", i), + }) + if err != nil { + b.Fatal(err) + } + } + uri := "state://bench/cap-5000" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memCapSinkChunk, memCapSinkErr = store.ResolveURI(ctx, uri) + } +} + +// --- NewInMemoryStore at very large size --- +// One-pass construction over 10k chunks — the seed-load cost for a +// large project bundle. + +func BenchmarkMemoryCapacity_NewInMemoryStore_10000(b *testing.B) { + chunks := make(map[int]string, 10000) + for i := 1; i <= 10000; i++ { + chunks[i] = "chunk" + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + memCapSinkStorePtr = NewInMemoryStore(chunks) + } +} diff --git a/go/state/project_seed.go b/go/state/project_seed.go new file mode 100644 index 0000000..4b798a1 --- /dev/null +++ b/go/state/project_seed.go @@ -0,0 +1,357 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package state + +import core "dappco.re/go" + +type ProjectSeedMode string + +const ( + ProjectSeedStateCheckpoint ProjectSeedMode = "state_checkpoint" + ProjectSeedReuseCurrent ProjectSeedMode = "reuse_current" + ProjectSeedSummaryWindow ProjectSeedMode = "summary_window" + ProjectSeedHybrid ProjectSeedMode = "hybrid" +) + +type ProjectSeedOptions struct { + BaseURI string `json:"base_uri,omitempty"` + ProjectID string `json:"project_id,omitempty"` + EntryURI string `json:"entry_uri,omitempty"` + BundleURI string `json:"bundle_uri,omitempty"` + IndexURI string `json:"index_uri,omitempty"` + Title string `json:"title,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +type ProjectSeed struct { + BaseURI string `json:"base_uri,omitempty"` + ProjectID string `json:"project_id,omitempty"` + EntryURI string `json:"entry_uri,omitempty"` + BundleURI string `json:"bundle_uri,omitempty"` + IndexURI string `json:"index_uri,omitempty"` + Title string `json:"title,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +type ProjectSeedWakeOptions struct { + Store any `json:"-"` + Model ModelIdentity `json:"model,omitempty"` + Tokenizer TokenizerIdentity `json:"tokenizer,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Runtime RuntimeIdentity `json:"runtime,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +type ProjectSeedContinuationOptions struct { + Mode ProjectSeedMode `json:"mode,omitempty"` + Store any `json:"-"` + EntryURI string `json:"entry_uri,omitempty"` + BundleURI string `json:"bundle_uri,omitempty"` + IndexURI string `json:"index_uri,omitempty"` + Title string `json:"title,omitempty"` + Parent WakeResult `json:"parent,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Tokenizer TokenizerIdentity `json:"tokenizer,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Runtime RuntimeIdentity `json:"runtime,omitempty"` + BlockSize int `json:"block_size,omitempty"` + Encoding string `json:"encoding,omitempty"` + Labels map[string]string `json:"labels,omitempty"` + Metadata map[string]string `json:"metadata,omitempty"` +} + +type ProjectSeedContinuationPlan struct { + Mode ProjectSeedMode `json:"mode,omitempty"` + Sleep SleepRequest `json:"sleep,omitempty"` + PersistState bool `json:"persist_state,omitempty"` + NeedsSummary bool `json:"needs_summary,omitempty"` + ReuseCurrentSeed bool `json:"reuse_current_seed,omitempty"` +} + +func NewProjectSeed(opts ProjectSeedOptions) ProjectSeed { + seed := ProjectSeed{ + BaseURI: cleanURI(opts.BaseURI), + ProjectID: cleanURI(opts.ProjectID), + EntryURI: cleanURI(opts.EntryURI), + BundleURI: cleanURI(opts.BundleURI), + IndexURI: cleanURI(opts.IndexURI), + Title: core.Trim(opts.Title), + Labels: cloneStringMap(opts.Labels), + Metadata: cloneStringMap(opts.Metadata), + } + if seed.BaseURI == "" { + seed.BaseURI = "state://projects" + } + if seed.ProjectID == "" { + seed.ProjectID = "default" + } + if seed.EntryURI == "" { + seed.EntryURI = joinURI(seed.BaseURI, seed.ProjectID, "seed") + } + if seed.BundleURI == "" { + seed.BundleURI = seed.EntryURI + "/bundle" + } + if seed.IndexURI == "" { + seed.IndexURI = seed.EntryURI + "/index" + } + if seed.Title == "" { + seed.Title = seed.ProjectID + " project seed" + } + return seed +} + +func (s ProjectSeed) WakeRequest(opts ProjectSeedWakeOptions) WakeRequest { + labels := mergeStringMaps(s.Labels, opts.Labels) + setProjectLabel(labels, s.ProjectID) + return WakeRequest{ + Store: opts.Store, + IndexURI: s.IndexURI, + EntryURI: s.EntryURI, + Model: opts.Model, + Tokenizer: opts.Tokenizer, + Adapter: opts.Adapter, + Runtime: opts.Runtime, + Labels: labels, + } +} + +func (s ProjectSeed) PlanContinuation(opts ProjectSeedContinuationOptions) ProjectSeedContinuationPlan { + mode := opts.Mode + if mode == "" { + mode = ProjectSeedStateCheckpoint + } + plan := ProjectSeedContinuationPlan{Mode: mode} + switch mode { + case ProjectSeedReuseCurrent: + plan.ReuseCurrentSeed = true + return plan + case ProjectSeedSummaryWindow: + plan.NeedsSummary = true + return plan + case ProjectSeedHybrid: + plan.PersistState = true + plan.NeedsSummary = true + default: + plan.Mode = ProjectSeedStateCheckpoint + plan.PersistState = true + } + plan.Sleep = s.sleepRequest(opts) + return plan +} + +func (s ProjectSeed) sleepRequest(opts ProjectSeedContinuationOptions) SleepRequest { + entryURI := cleanURI(opts.EntryURI) + if entryURI == "" { + entryURI = joinURI(s.BaseURI, s.ProjectID, "checkpoints", "latest") + } + bundleURI := cleanURI(opts.BundleURI) + if bundleURI == "" { + bundleURI = entryURI + "/bundle" + } + indexURI := cleanURI(opts.IndexURI) + if indexURI == "" { + indexURI = entryURI + "/index" + } + metadata := mergeStringMaps(s.Metadata, opts.Metadata) + setProjectLabel(metadata, s.ProjectID) + labels := mergeStringMaps(s.Labels, opts.Labels) + setProjectLabel(labels, s.ProjectID) + parent := opts.Parent.Entry + return SleepRequest{ + Store: opts.Store, + EntryURI: entryURI, + BundleURI: bundleURI, + IndexURI: indexURI, + ParentEntryURI: firstNonEmpty(parent.URI, s.EntryURI), + ParentBundleURI: firstNonEmpty(parent.BundleURI, s.BundleURI), + ParentIndexURI: firstNonEmpty(parent.IndexURI, s.IndexURI), + Title: firstNonEmpty(core.Trim(opts.Title), s.Title), + Model: opts.Model, + Tokenizer: opts.Tokenizer, + Adapter: opts.Adapter, + Runtime: opts.Runtime, + ReuseParentPrefix: true, + BlockSize: opts.BlockSize, + Encoding: opts.Encoding, + Labels: labels, + Metadata: metadata, + } +} + +type WakeCompatibilityReport struct { + Compatible bool `json:"compatible"` + SummaryRequired bool `json:"summary_required,omitempty"` + Reasons []string `json:"reasons,omitempty"` + Warnings []string `json:"warnings,omitempty"` +} + +func CheckWakeCompatibility(bundle Bundle, req WakeRequest) WakeCompatibilityReport { + if req.SkipCompatibilityCheck { + return WakeCompatibilityReport{ + Compatible: true, + Warnings: []string{"compatibility_check_skipped"}, + } + } + report := WakeCompatibilityReport{Compatible: true} + compareModelIdentity(&report, bundle, req.Model) + compareTokenizerIdentity(&report, bundle.Tokenizer, req.Tokenizer) + compareAdapterIdentity(&report, bundle.Adapter, req.Adapter) + compareRuntimeIdentity(&report, bundle.Runtime, req.Runtime) + report.Compatible = len(report.Reasons) == 0 + report.SummaryRequired = !report.Compatible + return report +} + +func compareModelIdentity(report *WakeCompatibilityReport, bundle Bundle, req ModelIdentity) { + model := bundle.Model + if model.Hash != "" && req.Hash != "" && model.Hash != req.Hash { + report.Reasons = append(report.Reasons, "model_hash_mismatch") + } + if model.Architecture != "" && req.Architecture != "" && model.Architecture != req.Architecture { + report.Reasons = append(report.Reasons, "model_architecture_mismatch") + } + if model.NumLayers > 0 && req.NumLayers > 0 && model.NumLayers != req.NumLayers { + report.Reasons = append(report.Reasons, "model_layer_mismatch") + } + if model.QuantBits > 0 && req.QuantBits > 0 && model.QuantBits != req.QuantBits { + report.Reasons = append(report.Reasons, "model_quantisation_mismatch") + } + prefixTokens := bundle.PromptTokens + bundle.GeneratedTokens + if prefixTokens <= 0 { + prefixTokens = bundle.PromptTokens + } + if req.ContextLength > 0 && prefixTokens > 0 && req.ContextLength < prefixTokens { + report.Reasons = append(report.Reasons, "context_length_too_small") + } +} + +func compareTokenizerIdentity(report *WakeCompatibilityReport, bundle, req TokenizerIdentity) { + if bundle.Hash != "" && req.Hash != "" && bundle.Hash != req.Hash { + report.Reasons = append(report.Reasons, "tokenizer_hash_mismatch") + } + if bundle.ChatTemplate != "" && req.ChatTemplate != "" && bundle.ChatTemplate != req.ChatTemplate { + report.Reasons = append(report.Reasons, "chat_template_mismatch") + } +} + +func compareAdapterIdentity(report *WakeCompatibilityReport, bundle, req AdapterIdentity) { + bundleActive := adapterIdentityActive(bundle) + reqActive := adapterIdentityActive(req) + switch { + case bundleActive && !reqActive: + report.Reasons = append(report.Reasons, "adapter_missing") + case !bundleActive && reqActive: + report.Reasons = append(report.Reasons, "adapter_unexpected") + case bundle.Hash != "" && req.Hash != "" && bundle.Hash != req.Hash: + report.Reasons = append(report.Reasons, "adapter_hash_mismatch") + case bundle.Path != "" && req.Path != "" && bundle.Path != req.Path: + report.Reasons = append(report.Reasons, "adapter_path_mismatch") + case bundle.Rank > 0 && req.Rank > 0 && bundle.Rank != req.Rank: + report.Reasons = append(report.Reasons, "adapter_rank_mismatch") + } +} + +func compareRuntimeIdentity(report *WakeCompatibilityReport, bundle, req RuntimeIdentity) { + if bundle.Backend != "" && req.Backend != "" && bundle.Backend != req.Backend { + report.Warnings = append(report.Warnings, "runtime_backend_changed") + } + if bundle.CacheMode != "" && req.CacheMode != "" && bundle.CacheMode != req.CacheMode { + report.Warnings = append(report.Warnings, "runtime_cache_mode_changed") + } +} + +func adapterIdentityActive(adapter AdapterIdentity) bool { + return adapter.Hash != "" || adapter.Path != "" || adapter.Format != "" || adapter.Rank != 0 || adapter.Alpha != 0 || len(adapter.TargetKeys) > 0 || adapter.BaseModelHash != "" +} + +func cleanURI(value string) string { + value = core.Trim(value) + value = core.TrimPrefix(value, "/") + return core.TrimSuffix(value, "/") +} + +func joinURI(base string, parts ...string) string { + // Walk parts twice — first to sum the exact final length, second to + // append into a pre-sized []byte buffer. cleanURI is alloc-free + // (string substring views), so the second walk is purely arithmetic + // + byte copies. The previous shape used core.NewBuilder() (heap + // pointer alloc) plus the Builder's internal buffer grow (second + // heap alloc); collapsing to a direct []byte buffer + core.AsString + // drops one heap alloc per call. The cleaned []string slot from the + // previous shape was stack-resident, so eliding it costs nothing. + cleanBase := cleanURI(base) + total := len(cleanBase) + for _, part := range parts { + p := cleanURI(part) + if p == "" { + continue + } + if total > 0 { + total++ // separator + } + total += len(p) + } + if total == 0 { + return "" + } + buf := make([]byte, 0, total) + if cleanBase != "" { + buf = append(buf, cleanBase...) + } + for _, part := range parts { + p := cleanURI(part) + if p == "" { + continue + } + if len(buf) > 0 { + buf = append(buf, '/') + } + buf = append(buf, p...) + } + return core.AsString(buf) +} + +func setProjectLabel(labels map[string]string, projectID string) { + if labels == nil || projectID == "" { + return + } + if labels["project_id"] == "" { + labels["project_id"] = projectID + } +} + +func mergeStringMaps(left, right map[string]string) map[string]string { + if len(left) == 0 && len(right) == 0 { + return nil + } + out := make(map[string]string, len(left)+len(right)+1) + for key, value := range left { + out[key] = value + } + for key, value := range right { + out[key] = value + } + return out +} + +func cloneStringMap(in map[string]string) map[string]string { + if len(in) == 0 { + return nil + } + out := make(map[string]string, len(in)) + for key, value := range in { + out[key] = value + } + return out +} + +func firstNonEmpty(values ...string) string { + for _, value := range values { + if core.Trim(value) != "" { + return value + } + } + return "" +} diff --git a/go/state/project_seed_bench_test.go b/go/state/project_seed_bench_test.go new file mode 100644 index 0000000..979d586 --- /dev/null +++ b/go/state/project_seed_bench_test.go @@ -0,0 +1,297 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the project-seed durable-checkpoint primitives. +// Per AX-11 — ProjectSeed is the per-project root; NewProjectSeed +// fires per workspace entry, WakeRequest / PlanContinuation fire per +// session boundary, and CheckWakeCompatibility fires before every +// model-state restore. The Labels / Metadata maps are the per-call +// allocation drivers; both shapes are benched here. +// +// Run: go test -bench='Benchmark' -benchmem -run='^$' ./state + +package state + +import "testing" + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + projectSeedSinkSeed ProjectSeed + projectSeedSinkWake WakeRequest + projectSeedSinkPlan ProjectSeedContinuationPlan + projectSeedSinkReport WakeCompatibilityReport +) + +// labelsMap builds a deterministic map of n distinct entries for +// benching map-merge + clone shapes. Each key is unique so the bench +// reflects the real per-entry map cost, not collision dedup. +func labelsMap(n int) map[string]string { + out := make(map[string]string, n) + for i := 0; i < n; i++ { + out[labelsKey(i)] = labelsValue(i) + } + return out +} + +func labelsKey(i int) string { + // Inline base-36 digits keep the key short + unique without + // pulling core.Sprintf onto the hot fixture path. + const digits = "0123456789abcdefghijklmnopqrstuvwxyz" + if i < 36 { + return "k" + string(digits[i]) + } + return "k" + string(digits[i/36]) + string(digits[i%36]) +} + +func labelsValue(i int) string { + const digits = "0123456789abcdefghijklmnopqrstuvwxyz" + if i < 36 { + return "v" + string(digits[i]) + } + return "v" + string(digits[i/36]) + string(digits[i%36]) +} + +// --- NewProjectSeed (per-workspace entry — sets defaults) --- + +func BenchmarkProjectSeed_NewProjectSeed_Minimal(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkSeed = NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + } +} + +func BenchmarkProjectSeed_NewProjectSeed_Defaulted(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // All URIs left empty so the default-fill branch runs. + projectSeedSinkSeed = NewProjectSeed(ProjectSeedOptions{ + ProjectID: "core/go-mlx", + }) + } +} + +func BenchmarkProjectSeed_NewProjectSeed_Labels_10(b *testing.B) { + labels := labelsMap(10) + metadata := labelsMap(10) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkSeed = NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + Labels: labels, + Metadata: metadata, + }) + } +} + +func BenchmarkProjectSeed_NewProjectSeed_Labels_100(b *testing.B) { + labels := labelsMap(100) + metadata := labelsMap(100) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkSeed = NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + Labels: labels, + Metadata: metadata, + }) + } +} + +// --- WakeRequest (per session boot) --- + +func BenchmarkProjectSeed_WakeRequest_Minimal(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedWakeOptions{ + Store: "store", + Model: ModelIdentity{ID: "gemma4", Hash: "model-a"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkWake = seed.WakeRequest(opts) + } +} + +func BenchmarkProjectSeed_WakeRequest_Labels_10(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + Labels: labelsMap(10), + }) + opts := ProjectSeedWakeOptions{ + Store: "store", + Model: ModelIdentity{ID: "gemma4"}, + Labels: labelsMap(10), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkWake = seed.WakeRequest(opts) + } +} + +func BenchmarkProjectSeed_WakeRequest_Labels_100(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + Labels: labelsMap(100), + }) + opts := ProjectSeedWakeOptions{ + Store: "store", + Model: ModelIdentity{ID: "gemma4"}, + Labels: labelsMap(100), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkWake = seed.WakeRequest(opts) + } +} + +// --- PlanContinuation (per session end — selects sleep shape) --- + +func BenchmarkProjectSeed_PlanContinuation_StateCheckpoint(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedContinuationOptions{ + Mode: ProjectSeedStateCheckpoint, + Store: "store", + Model: ModelIdentity{ID: "gemma4"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkPlan = seed.PlanContinuation(opts) + } +} + +func BenchmarkProjectSeed_PlanContinuation_ReuseCurrent(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedContinuationOptions{Mode: ProjectSeedReuseCurrent} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkPlan = seed.PlanContinuation(opts) + } +} + +func BenchmarkProjectSeed_PlanContinuation_SummaryWindow(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedContinuationOptions{Mode: ProjectSeedSummaryWindow} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkPlan = seed.PlanContinuation(opts) + } +} + +func BenchmarkProjectSeed_PlanContinuation_Hybrid(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedContinuationOptions{ + Mode: ProjectSeedHybrid, + Store: "store", + Model: ModelIdentity{ID: "gemma4"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkPlan = seed.PlanContinuation(opts) + } +} + +func BenchmarkProjectSeed_PlanContinuation_Labels_100(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + Labels: labelsMap(100), + Metadata: labelsMap(100), + }) + opts := ProjectSeedContinuationOptions{ + Mode: ProjectSeedStateCheckpoint, + Store: "store", + Model: ModelIdentity{ID: "gemma4"}, + Labels: labelsMap(100), + Metadata: labelsMap(100), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkPlan = seed.PlanContinuation(opts) + } +} + +// --- CheckWakeCompatibility (per restore — gates the wake) --- + +func BenchmarkProjectSeed_CheckWakeCompatibility_Compatible(b *testing.B) { + bundle := Bundle{ + Model: ModelIdentity{Hash: "model-a", Architecture: "gemma4_text", NumLayers: 28, QuantBits: 4, ContextLength: 4096}, + Tokenizer: TokenizerIdentity{Hash: "tok-a", ChatTemplate: "chat-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 8}, + Runtime: RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"}, + PromptTokens: 2048, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "model-a", Architecture: "gemma4_text", NumLayers: 28, QuantBits: 4, ContextLength: 8192}, + Tokenizer: TokenizerIdentity{Hash: "tok-a", ChatTemplate: "chat-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 8}, + Runtime: RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkReport = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkProjectSeed_CheckWakeCompatibility_Incompatible(b *testing.B) { + bundle := Bundle{ + Model: ModelIdentity{Hash: "model-a", Architecture: "gemma4_text", NumLayers: 28, QuantBits: 4, ContextLength: 4096}, + Tokenizer: TokenizerIdentity{Hash: "tok-a", ChatTemplate: "chat-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 8}, + Runtime: RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"}, + PromptTokens: 2048, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "model-b", Architecture: "qwen3", NumLayers: 28, QuantBits: 8, ContextLength: 1024}, + Tokenizer: TokenizerIdentity{Hash: "tok-b", ChatTemplate: "chat-b"}, + Adapter: AdapterIdentity{}, + Runtime: RuntimeIdentity{Backend: "rocm", CacheMode: "paged-q4"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkReport = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkProjectSeed_CheckWakeCompatibility_Skip(b *testing.B) { + bundle := Bundle{Model: ModelIdentity{Hash: "model-a"}} + req := WakeRequest{SkipCompatibilityCheck: true} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + projectSeedSinkReport = CheckWakeCompatibility(bundle, req) + } +} diff --git a/go/state/project_seed_deep_bench_test.go b/go/state/project_seed_deep_bench_test.go new file mode 100644 index 0000000..fc7a250 --- /dev/null +++ b/go/state/project_seed_deep_bench_test.go @@ -0,0 +1,308 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Deeper benchmarks for the project-seed durable-checkpoint primitives. +// Per AX-11 — the existing project_seed_bench_test.go covers the main +// constructor + per-session paths. These benches drill into the +// CheckWakeCompatibility partial-mismatch matrix (one reason at a time +// matters, since the report carries them as a slice), the URI-join +// helper (joinURI is on the hot construction path), and the +// PlanContinuation sleep-request assembly that does the heaviest +// per-seed work. +// +// Run: go test -bench='BenchmarkProjectSeedDeep' -benchmem -run='^$' ./state + +package state + +import "testing" + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + psDeepSinkSeed ProjectSeed + psDeepSinkPlan ProjectSeedContinuationPlan + psDeepSinkReport WakeCompatibilityReport + psDeepSinkString string + psDeepSinkMap map[string]string +) + +// --- CheckWakeCompatibility partial-mismatch matrix --- +// One mismatch reason at a time exercises the comparator without other +// branches polluting the per-call cost. + +func BenchmarkProjectSeedDeep_CheckCompat_ModelHashMismatch(b *testing.B) { + bundle := Bundle{ + Model: ModelIdentity{Hash: "model-a", Architecture: "gemma4", NumLayers: 28, QuantBits: 4, ContextLength: 4096}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a"}, + Runtime: RuntimeIdentity{Backend: "metal"}, + PromptTokens: 2048, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "model-X", Architecture: "gemma4", NumLayers: 28, QuantBits: 4, ContextLength: 8192}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a"}, + Runtime: RuntimeIdentity{Backend: "metal"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkReport = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkProjectSeedDeep_CheckCompat_TokenizerMismatch(b *testing.B) { + bundle := Bundle{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a", ChatTemplate: "chat-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a"}, + Runtime: RuntimeIdentity{Backend: "metal"}, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-b", ChatTemplate: "chat-b"}, + Adapter: AdapterIdentity{Hash: "adapter-a"}, + Runtime: RuntimeIdentity{Backend: "metal"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkReport = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkProjectSeedDeep_CheckCompat_AdapterMissing(b *testing.B) { + bundle := Bundle{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 8}, + Runtime: RuntimeIdentity{Backend: "metal"}, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{}, // missing — exercises the bundleActive && !reqActive branch + Runtime: RuntimeIdentity{Backend: "metal"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkReport = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkProjectSeedDeep_CheckCompat_AdapterUnexpected(b *testing.B) { + bundle := Bundle{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{}, + Runtime: RuntimeIdentity{Backend: "metal"}, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-x", Rank: 8}, + Runtime: RuntimeIdentity{Backend: "metal"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkReport = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkProjectSeedDeep_CheckCompat_AdapterRankMismatch(b *testing.B) { + bundle := Bundle{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 8, Path: "/a"}, + Runtime: RuntimeIdentity{Backend: "metal"}, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 16, Path: "/a"}, + Runtime: RuntimeIdentity{Backend: "metal"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkReport = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkProjectSeedDeep_CheckCompat_RuntimeBackendChange(b *testing.B) { + // Runtime mismatches emit Warnings, not Reasons — the report stays + // Compatible:true but carries telemetry. + bundle := Bundle{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a"}, + Runtime: RuntimeIdentity{Backend: "metal", CacheMode: "paged"}, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "m"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a"}, + Runtime: RuntimeIdentity{Backend: "rocm", CacheMode: "paged-q4"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkReport = CheckWakeCompatibility(bundle, req) + } +} + +func BenchmarkProjectSeedDeep_CheckCompat_ContextTooSmall(b *testing.B) { + bundle := Bundle{ + Model: ModelIdentity{Hash: "m", ContextLength: 4096}, + PromptTokens: 2048, + GeneratedTokens: 2048, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "m", ContextLength: 1024}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkReport = CheckWakeCompatibility(bundle, req) + } +} + +// --- PlanContinuation with custom URIs --- +// PlanContinuation defaults the entry/bundle/index URIs from the seed +// when not provided. These benches exercise the override branch where +// the consumer supplies explicit URIs. + +func BenchmarkProjectSeedDeep_PlanContinuation_CustomURIs(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedContinuationOptions{ + Mode: ProjectSeedStateCheckpoint, + Store: "store", + EntryURI: "state://override/entry", + BundleURI: "state://override/entry/bundle", + IndexURI: "state://override/entry/index", + Title: "override-title", + Model: ModelIdentity{ID: "gemma4"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkPlan = seed.PlanContinuation(opts) + } +} + +func BenchmarkProjectSeedDeep_PlanContinuation_WithParent(b *testing.B) { + // Parent ref provided — the sleepRequest assembly walks + // firstNonEmpty for entry/bundle/index URIs. + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedContinuationOptions{ + Mode: ProjectSeedStateCheckpoint, + Store: "store", + Parent: WakeResult{ + Entry: Ref{ + URI: "state://parent/entry", + BundleURI: "state://parent/bundle", + IndexURI: "state://parent/index", + }, + }, + Model: ModelIdentity{ID: "gemma4"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkPlan = seed.PlanContinuation(opts) + } +} + +// --- NewProjectSeed with mixed defaults --- +// One or two URIs supplied, rest defaulted. Exercises the per-field +// firstNonEmpty + joinURI fallback paths in the constructor. + +func BenchmarkProjectSeedDeep_NewProjectSeed_PartialURIs(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkSeed = NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + EntryURI: "state://override/entry", + // BundleURI + IndexURI left empty so the defaults run. + }) + } +} + +func BenchmarkProjectSeedDeep_NewProjectSeed_AllURIs(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + psDeepSinkSeed = NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + EntryURI: "state://lthn/projects/core/go-mlx/seed", + BundleURI: "state://lthn/projects/core/go-mlx/seed/bundle", + IndexURI: "state://lthn/projects/core/go-mlx/seed/index", + Title: "core/go-mlx seed", + }) + } +} + +// --- WakeRequest with mixed label shapes --- +// labels-only-from-seed vs labels-only-from-opts vs both — the +// merge path's allocator behaviour depends on the empty case. + +func BenchmarkProjectSeedDeep_WakeRequest_LabelsSeedOnly(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + Labels: labelsMap(8), + }) + opts := ProjectSeedWakeOptions{ + Store: "store", + Model: ModelIdentity{ID: "gemma4"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + seed.WakeRequest(opts) + } +} + +func BenchmarkProjectSeedDeep_WakeRequest_LabelsOptsOnly(b *testing.B) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedWakeOptions{ + Store: "store", + Model: ModelIdentity{ID: "gemma4"}, + Labels: labelsMap(8), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + seed.WakeRequest(opts) + } +} + +func BenchmarkProjectSeedDeep_WakeRequest_NoLabels(b *testing.B) { + // Both sides empty — the merge helper takes the early-out path + // and returns nil without allocating. + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + }) + opts := ProjectSeedWakeOptions{ + Store: "store", + Model: ModelIdentity{ID: "gemma4"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + seed.WakeRequest(opts) + } +} diff --git a/go/state/project_seed_test.go b/go/state/project_seed_test.go new file mode 100644 index 0000000..14b74d4 --- /dev/null +++ b/go/state/project_seed_test.go @@ -0,0 +1,145 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package state + +import "testing" + +func TestProjectSeed_WakeRequest_Good(t *testing.T) { + seed := NewProjectSeed(ProjectSeedOptions{ + BaseURI: "state://lthn/projects", + ProjectID: "core/go-mlx", + Title: "go-mlx seed", + Labels: map[string]string{"scope": "repo"}, + Metadata: map[string]string{"operator": "snider"}, + }) + + wake := seed.WakeRequest(ProjectSeedWakeOptions{ + Store: "store", + Model: ModelIdentity{ID: "gemma4", Hash: "model-a"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a"}, + Runtime: RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"}, + }) + + if wake.Store != "store" || wake.EntryURI != "state://lthn/projects/core/go-mlx/seed" || wake.IndexURI != "state://lthn/projects/core/go-mlx/seed/index" { + t.Fatalf("wake request = %+v, want project seed URIs and store", wake) + } + if wake.Model.Hash != "model-a" || wake.Tokenizer.Hash != "tok-a" || wake.Adapter.Hash != "adapter-a" || wake.Runtime.Backend != "metal" { + t.Fatalf("wake identities = %+v/%+v/%+v/%+v", wake.Model, wake.Tokenizer, wake.Adapter, wake.Runtime) + } + if wake.Labels["project_id"] != "core/go-mlx" || wake.Labels["scope"] != "repo" { + t.Fatalf("wake labels = %+v, want project and caller labels", wake.Labels) + } + + seed.Labels["scope"] = "mutated" + if wake.Labels["scope"] != "repo" { + t.Fatalf("wake request labels aliased seed labels: %+v", wake.Labels) + } +} + +func TestProjectSeed_PlanContinuationModes_Good(t *testing.T) { + seed := NewProjectSeed(ProjectSeedOptions{BaseURI: "state://lthn/projects", ProjectID: "core/go-mlx"}) + parent := WakeResult{ + Entry: Ref{URI: seed.EntryURI, BundleURI: seed.BundleURI, IndexURI: seed.IndexURI}, + PrefixTokens: 42, + } + + statePlan := seed.PlanContinuation(ProjectSeedContinuationOptions{ + Mode: ProjectSeedStateCheckpoint, + Store: "store", + EntryURI: "state://lthn/projects/core/go-mlx/tasks/inspect", + Title: "inspect result", + Parent: parent, + Model: ModelIdentity{ID: "gemma4"}, + Tokenizer: TokenizerIdentity{Hash: "tok-a"}, + Metadata: map[string]string{"finding_count": "2"}, + }) + if !statePlan.PersistState || statePlan.NeedsSummary || statePlan.ReuseCurrentSeed { + t.Fatalf("state plan flags = %+v, want state checkpoint", statePlan) + } + if statePlan.Sleep.Store != "store" || !statePlan.Sleep.ReuseParentPrefix { + t.Fatalf("sleep request = %+v, want store and parent prefix reuse", statePlan.Sleep) + } + if statePlan.Sleep.ParentEntryURI != seed.EntryURI || statePlan.Sleep.ParentBundleURI != seed.BundleURI || statePlan.Sleep.ParentIndexURI != seed.IndexURI { + t.Fatalf("sleep parent = %+v, want seed parent refs", statePlan.Sleep) + } + if statePlan.Sleep.Metadata["project_id"] != "core/go-mlx" || statePlan.Sleep.Metadata["finding_count"] != "2" { + t.Fatalf("sleep metadata = %+v, want project and caller metadata", statePlan.Sleep.Metadata) + } + + summaryPlan := seed.PlanContinuation(ProjectSeedContinuationOptions{Mode: ProjectSeedSummaryWindow}) + if summaryPlan.PersistState || !summaryPlan.NeedsSummary || summaryPlan.Sleep.EntryURI != "" { + t.Fatalf("summary plan = %+v, want summary-only window", summaryPlan) + } + + reusePlan := seed.PlanContinuation(ProjectSeedContinuationOptions{Mode: ProjectSeedReuseCurrent}) + if reusePlan.PersistState || reusePlan.NeedsSummary || !reusePlan.ReuseCurrentSeed { + t.Fatalf("reuse plan = %+v, want current seed reuse", reusePlan) + } +} + +func TestWakeCompatibility_GoodBadUgly(t *testing.T) { + bundle := Bundle{ + Model: ModelIdentity{Hash: "model-a", Architecture: "gemma4_text", NumLayers: 28, QuantBits: 4, ContextLength: 4096}, + Tokenizer: TokenizerIdentity{Hash: "tok-a", ChatTemplate: "chat-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 8}, + Runtime: RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"}, + PromptTokens: 2048, + } + req := WakeRequest{ + Model: ModelIdentity{Hash: "model-a", Architecture: "gemma4_text", NumLayers: 28, QuantBits: 4, ContextLength: 8192}, + Tokenizer: TokenizerIdentity{Hash: "tok-a", ChatTemplate: "chat-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 8}, + Runtime: RuntimeIdentity{Backend: "rocm", CacheMode: "paged-q8"}, + } + + report := CheckWakeCompatibility(bundle, req) + if !report.Compatible || report.SummaryRequired || len(report.Reasons) != 0 { + t.Fatalf("compatible report = %+v, want wake-compatible", report) + } + if len(report.Warnings) == 0 || report.Warnings[0] != "runtime_backend_changed" { + t.Fatalf("warnings = %+v, want runtime backend warning", report.Warnings) + } + + req.Tokenizer.Hash = "tok-b" + req.Adapter = AdapterIdentity{} + req.Model.ContextLength = 1024 + report = CheckWakeCompatibility(bundle, req) + if report.Compatible || !report.SummaryRequired { + t.Fatalf("incompatible report = %+v, want summary fallback", report) + } + if !stringSliceContains(report.Reasons, "tokenizer_hash_mismatch") || !stringSliceContains(report.Reasons, "adapter_missing") || !stringSliceContains(report.Reasons, "context_length_too_small") { + t.Fatalf("reasons = %+v, want tokenizer, adapter, and context blockers", report.Reasons) + } + + req = WakeRequest{ + Model: ModelIdentity{Hash: "model-b", Architecture: "qwen3", NumLayers: 28, QuantBits: 8, ContextLength: 8192}, + Tokenizer: TokenizerIdentity{Hash: "tok-a", ChatTemplate: "chat-a"}, + Adapter: AdapterIdentity{Hash: "adapter-a", Rank: 8}, + Runtime: RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"}, + } + report = CheckWakeCompatibility(bundle, req) + if report.Compatible || !report.SummaryRequired { + t.Fatalf("model-incompatible report = %+v, want summary fallback", report) + } + for _, want := range []string{"model_hash_mismatch", "model_architecture_mismatch", "model_quantisation_mismatch"} { + if !stringSliceContains(report.Reasons, want) { + t.Fatalf("reasons = %+v, want %s", report.Reasons, want) + } + } + + req.SkipCompatibilityCheck = true + report = CheckWakeCompatibility(bundle, req) + if !report.Compatible || len(report.Warnings) == 0 || report.Warnings[0] != "compatibility_check_skipped" { + t.Fatalf("skip report = %+v, want forced compatibility warning", report) + } +} + +func stringSliceContains(values []string, want string) bool { + for _, value := range values { + if value == want { + return true + } + } + return false +} diff --git a/go/state/putoptions_bench_test.go b/go/state/putoptions_bench_test.go new file mode 100644 index 0000000..a070d4e --- /dev/null +++ b/go/state/putoptions_bench_test.go @@ -0,0 +1,236 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the PutOptions input shape across the Writer surface. +// Per AX-11 — PutOptions is the per-call envelope every Put/PutBytes +// hits. The Tags map is the dominant allocator under heavy metadata +// loads (memvid bundle saves carry 4-12 tags per chunk). The URI string +// length matters because the Memory backend mirrors URIs into a lookup +// table — long URIs compound into the uri map. +// +// Run: go test -bench='BenchmarkPutOptions' -benchmem -run='^$' ./state + +package state + +import ( + "context" + "testing" +) + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + putOptsSinkRef ChunkRef + putOptsSinkErr error +) + +// --- Tags map size sweep --- +// Memvid bundle saves typically carry 0-8 tags per record (kind, track, +// epoch, source-tool, env, etc.). The Put path doesn't clone the map +// today but the structural shape benches confirm the read cost. + +func BenchmarkPutOptions_NoTags(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + opts := PutOptions{Kind: "bench"} + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} + +func BenchmarkPutOptions_Tags_1(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + opts := PutOptions{ + Kind: "bench", + Tags: map[string]string{"epoch": "3"}, + } + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} + +func BenchmarkPutOptions_Tags_4(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + opts := PutOptions{ + Kind: "bench", + Tags: map[string]string{ + "epoch": "3", + "track": "primary", + "source": "memvid", + "env": "bench", + }, + } + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} + +func BenchmarkPutOptions_Tags_8(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + opts := PutOptions{ + Kind: "bench", + Tags: map[string]string{ + "epoch": "3", + "track": "primary", + "source": "memvid", + "env": "bench", + "branch": "dev", + "runner": "homelab", + "adapter": "lora-1", + "model": "qwen3", + }, + } + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} + +// --- Labels slice size --- +// Per Lethean convention, Labels is the unordered string-list of +// arbitrary classifiers (e.g. "kind:training", "source:hypnos"). The +// slice header is shared by reference but indexes any persistence +// hashing. + +func BenchmarkPutOptions_Labels_0(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + opts := PutOptions{Kind: "bench"} + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} + +func BenchmarkPutOptions_Labels_4(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + opts := PutOptions{ + Kind: "bench", + Labels: []string{"k0:v0", "k1:v1", "k2:v2", "k3:v3"}, + } + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} + +// --- URI variants --- +// Empty URI bypasses the uri[] index write. Typical URI is a normal +// state:// path. Very-long URI tests the map-write of a 256-char key +// (e.g. fully-qualified bundle URI with epoch+layer suffixes). + +func BenchmarkPutOptions_URI_Empty(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + opts := PutOptions{Kind: "bench"} + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} + +func BenchmarkPutOptions_URI_Typical(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + opts := PutOptions{ + Kind: "bench", + URI: "state://lthn/projects/core/go-mlx/seed/v1/bundle", + } + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} + +func BenchmarkPutOptions_URI_Long(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + // 256-char URI — realistic for a fully-qualified bundle/segment/epoch + // path that includes runtime + model identity in the leaf. + uri := "state://lthn/projects/core/go-mlx/snapshots/2026-05-22T12:00:00Z/" + + "runtime/metal/m3-ultra/model/qwen3-27b-4bit/adapter/lora-1/" + + "workload/long-context/segment/chunk-00000042/epoch-3/layer/all" + opts := PutOptions{Kind: "bench", URI: uri} + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} + +// --- HasFrameOffset variants --- +// PutBytes always sets HasFrameOffset on the returned ref. The shape +// is asserted at the ref layer below; this bench exercises the +// observable cost of constructing the ref with explicit defaults. + +func BenchmarkPutOptions_Construct_HasFrameOffset(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef = ChunkRef{ + ChunkID: i, + FrameOffset: uint64(i), + HasFrameOffset: true, + Codec: CodecMemory, + } + } +} + +func BenchmarkPutOptions_Construct_NoFrameOffset(b *testing.B) { + // Some adapters omit the frame offset (e.g. opaque-blob stores). + // Confirms the "small" ref shape costs the same to construct. + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef = ChunkRef{ + ChunkID: i, + Codec: CodecMemory, + } + } +} + +// --- Title / Track / Kind string variants --- +// Same shape but with all metadata strings populated — the per-call +// cost should be ~constant since the map writes dominate, but the +// bench tracks regressions in the metadata-rich path. + +func BenchmarkPutOptions_FullMetadata(b *testing.B) { + store := NewInMemoryStore(nil) + ctx := context.Background() + opts := PutOptions{ + URI: "state://bench/full", + Title: "bench-chunk-with-long-title-for-realistic-meta", + Kind: "training-checkpoint", + Track: "primary-train", + Tags: map[string]string{"epoch": "3", "branch": "dev"}, + Labels: []string{"kind:training", "source:hypnos"}, + } + data := make([]byte, 64) + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + putOptsSinkRef, putOptsSinkErr = store.PutBytes(ctx, data, opts) + } +} diff --git a/go/state/state_test.go b/go/state/state_test.go new file mode 100644 index 0000000..4b3e76b --- /dev/null +++ b/go/state/state_test.go @@ -0,0 +1,146 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package state + +import ( + "context" + "testing" + + core "dappco.re/go" +) + +func TestState_InMemoryStore_Good(t *testing.T) { + store := NewInMemoryStore(map[int]string{7: "chunk seven"}) + + text, err := store.Get(context.Background(), 7) + if err != nil { + t.Fatalf("Get() error = %v", err) + } + if text != "chunk seven" { + t.Fatalf("Get() = %q, want chunk seven", text) + } + chunk, err := Resolve(context.Background(), store, 7) + if err != nil { + t.Fatalf("Resolve() error = %v", err) + } + if chunk.Ref.ChunkID != 7 || !chunk.Ref.HasFrameOffset || chunk.Ref.FrameOffset != 7 || chunk.Ref.Codec != CodecMemory { + t.Fatalf("chunk ref = %#v", chunk.Ref) + } +} + +func TestState_InMemoryStore_Bad(t *testing.T) { + store := NewInMemoryStore(nil) + + _, err := store.Get(context.Background(), 42) + + if !core.Is(err, ErrChunkNotFound) { + t.Fatalf("missing chunk error = %v, want ErrChunkNotFound", err) + } +} + +func TestState_BinaryStore_Good(t *testing.T) { + store := NewInMemoryStore(nil) + payload := []byte{0, 1, 2, 255} + + ref, err := store.PutBytes(context.Background(), payload, PutOptions{URI: "state://binary/1"}) + if err != nil { + t.Fatalf("PutBytes() error = %v", err) + } + payload[1] = 99 + + chunk, err := ResolveBytes(context.Background(), store, ref.ChunkID) + if err != nil { + t.Fatalf("ResolveBytes() error = %v", err) + } + if chunk.Ref.ChunkID != ref.ChunkID || len(chunk.Data) != 4 || chunk.Data[1] != 1 || chunk.Data[3] != 255 { + t.Fatalf("ResolveBytes() chunk = %+v, want copied binary payload", chunk) + } + chunk.Data[2] = 88 + again, err := ResolveBytes(context.Background(), store, ref.ChunkID) + if err != nil { + t.Fatalf("ResolveBytes(second) error = %v", err) + } + if again.Data[2] != 2 { + t.Fatalf("ResolveBytes() returned aliased data = %v", again.Data) + } + byURI, err := ResolveURI(context.Background(), store, "state://binary/1") + if err != nil { + t.Fatalf("ResolveURI(binary) error = %v", err) + } + if len(byURI.Data) != 4 || byURI.Data[0] != 0 { + t.Fatalf("ResolveURI(binary) chunk = %+v, want binary data", byURI) + } +} + +func TestState_BorrowRefBytesFallback_Good(t *testing.T) { + store := NewInMemoryStore(nil) + payload := []byte{4, 3, 2, 1} + ref, err := store.PutBytes(context.Background(), payload, PutOptions{}) + if err != nil { + t.Fatalf("PutBytes() error = %v", err) + } + + borrowed, err := BorrowRefBytes(context.Background(), store, ref) + if err != nil { + t.Fatalf("BorrowRefBytes() error = %v", err) + } + if borrowed.Ref.ChunkID != ref.ChunkID || len(borrowed.Data) != len(payload) || borrowed.Data[0] != 4 { + t.Fatalf("BorrowRefBytes() = %+v, want copied payload", borrowed) + } + if borrowed.Release != nil { + borrowed.Release() + } +} + +func TestState_BorrowRefBytes_Bad(t *testing.T) { + _, err := BorrowRefBytes(context.Background(), nil, ChunkRef{ChunkID: 42}) + + if !core.Is(err, ErrChunkNotFound) { + t.Fatalf("BorrowRefBytes(nil) error = %v, want ErrChunkNotFound", err) + } +} + +func TestState_WakeSleepForkContracts_Good(t *testing.T) { + model := fakeForker{} + + session, wake, err := model.ForkState(context.Background(), WakeRequest{ + Store: NewInMemoryStore(nil), + IndexURI: "state://index", + Model: ModelIdentity{ID: "tiny"}, + }) + + if err != nil { + t.Fatalf("ForkState() error = %v", err) + } + if session == nil || wake == nil || wake.Entry.URI != "state://index/entry" { + t.Fatalf("ForkState() = %#v, %#v; want session and wake report", session, wake) + } + sleep, err := session.SleepState(context.Background(), SleepRequest{EntryURI: "state://entry"}) + if err != nil { + t.Fatalf("SleepState() error = %v", err) + } + if sleep.Entry.URI != "state://entry" || sleep.TokenCount != 12 { + t.Fatalf("SleepState() = %#v, want entry token count", sleep) + } +} + +type fakeForker struct{} + +func (fakeForker) ForkState(_ context.Context, req WakeRequest) (Session, *WakeResult, error) { + session := fakeSession{} + return session, &WakeResult{ + Entry: Ref{URI: req.IndexURI + "/entry"}, + PrefixTokens: 12, + Labels: map[string]string{"backend": "fake"}, + }, nil +} + +type fakeSession struct{} + +func (fakeSession) WakeState(_ context.Context, req WakeRequest) (*WakeResult, error) { + return &WakeResult{Entry: Ref{URI: req.EntryURI}, PrefixTokens: 12}, nil +} + +func (fakeSession) SleepState(_ context.Context, req SleepRequest) (*SleepResult, error) { + return &SleepResult{Entry: Ref{URI: req.EntryURI}, TokenCount: 12}, nil +} diff --git a/go/state/store.go b/go/state/store.go new file mode 100644 index 0000000..a3b5779 --- /dev/null +++ b/go/state/store.go @@ -0,0 +1,259 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Package state defines portable model-state storage and lifecycle contracts. +package state + +import ( + "context" + stdio "io" + + core "dappco.re/go" +) + +var ErrChunkNotFound = core.NewError("state chunk not found") + +const ( + CodecMemory = "memory/plaintext" + CodecStateVideo = "state/qr-video" + CodecQRVideo = CodecStateVideo + // Deprecated: use CodecStateVideo. + CodecMemvidQRVideo = "memvid/qr-video" +) + +type Store interface { + Get(ctx context.Context, chunkID int) (string, error) +} + +type Resolver interface { + Resolve(ctx context.Context, chunkID int) (Chunk, error) +} + +type URIResolver interface { + ResolveURI(ctx context.Context, uri string) (Chunk, error) +} + +type Writer interface { + Put(ctx context.Context, text string, opts PutOptions) (ChunkRef, error) +} + +type BinaryResolver interface { + ResolveBytes(ctx context.Context, chunkID int) (Chunk, error) +} + +type RefBinaryResolver interface { + ResolveRefBytes(ctx context.Context, ref ChunkRef) (Chunk, error) +} + +// BorrowedChunk is a byte view borrowed from a store. Release is optional and +// may be nil when the view is store-lifetime bound; callers must keep the +// backing store open while retaining Data. +type BorrowedChunk struct { + Ref ChunkRef + Data []byte + Release func() +} + +// BinaryBorrower returns a borrowed byte view for a chunk ID. +type BinaryBorrower interface { + BorrowBytes(ctx context.Context, chunkID int) (BorrowedChunk, error) +} + +// RefBinaryBorrower returns a borrowed byte view for a full chunk ref. +type RefBinaryBorrower interface { + BorrowRefBytes(ctx context.Context, ref ChunkRef) (BorrowedChunk, error) +} + +type BinaryWriter interface { + PutBytes(ctx context.Context, data []byte, opts PutOptions) (ChunkRef, error) +} + +type BinaryStreamWriter interface { + PutBytesStream(ctx context.Context, payloadSize int, opts PutOptions, write func(stdio.Writer) error) (ChunkRef, error) +} + +type PutOptions struct { + URI string `json:"uri,omitempty"` + Title string `json:"title,omitempty"` + Kind string `json:"kind,omitempty"` + Track string `json:"track,omitempty"` + Tags map[string]string `json:"tags,omitempty"` + Labels []string `json:"labels,omitempty"` +} + +type Chunk struct { + Ref ChunkRef `json:"ref"` + Text string `json:"text"` + Data []byte `json:"data,omitempty"` +} + +type ChunkRef struct { + ChunkID int `json:"chunk_id"` + FrameOffset uint64 `json:"frame_offset,omitempty"` + HasFrameOffset bool `json:"has_frame_offset,omitempty"` + Codec string `json:"codec,omitempty"` + Segment string `json:"segment,omitempty"` +} + +type ChunkNotFoundError struct { + ID int +} + +func (e *ChunkNotFoundError) Error() string { + return core.Sprintf("state chunk %d not found", e.ID) +} + +func (e *ChunkNotFoundError) Unwrap() error { + return ErrChunkNotFound +} + +type URIChunkNotFoundError struct { + URI string +} + +func (e *URIChunkNotFoundError) Error() string { + if e.URI == "" { + return "state chunk URI not found" + } + return core.Sprintf("state chunk URI %q not found", e.URI) +} + +func (e *URIChunkNotFoundError) Unwrap() error { + return ErrChunkNotFound +} + +func Resolve(ctx context.Context, store Store, chunkID int) (Chunk, error) { + if ctx == nil { + ctx = context.Background() + } + if store == nil { + return Chunk{}, &ChunkNotFoundError{ID: chunkID} + } + if resolver, ok := store.(Resolver); ok { + return resolver.Resolve(ctx, chunkID) + } + text, err := store.Get(ctx, chunkID) + if err != nil { + return Chunk{}, err + } + return Chunk{ + Ref: ChunkRef{ChunkID: chunkID}, + Text: text, + }, nil +} + +func ResolveBytes(ctx context.Context, store Store, chunkID int) (Chunk, error) { + if ctx == nil { + ctx = context.Background() + } + if store == nil { + return Chunk{}, &ChunkNotFoundError{ID: chunkID} + } + if resolver, ok := store.(BinaryResolver); ok { + chunk, err := resolver.ResolveBytes(ctx, chunkID) + if err != nil { + return Chunk{}, err + } + if len(chunk.Data) == 0 && chunk.Text != "" { + chunk.Data = []byte(chunk.Text) + } + return chunk, nil + } + chunk, err := Resolve(ctx, store, chunkID) + if err != nil { + return Chunk{}, err + } + if len(chunk.Data) == 0 && chunk.Text != "" { + chunk.Data = []byte(chunk.Text) + } + return chunk, nil +} + +func ResolveRefBytes(ctx context.Context, store Store, ref ChunkRef) (Chunk, error) { + if ctx == nil { + ctx = context.Background() + } + if store == nil { + return Chunk{}, &ChunkNotFoundError{ID: ref.ChunkID} + } + if resolver, ok := store.(RefBinaryResolver); ok { + chunk, err := resolver.ResolveRefBytes(ctx, ref) + if err != nil { + return Chunk{}, err + } + if len(chunk.Data) == 0 && chunk.Text != "" { + chunk.Data = []byte(chunk.Text) + } + return chunk, nil + } + if ref.ChunkID == 0 { + return Chunk{}, &ChunkNotFoundError{ID: ref.ChunkID} + } + return ResolveBytes(ctx, store, ref.ChunkID) +} + +// BorrowBytes resolves a byte chunk and prefers store-native borrowed storage. +func BorrowBytes(ctx context.Context, store Store, chunkID int) (BorrowedChunk, error) { + if ctx == nil { + ctx = context.Background() + } + if store == nil { + return BorrowedChunk{}, &ChunkNotFoundError{ID: chunkID} + } + if borrower, ok := store.(BinaryBorrower); ok { + return borrower.BorrowBytes(ctx, chunkID) + } + chunk, err := ResolveBytes(ctx, store, chunkID) + if err != nil { + return BorrowedChunk{}, err + } + return BorrowedChunk{Ref: chunk.Ref, Data: chunk.Data}, nil +} + +// BorrowRefBytes resolves a byte chunk ref and prefers store-native borrowed +// storage. +func BorrowRefBytes(ctx context.Context, store Store, ref ChunkRef) (BorrowedChunk, error) { + if ctx == nil { + ctx = context.Background() + } + if store == nil { + return BorrowedChunk{}, &ChunkNotFoundError{ID: ref.ChunkID} + } + if borrower, ok := store.(RefBinaryBorrower); ok { + return borrower.BorrowRefBytes(ctx, ref) + } + if ref.ChunkID == 0 { + return BorrowedChunk{}, &ChunkNotFoundError{ID: ref.ChunkID} + } + return BorrowBytes(ctx, store, ref.ChunkID) +} + +func ResolveURI(ctx context.Context, store Store, uri string) (Chunk, error) { + if ctx == nil { + ctx = context.Background() + } + if store == nil || core.Trim(uri) == "" { + return Chunk{}, &URIChunkNotFoundError{URI: uri} + } + if resolver, ok := store.(URIResolver); ok { + return resolver.ResolveURI(ctx, uri) + } + return Chunk{}, &URIChunkNotFoundError{URI: uri} +} + +func MergeRef(base, overlay ChunkRef) ChunkRef { + out := base + if overlay.ChunkID != 0 || base.ChunkID == 0 { + out.ChunkID = overlay.ChunkID + } + if overlay.HasFrameOffset { + out.FrameOffset = overlay.FrameOffset + out.HasFrameOffset = true + } + if overlay.Codec != "" { + out.Codec = overlay.Codec + } + if overlay.Segment != "" { + out.Segment = overlay.Segment + } + return out +} diff --git a/go/state/store_bench_test.go b/go/state/store_bench_test.go new file mode 100644 index 0000000..e4e621c --- /dev/null +++ b/go/state/store_bench_test.go @@ -0,0 +1,257 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the top-level store dispatchers. +// Per AX-11 — Resolve / ResolveBytes / ResolveRefBytes / ResolveURI +// are the front-door API every consumer hits. They route to either +// the Store's native impl (filestore / memvid) or fall back to the +// minimal Store.Get adapter; both paths matter. MergeRef + the error +// formatters fire per chunk on the read-side hot loop. +// +// Run: go test -bench='Benchmark' -benchmem -run='^$' ./state + +package state + +import ( + "context" + "testing" +) + +// Sinks defeat compiler DCE. Distinct names per state-package bench file. +var ( + storeSinkChunk Chunk + storeSinkRef ChunkRef + storeSinkErr error + storeSinkErrText string + storeSinkChunkRef ChunkRef +) + +// --- Resolve (top-level dispatcher) --- +// Routes through the Resolver interface when available — InMemoryStore +// implements it, so this path is the "native dispatcher" cost. + +func BenchmarkStore_Resolve_Native_1KB(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = Resolve(ctx, store, 1) + } +} + +// Adapter store implements only the bare Store.Get — exercises the +// fallback branch in Resolve that wraps Get into a Chunk. + +func BenchmarkStore_Resolve_GetAdapter_1KB(b *testing.B) { + store := &benchGetOnlyStore{text: string(make([]byte, 1024))} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = Resolve(ctx, store, 1) + } +} + +func BenchmarkStore_Resolve_NilStore(b *testing.B) { + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = Resolve(ctx, nil, 1) + } +} + +// --- ResolveBytes (binary dispatcher) --- + +func BenchmarkStore_ResolveBytes_Native_1KB(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = ResolveBytes(ctx, store, 1) + } +} + +func BenchmarkStore_ResolveBytes_Native_64KB(b *testing.B) { + store := benchMemoryStore(b, 1, 64*1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = ResolveBytes(ctx, store, 1) + } +} + +// GetAdapter path — Store has no BinaryResolver, so ResolveBytes +// falls back through Resolve and copies Text → Data. + +func BenchmarkStore_ResolveBytes_GetAdapter_1KB(b *testing.B) { + store := &benchGetOnlyStore{text: string(make([]byte, 1024))} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = ResolveBytes(ctx, store, 1) + } +} + +// --- ResolveRefBytes (ChunkRef-with-frame-offset dispatcher) --- + +func BenchmarkStore_ResolveRefBytes_Native_1KB(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + ref := ChunkRef{ChunkID: 1, FrameOffset: 1, HasFrameOffset: true, Codec: CodecMemory} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = ResolveRefBytes(ctx, store, ref) + } +} + +// Without RefBinaryResolver — falls back through ResolveBytes by ID. + +func BenchmarkStore_ResolveRefBytes_GetAdapter_1KB(b *testing.B) { + store := &benchGetOnlyStore{text: string(make([]byte, 1024))} + ctx := context.Background() + ref := ChunkRef{ChunkID: 1, FrameOffset: 1, HasFrameOffset: true} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = ResolveRefBytes(ctx, store, ref) + } +} + +// --- ResolveURI (top-level URI dispatcher) --- + +func BenchmarkStore_ResolveURI_Native(b *testing.B) { + store := benchMemoryStore(b, 10, 1024) + ctx := context.Background() + uri := "state://bench/chunk-1" + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = ResolveURI(ctx, store, uri) + } +} + +func BenchmarkStore_ResolveURI_Empty(b *testing.B) { + store := benchMemoryStore(b, 1, 1024) + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = ResolveURI(ctx, store, "") + } +} + +func BenchmarkStore_ResolveURI_NoResolver(b *testing.B) { + // benchGetOnlyStore doesn't implement URIResolver — exercises + // the not-implemented branch that returns URIChunkNotFoundError. + store := &benchGetOnlyStore{text: "x"} + ctx := context.Background() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunk, storeSinkErr = ResolveURI(ctx, store, "state://bench/missing") + } +} + +// --- MergeRef (per-chunk overlay merge) --- +// Fires whenever a fork or restore needs to overlay a manifest ref +// onto a base ref (segment changes between bundle versions). + +func BenchmarkStore_MergeRef_OverlayAll(b *testing.B) { + base := ChunkRef{ChunkID: 7, FrameOffset: 7, HasFrameOffset: true, Codec: CodecMemory} + overlay := ChunkRef{ + ChunkID: 7, + FrameOffset: 42, + HasFrameOffset: true, + Codec: CodecStateVideo, + Segment: "epoch-3", + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunkRef = MergeRef(base, overlay) + } +} + +func BenchmarkStore_MergeRef_OverlayPartial(b *testing.B) { + base := ChunkRef{ChunkID: 7, FrameOffset: 7, HasFrameOffset: true, Codec: CodecMemory} + overlay := ChunkRef{Codec: CodecStateVideo} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunkRef = MergeRef(base, overlay) + } +} + +func BenchmarkStore_MergeRef_OverlayEmpty(b *testing.B) { + base := ChunkRef{ChunkID: 7, FrameOffset: 7, HasFrameOffset: true, Codec: CodecMemory} + overlay := ChunkRef{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkChunkRef = MergeRef(base, overlay) + } +} + +// --- ChunkNotFoundError / URIChunkNotFoundError formatters --- +// Fire on every miss; the format path crosses through core.Sprintf. + +func BenchmarkStore_ChunkNotFoundError_Error(b *testing.B) { + err := &ChunkNotFoundError{ID: 42} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkErrText = err.Error() + } +} + +func BenchmarkStore_URIChunkNotFoundError_Error(b *testing.B) { + err := &URIChunkNotFoundError{URI: "state://bench/missing"} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkErrText = err.Error() + } +} + +func BenchmarkStore_URIChunkNotFoundError_ErrorEmpty(b *testing.B) { + err := &URIChunkNotFoundError{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkErrText = err.Error() + } +} + +// --- ChunkRef value construction (the ID-only-shape) --- + +func BenchmarkStore_ChunkRef_Construct(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + storeSinkRef = ChunkRef{ + ChunkID: 7, + FrameOffset: 42, + HasFrameOffset: true, + Codec: CodecStateVideo, + Segment: "epoch-3", + } + } +} + +// --- Bench helpers --- + +// benchGetOnlyStore implements just the bare Store.Get contract so +// the bench can exercise the fallback dispatch path in Resolve / +// ResolveBytes / ResolveRefBytes when a backend only ships text reads. +type benchGetOnlyStore struct { + text string +} + +func (s *benchGetOnlyStore) Get(_ context.Context, _ int) (string, error) { + return s.text, nil +} diff --git a/go/training_bench_test.go b/go/training_bench_test.go new file mode 100644 index 0000000..401a066 --- /dev/null +++ b/go/training_bench_test.go @@ -0,0 +1,177 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the training contract shapes — DefaultLoRAConfig +// constructor + TrainingConfig / TrainingResult / DistillConfig / GRPOConfig +// JSON marshal. Per AX-11 — TrainingResult is the canonical wire format +// every trainer emits on every checkpoint; the per-step Metrics record is +// the tightest serialise loop. DefaultLoRAConfig fires once per training +// run but is exercised heavily in tests + tooling. +// +// Run: go test -bench='BenchmarkTraining' -benchmem -run='^$' . + +package inference + +import ( + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. Distinct names from the other bench files. +var ( + trainingBenchSinkConfig LoRAConfig + trainingBenchSinkString string +) + +// --- DefaultLoRAConfig (constructor allocation cost) --- + +func BenchmarkTraining_DefaultLoRAConfig(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + trainingBenchSinkConfig = DefaultLoRAConfig() + } +} + +// --- TrainingConfig marshal (per-run checkpoint envelope) --- + +func BenchmarkTraining_TrainingConfig_Marshal(b *testing.B) { + cfg := TrainingConfig{ + Epochs: 3, + BatchSize: 4, + GradientAccumulation: 8, + LearningRate: 1e-4, + LoRA: LoRAConfig{ + Rank: 16, + Alpha: 32, + TargetKeys: []string{"q_proj", "k_proj", "v_proj", "o_proj"}, + BFloat16: true, + }, + Labels: map[string]string{"run": "nightly", "dataset": "lthn-corpus"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + trainingBenchSinkString = core.JSONMarshalString(cfg) + } +} + +// --- TrainingMetrics marshal (per-step record — tightest loop) --- + +func BenchmarkTraining_TrainingMetrics_Marshal(b *testing.B) { + metrics := TrainingMetrics{ + Epoch: 2, + Step: 512, + Samples: 16384, + Tokens: 2097152, + Loss: 1.234, + LearningRate: 5e-5, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + trainingBenchSinkString = core.JSONMarshalString(metrics) + } +} + +// --- TrainingResult marshal (per-checkpoint envelope) --- + +func BenchmarkTraining_TrainingResult_Marshal(b *testing.B) { + result := TrainingResult{ + Model: ModelIdentity{ + Path: "/models/qwen3-4b", + Architecture: "qwen3", + QuantBits: 4, + }, + Adapter: AdapterIdentity{ + Path: "/adapters/run-2026-05-21/epoch-2", + Format: "safetensors", + Rank: 16, + Alpha: 32, + }, + Metrics: TrainingMetrics{ + Epoch: 2, + Step: 512, + Samples: 16384, + Tokens: 2097152, + Loss: 1.234, + LearningRate: 5e-5, + }, + Checkpoints: []StateRef{ + {Kind: "checkpoint", URI: "file:///tmp/step-256", SizeBytes: 1 << 20}, + {Kind: "checkpoint", URI: "file:///tmp/step-512", SizeBytes: 1 << 20}, + }, + Labels: map[string]string{"run": "nightly"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + trainingBenchSinkString = core.JSONMarshalString(result) + } +} + +// --- DistillConfig marshal (teacher/student wire envelope) --- + +func BenchmarkTraining_DistillConfig_Marshal(b *testing.B) { + cfg := DistillConfig{ + TrainingConfig: TrainingConfig{ + Epochs: 2, + BatchSize: 8, + GradientAccumulation: 4, + LearningRate: 2e-4, + LoRA: LoRAConfig{ + Rank: 8, + Alpha: 16, + TargetKeys: []string{"q_proj", "v_proj"}, + }, + }, + Temperature: 2.0, + Alpha: 0.7, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + trainingBenchSinkString = core.JSONMarshalString(cfg) + } +} + +// --- GRPOConfig marshal (reasoning policy optimisation envelope) --- + +func BenchmarkTraining_GRPOConfig_Marshal(b *testing.B) { + cfg := GRPOConfig{ + TrainingConfig: TrainingConfig{ + Epochs: 1, + BatchSize: 2, + LearningRate: 5e-6, + LoRA: LoRAConfig{ + Rank: 16, + Alpha: 32, + TargetKeys: []string{"q_proj", "k_proj", "v_proj", "o_proj"}, + BFloat16: true, + }, + }, + GroupSize: 8, + KLWeight: 0.04, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + trainingBenchSinkString = core.JSONMarshalString(cfg) + } +} + +// --- LoRAConfig marshal (per-adapter sidecar) --- + +func BenchmarkTraining_LoRAConfig_Marshal(b *testing.B) { + cfg := LoRAConfig{ + Rank: 64, + Alpha: 128, + TargetKeys: []string{"q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"}, + BFloat16: true, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + trainingBenchSinkString = core.JSONMarshalString(cfg) + } +} diff --git a/go/tuning.go b/go/tuning.go new file mode 100644 index 0000000..9984175 --- /dev/null +++ b/go/tuning.go @@ -0,0 +1,390 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "context" + "strconv" + + core "dappco.re/go" +) + +// TuningWorkload identifies the user-facing job a local model profile is +// being optimised for. The values are stable so UIs can persist profiles. +type TuningWorkload string + +const ( + TuningWorkloadChat TuningWorkload = "chat" + TuningWorkloadCoding TuningWorkload = "coding" + TuningWorkloadLongContext TuningWorkload = "long_context" + TuningWorkloadAgentState TuningWorkload = "agent_state" + TuningWorkloadThroughput TuningWorkload = "throughput" + TuningWorkloadLowLatency TuningWorkload = "low_latency" +) + +var defaultTuningWorkloads = []TuningWorkload{ + TuningWorkloadChat, + TuningWorkloadCoding, + TuningWorkloadLongContext, + TuningWorkloadAgentState, + TuningWorkloadThroughput, + TuningWorkloadLowLatency, +} + +// DefaultTuningWorkloads returns the standard set shown by local tuning UIs. +func DefaultTuningWorkloads() []TuningWorkload { + return append([]TuningWorkload(nil), defaultTuningWorkloads...) +} + +// MachineDiscoverer is implemented by runtimes that can report local hardware, +// supported settings, and optionally discovered model packs without loading +// weights. +type MachineDiscoverer interface { + DiscoverMachine(context.Context, MachineDiscoveryRequest) (*MachineDiscoveryReport, error) +} + +// TuningPlanner is implemented by runtimes that can propose candidate load +// settings for a model/workload pair. +type TuningPlanner interface { + PlanTuning(context.Context, TuningPlanRequest) (*TuningPlan, error) +} + +// MachineDeviceInfo records the backend-neutral hardware facts a driver can +// expose before any model is loaded. +type MachineDeviceInfo struct { + Name string `json:"name,omitempty"` + Architecture string `json:"architecture,omitempty"` + MaxBufferLength uint64 `json:"max_buffer_length,omitempty"` + MaxRecommendedWorkingSetSize uint64 `json:"max_recommended_working_set_size,omitempty"` + MemorySize uint64 `json:"memory_size,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// MachineDiscoveryRequest controls cheap local discovery. Drivers should keep +// this metadata-first and avoid loading weights. +type MachineDiscoveryRequest struct { + ModelDirs []string `json:"model_dirs,omitempty"` + Workloads []TuningWorkload `json:"workloads,omitempty"` + MaxModels int `json:"max_models,omitempty"` + IncludeModels bool `json:"include_models,omitempty"` + IncludeCandidates bool `json:"include_candidates,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// MachineDiscoveryReport is the UI-facing summary of a local backend plus any +// models and candidate settings discovered cheaply. +type MachineDiscoveryReport struct { + Runtime RuntimeIdentity `json:"runtime,omitempty"` + Device MachineDeviceInfo `json:"device,omitempty"` + Available bool `json:"available"` + Capabilities []Capability `json:"capabilities,omitempty"` + CacheModes []string `json:"cache_modes,omitempty"` + Models []DiscoveredModel `json:"models,omitempty"` + Workloads []TuningWorkload `json:"workloads,omitempty"` + Candidates []TuningCandidate `json:"candidates,omitempty"` + Warnings []string `json:"warnings,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// TuningBudget bounds optional autotuning work. Zero values mean the driver +// picks a short smoke-test default. +type TuningBudget struct { + MaxCandidates int `json:"max_candidates,omitempty"` + SmokeTokens int `json:"smoke_tokens,omitempty"` + Runs int `json:"runs,omitempty"` + AllowStateBench bool `json:"allow_state_bench,omitempty"` + AllowModelReloads bool `json:"allow_model_reloads,omitempty"` +} + +// TuningPlanRequest asks a backend to turn known hardware/model facts into +// candidate settings. It is intentionally metadata-only. +type TuningPlanRequest struct { + Runtime RuntimeIdentity `json:"runtime,omitempty"` + Device MachineDeviceInfo `json:"device,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Workloads []TuningWorkload `json:"workloads,omitempty"` + Budget TuningBudget `json:"budget,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// TuningCandidate is one concrete model-load shape the UI can try or persist. +type TuningCandidate struct { + ID string `json:"id,omitempty"` + Workload TuningWorkload `json:"workload,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Runtime RuntimeIdentity `json:"runtime,omitempty"` + ContextLength int `json:"context_length,omitempty"` + ParallelSlots int `json:"parallel_slots,omitempty"` + PromptCache bool `json:"prompt_cache,omitempty"` + PromptCacheMinTokens int `json:"prompt_cache_min_tokens,omitempty"` + CachePolicy string `json:"cache_policy,omitempty"` + CacheMode string `json:"cache_mode,omitempty"` + BatchSize int `json:"batch_size,omitempty"` + PrefillChunkSize int `json:"prefill_chunk_size,omitempty"` + ExpectedQuantization int `json:"expected_quantization,omitempty"` + MemoryLimitBytes uint64 `json:"memory_limit_bytes,omitempty"` + CacheLimitBytes uint64 `json:"cache_limit_bytes,omitempty"` + WiredLimitBytes uint64 `json:"wired_limit_bytes,omitempty"` + Reasons []string `json:"reasons,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// TuningPlan is a compact set of candidates and per-workload recommendations. +type TuningPlan struct { + Runtime RuntimeIdentity `json:"runtime,omitempty"` + Device MachineDeviceInfo `json:"device,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Workloads []TuningWorkload `json:"workloads,omitempty"` + Candidates []TuningCandidate `json:"candidates,omitempty"` + Recommended map[TuningWorkload]string `json:"recommended,omitempty"` + Warnings []string `json:"warnings,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// TuningMeasurements is the driver-neutral subset of a bench result used for +// scoring and persisted profiles. +type TuningMeasurements struct { + PromptTokens int `json:"prompt_tokens,omitempty"` + GeneratedTokens int `json:"generated_tokens,omitempty"` + LoadMilliseconds float64 `json:"load_milliseconds,omitempty"` + FirstTokenMilliseconds float64 `json:"first_token_milliseconds,omitempty"` + PrefillTokensPerSec float64 `json:"prefill_tokens_per_sec,omitempty"` + DecodeTokensPerSec float64 `json:"decode_tokens_per_sec,omitempty"` + PromptCacheHitRate float64 `json:"prompt_cache_hit_rate,omitempty"` + KVRestoreMilliseconds float64 `json:"kv_restore_milliseconds,omitempty"` + StateBundleMilliseconds float64 `json:"state_bundle_milliseconds,omitempty"` + TotalMilliseconds float64 `json:"total_milliseconds,omitempty"` + PeakMemoryBytes uint64 `json:"peak_memory_bytes,omitempty"` + ActiveMemoryBytes uint64 `json:"active_memory_bytes,omitempty"` + CorrectnessSmokeResult string `json:"correctness_smoke_result,omitempty"` + CorrectnessSmokeChecks int `json:"correctness_smoke_checks,omitempty"` +} + +// TuningScore records a comparable score plus the raw metrics that drove it. +type TuningScore struct { + Workload TuningWorkload `json:"workload,omitempty"` + Score float64 `json:"score,omitempty"` + FirstTokenMilliseconds float64 `json:"first_token_milliseconds,omitempty"` + PrefillTokensPerSec float64 `json:"prefill_tokens_per_sec,omitempty"` + DecodeTokensPerSec float64 `json:"decode_tokens_per_sec,omitempty"` + PromptCacheHitRate float64 `json:"prompt_cache_hit_rate,omitempty"` + KVRestoreMilliseconds float64 `json:"kv_restore_milliseconds,omitempty"` + PeakMemoryBytes uint64 `json:"peak_memory_bytes,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// TuningResult is emitted after each candidate finishes or fails. +type TuningResult struct { + Candidate TuningCandidate `json:"candidate,omitempty"` + Measurements TuningMeasurements `json:"measurements,omitempty"` + Score TuningScore `json:"score,omitempty"` + Error string `json:"error,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// TuningEventKind names the streamed lifecycle events an autotune runner emits. +type TuningEventKind string + +const ( + TuningEventCandidate TuningEventKind = "candidate" + TuningEventResult TuningEventKind = "result" + TuningEventSelected TuningEventKind = "selected" +) + +// TuningEvent lets UIs update as each candidate starts and finishes. +type TuningEvent struct { + Kind TuningEventKind `json:"kind"` + Candidate TuningCandidate `json:"candidate,omitempty"` + Result *TuningResult `json:"result,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// TuningProfileKey identifies a persisted winner for one machine/model/workload. +type TuningProfileKey struct { + MachineHash string `json:"machine_hash,omitempty"` + Runtime RuntimeIdentity `json:"runtime,omitempty"` + Model ModelIdentity `json:"model,omitempty"` + Adapter AdapterIdentity `json:"adapter,omitempty"` + Workload TuningWorkload `json:"workload,omitempty"` +} + +// TuningProfile stores a proven candidate for later fast reloads. +type TuningProfile struct { + Key TuningProfileKey `json:"key,omitempty"` + Candidate TuningCandidate `json:"candidate,omitempty"` + Measurements TuningMeasurements `json:"measurements,omitempty"` + Score TuningScore `json:"score,omitempty"` + CreatedAtUnix int64 `json:"created_at_unix,omitempty"` + Labels map[string]string `json:"labels,omitempty"` +} + +// ScoreTuningMeasurements turns measured smoke-test counters into a simple +// workload-aware score. It deliberately stays transparent rather than claiming +// a universal benchmark. +func ScoreTuningMeasurements(workload TuningWorkload, m TuningMeasurements) TuningScore { + // Labels map is lazy: most workloads emit zero label entries (Chat, + // Throughput, Default — and LongContext/AgentState/LowLatency when + // the optional measurements are missing). Eager-init then nil-out + // pays an empty-map alloc per call (~48 B/op) which escapes to heap + // because TuningScore returns the labels pointer. Lazy-init defers + // the alloc to the moment the first label key is written, and the + // no-label paths stay at zero heap allocs for the labels slot. When + // a label IS written, the map is pre-sized to the small upper bound + // for that workload to skip the default grow-from-empty. + var labels map[string]string + score := m.DecodeTokensPerSec + switch workload { + case TuningWorkloadLongContext: + score += m.PrefillTokensPerSec * 0.2 + if m.PromptCacheHitRate > 0 { + score += m.PromptCacheHitRate * 100 + labels = make(map[string]string, 1) + labels["prompt_cache"] = "enabled" + } + case TuningWorkloadAgentState: + score += m.PrefillTokensPerSec * 0.1 + score += m.PromptCacheHitRate * 120 + if m.KVRestoreMilliseconds > 0 { + score += 1000 / (m.KVRestoreMilliseconds + 1) + if labels == nil { + labels = make(map[string]string, 2) + } + labels["state_restore"] = "enabled" + } + if m.StateBundleMilliseconds > 0 { + score += 500 / (m.StateBundleMilliseconds + 1) + if labels == nil { + labels = make(map[string]string, 2) + } + labels["state_bundle"] = "enabled" + } + case TuningWorkloadThroughput: + score += m.PrefillTokensPerSec * 0.05 + case TuningWorkloadLowLatency: + if m.FirstTokenMilliseconds > 0 { + score += 1000 / (m.FirstTokenMilliseconds + 1) + labels = make(map[string]string, 1) + labels["first_token"] = "measured" + } + if m.TotalMilliseconds > 0 { + score += 1000 / m.TotalMilliseconds + } + default: + score += m.PrefillTokensPerSec * 0.02 + } + return TuningScore{ + Workload: workload, + Score: score, + FirstTokenMilliseconds: m.FirstTokenMilliseconds, + PrefillTokensPerSec: m.PrefillTokensPerSec, + DecodeTokensPerSec: m.DecodeTokensPerSec, + PromptCacheHitRate: m.PromptCacheHitRate, + KVRestoreMilliseconds: m.KVRestoreMilliseconds, + PeakMemoryBytes: m.PeakMemoryBytes, + Labels: labels, + } +} + +// ModelReplaceAction describes the safest way to move between loaded models +// or settings while preserving useful state where possible. +type ModelReplaceAction string + +const ( + ModelReplaceReuseState ModelReplaceAction = "reuse_state" + ModelReplaceCheckpointState ModelReplaceAction = "checkpoint_state" + ModelReplaceSummaryWindow ModelReplaceAction = "summary_window" +) + +// ModelReplaceRequest compares the current runtime/model/adapter against the +// requested replacement. +type ModelReplaceRequest struct { + CurrentModel ModelIdentity `json:"current_model,omitempty"` + NextModel ModelIdentity `json:"next_model,omitempty"` + CurrentRuntime RuntimeIdentity `json:"current_runtime,omitempty"` + NextRuntime RuntimeIdentity `json:"next_runtime,omitempty"` + CurrentAdapter AdapterIdentity `json:"current_adapter,omitempty"` + NextAdapter AdapterIdentity `json:"next_adapter,omitempty"` +} + +// ModelReplacePlan tells the UI whether state can be reused directly or should +// be compacted into a summary/new window before reload. +type ModelReplacePlan struct { + Action ModelReplaceAction `json:"action"` + Compatible bool `json:"compatible"` + Reasons []string `json:"reasons,omitempty"` +} + +// PlanModelReplace returns a conservative state-reuse decision for model swaps. +func PlanModelReplace(req ModelReplaceRequest) ModelReplacePlan { + sameModel := sameModelIdentity(req.CurrentModel, req.NextModel) + sameRuntime := sameRuntimeIdentity(req.CurrentRuntime, req.NextRuntime) + sameAdapter := sameAdapterIdentity(req.CurrentAdapter, req.NextAdapter) + switch { + case sameModel && sameRuntime && sameAdapter: + return ModelReplacePlan{Action: ModelReplaceReuseState, Compatible: true, Reasons: []string{"model, runtime, and adapter match"}} + case sameModel && sameAdapter: + // CheckpointState path: 0 or 1 reason. Pre-size the backing + // array so the append (when it fires) does not trigger an + // extra grow alloc; when sameRuntime keeps it empty the slice + // is still nil so json.Marshal honours omitempty correctly. + var reasons []string + if !sameRuntime { + reasons = make([]string, 0, 1) + reasons = append(reasons, "runtime or cache settings changed") + } + return ModelReplacePlan{Action: ModelReplaceCheckpointState, Compatible: true, Reasons: reasons} + default: + // SummaryWindow path: up to 2 reasons (model + adapter). The + // previous shape allocated `[]string{}` and then grew on each + // append — two allocs by the second append. Pre-sizing to 2 + // drops the grow. + reasons := make([]string, 0, 2) + if !sameModel { + reasons = append(reasons, "model identity changed") + } + if !sameAdapter { + reasons = append(reasons, "adapter identity changed") + } + return ModelReplacePlan{Action: ModelReplaceSummaryWindow, Compatible: false, Reasons: reasons} + } +} + +func sameModelIdentity(a, b ModelIdentity) bool { + if a.Hash != "" || b.Hash != "" { + return a.Hash != "" && a.Hash == b.Hash + } + if a.Path != "" || b.Path != "" { + return a.Path != "" && a.Path == b.Path && a.QuantBits == b.QuantBits && a.QuantType == b.QuantType + } + return a.Architecture == b.Architecture && a.QuantBits == b.QuantBits && a.ContextLength == b.ContextLength +} + +func sameRuntimeIdentity(a, b RuntimeIdentity) bool { + return a.Backend == b.Backend && a.Device == b.Device && a.CacheMode == b.CacheMode +} + +func sameAdapterIdentity(a, b AdapterIdentity) bool { + if a.Hash != "" || b.Hash != "" { + return a.Hash != "" && a.Hash == b.Hash + } + return a.Path == b.Path && a.Format == b.Format && a.Rank == b.Rank && a.Alpha == b.Alpha +} + +// CandidateID builds a stable readable ID when a planner has not supplied one. +// +// Hand-built via strconv.AppendInt + core.AsString — saves the fmt +// formatter pipeline that Sprintf would walk for every tuning lookup. +func CandidateID(workload TuningWorkload, cacheMode string, contextLength, batchSize int) string { + buf := make([]byte, 0, len(workload)+len(cacheMode)+32) + buf = append(buf, string(workload)...) + buf = append(buf, ':') + buf = append(buf, cacheMode...) + buf = append(buf, ':', 'c', 't', 'x') + buf = strconv.AppendInt(buf, int64(contextLength), 10) + buf = append(buf, ':', 'b', 'a', 't', 'c', 'h') + buf = strconv.AppendInt(buf, int64(batchSize), 10) + return core.AsString(buf) +} diff --git a/go/tuning_bench_test.go b/go/tuning_bench_test.go new file mode 100644 index 0000000..5653af1 --- /dev/null +++ b/go/tuning_bench_test.go @@ -0,0 +1,363 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Benchmarks for the tuning contract shapes — DefaultTuningWorkloads +// constructor, ScoreTuningMeasurements (per-result scoring), PlanModelReplace +// (per-model-swap state-reuse decision), CandidateID (per-candidate ID +// builder), and JSON marshal for the larger MachineDiscoveryReport / TuningPlan +// envelopes that the local-tuning UI fetches on every refresh. Per AX-11 — +// ScoreTuningMeasurements + CandidateID fire in tight loops during autotune; +// PlanModelReplace runs on every model swap; the report marshals are the +// wire format on every UI refresh. +// +// Run: go test -bench='BenchmarkTuning' -benchmem -run='^$' . + +package inference + +import ( + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. Distinct names from the other bench files. +var ( + tuningBenchSinkWorkloads []TuningWorkload + tuningBenchSinkScore TuningScore + tuningBenchSinkPlan ModelReplacePlan + tuningBenchSinkID string + tuningBenchSinkString string +) + +// --- DefaultTuningWorkloads (constructor allocation cost) --- + +func BenchmarkTuning_DefaultTuningWorkloads(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkWorkloads = DefaultTuningWorkloads() + } +} + +// --- ScoreTuningMeasurements — per-workload scoring switch --- + +func BenchmarkTuning_ScoreMeasurements_Chat(b *testing.B) { + m := TuningMeasurements{ + PrefillTokensPerSec: 900, + DecodeTokensPerSec: 120, + PeakMemoryBytes: 8 << 30, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkScore = ScoreTuningMeasurements(TuningWorkloadChat, m) + } +} + +func BenchmarkTuning_ScoreMeasurements_LongContext(b *testing.B) { + m := TuningMeasurements{ + PrefillTokensPerSec: 1200, + DecodeTokensPerSec: 45, + PromptCacheHitRate: 0.8, + PeakMemoryBytes: 12 << 30, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkScore = ScoreTuningMeasurements(TuningWorkloadLongContext, m) + } +} + +func BenchmarkTuning_ScoreMeasurements_AgentState(b *testing.B) { + m := TuningMeasurements{ + PrefillTokensPerSec: 900, + DecodeTokensPerSec: 120, + PromptCacheHitRate: 0.75, + KVRestoreMilliseconds: 4, + StateBundleMilliseconds: 2, + PeakMemoryBytes: 8 << 30, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkScore = ScoreTuningMeasurements(TuningWorkloadAgentState, m) + } +} + +func BenchmarkTuning_ScoreMeasurements_Throughput(b *testing.B) { + m := TuningMeasurements{ + PrefillTokensPerSec: 2400, + DecodeTokensPerSec: 220, + PeakMemoryBytes: 16 << 30, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkScore = ScoreTuningMeasurements(TuningWorkloadThroughput, m) + } +} + +func BenchmarkTuning_ScoreMeasurements_LowLatency(b *testing.B) { + m := TuningMeasurements{ + DecodeTokensPerSec: 80, + FirstTokenMilliseconds: 20, + TotalMilliseconds: 120, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkScore = ScoreTuningMeasurements(TuningWorkloadLowLatency, m) + } +} + +func BenchmarkTuning_ScoreMeasurements_Default(b *testing.B) { + m := TuningMeasurements{ + PrefillTokensPerSec: 1100, + DecodeTokensPerSec: 90, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Empty workload string falls to the default branch. + tuningBenchSinkScore = ScoreTuningMeasurements(TuningWorkload(""), m) + } +} + +// --- PlanModelReplace — per-swap state-reuse decision --- + +func BenchmarkTuning_PlanModelReplace_ReuseState(b *testing.B) { + model := ModelIdentity{Path: "/models/qwen", Hash: "abc", Architecture: "qwen3", QuantBits: 4} + runtime := RuntimeIdentity{Backend: "metal", CacheMode: "paged"} + adapter := AdapterIdentity{Hash: "lora1"} + req := ModelReplaceRequest{ + CurrentModel: model, + NextModel: model, + CurrentRuntime: runtime, + NextRuntime: runtime, + CurrentAdapter: adapter, + NextAdapter: adapter, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkPlan = PlanModelReplace(req) + } +} + +func BenchmarkTuning_PlanModelReplace_CheckpointState(b *testing.B) { + model := ModelIdentity{Path: "/models/qwen", Hash: "abc", Architecture: "qwen3", QuantBits: 4} + adapter := AdapterIdentity{Hash: "lora1"} + req := ModelReplaceRequest{ + CurrentModel: model, + NextModel: model, + CurrentRuntime: RuntimeIdentity{Backend: "metal", CacheMode: "paged"}, + NextRuntime: RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"}, + CurrentAdapter: adapter, + NextAdapter: adapter, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkPlan = PlanModelReplace(req) + } +} + +func BenchmarkTuning_PlanModelReplace_SummaryWindow(b *testing.B) { + current := ModelIdentity{Path: "/models/qwen", Hash: "abc", Architecture: "qwen3", QuantBits: 4} + next := ModelIdentity{Path: "/models/gemma", Hash: "def", Architecture: "gemma4", QuantBits: 4} + req := ModelReplaceRequest{ + CurrentModel: current, + NextModel: next, + CurrentRuntime: RuntimeIdentity{Backend: "metal", CacheMode: "paged"}, + NextRuntime: RuntimeIdentity{Backend: "metal", CacheMode: "paged"}, + CurrentAdapter: AdapterIdentity{Hash: "lora1"}, + NextAdapter: AdapterIdentity{Hash: "lora2"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkPlan = PlanModelReplace(req) + } +} + +// --- CandidateID — per-candidate stable ID builder --- + +func BenchmarkTuning_CandidateID(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkID = CandidateID(TuningWorkloadLongContext, "paged-q8", 32768, 4) + } +} + +// --- JSON marshal — UI-facing report envelopes --- + +func BenchmarkTuning_TuningCandidate_Marshal(b *testing.B) { + candidate := TuningCandidate{ + ID: "long_context:paged-q8:ctx32768:batch4", + Workload: TuningWorkloadLongContext, + Model: ModelIdentity{Architecture: "qwen3", QuantBits: 4, ContextLength: 32768}, + Runtime: RuntimeIdentity{Backend: "metal", CacheMode: "paged-q8"}, + ContextLength: 32768, + ParallelSlots: 2, + PromptCache: true, + PromptCacheMinTokens: 512, + CachePolicy: "lru", + CacheMode: "paged-q8", + BatchSize: 4, + PrefillChunkSize: 512, + ExpectedQuantization: 4, + MemoryLimitBytes: 16 << 30, + CacheLimitBytes: 8 << 30, + WiredLimitBytes: 4 << 30, + Reasons: []string{"context fits", "cache hit > 0.8"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkString = core.JSONMarshalString(candidate) + } +} + +func BenchmarkTuning_TuningResult_Marshal(b *testing.B) { + result := TuningResult{ + Candidate: TuningCandidate{ + ID: "long_context:paged-q8:ctx32768:batch4", + Workload: TuningWorkloadLongContext, + Model: ModelIdentity{Architecture: "qwen3", QuantBits: 4}, + ContextLength: 32768, + BatchSize: 4, + }, + Measurements: TuningMeasurements{ + PromptTokens: 2048, + GeneratedTokens: 128, + LoadMilliseconds: 1240, + FirstTokenMilliseconds: 35, + PrefillTokensPerSec: 1200, + DecodeTokensPerSec: 45, + PromptCacheHitRate: 0.81, + KVRestoreMilliseconds: 12, + TotalMilliseconds: 4200, + PeakMemoryBytes: 12 << 30, + ActiveMemoryBytes: 8 << 30, + }, + Score: TuningScore{ + Workload: TuningWorkloadLongContext, + Score: 125.4, + PrefillTokensPerSec: 1200, + DecodeTokensPerSec: 45, + PromptCacheHitRate: 0.81, + PeakMemoryBytes: 12 << 30, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkString = core.JSONMarshalString(result) + } +} + +func BenchmarkTuning_MachineDiscoveryReport_Marshal(b *testing.B) { + report := MachineDiscoveryReport{ + Runtime: RuntimeIdentity{Backend: "metal", Device: "m3-ultra", Version: "0.10"}, + Device: MachineDeviceInfo{ + Name: "Apple M3 Ultra", + Architecture: "arm64", + MaxBufferLength: 64 << 30, + MaxRecommendedWorkingSetSize: 80 << 30, + MemorySize: 96 << 30, + }, + Available: true, + CacheModes: []string{"paged", "paged-q8", "paged-q4"}, + Models: []DiscoveredModel{ + {Path: "/models/qwen3-4b", ModelType: "qwen3", QuantBits: 4, NumFiles: 4, Format: "safetensors"}, + {Path: "/models/gemma3-1b", ModelType: "gemma3", QuantBits: 4, NumFiles: 1, Format: "safetensors"}, + {Path: "/models/llama3-8b", ModelType: "llama", QuantBits: 4, NumFiles: 4, Format: "safetensors"}, + }, + Workloads: DefaultTuningWorkloads(), + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkString = core.JSONMarshalString(report) + } +} + +func BenchmarkTuning_TuningPlan_Marshal(b *testing.B) { + plan := TuningPlan{ + Runtime: RuntimeIdentity{Backend: "metal", Device: "m3-ultra"}, + Model: ModelIdentity{Architecture: "qwen3", QuantBits: 4}, + Workloads: []TuningWorkload{ + TuningWorkloadChat, + TuningWorkloadLongContext, + TuningWorkloadAgentState, + }, + Candidates: []TuningCandidate{ + {ID: "chat:paged:ctx4096:batch1", Workload: TuningWorkloadChat, ContextLength: 4096, BatchSize: 1, CacheMode: "paged"}, + {ID: "long_context:paged-q8:ctx32768:batch4", Workload: TuningWorkloadLongContext, ContextLength: 32768, BatchSize: 4, CacheMode: "paged-q8"}, + {ID: "agent_state:paged:ctx8192:batch1", Workload: TuningWorkloadAgentState, ContextLength: 8192, BatchSize: 1, CacheMode: "paged"}, + }, + Recommended: map[TuningWorkload]string{ + TuningWorkloadChat: "chat:paged:ctx4096:batch1", + TuningWorkloadLongContext: "long_context:paged-q8:ctx32768:batch4", + TuningWorkloadAgentState: "agent_state:paged:ctx8192:batch1", + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkString = core.JSONMarshalString(plan) + } +} + +func BenchmarkTuning_TuningEvent_Marshal(b *testing.B) { + event := TuningEvent{ + Kind: TuningEventResult, + Candidate: TuningCandidate{ + ID: "long_context:paged-q8:ctx32768:batch4", + Workload: TuningWorkloadLongContext, + Model: ModelIdentity{Architecture: "qwen3", QuantBits: 4}, + }, + Result: &TuningResult{ + Measurements: TuningMeasurements{ + PrefillTokensPerSec: 1200, + DecodeTokensPerSec: 45, + }, + Score: TuningScore{Workload: TuningWorkloadLongContext, Score: 125.4}, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkString = core.JSONMarshalString(event) + } +} + +func BenchmarkTuning_TuningProfile_Marshal(b *testing.B) { + profile := TuningProfile{ + Key: TuningProfileKey{ + MachineHash: "sha256-abcd-1234", + Runtime: RuntimeIdentity{Backend: "metal", Device: "m3-ultra"}, + Model: ModelIdentity{Architecture: "qwen3", QuantBits: 4}, + Workload: TuningWorkloadLongContext, + }, + Candidate: TuningCandidate{ + ID: "long_context:paged-q8:ctx32768:batch4", + Workload: TuningWorkloadLongContext, + ContextLength: 32768, + BatchSize: 4, + CacheMode: "paged-q8", + }, + Measurements: TuningMeasurements{ + PrefillTokensPerSec: 1200, + DecodeTokensPerSec: 45, + PromptCacheHitRate: 0.81, + }, + Score: TuningScore{Workload: TuningWorkloadLongContext, Score: 125.4}, + CreatedAtUnix: 1700000000, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuningBenchSinkString = core.JSONMarshalString(profile) + } +} diff --git a/go/tuning_deep_bench_test.go b/go/tuning_deep_bench_test.go new file mode 100644 index 0000000..3c3b60f --- /dev/null +++ b/go/tuning_deep_bench_test.go @@ -0,0 +1,304 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +// Deeper benchmarks for the tuning contract shapes. +// Per AX-11 — the existing tuning_bench_test.go covers main paths. +// These benches drill into the CandidateID variants (workload + cache +// mode + context length combinations), sameModelIdentity / sameRuntime +// / sameAdapter shape variants (hash vs path vs identity-only), and +// PlanModelReplace edge cases (runtime-only change, adapter-only +// change, all-empty). All of these fire in tight loops during autotune. +// +// Run: go test -bench='BenchmarkTuningDeep' -benchmem -run='^$' . + +package inference + +import ( + "testing" + + core "dappco.re/go" +) + +// Sinks defeat compiler DCE. Distinct names from the other bench files. +var ( + tuneDeepSinkID string + tuneDeepSinkPlan ModelReplacePlan + tuneDeepSinkScore TuningScore + tuneDeepSinkString string +) + +// --- CandidateID variants --- +// CandidateID builds a deterministic ID from workload + cache mode + +// context length + batch size. The existing bench covers a single +// combination; these cover the surface area. + +func BenchmarkTuningDeep_CandidateID_ShortFields(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkID = CandidateID(TuningWorkloadChat, "p", 256, 1) + } +} + +func BenchmarkTuningDeep_CandidateID_LongFields(b *testing.B) { + // Long cache mode + large context — exercises strconv.AppendInt + // on 6-digit numbers. + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkID = CandidateID(TuningWorkloadLongContext, "paged-q8-experimental", 131072, 32) + } +} + +func BenchmarkTuningDeep_CandidateID_AgentState(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkID = CandidateID(TuningWorkloadAgentState, "paged", 8192, 1) + } +} + +func BenchmarkTuningDeep_CandidateID_Throughput(b *testing.B) { + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkID = CandidateID(TuningWorkloadThroughput, "paged-q4", 4096, 16) + } +} + +func BenchmarkTuningDeep_CandidateID_EmptyMode(b *testing.B) { + // Empty cache mode — minimum-length string path. + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkID = CandidateID(TuningWorkloadLowLatency, "", 1024, 1) + } +} + +// --- PlanModelReplace edge cases --- +// The existing benches cover ReuseState / CheckpointState / SummaryWindow +// at the top of the matrix. These cover the inner shapes. + +func BenchmarkTuningDeep_PlanModelReplace_RuntimeOnly(b *testing.B) { + // Same model + same adapter, runtime differs only in cache mode. + model := ModelIdentity{Hash: "abc", Architecture: "qwen3", QuantBits: 4} + adapter := AdapterIdentity{Hash: "lora1"} + req := ModelReplaceRequest{ + CurrentModel: model, + NextModel: model, + CurrentRuntime: RuntimeIdentity{Backend: "metal", CacheMode: "paged"}, + NextRuntime: RuntimeIdentity{Backend: "metal", CacheMode: "paged-q4"}, + CurrentAdapter: adapter, + NextAdapter: adapter, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkPlan = PlanModelReplace(req) + } +} + +func BenchmarkTuningDeep_PlanModelReplace_AdapterOnly(b *testing.B) { + // Same model + same runtime, adapter changed. + model := ModelIdentity{Hash: "abc", Architecture: "qwen3", QuantBits: 4} + runtime := RuntimeIdentity{Backend: "metal", CacheMode: "paged"} + req := ModelReplaceRequest{ + CurrentModel: model, + NextModel: model, + CurrentRuntime: runtime, + NextRuntime: runtime, + CurrentAdapter: AdapterIdentity{Hash: "lora1"}, + NextAdapter: AdapterIdentity{Hash: "lora2"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkPlan = PlanModelReplace(req) + } +} + +func BenchmarkTuningDeep_PlanModelReplace_PathBasedModel(b *testing.B) { + // Model identity by path (no hash). Exercises sameModelIdentity's + // path-based branch — the Path+QuantBits+QuantType check. + req := ModelReplaceRequest{ + CurrentModel: ModelIdentity{Path: "/m/qwen", QuantBits: 4, QuantType: "q4_k_m"}, + NextModel: ModelIdentity{Path: "/m/qwen", QuantBits: 4, QuantType: "q4_k_m"}, + CurrentRuntime: RuntimeIdentity{Backend: "metal"}, + NextRuntime: RuntimeIdentity{Backend: "metal"}, + CurrentAdapter: AdapterIdentity{Path: "/a/lora1"}, + NextAdapter: AdapterIdentity{Path: "/a/lora1"}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkPlan = PlanModelReplace(req) + } +} + +func BenchmarkTuningDeep_PlanModelReplace_ArchitectureOnly(b *testing.B) { + // No hash, no path — falls to architecture+quant+context comparison. + req := ModelReplaceRequest{ + CurrentModel: ModelIdentity{Architecture: "qwen3", QuantBits: 4, ContextLength: 4096}, + NextModel: ModelIdentity{Architecture: "qwen3", QuantBits: 4, ContextLength: 4096}, + CurrentRuntime: RuntimeIdentity{Backend: "metal"}, + NextRuntime: RuntimeIdentity{Backend: "metal"}, + CurrentAdapter: AdapterIdentity{}, + NextAdapter: AdapterIdentity{}, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkPlan = PlanModelReplace(req) + } +} + +func BenchmarkTuningDeep_PlanModelReplace_AllEmpty(b *testing.B) { + // Empty identities — both sides "match" trivially (everything zero). + req := ModelReplaceRequest{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkPlan = PlanModelReplace(req) + } +} + +// --- ScoreTuningMeasurements edge cases --- + +func BenchmarkTuningDeep_Score_ZeroMeasurements(b *testing.B) { + // All-zero measurements — the score should be 0 with no labels. + m := TuningMeasurements{} + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkScore = ScoreTuningMeasurements(TuningWorkloadChat, m) + } +} + +func BenchmarkTuningDeep_Score_LongContext_NoCache(b *testing.B) { + // PromptCacheHitRate = 0 — the cache-enabled-label branch is + // skipped. + m := TuningMeasurements{ + PrefillTokensPerSec: 800, + DecodeTokensPerSec: 100, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkScore = ScoreTuningMeasurements(TuningWorkloadLongContext, m) + } +} + +func BenchmarkTuningDeep_Score_LowLatency_FirstTokenOnly(b *testing.B) { + // FirstTokenMilliseconds set, TotalMilliseconds zero — only the + // first-token branch fires. + m := TuningMeasurements{ + FirstTokenMilliseconds: 25, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkScore = ScoreTuningMeasurements(TuningWorkloadLowLatency, m) + } +} + +func BenchmarkTuningDeep_Score_AgentState_NoStateBundle(b *testing.B) { + // Only KVRestore set; StateBundle zero. Exercises the partial + // state-restore branch without the bundle branch. + m := TuningMeasurements{ + PrefillTokensPerSec: 800, + DecodeTokensPerSec: 100, + PromptCacheHitRate: 0.6, + KVRestoreMilliseconds: 3, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkScore = ScoreTuningMeasurements(TuningWorkloadAgentState, m) + } +} + +// --- DefaultTuningWorkloads slice clone --- +// The existing bench measures the default constructor; this confirms +// the slice copy is cheap relative to other slice ops. + +func BenchmarkTuningDeep_DefaultWorkloads_Append(b *testing.B) { + base := DefaultTuningWorkloads() + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + // Append one workload to the default — common shape for a + // UI building a "+custom" list. + clone := append([]TuningWorkload(nil), base...) + clone = append(clone, TuningWorkload("custom")) + _ = clone + } +} + +// --- MachineDeviceInfo JSON marshal --- +// Bench-light surface. Fires on every UI report refresh. + +func BenchmarkTuningDeep_MachineDeviceInfo_Marshal(b *testing.B) { + info := MachineDeviceInfo{ + Name: "Apple M3 Ultra", + Architecture: "arm64", + MaxBufferLength: 64 << 30, + MaxRecommendedWorkingSetSize: 80 << 30, + MemorySize: 96 << 30, + Labels: map[string]string{ + "chip": "m3-ultra", + "variant": "studio", + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkString = core.JSONMarshalString(info) + } +} + +// --- TuningPlanRequest marshal --- + +func BenchmarkTuningDeep_TuningPlanRequest_Marshal(b *testing.B) { + req := TuningPlanRequest{ + Runtime: RuntimeIdentity{Backend: "metal", Device: "m3-ultra"}, + Device: MachineDeviceInfo{ + Name: "Apple M3 Ultra", + Architecture: "arm64", + MaxRecommendedWorkingSetSize: 80 << 30, + }, + Model: ModelIdentity{Architecture: "qwen3", QuantBits: 4}, + Workloads: []TuningWorkload{ + TuningWorkloadChat, + TuningWorkloadLongContext, + TuningWorkloadAgentState, + }, + Budget: TuningBudget{ + MaxCandidates: 8, + SmokeTokens: 128, + Runs: 3, + AllowStateBench: true, + }, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkString = core.JSONMarshalString(req) + } +} + +// --- TuningProfileKey marshal --- +// Per-profile lookup key — fires on every cache hit during a model load. + +func BenchmarkTuningDeep_TuningProfileKey_Marshal(b *testing.B) { + key := TuningProfileKey{ + MachineHash: "sha256-abcd-1234-5678", + Runtime: RuntimeIdentity{Backend: "metal", Device: "m3-ultra"}, + Model: ModelIdentity{Architecture: "qwen3", QuantBits: 4, ContextLength: 32768}, + Adapter: AdapterIdentity{Hash: "lora1"}, + Workload: TuningWorkloadAgentState, + } + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + tuneDeepSinkString = core.JSONMarshalString(key) + } +} diff --git a/go/tuning_test.go b/go/tuning_test.go new file mode 100644 index 0000000..cae6ca6 --- /dev/null +++ b/go/tuning_test.go @@ -0,0 +1,109 @@ +// SPDX-Licence-Identifier: EUPL-1.2 + +package inference + +import ( + "testing" + + core "dappco.re/go" +) + +func TestDefaultTuningWorkloads_Good(t *testing.T) { + workloads := DefaultTuningWorkloads() + if len(workloads) < 4 { + t.Fatalf("DefaultTuningWorkloads() len = %d, want at least 4", len(workloads)) + } + if workloads[0] != TuningWorkloadChat { + t.Fatalf("first workload = %q, want %q", workloads[0], TuningWorkloadChat) + } + + workloads[0] = TuningWorkloadThroughput + next := DefaultTuningWorkloads() + if next[0] != TuningWorkloadChat { + t.Fatalf("DefaultTuningWorkloads() returned shared slice, first = %q", next[0]) + } +} + +func TestMachineDiscoveryReport_JSONIncludesUnavailable_Bad(t *testing.T) { + report := MachineDiscoveryReport{ + Runtime: RuntimeIdentity{Backend: "metal"}, + Available: false, + } + + data := core.JSONMarshalString(report) + if !core.Contains(data, `"available":false`) { + t.Fatalf("JSON = %s, want explicit available:false", data) + } +} + +func TestScoreTuningMeasurements_Good(t *testing.T) { + score := ScoreTuningMeasurements(TuningWorkloadAgentState, TuningMeasurements{ + PrefillTokensPerSec: 900, + DecodeTokensPerSec: 120, + PromptCacheHitRate: 0.75, + KVRestoreMilliseconds: 4, + StateBundleMilliseconds: 2, + PeakMemoryBytes: 8 << 30, + }) + + if score.Workload != TuningWorkloadAgentState { + t.Fatalf("score.Workload = %q, want %q", score.Workload, TuningWorkloadAgentState) + } + if score.Score <= score.DecodeTokensPerSec { + t.Fatalf("agent-state score = %f, want cache/restore benefit above decode tps %f", score.Score, score.DecodeTokensPerSec) + } + if score.Labels["state_restore"] != "enabled" { + t.Fatalf("score labels = %+v, want state_restore enabled", score.Labels) + } +} + +func TestScoreTuningMeasurements_LowLatencyFirstToken_Good(t *testing.T) { + score := ScoreTuningMeasurements(TuningWorkloadLowLatency, TuningMeasurements{ + DecodeTokensPerSec: 80, + FirstTokenMilliseconds: 20, + TotalMilliseconds: 120, + CorrectnessSmokeResult: "passed", + CorrectnessSmokeChecks: 2, + }) + + if score.FirstTokenMilliseconds != 20 { + t.Fatalf("FirstTokenMilliseconds = %f, want 20", score.FirstTokenMilliseconds) + } + if score.Score <= score.DecodeTokensPerSec { + t.Fatalf("low-latency score = %f, want first-token benefit above decode tps %f", score.Score, score.DecodeTokensPerSec) + } + if score.Labels["first_token"] != "measured" { + t.Fatalf("labels = %+v, want first_token measured", score.Labels) + } +} + +func TestPlanModelReplace_Good(t *testing.T) { + current := ModelIdentity{Path: "/models/qwen", Hash: "abc", Architecture: "qwen3", QuantBits: 4} + runtime := RuntimeIdentity{Backend: "metal", CacheMode: "paged"} + adapter := AdapterIdentity{Hash: "lora1"} + + reuse := PlanModelReplace(ModelReplaceRequest{ + CurrentModel: current, + NextModel: current, + CurrentRuntime: runtime, + NextRuntime: runtime, + CurrentAdapter: adapter, + NextAdapter: adapter, + }) + if reuse.Action != ModelReplaceReuseState || !reuse.Compatible { + t.Fatalf("reuse plan = %+v, want compatible reuse_state", reuse) + } + + next := current + next.Hash = "def" + next.Path = "/models/qwen-new" + summary := PlanModelReplace(ModelReplaceRequest{ + CurrentModel: current, + NextModel: next, + CurrentRuntime: runtime, + NextRuntime: runtime, + }) + if summary.Action != ModelReplaceSummaryWindow || summary.Compatible { + t.Fatalf("summary plan = %+v, want incompatible summary_window", summary) + } +}