diff --git a/src/main/java/com/google/genai/interactions/models/interactions/UserInputStep.kt b/src/main/java/com/google/genai/interactions/models/interactions/UserInputStep.kt index d66052dfa5b..88f30f9f44d 100644 --- a/src/main/java/com/google/genai/interactions/models/interactions/UserInputStep.kt +++ b/src/main/java/com/google/genai/interactions/models/interactions/UserInputStep.kt @@ -22,11 +22,21 @@ import com.fasterxml.jackson.annotation.JsonAnyGetter import com.fasterxml.jackson.annotation.JsonAnySetter import com.fasterxml.jackson.annotation.JsonCreator import com.fasterxml.jackson.annotation.JsonProperty +import com.fasterxml.jackson.core.JsonGenerator +import com.fasterxml.jackson.core.ObjectCodec +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.SerializerProvider +import com.fasterxml.jackson.databind.annotation.JsonDeserialize +import com.fasterxml.jackson.databind.annotation.JsonSerialize +import com.fasterxml.jackson.module.kotlin.jacksonTypeRef +import com.google.genai.interactions.core.BaseDeserializer +import com.google.genai.interactions.core.BaseSerializer import com.google.genai.interactions.core.ExcludeMissing import com.google.genai.interactions.core.JsonField import com.google.genai.interactions.core.JsonMissing import com.google.genai.interactions.core.JsonValue -import com.google.genai.interactions.core.checkKnown +import com.google.genai.interactions.core.allMaxBy +import com.google.genai.interactions.core.getOrThrow import com.google.genai.interactions.core.toImmutable import com.google.genai.interactions.errors.GeminiNextGenApiInvalidDataException import java.util.Collections @@ -39,16 +49,14 @@ class UserInputStep @JsonCreator(mode = JsonCreator.Mode.DISABLED) private constructor( private val type: JsonValue, - private val content: JsonField>, + private val content: JsonField, private val additionalProperties: MutableMap, ) { @JsonCreator private constructor( @JsonProperty("type") @ExcludeMissing type: JsonValue = JsonMissing.of(), - @JsonProperty("content") - @ExcludeMissing - content: JsonField> = JsonMissing.of(), + @JsonProperty("content") @ExcludeMissing content: JsonField = JsonMissing.of(), ) : this(type, content, mutableMapOf()) /** @@ -66,14 +74,14 @@ private constructor( * @throws GeminiNextGenApiInvalidDataException if the JSON field has an unexpected type (e.g. * if the server responded with an unexpected value). */ - fun content(): Optional> = content.getOptional("content") + fun content(): Optional = content.getOptional("content") /** * Returns the raw JSON value of [content]. * * Unlike [content], this method doesn't throw if the JSON field has an unexpected type. */ - @JsonProperty("content") @ExcludeMissing fun _content(): JsonField> = content + @JsonProperty("content") @ExcludeMissing fun _content(): JsonField = content @JsonAnySetter private fun putAdditionalProperty(key: String, value: JsonValue) { @@ -97,13 +105,13 @@ private constructor( class Builder internal constructor() { private var type: JsonValue = JsonValue.from("user_input") - private var content: JsonField>? = null + private var content: JsonField = JsonMissing.of() private var additionalProperties: MutableMap = mutableMapOf() @JvmSynthetic internal fun from(userInputStep: UserInputStep) = apply { type = userInputStep.type - content = userInputStep.content.map { it.toMutableList() } + content = userInputStep.content additionalProperties = userInputStep.additionalProperties.toMutableMap() } @@ -121,55 +129,21 @@ private constructor( */ fun type(type: JsonValue) = apply { this.type = type } - fun content(content: List) = content(JsonField.of(content)) + fun content(content: Content) = content(JsonField.of(content)) /** * Sets [Builder.content] to an arbitrary JSON value. * - * You should usually call [Builder.content] with a well-typed `List` value - * instead. This method is primarily for setting the field to an undocumented or not yet - * supported value. + * You should usually call [Builder.content] with a well-typed [Content] value instead. This + * method is primarily for setting the field to an undocumented or not yet supported value. */ - fun content(content: JsonField>) = apply { - this.content = content.map { it.toMutableList() } - } - - /** - * Adds a single [Content] to [Builder.content]. - * - * @throws IllegalStateException if the field was previously set to a non-list. - */ - fun addContent(content: Content) = apply { - this.content = - (this.content ?: JsonField.of(mutableListOf())).also { - checkKnown("content", it).add(content) - } - } - - /** Alias for calling [addContent] with `Content.ofText(text)`. */ - fun addContent(text: TextContent) = addContent(Content.ofText(text)) - - /** - * Alias for calling [addContent] with the following: - * ```java - * TextContent.builder() - * .text(text) - * .build() - * ``` - */ - fun addTextContent(text: String) = addContent(TextContent.builder().text(text).build()) - - /** Alias for calling [addContent] with `Content.ofImage(image)`. */ - fun addContent(image: ImageContent) = addContent(Content.ofImage(image)) + fun content(content: JsonField) = apply { this.content = content } - /** Alias for calling [addContent] with `Content.ofAudio(audio)`. */ - fun addContent(audio: AudioContent) = addContent(Content.ofAudio(audio)) + /** Alias for calling [content] with `Content.ofList(list)`. */ + fun contentOfList(list: List) = content(Content.ofList(list)) - /** Alias for calling [addContent] with `Content.ofDocument(document)`. */ - fun addContent(document: DocumentContent) = addContent(Content.ofDocument(document)) - - /** Alias for calling [addContent] with `Content.ofVideo(video)`. */ - fun addContent(video: VideoContent) = addContent(Content.ofVideo(video)) + /** Alias for calling [content] with `Content.ofString(string)`. */ + fun content(string: String) = content(Content.ofString(string)) fun additionalProperties(additionalProperties: Map) = apply { this.additionalProperties.clear() @@ -196,11 +170,7 @@ private constructor( * Further updates to this [Builder] will not mutate the returned instance. */ fun build(): UserInputStep = - UserInputStep( - type, - (content ?: JsonMissing.of()).map { it.toImmutable() }, - additionalProperties.toMutableMap(), - ) + UserInputStep(type, content, additionalProperties.toMutableMap()) } private var validated: Boolean = false @@ -215,7 +185,7 @@ private constructor( throw GeminiNextGenApiInvalidDataException("'type' is invalid, received $it") } } - content().ifPresent { it.forEach { it.validate() } } + content().ifPresent { it.validate() } validated = true } @@ -235,7 +205,179 @@ private constructor( @JvmSynthetic internal fun validity(): Int = type.let { if (it == JsonValue.from("user_input")) 1 else 0 } + - (content.asKnown().getOrNull()?.sumOf { it.validity().toInt() } ?: 0) + (content.asKnown().getOrNull()?.validity() ?: 0) + + @JsonDeserialize(using = Content.Deserializer::class) + @JsonSerialize(using = Content.Serializer::class) + class Content + private constructor( + private val list: List? = null, + private val string: String? = null, + private val _json: JsonValue? = null, + ) { + + fun list(): Optional> = Optional.ofNullable(list) + + fun string(): Optional = Optional.ofNullable(string) + + fun isList(): Boolean = list != null + + fun isString(): Boolean = string != null + + fun asList(): List = list.getOrThrow("list") + + fun asString(): String = string.getOrThrow("string") + + fun _json(): Optional = Optional.ofNullable(_json) + + fun accept(visitor: Visitor): T = + when { + list != null -> visitor.visitList(list) + string != null -> visitor.visitString(string) + else -> visitor.unknown(_json) + } + + private var validated: Boolean = false + + fun validate(): Content = apply { + if (validated) { + return@apply + } + + accept( + object : Visitor { + override fun visitList(list: List) { + list.forEach { it.validate() } + } + + override fun visitString(string: String) {} + } + ) + validated = true + } + + fun isValid(): Boolean = + try { + validate() + true + } catch (e: GeminiNextGenApiInvalidDataException) { + false + } + + /** + * Returns a score indicating how many valid values are contained in this object + * recursively. + * + * Used for best match union deserialization. + */ + @JvmSynthetic + internal fun validity(): Int = + accept( + object : Visitor { + override fun visitList(list: List) = + list.sumOf { it.validity().toInt() } + + override fun visitString(string: String) = 1 + + override fun unknown(json: JsonValue?) = 0 + } + ) + + override fun equals(other: Any?): Boolean { + if (this === other) { + return true + } + + return other is Content && list == other.list && string == other.string + } + + override fun hashCode(): Int = Objects.hash(list, string) + + override fun toString(): String = + when { + list != null -> "Content{list=$list}" + string != null -> "Content{string=$string}" + _json != null -> "Content{_unknown=$_json}" + else -> throw IllegalStateException("Invalid Content") + } + + companion object { + + @JvmStatic fun ofList(list: List) = Content(list = list.toImmutable()) + + @JvmStatic fun ofString(string: String) = Content(string = string) + } + + /** + * An interface that defines how to map each variant of [Content] to a value of type [T]. + */ + interface Visitor { + + fun visitList(list: List): T + + fun visitString(string: String): T + + /** + * Maps an unknown variant of [Content] to a value of type [T]. + * + * An instance of [Content] can contain an unknown variant if it was deserialized from + * data that doesn't match any known variant. For example, if the SDK is on an older + * version than the API, then the API may respond with new variants that the SDK is + * unaware of. + * + * @throws GeminiNextGenApiInvalidDataException in the default implementation. + */ + fun unknown(json: JsonValue?): T { + throw GeminiNextGenApiInvalidDataException("Unknown Content: $json") + } + } + + internal class Deserializer : BaseDeserializer(Content::class) { + + override fun ObjectCodec.deserialize(node: JsonNode): Content { + val json = JsonValue.fromJsonNode(node) + + val bestMatches = + sequenceOf( + tryDeserialize(node, jacksonTypeRef())?.let { + Content(string = it, _json = json) + }, + tryDeserialize(node, jacksonTypeRef>())?.let { + Content(list = it, _json = json) + }, + ) + .filterNotNull() + .allMaxBy { it.validity() } + .toList() + return when (bestMatches.size) { + // This can happen if what we're deserializing is completely incompatible with + // all the possible variants (e.g. deserializing from boolean). + 0 -> Content(_json = json) + 1 -> bestMatches.single() + // If there's more than one match with the highest validity, then use the first + // completely valid match, or simply the first match if none are completely + // valid. + else -> bestMatches.firstOrNull { it.isValid() } ?: bestMatches.first() + } + } + } + + internal class Serializer : BaseSerializer(Content::class) { + + override fun serialize( + value: Content, + generator: JsonGenerator, + provider: SerializerProvider, + ) { + when { + value.list != null -> generator.writeObject(value.list) + value.string != null -> generator.writeObject(value.string) + value._json != null -> generator.writeObject(value._json) + else -> throw IllegalStateException("Invalid Content") + } + } + } + } override fun equals(other: Any?): Boolean { if (this === other) {