From bf0a8e7ae0cab2a27cb60d551d31d896868189a0 Mon Sep 17 00:00:00 2001 From: mesutoezdil Date: Fri, 15 May 2026 21:35:19 +0200 Subject: [PATCH 1/2] fix(bedrock): preserve thinking blocks in multi-turn tool use When extended thinking is active, Bedrock returns thinking content blocks together with tool use blocks. The API requires these blocks to be sent back unmodified in the next request. Without them, Bedrock returns a ValidationException with toolUse.input is empty. Only emit thinking blocks for the last assistant turn before tool results. Sending them in all turns causes token counts to compound across long sessions. Truncate tool results in historical turns to 2000 chars. Older turns do not need full fidelity for large kubectl or YAML outputs. Wrap the OTLP span exporter to cap string attributes at 16KB, preventing large tool responses from exceeding Tempo's 4MB gRPC message limit. Fixes #1870 Signed-off-by: mesutoezdil --- go/adk/pkg/models/bedrock.go | 132 ++++++++++++++-- go/adk/pkg/models/bedrock_test.go | 159 ++++++++++++++++++++ go/adk/pkg/telemetry/tracing.go | 2 +- go/adk/pkg/telemetry/truncating_exporter.go | 79 ++++++++++ 4 files changed, 356 insertions(+), 16 deletions(-) create mode 100644 go/adk/pkg/telemetry/truncating_exporter.go diff --git a/go/adk/pkg/models/bedrock.go b/go/adk/pkg/models/bedrock.go index d9db5a842e..aac601b88e 100644 --- a/go/adk/pkg/models/bedrock.go +++ b/go/adk/pkg/models/bedrock.go @@ -171,18 +171,23 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques // is written with the sanitized name Bedrock already knows about. messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents, nameMap) + // temperature/top_p must not be sent when thinking is active. https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + _, thinkingEnabled := m.Config.AdditionalModelRequestFields["thinking"] + // Build inference config var inferenceConfig *types.InferenceConfiguration - if m.Config.MaxTokens != nil || m.Config.Temperature != nil || m.Config.TopP != nil { + if m.Config.MaxTokens != nil || (!thinkingEnabled && (m.Config.Temperature != nil || m.Config.TopP != nil)) { inferenceConfig = &types.InferenceConfiguration{} if m.Config.MaxTokens != nil { inferenceConfig.MaxTokens = aws.Int32(int32(*m.Config.MaxTokens)) } - if m.Config.Temperature != nil { - inferenceConfig.Temperature = aws.Float32(float32(*m.Config.Temperature)) - } - if m.Config.TopP != nil { - inferenceConfig.TopP = aws.Float32(float32(*m.Config.TopP)) + if !thinkingEnabled { + if m.Config.Temperature != nil { + inferenceConfig.Temperature = aws.Float32(float32(*m.Config.Temperature)) + } + if m.Config.TopP != nil { + inferenceConfig.TopP = aws.Float32(float32(*m.Config.TopP)) + } } } @@ -248,6 +253,10 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me toolCalls := make(map[int32]*streamingToolCall) var completedToolCalls []*genai.Part + // https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + reasoningBlocks := make(map[int32]*streamingReasoningBlock) + var completedThinkingParts []*genai.Part + // Get the event stream and read events from the channel stream := output.GetStream() defer stream.Close() @@ -295,10 +304,22 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me if tc, ok := toolCalls[blockIdx]; ok && delta.Value.Input != nil { tc.InputJSON += aws.ToString(delta.Value.Input) } + + case *types.ContentBlockDeltaMemberReasoningContent: + if _, ok := reasoningBlocks[blockIdx]; !ok { + reasoningBlocks[blockIdx] = &streamingReasoningBlock{} + } + rb := reasoningBlocks[blockIdx] + switch inner := delta.Value.(type) { + case *types.ReasoningContentBlockDeltaMemberText: + rb.Text.WriteString(inner.Value) + case *types.ReasoningContentBlockDeltaMemberSignature: + rb.Signature = inner.Value + } } } - // Handle content block stop (tool use complete) + // Handle content block stop (tool use or thinking block complete) if stop, ok := event.(*types.ConverseStreamOutputMemberContentBlockStop); ok { blockIdx := aws.ToInt32(stop.Value.ContentBlockIndex) if tc, ok := toolCalls[blockIdx]; ok { @@ -316,7 +337,18 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me Args: args, } completedToolCalls = append(completedToolCalls, &genai.Part{FunctionCall: functionCall}) - delete(toolCalls, blockIdx) // Clean up + delete(toolCalls, blockIdx) + } + if rb, ok := reasoningBlocks[blockIdx]; ok { + part := &genai.Part{ + Thought: true, + Text: rb.Text.String(), + } + if rb.Signature != "" { + part.ThoughtSignature = []byte(rb.Signature) + } + completedThinkingParts = append(completedThinkingParts, part) + delete(reasoningBlocks, blockIdx) } } @@ -337,8 +369,9 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me } } - // Build final response + // thinking parts first; block order must match what Bedrock returned. finalParts := []*genai.Part{} + finalParts = append(finalParts, completedThinkingParts...) text := aggregatedText.String() if text != "" { finalParts = append(finalParts, &genai.Part{Text: text}) @@ -366,6 +399,11 @@ type streamingToolCall struct { InputJSON string // Accumulated JSON input } +type streamingReasoningBlock struct { + Text strings.Builder + Signature string +} + // parseArgs parses the accumulated JSON input into a map func (tc *streamingToolCall) parseArgs() map[string]any { if tc.InputJSON == "" { @@ -403,6 +441,20 @@ func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, parts := []*genai.Part{} if message, ok := output.Output.(*types.ConverseOutputMemberMessage); ok { for _, block := range message.Value.Content { + // https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + if reasoningBlock, ok := block.(*types.ContentBlockMemberReasoningContent); ok { + if textBlock, ok := reasoningBlock.Value.(*types.ReasoningContentBlockMemberReasoningText); ok { + part := &genai.Part{ + Thought: true, + Text: aws.ToString(textBlock.Value.Text), + } + if textBlock.Value.Signature != nil { + part.ThoughtSignature = []byte(aws.ToString(textBlock.Value.Signature)) + } + parts = append(parts, part) + } + continue + } // Handle text content if textBlock, ok := block.(*types.ContentBlockMemberText); ok { parts = append(parts, &genai.Part{Text: textBlock.Value}) @@ -473,6 +525,15 @@ func documentToMap(doc document.Interface) map[string]any { return result } +const historyToolResultMaxLen = 2000 + +func truncateToolResult(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + fmt.Sprintf("\n... [truncated, %d chars omitted]", len(s)-maxLen) +} + // convertGenaiContentsToBedrockMessages converts genai.Content to Bedrock Converse API message format. // nameMap is the original->sanitized tool name map produced by convertGenaiToolsToBedrock. // Any FunctionCall found in the conversation history is written with the sanitized name so @@ -486,17 +547,40 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma idMap := make(map[string]string) idCounter := 0 - for _, content := range contents { + // Bedrock only requires thinking blocks in the last assistant turn before tool results. + // Sending them in earlier turns causes token counts to compound across long sessions. + // Truncate tool results in all turns except the most recent one carrying them. + lastThinkingIdx, lastToolResultIdx := -1, -1 + for i, c := range contents { + if c == nil { + continue + } + for _, p := range c.Parts { + if p == nil { + continue + } + if p.Thought && (c.Role == "model" || c.Role == "assistant") { + lastThinkingIdx = i + } + if p.FunctionResponse != nil && c.Role == "user" { + lastToolResultIdx = i + } + } + } + + for i, content := range contents { if content == nil || len(content.Parts) == 0 { continue } - // Determine role role := types.ConversationRoleUser if content.Role == "model" || content.Role == "assistant" { role = types.ConversationRoleAssistant } + emitThinking := i == lastThinkingIdx + truncateTools := i != lastToolResultIdx + var contentBlocks []types.ContentBlock for _, part := range content.Parts { @@ -504,9 +588,26 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma continue } - // Handle text + // Thought parts also carry Text; check Thought first. https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + if part.Thought { + if !emitThinking { + continue + } + reasoningText := &types.ReasoningTextBlock{ + Text: aws.String(part.Text), + } + if len(part.ThoughtSignature) > 0 { + reasoningText.Signature = aws.String(string(part.ThoughtSignature)) + } + contentBlocks = append(contentBlocks, &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberReasoningText{ + Value: *reasoningText, + }, + }) + continue + } + if part.Text != "" { - // Check if this is a system message if content.Role == "system" { systemInstruction = part.Text continue @@ -536,10 +637,11 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma continue } - // Handle function response (tool result in Bedrock terminology) if part.FunctionResponse != nil { - // Extract response content result := extractFunctionResponseContent(part.FunctionResponse.Response) + if truncateTools { + result = truncateToolResult(result, historyToolResultMaxLen) + } toolResult := types.ToolResultBlock{ ToolUseId: aws.String(sanitizeBedrockToolID(part.FunctionResponse.ID, idMap, &idCounter)), Content: []types.ToolResultContentBlock{ diff --git a/go/adk/pkg/models/bedrock_test.go b/go/adk/pkg/models/bedrock_test.go index de2d1c3caf..7ddc6e1474 100644 --- a/go/adk/pkg/models/bedrock_test.go +++ b/go/adk/pkg/models/bedrock_test.go @@ -2,6 +2,7 @@ package models import ( "encoding/json" + "strings" "testing" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" @@ -102,6 +103,41 @@ func TestConvertGenaiContentsToBedrockMessages(t *testing.T) { } }, }, + { + name: "thinking block preserved as ReasoningContent", + contents: []*genai.Content{ + { + Role: "model", + Parts: []*genai.Part{ + {Thought: true, Text: "let me think", ThoughtSignature: []byte("sig123")}, + {FunctionCall: &genai.FunctionCall{ID: "c1", Name: "get_weather", Args: map[string]any{"location": "Paris"}}}, + }, + }, + }, + wantMsgCount: 1, + checkMsg: func(t *testing.T, msgs []types.Message) { + if len(msgs[0].Content) != 2 { + t.Fatalf("expected 2 blocks (thinking + toolUse), got %d", len(msgs[0].Content)) + } + rb, ok := msgs[0].Content[0].(*types.ContentBlockMemberReasoningContent) + if !ok { + t.Fatalf("block 0: want *ContentBlockMemberReasoningContent, got %T", msgs[0].Content[0]) + } + rt, ok := rb.Value.(*types.ReasoningContentBlockMemberReasoningText) + if !ok { + t.Fatalf("reasoning value: want *ReasoningContentBlockMemberReasoningText, got %T", rb.Value) + } + if *rt.Value.Text != "let me think" { + t.Errorf("text: want %q, got %q", "let me think", *rt.Value.Text) + } + if *rt.Value.Signature != "sig123" { + t.Errorf("signature: want %q, got %q", "sig123", *rt.Value.Signature) + } + if _, ok := msgs[0].Content[1].(*types.ContentBlockMemberToolUse); !ok { + t.Errorf("block 1: want *ContentBlockMemberToolUse, got %T", msgs[0].Content[1]) + } + }, + }, } for _, tt := range tests { @@ -424,3 +460,126 @@ func TestStreamingToolCallParseArgs(t *testing.T) { }) } } + +func TestThinkingOnlyInLastAssistantTurn(t *testing.T) { + contents := []*genai.Content{ + { + Role: "model", + Parts: []*genai.Part{ + {Thought: true, Text: "first think", ThoughtSignature: []byte("sig1")}, + {FunctionCall: &genai.FunctionCall{ID: "c1", Name: "tool_a", Args: map[string]any{}}}, + }, + }, + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c1", Name: "tool_a", Response: map[string]any{"r": "v1"}}}}, + }, + { + Role: "model", + Parts: []*genai.Part{ + {Thought: true, Text: "second think", ThoughtSignature: []byte("sig2")}, + {FunctionCall: &genai.FunctionCall{ID: "c2", Name: "tool_b", Args: map[string]any{}}}, + }, + }, + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c2", Name: "tool_b", Response: map[string]any{"r": "v2"}}}}, + }, + } + + msgs, _ := convertGenaiContentsToBedrockMessages(contents, nil) + if len(msgs) != 4 { + t.Fatalf("want 4 messages, got %d", len(msgs)) + } + + // First assistant turn must NOT contain reasoning content. + for _, block := range msgs[0].Content { + if _, ok := block.(*types.ContentBlockMemberReasoningContent); ok { + t.Error("first assistant turn must not contain reasoning content") + } + } + + // Last assistant turn (index 2) must contain reasoning content. + hasReasoning := false + for _, block := range msgs[2].Content { + if _, ok := block.(*types.ContentBlockMemberReasoningContent); ok { + hasReasoning = true + } + } + if !hasReasoning { + t.Error("last assistant turn must contain reasoning content") + } +} + +func TestHistoricalToolResultTruncation(t *testing.T) { + longOutput := strings.Repeat("x", historyToolResultMaxLen+500) + contents := []*genai.Content{ + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c1", Name: "tool_a", Response: map[string]any{"result": longOutput}}}}, + }, + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c2", Name: "tool_b", Response: map[string]any{"result": longOutput}}}}, + }, + } + + msgs, _ := convertGenaiContentsToBedrockMessages(contents, nil) + if len(msgs) != 2 { + t.Fatalf("want 2 messages, got %d", len(msgs)) + } + + extractText := func(msg types.Message) string { + for _, block := range msg.Content { + if tr, ok := block.(*types.ContentBlockMemberToolResult); ok { + for _, c := range tr.Value.Content { + if txt, ok := c.(*types.ToolResultContentBlockMemberText); ok { + return txt.Value + } + } + } + } + return "" + } + + first := extractText(msgs[0]) + if len(first) >= len(longOutput) { + t.Errorf("historical tool result should be truncated, got len=%d", len(first)) + } + + last := extractText(msgs[1]) + if len(last) != len(longOutput) { + t.Errorf("latest tool result must not be truncated, got len=%d want %d", len(last), len(longOutput)) + } +} + +func TestTruncateToolResult(t *testing.T) { + cases := []struct { + name string + input string + maxLen int + wantLen int + wantMsg bool + }{ + {"no truncation needed", "short", 100, 5, false}, + {"exact boundary", strings.Repeat("a", 100), 100, 100, false}, + {"truncated", strings.Repeat("a", 150), 100, -1, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := truncateToolResult(tc.input, tc.maxLen) + if tc.wantMsg { + if len(got) <= tc.maxLen { + t.Errorf("expected truncated result longer than maxLen, got %d", len(got)) + } + if !strings.Contains(got, "truncated") { + t.Error("truncated result must contain truncation notice") + } + } else { + if len(got) != tc.wantLen { + t.Errorf("want len %d, got %d", tc.wantLen, len(got)) + } + } + }) + } +} diff --git a/go/adk/pkg/telemetry/tracing.go b/go/adk/pkg/telemetry/tracing.go index 695ec0711e..81e06d9aea 100644 --- a/go/adk/pkg/telemetry/tracing.go +++ b/go/adk/pkg/telemetry/tracing.go @@ -156,7 +156,7 @@ func newTracerProvider(ctx context.Context, res *resource.Resource) (*sdktrace.T return sdktrace.NewTracerProvider( sdktrace.WithSpanProcessor(kagentAttributesSpanProcessor{}), - sdktrace.WithBatcher(exporter), + sdktrace.WithBatcher(newTruncatingExporter(exporter)), sdktrace.WithResource(res), ), nil } diff --git a/go/adk/pkg/telemetry/truncating_exporter.go b/go/adk/pkg/telemetry/truncating_exporter.go new file mode 100644 index 0000000000..9866ecdd8a --- /dev/null +++ b/go/adk/pkg/telemetry/truncating_exporter.go @@ -0,0 +1,79 @@ +package telemetry + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +// maxSpanAttributeBytes caps individual string span attributes before export. +// Tempo's default gRPC message limit is 4 MB; large tool responses can exceed it. +const maxSpanAttributeBytes = 16 * 1024 + +type truncatingExporter struct { + inner sdktrace.SpanExporter +} + +func newTruncatingExporter(inner sdktrace.SpanExporter) sdktrace.SpanExporter { + return &truncatingExporter{inner: inner} +} + +func (e *truncatingExporter) ExportSpans(ctx context.Context, spans []sdktrace.ReadOnlySpan) error { + out := make([]sdktrace.ReadOnlySpan, len(spans)) + for i, s := range spans { + out[i] = truncateSpan(s) + } + return e.inner.ExportSpans(ctx, out) +} + +func (e *truncatingExporter) Shutdown(ctx context.Context) error { + return e.inner.Shutdown(ctx) +} + +func truncateSpan(s sdktrace.ReadOnlySpan) sdktrace.ReadOnlySpan { + attrs := s.Attributes() + needsCut := false + for _, a := range attrs { + if a.Value.Type() == attribute.STRING && len(a.Value.AsString()) > maxSpanAttributeBytes { + needsCut = true + break + } + } + if !needsCut { + return s + } + newAttrs := make([]attribute.KeyValue, len(attrs)) + for i, a := range attrs { + if a.Value.Type() == attribute.STRING { + v := a.Value.AsString() + if len(v) > maxSpanAttributeBytes { + newAttrs[i] = attribute.String(string(a.Key), + v[:maxSpanAttributeBytes]+fmt.Sprintf(" ...[truncated, %d bytes omitted]", len(v)-maxSpanAttributeBytes)) + continue + } + } + newAttrs[i] = a + } + stub := tracetest.SpanStub{ + Name: s.Name(), + SpanContext: s.SpanContext(), + Parent: s.Parent(), + SpanKind: s.SpanKind(), + StartTime: s.StartTime(), + EndTime: s.EndTime(), + Attributes: newAttrs, + Events: s.Events(), + Links: s.Links(), + Status: s.Status(), + DroppedAttributes: s.DroppedAttributes(), + DroppedEvents: s.DroppedEvents(), + DroppedLinks: s.DroppedLinks(), + ChildSpanCount: s.ChildSpanCount(), + Resource: s.Resource(), + InstrumentationLibrary: s.InstrumentationScope(), + } + return stub.Snapshot() +} From 9ef365602949ea942d9cc793616a7993e151cdd0 Mon Sep 17 00:00:00 2001 From: mesutoezdil Date: Fri, 22 May 2026 18:53:30 +0300 Subject: [PATCH 2/2] refactor(bedrock): extract buildInferenceConfig and add unit tests Extract inference config construction into a testable helper so that the temperature/top_p exclusion logic for extended thinking can be exercised directly without mocking the AWS Bedrock client. Signed-off-by: mesutoezdil --- go/adk/pkg/models/bedrock.go | 40 ++++++++------ go/adk/pkg/models/bedrock_test.go | 89 +++++++++++++++++++++++++++++++ 2 files changed, 112 insertions(+), 17 deletions(-) diff --git a/go/adk/pkg/models/bedrock.go b/go/adk/pkg/models/bedrock.go index aac601b88e..216a95dcc4 100644 --- a/go/adk/pkg/models/bedrock.go +++ b/go/adk/pkg/models/bedrock.go @@ -173,23 +173,7 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques // temperature/top_p must not be sent when thinking is active. https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html _, thinkingEnabled := m.Config.AdditionalModelRequestFields["thinking"] - - // Build inference config - var inferenceConfig *types.InferenceConfiguration - if m.Config.MaxTokens != nil || (!thinkingEnabled && (m.Config.Temperature != nil || m.Config.TopP != nil)) { - inferenceConfig = &types.InferenceConfiguration{} - if m.Config.MaxTokens != nil { - inferenceConfig.MaxTokens = aws.Int32(int32(*m.Config.MaxTokens)) - } - if !thinkingEnabled { - if m.Config.Temperature != nil { - inferenceConfig.Temperature = aws.Float32(float32(*m.Config.Temperature)) - } - if m.Config.TopP != nil { - inferenceConfig.TopP = aws.Float32(float32(*m.Config.TopP)) - } - } - } + inferenceConfig := buildInferenceConfig(m.Config, thinkingEnabled) // Build system prompt var systemPrompt []types.SystemContentBlock @@ -743,3 +727,25 @@ func bedrockStopReasonToGenai(reason types.StopReason) genai.FinishReason { return genai.FinishReasonStop } } + +// buildInferenceConfig constructs the Bedrock InferenceConfiguration from a +// BedrockConfig. When thinking is enabled, temperature and top_p must be +// omitted per the Bedrock extended-thinking API contract. +func buildInferenceConfig(cfg *BedrockConfig, thinkingEnabled bool) *types.InferenceConfiguration { + if cfg.MaxTokens == nil && (thinkingEnabled || (cfg.Temperature == nil && cfg.TopP == nil)) { + return nil + } + ic := &types.InferenceConfiguration{} + if cfg.MaxTokens != nil { + ic.MaxTokens = aws.Int32(int32(*cfg.MaxTokens)) + } + if !thinkingEnabled { + if cfg.Temperature != nil { + ic.Temperature = aws.Float32(float32(*cfg.Temperature)) + } + if cfg.TopP != nil { + ic.TopP = aws.Float32(float32(*cfg.TopP)) + } + } + return ic +} diff --git a/go/adk/pkg/models/bedrock_test.go b/go/adk/pkg/models/bedrock_test.go index 7ddc6e1474..1094bd0771 100644 --- a/go/adk/pkg/models/bedrock_test.go +++ b/go/adk/pkg/models/bedrock_test.go @@ -583,3 +583,92 @@ func TestTruncateToolResult(t *testing.T) { }) } } + +func TestBuildInferenceConfig(t *testing.T) { + fTemp := float64(0.7) + fTopP := float64(0.9) + maxTok := 1000 + + tests := []struct { + name string + cfg BedrockConfig + thinkingActive bool + wantNil bool + wantTemp *float32 + wantTopP *float32 + wantMaxTokens *int32 + }{ + { + name: "thinking drops temperature and topP", + cfg: BedrockConfig{Temperature: &fTemp, TopP: &fTopP}, + thinkingActive: true, + wantNil: true, + }, + { + name: "thinking with maxTokens keeps only maxTokens", + cfg: BedrockConfig{Temperature: &fTemp, TopP: &fTopP, MaxTokens: &maxTok}, + thinkingActive: true, + wantNil: false, + wantMaxTokens: func() *int32 { v := int32(1000); return &v }(), + }, + { + name: "no thinking passes temperature and topP", + cfg: BedrockConfig{Temperature: &fTemp, TopP: &fTopP}, + thinkingActive: false, + wantNil: false, + wantTemp: func() *float32 { v := float32(0.7); return &v }(), + wantTopP: func() *float32 { v := float32(0.9); return &v }(), + }, + { + name: "all nil returns nil", + cfg: BedrockConfig{}, + thinkingActive: false, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildInferenceConfig(&tt.cfg, tt.thinkingActive) + if tt.wantNil { + if got != nil { + t.Fatalf("want nil, got %+v", got) + } + return + } + if got == nil { + t.Fatal("want non-nil InferenceConfiguration, got nil") + } + if tt.wantTemp == nil && got.Temperature != nil { + t.Errorf("temperature: want nil, got %v", *got.Temperature) + } + if tt.wantTemp != nil { + if got.Temperature == nil { + t.Fatalf("temperature: want %v, got nil", *tt.wantTemp) + } + if *got.Temperature != *tt.wantTemp { + t.Errorf("temperature: want %v, got %v", *tt.wantTemp, *got.Temperature) + } + } + if tt.wantTopP == nil && got.TopP != nil { + t.Errorf("topP: want nil, got %v", *got.TopP) + } + if tt.wantTopP != nil { + if got.TopP == nil { + t.Fatalf("topP: want %v, got nil", *tt.wantTopP) + } + if *got.TopP != *tt.wantTopP { + t.Errorf("topP: want %v, got %v", *tt.wantTopP, *got.TopP) + } + } + if tt.wantMaxTokens != nil { + if got.MaxTokens == nil { + t.Fatalf("maxTokens: want %v, got nil", *tt.wantMaxTokens) + } + if *got.MaxTokens != *tt.wantMaxTokens { + t.Errorf("maxTokens: want %v, got %v", *tt.wantMaxTokens, *got.MaxTokens) + } + } + }) + } +}