diff --git a/CHANGELOG.md b/CHANGELOG.md index 9959f87..2a8f878 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,40 @@ ### Added +- **Token usage and execution duration emission (issue #87, FWS-3).** + Every `llm_call` audit event now carries `input_tokens`, + `output_tokens`, `model`, `provider`, `duration_ms`, and `request_id` + captured directly from provider response metadata (Anthropic, OpenAI, + Ollama via the OpenAI-compatible path, OpenAI Responses). Field + naming aligns with OTel GenAI semantic conventions + (`gen_ai.usage.input_tokens` / `gen_ai.usage.output_tokens`) so audit + consumers can correlate to OTel traces without a translation table. + When a provider returns no usage (some self-hosted Ollama setups), + the event flags `tokens_unavailable: true` rather than silent zeros. + Each `tool_exec` event gains `duration_ms` plus structured arg-shape + metadata (`args_size`, `result_size`) — raw arg values are not + emitted (payload stripping is FWS-8's concern). A new + `invocation_complete` event closes every A2A invocation with total + wall-clock duration and aggregated `input_tokens_total` / + `output_tokens_total` / `llm_call_count`. A2A responses now carry + the same totals inline as `X-Forge-Tokens-In`, `X-Forge-Tokens-Out`, + `X-Forge-Duration-Ms`, `X-Forge-Model`, `X-Forge-Provider` headers + so orchestrators can enforce cost ceilings during parallel workflow + execution without subscribing to the audit stream. Headers populate + regardless of OTel-tracing state. Cost calculation is deliberately + not in Forge — Forge emits tokens, the platform applies price tables. + The new emitters route through `AuditLogger.EmitFromContext` so + workflow-correlation fields (FWS-2) auto-tag every `llm_call` / + `tool_exec` / `invocation_complete` event when the inbound request + carried orchestrator headers. Schema additivity: existing audit + consumers reading the pre-FWS-3 shape continue to work unchanged. See + `docs/security/audit-logging.md#token-usage-and-execution-duration`. + + Internal API change as part of this work: `llm.UsageInfo` field + names were renamed `PromptTokens` → `InputTokens` and + `CompletionTokens` → `OutputTokens` (JSON tags too) to align with + OTel GenAI semconv. The type is internal to `forge-core/llm` and not + consumed outside that package, so no external callers are affected. - **Workflow correlation ID threading (issue #86, FWS-2).** Forge agents now extract orchestration headers — `X-Workflow-ID`, `X-Workflow-Stage-ID`, `X-Workflow-Step-ID`, `X-Invocation-Caller` — diff --git a/docs/security/audit-logging.md b/docs/security/audit-logging.md index 5eb25d2..5324f4d 100644 --- a/docs/security/audit-logging.md +++ b/docs/security/audit-logging.md @@ -17,7 +17,9 @@ All runtime security events are emitted as structured NDJSON to stderr with corr | `tool_exec` | Tool execution start/end (with tool name) | | `egress_allowed` | Outbound request allowed (with domain, mode) | | `egress_blocked` | Outbound request blocked (with domain, mode) | -| `llm_call` | LLM API call completed (with token count) | +| `llm_call` | LLM API call completed (with `input_tokens`, `output_tokens`, `model`, `provider`, `duration_ms`, `request_id`). See [Token usage and duration](#token-usage-and-execution-duration). | +| `llm_call_cancelled` | Streaming LLM call cancelled mid-flight; carries partial token counts captured up to cancellation. | +| `invocation_complete` | A2A invocation finished (auth → dispatch → engine → response). Carries `duration_ms` (wall-clock) plus aggregated `input_tokens_total` / `output_tokens_total` / `llm_call_count` / `model` / `provider`. | | `guardrail_check` | Guardrail evaluation result | | `auth_verify` | Inbound request authenticated successfully (with `provider`, `user_id`, `org_id`, `token_kind`) | | `auth_fail` | Inbound request rejected (with `reason`, `token_kind`) | @@ -39,6 +41,53 @@ The `source` field distinguishes in-process enforcer events from subprocess prox When the inbound A2A request carries the orchestrator's correlation headers (`X-Workflow-ID`, `X-Workflow-Stage-ID`, `X-Workflow-Step-ID`, `X-Invocation-Caller`), every audit event emitted during that invocation is tagged with the matching `workflow_id` / `stage_id` / `step_id` / `invocation_caller` fields. Header names are vendor-neutral so any A2A-compatible orchestrator can populate them. Direct A2A invocations (no orchestrator) omit the fields entirely — emitted JSON is byte-identical to the pre-correlation shape. See [Workflow correlation IDs](workflow-correlation.md) for the full reference, including outbound propagation for agent-to-agent flows. +### Token usage and execution duration + +Every `llm_call` audit event carries the normalized token counts the provider returned in its response metadata, plus the wall-clock time spent in the provider call. Field naming aligns with [OTel GenAI semantic conventions](https://opentelemetry.io/docs/specs/semconv/gen-ai/) (`gen_ai.usage.input_tokens` / `gen_ai.usage.output_tokens`) so audit consumers can correlate Forge audit events with OTel traces without a translation table. + +```json +{ + "ts": "2026-06-04T15:21:09Z", + "event": "llm_call", + "correlation_id": "9b3d…", + "task_id": "task-42", + "model": "claude-sonnet-4-6", + "provider": "anthropic", + "input_tokens": 1240, + "output_tokens": 387, + "duration_ms": 2150, + "request_id": "msg_01H8…" +} +``` + +| Field | Source | Notes | +|---|---|---| +| `input_tokens` | Provider response usage | Maps to `gen_ai.usage.input_tokens` | +| `output_tokens` | Provider response usage | Maps to `gen_ai.usage.output_tokens` | +| `tokens_unavailable` | Audit emitter | `true` when both counts are zero — some self-hosted Ollama setups don't return usage; billing consumers must distinguish "not measured" from "zero tokens used" | +| `model` | Runtime model config | The model identifier the executor was configured with | +| `provider` | Runtime model config | One of `anthropic`, `openai`, `ollama`, `custom` | +| `duration_ms` | Captured at call site | Wall-clock time spent in `client.Chat`, in milliseconds | +| `request_id` | Provider response | Opaque provider call ID (Anthropic `id`, OpenAI `id`) — debug-correlation handle only, never used for billing | + +Each `tool_exec` event (phase=end) carries `duration_ms` for the tool execution plus structured arg-shape metadata (`args_size`, `result_size`) — raw arg values are deliberately not included (payload stripping is FWS-8's concern). One `invocation_complete` event closes each A2A invocation with the total wall-clock duration and aggregated token totals across all LLM calls in the invocation. + +Workflow correlation fields (`workflow_id` / `stage_id` / `step_id` / `invocation_caller` from FWS-2) also auto-tag every `llm_call` / `tool_exec` / `invocation_complete` event when the inbound request carried orchestrator headers — billing and audit consumers can attribute cost not just to a task but to a specific workflow run / stage / step. + +A2A response headers carry the same per-invocation totals inline so an orchestrator can ceiling-check cost during parallel workflow execution without subscribing to the audit stream: + +| Header | Value | +|---|---| +| `X-Forge-Tokens-In` | Sum of `input_tokens` across all LLM calls in the invocation | +| `X-Forge-Tokens-Out` | Sum of `output_tokens` across all LLM calls in the invocation | +| `X-Forge-Duration-Ms` | Wall-clock invocation duration (auth → dispatch → engine → response) | +| `X-Forge-Model` | Most-recently-used model | +| `X-Forge-Provider` | Most-recently-used provider | + +Headers populate regardless of whether OTel tracing is enabled — they're the orchestration channel, not the observability channel. + +**Cost calculation is deliberately not in Forge.** Forge emits token counts; the platform applies price tables to compute dollar amounts. Price tables change frequently and shouldn't require agent redeploys. + ### Authentication events Every inbound request to `/tasks` emits exactly one of `auth_verify` or `auth_fail`. diff --git a/forge-cli/runtime/forge_usage_headers.go b/forge-cli/runtime/forge_usage_headers.go new file mode 100644 index 0000000..7f1d2e7 --- /dev/null +++ b/forge-cli/runtime/forge_usage_headers.go @@ -0,0 +1,44 @@ +package runtime + +import ( + "net/http" + "strconv" + + coreruntime "github.com/initializ/forge/forge-core/runtime" +) + +// A2A response header names for per-invocation cost telemetry. These +// are the inline channel for orchestrator real-time cost enforcement +// during parallel workflow execution — the orchestrator can ceiling-check +// against running totals before the next stage dispatches. They populate +// regardless of whether OTel tracing is enabled. See issue #87 / FWS-3. +const ( + HeaderForgeTokensIn = "X-Forge-Tokens-In" + HeaderForgeTokensOut = "X-Forge-Tokens-Out" + HeaderForgeDurationMs = "X-Forge-Duration-Ms" + HeaderForgeModel = "X-Forge-Model" + HeaderForgeProvider = "X-Forge-Provider" +) + +// applyForgeUsageHeaders stamps the X-Forge-* invocation-usage headers +// onto the given http.Header from a usage snapshot. Headers are omitted +// for snapshots with zero LLM calls (e.g. guardrail-failed invocations +// that never reached the LLM) so the response shape mirrors what +// actually happened. +func applyForgeUsageHeaders(h http.Header, snap coreruntime.LLMUsageSnapshot) { + if snap.LLMCallCount == 0 { + // Still stamp duration so orchestrators always see a wall-clock + // figure even for short-circuited invocations. + h.Set(HeaderForgeDurationMs, strconv.FormatInt(snap.InvocationDuration.Milliseconds(), 10)) + return + } + h.Set(HeaderForgeTokensIn, strconv.Itoa(snap.InputTokens)) + h.Set(HeaderForgeTokensOut, strconv.Itoa(snap.OutputTokens)) + h.Set(HeaderForgeDurationMs, strconv.FormatInt(snap.InvocationDuration.Milliseconds(), 10)) + if snap.PrimaryModel != "" { + h.Set(HeaderForgeModel, snap.PrimaryModel) + } + if snap.PrimaryProvider != "" { + h.Set(HeaderForgeProvider, snap.PrimaryProvider) + } +} diff --git a/forge-cli/runtime/forge_usage_headers_test.go b/forge-cli/runtime/forge_usage_headers_test.go new file mode 100644 index 0000000..c768275 --- /dev/null +++ b/forge-cli/runtime/forge_usage_headers_test.go @@ -0,0 +1,80 @@ +package runtime + +import ( + "net/http" + "testing" + "time" + + coreruntime "github.com/initializ/forge/forge-core/runtime" +) + +// Regression tests for issue #87 / FWS-3 — X-Forge-* response header +// emission. Headers are the inline channel for orchestrator real-time +// cost enforcement; they populate regardless of whether OTel tracing +// is enabled. + +func TestApplyForgeUsageHeaders_StampsAllFields(t *testing.T) { + h := http.Header{} + applyForgeUsageHeaders(h, coreruntime.LLMUsageSnapshot{ + InputTokens: 450, + OutputTokens: 180, + InvocationDuration: 1234 * time.Millisecond, + PrimaryModel: "claude-sonnet-4-6", + PrimaryProvider: "anthropic", + LLMCallCount: 3, + }) + + if h.Get(HeaderForgeTokensIn) != "450" { + t.Errorf("X-Forge-Tokens-In = %q, want 450", h.Get(HeaderForgeTokensIn)) + } + if h.Get(HeaderForgeTokensOut) != "180" { + t.Errorf("X-Forge-Tokens-Out = %q, want 180", h.Get(HeaderForgeTokensOut)) + } + if h.Get(HeaderForgeDurationMs) != "1234" { + t.Errorf("X-Forge-Duration-Ms = %q, want 1234", h.Get(HeaderForgeDurationMs)) + } + if h.Get(HeaderForgeModel) != "claude-sonnet-4-6" { + t.Errorf("X-Forge-Model = %q, want claude-sonnet-4-6", h.Get(HeaderForgeModel)) + } + if h.Get(HeaderForgeProvider) != "anthropic" { + t.Errorf("X-Forge-Provider = %q, want anthropic", h.Get(HeaderForgeProvider)) + } +} + +func TestApplyForgeUsageHeaders_NoLLMCalls_StillStampsDuration(t *testing.T) { + // Short-circuited invocation (guardrail-failed before LLM dispatch): + // orchestrator still wants a wall-clock figure, but token fields + // would mislead — emit duration only. + h := http.Header{} + applyForgeUsageHeaders(h, coreruntime.LLMUsageSnapshot{ + InvocationDuration: 5 * time.Millisecond, + LLMCallCount: 0, + }) + + if h.Get(HeaderForgeDurationMs) != "5" { + t.Errorf("X-Forge-Duration-Ms must still be stamped on short-circuited invocations, got %q", h.Get(HeaderForgeDurationMs)) + } + if h.Get(HeaderForgeTokensIn) != "" || h.Get(HeaderForgeTokensOut) != "" { + t.Errorf("token headers must NOT be stamped when no LLM calls happened, got in=%q out=%q", + h.Get(HeaderForgeTokensIn), h.Get(HeaderForgeTokensOut)) + } +} + +func TestApplyForgeUsageHeaders_OmitsModelProviderWhenAbsent(t *testing.T) { + // Edge case: LLM call happened but provider/model were empty (no + // runtime attribution available). Stamp tokens + duration only — + // don't stamp empty model/provider values. + h := http.Header{} + applyForgeUsageHeaders(h, coreruntime.LLMUsageSnapshot{ + InputTokens: 50, + OutputTokens: 25, + InvocationDuration: 100 * time.Millisecond, + LLMCallCount: 1, + }) + if _, present := h[http.CanonicalHeaderKey(HeaderForgeModel)]; present { + t.Errorf("X-Forge-Model must be omitted when PrimaryModel is empty") + } + if _, present := h[http.CanonicalHeaderKey(HeaderForgeProvider)]; present { + t.Errorf("X-Forge-Provider must be omitted when PrimaryProvider is empty") + } +} diff --git a/forge-cli/runtime/runner.go b/forge-cli/runtime/runner.go index 0b7e747..59f3583 100644 --- a/forge-cli/runtime/runner.go +++ b/forge-cli/runtime/runner.go @@ -586,6 +586,7 @@ func (r *Runner) Run(ctx context.Context) error { SystemPrompt: sysPrompt, Logger: r.logger, ModelName: mc.Client.Model, + Provider: mc.Provider, MaxIterations: 100, CharBudget: charBudget, FilesDir: filepath.Join(r.cfg.WorkDir, ".forge", "files"), @@ -791,129 +792,25 @@ func (r *Runner) Run(ctx context.Context) error { func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.AgentExecutor, guardrails coreruntime.GuardrailChecker, egressClient *http.Client, auditLogger *coreruntime.AuditLogger) { store := srv.TaskStore() - // tasks/send — synchronous request + // tasks/send — synchronous request. Delegates to executeTask so the + // JSON-RPC path goes through the same audit + accumulator wiring as + // REST POST /tasks/send. See issue #87 / FWS-3. srv.RegisterHandler("tasks/send", func(ctx context.Context, id any, rawParams json.RawMessage) *a2a.JSONRPCResponse { var params a2a.SendTaskParams if err := json.Unmarshal(rawParams, ¶ms); err != nil { return a2a.NewErrorResponse(id, a2a.ErrCodeInvalidParams, "invalid params: "+err.Error()) } - r.logger.Info("tasks/send", map[string]any{"task_id": params.ID}) - - // Inject egress client and correlation/task IDs into context - correlationID := coreruntime.GenerateID() - ctx = security.WithEgressClient(ctx, egressClient) - ctx = coreruntime.WithCorrelationID(ctx, correlationID) - ctx = coreruntime.WithTaskID(ctx, params.ID) - - auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ - Event: coreruntime.AuditSessionStart, - CorrelationID: correlationID, - TaskID: params.ID, - }) - - // Load existing task to preserve conversation history, or create new. - task := store.Get(params.ID) - if task == nil { - task = &a2a.Task{ID: params.ID} - } - task.Status = a2a.TaskStatus{State: a2a.TaskStateSubmitted} - store.Put(task) - - // Guardrail check inbound - if err := guardrails.CheckInbound(¶ms.Message); err != nil { - task.Status = a2a.TaskStatus{ - State: a2a.TaskStateFailed, - Message: &a2a.Message{ - Role: a2a.MessageRoleAgent, - Parts: []a2a.Part{a2a.NewTextPart("Guardrail violation: " + err.Error())}, - }, - } - store.Put(task) - auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ - Event: coreruntime.AuditSessionEnd, - CorrelationID: correlationID, - TaskID: params.ID, - Fields: map[string]any{"state": string(a2a.TaskStateFailed)}, - }) - return a2a.NewResponse(id, task) - } - - // Append inbound user message to task history. - task.History = append(task.History, params.Message) - - // Update to working - task.Status = a2a.TaskStatus{State: a2a.TaskStateWorking} - store.Put(task) - - // Execute via executor - respMsg, err := executor.Execute(ctx, task, ¶ms.Message) + // Delegate to executeTask so JSON-RPC and REST share the same + // audit + accumulator + invocation_complete wiring (issue #87 / + // FWS-3). The dispatcher already injected WorkflowContext into + // ctx from inbound headers per issue #86 / FWS-2, so every audit + // event executeTask emits carries workflow correlation fields + // when present. + task, _, err := r.executeTask(ctx, params, store, executor, guardrails, egressClient, auditLogger) if err != nil { - r.logger.Error("execute failed", map[string]any{"task_id": params.ID, "error": err.Error()}) - task.Status = a2a.TaskStatus{ - State: a2a.TaskStateFailed, - Message: &a2a.Message{ - Role: a2a.MessageRoleAgent, - Parts: []a2a.Part{a2a.NewTextPart(err.Error())}, - }, - } - store.Put(task) - auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ - Event: coreruntime.AuditSessionEnd, - CorrelationID: correlationID, - TaskID: params.ID, - Fields: map[string]any{"state": string(a2a.TaskStateFailed)}, - }) - return a2a.NewResponse(id, task) - } - - // Guardrail check outbound - if respMsg != nil { - if err := guardrails.CheckOutbound(respMsg); err != nil { - task.Status = a2a.TaskStatus{ - State: a2a.TaskStateFailed, - Message: &a2a.Message{ - Role: a2a.MessageRoleAgent, - Parts: []a2a.Part{a2a.NewTextPart("Outbound guardrail violation: " + err.Error())}, - }, - } - store.Put(task) - auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ - Event: coreruntime.AuditSessionEnd, - CorrelationID: correlationID, - TaskID: params.ID, - Fields: map[string]any{"state": string(a2a.TaskStateFailed)}, - }) - return a2a.NewResponse(id, task) - } - } - - // Append agent response to task history. - if respMsg != nil { - task.History = append(task.History, *respMsg) + return a2a.NewErrorResponse(id, a2a.ErrCodeInternal, err.Error()) } - - // Build completed task - task.Status = a2a.TaskStatus{ - State: a2a.TaskStateCompleted, - Message: respMsg, - } - if respMsg != nil { - task.Artifacts = []a2a.Artifact{ - { - Name: "response", - Parts: respMsg.Parts, - }, - } - } - store.Put(task) - auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ - Event: coreruntime.AuditSessionEnd, - CorrelationID: correlationID, - TaskID: params.ID, - Fields: map[string]any{"state": string(task.Status.State)}, - }) - r.logger.Info("task completed", map[string]any{"task_id": params.ID, "state": string(task.Status.State)}) return a2a.NewResponse(id, task) }) @@ -927,11 +824,33 @@ func (r *Runner) registerHandlers(srv *server.Server, executor coreruntime.Agent r.logger.Info("tasks/sendSubscribe", map[string]any{"task_id": params.ID}) - // Inject egress client and correlation/task IDs into context + // Inject egress client, correlation/task IDs, and per-invocation + // usage accumulator (issue #87 / FWS-3) into context. The + // accumulator lets the AfterLLMCall hook fold each call's + // tokens/duration into running totals for the invocation_complete + // audit event emitted before this handler returns. correlationID := coreruntime.GenerateID() ctx = security.WithEgressClient(ctx, egressClient) ctx = coreruntime.WithCorrelationID(ctx, correlationID) ctx = coreruntime.WithTaskID(ctx, params.ID) + sseAcc := coreruntime.NewLLMUsageAccumulator() + ctx = coreruntime.WithLLMUsageAccumulator(ctx, sseAcc) + defer func() { + snap := sseAcc.Snapshot() + fields := map[string]any{} + if snap.LLMCallCount > 0 { + fields["input_tokens_total"] = snap.InputTokens + fields["output_tokens_total"] = snap.OutputTokens + fields["llm_call_count"] = snap.LLMCallCount + if snap.PrimaryModel != "" { + fields["model"] = snap.PrimaryModel + } + if snap.PrimaryProvider != "" { + fields["provider"] = snap.PrimaryProvider + } + } + auditLogger.EmitInvocationComplete(ctx, snap.InvocationDuration, fields) + }() auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ Event: coreruntime.AuditSessionStart, @@ -1102,13 +1021,19 @@ func (r *Runner) executeTask( guardrails coreruntime.GuardrailChecker, egressClient *http.Client, auditLogger *coreruntime.AuditLogger, -) (*a2a.Task, error) { +) (*a2a.Task, coreruntime.LLMUsageSnapshot, error) { correlationID := coreruntime.GenerateID() ctx = security.WithEgressClient(ctx, egressClient) ctx = coreruntime.WithCorrelationID(ctx, correlationID) ctx = coreruntime.WithTaskID(ctx, params.ID) - - auditLogger.Emit(coreruntime.AuditEvent{ + // Per-invocation usage accumulator so AfterLLMCall hooks can fold + // each call's tokens/duration into running totals the response + // handler reads back for X-Forge-* headers + the + // invocation_complete audit event. See issue #87 / FWS-3. + acc := coreruntime.NewLLMUsageAccumulator() + ctx = coreruntime.WithLLMUsageAccumulator(ctx, acc) + + auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ Event: coreruntime.AuditSessionStart, CorrelationID: correlationID, TaskID: params.ID, @@ -1121,6 +1046,23 @@ func (r *Runner) executeTask( task.Status = a2a.TaskStatus{State: a2a.TaskStateSubmitted} store.Put(task) + emitInvocationComplete := func() { + snap := acc.Snapshot() + fields := map[string]any{"state": string(task.Status.State)} + if snap.LLMCallCount > 0 { + fields["input_tokens_total"] = snap.InputTokens + fields["output_tokens_total"] = snap.OutputTokens + fields["llm_call_count"] = snap.LLMCallCount + if snap.PrimaryModel != "" { + fields["model"] = snap.PrimaryModel + } + if snap.PrimaryProvider != "" { + fields["provider"] = snap.PrimaryProvider + } + } + auditLogger.EmitInvocationComplete(ctx, snap.InvocationDuration, fields) + } + if err := guardrails.CheckInbound(¶ms.Message); err != nil { task.Status = a2a.TaskStatus{ State: a2a.TaskStateFailed, @@ -1130,13 +1072,14 @@ func (r *Runner) executeTask( }, } store.Put(task) - auditLogger.Emit(coreruntime.AuditEvent{ + auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ Event: coreruntime.AuditSessionEnd, CorrelationID: correlationID, TaskID: params.ID, Fields: map[string]any{"state": string(a2a.TaskStateFailed)}, }) - return task, nil + emitInvocationComplete() + return task, acc.Snapshot(), nil } task.History = append(task.History, params.Message) @@ -1154,13 +1097,14 @@ func (r *Runner) executeTask( }, } store.Put(task) - auditLogger.Emit(coreruntime.AuditEvent{ + auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ Event: coreruntime.AuditSessionEnd, CorrelationID: correlationID, TaskID: params.ID, Fields: map[string]any{"state": string(a2a.TaskStateFailed)}, }) - return task, nil + emitInvocationComplete() + return task, acc.Snapshot(), nil } if respMsg != nil { @@ -1179,7 +1123,8 @@ func (r *Runner) executeTask( TaskID: params.ID, Fields: map[string]any{"state": string(a2a.TaskStateFailed)}, }) - return task, nil + emitInvocationComplete() + return task, acc.Snapshot(), nil } } @@ -1200,14 +1145,15 @@ func (r *Runner) executeTask( } } store.Put(task) - auditLogger.Emit(coreruntime.AuditEvent{ + auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ Event: coreruntime.AuditSessionEnd, CorrelationID: correlationID, TaskID: params.ID, Fields: map[string]any{"state": string(task.Status.State)}, }) + emitInvocationComplete() r.logger.Info("task completed", map[string]any{"task_id": params.ID, "state": string(task.Status.State)}) - return task, nil + return task, acc.Snapshot(), nil } // restTaskRequest is the simplified JSON body for REST task endpoints. @@ -1238,17 +1184,18 @@ func (r *Runner) registerRESTHandlers(srv *server.Server, executor coreruntime.A Message: body.Task.Message, } - // Pull workflow correlation headers (issue #86) so audit + // Pull workflow correlation headers (issue #86 / FWS-2) so audit // events tagged via EmitFromContext carry the orchestrator's // workflow/stage/step identifiers. Absent headers → IsZero // WorkflowContext → fields omitted (backward compat). ctx := coreruntime.WithWorkflowContext(req.Context(), coreruntime.WorkflowContextFromHTTPHeaders(req.Header)) - task, err := r.executeTask(ctx, params, store, executor, guardrails, egressClient, auditLogger) + task, snap, err := r.executeTask(ctx, params, store, executor, guardrails, egressClient, auditLogger) if err != nil { writeJSON(w, http.StatusInternalServerError, map[string]string{"error": err.Error()}) return } + applyForgeUsageHeaders(w.Header(), snap) writeJSON(w, http.StatusOK, task) }) @@ -1282,8 +1229,31 @@ func (r *Runner) registerRESTHandlers(srv *server.Server, executor coreruntime.A ctx := security.WithEgressClient(req.Context(), egressClient) ctx = coreruntime.WithCorrelationID(ctx, correlationID) ctx = coreruntime.WithTaskID(ctx, params.ID) + // Pull workflow correlation headers (issue #86 / FWS-2) before + // the accumulator setup so invocation_complete inherits workflow + // tagging via EmitFromContext. ctx = coreruntime.WithWorkflowContext(ctx, coreruntime.WorkflowContextFromHTTPHeaders(req.Header)) + // Per-invocation usage accumulator + invocation_complete on exit. + // See issue #87 / FWS-3. + restSSEAcc := coreruntime.NewLLMUsageAccumulator() + ctx = coreruntime.WithLLMUsageAccumulator(ctx, restSSEAcc) + defer func() { + snap := restSSEAcc.Snapshot() + fields := map[string]any{} + if snap.LLMCallCount > 0 { + fields["input_tokens_total"] = snap.InputTokens + fields["output_tokens_total"] = snap.OutputTokens + fields["llm_call_count"] = snap.LLMCallCount + if snap.PrimaryModel != "" { + fields["model"] = snap.PrimaryModel + } + if snap.PrimaryProvider != "" { + fields["provider"] = snap.PrimaryProvider + } + } + auditLogger.EmitInvocationComplete(ctx, snap.InvocationDuration, fields) + }() auditLogger.EmitFromContext(ctx, coreruntime.AuditEvent{ Event: coreruntime.AuditSessionStart, @@ -1549,29 +1519,48 @@ func (r *Runner) registerAuditHooks(hooks *coreruntime.HookRegistry, auditLogger if hctx.Error != nil { fields["error"] = hctx.Error.Error() } + // Structured arg-shape metadata (sizes only — never raw values; + // raw-arg-value emission is FWS-8's payload-stripping concern, + // not FWS-3's). See issue #87 / FWS-3. + if hctx.ToolInput != "" { + fields["args_size"] = len(hctx.ToolInput) + } + if hctx.ToolOutput != "" { + fields["result_size"] = len(hctx.ToolOutput) + } + ms := hctx.ToolExecDuration.Milliseconds() auditLogger.Emit(coreruntime.AuditEvent{ Event: coreruntime.AuditToolExec, CorrelationID: hctx.CorrelationID, TaskID: hctx.TaskID, + DurationMs: &ms, Fields: fields, }) return nil }) - hooks.Register(coreruntime.AfterLLMCall, func(_ context.Context, hctx *coreruntime.HookContext) error { - fields := map[string]any{} - if hctx.Response != nil && hctx.Response.Usage.TotalTokens > 0 { - fields["tokens"] = hctx.Response.Usage.TotalTokens - } - if r.modelConfig != nil && r.modelConfig.Client.OrgID != "" { - fields["organization_id"] = r.modelConfig.Client.OrgID - } - auditLogger.Emit(coreruntime.AuditEvent{ - Event: coreruntime.AuditLLMCall, - CorrelationID: hctx.CorrelationID, - TaskID: hctx.TaskID, - Fields: fields, + hooks.Register(coreruntime.AfterLLMCall, func(ctx context.Context, hctx *coreruntime.HookContext) error { + var usage coreruntime.LLMUsage + var requestID string + if hctx.Response != nil { + usage.InputTokens = hctx.Response.Usage.InputTokens + usage.OutputTokens = hctx.Response.Usage.OutputTokens + usage.TotalTokens = hctx.Response.Usage.TotalTokens + requestID = hctx.Response.ID + } + auditLogger.EmitLLMCall(ctx, coreruntime.LLMCallAuditArgs{ + Model: hctx.Model, + Provider: hctx.Provider, + RequestID: requestID, + Usage: usage, + Duration: hctx.LLMCallDuration, }) + // Accumulate per-invocation usage totals so the response handler + // can populate X-Forge-Tokens-In/Out + X-Forge-Duration-Ms + + // X-Forge-Model + X-Forge-Provider headers. See issue #87 / FWS-3. + if acc := coreruntime.LLMUsageAccumulatorFromContext(ctx); acc != nil { + acc.AddLLMCall(hctx.Model, hctx.Provider, usage, hctx.LLMCallDuration) + } return nil }) } diff --git a/forge-core/llm/providers/anthropic.go b/forge-core/llm/providers/anthropic.go index 28f20ef..ae128ce 100644 --- a/forge-core/llm/providers/anthropic.go +++ b/forge-core/llm/providers/anthropic.go @@ -269,9 +269,9 @@ func (c *AnthropicClient) parseAnthropicResponse(body io.Reader) (*llm.ChatRespo ID: resp.ID, Message: msg, Usage: llm.UsageInfo{ - PromptTokens: resp.Usage.InputTokens, - CompletionTokens: resp.Usage.OutputTokens, - TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, + InputTokens: resp.Usage.InputTokens, + OutputTokens: resp.Usage.OutputTokens, + TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens, }, FinishReason: finishReason, }, nil @@ -369,7 +369,7 @@ func (c *AnthropicClient) readAnthropicStream(r io.Reader, ch chan<- llm.StreamD ch <- llm.StreamDelta{ FinishReason: finishReason, Usage: &llm.UsageInfo{ - CompletionTokens: ev.Usage.OutputTokens, + OutputTokens: ev.Usage.OutputTokens, }, } diff --git a/forge-core/llm/providers/openai.go b/forge-core/llm/providers/openai.go index a7315fc..9d7c79c 100644 --- a/forge-core/llm/providers/openai.go +++ b/forge-core/llm/providers/openai.go @@ -220,9 +220,9 @@ func (c *OpenAIClient) parseOpenAIResponse(body io.Reader) (*llm.ChatResponse, e ToolCalls: choice.Message.ToolCalls, }, Usage: llm.UsageInfo{ - PromptTokens: resp.Usage.PromptTokens, - CompletionTokens: resp.Usage.CompletionTokens, - TotalTokens: resp.Usage.TotalTokens, + InputTokens: resp.Usage.PromptTokens, + OutputTokens: resp.Usage.CompletionTokens, + TotalTokens: resp.Usage.TotalTokens, }, FinishReason: choice.FinishReason, }, nil @@ -273,9 +273,9 @@ func (c *OpenAIClient) readSSEStream(r io.Reader, ch chan<- llm.StreamDelta) { } if chunk.Usage != nil { delta.Usage = &llm.UsageInfo{ - PromptTokens: chunk.Usage.PromptTokens, - CompletionTokens: chunk.Usage.CompletionTokens, - TotalTokens: chunk.Usage.TotalTokens, + InputTokens: chunk.Usage.PromptTokens, + OutputTokens: chunk.Usage.CompletionTokens, + TotalTokens: chunk.Usage.TotalTokens, } } ch <- delta diff --git a/forge-core/llm/providers/openai_embedder.go b/forge-core/llm/providers/openai_embedder.go index 15a5b37..67565de 100644 --- a/forge-core/llm/providers/openai_embedder.go +++ b/forge-core/llm/providers/openai_embedder.go @@ -121,8 +121,8 @@ func (e *OpenAIEmbedder) Embed(ctx context.Context, req *llm.EmbeddingRequest) ( Embeddings: embeddings, Model: embResp.Model, Usage: llm.UsageInfo{ - PromptTokens: embResp.Usage.PromptTokens, - TotalTokens: embResp.Usage.TotalTokens, + InputTokens: embResp.Usage.PromptTokens, + TotalTokens: embResp.Usage.TotalTokens, }, }, nil } diff --git a/forge-core/llm/providers/responses.go b/forge-core/llm/providers/responses.go index 410561a..b2f1bf4 100644 --- a/forge-core/llm/providers/responses.go +++ b/forge-core/llm/providers/responses.go @@ -420,9 +420,9 @@ func (c *ResponsesClient) readStream(r io.Reader, ch chan<- llm.StreamDelta) { delta := llm.StreamDelta{Done: true} if ev.Response.Usage != nil { delta.Usage = &llm.UsageInfo{ - PromptTokens: ev.Response.Usage.InputTokens, - CompletionTokens: ev.Response.Usage.OutputTokens, - TotalTokens: ev.Response.Usage.TotalTokens, + InputTokens: ev.Response.Usage.InputTokens, + OutputTokens: ev.Response.Usage.OutputTokens, + TotalTokens: ev.Response.Usage.TotalTokens, } } // Determine finish reason from output diff --git a/forge-core/llm/providers/usage_extraction_test.go b/forge-core/llm/providers/usage_extraction_test.go new file mode 100644 index 0000000..c21f7f8 --- /dev/null +++ b/forge-core/llm/providers/usage_extraction_test.go @@ -0,0 +1,126 @@ +package providers + +import ( + "context" + "encoding/json" + "io" + "net/http" + "net/http/httptest" + "testing" + + "github.com/initializ/forge/forge-core/llm" +) + +// Regression tests for issue #87 / FWS-3 — every provider must +// populate the normalized UsageInfo.InputTokens / OutputTokens / +// TotalTokens from its native response shape so the audit layer can +// emit accurate llm_call events regardless of which provider served +// the call. The OpenAI-compatible path also serves Ollama. + +func TestAnthropic_PopulatesUsageWithOTelAlignedNames(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "msg_test", + "content": []map[string]any{{"type": "text", "text": "ok"}}, + "stop_reason": "end_turn", + "usage": map[string]int{"input_tokens": 42, "output_tokens": 17}, + }) + })) + defer srv.Close() + + c := NewAnthropicClient(llm.ClientConfig{APIKey: "x", BaseURL: srv.URL, Model: "claude-3-5-sonnet"}) + resp, err := c.Chat(context.Background(), &llm.ChatRequest{ + Model: "claude-3-5-sonnet", + Messages: []llm.ChatMessage{{Role: llm.RoleUser, Content: "hi"}}, + }) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Usage.InputTokens != 42 { + t.Errorf("InputTokens = %d, want 42", resp.Usage.InputTokens) + } + if resp.Usage.OutputTokens != 17 { + t.Errorf("OutputTokens = %d, want 17", resp.Usage.OutputTokens) + } + if resp.Usage.TotalTokens != 59 { + t.Errorf("TotalTokens = %d, want 59 (Anthropic doesn't return total — provider computes input+output)", resp.Usage.TotalTokens) + } +} + +func TestOpenAI_PopulatesUsageWithOTelAlignedNames(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "chatcmpl-1", + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + }, + }, + "usage": map[string]int{ + "prompt_tokens": 7, + "completion_tokens": 3, + "total_tokens": 10, + }, + }) + })) + defer srv.Close() + + c := NewOpenAIClient(llm.ClientConfig{APIKey: "x", BaseURL: srv.URL, Model: "gpt-4o-mini"}) + resp, err := c.Chat(context.Background(), &llm.ChatRequest{ + Model: "gpt-4o-mini", + Messages: []llm.ChatMessage{{Role: llm.RoleUser, Content: "hi"}}, + }) + if err != nil { + t.Fatalf("Chat: %v", err) + } + // OpenAI wire format still uses prompt_tokens / completion_tokens + // (provider-specific), but the normalized UsageInfo we expose to + // audit consumers uses OTel-aligned input_tokens / output_tokens. + if resp.Usage.InputTokens != 7 { + t.Errorf("InputTokens (mapped from prompt_tokens) = %d, want 7", resp.Usage.InputTokens) + } + if resp.Usage.OutputTokens != 3 { + t.Errorf("OutputTokens (mapped from completion_tokens) = %d, want 3", resp.Usage.OutputTokens) + } + if resp.Usage.TotalTokens != 10 { + t.Errorf("TotalTokens = %d, want 10", resp.Usage.TotalTokens) + } +} + +func TestOllama_NoUsage_LeavesZerosForAuditUnavailableFlag(t *testing.T) { + // Some self-hosted Ollama models don't include token counts in the + // response. The provider must not invent values — leave zeros so + // the audit layer flags tokens_unavailable=true on the llm_call + // event rather than billing for a free call. + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + _, _ = io.Copy(io.Discard, r.Body) + _ = json.NewEncoder(w).Encode(map[string]any{ + "id": "ollama-1", + "choices": []map[string]any{ + { + "index": 0, + "message": map[string]any{"role": "assistant", "content": "ok"}, + "finish_reason": "stop", + }, + }, + // usage field deliberately absent + }) + })) + defer srv.Close() + + c := NewOllamaClient(llm.ClientConfig{BaseURL: srv.URL, Model: "llama3"}) + resp, err := c.Chat(context.Background(), &llm.ChatRequest{ + Model: "llama3", + Messages: []llm.ChatMessage{{Role: llm.RoleUser, Content: "hi"}}, + }) + if err != nil { + t.Fatalf("Chat: %v", err) + } + if resp.Usage.InputTokens != 0 || resp.Usage.OutputTokens != 0 { + t.Errorf("usage-less response should leave zeros so audit layer sets tokens_unavailable=true, got %+v", resp.Usage) + } +} diff --git a/forge-core/llm/types.go b/forge-core/llm/types.go index faa9953..88b7577 100644 --- a/forge-core/llm/types.go +++ b/forge-core/llm/types.go @@ -76,8 +76,13 @@ type StreamDelta struct { } // UsageInfo contains token usage information. +// +// Field naming aligns with OTel GenAI semantic conventions +// (gen_ai.usage.input_tokens / gen_ai.usage.output_tokens) so audit +// consumers can correlate Forge audit events with OTel traces without +// a translation table. See issue #87 / FWS-3. type UsageInfo struct { - PromptTokens int `json:"prompt_tokens"` - CompletionTokens int `json:"completion_tokens"` - TotalTokens int `json:"total_tokens"` + InputTokens int `json:"input_tokens"` + OutputTokens int `json:"output_tokens"` + TotalTokens int `json:"total_tokens"` } diff --git a/forge-core/runtime/audit.go b/forge-core/runtime/audit.go index 0eb9a39..26a9e2b 100644 --- a/forge-core/runtime/audit.go +++ b/forge-core/runtime/audit.go @@ -51,6 +51,17 @@ const ( // and the A2A 0.3.0 spec. EventAgentCardPublished = "agent_card_published" + // Lifecycle events emitted at A2A invocation boundaries. + // AuditInvocationComplete carries total wall-clock duration_ms for + // the full invocation (auth → dispatch → engine.Execute → response). + // See issue #87 / FWS-3. + AuditInvocationComplete = "invocation_complete" + + // AuditLLMCallCancelled is emitted when a streaming LLM call is + // cancelled mid-flight; carries partial usage counts captured up to + // the cancellation point. See issue #87 / FWS-3. + AuditLLMCallCancelled = "llm_call_cancelled" + // Deprecated: use EventAuthVerify. Kept as a string alias so any // audit-log consumer that grep'd for "auth_success" can be migrated. // Scheduled for removal in v0.11.0. @@ -67,6 +78,17 @@ const ( // `X-Invocation-Caller` headers from any A2A-compatible orchestrator. // Direct A2A invocations omit them entirely so the JSON shape matches // the pre-FWS-2 audit consumers. +// +// Token usage, duration, model, and provider fields (issue #87 / FWS-3) +// are populated by the LLM call site, tool execution path, and per- +// invocation lifecycle. They use *int / *int64 pointers so the JSON +// distinguishes "field absent" (nil) from "field present with zero +// value" — important for llm_call events where zero is a legitimate +// count and TokensUnavailable signals "provider did not report usage." +// +// Field naming aligns with OTel GenAI semconv (input_tokens / +// output_tokens / duration_ms) so audit consumers can correlate Forge +// audit events with OTel traces without a translation table. type AuditEvent struct { Timestamp string `json:"ts"` Event string `json:"event"` @@ -93,6 +115,27 @@ type AuditEvent struct { // or upstream agent in an agent-to-agent flow). InvocationCaller string `json:"invocation_caller,omitempty"` + // LLM call attribution (llm_call, llm_call_cancelled, invocation_complete). + Model string `json:"model,omitempty"` + Provider string `json:"provider,omitempty"` + + // Token counts captured from provider response metadata. Nil when + // the event is not an LLM call. Non-nil with zero values + a true + // TokensUnavailable flag when the provider did not return usage + // (e.g. some self-hosted Ollama setups). + InputTokens *int `json:"input_tokens,omitempty"` + OutputTokens *int `json:"output_tokens,omitempty"` + TokensUnavailable bool `json:"tokens_unavailable,omitempty"` + + // DurationMs is the wall-clock duration in milliseconds. Populated on + // llm_call, tool_exec, and invocation_complete events. + DurationMs *int64 `json:"duration_ms,omitempty"` + + // RequestID is the provider-specific call identifier (Anthropic + // `id`, OpenAI `id`, etc.) — kept as an opaque debug-correlation + // handle, never used for cost attribution. + RequestID string `json:"request_id,omitempty"` + Fields map[string]any `json:"fields,omitempty"` } @@ -162,6 +205,106 @@ func (a *AuditLogger) EmitFromContext(ctx context.Context, event AuditEvent) { a.Emit(event) } +// LLMCallAuditArgs is the shared input to AuditLogger.EmitLLMCall. The +// LLM call site captures these fields once at provider-call completion +// and the audit logger fans them out to the llm_call NDJSON event. The +// OTel tracing work (FORGE_OTEL_TRACING.md) will hook into this same +// capture point to populate gen_ai.usage.input_tokens / +// gen_ai.usage.output_tokens span attributes without re-doing the +// per-provider extraction. See issue #87 / FWS-3. +type LLMCallAuditArgs struct { + Model string + Provider string + RequestID string + Usage LLMUsage + Duration time.Duration + // Cancelled flips the emitted event from llm_call to llm_call_cancelled. + // Used for streaming calls aborted mid-flight; partial usage counts are + // still carried. + Cancelled bool +} + +// LLMUsage carries the normalized token counts an LLM call site +// captures from provider response metadata. Mirrors llm.UsageInfo but +// kept in the runtime package so the audit layer has no llm-package +// dependency. The audit emitter sets TokensUnavailable=true on the +// event when both Input and Output are zero — signal to billing +// consumers that the provider did not report usage rather than +// "the call genuinely consumed zero tokens." +type LLMUsage struct { + InputTokens int + OutputTokens int + TotalTokens int +} + +// EmitLLMCall builds and emits an llm_call (or llm_call_cancelled) +// audit event from the captured args. Routed through EmitFromContext +// so workflow-correlation fields (workflow_id / stage_id / step_id / +// invocation_caller from FWS-2) auto-tag every LLM call event when +// the inbound request carried orchestrator headers. This is the +// shared capture point that the OTel tracing work will hook into. +// See issue #87 / FWS-3. +func (a *AuditLogger) EmitLLMCall(ctx context.Context, args LLMCallAuditArgs) { + evt := AuditEvent{ + Event: AuditLLMCall, + Model: args.Model, + Provider: args.Provider, + RequestID: args.RequestID, + } + if args.Cancelled { + evt.Event = AuditLLMCallCancelled + } + in, out := args.Usage.InputTokens, args.Usage.OutputTokens + evt.InputTokens = &in + evt.OutputTokens = &out + if in == 0 && out == 0 { + evt.TokensUnavailable = true + } + d := args.Duration.Milliseconds() + evt.DurationMs = &d + a.EmitFromContext(ctx, evt) +} + +// EmitToolExec emits a tool_exec audit event tagged with the tool +// name + wall-clock duration. Routed through EmitFromContext so +// workflow-correlation fields auto-tag every tool execution when the +// inbound request was orchestrated. The Fields map may carry +// arg-shape metadata (e.g. arg sizes, types) — raw arg values are +// deliberately not emitted here; that question is FWS-8's +// payload-stripping concern, not FWS-3's. See issue #87 / FWS-3. +func (a *AuditLogger) EmitToolExec(ctx context.Context, tool string, duration time.Duration, fields map[string]any) { + d := duration.Milliseconds() + a.EmitFromContext(ctx, AuditEvent{ + Event: AuditToolExec, + DurationMs: &d, + Fields: mergeToolExecFields(tool, fields), + }) +} + +func mergeToolExecFields(tool string, fields map[string]any) map[string]any { + if fields == nil { + fields = map[string]any{} + } + fields["tool"] = tool + fields["phase"] = "end" + return fields +} + +// EmitInvocationComplete emits an invocation_complete audit event +// carrying the total wall-clock duration of the A2A invocation +// (auth → dispatch → engine.Execute → response). Routed through +// EmitFromContext so workflow-correlation fields are inherited from +// the inbound request. One event per invocation; emitted by the +// runner at the response boundary. See issue #87 / FWS-3. +func (a *AuditLogger) EmitInvocationComplete(ctx context.Context, duration time.Duration, fields map[string]any) { + d := duration.Milliseconds() + a.EmitFromContext(ctx, AuditEvent{ + Event: AuditInvocationComplete, + DurationMs: &d, + Fields: fields, + }) +} + // Context key types for correlation IDs, task IDs, and file directories. type correlationIDKey struct{} type taskIDKey struct{} diff --git a/forge-core/runtime/audit_llm_test.go b/forge-core/runtime/audit_llm_test.go new file mode 100644 index 0000000..ce0bd44 --- /dev/null +++ b/forge-core/runtime/audit_llm_test.go @@ -0,0 +1,200 @@ +package runtime + +import ( + "bytes" + "context" + "encoding/json" + "strings" + "testing" + "time" +) + +// Regression tests for issue #87 / FWS-3 — LLM call audit emission +// must carry token / duration / model / provider / request_id with +// OTel-aligned field names, distinguish "no tokens reported" +// (TokensUnavailable=true) from "zero tokens reported," and stay +// additive over the pre-FWS-3 AuditEvent shape. + +func TestEmitLLMCall_FullUsage(t *testing.T) { + var buf bytes.Buffer + audit := NewAuditLogger(&buf) + + ctx := WithCorrelationID(context.Background(), "corr-1") + ctx = WithTaskID(ctx, "task-1") + + audit.EmitLLMCall(ctx, LLMCallAuditArgs{ + Model: "claude-sonnet-4-6", + Provider: "anthropic", + RequestID: "msg_abc", + Usage: LLMUsage{InputTokens: 100, OutputTokens: 50, TotalTokens: 150}, + Duration: 120 * time.Millisecond, + }) + + var evt AuditEvent + if err := json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &evt); err != nil { + t.Fatalf("decode: %v\n%s", err, buf.String()) + } + if evt.Event != AuditLLMCall { + t.Errorf("Event = %q, want %q", evt.Event, AuditLLMCall) + } + if evt.CorrelationID != "corr-1" || evt.TaskID != "task-1" { + t.Errorf("ctx not pulled: %+v", evt) + } + if evt.Model != "claude-sonnet-4-6" || evt.Provider != "anthropic" || evt.RequestID != "msg_abc" { + t.Errorf("attribution missing: %+v", evt) + } + if evt.InputTokens == nil || *evt.InputTokens != 100 { + t.Errorf("InputTokens want 100, got %v", evt.InputTokens) + } + if evt.OutputTokens == nil || *evt.OutputTokens != 50 { + t.Errorf("OutputTokens want 50, got %v", evt.OutputTokens) + } + if evt.TokensUnavailable { + t.Errorf("TokensUnavailable should be false when counts > 0") + } + if evt.DurationMs == nil || *evt.DurationMs != 120 { + t.Errorf("DurationMs want 120, got %v", evt.DurationMs) + } +} + +func TestEmitLLMCall_TokensUnavailable_OllamaMissingUsage(t *testing.T) { + // Self-hosted setups (some Ollama models) don't return token counts. + // EmitLLMCall must flag tokens_unavailable=true rather than emit + // silent zeros that downstream billing would mistake for a free call. + var buf bytes.Buffer + audit := NewAuditLogger(&buf) + + audit.EmitLLMCall(context.Background(), LLMCallAuditArgs{ + Model: "llama3", + Provider: "ollama", + Usage: LLMUsage{InputTokens: 0, OutputTokens: 0, TotalTokens: 0}, + Duration: 50 * time.Millisecond, + }) + + var evt AuditEvent + _ = json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &evt) + if !evt.TokensUnavailable { + t.Errorf("TokensUnavailable should be true when both tokens are 0, got %+v", evt) + } + if evt.DurationMs == nil || *evt.DurationMs != 50 { + t.Errorf("DurationMs must still be set, got %v", evt.DurationMs) + } +} + +func TestEmitLLMCall_Cancelled_EmitsLLMCallCancelledEvent(t *testing.T) { + var buf bytes.Buffer + audit := NewAuditLogger(&buf) + + audit.EmitLLMCall(context.Background(), LLMCallAuditArgs{ + Model: "gpt-4", + Provider: "openai", + Usage: LLMUsage{InputTokens: 100, OutputTokens: 25, TotalTokens: 125}, + Duration: 200 * time.Millisecond, + Cancelled: true, + }) + + js := buf.String() + if !strings.Contains(js, `"event":"llm_call_cancelled"`) { + t.Errorf("Cancelled should emit llm_call_cancelled, got: %s", js) + } + if !strings.Contains(js, `"input_tokens":100`) { + t.Errorf("Cancelled event must still carry partial counts, got: %s", js) + } +} + +func TestEmitLLMCall_FieldNamesAlignWithOTelGenAI(t *testing.T) { + // FWS-3 deliverable: field naming aligns with OTel GenAI semconv + // (input_tokens / output_tokens matching gen_ai.usage.input_tokens + // / gen_ai.usage.output_tokens). Audit consumers can correlate to + // trace data without a translation table. + var buf bytes.Buffer + audit := NewAuditLogger(&buf) + audit.EmitLLMCall(context.Background(), LLMCallAuditArgs{ + Model: "claude", + Provider: "anthropic", + Usage: LLMUsage{InputTokens: 7, OutputTokens: 3, TotalTokens: 10}, + Duration: 1 * time.Millisecond, + }) + js := buf.String() + for _, want := range []string{`"input_tokens"`, `"output_tokens"`, `"duration_ms"`, `"model"`, `"provider"`} { + if !strings.Contains(js, want) { + t.Errorf("expected %s in JSON, got: %s", want, js) + } + } + // Pre-OTel-rename names must NOT appear at the audit-event level + // (legacy struct-name leakage). + for _, forbidden := range []string{`"prompt_tokens":`, `"completion_tokens":`} { + if strings.Contains(js, forbidden) { + t.Errorf("legacy field %s must not leak into llm_call audit, got: %s", forbidden, js) + } + } +} + +func TestEmit_BackwardCompat_NonLLMEventOmitsTokenFields(t *testing.T) { + // Schema additivity guarantee: events that aren't LLM calls must + // emit without input_tokens / output_tokens / duration_ms / etc. + // in the JSON. Pre-FWS-3 consumers reading session_start audit + // must see byte-identical shape. + var buf bytes.Buffer + audit := NewAuditLogger(&buf) + audit.Emit(AuditEvent{ + Event: AuditSessionStart, + CorrelationID: "corr-x", + TaskID: "task-x", + }) + js := buf.String() + for _, forbidden := range []string{`"input_tokens"`, `"output_tokens"`, `"duration_ms"`, `"model"`, `"provider"`, `"tokens_unavailable"`, `"request_id"`} { + if strings.Contains(js, forbidden) { + t.Errorf("non-LLM event should omit %s, got: %s", forbidden, js) + } + } +} + +func TestEmitToolExec_TagsDurationAndStructuredArgs(t *testing.T) { + var buf bytes.Buffer + audit := NewAuditLogger(&buf) + + audit.EmitToolExec(context.Background(), "file_read", 12*time.Millisecond, map[string]any{ + "args_size": 42, + "result_size": 1024, + }) + + var evt AuditEvent + _ = json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &evt) + if evt.Event != AuditToolExec { + t.Errorf("Event = %q, want %q", evt.Event, AuditToolExec) + } + if evt.DurationMs == nil || *evt.DurationMs != 12 { + t.Errorf("DurationMs = %v, want 12", evt.DurationMs) + } + if evt.Fields["tool"] != "file_read" { + t.Errorf("tool field missing") + } + if evt.Fields["args_size"] == nil { + t.Errorf("args_size structured arg metadata missing — raw args must NOT be present, but size MUST") + } +} + +func TestEmitInvocationComplete_CarriesWallClockDuration(t *testing.T) { + var buf bytes.Buffer + audit := NewAuditLogger(&buf) + + audit.EmitInvocationComplete(context.Background(), 950*time.Millisecond, map[string]any{ + "state": "completed", + "input_tokens_total": 200, + "output_tokens_total": 80, + "llm_call_count": 3, + }) + + var evt AuditEvent + _ = json.Unmarshal(bytes.TrimSpace(buf.Bytes()), &evt) + if evt.Event != AuditInvocationComplete { + t.Errorf("Event = %q, want %q", evt.Event, AuditInvocationComplete) + } + if evt.DurationMs == nil || *evt.DurationMs != 950 { + t.Errorf("DurationMs = %v, want 950", evt.DurationMs) + } + if v, ok := evt.Fields["llm_call_count"].(float64); !ok || v != 3 { + t.Errorf("llm_call_count missing or wrong, got %v (%T)", evt.Fields["llm_call_count"], evt.Fields["llm_call_count"]) + } +} diff --git a/forge-core/runtime/hooks.go b/forge-core/runtime/hooks.go index 844883b..f610ae5 100644 --- a/forge-core/runtime/hooks.go +++ b/forge-core/runtime/hooks.go @@ -2,6 +2,7 @@ package runtime import ( "context" + "time" "github.com/initializ/forge/forge-core/llm" ) @@ -18,6 +19,11 @@ const ( ) // HookContext carries data available to hooks at each hook point. +// +// LLMCallDuration / ToolExecDuration / Provider / Model are populated +// at the call site (loop.go) before the After* hook fires, so audit +// emitters can tag llm_call and tool_exec events with wall-clock +// timing and provider attribution. See issue #87 / FWS-3. type HookContext struct { Messages []llm.ChatMessage Response *llm.ChatResponse @@ -27,6 +33,18 @@ type HookContext struct { Error error TaskID string CorrelationID string + + // LLMCallDuration is the wall-clock time spent in the provider + // client.Chat call. Populated for AfterLLMCall hooks. + LLMCallDuration time.Duration + // Provider / Model identify the LLM provider + model used for the + // call. Populated for AfterLLMCall hooks so audit + A2A-header + // emitters can stamp attribution without re-walking config. + Provider string + Model string + // ToolExecDuration is the wall-clock time spent executing the tool. + // Populated for AfterToolExec hooks. + ToolExecDuration time.Duration } // Hook is a function invoked at a specific point in the agent loop. diff --git a/forge-core/runtime/loop.go b/forge-core/runtime/loop.go index 6f189c4..6e52f30 100644 --- a/forge-core/runtime/loop.go +++ b/forge-core/runtime/loop.go @@ -30,6 +30,7 @@ type LLMExecutor struct { store *MemoryStore logger Logger modelName string // resolved model name for context budget + provider string // resolved provider name (anthropic, openai, ollama, custom) charBudget int // resolved character budget maxToolResultChars int // computed from char budget filesDir string // directory for file_create output @@ -48,6 +49,7 @@ type LLMExecutorConfig struct { Store *MemoryStore Logger Logger ModelName string // model name for context-aware budgeting + Provider string // provider name (anthropic, openai, ollama, custom) — for audit attribution CharBudget int // explicit char budget override (0 = auto from model) FilesDir string // directory for file_create output (default: $TMPDIR/forge-files) SessionMaxAge time.Duration // max idle time before session recovery is skipped (0 = 30m default) @@ -103,6 +105,7 @@ func NewLLMExecutor(cfg LLMExecutorConfig) *LLMExecutor { store: cfg.Store, logger: logger, modelName: cfg.ModelName, + provider: cfg.Provider, charBudget: budget, maxToolResultChars: toolLimit, filesDir: cfg.FilesDir, @@ -242,12 +245,20 @@ func (e *LLMExecutor) Execute(ctx context.Context, task *a2a.Task, msg *a2a.Mess Tools: toolDefs, } + // Capture wall-clock duration of the provider call so the + // AfterLLMCall hook can stamp duration_ms on the llm_call audit + // event and X-Forge-Duration-Ms header. See issue #87 / FWS-3. + llmStart := time.Now() resp, err := e.client.Chat(ctx, req) + llmDuration := time.Since(llmStart) if err != nil { _ = e.hooks.Fire(ctx, OnError, &HookContext{ - Error: err, - TaskID: TaskIDFromContext(ctx), - CorrelationID: CorrelationIDFromContext(ctx), + Error: err, + TaskID: TaskIDFromContext(ctx), + CorrelationID: CorrelationIDFromContext(ctx), + LLMCallDuration: llmDuration, + Provider: e.provider, + Model: e.modelName, }) // Return user-friendly error (raw error is already logged via OnError hook) return nil, fmt.Errorf("something went wrong while processing your request, please try again") @@ -255,10 +266,13 @@ func (e *LLMExecutor) Execute(ctx context.Context, task *a2a.Task, msg *a2a.Mess // Fire AfterLLMCall hook if err := e.hooks.Fire(ctx, AfterLLMCall, &HookContext{ - Messages: messages, - Response: resp, - TaskID: TaskIDFromContext(ctx), - CorrelationID: CorrelationIDFromContext(ctx), + Messages: messages, + Response: resp, + TaskID: TaskIDFromContext(ctx), + CorrelationID: CorrelationIDFromContext(ctx), + LLMCallDuration: llmDuration, + Provider: e.provider, + Model: e.modelName, }); err != nil { return nil, fmt.Errorf("after LLM call hook: %w", err) } @@ -388,9 +402,21 @@ func (e *LLMExecutor) Execute(ctx context.Context, task *a2a.Task, msg *a2a.Mess retryReq := &llm.ChatRequest{ Messages: mem.Messages(), } + retryStart := time.Now() if retryResp, retryErr := e.client.Chat(ctx, retryReq); retryErr == nil && strings.TrimSpace(retryResp.Message.Content) != "" { resp = retryResp mem.Append(resp.Message) + // Fire AfterLLMCall so audit + headers capture the retry's + // usage/duration alongside the original turn. + _ = e.hooks.Fire(ctx, AfterLLMCall, &HookContext{ + Messages: mem.Messages(), + Response: retryResp, + TaskID: TaskIDFromContext(ctx), + CorrelationID: CorrelationIDFromContext(ctx), + LLMCallDuration: time.Since(retryStart), + Provider: e.provider, + Model: e.modelName, + }) } } if strings.TrimSpace(resp.Message.Content) == "" { @@ -428,8 +454,12 @@ func (e *LLMExecutor) Execute(ctx context.Context, task *a2a.Task, msg *a2a.Mess return nil, fmt.Errorf("before tool exec hook: %w", err) } - // Execute tool + // Execute tool. Capture wall-clock duration so the + // AfterToolExec hook can stamp duration_ms on the tool_exec + // audit event. See issue #87 / FWS-3. + toolStart := time.Now() result, execErr := e.tools.Execute(ctx, tc.Function.Name, json.RawMessage(tc.Function.Arguments)) + toolDuration := time.Since(toolStart) if execErr != nil { result = fmt.Sprintf("Error executing tool %s: %s", tc.Function.Name, execErr.Error()) } @@ -447,12 +477,13 @@ func (e *LLMExecutor) Execute(ctx context.Context, task *a2a.Task, msg *a2a.Mess // Fire AfterToolExec hook -- hooks may redact ToolOutput. afterHctx := &HookContext{ - ToolName: tc.Function.Name, - ToolInput: tc.Function.Arguments, - ToolOutput: result, - Error: execErr, - TaskID: TaskIDFromContext(ctx), - CorrelationID: CorrelationIDFromContext(ctx), + ToolName: tc.Function.Name, + ToolInput: tc.Function.Arguments, + ToolOutput: result, + Error: execErr, + TaskID: TaskIDFromContext(ctx), + CorrelationID: CorrelationIDFromContext(ctx), + ToolExecDuration: toolDuration, } if err := e.hooks.Fire(ctx, AfterToolExec, afterHctx); err != nil { return nil, fmt.Errorf("after tool exec hook: %w", err) diff --git a/forge-core/runtime/usage_accumulator.go b/forge-core/runtime/usage_accumulator.go new file mode 100644 index 0000000..c246b78 --- /dev/null +++ b/forge-core/runtime/usage_accumulator.go @@ -0,0 +1,112 @@ +package runtime + +import ( + "context" + "sync" + "time" +) + +// LLMUsageAccumulator aggregates per-invocation LLM usage so the A2A +// response handler can populate X-Forge-Tokens-In / X-Forge-Tokens-Out +// / X-Forge-Duration-Ms / X-Forge-Model / X-Forge-Provider headers. +// +// One accumulator is created per A2A invocation by the runner and +// stashed in context.Context. Every AfterLLMCall hook calls AddLLMCall +// to fold the current call's counts into the running totals. At +// response time the runner reads Snapshot() and stamps the headers. +// +// Headers are the orchestration channel for real-time cost enforcement +// during parallel workflow execution. They populate regardless of +// whether OTel tracing is enabled — they're the orchestration channel, +// not the observability channel. See issue #87 / FWS-3. +type LLMUsageAccumulator struct { + mu sync.Mutex + invocationStart time.Time + inputTokensSum int + outputTokensSum int + llmTimeSum time.Duration + primaryModel string + primaryProvider string + llmCallCount int + tokensUnavailHit bool +} + +// NewLLMUsageAccumulator returns a fresh accumulator with its invocation +// clock started at the time of the call. +func NewLLMUsageAccumulator() *LLMUsageAccumulator { + return &LLMUsageAccumulator{invocationStart: time.Now()} +} + +// AddLLMCall folds one LLM call's usage + duration into the running +// totals. The most-recently-added call's model + provider become the +// "primary" reported in the X-Forge-Model / X-Forge-Provider headers, +// matching the issue's spec: "the primary model used (most recent if +// multiple)". +func (a *LLMUsageAccumulator) AddLLMCall(model, provider string, usage LLMUsage, duration time.Duration) { + a.mu.Lock() + defer a.mu.Unlock() + a.inputTokensSum += usage.InputTokens + a.outputTokensSum += usage.OutputTokens + a.llmTimeSum += duration + a.llmCallCount++ + if model != "" { + a.primaryModel = model + } + if provider != "" { + a.primaryProvider = provider + } + if usage.InputTokens == 0 && usage.OutputTokens == 0 { + a.tokensUnavailHit = true + } +} + +// LLMUsageSnapshot is an immutable readout of the accumulator's totals +// at a single point in time. Returned by Snapshot for use by the A2A +// response handler. +type LLMUsageSnapshot struct { + InputTokens int + OutputTokens int + LLMTimeTotal time.Duration // sum of per-LLM-call durations + InvocationDuration time.Duration // wall-clock since accumulator creation + PrimaryModel string + PrimaryProvider string + LLMCallCount int + TokensUnavailable bool +} + +// Snapshot returns the current totals. Safe to call from a goroutine +// different from AddLLMCall callers. +func (a *LLMUsageAccumulator) Snapshot() LLMUsageSnapshot { + a.mu.Lock() + defer a.mu.Unlock() + return LLMUsageSnapshot{ + InputTokens: a.inputTokensSum, + OutputTokens: a.outputTokensSum, + LLMTimeTotal: a.llmTimeSum, + InvocationDuration: time.Since(a.invocationStart), + PrimaryModel: a.primaryModel, + PrimaryProvider: a.primaryProvider, + LLMCallCount: a.llmCallCount, + TokensUnavailable: a.tokensUnavailHit && a.inputTokensSum == 0 && a.outputTokensSum == 0, + } +} + +type llmUsageAccumulatorKey struct{} + +// WithLLMUsageAccumulator stashes a per-invocation accumulator in ctx. +// The runner creates one per A2A invocation at request entry; the +// AfterLLMCall hook reads it via LLMUsageAccumulatorFromContext and +// folds each call's counts into the totals. +func WithLLMUsageAccumulator(ctx context.Context, acc *LLMUsageAccumulator) context.Context { + return context.WithValue(ctx, llmUsageAccumulatorKey{}, acc) +} + +// LLMUsageAccumulatorFromContext returns the per-invocation +// accumulator from ctx, or nil when no accumulator was attached +// (e.g. internal cron-fire paths that don't need response headers). +func LLMUsageAccumulatorFromContext(ctx context.Context) *LLMUsageAccumulator { + if acc, ok := ctx.Value(llmUsageAccumulatorKey{}).(*LLMUsageAccumulator); ok { + return acc + } + return nil +} diff --git a/forge-core/runtime/usage_accumulator_test.go b/forge-core/runtime/usage_accumulator_test.go new file mode 100644 index 0000000..35d166b --- /dev/null +++ b/forge-core/runtime/usage_accumulator_test.go @@ -0,0 +1,132 @@ +package runtime + +import ( + "context" + "sync" + "testing" + "time" +) + +// Regression tests for issue #87 / FWS-3 — the per-invocation LLM +// usage accumulator. Tracks running totals so the A2A response handler +// can stamp X-Forge-Tokens-In/Out/Duration-Ms/Model/Provider headers +// and emit invocation_complete with aggregated counts. + +func TestLLMUsageAccumulator_AggregatesAcrossCalls(t *testing.T) { + acc := NewLLMUsageAccumulator() + acc.AddLLMCall("claude", "anthropic", LLMUsage{InputTokens: 100, OutputTokens: 50}, 50*time.Millisecond) + acc.AddLLMCall("claude", "anthropic", LLMUsage{InputTokens: 200, OutputTokens: 75}, 80*time.Millisecond) + acc.AddLLMCall("claude", "anthropic", LLMUsage{InputTokens: 50, OutputTokens: 25}, 30*time.Millisecond) + + snap := acc.Snapshot() + if snap.InputTokens != 350 { + t.Errorf("InputTokens sum = %d, want 350", snap.InputTokens) + } + if snap.OutputTokens != 150 { + t.Errorf("OutputTokens sum = %d, want 150", snap.OutputTokens) + } + if snap.LLMCallCount != 3 { + t.Errorf("LLMCallCount = %d, want 3", snap.LLMCallCount) + } +} + +func TestLLMUsageAccumulator_PrimaryIsMostRecentNonEmpty(t *testing.T) { + // Spec: X-Forge-Model / X-Forge-Provider report "the primary model + // used (most recent if multiple)." This matches the most common + // orchestration pattern where the final model decides cost class. + acc := NewLLMUsageAccumulator() + acc.AddLLMCall("claude-haiku", "anthropic", LLMUsage{InputTokens: 10, OutputTokens: 5}, time.Millisecond) + acc.AddLLMCall("gpt-4", "openai", LLMUsage{InputTokens: 20, OutputTokens: 10}, time.Millisecond) + snap := acc.Snapshot() + if snap.PrimaryModel != "gpt-4" || snap.PrimaryProvider != "openai" { + t.Errorf("Primary should be most-recent (gpt-4 / openai), got %s / %s", snap.PrimaryModel, snap.PrimaryProvider) + } +} + +func TestLLMUsageAccumulator_TokensUnavailableLatchesOnAllZero(t *testing.T) { + // If every call had no usage info (Ollama on a self-hosted model), + // the snapshot's TokensUnavailable must be true so the A2A header + // layer knows to skip X-Forge-Tokens-* (downstream billing must + // distinguish "we didn't measure" from "you used zero tokens"). + acc := NewLLMUsageAccumulator() + acc.AddLLMCall("llama3", "ollama", LLMUsage{InputTokens: 0, OutputTokens: 0}, time.Millisecond) + snap := acc.Snapshot() + if !snap.TokensUnavailable { + t.Errorf("all-zero usage must latch TokensUnavailable=true, got %+v", snap) + } +} + +func TestLLMUsageAccumulator_TokensUnavailableClearsWhenAnyCallReports(t *testing.T) { + // Mixed-provider workflow: if any call reported usage, totals are + // meaningful and TokensUnavailable should NOT latch — billing can + // use the snapshot's InputTokens/OutputTokens as the bill-from value. + acc := NewLLMUsageAccumulator() + acc.AddLLMCall("llama3", "ollama", LLMUsage{InputTokens: 0, OutputTokens: 0}, time.Millisecond) + acc.AddLLMCall("claude", "anthropic", LLMUsage{InputTokens: 100, OutputTokens: 50}, time.Millisecond) + snap := acc.Snapshot() + if snap.TokensUnavailable { + t.Errorf("partial-reporting workflow must NOT latch TokensUnavailable, got %+v", snap) + } + if snap.InputTokens != 100 || snap.OutputTokens != 50 { + t.Errorf("billable totals wrong: %+v", snap) + } +} + +func TestLLMUsageAccumulator_InvocationDurationTrackedSeparatelyFromLLMTime(t *testing.T) { + // LLM time and wall-clock invocation time are different — LLM time + // is sum-of-Chat-durations, invocation duration is end-to-end wall + // clock (includes tool execution, guardrails, audit emission). + acc := NewLLMUsageAccumulator() + time.Sleep(15 * time.Millisecond) + acc.AddLLMCall("claude", "anthropic", LLMUsage{InputTokens: 10}, 5*time.Millisecond) + snap := acc.Snapshot() + if snap.LLMTimeTotal != 5*time.Millisecond { + t.Errorf("LLMTimeTotal must be sum of per-call durations, got %v", snap.LLMTimeTotal) + } + if snap.InvocationDuration < 15*time.Millisecond { + t.Errorf("InvocationDuration must be wall-clock since accumulator creation, got %v", snap.InvocationDuration) + } +} + +func TestLLMUsageAccumulatorFromContext_MissingReturnsNil(t *testing.T) { + if acc := LLMUsageAccumulatorFromContext(context.Background()); acc != nil { + t.Errorf("missing ctx value should return nil, got %v", acc) + } +} + +func TestLLMUsageAccumulatorFromContext_RoundTrip(t *testing.T) { + acc := NewLLMUsageAccumulator() + ctx := WithLLMUsageAccumulator(context.Background(), acc) + got := LLMUsageAccumulatorFromContext(ctx) + if got != acc { + t.Errorf("ctx round-trip should return same accumulator") + } +} + +func TestLLMUsageAccumulator_ConcurrentAddSafe(t *testing.T) { + // AfterLLMCall hooks may fire from goroutines. The accumulator + // must be safe to add to concurrently or we'd lose token data to + // races — silently undercounting cost data is worse than crashing. + acc := NewLLMUsageAccumulator() + const goroutines = 50 + const callsEach = 10 + var wg sync.WaitGroup + wg.Add(goroutines) + for i := 0; i < goroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < callsEach; j++ { + acc.AddLLMCall("claude", "anthropic", LLMUsage{InputTokens: 1, OutputTokens: 1}, time.Microsecond) + } + }() + } + wg.Wait() + snap := acc.Snapshot() + want := goroutines * callsEach + if snap.InputTokens != want || snap.OutputTokens != want { + t.Errorf("concurrent add lost data: in=%d out=%d, want %d each", snap.InputTokens, snap.OutputTokens, want) + } + if snap.LLMCallCount != want { + t.Errorf("LLMCallCount = %d, want %d", snap.LLMCallCount, want) + } +}