From 036b406f3811721c3e82f91a68e599117b6402ca Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sat, 13 Jun 2026 22:13:10 +0000 Subject: [PATCH 1/2] feat(adk): align stores with ADK 2 contract --- sqlspec/adapters/adbc/adk/store.py | 694 ++++- sqlspec/adapters/aiomysql/adk/store.py | 819 +++--- sqlspec/adapters/aiosqlite/adk/store.py | 614 ++-- sqlspec/adapters/asyncmy/adk/store.py | 810 +++--- sqlspec/adapters/asyncpg/adk/store.py | 360 ++- .../adapters/cockroach_asyncpg/adk/store.py | 429 ++- .../adapters/cockroach_psycopg/adk/store.py | 1098 +++++--- sqlspec/adapters/duckdb/adk/store.py | 685 +++-- sqlspec/adapters/mysqlconnector/adk/store.py | 1255 +++++---- sqlspec/adapters/oracledb/adk/store.py | 2492 +++++++++++------ sqlspec/adapters/psqlpy/adk/store.py | 573 ++-- sqlspec/adapters/psycopg/adk/store.py | 1123 +++++--- sqlspec/adapters/pymysql/adk/store.py | 1036 ++++--- sqlspec/adapters/spanner/adk/store.py | 596 +++- sqlspec/adapters/sqlite/adk/store.py | 747 +++-- sqlspec/extensions/adk/_config_utils.py | 30 +- sqlspec/extensions/adk/_types.py | 12 +- sqlspec/extensions/adk/converters.py | 45 +- sqlspec/extensions/adk/memory/store.py | 232 +- .../adk/migrations/0001_create_adk_tables.py | 156 +- .../adk/migrations/0002_reset_adk_tables.py | 127 + sqlspec/extensions/adk/service.py | 124 +- sqlspec/extensions/adk/store.py | 881 ++++-- .../adk/test_dialect_integration.py | 237 -- .../extensions/adk/test_dialect_support.py | 25 +- .../adbc/extensions/adk/test_edge_cases.py | 273 -- .../extensions/adk/test_event_operations.py | 396 --- .../adbc/extensions/adk/test_memory_store.py | 34 +- .../extensions/adk/test_owner_id_column.py | 32 +- .../extensions/adk/test_session_operations.py | 184 -- .../aiomysql/extensions/adk/test_store.py | 17 +- .../asyncmy/extensions/adk/test_store.py | 18 +- .../extensions/adk/test_owner_id_column.py | 8 +- .../extensions/adk/test_session_operations.py | 134 - .../adapters/contracts/_adk_cases.py | 39 +- .../adapters/contracts/adk_behaviors.py | 149 +- .../adapters/contracts/conftest.py | 206 +- .../extensions/adk/test_memory_store.py | 46 +- .../duckdb/extensions/adk/test_store.py | 158 +- .../extensions/adk/test_store.py | 18 +- .../oracledb/extensions/adk/test_inmemory.py | 30 +- .../extensions/adk/test_oracle_specific.py | 138 +- .../extensions/adk/test_owner_id_column.py | 4 +- .../spanner/extensions/adk/conftest.py | 6 +- .../spanner/extensions/adk/test_adk_store.py | 77 +- .../extensions/adk/test_memory_store.py | 46 +- .../extensions/adk/test_owner_id_column.py | 92 +- .../test_oracledb/test_oracle_adk_store.py | 18 +- .../adapters/test_psycopg/test_adk_store.py | 92 +- .../adapters/test_spanner/test_adk_store.py | 61 +- .../test_adk/test_config_resolution.py | 84 + .../extensions/test_adk/test_converters.py | 135 +- .../unit/extensions/test_adk/test_service.py | 273 +- .../extensions/test_adk/test_store_config.py | 284 +- .../test_adk/test_store_instantiation.py | 46 + 55 files changed, 11729 insertions(+), 6569 deletions(-) create mode 100644 sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py delete mode 100644 tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py delete mode 100644 tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py delete mode 100644 tests/integration/adapters/adbc/extensions/adk/test_event_operations.py delete mode 100644 tests/integration/adapters/adbc/extensions/adk/test_session_operations.py delete mode 100644 tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py create mode 100644 tests/unit/extensions/test_adk/test_config_resolution.py diff --git a/sqlspec/adapters/adbc/adk/store.py b/sqlspec/adapters/adbc/adk/store.py index ed38c7f1e..75ebe794c 100644 --- a/sqlspec/adapters/adbc/adk/store.py +++ b/sqlspec/adapters/adbc/adk/store.py @@ -1,14 +1,14 @@ """ADBC ADK store for Google Agent Development Kit session/event storage.""" import contextlib +import re from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from sqlspec.adapters.adbc.config import AdbcConfig @@ -25,25 +25,31 @@ DIALECT_SNOWFLAKE: Final = "snowflake" DIALECT_GENERIC: Final = "generic" -ADBC_TABLE_NOT_FOUND_PATTERNS: Final = ("no such table", "table or view does not exist", "relation does not exist") +ADBC_TABLE_NOT_FOUND_PATTERNS: Final = ( + "no such table", + "table or view does not exist", + "relation does not exist", + "does not exist", + "table with name", +) -class AdbcADKStore(BaseAsyncADKStore["AdbcConfig"]): +class AdbcADKStore(BaseSyncADKStore["AdbcConfig"]): """ADBC synchronous ADK store for Arrow Database Connectivity. Implements session and event storage for Google Agent Development Kit using ADBC. ADBC provides a vendor-neutral API with Arrow-native data transfer across multiple databases (PostgreSQL, SQLite, DuckDB, etc.). - Events use the new 5-column contract: session_id, invocation_id, author, - timestamp, and event_json. The full ADK Event payload is stored as a - single JSON blob in event_json using a dialect-appropriate column type + Events use the clean-break contract: id, session_id, invocation_id, + timestamp, and event_data. The full ADK Event payload is stored as a + single JSON blob in event_data using a dialect-appropriate column type (JSONB for PostgreSQL, JSON for DuckDB, VARIANT for Snowflake, TEXT for SQLite and generic fallback). Provides: - Session state management with JSON serialization - - Event history tracking via single event_json blob + - Event history tracking via single event_data blob - Atomic event insert + session state update - Timezone-aware timestamps - Foreign key constraints with cascade delete @@ -64,7 +70,97 @@ def __init__(self, config: "AdbcConfig") -> None: super().__init__(config) self._dialect = self._detect_dialect() - @property + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return self._create_session(session_id, app_name, user_id, state, owner_id) + + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return self._get_session(app_name, user_id, session_id, renew_for=renew_for) + + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + self._update_session_state(app_name, user_id, session_id, state) + + def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return self._list_sessions(app_name, user_id) + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and associated events.""" + self._delete_session(app_name, user_id, session_id) + + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + self._append_event(event_record) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return self._append_event_and_update_state( + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state + ) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session.""" + return self._get_events(app_name, user_id, session_id, after_timestamp, limit) + + def delete_expired_events(self, before: datetime) -> int: + """Delete events older than a timestamp.""" + return self._delete_expired_events(before) + + def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions older than a timestamp.""" + return self._delete_idle_sessions(updated_before) + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state.""" + return self._get_app_state(app_name) + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state.""" + return self._get_user_state(app_name, user_id) + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state.""" + self._upsert_app_state(app_name, state) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state.""" + self._upsert_user_state(app_name, user_id, state) + + def get_metadata(self, key: str) -> "str | None": + """Return a metadata value.""" + return self._get_metadata(key) + + def set_metadata(self, key: str, value: str) -> None: + """Set a metadata value.""" + self._set_metadata(key, value) + def dialect(self) -> str: """Return the detected database dialect.""" return self._dialect @@ -130,6 +226,22 @@ def _serialize_json_field(self, value: Any) -> "str | None": return None return to_json(value) + def _json_storage_type(self) -> str: + if self._dialect == DIALECT_POSTGRESQL: + return "JSONB" + if self._dialect == DIALECT_DUCKDB: + return "JSON" + if self._dialect == DIALECT_SNOWFLAKE: + return "VARIANT" + return "TEXT" + + def _timestamp_storage_type(self) -> str: + if self._dialect == DIALECT_POSTGRESQL: + return "TIMESTAMPTZ" + if self._dialect == DIALECT_SNOWFLAKE: + return "TIMESTAMP_TZ" + return "TIMESTAMP" + def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": """Deserialize optional JSON field from database. @@ -143,7 +255,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return None return from_json(str(data)) # type: ignore[no-any-return] - async def _get_create_sessions_table_sql(self) -> str: + def _get_create_sessions_table_sql(self) -> str: """Get CREATE TABLE SQL for sessions with dialect dispatch. Returns: @@ -249,7 +361,7 @@ def _get_sessions_ddl_generic(self) -> str: ) """ - async def _get_create_events_table_sql(self) -> str: + def _get_create_events_table_sql(self) -> str: """Get CREATE TABLE SQL for events with dialect dispatch. Returns: @@ -273,11 +385,11 @@ def _get_events_ddl_postgresql(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -290,11 +402,11 @@ def _get_events_ddl_sqlite(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id TEXT PRIMARY KEY, session_id TEXT NOT NULL, - invocation_id TEXT NOT NULL, - author TEXT NOT NULL, + invocation_id TEXT, timestamp REAL NOT NULL, - event_json TEXT NOT NULL, + event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ @@ -307,12 +419,12 @@ def _get_events_ddl_duckdb(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSON NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + event_data JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ) """ @@ -324,11 +436,11 @@ def _get_events_ddl_snowflake(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, - invocation_id VARCHAR NOT NULL, - author VARCHAR NOT NULL, + invocation_id VARCHAR, timestamp TIMESTAMP_TZ NOT NULL DEFAULT CURRENT_TIMESTAMP(), - event_json VARIANT NOT NULL, + event_data VARIANT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ) """ @@ -341,22 +453,84 @@ def _get_events_ddl_generic(self) -> str: """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json TEXT NOT NULL, + event_data TEXT NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) """ + def _get_create_app_states_table_sql(self) -> str: + json_type = self._json_storage_type() + timestamp_type = self._timestamp_storage_type() + default = "DEFAULT CURRENT_TIMESTAMP()" if self._dialect == DIALECT_SNOWFLAKE else "DEFAULT CURRENT_TIMESTAMP" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state {json_type} NOT NULL, + update_time {timestamp_type} NOT NULL {default} + ) + """ + + def _get_create_user_states_table_sql(self) -> str: + json_type = self._json_storage_type() + timestamp_type = self._timestamp_storage_type() + default = "DEFAULT CURRENT_TIMESTAMP()" if self._dialect == DIALECT_SNOWFLAKE else "DEFAULT CURRENT_TIMESTAMP" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state {json_type} NOT NULL, + update_time {timestamp_type} NOT NULL {default}, + PRIMARY KEY (app_name, user_id) + ) + """ + + def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) + """ + + def _get_seed_metadata_sql(self) -> str: + if self._dialect in {DIALECT_POSTGRESQL, DIALECT_SQLITE, DIALECT_DUCKDB}: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT(key) DO NOTHING + """ + return f""" + INSERT INTO {self._metadata_table} (key, value) + SELECT 'schema_version', '1' + WHERE NOT EXISTS (SELECT 1 FROM {self._metadata_table} WHERE key = 'schema_version') + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": """Get DROP TABLE SQL statements. Returns: List of SQL statements to drop tables and indexes. """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" @@ -365,7 +539,7 @@ def _create_tables(self) -> None: try: self._enable_foreign_keys(cursor, conn) - cursor.execute(run_(self._get_create_sessions_table_sql)()) + cursor.execute(self._get_create_sessions_table_sql()) conn.commit() sessions_idx_app_user = ( @@ -382,7 +556,7 @@ def _create_tables(self) -> None: cursor.execute(sessions_idx_update) conn.commit() - cursor.execute(run_(self._get_create_events_table_sql)()) + cursor.execute(self._get_create_events_table_sql()) conn.commit() events_idx = ( @@ -391,13 +565,21 @@ def _create_tables(self) -> None: ) cursor.execute(events_idx) conn.commit() + + cursor.execute(self._get_create_app_states_table_sql()) + conn.commit() + + cursor.execute(self._get_create_user_states_table_sql()) + conn.commit() + + cursor.execute(self._get_create_metadata_table_sql()) + conn.commit() + + cursor.execute(self._get_seed_metadata_sql()) + conn.commit() finally: cursor.close() - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _enable_foreign_keys(self, cursor: Any, conn: Any) -> None: """Enable foreign key constraints for SQLite. @@ -405,12 +587,37 @@ def _enable_foreign_keys(self, cursor: Any, conn: Any) -> None: cursor: Database cursor. conn: Database connection. """ + if self._dialect != DIALECT_SQLITE: + return try: cursor.execute("PRAGMA foreign_keys = ON") conn.commit() except Exception: logger.debug("Foreign key enforcement not supported or already enabled") + def _format_sql(self, sql: str) -> str: + """Return SQL with dialect-appropriate positional placeholders.""" + if self._dialect != DIALECT_POSTGRESQL: + return sql + index = 0 + + def replace_placeholder(_match: Any) -> str: + nonlocal index + index += 1 + return f"${index}" + + return re.sub(r"\?", replace_placeholder, sql) + + def _execute(self, cursor: Any, sql: str, params: "tuple[Any, ...] | list[Any]") -> Any: + """Execute parameterized SQL using the current ADBC dialect's placeholder style.""" + return cursor.execute(self._format_sql(sql), params) + + def _json_placeholder(self) -> str: + """Return a JSON parameter placeholder for the current dialect.""" + if self._dialect == DIALECT_POSTGRESQL: + return "?::jsonb" + return "?" + def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: @@ -427,62 +634,84 @@ def _create_session( Created session record. """ state_json = self._serialize_state(state) + state_placeholder = self._json_placeholder() params: tuple[Any, ...] if self._owner_id_column_name: sql = f""" INSERT INTO {self._session_table} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time) - VALUES (?, ?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + VALUES (?, ?, ?, ?, {state_placeholder}, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) """ params = (session_id, app_name, user_id, owner_id, state_json) else: sql = f""" INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time) - VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) + VALUES (?, ?, ?, {state_placeholder}, CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) """ params = (session_id, app_name, user_id, state_json) with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, params) + self._execute(cursor, sql, params) conn.commit() finally: cursor.close() - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. + renew_for: If positive, touch the session update timestamp. Returns: Session record or None if not found. """ - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = ? - """ + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE app_name = ? AND user_id = ? AND id = ? + """ + params: tuple[Any, ...] = (app_name, user_id, session_id) + select_after_update = True + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = ? AND user_id = ? AND id = ? + """ + params = (app_name, user_id, session_id) + select_after_update = False try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (session_id,)) + self._execute(cursor, sql, params) + if select_after_update: + conn.commit() + self._execute( + cursor, + f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = ? AND user_id = ? AND id = ? + """, + params, + ) row = cursor.fetchone() if row is None: @@ -504,57 +733,53 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. state: New state dictionary (replaces existing state). """ state_json = self._serialize_state(state) sql = f""" UPDATE {self._session_table} - SET state = ?, update_time = CURRENT_TIMESTAMP - WHERE id = ? + SET state = {self._json_placeholder()}, update_time = CURRENT_TIMESTAMP + WHERE app_name = ? AND user_id = ? AND id = ? """ with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (state_json, session_id)) + self._execute(cursor, sql, (state_json, app_name, user_id, session_id)) conn.commit() finally: cursor.close() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and all associated events (cascade). Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. """ - sql = f"DELETE FROM {self._session_table} WHERE id = ?" + delete_events_sql = f"DELETE FROM {self._events_table} WHERE session_id = ?" + delete_session_sql = f"DELETE FROM {self._session_table} WHERE app_name = ? AND user_id = ? AND id = ?" with self._config.provide_connection() as conn: cursor = conn.cursor() try: self._enable_foreign_keys(cursor, conn) - cursor.execute(sql, (session_id,)) + self._execute(cursor, delete_events_sql, (session_id,)) + if self._dialect == DIALECT_DUCKDB: + conn.commit() + self._execute(cursor, delete_session_sql, (app_name, user_id, session_id)) conn.commit() finally: cursor.close() - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. @@ -586,7 +811,7 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, params) + self._execute(cursor, sql, params) rows = cursor.fetchall() return [ @@ -608,34 +833,31 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - def _insert_event(self, event_record: "EventRecord") -> None: """Insert an event record into the events table. Args: event_record: Event record to store. """ - event_json = self._serialize_json_field(event_record["event_json"]) + event_data = self._serialize_json_field(event_record["event_data"]) sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (?, ?, ?, ?, ?) + id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, {self._json_placeholder()}) """ with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute( + self._execute( + cursor, sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_json, + event_data, ), ) conn.commit() @@ -643,7 +865,15 @@ def _insert_event(self, event_record: "EventRecord") -> None: cursor.close() def _append_event_and_update_state( - self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + self, + event_record: "EventRecord", + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically insert an event and update the session's durable state. @@ -655,43 +885,72 @@ def _append_event_and_update_state( Args: event_record: Event record to store. + app_name: Application name. + user_id: User identifier. session_id: Session identifier whose state should be updated. state: Post-append durable state snapshot (``temp:`` keys already stripped by the service layer). + app_state: Optional app-scoped state snapshot. + user_state: Optional user-scoped state snapshot. """ insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (?, ?, ?, ?, ?) + id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, {self._json_placeholder()}) """ update_sql = f""" UPDATE {self._session_table} - SET state = ?, update_time = CURRENT_TIMESTAMP - WHERE id = ? + SET state = {self._json_placeholder()}, update_time = CURRENT_TIMESTAMP + WHERE app_name = ? AND user_id = ? AND id = ? """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? + """ + delete_app_state_sql = f"DELETE FROM {self._app_state_table} WHERE app_name = ?" + insert_app_state_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, {self._json_placeholder()}, ?) + """ + delete_user_state_sql = f"DELETE FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + insert_user_state_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, {self._json_placeholder()}, ?) """ state_json = self._serialize_state(state) - event_json = self._serialize_json_field(event_record["event_json"]) + event_data = self._serialize_json_field(event_record["event_data"]) with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute( + self._execute( + cursor, insert_sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_json, + event_data, ), ) - cursor.execute(update_sql, (state_json, session_id)) - cursor.execute(select_sql, (session_id,)) + self._execute(cursor, update_sql, (state_json, app_name, user_id, session_id)) + if app_state is not None: + self._execute(cursor, delete_app_state_sql, (app_name,)) + self._execute( + cursor, + insert_app_state_sql, + (app_name, self._serialize_state(app_state), datetime.now(timezone.utc)), + ) + if user_state is not None: + self._execute(cursor, delete_user_state_sql, (app_name, user_id)) + self._execute( + cursor, + insert_user_state_sql, + (app_name, user_id, self._serialize_state(user_state), datetime.now(timezone.utc)), + ) + self._execute(cursor, select_sql, (app_name, user_id, session_id)) row = cursor.fetchone() conn.commit() except Exception: @@ -714,18 +973,19 @@ def _append_event_and_update_state( update_time=row[5], ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -733,36 +993,42 @@ def _get_events( Returns: List of event records ordered by timestamp ASC. """ - where_clauses = ["session_id = ?"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = ["s.app_name = ?", "s.user_id = ?", "e.session_id = ?"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > ?") + where_clauses.append("e.timestamp > ?") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" + limit_clause = f" LIMIT {limit}" if limit is not None else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, params) + self._execute(cursor, sql, params) rows = cursor.fetchall() return [ EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], + id=row[0], + session_id=row[1], + invocation_id=row[2] or "", timestamp=row[3], - event_json=self._deserialize_json_field(row[4]) or {}, + event_data=self._deserialize_json_field(row[4]) or {}, + app_name=row[5], + user_id=row[6], ) for row in rows ] @@ -774,22 +1040,156 @@ def _get_events( return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: datetime) -> int: + count_sql = f"SELECT COUNT(*) FROM {self._events_table} WHERE timestamp < ?" + delete_sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + self._execute(cursor, count_sql, (before,)) + row = cursor.fetchone() + count = int(row[0]) if row is not None else 0 + self._execute(cursor, delete_sql, (before,)) + conn.commit() + return count + finally: + cursor.close() + except Exception as exc: + error_msg = str(exc).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return 0 + raise + + def _delete_idle_sessions(self, updated_before: datetime) -> int: + count_sql = f"SELECT COUNT(*) FROM {self._session_table} WHERE update_time < ?" + delete_events_sql = f""" + DELETE FROM {self._events_table} + WHERE session_id IN (SELECT id FROM {self._session_table} WHERE update_time < ?) + """ + delete_sessions_sql = f"DELETE FROM {self._session_table} WHERE update_time < ?" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + self._execute(cursor, count_sql, (updated_before,)) + row = cursor.fetchone() + count = int(row[0]) if row is not None else 0 + self._execute(cursor, delete_events_sql, (updated_before,)) + self._execute(cursor, delete_sessions_sql, (updated_before,)) + conn.commit() + return count + finally: + cursor.close() + except Exception as exc: + error_msg = str(exc).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return 0 + raise + + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = ?" + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + self._execute(cursor, sql, (app_name,)) + row = cursor.fetchone() + return self._deserialize_state(row[0]) if row is not None else None + finally: + cursor.close() + except Exception as exc: + error_msg = str(exc).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + self._execute(cursor, sql, (app_name, user_id)) + row = cursor.fetchone() + return self._deserialize_state(row[0]) if row is not None else None + finally: + cursor.close() + except Exception as exc: + error_msg = str(exc).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + delete_sql = f"DELETE FROM {self._app_state_table} WHERE app_name = ?" + insert_sql = ( + f"INSERT INTO {self._app_state_table} (app_name, state, update_time) " + f"VALUES (?, {self._json_placeholder()}, ?)" + ) + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + self._execute(cursor, delete_sql, (app_name,)) + self._execute(cursor, insert_sql, (app_name, self._serialize_state(state), datetime.now(timezone.utc))) + conn.commit() + finally: + cursor.close() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + delete_sql = f"DELETE FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + insert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, {self._json_placeholder()}, ?) + """ + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + self._execute(cursor, delete_sql, (app_name, user_id)) + self._execute( + cursor, insert_sql, (app_name, user_id, self._serialize_state(state), datetime.now(timezone.utc)) + ) + conn.commit() + finally: + cursor.close() + + def _get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = ?" + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + self._execute(cursor, sql, (key,)) + row = cursor.fetchone() + return row[0] if row is not None else None + finally: + cursor.close() + except Exception as exc: + error_msg = str(exc).lower() + if any(pattern in error_msg for pattern in ADBC_TABLE_NOT_FOUND_PATTERNS): + return None + raise + + def _set_metadata(self, key: str, value: str) -> None: + delete_sql = f"DELETE FROM {self._metadata_table} WHERE key = ?" + insert_sql = f"INSERT INTO {self._metadata_table} (key, value) VALUES (?, ?)" + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + self._execute(cursor, delete_sql, (key,)) + self._execute(cursor, insert_sql, (key, value)) + conn.commit() + finally: + cursor.close() def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - -class AdbcADKMemoryStore(BaseAsyncADKMemoryStore["AdbcConfig"]): +class AdbcADKMemoryStore(BaseSyncADKMemoryStore["AdbcConfig"]): """ADBC synchronous ADK memory store for Arrow Database Connectivity.""" __slots__ = ("_dialect",) @@ -802,6 +1202,28 @@ def __init__(self, config: "AdbcConfig") -> None: def dialect(self) -> str: return self._dialect + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return self._insert_memory_entries(entries, owner_id) + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return self._search_entries(query, app_name, user_id, limit) + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._delete_entries_by_session(session_id) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return self._delete_entries_older_than(days) + def _detect_dialect(self) -> str: driver_name = self._config.connection_config.get("driver_name", "").lower() if "postgres" in driver_name: @@ -834,7 +1256,7 @@ def _decode_timestamp(self, value: Any) -> datetime: return datetime.fromisoformat(value) return datetime.fromisoformat(str(value)) - async def _get_create_memory_table_sql(self) -> str: + def _get_create_memory_table_sql(self) -> str: if self._dialect == DIALECT_POSTGRESQL: return self._get_memory_ddl_postgresql() if self._dialect == DIALECT_SQLITE: @@ -945,7 +1367,7 @@ def _create_tables(self) -> None: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(run_(self._get_create_memory_table_sql)()) + cursor.execute(self._get_create_memory_table_sql()) conn.commit() idx_app_user = ( @@ -963,10 +1385,6 @@ def _create_tables(self) -> None: finally: cursor.close() - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1073,10 +1491,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -1117,12 +1531,6 @@ def _search_entries( return self._rows_to_records(rows) - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB} if use_returning: @@ -1142,10 +1550,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: cutoff = self._encode_timestamp(datetime.now(timezone.utc) - timedelta(days=days)) use_returning = self._dialect in {DIALECT_SQLITE, DIALECT_POSTGRESQL, DIALECT_DUCKDB} @@ -1166,10 +1570,6 @@ def _delete_entries_older_than(self, days: int) -> int: finally: cursor.close() - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: diff --git a/sqlspec/adapters/aiomysql/adk/store.py b/sqlspec/adapters/aiomysql/adk/store.py index df127aaa9..8fdb568e9 100644 --- a/sqlspec/adapters/aiomysql/adk/store.py +++ b/sqlspec/adapters/aiomysql/adk/store.py @@ -11,7 +11,7 @@ from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.aiomysql.config import AiomysqlConfig from sqlspec.extensions.adk import MemoryRecord @@ -23,129 +23,28 @@ class AiomysqlADKStore(BaseAsyncADKStore["AiomysqlConfig"]): - """MySQL/MariaDB ADK store using aiomysql driver. - - Implements session and event storage for Google Agent Development Kit - using MySQL/MariaDB via the aiomysql driver. Provides: - - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) - - Atomic event-append + state-update in one transaction - - Microsecond-precision timestamps - - Foreign key constraints with cascade delete - - Efficient upserts using ON DUPLICATE KEY UPDATE - """ + """MySQL/MariaDB ADK store using aiomysql driver.""" __slots__ = () def __init__(self, config: "AiomysqlConfig") -> None: - """Initialize aiomysql ADK store. - - Args: - config: AiomysqlConfig instance. - """ + """Initialize aiomysql ADK store.""" super().__init__(config) - def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": - """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. - - MySQL ignores inline REFERENCES syntax in column definitions. - This method extracts the column definition and creates a separate - FOREIGN KEY constraint. - - Args: - column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - - Returns: - Tuple of (column_definition, foreign_key_constraint) - """ - references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) - - if not references_match: - return (column_ddl.strip(), "") - - col_def = column_ddl[: references_match.start()].strip() - fk_clause = references_match.group(1).strip() - col_name = col_def.split()[0] - fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" - - return (col_def, fk_constraint) - - async def _get_create_sessions_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - """ - owner_id_col = "" - fk_constraint = "" - - if self._owner_id_column_ddl: - col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_col = f"{col_def}," - if fk_def: - fk_constraint = f",\n {fk_def}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{self._session_table}_app_user (app_name, user_id), - INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - async def _get_create_events_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - """ - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_json JSON NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, - INDEX idx_{self._events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - def _get_drop_tables_sql(self) -> "list[str]": - """Get MySQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" + """Create all ADK session tables if they don't exist.""" async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - """ + """Create a new session.""" state_json = to_json(state) params: tuple[Any, ...] @@ -169,21 +68,33 @@ async def create_session( await cursor.execute(sql, params) await conn.commit() - return await self.get_session(session_id) # type: ignore[return-value] - - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. + result = await self.get_session(app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result - Args: - session_id: Session identifier. + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by scoped identifiers.""" + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s + """ + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (app_name, user_id, session_id)) + await conn.commit() - Returns: - Session record or None if not found. - """ sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ try: @@ -191,74 +102,35 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": self._config.provide_connection() as conn, AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, ): - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (app_name, user_id, session_id)) row = await cursor.fetchone() if row is None: return None - session_id_val, app_name, user_id, state_json, create_time, update_time = row - - return SessionRecord( - id=session_id_val, - app_name=app_name, - user_id=user_id, - state=from_json(state_json) if isinstance(state_json, str) else state_json, - create_time=create_time, - update_time=update_time, - ) - except pymysql.err.ProgrammingError as e: - if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return _session_record_from_row(row) + except pymysql.err.ProgrammingError as exc: + if _is_mysql_table_missing(exc): return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - """ - state_json = to_json(state) - + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s - """ - - async with ( - self._config.provide_connection() as conn, - AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, - ): - await cursor.execute(sql, (state_json, session_id)) - await conn.commit() - - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ - sql = f"DELETE FROM {self._session_table} WHERE id = %s" async with ( self._config.provide_connection() as conn, AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, ): - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (to_json(state), app_name, user_id, session_id)) await conn.commit() async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - """ + """List sessions for an app, optionally filtered by user.""" if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -266,7 +138,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis WHERE app_name = %s ORDER BY update_time DESC """ - params: tuple[str, ...] = (app_name,) + params: tuple[Any, ...] = (app_name,) else: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -283,151 +155,129 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis ): await cursor.execute(sql, params) rows = await cursor.fetchall() - - return [ - SessionRecord( - id=row[0], - app_name=row[1], - user_id=row[2], - state=from_json(row[3]) if isinstance(row[3], str) else row[3], - create_time=row[4], - update_time=row[5], - ) - for row in rows - ] - except pymysql.err.ProgrammingError as e: - if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return [_session_record_from_row(row) for row in rows] + except pymysql.err.ProgrammingError as exc: + if _is_mysql_table_missing(exc): return [] raise - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and all associated events.""" + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" - Args: - event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_json). - """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (app_name, user_id, session_id)) + await conn.commit() + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) """ async with ( self._config.provide_connection() as conn, AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, ): - await cursor.execute( - sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_json_str, - ), - ) + await cursor.execute(sql, _event_insert_params(event_record)) await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - MySQL doesn't support UPDATE...RETURNING; we follow the UPDATE with a - SELECT inside the same transaction so callers get the refreshed row - in a single round-trip pair (no separate connection acquisition). - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. - """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json - state_json = to_json(state) - + """Atomically append an event and update session + scoped state.""" insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) """ - update_sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ - select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ async with ( self._config.provide_connection() as conn, AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, ): - await cursor.execute( - insert_sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_json_str, - ), - ) - await cursor.execute(update_sql, (state_json, session_id)) - await cursor.execute(select_sql, (session_id,)) - row = await cursor.fetchone() - await conn.commit() - - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) + try: + await cursor.execute(update_sql, (to_json(state), app_name, user_id, session_id)) + await cursor.execute(select_sql, (app_name, user_id, session_id)) + row = await cursor.fetchone() + if row is None: + _raise_session_not_found(session_id) + await cursor.execute( + insert_sql, + ( + event_record["id"], + app_name, + user_id, + session_id, + event_record["invocation_id"], + event_record["timestamp"], + _json_for_storage(event_record["event_data"]), + ), + ) + if app_state is not None: + await cursor.execute( + _mysql_upsert_app_state_sql(self._app_state_table), (app_name, to_json(app_state)) + ) + if user_state is not None: + await cursor.execute( + _mysql_upsert_user_state_sql(self._user_state_table), (app_name, user_id, to_json(user_state)) + ) + await conn.commit() + except Exception: + await conn.rollback() + raise - state_value = row[3] - return SessionRecord( - id=row[0], - app_name=row[1], - user_id=row[2], - state=from_json(state_value) if isinstance(state_value, str) else state_value, - create_time=row[4], - update_time=row[5], - ) + return _session_record_from_row(row) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + """Get events for a session.""" + if limit == 0: + return [] + where_clauses = ["app_name = %s", "user_id = %s", "session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: where_clauses.append("timestamp > %s") params.append(after_timestamp) - - where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" + limit_clause = "" + if limit is not None: + limit_clause = " LIMIT %s" + params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data FROM {self._events_table} - WHERE {where_clause} + WHERE {" AND ".join(where_clauses)} ORDER BY timestamp ASC{limit_clause} """ @@ -438,41 +288,171 @@ async def get_events( ): await cursor.execute(sql, params) rows = await cursor.fetchall() - - return [ - EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], - timestamp=row[3], - event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], - ) - for row in rows - ] - except pymysql.err.ProgrammingError as e: - if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return [_event_record_from_row(row) for row in rows] + except pymysql.err.ProgrammingError as exc: + if _is_mysql_table_missing(exc): return [] raise + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" -def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": - """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except pymysql.err.ProgrammingError as exc: + if _is_mysql_table_missing(exc): + return 0 + raise - Args: - column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE". + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the threshold.""" + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" - Returns: - Tuple of (column_definition, foreign_key_constraint). - """ - references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) - if not references_match: - return (column_ddl.strip(), "") + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (updated_before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except pymysql.err.ProgrammingError as exc: + if _is_mysql_table_missing(exc): + return 0 + raise - col_def = column_ddl[: references_match.start()].strip() - fk_clause = references_match.group(1).strip() - col_name = col_def.split()[0] - fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" - return (col_def, fk_constraint) + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" + + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (app_name,)) + row = await cursor.fetchone() + return _json_dict(row[0]) if row is not None else None + except pymysql.err.ProgrammingError as exc: + if _is_mysql_table_missing(exc): + return None + raise + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" + + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (app_name, user_id)) + row = await cursor.fetchone() + return _json_dict(row[0]) if row is not None else None + except pymysql.err.ProgrammingError as exc: + if _is_mysql_table_missing(exc): + return None + raise + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(_mysql_upsert_app_state_sql(self._app_state_table), (app_name, to_json(state))) + await conn.commit() + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute( + _mysql_upsert_user_state_sql(self._user_state_table), (app_name, user_id, to_json(state)) + ) + await conn.commit() + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE `key` = %s" + + try: + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(sql, (key,)) + row = await cursor.fetchone() + return str(row[0]) if row is not None else None + except pymysql.err.ProgrammingError as exc: + if _is_mysql_table_missing(exc): + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + async with ( + self._config.provide_connection() as conn, + AiomysqlCursor(conn, cursor_class=AiomysqlRawCursor) as cursor, + ): + await cursor.execute(_mysql_upsert_metadata_sql(self._metadata_table), (key, value)) + await conn.commit() + + async def _get_create_sessions_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for sessions.""" + return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) + + async def _get_create_events_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for events.""" + return _mysql_events_ddl(self._events_table, self._session_table) + + async def _get_create_app_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for app-scoped state.""" + return _mysql_app_state_ddl(self._app_state_table) + + async def _get_create_user_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for user-scoped state.""" + return _mysql_user_state_ddl(self._user_state_table) + + async def _get_create_metadata_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for ADK metadata.""" + return _mysql_metadata_ddl(self._metadata_table) + + async def _get_seed_metadata_sql(self) -> str: + """Get MySQL metadata seed SQL.""" + return f"INSERT IGNORE INTO {self._metadata_table} (`key`, value) VALUES ('schema_version', '1')" + + def _get_drop_app_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for ADK metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + """Get MySQL DROP TABLE SQL statements.""" + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] class AiomysqlADKMemoryStore(BaseAsyncADKMemoryStore["AiomysqlConfig"]): @@ -484,42 +464,6 @@ def __init__(self, config: "AiomysqlConfig") -> None: """Initialize aiomysql memory store.""" super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for memory entries.""" - owner_id_line = "" - fk_constraint = "" - if self._owner_id_column_ddl: - col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_line = f",\n {col_def}" - if fk_def: - fk_constraint = f",\n {fk_def}" - - fts_index = "" - if self._use_fts: - fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" - - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content_json JSON NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSON, - inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), - INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get MySQL DROP TABLE SQL statements.""" - return [f"DROP TABLE IF EXISTS {self._memory_table}"] - async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" if not self._enabled: @@ -664,3 +608,206 @@ async def delete_entries_older_than(self, days: int) -> int: await cursor.execute(sql, (days,)) await conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def _get_create_memory_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for memory entries.""" + owner_id_line = "" + fk_constraint = "" + if self._owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) + owner_id_line = f",\n {col_def}" + if fk_def: + fk_constraint = f",\n {fk_def}" + + fts_index = "" + if self._use_fts: + fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + content_json JSON NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSON, + inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), + INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get MySQL DROP TABLE SQL statements.""" + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + + +def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": + """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + + Args: + column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE". + + Returns: + Tuple of (column_definition, foreign_key_constraint). + """ + references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) + if not references_match: + return (column_ddl.strip(), "") + + col_def = column_ddl[: references_match.start()].strip() + fk_clause = references_match.group(1).strip() + col_name = col_def.split()[0] + fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" + return (col_def, fk_constraint) + + +def _is_mysql_table_missing(exc: BaseException) -> bool: + args = getattr(exc, "args", ()) + return "doesn't exist" in str(exc) or bool(args and args[0] == MYSQL_TABLE_NOT_FOUND_ERROR) + + +def _json_for_storage(value: Any) -> str: + return value if isinstance(value, str) else to_json(value) + + +def _json_dict(value: Any) -> "dict[str, Any]": + if isinstance(value, bytearray): + value = bytes(value) + if isinstance(value, (bytes, str)): + return cast("dict[str, Any]", from_json(value)) + return cast("dict[str, Any]", value) + + +def _session_record_from_row(row: Any) -> SessionRecord: + return SessionRecord( + id=row[0], app_name=row[1], user_id=row[2], state=_json_dict(row[3]), create_time=row[4], update_time=row[5] + ) + + +def _event_record_from_row(row: Any) -> EventRecord: + return EventRecord( + id=row[0], + app_name=row[1], + user_id=row[2], + session_id=row[3], + invocation_id=row[4], + timestamp=row[5], + event_data=_json_dict(row[6]), + ) + + +def _event_insert_params(event_record: EventRecord) -> "tuple[Any, ...]": + return ( + event_record["id"], + event_record["app_name"], + event_record["user_id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + _json_for_storage(event_record["event_data"]), + ) + + +def _mysql_sessions_ddl(session_table: str, owner_id_column_ddl: "str | None") -> str: + owner_id_line = "" + fk_constraint = "" + if owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(owner_id_column_ddl) + owner_id_line = f"\n {col_def}," + if fk_def: + fk_constraint = f",\n {fk_def}" + + return f""" + CREATE TABLE IF NOT EXISTS {session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL,{owner_id_line} + state JSON NOT NULL, + create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + INDEX idx_{session_table}_app_user (app_name, user_id), + INDEX idx_{session_table}_update_time (update_time DESC){fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_events_ddl(events_table: str, session_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {events_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + event_data JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE, + INDEX idx_{events_table}_scope (app_name, user_id, session_id, timestamp ASC), + INDEX idx_{events_table}_session (session_id, timestamp ASC) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_app_state_ddl(app_state_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_user_state_ddl(user_state_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_metadata_ddl(metadata_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {metadata_table} ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_upsert_app_state_sql(app_state_table: str) -> str: + return f""" + INSERT INTO {app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + +def _mysql_upsert_user_state_sql(user_state_table: str) -> str: + return f""" + INSERT INTO {user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + +def _mysql_upsert_metadata_sql(metadata_table: str) -> str: + return f""" + INSERT INTO {metadata_table} (`key`, value) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value) + """ + + +def _raise_session_not_found(session_id: str) -> None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index 72e5d0d53..5216b6b1e 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -1,7 +1,7 @@ """Aiosqlite async ADK store for Google Agent Development Kit session/event storage.""" import sqlite3 -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final, cast from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord @@ -48,76 +48,16 @@ def __init__(self, config: "AiosqliteConfig") -> None: """ super().__init__(config) - async def _apply_pragmas(self, connection: Any) -> None: - """Apply PRAGMA optimization profile for this connection. - - Args: - connection: Aiosqlite connection. - """ - await connection.execute("PRAGMA foreign_keys = ON") - await connection.execute("PRAGMA cache_size = -64000") - await connection.execute("PRAGMA mmap_size = 30000000") - await connection.execute("PRAGMA journal_size_limit = 67108864") - - async def _get_create_sessions_table_sql(self) -> str: - """Get SQLite CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - """ - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id TEXT PRIMARY KEY, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL{owner_id_line}, - state TEXT NOT NULL DEFAULT '{{}}', - create_time REAL NOT NULL, - update_time REAL NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - """ - - async def _get_create_events_table_sql(self) -> str: - """Get SQLite CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - """ - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - invocation_id TEXT, - author TEXT, - timestamp REAL NOT NULL, - event_data TEXT NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - """Get SQLite DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: """Create both sessions and events tables if they don't exist.""" async with self._config.provide_session() as driver: await self._apply_pragmas(driver.connection) await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -162,25 +102,39 @@ async def create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - async def get_session(self, session_id: str) -> "SessionRecord | None": + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. + renew_for: If positive, touch the session update timestamp while reading. Returns: Session record or None if not found. """ + params = (app_name, user_id, session_id) sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ try: async with self._config.provide_connection() as conn: await self._apply_pragmas(conn) - cursor = await conn.execute(sql, (session_id,)) + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f""" + UPDATE {self._session_table} + SET update_time = ? + WHERE app_name = ? AND user_id = ? AND id = ? + """ + await conn.execute(update_sql, (_datetime_to_julian(datetime.now(timezone.utc)), *params)) + await conn.commit() + cursor = await conn.execute(sql, params) row = await cursor.fetchone() if row is None: @@ -199,10 +153,12 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. state: New state dictionary (replaces existing state). """ @@ -212,12 +168,12 @@ async def update_session_state(self, session_id: str, state: "dict[str, Any]") - sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ async with self._config.provide_connection() as conn: await self._apply_pragmas(conn) - await conn.execute(sql, (state_json, now_julian, session_id)) + await conn.execute(sql, (state_json, now_julian, app_name, user_id, session_id)) await conn.commit() async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": @@ -269,36 +225,34 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis return [] raise - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and all associated events (cascade). Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. """ - sql = f"DELETE FROM {self._session_table} WHERE id = ?" + sql = f"DELETE FROM {self._session_table} WHERE app_name = ? AND user_id = ? AND id = ?" async with self._config.provide_connection() as conn: await self._apply_pragmas(conn) - await conn.execute(sql, (session_id,)) + await conn.execute(sql, (app_name, user_id, session_id)) await conn.commit() async def append_event(self, event_record: EventRecord) -> None: """Append an event to a session. Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. + event_record: Event record to store. """ - import uuid - timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - event_data_json = to_json(event_record["event_json"]) - event_id = str(uuid.uuid4()) + event_data_json = to_json(event_record["event_data"]) sql = f""" INSERT INTO {self._events_table} ( - id, session_id, invocation_id, author, timestamp, event_data - ) VALUES (?, ?, ?, ?, ?, ?) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?, ?) """ async with self._config.provide_connection() as conn: @@ -306,10 +260,11 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.execute( sql, ( - event_id, + event_record["id"], + event_record["app_name"], + event_record["user_id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], timestamp_julian, event_data_json, ), @@ -317,7 +272,15 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically append an event and update the session's durable state. @@ -327,47 +290,78 @@ async def append_event_and_update_state( Args: event_record: Event record to store. + app_name: Application name for scoped state. + user_id: User identifier for scoped state. session_id: Session identifier whose state should be updated. state: Post-append durable state snapshot (temp: keys already stripped by the service layer). + app_state: App-scoped state snapshot to upsert when changed. + user_state: User-scoped state snapshot to upsert when changed. """ - import uuid - timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - event_data_json = to_json(event_record["event_json"]) + event_data_json = to_json(event_record["event_data"]) now_julian = _datetime_to_julian(datetime.now(timezone.utc)) state_json = to_json(state) - event_id = str(uuid.uuid4()) insert_sql = f""" INSERT INTO {self._events_table} ( - id, session_id, invocation_id, author, timestamp, event_data - ) VALUES (?, ?, ?, ?, ?, ?) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?, ?) """ update_sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + async with self._config.provide_connection() as conn: await self._apply_pragmas(conn) - await conn.execute( - insert_sql, - ( - event_id, - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - timestamp_julian, - event_data_json, - ), - ) - cursor = await conn.execute(update_sql, (state_json, now_julian, session_id)) - row = await cursor.fetchone() - await conn.commit() + try: + cursor = await conn.execute(update_sql, (state_json, now_julian, app_name, user_id, session_id)) + row = await cursor.fetchone() + if row is not None: + await conn.execute( + insert_sql, + ( + event_record["id"], + app_name, + user_id, + event_record["session_id"], + event_record["invocation_id"], + timestamp_julian, + event_data_json, + ), + ) + if app_state is not None: + await conn.execute(app_upsert_sql, (app_name, to_json(app_state), now_julian)) + if user_state is not None: + await conn.execute(user_upsert_sql, (app_name, user_id, to_json(user_state), now_julian)) + except Exception: + await conn.rollback() + raise + else: + if row is None: + await conn.rollback() + else: + await conn.commit() if row is None: msg = f"Session {session_id} not found during append_event_and_update_state." @@ -383,11 +377,18 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -395,8 +396,11 @@ async def get_events( Returns: List of event records ordered by timestamp ASC. """ - where_clauses = ["session_id = ?"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = ["app_name = ?", "user_id = ?", "session_id = ?"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: where_clauses.append("timestamp > ?") @@ -406,7 +410,7 @@ async def get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, invocation_id, author, timestamp, event_data + SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -420,11 +424,13 @@ async def get_events( return [ EventRecord( - session_id=row[1], - invocation_id=row[2], - author=row[3], - timestamp=_julian_to_datetime(row[4]), - event_json=from_json(row[5]) if row[5] else {}, + id=row[0], + app_name=row[1], + user_id=row[2], + session_id=row[3], + invocation_id=row[4], + timestamp=_julian_to_datetime(row[5]), + event_data=from_json(row[6]) if row[6] else {}, ) for row in rows ] @@ -433,102 +439,276 @@ async def get_events( return [] raise + async def delete_expired_events(self, before: datetime) -> int: + """Delete events older than the given timestamp.""" + sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" -class AiosqliteADKMemoryStore(BaseAsyncADKMemoryStore["AiosqliteConfig"]): - """Aiosqlite ADK memory store using asynchronous SQLite driver. + try: + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (_datetime_to_julian(before),)) + deleted_count = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + await conn.commit() + return deleted_count + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return 0 + raise - Implements memory entry storage for Google Agent Development Kit - using SQLite via the asynchronous aiosqlite driver. Provides: - - Session memory storage with JSON as TEXT - - Simple LIKE search (simple strategy) - - Optional FTS5 full-text search (sqlite_fts5 strategy) - - Julian Day timestamps (REAL) for efficient date operations - - Deduplication via event_id unique constraint - - Efficient upserts using INSERT OR IGNORE + async def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time predates the given threshold.""" + sql = f"DELETE FROM {self._session_table} WHERE update_time < ?" - Args: - config: AiosqliteConfig with extension_config["adk"] settings. - """ + try: + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (_datetime_to_julian(updated_before),)) + deleted_count = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + await conn.commit() + return deleted_count + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return 0 + raise - __slots__ = () + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = ?" - def __init__(self, config: "AiosqliteConfig") -> None: - """Initialize Aiosqlite ADK memory store. + try: + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (app_name,)) + row = await cursor.fetchone() + return from_json(row[0]) if row is not None and row[0] else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + sql = f""" + SELECT state + FROM {self._user_state_table} + WHERE app_name = ? AND user_id = ? + """ + + try: + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (app_name, user_id)) + row = await cursor.fetchone() + return from_json(row[0]) if row is not None and row[0] else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + await conn.execute(sql, (app_name, to_json(state), _datetime_to_julian(datetime.now(timezone.utc)))) + await conn.commit() + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + await conn.execute( + sql, (app_name, user_id, to_json(state), _datetime_to_julian(datetime.now(timezone.utc))) + ) + await conn.commit() + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = ?" + + try: + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + cursor = await conn.execute(sql, (key,)) + row = await cursor.fetchone() + return str(row[0]) if row is not None else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """ + + async with self._config.provide_connection() as conn: + await self._apply_pragmas(conn) + await conn.execute(sql, (key, value)) + await conn.commit() + + async def _apply_pragmas(self, connection: Any) -> None: + """Apply PRAGMA optimization profile for this connection. Args: - config: AiosqliteConfig instance. + connection: Aiosqlite connection. """ - super().__init__(config) + await connection.execute("PRAGMA foreign_keys = ON") + await connection.execute("PRAGMA cache_size = -64000") + await connection.execute("PRAGMA mmap_size = 30000000") + await connection.execute("PRAGMA journal_size_limit = 67108864") - async def _get_create_memory_table_sql(self) -> str: - """Get SQLite CREATE TABLE SQL for memory entries. + async def _get_create_sessions_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for sessions. Returns: - SQL statement to create memory table with indexes. + SQL statement to create adk_session table with indexes. """ owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" - fts_table = "" - if self._use_fts: - fts_table = f""" - CREATE VIRTUAL TABLE IF NOT EXISTS {self._memory_table}_fts USING fts5( - content_text, - content={self._memory_table}, - content_rowid=rowid + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id TEXT PRIMARY KEY, + app_name TEXT NOT NULL, + user_id TEXT NOT NULL{owner_id_line}, + state TEXT NOT NULL DEFAULT '{{}}', + create_time REAL NOT NULL, + update_time REAL NOT NULL ); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + """ - CREATE TRIGGER IF NOT EXISTS {self._memory_table}_ai AFTER INSERT ON {self._memory_table} BEGIN - INSERT INTO {self._memory_table}_fts(rowid, content_text) VALUES (new.rowid, new.content_text); - END; - - CREATE TRIGGER IF NOT EXISTS {self._memory_table}_ad AFTER DELETE ON {self._memory_table} BEGIN - INSERT INTO {self._memory_table}_fts({self._memory_table}_fts, rowid, content_text) - VALUES('delete', old.rowid, old.content_text); - END; - - CREATE TRIGGER IF NOT EXISTS {self._memory_table}_au AFTER UPDATE ON {self._memory_table} BEGIN - INSERT INTO {self._memory_table}_fts({self._memory_table}_fts, rowid, content_text) - VALUES('delete', old.rowid, old.content_text); - INSERT INTO {self._memory_table}_fts(rowid, content_text) VALUES (new.rowid, new.content_text); - END; - """ - + async def _get_create_events_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for events.""" return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( + CREATE TABLE IF NOT EXISTS {self._events_table} ( id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, app_name TEXT NOT NULL, user_id TEXT NOT NULL, - event_id TEXT NOT NULL UNIQUE, - author TEXT{owner_id_line}, + session_id TEXT NOT NULL, + invocation_id TEXT, timestamp REAL NOT NULL, - content_json TEXT NOT NULL, - content_text TEXT NOT NULL, - metadata_json TEXT, - inserted_at REAL NOT NULL + event_data TEXT NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(app_name, user_id, session_id, timestamp ASC); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_invocation + ON {self._events_table}(invocation_id); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_timestamp + ON {self._events_table}(timestamp ASC); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); + async def _get_create_app_states_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for app-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name TEXT PRIMARY KEY, + state TEXT NOT NULL DEFAULT '{{}}', + update_time REAL NOT NULL + ); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_table} + async def _get_create_user_states_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for user-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name TEXT NOT NULL, + user_id TEXT NOT NULL, + state TEXT NOT NULL DEFAULT '{{}}', + update_time REAL NOT NULL, + PRIMARY KEY (app_name, user_id) + ); """ - def _get_drop_memory_table_sql(self) -> "list[str]": + async def _get_create_metadata_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for ADK internal metadata.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get SQLite SQL to seed the ADK schema-version metadata row.""" + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT(key) DO NOTHING; + """ + + def _get_drop_app_states_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for ADK internal metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": """Get SQLite DROP TABLE SQL statements.""" - statements = [f"DROP TABLE IF EXISTS {self._memory_table}"] - if self._use_fts: - statements.extend([ - f"DROP TABLE IF EXISTS {self._memory_table}_fts", - f"DROP TRIGGER IF EXISTS {self._memory_table}_ai", - f"DROP TRIGGER IF EXISTS {self._memory_table}_ad", - f"DROP TRIGGER IF EXISTS {self._memory_table}_au", - ]) - return statements + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class AiosqliteADKMemoryStore(BaseAsyncADKMemoryStore["AiosqliteConfig"]): + """Aiosqlite ADK memory store using asynchronous SQLite driver. + + Implements memory entry storage for Google Agent Development Kit + using SQLite via the asynchronous aiosqlite driver. Provides: + - Session memory storage with JSON as TEXT + - Simple LIKE search (simple strategy) + - Optional FTS5 full-text search (sqlite_fts5 strategy) + - Julian Day timestamps (REAL) for efficient date operations + - Deduplication via event_id unique constraint + - Efficient upserts using INSERT OR IGNORE + + Args: + config: AiosqliteConfig with extension_config["adk"] settings. + """ + + __slots__ = () + + def __init__(self, config: "AiosqliteConfig") -> None: + """Initialize Aiosqlite ADK memory store. + + Args: + config: AiosqliteConfig instance. + """ + super().__init__(config) async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist. @@ -675,6 +855,76 @@ async def delete_entries_older_than(self, days: int) -> int: await conn.commit() return cursor.rowcount + async def _get_create_memory_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for memory entries. + + Returns: + SQL statement to create memory table with indexes. + """ + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_table = "" + if self._use_fts: + fts_table = f""" + CREATE VIRTUAL TABLE IF NOT EXISTS {self._memory_table}_fts USING fts5( + content_text, + content={self._memory_table}, + content_rowid=rowid + ); + + CREATE TRIGGER IF NOT EXISTS {self._memory_table}_ai AFTER INSERT ON {self._memory_table} BEGIN + INSERT INTO {self._memory_table}_fts(rowid, content_text) VALUES (new.rowid, new.content_text); + END; + + CREATE TRIGGER IF NOT EXISTS {self._memory_table}_ad AFTER DELETE ON {self._memory_table} BEGIN + INSERT INTO {self._memory_table}_fts({self._memory_table}_fts, rowid, content_text) + VALUES('delete', old.rowid, old.content_text); + END; + + CREATE TRIGGER IF NOT EXISTS {self._memory_table}_au AFTER UPDATE ON {self._memory_table} BEGIN + INSERT INTO {self._memory_table}_fts({self._memory_table}_fts, rowid, content_text) + VALUES('delete', old.rowid, old.content_text); + INSERT INTO {self._memory_table}_fts(rowid, content_text) VALUES (new.rowid, new.content_text); + END; + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id TEXT PRIMARY KEY, + session_id TEXT NOT NULL, + app_name TEXT NOT NULL, + user_id TEXT NOT NULL, + event_id TEXT NOT NULL UNIQUE, + author TEXT{owner_id_line}, + timestamp REAL NOT NULL, + content_json TEXT NOT NULL, + content_text TEXT NOT NULL, + metadata_json TEXT, + inserted_at REAL NOT NULL + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_table} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get SQLite DROP TABLE SQL statements.""" + statements = [f"DROP TABLE IF EXISTS {self._memory_table}"] + if self._use_fts: + statements.extend([ + f"DROP TABLE IF EXISTS {self._memory_table}_fts", + f"DROP TRIGGER IF EXISTS {self._memory_table}_ai", + f"DROP TRIGGER IF EXISTS {self._memory_table}_ad", + f"DROP TRIGGER IF EXISTS {self._memory_table}_au", + ]) + return statements + def _datetime_to_julian(dt: datetime) -> float: """Convert datetime to Julian Day number for SQLite storage. diff --git a/sqlspec/adapters/asyncmy/adk/store.py b/sqlspec/adapters/asyncmy/adk/store.py index 00339d017..0c1dbe2a8 100644 --- a/sqlspec/adapters/asyncmy/adk/store.py +++ b/sqlspec/adapters/asyncmy/adk/store.py @@ -10,7 +10,7 @@ from sqlspec.utils.serializers import from_json, to_json if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.asyncmy.config import AsyncmyConfig from sqlspec.extensions.adk import MemoryRecord @@ -22,230 +22,98 @@ class AsyncmyADKStore(BaseAsyncADKStore["AsyncmyConfig"]): - """MySQL/MariaDB ADK store using AsyncMy driver. - - Implements session and event storage for Google Agent Development Kit - using MySQL/MariaDB via the AsyncMy driver. Provides: - - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) - - Atomic event-append + state-update in one transaction - - Microsecond-precision timestamps - - Foreign key constraints with cascade delete - - Efficient upserts using ON DUPLICATE KEY UPDATE - """ + """MySQL/MariaDB ADK store using AsyncMy driver.""" __slots__ = () def __init__(self, config: "AsyncmyConfig") -> None: - """Initialize AsyncMy ADK store. - - Args: - config: AsyncmyConfig instance. - """ + """Initialize AsyncMy ADK store.""" super().__init__(config) - def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": - """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. - - MySQL ignores inline REFERENCES syntax in column definitions. - This method extracts the column definition and creates a separate - FOREIGN KEY constraint. - - Args: - column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" - - Returns: - Tuple of (column_definition, foreign_key_constraint) - """ - references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) - - if not references_match: - return (column_ddl.strip(), "") - - col_def = column_ddl[: references_match.start()].strip() - fk_clause = references_match.group(1).strip() - col_name = col_def.split()[0] - fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" - - return (col_def, fk_constraint) - - async def _get_create_sessions_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for sessions. - - Returns: - SQL statement to create adk_sessions table with indexes. - """ - owner_id_col = "" - fk_constraint = "" - - if self._owner_id_column_ddl: - col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_col = f"{col_def}," - if fk_def: - fk_constraint = f",\n {fk_def}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{self._session_table}_app_user (app_name, user_id), - INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - async def _get_create_events_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for events. - - Returns: - SQL statement to create adk_events table with indexes. - """ - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_json JSON NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, - INDEX idx_{self._events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - def _get_drop_tables_sql(self) -> "list[str]": - """Get MySQL DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" + """Create all ADK session tables if they don't exist.""" async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - """ - state_json = to_json(state) - + """Create a new session.""" params: tuple[Any, ...] if self._owner_id_column_name: sql = f""" INSERT INTO {self._session_table} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time) VALUES (%s, %s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) """ - params = (session_id, app_name, user_id, owner_id, state_json) + params = (session_id, app_name, user_id, owner_id, to_json(state)) else: sql = f""" INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time) VALUES (%s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) """ - params = (session_id, app_name, user_id, state_json) + params = (session_id, app_name, user_id, to_json(state)) async with self._config.provide_connection() as conn, conn.cursor() as cursor: await cursor.execute(sql, params) await conn.commit() - return await self.get_session(session_id) # type: ignore[return-value] - - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - """ - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = %s - """ + result = await self.get_session(app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by scoped identifiers.""" try: async with self._config.provide_connection() as conn, conn.cursor() as cursor: - await cursor.execute(sql, (session_id,)) - row = await cursor.fetchone() - - if row is None: - return None - - session_id_val, app_name, user_id, state_json, create_time, update_time = row - - return SessionRecord( - id=session_id_val, - app_name=app_name, - user_id=user_id, - state=from_json(state_json) if isinstance(state_json, str) else state_json, - create_time=create_time, - update_time=update_time, + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + await cursor.execute( + f""" + UPDATE {self._session_table} + SET update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), + ) + await conn.commit() + + await cursor.execute( + f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), ) - except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue][reportAttributeAccessIssue] - if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + row = await cursor.fetchone() + return _session_record_from_row(row) if row is not None else None + except asyncmy.errors.ProgrammingError as exc: # pyright: ignore[reportAttributeAccessIssue] + if _is_mysql_table_missing(exc): return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - """ - state_json = to_json(state) - + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s - """ - - async with self._config.provide_connection() as conn, conn.cursor() as cursor: - await cursor.execute(sql, (state_json, session_id)) - await conn.commit() - - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ - sql = f"DELETE FROM {self._session_table} WHERE id = %s" async with self._config.provide_connection() as conn, conn.cursor() as cursor: - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (to_json(state), app_name, user_id, session_id)) await conn.commit() async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - """ + """List sessions for an app, optionally filtered by user.""" if user_id is None: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -253,7 +121,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis WHERE app_name = %s ORDER BY update_time DESC """ - params: tuple[str, ...] = (app_name,) + params: tuple[Any, ...] = (app_name,) else: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -267,145 +135,120 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async with self._config.provide_connection() as conn, conn.cursor() as cursor: await cursor.execute(sql, params) rows = await cursor.fetchall() - - return [ - SessionRecord( - id=row[0], - app_name=row[1], - user_id=row[2], - state=from_json(row[3]) if isinstance(row[3], str) else row[3], - create_time=row[4], - update_time=row[5], - ) - for row in rows - ] - except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue] - if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return [_session_record_from_row(row) for row in rows] + except asyncmy.errors.ProgrammingError as exc: # pyright: ignore[reportAttributeAccessIssue] + if _is_mysql_table_missing(exc): return [] raise - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and all associated events.""" + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" - Args: - event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_json). - """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (app_name, user_id, session_id)) + await conn.commit() + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) """ async with self._config.provide_connection() as conn, conn.cursor() as cursor: - await cursor.execute( - sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_json_str, - ), - ) + await cursor.execute(sql, _event_insert_params(event_record)) await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - MySQL doesn't support UPDATE...RETURNING; we follow the UPDATE with a - SELECT inside the same transaction so callers get the refreshed row - in a single round-trip pair (no separate connection acquisition). - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. - """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json - state_json = to_json(state) - + """Atomically append an event and update session + scoped state.""" insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) """ - update_sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ - select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ async with self._config.provide_connection() as conn, conn.cursor() as cursor: - await cursor.execute( - insert_sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_json_str, - ), - ) - await cursor.execute(update_sql, (state_json, session_id)) - await cursor.execute(select_sql, (session_id,)) - row = await cursor.fetchone() - await conn.commit() - - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) + try: + await cursor.execute(update_sql, (to_json(state), app_name, user_id, session_id)) + await cursor.execute(select_sql, (app_name, user_id, session_id)) + row = await cursor.fetchone() + if row is None: + _raise_session_not_found(session_id) + await cursor.execute( + insert_sql, + ( + event_record["id"], + app_name, + user_id, + session_id, + event_record["invocation_id"], + event_record["timestamp"], + _json_for_storage(event_record["event_data"]), + ), + ) + if app_state is not None: + await cursor.execute( + _mysql_upsert_app_state_sql(self._app_state_table), (app_name, to_json(app_state)) + ) + if user_state is not None: + await cursor.execute( + _mysql_upsert_user_state_sql(self._user_state_table), (app_name, user_id, to_json(user_state)) + ) + await conn.commit() + except Exception: + await conn.rollback() + raise - state_value = row[3] - return SessionRecord( - id=row[0], - app_name=row[1], - user_id=row[2], - state=from_json(state_value) if isinstance(state_value, str) else state_value, - create_time=row[4], - update_time=row[5], - ) + return _session_record_from_row(row) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + """Get events for a session.""" + if limit == 0: + return [] + where_clauses = ["app_name = %s", "user_id = %s", "session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: where_clauses.append("timestamp > %s") params.append(after_timestamp) - - where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" + limit_clause = "" + if limit is not None: + limit_clause = " LIMIT %s" + params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data FROM {self._events_table} - WHERE {where_clause} + WHERE {" AND ".join(where_clauses)} ORDER BY timestamp ASC{limit_clause} """ @@ -413,41 +256,147 @@ async def get_events( async with self._config.provide_connection() as conn, conn.cursor() as cursor: await cursor.execute(sql, params) rows = await cursor.fetchall() - - return [ - EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], - timestamp=row[3], - event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], - ) - for row in rows - ] - except asyncmy.errors.ProgrammingError as e: # pyright: ignore[reportAttributeAccessIssue] - if "doesn't exist" in str(e) or e.args[0] == MYSQL_TABLE_NOT_FOUND_ERROR: + return [_event_record_from_row(row) for row in rows] + except asyncmy.errors.ProgrammingError as exc: # pyright: ignore[reportAttributeAccessIssue] + if _is_mysql_table_missing(exc): return [] raise + async def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" -def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": - """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except asyncmy.errors.ProgrammingError as exc: # pyright: ignore[reportAttributeAccessIssue] + if _is_mysql_table_missing(exc): + return 0 + raise - Args: - column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE". + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the threshold.""" + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" - Returns: - Tuple of (column_definition, foreign_key_constraint). - """ - references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) - if not references_match: - return (column_ddl.strip(), "") + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (updated_before,)) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except asyncmy.errors.ProgrammingError as exc: # pyright: ignore[reportAttributeAccessIssue] + if _is_mysql_table_missing(exc): + return 0 + raise - col_def = column_ddl[: references_match.start()].strip() - fk_clause = references_match.group(1).strip() - col_name = col_def.split()[0] - fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" - return (col_def, fk_constraint) + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (app_name,)) + row = await cursor.fetchone() + return _json_dict(row[0]) if row is not None else None + except asyncmy.errors.ProgrammingError as exc: # pyright: ignore[reportAttributeAccessIssue] + if _is_mysql_table_missing(exc): + return None + raise + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (app_name, user_id)) + row = await cursor.fetchone() + return _json_dict(row[0]) if row is not None else None + except asyncmy.errors.ProgrammingError as exc: # pyright: ignore[reportAttributeAccessIssue] + if _is_mysql_table_missing(exc): + return None + raise + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(_mysql_upsert_app_state_sql(self._app_state_table), (app_name, to_json(state))) + await conn.commit() + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute( + _mysql_upsert_user_state_sql(self._user_state_table), (app_name, user_id, to_json(state)) + ) + await conn.commit() + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE `key` = %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(sql, (key,)) + row = await cursor.fetchone() + return str(row[0]) if row is not None else None + except asyncmy.errors.ProgrammingError as exc: # pyright: ignore[reportAttributeAccessIssue] + if _is_mysql_table_missing(exc): + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + async with self._config.provide_connection() as conn, conn.cursor() as cursor: + await cursor.execute(_mysql_upsert_metadata_sql(self._metadata_table), (key, value)) + await conn.commit() + + async def _get_create_sessions_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for sessions.""" + return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) + + async def _get_create_events_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for events.""" + return _mysql_events_ddl(self._events_table, self._session_table) + + async def _get_create_app_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for app-scoped state.""" + return _mysql_app_state_ddl(self._app_state_table) + + async def _get_create_user_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for user-scoped state.""" + return _mysql_user_state_ddl(self._user_state_table) + + async def _get_create_metadata_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for ADK metadata.""" + return _mysql_metadata_ddl(self._metadata_table) + + async def _get_seed_metadata_sql(self) -> str: + """Get MySQL metadata seed SQL.""" + return f"INSERT IGNORE INTO {self._metadata_table} (`key`, value) VALUES ('schema_version', '1')" + + def _get_drop_app_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for ADK metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + """Get MySQL DROP TABLE SQL statements.""" + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] class AsyncmyADKMemoryStore(BaseAsyncADKMemoryStore["AsyncmyConfig"]): @@ -459,42 +408,6 @@ def __init__(self, config: "AsyncmyConfig") -> None: """Initialize AsyncMy memory store.""" super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for memory entries.""" - owner_id_line = "" - fk_constraint = "" - if self._owner_id_column_ddl: - col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_line = f",\n {col_def}" - if fk_def: - fk_constraint = f",\n {fk_def}" - - fts_index = "" - if self._use_fts: - fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" - - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content_json JSON NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSON, - inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), - INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get MySQL DROP TABLE SQL statements.""" - return [f"DROP TABLE IF EXISTS {self._memory_table}"] - async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" if not self._enabled: @@ -630,3 +543,206 @@ async def delete_entries_older_than(self, days: int) -> int: await cursor.execute(sql, (days,)) await conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def _get_create_memory_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for memory entries.""" + owner_id_line = "" + fk_constraint = "" + if self._owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) + owner_id_line = f",\n {col_def}" + if fk_def: + fk_constraint = f",\n {fk_def}" + + fts_index = "" + if self._use_fts: + fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + content_json JSON NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSON, + inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), + INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get MySQL DROP TABLE SQL statements.""" + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + + +def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": + """Parse owner ID column DDL for MySQL FOREIGN KEY syntax. + + Args: + column_ddl: Column DDL like "tenant_id BIGINT NOT NULL REFERENCES tenants(id) ON DELETE CASCADE". + + Returns: + Tuple of (column_definition, foreign_key_constraint). + """ + references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) + if not references_match: + return (column_ddl.strip(), "") + + col_def = column_ddl[: references_match.start()].strip() + fk_clause = references_match.group(1).strip() + col_name = col_def.split()[0] + fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" + return (col_def, fk_constraint) + + +def _is_mysql_table_missing(exc: BaseException) -> bool: + args = getattr(exc, "args", ()) + return "doesn't exist" in str(exc) or bool(args and args[0] == MYSQL_TABLE_NOT_FOUND_ERROR) + + +def _json_for_storage(value: Any) -> str: + return value if isinstance(value, str) else to_json(value) + + +def _json_dict(value: Any) -> "dict[str, Any]": + if isinstance(value, bytearray): + value = bytes(value) + if isinstance(value, (bytes, str)): + return cast("dict[str, Any]", from_json(value)) + return cast("dict[str, Any]", value) + + +def _session_record_from_row(row: Any) -> SessionRecord: + return SessionRecord( + id=row[0], app_name=row[1], user_id=row[2], state=_json_dict(row[3]), create_time=row[4], update_time=row[5] + ) + + +def _event_record_from_row(row: Any) -> EventRecord: + return EventRecord( + id=row[0], + app_name=row[1], + user_id=row[2], + session_id=row[3], + invocation_id=row[4], + timestamp=row[5], + event_data=_json_dict(row[6]), + ) + + +def _event_insert_params(event_record: EventRecord) -> "tuple[Any, ...]": + return ( + event_record["id"], + event_record["app_name"], + event_record["user_id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + _json_for_storage(event_record["event_data"]), + ) + + +def _mysql_sessions_ddl(session_table: str, owner_id_column_ddl: "str | None") -> str: + owner_id_line = "" + fk_constraint = "" + if owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(owner_id_column_ddl) + owner_id_line = f"\n {col_def}," + if fk_def: + fk_constraint = f",\n {fk_def}" + + return f""" + CREATE TABLE IF NOT EXISTS {session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL,{owner_id_line} + state JSON NOT NULL, + create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + INDEX idx_{session_table}_app_user (app_name, user_id), + INDEX idx_{session_table}_update_time (update_time DESC){fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_events_ddl(events_table: str, session_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {events_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + event_data JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE, + INDEX idx_{events_table}_scope (app_name, user_id, session_id, timestamp ASC), + INDEX idx_{events_table}_session (session_id, timestamp ASC) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_app_state_ddl(app_state_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_user_state_ddl(user_state_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_metadata_ddl(metadata_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {metadata_table} ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_upsert_app_state_sql(app_state_table: str) -> str: + return f""" + INSERT INTO {app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + +def _mysql_upsert_user_state_sql(user_state_table: str) -> str: + return f""" + INSERT INTO {user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + +def _mysql_upsert_metadata_sql(metadata_table: str) -> str: + return f""" + INSERT INTO {metadata_table} (`key`, value) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value) + """ + + +def _raise_session_not_found(session_id: str) -> None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) diff --git a/sqlspec/adapters/asyncpg/adk/store.py b/sqlspec/adapters/asyncpg/adk/store.py index 8c5563b95..0d4d7776c 100644 --- a/sqlspec/adapters/asyncpg/adk/store.py +++ b/sqlspec/adapters/asyncpg/adk/store.py @@ -9,7 +9,7 @@ from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.asyncpg.config import AsyncpgConfig from sqlspec.extensions.adk import MemoryRecord @@ -25,11 +25,11 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via asyncpg. Events are stored as a single JSONB blob - (``event_json``) alongside indexed scalar columns for efficient querying. + (``event_data``) alongside indexed scalar columns for efficient querying. Provides: - Session state management with JSONB storage - - Full-fidelity event storage via ``event_json`` JSONB column + - Full-fidelity event storage via ``event_data`` JSONB column - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete @@ -46,54 +46,14 @@ class AsyncpgADKStore(BaseAsyncADKStore[AsyncConfigT]): def __init__(self, config: AsyncConfigT) -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -113,18 +73,34 @@ async def create_session( """ await conn.execute(sql, session_id, app_name, user_id, state) - return await self.get_session(session_id) # type: ignore[return-value] + result = await self.get_session(app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result - async def get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = $1 - """ + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE app_name = $1 AND user_id = $2 AND id = $3 + RETURNING id, app_name, user_id, state, create_time, update_time + """ + params = [app_name, user_id, session_id] + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = $1 AND user_id = $2 AND id = $3 + """ + params = [app_name, user_id, session_id] try: async with self._config.provide_connection() as conn: - row = await conn.fetchrow(sql, session_id) + row = await conn.fetchrow(sql, *params) if row is None: return None @@ -140,21 +116,21 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": except asyncpg.exceptions.UndefinedTableError: return None - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 """ async with self._config.provide_connection() as conn: - await conn.execute(sql, state, session_id) + await conn.execute(sql, state, app_name, user_id, session_id) - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = $1" + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = $1 AND user_id = $2 AND id = $3" async with self._config.provide_connection() as conn: - await conn.execute(sql, session_id) + await conn.execute(sql, app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -195,49 +171,74 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ async with self._config.provide_connection() as conn: await conn.execute( sql, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ update_sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ async with self._config.provide_connection() as conn, conn.transaction(): await conn.execute( insert_sql, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ) - row = await conn.fetchrow(update_sql, state, session_id) - - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) + row = await conn.fetchrow(update_sql, state, app_name, user_id, session_id) + if row is None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) + if app_state is not None: + await conn.execute(app_upsert_sql, app_name, app_state) + if user_state is not None: + await conn.execute(user_upsert_sql, app_name, user_id, user_state) return SessionRecord( id=row["id"], @@ -249,25 +250,34 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = $1"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = ["s.app_name = $1", "s.user_id = $2", "e.session_id = $3"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append(f"timestamp > ${len(params) + 1}") + where_clauses.append(f"e.timestamp > ${len(params) + 1}") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT ${len(params) + 1}" if limit else "" - if limit: + limit_clause = f" LIMIT ${len(params) + 1}" if limit is not None else "" + if limit is not None: params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -276,17 +286,197 @@ async def get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] except asyncpg.exceptions.UndefinedTableError: return [] + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < $1" + + try: + async with self._config.provide_connection() as conn: + result = await conn.execute(sql, before) + return int(result.split()[-1]) if result else 0 + except asyncpg.exceptions.UndefinedTableError: + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < $1" + + try: + async with self._config.provide_connection() as conn: + result = await conn.execute(sql, updated_before) + return int(result.split()[-1]) if result else 0 + except asyncpg.exceptions.UndefinedTableError: + return 0 + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = $1" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, app_name) + return row["state"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = $1 AND user_id = $2" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, app_name, user_id) + return row["state"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = $1" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, key) + return row["value"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def set_metadata(self, key: str, value: str) -> None: + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ($1, $2) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, key, value) + + async def _get_create_sessions_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) WITH (fillfactor = 80); + """ + + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + class AsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["AsyncpgConfig"]): """PostgreSQL ADK memory store using asyncpg driver. diff --git a/sqlspec/adapters/cockroach_asyncpg/adk/store.py b/sqlspec/adapters/cockroach_asyncpg/adk/store.py index 882621e83..e2ed54458 100644 --- a/sqlspec/adapters/cockroach_asyncpg/adk/store.py +++ b/sqlspec/adapters/cockroach_asyncpg/adk/store.py @@ -9,7 +9,7 @@ from sqlspec.utils.logging import get_logger if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.cockroach_asyncpg.config import CockroachAsyncpgConfig from sqlspec.extensions.adk import MemoryRecord @@ -21,76 +21,21 @@ class CockroachAsyncpgADKStore(BaseAsyncADKStore["CockroachAsyncpgConfig"]): - """CockroachDB ADK store using asyncpg driver. - - Implements session and event storage for Google Agent Development Kit - using CockroachDB via asyncpg in PostgreSQL compatibility mode. - Events are stored as a single JSONB blob (``event_json``) alongside - indexed scalar columns for efficient querying. - - CockroachDB-specific differences from native PostgreSQL: - - No FILLFACTOR (CockroachDB uses different storage engine) - - No BRIN indexes (different physical storage layout) - - GIN/Inverted indexes on JSONB are fully supported (v23.1+) - - Native tsvector/tsquery FTS with GIN is supported (v23.1+) - """ + """CockroachDB ADK store using asyncpg driver.""" __slots__ = () def __init__(self, config: "CockroachAsyncpgConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json - ON {self._events_table} USING GIN (event_json); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -112,22 +57,34 @@ async def create_session( async with self._config.provide_connection() as conn: await conn.execute(sql, *params) - result = await self.get_session(session_id) + result = await self.get_session(app_name, user_id, session_id) if result is None: msg = "Session creation failed" raise RuntimeError(msg) return result - async def get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = $1 - """ + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE app_name = $1 AND user_id = $2 AND id = $3 + RETURNING id, app_name, user_id, state, create_time, update_time + """ + params = (app_name, user_id, session_id) + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = $1 AND user_id = $2 AND id = $3 + """ + params = (app_name, user_id, session_id) try: async with self._config.provide_connection() as conn: - row = await conn.fetchrow(sql, session_id) + row = await conn.fetchrow(sql, *params) if row is None: return None @@ -142,21 +99,15 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": except asyncpg.exceptions.UndefinedTableError: return None - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 """ async with self._config.provide_connection() as conn: - await conn.execute(sql, state, session_id) - - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = $1" - - async with self._config.provide_connection() as conn: - await conn.execute(sql, session_id) + await conn.execute(sql, state, app_name, user_id, session_id) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -194,52 +145,85 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis for row in rows ] + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = $1 AND user_id = $2 AND id = $3" + + async with self._config.provide_connection() as conn: + await conn.execute(sql, app_name, user_id, session_id) + async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ async with self._config.provide_connection() as conn: await conn.execute( sql, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ update_sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 - RETURNING id, app_name, user_id, state, create_time, update_time + WHERE app_name = $2 AND user_id = $3 AND id = $4 + """ + select_sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = $1 AND user_id = $2 AND id = $3 + """ + app_upsert_sql = f""" + UPSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + """ + user_upsert_sql = f""" + UPSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) """ async with self._config.provide_connection() as conn, conn.transaction(): await conn.execute( insert_sql, + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ) - row = await conn.fetchrow(update_sql, state, session_id) - - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) + result = await conn.execute(update_sql, state, app_name, user_id, session_id) + if result == "UPDATE 0": + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) + if app_state is not None: + await conn.execute(app_upsert_sql, app_name, app_state) + if user_state is not None: + await conn.execute(user_upsert_sql, app_name, user_id, user_state) + row = await conn.fetchrow(select_sql, app_name, user_id, session_id) + if row is None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) return SessionRecord( id=row["id"], @@ -251,25 +235,34 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = $1"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = ["s.app_name = $1", "s.user_id = $2", "e.session_id = $3"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append(f"timestamp > ${len(params) + 1}") + where_clauses.append(f"e.timestamp > ${len(params) + 1}") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT ${len(params) + 1}" if limit else "" - if limit: + limit_clause = f" LIMIT ${len(params) + 1}" if limit is not None else "" + if limit is not None: params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -280,61 +273,199 @@ async def get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < $1" -class CockroachAsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["CockroachAsyncpgConfig"]): - """CockroachDB ADK memory store using asyncpg driver.""" + try: + async with self._config.provide_connection() as conn: + result = await conn.execute(sql, before) + return int(result.split()[-1]) if result else 0 + except asyncpg.exceptions.UndefinedTableError: + return 0 - __slots__ = () + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < $1" - def __init__(self, config: "CockroachAsyncpgConfig") -> None: - super().__init__(config) + try: + async with self._config.provide_connection() as conn: + result = await conn.execute(sql, updated_before) + return int(result.split()[-1]) if result else 0 + except asyncpg.exceptions.UndefinedTableError: + return 0 - async def _get_create_memory_table_sql(self) -> str: + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = $1" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, app_name) + return row["state"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = $1 AND user_id = $2" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, app_name, user_id) + return row["state"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, app_name, state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, app_name, user_id, state) + + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = $1" + + try: + async with self._config.provide_connection() as conn: + row = await conn.fetchrow(sql, key) + return row["value"] if row is not None else None + except asyncpg.exceptions.UndefinedTableError: + return None + + async def set_metadata(self, key: str, value: str) -> None: + sql = f""" + UPSERT INTO {self._metadata_table} (key, value) + VALUES ($1, $2) + """ + + async with self._config.provide_connection() as conn: + await conn.execute(sql, key, value) + + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts - ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); - """ + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( + CREATE TABLE IF NOT EXISTS {self._events_table} ( id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data + ON {self._events_table} USING GIN (event_data); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( app_name VARCHAR(128) NOT NULL, user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMPTZ NOT NULL, - content_json JSONB NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSONB, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) ); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_index} + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING """ - def _get_drop_memory_table_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._memory_table}"] + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class CockroachAsyncpgADKMemoryStore(BaseAsyncADKMemoryStore["CockroachAsyncpgConfig"]): + """CockroachDB ADK memory store using asyncpg driver.""" + + __slots__ = () + + def __init__(self, config: "CockroachAsyncpgConfig") -> None: + super().__init__(config) async def create_tables(self) -> None: if not self._enabled: @@ -463,3 +594,41 @@ async def delete_entries_older_than(self, days: int) -> int: async with self._config.provide_connection() as conn: result = await conn.execute(sql) return int(result.split()[-1]) if result else 0 + + async def _get_create_memory_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMPTZ NOT NULL, + content_json JSONB NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSONB, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_index} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + return [f"DROP TABLE IF EXISTS {self._memory_table}"] diff --git a/sqlspec/adapters/cockroach_psycopg/adk/store.py b/sqlspec/adapters/cockroach_psycopg/adk/store.py index b9fd44aa7..cadb5a0d7 100644 --- a/sqlspec/adapters/cockroach_psycopg/adk/store.py +++ b/sqlspec/adapters/cockroach_psycopg/adk/store.py @@ -1,18 +1,18 @@ """CockroachDB ADK store for Google Agent Development Kit session/event storage (psycopg).""" -from typing import TYPE_CHECKING, Any, cast +from typing import TYPE_CHECKING, Any, NoReturn, cast from psycopg import errors from psycopg import sql as pg_sql +from psycopg.rows import dict_row from psycopg.types.json import Jsonb -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore from sqlspec.utils.logging import get_logger -from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.cockroach_psycopg.config import CockroachPsycopgAsyncConfig, CockroachPsycopgSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -29,82 +29,26 @@ class CockroachPsycopgAsyncADKStore(BaseAsyncADKStore["CockroachPsycopgAsyncConfig"]): - """CockroachDB ADK store using psycopg async driver. - - Implements session and event storage for Google Agent Development Kit - using CockroachDB via psycopg in PostgreSQL compatibility mode. - Events are stored as a single JSONB blob (``event_json``) alongside - indexed scalar columns for efficient querying. - - CockroachDB-specific differences from native PostgreSQL: - - No FILLFACTOR (CockroachDB uses different storage engine) - - SQL strings require ``.encode()`` for cockroach-psycopg driver - - GIN/Inverted indexes on JSONB are fully supported (v23.1+) - - Native tsvector/tsquery FTS with GIN is supported (v23.1+) - """ + """CockroachDB ADK store using psycopg async driver.""" __slots__ = () def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json - ON {self._events_table} USING GIN (event_json); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = Jsonb(state) - params: tuple[Any, ...] if self._owner_id_column_name: sql = f""" @@ -119,26 +63,36 @@ async def create_session( """ params = (session_id, app_name, user_id, state_json) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), params) await conn.commit() - result = await self.get_session(session_id) + result = await self.get_session(app_name, user_id, session_id) if result is None: msg = "Session creation failed" raise RuntimeError(msg) return result - async def get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = %s - """ + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE app_name = %s AND user_id = %s AND id = %s + RETURNING id, app_name, user_id, state, create_time, update_time + """ + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = %s AND user_id = %s AND id = %s + """ try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (session_id,)) + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (app_name, user_id, session_id)) row = await cur.fetchone() if row is None: @@ -155,22 +109,15 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (Jsonb(state), session_id)) - await conn.commit() - - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" - - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql.encode(), (session_id,)) + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (Jsonb(state), app_name, user_id, session_id)) await conn.commit() async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": @@ -192,41 +139,47 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis params = (app_name, user_id) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), params) rows = await cur.fetchall() - - return [ - SessionRecord( - id=row["id"], - app_name=row["app_name"], - user_id=row["user_id"], - state=row["state"], - create_time=row["create_time"], - update_time=row["update_time"], - ) - for row in rows - ] except errors.UndefinedTable: return [] + return [ + SessionRecord( + id=row["id"], + app_name=row["app_name"], + user_id=row["user_id"], + state=row["state"], + create_time=row["create_time"], + update_time=row["update_time"], + ) + for row in rows + ] + + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" + + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (app_name, user_id, session_id)) + await conn.commit() + async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value - - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute( sql.encode(), ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), @@ -234,42 +187,63 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ update_sql = f""" UPDATE {self._session_table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """ - - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value - - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute( - insert_sql.encode(), - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - jsonb_value, - ), - ) - await cur.execute(update_sql.encode(), (Jsonb(state), session_id)) - row = await cur.fetchone() + app_upsert_sql = f""" + UPSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + """ + user_upsert_sql = f""" + UPSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + """ + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value + + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + try: + await cur.execute( + insert_sql.encode(), + ( + event_record["id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + jsonb_value, + ), + ) + await cur.execute(update_sql.encode(), (Jsonb(state), app_name, user_id, session_id)) + row = await cur.fetchone() + if row is None: + _raise_missing_session(session_id) + if app_state is not None: + await cur.execute(app_upsert_sql.encode(), (app_name, Jsonb(app_state))) + if user_state is not None: + await cur.execute(user_upsert_sql.encode(), (app_name, user_id, Jsonb(user_state))) + except Exception: + await conn.rollback() + raise await conn.commit() - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) - return SessionRecord( id=row["id"], app_name=row["app_name"], @@ -280,66 +254,339 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = ["s.app_name = %s", "s.user_id = %s", "e.session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append("e.timestamp > %s") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) - limit_clause = " LIMIT %s" if limit else "" - if limit: + limit_clause = " LIMIT %s" if limit is not None else "" + if limit is not None: params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(sql.encode(), tuple(params)) rows = await cur.fetchall() - - return [ - EventRecord( - session_id=row["session_id"], - invocation_id=row["invocation_id"], - author=row["author"], - timestamp=row["timestamp"], - event_json=row["event_json"], - ) - for row in rows - ] except errors.UndefinedTable: return [] + return [ + EventRecord( + id=row["id"], + session_id=row["session_id"], + invocation_id=row["invocation_id"], + timestamp=row["timestamp"], + event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], + ) + for row in rows + ] + + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (before,)) + await conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (updated_before,)) + await conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (app_name,)) + row = await cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (app_name, user_id)) + row = await cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + """ + + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (app_name, Jsonb(state))) + await conn.commit() -class CockroachPsycopgSyncADKStore(BaseAsyncADKStore["CockroachPsycopgSyncConfig"]): - """CockroachDB ADK store using psycopg sync driver. + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + """ - Implements session and event storage for Google Agent Development Kit - using CockroachDB via psycopg in PostgreSQL compatibility mode (sync). - Events are stored as a single JSONB blob (``event_json``) alongside - indexed scalar columns for efficient querying. + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (app_name, user_id, Jsonb(state))) + await conn.commit() - CockroachDB-specific differences from native PostgreSQL: - - No FILLFACTOR (CockroachDB uses different storage engine) - - SQL strings require ``.encode()`` for cockroach-psycopg driver - - GIN/Inverted indexes on JSONB are fully supported (v23.1+) - """ + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = %s" + + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (key,)) + row = await cur.fetchone() + return row["value"] if row is not None else None + except errors.UndefinedTable: + return None + + async def set_metadata(self, key: str, value: str) -> None: + sql = f""" + UPSERT INTO {self._metadata_table} (key, value) + VALUES (%s, %s) + """ + + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(sql.encode(), (key, value)) + await conn.commit() + + async def _get_create_sessions_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data + ON {self._events_table} USING GIN (event_data); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ); + """ + + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class CockroachPsycopgSyncADKStore(BaseSyncADKStore["CockroachPsycopgSyncConfig"]): + """CockroachDB ADK store using psycopg sync driver.""" __slots__ = () def __init__(self, config: "CockroachPsycopgSyncConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return self._create_session(session_id, app_name, user_id, state, owner_id) + + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return self._get_session(app_name, user_id, session_id, renew_for=renew_for) + + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + self._update_session_state(app_name, user_id, session_id, state) + + def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return self._list_sessions(app_name, user_id) + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and associated events.""" + self._delete_session(app_name, user_id, session_id) + + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + self._append_event(event_record) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update session + scoped state.""" + return self._append_event_and_update_state( + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state + ) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session.""" + return self._get_events(app_name, user_id, session_id, after_timestamp, limit) + + def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return self._delete_expired_events(before) + + def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return self._delete_idle_sessions(updated_before) + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return self._get_app_state(app_name) + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return self._get_user_state(app_name, user_id) + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + self._upsert_app_state(app_name, state) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + self._upsert_user_state(app_name, user_id, state) + + def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return self._get_metadata(key) + + def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + self._set_metadata(key, value) + + def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -365,41 +612,90 @@ async def _get_create_sessions_table_sql(self) -> str: WHERE state != '{{}}'::jsonb; """ - async def _get_create_events_table_sql(self) -> str: + def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_json - ON {self._events_table} USING GIN (event_json); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_event_data + ON {self._events_table} USING GIN (event_data); + """ + + def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + """ + + def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ); + """ + + def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING """ + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_sessions_table_sql)()) - driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(self._get_create_sessions_table_sql()) + driver.execute_script(self._get_create_events_table_sql()) + driver.execute_script(self._get_create_app_states_table_sql()) + driver.execute_script(self._get_create_user_states_table_sql()) + driver.execute_script(self._get_create_metadata_table_sql()) + driver.execute_script(self._get_seed_metadata_sql()) def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: state_json = Jsonb(state) - params: tuple[Any, ...] if self._owner_id_column_name: sql = f""" @@ -414,32 +710,36 @@ def _create_session( """ params = (session_id, app_name, user_id, state_json) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), params) conn.commit() - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Session creation failed" raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = %s - """ + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE app_name = %s AND user_id = %s AND id = %s + RETURNING id, app_name, user_id, state, create_time, update_time + """ + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = %s AND user_id = %s AND id = %s + """ try: - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (session_id,)) + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (app_name, user_id, session_id)) row = cur.fetchone() if row is None: @@ -456,36 +756,24 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (Jsonb(state), session_id)) + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (Jsonb(state), app_name, user_id, session_id)) conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - - def _delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(sql.encode(), (session_id,)) + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (app_name, user_id, session_id)) conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" @@ -505,7 +793,7 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses params = (app_name, user_id) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), params) rows = cur.fetchall() @@ -523,46 +811,85 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses except errors.UndefinedTable: return [] - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - - def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - insert_sql = f""" + def _insert_event(self, event_record: EventRecord) -> None: + sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """ - update_sql = f""" - UPDATE {self._session_table} - SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s - RETURNING id, app_name, user_id, state, create_time, update_time - """ + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value - - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute( - insert_sql.encode(), + sql.encode(), ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), ) - cur.execute(update_sql.encode(), (Jsonb(state), session_id)) - row = cur.fetchone() conn.commit() - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) + def _append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + insert_sql = f""" + INSERT INTO {self._events_table} ( + id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s) + """ + update_sql = f""" + UPDATE {self._session_table} + SET state = %s, update_time = CURRENT_TIMESTAMP + WHERE app_name = %s AND user_id = %s AND id = %s + RETURNING id, app_name, user_id, state, create_time, update_time + """ + app_upsert_sql = f""" + UPSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + """ + user_upsert_sql = f""" + UPSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + """ + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value + + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + try: + cur.execute( + insert_sql.encode(), + ( + event_record["id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + jsonb_value, + ), + ) + cur.execute(update_sql.encode(), (Jsonb(state), app_name, user_id, session_id)) + row = cur.fetchone() + if row is None: + _raise_missing_session(session_id) + if app_state is not None: + cur.execute(app_upsert_sql.encode(), (app_name, Jsonb(app_state))) + if user_state is not None: + cur.execute(user_upsert_sql.encode(), (app_name, user_id, Jsonb(user_state))) + except Exception: + conn.rollback() + raise + conn.commit() return SessionRecord( id=row["id"], @@ -573,134 +900,154 @@ def _append_event_and_update_state( update_time=row["update_time"], ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - - def _insert_event(self, event_record: EventRecord) -> None: - sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) - """ - - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value - - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute( - sql.encode(), - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - jsonb_value, - ), - ) - conn.commit() - def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = ["s.app_name = %s", "s.user_id = %s", "e.session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append("e.timestamp > %s") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) - limit_clause = " LIMIT %s" if limit else "" + limit_clause = " LIMIT %s" if limit is not None else "" + if limit is not None: + params.append(limit) + sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ - if limit: - params.append(limit) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(sql.encode(), tuple(params)) rows = cur.fetchall() return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] except errors.UndefinedTable: return [] - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < %s" - def _append_event(self, event_record: EventRecord) -> None: - """Synchronous implementation of append_event.""" - self._insert_event(event_record) + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (before,)) + conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < %s" + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (updated_before,)) + conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 -class CockroachPsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgAsyncConfig"]): - """CockroachDB ADK memory store using psycopg async driver.""" + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = %s" - __slots__ = () + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (app_name,)) + row = cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None - def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None: - super().__init__(config) + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = %s AND user_id = %s" - async def _get_create_memory_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (app_name, user_id)) + row = cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts - ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); - """ + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + """ - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMPTZ NOT NULL, - content_json JSONB NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSONB, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (app_name, Jsonb(state))) + conn.commit() - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + UPSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_index} + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (app_name, user_id, Jsonb(state))) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = %s" + + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (key,)) + row = cur.fetchone() + return row["value"] if row is not None else None + except errors.UndefinedTable: + return None + + def _set_metadata(self, key: str, value: str) -> None: + sql = f""" + UPSERT INTO {self._metadata_table} (key, value) + VALUES (%s, %s) """ - def _get_drop_memory_table_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._memory_table}"] + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(sql.encode(), (key, value)) + conn.commit() + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + self._insert_event(event_record) + + +class CockroachPsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgAsyncConfig"]): + """CockroachDB ADK memory store using psycopg async driver.""" + + __slots__ = () + + def __init__(self, config: "CockroachPsycopgAsyncConfig") -> None: + super().__init__(config) async def create_tables(self) -> None: if not self._enabled: @@ -820,8 +1167,46 @@ async def delete_entries_older_than(self, days: int) -> int: await conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + async def _get_create_memory_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMPTZ NOT NULL, + content_json JSONB NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSONB, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_index} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + -class CockroachPsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsycopgSyncConfig"]): +class CockroachPsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["CockroachPsycopgSyncConfig"]): """CockroachDB ADK memory store using psycopg sync driver.""" __slots__ = () @@ -829,7 +1214,29 @@ class CockroachPsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["CockroachPsyco def __init__(self, config: "CockroachPsycopgSyncConfig") -> None: super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return self._insert_memory_entries(entries, owner_id) + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return self._search_entries(query, app_name, user_id, limit) + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._delete_entries_by_session(session_id) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return self._delete_entries_older_than(days) + + def _get_create_memory_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -872,11 +1279,7 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_memory_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(self._get_create_memory_table_sql()) def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: @@ -921,10 +1324,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec inserted_count += cur.rowcount return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -966,12 +1365,6 @@ def _search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -983,10 +1376,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1001,10 +1390,6 @@ def _delete_entries_older_than(self, days: int) -> int: conn.commit() return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _build_insert_params(entry: "MemoryRecord") -> "tuple[object, ...]": return ( @@ -1037,3 +1422,8 @@ def _build_insert_params_with_owner(entry: "MemoryRecord", owner_id: "object | N Jsonb(entry["metadata_json"]) if entry["metadata_json"] is not None else None, entry["inserted_at"], ) + + +def _raise_missing_session(session_id: str) -> NoReturn: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) diff --git a/sqlspec/adapters/duckdb/adk/store.py b/sqlspec/adapters/duckdb/adk/store.py index 3ab9dc68b..b2713e738 100644 --- a/sqlspec/adapters/duckdb/adk/store.py +++ b/sqlspec/adapters/duckdb/adk/store.py @@ -8,14 +8,13 @@ """ import contextlib -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final, cast -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: from sqlspec.adapters.duckdb.config import DuckDBConfig @@ -29,14 +28,14 @@ DUCKDB_TABLE_NOT_FOUND_ERROR: Final = "does not exist" -class DuckdbADKStore(BaseAsyncADKStore["DuckDBConfig"]): +class DuckdbADKStore(BaseSyncADKStore["DuckDBConfig"]): """DuckDB ADK store for Google Agent Development Kit. Implements session and event storage for Google Agent Development Kit - using DuckDB's synchronous driver with async wrappers via ``async_()``. + using DuckDB's synchronous driver with a synchronous public API. Provides: - Session state management with native JSON type - - Event history with single JSON blob (event_json) plus indexed scalars + - Event history with single JSON blob (event_data) plus indexed scalars - Native TIMESTAMPTZ type support - Manual cascade delete (DuckDB has no FK CASCADE) - Columnar storage for analytical queries @@ -55,7 +54,170 @@ def __init__(self, config: "DuckDBConfig") -> None: """ super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: + def create_tables(self) -> None: + """Create both sessions and events tables if they don't exist.""" + self._create_tables() + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session. + + Args: + session_id: Unique session identifier. + app_name: Application name. + user_id: User identifier. + state: Initial session state. + owner_id: Optional owner ID value for owner_id_column (if configured). + + Returns: + Created session record. + """ + return self._create_session(session_id, app_name, user_id, state, owner_id) + + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + renew_for: If positive, touch the session update timestamp. + + Returns: + Session record or None if not found. + """ + return self._get_session(app_name, user_id, session_id, renew_for=renew_for) + + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + state: New state dictionary (replaces existing state). + """ + self._update_session_state(app_name, user_id, session_id, state) + + def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app, optionally filtered by user. + + Args: + app_name: Application name. + user_id: User identifier. If None, lists all sessions for the app. + + Returns: + List of session records ordered by update_time DESC. + """ + return self._list_sessions(app_name, user_id) + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and all associated events. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + """ + self._delete_session(app_name, user_id, session_id) + + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session. + + Args: + event_record: Event record to store. + """ + self._append_event(event_record) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update the session's durable state. + + The event insert and state update succeed together or fail together + within a single DuckDB transaction; the updated SessionRecord is + returned via UPDATE...RETURNING. + + Args: + event_record: Event record to store. + app_name: Application name. + user_id: User identifier. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (``temp:`` keys already + stripped by the service layer). + app_state: Optional app-scoped state snapshot. + user_state: Optional user-scoped state snapshot. + """ + return self._append_event_and_update_state( + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state + ) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + """ + return self._get_events(app_name, user_id, session_id, after_timestamp, limit) + + def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than a timestamp.""" + return self._delete_expired_events(before) + + def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions older than a timestamp.""" + return self._delete_idle_sessions(updated_before) + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state.""" + return self._get_app_state(app_name) + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state.""" + return self._get_user_state(app_name, user_id) + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state.""" + self._upsert_app_state(app_name, state) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state.""" + self._upsert_user_state(app_name, user_id, state) + + def get_metadata(self, key: str) -> "str | None": + """Return a metadata value.""" + return self._get_metadata(key) + + def set_metadata(self, key: str, value: str) -> None: + """Set a metadata value.""" + self._set_metadata(key, value) + + def _get_create_sessions_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for sessions. Returns: @@ -78,37 +240,99 @@ async def _get_create_sessions_table_sql(self) -> str: CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time ON {self._session_table}(update_time DESC); """ - async def _get_create_events_table_sql(self) -> str: + def _get_create_events_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for events. Returns: - SQL statement to create adk_events table with indexes. + SQL statement to create adk_event table with indexes. """ return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, - invocation_id VARCHAR NOT NULL, - author VARCHAR NOT NULL, + invocation_id VARCHAR, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSON NOT NULL, + event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ + def _get_create_app_states_table_sql(self) -> str: + """Get DuckDB CREATE TABLE SQL for app-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """ + + def _get_create_user_states_table_sql(self) -> str: + """Get DuckDB CREATE TABLE SQL for user-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR NOT NULL, + user_id VARCHAR NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) + """ + + def _get_create_metadata_table_sql(self) -> str: + """Get DuckDB CREATE TABLE SQL for internal ADK metadata.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR PRIMARY KEY, + value VARCHAR NOT NULL + ) + """ + + def _get_seed_metadata_sql(self) -> str: + """Get DuckDB SQL for seeding the schema metadata row.""" + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT(key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + """Get DuckDB DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get DuckDB DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get DuckDB DROP TABLE SQL for internal ADK metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": """Get DuckDB DROP TABLE SQL statements. Returns: List of SQL statements to drop tables and indexes. """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: """Synchronous implementation of create_tables.""" with self._config.provide_connection() as conn: conn.execute(self.__get_create_sessions_table_sql_sync()) conn.execute(self.__get_create_events_table_sql_sync()) + conn.execute(self._get_create_app_states_table_sql()) + conn.execute(self._get_create_user_states_table_sql()) + conn.execute(self._get_create_metadata_table_sql()) + conn.execute(self._get_seed_metadata_sql()) + conn.commit() def __get_create_sessions_table_sql_sync(self) -> str: """Synchronous version of DDL generation for use in _create_tables.""" @@ -133,20 +357,16 @@ def __get_create_events_table_sql_sync(self) -> str: """Synchronous version of DDL generation for use in _create_tables.""" return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR PRIMARY KEY, session_id VARCHAR NOT NULL, - invocation_id VARCHAR NOT NULL, - author VARCHAR NOT NULL, + invocation_id VARCHAR, timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSON NOT NULL, + event_data JSON NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ); CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session ON {self._events_table}(session_id, timestamp ASC); """ - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" - await async_(self._create_tables)() - def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: @@ -177,34 +397,29 @@ def _create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - """ - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Synchronous implementation of get_session.""" - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = ? - """ + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = ? + WHERE app_name = ? AND user_id = ? AND id = ? + RETURNING id, app_name, user_id, state, create_time, update_time + """ + params: list[Any] = [datetime.now(timezone.utc), app_name, user_id, session_id] + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = ? AND user_id = ? AND id = ? + """ + params = [app_name, user_id, session_id] try: with self._config.provide_connection() as conn: - cursor = conn.execute(sql, (session_id,)) + cursor = conn.execute(sql, params) row = cursor.fetchone() if row is None: @@ -227,18 +442,7 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - """ - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" now = datetime.now(timezone.utc) state_json = to_json(state) @@ -246,40 +450,23 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ with self._config.provide_connection() as conn: - conn.execute(sql, (state_json, now, session_id)) + conn.execute(sql, (state_json, now, app_name, user_id, session_id)) conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - """ - await async_(self._update_session_state)(session_id, state) - - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Synchronous implementation of delete_session.""" delete_events_sql = f"DELETE FROM {self._events_table} WHERE session_id = ?" - delete_session_sql = f"DELETE FROM {self._session_table} WHERE id = ?" + delete_session_sql = f"DELETE FROM {self._session_table} WHERE app_name = ? AND user_id = ? AND id = ?" with self._config.provide_connection() as conn: conn.execute(delete_events_sql, (session_id,)) - conn.execute(delete_session_sql, (session_id,)) + conn.execute(delete_session_sql, (app_name, user_id, session_id)) conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events. - - Args: - session_id: Session identifier. - """ - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": """Synchronous implementation of list_sessions.""" if user_id is None: @@ -320,25 +507,13 @@ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[S return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - """ - return await async_(self._list_sessions)(app_name, user_id) - def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" - event_json_str = to_json(event_record["event_json"]) + event_data_str = to_json(event_record["event_data"]) sql = f""" INSERT INTO {self._events_table} - (session_id, invocation_id, author, timestamp, event_json) + (id, session_id, invocation_id, timestamp, event_data) VALUES (?, ?, ?, ?, ?) """ @@ -346,64 +521,86 @@ def _append_event(self, event_record: EventRecord) -> None: conn.execute( sql, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_json_str, + event_data_str, ), ) conn.commit() - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_json). - """ - await async_(self._append_event)(event_record) - def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Synchronous implementation of append_event_and_update_state.""" now = datetime.now(timezone.utc) state_json = to_json(state) - event_json_str = to_json(event_record["event_json"]) + event_data_str = to_json(event_record["event_data"]) insert_sql = f""" INSERT INTO {self._events_table} - (session_id, invocation_id, author, timestamp, event_json) + (id, session_id, invocation_id, timestamp, event_data) VALUES (?, ?, ?, ?, ?) """ update_sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ with self._config.provide_connection() as conn: - conn.execute( - insert_sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_json_str, - ), - ) - cursor = conn.execute(update_sql, (state_json, now, session_id)) - row = cursor.fetchone() - conn.commit() - - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) - + try: + conn.execute("BEGIN TRANSACTION") + conn.execute( + insert_sql, + ( + event_record["id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + event_data_str, + ), + ) + cursor = conn.execute(update_sql, (state_json, now, app_name, user_id, session_id)) + row = cursor.fetchone() + if row is None: + _raise_session_not_found(session_id) + assert row is not None + if app_state is not None: + conn.execute(app_upsert_sql, (app_name, to_json(app_state), now)) + if user_state is not None: + conn.execute(user_upsert_sql, (app_name, user_id, to_json(user_state), now)) + conn.commit() + except Exception: + conn.rollback() + raise + + assert row is not None session_id_val, app_name, user_id, state_data, create_time, update_time = row return SessionRecord( id=session_id_val, @@ -414,42 +611,34 @@ def _append_event_and_update_state( update_time=update_time, ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - The event insert and state update succeed together or fail together - within a single DuckDB transaction; the updated SessionRecord is - returned via UPDATE...RETURNING. - - Args: - event_record: Event record to store (5-key shape). - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). - """ - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Synchronous implementation of get_events.""" - where_clauses = ["session_id = ?"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = ["s.app_name = ?", "s.user_id = ?", "e.session_id = ?"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > ?") + where_clauses.append("e.timestamp > ?") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" + limit_clause = f" LIMIT {limit}" if limit is not None else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -459,11 +648,13 @@ def _get_events( return [ EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], + id=row[0], + session_id=row[1], + invocation_id=row[2], timestamp=row[3], - event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], + event_data=from_json(row[4]) if isinstance(row[4], str) else row[4], + app_name=row[5], + user_id=row[6], ) for row in rows ] @@ -472,27 +663,122 @@ def _get_events( return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session. + def _delete_expired_events(self, before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) FROM {self._events_table} WHERE timestamp < ?" + delete_sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. + try: + with self._config.provide_connection() as conn: + count_row = conn.execute(count_sql, (before,)).fetchone() + count = int(count_row[0]) if count_row is not None else 0 + conn.execute(delete_sql, (before,)) + conn.commit() + return count + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return 0 + raise - Returns: - List of event records ordered by timestamp ASC. + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) FROM {self._session_table} WHERE update_time < ?" + delete_events_sql = f""" + DELETE FROM {self._events_table} + WHERE session_id IN (SELECT id FROM {self._session_table} WHERE update_time < ?) + """ + delete_sessions_sql = f"DELETE FROM {self._session_table} WHERE update_time < ?" + + try: + with self._config.provide_connection() as conn: + count_row = conn.execute(count_sql, (updated_before,)).fetchone() + count = int(count_row[0]) if count_row is not None else 0 + conn.execute(delete_events_sql, (updated_before,)) + conn.execute(delete_sessions_sql, (updated_before,)) + conn.commit() + return count + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return 0 + raise + + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = ?" + try: + with self._config.provide_connection() as conn: + row = conn.execute(sql, (app_name,)).fetchone() + if row is None: + return None + return cast("dict[str, Any]", from_json(row[0]) if isinstance(row[0], str) else row[0]) + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = ? AND user_id = ?" + try: + with self._config.provide_connection() as conn: + row = conn.execute(sql, (app_name, user_id)).fetchone() + if row is None: + return None + return cast("dict[str, Any]", from_json(row[0]) if isinstance(row[0], str) else row[0]) + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + now = datetime.now(timezone.utc) + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + with self._config.provide_connection() as conn: + conn.execute(sql, (app_name, to_json(state), now)) + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + now = datetime.now(timezone.utc) + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time """ - return await async_(self._get_events)(session_id, after_timestamp, limit) + with self._config.provide_connection() as conn: + conn.execute(sql, (app_name, user_id, to_json(state), now)) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = ?" + try: + with self._config.provide_connection() as conn: + row = conn.execute(sql, (key,)).fetchone() + return row[0] if row is not None else None + except Exception as e: + if DUCKDB_TABLE_NOT_FOUND_ERROR in str(e): + return None + raise + def _set_metadata(self, key: str, value: str) -> None: + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """ + with self._config.provide_connection() as conn: + conn.execute(sql, (key, value)) + conn.commit() -class DuckdbADKMemoryStore(BaseAsyncADKMemoryStore["DuckDBConfig"]): + +class DuckdbADKMemoryStore(BaseSyncADKMemoryStore["DuckDBConfig"]): """DuckDB ADK memory store using synchronous DuckDB driver with async wrappers. Implements memory entry storage for Google Agent Development Kit - using DuckDB's synchronous driver with async wrappers via ``async_()``. + using DuckDB's synchronous driver with a synchronous public API. Provides: - Session memory storage with native JSON type - Simple ILIKE search or BM25 full-text search via FTS extension @@ -515,6 +801,35 @@ def __init__(self, config: "DuckDBConfig") -> None: """ super().__init__(config) + def create_tables(self) -> None: + """Create the memory table and indexes if they don't exist.""" + self._create_tables() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication. + + After successful inserts, refreshes the FTS index if FTS is enabled. + """ + return self._insert_memory_entries(entries, owner_id) + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query. + + When FTS is enabled, uses ``match_bm25()`` for BM25-ranked results. + Falls back to ILIKE for simple substring matching. + """ + return self._search_entries(query, app_name, user_id, limit) + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._delete_entries_by_session(session_id) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return self._delete_entries_older_than(days) + def _ensure_fts_extension(self, conn: Any) -> bool: """Ensure the DuckDB FTS extension is available for this connection.""" with contextlib.suppress(Exception): @@ -560,7 +875,7 @@ def _refresh_fts_index(self, conn: Any) -> None: except Exception as exc: logger.debug("Failed to refresh DuckDB FTS index: %s", exc) - async def _get_create_memory_table_sql(self) -> str: + def _get_create_memory_table_sql(self) -> str: """Get DuckDB CREATE TABLE SQL for memory entries. Returns: @@ -635,10 +950,6 @@ def __get_create_memory_table_sql_sync(self) -> str: ON {self._memory_table}(session_id); """ - async def create_tables(self) -> None: - """Create the memory table and indexes if they don't exist.""" - await async_(self._create_tables)() - def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Synchronous implementation of insert_memory_entries.""" if not self._enabled: @@ -709,13 +1020,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication. - - After successful inserts, refreshes the FTS index if FTS is enabled. - """ - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -771,16 +1075,6 @@ def _search_entries( records.append(record) return records - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query. - - When FTS is enabled, uses ``match_bm25()`` for BM25-ranked results. - Falls back to ILIKE for simple substring matching. - """ - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: """Synchronous implementation of delete_entries_by_session.""" if not self._enabled: @@ -796,10 +1090,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: self._refresh_fts_index(conn) return deleted_count - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: """Synchronous implementation of delete_entries_older_than.""" if not self._enabled: @@ -819,6 +1109,7 @@ def _delete_entries_older_than(self, days: int) -> int: self._refresh_fts_index(conn) return deleted_count - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) + +def _raise_session_not_found(session_id: str) -> None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) diff --git a/sqlspec/adapters/mysqlconnector/adk/store.py b/sqlspec/adapters/mysqlconnector/adk/store.py index dfeb4f549..54b1408ff 100644 --- a/sqlspec/adapters/mysqlconnector/adk/store.py +++ b/sqlspec/adapters/mysqlconnector/adk/store.py @@ -5,13 +5,12 @@ import mysql.connector -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.mysqlconnector.config import MysqlConnectorAsyncConfig, MysqlConnectorSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -28,56 +27,38 @@ class MysqlConnectorAsyncADKStore(BaseAsyncADKStore["MysqlConnectorAsyncConfig"]): - """MySQL/MariaDB ADK store using mysql-connector async driver. - - Provides: - - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) - - Atomic event-append + state-update in one transaction - - Microsecond-precision timestamps - - Foreign key constraints with cascade delete - """ + """MySQL/MariaDB ADK store using mysql-connector async driver.""" __slots__ = () def __init__(self, config: "MysqlConnectorAsyncConfig") -> None: super().__init__(config) - def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": - return _parse_owner_id_column_for_mysql(column_ddl) - - async def _get_create_sessions_table_sql(self) -> str: - return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) - - async def _get_create_events_table_sql(self) -> str: - return _mysql_events_ddl(self._events_table, self._session_table) - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - state_json = to_json(state) - params: tuple[Any, ...] if self._owner_id_column_name: sql = f""" INSERT INTO {self._session_table} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time) VALUES (%s, %s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) """ - params = (session_id, app_name, user_id, owner_id, state_json) + params = (session_id, app_name, user_id, owner_id, to_json(state)) else: sql = f""" INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time) VALUES (%s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) """ - params = (session_id, app_name, user_id, state_json) + params = (session_id, app_name, user_id, to_json(state)) async with self._config.provide_connection() as conn: cursor = await conn.cursor() @@ -87,66 +68,58 @@ async def create_session( await cursor.close() await conn.commit() - return await self.get_session(session_id) # type: ignore[return-value] - - async def get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = %s - """ + result = await self.get_session(app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": try: async with self._config.provide_connection() as conn: cursor = await conn.cursor() try: - await cursor.execute(sql, (session_id,)) + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + await cursor.execute( + f""" + UPDATE {self._session_table} + SET update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), + ) + await conn.commit() + + await cursor.execute( + f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), + ) row = await cursor.fetchone() finally: await cursor.close() - if row is None: - return None - - session_id_val, app_name_val, user_id_val, state_json, create_time_val, update_time_val = row - - return SessionRecord( - id=cast("str", session_id_val), - app_name=cast("str", app_name_val), - user_id=cast("str", user_id_val), - state=from_json(state_json) if isinstance(state_json, str) else cast("dict[str, Any]", state_json), - create_time=cast("datetime", create_time_val), - update_time=cast("datetime", update_time_val), - ) + return _session_record_from_row(row) if row is not None else None except mysql.connector.Error as exc: - if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + if _is_mysql_table_missing(exc): return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - state_json = to_json(state) - + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ - async with self._config.provide_connection() as conn: cursor = await conn.cursor() try: - await cursor.execute(sql, (state_json, session_id)) - finally: - await cursor.close() - await conn.commit() - - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" - - async with self._config.provide_connection() as conn: - cursor = await conn.cursor() - try: - await cursor.execute(sql, (session_id,)) + await cursor.execute(sql, (to_json(state), app_name, user_id, session_id)) finally: await cursor.close() await conn.commit() @@ -159,7 +132,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis WHERE app_name = %s ORDER BY update_time DESC """ - params: tuple[str, ...] = (app_name,) + params: tuple[Any, ...] = (app_name,) else: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -177,153 +150,123 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis rows = await cursor.fetchall() finally: await cursor.close() - - return [ - SessionRecord( - id=cast("str", row[0]), - app_name=cast("str", row[1]), - user_id=cast("str", row[2]), - state=from_json(row[3]) if isinstance(row[3], str) else cast("dict[str, Any]", row[3]), - create_time=cast("datetime", row[4]), - update_time=cast("datetime", row[5]), - ) - for row in rows - ] + return [_session_record_from_row(row) for row in rows] except mysql.connector.Error as exc: - if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + if _is_mysql_table_missing(exc): return [] raise - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record with 5 keys (session_id, invocation_id, - author, timestamp, event_json). - """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (app_name, user_id, session_id)) + finally: + await cursor.close() + await conn.commit() + async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) """ - async with self._config.provide_connection() as conn: cursor = await conn.cursor() try: - await cursor.execute( - sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_json_str, - ), - ) + await cursor.execute(sql, _event_insert_params(event_record)) finally: await cursor.close() await conn.commit() async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - MySQL doesn't support UPDATE...RETURNING; the UPDATE is followed by a - SELECT inside the same transaction so callers get the refreshed row - without acquiring a second connection. - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. - """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json - state_json = to_json(state) - insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) """ - update_sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ - select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ - async with self._config.provide_connection() as conn: cursor = await conn.cursor() try: + await cursor.execute(update_sql, (to_json(state), app_name, user_id, session_id)) + await cursor.execute(select_sql, (app_name, user_id, session_id)) + row = await cursor.fetchone() + if row is None: + _raise_session_not_found(session_id) await cursor.execute( insert_sql, ( - event_record["session_id"], + event_record["id"], + app_name, + user_id, + session_id, event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_json_str, + _json_for_storage(event_record["event_data"]), ), ) - await cursor.execute(update_sql, (state_json, session_id)) - await cursor.execute(select_sql, (session_id,)) - row = await cursor.fetchone() + if app_state is not None: + await cursor.execute( + _mysql_upsert_app_state_sql(self._app_state_table), (app_name, to_json(app_state)) + ) + if user_state is not None: + await cursor.execute( + _mysql_upsert_user_state_sql(self._user_state_table), (app_name, user_id, to_json(user_state)) + ) + await conn.commit() + except Exception: + await conn.rollback() + raise finally: await cursor.close() - await conn.commit() - - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) - - state_value = row[3] - return SessionRecord( - id=cast("str", row[0]), - app_name=cast("str", row[1]), - user_id=cast("str", row[2]), - state=from_json(state_value) if isinstance(state_value, str) else cast("dict[str, Any]", state_value), - create_time=cast("datetime", row[4]), - update_time=cast("datetime", row[5]), - ) + return _session_record_from_row(row) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + if limit == 0: + return [] + where_clauses = ["app_name = %s", "user_id = %s", "session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: where_clauses.append("timestamp > %s") params.append(after_timestamp) - - where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT {limit}" if limit else "" + limit_clause = "" + if limit is not None: + limit_clause = " LIMIT %s" + params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data FROM {self._events_table} - WHERE {where_clause} + WHERE {" AND ".join(where_clauses)} ORDER BY timestamp ASC{limit_clause} """ @@ -335,41 +278,54 @@ async def get_events( rows = await cursor.fetchall() finally: await cursor.close() - - return [ - EventRecord( - session_id=cast("str", row[0]), - invocation_id=cast("str", row[1]), - author=cast("str", row[2]), - timestamp=cast("datetime", row[3]), - event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), - ) - for row in rows - ] + return [_event_record_from_row(row) for row in rows] except mysql.connector.Error as exc: - if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + if _is_mysql_table_missing(exc): return [] raise + async def delete_expired_events(self, before: "datetime") -> int: + return await _mysqlconnector_async_delete_by_timestamp(self, self._events_table, "timestamp", before) -class MysqlConnectorSyncADKStore(BaseAsyncADKStore["MysqlConnectorSyncConfig"]): - """MySQL/MariaDB ADK store using mysql-connector sync driver. + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + return await _mysqlconnector_async_delete_by_timestamp(self, self._session_table, "update_time", updated_before) - Provides: - - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) - - Atomic event-create + state-update in one transaction - - Microsecond-precision timestamps - - Foreign key constraints with cascade delete - """ + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + return await _mysqlconnector_async_get_state(self, self._app_state_table, "app_name = %s", (app_name,)) - __slots__ = () + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + return await _mysqlconnector_async_get_state( + self, self._user_state_table, "app_name = %s AND user_id = %s", (app_name, user_id) + ) - def __init__(self, config: "MysqlConnectorSyncConfig") -> None: - super().__init__(config) + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + await _mysqlconnector_async_execute_commit( + self, _mysql_upsert_app_state_sql(self._app_state_table), (app_name, to_json(state)) + ) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + await _mysqlconnector_async_execute_commit( + self, _mysql_upsert_user_state_sql(self._user_state_table), (app_name, user_id, to_json(state)) + ) + + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE `key` = %s" + try: + async with self._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (key,)) + row = await cursor.fetchone() + finally: + await cursor.close() + return str(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if _is_mysql_table_missing(exc): + return None + raise - def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": - return _parse_owner_id_column_for_mysql(column_ddl) + async def set_metadata(self, key: str, value: str) -> None: + await _mysqlconnector_async_execute_commit(self, _mysql_upsert_metadata_sql(self._metadata_table), (key, value)) async def _get_create_sessions_table_sql(self) -> str: return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) @@ -377,36 +333,161 @@ async def _get_create_sessions_table_sql(self) -> str: async def _get_create_events_table_sql(self) -> str: return _mysql_events_ddl(self._events_table, self._session_table) + async def _get_create_app_states_table_sql(self) -> str: + return _mysql_app_state_ddl(self._app_state_table) + + async def _get_create_user_states_table_sql(self) -> str: + return _mysql_user_state_ddl(self._user_state_table) + + async def _get_create_metadata_table_sql(self) -> str: + return _mysql_metadata_ddl(self._metadata_table) + + async def _get_seed_metadata_sql(self) -> str: + return f"INSERT IGNORE INTO {self._metadata_table} (`key`, value) VALUES ('schema_version', '1')" + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] - def _create_tables(self) -> None: - with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_sessions_table_sql)()) - driver.execute_script(run_(self._get_create_events_table_sql)()) - async def create_tables(self) -> None: +class MysqlConnectorSyncADKStore(BaseSyncADKStore["MysqlConnectorSyncConfig"]): + """MySQL/MariaDB ADK store using mysql-connector sync driver.""" + + __slots__ = () + + def __init__(self, config: "MysqlConnectorSyncConfig") -> None: + super().__init__(config) + + def create_tables(self) -> None: """Create tables if they don't exist.""" - await async_(self._create_tables)() + self._create_tables() - def _create_session( + def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: - state_json = to_json(state) + """Create a new session.""" + return self._create_session(session_id, app_name, user_id, state, owner_id) + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return self._get_session(app_name, user_id, session_id, renew_for=renew_for) + + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + self._update_session_state(app_name, user_id, session_id, state) + + def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return self._list_sessions(app_name, user_id) + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and associated events.""" + self._delete_session(app_name, user_id, session_id) + + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + self._append_event(event_record) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return self._append_event_and_update_state( + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state + ) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session.""" + return self._get_events(app_name, user_id, session_id, after_timestamp, limit) + + def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return self._delete_expired_events(before) + + def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the threshold.""" + return self._delete_idle_sessions(updated_before) + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return self._get_app_state(app_name) + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return self._get_user_state(app_name, user_id) + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + self._upsert_app_state(app_name, state) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + self._upsert_user_state(app_name, user_id, state) + + def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return self._get_metadata(key) + + def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + self._set_metadata(key, value) + + def _create_tables(self) -> None: + with self._config.provide_session() as driver: + driver.execute_script(self._get_create_sessions_table_sql()) + driver.execute_script(self._get_create_events_table_sql()) + driver.execute_script(self._get_create_app_states_table_sql()) + driver.execute_script(self._get_create_user_states_table_sql()) + driver.execute_script(self._get_create_metadata_table_sql()) + driver.execute_script(self._get_seed_metadata_sql()) + + def _create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: params: tuple[Any, ...] if self._owner_id_column_name: sql = f""" INSERT INTO {self._session_table} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time) VALUES (%s, %s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) """ - params = (session_id, app_name, user_id, owner_id, state_json) + params = (session_id, app_name, user_id, owner_id, to_json(state)) else: sql = f""" INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time) VALUES (%s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) """ - params = (session_id, app_name, user_id, state_json) + params = (session_id, app_name, user_id, to_json(state)) with self._config.provide_connection() as conn: cursor = conn.cursor() @@ -416,92 +497,61 @@ def _create_session( cursor.close() conn.commit() - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = %s - """ - + def _get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": try: with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (session_id,)) + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + cursor.execute( + f""" + UPDATE {self._session_table} + SET update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), + ) + conn.commit() + + cursor.execute( + f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), + ) row = cursor.fetchone() finally: cursor.close() - - if row is None: - return None - - session_id_val, app_name_val, user_id_val, state_json, create_time_val, update_time_val = row - - return SessionRecord( - id=cast("str", session_id_val), - app_name=cast("str", app_name_val), - user_id=cast("str", user_id_val), - state=from_json(state_json) if isinstance(state_json, str) else cast("dict[str, Any]", state_json), - create_time=cast("datetime", create_time_val), - update_time=cast("datetime", update_time_val), - ) + return _session_record_from_row(row) if row is not None else None except mysql.connector.Error as exc: - if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + if _is_mysql_table_missing(exc): return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - state_json = to_json(state) - + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ - - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute(sql, (state_json, session_id)) - finally: - cursor.close() - conn.commit() - - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - - def _delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" - with self._config.provide_connection() as conn: cursor = conn.cursor() try: - cursor.execute(sql, (session_id,)) + cursor.execute(sql, (to_json(state), app_name, user_id, session_id)) finally: cursor.close() conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: sql = f""" @@ -510,7 +560,7 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses WHERE app_name = %s ORDER BY update_time DESC """ - params: tuple[str, ...] = (app_name,) + params: tuple[Any, ...] = (app_name,) else: sql = f""" SELECT id, app_name, user_id, state, create_time, update_time @@ -528,160 +578,125 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses rows = cursor.fetchall() finally: cursor.close() - - return [ - SessionRecord( - id=cast("str", row[0]), - app_name=cast("str", row[1]), - user_id=cast("str", row[2]), - state=from_json(row[3]) if isinstance(row[3], str) else cast("dict[str, Any]", row[3]), - create_time=cast("datetime", row[4]), - update_time=cast("datetime", row[5]), - ) - for row in rows - ] + return [_session_record_from_row(row) for row in rows] except mysql.connector.Error as exc: - if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + if _is_mysql_table_missing(exc): return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - - def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically create an event and update the session's durable state. - - MySQL doesn't support UPDATE...RETURNING; the UPDATE is followed by a - SELECT inside the same transaction so callers get the refreshed row - without acquiring a second connection. + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name, user_id, session_id)) + finally: + cursor.close() + conn.commit() - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. + def _append_event(self, event_record: EventRecord) -> None: + sql = f""" + INSERT INTO {self._events_table} ( + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json - state_json = to_json(state) + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, _event_insert_params(event_record)) + finally: + cursor.close() + conn.commit() + def _append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) """ - update_sql = f""" UPDATE {self._session_table} - SET state = %s - WHERE id = %s + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s """ - select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """ with self._config.provide_connection() as conn: cursor = conn.cursor() try: + cursor.execute(update_sql, (to_json(state), app_name, user_id, session_id)) + cursor.execute(select_sql, (app_name, user_id, session_id)) + row = cursor.fetchone() + if row is None: + _raise_session_not_found(session_id) cursor.execute( insert_sql, ( - event_record["session_id"], + event_record["id"], + app_name, + user_id, + session_id, event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_json_str, + _json_for_storage(event_record["event_data"]), ), ) - cursor.execute(update_sql, (state_json, session_id)) - cursor.execute(select_sql, (session_id,)) - row = cursor.fetchone() + if app_state is not None: + cursor.execute(_mysql_upsert_app_state_sql(self._app_state_table), (app_name, to_json(app_state))) + if user_state is not None: + cursor.execute( + _mysql_upsert_user_state_sql(self._user_state_table), (app_name, user_id, to_json(user_state)) + ) + except Exception: + conn.rollback() + raise finally: cursor.close() conn.commit() - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) - - state_value = row[3] - return SessionRecord( - id=cast("str", row[0]), - app_name=cast("str", row[1]), - user_id=cast("str", row[2]), - state=from_json(state_value) if isinstance(state_value, str) else cast("dict[str, Any]", state_value), - create_time=cast("datetime", row[4]), - update_time=cast("datetime", row[5]), - ) - - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - - def _insert_event(self, event_record: EventRecord) -> None: - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json - - sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) - """ - - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute( - sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_json_str, - ), - ) - finally: - cursor.close() - conn.commit() + return _session_record_from_row(row) def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - """List events for a session ordered by timestamp. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + if limit == 0: + return [] + where_clauses = ["app_name = %s", "user_id = %s", "session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: where_clauses.append("timestamp > %s") params.append(after_timestamp) + limit_clause = "" + if limit is not None: + limit_clause = " LIMIT %s" + params.append(limit) - where_clause = " AND ".join(where_clauses) - limit_clause = " LIMIT %s" if limit else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json + SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data FROM {self._events_table} - WHERE {where_clause} + WHERE {" AND ".join(where_clauses)} ORDER BY timestamp ASC{limit_clause} """ - if limit: - params.append(limit) try: with self._config.provide_connection() as conn: @@ -691,35 +706,90 @@ def _get_events( rows = cursor.fetchall() finally: cursor.close() - - return [ - EventRecord( - session_id=cast("str", row[0]), - invocation_id=cast("str", row[1]), - author=cast("str", row[2]), - timestamp=cast("datetime", row[3]), - event_json=from_json(row[4]) if isinstance(row[4], str) else cast("dict[str, Any]", row[4]), - ) - for row in rows - ] + return [_event_record_from_row(row) for row in rows] except mysql.connector.Error as exc: - if "doesn't exist" in str(exc) or getattr(exc, "errno", None) == MYSQL_TABLE_NOT_FOUND_ERROR: + if _is_mysql_table_missing(exc): return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + return _mysqlconnector_sync_delete_by_timestamp(self, self._events_table, "timestamp", before) - def _append_event(self, event_record: EventRecord) -> None: - """Synchronous implementation of append_event.""" - self._insert_event(event_record) + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + return _mysqlconnector_sync_delete_by_timestamp(self, self._session_table, "update_time", updated_before) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + return _mysqlconnector_sync_get_state(self, self._app_state_table, "app_name = %s", (app_name,)) + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + return _mysqlconnector_sync_get_state( + self, self._user_state_table, "app_name = %s AND user_id = %s", (app_name, user_id) + ) + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + _mysqlconnector_sync_execute_commit( + self, _mysql_upsert_app_state_sql(self._app_state_table), (app_name, to_json(state)) + ) + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + _mysqlconnector_sync_execute_commit( + self, _mysql_upsert_user_state_sql(self._user_state_table), (app_name, user_id, to_json(state)) + ) + + def _get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE `key` = %s" + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key,)) + row = cursor.fetchone() + finally: + cursor.close() + return str(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if _is_mysql_table_missing(exc): + return None + raise + + def _set_metadata(self, key: str, value: str) -> None: + _mysqlconnector_sync_execute_commit(self, _mysql_upsert_metadata_sql(self._metadata_table), (key, value)) + + def _get_create_sessions_table_sql(self) -> str: + return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) + + def _get_create_events_table_sql(self) -> str: + return _mysql_events_ddl(self._events_table, self._session_table) + + def _get_create_app_states_table_sql(self) -> str: + return _mysql_app_state_ddl(self._app_state_table) + + def _get_create_user_states_table_sql(self) -> str: + return _mysql_user_state_ddl(self._user_state_table) + + def _get_create_metadata_table_sql(self) -> str: + return _mysql_metadata_ddl(self._metadata_table) + + def _get_seed_metadata_sql(self) -> str: + return f"INSERT IGNORE INTO {self._metadata_table} (`key`, value) VALUES ('schema_version', '1')" + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] class MysqlConnectorAsyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorAsyncConfig"]): @@ -730,40 +800,6 @@ class MysqlConnectorAsyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorA def __init__(self, config: "MysqlConnectorAsyncConfig") -> None: super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: - owner_id_line = "" - fk_constraint = "" - if self._owner_id_column_ddl: - col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_line = f",\n {col_def}" - if fk_def: - fk_constraint = f",\n {fk_def}" - - fts_index = "" - if self._use_fts: - fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" - - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - content_json JSON NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSON, - inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), - INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._memory_table}"] - async def create_tables(self) -> None: if not self._enabled: return @@ -910,8 +946,42 @@ async def delete_entries_older_than(self, days: int) -> int: finally: await cursor.close() + async def _get_create_memory_table_sql(self) -> str: + owner_id_line = "" + fk_constraint = "" + if self._owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(self._owner_id_column_ddl) + owner_id_line = f",\n {col_def}" + if fk_def: + fk_constraint = f",\n {fk_def}" + + fts_index = "" + if self._use_fts: + fts_index = f",\n FULLTEXT INDEX idx_{self._memory_table}_fts (content_text)" -class MysqlConnectorSyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorSyncConfig"]): + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + content_json JSON NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSON, + inserted_at TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + INDEX idx_{self._memory_table}_app_user_time (app_name, user_id, timestamp), + INDEX idx_{self._memory_table}_session (session_id){fts_index}{fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + + +class MysqlConnectorSyncADKMemoryStore(BaseSyncADKMemoryStore["MysqlConnectorSyncConfig"]): """MySQL/MariaDB ADK memory store using mysql-connector sync driver.""" __slots__ = () @@ -919,7 +989,29 @@ class MysqlConnectorSyncADKMemoryStore(BaseAsyncADKMemoryStore["MysqlConnectorSy def __init__(self, config: "MysqlConnectorSyncConfig") -> None: super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return self._insert_memory_entries(entries, owner_id) + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return self._search_entries(query, app_name, user_id, limit) + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._delete_entries_by_session(session_id) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return self._delete_entries_older_than(days) + + def _get_create_memory_table_sql(self) -> str: owner_id_line = "" fk_constraint = "" if self._owner_id_column_ddl: @@ -958,11 +1050,7 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_memory_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(self._get_create_memory_table_sql()) def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: @@ -1030,10 +1118,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec conn.commit() return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -1074,12 +1158,6 @@ def _search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1095,10 +1173,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -1117,10 +1191,6 @@ def _delete_entries_older_than(self, days: int) -> int: finally: cursor.close() - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": references_match = re.search(r"\s+REFERENCES\s+(.+)", column_ddl, re.IGNORECASE) @@ -1134,42 +1204,255 @@ def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": return (col_def, fk_constraint) +def _is_mysql_table_missing(exc: BaseException) -> bool: + args = getattr(exc, "args", ()) + errno = getattr(exc, "errno", None) + return ( + errno == MYSQL_TABLE_NOT_FOUND_ERROR + or "doesn't exist" in str(exc) + or bool(args and args[0] == MYSQL_TABLE_NOT_FOUND_ERROR) + ) + + +def _json_for_storage(value: Any) -> str: + return value if isinstance(value, str) else to_json(value) + + +def _json_dict(value: Any) -> "dict[str, Any]": + if isinstance(value, bytearray): + value = bytes(value) + if isinstance(value, (bytes, str)): + return cast("dict[str, Any]", from_json(value)) + return cast("dict[str, Any]", value) + + +def _session_record_from_row(row: Any) -> SessionRecord: + return SessionRecord( + id=row[0], app_name=row[1], user_id=row[2], state=_json_dict(row[3]), create_time=row[4], update_time=row[5] + ) + + +def _event_record_from_row(row: Any) -> EventRecord: + return EventRecord( + id=row[0], + app_name=row[1], + user_id=row[2], + session_id=row[3], + invocation_id=row[4], + timestamp=row[5], + event_data=_json_dict(row[6]), + ) + + +def _event_insert_params(event_record: EventRecord) -> "tuple[Any, ...]": + return ( + event_record["id"], + event_record["app_name"], + event_record["user_id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + _json_for_storage(event_record["event_data"]), + ) + + +def _raise_session_not_found(session_id: str) -> None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) + + +async def _mysqlconnector_async_delete_by_timestamp( + store: MysqlConnectorAsyncADKStore, table_name: str, column_name: str, threshold: "datetime" +) -> int: + sql = f"DELETE FROM {table_name} WHERE {column_name} < %s" + try: + async with store._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, (threshold,)) + rowcount = cursor.rowcount + finally: + await cursor.close() + await conn.commit() + except mysql.connector.Error as exc: + if _is_mysql_table_missing(exc): + return 0 + raise + else: + return rowcount if rowcount and rowcount > 0 else 0 + + +async def _mysqlconnector_async_get_state( + store: MysqlConnectorAsyncADKStore, table_name: str, where_clause: str, params: "tuple[Any, ...]" +) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {table_name} WHERE {where_clause} LIMIT 1" + try: + async with store._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, params) + row = await cursor.fetchone() + finally: + await cursor.close() + return _json_dict(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if _is_mysql_table_missing(exc): + return None + raise + + +async def _mysqlconnector_async_execute_commit( + store: MysqlConnectorAsyncADKStore, sql: str, params: "tuple[Any, ...]" +) -> None: + async with store._config.provide_connection() as conn: + cursor = await conn.cursor() + try: + await cursor.execute(sql, params) + finally: + await cursor.close() + await conn.commit() + + +def _mysqlconnector_sync_delete_by_timestamp( + store: MysqlConnectorSyncADKStore, table_name: str, column_name: str, threshold: "datetime" +) -> int: + sql = f"DELETE FROM {table_name} WHERE {column_name} < %s" + try: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (threshold,)) + rowcount = cursor.rowcount + finally: + cursor.close() + conn.commit() + except mysql.connector.Error as exc: + if _is_mysql_table_missing(exc): + return 0 + raise + else: + return rowcount if rowcount and rowcount > 0 else 0 + + +def _mysqlconnector_sync_get_state( + store: MysqlConnectorSyncADKStore, table_name: str, where_clause: str, params: "tuple[Any, ...]" +) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {table_name} WHERE {where_clause} LIMIT 1" + try: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, params) + row = cursor.fetchone() + finally: + cursor.close() + return _json_dict(row[0]) if row is not None else None + except mysql.connector.Error as exc: + if _is_mysql_table_missing(exc): + return None + raise + + +def _mysqlconnector_sync_execute_commit(store: MysqlConnectorSyncADKStore, sql: str, params: "tuple[Any, ...]") -> None: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, params) + finally: + cursor.close() + conn.commit() + + def _mysql_sessions_ddl(session_table: str, owner_id_column_ddl: "str | None") -> str: - """Generate shared MySQL sessions CREATE TABLE DDL.""" - owner_id_col = "" + owner_id_line = "" fk_constraint = "" - if owner_id_column_ddl: col_def, fk_def = _parse_owner_id_column_for_mysql(owner_id_column_ddl) - owner_id_col = f"{col_def}," + owner_id_line = f"\n {col_def}," if fk_def: fk_constraint = f",\n {fk_def}" return f""" - CREATE TABLE IF NOT EXISTS {session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{session_table}_app_user (app_name, user_id), - INDEX idx_{session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + CREATE TABLE IF NOT EXISTS {session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL,{owner_id_line} + state JSON NOT NULL, + create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + INDEX idx_{session_table}_app_user (app_name, user_id), + INDEX idx_{session_table}_update_time (update_time DESC){fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ def _mysql_events_ddl(events_table: str, session_table: str) -> str: - """Generate shared MySQL events CREATE TABLE DDL (post clean-break, 5 columns).""" return f""" - CREATE TABLE IF NOT EXISTS {events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_json JSON NOT NULL, - FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE, - INDEX idx_{events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ + CREATE TABLE IF NOT EXISTS {events_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + event_data JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE, + INDEX idx_{events_table}_scope (app_name, user_id, session_id, timestamp ASC), + INDEX idx_{events_table}_session (session_id, timestamp ASC) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_app_state_ddl(app_state_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_user_state_ddl(user_state_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_metadata_ddl(metadata_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {metadata_table} ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_upsert_app_state_sql(app_state_table: str) -> str: + return f""" + INSERT INTO {app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + +def _mysql_upsert_user_state_sql(user_state_table: str) -> str: + return f""" + INSERT INTO {user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + +def _mysql_upsert_metadata_sql(metadata_table: str) -> str: + return f""" + INSERT INTO {metadata_table} (`key`, value) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value) + """ diff --git a/sqlspec/adapters/oracledb/adk/store.py b/sqlspec/adapters/oracledb/adk/store.py index 5c6ae6700..ff9a918a6 100644 --- a/sqlspec/adapters/oracledb/adk/store.py +++ b/sqlspec/adapters/oracledb/adk/store.py @@ -2,7 +2,7 @@ from decimal import Decimal from enum import Enum -from typing import TYPE_CHECKING, Any, Final, cast +from typing import TYPE_CHECKING, Any, Final, NoReturn, cast import oracledb @@ -12,15 +12,14 @@ OracledbSyncDataDictionary, OracleVersionInfo, ) -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_, run_ from sqlspec.utils.type_guards import is_async_readable, is_readable if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.oracledb.config import OracleAsyncConfig, OracleSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -37,7 +36,6 @@ logger = get_logger("sqlspec.adapters.oracledb.adk.store") - ORACLE_TABLE_NOT_FOUND_ERROR: Final = 942 ORACLE_MIN_JSON_NATIVE_VERSION: Final = 21 ORACLE_MIN_JSON_NATIVE_COMPATIBLE: Final = 20 @@ -60,6 +58,12 @@ "archive_high": "COLUMN STORE COMPRESS FOR ARCHIVE HIGH", } ORACLE_DUPLICATE_KEY_ERROR: Final = 1 +ORACLE_DEFAULT_SESSION_TABLE: Final = "adk_session" +ORACLE_DEFAULT_EVENTS_TABLE: Final = "adk_event" +ORACLE_DEFAULT_APP_STATE_TABLE: Final = "adk_app_state" +ORACLE_DEFAULT_USER_STATE_TABLE: Final = "adk_user_state" +ORACLE_DEFAULT_METADATA_TABLE: Final = "adk_internal_metadata" +OracleDatabaseError: Final[type[Exception]] = cast("type[Exception]", oracledb.DatabaseError) class JSONStorageType(str, Enum): @@ -84,7 +88,7 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): Implements session and event storage for Google Agent Development Kit using Oracle Database via the python-oracledb async driver. Provides: - Session state management with version-specific JSON storage - - Full-fidelity event storage via ``event_json`` column + - Full-fidelity event storage via ``event_data`` column - Atomic ``append_event_and_update_state`` for durable session mutations - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Foreign key constraints with cascade delete @@ -92,6 +96,15 @@ class OracleAsyncADKStore(BaseAsyncADKStore["OracleAsyncConfig"]): Args: config: OracleAsyncConfig with extension_config["adk"] settings. + + Notes: + - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) + - event_data stored as JSON (21c+) or BLOB (older versions) + - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps + - Named parameters using :param_name + - State merging handled at application level + - owner_id_column supports NUMBER, VARCHAR2, RAW for Oracle FK types + - Configuration is read from config.extension_config["adk"] """ __slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info") @@ -101,168 +114,797 @@ def __init__(self, config: "OracleAsyncConfig") -> None: Args: config: OracleAsyncConfig instance. + + Notes: + Configuration is read from config.extension_config["adk"]: + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") + - owner_id_column: Optional owner FK column DDL (default: None) + - in_memory: Enable INMEMORY PRIORITY HIGH clause (default: False) """ super().__init__(config) + _configure_oracle_adk_session_tables(self, config) self._json_storage_type: JSONStorageType | None = None self._oracle_version_info: OracleVersionInfo | None = None adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) - async def _get_create_sessions_table_sql(self) -> str: - """Get Oracle CREATE TABLE SQL for sessions table. + async def create_tables(self) -> None: + """Create both sessions and events tables if they don't exist. - Auto-detects optimal JSON storage type based on Oracle version. - Result is cached to minimize database queries. + Notes: + Detects Oracle version to determine optimal JSON storage type. + Uses version-appropriate table schema. """ storage_type = await self._detect_json_storage_type() - return self._get_create_sessions_table_sql_for_type(storage_type) - - async def _get_create_events_table_sql(self) -> str: - """Get Oracle CREATE TABLE SQL for events table. + logger.debug("Creating ADK tables with storage type: %s", storage_type) - Auto-detects optimal JSON storage type based on Oracle version. - Result is cached to minimize database queries. - """ - storage_type = await self._detect_json_storage_type() - return self._get_create_events_table_sql_for_type(storage_type) + async with self._config.provide_session() as driver: + await driver.execute_script(self._get_create_sessions_table_sql_for_type(storage_type)) - async def _detect_json_storage_type(self) -> JSONStorageType: - """Detect the appropriate JSON storage type based on Oracle version. + await driver.execute_script(self._get_create_events_table_sql_for_type(storage_type)) + await driver.execute_script(self._get_create_app_states_table_sql_for_type(storage_type)) + await driver.execute_script(self._get_create_user_states_table_sql_for_type(storage_type)) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) + await driver.commit() - Returns: - Appropriate JSONStorageType for this Oracle version. - """ - if self._json_storage_type is not None: - return self._json_storage_type + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session. - version_info = await self._get_version_info() - self._json_storage_type = _storage_type_from_version(version_info) - return self._json_storage_type + Args: + session_id: Unique session identifier. + app_name: Application name. + user_id: User identifier. + state: Initial session state. + owner_id: Optional owner ID value for owner_id_column (if configured). - async def _get_version_info(self) -> "OracleVersionInfo | None": - """Return cached Oracle version info using Oracle data dictionary.""" + Returns: + Created session record. - if self._oracle_version_info is not None: - return self._oracle_version_info + Notes: + Uses SYSTIMESTAMP for create_time and update_time. + State is serialized using version-appropriate format. + owner_id is ignored if owner_id_column not configured. + """ + state_data = await self._serialize_state(state) - async with self._config.provide_session() as driver: - dictionary = OracledbAsyncDataDictionary() - self._oracle_version_info = await dictionary.get_version(driver) + if self._owner_id_column_name: + sql = f""" + INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time, {self._owner_id_column_name}) + VALUES (:id, :app_name, :user_id, :state, SYSTIMESTAMP, SYSTIMESTAMP, :owner_id) + """ + params = { + "id": session_id, + "app_name": app_name, + "user_id": user_id, + "state": state_data, + "owner_id": owner_id, + } + else: + sql = f""" + INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time) + VALUES (:id, :app_name, :user_id, :state, SYSTIMESTAMP, SYSTIMESTAMP) + """ + params = {"id": session_id, "app_name": app_name, "user_id": user_id, "state": state_data} - if self._oracle_version_info is None: - logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage") + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, params) + await conn.commit() - return self._oracle_version_info + return await self.get_session(app_name, user_id, session_id) # type: ignore[return-value] - async def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes": - """Serialize state dictionary to appropriate format based on storage type. + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID. Args: - state: State dictionary to serialize. + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: - JSON string for JSON_NATIVE, bytes for BLOB types. + Session record or None if not found. + + Notes: + Oracle returns datetime objects for TIMESTAMP columns. + State is deserialized using version-appropriate format. """ - storage_type = await self._detect_json_storage_type() - if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(state) + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + await cursor.execute( + f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE app_name = :app_name AND user_id = :user_id AND id = :id", + {"app_name": app_name, "user_id": user_id, "id": session_id}, + ) + await conn.commit() - return to_json(state, as_bytes=True) + await cursor.execute( + f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = :app_name AND user_id = :user_id AND id = :id + """, + {"app_name": app_name, "user_id": user_id, "id": session_id}, + ) + row = await cursor.fetchone() - async def _deserialize_state(self, data: Any) -> "dict[str, Any]": - """Deserialize state data from database format. + if row is None: + return None - Args: - data: Data from database (may be LOB, str, bytes, or dict). + session_id_val, app_name, user_id, state_data, create_time, update_time = row - Returns: - Deserialized state dictionary. - """ - if is_async_readable(data): - data = await data.read() - elif is_readable(data): - data = data.read() + state = await self._deserialize_state(state_data) - if isinstance(data, dict): - return cast("dict[str, Any]", _coerce_decimal_values(data)) + return SessionRecord( + id=session_id_val, + app_name=app_name, + user_id=user_id, + state=state, + create_time=create_time, + update_time=update_time, + ) + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise - if isinstance(data, bytes): - return from_json(data) # type: ignore[no-any-return] + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state. - if isinstance(data, str): - return from_json(data) # type: ignore[no-any-return] + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + state: New state dictionary (replaces existing state). - return from_json(str(data)) # type: ignore[no-any-return] + Notes: + This replaces the entire state dictionary. + Updates update_time to current timestamp. + State is serialized using version-appropriate format. + """ + state_data = await self._serialize_state(state) - async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": - """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values.""" - if data is None: - return None - return await self._deserialize_state(data) + sql = f""" + UPDATE {self._session_table} + SET state = :state, update_time = SYSTIMESTAMP + WHERE app_name = :app_name AND user_id = :user_id AND id = :id + """ - async def _serialize_event_json(self, event_json: Any) -> "str | bytes": - """Serialize event_json to the configured Oracle JSON storage format.""" - storage_type = await self._detect_json_storage_type() - if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(event_json) - return to_json(event_json, as_bytes=True) + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"state": state_data, "app_name": app_name, "user_id": user_id, "id": session_id}) + await conn.commit() - async def _read_event_json(self, data: Any) -> str: - """Read event_json from database, handling LOB types. + async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app, optionally filtered by user. Args: - data: Data from database (may be LOB, str, or dict). + app_name: Application name. + user_id: User identifier. If None, lists all sessions for the app. Returns: - JSON string. + List of session records ordered by update_time DESC. + + Notes: + Uses composite index on (app_name, user_id) when user_id is provided. + State is deserialized using version-appropriate format. """ - if is_async_readable(data): - data = await data.read() - elif is_readable(data): - data = data.read() - if isinstance(data, dict): - return to_json(data) + if user_id is None: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = :app_name + ORDER BY update_time DESC + """ + params = {"app_name": app_name} + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = :app_name AND user_id = :user_id + ORDER BY update_time DESC + """ + params = {"app_name": app_name, "user_id": user_id} - if isinstance(data, bytes): - return data.decode("utf-8") + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, params) + rows = await cursor.fetchall() - return str(data) + results = [] + for row in rows: + state = await self._deserialize_state(row[3]) - def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str: - """Get Oracle CREATE TABLE SQL for sessions with specified storage type. + results.append( + SessionRecord( + id=row[0], + app_name=row[1], + user_id=row[2], + state=state, + create_time=row[4], + update_time=row[5], + ) + ) + return results + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return [] + raise + + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and all associated events (cascade). Args: - storage_type: JSON storage type to use. + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. - Returns: - SQL statement to create adk_sessions table. + Notes: + Foreign key constraint ensures events are cascade-deleted. """ - if storage_type == JSONStorageType.JSON_NATIVE: - state_column = "state JSON NOT NULL" - elif storage_type == JSONStorageType.BLOB_JSON: - state_column = "state BLOB CHECK (state IS JSON) NOT NULL" - else: - state_column = "state BLOB NOT NULL" + sql = f"DELETE FROM {self._session_table} WHERE app_name = :app_name AND user_id = :user_id AND id = :id" - owner_id_column_sql = f", {self._owner_id_column_ddl}" if self._owner_id_column_ddl else "" - table_clauses = _oracle_table_feature_clauses( - self._config, - "session", - in_memory=self._in_memory, - hash_partition_key="id", - range_partition_key="create_time", + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) + await conn.commit() + + async def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session. + + Args: + event_record: Event record. + """ + sql = f""" + INSERT INTO {self._events_table} ( + id, session_id, invocation_id, timestamp, event_data + ) VALUES ( + :id, :session_id, :invocation_id, :timestamp, :event_data ) + """ - return f""" - BEGIN - EXECUTE IMMEDIATE 'CREATE TABLE {self._session_table} ( - id VARCHAR2(128) PRIMARY KEY, - app_name VARCHAR2(128) NOT NULL, - user_id VARCHAR2(128) NOT NULL, - {state_column}, + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute( + sql, + { + "id": event_record["id"], + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "timestamp": event_record["timestamp"], + "event_data": await self._serialize_event_data(event_record["event_data"]), + }, + ) + await conn.commit() + + async def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update session + scoped state. + + All writes are executed within a single transaction so they succeed or + fail together. + """ + insert_sql = f""" + INSERT INTO {self._events_table} ( + id, session_id, invocation_id, timestamp, event_data + ) VALUES ( + :id, :session_id, :invocation_id, :timestamp, :event_data + ) + """ + + state_data = await self._serialize_state(state) + update_sql = f""" + UPDATE {self._session_table} + SET state = :state, update_time = SYSTIMESTAMP + WHERE app_name = :app_name AND user_id = :user_id AND id = :id + """ + + select_sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = :app_name AND user_id = :user_id AND id = :id + """ + + app_upsert_sql = f""" + MERGE INTO {self._app_state_table} target + USING (SELECT :app_name AS app_name, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, state, update_time) + VALUES (source.app_name, source.state, SYSTIMESTAMP) + """ + + user_upsert_sql = f""" + MERGE INTO {self._user_state_table} target + USING (SELECT :app_name AS app_name, :user_id AS user_id, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name AND target.user_id = source.user_id) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, user_id, state, update_time) + VALUES (source.app_name, source.user_id, source.state, SYSTIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + try: + await cursor.execute( + update_sql, {"state": state_data, "app_name": app_name, "user_id": user_id, "id": session_id} + ) + await cursor.execute(select_sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) + row = await cursor.fetchone() + if row is None: + _raise_session_not_found(session_id) + await cursor.execute( + insert_sql, + { + "id": event_record["id"], + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "timestamp": event_record["timestamp"], + "event_data": await self._serialize_event_data(event_record["event_data"]), + }, + ) + if app_state is not None: + await cursor.execute( + app_upsert_sql, {"app_name": app_name, "state": await self._serialize_state(app_state)} + ) + if user_state is not None: + await cursor.execute( + user_upsert_sql, + {"app_name": app_name, "user_id": user_id, "state": await self._serialize_state(user_state)}, + ) + await conn.commit() + except Exception: + await conn.rollback() + raise + + session_id_val, row_app_name, row_user_id, state_data_row, create_time, update_time = row + return SessionRecord( + id=session_id_val, + app_name=row_app_name, + user_id=row_user_id, + state=await self._deserialize_state(state_data_row), + create_time=create_time, + update_time=update_time, + ) + + async def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. + + Returns: + List of event records ordered by timestamp ASC. + """ + + if limit == 0: + return [] + + where_clauses = ["s.app_name = :app_name", "s.user_id = :user_id", "e.session_id = :session_id"] + params: dict[str, Any] = {"app_name": app_name, "user_id": user_id, "session_id": session_id} + + if after_timestamp is not None: + where_clauses.append("e.timestamp > :after_timestamp") + params["after_timestamp"] = after_timestamp + + where_clause = " AND ".join(where_clauses) + limit_clause = "" + if limit is not None: + limit_clause = f" FETCH FIRST {limit} ROWS ONLY" + + sql = f""" + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id + WHERE {where_clause} + ORDER BY e.timestamp ASC{limit_clause} + """ + + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, params) + rows = await cursor.fetchall() + + return [ + EventRecord( + id=row[0], + session_id=row[1], + invocation_id=_oracle_text_value(row[2]), + timestamp=row[3], + event_data=await self._deserialize_json_field(row[4]) or {}, + app_name=row[5], + user_id=row[6], + ) + for row in rows + ] + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return [] + raise + + async def delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < :before" + + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"before": before}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < :updated_before" + + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"updated_before": updated_before}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = :app_name" + + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"app_name": app_name}) + row = await cursor.fetchone() + return await self._deserialize_state(row[0]) if row is not None else None + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + sql = f""" + SELECT state + FROM {self._user_state_table} + WHERE app_name = :app_name AND user_id = :user_id + """ + + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"app_name": app_name, "user_id": user_id}) + row = await cursor.fetchone() + return await self._deserialize_state(row[0]) if row is not None else None + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + sql = f""" + MERGE INTO {self._app_state_table} target + USING (SELECT :app_name AS app_name, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, state, update_time) + VALUES (source.app_name, source.state, SYSTIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"app_name": app_name, "state": await self._serialize_state(state)}) + await conn.commit() + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + sql = f""" + MERGE INTO {self._user_state_table} target + USING (SELECT :app_name AS app_name, :user_id AS user_id, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name AND target.user_id = source.user_id) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, user_id, state, update_time) + VALUES (source.app_name, source.user_id, source.state, SYSTIMESTAMP) + """ + + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute( + sql, {"app_name": app_name, "user_id": user_id, "state": await self._serialize_state(state)} + ) + await conn.commit() + + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = :key" + + try: + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"key": key}) + row = await cursor.fetchone() + return str(row[0]) if row is not None else None + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + sql = f""" + MERGE INTO {self._metadata_table} target + USING (SELECT :key AS key, :value AS value FROM DUAL) source + ON (target.key = source.key) + WHEN MATCHED THEN + UPDATE SET target.value = source.value + WHEN NOT MATCHED THEN + INSERT (key, value) + VALUES (source.key, source.value) + """ + + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"key": key, "value": value}) + await conn.commit() + + async def _get_create_sessions_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for sessions table. + + Auto-detects optimal JSON storage type based on Oracle version. + Result is cached to minimize database queries. + """ + storage_type = await self._detect_json_storage_type() + return self._get_create_sessions_table_sql_for_type(storage_type) + + async def _get_create_events_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for events table. + + Auto-detects optimal JSON storage type based on Oracle version. + Result is cached to minimize database queries. + """ + storage_type = await self._detect_json_storage_type() + return self._get_create_events_table_sql_for_type(storage_type) + + async def _get_create_app_states_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for app-scoped state.""" + storage_type = await self._detect_json_storage_type() + return self._get_create_app_states_table_sql_for_type(storage_type) + + async def _get_create_user_states_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for user-scoped state.""" + storage_type = await self._detect_json_storage_type() + return self._get_create_user_states_table_sql_for_type(storage_type) + + async def _get_create_metadata_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for ADK internal metadata.""" + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._metadata_table} ( + key VARCHAR2(128) PRIMARY KEY, + value VARCHAR2(512) NOT NULL + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + async def _get_seed_metadata_sql(self) -> str: + """Get Oracle SQL to seed the ADK schema-version metadata row.""" + return f""" + BEGIN + INSERT INTO {self._metadata_table} (key, value) + SELECT 'schema_version', '1' + FROM DUAL + WHERE NOT EXISTS ( + SELECT 1 FROM {self._metadata_table} WHERE key = 'schema_version' + ); + END; + """ + + async def _detect_json_storage_type(self) -> JSONStorageType: + """Detect the appropriate JSON storage type based on Oracle version. + + Returns: + Appropriate JSONStorageType for this Oracle version. + + Notes: + Queries product_component_version to determine Oracle version. + - Oracle 21c+ with compatible >= 20: Native JSON type + - Oracle 12c+: BLOB with IS JSON constraint + - Oracle 11g and earlier: plain BLOB + + Result is cached in self._json_storage_type. + """ + if self._json_storage_type is not None: + return self._json_storage_type + + version_info = await self._get_version_info() + self._json_storage_type = _storage_type_from_version(version_info) + return self._json_storage_type + + async def _get_version_info(self) -> "OracleVersionInfo | None": + """Return cached Oracle version info using Oracle data dictionary.""" + + if self._oracle_version_info is not None: + return self._oracle_version_info + + async with self._config.provide_session() as driver: + dictionary = OracledbAsyncDataDictionary() + self._oracle_version_info = await dictionary.get_version(driver) + + if self._oracle_version_info is None: + logger.warning("Could not detect Oracle version, defaulting to BLOB_JSON storage") + + return self._oracle_version_info + + async def _serialize_state(self, state: "dict[str, Any]") -> "str | bytes": + """Serialize state dictionary to appropriate format based on storage type. + + Args: + state: State dictionary to serialize. + + Returns: + JSON string for JSON_NATIVE, bytes for BLOB types. + """ + storage_type = await self._detect_json_storage_type() + + if storage_type == JSONStorageType.JSON_NATIVE: + return to_json(state) + + return to_json(state, as_bytes=True) + + async def _deserialize_state(self, data: Any) -> "dict[str, Any]": + """Deserialize state data from database format. + + Args: + data: Data from database (may be LOB, str, bytes, or dict). + + Returns: + Deserialized state dictionary. + + Notes: + Handles LOB reading if data has read() method. + Oracle JSON type may return dict directly. + """ + if is_async_readable(data): + data = await data.read() + elif is_readable(data): + data = data.read() + + if isinstance(data, dict): + return cast("dict[str, Any]", _coerce_decimal_values(data)) + + if isinstance(data, bytes): + return from_json(data) # type: ignore[no-any-return] + + if isinstance(data, str): + return from_json(data) # type: ignore[no-any-return] + + return from_json(str(data)) # type: ignore[no-any-return] + + async def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": + """Deserialize JSON payloads from Oracle JSON/BLOB/LOB values.""" + if data is None: + return None + return await self._deserialize_state(data) + + async def _serialize_event_data(self, event_data: Any) -> "str | bytes": + """Serialize event_data to the configured Oracle JSON storage format.""" + storage_type = await self._detect_json_storage_type() + event_data = _normalize_event_data_for_storage(event_data) + if storage_type == JSONStorageType.JSON_NATIVE: + return to_json(event_data) + return to_json(event_data, as_bytes=True) + + async def _read_event_data(self, data: Any) -> str: + """Read event_data from database, handling LOB types. + + Args: + data: Data from database (may be LOB, str, or dict). + + Returns: + JSON string. + """ + if is_async_readable(data): + data = await data.read() + elif is_readable(data): + data = data.read() + + if isinstance(data, dict): + return to_json(data) + + if isinstance(data, bytes): + return data.decode("utf-8") + + return str(data) + + def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for sessions with specified storage type. + + Args: + storage_type: JSON storage type to use. + + Returns: + SQL statement to create adk_session table. + """ + if storage_type == JSONStorageType.JSON_NATIVE: + state_column = "state JSON NOT NULL" + elif storage_type == JSONStorageType.BLOB_JSON: + state_column = "state BLOB CHECK (state IS JSON) NOT NULL" + else: + state_column = "state BLOB NOT NULL" + + owner_id_column_sql = f", {self._owner_id_column_ddl}" if self._owner_id_column_ddl else "" + table_clauses = _oracle_table_feature_clauses( + self._config, + "session", + in_memory=self._in_memory, + hash_partition_key="id", + range_partition_key="create_time", + ) + + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._session_table} ( + id VARCHAR2(128) PRIMARY KEY, + app_name VARCHAR2(128) NOT NULL, + user_id VARCHAR2(128) NOT NULL, + {state_column}, create_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL{owner_id_column_sql} ){table_clauses}'; @@ -297,17 +939,16 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for events with specified storage type. - The events table uses the new 5-column contract: session_id, invocation_id, - author, timestamp, and event_json. The event_json column stores the full - ADK Event as JSON (21c+) or BLOB (older versions). + The events table stores the full ADK Event in ``event_data`` and + keeps scalar event fields indexed for efficient scoped reads. Args: storage_type: JSON storage type to use. Returns: - SQL statement to create adk_events table. + SQL statement to create adk_event table. """ - event_json_col = _event_json_column_ddl(storage_type) + event_data_col = _event_data_column_ddl(storage_type) table_clauses = _oracle_table_feature_clauses( self._config, "events", @@ -319,11 +960,11 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - return f""" BEGIN EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( + id VARCHAR2(128) PRIMARY KEY, session_id VARCHAR2(128) NOT NULL, invocation_id VARCHAR2(256), - author VARCHAR2(256), timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - {event_json_col}, + {event_data_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ){table_clauses}'; @@ -339,418 +980,182 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - ON {self._events_table}(session_id, timestamp ASC)'; EXCEPTION WHEN OTHERS THEN - IF SQLCODE != -955 THEN - RAISE; - END IF; - END; - """ - - def _get_drop_tables_sql(self) -> "list[str]": - """Get Oracle DROP TABLE SQL statements. - - Returns: - List of SQL statements to drop tables and indexes. - """ - return [ - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP INDEX idx_{self._events_table}_session'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -1418 THEN - RAISE; - END IF; - END; - """, - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP INDEX idx_{self._session_table}_update_time'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -1418 THEN - RAISE; - END IF; - END; - """, - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP INDEX idx_{self._session_table}_app_user'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -1418 THEN - RAISE; - END IF; - END; - """, - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP TABLE {self._events_table}'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -942 THEN - RAISE; - END IF; - END; - """, - f""" - BEGIN - EXECUTE IMMEDIATE 'DROP TABLE {self._session_table}'; - EXCEPTION - WHEN OTHERS THEN - IF SQLCODE != -942 THEN - RAISE; - END IF; - END; - """, - ] - - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" - storage_type = await self._detect_json_storage_type() - logger.debug("Creating ADK tables with storage type: %s", storage_type) - - async with self._config.provide_session() as driver: - await driver.execute_script(self._get_create_sessions_table_sql_for_type(storage_type)) - - await driver.execute_script(self._get_create_events_table_sql_for_type(storage_type)) - - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner_id_column (if configured). - - Returns: - Created session record. - """ - state_data = await self._serialize_state(state) - - if self._owner_id_column_name: - sql = f""" - INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time, {self._owner_id_column_name}) - VALUES (:id, :app_name, :user_id, :state, SYSTIMESTAMP, SYSTIMESTAMP, :owner_id) - """ - params = { - "id": session_id, - "app_name": app_name, - "user_id": user_id, - "state": state_data, - "owner_id": owner_id, - } - else: - sql = f""" - INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time) - VALUES (:id, :app_name, :user_id, :state, SYSTIMESTAMP, SYSTIMESTAMP) - """ - params = {"id": session_id, "app_name": app_name, "user_id": user_id, "state": state_data} - - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, params) - await conn.commit() - - return await self.get_session(session_id) # type: ignore[return-value] - - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - """ - - try: - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute( - f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = :id - """, - {"id": session_id}, - ) - row = await cursor.fetchone() - - if row is None: - return None - - session_id_val, app_name, user_id, state_data, create_time, update_time = row - - state = await self._deserialize_state(state_data) - - return SessionRecord( - id=session_id_val, - app_name=app_name, - user_id=user_id, - state=state, - create_time=create_time, - update_time=update_time, - ) - except oracledb.DatabaseError as e: - error_obj = e.args[0] if e.args else None - if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return None - raise - - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - """ - state_data = await self._serialize_state(state) - - sql = f""" - UPDATE {self._session_table} - SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id - """ - - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"state": state_data, "id": session_id}) - await conn.commit() - - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - """ - sql = f"DELETE FROM {self._session_table} WHERE id = :id" - - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"id": session_id}) - await conn.commit() - - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - """ - - if user_id is None: - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE app_name = :app_name - ORDER BY update_time DESC - """ - params = {"app_name": app_name} - else: - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE app_name = :app_name AND user_id = :user_id - ORDER BY update_time DESC - """ - params = {"app_name": app_name, "user_id": user_id} - - try: - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, params) - rows = await cursor.fetchall() - - results = [] - for row in rows: - state = await self._deserialize_state(row[3]) - - results.append( - SessionRecord( - id=row[0], - app_name=row[1], - user_id=row[2], - state=state, - create_time=row[4], - update_time=row[5], - ) - ) - return results - except oracledb.DatabaseError as e: - error_obj = e.args[0] if e.args else None - if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return [] - raise + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. + BEGIN + EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._events_table}_invocation + ON {self._events_table}(invocation_id)'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; - Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. - """ - sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_json - ) + BEGIN + EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._events_table}_timestamp + ON {self._events_table}(timestamp ASC)'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; """ - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute( - sql, - { - "session_id": event_record["session_id"], - "invocation_id": event_record["invocation_id"], - "author": event_record["author"], - "timestamp": event_record["timestamp"], - "event_json": await self._serialize_event_json(event_record["event_json"]), - }, - ) - await conn.commit() + def _get_create_app_states_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for app-scoped state with specified storage type.""" + state_column = _json_column_ddl("state", storage_type) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state. + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._app_state_table} ( + app_name VARCHAR2(128) PRIMARY KEY, + {state_column}, + update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ - Both the event insert and session state update are executed within a - single transaction so they succeed or fail together. The refreshed - SessionRecord is read inside the same transaction (Oracle's RETURNING - INTO requires output bind variables which complicate async cursor - handling, so a SELECT-after-UPDATE is used instead). + def _get_create_user_states_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for user-scoped state with specified storage type.""" + state_column = _json_column_ddl("state", storage_type) - Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). - """ - insert_sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_json - ) + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._user_state_table} ( + app_name VARCHAR2(128) NOT NULL, + user_id VARCHAR2(128) NOT NULL, + {state_column}, + update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, + PRIMARY KEY (app_name, user_id) + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; """ - state_data = await self._serialize_state(state) - update_sql = f""" - UPDATE {self._session_table} - SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id + def _get_drop_app_states_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._app_state_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; """ - select_sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = :id + def _get_drop_user_states_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._user_state_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; """ - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute( - insert_sql, - { - "session_id": event_record["session_id"], - "invocation_id": event_record["invocation_id"], - "author": event_record["author"], - "timestamp": event_record["timestamp"], - "event_json": await self._serialize_event_json(event_record["event_json"]), - }, - ) - await cursor.execute(update_sql, {"state": state_data, "id": session_id}) - await cursor.execute(select_sql, {"id": session_id}) - row = await cursor.fetchone() - await conn.commit() - - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) - - session_id_val, app_name, user_id, state_data_row, create_time, update_time = row - return SessionRecord( - id=session_id_val, - app_name=app_name, - user_id=user_id, - state=await self._deserialize_state(state_data_row), - create_time=create_time, - update_time=update_time, - ) - - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session. - - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. - - Returns: - List of event records ordered by timestamp ASC. + def _get_drop_metadata_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._metadata_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; """ - where_clauses = ["session_id = :session_id"] - params: dict[str, Any] = {"session_id": session_id} - - if after_timestamp is not None: - where_clauses.append("timestamp > :after_timestamp") - params["after_timestamp"] = after_timestamp + def _get_drop_tables_sql(self) -> "list[str]": + """Get Oracle DROP TABLE SQL statements. - where_clause = " AND ".join(where_clauses) - limit_clause = "" - if limit: - limit_clause = f" FETCH FIRST {limit} ROWS ONLY" + Returns: + List of SQL statements to drop tables and indexes. - sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} - WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + Notes: + Order matters: drop events table (child) before sessions (parent). + Oracle automatically drops indexes when dropping tables. """ - - try: - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, params) - rows = await cursor.fetchall() - - return [ - EventRecord( - session_id=row[0], - invocation_id=_oracle_text_value(row[1]), - author=_oracle_text_value(row[2]), - timestamp=row[3], - event_json=await self._deserialize_json_field(row[4]) or {}, - ) - for row in rows - ] - except oracledb.DatabaseError as e: - error_obj = e.args[0] if e.args else None - if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return [] - raise + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP INDEX idx_{self._events_table}_session'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -1418 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP INDEX idx_{self._session_table}_update_time'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -1418 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP INDEX idx_{self._session_table}_app_user'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -1418 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._events_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """, + f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._session_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """, + ] -class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]): +class OracleSyncADKStore(BaseSyncADKStore["OracleSyncConfig"]): """Oracle synchronous ADK store using oracledb sync driver. Implements session and event storage for Google Agent Development Kit using Oracle Database via the python-oracledb synchronous driver. Provides: - Session state management with version-specific JSON storage - - Full-fidelity event storage via ``event_json`` column + - Full-fidelity event storage via ``event_data`` column - Atomic ``create_event_and_update_state`` for durable session mutations - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps - Foreign key constraints with cascade delete @@ -758,6 +1163,15 @@ class OracleSyncADKStore(BaseAsyncADKStore["OracleSyncConfig"]): Args: config: OracleSyncConfig with extension_config["adk"] settings. + + Notes: + - JSON storage type detected based on Oracle version (21c+, 12c+, legacy) + - event_data stored as JSON (21c+) or BLOB (older versions) + - TIMESTAMP WITH TIME ZONE for timezone-aware timestamps + - Named parameters using :param_name + - State merging handled at application level + - owner_id_column supports NUMBER, VARCHAR2, RAW for Oracle FK types + - Configuration is read from config.extension_config["adk"] """ __slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info") @@ -767,37 +1181,183 @@ def __init__(self, config: "OracleSyncConfig") -> None: Args: config: OracleSyncConfig instance. + + Notes: + Configuration is read from config.extension_config["adk"]: + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") + - owner_id_column: Optional owner FK column DDL (default: None) + - in_memory: Enable INMEMORY PRIORITY HIGH clause (default: False) """ super().__init__(config) + _configure_oracle_adk_session_tables(self, config) self._json_storage_type: JSONStorageType | None = None self._oracle_version_info: OracleVersionInfo | None = None adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) - async def _get_create_sessions_table_sql(self) -> str: + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return self._create_session(session_id, app_name, user_id, state, owner_id) + + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return self._get_session(app_name, user_id, session_id, renew_for=renew_for) + + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + self._update_session_state(app_name, user_id, session_id, state) + + def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return self._list_sessions(app_name, user_id) + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and associated events.""" + self._delete_session(app_name, user_id, session_id) + + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + self._append_event(event_record) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update session + scoped state.""" + return self._append_event_and_update_state( + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state + ) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session.""" + return self._get_events(app_name, user_id, session_id, after_timestamp, limit) + + def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return self._delete_expired_events(before) + + def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return self._delete_idle_sessions(updated_before) + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return self._get_app_state(app_name) + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return self._get_user_state(app_name, user_id) + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + self._upsert_app_state(app_name, state) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + self._upsert_user_state(app_name, user_id, state) + + def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return self._get_metadata(key) + + def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + self._set_metadata(key, value) + + def _get_create_sessions_table_sql(self) -> str: """Get Oracle CREATE TABLE SQL for sessions table. Auto-detects optimal JSON storage type based on Oracle version. Result is cached to minimize database queries. """ - storage_type = self._detect_json_storage_type() - return self._get_create_sessions_table_sql_for_type(storage_type) - - async def _get_create_events_table_sql(self) -> str: - """Get Oracle CREATE TABLE SQL for events table. + storage_type = self._detect_json_storage_type() + return self._get_create_sessions_table_sql_for_type(storage_type) + + def _get_create_events_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for events table. + + Auto-detects optimal JSON storage type based on Oracle version. + Result is cached to minimize database queries. + """ + storage_type = self._detect_json_storage_type() + return self._get_create_events_table_sql_for_type(storage_type) + + def _get_create_app_states_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for app-scoped state.""" + storage_type = self._detect_json_storage_type() + return self._get_create_app_states_table_sql_for_type(storage_type) + + def _get_create_user_states_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for user-scoped state.""" + storage_type = self._detect_json_storage_type() + return self._get_create_user_states_table_sql_for_type(storage_type) + + def _get_create_metadata_table_sql(self) -> str: + """Get Oracle CREATE TABLE SQL for ADK internal metadata.""" + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._metadata_table} ( + key VARCHAR2(128) PRIMARY KEY, + value VARCHAR2(512) NOT NULL + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ - Auto-detects optimal JSON storage type based on Oracle version. - Result is cached to minimize database queries. + def _get_seed_metadata_sql(self) -> str: + """Get Oracle SQL to seed the ADK schema-version metadata row.""" + return f""" + BEGIN + INSERT INTO {self._metadata_table} (key, value) + SELECT 'schema_version', '1' + FROM DUAL + WHERE NOT EXISTS ( + SELECT 1 FROM {self._metadata_table} WHERE key = 'schema_version' + ); + END; """ - storage_type = self._detect_json_storage_type() - return self._get_create_events_table_sql_for_type(storage_type) def _detect_json_storage_type(self) -> JSONStorageType: """Detect the appropriate JSON storage type based on Oracle version. Returns: Appropriate JSONStorageType for this Oracle version. + + Notes: + Queries product_component_version to determine Oracle version. + - Oracle 21c+ with compatible >= 20: Native JSON type + - Oracle 12c+: BLOB with IS JSON constraint + - Oracle 11g and earlier: plain BLOB + + Result is cached in self._json_storage_type. """ if self._json_storage_type is not None: return self._json_storage_type @@ -845,6 +1405,10 @@ def _deserialize_state(self, data: Any) -> "dict[str, Any]": Returns: Deserialized state dictionary. + + Notes: + Handles LOB reading if data has read() method. + Oracle JSON type may return dict directly. """ if is_readable(data): data = data.read() @@ -866,15 +1430,16 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return None return self._deserialize_state(data) - def _serialize_event_json(self, event_json: Any) -> "str | bytes": - """Serialize event_json to the configured Oracle JSON storage format.""" + def _serialize_event_data(self, event_data: Any) -> "str | bytes": + """Serialize event_data to the configured Oracle JSON storage format.""" storage_type = self._detect_json_storage_type() + event_data = _normalize_event_data_for_storage(event_data) if storage_type == JSONStorageType.JSON_NATIVE: - return to_json(event_json) - return to_json(event_json, as_bytes=True) + return to_json(event_data) + return to_json(event_data, as_bytes=True) - def _read_event_json(self, data: Any) -> str: - """Read event_json from database, handling LOB types. + def _read_event_data(self, data: Any) -> str: + """Read event_data from database, handling LOB types. Args: data: Data from database (may be LOB, str, or dict). @@ -900,7 +1465,7 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) storage_type: JSON storage type to use. Returns: - SQL statement to create adk_sessions table. + SQL statement to create adk_session table. """ if storage_type == JSONStorageType.JSON_NATIVE: state_column = "state JSON NOT NULL" @@ -914,7 +1479,7 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) self._config, "session", in_memory=self._in_memory, - hash_partition_key="id", + hash_partition_key="session_id", range_partition_key="create_time", ) @@ -959,17 +1524,16 @@ def _get_create_sessions_table_sql_for_type(self, storage_type: JSONStorageType) def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) -> str: """Get Oracle CREATE TABLE SQL for events with specified storage type. - The events table uses the new 5-column contract: session_id, invocation_id, - author, timestamp, and event_json. The event_json column stores the full - ADK Event as JSON (21c+) or BLOB (older versions). + The events table stores the full ADK Event in ``event_data`` and + keeps scalar event fields indexed for efficient scoped reads. Args: storage_type: JSON storage type to use. Returns: - SQL statement to create adk_events table. + SQL statement to create adk_event table. """ - event_json_col = _event_json_column_ddl(storage_type) + event_data_col = _event_data_column_ddl(storage_type) table_clauses = _oracle_table_feature_clauses( self._config, "events", @@ -981,11 +1545,11 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - return f""" BEGIN EXECUTE IMMEDIATE 'CREATE TABLE {self._events_table} ( + id VARCHAR2(128) PRIMARY KEY, session_id VARCHAR2(128) NOT NULL, invocation_id VARCHAR2(256), - author VARCHAR2(256), timestamp TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, - {event_json_col}, + {event_data_col}, CONSTRAINT fk_{self._events_table}_session FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ){table_clauses}'; @@ -1005,6 +1569,102 @@ def _get_create_events_table_sql_for_type(self, storage_type: JSONStorageType) - RAISE; END IF; END; + + BEGIN + EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._events_table}_invocation + ON {self._events_table}(invocation_id)'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + + BEGIN + EXECUTE IMMEDIATE 'CREATE INDEX idx_{self._events_table}_timestamp + ON {self._events_table}(timestamp ASC)'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + def _get_create_app_states_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for app-scoped state with specified storage type.""" + state_column = _json_column_ddl("state", storage_type) + + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._app_state_table} ( + app_name VARCHAR2(128) PRIMARY KEY, + {state_column}, + update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + def _get_create_user_states_table_sql_for_type(self, storage_type: JSONStorageType) -> str: + """Get Oracle CREATE TABLE SQL for user-scoped state with specified storage type.""" + state_column = _json_column_ddl("state", storage_type) + + return f""" + BEGIN + EXECUTE IMMEDIATE 'CREATE TABLE {self._user_state_table} ( + app_name VARCHAR2(128) NOT NULL, + user_id VARCHAR2(128) NOT NULL, + {state_column}, + update_time TIMESTAMP WITH TIME ZONE DEFAULT SYSTIMESTAMP NOT NULL, + PRIMARY KEY (app_name, user_id) + )'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -955 THEN + RAISE; + END IF; + END; + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._app_state_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """ + + def _get_drop_user_states_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._user_state_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; + """ + + def _get_drop_metadata_table_sql(self) -> str: + return f""" + BEGIN + EXECUTE IMMEDIATE 'DROP TABLE {self._metadata_table}'; + EXCEPTION + WHEN OTHERS THEN + IF SQLCODE != -942 THEN + RAISE; + END IF; + END; """ def _get_drop_tables_sql(self) -> "list[str]": @@ -1012,8 +1672,15 @@ def _get_drop_tables_sql(self) -> "list[str]": Returns: List of SQL statements to drop tables and indexes. + + Notes: + Order matters: drop events table (child) before sessions (parent). + Oracle automatically drops indexes when dropping tables. """ return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), f""" BEGIN EXECUTE IMMEDIATE 'DROP INDEX idx_{self._events_table}_session'; @@ -1067,7 +1734,12 @@ def _get_drop_tables_sql(self) -> "list[str]": ] def _create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" + """Create both sessions and events tables if they don't exist. + + Notes: + Detects Oracle version to determine optimal JSON storage type. + Uses version-appropriate table schema. + """ storage_type = self._detect_json_storage_type() logger.info("Creating ADK tables with storage type: %s", storage_type) @@ -1077,10 +1749,11 @@ def _create_tables(self) -> None: events_sql = SQL(self._get_create_events_table_sql_for_type(storage_type)) driver.execute_script(events_sql) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(SQL(self._get_create_app_states_table_sql_for_type(storage_type))) + driver.execute_script(SQL(self._get_create_user_states_table_sql_for_type(storage_type))) + driver.execute_script(SQL(self._get_create_metadata_table_sql())) + driver.execute_script(SQL(self._get_seed_metadata_sql())) + driver.commit() def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -1096,6 +1769,11 @@ def _create_session( Returns: Created session record. + + Notes: + Uses SYSTIMESTAMP for create_time and update_time. + State is serialized using version-appropriate format. + owner_id is ignored if owner_id_column not configured. """ state_data = self._serialize_state(state) @@ -1123,38 +1801,48 @@ def _create_session( cursor.execute(sql, params) conn.commit() - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Get session by ID. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. + renew_for: If positive, touch update_time while reading. Returns: Session record or None if not found. + + Notes: + Oracle returns datetime objects for TIMESTAMP columns. + State is deserialized using version-appropriate format. """ sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ try: with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute(sql, {"id": session_id}) + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + cursor.execute( + f"UPDATE {self._session_table} SET update_time = SYSTIMESTAMP WHERE app_name = :app_name AND user_id = :user_id AND id = :id", + {"app_name": app_name, "user_id": user_id, "id": session_id}, + ) + conn.commit() + + cursor.execute(sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) row = cursor.fetchone() if row is None: @@ -1172,57 +1860,39 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": create_time=create_time, update_time=update_time, ) - except oracledb.DatabaseError as e: + except OracleDatabaseError as e: error_obj = e.args[0] if e.args else None if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. state: New state dictionary (replaces existing state). + + Notes: + This replaces the entire state dictionary. + Updates update_time to current timestamp. + State is serialized using version-appropriate format. """ state_data = self._serialize_state(state) sql = f""" UPDATE {self._session_table} SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id - """ - - with self._config.provide_connection() as conn: - cursor = conn.cursor() - cursor.execute(sql, {"state": state_data, "id": session_id}) - conn.commit() - - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - - def _delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ - sql = f"DELETE FROM {self._session_table} WHERE id = :id" with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute(sql, {"id": session_id}) + cursor.execute(sql, {"state": state_data, "app_name": app_name, "user_id": user_id, "id": session_id}) conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": """List sessions for an app, optionally filtered by user. @@ -1232,6 +1902,10 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses Returns: List of session records ordered by update_time DESC. + + Notes: + Uses composite index on (app_name, user_id) when user_id is provided. + State is deserialized using version-appropriate format. """ if user_id is None: @@ -1272,37 +1946,71 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses ) ) return results - except oracledb.DatabaseError as e: + except OracleDatabaseError as e: error_obj = e.args[0] if e.args else None if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and all associated events (cascade). - def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically create an event and update the session's durable state. + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. - Both the event insert and session state update are executed within a - single transaction so they succeed or fail together; the refreshed - SessionRecord is read inside the same transaction. + Notes: + Foreign key constraint ensures events are cascade-deleted. + """ + sql = f"DELETE FROM {self._session_table} WHERE app_name = :app_name AND user_id = :user_id AND id = :id" - Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) + conn.commit() + + def _append_event(self, event_record: EventRecord) -> None: + """Synchronous implementation of append_event.""" + sql = f""" + INSERT INTO {self._events_table} ( + id, session_id, invocation_id, timestamp, event_data + ) VALUES ( + :id, :session_id, :invocation_id, :timestamp, :event_data + ) """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute( + sql, + { + "id": event_record["id"], + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "timestamp": event_record["timestamp"], + "event_data": self._serialize_event_data(event_record["event_data"]), + }, + ) + conn.commit() + + def _append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically create an event and update session + scoped state.""" insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_json + :id, :session_id, :invocation_id, :timestamp, :event_data ) """ @@ -1310,58 +2018,92 @@ def _append_event_and_update_state( update_sql = f""" UPDATE {self._session_table} SET state = :state, update_time = SYSTIMESTAMP - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id """ select_sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = :id + WHERE app_name = :app_name AND user_id = :user_id AND id = :id + """ + + app_upsert_sql = f""" + MERGE INTO {self._app_state_table} target + USING (SELECT :app_name AS app_name, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, state, update_time) + VALUES (source.app_name, source.state, SYSTIMESTAMP) + """ + + user_upsert_sql = f""" + MERGE INTO {self._user_state_table} target + USING (SELECT :app_name AS app_name, :user_id AS user_id, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name AND target.user_id = source.user_id) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, user_id, state, update_time) + VALUES (source.app_name, source.user_id, source.state, SYSTIMESTAMP) """ with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute( - insert_sql, - { - "session_id": event_record["session_id"], - "invocation_id": event_record["invocation_id"], - "author": event_record["author"], - "timestamp": event_record["timestamp"], - "event_json": self._serialize_event_json(event_record["event_json"]), - }, - ) - cursor.execute(update_sql, {"state": state_data, "id": session_id}) - cursor.execute(select_sql, {"id": session_id}) - row = cursor.fetchone() - conn.commit() - - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) + try: + cursor.execute( + update_sql, {"state": state_data, "app_name": app_name, "user_id": user_id, "id": session_id} + ) + cursor.execute(select_sql, {"app_name": app_name, "user_id": user_id, "id": session_id}) + row = cursor.fetchone() + if row is None: + _raise_session_not_found(session_id) + cursor.execute( + insert_sql, + { + "id": event_record["id"], + "session_id": event_record["session_id"], + "invocation_id": event_record["invocation_id"], + "timestamp": event_record["timestamp"], + "event_data": self._serialize_event_data(event_record["event_data"]), + }, + ) + if app_state is not None: + cursor.execute(app_upsert_sql, {"app_name": app_name, "state": self._serialize_state(app_state)}) + if user_state is not None: + cursor.execute( + user_upsert_sql, + {"app_name": app_name, "user_id": user_id, "state": self._serialize_state(user_state)}, + ) + conn.commit() + except Exception: + conn.rollback() + raise - session_id_val, app_name, user_id, state_data_row, create_time, update_time = row + session_id_val, row_app_name, row_user_id, state_data_row, create_time, update_time = row return SessionRecord( id=session_id_val, - app_name=app_name, - user_id=user_id, + app_name=row_app_name, + user_id=row_user_id, state=self._deserialize_state(state_data_row), create_time=create_time, update_time=update_time, ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """List events for a session ordered by timestamp. Args: + app_name: Application name. + user_id: User identifier. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -1370,78 +2112,186 @@ def _get_events( List of event records ordered by timestamp ASC. """ - where_clauses = ["session_id = :session_id"] - params: dict[str, Any] = {"session_id": session_id} + if limit == 0: + return [] + + where_clauses = ["s.app_name = :app_name", "s.user_id = :user_id", "e.session_id = :session_id"] + params: dict[str, Any] = {"app_name": app_name, "user_id": user_id, "session_id": session_id} if after_timestamp is not None: - where_clauses.append("timestamp > :after_timestamp") + where_clauses.append("e.timestamp > :after_timestamp") params["after_timestamp"] = after_timestamp where_clause = " AND ".join(where_clauses) - limit_clause = f" FETCH FIRST {limit} ROWS ONLY" if limit else "" + limit_clause = f" FETCH FIRST {limit} ROWS ONLY" if limit is not None else "" sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} + """ + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, params) + rows = cursor.fetchall() + + return [ + EventRecord( + id=row[0], + session_id=row[1], + invocation_id=_oracle_text_value(row[2]), + timestamp=row[3], + event_data=self._deserialize_json_field(row[4]) or {}, + app_name=row[5], + user_id=row[6], + ) + for row in rows + ] + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return [] + raise + + def _delete_expired_events(self, before: "datetime") -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < :before" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"before": before}) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < :updated_before" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"updated_before": updated_before}) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return 0 + raise + + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_app_state.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = :app_name" + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"app_name": app_name}) + row = cursor.fetchone() + return self._deserialize_state(row[0]) if row is not None else None + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_user_state.""" + sql = f""" + SELECT state + FROM {self._user_state_table} + WHERE app_name = :app_name AND user_id = :user_id + """ + + try: + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"app_name": app_name, "user_id": user_id}) + row = cursor.fetchone() + return self._deserialize_state(row[0]) if row is not None else None + except OracleDatabaseError as e: + error_obj = e.args[0] if e.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_app_state.""" + sql = f""" + MERGE INTO {self._app_state_table} target + USING (SELECT :app_name AS app_name, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, state, update_time) + VALUES (source.app_name, source.state, SYSTIMESTAMP) + """ + + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"app_name": app_name, "state": self._serialize_state(state)}) + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_user_state.""" + sql = f""" + MERGE INTO {self._user_state_table} target + USING (SELECT :app_name AS app_name, :user_id AS user_id, :state AS state FROM DUAL) source + ON (target.app_name = source.app_name AND target.user_id = source.user_id) + WHEN MATCHED THEN + UPDATE SET target.state = source.state, target.update_time = SYSTIMESTAMP + WHEN NOT MATCHED THEN + INSERT (app_name, user_id, state, update_time) + VALUES (source.app_name, source.user_id, source.state, SYSTIMESTAMP) """ + with self._config.provide_connection() as conn: + cursor = conn.cursor() + cursor.execute(sql, {"app_name": app_name, "user_id": user_id, "state": self._serialize_state(state)}) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + """Synchronous implementation of get_metadata.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = :key" + try: with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute(sql, params) - rows = cursor.fetchall() - - return [ - EventRecord( - session_id=row[0], - invocation_id=_oracle_text_value(row[1]), - author=_oracle_text_value(row[2]), - timestamp=row[3], - event_json=self._deserialize_json_field(row[4]) or {}, - ) - for row in rows - ] - except oracledb.DatabaseError as e: + cursor.execute(sql, {"key": key}) + row = cursor.fetchone() + return str(row[0]) if row is not None else None + except OracleDatabaseError as e: error_obj = e.args[0] if e.args else None if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return [] + return None raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) - - def _append_event(self, event_record: EventRecord) -> None: - """Synchronous implementation of append_event.""" + def _set_metadata(self, key: str, value: str) -> None: + """Synchronous implementation of set_metadata.""" sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES ( - :session_id, :invocation_id, :author, :timestamp, :event_json - ) + MERGE INTO {self._metadata_table} target + USING (SELECT :key AS key, :value AS value FROM DUAL) source + ON (target.key = source.key) + WHEN MATCHED THEN + UPDATE SET target.value = source.value + WHEN NOT MATCHED THEN + INSERT (key, value) + VALUES (source.key, source.value) """ with self._config.provide_connection() as conn: cursor = conn.cursor() - cursor.execute( - sql, - { - "session_id": event_record["session_id"], - "invocation_id": event_record["invocation_id"], - "author": event_record["author"], - "timestamp": event_record["timestamp"], - "event_json": self._serialize_event_json(event_record["event_json"]), - }, - ) + cursor.execute(sql, {"key": key, "value": value}) conn.commit() - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - class OracleAsyncADKMemoryStore(BaseAsyncADKMemoryStore["OracleAsyncConfig"]): """Oracle ADK memory store using async oracledb driver.""" @@ -1455,6 +2305,98 @@ def __init__(self, config: "OracleAsyncConfig") -> None: adk_config = config.extension_config.get("adk", {}) self._in_memory: bool = bool(adk_config.get("in_memory", False)) + async def create_tables(self) -> None: + if not self._enabled: + return + + async with self._config.provide_session() as driver: + await driver.execute_script(await self._get_create_memory_table_sql()) + + async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + if not self._enabled: + msg = "Memory store is disabled" + raise RuntimeError(msg) + + if not entries: + return 0 + + owner_column = f", {self._owner_id_column_name}" if self._owner_id_column_name else "" + owner_param = ", :owner_id" if self._owner_id_column_name else "" + sql = f""" + INSERT INTO {self._memory_table} ( + id, session_id, app_name, user_id, event_id, author{owner_column}, + timestamp, content_json, content_text, metadata_json, inserted_at + ) VALUES ( + :id, :session_id, :app_name, :user_id, :event_id, :author{owner_param}, + :timestamp, :content_json, :content_text, :metadata_json, :inserted_at + ) + """ + + inserted_count = 0 + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + for entry in entries: + content_json = await self._serialize_json_field(entry["content_json"]) + metadata_json = await self._serialize_json_field(entry["metadata_json"]) + params = { + "id": entry["id"], + "session_id": entry["session_id"], + "app_name": entry["app_name"], + "user_id": entry["user_id"], + "event_id": entry["event_id"], + "author": entry["author"], + "timestamp": entry["timestamp"], + "content_json": content_json, + "content_text": entry["content_text"], + "metadata_json": metadata_json, + "inserted_at": entry["inserted_at"], + } + if self._owner_id_column_name: + params["owner_id"] = str(owner_id) if owner_id is not None else None + if await self._execute_insert_entry(cursor, sql, params): + inserted_count += 1 + await conn.commit() + + return inserted_count + + async def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + if not self._enabled: + msg = "Memory store is disabled" + raise RuntimeError(msg) + + effective_limit = limit if limit is not None else self._max_results + + try: + if self._use_fts: + return await self._search_entries_fts(query, app_name, user_id, effective_limit) + return await self._search_entries_simple(query, app_name, user_id, effective_limit) + except OracleDatabaseError as exc: + error_obj = exc.args[0] if exc.args else None + if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: + return [] + raise + + async def delete_entries_by_session(self, session_id: str) -> int: + sql = f"DELETE FROM {self._memory_table} WHERE session_id = :session_id" + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"session_id": session_id}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + + async def delete_entries_older_than(self, days: int) -> int: + sql = f""" + DELETE FROM {self._memory_table} + WHERE inserted_at < SYSTIMESTAMP - NUMTODSINTERVAL(:days, 'DAY') + """ + async with self._config.provide_connection() as conn: + cursor = conn.cursor() + await cursor.execute(sql, {"days": days}) + await conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + async def _detect_json_storage_type(self) -> "JSONStorageType": if self._json_storage_type is not None: return self._json_storage_type @@ -1615,90 +2557,17 @@ def _get_drop_memory_table_sql(self) -> "list[str]": """, ] - async def create_tables(self) -> None: - if not self._enabled: - return - - async with self._config.provide_session() as driver: - await driver.execute_script(await self._get_create_memory_table_sql()) - async def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]") -> bool: """Execute an insert and skip duplicate key errors.""" try: await cursor.execute(sql, params) - except oracledb.DatabaseError as exc: + except OracleDatabaseError as exc: error_obj = exc.args[0] if exc.args else None if error_obj and error_obj.code == ORACLE_DUPLICATE_KEY_ERROR: return False raise return True - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - if not self._enabled: - msg = "Memory store is disabled" - raise RuntimeError(msg) - - if not entries: - return 0 - - owner_column = f", {self._owner_id_column_name}" if self._owner_id_column_name else "" - owner_param = ", :owner_id" if self._owner_id_column_name else "" - sql = f""" - INSERT INTO {self._memory_table} ( - id, session_id, app_name, user_id, event_id, author{owner_column}, - timestamp, content_json, content_text, metadata_json, inserted_at - ) VALUES ( - :id, :session_id, :app_name, :user_id, :event_id, :author{owner_param}, - :timestamp, :content_json, :content_text, :metadata_json, :inserted_at - ) - """ - - inserted_count = 0 - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - for entry in entries: - content_json = await self._serialize_json_field(entry["content_json"]) - metadata_json = await self._serialize_json_field(entry["metadata_json"]) - params = { - "id": entry["id"], - "session_id": entry["session_id"], - "app_name": entry["app_name"], - "user_id": entry["user_id"], - "event_id": entry["event_id"], - "author": entry["author"], - "timestamp": entry["timestamp"], - "content_json": content_json, - "content_text": entry["content_text"], - "metadata_json": metadata_json, - "inserted_at": entry["inserted_at"], - } - if self._owner_id_column_name: - params["owner_id"] = str(owner_id) if owner_id is not None else None - if await self._execute_insert_entry(cursor, sql, params): - inserted_count += 1 - await conn.commit() - - return inserted_count - - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - if not self._enabled: - msg = "Memory store is disabled" - raise RuntimeError(msg) - - effective_limit = limit if limit is not None else self._max_results - - try: - if self._use_fts: - return await self._search_entries_fts(query, app_name, user_id, effective_limit) - return await self._search_entries_simple(query, app_name, user_id, effective_limit) - except oracledb.DatabaseError as exc: - error_obj = exc.args[0] if exc.args else None - if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: - return [] - raise - async def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -1745,25 +2614,6 @@ async def _search_entries_simple(self, query: str, app_name: str, user_id: str, rows = await cursor.fetchall() return await self._rows_to_records(rows) - async def delete_entries_by_session(self, session_id: str) -> int: - sql = f"DELETE FROM {self._memory_table} WHERE session_id = :session_id" - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"session_id": session_id}) - await conn.commit() - return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - - async def delete_entries_older_than(self, days: int) -> int: - sql = f""" - DELETE FROM {self._memory_table} - WHERE inserted_at < SYSTIMESTAMP - NUMTODSINTERVAL(:days, 'DAY') - """ - async with self._config.provide_connection() as conn: - cursor = conn.cursor() - await cursor.execute(sql, {"days": days}) - await conn.commit() - return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: @@ -1788,7 +2638,7 @@ async def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": return records -class OracleSyncADKMemoryStore(BaseAsyncADKMemoryStore["OracleSyncConfig"]): +class OracleSyncADKMemoryStore(BaseSyncADKMemoryStore["OracleSyncConfig"]): """Oracle ADK memory store using sync oracledb driver.""" __slots__ = ("_in_memory", "_json_storage_type", "_oracle_version_info") @@ -1800,6 +2650,28 @@ def __init__(self, config: "OracleSyncConfig") -> None: adk_config = config.extension_config.get("adk", {}) self._in_memory = bool(adk_config.get("in_memory", False)) + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return self._insert_memory_entries(entries, owner_id) + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return self._search_entries(query, app_name, user_id, limit) + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._delete_entries_by_session(session_id) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return self._delete_entries_older_than(days) + def _detect_json_storage_type(self) -> "JSONStorageType": if self._json_storage_type is not None: return self._json_storage_type @@ -1839,7 +2711,7 @@ def _deserialize_json_field(self, data: Any) -> "dict[str, Any] | None": return _extract_json_value(data) - async def _get_create_memory_table_sql(self) -> str: + def _get_create_memory_table_sql(self) -> str: storage_type = self._detect_json_storage_type() return self._get_create_memory_table_sql_for_type(storage_type) @@ -1965,17 +2837,13 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_memory_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(self._get_create_memory_table_sql()) def _execute_insert_entry(self, cursor: Any, sql: str, params: "dict[str, Any]") -> bool: """Execute an insert and skip duplicate key errors.""" try: cursor.execute(sql, params) - except oracledb.DatabaseError as exc: + except OracleDatabaseError as exc: error_obj = exc.args[0] if exc.args else None if error_obj and error_obj.code == ORACLE_DUPLICATE_KEY_ERROR: return False @@ -2029,10 +2897,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -2046,18 +2910,12 @@ def _search_entries( if self._use_fts: return self._search_entries_fts(query, app_name, user_id, effective_limit) return self._search_entries_simple(query, app_name, user_id, effective_limit) - except oracledb.DatabaseError as exc: + except OracleDatabaseError as exc: error_obj = exc.args[0] if exc.args else None if error_obj and error_obj.code == ORACLE_TABLE_NOT_FOUND_ERROR: return [] raise - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -2112,10 +2970,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: sql = f""" DELETE FROM {self._memory_table} @@ -2127,10 +2981,6 @@ def _delete_entries_older_than(self, days: int) -> int: conn.commit() return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": records: list[MemoryRecord] = [] for row in rows: @@ -2155,6 +3005,37 @@ def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": return records +def _configure_oracle_adk_session_tables(store: Any, config: Any) -> None: + """Apply Oracle clean-break ADK table names independent of shared-base drift.""" + + adk_config = _get_oracle_adk_config(config) + table_names = { + "_session_table": str(adk_config.get("session_table") or ORACLE_DEFAULT_SESSION_TABLE), + "_events_table": str(adk_config.get("events_table") or ORACLE_DEFAULT_EVENTS_TABLE), + "_app_state_table": str(adk_config.get("app_state_table") or ORACLE_DEFAULT_APP_STATE_TABLE), + "_user_state_table": str(adk_config.get("user_state_table") or ORACLE_DEFAULT_USER_STATE_TABLE), + "_metadata_table": str(adk_config.get("metadata_table") or ORACLE_DEFAULT_METADATA_TABLE), + } + for attribute_name, table_name in table_names.items(): + _validate_oracle_identifier(table_name, "table name") + setattr(store, attribute_name, table_name) + + +def _normalize_event_data_for_storage(event_data: Any) -> Any: + """Return event data without ADK 2.2-invalid durable ``actions: null``.""" + + if isinstance(event_data, dict) and event_data.get("actions") is None: + normalized = dict(event_data) + normalized.pop("actions", None) + return normalized + return event_data + + +def _raise_session_not_found(session_id: str) -> NoReturn: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) + + def _coerce_decimal_values(value: Any) -> Any: if isinstance(value, Decimal): return float(value) @@ -2210,13 +3091,22 @@ def _extract_json_value(data: Any) -> "dict[str, Any]": return from_json(str(data)) # type: ignore[no-any-return] -def _event_json_column_ddl(storage_type: JSONStorageType) -> str: - """Return the DDL fragment for the event_json column.""" +def _event_data_column_ddl(storage_type: JSONStorageType) -> str: + """Return the DDL fragment for the event_data column.""" + if storage_type == JSONStorageType.JSON_NATIVE: + return "event_data JSON NOT NULL" + if storage_type == JSONStorageType.BLOB_JSON: + return "event_data BLOB CHECK (event_data IS JSON) NOT NULL" + return "event_data BLOB NOT NULL" + + +def _json_column_ddl(column_name: str, storage_type: JSONStorageType) -> str: + """Return an Oracle JSON column DDL fragment for the configured storage type.""" if storage_type == JSONStorageType.JSON_NATIVE: - return "event_json JSON NOT NULL" + return f"{column_name} JSON NOT NULL" if storage_type == JSONStorageType.BLOB_JSON: - return "event_json BLOB CHECK (event_json IS JSON) NOT NULL" - return "event_json BLOB NOT NULL" + return f"{column_name} BLOB CHECK ({column_name} IS JSON) NOT NULL" + return f"{column_name} BLOB NOT NULL" def _get_oracle_adk_config(config: Any) -> dict[str, Any]: diff --git a/sqlspec/adapters/psqlpy/adk/store.py b/sqlspec/adapters/psqlpy/adk/store.py index 54d6de0e7..6c7f5768a 100644 --- a/sqlspec/adapters/psqlpy/adk/store.py +++ b/sqlspec/adapters/psqlpy/adk/store.py @@ -1,7 +1,7 @@ """Psqlpy ADK store for Google Agent Development Kit session/event storage.""" import re -from typing import TYPE_CHECKING, Any, Final +from typing import TYPE_CHECKING, Any, Final, NoReturn import psqlpy.exceptions @@ -11,7 +11,7 @@ from sqlspec.utils.type_guards import has_query_result_metadata if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.psqlpy.config import PsqlpyConfig from sqlspec.extensions.adk import MemoryRecord @@ -22,6 +22,11 @@ logger = get_logger("sqlspec.adapters.psqlpy.adk.store") POSTGRES_TABLE_NOT_FOUND_SQLSTATE: Final = "42P01" +PSQLPY_TABLE_MISSING_ERRORS: Final[tuple[type[Exception], ...]] = ( + psqlpy.exceptions.DatabaseError, + psqlpy.exceptions.ConnectionExecuteError, +) +PSQLPY_STATUS_REGEX: Final[re.Pattern[str]] = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE) class PsqlpyADKStore(BaseAsyncADKStore["PsqlpyConfig"]): @@ -29,12 +34,12 @@ class PsqlpyADKStore(BaseAsyncADKStore["PsqlpyConfig"]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via the high-performance Rust-based psqlpy driver. - Events are stored as a single JSONB blob (``event_json``) alongside + Events are stored as a single JSONB blob (``event_data``) alongside indexed scalar columns for efficient querying. Provides: - Session state management with JSONB storage - - Full-fidelity event storage via ``event_json`` JSONB column + - Full-fidelity event storage via ``event_data`` JSONB column - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete @@ -50,54 +55,14 @@ class PsqlpyADKStore(BaseAsyncADKStore["PsqlpyConfig"]): def __init__(self, config: "PsqlpyConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -117,18 +82,32 @@ async def create_session( """ await conn.execute(sql, [session_id, app_name, user_id, state]) - return await self.get_session(session_id) # type: ignore[return-value] + res = await self.get_session(app_name, user_id, session_id) + if res is None: + msg = "Failed to retrieve created session." + raise RuntimeError(msg) + return res - async def get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = $1 - """ + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + sql = f""" + UPDATE {self._session_table} + SET update_time = CURRENT_TIMESTAMP + WHERE app_name = $1 AND user_id = $2 AND id = $3 + RETURNING id, app_name, user_id, state, create_time, update_time + """ + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {self._session_table} + WHERE app_name = $1 AND user_id = $2 AND id = $3 + """ try: async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - result = await conn.fetch(sql, [session_id]) + result = await conn.fetch(sql, [app_name, user_id, session_id]) rows: list[dict[str, Any]] = result.result() if result else [] if not rows: @@ -143,27 +122,20 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": create_time=row["create_time"], update_time=row["update_time"], ) - except psqlpy.exceptions.DatabaseError as e: - error_msg = str(e).lower() - if "does not exist" in error_msg or "relation" in error_msg: + except PSQLPY_TABLE_MISSING_ERRORS as e: + if _is_table_missing_error(e): return None raise - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 """ async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - await conn.execute(sql, [state, session_id]) - - async def delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = $1" - - async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - await conn.execute(sql, [session_id]) + await conn.execute(sql, [state, app_name, user_id, session_id]) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -199,16 +171,21 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis ) for row in rows ] - except psqlpy.exceptions.DatabaseError as e: - error_msg = str(e).lower() - if "does not exist" in error_msg or "relation" in error_msg: + except PSQLPY_TABLE_MISSING_ERRORS as e: + if _is_table_missing_error(e): return [] raise + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {self._session_table} WHERE app_name = $1 AND user_id = $2 AND id = $3" + + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + await conn.execute(sql, [app_name, user_id, session_id]) + async def append_event(self, event_record: EventRecord) -> None: sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ @@ -216,46 +193,76 @@ async def append_event(self, event_record: EventRecord) -> None: await conn.execute( sql, [ + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], - event_record["event_json"], + event_record["event_data"], ], ) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_sql = f""" INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES ($1, $2, $3, $4, $5) """ update_sql = f""" UPDATE {self._session_table} SET state = $1, update_time = CURRENT_TIMESTAMP - WHERE id = $2 + WHERE app_name = $2 AND user_id = $3 AND id = $4 RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - await conn.execute( - insert_sql, - [ - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_record["event_json"], - ], - ) - result = await conn.fetch(update_sql, [state, session_id]) - rows: list[dict[str, Any]] = result.result() if result else [] - - if not rows: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) + try: + await conn.execute("BEGIN") + await conn.execute( + insert_sql, + [ + event_record["id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + event_record["event_data"], + ], + ) + result = await conn.fetch(update_sql, [state, app_name, user_id, session_id]) + rows: list[dict[str, Any]] = result.result() if result else [] + if not rows: + _raise_missing_session(session_id) + if app_state is not None: + await conn.execute(app_upsert_sql, [app_name, app_state]) + if user_state is not None: + await conn.execute(user_upsert_sql, [app_name, user_id, user_state]) + except Exception: + await conn.execute("ROLLBACK") + raise + await conn.execute("COMMIT") row = rows[0] return SessionRecord( @@ -268,25 +275,34 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = $1"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = ["s.app_name = $1", "s.user_id = $2", "e.session_id = $3"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append(f"timestamp > ${len(params) + 1}") + where_clauses.append(f"e.timestamp > ${len(params) + 1}") params.append(after_timestamp) where_clause = " AND ".join(where_clauses) - limit_clause = f" LIMIT ${len(params) + 1}" if limit else "" - if limit: + limit_clause = f" LIMIT ${len(params) + 1}" if limit is not None else "" + if limit is not None: params.append(limit) sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ try: @@ -296,72 +312,229 @@ async def get_events( return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] - except psqlpy.exceptions.DatabaseError as e: - error_msg = str(e).lower() - if "does not exist" in error_msg or "relation" in error_msg: + except PSQLPY_TABLE_MISSING_ERRORS as e: + if _is_table_missing_error(e): return [] raise + async def delete_expired_events(self, before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) AS count FROM {self._events_table} WHERE timestamp < $1" + delete_sql = f"DELETE FROM {self._events_table} WHERE timestamp < $1" -PSQLPY_STATUS_REGEX: Final[re.Pattern[str]] = re.compile(r"^([A-Z]+)(?:\s+(\d+))?\s+(\d+)$", re.IGNORECASE) + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + count_result = await conn.fetch(count_sql, [before]) + count_rows: list[dict[str, Any]] = count_result.result() if count_result else [] + count = int(count_rows[0]["count"]) if count_rows else 0 + await conn.execute(delete_sql, [before]) + return count + except PSQLPY_TABLE_MISSING_ERRORS as e: + if _is_table_missing_error(e): + return 0 + raise + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + count_sql = f"SELECT COUNT(*) AS count FROM {self._session_table} WHERE update_time < $1" + delete_sql = f"DELETE FROM {self._session_table} WHERE update_time < $1" -class PsqlpyADKMemoryStore(BaseAsyncADKMemoryStore["PsqlpyConfig"]): - """PostgreSQL ADK memory store using Psqlpy driver.""" + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + count_result = await conn.fetch(count_sql, [updated_before]) + count_rows: list[dict[str, Any]] = count_result.result() if count_result else [] + count = int(count_rows[0]["count"]) if count_rows else 0 + await conn.execute(delete_sql, [updated_before]) + return count + except PSQLPY_TABLE_MISSING_ERRORS as e: + if _is_table_missing_error(e): + return 0 + raise - __slots__ = () + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = $1" - def __init__(self, config: "PsqlpyConfig") -> None: - """Initialize Psqlpy memory store.""" - super().__init__(config) + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, [app_name]) + rows: list[dict[str, Any]] = result.result() if result else [] + return rows[0]["state"] if rows else None + except PSQLPY_TABLE_MISSING_ERRORS as e: + if _is_table_missing_error(e): + return None + raise - async def _get_create_memory_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for memory entries.""" + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._user_state_table} WHERE app_name = $1 AND user_id = $2" + + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, [app_name, user_id]) + rows: list[dict[str, Any]] = result.result() if result else [] + return rows[0]["state"] if rows else None + except PSQLPY_TABLE_MISSING_ERRORS as e: + if _is_table_missing_error(e): + return None + raise + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES ($1, $2, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + await conn.execute(sql, [app_name, state]) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES ($1, $2, $3, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """ + + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + await conn.execute(sql, [app_name, user_id, state]) + + async def get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = $1" + + try: + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, [key]) + rows: list[dict[str, Any]] = result.result() if result else [] + return rows[0]["value"] if rows else None + except PSQLPY_TABLE_MISSING_ERRORS as e: + if _is_table_missing_error(e): + return None + raise + + async def set_metadata(self, key: str, value: str) -> None: + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ($1, $2) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """ + + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + await conn.execute(sql, [key, value]) + + async def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts - ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); - """ + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( + CREATE TABLE IF NOT EXISTS {self._events_table} ( id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( app_name VARCHAR(128) NOT NULL, user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMPTZ NOT NULL, - content_json JSONB NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSONB, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) WITH (fillfactor = 80); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_index} + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING """ - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements.""" - return [f"DROP TABLE IF EXISTS {self._memory_table}"] + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class PsqlpyADKMemoryStore(BaseAsyncADKMemoryStore["PsqlpyConfig"]): + """PostgreSQL ADK memory store using Psqlpy driver.""" + + __slots__ = () + + def __init__(self, config: "PsqlpyConfig") -> None: + """Initialize Psqlpy memory store.""" + super().__init__(config) async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" @@ -465,42 +638,6 @@ async def search_entries( return [] raise - async def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": - sql = f""" - SELECT id, session_id, app_name, user_id, event_id, author, - timestamp, content_json, content_text, metadata_json, inserted_at, - ts_rank(to_tsvector('english', content_text), plainto_tsquery('english', $1)) as rank - FROM {self._memory_table} - WHERE app_name = $2 - AND user_id = $3 - AND to_tsvector('english', content_text) @@ plainto_tsquery('english', $1) - ORDER BY rank DESC, timestamp DESC - LIMIT $4 - """ - params = [query, app_name, user_id, limit] - async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - result = await conn.fetch(sql, params) - rows: list[dict[str, Any]] = result.result() if result else [] - return _rows_to_records(rows) - - async def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": - sql = f""" - SELECT id, session_id, app_name, user_id, event_id, author, - timestamp, content_json, content_text, metadata_json, inserted_at - FROM {self._memory_table} - WHERE app_name = $1 - AND user_id = $2 - AND content_text ILIKE $3 - ORDER BY timestamp DESC - LIMIT $4 - """ - pattern = f"%{query}%" - params = [app_name, user_id, pattern, limit] - async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] - result = await conn.fetch(sql, params) - rows: list[dict[str, Any]] = result.result() if result else [] - return _rows_to_records(rows) - async def delete_entries_by_session(self, session_id: str) -> int: """Delete all memory entries for a specific session.""" count_sql = f"SELECT COUNT(*) AS count FROM {self._memory_table} WHERE session_id = $1" @@ -543,6 +680,82 @@ async def delete_entries_older_than(self, days: int) -> int: return 0 raise + async def _get_create_memory_table_sql(self) -> str: + """Get PostgreSQL CREATE TABLE SQL for memory entries.""" + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMPTZ NOT NULL, + content_json JSONB NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSONB, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_index} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get PostgreSQL DROP TABLE SQL statements.""" + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + + async def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": + sql = f""" + SELECT id, session_id, app_name, user_id, event_id, author, + timestamp, content_json, content_text, metadata_json, inserted_at, + ts_rank(to_tsvector('english', content_text), plainto_tsquery('english', $1)) as rank + FROM {self._memory_table} + WHERE app_name = $2 + AND user_id = $3 + AND to_tsvector('english', content_text) @@ plainto_tsquery('english', $1) + ORDER BY rank DESC, timestamp DESC + LIMIT $4 + """ + params = [query, app_name, user_id, limit] + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, params) + rows: list[dict[str, Any]] = result.result() if result else [] + return _rows_to_records(rows) + + async def _search_entries_simple(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": + sql = f""" + SELECT id, session_id, app_name, user_id, event_id, author, + timestamp, content_json, content_text, metadata_json, inserted_at + FROM {self._memory_table} + WHERE app_name = $1 + AND user_id = $2 + AND content_text ILIKE $3 + ORDER BY timestamp DESC + LIMIT $4 + """ + pattern = f"%{query}%" + params = [app_name, user_id, pattern, limit] + async with self._config.provide_connection() as conn: # pyright: ignore[reportAttributeAccessIssue] + result = await conn.fetch(sql, params) + rows: list[dict[str, Any]] = result.result() if result else [] + return _rows_to_records(rows) + def _extract_rows_affected(self, result: Any) -> int: """Extract rows affected from psqlpy result.""" try: @@ -589,3 +802,13 @@ def _rows_to_records(rows: "list[dict[str, Any]]") -> "list[MemoryRecord]": } for row in rows ] + + +def _is_table_missing_error(exc: Exception) -> bool: + error_msg = str(exc).lower() + return "does not exist" in error_msg or "relation" in error_msg + + +def _raise_missing_session(session_id: str) -> NoReturn: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) diff --git a/sqlspec/adapters/psycopg/adk/store.py b/sqlspec/adapters/psycopg/adk/store.py index 64677d83e..bd75af828 100644 --- a/sqlspec/adapters/psycopg/adk/store.py +++ b/sqlspec/adapters/psycopg/adk/store.py @@ -1,18 +1,18 @@ """Psycopg ADK store for Google Agent Development Kit session/event storage.""" -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NoReturn from psycopg import errors from psycopg import sql as pg_sql +from psycopg.rows import dict_row from psycopg.types.json import Jsonb -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk import BaseAsyncADKStore, BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore from sqlspec.utils.logging import get_logger -from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.psycopg.config import PsycopgAsyncConfig, PsycopgSyncConfig from sqlspec.extensions.adk import MemoryRecord @@ -28,12 +28,12 @@ class PsycopgAsyncADKStore(BaseAsyncADKStore["PsycopgAsyncConfig"]): Implements session and event storage for Google Agent Development Kit using PostgreSQL via psycopg3 with native async/await support. - Events are stored as a single JSONB blob (``event_json``) alongside + Events are stored as a single JSONB blob (``event_data``) alongside indexed scalar columns for efficient querying. Provides: - Session state management with JSONB storage - - Full-fidelity event storage via ``event_json`` JSONB column + - Full-fidelity event storage via ``event_data`` JSONB column - Atomic ``append_event_and_update_state`` for durable session mutations - Microsecond-precision timestamps with TIMESTAMPTZ - Foreign key constraints with cascade delete @@ -49,54 +49,14 @@ class PsycopgAsyncADKStore(BaseAsyncADKStore["PsycopgAsyncConfig"]): def __init__(self, config: "PsycopgAsyncConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL{owner_id_line}, - state JSONB NOT NULL DEFAULT '{{}}'::jsonb, - create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state - ON {self._session_table} USING GIN (state) - WHERE state != '{{}}'::jsonb; - """ - - async def _get_create_events_table_sql(self) -> str: - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, - timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ) WITH (fillfactor = 80); - - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_sessions_table_sql()) await driver.execute_script(await self._get_create_events_table_sql()) + await driver.execute_script(await self._get_create_app_states_table_sql()) + await driver.execute_script(await self._get_create_user_states_table_sql()) + await driver.execute_script(await self._get_create_metadata_table_sql()) + await driver.execute_script(await self._get_seed_metadata_sql()) async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -117,21 +77,37 @@ async def create_session( """).format(table=pg_sql.Identifier(self._session_table)) params = (session_id, app_name, user_id, Jsonb(state)) - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, params) - return await self.get_session(session_id) # type: ignore[return-value] + result = await self.get_session(app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result - async def get_session(self, session_id: str) -> "SessionRecord | None": - query = pg_sql.SQL(""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {table} - WHERE id = %s - """).format(table=pg_sql.Identifier(self._session_table)) + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + query = pg_sql.SQL(""" + UPDATE {table} + SET update_time = CURRENT_TIMESTAMP + WHERE app_name = %s AND user_id = %s AND id = %s + RETURNING id, app_name, user_id, state, create_time, update_time + """).format(table=pg_sql.Identifier(self._session_table)) + params = (app_name, user_id, session_id) + else: + query = pg_sql.SQL(""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {table} + WHERE app_name = %s AND user_id = %s AND id = %s + """).format(table=pg_sql.Identifier(self._session_table)) + params = (app_name, user_id, session_id) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(query, (session_id,)) + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, params) row = await cur.fetchone() if row is None: @@ -148,21 +124,15 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """).format(table=pg_sql.Identifier(self._session_table)) - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(query, (Jsonb(state), session_id)) - - async def delete_session(self, session_id: str) -> None: - query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) - - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(query, (session_id,)) + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (Jsonb(state), app_name, user_id, session_id)) async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -183,7 +153,7 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis params = (app_name, user_id) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, params) rows = await cur.fetchall() @@ -201,66 +171,104 @@ async def list_sessions(self, app_name: str, user_id: str | None = None) -> "lis except errors.UndefinedTable: return [] + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + query = pg_sql.SQL("DELETE FROM {table} WHERE app_name = %s AND user_id = %s AND id = %s").format( + table=pg_sql.Identifier(self._session_table) + ) + + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (app_name, user_id, session_id)) + async def append_event(self, event_record: EventRecord) -> None: query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute( query, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), ) async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) update_query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value - - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute( - insert_query, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - jsonb_value, - ), - ) - await cur.execute(update_query, (Jsonb(state), session_id)) - row = await cur.fetchone() + app_upsert_query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._app_state_table)) + + user_upsert_query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._user_state_table)) + + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value + + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + try: + await cur.execute( + insert_query, + ( + event_record["id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + jsonb_value, + ), + ) + await cur.execute(update_query, (Jsonb(state), app_name, user_id, session_id)) + row = await cur.fetchone() + if row is None: + _raise_missing_session(session_id) + if app_state is not None: + await cur.execute(app_upsert_query, (app_name, Jsonb(app_state))) + if user_state is not None: + await cur.execute(user_upsert_query, (app_name, user_id, Jsonb(user_state))) + except Exception: + await conn.rollback() + raise await conn.commit() - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) - return SessionRecord( id=row["id"], app_name=row["app_name"], @@ -271,78 +279,356 @@ async def append_event_and_update_state( ) async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = [pg_sql.SQL("s.app_name = %s"), pg_sql.SQL("s.user_id = %s"), pg_sql.SQL("e.session_id = %s")] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append(pg_sql.SQL("e.timestamp > %s")) params.append(after_timestamp) - where_clause = " AND ".join(where_clauses) - if limit: + where_clause = pg_sql.SQL(" AND ").join(where_clauses) + if limit is not None: params.append(limit) query = pg_sql.SQL( """ - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {events_table} e + JOIN {session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ ).format( - table=pg_sql.Identifier(self._events_table), - where_clause=pg_sql.SQL(where_clause), # pyright: ignore[reportArgumentType] - limit_clause=pg_sql.SQL(" LIMIT %s" if limit else ""), # pyright: ignore[reportArgumentType] + events_table=pg_sql.Identifier(self._events_table), + session_table=pg_sql.Identifier(self._session_table), + where_clause=where_clause, + limit_clause=pg_sql.SQL(" LIMIT %s" if limit is not None else ""), ) try: - async with self._config.provide_connection() as conn, conn.cursor() as cur: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: await cur.execute(query, tuple(params)) rows = await cur.fetchall() return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] except errors.UndefinedTable: return [] + async def delete_expired_events(self, before: "datetime") -> int: + query = pg_sql.SQL("DELETE FROM {table} WHERE timestamp < %s").format( + table=pg_sql.Identifier(self._events_table) + ) -class PsycopgSyncADKStore(BaseAsyncADKStore["PsycopgSyncConfig"]): - """PostgreSQL synchronous ADK store using Psycopg3 driver. + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (before,)) + await conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 - Implements session and event storage for Google Agent Development Kit - using PostgreSQL via psycopg3 with synchronous execution. - Events are stored as a single JSONB blob (``event_json``) alongside - indexed scalar columns for efficient querying. + async def delete_idle_sessions(self, updated_before: "datetime") -> int: + query = pg_sql.SQL("DELETE FROM {table} WHERE update_time < %s").format( + table=pg_sql.Identifier(self._session_table) + ) - Provides: - - Session state management with JSONB storage - - Full-fidelity event storage via ``event_json`` JSONB column - - Atomic ``create_event_and_update_state`` for durable session mutations - - Microsecond-precision timestamps with TIMESTAMPTZ - - Foreign key constraints with cascade delete - - GIN indexes for JSONB queries - - HOT updates with FILLFACTOR 80 + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (updated_before,)) + await conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 - Args: - config: PsycopgSyncConfig with extension_config["adk"] settings. - """ + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + query = pg_sql.SQL("SELECT state FROM {table} WHERE app_name = %s").format( + table=pg_sql.Identifier(self._app_state_table) + ) + + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (app_name,)) + row = await cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + query = pg_sql.SQL("SELECT state FROM {table} WHERE app_name = %s AND user_id = %s").format( + table=pg_sql.Identifier(self._user_state_table) + ) + + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (app_name, user_id)) + row = await cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._app_state_table)) + + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (app_name, Jsonb(state))) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._user_state_table)) + + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (app_name, user_id, Jsonb(state))) + + async def get_metadata(self, key: str) -> "str | None": + query = pg_sql.SQL("SELECT value FROM {table} WHERE key = %s").format( + table=pg_sql.Identifier(self._metadata_table) + ) + + try: + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (key,)) + row = await cur.fetchone() + return row["value"] if row is not None else None + except errors.UndefinedTable: + return None + + async def set_metadata(self, key: str, value: str) -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (key, value) + VALUES (%s, %s) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """).format(table=pg_sql.Identifier(self._metadata_table)) + + async with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + await cur.execute(query, (key, value)) + + async def _get_create_sessions_table_sql(self) -> str: + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL{owner_id_line}, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + create_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_state + ON {self._session_table} USING GIN (state) + WHERE state != '{{}}'::jsonb; + """ + + async def _get_create_events_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256), + timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + event_data JSONB NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ) WITH (fillfactor = 80); + + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(session_id, timestamp ASC); + """ + + async def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + """ + + async def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) WITH (fillfactor = 80); + """ + + async def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + async def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> "list[str]": + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class PsycopgSyncADKStore(BaseSyncADKStore["PsycopgSyncConfig"]): + """PostgreSQL synchronous ADK store using Psycopg3 driver.""" __slots__ = () def __init__(self, config: "PsycopgSyncConfig") -> None: super().__init__(config) - async def _get_create_sessions_table_sql(self) -> str: + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return self._create_session(session_id, app_name, user_id, state, owner_id) + + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return self._get_session(app_name, user_id, session_id, renew_for=renew_for) + + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + self._update_session_state(app_name, user_id, session_id, state) + + def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return self._list_sessions(app_name, user_id) + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and associated events.""" + self._delete_session(app_name, user_id, session_id) + + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + self._append_event(event_record) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update session + scoped state.""" + return self._append_event_and_update_state( + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state + ) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session.""" + return self._get_events(app_name, user_id, session_id, after_timestamp, limit) + + def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return self._delete_expired_events(before) + + def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the given threshold.""" + return self._delete_idle_sessions(updated_before) + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return self._get_app_state(app_name) + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return self._get_user_state(app_name, user_id) + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + self._upsert_app_state(app_name, state) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + self._upsert_user_state(app_name, user_id, state) + + def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return self._get_metadata(key) + + def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + self._set_metadata(key, value) + + def _get_create_sessions_table_sql(self) -> str: owner_id_line = "" if self._owner_id_column_ddl: owner_id_line = f",\n {self._owner_id_column_ddl}" @@ -368,14 +654,14 @@ async def _get_create_sessions_table_sql(self) -> str: WHERE state != '{{}}'::jsonb; """ - async def _get_create_events_table_sql(self) -> str: + def _get_create_events_table_sql(self) -> str: return f""" CREATE TABLE IF NOT EXISTS {self._events_table} ( + id VARCHAR(128) PRIMARY KEY, session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(256) NOT NULL, + invocation_id VARCHAR(256), timestamp TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, - event_json JSONB NOT NULL, + event_data JSONB NOT NULL, FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE ) WITH (fillfactor = 80); @@ -383,17 +669,67 @@ async def _get_create_events_table_sql(self) -> str: ON {self._events_table}(session_id, timestamp ASC); """ + def _get_create_app_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ) WITH (fillfactor = 80); + """ + + def _get_create_user_states_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSONB NOT NULL DEFAULT '{{}}'::jsonb, + update_time TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (app_name, user_id) + ) WITH (fillfactor = 80); + """ + + def _get_create_metadata_table_sql(self) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ); + """ + + def _get_seed_metadata_sql(self) -> str: + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT (key) DO NOTHING + """ + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] def _create_tables(self) -> None: with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_sessions_table_sql)()) - driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(self._get_create_sessions_table_sql()) + driver.execute_script(self._get_create_events_table_sql()) + driver.execute_script(self._get_create_app_states_table_sql()) + driver.execute_script(self._get_create_user_states_table_sql()) + driver.execute_script(self._get_create_metadata_table_sql()) + driver.execute_script(self._get_seed_metadata_sql()) def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -414,31 +750,37 @@ def _create_session( """).format(table=pg_sql.Identifier(self._session_table)) params = (session_id, app_name, user_id, Jsonb(state)) - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, params) - result = self._get_session(session_id) + result = self._get_session(app_name, user_id, session_id) if result is None: msg = "Failed to fetch created session" raise RuntimeError(msg) return result - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": - query = pg_sql.SQL(""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {table} - WHERE id = %s - """).format(table=pg_sql.Identifier(self._session_table)) + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + query = pg_sql.SQL(""" + UPDATE {table} + SET update_time = CURRENT_TIMESTAMP + WHERE app_name = %s AND user_id = %s AND id = %s + RETURNING id, app_name, user_id, state, create_time, update_time + """).format(table=pg_sql.Identifier(self._session_table)) + params = (app_name, user_id, session_id) + else: + query = pg_sql.SQL(""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {table} + WHERE app_name = %s AND user_id = %s AND id = %s + """).format(table=pg_sql.Identifier(self._session_table)) + params = (app_name, user_id, session_id) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(query, (session_id,)) + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, params) row = cur.fetchone() if row is None: @@ -455,33 +797,23 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": except errors.UndefinedTable: return None - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s """).format(table=pg_sql.Identifier(self._session_table)) - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(query, (Jsonb(state), session_id)) - - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) - - def _delete_session(self, session_id: str) -> None: - query = pg_sql.SQL("DELETE FROM {table} WHERE id = %s").format(table=pg_sql.Identifier(self._session_table)) + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (Jsonb(state), app_name, user_id, session_id)) - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute(query, (session_id,)) + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + query = pg_sql.SQL("DELETE FROM {table} WHERE app_name = %s AND user_id = %s AND id = %s").format( + table=pg_sql.Identifier(self._session_table) + ) - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (app_name, user_id, session_id)) def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": if user_id is None: @@ -502,7 +834,7 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses params = (app_name, user_id) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, params) rows = cur.fetchall() @@ -520,27 +852,23 @@ def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[Ses except errors.UndefinedTable: return [] - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - def _insert_event(self, event_record: EventRecord) -> None: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute( insert_query, ( + event_record["id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], event_record["timestamp"], jsonb_value, ), @@ -548,43 +876,73 @@ def _insert_event(self, event_record: EventRecord) -> None: conn.commit() def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: insert_query = pg_sql.SQL(""" INSERT INTO {table} ( - session_id, invocation_id, author, timestamp, event_json + id, session_id, invocation_id, timestamp, event_data ) VALUES (%s, %s, %s, %s, %s) """).format(table=pg_sql.Identifier(self._events_table)) update_query = pg_sql.SQL(""" UPDATE {table} SET state = %s, update_time = CURRENT_TIMESTAMP - WHERE id = %s + WHERE app_name = %s AND user_id = %s AND id = %s RETURNING id, app_name, user_id, state, create_time, update_time """).format(table=pg_sql.Identifier(self._session_table)) - event_json_value = event_record["event_json"] - jsonb_value = Jsonb(event_json_value) if isinstance(event_json_value, dict) else event_json_value - - with self._config.provide_connection() as conn, conn.cursor() as cur: - cur.execute( - insert_query, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - jsonb_value, - ), - ) - cur.execute(update_query, (Jsonb(state), session_id)) - row = cur.fetchone() + app_upsert_query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._app_state_table)) + + user_upsert_query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._user_state_table)) + + event_data_value = event_record["event_data"] + jsonb_value = Jsonb(event_data_value) if isinstance(event_data_value, dict) else event_data_value + + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + try: + cur.execute( + insert_query, + ( + event_record["id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + jsonb_value, + ), + ) + cur.execute(update_query, (Jsonb(state), app_name, user_id, session_id)) + row = cur.fetchone() + if row is None: + _raise_missing_session(session_id) + if app_state is not None: + cur.execute(app_upsert_query, (app_name, Jsonb(app_state))) + if user_state is not None: + cur.execute(user_upsert_query, (app_name, user_id, Jsonb(user_state))) + except Exception: + conn.rollback() + raise conn.commit() - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) - return SessionRecord( id=row["id"], app_name=row["app_name"], @@ -594,71 +952,169 @@ def _append_event_and_update_state( update_time=row["update_time"], ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = [pg_sql.SQL("s.app_name = %s"), pg_sql.SQL("s.user_id = %s"), pg_sql.SQL("e.session_id = %s")] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: - where_clauses.append("timestamp > %s") + where_clauses.append(pg_sql.SQL("e.timestamp > %s")) params.append(after_timestamp) - where_clause = " AND ".join(where_clauses) - if limit: + where_clause = pg_sql.SQL(" AND ").join(where_clauses) + if limit is not None: params.append(limit) query = pg_sql.SQL( """ - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {table} + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {events_table} e + JOIN {session_table} s ON e.session_id = s.id WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} + ORDER BY e.timestamp ASC{limit_clause} """ ).format( - table=pg_sql.Identifier(self._events_table), - where_clause=pg_sql.SQL(where_clause), # pyright: ignore[reportArgumentType] - limit_clause=pg_sql.SQL(" LIMIT %s" if limit else ""), # pyright: ignore[reportArgumentType] + events_table=pg_sql.Identifier(self._events_table), + session_table=pg_sql.Identifier(self._session_table), + where_clause=where_clause, + limit_clause=pg_sql.SQL(" LIMIT %s" if limit is not None else ""), ) try: - with self._config.provide_connection() as conn, conn.cursor() as cur: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: cur.execute(query, tuple(params)) rows = cur.fetchall() return [ EventRecord( + id=row["id"], session_id=row["session_id"], invocation_id=row["invocation_id"], - author=row["author"], timestamp=row["timestamp"], - event_json=row["event_json"], + event_data=row["event_data"], + app_name=row["app_name"], + user_id=row["user_id"], ) for row in rows ] except errors.UndefinedTable: return [] - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + def _delete_expired_events(self, before: "datetime") -> int: + query = pg_sql.SQL("DELETE FROM {table} WHERE timestamp < %s").format( + table=pg_sql.Identifier(self._events_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (before,)) + conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + def _delete_idle_sessions(self, updated_before: "datetime") -> int: + query = pg_sql.SQL("DELETE FROM {table} WHERE update_time < %s").format( + table=pg_sql.Identifier(self._session_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (updated_before,)) + conn.commit() + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + except errors.UndefinedTable: + return 0 + + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + query = pg_sql.SQL("SELECT state FROM {table} WHERE app_name = %s").format( + table=pg_sql.Identifier(self._app_state_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (app_name,)) + row = cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + query = pg_sql.SQL("SELECT state FROM {table} WHERE app_name = %s AND user_id = %s").format( + table=pg_sql.Identifier(self._user_state_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (app_name, user_id)) + row = cur.fetchone() + return row["state"] if row is not None else None + except errors.UndefinedTable: + return None + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, state, update_time) + VALUES (%s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._app_state_table)) + + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (app_name, Jsonb(state))) + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, CURRENT_TIMESTAMP) + ON CONFLICT (app_name, user_id) DO UPDATE SET + state = EXCLUDED.state, + update_time = CURRENT_TIMESTAMP + """).format(table=pg_sql.Identifier(self._user_state_table)) + + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (app_name, user_id, Jsonb(state))) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + query = pg_sql.SQL("SELECT value FROM {table} WHERE key = %s").format( + table=pg_sql.Identifier(self._metadata_table) + ) + + try: + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (key,)) + row = cur.fetchone() + return row["value"] if row is not None else None + except errors.UndefinedTable: + return None + + def _set_metadata(self, key: str, value: str) -> None: + query = pg_sql.SQL(""" + INSERT INTO {table} (key, value) + VALUES (%s, %s) + ON CONFLICT (key) DO UPDATE SET value = EXCLUDED.value + """).format(table=pg_sql.Identifier(self._metadata_table)) + + with self._config.provide_connection() as conn, conn.cursor(row_factory=dict_row) as cur: + cur.execute(query, (key, value)) + conn.commit() def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) - class PsycopgAsyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgAsyncConfig"]): """PostgreSQL ADK memory store using Psycopg3 async driver.""" @@ -669,46 +1125,6 @@ def __init__(self, config: "PsycopgAsyncConfig") -> None: """Initialize Psycopg async memory store.""" super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: - """Get PostgreSQL CREATE TABLE SQL for memory entries.""" - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" - - fts_index = "" - if self._use_fts: - fts_index = f""" - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts - ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); - """ - - return f""" - CREATE TABLE IF NOT EXISTS {self._memory_table} ( - id VARCHAR(128) PRIMARY KEY, - session_id VARCHAR(128) NOT NULL, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - event_id VARCHAR(128) NOT NULL UNIQUE, - author VARCHAR(256){owner_id_line}, - timestamp TIMESTAMPTZ NOT NULL, - content_json JSONB NOT NULL, - content_text TEXT NOT NULL, - metadata_json JSONB, - inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP - ); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time - ON {self._memory_table}(app_name, user_id, timestamp DESC); - - CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session - ON {self._memory_table}(session_id); - {fts_index} - """ - - def _get_drop_memory_table_sql(self) -> "list[str]": - """Get PostgreSQL DROP TABLE SQL statements.""" - return [f"DROP TABLE IF EXISTS {self._memory_table}"] - async def create_tables(self) -> None: """Create the memory table and indexes if they don't exist.""" if not self._enabled: @@ -782,6 +1198,69 @@ async def search_entries( except errors.UndefinedTable: return [] + async def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + sql = pg_sql.SQL("DELETE FROM {table} WHERE session_id = %s").format( + table=pg_sql.Identifier(self._memory_table) + ) + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql, (session_id,)) + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + + async def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + sql = pg_sql.SQL( + """ + DELETE FROM {table} + WHERE inserted_at < CURRENT_TIMESTAMP - {interval}::interval + """ + ).format(table=pg_sql.Identifier(self._memory_table), interval=pg_sql.Literal(f"{days} days")) + + async with self._config.provide_connection() as conn, conn.cursor() as cur: + await cur.execute(sql) + return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 + + async def _get_create_memory_table_sql(self) -> str: + """Get PostgreSQL CREATE TABLE SQL for memory entries.""" + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + fts_index = "" + if self._use_fts: + fts_index = f""" + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_fts + ON {self._memory_table} USING GIN (to_tsvector('english', content_text)); + """ + + return f""" + CREATE TABLE IF NOT EXISTS {self._memory_table} ( + id VARCHAR(128) PRIMARY KEY, + session_id VARCHAR(128) NOT NULL, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + event_id VARCHAR(128) NOT NULL UNIQUE, + author VARCHAR(256){owner_id_line}, + timestamp TIMESTAMPTZ NOT NULL, + content_json JSONB NOT NULL, + content_text TEXT NOT NULL, + metadata_json JSONB, + inserted_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP + ); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_app_user_time + ON {self._memory_table}(app_name, user_id, timestamp DESC); + + CREATE INDEX IF NOT EXISTS idx_{self._memory_table}_session + ON {self._memory_table}(session_id); + {fts_index} + """ + + def _get_drop_memory_table_sql(self) -> "list[str]": + """Get PostgreSQL DROP TABLE SQL statements.""" + return [f"DROP TABLE IF EXISTS {self._memory_table}"] + async def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = pg_sql.SQL( """ @@ -822,31 +1301,8 @@ async def _search_entries_simple(self, query: str, app_name: str, user_id: str, rows = await cur.fetchall() return _rows_to_records(rows) - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - sql = pg_sql.SQL("DELETE FROM {table} WHERE session_id = %s").format( - table=pg_sql.Identifier(self._memory_table) - ) - - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql, (session_id,)) - return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - sql = pg_sql.SQL( - """ - DELETE FROM {table} - WHERE inserted_at < CURRENT_TIMESTAMP - {interval}::interval - """ - ).format(table=pg_sql.Identifier(self._memory_table), interval=pg_sql.Literal(f"{days} days")) - - async with self._config.provide_connection() as conn, conn.cursor() as cur: - await cur.execute(sql) - return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - -class PsycopgSyncADKMemoryStore(BaseAsyncADKMemoryStore["PsycopgSyncConfig"]): +class PsycopgSyncADKMemoryStore(BaseSyncADKMemoryStore["PsycopgSyncConfig"]): """PostgreSQL ADK memory store using Psycopg3 sync driver.""" __slots__ = () @@ -855,7 +1311,29 @@ def __init__(self, config: "PsycopgSyncConfig") -> None: """Initialize Psycopg sync memory store.""" super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return self._insert_memory_entries(entries, owner_id) + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return self._search_entries(query, app_name, user_id, limit) + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._delete_entries_by_session(session_id) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return self._delete_entries_older_than(days) + + def _get_create_memory_table_sql(self) -> str: """Get PostgreSQL CREATE TABLE SQL for memory entries.""" owner_id_line = "" if self._owner_id_column_ddl: @@ -901,11 +1379,7 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_memory_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(self._get_create_memory_table_sql()) def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication.""" @@ -952,10 +1426,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -976,12 +1446,6 @@ def _search_entries( except errors.UndefinedTable: return [] - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = pg_sql.SQL( """ @@ -1032,10 +1496,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: cur.execute(sql, (session_id,)) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days.""" sql = pg_sql.SQL( @@ -1049,10 +1509,6 @@ def _delete_entries_older_than(self, days: int) -> int: cur.execute(sql) return cur.rowcount if cur.rowcount and cur.rowcount > 0 else 0 - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _build_insert_params(entry: "MemoryRecord") -> "tuple[object, ...]": return ( @@ -1104,3 +1560,8 @@ def _rows_to_records(rows: "list[Any]") -> "list[MemoryRecord]": } for row in rows ] + + +def _raise_missing_session(session_id: str) -> NoReturn: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) diff --git a/sqlspec/adapters/pymysql/adk/store.py b/sqlspec/adapters/pymysql/adk/store.py index aad9cd654..bac63c38e 100644 --- a/sqlspec/adapters/pymysql/adk/store.py +++ b/sqlspec/adapters/pymysql/adk/store.py @@ -5,13 +5,12 @@ import pymysql -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: - from datetime import datetime + from datetime import datetime, timedelta from sqlspec.adapters.pymysql.config import PyMysqlConfig from sqlspec.extensions.adk import MemoryRecord @@ -22,413 +21,153 @@ MYSQL_TABLE_NOT_FOUND_ERROR: Final = 1146 -class PyMysqlADKStore(BaseAsyncADKStore["PyMysqlConfig"]): - """MySQL/MariaDB ADK store using PyMySQL. - - Implements session and event storage for Google Agent Development Kit - using MySQL/MariaDB via the PyMySQL sync driver. Provides: - - Session state management with JSON storage - - Full-event JSON storage (single ``event_json`` column) - - Atomic event-create + state-update in one transaction - - Microsecond-precision timestamps - - Foreign key constraints with cascade delete - """ +class PyMysqlADKStore(BaseSyncADKStore["PyMysqlConfig"]): + """MySQL/MariaDB ADK store using PyMySQL.""" __slots__ = () def __init__(self, config: "PyMysqlConfig") -> None: super().__init__(config) - def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]": - return _parse_owner_id_column_for_mysql(column_ddl) - - async def _get_create_sessions_table_sql(self) -> str: - owner_id_col = "" - fk_constraint = "" - - if self._owner_id_column_ddl: - col_def, fk_def = self._parse_owner_id_column_for_mysql(self._owner_id_column_ddl) - owner_id_col = f"{col_def}," - if fk_def: - fk_constraint = f",\n {fk_def}" - - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id VARCHAR(128) PRIMARY KEY, - app_name VARCHAR(128) NOT NULL, - user_id VARCHAR(128) NOT NULL, - {owner_id_col} - state JSON NOT NULL, - create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), - INDEX idx_{self._session_table}_app_user (app_name, user_id), - INDEX idx_{self._session_table}_update_time (update_time DESC){fk_constraint} - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - async def _get_create_events_table_sql(self) -> str: - """Get MySQL CREATE TABLE SQL for events. - - Post clean-break schema: 5 columns only. - """ - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - session_id VARCHAR(128) NOT NULL, - invocation_id VARCHAR(256) NOT NULL, - author VARCHAR(128) NOT NULL, - timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), - event_json JSON NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE, - INDEX idx_{self._events_table}_session (session_id, timestamp ASC) - ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci - """ - - def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] - - def _create_tables(self) -> None: - with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_sessions_table_sql)()) - driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - - def _create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - state_json = to_json(state) - - params: tuple[Any, ...] - if self._owner_id_column_name: - sql = f""" - INSERT INTO {self._session_table} (id, app_name, user_id, {self._owner_id_column_name}, state, create_time, update_time) - VALUES (%s, %s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) - """ - params = (session_id, app_name, user_id, owner_id, state_json) - else: - sql = f""" - INSERT INTO {self._session_table} (id, app_name, user_id, state, create_time, update_time) - VALUES (%s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) - """ - params = (session_id, app_name, user_id, state_json) - - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute(sql, params) - finally: - cursor.close() - conn.commit() - - result = self._get_session(session_id) - if result is None: - msg = "Failed to fetch created session" - raise RuntimeError(msg) - return result + def create_tables(self) -> None: + """Create all ADK session tables if they don't exist.""" + _pymysql_create_tables(self) - async def create_session( + def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> SessionRecord: """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = %s - """ - - try: - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute(sql, (session_id,)) - row = cursor.fetchone() - finally: - cursor.close() - - if row is None: - return None - - session_id_val, app_name, user_id, state_json, create_time, update_time = row - - return SessionRecord( - id=session_id_val, - app_name=app_name, - user_id=user_id, - state=from_json(state_json) if isinstance(state_json, str) else state_json, - create_time=create_time, - update_time=update_time, - ) - except pymysql.MySQLError as exc: - if "doesn't exist" in str(exc) or getattr(exc, "args", [None])[0] == MYSQL_TABLE_NOT_FOUND_ERROR: - return None - raise - - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - state_json = to_json(state) - - sql = f""" - UPDATE {self._session_table} - SET state = %s - WHERE id = %s - """ + return _pymysql_create_session(self, session_id, app_name, user_id, state, owner_id) - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute(sql, (state_json, session_id)) - finally: - cursor.close() - conn.commit() + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by scoped identifiers.""" + return _pymysql_get_session(self, app_name, user_id, session_id, renew_for=renew_for) - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state.""" - await async_(self._update_session_state)(session_id, state) + _pymysql_update_session_state(self, app_name, user_id, session_id, state) - def _delete_session(self, session_id: str) -> None: - sql = f"DELETE FROM {self._session_table} WHERE id = %s" - - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute(sql, (session_id,)) - finally: - cursor.close() - conn.commit() + def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return _pymysql_list_sessions(self, app_name, user_id) - async def delete_session(self, session_id: str) -> None: + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - - def _list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - if user_id is None: - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE app_name = %s - ORDER BY update_time DESC - """ - params: tuple[str, ...] = (app_name,) - else: - sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE app_name = %s AND user_id = %s - ORDER BY update_time DESC - """ - params = (app_name, user_id) - - try: - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute(sql, params) - rows = cursor.fetchall() - finally: - cursor.close() - - return [ - SessionRecord( - id=row[0], - app_name=row[1], - user_id=row[2], - state=from_json(row[3]) if isinstance(row[3], str) else row[3], - create_time=row[4], - update_time=row[5], - ) - for row in rows - ] - except pymysql.MySQLError as exc: - if "doesn't exist" in str(exc) or getattr(exc, "args", [None])[0] == MYSQL_TABLE_NOT_FOUND_ERROR: - return [] - raise - - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) + _pymysql_delete_session(self, app_name, user_id, session_id) - def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + _pymysql_append_event(self, event_record) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: - """Atomically create an event and update the session's durable state. + """Atomically append an event and update session + scoped state.""" + return _pymysql_append_event_and_update_state( + self, event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state + ) - MySQL doesn't support UPDATE...RETURNING; the UPDATE is followed by a - SELECT inside the same transaction so callers get the refreshed row - without acquiring a second connection. + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session.""" + return _pymysql_get_events(self, app_name, user_id, session_id, after_timestamp, limit) - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot. - """ - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json - state_json = to_json(state) - - insert_sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) - """ + def delete_expired_events(self, before: "datetime") -> int: + """Delete events older than the given timestamp.""" + return _pymysql_delete_expired_events(self, before) - update_sql = f""" - UPDATE {self._session_table} - SET state = %s - WHERE id = %s - """ + def delete_idle_sessions(self, updated_before: "datetime") -> int: + """Delete sessions whose update_time predates the threshold.""" + return _pymysql_delete_idle_sessions(self, updated_before) - select_sql = f""" - SELECT id, app_name, user_id, state, create_time, update_time - FROM {self._session_table} - WHERE id = %s - """ + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return _pymysql_get_app_state(self, app_name) - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute( - insert_sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_json_str, - ), - ) - cursor.execute(update_sql, (state_json, session_id)) - cursor.execute(select_sql, (session_id,)) - row = cursor.fetchone() - finally: - cursor.close() - conn.commit() + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return _pymysql_get_user_state(self, app_name, user_id) - if row is None: - msg = f"Session {session_id} not found during append_event_and_update_state." - raise ValueError(msg) - - state_value = row[3] - return SessionRecord( - id=row[0], - app_name=row[1], - user_id=row[2], - state=from_json(state_value) if isinstance(state_value, str) else state_value, - create_time=row[4], - update_time=row[5], - ) + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + _pymysql_upsert_app_state(self, app_name, state) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + _pymysql_upsert_user_state(self, app_name, user_id, state) - def _insert_event(self, event_record: EventRecord) -> None: - event_json = event_record["event_json"] - event_json_str = to_json(event_json) if not isinstance(event_json, str) else event_json + def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return _pymysql_get_metadata(self, key) - sql = f""" - INSERT INTO {self._events_table} ( - session_id, invocation_id, author, timestamp, event_json - ) VALUES (%s, %s, %s, %s, %s) - """ + def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + _pymysql_set_metadata(self, key, value) - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute( - sql, - ( - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - event_record["timestamp"], - event_json_str, - ), - ) - finally: - cursor.close() - conn.commit() + def _get_create_sessions_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for sessions.""" + return _mysql_sessions_ddl(self._session_table, self._owner_id_column_ddl) - def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """List events for a session ordered by timestamp. + def _get_create_events_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for events.""" + return _mysql_events_ddl(self._events_table, self._session_table) - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. + def _get_create_app_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for app-scoped state.""" + return _mysql_app_state_ddl(self._app_state_table) - Returns: - List of event records ordered by timestamp ASC. - """ - where_clauses = ["session_id = %s"] - params: list[Any] = [session_id] + def _get_create_user_states_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for user-scoped state.""" + return _mysql_user_state_ddl(self._user_state_table) - if after_timestamp is not None: - where_clauses.append("timestamp > %s") - params.append(after_timestamp) + def _get_create_metadata_table_sql(self) -> str: + """Get MySQL CREATE TABLE SQL for ADK metadata.""" + return _mysql_metadata_ddl(self._metadata_table) - where_clause = " AND ".join(where_clauses) - limit_clause = " LIMIT %s" if limit else "" - sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} - WHERE {where_clause} - ORDER BY timestamp ASC{limit_clause} - """ - if limit: - params.append(limit) + def _get_seed_metadata_sql(self) -> str: + """Get MySQL metadata seed SQL.""" + return f"INSERT IGNORE INTO {self._metadata_table} (`key`, value) VALUES ('schema_version', '1')" - try: - with self._config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute(sql, tuple(params)) - rows = cursor.fetchall() - finally: - cursor.close() - - return [ - EventRecord( - session_id=row[0], - invocation_id=row[1], - author=row[2], - timestamp=row[3], - event_json=from_json(row[4]) if isinstance(row[4], str) else row[4], - ) - for row in rows - ] - except pymysql.MySQLError as exc: - if "doesn't exist" in str(exc) or getattr(exc, "args", [None])[0] == MYSQL_TABLE_NOT_FOUND_ERROR: - return [] - raise + def _get_drop_app_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) + def _get_drop_user_states_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" - def _append_event(self, event_record: EventRecord) -> None: - """Synchronous implementation of append_event.""" - self._insert_event(event_record) + def _get_drop_metadata_table_sql(self) -> str: + """Get MySQL DROP TABLE SQL for ADK metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) + def _get_drop_tables_sql(self) -> "list[str]": + """Get MySQL DROP TABLE SQL statements.""" + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] -class PyMysqlADKMemoryStore(BaseAsyncADKMemoryStore["PyMysqlConfig"]): +class PyMysqlADKMemoryStore(BaseSyncADKMemoryStore["PyMysqlConfig"]): """MySQL/MariaDB ADK memory store using PyMySQL.""" __slots__ = () @@ -436,7 +175,29 @@ class PyMysqlADKMemoryStore(BaseAsyncADKMemoryStore["PyMysqlConfig"]): def __init__(self, config: "PyMysqlConfig") -> None: super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return self._insert_memory_entries(entries, owner_id) + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return self._search_entries(query, app_name, user_id, limit) + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._delete_entries_by_session(session_id) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return self._delete_entries_older_than(days) + + def _get_create_memory_table_sql(self) -> str: owner_id_line = "" fk_constraint = "" if self._owner_id_column_ddl: @@ -475,11 +236,7 @@ def _create_tables(self) -> None: return with self._config.provide_session() as driver: - driver.execute_script(run_(self._get_create_memory_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(self._get_create_memory_table_sql()) def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: if not self._enabled: @@ -547,10 +304,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec conn.commit() return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -591,12 +344,6 @@ def _search_entries( return [cast("MemoryRecord", dict(zip(columns, row, strict=False))) for row in rows] - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _delete_entries_by_session(self, session_id: str) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -612,10 +359,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: finally: cursor.close() - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: if not self._enabled: msg = "Memory store is disabled" @@ -634,9 +377,365 @@ def _delete_entries_older_than(self, days: int) -> int: finally: cursor.close() - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) + +def _pymysql_create_tables(store: PyMysqlADKStore) -> None: + with store._config.provide_session() as driver: + driver.execute_script(store._get_create_sessions_table_sql()) + driver.execute_script(store._get_create_events_table_sql()) + driver.execute_script(store._get_create_app_states_table_sql()) + driver.execute_script(store._get_create_user_states_table_sql()) + driver.execute_script(store._get_create_metadata_table_sql()) + driver.execute_script(store._get_seed_metadata_sql()) + + +def _pymysql_create_session( + store: PyMysqlADKStore, + session_id: str, + app_name: str, + user_id: str, + state: "dict[str, Any]", + owner_id: "Any | None" = None, +) -> SessionRecord: + params: tuple[Any, ...] + if store._owner_id_column_name: + sql = f""" + INSERT INTO {store._session_table} (id, app_name, user_id, {store._owner_id_column_name}, state, create_time, update_time) + VALUES (%s, %s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) + """ + params = (session_id, app_name, user_id, owner_id, to_json(state)) + else: + sql = f""" + INSERT INTO {store._session_table} (id, app_name, user_id, state, create_time, update_time) + VALUES (%s, %s, %s, %s, UTC_TIMESTAMP(6), UTC_TIMESTAMP(6)) + """ + params = (session_id, app_name, user_id, to_json(state)) + + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, params) + finally: + cursor.close() + conn.commit() + + result = _pymysql_get_session(store, app_name, user_id, session_id) + if result is None: + msg = "Failed to fetch created session" + raise RuntimeError(msg) + return result + + +def _pymysql_get_session( + store: PyMysqlADKStore, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None +) -> "SessionRecord | None": + try: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + if renew_for is not None and store._calculate_expires_at(renew_for) is not None: + cursor.execute( + f""" + UPDATE {store._session_table} + SET update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), + ) + conn.commit() + + cursor.execute( + f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {store._session_table} + WHERE app_name = %s AND user_id = %s AND id = %s + """, + (app_name, user_id, session_id), + ) + row = cursor.fetchone() + finally: + cursor.close() + return _session_record_from_row(row) if row is not None else None + except pymysql.MySQLError as exc: + if _is_mysql_table_missing(exc): + return None + raise + + +def _pymysql_update_session_state( + store: PyMysqlADKStore, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]" +) -> None: + sql = f""" + UPDATE {store._session_table} + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s + """ + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (to_json(state), app_name, user_id, session_id)) + finally: + cursor.close() + conn.commit() + + +def _pymysql_list_sessions(store: PyMysqlADKStore, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + if user_id is None: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {store._session_table} + WHERE app_name = %s + ORDER BY update_time DESC + """ + params: tuple[Any, ...] = (app_name,) + else: + sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {store._session_table} + WHERE app_name = %s AND user_id = %s + ORDER BY update_time DESC + """ + params = (app_name, user_id) + + try: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, params) + rows = cursor.fetchall() + finally: + cursor.close() + return [_session_record_from_row(row) for row in rows] + except pymysql.MySQLError as exc: + if _is_mysql_table_missing(exc): + return [] + raise + + +def _pymysql_delete_session(store: PyMysqlADKStore, app_name: str, user_id: str, session_id: str) -> None: + sql = f"DELETE FROM {store._session_table} WHERE app_name = %s AND user_id = %s AND id = %s" + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (app_name, user_id, session_id)) + finally: + cursor.close() + conn.commit() + + +def _pymysql_append_event(store: PyMysqlADKStore, event_record: EventRecord) -> None: + sql = f""" + INSERT INTO {store._events_table} ( + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) + """ + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, _event_insert_params(event_record)) + finally: + cursor.close() + conn.commit() + + +def _pymysql_append_event_and_update_state( + store: PyMysqlADKStore, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, +) -> SessionRecord: + insert_sql = f""" + INSERT INTO {store._events_table} ( + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (%s, %s, %s, %s, %s, %s, %s) + """ + update_sql = f""" + UPDATE {store._session_table} + SET state = %s, update_time = UTC_TIMESTAMP(6) + WHERE app_name = %s AND user_id = %s AND id = %s + """ + select_sql = f""" + SELECT id, app_name, user_id, state, create_time, update_time + FROM {store._session_table} + WHERE app_name = %s AND user_id = %s AND id = %s + """ + + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(update_sql, (to_json(state), app_name, user_id, session_id)) + cursor.execute(select_sql, (app_name, user_id, session_id)) + row = cursor.fetchone() + if row is None: + _raise_session_not_found(session_id) + cursor.execute( + insert_sql, + ( + event_record["id"], + app_name, + user_id, + session_id, + event_record["invocation_id"], + event_record["timestamp"], + _json_for_storage(event_record["event_data"]), + ), + ) + if app_state is not None: + cursor.execute(_mysql_upsert_app_state_sql(store._app_state_table), (app_name, to_json(app_state))) + if user_state is not None: + cursor.execute( + _mysql_upsert_user_state_sql(store._user_state_table), (app_name, user_id, to_json(user_state)) + ) + except Exception: + conn.rollback() + raise + finally: + cursor.close() + conn.commit() + + return _session_record_from_row(row) + + +def _pymysql_get_events( + store: PyMysqlADKStore, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, +) -> "list[EventRecord]": + if limit == 0: + return [] + + where_clauses = ["app_name = %s", "user_id = %s", "session_id = %s"] + params: list[Any] = [app_name, user_id, session_id] + if after_timestamp is not None: + where_clauses.append("timestamp > %s") + params.append(after_timestamp) + limit_clause = "" + if limit is not None: + limit_clause = " LIMIT %s" + params.append(limit) + + sql = f""" + SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data + FROM {store._events_table} + WHERE {" AND ".join(where_clauses)} + ORDER BY timestamp ASC{limit_clause} + """ + + try: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, tuple(params)) + rows = cursor.fetchall() + finally: + cursor.close() + return [_event_record_from_row(row) for row in rows] + except pymysql.MySQLError as exc: + if _is_mysql_table_missing(exc): + return [] + raise + + +def _pymysql_delete_expired_events(store: PyMysqlADKStore, before: "datetime") -> int: + return _pymysql_delete_by_timestamp(store, store._events_table, "timestamp", before) + + +def _pymysql_delete_idle_sessions(store: PyMysqlADKStore, updated_before: "datetime") -> int: + return _pymysql_delete_by_timestamp(store, store._session_table, "update_time", updated_before) + + +def _pymysql_delete_by_timestamp( + store: PyMysqlADKStore, table_name: str, column_name: str, threshold: "datetime" +) -> int: + sql = f"DELETE FROM {table_name} WHERE {column_name} < %s" + try: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (threshold,)) + conn.commit() + return cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + finally: + cursor.close() + except pymysql.MySQLError as exc: + if _is_mysql_table_missing(exc): + return 0 + raise + + +def _pymysql_get_app_state(store: PyMysqlADKStore, app_name: str) -> "dict[str, Any] | None": + return _pymysql_get_state(store, store._app_state_table, "app_name = %s", (app_name,)) + + +def _pymysql_get_user_state(store: PyMysqlADKStore, app_name: str, user_id: str) -> "dict[str, Any] | None": + return _pymysql_get_state(store, store._user_state_table, "app_name = %s AND user_id = %s", (app_name, user_id)) + + +def _pymysql_get_state( + store: PyMysqlADKStore, table_name: str, where_clause: str, params: "tuple[Any, ...]" +) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {table_name} WHERE {where_clause}" + try: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, params) + row = cursor.fetchone() + finally: + cursor.close() + return _json_dict(row[0]) if row is not None else None + except pymysql.MySQLError as exc: + if _is_mysql_table_missing(exc): + return None + raise + + +def _pymysql_upsert_app_state(store: PyMysqlADKStore, app_name: str, state: "dict[str, Any]") -> None: + _pymysql_execute_commit(store, _mysql_upsert_app_state_sql(store._app_state_table), (app_name, to_json(state))) + + +def _pymysql_upsert_user_state(store: PyMysqlADKStore, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + _pymysql_execute_commit( + store, _mysql_upsert_user_state_sql(store._user_state_table), (app_name, user_id, to_json(state)) + ) + + +def _pymysql_get_metadata(store: PyMysqlADKStore, key: str) -> "str | None": + sql = f"SELECT value FROM {store._metadata_table} WHERE `key` = %s" + try: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, (key,)) + row = cursor.fetchone() + finally: + cursor.close() + return str(row[0]) if row is not None else None + except pymysql.MySQLError as exc: + if _is_mysql_table_missing(exc): + return None + raise + + +def _pymysql_set_metadata(store: PyMysqlADKStore, key: str, value: str) -> None: + _pymysql_execute_commit(store, _mysql_upsert_metadata_sql(store._metadata_table), (key, value)) + + +def _pymysql_execute_commit(store: PyMysqlADKStore, sql: str, params: "tuple[Any, ...]") -> None: + with store._config.provide_connection() as conn: + cursor = conn.cursor() + try: + cursor.execute(sql, params) + finally: + cursor.close() + conn.commit() def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": @@ -649,3 +748,150 @@ def _parse_owner_id_column_for_mysql(column_ddl: str) -> "tuple[str, str]": col_name = col_def.split()[0] fk_constraint = f"FOREIGN KEY ({col_name}) REFERENCES {fk_clause}" return (col_def, fk_constraint) + + +def _is_mysql_table_missing(exc: BaseException) -> bool: + args = getattr(exc, "args", ()) + return "doesn't exist" in str(exc) or bool(args and args[0] == MYSQL_TABLE_NOT_FOUND_ERROR) + + +def _json_for_storage(value: Any) -> str: + return value if isinstance(value, str) else to_json(value) + + +def _json_dict(value: Any) -> "dict[str, Any]": + if isinstance(value, bytearray): + value = bytes(value) + if isinstance(value, (bytes, str)): + return cast("dict[str, Any]", from_json(value)) + return cast("dict[str, Any]", value) + + +def _session_record_from_row(row: Any) -> SessionRecord: + return SessionRecord( + id=row[0], app_name=row[1], user_id=row[2], state=_json_dict(row[3]), create_time=row[4], update_time=row[5] + ) + + +def _event_record_from_row(row: Any) -> EventRecord: + return EventRecord( + id=row[0], + app_name=row[1], + user_id=row[2], + session_id=row[3], + invocation_id=row[4], + timestamp=row[5], + event_data=_json_dict(row[6]), + ) + + +def _event_insert_params(event_record: EventRecord) -> "tuple[Any, ...]": + return ( + event_record["id"], + event_record["app_name"], + event_record["user_id"], + event_record["session_id"], + event_record["invocation_id"], + event_record["timestamp"], + _json_for_storage(event_record["event_data"]), + ) + + +def _mysql_sessions_ddl(session_table: str, owner_id_column_ddl: "str | None") -> str: + owner_id_line = "" + fk_constraint = "" + if owner_id_column_ddl: + col_def, fk_def = _parse_owner_id_column_for_mysql(owner_id_column_ddl) + owner_id_line = f"\n {col_def}," + if fk_def: + fk_constraint = f",\n {fk_def}" + + return f""" + CREATE TABLE IF NOT EXISTS {session_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL,{owner_id_line} + state JSON NOT NULL, + create_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + INDEX idx_{session_table}_app_user (app_name, user_id), + INDEX idx_{session_table}_update_time (update_time DESC){fk_constraint} + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_events_ddl(events_table: str, session_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {events_table} ( + id VARCHAR(128) PRIMARY KEY, + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + session_id VARCHAR(128) NOT NULL, + invocation_id VARCHAR(256) NOT NULL, + timestamp TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + event_data JSON NOT NULL, + FOREIGN KEY (session_id) REFERENCES {session_table}(id) ON DELETE CASCADE, + INDEX idx_{events_table}_scope (app_name, user_id, session_id, timestamp ASC), + INDEX idx_{events_table}_session (session_id, timestamp ASC) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_app_state_ddl(app_state_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {app_state_table} ( + app_name VARCHAR(128) PRIMARY KEY, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_user_state_ddl(user_state_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {user_state_table} ( + app_name VARCHAR(128) NOT NULL, + user_id VARCHAR(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6) ON UPDATE CURRENT_TIMESTAMP(6), + PRIMARY KEY (app_name, user_id) + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_metadata_ddl(metadata_table: str) -> str: + return f""" + CREATE TABLE IF NOT EXISTS {metadata_table} ( + `key` VARCHAR(128) PRIMARY KEY, + value VARCHAR(512) NOT NULL + ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci + """ + + +def _mysql_upsert_app_state_sql(app_state_table: str) -> str: + return f""" + INSERT INTO {app_state_table} (app_name, state, update_time) + VALUES (%s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + +def _mysql_upsert_user_state_sql(user_state_table: str) -> str: + return f""" + INSERT INTO {user_state_table} (app_name, user_id, state, update_time) + VALUES (%s, %s, %s, UTC_TIMESTAMP(6)) + ON DUPLICATE KEY UPDATE state = VALUES(state), update_time = UTC_TIMESTAMP(6) + """ + + +def _mysql_upsert_metadata_sql(metadata_table: str) -> str: + return f""" + INSERT INTO {metadata_table} (`key`, value) + VALUES (%s, %s) + ON DUPLICATE KEY UPDATE value = VALUES(value) + """ + + +def _raise_session_not_found(session_id: str) -> None: + msg = f"Session {session_id} not found during append_event_and_update_state." + raise ValueError(msg) diff --git a/sqlspec/adapters/spanner/adk/store.py b/sqlspec/adapters/spanner/adk/store.py index b3b8e030b..aba580768 100644 --- a/sqlspec/adapters/spanner/adk/store.py +++ b/sqlspec/adapters/spanner/adk/store.py @@ -2,16 +2,15 @@ from collections.abc import Iterable from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING, Any, ClassVar, Protocol, cast +from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, cast from google.cloud.spanner_v1 import param_types from sqlspec.adapters.spanner.config import SpannerSyncConfig -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore from sqlspec.protocols import SpannerParamTypesProtocol from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: from google.cloud.spanner_v1.database import Database @@ -23,9 +22,11 @@ __all__ = ("SpannerSyncADKMemoryStore", "SpannerSyncADKStore") SPANNER_PARAM_TYPES: SpannerParamTypesProtocol = cast("SpannerParamTypesProtocol", param_types) +MIN_DROP_TABLE_TOKENS: Final = 3 +MIN_DROP_SEARCH_INDEX_TOKENS: Final = 4 -class SpannerSyncADKStore(BaseAsyncADKStore[SpannerSyncConfig]): +class SpannerSyncADKStore(BaseSyncADKStore[SpannerSyncConfig]): """Spanner ADK store backed by synchronous Spanner client.""" connector_name: ClassVar[str] = "spanner" @@ -42,9 +43,107 @@ def __init__(self, config: SpannerSyncConfig) -> None: ) self._events_row_deletion_policy = _spanner_row_deletion_policy(adk_config, "event_ttl_seconds", "timestamp") + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session.""" + return self._create_session(session_id, app_name, user_id, state, owner_id) + + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID.""" + return self._get_session(app_name, user_id, session_id, renew_for=renew_for) + + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + self._update_session_state(app_name, user_id, session_id, state) + + def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app.""" + return self._list_sessions(app_name, user_id) + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and associated events.""" + self._delete_session(app_name, user_id, session_id) + + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session.""" + self._append_event(event_record) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update the session's durable state.""" + return self._append_event_and_update_state( + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state + ) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session.""" + return self._get_events(app_name, user_id, session_id, after_timestamp, limit) + + def delete_expired_events(self, before: datetime) -> int: + """Delete events older than a timestamp.""" + return self._delete_expired_events(before) + + def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions older than a timestamp.""" + return self._delete_idle_sessions(updated_before) + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state.""" + return self._get_app_state(app_name) + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state.""" + return self._get_user_state(app_name, user_id) + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state.""" + self._upsert_app_state(app_name, state) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state.""" + self._upsert_user_state(app_name, user_id, state) + + def get_metadata(self, key: str) -> "str | None": + """Return a metadata value.""" + return self._get_metadata(key) + + def set_metadata(self, key: str, value: str) -> None: + """Set a metadata value.""" + self._set_metadata(key, value) + def _database(self) -> "Database": return self._config.get_database() + def _get_reset_drop_tables_sql(self) -> "list[str]": + return _filter_existing_spanner_drops(super()._get_reset_drop_tables_sql(), self._existing_tables()) + + def _existing_tables(self) -> "set[str]": + database = self._database() + return {table.table_id for table in database.list_tables()} # type: ignore[no-untyped-call] + def _run_read( self, sql: str, params: "dict[str, Any] | None" = None, types: "dict[str, Any] | None" = None ) -> "list[Any]": @@ -70,13 +169,26 @@ def _session_param_types(self, include_owner: bool) -> "dict[str, Any]": def _event_param_types(self) -> "dict[str, Any]": json_type = _json_param_type() return { + "id": SPANNER_PARAM_TYPES.STRING, "session_id": SPANNER_PARAM_TYPES.STRING, "invocation_id": SPANNER_PARAM_TYPES.STRING, - "author": SPANNER_PARAM_TYPES.STRING, "timestamp": SPANNER_PARAM_TYPES.TIMESTAMP, - "event_json": json_type, + "event_data": json_type, } + def _app_state_param_types(self) -> "dict[str, Any]": + return {"app_name": SPANNER_PARAM_TYPES.STRING, "state": _json_param_type()} + + def _user_state_param_types(self) -> "dict[str, Any]": + return { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "state": _json_param_type(), + } + + def _metadata_param_types(self) -> "dict[str, Any]": + return {"key": SPANNER_PARAM_TYPES.STRING, "value": SPANNER_PARAM_TYPES.STRING} + def _decode_state(self, raw: Any) -> Any: if isinstance(raw, str): return from_json(raw) @@ -118,23 +230,47 @@ def _create_session( "update_time": datetime.now(timezone.utc), } - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session.""" - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) + def _get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f""" + UPDATE {self._session_table} + SET update_time = PENDING_COMMIT_TIMESTAMP() + WHERE app_name = @app_name AND user_id = @user_id AND id = @id + """ + if self._shard_count > 1: + update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" + self._run_write([ + ( + update_sql, + {"app_name": app_name, "user_id": user_id, "id": session_id}, + { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "id": SPANNER_PARAM_TYPES.STRING, + }, + ) + ]) - def _get_session(self, session_id: str) -> "SessionRecord | None": sql = f""" SELECT id, app_name, user_id, state, create_time, update_time{", " + self._owner_id_column_name if self._owner_id_column_name else ""} FROM {self._session_table} - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id """ if self._shard_count > 1: sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" sql = f"{sql} LIMIT 1" - params = {"id": session_id} - rows = self._run_read(sql, params, {"id": SPANNER_PARAM_TYPES.STRING}) + params = {"app_name": app_name, "user_id": user_id, "id": session_id} + rows = self._run_read( + sql, + params, + { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "id": SPANNER_PARAM_TYPES.STRING, + }, + ) if not rows: return None @@ -150,25 +286,28 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": } return record - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID.""" - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - params = {"id": session_id, "state": to_json(state)} + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + params = {"app_name": app_name, "user_id": user_id, "id": session_id, "state": to_json(state)} json_type = _json_param_type() sql = f""" UPDATE {self._session_table} SET state = @state, update_time = PENDING_COMMIT_TIMESTAMP() - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id """ if self._shard_count > 1: sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" - self._run_write([(sql, params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type})]) - - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state.""" - await async_(self._update_session_state)(session_id, state) + self._run_write([ + ( + sql, + params, + { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "id": SPANNER_PARAM_TYPES.STRING, + "state": json_type, + }, + ) + ]) def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": sql = f""" @@ -201,26 +340,33 @@ def _list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[S records.append(record) return records - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app.""" - return await async_(self._list_sessions)(app_name, user_id) - - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: shard_clause = ( f" AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" if self._shard_count > 1 else "" ) delete_events_sql = f"DELETE FROM {self._events_table} WHERE session_id = @session_id{shard_clause}" - delete_session_sql = f"DELETE FROM {self._session_table} WHERE id = @session_id{shard_clause}" - params = {"session_id": session_id} - types = {"session_id": SPANNER_PARAM_TYPES.STRING} + delete_session_sql = ( + f"DELETE FROM {self._session_table} " + f"WHERE app_name = @app_name AND user_id = @user_id AND id = @session_id{shard_clause}" + ) + params = {"app_name": app_name, "user_id": user_id, "session_id": session_id} + types = { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "session_id": SPANNER_PARAM_TYPES.STRING, + } self._run_write([(delete_events_sql, params, types), (delete_session_sql, params, types)]) - async def delete_session(self, session_id: str) -> None: - """Delete session and associated events.""" - await async_(self._delete_session)(session_id) - def _append_event_and_update_state( - self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + self, + event_record: "EventRecord", + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Atomically insert an event and update session state in one transaction. @@ -232,79 +378,124 @@ def _append_event_and_update_state( Args: event_record: Event record to store. + app_name: Application name. + user_id: User identifier. session_id: Session whose state should be updated. state: Post-append durable state snapshot. + app_state: Optional app-scoped state snapshot. + user_state: Optional user-scoped state snapshot. """ event_params: dict[str, Any] = { + "id": event_record["id"], "session_id": event_record["session_id"], "invocation_id": event_record["invocation_id"], - "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": to_json(event_record["event_json"]), + "event_data": to_json(event_record["event_data"]), } insert_sql = f""" - INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) - VALUES (@session_id, @invocation_id, @author, @timestamp, @event_json) + INSERT INTO {self._events_table} (id, session_id, invocation_id, timestamp, event_data) + VALUES (@id, @session_id, @invocation_id, @timestamp, @event_data) """ json_type = _json_param_type() - state_params: dict[str, Any] = {"id": session_id, "state": to_json(state)} + state_params: dict[str, Any] = { + "app_name": app_name, + "user_id": user_id, + "id": session_id, + "state": to_json(state), + } update_sql = f""" UPDATE {self._session_table} SET state = @state, update_time = PENDING_COMMIT_TIMESTAMP() - WHERE id = @id + WHERE app_name = @app_name AND user_id = @user_id AND id = @id """ if self._shard_count > 1: update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@id), {self._shard_count})" - self._run_write([ + statements: list[tuple[str, dict[str, Any], dict[str, Any]]] = [ (insert_sql, event_params, self._event_param_types()), - (update_sql, state_params, {"id": SPANNER_PARAM_TYPES.STRING, "state": json_type}), - ]) - - record = self._get_session(session_id) + ( + update_sql, + state_params, + { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "id": SPANNER_PARAM_TYPES.STRING, + "state": json_type, + }, + ), + ] + if app_state is not None: + statements.append(( + f""" + INSERT OR UPDATE {self._app_state_table} (app_name, state, update_time) + VALUES (@app_name, @state, PENDING_COMMIT_TIMESTAMP()) + """, + {"app_name": app_name, "state": to_json(app_state)}, + self._app_state_param_types(), + )) + if user_state is not None: + statements.append(( + f""" + INSERT OR UPDATE {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (@app_name, @user_id, @state, PENDING_COMMIT_TIMESTAMP()) + """, + {"app_name": app_name, "user_id": user_id, "state": to_json(user_state)}, + self._user_state_param_types(), + )) + + self._run_write(statements) + + record = self._get_session(app_name, user_id, session_id) if record is None: msg = f"Session {session_id} not found during append_event_and_update_state." raise ValueError(msg) return record - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state.""" - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _insert_event(self, event_record: "EventRecord") -> None: event_params: dict[str, Any] = { + "id": event_record["id"], "session_id": event_record["session_id"], "invocation_id": event_record["invocation_id"], - "author": event_record["author"], "timestamp": event_record["timestamp"], - "event_json": to_json(event_record["event_json"]), + "event_data": to_json(event_record["event_data"]), } insert_sql = f""" - INSERT INTO {self._events_table} (session_id, invocation_id, author, timestamp, event_json) - VALUES (@session_id, @invocation_id, @author, @timestamp, @event_json) + INSERT INTO {self._events_table} (id, session_id, invocation_id, timestamp, event_data) + VALUES (@id, @session_id, @invocation_id, @timestamp, @event_data) """ self._run_write([(insert_sql, event_params, self._event_param_types())]) def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": + if limit == 0: + return [] + sql = f""" - SELECT session_id, invocation_id, author, timestamp, event_json - FROM {self._events_table} - WHERE session_id = @session_id + SELECT e.id, e.session_id, e.invocation_id, e.timestamp, e.event_data, s.app_name, s.user_id + FROM {self._events_table} e + JOIN {self._session_table} s ON e.session_id = s.id + WHERE s.app_name = @app_name AND s.user_id = @user_id AND e.session_id = @session_id """ if self._shard_count > 1: - sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" - params: dict[str, Any] = {"session_id": session_id} - types: dict[str, Any] = {"session_id": SPANNER_PARAM_TYPES.STRING} + sql = f"{sql} AND e.shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})" + params: dict[str, Any] = {"app_name": app_name, "user_id": user_id, "session_id": session_id} + types: dict[str, Any] = { + "app_name": SPANNER_PARAM_TYPES.STRING, + "user_id": SPANNER_PARAM_TYPES.STRING, + "session_id": SPANNER_PARAM_TYPES.STRING, + } if after_timestamp is not None: - sql = f"{sql} AND timestamp > @after_timestamp" + sql = f"{sql} AND e.timestamp > @after_timestamp" params["after_timestamp"] = after_timestamp types["after_timestamp"] = SPANNER_PARAM_TYPES.TIMESTAMP - sql = f"{sql} ORDER BY timestamp ASC" + sql = f"{sql} ORDER BY e.timestamp ASC" if limit is not None: sql = f"{sql} LIMIT @limit" params["limit"] = limit @@ -312,28 +503,91 @@ def _get_events( rows = self._run_read(sql, params, types) return [ { - "session_id": row[0], - "invocation_id": row[1] or "", - "author": row[2] or "", + "id": row[0], + "session_id": row[1], + "invocation_id": row[2] or "", "timestamp": row[3], - "event_json": row[4], + "event_data": self._decode_json(row[4]) or {}, + "app_name": row[5], + "user_id": row[6], } for row in rows ] - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session.""" - return await async_(self._get_events)(session_id, after_timestamp, limit) - def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" self._insert_event(event_record) - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session.""" - await async_(self._append_event)(event_record) + def _delete_expired_events(self, before: datetime) -> int: + sql = f"DELETE FROM {self._events_table} WHERE timestamp < @before" + return int( + cast("Any", self._database()).run_in_transaction( + _SpannerUpdateJob(sql, {"before": before}, {"before": SPANNER_PARAM_TYPES.TIMESTAMP}) + ) + ) + + def _delete_idle_sessions(self, updated_before: datetime) -> int: + sql = f"DELETE FROM {self._session_table} WHERE update_time < @updated_before" + return int( + cast("Any", self._database()).run_in_transaction( + _SpannerUpdateJob( + sql, {"updated_before": updated_before}, {"updated_before": SPANNER_PARAM_TYPES.TIMESTAMP} + ) + ) + ) + + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = @app_name LIMIT 1" + rows = self._run_read(sql, {"app_name": app_name}, {"app_name": SPANNER_PARAM_TYPES.STRING}) + if not rows: + return None + return self._decode_json(rows[0][0]) or {} + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + sql = f""" + SELECT state + FROM {self._user_state_table} + WHERE app_name = @app_name AND user_id = @user_id + LIMIT 1 + """ + rows = self._run_read( + sql, + {"app_name": app_name, "user_id": user_id}, + {"app_name": SPANNER_PARAM_TYPES.STRING, "user_id": SPANNER_PARAM_TYPES.STRING}, + ) + if not rows: + return None + return self._decode_json(rows[0][0]) or {} + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT OR UPDATE {self._app_state_table} (app_name, state, update_time) + VALUES (@app_name, @state, PENDING_COMMIT_TIMESTAMP()) + """ + self._run_write([(sql, {"app_name": app_name, "state": to_json(state)}, self._app_state_param_types())]) + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + sql = f""" + INSERT OR UPDATE {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (@app_name, @user_id, @state, PENDING_COMMIT_TIMESTAMP()) + """ + self._run_write([ + (sql, {"app_name": app_name, "user_id": user_id, "state": to_json(state)}, self._user_state_param_types()) + ]) + + def _get_metadata(self, key: str) -> "str | None": + sql = f"SELECT value FROM {self._metadata_table} WHERE key = @key LIMIT 1" + rows = self._run_read(sql, {"key": key}, {"key": SPANNER_PARAM_TYPES.STRING}) + if not rows: + return None + return str(rows[0][0]) + + def _set_metadata(self, key: str, value: str) -> None: + sql = f""" + INSERT OR UPDATE {self._metadata_table} (key, value) + VALUES (@key, @value) + """ + self._run_write([(sql, {"key": key, "value": value}, self._metadata_param_types())]) def _create_tables(self) -> None: database = self._database() @@ -341,18 +595,22 @@ def _create_tables(self) -> None: ddl_statements: list[str] = [] if self._session_table not in existing_tables: - ddl_statements.append(run_(self._get_create_sessions_table_sql)()) + ddl_statements.append(self._get_create_sessions_table_sql()) if self._events_table not in existing_tables: - ddl_statements.append(run_(self._get_create_events_table_sql)()) + ddl_statements.append(self._get_create_events_table_sql()) + if self._app_state_table not in existing_tables: + ddl_statements.append(self._get_create_app_states_table_sql()) + if self._user_state_table not in existing_tables: + ddl_statements.append(self._get_create_user_states_table_sql()) + if self._metadata_table not in existing_tables: + ddl_statements.append(self._get_create_metadata_table_sql()) if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] + if self._metadata_table not in existing_tables: + self._run_write([(self._get_seed_metadata_sql(), {}, {})]) - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - - async def _get_create_sessions_table_sql(self) -> str: + def _get_create_sessions_table_sql(self) -> str: owner_line = "" if self._owner_id_column_ddl: owner_line = f",\n {self._owner_id_column_ddl}" @@ -375,7 +633,7 @@ async def _get_create_sessions_table_sql(self) -> str: ) {pk}{options}{self._session_row_deletion_policy} """ - async def _get_create_events_table_sql(self) -> str: + def _get_create_events_table_sql(self) -> str: shard_column = "" pk = "PRIMARY KEY (session_id, timestamp)" if self._shard_count > 1: @@ -386,19 +644,67 @@ async def _get_create_events_table_sql(self) -> str: options = f"\nOPTIONS ({self._events_table_options})" return f""" CREATE TABLE {self._events_table} ( + id STRING(128) NOT NULL, session_id STRING(128) NOT NULL, - invocation_id STRING(256) NOT NULL, - author STRING(128) NOT NULL, + invocation_id STRING(256), timestamp TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true), - event_json JSON NOT NULL{shard_column} + event_data JSON NOT NULL{shard_column} ) {pk}{options}{self._events_row_deletion_policy} """ + def _get_create_app_states_table_sql(self) -> str: + return f""" +CREATE TABLE {self._app_state_table} ( + app_name STRING(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true) +) PRIMARY KEY (app_name) +""" + + def _get_create_user_states_table_sql(self) -> str: + return f""" +CREATE TABLE {self._user_state_table} ( + app_name STRING(128) NOT NULL, + user_id STRING(128) NOT NULL, + state JSON NOT NULL, + update_time TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true) +) PRIMARY KEY (app_name, user_id) +""" + + def _get_create_metadata_table_sql(self) -> str: + return f""" +CREATE TABLE {self._metadata_table} ( + key STRING(128) NOT NULL, + value STRING(512) NOT NULL +) PRIMARY KEY (key) +""" + + def _get_seed_metadata_sql(self) -> str: + return f""" +INSERT INTO {self._metadata_table} (key, value) +VALUES ('schema_version', '1') +""" + + def _get_drop_app_states_table_sql(self) -> str: + return f"DROP TABLE {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + return f"DROP TABLE {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE {self._metadata_table}" + def _get_drop_tables_sql(self) -> "list[str]": - return [f"DROP TABLE {self._events_table}", f"DROP TABLE {self._session_table}"] + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE {self._events_table}", + f"DROP TABLE {self._session_table}", + ] -class SpannerSyncADKMemoryStore(BaseAsyncADKMemoryStore[SpannerSyncConfig]): +class SpannerSyncADKMemoryStore(BaseSyncADKMemoryStore[SpannerSyncConfig]): """Spanner ADK memory store backed by synchronous Spanner client.""" connector_name: ClassVar[str] = "spanner" @@ -413,9 +719,38 @@ def __init__(self, config: SpannerSyncConfig) -> None: cast("dict[str, Any]", adk_config), "memory_ttl_seconds", "inserted_at" ) + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return self._insert_memory_entries(entries, owner_id) + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return self._search_entries(query, app_name, user_id, limit) + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._delete_entries_by_session(session_id) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return self._delete_entries_older_than(days) + def _database(self) -> "Database": return self._config.get_database() + def _get_reset_drop_memory_table_sql(self) -> "list[str]": + return _filter_existing_spanner_drops(super()._get_reset_drop_memory_table_sql(), self._existing_tables()) + + def _existing_tables(self) -> "set[str]": + database = self._database() + return {table.table_id for table in database.list_tables()} # type: ignore[no-untyped-call] + def _run_read( self, sql: str, params: "dict[str, Any] | None" = None, types: "dict[str, Any] | None" = None ) -> "list[Any]": @@ -464,16 +799,12 @@ def _create_tables(self) -> None: ddl_statements: list[str] = [] if self._memory_table not in existing_tables: - ddl_statements.extend(run_(self._get_create_memory_table_sql)()) + ddl_statements.extend(self._get_create_memory_table_sql()) if ddl_statements: database.update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call] - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() - - async def _get_create_memory_table_sql(self) -> "list[str]": + def _get_create_memory_table_sql(self) -> "list[str]": owner_line = "" if self._owner_id_column_ddl: owner_line = f",\n {self._owner_id_column_ddl}" @@ -585,10 +916,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec self._run_write(statements) return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _event_exists(self, event_id: str) -> bool: sql = f"SELECT event_id FROM {self._memory_table} WHERE event_id = @event_id LIMIT 1" rows = self._run_read(sql, {"event_id": event_id}, {"event_id": SPANNER_PARAM_TYPES.STRING}) @@ -607,12 +934,6 @@ def _search_entries( return self._search_entries_fts(query, app_name, user_id, effective_limit) return self._search_entries_simple(query, app_name, user_id, effective_limit) - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT id, session_id, app_name, user_id, event_id, author, @@ -662,10 +983,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: types = {"session_id": SPANNER_PARAM_TYPES.STRING} return self._execute_update(sql, params, types) - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: cutoff = datetime.now(timezone.utc) - timedelta(days=days) sql = f"DELETE FROM {self._memory_table} WHERE inserted_at < @cutoff" @@ -673,10 +990,6 @@ def _delete_entries_older_than(self, days: int) -> int: types = {"cutoff": SPANNER_PARAM_TYPES.TIMESTAMP} return self._execute_update(sql, params, types) - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _rows_to_records(self, rows: "list[Any]") -> "list[MemoryRecord]": return [ { @@ -719,6 +1032,35 @@ def _spanner_row_deletion_policy(adk_config: dict[str, Any], ttl_key: str, colum return f"\nROW DELETION POLICY (OLDER_THAN({column}, INTERVAL {ttl_days} DAY))" +def _filter_existing_spanner_drops(statements: "list[str]", existing_tables: "set[str]") -> "list[str]": + return [statement for statement in statements if _spanner_drop_statement_table(statement, existing_tables)] + + +def _spanner_drop_statement_table(statement: str, existing_tables: "set[str]") -> "str | None": + tokens = statement.strip().split() + if len(tokens) >= MIN_DROP_TABLE_TOKENS and tokens[0].upper() == "DROP" and tokens[1].upper() == "TABLE": + table_name = tokens[2] + return table_name if table_name in existing_tables else None + + index_name: str | None = None + if len(tokens) >= MIN_DROP_TABLE_TOKENS and tokens[0].upper() == "DROP" and tokens[1].upper() == "INDEX": + index_name = tokens[2] + if ( + len(tokens) >= MIN_DROP_SEARCH_INDEX_TOKENS + and tokens[0].upper() == "DROP" + and tokens[1].upper() == "SEARCH" + and tokens[2].upper() == "INDEX" + ): + index_name = tokens[3] + if index_name is None: + return None + + for table_name in existing_tables: + if index_name.startswith(f"idx_{table_name}_"): + return table_name + return None + + class _SpannerWriteJob: __slots__ = ("_statements",) @@ -741,6 +1083,18 @@ def __call__(self, transaction: "Transaction") -> None: transaction.execute_update(sql, params=params, param_types=types) # type: ignore[no-untyped-call] +class _SpannerUpdateJob: + __slots__ = ("_params", "_sql", "_types") + + def __init__(self, sql: str, params: "dict[str, Any]", types: "dict[str, Any]") -> None: + self._sql = sql + self._params = params + self._types = types + + def __call__(self, transaction: "Transaction") -> int: + return int(transaction.execute_update(self._sql, params=self._params, param_types=self._types)) # type: ignore[no-untyped-call] + + class _SpannerMemoryUpdateJob: __slots__ = ("_params", "_sql", "_types") diff --git a/sqlspec/adapters/sqlite/adk/store.py b/sqlspec/adapters/sqlite/adk/store.py index ffe738771..4959ea74d 100644 --- a/sqlspec/adapters/sqlite/adk/store.py +++ b/sqlspec/adapters/sqlite/adk/store.py @@ -1,14 +1,13 @@ """SQLite sync ADK store for Google Agent Development Kit session/event storage.""" import sqlite3 -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Final -from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord -from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore +from sqlspec.extensions.adk import BaseSyncADKStore, EventRecord, SessionRecord +from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore from sqlspec.utils.logging import get_logger from sqlspec.utils.serializers import from_json, to_json -from sqlspec.utils.sync_tools import async_, run_ if TYPE_CHECKING: import logging @@ -27,7 +26,7 @@ logger: "logging.Logger" = get_logger("sqlspec.adapters.sqlite.adk.store") -class SqliteADKStore(BaseAsyncADKStore["SqliteConfig"]): +class SqliteADKStore(BaseSyncADKStore["SqliteConfig"]): """SQLite ADK store using synchronous SQLite driver. Implements session and event storage for Google Agent Development Kit @@ -56,80 +55,189 @@ def __init__(self, config: "SqliteConfig") -> None: """ super().__init__(config) - def _apply_pragmas(self, connection: Any) -> None: - """Apply PRAGMA optimization profile for this connection. + def create_tables(self) -> None: + """Create both sessions and events tables if they don't exist.""" + self._create_tables() + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> SessionRecord: + """Create a new session. Args: - connection: SQLite connection. + session_id: Unique session identifier. + app_name: Application name. + user_id: User identifier. + state: Initial session state. + owner_id: Optional owner ID value for owner ID column. + + Returns: + Created session record. """ - connection.execute("PRAGMA foreign_keys = ON") - connection.execute("PRAGMA cache_size = -64000") - connection.execute("PRAGMA mmap_size = 30000000") - connection.execute("PRAGMA journal_size_limit = 67108864") + return self._create_session(session_id, app_name, user_id, state, owner_id) - async def _get_create_sessions_table_sql(self) -> str: - """Get SQLite CREATE TABLE SQL for sessions. + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get session by ID. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + renew_for: If positive, touch the session update timestamp while reading. Returns: - SQL statement to create adk_sessions table with indexes. + Session record or None if not found. """ - owner_id_line = "" - if self._owner_id_column_ddl: - owner_id_line = f",\n {self._owner_id_column_ddl}" + return self._get_session(app_name, user_id, session_id, renew_for=renew_for) - return f""" - CREATE TABLE IF NOT EXISTS {self._session_table} ( - id TEXT PRIMARY KEY, - app_name TEXT NOT NULL, - user_id TEXT NOT NULL{owner_id_line}, - state TEXT NOT NULL DEFAULT '{{}}', - create_time REAL NOT NULL, - update_time REAL NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user - ON {self._session_table}(app_name, user_id); - CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time - ON {self._session_table}(update_time DESC); + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + state: New state dictionary (replaces existing state). """ + self._update_session_state(app_name, user_id, session_id, state) - async def _get_create_events_table_sql(self) -> str: - """Get SQLite CREATE TABLE SQL for events. + def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": + """List sessions for an app, optionally filtered by user. + + Args: + app_name: Application name. + user_id: User identifier. If None, lists all sessions for the app. Returns: - SQL statement to create adk_events table with indexes. + List of session records ordered by update_time DESC. """ - return f""" - CREATE TABLE IF NOT EXISTS {self._events_table} ( - id TEXT PRIMARY KEY, - session_id TEXT NOT NULL, - invocation_id TEXT, - author TEXT, - timestamp REAL NOT NULL, - event_data TEXT NOT NULL, - FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE - ); - CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session - ON {self._events_table}(session_id, timestamp ASC); + return self._list_sessions(app_name, user_id) + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete session and all associated events (cascade). + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. """ + self._delete_session(app_name, user_id, session_id) - def _get_drop_tables_sql(self) -> "list[str]": - """Get SQLite DROP TABLE SQL statements. + def append_event(self, event_record: EventRecord) -> None: + """Append an event to a session. + + Args: + event_record: Event record to store. + """ + self._append_event(event_record) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> SessionRecord: + """Atomically append an event and update the session's durable state. + + Inserts the event and updates the session state + update_time in a + single transaction, returning the updated SessionRecord via RETURNING. + + Args: + event_record: Event record to store. + app_name: Application name for scoped state. + user_id: User identifier for scoped state. + session_id: Session identifier whose state should be updated. + state: Post-append durable state snapshot (temp: keys already + stripped by the service layer). + app_state: App-scoped state snapshot to upsert when changed. + user_state: User-scoped state snapshot to upsert when changed. + """ + return self._append_event_and_update_state( + event_record, app_name, user_id, session_id, state, app_state=app_state, user_state=user_state + ) + + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session. + + Args: + app_name: Application name. + user_id: User identifier. + session_id: Session identifier. + after_timestamp: Only return events after this time. + limit: Maximum number of events to return. Returns: - List of SQL statements to drop tables and indexes. + List of event records ordered by timestamp ASC. + """ + return self._get_events(app_name, user_id, session_id, after_timestamp, limit) + + def delete_expired_events(self, before: datetime) -> int: + """Delete events older than the given timestamp.""" + return self._delete_expired_events(before) + + def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time predates the given threshold.""" + return self._delete_idle_sessions(updated_before) + + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + return self._get_app_state(app_name) + + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + return self._get_user_state(app_name, user_id) + + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + self._upsert_app_state(app_name, state) + + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + self._upsert_user_state(app_name, user_id, state) + + def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + return self._get_metadata(key) + + def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + self._set_metadata(key, value) + + def _apply_pragmas(self, connection: Any) -> None: + """Apply PRAGMA optimization profile for this connection. + + Args: + connection: SQLite connection. """ - return [f"DROP TABLE IF EXISTS {self._events_table}", f"DROP TABLE IF EXISTS {self._session_table}"] + connection.execute("PRAGMA foreign_keys = ON") + connection.execute("PRAGMA cache_size = -64000") + connection.execute("PRAGMA mmap_size = 30000000") + connection.execute("PRAGMA journal_size_limit = 67108864") def _create_tables(self) -> None: """Synchronous implementation of create_tables.""" with self._config.provide_session() as driver: self._apply_pragmas(driver.connection) - driver.execute_script(run_(self._get_create_sessions_table_sql)()) - driver.execute_script(run_(self._get_create_events_table_sql)()) - - async def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(self._get_create_sessions_table_sql()) + driver.execute_script(self._get_create_events_table_sql()) + driver.execute_script(self._get_create_app_states_table_sql()) + driver.execute_script(self._get_create_user_states_table_sql()) + driver.execute_script(self._get_create_metadata_table_sql()) + driver.execute_script(self._get_seed_metadata_sql()) def _create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -163,35 +271,37 @@ def _create_session( id=session_id, app_name=app_name, user_id=user_id, state=state, create_time=now, update_time=now ) - async def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> SessionRecord: - """Create a new session. - - Args: - session_id: Unique session identifier. - app_name: Application name. - user_id: User identifier. - state: Initial session state. - owner_id: Optional owner ID value for owner ID column. - - Returns: - Created session record. - """ - return await async_(self._create_session)(session_id, app_name, user_id, state, owner_id) - - def _get_session(self, session_id: str) -> "SessionRecord | None": + def _get_session( + self, app_name: str, user_id: str, session_id: str, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": """Synchronous implementation of get_session.""" + params = (app_name, user_id, session_id) + update_params: tuple[Any, ...] + if renew_for is not None and self._calculate_expires_at(renew_for) is not None: + update_sql = f""" + UPDATE {self._session_table} + SET update_time = ? + WHERE app_name = ? AND user_id = ? AND id = ? + """ + now_julian = _datetime_to_julian(datetime.now(timezone.utc)) + update_params = (now_julian, app_name, user_id, session_id) + else: + update_sql = "" + update_params = () + sql = f""" SELECT id, app_name, user_id, state, create_time, update_time FROM {self._session_table} - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ try: with self._config.provide_connection() as conn: self._apply_pragmas(conn) - cursor = conn.execute(sql, (session_id,)) + if update_sql: + conn.execute(update_sql, update_params) + conn.commit() + cursor = conn.execute(sql, params) row = cursor.fetchone() if row is None: @@ -210,18 +320,7 @@ def _get_session(self, session_id: str) -> "SessionRecord | None": return None raise - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get session by ID. - - Args: - session_id: Session identifier. - - Returns: - Session record or None if not found. - """ - return await async_(self._get_session)(session_id) - - def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + def _update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Synchronous implementation of update_session_state.""" now_julian = _datetime_to_julian(datetime.now(timezone.utc)) state_json = to_json(state) @@ -229,23 +328,14 @@ def _update_session_state(self, session_id: str, state: "dict[str, Any]") -> Non sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? """ with self._config.provide_connection() as conn: self._apply_pragmas(conn) - conn.execute(sql, (state_json, now_julian, session_id)) + conn.execute(sql, (state_json, now_julian, app_name, user_id, session_id)) conn.commit() - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. - - Args: - session_id: Session identifier. - state: New state dictionary (replaces existing state). - """ - await async_(self._update_session_state)(session_id, state) - def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionRecord]": """Synchronous implementation of list_sessions.""" if user_id is None: @@ -287,115 +377,118 @@ def _list_sessions(self, app_name: str, user_id: "str | None") -> "list[SessionR return [] raise - async def list_sessions(self, app_name: str, user_id: str | None = None) -> "list[SessionRecord]": - """List sessions for an app, optionally filtered by user. - - Args: - app_name: Application name. - user_id: User identifier. If None, lists all sessions for the app. - - Returns: - List of session records ordered by update_time DESC. - """ - return await async_(self._list_sessions)(app_name, user_id) - - def _delete_session(self, session_id: str) -> None: + def _delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Synchronous implementation of delete_session.""" - sql = f"DELETE FROM {self._session_table} WHERE id = ?" + sql = f"DELETE FROM {self._session_table} WHERE app_name = ? AND user_id = ? AND id = ?" with self._config.provide_connection() as conn: self._apply_pragmas(conn) - conn.execute(sql, (session_id,)) + conn.execute(sql, (app_name, user_id, session_id)) conn.commit() - async def delete_session(self, session_id: str) -> None: - """Delete session and all associated events (cascade). - - Args: - session_id: Session identifier. - """ - await async_(self._delete_session)(session_id) - def _append_event(self, event_record: EventRecord) -> None: """Synchronous implementation of append_event.""" timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - event_data_json = to_json(event_record["event_json"]) + event_data_json = to_json(event_record["event_data"]) sql = f""" INSERT INTO {self._events_table} ( - id, session_id, invocation_id, author, timestamp, event_data - ) VALUES (?, ?, ?, ?, ?, ?) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?, ?) """ - import uuid - - event_id = str(uuid.uuid4()) - with self._config.provide_connection() as conn: self._apply_pragmas(conn) conn.execute( sql, ( - event_id, + event_record["id"], + event_record["app_name"], + event_record["user_id"], event_record["session_id"], event_record["invocation_id"], - event_record["author"], timestamp_julian, event_data_json, ), ) conn.commit() - async def append_event(self, event_record: EventRecord) -> None: - """Append an event to a session. - - Args: - event_record: Event record with 5 keys: session_id, invocation_id, - author, timestamp, event_json. - """ - await async_(self._append_event)(event_record) - def _append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> SessionRecord: """Synchronous implementation of append_event_and_update_state.""" - import uuid - timestamp_julian = _datetime_to_julian(event_record["timestamp"]) - event_data_json = to_json(event_record["event_json"]) + event_data_json = to_json(event_record["event_data"]) now_julian = _datetime_to_julian(datetime.now(timezone.utc)) state_json = to_json(state) - event_id = str(uuid.uuid4()) insert_sql = f""" INSERT INTO {self._events_table} ( - id, session_id, invocation_id, author, timestamp, event_data - ) VALUES (?, ?, ?, ?, ?, ?) + id, app_name, user_id, session_id, invocation_id, timestamp, event_data + ) VALUES (?, ?, ?, ?, ?, ?, ?) """ update_sql = f""" UPDATE {self._session_table} SET state = ?, update_time = ? - WHERE id = ? + WHERE app_name = ? AND user_id = ? AND id = ? RETURNING id, app_name, user_id, state, create_time, update_time """ + app_upsert_sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + user_upsert_sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + with self._config.provide_connection() as conn: self._apply_pragmas(conn) - conn.execute( - insert_sql, - ( - event_id, - event_record["session_id"], - event_record["invocation_id"], - event_record["author"], - timestamp_julian, - event_data_json, - ), - ) - cursor = conn.execute(update_sql, (state_json, now_julian, session_id)) - row = cursor.fetchone() - conn.commit() + try: + cursor = conn.execute(update_sql, (state_json, now_julian, app_name, user_id, session_id)) + row = cursor.fetchone() + if row is not None: + conn.execute( + insert_sql, + ( + event_record["id"], + app_name, + user_id, + event_record["session_id"], + event_record["invocation_id"], + timestamp_julian, + event_data_json, + ), + ) + if app_state is not None: + conn.execute(app_upsert_sql, (app_name, to_json(app_state), now_julian)) + if user_state is not None: + conn.execute(user_upsert_sql, (app_name, user_id, to_json(user_state), now_julian)) + except Exception: + conn.rollback() + raise + else: + if row is None: + conn.rollback() + else: + conn.commit() if row is None: msg = f"Session {session_id} not found during append_event_and_update_state." @@ -410,28 +503,20 @@ def _append_event_and_update_state( update_time=_julian_to_datetime(row[5]), ) - async def append_event_and_update_state( - self, event_record: EventRecord, session_id: str, state: "dict[str, Any]" - ) -> SessionRecord: - """Atomically append an event and update the session's durable state. - - Inserts the event and updates the session state + update_time in a - single transaction, returning the updated SessionRecord via RETURNING. - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (temp: keys already - stripped by the service layer). - """ - return await async_(self._append_event_and_update_state)(event_record, session_id, state) - def _get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Synchronous implementation of get_events.""" - where_clauses = ["session_id = ?"] - params: list[Any] = [session_id] + if limit == 0: + return [] + + where_clauses = ["app_name = ?", "user_id = ?", "session_id = ?"] + params: list[Any] = [app_name, user_id, session_id] if after_timestamp is not None: where_clauses.append("timestamp > ?") @@ -441,7 +526,7 @@ def _get_events( limit_clause = f" LIMIT {limit}" if limit else "" sql = f""" - SELECT id, session_id, invocation_id, author, timestamp, event_data + SELECT id, app_name, user_id, session_id, invocation_id, timestamp, event_data FROM {self._events_table} WHERE {where_clause} ORDER BY timestamp ASC{limit_clause} @@ -455,11 +540,13 @@ def _get_events( return [ EventRecord( - session_id=row[1], - invocation_id=row[2], - author=row[3], - timestamp=_julian_to_datetime(row[4]), - event_json=from_json(row[5]) if row[5] else {}, + id=row[0], + app_name=row[1], + user_id=row[2], + session_id=row[3], + invocation_id=row[4], + timestamp=_julian_to_datetime(row[5]), + event_data=from_json(row[6]) if row[6] else {}, ) for row in rows ] @@ -468,23 +555,239 @@ def _get_events( return [] raise - async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None - ) -> "list[EventRecord]": - """Get events for a session. + def _delete_expired_events(self, before: datetime) -> int: + """Synchronous implementation of delete_expired_events.""" + sql = f"DELETE FROM {self._events_table} WHERE timestamp < ?" - Args: - session_id: Session identifier. - after_timestamp: Only return events after this time. - limit: Maximum number of events to return. + try: + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + cursor = conn.execute(sql, (_datetime_to_julian(before),)) + deleted_count = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + conn.commit() + return deleted_count + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return 0 + raise + + def _delete_idle_sessions(self, updated_before: datetime) -> int: + """Synchronous implementation of delete_idle_sessions.""" + sql = f"DELETE FROM {self._session_table} WHERE update_time < ?" + + try: + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + cursor = conn.execute(sql, (_datetime_to_julian(updated_before),)) + deleted_count = cursor.rowcount if cursor.rowcount and cursor.rowcount > 0 else 0 + conn.commit() + return deleted_count + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return 0 + raise + + def _get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_app_state.""" + sql = f"SELECT state FROM {self._app_state_table} WHERE app_name = ?" + + try: + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + cursor = conn.execute(sql, (app_name,)) + row = cursor.fetchone() + return from_json(row[0]) if row is not None and row[0] else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + def _get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Synchronous implementation of get_user_state.""" + sql = f""" + SELECT state + FROM {self._user_state_table} + WHERE app_name = ? AND user_id = ? + """ + + try: + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + cursor = conn.execute(sql, (app_name, user_id)) + row = cursor.fetchone() + return from_json(row[0]) if row is not None and row[0] else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + def _upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_app_state.""" + sql = f""" + INSERT INTO {self._app_state_table} (app_name, state, update_time) + VALUES (?, ?, ?) + ON CONFLICT(app_name) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + conn.execute(sql, (app_name, to_json(state), _datetime_to_julian(datetime.now(timezone.utc)))) + conn.commit() + + def _upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Synchronous implementation of upsert_user_state.""" + sql = f""" + INSERT INTO {self._user_state_table} (app_name, user_id, state, update_time) + VALUES (?, ?, ?, ?) + ON CONFLICT(app_name, user_id) DO UPDATE SET + state = excluded.state, + update_time = excluded.update_time + """ + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + conn.execute(sql, (app_name, user_id, to_json(state), _datetime_to_julian(datetime.now(timezone.utc)))) + conn.commit() + + def _get_metadata(self, key: str) -> "str | None": + """Synchronous implementation of get_metadata.""" + sql = f"SELECT value FROM {self._metadata_table} WHERE key = ?" + + try: + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + cursor = conn.execute(sql, (key,)) + row = cursor.fetchone() + return str(row[0]) if row is not None else None + except sqlite3.OperationalError as exc: + if SQLITE_TABLE_NOT_FOUND_ERROR in str(exc): + return None + raise + + def _set_metadata(self, key: str, value: str) -> None: + """Synchronous implementation of set_metadata.""" + sql = f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE SET value = excluded.value + """ + + with self._config.provide_connection() as conn: + self._apply_pragmas(conn) + conn.execute(sql, (key, value)) + conn.commit() + + def _get_create_sessions_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for sessions. Returns: - List of event records ordered by timestamp ASC. + SQL statement to create adk_session table with indexes. + """ + owner_id_line = "" + if self._owner_id_column_ddl: + owner_id_line = f",\n {self._owner_id_column_ddl}" + + return f""" + CREATE TABLE IF NOT EXISTS {self._session_table} ( + id TEXT PRIMARY KEY, + app_name TEXT NOT NULL, + user_id TEXT NOT NULL{owner_id_line}, + state TEXT NOT NULL DEFAULT '{{}}', + create_time REAL NOT NULL, + update_time REAL NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_app_user + ON {self._session_table}(app_name, user_id); + CREATE INDEX IF NOT EXISTS idx_{self._session_table}_update_time + ON {self._session_table}(update_time DESC); + """ + + def _get_create_events_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for events.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._events_table} ( + id TEXT PRIMARY KEY, + app_name TEXT NOT NULL, + user_id TEXT NOT NULL, + session_id TEXT NOT NULL, + invocation_id TEXT, + timestamp REAL NOT NULL, + event_data TEXT NOT NULL, + FOREIGN KEY (session_id) REFERENCES {self._session_table}(id) ON DELETE CASCADE + ); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_session + ON {self._events_table}(app_name, user_id, session_id, timestamp ASC); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_invocation + ON {self._events_table}(invocation_id); + CREATE INDEX IF NOT EXISTS idx_{self._events_table}_timestamp + ON {self._events_table}(timestamp ASC); + """ + + def _get_create_app_states_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for app-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._app_state_table} ( + app_name TEXT PRIMARY KEY, + state TEXT NOT NULL DEFAULT '{{}}', + update_time REAL NOT NULL + ); + """ + + def _get_create_user_states_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for user-scoped state.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._user_state_table} ( + app_name TEXT NOT NULL, + user_id TEXT NOT NULL, + state TEXT NOT NULL DEFAULT '{{}}', + update_time REAL NOT NULL, + PRIMARY KEY (app_name, user_id) + ); + """ + + def _get_create_metadata_table_sql(self) -> str: + """Get SQLite CREATE TABLE SQL for ADK internal metadata.""" + return f""" + CREATE TABLE IF NOT EXISTS {self._metadata_table} ( + key TEXT PRIMARY KEY, + value TEXT NOT NULL + ); + """ + + def _get_seed_metadata_sql(self) -> str: + """Get SQLite SQL to seed the ADK schema-version metadata row.""" + return f""" + INSERT INTO {self._metadata_table} (key, value) + VALUES ('schema_version', '1') + ON CONFLICT(key) DO NOTHING; """ - return await async_(self._get_events)(session_id, after_timestamp, limit) + def _get_drop_app_states_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for app-scoped state.""" + return f"DROP TABLE IF EXISTS {self._app_state_table}" + + def _get_drop_user_states_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for user-scoped state.""" + return f"DROP TABLE IF EXISTS {self._user_state_table}" + + def _get_drop_metadata_table_sql(self) -> str: + """Get SQLite DROP TABLE SQL for ADK internal metadata.""" + return f"DROP TABLE IF EXISTS {self._metadata_table}" -class SqliteADKMemoryStore(BaseAsyncADKMemoryStore["SqliteConfig"]): + def _get_drop_tables_sql(self) -> "list[str]": + """Get SQLite DROP TABLE SQL statements.""" + return [ + self._get_drop_metadata_table_sql(), + self._get_drop_user_states_table_sql(), + self._get_drop_app_states_table_sql(), + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + +class SqliteADKMemoryStore(BaseSyncADKMemoryStore["SqliteConfig"]): """SQLite ADK memory store using synchronous SQLite driver. Implements memory entry storage for Google Agent Development Kit @@ -510,7 +813,29 @@ def __init__(self, config: "SqliteConfig") -> None: """ super().__init__(config) - async def _get_create_memory_table_sql(self) -> str: + def create_tables(self) -> None: + """Create tables if they don't exist.""" + self._create_tables() + + def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: + """Bulk insert memory entries with deduplication.""" + return self._insert_memory_entries(entries, owner_id) + + def search_entries( + self, query: str, app_name: str, user_id: str, limit: "int | None" = None + ) -> "list[MemoryRecord]": + """Search memory entries by text query.""" + return self._search_entries(query, app_name, user_id, limit) + + def delete_entries_by_session(self, session_id: str) -> int: + """Delete all memory entries for a specific session.""" + return self._delete_entries_by_session(session_id) + + def delete_entries_older_than(self, days: int) -> int: + """Delete memory entries older than specified days.""" + return self._delete_entries_older_than(days) + + def _get_create_memory_table_sql(self) -> str: """Get SQLite CREATE TABLE SQL for memory entries. Returns: @@ -597,11 +922,7 @@ def _create_tables(self) -> None: with self._config.provide_session() as driver: self._enable_foreign_keys(driver.connection) - driver.execute_script(run_(self._get_create_memory_table_sql)()) - - async def create_tables(self) -> None: - """Create tables if they don't exist.""" - await async_(self._create_tables)() + driver.execute_script(self._get_create_memory_table_sql()) def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication. @@ -687,10 +1008,6 @@ def _insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "objec return inserted_count - async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: - """Bulk insert memory entries with deduplication.""" - return await async_(self._insert_memory_entries)(entries, owner_id) - def _search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None ) -> "list[MemoryRecord]": @@ -721,12 +1038,6 @@ def _search_entries( logger.warning("FTS search failed; falling back to simple search: %s", exc) return self._search_entries_simple(query, app_name, user_id, effective_limit) - async def search_entries( - self, query: str, app_name: str, user_id: str, limit: "int | None" = None - ) -> "list[MemoryRecord]": - """Search memory entries by text query.""" - return await async_(self._search_entries)(query, app_name, user_id, limit) - def _search_entries_fts(self, query: str, app_name: str, user_id: str, limit: int) -> "list[MemoryRecord]": sql = f""" SELECT m.id, m.session_id, m.app_name, m.user_id, m.event_id, m.author, @@ -798,10 +1109,6 @@ def _delete_entries_by_session(self, session_id: str) -> int: return deleted_count - async def delete_entries_by_session(self, session_id: str) -> int: - """Delete all memory entries for a specific session.""" - return await async_(self._delete_entries_by_session)(session_id) - def _delete_entries_older_than(self, days: int) -> int: """Delete memory entries older than specified days. @@ -825,10 +1132,6 @@ def _delete_entries_older_than(self, days: int) -> int: return deleted_count - async def delete_entries_older_than(self, days: int) -> int: - """Delete memory entries older than specified days.""" - return await async_(self._delete_entries_older_than)(days) - def _datetime_to_julian(dt: datetime) -> float: """Convert datetime to Julian Day number for SQLite storage. diff --git a/sqlspec/extensions/adk/_config_utils.py b/sqlspec/extensions/adk/_config_utils.py index eb63fafa6..405f9fb17 100644 --- a/sqlspec/extensions/adk/_config_utils.py +++ b/sqlspec/extensions/adk/_config_utils.py @@ -28,6 +28,9 @@ class _ADKSessionStoreConfig(TypedDict): session_table: str events_table: str + app_state_table: str + user_state_table: str + metadata_table: str owner_id_column: NotRequired[str] @@ -45,6 +48,7 @@ class _ADKArtifactStoreConfig(TypedDict): """Normalized ADK artifact store configuration.""" artifact_table: str + storage_uri: NotRequired[str] class _ADKConfigSource(Protocol): @@ -66,11 +70,12 @@ def _get_adk_session_store_config(config: _ADKConfigSource) -> _ADKSessionStoreC """Return normalized session store table settings.""" adk_config = _get_adk_config_from_extension(config) - session_table = adk_config.get("session_table") - events_table = adk_config.get("events_table") result: _ADKSessionStoreConfig = { - "session_table": str(session_table) if session_table is not None else "adk_sessions", - "events_table": str(events_table) if events_table is not None else "adk_events", + "session_table": str(adk_config.get("session_table") or "adk_session"), + "events_table": str(adk_config.get("events_table") or "adk_event"), + "app_state_table": str(adk_config.get("app_state_table") or "adk_app_state"), + "user_state_table": str(adk_config.get("user_state_table") or "adk_user_state"), + "metadata_table": str(adk_config.get("metadata_table") or "adk_internal_metadata"), } owner_id = adk_config.get("owner_id_column") if owner_id is not None: @@ -83,14 +88,11 @@ def _get_adk_memory_store_config(config: _ADKConfigSource) -> _ADKMemoryStoreCon adk_config = _get_adk_config_from_extension(config) enable_memory = adk_config.get("enable_memory") - memory_table = adk_config.get("memory_table") - use_fts = adk_config.get("memory_use_fts") max_results = adk_config.get("memory_max_results") - result: _ADKMemoryStoreConfig = { "enable_memory": bool(enable_memory) if enable_memory is not None else True, - "memory_table": str(memory_table) if memory_table is not None else "adk_memory_entries", - "use_fts": bool(use_fts) if use_fts is not None else False, + "memory_table": str(adk_config.get("memory_table") or "adk_memory"), + "use_fts": bool(adk_config.get("memory_use_fts", False)), "max_results": int(max_results) if isinstance(max_results, int) else 20, } owner_id = adk_config.get("owner_id_column") @@ -103,8 +105,11 @@ def _get_adk_artifact_store_config(config: _ADKConfigSource) -> _ADKArtifactStor """Return normalized artifact store settings.""" adk_config = _get_adk_config_from_extension(config) - artifact_table = adk_config.get("artifact_table") - return {"artifact_table": str(artifact_table) if artifact_table is not None else "adk_artifact_versions"} + result: _ADKArtifactStoreConfig = {"artifact_table": str(adk_config.get("artifact_table") or "adk_artifact")} + storage_uri = adk_config.get("artifact_storage_uri") + if storage_uri is not None: + result["storage_uri"] = str(storage_uri) + return result def _resolve_adk_store_path(config: Any, store_suffix: str) -> str: @@ -178,7 +183,8 @@ def _is_adk_memory_migration_enabled(config: Any) -> bool: include_memory = adk_config.get("include_memory_migration") if include_memory is not None: return bool(include_memory) - return bool(adk_config.get("enable_memory", True)) + enable_memory = adk_config.get("enable_memory") + return bool(enable_memory) if enable_memory is not None else True def _validate_adk_store_registration(config: Any) -> None: diff --git a/sqlspec/extensions/adk/_types.py b/sqlspec/extensions/adk/_types.py index 3f11b62f0..476995708 100644 --- a/sqlspec/extensions/adk/_types.py +++ b/sqlspec/extensions/adk/_types.py @@ -27,15 +27,17 @@ class SessionRecord(TypedDict): class EventRecord(TypedDict): """Database record for an event. - Stores the full ADK Event as a single JSON blob (``event_json``) alongside - a small number of indexed scalar columns used for query filtering. + Stores the full ADK Event as a single JSON blob (``event_data``) alongside + indexed scalar columns used for scoped query filtering. This design eliminates column drift with upstream ADK: new Event fields are - automatically captured in ``event_json`` without schema changes. + automatically captured in ``event_data`` without schema changes. """ + id: str + app_name: str + user_id: str session_id: str invocation_id: str - author: str timestamp: datetime - event_json: "dict[str, Any]" + event_data: "dict[str, Any]" diff --git a/sqlspec/extensions/adk/converters.py b/sqlspec/extensions/adk/converters.py index 5963ce620..f7b34303e 100644 --- a/sqlspec/extensions/adk/converters.py +++ b/sqlspec/extensions/adk/converters.py @@ -1,9 +1,9 @@ """Conversion functions between ADK models and database records. Implements full-event JSON storage: the entire Event is serialized via -``Event.model_dump_json(exclude_none=True)`` into a single ``event_json`` +``Event.model_dump(exclude_none=True, mode="json")`` into a single ``event_data`` column, with a small set of indexed scalar columns extracted alongside for -query performance. Reconstruction uses ``Event.model_validate_json()``. +query performance. Reconstruction uses ``Event.model_validate()``. Also provides scoped-state helpers that normalise ADK state prefixes (``app:``, ``user:``, ``temp:``) so the shared service layer can split, @@ -106,34 +106,39 @@ def record_to_session(record: SessionRecord, events: "list[EventRecord]") -> "Se # --------------------------------------------------------------------------- -def event_to_record(event: "Event", session_id: str) -> EventRecord: +def event_to_record(event: "Event", app_name: str, user_id: str, session_id: str) -> EventRecord: """Convert ADK Event to database record using full-event JSON storage. - The entire Event is serialized into ``event_json`` via Pydantic's - ``model_dump_json(exclude_none=True)``. A small number of indexed scalar - columns are extracted alongside for query performance. + The entire Event is serialized into ``event_data`` via Pydantic's + ``model_dump(exclude_none=True, mode="json")``. Indexed scalar columns are + extracted alongside for scoped filtering. Args: event: ADK Event object. + app_name: Name of the parent app. + user_id: ID of the parent user. session_id: ID of the parent session. Returns: EventRecord for database storage. """ + event_data = _normalize_event_data(event.model_dump(exclude_none=True, mode="json")) return EventRecord( + id=event.id, + app_name=app_name, + user_id=user_id, session_id=session_id, invocation_id=event.invocation_id, - author=event.author, timestamp=datetime.fromtimestamp(event.timestamp, tz=timezone.utc), - event_json=event.model_dump(exclude_none=True, mode="json"), + event_data=event_data, ) def record_to_event(record: "EventRecord") -> "Event": """Convert database record to ADK Event. - Reconstruction is lossless: the full Event is restored from - ``event_json`` via ``Event.model_validate_json()``. + Reconstruction is lossless for valid ADK payloads: the full Event is + restored from ``event_data`` via ``Event.model_validate()``. Args: record: Event database record. @@ -141,7 +146,11 @@ def record_to_event(record: "EventRecord") -> "Event": Returns: ADK Event object. """ - return Event.model_validate(record["event_json"]) + event_data = _normalize_event_data(record["event_data"]) + event_data.setdefault("id", record["id"]) + event_data.setdefault("invocation_id", record["invocation_id"]) + event_data.setdefault("timestamp", record["timestamp"].timestamp()) + return Event.model_validate(event_data) # --------------------------------------------------------------------------- @@ -213,3 +222,17 @@ def merge_scoped_state( if user_state is not None: merged.update(user_state) return merged + + +def _normalize_event_data(event_data: "dict[str, Any]") -> "dict[str, Any]": + """Return event data acceptable to ADK 2.2's Event model. + + ADK 2.2 guards an assigned ``event.actions = None`` during service writes, + but explicit ``actions: null`` does not validate as a durable Event shape. + SQLSpec therefore omits that key before storing or restoring payloads. + """ + + normalized = dict(event_data) + if normalized.get("actions") is None: + normalized.pop("actions", None) + return normalized diff --git a/sqlspec/extensions/adk/memory/store.py b/sqlspec/extensions/adk/memory/store.py index dae884220..5f93abc1b 100644 --- a/sqlspec/extensions/adk/memory/store.py +++ b/sqlspec/extensions/adk/memory/store.py @@ -23,6 +23,7 @@ VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") COLUMN_NAME_PATTERN: Final = re.compile(r"^(\w+)") MAX_TABLE_NAME_LENGTH: Final = 63 +ADK_RESET_MEMORY_TABLES: Final = ("adk_memory", "adk_memory_entries") class BaseAsyncADKMemoryStore(ABC, Generic[ConfigT]): @@ -73,14 +74,6 @@ def __init__(self, config: ConfigT) -> None: ) _validate_table_name(self._memory_table) - def _get_store_config_from_extension(self) -> "_ADKMemoryStoreConfig": - """Extract ADK memory configuration from config.extension_config. - - Returns: - Dict with memory_table, use_fts, max_results, and optionally owner_id_column. - """ - return _get_adk_memory_store_config(self._config) - @property def config(self) -> ConfigT: """Return the database configuration.""" @@ -151,25 +144,6 @@ async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: " """ raise NotImplementedError - def _log_memory_table_created(self) -> None: - log_with_context( - logger, - logging.DEBUG, - "adk.memory.table.ready", - db_system=resolve_db_system(type(self).__name__), - memory_table=self._memory_table, - ) - - def _log_memory_table_skipped(self) -> None: - log_with_context( - logger, - logging.DEBUG, - "adk.memory.table.skipped", - db_system=resolve_db_system(type(self).__name__), - memory_table=self._memory_table, - reason="disabled", - ) - @abstractmethod async def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None @@ -218,6 +192,14 @@ async def delete_entries_older_than(self, days: int) -> int: """ raise NotImplementedError + def _get_store_config_from_extension(self) -> "_ADKMemoryStoreConfig": + """Extract ADK memory configuration from config.extension_config. + + Returns: + Dict with memory_table, use_fts, max_results, and optionally owner_id_column. + """ + return _get_adk_memory_store_config(self._config) + @abstractmethod async def _get_create_memory_table_sql(self) -> "str | list[str]": """Get the CREATE TABLE SQL for the memory table. @@ -236,50 +218,39 @@ def _get_drop_memory_table_sql(self) -> "list[str]": """ raise NotImplementedError + def _get_reset_drop_memory_table_sql(self) -> "list[str]": + """Return memory drops needed before recreating the clean-break schema.""" + statements = list(self._get_drop_memory_table_sql()) + for table_name in ADK_RESET_MEMORY_TABLES: + statements.extend(self._get_drop_memory_table_sql_for_table(table_name)) + return _deduplicate_statements(statements) + + def _get_drop_memory_table_sql_for_table(self, table_name: str) -> "list[str]": + current_table = self._memory_table + self._memory_table = table_name + try: + return list(self._get_drop_memory_table_sql()) + finally: + self._memory_table = current_table -def _parse_owner_id_column(owner_id_column_ddl: str) -> str: - """Extract column name from owner ID column DDL definition. - - Args: - owner_id_column_ddl: Full column DDL string. - - Returns: - Column name only (first word). - - Raises: - ValueError: If DDL format is invalid. - """ - match = COLUMN_NAME_PATTERN.match(owner_id_column_ddl.strip()) - if not match: - msg = f"Invalid owner_id_column DDL: {owner_id_column_ddl!r}. Must start with column name." - raise ValueError(msg) - - return match.group(1) - - -def _validate_table_name(table_name: str) -> None: - """Validate table name for SQL safety. - - Args: - table_name: Table name to validate. - - Raises: - ValueError: If table name is invalid. - """ - if not table_name: - msg = "Table name cannot be empty" - raise ValueError(msg) - - if len(table_name) > MAX_TABLE_NAME_LENGTH: - msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})" - raise ValueError(msg) + def _log_memory_table_created(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.memory.table.ready", + db_system=resolve_db_system(type(self).__name__), + memory_table=self._memory_table, + ) - if not VALID_TABLE_NAME_PATTERN.match(table_name): - msg = ( - f"Invalid table name: {table_name!r}. " - "Must start with letter/underscore and contain only alphanumeric characters and underscores" + def _log_memory_table_skipped(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.memory.table.skipped", + db_system=resolve_db_system(type(self).__name__), + memory_table=self._memory_table, + reason="disabled", ) - raise ValueError(msg) class BaseSyncADKMemoryStore(ABC, Generic[ConfigT]): @@ -330,14 +301,6 @@ def __init__(self, config: ConfigT) -> None: ) _validate_table_name(self._memory_table) - def _get_store_config_from_extension(self) -> "_ADKMemoryStoreConfig": - """Extract ADK memory configuration from config.extension_config. - - Returns: - Dict with memory_table, use_fts, max_results, and optionally owner_id_column. - """ - return _get_adk_memory_store_config(self._config) - @property def config(self) -> ConfigT: """Return the database configuration.""" @@ -408,25 +371,6 @@ def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object """ raise NotImplementedError - def _log_memory_table_created(self) -> None: - log_with_context( - logger, - logging.DEBUG, - "adk.memory.table.ready", - db_system=resolve_db_system(type(self).__name__), - memory_table=self._memory_table, - ) - - def _log_memory_table_skipped(self) -> None: - log_with_context( - logger, - logging.DEBUG, - "adk.memory.table.skipped", - db_system=resolve_db_system(type(self).__name__), - memory_table=self._memory_table, - reason="disabled", - ) - @abstractmethod def search_entries( self, query: str, app_name: str, user_id: str, limit: "int | None" = None @@ -475,6 +419,14 @@ def delete_entries_older_than(self, days: int) -> int: """ raise NotImplementedError + def _get_store_config_from_extension(self) -> "_ADKMemoryStoreConfig": + """Extract ADK memory configuration from config.extension_config. + + Returns: + Dict with memory_table, use_fts, max_results, and optionally owner_id_column. + """ + return _get_adk_memory_store_config(self._config) + @abstractmethod def _get_create_memory_table_sql(self) -> "str | list[str]": """Get the CREATE TABLE SQL for the memory table. @@ -484,6 +436,40 @@ def _get_create_memory_table_sql(self) -> "str | list[str]": """ raise NotImplementedError + def _get_reset_drop_memory_table_sql(self) -> "list[str]": + """Return memory drops needed before recreating the clean-break schema.""" + statements = list(self._get_drop_memory_table_sql()) + for table_name in ADK_RESET_MEMORY_TABLES: + statements.extend(self._get_drop_memory_table_sql_for_table(table_name)) + return _deduplicate_statements(statements) + + def _get_drop_memory_table_sql_for_table(self, table_name: str) -> "list[str]": + current_table = self._memory_table + self._memory_table = table_name + try: + return list(self._get_drop_memory_table_sql()) + finally: + self._memory_table = current_table + + def _log_memory_table_created(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.memory.table.ready", + db_system=resolve_db_system(type(self).__name__), + memory_table=self._memory_table, + ) + + def _log_memory_table_skipped(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.memory.table.skipped", + db_system=resolve_db_system(type(self).__name__), + memory_table=self._memory_table, + reason="disabled", + ) + @abstractmethod def _get_drop_memory_table_sql(self) -> "list[str]": """Get the DROP TABLE SQL statements for this database dialect. @@ -492,3 +478,59 @@ def _get_drop_memory_table_sql(self) -> "list[str]": List of SQL statements to drop the memory table and indexes. """ raise NotImplementedError + + +def _deduplicate_statements(statements: "list[str]") -> "list[str]": + seen: set[str] = set() + result: list[str] = [] + for statement in statements: + if statement in seen: + continue + result.append(statement) + seen.add(statement) + return result + + +def _parse_owner_id_column(owner_id_column_ddl: str) -> str: + """Extract column name from owner ID column DDL definition. + + Args: + owner_id_column_ddl: Full column DDL string. + + Returns: + Column name only (first word). + + Raises: + ValueError: If DDL format is invalid. + """ + match = COLUMN_NAME_PATTERN.match(owner_id_column_ddl.strip()) + if not match: + msg = f"Invalid owner_id_column DDL: {owner_id_column_ddl!r}. Must start with column name." + raise ValueError(msg) + + return match.group(1) + + +def _validate_table_name(table_name: str) -> None: + """Validate table name for SQL safety. + + Args: + table_name: Table name to validate. + + Raises: + ValueError: If table name is invalid. + """ + if not table_name: + msg = "Table name cannot be empty" + raise ValueError(msg) + + if len(table_name) > MAX_TABLE_NAME_LENGTH: + msg = f"Table name too long: {len(table_name)} chars (max {MAX_TABLE_NAME_LENGTH})" + raise ValueError(msg) + + if not VALID_TABLE_NAME_PATTERN.match(table_name): + msg = ( + f"Invalid table name: {table_name!r}. " + "Must start with letter/underscore and contain only alphanumeric characters and underscores" + ) + raise ValueError(msg) diff --git a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py index f44e6f5cf..8381a35c5 100644 --- a/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py +++ b/sqlspec/extensions/adk/migrations/0001_create_adk_tables.py @@ -1,160 +1,22 @@ -"""Create ADK session, events, and memory tables migration using store DDL definitions.""" +"""No-op migration: superseded by 0002_reset_adk_tables. -import inspect -import logging -from typing import TYPE_CHECKING, NoReturn, cast +This file used to create the legacy ADK ``sessions`` / ``events`` tables. The +ADK 2.0 clean break replaces that schema in 0002. 0001 is retained as a no-op +so installs that already applied it keep their tracking-table row; fresh +installs run it as a no-op and proceed to 0002. +""" -from sqlspec.exceptions import SQLSpecError -from sqlspec.extensions.adk._config_utils import ( - _get_adk_adapter_store_class, - _get_adk_memory_migration_store_class, - _is_adk_memory_migration_enabled, -) -from sqlspec.utils.logging import get_logger, log_with_context +from typing import TYPE_CHECKING if TYPE_CHECKING: - from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore - from sqlspec.extensions.adk.store import BaseAsyncADKStore from sqlspec.migrations.context import MigrationContext __all__ = ("down", "up") -logger = get_logger("sqlspec.migrations.adk.tables") - async def up(context: "MigrationContext | None" = None) -> "list[str]": - """Create the ADK session, events, and memory tables using store DDL definitions. - - This migration delegates to the appropriate store class to generate - dialect-specific DDL. The store classes contain the single source of - truth for table schemas. - - Args: - context: Migration context containing config. - - Returns: - List of SQL statements to execute for upgrade. - """ - if context is None or context.config is None: - _raise_missing_config() - - store_class = _get_store_class(context) - store_instance = store_class(config=context.config) - - statements = [ - await store_instance._get_create_sessions_table_sql(), # pyright: ignore[reportPrivateUsage] - await store_instance._get_create_events_table_sql(), # pyright: ignore[reportPrivateUsage] - ] - - if _is_memory_enabled(context): - memory_store_class = _get_memory_store_class(context) - if memory_store_class is not None: - memory_store = memory_store_class(config=context.config) - memory_sql = memory_store._get_create_memory_table_sql() # pyright: ignore[reportPrivateUsage] - if inspect.isawaitable(memory_sql): - memory_sql = await memory_sql - if isinstance(memory_sql, list): - statements.extend(memory_sql) - else: - statements.append(memory_sql) - log_with_context( - logger, logging.DEBUG, "adk.migration.memory.include", table_name=memory_store.memory_table - ) - - return statements - - -def _get_store_class(context: "MigrationContext | None") -> "type[BaseAsyncADKStore]": - """Get the appropriate store class based on the config's module path. - - Args: - context: Migration context containing config. - - Returns: - Store class matching the config's adapter. - """ - if not context or not context.config: - _raise_missing_config() - - return cast("type[BaseAsyncADKStore]", _get_adk_adapter_store_class(context.config, "ADKStore")) - - -def _get_memory_store_class( - context: "MigrationContext | None", -) -> "type[BaseAsyncADKMemoryStore | BaseSyncADKMemoryStore] | None": - """Get the appropriate memory store class based on the config's module path. - - Args: - context: Migration context containing config. - - Returns: - Memory store class matching the config's adapter, or None if not available. - """ - if not context or not context.config: - return None - - store_class = _get_adk_memory_migration_store_class(context.config) - if store_class is None: - log_with_context(logger, logging.DEBUG, "adk.migration.memory_store.missing") - return None - return cast("type[BaseAsyncADKMemoryStore | BaseSyncADKMemoryStore]", store_class) - - -def _is_memory_enabled(context: "MigrationContext | None") -> bool: - """Check if memory migration is enabled in the config. - - Args: - context: Migration context containing config. - - Returns: - True if memory migration should be included, False otherwise. - """ - if not context or not context.config: - return False - - return _is_adk_memory_migration_enabled(context.config) - - -def _raise_missing_config() -> NoReturn: - """Raise error when migration context has no config. - - Raises: - SQLSpecError: Always raised. - """ - msg = "Migration context must have a config to determine store class" - raise SQLSpecError(msg) + return [] async def down(context: "MigrationContext | None" = None) -> "list[str]": - """Drop the ADK session, events, and memory tables using store DDL definitions. - - This migration delegates to the appropriate store class to generate - dialect-specific DROP statements. The store classes contain the single - source of truth for table schemas. - - Args: - context: Migration context containing config. - - Returns: - List of SQL statements to execute for downgrade. - """ - if context is None or context.config is None: - _raise_missing_config() - - statements: list[str] = [] - - if _is_memory_enabled(context): - memory_store_class = _get_memory_store_class(context) - if memory_store_class is not None: - memory_store = memory_store_class(config=context.config) - memory_drop_stmts = memory_store._get_drop_memory_table_sql() # pyright: ignore[reportPrivateUsage] - statements.extend(memory_drop_stmts) - log_with_context( - logger, logging.DEBUG, "adk.migration.memory.drop.include", table_name=memory_store.memory_table - ) - - store_class = _get_store_class(context) - store_instance = store_class(config=context.config) - statements.extend(store_instance._get_drop_tables_sql()) # pyright: ignore[reportPrivateUsage] - - return statements + return [] diff --git a/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py b/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py new file mode 100644 index 000000000..19510baf5 --- /dev/null +++ b/sqlspec/extensions/adk/migrations/0002_reset_adk_tables.py @@ -0,0 +1,127 @@ +"""Reset ADK schema to the 2.0 clean-break shape. + +Unconditionally drops any legacy ADK tables (sessions, events, app_states, +user_states, metadata, memory) then creates the new schema and seeds the +internal metadata row. The memory table is dropped unconditionally so users +moving from ``enable_memory=True`` to ``enable_memory=False`` get cleanup; it +is recreated only when memory is enabled for the current config. +""" + +import inspect +import logging +from typing import TYPE_CHECKING, NoReturn, cast + +from sqlspec.exceptions import SQLSpecError +from sqlspec.extensions.adk._config_utils import ( + _get_adk_adapter_store_class, + _get_adk_memory_migration_store_class, + _is_adk_memory_migration_enabled, +) +from sqlspec.utils.logging import get_logger, log_with_context + +if TYPE_CHECKING: + from collections.abc import Awaitable + + from sqlspec.extensions.adk.memory.store import BaseAsyncADKMemoryStore, BaseSyncADKMemoryStore + from sqlspec.extensions.adk.store import BaseAsyncADKStore, BaseSyncADKStore + from sqlspec.migrations.context import MigrationContext + +__all__ = ("down", "up") + +logger = get_logger("sqlspec.migrations.adk.reset") + + +async def up(context: "MigrationContext | None" = None) -> "list[str]": + if context is None or context.config is None: + _raise_missing_config() + + store_class = _get_store_class(context) + store_instance = store_class(config=context.config) + + statements: list[str] = [] + + memory_store_class = _get_memory_store_class(context) + if memory_store_class is not None: + memory_store = memory_store_class(config=context.config) + statements.extend(memory_store._get_reset_drop_memory_table_sql()) # pyright: ignore[reportPrivateUsage] + log_with_context(logger, logging.DEBUG, "adk.migration.reset.memory.drop", table_name=memory_store.memory_table) + + statements.extend(store_instance._get_reset_drop_tables_sql()) # pyright: ignore[reportPrivateUsage] + + statements.extend([ + await _resolve_sql(store_instance._get_create_sessions_table_sql()), # pyright: ignore[reportPrivateUsage] + await _resolve_sql(store_instance._get_create_events_table_sql()), # pyright: ignore[reportPrivateUsage] + await _resolve_sql(store_instance._get_create_app_states_table_sql()), # pyright: ignore[reportPrivateUsage] + await _resolve_sql(store_instance._get_create_user_states_table_sql()), # pyright: ignore[reportPrivateUsage] + await _resolve_sql(store_instance._get_create_metadata_table_sql()), # pyright: ignore[reportPrivateUsage] + await _resolve_sql(store_instance._get_seed_metadata_sql()), # pyright: ignore[reportPrivateUsage] + ]) + + if _is_memory_enabled(context) and memory_store_class is not None: + memory_store = memory_store_class(config=context.config) + memory_sql = memory_store._get_create_memory_table_sql() # pyright: ignore[reportPrivateUsage] + if inspect.isawaitable(memory_sql): + memory_sql = await memory_sql + if isinstance(memory_sql, list): + statements.extend(memory_sql) + else: + statements.append(memory_sql) + log_with_context( + logger, logging.DEBUG, "adk.migration.reset.memory.create", table_name=memory_store.memory_table + ) + + return statements + + +async def down(context: "MigrationContext | None" = None) -> "list[str]": + if context is None or context.config is None: + _raise_missing_config() + + statements: list[str] = [] + store_class = _get_store_class(context) + store_instance = store_class(config=context.config) + + if _is_memory_enabled(context): + memory_store_class = _get_memory_store_class(context) + if memory_store_class is not None: + memory_store = memory_store_class(config=context.config) + statements.extend(memory_store._get_reset_drop_memory_table_sql()) # pyright: ignore[reportPrivateUsage] + + statements.extend(store_instance._get_reset_drop_tables_sql()) # pyright: ignore[reportPrivateUsage] + + return statements + + +def _raise_missing_config() -> NoReturn: + msg = "Migration context must have a config to determine store class" + raise SQLSpecError(msg) + + +def _get_store_class(context: "MigrationContext | None") -> "type[BaseAsyncADKStore | BaseSyncADKStore]": + if not context or not context.config: + _raise_missing_config() + return cast("type[BaseAsyncADKStore | BaseSyncADKStore]", _get_adk_adapter_store_class(context.config, "ADKStore")) + + +async def _resolve_sql(value: "str | Awaitable[str]") -> str: + if inspect.isawaitable(value): + return await value + return value + + +def _get_memory_store_class( + context: "MigrationContext | None", +) -> "type[BaseAsyncADKMemoryStore | BaseSyncADKMemoryStore] | None": + if not context or not context.config: + return None + store_class = _get_adk_memory_migration_store_class(context.config) + if store_class is None: + log_with_context(logger, logging.DEBUG, "adk.migration.reset.memory_store.missing") + return None + return cast("type[BaseAsyncADKMemoryStore | BaseSyncADKMemoryStore]", store_class) + + +def _is_memory_enabled(context: "MigrationContext | None") -> bool: + if not context or not context.config: + return False + return _is_adk_memory_migration_enabled(context.config) diff --git a/sqlspec/extensions/adk/service.py b/sqlspec/extensions/adk/service.py index 109fac9af..73a1f595c 100644 --- a/sqlspec/extensions/adk/service.py +++ b/sqlspec/extensions/adk/service.py @@ -1,9 +1,10 @@ """SQLSpec-backed session service for Google ADK.""" +import inspect import logging import uuid from datetime import datetime, timezone -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from google.adk.sessions.base_session_service import BaseSessionService, GetSessionConfig, ListSessionsResponse @@ -11,15 +12,22 @@ compute_update_marker, event_to_record, filter_temp_state, + merge_scoped_state, record_to_session, + split_scoped_state, ) from sqlspec.utils.logging import get_logger, log_with_context +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: + from collections.abc import Callable + from google.adk.events.event import Event from google.adk.sessions import Session - from sqlspec.extensions.adk.store import BaseAsyncADKStore + from sqlspec.extensions.adk.store import BaseAsyncADKStore, BaseSyncADKStore + + ADKStore = BaseAsyncADKStore | BaseSyncADKStore __all__ = ("SQLSpecSessionService",) @@ -36,7 +44,7 @@ class SQLSpecSessionService(BaseSessionService): store: Database store implementation. """ - def __init__(self, store: "BaseAsyncADKStore") -> None: + def __init__(self, store: "ADKStore") -> None: """Initialize the session service. Args: @@ -45,7 +53,7 @@ def __init__(self, store: "BaseAsyncADKStore") -> None: self._store = store @property - def store(self) -> "BaseAsyncADKStore": + def store(self) -> "ADKStore": """Return the database store.""" return self._store @@ -70,10 +78,25 @@ async def create_session( state = {} persisted_state = filter_temp_state(state) - - record = await self._store.create_session( - session_id=session_id, app_name=app_name, user_id=user_id, state=persisted_state + app_state_delta, user_state_delta, session_state = split_scoped_state(persisted_state) + current_app_state = await self._call_store("get_app_state", app_name) + current_user_state = await self._call_store("get_user_state", app_name, user_id) + + app_state = dict(current_app_state or {}) + if app_state_delta: + app_state.update(app_state_delta) + user_state = dict(current_user_state or {}) + if user_state_delta: + user_state.update(user_state_delta) + + record = await self._call_store( + "create_session", session_id=session_id, app_name=app_name, user_id=user_id, state=session_state ) + if app_state_delta: + await self._call_store("upsert_app_state", app_name, app_state) + if user_state_delta: + await self._call_store("upsert_user_state", app_name, user_id, user_state) + record["state"] = merge_scoped_state(record["state"], app_state, user_state) log_with_context( logger, logging.DEBUG, "adk.session.create", app_name=app_name, session_id=session_id, has_state=bool(state) ) @@ -94,7 +117,7 @@ async def get_session( Returns: Session object if found, None otherwise. """ - record = await self._store.get_session(session_id) + record = await self._call_store("get_session", app_name, user_id, session_id) if not record: log_with_context( @@ -108,6 +131,10 @@ async def get_session( ) return None + app_state = await self._call_store("get_app_state", app_name) + user_state = await self._call_store("get_user_state", app_name, user_id) + record["state"] = merge_scoped_state(record["state"], app_state, user_state) + after_timestamp = None limit = None @@ -116,7 +143,12 @@ async def get_session( after_timestamp = datetime.fromtimestamp(config.after_timestamp, tz=timezone.utc) limit = config.num_recent_events - events = await self._store.get_events(session_id=session_id, after_timestamp=after_timestamp, limit=limit) + if limit == 0: + events = [] + else: + events = await self._call_store( + "get_events", app_name, user_id, session_id, after_timestamp=after_timestamp, limit=limit + ) log_with_context( logger, logging.DEBUG, @@ -129,6 +161,24 @@ async def get_session( return record_to_session(record, events) + async def get_user_state(self, *, app_name: str, user_id: str) -> "dict[str, Any]": + """Get user-scoped state for an app and user. + + ADK's service API returns unprefixed user state keys, while SQLSpec + stores the durable state using ADK's documented ``user:`` prefix. + + Args: + app_name: Name of the application. + user_id: ID of the user. + + Returns: + User-scoped state with ``user:`` prefixes removed. + """ + state = await self._call_store("get_user_state", app_name, user_id) + if not state: + return {} + return {key.removeprefix("user:") if key.startswith("user:") else key: value for key, value in state.items()} + async def list_sessions(self, *, app_name: str, user_id: str | None = None) -> "ListSessionsResponse": """List all sessions for an app, optionally filtered by user. @@ -139,7 +189,7 @@ async def list_sessions(self, *, app_name: str, user_id: str | None = None) -> " Returns: Response containing list of sessions (without events). """ - records = await self._store.list_sessions(app_name=app_name, user_id=user_id) + records = await self._call_store("list_sessions", app_name, user_id=user_id) sessions = [record_to_session(record, events=[]) for record in records] log_with_context( @@ -161,7 +211,7 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str) user_id: ID of the user. session_id: Session identifier. """ - record = await self._store.get_session(session_id) + record = await self._call_store("get_session", app_name, user_id, session_id) if not record: log_with_context( @@ -175,7 +225,7 @@ async def delete_session(self, *, app_name: str, user_id: str, session_id: str) ) return - await self._store.delete_session(session_id) + await self._call_store("delete_session", app_name, user_id, session_id) log_with_context( logger, logging.DEBUG, "adk.session.delete", app_name=app_name, session_id=session_id, deleted=True ) @@ -214,16 +264,12 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": self._apply_temp_state(session, event) event = self._trim_temp_delta_state(event) - event_record = event_to_record(event=event, session_id=session.id) - - # Build durable state: current state minus temp keys, plus the - # event's state delta (temp keys already stripped by _trim above). - durable_state = filter_temp_state(session.state) - if event.actions and event.actions.state_delta: - durable_state.update(event.actions.state_delta) + event_record = event_to_record( + event=event, app_name=session.app_name, user_id=session.user_id, session_id=session.id + ) # --- Stale-session detection --- - current_record = await self._store.get_session(session.id) + current_record = await self._call_store("get_session", session.app_name, session.user_id, session.id) if current_record is None: msg = f"Session {session.id} not found." raise ValueError(msg) @@ -243,10 +289,34 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": ) raise ValueError(msg) + state_delta = (event.actions.state_delta if event.actions else None) or {} + app_state_delta, user_state_delta, session_state_delta = split_scoped_state(filter_temp_state(state_delta)) + + _, _, session_state = split_scoped_state(filter_temp_state(session.state)) + session_state.update(session_state_delta) + + app_state = None + if app_state_delta: + app_state = dict(await self._call_store("get_app_state", session.app_name) or {}) + app_state.update(app_state_delta) + + user_state = None + if user_state_delta: + user_state = dict(await self._call_store("get_user_state", session.app_name, session.user_id) or {}) + user_state.update(user_state_delta) + # --- Persist event and state atomically --- - updated_record = await self._store.append_event_and_update_state( - event_record=event_record, session_id=session.id, state=durable_state + updated_record = await self._call_store( + "append_event_and_update_state", + event_record, + session.app_name, + session.user_id, + session.id, + session_state, + app_state=app_state, + user_state=user_state, ) + updated_record["state"] = merge_scoped_state(updated_record["state"], app_state, user_state) # Use the returned record directly — saves a round-trip vs a follow-up get_session(). session.last_update_time = updated_record["update_time"].timestamp() @@ -261,3 +331,13 @@ async def append_event(self, session: "Session", event: "Event") -> "Event": ) return event + + async def _call_store(self, method_name: str, *args: Any, **kwargs: Any) -> Any: + """Call an async store method or offload a sync store method.""" + method = getattr(self._store, method_name) + if inspect.iscoroutinefunction(method): + return await method(*args, **kwargs) + sync_method = method + if TYPE_CHECKING: + sync_method = cast("Callable[..., Any]", method) + return await async_(sync_method)(*args, **kwargs) diff --git a/sqlspec/extensions/adk/store.py b/sqlspec/extensions/adk/store.py index e891cf8bb..6a7cd32d1 100644 --- a/sqlspec/extensions/adk/store.py +++ b/sqlspec/extensions/adk/store.py @@ -1,17 +1,18 @@ -"""Base store classes for ADK session backend (sync and async).""" +"""Base store class for ADK session backends.""" +import inspect import logging import re from abc import ABC, abstractmethod -from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Final, Generic, TypeVar, cast from sqlspec.extensions.adk._config_utils import _get_adk_session_store_config from sqlspec.observability import resolve_db_system from sqlspec.utils.logging import get_logger, log_with_context +from sqlspec.utils.sync_tools import async_ if TYPE_CHECKING: - from datetime import datetime - from sqlspec.config import DatabaseConfigProtocol from sqlspec.extensions.adk._types import EventRecord, SessionRecord @@ -21,10 +22,15 @@ logger = get_logger("sqlspec.extensions.adk.store") - VALID_TABLE_NAME_PATTERN: Final = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") COLUMN_NAME_PATTERN: Final = re.compile(r"^(\w+)") MAX_TABLE_NAME_LENGTH: Final = 63 +ADK_RESET_TABLE_PROFILES: Final = ( + ("adk_session", "adk_event", "adk_app_state", "adk_user_state", "adk_internal_metadata"), + ("adk_session", "adk_event", "adk_app_state", "adk_user_state", "adk_metadata"), + ("adk_sessions", "adk_events", "adk_app_states", "adk_user_states", "adk_internal_metadata"), + ("adk_sessions", "adk_events", "adk_app_states", "adk_user_states", "adk_metadata"), +) class BaseAsyncADKStore(ABC, Generic[ConfigT]): @@ -40,65 +46,68 @@ class BaseAsyncADKStore(ABC, Generic[ConfigT]): - Session and event CRUD operations Subclasses must implement dialect-specific SQL queries and will be created - in each adapter directory. + in each adapter directory (e.g., sqlspec/adapters/asyncpg/adk/store.py). Args: config: SQLSpec database configuration with extension_config["adk"] settings. + + Notes: + Configuration is read from config.extension_config["adk"]: + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") + - app_state_table: App-scoped state table name (default: "adk_app_state") + - user_state_table: User-scoped state table name (default: "adk_user_state") + - metadata_table: Internal metadata table name (default: "adk_internal_metadata") + - owner_id_column: Optional owner FK column DDL (default: None) """ - __slots__ = ("_config", "_events_table", "_owner_id_column_ddl", "_owner_id_column_name", "_session_table") + __slots__ = ( + "_app_state_table", + "_config", + "_events_table", + "_metadata_table", + "_owner_id_column_ddl", + "_owner_id_column_name", + "_session_table", + "_user_state_table", + ) def __init__(self, config: ConfigT) -> None: """Initialize the ADK store. Args: config: SQLSpec database configuration. + + Notes: + Reads configuration from config.extension_config["adk"]: + - session_table: Sessions table name (default: "adk_session") + - events_table: Events table name (default: "adk_event") + - app_state_table: App-scoped state table name (default: "adk_app_state") + - user_state_table: User-scoped state table name (default: "adk_user_state") + - metadata_table: Internal metadata table name (default: "adk_internal_metadata") + - owner_id_column: Optional owner FK column DDL (default: None) """ self._config = config store_config = self._get_store_config_from_extension() self._session_table: str = str(store_config["session_table"]) self._events_table: str = str(store_config["events_table"]) + self._app_state_table: str = str(store_config["app_state_table"]) + self._user_state_table: str = str(store_config["user_state_table"]) + self._metadata_table: str = str(store_config["metadata_table"]) self._owner_id_column_ddl: str | None = store_config.get("owner_id_column") self._owner_id_column_name: str | None = ( _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None ) _validate_table_name(self._session_table) _validate_table_name(self._events_table) + _validate_table_name(self._app_state_table) + _validate_table_name(self._user_state_table) + _validate_table_name(self._metadata_table) - def _get_store_config_from_extension(self) -> "dict[str, Any]": - """Extract ADK store configuration from config.extension_config. - - Returns: - Dict with session_table, events_table, and optionally owner_id_column. - """ - return dict(_get_adk_session_store_config(self._config)) - - @property - def config(self) -> ConfigT: - """Return the database configuration.""" - return self._config - - @property - def session_table(self) -> str: - """Return the sessions table name.""" - return self._session_table - - @property - def events_table(self) -> str: - """Return the events table name.""" - return self._events_table - - @property - def owner_id_column_ddl(self) -> "str | None": - """Return the full owner ID column DDL (or None if not configured).""" - return self._owner_id_column_ddl - - @property - def owner_id_column_name(self) -> "str | None": - """Return the owner ID column name only (or None if not configured).""" - return self._owner_id_column_name + async def create_tables(self) -> None: + """Create the sessions and events tables if they don't exist.""" + raise NotImplementedError - @abstractmethod async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None ) -> "SessionRecord": @@ -117,11 +126,16 @@ async def create_session( raise NotImplementedError @abstractmethod - async def get_session(self, session_id: str) -> "SessionRecord | None": - """Get a session by ID. + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get a session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. + renew_for: If positive, touch the session update timestamp while reading. Returns: Session record if found, None otherwise. @@ -129,10 +143,12 @@ async def get_session(self, session_id: str) -> "SessionRecord | None": raise NotImplementedError @abstractmethod - async def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: """Update session state. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. state: New state dictionary. """ @@ -152,10 +168,12 @@ async def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "l raise NotImplementedError @abstractmethod - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: """Delete a session and its events. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. """ raise NotImplementedError @@ -171,20 +189,41 @@ async def append_event(self, event_record: "EventRecord") -> None: @abstractmethod async def append_event_and_update_state( - self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" + self, + event_record: "EventRecord", + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> "SessionRecord": """Atomically append an event and update the session's durable state. This is the authoritative durable write boundary for post-creation - session mutations. The event insert and state update must succeed - together or fail together, and the updated session record is returned - in the same round-trip so callers don't need a follow-up read. + session mutations. The event insert, session state update, and the + optional scoped-state upserts must succeed together or fail together, + and the updated session record is returned in the same round-trip so + callers don't need a follow-up read. + + When ``app_state`` is provided (non-None), it is a full merged + app-scoped snapshot to replace/upsert for ``app_name``. When + ``user_state`` is provided, it is a full merged user-scoped snapshot to + replace/upsert for ``(app_name, user_id)``. ``None`` means that scope + was untouched by the event and must not be written. Args: event_record: Event record to store. + app_name: Application name for routing scoped-state upserts. + user_id: User identifier for routing user-scoped upserts. session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). + state: Post-append durable session-scoped state snapshot + (``temp:`` keys already stripped by the service layer). + app_state: Full app-scoped state snapshot (``app:*`` keys) to + upsert atomically, or ``None`` when untouched. + user_state: Full user-scoped state snapshot (``user:*`` keys) to + upsert atomically, or ``None`` when untouched. Returns: The updated SessionRecord reflecting the new state and update_time. @@ -197,11 +236,18 @@ async def append_event_and_update_state( @abstractmethod async def get_events( - self, session_id: str, after_timestamp: "datetime | None" = None, limit: "int | None" = None + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, ) -> "list[EventRecord]": """Get events for a session. Args: + app_name: Name of the application. + user_id: ID of the user. session_id: Session identifier. after_timestamp: Only return events after this time. limit: Maximum number of events to return. @@ -212,16 +258,227 @@ async def get_events( raise NotImplementedError @abstractmethod - async def create_tables(self) -> None: - """Create the sessions and events tables if they don't exist.""" + async def delete_expired_events(self, before: datetime) -> int: + """Delete events older than the given timestamp. + + Args: + before: Timestamp threshold; events with timestamp earlier than this value are deleted. + + Returns: + Number of event rows deleted. + """ + raise NotImplementedError + + @abstractmethod + async def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time predates the given threshold. + + Args: + updated_before: Timestamp threshold; sessions updated earlier than this value are deleted. + + Returns: + Number of session rows deleted. + """ + raise NotImplementedError + + @abstractmethod + async def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application. + + Args: + app_name: Application name. + + Returns: + App-scoped state mapping if present, otherwise ``None``. + """ + raise NotImplementedError + + @abstractmethod + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user. + + Args: + app_name: Application name. + user_id: User identifier. + + Returns: + User-scoped state mapping if present, otherwise ``None``. + """ + raise NotImplementedError + + @abstractmethod + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application. + + Args: + app_name: Application name. + state: App-scoped state mapping. + """ + raise NotImplementedError + + @abstractmethod + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user. + + Args: + app_name: Application name. + user_id: User identifier. + state: User-scoped state mapping. + """ + raise NotImplementedError + + @abstractmethod + async def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table. + + Args: + key: Metadata key. + + Returns: + Metadata value if present, otherwise ``None``. + """ + raise NotImplementedError + + @abstractmethod + async def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table. + + Args: + key: Metadata key. + value: Metadata value. + """ raise NotImplementedError + @property + def config(self) -> ConfigT: + """Return the database configuration.""" + return self._config + + @property + def session_table(self) -> str: + """Return the sessions table name.""" + return self._session_table + + @property + def events_table(self) -> str: + """Return the events table name.""" + return self._events_table + + @property + def app_state_table(self) -> str: + """Return the app-scoped state table name.""" + return self._app_state_table + + @property + def user_state_table(self) -> str: + """Return the user-scoped state table name.""" + return self._user_state_table + + @property + def metadata_table(self) -> str: + """Return the ADK metadata table name.""" + return self._metadata_table + + @property + def owner_id_column_ddl(self) -> "str | None": + """Return the full owner ID column DDL (or None if not configured).""" + return self._owner_id_column_ddl + + @property + def owner_id_column_name(self) -> "str | None": + """Return the owner ID column name only (or None if not configured).""" + return self._owner_id_column_name + async def ensure_tables(self) -> None: """Create tables and emit a standardized log entry.""" await self.create_tables() self._log_tables_created() + async def drop_tables(self) -> None: + """Drop all ADK tables managed by this store in FK-safe order.""" + await self._execute_lifecycle_scripts(self._get_drop_tables_sql()) + self._log_tables_dropped() + + async def recreate_tables(self) -> None: + """Drop and recreate all ADK tables managed by this store.""" + await self.drop_tables() + await self.ensure_tables() + self._log_tables_recreated() + + def _get_reset_drop_tables_sql(self) -> "list[str]": + """Return all table drops needed before recreating the clean-break schema.""" + statements = list(self._get_drop_tables_sql()) + for table_profile in ADK_RESET_TABLE_PROFILES: + statements.extend(self._get_drop_tables_sql_for_table_profile(table_profile)) + return _deduplicate_statements(statements) + + def _get_store_config_from_extension(self) -> "dict[str, Any]": + """Extract ADK store configuration from config.extension_config. + + Returns: + Dict with ADK table names and optionally owner_id_column. + """ + return dict(_get_adk_session_store_config(self._config)) + + def _calculate_expires_at(self, expires_in: "int | timedelta | None") -> "datetime | None": + """Calculate expiration timestamp from expires_in. + + Args: + expires_in: Seconds or timedelta until expiration. + + Returns: + UTC datetime of expiration, or None if no expiration. + """ + if expires_in is None: + return None + + expires_in_seconds = int(expires_in.total_seconds()) if isinstance(expires_in, timedelta) else expires_in + + if expires_in_seconds <= 0: + return None + + return datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds) + + def _value_to_bytes(self, value: "str | bytes") -> bytes: + """Convert value to bytes if needed. + + Args: + value: String or bytes value. + + Returns: + Value as bytes. + """ + if isinstance(value, str): + return value.encode("utf-8") + return value + + async def _execute_lifecycle_scripts(self, statements: list[str]) -> None: + """Execute lifecycle DDL scripts for async and sync-backed configs.""" + session_context = self._config.provide_session() + if hasattr(session_context, "__aenter__"): + async with cast("Any", session_context) as driver: + for statement in statements: + result = driver.execute_script(statement) + if inspect.isawaitable(result): + await result + commit = getattr(driver, "commit", None) + if callable(commit): + result = commit() + if inspect.isawaitable(result): + await result + return + + def _execute_sync() -> None: + with cast("Any", self._config.provide_session()) as driver: + for statement in statements: + driver.execute_script(statement) + commit = getattr(driver, "commit", None) + if callable(commit): + commit() + + await async_(_execute_sync)() + @abstractmethod async def _get_create_sessions_table_sql(self) -> str: """Get the CREATE TABLE SQL for the sessions table. @@ -240,6 +497,69 @@ async def _get_create_events_table_sql(self) -> str: """ raise NotImplementedError + @abstractmethod + async def _get_create_app_states_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the app-scoped state table. + + Returns: + SQL statement to create the app-scoped state table. + """ + raise NotImplementedError + + @abstractmethod + async def _get_create_user_states_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the user-scoped state table. + + Returns: + SQL statement to create the user-scoped state table. + """ + raise NotImplementedError + + @abstractmethod + async def _get_create_metadata_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the ADK internal metadata table. + + Returns: + SQL statement to create the ADK internal metadata table. + """ + raise NotImplementedError + + @abstractmethod + async def _get_seed_metadata_sql(self) -> str: + """Get the SQL statement that seeds the ADK schema-version metadata row. + + Returns: + SQL statement that records ``schema_version = 1``. + """ + raise NotImplementedError + + @abstractmethod + def _get_drop_app_states_table_sql(self) -> str: + """Get the DROP TABLE SQL statement for the app-scoped state table. + + Returns: + SQL statement to drop the app-scoped state table. + """ + raise NotImplementedError + + @abstractmethod + def _get_drop_user_states_table_sql(self) -> str: + """Get the DROP TABLE SQL statement for the user-scoped state table. + + Returns: + SQL statement to drop the user-scoped state table. + """ + raise NotImplementedError + + @abstractmethod + def _get_drop_metadata_table_sql(self) -> str: + """Get the DROP TABLE SQL statement for the ADK internal metadata table. + + Returns: + SQL statement to drop the ADK internal metadata table. + """ + raise NotImplementedError + @abstractmethod def _get_drop_tables_sql(self) -> "list[str]": """Get the DROP TABLE SQL statements for this database dialect. @@ -247,9 +567,34 @@ def _get_drop_tables_sql(self) -> "list[str]": Returns: List of SQL statements to drop the tables and all indexes. Order matters: drop events table before sessions table due to FK. + + Notes: + Should use IF EXISTS or dialect-specific error handling + to allow idempotent migrations. """ raise NotImplementedError + def _get_drop_tables_sql_for_table_profile(self, table_profile: "tuple[str, str, str, str, str]") -> "list[str]": + session_table, events_table, app_state_table, user_state_table, metadata_table = table_profile + current_session_table = self._session_table + current_events_table = self._events_table + current_app_state_table = self._app_state_table + current_user_state_table = self._user_state_table + current_table = self._metadata_table + self._session_table = session_table + self._events_table = events_table + self._app_state_table = app_state_table + self._user_state_table = user_state_table + self._metadata_table = metadata_table + try: + return list(self._get_drop_tables_sql()) + finally: + self._session_table = current_session_table + self._events_table = current_events_table + self._app_state_table = current_app_state_table + self._user_state_table = current_user_state_table + self._metadata_table = current_table + def _log_tables_created(self) -> None: log_with_context( logger, @@ -260,27 +605,48 @@ def _log_tables_created(self) -> None: events_table=self._events_table, ) + def _log_tables_dropped(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.tables.dropped", + db_system=resolve_db_system(type(self).__name__), + session_table=self._session_table, + events_table=self._events_table, + ) -class BaseSyncADKStore(ABC, Generic[ConfigT]): - """Base class for sync SQLSpec-backed ADK session stores. + def _log_tables_recreated(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.tables.recreated", + db_system=resolve_db_system(type(self).__name__), + session_table=self._session_table, + events_table=self._events_table, + ) - Implements storage operations for Google ADK sessions and events using - SQLSpec database adapters with synchronous execution. - This abstract base class provides common functionality for sync database-specific - store implementations including: - - Connection management via SQLSpec configs - - Table name validation - - Session and event CRUD operations +class BaseSyncADKStore(ABC, Generic[ConfigT]): + """Base class for sync SQLSpec-backed ADK session stores. - Subclasses must implement dialect-specific SQL queries and will be created - in each adapter directory. + Sync-backed adapters expose a real synchronous API for direct use in + synchronous applications. Async bridging belongs in ``SQLSpecSessionService`` + when Google ADK calls into a sync store from its async service surface. Args: config: SQLSpec database configuration with extension_config["adk"] settings. """ - __slots__ = ("_config", "_events_table", "_owner_id_column_ddl", "_owner_id_column_name", "_session_table") + __slots__ = ( + "_app_state_table", + "_config", + "_events_table", + "_metadata_table", + "_owner_id_column_ddl", + "_owner_id_column_name", + "_session_table", + "_user_state_table", + ) def __init__(self, config: ConfigT) -> None: """Initialize the sync ADK store. @@ -292,20 +658,124 @@ def __init__(self, config: ConfigT) -> None: store_config = self._get_store_config_from_extension() self._session_table: str = str(store_config["session_table"]) self._events_table: str = str(store_config["events_table"]) + self._app_state_table: str = str(store_config["app_state_table"]) + self._user_state_table: str = str(store_config["user_state_table"]) + self._metadata_table: str = str(store_config["metadata_table"]) self._owner_id_column_ddl: str | None = store_config.get("owner_id_column") self._owner_id_column_name: str | None = ( _parse_owner_id_column(self._owner_id_column_ddl) if self._owner_id_column_ddl else None ) _validate_table_name(self._session_table) _validate_table_name(self._events_table) + _validate_table_name(self._app_state_table) + _validate_table_name(self._user_state_table) + _validate_table_name(self._metadata_table) - def _get_store_config_from_extension(self) -> "dict[str, Any]": - """Extract ADK store configuration from config.extension_config. + @abstractmethod + def create_tables(self) -> None: + """Create the sessions and events tables if they don't exist.""" + raise NotImplementedError - Returns: - Dict with session_table, events_table, and optionally owner_id_column. - """ - return dict(_get_adk_session_store_config(self._config)) + @abstractmethod + def create_session( + self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None + ) -> "SessionRecord": + """Create a new session.""" + raise NotImplementedError + + @abstractmethod + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: "int | timedelta | None" = None + ) -> "SessionRecord | None": + """Get a session.""" + raise NotImplementedError + + @abstractmethod + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: "dict[str, Any]") -> None: + """Update session state.""" + raise NotImplementedError + + @abstractmethod + def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": + """List all sessions for an app, optionally filtered by user.""" + raise NotImplementedError + + @abstractmethod + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + """Delete a session and its events.""" + raise NotImplementedError + + @abstractmethod + def append_event(self, event_record: "EventRecord") -> None: + """Append an event to a session.""" + raise NotImplementedError + + @abstractmethod + def append_event_and_update_state( + self, + event_record: "EventRecord", + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + *, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, + ) -> "SessionRecord": + """Atomically append an event and update the session's durable state.""" + raise NotImplementedError + + @abstractmethod + def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: "datetime | None" = None, + limit: "int | None" = None, + ) -> "list[EventRecord]": + """Get events for a session.""" + raise NotImplementedError + + @abstractmethod + def delete_expired_events(self, before: datetime) -> int: + """Delete events older than the given timestamp.""" + raise NotImplementedError + + @abstractmethod + def delete_idle_sessions(self, updated_before: datetime) -> int: + """Delete sessions whose update_time predates the given threshold.""" + raise NotImplementedError + + @abstractmethod + def get_app_state(self, app_name: str) -> "dict[str, Any] | None": + """Return app-scoped state for an application.""" + raise NotImplementedError + + @abstractmethod + def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any] | None": + """Return user-scoped state for an application user.""" + raise NotImplementedError + + @abstractmethod + def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + """Insert or replace app-scoped state for an application.""" + raise NotImplementedError + + @abstractmethod + def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + """Insert or replace user-scoped state for an application user.""" + raise NotImplementedError + + @abstractmethod + def get_metadata(self, key: str) -> "str | None": + """Return a value from the ADK internal metadata table.""" + raise NotImplementedError + + @abstractmethod + def set_metadata(self, key: str, value: str) -> None: + """Set a value in the ADK internal metadata table.""" + raise NotImplementedError @property def config(self) -> ConfigT: @@ -322,6 +792,21 @@ def events_table(self) -> str: """Return the events table name.""" return self._events_table + @property + def app_state_table(self) -> str: + """Return the app-scoped state table name.""" + return self._app_state_table + + @property + def user_state_table(self) -> str: + """Return the user-scoped state table name.""" + return self._user_state_table + + @property + def metadata_table(self) -> str: + """Return the ADK metadata table name.""" + return self._metadata_table + @property def owner_id_column_ddl(self) -> "str | None": """Return the full owner ID column DDL (or None if not configured).""" @@ -332,156 +817,132 @@ def owner_id_column_name(self) -> "str | None": """Return the owner ID column name only (or None if not configured).""" return self._owner_id_column_name - @abstractmethod - def create_session( - self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None - ) -> "SessionRecord": - """Create a new session. + def ensure_tables(self) -> None: + """Create tables and emit a standardized log entry.""" - Args: - session_id: Unique identifier for the session. - app_name: Name of the application. - user_id: ID of the user. - state: Session state dictionary. - owner_id: Optional owner ID value for owner_id_column (if configured). + self.create_tables() + self._log_tables_created() - Returns: - The created session record. - """ - raise NotImplementedError + def drop_tables(self) -> None: + """Drop all ADK tables managed by this store in FK-safe order.""" + self._execute_lifecycle_scripts(self._get_drop_tables_sql()) + self._log_tables_dropped() - @abstractmethod - def get_session(self, session_id: str) -> "SessionRecord | None": - """Get a session by ID. + def recreate_tables(self) -> None: + """Drop and recreate all ADK tables managed by this store.""" + self.drop_tables() + self.ensure_tables() + self._log_tables_recreated() - Args: - session_id: Session identifier. + def _get_store_config_from_extension(self) -> "dict[str, Any]": + """Extract ADK store configuration from config.extension_config.""" + return dict(_get_adk_session_store_config(self._config)) - Returns: - Session record if found, None otherwise. - """ - raise NotImplementedError + def _calculate_expires_at(self, expires_in: "int | timedelta | None") -> "datetime | None": + """Calculate expiration timestamp from expires_in.""" + if expires_in is None: + return None - @abstractmethod - def update_session_state(self, session_id: str, state: "dict[str, Any]") -> None: - """Update session state. + expires_in_seconds = int(expires_in.total_seconds()) if isinstance(expires_in, timedelta) else expires_in - Args: - session_id: Session identifier. - state: New state dictionary. - """ - raise NotImplementedError + if expires_in_seconds <= 0: + return None - @abstractmethod - def list_sessions(self, app_name: str, user_id: "str | None" = None) -> "list[SessionRecord]": - """List all sessions for an app, optionally filtered by user. + return datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds) - Args: - app_name: Name of the application. - user_id: ID of the user. If None, returns all sessions for the app. + def _value_to_bytes(self, value: "str | bytes") -> bytes: + """Convert value to bytes if needed.""" + if isinstance(value, str): + return value.encode("utf-8") + return value - Returns: - List of session records. - """ - raise NotImplementedError + def _execute_lifecycle_scripts(self, statements: list[str]) -> None: + """Execute lifecycle DDL scripts using the sync driver session.""" + with cast("Any", self._config.provide_session()) as driver: + for statement in statements: + driver.execute_script(statement) + commit = getattr(driver, "commit", None) + if callable(commit): + commit() @abstractmethod - def delete_session(self, session_id: str) -> None: - """Delete a session and its events. - - Args: - session_id: Session identifier. - """ + def _get_create_sessions_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the sessions table.""" raise NotImplementedError @abstractmethod - def create_event( - self, - event_id: str, - session_id: str, - app_name: str, - user_id: str, - author: "str | None" = None, - actions: "bytes | None" = None, - content: "dict[str, Any] | None" = None, - **kwargs: Any, - ) -> "EventRecord": - """Create a new event. - - Args: - event_id: Unique event identifier. - session_id: Session identifier. - app_name: Application name. - user_id: User identifier. - author: Event author (user/assistant/system). - actions: Pickled actions object. - content: Event content (JSONB/JSON). - **kwargs: Additional optional fields. - - Returns: - Created event record. - """ + def _get_create_events_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the events table.""" raise NotImplementedError @abstractmethod - def create_event_and_update_state( - self, event_record: "EventRecord", session_id: str, state: "dict[str, Any]" - ) -> None: - """Atomically create an event and update the session's durable state. - - This is the authoritative durable write boundary for post-creation - session mutations. The event insert and state update must succeed - together or fail together. - - Args: - event_record: Event record to store. - session_id: Session identifier whose state should be updated. - state: Post-append durable state snapshot (``temp:`` keys already - stripped by the service layer). - """ + def _get_create_app_states_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the app-scoped state table.""" raise NotImplementedError @abstractmethod - def list_events(self, session_id: str) -> "list[EventRecord]": - """List events for a session ordered by timestamp. - - Args: - session_id: Session identifier. - - Returns: - List of event records ordered by timestamp ASC. - """ + def _get_create_user_states_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the user-scoped state table.""" raise NotImplementedError @abstractmethod - def create_tables(self) -> None: - """Create both sessions and events tables if they don't exist.""" + def _get_create_metadata_table_sql(self) -> str: + """Get the CREATE TABLE SQL for the ADK internal metadata table.""" raise NotImplementedError - def ensure_tables(self) -> None: - """Create tables and emit a standardized log entry.""" - - self.create_tables() - self._log_tables_created() + @abstractmethod + def _get_seed_metadata_sql(self) -> str: + """Get the SQL statement that seeds the ADK schema-version metadata row.""" + raise NotImplementedError @abstractmethod - def _get_create_sessions_table_sql(self) -> str: - """Get SQL to create sessions table. + def _get_drop_app_states_table_sql(self) -> str: + """Get the DROP TABLE SQL statement for the app-scoped state table.""" + raise NotImplementedError - Returns: - SQL statement to create adk_sessions table with indexes. - """ + @abstractmethod + def _get_drop_user_states_table_sql(self) -> str: + """Get the DROP TABLE SQL statement for the user-scoped state table.""" raise NotImplementedError @abstractmethod - def _get_create_events_table_sql(self) -> str: - """Get SQL to create events table. + def _get_drop_metadata_table_sql(self) -> str: + """Get the DROP TABLE SQL statement for the ADK internal metadata table.""" + raise NotImplementedError - Returns: - SQL statement to create adk_events table with indexes. - """ + @abstractmethod + def _get_drop_tables_sql(self) -> "list[str]": + """Get the DROP TABLE SQL statements for this database dialect.""" raise NotImplementedError + def _get_reset_drop_tables_sql(self) -> "list[str]": + """Return all table drops needed before recreating the clean-break schema.""" + statements = list(self._get_drop_tables_sql()) + for table_profile in ADK_RESET_TABLE_PROFILES: + statements.extend(self._get_drop_tables_sql_for_table_profile(table_profile)) + return _deduplicate_statements(statements) + + def _get_drop_tables_sql_for_table_profile(self, table_profile: "tuple[str, str, str, str, str]") -> "list[str]": + session_table, events_table, app_state_table, user_state_table, metadata_table = table_profile + current_session_table = self._session_table + current_events_table = self._events_table + current_app_state_table = self._app_state_table + current_user_state_table = self._user_state_table + current_table = self._metadata_table + self._session_table = session_table + self._events_table = events_table + self._app_state_table = app_state_table + self._user_state_table = user_state_table + self._metadata_table = metadata_table + try: + return list(self._get_drop_tables_sql()) + finally: + self._session_table = current_session_table + self._events_table = current_events_table + self._app_state_table = current_app_state_table + self._user_state_table = current_user_state_table + self._metadata_table = current_table + def _log_tables_created(self) -> None: log_with_context( logger, @@ -492,28 +953,47 @@ def _log_tables_created(self) -> None: events_table=self._events_table, ) - @abstractmethod - def _get_drop_tables_sql(self) -> "list[str]": - """Get SQL to drop tables. + def _log_tables_dropped(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.tables.dropped", + db_system=resolve_db_system(type(self).__name__), + session_table=self._session_table, + events_table=self._events_table, + ) - Returns: - List of SQL statements to drop tables and indexes. - Order matters: drop events before sessions due to FK. - """ - raise NotImplementedError + def _log_tables_recreated(self) -> None: + log_with_context( + logger, + logging.DEBUG, + "adk.tables.recreated", + db_system=resolve_db_system(type(self).__name__), + session_table=self._session_table, + events_table=self._events_table, + ) def _parse_owner_id_column(owner_id_column_ddl: str) -> str: """Extract column name from owner ID column DDL definition. Args: - owner_id_column_ddl: Full column DDL string. + owner_id_column_ddl: Full column DDL string (e.g., "user_id INTEGER REFERENCES users(id)"). Returns: Column name only (first word). Raises: ValueError: If DDL format is invalid. + + Examples: + "account_id INTEGER NOT NULL" -> "account_id" + "user_id UUID REFERENCES users(id)" -> "user_id" + "tenant VARCHAR(64) DEFAULT 'public'" -> "tenant" + + Notes: + Only the column name is parsed. The rest of the DDL is passed through + verbatim to CREATE TABLE statements. """ match = COLUMN_NAME_PATTERN.match(owner_id_column_ddl.strip()) if not match: @@ -523,6 +1003,17 @@ def _parse_owner_id_column(owner_id_column_ddl: str) -> str: return match.group(1) +def _deduplicate_statements(statements: "list[str]") -> "list[str]": + seen: set[str] = set() + result: list[str] = [] + for statement in statements: + if statement in seen: + continue + result.append(statement) + seen.add(statement) + return result + + def _validate_table_name(table_name: str) -> None: """Validate table name for SQL safety. diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py deleted file mode 100644 index e20536d83..000000000 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_integration.py +++ /dev/null @@ -1,237 +0,0 @@ -"""Integration tests for ADBC ADK store with actual database dialects. - -These tests require the actual ADBC drivers to be installed: -- adbc-driver-sqlite (default, always available) -- adbc-driver-postgresql (optional) -- adbc-driver-duckdb (optional) -- adbc-driver-snowflake (optional) - -Tests are marked with dialect-specific markers and will be skipped -if the driver is not installed. -""" - -import json -from pathlib import Path -from typing import Any - -import pytest - -from sqlspec.adapters.adbc import AdbcConfig -from sqlspec.adapters.adbc.adk import AdbcADKStore - -pytestmark = pytest.mark.adbc - - -@pytest.fixture() -async def sqlite_store(tmp_path: Path) -> Any: - """SQLite ADBC store fixture.""" - db_path = tmp_path / "sqlite_test.db" - config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) - store = AdbcADKStore(config) - await store.create_tables() - return store - - -async def test_sqlite_dialect_creates_text_columns(sqlite_store: Any) -> None: - """Test SQLite dialect creates TEXT columns for JSON.""" - with sqlite_store.config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute(f"PRAGMA table_info({sqlite_store.session_table})") - columns = cursor.fetchall() - - state_column = next(col for col in columns if col[1] == "state") - assert state_column[2] == "TEXT" - finally: - cursor.close() # type: ignore[no-untyped-call] - - -async def test_sqlite_dialect_session_operations(sqlite_store: Any) -> None: - """Test SQLite dialect with full session CRUD.""" - session_id = "sqlite-session-1" - app_name = "test-app" - user_id = "user-123" - state = {"nested": {"key": "value"}, "count": 42} - - created = await sqlite_store.create_session(session_id, app_name, user_id, state) - assert created["id"] == session_id - assert created["state"] == state - - retrieved = await sqlite_store.get_session(session_id) - assert retrieved["state"] == state - - new_state = {"updated": True} - await sqlite_store.update_session_state(session_id, new_state) - - updated = await sqlite_store.get_session(session_id) - assert updated["state"] == new_state - - -async def test_sqlite_dialect_event_operations(sqlite_store: Any) -> None: - """Test SQLite dialect with event operations.""" - session_id = "sqlite-session-events" - app_name = "test-app" - user_id = "user-123" - - await sqlite_store.create_session(session_id, app_name, user_id, {}) - - content = {"message": "Hello"} - - from datetime import datetime, timezone - - from sqlspec.extensions.adk import EventRecord - - event_record: EventRecord = { - "session_id": session_id, - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1", "content": content, "app_name": app_name, "user_id": user_id}, - } - await sqlite_store.append_event(event_record) - - events = await sqlite_store.get_events(session_id) - assert len(events) == 1 - retrieved_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - ) - assert retrieved_data["content"] == content - - -@pytest.mark.postgres -@pytest.mark.skipif(True, reason="Requires adbc-driver-postgresql and PostgreSQL server") -async def test_postgresql_dialect_creates_jsonb_columns() -> None: - """Test PostgreSQL dialect creates JSONB columns. - - This test is skipped by default. To run: - 1. Install adbc-driver-postgresql - 2. Start PostgreSQL server - 3. Update connection config - 4. Remove skipif marker - """ - config = AdbcConfig( - connection_config={"driver_name": "postgresql", "uri": "postgresql://user:pass@localhost/testdb"} - ) - store = AdbcADKStore(config) - await store.create_tables() - - with store.config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute( - f""" - SELECT data_type - FROM information_schema.columns - WHERE table_name = '{store.session_table}' - AND column_name = 'state' - """ - ) - result = cursor.fetchone() - assert result is not None - assert result[0] == "jsonb" - finally: - cursor.close() # type: ignore[no-untyped-call] # type: ignore[no-untyped-call] - - -@pytest.mark.duckdb -@pytest.mark.skipif(True, reason="Requires adbc-driver-duckdb") -async def test_duckdb_dialect_creates_json_columns(tmp_path: Path) -> None: - """Test DuckDB dialect creates JSON columns. - - This test is skipped by default. To run: - 1. Install adbc-driver-duckdb - 2. Remove skipif marker - """ - db_path = tmp_path / "duckdb_test.db" - config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": f"file:{db_path}"}) - store = AdbcADKStore(config) - await store.create_tables() - - session_id = "duckdb-session-1" - state = {"analytics": {"count": 1000, "revenue": 50000.00}} - - created = await store.create_session(session_id, "app", "user", state) - assert created["state"] == state - - -@pytest.mark.snowflake -@pytest.mark.skipif(True, reason="Requires adbc-driver-snowflake and Snowflake account") -async def test_snowflake_dialect_creates_variant_columns() -> None: - """Test Snowflake dialect creates VARIANT columns. - - This test is skipped by default. To run: - 1. Install adbc-driver-snowflake - 2. Configure Snowflake credentials - 3. Remove skipif marker - """ - config = AdbcConfig( - connection_config={ - "driver_name": "snowflake", - "uri": "snowflake://account.region/database?warehouse=wh", - "username": "user", - "password": "pass", - } - ) - store = AdbcADKStore(config) - await store.create_tables() - - with store.config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute( - f""" - SELECT data_type - FROM information_schema.columns - WHERE table_name = UPPER('{store.session_table}') - AND column_name = 'STATE' - """ - ) - result = cursor.fetchone() - assert result is not None - assert result[0] == "VARIANT" - finally: - cursor.close() # type: ignore[no-untyped-call] - - -async def test_sqlite_with_owner_id_column(tmp_path: Path) -> None: - """Test SQLite with owner ID column creates proper constraints.""" - db_path = tmp_path / "sqlite_fk_test.db" - base_config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) - - with base_config.provide_connection() as conn: - cursor = conn.cursor() - try: - cursor.execute("PRAGMA foreign_keys = ON") - cursor.execute("CREATE TABLE tenants (id INTEGER PRIMARY KEY, name TEXT)") - cursor.execute("INSERT INTO tenants (id, name) VALUES (1, 'Tenant A')") - conn.commit() - finally: - cursor.close() # type: ignore[no-untyped-call] - - config = AdbcConfig( - connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}, - extension_config={"adk": {"owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id)"}}, - ) - store = AdbcADKStore(config) - await store.create_tables() - - session = await store.create_session("s1", "app", "user", {"data": "test"}, owner_id=1) - assert session["id"] == "s1" - - retrieved = await store.get_session("s1") - assert retrieved is not None - - -async def test_generic_dialect_fallback(tmp_path: Path) -> None: - """Test generic dialect is used for unknown drivers.""" - db_path = tmp_path / "generic_test.db" - - config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) - - store = AdbcADKStore(config) - assert store.dialect in ["sqlite", "generic"] - - await store.create_tables() - - session = await store.create_session("generic-1", "app", "user", {"test": True}) - assert session["state"]["test"] is True diff --git a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py index 703d40437..c10698330 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_dialect_support.py @@ -90,58 +90,57 @@ def test_generic_sessions_ddl_contains_text() -> None: def test_postgresql_events_ddl_uses_jsonb() -> None: - """Test PostgreSQL events DDL uses JSONB for event_json.""" + """Test PostgreSQL events DDL uses JSONB for event_data.""" config = AdbcConfig(connection_config={"driver_name": "postgresql", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_postgresql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in ddl - assert "event_json" in ddl + assert "event_data" in ddl assert "session_id" in ddl assert "invocation_id" in ddl - assert "author" in ddl assert "timestamp" in ddl.lower() def test_sqlite_events_ddl_uses_text() -> None: - """Test SQLite events DDL uses TEXT for event_json.""" + """Test SQLite events DDL uses TEXT for event_data.""" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_sqlite() # pyright: ignore[reportPrivateUsage] assert "TEXT" in ddl - assert "event_json" in ddl + assert "event_data" in ddl assert "session_id" in ddl assert "REAL" in ddl # SQLite uses REAL for timestamps def test_duckdb_events_ddl_uses_json() -> None: - """Test DuckDB events DDL uses JSON type for event_json.""" + """Test DuckDB events DDL uses JSON type for event_data.""" config = AdbcConfig(connection_config={"driver_name": "duckdb", "uri": ":memory:"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_duckdb() # pyright: ignore[reportPrivateUsage] assert "JSON" in ddl - assert "event_json" in ddl + assert "event_data" in ddl def test_snowflake_events_ddl_uses_variant() -> None: - """Test Snowflake events DDL uses VARIANT for event_json.""" + """Test Snowflake events DDL uses VARIANT for event_data.""" config = AdbcConfig(connection_config={"driver_name": "snowflake", "uri": "snowflake://test"}) store = AdbcADKStore(config) ddl = store._get_events_ddl_snowflake() # pyright: ignore[reportPrivateUsage] assert "VARIANT" in ddl - assert "event_json" in ddl + assert "event_data" in ddl -async def test_ddl_dispatch_uses_correct_dialect() -> None: +def test_ddl_dispatch_uses_correct_dialect() -> None: """Test that DDL dispatch selects correct dialect method.""" config = AdbcConfig(connection_config={"driver_name": "postgresql", "uri": ":memory:"}) store = AdbcADKStore(config) - sessions_ddl = await store._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] + sessions_ddl = store._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in sessions_ddl - events_ddl = await store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage] + events_ddl = store._get_create_events_table_sql() # pyright: ignore[reportPrivateUsage] assert "JSONB" in events_ddl - assert "event_json" in events_ddl + assert "event_data" in events_ddl def test_owner_id_column_included_in_sessions_ddl() -> None: diff --git a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py b/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py deleted file mode 100644 index 0e11dd0bb..000000000 --- a/tests/integration/adapters/adbc/extensions/adk/test_edge_cases.py +++ /dev/null @@ -1,273 +0,0 @@ -"""Tests for ADBC ADK store edge cases and error handling.""" - -import json -from pathlib import Path -from typing import Any - -import pytest - -from sqlspec.adapters.adbc import AdbcConfig -from sqlspec.adapters.adbc.adk import AdbcADKStore - -pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration] - - -@pytest.fixture() -async def adbc_store(tmp_path: Path) -> AdbcADKStore: - """Create ADBC ADK store with SQLite backend.""" - db_path = tmp_path / "test_adk.db" - config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) - store = AdbcADKStore(config) - await store.create_tables() - return store - - -async def test_create_tables_idempotent(adbc_store: Any) -> None: - """Test that create_tables can be called multiple times safely.""" - await adbc_store.create_tables() - await adbc_store.create_tables() - - -def test_table_names_validation(tmp_path: Path) -> None: - """Test that invalid table names are rejected.""" - db_path = tmp_path / "test_validation.db" - - with pytest.raises(ValueError, match="Table name cannot be empty"): - config = AdbcConfig( - connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}, - extension_config={"adk": {"session_table": "", "events_table": "events"}}, - ) - AdbcADKStore(config) - - with pytest.raises(ValueError, match="Invalid table name"): - config = AdbcConfig( - connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}, - extension_config={"adk": {"session_table": "invalid-name", "events_table": "events"}}, - ) - AdbcADKStore(config) - - with pytest.raises(ValueError, match="Invalid table name"): - config = AdbcConfig( - connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}, - extension_config={"adk": {"session_table": "1_starts_with_number", "events_table": "events"}}, - ) - AdbcADKStore(config) - - with pytest.raises(ValueError, match="Table name too long"): - long_name = "a" * 100 - config = AdbcConfig( - connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}, - extension_config={"adk": {"session_table": long_name, "events_table": "events"}}, - ) - AdbcADKStore(config) - - -async def test_operations_before_create_tables(tmp_path: Path) -> None: - """Test operations gracefully handle missing tables.""" - db_path = tmp_path / "test_no_tables.db" - config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) - store = AdbcADKStore(config) - - session = await store.get_session("nonexistent") - assert session is None - - sessions = await store.list_sessions("app", "user") - assert sessions == [] - - events = await store.get_events("session") - assert events == [] - - -async def test_custom_table_names(tmp_path: Path) -> None: - """Test using custom table names.""" - db_path = tmp_path / "test_custom.db" - config = AdbcConfig( - connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}, - extension_config={"adk": {"session_table": "custom_sessions", "events_table": "custom_events"}}, - ) - store = AdbcADKStore(config) - await store.create_tables() - - session_id = "test" - session = await store.create_session(session_id, "app", "user", {"data": "test"}) - assert session["id"] == session_id - - retrieved = await store.get_session(session_id) - assert retrieved is not None - - -async def test_unicode_in_fields(adbc_store: Any) -> None: - """Test Unicode characters in various fields.""" - session_id = "unicode-session" - app_name = "\u6d4b\u8bd5\u5e94\u7528" - user_id = "\u30e6\u30fc\u30b6\u30fc123" - state = {"message": "Hello \u4e16\u754c"} - - created_session = await adbc_store.create_session(session_id, app_name, user_id, state) - assert created_session["app_name"] == app_name - assert created_session["user_id"] == user_id - assert created_session["state"]["message"] == "Hello \u4e16\u754c" - - from datetime import datetime, timezone - - from sqlspec.extensions.adk import EventRecord - - event_record: EventRecord = { - "session_id": session_id, - "invocation_id": "", - "author": "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "unicode-event", - "content": {"text": "\u3053\u3093\u306b\u3061\u306f"}, - "app_name": app_name, - "user_id": user_id, - }, - } - await adbc_store.append_event(event_record) - - events = await adbc_store.get_events(session_id) - assert len(events) == 1 - assert events[0]["author"] == "\u30a2\u30b7\u30b9\u30bf\u30f3\u30c8" - event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - ) - assert event_data["content"]["text"] == "\u3053\u3093\u306b\u3061\u306f" - - -async def test_special_characters_in_json(adbc_store: Any) -> None: - """Test special characters in JSON fields.""" - session_id = "special-chars" - state = { - "quotes": 'He said "Hello"', - "backslash": "C:\\Users\\test", - "newline": "Line1\nLine2", - "tab": "Col1\tCol2", - } - - await adbc_store.create_session(session_id, "app", "user", state) - retrieved = await adbc_store.get_session(session_id) - - assert retrieved is not None - assert retrieved["state"] == state - - -async def test_very_long_strings(adbc_store: Any) -> None: - """Test handling very long strings in VARCHAR fields.""" - long_id = "x" * 127 - long_app = "a" * 127 - long_user = "u" * 127 - - session = await adbc_store.create_session(long_id, long_app, long_user, {}) - assert session["id"] == long_id - assert session["app_name"] == long_app - assert session["user_id"] == long_user - - -async def test_session_state_with_deeply_nested_data(adbc_store: Any) -> None: - """Test deeply nested JSON structures.""" - session_id = "deep-nest" - deeply_nested = {"level1": {"level2": {"level3": {"level4": {"level5": {"value": "deep"}}}}}} - - await adbc_store.create_session(session_id, "app", "user", deeply_nested) - retrieved = await adbc_store.get_session(session_id) - - assert retrieved is not None - assert retrieved["state"]["level1"]["level2"]["level3"]["level4"]["level5"]["value"] == "deep" - - -async def test_concurrent_session_updates(adbc_store: Any) -> None: - """Test multiple updates to the same session.""" - session_id = "concurrent-test" - await adbc_store.create_session(session_id, "app", "user", {"version": 1}) - - for i in range(10): - await adbc_store.update_session_state(session_id, {"version": i + 2}) - - final_session = await adbc_store.get_session(session_id) - assert final_session is not None - assert final_session["state"]["version"] == 11 - - -async def test_event_with_none_values(adbc_store: Any) -> None: - """Test creating event with explicit None values for optional fields.""" - session_id = "none-test" - await adbc_store.create_session(session_id, "app", "user", {}) - - from datetime import datetime, timezone - - from sqlspec.extensions.adk import EventRecord - - event_record: EventRecord = { - "session_id": session_id, - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "none-event", "app_name": "app", "user_id": "user"}, - } - await adbc_store.append_event(event_record) - - events = await adbc_store.get_events(session_id) - assert len(events) == 1 - assert events[0]["session_id"] == session_id - assert "event_json" in events[0] - - -async def test_list_sessions_with_same_user_different_apps(adbc_store: Any) -> None: - """Test listing sessions doesn't mix data across apps.""" - user_id = "user-123" - app1 = "app1" - app2 = "app2" - - await adbc_store.create_session("s1", app1, user_id, {}) - await adbc_store.create_session("s2", app1, user_id, {}) - await adbc_store.create_session("s3", app2, user_id, {}) - - app1_sessions = await adbc_store.list_sessions(app1, user_id) - app2_sessions = await adbc_store.list_sessions(app2, user_id) - - assert len(app1_sessions) == 2 - assert len(app2_sessions) == 1 - - -async def test_delete_nonexistent_session(adbc_store: Any) -> None: - """Test deleting a session that doesn't exist.""" - await adbc_store.delete_session("nonexistent-session") - - -async def test_update_nonexistent_session(adbc_store: Any) -> None: - """Test updating a session that doesn't exist.""" - await adbc_store.update_session_state("nonexistent-session", {"data": "test"}) - - -async def test_drop_and_recreate_tables(adbc_store: Any) -> None: - """Test dropping and recreating tables.""" - session_id = "test-session" - await adbc_store.create_session(session_id, "app", "user", {"data": "test"}) - - drop_sqls = adbc_store._get_drop_tables_sql() - with adbc_store._config.provide_connection() as conn: - cursor = conn.cursor() - try: - for sql in drop_sqls: - cursor.execute(sql) - conn.commit() - finally: - cursor.close() - - await adbc_store.create_tables() - - session = await adbc_store.get_session(session_id) - assert session is None - - -async def test_json_with_escaped_characters(adbc_store: Any) -> None: - """Test JSON serialization of escaped characters.""" - session_id = "escaped-json" - state = {"escaped": r"test\nvalue\t", "quotes": r'"quoted"'} - - await adbc_store.create_session(session_id, "app", "user", state) - retrieved = await adbc_store.get_session(session_id) - - assert retrieved is not None - assert retrieved["state"] == state diff --git a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py deleted file mode 100644 index 8e54bf766..000000000 --- a/tests/integration/adapters/adbc/extensions/adk/test_event_operations.py +++ /dev/null @@ -1,396 +0,0 @@ -"""Tests for ADBC ADK store event operations.""" - -import asyncio -import json -from datetime import datetime, timedelta, timezone -from pathlib import Path -from typing import Any - -import pytest - -from sqlspec.adapters.adbc import AdbcConfig -from sqlspec.adapters.adbc.adk import AdbcADKStore -from sqlspec.extensions.adk import EventRecord - -pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration] - - -@pytest.fixture() -async def adbc_store(tmp_path: Path) -> AdbcADKStore: - """Create ADBC ADK store with SQLite backend.""" - db_path = tmp_path / "test_adk.db" - config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) - store = AdbcADKStore(config) - await store.create_tables() - return store - - -@pytest.fixture() -async def session_fixture(adbc_store: Any) -> dict[str, str]: - """Create a test session.""" - session_id = "test-session" - app_name = "test-app" - user_id = "user-123" - state = {"test": True} - await adbc_store.create_session(session_id, app_name, user_id, state) - return {"session_id": session_id, "app_name": app_name, "user_id": user_id} - - -async def test_create_event(adbc_store: Any, session_fixture: Any) -> None: - """Test creating a new event returns 5-key EventRecord.""" - event_record: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "user", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "event-1", - "content": {"message": "Hello"}, - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - } - await adbc_store.append_event(event_record) - - events = await adbc_store.get_events(session_fixture["session_id"]) - assert len(events) == 1 - assert events[0]["session_id"] == session_fixture["session_id"] - assert events[0]["author"] == "user" - assert events[0]["timestamp"] is not None - assert "event_json" in events[0] - - # Content is stored inside event_json - event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - ) - assert event_data["content"] == {"message": "Hello"} - - -async def test_list_events(adbc_store: Any, session_fixture: Any) -> None: - """Test listing events for a session.""" - event1: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "user", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "event-1", - "content": {"seq": 1}, - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - } - event2: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "assistant", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "event-2", - "content": {"seq": 2}, - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - } - await adbc_store.append_event(event1) - await adbc_store.append_event(event2) - - events = await adbc_store.get_events(session_fixture["session_id"]) - - assert len(events) == 2 - assert events[0]["author"] == "user" - assert events[1]["author"] == "assistant" - - -async def test_list_events_empty(adbc_store: Any, session_fixture: Any) -> None: - """Test listing events when none exist.""" - events = await adbc_store.get_events(session_fixture["session_id"]) - assert events == [] - - -async def test_event_with_all_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with all optional fields stored in event_json.""" - event_record: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "invocation-123", - "author": "assistant", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "full-event", - "content": {"text": "Response"}, - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - "branch": "main", - "grounding_metadata": {"sources": ["doc1", "doc2"]}, - "custom_metadata": {"custom": "data"}, - "partial": True, - "turn_complete": False, - "interrupted": False, - "error_code": "NONE", - "error_message": "No errors", - }, - } - await adbc_store.append_event(event_record) - - events = await adbc_store.get_events(session_fixture["session_id"]) - assert len(events) == 1 - - # Top-level indexed columns - assert events[0]["invocation_id"] == "invocation-123" - assert events[0]["author"] == "assistant" - - # Everything else is in event_json - event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - ) - assert event_data["content"] == {"text": "Response"} - assert event_data["branch"] == "main" - assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]} - assert event_data["custom_metadata"] == {"custom": "data"} - assert event_data["partial"] is True - assert event_data["turn_complete"] is False - assert event_data["interrupted"] is False - assert event_data["error_code"] == "NONE" - assert event_data["error_message"] == "No errors" - - -async def test_event_with_minimal_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with only required fields.""" - event_record: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "minimal-event", - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - } - await adbc_store.append_event(event_record) - - events = await adbc_store.get_events(session_fixture["session_id"]) - assert len(events) == 1 - assert events[0]["session_id"] == session_fixture["session_id"] - assert "event_json" in events[0] - - -async def test_event_json_fields(adbc_store: Any, session_fixture: Any) -> None: - """Test event JSON field serialization and deserialization via event_json.""" - complex_content = {"nested": {"data": "value"}, "list": [1, 2, 3], "null": None} - complex_grounding = {"sources": [{"title": "Doc", "url": "http://example.com"}]} - complex_custom = {"metadata": {"version": 1, "tags": ["tag1", "tag2"]}} - - event_record: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "json-event", - "content": complex_content, - "grounding_metadata": complex_grounding, - "custom_metadata": complex_custom, - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - } - await adbc_store.append_event(event_record) - - events = await adbc_store.get_events(session_fixture["session_id"]) - assert len(events) == 1 - event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - ) - assert event_data["content"] == complex_content - assert event_data["grounding_metadata"] == complex_grounding - assert event_data["custom_metadata"] == complex_custom - - -async def test_event_ordering(adbc_store: Any, session_fixture: Any) -> None: - """Test that events are ordered by timestamp ASC.""" - ev1: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, - } - await adbc_store.append_event(ev1) - - await asyncio.sleep(0.01) - - ev2: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, - } - await adbc_store.append_event(ev2) - - await asyncio.sleep(0.01) - - ev3: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-3", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, - } - await adbc_store.append_event(ev3) - - events = await adbc_store.get_events(session_fixture["session_id"]) - - assert len(events) == 3 - assert events[0]["timestamp"] < events[1]["timestamp"] - assert events[1]["timestamp"] < events[2]["timestamp"] - - -async def test_delete_session_cascades_events(adbc_store: Any, session_fixture: Any, tmp_path: Path) -> None: - """Test that deleting a session cascades to delete events. - - Note: SQLite with ADBC requires foreign key enforcement to be explicitly - enabled for cascade deletes to work. This test manually enables it. - """ - ev1: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, - } - ev2: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-2", "app_name": session_fixture["app_name"], "user_id": session_fixture["user_id"]}, - } - await adbc_store.append_event(ev1) - await adbc_store.append_event(ev2) - - events_before = await adbc_store.get_events(session_fixture["session_id"]) - assert len(events_before) == 2 - - await adbc_store.delete_session(session_fixture["session_id"]) - - session_after = await adbc_store.get_session(session_fixture["session_id"]) - assert session_after is None - - -async def test_event_with_empty_actions(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with empty actions bytes.""" - event_record: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "empty-actions", - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - } - await adbc_store.append_event(event_record) - - events = await adbc_store.get_events(session_fixture["session_id"]) - assert len(events) == 1 - assert "event_json" in events[0] - - -async def test_event_with_large_content(adbc_store: Any, session_fixture: Any) -> None: - """Test creating event with large content in event_json.""" - large_content = {"data": "x" * 10000} - - event_record: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "large-content", - "content": large_content, - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - } - await adbc_store.append_event(event_record) - - events = await adbc_store.get_events(session_fixture["session_id"]) - assert len(events) == 1 - event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - ) - assert event_data["content"] == large_content - - -async def test_append_event_preserves_existing_session_state(adbc_store: Any, session_fixture: Any) -> None: - """append_event must not overwrite the durable session state.""" - event_record: EventRecord = { - "session_id": session_fixture["session_id"], - "invocation_id": "append-only", - "author": "user", - "timestamp": datetime.now(timezone.utc), - "event_json": { - "id": "append-only-event", - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - } - - await adbc_store.append_event(event_record) - - session = await adbc_store.get_session(session_fixture["session_id"]) - assert session is not None - assert session["state"] == {"test": True} - - -async def test_get_events_applies_after_timestamp_and_limit(adbc_store: Any, session_fixture: Any) -> None: - """get_events must respect both after_timestamp and limit.""" - base_time = datetime(2026, 1, 1, tzinfo=timezone.utc) - event_records = [ - { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "user", - "timestamp": base_time, - "event_json": { - "id": "event-1", - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - }, - { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "assistant", - "timestamp": base_time + timedelta(seconds=1), - "event_json": { - "id": "event-2", - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - }, - { - "session_id": session_fixture["session_id"], - "invocation_id": "", - "author": "assistant", - "timestamp": base_time + timedelta(seconds=2), - "event_json": { - "id": "event-3", - "app_name": session_fixture["app_name"], - "user_id": session_fixture["user_id"], - }, - }, - ] - - for event_record in event_records: - await adbc_store.append_event(event_record) - - filtered_events = await adbc_store.get_events(session_fixture["session_id"], after_timestamp=base_time, limit=1) - - assert len(filtered_events) == 1 - filtered_event = filtered_events[0]["event_json"] - filtered_data = json.loads(filtered_event) if isinstance(filtered_event, str) else filtered_event - assert filtered_data["id"] == "event-2" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py b/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py index 313a37b38..9d389bb25 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_memory_store.py @@ -30,63 +30,63 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted ) -async def _build_store(tmp_path: Path) -> AdbcADKMemoryStore: +def _build_store(tmp_path: Path) -> AdbcADKMemoryStore: db_path = tmp_path / "test_adk_memory.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKMemoryStore(config) - await store.create_tables() + store.create_tables() return store -async def test_adbc_memory_store_insert_search_dedup(tmp_path: Path) -> None: +def test_adbc_memory_store_insert_search_dedup(tmp_path: Path) -> None: """Insert memory entries, search by text, and skip duplicates.""" - store = await _build_store(tmp_path) + store = _build_store(tmp_path) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now) - inserted = await store.insert_memory_entries([record1, record2]) + inserted = store.insert_memory_entries([record1, record2]) assert inserted == 2 - results = await store.search_entries(query="espresso", app_name="app", user_id="user") + results = store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-1" - deduped = await store.insert_memory_entries([record1]) + deduped = store.insert_memory_entries([record1]) assert deduped == 0 -async def test_adbc_memory_store_delete_by_session(tmp_path: Path) -> None: +def test_adbc_memory_store_delete_by_session(tmp_path: Path) -> None: """Delete memory entries by session id.""" - store = await _build_store(tmp_path) + store = _build_store(tmp_path) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now) - await store.insert_memory_entries([record1, record2]) + store.insert_memory_entries([record1, record2]) - deleted = await store.delete_entries_by_session("s1") + deleted = store.delete_entries_by_session("s1") assert deleted == 1 - remaining = await store.search_entries(query="latte", app_name="app", user_id="user") + remaining = store.search_entries(query="latte", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["session_id"] == "s2" -async def test_adbc_memory_store_delete_older_than(tmp_path: Path) -> None: +def test_adbc_memory_store_delete_older_than(tmp_path: Path) -> None: """Delete memory entries older than a cutoff.""" - store = await _build_store(tmp_path) + store = _build_store(tmp_path) now = datetime.now(timezone.utc) old = now - timedelta(days=40) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now) - await store.insert_memory_entries([record1, record2]) + store.insert_memory_entries([record1, record2]) - deleted = await store.delete_entries_older_than(30) + deleted = store.delete_entries_older_than(30) assert deleted == 1 - remaining = await store.search_entries(query="new", app_name="app", user_id="user") + remaining = store.search_entries(query="new", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["event_id"] == "evt-2" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py index 1b1752ae4..79559405b 100644 --- a/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/adbc/extensions/adk/test_owner_id_column.py @@ -9,7 +9,7 @@ @pytest.fixture() -async def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def] +def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def] """Create ADBC ADK store with owner ID column (SQLite).""" db_path = tmp_path / "test_fk.db" config = AdbcConfig( @@ -29,21 +29,21 @@ async def adbc_store_with_fk(tmp_path): # type: ignore[no-untyped-def] finally: cursor.close() # type: ignore[no-untyped-call] - await store.create_tables() + store.create_tables() return store @pytest.fixture() -async def adbc_store_no_fk(tmp_path): # type: ignore[no-untyped-def] +def adbc_store_no_fk(tmp_path): # type: ignore[no-untyped-def] """Create ADBC ADK store without owner ID column (SQLite).""" db_path = tmp_path / "test_no_fk.db" config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) store = AdbcADKStore(config) - await store.create_tables() + store.create_tables() return store -async def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-untyped-def] +def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore[no-untyped-def] """Test creating session with owner ID value.""" session_id = "test-session-1" app_name = "test-app" @@ -51,32 +51,32 @@ async def test_create_session_with_owner_id(adbc_store_with_fk): # type: ignore state = {"key": "value"} tenant_id = 1 - session = await adbc_store_with_fk.create_session(session_id, app_name, user_id, state, owner_id=tenant_id) + session = adbc_store_with_fk.create_session(session_id, app_name, user_id, state, owner_id=tenant_id) assert session["id"] == session_id assert session["state"] == state -async def test_create_session_without_owner_id_value(adbc_store_with_fk): # type: ignore[no-untyped-def] +def test_create_session_without_owner_id_value(adbc_store_with_fk): # type: ignore[no-untyped-def] """Test creating session without providing owner ID value still works.""" session_id = "test-session-2" app_name = "test-app" user_id = "user-123" state = {"key": "value"} - session = await adbc_store_with_fk.create_session(session_id, app_name, user_id, state) + session = adbc_store_with_fk.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id -async def test_create_session_no_fk_column_configured(adbc_store_no_fk): # type: ignore[no-untyped-def] +def test_create_session_no_fk_column_configured(adbc_store_no_fk): # type: ignore[no-untyped-def] """Test creating session when no FK column configured.""" session_id = "test-session-3" app_name = "test-app" user_id = "user-123" state = {"key": "value"} - session = await adbc_store_no_fk.create_session(session_id, app_name, user_id, state) + session = adbc_store_no_fk.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id assert session["state"] == state @@ -109,17 +109,19 @@ def test_owner_id_column_complex_ddl() -> None: assert store._owner_id_column_ddl == complex_ddl # pyright: ignore[reportPrivateUsage] -async def test_multiple_tenants_isolation(adbc_store_with_fk): # type: ignore[no-untyped-def] +def test_multiple_tenants_isolation(adbc_store_with_fk): # type: ignore[no-untyped-def] """Test sessions are properly isolated by tenant.""" app_name = "test-app" user_id = "user-123" - await adbc_store_with_fk.create_session("session-tenant1", app_name, user_id, {"data": "tenant1"}, owner_id=1) - await adbc_store_with_fk.create_session("session-tenant2", app_name, user_id, {"data": "tenant2"}, owner_id=2) + adbc_store_with_fk.create_session("session-tenant1", app_name, user_id, {"data": "tenant1"}, owner_id=1) + adbc_store_with_fk.create_session("session-tenant2", app_name, user_id, {"data": "tenant2"}, owner_id=2) - retrieved1 = await adbc_store_with_fk.get_session("session-tenant1") - retrieved2 = await adbc_store_with_fk.get_session("session-tenant2") + retrieved1 = adbc_store_with_fk.get_session(app_name, user_id, "session-tenant1") + retrieved2 = adbc_store_with_fk.get_session(app_name, user_id, "session-tenant2") + assert retrieved1 is not None + assert retrieved2 is not None assert retrieved1["state"]["data"] == "tenant1" assert retrieved2["state"]["data"] == "tenant2" diff --git a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py b/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py deleted file mode 100644 index b749461d7..000000000 --- a/tests/integration/adapters/adbc/extensions/adk/test_session_operations.py +++ /dev/null @@ -1,184 +0,0 @@ -"""Tests for ADBC ADK store session operations.""" - -from pathlib import Path -from typing import Any - -import pytest - -from sqlspec.adapters.adbc import AdbcConfig -from sqlspec.adapters.adbc.adk import AdbcADKStore - -pytestmark = [pytest.mark.xdist_group("sqlite"), pytest.mark.adbc, pytest.mark.integration] - - -@pytest.fixture() -async def adbc_store(tmp_path: Path) -> AdbcADKStore: - """Create ADBC ADK store with SQLite backend.""" - db_path = tmp_path / "test_adk.db" - config = AdbcConfig(connection_config={"driver_name": "sqlite", "uri": f"file:{db_path}"}) - store = AdbcADKStore(config) - await store.create_tables() - return store - - -async def test_create_session(adbc_store: Any) -> None: - """Test creating a new session.""" - session_id = "test-session-1" - app_name = "test-app" - user_id = "user-123" - state = {"key": "value", "count": 42} - - session = await adbc_store.create_session(session_id, app_name, user_id, state) - - assert session["id"] == session_id - assert session["app_name"] == app_name - assert session["user_id"] == user_id - assert session["state"] == state - assert session["create_time"] is not None - assert session["update_time"] is not None - - -async def test_get_session(adbc_store: Any) -> None: - """Test retrieving a session by ID.""" - session_id = "test-session-2" - app_name = "test-app" - user_id = "user-123" - state = {"data": "test"} - - await adbc_store.create_session(session_id, app_name, user_id, state) - retrieved = await adbc_store.get_session(session_id) - - assert retrieved is not None - assert retrieved["id"] == session_id - assert retrieved["state"] == state - - -async def test_get_nonexistent_session(adbc_store: Any) -> None: - """Test retrieving a session that doesn't exist.""" - result = await adbc_store.get_session("nonexistent-id") - assert result is None - - -async def test_update_session_state(adbc_store: Any) -> None: - """Test updating session state.""" - session_id = "test-session-3" - app_name = "test-app" - user_id = "user-123" - initial_state = {"version": 1} - - await adbc_store.create_session(session_id, app_name, user_id, initial_state) - - new_state = {"version": 2, "updated": True} - await adbc_store.update_session_state(session_id, new_state) - - updated = await adbc_store.get_session(session_id) - assert updated is not None - assert updated["state"] == new_state - assert updated["state"] != initial_state - - -async def test_delete_session(adbc_store: Any) -> None: - """Test deleting a session.""" - session_id = "test-session-4" - app_name = "test-app" - user_id = "user-123" - state = {"data": "test"} - - await adbc_store.create_session(session_id, app_name, user_id, state) - assert await adbc_store.get_session(session_id) is not None - - await adbc_store.delete_session(session_id) - assert await adbc_store.get_session(session_id) is None - - -async def test_list_sessions(adbc_store: Any) -> None: - """Test listing sessions for an app and user.""" - app_name = "test-app" - user_id = "user-123" - - await adbc_store.create_session("session-1", app_name, user_id, {"num": 1}) - await adbc_store.create_session("session-2", app_name, user_id, {"num": 2}) - await adbc_store.create_session("session-3", "other-app", user_id, {"num": 3}) - - sessions = await adbc_store.list_sessions(app_name, user_id) - - assert len(sessions) == 2 - session_ids = {s["id"] for s in sessions} - assert session_ids == {"session-1", "session-2"} - - -async def test_list_sessions_empty(adbc_store: Any) -> None: - """Test listing sessions when none exist.""" - sessions = await adbc_store.list_sessions("nonexistent-app", "nonexistent-user") - assert sessions == [] - - -async def test_session_state_with_complex_data(adbc_store: Any) -> None: - """Test session state with nested complex data structures.""" - session_id = "complex-session" - app_name = "test-app" - user_id = "user-123" - complex_state = { - "nested": {"key": "value", "number": 42}, - "list": [1, 2, 3], - "mixed": ["string", 123, {"nested": True}], - "null_value": None, - } - - session = await adbc_store.create_session(session_id, app_name, user_id, complex_state) - assert session["state"] == complex_state - - retrieved = await adbc_store.get_session(session_id) - assert retrieved is not None - assert retrieved["state"] == complex_state - - -async def test_session_state_empty_dict(adbc_store: Any) -> None: - """Test creating session with empty state dictionary.""" - session_id = "empty-state-session" - app_name = "test-app" - user_id = "user-123" - empty_state: dict[str, Any] = {} - - session = await adbc_store.create_session(session_id, app_name, user_id, empty_state) - assert session["state"] == empty_state - - retrieved = await adbc_store.get_session(session_id) - assert retrieved is not None - assert retrieved["state"] == empty_state - - -async def test_multiple_users_same_app(adbc_store: Any) -> None: - """Test sessions for multiple users in the same app.""" - app_name = "test-app" - user1 = "user-1" - user2 = "user-2" - - await adbc_store.create_session("session-user1-1", app_name, user1, {"user": 1}) - await adbc_store.create_session("session-user1-2", app_name, user1, {"user": 1}) - await adbc_store.create_session("session-user2-1", app_name, user2, {"user": 2}) - - user1_sessions = await adbc_store.list_sessions(app_name, user1) - user2_sessions = await adbc_store.list_sessions(app_name, user2) - - assert len(user1_sessions) == 2 - assert len(user2_sessions) == 1 - assert all(s["user_id"] == user1 for s in user1_sessions) - assert all(s["user_id"] == user2 for s in user2_sessions) - - -async def test_session_ordering(adbc_store: Any) -> None: - """Test that sessions are ordered by update_time DESC.""" - app_name = "test-app" - user_id = "user-123" - - await adbc_store.create_session("session-1", app_name, user_id, {"order": 1}) - await adbc_store.create_session("session-2", app_name, user_id, {"order": 2}) - await adbc_store.create_session("session-3", app_name, user_id, {"order": 3}) - - await adbc_store.update_session_state("session-1", {"order": 1, "updated": True}) - - sessions = await adbc_store.list_sessions(app_name, user_id) - - assert len(sessions) == 3 - assert sessions[0]["id"] == "session-1" diff --git a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py index 244d11f94..cd6dc0799 100644 --- a/tests/integration/adapters/aiomysql/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiomysql/extensions/adk/test_store.py @@ -57,9 +57,8 @@ async def test_storage_types_verification(aiomysql_adk_store: AiomysqlADKStore) assert "session_id" in event_col_names assert "invocation_id" in event_col_names - assert "author" in event_col_names assert "timestamp" in event_col_names - assert "event_json" in event_col_names + assert "event_data" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in timestamp_col[2].lower(), "timestamp must be TIMESTAMP(6) for microseconds" @@ -78,15 +77,17 @@ async def test_timestamp_precision(aiomysql_adk_store: AiomysqlADKStore) -> None event_time = datetime.now(timezone.utc) event: EventRecord = { + "id": "event-micro", + "app_name": app_name, + "user_id": user_id, "session_id": session_id, "invocation_id": "inv-micro", - "author": "system", "timestamp": event_time, - "event_json": {"app_name": app_name}, + "event_data": {"app_name": app_name, "author": "system"}, } await aiomysql_adk_store.append_event(event) - events = await aiomysql_adk_store.get_events(session_id) + events = await aiomysql_adk_store.get_events(app_name, user_id, session_id) assert len(events) == 1 assert hasattr(events[0]["timestamp"], "microsecond") @@ -123,7 +124,7 @@ async def test_owner_id_constraint_enforcement(aiomysql_adk_store_with_fk: Aiomy session_id=session_id, app_name=app_name, user_id=user_id, state={"tenant": "one"}, owner_id=1 ) - session = await aiomysql_adk_store_with_fk.get_session(session_id) + session = await aiomysql_adk_store_with_fk.get_session(app_name, user_id, session_id) assert session is not None with pytest.raises(Exception): @@ -140,14 +141,14 @@ async def test_owner_id_cascade_delete(aiomysql_adk_store_with_fk: AiomysqlADKSt session_id="tenant1-session", app_name="test-app", user_id="user1", state={"data": "test"}, owner_id=1 ) - session_before = await aiomysql_adk_store_with_fk.get_session("tenant1-session") + session_before = await aiomysql_adk_store_with_fk.get_session("test-app", "user1", "tenant1-session") assert session_before is not None async with config.provide_connection() as conn, AiomysqlCursor(conn) as cursor: await cursor.execute("DELETE FROM test_tenants WHERE id = 1") await conn.commit() - session_after = await aiomysql_adk_store_with_fk.get_session("tenant1-session") + session_after = await aiomysql_adk_store_with_fk.get_session("test-app", "user1", "tenant1-session") assert session_after is None diff --git a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py index 0bc940d38..9c9b68d60 100644 --- a/tests/integration/adapters/asyncmy/extensions/adk/test_store.py +++ b/tests/integration/adapters/asyncmy/extensions/adk/test_store.py @@ -54,12 +54,10 @@ async def test_storage_types_verification(asyncmy_adk_store: AsyncmyADKStore) -> event_columns = await cursor.fetchall() event_col_names = [col[0] for col in event_columns] - # New 5-column schema: session_id, invocation_id, author, timestamp, event_json assert "session_id" in event_col_names assert "invocation_id" in event_col_names - assert "author" in event_col_names assert "timestamp" in event_col_names - assert "event_json" in event_col_names + assert "event_data" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in timestamp_col[2].lower(), "timestamp must be TIMESTAMP(6) for microseconds" @@ -78,15 +76,17 @@ async def test_timestamp_precision(asyncmy_adk_store: AsyncmyADKStore) -> None: event_time = datetime.now(timezone.utc) event: EventRecord = { + "id": "event-micro", + "app_name": app_name, + "user_id": user_id, "session_id": session_id, "invocation_id": "inv-micro", - "author": "system", "timestamp": event_time, - "event_json": {"app_name": app_name}, + "event_data": {"app_name": app_name, "author": "system"}, } await asyncmy_adk_store.append_event(event) - events = await asyncmy_adk_store.get_events(session_id) + events = await asyncmy_adk_store.get_events(app_name, user_id, session_id) assert len(events) == 1 assert hasattr(events[0]["timestamp"], "microsecond") @@ -123,7 +123,7 @@ async def test_owner_id_constraint_enforcement(asyncmy_adk_store_with_fk: Asyncm session_id=session_id, app_name=app_name, user_id=user_id, state={"tenant": "one"}, owner_id=1 ) - session = await asyncmy_adk_store_with_fk.get_session(session_id) + session = await asyncmy_adk_store_with_fk.get_session(app_name, user_id, session_id) assert session is not None with pytest.raises(Exception): @@ -140,14 +140,14 @@ async def test_owner_id_cascade_delete(asyncmy_adk_store_with_fk: AsyncmyADKStor session_id="tenant1-session", app_name="test-app", user_id="user1", state={"data": "test"}, owner_id=1 ) - session_before = await asyncmy_adk_store_with_fk.get_session("tenant1-session") + session_before = await asyncmy_adk_store_with_fk.get_session("test-app", "user1", "tenant1-session") assert session_before is not None async with config.provide_connection() as conn, conn.cursor() as cursor: await cursor.execute("DELETE FROM test_tenants WHERE id = 1") await conn.commit() - session_after = await asyncmy_adk_store_with_fk.get_session("tenant1-session") + session_after = await asyncmy_adk_store_with_fk.get_session("test-app", "user1", "tenant1-session") assert session_after is None diff --git a/tests/integration/adapters/asyncpg/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/asyncpg/extensions/adk/test_owner_id_column.py index 1e4606c9c..bdd72023d 100644 --- a/tests/integration/adapters/asyncpg/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/asyncpg/extensions/adk/test_owner_id_column.py @@ -210,15 +210,15 @@ async def test_cascade_delete_behavior(tenants_table: Any, postgres_service: Any await store.create_session("session-2", "app-1", "user-2", {"data": "test"}, owner_id=1) await store.create_session("session-3", "app-1", "user-3", {"data": "test"}, owner_id=2) - session = await store.get_session("session-1") + session = await store.get_session("app-1", "user-1", "session-1") assert session is not None async with config.provide_connection() as conn: await conn.execute("DELETE FROM tenants WHERE id = 1") - session1 = await store.get_session("session-1") - session2 = await store.get_session("session-2") - session3 = await store.get_session("session-3") + session1 = await store.get_session("app-1", "user-1", "session-1") + session2 = await store.get_session("app-1", "user-2", "session-2") + session3 = await store.get_session("app-1", "user-3", "session-3") assert session1 is None assert session2 is None diff --git a/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py b/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py deleted file mode 100644 index cc6ac7fad..000000000 --- a/tests/integration/adapters/asyncpg/extensions/adk/test_session_operations.py +++ /dev/null @@ -1,134 +0,0 @@ -"""Tests for AsyncPG ADK store session operations.""" - -from typing import Any - -import pytest - -pytestmark = [pytest.mark.xdist_group("postgres"), pytest.mark.asyncpg, pytest.mark.integration] - - -async def test_create_session(asyncpg_adk_store: Any) -> None: - """Test creating a new session.""" - session_id = "session-123" - app_name = "test-app" - user_id = "user-456" - state = {"key": "value"} - - session = await asyncpg_adk_store.create_session(session_id, app_name, user_id, state) - - assert session["id"] == session_id - assert session["app_name"] == app_name - assert session["user_id"] == user_id - assert session["state"] == state - - -async def test_get_session(asyncpg_adk_store: Any) -> None: - """Test retrieving a session by ID.""" - session_id = "session-get" - app_name = "test-app" - user_id = "user-123" - state = {"test": True} - - await asyncpg_adk_store.create_session(session_id, app_name, user_id, state) - - retrieved = await asyncpg_adk_store.get_session(session_id) - - assert retrieved is not None - assert retrieved["id"] == session_id - assert retrieved["app_name"] == app_name - assert retrieved["user_id"] == user_id - assert retrieved["state"] == state - - -async def test_get_nonexistent_session(asyncpg_adk_store: Any) -> None: - """Test retrieving a session that doesn't exist.""" - result = await asyncpg_adk_store.get_session("nonexistent") - assert result is None - - -async def test_update_session_state(asyncpg_adk_store: Any) -> None: - """Test updating session state.""" - session_id = "session-update" - app_name = "test-app" - user_id = "user-123" - initial_state = {"count": 0} - updated_state = {"count": 5, "updated": True} - - await asyncpg_adk_store.create_session(session_id, app_name, user_id, initial_state) - - await asyncpg_adk_store.update_session_state(session_id, updated_state) - - retrieved = await asyncpg_adk_store.get_session(session_id) - assert retrieved is not None - assert retrieved["state"] == updated_state - - -async def test_list_sessions(asyncpg_adk_store: Any) -> None: - """Test listing sessions for an app and user.""" - app_name = "list-test-app" - user_id = "user-list" - - await asyncpg_adk_store.create_session("session-1", app_name, user_id, {"num": 1}) - await asyncpg_adk_store.create_session("session-2", app_name, user_id, {"num": 2}) - await asyncpg_adk_store.create_session("session-3", "other-app", user_id, {"num": 3}) - - sessions = await asyncpg_adk_store.list_sessions(app_name, user_id) - - assert len(sessions) == 2 - session_ids = {s["id"] for s in sessions} - assert session_ids == {"session-1", "session-2"} - - -async def test_list_sessions_empty(asyncpg_adk_store: Any) -> None: - """Test listing sessions when none exist.""" - sessions = await asyncpg_adk_store.list_sessions("nonexistent-app", "nonexistent-user") - assert sessions == [] - - -async def test_delete_session(asyncpg_adk_store: Any) -> None: - """Test deleting a session.""" - session_id = "session-delete" - app_name = "test-app" - user_id = "user-123" - - await asyncpg_adk_store.create_session(session_id, app_name, user_id, {"test": True}) - - await asyncpg_adk_store.delete_session(session_id) - - retrieved = await asyncpg_adk_store.get_session(session_id) - assert retrieved is None - - -async def test_delete_nonexistent_session(asyncpg_adk_store: Any) -> None: - """Test deleting a session that doesn't exist doesn't raise error.""" - await asyncpg_adk_store.delete_session("nonexistent") - - -async def test_session_timestamps(asyncpg_adk_store: Any) -> None: - """Test that create_time and update_time are set correctly.""" - session_id = "session-timestamps" - session = await asyncpg_adk_store.create_session(session_id, "app", "user", {"test": True}) - - assert session["create_time"] is not None - assert session["update_time"] is not None - assert session["create_time"] == session["update_time"] - - -async def test_complex_jsonb_state(asyncpg_adk_store: Any) -> None: - """Test storing complex nested JSONB state.""" - session_id = "session-complex" - complex_state = { - "nested": {"level1": {"level2": {"data": [1, 2, 3], "flags": {"active": True, "verified": False}}}}, - "arrays": ["a", "b", "c"], - "numbers": [1, 2.5, -3], - "nulls": None, - "booleans": [True, False], - } - - session = await asyncpg_adk_store.create_session(session_id, "app", "user", complex_state) - - assert session["state"] == complex_state - - retrieved = await asyncpg_adk_store.get_session(session_id) - assert retrieved is not None - assert retrieved["state"] == complex_state diff --git a/tests/integration/adapters/contracts/_adk_cases.py b/tests/integration/adapters/contracts/_adk_cases.py index 9f4bd02e8..6528f97bf 100644 --- a/tests/integration/adapters/contracts/_adk_cases.py +++ b/tests/integration/adapters/contracts/_adk_cases.py @@ -6,8 +6,11 @@ from _pytest.mark.structures import Mark, MarkDecorator from tests.integration.adapters.contracts._cases import ( + ADBC_MARK, + COCKROACH_XDIST_MARK, DUCKDB_XDIST_MARK, MYSQL_XDIST_MARK, + ORACLE_XDIST_MARK, POSTGRES_XDIST_MARK, SQLITE_XDIST_MARK, ) @@ -58,8 +61,40 @@ class AdkStoreCaseContext: ), AdkStoreCase("asyncpg", "adk_store_asyncpg", "asyncpg", marks=(POSTGRES_XDIST_MARK, pytest.mark.anyio)), AdkStoreCase("psqlpy", "adk_store_psqlpy", "psqlpy", marks=(POSTGRES_XDIST_MARK, pytest.mark.anyio)), + AdkStoreCase("psycopg-async", "adk_store_psycopg_async", "psycopg", marks=(POSTGRES_XDIST_MARK, pytest.mark.anyio)), + AdkStoreCase("psycopg-sync", "adk_store_psycopg_sync", "psycopg", marks=(POSTGRES_XDIST_MARK, pytest.mark.anyio)), + AdkStoreCase("pymysql", "adk_store_pymysql", "pymysql", marks=(MYSQL_XDIST_MARK, pytest.mark.anyio)), + AdkStoreCase( + "cockroach-asyncpg", + "adk_store_cockroach_asyncpg", + "cockroach_asyncpg", + marks=(COCKROACH_XDIST_MARK, pytest.mark.anyio), + ), + AdkStoreCase( + "cockroach-psycopg-async", + "adk_store_cockroach_psycopg_async", + "cockroach_psycopg", + marks=(COCKROACH_XDIST_MARK, pytest.mark.anyio), + ), + AdkStoreCase( + "cockroach-psycopg-sync", + "adk_store_cockroach_psycopg_sync", + "cockroach_psycopg", + marks=(COCKROACH_XDIST_MARK, pytest.mark.anyio), + ), + AdkStoreCase("oracledb-async", "adk_store_oracle_async", "oracledb", marks=(ORACLE_XDIST_MARK, pytest.mark.anyio)), + AdkStoreCase("oracledb-sync", "adk_store_oracle_sync", "oracledb", marks=(ORACLE_XDIST_MARK, pytest.mark.anyio)), + AdkStoreCase("adbc-sqlite", "adk_store_adbc_sqlite", "adbc", marks=(ADBC_MARK, pytest.mark.anyio)), + AdkStoreCase( + "adbc-duckdb", + "adk_store_adbc_duckdb", + "adbc", + marks=(ADBC_MARK, pytest.mark.anyio), + supports_atomic_state_update=False, + ), + AdkStoreCase( + "adbc-postgres", "adk_store_adbc_postgres", "adbc", marks=(ADBC_MARK, POSTGRES_XDIST_MARK, pytest.mark.anyio) + ), ) -# NOTE: psycopg-async/sync are excluded pending sqlspec-cne7 — the psycopg ADK store read -# methods index tuple-cursor rows by string key (TypeError). asyncpg/psqlpy cover postgres here. ADK_STORE_PARAMS = tuple(pytest.param(case, id=case.id, marks=case.marks) for case in ADK_STORE_CASES) diff --git a/tests/integration/adapters/contracts/adk_behaviors.py b/tests/integration/adapters/contracts/adk_behaviors.py index 7d9b7bb2f..50da8f8f1 100644 --- a/tests/integration/adapters/contracts/adk_behaviors.py +++ b/tests/integration/adapters/contracts/adk_behaviors.py @@ -1,24 +1,40 @@ """Behavior helpers for shared ADK session/event store contract tests.""" +from collections.abc import Awaitable from datetime import datetime, timedelta, timezone -from typing import Any +from inspect import isawaitable +from typing import Any, TypeVar from sqlspec.extensions.adk import EventRecord +T = TypeVar("T") + + +async def _resolve(result: T | Awaitable[T]) -> T: + if isawaitable(result): + return await result + return result + async def _aclose(config: Any) -> None: - result = config.close_pool() - if result is not None and hasattr(result, "__await__"): - await result + await _resolve(config.close_pool()) -def _event(session_id: str, index: int, when: datetime) -> EventRecord: +def _event(app_name: str, user_id: str, session_id: str, index: int, when: datetime) -> EventRecord: return { + "id": f"event-{session_id}-{index}", + "app_name": app_name, + "user_id": user_id, "session_id": session_id, "invocation_id": f"inv-{index}", - "author": "user", "timestamp": when, - "event_json": {"id": f"event-{index}", "content": {"parts": [{"text": f"hello-{index}"}]}}, + "event_data": { + "id": f"event-{session_id}-{index}", + "invocation_id": f"inv-{index}", + "author": "user", + "timestamp": when.timestamp(), + "content": {"parts": [{"text": f"hello-{index}"}]}, + }, } @@ -26,9 +42,9 @@ async def assert_adk_create_tables_idempotent_contract(make_store: Any) -> None: """Creating ADK tables twice is a safe no-op.""" config, store = make_store() try: - await store.create_tables() - await store.create_tables() - assert await store.get_session("missing") is None + await _resolve(store.create_tables()) + await _resolve(store.create_tables()) + assert await _resolve(store.get_session("app", "user", "missing")) is None finally: await _aclose(config) @@ -37,12 +53,14 @@ async def assert_adk_session_round_trip_contract(make_store: Any) -> None: """Sessions persist empty and populated state as JSON through a round-trip.""" config, store = make_store() try: - await store.create_tables() - empty = await store.create_session("session-empty", "app", "user", {}) + await _resolve(store.create_tables()) + empty = await _resolve(store.create_session("session-empty", "app", "user", {})) assert empty["state"] == {} - created = await store.create_session("session-state", "app", "user", {"turn": 1, "nested": {"a": [1, 2]}}) - fetched = await store.get_session("session-state") + created = await _resolve( + store.create_session("session-state", "app", "user", {"turn": 1, "nested": {"a": [1, 2]}}) + ) + fetched = await _resolve(store.get_session("app", "user", "session-state")) assert created["state"] == {"turn": 1, "nested": {"a": [1, 2]}} assert fetched is not None assert fetched["state"] == {"turn": 1, "nested": {"a": [1, 2]}} @@ -54,8 +72,8 @@ async def assert_adk_get_nonexistent_session_contract(make_store: Any) -> None: """Reading a missing session returns None.""" config, store = make_store() try: - await store.create_tables() - assert await store.get_session("nope") is None + await _resolve(store.create_tables()) + assert await _resolve(store.get_session("app", "user", "nope")) is None finally: await _aclose(config) @@ -64,10 +82,10 @@ async def assert_adk_update_session_state_contract(make_store: Any) -> None: """update_session_state replaces the durable session state.""" config, store = make_store() try: - await store.create_tables() - await store.create_session("session-update", "app", "user", {"count": 0}) - await store.update_session_state("session-update", {"count": 5}) - fetched = await store.get_session("session-update") + await _resolve(store.create_tables()) + await _resolve(store.create_session("session-update", "app", "user", {"count": 0})) + await _resolve(store.update_session_state("app", "user", "session-update", {"count": 5})) + fetched = await _resolve(store.get_session("app", "user", "session-update")) assert fetched is not None assert fetched["state"] == {"count": 5} finally: @@ -78,19 +96,19 @@ async def assert_adk_list_sessions_contract(make_store: Any) -> None: """list_sessions filters by app and optional user for tenant isolation.""" config, store = make_store() try: - await store.create_tables() - await store.create_session("s1", "app", "user-a", {}) - await store.create_session("s2", "app", "user-a", {}) - await store.create_session("s3", "app", "user-b", {}) - await store.create_session("s4", "other", "user-a", {}) + await _resolve(store.create_tables()) + await _resolve(store.create_session("s1", "app", "user-a", {})) + await _resolve(store.create_session("s2", "app", "user-a", {})) + await _resolve(store.create_session("s3", "app", "user-b", {})) + await _resolve(store.create_session("s4", "other", "user-a", {})) - app_sessions = await store.list_sessions("app") + app_sessions = await _resolve(store.list_sessions("app")) assert {row["id"] for row in app_sessions} == {"s1", "s2", "s3"} - user_sessions = await store.list_sessions("app", "user-a") + user_sessions = await _resolve(store.list_sessions("app", "user-a")) assert {row["id"] for row in user_sessions} == {"s1", "s2"} - assert await store.list_sessions("empty-app") == [] + assert await _resolve(store.list_sessions("empty-app")) == [] finally: await _aclose(config) @@ -99,14 +117,16 @@ async def assert_adk_delete_session_cascade_contract(make_store: Any) -> None: """Deleting a session removes the session and its events.""" config, store = make_store() try: - await store.create_tables() + await _resolve(store.create_tables()) session_id = "session-delete" - await store.create_session(session_id, "app", "user", {}) - await store.append_event(_event(session_id, 0, datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc))) + await _resolve(store.create_session(session_id, "app", "user", {})) + await _resolve( + store.append_event(_event("app", "user", session_id, 0, datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc))) + ) - await store.delete_session(session_id) - assert await store.get_session(session_id) is None - assert await store.get_events(session_id) == [] + await _resolve(store.delete_session("app", "user", session_id)) + assert await _resolve(store.get_session("app", "user", session_id)) is None + assert await _resolve(store.get_events("app", "user", session_id)) == [] finally: await _aclose(config) @@ -115,15 +135,20 @@ async def assert_adk_append_and_get_events_contract(make_store: Any) -> None: """Appended events round-trip through get_events.""" config, store = make_store() try: - await store.create_tables() + await _resolve(store.create_tables()) session_id = "session-events" - await store.create_session(session_id, "app", "user", {}) - await store.append_event(_event(session_id, 1, datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc))) + await _resolve(store.create_session(session_id, "app", "user", {})) + await _resolve( + store.append_event(_event("app", "user", session_id, 1, datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc))) + ) - events = await store.get_events(session_id) + events = await _resolve(store.get_events("app", "user", session_id)) assert len(events) == 1 + assert events[0]["id"] == f"event-{session_id}-1" + assert events[0]["app_name"] == "app" + assert events[0]["user_id"] == "user" assert events[0]["invocation_id"] == "inv-1" - assert events[0]["event_json"] == {"id": "event-1", "content": {"parts": [{"text": "hello-1"}]}} + assert events[0]["event_data"]["content"] == {"parts": [{"text": "hello-1"}]} finally: await _aclose(config) @@ -132,20 +157,30 @@ async def assert_adk_append_event_and_update_state_contract(make_store: Any) -> """Atomic append updates state and stores the event in one round-trip.""" config, store = make_store() try: - await store.create_tables() + await _resolve(store.create_tables()) session_id = "session-atomic" - await store.create_session(session_id, "app", "user", {}) - updated = await store.append_event_and_update_state( - _event(session_id, 1, datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc)), session_id, {"turn": 1} + await _resolve(store.create_session(session_id, "app", "user", {})) + updated = await _resolve( + store.append_event_and_update_state( + _event("app", "user", session_id, 1, datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc)), + "app", + "user", + session_id, + {"turn": 1}, + app_state={"app:theme": "dark"}, + user_state={"user:locale": "en-US"}, + ) ) assert updated["state"] == {"turn": 1} - session = await store.get_session(session_id) - events = await store.get_events(session_id) + session = await _resolve(store.get_session("app", "user", session_id)) + events = await _resolve(store.get_events("app", "user", session_id)) assert session is not None assert session["state"] == {"turn": 1} assert len(events) == 1 assert events[0]["invocation_id"] == "inv-1" + assert await _resolve(store.get_app_state("app")) == {"app:theme": "dark"} + assert await _resolve(store.get_user_state("app", "user")) == {"user:locale": "en-US"} finally: await _aclose(config) @@ -154,16 +189,22 @@ async def assert_adk_get_events_filtering_contract(make_store: Any) -> None: """get_events honors after_timestamp and limit.""" config, store = make_store() try: - await store.create_tables() + await _resolve(store.create_tables()) session_id = "session-filter" - await store.create_session(session_id, "app", "user", {}) + await _resolve(store.create_session(session_id, "app", "user", {})) base = datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc) for index in range(3): - await store.append_event(_event(session_id, index, base + timedelta(seconds=index))) + await _resolve( + store.append_event(_event("app", "user", session_id, index, base + timedelta(seconds=index))) + ) - events = await store.get_events(session_id, after_timestamp=base + timedelta(milliseconds=500), limit=1) + events = await _resolve( + store.get_events("app", "user", session_id, after_timestamp=base + timedelta(milliseconds=500), limit=1) + ) assert len(events) == 1 assert events[0]["invocation_id"] == "inv-1" + + assert await _resolve(store.get_events("app", "user", session_id, limit=0)) == [] finally: await _aclose(config) @@ -172,9 +213,11 @@ async def assert_adk_reads_empty_when_tables_missing_contract(make_store: Any) - """Read paths return None/empty when the ADK tables do not exist.""" config, store = make_store() try: - assert await store.get_session("missing") is None - assert await store.list_sessions("app") == [] - assert await store.list_sessions("app", "user") == [] - assert await store.get_events("session-x") == [] + assert await _resolve(store.get_session("app", "user", "missing")) is None + assert await _resolve(store.list_sessions("app")) == [] + assert await _resolve(store.list_sessions("app", "user")) == [] + assert await _resolve(store.get_events("app", "user", "session-x")) == [] + assert await _resolve(store.get_app_state("app")) is None + assert await _resolve(store.get_user_state("app", "user")) is None finally: await _aclose(config) diff --git a/tests/integration/adapters/contracts/conftest.py b/tests/integration/adapters/contracts/conftest.py index 006571ed9..a5df51927 100644 --- a/tests/integration/adapters/contracts/conftest.py +++ b/tests/integration/adapters/contracts/conftest.py @@ -17,6 +17,7 @@ from pytest_databases.docker.postgres import PostgresService from sqlspec.adapters.adbc import AdbcConfig, AdbcDriver +from sqlspec.adapters.adbc.adk import AdbcADKStore from sqlspec.adapters.aiomysql import AiomysqlConfig, AiomysqlDriver, AiomysqlDriverFeatures from sqlspec.adapters.aiomysql.adk import AiomysqlADKStore from sqlspec.adapters.aiomysql.litestar import AiomysqlStore @@ -35,6 +36,7 @@ CockroachAsyncpgDriver, CockroachAsyncpgDriverFeatures, ) +from sqlspec.adapters.cockroach_asyncpg.adk import CockroachAsyncpgADKStore from sqlspec.adapters.cockroach_psycopg import ( CockroachPsycopgAsyncConfig, CockroachPsycopgAsyncDriver, @@ -42,6 +44,7 @@ CockroachPsycopgSyncConfig, CockroachPsycopgSyncDriver, ) +from sqlspec.adapters.cockroach_psycopg.adk import CockroachPsycopgAsyncADKStore, CockroachPsycopgSyncADKStore from sqlspec.adapters.duckdb import DuckDBConfig, DuckDBDriver, DuckDBDriverFeatures from sqlspec.adapters.duckdb.adk import DuckdbADKStore from sqlspec.adapters.duckdb.litestar import DuckdbStore @@ -62,6 +65,7 @@ OracleSyncConfig, OracleSyncDriver, ) +from sqlspec.adapters.oracledb.adk import OracleAsyncADKStore, OracleSyncADKStore from sqlspec.adapters.psqlpy import PsqlpyConfig, PsqlpyDriver, PsqlpyDriverFeatures from sqlspec.adapters.psqlpy.adk import PsqlpyADKStore from sqlspec.adapters.psqlpy.litestar import PsqlpyStore @@ -72,8 +76,10 @@ PsycopgSyncConfig, PsycopgSyncDriver, ) +from sqlspec.adapters.psycopg.adk import PsycopgAsyncADKStore, PsycopgSyncADKStore from sqlspec.adapters.psycopg.litestar import PsycopgAsyncStore, PsycopgSyncStore from sqlspec.adapters.pymysql import PyMysqlConfig, PyMysqlDriver, PyMysqlDriverFeatures +from sqlspec.adapters.pymysql.adk import PyMysqlADKStore from sqlspec.adapters.pymysql.litestar import PyMysqlStore from sqlspec.adapters.sqlite import SqliteConfig, SqliteDriver, SqliteDriverFeatures from sqlspec.adapters.sqlite.adk import SqliteADKStore @@ -1499,7 +1505,25 @@ async def contract_pymysql_store(mysql_service: MySQLService) -> "AsyncGenerator def _adk_extension_config(suffix: str) -> dict[str, Any]: - return {"adk": {"session_table": f"adk_s_{suffix}", "events_table": f"adk_e_{suffix}"}} + return { + "adk": { + "session_table": f"adk_s_{suffix}", + "events_table": f"adk_e_{suffix}", + "app_state_table": f"adk_app_{suffix}", + "user_state_table": f"adk_user_{suffix}", + "metadata_table": f"adk_meta_{suffix}", + } + } + + +def _ensure_adbc_store_driver_available(config: AdbcConfig) -> None: + try: + with config.provide_session(): + pass + except Exception as error: + if any(marker in str(error) for marker in _ADBC_DRIVER_MISSING_MARKERS): + pytest.skip(f"ADBC driver not available: {error}") + raise @pytest.fixture @@ -1637,6 +1661,186 @@ def make() -> "tuple[Any, Any]": return make +@pytest.fixture +def adk_store_psycopg_async(postgres_service: PostgresService) -> Callable[..., Any]: + """Build a fresh psycopg async ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = PsycopgAsyncConfig( + connection_config={"conninfo": _postgres_conninfo(postgres_service), "autocommit": True}, + extension_config=_adk_extension_config(suffix), + ) + return config, PsycopgAsyncADKStore(config) + + return make + + +@pytest.fixture +def adk_store_psycopg_sync(postgres_service: PostgresService) -> Callable[..., Any]: + """Build a fresh psycopg sync ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = PsycopgSyncConfig( + connection_config={"conninfo": _postgres_conninfo(postgres_service), "autocommit": True}, + extension_config=_adk_extension_config(suffix), + ) + return config, PsycopgSyncADKStore(config) + + return make + + +@pytest.fixture +def adk_store_oracle_async(oracle_23ai_service: OracleService) -> Callable[..., Any]: + """Build a fresh Oracle async ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = OracleAsyncConfig( + connection_config=_oracle_pool_params(oracle_23ai_service), extension_config=_adk_extension_config(suffix) + ) + return config, OracleAsyncADKStore(config) + + return make + + +@pytest.fixture +def adk_store_oracle_sync(oracle_23ai_service: OracleService) -> Callable[..., Any]: + """Build a fresh Oracle sync ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = OracleSyncConfig( + connection_config=_oracle_pool_params(oracle_23ai_service), extension_config=_adk_extension_config(suffix) + ) + return config, OracleSyncADKStore(config) + + return make + + +@pytest.fixture +def adk_store_pymysql(mysql_service: MySQLService) -> Callable[..., Any]: + """Build a fresh PyMySQL ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = PyMysqlConfig( + connection_config=_mysql_connection_config(mysql_service), extension_config=_adk_extension_config(suffix) + ) + return config, PyMysqlADKStore(config) + + return make + + +@pytest.fixture +def adk_store_cockroach_asyncpg(cockroachdb_service: CockroachDBService) -> Callable[..., Any]: + """Build a fresh CockroachDB asyncpg ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = CockroachAsyncpgConfig( + connection_config={ + "host": cockroachdb_service.host, + "port": cockroachdb_service.port, + "user": "root", + "password": "", + "database": cockroachdb_service.database, + "ssl": None, + "min_size": 1, + "max_size": 5, + }, + extension_config=_adk_extension_config(suffix), + ) + return config, CockroachAsyncpgADKStore(config) + + return make + + +@pytest.fixture +def adk_store_cockroach_psycopg_async(cockroachdb_service: CockroachDBService) -> Callable[..., Any]: + """Build a fresh CockroachDB psycopg async ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = CockroachPsycopgAsyncConfig( + connection_config={"conninfo": _cockroach_conninfo(cockroachdb_service)}, + extension_config=_adk_extension_config(suffix), + ) + return config, CockroachPsycopgAsyncADKStore(config) + + return make + + +@pytest.fixture +def adk_store_cockroach_psycopg_sync(cockroachdb_service: CockroachDBService) -> Callable[..., Any]: + """Build a fresh CockroachDB psycopg sync ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = CockroachPsycopgSyncConfig( + connection_config={"conninfo": _cockroach_conninfo(cockroachdb_service)}, + extension_config=_adk_extension_config(suffix), + ) + return config, CockroachPsycopgSyncADKStore(config) + + return make + + +@pytest.fixture +def adk_store_adbc_sqlite(tmp_path: Path) -> Callable[..., Any]: + """Build a fresh ADBC SQLite ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = AdbcConfig( + connection_config={"driver_name": "sqlite", "uri": f"file:{tmp_path / f'adk_{suffix}.db'}"}, + extension_config=_adk_extension_config(suffix), + ) + _ensure_adbc_store_driver_available(config) + return config, AdbcADKStore(config) + + return make + + +@pytest.fixture +def adk_store_adbc_duckdb(tmp_path: Path) -> Callable[..., Any]: + """Build a fresh ADBC DuckDB ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = AdbcConfig( + connection_config={"driver_name": "duckdb", "path": str(tmp_path / f"adk_{suffix}.duckdb")}, + extension_config=_adk_extension_config(suffix), + ) + _ensure_adbc_store_driver_available(config) + return config, AdbcADKStore(config) + + return make + + +@pytest.fixture +def adk_store_adbc_postgres(postgres_service: PostgresService) -> Callable[..., Any]: + """Build a fresh ADBC PostgreSQL ADK store with isolated tables per call.""" + + def make() -> "tuple[Any, Any]": + suffix = uuid4().hex[:8] + config = AdbcConfig( + connection_config={ + "driver_name": "postgresql", + "uri": ( + f"postgresql://{postgres_service.user}:{postgres_service.password}" + f"@{postgres_service.host}:{postgres_service.port}/{postgres_service.database}" + ), + }, + extension_config=_adk_extension_config(suffix), + ) + _ensure_adbc_store_driver_available(config) + return config, AdbcADKStore(config) + + return make + + def _resolve_adk_store_case(request: pytest.FixtureRequest, case: AdkStoreCase) -> AdkStoreCaseContext: return AdkStoreCaseContext(case=case, make_store=request.getfixturevalue(case.factory_fixture)) diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py index b5a2b6ca5..cd97a092c 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_memory_store.py @@ -31,15 +31,15 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted ) -async def _build_store(tmp_path: Path) -> DuckdbADKMemoryStore: +def _build_store(tmp_path: Path) -> DuckdbADKMemoryStore: db_path = tmp_path / "test_adk_memory.duckdb" config = DuckDBConfig(connection_config={"database": str(db_path)}) store = DuckdbADKMemoryStore(config) - await store.create_tables() + store.create_tables() return store -async def _build_fts_store(tmp_path: Path) -> DuckdbADKMemoryStore: +def _build_fts_store(tmp_path: Path) -> DuckdbADKMemoryStore: db_path = tmp_path / "test_adk_memory_fts.duckdb" config = DuckDBConfig( connection_config={"database": str(db_path)}, extension_config={"adk": {"memory_use_fts": True}} @@ -48,74 +48,74 @@ async def _build_fts_store(tmp_path: Path) -> DuckdbADKMemoryStore: with config.provide_connection() as conn: if not store._ensure_fts_extension(conn): # pyright: ignore[reportPrivateUsage] pytest.skip("DuckDB FTS extension is unavailable") - await store.create_tables() + store.create_tables() return store -async def test_duckdb_memory_store_insert_search_dedup(tmp_path: Path) -> None: +def test_duckdb_memory_store_insert_search_dedup(tmp_path: Path) -> None: """Insert memory entries, search by text, and skip duplicates.""" - store = await _build_store(tmp_path) + store = _build_store(tmp_path) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now) - inserted = await store.insert_memory_entries([record1, record2]) + inserted = store.insert_memory_entries([record1, record2]) assert inserted == 2 - results = await store.search_entries(query="espresso", app_name="app", user_id="user") + results = store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-1" - deduped = await store.insert_memory_entries([record1]) + deduped = store.insert_memory_entries([record1]) assert deduped == 0 -async def test_duckdb_memory_store_delete_by_session(tmp_path: Path) -> None: +def test_duckdb_memory_store_delete_by_session(tmp_path: Path) -> None: """Delete memory entries by session id.""" - store = await _build_store(tmp_path) + store = _build_store(tmp_path) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now) - await store.insert_memory_entries([record1, record2]) + store.insert_memory_entries([record1, record2]) - deleted = await store.delete_entries_by_session("s1") + deleted = store.delete_entries_by_session("s1") assert deleted == 1 - remaining = await store.search_entries(query="latte", app_name="app", user_id="user") + remaining = store.search_entries(query="latte", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["session_id"] == "s2" -async def test_duckdb_memory_store_delete_older_than(tmp_path: Path) -> None: +def test_duckdb_memory_store_delete_older_than(tmp_path: Path) -> None: """Delete memory entries older than a cutoff.""" - store = await _build_store(tmp_path) + store = _build_store(tmp_path) now = datetime.now(timezone.utc) old = now - timedelta(days=40) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now) - await store.insert_memory_entries([record1, record2]) + store.insert_memory_entries([record1, record2]) - deleted = await store.delete_entries_older_than(30) + deleted = store.delete_entries_older_than(30) assert deleted == 1 - remaining = await store.search_entries(query="new", app_name="app", user_id="user") + remaining = store.search_entries(query="new", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["event_id"] == "evt-2" -async def test_duckdb_memory_store_fts_search_uses_bm25_path(tmp_path: Path) -> None: +def test_duckdb_memory_store_fts_search_uses_bm25_path(tmp_path: Path) -> None: """FTS-enabled DuckDB stores search through the BM25 index after insert refresh.""" - store = await _build_fts_store(tmp_path) + store = _build_fts_store(tmp_path) now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-fts-1", content_text="espresso roast", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-fts-2", content_text="latte foam", inserted_at=now) - await store.insert_memory_entries([record1, record2]) + store.insert_memory_entries([record1, record2]) - results = await store.search_entries(query="espresso", app_name="app", user_id="user") + results = store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-fts-1" diff --git a/tests/integration/adapters/duckdb/extensions/adk/test_store.py b/tests/integration/adapters/duckdb/extensions/adk/test_store.py index b3e38b06a..8c27fa3a0 100644 --- a/tests/integration/adapters/duckdb/extensions/adk/test_store.py +++ b/tests/integration/adapters/duckdb/extensions/adk/test_store.py @@ -7,8 +7,7 @@ concurrency, event ordering/JSON details) that is not portable across the contract matrix. """ -import json -from collections.abc import AsyncGenerator +from collections.abc import Generator from datetime import datetime, timezone from pathlib import Path @@ -22,7 +21,7 @@ @pytest.fixture -async def duckdb_adk_store(tmp_path: Path) -> "AsyncGenerator[DuckdbADKStore, None]": +def duckdb_adk_store(tmp_path: Path) -> "Generator[DuckdbADKStore, None, None]": """Create DuckDB ADK store with temporary file-based database. Args: @@ -41,25 +40,28 @@ async def duckdb_adk_store(tmp_path: Path) -> "AsyncGenerator[DuckdbADKStore, No extension_config={"adk": {"session_table": "test_sessions", "events_table": "test_events"}}, ) store = DuckdbADKStore(config) - await store.create_tables() + store.create_tables() yield store finally: if db_path.exists(): db_path.unlink() -async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: - """Test creating events with optional fields stored in event_json.""" +def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> None: + """Test creating events with optional fields stored in event_data.""" session_id = "session-008" - await duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) + duckdb_adk_store.create_session(session_id, "test-app", "user-008", {}) event_record: EventRecord = { + "id": "event-full", + "app_name": "test-app", + "user_id": "user-008", "session_id": session_id, "invocation_id": "inv-123", - "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": { + "event_data": { "id": "event-full", + "author": "assistant", "content": {"text": "Response"}, "app_name": "test-app", "user_id": "user-008", @@ -71,71 +73,70 @@ async def test_event_with_optional_fields(duckdb_adk_store: DuckdbADKStore) -> N "interrupted": False, }, } - await duckdb_adk_store.append_event(event_record) + duckdb_adk_store.append_event(event_record) - events = await duckdb_adk_store.get_events(session_id) + events = duckdb_adk_store.get_events("test-app", "user-008", session_id) assert len(events) == 1 - # The 5-key record has invocation_id as a top-level indexed column assert events[0]["invocation_id"] == "inv-123" - # Other fields are inside event_json - event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - ) + event_data = events[0]["event_data"] assert event_data["branch"] == "main" assert event_data["grounding_metadata"] == {"sources": ["doc1", "doc2"]} assert event_data["partial"] is True assert event_data["turn_complete"] is False -async def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None: +def test_event_ordering_by_timestamp(duckdb_adk_store: DuckdbADKStore) -> None: """Test events are ordered by timestamp ascending.""" session_id = "session-009" - await duckdb_adk_store.create_session(session_id, "test-app", "user-009", {}) + duckdb_adk_store.create_session(session_id, "test-app", "user-009", {}) t1 = datetime.now(timezone.utc) t2 = datetime.now(timezone.utc) t3 = datetime.now(timezone.utc) ev_middle: EventRecord = { + "id": "event-middle", + "app_name": "test-app", + "user_id": "user-009", "session_id": session_id, "invocation_id": "", - "author": "", "timestamp": t2, - "event_json": {"id": "event-middle", "app_name": "test-app", "user_id": "user-009"}, + "event_data": {"id": "event-middle", "app_name": "test-app", "user_id": "user-009"}, } ev_last: EventRecord = { + "id": "event-last", + "app_name": "test-app", + "user_id": "user-009", "session_id": session_id, "invocation_id": "", - "author": "", "timestamp": t3, - "event_json": {"id": "event-last", "app_name": "test-app", "user_id": "user-009"}, + "event_data": {"id": "event-last", "app_name": "test-app", "user_id": "user-009"}, } ev_first: EventRecord = { + "id": "event-first", + "app_name": "test-app", + "user_id": "user-009", "session_id": session_id, "invocation_id": "", - "author": "", "timestamp": t1, - "event_json": {"id": "event-first", "app_name": "test-app", "user_id": "user-009"}, + "event_data": {"id": "event-first", "app_name": "test-app", "user_id": "user-009"}, } - await duckdb_adk_store.append_event(ev_middle) - await duckdb_adk_store.append_event(ev_last) - await duckdb_adk_store.append_event(ev_first) + duckdb_adk_store.append_event(ev_middle) + duckdb_adk_store.append_event(ev_last) + duckdb_adk_store.append_event(ev_first) - events = await duckdb_adk_store.get_events(session_id) + events = duckdb_adk_store.get_events("test-app", "user-009", session_id) assert len(events) == 3 # Events should be ordered by timestamp ASC - event_ids = [] - for e in events: - data = json.loads(e["event_json"]) if isinstance(e["event_json"], str) else e["event_json"] - event_ids.append(data["id"]) + event_ids = [event["event_data"]["id"] for event in events] assert event_ids == ["event-first", "event-middle", "event-last"] -async def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> None: +def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) -> None: """Test session state with nested JSON structures.""" session_id = "session-complex" complex_state = { @@ -148,54 +149,59 @@ async def test_session_state_with_complex_data(duckdb_adk_store: DuckdbADKStore) "flags": [True, False, True], } - await duckdb_adk_store.create_session(session_id, "test-app", "user-010", complex_state) + duckdb_adk_store.create_session(session_id, "test-app", "user-010", complex_state) - session = await duckdb_adk_store.get_session(session_id) + session = duckdb_adk_store.get_session("test-app", "user-010", session_id) assert session is not None assert session["state"] == complex_state assert session["state"]["user"]["preferences"]["theme"] == "dark" assert session["state"]["conversation"]["turn_count"] == 5 -async def test_event_json_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: - """Test storing and retrieving event data via event_json.""" +def test_event_data_round_trip(duckdb_adk_store: DuckdbADKStore) -> None: + """Test storing and retrieving event data via event_data.""" session_id = "session-json-rt" - await duckdb_adk_store.create_session(session_id, "test-app", "user-012", {}) + duckdb_adk_store.create_session(session_id, "test-app", "user-012", {}) event_record: EventRecord = { + "id": "event-json", + "app_name": "test-app", + "user_id": "user-012", "session_id": session_id, "invocation_id": "", - "author": "system", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-json", "content": {"data": "value"}, "app_name": "test-app", "user_id": "user-012"}, + "event_data": { + "id": "event-json", + "author": "system", + "content": {"data": "value"}, + "app_name": "test-app", + "user_id": "user-012", + }, } - await duckdb_adk_store.append_event(event_record) + duckdb_adk_store.append_event(event_record) - events = await duckdb_adk_store.get_events(session_id) + events = duckdb_adk_store.get_events("test-app", "user-012", session_id) assert len(events) == 1 - event_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] - ) - assert event_data["content"] == {"data": "value"} + assert events[0]["event_data"]["content"] == {"data": "value"} -async def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> None: +def test_concurrent_session_updates(duckdb_adk_store: DuckdbADKStore) -> None: """Test multiple updates to same session.""" session_id = "session-concurrent" - await duckdb_adk_store.create_session(session_id, "test-app", "user-013", {"counter": 0}) + duckdb_adk_store.create_session(session_id, "test-app", "user-013", {"counter": 0}) for i in range(10): - session = await duckdb_adk_store.get_session(session_id) + session = duckdb_adk_store.get_session("test-app", "user-013", session_id) assert session is not None current_counter = session["state"]["counter"] - await duckdb_adk_store.update_session_state(session_id, {"counter": current_counter + 1}) + duckdb_adk_store.update_session_state("test-app", "user-013", session_id, {"counter": current_counter + 1}) - final_session = await duckdb_adk_store.get_session(session_id) + final_session = duckdb_adk_store.get_session("test-app", "user-013", session_id) assert final_session is not None assert final_session["state"]["counter"] == 10 -async def test_owner_id_column_with_integer(tmp_path: Path) -> None: +def test_owner_id_column_with_integer(tmp_path: Path) -> None: """Test owner ID column with INTEGER type.""" db_path = tmp_path / "test_owner_id_int.duckdb" try: @@ -217,12 +223,12 @@ async def test_owner_id_column_with_integer(tmp_path: Path) -> None: }, ) store = DuckdbADKStore(config_with_extension) - await store.create_tables() + store.create_tables() assert store.owner_id_column_name == "tenant_id" assert store.owner_id_column_ddl == "tenant_id INTEGER NOT NULL REFERENCES tenants(id)" - session = await store.create_session( + session = store.create_session( session_id="session-tenant-1", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=1 ) @@ -238,7 +244,7 @@ async def test_owner_id_column_with_integer(tmp_path: Path) -> None: db_path.unlink() -async def test_owner_id_column_with_ubigint(tmp_path: Path) -> None: +def test_owner_id_column_with_ubigint(tmp_path: Path) -> None: """Test owner ID column with DuckDB UBIGINT type.""" db_path = tmp_path / "test_owner_id_ubigint.duckdb" try: @@ -260,11 +266,11 @@ async def test_owner_id_column_with_ubigint(tmp_path: Path) -> None: }, ) store = DuckdbADKStore(config_with_extension) - await store.create_tables() + store.create_tables() assert store.owner_id_column_name == "owner_id" - session = await store.create_session( + session = store.create_session( session_id="session-user-1", app_name="test-app", user_id="user-001", @@ -284,7 +290,7 @@ async def test_owner_id_column_with_ubigint(tmp_path: Path) -> None: db_path.unlink() -async def test_owner_id_column_foreign_key_constraint(tmp_path: Path) -> None: +def test_owner_id_column_foreign_key_constraint(tmp_path: Path) -> None: """Test that FK constraint is enforced.""" db_path = tmp_path / "test_owner_id_constraint.duckdb" try: @@ -306,14 +312,14 @@ async def test_owner_id_column_foreign_key_constraint(tmp_path: Path) -> None: }, ) store = DuckdbADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - await store.create_session( + store.create_session( session_id="session-org-1", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=100 ) with pytest.raises(Exception) as exc_info: - await store.create_session( + store.create_session( session_id="session-org-invalid", app_name="test-app", user_id="user-002", @@ -327,7 +333,7 @@ async def test_owner_id_column_foreign_key_constraint(tmp_path: Path) -> None: db_path.unlink() -async def test_owner_id_column_without_value(tmp_path: Path) -> None: +def test_owner_id_column_without_value(tmp_path: Path) -> None: """Test creating session without owner_id when column is configured but nullable.""" db_path = tmp_path / "test_owner_id_nullable.duckdb" try: @@ -348,22 +354,22 @@ async def test_owner_id_column_without_value(tmp_path: Path) -> None: }, ) store = DuckdbADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - session = await store.create_session( + session = store.create_session( session_id="session-no-fk", app_name="test-app", user_id="user-001", state={"data": "test"}, owner_id=None ) assert session["id"] == "session-no-fk" - retrieved = await store.get_session("session-no-fk") + retrieved = store.get_session("test-app", "user-001", "session-no-fk") assert retrieved is not None finally: if db_path.exists(): db_path.unlink() -async def test_owner_id_column_with_varchar(tmp_path: Path) -> None: +def test_owner_id_column_with_varchar(tmp_path: Path) -> None: """Test owner ID column with VARCHAR type.""" db_path = tmp_path / "test_owner_id_varchar.duckdb" try: @@ -385,9 +391,9 @@ async def test_owner_id_column_with_varchar(tmp_path: Path) -> None: }, ) store = DuckdbADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - session = await store.create_session( + session = store.create_session( session_id="session-company-1", app_name="test-app", user_id="user-001", @@ -407,7 +413,7 @@ async def test_owner_id_column_with_varchar(tmp_path: Path) -> None: db_path.unlink() -async def test_owner_id_column_multiple_sessions(tmp_path: Path) -> None: +def test_owner_id_column_multiple_sessions(tmp_path: Path) -> None: """Test multiple sessions with same FK value.""" db_path = tmp_path / "test_owner_id_multiple.duckdb" try: @@ -429,10 +435,10 @@ async def test_owner_id_column_multiple_sessions(tmp_path: Path) -> None: }, ) store = DuckdbADKStore(config_with_extension) - await store.create_tables() + store.create_tables() for i in range(5): - await store.create_session( + store.create_session( session_id=f"session-dept-{i}", app_name="test-app", user_id=f"user-{i}", @@ -450,7 +456,7 @@ async def test_owner_id_column_multiple_sessions(tmp_path: Path) -> None: db_path.unlink() -async def test_owner_id_column_query_by_fk(tmp_path: Path) -> None: +def test_owner_id_column_query_by_fk(tmp_path: Path) -> None: """Test querying sessions by FK column value.""" db_path = tmp_path / "test_owner_id_query.duckdb" try: @@ -472,11 +478,11 @@ async def test_owner_id_column_query_by_fk(tmp_path: Path) -> None: }, ) store = DuckdbADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - await store.create_session("s1", "app", "u1", {"val": 1}, owner_id=1) - await store.create_session("s2", "app", "u2", {"val": 2}, owner_id=1) - await store.create_session("s3", "app", "u3", {"val": 3}, owner_id=2) + store.create_session("s1", "app", "u1", {"val": 1}, owner_id=1) + store.create_session("s2", "app", "u2", {"val": 2}, owner_id=1) + store.create_session("s3", "app", "u3", {"val": 3}, owner_id=2) with config.provide_connection() as conn: cursor = conn.execute("SELECT id FROM sessions_with_project WHERE project_id = ? ORDER BY id", (1,)) diff --git a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py index be4108bb7..f4fb61475 100644 --- a/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py +++ b/tests/integration/adapters/mysqlconnector/extensions/adk/test_store.py @@ -53,12 +53,10 @@ async def test_storage_types_verification(mysqlconnector_adk_store: MysqlConnect event_columns = await cursor.fetchall() event_col_names = [col[0] for col in event_columns] - # New 5-column schema: session_id, invocation_id, author, timestamp, event_json assert "session_id" in event_col_names assert "invocation_id" in event_col_names - assert "author" in event_col_names assert "timestamp" in event_col_names - assert "event_json" in event_col_names + assert "event_data" in event_col_names timestamp_col = next(col for col in event_columns if col[0] == "timestamp") assert "timestamp(6)" in cast("str", timestamp_col[2]).lower() @@ -78,15 +76,17 @@ async def test_timestamp_precision(mysqlconnector_adk_store: MysqlConnectorAsync event_time = datetime.now(timezone.utc) event: EventRecord = { + "id": "event-micro", + "app_name": app_name, + "user_id": user_id, "session_id": session_id, "invocation_id": "inv-micro", - "author": "system", "timestamp": event_time, - "event_json": {"app_name": app_name}, + "event_data": {"app_name": app_name, "author": "system"}, } await mysqlconnector_adk_store.append_event(event) - events = await mysqlconnector_adk_store.get_events(session_id) + events = await mysqlconnector_adk_store.get_events(app_name, user_id, session_id) assert len(events) == 1 assert hasattr(events[0]["timestamp"], "microsecond") @@ -127,7 +127,7 @@ async def test_owner_id_constraint_enforcement(mysqlconnector_adk_store_with_fk: session_id=session_id, app_name=app_name, user_id=user_id, state={"tenant": "one"}, owner_id=1 ) - session = await mysqlconnector_adk_store_with_fk.get_session(session_id) + session = await mysqlconnector_adk_store_with_fk.get_session(app_name, user_id, session_id) assert session is not None with pytest.raises(Exception): @@ -144,7 +144,7 @@ async def test_owner_id_cascade_delete(mysqlconnector_adk_store_with_fk: MysqlCo session_id="tenant1-session", app_name="test-app", user_id="user1", state={"data": "test"}, owner_id=1 ) - session_before = await mysqlconnector_adk_store_with_fk.get_session("tenant1-session") + session_before = await mysqlconnector_adk_store_with_fk.get_session("test-app", "user1", "tenant1-session") assert session_before is not None async with config.provide_connection() as conn: @@ -155,7 +155,7 @@ async def test_owner_id_cascade_delete(mysqlconnector_adk_store_with_fk: MysqlCo finally: await cursor.close() - session_after = await mysqlconnector_adk_store_with_fk.get_session("tenant1-session") + session_after = await mysqlconnector_adk_store_with_fk.get_session("test-app", "user1", "tenant1-session") assert session_after is None diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py index 5a32a911b..3a01c215a 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_inmemory.py @@ -35,7 +35,7 @@ async def test_inmemory_enabled_creates_sessions_table_with_inmemory_async( """ SELECT inmemory, inmemory_priority, inmemory_distribute FROM user_tables - WHERE table_name = 'ADK_SESSIONS' + WHERE table_name = 'ADK_SESSION' """ ) row = await cursor.fetchone() @@ -77,7 +77,7 @@ async def test_inmemory_enabled_creates_events_table_with_inmemory_async( """ SELECT inmemory, inmemory_priority, inmemory_distribute FROM user_tables - WHERE table_name = 'ADK_EVENTS' + WHERE table_name = 'ADK_EVENT' """ ) row = await cursor.fetchone() @@ -115,7 +115,7 @@ async def test_inmemory_disabled_creates_tables_without_inmemory_async(oracle_as """ SELECT inmemory, inmemory_priority, inmemory_distribute FROM user_tables - WHERE table_name IN ('ADK_SESSIONS', 'ADK_EVENTS') + WHERE table_name IN ('ADK_SESSION', 'ADK_EVENT') ORDER BY table_name """ ) @@ -153,7 +153,7 @@ async def test_inmemory_default_disabled_async(oracle_async_config: OracleAsyncC """ SELECT inmemory FROM user_tables - WHERE table_name = 'ADK_SESSIONS' + WHERE table_name = 'ADK_SESSION' """ ) row = await cursor.fetchone() @@ -214,7 +214,7 @@ async def test_inmemory_with_owner_id_column_async(oracle_async_config: OracleAs SELECT inmemory, column_name FROM user_tables t LEFT JOIN user_tab_columns c ON t.table_name = c.table_name - WHERE t.table_name = 'ADK_SESSIONS' AND (c.column_name = 'OWNER_ID' OR c.column_name IS NULL) + WHERE t.table_name = 'ADK_SESSION' AND (c.column_name = 'OWNER_ID' OR c.column_name IS NULL) """ ) rows = await cursor.fetchall() @@ -279,14 +279,14 @@ async def test_inmemory_tables_functional_async(oracle_async_config: OracleAsync assert session["id"] == session_id assert session["state"] == state - retrieved = await store.get_session(session_id) + retrieved = await store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["state"] == state updated_state = {"data": "updated", "count": 100} - await store.update_session_state(session_id, updated_state) + await store.update_session_state(app_name, user_id, session_id, updated_state) - retrieved_updated = await store.get_session(session_id) + retrieved_updated = await store.get_session(app_name, user_id, session_id) assert retrieved_updated is not None assert retrieved_updated["state"] == updated_state @@ -309,7 +309,7 @@ async def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> No ) store = OracleSyncADKStore(config) - await store.create_tables() + store.create_tables() try: with config.provide_connection() as conn: @@ -318,7 +318,7 @@ async def test_inmemory_enabled_sync(oracle_sync_config: OracleSyncConfig) -> No """ SELECT inmemory, inmemory_priority FROM user_tables - WHERE table_name IN ('ADK_SESSIONS', 'ADK_EVENTS') + WHERE table_name IN ('ADK_SESSION', 'ADK_EVENT') ORDER BY table_name """ ) @@ -350,7 +350,7 @@ async def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> N ) store = OracleSyncADKStore(config) - await store.create_tables() + store.create_tables() try: with config.provide_connection() as conn: @@ -359,7 +359,7 @@ async def test_inmemory_disabled_sync(oracle_sync_config: OracleSyncConfig) -> N """ SELECT inmemory, inmemory_priority FROM user_tables - WHERE table_name IN ('ADK_SESSIONS', 'ADK_EVENTS') + WHERE table_name IN ('ADK_SESSION', 'ADK_EVENT') """ ) rows = cursor.fetchall() @@ -389,7 +389,7 @@ async def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncCon ) store = OracleSyncADKStore(config) - await store.create_tables() + store.create_tables() try: session_id = "inmemory-sync-session" @@ -397,11 +397,11 @@ async def test_inmemory_tables_functional_sync(oracle_sync_config: OracleSyncCon user_id = "user-456" state = {"sync": True, "value": 99} - session = await store.create_session(session_id, app_name, user_id, state) + session = store.create_session(session_id, app_name, user_id, state) assert session["id"] == session_id assert session["state"] == state - retrieved = await store.get_session(session_id) + retrieved = store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["state"] == state diff --git a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py index 1e6cb1f94..3cd01057b 100644 --- a/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py +++ b/tests/integration/adapters/oracledb/extensions/adk/test_oracle_specific.py @@ -21,6 +21,21 @@ def _unique_session_id(prefix: str) -> str: return f"{prefix}-{uuid4().hex}" +def _event_record( + *, event_id: str, app_name: str, user_id: str, session_id: str, invocation_id: str, event_data: dict[str, Any] +) -> EventRecord: + """Return a clean-break EventRecord for Oracle store tests.""" + return EventRecord( + id=event_id, + app_name=app_name, + user_id=user_id, + session_id=session_id, + invocation_id=invocation_id, + timestamp=datetime.now(timezone.utc), + event_data=event_data, + ) + + def _drop_table_statements(store: object) -> "list[str]": """Return drop table statements for ADK stores.""" dropper = cast("Any", getattr(store, "_get_drop_tables_sql")) @@ -66,7 +81,7 @@ async def oracle_async_store(oracle_async_config: "OracleAsyncConfig") -> "Async async def oracle_sync_store(oracle_sync_config: "OracleSyncConfig") -> "AsyncGenerator[OracleSyncADKStore, None]": """Create a sync Oracle ADK store with tables created per test.""" store = OracleSyncADKStore(oracle_sync_config) - await store.create_tables() + store.create_tables() try: yield store finally: @@ -199,7 +214,7 @@ async def oracle_store_sync_with_fk( ) store = OracleSyncADKStore(config_with_extension) _cleanup_sync_store(store, config_with_extension) - await store.create_tables() + store.create_tables() try: yield store finally: @@ -216,14 +231,14 @@ async def test_state_lob_deserialization(oracle_async_store: "OracleAsyncADKStor session = await oracle_async_store.create_session(session_id, app_name, user_id, state) assert session["state"] == state - retrieved = await oracle_async_store.get_session(session_id) + retrieved = await oracle_async_store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["state"] == state assert retrieved["state"]["large_field"] == "x" * 10000 -async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test event_json LOB data is correctly deserialized.""" +async def test_event_data_lob_deserialization(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event_data LOB data is correctly deserialized.""" session_id = _unique_session_id("event-lob") app_name = "test-app" user_id = "user-123" @@ -239,30 +254,31 @@ async def test_event_json_lob_deserialization(oracle_async_store: "OracleAsyncAD "custom_metadata": {"tags": ["tag1", "tag2"], "priority": "high"}, } - event_record: EventRecord = { - "session_id": session_id, - "invocation_id": "", - "author": "assistant", - "timestamp": datetime.now(timezone.utc), - "event_json": event_data, - } + event_record = _event_record( + event_id="event-lob", + app_name=app_name, + user_id=user_id, + session_id=session_id, + invocation_id="", + event_data={**event_data, "author": "assistant"}, + ) await oracle_async_store.append_event(event_record) - events = await oracle_async_store.get_events(session_id) + events = await oracle_async_store.get_events(app_name, user_id, session_id) assert len(events) == 1 - # event_json contains all the data + # event_data contains all the data retrieved_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) assert retrieved_data["content"] == content assert retrieved_data["grounding_metadata"] == {"sources": ["a" * 1000, "b" * 1000]} assert retrieved_data["custom_metadata"] == {"tags": ["tag1", "tag2"], "priority": "high"} -async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test event_json blob is correctly stored and retrieved.""" - session_id = _unique_session_id("event-json") +async def test_event_data_storage(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test event_data blob is correctly stored and retrieved.""" + session_id = _unique_session_id("event-data") app_name = "test-app" user_id = "user-123" @@ -270,22 +286,23 @@ async def test_event_json_storage(oracle_async_store: "OracleAsyncADKStore") -> event_data = {"function": "test_func", "args": {"param": "value"}, "result": 42} - event_record: EventRecord = { - "session_id": session_id, - "invocation_id": "", - "author": "user", - "timestamp": datetime.now(timezone.utc), - "event_json": event_data, - } + event_record = _event_record( + event_id="event-data", + app_name=app_name, + user_id=user_id, + session_id=session_id, + invocation_id="", + event_data={**event_data, "author": "user"}, + ) await oracle_async_store.append_event(event_record) - events = await oracle_async_store.get_events(session_id) + events = await oracle_async_store.get_events(app_name, user_id, session_id) assert len(events) == 1 retrieved_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) - assert retrieved_data == event_data + assert retrieved_data == {**event_data, "author": "user"} async def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKStore") -> None: @@ -295,65 +312,76 @@ async def test_state_lob_deserialization_sync(oracle_sync_store: "OracleSyncADKS user_id = "user-123" state = {"large_field": "y" * 10000, "nested": {"data": [4, 5, 6]}} - session = await oracle_sync_store.create_session(session_id, app_name, user_id, state) + session = oracle_sync_store.create_session(session_id, app_name, user_id, state) assert session["state"] == state - retrieved = await oracle_sync_store.get_session(session_id) + retrieved = oracle_sync_store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["state"] == state -async def test_event_record_5_column_contract(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test the new 5-column EventRecord contract with append_event.""" - session_id = _unique_session_id("5col-session") +async def test_event_record_clean_break_contract(oracle_async_store: "OracleAsyncADKStore") -> None: + """Test the clean-break EventRecord contract with append_event.""" + session_id = _unique_session_id("event-contract") app_name = "test-app" user_id = "user-123" await oracle_async_store.create_session(session_id, app_name, user_id, {}) - event_record: EventRecord = { - "session_id": session_id, - "invocation_id": "inv-001", - "author": "assistant", - "timestamp": datetime.now(timezone.utc), - "event_json": {"content": {"text": "Hello"}, "partial": True, "turn_complete": False, "interrupted": True}, - } + event_record = _event_record( + event_id="event-contract", + app_name=app_name, + user_id=user_id, + session_id=session_id, + invocation_id="inv-001", + event_data={ + "author": "assistant", + "content": {"text": "Hello"}, + "partial": True, + "turn_complete": False, + "interrupted": True, + }, + ) await oracle_async_store.append_event(event_record) - events = await oracle_async_store.get_events(session_id) + events = await oracle_async_store.get_events(app_name, user_id, session_id) assert len(events) == 1 + assert events[0]["id"] == "event-contract" + assert events[0]["app_name"] == app_name + assert events[0]["user_id"] == user_id assert events[0]["session_id"] == session_id assert events[0]["invocation_id"] == "inv-001" - assert events[0]["author"] == "assistant" retrieved_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) + assert retrieved_data["author"] == "assistant" assert retrieved_data["partial"] is True assert retrieved_data["turn_complete"] is False assert retrieved_data["interrupted"] is True async def test_event_with_none_values(oracle_async_store: "OracleAsyncADKStore") -> None: - """Test event with minimal event_json content.""" + """Test event with minimal event_data content.""" session_id = _unique_session_id("none-session") app_name = "test-app" user_id = "user-123" await oracle_async_store.create_session(session_id, app_name, user_id, {}) - event_record: EventRecord = { - "session_id": session_id, - "invocation_id": "", - "author": "user", - "timestamp": datetime.now(timezone.utc), - "event_json": {"app_name": app_name}, - } + event_record = _event_record( + event_id="event-none-values", + app_name=app_name, + user_id=user_id, + session_id=session_id, + invocation_id="", + event_data={"app_name": app_name, "author": "user"}, + ) await oracle_async_store.append_event(event_record) - events = await oracle_async_store.get_events(session_id) + events = await oracle_async_store.get_events(app_name, user_id, session_id) assert len(events) == 1 @@ -436,7 +464,7 @@ async def test_json_fields_stored_and_retrieved(oracle_async_store: "OracleAsync session = await oracle_async_store.create_session(session_id, app_name, user_id, state) assert session["state"] == state - retrieved = await oracle_async_store.get_session(session_id) + retrieved = await oracle_async_store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["state"] == state assert retrieved["state"]["complex"]["unicode"] == "\u65e5\u672c\u8a9e\u30c6\u30b9\u30c8" @@ -450,10 +478,10 @@ async def test_create_session_with_owner_id_sync(oracle_store_sync_with_fk: "Ora state = {"data": "sync test"} owner_id = 100 - session = await oracle_store_sync_with_fk.create_session(session_id, app_name, user_id, state, owner_id=owner_id) + session = oracle_store_sync_with_fk.create_session(session_id, app_name, user_id, state, owner_id=owner_id) assert session["id"] == session_id assert session["state"] == state - retrieved = await oracle_store_sync_with_fk.get_session(session_id) + retrieved = oracle_store_sync_with_fk.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["id"] == session_id diff --git a/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py index 79970ab23..803b729e0 100644 --- a/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/psycopg/extensions/adk/test_owner_id_column.py @@ -57,7 +57,7 @@ async def psycopg_sync_store_with_fk(postgres_service: "PostgresService") -> "As }, ) store = PsycopgSyncADKStore(config) - await store.create_tables() + store.create_tables() yield store with config.provide_connection() as conn, conn.cursor() as cur: @@ -174,7 +174,7 @@ async def test_async_ddl_includes_owner_id_column(psycopg_async_store_with_fk: P async def test_sync_ddl_includes_owner_id_column(psycopg_sync_store_with_fk: PsycopgSyncADKStore) -> None: """Test that the DDL generation includes the owner_id_column.""" - ddl = await psycopg_sync_store_with_fk._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] + ddl = psycopg_sync_store_with_fk._get_create_sessions_table_sql() # pyright: ignore[reportPrivateUsage] assert "account_id VARCHAR(64) NOT NULL" in ddl assert "test_sessions_sync_fk" in ddl diff --git a/tests/integration/adapters/spanner/extensions/adk/conftest.py b/tests/integration/adapters/spanner/extensions/adk/conftest.py index 1b697c64a..9dc5cf6e6 100644 --- a/tests/integration/adapters/spanner/extensions/adk/conftest.py +++ b/tests/integration/adapters/spanner/extensions/adk/conftest.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncGenerator +from collections.abc import Generator from typing import TYPE_CHECKING import pytest @@ -29,7 +29,7 @@ def spanner_adk_config(spanner_service: SpannerService, spanner_database: "Datab @pytest.fixture -async def spanner_adk_store(spanner_adk_config: SpannerSyncConfig) -> AsyncGenerator[SpannerSyncADKStore, None]: +def spanner_adk_store(spanner_adk_config: SpannerSyncConfig) -> Generator[SpannerSyncADKStore, None, None]: store = SpannerSyncADKStore(spanner_adk_config) - await store.create_tables() + store.create_tables() yield store diff --git a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py index b7cca39f2..e47095cd6 100644 --- a/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py +++ b/tests/integration/adapters/spanner/extensions/adk/test_adk_store.py @@ -11,85 +11,88 @@ pytestmark = [pytest.mark.spanner, pytest.mark.integration] -async def test_create_and_get_session(spanner_adk_store: Any) -> None: +def test_create_and_get_session(spanner_adk_store: Any) -> None: session_id = "session-create" - await spanner_adk_store.delete_session(session_id) - created = await spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) + spanner_adk_store.delete_session("app", "user", session_id) + created = spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) assert created["id"] == session_id - fetched = await spanner_adk_store.get_session(session_id) + fetched = spanner_adk_store.get_session("app", "user", session_id) assert fetched is not None assert fetched["state"] == {"a": 1} -async def test_update_session_state(spanner_adk_store: Any) -> None: +def test_update_session_state(spanner_adk_store: Any) -> None: session_id = "session-update" - await spanner_adk_store.delete_session(session_id) - await spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) + spanner_adk_store.delete_session("app", "user", session_id) + spanner_adk_store.create_session(session_id, "app", "user", {"a": 1}) - await spanner_adk_store.update_session_state(session_id, {"a": 2, "b": True}) + spanner_adk_store.update_session_state("app", "user", session_id, {"a": 2, "b": True}) - fetched = await spanner_adk_store.get_session(session_id) + fetched = spanner_adk_store.get_session("app", "user", session_id) assert fetched is not None assert fetched["state"] == {"a": 2, "b": True} -async def test_list_sessions(spanner_adk_store: Any) -> None: - await spanner_adk_store.delete_session("session-list-1") - await spanner_adk_store.delete_session("session-list-2") - await spanner_adk_store.delete_session("session-list-3") - await spanner_adk_store.create_session("session-list-1", "app-list", "user1", {"v": 1}) - await spanner_adk_store.create_session("session-list-2", "app-list", "user1", {"v": 2}) - await spanner_adk_store.create_session("session-list-3", "app-list", "user2", {"v": 3}) +def test_list_sessions(spanner_adk_store: Any) -> None: + spanner_adk_store.delete_session("app-list", "user1", "session-list-1") + spanner_adk_store.delete_session("app-list", "user1", "session-list-2") + spanner_adk_store.delete_session("app-list", "user2", "session-list-3") + spanner_adk_store.create_session("session-list-1", "app-list", "user1", {"v": 1}) + spanner_adk_store.create_session("session-list-2", "app-list", "user1", {"v": 2}) + spanner_adk_store.create_session("session-list-3", "app-list", "user2", {"v": 3}) - sessions = await spanner_adk_store.list_sessions("app-list", "user1") + sessions = spanner_adk_store.list_sessions("app-list", "user1") session_ids = {s["id"] for s in sessions} assert session_ids == {"session-list-1", "session-list-2"} -async def test_delete_session(spanner_adk_store: Any) -> None: +def test_delete_session(spanner_adk_store: Any) -> None: session_id = "session-delete" - await spanner_adk_store.delete_session(session_id) - await spanner_adk_store.create_session(session_id, "app", "user", {"k": "v"}) - await spanner_adk_store.delete_session(session_id) + spanner_adk_store.delete_session("app", "user", session_id) + spanner_adk_store.create_session(session_id, "app", "user", {"k": "v"}) + spanner_adk_store.delete_session("app", "user", session_id) - assert await spanner_adk_store.get_session(session_id) is None + assert spanner_adk_store.get_session("app", "user", session_id) is None -async def test_create_and_list_events(spanner_adk_store: Any) -> None: +def test_create_and_list_events(spanner_adk_store: Any) -> None: session_id = "session-events" - await spanner_adk_store.delete_session(session_id) - await spanner_adk_store.create_session(session_id, "app", "user", {"x": 1}) + spanner_adk_store.delete_session("app", "user", session_id) + spanner_adk_store.create_session(session_id, "app", "user", {"x": 1}) event_one: EventRecord = { + "id": "event-1", + "app_name": "app", + "user_id": "user", "session_id": session_id, "invocation_id": "event-1", - "author": "user", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1", "content": {"msg": "hi"}, "app_name": "app", "user_id": "user"}, + "event_data": {"id": "event-1", "author": "user", "content": {"msg": "hi"}}, } event_two: EventRecord = { + "id": "event-2", + "app_name": "app", + "user_id": "user", "session_id": session_id, "invocation_id": "event-2", - "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-2", "content": {"msg": "ok"}, "app_name": "app", "user_id": "user"}, + "event_data": {"id": "event-2", "author": "assistant", "content": {"msg": "ok"}}, } - await spanner_adk_store.append_event(event_one) - await spanner_adk_store.append_event(event_two) + spanner_adk_store.append_event(event_one) + spanner_adk_store.append_event(event_two) - events = await spanner_adk_store.get_events(session_id) + events = spanner_adk_store.get_events("app", "user", session_id) assert len(events) == 2 - assert events[0]["author"] == "user" - assert events[1]["author"] == "assistant" - # Content is inside event_json in the new 5-column schema event0_data = ( - json.loads(events[0]["event_json"]) if isinstance(events[0]["event_json"], str) else events[0]["event_json"] + json.loads(events[0]["event_data"]) if isinstance(events[0]["event_data"], str) else events[0]["event_data"] ) event1_data = ( - json.loads(events[1]["event_json"]) if isinstance(events[1]["event_json"], str) else events[1]["event_json"] + json.loads(events[1]["event_data"]) if isinstance(events[1]["event_data"], str) else events[1]["event_data"] ) + assert event0_data["author"] == "user" + assert event1_data["author"] == "assistant" assert event0_data["content"] == {"msg": "hi"} assert event1_data["content"] == {"msg": "ok"} diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py index 492c4a18c..cf271c192 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_memory_store.py @@ -30,56 +30,56 @@ def _build_record(*, session_id: str, event_id: str, content_text: str, inserted ) -async def test_sqlite_memory_store_insert_search_dedup() -> None: +def test_sqlite_memory_store_insert_search_dedup() -> None: """Insert memory entries, search by text, and skip duplicates.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig(connection_config={"database": tmp.name}) store = SqliteADKMemoryStore(config) - await store.create_tables() + store.create_tables() now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="latte", inserted_at=now) - inserted = await store.insert_memory_entries([record1, record2]) + inserted = store.insert_memory_entries([record1, record2]) assert inserted == 2 - results = await store.search_entries(query="espresso", app_name="app", user_id="user") + results = store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-1" - deduped = await store.insert_memory_entries([record1]) + deduped = store.insert_memory_entries([record1]) assert deduped == 0 -async def test_sqlite_memory_store_fts_search() -> None: +def test_sqlite_memory_store_fts_search() -> None: """FTS-enabled memory stores search through the FTS5 virtual table.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig( connection_config={"database": tmp.name}, extension_config={"adk": {"memory_use_fts": True}} ) store = SqliteADKMemoryStore(config) - await store.create_tables() + store.create_tables() now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-fts-1", content_text="espresso roast", inserted_at=now) record2 = _build_record(session_id="s1", event_id="evt-fts-2", content_text="latte foam", inserted_at=now) - await store.insert_memory_entries([record1, record2]) + store.insert_memory_entries([record1, record2]) - results = await store.search_entries(query="espresso", app_name="app", user_id="user") + results = store.search_entries(query="espresso", app_name="app", user_id="user") assert len(results) == 1 assert results[0]["event_id"] == "evt-fts-1" -async def test_sqlite_memory_store_disabled_lifecycle() -> None: +def test_sqlite_memory_store_disabled_lifecycle() -> None: """Disabled memory stores skip table creation and reject memory operations.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig( connection_config={"database": tmp.name}, extension_config={"adk": {"enable_memory": False}} ) store = SqliteADKMemoryStore(config) - await store.create_tables() + store.create_tables() with config.provide_connection() as conn: cursor = conn.execute( @@ -92,47 +92,47 @@ async def test_sqlite_memory_store_disabled_lifecycle() -> None: now = datetime.now(timezone.utc) record = _build_record(session_id="s1", event_id="evt-disabled", content_text="espresso", inserted_at=now) with pytest.raises(RuntimeError, match="Memory store is disabled"): - await store.insert_memory_entries([record]) + store.insert_memory_entries([record]) with pytest.raises(RuntimeError, match="Memory store is disabled"): - await store.search_entries(query="espresso", app_name="app", user_id="user") + store.search_entries(query="espresso", app_name="app", user_id="user") -async def test_sqlite_memory_store_delete_by_session() -> None: +def test_sqlite_memory_store_delete_by_session() -> None: """Delete memory entries by session id.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig(connection_config={"database": tmp.name}) store = SqliteADKMemoryStore(config) - await store.create_tables() + store.create_tables() now = datetime.now(timezone.utc) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="espresso", inserted_at=now) record2 = _build_record(session_id="s2", event_id="evt-2", content_text="latte", inserted_at=now) - await store.insert_memory_entries([record1, record2]) + store.insert_memory_entries([record1, record2]) - deleted = await store.delete_entries_by_session("s1") + deleted = store.delete_entries_by_session("s1") assert deleted == 1 - remaining = await store.search_entries(query="latte", app_name="app", user_id="user") + remaining = store.search_entries(query="latte", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["session_id"] == "s2" -async def test_sqlite_memory_store_delete_older_than() -> None: +def test_sqlite_memory_store_delete_older_than() -> None: """Delete memory entries older than a cutoff.""" with tempfile.NamedTemporaryFile(suffix=".db") as tmp: config = SqliteConfig(connection_config={"database": tmp.name}) store = SqliteADKMemoryStore(config) - await store.create_tables() + store.create_tables() now = datetime.now(timezone.utc) old = now - timedelta(days=40) record1 = _build_record(session_id="s1", event_id="evt-1", content_text="old", inserted_at=old) record2 = _build_record(session_id="s1", event_id="evt-2", content_text="new", inserted_at=now) - await store.insert_memory_entries([record1, record2]) + store.insert_memory_entries([record1, record2]) - deleted = await store.delete_entries_older_than(30) + deleted = store.delete_entries_older_than(30) assert deleted == 1 - remaining = await store.search_entries(query="new", app_name="app", user_id="user") + remaining = store.search_entries(query="new", app_name="app", user_id="user") assert len(remaining) == 1 assert remaining[0]["event_id"] == "evt-2" diff --git a/tests/integration/adapters/sqlite/extensions/adk/test_owner_id_column.py b/tests/integration/adapters/sqlite/extensions/adk/test_owner_id_column.py index 0d69a5f61..e7e9350dd 100644 --- a/tests/integration/adapters/sqlite/extensions/adk/test_owner_id_column.py +++ b/tests/integration/adapters/sqlite/extensions/adk/test_owner_id_column.py @@ -101,7 +101,7 @@ def initial_state() -> "dict[str, Any]": return {"key": "value", "count": 0} -async def test_owner_id_column_integer_reference( +def test_owner_id_column_integer_reference( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test owner ID column with INTEGER foreign key.""" @@ -115,9 +115,9 @@ async def test_owner_id_column_integer_reference( }, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - session = await store.create_session(session_id, app_name, user_id, initial_state, owner_id=tenant_id) + session = store.create_session(session_id, app_name, user_id, initial_state, owner_id=tenant_id) assert session["id"] == session_id assert session["app_name"] == app_name @@ -126,13 +126,13 @@ async def test_owner_id_column_integer_reference( assert isinstance(session["create_time"], datetime) assert isinstance(session["update_time"], datetime) - retrieved = await store.get_session(session_id) + retrieved = store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["id"] == session_id assert retrieved["state"] == initial_state -async def test_owner_id_column_text_reference( +def test_owner_id_column_text_reference( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test owner ID column with TEXT foreign key.""" @@ -145,19 +145,19 @@ async def test_owner_id_column_text_reference( extension_config={"adk": {"owner_id_column": "user_ref TEXT REFERENCES users(username) ON DELETE CASCADE"}}, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - session = await store.create_session(session_id, app_name, user_id, initial_state, owner_id=username) + session = store.create_session(session_id, app_name, user_id, initial_state, owner_id=username) assert session["id"] == session_id assert session["state"] == initial_state - retrieved = await store.get_session(session_id) + retrieved = store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["id"] == session_id -async def test_owner_id_column_cascade_delete( +def test_owner_id_column_cascade_delete( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test CASCADE DELETE on owner ID column.""" @@ -171,11 +171,11 @@ async def test_owner_id_column_cascade_delete( }, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - await store.create_session(session_id, app_name, user_id, initial_state, owner_id=tenant_id) + store.create_session(session_id, app_name, user_id, initial_state, owner_id=tenant_id) - retrieved_before = await store.get_session(session_id) + retrieved_before = store.get_session(app_name, user_id, session_id) assert retrieved_before is not None with sqlite_config.provide_connection() as conn: @@ -183,11 +183,11 @@ async def test_owner_id_column_cascade_delete( conn.execute("DELETE FROM tenants WHERE id = ?", (tenant_id,)) conn.commit() - retrieved_after = await store.get_session(session_id) + retrieved_after = store.get_session(app_name, user_id, session_id) assert retrieved_after is None -async def test_owner_id_column_constraint_violation( +def test_owner_id_column_constraint_violation( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test FK constraint violation with invalid tenant_id.""" @@ -198,17 +198,17 @@ async def test_owner_id_column_constraint_violation( extension_config={"adk": {"owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id)"}}, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() invalid_tenant_id = 99999 with pytest.raises(Exception) as exc_info: - await store.create_session(session_id, app_name, user_id, initial_state, owner_id=invalid_tenant_id) + store.create_session(session_id, app_name, user_id, initial_state, owner_id=invalid_tenant_id) assert "FOREIGN KEY constraint failed" in str(exc_info.value) or "constraint" in str(exc_info.value).lower() -async def test_owner_id_column_not_null_constraint( +def test_owner_id_column_not_null_constraint( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test NOT NULL constraint on owner ID column.""" @@ -219,15 +219,15 @@ async def test_owner_id_column_not_null_constraint( extension_config={"adk": {"owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id)"}}, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() with pytest.raises(Exception) as exc_info: - await store.create_session(session_id, app_name, user_id, initial_state, owner_id=None) + store.create_session(session_id, app_name, user_id, initial_state, owner_id=None) assert "NOT NULL constraint failed" in str(exc_info.value) or "not null" in str(exc_info.value).lower() -async def test_owner_id_column_nullable( +def test_owner_id_column_nullable( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test nullable owner ID column.""" @@ -239,33 +239,33 @@ async def test_owner_id_column_nullable( extension_config={"adk": {"owner_id_column": "tenant_id INTEGER REFERENCES tenants(id)"}}, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - session_without_fk = await store.create_session(str(uuid.uuid4()), app_name, user_id, initial_state, owner_id=None) + session_without_fk = store.create_session(str(uuid.uuid4()), app_name, user_id, initial_state, owner_id=None) assert session_without_fk is not None - session_with_fk = await store.create_session(session_id, app_name, user_id, initial_state, owner_id=tenant_id) + session_with_fk = store.create_session(session_id, app_name, user_id, initial_state, owner_id=tenant_id) assert session_with_fk is not None -async def test_without_owner_id_column( +def test_without_owner_id_column( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test store without owner ID column configured.""" store = SqliteADKStore(sqlite_config) - await store.create_tables() + store.create_tables() - session = await store.create_session(session_id, app_name, user_id, initial_state) + session = store.create_session(session_id, app_name, user_id, initial_state) assert session["id"] == session_id assert session["state"] == initial_state - retrieved = await store.get_session(session_id) + retrieved = store.get_session(app_name, user_id, session_id) assert retrieved is not None assert retrieved["id"] == session_id -async def test_foreign_keys_pragma_enabled( +def test_foreign_keys_pragma_enabled( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test that PRAGMA foreign_keys = ON is properly enabled.""" @@ -277,9 +277,9 @@ async def test_foreign_keys_pragma_enabled( extension_config={"adk": {"owner_id_column": "tenant_id INTEGER NOT NULL REFERENCES tenants(id)"}}, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - await store.create_session(session_id, app_name, user_id, initial_state, owner_id=tenant_id) + store.create_session(session_id, app_name, user_id, initial_state, owner_id=tenant_id) with sqlite_config.provide_connection() as conn: cursor = conn.execute("PRAGMA foreign_keys") @@ -287,7 +287,7 @@ async def test_foreign_keys_pragma_enabled( assert fk_enabled == 1 -async def test_multi_tenant_isolation( +def test_multi_tenant_isolation( sqlite_config: SqliteConfig, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test multi-tenant isolation with different tenant IDs.""" @@ -302,16 +302,16 @@ async def test_multi_tenant_isolation( }, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() session1_id = str(uuid.uuid4()) session2_id = str(uuid.uuid4()) - await store.create_session(session1_id, app_name, user_id, initial_state, owner_id=tenant1_id) - await store.create_session(session2_id, app_name, user_id, {"data": "tenant2"}, owner_id=tenant2_id) + store.create_session(session1_id, app_name, user_id, initial_state, owner_id=tenant1_id) + store.create_session(session2_id, app_name, user_id, {"data": "tenant2"}, owner_id=tenant2_id) - session1 = await store.get_session(session1_id) - session2 = await store.get_session(session2_id) + session1 = store.get_session(app_name, user_id, session1_id) + session2 = store.get_session(app_name, user_id, session2_id) assert session1 is not None assert session2 is not None @@ -323,14 +323,14 @@ async def test_multi_tenant_isolation( conn.execute("DELETE FROM tenants WHERE id = ?", (tenant1_id,)) conn.commit() - session1_after = await store.get_session(session1_id) - session2_after = await store.get_session(session2_id) + session1_after = store.get_session(app_name, user_id, session1_id) + session2_after = store.get_session(app_name, user_id, session2_id) assert session1_after is None assert session2_after is not None -async def test_owner_id_column_ddl_extraction(sqlite_config: SqliteConfig) -> None: +def test_owner_id_column_ddl_extraction(sqlite_config: SqliteConfig) -> None: """Test that column name is correctly extracted from DDL.""" config_with_extension = SqliteConfig( connection_config=sqlite_config.connection_config, @@ -344,7 +344,7 @@ async def test_owner_id_column_ddl_extraction(sqlite_config: SqliteConfig) -> No assert store._owner_id_column_ddl == "tenant_id INTEGER NOT NULL REFERENCES tenants(id) ON DELETE CASCADE" # pyright: ignore[reportPrivateUsage] -async def test_create_session_without_fk_when_not_required( +def test_create_session_without_fk_when_not_required( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test creating session without owner_id when column is nullable.""" @@ -355,15 +355,15 @@ async def test_create_session_without_fk_when_not_required( extension_config={"adk": {"owner_id_column": "tenant_id INTEGER REFERENCES tenants(id)"}}, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - session = await store.create_session(session_id, app_name, user_id, initial_state) + session = store.create_session(session_id, app_name, user_id, initial_state) assert session["id"] == session_id assert session["state"] == initial_state -async def test_owner_id_with_default_value( +def test_owner_id_with_default_value( sqlite_config: SqliteConfig, session_id: str, app_name: str, user_id: str, initial_state: "dict[str, Any]" ) -> None: """Test owner ID column with DEFAULT value.""" @@ -377,10 +377,10 @@ async def test_owner_id_with_default_value( }, ) store = SqliteADKStore(config_with_extension) - await store.create_tables() + store.create_tables() - session = await store.create_session(session_id, app_name, user_id, initial_state) + session = store.create_session(session_id, app_name, user_id, initial_state) assert session["id"] == session_id - retrieved = await store.get_session(session_id) + retrieved = store.get_session(app_name, user_id, session_id) assert retrieved is not None diff --git a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py index 38bbca7a8..630713232 100644 --- a/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py +++ b/tests/unit/adapters/test_oracledb/test_oracle_adk_store.py @@ -12,7 +12,7 @@ OracleSyncADKMemoryStore, OracleSyncADKStore, ) -from sqlspec.adapters.oracledb.adk.store import _event_json_column_ddl +from sqlspec.adapters.oracledb.adk.store import _event_data_column_ddl def _mock_config(adk_config: dict[str, object]) -> MagicMock: @@ -61,27 +61,27 @@ def test_oracle_sync_adk_store_deserialize_state_dict_coerces_decimal() -> None: assert result == {"state": 5.0} -def test_oracle_event_json_column_ddl_prefers_blob_over_clob() -> None: - assert _event_json_column_ddl(JSONStorageType.JSON_NATIVE) == "event_json JSON NOT NULL" - assert _event_json_column_ddl(JSONStorageType.BLOB_JSON) == "event_json BLOB CHECK (event_json IS JSON) NOT NULL" - assert _event_json_column_ddl(JSONStorageType.BLOB_PLAIN) == "event_json BLOB NOT NULL" +def test_oracle_event_data_column_ddl_prefers_blob_over_clob() -> None: + assert _event_data_column_ddl(JSONStorageType.JSON_NATIVE) == "event_data JSON NOT NULL" + assert _event_data_column_ddl(JSONStorageType.BLOB_JSON) == "event_data BLOB CHECK (event_data IS JSON) NOT NULL" + assert _event_data_column_ddl(JSONStorageType.BLOB_PLAIN) == "event_data BLOB NOT NULL" -async def test_oracle_async_adk_store_serialize_event_json_uses_blob_for_non_native() -> None: +async def test_oracle_async_adk_store_serialize_event_data_uses_blob_for_non_native() -> None: store = OracleAsyncADKStore.__new__(OracleAsyncADKStore) # type: ignore[call-arg] store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined] - result = await store._serialize_event_json({"value": 1}) # type: ignore[attr-defined] + result = await store._serialize_event_data({"value": 1}) # type: ignore[attr-defined] assert isinstance(result, bytes) assert b'"value":1' in result -def test_oracle_sync_adk_store_serialize_event_json_uses_blob_for_non_native() -> None: +def test_oracle_sync_adk_store_serialize_event_data_uses_blob_for_non_native() -> None: store = OracleSyncADKStore.__new__(OracleSyncADKStore) # type: ignore[call-arg] store._json_storage_type = JSONStorageType.BLOB_JSON # type: ignore[attr-defined] - result = store._serialize_event_json({"value": 1}) # type: ignore[attr-defined] + result = store._serialize_event_data({"value": 1}) # type: ignore[attr-defined] assert isinstance(result, bytes) assert b'"value":1' in result diff --git a/tests/unit/adapters/test_psycopg/test_adk_store.py b/tests/unit/adapters/test_psycopg/test_adk_store.py index 754fc56cd..43dfefb44 100644 --- a/tests/unit/adapters/test_psycopg/test_adk_store.py +++ b/tests/unit/adapters/test_psycopg/test_adk_store.py @@ -20,12 +20,15 @@ def __enter__(self) -> Self: def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: return None - def execute(self, query: Any, params: Any) -> None: + def execute(self, query: Any, params: Any = None) -> None: self.execute_calls.append((query, params)) def fetchall(self) -> "list[dict[str, Any]]": return self._rows + def fetchone(self) -> "dict[str, Any] | None": + return self._rows[0] if self._rows else None + class _DummyConnection: def __init__(self, cursor: _DummyCursor) -> None: @@ -38,7 +41,7 @@ def __enter__(self) -> Self: def __exit__(self, exc_type: Any, exc: Any, tb: Any) -> None: return None - def cursor(self) -> _DummyCursor: + def cursor(self, **kwargs: Any) -> _DummyCursor: return self._cursor def commit(self) -> None: @@ -63,6 +66,9 @@ def _build_store( store._config = config # type: ignore[attr-defined] store._events_table = "test_events" # type: ignore[attr-defined] store._session_table = "test_sessions" # type: ignore[attr-defined] + store._app_state_table = "test_app_state" # type: ignore[attr-defined] + store._user_state_table = "test_user_state" # type: ignore[attr-defined] + store._metadata_table = "test_metadata" # type: ignore[attr-defined] store._owner_id_column_ddl = None # type: ignore[attr-defined] store._owner_id_column_name = None # type: ignore[attr-defined] return store, cursor, connection @@ -72,18 +78,21 @@ def test_sync_append_event_inserts_without_session_update() -> None: """append_event must insert a single event without writing session state.""" store, cursor, connection = _build_store() event_record = { + "id": "event-1", + "app_name": "app", + "user_id": "user", "session_id": "session-1", "invocation_id": "", - "author": "assistant", "timestamp": datetime.now(timezone.utc), - "event_json": {"id": "event-1"}, + "event_data": {"id": "event-1"}, } store._append_event(event_record) # type: ignore[arg-type] assert len(cursor.execute_calls) == 1 _, params = cursor.execute_calls[0] - assert params[0] == "session-1" + assert params[0] == "event-1" + assert params[1] == "session-1" assert isinstance(params[4], Jsonb) assert connection.commit_called @@ -95,16 +104,79 @@ def test_sync_get_events_passes_after_timestamp_and_limit() -> None: { "session_id": "session-1", "invocation_id": "", - "author": "assistant", "timestamp": base_time, - "event_json": {"id": "event-2"}, + "event_data": {"id": "event-2"}, + "id": "event-2", + "app_name": "app", + "user_id": "user", } ] store, cursor, _ = _build_store(rows) - result = store._get_events("session-1", after_timestamp=base_time, limit=1) + result = store._get_events("app", "user", "session-1", after_timestamp=base_time, limit=1) assert len(cursor.execute_calls) == 1 _, params = cursor.execute_calls[0] - assert params == ("session-1", base_time, 1) - assert result[0]["event_json"]["id"] == "event-2" + assert params == ("app", "user", "session-1", base_time, 1) + assert result[0]["event_data"]["id"] == "event-2" + + +def test_sync_get_events_limit_zero_returns_empty_without_query() -> None: + """get_events(limit=0) must return no events without querying.""" + store, cursor, _ = _build_store() + + result = store._get_events("app", "user", "session-1", limit=0) + + assert result == [] + assert cursor.execute_calls == [] + + +def test_sync_append_event_and_update_state_writes_scoped_state_in_one_unit() -> None: + """append_event_and_update_state must use event_data and optional scoped state.""" + base_time = datetime(2026, 1, 1, tzinfo=timezone.utc) + rows = [ + { + "id": "session-1", + "app_name": "app", + "user_id": "user", + "state": {"session": True}, + "create_time": base_time, + "update_time": base_time, + } + ] + store, cursor, connection = _build_store(rows) + event_record = { + "id": "event-1", + "app_name": "app", + "user_id": "user", + "session_id": "session-1", + "invocation_id": "invoke-1", + "timestamp": base_time, + "event_data": {"id": "event-1"}, + } + + result = store._append_event_and_update_state( + event_record, # type: ignore[arg-type] + "app", + "user", + "session-1", + {"session": True}, + app_state={}, + user_state={"user:theme": "dark"}, + ) + + assert result["id"] == "session-1" + assert len(cursor.execute_calls) == 4 + _, insert_params = cursor.execute_calls[0] + _, update_params = cursor.execute_calls[1] + _, app_state_params = cursor.execute_calls[2] + _, user_state_params = cursor.execute_calls[3] + assert insert_params[0] == "event-1" + assert isinstance(insert_params[4], Jsonb) + assert getattr(update_params[0], "obj", None) == {"session": True} + assert update_params[1:4] == ("app", "user", "session-1") + assert app_state_params[0] == "app" + assert getattr(app_state_params[1], "obj", None) == {} + assert user_state_params[:2] == ("app", "user") + assert getattr(user_state_params[2], "obj", None) == {"user:theme": "dark"} + assert connection.commit_called diff --git a/tests/unit/adapters/test_spanner/test_adk_store.py b/tests/unit/adapters/test_spanner/test_adk_store.py index 2081df46d..6fc4d38d2 100644 --- a/tests/unit/adapters/test_spanner/test_adk_store.py +++ b/tests/unit/adapters/test_spanner/test_adk_store.py @@ -2,6 +2,7 @@ """Unit tests for Spanner ADK store behavior.""" from datetime import datetime, timezone +from types import SimpleNamespace from unittest.mock import MagicMock, patch from sqlspec.adapters.spanner.adk import SpannerSyncADKMemoryStore, SpannerSyncADKStore @@ -19,11 +20,13 @@ def test_insert_event_preserves_event_record_timestamp() -> None: store = SpannerSyncADKStore(_mock_config()) timestamp = datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc) event: EventRecord = { + "id": "event-1", + "app_name": "app", + "user_id": "u1", "session_id": "session-1", "invocation_id": "inv-1", - "author": "user", "timestamp": timestamp, - "event_json": {"id": "event-1"}, + "event_data": {"content": "hello"}, } with patch.object(store, "_run_write") as run_write: @@ -33,19 +36,22 @@ def test_insert_event_preserves_event_record_timestamp() -> None: sql, params, _types = statements[0] assert "@timestamp" in sql assert "PENDING_COMMIT_TIMESTAMP()" not in sql + assert params["id"] == "event-1" assert params["timestamp"] is timestamp -async def test_append_event_and_update_state_preserves_event_record_timestamp() -> None: +def test_append_event_and_update_state_preserves_event_record_timestamp() -> None: """Atomic append uses the ADK event timestamp while session update uses commit time.""" store = SpannerSyncADKStore(_mock_config()) timestamp = datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc) event: EventRecord = { + "id": "event-1", + "app_name": "app", + "user_id": "u1", "session_id": "session-1", "invocation_id": "inv-1", - "author": "user", "timestamp": timestamp, - "event_json": {"id": "event-1"}, + "event_data": {"content": "hello"}, } # Stub the post-write SELECT — the contract requires returning the refreshed record. fake_record = { @@ -58,46 +64,47 @@ async def test_append_event_and_update_state_preserves_event_record_timestamp() } with patch.object(store, "_run_write") as run_write, patch.object(store, "_get_session", return_value=fake_record): - returned = await store.append_event_and_update_state(event, "session-1", {"turn": 1}) + returned = store.append_event_and_update_state(event, "app", "u1", "session-1", {"turn": 1}) event_sql, event_params, _event_types = run_write.call_args.args[0][0] update_sql, _state_params, _state_types = run_write.call_args.args[0][1] assert "@timestamp" in event_sql assert "PENDING_COMMIT_TIMESTAMP()" not in event_sql + assert event_params["id"] == "event-1" assert event_params["timestamp"] is timestamp assert "PENDING_COMMIT_TIMESTAMP()" in update_sql assert returned == fake_record -async def test_spanner_session_table_generates_row_deletion_policy_from_retention() -> None: +def test_spanner_session_table_generates_row_deletion_policy_from_retention() -> None: store = SpannerSyncADKStore(_mock_config({"retention": {"session_ttl_seconds": 86_400}})) - sql = await store._get_create_sessions_table_sql() + sql = store._get_create_sessions_table_sql() assert "ROW DELETION POLICY (OLDER_THAN(create_time, INTERVAL 1 DAY))" in sql -async def test_spanner_events_table_rounds_retention_up_to_days() -> None: +def test_spanner_events_table_rounds_retention_up_to_days() -> None: store = SpannerSyncADKStore(_mock_config({"retention": {"event_ttl_seconds": 86_401}})) - sql = await store._get_create_events_table_sql() + sql = store._get_create_events_table_sql() assert "ROW DELETION POLICY (OLDER_THAN(timestamp, INTERVAL 2 DAY))" in sql -async def test_spanner_memory_table_generates_ttl_and_table_options() -> None: +def test_spanner_memory_table_generates_ttl_and_table_options() -> None: store = SpannerSyncADKMemoryStore( _mock_config({"memory_table_options": "locality_group = 'hot'", "retention": {"memory_ttl_seconds": 604_800}}) ) - statements = await store._get_create_memory_table_sql() + statements = store._get_create_memory_table_sql() table_sql = statements[0] assert "OPTIONS (locality_group = 'hot')" in table_sql assert "ROW DELETION POLICY (OLDER_THAN(inserted_at, INTERVAL 7 DAY))" in table_sql -async def test_spanner_memory_insert_entries_writes_clean_break_record() -> None: +def test_spanner_memory_insert_entries_writes_clean_break_record() -> None: store = SpannerSyncADKMemoryStore(_mock_config()) timestamp = datetime(2026, 5, 10, 12, 0, tzinfo=timezone.utc) entry: MemoryRecord = { @@ -115,12 +122,12 @@ async def test_spanner_memory_insert_entries_writes_clean_break_record() -> None } with patch.object(store, "_event_exists", return_value=False), patch.object(store, "_run_write") as run_write: - inserted = await store.insert_memory_entries([entry]) + inserted = store.insert_memory_entries([entry]) assert inserted == 1 statements = run_write.call_args.args[0] sql, params, _types = statements[0] - assert "INSERT INTO adk_memory_entries" in sql + assert "INSERT INTO adk_memory" in sql assert params["content_json"] == '{"text":"hello"}' assert params["metadata_json"] == '{"source":"unit"}' assert params["inserted_at"] is timestamp @@ -149,3 +156,27 @@ def test_spanner_memory_rows_to_records_decodes_json_fields() -> None: assert records[0]["content_json"] == {"text": "hello"} assert records[0]["metadata_json"] == {"source": "unit"} assert records[0]["content_text"] == "hello" + + +def test_spanner_reset_drop_tables_filters_absent_tables() -> None: + config = _mock_config() + config.get_database.return_value.list_tables.return_value = [SimpleNamespace(table_id="adk_events")] + store = SpannerSyncADKStore(config) + + statements = store._get_reset_drop_tables_sql() + + assert statements == ["DROP TABLE adk_events"] + + +def test_spanner_memory_reset_drop_tables_filters_absent_tables_and_indexes() -> None: + config = _mock_config() + config.get_database.return_value.list_tables.return_value = [SimpleNamespace(table_id="adk_memory_entries")] + store = SpannerSyncADKMemoryStore(config) + + statements = store._get_reset_drop_memory_table_sql() + + assert statements == [ + "DROP INDEX idx_adk_memory_entries_session", + "DROP INDEX idx_adk_memory_entries_app_user_time", + "DROP TABLE adk_memory_entries", + ] diff --git a/tests/unit/extensions/test_adk/test_config_resolution.py b/tests/unit/extensions/test_adk/test_config_resolution.py new file mode 100644 index 000000000..87790e7ce --- /dev/null +++ b/tests/unit/extensions/test_adk/test_config_resolution.py @@ -0,0 +1,84 @@ +"""Tests for ADK flat-config resolution.""" + +from typing import Any + +from sqlspec.config import ADKConfig +from sqlspec.extensions.adk._config_utils import ( + _get_adk_artifact_store_config, + _get_adk_memory_store_config, + _get_adk_session_store_config, + _is_adk_memory_migration_enabled, +) + + +class _Config: + extension_config: dict[str, dict[str, Any]] + + def __init__(self, adk_config: dict[str, Any]) -> None: + self.extension_config = {"adk": adk_config} + + +def test_adk_config_uses_flat_keys() -> None: + """ADKConfig is a flat TypedDict; no per-adapter or nested negotiation blocks.""" + annotations = set(ADKConfig.__annotations__) + expected_flat = {"session_table", "events_table", "memory_table", "artifact_table", "in_memory", "owner_id_column"} + forbidden_nested = {"schema", "lifecycle", "capabilities", "optimizations", "oracle", "spanner", "adbc", "bigquery"} + assert expected_flat <= annotations + assert annotations.isdisjoint(forbidden_nested) + + +def test_flat_schema_config_resolves_all_adk_table_names() -> None: + config = _Config({ + "session_table": "agent_sessions", + "events_table": "agent_events", + "app_state_table": "agent_app_states", + "user_state_table": "agent_user_states", + "metadata_table": "agent_metadata", + "owner_id_column": "tenant_id UUID", + }) + + resolved = _get_adk_session_store_config(config) + + assert resolved == { + "session_table": "agent_sessions", + "events_table": "agent_events", + "app_state_table": "agent_app_states", + "user_state_table": "agent_user_states", + "metadata_table": "agent_metadata", + "owner_id_column": "tenant_id UUID", + } + + +def test_flat_memory_config_resolves_memory_store_settings() -> None: + config = _Config({ + "enable_memory": False, + "memory_table": "agent_memories", + "memory_use_fts": True, + "memory_max_results": 50, + }) + + resolved = _get_adk_memory_store_config(config) + + assert resolved == {"enable_memory": False, "memory_table": "agent_memories", "use_fts": True, "max_results": 50} + + +def test_flat_artifact_config_resolves_table_and_storage_uri() -> None: + config = _Config({"artifact_table": "agent_artifacts", "artifact_storage_uri": "s3://bucket/adk"}) + + resolved = _get_adk_artifact_store_config(config) + + assert resolved == {"artifact_table": "agent_artifacts", "storage_uri": "s3://bucket/adk"} + + +def test_include_memory_migration_overrides_enable_memory() -> None: + config = _Config({"enable_memory": True, "include_memory_migration": False}) + + assert not _is_adk_memory_migration_enabled(config) + + +def test_include_memory_migration_defaults_to_enable_memory() -> None: + enabled = _Config({"enable_memory": True}) + disabled = _Config({"enable_memory": False}) + + assert _is_adk_memory_migration_enabled(enabled) is True + assert _is_adk_memory_migration_enabled(disabled) is False diff --git a/tests/unit/extensions/test_adk/test_converters.py b/tests/unit/extensions/test_adk/test_converters.py index 76025da6a..b8c8d160b 100644 --- a/tests/unit/extensions/test_adk/test_converters.py +++ b/tests/unit/extensions/test_adk/test_converters.py @@ -1,8 +1,8 @@ """Unit tests for ADK session/event converters and scoped state helpers. Tests the NEW contract specified in Chapter 1 of the ADK Clean-Break Overhaul: -- EventRecord has exactly 5 keys (session_id, invocation_id, author, timestamp, event_json) -- event_to_record takes only (event, session_id), not (event, session_id, app_name, user_id) +- EventRecord has exactly 7 keys (id, app_name, user_id, session_id, invocation_id, timestamp, event_data) +- event_to_record takes (event, app_name, user_id, session_id) - record_to_event uses Event.model_validate for full round-trip fidelity - filter_temp_state, split_scoped_state, merge_scoped_state for scoped state handling - session_to_record strips temp: keys from state @@ -42,12 +42,13 @@ def _make_event( event_id: str = "evt-1", invocation_id: str = "inv-1", author: str = "user", - text: "str | None" = None, - state_delta: "dict | None" = None, - branch: "str | None" = None, - partial: "bool | None" = None, - turn_complete: "bool | None" = None, - custom_metadata: "dict | None" = None, + text: str | None = None, + state_delta: dict | None = None, + branch: str | None = None, + isolation_scope: str | None = None, + partial: bool | None = None, + turn_complete: bool | None = None, + custom_metadata: dict | None = None, ) -> Event: content = types.Content(parts=[types.Part(text=text)]) if text is not None else None actions = EventActions(state_delta=state_delta or {}) @@ -59,6 +60,7 @@ def _make_event( actions=actions, timestamp=datetime(2024, 1, 15, 12, 0, 0, tzinfo=timezone.utc).timestamp(), branch=branch, + isolation_scope=isolation_scope, partial=partial, turn_complete=turn_complete, custom_metadata=custom_metadata, @@ -244,65 +246,79 @@ def test_compute_update_marker_normalizes_aware_datetime_to_utc() -> None: # --------------------------------------------------------------------------- -def test_event_to_record_only_5_keys() -> None: - """EventRecord has exactly session_id, invocation_id, author, timestamp, event_json.""" +def test_event_to_record_clean_break_keys() -> None: + """EventRecord has exactly the clean-break indexed fields plus event_data.""" event = _make_event() - record = event_to_record(event, "session-1") - assert set(record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"} + record = event_to_record(event, "app", "u1", "session-1") + assert set(record.keys()) == {"id", "app_name", "user_id", "session_id", "invocation_id", "timestamp", "event_data"} -def test_event_to_record_signature_two_args_only() -> None: - """event_to_record raises TypeError if called with extra positional args (old 4-arg signature).""" +def test_event_to_record_rejects_old_two_arg_signature() -> None: + """event_to_record rejects the old session-only identity signature.""" event = _make_event() with pytest.raises(TypeError): - event_to_record(event, "session-1", "app-name", "user-id") # type: ignore[call-arg] + event_to_record(event, "session-1") # type: ignore[call-arg] def test_event_to_record_session_id_stored_correctly() -> None: """session_id in the record matches the argument passed.""" event = _make_event(invocation_id="inv-abc", author="model") - record = event_to_record(event, "my-session-id") + record = event_to_record(event, "app", "u1", "my-session-id") assert record["session_id"] == "my-session-id" def test_event_to_record_indexed_fields_match_event() -> None: - """Indexed scalar columns (invocation_id, author, timestamp) match the source event.""" - event = _make_event(invocation_id="inv-xyz", author="tool") - record = event_to_record(event, "s1") + """Indexed scalar columns match the source event and context identity.""" + event = _make_event(event_id="evt-xyz", invocation_id="inv-xyz", author="tool") + record = event_to_record(event, "app", "u1", "s1") + assert record["id"] == "evt-xyz" + assert record["app_name"] == "app" + assert record["user_id"] == "u1" + assert record["session_id"] == "s1" assert record["invocation_id"] == "inv-xyz" - assert record["author"] == "tool" assert isinstance(record["timestamp"], datetime) + assert record["event_data"]["author"] == "tool" -def test_event_to_record_event_json_matches_model_dump() -> None: - """event_json in the record equals event.model_dump(exclude_none=True, mode='json').""" +def test_event_to_record_event_data_matches_model_dump() -> None: + """event_data in the record equals event.model_dump(exclude_none=True, mode='json').""" event = _make_event(text="hello", state_delta={"key": "val"}, custom_metadata={"foo": "bar"}) - record = event_to_record(event, "s1") + record = event_to_record(event, "app", "u1", "s1") expected_json = event.model_dump(exclude_none=True, mode="json") - assert record["event_json"] == expected_json + assert record["event_data"] == expected_json -def test_event_to_record_event_json_is_dict() -> None: - """event_json field is a plain dict (not bytes, not string).""" +def test_event_to_record_omits_assigned_none_actions() -> None: + """A defensive assigned-None actions value is not persisted as actions:null.""" event = _make_event() - record = event_to_record(event, "s1") - assert isinstance(record["event_json"], dict) + event.actions = None # type: ignore[assignment] + record = event_to_record(event, "app", "u1", "s1") -def test_event_to_record_actions_in_event_json_is_structured() -> None: - """Actions are stored as structured JSON dict in event_json, not as raw bytes.""" + assert "actions" not in record["event_data"] + + +def test_event_to_record_event_data_is_dict() -> None: + """event_data field is a plain dict (not bytes, not string).""" + event = _make_event() + record = event_to_record(event, "app", "u1", "s1") + assert isinstance(record["event_data"], dict) + + +def test_event_to_record_actions_in_event_data_is_structured() -> None: + """Actions are stored as structured JSON dict in event_data, not as raw bytes.""" event = _make_event(state_delta={"x": "y"}) - record = event_to_record(event, "s1") - event_json = record["event_json"] + record = event_to_record(event, "app", "u1", "s1") + event_data = record["event_data"] # actions should be a dict in the JSON blob - if "actions" in event_json: - assert isinstance(event_json["actions"], dict) + if "actions" in event_data: + assert isinstance(event_data["actions"], dict) def test_event_to_record_timestamp_is_datetime() -> None: """timestamp column is a datetime object with timezone.""" event = _make_event() - record = event_to_record(event, "s1") + record = event_to_record(event, "app", "u1", "s1") assert isinstance(record["timestamp"], datetime) assert record["timestamp"].tzinfo is not None @@ -315,7 +331,7 @@ def test_event_to_record_timestamp_is_datetime() -> None: def test_record_to_event_full_roundtrip_basic() -> None: """Event -> record -> Event produces an identical object for basic fields.""" original = _make_event(event_id="evt-rt", invocation_id="inv-rt", author="model") - record = event_to_record(original, "s1") + record = event_to_record(original, "app", "u1", "s1") restored = record_to_event(record) assert restored.id == original.id @@ -326,7 +342,7 @@ def test_record_to_event_full_roundtrip_basic() -> None: def test_record_to_event_roundtrip_preserves_content() -> None: """Content (parts) survives the round-trip.""" original = _make_event(text="hello world", author="model") - record = event_to_record(original, "s1") + record = event_to_record(original, "app", "u1", "s1") restored = record_to_event(record) assert restored.content is not None @@ -337,7 +353,7 @@ def test_record_to_event_roundtrip_preserves_content() -> None: def test_record_to_event_roundtrip_preserves_actions() -> None: """EventActions (state_delta) survives the round-trip.""" original = _make_event(state_delta={"key": "v1", "other": 42}) - record = event_to_record(original, "s1") + record = event_to_record(original, "app", "u1", "s1") restored = record_to_event(record) assert restored.actions is not None @@ -347,7 +363,7 @@ def test_record_to_event_roundtrip_preserves_actions() -> None: def test_record_to_event_roundtrip_preserves_custom_metadata() -> None: """custom_metadata survives the round-trip.""" original = _make_event(custom_metadata={"tag": "v2", "score": 0.9}) - record = event_to_record(original, "s1") + record = event_to_record(original, "app", "u1", "s1") restored = record_to_event(record) assert restored.custom_metadata == {"tag": "v2", "score": 0.9} @@ -356,16 +372,37 @@ def test_record_to_event_roundtrip_preserves_custom_metadata() -> None: def test_record_to_event_roundtrip_preserves_branch() -> None: """branch field survives the round-trip.""" original = _make_event(branch="feature-branch") - record = event_to_record(original, "s1") + record = event_to_record(original, "app", "u1", "s1") restored = record_to_event(record) assert restored.branch == "feature-branch" +def test_record_to_event_roundtrip_preserves_isolation_scope() -> None: + """ADK 2.2 isolation_scope survives the round-trip.""" + original = _make_event(isolation_scope="scope-1") + record = event_to_record(original, "app", "u1", "s1") + restored = record_to_event(record) + + assert restored.isolation_scope == "scope-1" + + +def test_record_to_event_normalizes_actions_null_to_default() -> None: + """Legacy or external actions:null payloads are normalized before ADK validation.""" + original = _make_event(state_delta={"key": "v1"}) + record = event_to_record(original, "app", "u1", "s1") + record["event_data"]["actions"] = None + + restored = record_to_event(record) + + assert restored.actions is not None + assert restored.actions.state_delta == {} + + def test_record_to_event_roundtrip_preserves_partial_flag() -> None: """partial flag survives the round-trip.""" original = _make_event(partial=True) - record = event_to_record(original, "s1") + record = event_to_record(original, "app", "u1", "s1") restored = record_to_event(record) assert restored.partial is True @@ -374,7 +411,7 @@ def test_record_to_event_roundtrip_preserves_partial_flag() -> None: def test_record_to_event_roundtrip_preserves_turn_complete() -> None: """turn_complete flag survives the round-trip.""" original = _make_event(turn_complete=True) - record = event_to_record(original, "s1") + record = event_to_record(original, "app", "u1", "s1") restored = record_to_event(record) assert restored.turn_complete is True @@ -384,19 +421,19 @@ def test_record_to_event_roundtrip_preserves_timestamp() -> None: """timestamp survives the round-trip within float precision.""" fixed_ts = datetime(2024, 6, 1, 10, 30, 0, tzinfo=timezone.utc).timestamp() event = Event(id="ts-evt", invocation_id="inv-1", author="user", actions=EventActions(), timestamp=fixed_ts) - record = event_to_record(event, "s1") + record = event_to_record(event, "app", "u1", "s1") restored = record_to_event(record) assert abs(restored.timestamp - fixed_ts) < 1.0 # within 1 second -def test_record_to_event_ignores_unknown_fields_in_event_json() -> None: - """Unknown event_json fields are ignored by the current ADK Event model.""" +def test_record_to_event_ignores_unknown_fields_in_event_data() -> None: + """Unknown event_data fields are ignored by the current ADK Event model.""" event = _make_event(event_id="extra-fields-evt", author="tool") - record = event_to_record(event, "s1") + record = event_to_record(event, "app", "u1", "s1") - # Inject hypothetical future ADK field into event_json - record["event_json"]["hypothetical_v3_field"] = "some_value" # type: ignore[index] + # Inject hypothetical future ADK field into event_data + record["event_data"]["hypothetical_v3_field"] = "some_value" # type: ignore[index] restored = record_to_event(record) assert restored.id == "extra-fields-evt" @@ -470,7 +507,7 @@ def test_record_to_session_with_events_round_trip() -> None: update_time=datetime.now(timezone.utc), ) event = _make_event(text="hello", author="user") - event_record = event_to_record(event, "s1") + event_record = event_to_record(event, "app", "u1", "s1") session = record_to_session(session_record, [event_record]) diff --git a/tests/unit/extensions/test_adk/test_service.py b/tests/unit/extensions/test_adk/test_service.py index d479980ce..ad1e4c040 100644 --- a/tests/unit/extensions/test_adk/test_service.py +++ b/tests/unit/extensions/test_adk/test_service.py @@ -20,8 +20,10 @@ from google.adk.events.event import Event from google.adk.events.event_actions import EventActions +from google.adk.sessions.base_session_service import GetSessionConfig from google.adk.sessions.session import Session +from sqlspec.extensions.adk._types import EventRecord, SessionRecord from sqlspec.extensions.adk.service import SQLSpecSessionService # --------------------------------------------------------------------------- @@ -41,9 +43,14 @@ def __init__(self) -> None: self.append_event_and_update_state_calls: list[dict[str, Any]] = [] self.append_event_and_update_state_called = False self.get_session_calls = 0 + self.get_events_calls: list[dict[str, Any]] = [] + self.upsert_app_state_calls: list[dict[str, Any]] = [] + self.upsert_user_state_calls: list[dict[str, Any]] = [] # Track calls to create_session self.create_session_calls: list[dict[str, Any]] = [] + self.app_state: dict[str, Any] = {} + self.user_state: dict[str, Any] = {} # Provide a get_session that returns a minimal session record self._session_record = { @@ -56,14 +63,29 @@ def __init__(self) -> None: } async def append_event_and_update_state( - self, event_record: Any, session_id: str, state: "dict[str, Any]" + self, + event_record: Any, + app_name: str, + user_id: str, + session_id: str, + state: "dict[str, Any]", + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> "dict[str, Any]": self.append_event_and_update_state_called = True self.append_event_and_update_state_calls.append({ "event_record": event_record, + "app_name": app_name, + "user_id": user_id, "session_id": session_id, "state": state, + "app_state": app_state, + "user_state": user_state, }) + if app_state is not None: + self.app_state = dict(app_state) + if user_state is not None: + self.user_state = dict(user_state) # Return the updated SessionRecord — caller no longer needs a follow-up get_session(). updated = dict(self._session_record) updated["state"] = state @@ -71,8 +93,12 @@ async def append_event_and_update_state( self._session_record = updated return updated - async def get_session(self, session_id: str) -> "dict[str, Any] | None": + async def get_session(self, app_name: str, user_id: str, session_id: str) -> "dict[str, Any] | None": self.get_session_calls += 1 + if self._session_record["app_name"] != app_name or self._session_record["user_id"] != user_id: + return None + if self._session_record["id"] != session_id: + return None return self._session_record async def create_session( @@ -84,7 +110,7 @@ async def create_session( "user_id": user_id, "state": state, }) - return { + self._session_record = { "id": session_id, "app_name": app_name, "user_id": user_id, @@ -92,21 +118,163 @@ async def create_session( "create_time": datetime.now(timezone.utc), "update_time": datetime.now(timezone.utc), } + return self._session_record + + async def get_app_state(self, app_name: str) -> "dict[str, Any]": + return dict(self.app_state) + + async def get_user_state(self, app_name: str, user_id: str) -> "dict[str, Any]": + return dict(self.user_state) + + async def upsert_app_state(self, app_name: str, state: "dict[str, Any]") -> None: + self.upsert_app_state_calls.append({"app_name": app_name, "state": state}) + self.app_state = dict(state) + + async def upsert_user_state(self, app_name: str, user_id: str, state: "dict[str, Any]") -> None: + self.upsert_user_state_calls.append({"app_name": app_name, "user_id": user_id, "state": state}) + self.user_state = dict(state) # Old method — should NOT be called by the new service async def append_event(self, event_record: Any) -> None: raise AssertionError("append_event (old method) must not be called — use append_event_and_update_state") - async def get_events(self, *, session_id: str, after_timestamp: Any = None, limit: Any = None) -> list: + async def get_events( + self, *, app_name: str, user_id: str, session_id: str, after_timestamp: Any = None, limit: Any = None + ) -> list: + self.get_events_calls.append({ + "app_name": app_name, + "user_id": user_id, + "session_id": session_id, + "after_timestamp": after_timestamp, + "limit": limit, + }) return [] async def list_sessions(self, *, app_name: str, user_id: "str | None" = None) -> list: return [] - async def delete_session(self, session_id: str) -> None: + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: pass +class SyncStore: + """Sync store proving SQLSpecSessionService owns sync-to-async bridging.""" + + def __init__(self) -> None: + self.app_state: dict[str, Any] = {} + self.user_state: dict[str, Any] = {} + self.create_session_calls: list[dict[str, Any]] = [] + self._session_record = SessionRecord( + id="s1", + app_name="app", + user_id="u1", + state={}, + create_time=datetime.now(timezone.utc), + update_time=datetime.now(timezone.utc), + ) + + def create_session( + self, session_id: str, app_name: str, user_id: str, state: dict[str, Any], owner_id: Any | None = None + ) -> SessionRecord: + self.create_session_calls.append({ + "session_id": session_id, + "app_name": app_name, + "user_id": user_id, + "state": state, + "owner_id": owner_id, + }) + self._session_record = SessionRecord( + id=session_id, + app_name=app_name, + user_id=user_id, + state=state, + create_time=datetime.now(timezone.utc), + update_time=datetime.now(timezone.utc), + ) + return self._session_record + + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: Any | None = None + ) -> SessionRecord | None: + if self._session_record["app_name"] != app_name or self._session_record["user_id"] != user_id: + return None + if self._session_record["id"] != session_id: + return None + return self._session_record + + def get_events( + self, app_name: str, user_id: str, session_id: str, after_timestamp: Any = None, limit: Any = None + ) -> list[EventRecord]: + return [] + + def get_app_state(self, app_name: str) -> dict[str, Any]: + return dict(self.app_state) + + def get_user_state(self, app_name: str, user_id: str) -> dict[str, Any]: + return dict(self.user_state) + + def upsert_app_state(self, app_name: str, state: dict[str, Any]) -> None: + self.app_state = dict(state) + + def upsert_user_state(self, app_name: str, user_id: str, state: dict[str, Any]) -> None: + self.user_state = dict(state) + + def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: dict[str, Any], + app_state: dict[str, Any] | None = None, + user_state: dict[str, Any] | None = None, + ) -> SessionRecord: + if app_state is not None: + self.app_state = dict(app_state) + if user_state is not None: + self.user_state = dict(user_state) + self._session_record = SessionRecord( + id=self._session_record["id"], + app_name=self._session_record["app_name"], + user_id=self._session_record["user_id"], + state=state, + create_time=self._session_record["create_time"], + update_time=datetime.now(timezone.utc), + ) + return self._session_record + + def list_sessions(self, app_name: str, user_id: str | None = None) -> list[SessionRecord]: + return [self._session_record] + + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + return None + + +class StaleDetectionStore(MockStore): + """Mock store that can simulate a storage-side update between load and append.""" + + def __init__(self, *, stale_marker: bool = False, stale_timestamp: bool = False) -> None: + super().__init__() + self._stale_marker = stale_marker + self._stale_timestamp = stale_timestamp + + async def get_session(self, app_name: str, user_id: str, session_id: str) -> "dict[str, Any] | None": + record = dict(self._session_record) + if self._stale_marker or self._stale_timestamp: + # Simulate a storage-side update by advancing update_time. + from datetime import timedelta + + record["update_time"] = record["update_time"] + timedelta(seconds=10) # type: ignore[operator] + return record + + +class MissingSessionStore(MockStore): + """Mock store where the session disappears between load and append.""" + + async def get_session(self, app_name: str, user_id: str, session_id: str) -> "dict[str, Any] | None": + return None + + # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- @@ -252,6 +420,7 @@ async def test_append_event_passes_correct_session_id_to_store() -> None: store = MockStore() service = SQLSpecSessionService(store) # type: ignore[arg-type] session = _make_session(session_id="my-unique-session-id") + store._session_record["id"] = "my-unique-session-id" event = _make_event() await service.append_event(session, event) @@ -262,7 +431,7 @@ async def test_append_event_passes_correct_session_id_to_store() -> None: @pytest.mark.anyio async def test_append_event_event_record_has_5_keys() -> None: - """The event_record passed to the store has exactly 5 keys (new schema).""" + """The event_record passed to the store has exactly the clean-break schema keys.""" store = MockStore() service = SQLSpecSessionService(store) # type: ignore[arg-type] session = _make_session() @@ -272,7 +441,15 @@ async def test_append_event_event_record_has_5_keys() -> None: last_call = store.append_event_and_update_state_calls[-1] event_record = last_call["event_record"] - assert set(event_record.keys()) == {"session_id", "invocation_id", "author", "timestamp", "event_json"} + assert set(event_record.keys()) == { + "id", + "app_name", + "user_id", + "session_id", + "invocation_id", + "timestamp", + "event_data", + } @pytest.mark.anyio @@ -296,7 +473,7 @@ async def test_append_event_returns_the_event() -> None: @pytest.mark.anyio async def test_create_session_strips_temp_keys_from_initial_state() -> None: - """create_session filters temp: keys before passing state to the store.""" + """create_session filters temp: keys and splits app state before storing.""" store = MockStore() service = SQLSpecSessionService(store) # type: ignore[arg-type] @@ -306,7 +483,28 @@ async def test_create_session_strips_temp_keys_from_initial_state() -> None: persisted_state = store.create_session_calls[0]["state"] assert "temp:y" not in persisted_state assert persisted_state["x"] == 1 - assert persisted_state["app:z"] == 3 + assert "app:z" not in persisted_state + assert store.upsert_app_state_calls[-1]["state"] == {"app:z": 3} + + +@pytest.mark.anyio +async def test_create_session_supports_sync_store() -> None: + """The ADK async service bridges sync stores at the service boundary.""" + store = SyncStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + session = await service.create_session( + app_name="app", + user_id="u1", + session_id="sync-session", + state={"x": 1, "user:theme": "dark", "app:mode": "prod"}, + ) + + assert session.id == "sync-session" + assert store.create_session_calls[0]["state"] == {"x": 1} + assert store.user_state == {"user:theme": "dark"} + assert store.app_state == {"app:mode": "prod"} + assert session.state == {"x": 1, "app:mode": "prod", "user:theme": "dark"} @pytest.mark.anyio @@ -354,34 +552,42 @@ async def test_create_session_uses_provided_session_id() -> None: assert session.id == "my-id" -# --------------------------------------------------------------------------- -# Stale-session detection -# --------------------------------------------------------------------------- +@pytest.mark.anyio +async def test_create_session_splits_user_state_from_session_row() -> None: + """create_session persists user: keys in the user state bucket.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] + session = await service.create_session(app_name="app", user_id="u1", state={"x": 1, "user:theme": "dark"}) -class StaleDetectionStore(MockStore): - """Mock store that can simulate a storage-side update between load and append.""" + assert store.create_session_calls[0]["state"] == {"x": 1} + assert store.upsert_user_state_calls[-1]["state"] == {"user:theme": "dark"} + assert session.state == {"x": 1, "user:theme": "dark"} - def __init__(self, *, stale_marker: bool = False, stale_timestamp: bool = False) -> None: - super().__init__() - self._stale_marker = stale_marker - self._stale_timestamp = stale_timestamp - async def get_session(self, session_id: str) -> "dict[str, Any] | None": - record = dict(self._session_record) - if self._stale_marker or self._stale_timestamp: - # Simulate a storage-side update by advancing update_time - from datetime import timedelta +@pytest.mark.anyio +async def test_get_session_num_recent_events_zero_returns_no_events_without_store_query() -> None: + """ADK documents num_recent_events=0 as returning no events.""" + store = MockStore() + service = SQLSpecSessionService(store) # type: ignore[arg-type] - record["update_time"] = record["update_time"] + timedelta(seconds=10) # type: ignore[operator] - return record + session = await service.get_session( + app_name="app", user_id="u1", session_id="s1", config=GetSessionConfig(num_recent_events=0) + ) + assert session is not None + assert session.events == [] + assert store.get_events_calls == [] -class MissingSessionStore(MockStore): - """Mock store where the session disappears between load and append.""" - async def get_session(self, session_id: str) -> "dict[str, Any] | None": - return None +@pytest.mark.anyio +async def test_get_user_state_strips_user_prefixes() -> None: + """ADK 2.2 BaseSessionService.get_user_state returns unprefixed keys.""" + store = MockStore() + store.user_state = {"user:theme": "dark", "raw": "kept"} + service = SQLSpecSessionService(store) # type: ignore[arg-type] + + assert await service.get_user_state(app_name="app", user_id="u1") == {"theme": "dark", "raw": "kept"} @pytest.mark.anyio @@ -456,7 +662,14 @@ async def test_append_event_updates_inmemory_after_persist() -> None: class FailingStore(MockStore): async def append_event_and_update_state( - self, event_record: Any, session_id: str, state: Any + self, + event_record: Any, + app_name: str, + user_id: str, + session_id: str, + state: Any, + app_state: "dict[str, Any] | None" = None, + user_state: "dict[str, Any] | None" = None, ) -> "dict[str, Any]": raise RuntimeError("Simulated DB failure") diff --git a/tests/unit/extensions/test_adk/test_store_config.py b/tests/unit/extensions/test_adk/test_store_config.py index 6d172bbac..ea428ef75 100644 --- a/tests/unit/extensions/test_adk/test_store_config.py +++ b/tests/unit/extensions/test_adk/test_store_config.py @@ -1,6 +1,7 @@ # pyright: reportPrivateUsage=false """Tests for shared ADK store configuration behavior.""" +import importlib import logging from datetime import datetime from typing import Any @@ -13,7 +14,7 @@ from sqlspec.extensions.adk.memory import MemoryRecord from sqlspec.extensions.adk.memory import store as memory_store_module from sqlspec.extensions.adk.memory.store import BaseSyncADKMemoryStore -from sqlspec.extensions.adk.store import BaseSyncADKStore +from sqlspec.extensions.adk.store import BaseAsyncADKStore, BaseSyncADKStore class _Config: @@ -29,7 +30,128 @@ def provide_session(self) -> str: return "original-session" +class _AsyncSessionStore(BaseAsyncADKStore[Any]): + async def create_session( + self, session_id: str, app_name: str, user_id: str, state: dict[str, Any], owner_id: Any | None = None + ) -> SessionRecord: + return SessionRecord( + id=session_id, + app_name=app_name, + user_id=user_id, + state=state, + create_time=datetime.now(), + update_time=datetime.now(), + ) + + async def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: Any | None = None + ) -> SessionRecord | None: + return None + + async def update_session_state(self, app_name: str, user_id: str, session_id: str, state: dict[str, Any]) -> None: + return None + + async def list_sessions(self, app_name: str, user_id: str | None = None) -> list[SessionRecord]: + return [] + + async def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + return None + + async def append_event(self, event_record: EventRecord) -> None: + return None + + async def append_event_and_update_state( + self, + event_record: EventRecord, + app_name: str, + user_id: str, + session_id: str, + state: dict[str, Any], + *, + app_state: dict[str, Any] | None = None, + user_state: dict[str, Any] | None = None, + ) -> SessionRecord: + return await self.create_session(session_id, app_name, user_id, state) + + async def get_events( + self, + app_name: str, + user_id: str, + session_id: str, + after_timestamp: datetime | None = None, + limit: int | None = None, + ) -> list[EventRecord]: + return [] + + async def delete_expired_events(self, before: datetime) -> int: + return 0 + + async def delete_idle_sessions(self, updated_before: datetime) -> int: + return 0 + + async def get_app_state(self, app_name: str) -> dict[str, Any] | None: + return None + + async def get_user_state(self, app_name: str, user_id: str) -> dict[str, Any] | None: + return None + + async def upsert_app_state(self, app_name: str, state: dict[str, Any]) -> None: + return None + + async def upsert_user_state(self, app_name: str, user_id: str, state: dict[str, Any]) -> None: + return None + + async def get_metadata(self, key: str) -> str | None: + return None + + async def set_metadata(self, key: str, value: str) -> None: + return None + + async def create_tables(self) -> None: + return None + + async def _get_create_sessions_table_sql(self) -> str: + return "" + + async def _get_create_events_table_sql(self) -> str: + return "" + + async def _get_create_app_states_table_sql(self) -> str: + return "" + + async def _get_create_user_states_table_sql(self) -> str: + return "" + + async def _get_create_metadata_table_sql(self) -> str: + return "" + + async def _get_seed_metadata_sql(self) -> str: + return "" + + def _get_drop_app_states_table_sql(self) -> str: + return "" + + def _get_drop_user_states_table_sql(self) -> str: + return "" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + + def _get_drop_tables_sql(self) -> list[str]: + return [ + self._get_drop_metadata_table_sql(), + f"DROP TABLE IF EXISTS {self._user_state_table}", + f"DROP TABLE IF EXISTS {self._app_state_table}", + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] + + class _SyncSessionStore(BaseSyncADKStore[Any]): + def __init__(self, config: _Config) -> None: + super().__init__(config) + self.create_tables_called = False + def create_session( self, session_id: str, app_name: str, user_id: str, state: dict[str, Any], owner_id: Any | None = None ) -> SessionRecord: @@ -42,54 +164,108 @@ def create_session( update_time=datetime.now(), ) - def get_session(self, session_id: str) -> SessionRecord | None: + def get_session( + self, app_name: str, user_id: str, session_id: str, *, renew_for: Any | None = None + ) -> SessionRecord | None: return None - def update_session_state(self, session_id: str, state: dict[str, Any]) -> None: + def update_session_state(self, app_name: str, user_id: str, session_id: str, state: dict[str, Any]) -> None: return None def list_sessions(self, app_name: str, user_id: str | None = None) -> list[SessionRecord]: return [] - def delete_session(self, session_id: str) -> None: + def delete_session(self, app_name: str, user_id: str, session_id: str) -> None: + return None + + def append_event(self, event_record: EventRecord) -> None: return None - def create_event( + def append_event_and_update_state( self, - event_id: str, + event_record: EventRecord, + app_name: str, + user_id: str, session_id: str, + state: dict[str, Any], + *, + app_state: dict[str, Any] | None = None, + user_state: dict[str, Any] | None = None, + ) -> SessionRecord: + return self.create_session(session_id, app_name, user_id, state) + + def get_events( + self, app_name: str, user_id: str, - author: str | None = None, - actions: bytes | None = None, - content: dict[str, Any] | None = None, - **kwargs: Any, - ) -> EventRecord: - return EventRecord( - session_id=session_id, - invocation_id=event_id, - author=author or user_id, - timestamp=datetime.now(), - event_json=content or {}, - ) + session_id: str, + after_timestamp: datetime | None = None, + limit: int | None = None, + ) -> list[EventRecord]: + return [] + + def delete_expired_events(self, before: datetime) -> int: + return 0 - def create_event_and_update_state(self, event_record: EventRecord, session_id: str, state: dict[str, Any]) -> None: + def delete_idle_sessions(self, updated_before: datetime) -> int: + return 0 + + def get_app_state(self, app_name: str) -> dict[str, Any] | None: return None - def list_events(self, session_id: str) -> list[EventRecord]: - return [] + def get_user_state(self, app_name: str, user_id: str) -> dict[str, Any] | None: + return None - def create_tables(self) -> None: + def upsert_app_state(self, app_name: str, state: dict[str, Any]) -> None: + return None + + def upsert_user_state(self, app_name: str, user_id: str, state: dict[str, Any]) -> None: + return None + + def get_metadata(self, key: str) -> str | None: return None + def set_metadata(self, key: str, value: str) -> None: + return None + + def create_tables(self) -> None: + self.create_tables_called = True + def _get_create_sessions_table_sql(self) -> str: return "" def _get_create_events_table_sql(self) -> str: return "" + def _get_create_app_states_table_sql(self) -> str: + return "" + + def _get_create_user_states_table_sql(self) -> str: + return "" + + def _get_create_metadata_table_sql(self) -> str: + return "" + + def _get_seed_metadata_sql(self) -> str: + return "" + + def _get_drop_app_states_table_sql(self) -> str: + return "" + + def _get_drop_user_states_table_sql(self) -> str: + return "" + + def _get_drop_metadata_table_sql(self) -> str: + return f"DROP TABLE IF EXISTS {self._metadata_table}" + def _get_drop_tables_sql(self) -> list[str]: - return [] + return [ + self._get_drop_metadata_table_sql(), + f"DROP TABLE IF EXISTS {self._user_state_table}", + f"DROP TABLE IF EXISTS {self._app_state_table}", + f"DROP TABLE IF EXISTS {self._events_table}", + f"DROP TABLE IF EXISTS {self._session_table}", + ] class _SyncMemoryStore(BaseSyncADKMemoryStore[Any]): @@ -116,7 +292,7 @@ def _get_create_memory_table_sql(self) -> str | list[str]: return "" def _get_drop_memory_table_sql(self) -> list[str]: - return [] + return [f"DROP TABLE IF EXISTS {self._memory_table}"] class _SyncArtifactStore(BaseSyncADKArtifactStore[Any]): @@ -148,7 +324,7 @@ def create_table(self) -> None: return None -@pytest.mark.parametrize("store_cls", [_SyncSessionStore, _SyncMemoryStore, _SyncArtifactStore]) +@pytest.mark.parametrize("store_cls", [_AsyncSessionStore, _SyncSessionStore, _SyncMemoryStore, _SyncArtifactStore]) def test_adk_base_stores_keep_original_config(store_cls: type[Any]) -> None: config = _Config() store = store_cls(config) @@ -156,6 +332,14 @@ def test_adk_base_stores_keep_original_config(store_cls: type[Any]) -> None: assert store.config is config +def test_sync_session_store_ensure_tables_runs_sync_create_tables() -> None: + store = _SyncSessionStore(_Config({"session_table": "sessions", "events_table": "events"})) + + store.ensure_tables() + + assert store.create_tables_called + + def test_sync_memory_store_logs_ready_with_log_with_context(monkeypatch: pytest.MonkeyPatch) -> None: calls: list[dict[str, Any]] = [] @@ -193,3 +377,53 @@ def fake_log_with_context(logger: Any, level: int, event: str, **context: Any) - assert calls[0]["context"]["memory_table"] == "test_memories" assert calls[0]["context"]["reason"] == "disabled" assert "db_system" in calls[0]["context"] + + +def test_session_store_reset_drop_tables_includes_legacy_metadata_table() -> None: + store = _AsyncSessionStore(_Config()) + + statements = store._get_reset_drop_tables_sql() + + assert "DROP TABLE IF EXISTS adk_internal_metadata" in statements + assert "DROP TABLE IF EXISTS adk_metadata" in statements + assert "DROP TABLE IF EXISTS adk_session" in statements + assert "DROP TABLE IF EXISTS adk_event" in statements + assert "DROP TABLE IF EXISTS adk_app_state" in statements + assert "DROP TABLE IF EXISTS adk_user_state" in statements + assert "DROP TABLE IF EXISTS adk_sessions" in statements + assert "DROP TABLE IF EXISTS adk_events" in statements + assert "DROP TABLE IF EXISTS adk_app_states" in statements + assert "DROP TABLE IF EXISTS adk_user_states" in statements + assert store.metadata_table == "adk_internal_metadata" + + +def test_session_store_reset_drop_tables_does_not_duplicate_configured_legacy_metadata_table() -> None: + store = _AsyncSessionStore(_Config({"metadata_table": "adk_metadata"})) + + statements = store._get_reset_drop_tables_sql() + + assert statements.count("DROP TABLE IF EXISTS adk_metadata") == 1 + assert store.metadata_table == "adk_metadata" + + +@pytest.mark.anyio +async def test_reset_migration_accepts_sync_session_store(monkeypatch: pytest.MonkeyPatch) -> None: + migration = importlib.import_module("sqlspec.extensions.adk.migrations.0002_reset_adk_tables") + context = type("Context", (), {"config": _Config()})() + + monkeypatch.setattr(migration, "_get_store_class", lambda _context: _SyncSessionStore) + monkeypatch.setattr(migration, "_get_memory_store_class", lambda _context: None) + + statements = await migration.up(context) + + assert "" in statements + + +def test_sync_memory_store_reset_drop_tables_uses_drop_sql() -> None: + store = _SyncMemoryStore(_Config({"memory_table": "agent_memory"})) + + assert store._get_reset_drop_memory_table_sql() == [ + "DROP TABLE IF EXISTS agent_memory", + "DROP TABLE IF EXISTS adk_memory", + "DROP TABLE IF EXISTS adk_memory_entries", + ] diff --git a/tests/unit/extensions/test_adk/test_store_instantiation.py b/tests/unit/extensions/test_adk/test_store_instantiation.py index 21a79a4fa..4ffff6ac4 100644 --- a/tests/unit/extensions/test_adk/test_store_instantiation.py +++ b/tests/unit/extensions/test_adk/test_store_instantiation.py @@ -65,6 +65,30 @@ ALL_STORE_CLASSES = SESSION_STORE_CLASSES + MEMORY_STORE_CLASSES +SYNC_SESSION_STORE_CLASSES = [ + "sqlspec.adapters.adbc.adk.AdbcADKStore", + "sqlspec.adapters.cockroach_psycopg.adk.CockroachPsycopgSyncADKStore", + "sqlspec.adapters.duckdb.adk.DuckdbADKStore", + "sqlspec.adapters.mysqlconnector.adk.MysqlConnectorSyncADKStore", + "sqlspec.adapters.oracledb.adk.OracleSyncADKStore", + "sqlspec.adapters.psycopg.adk.PsycopgSyncADKStore", + "sqlspec.adapters.pymysql.adk.PyMysqlADKStore", + "sqlspec.adapters.spanner.adk.SpannerSyncADKStore", + "sqlspec.adapters.sqlite.adk.SqliteADKStore", +] + +SYNC_MEMORY_STORE_CLASSES = [ + "sqlspec.adapters.adbc.adk.AdbcADKMemoryStore", + "sqlspec.adapters.cockroach_psycopg.adk.CockroachPsycopgSyncADKMemoryStore", + "sqlspec.adapters.duckdb.adk.DuckdbADKMemoryStore", + "sqlspec.adapters.mysqlconnector.adk.MysqlConnectorSyncADKMemoryStore", + "sqlspec.adapters.oracledb.adk.OracleSyncADKMemoryStore", + "sqlspec.adapters.psycopg.adk.PsycopgSyncADKMemoryStore", + "sqlspec.adapters.pymysql.adk.PyMysqlADKMemoryStore", + "sqlspec.adapters.spanner.adk.SpannerSyncADKMemoryStore", + "sqlspec.adapters.sqlite.adk.SqliteADKMemoryStore", +] + def _load_class(class_path: str) -> type: module_path, class_name = class_path.rsplit(".", 1) @@ -112,6 +136,28 @@ def test_store_method_signatures_match_base_contract(class_path: str) -> None: ) +@pytest.mark.parametrize("class_path", SYNC_SESSION_STORE_CLASSES) +def test_sync_session_store_contract_methods_are_sync(class_path: str) -> None: + """Sync-backed session stores expose sync methods, not async wrappers.""" + cls = _load_class(class_path) + + assert issubclass(cls, BaseSyncADKStore) + assert not issubclass(cls, BaseAsyncADKStore) + for method_name in BaseSyncADKStore.__abstractmethods__: + assert not inspect.iscoroutinefunction(getattr(cls, method_name)), f"{class_path}.{method_name} is async" + + +@pytest.mark.parametrize("class_path", SYNC_MEMORY_STORE_CLASSES) +def test_sync_memory_store_contract_methods_are_sync(class_path: str) -> None: + """Sync-backed memory stores expose sync methods, not async wrappers.""" + cls = _load_class(class_path) + + assert issubclass(cls, BaseSyncADKMemoryStore) + assert not issubclass(cls, BaseAsyncADKMemoryStore) + for method_name in BaseSyncADKMemoryStore.__abstractmethods__: + assert not inspect.iscoroutinefunction(getattr(cls, method_name)), f"{class_path}.{method_name} is async" + + def test_adk_store_registration_validator_resolves_sqlite_store_classes() -> None: """The migration registration validator resolves both SQLite ADK store classes.""" from sqlspec.adapters.sqlite import SqliteConfig From 09269e27b632bd590f9509aafc4ae842dc225f9d Mon Sep 17 00:00:00 2001 From: Cody Fincher Date: Sun, 14 Jun 2026 02:25:00 +0000 Subject: [PATCH 2/2] fix(adk): commit aiosqlite schema setup --- sqlspec/adapters/aiosqlite/adk/store.py | 2 ++ sqlspec/config.py | 4 ++-- .../adapters/aiosqlite/extensions/adk/test_store.py | 2 +- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/sqlspec/adapters/aiosqlite/adk/store.py b/sqlspec/adapters/aiosqlite/adk/store.py index 5216b6b1e..4ffb89fe1 100644 --- a/sqlspec/adapters/aiosqlite/adk/store.py +++ b/sqlspec/adapters/aiosqlite/adk/store.py @@ -58,6 +58,7 @@ async def create_tables(self) -> None: await driver.execute_script(await self._get_create_user_states_table_sql()) await driver.execute_script(await self._get_create_metadata_table_sql()) await driver.execute_script(await self._get_seed_metadata_sql()) + await driver.commit() async def create_session( self, session_id: str, app_name: str, user_id: str, state: "dict[str, Any]", owner_id: "Any | None" = None @@ -720,6 +721,7 @@ async def create_tables(self) -> None: async with self._config.provide_session() as driver: await driver.execute_script(await self._get_create_memory_table_sql()) + await driver.commit() async def insert_memory_entries(self, entries: "list[MemoryRecord]", owner_id: "object | None" = None) -> int: """Bulk insert memory entries with deduplication. diff --git a/sqlspec/config.py b/sqlspec/config.py index b51ae59cf..500472d2a 100644 --- a/sqlspec/config.py +++ b/sqlspec/config.py @@ -606,10 +606,10 @@ class ADKConfig(TypedDict): """ session_table: NotRequired[str] - """Name of the sessions table. Default: 'adk_sessions'""" + """Name of the sessions table. Default: 'adk_session'""" events_table: NotRequired[str] - """Name of the events table. Default: 'adk_events'""" + """Name of the events table. Default: 'adk_event'""" memory_table: NotRequired[str] """Name of the memory entries table. Default: 'adk_memory_entries'""" diff --git a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py index 8f64a3676..31d8af571 100644 --- a/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py +++ b/tests/integration/adapters/aiosqlite/extensions/adk/test_store.py @@ -37,7 +37,7 @@ async def test_aiosqlite_session_owner_column_is_created_when_configured(tmp_pat await store.create_session("session-owner", "app", "user", {}, owner_id="tenant-1") async with config.provide_connection() as conn: - cursor = await conn.execute("SELECT owner_id FROM adk_sessions WHERE id = ?", ("session-owner",)) + cursor = await conn.execute("SELECT owner_id FROM adk_session WHERE id = ?", ("session-owner",)) row = await cursor.fetchone() assert row == ("tenant-1",)