diff --git a/openeo_driver/users/auth.py b/openeo_driver/users/auth.py index d1606003..303723b4 100644 --- a/openeo_driver/users/auth.py +++ b/openeo_driver/users/auth.py @@ -50,7 +50,7 @@ def __init__( # TODO: handle `oidc_providers` and `user_access_validation` through OpenEoBackendConfig self._oidc_providers: Dict[str, OidcProvider] = {p.id: p for p in oidc_providers} self._user_access_validation = user_access_validation - self._cache = TtlCache(default_ttl=10 * 60) + self._cache = TtlCache(default_ttl=15 * 60) def public(self, f: Callable): """ @@ -112,7 +112,7 @@ def get_user_from_bearer_token(self, request: flask.Request) -> User: cache_key = ("bearer", bearer) if not self._cache.contains(cache_key): user = self._get_user_from_bearer_token(bearer=bearer) - self._cache.set(cache_key, value=user, ttl=15 * 60) + self._cache.set(cache_key, value=user) return self._cache.get(cache_key) def _get_user_from_bearer_token(self, bearer: str) -> User: @@ -210,7 +210,6 @@ def _get_oidc_provider_config(self, oidc_provider: OidcProvider) -> dict: callback=lambda: self._oidc_provider_request( oidc_provider.discovery_url ).json(), - ttl=15 * 60, ) def resolve_oidc_access_token(self, oidc_provider: OidcProvider, access_token: str) -> User: diff --git a/openeo_driver/util/caching.py b/openeo_driver/util/caching.py index 2c3ed2b8..c5d490e2 100644 --- a/openeo_driver/util/caching.py +++ b/openeo_driver/util/caching.py @@ -1,7 +1,11 @@ import functools import logging +import threading import time -from typing import Union, Tuple, Any, Callable, Dict, Optional +import warnings +from typing import Union, Tuple, Any, Callable, Optional + +import cachetools _log = logging.getLogger(__name__) @@ -11,62 +15,88 @@ class TtlCache: """ - Simple dictionary based, in-memory key-value cache with expiry. + In-memory key-value cache with TTL expiry and a maximum size, backed by + :class:`cachetools.TTLCache`. When the cache is full, the least-recently-used + item is evicted to make room for new entries. + + Cache interactions are thread-safe. The lock is intentionally *not* held while + a cache-miss callback is executing, so slow callbacks do not block other readers. """ def __init__( - self, default_ttl: float = 60, _clock: Callable[[], float] = time.time + self, + default_ttl: float = 60, + *, + max_size: int = 1000, + _clock: Callable[[], float] = time.time, ): - self._cache: Dict[CacheKey, Tuple[Any, float]] = {} self.default_ttl = default_ttl - self._clock = _clock + self._cache: cachetools.TTLCache = cachetools.TTLCache(maxsize=max_size, ttl=default_ttl, timer=_clock) + self._lock = threading.Lock() def set(self, key: CacheKey, value: Any, ttl: Optional[float] = None) -> None: """Store item in cache""" - self._cache[key] = (value, self._clock() + (ttl or self.default_ttl)) + if ttl is not None: + warnings.warn( + "Per-item ttl is deprecated and will be ignored; use default_ttl on the cache instead.", + DeprecationWarning, + stacklevel=2, + ) + with self._lock: + self._cache[key] = value def contains(self, key: CacheKey) -> bool: - """Check whether cache contains item under given key""" - if key in self._cache: - value, expiration = self._cache[key] - if self._clock() <= expiration: - return True - del self._cache[key] - return False + """Check whether cache contains a non-expired item under the given key.""" + with self._lock: + return key in self._cache def get(self, key: CacheKey, default=None) -> Any: - """Get item from cache and if not available: return default value.""" - # TODO: raise KeyError on cache miss? - return self._cache[key][0] if self.contains(key) else default + """Get item from cache; return *default* on a cache miss or expiry.""" + with self._lock: + return self._cache.get(key, default) def get_or_call( self, key: CacheKey, callback: Callable[[], Any], ttl: Optional[float] = None ) -> Any: """ - Try to get item from cache. If not available or expired: call callback to build it and store result in cache. + Return the cached value for *key*, or call *callback* to build it on a miss. - This method allows to implement typicall cache usage pattern in a single call: + The lock is held only during cache look-up and result storage, **not** while + *callback* is executing. This means two concurrent callers may both experience + a cache miss and both invoke the callback simultaneously; the last one to finish + wins the store. This is intentional — it avoids blocking other callers during + potentially slow work. + + This method allows implementing the typical cache usage pattern in a single call:: item = cache.get_or_call( key="foo", - callback=lambda: expensive_operation(iterations=10000) + callback=lambda: expensive_operation(iterations=10000), ) - :param key: key to store item at (can be a simple string, - or something more complex like a tuple of strings/ints) - :param callback: item builder to call when item is not in cache or expired - :param ttl: optionally override default TTL - :return: item (from cache or freshly built) + :param key: cache key (a string or a tuple of strings/ints) + :param callback: callable invoked on a cache miss to produce the value + :param ttl: deprecated; has no effect and will emit a :class:`DeprecationWarning` + :return: the cached or freshly produced value """ - if self.contains(key): - value = self._cache[key][0] - else: - value = callback() - self.set(key=key, value=value, ttl=ttl) + if ttl is not None: + warnings.warn( + "Per-item ttl is deprecated and will be ignored; use default_ttl on the cache instead.", + DeprecationWarning, + stacklevel=2, + ) + with self._lock: + if key in self._cache: + return self._cache[key] + # Lock intentionally released before calling the callback. + value = callback() + with self._lock: + self._cache[key] = value return value - def flush(self): - self._cache = {} + def flush(self) -> None: + with self._lock: + self._cache.clear() def lru_cache_if_simple_args(func=None, *, maxsize: int = 128): diff --git a/setup.py b/setup.py index 5d5423fc..a6ef418f 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,7 @@ "markdown>3.4", "pystac>=1.8.0", # TODO #370/#396 require more recent pystac version once py3.8 support is dropped "antimeridian>=0.3.8", # 0.3.8 is the highest version that still supports Python 3.8 + "cachetools>=5.0", ], extras_require={ "dev": tests_require + typing_require, diff --git a/tests/conftest.py b/tests/conftest.py index b52aa354..d1554000 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,12 +3,14 @@ import logging import os import time +import typing from typing import Optional from unittest import mock import flask import pytest import pythonjsonlogger.jsonlogger +from openeo.testing.util import Sleeper import openeo_driver.config.load import openeo_driver.dummy.dummy_config @@ -144,3 +146,15 @@ def swift_url(monkeypatch): For real environments this is used as the fallback endpoint for doing S3 requests. """ monkeypatch.setenv("SWIFT_URL", TEST_SWIFT_URL) + + +@pytest.fixture +def fast_sleep(time_machine) -> typing.Iterator[Sleeper]: + with Sleeper().patch(time_machine=time_machine) as sleeper: + yield sleeper + + +@pytest.fixture +def simple_time(time_machine, fast_sleep): + """Fixture to set up simple time stepping (and fast sleep) with `time_machine`.""" + time_machine.move_to(1000, tick=False) diff --git a/tests/users/test_auth.py b/tests/users/test_auth.py index 7daa10d5..9d4eaa95 100644 --- a/tests/users/test_auth.py +++ b/tests/users/test_auth.py @@ -1,3 +1,5 @@ +import time + import flask import json import pytest @@ -361,7 +363,7 @@ def userinfo(request, context): assert internal_auth_data["access_token"] == oidc_access_token -def test_bearer_auth_oidc_caching(app, requests_mock, oidc_provider): +def test_bearer_auth_oidc_caching(app, requests_mock, oidc_provider, simple_time): def userinfo(request, context): """Fake OIDC /userinfo endpoint handler""" _, _, token = request.headers["Authorization"].partition("Bearer ") @@ -370,16 +372,11 @@ def userinfo(request, context): userinfo = requests_mock.get(oidc_provider.issuer + "/userinfo", text=userinfo) - def set_time(time): - # TODO reusable time mocking - app.config["auth_handler"]._cache._clock = lambda: time - with app.test_client() as client: # Note: user id is "hidden" in access token headers = {"Authorization": f"Bearer oidc/{oidc_provider.id}/rmxje3uhs.oidcuser.o94h4oe"} headers_other = {"Authorization": f"Bearer oidc/{oidc_provider.id}/trwe35.otheruser.fg34fsf"} - set_time(10) resp = client.get("/personal/hello", headers=headers) assert (resp.status_code, resp.data) == (200, b"hello oidcuser") assert userinfo.call_count == 1 @@ -387,41 +384,36 @@ def set_time(time): assert (resp.status_code, resp.data) == (200, b"hello otheruser") assert userinfo.call_count == 2 - set_time(100) + time.sleep(100) resp = client.get("/personal/hello", headers=headers) assert (resp.status_code, resp.data) == (200, b"hello oidcuser") assert userinfo.call_count == 2 - set_time(10000) + time.sleep(10000) resp = client.get("/personal/hello", headers=headers) assert (resp.status_code, resp.data) == (200, b"hello oidcuser") assert userinfo.call_count == 3 -def test_userinfo_url_caching(app, requests_mock, oidc_provider): +def test_userinfo_url_caching(app, requests_mock, oidc_provider, simple_time): oidc_discovery_url = oidc_provider.issuer + "/.well-known/openid-configuration" oidc_userinfo_url = oidc_provider.issuer + "/userinfo" discovery_mock = requests_mock.get(oidc_discovery_url, json={"userinfo_endpoint": oidc_userinfo_url}) requests_mock.get(oidc_provider.issuer + "/userinfo", json={"sub": "foo"}) - def set_time(time): - # TODO reusable time mocking - app.config["auth_handler"]._cache._clock = lambda: time - with app.test_client() as client: assert discovery_mock.call_count == 0 - set_time(10) resp = client.get("/private/hello", headers={"Authorization": f"Bearer oidc/{oidc_provider.id}/dfergef"}) assert resp.status_code == 200 assert discovery_mock.call_count == 1 - set_time(60) + time.sleep(60) resp = client.get("/private/hello", headers={"Authorization": f"Bearer oidc/{oidc_provider.id}/ftreyer"}) assert resp.status_code == 200 assert discovery_mock.call_count == 1 - set_time(30 * 60) + time.sleep(30 * 60) resp = client.get("/private/hello", headers={"Authorization": f"Bearer oidc/{oidc_provider.id}/th56te"}) assert resp.status_code == 200 assert discovery_mock.call_count == 2 diff --git a/tests/util/test_caching.py b/tests/util/test_caching.py index 64bff5a3..cb7a0e6b 100644 --- a/tests/util/test_caching.py +++ b/tests/util/test_caching.py @@ -35,25 +35,42 @@ def test_default_ttl(self): cache.set("foo", "bar") clock.set(105) assert cache.get("foo") == "bar" - clock.set(110) + clock.set(109) assert cache.contains("foo") assert cache.get("foo") == "bar" clock.set(115) assert not cache.contains("foo") assert cache.get("foo") is None - def test_item_ttl(self): + def test_item_ttl_deprecated(self): + """Per-item ttl is deprecated: the argument is accepted but warns and is ignored.""" clock = FakeClock() cache = TtlCache(default_ttl=10, _clock=clock) clock.set(100) - cache.set("foo", "bar", ttl=20) - clock.set(115) + with pytest.warns(DeprecationWarning, match="Per-item ttl is deprecated"): + cache.set("foo", "bar", ttl=20) + # Item is still stored and retrievable within the default TTL. + clock.set(105) assert cache.contains("foo") assert cache.get("foo") == "bar" - clock.set(125) + # After default_ttl the item expires (per-item ttl=20 is ignored). + clock.set(115) assert not cache.contains("foo") assert cache.get("foo") is None + def test_max_size(self): + cache = TtlCache(default_ttl=60, max_size=3) + cache.set("a", 1) + cache.set("b", 2) + cache.set("c", 3) + assert cache.contains("a") + assert cache.contains("b") + assert cache.contains("c") + # Adding a fourth item evicts the least-recently-used entry. + cache.set("d", 4) + assert cache.contains("d") + assert sum(cache.contains(k) for k in ("a", "b", "c", "d")) == 3 + def test_get_or_call(self): def calculate(_state={"x": 0}): _state["x"] += 1 @@ -70,6 +87,12 @@ def calculate(_state={"x": 0}): clock.set(140) assert cache.get_or_call("foo", callback=calculate) == 3 + def test_get_or_call_ttl_deprecated(self): + cache = TtlCache(default_ttl=10) + with pytest.warns(DeprecationWarning, match="Per-item ttl is deprecated"): + result = cache.get_or_call("foo", callback=lambda: 42, ttl=5) + assert result == 42 + def test_get_or_call_error(self): def calculate(): return 4 / 0