Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,21 @@ import com.fasterxml.jackson.annotation.JsonAnyGetter
import com.fasterxml.jackson.annotation.JsonAnySetter
import com.fasterxml.jackson.annotation.JsonCreator
import com.fasterxml.jackson.annotation.JsonProperty
import com.fasterxml.jackson.core.JsonGenerator
import com.fasterxml.jackson.core.ObjectCodec
import com.fasterxml.jackson.databind.JsonNode
import com.fasterxml.jackson.databind.SerializerProvider
import com.fasterxml.jackson.databind.annotation.JsonDeserialize
import com.fasterxml.jackson.databind.annotation.JsonSerialize
import com.fasterxml.jackson.module.kotlin.jacksonTypeRef
import com.google.genai.interactions.core.BaseDeserializer
import com.google.genai.interactions.core.BaseSerializer
import com.google.genai.interactions.core.ExcludeMissing
import com.google.genai.interactions.core.JsonField
import com.google.genai.interactions.core.JsonMissing
import com.google.genai.interactions.core.JsonValue
import com.google.genai.interactions.core.checkKnown
import com.google.genai.interactions.core.allMaxBy
import com.google.genai.interactions.core.getOrThrow
import com.google.genai.interactions.core.toImmutable
import com.google.genai.interactions.errors.GeminiNextGenApiInvalidDataException
import java.util.Collections
Expand All @@ -39,16 +49,14 @@ class UserInputStep
@JsonCreator(mode = JsonCreator.Mode.DISABLED)
private constructor(
private val type: JsonValue,
private val content: JsonField<List<Content>>,
private val content: JsonField<Content>,
private val additionalProperties: MutableMap<String, JsonValue>,
) {

@JsonCreator
private constructor(
@JsonProperty("type") @ExcludeMissing type: JsonValue = JsonMissing.of(),
@JsonProperty("content")
@ExcludeMissing
content: JsonField<List<Content>> = JsonMissing.of(),
@JsonProperty("content") @ExcludeMissing content: JsonField<Content> = JsonMissing.of(),
) : this(type, content, mutableMapOf())

/**
Expand All @@ -66,14 +74,14 @@ private constructor(
* @throws GeminiNextGenApiInvalidDataException if the JSON field has an unexpected type (e.g.
* if the server responded with an unexpected value).
*/
fun content(): Optional<List<Content>> = content.getOptional("content")
fun content(): Optional<Content> = content.getOptional("content")

/**
* Returns the raw JSON value of [content].
*
* Unlike [content], this method doesn't throw if the JSON field has an unexpected type.
*/
@JsonProperty("content") @ExcludeMissing fun _content(): JsonField<List<Content>> = content
@JsonProperty("content") @ExcludeMissing fun _content(): JsonField<Content> = content

@JsonAnySetter
private fun putAdditionalProperty(key: String, value: JsonValue) {
Expand All @@ -97,13 +105,13 @@ private constructor(
class Builder internal constructor() {

private var type: JsonValue = JsonValue.from("user_input")
private var content: JsonField<MutableList<Content>>? = null
private var content: JsonField<Content> = JsonMissing.of()
private var additionalProperties: MutableMap<String, JsonValue> = mutableMapOf()

@JvmSynthetic
internal fun from(userInputStep: UserInputStep) = apply {
type = userInputStep.type
content = userInputStep.content.map { it.toMutableList() }
content = userInputStep.content
additionalProperties = userInputStep.additionalProperties.toMutableMap()
}

Expand All @@ -121,55 +129,21 @@ private constructor(
*/
fun type(type: JsonValue) = apply { this.type = type }

fun content(content: List<Content>) = content(JsonField.of(content))
fun content(content: Content) = content(JsonField.of(content))

/**
* Sets [Builder.content] to an arbitrary JSON value.
*
* You should usually call [Builder.content] with a well-typed `List<Content>` value
* instead. This method is primarily for setting the field to an undocumented or not yet
* supported value.
* You should usually call [Builder.content] with a well-typed [Content] value instead. This
* method is primarily for setting the field to an undocumented or not yet supported value.
*/
fun content(content: JsonField<List<Content>>) = apply {
this.content = content.map { it.toMutableList() }
}

/**
* Adds a single [Content] to [Builder.content].
*
* @throws IllegalStateException if the field was previously set to a non-list.
*/
fun addContent(content: Content) = apply {
this.content =
(this.content ?: JsonField.of(mutableListOf())).also {
checkKnown("content", it).add(content)
}
}

/** Alias for calling [addContent] with `Content.ofText(text)`. */
fun addContent(text: TextContent) = addContent(Content.ofText(text))

/**
* Alias for calling [addContent] with the following:
* ```java
* TextContent.builder()
* .text(text)
* .build()
* ```
*/
fun addTextContent(text: String) = addContent(TextContent.builder().text(text).build())

/** Alias for calling [addContent] with `Content.ofImage(image)`. */
fun addContent(image: ImageContent) = addContent(Content.ofImage(image))
fun content(content: JsonField<Content>) = apply { this.content = content }

/** Alias for calling [addContent] with `Content.ofAudio(audio)`. */
fun addContent(audio: AudioContent) = addContent(Content.ofAudio(audio))
/** Alias for calling [content] with `Content.ofList(list)`. */
fun contentOfList(list: List<Content>) = content(Content.ofList(list))

/** Alias for calling [addContent] with `Content.ofDocument(document)`. */
fun addContent(document: DocumentContent) = addContent(Content.ofDocument(document))

/** Alias for calling [addContent] with `Content.ofVideo(video)`. */
fun addContent(video: VideoContent) = addContent(Content.ofVideo(video))
/** Alias for calling [content] with `Content.ofString(string)`. */
fun content(string: String) = content(Content.ofString(string))

fun additionalProperties(additionalProperties: Map<String, JsonValue>) = apply {
this.additionalProperties.clear()
Expand All @@ -196,11 +170,7 @@ private constructor(
* Further updates to this [Builder] will not mutate the returned instance.
*/
fun build(): UserInputStep =
UserInputStep(
type,
(content ?: JsonMissing.of()).map { it.toImmutable() },
additionalProperties.toMutableMap(),
)
UserInputStep(type, content, additionalProperties.toMutableMap())
}

private var validated: Boolean = false
Expand All @@ -215,7 +185,7 @@ private constructor(
throw GeminiNextGenApiInvalidDataException("'type' is invalid, received $it")
}
}
content().ifPresent { it.forEach { it.validate() } }
content().ifPresent { it.validate() }
validated = true
}

Expand All @@ -235,7 +205,179 @@ private constructor(
@JvmSynthetic
internal fun validity(): Int =
type.let { if (it == JsonValue.from("user_input")) 1 else 0 } +
(content.asKnown().getOrNull()?.sumOf { it.validity().toInt() } ?: 0)
(content.asKnown().getOrNull()?.validity() ?: 0)

@JsonDeserialize(using = Content.Deserializer::class)
@JsonSerialize(using = Content.Serializer::class)
class Content
private constructor(
private val list: List<Content>? = null,
private val string: String? = null,
private val _json: JsonValue? = null,
) {

fun list(): Optional<List<Content>> = Optional.ofNullable(list)

fun string(): Optional<String> = Optional.ofNullable(string)

fun isList(): Boolean = list != null

fun isString(): Boolean = string != null

fun asList(): List<Content> = list.getOrThrow("list")

fun asString(): String = string.getOrThrow("string")

fun _json(): Optional<JsonValue> = Optional.ofNullable(_json)

fun <T> accept(visitor: Visitor<T>): T =
when {
list != null -> visitor.visitList(list)
string != null -> visitor.visitString(string)
else -> visitor.unknown(_json)
}

private var validated: Boolean = false

fun validate(): Content = apply {
if (validated) {
return@apply
}

accept(
object : Visitor<Unit> {
override fun visitList(list: List<Content>) {
list.forEach { it.validate() }
}

override fun visitString(string: String) {}
}
)
validated = true
}

fun isValid(): Boolean =
try {
validate()
true
} catch (e: GeminiNextGenApiInvalidDataException) {
false
}

/**
* Returns a score indicating how many valid values are contained in this object
* recursively.
*
* Used for best match union deserialization.
*/
@JvmSynthetic
internal fun validity(): Int =
accept(
object : Visitor<Int> {
override fun visitList(list: List<Content>) =
list.sumOf { it.validity().toInt() }

override fun visitString(string: String) = 1

override fun unknown(json: JsonValue?) = 0
}
)

override fun equals(other: Any?): Boolean {
if (this === other) {
return true
}

return other is Content && list == other.list && string == other.string
}

override fun hashCode(): Int = Objects.hash(list, string)

override fun toString(): String =
when {
list != null -> "Content{list=$list}"
string != null -> "Content{string=$string}"
_json != null -> "Content{_unknown=$_json}"
else -> throw IllegalStateException("Invalid Content")
}

companion object {

@JvmStatic fun ofList(list: List<Content>) = Content(list = list.toImmutable())

@JvmStatic fun ofString(string: String) = Content(string = string)
}

/**
* An interface that defines how to map each variant of [Content] to a value of type [T].
*/
interface Visitor<out T> {

fun visitList(list: List<Content>): T

fun visitString(string: String): T

/**
* Maps an unknown variant of [Content] to a value of type [T].
*
* An instance of [Content] can contain an unknown variant if it was deserialized from
* data that doesn't match any known variant. For example, if the SDK is on an older
* version than the API, then the API may respond with new variants that the SDK is
* unaware of.
*
* @throws GeminiNextGenApiInvalidDataException in the default implementation.
*/
fun unknown(json: JsonValue?): T {
throw GeminiNextGenApiInvalidDataException("Unknown Content: $json")
}
}

internal class Deserializer : BaseDeserializer<Content>(Content::class) {

override fun ObjectCodec.deserialize(node: JsonNode): Content {
val json = JsonValue.fromJsonNode(node)

val bestMatches =
sequenceOf(
tryDeserialize(node, jacksonTypeRef<String>())?.let {
Content(string = it, _json = json)
},
tryDeserialize(node, jacksonTypeRef<List<Content>>())?.let {
Content(list = it, _json = json)
},
)
.filterNotNull()
.allMaxBy { it.validity() }
.toList()
return when (bestMatches.size) {
// This can happen if what we're deserializing is completely incompatible with
// all the possible variants (e.g. deserializing from boolean).
0 -> Content(_json = json)
1 -> bestMatches.single()
// If there's more than one match with the highest validity, then use the first
// completely valid match, or simply the first match if none are completely
// valid.
else -> bestMatches.firstOrNull { it.isValid() } ?: bestMatches.first()
}
}
}

internal class Serializer : BaseSerializer<Content>(Content::class) {

override fun serialize(
value: Content,
generator: JsonGenerator,
provider: SerializerProvider,
) {
when {
value.list != null -> generator.writeObject(value.list)
value.string != null -> generator.writeObject(value.string)
value._json != null -> generator.writeObject(value._json)
else -> throw IllegalStateException("Invalid Content")
}
}
}
}

override fun equals(other: Any?): Boolean {
if (this === other) {
Expand Down
Loading