Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 117 additions & 1 deletion veadk/cloud/harness_app/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,21 +32,25 @@
uvicorn app:app --host 0.0.0.0 --port 8000
"""

import json
import os
import tempfile
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any

from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from google.adk.agents import RunConfig
from google.adk.agents.run_config import StreamingMode
from google.adk.agents.base_agent import BaseAgent
from google.adk.artifacts.in_memory_artifact_service import InMemoryArtifactService
from google.adk.auth.credential_service.in_memory_credential_service import (
InMemoryCredentialService,
)
from google.adk.cli.adk_web_server import AdkWebServer
from google.adk.cli.adk_web_server import AdkWebServer, RunAgentRequest
from google.adk.cli.utils.base_agent_loader import BaseAgentLoader
from google.adk.utils.context_utils import Aclosing
from google.adk.evaluation.local_eval_set_results_manager import (
LocalEvalSetResultsManager,
)
Expand All @@ -58,6 +62,7 @@
from veadk.a2a.utils.agent_to_a2a import to_a2a
from veadk.cloud.harness_app.agent import agent, short_term_memory
from veadk.cloud.harness_app.types import (
HarnessOverrides,
InvokeHarnessRequest,
InvokeHarnessResponse,
)
Expand Down Expand Up @@ -113,6 +118,16 @@ def list_agents_detailed(self) -> list[dict[str, Any]]:
]


class HarnessRunAgentRequest(RunAgentRequest):
"""ADK ``/run_sse`` request plus an optional once-time harness override.

When ``harness`` is set, the streaming run uses a spawned agent (base agent
cloned with the override applied); otherwise it uses the base agent.
"""

harness: HarnessOverrides | None = None


class HarnessApp:
def __init__(
self,
Expand Down Expand Up @@ -160,6 +175,7 @@ async def lifespan(app: FastAPI):
# it catches the well-known / RPC paths the ADK routes don't claim.
self.app = self._server.get_fast_api_app(lifespan=lifespan)
self.mount()
self._mount_run_sse_override()
self.app.mount("/", self._a2a_app)

def mount(self):
Expand Down Expand Up @@ -238,6 +254,106 @@ async def invoke_harness(
output=output,
)

def _mount_run_sse_override(self):
"""Override ADK's ``/run_sse`` so it honors once-time harness overrides.

ADK's default ``/run_sse`` always runs the served (base) agent. We wrap it:
when the request carries a ``harness`` override, stream a *spawned* agent
(base cloned + override applied); otherwise **delegate to ADK's original
handler unchanged** — so the no-override path is identical to stock run_sse.
"""
# Capture ADK's default /run_sse handler to delegate to when there is no
# override (keeps the base path bit-for-bit ADK behavior).
adk_run_sse = None
for r in self.app.router.routes:
if getattr(r, "path", None) == "/run_sse" and "POST" in getattr(
r, "methods", set()
):
adk_run_sse = r.endpoint
break

@self.app.post("/run_sse")
async def run_sse(req: HarnessRunAgentRequest):
if req.harness is None and adk_run_sse is not None:
# No override -> exactly ADK's default /run_sse.
return await adk_run_sse(req)
return StreamingResponse(
self._run_sse_events(req), media_type="text/event-stream"
)

# Move ours to the front so it wins (Starlette matches the first route),
# without deleting the default we delegate to.
routes = self.app.router.routes
for i, r in enumerate(routes):
if getattr(r, "path", None) == "/run_sse" and (
getattr(r, "endpoint", None) is run_sse
):
routes.insert(0, routes.pop(i))
break

async def _run_sse_events(self, req: "HarnessRunAgentRequest"):
"""Yield SSE ``data:`` lines for a run, spawning the agent on override."""
run_config = RunConfig(
streaming_mode=StreamingMode.SSE if req.streaming else StreamingMode.NONE
)
work_dir_ctx = None
try:
if req.harness is not None:
logger.info(f"run_sse once-time override: {req.harness}")
# Skills may download into a temp dir read from disk during the
# run, so keep it alive for the whole stream.
work_dir_ctx = tempfile.TemporaryDirectory(prefix="harness_run_sse_")
try:
agent = spawn_harness_agent(
self.agent,
req.harness,
download_dir=Path(work_dir_ctx.name),
)
except (SkillLoadError, ToolLoadError) as e:
logger.error(f"Once-time override failed to load: {e}")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
return
else:
agent = self.agent

runner = Runner(
agent=agent,
short_term_memory=self.short_term_memory,
app_name=req.app_name,
)
# Be self-sufficient: create the session if the caller did not.
if not await runner.session_service.get_session(
app_name=req.app_name,
user_id=req.user_id,
session_id=req.session_id,
):
await runner.session_service.create_session(
app_name=req.app_name,
user_id=req.user_id,
session_id=req.session_id,
)

async with Aclosing(
runner.run_async(
user_id=req.user_id,
session_id=req.session_id,
new_message=req.new_message,
run_config=run_config,
)
) as agen:
async for event in agen:
yield (
"data: "
+ event.model_dump_json(exclude_none=True, by_alias=True)
+ "\n\n"
)
except Exception as e:
logger.exception("run_sse failed")
yield f"data: {json.dumps({'error': str(e)})}\n\n"
finally:
if work_dir_ctx is not None:
work_dir_ctx.cleanup()

def serve(self, host: str = "0.0.0.0", port: int = 8000) -> None:
import uvicorn

Expand Down
Loading