Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 1 addition & 8 deletions slowapi/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,13 +326,6 @@ def emit(*_):
self._fallback_storage = MemoryStorage()
self._fallback_limiter = STRATEGIES[strategy](self._fallback_storage)

def slowapi_startup(self) -> None:
"""
Starlette startup event handler that links the app with the Limiter instance.
"""
app.state.limiter = self # type: ignore
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) # type: ignore

def get_app_config(self, key: str, default_value: T = None) -> T:
"""
Place holder until we find a better way to load config from app
Expand Down Expand Up @@ -486,7 +479,7 @@ def __evaluate_limits(
failed_limit = None
limit_for_header = None
for lim in limits:
limit_scope = lim.scope or endpoint
limit_scope = lim.scope_for(request) or endpoint
if lim.is_exempt(request):
continue
if lim.methods is not None and request.method.lower() not in lim.methods:
Expand Down
14 changes: 4 additions & 10 deletions slowapi/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,18 +51,12 @@ def is_exempt(self, request: Optional[Request] = None) -> bool:
return self.exempt_when(request)
return self.exempt_when()

@property
def scope(self) -> str:
# flack.request.endpoint is the name of the function for the endpoint
# FIXME: how to get the request here?
def scope_for(self, request: Request) -> str:
if self.__scope is None:
return ""
else:
return (
self.__scope(request.endpoint) # type: ignore
if callable(self.__scope)
else self.__scope
)
if callable(self.__scope):
return self.__scope(request)
return self.__scope


class LimitGroup(object):
Expand Down
30 changes: 30 additions & 0 deletions tests/test_starlette_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,36 @@ def t2(request: Request):
# the shared limit has already been hit via t1
assert client.get("/t2").status_code == 429

def test_shared_decorator_callable_scope(self, build_starlette_app):
"""Callable scope receives the request and buckets are keyed by its return value."""
app, limiter = build_starlette_app(key_func=get_ipaddr)

def scope_from_tenant(request: Request) -> str:
return request.headers.get("X-Tenant", "default")

shared_lim = limiter.shared_limit("5/minute", scope=scope_from_tenant)

@shared_lim
def t1(request: Request):
return PlainTextResponse("test")

@shared_lim
def t2(request: Request):
return PlainTextResponse("test")

app.add_route("/t1", t1)
app.add_route("/t2", t2)

client = TestClient(app)
# tenant A burns its budget on /t1 ...
for i in range(10):
resp = client.get("/t1", headers={"X-Tenant": "A"})
assert resp.status_code == (200 if i < 5 else 429)
# ... and /t2 is also exhausted for tenant A (shared scope)
assert client.get("/t2", headers={"X-Tenant": "A"}).status_code == 429
# but tenant B has its own bucket — scope callable isolates them
assert client.get("/t2", headers={"X-Tenant": "B"}).status_code == 200

def test_multiple_decorators(self, build_starlette_app):
app, limiter = build_starlette_app(key_func=get_ipaddr)

Expand Down