diff --git a/astrbot/core/cron/manager.py b/astrbot/core/cron/manager.py index fde2ad5cd8..3e79519a6d 100644 --- a/astrbot/core/cron/manager.py +++ b/astrbot/core/cron/manager.py @@ -14,9 +14,11 @@ from astrbot.core.cron.events import CronMessageEvent from astrbot.core.db import BaseDatabase from astrbot.core.db.po import CronJob +from astrbot.core.pipeline.context_utils import call_event_hook from astrbot.core.platform.message_session import MessageSession from astrbot.core.platform.message_type import MessageType from astrbot.core.provider.entites import ProviderRequest +from astrbot.core.star.star_handler import EventType from astrbot.core.utils.history_saver import persist_agent_history if TYPE_CHECKING: @@ -377,14 +379,31 @@ async def _woke_main_agent( self.ctx.get_llm_tool_manager().get_builtin_tool(SendMessageToUserTool) ) + await call_event_hook(cron_event, EventType.OnWaitingLLMRequestEvent) + result = await build_main_agent( - event=cron_event, plugin_context=self.ctx, config=config, req=req + event=cron_event, + plugin_context=self.ctx, + config=config, + req=req, + apply_reset=False, ) if not result: logger.error("Failed to build main agent for cron job.") return runner = result.agent_runner + req = result.provider_request + reset_coro = result.reset_coro + + if await call_event_hook(cron_event, EventType.OnLLMRequestEvent, req): + if reset_coro: + reset_coro.close() + return + + if reset_coro: + await reset_coro + async for _ in runner.step_until_done(30): # agent will send message to user via using tools pass diff --git a/tests/unit/test_cron_manager.py b/tests/unit/test_cron_manager.py index 9dd3fc34dc..ffb8e62a57 100644 --- a/tests/unit/test_cron_manager.py +++ b/tests/unit/test_cron_manager.py @@ -5,8 +5,12 @@ import pytest +from astrbot.core.astr_main_agent import MainAgentBuildResult from astrbot.core.cron.manager import CronJobManager, CronJobSchedulingError from astrbot.core.db.po import CronJob +from astrbot.core.platform.message_session import MessageSession +from astrbot.core.platform.message_type import MessageType +from astrbot.core.provider.entities import ProviderRequest @pytest.fixture @@ -503,3 +507,80 @@ def test_get_next_run_time_nonexistent(self, cron_manager): next_run = cron_manager._get_next_run_time("non-existent") assert next_run is None + + +class TestWokeMainAgent: + """Tests for cron-triggered main-agent wake flow.""" + + @pytest.mark.asyncio + async def test_woke_main_agent_calls_llm_request_hook_before_reset( + self, cron_manager, mock_context + ): + """Test cron path mirrors the normal request-hook timing.""" + + async def reset_stub(): + return None + + async def empty_steps(*args, **kwargs): + if False: + yield None + + cron_manager.ctx = mock_context + mock_context.get_config.return_value = { + "admins_id": [], + "provider_settings": {"tool_call_timeout": 120}, + } + mock_context.conversation_manager.get_curr_conversation_id = AsyncMock( + return_value="conv-id" + ) + mock_context.conversation_manager.get_conversation = AsyncMock( + return_value=MagicMock(history="[]", cid="conv-id", persona_id=None) + ) + mock_context.get_llm_tool_manager.return_value = MagicMock() + + req = ProviderRequest(prompt="scheduled") + runner = MagicMock() + runner.step_until_done = empty_steps + runner.get_final_llm_resp.return_value = None + build_result = MainAgentBuildResult( + agent_runner=runner, + provider_request=req, + provider=MagicMock(), + reset_coro=reset_stub(), + ) + + with ( + patch( + "astrbot.core.astr_main_agent.build_main_agent", + AsyncMock(return_value=build_result), + ) as mock_build_main_agent, + patch( + "astrbot.core.cron.manager.call_event_hook", + AsyncMock(side_effect=[False, True]), + create=True, + ) as mock_call_event_hook, + patch( + "astrbot.core.cron.manager.persist_agent_history", + AsyncMock(), + ), + ): + await cron_manager._woke_main_agent( + message="hello", + session_str=MessageSession( + platform_name="cron", + message_type=MessageType.OTHER_MESSAGE, + session_id="test-session", + ), + extras={"cron_job": {}, "cron_payload": {}}, + ) + + assert mock_build_main_agent.await_args.kwargs["apply_reset"] is False + assert ( + mock_call_event_hook.await_args_list[0].args[1].name + == "OnWaitingLLMRequestEvent" + ) + assert ( + mock_call_event_hook.await_args_list[1].args[1].name + == "OnLLMRequestEvent" + ) + assert mock_call_event_hook.await_args_list[1].args[2] is req