Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions openeo_driver/users/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
# TODO: handle `oidc_providers` and `user_access_validation` through OpenEoBackendConfig
self._oidc_providers: Dict[str, OidcProvider] = {p.id: p for p in oidc_providers}
self._user_access_validation = user_access_validation
self._cache = TtlCache(default_ttl=10 * 60)
self._cache = TtlCache(default_ttl=15 * 60)

def public(self, f: Callable):
"""
Expand Down Expand Up @@ -112,7 +112,7 @@ def get_user_from_bearer_token(self, request: flask.Request) -> User:
cache_key = ("bearer", bearer)
if not self._cache.contains(cache_key):
user = self._get_user_from_bearer_token(bearer=bearer)
self._cache.set(cache_key, value=user, ttl=15 * 60)
self._cache.set(cache_key, value=user)
return self._cache.get(cache_key)

def _get_user_from_bearer_token(self, bearer: str) -> User:
Expand Down Expand Up @@ -210,7 +210,6 @@ def _get_oidc_provider_config(self, oidc_provider: OidcProvider) -> dict:
callback=lambda: self._oidc_provider_request(
oidc_provider.discovery_url
).json(),
ttl=15 * 60,
)

def resolve_oidc_access_token(self, oidc_provider: OidcProvider, access_token: str) -> User:
Expand Down
92 changes: 61 additions & 31 deletions openeo_driver/util/caching.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import functools
import logging
import threading
import time
from typing import Union, Tuple, Any, Callable, Dict, Optional
import warnings
from typing import Union, Tuple, Any, Callable, Optional

import cachetools

_log = logging.getLogger(__name__)

Expand All @@ -11,62 +15,88 @@

class TtlCache:
"""
Simple dictionary based, in-memory key-value cache with expiry.
In-memory key-value cache with TTL expiry and a maximum size, backed by
:class:`cachetools.TTLCache`. When the cache is full, the least-recently-used
item is evicted to make room for new entries.

Cache interactions are thread-safe. The lock is intentionally *not* held while
a cache-miss callback is executing, so slow callbacks do not block other readers.
"""

def __init__(
self, default_ttl: float = 60, _clock: Callable[[], float] = time.time
self,
default_ttl: float = 60,
*,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also make default_ttl a keyword-only argument

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And it also doesn't make sense anymore to call it default_ttl: should just be ttl

max_size: int = 1000,
_clock: Callable[[], float] = time.time,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the tests now use time_machine to time mocking, is it still necessary to have this _clock attribute (which is only there for testability)?

):
self._cache: Dict[CacheKey, Tuple[Any, float]] = {}
self.default_ttl = default_ttl
self._clock = _clock
self._cache: cachetools.TTLCache = cachetools.TTLCache(maxsize=max_size, ttl=default_ttl, timer=_clock)
self._lock = threading.Lock()

def set(self, key: CacheKey, value: Any, ttl: Optional[float] = None) -> None:
"""Store item in cache"""
self._cache[key] = (value, self._clock() + (ttl or self.default_ttl))
if ttl is not None:
warnings.warn(
"Per-item ttl is deprecated and will be ignored; use default_ttl on the cache instead.",
DeprecationWarning,
stacklevel=2,
)
with self._lock:
self._cache[key] = value

def contains(self, key: CacheKey) -> bool:
"""Check whether cache contains item under given key"""
if key in self._cache:
value, expiration = self._cache[key]
if self._clock() <= expiration:
return True
del self._cache[key]
return False
"""Check whether cache contains a non-expired item under the given key."""
with self._lock:
return key in self._cache

def get(self, key: CacheKey, default=None) -> Any:
"""Get item from cache and if not available: return default value."""
# TODO: raise KeyError on cache miss?
return self._cache[key][0] if self.contains(key) else default
"""Get item from cache; return *default* on a cache miss or expiry."""
with self._lock:
return self._cache.get(key, default)

def get_or_call(
self, key: CacheKey, callback: Callable[[], Any], ttl: Optional[float] = None
) -> Any:
"""
Try to get item from cache. If not available or expired: call callback to build it and store result in cache.
Return the cached value for *key*, or call *callback* to build it on a miss.

This method allows to implement typicall cache usage pattern in a single call:
The lock is held only during cache look-up and result storage, **not** while
*callback* is executing. This means two concurrent callers may both experience
a cache miss and both invoke the callback simultaneously; the last one to finish
wins the store. This is intentional — it avoids blocking other callers during
potentially slow work.

This method allows implementing the typical cache usage pattern in a single call::

item = cache.get_or_call(
key="foo",
callback=lambda: expensive_operation(iterations=10000)
callback=lambda: expensive_operation(iterations=10000),
)

:param key: key to store item at (can be a simple string,
or something more complex like a tuple of strings/ints)
:param callback: item builder to call when item is not in cache or expired
:param ttl: optionally override default TTL
:return: item (from cache or freshly built)
:param key: cache key (a string or a tuple of strings/ints)
:param callback: callable invoked on a cache miss to produce the value
:param ttl: deprecated; has no effect and will emit a :class:`DeprecationWarning`
:return: the cached or freshly produced value
"""
if self.contains(key):
value = self._cache[key][0]
else:
value = callback()
self.set(key=key, value=value, ttl=ttl)
if ttl is not None:
warnings.warn(
"Per-item ttl is deprecated and will be ignored; use default_ttl on the cache instead.",
DeprecationWarning,
stacklevel=2,
)
with self._lock:
if key in self._cache:
return self._cache[key]
# Lock intentionally released before calling the callback.
value = callback()
with self._lock:
self._cache[key] = value
return value

def flush(self):
self._cache = {}
def flush(self) -> None:
with self._lock:
self._cache.clear()


def lru_cache_if_simple_args(func=None, *, maxsize: int = 128):
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"markdown>3.4",
"pystac>=1.8.0", # TODO #370/#396 require more recent pystac version once py3.8 support is dropped
"antimeridian>=0.3.8", # 0.3.8 is the highest version that still supports Python 3.8
"cachetools>=5.0",
],
extras_require={
"dev": tests_require + typing_require,
Expand Down
14 changes: 14 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import logging
import os
import time
import typing
from typing import Optional
from unittest import mock

import flask
import pytest
import pythonjsonlogger.jsonlogger
from openeo.testing.util import Sleeper

import openeo_driver.config.load
import openeo_driver.dummy.dummy_config
Expand Down Expand Up @@ -144,3 +146,15 @@ def swift_url(monkeypatch):
For real environments this is used as the fallback endpoint for doing S3 requests.
"""
monkeypatch.setenv("SWIFT_URL", TEST_SWIFT_URL)


@pytest.fixture
def fast_sleep(time_machine) -> typing.Iterator[Sleeper]:
with Sleeper().patch(time_machine=time_machine) as sleeper:
yield sleeper


@pytest.fixture
def simple_time(time_machine, fast_sleep):
"""Fixture to set up simple time stepping (and fast sleep) with `time_machine`."""
time_machine.move_to(1000, tick=False)
24 changes: 8 additions & 16 deletions tests/users/test_auth.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import time

import flask
import json
import pytest
Expand Down Expand Up @@ -361,7 +363,7 @@ def userinfo(request, context):
assert internal_auth_data["access_token"] == oidc_access_token


def test_bearer_auth_oidc_caching(app, requests_mock, oidc_provider):
def test_bearer_auth_oidc_caching(app, requests_mock, oidc_provider, simple_time):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think using the "fast_sleep" fixture here would be more to the point and more self-explanatory

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did is to be consistent with OpenEO Python client. The problem with fast_sleep is that it does not set a reference time. So simple_time does time_machine.move_to(1000, tick=False) and then allows fast_sleep. If you just use fixture fast_sleep still need to initialize time_machine prior to using

def userinfo(request, context):
"""Fake OIDC /userinfo endpoint handler"""
_, _, token = request.headers["Authorization"].partition("Bearer ")
Expand All @@ -370,58 +372,48 @@ def userinfo(request, context):

userinfo = requests_mock.get(oidc_provider.issuer + "/userinfo", text=userinfo)

def set_time(time):
# TODO reusable time mocking
app.config["auth_handler"]._cache._clock = lambda: time

with app.test_client() as client:
# Note: user id is "hidden" in access token
headers = {"Authorization": f"Bearer oidc/{oidc_provider.id}/rmxje3uhs.oidcuser.o94h4oe"}
headers_other = {"Authorization": f"Bearer oidc/{oidc_provider.id}/trwe35.otheruser.fg34fsf"}

set_time(10)
resp = client.get("/personal/hello", headers=headers)
assert (resp.status_code, resp.data) == (200, b"hello oidcuser")
assert userinfo.call_count == 1
resp = client.get("/personal/hello", headers=headers_other)
assert (resp.status_code, resp.data) == (200, b"hello otheruser")
assert userinfo.call_count == 2

set_time(100)
time.sleep(100)
resp = client.get("/personal/hello", headers=headers)
assert (resp.status_code, resp.data) == (200, b"hello oidcuser")
assert userinfo.call_count == 2

set_time(10000)
time.sleep(10000)
resp = client.get("/personal/hello", headers=headers)
assert (resp.status_code, resp.data) == (200, b"hello oidcuser")
assert userinfo.call_count == 3


def test_userinfo_url_caching(app, requests_mock, oidc_provider):
def test_userinfo_url_caching(app, requests_mock, oidc_provider, simple_time):
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same note about fast_sleep

oidc_discovery_url = oidc_provider.issuer + "/.well-known/openid-configuration"
oidc_userinfo_url = oidc_provider.issuer + "/userinfo"
discovery_mock = requests_mock.get(oidc_discovery_url, json={"userinfo_endpoint": oidc_userinfo_url})
requests_mock.get(oidc_provider.issuer + "/userinfo", json={"sub": "foo"})

def set_time(time):
# TODO reusable time mocking
app.config["auth_handler"]._cache._clock = lambda: time

with app.test_client() as client:
assert discovery_mock.call_count == 0

set_time(10)
resp = client.get("/private/hello", headers={"Authorization": f"Bearer oidc/{oidc_provider.id}/dfergef"})
assert resp.status_code == 200
assert discovery_mock.call_count == 1

set_time(60)
time.sleep(60)
resp = client.get("/private/hello", headers={"Authorization": f"Bearer oidc/{oidc_provider.id}/ftreyer"})
assert resp.status_code == 200
assert discovery_mock.call_count == 1

set_time(30 * 60)
time.sleep(30 * 60)
resp = client.get("/private/hello", headers={"Authorization": f"Bearer oidc/{oidc_provider.id}/th56te"})
assert resp.status_code == 200
assert discovery_mock.call_count == 2
Expand Down
33 changes: 28 additions & 5 deletions tests/util/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,25 +35,42 @@ def test_default_ttl(self):
cache.set("foo", "bar")
clock.set(105)
assert cache.get("foo") == "bar"
clock.set(110)
clock.set(109)
assert cache.contains("foo")
assert cache.get("foo") == "bar"
clock.set(115)
assert not cache.contains("foo")
assert cache.get("foo") is None

def test_item_ttl(self):
def test_item_ttl_deprecated(self):
"""Per-item ttl is deprecated: the argument is accepted but warns and is ignored."""
clock = FakeClock()
cache = TtlCache(default_ttl=10, _clock=clock)
clock.set(100)
cache.set("foo", "bar", ttl=20)
clock.set(115)
with pytest.warns(DeprecationWarning, match="Per-item ttl is deprecated"):
cache.set("foo", "bar", ttl=20)
# Item is still stored and retrievable within the default TTL.
clock.set(105)
assert cache.contains("foo")
assert cache.get("foo") == "bar"
clock.set(125)
# After default_ttl the item expires (per-item ttl=20 is ignored).
clock.set(115)
assert not cache.contains("foo")
assert cache.get("foo") is None

def test_max_size(self):
cache = TtlCache(default_ttl=60, max_size=3)
cache.set("a", 1)
cache.set("b", 2)
cache.set("c", 3)
assert cache.contains("a")
assert cache.contains("b")
assert cache.contains("c")
# Adding a fourth item evicts the least-recently-used entry.
cache.set("d", 4)
assert cache.contains("d")
assert sum(cache.contains(k) for k in ("a", "b", "c", "d")) == 3

def test_get_or_call(self):
def calculate(_state={"x": 0}):
_state["x"] += 1
Expand All @@ -70,6 +87,12 @@ def calculate(_state={"x": 0}):
clock.set(140)
assert cache.get_or_call("foo", callback=calculate) == 3

def test_get_or_call_ttl_deprecated(self):
cache = TtlCache(default_ttl=10)
with pytest.warns(DeprecationWarning, match="Per-item ttl is deprecated"):
result = cache.get_or_call("foo", callback=lambda: 42, ttl=5)
assert result == 42

def test_get_or_call_error(self):
def calculate():
return 4 / 0
Expand Down