diff --git a/core/application/application.go b/core/application/application.go index 83057c9cd5b9..55b507eb4137 100644 --- a/core/application/application.go +++ b/core/application/application.go @@ -19,6 +19,7 @@ import ( "github.com/mudler/LocalAI/core/services/nodes" "github.com/mudler/LocalAI/core/services/routing/admission" "github.com/mudler/LocalAI/core/services/routing/billing" + "github.com/mudler/LocalAI/core/services/routing/corpus" "github.com/mudler/LocalAI/core/services/routing/pii" "github.com/mudler/LocalAI/core/services/routing/piidetector" "github.com/mudler/LocalAI/core/services/routing/router" @@ -74,6 +75,8 @@ type Application struct { mitmHostConflicts atomic.Pointer[map[string][]string] routerDecisions router.DecisionStore routerRegistry *router.Registry + routerCorpus *corpus.Manager + routerCorpusOnce sync.Once admissionLimiter *admission.Limiter watchdogMutex sync.Mutex watchdogStop chan bool @@ -524,6 +527,12 @@ func (a *Application) start() error { assistantClient.PIIRedactor = a.piiRedactor assistantClient.PIIEvents = a.piiEvents assistantClient.RouterDecisions = a.routerDecisions + // Router corpus tools — same factories the RouteModel middleware + // uses, so the assistant and the request path agree on store + // namespaces and model resolution. + assistantClient.RouterCorpus = a.RouterCorpus() + assistantClient.RouterEmbedder = a.Embedder + assistantClient.RouterVectorStore = a.VectorStore if err := holder.Initialize(a.applicationConfig.Context, assistantClient, localaitools.Options{}); err != nil { // Why log+continue instead of fail: the assistant is an optional // feature; a failure here must not take down the whole server. diff --git a/core/application/router_factories.go b/core/application/router_factories.go index 879c43a835ee..0c4bb936950a 100644 --- a/core/application/router_factories.go +++ b/core/application/router_factories.go @@ -1,11 +1,14 @@ package application import ( + "cmp" "context" "fmt" + "path/filepath" "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/services/routing/corpus" ) // adapterConfig resolves a model name to its runtime ModelConfig, or nil when @@ -118,3 +121,14 @@ func (l *lazyEmbedder) Embed(ctx context.Context, text string) ([]float32, error func (a *Application) VectorStore(storeName string) backend.VectorStore { return backend.NewVectorStore(a.modelLoader, a.applicationConfig, storeName) } + +// RouterCorpus returns the process-wide KNN corpus manager. Corpus +// files live under /router-corpus (same DataPath → +// DynamicConfigsDir precedence the agent pool uses for its state). +func (a *Application) RouterCorpus() *corpus.Manager { + a.routerCorpusOnce.Do(func() { + root := cmp.Or(a.applicationConfig.DataPath, a.applicationConfig.DynamicConfigsDir, ".") + a.routerCorpus = corpus.NewManager(filepath.Join(root, "router-corpus")) + }) + return a.routerCorpus +} diff --git a/core/backend/stores.go b/core/backend/stores.go index 8b73ee17c017..781e02a94f5f 100644 --- a/core/backend/stores.go +++ b/core/backend/stores.go @@ -14,13 +14,23 @@ import ( ) // VectorStore is the narrowed KNN store used by the router's embedding -// cache. Search returns the top-1 match (cosine similarity in [-1, 1]) -// and the serialised payload, or ok=false on a clean miss. +// cache and the KNN classifier. Search returns the top-1 match (cosine +// similarity in [-1, 1]) and the serialised payload, or ok=false on a +// clean miss. SearchK returns up to k nearest neighbours ordered by +// descending similarity; an empty slice is a clean miss. type VectorStore interface { Search(ctx context.Context, vec []float32) (similarity float64, payload []byte, ok bool, err error) + SearchK(ctx context.Context, vec []float32, k int) ([]Neighbor, error) Insert(ctx context.Context, vec []float32, payload []byte) error } +// Neighbor is one SearchK result — the stored payload and its cosine +// similarity to the query vector. +type Neighbor struct { + Similarity float64 + Payload []byte +} + // NewVectorStore returns a VectorStore backed by the local-store // gRPC backend, namespaced by storeName so two routers don't collide. func NewVectorStore(loader *model.ModelLoader, appConfig *config.ApplicationConfig, storeName string) VectorStore { @@ -63,6 +73,35 @@ func (s *localVectorStore) Search(ctx context.Context, vec []float32) (sim float return float64(similarities[0]), values[0], true, nil } +func (s *localVectorStore) SearchK(ctx context.Context, vec []float32, k int) (neighbors []Neighbor, err error) { + start := time.Now() + outcome := "hit" + sim := 0.0 + defer func() { + s.recordTrace(start, "search", len(vec), sim, outcome, err) + }() + be, berr := s.backend(ctx) + if berr != nil { + outcome = "backend_load_error" + return nil, fmt.Errorf("vector store load: %w", berr) + } + _, values, similarities, ferr := store.Find(ctx, be, vec, k) + if ferr != nil { + outcome = "find_error" + return nil, fmt.Errorf("vector store find: %w", ferr) + } + if len(values) == 0 { + outcome = "miss" + return nil, nil + } + neighbors = make([]Neighbor, 0, len(values)) + for i, v := range values { + neighbors = append(neighbors, Neighbor{Similarity: float64(similarities[i]), Payload: v}) + } + sim = neighbors[0].Similarity + return neighbors, nil +} + func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload []byte) (err error) { start := time.Now() outcome := "ok" @@ -81,6 +120,56 @@ func (s *localVectorStore) Insert(ctx context.Context, vec []float32, payload [] return nil } +// InsertBatch upserts many vectors in one gRPC round-trip. Not part of +// the VectorStore interface — the corpus manager type-asserts for it +// and falls back to per-entry Insert on stores that lack it. +func (s *localVectorStore) InsertBatch(ctx context.Context, vecs [][]float32, payloads [][]byte) (err error) { + start := time.Now() + outcome := "ok" + dim := 0 + if len(vecs) > 0 { + dim = len(vecs[0]) + } + defer func() { + s.recordTrace(start, "insert_batch", dim, 0, outcome, err) + }() + be, berr := s.backend(ctx) + if berr != nil { + outcome = "backend_load_error" + return fmt.Errorf("vector store load: %w", berr) + } + if serr := store.SetCols(ctx, be, vecs, payloads); serr != nil { + outcome = "insert_error" + return serr + } + return nil +} + +// Delete removes vectors by key. Optional capability like InsertBatch; +// used by the corpus manager's Clear so a wiped corpus also leaves the +// live index. +func (s *localVectorStore) Delete(ctx context.Context, vecs [][]float32) (err error) { + start := time.Now() + outcome := "ok" + dim := 0 + if len(vecs) > 0 { + dim = len(vecs[0]) + } + defer func() { + s.recordTrace(start, "delete", dim, 0, outcome, err) + }() + be, berr := s.backend(ctx) + if berr != nil { + outcome = "backend_load_error" + return fmt.Errorf("vector store load: %w", berr) + } + if serr := store.DeleteCols(ctx, be, vecs); serr != nil { + outcome = "delete_error" + return serr + } + return nil +} + // recordTrace surfaces vector-store calls in /api/backend-traces, including // the backend-load-failure path that otherwise vanishes into an xlog.Warn. // modelName uses the store namespace (e.g. "router-cache-smart-router") so diff --git a/core/config/meta/registry.go b/core/config/meta/registry.go index b8200cd4115b..2fd0cd0d4f31 100644 --- a/core/config/meta/registry.go +++ b/core/config/meta/registry.go @@ -801,17 +801,19 @@ func DefaultRegistry() map[string]FieldMetaOverride { "router.classifier": { Section: "router", Label: "Classifier", - Description: "Picks a candidate by scoring every policy label against the prompt. Only \"score\" is shipped today; it asks the classifier_model to rank each label and reads off the softmax. Empty defaults to \"score\".", + Description: "How the router picks labels for a prompt. \"score\" asks the classifier_model to rank each policy label and reads off the softmax; \"colbert\" reranks policy descriptions against the prompt via a reranker model; \"knn\" votes over a curated corpus of labelled example prompts (seeded via the corpus API) and routes to the fallback when the prompt is unlike all corpus entries. Empty defaults to \"score\".", Component: "select", Options: []FieldOption{ {Value: "score", Label: "Score (Arch-Router-style)"}, + {Value: "colbert", Label: "Colbert (reranker)"}, + {Value: "knn", Label: "KNN (labelled corpus)"}, }, Order: 230, }, "router.classifier_model": { Section: "router", Label: "Classifier Model", - Description: "Loaded LocalAI model the score classifier asks to rank each policy label as a continuation. Must support the Score gRPC primitive (today: llama-cpp, vLLM) and use the ChatML template. Arch-Router-1.5B Q4_K_M is the canonical choice; any small ChatML instruct model also works at a higher activation_threshold.", + Description: "Loaded LocalAI model the score classifier asks to rank each policy label as a continuation (for colbert: the reranker model). Must support the Score gRPC primitive (today: llama-cpp, vLLM) and use the ChatML template. Arch-Router-1.5B Q4_K_M is the canonical choice; any small ChatML instruct model also works at a higher activation_threshold. Not used by the knn classifier.", Component: "model-select", AutocompleteProvider: ProviderModelsScore, Order: 231, @@ -903,5 +905,48 @@ func DefaultRegistry() map[string]FieldMetaOverride { Component: "input", Order: 240, }, + "router.knn.embedding_model": { + Section: "router", + Label: "KNN: Embedding Model", + Description: "Embedding model the knn classifier uses for corpus entries and incoming prompts. Required when classifier is \"knn\". Changing it invalidates stored vectors — entries recorded under a different embedder are re-embedded on load. nomic-embed-text-v1.5 is the recommended default.", + Component: "model-select", + AutocompleteProvider: ProviderModels, + Order: 241, + }, + "router.knn.k": { + Section: "router", + Label: "KNN: Neighbours (K)", + Description: "How many nearest corpus entries vote on a prompt. 0 picks the default (3). K=1 routes on the single nearest example; larger K tolerates a mislabelled exemplar but needs denser corpus coverage per label.", + Component: "number", + Min: f64(0), + Order: 242, + }, + "router.knn.similarity_threshold": { + Section: "router", + Label: "KNN: Similarity Threshold", + Description: "Cosine-similarity floor a corpus entry must clear to vote. When no entry clears it the router uses the fallback model — a prompt unlike all labelled examples is treated as undecidable rather than guessed. 0 picks the default (0.80).", + Component: "slider", + Min: f64(0), + Max: f64(1), + Step: f64(0.01), + Order: 243, + }, + "router.knn.vote_threshold": { + Section: "router", + Label: "KNN: Vote Threshold", + Description: "Similarity-weighted vote share a label needs to activate. 0 picks the default (0.5, a weighted majority). Lower values allow multi-label activations from minority neighbours; higher values demand near-unanimous neighbourhoods.", + Component: "slider", + Min: f64(0), + Max: f64(1), + Step: f64(0.05), + Order: 244, + }, + "router.knn.store_name": { + Section: "router", + Label: "KNN: Store Name", + Description: "Optional override for the local-store collection holding the corpus vectors. Empty defaults to \"router-corpus-\".", + Component: "input", + Order: 245, + }, } } diff --git a/core/config/model_config.go b/core/config/model_config.go index 69dda331ba27..ef02550aa720 100644 --- a/core/config/model_config.go +++ b/core/config/model_config.go @@ -337,7 +337,17 @@ type RouterConfig struct { // embeddings to past decisions, so semantically-similar prompts // reuse a classification instead of re-running the classifier // model. Omit the block to disable. See router/embedding_cache.go. + // Ignored (with a warning) for the knn classifier — that IS a + // KNN lookup already; wrapping it in another would embed twice + // for no additional information. EmbeddingCache *EmbeddingCacheConfig `yaml:"embedding_cache,omitempty" json:"embedding_cache,omitempty"` + + // KNN configures the "knn" classifier: nearest-neighbour voting + // over a curated corpus of labelled example prompts. Required when + // classifier is "knn", ignored otherwise. The corpus is seeded and + // curated through the router corpus API (never through the UI); + // see router/knn.go for the decision semantics. + KNN *RouterKNNConfig `yaml:"knn,omitempty" json:"knn,omitempty"` } // EmbeddingCacheConfig configures the L2 embedding-similarity decision @@ -371,6 +381,44 @@ type EmbeddingCacheConfig struct { StoreName string `yaml:"store_name,omitempty" json:"store_name,omitempty"` } +// RouterKNNConfig configures the knn classifier. It shares the +// embedding + local-store plumbing with EmbeddingCacheConfig but the +// two are deliberately separate blocks: the cache stores another +// classifier's decisions opportunistically, while the KNN corpus is +// explicit labelled ground truth — different lifecycle, different +// store namespace, different failure story. +type RouterKNNConfig struct { + // EmbeddingModel names the loaded LocalAI model used to embed + // both corpus entries and incoming probes. Required. Changing it + // invalidates the stored vectors — the corpus loader re-embeds + // entries recorded under a different embedder fingerprint. + EmbeddingModel string `yaml:"embedding_model" json:"embedding_model"` + + // K is how many nearest corpus entries vote on a probe. 0 picks + // the package default (3). K=1 reproduces exact nearest-entry + // routing; larger K tolerates mislabelled exemplars at the cost + // of needing denser corpus coverage per label region. + K int `yaml:"k,omitempty" json:"k,omitempty"` + + // SimilarityThreshold is the epistemic gate: corpus entries less + // similar than this to the probe cannot vote, and when none clear + // it the router uses the fallback model — a probe unlike all + // labelled experience is undecidable, not a guess. 0 picks the + // package default (0.80). + SimilarityThreshold float64 `yaml:"similarity_threshold,omitempty" json:"similarity_threshold,omitempty"` + + // VoteThreshold is the similarity-weighted vote share a label + // needs to activate. 0 picks the package default (0.5, a weighted + // majority). Lower values let minority-label neighbours activate + // additional labels (multi-label routing); higher values demand + // near-unanimous neighbourhoods. + VoteThreshold float64 `yaml:"vote_threshold,omitempty" json:"vote_threshold,omitempty"` + + // StoreName overrides the local-store collection holding the + // corpus vectors. Empty defaults to "router-corpus-". + StoreName string `yaml:"store_name,omitempty" json:"store_name,omitempty"` +} + // RouterPolicy is one entry in the label vocabulary. The label string // is what the classifier model emits and what candidates reference in // their Labels field; the description is the natural-language hint diff --git a/core/http/endpoints/localai/api_instructions.go b/core/http/endpoints/localai/api_instructions.go index 405921e5e589..e58fd832ffe8 100644 --- a/core/http/endpoints/localai/api_instructions.go +++ b/core/http/endpoints/localai/api_instructions.go @@ -114,7 +114,7 @@ var instructionDefs = []instructionDef{ Name: "intelligent-routing", Description: "Per-model `router:` configuration that classifies requests and rewrites the served model", Tags: []string{"router"}, - Intro: "Add a `router:` block to a ModelConfig to turn it into a routing model. The block declares a classifier (today: `feature` — handcrafted rules over prompt length and code-fence presence), a list of candidates (label + downstream model + optional rule), and a fallback. When a client addresses the routing model, the RouteModel middleware invokes the classifier, picks a candidate, and rewrites input.Model — the standard model-resolution path then runs ACL, disabled-state, and per-model PII against the chosen target. Depth-1 invariant: candidates must NOT themselves carry a `router:` block; runtime check returns 500 on violation. Decisions are logged to GET /api/router/decisions and surfaced in the /app/middleware Routing tab. POST /api/router/decide is the programmatic decision-oracle: external routers (e.g. an organisation-wide router service) send `{router, input}` and receive the classifier's label set + candidate model WITHOUT LocalAI rewriting, forwarding, or recording the call. Shares the classifier cache with the in-band path so warm-up costs are paid once.", + Intro: "Add a `router:` block to a ModelConfig to turn it into a routing model. The block declares a classifier (`score` — a small model ranks each policy label, Arch-Router-style; `colbert` — a reranker scores policy descriptions against the prompt; `knn` — similarity-weighted vote over a curated corpus of labelled example prompts), `policies` (the label vocabulary), `candidates` (downstream model + labels it serves; first candidate whose labels cover the active set wins, so order small → large), and a `fallback`. The knn classifier needs a `knn: { embedding_model }` block instead of a classifier_model, and reads a persisted corpus seeded via POST /api/router/{name}/corpus with `{entries: [{text, labels}]}` (admin-only; texts are embedded server-side, persisted under the state dir, and NEVER returned by any endpoint — GET /api/router/{name}/corpus/stats reports label counts only, DELETE /api/router/{name}/corpus wipes it). knn routes to the fallback whenever the prompt is less similar than knn.similarity_threshold to every corpus entry — out-of-corpus prompts are treated as undecidable rather than guessed. When a client addresses the routing model, the RouteModel middleware invokes the classifier, picks a candidate, and rewrites input.Model — the standard model-resolution path then runs ACL, disabled-state, and per-model PII against the chosen target. Depth-1 invariant: candidates must NOT themselves carry a `router:` block; runtime check returns 500 on violation. Decisions are logged to GET /api/router/decisions and surfaced in the /app/middleware Routing tab. POST /api/router/decide is the programmatic decision-oracle: external routers (e.g. an organisation-wide router service) send `{router, input}` and receive the classifier's label set + candidate model WITHOUT LocalAI rewriting, forwarding, or recording the call. Shares the classifier cache with the in-band path so warm-up costs are paid once.", }, } diff --git a/core/http/endpoints/localai/router_corpus.go b/core/http/endpoints/localai/router_corpus.go new file mode 100644 index 000000000000..bf2a8223055c --- /dev/null +++ b/core/http/endpoints/localai/router_corpus.go @@ -0,0 +1,189 @@ +package localai + +import ( + "fmt" + "net/http" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/schema" + "github.com/mudler/LocalAI/core/services/routing/corpus" +) + +// The router corpus endpoints manage the labelled exemplar corpus +// behind the knn classifier. Corpus input is API-only by design: +// entries can contain example user content, so they are seeded and +// curated programmatically and never entered through or displayed in +// the UI — the inspection surface returns label counts, never texts. +// +// All three handlers resolve the router model by path param and +// require it to declare a `router.knn` block; the store name and +// embedding model come from that config so the API can't desync from +// what the classifier actually queries. + +// resolveKNNRouter loads the named model config and returns its router +// KNN settings (store name defaulted the same way buildClassifier +// defaults it). Echo-shaped errors for the three failure modes. +func resolveKNNRouter(c echo.Context, loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig) (*config.ModelConfig, string, error) { + name := c.Param("name") + if name == "" { + return nil, "", echo.NewHTTPError(http.StatusBadRequest, "router name is required") + } + cfg, err := loader.LoadModelConfigFileByNameDefaultOptions(name, appConfig) + if err != nil { + return nil, "", echo.NewHTTPError(http.StatusInternalServerError, "failed to load model config: "+err.Error()) + } + // A synthetic stub (no Name) means the model is unknown — see + // RouterDecideEndpoint for the discrimination rationale. + if cfg == nil || cfg.Name == "" { + return nil, "", echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("model %q not found", name)) + } + if cfg.Router.KNN == nil || cfg.Router.KNN.EmbeddingModel == "" { + return nil, "", echo.NewHTTPError(http.StatusBadRequest, + fmt.Sprintf("model %q has no router.knn block (set classifier: knn and knn.embedding_model first)", name)) + } + storeName := cfg.Router.KNN.StoreName + if storeName == "" { + storeName = "router-corpus-" + cfg.Name + } + return cfg, storeName, nil +} + +// RouterCorpusAddEndpoint bulk-seeds a router's KNN corpus. Entries +// are validated against the router's policy labels, embedded with the +// router's knn.embedding_model, persisted to the corpus file, and +// upserted into the live vector index — routing sees them immediately, +// no reload required. +// +// @Summary Seed the KNN routing corpus with labelled example prompts +// @Tags router +// @Accept json +// @Produce json +// @Param name path string true "router model name" +// @Param request body schema.RouterCorpusAddRequest true "labelled exemplars" +// @Success 200 {object} schema.RouterCorpusAddResponse +// @Failure 400 {object} map[string]string +// @Failure 404 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /api/router/{name}/corpus [post] +func RouterCorpusAddEndpoint(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, mgr *corpus.Manager, deps middleware.ClassifierDeps) echo.HandlerFunc { + return func(c echo.Context) error { + cfg, storeName, err := resolveKNNRouter(c, loader, appConfig) + if err != nil { + return err + } + var req schema.RouterCorpusAddRequest + if err := c.Bind(&req); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "invalid request body: "+err.Error()) + } + if len(req.Entries) == 0 { + return echo.NewHTTPError(http.StatusBadRequest, "entries is required") + } + + // Labels must be declared policies — the same invariant + // candidate tables are validated against. Catching it here + // keeps a typo from silently creating an unroutable label. + declared := map[string]struct{}{} + for _, p := range cfg.Router.Policies { + declared[p.Label] = struct{}{} + } + entries := make([]corpus.Entry, 0, len(req.Entries)) + for i, e := range req.Entries { + for _, l := range e.Labels { + if _, ok := declared[l]; !ok { + return echo.NewHTTPError(http.StatusBadRequest, + fmt.Sprintf("entry %d: label %q is not declared in router policies", i, l)) + } + } + entries = append(entries, corpus.Entry{Text: e.Text, Labels: e.Labels}) + } + + embedder := deps.Embedder(cfg.Router.KNN.EmbeddingModel) + if embedder == nil { + return echo.NewHTTPError(http.StatusBadRequest, + fmt.Sprintf("embedding_model %q not loadable", cfg.Router.KNN.EmbeddingModel)) + } + store := deps.VectorStore(storeName) + + added, skipped, err := mgr.Add(c.Request().Context(), storeName, cfg.Router.KNN.EmbeddingModel, embedder, store, entries) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + stats, err := mgr.Stats(storeName) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, schema.RouterCorpusAddResponse{ + Router: cfg.Name, + Added: added, + Skipped: skipped, + Total: stats.Total, + LabelCounts: stats.LabelCounts, + }) + } +} + +// RouterCorpusStatsEndpoint reports corpus size and per-label counts. +// Deliberately count-only: corpus texts never leave the server. +// +// @Summary Inspect a router's KNN corpus (label counts only, never texts) +// @Tags router +// @Produce json +// @Param name path string true "router model name" +// @Success 200 {object} schema.RouterCorpusStatsResponse +// @Failure 400 {object} map[string]string +// @Failure 404 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /api/router/{name}/corpus/stats [get] +func RouterCorpusStatsEndpoint(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, mgr *corpus.Manager) echo.HandlerFunc { + return func(c echo.Context) error { + cfg, storeName, err := resolveKNNRouter(c, loader, appConfig) + if err != nil { + return err + } + stats, err := mgr.Stats(storeName) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, schema.RouterCorpusStatsResponse{ + Router: cfg.Name, + StoreName: storeName, + EmbeddingModel: cfg.Router.KNN.EmbeddingModel, + Total: stats.Total, + LabelCounts: stats.LabelCounts, + EmbeddingModels: stats.EmbeddingModels, + }) + } +} + +// RouterCorpusClearEndpoint wipes a router's corpus — file and live +// index. Reseed with POST afterwards; there is intentionally no +// per-entry delete (exemplars are curated as a set; partial edits by +// vector identity are how mislabelled neighbourhoods linger). +// +// @Summary Clear a router's KNN corpus +// @Tags router +// @Produce json +// @Param name path string true "router model name" +// @Success 200 {object} schema.RouterCorpusClearResponse +// @Failure 400 {object} map[string]string +// @Failure 404 {object} map[string]string +// @Failure 500 {object} map[string]string +// @Router /api/router/{name}/corpus [delete] +func RouterCorpusClearEndpoint(loader *config.ModelConfigLoader, appConfig *config.ApplicationConfig, mgr *corpus.Manager, deps middleware.ClassifierDeps) echo.HandlerFunc { + return func(c echo.Context) error { + cfg, storeName, err := resolveKNNRouter(c, loader, appConfig) + if err != nil { + return err + } + cleared, err := mgr.Clear(c.Request().Context(), storeName, deps.VectorStore(storeName)) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, err.Error()) + } + return c.JSON(http.StatusOK, schema.RouterCorpusClearResponse{ + Router: cfg.Name, + Cleared: cleared, + }) + } +} diff --git a/core/http/endpoints/localai/router_corpus_test.go b/core/http/endpoints/localai/router_corpus_test.go new file mode 100644 index 000000000000..bfb042730b79 --- /dev/null +++ b/core/http/endpoints/localai/router_corpus_test.go @@ -0,0 +1,210 @@ +package localai_test + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + + "github.com/labstack/echo/v4" + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/config" + "github.com/mudler/LocalAI/core/http/endpoints/localai" + "github.com/mudler/LocalAI/core/http/middleware" + "github.com/mudler/LocalAI/core/services/routing/corpus" + "github.com/mudler/LocalAI/pkg/system" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + "gopkg.in/yaml.v3" +) + +// The corpus endpoints manage the knn classifier's labelled exemplar +// store. These specs pin the validation surface (knn-only, declared +// labels), the seed → stats → clear lifecycle, and the privacy +// contract: corpus texts never appear in any response body. + +type corpusTestEmbedder struct{} + +func (corpusTestEmbedder) Embed(_ context.Context, text string) ([]float32, error) { + return []float32{float32(len(text)), 1}, nil +} + +// corpusTestStore implements only the narrow VectorStore interface — +// no batch/delete fast paths — so the specs also cover the manager's +// per-entry fallbacks. +type corpusTestStore struct { + inserted int +} + +func (s *corpusTestStore) Search(_ context.Context, _ []float32) (float64, []byte, bool, error) { + return 0, nil, false, nil +} + +func (s *corpusTestStore) SearchK(_ context.Context, _ []float32, _ int) ([]backend.Neighbor, error) { + return nil, nil +} + +func (s *corpusTestStore) Insert(_ context.Context, _ []float32, _ []byte) error { + s.inserted++ + return nil +} + +func writeKNNRouter(modelDir, name string) { + cfg := &config.ModelConfig{ + Name: name, + Router: config.RouterConfig{ + Classifier: "knn", + Fallback: "small-model", + KNN: &config.RouterKNNConfig{EmbeddingModel: "embed-model"}, + Policies: []config.RouterPolicy{ + {Label: "code-generation", Description: "writing or debugging code"}, + {Label: "casual-chat", Description: "small talk"}, + }, + Candidates: []config.RouterCandidate{ + {Model: "small-model", Labels: []string{"casual-chat"}}, + {Model: "big-model", Labels: []string{"code-generation", "casual-chat"}}, + }, + }, + } + b, err := yaml.Marshal(cfg) + Expect(err).NotTo(HaveOccurred()) + Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(b), 0o644)).To(Succeed()) +} + +var _ = Describe("Router corpus endpoints", func() { + var ( + modelDir string + corpusDir string + appConfig *config.ApplicationConfig + loader *config.ModelConfigLoader + mgr *corpus.Manager + store *corpusTestStore + e *echo.Echo + ) + + const seededText = "please debug this stack trace for me" + + BeforeEach(func() { + d, err := os.MkdirTemp("", "router-corpus-ep-*") + Expect(err).NotTo(HaveOccurred()) + modelDir = d + corpusDir = filepath.Join(d, "corpus") + appConfig = &config.ApplicationConfig{ + Context: context.Background(), + SystemState: &system.SystemState{Model: system.Model{ModelsPath: modelDir}}, + } + loader = config.NewModelConfigLoader(modelDir) + mgr = corpus.NewManager(corpusDir) + store = &corpusTestStore{} + + deps := middleware.ClassifierDeps{ + Embedder: func(string) backend.Embedder { return corpusTestEmbedder{} }, + VectorStore: func(string) backend.VectorStore { return store }, + } + e = echo.New() + e.POST("/api/router/:name/corpus", localai.RouterCorpusAddEndpoint(loader, appConfig, mgr, deps)) + e.GET("/api/router/:name/corpus/stats", localai.RouterCorpusStatsEndpoint(loader, appConfig, mgr)) + e.DELETE("/api/router/:name/corpus", localai.RouterCorpusClearEndpoint(loader, appConfig, mgr, deps)) + }) + + AfterEach(func() { + _ = os.RemoveAll(modelDir) + }) + + do := func(method, path, body string) *httptest.ResponseRecorder { + var rd *strings.Reader + if body == "" { + rd = strings.NewReader("") + } else { + rd = strings.NewReader(body) + } + req := httptest.NewRequest(method, path, rd) + req.Header.Set("Content-Type", "application/json") + rec := httptest.NewRecorder() + e.ServeHTTP(rec, req) + return rec + } + + seedBody := `{"entries":[ + {"text":"` + seededText + `","labels":["code-generation"]}, + {"text":"how is your day going","labels":["casual-chat"]} + ]}` + + It("returns 404 for an unknown router", func() { + rec := do(http.MethodPost, "/api/router/nope/corpus", seedBody) + Expect(rec.Code).To(Equal(http.StatusNotFound)) + }) + + It("returns 400 for a router without a knn block", func() { + writeScoreRouter(modelDir, "score-router") + rec := do(http.MethodPost, "/api/router/score-router/corpus", seedBody) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + Expect(rec.Body.String()).To(ContainSubstring("router.knn")) + }) + + It("rejects entries with undeclared labels", func() { + writeKNNRouter(modelDir, "knn-router") + rec := do(http.MethodPost, "/api/router/knn-router/corpus", + `{"entries":[{"text":"hello","labels":["not-a-policy"]}]}`) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + Expect(rec.Body.String()).To(ContainSubstring("not declared")) + }) + + It("rejects an empty entries list", func() { + writeKNNRouter(modelDir, "knn-router") + rec := do(http.MethodPost, "/api/router/knn-router/corpus", `{"entries":[]}`) + Expect(rec.Code).To(Equal(http.StatusBadRequest)) + }) + + It("seeds, reports counts only, and clears", func() { + writeKNNRouter(modelDir, "knn-router") + + rec := do(http.MethodPost, "/api/router/knn-router/corpus", seedBody) + Expect(rec.Code).To(Equal(http.StatusOK), rec.Body.String()) + var addResp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &addResp)).To(Succeed()) + Expect(addResp["added"]).To(BeNumerically("==", 2)) + Expect(addResp["total"]).To(BeNumerically("==", 2)) + Expect(store.inserted).To(Equal(2)) + + // The seed response must not echo the entry texts back. + Expect(rec.Body.String()).NotTo(ContainSubstring(seededText)) + + rec = do(http.MethodGet, "/api/router/knn-router/corpus/stats", "") + Expect(rec.Code).To(Equal(http.StatusOK)) + var stats map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &stats)).To(Succeed()) + Expect(stats["total"]).To(BeNumerically("==", 2)) + Expect(stats["label_counts"]).To(HaveKeyWithValue("code-generation", BeNumerically("==", 1))) + Expect(stats["embedding_model"]).To(Equal("embed-model")) + // Privacy contract: the inspection surface never returns texts. + Expect(rec.Body.String()).NotTo(ContainSubstring(seededText)) + + rec = do(http.MethodDelete, "/api/router/knn-router/corpus", "") + Expect(rec.Code).To(Equal(http.StatusOK)) + var clr map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &clr)).To(Succeed()) + Expect(clr["cleared"]).To(BeNumerically("==", 2)) + + rec = do(http.MethodGet, "/api/router/knn-router/corpus/stats", "") + Expect(rec.Code).To(Equal(http.StatusOK)) + var after map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &after)).To(Succeed()) + Expect(after["total"]).To(BeNumerically("==", 0)) + }) + + It("skips duplicate texts on reseed", func() { + writeKNNRouter(modelDir, "knn-router") + Expect(do(http.MethodPost, "/api/router/knn-router/corpus", seedBody).Code).To(Equal(http.StatusOK)) + rec := do(http.MethodPost, "/api/router/knn-router/corpus", seedBody) + Expect(rec.Code).To(Equal(http.StatusOK)) + var resp map[string]any + Expect(json.Unmarshal(rec.Body.Bytes(), &resp)).To(Succeed()) + Expect(resp["added"]).To(BeNumerically("==", 0)) + Expect(resp["skipped"]).To(BeNumerically("==", 2)) + Expect(resp["total"]).To(BeNumerically("==", 2)) + }) +}) diff --git a/core/http/endpoints/localai/router_decide.go b/core/http/endpoints/localai/router_decide.go index 11fddcf574f7..b13b3a7acbc0 100644 --- a/core/http/endpoints/localai/router_decide.go +++ b/core/http/endpoints/localai/router_decide.go @@ -95,15 +95,16 @@ func RouterDecideEndpoint(loader *config.ModelConfigLoader, appConfig *config.Ap } return c.JSON(http.StatusOK, schema.RouterDecideResponse{ - Router: req.Router, - Classifier: classifierName, - Labels: decision.Labels, - Candidate: candidate, - Fallback: fallback, - Score: decision.Score, - LatencyMs: decision.Latency.Milliseconds(), - Cached: decision.Cached, - CacheSimilarity: decision.CacheSimilarity, + Router: req.Router, + Classifier: classifierName, + Labels: decision.Labels, + Candidate: candidate, + Fallback: fallback, + Score: decision.Score, + LatencyMs: decision.Latency.Milliseconds(), + Cached: decision.Cached, + CacheSimilarity: decision.CacheSimilarity, + NearestSimilarity: decision.NearestSimilarity, }) } } diff --git a/core/http/endpoints/mcp/localai_assistant_test.go b/core/http/endpoints/mcp/localai_assistant_test.go index 8de7355c671f..11e0a1794a6b 100644 --- a/core/http/endpoints/mcp/localai_assistant_test.go +++ b/core/http/endpoints/mcp/localai_assistant_test.go @@ -167,3 +167,15 @@ var _ = Describe("LocalAIAssistantHolder", func() { Expect(exec.HasTools()).To(BeFalse()) }) }) + +func (stubClient) GetRouterCorpusStats(_ context.Context, routerModel string) (*localaitools.RouterCorpusStats, error) { + return &localaitools.RouterCorpusStats{Router: routerModel, LabelCounts: map[string]int{}}, nil +} + +func (stubClient) SeedRouterCorpus(_ context.Context, req localaitools.RouterCorpusSeedRequest) (*localaitools.RouterCorpusSeedResult, error) { + return &localaitools.RouterCorpusSeedResult{Router: req.Router, LabelCounts: map[string]int{}}, nil +} + +func (stubClient) ClearRouterCorpus(_ context.Context, routerModel string) (*localaitools.RouterCorpusClearResult, error) { + return &localaitools.RouterCorpusClearResult{Router: routerModel}, nil +} diff --git a/core/http/endpoints/openai/realtime_model.go b/core/http/endpoints/openai/realtime_model.go index 71f553980465..db2eac867ec7 100644 --- a/core/http/endpoints/openai/realtime_model.go +++ b/core/http/endpoints/openai/realtime_model.go @@ -530,6 +530,7 @@ func buildRealtimeRoutingContext(a *application.Application, sessionID string) * } deps := &middleware.ClassifierDeps{ Scorer: a.Scorer, + Corpus: a.RouterCorpus(), TokenCounter: a.TokenCounter, Embedder: a.Embedder, VectorStore: a.VectorStore, diff --git a/core/http/middleware/route_model.go b/core/http/middleware/route_model.go index 470bd05f5aa2..a7e3255fb39d 100644 --- a/core/http/middleware/route_model.go +++ b/core/http/middleware/route_model.go @@ -49,6 +49,16 @@ type RerankerFactory func(modelName string) backend.Reranker // lives in ModelConfig.Validate() and runs at config load/save time. type ModelConfigLookup func(modelName string) *config.ModelConfig +// CorpusLoader syncs a router's persisted KNN corpus into the live +// vector index when the knn classifier is built. Implemented by +// corpus.Manager; declared here (consumer side) so the middleware +// doesn't depend on the corpus package. Optional — when nil, the knn +// classifier serves whatever the index already holds (tests, embedded +// callers). +type CorpusLoader interface { + EnsureLoaded(ctx context.Context, storeName, embeddingModel string, embedder backend.Embedder, store backend.VectorStore) (int, error) +} + // ClassifierDeps bundles the backend factories the router middleware // needs to build a classifier and its optional L2 cache. Bundled into // one struct because RouteModel already takes many positional @@ -66,6 +76,10 @@ type ClassifierDeps struct { VectorStore VectorStoreFactory Reranker RerankerFactory + // Corpus loads the persisted KNN corpus into the vector index when + // a knn classifier is built. Optional; nil skips the load. + Corpus CorpusLoader + // ModelLookup resolves the classifier_model name to its config so // buildClassifier can reject misconfigurations that would // otherwise crash the llama-cpp backend at request time. Optional @@ -375,13 +389,61 @@ func buildClassifier(cfg *config.ModelConfig, deps ClassifierDeps) (router.Class rerankClassifier = rerankClassifier.WithTokenTrim(count, ctxTokens) } inner = rerankClassifier + case router.ClassifierKNN: + if rc.KNN == nil || rc.KNN.EmbeddingModel == "" { + return nil, fmt.Errorf("router classifier knn requires a knn block with embedding_model") + } + if deps.Embedder == nil || deps.VectorStore == nil { + return nil, fmt.Errorf("router classifier knn unavailable: embedder/vector-store factories not wired") + } + embedder := deps.Embedder(rc.KNN.EmbeddingModel) + if embedder == nil { + return nil, fmt.Errorf("router classifier knn: embedding_model %q not loadable", rc.KNN.EmbeddingModel) + } + storeName := rc.KNN.StoreName + if storeName == "" { + storeName = "router-corpus-" + cfg.Name + } + vstore := deps.VectorStore(storeName) + if vstore == nil { + return nil, fmt.Errorf("router classifier knn: vector store %q not loadable", storeName) + } + if deps.Corpus != nil { + // A failed load must not break routing — the classifier + // still works against whatever the index holds, and the + // epistemic gate falls back safely on an empty index. + if n, err := deps.Corpus.EnsureLoaded(context.Background(), storeName, rc.KNN.EmbeddingModel, embedder, vstore); err != nil { + xlog.Warn("router: knn corpus load failed; routing continues on the live index", + "router_model", cfg.Name, "store", storeName, "error", err) + } else if n > 0 { + xlog.Info("router: knn corpus loaded", + "router_model", cfg.Name, "store", storeName, "entries", n) + } + } + knnClassifier := router.NewKNNClassifier(embedder, vstore, router.KNNClassifierOptions{ + K: rc.KNN.K, + SimilarityThreshold: rc.KNN.SimilarityThreshold, + VoteThreshold: rc.KNN.VoteThreshold, + }) + if count, ctxTokens := modelTokenTrim(rc.KNN.EmbeddingModel, deps); count != nil { + knnClassifier = knnClassifier.WithTokenTrim(count, ctxTokens) + } + inner = knnClassifier default: - return nil, fmt.Errorf("router: unknown classifier %q (supported: %s)", name, strings.Join([]string{router.ClassifierScore, router.ClassifierColbert}, ", ")) + return nil, fmt.Errorf("router: unknown classifier %q (supported: %s)", name, strings.Join([]string{router.ClassifierScore, router.ClassifierColbert, router.ClassifierKNN}, ", ")) } if rc.EmbeddingCache == nil { return inner, nil } + if name == router.ClassifierKNN { + // The knn classifier IS an embedding-KNN lookup — wrapping it in + // the embedding cache would embed every probe twice to answer the + // same question. Ignore the block rather than fail routing. + xlog.Warn("router: embedding_cache ignored for knn classifier", + "router_model", cfg.Name) + return inner, nil + } wrapped, err := wrapWithEmbeddingCache(cfg, inner, deps) if err != nil { // Caching plumbing problems must not break routing — log, @@ -428,7 +490,11 @@ func assertClassifierDeclaresScore(classifierModel string, lookup ModelConfigLoo // returns the parsed []ScorePolicy. Both Score and Rerank classifiers // take the same policy shape. func validateRouterPolicies(classifierName string, rc config.RouterConfig) ([]router.ScorePolicy, error) { - if rc.ClassifierModel == "" { + // The knn classifier has no scoring model — its label knowledge + // lives in the corpus. It still declares policies: they are the + // label vocabulary corpus entries are validated against and the + // "categories" surface the Routing tab shows. + if rc.ClassifierModel == "" && classifierName != router.ClassifierKNN { return nil, fmt.Errorf("router classifier %s requires classifier_model", classifierName) } if len(rc.Policies) == 0 { diff --git a/core/http/middleware/route_model_test.go b/core/http/middleware/route_model_test.go index 4a9be2b12fb3..c8f85b8f3a05 100644 --- a/core/http/middleware/route_model_test.go +++ b/core/http/middleware/route_model_test.go @@ -2,6 +2,7 @@ package middleware_test import ( "context" + "errors" "net/http" "net/http/httptest" "os" @@ -540,3 +541,185 @@ template: ` Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(body), 0o644)).To(Succeed()) } + +// --- knn classifier middleware specs --- + +// fixedEmbedder returns the same vector for every probe — the KNN +// middleware specs script the neighbourhood in the store fake, so the +// query vector itself is irrelevant. +type fixedEmbedder struct{} + +func (fixedEmbedder) Embed(_ context.Context, _ string) ([]float32, error) { + return []float32{1, 0, 0}, nil +} + +// scriptedVectorStore returns a fixed neighbour list from SearchK. The +// classifier must never call Search (top-1 is the cache's shape) nor +// Insert (the KNN corpus grows only through explicit curation) — both +// error loudly so a regression shows up as a failed spec. +type scriptedVectorStore struct { + neighbors []backend.Neighbor +} + +func (s *scriptedVectorStore) SearchK(_ context.Context, _ []float32, k int) ([]backend.Neighbor, error) { + if len(s.neighbors) > k { + return s.neighbors[:k], nil + } + return s.neighbors, nil +} + +func (s *scriptedVectorStore) Search(_ context.Context, _ []float32) (float64, []byte, bool, error) { + return 0, nil, false, errTestKNNSearch +} + +func (s *scriptedVectorStore) Insert(_ context.Context, _ []float32, _ []byte) error { + return errTestKNNInsert +} + +var ( + errTestKNNSearch = errors.New("knn classifier must use SearchK, not Search") + errTestKNNInsert = errors.New("knn classifier must never insert into the corpus") +) + +func corpusPayload(labels ...string) []byte { + b, err := router.EncodeCorpusEntry(labels) + Expect(err).NotTo(HaveOccurred()) + return b +} + +// newKNNRouterModel mirrors newScoreRouterModel but with the knn +// classifier: same policy vocabulary and candidate table, no +// classifier_model, corpus semantics supplied by the store fake. +func newKNNRouterModel(modelDir, name string) *config.ModelConfig { + cfg := &config.ModelConfig{ + Name: name, + Router: config.RouterConfig{ + Classifier: "knn", + Fallback: "qwen3-0.6b", + KNN: &config.RouterKNNConfig{EmbeddingModel: "embed-model"}, + Policies: []config.RouterPolicy{ + {Label: "code-generation", Description: "writing or debugging code"}, + {Label: "casual-chat", Description: "small talk"}, + {Label: "math-reasoning", Description: "arithmetic and word problems"}, + }, + Candidates: []config.RouterCandidate{ + {Model: "small-model", Labels: []string{"casual-chat"}}, + {Model: "big-model", Labels: []string{"code-generation", "casual-chat", "math-reasoning"}}, + }, + }, + } + Expect(os.WriteFile(filepath.Join(modelDir, name+".yaml"), []byte(toYAML(cfg)), 0o644)).To(Succeed()) + return cfg +} + +var _ = Describe("RouteModel middleware (knn classifier)", func() { + var ( + modelDir string + appConfig *config.ApplicationConfig + loader *config.ModelConfigLoader + store *fakeDecisionStore + vstore *scriptedVectorStore + storeName string + ) + + BeforeEach(func() { + d, err := os.MkdirTemp("", "router-knn-test-*") + Expect(err).NotTo(HaveOccurred()) + modelDir = d + appConfig = &config.ApplicationConfig{ + Context: context.Background(), + SystemState: &system.SystemState{Model: system.Model{ModelsPath: modelDir}}, + } + loader = config.NewModelConfigLoader(modelDir) + store = &fakeDecisionStore{} + vstore = &scriptedVectorStore{} + storeName = "" + }) + + AfterEach(func() { + _ = os.RemoveAll(modelDir) + }) + + knnDeps := func() ClassifierDeps { + return ClassifierDeps{ + Embedder: func(string) backend.Embedder { return fixedEmbedder{} }, + VectorStore: func(name string) backend.VectorStore { + storeName = name + return vstore + }, + } + } + + It("routes to the candidate covering the corpus vote", func() { + routerCfg := newKNNRouterModel(modelDir, "smart-router") + writeCandidate(modelDir, "small-model") + writeCandidate(modelDir, "big-model") + vstore.neighbors = []backend.Neighbor{ + {Similarity: 0.92, Payload: corpusPayload("code-generation")}, + {Similarity: 0.88, Payload: corpusPayload("code-generation")}, + } + + rec, err := runRouterWithDeps(loader, appConfig, store, routerCfg, + openAIChat("debug my Go null pointer"), knnDeps()) + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Body.String()).To(Equal("served:big-model")) + Expect(storeName).To(Equal("router-corpus-smart-router"), + "corpus store must be namespaced per router, separate from the decision cache") + Expect(store.records).To(HaveLen(1)) + Expect(store.records[0].Classifier).To(Equal("knn")) + Expect(store.records[0].Label).To(ContainSubstring("code-generation")) + Expect(store.records[0].NearestSimilarity).To(BeNumerically("~", 0.92, 1e-9)) + }) + + It("falls back when the probe is out of corpus range (epistemic gate)", func() { + routerCfg := newKNNRouterModel(modelDir, "smart-router") + writeCandidate(modelDir, "small-model") + writeCandidate(modelDir, "big-model") + writeCandidate(modelDir, "qwen3-0.6b") + // Nearest labelled experience is far below the 0.80 default gate. + vstore.neighbors = []backend.Neighbor{ + {Similarity: 0.42, Payload: corpusPayload("code-generation")}, + } + + rec, err := runRouterWithDeps(loader, appConfig, store, routerCfg, + openAIChat("translate this sanskrit poem"), knnDeps()) + Expect(err).NotTo(HaveOccurred()) + Expect(rec.Body.String()).To(Equal("served:qwen3-0.6b")) + Expect(store.records).To(HaveLen(1)) + Expect(store.records[0].Label).To(Equal(router.LabelFallback)) + // The decision log must still carry the epistemic signal — how + // close the nearest corpus entry was to routing this probe. + Expect(store.records[0].NearestSimilarity).To(BeNumerically("~", 0.42, 1e-9)) + }) + + It("rejects a knn router without a knn block at build time", func() { + routerCfg := newKNNRouterModel(modelDir, "smart-router") + routerCfg.Router.KNN = nil + writeCandidate(modelDir, "small-model") + writeCandidate(modelDir, "big-model") + + _, err := runRouterWithDeps(loader, appConfig, store, routerCfg, + openAIChat("hello"), knnDeps()) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("knn")) + }) + + It("ignores an embedding_cache block instead of double-embedding", func() { + routerCfg := newKNNRouterModel(modelDir, "smart-router") + routerCfg.Router.EmbeddingCache = &config.EmbeddingCacheConfig{EmbeddingModel: "embed-model"} + writeCandidate(modelDir, "small-model") + writeCandidate(modelDir, "big-model") + vstore.neighbors = []backend.Neighbor{ + {Similarity: 0.95, Payload: corpusPayload("casual-chat")}, + } + + rec, err := runRouterWithDeps(loader, appConfig, store, routerCfg, + openAIChat("hi"), knnDeps()) + Expect(err).NotTo(HaveOccurred()) + // Routed by the corpus (small-model covers casual-chat) and NOT + // wrapped: a cache wrap would have called the fake's Search, + // which errors loudly. + Expect(rec.Body.String()).To(Equal("served:small-model")) + Expect(store.records[0].Cached).To(BeFalse()) + }) +}) diff --git a/core/http/react-ui/e2e/middleware-page.spec.js b/core/http/react-ui/e2e/middleware-page.spec.js index 98e011c1e1d4..a12733a99659 100644 --- a/core/http/react-ui/e2e/middleware-page.spec.js +++ b/core/http/react-ui/e2e/middleware-page.spec.js @@ -57,9 +57,31 @@ const MOCK_STATUS = { }, }, }, + { + name: 'knn-router', + classifier: 'knn', + fallback: 'qwen-7b', + policies: [ + { label: 'casual-chat', description: 'small talk' }, + { label: 'code-generation', description: 'writing or debugging code' }, + ], + candidates: [ + { model: 'qwen-3b', labels: ['casual-chat'] }, + { model: 'qwen-coder', labels: ['code-generation', 'casual-chat'] }, + ], + knn: { + embedding_model: 'nomic-embed-text-v1.5', + k: 3, + similarity_threshold: 0.80, + vote_threshold: 0.5, + store_name: 'router-corpus-knn-router', + // Counts only — the status endpoint never sends corpus texts. + corpus: { total: 12, label_counts: { 'code-generation': 7, 'casual-chat': 5 } }, + }, + }, ], recent_decision_count: 1, - available_classifiers: ['score'], + available_classifiers: ['score', 'colbert', 'knn'], }, } @@ -72,6 +94,13 @@ const MOCK_DECISIONS = { cached: true, cache_similarity: 0.92, created_at: '2026-05-06T11:00:00Z', }, + { + id: 'rd_a2', correlation_id: 'corr-2', user_id: 'local', + router_model: 'knn-router', requested_model: 'knn-router', served_model: 'qwen-7b', + classifier: 'knn', label: 'fallback', score: 0, latency_ms: 9, + cached: false, nearest_similarity: 0.42, + created_at: '2026-05-06T11:01:00Z', + }, ], } @@ -229,6 +258,26 @@ test.describe('Middleware page — admin in no-auth mode', () => { await expect(page.getByText('casual-chat').first()).toBeVisible() }) + test('Routing tab renders knn corpus stats and out-of-corpus fallback detail', async ({ page }) => { + await page.goto('/app/middleware') + await page.getByRole('button', { name: /Routing/i }).click() + + // KNN router row: corpus size, K, and gate threshold in the + // Cache / corpus column. + await expect(page.getByText('knn-router').first()).toBeVisible() + await expect(page.getByText(/12 exemplars · k=3 · sim ≥ 0\.8/).first()).toBeVisible() + + // Per-label exemplar counts — counts only, never corpus texts. + await expect(page.getByText(/code-generation: 7/).first()).toBeVisible() + await expect(page.getByText(/casual-chat: 5/).first()).toBeVisible() + + // Expanding the knn fallback decision explains the epistemic gate + // and surfaces how far away the nearest labelled experience was. + await page.getByText('fallback', { exact: true }).first().click() + await expect(page.getByText(/Out-of-corpus fallback/i).first()).toBeVisible() + await expect(page.getByText(/similarity 0\.42/).first()).toBeVisible() + }) + test('Routing tab renders embedding-cache stats and similarity histogram', async ({ page }) => { await page.goto('/app/middleware') await page.getByRole('button', { name: /Routing/i }).click() diff --git a/core/http/react-ui/src/pages/Middleware.jsx b/core/http/react-ui/src/pages/Middleware.jsx index 027967f694d8..0fb10ae61903 100644 --- a/core/http/react-ui/src/pages/Middleware.jsx +++ b/core/http/react-ui/src/pages/Middleware.jsx @@ -512,7 +512,9 @@ function DecisionDetail({ d }) {
{d.cached ? 'Cached decision — per-label scores not recorded (the cache stores only the resulting label set).' - : 'No per-label scores recorded for this decision (likely a fallback row).'} + : d.nearest_similarity + ? `Out-of-corpus fallback — the nearest labelled corpus entry was at similarity ${d.nearest_similarity.toFixed(2)}, below the router's gate. Seed exemplars near this kind of prompt to route it.` + : 'No per-label scores recorded for this decision (likely a fallback row).'}
) } @@ -611,7 +613,7 @@ function RoutingTab({ status, decisions }) { Model Classifier Candidates - Embedding cache + Cache / corpus Fallback @@ -632,7 +634,7 @@ function RoutingTab({ status, decisions }) { ))} - + {m.knn ? : } {m.fallback || '—'} @@ -1104,6 +1106,47 @@ function EventsTab({ events }) { // for configured caches, shows hit/miss/near-miss counters plus a // similarity histogram with a marker at the configured threshold so // admins can tell at a glance whether the threshold is well-placed. +// RouterKNNCell summarises a knn router's corpus for the Active +// routers table: embedding model, corpus size, per-label exemplar +// counts, and the epistemic-gate threshold. Counts only — corpus +// texts never reach the UI (the status endpoint doesn't send them, +// by design; seeding/curation is API-only). +function RouterKNNCell({ knn }) { + if (!knn) { + return + } + const corpus = knn.corpus || {} + const total = corpus.total || 0 + const counts = corpus.label_counts || {} + const k = knn.k || 3 + const sim = knn.similarity_threshold || 0.80 + return ( +
+
{knn.embedding_model}
+
+ {total === 0 ? ( + + empty corpus — seed via API + + ) : ( + + {total} exemplars · k={k} · sim ≥ {sim} + + )} +
+ {total > 0 && ( +
+ {Object.entries(counts).sort((a, b) => b[1] - a[1] || a[0].localeCompare(b[0])).map(([label, n]) => ( + + {label}: {n} + + ))} +
+ )} +
+ ) +} + function RouterCacheCell({ cache }) { if (!cache) { return diff --git a/core/http/routes/anthropic.go b/core/http/routes/anthropic.go index 288aa6f57b1a..4565a4e7815a 100644 --- a/core/http/routes/anthropic.go +++ b/core/http/routes/anthropic.go @@ -57,6 +57,7 @@ func RegisterAnthropicRoutes(app *echo.Echo, router.SourceAnthropic, middleware.ClassifierDeps{ Scorer: application.Scorer, + Corpus: application.RouterCorpus(), TokenCounter: application.TokenCounter, Embedder: application.Embedder, VectorStore: application.VectorStore, diff --git a/core/http/routes/middleware.go b/core/http/routes/middleware.go index 6d130863ad5e..b22c7874c4db 100644 --- a/core/http/routes/middleware.go +++ b/core/http/routes/middleware.go @@ -136,6 +136,7 @@ func RegisterMiddlewareRoutes(e *echo.Echo, app *application.Application) { app.ApplicationConfig(), middleware.ClassifierDeps{ Scorer: app.Scorer, + Corpus: app.RouterCorpus(), TokenCounter: app.TokenCounter, Embedder: app.Embedder, VectorStore: app.VectorStore, @@ -155,6 +156,35 @@ func RegisterMiddlewareRoutes(e *echo.Echo, app *application.Application) { } return decideHandler(c) }) + + // Router KNN corpus management. Corpus input/curation is API-only + // by design — entries can contain example user content, so the UI + // never sends or renders them; the stats endpoint returns label + // counts only. Admin-gated like /api/router/decide: seeding the + // corpus changes routing behaviour and Add runs embedding + // inference on arbitrary input. + corpusDeps := middleware.ClassifierDeps{ + Embedder: app.Embedder, + VectorStore: app.VectorStore, + } + corpusAdd := localai.RouterCorpusAddEndpoint(app.ModelConfigLoader(), app.ApplicationConfig(), app.RouterCorpus(), corpusDeps) + corpusStats := localai.RouterCorpusStatsEndpoint(app.ModelConfigLoader(), app.ApplicationConfig(), app.RouterCorpus()) + corpusClear := localai.RouterCorpusClearEndpoint(app.ModelConfigLoader(), app.ApplicationConfig(), app.RouterCorpus(), corpusDeps) + adminGate := func(h echo.HandlerFunc) echo.HandlerFunc { + return func(c echo.Context) error { + viewer := resolveUsageUser(c, app) + if viewer == nil { + return c.JSON(http.StatusUnauthorized, map[string]string{"error": "not authenticated"}) + } + if viewer.Role != auth.RoleAdmin { + return c.JSON(http.StatusForbidden, map[string]string{"error": "admin access required"}) + } + return h(c) + } + } + e.POST("/api/router/:name/corpus", adminGate(corpusAdd)) + e.GET("/api/router/:name/corpus/stats", adminGate(corpusStats)) + e.DELETE("/api/router/:name/corpus", adminGate(corpusClear)) } // buildRouterStatus inventories every model that declares a Router @@ -210,6 +240,28 @@ func buildRouterStatus(app *application.Application) map[string]any { } entry["embedding_cache"] = cacheEntry } + if kc := cfg.Router.KNN; kc != nil { + storeName := kc.StoreName + if storeName == "" { + storeName = "router-corpus-" + cfg.Name + } + knnEntry := map[string]any{ + "embedding_model": kc.EmbeddingModel, + "k": kc.K, + "similarity_threshold": kc.SimilarityThreshold, + "vote_threshold": kc.VoteThreshold, + "store_name": storeName, + } + // Corpus stats are counts only — entry texts never reach + // the UI (or any API surface). + if s, err := app.RouterCorpus().Stats(storeName); err == nil { + knnEntry["corpus"] = map[string]any{ + "total": s.Total, + "label_counts": s.LabelCounts, + } + } + entry["knn"] = knnEntry + } models = append(models, entry) } @@ -224,7 +276,7 @@ func buildRouterStatus(app *application.Application) map[string]any { "configured": hasAny, "models": models, "recent_decision_count": recentCount, - "available_classifiers": []string{router.ClassifierScore}, + "available_classifiers": []string{router.ClassifierScore, router.ClassifierColbert, router.ClassifierKNN}, } if !hasAny { out["note"] = "No router models configured. Add a `router:` block to a model YAML to enable intelligent routing." diff --git a/core/http/routes/openai.go b/core/http/routes/openai.go index 32603f5675dc..c0db390044a3 100644 --- a/core/http/routes/openai.go +++ b/core/http/routes/openai.go @@ -73,6 +73,7 @@ func RegisterOpenAIRoutes(app *echo.Echo, router.SourceChat, middleware.ClassifierDeps{ Scorer: application.Scorer, + Corpus: application.RouterCorpus(), TokenCounter: application.TokenCounter, Embedder: application.Embedder, VectorStore: application.VectorStore, diff --git a/core/schema/localai.go b/core/schema/localai.go index 41b513ce996d..2d89b061c50e 100644 --- a/core/schema/localai.go +++ b/core/schema/localai.go @@ -33,31 +33,31 @@ type GalleryResponse struct { type VideoRequest struct { BasicModelRequest - Prompt string `json:"prompt" yaml:"prompt"` // text description of the video to generate - NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"` // things to avoid in the output - StartImage string `json:"start_image" yaml:"start_image"` // URL or base64 of the first frame - EndImage string `json:"end_image" yaml:"end_image"` // URL or base64 of the last frame - Width int32 `json:"width" yaml:"width"` // output width in pixels - Height int32 `json:"height" yaml:"height"` // output height in pixels - NumFrames int32 `json:"num_frames" yaml:"num_frames"` // total number of frames to generate - FPS int32 `json:"fps" yaml:"fps"` // frames per second - Seconds string `json:"seconds,omitempty" yaml:"seconds,omitempty"` // duration in seconds (alternative to num_frames) - Size string `json:"size,omitempty" yaml:"size,omitempty"` // WxH shorthand (e.g. "512x512") + Prompt string `json:"prompt" yaml:"prompt"` // text description of the video to generate + NegativePrompt string `json:"negative_prompt" yaml:"negative_prompt"` // things to avoid in the output + StartImage string `json:"start_image" yaml:"start_image"` // URL or base64 of the first frame + EndImage string `json:"end_image" yaml:"end_image"` // URL or base64 of the last frame + Width int32 `json:"width" yaml:"width"` // output width in pixels + Height int32 `json:"height" yaml:"height"` // output height in pixels + NumFrames int32 `json:"num_frames" yaml:"num_frames"` // total number of frames to generate + FPS int32 `json:"fps" yaml:"fps"` // frames per second + Seconds string `json:"seconds,omitempty" yaml:"seconds,omitempty"` // duration in seconds (alternative to num_frames) + Size string `json:"size,omitempty" yaml:"size,omitempty"` // WxH shorthand (e.g. "512x512") InputReference string `json:"input_reference,omitempty" yaml:"input_reference,omitempty"` // reference image or video URL - Seed int32 `json:"seed" yaml:"seed"` // random seed for reproducibility - CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"` // classifier-free guidance scale - Step int32 `json:"step" yaml:"step"` // number of diffusion steps - ResponseFormat string `json:"response_format" yaml:"response_format"` // output format (url or b64_json) + Seed int32 `json:"seed" yaml:"seed"` // random seed for reproducibility + CFGScale float32 `json:"cfg_scale" yaml:"cfg_scale"` // classifier-free guidance scale + Step int32 `json:"step" yaml:"step"` // number of diffusion steps + ResponseFormat string `json:"response_format" yaml:"response_format"` // output format (url or b64_json) } // @Description TTS request body type TTSRequest struct { BasicModelRequest - Input string `json:"input" yaml:"input"` // text input - Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id - Backend string `json:"backend" yaml:"backend"` // backend engine override - Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model - Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format + Input string `json:"input" yaml:"input"` // text input + Voice string `json:"voice" yaml:"voice"` // voice audio file or speaker id + Backend string `json:"backend" yaml:"backend"` // backend engine override + Language string `json:"language,omitempty" yaml:"language,omitempty"` // (optional) language to use with TTS model + Format string `json:"response_format,omitempty" yaml:"response_format,omitempty"` // (optional) output format Stream bool `json:"stream,omitempty" yaml:"stream,omitempty"` // (optional) enable streaming TTS SampleRate int `json:"sample_rate,omitempty" yaml:"sample_rate,omitempty"` // (optional) desired output sample rate // Instructions is a free-form, per-request style/voice description. It maps to @@ -161,10 +161,10 @@ type SystemInformationResponse struct { type DetectionRequest struct { BasicModelRequest Image string `json:"image"` // URL or base64-encoded image to analyze - Prompt string `json:"prompt,omitempty"` // Text prompt (for SAM 3 PCS mode) - Points []float32 `json:"points,omitempty"` // Point coordinates as [x,y,label,...] triples (label: 1=pos, 0=neg) - Boxes []float32 `json:"boxes,omitempty"` // Box coordinates as [x1,y1,x2,y2,...] quads - Threshold float32 `json:"threshold,omitempty"` // Detection confidence threshold + Prompt string `json:"prompt,omitempty"` // Text prompt (for SAM 3 PCS mode) + Points []float32 `json:"points,omitempty"` // Point coordinates as [x,y,label,...] triples (label: 1=pos, 0=neg) + Boxes []float32 `json:"boxes,omitempty"` // Box coordinates as [x1,y1,x2,y2,...] quads + Threshold float32 `json:"threshold,omitempty"` // Detection confidence threshold } type DetectionResponse struct { @@ -236,14 +236,14 @@ type FaceVerifyRequest struct { } type FaceVerifyResponse struct { - Verified bool `json:"verified"` - Distance float32 `json:"distance"` - Threshold float32 `json:"threshold"` - Confidence float32 `json:"confidence"` - Model string `json:"model"` - Img1Area FacialArea `json:"img1_area"` - Img2Area FacialArea `json:"img2_area"` - ProcessingTimeMs float32 `json:"processing_time_ms,omitempty"` + Verified bool `json:"verified"` + Distance float32 `json:"distance"` + Threshold float32 `json:"threshold"` + Confidence float32 `json:"confidence"` + Model string `json:"model"` + Img1Area FacialArea `json:"img1_area"` + Img2Area FacialArea `json:"img2_area"` + ProcessingTimeMs float32 `json:"processing_time_ms,omitempty"` // Liveness fields are only populated when the request set // anti_spoofing=true. Pointers keep them fully absent from the // JSON response otherwise, so callers can tell "not checked" @@ -524,6 +524,64 @@ type RouterDecideResponse struct { // CacheSimilarity carries the cosine similarity of the cache hit // (0 when not cached). CacheSimilarity float64 `json:"cache_similarity,omitempty"` + // NearestSimilarity is the cosine similarity of the closest KNN + // corpus entry — populated by the knn classifier even when the + // decision fell back because the probe was out of corpus range. + // 0 for other classifiers. + NearestSimilarity float64 `json:"nearest_similarity,omitempty"` +} + +// RouterCorpusEntry is one labelled exemplar submitted to +// POST /api/router/{name}/corpus. The text is embedded server-side +// with the router's knn.embedding_model; labels must be declared in +// the router's policies. +type RouterCorpusEntry struct { + Text string `json:"text"` + Labels []string `json:"labels"` +} + +// RouterCorpusAddRequest is the input for POST /api/router/{name}/corpus — +// bulk-seeds the KNN routing corpus. Corpus input is API-only by +// design: entries may contain example user content, so they are never +// entered through (or displayed in) the UI. +type RouterCorpusAddRequest struct { + Entries []RouterCorpusEntry `json:"entries"` +} + +// RouterCorpusAddResponse reports the outcome of a corpus seed call. +type RouterCorpusAddResponse struct { + Router string `json:"router"` + // Added is how many entries were embedded, persisted, and indexed. + Added int `json:"added"` + // Skipped counts entries whose text was already in the corpus — + // duplicates are rejected rather than double-weighted. + Skipped int `json:"skipped"` + // Total is the corpus size after the call. + Total int `json:"total"` + // LabelCounts is the per-label exemplar count after the call. + LabelCounts map[string]int `json:"label_counts"` +} + +// RouterCorpusStatsResponse is the inspection surface for a router's +// KNN corpus: counts and configuration only — entry texts are never +// returned by any endpoint. +type RouterCorpusStatsResponse struct { + Router string `json:"router"` + StoreName string `json:"store_name"` + EmbeddingModel string `json:"embedding_model"` + Total int `json:"total"` + LabelCounts map[string]int `json:"label_counts"` + // EmbeddingModels lists the embedder fingerprints present in the + // persisted corpus; more than one means part of the corpus is + // pending re-embedding on the next load. + EmbeddingModels []string `json:"embedding_models,omitempty"` +} + +// RouterCorpusClearResponse reports how many entries a +// DELETE /api/router/{name}/corpus removed. +type RouterCorpusClearResponse struct { + Router string `json:"router"` + Cleared int `json:"cleared"` } // PIIDecideRequest is the input for POST /api/pii/decide — the diff --git a/core/services/routing/corpus/corpus_suite_test.go b/core/services/routing/corpus/corpus_suite_test.go new file mode 100644 index 000000000000..36e2abea42ba --- /dev/null +++ b/core/services/routing/corpus/corpus_suite_test.go @@ -0,0 +1,13 @@ +package corpus_test + +import ( + "testing" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +func TestCorpus(t *testing.T) { + RegisterFailHandler(Fail) + RunSpecs(t, "Router corpus manager suite") +} diff --git a/core/services/routing/corpus/manager.go b/core/services/routing/corpus/manager.go new file mode 100644 index 000000000000..e61d8fd84bd8 --- /dev/null +++ b/core/services/routing/corpus/manager.go @@ -0,0 +1,375 @@ +// Package corpus persists and serves the labelled exemplar corpora +// behind KNN routing (router classifier "knn"). +// +// The corpus FILE is the source of truth: one JSONL file per store +// name under , each line an Entry {text, labels, vector, +// embedding_model}. The local-store vector backend is a pure in-memory +// index rebuilt from the file — it has no persistence of its own (its +// Load is an explicit no-op), and keeping it that way preserves the +// documented swap point for external vector backends. Vectors are +// cached in the file alongside the text so a restart re-indexes +// without re-embedding; entries recorded under a different embedding +// model are re-embedded on load. +// +// Texts in the corpus file never leave the server: Stats exposes label +// counts only, and there is deliberately no API that returns entries. +package corpus + +import ( + "bufio" + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "sort" + "strings" + "sync" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/services/routing/router" + "github.com/mudler/xlog" +) + +// Entry is one labelled exemplar. Vector and EmbeddingModel are the +// embedding cache — absent (or stale) entries are re-embedded when the +// corpus is loaded or added to. +type Entry struct { + Text string `json:"text"` + Labels []string `json:"labels"` + Vector []float32 `json:"vector,omitempty"` + EmbeddingModel string `json:"embedding_model,omitempty"` +} + +// Stats is the inspection surface — counts only, never texts. The +// router corpus API and the Routing tab render this. +type Stats struct { + StoreName string `json:"store_name"` + Total int `json:"total"` + LabelCounts map[string]int `json:"label_counts"` + // EmbeddingModels lists the distinct embedder fingerprints present + // in the file. More than one entry means a re-embed is pending for + // part of the corpus (it happens lazily on the next load). + EmbeddingModels []string `json:"embedding_models,omitempty"` +} + +// batchIndex and deleter are optional fast paths the local-store +// implementation provides; the Manager degrades to per-entry Insert +// (and to "entries persist in the index until restart" on Clear) when +// a store doesn't. +type batchIndex interface { + InsertBatch(ctx context.Context, vecs [][]float32, payloads [][]byte) error +} + +type deleter interface { + Delete(ctx context.Context, vecs [][]float32) error +} + +// Manager owns the corpus files and their sync state into the +// in-memory vector index. One per process, shared by the classifier +// build path (EnsureLoaded) and the corpus API (Add/Stats/Clear). +type Manager struct { + dir string + + mu sync.Mutex + // loadedModel records which embedding model a store name was synced + // into the live index under. Guards double-loading and detects the + // embedding-model-changed-on-a-live-index case, which local-store + // cannot serve (it enforces one key length per store). + loadedModel map[string]string +} + +// NewManager roots corpus files at dir (created lazily on first +// write). +func NewManager(dir string) *Manager { + return &Manager{dir: dir, loadedModel: map[string]string{}} +} + +// path maps a store name to its corpus file. Store names come from +// YAML (store_name) or model names; sanitise so they can't escape the +// corpus dir or collide with path syntax. +func (m *Manager) path(storeName string) string { + safe := strings.Map(func(r rune) rune { + switch { + case r >= 'a' && r <= 'z', r >= 'A' && r <= 'Z', r >= '0' && r <= '9', r == '-', r == '_', r == '.': + return r + } + return '_' + }, storeName) + return filepath.Join(m.dir, safe+".jsonl") +} + +// EnsureLoaded syncs the persisted corpus for storeName into the live +// vector index, once per (store, embedding model) per process. Entries +// recorded under a different embedding model are re-embedded and the +// file rewritten. Returns the number of entries now indexed. A missing +// file is an empty corpus, not an error. +func (m *Manager) EnsureLoaded(ctx context.Context, storeName, embeddingModel string, embedder backend.Embedder, store backend.VectorStore) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + if prev, ok := m.loadedModel[storeName]; ok { + if prev == embeddingModel { + return 0, nil + } + // The in-memory index already holds vectors from another + // embedder; local-store enforces a single key length, so mixing + // would fail on insert (or silently corrupt neighbourhoods when + // dimensions happen to match). A restart re-indexes cleanly. + return 0, fmt.Errorf("corpus %q was indexed with embedding model %q this process; restart LocalAI to re-index it with %q", storeName, prev, embeddingModel) + } + + entries, err := m.read(storeName) + if err != nil { + return 0, err + } + if len(entries) == 0 { + m.loadedModel[storeName] = embeddingModel + return 0, nil + } + + dirty := false + for i := range entries { + if len(entries[i].Vector) > 0 && entries[i].EmbeddingModel == embeddingModel { + continue + } + if embedder == nil { + return 0, fmt.Errorf("corpus %q has entries needing (re-)embedding but no embedder is available", storeName) + } + vec, err := embedder.Embed(ctx, entries[i].Text) + if err != nil { + return 0, fmt.Errorf("corpus %q: re-embedding entry %d: %w", storeName, i, err) + } + entries[i].Vector = vec + entries[i].EmbeddingModel = embeddingModel + dirty = true + } + if dirty { + if err := m.write(storeName, entries); err != nil { + return 0, err + } + } + + if err := insertAll(ctx, store, entries); err != nil { + return 0, err + } + m.loadedModel[storeName] = embeddingModel + return len(entries), nil +} + +// Add validates, embeds, persists, and indexes new exemplars. Entries +// whose text is already in the corpus are skipped (an exemplar's +// labels are corrected via Clear + reseed, not silent overwrite). +// Returns (added, skipped). The file write happens before the index +// insert: if indexing fails the entries are still durable and the next +// EnsureLoaded (or restart) syncs them. +func (m *Manager) Add(ctx context.Context, storeName, embeddingModel string, embedder backend.Embedder, store backend.VectorStore, entries []Entry) (int, int, error) { + if embedder == nil { + return 0, 0, fmt.Errorf("corpus %q: no embedder available", storeName) + } + for i, e := range entries { + if strings.TrimSpace(e.Text) == "" { + return 0, 0, fmt.Errorf("corpus entry %d: empty text", i) + } + if len(e.Labels) == 0 { + return 0, 0, fmt.Errorf("corpus entry %d: at least one label required", i) + } + } + + m.mu.Lock() + defer m.mu.Unlock() + + existing, err := m.read(storeName) + if err != nil { + return 0, 0, err + } + seen := make(map[string]struct{}, len(existing)) + for _, e := range existing { + seen[e.Text] = struct{}{} + } + + added := make([]Entry, 0, len(entries)) + skipped := 0 + for _, e := range entries { + if _, dup := seen[e.Text]; dup { + skipped++ + continue + } + seen[e.Text] = struct{}{} + vec, err := embedder.Embed(ctx, e.Text) + if err != nil { + return 0, skipped, fmt.Errorf("corpus %q: embedding %q-labelled entry: %w", storeName, e.Labels[0], err) + } + e.Vector = vec + e.EmbeddingModel = embeddingModel + added = append(added, e) + } + if len(added) == 0 { + return 0, skipped, nil + } + + if err := m.write(storeName, append(existing, added...)); err != nil { + return 0, skipped, err + } + if store != nil { + if err := insertAll(ctx, store, added); err != nil { + // Durable but not indexed — routing won't see the new + // entries until the next successful load. Surface loudly. + return len(added), skipped, fmt.Errorf("corpus %q: entries persisted but indexing failed (they will index on next load/restart): %w", storeName, err) + } + } + return len(added), skipped, nil +} + +// Stats reports label counts for the persisted corpus. Never returns +// entry texts. +func (m *Manager) Stats(storeName string) (Stats, error) { + m.mu.Lock() + defer m.mu.Unlock() + + entries, err := m.read(storeName) + if err != nil { + return Stats{}, err + } + s := Stats{StoreName: storeName, Total: len(entries), LabelCounts: map[string]int{}} + models := map[string]struct{}{} + for _, e := range entries { + for _, l := range e.Labels { + s.LabelCounts[l]++ + } + if e.EmbeddingModel != "" { + models[e.EmbeddingModel] = struct{}{} + } + } + for mn := range models { + s.EmbeddingModels = append(s.EmbeddingModels, mn) + } + sort.Strings(s.EmbeddingModels) + return s, nil +} + +// Clear deletes the corpus file and removes its vectors from the live +// index when the store supports deletion. When it doesn't, the index +// keeps serving stale entries until restart — the returned count and +// the warning log make that visible rather than silent. +func (m *Manager) Clear(ctx context.Context, storeName string, store backend.VectorStore) (int, error) { + m.mu.Lock() + defer m.mu.Unlock() + + entries, err := m.read(storeName) + if err != nil { + return 0, err + } + if err := os.Remove(m.path(storeName)); err != nil && !os.IsNotExist(err) { + return 0, err + } + delete(m.loadedModel, storeName) + if len(entries) == 0 || store == nil { + return len(entries), nil + } + if d, ok := store.(deleter); ok { + vecs := make([][]float32, 0, len(entries)) + for _, e := range entries { + if len(e.Vector) > 0 { + vecs = append(vecs, e.Vector) + } + } + if len(vecs) > 0 { + if err := d.Delete(ctx, vecs); err != nil { + return len(entries), fmt.Errorf("corpus %q: file cleared but live index deletion failed (stale entries served until restart): %w", storeName, err) + } + } + } else { + xlog.Warn("corpus: vector store does not support deletion; cleared corpus stays in the live index until restart", + "store", storeName, "entries", len(entries)) + } + return len(entries), nil +} + +func insertAll(ctx context.Context, store backend.VectorStore, entries []Entry) error { + vecs := make([][]float32, 0, len(entries)) + payloads := make([][]byte, 0, len(entries)) + for _, e := range entries { + payload, err := router.EncodeCorpusEntry(e.Labels) + if err != nil { + return err + } + vecs = append(vecs, e.Vector) + payloads = append(payloads, payload) + } + if b, ok := store.(batchIndex); ok { + return b.InsertBatch(ctx, vecs, payloads) + } + for i := range vecs { + if err := store.Insert(ctx, vecs[i], payloads[i]); err != nil { + return err + } + } + return nil +} + +// read returns the persisted entries for storeName; a missing file is +// an empty corpus. Callers hold m.mu. +func (m *Manager) read(storeName string) ([]Entry, error) { + f, err := os.Open(m.path(storeName)) + if os.IsNotExist(err) { + return nil, nil + } + if err != nil { + return nil, err + } + defer func() { _ = f.Close() }() + var entries []Entry + sc := bufio.NewScanner(f) + // Vectors inline in JSON push line length well past the default + // 64KiB scanner cap (a 4096-dim float32 vector is ~50KiB of JSON + // alone). 16MiB bounds any realistic embedding width. + sc.Buffer(make([]byte, 0, 1<<20), 16<<20) + line := 0 + for sc.Scan() { + line++ + raw := strings.TrimSpace(sc.Text()) + if raw == "" { + continue + } + var e Entry + if err := json.Unmarshal([]byte(raw), &e); err != nil { + return nil, fmt.Errorf("corpus file %s line %d: %w", m.path(storeName), line, err) + } + entries = append(entries, e) + } + if err := sc.Err(); err != nil { + return nil, err + } + return entries, nil +} + +// write atomically replaces the corpus file (tmp + rename) so a crash +// mid-write can't truncate the corpus. Callers hold m.mu. +func (m *Manager) write(storeName string, entries []Entry) error { + if err := os.MkdirAll(m.dir, 0o750); err != nil { + return err + } + target := m.path(storeName) + tmp, err := os.CreateTemp(m.dir, filepath.Base(target)+".tmp-*") + if err != nil { + return err + } + defer func() { _ = os.Remove(tmp.Name()) }() + w := bufio.NewWriter(tmp) + enc := json.NewEncoder(w) + for _, e := range entries { + if err := enc.Encode(e); err != nil { + _ = tmp.Close() + return err + } + } + if err := w.Flush(); err != nil { + _ = tmp.Close() + return err + } + if err := tmp.Close(); err != nil { + return err + } + return os.Rename(tmp.Name(), target) +} diff --git a/core/services/routing/corpus/manager_test.go b/core/services/routing/corpus/manager_test.go new file mode 100644 index 000000000000..9c69581bd65b --- /dev/null +++ b/core/services/routing/corpus/manager_test.go @@ -0,0 +1,240 @@ +package corpus_test + +import ( + "context" + "os" + "path/filepath" + "strings" + "sync" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/services/routing/corpus" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// countingEmbedder returns a deterministic vector per (model, text) +// and counts calls, so specs can assert when re-embedding happened vs +// when the cached vectors were reused. +type countingEmbedder struct { + mu sync.Mutex + model float32 // baked into the vector so specs can tell models apart + calls int +} + +func (e *countingEmbedder) Embed(_ context.Context, text string) ([]float32, error) { + e.mu.Lock() + defer e.mu.Unlock() + e.calls += 1 + return []float32{float32(len(text)), e.model}, nil +} + +// capturingStore records index mutations. Search/SearchK are +// irrelevant to the manager and return clean misses. +type capturingStore struct { + mu sync.Mutex + payloads [][]byte + batches int + deleted [][]float32 +} + +func (s *capturingStore) Search(_ context.Context, _ []float32) (float64, []byte, bool, error) { + return 0, nil, false, nil +} + +func (s *capturingStore) SearchK(_ context.Context, _ []float32, _ int) ([]backend.Neighbor, error) { + return nil, nil +} + +func (s *capturingStore) Insert(_ context.Context, _ []float32, payload []byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.payloads = append(s.payloads, payload) + return nil +} + +func (s *capturingStore) InsertBatch(_ context.Context, vecs [][]float32, payloads [][]byte) error { + s.mu.Lock() + defer s.mu.Unlock() + s.batches++ + s.payloads = append(s.payloads, payloads...) + _ = vecs + return nil +} + +func (s *capturingStore) Delete(_ context.Context, vecs [][]float32) error { + s.mu.Lock() + defer s.mu.Unlock() + s.deleted = append(s.deleted, vecs...) + return nil +} + +var _ = Describe("corpus.Manager", func() { + var ( + dir string + mgr *corpus.Manager + embedder *countingEmbedder + store *capturingStore + ctx context.Context + ) + + const storeName = "router-corpus-smart-router" + + seed := []corpus.Entry{ + {Text: "debug this Go null pointer", Labels: []string{"code-generation"}}, + {Text: "what is 12 * 42?", Labels: []string{"math-reasoning"}}, + {Text: "refactor this and explain the math", Labels: []string{"code-generation", "math-reasoning"}}, + } + + BeforeEach(func() { + d, err := os.MkdirTemp("", "corpus-test-*") + Expect(err).NotTo(HaveOccurred()) + dir = d + mgr = corpus.NewManager(dir) + embedder = &countingEmbedder{model: 1} + store = &capturingStore{} + ctx = context.Background() + }) + + AfterEach(func() { + _ = os.RemoveAll(dir) + }) + + It("adds entries: embeds, persists, and indexes them", func() { + added, skipped, err := mgr.Add(ctx, storeName, "embed-1", embedder, store, seed) + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(Equal(3)) + Expect(skipped).To(BeZero()) + Expect(embedder.calls).To(Equal(3)) + Expect(store.payloads).To(HaveLen(3)) + Expect(store.batches).To(Equal(1), "should use the batch fast path") + + // Payloads are the label sets the classifier votes over. + Expect(string(store.payloads[0])).To(ContainSubstring("code-generation")) + + // Persisted on disk under a sanitised name. + _, err = os.Stat(filepath.Join(dir, storeName+".jsonl")) + Expect(err).NotTo(HaveOccurred()) + }) + + It("skips duplicate texts instead of double-weighting them", func() { + _, _, err := mgr.Add(ctx, storeName, "embed-1", embedder, store, seed) + Expect(err).NotTo(HaveOccurred()) + added, skipped, err := mgr.Add(ctx, storeName, "embed-1", embedder, store, seed[:2]) + Expect(err).NotTo(HaveOccurred()) + Expect(added).To(BeZero()) + Expect(skipped).To(Equal(2)) + }) + + It("rejects empty text and label-less entries", func() { + _, _, err := mgr.Add(ctx, storeName, "embed-1", embedder, store, []corpus.Entry{{Text: " ", Labels: []string{"x"}}}) + Expect(err).To(HaveOccurred()) + _, _, err = mgr.Add(ctx, storeName, "embed-1", embedder, store, []corpus.Entry{{Text: "hello", Labels: nil}}) + Expect(err).To(HaveOccurred()) + }) + + It("reloads a persisted corpus into a fresh index without re-embedding", func() { + _, _, err := mgr.Add(ctx, storeName, "embed-1", embedder, store, seed) + Expect(err).NotTo(HaveOccurred()) + + // Simulate restart: fresh manager over the same dir, empty index. + mgr2 := corpus.NewManager(dir) + store2 := &capturingStore{} + embedder2 := &countingEmbedder{model: 1} + n, err := mgr2.EnsureLoaded(ctx, storeName, "embed-1", embedder2, store2) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(store2.payloads).To(HaveLen(3)) + Expect(embedder2.calls).To(BeZero(), "cached vectors must be reused") + + // Second call is a no-op — already synced this process. + n, err = mgr2.EnsureLoaded(ctx, storeName, "embed-1", embedder2, store2) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(BeZero()) + }) + + It("re-embeds entries recorded under a different embedding model", func() { + _, _, err := mgr.Add(ctx, storeName, "embed-1", embedder, store, seed) + Expect(err).NotTo(HaveOccurred()) + + mgr2 := corpus.NewManager(dir) + newEmbedder := &countingEmbedder{model: 2} + store2 := &capturingStore{} + n, err := mgr2.EnsureLoaded(ctx, storeName, "embed-2", newEmbedder, store2) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(newEmbedder.calls).To(Equal(3), "fingerprint mismatch must re-embed") + + // The rewrite is durable: a third manager loads under embed-2 + // without touching the embedder again. + mgr3 := corpus.NewManager(dir) + embedder3 := &countingEmbedder{model: 2} + n, err = mgr3.EnsureLoaded(ctx, storeName, "embed-2", embedder3, &capturingStore{}) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(embedder3.calls).To(BeZero()) + }) + + It("refuses to mix embedding models in a live index", func() { + _, _, err := mgr.Add(ctx, storeName, "embed-1", embedder, store, seed) + Expect(err).NotTo(HaveOccurred()) + _, err = mgr.EnsureLoaded(ctx, storeName, "embed-1", embedder, store) + Expect(err).NotTo(HaveOccurred()) + + _, err = mgr.EnsureLoaded(ctx, storeName, "embed-2", &countingEmbedder{model: 2}, store) + Expect(err).To(HaveOccurred()) + Expect(err.Error()).To(ContainSubstring("restart")) + }) + + It("reports label counts and never texts", func() { + _, _, err := mgr.Add(ctx, storeName, "embed-1", embedder, store, seed) + Expect(err).NotTo(HaveOccurred()) + st, err := mgr.Stats(storeName) + Expect(err).NotTo(HaveOccurred()) + Expect(st.Total).To(Equal(3)) + Expect(st.LabelCounts).To(Equal(map[string]int{ + "code-generation": 2, + "math-reasoning": 2, + })) + Expect(st.EmbeddingModels).To(Equal([]string{"embed-1"})) + }) + + It("clears the file and the live index", func() { + _, _, err := mgr.Add(ctx, storeName, "embed-1", embedder, store, seed) + Expect(err).NotTo(HaveOccurred()) + + n, err := mgr.Clear(ctx, storeName, store) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(Equal(3)) + Expect(store.deleted).To(HaveLen(3)) + + st, err := mgr.Stats(storeName) + Expect(err).NotTo(HaveOccurred()) + Expect(st.Total).To(BeZero()) + + // And a load after clear indexes nothing. + loaded, err := corpus.NewManager(dir).EnsureLoaded(ctx, storeName, "embed-1", embedder, &capturingStore{}) + Expect(err).NotTo(HaveOccurred()) + Expect(loaded).To(BeZero()) + }) + + It("treats a missing file as an empty corpus", func() { + n, err := mgr.EnsureLoaded(ctx, "never-seeded", "embed-1", embedder, store) + Expect(err).NotTo(HaveOccurred()) + Expect(n).To(BeZero()) + st, err := mgr.Stats("never-seeded") + Expect(err).NotTo(HaveOccurred()) + Expect(st.Total).To(BeZero()) + }) + + It("sanitises hostile store names into the corpus dir", func() { + hostile := "../../etc/passwd" + _, _, err := mgr.Add(ctx, hostile, "embed-1", embedder, store, seed[:1]) + Expect(err).NotTo(HaveOccurred()) + entries, err := os.ReadDir(dir) + Expect(err).NotTo(HaveOccurred()) + Expect(entries).To(HaveLen(1)) + Expect(strings.Contains(entries[0].Name(), "/")).To(BeFalse()) + }) +}) diff --git a/core/services/routing/router/decisions.go b/core/services/routing/router/decisions.go index d446ac29a63e..23f512813720 100644 --- a/core/services/routing/router/decisions.go +++ b/core/services/routing/router/decisions.go @@ -11,18 +11,19 @@ import ( // Prompt is NEVER stored — admins audit by Hash if they need to // dedupe recurring routing patterns. type DecisionRecord struct { - ID string `json:"id"` - CorrelationID string `json:"correlation_id"` - UserID string `json:"user_id"` - RouterModel string `json:"router_model"` // The smart-router model name the client asked for. - RequestedModel string `json:"requested_model"`// Same as RouterModel for now; reserved for chained routers. - ServedModel string `json:"served_model"` // The candidate the classifier picked. - Classifier string `json:"classifier"` // Classifier.Name(), e.g. "score". - Label string `json:"label"` - Score float64 `json:"score"` - LatencyMs int64 `json:"latency_ms"` - Cached bool `json:"cached"` // True when the decision came from the L2 embedding cache. - CacheSimilarity float64 `json:"cache_similarity,omitempty"` // Cosine similarity of the cache hit, 0 when not cached. + ID string `json:"id"` + CorrelationID string `json:"correlation_id"` + UserID string `json:"user_id"` + RouterModel string `json:"router_model"` // The smart-router model name the client asked for. + RequestedModel string `json:"requested_model"` // Same as RouterModel for now; reserved for chained routers. + ServedModel string `json:"served_model"` // The candidate the classifier picked. + Classifier string `json:"classifier"` // Classifier.Name(), e.g. "score". + Label string `json:"label"` + Score float64 `json:"score"` + LatencyMs int64 `json:"latency_ms"` + Cached bool `json:"cached"` // True when the decision came from the L2 embedding cache. + CacheSimilarity float64 `json:"cache_similarity,omitempty"` // Cosine similarity of the cache hit, 0 when not cached. + NearestSimilarity float64 `json:"nearest_similarity,omitempty"` // KNN classifier: similarity of the closest corpus entry, set even on fallback decisions. 0 for other classifiers. // LabelScores carries the full per-label score distribution so the // admin UI can show how close inactive labels got to the activation // threshold. Empty on cache hits (only the final label set is cached). @@ -32,8 +33,8 @@ type DecisionRecord struct { // the admin page can split realtime / chat / anthropic streams. Empty // string is treated as "chat" for backward compatibility with rows // written before the field existed. - Source string `json:"source,omitempty"` - CreatedAt time.Time `json:"created_at"` + Source string `json:"source,omitempty"` + CreatedAt time.Time `json:"created_at"` } // Source values for DecisionRecord.Source. Kept as constants so callers diff --git a/core/services/routing/router/embedding_cache_test.go b/core/services/routing/router/embedding_cache_test.go index e36b049c3d87..41be408a00f7 100644 --- a/core/services/routing/router/embedding_cache_test.go +++ b/core/services/routing/router/embedding_cache_test.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/services/routing/router" . "github.com/onsi/ginkgo/v2" @@ -88,6 +89,31 @@ func (s *memVectorStore) Search(_ context.Context, vec []float32) (float64, []by return 0, nil, false, nil } +// SearchK reuses the same synthetic metric as Search: 1.0 exact, 0.8 +// leading-element, 0.0 otherwise — zero-similarity entries are omitted. +func (s *memVectorStore) SearchK(_ context.Context, vec []float32, k int) ([]backend.Neighbor, error) { + s.mu.Lock() + defer s.mu.Unlock() + if s.failOps > 0 { + s.failOps-- + return nil, errors.New("store offline") + } + var exact, close []backend.Neighbor + for _, e := range s.entries { + switch { + case vecEqual(e.vec, vec): + exact = append(exact, backend.Neighbor{Similarity: 1.0, Payload: e.payload}) + case len(vec) > 0 && len(e.vec) > 0 && vec[0] == e.vec[0]: + close = append(close, backend.Neighbor{Similarity: 0.80, Payload: e.payload}) + } + } + neighbors := append(exact, close...) + if len(neighbors) > k { + neighbors = neighbors[:k] + } + return neighbors, nil +} + func (s *memVectorStore) Insert(_ context.Context, vec []float32, payload []byte) error { s.mu.Lock() defer s.mu.Unlock() diff --git a/core/services/routing/router/knn.go b/core/services/routing/router/knn.go new file mode 100644 index 000000000000..ddde30f43679 --- /dev/null +++ b/core/services/routing/router/knn.go @@ -0,0 +1,222 @@ +package router + +import ( + "context" + "encoding/json" + "fmt" + "sort" + "time" + + "github.com/mudler/LocalAI/core/backend" +) + +// KNNClassifier routes by nearest-neighbour vote over a curated, +// labelled corpus of example prompts. It is the first-class form of +// what EmbeddingCacheClassifier does opportunistically: instead of +// caching another classifier's decisions, the corpus is seeded and +// curated explicitly (via the router corpus API), each entry carrying +// the policy labels a matching prompt should activate. +// +// Classify embeds the probe, fetches the K nearest corpus entries, and +// activates every label whose similarity-weighted vote share clears +// VoteThreshold. Neighbours below SimilarityThreshold are discarded +// first — that cutoff is the epistemic gate: a probe dissimilar from +// *all* labelled experience is undecidable by construction, so the +// classifier returns an empty label set and the middleware falls back +// to cfg.Router.Fallback (the assumed-best model) rather than guessing. +// +// The classifier never inserts into the corpus on its own. Routing +// outcomes only become corpus entries through explicit curation — a +// mislabelled exemplar poisons every future neighbourhood around it, +// so growth is an admin decision, not a side effect. +type KNNClassifier struct { + embedder backend.Embedder + store backend.VectorStore + k int + similarityThreshold float64 + voteThreshold float64 + + // budget trims the conversation to the embedder model's own context + // before embedding; nil embeds Probe.Prompt as built by the caller. + budget *lazyBudget +} + +// Defaults. K=3 keeps a weighted majority meaningful on small corpora +// while tolerating one mislabelled neighbour; the similarity default +// matches the embedding cache so one threshold intuition serves both. +// VoteThreshold 0.5 means a label activates on a weighted majority — +// with K=1 this degenerates to "the nearest entry's labels", the same +// contract the embedding cache implements. +const ( + defaultKNNK = 3 + defaultKNNSimilarity = 0.80 + defaultKNNVote = 0.5 +) + +// KNNClassifierOptions carries the tunables; zero values pick the +// package defaults above. +type KNNClassifierOptions struct { + K int + SimilarityThreshold float64 + VoteThreshold float64 +} + +// NewKNNClassifier builds a KNN classifier over the given embedder and +// vector store. Panics on nil embedder/store — same fail-fast posture +// as the other classifiers; buildClassifier validates config before +// construction. +func NewKNNClassifier(embedder backend.Embedder, store backend.VectorStore, opts KNNClassifierOptions) *KNNClassifier { + if embedder == nil { + panic("router/knn: embedder is required") + } + if store == nil { + panic("router/knn: vector store is required") + } + if opts.K <= 0 { + opts.K = defaultKNNK + } + if opts.SimilarityThreshold <= 0 { + opts.SimilarityThreshold = defaultKNNSimilarity + } + if opts.VoteThreshold <= 0 { + opts.VoteThreshold = defaultKNNVote + } + return &KNNClassifier{ + embedder: embedder, + store: store, + k: opts.K, + similarityThreshold: opts.SimilarityThreshold, + voteThreshold: opts.VoteThreshold, + } +} + +// WithTokenTrim wires the embedder model's own tokenizer and context so +// the probe embeds the most recent turns that fit instead of a +// caller-chosen size. nil tokenizer / non-positive context leaves +// trimming off. Returns the receiver for chaining at construction. +func (c *KNNClassifier) WithTokenTrim(tokenize func(string) (int, error), maxContextTokens int) *KNNClassifier { + c.budget = &lazyBudget{tokenize: tokenize, maxContext: maxContextTokens} + return c +} + +func (c *KNNClassifier) Name() string { return ClassifierKNN } + +func (c *KNNClassifier) Classify(ctx context.Context, p Probe) (Decision, error) { + start := time.Now() + + vec, err := c.embedder.Embed(ctx, trimmedProbeText(p, c.budget, identityRender)) + if err != nil { + return errDecision(start, fmt.Errorf("knn classifier embed: %w", err)) + } + neighbors, err := c.store.SearchK(ctx, vec, c.k) + if err != nil { + return errDecision(start, fmt.Errorf("knn classifier search: %w", err)) + } + + // Epistemic gate: only neighbours the probe is genuinely close to + // may vote. Keeping sub-threshold neighbours out of the vote (rather + // than merely gating on the best one) stops far-away corpus regions + // from diluting a clear local majority. + best := 0.0 + usable := neighbors[:0] + for _, n := range neighbors { + if n.Similarity > best { + best = n.Similarity + } + if n.Similarity >= c.similarityThreshold { + usable = append(usable, n) + } + } + if len(usable) == 0 { + // Out of corpus range — empty label set routes to the fallback + // via MatchCandidate's empty-active-set contract. Surfacing the + // best similarity in the decision log tells the admin whether + // the corpus needs entries near this probe or the threshold is + // simply too tight. + return Decision{ + NearestSimilarity: best, + ActivationThreshold: c.voteThreshold, + Latency: time.Since(start), + }, nil + } + + votes := map[string]float64{} + total := 0.0 + for _, n := range usable { + entry, ok := decodeCorpusEntry(n.Payload) + if !ok { + // A corrupt payload can't vote; it still counted toward K. + continue + } + total += n.Similarity + for _, l := range entry.Labels { + votes[l] += n.Similarity + } + } + if total == 0 { + return Decision{ + NearestSimilarity: best, + ActivationThreshold: c.voteThreshold, + Latency: time.Since(start), + }, nil + } + + // Vote shares in descending order; ties broken lexicographically so + // the decision log is deterministic. + labels := make([]string, 0, len(votes)) + for l := range votes { + labels = append(labels, l) + } + sort.Slice(labels, func(i, j int) bool { + if votes[labels[i]] != votes[labels[j]] { + return votes[labels[i]] > votes[labels[j]] + } + return labels[i] < labels[j] + }) + scores := make([]float64, len(labels)) + active := []string{} + for i, l := range labels { + scores[i] = votes[l] / total + if scores[i] >= c.voteThreshold { + active = append(active, l) + } + } + + d := Decision{ + Labels: active, + ActivationThreshold: c.voteThreshold, + LabelScores: NewLabelScores(labels, scores), + NearestSimilarity: best, + Latency: time.Since(start), + } + if len(active) > 0 { + d.Score = votes[active[0]] / total + } + return d, nil +} + +// corpusEntry is the stored shape of one labelled exemplar. Kept +// deliberately minimal: the vector key lives in the store, the text +// lives only in the corpus file (never returned by inspection APIs), +// so the store payload is just the label set. +type corpusEntry struct { + Labels []string `json:"labels"` +} + +// EncodeCorpusEntry serialises the labels of one corpus exemplar into +// the vector-store payload shape Classify votes over. Exported for the +// corpus loader/API in core, which owns insertion. +func EncodeCorpusEntry(labels []string) ([]byte, error) { + if len(labels) == 0 { + return nil, fmt.Errorf("corpus entry needs at least one label") + } + return json.Marshal(corpusEntry{Labels: labels}) +} + +func decodeCorpusEntry(b []byte) (corpusEntry, bool) { + var e corpusEntry + if err := json.Unmarshal(b, &e); err != nil || len(e.Labels) == 0 { + return corpusEntry{}, false + } + return e, true +} diff --git a/core/services/routing/router/knn_test.go b/core/services/routing/router/knn_test.go new file mode 100644 index 000000000000..49bfd9a4c1e6 --- /dev/null +++ b/core/services/routing/router/knn_test.go @@ -0,0 +1,174 @@ +package router_test + +import ( + "context" + "errors" + + "github.com/mudler/LocalAI/core/backend" + "github.com/mudler/LocalAI/core/services/routing/router" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +// scriptedKNNStore returns a fixed neighbour list from SearchK, +// letting tests exercise the vote/gate math without a real store. +type scriptedKNNStore struct { + neighbors []backend.Neighbor + err error + lastK int +} + +func (s *scriptedKNNStore) SearchK(_ context.Context, _ []float32, k int) ([]backend.Neighbor, error) { + s.lastK = k + if s.err != nil { + return nil, s.err + } + if len(s.neighbors) > k { + return s.neighbors[:k], s.err + } + return s.neighbors, s.err +} + +func (s *scriptedKNNStore) Search(_ context.Context, _ []float32) (float64, []byte, bool, error) { + return 0, nil, false, errors.New("knn classifier must use SearchK") +} + +func (s *scriptedKNNStore) Insert(_ context.Context, _ []float32, _ []byte) error { + return errors.New("knn classifier must never insert") +} + +func mustEntry(labels ...string) []byte { + b, err := router.EncodeCorpusEntry(labels) + Expect(err).ToNot(HaveOccurred()) + return b +} + +var _ = Describe("KNNClassifier", func() { + var ( + embedder *fakeEmbedder + probe router.Probe + ctx context.Context + ) + + BeforeEach(func() { + ctx = context.Background() + embedder = &fakeEmbedder{table: map[string][]float32{"prompt": {1, 0, 0}}} + probe = router.Probe{Prompt: "prompt"} + }) + + classify := func(store *scriptedKNNStore, opts router.KNNClassifierOptions) router.Decision { + c := router.NewKNNClassifier(embedder, store, opts) + d, err := c.Classify(ctx, probe) + Expect(err).ToNot(HaveOccurred()) + return d + } + + It("computes similarity-weighted vote shares exactly", func() { + // Hand-computed: usable sims 0.90 {code}, 0.85 {code,math}, + // 0.82 {math}; total = 2.57. + // share(code) = (0.90+0.85)/2.57 = 0.68093... + // share(math) = (0.85+0.82)/2.57 = 0.64980... + // Both clear the 0.5 majority → both active; candidate matching + // then requires a model labelled for both. + store := &scriptedKNNStore{neighbors: []backend.Neighbor{ + {Similarity: 0.90, Payload: mustEntry("code")}, + {Similarity: 0.85, Payload: mustEntry("code", "math")}, + {Similarity: 0.82, Payload: mustEntry("math")}, + }} + d := classify(store, router.KNNClassifierOptions{K: 3}) + Expect(d.Labels).To(Equal([]string{"code", "math"})) + Expect(d.LabelScores).To(HaveLen(2)) + Expect(d.LabelScores[0].Label).To(Equal("code")) + Expect(d.LabelScores[0].Score).To(BeNumerically("~", 1.75/2.57, 1e-9)) + Expect(d.LabelScores[1].Label).To(Equal("math")) + Expect(d.LabelScores[1].Score).To(BeNumerically("~", 1.67/2.57, 1e-9)) + Expect(d.Score).To(BeNumerically("~", 1.75/2.57, 1e-9)) + Expect(d.NearestSimilarity).To(BeNumerically("~", 0.90, 1e-9)) + Expect(store.lastK).To(Equal(3)) + }) + + It("does not activate a minority label", func() { + // share(chat) = 0.81/2.57 < 0.5 → inactive, but still reported + // in LabelScores so the decision log shows how close it came. + store := &scriptedKNNStore{neighbors: []backend.Neighbor{ + {Similarity: 0.90, Payload: mustEntry("code")}, + {Similarity: 0.86, Payload: mustEntry("code")}, + {Similarity: 0.81, Payload: mustEntry("chat")}, + }} + d := classify(store, router.KNNClassifierOptions{K: 3}) + Expect(d.Labels).To(Equal([]string{"code"})) + Expect(d.LabelScores).To(HaveLen(2)) + Expect(d.LabelScores[1].Label).To(Equal("chat")) + Expect(d.LabelScores[1].Score).To(BeNumerically("<", 0.5)) + }) + + It("gates out-of-corpus probes to the fallback (empty labels)", func() { + store := &scriptedKNNStore{neighbors: []backend.Neighbor{ + {Similarity: 0.55, Payload: mustEntry("code")}, + {Similarity: 0.40, Payload: mustEntry("math")}, + }} + d := classify(store, router.KNNClassifierOptions{SimilarityThreshold: 0.80}) + Expect(d.Labels).To(BeEmpty()) + // The admin-facing epistemic signal: how far away the nearest + // labelled experience was. + Expect(d.NearestSimilarity).To(BeNumerically("~", 0.55, 1e-9)) + }) + + It("excludes sub-threshold neighbours from the vote", func() { + // The 0.3-sim {chat} neighbour must not dilute the local + // majority: with it, share(code) would be 0.9/1.2 = 0.75; the + // vote must instead be over the single usable neighbour. + store := &scriptedKNNStore{neighbors: []backend.Neighbor{ + {Similarity: 0.90, Payload: mustEntry("code")}, + {Similarity: 0.30, Payload: mustEntry("chat")}, + }} + d := classify(store, router.KNNClassifierOptions{K: 2, SimilarityThreshold: 0.80}) + Expect(d.Labels).To(Equal([]string{"code"})) + Expect(d.LabelScores).To(HaveLen(1)) + Expect(d.LabelScores[0].Score).To(BeNumerically("~", 1.0, 1e-9)) + }) + + It("degenerates to nearest-entry labels at K=1", func() { + store := &scriptedKNNStore{neighbors: []backend.Neighbor{ + {Similarity: 0.95, Payload: mustEntry("reasoning", "math")}, + }} + d := classify(store, router.KNNClassifierOptions{K: 1}) + Expect(d.Labels).To(Equal([]string{"math", "reasoning"})) + Expect(d.Score).To(BeNumerically("~", 1.0, 1e-9)) + }) + + It("skips corrupt payloads and falls back when nothing can vote", func() { + store := &scriptedKNNStore{neighbors: []backend.Neighbor{ + {Similarity: 0.90, Payload: []byte("not json")}, + }} + d := classify(store, router.KNNClassifierOptions{}) + Expect(d.Labels).To(BeEmpty()) + Expect(d.NearestSimilarity).To(BeNumerically("~", 0.90, 1e-9)) + }) + + It("returns the embed error so the middleware can fall back", func() { + embedder.table = nil // unknown prompt → fakeEmbedder errors + store := &scriptedKNNStore{} + c := router.NewKNNClassifier(embedder, store, router.KNNClassifierOptions{}) + _, err := c.Classify(ctx, probe) + Expect(err).To(HaveOccurred()) + }) + + It("returns the store error so the middleware can fall back", func() { + store := &scriptedKNNStore{err: errors.New("store offline")} + c := router.NewKNNClassifier(embedder, store, router.KNNClassifierOptions{}) + _, err := c.Classify(ctx, probe) + Expect(err).To(HaveOccurred()) + }) + + It("rejects corpus entries without labels at encode time", func() { + _, err := router.EncodeCorpusEntry(nil) + Expect(err).To(HaveOccurred()) + }) + + It("panics on missing embedder or store", func() { + Expect(func() { router.NewKNNClassifier(nil, &scriptedKNNStore{}, router.KNNClassifierOptions{}) }).To(Panic()) + Expect(func() { router.NewKNNClassifier(embedder, nil, router.KNNClassifierOptions{}) }).To(Panic()) + }) +}) diff --git a/core/services/routing/router/resolve.go b/core/services/routing/router/resolve.go index a474d6d4854f..486eedfdea5d 100644 --- a/core/services/routing/router/resolve.go +++ b/core/services/routing/router/resolve.go @@ -165,6 +165,7 @@ func (r *ResolveResult) ToDecisionRecord(id, correlationID, userID, source strin LatencyMs: r.Decision.Latency.Milliseconds(), Cached: r.Decision.Cached, CacheSimilarity: r.Decision.CacheSimilarity, + NearestSimilarity: r.Decision.NearestSimilarity, LabelScores: r.Decision.LabelScores, ActivationThreshold: r.Decision.ActivationThreshold, Source: source, diff --git a/core/services/routing/router/types.go b/core/services/routing/router/types.go index 05efdc3497fa..8e4d8f1da086 100644 --- a/core/services/routing/router/types.go +++ b/core/services/routing/router/types.go @@ -71,6 +71,13 @@ type Decision struct { // the cosine similarity of the cache hit (0 when not cached). Cached bool `json:"cached,omitempty"` CacheSimilarity float64 `json:"cache_similarity,omitempty"` + + // NearestSimilarity is the cosine similarity of the closest corpus + // entry the KNN classifier saw — set even when the decision fell + // through to the fallback because the probe was out of corpus range, + // which is exactly when an admin wants to know how far off the + // nearest labelled experience was. 0 for other classifiers. + NearestSimilarity float64 `json:"nearest_similarity,omitempty"` } // LabelScore is one entry in Decision.LabelScores — a policy label and @@ -127,6 +134,13 @@ const ( // `type:` field on that model's YAML controls which Reranker // library mode loads. See router/rerank.go. ClassifierColbert = "colbert" + + // ClassifierKNN picks labels by similarity-weighted vote over a + // curated corpus of labelled example prompts, with an epistemic + // gate: probes dissimilar from all corpus entries activate no + // labels and route to the fallback. Needs an embedding model and + // a seeded corpus, not a classifier_model. See router/knn.go. + ClassifierKNN = "knn" ) // LabelFallback is the synthetic label written to the decision diff --git a/docs/content/features/middleware.md b/docs/content/features/middleware.md index 397af3c9207d..c3733891c078 100644 --- a/docs/content/features/middleware.md +++ b/docs/content/features/middleware.md @@ -334,17 +334,20 @@ silent-bypass. ### Available classifiers -LocalAI ships two classifier implementations. Pick one with `classifier:` +LocalAI ships three classifier implementations. Pick one with `classifier:` in the router YAML: | Classifier | When to use | Underlying primitive | |---|---|---| | `score` (default) | Small classifier-tuned LM (Arch-Router-style). Best when label vocabulary is well-covered by next-token continuation. | `Score` gRPC primitive (llama-cpp, vLLM). | | `colbert` | When label descriptions are abstract or short and a next-token classifier produces flat distributions. Robust on long-form policy descriptions. | rerankers backend in ColBERT mode (e.g. `bge-m3-colbert` from the gallery). | +| `knn` | When you have (or can generate) labelled example prompts — including outcome-labelled production traffic. Deterministic, auditable, cheapest per request, and the only classifier with an explicit out-of-distribution fallback. | embeddings backend + local-store KNN over a persisted, curated corpus. | -Both classifiers share the same YAML shape: `classifier_model`, -`policies`, `candidates`, `fallback`, `activation_threshold`, -`classifier_cache_size`, and the optional `embedding_cache` block. +All three share `policies`, `candidates`, `fallback`, and +`classifier_cache_size`. `score` and `colbert` take a +`classifier_model` (+ `activation_threshold`, optional +`embedding_cache`); `knn` instead takes a `knn:` block and a corpus +seeded through the API. ### The Score classifier @@ -425,6 +428,109 @@ underlying scoring head loads — `colbert` for late-interaction MaxSim, `cross-encoder` for cross-attention scoring. The classifier itself is indifferent; pick the head that fits your latency / quality budget. +### The KNN classifier + +The `knn` classifier routes by **similarity-weighted vote over a +curated corpus of labelled example prompts**. Where `score` and +`colbert` ask a model's opinion per request, `knn` consults recorded +experience: each corpus entry is an example prompt plus the policy +labels it should activate. It needs no classifier model — just an +embedding model and a seeded corpus. + +```yaml +router: + classifier: knn + fallback: gpt-4o-proxy # used whenever the prompt is unlike all corpus entries + knn: + embedding_model: nomic-embed-text-v1.5 + k: 3 # neighbours that vote (default 3) + similarity_threshold: 0.80 # the epistemic gate (default 0.80) + vote_threshold: 0.5 # weighted vote share a label needs (default 0.5) + # store_name: router-corpus-smart-router # default "router-corpus-" + policies: + - label: code-generation + description: writing or debugging code + - label: casual-chat + description: small talk + candidates: + - model: qwen3-0.6b + labels: [casual-chat] + - model: qwen-coder + labels: [code-generation, casual-chat] +``` + +For each request: + +1. Embed the prompt with `knn.embedding_model`. +2. Fetch the `k` nearest corpus entries (cosine similarity). +3. **Epistemic gate**: entries below `similarity_threshold` cannot + vote. If none clears it, the classifier activates **no** labels and + the router uses `fallback` — a prompt unlike all labelled + experience is treated as *undecidable*, not guessed. The decision + log records `nearest_similarity` so you can see how far away the + closest labelled example was. +4. Each surviving neighbour votes for its labels, weighted by its + similarity; every label whose vote share clears `vote_threshold` + joins the active set. Candidate matching then proceeds exactly as + for the other classifiers. With `k: 1` this degenerates to + "nearest example's labels". + +#### Seeding and curating the corpus (API-only) + +Corpus entries may contain example user content, so they are managed +exclusively through the admin API — the UI never sends or displays +them, and no endpoint returns entry texts (inspection is label counts +only): + +```bash +# Seed labelled exemplars (embedded server-side; indexed immediately) +curl -X POST http://localhost:8080/api/router/smart-router/corpus \ + -H "Content-Type: application/json" \ + -d '{"entries": [ + {"text": "why does this segfault when I free the buffer twice", "labels": ["code-generation"]}, + {"text": "hey hows it going", "labels": ["casual-chat"]} + ]}' + +# Inspect — counts only, never texts +curl http://localhost:8080/api/router/smart-router/corpus/stats + +# Wipe (file + live index); reseed afterwards +curl -X DELETE http://localhost:8080/api/router/smart-router/corpus +``` + +Entry labels must be declared in `policies` (same invariant as +candidate labels), and duplicate texts are skipped rather than +double-weighted. Label your exemplars with *outcomes*, not topics, +when routing for difficulty: an entry recording "the small model +handled prompts like this" is exactly as useful as one recording that +it failed — grade a sample of production traffic against your +candidates and seed both. + +#### Persistence + +The corpus is persisted as one JSONL file per router under +`/router-corpus/` (text, labels, vector, embedding-model +fingerprint) — **the file is the source of truth** and survives +restarts; the local-store index is rebuilt from it at classifier build +time without re-embedding. Changing `knn.embedding_model` re-embeds +the corpus on the next load (restart LocalAI if the old index was +already live — mixed embedding spaces cannot be served). + +#### Tuning notes + +- **`similarity_threshold` is the safety knob.** Too low and the + router confidently extrapolates from unrelated exemplars; too high + and everything falls back. Watch `nearest_similarity` in the + decision log: fallback rows clustering just under the threshold mean + the corpus needs entries near that traffic (or the gate is too + tight). +- **`k` trades robustness for corpus density**: `k: 3` tolerates one + mislabelled neighbour; raise it only when every label region has + several exemplars. +- **`embedding_cache` is ignored** for `knn` (with a warning) — the + classifier is already an embedding KNN lookup; wrapping it in + another would embed twice for no additional information. + ### YAML reference ```yaml @@ -432,8 +538,9 @@ name: smart-router known_usecases: - chat router: - # `score` (Arch-Router-style next-token scoring) or `colbert` - # (rerank policy descriptions). See "Available classifiers" above. + # `score` (Arch-Router-style next-token scoring), `colbert` (rerank + # policy descriptions), or `knn` (vote over a labelled corpus). + # See "Available classifiers" above. classifier: score # A model loaded by LocalAI that supports the Score gRPC primitive @@ -553,8 +660,11 @@ For each request: The local-store collection is named `router-cache-` by default — each router gets its own collection so two routers can't -cross-contaminate. Collections persist on disk (local-store is the -canonical persistent vector backend), so the cache survives restarts. +cross-contaminate. The collection is **in-memory only**: local-store +keeps no on-disk artefact, so the embedding cache starts empty on every +restart and re-learns from live traffic. (The KNN classifier's corpus +does NOT have this limitation — its corpus file is the source of truth +and re-indexes on startup; see "The KNN classifier" above.) #### Tuning notes @@ -595,6 +705,9 @@ canonical usage log lives in `/api/usage` and correlates by request ID. |---|---|---|---| | GET | `/api/router/status` | any | Router configuration: each router model's classifier, policies, candidates. | | GET | `/api/router/decisions` | admin | Decision log with optional filters (`correlation_id`, `user_id`, `router_model`, `limit`). | +| POST | `/api/router/{name}/corpus` | admin | Seed the KNN corpus with labelled exemplars: `{"entries": [{"text": "...", "labels": ["..."]}]}`. Embedded server-side, persisted, indexed immediately. | +| GET | `/api/router/{name}/corpus/stats` | admin | KNN corpus size and per-label counts. Counts only — entry texts are never returned. | +| DELETE | `/api/router/{name}/corpus` | admin | Wipe the KNN corpus (file + live index). | | POST | `/api/score` | admin | Direct access to the `Score` gRPC primitive — useful for offline threshold tuning. Body: `{"model": "", "prompt": "", "candidates": ["label-a", ...], "length_normalize": true}`. The llama-cpp and vLLM backends implement Score; other backends return `UNIMPLEMENTED`. | ### MCP tools @@ -603,10 +716,14 @@ canonical usage log lives in `/api/usage` and correlates by request ID. |---|---|---| | `get_router_decisions` | read | Recent decision log with optional filters. | | `get_middleware_status` | read | Includes the router section listing configured router models. | - -Mutating routing config — adding a candidate, changing the classifier -model — is YAML-only today; reload with `POST /models/reload` to pick -up edits without restarting. +| `get_router_corpus_stats` | read | KNN corpus size and per-label counts (never texts). | +| `seed_router_corpus` | write | Add labelled exemplars to a KNN router's corpus. | +| `clear_router_corpus` | write | Wipe a KNN router's corpus. | + +Mutating the rest of the routing config — adding a candidate, changing +the classifier model — goes through the model-config surface +(`edit_model_config` / `PATCH /api/models/config-json/:name`); reload +with `POST /models/reload` to pick up YAML edits without restarting. ### Operational notes diff --git a/pkg/mcp/localaitools/client.go b/pkg/mcp/localaitools/client.go index f6f6114bef12..71ed137ee56b 100644 --- a/pkg/mcp/localaitools/client.go +++ b/pkg/mcp/localaitools/client.go @@ -101,4 +101,16 @@ type LocalAIClient interface { // /app/middleware Routing tab and for agent-driven introspection. // Admin-required when auth is on. GetRouterDecisions(ctx context.Context, q RouterDecisionsQuery) ([]RouterDecision, error) + + // GetRouterCorpusStats reports a knn router's corpus size and + // per-label counts — counts only, texts are never exposed. + GetRouterCorpusStats(ctx context.Context, routerModel string) (*RouterCorpusStats, error) + + // SeedRouterCorpus adds labelled exemplars to a knn router's + // corpus (embedded server-side, persisted, indexed immediately). + SeedRouterCorpus(ctx context.Context, req RouterCorpusSeedRequest) (*RouterCorpusSeedResult, error) + + // ClearRouterCorpus wipes a knn router's corpus — file and live + // index. + ClearRouterCorpus(ctx context.Context, routerModel string) (*RouterCorpusClearResult, error) } diff --git a/pkg/mcp/localaitools/coverage_test.go b/pkg/mcp/localaitools/coverage_test.go index 39a2ab544a88..f222eb9438d8 100644 --- a/pkg/mcp/localaitools/coverage_test.go +++ b/pkg/mcp/localaitools/coverage_test.go @@ -26,22 +26,23 @@ import ( // the contributor explicitly acknowledges the asymmetry. var toolToHTTPRoute = map[string]string{ // Read-only tools. - ToolGallerySearch: "GET /models/available", - ToolListInstalledModels: "GET / (welcome JSON, ModelsConfig field)", - ToolListGalleries: "GET /models/galleries", - ToolGetJobStatus: "GET /models/jobs/:uuid", - ToolGetModelConfig: "(none) — no JSON-only REST yet; httpapi.Client returns a documented stub", - ToolListBackends: "GET /backends", - ToolListKnownBackends: "GET /backends/known", - ToolSystemInfo: "GET / (welcome JSON)", - ToolListNodes: "GET /api/nodes", - ToolVRAMEstimate: "POST /api/models/vram-estimate", - ToolGetBranding: "GET /api/branding", - ToolGetUsageStats: "GET /api/usage (or /api/usage/all when all=true)", - ToolGetPIIEvents: "GET /api/pii/events", - ToolGetMiddlewareStatus: "GET /api/middleware/status", - ToolGetRouterDecisions: "GET /api/router/decisions", - ToolListAliases: "GET /api/aliases", + ToolGallerySearch: "GET /models/available", + ToolListInstalledModels: "GET / (welcome JSON, ModelsConfig field)", + ToolListGalleries: "GET /models/galleries", + ToolGetJobStatus: "GET /models/jobs/:uuid", + ToolGetModelConfig: "(none) — no JSON-only REST yet; httpapi.Client returns a documented stub", + ToolListBackends: "GET /backends", + ToolListKnownBackends: "GET /backends/known", + ToolSystemInfo: "GET / (welcome JSON)", + ToolListNodes: "GET /api/nodes", + ToolVRAMEstimate: "POST /api/models/vram-estimate", + ToolGetBranding: "GET /api/branding", + ToolGetUsageStats: "GET /api/usage (or /api/usage/all when all=true)", + ToolGetPIIEvents: "GET /api/pii/events", + ToolGetMiddlewareStatus: "GET /api/middleware/status", + ToolGetRouterDecisions: "GET /api/router/decisions", + ToolGetRouterCorpusStats: "GET /api/router/:name/corpus/stats", + ToolListAliases: "GET /api/aliases", // Mutating tools. ToolInstallModel: "POST /models/apply", @@ -55,6 +56,8 @@ var toolToHTTPRoute = map[string]string{ ToolToggleModelPinned: "PUT /models/toggle-pinned/:name/:action", ToolSetBranding: "POST /api/settings (instance_name, instance_tagline)", ToolSetAlias: "PATCH /api/models/config-json/:name (swap) or POST /models/import (create)", + ToolSeedRouterCorpus: "POST /api/router/:name/corpus", + ToolClearRouterCorpus: "DELETE /api/router/:name/corpus", } // allKnownTools is the union of expectedFullCatalog (defined in diff --git a/pkg/mcp/localaitools/dto.go b/pkg/mcp/localaitools/dto.go index f8aa98eeeac6..bb12fa39a2a6 100644 --- a/pkg/mcp/localaitools/dto.go +++ b/pkg/mcp/localaitools/dto.go @@ -278,6 +278,50 @@ type RouterDecision struct { CreatedAt string `json:"created_at"` } +// RouterCorpusEntry is one labelled exemplar for seed_router_corpus. +type RouterCorpusEntry struct { + Text string `json:"text" jsonschema:"Example prompt text. Embedded server-side and persisted; NEVER returned by any tool or endpoint."` + Labels []string `json:"labels" jsonschema:"Policy labels this exemplar activates. Every label must be declared in the router's policies."` +} + +// RouterCorpusSeedRequest is the input for seed_router_corpus. +type RouterCorpusSeedRequest struct { + Router string `json:"router" jsonschema:"Router model name — the ModelConfig with classifier: knn and a router.knn block."` + Entries []RouterCorpusEntry `json:"entries" jsonschema:"Labelled exemplars to add. Duplicate texts are skipped, not double-weighted."` +} + +// RouterCorpusSeedResult reports the outcome of seed_router_corpus. +type RouterCorpusSeedResult struct { + Router string `json:"router"` + Added int `json:"added"` + Skipped int `json:"skipped"` + Total int `json:"total"` + LabelCounts map[string]int `json:"label_counts"` +} + +// RouterCorpusQuery names the router whose corpus to inspect or clear. +type RouterCorpusQuery struct { + Router string `json:"router" jsonschema:"Router model name."` +} + +// RouterCorpusStats is the count-only inspection surface for a +// router's KNN corpus. Entry texts are never exposed. +type RouterCorpusStats struct { + Router string `json:"router"` + StoreName string `json:"store_name"` + EmbeddingModel string `json:"embedding_model"` + Total int `json:"total"` + LabelCounts map[string]int `json:"label_counts"` + EmbeddingModels []string `json:"embedding_models,omitempty"` +} + +// RouterCorpusClearResult reports how many entries clear_router_corpus +// removed. +type RouterCorpusClearResult struct { + Router string `json:"router"` + Cleared int `json:"cleared"` +} + // VRAMEstimateRequest is the input for vram_estimate. The output type is // pkg/vram.EstimateResult — used directly via the LocalAIClient interface // so the LLM sees the same shape (size_bytes/size_display/vram_bytes/ diff --git a/pkg/mcp/localaitools/fakes_test.go b/pkg/mcp/localaitools/fakes_test.go index 388245ad210a..09056edc2dfa 100644 --- a/pkg/mcp/localaitools/fakes_test.go +++ b/pkg/mcp/localaitools/fakes_test.go @@ -297,3 +297,18 @@ func (f *fakeClient) GetMiddlewareStatus(_ context.Context) (*MiddlewareStatus, Router: MiddlewareRouterStatus{Configured: false, Models: []string{}}, }, nil } + +func (f *fakeClient) GetRouterCorpusStats(_ context.Context, routerModel string) (*RouterCorpusStats, error) { + f.record("GetRouterCorpusStats", routerModel) + return &RouterCorpusStats{Router: routerModel, LabelCounts: map[string]int{}}, nil +} + +func (f *fakeClient) SeedRouterCorpus(_ context.Context, req RouterCorpusSeedRequest) (*RouterCorpusSeedResult, error) { + f.record("SeedRouterCorpus", req) + return &RouterCorpusSeedResult{Router: req.Router, Added: len(req.Entries), LabelCounts: map[string]int{}}, nil +} + +func (f *fakeClient) ClearRouterCorpus(_ context.Context, routerModel string) (*RouterCorpusClearResult, error) { + f.record("ClearRouterCorpus", routerModel) + return &RouterCorpusClearResult{Router: routerModel}, nil +} diff --git a/pkg/mcp/localaitools/httpapi/client.go b/pkg/mcp/localaitools/httpapi/client.go index 90ec332e2b61..dc0deff7f61a 100644 --- a/pkg/mcp/localaitools/httpapi/client.go +++ b/pkg/mcp/localaitools/httpapi/client.go @@ -708,3 +708,34 @@ func containsTagExact(tags []string, lowerNeedle string) bool { } return false } + +func (c *Client) GetRouterCorpusStats(ctx context.Context, routerModel string) (*localaitools.RouterCorpusStats, error) { + var out localaitools.RouterCorpusStats + path := fmt.Sprintf("/api/router/%s/corpus/stats", url.PathEscape(routerModel)) + if err := c.do(ctx, http.MethodGet, path, nil, &out); err != nil { + return nil, err + } + return &out, nil +} + +func (c *Client) SeedRouterCorpus(ctx context.Context, req localaitools.RouterCorpusSeedRequest) (*localaitools.RouterCorpusSeedResult, error) { + // The REST body carries entries only; the router rides in the path. + body := struct { + Entries []localaitools.RouterCorpusEntry `json:"entries"` + }{Entries: req.Entries} + var out localaitools.RouterCorpusSeedResult + path := fmt.Sprintf("/api/router/%s/corpus", url.PathEscape(req.Router)) + if err := c.do(ctx, http.MethodPost, path, body, &out); err != nil { + return nil, err + } + return &out, nil +} + +func (c *Client) ClearRouterCorpus(ctx context.Context, routerModel string) (*localaitools.RouterCorpusClearResult, error) { + var out localaitools.RouterCorpusClearResult + path := fmt.Sprintf("/api/router/%s/corpus", url.PathEscape(routerModel)) + if err := c.do(ctx, http.MethodDelete, path, nil, &out); err != nil { + return nil, err + } + return &out, nil +} diff --git a/pkg/mcp/localaitools/inproc/client.go b/pkg/mcp/localaitools/inproc/client.go index e62934ccc69a..3e0caeffddb6 100644 --- a/pkg/mcp/localaitools/inproc/client.go +++ b/pkg/mcp/localaitools/inproc/client.go @@ -13,6 +13,7 @@ import ( "path/filepath" "github.com/google/uuid" + "github.com/mudler/LocalAI/core/backend" "github.com/mudler/LocalAI/core/config" "github.com/mudler/LocalAI/core/gallery" "github.com/mudler/LocalAI/core/gallery/importers" @@ -21,6 +22,7 @@ import ( "github.com/mudler/LocalAI/core/services/galleryop" "github.com/mudler/LocalAI/core/services/modeladmin" "github.com/mudler/LocalAI/core/services/routing/billing" + "github.com/mudler/LocalAI/core/services/routing/corpus" "github.com/mudler/LocalAI/core/services/routing/pii" "github.com/mudler/LocalAI/core/services/routing/router" "github.com/mudler/LocalAI/internal" @@ -63,6 +65,16 @@ type Client struct { // returns when stats are disabled. RouterDecisions router.DecisionStore + // RouterCorpus + the two factories back the corpus tools + // (seed_router_corpus / get_router_corpus_stats / + // clear_router_corpus). nil RouterCorpus makes them return an + // "unavailable" error. The factories mirror the middleware's + // ClassifierDeps so the tools and the request path resolve models + // and store namespaces identically. + RouterCorpus *corpus.Manager + RouterEmbedder func(modelName string) backend.Embedder + RouterVectorStore func(storeName string) backend.VectorStore + modelAdmin *modeladmin.ConfigService } @@ -807,3 +819,111 @@ func capabilityFlagsOf(m *config.ModelConfig) []string { } return out } + +// resolveKNNRouter mirrors the REST endpoint's resolution: the model +// must exist and declare a router.knn block; the store name defaults +// the same way buildClassifier defaults it. +func (c *Client) resolveKNNRouter(routerModel string) (*config.ModelConfig, string, error) { + cfg, err := c.ConfigLoader.LoadModelConfigFileByNameDefaultOptions(routerModel, c.AppConfig) + if err != nil { + return nil, "", fmt.Errorf("load model config: %w", err) + } + if cfg == nil || cfg.Name == "" { + return nil, "", fmt.Errorf("model %q not found", routerModel) + } + if cfg.Router.KNN == nil || cfg.Router.KNN.EmbeddingModel == "" { + return nil, "", fmt.Errorf("model %q has no router.knn block (set classifier: knn and knn.embedding_model first)", routerModel) + } + storeName := cfg.Router.KNN.StoreName + if storeName == "" { + storeName = "router-corpus-" + cfg.Name + } + return cfg, storeName, nil +} + +func (c *Client) GetRouterCorpusStats(_ context.Context, routerModel string) (*localaitools.RouterCorpusStats, error) { + if c.RouterCorpus == nil { + return nil, errors.New("router corpus manager unavailable") + } + cfg, storeName, err := c.resolveKNNRouter(routerModel) + if err != nil { + return nil, err + } + stats, err := c.RouterCorpus.Stats(storeName) + if err != nil { + return nil, err + } + return &localaitools.RouterCorpusStats{ + Router: cfg.Name, + StoreName: storeName, + EmbeddingModel: cfg.Router.KNN.EmbeddingModel, + Total: stats.Total, + LabelCounts: stats.LabelCounts, + EmbeddingModels: stats.EmbeddingModels, + }, nil +} + +func (c *Client) SeedRouterCorpus(ctx context.Context, req localaitools.RouterCorpusSeedRequest) (*localaitools.RouterCorpusSeedResult, error) { + if c.RouterCorpus == nil || c.RouterEmbedder == nil || c.RouterVectorStore == nil { + return nil, errors.New("router corpus manager unavailable") + } + cfg, storeName, err := c.resolveKNNRouter(req.Router) + if err != nil { + return nil, err + } + + // Same invariant the REST endpoint enforces: labels must be + // declared policies, so a typo can't create an unroutable label. + declared := map[string]struct{}{} + for _, p := range cfg.Router.Policies { + declared[p.Label] = struct{}{} + } + entries := make([]corpus.Entry, 0, len(req.Entries)) + for i, e := range req.Entries { + for _, l := range e.Labels { + if _, ok := declared[l]; !ok { + return nil, fmt.Errorf("entry %d: label %q is not declared in router policies", i, l) + } + } + entries = append(entries, corpus.Entry{Text: e.Text, Labels: e.Labels}) + } + + embedder := c.RouterEmbedder(cfg.Router.KNN.EmbeddingModel) + if embedder == nil { + return nil, fmt.Errorf("embedding_model %q not loadable", cfg.Router.KNN.EmbeddingModel) + } + added, skipped, err := c.RouterCorpus.Add(ctx, storeName, cfg.Router.KNN.EmbeddingModel, embedder, c.RouterVectorStore(storeName), entries) + if err != nil { + return nil, err + } + stats, err := c.RouterCorpus.Stats(storeName) + if err != nil { + return nil, err + } + return &localaitools.RouterCorpusSeedResult{ + Router: cfg.Name, + Added: added, + Skipped: skipped, + Total: stats.Total, + LabelCounts: stats.LabelCounts, + }, nil +} + +func (c *Client) ClearRouterCorpus(ctx context.Context, routerModel string) (*localaitools.RouterCorpusClearResult, error) { + if c.RouterCorpus == nil { + return nil, errors.New("router corpus manager unavailable") + } + cfg, storeName, err := c.resolveKNNRouter(routerModel) + if err != nil { + return nil, err + } + var store backend.VectorStore + if c.RouterVectorStore != nil { + store = c.RouterVectorStore(storeName) + } + cleared, err := c.RouterCorpus.Clear(ctx, storeName, store) + if err != nil { + return nil, err + } + return &localaitools.RouterCorpusClearResult{Router: cfg.Name, Cleared: cleared}, nil +} diff --git a/pkg/mcp/localaitools/server_test.go b/pkg/mcp/localaitools/server_test.go index 052ca1e8b8f1..b648ca608167 100644 --- a/pkg/mcp/localaitools/server_test.go +++ b/pkg/mcp/localaitools/server_test.go @@ -81,6 +81,7 @@ var expectedFullCatalog = sortedStrings( ToolGetMiddlewareStatus, ToolGetModelConfig, ToolGetPIIEvents, + ToolGetRouterCorpusStats, ToolGetRouterDecisions, ToolGetUsageStats, ToolImportModelURI, @@ -93,6 +94,8 @@ var expectedFullCatalog = sortedStrings( ToolListKnownBackends, ToolListNodes, ToolReloadModels, + ToolSeedRouterCorpus, + ToolClearRouterCorpus, ToolSetAlias, ToolSetBranding, ToolSystemInfo, @@ -110,6 +113,7 @@ var expectedReadOnlyCatalog = sortedStrings( ToolGetMiddlewareStatus, ToolGetModelConfig, ToolGetPIIEvents, + ToolGetRouterCorpusStats, ToolGetRouterDecisions, ToolGetUsageStats, ToolListAliases, diff --git a/pkg/mcp/localaitools/tools.go b/pkg/mcp/localaitools/tools.go index 263bd791ef3f..03f13179540e 100644 --- a/pkg/mcp/localaitools/tools.go +++ b/pkg/mcp/localaitools/tools.go @@ -8,21 +8,22 @@ package localaitools // SafetyAnchors guards that those strings stay aligned. const ( // Read-only tools. - ToolGallerySearch = "gallery_search" - ToolListInstalledModels = "list_installed_models" - ToolListGalleries = "list_galleries" - ToolGetJobStatus = "get_job_status" - ToolGetModelConfig = "get_model_config" - ToolListBackends = "list_backends" - ToolListKnownBackends = "list_known_backends" - ToolSystemInfo = "system_info" - ToolListNodes = "list_nodes" - ToolVRAMEstimate = "vram_estimate" - ToolGetBranding = "get_branding" - ToolGetUsageStats = "get_usage_stats" - ToolGetPIIEvents = "get_pii_events" - ToolGetMiddlewareStatus = "get_middleware_status" - ToolGetRouterDecisions = "get_router_decisions" + ToolGallerySearch = "gallery_search" + ToolListInstalledModels = "list_installed_models" + ToolListGalleries = "list_galleries" + ToolGetJobStatus = "get_job_status" + ToolGetModelConfig = "get_model_config" + ToolListBackends = "list_backends" + ToolListKnownBackends = "list_known_backends" + ToolSystemInfo = "system_info" + ToolListNodes = "list_nodes" + ToolVRAMEstimate = "vram_estimate" + ToolGetBranding = "get_branding" + ToolGetUsageStats = "get_usage_stats" + ToolGetPIIEvents = "get_pii_events" + ToolGetMiddlewareStatus = "get_middleware_status" + ToolGetRouterDecisions = "get_router_decisions" + ToolGetRouterCorpusStats = "get_router_corpus_stats" // Mutating tools — guarded by Options.DisableMutating and the // LLM-side safety prompt (see prompts/10_safety.md). @@ -37,6 +38,8 @@ const ( ToolToggleModelPinned = "toggle_model_pinned" ToolSetBranding = "set_branding" ToolSetAlias = "set_alias" + ToolSeedRouterCorpus = "seed_router_corpus" + ToolClearRouterCorpus = "clear_router_corpus" // ToolListAliases is read-only but lives here so the alias tools stay // grouped; the catalog tests assert its read-only placement. diff --git a/pkg/mcp/localaitools/tools_middleware.go b/pkg/mcp/localaitools/tools_middleware.go index 5dd8066fd4fb..d761768a80d6 100644 --- a/pkg/mcp/localaitools/tools_middleware.go +++ b/pkg/mcp/localaitools/tools_middleware.go @@ -18,7 +18,13 @@ import ( // PII detection policy lives on each detector model's pii_detection // block, edited via the model-config tools — there is no global pattern // set to mutate here anymore. -func registerMiddlewareTools(s *mcp.Server, client LocalAIClient, _ Options) { +// +// The router corpus tools manage the knn classifier's labelled +// exemplar store: get_router_corpus_stats is read-only (counts, never +// texts); seed_router_corpus / clear_router_corpus mutate routing +// behaviour and sit behind the DisableMutating gate like the +// model-config tools. +func registerMiddlewareTools(s *mcp.Server, client LocalAIClient, opts Options) { mcp.AddTool(s, &mcp.Tool{ Name: ToolGetMiddlewareStatus, Description: "Aggregated routing-module status: per-model resolved PII state and the NER detector models each one references, recent event count, plus the active router models and their classifier configs. Read-only.", @@ -40,4 +46,53 @@ func registerMiddlewareTools(s *mcp.Server, client LocalAIClient, _ Options) { } return jsonResult(decisions), nil, nil }) + + mcp.AddTool(s, &mcp.Tool{ + Name: ToolGetRouterCorpusStats, + Description: "Size and per-label exemplar counts of a knn router's corpus. Counts only — corpus texts are never exposed by any tool. Read-only.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, args RouterCorpusQuery) (*mcp.CallToolResult, any, error) { + if args.Router == "" { + return errorResultf("router is required"), nil, nil + } + stats, err := client.GetRouterCorpusStats(ctx, args.Router) + if err != nil { + return errorResult(err), nil, nil + } + return jsonResult(stats), nil, nil + }) + + if opts.DisableMutating { + return + } + + mcp.AddTool(s, &mcp.Tool{ + Name: ToolSeedRouterCorpus, + Description: "Add labelled example prompts to a knn router's corpus. Entries are embedded server-side with the router's knn.embedding_model, persisted, and indexed immediately — routing behaviour changes right away, so confirm with the user first (safety rule 1). Labels must be declared in the router's policies; duplicate texts are skipped.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, args RouterCorpusSeedRequest) (*mcp.CallToolResult, any, error) { + if args.Router == "" { + return errorResultf("router is required"), nil, nil + } + if len(args.Entries) == 0 { + return errorResultf("entries is required"), nil, nil + } + res, err := client.SeedRouterCorpus(ctx, args) + if err != nil { + return errorResult(err), nil, nil + } + return jsonResult(res), nil, nil + }) + + mcp.AddTool(s, &mcp.Tool{ + Name: ToolClearRouterCorpus, + Description: "Wipe a knn router's corpus — the persisted file and the live index. The router falls back for every prompt until reseeded. Destructive; requires explicit user confirmation per safety rule 1.", + }, func(ctx context.Context, _ *mcp.CallToolRequest, args RouterCorpusQuery) (*mcp.CallToolResult, any, error) { + if args.Router == "" { + return errorResultf("router is required"), nil, nil + } + res, err := client.ClearRouterCorpus(ctx, args.Router) + if err != nil { + return errorResult(err), nil, nil + } + return jsonResult(res), nil, nil + }) } diff --git a/swagger/docs.go b/swagger/docs.go index e7b6b9acf296..7ef0e1aa1cd3 100644 --- a/swagger/docs.go +++ b/swagger/docs.go @@ -1284,6 +1284,181 @@ const docTemplate = `{ } } }, + "/api/router/{name}/corpus": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "router" + ], + "summary": "Seed the KNN routing corpus with labelled example prompts", + "parameters": [ + { + "type": "string", + "description": "router model name", + "name": "name", + "in": "path", + "required": true + }, + { + "description": "labelled exemplars", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/schema.RouterCorpusAddRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schema.RouterCorpusAddResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "404": { + "description": "Not Found", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + }, + "delete": { + "produces": [ + "application/json" + ], + "tags": [ + "router" + ], + "summary": "Clear a router's KNN corpus", + "parameters": [ + { + "type": "string", + "description": "router model name", + "name": "name", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schema.RouterCorpusClearResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "404": { + "description": "Not Found", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } + }, + "/api/router/{name}/corpus/stats": { + "get": { + "produces": [ + "application/json" + ], + "tags": [ + "router" + ], + "summary": "Inspect a router's KNN corpus (label counts only, never texts)", + "parameters": [ + { + "type": "string", + "description": "router model name", + "name": "name", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schema.RouterCorpusStatsResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "404": { + "description": "Not Found", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } + }, "/api/traces": { "get": { "description": "Returns captured API exchange traces (request/response pairs) in reverse chronological order", @@ -6139,6 +6314,99 @@ const docTemplate = `{ } } }, + "schema.RouterCorpusAddRequest": { + "type": "object", + "properties": { + "entries": { + "type": "array", + "items": { + "$ref": "#/definitions/schema.RouterCorpusEntry" + } + } + } + }, + "schema.RouterCorpusAddResponse": { + "type": "object", + "properties": { + "added": { + "description": "Added is how many entries were embedded, persisted, and indexed.", + "type": "integer" + }, + "label_counts": { + "description": "LabelCounts is the per-label exemplar count after the call.", + "type": "object", + "additionalProperties": { + "type": "integer" + } + }, + "router": { + "type": "string" + }, + "skipped": { + "description": "Skipped counts entries whose text was already in the corpus —\nduplicates are rejected rather than double-weighted.", + "type": "integer" + }, + "total": { + "description": "Total is the corpus size after the call.", + "type": "integer" + } + } + }, + "schema.RouterCorpusClearResponse": { + "type": "object", + "properties": { + "cleared": { + "type": "integer" + }, + "router": { + "type": "string" + } + } + }, + "schema.RouterCorpusEntry": { + "type": "object", + "properties": { + "labels": { + "type": "array", + "items": { + "type": "string" + } + }, + "text": { + "type": "string" + } + } + }, + "schema.RouterCorpusStatsResponse": { + "type": "object", + "properties": { + "embedding_model": { + "type": "string" + }, + "embedding_models": { + "description": "EmbeddingModels lists the embedder fingerprints present in the\npersisted corpus; more than one means part of the corpus is\npending re-embedding on the next load.", + "type": "array", + "items": { + "type": "string" + } + }, + "label_counts": { + "type": "object", + "additionalProperties": { + "type": "integer" + } + }, + "router": { + "type": "string" + }, + "store_name": { + "type": "string" + }, + "total": { + "type": "integer" + } + } + }, "schema.RouterDecideRequest": { "type": "object", "properties": { @@ -6186,6 +6454,10 @@ const docTemplate = `{ "description": "LatencyMs is the classifier's wall-clock cost.", "type": "integer" }, + "nearest_similarity": { + "description": "NearestSimilarity is the cosine similarity of the closest KNN\ncorpus entry — populated by the knn classifier even when the\ndecision fell back because the probe was out of corpus range.\n0 for other classifiers.", + "type": "number" + }, "router": { "description": "Router echoes the requested router model.", "type": "string" diff --git a/swagger/swagger.json b/swagger/swagger.json index 4f9695bb109f..8e8b553d809d 100644 --- a/swagger/swagger.json +++ b/swagger/swagger.json @@ -1281,6 +1281,181 @@ } } }, + "/api/router/{name}/corpus": { + "post": { + "consumes": [ + "application/json" + ], + "produces": [ + "application/json" + ], + "tags": [ + "router" + ], + "summary": "Seed the KNN routing corpus with labelled example prompts", + "parameters": [ + { + "type": "string", + "description": "router model name", + "name": "name", + "in": "path", + "required": true + }, + { + "description": "labelled exemplars", + "name": "request", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/schema.RouterCorpusAddRequest" + } + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schema.RouterCorpusAddResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "404": { + "description": "Not Found", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + }, + "delete": { + "produces": [ + "application/json" + ], + "tags": [ + "router" + ], + "summary": "Clear a router's KNN corpus", + "parameters": [ + { + "type": "string", + "description": "router model name", + "name": "name", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schema.RouterCorpusClearResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "404": { + "description": "Not Found", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } + }, + "/api/router/{name}/corpus/stats": { + "get": { + "produces": [ + "application/json" + ], + "tags": [ + "router" + ], + "summary": "Inspect a router's KNN corpus (label counts only, never texts)", + "parameters": [ + { + "type": "string", + "description": "router model name", + "name": "name", + "in": "path", + "required": true + } + ], + "responses": { + "200": { + "description": "OK", + "schema": { + "$ref": "#/definitions/schema.RouterCorpusStatsResponse" + } + }, + "400": { + "description": "Bad Request", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "404": { + "description": "Not Found", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + }, + "500": { + "description": "Internal Server Error", + "schema": { + "type": "object", + "additionalProperties": { + "type": "string" + } + } + } + } + } + }, "/api/traces": { "get": { "description": "Returns captured API exchange traces (request/response pairs) in reverse chronological order", @@ -6136,6 +6311,99 @@ } } }, + "schema.RouterCorpusAddRequest": { + "type": "object", + "properties": { + "entries": { + "type": "array", + "items": { + "$ref": "#/definitions/schema.RouterCorpusEntry" + } + } + } + }, + "schema.RouterCorpusAddResponse": { + "type": "object", + "properties": { + "added": { + "description": "Added is how many entries were embedded, persisted, and indexed.", + "type": "integer" + }, + "label_counts": { + "description": "LabelCounts is the per-label exemplar count after the call.", + "type": "object", + "additionalProperties": { + "type": "integer" + } + }, + "router": { + "type": "string" + }, + "skipped": { + "description": "Skipped counts entries whose text was already in the corpus —\nduplicates are rejected rather than double-weighted.", + "type": "integer" + }, + "total": { + "description": "Total is the corpus size after the call.", + "type": "integer" + } + } + }, + "schema.RouterCorpusClearResponse": { + "type": "object", + "properties": { + "cleared": { + "type": "integer" + }, + "router": { + "type": "string" + } + } + }, + "schema.RouterCorpusEntry": { + "type": "object", + "properties": { + "labels": { + "type": "array", + "items": { + "type": "string" + } + }, + "text": { + "type": "string" + } + } + }, + "schema.RouterCorpusStatsResponse": { + "type": "object", + "properties": { + "embedding_model": { + "type": "string" + }, + "embedding_models": { + "description": "EmbeddingModels lists the embedder fingerprints present in the\npersisted corpus; more than one means part of the corpus is\npending re-embedding on the next load.", + "type": "array", + "items": { + "type": "string" + } + }, + "label_counts": { + "type": "object", + "additionalProperties": { + "type": "integer" + } + }, + "router": { + "type": "string" + }, + "store_name": { + "type": "string" + }, + "total": { + "type": "integer" + } + } + }, "schema.RouterDecideRequest": { "type": "object", "properties": { @@ -6183,6 +6451,10 @@ "description": "LatencyMs is the classifier's wall-clock cost.", "type": "integer" }, + "nearest_similarity": { + "description": "NearestSimilarity is the cosine similarity of the closest KNN\ncorpus entry — populated by the knn classifier even when the\ndecision fell back because the probe was out of corpus range.\n0 for other classifiers.", + "type": "number" + }, "router": { "description": "Router echoes the requested router model.", "type": "string" diff --git a/swagger/swagger.yaml b/swagger/swagger.yaml index cbb17d719bd0..e28591f5cb55 100644 --- a/swagger/swagger.yaml +++ b/swagger/swagger.yaml @@ -2059,6 +2059,73 @@ definitions: redacted_text: type: string type: object + schema.RouterCorpusAddRequest: + properties: + entries: + items: + $ref: '#/definitions/schema.RouterCorpusEntry' + type: array + type: object + schema.RouterCorpusAddResponse: + properties: + added: + description: Added is how many entries were embedded, persisted, and indexed. + type: integer + label_counts: + additionalProperties: + type: integer + description: LabelCounts is the per-label exemplar count after the call. + type: object + router: + type: string + skipped: + description: |- + Skipped counts entries whose text was already in the corpus — + duplicates are rejected rather than double-weighted. + type: integer + total: + description: Total is the corpus size after the call. + type: integer + type: object + schema.RouterCorpusClearResponse: + properties: + cleared: + type: integer + router: + type: string + type: object + schema.RouterCorpusEntry: + properties: + labels: + items: + type: string + type: array + text: + type: string + type: object + schema.RouterCorpusStatsResponse: + properties: + embedding_model: + type: string + embedding_models: + description: |- + EmbeddingModels lists the embedder fingerprints present in the + persisted corpus; more than one means part of the corpus is + pending re-embedding on the next load. + items: + type: string + type: array + label_counts: + additionalProperties: + type: integer + type: object + router: + type: string + store_name: + type: string + total: + type: integer + type: object schema.RouterDecideRequest: properties: input: @@ -2110,6 +2177,13 @@ definitions: latency_ms: description: LatencyMs is the classifier's wall-clock cost. type: integer + nearest_similarity: + description: |- + NearestSimilarity is the cosine similarity of the closest KNN + corpus entry — populated by the knn classifier even when the + decision fell back because the probe was out of corpus range. + 0 for other classifiers. + type: number router: description: Router echoes the requested router model. type: string @@ -3337,6 +3411,121 @@ paths: summary: Redact PII in a string by applying the configured policy. tags: - pii + /api/router/{name}/corpus: + delete: + parameters: + - description: router model name + in: path + name: name + required: true + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/schema.RouterCorpusClearResponse' + "400": + description: Bad Request + schema: + additionalProperties: + type: string + type: object + "404": + description: Not Found + schema: + additionalProperties: + type: string + type: object + "500": + description: Internal Server Error + schema: + additionalProperties: + type: string + type: object + summary: Clear a router's KNN corpus + tags: + - router + post: + consumes: + - application/json + parameters: + - description: router model name + in: path + name: name + required: true + type: string + - description: labelled exemplars + in: body + name: request + required: true + schema: + $ref: '#/definitions/schema.RouterCorpusAddRequest' + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/schema.RouterCorpusAddResponse' + "400": + description: Bad Request + schema: + additionalProperties: + type: string + type: object + "404": + description: Not Found + schema: + additionalProperties: + type: string + type: object + "500": + description: Internal Server Error + schema: + additionalProperties: + type: string + type: object + summary: Seed the KNN routing corpus with labelled example prompts + tags: + - router + /api/router/{name}/corpus/stats: + get: + parameters: + - description: router model name + in: path + name: name + required: true + type: string + produces: + - application/json + responses: + "200": + description: OK + schema: + $ref: '#/definitions/schema.RouterCorpusStatsResponse' + "400": + description: Bad Request + schema: + additionalProperties: + type: string + type: object + "404": + description: Not Found + schema: + additionalProperties: + type: string + type: object + "500": + description: Internal Server Error + schema: + additionalProperties: + type: string + type: object + summary: Inspect a router's KNN corpus (label counts only, never texts) + tags: + - router /api/router/decide: post: consumes: