Skip to content

Commit aac7dc8

Browse files
google-genai-botcopybara-github
authored andcommitted
feat: Add support for refusal content using "[[REFUSAL]]:" prefix
This is part of a larger chain of commits for adding chat completion API support to the Apigee model. PiperOrigin-RevId: 895386883
1 parent 9700523 commit aac7dc8

5 files changed

Lines changed: 218 additions & 16 deletions

File tree

core/src/main/java/com/google/adk/models/chat/ChatCompletionsCommon.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import com.google.genai.types.Part;
2626
import java.util.Base64;
2727
import java.util.Map;
28+
import java.util.Objects;
2829
import org.jspecify.annotations.Nullable;
2930

3031
/** Shared models for Chat Completions Request and Response. */
@@ -45,6 +46,50 @@ private ChatCompletionsCommon() {}
4546
public static final String METADATA_KEY_SYSTEM_FINGERPRINT = "system_fingerprint";
4647
public static final String METADATA_KEY_SERVICE_TIER = "service_tier";
4748

49+
/**
50+
* Prefix used to mark refusal content in a text Part, since there is no dedicated field for
51+
* refusal content in the Gemini API.
52+
*/
53+
static final String REFUSAL_PREFIX = "[[REFUSAL]]: ";
54+
55+
/**
56+
* Result of splitting a text part into its non-refusal content and refusal content. Either
57+
* component may be {@code null} when absent.
58+
*/
59+
record RefusalSplit(@Nullable String content, @Nullable String refusal) {}
60+
61+
/**
62+
* Splits a text Part value into a content portion and a refusal portion based on the {@link
63+
* #REFUSAL_PREFIX} sentinel:
64+
*
65+
* <ul>
66+
* <li>If {@code text} starts with the prefix, the entire suffix becomes the refusal and the
67+
* content is {@code null}.
68+
* <li>If {@code text} contains {@code "\n" + REFUSAL_PREFIX} (i.e., the prefix on its own line
69+
* after some content), the text is split: everything before the newline is content,
70+
* everything after the prefix is refusal.
71+
* <li>Otherwise the text is returned as content with no refusal. The prefix is intentionally
72+
* NOT recognized mid-line without a preceding newline.
73+
* </ul>
74+
*
75+
* @param text the raw text from a {@link Part#text()}.
76+
* @return a {@link RefusalSplit} with the content and refusal portions.
77+
*/
78+
static RefusalSplit parseRefusalPrefix(String text) {
79+
Objects.requireNonNull(text, "text cannot be null");
80+
if (text.startsWith(REFUSAL_PREFIX)) {
81+
return new RefusalSplit(null, text.substring(REFUSAL_PREFIX.length()));
82+
}
83+
String separator = "\n" + REFUSAL_PREFIX;
84+
int index = text.indexOf(separator);
85+
if (index >= 0) {
86+
String before = text.substring(0, index);
87+
String after = text.substring(index + separator.length());
88+
return new RefusalSplit(before.isEmpty() ? null : before, after);
89+
}
90+
return new RefusalSplit(text, null);
91+
}
92+
4893
/**
4994
* See
5095
* https://developers.openai.com/api/reference/resources/chat#(resource)%20chat.completions%20%3E%20(model)%20chat_completion_message_tool_call%20%3E%20(schema)

core/src/main/java/com/google/adk/models/chat/ChatCompletionsRequest.java

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -350,14 +350,26 @@ private static List<Message> processContent(Content content) {
350350
List<ContentPart> contentParts = new ArrayList<>();
351351
List<ChatCompletionsCommon.ToolCall> toolCalls = new ArrayList<>();
352352
List<Message> toolResponses = new ArrayList<>();
353+
List<String> refusals = new ArrayList<>();
353354

354355
content
355356
.parts()
356357
.ifPresent(
357358
parts -> {
358359
for (Part part : parts) {
359360
if (part.text().isPresent()) {
360-
contentParts.add(processTextPart(part));
361+
// Text Parts may carry refusal content prefixed with REFUSAL_PREFIX.
362+
ChatCompletionsCommon.RefusalSplit split =
363+
ChatCompletionsCommon.parseRefusalPrefix(part.text().get());
364+
if (split.content() != null) {
365+
ContentPart textPart = new ContentPart();
366+
textPart.type = "text";
367+
textPart.text = split.content();
368+
contentParts.add(textPart);
369+
}
370+
if (split.refusal() != null) {
371+
refusals.add(split.refusal());
372+
}
361373
} else if (part.inlineData().isPresent()) {
362374
contentParts.add(processInlineDataPart(part));
363375
} else if (part.fileData().isPresent()) {
@@ -381,6 +393,9 @@ private static List<Message> processContent(Content content) {
381393
if (!toolCalls.isEmpty()) {
382394
msg.toolCalls = ImmutableList.copyOf(toolCalls);
383395
}
396+
if (!refusals.isEmpty()) {
397+
msg.refusal = String.join("\n", refusals);
398+
}
384399
if (!contentParts.isEmpty()) {
385400
if (contentParts.size() == 1 && Objects.equals(contentParts.get(0).type, "text")) {
386401
msg.content = new MessageContent(contentParts.get(0).text);
@@ -394,19 +409,6 @@ private static List<Message> processContent(Content content) {
394409
}
395410
}
396411

397-
/**
398-
* Processes a text part and returns a mapped ContentPart.
399-
*
400-
* @param part The input part containing simple text.
401-
* @return The mapped text part.
402-
*/
403-
private static ContentPart processTextPart(Part part) {
404-
ContentPart textPart = new ContentPart();
405-
textPart.type = "text";
406-
textPart.text = part.text().get();
407-
return textPart;
408-
}
409-
410412
/**
411413
* Processes an inline data part and returns a mapped ContentPart.
412414
*

core/src/main/java/com/google/adk/models/chat/ChatCompletionsResponse.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ private ImmutableList<Part> mapMessageToParts(Message message) {
180180
parts.add(Part.fromText(message.content));
181181
}
182182
if (message.refusal != null) {
183-
parts.add(Part.fromText(message.refusal));
183+
parts.add(Part.fromText(ChatCompletionsCommon.REFUSAL_PREFIX + message.refusal));
184184
}
185185
if (message.toolCalls != null) {
186186
parts.addAll(mapToolCallsToParts(message.toolCalls));

core/src/test/java/com/google/adk/models/chat/ChatCompletionsRequestTest.java

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,157 @@ public void testFromLlmRequest_basic() throws Exception {
245245
assertThat(request.messages.get(0).content.getValue()).isEqualTo("Hello");
246246
}
247247

248+
@Test
249+
public void testFromLlmRequest_withRefusal() throws Exception {
250+
LlmRequest llmRequest =
251+
LlmRequest.builder()
252+
.model("gemini-1.5-pro")
253+
.contents(
254+
ImmutableList.of(
255+
Content.builder()
256+
.role("model")
257+
.parts(
258+
ImmutableList.of(
259+
Part.fromText("Regular text response"),
260+
Part.fromText(
261+
ChatCompletionsCommon.REFUSAL_PREFIX + "I cannot do that.")))
262+
.build()))
263+
.build();
264+
265+
ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);
266+
267+
assertThat(request.messages).hasSize(1);
268+
ChatCompletionsRequest.Message message = request.messages.get(0);
269+
assertThat(message.role).isEqualTo("assistant");
270+
assertThat(message.refusal).isEqualTo("I cannot do that.");
271+
assertThat(message.content.getValue()).isEqualTo("Regular text response");
272+
}
273+
274+
@Test
275+
public void testFromLlmRequest_withRefusalEmbeddedAfterNewline() throws Exception {
276+
// A single Part containing both content and refusal, separated by "\n[[REFUSAL]]: ".
277+
LlmRequest llmRequest =
278+
LlmRequest.builder()
279+
.model("gemini-1.5-pro")
280+
.contents(
281+
ImmutableList.of(
282+
Content.builder()
283+
.role("model")
284+
.parts(
285+
ImmutableList.of(
286+
Part.fromText(
287+
"Partial text answer\n"
288+
+ ChatCompletionsCommon.REFUSAL_PREFIX
289+
+ "System error or refusal")))
290+
.build()))
291+
.build();
292+
293+
ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);
294+
295+
assertThat(request.messages).hasSize(1);
296+
ChatCompletionsRequest.Message message = request.messages.get(0);
297+
assertThat(message.role).isEqualTo("assistant");
298+
assertThat(message.content.getValue()).isEqualTo("Partial text answer");
299+
assertThat(message.refusal).isEqualTo("System error or refusal");
300+
}
301+
302+
@Test
303+
public void testFromLlmRequest_withMultipleRefusalsJoinedWithNewline() throws Exception {
304+
LlmRequest llmRequest =
305+
LlmRequest.builder()
306+
.model("gemini-1.5-pro")
307+
.contents(
308+
ImmutableList.of(
309+
Content.builder()
310+
.role("model")
311+
.parts(
312+
ImmutableList.of(
313+
Part.fromText(ChatCompletionsCommon.REFUSAL_PREFIX + "First"),
314+
Part.fromText(ChatCompletionsCommon.REFUSAL_PREFIX + "Second")))
315+
.build()))
316+
.build();
317+
318+
ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);
319+
320+
assertThat(request.messages).hasSize(1);
321+
ChatCompletionsRequest.Message message = request.messages.get(0);
322+
assertThat(message.role).isEqualTo("assistant");
323+
assertThat(message.refusal).isEqualTo("First\nSecond");
324+
assertThat(message.content).isNull();
325+
}
326+
327+
@Test
328+
public void testFromLlmRequest_withRefusalOnlyHasNullContent() throws Exception {
329+
LlmRequest llmRequest =
330+
LlmRequest.builder()
331+
.model("gemini-1.5-pro")
332+
.contents(
333+
ImmutableList.of(
334+
Content.builder()
335+
.role("model")
336+
.parts(
337+
ImmutableList.of(
338+
Part.fromText(
339+
ChatCompletionsCommon.REFUSAL_PREFIX + "Only a refusal")))
340+
.build()))
341+
.build();
342+
343+
ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);
344+
345+
assertThat(request.messages).hasSize(1);
346+
ChatCompletionsRequest.Message message = request.messages.get(0);
347+
assertThat(message.role).isEqualTo("assistant");
348+
assertThat(message.refusal).isEqualTo("Only a refusal");
349+
assertThat(message.content).isNull();
350+
}
351+
352+
@Test
353+
public void testFromLlmRequest_withRefusalPrefixAfterEmptyContentLine() throws Exception {
354+
// Edge case: text begins with "\n[[REFUSAL]]: ..." -- empty content before the prefix.
355+
// Expectation: no content part, refusal populated.
356+
String text = "\n" + ChatCompletionsCommon.REFUSAL_PREFIX + "Refusal only";
357+
LlmRequest llmRequest =
358+
LlmRequest.builder()
359+
.model("gemini-1.5-pro")
360+
.contents(
361+
ImmutableList.of(
362+
Content.builder()
363+
.role("model")
364+
.parts(ImmutableList.of(Part.fromText(text)))
365+
.build()))
366+
.build();
367+
368+
ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);
369+
370+
assertThat(request.messages).hasSize(1);
371+
ChatCompletionsRequest.Message message = request.messages.get(0);
372+
assertThat(message.refusal).isEqualTo("Refusal only");
373+
assertThat(message.content).isNull();
374+
}
375+
376+
@Test
377+
public void testFromLlmRequest_withRefusalPrefixMidLineIsNotSplit() throws Exception {
378+
// The prefix is intentionally NOT recognized mid-line without a preceding newline.
379+
String inlineText = "foo " + ChatCompletionsCommon.REFUSAL_PREFIX + "bar";
380+
LlmRequest llmRequest =
381+
LlmRequest.builder()
382+
.model("gemini-1.5-pro")
383+
.contents(
384+
ImmutableList.of(
385+
Content.builder()
386+
.role("model")
387+
.parts(ImmutableList.of(Part.fromText(inlineText)))
388+
.build()))
389+
.build();
390+
391+
ChatCompletionsRequest request = ChatCompletionsRequest.fromLlmRequest(llmRequest, false);
392+
393+
assertThat(request.messages).hasSize(1);
394+
ChatCompletionsRequest.Message message = request.messages.get(0);
395+
assertThat(message.refusal).isNull();
396+
assertThat(message.content.getValue()).isEqualTo(inlineText);
397+
}
398+
248399
@Test
249400
public void testFromLlmRequest_withSystemInstruction() throws Exception {
250401
LlmRequest llmRequest =

core/src/test/java/com/google/adk/models/chat/ChatCompletionsResponseTest.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,7 @@ public void testToLlmResponse_withRefusal() throws Exception {
504504
"index": 0,
505505
"message": {
506506
"role": "assistant",
507+
"content": "Partial text answer",
507508
"refusal": "System error or refusal"
508509
},
509510
"finish_reason": "stop"
@@ -521,8 +522,11 @@ public void testToLlmResponse_withRefusal() throws Exception {
521522

522523
// Content
523524
assertThat(response.content().get().role()).hasValue("model");
525+
assertThat(response.content().get().parts().get()).hasSize(2);
524526
assertThat(response.content().get().parts().get().get(0).text())
525-
.hasValue("System error or refusal");
527+
.hasValue("Partial text answer");
528+
assertThat(response.content().get().parts().get().get(1).text())
529+
.hasValue("[[REFUSAL]]: System error or refusal");
526530

527531
// Custom Metadata
528532
List<CustomMetadata> metadata = response.customMetadata().get();

0 commit comments

Comments
 (0)