Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions src/agents/extensions/memory/dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,19 +323,49 @@ async def add_items(self, items: list[TResponseInputItem]) -> None:
continue
raise

# Update metadata
metadata = {
"session_id": self.session_id,
"created_at": str(int(time.time())),
"updated_at": str(int(time.time())),
}
await self._dapr_client.save_state(
store_name=self._state_store_name,
key=self._metadata_key,
value=json.dumps(metadata),
state_metadata=self._get_metadata(),
options=self._get_state_options(),
)
# Update metadata, preserving created_at across subsequent writes.
# Use first-write concurrency with the read ETag so a concurrent write
# that already established `created_at` can't be clobbered by a stale
# read that saw no metadata.
now = str(int(time.time()))
meta_attempt = 0
while True:
meta_attempt += 1
existing_meta_response = await self._dapr_client.get_state(
store_name=self._state_store_name,
key=self._metadata_key,
state_metadata=self._get_read_metadata(),
)
created_at = now
if existing_meta_response.data:
try:
existing_meta = json.loads(existing_meta_response.data.decode("utf-8"))
if isinstance(existing_meta, dict) and existing_meta.get("created_at"):
created_at = str(existing_meta["created_at"])
except (json.JSONDecodeError, UnicodeDecodeError, AttributeError):
# Corrupt metadata — start fresh with current timestamp.
pass
metadata = {
"session_id": self.session_id,
"created_at": created_at,
"updated_at": now,
}
meta_etag = getattr(existing_meta_response, "etag", None) or None
try:
await self._dapr_client.save_state(
store_name=self._state_store_name,
key=self._metadata_key,
value=json.dumps(metadata),
etag=meta_etag,
state_metadata=self._get_metadata(),
options=self._get_state_options(concurrency=Concurrency.first_write),
)
break
except Exception as error:
should_retry = await self._handle_concurrency_conflict(error, meta_attempt)
if should_retry:
continue
raise

async def pop_item(self) -> TResponseInputItem | None:
"""Remove and return the most recent item from the session.
Expand Down
26 changes: 26 additions & 0 deletions tests/extensions/memory/test_dapr_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,32 @@ async def test_add_empty_items_list(fake_dapr_client: FakeDaprClient):
await session.close()


async def test_metadata_preserves_created_at(fake_dapr_client: FakeDaprClient):
"""add_items must preserve created_at across writes; only updated_at advances."""
session = await _create_test_session(fake_dapr_client)
try:
await session.add_items([{"role": "user", "content": "first"}])
first_meta_raw = fake_dapr_client._state[session._metadata_key].decode("utf-8")
first_meta = json.loads(first_meta_raw)
first_created = first_meta["created_at"]
first_updated = first_meta["updated_at"]

# Wait one second so timestamps are guaranteed to differ.
import time as _time

_time.sleep(1)

await session.add_items([{"role": "user", "content": "second"}])
second_meta = json.loads(fake_dapr_client._state[session._metadata_key].decode("utf-8"))

assert second_meta["created_at"] == first_created, (
"created_at must be preserved across add_items calls"
)
assert int(second_meta["updated_at"]) >= int(first_updated)
finally:
await session.close()


async def test_unicode_content(fake_dapr_client: FakeDaprClient):
"""Test that session correctly stores and retrieves unicode/non-ASCII content."""
session = await _create_test_session(fake_dapr_client)
Expand Down