Skip to content

Commit c4edefb

Browse files
committed
feat: add return types to invoke_hook
1 parent 46935a8 commit c4edefb

1 file changed

Lines changed: 14 additions & 5 deletions

File tree

mellea/plugins/manager.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import logging
6-
from typing import TYPE_CHECKING, Any, Literal
6+
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
77

88
from mellea.plugins.base import MelleaBasePayload, PluginViolationError
99
from mellea.plugins.context import build_global_context
@@ -175,13 +175,17 @@ def deregister_session_plugins(session_id: str) -> None:
175175
logger.debug("Plugin %s already unregistered", name, exc_info=True)
176176

177177

178+
# Hooks return the same payload they received. Use this to accurately reflect that typing.
179+
_MelleaBasePayload = TypeVar("_MelleaBasePayload", bound=MelleaBasePayload)
180+
181+
178182
async def invoke_hook(
179183
hook_type: HookType,
180-
payload: MelleaBasePayload,
184+
payload: _MelleaBasePayload,
181185
*,
182186
backend: Backend | None = None,
183187
**context_fields: Any,
184-
) -> tuple[Any | None, MelleaBasePayload]:
188+
) -> tuple[Any | None, _MelleaBasePayload]:
185189
"""Invoke a hook if plugins are configured.
186190
187191
Returns ``(result, possibly-modified-payload)``.
@@ -241,7 +245,12 @@ async def invoke_hook(
241245
plugin_name=v.plugin_name or "",
242246
)
243247

244-
modified = (
245-
result.modified_payload if result and result.modified_payload else payload
248+
# `result` doesn't type the returned payload correctly.
249+
# If the modified payload exists, cast it as the correct type here,
250+
# else return the original payload.
251+
modified: _MelleaBasePayload = (
252+
cast(_MelleaBasePayload, result.modified_payload)
253+
if result and result.modified_payload
254+
else payload
246255
)
247256
return result, modified

0 commit comments

Comments
 (0)