Skip to content

Commit cd78d87

Browse files
allen-stephencopybara-github
authored andcommitted
feat(a2a): add support for persistent task stores
Merge #5597 **Please ensure you have read the [contribution guide](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) before creating a pull request.** ### Link to Issue or Description of Change **1. Link to an existing issue (if applicable):** - Related: #4971 **2. Or, if no issue exists, describe the change:** _If applicable, please follow the issue templates to provide as much detail as possible._ **Problem:** A2A support lacked pluggable or persistent task store backends, forcing a strict default to `InMemoryTaskStore`. **Solution:** Extended `ServiceRegistry` and `ServiceFactory` to support URI-driven configuration for A2A task stores. Added built-in support for `memory://`, `postgresql://`, `mysql://`, and `sqlite://` schemes. Plumbed the options down into `to_a2a()` and `get_fast_api_app()`, ensuring connection strings are securely redacted from application logs. ### Testing Plan **Unit Tests:** - [x] I have added or updated unit tests for my change. - [x] All unit tests pass locally. Summary: Ran `pytest` against `test_agent_to_a2a.py`, `test_fast_api.py`, `test_service_registry.py`, and `test_service_factory.py`. All 131 unit tests passed successfully. **Manual End-to-End (E2E) Tests:** Compiled the package distribution wheel via `uv build` and successfully validated its installation inside a clean isolated sandbox environment via `uv pip install dist/google_adk-1.32.0-py3-none-any.whl[a2a]`. Verified that the application instantiates default in-memory stores and custom persistent URI stores ### Checklist - [x] I have read the [CONTRIBUTING.md](https://github.com/google/adk-python/blob/main/CONTRIBUTING.md) document. - [x] I have performed a self-review of my own code. - [x] I have commented my code, particularly in hard-to-understand areas. - [x] I have added tests that prove my fix is effective or that my feature works. - [x] New and existing unit tests pass locally with my changes. - [x] I have manually tested my changes end-to-end. - [x] Any dependent changes have been merged and published in downstream modules. COPYBARA_INTEGRATE_REVIEW=#5597 from allen-stephen:feat/a2a-task-store-support f53ba0f PiperOrigin-RevId: 914496800
1 parent 76b9f0b commit cd78d87

9 files changed

Lines changed: 539 additions & 46 deletions

File tree

src/google/adk/a2a/utils/agent_to_a2a.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,13 @@
1818
import logging
1919
from typing import AsyncIterator
2020
from typing import Callable
21-
from typing import Optional
22-
from typing import Union
2321

2422
from a2a.server.apps import A2AStarletteApplication
2523
from a2a.server.request_handlers import DefaultRequestHandler
2624
from a2a.server.tasks import InMemoryPushNotificationConfigStore
2725
from a2a.server.tasks import InMemoryTaskStore
2826
from a2a.server.tasks import PushNotificationConfigStore
27+
from a2a.server.tasks import TaskStore
2928
from a2a.types import AgentCard
3029
from starlette.applications import Starlette
3130

@@ -41,8 +40,8 @@
4140

4241

4342
def _load_agent_card(
44-
agent_card: Optional[Union[AgentCard, str]],
45-
) -> Optional[AgentCard]:
43+
agent_card: AgentCard | str | None,
44+
) -> AgentCard | None:
4645
"""Load agent card from various sources.
4746
4847
Args:
@@ -82,10 +81,11 @@ def to_a2a(
8281
host: str = "localhost",
8382
port: int = 8000,
8483
protocol: str = "http",
85-
agent_card: Optional[Union[AgentCard, str]] = None,
86-
push_config_store: Optional[PushNotificationConfigStore] = None,
87-
runner: Optional[Runner] = None,
88-
lifespan: Optional[Callable[[Starlette], AsyncIterator[None]]] = None,
84+
agent_card: AgentCard | str | None = None,
85+
push_config_store: PushNotificationConfigStore | None = None,
86+
task_store: TaskStore | None = None,
87+
runner: Runner | None = None,
88+
lifespan: Callable[[Starlette], AsyncIterator[None]] | None = None,
8989
) -> Starlette:
9090
"""Convert an ADK agent to a A2A Starlette application.
9191
@@ -100,6 +100,8 @@ def to_a2a(
100100
push_config_store: Optional A2A push notification config store. If not
101101
provided, an in-memory store will be created so push-notification
102102
config RPC methods are supported.
103+
task_store: Optional A2A task store for persisting task state. If not
104+
provided, an in-memory store will be created.
103105
runner: Optional pre-built Runner object. If not provided, a default
104106
runner will be created using in-memory services.
105107
lifespan: Optional async context manager for Starlette lifespan
@@ -127,6 +129,20 @@ async def lifespan(app):
127129
await app.state.db.close()
128130
129131
app = to_a2a(agent, lifespan=lifespan)
132+
133+
# Or with a persistent task store (the caller owns engine disposal):
134+
from a2a.server.tasks import DatabaseTaskStore
135+
from sqlalchemy.ext.asyncio import create_async_engine
136+
137+
engine = create_async_engine("postgresql+asyncpg://...")
138+
task_store = DatabaseTaskStore(engine=engine)
139+
140+
@asynccontextmanager
141+
async def lifespan(app):
142+
yield
143+
await engine.dispose()
144+
145+
app = to_a2a(agent, task_store=task_store, lifespan=lifespan)
130146
"""
131147
# Set up ADK logging to ensure logs are visible when using uvicorn directly
132148
adk_logger = logging.getLogger("google_adk")
@@ -145,7 +161,8 @@ async def create_runner() -> Runner:
145161
)
146162

147163
# Create A2A components
148-
task_store = InMemoryTaskStore()
164+
if task_store is None:
165+
task_store = InMemoryTaskStore()
149166

150167
agent_executor = A2aAgentExecutor(
151168
runner=runner or create_runner,

src/google/adk/cli/fast_api.py

Lines changed: 50 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
from contextlib import asynccontextmanager
1718
import importlib
1819
import json
1920
import logging
@@ -24,7 +25,6 @@
2425
from typing import Any
2526
from typing import Literal
2627
from typing import Mapping
27-
from typing import Optional
2828

2929
import click
3030
from fastapi import FastAPI
@@ -48,6 +48,7 @@
4848
from .utils.agent_change_handler import AgentChangeEventHandler
4949
from .utils.agent_loader import AgentLoader
5050
from .utils.base_agent_loader import BaseAgentLoader
51+
from .utils.service_factory import _create_task_store_from_options
5152
from .utils.service_factory import create_artifact_service_from_options
5253
from .utils.service_factory import create_memory_service_from_options
5354
from .utils.service_factory import create_session_service_from_options
@@ -75,28 +76,29 @@ def __getattr__(name: str):
7576
def get_fast_api_app(
7677
*,
7778
agents_dir: str,
78-
agent_loader: Optional[BaseAgentLoader] = None,
79-
session_service_uri: Optional[str] = None,
80-
session_db_kwargs: Optional[Mapping[str, Any]] = None,
81-
artifact_service_uri: Optional[str] = None,
82-
memory_service_uri: Optional[str] = None,
79+
agent_loader: BaseAgentLoader | None = None,
80+
session_service_uri: str | None = None,
81+
session_db_kwargs: Mapping[str, Any] | None = None,
82+
artifact_service_uri: str | None = None,
83+
memory_service_uri: str | None = None,
8384
use_local_storage: bool = True,
84-
eval_storage_uri: Optional[str] = None,
85-
allow_origins: Optional[list[str]] = None,
85+
eval_storage_uri: str | None = None,
86+
allow_origins: list[str] | None = None,
8687
web: bool,
8788
a2a: bool = False,
89+
task_store_uri: str | None = None,
8890
host: str = "127.0.0.1",
8991
port: int = 8000,
90-
url_prefix: Optional[str] = None,
92+
url_prefix: str | None = None,
9193
trace_to_cloud: bool = False,
9294
otel_to_cloud: bool = False,
9395
reload_agents: bool = False,
94-
lifespan: Optional[Lifespan[FastAPI]] = None,
95-
extra_plugins: Optional[list[str]] = None,
96-
logo_text: Optional[str] = None,
97-
logo_image_url: Optional[str] = None,
96+
lifespan: Lifespan[FastAPI] | None = None,
97+
extra_plugins: list[str] | None = None,
98+
logo_text: str | None = None,
99+
logo_image_url: str | None = None,
98100
auto_create_session: bool = False,
99-
trigger_sources: Optional[list[Literal["pubsub", "eventarc"]]] = None,
101+
trigger_sources: list[Literal["pubsub", "eventarc"]] | None = None,
100102
) -> FastAPI:
101103
"""Constructs and returns a FastAPI application for serving ADK agents.
102104
@@ -128,6 +130,8 @@ def get_fast_api_app(
128130
allow_origins: List of allowed origins for CORS.
129131
web: Whether to enable the web UI and serve its assets.
130132
a2a: Whether to enable Agent-to-Agent (A2A) protocol support.
133+
task_store_uri: URI for the A2A task store. Uses in-memory task store if
134+
None. Only used when ``a2a=True``.
131135
host: Host address for the server (defaults to 127.0.0.1).
132136
port: Port number for the server (defaults to 8000).
133137
url_prefix: Optional prefix for all URL routes.
@@ -272,6 +276,33 @@ def tear_down_observer(observer: Observer, _: AdkWebServer):
272276
web_assets_dir=ANGULAR_DIST_PATH,
273277
)
274278

279+
# Create the task store early so its engine can be disposed via the
280+
# lifespan, preventing connection pool leaks on shutdown.
281+
a2a_task_store = None
282+
if a2a:
283+
base_path = Path.cwd() / agents_dir
284+
if base_path.exists() and base_path.is_dir():
285+
a2a_task_store = _create_task_store_from_options(
286+
task_store_uri=task_store_uri,
287+
)
288+
289+
if a2a_task_store is not None and hasattr(a2a_task_store, "engine"):
290+
outer_lifespan = lifespan
291+
292+
@asynccontextmanager
293+
async def _a2a_lifespan(app_instance: FastAPI):
294+
try:
295+
if outer_lifespan:
296+
async with outer_lifespan(app_instance) as ctx:
297+
yield ctx
298+
else:
299+
yield
300+
finally:
301+
logger.info("Disposing A2A task store engine")
302+
await a2a_task_store.engine.dispose()
303+
304+
lifespan = _a2a_lifespan
305+
275306
app = adk_web_server.get_fast_api_app(
276307
lifespan=lifespan,
277308
allow_origins=allow_origins,
@@ -339,7 +370,7 @@ def _walk(node: Any) -> None:
339370
for doc in docs:
340371
_walk(doc)
341372

342-
def _parse_upload_filename(filename: Optional[str]) -> tuple[str, str]:
373+
def _parse_upload_filename(filename: str | None) -> tuple[str, str]:
343374
if not filename:
344375
raise ValueError("Upload filename is missing.")
345376
filename = _normalize_relative_path(filename)
@@ -473,7 +504,7 @@ def ensure_tmp_exists(app_name: str) -> bool:
473504

474505
@app.post("/builder/save", response_model_exclude_none=True)
475506
async def builder_build(
476-
files: list[UploadFile], tmp: Optional[bool] = False
507+
files: list[UploadFile], tmp: bool | None = False
477508
) -> bool:
478509
try:
479510
# Phase 1: parse filenames and read content into memory.
@@ -544,8 +575,8 @@ async def builder_cancel(app_name: str) -> bool:
544575
)
545576
async def get_agent_builder(
546577
app_name: str,
547-
file_path: Optional[str] = None,
548-
tmp: Optional[bool] = False,
578+
file_path: str | None = None,
579+
tmp: bool | None = False,
549580
):
550581
try:
551582
app_root = _get_app_root(app_name)
@@ -584,11 +615,10 @@ async def get_agent_builder(
584615
headers={"Cache-Control": "no-store"},
585616
)
586617

587-
if a2a:
618+
if a2a and a2a_task_store is not None:
588619
from a2a.server.apps import A2AStarletteApplication
589620
from a2a.server.request_handlers import DefaultRequestHandler
590621
from a2a.server.tasks import InMemoryPushNotificationConfigStore
591-
from a2a.server.tasks import InMemoryTaskStore
592622
from a2a.types import AgentCard
593623
from a2a.utils.constants import AGENT_CARD_WELL_KNOWN_PATH
594624

@@ -598,7 +628,6 @@ async def get_agent_builder(
598628
base_path = Path.cwd() / agents_dir
599629
# the root agents directory should be an existing folder
600630
if base_path.exists() and base_path.is_dir():
601-
a2a_task_store = InMemoryTaskStore()
602631

603632
def create_a2a_runner_loader(captured_app_name: str):
604633
"""Factory function to create A2A runner with proper closure."""

src/google/adk/cli/service_registry.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ def my_session_factory(uri: str, **kwargs):
6969
from pathlib import Path
7070
import sys
7171
from typing import Any
72-
from typing import Optional
7372
from typing import Protocol
7473
from urllib.parse import unquote
7574
from urllib.parse import urlparse
@@ -98,6 +97,7 @@ def __init__(self):
9897
self._session_factories: dict[str, ServiceFactory] = {}
9998
self._artifact_factories: dict[str, ServiceFactory] = {}
10099
self._memory_factories: dict[str, ServiceFactory] = {}
100+
self._task_store_factories: dict[str, ServiceFactory] = {}
101101

102102
def register_session_service(
103103
self, scheme: str, factory: ServiceFactory
@@ -123,6 +123,12 @@ def register_memory_service(
123123
"""Register a factory for a custom memory service URI scheme."""
124124
self._memory_factories[scheme] = factory
125125

126+
def _register_task_store_service(
127+
self, scheme: str, factory: ServiceFactory
128+
) -> None:
129+
"""Register a factory for a custom A2A task store URI scheme."""
130+
self._task_store_factories[scheme] = factory
131+
126132
def create_session_service(
127133
self, uri: str, **kwargs
128134
) -> BaseSessionService | None:
@@ -150,6 +156,17 @@ def create_memory_service(
150156
return self._memory_factories[scheme](uri, **kwargs)
151157
return None
152158

159+
def _create_task_store_service(self, uri: str, **kwargs: Any) -> Any:
160+
"""Create A2A task store from URI using registered factories."""
161+
scheme = urlparse(uri).scheme
162+
if scheme and scheme in self._task_store_factories:
163+
return self._task_store_factories[scheme](uri, **kwargs)
164+
supported = sorted(self._task_store_factories.keys())
165+
raise ValueError(
166+
f"Unsupported A2A task store URI scheme: '{scheme}'."
167+
f" Supported schemes: {supported}"
168+
)
169+
153170

154171
def get_service_registry() -> ServiceRegistry:
155172
"""Gets the singleton ServiceRegistry instance, initializing it if needed."""
@@ -333,9 +350,42 @@ def agentengine_memory_factory(uri: str, **kwargs):
333350
registry.register_memory_service("rag", rag_memory_factory)
334351
registry.register_memory_service("agentengine", agentengine_memory_factory)
335352

353+
# -- A2A Task Store Services --
354+
def memory_task_store_factory(uri: str, **kwargs: Any) -> Any:
355+
try:
356+
from a2a.server.tasks import InMemoryTaskStore
357+
except ImportError as e:
358+
raise ImportError(
359+
"A2A task store support requires the 'a2a' package."
360+
" Install it with: pip install google-adk[a2a]"
361+
) from e
362+
363+
return InMemoryTaskStore()
364+
365+
def database_task_store_factory(uri: str, **kwargs: Any) -> Any:
366+
try:
367+
from a2a.server.tasks import DatabaseTaskStore
368+
except ImportError as e:
369+
raise ImportError(
370+
"A2A task store support requires the 'a2a' package."
371+
" Install it with: pip install google-adk[a2a]"
372+
) from e
373+
from sqlalchemy.ext.asyncio import create_async_engine
374+
375+
engine = create_async_engine(uri)
376+
return DatabaseTaskStore(engine=engine)
377+
378+
registry._register_task_store_service("memory", memory_task_store_factory)
379+
for scheme in [
380+
"postgresql+asyncpg",
381+
"mysql+aiomysql",
382+
"sqlite+aiosqlite",
383+
]:
384+
registry._register_task_store_service(scheme, database_task_store_factory)
385+
336386

337387
def _load_gcp_config(
338-
agents_dir: Optional[str], service_name: str
388+
agents_dir: str | None, service_name: str
339389
) -> tuple[str, str]:
340390
"""Loads GCP project and location from environment."""
341391
if not agents_dir:
@@ -355,7 +405,7 @@ def _load_gcp_config(
355405

356406

357407
def _parse_agent_engine_kwargs(
358-
uri_part: str, agents_dir: Optional[str]
408+
uri_part: str, agents_dir: str | None
359409
) -> dict[str, Any]:
360410
"""Helper to parse agent engine resource name."""
361411
if not uri_part:
@@ -437,5 +487,7 @@ def _register_services_from_yaml_config(
437487
registry.register_artifact_service(scheme, factory)
438488
elif service_type == "memory":
439489
registry.register_memory_service(scheme, factory)
490+
elif service_type == "task_store":
491+
registry._register_task_store_service(scheme, factory)
440492
else:
441493
logger.warning("Unknown service type in YAML: %s", service_type)

0 commit comments

Comments
 (0)