Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 131 additions & 23 deletions go/adk/pkg/models/bedrock.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,9 @@ func (m *BedrockModel) GenerateContent(ctx context.Context, req *model.LLMReques
// is written with the sanitized name Bedrock already knows about.
messages, systemInstruction := convertGenaiContentsToBedrockMessages(req.Contents, nameMap)

// Build inference config
var inferenceConfig *types.InferenceConfiguration
if m.Config.MaxTokens != nil || m.Config.Temperature != nil || m.Config.TopP != nil {
inferenceConfig = &types.InferenceConfiguration{}
if m.Config.MaxTokens != nil {
inferenceConfig.MaxTokens = aws.Int32(int32(*m.Config.MaxTokens))
}
if m.Config.Temperature != nil {
inferenceConfig.Temperature = aws.Float32(float32(*m.Config.Temperature))
}
if m.Config.TopP != nil {
inferenceConfig.TopP = aws.Float32(float32(*m.Config.TopP))
}
}
// temperature/top_p must not be sent when thinking is active. https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html
_, thinkingEnabled := m.Config.AdditionalModelRequestFields["thinking"]
inferenceConfig := buildInferenceConfig(m.Config, thinkingEnabled)

// Build system prompt
var systemPrompt []types.SystemContentBlock
Expand Down Expand Up @@ -248,6 +237,10 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
toolCalls := make(map[int32]*streamingToolCall)
var completedToolCalls []*genai.Part

// https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html
reasoningBlocks := make(map[int32]*streamingReasoningBlock)
var completedThinkingParts []*genai.Part

// Get the event stream and read events from the channel
stream := output.GetStream()
defer stream.Close()
Expand Down Expand Up @@ -295,10 +288,22 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
if tc, ok := toolCalls[blockIdx]; ok && delta.Value.Input != nil {
tc.InputJSON += aws.ToString(delta.Value.Input)
}

case *types.ContentBlockDeltaMemberReasoningContent:
if _, ok := reasoningBlocks[blockIdx]; !ok {
reasoningBlocks[blockIdx] = &streamingReasoningBlock{}
}
rb := reasoningBlocks[blockIdx]
switch inner := delta.Value.(type) {
case *types.ReasoningContentBlockDeltaMemberText:
rb.Text.WriteString(inner.Value)
case *types.ReasoningContentBlockDeltaMemberSignature:
rb.Signature = inner.Value
}
}
}

// Handle content block stop (tool use complete)
// Handle content block stop (tool use or thinking block complete)
if stop, ok := event.(*types.ConverseStreamOutputMemberContentBlockStop); ok {
blockIdx := aws.ToInt32(stop.Value.ContentBlockIndex)
if tc, ok := toolCalls[blockIdx]; ok {
Expand All @@ -316,7 +321,18 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
Args: args,
}
completedToolCalls = append(completedToolCalls, &genai.Part{FunctionCall: functionCall})
delete(toolCalls, blockIdx) // Clean up
delete(toolCalls, blockIdx)
}
if rb, ok := reasoningBlocks[blockIdx]; ok {
part := &genai.Part{
Thought: true,
Text: rb.Text.String(),
}
if rb.Signature != "" {
part.ThoughtSignature = []byte(rb.Signature)
}
completedThinkingParts = append(completedThinkingParts, part)
delete(reasoningBlocks, blockIdx)
}
}

Expand All @@ -337,8 +353,9 @@ func (m *BedrockModel) generateStreaming(ctx context.Context, modelId string, me
}
}

// Build final response
// thinking parts first; block order must match what Bedrock returned.
finalParts := []*genai.Part{}
finalParts = append(finalParts, completedThinkingParts...)
text := aggregatedText.String()
if text != "" {
finalParts = append(finalParts, &genai.Part{Text: text})
Expand Down Expand Up @@ -366,6 +383,11 @@ type streamingToolCall struct {
InputJSON string // Accumulated JSON input
}

type streamingReasoningBlock struct {
Text strings.Builder
Signature string
}

// parseArgs parses the accumulated JSON input into a map
func (tc *streamingToolCall) parseArgs() map[string]any {
if tc.InputJSON == "" {
Expand Down Expand Up @@ -403,6 +425,20 @@ func (m *BedrockModel) generateNonStreaming(ctx context.Context, modelId string,
parts := []*genai.Part{}
if message, ok := output.Output.(*types.ConverseOutputMemberMessage); ok {
for _, block := range message.Value.Content {
// https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html
if reasoningBlock, ok := block.(*types.ContentBlockMemberReasoningContent); ok {
if textBlock, ok := reasoningBlock.Value.(*types.ReasoningContentBlockMemberReasoningText); ok {
part := &genai.Part{
Thought: true,
Text: aws.ToString(textBlock.Value.Text),
}
if textBlock.Value.Signature != nil {
part.ThoughtSignature = []byte(aws.ToString(textBlock.Value.Signature))
}
parts = append(parts, part)
}
continue
}
// Handle text content
if textBlock, ok := block.(*types.ContentBlockMemberText); ok {
parts = append(parts, &genai.Part{Text: textBlock.Value})
Expand Down Expand Up @@ -473,6 +509,15 @@ func documentToMap(doc document.Interface) map[string]any {
return result
}

const historyToolResultMaxLen = 2000

func truncateToolResult(s string, maxLen int) string {
if len(s) <= maxLen {
return s
}
return s[:maxLen] + fmt.Sprintf("\n... [truncated, %d chars omitted]", len(s)-maxLen)
}

// convertGenaiContentsToBedrockMessages converts genai.Content to Bedrock Converse API message format.
// nameMap is the original->sanitized tool name map produced by convertGenaiToolsToBedrock.
// Any FunctionCall found in the conversation history is written with the sanitized name so
Expand All @@ -486,27 +531,67 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma
idMap := make(map[string]string)
idCounter := 0

for _, content := range contents {
// Bedrock only requires thinking blocks in the last assistant turn before tool results.
// Sending them in earlier turns causes token counts to compound across long sessions.
// Truncate tool results in all turns except the most recent one carrying them.
lastThinkingIdx, lastToolResultIdx := -1, -1
for i, c := range contents {
if c == nil {
continue
}
for _, p := range c.Parts {
if p == nil {
continue
}
if p.Thought && (c.Role == "model" || c.Role == "assistant") {
lastThinkingIdx = i
}
if p.FunctionResponse != nil && c.Role == "user" {
lastToolResultIdx = i
}
}
}

for i, content := range contents {
if content == nil || len(content.Parts) == 0 {
continue
}

// Determine role
role := types.ConversationRoleUser
if content.Role == "model" || content.Role == "assistant" {
role = types.ConversationRoleAssistant
}

emitThinking := i == lastThinkingIdx
truncateTools := i != lastToolResultIdx

var contentBlocks []types.ContentBlock

for _, part := range content.Parts {
if part == nil {
continue
}

// Handle text
// Thought parts also carry Text; check Thought first. https://docs.aws.amazon.com/bedrock/latest/userguide/claude-messages-extended-thinking.html
if part.Thought {
if !emitThinking {
continue
}
reasoningText := &types.ReasoningTextBlock{
Text: aws.String(part.Text),
}
if len(part.ThoughtSignature) > 0 {
reasoningText.Signature = aws.String(string(part.ThoughtSignature))
}
contentBlocks = append(contentBlocks, &types.ContentBlockMemberReasoningContent{
Value: &types.ReasoningContentBlockMemberReasoningText{
Value: *reasoningText,
},
})
continue
}

if part.Text != "" {
// Check if this is a system message
if content.Role == "system" {
systemInstruction = part.Text
continue
Expand Down Expand Up @@ -536,10 +621,11 @@ func convertGenaiContentsToBedrockMessages(contents []*genai.Content, nameMap ma
continue
}

// Handle function response (tool result in Bedrock terminology)
if part.FunctionResponse != nil {
// Extract response content
result := extractFunctionResponseContent(part.FunctionResponse.Response)
if truncateTools {
result = truncateToolResult(result, historyToolResultMaxLen)
}
toolResult := types.ToolResultBlock{
ToolUseId: aws.String(sanitizeBedrockToolID(part.FunctionResponse.ID, idMap, &idCounter)),
Content: []types.ToolResultContentBlock{
Expand Down Expand Up @@ -641,3 +727,25 @@ func bedrockStopReasonToGenai(reason types.StopReason) genai.FinishReason {
return genai.FinishReasonStop
}
}

// buildInferenceConfig constructs the Bedrock InferenceConfiguration from a
// BedrockConfig. When thinking is enabled, temperature and top_p must be
// omitted per the Bedrock extended-thinking API contract.
func buildInferenceConfig(cfg *BedrockConfig, thinkingEnabled bool) *types.InferenceConfiguration {
if cfg.MaxTokens == nil && (thinkingEnabled || (cfg.Temperature == nil && cfg.TopP == nil)) {
return nil
}
ic := &types.InferenceConfiguration{}
if cfg.MaxTokens != nil {
ic.MaxTokens = aws.Int32(int32(*cfg.MaxTokens))
}
if !thinkingEnabled {
if cfg.Temperature != nil {
ic.Temperature = aws.Float32(float32(*cfg.Temperature))
}
if cfg.TopP != nil {
ic.TopP = aws.Float32(float32(*cfg.TopP))
}
}
return ic
}
Loading
Loading