diff --git a/go/adk/pkg/models/bedrock.go b/go/adk/pkg/models/bedrock.go index d9db5a842e..216a95dcc4 100644 --- a/go/adk/pkg/models/bedrock.go +++ b/go/adk/pkg/models/bedrock.go @@ -171,20 +171,9 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques // is written with the sanitized name Bedrock already knows about. messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents, nameMap) - // Build inference config - var inferenceConfig *types.InferenceConfiguration - if m.Config.MaxTokens != nil || m.Config.Temperature != nil || m.Config.TopP != nil { - inferenceConfig = &types.InferenceConfiguration{} - if m.Config.MaxTokens != nil { - inferenceConfig.MaxTokens = aws.Int32(int32(*m.Config.MaxTokens)) - } - if m.Config.Temperature != nil { - inferenceConfig.Temperature = aws.Float32(float32(*m.Config.Temperature)) - } - if m.Config.TopP != nil { - inferenceConfig.TopP = aws.Float32(float32(*m.Config.TopP)) - } - } + // temperature/top_p must not be sent when thinking is active. https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + _, thinkingEnabled := m.Config.AdditionalModelRequestFields["thinking"] + inferenceConfig := buildInferenceConfig(m.Config, thinkingEnabled) // Build system prompt var systemPrompt []types.SystemContentBlock @@ -248,6 +237,10 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me toolCalls := make(map[int32]*streamingToolCall) var completedToolCalls []*genai.Part + // https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + reasoningBlocks := make(map[int32]*streamingReasoningBlock) + var completedThinkingParts []*genai.Part + // Get the event stream and read events from the channel stream := output.GetStream() defer stream.Close() @@ -295,10 +288,22 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me if tc, ok := toolCalls[blockIdx]; ok && delta.Value.Input != nil { tc.InputJSON += aws.ToString(delta.Value.Input) } + + case *types.ContentBlockDeltaMemberReasoningContent: + if _, ok := reasoningBlocks[blockIdx]; !ok { + reasoningBlocks[blockIdx] = &streamingReasoningBlock{} + } + rb := reasoningBlocks[blockIdx] + switch inner := delta.Value.(type) { + case *types.ReasoningContentBlockDeltaMemberText: + rb.Text.WriteString(inner.Value) + case *types.ReasoningContentBlockDeltaMemberSignature: + rb.Signature = inner.Value + } } } - // Handle content block stop (tool use complete) + // Handle content block stop (tool use or thinking block complete) if stop, ok := event.(*types.ConverseStreamOutputMemberContentBlockStop); ok { blockIdx := aws.ToInt32(stop.Value.ContentBlockIndex) if tc, ok := toolCalls[blockIdx]; ok { @@ -316,7 +321,18 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me Args: args, } completedToolCalls = append(completedToolCalls, &genai.Part{FunctionCall: functionCall}) - delete(toolCalls, blockIdx) // Clean up + delete(toolCalls, blockIdx) + } + if rb, ok := reasoningBlocks[blockIdx]; ok { + part := &genai.Part{ + Thought: true, + Text: rb.Text.String(), + } + if rb.Signature != "" { + part.ThoughtSignature = []byte(rb.Signature) + } + completedThinkingParts = append(completedThinkingParts, part) + delete(reasoningBlocks, blockIdx) } } @@ -337,8 +353,9 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me } } - // Build final response + // thinking parts first; block order must match what Bedrock returned. finalParts := []*genai.Part{} + finalParts = append(finalParts, completedThinkingParts...) text := aggregatedText.String() if text != "" { finalParts = append(finalParts, &genai.Part{Text: text}) @@ -366,6 +383,11 @@ type streamingToolCall struct { InputJSON string // Accumulated JSON input } +type streamingReasoningBlock struct { + Text strings.Builder + Signature string +} + // parseArgs parses the accumulated JSON input into a map func (tc *streamingToolCall) parseArgs() map[string]any { if tc.InputJSON == "" { @@ -403,6 +425,20 @@ func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string, parts := []*genai.Part{} if message, ok := output.Output.(*types.ConverseOutputMemberMessage); ok { for _, block := range message.Value.Content { + // https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + if reasoningBlock, ok := block.(*types.ContentBlockMemberReasoningContent); ok { + if textBlock, ok := reasoningBlock.Value.(*types.ReasoningContentBlockMemberReasoningText); ok { + part := &genai.Part{ + Thought: true, + Text: aws.ToString(textBlock.Value.Text), + } + if textBlock.Value.Signature != nil { + part.ThoughtSignature = []byte(aws.ToString(textBlock.Value.Signature)) + } + parts = append(parts, part) + } + continue + } // Handle text content if textBlock, ok := block.(*types.ContentBlockMemberText); ok { parts = append(parts, &genai.Part{Text: textBlock.Value}) @@ -473,6 +509,15 @@ func documentToMap(doc document.Interface) map[string]any { return result } +const historyToolResultMaxLen = 2000 + +func truncateToolResult(s string, maxLen int) string { + if len(s) <= maxLen { + return s + } + return s[:maxLen] + fmt.Sprintf("\n... [truncated, %d chars omitted]", len(s)-maxLen) +} + // convertGenaiContentsToBedrockMessages converts genai.Content to Bedrock Converse API message format. // nameMap is the original->sanitized tool name map produced by convertGenaiToolsToBedrock. // Any FunctionCall found in the conversation history is written with the sanitized name so @@ -486,17 +531,40 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma idMap := make(map[string]string) idCounter := 0 - for _, content := range contents { + // Bedrock only requires thinking blocks in the last assistant turn before tool results. + // Sending them in earlier turns causes token counts to compound across long sessions. + // Truncate tool results in all turns except the most recent one carrying them. + lastThinkingIdx, lastToolResultIdx := -1, -1 + for i, c := range contents { + if c == nil { + continue + } + for _, p := range c.Parts { + if p == nil { + continue + } + if p.Thought && (c.Role == "model" || c.Role == "assistant") { + lastThinkingIdx = i + } + if p.FunctionResponse != nil && c.Role == "user" { + lastToolResultIdx = i + } + } + } + + for i, content := range contents { if content == nil || len(content.Parts) == 0 { continue } - // Determine role role := types.ConversationRoleUser if content.Role == "model" || content.Role == "assistant" { role = types.ConversationRoleAssistant } + emitThinking := i == lastThinkingIdx + truncateTools := i != lastToolResultIdx + var contentBlocks []types.ContentBlock for _, part := range content.Parts { @@ -504,9 +572,26 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma continue } - // Handle text + // Thought parts also carry Text; check Thought first. https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html + if part.Thought { + if !emitThinking { + continue + } + reasoningText := &types.ReasoningTextBlock{ + Text: aws.String(part.Text), + } + if len(part.ThoughtSignature) > 0 { + reasoningText.Signature = aws.String(string(part.ThoughtSignature)) + } + contentBlocks = append(contentBlocks, &types.ContentBlockMemberReasoningContent{ + Value: &types.ReasoningContentBlockMemberReasoningText{ + Value: *reasoningText, + }, + }) + continue + } + if part.Text != "" { - // Check if this is a system message if content.Role == "system" { systemInstruction = part.Text continue @@ -536,10 +621,11 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma continue } - // Handle function response (tool result in Bedrock terminology) if part.FunctionResponse != nil { - // Extract response content result := extractFunctionResponseContent(part.FunctionResponse.Response) + if truncateTools { + result = truncateToolResult(result, historyToolResultMaxLen) + } toolResult := types.ToolResultBlock{ ToolUseId: aws.String(sanitizeBedrockToolID(part.FunctionResponse.ID, idMap, &idCounter)), Content: []types.ToolResultContentBlock{ @@ -641,3 +727,25 @@ func bedrockStopReasonToGenai(reason types.StopReason) genai.FinishReason { return genai.FinishReasonStop } } + +// buildInferenceConfig constructs the Bedrock InferenceConfiguration from a +// BedrockConfig. When thinking is enabled, temperature and top_p must be +// omitted per the Bedrock extended-thinking API contract. +func buildInferenceConfig(cfg *BedrockConfig, thinkingEnabled bool) *types.InferenceConfiguration { + if cfg.MaxTokens == nil && (thinkingEnabled || (cfg.Temperature == nil && cfg.TopP == nil)) { + return nil + } + ic := &types.InferenceConfiguration{} + if cfg.MaxTokens != nil { + ic.MaxTokens = aws.Int32(int32(*cfg.MaxTokens)) + } + if !thinkingEnabled { + if cfg.Temperature != nil { + ic.Temperature = aws.Float32(float32(*cfg.Temperature)) + } + if cfg.TopP != nil { + ic.TopP = aws.Float32(float32(*cfg.TopP)) + } + } + return ic +} diff --git a/go/adk/pkg/models/bedrock_test.go b/go/adk/pkg/models/bedrock_test.go index de2d1c3caf..1094bd0771 100644 --- a/go/adk/pkg/models/bedrock_test.go +++ b/go/adk/pkg/models/bedrock_test.go @@ -2,6 +2,7 @@ package models import ( "encoding/json" + "strings" "testing" "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types" @@ -102,6 +103,41 @@ func TestConvertGenaiContentsToBedrockMessages(t *testing.T) { } }, }, + { + name: "thinking block preserved as ReasoningContent", + contents: []*genai.Content{ + { + Role: "model", + Parts: []*genai.Part{ + {Thought: true, Text: "let me think", ThoughtSignature: []byte("sig123")}, + {FunctionCall: &genai.FunctionCall{ID: "c1", Name: "get_weather", Args: map[string]any{"location": "Paris"}}}, + }, + }, + }, + wantMsgCount: 1, + checkMsg: func(t *testing.T, msgs []types.Message) { + if len(msgs[0].Content) != 2 { + t.Fatalf("expected 2 blocks (thinking + toolUse), got %d", len(msgs[0].Content)) + } + rb, ok := msgs[0].Content[0].(*types.ContentBlockMemberReasoningContent) + if !ok { + t.Fatalf("block 0: want *ContentBlockMemberReasoningContent, got %T", msgs[0].Content[0]) + } + rt, ok := rb.Value.(*types.ReasoningContentBlockMemberReasoningText) + if !ok { + t.Fatalf("reasoning value: want *ReasoningContentBlockMemberReasoningText, got %T", rb.Value) + } + if *rt.Value.Text != "let me think" { + t.Errorf("text: want %q, got %q", "let me think", *rt.Value.Text) + } + if *rt.Value.Signature != "sig123" { + t.Errorf("signature: want %q, got %q", "sig123", *rt.Value.Signature) + } + if _, ok := msgs[0].Content[1].(*types.ContentBlockMemberToolUse); !ok { + t.Errorf("block 1: want *ContentBlockMemberToolUse, got %T", msgs[0].Content[1]) + } + }, + }, } for _, tt := range tests { @@ -424,3 +460,215 @@ func TestStreamingToolCallParseArgs(t *testing.T) { }) } } + +func TestThinkingOnlyInLastAssistantTurn(t *testing.T) { + contents := []*genai.Content{ + { + Role: "model", + Parts: []*genai.Part{ + {Thought: true, Text: "first think", ThoughtSignature: []byte("sig1")}, + {FunctionCall: &genai.FunctionCall{ID: "c1", Name: "tool_a", Args: map[string]any{}}}, + }, + }, + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c1", Name: "tool_a", Response: map[string]any{"r": "v1"}}}}, + }, + { + Role: "model", + Parts: []*genai.Part{ + {Thought: true, Text: "second think", ThoughtSignature: []byte("sig2")}, + {FunctionCall: &genai.FunctionCall{ID: "c2", Name: "tool_b", Args: map[string]any{}}}, + }, + }, + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c2", Name: "tool_b", Response: map[string]any{"r": "v2"}}}}, + }, + } + + msgs, _ := convertGenaiContentsToBedrockMessages(contents, nil) + if len(msgs) != 4 { + t.Fatalf("want 4 messages, got %d", len(msgs)) + } + + // First assistant turn must NOT contain reasoning content. + for _, block := range msgs[0].Content { + if _, ok := block.(*types.ContentBlockMemberReasoningContent); ok { + t.Error("first assistant turn must not contain reasoning content") + } + } + + // Last assistant turn (index 2) must contain reasoning content. + hasReasoning := false + for _, block := range msgs[2].Content { + if _, ok := block.(*types.ContentBlockMemberReasoningContent); ok { + hasReasoning = true + } + } + if !hasReasoning { + t.Error("last assistant turn must contain reasoning content") + } +} + +func TestHistoricalToolResultTruncation(t *testing.T) { + longOutput := strings.Repeat("x", historyToolResultMaxLen+500) + contents := []*genai.Content{ + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c1", Name: "tool_a", Response: map[string]any{"result": longOutput}}}}, + }, + { + Role: "user", + Parts: []*genai.Part{{FunctionResponse: &genai.FunctionResponse{ID: "c2", Name: "tool_b", Response: map[string]any{"result": longOutput}}}}, + }, + } + + msgs, _ := convertGenaiContentsToBedrockMessages(contents, nil) + if len(msgs) != 2 { + t.Fatalf("want 2 messages, got %d", len(msgs)) + } + + extractText := func(msg types.Message) string { + for _, block := range msg.Content { + if tr, ok := block.(*types.ContentBlockMemberToolResult); ok { + for _, c := range tr.Value.Content { + if txt, ok := c.(*types.ToolResultContentBlockMemberText); ok { + return txt.Value + } + } + } + } + return "" + } + + first := extractText(msgs[0]) + if len(first) >= len(longOutput) { + t.Errorf("historical tool result should be truncated, got len=%d", len(first)) + } + + last := extractText(msgs[1]) + if len(last) != len(longOutput) { + t.Errorf("latest tool result must not be truncated, got len=%d want %d", len(last), len(longOutput)) + } +} + +func TestTruncateToolResult(t *testing.T) { + cases := []struct { + name string + input string + maxLen int + wantLen int + wantMsg bool + }{ + {"no truncation needed", "short", 100, 5, false}, + {"exact boundary", strings.Repeat("a", 100), 100, 100, false}, + {"truncated", strings.Repeat("a", 150), 100, -1, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := truncateToolResult(tc.input, tc.maxLen) + if tc.wantMsg { + if len(got) <= tc.maxLen { + t.Errorf("expected truncated result longer than maxLen, got %d", len(got)) + } + if !strings.Contains(got, "truncated") { + t.Error("truncated result must contain truncation notice") + } + } else { + if len(got) != tc.wantLen { + t.Errorf("want len %d, got %d", tc.wantLen, len(got)) + } + } + }) + } +} + +func TestBuildInferenceConfig(t *testing.T) { + fTemp := float64(0.7) + fTopP := float64(0.9) + maxTok := 1000 + + tests := []struct { + name string + cfg BedrockConfig + thinkingActive bool + wantNil bool + wantTemp *float32 + wantTopP *float32 + wantMaxTokens *int32 + }{ + { + name: "thinking drops temperature and topP", + cfg: BedrockConfig{Temperature: &fTemp, TopP: &fTopP}, + thinkingActive: true, + wantNil: true, + }, + { + name: "thinking with maxTokens keeps only maxTokens", + cfg: BedrockConfig{Temperature: &fTemp, TopP: &fTopP, MaxTokens: &maxTok}, + thinkingActive: true, + wantNil: false, + wantMaxTokens: func() *int32 { v := int32(1000); return &v }(), + }, + { + name: "no thinking passes temperature and topP", + cfg: BedrockConfig{Temperature: &fTemp, TopP: &fTopP}, + thinkingActive: false, + wantNil: false, + wantTemp: func() *float32 { v := float32(0.7); return &v }(), + wantTopP: func() *float32 { v := float32(0.9); return &v }(), + }, + { + name: "all nil returns nil", + cfg: BedrockConfig{}, + thinkingActive: false, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := buildInferenceConfig(&tt.cfg, tt.thinkingActive) + if tt.wantNil { + if got != nil { + t.Fatalf("want nil, got %+v", got) + } + return + } + if got == nil { + t.Fatal("want non-nil InferenceConfiguration, got nil") + } + if tt.wantTemp == nil && got.Temperature != nil { + t.Errorf("temperature: want nil, got %v", *got.Temperature) + } + if tt.wantTemp != nil { + if got.Temperature == nil { + t.Fatalf("temperature: want %v, got nil", *tt.wantTemp) + } + if *got.Temperature != *tt.wantTemp { + t.Errorf("temperature: want %v, got %v", *tt.wantTemp, *got.Temperature) + } + } + if tt.wantTopP == nil && got.TopP != nil { + t.Errorf("topP: want nil, got %v", *got.TopP) + } + if tt.wantTopP != nil { + if got.TopP == nil { + t.Fatalf("topP: want %v, got nil", *tt.wantTopP) + } + if *got.TopP != *tt.wantTopP { + t.Errorf("topP: want %v, got %v", *tt.wantTopP, *got.TopP) + } + } + if tt.wantMaxTokens != nil { + if got.MaxTokens == nil { + t.Fatalf("maxTokens: want %v, got nil", *tt.wantMaxTokens) + } + if *got.MaxTokens != *tt.wantMaxTokens { + t.Errorf("maxTokens: want %v, got %v", *tt.wantMaxTokens, *got.MaxTokens) + } + } + }) + } +} diff --git a/go/adk/pkg/telemetry/tracing.go b/go/adk/pkg/telemetry/tracing.go index 695ec0711e..81e06d9aea 100644 --- a/go/adk/pkg/telemetry/tracing.go +++ b/go/adk/pkg/telemetry/tracing.go @@ -156,7 +156,7 @@ func newTracerProvider(ctx context.Context, res *resource.Resource) (*sdktrace.T return sdktrace.NewTracerProvider( sdktrace.WithSpanProcessor(kagentAttributesSpanProcessor{}), - sdktrace.WithBatcher(exporter), + sdktrace.WithBatcher(newTruncatingExporter(exporter)), sdktrace.WithResource(res), ), nil } diff --git a/go/adk/pkg/telemetry/truncating_exporter.go b/go/adk/pkg/telemetry/truncating_exporter.go new file mode 100644 index 0000000000..9866ecdd8a --- /dev/null +++ b/go/adk/pkg/telemetry/truncating_exporter.go @@ -0,0 +1,79 @@ +package telemetry + +import ( + "context" + "fmt" + + "go.opentelemetry.io/otel/attribute" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + "go.opentelemetry.io/otel/sdk/trace/tracetest" +) + +// maxSpanAttributeBytes caps individual string span attributes before export. +// Tempo's default gRPC message limit is 4 MB; large tool responses can exceed it. +const maxSpanAttributeBytes = 16 * 1024 + +type truncatingExporter struct { + inner sdktrace.SpanExporter +} + +func newTruncatingExporter(inner sdktrace.SpanExporter) sdktrace.SpanExporter { + return &truncatingExporter{inner: inner} +} + +func (e *truncatingExporter) ExportSpans(ctx context.Context, spans []sdktrace.ReadOnlySpan) error { + out := make([]sdktrace.ReadOnlySpan, len(spans)) + for i, s := range spans { + out[i] = truncateSpan(s) + } + return e.inner.ExportSpans(ctx, out) +} + +func (e *truncatingExporter) Shutdown(ctx context.Context) error { + return e.inner.Shutdown(ctx) +} + +func truncateSpan(s sdktrace.ReadOnlySpan) sdktrace.ReadOnlySpan { + attrs := s.Attributes() + needsCut := false + for _, a := range attrs { + if a.Value.Type() == attribute.STRING && len(a.Value.AsString()) > maxSpanAttributeBytes { + needsCut = true + break + } + } + if !needsCut { + return s + } + newAttrs := make([]attribute.KeyValue, len(attrs)) + for i, a := range attrs { + if a.Value.Type() == attribute.STRING { + v := a.Value.AsString() + if len(v) > maxSpanAttributeBytes { + newAttrs[i] = attribute.String(string(a.Key), + v[:maxSpanAttributeBytes]+fmt.Sprintf(" ...[truncated, %d bytes omitted]", len(v)-maxSpanAttributeBytes)) + continue + } + } + newAttrs[i] = a + } + stub := tracetest.SpanStub{ + Name: s.Name(), + SpanContext: s.SpanContext(), + Parent: s.Parent(), + SpanKind: s.SpanKind(), + StartTime: s.StartTime(), + EndTime: s.EndTime(), + Attributes: newAttrs, + Events: s.Events(), + Links: s.Links(), + Status: s.Status(), + DroppedAttributes: s.DroppedAttributes(), + DroppedEvents: s.DroppedEvents(), + DroppedLinks: s.DroppedLinks(), + ChildSpanCount: s.ChildSpanCount(), + Resource: s.Resource(), + InstrumentationLibrary: s.InstrumentationScope(), + } + return stub.Snapshot() +}