core LLM completion layer#2443
Conversation
Add new pkg/ai package that extracts and centralizes model interaction logic from runtime. The package reuses existing types from chat, tools, and provider packages without moving them.
|
@dgageot mind taking a look ? Extracted the LLM interaction logic from |
|
@dgageot just a gentle reminder — any update on this PR ? |
aheritier
left a comment
There was a problem hiding this comment.
Good refactoring — the pkg/ai package has a clean API, proper separation of concerns, and solid test coverage with golden files. The interceptor pattern is well designed.
A few items need attention before merge:
- Merge conflicts — branch is in a dirty state, needs rebase.
WithRetriessemantics mismatch — doc says "n + 1 attempts" but code gives exactly n (see inline).- Token usage accumulation bug — loses data when
r.Usage != nil && r2.Usage == nil(see inline). - Unguarded
slices.IndexFuncresult — can panic if model ID not found (see inline). - Cooldown behavior change — the old code only set cooldown when the primary failed with a non-retryable error; the new code always sets cooldown when any fallback succeeds. This is a semantic change that may cause undesired pinning after transient retryable failures exhaust retries.
CI is green on the last push (lint, test, build all pass).
|
/review |
aheritier
left a comment
There was a problem hiding this comment.
Inline comments attached.
| // errors (5xx, timeouts). The total attempts per model is n + 1. | ||
| func WithRetries(n int) Option { | ||
| return func(c *completion) { | ||
| c.retries = n |
There was a problem hiding this comment.
Blocking: The doc says "The total attempts per model is n + 1" but the code uses for retry := range c.retries which yields exactly n iterations. WithRetries(3) = 3 attempts, not 4.
Either fix the doc to "The total attempts per model is n" or change the loop to range c.retries + 1.
|
|
||
| if r2.Usage != nil && r.Usage != nil { | ||
| r2.Usage = &chat.Usage{ | ||
| InputTokens: r.Usage.InputTokens + r2.Usage.InputTokens, | ||
| OutputTokens: r.Usage.OutputTokens + r2.Usage.OutputTokens, | ||
| CachedInputTokens: r.Usage.CachedInputTokens + r2.Usage.CachedInputTokens, | ||
| CacheWriteTokens: r.Usage.CacheWriteTokens + r2.Usage.CacheWriteTokens, | ||
| ReasoningTokens: r.Usage.ReasoningTokens + r2.Usage.ReasoningTokens, | ||
| } | ||
| } | ||
|
|
||
| return r2, nil | ||
| } |
There was a problem hiding this comment.
Non-blocking: If r.Usage != nil but r2.Usage == nil (provider didn't return usage on the recursive call), the first call's token counts are silently lost. Consider:
switch {
case r2.Usage != nil && r.Usage != nil:
r2.Usage = &chat.Usage{
InputTokens: r.Usage.InputTokens + r2.Usage.InputTokens,
OutputTokens: r.Usage.OutputTokens + r2.Usage.OutputTokens,
CachedInputTokens: r.Usage.CachedInputTokens + r2.Usage.CachedInputTokens,
CacheWriteTokens: r.Usage.CacheWriteTokens + r2.Usage.CacheWriteTokens,
ReasoningTokens: r.Usage.ReasoningTokens + r2.Usage.ReasoningTokens,
}
case r2.Usage == nil && r.Usage != nil:
r2.Usage = r.Usage
}| }) | ||
|
|
||
| // Rotate to put the responding model first. | ||
| c.models = append(c.models[idx:], c.models[:idx]...) | ||
| } | ||
|
|
||
| r2, err := c.generate(ctx) |
There was a problem hiding this comment.
Non-blocking: slices.IndexFunc returns -1 when no match is found. If r.Model doesn't match any entry in c.models (e.g. a router provider exposing a different sub-model ID), this will panic on c.models[-1:].
Add a guard:
idx := slices.IndexFunc(c.models, func(m provider.Provider) bool {
return m.ID() == r.Model
})
if idx > 0 {
c.models = append(c.models[idx:], c.models[:idx]...)
}| for _, m := range models { | ||
| if m.ID() == res.Model { | ||
| usedModel = m | ||
| break | ||
| } | ||
| return retryDecisionContinue | ||
| } | ||
|
|
||
| if !retryable { | ||
| slog.Error("Non-retryable error from model", | ||
| "agent", a.Name(), | ||
| "model", modelEntry.provider.ID(), | ||
| "error", err) | ||
| if !modelEntry.isFallback { | ||
| *primaryFailedWithNonRetryable = true | ||
| // Handle cooldown state based on which model succeeded |
There was a problem hiding this comment.
Question: This now unconditionally sets cooldown whenever a fallback succeeds. The old code only set cooldown when primaryFailedWithNonRetryable was true (429 or 4xx). If the primary exhausts retries on transient 5xx errors and a fallback picks up, the runtime will now pin to that fallback for the full cooldown window — previously it wouldn't.
Is this intentional? If so, it's probably fine (simpler), but worth noting as a behavior change. If not, pkg/ai would need to surface why the fallback was triggered (non-retryable vs exhausted retries).
| models = append(models, titleModel) | ||
| } | ||
|
|
||
| if lastErr != nil { | ||
| return "", fmt.Errorf("generating title failed: %w", lastErr) | ||
| str, err := ai.GenerateText( | ||
| ctx, | ||
| ai.WithModels(models...), | ||
| ai.WithMessages(messages...), | ||
| ai.WithRequireContent(), | ||
| ai.WithLogger(lg), | ||
| ) | ||
| if err != nil { | ||
| return "", fmt.Errorf("generating title failed: %w", err) | ||
| } | ||
| return "", nil | ||
|
|
There was a problem hiding this comment.
Praise: Nice simplification of the title generator. The old manual stream-drain + per-model retry loop collapses cleanly into a single ai.GenerateText call with WithRequireContent.
| c.messages = append(c.messages, msgs...) | ||
|
|
||
| if c.models[0].ID() != r.Model { | ||
| idx := slices.IndexFunc(c.models, func(m provider.Provider) bool { |
There was a problem hiding this comment.
[HIGH] Panic: slices.IndexFunc result not checked for -1 before use as slice index
slices.IndexFunc returns -1 when no element matches. If r.Model is not found in c.models (e.g., the provider returns an unexpected model ID), idx will be -1 and c.models[-1:] will panic with an index out of range.
The outer if c.models[0].ID() != r.Model guard only checks that the responding model is not the first one — it does not guarantee the model exists anywhere in the slice.
Fix: check idx before use:
idx := slices.IndexFunc(c.models, func(m provider.Provider) bool {
return m.ID() == r.Model
})
if idx < 0 {
idx = 0 // responding model not found; keep current order
}
c.models = append(c.models[idx:], c.models[:idx]...)| } | ||
|
|
||
| if ctx.Err() != nil { | ||
| return nil, err |
There was a problem hiding this comment.
[MEDIUM] Context cancellation returns stream error instead of ctx.Err(), masking cancellation
When ctx.Err() != nil (context was cancelled or timed out), the code returns err — the stream/model error — rather than ctx.Err(). Callers that check errors.Is(err, context.Canceled) or errors.Is(err, context.DeadlineExceeded) will not recognize this as a cancellation and may incorrectly retry or log it as a transient failure.
Note: the select at the bottom of the retry loop (the case <-ctx.Done() branch) correctly returns ctx.Err(), making this path inconsistent.
Fix:
if ctx.Err() != nil {
return nil, ctx.Err() // return the cancellation error, not the stream error
}| c.onModelFallback(c.models[i-1], model, err) | ||
| } | ||
|
|
||
| for retry := range c.retries { |
There was a problem hiding this comment.
[MEDIUM] Off-by-one in retry loop: for retry := range c.retries gives c.retries attempts, not c.retries + 1 as documented
The WithRetries option docs say:
The total attempts per model is n + 1.
But for retry := range c.retries (Go 1.22+ range-over-integer) iterates from 0 to c.retries-1, producing exactly c.retries total attempts — one fewer than documented.
Examples:
WithRetries(1)→ 1 attempt (expected: 2)WithRetries(2)→ 2 attempts (expected: 3)- Default (
c.retries = 1) → 1 attempt total, meaning zero retries on failure
Fix — change the loop to iterate c.retries + 1 times:
for retry := range c.retries + 1 {
...
}Or update the docs to state "n total attempts" if single-attempt behavior is intentional.
| // Handle tool call deltas | ||
| if len(choice.Delta.ToolCalls) > 0 { | ||
| for _, delta := range choice.Delta.ToolCalls { | ||
| idx, ok := toolCallIndex[delta.ID] |
There was a problem hiding this comment.
[MEDIUM] Empty-string tool call ID used as map key, merging distinct tool calls from streaming providers
toolCallIndex[delta.ID] uses delta.ID directly as a map key. In many streaming LLM protocols, the tool call ID is only sent in the first delta chunk — subsequent delta chunks for the same tool call arrive with delta.ID == "". This is fine for a single tool call per stream.
However, if a streaming provider sends multiple new tool calls where more than one arrives with delta.ID == "" (or a provider that never sends IDs in deltas), all of them are grouped under the same empty-string key. The second new call overwrites the first call's entry in toolCallIndex, and their arguments get merged into a single garbled ToolCall.
Consider adding a guard or using the delta's Index field (if available from the streaming spec) as the map key:
// Use index as primary key, fall back to ID
key := fmt.Sprintf("%d", delta.Index)
if delta.ID != "" {
key = delta.ID
}
idx, ok := toolCallIndex[key]Or at minimum validate that new entries (where !ok) have a non-empty ID, and log a warning otherwise.
|
|
||
| slog.Debug("Generated session title", "session_id", sessionID, "title", result, "model", baseModel.ID()) | ||
| return result, nil | ||
| models = append(models, titleModel) |
There was a problem hiding this comment.
[MEDIUM] provider.CloneWithOptions return value is not nil-checked before appending to models
provider.CloneWithOptions may return nil depending on its implementation (e.g., if the base model type doesn't support cloning). A nil provider.Provider interface value appended to models and passed to ai.WithModels would cause a nil-pointer dereference (panic) inside ai.completion.stream() when the interface method is called.
The baseModel == nil guard above only filters nil inputs; it does not protect against a nil output from CloneWithOptions.
Fix: add a nil check on the result:
titleModel := provider.CloneWithOptions(ctx, baseModel, ...)
if titleModel == nil {
continue
}
models = append(models, titleModel)
Add new pkg/ai package that extracts and centralizes model interaction logic from runtime.
See #2409