Skip to content

Commit 76b9f0b

Browse files
GWealecopybara-github
authored andcommitted
fix(cache): enforce CacheMetadata active-state invariant
Co-authored-by: George Weale <gweale@google.com> PiperOrigin-RevId: 914400439
1 parent 8dd9147 commit 76b9f0b

3 files changed

Lines changed: 43 additions & 6 deletions

File tree

src/google/adk/models/cache_metadata.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from pydantic import BaseModel
2121
from pydantic import ConfigDict
2222
from pydantic import Field
23+
from pydantic import model_validator
2324

2425

2526
class CacheMetadata(BaseModel):
@@ -97,6 +98,16 @@ class CacheMetadata(BaseModel):
9798
),
9899
)
99100

101+
@model_validator(mode="after")
102+
def _enforce_active_state_invariant(self) -> "CacheMetadata":
103+
active = (self.cache_name, self.expire_time, self.invocations_used)
104+
if len({f is not None for f in active}) > 1:
105+
raise ValueError(
106+
"cache_name, expire_time, and invocations_used must all be set "
107+
"(active cache) or all be None (fingerprint-only state)"
108+
)
109+
return self
110+
100111
@property
101112
def expire_soon(self) -> bool:
102113
"""Check if the cache will expire soon (with 2-minute buffer)."""
@@ -113,12 +124,6 @@ def __str__(self) -> str:
113124
f"fingerprint={self.fingerprint[:8]}..."
114125
)
115126
cache_id = self.cache_name.split("/")[-1]
116-
if self.expire_time is None:
117-
return (
118-
f"Cache {cache_id}: used {self.invocations_used} invocations, "
119-
f"cached {self.contents_count} contents, "
120-
"expires unknown"
121-
)
122127
time_until_expiry_minutes = (self.expire_time - time.time()) / 60
123128
return (
124129
f"Cache {cache_id}: used {self.invocations_used} invocations, "

tests/unittests/models/test_cache_metadata.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,30 @@ def test_missing_required_fields(self):
317317
assert metadata.expire_time is None
318318
assert metadata.invocations_used is None
319319
assert metadata.created_at is None
320+
321+
def test_partial_active_state_rejected(self):
322+
"""cache_name, expire_time, invocations_used must all be set or all None."""
323+
# Only cache_name set.
324+
with pytest.raises(ValidationError, match="must all be set"):
325+
CacheMetadata(
326+
cache_name="projects/123/locations/us-central1/cachedContents/x",
327+
fingerprint="abc",
328+
contents_count=1,
329+
)
330+
331+
# cache_name + expire_time but no invocations_used.
332+
with pytest.raises(ValidationError, match="must all be set"):
333+
CacheMetadata(
334+
cache_name="projects/123/locations/us-central1/cachedContents/x",
335+
expire_time=time.time() + 1800,
336+
fingerprint="abc",
337+
contents_count=1,
338+
)
339+
340+
# invocations_used set without cache_name (e.g. construction bug).
341+
with pytest.raises(ValidationError, match="must all be set"):
342+
CacheMetadata(
343+
fingerprint="abc",
344+
invocations_used=3,
345+
contents_count=1,
346+
)

tests/unittests/sessions/test_vertex_ai_session_service.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,7 +1020,12 @@ async def test_append_event():
10201020
),
10211021
cache_metadata=CacheMetadata(
10221022
cache_name='test_cache_name',
1023+
expire_time=(
1024+
datetime.datetime.now(datetime.timezone.utc)
1025+
+ datetime.timedelta(minutes=30)
1026+
).timestamp(),
10231027
fingerprint='test_fingerprint',
1028+
invocations_used=1,
10241029
contents_count=1,
10251030
),
10261031
citation_metadata=genai_types.CitationMetadata(

0 commit comments

Comments
 (0)