diff --git a/AGENTS.md b/AGENTS.md index ba40082d595..184263bac42 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -71,6 +71,14 @@ merge. | Python | 3.12 | | Node | 24 | +JDK 17 is required, not just recommended. The `sbt-jacoco 3.5.0` plugin +([`project/plugins.sbt`](project/plugins.sbt)) ships JaCoCo 0.8.11, which +cannot instrument class files compiled to newer bytecode versions — under +JDK 21+ `sbt test` fails during the instrumentation pass with +`java.io.IOException: Error while instrumenting .class with JaCoCo`, +before any test runs. Point `JAVA_HOME` at a Temurin 17 install for sbt +invocations until the plugin is upgraded. + One Python venv shared across worktrees, sibling of the texera checkout: ``` @@ -90,6 +98,21 @@ in [`udf.conf`](common/config/src/main/resources/udf.conf) or `export UDF_PYTHON_PATH="$(pwd)/../venv312/bin/python"` (env var overrides). Without it, `sbt` Python-integration tests fail to launch a worker. +Backend services that touch datasets (`FileService`, +`WorkflowComputingUnitManagingService`, anything calling +`S3StorageClient`/`LakeFSStorageClient`) need MinIO + LakeFS running locally. +Bring them up with: + +```bash +docker compose -f file-service/src/main/resources/docker-compose.yml up -d minio lakefs +``` + +MinIO listens on `:9000`, LakeFS on `:8000`; both are wired to the +defaults in [`storage.conf`](common/config/src/main/resources/storage.conf). +Without these, startup fails at +[`S3StorageClient.createBucketIfNotExist`](file-service/src/main/scala/org/apache/texera/service/util/S3StorageClient.scala) +with `Connection refused` against `localhost:9000`. + [`.jvmopts`](.jvmopts) holds every `--add-opens` flag Texera needs for JDK 17+, with each group annotated by its upstream source (Kryo, Apache Arrow, Apache Pekko). sbt's launcher and the [`.run/`](.run) diff --git a/access-control-service/src/main/scala/org/apache/texera/service/AccessControlService.scala b/access-control-service/src/main/scala/org/apache/texera/service/AccessControlService.scala index 21d367e2bb1..c4b15deec7a 100644 --- a/access-control-service/src/main/scala/org/apache/texera/service/AccessControlService.scala +++ b/access-control-service/src/main/scala/org/apache/texera/service/AccessControlService.scala @@ -31,7 +31,8 @@ import org.apache.texera.service.resource.{ AccessControlResource, HealthCheckResource, LiteLLMModelsResource, - LiteLLMProxyResource + LiteLLMProxyResource, + OpenRouterModelsResource } import org.eclipse.jetty.server.session.SessionHandler import java.nio.file.Path @@ -69,6 +70,7 @@ class AccessControlService extends Application[AccessControlServiceConfiguration environment.jersey.register(classOf[AccessControlResource]) environment.jersey.register(classOf[LiteLLMProxyResource]) environment.jersey.register(classOf[LiteLLMModelsResource]) + environment.jersey.register(new OpenRouterModelsResource) // Register JWT authentication filter environment.jersey.register(new AuthDynamicFeature(classOf[JwtAuthFilter])) diff --git a/access-control-service/src/main/scala/org/apache/texera/service/resource/OpenRouterModelsResource.scala b/access-control-service/src/main/scala/org/apache/texera/service/resource/OpenRouterModelsResource.scala new file mode 100644 index 00000000000..291a6c976ba --- /dev/null +++ b/access-control-service/src/main/scala/org/apache/texera/service/resource/OpenRouterModelsResource.scala @@ -0,0 +1,188 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.texera.service.resource + +import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} +import com.fasterxml.jackson.module.scala.DefaultScalaModule +import com.typesafe.scalalogging.LazyLogging +import jakarta.ws.rs.core.{MediaType, Response} +import jakarta.ws.rs.{GET, Path, Produces} +import org.apache.texera.config.{GuiConfig, LLMConfig} + +import java.net.URI +import java.net.http.{HttpClient, HttpRequest, HttpResponse} +import java.nio.charset.StandardCharsets +import java.time.{Clock, Duration} +import scala.jdk.CollectionConverters.IteratorHasAsScala + +case class OpenRouterModelSummary( + id: String, + name: String, + contextLength: Option[Long], + pricing: Map[String, String] +) + +case class OpenRouterModelsResponse( + data: Seq[OpenRouterModelSummary], + cachedAtEpochMillis: Long, + expiresAtEpochMillis: Long, + stale: Boolean, + error: Option[String] = None +) + +case class OpenRouterModelsError(error: String) + +object OpenRouterModelsResource { + private val mapper: ObjectMapper = new ObjectMapper().registerModule(DefaultScalaModule) + private val openRouterModelsUri = + URI.create("https://openrouter.ai/api/v1/models?output_modalities=text") + private val client = HttpClient + .newBuilder() + .connectTimeout(Duration.ofSeconds(3)) + .build() + + def buildOpenRouterModelsRequest(openRouterApiKey: Option[String]): HttpRequest = { + val builder = HttpRequest + .newBuilder(openRouterModelsUri) + .timeout(Duration.ofSeconds(5)) + .header("Accept", "application/json") + .header("User-Agent", "Apache-Texera") + + openRouterApiKey + .map(_.trim) + .filter(_.nonEmpty) + .foreach(apiKey => builder.header("Authorization", s"Bearer $apiKey")) + + builder.GET().build() + } + + def fetchOpenRouterModelsJson( + openRouterApiKey: Option[String] = LLMConfig.openRouterApiKey + ): String = { + val request = buildOpenRouterModelsRequest(openRouterApiKey) + + val response = client.send(request, HttpResponse.BodyHandlers.ofString(StandardCharsets.UTF_8)) + if (response.statusCode() / 100 != 2) { + throw new RuntimeException(s"OpenRouter returned HTTP ${response.statusCode()}") + } + response.body() + } + + def summarizeModels(rawJson: String): Seq[OpenRouterModelSummary] = { + val root = mapper.readTree(rawJson) + val data = root.path("data") + if (!data.isArray) { + throw new IllegalArgumentException("OpenRouter response does not contain a data array") + } + + data.elements().asScala.flatMap { model => + for { + id <- nonEmptyText(model, "id") + name <- nonEmptyText(model, "name") + } yield OpenRouterModelSummary( + id = id, + name = name, + contextLength = longValue(model, "context_length"), + pricing = pricing(model.path("pricing")) + ) + }.toSeq + } + + private def nonEmptyText(node: JsonNode, fieldName: String): Option[String] = + Option(node.get(fieldName)) + .filterNot(n => n.isNull || n.isMissingNode) + .map(_.asText().trim) + .filter(_.nonEmpty) + + private def longValue(node: JsonNode, fieldName: String): Option[Long] = + Option(node.get(fieldName)) + .filter(n => n.isNumber) + .map(_.asLong()) + + private def pricing(node: JsonNode): Map[String, String] = + if (node == null || !node.isObject) { + Map.empty + } else { + node + .fields() + .asScala + .filterNot(entry => entry.getValue.isNull || entry.getValue.isMissingNode) + .map(entry => entry.getKey -> entry.getValue.asText()) + .toMap + } +} + +@Path("/models/openrouter") +@Produces(Array(MediaType.APPLICATION_JSON)) +class OpenRouterModelsResource( + fetchModelsJson: () => String = () => OpenRouterModelsResource.fetchOpenRouterModelsJson(), + clock: Clock = Clock.systemUTC(), + cacheTtl: Duration = Duration.ofHours(1), + staleFailureRetryTtl: Duration = Duration.ofMinutes(5), + isCopilotEnabled: () => Boolean = () => GuiConfig.guiWorkflowWorkspaceCopilotEnabled +) extends LazyLogging { + + private var cachedResponse: Option[OpenRouterModelsResponse] = None + + @GET + def getOpenRouterModels: Response = synchronized { + if (!isCopilotEnabled()) { + return Response + .status(Response.Status.FORBIDDEN) + .entity("""{"error": "Copilot feature is disabled"}""") + .build() + } + + val now = clock.millis() + cachedResponse.filter(_.expiresAtEpochMillis > now) match { + case Some(cached) => Response.ok(cached).build() + case None => refresh(now) + } + } + + private def refresh(now: Long): Response = + try { + val models = OpenRouterModelsResource.summarizeModels(fetchModelsJson()) + val response = OpenRouterModelsResponse( + data = models, + cachedAtEpochMillis = now, + expiresAtEpochMillis = now + cacheTtl.toMillis, + stale = false + ) + cachedResponse = Some(response) + Response.ok(response).build() + } catch { + case e: Exception => + logger.warn(s"Failed to fetch OpenRouter models: ${e.getMessage}", e) + cachedResponse match { + case Some(cached) => + val staleResponse = cached.copy( + expiresAtEpochMillis = now + staleFailureRetryTtl.toMillis, + stale = true, + error = Some(e.getMessage) + ) + cachedResponse = Some(staleResponse) + Response.ok(staleResponse).build() + case None => + Response + .status(Response.Status.SERVICE_UNAVAILABLE) + .entity(OpenRouterModelsError(s"Failed to fetch OpenRouter models: ${e.getMessage}")) + .build() + } + } +} diff --git a/access-control-service/src/test/scala/org/apache/texera/OpenRouterModelsResourceSpec.scala b/access-control-service/src/test/scala/org/apache/texera/OpenRouterModelsResourceSpec.scala new file mode 100644 index 00000000000..85731b6568a --- /dev/null +++ b/access-control-service/src/test/scala/org/apache/texera/OpenRouterModelsResourceSpec.scala @@ -0,0 +1,206 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.texera + +import jakarta.ws.rs.core.Response +import org.apache.texera.service.resource.{ + OpenRouterModelSummary, + OpenRouterModelsResource, + OpenRouterModelsResponse +} +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +import java.time.{Clock, Duration, Instant, ZoneOffset} + +class OpenRouterModelsResourceSpec extends AnyFlatSpec with Matchers { + + private val fixtureJson = + """ + |{ + | "data": [ + | { + | "id": "openai/gpt-4", + | "name": "GPT-4", + | "context_length": 8192, + | "pricing": { + | "prompt": "0.00003", + | "completion": "0.00006" + | } + | }, + | { + | "id": "meta-llama/llama-3.1-8b-instruct:free", + | "name": "Meta: Llama 3.1 8B Instruct (free)", + | "pricing": { + | "prompt": "0", + | "completion": "0" + | } + | } + | ] + |} + |""".stripMargin + + private val fixedClock: Clock = + Clock.fixed(Instant.ofEpochMilli(1000), ZoneOffset.UTC) + + private def openRouterResource( + fetchModelsJson: () => String, + clock: Clock = fixedClock, + cacheTtl: Duration = Duration.ofHours(1), + staleFailureRetryTtl: Duration = Duration.ofMinutes(5) + ): OpenRouterModelsResource = + new OpenRouterModelsResource( + fetchModelsJson, + clock, + cacheTtl, + staleFailureRetryTtl, + isCopilotEnabled = () => true + ) + + "OpenRouterModelsResource" should "summarize model ids, names, context, and pricing" in { + val resource = openRouterResource(() => fixtureJson) + + val response = resource.getOpenRouterModels + response.getStatus shouldBe Response.Status.OK.getStatusCode + + val entity = response.getEntity.asInstanceOf[OpenRouterModelsResponse] + entity.stale shouldBe false + entity.cachedAtEpochMillis shouldBe 1000 + entity.data should contain theSameElementsInOrderAs Seq( + OpenRouterModelSummary( + "openai/gpt-4", + "GPT-4", + Some(8192), + Map("prompt" -> "0.00003", "completion" -> "0.00006") + ), + OpenRouterModelSummary( + "meta-llama/llama-3.1-8b-instruct:free", + "Meta: Llama 3.1 8B Instruct (free)", + None, + Map("prompt" -> "0", "completion" -> "0") + ) + ) + } + + it should "serve cached models without fetching again before the TTL expires" in { + var fetches = 0 + val resource = openRouterResource( + () => { + fetches += 1 + fixtureJson + }, + fixedClock, + Duration.ofHours(1) + ) + + resource.getOpenRouterModels.getStatus shouldBe Response.Status.OK.getStatusCode + resource.getOpenRouterModels.getStatus shouldBe Response.Status.OK.getStatusCode + + fetches shouldBe 1 + } + + it should "return FORBIDDEN without fetching models when copilot is disabled" in { + var fetches = 0 + val resource = new OpenRouterModelsResource( + () => { + fetches += 1 + fixtureJson + }, + fixedClock, + isCopilotEnabled = () => false + ) + + val response = resource.getOpenRouterModels + + response.getStatus shouldBe Response.Status.FORBIDDEN.getStatusCode + response.getEntity shouldBe """{"error": "Copilot feature is disabled"}""" + fetches shouldBe 0 + } + + it should "return stale cached models when refresh fails after the TTL expires" in { + var fetches = 0 + val resource = openRouterResource( + () => { + fetches += 1 + if (fetches == 1) fixtureJson else throw new RuntimeException("upstream unavailable") + }, + fixedClock, + Duration.ZERO + ) + + resource.getOpenRouterModels.getStatus shouldBe Response.Status.OK.getStatusCode + val staleResponse = resource.getOpenRouterModels + + staleResponse.getStatus shouldBe Response.Status.OK.getStatusCode + val entity = staleResponse.getEntity.asInstanceOf[OpenRouterModelsResponse] + entity.stale shouldBe true + entity.error should contain("upstream unavailable") + } + + it should "reuse stale cached models during the failure retry window" in { + var fetches = 0 + val resource = openRouterResource( + () => { + fetches += 1 + if (fetches == 1) fixtureJson else throw new RuntimeException("upstream unavailable") + }, + fixedClock, + Duration.ZERO, + staleFailureRetryTtl = Duration.ofMinutes(5) + ) + + resource.getOpenRouterModels.getStatus shouldBe Response.Status.OK.getStatusCode + resource.getOpenRouterModels.getStatus shouldBe Response.Status.OK.getStatusCode + val staleResponse = resource.getOpenRouterModels + + staleResponse.getStatus shouldBe Response.Status.OK.getStatusCode + staleResponse.getEntity.asInstanceOf[OpenRouterModelsResponse].stale shouldBe true + fetches shouldBe 2 + } + + it should "return SERVICE_UNAVAILABLE when the first upstream fetch fails" in { + val resource = openRouterResource( + () => throw new RuntimeException("upstream unavailable"), + fixedClock + ) + + val response = resource.getOpenRouterModels + + response.getStatus shouldBe Response.Status.SERVICE_UNAVAILABLE.getStatusCode + } + + it should "return SERVICE_UNAVAILABLE for malformed OpenRouter responses" in { + val resource = openRouterResource(() => """{"data": {}}""") + + val response = resource.getOpenRouterModels + + response.getStatus shouldBe Response.Status.SERVICE_UNAVAILABLE.getStatusCode + } + + it should "build OpenRouter requests with authorization when an API key is configured" in { + val request = OpenRouterModelsResource.buildOpenRouterModelsRequest(Some(" openrouter-key ")) + + request.headers().firstValue("Authorization").orElse("") shouldBe "Bearer openrouter-key" + } + + it should "build OpenRouter requests without authorization when no API key is configured" in { + val request = OpenRouterModelsResource.buildOpenRouterModelsRequest(None) + + request.headers().firstValue("Authorization").isPresent shouldBe false + } +} diff --git a/amber/src/main/resources/logback.xml b/amber/src/main/resources/logback.xml index 43afb5d44c7..fdfd20ebb4e 100644 --- a/amber/src/main/resources/logback.xml +++ b/amber/src/main/resources/logback.xml @@ -51,6 +51,7 @@ + diff --git a/common/config/src/main/resources/llm.conf b/common/config/src/main/resources/llm.conf index 23b9360cdab..bfa196adf7e 100644 --- a/common/config/src/main/resources/llm.conf +++ b/common/config/src/main/resources/llm.conf @@ -24,4 +24,8 @@ llm { # Master key for LiteLLM authentication master-key = "" master-key = ${?LITELLM_MASTER_KEY} + + # Optional OpenRouter API key for OpenRouter API requests + openrouter-api-key = "" + openrouter-api-key = ${?OPENROUTER_API_KEY} } diff --git a/common/config/src/main/scala/org/apache/texera/config/LLMConfig.scala b/common/config/src/main/scala/org/apache/texera/config/LLMConfig.scala index a85b734bad6..42bf15d3f72 100644 --- a/common/config/src/main/scala/org/apache/texera/config/LLMConfig.scala +++ b/common/config/src/main/scala/org/apache/texera/config/LLMConfig.scala @@ -26,4 +26,8 @@ object LLMConfig { // LLM Service Configuration val baseUrl: String = conf.getString("llm.base-url") val masterKey: String = conf.getString("llm.master-key") + val openRouterApiKey: Option[String] = + if (conf.hasPath("llm.openrouter-api-key")) + Option(conf.getString("llm.openrouter-api-key")).map(_.trim).filter(_.nonEmpty) + else None } diff --git a/common/workflow-operator/build.sbt b/common/workflow-operator/build.sbt index 1c082cae96e..5fedca154b3 100644 --- a/common/workflow-operator/build.sbt +++ b/common/workflow-operator/build.sbt @@ -110,7 +110,20 @@ libraryDependencies ++= Seq( "org.apache.commons" % "commons-compress" % "1.27.1", "org.tukaani" % "xz" % "1.9", "com.univocity" % "univocity-parsers" % "2.9.1", - "org.apache.lucene" % "lucene-analyzers-common" % "8.11.4" + "org.apache.lucene" % "lucene-analyzers-common" % "8.11.4", + ("dev.langchain4j" % "langchain4j" % "1.0.1") + .exclude("com.fasterxml.jackson.core", "jackson-databind") + .exclude("com.fasterxml.jackson.core", "jackson-core") + .exclude("com.fasterxml.jackson.core", "jackson-annotations"), + ("dev.langchain4j" % "langchain4j-open-ai" % "1.0.1") + .exclude("com.fasterxml.jackson.core", "jackson-databind") + .exclude("com.fasterxml.jackson.core", "jackson-core") + .exclude("com.fasterxml.jackson.core", "jackson-annotations"), + // AI Agent tools: URL fetch + PDF read pipeline + "org.jsoup" % "jsoup" % "1.17.2", + "net.dankito.readability4j" % "readability4j" % "1.0.8", + "com.vladsch.flexmark" % "flexmark-html2md-converter" % "0.64.8", + "org.apache.pdfbox" % "pdfbox" % "2.0.32" ) libraryDependencies += "io.github.classgraph" % "classgraph" % "4.8.184" % Test diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala index 4e9d6c6e2cd..b190726f2eb 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/LogicalOp.scala @@ -34,6 +34,7 @@ import org.apache.texera.amber.core.workflow.WorkflowContext.{ DEFAULT_WORKFLOW_ID } import org.apache.texera.amber.core.workflow.{PhysicalOp, PhysicalPlan, PortIdentity} +import org.apache.texera.amber.operator.aiagent.AIAgentOpDesc import org.apache.texera.amber.operator.aggregate.AggregateOpDesc import org.apache.texera.amber.operator.cartesianProduct.CartesianProductOpDesc import org.apache.texera.amber.operator.dictionary.DictionaryMatcherOpDesc @@ -164,6 +165,7 @@ trait StateTransferFunc @JsonSubTypes( Array( new Type(value = classOf[IfOpDesc], name = "If"), + new Type(value = classOf[AIAgentOpDesc], name = "AIAgent"), new Type(value = classOf[SankeyDiagramOpDesc], name = "SankeyDiagram"), new Type(value = classOf[IcicleChartOpDesc], name = "IcicleChart"), new Type(value = classOf[FileListerSourceOpDesc], name = "FileLister"), diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentFinalAnswerTools.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentFinalAnswerTools.scala new file mode 100644 index 00000000000..8d7cf926e63 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentFinalAnswerTools.scala @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import dev.langchain4j.agent.tool.ToolSpecification +import dev.langchain4j.model.chat.request.json.JsonObjectSchema + +import scala.jdk.CollectionConverters._ + +object AIAgentFinalAnswerTools { + final val SubmitTextResult = "submit_text_result" + final val SubmitStructuredResult = "submit_structured_result" + + def textResult(labels: List[String]): ToolSpecification = { + val normalizedLabels = normalize(labels) + val parameters = JsonObjectSchema + .builder() + .description("Final text response for the current row") + if (normalizedLabels.isEmpty) { + parameters.addStringProperty("response", "Final text response") + } else { + parameters.addEnumProperty( + "response", + normalizedLabels.asJava, + "Final classification label. Must exactly match one allowed label." + ) + } + ToolSpecification + .builder() + .name(SubmitTextResult) + .description("Submit the final AI Agent text result for the current row") + .parameters(parameters.required("response").additionalProperties(false).build()) + .build() + } + + def structuredResult(fields: List[AIAgentStructuredOutputField]): ToolSpecification = { + val normalizedFields = Option(fields).getOrElse(List.empty).filter { field => + field != null && field.columnName != null && field.columnName.trim.nonEmpty + } + val parameters = JsonObjectSchema + .builder() + .description("Final structured response for the current row") + normalizedFields.foreach { field => + val columnName = field.columnName.trim + val instructions = Option(field.instructions).getOrElse("").trim + val description = + if (instructions.isEmpty) "Extract this value for the row" else instructions + if (field.normalizedFieldType == AIAgentStructuredFieldType.Classification) { + val labels = normalize(field.classificationLabels) + if (labels.nonEmpty) { + parameters.addEnumProperty(columnName, labels.asJava, description) + } else { + parameters.addStringProperty(columnName, description) + } + } else { + parameters.addStringProperty(columnName, description) + } + } + ToolSpecification + .builder() + .name(SubmitStructuredResult) + .description("Submit the final AI Agent structured result for the current row") + .parameters( + parameters + .required(normalizedFields.map(_.columnName.trim): _*) + .additionalProperties(false) + .build() + ) + .build() + } + + private def normalize(values: List[String]): List[String] = + Option(values).getOrElse(List.empty).map(_.trim).filter(_.nonEmpty) + +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentOpDesc.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentOpDesc.scala new file mode 100644 index 00000000000..81b3a438040 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentOpDesc.scala @@ -0,0 +1,481 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import com.fasterxml.jackson.annotation.{JsonFormat, JsonProperty, JsonPropertyDescription} +import com.kjetland.jackson.jsonSchema.annotations.{JsonSchemaInject, JsonSchemaTitle} +import org.apache.texera.amber.core.executor.OpExecWithClassName +import org.apache.texera.amber.core.tuple.{Attribute, AttributeType} +import org.apache.texera.amber.core.virtualidentity.{ExecutionIdentity, WorkflowIdentity} +import org.apache.texera.amber.core.workflow.{ + InputPort, + OutputPort, + PhysicalOp, + SchemaPropagationFunc +} +import org.apache.texera.amber.operator.map.MapOpDesc +import org.apache.texera.amber.operator.metadata.annotations.{AutofillAttributeNameList, UIWidget} +import org.apache.texera.amber.operator.metadata.{OperatorGroupConstants, OperatorInfo} +import org.apache.texera.amber.util.JSONUtils.objectMapper + +class AIAgentOpDesc extends MapOpDesc { + + @JsonProperty(value = "outputMode", required = false, defaultValue = "text") + @JsonSchemaTitle("Output Mode") + @JsonPropertyDescription( + "Controls whether the AI response is emitted as one text column or multiple structured columns" + ) + @JsonSchemaInject( + json = """{"enum": ["text", "structured"], "default": "text"}""" + ) + var outputMode: String = AIAgentOutputMode.Text + + @JsonProperty(value = "structuredOutputFields", required = false) + @JsonSchemaTitle("Structured Output Fields") + @JsonPropertyDescription( + "Define each output column and what the model should extract for it. The model returns a JSON object with one key per column." + ) + @JsonSchemaInject( + json = + """{"hideTarget": "outputMode", "hideType": "equals", "hideExpectedValue": "text", "hideOnNull": true}""" + ) + var structuredOutputFields: List[AIAgentStructuredOutputField] = List.empty + + @JsonProperty(value = "textClassificationLabels", required = false) + @JsonSchemaTitle("Text Classification Labels") + @JsonPropertyDescription( + "Optional allowed labels for text output. Leave empty for free-form text; fill this to make the text response a classification." + ) + @JsonSchemaInject( + json = + """{"hideTarget": "outputMode", "hideType": "equals", "hideExpectedValue": "structured", "hideOnNull": true, "widget": {"formlyConfig": {"type": "tags-input"}}}""" + ) + var textClassificationLabels: List[String] = List.empty + + @JsonProperty(value = "classificationLabels", required = false) + @JsonSchemaTitle("Legacy Classification Labels") + @JsonPropertyDescription( + "Deprecated. Use Text Classification Labels or structured classification fields." + ) + @JsonSchemaInject( + json = """{"hideTarget": "outputMode", "hideType": "regex", "hideExpectedValue": ".*"}""" + ) + var classificationLabels: List[String] = List.empty + + @JsonProperty(value = "confidenceColumn", required = false) + @JsonSchemaTitle("Legacy Confidence Column") + @JsonPropertyDescription( + "Deprecated. Retained so workflows saved before classification was removed still load." + ) + @JsonSchemaInject( + json = """{"hideTarget": "outputMode", "hideType": "regex", "hideExpectedValue": ".*"}""" + ) + var confidenceColumn: String = "" + + @JsonProperty(value = "outputColumn", required = false, defaultValue = "ai_agent_response") + @JsonSchemaTitle("Output Column") + @JsonPropertyDescription("Column name for the text response") + @JsonSchemaInject( + json = + """{"hideTarget": "outputMode", "hideType": "equals", "hideExpectedValue": "structured"}""" + ) + var outputColumn: String = "ai_agent_response" + + @JsonProperty(value = "systemPrompt", required = false) + @JsonSchemaTitle("System Prompt") + @JsonPropertyDescription("Optional system prompt sent before each row prompt") + @JsonSchemaInject(json = UIWidget.UIWidgetTextArea) + var systemPrompt: String = "" + + @JsonProperty(value = "inputColumn", required = true) + @JsonSchemaTitle("Columns Sent to AI") + @JsonPropertyDescription("Column values sent as the user prompt for each input row") + @JsonFormat(`with` = Array(JsonFormat.Feature.ACCEPT_SINGLE_VALUE_AS_ARRAY)) + @AutofillAttributeNameList + var inputColumn: List[String] = List.empty + + @JsonProperty(value = "apiKey", required = true) + @JsonSchemaTitle("OpenRouter API Key") + @JsonPropertyDescription("OpenRouter API key") + @JsonSchemaInject(json = UIWidget.UIWidgetPassword) + var apiKey: String = _ + + @JsonProperty(value = "model", required = true, defaultValue = "openai/gpt-4o-mini") + @JsonSchemaTitle("Model") + @JsonPropertyDescription("OpenRouter model ID") + var model: String = "openai/gpt-4o-mini" + + @JsonProperty(value = "temperature", required = true) + @JsonSchemaTitle("Temperature") + @JsonPropertyDescription("Sampling temperature") + @JsonSchemaInject(json = """{"default": 0.7}""") + var temperature: Double = 0.7 + + @JsonProperty(value = "timeoutSeconds", required = true, defaultValue = "60") + @JsonSchemaTitle("Timeout Seconds") + @JsonPropertyDescription("OpenRouter request timeout in seconds") + var timeoutSeconds: Int = 60 + + @JsonProperty(value = "enabledTools", required = false) + @JsonSchemaTitle("Enabled Tools") + @JsonPropertyDescription( + "Optional tools the AI can call per row. read_url fetches a web page as Markdown; read_pdf extracts text from a PDF URL or path." + ) + @JsonSchemaInject( + json = + """{"type": "array", "items": {"type": "string", "enum": ["read_url", "read_pdf"]}, "uniqueItems": true, "default": []}""" + ) + var enabledTools: List[String] = List.empty + + @JsonProperty(value = "maxToolIterations", required = false) + @JsonSchemaTitle("Max Tool Iterations") + @JsonPropertyDescription( + "Maximum number of model turns when tools are enabled. Each turn either calls a tool or returns the final answer." + ) + @JsonSchemaInject( + json = """{"default": 5}""" + ) + var maxToolIterations: java.lang.Integer = 5 + + @JsonProperty(value = "urlFetchMaxChars", required = false) + @JsonSchemaTitle("URL Fetch Max Chars") + @JsonPropertyDescription("Truncate read_url Markdown output to this many characters.") + @JsonSchemaInject(json = """{"default": 50000}""") + var urlFetchMaxChars: java.lang.Integer = 50000 + + @JsonProperty(value = "pdfReadMaxChars", required = false) + @JsonSchemaTitle("PDF Read Max Chars") + @JsonPropertyDescription("Truncate read_pdf text output to this many characters.") + @JsonSchemaInject(json = """{"default": 100000}""") + var pdfReadMaxChars: java.lang.Integer = 100000 + + @JsonProperty(value = "parallelism", required = false) + @JsonSchemaTitle("Parallelism") + @JsonPropertyDescription( + "Number of parallel workers. Texera shards the input rows across this many actors, each running the agent loop independently. Each worker has its own response cache. Watch out for upstream API rate limits when raising this." + ) + @JsonSchemaInject(json = """{"default": 1, "minimum": 1, "maximum": 32}""") + var parallelism: java.lang.Integer = 1 + + @JsonProperty(value = "cacheEnabled", required = false) + @JsonSchemaTitle("Cache Responses") + @JsonPropertyDescription( + "If true, identical (model, prompt, tools, structured output) requests reuse the previous response without billing." + ) + @JsonSchemaInject(json = """{"default": true}""") + var cacheEnabled: Boolean = true + + @JsonProperty(value = "emitCostColumn", required = false) + @JsonSchemaTitle("Emit Cost Column") + @JsonPropertyDescription( + "If true, append a double column with the estimated USD cost of the LLM call for the row." + ) + @JsonSchemaInject(json = """{"default": true}""") + var emitCostColumn: Boolean = true + + @JsonProperty(value = "costColumnName", required = false) + @JsonSchemaTitle("Cost Column Name") + @JsonPropertyDescription("Name of the cost column appended to the output schema.") + @JsonSchemaInject(json = """{"default": "_cost_usd"}""") + var costColumnName: String = "_cost_usd" + + @JsonProperty(value = "maxRowCostUsd", required = false) + @JsonSchemaTitle("Max Row Cost (USD)") + @JsonPropertyDescription( + "Optional hard cap. If a row's accumulated tool-loop cost exceeds this value, the row is aborted and surfaced as an error." + ) + var maxRowCostUsd: java.lang.Double = _ + + @JsonProperty(value = "emitErrorColumn", required = false) + @JsonSchemaTitle("Emit Error Column") + @JsonPropertyDescription( + "If true, append a string column containing the error message when the LLM call fails (empty on success)." + ) + @JsonSchemaInject(json = """{"default": true}""") + var emitErrorColumn: Boolean = true + + @JsonProperty(value = "errorColumnName", required = false) + @JsonSchemaTitle("Error Column Name") + @JsonPropertyDescription("Name of the error column appended to the output schema.") + @JsonSchemaInject(json = """{"default": "_error"}""") + var errorColumnName: String = "_error" + + @JsonProperty(value = "mcpServers", required = false) + @JsonSchemaTitle("MCP Servers") + @JsonPropertyDescription( + "Connect Model Context Protocol servers to give the agent additional tools (e.g. Notion, Slack, GitHub, internal APIs). Tools discovered from each server are namespaced as `serverName__toolName`." + ) + var mcpServers: List[AIAgentMCPServerConfig] = List.empty + + override def getPhysicalOp( + workflowId: WorkflowIdentity, + executionId: ExecutionIdentity + ): PhysicalOp = { + validateConfig() + val workers = normalizedParallelism + val base = PhysicalOp + .oneToOnePhysicalOp( + workflowId, + executionId, + operatorIdentifier, + OpExecWithClassName( + "org.apache.texera.amber.operator.aiagent.AIAgentOpExec", + objectMapper.writeValueAsString(this) + ) + ) + val parallelized = + if (workers > 1) base.withParallelizable(true).withSuggestedWorkerNum(workers) + else base.withParallelizable(false) + parallelized + .withInputPorts(operatorInfo.inputPorts) + .withOutputPorts(operatorInfo.outputPorts) + .withPropagateSchema( + SchemaPropagationFunc(inputSchemas => { + val inputPortId = operatorInfo.inputPorts.head.id + val outputPortId = operatorInfo.outputPorts.head.id + val inputSchema = inputSchemas(inputPortId) + Map(outputPortId -> outputSchema(inputSchema)) + }) + ) + } + + def normalizedParallelism: Int = + Option(parallelism).map(_.intValue).filter(_ >= 1).map(_.min(32)).getOrElse(1) + + private def validateConfig(): Unit = { + if (apiKey == null || apiKey.trim.isEmpty) { + throw new IllegalArgumentException( + "AI Agent: OpenRouter API key is missing. Open this operator's property panel and paste your key into 'OpenRouter API Key'." + ) + } + if (model == null || model.trim.isEmpty) { + throw new IllegalArgumentException("AI Agent: Model is required.") + } + if (normalizedOutputMode == AIAgentOutputMode.Structured && + normalizedStructuredOutputFields.isEmpty) { + throw new IllegalArgumentException( + "AI Agent: Structured output mode requires at least one structured output field. Add fields or switch to text mode." + ) + } + if (Option(inputColumn).getOrElse(List.empty).forall(c => c == null || c.trim.isEmpty)) { + throw new IllegalArgumentException( + "AI Agent: 'Columns Sent to AI' is empty — pick at least one input column." + ) + } + } + + private def outputSchema(inputSchema: org.apache.texera.amber.core.tuple.Schema) = { + val base = normalizedOutputMode match { + case AIAgentOutputMode.Structured => + normalizedStructuredOutputColumns.foldLeft(inputSchema) { (schema, column) => + schema.add(new Attribute(column, AttributeType.STRING)) + } + case AIAgentOutputMode.Classification => + addTextOutputColumn(inputSchema) + case _ => + addTextOutputColumn(inputSchema) + } + val withCost = + if ( + emitCostColumn && + normalizedCostColumnName.nonEmpty && + !base.containsAttribute(normalizedCostColumnName) + ) { + base.add(new Attribute(normalizedCostColumnName, AttributeType.DOUBLE)) + } else base + if ( + emitErrorColumn && + normalizedErrorColumnName.nonEmpty && + !withCost.containsAttribute(normalizedErrorColumnName) + ) { + withCost.add(new Attribute(normalizedErrorColumnName, AttributeType.STRING)) + } else withCost + } + + def normalizedErrorColumnName: String = + Option(errorColumnName).map(_.trim).filter(_.nonEmpty).getOrElse("_error") + + def normalizedCostColumnName: String = + Option(costColumnName).map(_.trim).filter(_.nonEmpty).getOrElse("_cost_usd") + + def normalizedMaxRowCostUsd: Option[Double] = + Option(maxRowCostUsd).map(_.doubleValue).filter(_ > 0.0) + + private def addTextOutputColumn(inputSchema: org.apache.texera.amber.core.tuple.Schema) = + if (outputColumn == null || outputColumn.trim.isEmpty) { + inputSchema + } else { + inputSchema.add(new Attribute(outputColumn, AttributeType.STRING)) + } + + def normalizedStructuredOutputColumns: List[String] = + normalizedStructuredOutputFields.map(_.columnName.trim) + + def normalizedStructuredOutputFields: List[AIAgentStructuredOutputField] = + Option(structuredOutputFields) + .getOrElse(List.empty) + .filter(field => field != null && field.columnName != null && field.columnName.trim.nonEmpty) + + def normalizedClassificationLabels: List[String] = + Option(classificationLabels).getOrElse(List.empty).map(_.trim).filter(_.nonEmpty) + + def normalizedTextClassificationLabels: List[String] = + Option(textClassificationLabels).getOrElse(List.empty).map(_.trim).filter(_.nonEmpty) + + def normalizedMaxToolIterations: Int = + Option(maxToolIterations).map(_.intValue).filter(_ > 0).getOrElse(5) + + def normalizedUrlFetchMaxChars: Int = + Option(urlFetchMaxChars).map(_.intValue).filter(_ > 0).getOrElse(UrlFetcher.DefaultMaxChars) + + def normalizedPdfReadMaxChars: Int = + Option(pdfReadMaxChars).map(_.intValue).filter(_ > 0).getOrElse(PdfReader.DefaultMaxChars) + + def normalizedEnabledTools: List[String] = + Option(enabledTools).getOrElse(List.empty).map(_.trim).filter(_.nonEmpty).distinct + + def buildTools: List[AIAgentTool] = { + val builtIn = normalizedEnabledTools.flatMap { + case UrlFetchTool.Name => Some(new UrlFetchTool(normalizedUrlFetchMaxChars)) + case PdfReadTool.Name => Some(new PdfReadTool(normalizedPdfReadMaxChars)) + case _ => None + } + val mcp = normalizedMcpServers.flatMap { server => + val client = new MCPClient( + serverName = server.normalizedName, + url = server.url.trim, + bearerToken = Option(server.bearerToken).map(_.trim).filter(_.nonEmpty), + timeoutSeconds = server.normalizedTimeoutSeconds + ) + try { + client.initialize() + client.listTools().map(info => new MCPToolAdapter(client, info)) + } catch { + case t: Throwable => + throw new RuntimeException( + s"AI Agent: failed to connect to MCP server '${server.normalizedName}' at ${server.url.trim}: " + + s"${t.getClass.getSimpleName}: ${Option(t.getMessage).getOrElse("")}. " + + s"Check the URL and Bearer Token in the operator's MCP Servers config.", + t + ) + } + } + builtIn ++ mcp + } + + def normalizedMcpServers: List[AIAgentMCPServerConfig] = + Option(mcpServers) + .getOrElse(List.empty) + .filter(s => s != null && s.url != null && s.url.trim.nonEmpty) + + def normalizedOutputMode: String = + Option(outputMode).map(_.trim).filter(_.nonEmpty).getOrElse(AIAgentOutputMode.Text) + + override def operatorInfo: OperatorInfo = + OperatorInfo( + "AI Agent", + "Calls OpenRouter chat completions once per input row", + OperatorGroupConstants.API_GROUP, + inputPorts = List(InputPort()), + outputPorts = List(OutputPort()), + supportReconfiguration = true + ) +} + +object AIAgentOutputMode { + final val Text = "text" + final val Structured = "structured" + final val Classification = "classification" +} + +class AIAgentStructuredOutputField { + @JsonProperty(value = "fieldType", required = false, defaultValue = "text") + @JsonSchemaTitle("Field Type") + @JsonPropertyDescription("Choose free-form text or a classification label for this output column") + @JsonSchemaInject(json = """{"enum": ["text", "classification"], "default": "text"}""") + var fieldType: String = AIAgentStructuredFieldType.Text + + @JsonProperty(value = "columnName", required = true) + @JsonSchemaTitle("Column Name") + @JsonPropertyDescription("Output column name and JSON key for this extracted value") + var columnName: String = "" + + @JsonProperty(value = "instructions", required = false) + @JsonSchemaTitle("Instructions") + @JsonPropertyDescription("Describe what this column should contain for each row") + @JsonSchemaInject(json = UIWidget.UIWidgetTextArea) + var instructions: String = "" + + @JsonProperty(value = "classificationLabels", required = false) + @JsonSchemaTitle("Classification Labels") + @JsonPropertyDescription("Allowed labels when this structured field is a classification") + @JsonSchemaInject( + json = + """{"hideTarget": "fieldType", "hideType": "equals", "hideExpectedValue": "text", "hideOnNull": true, "widget": {"formlyConfig": {"type": "tags-input"}}}""" + ) + var classificationLabels: List[String] = List.empty + + def normalizedFieldType: String = + Option(fieldType).map(_.trim).filter(_.nonEmpty).getOrElse(AIAgentStructuredFieldType.Text) + + def normalizedClassificationLabels: List[String] = + Option(classificationLabels).getOrElse(List.empty).map(_.trim).filter(_.nonEmpty) +} + +object AIAgentStructuredFieldType { + final val Text = "text" + final val Classification = "classification" +} + +class AIAgentMCPServerConfig { + @JsonProperty(value = "name", required = false) + @JsonSchemaTitle("Server Name") + @JsonPropertyDescription( + "Short identifier used to namespace this server's tools (e.g. `notion` → `notion__search`). Letters, digits, and underscores only." + ) + var name: String = "" + + @JsonProperty(value = "url", required = true) + @JsonSchemaTitle("Server URL") + @JsonPropertyDescription( + "Streamable HTTP endpoint of the MCP server, e.g. https://mcp.notion.com/mcp" + ) + var url: String = "" + + @JsonProperty(value = "bearerToken", required = false) + @JsonSchemaTitle("Bearer Token") + @JsonPropertyDescription( + "Optional auth token sent as `Authorization: Bearer ...`. Leave blank for unauthenticated servers." + ) + @JsonSchemaInject(json = UIWidget.UIWidgetPassword) + var bearerToken: String = "" + + @JsonProperty(value = "timeoutSeconds", required = false) + @JsonSchemaTitle("Timeout Seconds") + @JsonPropertyDescription("HTTP timeout for each MCP request.") + @JsonSchemaInject(json = """{"default": 30}""") + var timeoutSeconds: java.lang.Integer = 30 + + def normalizedName: String = + Option(name).map(_.trim).filter(_.nonEmpty).getOrElse("mcp") + + def normalizedTimeoutSeconds: Int = + Option(timeoutSeconds).map(_.intValue).filter(_ > 0).getOrElse(30) +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentOpExec.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentOpExec.scala new file mode 100644 index 00000000000..e8ed64b0dfd --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentOpExec.scala @@ -0,0 +1,244 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import org.apache.texera.amber.core.tuple.{Tuple, TupleLike} +import org.apache.texera.amber.operator.map.MapOpExec +import org.apache.texera.amber.util.JSONUtils.objectMapper + +class AIAgentOpExec(descString: String) extends MapOpExec { + private val desc: AIAgentOpDesc = + objectMapper.readValue(descString, classOf[AIAgentOpDesc]) + private val client: ChatCompletionClient = new OpenRouterClient + private val tools: List[AIAgentTool] = desc.buildTools + private val cache: AIAgentResponseCache = new AIAgentResponseCache() + + setMapFunc(callAIAgent) + + override def close(): Unit = { + tools.foreach { tool => + try tool.close() + catch { + case _: Throwable => + } + } + } + + private def cacheKey(userPrompt: String): String = { + val mcpSig = desc.normalizedMcpServers + .map { s => + val tokenHash = Option(s.bearerToken) + .map(_.trim) + .filter(_.nonEmpty) + .map(AIAgentResponseCache.sha256) + .getOrElse("") + s"${s.normalizedName}|${s.url.trim}|$tokenHash" + } + .sorted + .mkString(",") + val toolSig = tools.map(_.name).sorted.mkString(",") + "||mcp=" + mcpSig + val structSig = desc.normalizedOutputMode + ":" + + desc.normalizedStructuredOutputFields + .map(f => s"${f.columnName.trim}|${f.normalizedFieldType}|${f.normalizedClassificationLabels.mkString(";")}") + .mkString(",") + ":" + desc.normalizedTextClassificationLabels.mkString(";") + + ":" + desc.normalizedClassificationLabels.mkString(";") + + s":operator=${desc.operatorIdentifier.id}" + AIAgentResponseCache.key( + desc.model, + desc.temperature, + AIAgentResponseCache.sha256(desc.apiKey), + effectiveSystemPrompt, + userPrompt, + toolSig, + structSig + ) + } + + private def callAIAgent(tuple: Tuple): TupleLike = { + val userPrompt = buildUserPrompt(tuple) + val key = if (desc.cacheEnabled) cacheKey(userPrompt) else null + val (fields, costUsd, errorMsg) = + try { + val cached = if (desc.cacheEnabled) cache.get(key) else None + cached match { + case Some(text) => (outputFields(text), 0.0, "") + case None => + val result = completeForMode(userPrompt) + if (desc.cacheEnabled) cache.put(key, result.text) + (outputFields(result.text), result.usdCost, "") + } + } catch { + case t: Throwable => + val raw = s"${t.getClass.getSimpleName}: ${Option(t.getMessage).getOrElse("")}" + val truncated = if (raw.length > 1000) raw.substring(0, 1000) else raw + (emptyOutputFields, 0.0, truncated) + } + val inputSchema = tuple.getSchema + val emittedNames = outputAttributeNames(inputSchema) + val withCost: Seq[Any] = + if ( + desc.emitCostColumn && + desc.normalizedCostColumnName.nonEmpty && + !emittedNames.exists(_.equalsIgnoreCase(desc.normalizedCostColumnName)) + ) + fields :+ java.lang.Double.valueOf(costUsd) + else fields + val namesWithCost = + if (withCost.length > fields.length) emittedNames :+ desc.normalizedCostColumnName else emittedNames + val withError: Seq[Any] = + if ( + desc.emitErrorColumn && + desc.normalizedErrorColumnName.nonEmpty && + !namesWithCost.exists(_.equalsIgnoreCase(desc.normalizedErrorColumnName)) + ) withCost :+ errorMsg + else withCost + TupleLike(tuple.getFields ++ withError) + } + + private def outputAttributeNames(inputSchema: org.apache.texera.amber.core.tuple.Schema): Seq[String] = { + val inputNames = inputSchema.getAttributeNames + desc.normalizedOutputMode match { + case AIAgentOutputMode.Structured => + inputNames ++ desc.normalizedStructuredOutputColumns + case _ if Option(desc.outputColumn).exists(_.trim.nonEmpty) => + inputNames :+ desc.outputColumn.trim + case _ => + inputNames + } + } + + private def emptyOutputFields: Seq[Any] = + desc.normalizedOutputMode match { + case AIAgentOutputMode.Structured => + desc.normalizedStructuredOutputColumns.map(_ => "") + case _ => + Seq("") + } + + private def completeForMode(userPrompt: String): ChatCompletionResult = { + val toolSpecification = desc.normalizedOutputMode match { + case AIAgentOutputMode.Structured => + Some(AIAgentFinalAnswerTools.structuredResult(desc.normalizedStructuredOutputFields)) + case AIAgentOutputMode.Classification => + Some(AIAgentFinalAnswerTools.textResult(desc.normalizedClassificationLabels)) + case _ if desc.normalizedTextClassificationLabels.nonEmpty => + Some(AIAgentFinalAnswerTools.textResult(desc.normalizedTextClassificationLabels)) + case _ => + None + } + + val hasTools = tools.nonEmpty + if (!hasTools && toolSpecification.isEmpty) { + client.complete( + desc.apiKey, + desc.model, + effectiveSystemPrompt, + userPrompt, + desc.temperature, + desc.timeoutSeconds + ) + } else { + client.completeWithTools( + desc.apiKey, + desc.model, + effectiveSystemPrompt, + userPrompt, + desc.temperature, + desc.timeoutSeconds, + tools, + toolSpecification, + desc.normalizedMaxToolIterations, + desc.normalizedMaxRowCostUsd + ) + } + } + + private def buildUserPrompt(tuple: Tuple): String = { + val inputColumns = Option(desc.inputColumn).getOrElse(List.empty).filter(_.trim.nonEmpty) + require(inputColumns.nonEmpty, "At least one column must be sent to AI") + inputColumns + .map { column => + if (!tuple.getSchema.containsAttribute(column)) { + throw new IllegalArgumentException(s"AI Agent references missing column: $column") + } + val value = Option(tuple.getField[Any](column)).map(_.toString).getOrElse("") + s"$column: $value" + } + .mkString("\n") + } + + private def effectiveSystemPrompt: String = { + val basePrompt = Option(desc.systemPrompt).getOrElse("").trim + val modePrompt = desc.normalizedOutputMode match { + case AIAgentOutputMode.Structured => + val fields = desc.normalizedStructuredOutputFields + require(fields.nonEmpty, "Structured output mode requires at least one output field") + val fieldInstructions = fields + .map { field => + val instructions = Option(field.instructions).getOrElse("").trim + val classificationSuffix = + if (field.normalizedFieldType == AIAgentStructuredFieldType.Classification) { + val labels = field.normalizedClassificationLabels + if (labels.isEmpty) { + " Choose a classification label." + } else { + s" Choose exactly one of: ${labels.mkString(", ")}." + } + } else { + "" + } + if (instructions.isEmpty) { + s"- ${field.columnName.trim}: extract this value for the row.$classificationSuffix" + } else { + s"- ${field.columnName.trim}: $instructions$classificationSuffix" + } + } + .mkString("\n") + s"""Call the ${AIAgentFinalAnswerTools.SubmitStructuredResult} tool exactly once with the final structured result. + | + |Structured output fields: + |$fieldInstructions""".stripMargin + case AIAgentOutputMode.Classification => + val labels = desc.normalizedClassificationLabels + require(labels.nonEmpty, "Classification mode requires at least one label") + s"""Call the ${AIAgentFinalAnswerTools.SubmitTextResult} tool exactly once. The response value must exactly match one of these labels: ${labels + .mkString(", ")}.""" + case _ if desc.normalizedTextClassificationLabels.nonEmpty => + s"""Call the ${AIAgentFinalAnswerTools.SubmitTextResult} tool exactly once. The response value must exactly match one of these labels: ${desc.normalizedTextClassificationLabels + .mkString(", ")}.""" + case _ => + "" + } + + List(basePrompt, modePrompt).filter(_.nonEmpty).mkString("\n\n") + } + + private def outputFields(content: String): Seq[Any] = + desc.normalizedOutputMode match { + case AIAgentOutputMode.Structured => + AIAgentOutputParser.parseStructuredFields(content, desc.normalizedStructuredOutputFields) + case AIAgentOutputMode.Classification => + Seq(AIAgentOutputParser.parseTextResult(content, desc.normalizedClassificationLabels)) + case _ if desc.normalizedTextClassificationLabels.nonEmpty => + Seq(AIAgentOutputParser.parseTextResult(content, desc.normalizedTextClassificationLabels)) + case _ => + Seq(content) + } +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentOutputParser.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentOutputParser.scala new file mode 100644 index 00000000000..dedbf0a86aa --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentOutputParser.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import com.fasterxml.jackson.databind.JsonNode +import org.apache.texera.amber.util.JSONUtils.objectMapper + +object AIAgentOutputParser { + + def parseTextResult(response: String, labels: List[String]): String = { + val root = parseJsonObject(response, "text output") + val responseNode = root.get("response") + if (responseNode == null || !responseNode.isTextual) { + throw new IllegalArgumentException("Text output must contain a string response") + } + val value = responseNode.asText() + val normalizedLabels = normalize(labels) + if (normalizedLabels.nonEmpty && !normalizedLabels.contains(value)) { + throw new IllegalArgumentException( + s"""Text classification label "$value" is not one of: ${normalizedLabels.mkString(", ")}""" + ) + } + value + } + + def parseStructured(response: String, columns: List[String]): Seq[String] = { + val normalizedColumns = normalize(columns) + require( + normalizedColumns.nonEmpty, + "Structured output mode requires at least one output column" + ) + + val root = parseJsonObject(response, "structured output") + normalizedColumns.map(column => jsonValueToString(root.get(column))) + } + + def parseStructuredFields(response: String, fields: List[AIAgentStructuredOutputField]): Seq[String] = { + val normalizedFields = Option(fields) + .getOrElse(List.empty) + .filter(field => field != null && field.columnName != null && field.columnName.trim.nonEmpty) + require( + normalizedFields.nonEmpty, + "Structured output mode requires at least one output column" + ) + + val root = parseJsonObject(response, "structured output") + normalizedFields.map { field => + val value = jsonValueToString(root.get(field.columnName.trim)) + if (field.normalizedFieldType == AIAgentStructuredFieldType.Classification) { + val labels = field.normalizedClassificationLabels + if (labels.nonEmpty && !labels.contains(value)) { + throw new IllegalArgumentException( + s"""Structured classification field "${field.columnName.trim}" label "$value" is not one of: ${labels + .mkString(", ")}""" + ) + } + } + value + } + } + + def parseClassification(response: String, labels: List[String]): (String, java.lang.Double) = { + val normalizedLabels = normalize(labels) + require(normalizedLabels.nonEmpty, "Classification mode requires at least one label") + + val root = parseJsonObject(response, "classification output") + val labelNode = root.get("label") + if (labelNode == null || !labelNode.isTextual) { + throw new IllegalArgumentException("Classification output must contain a string label") + } + + val label = labelNode.asText() + if (!normalizedLabels.contains(label)) { + throw new IllegalArgumentException( + s"""Classification label "$label" is not one of: ${normalizedLabels.mkString(", ")}""" + ) + } + + val confidenceNode = root.get("confidence") + val confidence: java.lang.Double = + if (confidenceNode == null || confidenceNode.isNull) { + null + } else if (confidenceNode.isNumber) { + confidenceNode.asDouble() + } else { + throw new IllegalArgumentException( + "Classification confidence must be numeric when provided" + ) + } + + (label, confidence) + } + + private def parseJsonObject(response: String, outputName: String): JsonNode = { + val root = objectMapper.readTree(Option(response).getOrElse("")) + if (!root.isObject) { + throw new IllegalArgumentException(s"AI Agent $outputName must be a JSON object") + } + root + } + + private def jsonValueToString(node: JsonNode): String = + if (node == null || node.isNull) { + "" + } else if (node.isTextual) { + node.asText() + } else { + node.toString + } + + private def normalize(values: List[String]): List[String] = + Option(values).getOrElse(List.empty).map(_.trim).filter(_.nonEmpty) +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentPromptSuggester.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentPromptSuggester.scala new file mode 100644 index 00000000000..62a3df1f4f1 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentPromptSuggester.scala @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package org.apache.texera.amber.operator.aiagent + +import com.fasterxml.jackson.databind.JsonNode +import dev.langchain4j.agent.tool.ToolSpecification +import dev.langchain4j.model.chat.request.json.JsonObjectSchema +import org.apache.texera.amber.util.JSONUtils.objectMapper + +import scala.jdk.CollectionConverters._ + +case class SuggestedPromptConfig( + systemPrompt: String, + outputMode: String, + outputColumn: String, + structuredOutputFields: List[AIAgentStructuredOutputField] +) + +// Backend RPC for "suggest a prompt": given the input column schema and a one-line goal, +// asks the LLM to draft a system prompt + output schema. Not wired to any HTTP endpoint — +// the frontend should expose this via a new REST route (out of scope for this hack). +object AIAgentPromptSuggester { + + private val SystemPrompt = + """You are a prompt designer for a per-row data-pipeline AI Agent. + |Given a one-line user goal and the input row schema, produce: + | - a concise systemPrompt instructing the model what to do for each row, + | - an outputMode of either "text" or "structured", + | - if outputMode is "text", an outputColumn name, + | - if outputMode is "structured", one or more structured output fields + | (each with columnName, fieldType "text" or "classification", + | instructions, and optional classificationLabels). + |Call submit_prompt_suggestion exactly once with the final JSON suggestion.""".stripMargin + + def suggest( + apiKey: String, + model: String, + goal: String, + inputColumns: List[(String, String)], + timeoutSeconds: Int = 60, + client: ChatCompletionClient = new OpenRouterClient + ): SuggestedPromptConfig = { + val columnsBlock = + if (inputColumns.isEmpty) "(no input columns)" + else inputColumns.map { case (n, t) => s"- $n: $t" }.mkString("\n") + val userPrompt = + s"""Goal: ${Option(goal).getOrElse("").trim} + | + |Input columns: + |$columnsBlock""".stripMargin + + val result = client.completeWithRequiredTool( + apiKey, + model, + SystemPrompt, + userPrompt, + 0.2, + timeoutSeconds, + suggestionTool + ) + parse(result.text) + } + + private def suggestionTool: ToolSpecification = { + val params = JsonObjectSchema + .builder() + .description("Final prompt suggestion") + .addStringProperty("systemPrompt", "System prompt to use for each row") + .addEnumProperty( + "outputMode", + List("text", "structured").asJava, + "Either text (one column) or structured (one column per extracted field)" + ) + .addStringProperty("outputColumn", "Output column name when outputMode is text") + .addStringProperty( + "structuredOutputFieldsJson", + "JSON array of {columnName, fieldType, instructions, classificationLabels} objects when outputMode is structured. Use [] for text mode." + ) + ToolSpecification + .builder() + .name("submit_prompt_suggestion") + .description("Submit a drafted prompt configuration") + .parameters( + params + .required("systemPrompt", "outputMode", "outputColumn", "structuredOutputFieldsJson") + .additionalProperties(false) + .build() + ) + .build() + } + + private def parse(json: String): SuggestedPromptConfig = { + val root: JsonNode = objectMapper.readTree(Option(json).getOrElse("{}")) + val systemPrompt = textOrEmpty(root.get("systemPrompt")) + val outputMode = { + val raw = textOrEmpty(root.get("outputMode")).trim.toLowerCase + if (raw == "structured") "structured" else "text" + } + val outputColumn = { + val raw = textOrEmpty(root.get("outputColumn")).trim + if (raw.nonEmpty) raw else "ai_agent_response" + } + val fieldsJson = textOrEmpty(root.get("structuredOutputFieldsJson")) + val fields: List[AIAgentStructuredOutputField] = + if (fieldsJson.trim.isEmpty) List.empty + else { + val arr = objectMapper.readTree(fieldsJson) + if (!arr.isArray) List.empty + else + arr.elements().asScala.toList.flatMap { node => + val name = textOrEmpty(node.get("columnName")).trim + if (name.isEmpty) None + else { + val f = new AIAgentStructuredOutputField + f.columnName = name + f.fieldType = { + val ft = textOrEmpty(node.get("fieldType")).trim.toLowerCase + if (ft == "classification") "classification" else "text" + } + f.instructions = textOrEmpty(node.get("instructions")) + val labelsNode = node.get("classificationLabels") + f.classificationLabels = + if (labelsNode == null || !labelsNode.isArray) List.empty + else labelsNode.elements().asScala.toList.map(_.asText("")).filter(_.nonEmpty) + Some(f) + } + } + } + SuggestedPromptConfig(systemPrompt, outputMode, outputColumn, fields) + } + + private def textOrEmpty(node: JsonNode): String = + if (node == null || node.isNull) "" + else if (node.isTextual) node.asText() + else node.toString +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentResponseCache.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentResponseCache.scala new file mode 100644 index 00000000000..2503cdd95da --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentResponseCache.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + */ + +package org.apache.texera.amber.operator.aiagent + +import java.security.MessageDigest + +class AIAgentResponseCache(capacity: Int = 1000) extends Serializable { + private val map: java.util.LinkedHashMap[String, String] = + new java.util.LinkedHashMap[String, String](capacity, 0.75f, true) { + override def removeEldestEntry(eldest: java.util.Map.Entry[String, String]): Boolean = + this.size() > capacity + } + + def get(key: String): Option[String] = map.synchronized { + Option(map.get(key)) + } + + def put(key: String, value: String): Unit = map.synchronized { + map.put(key, value) + } + + def size: Int = map.synchronized(map.size()) +} + +object AIAgentResponseCache { + def key( + model: String, + temperature: Double, + apiKeySignature: String, + systemPrompt: String, + userPrompt: String, + toolSig: String, + structSig: String + ): String = { + val canonical = + s"${Option(model).getOrElse("")}|$temperature|$apiKeySignature|${Option(systemPrompt).getOrElse("")}|" + + s"${Option(userPrompt).getOrElse("")}|$toolSig|$structSig" + sha256(canonical) + } + + def sha256(text: String): String = { + val md = MessageDigest.getInstance("SHA-256") + md.update(Option(text).getOrElse("").getBytes("UTF-8")) + md.digest().map("%02x".format(_)).mkString + } +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentTool.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentTool.scala new file mode 100644 index 00000000000..b17e01f306f --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentTool.scala @@ -0,0 +1,46 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import dev.langchain4j.agent.tool.ToolSpecification + +/** + * A single tool that the AI Agent can call inside its per-row execution loop. + * + * Mirrors the tool-factory shape used by the workflow-edit assistant in + * agent-service: one object per tool, exposing a LangChain4j specification and + * a synchronous executor that takes the model's argument JSON and returns a + * string result. Errors should be returned via [[AIAgentToolResult.error]] + * rather than thrown, so the model can read the failure and recover. + */ +trait AIAgentTool extends AutoCloseable with Serializable { + def name: String + def specification: ToolSpecification + def execute(argumentsJson: String): String + override def close(): Unit = {} +} + +object AIAgentToolResult { + final val ErrorPrefix = "[ERROR] " + + def ok(message: String): String = message + def error(message: String): String = s"$ErrorPrefix$message" + def isError(result: String): Boolean = result != null && result.startsWith(ErrorPrefix) +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentUrlSafety.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentUrlSafety.scala new file mode 100644 index 00000000000..d27e39ab498 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/AIAgentUrlSafety.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import java.net.{Inet4Address, Inet6Address, InetAddress, URI} + +object AIAgentUrlSafety { + private val AllowedSchemes = Set("http", "https") + + def validatePublicHttpUrl(rawUrl: String): URI = { + require(rawUrl != null && rawUrl.trim.nonEmpty, "url is required") + val uri = URI.create(rawUrl.trim) + val scheme = Option(uri.getScheme).map(_.toLowerCase).getOrElse("") + require(AllowedSchemes.contains(scheme), "Only http(s) URLs are allowed") + val host = Option(uri.getHost).map(_.trim).filter(_.nonEmpty).getOrElse { + throw new IllegalArgumentException("URL host is required") + } + require(!isLocalHostName(host), s"Private or local URL hosts are not allowed: $host") + val addresses = InetAddress.getAllByName(host) + require(addresses.nonEmpty, s"Could not resolve URL host: $host") + addresses.foreach { address => + require(!isPrivateAddress(address), s"Private or local URL hosts are not allowed: $host") + } + uri + } + + private def isLocalHostName(host: String): Boolean = { + val normalized = host.stripSuffix(".").toLowerCase + normalized == "localhost" || normalized.endsWith(".localhost") + } + + private def isPrivateAddress(address: InetAddress): Boolean = + address.isAnyLocalAddress || + address.isLoopbackAddress || + address.isLinkLocalAddress || + address.isSiteLocalAddress || + address.isMulticastAddress || + isCarrierGradeNat(address) || + isUniqueLocalIpv6(address) + + private def isCarrierGradeNat(address: InetAddress): Boolean = + address match { + case ipv4: Inet4Address => + val bytes = ipv4.getAddress.map(_ & 0xff) + bytes(0) == 100 && bytes(1) >= 64 && bytes(1) <= 127 + case _ => false + } + + private def isUniqueLocalIpv6(address: InetAddress): Boolean = + address match { + case ipv6: Inet6Address => + val first = ipv6.getAddress.head & 0xff + (first & 0xfe) == 0xfc + case _ => false + } +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/MCPClient.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/MCPClient.scala new file mode 100644 index 00000000000..34769b41520 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/MCPClient.scala @@ -0,0 +1,225 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import com.fasterxml.jackson.databind.JsonNode +import com.fasterxml.jackson.databind.node.ObjectNode +import com.typesafe.scalalogging.LazyLogging +import org.apache.texera.amber.util.JSONUtils.objectMapper + +import java.net.URI +import java.net.http.HttpRequest.BodyPublishers +import java.net.http.HttpResponse.BodyHandlers +import java.net.http.{HttpClient, HttpRequest, HttpResponse} +import java.time.Duration +import java.util.concurrent.atomic.AtomicLong + +case class McpToolInfo(name: String, description: String, inputSchema: JsonNode) + +/** + * Minimal Model Context Protocol client over Streamable HTTP. + * + * Supports the three calls we need to register and execute MCP-discovered + * tools inside the AI Agent loop: `initialize`, `tools/list`, `tools/call`. + * Bearer auth is sent as `Authorization: Bearer ` when a token is + * configured. Session continuity is maintained via the `Mcp-Session-Id` + * response header per the MCP spec. + * + * Response bodies may arrive as plain `application/json` or as a single SSE + * event (`text/event-stream`); both shapes are parsed transparently. + */ +class MCPClient( + val serverName: String, + val url: String, + val bearerToken: Option[String] = None, + val timeoutSeconds: Int = 30, + val protocolVersion: String = "2025-06-18" +) extends AutoCloseable + with LazyLogging { + + private val httpClient: HttpClient = HttpClient + .newBuilder() + .connectTimeout(Duration.ofSeconds(timeoutSeconds.toLong)) + .followRedirects(HttpClient.Redirect.NORMAL) + .build() + + private val nextId = new AtomicLong(0) + @volatile private var sessionId: Option[String] = None + @volatile private var initialized: Boolean = false + + def initialize(): Unit = { + val params = objectMapper.createObjectNode() + params.put("protocolVersion", protocolVersion) + params.set[ObjectNode]("capabilities", objectMapper.createObjectNode()) + val clientInfo = objectMapper.createObjectNode() + clientInfo.put("name", "texera-aiagent") + clientInfo.put("version", "1.0.0") + params.set[ObjectNode]("clientInfo", clientInfo) + + sendRequest("initialize", Some(params)) + sendNotification("notifications/initialized") + initialized = true + logger.info( + s"[MCP] initialized server=$serverName url=$url session=${sessionId.getOrElse("-")}" + ) + } + + def listTools(): List[McpToolInfo] = { + ensureInitialized() + val result = sendRequest("tools/list", None) + val toolsNode = Option(result.get("tools")).filter(_.isArray) + toolsNode match { + case Some(arr) => + val builder = List.newBuilder[McpToolInfo] + val iter = arr.elements() + while (iter.hasNext) { + val t = iter.next() + val name = Option(t.get("name")).map(_.asText("")).getOrElse("") + val description = Option(t.get("description")).map(_.asText("")).getOrElse("") + val schema = Option(t.get("inputSchema")).getOrElse(objectMapper.createObjectNode()) + if (name.nonEmpty) { + builder += McpToolInfo(name, description, schema) + } + } + builder.result() + case None => List.empty + } + } + + def callTool(toolName: String, argumentsJson: String): String = { + ensureInitialized() + val params = objectMapper.createObjectNode() + params.put("name", toolName) + val args = + try objectMapper.readTree(Option(argumentsJson).filter(_.nonEmpty).getOrElse("{}")) + catch { case _: Throwable => objectMapper.createObjectNode() } + params.set[ObjectNode]("arguments", args) + + val result = sendRequest("tools/call", Some(params)) + val isError = Option(result.get("isError")).exists(_.asBoolean(false)) + val contentNode = Option(result.get("content")).filter(_.isArray) + val text = contentNode match { + case Some(arr) => + val sb = new StringBuilder + val iter = arr.elements() + while (iter.hasNext) { + val item = iter.next() + val itemType = Option(item.get("type")).map(_.asText("")).getOrElse("") + if (itemType == "text") { + sb.append(Option(item.get("text")).map(_.asText("")).getOrElse("")) + } + } + sb.toString + case None => "" + } + if (isError) AIAgentToolResult.error(if (text.nonEmpty) text else "MCP tool returned isError") + else text + } + + override def close(): Unit = { + sessionId.foreach { _ => + try httpDelete() + catch { + case _: Throwable => + } + } + initialized = false + sessionId = None + } + + private def ensureInitialized(): Unit = + if (!initialized) initialize() + + private def sendNotification(method: String): Unit = { + val payload = objectMapper.createObjectNode() + payload.put("jsonrpc", "2.0") + payload.put("method", method) + httpPost(payload.toString) + } + + private def sendRequest(method: String, params: Option[JsonNode]): JsonNode = { + val id = nextId.incrementAndGet() + val payload = objectMapper.createObjectNode() + payload.put("jsonrpc", "2.0") + payload.put("id", id) + payload.put("method", method) + params.foreach(p => payload.set[ObjectNode]("params", p)) + + val responseBody = httpPost(payload.toString) + val root = parseRpcBody(responseBody) + Option(root.get("error")).foreach { err => + val code = Option(err.get("code")).map(_.asInt(0)).getOrElse(0) + val msg = Option(err.get("message")).map(_.asText("")).getOrElse("unknown") + throw new RuntimeException(s"MCP $serverName.$method error $code: $msg") + } + Option(root.get("result")).getOrElse(objectMapper.createObjectNode()) + } + + private def httpPost(body: String): String = { + val builder = HttpRequest + .newBuilder() + .uri(URI.create(url)) + .timeout(Duration.ofSeconds(timeoutSeconds.toLong)) + .header("Content-Type", "application/json") + .header("Accept", "application/json, text/event-stream") + .POST(BodyPublishers.ofString(body)) + bearerToken.foreach(t => builder.header("Authorization", s"Bearer ${t.trim}")) + sessionId.foreach(sid => builder.header("Mcp-Session-Id", sid)) + val request = builder.build() + + val response: HttpResponse[String] = httpClient.send(request, BodyHandlers.ofString()) + val status = response.statusCode() + if (status < 200 || status >= 300) { + throw new RuntimeException(s"MCP $serverName HTTP $status: ${response.body()}") + } + Option(response.headers().firstValue("mcp-session-id").orElse(null)) + .filter(_.nonEmpty) + .foreach(sid => sessionId = Some(sid)) + response.body() + } + + private def httpDelete(): Unit = { + val builder = HttpRequest + .newBuilder() + .uri(URI.create(url)) + .timeout(Duration.ofSeconds(timeoutSeconds.toLong)) + .DELETE() + bearerToken.foreach(t => builder.header("Authorization", s"Bearer ${t.trim}")) + sessionId.foreach(sid => builder.header("Mcp-Session-Id", sid)) + httpClient.send(builder.build(), BodyHandlers.discarding()) + } + + private def parseRpcBody(body: String): JsonNode = { + val trimmed = Option(body).getOrElse("").trim + if (trimmed.isEmpty) objectMapper.createObjectNode() + else if (trimmed.startsWith("{")) objectMapper.readTree(trimmed) + else { + // SSE: take the first `data:` line that parses as JSON-RPC. + val data = trimmed.linesIterator + .map(_.trim) + .filter(_.startsWith("data:")) + .map(_.stripPrefix("data:").trim) + .find(_.startsWith("{")) + data + .map(objectMapper.readTree) + .getOrElse(objectMapper.createObjectNode()) + } + } +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/MCPToolAdapter.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/MCPToolAdapter.scala new file mode 100644 index 00000000000..7fa612e5b75 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/MCPToolAdapter.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import com.fasterxml.jackson.databind.JsonNode +import dev.langchain4j.agent.tool.ToolSpecification +import dev.langchain4j.model.chat.request.json.{ + JsonArraySchema, + JsonBooleanSchema, + JsonIntegerSchema, + JsonNumberSchema, + JsonObjectSchema, + JsonSchemaElement, + JsonStringSchema +} + +import scala.jdk.CollectionConverters._ + +/** + * Wraps a tool discovered from an MCP server as an [[AIAgentTool]] so it can + * be dropped into the per-row tool-execution loop alongside built-in tools + * like `read_url` / `read_pdf`. + * + * Tool names are namespaced by server (e.g. `notion__search`) so multiple + * servers can expose tools with the same local name without colliding. + * + * Argument schemas come from the MCP server as JSON Schema and are converted + * to LangChain4j's [[JsonObjectSchema]] so the model receives a well-typed + * parameter spec. + */ +class MCPToolAdapter( + val client: MCPClient, + val toolInfo: McpToolInfo +) extends AIAgentTool { + + override val name: String = MCPToolAdapter.namespacedName(client.serverName, toolInfo.name) + + override val specification: ToolSpecification = ToolSpecification + .builder() + .name(name) + .description( + if (toolInfo.description.nonEmpty) toolInfo.description + else s"MCP tool ${toolInfo.name} on server ${client.serverName}" + ) + .parameters(MCPToolAdapter.toJsonObjectSchema(toolInfo.inputSchema)) + .build() + + override def execute(argumentsJson: String): String = { + try AIAgentToolResult.ok(client.callTool(toolInfo.name, argumentsJson)) + catch { + case t: Throwable => + AIAgentToolResult.error( + s"${t.getClass.getSimpleName}: ${Option(t.getMessage).getOrElse("")}" + ) + } + } + + override def close(): Unit = client.close() +} + +object MCPToolAdapter { + def namespacedName(serverName: String, toolName: String): String = { + val safeServer = sanitize(serverName) + val safeTool = sanitize(toolName) + if (safeServer.isEmpty) safeTool else s"${safeServer}__${safeTool}" + } + + private def sanitize(s: String): String = + Option(s).getOrElse("").trim.replaceAll("[^A-Za-z0-9_]", "_") + + /** + * Convert an MCP-provided JSON Schema object into a LangChain4j + * [[JsonObjectSchema]]. Supports top-level primitive properties, nested + * objects, arrays, and `required[]`. Unknown / unsupported types degrade to + * string so the model can still send something rather than failing. + */ + def toJsonObjectSchema(schema: JsonNode): JsonObjectSchema = { + val builder = JsonObjectSchema.builder() + if (schema == null || !schema.isObject) { + return builder.additionalProperties(false).build() + } + Option(schema.get("description")) + .map(_.asText("")) + .filter(_.nonEmpty) + .foreach(builder.description) + + val propertiesNode = Option(schema.get("properties")).filter(_.isObject) + propertiesNode.foreach { props => + val fields = props.fields() + while (fields.hasNext) { + val entry = fields.next() + val propName = entry.getKey + val propSchema = entry.getValue + builder.addProperty(propName, toJsonSchemaElement(propSchema)) + } + } + + Option(schema.get("required")).filter(_.isArray).foreach { req => + val names = req.elements().asScala.map(_.asText("")).filter(_.nonEmpty).toList + if (names.nonEmpty) builder.required(names.asJava) + } + + builder.additionalProperties(false).build() + } + + private def toJsonSchemaElement(schema: JsonNode): JsonSchemaElement = { + if (schema == null || !schema.isObject) return JsonStringSchema.builder().build() + val description = Option(schema.get("description")).map(_.asText("")).getOrElse("") + val typeStr = Option(schema.get("type")).map(_.asText("")).getOrElse("string") + typeStr match { + case "string" => + val b = JsonStringSchema.builder() + if (description.nonEmpty) b.description(description) + b.build() + case "integer" => + val b = JsonIntegerSchema.builder() + if (description.nonEmpty) b.description(description) + b.build() + case "number" => + val b = JsonNumberSchema.builder() + if (description.nonEmpty) b.description(description) + b.build() + case "boolean" => + val b = JsonBooleanSchema.builder() + if (description.nonEmpty) b.description(description) + b.build() + case "array" => + val b = JsonArraySchema.builder() + if (description.nonEmpty) b.description(description) + val itemsNode = Option(schema.get("items")).getOrElse(null) + if (itemsNode != null) b.items(toJsonSchemaElement(itemsNode)) + else b.items(JsonStringSchema.builder().build()) + b.build() + case "object" => + toJsonObjectSchema(schema) + case _ => + val b = JsonStringSchema.builder() + if (description.nonEmpty) b.description(description) + b.build() + } + } +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/OpenRouterClient.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/OpenRouterClient.scala new file mode 100644 index 00000000000..718c3cb618f --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/OpenRouterClient.scala @@ -0,0 +1,402 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import com.fasterxml.jackson.databind.JsonNode +import com.typesafe.scalalogging.LazyLogging +import dev.langchain4j.agent.tool.ToolSpecification +import dev.langchain4j.data.message.{ + AiMessage, + ChatMessage, + SystemMessage, + ToolExecutionResultMessage, + UserMessage +} +import dev.langchain4j.model.chat.request.{ChatRequest, ToolChoice} +import dev.langchain4j.model.openai.OpenAiChatModel +import org.apache.texera.amber.util.JSONUtils.objectMapper + +import java.time.Duration +import scala.jdk.CollectionConverters._ + +case class ChatCompletionResult(text: String, usdCost: Double) + +trait ChatCompletionClient extends Serializable { + def complete( + apiKey: String, + model: String, + systemPrompt: String, + userPrompt: String, + temperature: Double, + timeoutSeconds: Int + ): ChatCompletionResult + + def completeWithRequiredTool( + apiKey: String, + model: String, + systemPrompt: String, + userPrompt: String, + temperature: Double, + timeoutSeconds: Int, + toolSpecification: ToolSpecification + ): ChatCompletionResult + + /** + * Multi-turn tool-execution loop for the per-row AI Agent. + * + * Mirrors the loop shape used by the workflow-edit assistant in + * `agent-service`, simplified for the per-row data-in/data-out case: no + * persistent conversation, no streaming. The loop terminates when the model + * calls `finalAnswerTool` (returning its arguments JSON), when the model + * returns plain text with no tool call (returned verbatim) and no + * finalAnswer is required, or when `maxIterations` is exhausted (the last + * turn is forced to call finalAnswerTool if one is supplied). + */ + def completeWithTools( + apiKey: String, + model: String, + systemPrompt: String, + userPrompt: String, + temperature: Double, + timeoutSeconds: Int, + tools: List[AIAgentTool], + finalAnswerTool: Option[ToolSpecification], + maxIterations: Int, + maxRowCostUsd: Option[Double] = None + ): ChatCompletionResult +} + +class OpenRouterClient extends ChatCompletionClient with LazyLogging { + override def complete( + apiKey: String, + model: String, + systemPrompt: String, + userPrompt: String, + temperature: Double, + timeoutSeconds: Int + ): ChatCompletionResult = { + require(apiKey != null && apiKey.trim.nonEmpty, "OpenRouter API key is required") + require(model != null && model.trim.nonEmpty, "OpenRouter model is required") + require(timeoutSeconds > 0, "Timeout seconds must be positive") + + val chatModel = createChatModel(apiKey, model, temperature, timeoutSeconds) + + val response = chatModel.chat(buildMessages(systemPrompt, userPrompt).toList.asJava) + val text = Option(response.aiMessage().text()).getOrElse("") + val cost = OpenRouterPricing.costFor(model, Option(response.tokenUsage())) + ChatCompletionResult(text, cost) + } + + override def completeWithRequiredTool( + apiKey: String, + model: String, + systemPrompt: String, + userPrompt: String, + temperature: Double, + timeoutSeconds: Int, + toolSpecification: ToolSpecification + ): ChatCompletionResult = + completeWithTools( + apiKey, + model, + systemPrompt, + userPrompt, + temperature, + timeoutSeconds, + tools = List.empty, + finalAnswerTool = Some(toolSpecification), + maxIterations = 1 + ) + + override def completeWithTools( + apiKey: String, + model: String, + systemPrompt: String, + userPrompt: String, + temperature: Double, + timeoutSeconds: Int, + tools: List[AIAgentTool], + finalAnswerTool: Option[ToolSpecification], + maxIterations: Int, + maxRowCostUsd: Option[Double] = None + ): ChatCompletionResult = { + require(maxIterations > 0, "maxIterations must be positive") + val chatModel = createChatModel(apiKey, model, temperature, timeoutSeconds) + val toolsByName: Map[String, AIAgentTool] = tools.map(t => t.name -> t).toMap + val allSpecs: List[ToolSpecification] = tools.map(_.specification) ++ finalAnswerTool.toList + + val messages = scala.collection.mutable.ListBuffer.empty[ChatMessage] + messages ++= buildMessages(systemPrompt, userPrompt) + + logger.info( + s"[AIAgent] start model=$model tools=${tools.map(_.name).mkString(",")} " + + s"finalAnswer=${finalAnswerTool.map(_.name()).getOrElse("none")} maxIter=$maxIterations" + ) + + var accumulatedCost: Double = 0.0 + var lastText: String = "" + var iteration = 0 + while (iteration < maxIterations) { + iteration += 1 + val isLastIteration = iteration == maxIterations + val turnStart = System.currentTimeMillis() + val (aiMessage, turnCost) = sendChatTurnWithCost( + chatModel, + model, + messages.toList, + allSpecs, + forceFinalAnswer = isLastIteration && finalAnswerTool.isDefined, + finalAnswerTool + ) + accumulatedCost += turnCost + val turnMs = System.currentTimeMillis() - turnStart + lastText = Option(aiMessage.text()).getOrElse("") + + val requests = Option(aiMessage.toolExecutionRequests()) + .map(_.asScala.toList) + .getOrElse(List.empty) + + logger.info( + s"[AIAgent] turn=$iteration/$maxIterations chatMs=$turnMs " + + s"toolCalls=${requests.map(_.name()).mkString(",")} " + + s"textLen=${lastText.length} costUsd=$accumulatedCost" + ) + + maxRowCostUsd.foreach { cap => + if (accumulatedCost > cap) { + throw new RuntimeException( + f"Row cost cap exceeded: $$$accumulatedCost%.6f > $$$cap%.6f" + ) + } + } + + if (requests.isEmpty) { + logger.info(s"[AIAgent] done turn=$iteration reason=plainText") + return ChatCompletionResult(lastText, accumulatedCost) + } + + finalAnswerTool + .flatMap(spec => requests.find(_.name() == spec.name())) + .foreach { req => + logger.info(s"[AIAgent] done turn=$iteration reason=finalAnswerTool") + return ChatCompletionResult(req.arguments(), accumulatedCost) + } + + messages += aiMessage + requests.foreach { request => + val toolStart = System.currentTimeMillis() + val result = toolsByName.get(request.name()) match { + case Some(tool) => + try tool.execute(Option(request.arguments()).getOrElse("{}")) + catch { + case t: Throwable => + AIAgentToolResult.error( + s"Tool ${request.name()} threw ${t.getClass.getSimpleName}: ${Option(t.getMessage) + .getOrElse("")}" + ) + } + case None => + AIAgentToolResult.error(s"Unknown tool: ${request.name()}") + } + val toolMs = System.currentTimeMillis() - toolStart + val isErr = AIAgentToolResult.isError(result) + val argsSnippet = + Option(request.arguments()).getOrElse("").replaceAll("\\s+", " ").take(300) + val errSnippet = if (isErr) s" args=$argsSnippet error=${result.take(400)}" else "" + logger.info( + s"[AIAgent] tool=${request.name()} ms=$toolMs " + + s"isError=$isErr resultLen=${result.length}$errSnippet" + ) + messages += ToolExecutionResultMessage.from(request, result) + } + } + logger.warn(s"[AIAgent] done turn=$iteration reason=maxIterations textLen=${lastText.length}") + ChatCompletionResult(lastText, accumulatedCost) + } + + private def sendChatTurnWithCost( + chatModel: OpenAiChatModel, + model: String, + messages: List[ChatMessage], + toolSpecifications: List[ToolSpecification], + forceFinalAnswer: Boolean, + finalAnswerTool: Option[ToolSpecification] + ): (AiMessage, Double) = { + val builder = ChatRequest.builder().messages(messages.asJava) + if (forceFinalAnswer && finalAnswerTool.isDefined) { + builder + .toolSpecifications(finalAnswerTool.get) + .toolChoice(ToolChoice.REQUIRED) + } else if (toolSpecifications.nonEmpty) { + builder.toolSpecifications(toolSpecifications.asJava) + if (finalAnswerTool.isDefined && toolSpecifications.size == 1) { + builder.toolChoice(ToolChoice.REQUIRED) + } + } + val resp = chatModel.chat(builder.build()) + val cost = OpenRouterPricing.costFor(model, Option(resp.tokenUsage())) + (resp.aiMessage(), cost) + } + + private def createChatModel( + apiKey: String, + model: String, + temperature: Double, + timeoutSeconds: Int + ): OpenAiChatModel = { + require(apiKey != null && apiKey.trim.nonEmpty, "OpenRouter API key is required") + require(model != null && model.trim.nonEmpty, "OpenRouter model is required") + require(timeoutSeconds > 0, "Timeout seconds must be positive") + + OpenAiChatModel + .builder() + .baseUrl(OpenRouterClient.OpenRouterBaseUrl) + .apiKey(apiKey.trim) + .modelName(model) + .temperature(temperature) + .timeout(Duration.ofSeconds(timeoutSeconds.toLong)) + .build() + } + + private def buildMessages( + systemPrompt: String, + userPrompt: String + ): scala.collection.mutable.ListBuffer[dev.langchain4j.data.message.ChatMessage] = { + val messages = + scala.collection.mutable.ListBuffer.empty[dev.langchain4j.data.message.ChatMessage] + if (systemPrompt != null && systemPrompt.nonEmpty) { + messages += SystemMessage.from(systemPrompt) + } + messages += UserMessage.from(Option(userPrompt).getOrElse("")) + messages + } +} + +object OpenRouterPricing extends com.typesafe.scalalogging.LazyLogging { + // USD per 1M tokens (prompt, completion). Fallback used when the OpenRouter + // models API is unreachable. + private val fallbackTable: Map[String, (Double, Double)] = Map( + "openai/gpt-4o-mini" -> (0.15, 0.60), + "openai/gpt-4o" -> (2.50, 10.00), + "openai/gpt-5" -> (1.25, 10.00), + "anthropic/claude-3.5-sonnet" -> (3.00, 15.00), + "anthropic/claude-3.5-haiku" -> (0.80, 4.00), + "google/gemini-2.0-flash-001" -> (0.10, 0.40), + "meta-llama/llama-3.3-70b-instruct" -> (0.12, 0.30) + ) + + // Per-token (not per-1M) pricing fetched from OpenRouter. Lazy + cached for + // the JVM lifetime; one network hit per worker on first use. + @volatile private var remoteTable: Option[Map[String, (Double, Double)]] = None + private val lock = new Object + + private def loadRemoteTable(): Map[String, (Double, Double)] = { + val client = java.net.http.HttpClient + .newBuilder() + .connectTimeout(java.time.Duration.ofSeconds(5)) + .build() + val req = java.net.http.HttpRequest + .newBuilder(java.net.URI.create("https://openrouter.ai/api/v1/models")) + .timeout(java.time.Duration.ofSeconds(10)) + .GET() + .build() + val resp = client.send(req, java.net.http.HttpResponse.BodyHandlers.ofString()) + if (resp.statusCode() / 100 != 2) { + throw new RuntimeException(s"OpenRouter /models HTTP ${resp.statusCode()}") + } + val root = objectMapper.readTree(resp.body()) + val data = root.get("data") + if (data == null || !data.isArray) { + throw new RuntimeException("OpenRouter /models response missing data[]") + } + val builder = Map.newBuilder[String, (Double, Double)] + data.elements().forEachRemaining { node => + val id = Option(node.get("id")).map(_.asText("")).getOrElse("") + val pricing = node.get("pricing") + if (id.nonEmpty && pricing != null) { + val pIn = Option(pricing.get("prompt")).flatMap(n => parseDouble(n.asText(""))).getOrElse(0.0) + val pOut = Option(pricing.get("completion")).flatMap(n => parseDouble(n.asText(""))).getOrElse(0.0) + builder += (id.toLowerCase -> (pIn, pOut)) + } + } + builder.result() + } + + private def parseDouble(s: String): Option[Double] = + try Some(s.toDouble) catch { case _: Throwable => None } + + private def perTokenPricing(model: String): (Double, Double) = { + val key = Option(model).map(_.trim.toLowerCase).getOrElse("") + if (remoteTable.isEmpty) { + lock.synchronized { + if (remoteTable.isEmpty) { + try { + val t = loadRemoteTable() + logger.info(s"[AIAgent] loaded OpenRouter pricing for ${t.size} models") + remoteTable = Some(t) + } catch { + case t: Throwable => + logger.warn(s"[AIAgent] OpenRouter pricing fetch failed: ${t.getMessage}; using fallback table") + remoteTable = Some(Map.empty) + } + } + } + } + remoteTable.get.get(key) match { + case Some(p) => p + case None => + val (pIn, pOut) = fallbackTable.getOrElse(key, (0.0, 0.0)) + (pIn / 1000000.0, pOut / 1000000.0) + } + } + + def costFor(model: String, usage: Option[dev.langchain4j.model.output.TokenUsage]): Double = { + val (pIn, pOut) = perTokenPricing(model) + usage match { + case Some(u) => + val inTok = Option(u.inputTokenCount()).map(_.intValue).getOrElse(0) + val outTok = Option(u.outputTokenCount()).map(_.intValue).getOrElse(0) + inTok * pIn + outTok * pOut + case None => 0.0 + } + } +} + +object OpenRouterClient { + final val OpenRouterBaseUrl = "https://openrouter.ai/api/v1" + final val OpenRouterChatCompletionsUrl = "https://openrouter.ai/api/v1/chat/completions" + + def parseChatCompletionContent(responseBody: String): String = { + val root = objectMapper.readTree(responseBody) + extractFirstChoiceContent(root).getOrElse { + throw new RuntimeException("OpenRouter response does not contain choices[0].message.content") + } + } + + private def extractFirstChoiceContent(root: JsonNode): Option[String] = + for { + choices <- Option(root.get("choices")) + if choices.isArray && choices.size() > 0 + choice <- Option(choices.get(0)) + message <- Option(choice.get("message")) + content <- Option(message.get("content")) + if content.isTextual + } yield content.asText() +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/PdfReadTool.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/PdfReadTool.scala new file mode 100644 index 00000000000..1dab0f1a0a4 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/PdfReadTool.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import com.fasterxml.jackson.databind.JsonNode +import dev.langchain4j.agent.tool.ToolSpecification +import dev.langchain4j.model.chat.request.json.{JsonIntegerSchema, JsonObjectSchema} +import org.apache.texera.amber.util.JSONUtils.objectMapper + +class PdfReadTool(maxChars: Int) extends AIAgentTool { + override val name: String = PdfReadTool.Name + + override val specification: ToolSpecification = ToolSpecification + .builder() + .name(name) + .description( + "Read text from a PDF document at a public http(s) URL. Optionally restrict to a page range (1-based, inclusive). Returns extracted text." + ) + .parameters( + JsonObjectSchema + .builder() + .addStringProperty("source", "Public http(s) URL of the PDF") + .addProperty( + "startPage", + JsonIntegerSchema + .builder() + .description("First page to read (1-based, inclusive). Omit for page 1.") + .build() + ) + .addProperty( + "endPage", + JsonIntegerSchema + .builder() + .description("Last page to read (1-based, inclusive). Omit for the last page.") + .build() + ) + .required("source") + .additionalProperties(false) + .build() + ) + .build() + + override def execute(argumentsJson: String): String = { + try { + val args: JsonNode = objectMapper.readTree(Option(argumentsJson).getOrElse("{}")) + val source = Option(args.get("source")).map(_.asText("")).getOrElse("") + if (source.trim.isEmpty) { + AIAgentToolResult.error("Missing required argument: source") + } else { + val startPage = Option(args.get("startPage")).filter(_.isNumber).map(_.asInt()) + val endPage = Option(args.get("endPage")).filter(_.isNumber).map(_.asInt()) + AIAgentToolResult.ok(PdfReader.readText(source, startPage, endPage, maxChars)) + } + } catch { + case t: Throwable => + AIAgentToolResult.error( + s"${t.getClass.getSimpleName}: ${Option(t.getMessage).getOrElse("")}" + ) + } + } +} + +object PdfReadTool { + final val Name = "read_pdf" +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/PdfReader.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/PdfReader.scala new file mode 100644 index 00000000000..6024f104ad6 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/PdfReader.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import org.apache.pdfbox.pdmodel.PDDocument +import org.apache.pdfbox.text.PDFTextStripper + +import java.io.{ByteArrayOutputStream, InputStream} +import java.net.{HttpURLConnection, URL} + +/** + * Read a PDF from a public HTTP(S) URL and return its text content + * for an LLM to consume. + * + * Self-contained on the JVM via Apache PDFBox 2.x — no external service. + * Supports optional page-range filtering (1-based inclusive) and a hard + * character cap with truncation marker, since real-world PDFs are routinely + * larger than the model's context window. + */ +object PdfReader { + final val DefaultMaxChars = 100000 + final val DefaultMaxBytes = 32 * 1024 * 1024 + final val DefaultTimeoutMs = 30000 + final val TruncationMarker = "\n\n[... truncated ...]" + private final val UserAgent = + "Mozilla/5.0 (compatible; TexeraAIAgent/1.0; +https://texera.io)" + + def readText( + source: String, + startPage: Option[Int] = None, + endPage: Option[Int] = None, + maxChars: Int = DefaultMaxChars, + maxBytes: Int = DefaultMaxBytes, + timeoutMs: Int = DefaultTimeoutMs + ): String = { + require(source != null && source.trim.nonEmpty, "source is required") + val bytes = loadBytes(source.trim, maxBytes, timeoutMs) + val effectiveMax = if (maxChars <= 0) DefaultMaxChars else maxChars + val document = PDDocument.load(bytes) + try { + val stripper = new PDFTextStripper() + val totalPages = document.getNumberOfPages + stripper.setStartPage(startPage.map(_.max(1)).getOrElse(1)) + stripper.setEndPage(endPage.map(_.min(totalPages)).getOrElse(totalPages)) + val text = Option(stripper.getText(document)).getOrElse("").trim + truncate(text, effectiveMax) + } finally { + document.close() + } + } + + private def loadBytes(source: String, maxBytes: Int, timeoutMs: Int): Array[Byte] = { + val uri = AIAgentUrlSafety.validatePublicHttpUrl(source) + downloadBytes(uri.toString, maxBytes, timeoutMs, redirectCount = 0) + } + + private def downloadBytes( + url: String, + maxBytes: Int, + timeoutMs: Int, + redirectCount: Int + ): Array[Byte] = { + require(redirectCount <= 5, s"Too many redirects fetching $url") + val connection = new URL(url).openConnection().asInstanceOf[HttpURLConnection] + connection.setRequestProperty("User-Agent", UserAgent) + connection.setConnectTimeout(timeoutMs) + connection.setReadTimeout(timeoutMs) + connection.setInstanceFollowRedirects(false) + val responseCode = connection.getResponseCode + if (responseCode >= 300 && responseCode < 400) { + val location = connection.getHeaderField("Location") + require(location != null && location.trim.nonEmpty, s"Redirect missing Location fetching $url") + val redirected = new URL(new URL(url), location.trim).toString + AIAgentUrlSafety.validatePublicHttpUrl(redirected) + return downloadBytes(redirected, maxBytes, timeoutMs, redirectCount + 1) + } + require(responseCode >= 200 && responseCode < 300, s"HTTP $responseCode fetching $url") + val stream: InputStream = connection.getInputStream + try readWithCap(stream, maxBytes) + finally stream.close() + } + + private def readWithCap(stream: InputStream, maxBytes: Int): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + val chunk = new Array[Byte](8192) + var total = 0 + var read = stream.read(chunk) + while (read != -1) { + total += read + require(total <= maxBytes, s"PDF body exceeds maxBytes ($maxBytes)") + buffer.write(chunk, 0, read) + read = stream.read(chunk) + } + buffer.toByteArray + } + + private[aiagent] def truncate(text: String, maxChars: Int): String = + if (text.length <= maxChars) text + else text.substring(0, maxChars).stripTrailing() + TruncationMarker +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/PromptTemplate.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/PromptTemplate.scala new file mode 100644 index 00000000000..ed47fe52c05 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/PromptTemplate.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import org.apache.texera.amber.core.tuple.Tuple + +import java.util.regex.{Matcher, Pattern} + +object PromptTemplate { + private val PlaceholderPattern: Pattern = + Pattern.compile("\\{\\{\\s*([A-Za-z_][A-Za-z0-9_\\-.]*)\\s*\\}\\}") + + def render(template: String, tuple: Tuple): String = { + if (template == null) { + return "" + } + + val matcher = PlaceholderPattern.matcher(template) + val rendered = new StringBuffer() + while (matcher.find()) { + val attributeName = matcher.group(1) + if (!tuple.getSchema.containsAttribute(attributeName)) { + throw new IllegalArgumentException( + s"Prompt template references missing column: $attributeName" + ) + } + val value = Option(tuple.getField[Any](attributeName)).map(_.toString).getOrElse("") + matcher.appendReplacement(rendered, Matcher.quoteReplacement(value)) + } + matcher.appendTail(rendered) + rendered.toString + } +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/UrlFetchTool.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/UrlFetchTool.scala new file mode 100644 index 00000000000..190ad430542 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/UrlFetchTool.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import com.fasterxml.jackson.databind.JsonNode +import dev.langchain4j.agent.tool.ToolSpecification +import dev.langchain4j.model.chat.request.json.JsonObjectSchema +import org.apache.texera.amber.util.JSONUtils.objectMapper + +class UrlFetchTool(maxChars: Int) extends AIAgentTool { + override val name: String = UrlFetchTool.Name + + override val specification: ToolSpecification = ToolSpecification + .builder() + .name(name) + .description( + "Fetch a web page and return its main content as Markdown. Use when you need to read a URL that the user provided or referenced. Returns clean article text without nav/ads/footers." + ) + .parameters( + JsonObjectSchema + .builder() + .addStringProperty("url", "Absolute http(s) URL of the page to fetch") + .required("url") + .additionalProperties(false) + .build() + ) + .build() + + override def execute(argumentsJson: String): String = { + try { + val args: JsonNode = objectMapper.readTree(Option(argumentsJson).getOrElse("{}")) + val url = Option(args.get("url")).map(_.asText("")).getOrElse("") + if (url.trim.isEmpty) { + AIAgentToolResult.error("Missing required argument: url") + } else { + AIAgentToolResult.ok(UrlFetcher.fetchAsMarkdown(url, maxChars)) + } + } catch { + case t: Throwable => + AIAgentToolResult.error( + s"${t.getClass.getSimpleName}: ${Option(t.getMessage).getOrElse("")}" + ) + } + } +} + +object UrlFetchTool { + final val Name = "read_url" +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/UrlFetcher.scala b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/UrlFetcher.scala new file mode 100644 index 00000000000..a8775676332 --- /dev/null +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/aiagent/UrlFetcher.scala @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import com.typesafe.scalalogging.LazyLogging +import com.vladsch.flexmark.html2md.converter.FlexmarkHtmlConverter +import net.dankito.readability4j.Readability4J +import org.jsoup.Jsoup + +/** + * Fetch a URL and return a token-efficient Markdown representation of the + * main content, suitable to hand to an LLM. + * + * Pipeline: + * 1. HTTP GET via jsoup (timeout + size cap, browser-like User-Agent) + * 2. Readability4J to strip nav/footer/sidebar/ads and keep the article body + * 3. flexmark-html2md to convert the cleaned HTML to Markdown + * 4. Truncate to `maxChars` with a clear marker + * + * Self-contained on the JVM — no third-party service, no rate limits. + */ +object UrlFetcher extends LazyLogging { + final val DefaultMaxChars = 50000 + final val DefaultTimeoutMs = 15000 + final val DefaultMaxBodyBytes = 5 * 1024 * 1024 + final val TruncationMarker = "\n\n[... truncated ...]" + private final val MaxRedirects = 5 + private final val UserAgent = + "Mozilla/5.0 (compatible; TexeraAIAgent/1.0; +https://texera.io)" + + private lazy val htmlToMd: FlexmarkHtmlConverter = + FlexmarkHtmlConverter.builder().build() + + def fetchAsMarkdown( + url: String, + maxChars: Int = DefaultMaxChars, + timeoutMs: Int = DefaultTimeoutMs, + maxBodyBytes: Int = DefaultMaxBodyBytes + ): String = { + require(url != null && url.trim.nonEmpty, "url is required") + val target = ensureScheme(url.trim) + val effectiveMax = if (maxChars <= 0) DefaultMaxChars else maxChars + fetchAsMarkdownChecked(target, effectiveMax, timeoutMs, maxBodyBytes, 0) + } + + private def fetchAsMarkdownChecked( + target: String, + maxChars: Int, + timeoutMs: Int, + maxBodyBytes: Int, + redirectCount: Int + ): String = { + val targetUri = AIAgentUrlSafety.validatePublicHttpUrl(target) + val fetchStart = System.currentTimeMillis() + val response = Jsoup + .connect(targetUri.toString) + .userAgent(UserAgent) + .timeout(timeoutMs) + .maxBodySize(maxBodyBytes) + .followRedirects(false) + .ignoreContentType(false) + .ignoreHttpErrors(true) + .execute() + val statusCode = response.statusCode() + if (statusCode >= 300 && statusCode < 400) { + require(redirectCount < MaxRedirects, s"Too many redirects fetching $target") + val location = response.header("Location") + require(location != null && location.trim.nonEmpty, s"Redirect missing Location fetching $target") + val redirected = targetUri.resolve(location.trim).toString + return fetchAsMarkdownChecked(redirected, maxChars, timeoutMs, maxBodyBytes, redirectCount + 1) + } + require(statusCode >= 200 && statusCode < 300, s"HTTP $statusCode fetching $target") + val html = response.body() + val fetchMs = System.currentTimeMillis() - fetchStart + + val readability = new Readability4J(targetUri.toString, html) + val article = readability.parse() + val articleHtml = Option(article.getArticleContent) + .map(_.outerHtml()) + .filter(_.nonEmpty) + .getOrElse(html) + + val markdown = htmlToMd.convert(articleHtml).trim + val out = truncate(markdown, maxChars) + logger.info( + s"[UrlFetcher] url=$targetUri fetchMs=$fetchMs htmlLen=${html.length} mdLen=${out.length}" + ) + out + } + + private[aiagent] def truncate(text: String, maxChars: Int): String = + if (text.length <= maxChars) text + else text.substring(0, maxChars).stripTrailing() + TruncationMarker + + private def ensureScheme(url: String): String = + if (url.startsWith("http://") || url.startsWith("https://")) url + else "https://" + url +} diff --git a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/metadata/annotations/UIWidget.java b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/metadata/annotations/UIWidget.java index 04bc6a05e68..5cbf973981e 100644 --- a/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/metadata/annotations/UIWidget.java +++ b/common/workflow-operator/src/main/scala/org/apache/texera/amber/operator/metadata/annotations/UIWidget.java @@ -25,4 +25,6 @@ public class UIWidget { public static final String UIWidgetPassword = "{ \"widget\": {\n \"formlyConfig\": {\n \"templateOptions\": {\n \"type\": \"password\"\n }\n }\n }\n }"; + public static final String UIWidgetTagsInput = "{ \"widget\": { \"formlyConfig\": { \"type\": \"tags-input\" } } }"; + } diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentOpDescSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentOpDescSpec.scala new file mode 100644 index 00000000000..87f86ab615a --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentOpDescSpec.scala @@ -0,0 +1,275 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import com.github.fge.jsonschema.main.JsonSchemaFactory +import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema} +import org.apache.texera.amber.core.virtualidentity.{ExecutionIdentity, WorkflowIdentity} +import org.apache.texera.amber.operator.metadata.OperatorMetadataGenerator +import org.apache.texera.amber.util.JSONUtils.objectMapper +import org.scalatest.flatspec.AnyFlatSpec + +import scala.jdk.CollectionConverters.IteratorHasAsScala + +class AIAgentOpDescSpec extends AnyFlatSpec { + + private val workflowId = WorkflowIdentity(0L) + private val executionId = ExecutionIdentity(0L) + private val inputSchema: Schema = Schema() + .add(new Attribute("comment", AttributeType.STRING)) + .add(new Attribute("rating", AttributeType.INTEGER)) + + private def validDesc(): AIAgentOpDesc = { + val desc = new AIAgentOpDesc + desc.apiKey = "openrouter-key" + desc.inputColumn = List("comment") + desc + } + + "AIAgentOpDesc metadata" should "generate a numeric temperature default" in { + val schema = OperatorMetadataGenerator.generateOperatorJsonSchema(classOf[AIAgentOpDesc]) + val temperature = schema.get("properties").get("temperature") + + assert(temperature.get("default").isDouble) + assert(temperature.get("default").asDouble() == 0.7) + } + + it should "expose output modes and hide mode-specific fields" in { + val schema = OperatorMetadataGenerator.generateOperatorJsonSchema(classOf[AIAgentOpDesc]) + val properties = schema.get("properties") + val outputMode = properties.get("outputMode") + + assert(outputMode.get("default").asText() == "text") + assert(outputMode.get("enum").toString.contains("structured")) + assert(!outputMode.get("enum").toString.contains("classification")) + assert( + properties + .get("structuredOutputFields") + .get("hideExpectedValue") + .asText() == "text" + ) + assert( + properties + .get("structuredOutputFields") + .get("title") + .asText() == "Structured Output Fields" + ) + assert( + properties + .get("structuredOutputFields") + .get("description") + .asText() + .contains("Define each output column and what the model should extract") + ) + assert( + properties.get("textClassificationLabels").get("hideExpectedValue").asText() == "structured" + ) + assert(properties.get("outputColumn").get("hideExpectedValue").asText() == "structured") + } + + it should "show mode-specific output fields immediately after output mode" in { + val schema = OperatorMetadataGenerator.generateOperatorJsonSchema(classOf[AIAgentOpDesc]) + val properties = schema.get("properties") + val propertyNames = properties.fieldNames().asScala.toList + val aiAgentPropertyNames = propertyNames.filterNot(_ == "dummyPropertyList") + + assert(aiAgentPropertyNames.head == "outputMode") + assert( + aiAgentPropertyNames.slice(1, 4) == List( + "structuredOutputFields", + "textClassificationLabels", + "classificationLabels" + ) + ) + assert(properties.get("inputColumn").get("title").asText() == "Columns Sent to AI") + assert(properties.get("inputColumn").get("autofill").asText() == "attributeNameList") + assert(!propertyNames.contains("userPromptTemplate")) + } + + it should "deserialize a legacy single input column as a one-element list" in { + val desc = objectMapper.readValue( + """{"operatorType":"AIAgent","inputColumn":"text","apiKey":"openrouter-key"}""", + classOf[AIAgentOpDesc] + ) + + assert(desc.inputColumn == List("text")) + } + + it should "deserialize structured output fields with column instructions" in { + val desc = objectMapper.readValue( + """{ + | "operatorType": "AIAgent", + | "inputColumn": "text", + | "apiKey": "openrouter-key", + | "outputMode": "structured", + | "structuredOutputFields": [ + | { + | "fieldType": "classification", + | "columnName": "sentiment", + | "instructions": "positive, neutral, or negative", + | "classificationLabels": ["positive", "neutral", "negative"] + | } + | ] + |}""".stripMargin, + classOf[AIAgentOpDesc] + ) + + assert(desc.normalizedStructuredOutputColumns == List("sentiment")) + assert( + desc.normalizedStructuredOutputFields.head.normalizedFieldType == AIAgentStructuredFieldType.Classification + ) + assert( + desc.normalizedStructuredOutputFields.head.instructions == "positive, neutral, or negative" + ) + assert( + desc.normalizedStructuredOutputFields.head.normalizedClassificationLabels == List( + "positive", + "neutral", + "negative" + ) + ) + } + + it should "validate structured mode without hidden text output column" in { + val schema = OperatorMetadataGenerator.generateOperatorJsonSchema(classOf[AIAgentOpDesc]) + val properties = + """ + |{ + | "outputMode": "structured", + | "structuredOutputFields": [ + | { + | "columnName": "sentiment", + | "instructions": "positive, neutral, or negative" + | } + | ], + | "systemPrompt": "", + | "inputColumn": ["comment"], + | "apiKey": "openrouter-key", + | "model": "openai/gpt-4o-mini", + | "temperature": 0.7, + | "timeoutSeconds": 60, + | "cacheEnabled": true, + | "emitCostColumn": true, + | "emitErrorColumn": true + |} + |""".stripMargin + val report = JsonSchemaFactory + .byDefault() + .getJsonSchema(schema) + .validate(objectMapper.readTree(properties)) + + assert(report.isSuccess) + } + + it should "propagate input schema while output column is blank" in { + val desc = validDesc() + desc.outputColumn = "" + + val op = desc.getPhysicalOp(workflowId, executionId) + val inputPortId = op.inputPorts.keys.head + val outputPortId = op.outputPorts.keys.head + val updated = op.propagateSchema(Some(inputPortId -> inputSchema)) + + val outputSchema = inputSchema + .add(new Attribute("_cost_usd", AttributeType.DOUBLE)) + .add(new Attribute("_error", AttributeType.STRING)) + + assert(updated.outputPorts(outputPortId)._3.toOption.contains(outputSchema)) + } + + it should "append the configured output column during schema propagation" in { + val desc = validDesc() + desc.outputColumn = "ai_response" + + val op = desc.getPhysicalOp(workflowId, executionId) + val inputPortId = op.inputPorts.keys.head + val outputPortId = op.outputPorts.keys.head + val updated = op.propagateSchema(Some(inputPortId -> inputSchema)) + val outputSchema = inputSchema + .add(new Attribute("ai_response", AttributeType.STRING)) + .add(new Attribute("_cost_usd", AttributeType.DOUBLE)) + .add(new Attribute("_error", AttributeType.STRING)) + + assert(updated.outputPorts(outputPortId)._3.toOption.contains(outputSchema)) + } + + it should "append structured output columns during schema propagation" in { + val desc = validDesc() + desc.outputMode = AIAgentOutputMode.Structured + val sentiment = new AIAgentStructuredOutputField + sentiment.columnName = "sentiment" + sentiment.instructions = "positive, neutral, or negative" + val reason = new AIAgentStructuredOutputField + reason.columnName = "reason" + reason.instructions = "short explanation for the sentiment" + val blank = new AIAgentStructuredOutputField + blank.columnName = " " + desc.structuredOutputFields = List(sentiment, reason, blank) + + val op = desc.getPhysicalOp(workflowId, executionId) + val inputPortId = op.inputPorts.keys.head + val outputPortId = op.outputPorts.keys.head + val updated = op.propagateSchema(Some(inputPortId -> inputSchema)) + val outputSchema = inputSchema + .add(new Attribute("sentiment", AttributeType.STRING)) + .add(new Attribute("reason", AttributeType.STRING)) + .add(new Attribute("_cost_usd", AttributeType.DOUBLE)) + .add(new Attribute("_error", AttributeType.STRING)) + + assert(updated.outputPorts(outputPortId)._3.toOption.contains(outputSchema)) + } + + it should "append classification label and confidence columns during schema propagation" in { + val desc = validDesc() + desc.outputMode = AIAgentOutputMode.Classification + desc.outputColumn = "category" + + val op = desc.getPhysicalOp(workflowId, executionId) + val inputPortId = op.inputPorts.keys.head + val outputPortId = op.outputPorts.keys.head + val updated = op.propagateSchema(Some(inputPortId -> inputSchema)) + val outputSchema = inputSchema + .add(new Attribute("category", AttributeType.STRING)) + .add(new Attribute("_cost_usd", AttributeType.DOUBLE)) + .add(new Attribute("_error", AttributeType.STRING)) + + assert(updated.outputPorts(outputPortId)._3.toOption.contains(outputSchema)) + } + + it should "skip cost and error columns when structured fields already use those names" in { + val desc = validDesc() + desc.outputMode = AIAgentOutputMode.Structured + val cost = new AIAgentStructuredOutputField + cost.columnName = "_cost_usd" + val error = new AIAgentStructuredOutputField + error.columnName = "_error" + desc.structuredOutputFields = List(cost, error) + + val op = desc.getPhysicalOp(workflowId, executionId) + val inputPortId = op.inputPorts.keys.head + val outputPortId = op.outputPorts.keys.head + val updated = op.propagateSchema(Some(inputPortId -> inputSchema)) + val outputSchema = inputSchema + .add(new Attribute("_cost_usd", AttributeType.STRING)) + .add(new Attribute("_error", AttributeType.STRING)) + + assert(updated.outputPorts(outputPortId)._3.toOption.contains(outputSchema)) + } +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentOutputParserSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentOutputParserSpec.scala new file mode 100644 index 00000000000..046ee7d1f39 --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentOutputParserSpec.scala @@ -0,0 +1,152 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import org.scalatest.flatspec.AnyFlatSpec + +class AIAgentOutputParserSpec extends AnyFlatSpec { + + "AIAgentOutputParser.parseTextResult" should "extract a free-form text response" in { + val response = AIAgentOutputParser.parseTextResult( + """{"response":"useful summary"}""", + List.empty + ) + + assert(response == "useful summary") + } + + it should "validate text classification labels" in { + val response = AIAgentOutputParser.parseTextResult( + """{"response":"billing"}""", + List("technical", "billing") + ) + + assert(response == "billing") + } + + it should "fail when a text classification label is outside the allowed labels" in { + val error = intercept[IllegalArgumentException] { + AIAgentOutputParser.parseTextResult( + """{"response":"sales"}""", + List("technical", "billing") + ) + } + + assert(error.getMessage.contains("not one of")) + } + + "AIAgentOutputParser.parseStructured" should "extract configured JSON fields in order" in { + val fields = AIAgentOutputParser.parseStructured( + """{"sentiment":"positive","reason":"clear signal","score":0.91}""", + List("sentiment", "reason", "score") + ) + + assert(fields == Seq("positive", "clear signal", "0.91")) + } + + it should "return an empty string for missing or null structured fields" in { + val fields = AIAgentOutputParser.parseStructured( + """{"sentiment":null}""", + List("sentiment", "reason") + ) + + assert(fields == Seq("", "")) + } + + it should "validate structured classification field labels" in { + val field = new AIAgentStructuredOutputField + field.columnName = "sentiment" + field.fieldType = AIAgentStructuredFieldType.Classification + field.classificationLabels = List("positive", "negative") + + val fields = AIAgentOutputParser.parseStructuredFields( + """{"sentiment":"positive"}""", + List(field) + ) + + assert(fields == Seq("positive")) + } + + it should "fail when a structured classification field label is outside the allowed labels" in { + val field = new AIAgentStructuredOutputField + field.columnName = "sentiment" + field.fieldType = AIAgentStructuredFieldType.Classification + field.classificationLabels = List("positive", "negative") + + val error = intercept[IllegalArgumentException] { + AIAgentOutputParser.parseStructuredFields( + """{"sentiment":"neutral"}""", + List(field) + ) + } + + assert(error.getMessage.contains("not one of")) + } + + it should "fail when structured output is not a JSON object" in { + val error = intercept[IllegalArgumentException] { + AIAgentOutputParser.parseStructured("""["positive"]""", List("sentiment")) + } + + assert(error.getMessage.contains("JSON object")) + } + + "AIAgentOutputParser.parseClassification" should "extract a valid label and confidence" in { + val (label, confidence) = AIAgentOutputParser.parseClassification( + """{"label":"billing","confidence":0.82}""", + List("technical", "billing") + ) + + assert(label == "billing") + assert(confidence == 0.82) + } + + it should "allow a missing confidence" in { + val (label, confidence) = AIAgentOutputParser.parseClassification( + """{"label":"technical"}""", + List("technical", "billing") + ) + + assert(label == "technical") + assert(confidence == null) + } + + it should "fail when the label is outside the allowed labels" in { + val error = intercept[IllegalArgumentException] { + AIAgentOutputParser.parseClassification( + """{"label":"sales","confidence":0.4}""", + List("technical", "billing") + ) + } + + assert(error.getMessage.contains("not one of")) + } + + it should "fail when confidence is not numeric" in { + val error = intercept[IllegalArgumentException] { + AIAgentOutputParser.parseClassification( + """{"label":"billing","confidence":"high"}""", + List("technical", "billing") + ) + } + + assert(error.getMessage.contains("confidence must be numeric")) + } +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentResponseCacheSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentResponseCacheSpec.scala new file mode 100644 index 00000000000..d769a13e6c4 --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentResponseCacheSpec.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import org.scalatest.flatspec.AnyFlatSpec + +class AIAgentResponseCacheSpec extends AnyFlatSpec { + + "AIAgentResponseCache.key" should "include the API key signature" in { + val baseArgs = ( + "openai/gpt-4o-mini", + 0.7, + "system", + "prompt", + "read_url", + "text" + ) + + val keyA = AIAgentResponseCache.key( + baseArgs._1, + baseArgs._2, + AIAgentResponseCache.sha256("key-a"), + baseArgs._3, + baseArgs._4, + baseArgs._5, + baseArgs._6 + ) + val keyB = AIAgentResponseCache.key( + baseArgs._1, + baseArgs._2, + AIAgentResponseCache.sha256("key-b"), + baseArgs._3, + baseArgs._4, + baseArgs._5, + baseArgs._6 + ) + + assert(keyA != keyB) + } + + it should "be instance-local" in { + val cacheA = new AIAgentResponseCache() + val cacheB = new AIAgentResponseCache() + cacheA.put("k", "v") + + assert(cacheA.get("k").contains("v")) + assert(cacheB.get("k").isEmpty) + } +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentUrlSafetySpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentUrlSafetySpec.scala new file mode 100644 index 00000000000..1eaf8302f8e --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/AIAgentUrlSafetySpec.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import org.scalatest.flatspec.AnyFlatSpec + +class AIAgentUrlSafetySpec extends AnyFlatSpec { + + "AIAgentUrlSafety.validatePublicHttpUrl" should "allow public http and https URLs" in { + assert(AIAgentUrlSafety.validatePublicHttpUrl("https://93.184.216.34/report.pdf").getHost == "93.184.216.34") + assert(AIAgentUrlSafety.validatePublicHttpUrl("http://93.184.216.34").getScheme == "http") + } + + it should "reject non-http schemes" in { + val error = intercept[IllegalArgumentException] { + AIAgentUrlSafety.validatePublicHttpUrl("file:///etc/passwd") + } + + assert(error.getMessage.contains("http(s)")) + } + + it should "reject localhost names" in { + val error = intercept[IllegalArgumentException] { + AIAgentUrlSafety.validatePublicHttpUrl("http://localhost:9000") + } + + assert(error.getMessage.contains("Private or local")) + } + + it should "reject private IP addresses" in { + val error = intercept[IllegalArgumentException] { + AIAgentUrlSafety.validatePublicHttpUrl("http://10.0.0.5/secret.pdf") + } + + assert(error.getMessage.contains("Private or local")) + } +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/OpenRouterClientSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/OpenRouterClientSpec.scala new file mode 100644 index 00000000000..e189abbcd36 --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/OpenRouterClientSpec.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import org.scalatest.flatspec.AnyFlatSpec + +class OpenRouterClientSpec extends AnyFlatSpec { + + "OpenRouterClient.parseChatCompletionContent" should "extract the first message content" in { + val response = + """ + |{ + | "id": "gen-1", + | "choices": [ + | { + | "message": { + | "role": "assistant", + | "content": "The answer is row-wise." + | } + | } + | ] + |} + |""".stripMargin + + assert(OpenRouterClient.parseChatCompletionContent(response) == "The answer is row-wise.") + } + + it should "extract an empty assistant message" in { + val response = + """ + |{ + | "choices": [ + | { + | "message": { + | "content": "" + | } + | } + | ] + |} + |""".stripMargin + + assert(OpenRouterClient.parseChatCompletionContent(response) == "") + } + + it should "fail when choices is empty" in { + val error = intercept[RuntimeException] { + OpenRouterClient.parseChatCompletionContent("""{"choices": []}""") + } + + assert(error.getMessage.contains("choices[0].message.content")) + } + + it should "fail when message content is missing" in { + val error = intercept[RuntimeException] { + OpenRouterClient.parseChatCompletionContent("""{"choices": [{"message": {}}]}""") + } + + assert(error.getMessage.contains("choices[0].message.content")) + } + + it should "fail when the first choice is null" in { + val error = intercept[RuntimeException] { + OpenRouterClient.parseChatCompletionContent("""{"choices": [null]}""") + } + + assert(error.getMessage.contains("choices[0].message.content")) + } +} diff --git a/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/PromptTemplateSpec.scala b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/PromptTemplateSpec.scala new file mode 100644 index 00000000000..07fa4b12d98 --- /dev/null +++ b/common/workflow-operator/src/test/scala/org/apache/texera/amber/operator/aiagent/PromptTemplateSpec.scala @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.texera.amber.operator.aiagent + +import org.apache.texera.amber.core.tuple.{Attribute, AttributeType, Schema, Tuple} +import org.scalatest.flatspec.AnyFlatSpec + +class PromptTemplateSpec extends AnyFlatSpec { + private val schema: Schema = Schema() + .add(new Attribute("question", AttributeType.STRING)) + .add(new Attribute("count", AttributeType.INTEGER)) + .add(new Attribute("note", AttributeType.STRING)) + + private val tuple: Tuple = Tuple + .builder(schema) + .add(schema.getAttribute("question"), "How are rows processed?") + .add(schema.getAttribute("count"), Integer.valueOf(3)) + .add(schema.getAttribute("note"), null) + .build() + + "PromptTemplate.render" should "replace placeholders with tuple field values" in { + val rendered = PromptTemplate.render("Q: {{question}} Count: {{count}}", tuple) + + assert(rendered == "Q: How are rows processed? Count: 3") + } + + it should "support repeated placeholders and whitespace inside delimiters" in { + val rendered = PromptTemplate.render("{{ question }} -> {{question}}", tuple) + + assert(rendered == "How are rows processed? -> How are rows processed?") + } + + it should "render null fields as empty strings" in { + val rendered = PromptTemplate.render("Note={{note}}.", tuple) + + assert(rendered == "Note=.") + } + + it should "leave templates without placeholders unchanged" in { + val rendered = PromptTemplate.render("Summarize this row.", tuple) + + assert(rendered == "Summarize this row.") + } + + it should "return an empty prompt for a null template" in { + assert(PromptTemplate.render(null, tuple) == "") + } + + it should "fail when a placeholder references a missing column" in { + val error = intercept[IllegalArgumentException] { + PromptTemplate.render("{{missing}}", tuple) + } + + assert(error.getMessage.contains("missing column: missing")) + } +} diff --git a/frontend/src/app/app.module.ts b/frontend/src/app/app.module.ts index 21928b77039..0f42d76bdc9 100644 --- a/frontend/src/app/app.module.ts +++ b/frontend/src/app/app.module.ts @@ -184,6 +184,7 @@ import { ComputingUnitSelectionComponent } from "./workspace/component/power-but import { NzSliderModule } from "ng-zorro-antd/slider"; import { AdminSettingsComponent } from "./dashboard/component/admin/settings/admin-settings.component"; import { FormlyRepeatDndComponent } from "./common/formly/repeat-dnd/repeat-dnd.component"; +import { TagsInputComponent } from "./common/formly/tags-input.component"; import { NzInputNumberModule } from "ng-zorro-antd/input-number"; import { NzGridModule } from "ng-zorro-antd/grid"; import { NzCheckboxModule } from "ng-zorro-antd/checkbox"; @@ -265,6 +266,7 @@ registerLocaleData(en); NzGridModule, ScrollingModule, FormlyRepeatDndComponent, + TagsInputComponent, AdminGmailComponent, PublicProjectComponent, WorkspaceComponent, diff --git a/frontend/src/app/common/formly/formly-config.ts b/frontend/src/app/common/formly/formly-config.ts index c3995abb544..8f32775d5ee 100644 --- a/frontend/src/app/common/formly/formly-config.ts +++ b/frontend/src/app/common/formly/formly-config.ts @@ -28,6 +28,8 @@ import { DatasetFileSelectorComponent } from "../../workspace/component/dataset- import { CollabWrapperComponent } from "./collab-wrapper/collab-wrapper/collab-wrapper.component"; import { FormlyRepeatDndComponent } from "./repeat-dnd/repeat-dnd.component"; import { DatasetVersionSelectorComponent } from "../../workspace/component/dataset-version-selector/dataset-version-selector.component"; +import { OpenRouterModelSelectorComponent } from "./openrouter-model-selector.component"; +import { TagsInputComponent } from "./tags-input.component"; /** * Configuration for using Json Schema with Formly. @@ -79,6 +81,8 @@ export const TEXERA_FORMLY_CONFIG = { { name: "codearea", component: CodeareaCustomTemplateComponent }, { name: "inputautocomplete", component: DatasetFileSelectorComponent, wrappers: ["form-field"] }, { name: "datasetversionselector", component: DatasetVersionSelectorComponent, wrappers: ["form-field"] }, + { name: "openrouter-model-selector", component: OpenRouterModelSelectorComponent, wrappers: ["form-field"] }, + { name: "tags-input", component: TagsInputComponent, wrappers: ["form-field"] }, { name: "repeat-section-dnd", component: FormlyRepeatDndComponent }, ], wrappers: [ diff --git a/frontend/src/app/common/formly/openrouter-model-selector.component.ts b/frontend/src/app/common/formly/openrouter-model-selector.component.ts new file mode 100644 index 00000000000..a013ddf1e56 --- /dev/null +++ b/frontend/src/app/common/formly/openrouter-model-selector.component.ts @@ -0,0 +1,96 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { NgFor } from "@angular/common"; +import { Component } from "@angular/core"; +import { ReactiveFormsModule } from "@angular/forms"; +import { FieldType, FieldTypeConfig, FormlyModule } from "@ngx-formly/core"; +import { NzOptionComponent, NzOptionGroupComponent, NzSelectComponent } from "ng-zorro-antd/select"; + +interface OpenRouterModelOption { + value: string; + label: string; + company?: string; +} + +interface OpenRouterModelGroup { + company: string; + options: OpenRouterModelOption[]; +} + +@Component({ + selector: "texera-openrouter-model-selector", + template: ` + + + + + + + `, + imports: [FormlyModule, NgFor, NzOptionComponent, NzOptionGroupComponent, NzSelectComponent, ReactiveFormsModule], +}) +export class OpenRouterModelSelectorComponent extends FieldType { + searchValue = ""; + + get modelOptions(): OpenRouterModelOption[] { + return Array.isArray(this.props.options) ? this.props.options : []; + } + + get visibleModelOptions(): OpenRouterModelOption[] { + const searchValue = this.searchValue.trim(); + if ( + !searchValue || + this.modelOptions.some( + option => option.value === searchValue || option.label.toLowerCase() === searchValue.toLowerCase() + ) + ) { + return this.modelOptions; + } + return [{ value: searchValue, label: searchValue }, ...this.modelOptions]; + } + + get visibleModelGroups(): OpenRouterModelGroup[] { + const groups = new Map(); + this.visibleModelOptions.forEach(option => { + const company = option.company ?? ""; + groups.set(company, [...(groups.get(company) ?? []), option]); + }); + + return Array.from(groups.entries()).map(([company, options]) => ({ company, options })); + } + + onSearch(value: string): void { + this.searchValue = value; + } + + filterOption = (input: string, option: any): boolean => { + const searchValue = input.toLowerCase(); + const value = String(option.nzValue ?? option.value ?? "").toLowerCase(); + const label = String(option.nzLabel ?? option.label ?? "").toLowerCase(); + return value.includes(searchValue) || label.includes(searchValue); + }; +} diff --git a/frontend/src/app/common/formly/tags-input.component.ts b/frontend/src/app/common/formly/tags-input.component.ts new file mode 100644 index 00000000000..e6098780648 --- /dev/null +++ b/frontend/src/app/common/formly/tags-input.component.ts @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +import { Component } from "@angular/core"; +import { ReactiveFormsModule } from "@angular/forms"; +import { FieldType, FieldTypeConfig, FormlyModule } from "@ngx-formly/core"; +import { NzSelectComponent } from "ng-zorro-antd/select"; + +@Component({ + selector: "texera-tags-input", + template: ` + + + `, + imports: [FormlyModule, NzSelectComponent, ReactiveFormsModule], +}) +export class TagsInputComponent extends FieldType {} diff --git a/frontend/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts b/frontend/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts index c7ab561f403..bd814d4feef 100644 --- a/frontend/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts +++ b/frontend/src/app/workspace/component/property-editor/operator-property-edit-frame/operator-property-edit-frame.component.ts @@ -17,10 +17,11 @@ * under the License. */ +import { HttpClient } from "@angular/common/http"; import { ChangeDetectorRef, Component, Input, OnChanges, OnDestroy, OnInit, SimpleChanges } from "@angular/core"; import { ExecuteWorkflowService } from "../../../service/execute-workflow/execute-workflow.service"; import { WorkflowStatusService } from "../../../service/workflow-status/workflow-status.service"; -import { Subject } from "rxjs"; +import { Observable, of, Subject } from "rxjs"; import { AbstractControl, FormGroup, FormsModule, ReactiveFormsModule } from "@angular/forms"; import { FormlyFieldConfig, FormlyFormOptions, FormlyModule } from "@ngx-formly/core"; import Ajv from "ajv"; @@ -50,7 +51,7 @@ import { TypeCastingDisplayComponent, } from "../typecasting-display/type-casting-display.component"; import { UntilDestroy, untilDestroyed } from "@ngneat/until-destroy"; -import { filter } from "rxjs/operators"; +import { catchError, filter, map, shareReplay } from "rxjs/operators"; import { NotificationService } from "../../../../common/service/notification/notification.service"; import { PresetWrapperComponent } from "src/app/common/formly/preset-wrapper/preset-wrapper.component"; import { WorkflowVersionService } from "../../../../dashboard/service/user/workflow-version/workflow-version.service"; @@ -70,9 +71,41 @@ import { NzIconDirective } from "ng-zorro-antd/icon"; import { NzPopoverDirective } from "ng-zorro-antd/popover"; import { NzFormDirective } from "ng-zorro-antd/form"; import { NzWaveDirective } from "ng-zorro-antd/core/wave"; +import { AppSettings } from "../../../../common/app-setting"; Quill.register("modules/cursors", QuillCursors); +interface OpenRouterModelSummary { + id: string; + name: string; + contextLength?: number; + pricing: Record; +} + +interface OpenRouterModelsResponse { + data: OpenRouterModelSummary[]; +} + +interface OpenRouterModelOption { + value: string; + label: string; + company: string; +} + +const OPENROUTER_PROVIDER_DISPLAY_NAMES: Record = { + anthropic: "Anthropic", + cohere: "Cohere", + deepseek: "DeepSeek", + google: "Google", + meta: "Meta", + microsoft: "Microsoft", + mistralai: "Mistral AI", + openai: "OpenAI", + perplexity: "Perplexity", + qwen: "Qwen", + xai: "xAI", +}; + /** * Property Editor uses JSON Schema to automatically generate the form from the JSON Schema of an operator. * For example, the JSON Schema of Sentiment Analysis could be: @@ -160,6 +193,7 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On public operatorVersion: string = ""; quillBinding?: QuillBinding; quill!: Quill; + private openRouterModelOptions$?: Observable; // used to tear down subscriptions that takeUntil(teardownObservable) private teardownObservable: Subject = new Subject(); @@ -173,7 +207,8 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On private changeDetectorRef: ChangeDetectorRef, private workflowVersionService: WorkflowVersionService, private workflowStatusSerivce: WorkflowStatusService, - private config: GuiConfigService + private config: GuiConfigService, + private httpClient: HttpClient ) {} ngOnChanges(changes: SimpleChanges): void { @@ -240,7 +275,9 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On // set the operator data needed this.workflowActionService.setOperatorVersion(operator.operatorID, this.currentOperatorSchema.operatorVersion); this.operatorVersion = operator.operatorVersion.slice(0, 9); - this.setFormlyFormBinding(this.currentOperatorSchema.jsonSchema); + const jsonSchema = cloneDeep(this.currentOperatorSchema.jsonSchema); + this.setFormlyFormBinding(jsonSchema); + this.populateOpenRouterModelOptions(this.currentOperatorId, jsonSchema); this.formTitle = operator.customDisplayName ?? this.currentOperatorSchema.additionalMetadata.userFriendlyName; this.operatorDescription = this.currentOperatorSchema.additionalMetadata.operatorDescription; /** @@ -248,6 +285,7 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On * Prevent the form directly changes the value in the texera graph without going through workflow action service. */ this.formData = cloneDeep(operator.operatorProperties); + this.normalizeAttributeNameListValues(jsonSchema, this.formData); // use ajv to initialize the default value to data according to schema, see https://ajv.js.org/#assigning-defaults // WorkflowUtil service also makes sure that the default values are filled in when operator is added from the UI @@ -255,7 +293,7 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On // 1. the operator might be added not directly from the UI, which violates the precondition // 2. the schema might change, which specifies a new default value // 3. formly doesn't emit change event when it fills in default value, causing an inconsistency between component and service - this.ajv.validate(this.currentOperatorSchema.jsonSchema, this.formData); + this.ajv.validate(jsonSchema, this.formData); // manually trigger a form change event because default value might be filled in this.onFormChanges(this.formData); @@ -549,6 +587,18 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On }; } + if ( + this.currentOperatorSchema?.operatorType === "AIAgent" && + mappedField.key === "model" && + isDefined(mapSource.openRouterModelOptions) + ) { + mappedField.type = "openrouter-model-selector"; + mappedField.props = { + ...mappedField.props, + options: [...mapSource.openRouterModelOptions], + }; + } + // Add custom validators for attribute type if (isDefined(mapSource.attributeTypeRules)) { mappedField.validators.checkAttributeType = { @@ -736,6 +786,87 @@ export class OperatorPropertyEditFrameComponent implements OnInit, OnChanges, On this.formlyFields = [field]; } + private populateOpenRouterModelOptions(operatorId: string, schema: CustomJSONSchema7): void { + if (this.currentOperatorSchema?.operatorType !== "AIAgent") { + return; + } + this.fetchOpenRouterModelOptions() + .pipe(untilDestroyed(this)) + .subscribe(modelOptions => { + if ( + this.currentOperatorId !== operatorId || + this.currentOperatorSchema?.operatorType !== "AIAgent" || + modelOptions.length === 0 + ) { + return; + } + const modelSchema = schema.properties?.model; + if (typeof modelSchema === "boolean" || !modelSchema) { + return; + } + modelSchema.openRouterModelOptions = modelOptions; + this.setFormlyFormBinding(schema); + this.changeDetectorRef.detectChanges(); + }); + } + + private fetchOpenRouterModelOptions(): Observable { + if (!this.openRouterModelOptions$) { + this.openRouterModelOptions$ = this.httpClient + .get(`${AppSettings.getApiEndpoint()}/models/openrouter`) + .pipe( + map(response => + response.data + .map(model => ({ + value: model.id, + label: model.name === model.id ? model.id : `${model.name} (${model.id})`, + company: this.getOpenRouterProviderName(model.id), + })) + .sort((a, b) => { + const companyComparison = a.company.localeCompare(b.company); + if (companyComparison !== 0) { + return companyComparison; + } + return a.label.localeCompare(b.label, undefined, { numeric: true, sensitivity: "base" }); + }) + ), + catchError(() => of([])), + shareReplay(1) + ); + } + return this.openRouterModelOptions$; + } + + private getOpenRouterProviderName(modelId: string): string { + const providerSlug = modelId.split("/")[0]?.replace(/^~/, "") ?? ""; + return ( + OPENROUTER_PROVIDER_DISPLAY_NAMES[providerSlug] ?? + providerSlug + .split(/[-_]/) + .filter(Boolean) + .map(segment => segment.charAt(0).toUpperCase() + segment.slice(1)) + .join(" ") + ); + } + + private normalizeAttributeNameListValues(schema: CustomJSONSchema7, model: Record): void { + Object.entries(schema.properties ?? {}).forEach(([propertyName, propertySchema]) => { + if ( + typeof propertySchema === "boolean" || + propertySchema.autofill !== "attributeNameList" || + !Object.prototype.hasOwnProperty.call(model, propertyName) + ) { + return; + } + const value = model[propertyName]; + if (typeof value === "string") { + model[propertyName] = value === "" ? [] : [value]; + } else if (value === null || value === undefined) { + model[propertyName] = []; + } + }); + } + allowModifyOperatorLogic(): void { this.setInteractivity(true); } diff --git a/frontend/src/app/workspace/types/custom-json-schema.interface.ts b/frontend/src/app/workspace/types/custom-json-schema.interface.ts index 50edb681618..e189fc8dc01 100644 --- a/frontend/src/app/workspace/types/custom-json-schema.interface.ts +++ b/frontend/src/app/workspace/types/custom-json-schema.interface.ts @@ -69,4 +69,9 @@ export interface CustomJSONSchema7 extends JSONSchema7 { hideOnNull?: boolean; additionalEnumValue?: string; + openRouterModelOptions?: ReadonlyArray<{ + value: string; + label: string; + company?: string; + }>; } diff --git a/frontend/src/assets/operator_images/AIAgent.png b/frontend/src/assets/operator_images/AIAgent.png new file mode 100644 index 00000000000..3e63b549dea Binary files /dev/null and b/frontend/src/assets/operator_images/AIAgent.png differ