-
Notifications
You must be signed in to change notification settings - Fork 8
feature: Bounded TTL Caching #508
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,11 @@ | ||
| import functools | ||
| import logging | ||
| import threading | ||
| import time | ||
| from typing import Union, Tuple, Any, Callable, Dict, Optional | ||
| import warnings | ||
| from typing import Union, Tuple, Any, Callable, Optional | ||
|
|
||
| import cachetools | ||
|
|
||
| _log = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -11,62 +15,88 @@ | |
|
|
||
| class TtlCache: | ||
| """ | ||
| Simple dictionary based, in-memory key-value cache with expiry. | ||
| In-memory key-value cache with TTL expiry and a maximum size, backed by | ||
| :class:`cachetools.TTLCache`. When the cache is full, the least-recently-used | ||
| item is evicted to make room for new entries. | ||
|
|
||
| Cache interactions are thread-safe. The lock is intentionally *not* held while | ||
| a cache-miss callback is executing, so slow callbacks do not block other readers. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, default_ttl: float = 60, _clock: Callable[[], float] = time.time | ||
| self, | ||
| default_ttl: float = 60, | ||
| *, | ||
| max_size: int = 1000, | ||
| _clock: Callable[[], float] = time.time, | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Given that the tests now use time_machine to time mocking, is it still necessary to have this |
||
| ): | ||
| self._cache: Dict[CacheKey, Tuple[Any, float]] = {} | ||
| self.default_ttl = default_ttl | ||
| self._clock = _clock | ||
| self._cache: cachetools.TTLCache = cachetools.TTLCache(maxsize=max_size, ttl=default_ttl, timer=_clock) | ||
| self._lock = threading.Lock() | ||
|
|
||
| def set(self, key: CacheKey, value: Any, ttl: Optional[float] = None) -> None: | ||
| """Store item in cache""" | ||
| self._cache[key] = (value, self._clock() + (ttl or self.default_ttl)) | ||
| if ttl is not None: | ||
| warnings.warn( | ||
| "Per-item ttl is deprecated and will be ignored; use default_ttl on the cache instead.", | ||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| with self._lock: | ||
| self._cache[key] = value | ||
|
|
||
| def contains(self, key: CacheKey) -> bool: | ||
| """Check whether cache contains item under given key""" | ||
| if key in self._cache: | ||
| value, expiration = self._cache[key] | ||
| if self._clock() <= expiration: | ||
| return True | ||
| del self._cache[key] | ||
| return False | ||
| """Check whether cache contains a non-expired item under the given key.""" | ||
| with self._lock: | ||
| return key in self._cache | ||
|
|
||
| def get(self, key: CacheKey, default=None) -> Any: | ||
| """Get item from cache and if not available: return default value.""" | ||
| # TODO: raise KeyError on cache miss? | ||
| return self._cache[key][0] if self.contains(key) else default | ||
| """Get item from cache; return *default* on a cache miss or expiry.""" | ||
| with self._lock: | ||
| return self._cache.get(key, default) | ||
|
|
||
| def get_or_call( | ||
| self, key: CacheKey, callback: Callable[[], Any], ttl: Optional[float] = None | ||
| ) -> Any: | ||
| """ | ||
| Try to get item from cache. If not available or expired: call callback to build it and store result in cache. | ||
| Return the cached value for *key*, or call *callback* to build it on a miss. | ||
|
|
||
| This method allows to implement typicall cache usage pattern in a single call: | ||
| The lock is held only during cache look-up and result storage, **not** while | ||
| *callback* is executing. This means two concurrent callers may both experience | ||
| a cache miss and both invoke the callback simultaneously; the last one to finish | ||
| wins the store. This is intentional — it avoids blocking other callers during | ||
| potentially slow work. | ||
|
|
||
| This method allows implementing the typical cache usage pattern in a single call:: | ||
|
|
||
| item = cache.get_or_call( | ||
| key="foo", | ||
| callback=lambda: expensive_operation(iterations=10000) | ||
| callback=lambda: expensive_operation(iterations=10000), | ||
| ) | ||
|
|
||
| :param key: key to store item at (can be a simple string, | ||
| or something more complex like a tuple of strings/ints) | ||
| :param callback: item builder to call when item is not in cache or expired | ||
| :param ttl: optionally override default TTL | ||
| :return: item (from cache or freshly built) | ||
| :param key: cache key (a string or a tuple of strings/ints) | ||
| :param callback: callable invoked on a cache miss to produce the value | ||
| :param ttl: deprecated; has no effect and will emit a :class:`DeprecationWarning` | ||
| :return: the cached or freshly produced value | ||
| """ | ||
| if self.contains(key): | ||
| value = self._cache[key][0] | ||
| else: | ||
| value = callback() | ||
| self.set(key=key, value=value, ttl=ttl) | ||
| if ttl is not None: | ||
| warnings.warn( | ||
| "Per-item ttl is deprecated and will be ignored; use default_ttl on the cache instead.", | ||
| DeprecationWarning, | ||
| stacklevel=2, | ||
| ) | ||
| with self._lock: | ||
| if key in self._cache: | ||
| return self._cache[key] | ||
| # Lock intentionally released before calling the callback. | ||
| value = callback() | ||
| with self._lock: | ||
| self._cache[key] = value | ||
| return value | ||
|
|
||
| def flush(self): | ||
| self._cache = {} | ||
| def flush(self) -> None: | ||
| with self._lock: | ||
| self._cache.clear() | ||
|
|
||
|
|
||
| def lru_cache_if_simple_args(func=None, *, maxsize: int = 128): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,3 +1,5 @@ | ||
| import time | ||
|
|
||
| import flask | ||
| import json | ||
| import pytest | ||
|
|
@@ -361,7 +363,7 @@ def userinfo(request, context): | |
| assert internal_auth_data["access_token"] == oidc_access_token | ||
|
|
||
|
|
||
| def test_bearer_auth_oidc_caching(app, requests_mock, oidc_provider): | ||
| def test_bearer_auth_oidc_caching(app, requests_mock, oidc_provider, simple_time): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think using the "fast_sleep" fixture here would be more to the point and more self-explanatory
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did is to be consistent with OpenEO Python client. The problem with fast_sleep is that it does not set a reference time. So simple_time does |
||
| def userinfo(request, context): | ||
| """Fake OIDC /userinfo endpoint handler""" | ||
| _, _, token = request.headers["Authorization"].partition("Bearer ") | ||
|
|
@@ -370,58 +372,48 @@ def userinfo(request, context): | |
|
|
||
| userinfo = requests_mock.get(oidc_provider.issuer + "/userinfo", text=userinfo) | ||
|
|
||
| def set_time(time): | ||
| # TODO reusable time mocking | ||
| app.config["auth_handler"]._cache._clock = lambda: time | ||
|
|
||
| with app.test_client() as client: | ||
| # Note: user id is "hidden" in access token | ||
| headers = {"Authorization": f"Bearer oidc/{oidc_provider.id}/rmxje3uhs.oidcuser.o94h4oe"} | ||
| headers_other = {"Authorization": f"Bearer oidc/{oidc_provider.id}/trwe35.otheruser.fg34fsf"} | ||
|
|
||
| set_time(10) | ||
| resp = client.get("/personal/hello", headers=headers) | ||
| assert (resp.status_code, resp.data) == (200, b"hello oidcuser") | ||
| assert userinfo.call_count == 1 | ||
| resp = client.get("/personal/hello", headers=headers_other) | ||
| assert (resp.status_code, resp.data) == (200, b"hello otheruser") | ||
| assert userinfo.call_count == 2 | ||
|
|
||
| set_time(100) | ||
| time.sleep(100) | ||
| resp = client.get("/personal/hello", headers=headers) | ||
| assert (resp.status_code, resp.data) == (200, b"hello oidcuser") | ||
| assert userinfo.call_count == 2 | ||
|
|
||
| set_time(10000) | ||
| time.sleep(10000) | ||
| resp = client.get("/personal/hello", headers=headers) | ||
| assert (resp.status_code, resp.data) == (200, b"hello oidcuser") | ||
| assert userinfo.call_count == 3 | ||
|
|
||
|
|
||
| def test_userinfo_url_caching(app, requests_mock, oidc_provider): | ||
| def test_userinfo_url_caching(app, requests_mock, oidc_provider, simple_time): | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same note about fast_sleep |
||
| oidc_discovery_url = oidc_provider.issuer + "/.well-known/openid-configuration" | ||
| oidc_userinfo_url = oidc_provider.issuer + "/userinfo" | ||
| discovery_mock = requests_mock.get(oidc_discovery_url, json={"userinfo_endpoint": oidc_userinfo_url}) | ||
| requests_mock.get(oidc_provider.issuer + "/userinfo", json={"sub": "foo"}) | ||
|
|
||
| def set_time(time): | ||
| # TODO reusable time mocking | ||
| app.config["auth_handler"]._cache._clock = lambda: time | ||
|
|
||
| with app.test_client() as client: | ||
| assert discovery_mock.call_count == 0 | ||
|
|
||
| set_time(10) | ||
| resp = client.get("/private/hello", headers={"Authorization": f"Bearer oidc/{oidc_provider.id}/dfergef"}) | ||
| assert resp.status_code == 200 | ||
| assert discovery_mock.call_count == 1 | ||
|
|
||
| set_time(60) | ||
| time.sleep(60) | ||
| resp = client.get("/private/hello", headers={"Authorization": f"Bearer oidc/{oidc_provider.id}/ftreyer"}) | ||
| assert resp.status_code == 200 | ||
| assert discovery_mock.call_count == 1 | ||
|
|
||
| set_time(30 * 60) | ||
| time.sleep(30 * 60) | ||
| resp = client.get("/private/hello", headers={"Authorization": f"Bearer oidc/{oidc_provider.id}/th56te"}) | ||
| assert resp.status_code == 200 | ||
| assert discovery_mock.call_count == 2 | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also make default_ttl a keyword-only argument
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And it also doesn't make sense anymore to call it
default_ttl: should just bettl