From 589328ea747ad4a994223af5789320e171ea2aa7 Mon Sep 17 00:00:00 2001 From: Google Team Member Date: Fri, 17 Apr 2026 12:30:46 -0700 Subject: [PATCH] feat: Support ChatCompletionChunk to LlmResponse conversion This is part of a larger chain of commits for adding chat completion API support to the Apigee model. PiperOrigin-RevId: 901419365 --- .../models/chat/ChatCompletionsCommon.java | 2 +- .../models/chat/ChatCompletionsResponse.java | 469 +++++++++++++++--- .../chat/ChatCompletionsResponseTest.java | 189 ++++++- 3 files changed, 600 insertions(+), 60 deletions(-) diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java index cd5b4d7bf..e26546313 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java @@ -136,7 +136,7 @@ public FunctionCall toFunctionCall(@Nullable String toolCallId) { if (name != null) { fcBuilder.name(name); } - if (arguments != null) { + if (arguments != null && !arguments.isEmpty()) { try { Map args = objectMapper.readValue(arguments, new TypeReference>() {}); diff --git a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java index c52389aa3..9645016a9 100644 --- a/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java +++ b/core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java @@ -19,16 +19,26 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.core.JsonProcessingException; +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; import com.google.adk.models.LlmResponse; +import com.google.common.collect.ImmutableList; import com.google.genai.types.Content; import com.google.genai.types.CustomMetadata; import com.google.genai.types.FinishReason; import com.google.genai.types.FinishReason.Known; +import com.google.genai.types.FunctionCall; import com.google.genai.types.GenerateContentResponseUsageMetadata; import com.google.genai.types.Part; -import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.TreeMap; import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * Data Transfer Objects for Chat Completion and Chat Completion Chunk API responses. @@ -36,10 +46,61 @@ *

See https://developers.openai.com/api/reference/resources/chat */ @JsonIgnoreProperties(ignoreUnknown = true) -final class ChatCompletionsResponse { +public final class ChatCompletionsResponse { private ChatCompletionsResponse() {} + static @Nullable FinishReason mapFinishReason(String reason) { + if (reason == null) { + return null; + } + return switch (reason) { + case "stop", "tool_calls" -> new FinishReason(Known.STOP.toString()); + case "length" -> new FinishReason(Known.MAX_TOKENS.toString()); + case "content_filter" -> new FinishReason(Known.SAFETY.toString()); + default -> new FinishReason(Known.OTHER.toString()); + }; + } + + static @Nullable GenerateContentResponseUsageMetadata mapUsage(Usage usage) { + if (usage == null) { + return null; + } + GenerateContentResponseUsageMetadata.Builder builder = + GenerateContentResponseUsageMetadata.builder(); + if (usage.promptTokens != null) { + builder.promptTokenCount(usage.promptTokens); + } + if (usage.completionTokens != null) { + builder.candidatesTokenCount(usage.completionTokens); + } + if (usage.totalTokens != null) { + builder.totalTokenCount(usage.totalTokens); + } + if (usage.thoughtsTokenCount != null) { + builder.thoughtsTokenCount(usage.thoughtsTokenCount); + } else if (usage.completionTokensDetails != null + && usage.completionTokensDetails.reasoningTokens != null) { + builder.thoughtsTokenCount(usage.completionTokensDetails.reasoningTokens); + } + return builder.build(); + } + + /** + * Maps the chat role string to the model role string. + * + * @param role the chat role string, or {@code null}. + * @return the model role string, or the input role if it doesn't match the assistant role. + */ + static @Nullable String mapRole(@Nullable String role) { + if (role == null) { + return null; + } + return role.equals(ChatCompletionsCommon.ROLE_ASSISTANT) + ? ChatCompletionsCommon.ROLE_MODEL + : role; + } + /** * See * https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion%20%3E%20(schema) @@ -95,49 +156,10 @@ public LlmResponse toLlmResponse() { builder.usageMetadata(mapUsage(usage)); } - List customMetadataList = buildCustomMetadata(); + ImmutableList customMetadataList = buildCustomMetadata(); return builder.customMetadata(customMetadataList).build(); } - /** - * Maps the finish reason string to a {@link FinishReason}. - * - * @param reason the finish reason string. - * @return the {@link FinishReason}, or {@code null} if the input reason is null. - */ - private @Nullable FinishReason mapFinishReason(String reason) { - if (reason == null) { - return null; - } - return switch (reason) { - case "stop", "tool_calls" -> new FinishReason(Known.STOP.toString()); - case "length" -> new FinishReason(Known.MAX_TOKENS.toString()); - case "content_filter" -> new FinishReason(Known.SAFETY.toString()); - default -> new FinishReason(Known.OTHER.toString()); - }; - } - - private GenerateContentResponseUsageMetadata mapUsage(Usage usage) { - GenerateContentResponseUsageMetadata.Builder builder = - GenerateContentResponseUsageMetadata.builder(); - if (usage.promptTokens != null) { - builder.promptTokenCount(usage.promptTokens); - } - if (usage.completionTokens != null) { - builder.candidatesTokenCount(usage.completionTokens); - } - if (usage.totalTokens != null) { - builder.totalTokenCount(usage.totalTokens); - } - if (usage.thoughtsTokenCount != null) { - builder.thoughtsTokenCount(usage.thoughtsTokenCount); - } else if (usage.completionTokensDetails != null - && usage.completionTokensDetails.reasoningTokens != null) { - builder.thoughtsTokenCount(usage.completionTokensDetails.reasoningTokens); - } - return builder.build(); - } - /** * Maps the chosen completion to a {@link Content} object. * @@ -152,14 +174,8 @@ private Content mapChoiceToContent(@Nullable Choice choice) { return contentBuilder.build(); } - private String mapRole(@Nullable String role) { - return (role != null && role.equals(ChatCompletionsCommon.ROLE_ASSISTANT)) - ? ChatCompletionsCommon.ROLE_MODEL - : role; - } - - private List mapMessageToParts(Message message) { - List parts = new ArrayList<>(); + private ImmutableList mapMessageToParts(Message message) { + ImmutableList.Builder parts = ImmutableList.builder(); if (message.content != null) { parts.add(Part.fromText(message.content)); } @@ -169,18 +185,19 @@ private List mapMessageToParts(Message message) { if (message.toolCalls != null) { parts.addAll(mapToolCallsToParts(message.toolCalls)); } - return parts; + return parts.build(); } - private List mapToolCallsToParts(List toolCalls) { - List parts = new ArrayList<>(); + private ImmutableList mapToolCallsToParts( + List toolCalls) { + ImmutableList.Builder parts = ImmutableList.builder(); for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) { Part part = toolCall.toPart(); if (part != null) { parts.add(part); } } - return parts; + return parts.build(); } /** @@ -188,8 +205,8 @@ private List mapToolCallsToParts(List tool * * @return a list of {@link CustomMetadata}, which will be empty if no relevant fields are set. */ - private List buildCustomMetadata() { - List customMetadataList = new ArrayList<>(); + private ImmutableList buildCustomMetadata() { + ImmutableList.Builder customMetadataList = ImmutableList.builder(); if (id != null) { customMetadataList.add( CustomMetadata.builder() @@ -225,7 +242,7 @@ private List buildCustomMetadata() { .stringValue(serviceTier) .build()); } - return customMetadataList; + return customMetadataList.build(); } } @@ -489,4 +506,342 @@ static class Audio { /** See class definition for more details. */ public String transcript; } + + /** Accumulates chunks into a final response. */ + static class ChatCompletionChunkCollection { + private static final ObjectMapper objectMapper = new ObjectMapper(); + private static final Logger logger = + LoggerFactory.getLogger(ChatCompletionChunkCollection.class); + + private final StringBuilder contentParts = new StringBuilder(); + private final Map toolCallParts = new TreeMap<>(); + private final Map toolCallArgsAccumulator = new HashMap<>(); + private String role = ""; + private String model = ""; + private Usage usage; + private final Map customMetadataMap = new HashMap<>(); + + private ImmutableList getCustomMetadataList() { + ImmutableList.Builder list = ImmutableList.builder(); + for (Entry entry : customMetadataMap.entrySet()) { + list.add( + CustomMetadata.builder().key(entry.getKey()).stringValue(entry.getValue()).build()); + } + return list.build(); + } + + /** + * Processes a single chunk of a chat completion response. + * + * @param chunk the chunk to process, or {@code null}. + * @return a list of {@link LlmResponse} objects generated from this chunk. + */ + public ImmutableList processChunk(ChatCompletionChunk chunk) { + if (chunk == null) { + return ImmutableList.of(); + } + + updateState(chunk); + + ImmutableList.Builder responses = ImmutableList.builder(); + if (chunk.choices == null || chunk.choices.isEmpty()) { + addGenericResponseIfSet(responses); + return responses.build(); + } + + // The ADK only supports n=1 choices. If more than 1 choice is returned, all choices + // after the first will be dropped. + if (chunk.choices.size() > 1) { + logger.error( + "Multiple choices found in streaming response but only the first one will be used."); + } + ChunkChoice choice = chunk.choices.get(0); + + ImmutableList chunkParts = mapDeltaToParts(choice); + + responses.add(buildPartialResponse(chunkParts)); + + if (choice.finishReason != null && !choice.finishReason.isEmpty()) { + responses.add(buildFinalResponse(choice)); + } + + return responses.build(); + } + + /** + * Updates the internal state (model, usage, metadata) from the chunk. + * + * @param chunk the chunk to read from. + */ + private void updateState(ChatCompletionChunk chunk) { + if (chunk.model != null) { + this.model = chunk.model; + } + if (chunk.usage != null) { + this.usage = chunk.usage; + } + + if (chunk.id != null) { + customMetadataMap.put(ChatCompletionsCommon.METADATA_KEY_ID, chunk.id); + } + if (chunk.created != null) { + customMetadataMap.put(ChatCompletionsCommon.METADATA_KEY_CREATED, chunk.created.toString()); + } + if (chunk.object != null) { + customMetadataMap.put(ChatCompletionsCommon.METADATA_KEY_OBJECT, chunk.object); + } + if (chunk.systemFingerprint != null) { + customMetadataMap.put( + ChatCompletionsCommon.METADATA_KEY_SYSTEM_FINGERPRINT, chunk.systemFingerprint); + } + if (chunk.serviceTier != null) { + customMetadataMap.put(ChatCompletionsCommon.METADATA_KEY_SERVICE_TIER, chunk.serviceTier); + } + } + + /** + * Adds a generic response to the list if usage or metadata is set but choices are empty. + * + * @param responses the list to add to. + */ + private void addGenericResponseIfSet(ImmutableList.Builder responses) { + if (this.usage != null || !customMetadataMap.isEmpty()) { + responses.add( + LlmResponse.builder() + .partial(true) + .modelVersion(this.model) + .usageMetadata(mapUsage(this.usage)) + .customMetadata(getCustomMetadataList()) + .build()); + } + } + + /** + * Maps the choice's delta to a list of parts and updates state. + * + * @param choice the choice to map. + * @return a list of {@link Part}s for this chunk. + */ + private ImmutableList mapDeltaToParts(ChunkChoice choice) { + ImmutableList.Builder chunkParts = ImmutableList.builder(); + if (choice.delta != null) { + updateRole(choice.delta.role); + appendContent(choice.delta.content, chunkParts); + appendRefusal(choice.delta.refusal, chunkParts); + appendToolCalls(choice.delta.toolCalls, chunkParts); + } + return chunkParts.build(); + } + + /** + * Updates the accumulated role if the delta contains a valid role. + * + * @param deltaRole the role string from the delta, or {@code null}. + */ + private void updateRole(@Nullable String deltaRole) { + if (deltaRole != null && !deltaRole.isEmpty()) { + String mapped = ChatCompletionsResponse.mapRole(deltaRole); + if (mapped != null) { + this.role = mapped; + } + } + } + + /** + * Appends content to the accumulator and adds it to the chunk parts. + * + * @param content the content string, or {@code null}. + * @param chunkParts the list of parts for this chunk. + */ + private void appendContent(@Nullable String content, ImmutableList.Builder chunkParts) { + if (content != null && !content.isEmpty()) { + contentParts.append(content); + chunkParts.add(Part.fromText(content)); + } + } + + /** + * Appends refusal to the accumulator and adds it to the chunk parts. + * + * @param refusal the refusal string, or {@code null}. + * @param chunkParts the list of parts for this chunk. + */ + private void appendRefusal(@Nullable String refusal, ImmutableList.Builder chunkParts) { + if (refusal != null && !refusal.isEmpty()) { + if (contentParts.length() > 0) { + contentParts.append("\n"); + } + contentParts.append(refusal); + chunkParts.add(Part.fromText(refusal)); + } + } + + /** + * Appends tool calls to the accumulator and adds them to the chunk parts. + * + * @param toolCalls the list of tool calls, or {@code null}. + * @param chunkParts the list of parts for this chunk. + */ + private void appendToolCalls( + @Nullable List toolCalls, + ImmutableList.Builder chunkParts) { + if (toolCalls != null) { + for (ChatCompletionsCommon.ToolCall toolCall : toolCalls) { + Part p = upsertToolCall(toolCall); + if (p != null) { + chunkParts.add(p); + } + } + } + } + + /** + * Builds a partial {@link LlmResponse} for the current chunk parts. + * + * @param chunkParts the parts for this chunk. + * @return the partial response. + */ + private LlmResponse buildPartialResponse(List chunkParts) { + return LlmResponse.builder() + .partial(true) + .content(Content.builder().role(this.role).parts(chunkParts).build()) + .modelVersion(this.model) + .usageMetadata(mapUsage(this.usage)) + .customMetadata(getCustomMetadataList()) + .build(); + } + + /** + * Builds the final {@link LlmResponse} with all accumulated content. + * + * @param choice the choice containing the finish reason. + * @return the final response. + */ + private LlmResponse buildFinalResponse(ChunkChoice choice) { + return LlmResponse.builder() + .content(Content.builder().role(this.role).parts(getContentParts()).build()) + .finishReason(ChatCompletionsResponse.mapFinishReason(choice.finishReason)) + .modelVersion(this.model) + .usageMetadata(mapUsage(this.usage)) + .customMetadata(getCustomMetadataList()) + .build(); + } + + /** + * Upserts a tool call from a chunk into the collection and returns the part for this chunk. + * + * @param toolCall the tool call from the chunk. + * @return the {@link Part} to emit for this chunk, or {@code null} if it cannot be converted. + */ + private Part upsertToolCall(ChatCompletionsCommon.ToolCall toolCall) { + int index = toolCall.index != null ? toolCall.index : toolCallParts.size(); + + initializeToolCallState(index); + updateAccumulatedToolCall(index, toolCall); + + return buildChunkToolCallPart(toolCall); + } + + /** + * Initializes the state for a new tool call index if it doesn't exist. + * + * @param index the index of the tool call. + */ + private void initializeToolCallState(int index) { + if (!toolCallParts.containsKey(index)) { + toolCallParts.put( + index, Part.builder().functionCall(FunctionCall.builder().build()).build()); + toolCallArgsAccumulator.put(index, new StringBuilder()); + } + } + + /** + * Updates the accumulated tool call state with data from the chunk. + * + * @param index the index of the tool call. + * @param toolCall the tool call from the chunk. + */ + private void updateAccumulatedToolCall(int index, ChatCompletionsCommon.ToolCall toolCall) { + Part part = toolCallParts.get(index); + FunctionCall.Builder fcBuilder = + part.functionCall().isPresent() + ? part.functionCall().get().toBuilder() + : FunctionCall.builder(); + + if (toolCall.id != null) { + fcBuilder.id(toolCall.id); + } + + appendFunctionDetails(fcBuilder, toolCall.function, index); + + part = toolCall.applyThoughtSignature(part); + Part updatedPart = part.toBuilder().functionCall(fcBuilder.build()).build(); + toolCallParts.put(index, updatedPart); + } + + private void appendFunctionDetails( + FunctionCall.Builder fcBuilder, ChatCompletionsCommon.Function function, int index) { + if (function == null) { + return; + } + if (function.name != null) { + fcBuilder.name(function.name); + } + if (function.arguments != null) { + toolCallArgsAccumulator.get(index).append(function.arguments); + } + } + + /** + * Builds the {@link Part} for the current chunk's tool call. + * + * @param toolCall the tool call from the chunk. + * @return the {@link Part} for this chunk. + */ + private Part buildChunkToolCallPart(ChatCompletionsCommon.ToolCall toolCall) { + Part chunkPart = toolCall.toPart(); + if (chunkPart == null) { + FunctionCall.Builder chunkFcBuilder = FunctionCall.builder(); + if (toolCall.id != null) { + chunkFcBuilder.id(toolCall.id); + } + chunkPart = Part.builder().functionCall(chunkFcBuilder.build()).build(); + chunkPart = toolCall.applyThoughtSignature(chunkPart); + } + return chunkPart; + } + + private ImmutableList getContentParts() { + ImmutableList.Builder parts = ImmutableList.builder(); + if (contentParts.length() > 0) { + parts.add(Part.fromText(contentParts.toString())); + } + + // If a server sends keys 0 and 2 but not 1 then squash the indices and + // return parts at indices 0 and 1. + ImmutableList sortedKeys = ImmutableList.sortedCopyOf(toolCallParts.keySet()); + + for (int index : sortedKeys) { + Part part = toolCallParts.get(index); + if (part != null && part.functionCall().isPresent()) { + FunctionCall fc = part.functionCall().get(); + StringBuilder argsSb = toolCallArgsAccumulator.get(index); + if (argsSb != null && argsSb.length() > 0) { + try { + Map args = + objectMapper.readValue( + argsSb.toString(), new TypeReference>() {}); + fc = fc.toBuilder().args(args).build(); + part = part.toBuilder().functionCall(fc).build(); + } catch (JsonProcessingException e) { + throw new IllegalArgumentException( + "Failed to parse final tool call arguments: " + argsSb, e); + } + } + } + parts.add(part); + } + return parts.build(); + } + } } diff --git a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java index dd1a5d85a..ad1839019 100644 --- a/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java +++ b/core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java @@ -22,10 +22,15 @@ import com.google.adk.models.LlmResponse; import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletion; import com.google.adk.models.chat.ChatCompletionsResponse.ChatCompletionChunk; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.genai.types.Content; import com.google.genai.types.CustomMetadata; +import com.google.genai.types.FinishReason; import com.google.genai.types.FinishReason.Known; import com.google.genai.types.FunctionCall; import com.google.genai.types.Part; +import java.util.Arrays; import java.util.Base64; import java.util.List; import java.util.Map; @@ -482,7 +487,6 @@ public void testToLlmResponse_thoughtSignature() throws Exception { objectMapper.readValue(json, ChatCompletion.class); LlmResponse response = completion.toLlmResponse(); - assertThat(response.content().get().parts().get().get(0).thoughtSignature().get()) .isEqualTo(Base64.getDecoder().decode("c2ln")); } @@ -646,7 +650,7 @@ public void testToolCallToPart_withThoughtSignature() throws Exception { Part part = toolCall.toPart(); assertThat(part).isNotNull(); - assertThat(part.thoughtSignature().get()).isEqualTo(Base64.getDecoder().decode("c2ln")); + assertThat(part.thoughtSignature()).hasValue(Base64.getDecoder().decode("c2ln")); } @Test @@ -687,4 +691,185 @@ public void testToLlmResponse_noChoices() throws Exception { assertThat(response.content()).isPresent(); assertThat(response.content().get().parts()).isEmpty(); } + + @Test + public void testChunkCollection_accumulatesMultipleToolCalls() throws Exception { + ChatCompletionsResponse.ChatCompletionChunkCollection collection = + new ChatCompletionsResponse.ChatCompletionChunkCollection(); + + String chunk1Json = + """ + {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_id_1","type":"function","function":{"name":"roll_die","arguments":""}}]}}]} + """; + String chunk2Json = + """ + {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\\"sides\\\":8}"}}]}}]} + """; + String chunk3Json = + """ + {"choices":[{"delta":{"tool_calls":[{"index":1,"id":"call_id_2","type":"function","function":{"name":"roll_die","arguments":""}}]}}]} + """; + String chunk4Json = + """ + {"choices":[{"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\\\"sides\\\":8}"}}]}}]} + """; + String chunk5Json = + """ + {"choices":[{"finish_reason":"tool_calls"}]} + """; + + ImmutableList unused1 = + collection.processChunk( + objectMapper.readValue(chunk1Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList unused2 = + collection.processChunk( + objectMapper.readValue(chunk2Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList unused3 = + collection.processChunk( + objectMapper.readValue(chunk3Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList unused4 = + collection.processChunk( + objectMapper.readValue(chunk4Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList responses = + collection.processChunk( + objectMapper.readValue(chunk5Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + + LlmResponse expectedFinalResponse = + LlmResponse.builder() + .content( + Content.builder() + .role("") + .parts( + Arrays.asList( + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_id_1") + .name("roll_die") + .args(ImmutableMap.of("sides", 8)) + .build()) + .build(), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("call_id_2") + .name("roll_die") + .args(ImmutableMap.of("sides", 8)) + .build()) + .build())) + .build()) + .finishReason(new FinishReason(Known.STOP.toString())) + .customMetadata(ImmutableList.of()) + .modelVersion("") + .build(); + + LlmResponse finalResponse = responses.get(1); + + assertThat(finalResponse).isEqualTo(expectedFinalResponse); + } + + @Test + public void testChunkCollection_simpleText() throws Exception { + ChatCompletionsResponse.ChatCompletionChunkCollection collection = + new ChatCompletionsResponse.ChatCompletionChunkCollection(); + + String chunk1Json = + """ + {"choices":[{"delta":{"content":"Hello "}}]} + """; + String chunk2Json = + """ + {"choices":[{"delta":{"content":"World!"}}]} + """; + String chunk3Json = + """ + {"choices":[{"finish_reason":"stop"}]} + """; + + ImmutableList unused1 = + collection.processChunk( + objectMapper.readValue(chunk1Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList unused2 = + collection.processChunk( + objectMapper.readValue(chunk2Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList responses = + collection.processChunk( + objectMapper.readValue(chunk3Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + + LlmResponse expectedFinalResponse = + LlmResponse.builder() + .content( + Content.builder() + .role("") + .parts(ImmutableList.of(Part.fromText("Hello World!"))) + .build()) + .finishReason(new FinishReason(Known.STOP.toString())) + .customMetadata(ImmutableList.of()) + .modelVersion("") + .build(); + + LlmResponse finalResponse = responses.get(1); + + assertThat(finalResponse).isEqualTo(expectedFinalResponse); + } + + @Test + public void testChunkCollection_withRefusal() throws Exception { + ChatCompletionsResponse.ChatCompletionChunkCollection collection = + new ChatCompletionsResponse.ChatCompletionChunkCollection(); + + String chunk1Json = + """ + {"choices":[{"delta":{"refusal":"I cannot do that."}}]} + """; + String chunk2Json = + """ + {"choices":[{"finish_reason":"stop"}]} + """; + + ImmutableList unused1 = + collection.processChunk( + objectMapper.readValue(chunk1Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + ImmutableList responses = + collection.processChunk( + objectMapper.readValue(chunk2Json, ChatCompletionsResponse.ChatCompletionChunk.class)); + + LlmResponse expectedFinalResponse = + LlmResponse.builder() + .content( + Content.builder() + .role("") + .parts(ImmutableList.of(Part.fromText("I cannot do that."))) + .build()) + .finishReason(new FinishReason(Known.STOP.toString())) + .customMetadata(ImmutableList.of()) + .modelVersion("") + .build(); + + LlmResponse finalResponse = responses.get(1); + + assertThat(finalResponse).isEqualTo(expectedFinalResponse); + } + + @Test + public void testChunkCollection_noChoices() throws Exception { + String json = + """ + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4" + } + """; + + ChatCompletionsResponse.ChatCompletion completion = + objectMapper.readValue(json, ChatCompletionsResponse.ChatCompletion.class); + + LlmResponse response = completion.toLlmResponse(); + + assertThat(response.modelVersion()).hasValue("gpt-4"); + assertThat(response.content()).isPresent(); + assertThat(response.content().get().parts()).isEmpty(); + } }