|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import logging |
6 | | -from typing import TYPE_CHECKING, Any, Literal |
| 6 | +from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast |
7 | 7 |
|
8 | 8 | from mellea.plugins.base import MelleaBasePayload, PluginViolationError |
9 | 9 | from mellea.plugins.context import build_global_context |
@@ -175,13 +175,17 @@ def deregister_session_plugins(session_id: str) -> None: |
175 | 175 | logger.debug("Plugin %s already unregistered", name, exc_info=True) |
176 | 176 |
|
177 | 177 |
|
| 178 | +# Hooks return the same payload they received. Use this to accurately reflect that typing. |
| 179 | +_MelleaBasePayload = TypeVar("_MelleaBasePayload", bound=MelleaBasePayload) |
| 180 | + |
| 181 | + |
178 | 182 | async def invoke_hook( |
179 | 183 | hook_type: HookType, |
180 | | - payload: MelleaBasePayload, |
| 184 | + payload: _MelleaBasePayload, |
181 | 185 | *, |
182 | 186 | backend: Backend | None = None, |
183 | 187 | **context_fields: Any, |
184 | | -) -> tuple[Any | None, MelleaBasePayload]: |
| 188 | +) -> tuple[Any | None, _MelleaBasePayload]: |
185 | 189 | """Invoke a hook if plugins are configured. |
186 | 190 |
|
187 | 191 | Returns ``(result, possibly-modified-payload)``. |
@@ -241,7 +245,12 @@ async def invoke_hook( |
241 | 245 | plugin_name=v.plugin_name or "", |
242 | 246 | ) |
243 | 247 |
|
244 | | - modified = ( |
245 | | - result.modified_payload if result and result.modified_payload else payload |
| 248 | + # `result` doesn't type the returned payload correctly. |
| 249 | + # If the modified payload exists, cast it as the correct type here, |
| 250 | + # else return the original payload. |
| 251 | + modified: _MelleaBasePayload = ( |
| 252 | + cast(_MelleaBasePayload, result.modified_payload) |
| 253 | + if result and result.modified_payload |
| 254 | + else payload |
246 | 255 | ) |
247 | 256 | return result, modified |
0 commit comments