diff --git a/.gitignore b/.gitignore index 5eb9616c8c..e63b3eeca5 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,4 @@ GenieData/ .worktrees/ dashboard/bun.lock +docs/plans/ diff --git a/README_zh.md b/README_zh.md index 425719faba..9b265b48e7 100644 --- a/README_zh.md +++ b/README_zh.md @@ -47,9 +47,9 @@ AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、 3. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台。 4. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)。 5. 📦 插件扩展,已有 1000+ 个插件可一键安装。 -6. 🛡️ [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) 隔离化环境,安全地执行任何代码、调用 Shell、会话级资源复用。 +6. 🛡️ [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) 提供隔离沙盒环境,支持安全执行代码、调用 Shell,并在会话内复用资源。 7. 💻 WebUI 支持。 -8. 🌈 Web ChatUI 支持,ChatUI 内置代理沙盒、网页搜索等。 +8. 🌈 Web ChatUI 支持,内置 Agent Sandbox、网页搜索等能力。 9. 🌐 国际化(i18n)支持。 diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 3f74f0ec9b..4764c08480 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -146,6 +146,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): REPEATED_TOOL_NOTICE_L1_THRESHOLD = 3 REPEATED_TOOL_NOTICE_L2_THRESHOLD = 4 REPEATED_TOOL_NOTICE_L3_THRESHOLD = 5 + REPEATED_TOOL_NOTICE_EXEMPT_TOOL_NAMES = frozenset({"astrbot_execute_shell"}) REPEATED_TOOL_NOTICE_L1_TEMPLATE = ( "\n\n[SYSTEM NOTICE] By the way, you have executed the same tool " "`{tool_name}` {streak} times consecutively. Double-check whether another " @@ -667,6 +668,9 @@ def _track_tool_call_streak(self, tool_name: str) -> int: return self._same_tool_streak def _build_repeated_tool_call_guidance(self, tool_name: str, streak: int) -> str: + if tool_name in self.REPEATED_TOOL_NOTICE_EXEMPT_TOOL_NAMES: + return "" + if streak < self.REPEATED_TOOL_NOTICE_L1_THRESHOLD: return "" diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index de5caad554..6d512926e1 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -20,6 +20,7 @@ from astrbot.core.astr_main_agent_resources import ( BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT, ) +from astrbot.core.computer.sandbox_tool_binding import tool_available_in_runtime from astrbot.core.cron.events import CronMessageEvent from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( @@ -31,9 +32,6 @@ from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools from astrbot.core.tools.computer_tools import ( - CuaKeyboardTypeTool, - CuaMouseClickTool, - CuaScreenshotTool, ExecuteShellTool, FileDownloadTool, FileEditTool, @@ -43,8 +41,12 @@ GrepTool, LocalPythonTool, PythonTool, + SandboxLifecycleTool, + SandboxOperationTool, + SandboxQueryTool, ) from astrbot.core.tools.message_tools import SendMessageToUserTool +from astrbot.core.tools.registry import get_builtin_tool_config_rule from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.history_saver import persist_agent_history from astrbot.core.utils.image_ref_utils import is_supported_image_ref @@ -52,6 +54,14 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): + _runtime_computer_tools_cache: dict[ + tuple[int, str, str], dict[str, FunctionTool] + ] = {} + + @classmethod + def clear_runtime_computer_tools_cache(cls, provider_id: str | None = None) -> None: + cls._runtime_computer_tools_cache.clear() + @classmethod def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]: if image_urls_raw is None: @@ -192,8 +202,14 @@ def _get_runtime_computer_tools( booter: str | None = None, ) -> dict[str, FunctionTool]: booter = "" if booter is None else str(booter).lower() + cache_key = (id(tool_mgr), runtime, booter) + if cache_key in cls._runtime_computer_tools_cache: + return cls._runtime_computer_tools_cache[cache_key] if runtime == "sandbox": shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool) + sandbox_query_tool = tool_mgr.get_builtin_tool(SandboxQueryTool) + sandbox_lifecycle_tool = tool_mgr.get_builtin_tool(SandboxLifecycleTool) + sandbox_operation_tool = tool_mgr.get_builtin_tool(SandboxOperationTool) python_tool = tool_mgr.get_builtin_tool(PythonTool) upload_tool = tool_mgr.get_builtin_tool(FileUploadTool) download_tool = tool_mgr.get_builtin_tool(FileDownloadTool) @@ -203,6 +219,9 @@ def _get_runtime_computer_tools( grep_tool = tool_mgr.get_builtin_tool(GrepTool) tools = { shell_tool.name: shell_tool, + sandbox_query_tool.name: sandbox_query_tool, + sandbox_lifecycle_tool.name: sandbox_lifecycle_tool, + sandbox_operation_tool.name: sandbox_operation_tool, python_tool.name: python_tool, upload_tool.name: upload_tool, download_tool.name: download_tool, @@ -211,17 +230,7 @@ def _get_runtime_computer_tools( edit_tool.name: edit_tool, grep_tool.name: grep_tool, } - if booter == "cua": - screenshot_tool = tool_mgr.get_builtin_tool(CuaScreenshotTool) - mouse_click_tool = tool_mgr.get_builtin_tool(CuaMouseClickTool) - keyboard_type_tool = tool_mgr.get_builtin_tool(CuaKeyboardTypeTool) - tools.update( - { - screenshot_tool.name: screenshot_tool, - mouse_click_tool.name: mouse_click_tool, - keyboard_type_tool.name: keyboard_type_tool, - } - ) + cls._runtime_computer_tools_cache[cache_key] = tools return tools if runtime == "local": shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool) @@ -230,7 +239,7 @@ def _get_runtime_computer_tools( write_tool = tool_mgr.get_builtin_tool(FileWriteTool) edit_tool = tool_mgr.get_builtin_tool(FileEditTool) grep_tool = tool_mgr.get_builtin_tool(GrepTool) - return { + tools = { shell_tool.name: shell_tool, python_tool.name: python_tool, read_tool.name: read_tool, @@ -238,8 +247,29 @@ def _get_runtime_computer_tools( edit_tool.name: edit_tool, grep_tool.name: grep_tool, } + cls._runtime_computer_tools_cache[cache_key] = tools + return tools return {} + @staticmethod + def _tool_available_for_runtime_config(tool: FunctionTool, runtime: str) -> bool: + if not tool_available_in_runtime(tool, runtime): + return False + rule = get_builtin_tool_config_rule(tool.name) + if rule is None: + return True + conditions = rule.evaluate( + {"provider_settings": {"computer_use_runtime": runtime}} + ) + runtime_conditions = [ + condition + for condition in conditions + if str(condition.get("key")) == "provider_settings.computer_use_runtime" + ] + if not runtime_conditions: + return True + return all(bool(condition.get("matched")) for condition in runtime_conditions) + @classmethod def _build_handoff_toolset( cls, @@ -259,17 +289,18 @@ def _build_handoff_toolset( runtime_computer_tools = cls._get_runtime_computer_tools( runtime, tool_mgr, - provider_settings.get("sandbox", {}).get("booter"), ) # Keep persona semantics aligned with the main agent: tools=None means # "all tools", including runtime computer-use tools. if tools is None: toolset = ToolSet() - for registered_tool in llm_tools.func_list: + for registered_tool in getattr(tool_mgr, "func_list", llm_tools.func_list): if isinstance(registered_tool, HandoffTool): continue - if registered_tool.active: + if registered_tool.active and cls._tool_available_for_runtime_config( + registered_tool, runtime + ): toolset.add_tool(registered_tool) for runtime_tool in runtime_computer_tools.values(): toolset.add_tool(runtime_tool) @@ -281,14 +312,20 @@ def _build_handoff_toolset( toolset = ToolSet() for tool_name_or_obj in tools: if isinstance(tool_name_or_obj, str): - registered_tool = llm_tools.get_func(tool_name_or_obj) - if registered_tool and registered_tool.active: + registered_tool = tool_mgr.get_func(tool_name_or_obj) + if ( + registered_tool + and registered_tool.active + and cls._tool_available_for_runtime_config(registered_tool, runtime) + ): toolset.add_tool(registered_tool) continue runtime_tool = runtime_computer_tools.get(tool_name_or_obj) if runtime_tool: toolset.add_tool(runtime_tool) - elif isinstance(tool_name_or_obj, FunctionTool): + elif isinstance( + tool_name_or_obj, FunctionTool + ) and cls._tool_available_for_runtime_config(tool_name_or_obj, runtime): toolset.add_tool(tool_name_or_obj) return None if toolset.empty() else toolset @@ -538,11 +575,16 @@ async def _wake_main_agent_for_background_result( message_type=session.message_type, ) cron_event.role = event.role + session_config = ctx.get_config(umo=event.unified_msg_origin) + provider_settings = session_config.get("provider_settings", {}) config = MainAgentBuildConfig( tool_call_timeout=run_context.tool_call_timeout, - streaming_response=ctx.get_config() - .get("provider_settings", {}) - .get("stream", False), + streaming_response=provider_settings.get("stream", False), + computer_use_runtime=str( + provider_settings.get("computer_use_runtime", "local") + ), + sandbox_cfg=provider_settings.get("sandbox", {}), + provider_settings=provider_settings, ) req = ProviderRequest() diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index cee6e9e27d..0c5ea2770f 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -15,7 +15,7 @@ from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.agent.message import TextPart -from astrbot.core.agent.tool import ToolSet +from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS from astrbot.core.astr_agent_run_util import AgentRunner @@ -24,10 +24,13 @@ CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT, LIVE_MODE_SYSTEM_PROMPT, LLM_SAFETY_MODE_SYSTEM_PROMPT, + SANDBOX_GUI_PROMPT, SANDBOX_MODE_PROMPT, TOOL_CALL_PROMPT, TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, ) +from astrbot.core.computer import computer_client +from astrbot.core.computer.sandbox_tool_binding import tool_available_in_runtime from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import File, Image, Record, Reply, Video from astrbot.core.persona_error_reply import ( @@ -47,32 +50,18 @@ from astrbot.core.star.star import star_registry from astrbot.core.star.star_handler import star_map from astrbot.core.tools.computer_tools import ( - AnnotateExecutionTool, - BrowserBatchExecTool, - BrowserExecTool, - CreateSkillCandidateTool, - CreateSkillPayloadTool, - CuaKeyboardTypeTool, - CuaMouseClickTool, - CuaScreenshotTool, - EvaluateSkillCandidateTool, ExecuteShellTool, FileDownloadTool, FileEditTool, FileReadTool, FileUploadTool, FileWriteTool, - GetExecutionHistoryTool, - GetSkillPayloadTool, GrepTool, - ListSkillCandidatesTool, - ListSkillReleasesTool, LocalPythonTool, - PromoteSkillCandidateTool, PythonTool, - RollbackSkillReleaseTool, - RunBrowserSkillTool, - SyncSkillReleaseTool, + SandboxLifecycleTool, + SandboxOperationTool, + SandboxQueryTool, normalize_umo_for_workspace, ) from astrbot.core.tools.cron_tools import FutureTaskTool @@ -81,6 +70,7 @@ retrieve_knowledge_base, ) from astrbot.core.tools.message_tools import SendMessageToUserTool +from astrbot.core.tools.registry import get_builtin_tool_config_rule from astrbot.core.tools.web_search_tools import ( BaiduWebSearchTool, BochaWebSearchTool, @@ -451,6 +441,34 @@ def _filter_skills_for_current_config( return filtered +def _tool_available_for_current_runtime(tool: FunctionTool, cfg: dict) -> bool: + runtime = str(cfg.get("computer_use_runtime", "local")) + if not tool_available_in_runtime(tool, runtime): + return False + rule = get_builtin_tool_config_rule(tool.name) + if rule is None: + return True + conditions = rule.evaluate({"provider_settings": cfg}) + runtime_conditions = [ + condition + for condition in conditions + if str(condition.get("key")) == "provider_settings.computer_use_runtime" + ] + if not runtime_conditions: + return True + return all(bool(condition.get("matched")) for condition in runtime_conditions) + + +def _filter_tools_for_current_config( + toolset: ToolSet, cfg: dict, session_id: str +) -> ToolSet: + filtered = ToolSet() + for tool in toolset: + if _tool_available_for_current_runtime(tool, cfg): + filtered.add_tool(tool) + return filtered + + async def _ensure_persona_and_skills( req: ProviderRequest, cfg: dict, @@ -479,6 +497,7 @@ async def _ensure_persona_and_skills( if req.system_prompt is None: req.system_prompt = "" + session_id = event.unified_msg_origin if persona: # Inject persona system prompt @@ -492,7 +511,12 @@ async def _ensure_persona_and_skills( # Inject skills prompt runtime = cfg.get("computer_use_runtime", "local") skill_manager = SkillManager() - skills = skill_manager.list_skills(active_only=True, runtime=runtime) + current_provider = computer_client.get_current_sandbox_provider_id(session_id) + skills = skill_manager.list_skills( + active_only=True, + runtime=runtime, + provider_id=current_provider, + ) skills = _filter_skills_for_current_config(skills, cfg) if skills: @@ -511,10 +535,20 @@ async def _ensure_persona_and_skills( "If you need to use these capabilities, ask the user to enable Computer Use in the AstrBot WebUI -> Config." ) tmgr = plugin_context.get_llm_tool_manager() + persona_tools_configured = bool(persona and persona.get("tools") is not None) + req._persona_tools_configured = persona_tools_configured + req._persona_allowed_tool_names = ( + {str(tool_name) for tool_name in persona.get("tools", [])} + if persona_tools_configured + else None + ) # inject toolset in the persona if (persona and persona.get("tools") is None) or not persona: persona_toolset = tmgr.get_full_tool_set() + persona_toolset = _filter_tools_for_current_config( + persona_toolset, cfg, session_id + ) for tool in list(persona_toolset): if not tool.active: persona_toolset.remove_tool(tool.name) @@ -523,7 +557,11 @@ async def _ensure_persona_and_skills( if persona["tools"]: for tool_name in persona["tools"]: tool = tmgr.get_func(tool_name) - if tool and tool.active: + if ( + tool + and tool.active + and _tool_available_for_current_runtime(tool, cfg) + ): persona_toolset.add_tool(tool) if not req.func_tool: req.func_tool = persona_toolset @@ -545,13 +583,15 @@ async def _ensure_persona_and_skills( if a.get("enabled", True) is False: continue persona_tools = None + persona_tools_configured = False pid = a.get("persona_id") if pid: persona = plugin_context.persona_manager.get_persona_v3_by_id(pid) if persona is not None: persona_tools = persona.get("tools") + persona_tools_configured = "tools" in persona tools = a.get("tools", []) - if persona_tools is not None: + if persona_tools_configured: tools = persona_tools if tools is None: assigned_tools.update( @@ -559,6 +599,7 @@ async def _ensure_persona_and_skills( tool.name for tool in tmgr.func_list if not isinstance(tool, HandoffTool) + and _tool_available_for_current_runtime(tool, cfg) ] ) continue @@ -951,7 +992,9 @@ async def _decorate_llm_request( _apply_workspace_extra_prompt(event, req) -def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: +def _plugin_tool_fix( + event: AstrMessageEvent, req: ProviderRequest, cfg: dict | None = None +) -> None: """根据事件中的插件设置,过滤请求中的工具列表。 注意:没有 handler_module_path 的工具(如 MCP 工具)会被保留, @@ -979,6 +1022,10 @@ def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: new_tool_set.add_tool(tool) req.func_tool = new_tool_set + if req.func_tool is not None and cfg is not None: + session_id = req.session_id or event.unified_msg_origin + req.func_tool = _filter_tools_for_current_config(req.func_tool, cfg, session_id) + async def _handle_webchat( event: AstrMessageEvent, req: ProviderRequest, prov: Provider @@ -1037,100 +1084,40 @@ def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) - def _apply_sandbox_tools( config: MainAgentBuildConfig, req: ProviderRequest, - session_id: str, ) -> None: if req.func_tool is None: req.func_tool = ToolSet() if req.system_prompt is None: req.system_prompt = "" - booter = config.sandbox_cfg.get("booter", "shipyard_neo") - if booter == "shipyard": - ep = config.sandbox_cfg.get("shipyard_endpoint", "") - at = config.sandbox_cfg.get("shipyard_access_token", "") - if not ep or not at: - logger.error("Shipyard sandbox configuration is incomplete.") - return - os.environ["SHIPYARD_ENDPOINT"] = ep - os.environ["SHIPYARD_ACCESS_TOKEN"] = at + allowed_tool_names = getattr(req, "_persona_allowed_tool_names", None) + persona_tools_configured = bool(getattr(req, "_persona_tools_configured", False)) tool_mgr = llm_tools - req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(PythonTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileUploadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileDownloadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool)) - if booter == "shipyard_neo": - # Neo-specific path rule: filesystem tools operate relative to sandbox - # workspace root. Do not prepend "/workspace". - req.system_prompt += ( - "\n[Shipyard Neo File Path Rule]\n" - "When using sandbox filesystem tools (upload/download/read/write/list/delete), " - "always pass paths relative to the sandbox workspace root. " - "Example: use `baidu_homepage.png` instead of `/workspace/baidu_homepage.png`.\n" - ) - - req.system_prompt += ( - "\n[Neo Skill Lifecycle Workflow]\n" - "When user asks to create/update a reusable skill in Neo mode, use lifecycle tools instead of directly writing local skill folders.\n" - "Preferred sequence:\n" - "1) Use `astrbot_create_skill_payload` to store canonical payload content and get `payload_ref`.\n" - "2) Use `astrbot_create_skill_candidate` with `skill_key` + `source_execution_ids` (and optional `payload_ref`) to create a candidate.\n" - "3) Use `astrbot_promote_skill_candidate` to release: `stage=canary` for trial; `stage=stable` for production.\n" - "For stable release, set `sync_to_local=true` to sync `payload.skill_markdown` into local `SKILL.md`.\n" - "Do not treat ad-hoc generated files as reusable Neo skills unless they are captured via payload/candidate/release.\n" - "To update an existing skill, create a new payload/candidate and promote a new release version; avoid patching old local folders directly.\n" - ) + added_tool = False - # Determine sandbox capabilities from an already-booted session. - # If no session exists yet (first request), capabilities is None - # and we register all tools conservatively. - from astrbot.core.computer.computer_client import session_booter - - sandbox_capabilities: list[str] | None = None - existing_booter = session_booter.get(session_id) - if existing_booter is not None: - sandbox_capabilities = getattr(existing_booter, "capabilities", None) - - # Browser tools: only register if profile supports browser - # (or if capabilities are unknown because sandbox hasn't booted yet) - if sandbox_capabilities is None or "browser" in sandbox_capabilities: - req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserExecTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserBatchExecTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(RunBrowserSkillTool)) - - # Neo-specific tools (always available for shipyard_neo) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetExecutionHistoryTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(AnnotateExecutionTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillPayloadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetSkillPayloadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillCandidateTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillCandidatesTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(EvaluateSkillCandidateTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(PromoteSkillCandidateTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillReleasesTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(RollbackSkillReleaseTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(SyncSkillReleaseTool)) - - if booter == "cua": - req.system_prompt += ( - "\n[CUA Desktop Control]\n" - "Use `astrbot_execute_shell` with `background=true` to launch GUI apps. " - 'Use Firefox for browser tasks, for example `firefox "https://example.com"`. ' - "After each visible step, call `astrbot_cua_screenshot` with " - "`send_to_user=true` and `return_image_to_llm=true` so the user can " - "monitor progress. When typing, inspect the screenshot first and confirm " - "the target field is focused and empty or safe to append to. Use " - "`astrbot_cua_mouse_click` for coordinates and `astrbot_cua_keyboard_type` " - "for text input; use text=`\\n` for Enter.\n" + def add_sandbox_tool(tool_cls) -> None: + nonlocal added_tool + tool = tool_mgr.get_builtin_tool(tool_cls) + if persona_tools_configured and tool.name not in allowed_tool_names: + return + req.func_tool.add_tool(tool) + added_tool = True + + add_sandbox_tool(ExecuteShellTool) + add_sandbox_tool(SandboxQueryTool) + add_sandbox_tool(SandboxLifecycleTool) + add_sandbox_tool(SandboxOperationTool) + add_sandbox_tool(PythonTool) + add_sandbox_tool(FileUploadTool) + add_sandbox_tool(FileDownloadTool) + add_sandbox_tool(FileReadTool) + add_sandbox_tool(FileWriteTool) + add_sandbox_tool(FileEditTool) + add_sandbox_tool(GrepTool) + if added_tool: + req.system_prompt = ( + f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}{SANDBOX_GUI_PROMPT}\n" ) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CuaScreenshotTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CuaMouseClickTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CuaKeyboardTypeTool)) - - req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n" def _proactive_cron_job_tools(req: ProviderRequest, plugin_context: Context) -> None: @@ -1484,14 +1471,14 @@ async def build_main_agent( if not req.session_id: req.session_id = event.unified_msg_origin - _plugin_tool_fix(event, req) + _plugin_tool_fix(event, req, config.provider_settings) await _apply_web_search_tools(event, req, plugin_context) if config.llm_safety_mode: _apply_llm_safety_mode(config, req) if config.computer_use_runtime == "sandbox": - _apply_sandbox_tools(config, req, req.session_id) + _apply_sandbox_tools(config, req) elif config.computer_use_runtime == "local": _apply_local_env_tools(req, plugin_context) diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 4efa0e5a6d..9a235088cc 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -13,6 +13,17 @@ SANDBOX_MODE_PROMPT = ( "You have access to a sandboxed environment and can execute shell commands and Python code securely." + " You can manage sandbox lifecycle, including listing sandbox providers, listing sandboxes, checking the current sandbox, creating a new sandbox, switching sandboxes, releasing sandbox occupancy, taking over a sandbox, destroying a sandbox, and copying files between sandboxes." + " Before creating a new sandbox, always check the current sandbox first." + " If there is no current sandbox, list sandboxes and inspect each sandbox's access field for this session." + " Prefer reusing access.status=current first, then access.status=idle. Never treat status=running alone as reusable." + " If access.status=occupied or access.can_switch=false, another active session controls that sandbox; do not switch to it unless the user explicitly asks to take it over." + " If you need a different provider, use astrbot_sandbox_query with action=list_providers first and pass provider_id explicitly to astrbot_sandbox_lifecycle with action=create." + " You can create a new sandbox only when the user explicitly asks for a fresh or separate environment, or when no existing sandbox can be reused safely." + " Each successful sandbox operation that accesses a sandbox automatically renews this session's lease to now plus the configured sandbox lease timeout." + " Sandbox-bound tool results include lease metadata such as lease_expires_at and lease_expires_in_seconds." + " When this session's lease expires, this session no longer has a current sandbox; use list_sandboxes and then switch, takeover, or create before continuing sandbox work." + " For long-running work, monitor lease metadata; if the remaining time is low before a long idle period or external wait, call astrbot_sandbox_lifecycle with action=renew_lease." # "Your have extended skills library, such as PDF processing, image generation, data analysis, etc. " # "Before handling complex tasks, please retrieve and review the documentation in the in /app/skills/ directory. " # "If the current task matches the description of a specific skill, prioritize following the workflow defined by that skill." @@ -22,6 +33,13 @@ # "Use shell commands such as grep, sed, awk to extract relevant information from the documentation as needed.\n" ) +SANDBOX_GUI_PROMPT = ( + " When working with GUI-capable sandboxes, use astrbot_sandbox_operation with action=capture_screenshot to show progress whenever it is helpful, especially after each meaningful GUI step." + " If the screenshot should be visible to the user, set send_to_user=true in that same capture_screenshot call instead of taking a screenshot first and then calling send_message_to_user separately." + " Set return_image_to_llm=true only when you need to inspect the screenshot yourself before deciding the next step." + " If the task is completed successfully, also send a final result screenshot with send_to_user=true to show the outcome clearly." +) + TOOL_CALL_PROMPT = ( "When using tools: " "never return an empty response; " diff --git a/astrbot/core/computer/booters/base.py b/astrbot/core/computer/booters/base.py index ec1af5cdc8..96c967c463 100644 --- a/astrbot/core/computer/booters/base.py +++ b/astrbot/core/computer/booters/base.py @@ -37,12 +37,10 @@ def gui(self) -> GUIComponent | None: async def boot(self, session_id: str) -> None: ... async def shutdown(self, **kwargs) -> None: - """Shut down the computer sandbox. + """Close the current runtime connection without deleting sandbox resources. - Subclasses may accept extra keyword arguments for - type-specific cleanup (e.g. ``delete_sandbox`` for - ShipyardNeoBooter). The default implementation ignores - them. + Subclasses may accept extra keyword arguments for type-specific cleanup. + The default implementation ignores them. """ ... diff --git a/astrbot/core/computer/booters/bay_manager.py b/astrbot/core/computer/booters/bay_manager.py deleted file mode 100644 index 61ccc1b3a5..0000000000 --- a/astrbot/core/computer/booters/bay_manager.py +++ /dev/null @@ -1,259 +0,0 @@ -"""Manage Bay container lifecycle for zero-config Shipyard Neo integration. - -When no Bay endpoint is configured, AstrBot can automatically start a Bay -container using the Docker socket (like BoxliteBooter does for Ship -containers). -""" - -from __future__ import annotations - -import asyncio -import io -import json -import tarfile -from typing import Any - -import aiodocker -import aiohttp - -from astrbot.api import logger - -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- - -BAY_IMAGE = "ghcr.io/astrbotdevs/shipyard-neo-bay:latest" -BAY_CONTAINER_NAME = "astrbot-bay" -BAY_LABEL = "astrbot.bay.managed" -BAY_PORT = 8114 -HEALTH_TIMEOUT_S = 60 -HEALTH_POLL_INTERVAL_S = 2 - - -class BayContainerManager: - """Start / reuse / stop a Bay container via Docker Engine API.""" - - def __init__( - self, - image: str = BAY_IMAGE, - host_port: int = BAY_PORT, - ) -> None: - self._image = image - self._host_port = host_port - self._docker: aiodocker.Docker | None = None - self._container: Any = None - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - async def ensure_running(self) -> str: - """Make sure a Bay container is running. Returns the endpoint URL. - - If a container labelled ``astrbot.bay.managed`` already exists - and is running, it will be reused. Otherwise a new container is - created from *self._image*. - """ - try: - self._docker = aiodocker.Docker() - except Exception as exc: - raise RuntimeError( - "Failed to connect to Docker daemon. " - "Ensure Docker is installed and running, or configure " - "an explicit Bay endpoint instead of auto-start mode." - ) from exc - - # 1. Look for an existing managed container - existing = await self._find_managed_container() - if existing is not None: - state = existing["State"] - if state.get("Running"): - cid = existing["Id"][:12] - logger.info("[BayManager] Reusing existing Bay container: %s", cid) - self._container = await self._docker.containers.get(existing["Id"]) - return f"http://127.0.0.1:{self._host_port}" - else: - # Container exists but stopped — restart it - logger.info("[BayManager] Restarting stopped Bay container") - container = await self._docker.containers.get(existing["Id"]) - await container.start() - self._container = container - return f"http://127.0.0.1:{self._host_port}" - - # 2. Pull image if needed - await self._pull_image_if_needed() - - # 3. Create and start container - logger.info( - "[BayManager] Starting Bay container: image=%s, port=%d", - self._image, - self._host_port, - ) - config = { - "Image": self._image, - "Labels": {BAY_LABEL: "true"}, - "Env": [ - "BAY_SERVER__HOST=0.0.0.0", - f"BAY_SERVER__PORT={BAY_PORT}", - "BAY_DATA_DIR=/app/data", - # allow_anonymous=false → auto-provisions API key - "BAY_SECURITY__ALLOW_ANONYMOUS=false", - ], - "HostConfig": { - "PortBindings": { - f"{BAY_PORT}/tcp": [{"HostPort": str(self._host_port)}], - }, - "Binds": [ - # Bay needs Docker socket to create sandbox containers - "/var/run/docker.sock:/var/run/docker.sock", - ], - "RestartPolicy": {"Name": "unless-stopped"}, - }, - } - self._container = await self._docker.containers.create_or_replace( - BAY_CONTAINER_NAME, config - ) - await self._container.start() - logger.info("[BayManager] Bay container started: %s", BAY_CONTAINER_NAME) - - return f"http://127.0.0.1:{self._host_port}" - - async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None: - """Block until Bay's ``/health`` endpoint returns 200.""" - url = f"http://127.0.0.1:{self._host_port}/health" - loop = asyncio.get_running_loop() - deadline = loop.time() + timeout - last_error: str = "" - - async with aiohttp.ClientSession() as session: - while loop.time() < deadline: - try: - async with session.get( - url, timeout=aiohttp.ClientTimeout(total=3) - ) as resp: - if resp.status == 200: - logger.info("[BayManager] Bay is healthy") - return - last_error = f"HTTP {resp.status}" - except Exception as exc: - last_error = str(exc) - - await asyncio.sleep(HEALTH_POLL_INTERVAL_S) - - raise TimeoutError( - f"Bay did not become healthy within {timeout}s (last error: {last_error})" - ) - - async def read_credentials(self) -> str: - """Read auto-provisioned API key from Bay container. - - Bay writes ``credentials.json`` to its data directory when - ``allow_anonymous=false`` and no explicit API key is set. - """ - if self._container is None: - return "" - - try: - # Read credentials.json from container filesystem - tar_stream = await self._container.get_archive("/app/data/credentials.json") - # get_archive returns (tar_data, stat) - tar_data = tar_stream - - if isinstance(tar_data, dict): - raw = tar_data.get("data", b"") - elif isinstance(tar_data, tuple): - # (stream, stat_info) - raw = b"" - stream = tar_data[0] - if hasattr(stream, "read"): - raw = await stream.read() - elif isinstance(stream, bytes): - raw = stream - else: - # It might be a chunked response - chunks = [] - async for chunk in stream: - chunks.append(chunk) - raw = b"".join(chunks) - else: - raw = tar_data if isinstance(tar_data, bytes) else b"" - - if not raw: - logger.debug("[BayManager] Empty tar response from container") - return "" - - tario = io.BytesIO(raw) - with tarfile.open(fileobj=tario) as tar: - for member in tar.getmembers(): - f = tar.extractfile(member) - if f: - creds = json.loads(f.read().decode("utf-8")) - api_key = creds.get("api_key", "") - if api_key: - masked = ( - f"{api_key[:8]}..." - if len(api_key) >= 10 - else "redacted" - ) - logger.info( - "[BayManager] Auto-discovered Bay API key: %s", - masked, - ) - return api_key - except Exception as exc: - logger.debug( - "[BayManager] Failed to read credentials from container: %s", exc - ) - - return "" - - async def close_client(self) -> None: - """Close the Docker client without stopping the container. - - The Bay container stays running for reuse by future sessions. - """ - if self._docker is not None: - await self._docker.close() - self._docker = None - - async def stop(self) -> None: - """Stop and remove the managed Bay container.""" - if self._container is not None: - try: - await self._container.stop() - await self._container.delete(force=True) - logger.info("[BayManager] Bay container stopped and removed") - except Exception as exc: - logger.debug("[BayManager] Error stopping Bay container: %s", exc) - finally: - self._container = None - - await self.close_client() - - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - - async def _find_managed_container(self) -> dict | None: - """Find an existing container with our management label.""" - assert self._docker is not None - containers = await self._docker.containers.list( - all=True, - filters=json.dumps({"label": [f"{BAY_LABEL}=true"]}), - ) - if containers: - # Inspect first match to get full state - return await containers[0].show() - return None - - async def _pull_image_if_needed(self) -> None: - """Pull the Bay image if it doesn't exist locally.""" - assert self._docker is not None - try: - await self._docker.images.inspect(self._image) - logger.debug("[BayManager] Image %s already exists", self._image) - except aiodocker.exceptions.DockerError: - logger.info("[BayManager] Pulling image %s ...", self._image) - # Pull with progress logging - await self._docker.images.pull(self._image) - logger.info("[BayManager] Image %s pulled successfully", self._image) diff --git a/astrbot/core/computer/booters/boxlite.py b/astrbot/core/computer/booters/boxlite.py deleted file mode 100644 index aa3ca59761..0000000000 --- a/astrbot/core/computer/booters/boxlite.py +++ /dev/null @@ -1,194 +0,0 @@ -import asyncio -import random -from typing import Any - -import aiohttp -import boxlite -from shipyard import FileSystemComponent as ShipyardFileSystemComponent -from shipyard.python import PythonComponent as ShipyardPythonComponent -from shipyard.shell import ShellComponent as ShipyardShellComponent - -from astrbot.api import logger - -from ..olayer import FileSystemComponent, PythonComponent, ShellComponent -from .base import ComputerBooter -from .shipyard import ShipyardFileSystemWrapper - - -class MockShipyardSandboxClient: - def __init__(self, sb_url: str) -> None: - self.sb_url = sb_url.rstrip("/") - - async def _exec_operation( - self, - ship_id: str, - operation_type: str, - payload: dict[str, Any], - session_id: str, - ) -> dict[str, Any]: - async with aiohttp.ClientSession() as session: - headers = {"X-SESSION-ID": session_id} - async with session.post( - f"{self.sb_url}/{operation_type}", - json=payload, - headers=headers, - ) as response: - if response.status == 200: - return await response.json() - else: - error_text = await response.text() - raise Exception( - f"Failed to exec operation: {response.status} {error_text}" - ) - - async def upload_file(self, path: str, remote_path: str) -> dict: - """Upload a file to the sandbox""" - url = f"http://{self.sb_url}/upload" - - try: - # Read file content - with open(path, "rb") as f: - file_content = f.read() - - # Create multipart form data - data = aiohttp.FormData() - data.add_field( - "file", - file_content, - filename=remote_path.split("/")[-1], - content_type="application/octet-stream", - ) - data.add_field("file_path", remote_path) - - timeout = aiohttp.ClientTimeout(total=120) # 2 minutes for file upload - - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(url, data=data) as response: - if response.status == 200: - logger.info( - "[Computer] File uploaded to Boxlite sandbox: %s", - remote_path, - ) - return { - "success": True, - "message": "File uploaded successfully", - "file_path": remote_path, - } - else: - error_text = await response.text() - return { - "success": False, - "error": f"Server returned {response.status}: {error_text}", - "message": "File upload failed", - } - - except aiohttp.ClientError as e: - logger.error(f"Failed to upload file: {e}") - return { - "success": False, - "error": f"Connection error: {str(e)}", - "message": "File upload failed", - } - except asyncio.TimeoutError: - return { - "success": False, - "error": "File upload timeout", - "message": "File upload failed", - } - except FileNotFoundError: - logger.error(f"File not found: {path}") - return { - "success": False, - "error": f"File not found: {path}", - "message": "File upload failed", - } - except Exception as e: - logger.error(f"Unexpected error uploading file: {e}") - return { - "success": False, - "error": f"Internal error: {str(e)}", - "message": "File upload failed", - } - - async def wait_healthy(self, ship_id: str, session_id: str) -> None: - """Mock wait healthy""" - loop = 60 - while loop > 0: - try: - logger.info( - f"Checking health for sandbox {ship_id} on {self.sb_url}..." - ) - url = f"{self.sb_url}/health" - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - logger.info(f"Sandbox {ship_id} is healthy") - return - except Exception: - await asyncio.sleep(1) - loop -= 1 - - -class BoxliteBooter(ComputerBooter): - async def boot(self, session_id: str) -> None: - logger.info( - f"Booting(Boxlite) for session: {session_id}, this may take a while..." - ) - random_port = random.randint(20000, 30000) - self.box = boxlite.SimpleBox( - image="soulter/shipyard-ship", - memory_mib=512, - cpus=1, - ports=[ - { - "host_port": random_port, - "guest_port": 8123, - } - ], - ) - await self.box.start() - logger.info(f"Boxlite booter started for session: {session_id}") - self.mocked = MockShipyardSandboxClient( - sb_url=f"http://127.0.0.1:{random_port}" - ) - self._python = ShipyardPythonComponent( - client=self.mocked, # type: ignore - ship_id=self.box.id, - session_id=session_id, - ) - self._shell = ShipyardShellComponent( - client=self.mocked, # type: ignore - ship_id=self.box.id, - session_id=session_id, - ) - self._ship_fs = ShipyardFileSystemComponent( - client=self.mocked, # type: ignore - ship_id=self.box.id, - session_id=session_id, - ) - self._fs = ShipyardFileSystemWrapper( - _shipyard_fs=self._ship_fs, _shipyard_shell=self._shell - ) - - await self.mocked.wait_healthy(self.box.id, session_id) - - async def shutdown(self) -> None: - logger.info(f"Shutting down Boxlite booter for ship: {self.box.id}") - self.box.shutdown() - logger.info(f"Boxlite booter for ship: {self.box.id} stopped") - - @property - def fs(self) -> FileSystemComponent: - return self._fs - - @property - def python(self) -> PythonComponent: - return self._python - - @property - def shell(self) -> ShellComponent: - return self._shell - - async def upload_file(self, path: str, file_name: str) -> dict: - """Upload file to sandbox""" - return await self.mocked.upload_file(path, file_name) diff --git a/astrbot/core/computer/booters/cua.py b/astrbot/core/computer/booters/cua.py deleted file mode 100644 index 151b4c0e04..0000000000 --- a/astrbot/core/computer/booters/cua.py +++ /dev/null @@ -1,878 +0,0 @@ -from __future__ import annotations - -import base64 -import inspect -import shlex -from dataclasses import asdict, dataclass, is_dataclass -from pathlib import Path -from typing import Any - -from astrbot.api import logger - -from ..olayer import FileSystemComponent, GUIComponent, PythonComponent, ShellComponent -from .base import ComputerBooter -from .cua_defaults import CUA_CONFIG_KEYS, CUA_DEFAULT_CONFIG -from .shipyard_search_file_util import search_files_via_shell - -_POSIX_OS_TYPES = {"linux", "darwin", "macos"} - -_CUA_BACKGROUND_LAUNCHER = """ -import subprocess, sys, time - -p = subprocess.Popen( - ["sh", "-lc", sys.argv[1]], - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - start_new_session=True, -) -sys.stdout.write(str(p.pid) + "\\n") -sys.stdout.flush() -time.sleep(0.2) -code = p.poll() -sys.exit(0 if code is None else code) -""".strip() - - -async def _maybe_await(value: Any) -> Any: - if inspect.isawaitable(value): - return await value - return value - - -def build_cua_booter_kwargs(sandbox_cfg: dict[str, Any]) -> dict[str, Any]: - return { - name: sandbox_cfg.get(config_key, CUA_DEFAULT_CONFIG[name]) - for name, config_key in CUA_CONFIG_KEYS.items() - } - - -async def _write_base64_via_shell( - shell: ShellComponent, - path: str, - data: bytes, -) -> dict[str, Any]: - encoded = base64.b64encode(data).decode("ascii") - decoder = ( - "import base64,pathlib,sys; " - "pathlib.Path(sys.argv[1]).write_bytes(base64.b64decode(sys.stdin.read()))" - ) - return await shell.exec( - f"python3 -c {shlex.quote(decoder)} {shlex.quote(path)} <<'EOF'\n{encoded}\nEOF" - ) - - -@dataclass(slots=True) -class ProcessResult: - stdout: str - stderr: str - exit_code: int | None - success: bool - - -def _maybe_model_dump(value: Any) -> dict[str, Any]: - if isinstance(value, dict): - return value - if is_dataclass(value) and not isinstance(value, type): - return asdict(value) - if hasattr(value, "model_dump"): - dumped = value.model_dump() - if isinstance(dumped, dict): - return dumped - if hasattr(value, "dict"): - dumped = value.dict() - if isinstance(dumped, dict): - return dumped - attr_payload = { - key: getattr(value, key) - for key in ( - "stdout", - "stderr", - "output", - "error", - "returncode", - "return_code", - "exit_code", - "success", - ) - if hasattr(value, key) - } - if attr_payload: - return attr_payload - return {} - - -def _slice_content_by_lines( - content: str, - *, - offset: int | None = None, - limit: int | None = None, -) -> str: - lines = content.splitlines(keepends=True) - start = 0 if offset is None else offset - selected = lines[start:] if limit is None else lines[start : start + limit] - return "".join(selected) - - -def _normalize_process_result(raw: Any) -> ProcessResult: - """Best-effort normalization for the process shapes returned by CUA SDKs.""" - payload = _maybe_model_dump(raw) - if not payload and isinstance(raw, str): - payload = {"stdout": raw} - - def first_text(*keys: str) -> str: - for key in keys: - value = payload.get(key) - if value is not None: - return str(value) - return "" - - stdout = first_text("stdout", "output") - stderr = first_text("stderr", "error") - exit_code = payload.get("exit_code") - if exit_code is None: - exit_code = payload.get("returncode") - if exit_code is None: - exit_code = payload.get("return_code") - if exit_code is None: - exit_code = 0 if not stderr else 1 - success = bool(payload.get("success", not stderr and exit_code in (0, None))) - return ProcessResult( - stdout=stdout, - stderr=stderr, - exit_code=exit_code, - success=success, - ) - - -def _is_missing_python3_error(stderr: str) -> bool: - lowered = stderr.lower() - return "python3" in lowered and ( - "not found" in lowered - or "command not found" in lowered - or "no such file" in lowered - ) - - -def _python3_requirement_error(operation: str, stderr: str) -> str: - return f"CUA {operation} requires python3 in the sandbox image: {stderr}" - - -def _normalize_with_python3_requirement(raw: Any, operation: str) -> ProcessResult: - proc = _normalize_process_result(raw) - if proc.stderr and _is_missing_python3_error(proc.stderr): - return ProcessResult( - stdout=proc.stdout, - stderr=_python3_requirement_error(operation, proc.stderr), - exit_code=proc.exit_code, - success=proc.success, - ) - return proc - - -async def _exec_python3_or_error( - shell: ShellComponent, - code: str, - *, - operation: str, - timeout: int | None = 30, -) -> ProcessResult: - result = await shell.exec(f"python3 - <<'PY'\n{code}\nPY", timeout=timeout) - return _normalize_with_python3_requirement(result, operation) - - -def _is_posix_os_type(os_type: str) -> bool: - return os_type.lower() in _POSIX_OS_TYPES - - -def _posix_fs_error_message(os_type: str) -> str: - return ( - "CUA filesystem shell fallback is only supported for POSIX images; " - f"os_type={os_type!r} does not support the required shell commands." - ) - - -def _non_posix_filesystem_result(path: str, os_type: str) -> dict[str, Any]: - error = _posix_fs_error_message(os_type) - return {"success": False, "path": path, "error": error, "message": error} - - -def _raise_non_posix_filesystem_error(os_type: str) -> None: - raise RuntimeError(_posix_fs_error_message(os_type)) - - -def _resolve_component_method( - component: Any, - method_names: str | tuple[str, ...], -) -> Any | None: - if component is None: - return None - names = (method_names,) if isinstance(method_names, str) else method_names - for method_name in names: - method = getattr(component, method_name, None) - if method is not None: - return method - return None - - -def _missing_component_method_error( - component_name: str, - method_names: str | tuple[str, ...], -) -> RuntimeError: - names = (method_names,) if isinstance(method_names, str) else method_names - candidates = ", ".join(f"{component_name}.{name}" for name in names) - return RuntimeError( - f"CUA sandbox does not provide any of: {candidates}. " - "Please check the installed CUA SDK version and sandbox backend." - ) - - -def _has_component_method(root: Any, component_name: str, method_name: str) -> bool: - component = getattr(root, component_name, None) - return getattr(component, method_name, None) is not None - - -def _resolve_files_components(sandbox: Any) -> tuple[Any, ...]: - components: list[Any] = [] - seen_ids: set[int] = set() - for name in ("files", "filesystem"): - component = getattr(sandbox, name, None) - if component is None: - continue - component_id = id(component) - if component_id in seen_ids: - continue - seen_ids.add(component_id) - components.append(component) - return tuple(components) - - -def _resolve_files_method( - components: tuple[Any, ...], - method_names: str | tuple[str, ...], -) -> Any | None: - for component in components: - method = _resolve_component_method(component, method_names) - if method is not None: - return method - return None - - -def _normalize_native_upload_result(raw: Any, file_name: str) -> dict[str, Any]: - payload = _maybe_model_dump(raw) - if not payload: - return {"success": True, "file_path": file_name} - if "file_path" not in payload and "path" not in payload: - payload["file_path"] = file_name - if "success" not in payload: - payload["success"] = not bool(payload.get("error") or payload.get("stderr")) - return payload - - -class CuaShellComponent(ShellComponent): - def __init__(self, sandbox: Any, os_type: str = "linux") -> None: - self._sandbox = sandbox - self._os_type = os_type.lower() - shell = sandbox.shell - self._exec_raw = getattr(shell, "exec", None) or getattr(shell, "run", None) - if self._exec_raw is None: - raise RuntimeError("CUA sandbox shell must provide `.exec` or `.run`.") - - async def exec( - self, - command: str, - cwd: str | None = None, - env: dict[str, str] | None = None, - timeout: int | None = 30, - shell: bool = True, - background: bool = False, - ) -> dict[str, Any]: - if not shell: - return { - "stdout": "", - "stderr": "error: only shell mode is supported in CUA booter.", - "exit_code": 2, - "success": False, - } - - kwargs: dict[str, Any] = {} - if cwd is not None: - kwargs["cwd"] = cwd - if timeout is not None: - kwargs["timeout"] = timeout - if env: - kwargs["env"] = env - if background: - if not _is_posix_os_type(self._os_type): - return { - "stdout": "", - "stderr": "error: background shell execution is only supported for POSIX CUA images.", - "exit_code": 2, - "success": False, - } - command = _build_cua_background_command(command) - - result = await _maybe_await(self._exec_raw(command, **kwargs)) - proc = ( - _normalize_with_python3_requirement(result, "background execution") - if background - else _normalize_process_result(result) - ) - response = { - "stdout": proc.stdout, - "stderr": proc.stderr, - "exit_code": proc.exit_code, - "success": proc.success, - } - if background: - try: - response["pid"] = int(proc.stdout.strip().splitlines()[-1]) - except Exception: - response["pid"] = None - return response - - -def _build_cua_background_command(command: str) -> str: - return f"python3 -c {shlex.quote(_CUA_BACKGROUND_LAUNCHER)} {shlex.quote(command)}" - - -class CuaPythonComponent(PythonComponent): - def __init__(self, sandbox: Any, os_type: str = "linux") -> None: - self._sandbox = sandbox - self._os_type = os_type - python = getattr(sandbox, "python", None) - self._python_exec = None - if python is not None: - self._python_exec = getattr(python, "exec", None) or getattr( - python, "run", None - ) - - async def exec( - self, - code: str, - kernel_id: str | None = None, - timeout: int = 30, - silent: bool = False, - ) -> dict[str, Any]: - _ = kernel_id - if self._python_exec is not None: - result = await _maybe_await(self._python_exec(code, timeout=timeout)) - proc = _normalize_process_result(result) - else: - shell = CuaShellComponent(self._sandbox, os_type=self._os_type) - proc = await _exec_python3_or_error( - shell, - code, - operation="Python execution fallback", - timeout=timeout, - ) - - output_text = "" if silent else proc.stdout - error_text = proc.stderr - return { - "success": proc.success if not silent else not bool(error_text), - "data": { - "output": {"text": output_text, "images": []}, - "error": error_text, - }, - "output": output_text, - "error": error_text, - } - - -def _write_result(path: str, result: dict[str, Any]) -> dict[str, Any]: - stderr = result.get("stderr", "") - if stderr and _is_missing_python3_error(stderr): - result = { - **result, - "stderr": _python3_requirement_error("filesystem write fallback", stderr), - } - if result.get("stderr") or result.get("success") is False: - return {"success": False, "path": path, **result} - return {"success": True, "path": path, **result} - - -class CuaFileSystemComponent(FileSystemComponent): - def __init__( - self, sandbox: Any, os_type: str = CUA_DEFAULT_CONFIG["os_type"] - ) -> None: - self._shell = CuaShellComponent(sandbox, os_type=os_type) - self._fs_components = _resolve_files_components(sandbox) - self._os_type = os_type.lower() - self._fallback = _PosixShellFileSystem(self._shell, self._os_type) - - async def create_file( - self, - path: str, - content: str = "", - mode: int = 0o644, - ) -> dict[str, Any]: - write_result = await self.write_file(path, content) - if not write_result.get("success"): - return {**write_result, "mode": mode, "mode_applied": False} - return {"success": True, "path": path, "mode": mode, "mode_applied": False} - - async def read_file( - self, - path: str, - encoding: str = "utf-8", - offset: int | None = None, - limit: int | None = None, - ) -> dict[str, Any]: - read_file = _resolve_files_method( - self._fs_components, ("read_file", "read_text") - ) - if read_file is None: - return await self._fallback.read_file(path, encoding, offset, limit) - else: - content = await _maybe_await(read_file(path)) - if isinstance(content, bytes): - content = content.decode(encoding, errors="replace") - return { - "success": True, - "path": path, - "content": _slice_content_by_lines( - str(content), offset=offset, limit=limit - ), - } - - async def write_file( - self, - path: str, - content: str, - mode: str = "w", - encoding: str = "utf-8", - ) -> dict[str, Any]: - _ = mode - write_file = _resolve_files_method( - self._fs_components, ("write_file", "write_text") - ) - if write_file is None: - return await self._fallback.write_file(path, content, mode, encoding) - else: - await _maybe_await(write_file(path, content)) - return {"success": True, "path": path} - - async def delete_file(self, path: str) -> dict[str, Any]: - delete = _resolve_files_method( - self._fs_components, ("delete", "delete_file", "remove") - ) - if delete is None: - return await self._fallback.delete_file(path) - else: - await _maybe_await(delete(path)) - return {"success": True, "path": path} - - async def list_dir( - self, - path: str = ".", - show_hidden: bool = False, - ) -> dict[str, Any]: - list_dir = _resolve_files_method(self._fs_components, ("list_dir", "list")) - if list_dir is not None: - entries = await _maybe_await(list_dir(path)) - return {"success": True, "path": path, "entries": entries} - return await self._fallback.list_dir(path, show_hidden) - - async def search_files( - self, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - ) -> dict[str, Any]: - return await self._fallback.search_files( - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - - async def edit_file( - self, - path: str, - old_string: str, - new_string: str, - replace_all: bool = False, - encoding: str = "utf-8", - ) -> dict[str, Any]: - read_result = await self.read_file(path, encoding=encoding) - if not read_result.get("success"): - return read_result - content = read_result.get("content", "") - occurrences = content.count(old_string) - if occurrences == 0: - return { - "success": False, - "error": "old string not found in file", - "replacements": 0, - } - updated = content.replace(old_string, new_string, -1 if replace_all else 1) - write_result = await self.write_file(path, updated, encoding=encoding) - if not write_result.get("success"): - return write_result - return { - "success": True, - "path": path, - "replacements": occurrences if replace_all else 1, - } - - -class _PosixShellFileSystem(FileSystemComponent): - def __init__(self, shell: CuaShellComponent, os_type: str) -> None: - self._shell = shell - self._os_type = os_type.lower() - - def _ensure_posix(self, path: str) -> dict[str, Any] | None: - if _is_posix_os_type(self._os_type): - return None - return _non_posix_filesystem_result(path, self._os_type) - - async def read_file( - self, - path: str, - encoding: str = "utf-8", - offset: int | None = None, - limit: int | None = None, - ) -> dict[str, Any]: - _ = encoding - if error := self._ensure_posix(path): - return error - result = await self._shell.exec(f"cat {shlex.quote(path)}") - if result.get("stderr"): - return {"success": False, "path": path, "error": result["stderr"]} - return { - "success": True, - "path": path, - "content": _slice_content_by_lines( - str(result.get("stdout", "")), offset=offset, limit=limit - ), - } - - async def write_file( - self, - path: str, - content: str, - mode: str = "w", - encoding: str = "utf-8", - ) -> dict[str, Any]: - _ = mode - if error := self._ensure_posix(path): - return error - result = await _write_base64_via_shell( - self._shell, path, content.encode(encoding) - ) - return _write_result(path, result) - - async def delete_file(self, path: str) -> dict[str, Any]: - if error := self._ensure_posix(path): - return error - result = await self._shell.exec(f"rm -rf {shlex.quote(path)}") - if result.get("stderr"): - return {"success": False, "path": path, "error": result["stderr"]} - return {"success": True, "path": path} - - async def list_dir( - self, - path: str = ".", - show_hidden: bool = False, - ) -> dict[str, Any]: - if error := self._ensure_posix(path): - return error - return await _list_dir_via_shell(self._shell, path, show_hidden) - - async def search_files( - self, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - ) -> dict[str, Any]: - search_path = path or "." - if error := self._ensure_posix(search_path): - return error - return await search_files_via_shell( - self._shell, - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - - -async def _list_dir_via_shell( - shell: CuaShellComponent, - path: str, - show_hidden: bool, -) -> dict[str, Any]: - flags = "-1A" if show_hidden else "-1" - result = await shell.exec(f"ls {flags} {shlex.quote(path)}") - stdout = result.get("stdout", "") - return { - "success": not bool(result.get("stderr")), - "path": path, - "entries": [line for line in stdout.splitlines() if line.strip()], - "error": result.get("stderr", ""), - } - - -class CuaGUIComponent(GUIComponent): - def __init__(self, sandbox: Any) -> None: - self._sandbox = sandbox - mouse = getattr(sandbox, "mouse", None) - keyboard = getattr(sandbox, "keyboard", None) - self._click = _resolve_component_method(mouse, "click") - self._type_text = _resolve_component_method(keyboard, "type") - self._press_key = _resolve_component_method( - keyboard, ("press", "key_press", "press_key") - ) - - async def screenshot(self, path: str | None = None) -> dict[str, Any]: - raw = await self._sandbox.screenshot() - data = _screenshot_to_bytes(raw) - if path: - Path(path).parent.mkdir(parents=True, exist_ok=True) - Path(path).write_bytes(data) - return { - "success": True, - "path": path, - "mime_type": "image/png", - "base64": base64.b64encode(data).decode("ascii"), - } - - async def click(self, x: int, y: int, button: str = "left") -> dict[str, Any]: - if self._click is None: - raise _missing_component_method_error("mouse", "click") - result = await _maybe_await(self._click(x, y, button=button)) - payload = _maybe_model_dump(result) - return {"success": bool(payload.get("success", True)), **payload} - - async def type_text(self, text: str) -> dict[str, Any]: - if self._type_text is None: - raise _missing_component_method_error("keyboard", "type") - result = await _maybe_await(self._type_text(text)) - payload = _maybe_model_dump(result) - return {"success": bool(payload.get("success", True)), **payload} - - async def press_key(self, key: str) -> dict[str, Any]: - if self._press_key is None: - raise _missing_component_method_error( - "keyboard", ("press", "key_press", "press_key") - ) - result = await _maybe_await(self._press_key(key)) - payload = _maybe_model_dump(result) - return {"success": bool(payload.get("success", True)), **payload} - - -def _screenshot_to_bytes(raw: Any) -> bytes: - def from_str(value: str) -> bytes: - if value.startswith("data:image"): - value = value.split(",", 1)[1] - try: - return base64.b64decode(value, validate=True) - except Exception: - candidate = Path(value) - if candidate.is_file(): - return candidate.read_bytes() - return value.encode("utf-8") - - if isinstance(raw, (bytes, bytearray)): - return bytes(raw) - if isinstance(raw, str): - return from_str(raw) - if hasattr(raw, "save"): - import io - - output = io.BytesIO() - raw.save(output, format="PNG") - return output.getvalue() - payload = _maybe_model_dump(raw) - for key in ("data", "base64", "image"): - value = payload.get(key) - if value: - return _screenshot_to_bytes(value) - raise TypeError(f"Unsupported CUA screenshot result: {type(raw)!r}") - - -@dataclass(slots=True) -class _CuaRuntime: - sandbox_cm: Any - sandbox: Any - shell: CuaShellComponent - python: CuaPythonComponent - fs: CuaFileSystemComponent - gui: CuaGUIComponent | None - - -class CuaBooter(ComputerBooter): - def __init__( - self, - image: str = CUA_DEFAULT_CONFIG["image"], - os_type: str = CUA_DEFAULT_CONFIG["os_type"], - ttl: int = CUA_DEFAULT_CONFIG["ttl"], - telemetry_enabled: bool = CUA_DEFAULT_CONFIG["telemetry_enabled"], - local: bool = CUA_DEFAULT_CONFIG["local"], - api_key: str = CUA_DEFAULT_CONFIG["api_key"], - ) -> None: - self.image = image - self.os_type = os_type - self.ttl = ttl - self.telemetry_enabled = telemetry_enabled - self.local = local - self.api_key = api_key - self._runtime: _CuaRuntime | None = None - - async def boot(self, session_id: str) -> None: - _ = session_id - try: - from cua import Image, Sandbox - except ImportError as exc: - raise RuntimeError( - "CUA sandbox support requires the optional `cua` package. " - "Install it with `pip install cua` in the AstrBot environment." - ) from exc - - image_obj = self._build_image(Image) - ephemeral_kwargs = self._build_ephemeral_kwargs(Sandbox.ephemeral) - sandbox_cm = Sandbox.ephemeral(image_obj, **ephemeral_kwargs) - sandbox = await sandbox_cm.__aenter__() - try: - self._runtime = _CuaRuntime( - sandbox_cm=sandbox_cm, - sandbox=sandbox, - shell=CuaShellComponent(sandbox, os_type=self.os_type), - python=CuaPythonComponent(sandbox, os_type=self.os_type), - fs=CuaFileSystemComponent(sandbox, os_type=self.os_type), - gui=CuaGUIComponent(sandbox), - ) - except Exception: - await sandbox_cm.__aexit__(None, None, None) - self._runtime = None - raise - logger.info( - "[Computer] CUA sandbox booted: image=%s, os_type=%s", - self.image, - self.os_type, - ) - - def _build_image(self, image_cls: Any) -> Any: - image_name = (self.image or self.os_type or "linux").strip().lower() - factory = getattr(image_cls, image_name, None) - if callable(factory): - return factory() - os_factory = getattr(image_cls, (self.os_type or "linux").strip().lower(), None) - if callable(os_factory): - return os_factory() - return image_name - - def _build_ephemeral_kwargs(self, ephemeral: Any) -> dict[str, Any]: - try: - parameters = inspect.signature(ephemeral).parameters - except (TypeError, ValueError): - return {} - kwargs: dict[str, Any] = {} - if "ttl" in parameters: - kwargs["ttl"] = self.ttl - if "telemetry_enabled" in parameters: - kwargs["telemetry_enabled"] = self.telemetry_enabled - if "local" in parameters: - kwargs["local"] = self.local - if "api_key" in parameters and self.api_key: - kwargs["api_key"] = self.api_key - return kwargs - - async def shutdown(self) -> None: - if self._runtime is not None: - await self._runtime.sandbox_cm.__aexit__(None, None, None) - self._runtime = None - - @property - def capabilities(self) -> tuple[str, ...] | None: - capabilities = ["python", "shell", "filesystem"] - if self._runtime is None: - return tuple(capabilities) - - sandbox = self._runtime.sandbox - has_screenshot = getattr(sandbox, "screenshot", None) is not None - has_mouse = _has_component_method(sandbox, "mouse", "click") - has_keyboard = _has_component_method(sandbox, "keyboard", "type") - if has_screenshot or has_mouse or has_keyboard: - capabilities.append("gui") - if has_screenshot: - capabilities.append("screenshot") - if has_mouse: - capabilities.append("mouse") - if has_keyboard: - capabilities.append("keyboard") - return tuple(capabilities) - - @property - def fs(self) -> FileSystemComponent: - if self._runtime is None: - raise RuntimeError("CuaBooter is not initialized.") - return self._runtime.fs - - @property - def python(self) -> PythonComponent: - if self._runtime is None: - raise RuntimeError("CuaBooter is not initialized.") - return self._runtime.python - - @property - def shell(self) -> ShellComponent: - if self._runtime is None: - raise RuntimeError("CuaBooter is not initialized.") - return self._runtime.shell - - @property - def gui(self) -> GUIComponent | None: - return None if self._runtime is None else self._runtime.gui - - async def upload_file(self, path: str, file_name: str) -> dict: - local_path = Path(path) - if not local_path.is_file(): - return {"success": False, "error": f"File not found: {path}"} - sandbox = None if self._runtime is None else self._runtime.sandbox - if sandbox is not None and hasattr(sandbox, "upload_file"): - return _maybe_model_dump( - await sandbox.upload_file(str(local_path), file_name) - ) - files_components = () if sandbox is None else _resolve_files_components(sandbox) - upload = _resolve_files_method(files_components, "upload") - if upload is not None: - result = await _maybe_await(upload(str(local_path), file_name)) - return _normalize_native_upload_result(result, file_name) - write_bytes = _resolve_files_method(files_components, "write_bytes") - if write_bytes is not None: - result = await _maybe_await(write_bytes(file_name, local_path.read_bytes())) - return _normalize_native_upload_result(result, file_name) - if not _is_posix_os_type(self.os_type): - return _non_posix_filesystem_result(file_name, self.os_type) - result = await _write_base64_via_shell( - self.shell, file_name, local_path.read_bytes() - ) - return { - "success": not bool(result.get("stderr")), - "file_path": file_name, - **result, - } - - async def download_file(self, remote_path: str, local_path: str) -> None: - sandbox = None if self._runtime is None else self._runtime.sandbox - if sandbox is not None and hasattr(sandbox, "download_file"): - await sandbox.download_file(remote_path, local_path) - return - if not _is_posix_os_type(self.os_type): - _raise_non_posix_filesystem_error(self.os_type) - result = await self.shell.exec(f"base64 {shlex.quote(remote_path)}") - if result.get("stderr"): - raise RuntimeError(result["stderr"]) - Path(local_path).parent.mkdir(parents=True, exist_ok=True) - Path(local_path).write_bytes(base64.b64decode(result.get("stdout", ""))) - - async def available(self) -> bool: - return self._runtime is not None diff --git a/astrbot/core/computer/booters/cua_defaults.py b/astrbot/core/computer/booters/cua_defaults.py deleted file mode 100644 index a36c6e6546..0000000000 --- a/astrbot/core/computer/booters/cua_defaults.py +++ /dev/null @@ -1,18 +0,0 @@ -CUA_DEFAULT_CONFIG = { - "image": "linux", - "os_type": "linux", - "ttl": 3600, - "idle_timeout": 0, - "telemetry_enabled": False, - "local": True, - "api_key": "", -} - -CUA_CONFIG_KEYS = { - "image": "cua_image", - "os_type": "cua_os_type", - "ttl": "cua_ttl", - "telemetry_enabled": "cua_telemetry_enabled", - "local": "cua_local", - "api_key": "cua_api_key", -} diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py index 1fb7b5cf7a..8ac4dd6a4d 100644 --- a/astrbot/core/computer/booters/local.py +++ b/astrbot/core/computer/booters/local.py @@ -20,7 +20,8 @@ from ..olayer import FileSystemComponent, PythonComponent, ShellComponent from .base import ComputerBooter -from .shipyard_search_file_util import _truncate_long_lines + +_MAX_SEARCH_LINE_COLUMNS = 1000 _BLOCKED_COMMAND_PATTERNS = [ " rm -rf ", @@ -83,6 +84,23 @@ def _decode_shell_output(output: bytes | None) -> str: return _decode_bytes_with_fallback(output, preferred_encoding="utf-8") +def _truncate_long_lines(text: str) -> str: + output_lines: list[str] = [] + for line in text.splitlines(keepends=True): + line_ending = "" + line_body = line + if line.endswith("\r\n"): + line_body = line[:-2] + line_ending = "\r\n" + elif line.endswith("\n") or line.endswith("\r"): + line_body = line[:-1] + line_ending = line[-1] + if len(line_body) > _MAX_SEARCH_LINE_COLUMNS: + line_body = line_body[:_MAX_SEARCH_LINE_COLUMNS] + output_lines.append(f"{line_body}{line_ending}") + return "".join(output_lines) + + @dataclass class LocalShellComponent(ShellComponent): async def exec( diff --git a/astrbot/core/computer/booters/shell_background.py b/astrbot/core/computer/booters/shell_background.py deleted file mode 100644 index 6fe94c133a..0000000000 --- a/astrbot/core/computer/booters/shell_background.py +++ /dev/null @@ -1,18 +0,0 @@ -import shlex - -_BACKGROUND_SPAWN_SCRIPT = ( - "import subprocess, sys; " - "p = subprocess.Popen(" - "['bash', '-lc', sys.argv[1]], " - "stdin=subprocess.DEVNULL, " - "stdout=subprocess.DEVNULL, " - "stderr=subprocess.DEVNULL, " - "start_new_session=True, " - "close_fds=True" - "); " - "print(p.pid)" -) - - -def build_detached_shell_command(command: str) -> str: - return f"python3 -c {shlex.quote(_BACKGROUND_SPAWN_SCRIPT)} {shlex.quote(command)}" diff --git a/astrbot/core/computer/booters/shipyard.py b/astrbot/core/computer/booters/shipyard.py deleted file mode 100644 index a8375544da..0000000000 --- a/astrbot/core/computer/booters/shipyard.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -import shlex -from typing import Any - -from shipyard import FileSystemComponent as ShipyardFileSystemComponent -from shipyard import ShipyardClient, Spec - -from astrbot.api import logger - -from ..olayer import FileSystemComponent, PythonComponent, ShellComponent -from .base import ComputerBooter -from .shell_background import build_detached_shell_command -from .shipyard_search_file_util import search_files_via_shell - - -def _maybe_model_dump(value: Any) -> dict[str, Any]: - if isinstance(value, dict): - return value - if hasattr(value, "model_dump"): - dumped = value.model_dump() - if isinstance(dumped, dict): - return dumped - return {} - - -class ShipyardShellWrapper: - def __init__(self, _shipyard_shell: ShellComponent): - self._shell = _shipyard_shell - - async def exec( - self, - command: str, - cwd: str | None = None, - env: dict[str, str] | None = None, - timeout: int | None = 300, - shell: bool = True, - background: bool = False, - ) -> dict[str, Any]: - if not shell: - return { - "stdout": "", - "stderr": "error: only shell mode is supported in shipyard booter.", - "exit_code": 2, - "success": False, - } - - run_command = command - if env: - env_prefix = " ".join( - f"{k}={shlex.quote(str(v))}" for k, v in sorted(env.items()) - ) - run_command = f"{env_prefix} {run_command}" - - if background: - run_command = build_detached_shell_command(run_command) - - result = await self._shell.exec( - run_command, - timeout=timeout or 300, - cwd=cwd, - ) - payload = _maybe_model_dump(result) - - stdout = payload.get("output", payload.get("stdout", "")) or "" - stderr = payload.get("error", payload.get("stderr", "")) or "" - exit_code = payload.get("exit_code") - if background: - pid: int | None = None - try: - pid = int(str(stdout).strip().splitlines()[-1]) - except Exception: - pid = None - return { - "pid": pid, - "stdout": ( - f"Command is running in the background. pid={pid}" - if pid is not None - else "Command was submitted in the background." - ), - "stderr": stderr, - "exit_code": exit_code, - "success": bool(payload.get("success", not stderr)), - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "command": payload.get("command"), - } - - return { - "stdout": stdout, - "stderr": stderr, - "exit_code": exit_code, - "success": bool(payload.get("success", not stderr)), - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "command": payload.get("command"), - } - - -class ShipyardFileSystemWrapper: - def __init__( - self, _shipyard_fs: ShipyardFileSystemComponent, _shipyard_shell: ShellComponent - ): - self._fs = _shipyard_fs - self._shell = _shipyard_shell - - async def create_file( - self, path: str, content: str = "", mode: int = 420 - ) -> dict[str, Any]: - return await self._fs.create_file(path=path, content=content, mode=mode) - - async def read_file( - self, - path: str, - encoding: str = "utf-8", - offset: int | None = None, - limit: int | None = None, - ) -> dict[str, Any]: - return await self._fs.read_file( - path=path, encoding=encoding, offset=offset, limit=limit - ) - - async def write_file( - self, path: str, content: str, mode: str = "w", encoding: str = "utf-8" - ) -> dict[str, Any]: - return await self._fs.write_file( - path=path, content=content, mode=mode, encoding=encoding - ) - - async def list_dir( - self, path: str = ".", show_hidden: bool = False - ) -> dict[str, Any]: - return await self._fs.list_dir(path=path, show_hidden=show_hidden) - - async def delete_file(self, path: str) -> dict[str, Any]: - return await self._fs.delete_file(path=path) - - async def search_files( - self, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - ) -> dict[str, Any]: - return await search_files_via_shell( - self._shell, - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - - async def edit_file( - self, - path: str, - old_string: str, - new_string: str, - replace_all: bool = False, - encoding: str = "utf-8", - ) -> dict[str, Any]: - return await self._fs.edit_file( - path=path, - old_string=old_string, - new_string=new_string, - replace_all=replace_all, - encoding=encoding, - ) - - -class ShipyardBooter(ComputerBooter): - def __init__( - self, - endpoint_url: str, - access_token: str, - ttl: int = 3600, - session_num: int = 10, - ) -> None: - self._sandbox_client = ShipyardClient( - endpoint_url=endpoint_url, access_token=access_token - ) - self._ttl = ttl - self._session_num = session_num - - async def boot(self, session_id: str) -> None: - ship = await self._sandbox_client.create_ship( - ttl=self._ttl, - spec=Spec(cpus=1.0, memory="512m"), - max_session_num=self._session_num, - session_id=session_id, - ) - logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}") - self._ship = ship - self._shell = ShipyardShellWrapper(self._ship.shell) - self._fs = ShipyardFileSystemWrapper(self._ship.fs, self._shell) - - async def shutdown(self) -> None: - logger.info("[Computer] Shipyard booter shutdown.") - - @property - def fs(self) -> FileSystemComponent: - return self._fs - - @property - def python(self) -> PythonComponent: - return self._ship.python - - @property - def shell(self) -> ShellComponent: - return self._shell - - async def upload_file(self, path: str, file_name: str) -> dict: - """Upload file to sandbox""" - result = await self._ship.upload_file(path, file_name) - logger.info("[Computer] File uploaded to Shipyard sandbox: %s", file_name) - return result - - async def download_file(self, remote_path: str, local_path: str): - """Download file from sandbox.""" - result = await self._ship.download_file(remote_path, local_path) - logger.info( - "[Computer] File downloaded from Shipyard sandbox: %s -> %s", - remote_path, - local_path, - ) - return result - - async def available(self) -> bool: - """Check if the sandbox is available.""" - try: - ship_id = self._ship.id - data = await self._sandbox_client.get_ship(ship_id) - if not data: - logger.info( - "[Computer] Shipyard sandbox health check: id=%s, healthy=False (no data)", - ship_id, - ) - return False - health = bool(data.get("status", 0) == 1) - logger.info( - "[Computer] Shipyard sandbox health check: id=%s, healthy=%s", - ship_id, - health, - ) - return health - except Exception as e: - logger.error(f"Error checking Shipyard sandbox availability: {e}") - return False diff --git a/astrbot/core/computer/booters/shipyard_neo.py b/astrbot/core/computer/booters/shipyard_neo.py deleted file mode 100644 index dd982960f4..0000000000 --- a/astrbot/core/computer/booters/shipyard_neo.py +++ /dev/null @@ -1,702 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -import shlex -from typing import Any, cast - -from astrbot.api import logger - -from ..olayer import ( - BrowserComponent, - FileSystemComponent, - PythonComponent, - ShellComponent, -) -from .base import ComputerBooter -from .shell_background import build_detached_shell_command -from .shipyard_search_file_util import search_files_via_shell - -try: - from shipyard_neo import BayClient - from shipyard_neo.sandbox import Sandbox -except ImportError: - logger.warning( - "shipyard_neo_sdk is not installed. ShipyardNeoBooter will not work without it." - ) - - -def _maybe_model_dump(value: Any) -> dict[str, Any]: - if isinstance(value, dict): - return value - if hasattr(value, "model_dump"): - dumped = value.model_dump() - if isinstance(dumped, dict): - return dumped - return {} - - -def _slice_content_by_lines( - content: str, - *, - offset: int | None = None, - limit: int | None = None, -) -> str: - lines = content.splitlines(keepends=True) - start = 0 if offset is None else offset - selected = lines[start:] if limit is None else lines[start : start + limit] - return "".join(selected) - - -class NeoPythonComponent(PythonComponent): - def __init__(self, sandbox: Sandbox) -> None: - self._sandbox = sandbox - - async def exec( - self, - code: str, - kernel_id: str | None = None, - timeout: int = 30, - silent: bool = False, - ) -> dict[str, Any]: - _ = kernel_id # Bay runtime does not expose kernel_id in current SDK. - result = await self._sandbox.python.exec(code, timeout=timeout) - payload = _maybe_model_dump(result) - - output_text = payload.get("output", "") or "" - error_text = payload.get("error", "") or "" - data = payload.get("data") if isinstance(payload.get("data"), dict) else {} - rich_output = data.get("output") if isinstance(data.get("output"), dict) else {} - if not isinstance(rich_output.get("images"), list): - rich_output["images"] = [] - if "text" not in rich_output: - rich_output["text"] = output_text - - if silent: - rich_output["text"] = "" - - return { - "success": bool(payload.get("success", error_text == "")), - "data": { - "output": rich_output, - "error": error_text, - }, - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "code": payload.get("code"), - "output": output_text, - "error": error_text, - } - - -class NeoShellComponent(ShellComponent): - def __init__(self, sandbox: Sandbox) -> None: - self._sandbox = sandbox - - async def exec( - self, - command: str, - cwd: str | None = None, - env: dict[str, str] | None = None, - timeout: int | None = 300, - shell: bool = True, - background: bool = False, - ) -> dict[str, Any]: - if not shell: - return { - "stdout": "", - "stderr": "error: only shell mode is supported in shipyard_neo booter.", - "exit_code": 2, - "success": False, - } - - run_command = command - if env: - env_prefix = " ".join( - f"{k}={shlex.quote(str(v))}" for k, v in sorted(env.items()) - ) - run_command = f"{env_prefix} {run_command}" - - if background: - run_command = build_detached_shell_command(run_command) - - result = await self._sandbox.shell.exec( - run_command, - timeout=timeout or 300, - cwd=cwd, - ) - payload = _maybe_model_dump(result) - - stdout = payload.get("output", "") or "" - stderr = payload.get("error", "") or "" - exit_code = payload.get("exit_code") - if background: - pid: int | None = None - try: - pid = int(stdout.strip().splitlines()[-1]) - except Exception: - pid = None - return { - "pid": pid, - "stdout": ( - f"Command is running in the background. pid={pid}" - if pid is not None - else "Command was submitted in the background." - ), - "stderr": stderr, - "exit_code": exit_code, - "success": bool(payload.get("success", not stderr)), - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "command": payload.get("command"), - } - - return { - "stdout": stdout, - "stderr": stderr, - "exit_code": exit_code, - "success": bool(payload.get("success", not stderr)), - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "command": payload.get("command"), - } - - -class NeoFileSystemComponent(FileSystemComponent): - def __init__(self, sandbox: Sandbox, shell: ShellComponent) -> None: - self._sandbox = sandbox - self._shell = shell - - async def create_file( - self, - path: str, - content: str = "", - mode: int = 0o644, - ) -> dict[str, Any]: - _ = mode - await self._sandbox.filesystem.write_file(path, content) - return {"success": True, "path": path} - - async def read_file( - self, - path: str, - encoding: str = "utf-8", - offset: int | None = None, - limit: int | None = None, - ) -> dict[str, Any]: - _ = encoding - content = await self._sandbox.filesystem.read_file(path) - return { - "success": True, - "path": path, - "content": _slice_content_by_lines( - content, - offset=offset, - limit=limit, - ), - } - - async def search_files( - self, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - ) -> dict[str, Any]: - return await search_files_via_shell( - self._shell, - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - - async def edit_file( - self, - path: str, - old_string: str, - new_string: str, - replace_all: bool = False, - encoding: str = "utf-8", - ) -> dict[str, Any]: - _ = encoding - content = await self._sandbox.filesystem.read_file(path) - occurrences = content.count(old_string) - if occurrences == 0: - return { - "success": False, - "error": "old string not found in file", - "replacements": 0, - } - if replace_all: - updated = content.replace(old_string, new_string) - replacements = occurrences - else: - updated = content.replace(old_string, new_string, 1) - replacements = 1 - await self._sandbox.filesystem.write_file(path, updated) - return { - "success": True, - "path": path, - "replacements": replacements, - } - - async def write_file( - self, - path: str, - content: str, - mode: str = "w", - encoding: str = "utf-8", - ) -> dict[str, Any]: - _ = mode - _ = encoding - await self._sandbox.filesystem.write_file(path, content) - return {"success": True, "path": path} - - async def delete_file(self, path: str) -> dict[str, Any]: - await self._sandbox.filesystem.delete(path) - return {"success": True, "path": path} - - async def list_dir( - self, - path: str = ".", - show_hidden: bool = False, - ) -> dict[str, Any]: - entries = await self._sandbox.filesystem.list_dir(path) - data = [] - for entry in entries: - item = _maybe_model_dump(entry) - if not show_hidden and str(item.get("name", "")).startswith("."): - continue - data.append(item) - return {"success": True, "path": path, "entries": data} - - -class NeoBrowserComponent(BrowserComponent): - def __init__(self, sandbox: Sandbox) -> None: - self._sandbox = sandbox - - async def exec( - self, - cmd: str, - timeout: int = 30, - description: str | None = None, - tags: str | None = None, - learn: bool = False, - include_trace: bool = False, - ) -> dict[str, Any]: - result = await self._sandbox.browser.exec( - cmd, - timeout=timeout, - description=description, - tags=tags, - learn=learn, - include_trace=include_trace, - ) - return _maybe_model_dump(result) - - async def exec_batch( - self, - commands: list[str], - timeout: int = 60, - stop_on_error: bool = True, - description: str | None = None, - tags: str | None = None, - learn: bool = False, - include_trace: bool = False, - ) -> dict[str, Any]: - result = await self._sandbox.browser.exec_batch( - commands, - timeout=timeout, - stop_on_error=stop_on_error, - description=description, - tags=tags, - learn=learn, - include_trace=include_trace, - ) - return _maybe_model_dump(result) - - async def run_skill( - self, - skill_key: str, - timeout: int = 60, - stop_on_error: bool = True, - include_trace: bool = False, - description: str | None = None, - tags: str | None = None, - ) -> dict[str, Any]: - result = await self._sandbox.browser.run_skill( - skill_key=skill_key, - timeout=timeout, - stop_on_error=stop_on_error, - include_trace=include_trace, - description=description, - tags=tags, - ) - return _maybe_model_dump(result) - - -class ShipyardNeoBooter(ComputerBooter): - """Booter backed by Shipyard Neo (Bay). - - If *endpoint_url* is empty or set to ``"__auto__"``, Bay will be - started automatically as a Docker container (like Boxlite does for - Ship containers). - """ - - AUTO_SENTINEL = "__auto__" - DEFAULT_PROFILE = "python-default" - - def __init__( - self, - endpoint_url: str, - access_token: str, - profile: str = "", - ttl: int = 3600, - ) -> None: - self._endpoint_url = endpoint_url - self._access_token = access_token - self._profile = profile.strip() if profile else "" - self._ttl = ttl - self._client: BayClient | None = None - self._sandbox: Sandbox | None = None - self._bay_manager: Any = None # BayContainerManager when auto-started - self._fs: FileSystemComponent | None = None - self._python: PythonComponent | None = None - self._shell: ShellComponent | None = None - self._browser: BrowserComponent | None = None - - @property - def bay_client(self) -> Any: - return self._client - - @property - def sandbox(self) -> Any: - return self._sandbox - - @property - def capabilities(self) -> tuple[str, ...] | None: - """Sandbox capabilities from the Bay profile. - - Returns an immutable tuple after :meth:`boot`; ``None`` before boot. - """ - if self._sandbox is None: - return None - caps = getattr(self._sandbox, "capabilities", None) - return tuple(caps) if caps is not None else None - - @property - def is_auto_mode(self) -> bool: - """True when Bay should be auto-started.""" - ep = (self._endpoint_url or "").strip() - return not ep or ep == self.AUTO_SENTINEL - - async def boot(self, session_id: str) -> None: - _ = session_id - - # --- Auto-start Bay if needed --- - if self.is_auto_mode: - from .bay_manager import BayContainerManager - - # Clean up previous manager if re-booting - if self._bay_manager is not None: - await self._bay_manager.close_client() - - logger.info("[Computer] Neo auto-start mode: launching Bay container") - self._bay_manager = BayContainerManager() - self._endpoint_url = await self._bay_manager.ensure_running() - await self._bay_manager.wait_healthy() - # Read auto-provisioned credentials - if not self._access_token: - self._access_token = await self._bay_manager.read_credentials() - logger.info("[Computer] Bay auto-started at %s", self._endpoint_url) - - if not self._endpoint_url or not self._access_token: - if self._bay_manager is not None: - raise ValueError( - "Bay container started but credentials could not be read. " - "Ensure Bay generated credentials.json, or set access_token manually." - ) - raise ValueError( - "Shipyard Neo sandbox configuration is incomplete. " - "Set endpoint (default http://127.0.0.1:8114) and access token, " - "or ensure Bay's credentials.json is accessible for auto-discovery." - ) - - self._client = BayClient( - endpoint_url=self._endpoint_url, - access_token=self._access_token, - ) - await self._client.__aenter__() - - # Resolve profile: user-specified > smart selection > default. - # An empty profile means auto-select; any non-empty profile must be - # honoured as an explicit choice, including "python-default". - resolved_profile = await self._resolve_profile(self._client) - - self._sandbox = await self._client.create_sandbox( - profile=resolved_profile, - ttl=self._ttl, - ) - - # --- Readiness gate: wait until sandbox session is READY --- - await self._wait_until_ready(self._sandbox) - - self._shell = NeoShellComponent(self._sandbox) - self._fs = NeoFileSystemComponent(self._sandbox, self._shell) - self._python = NeoPythonComponent(self._sandbox) - - caps = self.capabilities or () - self._browser = ( - NeoBrowserComponent(self._sandbox) if "browser" in caps else None - ) - - logger.info( - "Got Shipyard Neo sandbox: %s (profile=%s, capabilities=%s, auto=%s)", - self._sandbox.id, - resolved_profile, - list(caps), - bool(self._bay_manager), - ) - - async def _wait_until_ready(self, sandbox: Sandbox) -> None: - """Poll sandbox status until READY, or raise on FAILED / timeout. - - Covers both warm-pool hits (near-instant) and cold starts (up to 180s). - On FAILED, EXPIRED, or timeout the sandbox is deleted before raising - so no orphan resources leak on Bay. - """ - READINESS_TIMEOUT = 180 # seconds - POLL_INTERVAL = 2 # seconds - - sandbox_id = sandbox.id - deadline = asyncio.get_running_loop().time() + READINESS_TIMEOUT - - while True: - await sandbox.refresh() - status = getattr(sandbox.status, "value", str(sandbox.status)) - - if status == "ready": - logger.info( - "[Computer] Sandbox %s is ready (profile=%s)", - sandbox_id, - sandbox.profile, - ) - return - - if status in {"failed", "expired"}: - logger.error( - "[Computer] Sandbox %s reached terminal state: %s", - sandbox_id, - status, - ) - try: - await sandbox.delete() - except Exception as del_err: - logger.warning( - "[Computer] Failed to delete failed sandbox %s: %s", - sandbox_id, - del_err, - ) - raise RuntimeError( - f"Sandbox {sandbox_id} is in terminal state: {status}" - ) - - remaining = deadline - asyncio.get_running_loop().time() - if remaining <= 0: - logger.error( - "[Computer] Sandbox %s did not become ready within %ds " - "(last status: %s)", - sandbox_id, - READINESS_TIMEOUT, - status, - ) - try: - await sandbox.delete() - except Exception as del_err: - logger.warning( - "[Computer] Failed to delete timed-out sandbox %s: %s", - sandbox_id, - del_err, - ) - raise TimeoutError( - f"Sandbox {sandbox_id} did not become ready within " - f"{READINESS_TIMEOUT}s (last status: {status})" - ) - - logger.debug( - "[Computer] Sandbox %s status=%s, waiting...", - sandbox_id, - status, - ) - await asyncio.sleep(POLL_INTERVAL) - - async def _resolve_profile(self, client: Any) -> str: - """Pick the best profile for this session. - - Resolution order: - 1. User-specified profile (non-empty) → use as-is. - 2. Query ``GET /v1/profiles`` and pick the profile with the most - capabilities, preferring profiles that include ``"browser"``. - 3. Fall back to :attr:`DEFAULT_PROFILE`. - - Auth errors (401/403) are re-raised immediately — they indicate a - misconfigured token, and silently falling back would just delay the - real failure to ``create_sandbox``. - """ - # User explicitly set a profile → honour it. - if self._profile: - logger.info("[Computer] Using user-specified profile: %s", self._profile) - return self._profile - - # Query Bay for available profiles - from shipyard_neo.errors import ForbiddenError, UnauthorizedError - - try: - profile_list = await client.list_profiles() - profiles = profile_list.items - except (UnauthorizedError, ForbiddenError): - raise # auth errors must not be silenced - except Exception as exc: - logger.warning( - "[Computer] Failed to query Bay profiles, falling back to %s: %s", - self.DEFAULT_PROFILE, - exc, - ) - return self.DEFAULT_PROFILE - - if not profiles: - return self.DEFAULT_PROFILE - - def _score(p: Any) -> tuple[int, int]: - """(has_browser, capability_count) — higher is better.""" - caps = getattr(p, "capabilities", []) or [] - return (1 if "browser" in caps else 0, len(caps)) - - best = max(profiles, key=_score) - chosen = getattr(best, "id", self.DEFAULT_PROFILE) - - if chosen != self.DEFAULT_PROFILE: - caps = getattr(best, "capabilities", []) - logger.info( - "[Computer] Auto-selected profile %s (capabilities=%s)", - chosen, - caps, - ) - - return chosen - - async def shutdown(self, *, delete_sandbox: bool = False) -> None: - if self._client is not None: - sandbox_id = getattr(self._sandbox, "id", "unknown") - - # Delete sandbox on Bay BEFORE closing the HTTP client. - # This is critical for cleanup — calling delete after - # __aexit__ would fail because the httpx session is already - # torn down. - if delete_sandbox and self._sandbox is not None: - try: - logger.info( - "[Computer] Deleting Shipyard Neo sandbox: id=%s", sandbox_id - ) - await self._sandbox.delete() - logger.info( - "[Computer] Shipyard Neo sandbox deleted: id=%s", sandbox_id - ) - except Exception as e: - logger.warning( - "[Computer] Failed to delete sandbox %s (may already be " - "cleaned up by Bay GC): %s", - sandbox_id, - e, - ) - - logger.info( - "[Computer] Shutting down Shipyard Neo sandbox client: id=%s", - sandbox_id, - ) - await self._client.__aexit__(None, None, None) - self._client = None - self._sandbox = None - logger.info( - "[Computer] Shipyard Neo sandbox client shut down: id=%s", sandbox_id - ) - - # NOTE: We intentionally do NOT stop the Bay container here. - # It stays running for reuse by future sessions. The user can - # stop it manually or via ``BayContainerManager.stop()``. - if self._bay_manager is not None: - await self._bay_manager.close_client() - - @property - def fs(self) -> FileSystemComponent: - if self._fs is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - return self._fs - - @property - def python(self) -> PythonComponent: - if self._python is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - return self._python - - @property - def shell(self) -> ShellComponent: - if self._shell is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - return self._shell - - @property - def browser(self) -> BrowserComponent: - if self._browser is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - return self._browser - - async def upload_file(self, path: str, file_name: str) -> dict: - if self._sandbox is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - with open(path, "rb") as f: - content = f.read() - remote_path = file_name.lstrip("/") - await self._sandbox.filesystem.upload(remote_path, content) - logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path) - return { - "success": True, - "message": "File uploaded successfully", - "file_path": remote_path, - } - - async def download_file(self, remote_path: str, local_path: str) -> None: - if self._sandbox is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - content = await self._sandbox.filesystem.download(remote_path.lstrip("/")) - local_dir = os.path.dirname(local_path) - if local_dir: - os.makedirs(local_dir, exist_ok=True) - with open(local_path, "wb") as f: - f.write(cast(bytes, content)) - logger.info( - "[Computer] File downloaded from Neo sandbox: %s -> %s", - remote_path, - local_path, - ) - - async def available(self) -> bool: - if self._sandbox is None: - return False - try: - await self._sandbox.refresh() - status = getattr(self._sandbox.status, "value", str(self._sandbox.status)) - healthy = status not in {"failed", "expired"} - logger.info( - "[Computer] Neo sandbox health check: id=%s, status=%s, healthy=%s", - getattr(self._sandbox, "id", "unknown"), - status, - healthy, - ) - return healthy - except Exception as e: - logger.error(f"Error checking Shipyard Neo sandbox availability: {e}") - return False diff --git a/astrbot/core/computer/booters/shipyard_search_file_util.py b/astrbot/core/computer/booters/shipyard_search_file_util.py deleted file mode 100644 index cdd41de82e..0000000000 --- a/astrbot/core/computer/booters/shipyard_search_file_util.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -import shlex -from typing import Any - -from ..olayer import ShellComponent - -_MAX_SEARCH_LINE_COLUMNS = 1000 - - -def _truncate_long_lines(text: str) -> str: - output_lines: list[str] = [] - for line in text.splitlines(keepends=True): - line_ending = "" - line_body = line - if line.endswith("\r\n"): - line_body = line[:-2] - line_ending = "\r\n" - elif line.endswith("\n") or line.endswith("\r"): - line_body = line[:-1] - line_ending = line[-1] - - if len(line_body) > _MAX_SEARCH_LINE_COLUMNS: - line_body = line_body[:_MAX_SEARCH_LINE_COLUMNS] - - output_lines.append(f"{line_body}{line_ending}") - return "".join(output_lines) - - -def _build_rg_command( - *, - pattern: str, - path: str, - glob: str | None, - after_context: int | None, - before_context: int | None, -) -> list[str]: - command = [ - "rg", - "--color=never", - "-n", - "--max-columns", - str(_MAX_SEARCH_LINE_COLUMNS), - "-e", - pattern, - ] - if glob: - command.extend(["-g", glob]) - if after_context is not None: - command.extend(["-A", str(after_context)]) - if before_context is not None: - command.extend(["-B", str(before_context)]) - command.extend(["--", path]) - return command - - -def _build_grep_command( - *, - pattern: str, - path: str, - glob: str | None, - after_context: int | None, - before_context: int | None, -) -> list[str]: - command = ["grep", "-R", "-H", "-n", "-e", pattern] - if glob: - command.append(f"--include={glob}") - if after_context is not None: - command.extend(["-A", str(after_context)]) - if before_context is not None: - command.extend(["-B", str(before_context)]) - command.extend(["--", path]) - return command - - -def _quote_command(command: list[str]) -> str: - return shlex.join(command) - - -def build_search_command( - *, - pattern: str, - path: str, - glob: str | None, - after_context: int | None, - before_context: int | None, -) -> str: - rg_command = _quote_command( - _build_rg_command( - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - ) - grep_command = _quote_command( - _build_grep_command( - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - ) - return ( - "if command -v rg >/dev/null 2>&1; then " - f"{rg_command}; " - "elif command -v grep >/dev/null 2>&1; then " - f"{grep_command}; " - "else " - "echo 'Neither rg nor grep is available in the sandbox.' >&2; " - "exit 127; " - "fi" - ) - - -async def search_files_via_shell( - shell: ShellComponent, - *, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - timeout: int = 30, -) -> dict[str, Any]: - command = build_search_command( - pattern=pattern, - path=path or ".", - glob=glob, - after_context=after_context, - before_context=before_context, - ) - result = await shell.exec(command, timeout=timeout) - stdout = _truncate_long_lines(str(result.get("stdout", "") or "")) - stderr = str(result.get("stderr", "") or "") - exit_code = result.get("exit_code") - if exit_code in (0, None): - return {"success": True, "content": stdout} - if exit_code == 1: - return {"success": True, "content": ""} - return { - "success": False, - "content": "", - "error": stderr or f"command exited with code {exit_code}", - "exit_code": exit_code, - } diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index 9be646265e..5984f48c41 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -1,13 +1,12 @@ import asyncio import json -import os import shutil -import time import uuid -from dataclasses import dataclass from pathlib import Path from astrbot.api import logger +from astrbot.core.agent.tool import FunctionTool +from astrbot.core.provider.register import llm_tools from astrbot.core.skills.skill_manager import SANDBOX_SKILLS_ROOT, SkillManager from astrbot.core.star.context import Context from astrbot.core.utils.astrbot_path import ( @@ -17,74 +16,312 @@ from .booters.base import ComputerBooter from .booters.local import LocalBooter - -session_booter: dict[str, ComputerBooter] = {} -local_booter: ComputerBooter | None = None +from .sandbox_manager import SandboxManager +from .sandbox_models import SandboxStatus +from .sandbox_provider import SandboxProvider +from .sandbox_registry import SandboxRegistry +from .sandbox_tool_binding import mark_tool_as_sandbox_provider_tool + +local_booter: LocalBooter | None = None +sandbox_registry = SandboxRegistry() +sandbox_manager = SandboxManager(registry=sandbox_registry, providers={}) _MANAGED_SKILLS_FILE = ".astrbot_managed_skills.json" +_SANDBOX_SKILLS_SYNC_LOCK = asyncio.Lock() +# Tracks tools registered per provider so core can remove them on unregister. +_provider_tools: dict[str, list[FunctionTool]] = {} -@dataclass(slots=True) -class _CUAIdleState: - expires_at: float - task: asyncio.Task +def _sandbox_provider_info(provider_id: str, provider: SandboxProvider) -> dict: + return { + "provider_id": provider_id, + "capabilities": sorted(getattr(provider, "capabilities", set())), + "tool_names": sorted(getattr(provider, "tool_names", set())), + "system_prompt": str(getattr(provider, "system_prompt", "") or ""), + } -cua_idle_state: dict[str, _CUAIdleState] = {} +def _has_managed_sandboxes_for_provider(provider_id: str) -> bool: + return any( + record.get("managed") and record.get("provider") == provider_id + for record in sandbox_manager.registry.list_sandboxes() + ) -def _get_cua_idle_timeout(config: dict) -> float: - sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) - value = sandbox_cfg.get("cua_idle_timeout", 0) - try: - timeout = float(value) - except (TypeError, ValueError): - return 0.0 - return max(timeout, 0.0) +def register_sandbox_provider( + provider: SandboxProvider, + *, + replace: bool = False, + tools: list[FunctionTool] | None = None, +) -> None: + """Register a plugin-provided sandbox runtime. + + Args: + provider: The sandbox provider instance. + replace: If ``True``, replace an existing provider with the same ID. + tools: Optional list of provider-specific tools to register with the + global LLM tool manager. Core will automatically unregister these + tools when the provider is unregistered. + """ + if not provider.provider_id: + raise ValueError("Sandbox provider_id must be a non-empty string.") + if provider.provider_id in sandbox_manager.providers and not replace: + raise RuntimeError( + f"Sandbox provider {provider.provider_id} is already registered" + ) -def _clear_cua_idle_state(session_id: str) -> None: - state = cua_idle_state.pop(session_id, None) - if state is not None and not state.task.done(): - state.task.cancel() + # Clean up previous tools when replacing. + if replace and provider.provider_id in sandbox_manager.providers: + _unregister_provider_tools(provider.provider_id) + sandbox_manager.providers[provider.provider_id] = provider -def _schedule_cua_idle_cleanup(session_id: str, timeout: float) -> None: - _clear_cua_idle_state(session_id) - if timeout <= 0: - return - expires_at = time.monotonic() + timeout + if tools: + registered: list[FunctionTool] = [] + for tool in tools: + mark_tool_as_sandbox_provider_tool(tool, provider.provider_id) + llm_tools.func_list.append(tool) + registered.append(tool) + _provider_tools[provider.provider_id] = registered + logger.info( + "Sandbox provider %s registered with %d tool(s)", + provider.provider_id, + len(registered), + ) + else: + logger.info("Sandbox provider %s registered", provider.provider_id) - async def _expire_when_idle() -> None: - try: - remaining = expires_at - time.monotonic() - if remaining > 0: - await asyncio.sleep(remaining) - state = cua_idle_state.get(session_id) - if state is None or state.expires_at != expires_at: - return +def unregister_sandbox_provider(provider_id: str, *, force: bool = False) -> None: + if not force and _has_managed_sandboxes_for_provider(provider_id): + raise RuntimeError( + f"Sandbox provider {provider_id} has active managed sandboxes; " + "destroy them or pass force=True before unregistering." + ) + + if force: + # Synchronously clear registry and memory state for this provider's + # sandboxes. Async destroy_booter is best-effort via background task. + _cleanup_provider_sandboxes_sync(provider_id) + + _unregister_provider_tools(provider_id) + sandbox_manager.providers.pop(provider_id, None) - booter = session_booter.get(session_id) + +def _unregister_provider_tools(provider_id: str) -> None: + registered = _provider_tools.pop(provider_id, []) + if registered: + registered_ids = {id(tool) for tool in registered} + llm_tools.func_list = [ + tool for tool in llm_tools.func_list if id(tool) not in registered_ids + ] + from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor + + FunctionToolExecutor.clear_runtime_computer_tools_cache(provider_id) + if registered: + logger.info( + "Unregistered %d tool(s) for sandbox provider %s", + len(registered), + provider_id, + ) + + +def _cleanup_provider_sandboxes_sync(provider_id: str) -> None: + """Synchronous cleanup of a provider's managed sandboxes on unregister. + + Temporary registry records and in-memory state are removed immediately. If + a temporary booter is alive and an event loop is running, an async + destroy_booter task is spawned as a best-effort cleanup. Persistent records + are preserved and their live booters are only shut down to close the current + runtime connection. + """ + import asyncio + + for record in list(sandbox_manager.registry.list_sandboxes()): + if not record.get("managed") or record.get("provider") != provider_id: + continue + sandbox_id = record["sandbox_id"] + if record.get("retention_policy") == "persistent": + booter = sandbox_manager.session_booter.pop(sandbox_id, None) + sandbox_manager.clear_idle_state(sandbox_id) + sandbox_manager.drop_boot_lock(sandbox_id) if booter is not None: try: - await booter.shutdown() - except Exception as shutdown_err: - logger.warning( - "[Computer] Failed to shutdown idle CUA sandbox for session %s: %s", - session_id, - shutdown_err, + loop = asyncio.get_running_loop() + loop.create_task(_safe_shutdown_booter(booter, record)) + except RuntimeError: + pass # no running event loop + continue + booter = sandbox_manager.session_booter.pop(sandbox_id, None) + sandbox_manager.clear_idle_state(sandbox_id) + sandbox_manager.registry.delete_sandbox(sandbox_id) + sandbox_manager.drop_boot_lock(sandbox_id) + if booter is not None: + try: + loop = asyncio.get_running_loop() + provider = sandbox_manager.providers.get(provider_id) + if provider is not None: + loop.create_task(_safe_destroy_booter(provider, booter, record)) + except RuntimeError: + pass # no running event loop + try: + sandbox_manager.registry.save() + except Exception as exc: + logger.warning( + "[Computer] Failed to save registry after force-unregister: %s", + exc, + ) + from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor + + FunctionToolExecutor.clear_runtime_computer_tools_cache(provider_id) + logger.info( + "Force-unregistered sandbox provider %s: sandboxes cleaned up", + provider_id, + ) + + +async def cleanup_sandbox_provider(provider_id: str) -> None: + """Destroy all sandboxes owned by a provider before unregistering it.""" + provider = sandbox_manager.providers.get(provider_id) + removed = 0 + preserved = 0 + handled_sandbox_ids: set[str] = set() + + def _pop_live_booter(sandbox_id: str): + booter = sandbox_manager.session_booter.pop(sandbox_id, None) + sandbox_manager.clear_idle_state(sandbox_id) + sandbox_manager.drop_boot_lock(sandbox_id) + return booter + + for record in list(sandbox_manager.registry.list_sandboxes()): + if not record.get("managed") or record.get("provider") != provider_id: + continue + sandbox_id = record["sandbox_id"] + handled_sandbox_ids.add(sandbox_id) + booter = _pop_live_booter(sandbox_id) + if record.get("retention_policy") == "persistent": + if booter is not None: + await _safe_shutdown_booter(booter, record) + preserved += 1 + continue + if booter is not None and provider is not None: + if not await _safe_destroy_booter(provider, booter, record): + sandbox_manager.session_booter[sandbox_id] = booter + sandbox_manager.registry.update_sandbox_status( + sandbox_id, + "error", + ) + continue + sandbox_manager.registry.delete_sandbox(sandbox_id) + removed += 1 + + for sandbox_id, booter in list(sandbox_manager.session_booter.items()): + booter_provider = getattr(booter, "provider_id", None) + if str(booter_provider or "") != provider_id: + continue + if sandbox_id in handled_sandbox_ids: + continue + record = sandbox_manager.registry.get_sandbox(sandbox_id) or { + "sandbox_id": sandbox_id, + "provider": provider_id, + "managed": True, + "retention_policy": "temporary", + } + sandbox_manager.session_booter.pop(sandbox_id, None) + sandbox_manager.clear_idle_state(sandbox_id) + sandbox_manager.drop_boot_lock(sandbox_id) + if provider is not None: + if not await _safe_destroy_booter(provider, booter, record): + sandbox_manager.session_booter[sandbox_id] = booter + if sandbox_manager.registry.get_sandbox(sandbox_id) is not None: + sandbox_manager.registry.update_sandbox_status( + sandbox_id, + "error", ) - finally: - session_booter.pop(session_id, None) - except asyncio.CancelledError: - raise - finally: - state = cua_idle_state.get(session_id) - if state is not None and state.expires_at == expires_at: - cua_idle_state.pop(session_id, None) + continue + if sandbox_manager.registry.get_sandbox(sandbox_id) is not None: + sandbox_manager.registry.delete_sandbox(sandbox_id) + removed += 1 + try: + await sandbox_manager.save_registry_async() + except Exception as exc: + logger.warning( + "[Computer] Failed to save registry after provider cleanup: %s", + exc, + ) + logger.info( + "Provider sandbox cleanup completed: provider=%s removed_temporary=%d preserved_persistent=%d", + provider_id, + removed, + preserved, + ) + + +def detach_sandbox_provider(provider_id: str) -> None: + """Remove a provider and its registered tools without touching sandboxes.""" + _unregister_provider_tools(provider_id) + sandbox_manager.providers.pop(provider_id, None) + + +async def _safe_destroy_booter( + provider: SandboxProvider, booter: ComputerBooter, record: dict +) -> bool: + try: + await provider.destroy_booter(booter, record) + return True + except Exception as exc: + logger.warning( + "Background destroy_booter failed for sandbox %s: %s", + record.get("sandbox_id"), + exc, + ) + return False + + +async def _safe_shutdown_booter(booter: ComputerBooter, record: dict) -> None: + try: + await booter.shutdown() + except Exception as exc: + logger.warning( + "Background shutdown failed for sandbox %s: %s", + record.get("sandbox_id"), + exc, + ) + + +def get_sandbox_provider_info(provider_id: str) -> dict | None: + provider = sandbox_manager.providers.get(provider_id) + if provider is None: + return None + return _sandbox_provider_info(provider_id, provider) + + +def get_current_sandbox_provider_id(session_id: str) -> str | None: + current_sandbox_id = sandbox_manager.registry.get_current_sandbox_id(session_id) + if not current_sandbox_id: + return None + current_record = sandbox_manager.registry.get_sandbox(current_sandbox_id) + if current_record is None: + return None + if current_record.get("status") in { + SandboxStatus.STOPPING, + SandboxStatus.STOPPED, + SandboxStatus.ERROR, + }: + return None + provider_id = str(current_record.get("provider") or "").strip() + return provider_id or None + - task = asyncio.create_task(_expire_when_idle()) - cua_idle_state[session_id] = _CUAIdleState(expires_at=expires_at, task=task) +def list_sandbox_providers() -> list[dict]: + return [ + _sandbox_provider_info(provider_id, provider) + for provider_id, provider in sorted(sandbox_manager.providers.items()) + ] + + +async def cleanup_managed_sandboxes() -> None: + await sandbox_manager.cleanup_managed_sandboxes() def _list_local_skill_dirs(skills_root: Path) -> list[Path]: @@ -131,65 +368,6 @@ def _normalize_shell_exec_result(result: object) -> dict: return {"exit_code": 0, "stdout": "", "stderr": ""} -def _discover_bay_credentials(endpoint: str) -> str: - """Try to auto-discover Bay API key from credentials.json. - - Search order: - 1. BAY_DATA_DIR env var - 2. Mono-repo relative path: ../pkgs/bay/ (dev layout) - 3. Current working directory - - Returns: - API key string, or empty string if not found. - """ - candidates: list[Path] = [] - - # 1. BAY_DATA_DIR env var - bay_data_dir = os.environ.get("BAY_DATA_DIR") - if bay_data_dir: - candidates.append(Path(bay_data_dir) / "credentials.json") - - # 2. Mono-repo layout: AstrBot/../pkgs/bay/credentials.json - astrbot_root = Path(__file__).resolve().parents[3] # astrbot/core/computer/ → root - candidates.append(astrbot_root.parent / "pkgs" / "bay" / "credentials.json") - - # 3. Current working directory - candidates.append(Path.cwd() / "credentials.json") - - for cred_path in candidates: - if not cred_path.is_file(): - continue - try: - data = json.loads(cred_path.read_text()) - api_key = data.get("api_key", "") - if api_key: - # Optionally verify endpoint matches - cred_endpoint = data.get("endpoint", "") - if ( - cred_endpoint - and endpoint - and cred_endpoint.rstrip("/") != endpoint.rstrip("/") - ): - logger.warning( - "[Computer] credentials.json endpoint mismatch: " - "file=%s, configured=%s — using key anyway", - cred_endpoint, - endpoint, - ) - masked_key = f"{api_key[:4]}..." if len(api_key) >= 6 else "redacted" - logger.info( - "[Computer] Auto-discovered Bay API key from %s (prefix=%s)", - cred_path, - masked_key, - ) - return api_key - except (json.JSONDecodeError, OSError) as exc: - logger.debug("[Computer] Failed to read %s: %s", cred_path, exc) - - logger.debug("[Computer] No Bay credentials.json found in search paths") - return "" - - def _build_python_exec_command(script: str) -> str: return ( "if command -v python3 >/dev/null 2>&1; then PYBIN=python3; " @@ -201,7 +379,7 @@ def _build_python_exec_command(script: str) -> str: ) -def _build_apply_sync_command() -> str: +def _build_apply_sync_command(zip_name: str = "skills.zip") -> str: """Build shell command for sync stage only. This stage mutates sandbox files (managed skill replacement) but does not scan @@ -215,7 +393,7 @@ def _build_apply_sync_command() -> str: from pathlib import Path root = Path({SANDBOX_SKILLS_ROOT!r}) -zip_path = root / "skills.zip" +zip_path = root / {zip_name!r} tmp_extract = Path(f"{{root}}_tmp_extract") managed_file = root / {_MANAGED_SKILLS_FILE!r} @@ -435,16 +613,22 @@ def _decode_sync_payload(stdout: str) -> dict | None: return None -def _update_sandbox_skills_cache(payload: dict | None) -> None: +def _update_sandbox_skills_cache( + payload: dict | None, + provider_id: str | None = None, +) -> None: if not isinstance(payload, dict): return skills = payload.get("skills", []) if not isinstance(skills, list): return - SkillManager().set_sandbox_skills_cache(skills) + SkillManager().set_sandbox_skills_cache(skills, provider_id=provider_id) -async def _apply_skills_to_sandbox(booter: ComputerBooter) -> None: +async def _apply_skills_to_sandbox( + booter: ComputerBooter, + zip_name: str = "skills.zip", +) -> None: """Apply local skill bundle to sandbox filesystem only. This function is intentionally limited to file mutation. Metadata scanning is @@ -452,7 +636,7 @@ async def _apply_skills_to_sandbox(booter: ComputerBooter) -> None: """ logger.info("[Computer] Skill sync phase=apply start") apply_result = _normalize_shell_exec_result( - await booter.shell.exec(_build_apply_sync_command()) + await booter.shell.exec(_build_apply_sync_command(zip_name)) ) if not _shell_exec_succeeded(apply_result): detail = _format_exec_error_detail(apply_result) @@ -480,63 +664,70 @@ async def _scan_sandbox_skills(booter: ComputerBooter) -> dict | None: return payload -async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: +async def _sync_skills_to_sandbox( + booter: ComputerBooter, + provider_id: str | None = None, +) -> None: """Sync local skills to sandbox and refresh cache. Backward-compatible orchestrator: keep historical behavior while internally splitting into `apply` and `scan` phases. """ - sync_skill_dirs = _collect_sync_skill_dirs() + async with _SANDBOX_SKILLS_SYNC_LOCK: + sync_skill_dirs = _collect_sync_skill_dirs() - temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) - zip_base = temp_dir / "skills_bundle" - zip_path = zip_base.with_suffix(".zip") - bundle_root = temp_dir / f"skills_bundle_{uuid.uuid4().hex}" + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + zip_base = temp_dir / f"skills_bundle_{uuid.uuid4().hex}" + zip_path = zip_base.with_suffix(".zip") + bundle_root = temp_dir / f"{zip_base.name}_contents" + remote_zip_name = f"{zip_base.name}.zip" + remote_zip = Path(SANDBOX_SKILLS_ROOT) / remote_zip_name - try: - if sync_skill_dirs: - if zip_path.exists(): - zip_path.unlink() - if bundle_root.exists(): - shutil.rmtree(bundle_root) - bundle_root.mkdir(parents=True) - for skill_name, skill_dir in sync_skill_dirs: - shutil.copytree(skill_dir, bundle_root / skill_name) - shutil.make_archive(str(zip_base), "zip", str(bundle_root)) - remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip" - logger.info("Uploading skills bundle to sandbox...") - await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}") - upload_result = await booter.upload_file(str(zip_path), str(remote_zip)) - if not upload_result.get("success", False): - raise RuntimeError("Failed to upload skills bundle to sandbox.") - else: + try: + if sync_skill_dirs: + bundle_root.mkdir(parents=True) + for skill_name, skill_dir in sync_skill_dirs: + shutil.copytree(skill_dir, bundle_root / skill_name) + shutil.make_archive(str(zip_base), "zip", str(bundle_root)) + logger.info("Uploading skills bundle to sandbox...") + await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}") + upload_result = await booter.upload_file(str(zip_path), str(remote_zip)) + if not upload_result.get("success", False): + raise RuntimeError("Failed to upload skills bundle to sandbox.") + else: + logger.info( + "No local skills found. Keeping sandbox built-ins and refreshing metadata." + ) + await booter.shell.exec( + f"rm -f {SANDBOX_SKILLS_ROOT}/{remote_zip_name}" + ) + + # Keep backward-compatible behavior while splitting lifecycle into two + # observable phases: apply (filesystem mutation) + scan (metadata read). + await _apply_skills_to_sandbox(booter, zip_name=remote_zip_name) + payload = await _scan_sandbox_skills(booter) + _update_sandbox_skills_cache(payload, provider_id=provider_id) + managed = ( + payload.get("managed_skills", []) if isinstance(payload, dict) else [] + ) logger.info( - "No local skills found. Keeping sandbox built-ins and refreshing metadata." + "[Computer] Sandbox skill sync complete: managed=%d", + len(managed), ) - await booter.shell.exec(f"rm -f {SANDBOX_SKILLS_ROOT}/skills.zip") - - # Keep backward-compatible behavior while splitting lifecycle into two - # observable phases: apply (filesystem mutation) + scan (metadata read). - await _apply_skills_to_sandbox(booter) - payload = await _scan_sandbox_skills(booter) - _update_sandbox_skills_cache(payload) - managed = payload.get("managed_skills", []) if isinstance(payload, dict) else [] - logger.info( - "[Computer] Sandbox skill sync complete: managed=%d", - len(managed), - ) - finally: - if bundle_root.exists(): - try: - shutil.rmtree(bundle_root) - except Exception: - logger.warning(f"Failed to remove temp skills bundle: {bundle_root}") - if zip_path.exists(): - try: - zip_path.unlink() - except Exception: - logger.warning(f"Failed to remove temp skills zip: {zip_path}") + finally: + if bundle_root.exists(): + try: + shutil.rmtree(bundle_root) + except Exception: + logger.warning( + f"Failed to remove temp skills bundle: {bundle_root}" + ) + if zip_path.exists(): + try: + zip_path.unlink() + except Exception: + logger.warning(f"Failed to remove temp skills zip: {zip_path}") async def get_booter( @@ -551,126 +742,55 @@ async def get_booter( elif runtime == "none": raise RuntimeError("Sandbox runtime is disabled by configuration.") - sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) - booter_type = sandbox_cfg.get("booter", "shipyard_neo") - cua_idle_timeout = _get_cua_idle_timeout(config) if booter_type == "cua" else 0.0 - - if session_id in session_booter: - booter = session_booter[session_id] - if not await booter.available(): - # Clean up old booter before rebuilding so sandbox resources - # on Bay (containers, volumes, networks) are not leaked. - # Only ShipyardNeoBooter supports delete_sandbox; other booters - # (local, boxlite, cua, etc.) are not backed by a remote sandbox - # manager and don't need it. - try: - if booter_type == "shipyard_neo": - await booter.shutdown(delete_sandbox=True) - else: - await booter.shutdown() - except Exception as shutdown_err: - logger.warning( - "[Computer] Error shutting down stale booter for session %s: %s", - session_id, - shutdown_err, - ) - _clear_cua_idle_state(session_id) - session_booter.pop(session_id, None) - if session_id not in session_booter: - uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex - logger.info( - f"[Computer] Initializing booter: type={booter_type}, session={session_id}" - ) - if booter_type == "shipyard": - from .booters.shipyard import ShipyardBooter - - ep = sandbox_cfg.get("shipyard_endpoint", "") - token = sandbox_cfg.get("shipyard_access_token", "") - ttl = sandbox_cfg.get("shipyard_ttl", 3600) - max_sessions = sandbox_cfg.get("shipyard_max_sessions", 10) - - client = ShipyardBooter( - endpoint_url=ep, access_token=token, ttl=ttl, session_num=max_sessions - ) - elif booter_type == "shipyard_neo": - from .booters.shipyard_neo import ShipyardNeoBooter - - ep = sandbox_cfg.get("shipyard_neo_endpoint", "") - token = sandbox_cfg.get("shipyard_neo_access_token", "") - ttl = sandbox_cfg.get("shipyard_neo_ttl", 3600) - profile = sandbox_cfg.get("shipyard_neo_profile", "python-default") - - # Auto-discover token from Bay's credentials.json if not configured - if not token: - token = _discover_bay_credentials(ep) - - logger.info( - f"[Computer] Shipyard Neo config: endpoint={ep}, profile={profile}, ttl={ttl}" - ) - client = ShipyardNeoBooter( - endpoint_url=ep, - access_token=token, - profile=profile, - ttl=ttl, - ) - elif booter_type == "cua": - from .booters.cua import CuaBooter, build_cua_booter_kwargs - - cua_kwargs = build_cua_booter_kwargs(sandbox_cfg) - logger.info( - f"[Computer] CUA config: image={cua_kwargs['image']}, " - f"os_type={cua_kwargs['os_type']}, ttl={cua_kwargs['ttl']}" + current_sandbox_id = sandbox_manager.registry.get_current_sandbox_id(session_id) + if current_sandbox_id: + current_record = sandbox_manager.registry.get_sandbox(current_sandbox_id) + if current_record and current_record.get("managed"): + return await sandbox_manager.get_observer_booter_by_id( + current_sandbox_id, + session_id, + require_lease=True, + context=context, ) - client = CuaBooter(**cua_kwargs) - elif booter_type == "boxlite": - from .booters.boxlite import BoxliteBooter - client = BoxliteBooter() - else: - raise ValueError(f"Unknown booter type: {booter_type}") - - try: - await client.boot(uuid_str) - logger.info( - f"[Computer] Sandbox booted successfully: type={booter_type}, session={session_id}" - ) - await _sync_skills_to_sandbox(client) - except Exception as e: - logger.error(f"Error booting sandbox for session {session_id}: {e}") - try: - if booter_type == "shipyard_neo": - await client.shutdown(delete_sandbox=True) - else: - await client.shutdown() - except Exception as shutdown_error: - logger.warning( - "Failed to shutdown sandbox after boot error for session %s: %s", - session_id, - shutdown_error, - ) - _clear_cua_idle_state(session_id) - raise e + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + provider_id = str(sandbox_cfg.get("booter", "")).strip() + if not provider_id: + raise ValueError( + "Sandbox provider is not configured. Install and enable a sandbox provider plugin, then select it in provider_settings.sandbox.booter." + ) - session_booter[session_id] = client - if booter_type == "cua": - _schedule_cua_idle_cleanup(session_id, cua_idle_timeout) - return session_booter[session_id] + logger.info( + f"[Computer] Initializing sandbox provider: provider={provider_id}, session={session_id}" + ) + if provider_id in sandbox_manager.providers: + return await sandbox_manager.get_or_create_booter( + context, + session_id, + provider_id, + ) + raise ValueError( + f"Unknown sandbox provider: {provider_id}. Install and enable a sandbox provider plugin, then select it in provider_settings.sandbox.booter." + ) async def sync_skills_to_active_sandboxes() -> None: """Best-effort skills synchronization for all active sandbox sessions.""" + active_booters = list(sandbox_manager.session_booter.items()) logger.info( - "[Computer] Syncing skills to %d active sandbox(es)", len(session_booter) + "[Computer] Syncing skills to %d active sandbox(es)", len(active_booters) ) - for session_id, booter in list(session_booter.items()): + for sandbox_id, booter in active_booters: + record = sandbox_manager.registry.get_sandbox(sandbox_id) or {} + provider_id = str(record.get("provider") or "").strip() or None try: - if not await booter.available(): + if not await sandbox_manager.booter_available(booter): continue - await _sync_skills_to_sandbox(booter) + await _sync_skills_to_sandbox(booter, provider_id=provider_id) except Exception as e: logger.warning( - "Failed to sync skills to sandbox for session %s: %s", - session_id, + "Failed to sync skills to sandbox for sandbox %s: %s", + sandbox_id, e, ) diff --git a/astrbot/core/computer/sandbox_manager.py b/astrbot/core/computer/sandbox_manager.py new file mode 100644 index 0000000000..3847a0eebc --- /dev/null +++ b/astrbot/core/computer/sandbox_manager.py @@ -0,0 +1,1969 @@ +from __future__ import annotations + +import asyncio +import inspect +import math +import time +import uuid +from dataclasses import dataclass + +from astrbot.api import logger +from astrbot.core.computer.booters.base import ComputerBooter +from astrbot.core.computer.sandbox_models import SandboxRecord, SandboxStatus +from astrbot.core.computer.sandbox_provider import SandboxProvider +from astrbot.core.computer.sandbox_registry import SandboxRegistry +from astrbot.core.computer.sandbox_timeouts import ( + DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS, + expires_at_from_timeout, + get_provider_sandbox_config, + idle_cleanup_at_from_record, + lease_is_active, + resolve_sandbox_timeout, +) +from astrbot.core.star.context import Context + +SANDBOX_LEASE_SECONDS = int(DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS) +MAX_SANDBOX_LEASE_ATTEMPTS = 3 +MAX_IDLE_DESTROY_ATTEMPTS = 3 +SANDBOX_TTL_DESTROY_RETRY_SECONDS = 60 + + +@dataclass(slots=True) +class SandboxIdleState: + expires_at: float + task: asyncio.Task + + +@dataclass(slots=True) +class SandboxExpirationState: + expires_at: float + monotonic_expires_at: float + task: asyncio.Task + + +class SandboxManager: + def __init__( + self, + *, + registry: SandboxRegistry, + providers: dict[str, SandboxProvider], + ) -> None: + self.registry = registry + self.providers = providers + self.session_booter: dict[str, ComputerBooter] = {} + self.idle_state: dict[str, SandboxIdleState] = {} + self.expiration_state: dict[str, SandboxExpirationState] = {} + self.boot_locks: dict[str, asyncio.Lock] = {} + self.created_hook_inflight: set[str] = set() + self.pending_boot_tasks: dict[str, asyncio.Task] = {} + self.pending_destroy_tasks: dict[str, asyncio.Task] = {} + + def _ensure_unique_sandbox_name( + self, sandbox_name: str, *, exclude_sandbox_id: str | None = None + ) -> str: + normalized_name = str(sandbox_name).strip() + for record in self.registry.list_sandboxes(): + if record.get("sandbox_id") == exclude_sandbox_id: + continue + if str(record.get("sandbox_name") or "").strip() == normalized_name: + raise RuntimeError(f"Sandbox name '{normalized_name}' already exists") + return normalized_name + + def _created_sandbox_name(self, sandbox_id: str, sandbox_name: str | None) -> str: + if sandbox_name is None: + return sandbox_id + normalized_name = str(sandbox_name).strip() + if not normalized_name: + return sandbox_id + return self._ensure_unique_sandbox_name(normalized_name) + + def save_registry(self) -> None: + try: + self.registry.save() + except Exception as exc: + logger.warning("[Computer] Failed to save sandbox registry: %s", exc) + raise + + async def save_registry_async(self) -> None: + try: + await self.registry.save_async() + except Exception as exc: + logger.warning("[Computer] Failed to save sandbox registry: %s", exc) + raise + + async def _defer_lifecycle_task_start(self) -> None: + # Let the request that queued this lifecycle work finish before a + # provider boot/destroy path gets a chance to monopolize the event loop. + await asyncio.sleep(0) + + def _sandbox_boot_lock(self, sandbox_id: str) -> asyncio.Lock: + lock = self.boot_locks.get(sandbox_id) + if lock is None: + lock = asyncio.Lock() + self.boot_locks[sandbox_id] = lock + return lock + + def _lease_timeout(self, context: Context | None, session_id: str) -> float: + sandbox_cfg = get_provider_sandbox_config(context, session_id) + return resolve_sandbox_timeout( + sandbox_cfg, + "sandbox_lease_timeout", + aliases=("lease_timeout",), + default=DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS, + ) + + def _idle_timeout(self, context: Context | None, session_id: str) -> float: + sandbox_cfg = get_provider_sandbox_config(context, session_id) + return resolve_sandbox_timeout( + sandbox_cfg, + "sandbox_idle_timeout", + default=0.0, + ) + + def _expires_at( + self, context: Context | None, session_id: str, idle_timeout: float + ) -> float | None: + if idle_timeout > 0: + return None + sandbox_cfg = get_provider_sandbox_config(context, session_id) + ttl = resolve_sandbox_timeout( + sandbox_cfg, + "sandbox_ttl", + default=0.0, + ) + return expires_at_from_timeout(ttl) + + def _sandbox_policy_timeouts( + self, context: Context | None, session_id: str + ) -> tuple[float, float | None]: + idle_timeout = self._idle_timeout(context, session_id) + return idle_timeout, self._expires_at(context, session_id, idle_timeout) + + def _max_sandboxes(self, context: Context | None, session_id: str) -> int: + sandbox_cfg = get_provider_sandbox_config(context, session_id) + try: + max_sandboxes = int(sandbox_cfg.get("max_sandboxes", 10)) + except (TypeError, ValueError): + return 10 + if max_sandboxes < 0: + return 0 + return max_sandboxes + + def _ensure_under_max_sandboxes( + self, context: Context | None, session_id: str + ) -> None: + max_sandboxes = self._max_sandboxes(context, session_id) + if max_sandboxes <= 0: + return + managed_count = sum( + 1 for record in self.registry.list_sandboxes() if record.get("managed") + ) + if managed_count >= max_sandboxes: + raise RuntimeError( + f"Sandbox limit reached. Maximum managed sandboxes: {max_sandboxes}." + ) + + def drop_boot_lock(self, sandbox_id: str) -> None: + self.boot_locks.pop(sandbox_id, None) + + def clear_runtime_state(self, sandbox_id: str) -> None: + self.session_booter.pop(sandbox_id, None) + self.clear_idle_state(sandbox_id) + self.clear_expiration_state(sandbox_id) + self.created_hook_inflight.discard(sandbox_id) + + def clear_runtime_state_and_drop_lock(self, sandbox_id: str) -> None: + self.clear_runtime_state(sandbox_id) + self.drop_boot_lock(sandbox_id) + + def clear_all_runtime_state(self) -> None: + for sandbox_id in list(self.session_booter): + self.clear_runtime_state(sandbox_id) + for sandbox_id in list(self.idle_state): + self.clear_runtime_state_and_drop_lock(sandbox_id) + for sandbox_id in list(self.expiration_state): + self.clear_runtime_state(sandbox_id) + self.boot_locks.clear() + + async def cancel_pending_boot_task(self, sandbox_id: str) -> None: + task = self.pending_boot_tasks.pop(sandbox_id, None) + if task is None: + return + task.cancel() + try: + done, _pending = await asyncio.wait({task}, timeout=1) + if not done: + logger.warning( + "[Computer] Timed out waiting for pending sandbox boot task cancellation: %s", + sandbox_id, + ) + return + await task + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning( + "[Computer] Pending sandbox boot task ended with error for %s: %s", + sandbox_id, + exc, + ) + + async def wait_pending_destroy_task( + self, sandbox_id: str, *, timeout: float | None = 1 + ) -> None: + task = self.pending_destroy_tasks.get(sandbox_id) + if task is None: + return + try: + if timeout is None: + await asyncio.shield(task) + else: + await asyncio.wait_for(asyncio.shield(task), timeout=timeout) + except TimeoutError: + if not task.done(): + logger.warning( + "[Computer] Timed out waiting for pending sandbox destroy task: %s", + sandbox_id, + ) + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning( + "[Computer] Pending sandbox destroy task ended with error for %s: %s", + sandbox_id, + exc, + ) + finally: + if task.done(): + self.pending_destroy_tasks.pop(sandbox_id, None) + + def get_provider(self, provider_id: str) -> SandboxProvider: + provider = self.providers.get(provider_id) + if provider is None: + raise RuntimeError(f"Provider {provider_id} is not supported") + return provider + + def build_record_payload( + self, + *, + sandbox_id: str, + sandbox_name: str, + session_id: str, + provider_id: str, + idle_timeout: float, + expires_at: float | None, + connect_info: dict, + is_default: bool = False, + status: str = SandboxStatus.RUNNING, + ) -> dict: + return { + "sandbox_id": sandbox_id, + "sandbox_name": sandbox_name, + "provider": provider_id, + "managed": True, + "created_by_astrbot": True, + "owner_user_id": session_id, + "owner_session_id": session_id, + "connect_info": connect_info, + "capabilities": sorted( + getattr(self.get_provider(provider_id), "capabilities", set()) + ), + "tool_names": sorted( + getattr(self.get_provider(provider_id), "tool_names", set()) + ), + "is_default": is_default, + "idle_timeout": idle_timeout, + "expires_at": expires_at, + "status": status, + } + + def new_sandbox_id(self, provider_id: str) -> str: + return f"{provider_id}-{uuid.uuid4().hex[:12]}" + + def get_default_sandbox_id(self, provider_id: str) -> str | None: + default_sandbox_id = self.registry.get_default_sandbox_id(provider_id) + if default_sandbox_id: + record = self.registry.get_sandbox(default_sandbox_id) + if record and record.get("provider") == provider_id: + return default_sandbox_id + for record in self.registry.list_sandboxes(): + if record.get("managed") and record.get("provider") == provider_id: + return record["sandbox_id"] + return None + + async def booter_available(self, booter: ComputerBooter) -> bool: + available = getattr(booter, "available", None) + if available is None: + return False + if getattr(available, "__isabstractmethod__", False): + return False + result = available() if callable(available) else available + if inspect.isawaitable(result): + result = await result + if result is None: + return False + return bool(result) + + def acquire_lease( + self, sandbox_id: str, session_id: str, *, ttl: float | None = None + ) -> bool: + return self.registry.acquire_lease( + sandbox_id=sandbox_id, + session_id=session_id, + user_id=session_id, + ttl=DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS if ttl is None else ttl, + ) + + def sandbox_has_active_lease(self, sandbox_id: str) -> bool: + record = self.registry.get_sandbox(sandbox_id) + if record is None: + return False + return lease_is_active( + record.get("controller_session_id"), + record.get("lease_expires_at"), + ) + + def _release_expired_lease(self, record: dict) -> dict: + sandbox_id = record.get("sandbox_id") + controller_session_id = record.get("controller_session_id") + if not sandbox_id or not controller_session_id: + return record + if lease_is_active(controller_session_id, record.get("lease_expires_at")): + return record + released = self.registry.release_lease(sandbox_id) or record + if self.registry.get_current_sandbox_id(controller_session_id) == sandbox_id: + self.registry.set_current_sandbox_id(controller_session_id, None) + return released + + def sandbox_controlled_by_other_session( + self, sandbox_id: str, session_id: str + ) -> bool: + record = self.registry.get_sandbox(sandbox_id) + if record is None: + return False + controller_session_id = record.get("controller_session_id") + if not controller_session_id or controller_session_id == session_id: + return False + return lease_is_active(controller_session_id, record.get("lease_expires_at")) + + async def _upsert_new_sandbox_record( + self, context: Context, session_id: str, provider_id: str, create_config: dict + ) -> str: + self._ensure_under_max_sandboxes(context, session_id) + provider = self.get_provider(provider_id) + sandbox_id = self.new_sandbox_id(provider_id) + idle_timeout, expires_at = self._sandbox_policy_timeouts(context, session_id) + self.registry.upsert_sandbox( + **self.build_record_payload( + sandbox_id=sandbox_id, + sandbox_name=sandbox_id, + session_id=session_id, + provider_id=provider_id, + idle_timeout=idle_timeout, + expires_at=expires_at, + connect_info=provider.build_connect_info( + sandbox_id, + {**create_config, "sandbox_id": sandbox_id}, + ), + ) + ) + await self.save_registry_async() + return sandbox_id + + def _find_idle_provider_sandbox_id( + self, provider_id: str, *, exclude: set[str] | None = None + ) -> str | None: + excluded = exclude or set() + for record in self.registry.list_sandboxes(): + sandbox_id = record.get("sandbox_id") + if not sandbox_id or sandbox_id in excluded: + continue + if not record.get("managed") or record.get("provider") != provider_id: + continue + if record.get("status") != SandboxStatus.RUNNING: + continue + if self.sandbox_has_active_lease(sandbox_id): + continue + if sandbox_id not in self.session_booter: + continue + return sandbox_id + return None + + @staticmethod + def _sandbox_can_be_bootstrapped(record: dict) -> bool: + status = record.get("status") + if status == SandboxStatus.RUNNING: + return True + return bool( + record.get("retention_policy") == "persistent" + and status == SandboxStatus.UNKNOWN + ) + + async def get_or_create_booter( + self, context: Context, session_id: str, provider_id: str + ) -> ComputerBooter: + provider = self.get_provider(provider_id) + create_config = provider.build_create_config(context, session_id) + idle_timeout, expires_at = self._sandbox_policy_timeouts(context, session_id) + lease_timeout = self._lease_timeout(context, session_id) + + current_sandbox_id = self.registry.get_current_sandbox_id(session_id) + current_record = self.registry.get_sandbox(current_sandbox_id) + excluded_stale_current_ids: set[str] = set() + if current_sandbox_id and ( + current_record is None or current_record.get("provider") != provider_id + ): + if ( + current_record + and current_record.get("controller_session_id") == session_id + ): + self.registry.release_lease(current_sandbox_id) + self.registry.set_current_sandbox_id(session_id, None) + await self.save_registry_async() + current_sandbox_id = None + current_record = None + if current_sandbox_id and current_record: + status = current_record.get("status") + current_controller_session_id = current_record.get("controller_session_id") + if status == SandboxStatus.RUNNING and ( + current_controller_session_id != session_id + or not lease_is_active( + current_controller_session_id, + current_record.get("lease_expires_at"), + ) + ): + self._release_expired_lease(current_record) + self.registry.set_current_sandbox_id(session_id, None) + await self.save_registry_async() + excluded_stale_current_ids.add(current_sandbox_id) + current_sandbox_id = None + current_record = None + status = None + if current_sandbox_id and current_record: + status = current_record.get("status") + if status == SandboxStatus.CREATING: + pending_boot_task = self.pending_boot_tasks.get(current_sandbox_id) + if pending_boot_task is not None: + await asyncio.shield(pending_boot_task) + current_record = self.registry.get_sandbox(current_sandbox_id) + status = current_record.get("status") if current_record else None + if status in { + SandboxStatus.CREATING, + SandboxStatus.RESTORING, + SandboxStatus.STOPPING, + SandboxStatus.ERROR, + }: + if current_record.get("controller_session_id") == session_id: + self.registry.release_lease(current_sandbox_id) + self.registry.set_current_sandbox_id(session_id, None) + await self.save_registry_async() + current_sandbox_id = None + current_record = None + elif ( + current_record.get("retention_policy") == "persistent" + and status == SandboxStatus.UNKNOWN + and current_sandbox_id not in self.session_booter + ): + current_record = await self._revive_persistent_booter_if_needed( + current_record, current_sandbox_id, session_id, context + ) + if ( + current_sandbox_id + and current_record + and current_record.get("provider") == provider_id + and current_sandbox_id in self.session_booter + ): + if not self.acquire_lease( + current_sandbox_id, session_id, ttl=lease_timeout + ): + self.registry.set_current_sandbox_id(session_id, None) + await self.save_registry_async() + else: + booter = self.session_booter[current_sandbox_id] + if await self.booter_available(booter): + self.registry.touch_sandbox(current_sandbox_id) + await self.save_registry_async() + self.schedule_lifecycle_cleanup( + current_sandbox_id, + idle_timeout, + current_record.get("expires_at"), + ) + return booter + self.clear_runtime_state(current_sandbox_id) + self.registry.release_lease(current_sandbox_id) + await self.save_registry_async() + + created_target_record = False + target_sandbox_id = self.get_default_sandbox_id(provider_id) + if target_sandbox_id in excluded_stale_current_ids: + target_sandbox_id = None + target_record = self.registry.get_sandbox(target_sandbox_id) + if ( + target_sandbox_id + and target_record + and target_record.get("provider") == provider_id + and target_record.get("retention_policy") == "persistent" + and target_record.get("status") == SandboxStatus.UNKNOWN + ): + target_record = await self._revive_persistent_booter_if_needed( + target_record, target_sandbox_id, session_id, context + ) + elif target_record and not self._sandbox_can_be_bootstrapped(target_record): + target_sandbox_id = None + + if target_sandbox_id is None: + self._ensure_under_max_sandboxes(context, session_id) + target_sandbox_id = self.new_sandbox_id(provider_id) + created_target_record = True + record = self.registry.upsert_sandbox( + **self.build_record_payload( + sandbox_id=target_sandbox_id, + sandbox_name=target_sandbox_id, + session_id=session_id, + provider_id=provider_id, + idle_timeout=idle_timeout, + expires_at=expires_at, + connect_info=provider.build_connect_info( + target_sandbox_id, + {**create_config, "sandbox_id": target_sandbox_id}, + ), + is_default=True, + ) + ) + self.registry.set_default_sandbox_id(record["sandbox_id"]) + await self.save_registry_async() + + if self.sandbox_controlled_by_other_session(target_sandbox_id, session_id): + reusable_sandbox_id = self._find_idle_provider_sandbox_id( + provider_id, exclude={target_sandbox_id} + ) + if reusable_sandbox_id is not None: + target_sandbox_id = reusable_sandbox_id + created_target_record = False + else: + target_sandbox_id = await self._upsert_new_sandbox_record( + context, session_id, provider_id, create_config + ) + created_target_record = True + + for _attempt in range(MAX_SANDBOX_LEASE_ATTEMPTS): + async with self._sandbox_boot_lock(target_sandbox_id): + target_record = self.registry.get_sandbox(target_sandbox_id) + if target_record and not self._sandbox_can_be_bootstrapped( + target_record + ): + target_sandbox_id = await self._upsert_new_sandbox_record( + context, session_id, provider_id, create_config + ) + created_target_record = True + continue + + if target_sandbox_id in self.session_booter and not self.acquire_lease( + target_sandbox_id, session_id, ttl=lease_timeout + ): + target_sandbox_id = await self._upsert_new_sandbox_record( + context, session_id, provider_id, create_config + ) + created_target_record = True + continue + + if target_sandbox_id in self.session_booter: + booter = self.session_booter[target_sandbox_id] + if await self.booter_available(booter): + break + self.clear_runtime_state(target_sandbox_id) + self.registry.release_lease(target_sandbox_id) + self.registry.update_sandbox_status( + target_sandbox_id, SandboxStatus.UNKNOWN + ) + await self.save_registry_async() + + if not self.acquire_lease( + target_sandbox_id, session_id, ttl=lease_timeout + ): + target_sandbox_id = await self._upsert_new_sandbox_record( + context, session_id, provider_id, create_config + ) + created_target_record = True + continue + + try: + client = await provider.create_booter( + context, session_id, target_sandbox_id, create_config + ) + except Exception: + if created_target_record: + self.registry.delete_sandbox(target_sandbox_id) + else: + self.registry.release_lease(target_sandbox_id) + self.registry.update_sandbox_status( + target_sandbox_id, SandboxStatus.UNKNOWN + ) + self.clear_runtime_state(target_sandbox_id) + await self.save_registry_async() + raise + setattr(client, "sandbox_id", target_sandbox_id) + setattr(client, "provider_id", provider_id) + self.session_booter[target_sandbox_id] = client + break + else: + raise RuntimeError( + "Could not acquire sandbox lease after multiple attempts" + ) + + await self._finalize_created_booter( + provider, + target_sandbox_id, + session_id=session_id, + idle_timeout=idle_timeout, + remove_record_on_failure=created_target_record, + ) + await self._invoke_sandbox_created_hook(provider, target_sandbox_id) + return self.session_booter[target_sandbox_id] + + async def _finalize_created_booter( + self, + provider: SandboxProvider, + sandbox_id: str, + *, + session_id: str | None = None, + idle_timeout: float, + remove_record_on_failure: bool = False, + ) -> None: + """Common post-creation steps: persist, idle cleanup, skill sync, hooks.""" + booter = self.session_booter.get(sandbox_id) + record = self.registry.get_sandbox(sandbox_id) or {} + update_connect_info_after_boot = getattr( + provider, "update_connect_info_after_boot", None + ) + if booter is not None and callable(update_connect_info_after_boot): + connect_info = update_connect_info_after_boot(record, booter) + if connect_info is not None: + self.registry.update_sandbox_config( + sandbox_id, connect_info=connect_info + ) + self.registry.touch_sandbox(sandbox_id) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.RUNNING) + if session_id is not None: + self.registry.set_current_sandbox_id(session_id, sandbox_id) + try: + await self.save_registry_async() + except Exception: + if booter is not None: + try: + await provider.destroy_booter( + booter, self.registry.get_sandbox(sandbox_id) or {} + ) + except Exception as destroy_err: + logger.warning( + "[Computer] Failed to rollback sandbox %s after registry save error: %s", + sandbox_id, + destroy_err, + ) + self.clear_runtime_state(sandbox_id) + if remove_record_on_failure: + self.registry.delete_sandbox(sandbox_id) + else: + self.registry.release_lease(sandbox_id) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + if session_id is not None: + self.registry.set_current_sandbox_id(session_id, None) + raise + record = self.registry.get_sandbox(sandbox_id) or {} + self.schedule_lifecycle_cleanup( + sandbox_id, idle_timeout, record.get("expires_at") + ) + + # Auto-sync skills unless the provider opts out. Best-effort: a sync + # failure is logged but does not destroy the already-created sandbox. + if getattr(provider, "auto_sync_skills", True): + booter = self.session_booter.get(sandbox_id) + if booter is not None and hasattr(booter, "shell"): + try: + await self._sync_skills_to_booter( + booter, + provider_id=getattr(provider, "provider_id", None), + ) + except Exception as sync_err: + logger.warning( + "[Computer] Auto skill sync failed for %s: %s", + sandbox_id, + sync_err, + ) + + async def _invoke_sandbox_created_hook( + self, provider: SandboxProvider, sandbox_id: str + ) -> None: + """Invoke provider's on_sandbox_created hook if present. + + Each sandbox only fires the hook once, guarded by a persistent flag in + the registry record so that dashboard-created sandboxes still receive + the hook when they are first leased via switch/takeover. + + The flag is only set on success so that a transient hook failure can + be retried on the next lease operation. The check-and-set is protected + by the sandbox boot lock to prevent duplicate triggers under concurrent + lease operations. + """ + if not hasattr(provider, "on_sandbox_created"): + async with self._sandbox_boot_lock(sandbox_id): + if not self.registry.has_created_hook_fired(sandbox_id): + self.registry.mark_created_hook_fired(sandbox_id) + await self.save_registry_async() + return + + async with self._sandbox_boot_lock(sandbox_id): + record = self.registry.get_sandbox(sandbox_id) or {} + if ( + record.get("created_hook_fired") + or sandbox_id in self.created_hook_inflight + ): + return + self.created_hook_inflight.add(sandbox_id) + + should_mark_fired = False + try: + await provider.on_sandbox_created(record) + should_mark_fired = True + except Exception as hook_err: + logger.warning( + "[Computer] on_sandbox_created hook failed for %s: %s", + sandbox_id, + hook_err, + ) + return + finally: + async with self._sandbox_boot_lock(sandbox_id): + if should_mark_fired: + if not self.registry.has_created_hook_fired(sandbox_id): + self.registry.mark_created_hook_fired(sandbox_id) + await self.save_registry_async() + self.created_hook_inflight.discard(sandbox_id) + + async def create_sandbox_uncontrolled( + self, + context: Context, + session_id: str, + provider_id: str, + sandbox_name: str | None = None, + ) -> dict: + provider = self.get_provider(provider_id) + sandbox_id = self.new_sandbox_id(provider_id) + sandbox_name = self._created_sandbox_name(sandbox_id, sandbox_name) + self._ensure_under_max_sandboxes(context, session_id) + create_config = provider.build_create_config(context, session_id) + idle_timeout, expires_at = self._sandbox_policy_timeouts(context, session_id) + async with self._sandbox_boot_lock(sandbox_id): + record = self.registry.upsert_sandbox( + **self.build_record_payload( + sandbox_id=sandbox_id, + sandbox_name=sandbox_name, + session_id=session_id, + provider_id=provider_id, + idle_timeout=idle_timeout, + expires_at=expires_at, + connect_info=provider.build_connect_info( + sandbox_name, + {**create_config, "sandbox_id": sandbox_id}, + ), + status=SandboxStatus.CREATING, + ) + ) + try: + client = await provider.create_booter( + context, session_id, sandbox_id, create_config + ) + except asyncio.CancelledError: + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.ERROR) + self.clear_runtime_state(sandbox_id) + await self.save_registry_async() + raise + except Exception: + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.ERROR) + self.clear_runtime_state(sandbox_id) + await self.save_registry_async() + raise + setattr(client, "sandbox_id", sandbox_id) + setattr(client, "provider_id", provider_id) + self.session_booter[sandbox_id] = client + await self._finalize_created_booter( + provider, + sandbox_id, + session_id=None, + idle_timeout=idle_timeout, + remove_record_on_failure=True, + ) + return self.registry.get_sandbox(sandbox_id) or record + + async def create_sandbox_uncontrolled_deferred( + self, + context: Context, + session_id: str, + provider_id: str, + sandbox_name: str | None = None, + ) -> dict: + provider = self.get_provider(provider_id) + sandbox_id = self.new_sandbox_id(provider_id) + sandbox_name = self._created_sandbox_name(sandbox_id, sandbox_name) + self._ensure_under_max_sandboxes(context, session_id) + create_config = provider.build_create_config(context, session_id) + idle_timeout, expires_at = self._sandbox_policy_timeouts(context, session_id) + async with self._sandbox_boot_lock(sandbox_id): + record = self.registry.upsert_sandbox( + **self.build_record_payload( + sandbox_id=sandbox_id, + sandbox_name=sandbox_name, + session_id=session_id, + provider_id=provider_id, + idle_timeout=idle_timeout, + expires_at=expires_at, + connect_info=provider.build_connect_info( + sandbox_name, + {**create_config, "sandbox_id": sandbox_id}, + ), + status=SandboxStatus.CREATING, + ) + ) + await self.save_registry_async() + + task = asyncio.create_task( + self._boot_sandbox_uncontrolled_deferred( + context=context, + session_id=session_id, + provider=provider, + sandbox_id=sandbox_id, + create_config=create_config, + idle_timeout=idle_timeout, + ) + ) + self.pending_boot_tasks[sandbox_id] = task + + return self.registry.get_sandbox(sandbox_id) or record + + async def _boot_sandbox_uncontrolled_deferred( + self, + *, + context: Context, + session_id: str, + provider: SandboxProvider, + sandbox_id: str, + create_config: dict, + idle_timeout: float, + ) -> None: + try: + await self._defer_lifecycle_task_start() + async with self._sandbox_boot_lock(sandbox_id): + current = self.registry.get_sandbox(sandbox_id) + if current is None or current.get("status") != SandboxStatus.CREATING: + return + + try: + client = await provider.create_booter( + context, session_id, sandbox_id, create_config + ) + except asyncio.CancelledError: + raise + except Exception as boot_err: + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.ERROR) + await self.save_registry_async() + logger.warning( + "[Computer] Deferred sandbox boot failed: sandbox_id=%s session_id=%s error=%s", + sandbox_id, + session_id, + boot_err, + ) + return + + current = self.registry.get_sandbox(sandbox_id) + if current is None or current.get("status") != SandboxStatus.CREATING: + try: + cleanup_record = self.registry.get_sandbox(sandbox_id) or {} + await provider.destroy_booter(client, cleanup_record) + except Exception as destroy_err: + logger.warning( + "[Computer] Deferred sandbox cleanup failed after record removal: sandbox_id=%s error=%s", + sandbox_id, + destroy_err, + ) + return + + setattr(client, "sandbox_id", sandbox_id) + setattr(client, "provider_id", provider.provider_id) + self.session_booter[sandbox_id] = client + await self._finalize_created_booter( + provider, + sandbox_id, + session_id=None, + idle_timeout=idle_timeout, + remove_record_on_failure=True, + ) + finally: + self.pending_boot_tasks.pop(sandbox_id, None) + + async def create_sandbox( + self, + context: Context, + session_id: str, + provider_id: str, + sandbox_name: str | None = None, + ) -> dict: + sandbox = await self.create_sandbox_uncontrolled( + context, session_id, provider_id, sandbox_name + ) + sandbox_id = sandbox["sandbox_id"] + lease_timeout = self._lease_timeout(context, session_id) + if not self.acquire_lease(sandbox_id, session_id, ttl=lease_timeout): + provider = self.get_provider(sandbox.get("provider", "")) + await self._destroy_sandbox_cleanup(provider, sandbox_id, sandbox) + raise RuntimeError(f"Sandbox {sandbox_id} is busy") + await self._set_current_sandbox_after_lease(session_id, sandbox_id, sandbox) + provider = self.get_provider(sandbox.get("provider", "")) + # Reset idle cleanup after lease acquisition. The uncontrolled + # creation path already schedules cleanup, but a slow skill-sync or + # short idle_timeout could let the timer expire before the lease is + # acquired. Re-scheduling here guarantees a full idle window. + idle_timeout = sandbox.get("idle_timeout") or 0 + self.schedule_lifecycle_cleanup( + sandbox_id, float(idle_timeout), sandbox.get("expires_at") + ) + await self._invoke_sandbox_created_hook(provider, sandbox_id) + return self.registry.get_sandbox(sandbox_id) or sandbox + + def list_sandboxes(self) -> list[dict]: + records = [] + for record in self.registry.list_sandboxes(): + if not record.get("managed"): + continue + if "booter_type" in record: + record = SandboxRecord.from_dict(record).to_dict() + record = self._release_expired_lease(record) + provider = self.providers.get(record.get("provider")) + updated = dict(record) + updated["capabilities"] = sorted( + getattr(provider, "capabilities", record.get("capabilities", [])) + if provider + else record.get("capabilities", []) + ) + updated["tool_names"] = sorted( + getattr(provider, "tool_names", record.get("tool_names", [])) + if provider + else record.get("tool_names", []) + ) + if self.sandbox_has_active_lease(updated["sandbox_id"]): + updated["idle_cleanup_at"] = None + else: + updated["idle_cleanup_at"] = idle_cleanup_at_from_record( + last_used_at=updated.get("last_used_at"), + idle_timeout=updated.get("idle_timeout"), + ) + records.append(updated) + return records + + async def list_sandboxes_checked(self) -> list[dict]: + changed = False + for record in self.registry.list_sandboxes(): + sandbox_id = record.get("sandbox_id") + if not sandbox_id or not record.get("managed"): + continue + booter = self.session_booter.get(sandbox_id) + if booter is None: + continue + try: + available = await self.booter_available(booter) + except Exception as exc: + logger.warning( + "[Computer] Sandbox health check failed for %s: %s", + sandbox_id, + exc, + ) + available = False + if available: + continue + self.clear_runtime_state_and_drop_lock(sandbox_id) + next_status = ( + SandboxStatus.UNKNOWN + if record.get("retention_policy") == "persistent" + else SandboxStatus.ERROR + ) + self.registry.update_sandbox_status(sandbox_id, next_status) + changed = True + if changed: + await self.save_registry_async() + return self.list_sandboxes() + + def set_default_sandbox(self, sandbox_id: str) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + self.registry.set_default_sandbox_id(sandbox_id) + self.save_registry() + return self.registry.get_sandbox(sandbox_id) or record + + def update_sandbox_config( + self, + sandbox_id: str, + *, + sandbox_name: str | None = None, + idle_timeout: int | float | None, + expires_at: int | float | None, + retention_policy: str, + ) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + provider_id = record.get("provider", "") + provider = self.providers.get(provider_id) + if retention_policy not in {"temporary", "persistent"}: + raise RuntimeError("retention_policy must be temporary or persistent") + if retention_policy == "persistent" and provider is None: + raise RuntimeError(f"Provider {provider_id} is not available") + if ( + retention_policy == "persistent" + and provider is not None + and not getattr(provider, "supports_persistent_reconnect", False) + ): + raise RuntimeError( + f"Provider {record.get('provider')} does not support persistent sandboxes" + ) + if retention_policy == "persistent": + idle_timeout = None + expires_at = None + elif idle_timeout and float(idle_timeout) > 0: + expires_at = None + updates = { + "idle_timeout": idle_timeout, + "expires_at": expires_at, + "retention_policy": retention_policy, + } + if sandbox_name is not None: + normalized_name = str(sandbox_name).strip() + if not normalized_name: + raise ValueError("sandbox_name must be a non-empty string") + normalized_name = self._ensure_unique_sandbox_name( + normalized_name, exclude_sandbox_id=sandbox_id + ) + updates["sandbox_name"] = normalized_name + if provider is not None: + updates["connect_info"] = provider.update_connect_info( + record, + sandbox_name=normalized_name, + ) + updated = self.registry.update_sandbox_config(sandbox_id, **updates) + if retention_policy == "persistent": + self.clear_idle_state(sandbox_id) + self.clear_expiration_state(sandbox_id) + else: + self.schedule_lifecycle_cleanup( + sandbox_id, float(idle_timeout or 0), expires_at + ) + self.save_registry() + return updated or record + + def set_sandbox_retention_policy( + self, + context: Context | None, + session_id: str, + sandbox_id: str, + retention_policy: str, + *, + sandbox_name: str | None = None, + ) -> dict: + idle_timeout: float | None + expires_at: float | None + if retention_policy == "persistent": + idle_timeout = None + expires_at = None + else: + idle_timeout, expires_at = self._sandbox_policy_timeouts( + context, session_id + ) + return self.update_sandbox_config( + sandbox_id, + sandbox_name=sandbox_name, + idle_timeout=idle_timeout, + expires_at=expires_at, + retention_policy=retention_policy, + ) + + async def _revive_persistent_booter_if_needed( + self, + record: dict, + sandbox_id: str, + session_id: str | None, + context: Context | None, + ) -> dict: + if ( + context is None + or record.get("retention_policy") != "persistent" + or record.get("status") + not in {SandboxStatus.RUNNING, SandboxStatus.UNKNOWN} + ): + return record + + provider = self.get_provider(record.get("provider", "")) + if not getattr(provider, "supports_persistent_reconnect", False): + return record + + create_session_id = str( + record.get("owner_session_id") or session_id or "dashboard" + ) + create_config = provider.build_create_config(context, create_session_id) + connect_info = record.get("connect_info") or {} + create_config = { + **create_config, + "persistent_name": str( + connect_info.get("persistent_name") or sandbox_id + ).strip(), + "resume": True, + } + existing_runtime_id = connect_info.get("sandbox_id") + if existing_runtime_id: + create_config["sandbox_id"] = existing_runtime_id + existing_host_port = connect_info.get("host_port") + if existing_host_port: + create_config["host_port"] = existing_host_port + + async with self._sandbox_boot_lock(sandbox_id): + current = self.registry.get_sandbox(sandbox_id) + booter = self.session_booter.get(sandbox_id) + if ( + booter is None + and current is not None + and current.get("status") + in { + SandboxStatus.RUNNING, + SandboxStatus.UNKNOWN, + } + ): + previous_status = current.get("status") or SandboxStatus.UNKNOWN + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.RESTORING) + await self.save_registry_async() + try: + client = await provider.create_booter( + context, + create_session_id, + sandbox_id, + create_config, + ) + except asyncio.CancelledError: + latest = self.registry.get_sandbox(sandbox_id) + if ( + latest is not None + and latest.get("status") == SandboxStatus.RESTORING + ): + self.registry.update_sandbox_status(sandbox_id, previous_status) + await self.save_registry_async() + raise + except Exception: + latest = self.registry.get_sandbox(sandbox_id) + if ( + latest is not None + and latest.get("status") == SandboxStatus.RESTORING + ): + self.registry.update_sandbox_status(sandbox_id, previous_status) + await self.save_registry_async() + raise + setattr(client, "sandbox_id", sandbox_id) + setattr(client, "provider_id", provider.provider_id) + self.session_booter[sandbox_id] = client + await self._finalize_created_booter( + provider, + sandbox_id, + session_id=None, + idle_timeout=( + 0 + if record.get("retention_policy") == "persistent" + else self._idle_timeout(context, create_session_id) + ), + ) + return self.registry.get_sandbox(sandbox_id) or record + + async def switch_current_sandbox_checked( + self, session_id: str, sandbox_id: str, context: Context | None = None + ) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + record = await self._revive_persistent_booter_if_needed( + record, sandbox_id, session_id, context + ) + booter = self.session_booter.get(sandbox_id) + if booter is None: + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + if not await self.booter_available(booter): + self.session_booter.pop(sandbox_id, None) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + await self.save_registry_async() + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + lease_timeout = self._lease_timeout(context, session_id) + if not self.acquire_lease(sandbox_id, session_id, ttl=lease_timeout): + raise RuntimeError(f"Sandbox {sandbox_id} is busy") + result = await self._set_current_sandbox_after_lease( + session_id, sandbox_id, record + ) + provider = self.get_provider(record.get("provider", "")) + await self._invoke_sandbox_created_hook(provider, sandbox_id) + return result + + async def _set_current_sandbox_after_lease( + self, session_id: str, sandbox_id: str, record: dict + ) -> dict: + previous_sandbox_id = self.registry.get_current_sandbox_id(session_id) + if previous_sandbox_id and previous_sandbox_id != sandbox_id: + previous = self.registry.get_sandbox(previous_sandbox_id) + if previous and previous.get("controller_session_id") == session_id: + self.registry.release_lease(previous_sandbox_id) + self.registry.set_current_sandbox_id(session_id, sandbox_id) + self.registry.touch_sandbox(sandbox_id) + await self.save_registry_async() + return self.registry.get_sandbox(sandbox_id) or record + + def get_current_sandbox(self, session_id: str) -> dict: + sandbox_id = self.registry.get_current_sandbox_id(session_id) + sandbox = self.registry.get_sandbox(sandbox_id) if sandbox_id else None + if sandbox: + controller_session_id = sandbox.get("controller_session_id") + if controller_session_id != session_id or not lease_is_active( + controller_session_id, sandbox.get("lease_expires_at") + ): + self._release_expired_lease(sandbox) + self.registry.set_current_sandbox_id(session_id, None) + self.save_registry() + sandbox_id = None + sandbox = None + if sandbox: + provider = self.providers.get(sandbox.get("provider")) + if provider: + sandbox = dict(sandbox) + sandbox["capabilities"] = sorted( + getattr(provider, "capabilities", sandbox.get("capabilities", [])) + ) + sandbox["tool_names"] = sorted( + getattr(provider, "tool_names", sandbox.get("tool_names", [])) + ) + return { + "current_sandbox_id": sandbox_id, + "sandbox": sandbox, + } + + def release_current_sandbox( + self, session_id: str, sandbox_id: str | None = None + ) -> dict: + target_sandbox_id = sandbox_id or self.registry.get_current_sandbox_id( + session_id + ) + if target_sandbox_id is None: + raise RuntimeError("No current sandbox") + record = self.registry.get_sandbox(target_sandbox_id) + if record is None: + raise RuntimeError(f"Sandbox {target_sandbox_id} not found") + controller_session_id = record.get("controller_session_id") + if ( + controller_session_id + and controller_session_id != session_id + and self.sandbox_has_active_lease(target_sandbox_id) + ): + raise RuntimeError( + f"Sandbox {target_sandbox_id} is controlled by another session" + ) + released = self.registry.release_lease(target_sandbox_id) or record + if self.registry.get_current_sandbox_id(session_id) == target_sandbox_id: + self.registry.set_current_sandbox_id(session_id, None) + self.save_registry() + return released + + def force_release_sandbox(self, sandbox_id: str) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None: + raise RuntimeError(f"Sandbox {sandbox_id} not found") + controller_session_id = record.get("controller_session_id") + released = self.registry.release_lease(sandbox_id) or record + if controller_session_id: + if ( + self.registry.get_current_sandbox_id(controller_session_id) + == sandbox_id + ): + self.registry.set_current_sandbox_id(controller_session_id, None) + self.save_registry() + return released + + async def renew_current_sandbox_lease( + self, + session_id: str, + ttl_seconds: int | float | None = None, + context: Context | None = None, + ) -> dict: + current = self.get_current_sandbox(session_id) + sandbox_id = current.get("current_sandbox_id") + if sandbox_id is None: + raise RuntimeError("No current sandbox") + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + status = record.get("status") + if status == SandboxStatus.CREATING: + raise RuntimeError(f"Sandbox {sandbox_id} is still being created") + if status == SandboxStatus.RESTORING: + raise RuntimeError(f"Sandbox {sandbox_id} is being restored") + if status == SandboxStatus.STOPPING: + raise RuntimeError(f"Sandbox {sandbox_id} is being destroyed") + if status == SandboxStatus.STOPPED: + raise RuntimeError(f"Sandbox {sandbox_id} has been destroyed") + if status == SandboxStatus.ERROR: + raise RuntimeError( + f"Sandbox {sandbox_id} encountered an error during creation" + ) + if status != SandboxStatus.RUNNING: + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + booter = self.session_booter.get(sandbox_id) + if booter is None: + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + if not await self.booter_available(booter): + self.session_booter.pop(sandbox_id, None) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + await self.save_registry_async() + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + controller_session_id = record.get("controller_session_id") + if controller_session_id and controller_session_id != session_id: + raise RuntimeError(f"Sandbox {sandbox_id} is controlled by another session") + ttl = ( + self._lease_timeout(context, session_id) + if ttl_seconds is None + else float(ttl_seconds) + ) + if not math.isfinite(ttl): + raise RuntimeError("ttl_seconds must be finite") + if ttl < 0: + raise RuntimeError("ttl_seconds must be non-negative") + if not self.acquire_lease(sandbox_id, session_id, ttl=ttl): + raise RuntimeError(f"Sandbox {sandbox_id} is busy") + self.registry.touch_sandbox(sandbox_id) + self.save_registry() + return self.registry.get_sandbox(sandbox_id) or record + + async def takeover_sandbox( + self, + session_id: str, + sandbox_id: str, + context: Context | None = None, + ) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + record = await self._revive_persistent_booter_if_needed( + record, sandbox_id, session_id, context + ) + booter = self.session_booter.get(sandbox_id) + status = record.get("status") + if booter is None: + if status == SandboxStatus.CREATING: + raise RuntimeError(f"Sandbox {sandbox_id} is still being created") + if status == SandboxStatus.RESTORING: + raise RuntimeError(f"Sandbox {sandbox_id} is being restored") + if status == SandboxStatus.STOPPING: + raise RuntimeError(f"Sandbox {sandbox_id} is being destroyed") + if status == SandboxStatus.STOPPED: + raise RuntimeError(f"Sandbox {sandbox_id} has been destroyed") + if status == SandboxStatus.ERROR: + raise RuntimeError( + f"Sandbox {sandbox_id} encountered an error during creation" + ) + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + if not await self.booter_available(booter): + self.clear_runtime_state(sandbox_id) + next_status = ( + SandboxStatus.UNKNOWN + if record.get("retention_policy") == "persistent" + else SandboxStatus.ERROR + ) + self.registry.update_sandbox_status(sandbox_id, next_status) + await self.save_registry_async() + raise RuntimeError( + f"Sandbox {sandbox_id} is unavailable (booter health check failed)" + ) + previous_controller_session_id = record.get("controller_session_id") + updated = ( + self.registry.takeover_lease( + sandbox_id=sandbox_id, + session_id=session_id, + user_id=session_id, + ttl=self._lease_timeout(context, session_id), + ) + or record + ) + updated = await self._set_current_sandbox_after_lease( + session_id, sandbox_id, updated + ) + if ( + previous_controller_session_id + and previous_controller_session_id != session_id + and self.registry.get_current_sandbox_id(previous_controller_session_id) + == sandbox_id + ): + self.registry.set_current_sandbox_id(previous_controller_session_id, None) + await self.save_registry_async() + provider = self.get_provider(record.get("provider", "")) + await self._invoke_sandbox_created_hook(provider, sandbox_id) + return updated + + async def _destroy_sandbox_cleanup( + self, + provider: SandboxProvider, + sandbox_id: str, + record: dict, + ) -> None: + destroy_err: Exception | None = None + async with self._sandbox_boot_lock(sandbox_id): + current = self.registry.get_sandbox(sandbox_id) or record + booter = self.session_booter.get(sandbox_id) + if booter is not None: + try: + await provider.destroy_booter(booter, current) + except Exception as exc: + destroy_err = exc + logger.warning( + "[Computer] destroy_booter failed for %s: %s", + sandbox_id, + exc, + ) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.ERROR) + await self.save_registry_async() + else: + self.clear_runtime_state(sandbox_id) + if destroy_err is None: + self.registry.delete_sandbox(sandbox_id) + await self.save_registry_async() + + self.drop_boot_lock(sandbox_id) + + if destroy_err is not None: + raise destroy_err + + if hasattr(provider, "on_sandbox_destroyed"): + try: + await provider.on_sandbox_destroyed(record) + except Exception as hook_err: + logger.warning( + "[Computer] on_sandbox_destroyed hook failed for %s: %s", + sandbox_id, + hook_err, + ) + + async def destroy_sandbox(self, session_id: str, sandbox_id: str) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + if record.get("status") == SandboxStatus.STOPPING: + return record + controller_session_id = record.get("controller_session_id") + if ( + controller_session_id + and controller_session_id != session_id + and self.sandbox_has_active_lease(sandbox_id) + ): + raise RuntimeError(f"Sandbox {sandbox_id} is controlled by another session") + try: + provider = self.get_provider(record.get("provider", "")) + except RuntimeError: + if record.get("retention_policy") != "persistent": + raise + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + await self.save_registry_async() + return record + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.STOPPING) + await self.save_registry_async() + await self.cancel_pending_boot_task(sandbox_id) + await self._destroy_sandbox_cleanup(provider, sandbox_id, record) + return record + + async def destroy_sandbox_deferred(self, session_id: str, sandbox_id: str) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + if record.get("status") == SandboxStatus.STOPPING: + return record + controller_session_id = record.get("controller_session_id") + if ( + controller_session_id + and controller_session_id != session_id + and self.sandbox_has_active_lease(sandbox_id) + ): + raise RuntimeError(f"Sandbox {sandbox_id} is controlled by another session") + try: + provider = self.get_provider(record.get("provider", "")) + except RuntimeError: + if record.get("retention_policy") != "persistent": + raise + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + await self.save_registry_async() + return record + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.STOPPING) + await self.save_registry_async() + + async def _run_destroy_cleanup() -> None: + try: + await self._defer_lifecycle_task_start() + await self.cancel_pending_boot_task(sandbox_id) + await self._destroy_sandbox_cleanup(provider, sandbox_id, record) + finally: + self.pending_destroy_tasks.pop(sandbox_id, None) + + task = asyncio.create_task(_run_destroy_cleanup()) + self.pending_destroy_tasks[sandbox_id] = task + return self.registry.get_sandbox(sandbox_id) or record + + async def get_observer_booter_by_id( + self, + sandbox_id: str, + session_id: str | None = None, + *, + require_lease: bool = True, + context: Context | None = None, + ) -> ComputerBooter: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + controlled_by_other = bool( + session_id + and self.sandbox_controlled_by_other_session(sandbox_id, session_id) + ) + if controlled_by_other and require_lease: + raise RuntimeError(f"Sandbox {sandbox_id} is controlled by another session") + booter = self.session_booter.get(sandbox_id) + record = await self._revive_persistent_booter_if_needed( + record, sandbox_id, session_id, context + ) + booter = self.session_booter.get(sandbox_id) + status = record.get("status") + if booter is None: + if status == SandboxStatus.CREATING: + raise RuntimeError(f"Sandbox {sandbox_id} is still being created") + if status == SandboxStatus.RESTORING: + raise RuntimeError(f"Sandbox {sandbox_id} is being restored") + if status == SandboxStatus.STOPPING: + raise RuntimeError(f"Sandbox {sandbox_id} is being destroyed") + if status == SandboxStatus.STOPPED: + raise RuntimeError(f"Sandbox {sandbox_id} has been destroyed") + if status == SandboxStatus.ERROR: + raise RuntimeError( + f"Sandbox {sandbox_id} encountered an error during creation" + ) + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + if not await self.booter_available(booter): + self.session_booter.pop(sandbox_id, None) + next_status = ( + SandboxStatus.UNKNOWN + if record.get("retention_policy") == "persistent" + else SandboxStatus.ERROR + ) + self.registry.update_sandbox_status(sandbox_id, next_status) + await self.save_registry_async() + raise RuntimeError( + f"Sandbox {sandbox_id} is unavailable (booter health check failed)" + ) + if require_lease and session_id: + lease_timeout = self._lease_timeout(context, session_id) + if not self.acquire_lease(sandbox_id, session_id, ttl=lease_timeout): + raise RuntimeError(f"Sandbox {sandbox_id} is busy") + record = self.registry.get_sandbox(sandbox_id) or record + # Only touch lifecycle when the caller actually holds the lease (or + # the sandbox is unclaimed). Pure observer access must not reset + # idle timers for sandboxes controlled by other sessions. + if session_id and record.get("controller_session_id") == session_id: + self.registry.touch_sandbox(sandbox_id) + await self.save_registry_async() + idle_timeout = record.get("idle_timeout") or 0 + self.schedule_lifecycle_cleanup( + sandbox_id, float(idle_timeout), record.get("expires_at") + ) + return booter + + async def reconcile_on_startup(self) -> None: + for sandbox_id in list(self.pending_boot_tasks): + await self.cancel_pending_boot_task(sandbox_id) + for sandbox_id in list(self.pending_destroy_tasks): + await self.wait_pending_destroy_task(sandbox_id, timeout=None) + self.registry.load() + self.registry.reconcile_startup() + self.clear_all_runtime_state() + + # Validate persistent sandbox records against provider reality. + # If a provider reports that its persistent sandbox no longer exists + # externally, remove the stale registry record so the dashboard does + # not show ghost entries. + for record in list(self.registry.list_sandboxes()): + if record.get("retention_policy") != "persistent": + continue + try: + provider = self.get_provider(record.get("provider", "")) + except RuntimeError: + sandbox_id = record["sandbox_id"] + logger.info( + "[Computer] Provider for persistent sandbox %s is unavailable; keeping registry record", + sandbox_id, + ) + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + continue + if not getattr(provider, "supports_persistent_reconnect", False): + continue + check_exists = getattr(provider, "check_persistent_sandbox_exists", None) + if check_exists is None: + continue + try: + exists = await check_exists(record) + except Exception as exc: + logger.warning( + "[Computer] Failed to check persistent sandbox %s existence: %s", + record.get("sandbox_id"), + exc, + ) + continue + if not exists: + sandbox_id = record["sandbox_id"] + if not getattr(provider, "prune_missing_persistent_records", False): + logger.info( + "[Computer] Persistent sandbox %s was not confirmed externally; keeping registry record as unknown", + sandbox_id, + ) + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.UNKNOWN + ) + continue + logger.info( + "[Computer] Persistent sandbox %s no longer exists externally; removing registry record", + sandbox_id, + ) + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + + await self.save_registry_async() + + async def restore_persistent_sandboxes( + self, + context: Context, + *, + per_sandbox_timeout: float | None = None, + ) -> tuple[int, int]: + restored = 0 + deleted = 0 + for record in self.registry.list_sandboxes(): + sandbox_id = record["sandbox_id"] + if not record.get("managed"): + continue + if record.get("retention_policy") != "persistent": + continue + if record.get("status") not in { + SandboxStatus.RUNNING, + SandboxStatus.UNKNOWN, + }: + continue + try: + restore_coro = self._revive_persistent_booter_if_needed( + record=record, + sandbox_id=sandbox_id, + session_id=str(record.get("owner_session_id") or "dashboard"), + context=context, + ) + if per_sandbox_timeout is None: + await restore_coro + else: + await asyncio.wait_for(restore_coro, timeout=per_sandbox_timeout) + restored += 1 + except asyncio.TimeoutError: + self.session_booter.pop(sandbox_id, None) + self.clear_idle_state(sandbox_id) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + deleted += 1 + logger.warning( + "[Computer] Persistent sandbox restore timed out; keeping registry record as unknown: %s", + sandbox_id, + ) + except Exception as exc: + logger.warning( + "[Computer] Failed to restore persistent sandbox %s: %s", + sandbox_id, + exc, + ) + return restored, deleted + + async def cleanup_managed_sandboxes(self) -> None: + for sandbox_id in list(self.pending_boot_tasks): + await self.cancel_pending_boot_task(sandbox_id) + for sandbox_id in list(self.pending_destroy_tasks): + await self.wait_pending_destroy_task(sandbox_id, timeout=None) + managed_records = [ + record + for record in self.list_sandboxes() + if record["sandbox_id"] not in self.pending_destroy_tasks + ] + for record in managed_records: + sandbox_id = record["sandbox_id"] + if record.get("retention_policy") == "persistent": + booter = self.session_booter.get(sandbox_id) + if booter is not None: + try: + await booter.shutdown() + except Exception as shutdown_err: + logger.warning( + "[Computer] Failed to close persistent sandbox runtime %s: %s", + sandbox_id, + shutdown_err, + ) + self.clear_runtime_state_and_drop_lock(sandbox_id) + continue + provider = None + try: + provider = self.get_provider(record.get("provider", "")) + except RuntimeError as provider_error: + logger.warning( + "[Computer] Provider unavailable for sandbox %s: %s", + sandbox_id, + provider_error, + ) + booter = self.session_booter.get(sandbox_id) + if booter is not None: + if provider is not None: + try: + await provider.destroy_booter(booter, record) + except Exception as shutdown_err: + logger.warning( + "[Computer] Failed to shutdown managed sandbox %s: %s", + sandbox_id, + shutdown_err, + ) + # Always pop the booter so memory is freed even when the + # provider has already been unregistered. + self.clear_runtime_state(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + self.clear_runtime_state(sandbox_id) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + + def clear_idle_state(self, sandbox_id: str) -> None: + state = self.idle_state.pop(sandbox_id, None) + if ( + state is not None + and not state.task.done() + and state.task is not asyncio.current_task() + ): + state.task.cancel() + + def clear_expiration_state(self, sandbox_id: str) -> None: + state = self.expiration_state.pop(sandbox_id, None) + if ( + state is not None + and not state.task.done() + and state.task is not asyncio.current_task() + ): + state.task.cancel() + + def schedule_idle_cleanup(self, sandbox_id: str, timeout: float) -> None: + self.clear_idle_state(sandbox_id) + if timeout <= 0: + return + self.registry.touch_sandbox(sandbox_id) + expires_at = time.monotonic() + timeout + task = asyncio.create_task( + self._expire_when_idle(sandbox_id, timeout, expires_at) + ) + self.idle_state[sandbox_id] = SandboxIdleState(expires_at=expires_at, task=task) + + def schedule_ttl_cleanup(self, sandbox_id: str, expires_at: float | None) -> None: + self.clear_expiration_state(sandbox_id) + if expires_at is None: + return + monotonic_expires_at = time.monotonic() + max( + 0.0, float(expires_at) - time.time() + ) + task = asyncio.create_task( + self._expire_at_fixed_time( + sandbox_id, float(expires_at), monotonic_expires_at + ) + ) + self.expiration_state[sandbox_id] = SandboxExpirationState( + expires_at=float(expires_at), + monotonic_expires_at=monotonic_expires_at, + task=task, + ) + + def schedule_lifecycle_cleanup( + self, + sandbox_id: str, + idle_timeout: float, + expires_at: float | None, + ) -> None: + if idle_timeout > 0: + self.clear_expiration_state(sandbox_id) + self.schedule_idle_cleanup(sandbox_id, idle_timeout) + return + self.clear_idle_state(sandbox_id) + self.schedule_ttl_cleanup(sandbox_id, expires_at) + + async def _expire_at_fixed_time( + self, + sandbox_id: str, + expires_at: float, + monotonic_expires_at: float, + ) -> None: + current_task = asyncio.current_task() + destroy_attempts = 0 + try: + while True: + remaining = monotonic_expires_at - time.monotonic() + if remaining > 0: + await asyncio.sleep(remaining) + state = self.expiration_state.get(sandbox_id) + if ( + state is None + or state.task is not current_task + or state.expires_at != float(expires_at) + ): + return + record = self.registry.get_sandbox(sandbox_id) + if record is None: + self.session_booter.pop(sandbox_id, None) + return + if float(record.get("expires_at") or 0) != float(expires_at): + return + booter = self.session_booter.get(sandbox_id) + if booter is not None: + try: + provider = self.get_provider(record.get("provider", "")) + await provider.destroy_booter(booter, record) + except Exception as shutdown_err: + logger.warning( + "[Computer] Failed to shutdown expired sandbox %s: %s", + sandbox_id, + shutdown_err, + ) + destroy_attempts += 1 + if destroy_attempts < MAX_IDLE_DESTROY_ATTEMPTS: + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.UNKNOWN + ) + await self.save_registry_async() + await asyncio.sleep(SANDBOX_TTL_DESTROY_RETRY_SECONDS) + continue + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.ERROR + ) + await self.save_registry_async() + return + self.clear_runtime_state(sandbox_id) + if record.get("retention_policy") == "persistent": + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.STOPPED + ) + else: + self.registry.delete_sandbox(sandbox_id) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + return + finally: + state = self.expiration_state.get(sandbox_id) + if ( + state is not None + and state.task is current_task + and state.expires_at == float(expires_at) + ): + self.expiration_state.pop(sandbox_id, None) + + async def _expire_when_idle( + self, sandbox_id: str, timeout: float, initial_expires_at: float + ) -> None: + current_expires_at = initial_expires_at + destroy_attempts = 0 + try: + while True: + remaining = current_expires_at - time.monotonic() + if remaining > 0: + await asyncio.sleep(remaining) + state = self.idle_state.get(sandbox_id) + if state is None or state.expires_at != current_expires_at: + return + record = self.registry.get_sandbox(sandbox_id) + if record is None: + self.session_booter.pop(sandbox_id, None) + return + if self.sandbox_has_active_lease(sandbox_id): + current_expires_at = time.monotonic() + timeout + self.idle_state[sandbox_id] = SandboxIdleState( + expires_at=current_expires_at, task=state.task + ) + continue + if record.get("retention_policy") == "persistent": + return + booter = self.session_booter.get(sandbox_id) + if booter is not None: + try: + provider = self.get_provider(record.get("provider", "")) + self.session_booter.pop(sandbox_id, None) + await provider.destroy_booter(booter, record) + except Exception as shutdown_err: + logger.warning( + "[Computer] Failed to shutdown idle sandbox %s: %s", + sandbox_id, + shutdown_err, + ) + try: + booter_available = await self.booter_available(booter) + except Exception: + booter_available = False + if booter_available: + destroy_attempts += 1 + if destroy_attempts < MAX_IDLE_DESTROY_ATTEMPTS: + self.session_booter[sandbox_id] = booter + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.UNKNOWN + ) + await self.save_registry_async() + # Retry cleanup after the normal timeout instead of + # leaving the sandbox without any scheduled cleanup. + current_expires_at = time.monotonic() + timeout + self.idle_state[sandbox_id] = SandboxIdleState( + expires_at=current_expires_at, task=state.task + ) + continue + logger.warning( + "[Computer] Giving up on idle sandbox %s after %d destroy attempts", + sandbox_id, + destroy_attempts, + ) + self.session_booter[sandbox_id] = booter + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.ERROR + ) + self.idle_state.pop(sandbox_id, None) + await self.save_registry_async() + return + self.clear_runtime_state(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + return + self.registry.delete_sandbox(sandbox_id) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + return + except asyncio.CancelledError: + raise + finally: + state = self.idle_state.get(sandbox_id) + if state is not None and state.expires_at == current_expires_at: + self.idle_state.pop(sandbox_id, None) + + @staticmethod + async def _sync_skills_to_booter( + booter: ComputerBooter, + provider_id: str | None = None, + ) -> None: + """Delay-import wrapper to avoid circular imports.""" + from astrbot.core.computer.computer_client import _sync_skills_to_sandbox + + await _sync_skills_to_sandbox(booter, provider_id=provider_id) diff --git a/astrbot/core/computer/sandbox_models.py b/astrbot/core/computer/sandbox_models.py new file mode 100644 index 0000000000..32b6622030 --- /dev/null +++ b/astrbot/core/computer/sandbox_models.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from astrbot.core.computer.sandbox_timeouts import lease_is_active + + +class SandboxRetentionPolicy(str, Enum): + TEMPORARY = "temporary" + PERSISTENT = "persistent" + + +class SandboxStatus(str, Enum): + CREATING = "creating" + RESTORING = "restoring" + RUNNING = "running" + ERROR = "error" + STOPPING = "stopping" + STOPPED = "stopped" + UNKNOWN = "unknown" + + +@dataclass(slots=True) +class SandboxRecord: + sandbox_id: str + sandbox_name: str + provider: str + managed: bool + created_by_astrbot: bool + is_default: bool = False + owner_user_id: str | None = None + owner_session_id: str | None = None + created_by_user_id: str | None = None + created_by_session_id: str | None = None + controller_user_id: str | None = None + controller_session_id: str | None = None + lease_expires_at: float | None = None + last_used_at: float | None = None + idle_timeout: int | float | None = None + expires_at: float | None = None + retention_policy: SandboxRetentionPolicy = SandboxRetentionPolicy.TEMPORARY + status: SandboxStatus = SandboxStatus.RUNNING + connect_info: dict[str, Any] = field(default_factory=dict) + capabilities: list[str] = field(default_factory=list) + tool_names: list[str] = field(default_factory=list) + labels: dict[str, Any] = field(default_factory=dict) + notes: str | None = None + created_hook_fired: bool = False + + @staticmethod + def _required_string(data: dict[str, Any], field_name: str) -> str: + value = data[field_name] + if not isinstance(value, str): + raise ValueError(f"{field_name} must be a non-empty string") + value = value.strip() + if not value: + raise ValueError(f"{field_name} must be a non-empty string") + return value + + @classmethod + def _required_provider(cls, data: dict[str, Any]) -> str: + if "provider" in data: + return cls._required_string(data, "provider") + return cls._required_string(data, "booter_type") + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SandboxRecord: + return cls( + sandbox_id=cls._required_string(data, "sandbox_id"), + sandbox_name=cls._required_string(data, "sandbox_name"), + provider=cls._required_provider(data), + managed=bool(data["managed"]), + created_by_astrbot=bool(data["created_by_astrbot"]), + is_default=bool(data.get("is_default", False)), + owner_user_id=data.get("owner_user_id"), + owner_session_id=data.get("owner_session_id"), + created_by_user_id=data.get("created_by_user_id") + or data.get("owner_user_id"), + created_by_session_id=data.get("created_by_session_id") + or data.get("owner_session_id"), + controller_user_id=data.get("controller_user_id"), + controller_session_id=data.get("controller_session_id"), + lease_expires_at=data.get("lease_expires_at"), + last_used_at=data.get("last_used_at"), + idle_timeout=data.get("idle_timeout"), + expires_at=data.get("expires_at"), + retention_policy=SandboxRetentionPolicy( + data.get("retention_policy", SandboxRetentionPolicy.TEMPORARY) + ), + status=SandboxStatus(data.get("status", SandboxStatus.RUNNING)), + connect_info=dict(data.get("connect_info") or {}), + capabilities=sorted( + str(item) for item in data.get("capabilities", []) if item + ), + tool_names=sorted(str(item) for item in data.get("tool_names", []) if item), + labels=dict(data.get("labels") or {}), + notes=data.get("notes"), + created_hook_fired=bool(data.get("created_hook_fired", False)), + ) + + def to_dict(self) -> dict[str, Any]: + return { + "sandbox_id": self.sandbox_id, + "sandbox_name": self.sandbox_name, + "provider": self.provider, + "managed": self.managed, + "created_by_astrbot": self.created_by_astrbot, + "is_default": self.is_default, + "owner_user_id": self.owner_user_id, + "owner_session_id": self.owner_session_id, + "created_by_user_id": self.created_by_user_id, + "created_by_session_id": self.created_by_session_id, + "controller_user_id": self.controller_user_id, + "controller_session_id": self.controller_session_id, + "lease_expires_at": self.lease_expires_at, + "last_used_at": self.last_used_at, + "idle_timeout": self.idle_timeout, + "expires_at": self.expires_at, + "retention_policy": self.retention_policy.value, + "status": self.status.value, + "connect_info": dict(self.connect_info), + "capabilities": list(self.capabilities), + "tool_names": list(self.tool_names), + "labels": dict(self.labels), + "notes": self.notes, + "created_hook_fired": self.created_hook_fired, + } + + def has_active_lease(self, *, now: float | None = None) -> bool: + return lease_is_active( + self.controller_session_id, self.lease_expires_at, now=now + ) + + def is_controlled_by(self, session_id: str, *, now: float | None = None) -> bool: + return self.controller_session_id == session_id and self.has_active_lease( + now=now + ) diff --git a/astrbot/core/computer/sandbox_provider.py b/astrbot/core/computer/sandbox_provider.py new file mode 100644 index 0000000000..1930f06911 --- /dev/null +++ b/astrbot/core/computer/sandbox_provider.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from astrbot.core.computer.booters.base import ComputerBooter +from astrbot.core.star.context import Context + + +class SandboxProvider(Protocol): + """Protocol for plugin-provided sandbox runtime providers. + + Required attributes: + provider_id: Unique provider identifier (e.g. "browser", "python_sandbox"). + capabilities: Set of capability strings (e.g. {"shell", "python", "gui"}). + tool_names: Set of tool names this provider contributes to the LLM. + + Optional attributes (core uses ``getattr`` with safe fallbacks): + provider_api_version: Provider API compatibility version. Defaults to "1.0". + system_prompt: Runtime-specific instructions exposed in provider metadata. + plugin_config: Plugin-specific configuration dict. Implementations are + encouraged to accept this as an ``__init__`` parameter so the + provider is fully initialized at construction time. + auto_sync_skills: If ``False``, core will skip automatic skill sync after + booting a sandbox for this provider. Defaults to ``True``. + prune_missing_persistent_records: If ``True``, startup reconciliation may + delete persistent registry records when the provider confirms the + external sandbox is missing. Defaults to ``False`` to avoid data loss + from transient reconnect failures. + """ + + provider_id: str + capabilities: set[str] + tool_names: set[str] + system_prompt: str = "" + plugin_config: dict[str, Any] | None = None + provider_api_version: str = "1.0" + auto_sync_skills: bool = True + supports_persistent_reconnect: bool = False + prune_missing_persistent_records: bool = False + + def build_create_config(self, context: Context, session_id: str) -> dict: ... + + def build_connect_info(self, sandbox_name: str, config: dict) -> dict: ... + + def update_connect_info(self, record: dict, *, sandbox_name: str) -> dict: ... + + def update_connect_info_after_boot( + self, record: dict, booter: ComputerBooter + ) -> dict | None: ... + + async def create_booter( + self, + context: Context, + session_id: str, + sandbox_id: str, + config: dict, + ) -> ComputerBooter: ... + + async def destroy_booter(self, booter: ComputerBooter, record: dict) -> None: ... + + # Optional lifecycle hooks -- core checks ``hasattr`` before invoking. + + async def on_sandbox_created(self, record: dict) -> None: + """Called after a sandbox is successfully created and leased.""" + + async def on_sandbox_destroyed(self, record: dict) -> None: + """Called after a sandbox is destroyed and removed from registry.""" diff --git a/astrbot/core/computer/sandbox_registry.py b/astrbot/core/computer/sandbox_registry.py new file mode 100644 index 0000000000..f867495ba0 --- /dev/null +++ b/astrbot/core/computer/sandbox_registry.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import asyncio +import json +import time +from copy import deepcopy +from pathlib import Path +from typing import Any + +from astrbot.api import logger +from astrbot.core.computer.sandbox_models import SandboxRecord, SandboxStatus +from astrbot.core.computer.sandbox_timeouts import ( + lease_expires_at_from_timeout, + lease_is_active, +) +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +_UNSET = object() +_SCHEMA_VERSION = 1 + + +def _default_registry_payload() -> dict[str, Any]: + return { + "schema_version": _SCHEMA_VERSION, + "default_sandbox_id": None, + "default_sandbox_ids": {}, + "sandboxes": {}, + "session_current": {}, + } + + +def _coerce_schema_version(value: Any) -> int: + try: + version = int(value) + except (TypeError, ValueError): + return _SCHEMA_VERSION + return version if version > 0 else _SCHEMA_VERSION + + +class SandboxRegistry: + def __init__(self, storage_path: str | Path | None = None): + if storage_path is None: + storage_path = Path(get_astrbot_data_path()) / "sandbox_registry.json" + self.storage_path = Path(storage_path) + self._payload = _default_registry_payload() + self._save_lock = asyncio.Lock() + + @property + def default_sandbox_id(self) -> str | None: + return self._payload["default_sandbox_id"] + + def get_default_sandbox_id(self, provider: str) -> str | None: + sandbox_id = self._payload.get("default_sandbox_ids", {}).get(provider) + if sandbox_id and sandbox_id in self._payload["sandboxes"]: + return sandbox_id + if self._payload["default_sandbox_id"]: + record = self.get_sandbox(self._payload["default_sandbox_id"]) + if record and record.get("provider") == provider: + return self._payload["default_sandbox_id"] + return None + + def get_sandbox(self, sandbox_id: str | None) -> dict[str, Any] | None: + if sandbox_id is None: + return None + record = self._payload["sandboxes"].get(sandbox_id) + return deepcopy(record) if record is not None else None + + def list_sandboxes(self) -> list[dict[str, Any]]: + return [deepcopy(item) for item in self._payload["sandboxes"].values()] + + def set_default_sandbox_id(self, sandbox_id: str | None) -> None: + old_default = self._payload["default_sandbox_id"] + self._payload["default_sandbox_id"] = sandbox_id + if sandbox_id and sandbox_id in self._payload["sandboxes"]: + record = self._payload["sandboxes"][sandbox_id] + provider = record.get("provider") + if provider: + old_provider_default = self._payload.setdefault( + "default_sandbox_ids", {} + ).get(provider) + if ( + old_provider_default + and old_provider_default in self._payload["sandboxes"] + ): + self._payload["sandboxes"][old_provider_default]["is_default"] = ( + False + ) + self._payload["default_sandbox_ids"][provider] = sandbox_id + record["is_default"] = True + elif old_default and old_default in self._payload["sandboxes"]: + self._payload["sandboxes"][old_default]["is_default"] = False + + def get_current_sandbox_id(self, session_id: str) -> str | None: + return self._payload["session_current"].get(session_id) + + def set_current_sandbox_id(self, session_id: str, sandbox_id: str | None) -> None: + if sandbox_id is None: + self._payload["session_current"].pop(session_id, None) + else: + self._payload["session_current"][session_id] = sandbox_id + + def upsert_sandbox( + self, + *, + sandbox_id: str, + sandbox_name: str, + provider: str, + managed: bool, + created_by_astrbot: bool, + owner_user_id: str | None, + owner_session_id: str | None, + connect_info: dict[str, Any], + is_default: bool | object = _UNSET, + status: str | object = _UNSET, + idle_timeout: int | float | None | object = _UNSET, + expires_at: float | None | object = _UNSET, + retention_policy: str | object = _UNSET, + last_used_at: float | None | object = _UNSET, + controller_user_id: str | None | object = _UNSET, + controller_session_id: str | None | object = _UNSET, + lease_expires_at: float | None | object = _UNSET, + labels: dict[str, Any] | None | object = _UNSET, + capabilities: list[str] | set[str] | None | object = _UNSET, + tool_names: list[str] | set[str] | None | object = _UNSET, + notes: str | None | object = _UNSET, + ) -> dict[str, Any]: + record = self._payload["sandboxes"].get(sandbox_id, {}) + record.update( + { + "sandbox_id": sandbox_id, + "sandbox_name": sandbox_name, + "provider": provider, + "managed": managed, + "created_by_astrbot": created_by_astrbot, + "owner_user_id": owner_user_id, + "owner_session_id": owner_session_id, + "created_by_user_id": owner_user_id, + "created_by_session_id": owner_session_id, + "connect_info": deepcopy(connect_info), + } + ) + defaults = { + "controller_user_id": None, + "controller_session_id": None, + "lease_expires_at": None, + "last_used_at": None, + "idle_timeout": None, + "expires_at": None, + "retention_policy": "temporary", + "status": "running", + "is_default": False, + "labels": {}, + "capabilities": [], + "tool_names": [], + "notes": None, + "created_hook_fired": False, + } + updates = { + "controller_user_id": controller_user_id, + "controller_session_id": controller_session_id, + "lease_expires_at": lease_expires_at, + "last_used_at": last_used_at, + "idle_timeout": idle_timeout, + "expires_at": expires_at, + "retention_policy": retention_policy, + "status": status, + "is_default": is_default, + "labels": deepcopy(labels) if labels is not _UNSET else _UNSET, + "capabilities": sorted(capabilities) + if capabilities is not _UNSET + else _UNSET, + "tool_names": sorted(tool_names) if tool_names is not _UNSET else _UNSET, + "notes": notes, + "created_hook_fired": _UNSET, + } + for field_name, default_value in defaults.items(): + value = updates[field_name] + if value is _UNSET: + record.setdefault(field_name, deepcopy(default_value)) + else: + record[field_name] = value + record = SandboxRecord.from_dict(record).to_dict() + self._payload["sandboxes"][sandbox_id] = record + if is_default is True or ( + managed and self._payload["default_sandbox_id"] is None + ): + self.set_default_sandbox_id(sandbox_id) + return deepcopy(record) + + def delete_sandbox(self, sandbox_id: str) -> None: + was_default = self._payload["default_sandbox_id"] == sandbox_id + deleted = self._payload["sandboxes"].pop(sandbox_id, None) + if deleted: + provider = deleted.get("provider") + if ( + provider + and self._payload.get("default_sandbox_ids", {}).get(provider) + == sandbox_id + ): + self._payload["default_sandbox_ids"].pop(provider, None) + for candidate_id, candidate in self._payload["sandboxes"].items(): + if ( + candidate.get("managed") + and candidate.get("provider") == provider + ): + self.set_default_sandbox_id(candidate_id) + break + if was_default: + self._payload["default_sandbox_id"] = None + for candidate_id, candidate in self._payload["sandboxes"].items(): + if candidate.get("managed"): + self.set_default_sandbox_id(candidate_id) + break + stale_sessions = [ + session_id + for session_id, current_id in self._payload["session_current"].items() + if current_id == sandbox_id + ] + for session_id in stale_sessions: + self._payload["session_current"].pop(session_id, None) + + def touch_sandbox( + self, sandbox_id: str, *, ts: float | None = None + ) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + record["last_used_at"] = ts if ts is not None else time.time() + return deepcopy(record) + + def update_sandbox_config( + self, + sandbox_id: str, + *, + sandbox_name: str | object = _UNSET, + connect_info: dict[str, Any] | object = _UNSET, + idle_timeout: int | float | None | object = _UNSET, + expires_at: int | float | None | object = _UNSET, + retention_policy: str | object = _UNSET, + ) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + if sandbox_name is not _UNSET: + name = str(sandbox_name).strip() + if not name: + raise ValueError("sandbox_name must be a non-empty string") + record["sandbox_name"] = name + if connect_info is not _UNSET: + record["connect_info"] = deepcopy(connect_info) + if idle_timeout is not _UNSET: + record["idle_timeout"] = idle_timeout + if expires_at is not _UNSET: + record["expires_at"] = expires_at + if retention_policy is not _UNSET: + record["retention_policy"] = retention_policy + return deepcopy(record) + + def update_sandbox_status( + self, sandbox_id: str, status: str + ) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + record["status"] = getattr(status, "value", status) + return deepcopy(record) + + def has_created_hook_fired(self, sandbox_id: str) -> bool: + record = self._payload["sandboxes"].get(sandbox_id) + return bool(record and record.get("created_hook_fired")) + + def mark_created_hook_fired(self, sandbox_id: str) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + record["created_hook_fired"] = True + return deepcopy(record) + + def acquire_lease( + self, + *, + sandbox_id: str, + session_id: str, + user_id: str | None, + ttl: int | float, + now: float | None = None, + ) -> bool: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return False + current_time = time.time() if now is None else now + controller_session_id = record.get("controller_session_id") + lease_expires_at = record.get("lease_expires_at") + if lease_is_active( + controller_session_id, lease_expires_at, now=current_time + ) and (controller_session_id != session_id): + return False + record["controller_session_id"] = session_id + record["controller_user_id"] = user_id + next_lease_expires_at = lease_expires_at_from_timeout(ttl, now=current_time) + if controller_session_id == session_id and lease_is_active( + controller_session_id, lease_expires_at, now=current_time + ): + if lease_expires_at is None: + next_lease_expires_at = None + elif next_lease_expires_at is not None: + next_lease_expires_at = max( + float(lease_expires_at), next_lease_expires_at + ) + record["lease_expires_at"] = next_lease_expires_at + return True + + def release_lease(self, sandbox_id: str) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + record["controller_session_id"] = None + record["controller_user_id"] = None + record["lease_expires_at"] = None + return deepcopy(record) + + def takeover_lease( + self, + *, + sandbox_id: str, + session_id: str, + user_id: str | None, + ttl: int | float, + now: float | None = None, + ) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + current_time = time.time() if now is None else now + record["controller_session_id"] = session_id + record["controller_user_id"] = user_id + record["lease_expires_at"] = lease_expires_at_from_timeout( + ttl, now=current_time + ) + return deepcopy(record) + + def reconcile_startup(self) -> None: + self._payload["session_current"] = {} + for sandbox_id, record in list(self._payload["sandboxes"].items()): + if record.get("retention_policy") != "persistent": + self._payload["sandboxes"].pop(sandbox_id, None) + continue + record["controller_session_id"] = None + record["controller_user_id"] = None + record["lease_expires_at"] = None + if record.get("status") == SandboxStatus.RUNNING: + record["status"] = SandboxStatus.UNKNOWN.value + elif record.get("status") in { + SandboxStatus.CREATING, + SandboxStatus.RESTORING, + }: + record["status"] = SandboxStatus.ERROR.value + self._prune_default_references() + + def _prune_default_references(self) -> None: + sandboxes = self._payload["sandboxes"] + default_sandbox_id = self._payload.get("default_sandbox_id") + if default_sandbox_id not in sandboxes: + self._payload["default_sandbox_id"] = None + default_sandbox_ids = self._payload.get("default_sandbox_ids") or {} + valid_default_sandbox_ids = { + provider: sandbox_id + for provider, sandbox_id in default_sandbox_ids.items() + if sandbox_id in sandboxes + and sandboxes[sandbox_id].get("provider") == provider + } + self._payload["default_sandbox_ids"] = valid_default_sandbox_ids + for record in sandboxes.values(): + record["is_default"] = False + if self._payload["default_sandbox_id"]: + sandboxes[self._payload["default_sandbox_id"]]["is_default"] = True + for sandbox_id in valid_default_sandbox_ids.values(): + if sandbox_id in sandboxes: + sandboxes[sandbox_id]["is_default"] = True + + def load(self) -> None: + if not self.storage_path.exists(): + self._payload = _default_registry_payload() + return + try: + payload = json.loads(self.storage_path.read_text(encoding="utf-8")) + except Exception as exc: + logger.warning("Failed to load sandbox registry: %s", exc) + self._payload = _default_registry_payload() + return + if not isinstance(payload, dict): + logger.warning("Failed to load sandbox registry: payload is not an object") + self._payload = _default_registry_payload() + return + self._payload = _default_registry_payload() + self._payload["schema_version"] = _coerce_schema_version( + payload.get("schema_version") + ) + self._payload.update({key: payload.get(key) for key in self._payload}) + self._payload["schema_version"] = _coerce_schema_version( + self._payload.get("schema_version") + ) + self._payload["default_sandbox_ids"] = dict( + self._payload.get("default_sandbox_ids") or {} + ) + self._payload["sandboxes"] = dict(self._payload.get("sandboxes") or {}) + self._payload["session_current"] = dict( + self._payload.get("session_current") or {} + ) + + def _write_payload(self, payload: dict[str, Any]) -> None: + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + temp_path = self.storage_path.with_name( + f"{self.storage_path.name}.{time.time_ns()}.tmp" + ) + try: + temp_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True), + encoding="utf-8", + ) + temp_path.replace(self.storage_path) + finally: + if temp_path.exists(): + temp_path.unlink() + + def save(self) -> None: + self._write_payload(deepcopy(self._payload)) + + async def save_async(self) -> None: + async with self._save_lock: + payload = deepcopy(self._payload) + await asyncio.to_thread(self._write_payload, payload) diff --git a/astrbot/core/computer/sandbox_timeouts.py b/astrbot/core/computer/sandbox_timeouts.py new file mode 100644 index 0000000000..d1641b51b3 --- /dev/null +++ b/astrbot/core/computer/sandbox_timeouts.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import math +import time +from collections.abc import Mapping +from typing import Any + +DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS = 600.0 + + +def _coerce_timeout(value: Any, default: float) -> float: + if isinstance(value, bool): + return default + try: + timeout = float(value) + except (TypeError, ValueError): + return default + if not math.isfinite(timeout) or timeout < 0: + return default + return timeout + + +def resolve_sandbox_timeout( + config: Mapping[str, Any], + key: str, + *, + aliases: tuple[str, ...] = (), + default: float, +) -> float: + for candidate in (key, *aliases): + if candidate in config: + return _coerce_timeout(config.get(candidate), default) + return default + + +def lease_is_active( + controller_session_id: str | None, + lease_expires_at: float | None, + *, + now: float | None = None, +) -> bool: + if not controller_session_id: + return False + if lease_expires_at is None: + return True + current_time = time.time() if now is None else now + return float(lease_expires_at) > current_time + + +def lease_expires_at_from_timeout( + timeout: float | int | None, + *, + now: float | None = None, +) -> float | None: + if timeout is None: + return None + current_time = time.time() if now is None else now + normalized = _coerce_timeout(timeout, DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS) + if normalized <= 0: + return None + return current_time + normalized + + +def expires_at_from_timeout( + timeout: float | int | None, + *, + now: float | None = None, +) -> float | None: + return lease_expires_at_from_timeout(timeout, now=now) + + +def idle_cleanup_at_from_record( + *, + last_used_at: float | None, + idle_timeout: float | int | None, + now: float | None = None, +) -> float | None: + if last_used_at is None: + return None + current_timeout = _coerce_timeout(idle_timeout, 0.0) + if current_timeout <= 0: + return None + candidate = float(last_used_at) + current_timeout + return candidate + + +def get_provider_sandbox_config(context: Any, session_id: str) -> dict[str, Any]: + if context is None: + return {} + get_config = getattr(context, "get_config", None) + if not callable(get_config): + return {} + config = get_config(umo=session_id) + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + return sandbox_cfg if isinstance(sandbox_cfg, dict) else {} diff --git a/astrbot/core/computer/sandbox_tool_binding.py b/astrbot/core/computer/sandbox_tool_binding.py new file mode 100644 index 0000000000..21f49733c0 --- /dev/null +++ b/astrbot/core/computer/sandbox_tool_binding.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from astrbot.core.agent.tool import FunctionTool +from astrbot.core.tools.registry import ( + BuiltinToolConfigRule, + build_builtin_tool_config_rule, +) + +TFunctionTool = type[FunctionTool] + +_sandbox_provider_tool_config_rules: dict[str, BuiltinToolConfigRule] = {} + + +def tool_available_in_runtime(tool: Any, runtime: str) -> bool: + """Return whether a tool should be exposed for the computer-use runtime. + + Provider-specific sandbox tools are registered once when their provider is + enabled. They are visible to all sandbox sessions and hidden from local/none + runtimes. + """ + tool_provider = getattr(tool, "sandbox_provider_id", None) + if not tool_provider: + return True + return runtime == "sandbox" + + +def mark_tool_as_sandbox_provider_tool(tool: Any, provider_id: str) -> Any: + provider_id = _normalize_provider_id(provider_id) + tool.sandbox_provider_id = provider_id + marker = f"[Sandbox provider-specific tool: {provider_id}]" + description = str(getattr(tool, "description", "") or "") + if marker not in description: + tool.description = ( + f"{marker} This tool only works when the current sandbox uses provider " + f"'{provider_id}'. If the current sandbox uses another provider, switch or " + f"create a '{provider_id}' sandbox first. {description}" + ).strip() + return tool + + +def sandbox_provider_tool( + provider_id: str, + *, + config: dict[str, Any] | None = None, +) -> Callable[[TFunctionTool], TFunctionTool]: + """Mark a FunctionTool class as belonging to a sandbox provider. + + Sandbox provider tools are plugin-owned tools with sandbox runtime semantics. + They are not AstrBot Core builtin tools, but may still expose config tags in + the dashboard through the same condition format as builtin tools. + """ + + normalized_provider_id = _normalize_provider_id(provider_id) + + def _register(cls: TFunctionTool) -> TFunctionTool: + tool_name = _resolve_tool_name(cls) + cls.sandbox_provider_id = normalized_provider_id + if config is not None: + _sandbox_provider_tool_config_rules[tool_name] = ( + build_builtin_tool_config_rule(config) + ) + return cls + + return _register + + +def get_sandbox_provider_tool_config_statuses( + tool_name: str, + config_entries: list[dict[str, Any]], +) -> list[dict[str, Any]]: + rule = _sandbox_provider_tool_config_rules.get(tool_name) + if rule is None: + return [] + + statuses: list[dict[str, Any]] = [] + for entry in config_entries: + config = entry.get("config") + if not isinstance(config, dict): + continue + + conditions = rule.evaluate(config) + enabled = bool(conditions) and all( + bool(condition.get("matched")) for condition in conditions + ) + statuses.append( + { + "conf_id": entry.get("conf_id"), + "conf_name": entry.get("conf_name"), + "enabled": enabled, + "matched_conditions": [ + condition for condition in conditions if condition.get("matched") + ], + "failed_conditions": [ + condition + for condition in conditions + if not condition.get("matched") + ], + } + ) + return statuses + + +def _resolve_tool_name(tool_cls: type[FunctionTool]) -> str: + tool_name = getattr(tool_cls, "name", None) + if isinstance(tool_name, str) and tool_name: + return tool_name + + dataclass_fields = getattr(tool_cls, "__dataclass_fields__", {}) + name_field = dataclass_fields.get("name") + if name_field is not None and isinstance(name_field.default, str): + return name_field.default + + raise ValueError( + f"Sandbox provider tool class {tool_cls.__module__}.{tool_cls.__name__} does not define a valid name.", + ) + + +def _normalize_provider_id(provider_id: str | None) -> str: + return "" if provider_id is None else str(provider_id).strip().lower() diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 4d62becb55..56ffd7943d 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -17,6 +17,8 @@ DASHBOARD_INITIAL_PASSWORD_ENV = "ASTRBOT_DASHBOARD_INITIAL_PASSWORD" logger = logging.getLogger("astrbot") +CORE_COMPUTER_RUNTIME_IDS = {"local", "sandbox", "none"} + class RateLimitStrategy(enum.Enum): STALL = "stall" @@ -76,6 +78,7 @@ def __init__( ) # 检查配置完整性,并插入 has_new = self.check_config_integrity(default_config, conf) + has_new |= self._migrate_legacy_sandbox_runtime(conf) if ( "dashboard" in conf and isinstance(conf["dashboard"], dict) @@ -154,6 +157,33 @@ def _parse_schema(schema: dict, conf: dict) -> None: return conf + def _migrate_legacy_sandbox_runtime(self, conf: dict) -> bool: + provider_settings = conf.get("provider_settings") + if not isinstance(provider_settings, dict): + return False + + runtime = provider_settings.get("computer_use_runtime") + if runtime in CORE_COMPUTER_RUNTIME_IDS or not runtime: + return False + + # Older configs stored sandbox provider IDs directly as the runtime. + # Preserve that value as the selected sandbox booter without teaching + # core about concrete provider names. + sandbox_cfg = provider_settings.get("sandbox") + if not isinstance(sandbox_cfg, dict): + sandbox_cfg = {} + provider_settings["sandbox"] = sandbox_cfg + + if not sandbox_cfg.get("booter"): + sandbox_cfg["booter"] = runtime + logger.info( + "Config key migrated: provider_settings.computer_use_runtime %s -> sandbox", + runtime, + ) + + provider_settings["computer_use_runtime"] = "sandbox" + return True + def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): """检查配置完整性,如果有新的配置项或顺序不一致则返回 True""" has_new = False diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index d060ce1c3d..eeb11b9496 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2,7 +2,6 @@ import os -from astrbot.core.computer.booters.cua_defaults import CUA_DEFAULT_CONFIG from astrbot.core.utils.astrbot_path import get_astrbot_data_path VERSION = "4.25.5" @@ -169,21 +168,17 @@ "computer_use_runtime": "none", "computer_use_require_admin": True, "sandbox": { - "booter": "shipyard_neo", - "shipyard_endpoint": "", - "shipyard_access_token": "", - "shipyard_ttl": 3600, - "shipyard_max_sessions": 10, - "shipyard_neo_endpoint": "", - "shipyard_neo_access_token": "", - "shipyard_neo_profile": "python-default", - "shipyard_neo_ttl": 3600, - "cua_image": CUA_DEFAULT_CONFIG["image"], - "cua_os_type": CUA_DEFAULT_CONFIG["os_type"], - "cua_idle_timeout": CUA_DEFAULT_CONFIG["idle_timeout"], - "cua_telemetry_enabled": CUA_DEFAULT_CONFIG["telemetry_enabled"], - "cua_local": CUA_DEFAULT_CONFIG["local"], - "cua_api_key": CUA_DEFAULT_CONFIG["api_key"], + "booter": "", + "sandbox_lease_timeout": 600, + "sandbox_idle_timeout": 1800, + "sandbox_ttl": 3600, + "max_sandboxes": 10, + "member_permissions": { + "create": False, + "set_retention_policy": False, + "takeover": False, + "destroy": False, + }, }, "image_compress_enabled": True, "image_compress_options": { @@ -3369,143 +3364,80 @@ "hint": "开启后,需要 AstrBot 管理员权限才能调用使用电脑能力。在平台配置->管理员中可添加管理员。使用 /sid 指令查看管理员 ID。", }, "provider_settings.sandbox.booter": { - "description": "沙箱环境驱动器", - "type": "string", - "options": ["shipyard_neo", "shipyard", "cua"], - "labels": ["Shipyard Neo", "Shipyard", "CUA"], - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - }, - }, - "provider_settings.sandbox.shipyard_neo_endpoint": { - "description": "Shipyard Neo API Endpoint", - "type": "string", - "hint": "Shipyard Neo(Bay) 服务的 API 地址,默认 http://127.0.0.1:8114。", - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", - }, - }, - "provider_settings.sandbox.shipyard_neo_access_token": { - "description": "Shipyard Neo Access Token", - "type": "string", - "hint": "Bay 的 API Key(sk-bay-...)。留空时自动从 credentials.json 发现。", - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", - }, - }, - "provider_settings.sandbox.shipyard_neo_profile": { - "description": "Shipyard Neo Profile", + "description": "沙箱驱动", "type": "string", - "hint": "Shipyard Neo 沙箱 profile,如 python-default。留空时自动选择能力更完整的 profile。", + "options": [], + "labels": [], "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", }, }, - "provider_settings.sandbox.shipyard_neo_ttl": { - "description": "Shipyard Neo Sandbox TTL", + "provider_settings.sandbox.sandbox_lease_timeout": { + "description": "沙箱占用超时", "type": "int", - "hint": "Shipyard Neo 沙箱生存时间(秒)。", + "hint": "单位为秒。每次 Agent 成功访问沙盒时,都会自动将本会话的沙盒租约续到当前时间 + 此时长。默认 600 秒;到期后当前会话不再绑定该沙盒,其他会话可接管。`0` 表示租约不会自动过期,需手动释放。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", }, }, - "provider_settings.sandbox.cua_image": { - "description": "CUA Image", - "type": "string", - "hint": "CUA 沙箱镜像/系统类型,默认 linux。可填写 linux、macos、windows、android,具体取决于 CUA SDK 支持。", + "provider_settings.sandbox.sandbox_idle_timeout": { + "description": "沙箱空闲回收时间", + "type": "int", + "hint": "单位为秒。`0` 表示不启用空闲回收,此时才会启用沙箱存活时间。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", }, }, - "provider_settings.sandbox.cua_os_type": { - "description": "CUA OS Type", - "type": "string", - "options": ["linux", "macos", "windows", "android"], - "labels": ["Linux", "macOS", "Windows", "Android"], - "hint": "CUA 沙箱操作系统类型,默认 linux。", + "provider_settings.sandbox.sandbox_ttl": { + "description": "沙箱存活时间", + "type": "int", + "hint": "单位为秒。仅在空闲回收时间为 `0` 时生效;`0` 表示不自动销毁。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", }, }, - "provider_settings.sandbox.cua_idle_timeout": { - "description": "CUA Idle Timeout", + "provider_settings.sandbox.max_sandboxes": { + "description": "最大沙箱数量", "type": "int", - "hint": "Idle timeout for CUA sandbox sessions in seconds. When greater than 0, AstrBot proactively shuts down an idle CUA sandbox after that amount of inactivity; 0 disables it.", + "hint": "全局托管沙箱数量上限,默认 10。`0` 表示不限制。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", }, }, - "provider_settings.sandbox.cua_telemetry_enabled": { - "description": "CUA Telemetry", + "provider_settings.sandbox.member_permissions.create": { + "description": "允许普通用户创建沙箱", "type": "bool", - "hint": "是否允许 CUA SDK 发送遥测数据。默认关闭。", + "hint": "允许普通用户创建新的托管沙箱。普通用户的创建请求仍会受到最大沙箱数量限制。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", + "provider_settings.computer_use_require_admin": False, }, }, - "provider_settings.sandbox.cua_local": { - "description": "CUA Local Sandbox", + "provider_settings.sandbox.member_permissions.set_retention_policy": { + "description": "允许普通用户修改沙箱保留策略", "type": "bool", - "hint": "是否优先使用 CUA 本地沙箱。默认开启,避免云端沙箱要求 CUA_API_KEY。关闭后可使用 CUA 云端沙箱。", + "hint": "允许普通用户在临时沙箱和持久沙箱策略之间切换。持久沙箱会保留环境以便后续复用。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", + "provider_settings.computer_use_require_admin": False, }, }, - "provider_settings.sandbox.cua_api_key": { - "description": "CUA API Key", - "type": "string", - "hint": "CUA 云端沙箱 API Key。仅在关闭本地沙箱时需要。也可以通过 CUA_API_KEY 环境变量提供。", - "obvious_hint": True, - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", - "provider_settings.sandbox.cua_local": False, - }, - }, - "provider_settings.sandbox.shipyard_endpoint": { - "description": "Shipyard API Endpoint", - "type": "string", - "hint": "Shipyard 服务的 API 访问地址。", - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard", - }, - "_special": "check_shipyard_connection", - }, - "provider_settings.sandbox.shipyard_access_token": { - "description": "Shipyard Access Token", - "type": "string", - "hint": "用于访问 Shipyard 服务的访问令牌。", - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard", - }, - }, - "provider_settings.sandbox.shipyard_ttl": { - "description": "Shipyard Session TTL", - "type": "int", - "hint": "Shipyard 会话的生存时间(秒)。", + "provider_settings.sandbox.member_permissions.takeover": { + "description": "允许普通用户强占沙箱", + "type": "bool", + "hint": "允许普通用户强制接管被其他会话占用的沙箱。此操作会转移沙箱控制权,建议谨慎开启。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard", + "provider_settings.computer_use_require_admin": False, }, }, - "provider_settings.sandbox.shipyard_max_sessions": { - "description": "Shipyard Max Sessions", - "type": "int", - "hint": "Shipyard 最大会话数量。", + "provider_settings.sandbox.member_permissions.destroy": { + "description": "允许普通用户删除沙箱", + "type": "bool", + "hint": "允许普通用户删除自己可访问的托管沙箱。删除后沙箱环境和对应记录都会被移除。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard", + "provider_settings.computer_use_require_admin": False, }, }, }, diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index c325a2ea38..1173342970 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -19,6 +19,7 @@ from astrbot.api import logger, sp from astrbot.core import LogBroker, LogManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.computer import computer_client from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.cron import CronJobManager @@ -60,6 +61,7 @@ def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None: self.cron_manager: CronJobManager | None = None self.temp_dir_cleaner: TempDirCleaner | None = None self._default_chat_provider_warning_emitted = False + self._persistent_restore_task: asyncio.Task | None = None # 设置代理 proxy_config = self.astrbot_config.get("http_proxy", "") @@ -242,6 +244,17 @@ async def initialize(self) -> None: # 扫描、注册插件、实例化插件类 await self.plugin_manager.reload() + # Reconcile sandbox registry on startup to clear stale state and + # remove persistent records whose underlying resources no longer exist. + try: + await computer_client.sandbox_manager.reconcile_on_startup() + except Exception as e: + logger.warning( + "Sandbox startup reconciliation failed: %s", + e, + exc_info=True, + ) + # 根据配置实例化各个 Provider self._default_chat_provider_warning_emitted = False await self.provider_manager.initialize() @@ -276,6 +289,41 @@ async def initialize(self) -> None: asyncio.create_task(update_llm_metadata()) + async def _restore_persistent_sandboxes_background(self) -> None: + try: + # Do not let persistent sandbox recovery compete with the main + # startup path. Recovery is best-effort and should never delay the + # process becoming ready. + await asyncio.sleep(0) + ( + restored, + deleted, + ) = await computer_client.sandbox_manager.restore_persistent_sandboxes( + self.star_context, + per_sandbox_timeout=30.0, + ) + logger.info( + "Persistent sandbox restore finished: restored=%d deleted=%d", + restored, + deleted, + ) + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning( + "Persistent sandbox restore failed: %s", + e, + exc_info=True, + ) + + def _schedule_persistent_sandbox_restore(self) -> None: + if self._persistent_restore_task is not None: + return + self._persistent_restore_task = asyncio.create_task( + self._restore_persistent_sandboxes_background(), + name="persistent-sandbox-restore", + ) + def _load(self) -> None: """加载事件总线和任务并初始化.""" # 创建一个异步任务来执行事件总线的 dispatch() 方法 @@ -339,6 +387,7 @@ async def start(self) -> None: """ self._load() logger.info("AstrBot started.") + self._schedule_persistent_sandbox_restore() # 执行启动完成事件钩子 handlers = star_handlers_registry.get_handlers_by_event_type( @@ -368,6 +417,24 @@ async def stop(self) -> None: if self.cron_manager: await self.cron_manager.shutdown() + persistent_restore_task = getattr(self, "_persistent_restore_task", None) + if persistent_restore_task is not None: + persistent_restore_task.cancel() + try: + await persistent_restore_task + except asyncio.CancelledError: + pass + self._persistent_restore_task = None + + try: + await computer_client.cleanup_managed_sandboxes() + except Exception as e: + logger.warning( + "Managed sandbox cleanup during shutdown failed: %s", + e, + exc_info=True, + ) + for plugin in self.plugin_manager.context.get_all_stars(): try: await self.plugin_manager._terminate_plugin(plugin) @@ -399,6 +466,30 @@ async def stop(self) -> None: async def restart(self) -> None: """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" + for task in getattr(self, "curr_tasks", []): + task.cancel() + + if self.cron_manager: + await self.cron_manager.shutdown() + + persistent_restore_task = getattr(self, "_persistent_restore_task", None) + if persistent_restore_task is not None: + persistent_restore_task.cancel() + try: + await persistent_restore_task + except asyncio.CancelledError: + pass + self._persistent_restore_task = None + + try: + await computer_client.cleanup_managed_sandboxes() + except Exception as e: + logger.warning( + "Managed sandbox cleanup during restart failed: %s", + e, + exc_info=True, + ) + await self.provider_manager.terminate() await self.platform_manager.terminate() await self.kb_manager.terminate() diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 4b642d8ce5..b8706d6d6c 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,4 +1,5 @@ import asyncio +import os import re from collections.abc import AsyncGenerator @@ -31,6 +32,13 @@ def __init__( super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot + @staticmethod + def _parse_session_id_int(session_id: str | None) -> int | None: + try: + return int(str(session_id).strip()) + except (TypeError, ValueError): + return None + @staticmethod async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict: """修复部分字段""" @@ -86,6 +94,37 @@ async def _parse_onebot_json(message_chain: MessageChain): ret.append(d) return ret + @staticmethod + async def _upload_file_segment( + bot: CQHttp, + segment: File, + is_group: bool, + session_id: str | None, + ) -> None: + session_id_int = AiocqhttpMessageEvent._parse_session_id_int(session_id) + if not isinstance(session_id_int, int): + raise ValueError( + f"无法发送文件:缺少有效的数字 session_id({session_id})", + ) + + file_path = await segment.get_file(allow_return_url=True) + if not file_path: + raise ValueError("无法发送文件:文件路径为空") + name = segment.name or os.path.basename(str(file_path)) or "file" + + if is_group: + await bot.upload_group_file( + group_id=session_id_int, + file=file_path, + name=name, + ) + else: + await bot.upload_private_file( + user_id=session_id_int, + file=file_path, + name=name, + ) + @classmethod async def _dispatch_send( cls, @@ -95,10 +134,8 @@ async def _dispatch_send( session_id: str | None, messages: list[dict], ) -> None: - # session_id 必须是纯数字字符串 - session_id_int = ( - int(session_id) if session_id and session_id.isdigit() else None - ) + # session_id 必须是数字字符串 + session_id_int = cls._parse_session_id_int(session_id) if is_group and isinstance(session_id_int, int): await bot.send_group_msg(group_id=session_id_int, message=messages) @@ -156,8 +193,11 @@ async def send_message( payload["user_id"] = session_id await bot.call_action("send_private_forward_msg", **payload) elif isinstance(seg, File): - d = await cls._from_segment_to_dict(seg) - await cls._dispatch_send(bot, event, is_group, session_id, [d]) + try: + await cls._upload_file_segment(bot, seg, is_group, session_id) + except Exception: + messages = await cls._parse_onebot_json(MessageChain([seg])) + await cls._dispatch_send(bot, event, is_group, session_id, messages) else: messages = await cls._parse_onebot_json(MessageChain([seg])) if not messages: diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index f3b3cb77c2..8fc32e9f1f 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -423,6 +423,18 @@ def get_builtin_tool(self, tool: str | type[FuncTool]) -> FuncTool: self.builtin_func_list[tool_cls] = builtin_tool return builtin_tool + def clear_builtin_tool_cache_by_module_prefix( + self, module_prefix: str + ) -> list[str]: + removed: list[str] = [] + for tool_cls in tuple(self.builtin_func_list): + if not getattr(tool_cls, "__module__", "").startswith(module_prefix): + continue + tool = self.builtin_func_list.pop(tool_cls, None) + tool_name = get_builtin_tool_name(tool_cls) or getattr(tool, "name", None) + removed.append(tool_name or tool_cls.__name__) + return removed + def iter_builtin_tools(self) -> list[FuncTool]: ensure_builtin_tools_loaded() return [ diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ae4001fcd6..8780e48176 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -586,6 +586,14 @@ async def load_provider(self, provider_config: dict) -> None: return if provider_config.get("provider_type", "") == "agent_runner": return + if "type" not in provider_config: + logger.warning( + "Provider %s has no adapter type after merging provider source %s, skipping. " + "This is likely a stale provider model config; remove it in the dashboard.", + provider_config.get("id"), + provider_config.get("provider_source_id"), + ) + return logger.info( "Loading model %s(%s) ...", diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py index 838301c044..c13c079582 100644 --- a/astrbot/core/skills/skill_manager.py +++ b/astrbot/core/skills/skill_manager.py @@ -6,6 +6,7 @@ import shlex import shutil import tempfile +import threading import uuid import zipfile from dataclasses import dataclass @@ -29,6 +30,11 @@ _SANDBOX_SKILLS_CACHE_VERSION = 1 _SKILL_NAME_RE = re.compile(r"^[\w.-]+$") +_SANDBOX_SKILLS_CACHE_LOCK = threading.RLock() + + +def _normalize_cache_provider_id(provider_id: str | None) -> str: + return str(provider_id or "").strip().lower() def _normalize_skill_name(name: str | None) -> str: @@ -353,22 +359,38 @@ def _save_config(self, config: dict) -> None: def _load_sandbox_skills_cache(self) -> dict: if not os.path.exists(self.sandbox_skills_cache_path): - return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []} + return { + "version": _SANDBOX_SKILLS_CACHE_VERSION, + "skills": [], + "providers": {}, + } try: with open(self.sandbox_skills_cache_path, encoding="utf-8") as f: data = json.load(f) if not isinstance(data, dict): - return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []} + return { + "version": _SANDBOX_SKILLS_CACHE_VERSION, + "skills": [], + "providers": {}, + } skills = data.get("skills", []) if not isinstance(skills, list): skills = [] + providers = data.get("providers", {}) + if not isinstance(providers, dict): + providers = {} return { "version": int(data.get("version", _SANDBOX_SKILLS_CACHE_VERSION)), "skills": skills, + "providers": providers, "updated_at": data.get("updated_at"), } except Exception: - return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []} + return { + "version": _SANDBOX_SKILLS_CACHE_VERSION, + "skills": [], + "providers": {}, + } def _save_sandbox_skills_cache(self, cache: dict) -> None: cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION @@ -376,7 +398,11 @@ def _save_sandbox_skills_cache(self, cache: dict) -> None: with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f: json.dump(cache, f, ensure_ascii=False, indent=2) - def set_sandbox_skills_cache(self, skills: list[dict]) -> None: + def set_sandbox_skills_cache( + self, + skills: list[dict], + provider_id: str | None = None, + ) -> None: """Persist sandbox skill metadata discovered from runtime side.""" deduped: dict[str, dict[str, str]] = {} for item in skills: @@ -394,16 +420,72 @@ def set_sandbox_skills_cache(self, skills: list[dict]) -> None: "description": description, "path": path, } - cache = { - "version": _SANDBOX_SKILLS_CACHE_VERSION, - "skills": [deduped[name] for name in sorted(deduped)], - } - self._save_sandbox_skills_cache(cache) + provider_key = _normalize_cache_provider_id(provider_id) + skills_payload = [deduped[name] for name in sorted(deduped)] + with _SANDBOX_SKILLS_CACHE_LOCK: + cache = self._load_sandbox_skills_cache() + providers = cache.get("providers", {}) + if not isinstance(providers, dict): + providers = {} + if provider_key: + providers[provider_key] = { + "skills": skills_payload, + } + else: + cache["skills"] = skills_payload + providers["default"] = { + "skills": skills_payload, + } + cache = { + "version": _SANDBOX_SKILLS_CACHE_VERSION, + "skills": cache.get("skills", []), + "providers": providers, + } + self._save_sandbox_skills_cache(cache) + + def _sandbox_cache_skills_for_provider( + self, cache: dict, provider_id: str | None + ) -> list[dict]: + provider_key = _normalize_cache_provider_id(provider_id) + providers = cache.get("providers", {}) + if provider_key and isinstance(providers, dict): + provider_cache = providers.get(provider_key) + if isinstance(provider_cache, dict): + skills = provider_cache.get("skills", []) + return skills if isinstance(skills, list) else [] + return [] + + skills = cache.get("skills", []) + if isinstance(skills, list): + return skills + return [] def get_sandbox_skills_cache_status(self) -> dict[str, object]: cache = self._load_sandbox_skills_cache() - skills = cache.get("skills", []) - count = len(skills) if isinstance(skills, list) else 0 + count = 0 + seen: set[str] = set() + for item in cache.get("skills", []): + if not isinstance(item, dict): + continue + name = str(item.get("name", "")).strip() + if name and name not in seen: + seen.add(name) + count += 1 + providers = cache.get("providers", {}) + if isinstance(providers, dict): + for provider_cache in providers.values(): + if not isinstance(provider_cache, dict): + continue + skills = provider_cache.get("skills", []) + if not isinstance(skills, list): + continue + for item in skills: + if not isinstance(item, dict): + continue + name = str(item.get("name", "")).strip() + if name and name not in seen: + seen.add(name) + count += 1 return { "exists": os.path.exists(self.sandbox_skills_cache_path), "ready": count > 0, @@ -416,6 +498,7 @@ def list_skills( *, active_only: bool = False, runtime: str = "local", + provider_id: str | None = None, show_sandbox_path: bool = True, ) -> list[SkillInfo]: """List all skills. @@ -432,7 +515,9 @@ def list_skills( sandbox_cached_paths: dict[str, str] = {} sandbox_cached_descriptions: dict[str, str] = {} cache_for_paths = self._load_sandbox_skills_cache() - for item in cache_for_paths.get("skills", []): + for item in self._sandbox_cache_skills_for_provider( + cache_for_paths, provider_id + ): if not isinstance(item, dict): continue name = str(item.get("name", "") or "").strip() @@ -527,8 +612,10 @@ def list_skills( ) if runtime == "sandbox": - cache = self._load_sandbox_skills_cache() - for item in cache.get("skills", []): + cache = self._sandbox_cache_skills_for_provider( + self._load_sandbox_skills_cache(), provider_id + ) + for item in cache: if not isinstance(item, dict): continue skill_name = str(item.get("name", "")).strip() @@ -574,14 +661,24 @@ def is_sandbox_only_skill(self, name: str) -> bool: if skill_md_exists: return False cache = self._load_sandbox_skills_cache() - skills = cache.get("skills", []) - if not isinstance(skills, list): - return False - for item in skills: + for item in cache.get("skills", []): if not isinstance(item, dict): continue if str(item.get("name", "")).strip() == name: return True + providers = cache.get("providers", {}) + if isinstance(providers, dict): + for provider_cache in providers.values(): + if not isinstance(provider_cache, dict): + continue + skills = provider_cache.get("skills", []) + if not isinstance(skills, list): + continue + for item in skills: + if not isinstance(item, dict): + continue + if str(item.get("name", "")).strip() == name: + return True return False def is_plugin_skill(self, name: str) -> bool: @@ -598,22 +695,47 @@ def set_skill_active(self, name: str, active: bool) -> None: self._save_config(config) def _remove_skill_from_sandbox_cache(self, name: str) -> None: - cache = self._load_sandbox_skills_cache() - skills = cache.get("skills", []) - if not isinstance(skills, list): - return - - filtered = [ - item - for item in skills - if not ( - isinstance(item, dict) and str(item.get("name", "")).strip() == name - ) - ] - - if len(filtered) != len(skills): - cache["skills"] = filtered - self._save_sandbox_skills_cache(cache) + with _SANDBOX_SKILLS_CACHE_LOCK: + cache = self._load_sandbox_skills_cache() + changed = False + skills = cache.get("skills", []) + if isinstance(skills, list): + filtered = [ + item + for item in skills + if not ( + isinstance(item, dict) + and str(item.get("name", "")).strip() == name + ) + ] + if len(filtered) != len(skills): + cache["skills"] = filtered + changed = True + + providers = cache.get("providers", {}) + if isinstance(providers, dict): + for provider_key, provider_cache in list(providers.items()): + if not isinstance(provider_cache, dict): + continue + provider_skills = provider_cache.get("skills", []) + if not isinstance(provider_skills, list): + continue + filtered = [ + item + for item in provider_skills + if not ( + isinstance(item, dict) + and str(item.get("name", "")).strip() == name + ) + ] + if len(filtered) != len(provider_skills): + provider_cache["skills"] = filtered + providers[provider_key] = provider_cache + changed = True + + if changed: + cache["providers"] = providers + self._save_sandbox_skills_cache(cache) def delete_skill(self, name: str) -> None: if self.is_sandbox_only_skill(name): diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 824c3b653b..738f6b557e 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -31,6 +31,7 @@ from astrbot.core.config.default import VERSION from astrbot.core.platform.register import unregister_platform_adapters_by_module from astrbot.core.provider.register import llm_tools +from astrbot.core.tools.registry import unregister_builtin_tools_by_module_prefix from astrbot.core.utils.astrbot_path import ( get_astrbot_config_path, get_astrbot_path, @@ -745,6 +746,14 @@ def _cleanup_plugin_state(self, dir_name: str, is_reserved: bool = False) -> Non llm_tools.func_list.remove(tool) logger.info(f"清理工具: {tool.name}") + for tool_name in llm_tools.clear_builtin_tool_cache_by_module_prefix( + module_prefix + ): + logger.info(f"清理内置工具缓存: {tool_name}") + + for tool_name in unregister_builtin_tools_by_module_prefix(module_prefix): + logger.info(f"清理内置工具注册: {tool_name}") + for adapter_name in unregister_platform_adapters_by_module(module_prefix): logger.info(f"清理平台适配器: {adapter_name}") @@ -1667,6 +1676,18 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str) -> Non # module_path is like "data.plugins.my_plugin.main", extract prefix like "data.plugins.my_plugin" module_prefix = ".".join(plugin_module_path.split(".")[:-1]) if module_prefix: + for tool_name in llm_tools.clear_builtin_tool_cache_by_module_prefix( + module_prefix + ): + logger.info( + f"移除了插件 {plugin_name} 的内置工具缓存 {tool_name}", + ) + + for tool_name in unregister_builtin_tools_by_module_prefix(module_prefix): + logger.info( + f"移除了插件 {plugin_name} 的内置工具注册 {tool_name}", + ) + unregistered_adapters = unregister_platform_adapters_by_module( module_prefix ) diff --git a/astrbot/core/tools/computer_tools/__init__.py b/astrbot/core/tools/computer_tools/__init__.py index f90c2e1de8..16610e9e74 100644 --- a/astrbot/core/tools/computer_tools/__init__.py +++ b/astrbot/core/tools/computer_tools/__init__.py @@ -1,8 +1,3 @@ -from .cua import ( - CuaKeyboardTypeTool, - CuaMouseClickTool, - CuaScreenshotTool, -) from .fs import ( FileDownloadTool, FileEditTool, @@ -12,52 +7,23 @@ GrepTool, ) from .python import LocalPythonTool, PythonTool +from .sandbox import SandboxLifecycleTool, SandboxOperationTool, SandboxQueryTool from .shell import ExecuteShellTool -from .shipyard_neo import ( - AnnotateExecutionTool, - BrowserBatchExecTool, - BrowserExecTool, - CreateSkillCandidateTool, - CreateSkillPayloadTool, - EvaluateSkillCandidateTool, - GetExecutionHistoryTool, - GetSkillPayloadTool, - ListSkillCandidatesTool, - ListSkillReleasesTool, - PromoteSkillCandidateTool, - RollbackSkillReleaseTool, - RunBrowserSkillTool, - SyncSkillReleaseTool, -) from .util import check_admin_permission, normalize_umo_for_workspace __all__ = [ - "AnnotateExecutionTool", - "BrowserBatchExecTool", - "BrowserExecTool", - "CreateSkillCandidateTool", - "CreateSkillPayloadTool", - "CuaKeyboardTypeTool", - "CuaMouseClickTool", - "CuaScreenshotTool", - "EvaluateSkillCandidateTool", "ExecuteShellTool", "FileDownloadTool", "FileEditTool", "FileReadTool", "FileUploadTool", "FileWriteTool", - "GetExecutionHistoryTool", - "GetSkillPayloadTool", "GrepTool", - "ListSkillCandidatesTool", - "ListSkillReleasesTool", "LocalPythonTool", - "PromoteSkillCandidateTool", "PythonTool", - "RollbackSkillReleaseTool", - "RunBrowserSkillTool", - "SyncSkillReleaseTool", + "SandboxQueryTool", + "SandboxLifecycleTool", + "SandboxOperationTool", "normalize_umo_for_workspace", "check_admin_permission", ] diff --git a/astrbot/core/tools/computer_tools/cua.py b/astrbot/core/tools/computer_tools/cua.py deleted file mode 100644 index 7b37a55086..0000000000 --- a/astrbot/core/tools/computer_tools/cua.py +++ /dev/null @@ -1,177 +0,0 @@ -from __future__ import annotations - -import json -import uuid -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - -import mcp - -from astrbot.api import FunctionTool -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import ToolExecResult -from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.computer.computer_client import get_booter -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.tools.computer_tools.util import check_admin_permission -from astrbot.core.tools.registry import builtin_tool -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -_CUA_TOOL_CONFIG = { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", -} - - -def _to_json(data: Any) -> str: - return json.dumps(data, ensure_ascii=False, default=str) - - -def _exception_detail(error: Exception) -> str: - return str(error) or type(error).__name__ - - -async def _get_gui_component(context: ContextWrapper[AstrAgentContext]) -> Any: - booter = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) - gui = getattr(booter, "gui", None) - if gui is None: - raise RuntimeError( - "Current sandbox booter does not support CUA GUI capability. " - "Please switch sandbox booter to cua." - ) - return gui - - -@builtin_tool(config=_CUA_TOOL_CONFIG) -@dataclass -class CuaScreenshotTool(FunctionTool): - name: str = "astrbot_cua_screenshot" - description: str = ( - "Capture a screenshot from the CUA sandbox and optionally send it to the user." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "send_to_user": { - "type": "boolean", - "description": "Whether to send the screenshot image to the current conversation.", - "default": True, - }, - "return_image_to_llm": { - "type": "boolean", - "description": "Whether to include the screenshot image content in the tool result for model inspection.", - "default": True, - }, - }, - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - send_to_user: bool = True, - return_image_to_llm: bool = True, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Taking CUA screenshots"): - return err - try: - gui = await _get_gui_component(context) - path = _new_screenshot_path(context.context.event.unified_msg_origin) - result = await gui.screenshot(path) - payload = {"success": True, **result, "path": path} - if send_to_user: - await context.context.event.send(MessageChain().file_image(path)) - payload["sent_to_user"] = True - image_data = payload.pop("base64", "") - content: list[mcp.types.TextContent | mcp.types.ImageContent] = [ - mcp.types.TextContent(type="text", text=_to_json(payload)) - ] - if return_image_to_llm: - content.append( - mcp.types.ImageContent( - type="image", - data=str(image_data), - mimeType=str(payload.get("mime_type", "image/png")), - ) - ) - return mcp.types.CallToolResult(content=content) - except Exception as e: - return f"Error taking CUA screenshot: {_exception_detail(e)}" - - -@builtin_tool(config=_CUA_TOOL_CONFIG) -@dataclass -class CuaMouseClickTool(FunctionTool): - name: str = "astrbot_cua_mouse_click" - description: str = "Click a coordinate in the CUA sandbox desktop." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "x": {"type": "integer", "description": "X coordinate."}, - "y": {"type": "integer", "description": "Y coordinate."}, - "button": { - "type": "string", - "description": "Mouse button, usually left, right, or middle.", - "default": "left", - }, - }, - "required": ["x", "y"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - x: int, - y: int, - button: str = "left", - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using CUA mouse"): - return err - try: - gui = await _get_gui_component(context) - return _to_json(await gui.click(x, y, button=button)) - except Exception as e: - return f"Error clicking CUA desktop: {_exception_detail(e)}" - - -@builtin_tool(config=_CUA_TOOL_CONFIG) -@dataclass -class CuaKeyboardTypeTool(FunctionTool): - name: str = "astrbot_cua_keyboard_type" - description: str = "Type text into the CUA sandbox desktop." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to type."}, - }, - "required": ["text"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - text: str, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using CUA keyboard"): - return err - try: - gui = await _get_gui_component(context) - return _to_json(await gui.type_text(text)) - except Exception as e: - return f"Error typing in CUA desktop: {_exception_detail(e)}" - - -def _new_screenshot_path(umo: str) -> str: - safe_prefix = uuid.uuid5(uuid.NAMESPACE_DNS, umo).hex[:12] - screenshot_dir = Path(get_astrbot_temp_path()) / "cua_screenshots" - screenshot_dir.mkdir(parents=True, exist_ok=True) - return str(screenshot_dir / f"{safe_prefix}-{uuid.uuid4().hex}.png") diff --git a/astrbot/core/tools/computer_tools/fs.py b/astrbot/core/tools/computer_tools/fs.py index 5660022fd0..a44dc1f432 100644 --- a/astrbot/core/tools/computer_tools/fs.py +++ b/astrbot/core/tools/computer_tools/fs.py @@ -70,6 +70,10 @@ _IMAGE_FILE_SUFFIXES = {".bmp", ".gif", ".jpeg", ".jpg", ".png", ".webp"} +def _remote_basename(path: str) -> str: + return path.replace("\\", "/").rstrip("/").split("/")[-1] + + def _restricted_env_path_labels(umo: str, *, include_plugin_skills: bool) -> list[str]: """Labels for the allowed directories in a local(not sandbox) and restricted(not admin) environment""" normalized_umo = normalize_umo_for_workspace(umo) @@ -772,7 +776,7 @@ async def call( context.context.event.unified_msg_origin, ) try: - name = os.path.basename(remote_path) + name = _remote_basename(remote_path) or os.path.basename(remote_path) local_path = os.path.join( get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}" @@ -784,7 +788,9 @@ async def call( if also_send_to_user: try: - name = os.path.basename(local_path) + # Keep the user-facing filename stable; the local temp path + # still carries a random prefix to avoid collisions. + name = _remote_basename(remote_path) or os.path.basename(local_path) if Path(local_path).suffix.lower() in _IMAGE_FILE_SUFFIXES: message_component = Image.fromFileSystem(local_path) sent_as = "image" diff --git a/astrbot/core/tools/computer_tools/sandbox.py b/astrbot/core/tools/computer_tools/sandbox.py new file mode 100644 index 0000000000..29421dcabd --- /dev/null +++ b/astrbot/core/tools/computer_tools/sandbox.py @@ -0,0 +1,779 @@ +import json +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path + +import mcp + +from astrbot.api import FunctionTool +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import ToolExecResult +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.computer import computer_client +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..registry import builtin_tool +from .util import check_admin_permission + +_SANDBOX_RUNTIME_TOOL_CONFIG = { + "provider_settings.computer_use_runtime": "sandbox", +} + + +def _dump(data) -> str: + return json.dumps(data, ensure_ascii=False, default=str) + + +def _remote_basename(path: str) -> str: + return path.replace("\\", "/").rstrip("/").split("/")[-1] + + +def _format_agent_time(value: int | float | None) -> str | None: + if value is None: + return None + if isinstance(value, bool): + return str(value) + if not isinstance(value, (int, float)): + return str(value) + return ( + datetime.fromtimestamp(float(value)) + .astimezone() + .strftime("%Y-%m-%d %H:%M:%S %Z") + ) + + +def _format_sandbox_for_agent(value): + if isinstance(value, list): + return [_format_sandbox_for_agent(item) for item in value] + if not isinstance(value, dict): + return value + formatted = {} + for key, item in value.items(): + if key.endswith("_at"): + formatted[key] = _format_agent_time(item) + else: + formatted[key] = _format_sandbox_for_agent(item) + return formatted + + +def _lease_metadata_for_agent( + context: ContextWrapper[AstrAgentContext], + sandbox_id: str, + record: dict | None = None, +) -> dict | None: + if not sandbox_id: + return None + manager = _sandbox_manager() + if record is None: + registry = getattr(manager, "registry", None) + get_sandbox = getattr(registry, "get_sandbox", None) + record = get_sandbox(sandbox_id) if callable(get_sandbox) else None + if not record: + return None + lease_expires_at = record.get("lease_expires_at") + lease_expires_in_seconds = None + if isinstance(lease_expires_at, (int, float)) and not isinstance( + lease_expires_at, bool + ): + lease_expires_in_seconds = max(0, int(float(lease_expires_at) - time.time())) + return { + "sandbox_id": sandbox_id, + "lease_expires_at": _format_agent_time(lease_expires_at), + "lease_expires_in_seconds": lease_expires_in_seconds, + "auto_renew_interval_seconds": _lease_timeout_for_agent(context), + } + + +def _lease_timeout_for_agent(context: ContextWrapper[AstrAgentContext]) -> float: + manager = _sandbox_manager() + lease_timeout = getattr(manager, "_lease_timeout", None) + if callable(lease_timeout): + return lease_timeout( + context.context.context, + context.context.event.unified_msg_origin, + ) + return float(_sandbox_config(context).get("sandbox_lease_timeout", 600)) + + +def _attach_lease_metadata( + payload: dict, + context: ContextWrapper[AstrAgentContext], + sandbox_id: str | None, + record: dict | None = None, +) -> dict: + if not sandbox_id: + return payload + lease = _lease_metadata_for_agent(context, sandbox_id, record) + if lease is not None: + payload["lease"] = lease + return payload + + +def _sandbox_manager(): + return computer_client.sandbox_manager + + +def _current_provider_id(context: ContextWrapper[AstrAgentContext]) -> str: + plugin_context = context.context.context + session_id = context.context.event.unified_msg_origin + config = plugin_context.get_config(umo=session_id) + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + return str(sandbox_cfg.get("booter", "")).strip() + + +def _is_admin(context: ContextWrapper[AstrAgentContext]) -> bool: + return context.context.event.role == "admin" + + +def _sandbox_config(context: ContextWrapper[AstrAgentContext]) -> dict: + config = context.context.context.get_config( + umo=context.context.event.unified_msg_origin + ) + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + return sandbox_cfg if isinstance(sandbox_cfg, dict) else {} + + +def _member_sandbox_permission_enabled( + context: ContextWrapper[AstrAgentContext], permission: str +) -> bool: + permissions = _sandbox_config(context).get("member_permissions", {}) + if not isinstance(permissions, dict): + return False + return bool(permissions.get(permission, False)) + + +def _check_basic_sandbox_permission( + context: ContextWrapper[AstrAgentContext], operation_name: str +) -> str | None: + return check_admin_permission(context, operation_name) + + +def _check_member_sandbox_permission( + context: ContextWrapper[AstrAgentContext], operation_name: str, permission: str +) -> str | None: + if permission_error := check_admin_permission(context, operation_name): + return permission_error + if _is_admin(context) or _member_sandbox_permission_enabled(context, permission): + return None + return ( + f"error: Permission denied. {operation_name} is disabled for non-admin users " + "by sandbox member permission settings." + ) + + +def _visible_to_session(record: dict, session_id: str) -> bool: + return record.get("controller_session_id") == session_id or _is_idle_sandbox(record) + + +def _is_idle_sandbox(record: dict) -> bool: + controller_session_id = record.get("controller_session_id") + if not controller_session_id: + return True + lease_expires_at = record.get("lease_expires_at") + return bool(lease_expires_at and lease_expires_at <= time.time()) + + +def _sandbox_status_for_session(record: dict, session_id: str) -> str: + controller_session_id = record.get("controller_session_id") + if controller_session_id == session_id: + return "current" + if controller_session_id and not _is_idle_sandbox(record): + return "occupied" + return "idle" + + +def _redact_sandbox_for_session(record: dict, session_id: str, *, admin: bool) -> dict: + visible = dict(record) + visible["access"] = { + "status": _sandbox_status_for_session(record, session_id), + "can_switch": _visible_to_session(record, session_id), + "occupied": not _is_idle_sandbox(record), + } + if admin: + return visible + visible.pop("connect_info", None) + visible["owner_session_id"] = None + visible["owner_user_id"] = None + visible["created_by_session_id"] = None + visible["created_by_user_id"] = None + if record.get("controller_session_id") != session_id: + visible["controller_session_id"] = None + visible["controller_user_id"] = None + return visible + + +def _sandbox_access_denied( + context: ContextWrapper[AstrAgentContext], record: dict | None +) -> str | None: + if record is None or _is_admin(context): + return None + session_id = context.context.event.unified_msg_origin + if _visible_to_session(record, session_id): + return None + return "error: Permission denied. This sandbox belongs to another session." + + +async def _query_list_sandboxes( + context: ContextWrapper[AstrAgentContext], +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Listing sandboxes" + ): + return permission_error + session_id = context.context.event.unified_msg_origin + manager = _sandbox_manager() + list_checked = getattr(manager, "list_sandboxes_checked", None) + if callable(list_checked): + sandboxes = await list_checked() + else: + sandboxes = manager.list_sandboxes() + sandboxes = [ + _redact_sandbox_for_session(record, session_id, admin=_is_admin(context)) + for record in sandboxes + ] + return _dump({"sandboxes": _format_sandbox_for_agent(sandboxes)}) + + +async def _query_list_providers( + context: ContextWrapper[AstrAgentContext], +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Listing sandbox providers" + ): + return permission_error + return _dump({"providers": computer_client.list_sandbox_providers()}) + + +async def _query_get_current( + context: ContextWrapper[AstrAgentContext], +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Getting current sandbox" + ): + return permission_error + session_id = context.context.event.unified_msg_origin + current = _sandbox_manager().get_current_sandbox(session_id) + payload = _format_sandbox_for_agent(current) + return _dump( + _attach_lease_metadata( + payload, + context, + current.get("current_sandbox_id"), + current.get("sandbox"), + ) + ) + + +async def _lifecycle_create( + context: ContextWrapper[AstrAgentContext], + sandbox_name: str = "", + provider_id: str = "", +) -> ToolExecResult: + if permission_error := _check_member_sandbox_permission( + context, "Creating sandbox", "create" + ): + return permission_error + + plugin_context = context.context.context + session_id = context.context.event.unified_msg_origin + requested_provider_id = str(provider_id).strip().lower() + if requested_provider_id: + provider_id = requested_provider_id + else: + provider_id = _current_provider_id(context) + if not provider_id: + return "Error creating sandbox: sandbox booter is not configured." + manager = _sandbox_manager() + if provider_id not in manager.providers: + providers = computer_client.list_sandbox_providers() + available = ", ".join(p["provider_id"] for p in providers) or "none" + return ( + f"Error creating sandbox: sandbox provider '{provider_id}' is not " + f"available. Available providers: {available}." + ) + + try: + sandbox = await manager.create_sandbox( + plugin_context, + session_id, + provider_id, + sandbox_name=sandbox_name.strip() or None, + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error creating sandbox: {detail}" + + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_switch( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Switching sandbox" + ): + return permission_error + if not sandbox_id: + return "Error switching sandbox: sandbox_id is required." + session_id = context.context.event.unified_msg_origin + manager = _sandbox_manager() + record = manager.registry.get_sandbox(sandbox_id) + if permission_error := _sandbox_access_denied(context, record): + return permission_error + try: + sandbox = await manager.switch_current_sandbox_checked( + session_id, sandbox_id, context=context.context.context + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error switching sandbox: {detail}" + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_release( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Releasing sandbox" + ): + return permission_error + session_id = context.context.event.unified_msg_origin + try: + sandbox = _sandbox_manager().release_current_sandbox( + session_id, sandbox_id.strip() or None + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error releasing sandbox: {detail}" + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_set_retention( + context: ContextWrapper[AstrAgentContext], + retention_policy: str = "", + sandbox_id: str = "", + sandbox_name: str = "", +) -> ToolExecResult: + if permission_error := _check_member_sandbox_permission( + context, "Changing sandbox retention policy", "set_retention_policy" + ): + return permission_error + if not retention_policy: + return "Error changing sandbox retention policy: retention_policy is required." + manager = _sandbox_manager() + session_id = context.context.event.unified_msg_origin + target_sandbox_id = sandbox_id.strip() + if not target_sandbox_id: + current = manager.get_current_sandbox(session_id) + target_sandbox_id = current.get("current_sandbox_id") or "" + if not target_sandbox_id: + return "Error changing sandbox retention policy: No current sandbox" + record = manager.registry.get_sandbox(target_sandbox_id) + if permission_error := _sandbox_access_denied(context, record): + return permission_error + try: + sandbox = manager.set_sandbox_retention_policy( + context.context.context, + session_id, + target_sandbox_id, + retention_policy.strip().lower(), + sandbox_name=sandbox_name.strip() or None, + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error changing sandbox retention policy: {detail}" + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_renew_lease( + context: ContextWrapper[AstrAgentContext], + ttl_seconds: int | float | None = None, +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Renewing sandbox lease" + ): + return permission_error + session_id = context.context.event.unified_msg_origin + try: + sandbox = await _sandbox_manager().renew_current_sandbox_lease( + session_id, ttl_seconds=ttl_seconds, context=context.context.context + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error renewing sandbox lease: {detail}" + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_takeover( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> ToolExecResult: + if permission_error := _check_member_sandbox_permission( + context, "Taking over sandbox", "takeover" + ): + return permission_error + if not sandbox_id: + return "Error taking over sandbox: sandbox_id is required." + session_id = context.context.event.unified_msg_origin + try: + sandbox = await _sandbox_manager().takeover_sandbox( + session_id, sandbox_id, context=context.context.context + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error taking over sandbox: {detail}" + return _dump({"sandbox": _format_sandbox_for_agent(sandbox)}) + + +async def _lifecycle_destroy( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> ToolExecResult: + if permission_error := _check_member_sandbox_permission( + context, "Destroying sandbox", "destroy" + ): + return permission_error + if not sandbox_id: + return "Error destroying sandbox: sandbox_id is required." + session_id = context.context.event.unified_msg_origin + try: + sandbox = await _sandbox_manager().destroy_sandbox(session_id, sandbox_id) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error destroying sandbox: {detail}" + return _dump({"sandbox": _format_sandbox_for_agent(sandbox)}) + + +def _current_sandbox_id_for_operation( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> str: + target_sandbox_id = sandbox_id.strip() + if target_sandbox_id: + return target_sandbox_id + current = _sandbox_manager().get_current_sandbox( + context.context.event.unified_msg_origin + ) + return str(current.get("current_sandbox_id") or "").strip() + + +async def _operation_capture_screenshot( + context: ContextWrapper[AstrAgentContext], + sandbox_id: str = "", + send_to_user: bool = False, + return_image_to_llm: bool = False, +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Sandbox screenshot capture" + ): + return permission_error + target_sandbox_id = _current_sandbox_id_for_operation(context, sandbox_id) + if not target_sandbox_id: + return "Error taking sandbox screenshot: No current sandbox" + try: + booter = await _sandbox_manager().get_observer_booter_by_id( + target_sandbox_id, + context.context.event.unified_msg_origin, + context=context.context.context, + ) + gui = getattr(booter, "gui", None) + if gui is None: + return f"Error taking sandbox screenshot: sandbox {target_sandbox_id} does not support screenshots." + screenshot_dir = Path(get_astrbot_temp_path()) / "sandbox_screenshots" + screenshot_dir.mkdir(parents=True, exist_ok=True) + path = str(screenshot_dir / f"{uuid.uuid4().hex}.png") + result = await gui.screenshot(path) + payload = {"sandbox_id": target_sandbox_id, "path": path, **result} + if send_to_user: + await context.context.event.send(MessageChain().file_image(path)) + payload["sent_to_user"] = True + image_data = payload.pop("base64", "") + payload = _attach_lease_metadata(payload, context, target_sandbox_id) + if return_image_to_llm: + content: list[mcp.types.TextContent | mcp.types.ImageContent] = [ + mcp.types.TextContent(type="text", text=_dump(payload)) + ] + if image_data: + content.append( + mcp.types.ImageContent( + type="image", + data=str(image_data), + mimeType=str(payload.get("mime_type", "image/png")), + ) + ) + return mcp.types.CallToolResult(content=content) + if image_data: + payload["base64"] = image_data + return _dump(payload) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error taking sandbox screenshot: {detail}" + + +async def _operation_copy_file( + context: ContextWrapper[AstrAgentContext], + source_sandbox_id: str = "", + source_path: str = "", + target_sandbox_id: str = "", + target_path: str = "", +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Copying files between sandboxes" + ): + return permission_error + if not all([source_sandbox_id, source_path, target_sandbox_id, target_path]): + return "Error copying file between sandboxes: source_sandbox_id, source_path, target_sandbox_id, and target_path are required." + try: + manager = _sandbox_manager() + session_id = context.context.event.unified_msg_origin + source = await manager.get_observer_booter_by_id( + source_sandbox_id, session_id, context=context.context.context + ) + target = await manager.get_observer_booter_by_id( + target_sandbox_id, session_id, context=context.context.context + ) + temp_dir = Path(get_astrbot_temp_path()) / "sandbox_copy" + temp_dir.mkdir(parents=True, exist_ok=True) + local_path = temp_dir / f"{uuid.uuid4().hex}-{_remote_basename(target_path)}" + try: + await source.download_file(source_path, str(local_path)) + upload_result = await target.upload_file(str(local_path), target_path) + finally: + try: + local_path.unlink(missing_ok=True) + except OSError: + pass + return _dump( + _attach_lease_metadata( + { + "source_sandbox_id": source_sandbox_id, + "source_path": source_path, + "target_sandbox_id": target_sandbox_id, + "target_path": target_path, + "upload_result": upload_result, + }, + context, + target_sandbox_id, + ) + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error copying file between sandboxes: {detail}" + + +@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG) +@dataclass +class SandboxQueryTool(FunctionTool): + name: str = "astrbot_sandbox_query" + description: str = ( + "Query managed sandboxes, the current sandbox, or loaded sandbox providers. " + "Actions: list_sandboxes has no extra parameters; get_current has no extra parameters; " + "list_providers has no extra parameters. Use list_sandboxes before creating a new sandbox " + "when you need to find a reusable one." + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["list_sandboxes", "get_current", "list_providers"], + "description": "Query action to perform.", + } + }, + "required": ["action"], + } + ) + + async def call( + self, context: ContextWrapper[AstrAgentContext], action: str + ) -> ToolExecResult: + match action: + case "list_sandboxes": + return await _query_list_sandboxes(context) + case "get_current": + return await _query_get_current(context) + case "list_providers": + return await _query_list_providers(context) + return f"Error querying sandbox: unsupported action '{action}'." + + +@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG) +@dataclass +class SandboxLifecycleTool(FunctionTool): + name: str = "astrbot_sandbox_lifecycle" + description: str = ( + "Manage sandbox lifecycle and session occupancy: create, switch, release, " + "renew lease, set retention, takeover, or destroy a sandbox. " + "Actions: create accepts sandbox_name and provider_id; switch requires sandbox_id; " + "release accepts optional sandbox_id; renew_lease accepts ttl_seconds; " + "set_retention requires retention_policy and accepts sandbox_id/sandbox_name; " + "takeover requires sandbox_id; destroy requires sandbox_id." + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": [ + "create", + "switch", + "release", + "renew_lease", + "set_retention", + "takeover", + "destroy", + ], + "description": "Lifecycle action to perform.", + }, + "sandbox_id": { + "type": "string", + "description": "Target sandbox ID for switch, release, retention, takeover, or destroy.", + }, + "sandbox_name": { + "type": "string", + "description": "Optional sandbox name for create or set_retention.", + }, + "provider_id": { + "type": "string", + "description": "Optional provider ID for create. Defaults to configured sandbox booter.", + }, + "retention_policy": { + "type": "string", + "enum": ["persistent", "temporary"], + "description": "Target retention policy for set_retention.", + }, + "ttl_seconds": { + "type": "number", + "description": "Optional lease duration for renew_lease.", + }, + }, + "required": ["action"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + action: str, + sandbox_id: str = "", + sandbox_name: str = "", + provider_id: str = "", + retention_policy: str = "", + ttl_seconds: int | float | None = None, + ) -> ToolExecResult: + match action: + case "create": + return await _lifecycle_create(context, sandbox_name, provider_id) + case "switch": + return await _lifecycle_switch(context, sandbox_id) + case "release": + return await _lifecycle_release(context, sandbox_id) + case "renew_lease": + return await _lifecycle_renew_lease(context, ttl_seconds) + case "set_retention": + return await _lifecycle_set_retention( + context, retention_policy, sandbox_id, sandbox_name + ) + case "takeover": + return await _lifecycle_takeover(context, sandbox_id) + case "destroy": + return await _lifecycle_destroy(context, sandbox_id) + return f"Error managing sandbox lifecycle: unsupported action '{action}'." + + +@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG) +@dataclass +class SandboxOperationTool(FunctionTool): + name: str = "astrbot_sandbox_operation" + description: str = ( + "Run standard sandbox operations. Actions: capture_screenshot accepts sandbox_id, " + "send_to_user, and return_image_to_llm; copy_file requires source_sandbox_id, " + "source_path, target_sandbox_id, and target_path." + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["capture_screenshot", "copy_file"], + "description": "Sandbox operation to perform.", + }, + "sandbox_id": { + "type": "string", + "description": "Target sandbox ID for capture_screenshot. Defaults to current sandbox.", + }, + "send_to_user": { + "type": "boolean", + "description": "Whether capture_screenshot should send the image to the current conversation.", + "default": False, + }, + "return_image_to_llm": { + "type": "boolean", + "description": "Whether capture_screenshot should include image content in the tool result for model inspection.", + "default": False, + }, + "source_sandbox_id": { + "type": "string", + "description": "Source sandbox ID for copy_file.", + }, + "source_path": { + "type": "string", + "description": "Source path for copy_file.", + }, + "target_sandbox_id": { + "type": "string", + "description": "Target sandbox ID for copy_file.", + }, + "target_path": { + "type": "string", + "description": "Target path for copy_file.", + }, + }, + "required": ["action"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + action: str, + sandbox_id: str = "", + send_to_user: bool = False, + return_image_to_llm: bool = False, + source_sandbox_id: str = "", + source_path: str = "", + target_sandbox_id: str = "", + target_path: str = "", + ) -> ToolExecResult: + match action: + case "capture_screenshot": + return await _operation_capture_screenshot( + context, sandbox_id, send_to_user, return_image_to_llm + ) + case "copy_file": + return await _operation_copy_file( + context, + source_sandbox_id, + source_path, + target_sandbox_id, + target_path, + ) + return f"Error running sandbox operation: unsupported action '{action}'." diff --git a/astrbot/core/tools/computer_tools/shipyard_neo/__init__.py b/astrbot/core/tools/computer_tools/shipyard_neo/__init__.py deleted file mode 100644 index 9228c86354..0000000000 --- a/astrbot/core/tools/computer_tools/shipyard_neo/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -from .browser import BrowserBatchExecTool, BrowserExecTool, RunBrowserSkillTool -from .neo_skills import ( - AnnotateExecutionTool, - CreateSkillCandidateTool, - CreateSkillPayloadTool, - EvaluateSkillCandidateTool, - GetExecutionHistoryTool, - GetSkillPayloadTool, - ListSkillCandidatesTool, - ListSkillReleasesTool, - PromoteSkillCandidateTool, - RollbackSkillReleaseTool, - SyncSkillReleaseTool, -) - -__all__ = [ - "AnnotateExecutionTool", - "BrowserBatchExecTool", - "BrowserExecTool", - "CreateSkillCandidateTool", - "CreateSkillPayloadTool", - "EvaluateSkillCandidateTool", - "GetExecutionHistoryTool", - "GetSkillPayloadTool", - "ListSkillCandidatesTool", - "ListSkillReleasesTool", - "PromoteSkillCandidateTool", - "RollbackSkillReleaseTool", - "RunBrowserSkillTool", - "SyncSkillReleaseTool", -] diff --git a/astrbot/core/tools/computer_tools/shipyard_neo/browser.py b/astrbot/core/tools/computer_tools/shipyard_neo/browser.py deleted file mode 100644 index b4b7f4fd06..0000000000 --- a/astrbot/core/tools/computer_tools/shipyard_neo/browser.py +++ /dev/null @@ -1,204 +0,0 @@ -import json -from dataclasses import dataclass, field -from typing import Any - -from astrbot.api import FunctionTool -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import ToolExecResult -from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.computer.computer_client import get_booter -from astrbot.core.tools.computer_tools.util import check_admin_permission -from astrbot.core.tools.registry import builtin_tool - -_SHIPYARD_NEO_TOOL_CONFIG = { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", -} - - -def _to_json(data: Any) -> str: - return json.dumps(data, ensure_ascii=False, default=str) - - -async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> Any: - booter = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) - browser = getattr(booter, "browser", None) - if browser is None: - raise RuntimeError( - "Current sandbox booter does not support browser capability. " - "Please switch to shipyard_neo." - ) - return browser - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class BrowserExecTool(FunctionTool): - name: str = "astrbot_execute_browser" - description: str = "Execute one browser automation command in the sandbox." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "cmd": {"type": "string", "description": "Browser command to execute."}, - "timeout": {"type": "integer", "default": 30}, - "description": { - "type": "string", - "description": "Optional execution description.", - }, - "tags": {"type": "string", "description": "Optional tags."}, - "learn": { - "type": "boolean", - "description": "Whether to mark execution as learn evidence.", - "default": False, - }, - "include_trace": { - "type": "boolean", - "description": "Whether to include trace_ref in response.", - "default": False, - }, - }, - "required": ["cmd"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - cmd: str, - timeout: int = 30, - description: str | None = None, - tags: str | None = None, - learn: bool = False, - include_trace: bool = False, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using browser tools"): - return err - try: - browser = await _get_browser_component(context) - result = await browser.exec( - cmd=cmd, - timeout=timeout, - description=description, - tags=tags, - learn=learn, - include_trace=include_trace, - ) - return _to_json(result) - except Exception as e: - return f"Error executing browser command: {str(e)}" - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class BrowserBatchExecTool(FunctionTool): - name: str = "astrbot_execute_browser_batch" - description: str = "Execute a browser command batch in the sandbox." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "commands": { - "type": "array", - "items": {"type": "string"}, - "description": "Ordered browser commands.", - }, - "timeout": {"type": "integer", "default": 60}, - "stop_on_error": {"type": "boolean", "default": True}, - "description": { - "type": "string", - "description": "Optional execution description.", - }, - "tags": {"type": "string", "description": "Optional tags."}, - "learn": { - "type": "boolean", - "description": "Whether to mark execution as learn evidence.", - "default": False, - }, - "include_trace": { - "type": "boolean", - "description": "Whether to include trace_ref in response.", - "default": False, - }, - }, - "required": ["commands"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - commands: list[str], - timeout: int = 60, - stop_on_error: bool = True, - description: str | None = None, - tags: str | None = None, - learn: bool = False, - include_trace: bool = False, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using browser tools"): - return err - try: - browser = await _get_browser_component(context) - result = await browser.exec_batch( - commands=commands, - timeout=timeout, - stop_on_error=stop_on_error, - description=description, - tags=tags, - learn=learn, - include_trace=include_trace, - ) - return _to_json(result) - except Exception as e: - return f"Error executing browser batch command: {str(e)}" - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class RunBrowserSkillTool(FunctionTool): - name: str = "astrbot_run_browser_skill" - description: str = "Run a released browser skill in the sandbox by skill_key." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "skill_key": {"type": "string"}, - "timeout": {"type": "integer", "default": 60}, - "stop_on_error": {"type": "boolean", "default": True}, - "include_trace": {"type": "boolean", "default": False}, - "description": {"type": "string"}, - "tags": {"type": "string"}, - }, - "required": ["skill_key"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - skill_key: str, - timeout: int = 60, - stop_on_error: bool = True, - include_trace: bool = False, - description: str | None = None, - tags: str | None = None, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using browser tools"): - return err - try: - browser = await _get_browser_component(context) - result = await browser.run_skill( - skill_key=skill_key, - timeout=timeout, - stop_on_error=stop_on_error, - include_trace=include_trace, - description=description, - tags=tags, - ) - return _to_json(result) - except Exception as e: - return f"Error running browser skill: {str(e)}" diff --git a/astrbot/core/tools/computer_tools/shipyard_neo/neo_skills.py b/astrbot/core/tools/computer_tools/shipyard_neo/neo_skills.py deleted file mode 100644 index e2c4f59093..0000000000 --- a/astrbot/core/tools/computer_tools/shipyard_neo/neo_skills.py +++ /dev/null @@ -1,556 +0,0 @@ -import json -from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field -from typing import Any - -from astrbot.api import FunctionTool -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import ToolExecResult -from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.computer.computer_client import get_booter -from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager -from astrbot.core.tools.computer_tools.util import check_admin_permission -from astrbot.core.tools.registry import builtin_tool - -_SHIPYARD_NEO_TOOL_CONFIG = { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", -} - - -def _to_jsonable(model_like: Any) -> Any: - if isinstance(model_like, dict): - return model_like - if isinstance(model_like, list): - return [_to_jsonable(i) for i in model_like] - if hasattr(model_like, "model_dump"): - return _to_jsonable(model_like.model_dump()) - return model_like - - -def _to_json_text(data: Any) -> str: - return json.dumps(_to_jsonable(data), ensure_ascii=False, default=str) - - -async def _get_neo_context( - context: ContextWrapper[AstrAgentContext], -) -> tuple[Any, Any]: - booter = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) - client = getattr(booter, "bay_client", None) - sandbox = getattr(booter, "sandbox", None) - if client is None or sandbox is None: - raise RuntimeError( - "Current sandbox booter does not support Neo skill lifecycle APIs. " - "Please switch to shipyard_neo." - ) - return client, sandbox - - -@dataclass -class NeoSkillToolBase(FunctionTool): - error_prefix: str = "Error" - - async def _run( - self, - context: ContextWrapper[AstrAgentContext], - neo_call: Callable[[Any, Any], Awaitable[Any]], - error_action: str, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using skill lifecycle tools"): - return err - try: - client, sandbox = await _get_neo_context(context) - result = await neo_call(client, sandbox) - return _to_json_text(result) - except Exception as e: - return f"{self.error_prefix} {error_action}: {str(e)}" - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class GetExecutionHistoryTool(NeoSkillToolBase): - name: str = "astrbot_get_execution_history" - description: str = "Get execution history from current sandbox." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "exec_type": {"type": "string"}, - "success_only": {"type": "boolean", "default": False}, - "limit": {"type": "integer", "default": 100}, - "offset": {"type": "integer", "default": 0}, - "tags": {"type": "string"}, - "has_notes": {"type": "boolean", "default": False}, - "has_description": {"type": "boolean", "default": False}, - }, - "required": [], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - exec_type: str | None = None, - success_only: bool = False, - limit: int = 100, - offset: int = 0, - tags: str | None = None, - has_notes: bool = False, - has_description: bool = False, - ) -> ToolExecResult: - return await self._run( - context, - lambda _client, sandbox: sandbox.get_execution_history( - exec_type=exec_type, - success_only=success_only, - limit=limit, - offset=offset, - tags=tags, - has_notes=has_notes, - has_description=has_description, - ), - error_action="getting execution history", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class AnnotateExecutionTool(NeoSkillToolBase): - name: str = "astrbot_annotate_execution" - description: str = "Annotate one execution history record." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "execution_id": {"type": "string"}, - "description": {"type": "string"}, - "tags": {"type": "string"}, - "notes": {"type": "string"}, - }, - "required": ["execution_id"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - execution_id: str, - description: str | None = None, - tags: str | None = None, - notes: str | None = None, - ) -> ToolExecResult: - return await self._run( - context, - lambda _client, sandbox: sandbox.annotate_execution( - execution_id=execution_id, - description=description, - tags=tags, - notes=notes, - ), - error_action="annotating execution", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class CreateSkillPayloadTool(NeoSkillToolBase): - name: str = "astrbot_create_skill_payload" - description: str = ( - "Step 1/3 for Neo skill authoring: create immutable payload content and return payload_ref. " - "Use this to store skill_markdown and structured metadata; do NOT write local skill folders directly." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "payload": { - "anyOf": [ - {"type": "object"}, - {"type": "array", "items": {"type": "object"}}, - ], - "description": ( - "Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. " - "This only stores content and returns payload_ref; it does not create a candidate or release." - ), - }, - "kind": { - "type": "string", - "description": "Payload kind.", - "default": "astrbot_skill_v1", - }, - }, - "required": ["payload"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - payload: dict[str, Any] | list[Any], - kind: str = "astrbot_skill_v1", - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.create_payload( - payload=payload, - kind=kind, - ), - error_action="creating skill payload", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class GetSkillPayloadTool(NeoSkillToolBase): - name: str = "astrbot_get_skill_payload" - description: str = "Get one skill payload by payload_ref." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "payload_ref": {"type": "string"}, - }, - "required": ["payload_ref"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - payload_ref: str, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.get_payload(payload_ref), - error_action="getting skill payload", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class CreateSkillCandidateTool(NeoSkillToolBase): - name: str = "astrbot_create_skill_candidate" - description: str = ( - "Step 2/3 for Neo skill authoring: create a candidate by binding execution evidence " - "(source_execution_ids) with skill identity (skill_key) and optional payload_ref." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "skill_key": { - "type": "string", - "description": "Stable logical identifier, e.g. image-collage-9grid.", - }, - "source_execution_ids": { - "type": "array", - "items": {"type": "string"}, - "description": "Execution evidence IDs captured from sandbox history.", - }, - "scenario_key": { - "type": "string", - "description": "Optional scenario namespace for grouping candidates.", - }, - "payload_ref": { - "type": "string", - "description": "Optional payload reference created by astrbot_create_skill_payload.", - }, - }, - "required": ["skill_key", "source_execution_ids"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - skill_key: str, - source_execution_ids: list[str], - scenario_key: str | None = None, - payload_ref: str | None = None, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.create_candidate( - skill_key=skill_key, - source_execution_ids=source_execution_ids, - scenario_key=scenario_key, - payload_ref=payload_ref, - ), - error_action="creating skill candidate", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class ListSkillCandidatesTool(NeoSkillToolBase): - name: str = "astrbot_list_skill_candidates" - description: str = "List skill candidates." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "status": {"type": "string"}, - "skill_key": {"type": "string"}, - "limit": {"type": "integer", "default": 100}, - "offset": {"type": "integer", "default": 0}, - }, - "required": [], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - status: str | None = None, - skill_key: str | None = None, - limit: int = 100, - offset: int = 0, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.list_candidates( - status=status, - skill_key=skill_key, - limit=limit, - offset=offset, - ), - error_action="listing skill candidates", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class EvaluateSkillCandidateTool(NeoSkillToolBase): - name: str = "astrbot_evaluate_skill_candidate" - description: str = "Evaluate a skill candidate." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "candidate_id": {"type": "string"}, - "passed": {"type": "boolean"}, - "score": {"type": "number"}, - "benchmark_id": {"type": "string"}, - "report": {"type": "string"}, - }, - "required": ["candidate_id", "passed"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - candidate_id: str, - passed: bool, - score: float | None = None, - benchmark_id: str | None = None, - report: str | None = None, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.evaluate_candidate( - candidate_id, - passed=passed, - score=score, - benchmark_id=benchmark_id, - report=report, - ), - error_action="evaluating skill candidate", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class PromoteSkillCandidateTool(NeoSkillToolBase): - name: str = "astrbot_promote_skill_candidate" - description: str = ( - "Step 3/3 for Neo skill authoring: promote candidate to canary/stable release. " - "If stage=stable and sync_to_local=true, payload.skill_markdown is synced to local SKILL.md automatically." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "candidate_id": {"type": "string"}, - "stage": { - "type": "string", - "description": "Release stage: canary/stable", - "default": "canary", - }, - "sync_to_local": { - "type": "boolean", - "description": ( - "Only used with stage=stable. true means sync payload.skill_markdown to local SKILL.md; " - "false means release remains Neo-side only." - ), - "default": True, - }, - }, - "required": ["candidate_id"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - candidate_id: str, - stage: str = "canary", - sync_to_local: bool = True, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using skill lifecycle tools"): - return err - if stage not in {"canary", "stable"}: - return "Error promoting skill candidate: stage must be canary or stable." - - try: - client, _sandbox = await _get_neo_context(context) - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.promote_with_optional_sync( - client, - candidate_id=candidate_id, - stage=stage, - sync_to_local=sync_to_local, - ) - if result.get("sync_error"): - rollback_json = result.get("rollback") - if rollback_json: - return ( - "Error promoting skill candidate: stable release synced failed; " - f"auto rollback succeeded. sync_error={result['sync_error']}; " - f"rollback={_to_json_text(rollback_json)}" - ) - return _to_json_text( - { - "release": result.get("release"), - "sync": result.get("sync"), - "rollback": result.get("rollback"), - } - ) - except Exception as e: - return f"Error promoting skill candidate: {str(e)}" - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class ListSkillReleasesTool(NeoSkillToolBase): - name: str = "astrbot_list_skill_releases" - description: str = "List skill releases." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "skill_key": {"type": "string"}, - "active_only": {"type": "boolean", "default": False}, - "stage": {"type": "string"}, - "limit": {"type": "integer", "default": 100}, - "offset": {"type": "integer", "default": 0}, - }, - "required": [], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - skill_key: str | None = None, - active_only: bool = False, - stage: str | None = None, - limit: int = 100, - offset: int = 0, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.list_releases( - skill_key=skill_key, - active_only=active_only, - stage=stage, - limit=limit, - offset=offset, - ), - error_action="listing skill releases", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class RollbackSkillReleaseTool(NeoSkillToolBase): - name: str = "astrbot_rollback_skill_release" - description: str = "Rollback one skill release." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "release_id": {"type": "string"}, - }, - "required": ["release_id"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - release_id: str, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.rollback_release(release_id), - error_action="rolling back skill release", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class SyncSkillReleaseTool(NeoSkillToolBase): - name: str = "astrbot_sync_skill_release" - description: str = ( - "Sync stable Neo release payload to local SKILL.md and update mapping metadata." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "release_id": {"type": "string"}, - "skill_key": {"type": "string"}, - "require_stable": {"type": "boolean", "default": True}, - }, - "required": [], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - release_id: str | None = None, - skill_key: str | None = None, - require_stable: bool = True, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: _sync_release_to_dict( - client, - release_id=release_id, - skill_key=skill_key, - require_stable=require_stable, - ), - error_action="syncing skill release", - ) - - -async def _sync_release_to_dict( - client: Any, - *, - release_id: str | None, - skill_key: str | None, - require_stable: bool, -) -> dict[str, str]: - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.sync_release( - client, - release_id=release_id, - skill_key=skill_key, - require_stable=require_stable, - ) - return sync_mgr.sync_result_to_dict(result) diff --git a/astrbot/core/tools/computer_tools/util.py b/astrbot/core/tools/computer_tools/util.py index a3930b4c6a..07b4630c3c 100644 --- a/astrbot/core/tools/computer_tools/util.py +++ b/astrbot/core/tools/computer_tools/util.py @@ -41,3 +41,15 @@ def check_admin_permission( f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command." ) return None + + +def check_strict_admin_permission( + context: ContextWrapper[AstrAgentContext], operation_name: str +) -> str | None: + if context.context.event.role != "admin": + return ( + f"error: Permission denied. {operation_name} is only allowed for admin users. " + "Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature. " + f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command." + ) + return None diff --git a/astrbot/core/tools/registry.py b/astrbot/core/tools/registry.py index c3b10d2295..196a8fe00f 100644 --- a/astrbot/core/tools/registry.py +++ b/astrbot/core/tools/registry.py @@ -19,6 +19,7 @@ _builtin_tool_classes_by_name: dict[str, type[FunctionTool]] = {} _builtin_tool_names_by_class: dict[type[FunctionTool], str] = {} +_builtin_tool_names_by_module_prefix: dict[str, tuple[str, ...]] = {} _builtin_tools_loaded = False _MISSING = object() @@ -118,6 +119,10 @@ def _build_rule_from_config_map( return BuiltinToolConfigRule(conditions=tuple(conditions)) +def build_builtin_tool_config_rule(config_map: dict[str, Any]) -> BuiltinToolConfigRule: + return _build_rule_from_config_map(config_map) + + def _evaluate_send_message_tool(config: dict[str, Any]) -> list[dict[str, Any]]: platform_configs = config.get("platform", []) if not isinstance(platform_configs, list): @@ -238,6 +243,51 @@ def _register(cls: TFunctionTool) -> TFunctionTool: return _register(tool_cls) +def unregister_builtin_tool_class(tool_cls: type[FunctionTool]) -> str | None: + tool_name = _builtin_tool_names_by_class.pop(tool_cls, None) + if tool_name is None: + return None + existing = _builtin_tool_classes_by_name.get(tool_name) + if existing is tool_cls: + _builtin_tool_classes_by_name.pop(tool_name, None) + _BUILTIN_TOOL_CONFIG_RULES.pop(tool_name, None) + return tool_name + + +def _iter_builtin_tool_names_by_module_prefix(module_prefix: str) -> tuple[str, ...]: + return tuple( + tool_name + for tool_cls, tool_name in _builtin_tool_names_by_class.items() + if getattr(tool_cls, "__module__", "").startswith(module_prefix) + ) + + +def register_builtin_tools_by_module_prefix(module_prefix: str) -> list[str]: + ensure_builtin_tools_loaded() + tool_names = _iter_builtin_tool_names_by_module_prefix(module_prefix) + _builtin_tool_names_by_module_prefix[module_prefix] = tool_names + return list(tool_names) + + +def unregister_builtin_tools_by_module_prefix(module_prefix: str) -> list[str]: + recorded_tool_names = _builtin_tool_names_by_module_prefix.pop(module_prefix, ()) + tool_names = recorded_tool_names or _iter_builtin_tool_names_by_module_prefix( + module_prefix + ) + + removed: list[str] = [] + for tool_name in tool_names: + tool_cls = _builtin_tool_classes_by_name.get(tool_name) + if tool_cls is None: + continue + if not getattr(tool_cls, "__module__", "").startswith(module_prefix): + continue + removed_tool_name = unregister_builtin_tool_class(tool_cls) + if removed_tool_name is not None: + removed.append(removed_tool_name) + return removed + + def ensure_builtin_tools_loaded() -> None: global _builtin_tools_loaded if _builtin_tools_loaded: @@ -318,6 +368,7 @@ def get_builtin_tool_config_tags( __all__ = [ "builtin_tool", + "build_builtin_tool_config_rule", "ensure_builtin_tools_loaded", "get_builtin_tool_config_rule", "get_builtin_tool_config_statuses", @@ -325,4 +376,7 @@ def get_builtin_tool_config_tags( "get_builtin_tool_class", "get_builtin_tool_name", "iter_builtin_tool_classes", + "register_builtin_tools_by_module_prefix", + "unregister_builtin_tool_class", + "unregister_builtin_tools_by_module_prefix", ] diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index 40b899620d..7846ff215e 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -128,6 +128,52 @@ def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None: logger.info("Provider-source structure migration completed") +def _prune_invalid_provider_source_models(conf: AstrBotConfig) -> None: + """Remove stale provider model entries that cannot resolve a source type. + + New-style provider model entries intentionally keep only model-level fields and + inherit adapter fields such as `type` from provider_sources. Older configs can + contain entries that reference a removed source id; those can never be loaded + and would otherwise produce repeated startup errors. + """ + providers = conf.get("provider", []) + if not isinstance(providers, list): + return + + provider_sources = conf.get("provider_sources", []) + source_types = { + source.get("id"): source.get("type") + for source in provider_sources + if isinstance(source, dict) + } + + kept = [] + removed_ids = [] + for provider in providers: + if not isinstance(provider, dict): + kept.append(provider) + continue + if provider.get("type"): + kept.append(provider) + continue + source_id = provider.get("provider_source_id") + if source_id and source_types.get(source_id): + kept.append(provider) + continue + removed_ids.append(str(provider.get("id") or source_id or "")) + + if not removed_ids: + return + + conf["provider"] = kept + conf.save_config() + logger.info( + "Pruned %d invalid provider model config(s) with missing adapter type: %s", + len(removed_ids), + ", ".join(removed_ids), + ) + + async def migra( db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager ) -> None: @@ -181,3 +227,9 @@ async def migra( except Exception as e: logger.error(f"Migration for provider-source structure failed: {e!s}") logger.error(traceback.format_exc()) + + try: + _prune_invalid_provider_source_models(astrbot_config) + except Exception as e: + logger.error(f"Migration for invalid provider-source models failed: {e!s}") + logger.error(traceback.format_exc()) diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index fbbd0c7a08..512e2cf1a3 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -14,6 +14,7 @@ from .persona import PersonaRoute from .platform import PlatformRoute from .plugin import PluginRoute +from .sandbox import SandboxRoute from .session_management import SessionManagementRoute from .skills import SkillsRoute from .stat import StatRoute @@ -40,6 +41,7 @@ "PlatformRoute", "PluginRoute", "SessionManagementRoute", + "SandboxRoute", "StatRoute", "StaticFileRoute", "SubAgentRoute", diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index aebe26047c..bfa01daa9e 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -9,6 +9,7 @@ from quart import jsonify, make_response, request from astrbot.core import astrbot_config, file_token_service, logger +from astrbot.core.computer import computer_client from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.default import ( CONFIG_METADATA_2, @@ -284,58 +285,27 @@ def _protected_2fa_config_changed(old_config: dict, new_config: dict) -> bool: ) +def _normalize_unavailable_sandbox_booter(config: dict) -> dict: + sandbox = config.get("provider_settings", {}).get("sandbox", {}) + if not isinstance(sandbox, dict): + return config + booter = str(sandbox.get("booter") or "").strip() + if not booter: + return config + provider_ids = { + str(provider.get("provider_id") or "") + for provider in computer_client.list_sandbox_providers() + } + if booter not in provider_ids: + sandbox["booter"] = "" + return config + + async def _validate_neo_connectivity( post_config: dict, ) -> str | None: - """Check if Bay is reachable when Shipyard Neo sandbox is configured. - - Returns a warning message string if Bay isn't reachable, or None if - everything looks fine (or Neo isn't configured). - """ - ps = post_config.get("provider_settings", {}) - runtime = ps.get("computer_use_runtime", "none") - sandbox = ps.get("sandbox", {}) - booter = sandbox.get("booter", "") - - # Only check when sandbox mode + shipyard_neo is selected - if runtime != "sandbox" or booter != "shipyard_neo": - return None - - endpoint = sandbox.get("shipyard_neo_endpoint", "").rstrip("/") - if not endpoint: - return "⚠️ Shipyard Neo endpoint 未设置" - - access_token = sandbox.get("shipyard_neo_access_token", "") - if not access_token: - # Try auto-discovery - from astrbot.core.computer.computer_client import _discover_bay_credentials - - access_token = _discover_bay_credentials(endpoint) - - if not access_token: - return ( - "⚠️ 未找到 Bay API Key。请填写访问令牌," - "或确保 Bay 的 credentials.json 可被自动发现。" - ) - - # Connectivity check - import aiohttp - - health_url = f"{endpoint}/health" - try: - async with aiohttp.ClientSession() as session: - async with session.get( - health_url, - timeout=aiohttp.ClientTimeout(total=5), - ) as resp: - if resp.status != 200: - return ( - f"⚠️ Bay 健康检查失败 (HTTP {resp.status})," - f"请确认 Bay 正在运行: {endpoint}" - ) - except Exception: - return f"⚠️ 无法连接 Bay ({endpoint}),请确认 Bay 已启动。" - + """Concrete sandbox providers own their connectivity checks.""" + del post_config return None @@ -632,8 +602,11 @@ async def delete_ucr(self): async def get_default_config(self): """获取默认配置文件""" - metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3) - return Response().ok({"config": DEFAULT_CONFIG, "metadata": metadata}).__dict__ + metadata = ConfigMetadataI18n.convert_to_i18n_keys( + self._inject_sandbox_provider_options(copy.deepcopy(CONFIG_METADATA_3)) + ) + config = _normalize_unavailable_sandbox_booter(copy.deepcopy(DEFAULT_CONFIG)) + return Response().ok({"config": config, "metadata": metadata}).__dict__ async def get_abconf_list(self): """获取所有 AstrBot 配置文件的列表""" @@ -664,15 +637,23 @@ async def get_abconf(self): try: if system_config: - abconf = self.acm.confs["default"] + abconf = _normalize_unavailable_sandbox_booter( + copy.deepcopy(dict(self.acm.confs["default"])) + ) metadata = ConfigMetadataI18n.convert_to_i18n_keys( - CONFIG_METADATA_3_SYSTEM + self._inject_sandbox_provider_options( + copy.deepcopy(CONFIG_METADATA_3_SYSTEM) + ) ) return Response().ok({"config": abconf, "metadata": metadata}).__dict__ if abconf_id is None: raise ValueError("abconf_id cannot be None") - abconf = self.acm.confs[abconf_id] - metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3) + abconf = _normalize_unavailable_sandbox_booter( + copy.deepcopy(dict(self.acm.confs[abconf_id])) + ) + metadata = ConfigMetadataI18n.convert_to_i18n_keys( + self._inject_sandbox_provider_options(copy.deepcopy(CONFIG_METADATA_3)) + ) return Response().ok({"config": abconf, "metadata": metadata}).__dict__ except ValueError as e: return Response().error(str(e)).__dict__ @@ -1588,12 +1569,29 @@ async def _get_astrbot_config(self): if provider.default_config_tmpl: provider_default_tmpl[provider.type] = provider.default_config_tmpl + self._inject_sandbox_provider_options(metadata) + return { "metadata": metadata, "config": config, "platform_i18n_translations": platform_i18n_translations, } + def _inject_sandbox_provider_options(self, metadata: dict) -> dict: + try: + items = metadata["ai_group"]["metadata"]["agent_computer_use"]["items"] + booter = items.get("provider_settings.sandbox.booter") + except KeyError: + return metadata + if not isinstance(booter, dict): + return metadata + + providers = computer_client.list_sandbox_providers() + options = [provider["provider_id"] for provider in providers] + booter["options"] = options + booter["labels"] = options.copy() + return metadata + async def _get_plugin_config(self, plugin_name: str): ret: dict = {"metadata": None, "config": None, "i18n": {}} diff --git a/astrbot/dashboard/routes/sandbox.py b/astrbot/dashboard/routes/sandbox.py new file mode 100644 index 0000000000..786ef113f2 --- /dev/null +++ b/astrbot/dashboard/routes/sandbox.py @@ -0,0 +1,311 @@ +import traceback + +from quart import jsonify, request + +from astrbot.core import logger +from astrbot.core.computer import computer_client +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + +from .route import Response, Route, RouteContext +from .sandbox_helpers import ( + demo_mode_response, + is_demo_mode, + is_sandbox_limit_error, + is_sandbox_name_conflict, + is_sandbox_user_error, + sanitize_shell_timeout, +) + + +class SandboxRoute(Route): + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.core_lifecycle = core_lifecycle + self.routes = [ + ("/sandbox/providers", ("GET", self.list_providers)), + ("/sandbox", ("GET", self.list_sandboxes)), + ("/sandbox/current", ("GET", self.get_current_sandbox)), + ("/sandbox/current", ("DELETE", self.release_current_sandbox)), + ("/sandbox", ("POST", self.create_sandbox)), + ("/sandbox//switch", ("POST", self.switch_sandbox)), + ("/sandbox//takeover", ("POST", self.takeover_sandbox)), + ("/sandbox//default", ("POST", self.set_default_sandbox)), + ("/sandbox//shell", ("POST", self.run_shell)), + ("/sandbox//screenshot", ("POST", self.capture_screenshot)), + ("/sandbox/", ("PATCH", self.update_sandbox)), + ("/sandbox/", ("DELETE", self.destroy_sandbox)), + ] + self.register_routes() + + def _session_id(self) -> str: + return request.args.get("session_id") or "dashboard" + + async def list_providers(self): + try: + config = self.core_lifecycle.star_context.get_config(umo=self._session_id()) + sandbox_config = config.get("provider_settings", {}).get("sandbox", {}) + default_provider_id = "" + if isinstance(sandbox_config, dict): + configured_provider_id = str(sandbox_config.get("booter") or "").strip() + if computer_client.get_sandbox_provider_info(configured_provider_id): + default_provider_id = configured_provider_id + return jsonify( + Response() + .ok( + data={ + "providers": computer_client.list_sandbox_providers(), + "default_provider_id": default_provider_id, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to list sandbox providers: {e!s}").__dict__ + ) + + async def list_sandboxes(self): + try: + return jsonify( + Response() + .ok( + data={ + "sandboxes": await computer_client.sandbox_manager.list_sandboxes_checked() + } + ) + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to list sandboxes: {e!s}").__dict__ + ) + + async def get_current_sandbox(self): + try: + return jsonify( + Response() + .ok( + data=computer_client.sandbox_manager.get_current_sandbox( + self._session_id() + ) + ) + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to get current sandbox: {e!s}").__dict__ + ) + + async def create_sandbox(self): + if is_demo_mode(): + return demo_mode_response() + try: + data = await request.get_json(silent=True) or {} + provider_id = str(data.get("provider_id") or "").strip() + if not provider_id: + return jsonify(Response().error("provider_id is required").__dict__) + sandbox = await computer_client.sandbox_manager.create_sandbox_uncontrolled_deferred( + self.core_lifecycle.star_context, + self._session_id(), + provider_id, + sandbox_name=data.get("sandbox_name"), + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except RuntimeError as e: + if is_sandbox_name_conflict(e) or is_sandbox_limit_error(e): + logger.warning(str(e)) + return jsonify(Response().error(str(e)).__dict__) + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to create sandbox: {e!s}").__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to create sandbox: {e!s}").__dict__ + ) + + async def switch_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox = ( + await computer_client.sandbox_manager.switch_current_sandbox_checked( + self._session_id(), + sandbox_id, + context=self.core_lifecycle.star_context, + ) + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to switch sandbox: {e!s}").__dict__ + ) + + async def release_current_sandbox(self): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox_id = request.args.get("sandbox_id") + if sandbox_id: + sandbox = computer_client.sandbox_manager.force_release_sandbox( + sandbox_id + ) + else: + sandbox = computer_client.sandbox_manager.release_current_sandbox( + self._session_id() + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to release sandbox: {e!s}").__dict__ + ) + + async def takeover_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox = await computer_client.sandbox_manager.takeover_sandbox( + self._session_id(), sandbox_id, context=self.core_lifecycle.star_context + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to takeover sandbox: {e!s}").__dict__ + ) + + async def set_default_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox = computer_client.sandbox_manager.set_default_sandbox(sandbox_id) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to set default sandbox: {e!s}").__dict__ + ) + + async def run_shell(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + data = await request.get_json(silent=True) or {} + command = str(data.get("command") or "").strip() + if not command: + return jsonify(Response().error("command is required").__dict__) + # Dashboard shell access is an administrative operation; it does + # not need a lease so admins can operate any sandbox at any time. + booter = await computer_client.sandbox_manager.get_observer_booter_by_id( + sandbox_id, + self._session_id(), + require_lease=False, + context=self.core_lifecycle.star_context, + ) + shell = getattr(booter, "shell", None) + if shell is None: + return jsonify( + Response().error("Sandbox does not support shell.").__dict__ + ) + result = await shell.exec( + command, + cwd=data.get("cwd"), + env=data.get("env"), + timeout=sanitize_shell_timeout(data.get("timeout", 300)), + shell=data.get("shell", True), + ) + return jsonify(Response().ok(data={"result": result}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to run sandbox shell: {e!s}").__dict__ + ) + + async def capture_screenshot(self, sandbox_id: str): + try: + data = await request.get_json(silent=True) or {} + # Dashboard screenshot is a read-only observer operation; it does + # not need a lease and must not reset the sandbox idle timer. + booter = await computer_client.sandbox_manager.get_observer_booter_by_id( + sandbox_id, + self._session_id(), + require_lease=False, + context=self.core_lifecycle.star_context, + ) + gui = getattr(booter, "gui", None) + if gui is None: + return jsonify( + Response().error("Sandbox does not support screenshots.").__dict__ + ) + screenshot = await gui.screenshot(path=data.get("path")) + return jsonify(Response().ok(data={"screenshot": screenshot}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response() + .error(f"Failed to capture sandbox screenshot: {e!s}") + .__dict__ + ) + + async def update_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + data = await request.get_json(silent=True) or {} + current_sandbox = computer_client.sandbox_manager.registry.get_sandbox( + sandbox_id + ) + retention_policy = data.get( + "retention_policy", + current_sandbox.get("retention_policy", "temporary") + if current_sandbox + else "temporary", + ) + idle_timeout = data.get( + "idle_timeout", + current_sandbox.get("idle_timeout") if current_sandbox else None, + ) + expires_at = data.get( + "expires_at", + current_sandbox.get("expires_at") if current_sandbox else None, + ) + sandbox = computer_client.sandbox_manager.update_sandbox_config( + sandbox_id, + sandbox_name=data.get("sandbox_name"), + idle_timeout=idle_timeout, + expires_at=expires_at, + retention_policy=retention_policy, + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + if is_sandbox_user_error(e): + logger.info("Failed to update sandbox: %s", e) + else: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to update sandbox: {e!s}").__dict__ + ) + + async def destroy_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox = await computer_client.sandbox_manager.destroy_sandbox_deferred( + self._session_id(), sandbox_id + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to destroy sandbox: {e!s}").__dict__ + ) diff --git a/astrbot/dashboard/routes/sandbox_helpers.py b/astrbot/dashboard/routes/sandbox_helpers.py new file mode 100644 index 0000000000..45a61777da --- /dev/null +++ b/astrbot/dashboard/routes/sandbox_helpers.py @@ -0,0 +1,54 @@ +import math + +from quart import jsonify + +from astrbot.core import DEMO_MODE + +from .route import Response + + +def is_demo_mode() -> bool: + return DEMO_MODE + + +def demo_mode_response(): + return jsonify( + Response() + .error("You are not permitted to do this operation in demo mode") + .__dict__ + ) + + +def is_sandbox_name_conflict(error: Exception) -> bool: + return isinstance(error, RuntimeError) and str(error).startswith("Sandbox name ") + + +def is_sandbox_limit_error(error: Exception) -> bool: + return isinstance(error, RuntimeError) and str(error).startswith( + "Sandbox limit reached" + ) + + +def is_sandbox_user_error(error: Exception) -> bool: + if not isinstance(error, (RuntimeError, ValueError)): + return False + message = str(error) + return ( + is_sandbox_name_conflict(error) + or is_sandbox_limit_error(error) + or "does not support persistent sandboxes" in message + or "retention_policy must be" in message + or "sandbox_name must be" in message + ) + + +def sanitize_shell_timeout(value, default: float = 300) -> float: + if isinstance(value, bool): + return default + try: + timeout = float(value) + except (TypeError, ValueError): + return default + if not math.isfinite(timeout) or timeout <= 0: + return default + return timeout diff --git a/astrbot/dashboard/routes/skills.py b/astrbot/dashboard/routes/skills.py index c86598212e..9a395a2255 100644 --- a/astrbot/dashboard/routes/skills.py +++ b/astrbot/dashboard/routes/skills.py @@ -2,44 +2,17 @@ import re import shutil import traceback -from collections.abc import Awaitable, Callable from pathlib import Path -from typing import Any from quart import request, send_file from astrbot.core import DEMO_MODE, logger -from astrbot.core.computer.computer_client import ( - _discover_bay_credentials, - sync_skills_to_active_sandboxes, -) -from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager +from astrbot.core.computer.computer_client import sync_skills_to_active_sandboxes from astrbot.core.skills.skill_manager import SkillManager from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from .route import Response, Route, RouteContext - -def _to_jsonable(value: Any) -> Any: - if isinstance(value, dict): - return {k: _to_jsonable(v) for k, v in value.items()} - if isinstance(value, list): - return [_to_jsonable(v) for v in value] - if hasattr(value, "model_dump"): - return _to_jsonable(value.model_dump()) - return value - - -def _to_bool(value: Any, default: bool = False) -> bool: - if value is None: - return default - if isinstance(value, bool): - return value - if isinstance(value, str): - return value.strip().lower() in {"1", "true", "yes", "y", "on"} - return bool(value) - - _SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$") _SKILL_FILE_MAX_BYTES = 512 * 1024 _EDITABLE_SKILL_FILE_SUFFIXES = { @@ -87,15 +60,6 @@ def __init__(self, context: RouteContext, core_lifecycle) -> None: ], "/skills/update": ("POST", self.update_skill), "/skills/delete": ("POST", self.delete_skill), - "/skills/neo/candidates": ("GET", self.get_neo_candidates), - "/skills/neo/releases": ("GET", self.get_neo_releases), - "/skills/neo/payload": ("GET", self.get_neo_payload), - "/skills/neo/evaluate": ("POST", self.evaluate_neo_candidate), - "/skills/neo/promote": ("POST", self.promote_neo_candidate), - "/skills/neo/rollback": ("POST", self.rollback_neo_release), - "/skills/neo/sync": ("POST", self.sync_neo_release), - "/skills/neo/delete-candidate": ("POST", self.delete_neo_candidate), - "/skills/neo/delete-release": ("POST", self.delete_neo_release), } self.register_routes() @@ -179,58 +143,6 @@ def _serialize_skill_file_entry( ), } - def _get_neo_client_config(self) -> tuple[str, str]: - provider_settings = self.core_lifecycle.astrbot_config.get( - "provider_settings", - {}, - ) - sandbox = provider_settings.get("sandbox", {}) - endpoint = sandbox.get("shipyard_neo_endpoint", "") - access_token = sandbox.get("shipyard_neo_access_token", "") - - # Auto-discover token from Bay's credentials.json if not configured - if not access_token and endpoint: - access_token = _discover_bay_credentials(endpoint) - - if not endpoint or not access_token: - raise ValueError( - "Shipyard Neo endpoint or access token not configured. " - "Set them in Dashboard or ensure Bay's credentials.json is accessible." - ) - return endpoint, access_token - - async def _delete_neo_release( - self, client: Any, release_id: str, reason: str | None - ): - return await client.skills.delete_release(release_id, reason=reason) - - async def _delete_neo_candidate( - self, client: Any, candidate_id: str, reason: str | None - ): - return await client.skills.delete_candidate(candidate_id, reason=reason) - - async def _with_neo_client( - self, - operation: Callable[[Any], Awaitable[dict]], - ) -> dict: - try: - endpoint, access_token = self._get_neo_client_config() - - from shipyard_neo import BayClient - - async with BayClient( - endpoint_url=endpoint, - access_token=access_token, - ) as client: - return await operation(client) - except ValueError as e: - # Config not ready — expected when Neo isn't set up yet - logger.debug("[Neo] %s", e) - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - async def get_skills(self): try: provider_settings = self.core_lifecycle.astrbot_config.get( @@ -707,254 +619,3 @@ async def delete_skill(self): except Exception as e: logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - - async def get_neo_candidates(self): - logger.info("[Neo] GET /skills/neo/candidates requested.") - status = request.args.get("status") - skill_key = request.args.get("skill_key") - limit = int(request.args.get("limit", 100)) - offset = int(request.args.get("offset", 0)) - - async def _do(client): - candidates = await client.skills.list_candidates( - status=status, - skill_key=skill_key, - limit=limit, - offset=offset, - ) - result = _to_jsonable(candidates) - total = result.get("total", "?") if isinstance(result, dict) else "?" - logger.info(f"[Neo] Candidates fetched: total={total}") - return Response().ok(result).__dict__ - - return await self._with_neo_client(_do) - - async def get_neo_releases(self): - logger.info("[Neo] GET /skills/neo/releases requested.") - skill_key = request.args.get("skill_key") - stage = request.args.get("stage") - active_only = _to_bool(request.args.get("active_only"), False) - limit = int(request.args.get("limit", 100)) - offset = int(request.args.get("offset", 0)) - - async def _do(client): - releases = await client.skills.list_releases( - skill_key=skill_key, - active_only=active_only, - stage=stage, - limit=limit, - offset=offset, - ) - result = _to_jsonable(releases) - total = result.get("total", "?") if isinstance(result, dict) else "?" - logger.info(f"[Neo] Releases fetched: total={total}") - return Response().ok(result).__dict__ - - return await self._with_neo_client(_do) - - async def get_neo_payload(self): - logger.info("[Neo] GET /skills/neo/payload requested.") - payload_ref = request.args.get("payload_ref", "") - if not payload_ref: - return Response().error("Missing payload_ref").__dict__ - - async def _do(client): - payload = await client.skills.get_payload(payload_ref) - logger.info(f"[Neo] Payload fetched: ref={payload_ref}") - return Response().ok(_to_jsonable(payload)).__dict__ - - return await self._with_neo_client(_do) - - async def evaluate_neo_candidate(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/evaluate requested.") - data = await request.get_json() - candidate_id = data.get("candidate_id") - passed_value = data.get("passed") - if not candidate_id or passed_value is None: - return Response().error("Missing candidate_id or passed").__dict__ - passed = _to_bool(passed_value, False) - - async def _do(client): - result = await client.skills.evaluate_candidate( - candidate_id, - passed=passed, - score=data.get("score"), - benchmark_id=data.get("benchmark_id"), - report=data.get("report"), - ) - logger.info( - f"[Neo] Candidate evaluated: id={candidate_id}, passed={passed}" - ) - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) - - async def promote_neo_candidate(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/promote requested.") - data = await request.get_json() - candidate_id = data.get("candidate_id") - stage = data.get("stage", "canary") - sync_to_local = _to_bool(data.get("sync_to_local"), True) - if not candidate_id: - return Response().error("Missing candidate_id").__dict__ - if stage not in {"canary", "stable"}: - return Response().error("Invalid stage, must be canary/stable").__dict__ - - async def _do(client): - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.promote_with_optional_sync( - client, - candidate_id=candidate_id, - stage=stage, - sync_to_local=sync_to_local, - ) - release_json = result.get("release") - logger.info(f"[Neo] Candidate promoted: id={candidate_id}, stage={stage}") - - sync_json = result.get("sync") - did_sync_to_local = bool(sync_json) - if did_sync_to_local: - logger.info( - f"[Neo] Stable release synced to local: skill={sync_json.get('local_skill_name', '')}" - ) - - if result.get("sync_error"): - resp = Response().error( - "Stable promote synced failed and has been rolled back. " - f"sync_error={result['sync_error']}" - ) - resp.data = { - "release": release_json, - "rollback": result.get("rollback"), - } - return resp.__dict__ - - # Try to push latest local skills to all active sandboxes. - if not did_sync_to_local: - try: - await sync_skills_to_active_sandboxes() - except Exception: - logger.warning("Failed to sync skills to active sandboxes.") - - return Response().ok({"release": release_json, "sync": sync_json}).__dict__ - - return await self._with_neo_client(_do) - - async def rollback_neo_release(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/rollback requested.") - data = await request.get_json() - release_id = data.get("release_id") - if not release_id: - return Response().error("Missing release_id").__dict__ - - async def _do(client): - result = await client.skills.rollback_release(release_id) - logger.info(f"[Neo] Release rolled back: id={release_id}") - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) - - async def sync_neo_release(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/sync requested.") - data = await request.get_json() - release_id = data.get("release_id") - skill_key = data.get("skill_key") - require_stable = _to_bool(data.get("require_stable"), True) - if not release_id and not skill_key: - return Response().error("Missing release_id or skill_key").__dict__ - - async def _do(client): - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.sync_release( - client, - release_id=release_id, - skill_key=skill_key, - require_stable=require_stable, - ) - logger.info( - f"[Neo] Release synced to local: skill={result.local_skill_name}, " - f"release_id={result.release_id}" - ) - return ( - Response() - .ok( - { - "skill_key": result.skill_key, - "local_skill_name": result.local_skill_name, - "release_id": result.release_id, - "candidate_id": result.candidate_id, - "payload_ref": result.payload_ref, - "map_path": result.map_path, - "synced_at": result.synced_at, - } - ) - .__dict__ - ) - - return await self._with_neo_client(_do) - - async def delete_neo_candidate(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/delete-candidate requested.") - data = await request.get_json() - candidate_id = data.get("candidate_id") - reason = data.get("reason") - if not candidate_id: - return Response().error("Missing candidate_id").__dict__ - - async def _do(client): - result = await self._delete_neo_candidate(client, candidate_id, reason) - logger.info(f"[Neo] Candidate deleted: id={candidate_id}") - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) - - async def delete_neo_release(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/delete-release requested.") - data = await request.get_json() - release_id = data.get("release_id") - reason = data.get("reason") - if not release_id: - return Response().error("Missing release_id").__dict__ - - async def _do(client): - result = await self._delete_neo_release(client, release_id, reason) - logger.info(f"[Neo] Release deleted: id={release_id}") - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 2ad80687ef..63864c65a6 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -4,6 +4,9 @@ from astrbot.core import logger, sp from astrbot.core.agent.mcp_client import MCPTool, validate_mcp_stdio_config +from astrbot.core.computer.sandbox_tool_binding import ( + get_sandbox_provider_tool_config_statuses, +) from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.star import star_map from astrbot.core.tools.registry import get_builtin_tool_config_statuses @@ -483,7 +486,21 @@ async def get_tool_list(self): readonly = False builtin_config_statuses = [] builtin_config_tags = [] - if self.tool_mgr.is_builtin_tool(tool.name): + sandbox_provider_id = getattr(tool, "sandbox_provider_id", None) + if sandbox_provider_id: + origin = "sandbox" + origin_name = str(sandbox_provider_id) + readonly = True + builtin_config_statuses = get_sandbox_provider_tool_config_statuses( + tool.name, + config_entries, + ) + builtin_config_tags = [ + status + for status in builtin_config_statuses + if status["enabled"] + ] + elif self.tool_mgr.is_builtin_tool(tool.name): origin = "builtin" origin_name = "AstrBot Core" readonly = True @@ -544,6 +561,18 @@ async def toggle_tool(self): .__dict__ ) + tool = next( + (t for t in self.tool_mgr.func_list if t.name == tool_name), None + ) + if getattr(tool, "sandbox_provider_id", None): + return ( + Response() + .error( + "Sandbox provider tools are read-only and cannot be toggled." + ) + .__dict__ + ) + if self.tool_mgr.is_builtin_tool(tool_name): return ( Response() @@ -604,6 +633,18 @@ async def update_tool_permission(self): .__dict__ ) + tool = next( + (t for t in self.tool_mgr.func_list if t.name == tool_name), None + ) + if getattr(tool, "sandbox_provider_id", None): + return ( + Response() + .error( + "Sandbox provider tools do not support per-tool permission configuration." + ) + .__dict__ + ) + if self.tool_mgr.is_builtin_tool(tool_name): return ( Response() diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 9888da8f5f..0960b5f401 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -8,6 +8,7 @@ from datetime import datetime from pathlib import Path from typing import Protocol, cast +from urllib.parse import urlsplit import jwt import psutil @@ -41,6 +42,7 @@ from .routes.live_chat import LiveChatRoute from .routes.platform import PlatformRoute from .routes.route import Response, RouteContext +from .routes.sandbox import SandboxRoute from .routes.session_management import SessionManagementRoute from .routes.subagent import SubAgentRoute from .routes.t2i import T2iRoute @@ -300,6 +302,7 @@ def __init__( db, core_lifecycle, ) + self.sandbox_route = SandboxRoute(self.context, core_lifecycle) self.persona_route = PersonaRoute(self.context, db, core_lifecycle) self.cron_route = CronRoute(self.context, core_lifecycle) self.t2i_route = T2iRoute(self.context, core_lifecycle) @@ -435,6 +438,9 @@ async def auth_middleware(self): if not isinstance(username, str) or not username.strip(): raise jwt.InvalidTokenError("missing username in token payload") g.username = username + sandbox_origin_error = self._sandbox_cookie_origin_error() + if sandbox_origin_error is not None: + return sandbox_origin_error except jwt.ExpiredSignatureError: r = jsonify(Response().error("Token 过期").__dict__) r.status_code = 401 @@ -444,6 +450,38 @@ async def auth_middleware(self): r.status_code = 401 return r + def _sandbox_cookie_origin_error(self): + if request.method not in {"POST", "PATCH", "DELETE"}: + return None + if not request.path.startswith("/api/sandbox"): + return None + auth_header = request.headers.get("Authorization", "").strip() + if auth_header.startswith("Bearer "): + return None + + origin = request.headers.get("Origin", "").strip() + referer = request.headers.get("Referer", "").strip() + candidate = origin or referer + if not candidate: + return None + if self._is_same_dashboard_origin(candidate): + return None + + r = jsonify(Response().error("Origin is not allowed").__dict__) + r.status_code = 403 + return r + + @staticmethod + def _is_same_dashboard_origin(candidate: str) -> bool: + parsed_candidate = urlsplit(candidate) + if not parsed_candidate.scheme or not parsed_candidate.netloc: + return False + parsed_request = urlsplit(str(request.url_root)) + return ( + parsed_candidate.scheme == parsed_request.scheme + and parsed_candidate.netloc == parsed_request.netloc + ) + def _get_request_client_ip(self, current_request) -> str: if bool(self.config.get("dashboard", {}).get("trust_proxy_headers", False)): forwarded_for = str( diff --git a/dashboard/pnpm-workspace.yaml b/dashboard/pnpm-workspace.yaml index a23eae0866..1bdd797c8c 100644 --- a/dashboard/pnpm-workspace.yaml +++ b/dashboard/pnpm-workspace.yaml @@ -1,3 +1,6 @@ allowBuilds: esbuild: true vue-demi: true +onlyBuiltDependencies: + - esbuild + - vue-demi diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css b/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css index be565ba238..c143cf7979 100644 --- a/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css +++ b/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css @@ -1,4 +1,4 @@ -/* Auto-generated MDI subset – 272 icons */ +/* Auto-generated MDI subset – 273 icons */ /* Do not edit manually. Run: pnpm run subset-icons */ @font-face { @@ -316,6 +316,10 @@ content: "\F1C2B"; } +.mdi-cube-outline::before { + content: "\F01A7"; +} + .mdi-cursor-default-click::before { content: "\F0CFD"; } diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff index 8e57a70b4f..42491d6927 100644 Binary files a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff and b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff differ diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 index 107a267095..a388b33bff 100644 Binary files a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 and b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 differ diff --git a/dashboard/src/components/extension/SkillsSection.vue b/dashboard/src/components/extension/SkillsSection.vue index 2ad1bddffe..56cb3f2db5 100644 --- a/dashboard/src/components/extension/SkillsSection.vue +++ b/dashboard/src/components/extension/SkillsSection.vue @@ -1521,19 +1521,9 @@ export default { }; const loadNeoAvailability = async () => { - try { - const res = await axios.get("/api/config/get"); - const config = res?.data?.data?.config || {}; - const providerSettings = config?.provider_settings || {}; - const currentRuntime = - providerSettings?.computer_use_runtime || "local"; - const booter = providerSettings?.sandbox?.booter || ""; - neoEnabled.value = - currentRuntime === "sandbox" && booter === "shipyard_neo"; - } catch (_err) { - neoEnabled.value = false; - } - + // Core no longer ships /api/skills/neo/* routes; Neo skill management is + // provided by external sandbox provider plugins instead of this page. + neoEnabled.value = false; neoUnavailableMessage.value = tm("skills.neoRuntimeRequired"); if (!neoEnabled.value && mode.value === "neo") { mode.value = "local"; diff --git a/dashboard/src/components/extension/componentPanel/components/ToolTable.vue b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue index 75e59f30a9..937fe4cbe2 100644 --- a/dashboard/src/components/extension/componentPanel/components/ToolTable.vue +++ b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue @@ -68,7 +68,7 @@ const formatCondition = (condition: ToolConfigCondition) => { }; const enabledConfigTags = (tool: ToolItem): BuiltinToolConfigTag[] => { - if (tool.origin !== 'builtin') return []; + if (tool.origin !== 'builtin' && tool.origin !== 'sandbox') return []; return (tool.builtin_config_tags || []).filter(tag => tag.enabled); }; @@ -89,6 +89,23 @@ const getPermissionLabel = (permission?: string): string => { return tmTool('functionTools.table.permissionEveryone'); } }; + +const getOriginLabel = (origin?: string): string => { + switch (origin) { + case 'builtin': + return tmTool('functionTools.table.originBuiltin'); + case 'mcp': + return tmTool('functionTools.table.originMcp'); + case 'plugin': + return tmTool('functionTools.table.originPlugin'); + case 'sandbox': + return tmTool('functionTools.table.originSandbox'); + case 'unknown': + return tmTool('functionTools.table.originUnknown'); + default: + return origin || '-'; + } +}; @@ -148,7 +165,7 @@ const getPermissionLabel = (permission?: string): string => { - {{ item.origin || '-' }} + {{ getOriginLabel(item.origin) }} diff --git a/dashboard/src/i18n/locales/en-US/core/navigation.json b/dashboard/src/i18n/locales/en-US/core/navigation.json index e0e39639ba..25cc446577 100644 --- a/dashboard/src/i18n/locales/en-US/core/navigation.json +++ b/dashboard/src/i18n/locales/en-US/core/navigation.json @@ -20,6 +20,7 @@ }, "conversation": "Conversations", "sessionManagement": "Custom Rules", + "sandboxes": "Sandboxes", "console": "Console", "trace": "Trace", "alkaid": "Alkaid Lab", diff --git a/dashboard/src/i18n/locales/en-US/features/config-metadata.json b/dashboard/src/i18n/locales/en-US/features/config-metadata.json index a79746c9db..2ae298e7d1 100644 --- a/dashboard/src/i18n/locales/en-US/features/config-metadata.json +++ b/dashboard/src/i18n/locales/en-US/features/config-metadata.json @@ -168,63 +168,41 @@ }, "sandbox": { "booter": { - "description": "Sandbox Environment Driver" + "description": "Default Sandbox Driver" }, - "shipyard_neo_endpoint": { - "description": "Shipyard Neo API Endpoint", - "hint": "Bay API address, default http://127.0.0.1:8114." + "sandbox_lease_timeout": { + "description": "Sandbox lease timeout", + "hint": "Seconds. Each successful sandbox access renews this session's lease to now plus this duration. Default 600 seconds. When it expires, this session no longer has a current sandbox and other sessions can take over. `0` means the lease never expires and must be released manually." }, - "shipyard_neo_access_token": { - "description": "Shipyard Neo Access Token", - "hint": "Bay API Key (sk-bay-...). Leave empty for auto-discovery from credentials.json." + "sandbox_idle_timeout": { + "description": "Sandbox idle cleanup timeout", + "hint": "Seconds. Starts counting after the sandbox is released. If it stays unclaimed longer than this, it is cleaned up automatically, default 1800 seconds. `0` disables idle cleanup." }, - "shipyard_neo_profile": { - "description": "Shipyard Neo Profile", - "hint": "Sandbox profile for Shipyard Neo, e.g. python-default. Leave empty to auto-select the most capable profile." + "sandbox_ttl": { + "description": "Sandbox lifetime", + "hint": "Seconds. Only active when idle cleanup is `0`. Counts from creation time and forces cleanup when it expires. `0` disables automatic expiration." }, - "shipyard_neo_ttl": { - "description": "Shipyard Neo Sandbox TTL", - "hint": "Sandbox time-to-live in seconds." + "max_sandboxes": { + "description": "Maximum sandboxes", + "hint": "Global managed sandbox limit, default 10. `0` means unlimited." }, - "cua_image": { - "description": "CUA Image", - "hint": "CUA sandbox image or OS type. Defaults to linux. Supported values depend on the installed CUA SDK." - }, - "cua_os_type": { - "description": "CUA OS Type", - "hint": "CUA sandbox operating system type. Defaults to linux." - }, - "cua_idle_timeout": { - "description": "CUA Idle Timeout", - "hint": "Idle timeout for CUA sandbox sessions in seconds. When greater than 0, AstrBot proactively shuts down an idle CUA sandbox after that amount of inactivity; 0 disables it." - }, - "cua_telemetry_enabled": { - "description": "CUA Telemetry", - "hint": "Allow the CUA SDK to send telemetry data. Disabled by default." - }, - "cua_local": { - "description": "CUA Local Sandbox", - "hint": "Prefer a local CUA sandbox. Enabled by default to avoid requiring CUA_API_KEY for cloud sandboxes. Disable this to use CUA cloud sandboxes." - }, - "cua_api_key": { - "description": "CUA API Key", - "hint": "CUA cloud sandbox API key. Required only when local sandbox is disabled. You can also provide it via the CUA_API_KEY environment variable." - }, - "shipyard_endpoint": { - "description": "Shipyard API Endpoint", - "hint": "API access address for Shipyard service." - }, - "shipyard_access_token": { - "description": "Shipyard Access Token", - "hint": "Access token for accessing Shipyard service." - }, - "shipyard_ttl": { - "description": "Shipyard Session TTL", - "hint": "Session time-to-live in seconds." - }, - "shipyard_max_sessions": { - "description": "Shipyard Max Sessions", - "hint": "Maximum number of Shipyard sessions an instance can handle." + "member_permissions": { + "create": { + "description": "Allow members to create sandboxes", + "hint": "Lets non-admin users create new managed sandboxes. Creation still respects the maximum sandbox limit." + }, + "set_retention_policy": { + "description": "Allow members to change sandbox retention", + "hint": "Lets non-admin users switch sandboxes between temporary and persistent retention. Persistent sandboxes keep their environment for later reuse." + }, + "takeover": { + "description": "Allow members to take over sandboxes", + "hint": "Lets non-admin users force takeover sandboxes occupied by other sessions. This transfers control of the sandbox, so enable it carefully." + }, + "destroy": { + "description": "Allow members to delete sandboxes", + "hint": "Lets non-admin users delete managed sandboxes they can access. Deletion removes both the sandbox environment and its managed record." + } } } } diff --git a/dashboard/src/i18n/locales/en-US/features/sandbox.json b/dashboard/src/i18n/locales/en-US/features/sandbox.json new file mode 100644 index 0000000000..c94fdf7172 --- /dev/null +++ b/dashboard/src/i18n/locales/en-US/features/sandbox.json @@ -0,0 +1,126 @@ +{ + "title": "Sandbox Management", + "subtitle": "Inspect and operate managed sandboxes across providers.", + "actions": { + "create": "Create", + "refresh": "Refresh", + "inspect": "Inspect", + "setDefault": "Set default", + "configure": "Configure", + "switch": "Switch", + "takeover": "Take over", + "console": "Console", + "release": "Release", + "screenshot": "Screenshot", + "destroy": "Destroy", + "cancel": "Cancel", + "close": "Close", + "save": "Save" + }, + "metrics": { + "total": "Managed sandboxes", + "providers": "Providers", + "busy": "Occupied", + "default": "Provider defaults" + }, + "labels": { + "default": "Default", + "busy": "Occupied", + "available": "Idle", + "noController": "No controller", + "unknown": "Unknown", + "none": "None", + "temporary": "Temporary", + "persistent": "Persistent", + "creating": "Creating", + "restoring": "Restoring", + "running": "Running", + "error": "Error", + "stopping": "Stopping", + "stopped": "Stopped", + "unknownStatus": "Unknown status: {status}" + }, + "headers": { + "sandbox": "Sandbox", + "provider": "Provider", + "capabilities": "Capabilities", + "status": "Status", + "lastUsed": "Last used", + "actions": "Actions" + }, + "fields": { + "provider": "Provider", + "capabilities": "Capabilities", + "toolNames": "Tools", + "status": "Status", + "owner": "Owner session", + "controller": "Controller session", + "retentionPolicy": "Retention policy", + "occupiedUntil": "Occupied until", + "idleCleanupAt": "Idle cleanup at", + "expiresAt": "Expiration cleanup at", + "connectInfo": "Connection info" + }, + "empty": { + "title": "No managed sandboxes", + "subtitle": "Create a managed sandbox or wait for a provider to register one." + }, + "create": { + "title": "Create sandbox", + "name": "Sandbox name", + "providerHint": "This provider is visible in the dashboard model but creation is not wired in this phase." + }, + "screenshot": { + "title": "Sandbox screenshot", + "noPreview": "Screenshot captured. Preview rendering will be added later." + }, + "console": { + "title": "Sandbox console", + "notice": "The console does not change sandbox occupancy, but commands directly affect this sandbox environment.", + "command": "Shell command", + "run": "Run command", + "output": "Command output", + "empty": "No commands executed yet.", + "running": "Running...", + "dangerConfirm": "This command may destroy sandbox data: {command}\nContinue?" + }, + "config": { + "title": "Sandbox configuration", + "name": "Sandbox name", + "nameRequired": "Sandbox name is required", + "idleTimeout": "Idle timeout (seconds)", + "idleTimeoutHint": "With the temporary policy, sandboxes can be cleaned after this idle duration. Empty or 0 disables idle cleanup.", + "expiresAt": "Fixed expiration time", + "expiresAtHint": "Optional. The sandbox can be cleaned after this time." + }, + "tooltips": { + "takeover": "Bind the current dashboard session to this sandbox and obtain control. If another controller exists, control is transferred.", + "console": "Open a shell console without changing the sandbox occupancy." + }, + "destroyConfirm": { + "title": "Destroy sandbox?", + "message": "Destroy {name}? This shuts down the sandbox and removes its managed record." + }, + "messages": { + "loadFailed": "Failed to load sandboxes.", + "operationFailed": "Sandbox operation failed.", + "created": "Sandbox created.", + "createSubmitted": "Sandbox creation request submitted.", + "createReady": "Sandbox is now running.", + "createFailed": "Sandbox creation failed.", + "createUnknown": "Sandbox state is unknown. Refresh manually or inspect the backend.", + "createRefreshUnstable": "Sandbox status refresh became unstable during creation. Refresh manually later.", + "createNotVisible": "Sandbox {sandboxId} is not visible in the list yet. Refresh manually later.", + "createTimedOut": "Sandbox {sandboxId} is still processing. Refresh manually later.", + "createUnexpectedStatus": "Sandbox creation ended with status: {status}.", + "maxSandboxesReached": "Maximum sandbox count reached: {max}. Release or destroy an unused sandbox first.", + "defaultSet": "Default sandbox updated.", + "configSaved": "Sandbox configuration saved.", + "released": "Sandbox occupancy released.", + "takeover": "Current session now controls this sandbox.", + "destroyed": "Sandbox destroyed.", + "destroyRefreshUnstable": "Sandbox status refresh became unstable during destroy. Refresh manually later.", + "destroyTimedOut": "Sandbox {sandboxId} is still being destroyed. Refresh manually later.", + "screenshot": "Screenshot captured." + } +} diff --git a/dashboard/src/i18n/locales/en-US/features/tool-use.json b/dashboard/src/i18n/locales/en-US/features/tool-use.json index 52c27bbfed..324b376137 100644 --- a/dashboard/src/i18n/locales/en-US/features/tool-use.json +++ b/dashboard/src/i18n/locales/en-US/features/tool-use.json @@ -55,6 +55,11 @@ "permissionMember": "All Users", "permissionEveryone": "All Users", "permissionBuiltin": "System Builtin", + "originBuiltin": "Builtin", + "originMcp": "MCP", + "originPlugin": "Plugin", + "originSandbox": "Sandbox", + "originUnknown": "Unknown", "actions": "Actions" }, "configTags": { diff --git a/dashboard/src/i18n/locales/ru-RU/core/navigation.json b/dashboard/src/i18n/locales/ru-RU/core/navigation.json index 315e386ec2..d9864371c6 100644 --- a/dashboard/src/i18n/locales/ru-RU/core/navigation.json +++ b/dashboard/src/i18n/locales/ru-RU/core/navigation.json @@ -1,50 +1,48 @@ -{ - "welcome": "Добро пожаловать", +{ "dashboard": "Статистика", - "platforms": "Боты", - "providers": "Провайдеры моделей", - "commands": "Команды", - "persona": "Персонажи", - "subagent": "Субагенты", - "toolUse": "Инструменты MCP", - "extension": "Плагины", - "extensionTabs": { - "installed": "Плагины AstrBot", - "market": "Магазин плагинов", - "mcp": "Серверы MCP", - "skills": "Навыки", - "components": "Управление поведением" - }, - "config": "Конфигурация", - "chat": "Чат", - "cron": "Запланированные задачи", - "conversation": "Данные диалогов", - "sessionManagement": "Пользовательские правила", - "console": "Логи платформы", - "trace": "Трассировка", - "alkaid": "Alkaid Lab", + "extension": "Расширения", + "source": "Платформы", + "provider": "Поставщики", + "agent": "Агенты", + "modelProvider": "Поставщики моделей", + "serviceProvider": "Поставщики сервисов", "knowledgeBase": "База знаний", - "about": "О программе", + "mcp": "MCP", "settings": "Настройки", - "changelog": "Журнал изменений", + "configs": "Конфигурация", + "logs": "Логи", + "terminal": "Терминал", + "about": "О программе", + "theme": "Тема", + "language": "Язык", + "lightMode": "Светлый режим", + "darkMode": "Темный режим", + "logout": "Выйти", "documentation": "Документация", - "faq": "FAQ", - "github": "GitHub", - "drag": "Перетащить", - "groups": { - "more": "Дополнительно" - }, - "changelogDialog": { - "title": "Журнал изменений", - "loading": "Загрузка...", - "error": "Ошибка загрузки", - "notFound": "Журнал изменений для этой версии не найден", - "selectVersion": "Выберите версию", - "current": "Текущая" - }, + "support": "Поддержка", + "profile": "Профиль", + "notifications": "Уведомления", + "search": "Поиск", + "back": "Назад", + "menu": "Меню", + "close": "Закрыть", + "toggleMenu": "Переключить меню", + "collapse": "Свернуть", + "expand": "Развернуть", + "refresh": "Обновить", + "home": "Главная", + "breadcrumb": "Навигационная цепочка", + "pageTitle": "Заголовок страницы", + "navigation": "Навигация", + "primaryNavigation": "Основная навигация", + "secondaryNavigation": "Вторичная навигация", + "userMenu": "Меню пользователя", + "quickActions": "Быстрые действия", + "recentItems": "Недавние элементы", + "favorites": "Избранное", "configTabs": { "normal": "Обычная конфигурация", "system": "Системная конфигурация" }, "pluginWebui": "Страницы плагинов" -} \ No newline at end of file +} diff --git a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json index b08662b7ba..181e720252 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json +++ b/dashboard/src/i18n/locales/ru-RU/features/config-metadata.json @@ -170,61 +170,39 @@ "booter": { "description": "Драйвер среды песочницы" }, - "shipyard_neo_endpoint": { - "description": "Эндпоинт Shipyard Neo API", - "hint": "Адрес Bay API, по умолчанию http://127.0.0.1:8114." + "sandbox_lease_timeout": { + "description": "Таймаут владения песочницей", + "hint": "В секундах. Каждый успешный доступ к песочнице продлевает владение этой сессии до текущего времени плюс это значение. По умолчанию 600 секунд. После истечения у текущей сессии больше нет текущей песочницы, и другие сессии могут получить контроль. `0` означает, что владение не истекает и должно быть освобождено вручную." }, - "shipyard_neo_access_token": { - "description": "Токен доступа Shipyard Neo", - "hint": "Ключ Bay API (sk-bay-...). Оставьте пустым для автопоиска в credentials.json." + "sandbox_idle_timeout": { + "description": "Таймаут очистки простоя песочницы", + "hint": "В секундах. `0` отключает очистку простоя; TTL-очистка применяется только в этом случае." }, - "shipyard_neo_profile": { - "description": "Профиль Shipyard Neo", - "hint": "Профиль песочницы, например, python-default. Оставьте пустым для автоматического выбора самого функционального профиля." + "sandbox_ttl": { + "description": "Время жизни песочницы", + "hint": "В секундах. Действует только когда очистка простоя равна `0`; `0` отключает автоматическое истечение." }, - "shipyard_neo_ttl": { - "description": "TTL песочницы Shipyard Neo", - "hint": "Время жизни песочницы в секундах." + "max_sandboxes": { + "description": "Максимум песочниц", + "hint": "Глобальный лимит управляемых песочниц, по умолчанию 10. `0` означает без ограничения." }, - "cua_image": { - "description": "Образ CUA", - "hint": "Образ или тип ОС песочницы CUA. По умолчанию linux. Поддерживаемые значения зависят от установленного CUA SDK." - }, - "cua_os_type": { - "description": "Тип ОС CUA", - "hint": "Тип операционной системы песочницы CUA. По умолчанию linux." - }, - "cua_idle_timeout": { - "description": "Таймаут простоя CUA", - "hint": "Таймаут простоя сессии песочницы CUA в секундах. Если значение больше 0, AstrBot автоматически завершит неактивную песочницу CUA после указанного времени бездействия; 0 отключает автоочистку." - }, - "cua_telemetry_enabled": { - "description": "Телеметрия CUA", - "hint": "Разрешить CUA SDK отправлять телеметрию. По умолчанию выключено." - }, - "cua_local": { - "description": "Локальная песочница CUA", - "hint": "Предпочитать локальную песочницу CUA. Включено по умолчанию, чтобы не требовать CUA_API_KEY для облачных песочниц. Отключите для использования облачных песочниц CUA." - }, - "cua_api_key": { - "description": "CUA API Key", - "hint": "API key для облачной песочницы CUA. Требуется только если локальная песочница отключена. Также можно передать через переменную окружения CUA_API_KEY." - }, - "shipyard_endpoint": { - "description": "Эндпоинт Shipyard API", - "hint": "Адрес API для доступа к сервису Shipyard." - }, - "shipyard_access_token": { - "description": "Токен доступа Shipyard", - "hint": "Токен доступа для работы с сервисом Shipyard." - }, - "shipyard_ttl": { - "description": "TTL сессии Shipyard", - "hint": "Время жизни сессии в секундах." - }, - "shipyard_max_sessions": { - "description": "Макс. количество сессий Shipyard", - "hint": "Максимальное количество сессий Shipyard, которое может поддерживать экземпляр." + "member_permissions": { + "create": { + "description": "Разрешить обычным пользователям создавать песочницы", + "hint": "Позволяет обычным пользователям создавать новые управляемые песочницы. Создание всё равно учитывает общий лимит песочниц." + }, + "set_retention_policy": { + "description": "Разрешить обычным пользователям менять политику хранения", + "hint": "Позволяет обычным пользователям переключать песочницы между временной и постоянной политикой хранения. Постоянные песочницы сохраняют окружение для повторного использования." + }, + "takeover": { + "description": "Разрешить обычным пользователям захватывать песочницы", + "hint": "Позволяет обычным пользователям принудительно захватывать песочницы, занятые другими сессиями. Это передаёт управление песочницей, поэтому включайте настройку осторожно." + }, + "destroy": { + "description": "Разрешить обычным пользователям удалять песочницы", + "hint": "Позволяет обычным пользователям удалять доступные им управляемые песочницы. Удаление убирает и окружение песочницы, и её запись." + } } } } diff --git a/dashboard/src/i18n/locales/ru-RU/features/sandbox.json b/dashboard/src/i18n/locales/ru-RU/features/sandbox.json new file mode 100644 index 0000000000..e47f73ab02 --- /dev/null +++ b/dashboard/src/i18n/locales/ru-RU/features/sandbox.json @@ -0,0 +1,126 @@ +{ + "title": "Sandbox Management", + "subtitle": "Inspect and operate managed sandboxes across providers.", + "actions": { + "create": "Create", + "refresh": "Refresh", + "inspect": "Inspect", + "setDefault": "Set default", + "configure": "Configure", + "switch": "Switch", + "takeover": "Take over", + "console": "Console", + "release": "Release", + "screenshot": "Screenshot", + "destroy": "Destroy", + "cancel": "Cancel", + "close": "Close", + "save": "Save" + }, + "metrics": { + "total": "Managed sandboxes", + "providers": "Providers", + "busy": "Occupied", + "default": "Provider defaults" + }, + "labels": { + "default": "Default", + "busy": "Occupied", + "available": "Idle", + "noController": "No controller", + "unknown": "Unknown", + "none": "None", + "temporary": "Temporary", + "persistent": "Persistent", + "creating": "Creating", + "restoring": "Restoring", + "running": "Running", + "error": "Error", + "stopping": "Stopping", + "stopped": "Stopped", + "unknownStatus": "Unknown status: {status}" + }, + "headers": { + "sandbox": "Sandbox", + "provider": "Provider", + "capabilities": "Capabilities", + "status": "Status", + "lastUsed": "Last used", + "actions": "Actions" + }, + "fields": { + "provider": "Provider", + "capabilities": "Capabilities", + "toolNames": "Tools", + "status": "Status", + "owner": "Owner session", + "controller": "Controller session", + "retentionPolicy": "Retention policy", + "occupiedUntil": "Occupied until", + "idleCleanupAt": "Idle cleanup at", + "expiresAt": "Expiration cleanup at", + "connectInfo": "Connection info" + }, + "empty": { + "title": "No managed sandboxes", + "subtitle": "Create a managed sandbox or wait for a provider to register one." + }, + "create": { + "title": "Create sandbox", + "name": "Sandbox name", + "providerHint": "This provider is visible in the dashboard model but creation is not wired in this phase." + }, + "screenshot": { + "title": "Sandbox screenshot", + "noPreview": "Screenshot captured. Preview rendering will be added later." + }, + "console": { + "title": "Sandbox console", + "notice": "The console does not change sandbox occupancy, but commands directly affect this sandbox environment.", + "command": "Shell command", + "run": "Run command", + "output": "Command output", + "empty": "No commands executed yet.", + "running": "Running...", + "dangerConfirm": "This command may destroy sandbox data: {command}\nContinue?" + }, + "config": { + "title": "Sandbox configuration", + "name": "Sandbox name", + "nameRequired": "Sandbox name is required", + "idleTimeout": "Idle timeout (seconds)", + "idleTimeoutHint": "With the temporary policy, sandboxes can be cleaned after this idle duration. Empty or 0 disables idle cleanup.", + "expiresAt": "Fixed expiration time", + "expiresAtHint": "Optional. The sandbox can be cleaned after this time." + }, + "tooltips": { + "takeover": "Bind the current dashboard session to this sandbox and obtain control. If another controller exists, control is transferred.", + "console": "Open a shell console without changing the sandbox occupancy." + }, + "destroyConfirm": { + "title": "Destroy sandbox?", + "message": "Destroy {name}? This shuts down the sandbox and removes its managed record." + }, + "messages": { + "loadFailed": "Failed to load sandboxes.", + "operationFailed": "Sandbox operation failed.", + "created": "Sandbox created.", + "createSubmitted": "Sandbox creation request submitted.", + "createReady": "Sandbox is now running.", + "createFailed": "Sandbox creation failed.", + "createUnknown": "Sandbox state is unknown. Refresh manually or inspect the backend.", + "createRefreshUnstable": "Sandbox status refresh became unstable during creation. Refresh manually later.", + "createNotVisible": "Sandbox {sandboxId} is not visible in the list yet. Refresh manually later.", + "createTimedOut": "Sandbox {sandboxId} is still processing. Refresh manually later.", + "createUnexpectedStatus": "Sandbox creation ended with status: {status}.", + "maxSandboxesReached": "Достигнут максимальный лимит песочниц: {max}. Сначала освободите или удалите ненужную песочницу.", + "defaultSet": "Default sandbox updated.", + "configSaved": "Sandbox configuration saved.", + "released": "Sandbox occupancy released.", + "takeover": "Current session now controls this sandbox.", + "destroyed": "Sandbox destroyed.", + "destroyRefreshUnstable": "Sandbox status refresh became unstable during destroy. Refresh manually later.", + "destroyTimedOut": "Sandbox {sandboxId} is still being destroyed. Refresh manually later.", + "screenshot": "Screenshot captured." + } +} diff --git a/dashboard/src/i18n/locales/ru-RU/features/tool-use.json b/dashboard/src/i18n/locales/ru-RU/features/tool-use.json index a4e2a583e8..8cbc9076d3 100644 --- a/dashboard/src/i18n/locales/ru-RU/features/tool-use.json +++ b/dashboard/src/i18n/locales/ru-RU/features/tool-use.json @@ -55,6 +55,11 @@ "permissionMember": "Все", "permissionEveryone": "Все", "permissionBuiltin": "Встроенный", + "originBuiltin": "Встроенный", + "originMcp": "MCP", + "originPlugin": "Плагин", + "originSandbox": "Песочница", + "originUnknown": "Неизвестно", "actions": "Действия" }, "configTags": { diff --git a/dashboard/src/i18n/locales/zh-CN/core/navigation.json b/dashboard/src/i18n/locales/zh-CN/core/navigation.json index 163ab35149..30ce11ac0f 100644 --- a/dashboard/src/i18n/locales/zh-CN/core/navigation.json +++ b/dashboard/src/i18n/locales/zh-CN/core/navigation.json @@ -20,6 +20,7 @@ "cron": "未来任务", "conversation": "对话数据", "sessionManagement": "自定义规则", + "sandboxes": "沙盒管理", "console": "平台日志", "trace": "追踪", "alkaid": "Alkaid", diff --git a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json index 75ce4fd931..991ca9028f 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json +++ b/dashboard/src/i18n/locales/zh-CN/features/config-metadata.json @@ -158,11 +158,11 @@ }, "agent_computer_use": { "description": "使用电脑能力", - "hint": "让 AstrBot 访问和使用本机环境或者隔离的沙盒环境,以执行更复杂的任务。详见: [电脑使用](https://docs.astrbot.app/use/computer.html)。", + "hint": "让 AstrBot 使用本机环境或隔离沙盒来执行 Shell、Python、文件操作等任务。详见:[电脑使用](https://docs.astrbot.app/use/computer.html)。", "provider_settings": { "computer_use_runtime": { "description": "运行环境", - "hint": "允许 Agent 访问的环境。local 为本机环境,sandbox 为沙箱环境,none 为不允许任何环境。" + "hint": "Agent 可以使用的执行环境。local 表示本机环境,sandbox 表示隔离沙盒,none 表示不开放电脑能力。" }, "computer_use_require_admin": { "description": "需要 AstrBot 管理员权限", @@ -170,63 +170,41 @@ }, "sandbox": { "booter": { - "description": "沙箱环境驱动器" + "description": "默认沙盒驱动" }, - "shipyard_neo_endpoint": { - "description": "Shipyard Neo API Endpoint", - "hint": "Shipyard Neo(Bay) 服务的 API 地址,默认 http://127.0.0.1:8114。" + "sandbox_lease_timeout": { + "description": "沙盒占用超时", + "hint": "单位为秒。每次 Agent 成功访问沙盒时,都会自动将本会话的沙盒租约续到当前时间 + 此时长。默认 600 秒;到期后当前会话不再绑定该沙盒,其他会话可接管。`0` 表示租约不会自动过期,需手动释放。" }, - "shipyard_neo_access_token": { - "description": "Shipyard Neo 访问令牌", - "hint": "Bay 的 API Key(sk-bay-...)。留空时自动从 credentials.json 发现。" + "sandbox_idle_timeout": { + "description": "沙盒空闲清理时间", + "hint": "单位为秒。沙盒释放后开始计算,超过这个时间仍未被再次占用就会自动清理,默认 1800 秒。`0` 表示不启用空闲清理。" }, - "shipyard_neo_profile": { - "description": "Shipyard Neo Profile", - "hint": "Shipyard Neo 沙箱 profile,例如 python-default。留空时自动选择能力更完整的 profile。" + "sandbox_ttl": { + "description": "沙盒存活时间", + "hint": "单位为秒。仅在空闲清理为 `0` 时生效,从创建开始计时,到期后强制清理。`0` 表示不自动过期。" }, - "shipyard_neo_ttl": { - "description": "Shipyard Neo Sandbox 存活时间(秒)", - "hint": "Shipyard Neo 沙箱的生存时间(秒)。" + "max_sandboxes": { + "description": "最大沙盒数量", + "hint": "全局托管沙盒数量上限,默认 10。`0` 表示不限制。" }, - "cua_image": { - "description": "CUA 镜像", - "hint": "CUA 沙箱镜像/系统类型,默认 linux。可填写 linux、macos、windows、android,具体取决于 CUA SDK 支持。" - }, - "cua_os_type": { - "description": "CUA 操作系统类型", - "hint": "CUA 沙箱操作系统类型,默认 linux。" - }, - "cua_idle_timeout": { - "description": "CUA 空闲超时(秒)", - "hint": "CUA 沙箱空闲超时时间(秒)。大于 0 时,AstrBot 会在会话空闲达到该时长后主动关闭 CUA 沙箱;0 表示禁用。" - }, - "cua_telemetry_enabled": { - "description": "CUA 遥测", - "hint": "是否允许 CUA SDK 发送遥测数据。默认关闭。" - }, - "cua_local": { - "description": "CUA 本地沙箱", - "hint": "是否优先使用 CUA 本地沙箱。默认开启,避免云端沙箱要求 CUA_API_KEY。关闭后可使用 CUA 云端沙箱。" - }, - "cua_api_key": { - "description": "CUA API Key", - "hint": "CUA 云端沙箱 API Key。仅在关闭本地沙箱时需要。也可以通过 CUA_API_KEY 环境变量提供。" - }, - "shipyard_endpoint": { - "description": "Shipyard API Endpoint", - "hint": "Shipyard 服务的 API 访问地址。" - }, - "shipyard_access_token": { - "description": "Shipyard 访问令牌", - "hint": "用于访问 Shipyard 服务的访问令牌。" - }, - "shipyard_ttl": { - "description": "Shipyard Ship 存活时间(秒)", - "hint": "Shipyard 会话的生存时间(秒)。" - }, - "shipyard_max_sessions": { - "description": "Shipyard Ship 会话复用上限", - "hint": "决定了一个实例承载的最大会话数量。" + "member_permissions": { + "create": { + "description": "允许普通用户创建沙盒", + "hint": "允许普通用户创建新的托管沙盒。创建请求仍会受到最大沙盒数量限制。" + }, + "set_retention_policy": { + "description": "允许普通用户修改沙盒保留策略", + "hint": "允许普通用户在临时沙盒和持久沙盒之间切换。持久沙盒会保留环境,方便后续复用。" + }, + "takeover": { + "description": "允许普通用户接管沙盒", + "hint": "允许普通用户接管被其他会话占用的沙盒。此操作会转移占用权,建议谨慎开启。" + }, + "destroy": { + "description": "允许普通用户删除沙盒", + "hint": "允许普通用户删除自己可访问的托管沙盒。删除后,沙盒环境和对应记录都会被移除。" + } } } } diff --git a/dashboard/src/i18n/locales/zh-CN/features/sandbox.json b/dashboard/src/i18n/locales/zh-CN/features/sandbox.json new file mode 100644 index 0000000000..92eac1e395 --- /dev/null +++ b/dashboard/src/i18n/locales/zh-CN/features/sandbox.json @@ -0,0 +1,126 @@ +{ + "title": "沙盒管理", + "subtitle": "统一查看和操作不同沙盒驱动下的托管沙盒。", + "actions": { + "create": "创建", + "refresh": "刷新", + "inspect": "查看", + "setDefault": "设为默认", + "configure": "配置", + "switch": "切换", + "takeover": "接管", + "console": "控制台", + "release": "释放", + "screenshot": "截图", + "destroy": "销毁", + "cancel": "取消", + "close": "关闭", + "save": "保存" + }, + "metrics": { + "total": "托管沙盒", + "providers": "驱动数量", + "busy": "占用中", + "default": "默认沙盒" + }, + "labels": { + "default": "默认", + "busy": "占用中", + "available": "空闲", + "noController": "未被占用", + "unknown": "未知", + "none": "无", + "temporary": "临时", + "persistent": "持久", + "creating": "创建中", + "restoring": "恢复中", + "running": "运行中", + "error": "异常", + "stopping": "销毁中", + "stopped": "已停止", + "unknownStatus": "未知状态:{status}" + }, + "headers": { + "sandbox": "沙盒", + "provider": "驱动", + "capabilities": "能力", + "status": "状态", + "lastUsed": "最后使用", + "actions": "操作" + }, + "fields": { + "provider": "驱动", + "capabilities": "能力", + "toolNames": "工具", + "status": "状态", + "owner": "所属会话", + "controller": "占用会话", + "retentionPolicy": "保留策略", + "occupiedUntil": "占用到期时间", + "idleCleanupAt": "空闲清理时间", + "expiresAt": "到期清理时间", + "connectInfo": "连接信息" + }, + "empty": { + "title": "暂无托管沙盒", + "subtitle": "可以创建新的托管沙盒,或先安装并启用沙盒驱动。" + }, + "create": { + "title": "创建沙盒", + "name": "沙盒名称", + "providerHint": "这个沙盒驱动已经可以展示在管理页,但暂时还不能从这里创建新沙盒。" + }, + "screenshot": { + "title": "沙盒截图", + "noPreview": "截图已完成。预览渲染将在后续补齐。" + }, + "console": { + "title": "沙盒控制台", + "notice": "控制台不会接管沙盒,但命令会直接影响里面的环境。", + "command": "Shell 命令", + "run": "执行命令", + "output": "执行结果", + "empty": "尚未执行命令。", + "running": "执行中...", + "dangerConfirm": "该命令可能破坏沙盒数据:{command}\n确定继续执行吗?" + }, + "config": { + "title": "沙盒配置", + "name": "沙盒名称", + "nameRequired": "沙盒名称不能为空", + "idleTimeout": "空闲超时(秒)", + "idleTimeoutHint": "临时策略下,沙盒在空闲达到该时间后可被清理,0 或空表示不启用空闲清理。", + "expiresAt": "固定过期时间", + "expiresAtHint": "可选。到达该时间后沙盒可被清理。" + }, + "tooltips": { + "takeover": "让当前控制台会话占用这个沙盒;如果它已经被别的会话占用,会转移占用权。", + "console": "打开命令行控制台,不接管沙盒。" + }, + "destroyConfirm": { + "title": "确认销毁沙盒", + "message": "确定要销毁 {name} 吗?这会关闭沙盒,并删除对应的托管记录。" + }, + "messages": { + "loadFailed": "加载沙盒失败。", + "operationFailed": "沙盒操作失败。", + "created": "沙盒已创建。", + "createSubmitted": "创建请求已提交。", + "createReady": "沙盒已进入运行中。", + "createFailed": "沙盒创建失败。", + "createUnknown": "沙盒状态未知,请手动刷新或检查后端。", + "createRefreshUnstable": "创建期间刷新状态不稳定,请稍后手动刷新确认。", + "createNotVisible": "沙盒 {sandboxId} 暂未出现在列表中,请稍后手动刷新确认。", + "createTimedOut": "沙盒 {sandboxId} 仍在处理中,请稍后手动刷新确认。", + "createUnexpectedStatus": "沙盒创建流程结束,当前状态:{status}。", + "maxSandboxesReached": "托管沙盒数量已达到上限:{max}。请先释放或销毁不需要的沙盒。", + "defaultSet": "默认沙盒已更新。", + "configSaved": "沙盒配置已保存。", + "released": "沙盒占用已释放。", + "takeover": "当前会话已接管此沙盒。", + "destroyed": "沙盒已销毁。", + "destroyRefreshUnstable": "销毁期间刷新状态不稳定,请稍后手动刷新确认。", + "destroyTimedOut": "沙盒 {sandboxId} 仍在销毁中,请稍后手动刷新确认。", + "screenshot": "截图已完成。" + } +} diff --git a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json index a2f1402618..6a8444cd4f 100644 --- a/dashboard/src/i18n/locales/zh-CN/features/tool-use.json +++ b/dashboard/src/i18n/locales/zh-CN/features/tool-use.json @@ -55,6 +55,11 @@ "permissionMember": "全部用户", "permissionEveryone": "全部用户", "permissionBuiltin": "系统内置", + "originBuiltin": "内置", + "originMcp": "MCP", + "originPlugin": "插件", + "originSandbox": "沙盒", + "originUnknown": "未知", "actions": "操作" }, "configTags": { diff --git a/dashboard/src/i18n/translations.ts b/dashboard/src/i18n/translations.ts index 9d27b41d7f..63fe69da87 100644 --- a/dashboard/src/i18n/translations.ts +++ b/dashboard/src/i18n/translations.ts @@ -37,6 +37,7 @@ import zhCNPersona from './locales/zh-CN/features/persona.json'; import zhCNMigration from './locales/zh-CN/features/migration.json'; import zhCNCommand from './locales/zh-CN/features/command.json'; import zhCNSubagent from './locales/zh-CN/features/subagent.json'; +import zhCNSandbox from './locales/zh-CN/features/sandbox.json'; import zhCNWelcome from './locales/zh-CN/features/welcome.json'; import zhCNErrors from './locales/zh-CN/messages/errors.json'; @@ -79,6 +80,7 @@ import enUSPersona from './locales/en-US/features/persona.json'; import enUSMigration from './locales/en-US/features/migration.json'; import enUSCommand from './locales/en-US/features/command.json'; import enUSSubagent from './locales/en-US/features/subagent.json'; +import enUSSandbox from './locales/en-US/features/sandbox.json'; import enUSWelcome from './locales/en-US/features/welcome.json'; import enUSErrors from './locales/en-US/messages/errors.json'; @@ -121,6 +123,7 @@ import ruRUPersona from './locales/ru-RU/features/persona.json'; import ruRUMigration from './locales/ru-RU/features/migration.json'; import ruRUCommand from './locales/ru-RU/features/command.json'; import ruRUSubagent from './locales/ru-RU/features/subagent.json'; +import ruRUSandbox from './locales/ru-RU/features/sandbox.json'; import ruRUWelcome from './locales/ru-RU/features/welcome.json'; import ruRUErrors from './locales/ru-RU/messages/errors.json'; @@ -171,6 +174,7 @@ export const translations = { migration: zhCNMigration, command: zhCNCommand, subagent: zhCNSubagent, + sandbox: zhCNSandbox, welcome: zhCNWelcome }, messages: { @@ -221,6 +225,7 @@ export const translations = { migration: enUSMigration, command: enUSCommand, subagent: enUSSubagent, + sandbox: enUSSandbox, welcome: enUSWelcome }, messages: { @@ -271,6 +276,7 @@ export const translations = { migration: ruRUMigration, command: ruRUCommand, subagent: ruRUSubagent, + sandbox: ruRUSandbox, welcome: ruRUWelcome }, messages: { diff --git a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts index f4359e4b5e..2cbc551f56 100644 --- a/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts +++ b/dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts @@ -119,6 +119,11 @@ const sidebarItem: menu[] = [ icon: 'mdi-vector-link', to: '/subagent' }, + { + title: 'core.navigation.sandboxes', + icon: 'mdi-cube-outline', + to: '/sandboxes' + }, { title: 'core.navigation.dashboard', icon: 'mdi-view-dashboard', diff --git a/dashboard/src/router/MainRoutes.ts b/dashboard/src/router/MainRoutes.ts index acb80d720a..335467de5a 100644 --- a/dashboard/src/router/MainRoutes.ts +++ b/dashboard/src/router/MainRoutes.ts @@ -89,6 +89,11 @@ const MainRoutes = { path: '/subagent', component: () => import('@/views/SubAgentPage.vue') }, + { + name: 'Sandboxes', + path: '/sandboxes', + component: () => import('@/views/SandboxManagementPage.vue') + }, { name: 'CronJobs', path: '/cron', diff --git a/dashboard/src/views/SandboxManagementPage.vue b/dashboard/src/views/SandboxManagementPage.vue new file mode 100644 index 0000000000..63c5e54651 --- /dev/null +++ b/dashboard/src/views/SandboxManagementPage.vue @@ -0,0 +1,1402 @@ + + + + + + {{ tm('title') }} + {{ tm('subtitle') }} + + + + {{ tm('actions.create') }} + + + {{ tm('actions.refresh') }} + + + + + + + {{ tm('metrics.total') }} + {{ sandboxes.length }} + + + {{ tm('metrics.providers') }} + {{ providerCount }} + + + {{ tm('metrics.busy') }} + {{ busyCount }} + + + {{ tm('metrics.default') }} + {{ defaultCount }} + + + + + + + + + + {{ item.sandbox_name || item.sandbox_id }} + {{ tm('labels.default') }} + + {{ item.sandbox_id }} + + + + + + {{ item.provider || tm('labels.unknown') }} + + + + + + + {{ capability }} + + - + + + + + + + {{ statusLabel(item) }} + + {{ item.controller_session_id }} + + + + + + + {{ formatTime(item.last_used_at) }} + + + + + {{ tm('actions.inspect') }} + {{ tm('actions.setDefault') }} + {{ tm('actions.configure') }} + + {{ tm('actions.console') }} + {{ tm('tooltips.console') }} + + {{ tm('actions.release') }} + {{ tm('actions.screenshot') }} + {{ tm('actions.destroy') }} + + + + + + mdi-cube-outline + {{ tm('empty.title') }} + {{ tm('empty.subtitle') }} + + + + + + + + + + {{ selectedSandboxRecord.sandbox_name || selectedSandboxRecord.sandbox_id }} + {{ selectedSandboxRecord.sandbox_id }} + + + + + + + + + + + + + + + {{ tm('fields.connectInfo') }} + {{ JSON.stringify(selectedSandboxRecord.connect_info || {}, null, 2) }} + + + + + + + + + {{ tm('console.title') }} + {{ consoleSandbox.sandbox_name || consoleSandbox.sandbox_id }} + + + + {{ hasController(consoleSandbox) ? tm('labels.busy') : tm('labels.available') }} + + + + + + + {{ tm('console.notice') }} + + + + + {{ consoleSandbox.provider || 'sandbox' }} + {{ consoleSandbox.controller_session_id || tm('labels.noController') }} + + + {{ tm('console.empty') }} + + {{ displayConsoleCwd(entry.cwd) }} $ {{ entry.command }} + {{ entry.stdout }} + {{ entry.stderr }} + {{ tm('console.running') }} + exit_code: {{ entry.exitCode }} + + + + {{ displayConsoleCwd(consoleCwd) }} $ + + + + + + + + + + {{ tm('create.title') }} + + + + + + + {{ tm('actions.cancel') }} + + {{ tm('actions.create') }} + + + + + + + + {{ tm('config.title') }} + + + + + + + + + + + + {{ tm('actions.cancel') }} + {{ tm('actions.save') }} + + + + + + + {{ tm('screenshot.title') }} + + + + + + {{ tm('screenshot.noPreview') }} + + + + + {{ tm('actions.close') }} + + + + + + + {{ tm('destroyConfirm.title') }} + + + {{ tm('destroyConfirm.message', { name: destroySandboxTarget?.sandbox_name || destroySandboxTarget?.sandbox_id || '-' }) }} + + + + + {{ tm('actions.cancel') }} + {{ tm('actions.destroy') }} + + + + + + {{ snackbar.message }} + + {{ tm('actions.close') }} + + + + + + + + diff --git a/dashboard/src/views/sandbox/consoleUtils.ts b/dashboard/src/views/sandbox/consoleUtils.ts new file mode 100644 index 0000000000..87c1b1eebe --- /dev/null +++ b/dashboard/src/views/sandbox/consoleUtils.ts @@ -0,0 +1,11 @@ +const DANGEROUS_CONSOLE_COMMAND_PATTERNS = [ + /(^|[;&|]\s*)rm\s+(?:-[\w-]*r[\w-]*f[\w-]*|-[\w-]*f[\w-]*r[\w-]*)\s+(?:--\s+)?(?:\/(?:\S*)?|~(?:\S*)?|\$HOME(?:\S*)?)(?:\s|$)/i, + /(^|[;&|]\s*)mkfs(?:\.[\w-]+)?\s+/, + /(^|[;&|]\s*)dd\s+[^\n]*(?:of=\/dev\/|of=\/)/, + /(^|[;&|]\s*):\(\)\s*\{\s*:\|:\s*&\s*\}\s*;/, +] + +export function isDangerousConsoleCommand(command: string) { + const normalized = command.trim() + return DANGEROUS_CONSOLE_COMMAND_PATTERNS.some((pattern) => pattern.test(normalized)) +} diff --git a/dashboard/src/views/sandbox/types.ts b/dashboard/src/views/sandbox/types.ts new file mode 100644 index 0000000000..76a01046f3 --- /dev/null +++ b/dashboard/src/views/sandbox/types.ts @@ -0,0 +1,53 @@ +export type SandboxRecord = { + sandbox_id: string + sandbox_name?: string + provider?: string + managed?: boolean + created_by_astrbot?: boolean + is_default?: boolean + owner_session_id?: string | null + controller_session_id?: string | null + lease_expires_at?: number | null + last_used_at?: number | null + idle_timeout?: number | null + idle_cleanup_at?: number | null + expires_at?: number | null + retention_policy?: string | null + status?: string + connect_info?: Record + capabilities?: string[] + tool_names?: string[] +} + +export type LoadSandboxesResult = { + ok: boolean + records: SandboxRecord[] + error?: string +} + +export type ProviderOption = { + title: string + value: string +} + +export type SandboxProviderInfo = { + provider_id: string +} + +export type SandboxAction = + | 'setDefault' + | 'configure' + | 'console' + | 'release' + | 'screenshot' + | 'destroy' + +export type ConsoleHistoryEntry = { + id: number + cwd: string + command: string + stdout: string + stderr: string + exitCode: unknown + running?: boolean +} diff --git a/docs/.vitepress/config.mjs b/docs/.vitepress/config.mjs index cd62409ca9..fd087ca9ae 100644 --- a/docs/.vitepress/config.mjs +++ b/docs/.vitepress/config.mjs @@ -191,6 +191,7 @@ export default defineConfig({ { text: "接收消息事件", link: "/guides/listen-message-event" }, { text: "发送消息", link: "/guides/send-message" }, { text: "插件配置", link: "/guides/plugin-config" }, + { text: "沙盒运行时插件", link: "/guides/sandbox-runtime" }, { text: "插件 Pages", link: "/guides/plugin-pages" }, { text: "插件国际化", link: "/guides/plugin-i18n" }, { text: "调用 AI", link: "/guides/ai" }, @@ -436,6 +437,7 @@ export default defineConfig({ { text: "Listen to Message Events", link: "/guides/listen-message-event" }, { text: "Send Messages", link: "/guides/send-message" }, { text: "Plugin Configuration", link: "/guides/plugin-config" }, + { text: "Sandbox Runtime Plugin", link: "/guides/sandbox-runtime" }, { text: "Plugin Pages", link: "/guides/plugin-pages" }, { text: "Plugin Internationalization", link: "/guides/plugin-i18n" }, { text: "AI", link: "/guides/ai" }, diff --git a/docs/en/dev/star/guides/plugin-config.md b/docs/en/dev/star/guides/plugin-config.md index cf05c94818..566e451d69 100644 --- a/docs/en/dev/star/guides/plugin-config.md +++ b/docs/en/dev/star/guides/plugin-config.md @@ -234,3 +234,5 @@ class ConfigPlugin(Star): ## Configuration Updates When you update the Schema across different versions, AstrBot will recursively inspect the configuration items in the Schema, automatically adding default values for missing items and removing those that no longer exist. + +Note that `default` is only applied when creating a new config file or when a field is missing from an existing config. If a field already exists in `data/config/_config.json`, changing the Schema `default` later will not overwrite that saved value. This is intentional so plugin upgrades do not silently replace user-edited settings. diff --git a/docs/en/dev/star/guides/sandbox-runtime.md b/docs/en/dev/star/guides/sandbox-runtime.md new file mode 100644 index 0000000000..a354d5c1ff --- /dev/null +++ b/docs/en/dev/star/guides/sandbox-runtime.md @@ -0,0 +1,166 @@ +# Building a Sandbox Runtime Plugin + +A sandbox runtime plugin teaches AstrBot how to start and connect to a sandbox service. The plugin usually contains a provider, a booter/client, a config schema, and optional tools for features such as screenshots or browser control. + +If you are migrating an existing runtime, focus on the config boundary first: AstrBot Core handles routing, reuse, and cleanup, while the plugin owns the actual endpoint, token, image, and timeout settings. + +Start with this structure: + +```text +data/plugins// + main.py + metadata.yaml + _conf_schema.json + provider.py + booters/ + tools/ +``` + +Use the files like this: + +- `main.py`: register the provider, and register any extra tools. +- `provider.py`: adapt your runtime to AstrBot's sandbox provider methods. +- `booters/`: put the client code that starts, connects to, and shuts down the sandbox. +- `tools/`: add optional runtime tools such as screenshot, mouse, keyboard, browser, or lifecycle helpers. +- `_conf_schema.json`: define the settings shown in WebUI. See [Plugin Configuration](./plugin-config.md) for the schema format. + +## 1. Register the provider + +In `main.py`, create your provider and register it when the plugin loads. Pass the plugin config into the provider so `provider.py` can read values from `_conf_schema.json`. + +```python +from astrbot.api.star import Context, Star, register +from astrbot.core.computer.computer_client import ( + register_sandbox_provider, + unregister_sandbox_provider, +) + +from .provider import MySandboxProvider + + +@register("astrbot_sandbox_demo", "AstrBot Team", "Demo sandbox provider", "0.1.0") +class DemoSandboxPlugin(Star): + def __init__(self, context: Context, config=None) -> None: + super().__init__(context) + self.provider = MySandboxProvider() + self.provider.plugin_config = config or {} + register_sandbox_provider(self.provider, replace=True) + + async def terminate(self) -> None: + unregister_sandbox_provider(self.provider.provider_id, force=True) +``` + +Use a stable name for the plugin directory, `metadata.yaml`, and `@register(...)`. Keeping them aligned makes the generated config file easy to find. + +## 2. Implement `provider.py` + +AstrBot calls the provider whenever it needs to create, reuse, rename, or destroy a sandbox. Implement these fields and methods: + +- `provider_id` +- `capabilities` +- `tool_names` +- `system_prompt` +- `build_create_config(context, session_id)` +- `build_connect_info(sandbox_name, config)` +- `update_connect_info(record, *, sandbox_name)` +- `create_booter(context, session_id, sandbox_id, config)` +- `destroy_booter(booter, record)` + +## 2.1 How config migration works + +If older versions already stored settings in `provider_settings.sandbox`, treat that as a compatibility input and move new editable values into the plugin config: + +- Put new user-facing fields in `_conf_schema.json` first. +- Use `build_create_config()` to merge plugin config with any legacy overrides. +- Keep `provider_settings.sandbox` as a transition layer only. +- Treat `tool_names`, `capabilities`, and `system_prompt` as runtime capability declarations rather than user config entries. + +This is a minimal provider skeleton: + +```python +class MySandboxProvider: + provider_id = "demo" + capabilities = {"shell", "python", "filesystem"} + tool_names = set() + system_prompt = ( + "When using this sandbox provider, follow its runtime-specific path, " + "GUI, browser, and lifecycle rules." + ) + + def build_create_config(self, context, session_id): + config = context.get_config(umo=session_id) + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + plugin_cfg = getattr(self, "plugin_config", None) or {} + return { + "endpoint_url": sandbox_cfg.get( + "demo_endpoint", plugin_cfg.get("demo_endpoint", "") + ), + "ttl": sandbox_cfg.get("demo_ttl", plugin_cfg.get("demo_ttl", 3600)), + } + + def build_connect_info(self, sandbox_name, config): + return {"name": sandbox_name, **config} + + def update_connect_info(self, record, *, sandbox_name): + info = dict(record.get("connect_info") or {}) + info["name"] = sandbox_name + return info + + async def create_booter(self, context, session_id, sandbox_id, config): + booter = MyBooter(**config) + await booter.boot(session_id) + return booter + + async def destroy_booter(self, booter, record): + await booter.shutdown() +``` + +## 3. Add runtime config + +Create `_conf_schema.json` for values that users should edit in WebUI, such as API endpoints, access tokens, profiles, image names, or timeouts. + +```json +{ + "demo_endpoint": { + "description": "Demo API Endpoint", + "type": "string", + "default": "", + "hint": "API endpoint for the demo sandbox service." + }, + "demo_ttl": { + "description": "Sandbox TTL", + "type": "int", + "default": 3600, + "hint": "Sandbox lifetime in seconds." + } +} +``` + +AstrBot stores the saved values in `data/config/_config.json` and passes them to the plugin constructor as `config`. + +If your provider still supports older values under `provider_settings.sandbox`, read them in `build_create_config()` as overrides on top of the plugin config. New provider settings should normally live in `_conf_schema.json`. + +## 4. Add optional tools + +If your runtime exposes extra abilities, register those tools in the plugin and list the tool names in `provider.tool_names`. + +Common examples: + +- screenshot tools +- mouse / keyboard tools +- browser tools +- runtime-specific lifecycle helpers + +AstrBot uses `tool_names` when mounting tools in sandbox mode. Make sure the names match the tools you register in `main.py`. + +Use `system_prompt` as provider metadata for stable runtime-specific instructions. Core exposes it through provider info so dashboards or higher-level integrations can show the provider rules, but it is not automatically appended to every model request. + +## 5. Try it locally + +After adding the plugin under `data/plugins//`, start AstrBot and check these items: + +- The plugin loads without import errors. +- The WebUI config page shows fields from `_conf_schema.json`. +- The sandbox runtime selector includes your `provider_id`. +- Creating a sandbox calls `create_booter()`. +- Stopping or unloading the plugin calls `terminate()` and unregisters the provider. diff --git a/docs/en/use/astrbot-agent-sandbox.md b/docs/en/use/astrbot-agent-sandbox.md index 31115d20de..47a112fbb0 100644 --- a/docs/en/use/astrbot-agent-sandbox.md +++ b/docs/en/use/astrbot-agent-sandbox.md @@ -3,32 +3,140 @@ > [!TIP] > This feature is currently in technical preview and may have some bugs. If you encounter any issues, please submit an issue on [GitHub](https://github.com/AstrBotDevs/AstrBot/issues). -Starting from version `v4.12.0`, AstrBot introduced the Agent sandbox environment to replace the previous code executor functionality. The sandbox environment provides Agents with safer and more flexible code execution and automation capabilities. +Starting from version `v4.12.0`, AstrBot introduced the Agent sandbox environment to replace the previous code executor functionality. It lets Agents run shell, Python, file, and desktop automation tasks in an isolated environment instead of directly on the AstrBot host.  -## Enabling the Sandbox Environment +## Installing and Enabling the Sandbox Environment + +If you are migrating existing settings, start with the config mapping. Sandbox drivers are now shipped as separate plugins, and AstrBot Core only handles routing, reuse, and cleanup. What you need to update is the `Computer Use Runtime`, the `Sandbox Driver`, and the driver-specific settings. + +Starting with the current version, concrete sandbox drivers such as `Shipyard Neo`, `BoxLite`, `Shipyard`, and `CUA` are shipped as **separate plugins**, not built into AstrBot Core by default. + +That means enabling sandbox mode always has two steps: + +1. Install the sandbox plugin you want to use +2. Then select and configure that driver in the AstrBot WebUI + +Install and load the sandbox driver plugin first. The driver only appears in the WebUI after its plugin is installed and loaded. If you only switch `Computer Use Runtime` to `sandbox` without installing a matching plugin, there will be no driver to select or configure. AstrBot currently supports the following sandbox drivers: -- `Shipyard Neo` (recommended) -- `Shipyard` (legacy option, still supported) +- [`Shipyard Neo`](https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo) (recommended for long-running and multi-user usage) +- [`BoxLite`](https://github.com/AstrBotDevs/astrbot_sandbox_boxlite) (a lightweight local sandbox for shell, Python, and file operations only) +- [`Shipyard`](https://github.com/AstrBotDevs/astrbot_sandbox_shipyard) (legacy option, still supported) +- [`CUA`](https://github.com/AstrBotDevs/astrbot_sandbox_cua) (local or cloud computer-use sandbox, suitable for desktop interaction tasks) + +If you just want to get started, choose based on the task: + +- Use `Shipyard Neo` for stable shell, Python, filesystem access, and Skills sync. +- Use `BoxLite` for a lightweight local runtime. +- Use `CUA` when you need screenshots, mouse clicks, or keyboard input. +- Keep using `Shipyard` only if you already have an old deployment. + +Recommended installation method: open the WebUI plugin management page, click the `+` button in the lower-right corner, choose URL installation, and enter the plugin repository URL: + +```text +https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo +https://github.com/AstrBotDevs/astrbot_sandbox_boxlite +https://github.com/AstrBotDevs/astrbot_sandbox_shipyard +https://github.com/AstrBotDevs/astrbot_sandbox_cua +``` + +If the plugin panel is not available in your deployment, you can manually install plugins under `data/plugins`. This is not recommended as the default path, because updates, enable/disable actions, and status checks are easier from the plugin panel. + +```bash +git clone https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo.git data/plugins/astrbot_sandbox_shipyard_neo +git clone https://github.com/AstrBotDevs/astrbot_sandbox_boxlite.git data/plugins/astrbot_sandbox_boxlite +git clone https://github.com/AstrBotDevs/astrbot_sandbox_shipyard.git data/plugins/astrbot_sandbox_shipyard +git clone https://github.com/AstrBotDevs/astrbot_sandbox_cua.git data/plugins/astrbot_sandbox_cua +``` -In the current AstrBot console, go to **AI Settings** -> **Agent Computer Use** and select: +After installation, restart AstrBot or reload plugins from the plugin management page. + +Then open the AstrBot console, go to **AI Settings** -> **Agent Computer Use**, and select: - `Computer Use Runtime` = `sandbox` -- `Sandbox Driver` = `Shipyard Neo` or `Shipyard` +- `Sandbox Driver` = `Shipyard Neo`, `BoxLite`, `Shipyard`, or `CUA` -`Shipyard Neo` is now the default driver. It consists of Bay, Ship, and Gull: +`Shipyard Neo` is the recommended default driver. It consists of Bay, Ship, and Gull: - **Bay**: the control-plane API responsible for creating and managing sandboxes - **Ship**: provides Python / Shell / filesystem capabilities - **Gull**: provides browser automation capabilities -For `Shipyard Neo`, the workspace root is fixed at `/workspace`. When using filesystem tools in AstrBot, you should pass **paths relative to the workspace root**, for example `reports/result.txt`, not `/workspace/reports/result.txt`. +For `Shipyard Neo`, the workspace root is fixed at `/workspace`. When using filesystem tools in AstrBot, pass **paths relative to the workspace root**, for example `reports/result.txt`, not `/workspace/reports/result.txt`. > [!TIP] -> Browser capability is not available in every `Shipyard Neo` profile. AstrBot only mounts browser-related tools when the selected profile supports the `browser` capability. A typical example is `browser-python`. +> Browser capability is not available in every `Shipyard Neo` profile. AstrBot only mounts browser-related tools when the selected profile supports the `browser` capability. A common example is `browser-python`. + +## Managed sandboxes, leases, and retention + +After sandbox mode is enabled, the WebUI sandbox page shows the sandboxes managed by AstrBot. These terms are worth keeping separate: + +- **Managed sandbox**: a sandbox recorded and managed by AstrBot. It may come from `Shipyard Neo`, `BoxLite`, `CUA`, or legacy `Shipyard`. +- **Default sandbox**: the sandbox AstrBot tries to reuse first for a driver. +- **Occupied**: a session is currently controlling the sandbox. Other sessions cannot use it directly unless takeover is allowed. +- **Lease**: how long the current session keeps control. The default is 600 seconds. Agents can renew an active lease, and later normal tool calls will not shorten a longer active lease. When the lease expires, the session no longer has a current sandbox; the agent must list sandboxes and switch, take over, or create one before continuing sandbox work. +- **Temporary sandbox**: can be cleaned up after it is released and stays idle or reaches its expiry time. +- **Persistent sandbox**: keeps its environment for reuse. It can still be occupied or released, but release alone does not delete it. + +In the sandbox page, `Last used` means the last time AstrBot occupied, renewed, or switched to the sandbox. `Occupied until` is the lease expiry for the current controlling session. + +Driver TTL and AstrBot leases are different. For example, `CUA Sandbox TTL` controls the lifetime of the CUA instance, while AstrBot's `Sandbox lease timeout` controls how long one message session keeps control of it. + +## CUA Runtime + +`CUA` is a sandbox runtime designed for computer-use scenarios. It can create Linux, macOS, Windows, Android, and other sandbox types through a unified Python SDK, and exposes shell, screenshot, mouse, keyboard, and filesystem interfaces. + +Before configuring the `CUA` driver in AstrBot, install the plugin from the WebUI plugin management page. Click the `+` button in the lower-right corner and install it from this URL: + +```text +https://github.com/AstrBotDevs/astrbot_sandbox_cua +``` + +If the plugin panel is not available, install it manually under `data/plugins`: + +```bash +git clone https://github.com/AstrBotDevs/astrbot_sandbox_cua.git data/plugins/astrbot_sandbox_cua +``` + +Then restart AstrBot or reload plugins from the plugin management page. + +If you run AstrBot from source or inside a virtual environment, install the CUA Python dependency in the same Python environment used by AstrBot: + +```bash +pip install cua +``` + +If you use `uv` to manage the AstrBot environment, run: + +```bash +uv pip install cua +``` + +CUA also has runtime-specific prerequisites: + +- Local Linux containers usually require Docker. +- Local Linux/Windows VMs usually require QEMU or the corresponding local CUA runtime. +- macOS VMs usually depend on CUA/Lume-related runtimes. +- Cloud CUA requires a valid CUA API key. + +For host requirements, image support, and local runtime installation details, refer to the [official CUA documentation](https://cua.ai/docs). + +After the dependency is installed, configure CUA in the AstrBot WebUI: + +- `Computer Use Runtime` = `sandbox` +- `Sandbox Driver` = `CUA` + +Common CUA settings include: + +- `CUA Image`: target image, for example `linux`, `macos`, `windows`, or `android` +- `CUA OS Type`: OS type used by AstrBot for fallback decisions +- `CUA Sandbox TTL`: sandbox lifetime in seconds +- `CUA Telemetry Enabled`: whether CUA telemetry is enabled +- `CUA Local Runtime`: whether to prefer local runtime +- `CUA API Key`: only needed when using cloud runtime ## Performance Requirements @@ -40,6 +148,18 @@ We recommend that your host machine have at least 2 CPUs, 4 GB of memory, and sw ### Deploy Shipyard Neo Separately (Recommended) +Before configuring `Shipyard Neo` in AstrBot, install its plugin from the WebUI plugin management page. Click the `+` button in the lower-right corner and install it from this URL: + +```text +https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo +``` + +If the plugin panel is not available, install it manually under `data/plugins`: + +```bash +git clone https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo.git data/plugins/astrbot_sandbox_shipyard_neo +``` + If you plan to use `Shipyard Neo` for the long term, it is generally better to **deploy it separately on a machine with more resources**, such as your homelab, a LAN server, or a dedicated cloud host, and then let AstrBot connect to Bay remotely. The reason is that `Shipyard Neo` can become fairly resource-heavy when browser capability is enabled, because it needs to run a full browser runtime. On resource-constrained cloud servers, deploying AstrBot and `Shipyard Neo` on the same machine usually puts significant pressure on CPU and memory, which can negatively affect both stability and overall experience. @@ -269,6 +389,43 @@ From AstrBot's perspective, the current implementation caches the sandbox booter For more detailed explanations of TTL and persistence behavior, see the later sections on “`Shipyard Neo Sandbox TTL`” and “Data Persistence in the Sandbox Environment”. +## BoxLite Driver + +`BoxLite` is a lighter local sandbox driver for scenarios that only need shell, Python, and file operations. It does not provide browser or GUI-specific tools. + +### Install the BoxLite Plugin + +Before configuring `BoxLite`, install it from the WebUI plugin management page. Click the `+` button in the lower-right corner and install it from this URL. `BoxLite` currently depends on part of the legacy `Shipyard` plugin code, so install the `Shipyard` plugin as well. + +```text +https://github.com/AstrBotDevs/astrbot_sandbox_boxlite +https://github.com/AstrBotDevs/astrbot_sandbox_shipyard +``` + +If the plugin panel is not available, install both plugins manually under `data/plugins`: + +```bash +git clone https://github.com/AstrBotDevs/astrbot_sandbox_boxlite.git data/plugins/astrbot_sandbox_boxlite +git clone https://github.com/AstrBotDevs/astrbot_sandbox_shipyard.git data/plugins/astrbot_sandbox_shipyard +``` + +Then restart AstrBot or reload plugins from the plugin management page. + +### Configure BoxLite in AstrBot + +Open the WebUI: + +- `Config -> General Config -> Computer Use` + +Then set: + +- `Computer Use Runtime` = `sandbox` +- `Sandbox Driver` = `BoxLite` + +`BoxLite` does not currently expose any extra driver-specific configuration fields. Once the plugin is enabled, it is ready to use. + +Then restart AstrBot or reload plugins from the plugin management page. + ## Legacy Option: Shipyard The following content describes the older `Shipyard` driver. It is kept for compatibility with existing legacy deployments. diff --git a/docs/zh/dev/star/guides/plugin-config.md b/docs/zh/dev/star/guides/plugin-config.md index 4016f70ba9..f9b7aa99fd 100644 --- a/docs/zh/dev/star/guides/plugin-config.md +++ b/docs/zh/dev/star/guides/plugin-config.md @@ -233,3 +233,5 @@ class ConfigPlugin(Star): ## 配置更新 您在发布不同版本更新 Schema 时,AstrBot 会递归检查 Schema 的配置项,自动为缺失的配置项添加默认值、移除不存在的配置项。 + +需要注意的是,`default` 只会用于“新建配置文件”或“已有配置中缺失的字段”。如果某个字段已经存在于 `data/config/_config.json` 中,那么后续即使您修改了 Schema 里的 `default`,AstrBot 也不会自动覆盖这个已保存的值。这是为了避免升级插件时意外覆盖用户已经手动修改过的配置。 diff --git a/docs/zh/dev/star/guides/sandbox-runtime.md b/docs/zh/dev/star/guides/sandbox-runtime.md new file mode 100644 index 0000000000..76bb07d48a --- /dev/null +++ b/docs/zh/dev/star/guides/sandbox-runtime.md @@ -0,0 +1,197 @@ +# 制作沙盒运行时插件 + +沙盒运行时插件负责告诉 AstrBot 如何启动并连接到一个沙盒服务。一个插件通常包含驱动实现、booter/client、配置 schema,以及截图、浏览器控制这类可选工具。 + +如果你要把一个现有运行时迁移成插件,先把配置边界分清:AstrBot Core 只管“什么时候用、用哪个、怎么回收”,真正要填哪些地址、令牌、镜像和超时,应该放在插件自己的配置里。 + +## 先看整体结构 + +一个沙盒运行时插件通常会被拆成四块: + +- **provider**:告诉 AstrBot 这个运行时能做什么、怎么创建沙盒、怎么销毁沙盒。 +- **booter / client**:负责和真实沙盒服务交互,创建、连接和关闭实例。 +- **tools**:补充截图、鼠标、键盘、浏览器或生命周期类工具。 +- **配置 schema**:把用户需要改的项暴露到 WebUI。 + +分工看起来不复杂,但把边界拆清楚之后,后面做迁移、排障和补新能力都会轻松很多。 + +可以先按这个结构创建目录: + +```text +data/plugins// + main.py + metadata.yaml + _conf_schema.json + provider.py + booters/ + tools/ +``` + +各文件大致这样分工: + +- `main.py`:注册沙盒驱动,也可以注册额外工具。 +- `provider.py`:把你的运行时适配成 AstrBot 的沙盒驱动。 +- `booters/`:放启动、连接、关闭沙盒的 client 代码。 +- `tools/`:放截图、鼠标、键盘、浏览器或生命周期这类可选工具。 +- `_conf_schema.json`:定义 WebUI 中展示的配置项。schema 格式见[插件配置](./plugin-config.md)。 + +## 1. 注册沙盒驱动 + +在 `main.py` 中创建驱动实例,并在插件加载时注册它。把插件配置传给驱动,这样 `provider.py` 就能读取 `_conf_schema.json` 生成的配置。 + +```python +from astrbot.api.star import Context, Star, register +from astrbot.core.computer.computer_client import ( + register_sandbox_provider, + unregister_sandbox_provider, +) + +from .provider import MySandboxProvider + + +@register("astrbot_sandbox_demo", "AstrBot Team", "Demo 沙盒驱动", "0.1.0") +class DemoSandboxPlugin(Star): + def __init__(self, context: Context, config=None) -> None: + super().__init__(context) + self.provider = MySandboxProvider() + self.provider.plugin_config = config or {} + register_sandbox_provider(self.provider, replace=True) + + async def terminate(self) -> None: + unregister_sandbox_provider(self.provider.provider_id, force=True) +``` + +建议插件目录名、`metadata.yaml` 里的 `name`、`@register(...)` 的名字保持一致,后续查找生成的配置文件会更直观。 + +## 2. 实现 `provider.py` + +AstrBot 在创建、复用、重命名、销毁沙盒时会调用这个驱动。你需要实现这些字段和方法: + +- `provider_id` +- `capabilities` +- `tool_names` +- `system_prompt` +- `build_create_config(context, session_id)` +- `build_connect_info(sandbox_name, config)` +- `update_connect_info(record, *, sandbox_name)` +- `create_booter(context, session_id, sandbox_id, config)` +- `destroy_booter(booter, record)` + +## 2.1 配置迁移怎么做 + +如果旧版本已经把配置写在 `provider_settings.sandbox` 里,迁移时可以先把它当成兼容输入,再逐步迁到插件配置: + +- 新的可编辑项优先放到 `_conf_schema.json`。 +- `build_create_config()` 负责把插件配置和旧配置合并成真正的创建参数。 +- `provider_settings.sandbox` 只适合作为过渡层,不建议继续扩展新字段。 +- `tool_names`、`capabilities`、`system_prompt` 不是用户配置入口,它们更像运行时能力声明。 + +下面是一个最小骨架: + +```python +class MySandboxProvider: + provider_id = "demo" + capabilities = {"shell", "python", "filesystem"} + tool_names = set() + system_prompt = ( + "When using this sandbox provider, follow its runtime-specific path, " + "GUI, browser, and lifecycle rules." + ) + + def build_create_config(self, context, session_id): + config = context.get_config(umo=session_id) + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + plugin_cfg = getattr(self, "plugin_config", None) or {} + return { + "endpoint_url": sandbox_cfg.get( + "demo_endpoint", plugin_cfg.get("demo_endpoint", "") + ), + "ttl": sandbox_cfg.get("demo_ttl", plugin_cfg.get("demo_ttl", 3600)), + } + + def build_connect_info(self, sandbox_name, config): + return {"name": sandbox_name, **config} + + def update_connect_info(self, record, *, sandbox_name): + info = dict(record.get("connect_info") or {}) + info["name"] = sandbox_name + return info + + async def create_booter(self, context, session_id, sandbox_id, config): + booter = MyBooter(**config) + await booter.boot(session_id) + return booter + + async def destroy_booter(self, booter, record): + await booter.shutdown() +``` + +## 3. 添加运行时配置 + +如果用户需要在 WebUI 中填写 API 地址、访问令牌、profile、镜像名或超时时间,把这些字段写进 `_conf_schema.json`。 + +```json +{ + "demo_endpoint": { + "description": "Demo API Endpoint", + "type": "string", + "default": "", + "hint": "API endpoint for the demo sandbox service." + }, + "demo_ttl": { + "description": "Sandbox TTL", + "type": "int", + "default": 3600, + "hint": "Sandbox lifetime in seconds." + } +} +``` + +AstrBot 会把用户保存的值写到 `data/config/_config.json`,并在插件实例化时作为 `config` 传入。 + +如果驱动还需要兼容旧的 `provider_settings.sandbox` 配置,可以在 `build_create_config()` 中把它们作为插件配置之上的 override。新的驱动配置通常放在 `_conf_schema.json` 即可。 + +## 4. 添加可选工具 + +如果你的运行时除了 shell/python/filesystem 外还提供其他能力,在插件里注册对应工具,并把工具名写进 `provider.tool_names`。 + +常见工具包括: + +- 截图工具 +- 鼠标 / 键盘工具 +- 浏览器工具 +- 运行时专属生命周期工具 + +AstrBot 在沙盒模式下挂载工具时会读取 `tool_names`。这里的名字要和 `main.py` 中注册的工具名一致。 + +`system_prompt` 可以作为 provider 元数据保存稳定的运行时提示词。Core 会在 provider info 中暴露它,方便 WebUI 或上层集成展示驱动规则,但不会自动把它追加到每次模型请求里。 + +## 4.1 一次沙盒请求是怎么走的 + +可以把它理解成下面这条链路: + +1. 插件加载时,`main.py` 先把 provider 注册到 AstrBot。 +2. 用户发起沙盒请求后,AstrBot 先看当前会话有没有可复用的沙盒。 +3. 如果能复用,AstrBot 直接接回现有 booter。 +4. 如果不能复用,AstrBot 会调用 provider 生成创建配置,再让 booter/client 去启动新沙盒。 +5. 沙盒创建成功后,AstrBot 会把它写回 registry,并按 capability 和 `tool_names` 挂载对应工具。 + +这里最容易忽略的一点是:AstrBot 并不知道某个运行时自己该怎么工作,它只认识 provider 提供的抽象能力。路径规则、浏览器约束、持久化目录、是否支持某些工具,这些都应该由 provider 讲清楚。 + +## 4.2 迁移旧运行时时的几个注意点 + +- 如果旧配置里已经有 `provider_settings.sandbox`,可以先把它当成兼容层,新的可编辑项尽量放到 `_conf_schema.json`。 +- `tool_names` 要和 `main.py` 里实际注册的工具名一致,不然工具挂载会对不上。 +- `system_prompt` 如果提供,最好直接写清楚路径规则和使用约束,别只写“请遵守运行时规则”这种空话。 +- 如果运行时有持久化数据,最好在文档里明确写出它的根目录、生命周期和是否会在重建后保留。 +- 迁移完成后,先验证最基础的创建、复用、销毁,再去看截图、浏览器和上传下载这些扩展能力。 + +## 5. 本地试跑 + +把插件放到 `data/plugins//` 后,启动 AstrBot,重点检查这些点: + +- 插件加载时没有 import error。 +- WebUI 配置页能看到 `_conf_schema.json` 里的字段。 +- 沙盒驱动选择项里能看到你的 `provider_id`。 +- 创建沙盒时会调用 `create_booter()`。 +- 停止或卸载插件时会调用 `terminate()` 并注销驱动。 diff --git a/docs/zh/use/astrbot-agent-sandbox.md b/docs/zh/use/astrbot-agent-sandbox.md index fd80691411..5679328af7 100644 --- a/docs/zh/use/astrbot-agent-sandbox.md +++ b/docs/zh/use/astrbot-agent-sandbox.md @@ -3,39 +3,99 @@ > [!TIP] > 此功能目前处于技术预览阶段,可能会存在一些 Bug。如果您遇到了问题,请在 [GitHub](https://github.com/AstrBotDevs/AstrBot/issues) 上提交 issue。 -在 `v4.12.0` 版本及之后,AstrBot 引入了 Agent 沙盒环境,以替代之前的代码执行器功能。沙盒环境给 Agent 提供了更安全、更灵活的代码执行和自动化操作能力。 +从 `v4.12.0` 开始,AstrBot 引入了 Agent 沙盒环境,用来替代之前的代码执行器功能。它让 Agent 在隔离环境里运行 Shell、Python、文件操作或桌面自动化任务,不必直接碰 AstrBot 主机。 + +如果你是从旧配置迁移过来的,先看清楚配置是怎么对应的。现在沙盒运行时已经拆成独立插件,AstrBot Core 只负责调度、复用、占用和清理;你需要做的是:先装驱动插件,再把 `Computer Use Runtime`、`沙盒驱动` 和驱动自己的参数填好。  -## 启用沙盒环境 +## 安装与启用沙盒环境 + +### 先看这 3 个配置 + +先安装并加载沙盒驱动插件,WebUI 里才会出现对应的驱动选项。只把 `Computer Use Runtime` 改成 `sandbox` 还不够;如果插件没装好,后面就没法选择和配置这个沙盒驱动。 + +插件安装完成后,再回到 WebUI 配置: + +1. `Computer Use Runtime` 设为 `sandbox`。 +2. 在 `沙盒驱动` 里选择已经安装的驱动,例如 `Shipyard Neo`、`BoxLite`、`Shipyard` 或 `CUA`。 +3. 按所选驱动补齐对应配置,比如 `Shipyard Neo API Endpoint`、`Shipyard Neo Access Token`、`CUA Image`、`CUA Sandbox TTL`。 + +这几个驱动现在都以独立插件的形式提供,所以顺序一定是:先装插件,重启 AstrBot 或重载插件,再回到 WebUI 选择和配置沙盒。 + +当前可用的沙盒驱动包括: + +- [`Shipyard Neo`](https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo)(推荐,适合长期运行和多人使用) +- [`BoxLite`](https://github.com/AstrBotDevs/astrbot_sandbox_boxlite)(轻量本地沙盒,适合只需要 Shell、Python 和文件操作的场景) +- [`Shipyard`](https://github.com/AstrBotDevs/astrbot_sandbox_shipyard)(旧方案,仍可继续使用) +- [`CUA`](https://github.com/AstrBotDevs/astrbot_sandbox_cua)(本地或云端电脑使用沙盒,适合需要桌面操作的场景) -目前,AstrBot 的沙盒环境驱动器支持: +如果你只是想先跑通沙盒,建议这样选: + +- 需要稳定的 Shell、Python、文件系统和 Skills 同步:优先选 `Shipyard Neo`。 +- 只想要轻量本地执行环境:选 `BoxLite`。 +- 需要截图、鼠标、键盘这类桌面操作:选 `CUA`。 +- 还在用旧部署:继续用 `Shipyard`,但新环境不建议再从它开始。 + +推荐安装方式:打开 WebUI 的“插件管理”页面,点击右下角 `+`,选择通过 URL 安装插件,然后填入对应的插件仓库地址: + +```text +https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo +https://github.com/AstrBotDevs/astrbot_sandbox_boxlite +https://github.com/AstrBotDevs/astrbot_sandbox_shipyard +https://github.com/AstrBotDevs/astrbot_sandbox_cua +``` + +如果你的部署环境无法使用插件面板,也可以手动安装到 `data/plugins`。这种方式不推荐作为首选,因为后续更新、启停和状态查看都没有插件面板顺手。 + +```bash +git clone https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo.git data/plugins/astrbot_sandbox_shipyard_neo +git clone https://github.com/AstrBotDevs/astrbot_sandbox_boxlite.git data/plugins/astrbot_sandbox_boxlite +git clone https://github.com/AstrBotDevs/astrbot_sandbox_shipyard.git data/plugins/astrbot_sandbox_shipyard +git clone https://github.com/AstrBotDevs/astrbot_sandbox_cua.git data/plugins/astrbot_sandbox_cua +``` -- `Shipyard Neo`(当前推荐) -- `Shipyard`(旧方案,仍可继续使用) -- `CUA`(本地或云端电脑使用沙盒,适合需要桌面操作的场景) +安装完成后,重启 AstrBot,或者在插件管理页重新加载插件。 -在当前版本的 AstrBot 控制台中,可在“AI 配置” -> “Agent Computer Use”中选择: +然后打开 AstrBot 控制台,在“AI 配置” -> “Agent Computer Use”中选择: - `Computer Use Runtime` = `sandbox` -- `沙箱环境驱动器` = `Shipyard Neo`、`Shipyard` 或 `CUA` +- `沙盒驱动` = `Shipyard Neo`、`BoxLite`、`Shipyard` 或 `CUA` -其中,`Shipyard Neo` 是当前默认驱动器。它由 Bay、Ship、Gull 三部分组成: +其中,`Shipyard Neo` 是当前推荐的默认驱动。它由 Bay、Ship、Gull 三部分组成: -- **Bay**:控制面 API,负责创建和管理 sandbox +- **Bay**:控制面 API,负责创建和管理沙盒 - **Ship**:负责 Python / Shell / 文件系统能力 - **Gull**:负责浏览器自动化能力 -对于 `Shipyard Neo`,工作区根目录固定为 `/workspace`。在 AstrBot 中调用文件系统工具时,应当传入**相对于工作区根目录**的路径,例如 `reports/result.txt`,而不是 `/workspace/reports/result.txt`。 +对于 `Shipyard Neo`,工作区根目录固定为 `/workspace`。在 AstrBot 中调用文件系统工具时,请传入**相对于工作区根目录**的路径,例如 `reports/result.txt`,不要写成 `/workspace/reports/result.txt`。 > [!TIP] -> `Shipyard Neo` 下浏览器能力并不是所有 profile 都有。只有 profile 支持 `browser` capability 时,AstrBot 才会挂载浏览器相关工具。典型 profile 如 `browser-python`。 +> `Shipyard Neo` 的浏览器能力不是所有 profile 都支持。只有 profile 带有 `browser` capability 时,AstrBot 才会挂载浏览器相关工具。一个常见示例是 `browser-python`。 + +## 托管沙盒、占用和保留策略 + +启用沙盒后,你可以在 WebUI 的“沙盒管理”页面看到 AstrBot 管理的沙盒。这里有几个容易混淆的词: + +- **托管沙盒**:AstrBot 记录并管理的沙盒。它可能来自 `Shipyard Neo`、`BoxLite`、`CUA` 或旧版 `Shipyard`。 +- **默认沙盒**:某个驱动下优先复用的沙盒。Agent 没有明确指定时,会优先尝试复用它。 +- **占用**:某个会话正在控制这个沙盒。占用期间,其他会话不能直接抢用,除非管理员或配置允许接管。 +- **占用租约**:占用的有效时间。默认 600 秒。Agent 可以调用续租工具延长有效租约;续租后,后续普通工具调用不会把更长的租约缩短。租约过期后,当前会话不再绑定该沙盒;Agent 需要先查看沙盒列表,再切换、接管或创建沙盒后继续工作。 +- **临时沙盒**:释放后可以按空闲时间或过期时间自动清理。 +- **持久沙盒**:环境会保留,适合准备好依赖后反复使用。它仍然会有“占用”状态,只是释放后不会因为临时清理策略被删掉。 -## CUA 运行时 +排障时可以先看这两项: -`CUA` 是一个面向电脑使用(Computer Use)的沙盒运行时。它可以通过统一的 Python SDK 创建 Linux、macOS、Windows、Android 等不同类型的沙盒,并暴露 Shell、截图、鼠标、键盘、文件系统等接口。 +- `最后使用`:最近一次由 Agent 占用、续租或切换到这个沙盒的时间。 +- `占用到期时间`:当前会话控制权的到期时间。它和驱动自己的 TTL 不是一回事。 -在 AstrBot 中选择 `CUA` 驱动器后,Agent 可以在 CUA sandbox 中使用: +驱动 TTL 通常表示沙盒实例本身最多保留多久;占用租约只表示当前会话还能控制它多久。比如 CUA 的 `CUA Sandbox TTL` 管的是 CUA 实例生命周期,而 AstrBot 的 `沙盒占用超时` 管的是会话占用时间。 + +## CUA 驱动 + +`CUA` 面向电脑使用(Computer Use)场景。它可以通过统一的 Python SDK 创建 Linux、macOS、Windows、Android 等不同类型的沙盒,并提供 Shell、截图、鼠标、键盘和文件系统等接口。 + +在 AstrBot 中选择 `CUA` 驱动后,Agent 可以在 CUA 沙盒中使用: - Shell 工具 - Python 工具 @@ -46,41 +106,57 @@ - 沙盒文件上传与下载工具 > [!NOTE] -> CUA 是可选运行时,AstrBot 默认安装不会强制安装它。如果选择了 `CUA` 但当前 Python 环境没有安装 `cua` 包,启动沙盒时会提示安装缺失。 +> CUA 是可选驱动,AstrBot 默认不会强制安装它。如果你选择了 `CUA`,但当前 Python 环境里没有安装 `cua` 包,启动沙盒时会提示缺少依赖。 + +### 先安装 CUA 插件 + +在配置 `CUA` 之前,推荐先在 WebUI 的“插件管理”页面点击右下角 `+`,通过 URL 安装插件: + +```text +https://github.com/AstrBotDevs/astrbot_sandbox_cua +``` + +如果无法使用插件面板,再手动安装到 `data/plugins`: + +```bash +git clone https://github.com/AstrBotDevs/astrbot_sandbox_cua.git data/plugins/astrbot_sandbox_cua +``` + +然后重启 AstrBot,或者在插件管理页重新加载插件。 ### 安装 CUA 依赖 -如果您通过源码或虚拟环境运行 AstrBot,请在 AstrBot 使用的 Python 环境中安装 CUA: +如果你是通过源码或虚拟环境运行 AstrBot,请在 AstrBot 使用的 Python 环境里安装 CUA: ```bash pip install cua ``` -如果您使用 `uv` 管理 AstrBot 环境,可在 AstrBot 项目目录中执行: +如果你使用 `uv` 管理 AstrBot 环境,可以在 AstrBot 项目目录中执行: ```bash uv pip install cua ``` -CUA 本身还依赖具体运行方式: +CUA 还依赖具体的运行方式: - 本地 Linux 容器通常需要 Docker 可用。 - 本地 Linux/Windows VM 通常需要 QEMU 或 CUA 对应的本地运行时。 - macOS VM 通常依赖 CUA/Lume 相关运行时。 - 云端 CUA 需要可用的 CUA API Key。 -具体宿主机要求、镜像支持情况和本地运行时安装方式,请参考 [CUA 官方文档](https://cua.ai/docs)。 +具体的宿主机要求、镜像支持情况,以及本地运行时的安装方式,请参考 [CUA 官方文档](https://cua.ai/docs)。 ### 在 AstrBot 中配置 CUA -进入 WebUI: +打开 WebUI: - `配置 -> 普通配置 -> 使用电脑能力` 然后设置: - `Computer Use Runtime` = `sandbox` -- `沙箱环境驱动器` = `CUA` +- `沙盒驱动` = `CUA` CUA 相关配置项包括: @@ -91,22 +167,22 @@ CUA 相关配置项包括: - `CUA Local Runtime`:是否使用本地运行时。默认开启。关闭后会按 CUA SDK 的云端方式创建沙盒。 - `CUA API Key`:云端 CUA 所需的 API Key。仅在使用云端运行时时填写。 -一个最小本地 Linux 容器配置通常是: +一个最小的本地 Linux 容器配置通常是: ```text Computer Use Runtime = sandbox -沙箱环境驱动器 = CUA +沙盒驱动 = CUA CUA Image = linux CUA OS Type = linux CUA Local Runtime = true CUA Sandbox TTL = 3600 ``` -如果使用云端 CUA,可改为: +如果使用云端 CUA,可以改成: ```text Computer Use Runtime = sandbox -沙箱环境驱动器 = CUA +沙盒驱动 = CUA CUA Image = linux CUA OS Type = linux CUA Local Runtime = false @@ -114,14 +190,14 @@ CUA API Key = ``` > [!WARNING] -> 不要把 CUA API Key 写入公开日志、截图或 issue。AstrBot 的运行日志不会输出该字段,但部署平台、Shell 历史和容器环境变量仍需自行保护。 +> 不要把 CUA API Key 写进公开日志、截图或 issue。AstrBot 的运行日志不会输出这个字段,但部署平台、Shell 历史和容器环境变量仍然需要你自己保护好。 ### 使用 CUA 时的注意事项 - `linux` 镜像通常适合 Shell、Python、文件系统和桌面自动化测试。 -- 非 POSIX 镜像(如 `windows`、`android`)不一定支持 `sh`、`cat`、`ls`、`rm`、`base64` 等命令。AstrBot 对需要这些命令的 fallback 操作会返回明确错误。 -- 如果需要在 CUA sandbox 中打开浏览器或 GUI 程序,通常应使用 Shell 后台执行,例如显式传入 `background=true`,避免命令阻塞后续工具调用。 -- 直接把 sandbox 内的文件路径发送给用户通常不可行。应优先使用 AstrBot 的沙盒下载工具,将文件下载到 AstrBot 临时目录后再发送。 +- 非 POSIX 镜像(如 `windows`、`android`)不一定支持 `sh`、`cat`、`ls`、`rm`、`base64` 等命令。AstrBot 对依赖这些命令的 fallback 操作会返回明确错误。 +- 如果你需要在 CUA 沙盒里打开浏览器或 GUI 程序,通常应通过 Shell 后台执行,例如显式传入 `background=true`,避免命令阻塞后续工具调用。 +- 直接把沙盒内的文件路径发给用户通常不可行。应优先使用 AstrBot 的沙盒下载工具,把文件下载到 AstrBot 临时目录后再发送。 - CUA 与 Shipyard Neo 的 workspace 语义不同。Shipyard Neo 固定使用 `/workspace`;CUA 的工作目录和文件路径取决于镜像与运行时。 ### 何时选择 CUA @@ -132,19 +208,31 @@ CUA API Key = - 需要测试不同 OS 镜像中的行为,例如 Linux、Windows、Android。 - 已经在本机或云端部署好 CUA 运行环境。 -如果只是需要稳定的 Python/Shell/文件系统沙盒,且不需要桌面 GUI 操作,通常优先选择 `Shipyard Neo`。它与 AstrBot 的 workspace、Skills 同步和长期运行模式更贴合。 +如果你只是需要稳定的 Python/Shell/文件系统沙盒,而且不需要桌面 GUI 操作,通常优先选择 `Shipyard Neo`。它和 AstrBot 的 workspace、Skills 同步,以及长期运行模式更匹配。 ## 性能要求 AstrBot 给每个沙盒环境限制最高 1 CPU 和 512 MB 内存。 -我们建议您的宿主机至少有 2 个 CPU 和 4 GB 内存,并开启 Swap,以保证多个沙盒环境实例可以稳定运行。 +建议宿主机至少有 2 个 CPU 和 4 GB 内存,并开启 Swap。这样同时跑多个沙盒时更稳。 ## 推荐:使用 Shipyard Neo ### 单独部署 Shipyard Neo(推荐) -如果您准备长期使用 `Shipyard Neo`,更推荐将它**单独部署在一台资源更充足的机器上**,例如您的 homelab、局域网服务器,或独立云主机,然后再让 AstrBot 远程接入 Bay。 +在 AstrBot 侧配置 `Shipyard Neo` 之前,推荐先在 WebUI 的“插件管理”页面点击右下角 `+`,通过 URL 安装插件: + +```text +https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo +``` + +如果无法使用插件面板,再手动安装到 `data/plugins`: + +```bash +git clone https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo.git data/plugins/astrbot_sandbox_shipyard_neo +``` + +如果你准备长期使用 `Shipyard Neo`,更推荐将它**单独部署在一台资源更充足的机器上**,例如 homelab、局域网服务器,或独立云主机,然后再让 AstrBot 远程接入 Bay。 原因是:`Shipyard Neo` 在启用浏览器能力时需要运行较重的浏览器运行时。对于资源紧张的云服务器,把 AstrBot 和 `Shipyard Neo` 部署在同一台机器上,通常会让 CPU 和内存压力都比较大,稳定性和体验都不理想。 @@ -160,7 +248,7 @@ docker compose up -d 部署完成后: - Bay 默认监听在 `http://:8114` -- 在 AstrBot 控制台中选择 `Shipyard Neo` 驱动器 +- 在 AstrBot 控制台中选择 `Shipyard Neo` 驱动 - `Shipyard Neo API Endpoint` 填写对应地址,例如 `http://:8114` - `Shipyard Neo Access Token` 填写 Bay API Key;如果 AstrBot 能访问 Bay 的 `credentials.json`,也可以留空让 AstrBot 自动发现 @@ -175,7 +263,7 @@ docker compose up -d # Bay Production Config - Docker Compose (container_network mode) # # Bay 运行在 Docker 容器中,并通过共享 Docker 网络与 Ship/Gull 容器通信。 -# 这种模式下,sandbox 容器不需要向宿主机暴露端口。 +# 这种模式下,沙盒容器不需要向宿主机暴露端口。 # # 部署前至少需要修改: # 1. security.api_key —— 设置强随机密钥 @@ -197,7 +285,7 @@ driver: # 当前默认使用 Docker 驱动 type: docker - # 创建新 sandbox 时是否拉取镜像。 + # 创建新沙盒时是否拉取镜像。 # 生产环境通常建议 always,以便拿到最新镜像。 image_pull_policy: always @@ -212,7 +300,7 @@ driver: # 共享网络名,必须与 docker-compose.yaml 中的网络一致 network: "bay-network" - # 是否将 sandbox 容器端口暴露到宿主机。 + # 是否将沙盒容器端口暴露到宿主机。 # 生产环境建议关闭,以减少攻击面。 publish_ports: false host_port: null @@ -222,7 +310,7 @@ cargo: root_path: "/var/lib/bay/cargos" # 默认工作区大小限制(MB) default_size_limit_mb: 1024 - # Cargo 挂载到 sandbox 内的路径。AstrBot/Neo 的工作区根目录就是这里。 + # Cargo 挂载到沙盒内的路径。AstrBot/Neo 的工作区根目录就是这里。 mount_path: "/workspace" security: @@ -232,15 +320,15 @@ security: allow_anonymous: false # 容器代理环境变量注入。 -# 启用后,Bay 会把 HTTP(S)_PROXY 和 NO_PROXY 注入到 sandbox 容器。 +# 启用后,Bay 会把 HTTP(S)_PROXY 和 NO_PROXY 注入到沙盒容器。 proxy: enabled: false # http_proxy: "http://proxy.example.com:7890" # https_proxy: "http://proxy.example.com:7890" # no_proxy: "my-internal.service" -# Warm Pool:预热一批待命 sandbox,减少冷启动延迟。 -# 当用户创建 sandbox 时,Bay 会优先尝试领取一个已预热实例。 +# Warm Pool:预热一批待命沙盒,减少冷启动延迟。 +# 当用户创建沙盒时,Bay 会优先尝试领取一个已预热实例。 warm_pool: enabled: true # 预热队列 worker 数量 @@ -368,13 +456,50 @@ gc: - **Session**:实际运行中的容器会话,可被停止或重建 - **Cargo**:持久化工作区卷,挂载到 `/workspace` -对 AstrBot 而言,当前会按请求的 `session_id` 维度缓存沙箱 booter;在主 Agent 默认流程下,这个 `session_id` 通常等于消息会话标识 `unified_msg_origin`。因此,同一消息会话的后续请求通常会继续复用同一个 Neo sandbox;如果沙箱失效,则会自动重建。 +AstrBot 会按请求的 `session_id` 缓存沙盒 booter。在主 Agent 默认流程下,这个 `session_id` 通常等于消息会话标识 `unified_msg_origin`。因此,同一消息会话的后续请求通常会复用同一个 Neo 沙盒;如果沙盒失效,AstrBot 会自动重建。 关于 TTL 与数据持久化的更详细说明,请参考下文的“关于 `Shipyard Neo Sandbox TTL`”与“关于沙盒环境的数据持久化”小节。 +## BoxLite 驱动 + +`BoxLite` 是一个更轻量的本地沙盒驱动,适合只需要 Shell、Python 和文件操作的场景,不提供浏览器或 GUI 专用工具。 + +### 安装 BoxLite 插件 + +在配置 `BoxLite` 之前,推荐先在 WebUI 的“插件管理”页面点击右下角 `+`,通过 URL 安装 `BoxLite` 插件。当前 `BoxLite` 还依赖旧版 `Shipyard` 插件中的部分代码,因此也需要安装 `Shipyard` 插件。 + +```text +https://github.com/AstrBotDevs/astrbot_sandbox_boxlite +https://github.com/AstrBotDevs/astrbot_sandbox_shipyard +``` + +如果无法使用插件面板,再手动安装到 `data/plugins`: + +```bash +git clone https://github.com/AstrBotDevs/astrbot_sandbox_boxlite.git data/plugins/astrbot_sandbox_boxlite +git clone https://github.com/AstrBotDevs/astrbot_sandbox_shipyard.git data/plugins/astrbot_sandbox_shipyard +``` + +然后重启 AstrBot,或者在插件管理页重新加载插件。 + +### 在 AstrBot 中配置 BoxLite + +打开 WebUI: + +- `配置 -> 普通配置 -> 使用电脑能力` + +然后设置: + +- `Computer Use Runtime` = `sandbox` +- `沙盒驱动` = `BoxLite` + +`BoxLite` 当前没有额外的驱动级配置项,启用插件后即可使用。 + +然后重启 AstrBot,或在插件管理页重新加载插件。 + ## 旧方案:Shipyard -以下内容为旧版 `Shipyard` 驱动器的部署与配置说明,仍然保留,供兼容旧部署方案时参考。 +以下内容是旧版 `Shipyard` 驱动的部署与配置说明,保留给仍在使用旧部署方案的用户参考。 ### 使用 Docker Compose 部署 AstrBot 和 Shipyard @@ -418,8 +543,8 @@ docker pull soulter/shipyard-ship:latest 在 AstrBot 控制台,进入 “AI 配置” -> “Agent Computer Use”。 1. 将 `Computer Use Runtime` 设为 `sandbox` -2. 在 `沙箱环境驱动器` 中选择 `Shipyard Neo` 或 `Shipyard` -3. 根据驱动器填写对应配置项 +2. 在 `沙盒驱动` 中选择 `Shipyard Neo` 或 `Shipyard` +3. 根据驱动填写对应配置项 4. 点击右下角“保存” ### 配置 Shipyard Neo @@ -436,7 +561,7 @@ docker pull soulter/shipyard-ship:latest - 例如 `python-default`、`browser-python` - 如果留空,AstrBot 会优先尝试选择能力更完整、且优先带有 `browser` capability 的 profile,失败时再回退到 `python-default` - `Shipyard Neo Sandbox TTL` - - sandbox 生命周期上限,默认值为 3600 秒(1 小时) + - 沙盒生命周期上限,默认值为 3600 秒(1 小时) ### 配置 Shipyard(旧方案) @@ -456,11 +581,13 @@ docker pull soulter/shipyard-ship:latest 在 `Shipyard Neo` 中: -- TTL 表示 sandbox 生命周期上限 +- TTL 表示沙盒生命周期上限 - profile 还会定义一个独立的空闲超时(`idle_timeout`) - AstrBot 发起能力调用时,通常会刷新空闲超时,而不是直接延长 TTL - `keepalive` 只会延长空闲超时,不会自动启动新的 session,也不会延长 TTL +换句话说,TTL 更像“这个沙盒最多能活多久”,空闲超时更像“这个沙盒多久没人用就可以收掉”。两者不是一回事,排障时最好分开看。 + ## 关于 `Shipyard Ship 存活时间(秒)` 以下说明仅适用于旧版 `Shipyard`: diff --git a/docs/zh/use/computer.md b/docs/zh/use/computer.md index 7da8dd5d17..eb54d6fe82 100644 --- a/docs/zh/use/computer.md +++ b/docs/zh/use/computer.md @@ -92,16 +92,20 @@ data/workspaces/{normalized_umo}/notes/todo.txt 中配置。用户可通过 `/sid` 获取自己的 ID。 -## Sandbox 模式 +## 沙盒模式 `sandbox` 模式会把执行动作放到隔离环境中,而不是直接在 AstrBot 主机上运行。 在沙盒中,Agent 仍然可以使用 Shell、Python、文件系统工具;如果所选沙盒 profile 支持 `browser` capability,还会挂载浏览器自动化工具。 -沙盒环境驱动器可在 `配置 -> 普通配置 -> 使用电脑能力` 的沙箱配置中选择。当前常用选项包括: +沙盒由 AstrBot 托管后,会有“占用”和“保留策略”两层状态。占用表示某个会话暂时控制这个沙盒;占用租约到期后,当前会话不再绑定该沙盒,其他会话可以重新占用或接管。保留策略决定沙盒释放后是保留下来复用,还是按空闲/过期规则清理。 -- `Shipyard Neo`:AstrBot 推荐的远程/独立部署沙盒服务,适合长期运行和多人使用。 -- `CUA`:基于 [CUA](https://github.com/trycua/cua) 的本地或云端电脑使用沙盒,可提供桌面截图、鼠标、键盘、Shell、Python 和文件系统能力。 +沙盒驱动需要先安装对应插件。推荐在 WebUI 的“插件管理”页面点击右下角 `+`,输入插件仓库地址安装;插件安装并加载后,才可以在 `配置 -> 普通配置 -> 使用电脑能力` 的沙盒配置中选择和填写参数。当前可用选项包括: + +- [`Shipyard Neo`](https://github.com/AstrBotDevs/astrbot_sandbox_shipyard_neo):AstrBot 推荐的远程/独立部署沙盒服务,适合长期运行和多人使用。 +- [`BoxLite`](https://github.com/AstrBotDevs/astrbot_sandbox_boxlite):轻量本地沙盒,适合只需要 Shell、Python 和文件操作的场景。 +- [`Shipyard`](https://github.com/AstrBotDevs/astrbot_sandbox_shipyard):旧方案,仍可继续使用。 +- [`CUA`](https://github.com/AstrBotDevs/astrbot_sandbox_cua):本地或云端电脑使用沙盒,适合需要桌面截图、鼠标和键盘操作的场景。 使用 `Shipyard Neo` 时,沙盒 workspace 根目录通常是: @@ -123,7 +127,7 @@ result.txt 使用 `CUA` 时,工作目录和可用命令取决于所选 CUA image 与运行方式。Linux CUA 容器通常提供类 Unix Shell;Windows、Android 等非 POSIX 镜像不保证支持 `sh`、`ls`、`rm`、`base64` 等命令,AstrBot 会对部分 shell fallback 操作返回明确错误。 -沙盒部署、驱动器选择、CUA 配置、profile、TTL、数据持久化、浏览器能力等内容请参考:[Agent 沙盒环境](/use/astrbot-agent-sandbox)。 +沙盒部署、驱动选择、CUA 配置、profile、占用租约、TTL、数据持久化、浏览器能力等内容请参考:[Agent 沙盒环境](/use/astrbot-agent-sandbox)。 > [!NOTE] > 即使在 `sandbox` 模式下,“需要 AstrBot 管理员权限”仍会影响 Shell、Python、浏览器、上传下载等工具的调用权限。具体权限取决于你的配置。 diff --git a/docs/zh/use/skills.md b/docs/zh/use/skills.md index de7b7a97e2..d029fb5a59 100644 --- a/docs/zh/use/skills.md +++ b/docs/zh/use/skills.md @@ -29,10 +29,9 @@ Skills 提供了 Agent 操作说明书,并且内容通常包含 Python 代码 目前,AstrBot 提供两种执行环境: - Local(Agent 将在你的 AstrBot 运行环境中运行。**请谨慎使用,因为这会允许 Agent 在你的环境执行任意代码,可能带来安全风险**) -- Sandbox (Agent 在隔离化的沙盒环境中运行。**需要先启动 AstrBot 沙盒模式**,请参考:[沙盒模式](/use/astrbot-agent-sandbox),如果这个模式下不启动沙盒模式,将不会将 Skills 传给 Agent) +- Sandbox:Agent 在隔离沙盒中运行。需要先启用 AstrBot 沙盒模式,参考:[沙盒模式](/use/astrbot-agent-sandbox)。如果没有启用沙盒模式,Skills 不会传给 Agent。 你可以在 `配置` 页面 - 使用电脑能力 中选择默认的执行环境。 > [!NOTE] > 需要说明的是,如果您使用 Local 作为执行环境,AstrBot 目前仅允许 **AstrBot 管理员**请求时才真正让 Agent 操作你的本地环境,普通用户将会被禁止,Agent 将无法通过 Shell、Python 等 Tool 在本地环境执行代码,会收到相应的权限限制提示,如 `Sorry, I cannot execute code on your local environment due to permission restrictions.`。 - diff --git a/tests/test_computer_config.py b/tests/test_computer_config.py index 26f72991c3..e0af81c3aa 100644 --- a/tests/test_computer_config.py +++ b/tests/test_computer_config.py @@ -1,176 +1,12 @@ -"""Tests for _discover_bay_credentials() auto-discovery and _log_computer_config_changes().""" +"""Tests for _log_computer_config_changes().""" from __future__ import annotations -import json -import logging -from pathlib import Path from unittest.mock import patch -import pytest - -from astrbot.core.computer.computer_client import _discover_bay_credentials from astrbot.dashboard.routes.config import _log_computer_config_changes -# ═══════════════════════════════════════════════════════════════ -# _discover_bay_credentials -# ═══════════════════════════════════════════════════════════════ - - -class TestDiscoverBayCredentials: - """Test Bay API key auto-discovery from credentials.json.""" - - def _write_creds( - self, - path: Path, - api_key: str = "sk-bay-abc123", - endpoint: str = "http://127.0.0.1:8114", - ) -> None: - """Helper: write a credentials.json file.""" - path.parent.mkdir(parents=True, exist_ok=True) - path.write_text( - json.dumps( - { - "api_key": api_key, - "endpoint": endpoint, - "generated_at": "2026-02-17T00:00:00+00:00", - } - ) - ) - - def test_discover_from_bay_data_dir_env( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """BAY_DATA_DIR env var takes highest priority.""" - data_dir = tmp_path / "bay_data" - cred_file = data_dir / "credentials.json" - self._write_creds(cred_file, api_key="sk-bay-from-env-dir") - monkeypatch.setenv("BAY_DATA_DIR", str(data_dir)) - - result = _discover_bay_credentials("http://127.0.0.1:8114") - assert result == "sk-bay-from-env-dir" - - def test_discover_from_cwd( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Falls back to current working directory.""" - cred_file = tmp_path / "credentials.json" - self._write_creds(cred_file, api_key="sk-bay-from-cwd") - monkeypatch.chdir(tmp_path) - monkeypatch.delenv("BAY_DATA_DIR", raising=False) - - result = _discover_bay_credentials("http://127.0.0.1:8114") - assert result == "sk-bay-from-cwd" - - def test_returns_empty_when_no_credentials_found( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Returns empty string when no credentials.json exists anywhere.""" - monkeypatch.chdir(tmp_path) - monkeypatch.delenv("BAY_DATA_DIR", raising=False) - - result = _discover_bay_credentials("http://127.0.0.1:8114") - assert result == "" - - def test_skips_empty_api_key( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Skips credentials.json when api_key is empty.""" - cred_file = tmp_path / "credentials.json" - self._write_creds(cred_file, api_key="") - monkeypatch.chdir(tmp_path) - monkeypatch.delenv("BAY_DATA_DIR", raising=False) - - result = _discover_bay_credentials("http://127.0.0.1:8114") - assert result == "" - - def test_skips_malformed_json( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Handles malformed JSON gracefully.""" - cred_file = tmp_path / "credentials.json" - cred_file.parent.mkdir(parents=True, exist_ok=True) - cred_file.write_text("not valid json {{{") - monkeypatch.chdir(tmp_path) - monkeypatch.delenv("BAY_DATA_DIR", raising=False) - - result = _discover_bay_credentials("http://127.0.0.1:8114") - assert result == "" - - @patch("astrbot.core.computer.computer_client.logger") - def test_endpoint_mismatch_still_returns_key( - self, mock_logger, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Returns key even if endpoint doesn't match, but logs a warning.""" - data_dir = tmp_path / "bay_data" - cred_file = data_dir / "credentials.json" - self._write_creds( - cred_file, api_key="sk-bay-mismatch", endpoint="http://other-host:9000" - ) - monkeypatch.setenv("BAY_DATA_DIR", str(data_dir)) - - result = _discover_bay_credentials("http://127.0.0.1:8114") - - assert result == "sk-bay-mismatch" - mock_logger.warning.assert_called_once() - warning_msg = mock_logger.warning.call_args[0][0] - assert "endpoint mismatch" in warning_msg - - def test_endpoint_match_no_warning( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """No warning when endpoints match.""" - data_dir = tmp_path / "bay_data" - cred_file = data_dir / "credentials.json" - self._write_creds( - cred_file, api_key="sk-bay-match", endpoint="http://127.0.0.1:8114" - ) - monkeypatch.setenv("BAY_DATA_DIR", str(data_dir)) - - with patch("astrbot.core.computer.computer_client.logger") as mock_logger: - result = _discover_bay_credentials("http://127.0.0.1:8114") - - assert result == "sk-bay-match" - mock_logger.warning.assert_not_called() - - def test_bay_data_dir_priority_over_cwd( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """BAY_DATA_DIR takes priority over cwd.""" - env_dir = tmp_path / "env_dir" - cwd_dir = tmp_path / "cwd_dir" - self._write_creds(env_dir / "credentials.json", api_key="sk-bay-env-wins") - self._write_creds(cwd_dir / "credentials.json", api_key="sk-bay-cwd-loses") - monkeypatch.setenv("BAY_DATA_DIR", str(env_dir)) - monkeypatch.chdir(cwd_dir) - - result = _discover_bay_credentials("http://127.0.0.1:8114") - assert result == "sk-bay-env-wins" - - def test_trailing_slash_normalization( - self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch - ) -> None: - """Trailing slashes on endpoints are normalized before comparison.""" - data_dir = tmp_path / "bay_data" - cred_file = data_dir / "credentials.json" - self._write_creds( - cred_file, api_key="sk-bay-slash", endpoint="http://127.0.0.1:8114/" - ) - monkeypatch.setenv("BAY_DATA_DIR", str(data_dir)) - - with patch("astrbot.core.computer.computer_client.logger") as mock_logger: - result = _discover_bay_credentials("http://127.0.0.1:8114") - - assert result == "sk-bay-slash" - mock_logger.warning.assert_not_called() - - -# ═══════════════════════════════════════════════════════════════ -# _log_computer_config_changes -# ═══════════════════════════════════════════════════════════════ - - class TestLogComputerConfigChanges: """Test config change detection and logging.""" @@ -184,7 +20,10 @@ def test_logs_runtime_change(self, mock_logger) -> None: mock_logger.info.assert_called() call_args = [str(c) for c in mock_logger.info.call_args_list] - assert any("computer_use_runtime" in c and "none" in c and "sandbox" in c for c in call_args) + assert any( + "computer_use_runtime" in c and "none" in c and "sandbox" in c + for c in call_args + ) @patch("astrbot.dashboard.routes.config.logger") def test_no_log_when_runtime_unchanged(self, mock_logger) -> None: @@ -199,8 +38,8 @@ def test_no_log_when_runtime_unchanged(self, mock_logger) -> None: @patch("astrbot.dashboard.routes.config.logger") def test_logs_sandbox_key_change(self, mock_logger) -> None: """Detects sandbox sub-key change.""" - old = {"provider_settings": {"sandbox": {"booter": "shipyard"}}} - new = {"provider_settings": {"sandbox": {"booter": "shipyard_neo"}}} + old = {"provider_settings": {"sandbox": {"booter": "provider_a"}}} + new = {"provider_settings": {"sandbox": {"booter": "provider_b"}}} _log_computer_config_changes(old, new) @@ -210,20 +49,20 @@ def test_logs_sandbox_key_change(self, mock_logger) -> None: for call in mock_logger.info.call_args_list: args = call[0] # positional args: (fmt, key, old_val, new_val) if len(args) >= 4 and args[1] == "booter": - assert args[2] == "shipyard" - assert args[3] == "shipyard_neo" + assert args[2] == "provider_a" + assert args[3] == "provider_b" found = True break - assert found, f"Expected booter change in log calls: {mock_logger.info.call_args_list}" + assert found, ( + f"Expected booter change in log calls: {mock_logger.info.call_args_list}" + ) @patch("astrbot.dashboard.routes.config.logger") def test_masks_token_values(self, mock_logger) -> None: """Token/secret values are masked in log output.""" - old = {"provider_settings": {"sandbox": {"shipyard_neo_access_token": ""}}} + old = {"provider_settings": {"sandbox": {"sandbox_access_token": ""}}} new = { - "provider_settings": { - "sandbox": {"shipyard_neo_access_token": "sk-bay-secret123"} - } + "provider_settings": {"sandbox": {"sandbox_access_token": "sk-secret123"}} } _log_computer_config_changes(old, new) @@ -231,17 +70,13 @@ def test_masks_token_values(self, mock_logger) -> None: mock_logger.info.assert_called() call_args_str = str(mock_logger.info.call_args_list) assert "***" in call_args_str - assert "sk-bay-secret123" not in call_args_str + assert "sk-secret123" not in call_args_str @patch("astrbot.dashboard.routes.config.logger") def test_masks_empty_token_as_empty_label(self, mock_logger) -> None: """Empty token values show as '(empty)' not '***'.""" - old = { - "provider_settings": { - "sandbox": {"shipyard_neo_access_token": "old-key"} - } - } - new = {"provider_settings": {"sandbox": {"shipyard_neo_access_token": ""}}} + old = {"provider_settings": {"sandbox": {"sandbox_access_token": "old-key"}}} + new = {"provider_settings": {"sandbox": {"sandbox_access_token": ""}}} _log_computer_config_changes(old, new) @@ -256,8 +91,8 @@ def test_no_log_when_nothing_changed(self, mock_logger) -> None: "provider_settings": { "computer_use_runtime": "sandbox", "sandbox": { - "booter": "shipyard_neo", - "shipyard_neo_endpoint": "http://127.0.0.1:8114", + "booter": "provider_a", + "sandbox_endpoint": "http://127.0.0.1:8114", }, } } @@ -283,7 +118,7 @@ def test_detects_new_sandbox_key(self, mock_logger) -> None: old = {"provider_settings": {"sandbox": {}}} new = { "provider_settings": { - "sandbox": {"shipyard_neo_endpoint": "http://127.0.0.1:8114"} + "sandbox": {"sandbox_endpoint": "http://127.0.0.1:8114"} } } @@ -291,14 +126,14 @@ def test_detects_new_sandbox_key(self, mock_logger) -> None: mock_logger.info.assert_called() call_args_str = str(mock_logger.info.call_args_list) - assert "shipyard_neo_endpoint" in call_args_str + assert "sandbox_endpoint" in call_args_str @patch("astrbot.dashboard.routes.config.logger") def test_detects_removed_sandbox_key(self, mock_logger) -> None: """Detects a removed sandbox key.""" old = { "provider_settings": { - "sandbox": {"shipyard_neo_endpoint": "http://127.0.0.1:8114"} + "sandbox": {"sandbox_endpoint": "http://127.0.0.1:8114"} } } new = {"provider_settings": {"sandbox": {}}} @@ -307,15 +142,13 @@ def test_detects_removed_sandbox_key(self, mock_logger) -> None: mock_logger.info.assert_called() call_args_str = str(mock_logger.info.call_args_list) - assert "shipyard_neo_endpoint" in call_args_str + assert "sandbox_endpoint" in call_args_str @patch("astrbot.dashboard.routes.config.logger") def test_secret_key_masked(self, mock_logger) -> None: """Any key containing 'secret' is also masked.""" old = {"provider_settings": {"sandbox": {"my_secret_key": ""}}} - new = { - "provider_settings": {"sandbox": {"my_secret_key": "very-secret-value"}} - } + new = {"provider_settings": {"sandbox": {"my_secret_key": "very-secret-value"}}} _log_computer_config_changes(old, new) diff --git a/tests/test_computer_fs_tools.py b/tests/test_computer_fs_tools.py index eaf72ec66e..40b819c9f2 100644 --- a/tests/test_computer_fs_tools.py +++ b/tests/test_computer_fs_tools.py @@ -5,6 +5,7 @@ import zipfile from types import SimpleNamespace from typing import Any +from unittest.mock import AsyncMock import pytest from mcp.types import CallToolResult, ImageContent @@ -95,6 +96,130 @@ async def _fake_get_booter(_ctx, _umo): return workspace +@pytest.mark.asyncio +async def test_sandbox_file_download_keeps_original_filename( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +): + temp_root = tmp_path / "temp" + temp_root.mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr( + fs_tools, + "get_astrbot_temp_path", + lambda: str(temp_root), + ) + + booter = SimpleNamespace(download_file=AsyncMock()) + + async def _fake_get_booter(_ctx, _umo): + return booter + + monkeypatch.setattr(fs_tools, "get_booter", _fake_get_booter) + + context = _make_sandbox_context() + result = await fs_tools.FileDownloadTool().call( + context, + remote_path="reports/sandbox_evaluation_report.md", + also_send_to_user=True, + ) + + assert "sandbox_evaluation_report.md" in result + sent_chain = context.context.event.send.await_args.args[0] + sent_file = sent_chain.chain[0] + assert sent_file.name == "sandbox_evaluation_report.md" + + +@pytest.mark.asyncio +async def test_sandbox_file_download_handles_windows_remote_filename( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +): + temp_root = tmp_path / "temp" + temp_root.mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr( + fs_tools, + "get_astrbot_temp_path", + lambda: str(temp_root), + ) + + booter = SimpleNamespace(download_file=AsyncMock()) + + async def _fake_get_booter(_ctx, _umo): + return booter + + monkeypatch.setattr(fs_tools, "get_booter", _fake_get_booter) + + context = _make_sandbox_context() + result = await fs_tools.FileDownloadTool().call( + context, + remote_path=r"C:\Users\AstrBot\report.txt", + also_send_to_user=True, + ) + + assert "report.txt" in result + sent_chain = context.context.event.send.await_args.args[0] + sent_file = sent_chain.chain[0] + assert sent_file.name == "report.txt" + + +@pytest.mark.asyncio +async def test_sandbox_file_download_strips_trailing_remote_slash( + monkeypatch: pytest.MonkeyPatch, + tmp_path, +): + temp_root = tmp_path / "temp" + temp_root.mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr( + fs_tools, + "get_astrbot_temp_path", + lambda: str(temp_root), + ) + + booter = SimpleNamespace(download_file=AsyncMock()) + + async def _fake_get_booter(_ctx, _umo): + return booter + + monkeypatch.setattr(fs_tools, "get_booter", _fake_get_booter) + + context = _make_sandbox_context() + result = await fs_tools.FileDownloadTool().call( + context, + remote_path="reports/export/", + also_send_to_user=True, + ) + + assert "export" in result + sent_chain = context.context.event.send.await_args.args[0] + sent_file = sent_chain.chain[0] + assert sent_file.name == "export" + + +def _make_sandbox_context( + *, + role: str = "admin", + umo: str = "qq:friend:user-1", +): + config_holder = SimpleNamespace( + get_config=lambda umo=None: { + "provider_settings": { + "computer_use_require_admin": True, + "computer_use_runtime": "sandbox", + } + } + ) + event = SimpleNamespace( + role=role, + unified_msg_origin=umo, + send=AsyncMock(), + ) + astr_ctx = SimpleNamespace(context=config_holder, event=event) + return ContextWrapper(context=astr_ctx) + + def _make_large_text() -> str: return "".join(f"line-{index:05d}-{'x' * 48}\n" for index in range(6000)) diff --git a/tests/test_computer_skill_sync.py b/tests/test_computer_skill_sync.py index 0bac69f9f1..c4b78a3f2a 100644 --- a/tests/test_computer_skill_sync.py +++ b/tests/test_computer_skill_sync.py @@ -59,7 +59,8 @@ def test_sync_skills_keeps_builtin_skills_when_local_is_empty( captured = {"skills": None} - def _fake_set_cache(self, skills): + def _fake_set_cache(self, skills, provider_id=None): + _ = provider_id captured["skills"] = skills monkeypatch.setattr( @@ -85,7 +86,10 @@ def _fake_set_cache(self, skills): asyncio.run(computer_client._sync_skills_to_sandbox(cast(ComputerBooter, booter))) assert booter.uploads == [] - assert any(cmd == "rm -f skills/skills.zip" for cmd in booter.shell.commands) + assert any( + cmd.startswith("rm -f skills/skills_bundle_") and cmd.endswith(".zip") + for cmd in booter.shell.commands + ) assert captured["skills"] == [ { "name": "python-sandbox", @@ -108,7 +112,8 @@ def test_sync_skills_uses_managed_strategy_instead_of_wiping_all( captured = {"skills": None} - def _fake_set_cache(self, skills): + def _fake_set_cache(self, skills, provider_id=None): + _ = provider_id captured["skills"] = skills monkeypatch.setattr( @@ -130,7 +135,8 @@ def _fake_set_cache(self, skills): asyncio.run(computer_client._sync_skills_to_sandbox(cast(ComputerBooter, booter))) assert len(booter.uploads) == 1 - assert booter.uploads[0][1] == "skills/skills.zip" + assert booter.uploads[0][1].startswith("skills/skills_bundle_") + assert booter.uploads[0][1].endswith(".zip") assert not any( "find skills -mindepth 1 -delete" in cmd for cmd in booter.shell.commands ) @@ -158,7 +164,8 @@ def test_sync_skills_includes_plugin_provided_skills( captured = {"skills": None} - def _fake_set_cache(self, skills): + def _fake_set_cache(self, skills, provider_id=None): + _ = provider_id captured["skills"] = skills monkeypatch.setattr( @@ -184,7 +191,8 @@ def _fake_set_cache(self, skills): asyncio.run(computer_client._sync_skills_to_sandbox(cast(ComputerBooter, booter))) assert len(booter.uploads) == 1 - assert booter.uploads[0][1] == "skills/skills.zip" + assert booter.uploads[0][1].startswith("skills/skills_bundle_") + assert booter.uploads[0][1].endswith(".zip") assert captured["skills"] == [ { "name": "demo-skill", @@ -194,6 +202,61 @@ def _fake_set_cache(self, skills): ] +def test_sync_skills_uses_unique_temp_zip_per_concurrent_sync( + monkeypatch, + tmp_path: Path, +): + skills_root = tmp_path / "skills" + temp_root = tmp_path / "temp" + skills_root.mkdir(parents=True, exist_ok=True) + temp_root.mkdir(parents=True, exist_ok=True) + skill_dir = skills_root / "custom-agent-skill" + skill_dir.mkdir(parents=True, exist_ok=True) + skill_dir.joinpath("SKILL.md").write_text("# demo", encoding="utf-8") + + captured: list[tuple[str, str]] = [] + + def _fake_set_cache(self, skills, provider_id=None): + _ = skills, provider_id + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.get_astrbot_skills_path", + lambda: str(skills_root), + ) + monkeypatch.setattr( + "astrbot.core.computer.computer_client.get_astrbot_temp_path", + lambda: str(temp_root), + ) + monkeypatch.setattr( + "astrbot.core.computer.computer_client.SkillManager.set_sandbox_skills_cache", + _fake_set_cache, + ) + + class _ConcurrentBooter(_FakeBooter): + async def upload_file(self, path: str, file_name: str) -> dict: + captured.append((path, file_name)) + return await super().upload_file(path, file_name) + + booter_a = _ConcurrentBooter( + '{"skills":[{"name":"custom-agent-skill","description":"","path":"skills/custom-agent-skill/SKILL.md"}]}' + ) + booter_b = _ConcurrentBooter( + '{"skills":[{"name":"custom-agent-skill","description":"","path":"skills/custom-agent-skill/SKILL.md"}]}' + ) + + async def _run_concurrent_syncs(): + await asyncio.gather( + computer_client._sync_skills_to_sandbox(cast(ComputerBooter, booter_a)), + computer_client._sync_skills_to_sandbox(cast(ComputerBooter, booter_b)), + ) + + asyncio.run(_run_concurrent_syncs()) + + assert len(captured) == 2 + assert captured[0][0] != captured[1][0] + assert captured[0][1] != captured[1][1] + + def test_build_scan_command_frontmatter_newline_is_escaped_literal(): command = computer_client._build_scan_command() script = _extract_embedded_python(command) diff --git a/tests/test_computer_tool_permissions.py b/tests/test_computer_tool_permissions.py deleted file mode 100644 index 07f7983da3..0000000000 --- a/tests/test_computer_tool_permissions.py +++ /dev/null @@ -1,100 +0,0 @@ -import json -from types import SimpleNamespace - -import pytest - -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.tools.computer_tools.shipyard_neo.browser import BrowserExecTool -from astrbot.core.tools.computer_tools.shipyard_neo.neo_skills import ( - GetExecutionHistoryTool, -) - - -class _FakeBrowser: - async def exec(self, **kwargs): - return { - "ok": True, - "cmd": kwargs["cmd"], - } - - -class _FakeSandbox: - async def get_execution_history(self, **kwargs): - return { - "items": [], - "limit": kwargs["limit"], - } - - -def _make_run_context(require_admin: bool, role: str = "member") -> ContextWrapper: - config_holder = SimpleNamespace( - get_config=lambda umo: { # noqa: ARG005 - "provider_settings": { - "computer_use_require_admin": require_admin, - } - } - ) - event = SimpleNamespace( - role=role, - unified_msg_origin="qq_official:friend:user-1", - get_sender_id=lambda: "user-1", - ) - astr_ctx = SimpleNamespace(context=config_holder, event=event) - return ContextWrapper(context=astr_ctx) - - -@pytest.mark.asyncio -async def test_browser_tool_allows_non_admin_when_admin_requirement_disabled( - monkeypatch, -): - async def _fake_get_booter(_ctx, _session_id): - return SimpleNamespace(browser=_FakeBrowser()) - - monkeypatch.setattr( - "astrbot.core.tools.computer_tools.shipyard_neo.browser.get_booter", - _fake_get_booter, - ) - - result = await BrowserExecTool().call( - _make_run_context(require_admin=False), - cmd="open https://example.com", - ) - - assert json.loads(result)["ok"] is True - - -@pytest.mark.asyncio -async def test_neo_skill_tool_allows_non_admin_when_admin_requirement_disabled( - monkeypatch, -): - async def _fake_get_booter(_ctx, _session_id): - return SimpleNamespace( - bay_client=object(), - sandbox=_FakeSandbox(), - ) - - monkeypatch.setattr( - "astrbot.core.tools.computer_tools.shipyard_neo.neo_skills.get_booter", - _fake_get_booter, - ) - - result = await GetExecutionHistoryTool().call( - _make_run_context(require_admin=False), - limit=5, - ) - - payload = json.loads(result) - assert payload["items"] == [] - assert payload["limit"] == 5 - - -@pytest.mark.asyncio -async def test_browser_tool_still_denies_non_admin_when_admin_requirement_enabled(): - result = await BrowserExecTool().call( - _make_run_context(require_admin=True), - cmd="open https://example.com", - ) - - assert "Permission denied" in result - assert "Using browser tools is only allowed for admin users" in result - assert "User's ID is: user-1" in result diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py index 596cdb25d8..984b16e176 100644 --- a/tests/test_dashboard.py +++ b/tests/test_dashboard.py @@ -4,12 +4,12 @@ import os import re import shutil -import sys import uuid import zipfile from datetime import datetime from pathlib import Path from types import SimpleNamespace +from unittest.mock import AsyncMock from urllib.parse import parse_qs, urlsplit, urlunsplit import pyotp @@ -49,6 +49,28 @@ create_mock_updater_update, ) + +class FakeSandboxProvider: + provider_id = "dashboard-generic" + capabilities = {"shell", "filesystem"} + tool_names = {"dashboard_generic_tool"} + + def build_create_config(self, context, session_id): + return {} + + def build_connect_info(self, sandbox_name, config): + return {"name": sandbox_name} + + def update_connect_info(self, record, *, sandbox_name): + return {"name": sandbox_name} + + async def create_booter(self, context, session_id, sandbox_id, config): + return SimpleNamespace(available=lambda: True, shutdown=lambda: None) + + async def destroy_booter(self, booter, record): + return None + + _TEST_DASHBOARD_PASSWORD = "AstrbotTest123" PLUGIN_PAGE_DEMO_NAME = "astrbot_plugin_page_demo" PLUGIN_PAGE_DEMO_PAGE_NAME = "bridge-demo" @@ -201,137 +223,1085 @@ async def core_lifecycle_td(tmp_path_factory): pass -@pytest.fixture(scope="module") -def app(core_lifecycle_td: AstrBotCoreLifecycle): - """Creates a Quart app instance for testing.""" - shutdown_event = asyncio.Event() - # The db instance is already part of the core_lifecycle_td - server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event) - server.app._dashboard_server = server # expose for test cleanup - return server.app +@pytest.fixture(scope="module") +def app(core_lifecycle_td: AstrBotCoreLifecycle): + """Creates a Quart app instance for testing.""" + shutdown_event = asyncio.Event() + # The db instance is already part of the core_lifecycle_td + server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event) + server.app._dashboard_server = server # expose for test cleanup + return server.app + + +def _resolve_dashboard_password(core_lifecycle_td: AstrBotCoreLifecycle) -> str: + """Return a login password compatible with both hashed and plain defaults.""" + generated_password = getattr(core_lifecycle_td, "_dashboard_plain_password", None) + if generated_password: + return generated_password + password = core_lifecycle_td.astrbot_config["dashboard"]["pbkdf2_password"] + if isinstance(password, str) and password.startswith("pbkdf2_sha256$"): + return "astrbot" + return password + + +def test_dashboard_uses_bundled_dist_when_data_dist_is_stale( + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch, + tmp_path, +): + data_dir = tmp_path / "data" + user_dist = data_dir / "dist" + bundled_dist = tmp_path / "bundled-dist" + user_dist.mkdir(parents=True) + bundled_dist.mkdir() + + monkeypatch.setattr( + "astrbot.dashboard.server.get_astrbot_data_path", + lambda: str(data_dir), + ) + monkeypatch.setattr( + "astrbot.dashboard.server.get_bundled_dashboard_dist_path", + lambda: bundled_dist, + ) + monkeypatch.setattr( + "astrbot.dashboard.server.should_use_bundled_dashboard_dist", + lambda *_args, **_kwargs: True, + ) + + shutdown_event = asyncio.Event() + server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event) + + assert server.data_path == str(bundled_dist) + + +async def _set_dashboard_password_change_required( + core_lifecycle_td: AstrBotCoreLifecycle, + required: bool, +) -> None: + await set_password_change_required( + core_lifecycle_td.db, + core_lifecycle_td.astrbot_config, + required, + ) + + +async def _restore_dashboard_password_state( + core_lifecycle_td: AstrBotCoreLifecycle, + dashboard_config: dict, +) -> None: + core_lifecycle_td.astrbot_config["dashboard"] = dashboard_config + await set_password_change_required( + core_lifecycle_td.db, + core_lifecycle_td.astrbot_config, + False, + ) + await set_password_storage_upgraded( + core_lifecycle_td.db, + core_lifecycle_td.astrbot_config, + bool(dashboard_config.get("pbkdf2_password")), + ) + + +@pytest_asyncio.fixture(scope="module") +async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): + """Handles login and returns an authenticated header.""" + test_client = app.test_client() + response = await test_client.post( + "/api/auth/login", + json={ + "username": core_lifecycle_td.astrbot_config["dashboard"]["username"], + "password": _resolve_dashboard_password(core_lifecycle_td), + }, + ) + data = await response.get_json() + assert data["status"] == "ok" + token = data["data"]["token"] + return {"Authorization": f"Bearer {token}"} + + +@pytest.mark.asyncio +async def test_auth_login( + app: Quart, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch: pytest.MonkeyPatch, +): + """Tests the login functionality with both wrong and correct credentials.""" + monkeypatch.setitem(app.config, "DASHBOARD_JWT_COOKIE_SECURE", False) + + test_client = app.test_client() + response = await test_client.post( + "/api/auth/login", + json={"username": "wrong", "password": "password"}, + ) + data = await response.get_json() + assert data["status"] == "error" + + response = await test_client.post( + "/api/auth/login", + json={ + "username": core_lifecycle_td.astrbot_config["dashboard"]["username"], + "password": _resolve_dashboard_password(core_lifecycle_td), + }, + ) + data = await response.get_json() + assert data["status"] == "ok" and "token" in data["data"] + set_cookie_headers = response.headers.getlist("Set-Cookie") + jwt_cookie_header = next( + (value for value in set_cookie_headers if DASHBOARD_JWT_COOKIE_NAME in value), + "", + ) + assert jwt_cookie_header + assert "HttpOnly" in jwt_cookie_header + assert "SameSite=Strict" in jwt_cookie_header + assert "Secure" not in jwt_cookie_header + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_lists_generic_providers( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager(registry=SandboxRegistry(), providers={}) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + computer_client.register_sandbox_provider(provider) + + test_client = app.test_client() + response = await test_client.get( + "/api/sandbox/providers", headers=authenticated_header + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["providers"] == [ + { + "provider_id": "dashboard-generic", + "capabilities": ["filesystem", "shell"], + "tool_names": ["dashboard_generic_tool"], + "system_prompt": "", + } + ] + assert data["data"]["default_provider_id"] == "" + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_provider_list_includes_configured_default( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, + core_lifecycle_td: AstrBotCoreLifecycle, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager(registry=SandboxRegistry(), providers={}) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + computer_client.register_sandbox_provider(provider) + monkeypatch.setattr( + core_lifecycle_td.star_context, + "get_config", + lambda umo=None: { + "provider_settings": { + "sandbox": {"booter": provider.provider_id}, + } + }, + ) + + test_client = app.test_client() + response = await test_client.get( + "/api/sandbox/providers", headers=authenticated_header + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["default_provider_id"] == provider.provider_id + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_provider_list_omits_disabled_plugins( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager(registry=SandboxRegistry(), providers={}) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + test_client = app.test_client() + response = await test_client.get( + "/api/sandbox/providers", headers=authenticated_header + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["providers"] == [] + + +@pytest.mark.asyncio +async def test_config_metadata_includes_registered_sandbox_providers( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager(registry=SandboxRegistry(), providers={}) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + computer_client.register_sandbox_provider(provider) + + test_client = app.test_client() + response = await test_client.get( + "/api/config/abconf?id=default", headers=authenticated_header + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + metadata_text = str(data["data"]["metadata"]) + assert "provider_settings.sandbox.booter" in metadata_text + assert "dashboard-generic" in metadata_text + + +@pytest.mark.asyncio +async def test_config_abconf_clears_unavailable_sandbox_booter_for_display( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, + core_lifecycle_td: AstrBotCoreLifecycle, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager(registry=SandboxRegistry(), providers={}) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + original_booter = core_lifecycle_td.astrbot_config["provider_settings"][ + "sandbox" + ].get("booter") + core_lifecycle_td.astrbot_config["provider_settings"]["sandbox"]["booter"] = ( + "shipyard" + ) + + try: + test_client = app.test_client() + response = await test_client.get( + "/api/config/abconf?id=default", headers=authenticated_header + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["config"]["provider_settings"]["sandbox"]["booter"] == "" + finally: + core_lifecycle_td.astrbot_config["provider_settings"]["sandbox"]["booter"] = ( + original_booter + ) + + +@pytest.mark.asyncio +async def test_config_save_preserves_unavailable_sandbox_booter( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, + core_lifecycle_td: AstrBotCoreLifecycle, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager(registry=SandboxRegistry(), providers={}) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + original_config = copy.deepcopy(dict(core_lifecycle_td.astrbot_config)) + post_config = copy.deepcopy(original_config) + post_config["provider_settings"]["computer_use_runtime"] = "sandbox" + post_config["provider_settings"]["sandbox"]["booter"] = "shipyard" + + try: + test_client = app.test_client() + response = await test_client.post( + "/api/config/astrbot/update", + headers=authenticated_header, + json={"conf_id": "default", "config": post_config}, + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert ( + core_lifecycle_td.astrbot_config["provider_settings"]["sandbox"]["booter"] + == "shipyard" + ) + finally: + core_lifecycle_td.astrbot_config.save_config(original_config) + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_lists_managed_sandboxes( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Sandbox 1"}, + ) + + test_client = app.test_client() + response = await test_client.get("/api/sandbox", headers=authenticated_header) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["sandboxes"][0]["sandbox_id"] == "sandbox-1" + assert data["data"]["sandboxes"][0]["capabilities"] == [ + "filesystem", + "shell", + ] + assert data["data"]["sandboxes"][0]["tool_names"] == [ + "dashboard_generic_tool", + ] + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_create_does_not_auto_occupy_sandbox( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + test_client = app.test_client() + response = await test_client.post( + "/api/sandbox?session_id=dashboard", + json={"provider_id": provider.provider_id, "sandbox_name": "Named"}, + headers=authenticated_header, + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["sandbox"]["sandbox_name"] == "Named" + assert data["data"]["sandbox"]["status"] == "creating" + assert data["data"]["sandbox"]["controller_session_id"] is None + assert manager.get_current_sandbox("dashboard")["current_sandbox_id"] is None + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_blocks_mutations_in_demo_mode( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + shell = SimpleNamespace(exec=AsyncMock(return_value={"stdout": "ran"})) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="dashboard", + owner_session_id="dashboard", + connect_info={"name": "Sandbox 1"}, + status="running", + ) + manager.session_booter["sandbox-1"] = SimpleNamespace( + shell=shell, + available=lambda: True, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr( + "astrbot.dashboard.routes.sandbox_helpers.DEMO_MODE", True, raising=False + ) + + test_client = app.test_client() + create_response = await test_client.post( + "/api/sandbox?session_id=dashboard", + json={"provider_id": provider.provider_id, "sandbox_name": "Blocked"}, + headers=authenticated_header, + ) + shell_response = await test_client.post( + "/api/sandbox/sandbox-1/shell?session_id=dashboard", + json={"command": "echo blocked"}, + headers=authenticated_header, + ) + + for task in list(manager.pending_boot_tasks.values()): + task.cancel() + if manager.pending_boot_tasks: + await asyncio.gather( + *manager.pending_boot_tasks.values(), return_exceptions=True + ) + + create_data = await create_response.get_json() + shell_data = await shell_response.get_json() + + assert create_response.status_code == 200 + assert shell_response.status_code == 200 + assert create_data["status"] == "error" + assert shell_data["status"] == "error" + assert "demo mode" in create_data["message"] + assert "demo mode" in shell_data["message"] + assert [sandbox["sandbox_id"] for sandbox in manager.list_sandboxes()] == [ + "sandbox-1" + ] + shell.exec.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_create_rejects_cross_origin_cookie_auth( + app: Quart, + core_lifecycle_td: AstrBotCoreLifecycle, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + test_client = app.test_client() + login_response = await test_client.post( + "/api/auth/login", + json={ + "username": core_lifecycle_td.astrbot_config["dashboard"]["username"], + "password": _resolve_dashboard_password(core_lifecycle_td), + }, + ) + assert login_response.status_code == 200 + + response = await test_client.post( + "/api/sandbox?session_id=dashboard", + json={"provider_id": provider.provider_id, "sandbox_name": "Named"}, + headers={"Origin": "http://evil.localhost:3000"}, + ) + data = await response.get_json() + + assert response.status_code == 403 + assert data["status"] == "error" + assert "Origin" in data["message"] + assert manager.list_sandboxes() == [] + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_create_allows_cross_origin_authorization_header( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + test_client = app.test_client() + response = await test_client.post( + "/api/sandbox?session_id=dashboard", + json={"provider_id": provider.provider_id, "sandbox_name": "Named"}, + headers={ + **authenticated_header, + "Origin": "http://api-client.example", + }, + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["sandbox"]["sandbox_name"] == "Named" + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_create_rejects_duplicate_name( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + logged_errors = [] + monkeypatch.setattr( + "astrbot.dashboard.routes.sandbox.logger.error", + lambda *args, **kwargs: logged_errors.append((args, kwargs)), + ) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Named", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Named"}, + ) + + test_client = app.test_client() + response = await test_client.post( + "/api/sandbox?session_id=dashboard", + json={"provider_id": provider.provider_id, "sandbox_name": "Named"}, + headers=authenticated_header, + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "error" + assert data["message"] == "Sandbox name 'Named' already exists" + assert logged_errors == [] + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_create_reports_max_sandbox_limit( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, + core_lifecycle_td: AstrBotCoreLifecycle, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="dashboard", + owner_session_id="dashboard", + connect_info={"name": "Sandbox 1"}, + ) + monkeypatch.setattr( + core_lifecycle_td.star_context, + "get_config", + lambda umo=None: { + "provider_settings": { + "sandbox": {"max_sandboxes": 1}, + } + }, + ) + + test_client = app.test_client() + response = await test_client.post( + "/api/sandbox?session_id=dashboard", + json={"provider_id": provider.provider_id, "sandbox_name": "Second"}, + headers=authenticated_header, + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "error" + assert "Sandbox limit reached" in data["message"] + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_sets_default_sandbox( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + for sandbox_id in ("sandbox-1", "sandbox-2"): + manager.registry.upsert_sandbox( + sandbox_id=sandbox_id, + sandbox_name=sandbox_id, + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": sandbox_id}, + ) + + test_client = app.test_client() + response = await test_client.post( + "/api/sandbox/sandbox-2/default", headers=authenticated_header + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["sandbox"]["sandbox_id"] == "sandbox-2" + assert data["data"]["sandbox"]["is_default"] is True + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_patch_preserves_existing_retention_policy( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + provider.supports_persistent_reconnect = True + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Sandbox 1"}, + retention_policy="persistent", + status="running", + ) + + test_client = app.test_client() + response = await test_client.patch( + "/api/sandbox/sandbox-1", + json={"sandbox_name": "Renamed"}, + headers=authenticated_header, + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["sandbox"]["sandbox_name"] == "Renamed" + assert data["data"]["sandbox"]["retention_policy"] == "persistent" + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_patch_name_preserves_temporary_lifecycle_fields( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Sandbox 1"}, + retention_policy="temporary", + idle_timeout=0, + expires_at=1234567890.0, + status="running", + ) + + test_client = app.test_client() + response = await test_client.patch( + "/api/sandbox/sandbox-1", + json={"sandbox_name": "Renamed"}, + headers=authenticated_header, + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["sandbox"]["sandbox_name"] == "Renamed" + assert data["data"]["sandbox"]["retention_policy"] == "temporary" + assert data["data"]["sandbox"]["idle_timeout"] == 0 + assert data["data"]["sandbox"]["expires_at"] == 1234567890.0 + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_force_releases_busy_sandbox( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + controller_user_id="webchat", + controller_session_id="webchat:friend:user", + lease_expires_at=9999999999, + connect_info={"name": "Sandbox 1"}, + ) + + test_client = app.test_client() + response = await test_client.delete( + "/api/sandbox/current?session_id=dashboard&sandbox_id=sandbox-1", + headers=authenticated_header, + ) + data = await response.get_json() + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["sandbox"]["controller_session_id"] is None + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_runs_shell_in_managed_sandbox( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + class FakeShell: + async def exec(self, command, cwd=None, env=None, timeout=300, shell=True): + return { + "command": command, + "cwd": cwd, + "env": env, + "timeout": timeout, + "shell": shell, + "stdout": "ok\n", + "stderr": "", + "exit_code": 0, + } + + async def available(): + return True + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Sandbox 1"}, + ) + manager.session_booter["sandbox-1"] = SimpleNamespace( + available=available, shell=FakeShell() + ) + test_client = app.test_client() + response = await test_client.post( + "/api/sandbox/sandbox-1/shell", + json={"command": "pwd", "cwd": "/workspace", "timeout": 5}, + headers=authenticated_header, + ) + data = await response.get_json() -def _resolve_dashboard_password(core_lifecycle_td: AstrBotCoreLifecycle) -> str: - """Return a login password compatible with both hashed and plain defaults.""" - generated_password = getattr(core_lifecycle_td, "_dashboard_plain_password", None) - if generated_password: - return generated_password - password = core_lifecycle_td.astrbot_config["dashboard"]["pbkdf2_password"] - if isinstance(password, str) and password.startswith("pbkdf2_sha256$"): - return "astrbot" - return password + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["result"]["command"] == "pwd" + assert data["data"]["result"]["cwd"] == "/workspace" + assert data["data"]["result"]["timeout"] == 5 -def test_dashboard_uses_bundled_dist_when_data_dist_is_stale( - core_lifecycle_td: AstrBotCoreLifecycle, - monkeypatch, - tmp_path, +@pytest.mark.asyncio +async def test_sandbox_dashboard_shell_uses_default_timeout_for_invalid_value( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, ): - data_dir = tmp_path / "data" - user_dist = data_dir / "dist" - bundled_dist = tmp_path / "bundled-dist" - user_dist.mkdir(parents=True) - bundled_dist.mkdir() + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry - monkeypatch.setattr( - "astrbot.dashboard.server.get_astrbot_data_path", - lambda: str(data_dir), + class FakeShell: + async def exec(self, command, cwd=None, env=None, timeout=300, shell=True): + return {"timeout": timeout, "stdout": "ok\n", "stderr": "", "exit_code": 0} + + async def available(): + return True + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} ) - monkeypatch.setattr( - "astrbot.dashboard.server.get_bundled_dashboard_dist_path", - lambda: bundled_dist, + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Sandbox 1"}, ) - monkeypatch.setattr( - "astrbot.dashboard.server.should_use_bundled_dashboard_dist", - lambda *_args, **_kwargs: True, + manager.session_booter["sandbox-1"] = SimpleNamespace( + available=available, shell=FakeShell() ) - shutdown_event = asyncio.Event() - server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event) + test_client = app.test_client() + response = await test_client.post( + "/api/sandbox/sandbox-1/shell", + json={"command": "pwd", "timeout": "not-a-number"}, + headers=authenticated_header, + ) + data = await response.get_json() - assert server.data_path == str(bundled_dist) + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["result"]["timeout"] == 300 -async def _set_dashboard_password_change_required( - core_lifecycle_td: AstrBotCoreLifecycle, - required: bool, -) -> None: - await set_password_change_required( - core_lifecycle_td.db, - core_lifecycle_td.astrbot_config, - required, - ) +@pytest.mark.asyncio +async def test_sandbox_dashboard_shell_bypasses_lease_for_admin_access( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + """Dashboard shell is an administrative operation and must bypass lease.""" + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + class FakeShell: + async def exec(self, command, cwd=None, env=None, timeout=300, shell=True): + return { + "command": command, + "stdout": "ok\n", + "stderr": "", + "exit_code": 0, + } + async def available(): + return True -async def _restore_dashboard_password_state( - core_lifecycle_td: AstrBotCoreLifecycle, - dashboard_config: dict, -) -> None: - core_lifecycle_td.astrbot_config["dashboard"] = dashboard_config - await set_password_change_required( - core_lifecycle_td.db, - core_lifecycle_td.astrbot_config, - False, + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} ) - await set_password_storage_upgraded( - core_lifecycle_td.db, - core_lifecycle_td.astrbot_config, - bool(dashboard_config.get("pbkdf2_password")), + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + controller_user_id="webchat", + controller_session_id="webchat:friend:user", + lease_expires_at=9999999999, + connect_info={"name": "Sandbox 1"}, + ) + manager.session_booter["sandbox-1"] = SimpleNamespace( + available=available, shell=FakeShell() ) - -@pytest_asyncio.fixture(scope="module") -async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): - """Handles login and returns an authenticated header.""" test_client = app.test_client() response = await test_client.post( - "/api/auth/login", - json={ - "username": core_lifecycle_td.astrbot_config["dashboard"]["username"], - "password": _resolve_dashboard_password(core_lifecycle_td), - }, + "/api/sandbox/sandbox-1/shell?session_id=dashboard", + json={"command": "pwd"}, + headers=authenticated_header, ) data = await response.get_json() + + assert response.status_code == 200 assert data["status"] == "ok" - token = data["data"]["token"] - return {"Authorization": f"Bearer {token}"} + assert data["data"]["result"]["stdout"] == "ok\n" @pytest.mark.asyncio -async def test_auth_login( +async def test_sandbox_dashboard_captures_managed_sandbox_screenshot( app: Quart, - core_lifecycle_td: AstrBotCoreLifecycle, + authenticated_header: dict, monkeypatch: pytest.MonkeyPatch, ): - """Tests the login functionality with both wrong and correct credentials.""" - monkeypatch.setitem(app.config, "DASHBOARD_JWT_COOKIE_SECURE", False) + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + class FakeGui: + async def screenshot(self, path=None): + return {"mime_type": "image/png", "base64": "abc", "path": path} + + async def available(): + return True + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Sandbox 1"}, + ) + manager.session_booter["sandbox-1"] = SimpleNamespace( + available=available, gui=FakeGui() + ) test_client = app.test_client() response = await test_client.post( - "/api/auth/login", - json={"username": "wrong", "password": "password"}, + "/api/sandbox/sandbox-1/screenshot", + json={"path": "/tmp/screen.png"}, + headers=authenticated_header, ) data = await response.get_json() - assert data["status"] == "error" + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["screenshot"] == { + "mime_type": "image/png", + "base64": "abc", + "path": "/tmp/screen.png", + } + + +@pytest.mark.asyncio +async def test_sandbox_dashboard_screenshot_bypasses_lease_for_monitoring( + app: Quart, + authenticated_header: dict, + monkeypatch: pytest.MonkeyPatch, +): + """Dashboard screenshot is read-only observer access and must not need a lease.""" + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + class FakeGui: + async def screenshot(self, path=None): + return {"mime_type": "image/png", "base64": "abc", "path": path} + + async def available(): + return True + + provider = FakeSandboxProvider() + manager = SandboxManager( + registry=SandboxRegistry(), providers={provider.provider_id: provider} + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="sandbox-1", + sandbox_name="Sandbox 1", + provider=provider.provider_id, + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + controller_user_id="webchat", + controller_session_id="webchat:friend:user", + lease_expires_at=9999999999, + connect_info={"name": "Sandbox 1"}, + ) + manager.session_booter["sandbox-1"] = SimpleNamespace( + available=available, gui=FakeGui() + ) + + test_client = app.test_client() response = await test_client.post( - "/api/auth/login", - json={ - "username": core_lifecycle_td.astrbot_config["dashboard"]["username"], - "password": _resolve_dashboard_password(core_lifecycle_td), - }, + "/api/sandbox/sandbox-1/screenshot?session_id=dashboard", + json={"path": "/tmp/screen.png"}, + headers=authenticated_header, ) data = await response.get_json() - assert data["status"] == "ok" and "token" in data["data"] - set_cookie_headers = response.headers.getlist("Set-Cookie") - jwt_cookie_header = next( - (value for value in set_cookie_headers if DASHBOARD_JWT_COOKIE_NAME in value), - "", - ) - assert jwt_cookie_header - assert "HttpOnly" in jwt_cookie_header - assert "SameSite=Strict" in jwt_cookie_header - assert "Secure" not in jwt_cookie_header + + assert response.status_code == 200 + assert data["status"] == "ok" + assert data["data"]["screenshot"]["base64"] == "abc" @pytest.mark.asyncio @@ -875,6 +1845,26 @@ async def test_config_save_rejects_recovery_code_for_protected_totp_changes( ) +@pytest.mark.asyncio +async def test_validate_neo_connectivity_noops_for_plugin_managed_provider_config(): + from astrbot.dashboard.routes.config import _validate_neo_connectivity + + warning = await _validate_neo_connectivity( + { + "provider_settings": { + "computer_use_runtime": "sandbox", + "sandbox": { + "booter": "shipyard_neo", + "shipyard_neo_endpoint": "http://127.0.0.1:65535", + "shipyard_neo_access_token": "", + }, + } + } + ) + + assert warning is None + + @pytest.mark.asyncio async def test_auth_totp_setup_with_valid_code_returns_recovery_code( app: Quart, @@ -2552,185 +3542,13 @@ async def mock_pip_install(*args, **kwargs): assert data["message"] == "install failed" -class _FakeNeoSkills: - async def list_candidates(self, **kwargs): - _ = kwargs - return [ - { - "id": "cand-1", - "skill_key": "neo.demo", - "status": "evaluated_pass", - "payload_ref": "pref-1", - } - ] - - async def list_releases(self, **kwargs): - _ = kwargs - return [ - { - "id": "rel-1", - "skill_key": "neo.demo", - "candidate_id": "cand-1", - "stage": "stable", - "active": True, - } - ] - - async def get_payload(self, payload_ref: str): - return { - "payload_ref": payload_ref, - "payload": {"skill_markdown": "# Demo"}, - } - - async def evaluate_candidate(self, candidate_id: str, **kwargs): - return {"candidate_id": candidate_id, **kwargs} - - async def promote_candidate(self, candidate_id: str, stage: str = "canary"): - return { - "id": "rel-2", - "skill_key": "neo.demo", - "candidate_id": candidate_id, - "stage": stage, - } - - async def rollback_release(self, release_id: str): - return {"id": "rb-1", "rolled_back_release_id": release_id} - - -class _FakeNeoBayClient: - def __init__(self, endpoint_url: str, access_token: str): - self.endpoint_url = endpoint_url - self.access_token = access_token - self.skills = _FakeNeoSkills() - - async def __aenter__(self): - return self - - async def __aexit__(self, exc_type, exc, tb): - _ = exc_type, exc, tb - return False - - @pytest.mark.asyncio -async def test_neo_skills_routes( +async def test_core_dashboard_does_not_ship_neo_skill_routes( app: Quart, - authenticated_header: dict, - core_lifecycle_td: AstrBotCoreLifecycle, - monkeypatch, ): - provider_settings = core_lifecycle_td.astrbot_config.setdefault( - "provider_settings", {} - ) - sandbox = provider_settings.setdefault("sandbox", {}) - sandbox["shipyard_neo_endpoint"] = "http://neo.test" - sandbox["shipyard_neo_access_token"] = "neo-token" - - fake_shipyard_neo_module = SimpleNamespace(BayClient=_FakeNeoBayClient) - monkeypatch.setitem(sys.modules, "shipyard_neo", fake_shipyard_neo_module) - - async def _fake_sync_release(self, client, **kwargs): - _ = self, client, kwargs - return SimpleNamespace( - skill_key="neo.demo", - local_skill_name="neo_demo", - release_id="rel-2", - candidate_id="cand-1", - payload_ref="pref-1", - map_path="data/skills/neo_skill_map.json", - synced_at="2026-01-01T00:00:00Z", - ) - - async def _fake_sync_skills_to_active_sandboxes(): - return - - monkeypatch.setattr( - "astrbot.dashboard.routes.skills.NeoSkillSyncManager.sync_release", - _fake_sync_release, - ) - monkeypatch.setattr( - "astrbot.dashboard.routes.skills.sync_skills_to_active_sandboxes", - _fake_sync_skills_to_active_sandboxes, - ) - - test_client = app.test_client() - - response = await test_client.get( - "/api/skills/neo/candidates", headers=authenticated_header - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - assert isinstance(data["data"], list) - assert data["data"][0]["id"] == "cand-1" - - response = await test_client.get( - "/api/skills/neo/releases", headers=authenticated_header - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - assert isinstance(data["data"], list) - assert data["data"][0]["id"] == "rel-1" - - response = await test_client.get( - "/api/skills/neo/payload?payload_ref=pref-1", headers=authenticated_header - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - assert data["data"]["payload_ref"] == "pref-1" - - response = await test_client.post( - "/api/skills/neo/evaluate", - json={"candidate_id": "cand-1", "passed": True, "score": 0.95}, - headers=authenticated_header, - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - assert data["data"]["candidate_id"] == "cand-1" - assert data["data"]["passed"] is True - - response = await test_client.post( - "/api/skills/neo/evaluate", - json={"candidate_id": "cand-1", "passed": "false", "score": 0.0}, - headers=authenticated_header, - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - assert data["data"]["passed"] is False - - response = await test_client.post( - "/api/skills/neo/promote", - json={"candidate_id": "cand-1", "stage": "stable"}, - headers=authenticated_header, - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - assert data["data"]["release"]["id"] == "rel-2" - assert data["data"]["sync"]["local_skill_name"] == "neo_demo" - - response = await test_client.post( - "/api/skills/neo/rollback", - json={"release_id": "rel-2"}, - headers=authenticated_header, - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - assert data["data"]["rolled_back_release_id"] == "rel-2" - - response = await test_client.post( - "/api/skills/neo/sync", - json={"release_id": "rel-2"}, - headers=authenticated_header, - ) - assert response.status_code == 200 - data = await response.get_json() - assert data["status"] == "ok" - assert data["data"]["skill_key"] == "neo.demo" + assert "/api/skills/neo/candidates" not in { + rule.rule for rule in app.url_map.iter_rules() + } @pytest.mark.asyncio diff --git a/tests/test_neo_skill_tools.py b/tests/test_neo_skill_tools.py deleted file mode 100644 index 076da00945..0000000000 --- a/tests/test_neo_skill_tools.py +++ /dev/null @@ -1,88 +0,0 @@ -from __future__ import annotations - -import asyncio -from types import SimpleNamespace - -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.tools.computer_tools.shipyard_neo.neo_skills import ( - PromoteSkillCandidateTool, -) - - -class _FakeSkills: - def __init__(self): - self.rollback_called_with = None - - async def promote_candidate(self, candidate_id: str, stage: str = "canary"): - assert candidate_id == "cand-1" - assert stage == "stable" - return { - "id": "sr-1", - "skill_key": "k1", - "candidate_id": candidate_id, - "stage": stage, - } - - async def rollback_release(self, release_id: str): - self.rollback_called_with = release_id - return {"id": "rb-1", "rollback_of": release_id} - - -class _FakeClient: - def __init__(self): - self.skills = _FakeSkills() - - -class _FakeBooter: - def __init__(self): - self.bay_client = _FakeClient() - self.sandbox = object() - - -def test_promote_stable_sync_failure_auto_rolls_back(monkeypatch): - async def _fake_get_booter(_ctx, _session_id): - return _FakeBooter() - - async def _fake_sync_release(self, client, **kwargs): - _ = self, client, kwargs - raise ValueError("sync failed") - - monkeypatch.setattr( - "astrbot.core.tools.computer_tools.shipyard_neo.neo_skills.get_booter", - _fake_get_booter, - ) - monkeypatch.setattr( - "astrbot.core.tools.computer_tools.shipyard_neo.neo_skills.NeoSkillSyncManager.sync_release", - _fake_sync_release, - ) - - event = SimpleNamespace( - role="admin", - unified_msg_origin="session-1", - get_sender_id=lambda: "admin-user", - ) - astr_ctx = SimpleNamespace( - context=SimpleNamespace( - get_config=lambda umo: { # noqa: ARG005 - "provider_settings": { - "computer_use_require_admin": True, - } - } - ), - event=event, - ) - run_ctx = ContextWrapper(context=astr_ctx) - - tool = PromoteSkillCandidateTool() - result = asyncio.run( - tool.call( - run_ctx, - candidate_id="cand-1", - stage="stable", - sync_to_local=True, - ) - ) - - assert isinstance(result, str) - assert "auto rollback succeeded" in result - assert "sync failed" in result diff --git a/tests/test_plugin_manager.py b/tests/test_plugin_manager.py index 632d312999..b12964cdb6 100644 --- a/tests/test_plugin_manager.py +++ b/tests/test_plugin_manager.py @@ -1,14 +1,21 @@ import asyncio import json import os +from dataclasses import dataclass, field from pathlib import Path from typing import Any, cast import pytest import yaml +from astrbot.core.agent.tool import FunctionTool from astrbot.core.star import star_manager as star_manager_module from astrbot.core.star.star_manager import PluginDependencyInstallError, PluginManager +from astrbot.core.tools.registry import ( + builtin_tool, + get_builtin_tool_class, + unregister_builtin_tools_by_module_prefix, +) from astrbot.core.utils.pip_installer import PipInstallError from astrbot.core.utils.requirements_utils import MissingRequirementsPlan @@ -162,6 +169,19 @@ def _clear_star_runtime_state(): star_manager_module.star_handlers_registry.clear() +def _register_plugin_builtin_tool(): + @dataclass + class PluginBuiltinTool(FunctionTool): + name: str = "astrbot_test_plugin_builtin_tool" + description: str = "Plugin builtin tool used to test reload cleanup." + parameters: dict = field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + + PluginBuiltinTool.__module__ = "data.plugins.helloworld.tools.browser" + return builtin_tool(PluginBuiltinTool) + + def _build_load_mock(events): async def mock_load(specified_dir_name=None, ignore_version_check=False): del ignore_version_check @@ -527,6 +547,37 @@ async def mock_load( assert unbound == plugin_names +@pytest.mark.asyncio +async def test_unbind_plugin_unregisters_plugin_builtin_tools( + plugin_manager_pm: PluginManager, +): + _clear_star_runtime_state() + plugin_module_path = "data.plugins.helloworld.main" + metadata = star_manager_module.StarMetadata( + name=TEST_PLUGIN_NAME, + root_dir_name=TEST_PLUGIN_DIR, + module_path=plugin_module_path, + ) + star_manager_module.star_map[plugin_module_path] = metadata + star_manager_module.star_registry.append(metadata) + tool_cls = _register_plugin_builtin_tool() + star_manager_module.llm_tools.get_builtin_tool(tool_cls) + + try: + assert get_builtin_tool_class("astrbot_test_plugin_builtin_tool") is tool_cls + + await plugin_manager_pm._unbind_plugin(TEST_PLUGIN_NAME, plugin_module_path) + + assert get_builtin_tool_class("astrbot_test_plugin_builtin_tool") is None + assert tool_cls not in star_manager_module.llm_tools.builtin_func_list + finally: + unregister_builtin_tools_by_module_prefix("data.plugins.helloworld") + star_manager_module.llm_tools.clear_builtin_tool_cache_by_module_prefix( + "data.plugins.helloworld" + ) + _clear_star_runtime_state() + + @pytest.mark.asyncio async def test_load_reports_unregistered_plugin_without_index_error( plugin_manager_pm: PluginManager, monkeypatch diff --git a/tests/test_profile_aware_tools.py b/tests/test_profile_aware_tools.py deleted file mode 100644 index 86468c3451..0000000000 --- a/tests/test_profile_aware_tools.py +++ /dev/null @@ -1,295 +0,0 @@ -"""Tests for profile-aware sandbox selection and conditional tool registration.""" - -from __future__ import annotations - -from types import SimpleNamespace -from unittest.mock import patch - -import pytest - - -# ═══════════════════════════════════════════════════════════════ -# ShipyardNeoBooter.capabilities -# ═══════════════════════════════════════════════════════════════ - - -class TestShipyardNeoBooterCapabilities: - """Test capabilities property on ShipyardNeoBooter.""" - - def _make_booter(self, sandbox_caps: list[str] | None = None): - from astrbot.core.computer.booters.shipyard_neo import ShipyardNeoBooter - - booter = ShipyardNeoBooter( - endpoint_url="http://localhost:8114", - access_token="sk-bay-test", - ) - if sandbox_caps is not None: - booter._sandbox = SimpleNamespace(capabilities=sandbox_caps) - return booter - - def test_none_before_boot(self): - booter = self._make_booter() - assert booter.capabilities is None - - def test_returns_tuple_after_boot(self): - booter = self._make_booter(["python", "shell", "filesystem"]) - assert booter.capabilities == ("python", "shell", "filesystem") - assert isinstance(booter.capabilities, tuple) - - def test_includes_browser_when_present(self): - booter = self._make_booter(["python", "shell", "filesystem", "browser"]) - assert "browser" in booter.capabilities - - def test_no_browser_when_absent(self): - booter = self._make_booter(["python", "shell", "filesystem"]) - assert "browser" not in booter.capabilities - - def test_returns_immutable(self): - """Verify capabilities returns an immutable tuple.""" - booter = self._make_booter(["python"]) - caps = booter.capabilities - assert isinstance(caps, tuple) - with pytest.raises(AttributeError): - caps.append("mutated") # type: ignore[attr-defined] - - -# ═══════════════════════════════════════════════════════════════ -# _apply_sandbox_tools — conditional browser tool registration -# ═══════════════════════════════════════════════════════════════ - - -def _make_config(booter_type: str = "shipyard_neo"): - return SimpleNamespace( - sandbox_cfg={"booter": booter_type}, - ) - - -def _make_req(): - return SimpleNamespace(func_tool=None, system_prompt="") - - -def _import_apply_sandbox_tools(): - """Import _apply_sandbox_tools, skipping if circular-import fails.""" - try: - from astrbot.core.astr_main_agent import _apply_sandbox_tools - - return _apply_sandbox_tools - except ImportError: - pytest.skip("Cannot import _apply_sandbox_tools (circular import in test env)") - - -class TestApplySandboxToolsConditional: - """Verify browser tools are conditionally registered.""" - - def _tool_names(self, req) -> set[str]: - """Extract tool names from a request's func_tool.""" - if req.func_tool is None: - return set() - return {t.name for t in req.func_tool.tools} - - def test_no_session_registers_all(self): - """First request (no booted session) → all tools including browser.""" - fn = _import_apply_sandbox_tools() - config = _make_config("shipyard_neo") - req = _make_req() - - with patch( - "astrbot.core.computer.computer_client.session_booter", {} - ): - fn(config, req, "session-1") - - names = self._tool_names(req) - assert "astrbot_execute_browser" in names - assert "astrbot_execute_browser_batch" in names - assert "astrbot_run_browser_skill" in names - - def test_with_browser_capability(self): - """Booted session with browser capability → browser tools registered.""" - fn = _import_apply_sandbox_tools() - config = _make_config("shipyard_neo") - req = _make_req() - fake_booter = SimpleNamespace( - capabilities=["python", "shell", "filesystem", "browser"] - ) - - with patch( - "astrbot.core.computer.computer_client.session_booter", - {"session-1": fake_booter}, - ): - fn(config, req, "session-1") - - names = self._tool_names(req) - assert "astrbot_execute_browser" in names - - def test_without_browser_capability(self): - """Booted session WITHOUT browser capability → browser tools NOT registered.""" - fn = _import_apply_sandbox_tools() - config = _make_config("shipyard_neo") - req = _make_req() - fake_booter = SimpleNamespace( - capabilities=["python", "shell", "filesystem"] - ) - - with patch( - "astrbot.core.computer.computer_client.session_booter", - {"session-1": fake_booter}, - ): - fn(config, req, "session-1") - - names = self._tool_names(req) - assert "astrbot_execute_browser" not in names - assert "astrbot_execute_browser_batch" not in names - assert "astrbot_run_browser_skill" not in names - # Skill tools should still be registered - assert "astrbot_get_execution_history" in names - - def test_skill_tools_always_registered(self): - """Skill lifecycle tools are registered regardless of capabilities.""" - fn = _import_apply_sandbox_tools() - config = _make_config("shipyard_neo") - req = _make_req() - fake_booter = SimpleNamespace(capabilities=["python"]) - - with patch( - "astrbot.core.computer.computer_client.session_booter", - {"session-1": fake_booter}, - ): - fn(config, req, "session-1") - - names = self._tool_names(req) - assert "astrbot_create_skill_candidate" in names - assert "astrbot_promote_skill_candidate" in names - - -# ═══════════════════════════════════════════════════════════════ -# _resolve_profile -# ═══════════════════════════════════════════════════════════════ - - -class TestResolveProfile: - """Test smart profile selection logic.""" - - def _make_booter(self, profile: str = ""): - from astrbot.core.computer.booters.shipyard_neo import ShipyardNeoBooter - - return ShipyardNeoBooter( - endpoint_url="http://localhost:8114", - access_token="sk-bay-test", - profile=profile, - ) - - @pytest.mark.asyncio - async def test_user_specified_profile_honoured(self): - """User explicitly sets a non-default profile → use it directly.""" - booter = self._make_booter(profile="browser-python") - client = SimpleNamespace() # list_profiles should NOT be called - result = await booter._resolve_profile(client) - assert result == "browser-python" - - @pytest.mark.asyncio - async def test_user_specified_default_profile_honoured(self): - """User explicitly sets python-default → use it directly.""" - booter = self._make_booter(profile="python-default") - client = SimpleNamespace() # list_profiles should NOT be called - result = await booter._resolve_profile(client) - assert result == "python-default" - - @pytest.mark.asyncio - async def test_selects_browser_profile(self): - """When profile is empty, prefer an available profile with browser.""" - - async def _mock_list_profiles(): - return SimpleNamespace( - items=[ - SimpleNamespace( - id="python-default", - capabilities=["python", "shell", "filesystem"], - ), - SimpleNamespace( - id="browser-python", - capabilities=["python", "shell", "filesystem", "browser"], - ), - ] - ) - - booter = self._make_booter() - client = SimpleNamespace(list_profiles=_mock_list_profiles) - result = await booter._resolve_profile(client) - assert result == "browser-python" - - @pytest.mark.asyncio - async def test_falls_back_to_default_on_api_error(self): - """API error → graceful fallback to python-default.""" - - async def _failing_list_profiles(): - raise ConnectionError("Bay unreachable") - - booter = self._make_booter() - client = SimpleNamespace(list_profiles=_failing_list_profiles) - result = await booter._resolve_profile(client) - assert result == "python-default" - - @pytest.mark.asyncio - async def test_falls_back_on_empty_profiles(self): - """Empty profile list → python-default.""" - - async def _empty_list_profiles(): - return SimpleNamespace(items=[]) - - booter = self._make_booter() - client = SimpleNamespace(list_profiles=_empty_list_profiles) - result = await booter._resolve_profile(client) - assert result == "python-default" - - @pytest.mark.asyncio - async def test_single_profile_selected(self): - """Only one profile available → use it.""" - - async def _single_profile(): - return SimpleNamespace( - items=[ - SimpleNamespace( - id="python-data", - capabilities=["python", "shell", "filesystem"], - ), - ] - ) - - booter = self._make_booter() - client = SimpleNamespace(list_profiles=_single_profile) - result = await booter._resolve_profile(client) - assert result == "python-data" - - @pytest.mark.asyncio - async def test_auth_error_not_silenced(self): - """UnauthorizedError must propagate, not be downgraded to fallback.""" - from shipyard_neo.errors import UnauthorizedError - - async def _unauthorized_list_profiles(): - raise UnauthorizedError("bad token") - - booter = self._make_booter() - client = SimpleNamespace(list_profiles=_unauthorized_list_profiles) - with pytest.raises(UnauthorizedError): - await booter._resolve_profile(client) - - -# ═══════════════════════════════════════════════════════════════ -# ComputerBooter base class -# ═══════════════════════════════════════════════════════════════ - - -class TestBaseComputerBooter: - """Verify base class defaults.""" - - def test_capabilities_default_none(self): - from astrbot.core.computer.booters.base import ComputerBooter - - booter = ComputerBooter() - assert booter.capabilities is None - - def test_browser_default_none(self): - from astrbot.core.computer.booters.base import ComputerBooter - - booter = ComputerBooter() - assert booter.browser is None diff --git a/tests/test_sandbox_frontend_contract.py b/tests/test_sandbox_frontend_contract.py new file mode 100644 index 0000000000..4cb98efc19 --- /dev/null +++ b/tests/test_sandbox_frontend_contract.py @@ -0,0 +1,348 @@ +from pathlib import Path + +ROOT = Path(__file__).resolve().parents[1] + + +def test_sandbox_management_page_exists(): + assert (ROOT / "dashboard/src/views/SandboxManagementPage.vue").is_file() + + +def test_main_routes_include_sandboxes_page(): + content = (ROOT / "dashboard/src/router/MainRoutes.ts").read_text(encoding="utf-8") + + assert "name: 'Sandboxes'" in content + assert "path: '/sandboxes'" in content + assert "SandboxManagementPage.vue" in content + + +def test_sidebar_includes_sandboxes_navigation(): + content = ( + ROOT / "dashboard/src/layouts/full/vertical-sidebar/sidebarItem.ts" + ).read_text(encoding="utf-8") + + assert "core.navigation.sandboxes" in content + assert "to: '/sandboxes'" in content + + +def test_sandbox_management_page_uses_current_sandbox_api_prefix(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "/api/sandboxes" not in content + assert "/api/sandbox" in content + + +def test_sandbox_management_page_does_not_gate_destroy_by_provider_capability(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "hasCapability(item, 'destroy')" not in content + + +def test_sandbox_management_page_disables_destroy_for_occupied_sandboxes(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "case 'destroy':" in content + assert "return status !== 'stopping' && !item.controller_session_id" in content + + +def test_sandbox_management_page_replaces_console_history_after_command_updates(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "consoleHistory.value = [...consoleHistory.value]" in content + + +def test_sandbox_management_page_release_is_not_limited_to_dashboard_controller(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "item.controller_session_id === 'dashboard'" not in content + assert "case 'release':" in content + assert "return status !== 'stopping' && !!item.controller_session_id" in content + + +def test_sandbox_management_page_uses_backend_create_record_without_local_status_guess(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "status: 'creating' as const" not in content + assert "startCreatePolling(created.sandbox_id, created)" in content + assert "upsertSandboxRecord(placeholder)" in content + assert ( + "const index = sandboxes.value.findIndex((item) => item.sandbox_id === record.sandbox_id)" + in content + ) + assert "next[index] = record" in content + assert "sandboxes.value = next" in content + + +def test_sandbox_management_page_keeps_create_button_available_while_other_sandboxes_are_creating(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert 'prepend-icon="mdi-plus" @click="createDialog = true"' in content + assert 'prepend-icon="mdi-plus" :disabled="creatingRequestPending"' not in content + assert ':disabled="createFlowActive"' not in content + assert ':disabled="!hasProviderOptions || creatingRequestPending"' in content + + +def test_sandbox_management_page_tracks_multiple_pending_creates_instead_of_single_id(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "pendingCreateSandboxes" in content + assert "pendingCreateSandboxId" not in content + + +def test_sandbox_management_page_loads_provider_options_from_api(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "const providerOptions = [" not in content + assert "axios.get('/api/sandbox/providers'" in content + assert "providerOptions.value = providers.map(" in content + assert "defaultProviderId" in content + assert ( + "providerOptions.value.find((option) => option.value === defaultProviderId)" + in content + ) + + +def test_sandbox_management_page_does_not_show_legacy_provider_hint(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "tm('create.providerHint')" not in content + + +def test_sandbox_management_page_does_not_allow_configure_while_creating_or_restoring(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "case 'configure':" in content + assert ( + "return status !== 'creating' && status !== 'restoring' && status !== 'stopping'" + in content + ) + + +def test_sandbox_management_page_does_not_toast_running_after_create(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "toast(tm('messages.createReady'))" not in content + + +def test_sandbox_management_page_destroy_closes_dialog_before_backend_cleanup(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "const targetId = target.sandbox_id" in content + assert "destroyDialog.value = false" in content + assert "startDestroyPolling(targetId)" in content + assert "const res = await axios.delete(sandboxApiPath(targetId)" in content + assert "status: 'stopping'" not in content + assert "upsertSandboxRecord(stoppingRecord)" not in content + assert "destroying" not in content + assert ( + "const sandbox = res.data.data?.sandbox as SandboxRecord | undefined" in content + ) + assert "upsertSandboxRecord(sandbox)" in content + assert "void loadSandboxes({ silent: true })" in content + assert "destroyQueued" not in content + + +def test_sandbox_management_page_starts_destroy_polling_only_after_backend_accepts(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert ( + "if (res.data.status === 'ok') {\n startDestroyPolling(targetId)" + in content + ) + assert "stopDestroyPollingForSandbox(targetId)" in content + + +def test_sandbox_management_page_polls_until_destroyed_sandbox_disappears(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "function startDestroyPolling" in content + assert "pendingDestroySandboxes" in content + assert ( + "const record = result.records.find((item) => item.sandbox_id === trackedSandboxId)" + in content + ) + assert "if (!record) {" in content + assert "finishDestroyPolling(trackedSandboxId)" in content + assert "removeSandboxRecord(sandboxId)" in content + assert "startDestroyPolling(targetId)" in content + + +def test_sandbox_management_page_splits_running_into_busy_and_available_labels(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "return hasController(item) ? 'busy' : 'available'" in content + + +def test_sandbox_management_page_shows_controller_session_in_status_tooltip(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert 'v-tooltip v-if="item.controller_session_id" activator="parent"' in content + assert "{{ item.controller_session_id }}" in content + assert 'class="text-caption text-medium-emphasis mt-1"' not in content + + +def test_sandbox_management_page_confirms_dangerous_console_commands(): + page = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + console_utils = (ROOT / "dashboard/src/views/sandbox/consoleUtils.ts").read_text( + encoding="utf-8" + ) + + assert "import { isDangerousConsoleCommand } from './sandbox/consoleUtils'" in page + assert "isDangerousConsoleCommand(command) && !window.confirm" in page + assert "window.confirm(tm('console.dangerConfirm'" in page + assert "function isDangerousConsoleCommand" in console_utils + assert "rm\\s+(?:-" in console_utils + assert "(?:--\\s+)?" in console_utils + + +def test_sandbox_management_page_displays_console_cwd_relative_to_sandbox_home(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "if (cwd === '/workspace') return '~'" in content + assert ( + "if (cwd.startsWith('/workspace/')) return `~${cwd.slice('/workspace'.length)}`" + in content + ) + assert "cwd.match(/^\\/home\\/[^/]+(.*)$/)" in content + assert "return suffix ? `~${suffix}` : '~'" in content + + +def test_sandbox_management_page_strips_console_cwd_markers_from_output(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "function stripConsoleCwdMarkers" in content + assert "stripConsoleCwdMarkers(stdout)" in content + assert "stripConsoleCwdMarkers(visibleStdout)" in content + assert "!line.includes('__ASTRBOT_CWD__')" in content + + +def test_sandbox_management_page_records_console_api_errors_in_history(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert ( + "throw new Error(res.data.message || tm('messages.operationFailed'))" in content + ) + assert "entry.stderr = normalizeTerminalOutput(e?.message || String(e))" in content + + +def test_sandbox_management_page_console_cwd_prefix_does_not_hide_failed_cd(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "? `cd ${quoteForShell(cwd)}; `" in content + assert "? `cd ${quoteForShell(cwd)} && `" not in content + + +def test_sandbox_management_page_surfaces_unknown_status_key(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "tm('labels.unknownStatus', { status: key })" in content + + +def test_sandbox_management_page_localizes_max_sandbox_limit_errors(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + zh = (ROOT / "dashboard/src/i18n/locales/zh-CN/features/sandbox.json").read_text( + encoding="utf-8" + ) + en = (ROOT / "dashboard/src/i18n/locales/en-US/features/sandbox.json").read_text( + encoding="utf-8" + ) + ru = (ROOT / "dashboard/src/i18n/locales/ru-RU/features/sandbox.json").read_text( + encoding="utf-8" + ) + + assert "function localizedSandboxError" in content + assert "Sandbox limit reached" in content + assert "messages.maxSandboxesReached" in content + assert "maxSandboxesReached" in zh + assert "maxSandboxesReached" in en + assert "maxSandboxesReached" in ru + + +def test_sandbox_management_page_does_not_render_legacy_booter_type(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "booter_type" not in content + assert "showBooterTypeCaption" not in content + assert "provider-summary" not in content + + +def test_sandbox_management_page_has_dedicated_capabilities_column(): + content = (ROOT / "dashboard/src/views/SandboxManagementPage.vue").read_text( + encoding="utf-8" + ) + + assert "key: 'capabilities'" in content + assert 'v-for="capability in item.capabilities || []"' in content + assert "tm('headers.capabilities')" in content + + +def test_sandbox_i18n_uses_status_and_idle_labels(): + zh = (ROOT / "dashboard/src/i18n/locales/zh-CN/features/sandbox.json").read_text( + encoding="utf-8" + ) + en = (ROOT / "dashboard/src/i18n/locales/en-US/features/sandbox.json").read_text( + encoding="utf-8" + ) + ru = (ROOT / "dashboard/src/i18n/locales/ru-RU/features/sandbox.json").read_text( + encoding="utf-8" + ) + + assert '"status": "状态"' in zh + assert '"available": "空闲"' in zh + assert '"unknownStatus": "未知状态:{status}"' in zh + assert '"status": "Status"' in en + assert '"available": "Idle"' in en + assert '"unknownStatus": "Unknown status: {status}"' in en + assert '"dangerConfirm"' in zh + assert '"dangerConfirm"' in en + assert '"unknownStatus"' in ru + assert '"dangerConfirm"' in ru diff --git a/tests/test_sandbox_plugin_schema_contract.py b/tests/test_sandbox_plugin_schema_contract.py new file mode 100644 index 0000000000..4064a94692 --- /dev/null +++ b/tests/test_sandbox_plugin_schema_contract.py @@ -0,0 +1,263 @@ +import json +from pathlib import Path + +import pytest + +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.config.default import CONFIG_METADATA_3, DEFAULT_CONFIG + +ROOT = Path(__file__).resolve().parents[1] +SHIPYARD_COMPOSE = (ROOT / "compose-with-shipyard.yml").read_text(encoding="utf-8") +SANDBOX_TIMEOUT_KEYS = { + "sandbox_ttl", + "sandbox_idle_timeout", + "sandbox_lease_timeout", + "cua_ttl", + "cua_idle_timeout", + "shipyard_ttl", + "shipyard_idle_timeout", + "shipyard_neo_ttl", +} + + +def _require_plugin_files(*relative_paths: str) -> None: + missing = [path for path in relative_paths if not (ROOT / path).is_file()] + if missing: + pytest.skip(f"sandbox plugin repository files are not present: {missing}") + + +def _load_schema(plugin_name: str) -> dict: + schema_path = ROOT / "data/plugins" / plugin_name / "_conf_schema.json" + if not schema_path.is_file(): + pytest.skip(f"sandbox plugin schema is not present: {schema_path}") + return json.loads(schema_path.read_text(encoding="utf-8")) + + +def _assert_no_plugin_timeout_schema(schema: dict) -> None: + assert not (SANDBOX_TIMEOUT_KEYS & set(schema)) + + +def _read_plugin_file(plugin_name: str, filename: str) -> str: + path = ROOT / "data/plugins" / plugin_name / filename + if not path.is_file(): + pytest.skip(f"sandbox plugin file is not present: {path}") + return path.read_text(encoding="utf-8") + + +@pytest.mark.parametrize( + ("plugin_name", "description"), + [ + ("astrbot_sandbox_cua", "为 AstrBot 提供 CUA 沙盒运行时。"), + ("astrbot_sandbox_boxlite", "为 AstrBot 提供 Boxlite 本地沙盒运行时。"), + ("astrbot_sandbox_shipyard", "为 AstrBot 提供 Shipyard 沙盒运行时。"), + ("astrbot_sandbox_shipyard_neo", "为 AstrBot 提供 Shipyard Neo 沙盒运行时。"), + ], +) +def test_sandbox_plugin_metadata_is_localized(plugin_name: str, description: str): + metadata = _read_plugin_file(plugin_name, "metadata.yaml") + main_py = _read_plugin_file(plugin_name, "main.py") + + assert f"desc: {description}" in metadata + assert f'"{description}"' in main_py + assert "sandbox runtime provider for AstrBot" not in metadata + assert "sandbox runtime provider for AstrBot" not in main_py + + +def test_core_sandbox_timeout_defaults_live_in_bot_config(): + sandbox = DEFAULT_CONFIG["provider_settings"]["sandbox"] + + assert sandbox["sandbox_ttl"] == 3600 + assert sandbox["sandbox_idle_timeout"] == 1800 + assert sandbox["sandbox_lease_timeout"] == 600 + + +def test_dashboard_schema_exposes_sandbox_lease_timeout(): + schema = CONFIG_METADATA_3["ai_group"]["metadata"]["agent_computer_use"]["items"] + + lease_timeout = schema["provider_settings.sandbox.sandbox_lease_timeout"] + assert lease_timeout["type"] == "int" + assert "每次 Agent 成功访问沙盒" in lease_timeout["hint"] + assert "默认 600 秒" in lease_timeout["hint"] + assert "当前会话不再绑定该沙盒" in lease_timeout["hint"] + assert "其他会话可接管" in lease_timeout["hint"] + + +def test_cua_schema_defaults_match_documented_hints(): + schema = _load_schema("astrbot_sandbox_cua") + + _assert_no_plugin_timeout_schema(schema) + assert schema["cua_image"]["default"] == "linux" + assert schema["cua_os_type"]["default"] == "linux" + assert schema["cua_telemetry_enabled"]["default"] is False + assert schema["cua_local"]["default"] is True + assert schema["cua_api_key"]["default"] == "" + assert schema["cua_local"]["description"] == "CUA 本地沙箱" + assert "默认开启" in schema["cua_local"]["hint"] + + +def test_shipyard_schema_is_localized_and_has_defaults(): + schema = _load_schema("astrbot_sandbox_shipyard") + + _assert_no_plugin_timeout_schema(schema) + assert schema["shipyard_endpoint"]["description"] == "Shipyard API 地址" + assert schema["shipyard_endpoint"]["default"] == "http://127.0.0.1:8156" + assert schema["shipyard_auto_start"]["default"] is True + assert schema["shipyard_bay_image"]["default"] == "soulter/shipyard-bay:latest" + assert schema["shipyard_ship_image"]["default"] == "soulter/shipyard-ship:latest" + assert schema["shipyard_bay_image"]["default"] in SHIPYARD_COMPOSE + assert schema["shipyard_ship_image"]["default"] in SHIPYARD_COMPOSE + assert schema["shipyard_access_token"]["description"] == "Shipyard 访问令牌" + assert schema["shipyard_access_token"]["default"] == "" + assert schema["shipyard_max_sessions"]["description"] == "Shipyard 最大会话数" + assert schema["shipyard_max_sessions"]["default"] == 10 + + +def test_shipyard_neo_schema_is_localized_and_has_defaults(): + schema = _load_schema("astrbot_sandbox_shipyard_neo") + + _assert_no_plugin_timeout_schema(schema) + assert schema["shipyard_neo_endpoint"]["description"] == "Shipyard Neo API 地址" + assert schema["shipyard_neo_endpoint"]["default"] == "http://127.0.0.1:8114" + assert schema["shipyard_neo_access_token"]["description"] == "Shipyard Neo 访问令牌" + assert schema["shipyard_neo_access_token"]["default"] == "" + assert schema["shipyard_neo_profile"]["description"] == "Shipyard Neo Profile" + assert schema["shipyard_neo_profile"]["default"] == "python-default" + + +def test_shipyard_neo_plugin_does_not_duplicate_builtin_tool_registration(): + _require_plugin_files( + "data/plugins/astrbot_sandbox_shipyard_neo/main.py", + "data/plugins/astrbot_sandbox_shipyard_neo/tools/shipyard_neo/browser.py", + ) + content = (ROOT / "data/plugins/astrbot_sandbox_shipyard_neo/main.py").read_text( + encoding="utf-8" + ) + + assert "tools=[" not in content + + +def test_cua_provider_falls_back_to_local_mode_without_api_key(monkeypatch): + _require_plugin_files("data/plugins/astrbot_sandbox_cua/provider.py") + from data.plugins.astrbot_sandbox_cua.provider import CuaSandboxProvider + + monkeypatch.delenv("CUA_API_KEY", raising=False) + + class FakeContext: + def get_config(self, umo=None): + del umo + return {"provider_settings": {"sandbox": {}}} + + provider = CuaSandboxProvider() + provider.plugin_config = { + "cua_local": False, + "cua_api_key": "", + "cua_image": "linux", + "cua_os_type": "linux", + } + + config = provider.build_create_config(FakeContext(), "dashboard") + + assert config["local"] is True + assert config["api_key"] == "" + + +def test_cua_provider_connect_info_tracks_persistent_runtime_name(): + _require_plugin_files("data/plugins/astrbot_sandbox_cua/provider.py") + from data.plugins.astrbot_sandbox_cua.provider import CuaSandboxProvider + + provider = CuaSandboxProvider() + + info = provider.build_connect_info( + "Named", + { + "local": True, + "image": "linux", + "os_type": "linux", + "persistent_name": "cua-persistent-1", + }, + ) + + assert info["name"] == "Named" + assert info["persistent_name"] == "cua-persistent-1" + + +def test_existing_plugin_config_keeps_saved_values_when_schema_defaults_change( + tmp_path, +): + config_path = tmp_path / "plugin_config.json" + config_path.write_text('{"flag": false, "ttl": 0}', encoding="utf-8") + + config = AstrBotConfig( + config_path=str(config_path), + schema={ + "flag": {"type": "bool", "default": True}, + "ttl": {"type": "int", "default": 3600}, + }, + ) + + assert config["flag"] is False + assert config["ttl"] == 0 + + +def test_cua_adapter_uses_core_screenshot_operation(): + _require_plugin_files("data/plugins/astrbot_sandbox_cua/tools/cua.py") + from data.plugins.astrbot_sandbox_cua.provider import CuaSandboxProvider + from data.plugins.astrbot_sandbox_cua.tools import cua as cua_tools + + provider = CuaSandboxProvider() + + assert "screenshot" in provider.capabilities + assert "astrbot_cua_screenshot" not in provider.tool_names + assert provider.tool_names == { + "astrbot_cua_keyboard_type", + "astrbot_cua_mouse_click", + } + assert not hasattr(cua_tools, "CuaScreenshotTool") + + +def test_core_sandbox_screenshot_operation_can_return_image_to_llm(): + from astrbot.core.tools.computer_tools.sandbox import SandboxOperationTool + + tool = SandboxOperationTool() + properties = tool.parameters["properties"] + + assert "capture_screenshot" in properties["action"]["enum"] + assert properties["send_to_user"]["description"] + assert properties["return_image_to_llm"]["default"] is False + + +@pytest.mark.asyncio +async def test_shipyard_neo_execution_history_ignores_empty_optional_filters( + monkeypatch, +): + _require_plugin_files( + "data/plugins/astrbot_sandbox_shipyard_neo/tools/shipyard_neo/neo_skills.py" + ) + from data.plugins.astrbot_sandbox_shipyard_neo.tools.shipyard_neo import ( + neo_skills, + ) + + captured = {} + + class Sandbox: + async def get_execution_history(self, **kwargs): + captured.update(kwargs) + return [] + + async def fake_get_neo_context(context): + return object(), Sandbox() + + monkeypatch.setattr(neo_skills, "_get_neo_context", fake_get_neo_context) + monkeypatch.setattr(neo_skills, "check_admin_permission", lambda *args: None) + + result = await neo_skills.GetExecutionHistoryTool().call( + object(), + exec_type="", + tags="", + limit=5, + offset=0, + ) + + assert result == "[]" + assert captured["exec_type"] is None + assert captured["tags"] is None diff --git a/tests/test_shipyard_neo_booter.py b/tests/test_shipyard_neo_booter.py deleted file mode 100644 index b0d7ecc01d..0000000000 --- a/tests/test_shipyard_neo_booter.py +++ /dev/null @@ -1,344 +0,0 @@ -"""Tests for ShipyardNeoBooter — readiness gate, shutdown cleanup, and rebuild recovery.""" - -from __future__ import annotations - -import asyncio -from types import SimpleNamespace -from unittest.mock import AsyncMock, patch - -import pytest - - -# ═══════════════════════════════════════════════════════════════ -# _wait_until_ready -# ═══════════════════════════════════════════════════════════════ - - -def _make_sandbox_mock(statuses: list[str], *, delete_side_effect=None): - """Build a sandbox mock that returns *statuses* in order on refresh(). - - After the list is exhausted subsequent refresh() calls return the last status. - """ - call_count = 0 - - async def _refresh(): - nonlocal call_count - idx = min(call_count, len(statuses) - 1) - call_count += 1 - s = statuses[idx] - sandbox.status = SimpleNamespace(value=s) - - sandbox = SimpleNamespace( - id="sandbox-test-1", - profile="python-default", - status=SimpleNamespace(value=statuses[0]), - refresh=_refresh, - delete=AsyncMock(side_effect=delete_side_effect), - ) - return sandbox - - -class TestWaitUntilReady: - def _make_booter(self): - from astrbot.core.computer.booters.shipyard_neo import ShipyardNeoBooter - - return ShipyardNeoBooter( - endpoint_url="http://localhost:8114", - access_token="sk-bay-test", - ) - - @pytest.mark.asyncio - async def test_already_ready_returns_immediately(self): - """Sandbox is READY on first poll → instant return (warm hit).""" - booter = self._make_booter() - sandbox = _make_sandbox_mock(["ready"]) - - await booter._wait_until_ready(sandbox) - - sandbox.delete.assert_not_called() - - @pytest.mark.asyncio - async def test_starting_then_ready(self): - """Sandbox transitions STARTING → READY within timeout.""" - booter = self._make_booter() - sandbox = _make_sandbox_mock(["starting", "starting", "ready"]) - - await booter._wait_until_ready(sandbox) - - sandbox.delete.assert_not_called() - - @pytest.mark.asyncio - async def test_failed_deletes_and_raises(self): - """Sandbox reaches FAILED → delete called → RuntimeError raised.""" - booter = self._make_booter() - sandbox = _make_sandbox_mock(["starting", "failed"]) - - with pytest.raises(RuntimeError, match="terminal state"): - await booter._wait_until_ready(sandbox) - - sandbox.delete.assert_awaited_once() - - @pytest.mark.asyncio - async def test_expired_deletes_and_raises(self): - """Sandbox reaches EXPIRED → delete called → RuntimeError raised.""" - booter = self._make_booter() - sandbox = _make_sandbox_mock(["starting", "expired"]) - - with pytest.raises(RuntimeError, match="terminal state"): - await booter._wait_until_ready(sandbox) - - sandbox.delete.assert_awaited_once() - - @pytest.mark.asyncio - async def test_timeout_deletes_and_raises(self): - """Sandbox never reaches READY → delete called → TimeoutError raised.""" - booter = self._make_booter() - # Return 'idle' every time to simulate a stuck sandbox - sandbox = _make_sandbox_mock(["idle"]) - - # Override the deadline so we don't actually sleep 180s - original_time = asyncio.get_running_loop().time - - call_idx = 0 - - def _fake_time(): - nonlocal call_idx - # After one tick, jump past the deadline - if call_idx == 0: - call_idx += 1 - return original_time() - # Return a value beyond the 180s timeout - return original_time() + 200 - - with patch( - "astrbot.core.computer.booters.shipyard_neo.asyncio.get_running_loop" - ) as mock_loop: - mock_loop.return_value.time = _fake_time - - with pytest.raises(TimeoutError, match="did not become ready"): - await booter._wait_until_ready(sandbox) - - sandbox.delete.assert_awaited_once() - - @pytest.mark.asyncio - async def test_delete_failure_during_cleanup_is_safe(self): - """If sandbox.delete() itself throws, the original error is still raised.""" - booter = self._make_booter() - sandbox = _make_sandbox_mock( - ["failed"], - delete_side_effect=RuntimeError("Bay unreachable"), - ) - - with pytest.raises(RuntimeError, match="terminal state"): - await booter._wait_until_ready(sandbox) - - sandbox.delete.assert_awaited_once() - - -# ═══════════════════════════════════════════════════════════════ -# shutdown -# ═══════════════════════════════════════════════════════════════ - - -class TestShutdown: - def _make_booter(self): - from astrbot.core.computer.booters.shipyard_neo import ShipyardNeoBooter - - return ShipyardNeoBooter( - endpoint_url="http://localhost:8114", - access_token="sk-bay-test", - ) - - @pytest.mark.asyncio - async def test_delete_sandbox_true_calls_delete(self): - """delete_sandbox=True → sandbox.delete() called, then client closed.""" - booter = self._make_booter() - sandbox = SimpleNamespace( - id="sandbox-test-1", - delete=AsyncMock(), - ) - client = SimpleNamespace( - __aexit__=AsyncMock(), - ) - booter._sandbox = sandbox # type: ignore[assignment] - booter._client = client # type: ignore[assignment] - - await booter.shutdown(delete_sandbox=True) - - sandbox.delete.assert_awaited_once() - client.__aexit__.assert_awaited_once() - assert booter._client is None - assert booter._sandbox is None - - @pytest.mark.asyncio - async def test_delete_sandbox_false_does_not_call_delete(self): - """delete_sandbox=False (default) → sandbox.delete() NOT called.""" - booter = self._make_booter() - sandbox = SimpleNamespace( - id="sandbox-test-1", - delete=AsyncMock(), - ) - client = SimpleNamespace( - __aexit__=AsyncMock(), - ) - booter._sandbox = sandbox # type: ignore[assignment] - booter._client = client # type: ignore[assignment] - - await booter.shutdown() # default delete_sandbox=False - - sandbox.delete.assert_not_called() - client.__aexit__.assert_awaited_once() - assert booter._client is None - assert booter._sandbox is None - - @pytest.mark.asyncio - async def test_delete_failure_still_closes_client(self): - """If sandbox.delete() throws, HTTP client is still torn down.""" - booter = self._make_booter() - sandbox = SimpleNamespace( - id="sandbox-test-1", - delete=AsyncMock(side_effect=RuntimeError("Bay gone")), - ) - client = SimpleNamespace( - __aexit__=AsyncMock(), - ) - booter._sandbox = sandbox # type: ignore[assignment] - booter._client = client # type: ignore[assignment] - - # Should not raise — delete failure is logged but swallowed - await booter.shutdown(delete_sandbox=True) - - sandbox.delete.assert_awaited_once() - client.__aexit__.assert_awaited_once() - assert booter._client is None - assert booter._sandbox is None - - @pytest.mark.asyncio - async def test_no_client_is_noop(self): - """shutdown() on an uninitialised booter is a no-op.""" - booter = self._make_booter() - # _client is None by default - await booter.shutdown(delete_sandbox=True) - # No exception → ok - - -# ═══════════════════════════════════════════════════════════════ -# get_booter rebuild path -# ═══════════════════════════════════════════════════════════════ - - -class TestGetBooterRebuild: - """Verify that stale ShipyardNeoBooter instances are cleaned up on rebuild.""" - - def _make_fake_context(self, booter_type: str = "shipyard_neo"): - """Build a context-like object for get_booter().""" - _cfg = { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": { - "booter": booter_type, - "shipyard_neo_endpoint": "http://bay:8114", - "shipyard_neo_access_token": "sk-test", - "shipyard_neo_ttl": 3600, - "shipyard_neo_profile": "python-default", - }, - } - } - return SimpleNamespace( - get_config=lambda umo=None: _cfg, - ) - - @pytest.mark.asyncio - async def test_stale_neo_booter_calls_shutdown_with_delete(self, monkeypatch): - """A stale ShipyardNeoBooter gets shutdown(delete_sandbox=True) on eviction.""" - from astrbot.core.computer import computer_client - from astrbot.core.computer.booters.shipyard_neo import ShipyardNeoBooter - - ctx = self._make_fake_context() - - stale = ShipyardNeoBooter( - endpoint_url="http://bay:8114", access_token="sk-test" - ) - stale._sandbox = SimpleNamespace(id="stale-sandbox") # type: ignore[assignment] - stale._client = SimpleNamespace(__aexit__=AsyncMock()) # type: ignore[assignment] - stale._sandbox.refresh = AsyncMock(side_effect=RuntimeError("sandbox gone")) # type: ignore[union-attr] - # available() will return False because refresh() throws - stale.shutdown = AsyncMock() - - monkeypatch.setitem(computer_client.session_booter, "session-1", stale) - - from astrbot.core.computer.computer_client import get_booter - - # get_booter should evict stale and rebuild. - # We need to mock the entire rebuild path so it doesn't actually - # try to connect to Bay. - async def _fake_boot(_self, _sid): - _self._sandbox = SimpleNamespace( # type: ignore[assignment] - id="new-sandbox", - refresh=AsyncMock(), - status=SimpleNamespace(value="ready"), - capabilities=["python", "shell", "filesystem"], - ) - _self._client = SimpleNamespace() # type: ignore[assignment] - _self._shell = SimpleNamespace() # type: ignore[assignment] - _self._fs = SimpleNamespace() # type: ignore[assignment] - _self._python = SimpleNamespace() # type: ignore[assignment] - - with patch.object( - ShipyardNeoBooter, "boot", _fake_boot - ), patch( - "astrbot.core.computer.computer_client._sync_skills_to_sandbox", - AsyncMock(), - ): - await get_booter(ctx, "session-1") - - stale.shutdown.assert_awaited_once_with(delete_sandbox=True) - # Old entry should be replaced - new_booter = computer_client.session_booter.get("session-1") - assert new_booter is not None - assert new_booter is not stale - - @pytest.mark.asyncio - async def test_stale_non_neo_booter_calls_plain_shutdown(self, monkeypatch): - """Non-neo booter (e.g. shipyard) → plain shutdown() without delete_sandbox.""" - from astrbot.core.computer import computer_client - - ctx = self._make_fake_context(booter_type="shipyard") - - stale = SimpleNamespace(shutdown=AsyncMock()) - stale.available = AsyncMock(return_value=False) - - monkeypatch.setitem(computer_client.session_booter, "session-1", stale) - - # Patch ShipyardBooter entirely to skip its __init__ validation - class _FakeShipyardBooter: - def __init__(self, **kwargs): - pass - - async def boot(self, _sid): - self._sandbox = SimpleNamespace( # type: ignore[assignment] - refresh=AsyncMock(), - status=SimpleNamespace(value="ready"), - ) - self._shell = SimpleNamespace() # type: ignore[assignment] - self._fs = SimpleNamespace() # type: ignore[assignment] - self._python = SimpleNamespace() # type: ignore[assignment] - - async def shutdown(self, **kwargs): - pass - - with patch( - "astrbot.core.computer.booters.shipyard.ShipyardBooter", - _FakeShipyardBooter, - ), patch( - "astrbot.core.computer.computer_client._sync_skills_to_sandbox", - AsyncMock(), - ): - from astrbot.core.computer.computer_client import get_booter - - await get_booter(ctx, "session-1") - - stale.shutdown.assert_awaited_once() - # No delete_sandbox kwarg for non-neo booters - call_kwargs = stale.shutdown.call_args.kwargs - assert call_kwargs == {} diff --git a/tests/test_skill_manager_sandbox_cache.py b/tests/test_skill_manager_sandbox_cache.py index 35fb608118..f267e40142 100644 --- a/tests/test_skill_manager_sandbox_cache.py +++ b/tests/test_skill_manager_sandbox_cache.py @@ -1,5 +1,8 @@ from __future__ import annotations +import json +import threading +from typing import Any from pathlib import Path import pytest @@ -61,6 +64,125 @@ def test_list_skills_merges_local_and_sandbox_cache(monkeypatch, tmp_path: Path) assert by_name["python-sandbox"].path == "/app/skills/python-sandbox/SKILL.md" +def test_sandbox_cache_isolated_by_provider(monkeypatch, tmp_path: Path): + data_dir = tmp_path / "data" + temp_dir = tmp_path / "temp" + skills_root = tmp_path / "skills" + data_dir.mkdir(parents=True, exist_ok=True) + temp_dir.mkdir(parents=True, exist_ok=True) + skills_root.mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr( + "astrbot.core.skills.skill_manager.get_astrbot_data_path", + lambda: str(data_dir), + ) + monkeypatch.setattr( + "astrbot.core.skills.skill_manager.get_astrbot_temp_path", + lambda: str(temp_dir), + ) + + mgr = SkillManager(skills_root=str(skills_root)) + mgr.set_sandbox_skills_cache( + [ + { + "name": "python-sandbox", + "description": "ship built-in", + "path": "/home/ship_1e53ee8e/workspace/skills/python-sandbox/SKILL.md", + } + ], + provider_id="shipyard", + ) + + skills = mgr.list_skills(runtime="sandbox", provider_id="cua") + + assert all(skill.name != "python-sandbox" for skill in skills) + + +def test_sandbox_cache_updates_are_serialized_across_instances( + monkeypatch, + tmp_path: Path, +): + data_dir = tmp_path / "data" + temp_dir = tmp_path / "temp" + skills_root = tmp_path / "skills" + data_dir.mkdir(parents=True, exist_ok=True) + temp_dir.mkdir(parents=True, exist_ok=True) + skills_root.mkdir(parents=True, exist_ok=True) + + monkeypatch.setattr( + "astrbot.core.skills.skill_manager.get_astrbot_data_path", + lambda: str(data_dir), + ) + monkeypatch.setattr( + "astrbot.core.skills.skill_manager.get_astrbot_temp_path", + lambda: str(temp_dir), + ) + + mgr_a = SkillManager(skills_root=str(skills_root)) + mgr_b = SkillManager(skills_root=str(skills_root)) + + original_save = SkillManager._save_sandbox_skills_cache + save_started = threading.Event() + release_first_save = threading.Event() + save_calls = 0 + save_calls_lock = threading.Lock() + + def blocking_save(self, cache: dict[str, Any]) -> None: + nonlocal save_calls + with save_calls_lock: + save_calls += 1 + current_call = save_calls + if current_call == 1: + save_started.set() + assert release_first_save.wait(timeout=1) + original_save(self, cache) + + monkeypatch.setattr( + "astrbot.core.skills.skill_manager.SkillManager._save_sandbox_skills_cache", + blocking_save, + ) + + thread_a = threading.Thread( + target=mgr_a.set_sandbox_skills_cache, + kwargs={ + "skills": [ + { + "name": "ship-skill", + "description": "ship", + "path": "/home/ship/workspace/skills/ship-skill/SKILL.md", + } + ], + "provider_id": "shipyard", + }, + ) + thread_b = threading.Thread( + target=mgr_b.set_sandbox_skills_cache, + kwargs={ + "skills": [ + { + "name": "cua-skill", + "description": "cua", + "path": "/home/cua/workspace/skills/cua-skill/SKILL.md", + } + ], + "provider_id": "cua", + }, + ) + + thread_a.start() + assert save_started.wait(timeout=1) + thread_b.start() + release_first_save.set() + thread_a.join(timeout=1) + thread_b.join(timeout=1) + + assert not thread_a.is_alive() + assert not thread_b.is_alive() + + cache = json.loads((data_dir / "sandbox_skills_cache.json").read_text(encoding="utf-8")) + assert set(cache["providers"]) == {"shipyard", "cua"} + + def test_sandbox_cached_skill_respects_active_and_display_path( monkeypatch, tmp_path: Path, @@ -154,4 +276,3 @@ def test_sandbox_and_local_path_resolution_with_show_sandbox_path_false( assert local_skill_path.is_relative_to(skills_root) assert local_skill_path == skills_root / "custom-local" / "SKILL.md" assert by_name["python-sandbox"].path == "/app/skills/python-sandbox/SKILL.md" - diff --git a/tests/test_tool_loop_agent_runner.py b/tests/test_tool_loop_agent_runner.py index b4464680fb..5355ba4a30 100644 --- a/tests/test_tool_loop_agent_runner.py +++ b/tests/test_tool_loop_agent_runner.py @@ -960,6 +960,44 @@ async def test_same_tool_streak_resets_after_switching_tools( assert level_2_notice in content +@pytest.mark.asyncio +async def test_repeated_shell_tool_results_do_not_include_guidance( + runner, mock_tool_executor, mock_hooks +): + runner_cls = type(runner) + total_calls = runner_cls.REPEATED_TOOL_NOTICE_L3_THRESHOLD + provider = SequentialToolProvider(["astrbot_execute_shell"] * total_calls) + tool = FunctionTool( + name="astrbot_execute_shell", + description="Execute shell commands", + parameters={"type": "object", "properties": {"command": {"type": "string"}}}, + handler=AsyncMock(), + ) + request = ProviderRequest( + prompt="Run several shell commands", + func_tool=ToolSet(tools=[tool]), + contexts=[], + ) + + await runner.reset( + provider=provider, + request=request, + run_context=ContextWrapper(context=None), + tool_executor=mock_tool_executor, + agent_hooks=mock_hooks, + streaming=False, + ) + + async for _ in runner.step_until_done(total_calls + 1): + pass + + tool_messages = [ + m for m in runner.run_context.messages if getattr(m, "role", None) == "tool" + ] + assert len(tool_messages) == total_calls + assert all("SYSTEM NOTICE" not in str(message.content) for message in tool_messages) + + @pytest.mark.asyncio async def test_fallback_provider_used_when_primary_raises( runner, provider_request, mock_tool_executor, mock_hooks diff --git a/tests/unit/test_aiocqhttp_poke.py b/tests/unit/test_aiocqhttp_poke.py index dff886b91e..d8ec79fe78 100644 --- a/tests/unit/test_aiocqhttp_poke.py +++ b/tests/unit/test_aiocqhttp_poke.py @@ -49,3 +49,127 @@ async def test_aiocqhttp_send_message_dispatches_onebot_v11_poke_payload(): group_id=123456, message=[{"type": "poke", "data": {"type": "126", "id": "2003"}}], ) + + +@pytest.mark.asyncio +async def test_aiocqhttp_send_group_file_uses_upload_action(tmp_path): + bot = AsyncMock() + file_path = tmp_path / "report.md" + file_path.write_text("report", encoding="utf-8") + chain = MessageChain([Comp.File(name="report.md", file=str(file_path))]) + + await AiocqhttpMessageEvent.send_message( + bot=bot, + message_chain=chain, + event=None, + is_group=True, + session_id="123456", + ) + + bot.upload_group_file.assert_awaited_once_with( + group_id=123456, + file=str(file_path), + name="report.md", + ) + bot.send_group_msg.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_aiocqhttp_send_group_file_accepts_whitespace_session_id(tmp_path): + bot = AsyncMock() + file_path = tmp_path / "report.md" + file_path.write_text("report", encoding="utf-8") + chain = MessageChain([Comp.File(name="report.md", file=str(file_path))]) + + await AiocqhttpMessageEvent.send_message( + bot=bot, + message_chain=chain, + event=None, + is_group=True, + session_id=" 123456 ", + ) + + bot.upload_group_file.assert_awaited_once_with( + group_id=123456, + file=str(file_path), + name="report.md", + ) + + +@pytest.mark.asyncio +async def test_aiocqhttp_send_private_file_uses_upload_action(tmp_path): + bot = AsyncMock() + file_path = tmp_path / "report.md" + file_path.write_text("report", encoding="utf-8") + chain = MessageChain([Comp.File(name="report.md", file=str(file_path))]) + + await AiocqhttpMessageEvent.send_message( + bot=bot, + message_chain=chain, + event=None, + is_group=False, + session_id="654321", + ) + + bot.upload_private_file.assert_awaited_once_with( + user_id=654321, + file=str(file_path), + name="report.md", + ) + bot.send_private_msg.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_aiocqhttp_send_group_file_falls_back_to_file_segment(tmp_path): + bot = AsyncMock() + bot.upload_group_file.side_effect = RuntimeError("unsupported action") + file_path = tmp_path / "report.md" + file_path.write_text("report", encoding="utf-8") + chain = MessageChain([Comp.File(name="report.md", file=str(file_path))]) + + await AiocqhttpMessageEvent.send_message( + bot=bot, + message_chain=chain, + event=None, + is_group=True, + session_id="123456", + ) + + bot.upload_group_file.assert_awaited_once() + bot.send_group_msg.assert_awaited_once_with( + group_id=123456, + message=[ + { + "type": "file", + "data": {"name": "report.md", "file": file_path.as_uri()}, + } + ], + ) + + +@pytest.mark.asyncio +async def test_aiocqhttp_send_private_file_falls_back_to_file_segment(tmp_path): + bot = AsyncMock() + bot.upload_private_file.side_effect = RuntimeError("unsupported action") + file_path = tmp_path / "report.md" + file_path.write_text("report", encoding="utf-8") + chain = MessageChain([Comp.File(name="report.md", file=str(file_path))]) + + await AiocqhttpMessageEvent.send_message( + bot=bot, + message_chain=chain, + event=None, + is_group=False, + session_id="654321", + ) + + bot.upload_private_file.assert_awaited_once() + bot.send_private_msg.assert_awaited_once_with( + user_id=654321, + message=[ + { + "type": "file", + "data": {"name": "report.md", "file": file_path.as_uri()}, + } + ], + ) diff --git a/tests/unit/test_astr_agent_tool_exec.py b/tests/unit/test_astr_agent_tool_exec.py index 5fab9fe0a2..ccc6b11e57 100644 --- a/tests/unit/test_astr_agent_tool_exec.py +++ b/tests/unit/test_astr_agent_tool_exec.py @@ -6,17 +6,28 @@ from astrbot.core.agent.run_context import ContextWrapper from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor from astrbot.core.message.components import Image +from astrbot.core.provider.register import llm_tools class _DummyEvent: def __init__(self, message_components: list[object] | None = None) -> None: self.unified_msg_origin = "webchat:FriendMessage:webchat!user!session" self.message_obj = SimpleNamespace(message=message_components or []) + self.role = "member" def get_extra(self, _key: str): return None +class _NoopRunner: + async def step_until_done(self, _limit: int): + if False: + yield None + + def get_final_llm_resp(self): + return None + + class _DummyTool: def __init__(self) -> None: self.name = "transfer_to_subagent" @@ -224,6 +235,178 @@ async def _fake_tool_loop_agent(**kwargs): assert captured["image_urls"] == ["https://example.com/raw.png"] +@pytest.mark.asyncio +async def test_build_handoff_toolset_uses_registered_provider_tools_only( + monkeypatch: pytest.MonkeyPatch, +): + from astrbot.core.agent.tool import FunctionTool + from astrbot.core.computer import computer_client + + registered_provider_tool = FunctionTool( + name="provider_a_screenshot", + parameters={"type": "object", "properties": {}}, + description="Provider A screenshot", + ) + registered_provider_tool.sandbox_provider_id = "provider_a" + unregistered_provider_tool = FunctionTool( + name="provider_b_tool", + parameters={"type": "object", "properties": {}}, + description="Provider B tool", + ) + unregistered_provider_tool.sandbox_provider_id = "provider_b" + + previous_tools = list(llm_tools.func_list) + FunctionToolExecutor._runtime_computer_tools_cache.clear() + llm_tools.func_list = [registered_provider_tool] + + tool_mgr = SimpleNamespace( + get_builtin_tool=lambda cls, **kwargs: cls(**kwargs), + get_func=lambda name: { + "provider_a_screenshot": registered_provider_tool, + "provider_b_tool": unregistered_provider_tool, + }.get(name), + ) + context = SimpleNamespace( + get_config=lambda **_kwargs: { + "provider_settings": { + "computer_use_runtime": "sandbox", + "sandbox": {"booter": "provider_a"}, + } + }, + get_llm_tool_manager=lambda: tool_mgr, + ) + event = _DummyEvent([]) + run_context = ContextWrapper(context=SimpleNamespace(event=event, context=context)) + monkeypatch.setattr( + computer_client, + "list_sandbox_providers", + lambda: [ + {"provider_id": "provider_b", "tool_names": ["provider_b_tool"]}, + ], + ) + + try: + toolset = FunctionToolExecutor._build_handoff_toolset(run_context, None) + assert toolset is not None + assert "astrbot_sandbox_query" in toolset.names() + assert "provider_a_screenshot" in toolset.names() + assert "provider_b_tool" not in toolset.names() + finally: + llm_tools.func_list = previous_tools + FunctionToolExecutor._runtime_computer_tools_cache.clear() + + +@pytest.mark.asyncio +async def test_build_handoff_toolset_uses_scoped_tool_manager_for_all_tools(): + from astrbot.core.agent.tool import FunctionTool + + allowed_tool = FunctionTool( + name="allowed_tool", + parameters={"type": "object", "properties": {}}, + description="allowed", + ) + disallowed_tool = FunctionTool( + name="disallowed_tool", + parameters={"type": "object", "properties": {}}, + description="disallowed", + ) + previous_tools = list(llm_tools.func_list) + FunctionToolExecutor._runtime_computer_tools_cache.clear() + llm_tools.func_list = [allowed_tool, disallowed_tool] + tool_mgr = SimpleNamespace(func_list=[allowed_tool], get_func=lambda name: None) + context = SimpleNamespace( + get_config=lambda **_kwargs: { + "provider_settings": {"computer_use_runtime": "none"} + }, + get_llm_tool_manager=lambda: tool_mgr, + ) + run_context = ContextWrapper( + context=SimpleNamespace(event=_DummyEvent([]), context=context) + ) + + try: + toolset = FunctionToolExecutor._build_handoff_toolset(run_context, None) + assert toolset is not None + assert "allowed_tool" in toolset.names() + assert "disallowed_tool" not in toolset.names() + finally: + llm_tools.func_list = previous_tools + FunctionToolExecutor._runtime_computer_tools_cache.clear() + + +def test_clear_runtime_computer_tools_cache_provider_id_clears_all_entries(): + FunctionToolExecutor._runtime_computer_tools_cache = { + (1, "sandbox", ""): {}, + (2, "local", "other"): {}, + } + + FunctionToolExecutor.clear_runtime_computer_tools_cache("generic") + + assert FunctionToolExecutor._runtime_computer_tools_cache == {} + + +@pytest.mark.asyncio +async def test_background_wake_preserves_computer_runtime_config( + monkeypatch: pytest.MonkeyPatch, +): + captured: dict = {} + + async def _fake_get_session_conv(**_kwargs): + return SimpleNamespace(history="[]") + + async def _fake_build_main_agent(**kwargs): + captured["config"] = kwargs["config"] + return SimpleNamespace(agent_runner=_NoopRunner()) + + async def _fake_persist_agent_history(*_args, **_kwargs): + return None + + provider_settings = { + "computer_use_runtime": "sandbox", + "sandbox": {"booter": "generic", "max_sandboxes": 2}, + "stream": True, + } + context = SimpleNamespace( + conversation_manager=SimpleNamespace(), + get_config=lambda **_kwargs: {"provider_settings": provider_settings}, + get_llm_tool_manager=lambda: SimpleNamespace( + get_builtin_tool=lambda _cls: SimpleNamespace( + name="send_message_to_user", active=True + ) + ), + ) + run_context = ContextWrapper( + context=SimpleNamespace(event=_DummyEvent([]), context=context), + tool_call_timeout=17, + ) + + monkeypatch.setattr( + "astrbot.core.astr_main_agent._get_session_conv", _fake_get_session_conv + ) + monkeypatch.setattr( + "astrbot.core.astr_main_agent.build_main_agent", _fake_build_main_agent + ) + monkeypatch.setattr( + "astrbot.core.astr_agent_tool_exec.persist_agent_history", + _fake_persist_agent_history, + ) + + await FunctionToolExecutor._wake_main_agent_for_background_result( + run_context, + task_id="task-id", + tool_name="tool", + result_text="result", + tool_args={}, + note="note", + summary_name="tool", + ) + + config = captured["config"] + assert config.computer_use_runtime == "sandbox" + assert config.sandbox_cfg == {"booter": "generic", "max_sandboxes": 2} + assert config.provider_settings == provider_settings + + @pytest.mark.asyncio async def test_collect_handoff_image_urls_keeps_extensionless_existing_event_file( monkeypatch: pytest.MonkeyPatch, diff --git a/tests/unit/test_astr_main_agent.py b/tests/unit/test_astr_main_agent.py index 31c80e09ea..70ec70883e 100644 --- a/tests/unit/test_astr_main_agent.py +++ b/tests/unit/test_astr_main_agent.py @@ -1,7 +1,6 @@ """Tests for astr_main_agent module.""" import datetime -import os from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -817,6 +816,219 @@ async def test_ensure_tools_from_persona(self, mock_event, mock_context): assert req.func_tool is not None + def test_apply_sandbox_tools_respects_explicit_empty_persona_tool_list(self): + module = ama + req = ProviderRequest(func_tool=ToolSet()) + req._persona_tools_configured = True + req._persona_allowed_tool_names = set() + config = module.MainAgentBuildConfig( + tool_call_timeout=60, + computer_use_runtime="sandbox", + provider_settings={"computer_use_runtime": "sandbox"}, + ) + + module._apply_sandbox_tools(config, req) + + assert req.func_tool is not None + assert req.func_tool.empty() + + @pytest.mark.asyncio + async def test_ensure_persona_keeps_all_sandbox_provider_tools_in_sandbox_runtime( + self, mock_event, mock_context + ): + module = ama + provider_tool = FunctionTool( + name="provider_a_screenshot", + parameters={"type": "object", "properties": {}}, + description="provider screenshot", + ) + provider_tool.sandbox_provider_id = "provider_a" + generic_tool = FunctionTool( + name="regular_tool", + parameters={"type": "object", "properties": {}}, + description="regular", + ) + tmgr = mock_context.get_llm_tool_manager.return_value + tmgr.get_full_tool_set.return_value = ToolSet([provider_tool, generic_tool]) + mock_context.persona_manager.resolve_selected_persona = AsyncMock( + return_value=(None, None, None, False) + ) + + req = ProviderRequest() + req.conversation = MagicMock(persona_id=None) + + await module._ensure_persona_and_skills( + req, + { + "computer_use_runtime": "sandbox", + "sandbox": {"booter": "provider_b"}, + }, + mock_context, + mock_event, + ) + + assert req.func_tool is not None + assert "regular_tool" in req.func_tool.names() + assert "provider_a_screenshot" in req.func_tool.names() + + @pytest.mark.asyncio + async def test_ensure_persona_keeps_current_sandbox_provider_tools( + self, mock_event, mock_context + ): + module = ama + provider_tool = FunctionTool( + name="provider_a_screenshot", + parameters={"type": "object", "properties": {}}, + description="provider screenshot", + ) + provider_tool.sandbox_provider_id = "provider_a" + tmgr = mock_context.get_llm_tool_manager.return_value + tmgr.get_full_tool_set.return_value = ToolSet([provider_tool]) + mock_context.persona_manager.resolve_selected_persona = AsyncMock( + return_value=(None, None, None, False) + ) + + req = ProviderRequest() + req.conversation = MagicMock(persona_id=None) + + with patch( + "astrbot.core.computer.computer_client.get_current_sandbox_provider_id", + return_value="provider_a", + ): + await module._ensure_persona_and_skills( + req, + { + "computer_use_runtime": "sandbox", + "sandbox": {"booter": "provider_a"}, + }, + mock_context, + mock_event, + ) + + assert req.func_tool is not None + assert "provider_a_screenshot" in req.func_tool.names() + + @pytest.mark.asyncio + async def test_ensure_persona_keeps_all_provider_tools_with_bound_sandbox( + self, monkeypatch, mock_event, mock_context + ): + module = ama + provider_a_tool = FunctionTool( + name="provider_a_screenshot", + parameters={"type": "object", "properties": {}}, + description="Provider A screenshot", + ) + provider_a_tool.sandbox_provider_id = "provider_a" + provider_b_tool = FunctionTool( + name="provider_b_tool", + parameters={"type": "object", "properties": {}}, + description="Provider B tool", + ) + provider_b_tool.sandbox_provider_id = "provider_b" + tmgr = mock_context.get_llm_tool_manager.return_value + tmgr.get_full_tool_set.return_value = ToolSet( + [provider_a_tool, provider_b_tool] + ) + mock_context.persona_manager.resolve_selected_persona = AsyncMock( + return_value=(None, None, None, False) + ) + + from astrbot.core.computer.computer_client import sandbox_manager + + monkeypatch.setattr( + sandbox_manager.registry, + "get_current_sandbox_id", + lambda session_id: ( + "provider-b-1" if session_id == mock_event.unified_msg_origin else None + ), + ) + monkeypatch.setattr( + sandbox_manager.registry, + "get_sandbox", + lambda sandbox_id: ( + {"sandbox_id": sandbox_id, "provider": "provider_b"} + if sandbox_id == "provider-b-1" + else None + ), + ) + + req = ProviderRequest() + req.conversation = MagicMock(persona_id=None) + + await module._ensure_persona_and_skills( + req, + { + "computer_use_runtime": "sandbox", + "sandbox": {"booter": "provider_a"}, + }, + mock_context, + mock_event, + ) + + assert req.func_tool is not None + assert "provider_b_tool" in req.func_tool.names() + assert "provider_a_screenshot" in req.func_tool.names() + + @pytest.mark.asyncio + async def test_handoff_all_tools_filters_other_sandbox_provider_tools( + self, mock_event, mock_context + ): + module = ama + provider_tool = FunctionTool( + name="provider_a_screenshot", + parameters={"type": "object", "properties": {}}, + description="provider screenshot", + ) + provider_tool.sandbox_provider_id = "provider_a" + generic_tool = FunctionTool( + name="regular_tool", + parameters={"type": "object", "properties": {}}, + description="regular", + ) + tmgr = mock_context.get_llm_tool_manager.return_value + tmgr.func_list = [provider_tool, generic_tool] + tmgr.get_full_tool_set.return_value = ToolSet([provider_tool, generic_tool]) + mock_context.persona_manager.resolve_selected_persona = AsyncMock( + return_value=(None, None, None, False) + ) + mock_context.persona_manager.get_persona_v3_by_id = MagicMock( + return_value={"name": "default", "tools": None} + ) + handoff = MagicMock() + handoff.name = "transfer_to_planner" + mock_context.subagent_orchestrator = MagicMock(handoffs=[handoff]) + mock_context.get_config.return_value = { + "subagent_orchestrator": { + "main_enable": True, + "remove_main_duplicate_tools": True, + "agents": [ + { + "name": "planner", + "enabled": True, + "persona_id": "default", + } + ], + } + } + + req = ProviderRequest() + req.conversation = MagicMock(persona_id=None) + + await module._ensure_persona_and_skills( + req, + { + "computer_use_runtime": "sandbox", + "sandbox": {"booter": "provider_b"}, + }, + mock_context, + mock_event, + ) + + assert req.func_tool is not None + assert "transfer_to_planner" in req.func_tool.names() + assert "regular_tool" not in req.func_tool.names() + assert "provider_a_screenshot" not in req.func_tool.names() + @pytest.mark.asyncio async def test_subagent_dedupe_uses_default_persona_tools( self, mock_event, mock_context @@ -1025,6 +1237,89 @@ def test_plugin_tool_fix_preserves_tools_without_plugin_origin(self, mock_event) assert "transfer_to_demo_agent" in req.func_tool.names() + def test_plugin_tool_fix_keeps_provider_specific_tools_in_sandbox_runtime( + self, mock_event + ): + module = ama + provider_tool = FunctionTool( + name="astrbot_cua_mouse_click", + description="provider-specific", + parameters={"type": "object", "properties": {}}, + handler_module_path=None, + active=True, + ) + provider_tool.sandbox_provider_id = "cua" + generic_tool = FunctionTool( + name="astrbot_sandbox_query", + description="generic", + parameters={"type": "object", "properties": {}}, + handler_module_path=None, + active=True, + ) + + tool_set = ToolSet() + tool_set.add_tool(provider_tool) + tool_set.add_tool(generic_tool) + req = ProviderRequest(func_tool=tool_set, session_id="session-a") + mock_event.plugins_name = ["other_plugin"] + mock_event.unified_msg_origin = "session-a" + + with ( + patch("astrbot.core.astr_main_agent.star_map"), + patch( + "astrbot.core.computer.computer_client.get_current_sandbox_provider_id", + return_value=None, + ), + ): + module._plugin_tool_fix( + mock_event, + req, + {"computer_use_runtime": "sandbox"}, + ) + + assert "astrbot_sandbox_query" in req.func_tool.names() + assert "astrbot_cua_mouse_click" in req.func_tool.names() + + def test_plugin_tool_fix_hides_provider_specific_tools_outside_sandbox_runtime( + self, mock_event + ): + module = ama + cua_tool = FunctionTool( + name="astrbot_cua_keyboard_type", + description="cua", + parameters={"type": "object", "properties": {}}, + handler_module_path=None, + active=True, + ) + cua_tool.sandbox_provider_id = "cua" + neo_tool = FunctionTool( + name="astrbot_execute_browser", + description="neo", + parameters={"type": "object", "properties": {}}, + handler_module_path=None, + active=True, + ) + neo_tool.sandbox_provider_id = "shipyard_neo" + + tool_set = ToolSet() + tool_set.add_tool(cua_tool) + tool_set.add_tool(neo_tool) + req = ProviderRequest(func_tool=tool_set, session_id="session-a") + mock_event.plugins_name = ["other_plugin"] + mock_event.unified_msg_origin = "session-a" + + with ( + patch("astrbot.core.astr_main_agent.star_map"), + ): + module._plugin_tool_fix( + mock_event, + req, + {"computer_use_runtime": "local"}, + ) + + assert "astrbot_cua_keyboard_type" not in req.func_tool.names() + assert "astrbot_execute_browser" not in req.func_tool.names() + class TestBuildMainAgent: """Tests for build_main_agent function.""" @@ -1924,7 +2219,7 @@ def test_apply_sandbox_tools_creates_toolset_if_none(self, mock_context): ) req = ProviderRequest(prompt="Test", func_tool=None) - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req) assert req.func_tool is not None assert isinstance(req.func_tool, ToolSet) @@ -1939,13 +2234,16 @@ def test_apply_sandbox_tools_adds_required_tools(self, mock_context): ) req = ProviderRequest(prompt="Test", func_tool=None) - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req) tool_names = req.func_tool.names() assert "astrbot_execute_shell" in tool_names assert "astrbot_execute_ipython" in tool_names assert "astrbot_upload_file" in tool_names assert "astrbot_download_file" in tool_names + assert "astrbot_sandbox_query" in tool_names + assert "astrbot_sandbox_lifecycle" in tool_names + assert "astrbot_sandbox_operation" in tool_names def test_apply_sandbox_tools_adds_sandbox_prompt(self, mock_context): """Test that sandbox mode prompt is added to system_prompt.""" @@ -1957,149 +2255,189 @@ def test_apply_sandbox_tools_adds_sandbox_prompt(self, mock_context): ) req = ProviderRequest(prompt="Test", system_prompt="Original prompt") - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req) assert "sandboxed environment" in req.system_prompt + assert "send screenshots to the user to show progress" not in req.system_prompt - def test_apply_sandbox_tools_with_cua_adds_gui_guidance(self, mock_context): - """Test that CUA sandbox guidance nudges reliable GUI workflows.""" + def test_apply_sandbox_tools_preserves_existing_toolset(self, mock_context): + """Test that existing tools are preserved when adding sandbox tools.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", - sandbox_cfg={"booter": "cua"}, + sandbox_cfg={}, ) - req = ProviderRequest(prompt="Test", system_prompt="Original prompt") + existing_toolset = ToolSet() + existing_tool = MagicMock() + existing_tool.name = "existing_tool" + existing_toolset.add_tool(existing_tool) + req = ProviderRequest(prompt="Test", func_tool=existing_toolset) - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req) - assert req.func_tool is not None - tool_names = req.func_tool.names() - assert "astrbot_cua_screenshot" in tool_names - assert "astrbot_cua_mouse_click" in tool_names - assert "astrbot_cua_keyboard_type" in tool_names - assert "astrbot_cua_key_press" not in tool_names - - assert "Firefox" in req.system_prompt - assert "background=true" in req.system_prompt - assert 'firefox "https://example.com"' in req.system_prompt - assert "astrbot_cua_screenshot" in req.system_prompt - assert "astrbot_cua_key_press" not in req.system_prompt - assert "return_image_to_llm" in req.system_prompt - assert "astrbot_execute_shell" in req.system_prompt - assert "\\n" in req.system_prompt - assert "send_to_user=true" in req.system_prompt - assert "focused and empty or safe to append" in req.system_prompt + assert "existing_tool" in req.func_tool.names() + assert "astrbot_execute_shell" in req.func_tool.names() - def test_apply_sandbox_tools_with_shipyard_booter(self, monkeypatch, mock_context): - """Test sandbox tools with shipyard booter configuration.""" + def test_apply_sandbox_tools_appends_to_existing_system_prompt(self, mock_context): + """Test that sandbox prompt is appended to existing system prompt.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", - sandbox_cfg={ - "booter": "shipyard", - "shipyard_endpoint": "https://shipyard.example.com", - "shipyard_access_token": "test-token", - }, + sandbox_cfg={}, ) - req = ProviderRequest(prompt="Test", func_tool=None) - - monkeypatch.delenv("SHIPYARD_ENDPOINT", raising=False) - monkeypatch.delenv("SHIPYARD_ACCESS_TOKEN", raising=False) + req = ProviderRequest(prompt="Test", system_prompt="Base prompt") - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req) - assert os.environ.get("SHIPYARD_ENDPOINT") == "https://shipyard.example.com" - assert os.environ.get("SHIPYARD_ACCESS_TOKEN") == "test-token" + assert req.system_prompt.startswith("Base prompt") + assert "sandboxed environment" in req.system_prompt - def test_apply_sandbox_tools_shipyard_missing_endpoint(self, mock_context): - """Test that shipyard config is skipped when endpoint is missing.""" + def test_apply_sandbox_tools_with_none_system_prompt(self, mock_context): + """Test that sandbox prompt is applied when system_prompt is None.""" module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", - sandbox_cfg={ - "booter": "shipyard", - "shipyard_endpoint": "", - "shipyard_access_token": "test-token", - }, + sandbox_cfg={}, ) - req = ProviderRequest(prompt="Test", func_tool=None) + req = ProviderRequest(prompt="Test", system_prompt=None) - with patch("astrbot.core.astr_main_agent.logger") as mock_logger: - module._apply_sandbox_tools(config, req, "session-123") + module._apply_sandbox_tools(config, req) - mock_logger.error.assert_called_once() + assert isinstance(req.system_prompt, str) + assert "sandboxed environment" in req.system_prompt + assert "check the current sandbox first" in req.system_prompt + assert "listing sandbox providers" in req.system_prompt + assert "inspect each sandbox's access field" in req.system_prompt + assert "Never treat status=running alone as reusable" in req.system_prompt + assert "access.status=occupied" in req.system_prompt + assert "fresh or separate environment" in req.system_prompt + assert "send screenshots to the user to show progress" not in req.system_prompt assert ( - "Shipyard sandbox configuration is incomplete" - in mock_logger.error.call_args[0][0] + "astrbot_sandbox_operation with action=capture_screenshot" + in req.system_prompt ) + assert "send_to_user=true" in req.system_prompt + assert "send_message_to_user separately" in req.system_prompt + assert "automatically renews this session's lease" in req.system_prompt + assert "now plus the configured sandbox lease timeout" in req.system_prompt + assert "lease_expires_in_seconds" in req.system_prompt + assert "no longer has a current sandbox" in req.system_prompt - def test_apply_sandbox_tools_shipyard_missing_access_token(self, mock_context): - """Test that shipyard config is skipped when access token is missing.""" + def test_apply_sandbox_tools_does_not_scan_provider_tool_names(self, mock_context): module = ama config = module.MainAgentBuildConfig( tool_call_timeout=60, computer_use_runtime="sandbox", - sandbox_cfg={ - "booter": "shipyard", - "shipyard_endpoint": "https://shipyard.example.com", - "shipyard_access_token": "", - }, + sandbox_cfg={"booter": "provider_a"}, ) - req = ProviderRequest(prompt="Test", func_tool=None) + req = ProviderRequest(prompt="Test", system_prompt="Base prompt") + req.session_id = "session-a" - with patch("astrbot.core.astr_main_agent.logger") as mock_logger: - module._apply_sandbox_tools(config, req, "session-123") + with ( + patch( + "astrbot.core.computer.computer_client.list_sandbox_providers", + return_value=[ + { + "provider_id": "provider_a", + "tool_names": ["provider_a_screenshot"], + } + ], + ), + patch( + "astrbot.core.provider.register.llm_tools.get_func", + side_effect=AssertionError("provider tools must be registered once"), + ), + ): + module._apply_sandbox_tools(config, req) - mock_logger.error.assert_called_once() + assert "provider_a_screenshot" not in req.func_tool.names() + assert "send screenshots to the user to show progress" not in req.system_prompt - def test_apply_sandbox_tools_preserves_existing_toolset(self, mock_context): - """Test that existing tools are preserved when adding sandbox tools.""" + def test_registered_provider_tools_are_included_by_persona_toolset( + self, mock_context + ): module = ama - config = module.MainAgentBuildConfig( - tool_call_timeout=60, - computer_use_runtime="sandbox", - sandbox_cfg={}, + cfg = {"computer_use_runtime": "sandbox"} + toolset = ToolSet() + + provider_tool = FunctionTool( + name="provider_a_screenshot", + parameters={"type": "object", "properties": {}}, + description="Provider A screenshot", ) - existing_toolset = ToolSet() - existing_tool = MagicMock() - existing_tool.name = "existing_tool" - existing_toolset.add_tool(existing_tool) - req = ProviderRequest(prompt="Test", func_tool=existing_toolset) + provider_tool.sandbox_provider_id = "provider_a" + toolset.add_tool(provider_tool) - module._apply_sandbox_tools(config, req, "session-123") + filtered = module._filter_tools_for_current_config(toolset, cfg, "session-a") - assert "existing_tool" in req.func_tool.names() - assert "astrbot_execute_shell" in req.func_tool.names() + assert "provider_a_screenshot" in filtered.names() - def test_apply_sandbox_tools_appends_to_existing_system_prompt(self, mock_context): - """Test that sandbox prompt is appended to existing system prompt.""" + def test_filter_tools_for_current_config_applies_builtin_runtime_rules(self): module = ama - config = module.MainAgentBuildConfig( - tool_call_timeout=60, - computer_use_runtime="sandbox", - sandbox_cfg={}, + toolset = ToolSet() + toolset.add_tool(module.LocalPythonTool()) + toolset.add_tool(module.SandboxLifecycleTool()) + + sandbox_filtered = module._filter_tools_for_current_config( + toolset, {"computer_use_runtime": "sandbox"}, "session-a" + ) + none_filtered = module._filter_tools_for_current_config( + toolset, {"computer_use_runtime": "none"}, "session-a" ) - req = ProviderRequest(prompt="Test", system_prompt="Base prompt") - module._apply_sandbox_tools(config, req, "session-123") + assert "astrbot_execute_python" not in sandbox_filtered.names() + assert "astrbot_sandbox_lifecycle" in sandbox_filtered.names() + assert "astrbot_execute_python" not in none_filtered.names() + assert "astrbot_sandbox_lifecycle" not in none_filtered.names() - assert req.system_prompt.startswith("Base prompt") - assert "sandboxed environment" in req.system_prompt + def test_handoff_runtime_computer_tools_include_sandbox_lifecycle_tools(self): + tool_mgr = MagicMock() - def test_apply_sandbox_tools_with_none_system_prompt(self, mock_context): - """Test that sandbox prompt is applied when system_prompt is None.""" - module = ama - config = module.MainAgentBuildConfig( - tool_call_timeout=60, - computer_use_runtime="sandbox", - sandbox_cfg={}, - ) - req = ProviderRequest(prompt="Test", system_prompt=None) + class NamedTool: + def __init__(self, name): + self.name = name + self.active = True - module._apply_sandbox_tools(config, req, "session-123") + def get_builtin_tool(cls, **kwargs): + del kwargs + return cls() - assert isinstance(req.system_prompt, str) - assert "sandboxed environment" in req.system_prompt + tool_mgr.get_builtin_tool.side_effect = get_builtin_tool + tool_mgr.get_func.side_effect = lambda name: NamedTool(name) + + with patch( + "astrbot.core.computer.computer_client.list_sandbox_providers", + return_value=[], + ): + tools = ama.FunctionToolExecutor._get_runtime_computer_tools( + "sandbox", tool_mgr, "provider_a" + ) + + assert "astrbot_sandbox_query" in tools + assert "astrbot_sandbox_lifecycle" in tools + assert "astrbot_sandbox_operation" in tools + + def test_runtime_computer_tools_are_cached_per_runtime_and_booter(self): + tool_mgr = MagicMock() + + def get_builtin_tool(cls, **kwargs): + del kwargs + return cls() + + tool_mgr.get_builtin_tool.side_effect = get_builtin_tool + tool_mgr.get_func.side_effect = lambda name: None + ama.FunctionToolExecutor._runtime_computer_tools_cache.clear() + + first = ama.FunctionToolExecutor._get_runtime_computer_tools( + "sandbox", tool_mgr, "provider_a" + ) + call_count = tool_mgr.get_builtin_tool.call_count + second = ama.FunctionToolExecutor._get_runtime_computer_tools( + "sandbox", tool_mgr, "provider_a" + ) + + assert first is second + assert tool_mgr.get_builtin_tool.call_count == call_count diff --git a/tests/unit/test_builtin_tool_registry.py b/tests/unit/test_builtin_tool_registry.py new file mode 100644 index 0000000000..95d2169f9a --- /dev/null +++ b/tests/unit/test_builtin_tool_registry.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass, field + +from astrbot.api import FunctionTool +from astrbot.core.tools.registry import ( + builtin_tool, + register_builtin_tools_by_module_prefix, + get_builtin_tool_class, + get_builtin_tool_name, + unregister_builtin_tools_by_module_prefix, +) + + +def _register_test_tool(): + @builtin_tool + @dataclass + class ExampleTool(FunctionTool): + name: str = "astrbot_test_reload_tool" + description: str = "Test tool for registry reload behavior." + parameters: dict = field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + + return ExampleTool + + +def test_builtin_tool_registry_can_unregister_and_reregister_same_name(): + tool_cls_1 = _register_test_tool() + registered = register_builtin_tools_by_module_prefix(__name__) + + assert registered == ["astrbot_test_reload_tool"] + assert get_builtin_tool_class("astrbot_test_reload_tool") is tool_cls_1 + assert get_builtin_tool_name(tool_cls_1) == "astrbot_test_reload_tool" + + removed = unregister_builtin_tools_by_module_prefix(__name__) + + assert removed == ["astrbot_test_reload_tool"] + assert get_builtin_tool_class("astrbot_test_reload_tool") is None + assert get_builtin_tool_name(tool_cls_1) is None + + tool_cls_2 = _register_test_tool() + assert tool_cls_2 is not tool_cls_1 + assert get_builtin_tool_class("astrbot_test_reload_tool") is tool_cls_2 + + unregister_builtin_tools_by_module_prefix(__name__) diff --git a/tests/unit/test_computer.py b/tests/unit/test_computer.py index 71b31a301a..2213ab6a6e 100644 --- a/tests/unit/test_computer.py +++ b/tests/unit/test_computer.py @@ -1,10 +1,9 @@ """Tests for astrbot/core/computer module. -This module tests the ComputerClient, Booter implementations (local, shipyard, boxlite), +This module tests the ComputerClient, local booter implementation, filesystem operations, Python execution, shell execution, and security restrictions. """ -import sys from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -372,118 +371,6 @@ def test_base_class_is_protocol(self): assert hasattr(booter, "available") -class TestShipyardBooter: - """Tests for ShipyardBooter.""" - - @pytest.mark.asyncio - async def test_shipyard_booter_init(self): - """Test ShipyardBooter initialization.""" - with patch("astrbot.core.computer.booters.shipyard.ShipyardClient"): - from astrbot.core.computer.booters.shipyard import ShipyardBooter - - booter = ShipyardBooter( - endpoint_url="http://localhost:8080", - access_token="test_token", - ttl=3600, - session_num=10, - ) - assert booter._ttl == 3600 - assert booter._session_num == 10 - - @pytest.mark.asyncio - async def test_shipyard_booter_boot(self): - """Test ShipyardBooter boot method.""" - mock_ship = MagicMock() - mock_ship.id = "test-ship-id" - mock_ship.fs = MagicMock() - mock_ship.python = MagicMock() - mock_ship.shell = MagicMock() - - mock_client = MagicMock() - mock_client.create_ship = AsyncMock(return_value=mock_ship) - - with patch( - "astrbot.core.computer.booters.shipyard.ShipyardClient", - return_value=mock_client, - ): - from astrbot.core.computer.booters.shipyard import ShipyardBooter - - booter = ShipyardBooter( - endpoint_url="http://localhost:8080", - access_token="test_token", - ) - await booter.boot("test-session") - assert booter._ship == mock_ship - - @pytest.mark.asyncio - async def test_shipyard_available_healthy(self): - """Test ShipyardBooter available when healthy.""" - mock_ship = MagicMock() - mock_ship.id = "test-ship-id" - - mock_client = MagicMock() - mock_client.get_ship = AsyncMock(return_value={"status": 1}) - - with patch( - "astrbot.core.computer.booters.shipyard.ShipyardClient", - return_value=mock_client, - ): - from astrbot.core.computer.booters.shipyard import ShipyardBooter - - booter = ShipyardBooter( - endpoint_url="http://localhost:8080", - access_token="test_token", - ) - booter._ship = mock_ship - booter._sandbox_client = mock_client - - result = await booter.available() - assert result is True - - @pytest.mark.asyncio - async def test_shipyard_available_unhealthy(self): - """Test ShipyardBooter available when unhealthy.""" - mock_ship = MagicMock() - mock_ship.id = "test-ship-id" - - mock_client = MagicMock() - mock_client.get_ship = AsyncMock(return_value={"status": 0}) - - with patch( - "astrbot.core.computer.booters.shipyard.ShipyardClient", - return_value=mock_client, - ): - from astrbot.core.computer.booters.shipyard import ShipyardBooter - - booter = ShipyardBooter( - endpoint_url="http://localhost:8080", - access_token="test_token", - ) - booter._ship = mock_ship - booter._sandbox_client = mock_client - - result = await booter.available() - assert result is False - - -class TestBoxliteBooter: - """Tests for BoxliteBooter.""" - - @pytest.mark.asyncio - async def test_boxlite_booter_init(self): - """Test BoxliteBooter can be instantiated via __new__.""" - # Need to mock boxlite module before importing - mock_boxlite = MagicMock() - mock_boxlite.SimpleBox = MagicMock() - - with patch.dict(sys.modules, {"boxlite": mock_boxlite}): - from astrbot.core.computer.booters.boxlite import BoxliteBooter - - # Just verify class exists and can be instantiated (boot is async) - booter = BoxliteBooter.__new__(BoxliteBooter) - assert booter is not None - - class TestComputerClient: """Tests for computer_client module functions.""" @@ -503,67 +390,11 @@ def test_get_local_booter(self): # Reset for other tests computer_client.local_booter = None - @pytest.mark.asyncio - async def test_get_booter_shipyard(self): - """Test get_booter with shipyard type.""" - from astrbot.core.computer import computer_client - from astrbot.core.computer.booters.shipyard import ShipyardBooter - - # Clear session booter - computer_client.session_booter.clear() - - mock_context = MagicMock() - mock_config = MagicMock() - mock_config.get = lambda key, default=None: { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": { - "booter": "shipyard", - "shipyard_endpoint": "http://localhost:8080", - "shipyard_access_token": "test_token", - "shipyard_ttl": 3600, - "shipyard_max_sessions": 10, - }, - } - }.get(key, default) - mock_context.get_config = MagicMock(return_value=mock_config) - - # Mock the ShipyardBooter - mock_ship = MagicMock() - mock_ship.id = "test-ship-id" - mock_ship.fs = MagicMock() - mock_ship.python = MagicMock() - mock_ship.shell = MagicMock() - - mock_booter = MagicMock() - mock_booter.boot = AsyncMock() - mock_booter.available = AsyncMock(return_value=True) - mock_booter.shell = MagicMock() - mock_booter.upload_file = AsyncMock(return_value={"success": True}) - - with ( - patch.object(ShipyardBooter, "boot", new=AsyncMock()), - patch( - "astrbot.core.computer.computer_client._sync_skills_to_sandbox", - AsyncMock(), - ), - ): - # Directly set the booter in the session - computer_client.session_booter["test-session-id"] = mock_booter - - booter = await computer_client.get_booter(mock_context, "test-session-id") - assert booter is mock_booter - - # Cleanup - computer_client.session_booter.clear() - @pytest.mark.asyncio async def test_get_booter_unknown_type(self): - """Test get_booter with unknown booter type raises ValueError.""" + """Test get_booter with unknown sandbox provider raises ValueError.""" from astrbot.core.computer import computer_client - computer_client.session_booter.clear() - mock_context = MagicMock() mock_config = MagicMock() mock_config.get = lambda key, default=None: { @@ -578,60 +409,13 @@ async def test_get_booter_unknown_type(self): with pytest.raises(ValueError) as exc_info: await computer_client.get_booter(mock_context, "test-session-id") - assert "Unknown booter type" in str(exc_info.value) - - @pytest.mark.asyncio - async def test_get_booter_reuses_existing(self): - """Test get_booter reuses existing booter for same session.""" - from astrbot.core.computer import computer_client - from astrbot.core.computer.booters.shipyard import ShipyardBooter - - computer_client.session_booter.clear() - - mock_context = MagicMock() - mock_config = MagicMock() - mock_config.get = lambda key, default=None: { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": { - "booter": "shipyard", - "shipyard_endpoint": "http://localhost:8080", - "shipyard_access_token": "test_token", - }, - } - }.get(key, default) - mock_context.get_config = MagicMock(return_value=mock_config) - - mock_booter = MagicMock() - mock_booter.boot = AsyncMock() - mock_booter.available = AsyncMock(return_value=True) - mock_booter.shell = MagicMock() - mock_booter.upload_file = AsyncMock(return_value={"success": True}) - - with ( - patch.object(ShipyardBooter, "boot", new=AsyncMock()), - patch( - "astrbot.core.computer.computer_client._sync_skills_to_sandbox", - AsyncMock(), - ), - ): - # Pre-set the booter - computer_client.session_booter["test-session"] = mock_booter - - booter1 = await computer_client.get_booter(mock_context, "test-session") - booter2 = await computer_client.get_booter(mock_context, "test-session") - assert booter1 is booter2 - - # Cleanup - computer_client.session_booter.clear() + assert "Unknown sandbox provider" in str(exc_info.value) + assert "Install and enable a sandbox provider plugin" in str(exc_info.value) @pytest.mark.asyncio - async def test_get_booter_rebuild_unavailable(self): - """Test get_booter rebuilds when existing booter is unavailable.""" + async def test_get_booter_empty_sandbox_provider_hint(self): + """Test get_booter with empty sandbox booter gives actionable error.""" from astrbot.core.computer import computer_client - from astrbot.core.computer.booters.shipyard import ShipyardBooter - - computer_client.session_booter.clear() mock_context = MagicMock() mock_config = MagicMock() @@ -639,54 +423,22 @@ async def test_get_booter_rebuild_unavailable(self): "provider_settings": { "computer_use_runtime": "sandbox", "sandbox": { - "booter": "shipyard", - "shipyard_endpoint": "http://localhost:8080", - "shipyard_access_token": "test_token", + "booter": "", }, } }.get(key, default) mock_context.get_config = MagicMock(return_value=mock_config) - mock_unavailable_booter = MagicMock(spec=ShipyardBooter) - mock_unavailable_booter.available = AsyncMock(return_value=False) - - mock_new_booter = MagicMock(spec=ShipyardBooter) - mock_new_booter.boot = AsyncMock() - - with ( - patch( - "astrbot.core.computer.booters.shipyard.ShipyardBooter", - return_value=mock_new_booter, - ) as mock_booter_cls, - patch( - "astrbot.core.computer.computer_client._sync_skills_to_sandbox", - AsyncMock(), - ), - ): - session_id = "test-session-rebuild" - # Pre-set the unavailable booter - computer_client.session_booter[session_id] = mock_unavailable_booter - - # get_booter should detect the booter is unavailable and create a new one - new_booter_instance = await computer_client.get_booter( - mock_context, session_id - ) - - # Assert that a new booter was created and is now in the session - mock_booter_cls.assert_called_once() - mock_new_booter.boot.assert_awaited_once() - assert new_booter_instance is mock_new_booter - assert computer_client.session_booter[session_id] is mock_new_booter - - # Cleanup - computer_client.session_booter.clear() + with pytest.raises(ValueError) as exc_info: + await computer_client.get_booter(mock_context, "test-session-id") + assert "Sandbox provider is not configured" in str(exc_info.value) class TestSyncSkillsToSandbox: """Tests for _sync_skills_to_sandbox function.""" @pytest.mark.asyncio - async def test_sync_skills_no_skills_dir(self): + async def test_sync_skills_no_skills_dir(self, tmp_path): """Test sync does nothing when skills directory doesn't exist.""" from astrbot.core.computer import computer_client @@ -694,21 +446,15 @@ async def test_sync_skills_no_skills_dir(self): mock_booter.shell.exec = AsyncMock() mock_booter.upload_file = AsyncMock(return_value={"success": True}) - with ( - patch( - "astrbot.core.computer.computer_client.get_astrbot_skills_path", - return_value="/nonexistent/path", - ), - patch( - "astrbot.core.computer.computer_client.os.path.isdir", - return_value=False, - ), + with patch( + "astrbot.core.computer.computer_client.get_astrbot_skills_path", + return_value=str(tmp_path / "missing"), ): await computer_client._sync_skills_to_sandbox(mock_booter) mock_booter.upload_file.assert_not_called() @pytest.mark.asyncio - async def test_sync_skills_empty_dir(self): + async def test_sync_skills_empty_dir(self, tmp_path): """Test sync does nothing when skills directory is empty.""" from astrbot.core.computer import computer_client @@ -716,25 +462,18 @@ async def test_sync_skills_empty_dir(self): mock_booter.shell.exec = AsyncMock() mock_booter.upload_file = AsyncMock(return_value={"success": True}) - with ( - patch( - "astrbot.core.computer.computer_client.get_astrbot_skills_path", - return_value="/tmp/empty", - ), - patch( - "astrbot.core.computer.computer_client.os.path.isdir", - return_value=True, - ), - patch( - "astrbot.core.computer.computer_client.Path.iterdir", - return_value=iter([]), - ), + empty_skills = tmp_path / "empty" + empty_skills.mkdir() + + with patch( + "astrbot.core.computer.computer_client.get_astrbot_skills_path", + return_value=str(empty_skills), ): await computer_client._sync_skills_to_sandbox(mock_booter) mock_booter.upload_file.assert_not_called() @pytest.mark.asyncio - async def test_sync_skills_success(self): + async def test_sync_skills_success(self, tmp_path): """Test successful skills sync.""" from astrbot.core.computer import computer_client @@ -742,36 +481,21 @@ async def test_sync_skills_success(self): mock_booter.shell.exec = AsyncMock(return_value={"exit_code": 0}) mock_booter.upload_file = AsyncMock(return_value={"success": True}) - mock_skill_file = MagicMock() - mock_skill_file.name = "skill.py" - mock_skill_file.__str__ = lambda: "/tmp/skills/skill.py" + skills_dir = tmp_path / "skills" + demo_skill = skills_dir / "demo_skill" + demo_skill.mkdir(parents=True) + (demo_skill / "SKILL.md").write_text("# Demo", encoding="utf-8") + temp_dir = tmp_path / "temp" + temp_dir.mkdir() with ( patch( "astrbot.core.computer.computer_client.get_astrbot_skills_path", - return_value="/tmp/skills", - ), - patch( - "astrbot.core.computer.computer_client.os.path.isdir", - return_value=True, - ), - patch( - "astrbot.core.computer.computer_client.Path.iterdir", - return_value=iter([mock_skill_file]), + return_value=str(skills_dir), ), patch( "astrbot.core.computer.computer_client.get_astrbot_temp_path", - return_value="/tmp", - ), - patch( - "astrbot.core.computer.computer_client.shutil.make_archive", - ), - patch( - "astrbot.core.computer.computer_client.os.path.exists", - return_value=True, - ), - patch( - "astrbot.core.computer.computer_client.os.remove", + return_value=str(temp_dir), ), ): # Should not raise diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 7afe82ebed..6625f18d97 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -87,6 +87,29 @@ def test_init_loads_existing_file(self, temp_config_path, minimal_default_config assert config.platform_settings["unique_session"] is True assert config.provider_settings["enable"] is False + def test_init_migrates_legacy_computer_runtime(self, temp_config_path): + """Legacy runtime values are migrated to sandbox + booter.""" + existing_config = { + "config_version": 2, + "provider_settings": { + "computer_use_runtime": "cua", + "sandbox": {}, + }, + } + with open(temp_config_path, "w", encoding="utf-8-sig") as f: + json.dump(existing_config, f) + + config = AstrBotConfig(config_path=temp_config_path) + + assert config.provider_settings["computer_use_runtime"] == "sandbox" + assert config.provider_settings["sandbox"]["booter"] == "cua" + + with open(temp_config_path, encoding="utf-8-sig") as f: + saved_config = json.load(f) + + assert saved_config["provider_settings"]["computer_use_runtime"] == "sandbox" + assert saved_config["provider_settings"]["sandbox"]["booter"] == "cua" + def test_first_deploy_flag(self, temp_config_path, minimal_default_config): """Test first_deploy flag is set for new config.""" config = AstrBotConfig( @@ -440,6 +463,31 @@ def test_remove_unknown_config_keys(self, temp_config_path, minimal_default_conf assert "unknown_key" not in config + def test_remove_unknown_sandbox_provider_config_keys(self, temp_config_path): + default_config = { + "provider_settings": { + "sandbox": {"booter": ""}, + }, + } + existing_config = { + "provider_settings": { + "sandbox": { + "booter": "generic_provider", + "cua_image": "linux", + "shipyard_neo_endpoint": "http://localhost:8000", + "shipyard_access_token": "token", + }, + }, + } + with open(temp_config_path, "w", encoding="utf-8-sig") as f: + json.dump(existing_config, f) + + config = AstrBotConfig( + config_path=temp_config_path, default_config=default_config + ) + + assert config["provider_settings"]["sandbox"] == {"booter": "generic_provider"} + def test_nested_config_validation(self, temp_config_path): """Test validation of nested config structures.""" default_config = { diff --git a/tests/unit/test_core_lifecycle.py b/tests/unit/test_core_lifecycle.py index 1fc8035e48..65a4b3fa48 100644 --- a/tests/unit/test_core_lifecycle.py +++ b/tests/unit/test_core_lifecycle.py @@ -259,6 +259,141 @@ async def test_subagent_orchestrator_error_is_logged( ) +class TestAstrBotCoreLifecycleSandboxRestore: + @pytest.mark.asyncio + async def test_initialize_does_not_restore_persistent_sandboxes( + self, mock_log_broker, mock_db, mock_astrbot_config + ): + lifecycle = AstrBotCoreLifecycle(mock_log_broker, mock_db) + + mock_db.initialize = AsyncMock() + mock_html_renderer = MagicMock() + mock_html_renderer.initialize = AsyncMock() + mock_umop_config_router = MagicMock() + mock_umop_config_router.initialize = AsyncMock() + mock_astrbot_config_mgr = MagicMock() + mock_astrbot_config_mgr.default_conf = {} + mock_astrbot_config_mgr.confs = {} + mock_persona_mgr = MagicMock(initialize=AsyncMock()) + mock_provider_manager = MagicMock(initialize=AsyncMock()) + mock_platform_manager = MagicMock(initialize=AsyncMock()) + mock_conversation_manager = MagicMock() + mock_platform_message_history_manager = MagicMock() + mock_kb_manager = MagicMock(initialize=AsyncMock()) + mock_cron_manager = MagicMock() + mock_star_context = MagicMock(_register_tasks=[]) + mock_plugin_manager = MagicMock(reload=AsyncMock()) + mock_pipeline_scheduler = MagicMock(initialize=AsyncMock()) + mock_astrbot_updator = MagicMock() + mock_event_bus = MagicMock() + with ( + patch("astrbot.core.core_lifecycle.astrbot_config", mock_astrbot_config), + patch("astrbot.core.core_lifecycle.html_renderer", mock_html_renderer), + patch( + "astrbot.core.core_lifecycle.UmopConfigRouter", + return_value=mock_umop_config_router, + ), + patch( + "astrbot.core.core_lifecycle.AstrBotConfigManager", + return_value=mock_astrbot_config_mgr, + ), + patch( + "astrbot.core.core_lifecycle.PersonaManager", + return_value=mock_persona_mgr, + ), + patch( + "astrbot.core.core_lifecycle.ProviderManager", + return_value=mock_provider_manager, + ), + patch( + "astrbot.core.core_lifecycle.PlatformManager", + return_value=mock_platform_manager, + ), + patch( + "astrbot.core.core_lifecycle.ConversationManager", + return_value=mock_conversation_manager, + ), + patch( + "astrbot.core.core_lifecycle.PlatformMessageHistoryManager", + return_value=mock_platform_message_history_manager, + ), + patch( + "astrbot.core.core_lifecycle.KnowledgeBaseManager", + return_value=mock_kb_manager, + ), + patch( + "astrbot.core.core_lifecycle.CronJobManager", + return_value=mock_cron_manager, + ), + patch( + "astrbot.core.core_lifecycle.Context", return_value=mock_star_context + ), + patch( + "astrbot.core.core_lifecycle.PluginManager", + return_value=mock_plugin_manager, + ), + patch( + "astrbot.core.core_lifecycle.PipelineScheduler", + return_value=mock_pipeline_scheduler, + ), + patch( + "astrbot.core.core_lifecycle.AstrBotUpdator", + return_value=mock_astrbot_updator, + ), + patch("astrbot.core.core_lifecycle.EventBus", return_value=mock_event_bus), + patch("astrbot.core.core_lifecycle.migra", new_callable=AsyncMock), + patch( + "astrbot.core.core_lifecycle.update_llm_metadata", + new_callable=AsyncMock, + ), + patch( + "astrbot.core.core_lifecycle.computer_client.sandbox_manager.reconcile_on_startup", + new_callable=AsyncMock, + ), + patch( + "astrbot.core.core_lifecycle.computer_client.sandbox_manager.restore_persistent_sandboxes", + new_callable=AsyncMock, + ) as restore_persistent, + ): + await asyncio.wait_for(lifecycle.initialize(), timeout=1) + + restore_persistent.assert_not_called() + + @pytest.mark.asyncio + async def test_start_schedules_persistent_sandbox_restore_without_waiting( + self, mock_log_broker, mock_db + ): + lifecycle = AstrBotCoreLifecycle(mock_log_broker, mock_db) + restore_started = asyncio.Event() + restore_finished = asyncio.Event() + lifecycle.star_context = MagicMock(_register_tasks=[]) + lifecycle.curr_tasks = [] + lifecycle.event_bus = MagicMock() + + async def dispatch(): + return None + + lifecycle.event_bus.dispatch = dispatch + lifecycle.cron_manager = None + lifecycle.temp_dir_cleaner = None + + async def restore_persistent(_context, **_kwargs): + restore_started.set() + await restore_finished.wait() + return 1, 0 + + with patch( + "astrbot.core.core_lifecycle.computer_client.sandbox_manager.restore_persistent_sandboxes", + restore_persistent, + ): + await asyncio.wait_for(lifecycle.start(), timeout=1) + assert lifecycle._persistent_restore_task is not None + await asyncio.wait_for(restore_started.wait(), timeout=1) + assert not restore_finished.is_set() + restore_finished.set() + await asyncio.wait_for(lifecycle._persistent_restore_task, timeout=1) + + class TestAstrBotCoreLifecycleDefaultChatProviderWarning: """Tests for startup warning when default chat provider is unset.""" @@ -886,18 +1021,58 @@ async def test_restart_terminates_managers_and_starts_thread( lifecycle.astrbot_updator = MagicMock() - with patch("astrbot.core.core_lifecycle.threading.Thread") as mock_thread: + with ( + patch( + "astrbot.core.core_lifecycle.computer_client.cleanup_managed_sandboxes", + new_callable=AsyncMock, + ) as mock_cleanup, + patch("astrbot.core.core_lifecycle.threading.Thread") as mock_thread, + ): await lifecycle.restart() # Verify managers were terminated lifecycle.provider_manager.terminate.assert_awaited_once() lifecycle.platform_manager.terminate.assert_awaited_once() lifecycle.kb_manager.terminate.assert_awaited_once() + mock_cleanup.assert_awaited_once() # Verify thread was started mock_thread.assert_called_once() mock_thread.return_value.start.assert_called_once() + @pytest.mark.asyncio + async def test_restart_cleans_managed_sandboxes_before_reboot( + self, mock_log_broker, mock_db + ): + """Test that restart cleans managed sandboxes before rebooting.""" + lifecycle = AstrBotCoreLifecycle(mock_log_broker, mock_db) + + lifecycle.provider_manager = MagicMock() + lifecycle.provider_manager.terminate = AsyncMock() + + lifecycle.platform_manager = MagicMock() + lifecycle.platform_manager.terminate = AsyncMock() + + lifecycle.kb_manager = MagicMock() + lifecycle.kb_manager.terminate = AsyncMock() + + lifecycle.dashboard_shutdown_event = asyncio.Event() + + with ( + patch( + "astrbot.core.core_lifecycle.computer_client.cleanup_managed_sandboxes", + new_callable=AsyncMock, + ) as mock_cleanup, + patch("astrbot.core.core_lifecycle.threading.Thread") as mock_thread, + ): + lifecycle.astrbot_updator = MagicMock() + + await lifecycle.restart() + + mock_cleanup.assert_awaited_once() + mock_thread.assert_called_once() + mock_thread.return_value.start.assert_called_once() + class TestAstrBotCoreLifecycleLoadPipelineScheduler: """Tests for AstrBotCoreLifecycle.load_pipeline_scheduler method.""" diff --git a/tests/unit/test_cua_computer_use.py b/tests/unit/test_cua_computer_use.py deleted file mode 100644 index 3a092146c0..0000000000 --- a/tests/unit/test_cua_computer_use.py +++ /dev/null @@ -1,1812 +0,0 @@ -import asyncio -import base64 -import json -import shlex -from pathlib import Path - -import mcp -import pytest - -from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor -from astrbot.core.config.default import CONFIG_METADATA_3 -from astrbot.core.provider.func_tool_manager import FunctionToolManager - - -class FakeContext: - def __init__(self, config: dict): - self._config = config - - def get_config(self, umo: str | None = None): - return self._config - - -def _clear_cua_session_state(computer_client, session_id: str) -> None: - computer_client.session_booter.pop(session_id, None) - state = getattr(computer_client, "cua_idle_state", {}).pop(session_id, None) - if state is not None and not state.task.done(): - state.task.cancel() - - -class FakeShell: - def __init__(self): - self.commands = [] - - async def run(self, command: str, **kwargs): - self.commands.append((command, kwargs)) - return {"stdout": "ok", "stderr": "", "exit_code": 0} - - -class ProcessShapeShell: - async def run(self, command: str, **kwargs): - return {"output": "shape-ok", "returncode": 0} - - -class CommandResultShapeShell: - def __init__(self, stdout: str = "shape-ok", stderr: str = "", returncode: int = 0): - self.commands = [] - self.stdout = stdout - self.stderr = stderr - self.returncode = returncode - - @property - def success(self): - return self.returncode == 0 - - async def run(self, command: str, **kwargs): - self.commands.append((command, kwargs)) - return self - - -class FakePython: - async def run(self, code: str, **kwargs): - return {"output": "42", "error": ""} - - -class FakeFilesystem: - def __init__(self): - self.files = {} - - async def write_file(self, path: str, content: str): - self.files[path] = content - - async def read_file(self, path: str): - return self.files[path] - - async def delete(self, path: str): - self.files.pop(path, None) - - async def list_dir(self, path: str): - return [path] - - -class FakeFiles: - def __init__(self): - self.uploads = [] - self.byte_writes = [] - self.text_writes = [] - self.text_reads = {} - - async def upload(self, local_path: str, remote_path: str): - self.uploads.append((local_path, remote_path)) - - async def write_bytes(self, path: str, content: bytes): - self.byte_writes.append((path, content)) - - async def write_text(self, path: str, content: str): - self.text_writes.append((path, content)) - self.text_reads[path] = content - - async def read_text(self, path: str): - return self.text_reads[path] - - -class FakeMouse: - def __init__(self): - self.clicks = [] - - async def click(self, x: int, y: int, button: str = "left"): - self.clicks.append((x, y, button)) - return {"success": True} - - -class FakeKeyboard: - def __init__(self): - self.typed = [] - self.pressed = [] - - async def type(self, text: str): - self.typed.append(text) - return {"success": True} - - async def press(self, key: str): - self.pressed.append(key) - return {"success": True} - - -class FakeSandbox: - def __init__(self): - self.shell = FakeShell() - self.python = FakePython() - self.filesystem = FakeFilesystem() - self.mouse = FakeMouse() - self.keyboard = FakeKeyboard() - - async def screenshot(self): - return b"fake-png" - - -class SyncShell: - def __init__(self, stdout: str = "ok"): - self.commands = [] - self.stdout = stdout - - def run(self, command: str, **kwargs): - self.commands.append((command, kwargs)) - return {"stdout": self.stdout, "stderr": "", "exit_code": 0} - - -class FailingShell: - def __init__(self): - self.commands = [] - - async def run(self, command: str, **kwargs): - self.commands.append((command, kwargs)) - return { - "stdout": "", - "stderr": "python3: command not found", - "exit_code": 127, - "success": False, - } - - -class SandboxWithoutFilesystem: - def __init__(self): - self.shell = FakeShell() - self.python = FakePython() - - -class SyncPython: - def run(self, code: str, **kwargs): - return {"output": "sync", "error": ""} - - -def _agent_computer_use_items(): - return CONFIG_METADATA_3["ai_group"]["metadata"]["agent_computer_use"]["items"] - - -@pytest.mark.asyncio -async def test_get_booter_creates_cua_booter(monkeypatch): - from astrbot.core.computer import computer_client - - created = [] - - class FakeCuaBooter: - def __init__( - self, - image: str, - os_type: str, - ttl: int, - telemetry_enabled: bool, - local: bool, - api_key: str, - ): - created.append((image, os_type, ttl, telemetry_enabled, local, api_key)) - - async def boot(self, session_id: str): - self.session_id = session_id - - async def available(self): - return True - - monkeypatch.setattr( - computer_client, "_sync_skills_to_sandbox", lambda booter: asyncio.sleep(0) - ) - monkeypatch.setitem(computer_client.session_booter, "cua-test", None) - computer_client.session_booter.pop("cua-test", None) - monkeypatch.setattr( - "astrbot.core.computer.booters.cua.CuaBooter", - FakeCuaBooter, - raising=False, - ) - - ctx = FakeContext( - { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": { - "booter": "cua", - "cua_image": "linux", - "cua_os_type": "linux", - "cua_ttl": 120, - "cua_telemetry_enabled": False, - "cua_local": True, - "cua_api_key": "", - }, - } - } - ) - - booter = await computer_client.get_booter(ctx, "cua-test") - - assert isinstance(booter, FakeCuaBooter) - assert created == [("linux", "linux", 120, False, True, "")] - - -def test_cua_ephemeral_kwargs_include_local_when_supported(): - from astrbot.core.computer.booters.cua import CuaBooter - - def ephemeral(image, ttl=None, telemetry_enabled=None, local=None): - return image, ttl, telemetry_enabled, local - - kwargs = CuaBooter( - ttl=120, telemetry_enabled=False, local=True - )._build_ephemeral_kwargs(ephemeral) - - assert kwargs == {"ttl": 120, "telemetry_enabled": False, "local": True} - - -def test_cua_ephemeral_kwargs_include_api_key_for_cloud_when_supported(): - from astrbot.core.computer.booters.cua import CuaBooter - - def ephemeral(image, local=None, api_key=None): - return image, local, api_key - - kwargs = CuaBooter(local=False, api_key="sk-test")._build_ephemeral_kwargs( - ephemeral - ) - - assert kwargs == {"local": False, "api_key": "sk-test"} - - -def test_cua_default_config_matches_booter_defaults(): - from astrbot.core.computer.booters.cua import CUA_DEFAULT_CONFIG, CuaBooter - from astrbot.core.config.default import DEFAULT_CONFIG - - booter = CuaBooter() - sandbox_defaults = DEFAULT_CONFIG["provider_settings"]["sandbox"] - - assert booter.image == CUA_DEFAULT_CONFIG["image"] - assert booter.os_type == CUA_DEFAULT_CONFIG["os_type"] - assert booter.ttl == CUA_DEFAULT_CONFIG["ttl"] - assert booter.telemetry_enabled == CUA_DEFAULT_CONFIG["telemetry_enabled"] - assert booter.local == CUA_DEFAULT_CONFIG["local"] - assert booter.api_key == CUA_DEFAULT_CONFIG["api_key"] - assert sandbox_defaults["cua_image"] == CUA_DEFAULT_CONFIG["image"] - assert sandbox_defaults["cua_os_type"] == CUA_DEFAULT_CONFIG["os_type"] - assert "cua_ttl" not in sandbox_defaults - assert sandbox_defaults["cua_idle_timeout"] == CUA_DEFAULT_CONFIG["idle_timeout"] - assert ( - sandbox_defaults["cua_telemetry_enabled"] - == CUA_DEFAULT_CONFIG["telemetry_enabled"] - ) - assert sandbox_defaults["cua_local"] == CUA_DEFAULT_CONFIG["local"] - assert sandbox_defaults["cua_api_key"] == CUA_DEFAULT_CONFIG["api_key"] - - -@pytest.mark.asyncio -async def test_cua_config_log_does_not_include_api_key(monkeypatch): - from astrbot.core.computer import computer_client - - log_messages = [] - - class FakeCuaBooter: - def __init__(self, **kwargs): - self.kwargs = kwargs - - async def boot(self, session_id: str): - self.session_id = session_id - - async def available(self): - return True - - monkeypatch.setattr( - computer_client, "_sync_skills_to_sandbox", lambda booter: asyncio.sleep(0) - ) - monkeypatch.setitem(computer_client.session_booter, "cua-log-test", None) - computer_client.session_booter.pop("cua-log-test", None) - monkeypatch.setattr( - "astrbot.core.computer.booters.cua.CuaBooter", - FakeCuaBooter, - raising=False, - ) - monkeypatch.setattr(computer_client.logger, "info", log_messages.append) - - ctx = FakeContext( - { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": { - "booter": "cua", - "cua_local": False, - "cua_api_key": "sk-secret-value", - }, - } - } - ) - - await computer_client.get_booter(ctx, "cua-log-test") - - assert log_messages - assert all("sk-secret-value" not in message for message in log_messages) - assert all("api_key" not in message for message in log_messages) - - -@pytest.mark.asyncio -async def test_get_booter_shuts_down_client_when_skill_sync_fails(monkeypatch): - from astrbot.core.computer import computer_client - - shutdowns = [] - - class FakeCuaBooter: - def __init__(self, **kwargs): - self.kwargs = kwargs - - async def boot(self, session_id: str): - self.session_id = session_id - - async def shutdown(self): - shutdowns.append(self.session_id) - - async def fail_sync(booter): - raise RuntimeError("sync failed") - - monkeypatch.setattr(computer_client, "_sync_skills_to_sandbox", fail_sync) - monkeypatch.setitem(computer_client.session_booter, "cua-sync-fail", None) - computer_client.session_booter.pop("cua-sync-fail", None) - monkeypatch.setattr( - "astrbot.core.computer.booters.cua.CuaBooter", - FakeCuaBooter, - raising=False, - ) - - ctx = FakeContext( - { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": {"booter": "cua"}, - } - } - ) - - with pytest.raises(RuntimeError, match="sync failed"): - await computer_client.get_booter(ctx, "cua-sync-fail") - - assert len(shutdowns) == 1 - assert "cua-sync-fail" not in computer_client.session_booter - - -@pytest.mark.asyncio -async def test_cua_idle_timeout_shuts_down_session_proactively(monkeypatch): - from astrbot.core.computer import computer_client - - shutdowns = [] - - class FakeCuaBooter: - def __init__(self, **kwargs): - self.kwargs = kwargs - - async def boot(self, session_id: str): - self.session_id = session_id - - async def available(self): - return True - - async def shutdown(self): - shutdowns.append(self.session_id) - - monkeypatch.setattr( - computer_client, "_sync_skills_to_sandbox", lambda booter: asyncio.sleep(0) - ) - monkeypatch.setattr( - "astrbot.core.computer.booters.cua.CuaBooter", - FakeCuaBooter, - raising=False, - ) - _clear_cua_session_state(computer_client, "cua-idle-expire") - - ctx = FakeContext( - { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": { - "booter": "cua", - "cua_idle_timeout": 0.1, - }, - } - } - ) - - booter = await computer_client.get_booter(ctx, "cua-idle-expire") - await asyncio.sleep(0.2) - - assert shutdowns == [booter.session_id] - assert "cua-idle-expire" not in computer_client.session_booter - - -@pytest.mark.asyncio -async def test_cua_idle_timeout_refreshes_on_reuse(monkeypatch): - from astrbot.core.computer import computer_client - - shutdowns = [] - - class FakeCuaBooter: - def __init__(self, **kwargs): - self.kwargs = kwargs - - async def boot(self, session_id: str): - self.session_id = session_id - - async def available(self): - return True - - async def shutdown(self): - shutdowns.append(self.session_id) - - monkeypatch.setattr( - computer_client, "_sync_skills_to_sandbox", lambda booter: asyncio.sleep(0) - ) - monkeypatch.setattr( - "astrbot.core.computer.booters.cua.CuaBooter", - FakeCuaBooter, - raising=False, - ) - _clear_cua_session_state(computer_client, "cua-idle-refresh") - - ctx = FakeContext( - { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": { - "booter": "cua", - "cua_idle_timeout": 0.2, - }, - } - } - ) - - booter1 = await computer_client.get_booter(ctx, "cua-idle-refresh") - await asyncio.sleep(0.05) - booter2 = await computer_client.get_booter(ctx, "cua-idle-refresh") - await asyncio.sleep(0.05) - - assert booter2 is booter1 - assert shutdowns == [] - - await asyncio.sleep(0.25) - - assert shutdowns == [booter1.session_id] - assert "cua-idle-refresh" not in computer_client.session_booter - - -@pytest.mark.asyncio -async def test_cua_idle_timeout_zero_disables_proactive_shutdown(monkeypatch): - from astrbot.core.computer import computer_client - - shutdowns = [] - - class FakeCuaBooter: - def __init__(self, **kwargs): - self.kwargs = kwargs - - async def boot(self, session_id: str): - self.session_id = session_id - - async def available(self): - return True - - async def shutdown(self): - shutdowns.append(self.session_id) - - monkeypatch.setattr( - computer_client, "_sync_skills_to_sandbox", lambda booter: asyncio.sleep(0) - ) - monkeypatch.setattr( - "astrbot.core.computer.booters.cua.CuaBooter", - FakeCuaBooter, - raising=False, - ) - _clear_cua_session_state(computer_client, "cua-idle-disabled") - - ctx = FakeContext( - { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": { - "booter": "cua", - "cua_idle_timeout": 0, - }, - } - } - ) - - await computer_client.get_booter(ctx, "cua-idle-disabled") - await asyncio.sleep(0.05) - - assert shutdowns == [] - assert "cua-idle-disabled" in computer_client.session_booter - assert "cua-idle-disabled" not in computer_client.cua_idle_state - - -@pytest.mark.asyncio -async def test_non_cua_booter_does_not_schedule_idle_cleanup(monkeypatch): - from astrbot.core.computer import computer_client - - class FakeShipyardBooter: - async def available(self): - return True - - _clear_cua_session_state(computer_client, "shipyard-session") - computer_client.session_booter["shipyard-session"] = FakeShipyardBooter() - - ctx = FakeContext( - { - "provider_settings": { - "computer_use_runtime": "sandbox", - "sandbox": { - "booter": "shipyard", - "shipyard_endpoint": "http://localhost:8080", - "shipyard_access_token": "token", - "cua_idle_timeout": 0.01, - }, - } - } - ) - - booter = await computer_client.get_booter(ctx, "shipyard-session") - - assert isinstance(booter, FakeShipyardBooter) - assert "shipyard-session" not in computer_client.cua_idle_state - - -@pytest.mark.asyncio -async def test_cua_components_map_sdk_results(tmp_path): - from astrbot.core.computer.booters.cua import ( - CuaFileSystemComponent, - CuaGUIComponent, - CuaPythonComponent, - CuaShellComponent, - ) - - sandbox = FakeSandbox() - - shell_result = await CuaShellComponent(sandbox).exec("echo ok", cwd="/workspace") - python_result = await CuaPythonComponent(sandbox).exec("print(42)") - fs = CuaFileSystemComponent(sandbox) - await fs.write_file("hello.txt", "hello") - read_result = await fs.read_file("hello.txt") - screenshot_path = tmp_path / "screen.png" - gui = CuaGUIComponent(sandbox) - screenshot_result = await gui.screenshot(str(screenshot_path)) - click_result = await gui.click(10, 20, button="right") - type_result = await gui.type_text("hello") - press_result = await gui.press_key("Enter") - - assert shell_result["stdout"] == "ok" - assert python_result["data"]["output"]["text"] == "42" - assert read_result["content"] == "hello" - assert screenshot_path.read_bytes() == b"fake-png" - assert screenshot_result["mime_type"] == "image/png" - assert click_result["success"] is True - assert type_result["success"] is True - assert press_result["success"] is True - assert sandbox.mouse.clicks == [(10, 20, "right")] - assert sandbox.keyboard.typed == ["hello"] - assert sandbox.keyboard.pressed == ["Enter"] - - -@pytest.mark.asyncio -async def test_cua_list_dir_returns_entries_list_for_shell_fallback(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - sandbox = FakeSandbox() - delattr(sandbox, "filesystem") - - result = await CuaFileSystemComponent(sandbox).list_dir(".") - - assert result["success"] is True - assert result["entries"] == ["ok"] - assert sandbox.shell.commands[0][0] == "ls -1 ." - - -@pytest.mark.asyncio -async def test_cua_shell_filesystem_fallback_shell_quotes_paths(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - path = "folder/it's file.txt" - sandbox = FakeSandbox() - delattr(sandbox, "filesystem") - fs = CuaFileSystemComponent(sandbox) - - await fs.read_file(path) - await fs.delete_file(path) - await fs.list_dir(path) - - assert sandbox.shell.commands[0][0] == f"cat {shlex.quote(path)}" - assert sandbox.shell.commands[1][0] == f"rm -rf {shlex.quote(path)}" - assert sandbox.shell.commands[2][0] == f"ls -1 {shlex.quote(path)}" - - -@pytest.mark.asyncio -async def test_cua_write_file_shell_fallback_uses_python_base64_decoder(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - sandbox = FakeSandbox() - delattr(sandbox, "filesystem") - - await CuaFileSystemComponent(sandbox).write_file("hello.txt", "hello") - - command = sandbox.shell.commands[0][0] - assert "python3 -c" in command - assert "base64 -d" not in command - - -@pytest.mark.asyncio -async def test_cua_create_file_reports_mode_as_informational(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - sandbox = FakeSandbox() - - result = await CuaFileSystemComponent(sandbox).create_file("hello.txt", mode=0o600) - - assert result["success"] is True - assert result["mode"] == 0o600 - assert result["mode_applied"] is False - - -@pytest.mark.asyncio -async def test_cua_write_file_shell_fallback_propagates_shell_failure(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - sandbox = FakeSandbox() - sandbox.shell = FailingShell() - delattr(sandbox, "filesystem") - - result = await CuaFileSystemComponent(sandbox).write_file("hello.txt", "hello") - - assert result["success"] is False - assert "requires python3" in result["stderr"] - assert "python3: command not found" in result["stderr"] - assert result["path"] == "hello.txt" - - -@pytest.mark.asyncio -async def test_cua_edit_file_propagates_write_failure(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - class ReadableButFailingWriteShell: - def __init__(self): - self.commands = [] - - async def run(self, command: str, **kwargs): - self.commands.append((command, kwargs)) - if command.startswith("cat "): - return {"stdout": "hello old", "stderr": "", "exit_code": 0} - return { - "stdout": "", - "stderr": "permission denied", - "exit_code": 1, - "success": False, - } - - sandbox = FakeSandbox() - sandbox.shell = ReadableButFailingWriteShell() - delattr(sandbox, "filesystem") - - result = await CuaFileSystemComponent(sandbox).edit_file("hello.txt", "old", "new") - - assert result["success"] is False - assert result["stderr"] == "permission denied" - assert result["path"] == "hello.txt" - - -@pytest.mark.asyncio -async def test_cua_list_dir_shell_fallback_returns_filename_only_entries(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - sandbox = FakeSandbox() - sandbox.shell = SyncShell("alpha.txt\nfolder\n") - delattr(sandbox, "filesystem") - - result = await CuaFileSystemComponent(sandbox).list_dir(".", show_hidden=True) - - assert result["entries"] == ["alpha.txt", "folder"] - assert sandbox.shell.commands[0][0] == "ls -1A ." - - -@pytest.mark.asyncio -async def test_cua_shell_filesystem_fallback_rejects_non_posix_os_type(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - sandbox = SandboxWithoutFilesystem() - fs = CuaFileSystemComponent(sandbox, os_type="windows") - - read_result = await fs.read_file("hello.txt") - write_result = await fs.write_file("hello.txt", "hello") - delete_result = await fs.delete_file("hello.txt") - list_result = await fs.list_dir(".") - - for result in (read_result, write_result, delete_result, list_result): - assert result["success"] is False - assert ( - "filesystem shell fallback is only supported for POSIX" in result["error"] - ) - assert sandbox.shell.commands == [] - - -@pytest.mark.asyncio -async def test_cua_shell_and_python_accept_sync_sdk_methods(): - from astrbot.core.computer.booters.cua import CuaPythonComponent, CuaShellComponent - - sandbox = FakeSandbox() - sandbox.shell = SyncShell() - sandbox.python = SyncPython() - - shell_result = await CuaShellComponent(sandbox).exec("echo ok") - python_result = await CuaPythonComponent(sandbox).exec("print('ok')") - - assert shell_result["stdout"] == "ok" - assert python_result["data"]["output"]["text"] == "sync" - - -@pytest.mark.asyncio -async def test_cua_filesystem_prefers_native_files_interface(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - sandbox = SandboxWithoutFilesystem() - sandbox.files = FakeFiles() - - fs = CuaFileSystemComponent(sandbox) - await fs.write_file("hello.txt", "hello") - result = await fs.read_file("hello.txt") - - assert sandbox.files.text_writes == [("hello.txt", "hello")] - assert result["success"] is True - assert result["content"] == "hello" - assert sandbox.shell.commands == [] - - -@pytest.mark.asyncio -async def test_cua_filesystem_uses_legacy_filesystem_when_files_lacks_method(): - from astrbot.core.computer.booters.cua import CuaFileSystemComponent - - sandbox = SandboxWithoutFilesystem() - sandbox.files = type("UploadOnlyFiles", (), {"upload": FakeFiles().upload})() - sandbox.filesystem = FakeFilesystem() - - fs = CuaFileSystemComponent(sandbox) - await fs.write_file("hello.txt", "hello") - result = await fs.read_file("hello.txt") - - assert sandbox.filesystem.files == {"hello.txt": "hello"} - assert result["success"] is True - assert result["content"] == "hello" - assert sandbox.shell.commands == [] - - -@pytest.mark.asyncio -async def test_cua_shell_normalizes_output_returncode_shape(): - from astrbot.core.computer.booters.cua import CuaShellComponent - - sandbox = FakeSandbox() - sandbox.shell = ProcessShapeShell() - - result = await CuaShellComponent(sandbox).exec("echo ok") - - assert result == { - "stdout": "shape-ok", - "stderr": "", - "exit_code": 0, - "success": True, - } - - -@pytest.mark.asyncio -async def test_cua_shell_normalizes_command_result_object_shape(): - from astrbot.core.computer.booters.cua import CuaShellComponent - - sandbox = FakeSandbox() - sandbox.shell = CommandResultShapeShell(stdout="hello\n", returncode=0) - - result = await CuaShellComponent(sandbox).exec("echo hello") - - assert result == { - "stdout": "hello\n", - "stderr": "", - "exit_code": 0, - "success": True, - } - - -@pytest.mark.asyncio -async def test_cua_shell_prefers_returncode_when_exit_code_is_none(): - from astrbot.core.computer.booters.cua import CuaShellComponent - - class ShellWithMixedExitCode: - async def run(self, command: str, **kwargs): - return { - "stdout": "", - "stderr": "", - "exit_code": None, - "returncode": 1, - } - - sandbox = FakeSandbox() - sandbox.shell = ShellWithMixedExitCode() - - result = await CuaShellComponent(sandbox).exec("false") - - assert result["exit_code"] == 1 - assert result["success"] is False - - -@pytest.mark.asyncio -async def test_cua_python_fallback_preserves_shell_command_result_stdout(): - from astrbot.core.computer.booters.cua import CuaPythonComponent - - sandbox = SandboxWithoutFilesystem() - sandbox.shell = CommandResultShapeShell(stdout="from python fallback\n") - delattr(sandbox, "python") - - result = await CuaPythonComponent(sandbox).exec("print('from python fallback')") - - assert result["success"] is True - assert result["output"] == "from python fallback\n" - assert result["data"]["output"]["text"] == "from python fallback\n" - - -@pytest.mark.asyncio -async def test_cua_shell_background_wrapper_detaches_via_python_subprocess(): - from astrbot.core.computer.booters.cua import CuaShellComponent - - sandbox = FakeSandbox() - - await CuaShellComponent(sandbox).exec( - "chromium https://example.com", background=True - ) - - command = sandbox.shell.commands[0][0] - assert command.startswith("python3 -c ") - assert "subprocess.Popen" in command - assert "start_new_session=True" in command - assert "p.pid" in command - assert "stdout=subprocess.DEVNULL" in command - assert "stderr=subprocess.DEVNULL" in command - assert "time.sleep(0.2)" in command - assert "'chromium https://example.com'" in command - assert "&" not in command - - -@pytest.mark.asyncio -async def test_cua_shell_background_rejects_non_posix_os_type(): - from astrbot.core.computer.booters.cua import CuaShellComponent - - sandbox = FakeSandbox() - - result = await CuaShellComponent(sandbox, os_type="windows").exec( - "start notepad", background=True - ) - - assert result == { - "stdout": "", - "stderr": "error: background shell execution is only supported for POSIX CUA images.", - "exit_code": 2, - "success": False, - } - assert sandbox.shell.commands == [] - - -@pytest.mark.asyncio -async def test_cua_upload_file_fallback_rejects_non_posix_os_type(tmp_path): - from astrbot.core.computer.booters.cua import ( - CuaBooter, - CuaFileSystemComponent, - CuaGUIComponent, - CuaPythonComponent, - CuaShellComponent, - _CuaRuntime, - ) - - local_file = tmp_path / "upload.txt" - local_file.write_text("hello", encoding="utf-8") - sandbox = SandboxWithoutFilesystem() - booter = CuaBooter(os_type="windows") - booter._runtime = _CuaRuntime( - sandbox_cm=object(), - sandbox=sandbox, - shell=CuaShellComponent(sandbox, os_type="windows"), - python=CuaPythonComponent(sandbox, os_type="windows"), - fs=CuaFileSystemComponent(sandbox, os_type="windows"), - gui=CuaGUIComponent(sandbox), - ) - - result = await booter.upload_file(str(local_file), "remote.txt") - - assert result["success"] is False - assert "filesystem shell fallback is only supported for POSIX" in result["error"] - assert sandbox.shell.commands == [] - - -@pytest.mark.asyncio -async def test_cua_upload_file_prefers_native_files_upload(tmp_path): - from astrbot.core.computer.booters.cua import ( - CuaBooter, - CuaFileSystemComponent, - CuaGUIComponent, - CuaPythonComponent, - CuaShellComponent, - _CuaRuntime, - ) - - local_file = tmp_path / "upload.txt" - local_file.write_text("hello", encoding="utf-8") - sandbox = SandboxWithoutFilesystem() - sandbox.files = FakeFiles() - booter = CuaBooter() - booter._runtime = _CuaRuntime( - sandbox_cm=object(), - sandbox=sandbox, - shell=CuaShellComponent(sandbox), - python=CuaPythonComponent(sandbox), - fs=CuaFileSystemComponent(sandbox), - gui=CuaGUIComponent(sandbox), - ) - - result = await booter.upload_file(str(local_file), "remote.txt") - - assert result["success"] is True - assert sandbox.files.uploads == [(str(local_file), "remote.txt")] - assert sandbox.shell.commands == [] - - -@pytest.mark.asyncio -async def test_cua_upload_file_uses_native_write_bytes_when_upload_missing(tmp_path): - from astrbot.core.computer.booters.cua import ( - CuaBooter, - CuaFileSystemComponent, - CuaGUIComponent, - CuaPythonComponent, - CuaShellComponent, - _CuaRuntime, - ) - - class FilesWithoutUpload: - def __init__(self): - self.byte_writes = [] - - async def write_bytes(self, path: str, content: bytes): - self.byte_writes.append((path, content)) - - local_file = tmp_path / "upload.txt" - local_file.write_bytes(b"hello-bytes") - sandbox = SandboxWithoutFilesystem() - sandbox.files = FilesWithoutUpload() - booter = CuaBooter() - booter._runtime = _CuaRuntime( - sandbox_cm=object(), - sandbox=sandbox, - shell=CuaShellComponent(sandbox), - python=CuaPythonComponent(sandbox), - fs=CuaFileSystemComponent(sandbox), - gui=CuaGUIComponent(sandbox), - ) - - result = await booter.upload_file(str(local_file), "remote.txt") - - assert result["success"] is True - assert sandbox.files.byte_writes == [("remote.txt", b"hello-bytes")] - assert sandbox.shell.commands == [] - - -@pytest.mark.asyncio -async def test_cua_upload_file_propagates_native_upload_failure_result(tmp_path): - from astrbot.core.computer.booters.cua import ( - CuaBooter, - CuaFileSystemComponent, - CuaGUIComponent, - CuaPythonComponent, - CuaShellComponent, - _CuaRuntime, - ) - - class FailingFilesUpload: - async def upload(self, local_path: str, remote_path: str): - return {"success": False, "error": "disk full"} - - local_file = tmp_path / "upload.txt" - local_file.write_text("hello", encoding="utf-8") - sandbox = SandboxWithoutFilesystem() - sandbox.files = FailingFilesUpload() - booter = CuaBooter() - booter._runtime = _CuaRuntime( - sandbox_cm=object(), - sandbox=sandbox, - shell=CuaShellComponent(sandbox), - python=CuaPythonComponent(sandbox), - fs=CuaFileSystemComponent(sandbox), - gui=CuaGUIComponent(sandbox), - ) - - result = await booter.upload_file(str(local_file), "remote.txt") - - assert result["success"] is False - assert result["error"] == "disk full" - - -@pytest.mark.asyncio -async def test_cua_download_file_shell_quotes_remote_path(tmp_path): - from astrbot.core.computer.booters.cua import ( - CuaBooter, - CuaFileSystemComponent, - CuaGUIComponent, - CuaPythonComponent, - CuaShellComponent, - _CuaRuntime, - ) - - class Base64Shell(FakeShell): - async def run(self, command: str, **kwargs): - self.commands.append((command, kwargs)) - return { - "stdout": base64.b64encode(b"hello").decode(), - "stderr": "", - "exit_code": 0, - } - - sandbox = SandboxWithoutFilesystem() - sandbox.shell = Base64Shell() - booter = CuaBooter() - booter._runtime = _CuaRuntime( - sandbox_cm=object(), - sandbox=sandbox, - shell=CuaShellComponent(sandbox), - python=CuaPythonComponent(sandbox), - fs=CuaFileSystemComponent(sandbox), - gui=CuaGUIComponent(sandbox), - ) - remote_path = "folder/it's file.txt" - local_path = tmp_path / "download.txt" - - await booter.download_file(remote_path, str(local_path)) - - assert sandbox.shell.commands[0][0] == f"base64 {shlex.quote(remote_path)}" - assert local_path.read_bytes() == b"hello" - - -@pytest.mark.asyncio -async def test_cua_download_file_fallback_rejects_non_posix_os_type(tmp_path): - from astrbot.core.computer.booters.cua import ( - CuaBooter, - CuaFileSystemComponent, - CuaGUIComponent, - CuaPythonComponent, - CuaShellComponent, - _CuaRuntime, - ) - - sandbox = SandboxWithoutFilesystem() - booter = CuaBooter(os_type="windows") - booter._runtime = _CuaRuntime( - sandbox_cm=object(), - sandbox=sandbox, - shell=CuaShellComponent(sandbox, os_type="windows"), - python=CuaPythonComponent(sandbox, os_type="windows"), - fs=CuaFileSystemComponent(sandbox, os_type="windows"), - gui=CuaGUIComponent(sandbox), - ) - - with pytest.raises(RuntimeError, match="filesystem shell fallback"): - await booter.download_file("remote.txt", str(tmp_path / "download.txt")) - - assert sandbox.shell.commands == [] - - -@pytest.mark.asyncio -async def test_cua_boot_cleans_up_sandbox_when_component_setup_fails(monkeypatch): - from astrbot.core.computer.booters import cua as cua_booter - - closed = [] - - class FakeSandboxContext: - async def __aenter__(self): - return FakeSandbox() - - async def __aexit__(self, exc_type, exc, tb): - closed.append((exc_type, exc, tb)) - - class FakeImage: - @staticmethod - def linux(): - return "linux-image" - - class FakeSandboxFactory: - @staticmethod - def ephemeral(image, **kwargs): - return FakeSandboxContext() - - class BrokenShellComponent: - def __init__(self, sandbox, os_type="linux"): - raise RuntimeError("component setup failed") - - original_import = __import__ - - def fake_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == "cua": - - class FakeCuaModule: - Image = FakeImage - Sandbox = FakeSandboxFactory - - return FakeCuaModule() - return original_import(name, globals, locals, fromlist, level) - - monkeypatch.setattr("builtins.__import__", fake_import) - monkeypatch.setattr(cua_booter, "CuaShellComponent", BrokenShellComponent) - - booter = cua_booter.CuaBooter() - - with pytest.raises(RuntimeError, match="component setup failed"): - await booter.boot("session") - - assert len(closed) == 1 - assert booter._runtime is None - - -@pytest.mark.asyncio -async def test_cua_shell_background_reports_missing_python3_requirement(): - from astrbot.core.computer.booters.cua import CuaShellComponent - - sandbox = FakeSandbox() - sandbox.shell = FailingShell() - - result = await CuaShellComponent(sandbox).exec("firefox", background=True) - - assert result["success"] is False - assert "requires python3" in result["stderr"] - assert "python3: command not found" in result["stderr"] - - -@pytest.mark.asyncio -async def test_cua_python_fallback_reports_missing_python3_requirement(): - from astrbot.core.computer.booters.cua import CuaPythonComponent - - sandbox = SandboxWithoutFilesystem() - sandbox.shell = FailingShell() - delattr(sandbox, "python") - - result = await CuaPythonComponent(sandbox).exec("print('hello')") - - assert result["success"] is False - assert "requires python3" in result["error"] - assert "python3: command not found" in result["error"] - - -@pytest.mark.asyncio -async def test_cua_gui_reports_missing_mouse_or_keyboard(): - from astrbot.core.computer.booters.cua import CuaGUIComponent - - class SandboxWithoutGuiDevices: - async def screenshot(self): - return b"fake-png" - - gui = CuaGUIComponent(SandboxWithoutGuiDevices()) - - with pytest.raises(RuntimeError, match="mouse.*click"): - await gui.click(1, 2) - - with pytest.raises(RuntimeError, match="keyboard.*type"): - await gui.type_text("hello") - - with pytest.raises(RuntimeError, match="keyboard.*press"): - await gui.press_key("Enter") - - -@pytest.mark.asyncio -async def test_cua_gui_press_error_lists_probed_methods(): - from astrbot.core.computer.booters.cua import CuaGUIComponent - - class SandboxWithoutPress: - keyboard = object() - - gui = CuaGUIComponent(SandboxWithoutPress()) - - with pytest.raises(RuntimeError) as exc_info: - await gui.press_key("Enter") - - message = str(exc_info.value) - assert "keyboard.press" in message - assert "keyboard.key_press" in message - assert "keyboard.press_key" in message - - -@pytest.mark.asyncio -async def test_cua_gui_caches_component_methods_after_initialization(): - from astrbot.core.computer.booters.cua import CuaGUIComponent - - class CountingMouse: - def __init__(self): - self.click_lookups = 0 - self.clicks = [] - - def __getattribute__(self, name): - if name == "click": - object.__getattribute__(self, "__dict__")["click_lookups"] += 1 - return object.__getattribute__(self, name) - - async def click(self, x: int, y: int, button: str = "left"): - self.clicks.append((x, y, button)) - return {"success": True} - - class Sandbox: - def __init__(self): - self.mouse = CountingMouse() - - sandbox = Sandbox() - gui = CuaGUIComponent(sandbox) - - await gui.click(1, 2) - await gui.click(3, 4, button="right") - - assert sandbox.mouse.click_lookups == 1 - assert sandbox.mouse.clicks == [(1, 2, "left"), (3, 4, "right")] - - -def test_cua_capabilities_reflect_initialized_sandbox_gui_devices(): - from astrbot.core.computer.booters.cua import ( - CuaBooter, - CuaFileSystemComponent, - CuaGUIComponent, - CuaPythonComponent, - CuaShellComponent, - _CuaRuntime, - ) - - def set_runtime(booter, sandbox): - shell = CuaShellComponent(sandbox) - booter._runtime = _CuaRuntime( - sandbox_cm=object(), - sandbox=sandbox, - shell=shell, - python=CuaPythonComponent(sandbox), - fs=CuaFileSystemComponent(sandbox), - gui=CuaGUIComponent(sandbox), - ) - - booter = CuaBooter() - set_runtime(booter, FakeSandbox()) - - assert booter.capabilities == ( - "python", - "shell", - "filesystem", - "gui", - "screenshot", - "mouse", - "keyboard", - ) - - class ScreenshotOnlySandbox: - shell = FakeShell() - - async def screenshot(self): - return b"fake-png" - - set_runtime(booter, ScreenshotOnlySandbox()) - - assert booter.capabilities == ("python", "shell", "filesystem", "gui", "screenshot") - - -@pytest.mark.asyncio -async def test_cua_shutdown_clears_cached_components(): - from astrbot.core.computer.booters.cua import ( - CuaBooter, - CuaFileSystemComponent, - CuaGUIComponent, - CuaPythonComponent, - CuaShellComponent, - _CuaRuntime, - ) - - closed = [] - - class FakeSandboxContext: - async def __aexit__(self, exc_type, exc, tb): - closed.append(True) - - booter = CuaBooter() - sandbox = FakeSandbox() - booter._runtime = _CuaRuntime( - sandbox_cm=FakeSandboxContext(), - sandbox=sandbox, - shell=CuaShellComponent(sandbox), - python=CuaPythonComponent(sandbox), - fs=CuaFileSystemComponent(sandbox), - gui=CuaGUIComponent(sandbox), - ) - - await booter.shutdown() - - assert closed == [True] - assert await booter.available() is False - assert booter._runtime is None - - -def test_cua_tools_are_registered_as_builtin_tools(): - from astrbot.core.tools.computer_tools.cua import ( - CuaKeyboardTypeTool, - CuaMouseClickTool, - CuaScreenshotTool, - ) - - manager = FunctionToolManager() - - assert manager.get_builtin_tool(CuaScreenshotTool).name == "astrbot_cua_screenshot" - assert manager.get_builtin_tool(CuaMouseClickTool).name == "astrbot_cua_mouse_click" - assert ( - manager.get_builtin_tool(CuaKeyboardTypeTool).name - == "astrbot_cua_keyboard_type" - ) - - -def test_cua_runtime_tools_are_available_to_handoffs(): - manager = FunctionToolManager() - - tools = FunctionToolExecutor._get_runtime_computer_tools("sandbox", manager, "cua") - - assert "astrbot_cua_screenshot" in tools - assert "astrbot_cua_mouse_click" in tools - assert "astrbot_cua_keyboard_type" in tools - assert "astrbot_cua_key_press" not in tools - - -def test_runtime_tool_selection_treats_none_booter_as_empty(): - manager = FunctionToolManager() - - tools = FunctionToolExecutor._get_runtime_computer_tools("sandbox", manager, None) - - assert "astrbot_execute_shell" in tools - assert "astrbot_cua_screenshot" not in tools - - -def test_runtime_tool_selection_normalizes_cua_booter_case(): - manager = FunctionToolManager() - - tools = FunctionToolExecutor._get_runtime_computer_tools("sandbox", manager, "CUA") - - assert "astrbot_cua_screenshot" in tools - - -def test_cua_is_exposed_in_sandbox_config_metadata(): - items = _agent_computer_use_items() - booter = items["provider_settings.sandbox.booter"] - - assert "cua" in booter["options"] - assert "CUA" in booter["labels"] - assert "provider_settings.sandbox.cua_image" in items - assert "provider_settings.sandbox.cua_os_type" in items - assert "provider_settings.sandbox.cua_ttl" not in items - assert "provider_settings.sandbox.cua_idle_timeout" in items - assert "provider_settings.sandbox.cua_telemetry_enabled" in items - assert "provider_settings.sandbox.cua_local" in items - assert "provider_settings.sandbox.cua_api_key" in items - assert ( - items["provider_settings.sandbox.cua_api_key"]["condition"][ - "provider_settings.sandbox.cua_local" - ] - is False - ) - - -_PNG_BYTES = base64.b64decode( - "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mP8/x8AAwMCAO+/p9sAAAAASUVORK5CYII=" -) - - -@pytest.mark.asyncio -async def test_screenshot_tool_returns_image_and_sends_file(monkeypatch, tmp_path): - from astrbot.core.tools.computer_tools import cua as cua_tools - from astrbot.core.tools.computer_tools.cua import CuaScreenshotTool - - sent_messages = [] - - class FakeEvent: - unified_msg_origin = "umo" - role = "admin" - - async def send(self, message): - sent_messages.append(message) - - class FakeAstrContext: - event = FakeEvent() - context = FakeContext( - { - "provider_settings": { - "computer_use_runtime": "sandbox", - "computer_use_require_admin": True, - "sandbox": {"booter": "cua"}, - } - } - ) - - class FakeWrapper: - context = FakeAstrContext() - - class FakeGUI: - async def screenshot(self, path: str): - Path(path).write_bytes(b"fake-png") - return { - "success": True, - "path": path, - "mime_type": "image/png", - "base64": base64.b64encode(b"fake-png").decode(), - } - - class FakeBooter: - gui = FakeGUI() - - async def fake_get_booter(context, session_id): - return FakeBooter() - - monkeypatch.setattr(cua_tools, "get_booter", fake_get_booter) - monkeypatch.setattr(cua_tools, "get_astrbot_temp_path", lambda: str(tmp_path)) - - result = await CuaScreenshotTool().call(FakeWrapper(), send_to_user=True) - - assert isinstance(result, mcp.types.CallToolResult) - image_parts = [part for part in result.content if part.type == "image"] - text_parts = [part for part in result.content if part.type == "text"] - payload = json.loads(text_parts[0].text) - assert image_parts[0].data == base64.b64encode(b"fake-png").decode() - assert "base64" not in payload - assert Path(payload["path"]).exists() - assert sent_messages - - -@pytest.mark.parametrize( - "screenshot_shape", - [ - "data_url", - "path_string", - "save_object", - "base64_dict", - ], -) -@pytest.mark.asyncio -async def test_screenshot_tool_normalizes_supported_screenshot_shapes( - monkeypatch, - tmp_path, - screenshot_shape, -): - from astrbot.core.computer.booters.cua import CuaGUIComponent - from astrbot.core.tools.computer_tools import cua as cua_tools - from astrbot.core.tools.computer_tools.cua import CuaScreenshotTool - - sent_messages = [] - - class FakeEvent: - unified_msg_origin = "umo" - role = "admin" - - async def send(self, message): - sent_messages.append(message) - - class FakeAstrContext: - event = FakeEvent() - context = FakeContext( - { - "provider_settings": { - "computer_use_runtime": "sandbox", - "computer_use_require_admin": True, - "sandbox": {"booter": "cua"}, - } - } - ) - - class FakeWrapper: - context = FakeAstrContext() - - class SaveObject: - def save(self, output, format): - assert format == "PNG" - output.write(_PNG_BYTES) - - class FakeSandbox: - async def screenshot(self): - if screenshot_shape == "data_url": - encoded = base64.b64encode(_PNG_BYTES).decode() - return f"data:image/png;base64,{encoded}" - if screenshot_shape == "path_string": - source_path = tmp_path / "source.png" - source_path.write_bytes(_PNG_BYTES) - return str(source_path) - if screenshot_shape == "save_object": - return SaveObject() - return {"base64": base64.b64encode(_PNG_BYTES).decode()} - - class FakeBooter: - gui = CuaGUIComponent(FakeSandbox()) - - async def fake_get_booter(context, session_id): - return FakeBooter() - - monkeypatch.setattr(cua_tools, "get_booter", fake_get_booter) - monkeypatch.setattr(cua_tools, "get_astrbot_temp_path", lambda: str(tmp_path)) - - result = await CuaScreenshotTool().call(FakeWrapper(), send_to_user=True) - - assert isinstance(result, mcp.types.CallToolResult) - image_parts = [part for part in result.content if part.type == "image"] - text_parts = [part for part in result.content if part.type == "text"] - payload = json.loads(text_parts[0].text) - assert "base64" not in payload - assert payload["mime_type"] == "image/png" - assert Path(payload["path"]).read_bytes() == _PNG_BYTES - assert base64.b64decode(image_parts[0].data) == _PNG_BYTES - assert sent_messages - - -@pytest.mark.asyncio -async def test_screenshot_tool_can_opt_in_to_llm_image_content(monkeypatch, tmp_path): - from astrbot.core.tools.computer_tools import cua as cua_tools - from astrbot.core.tools.computer_tools.cua import CuaScreenshotTool - - class FakeEvent: - unified_msg_origin = "umo" - role = "admin" - - async def send(self, message): - pass - - class FakeAstrContext: - event = FakeEvent() - context = FakeContext( - {"provider_settings": {"computer_use_require_admin": True}} - ) - - class FakeWrapper: - context = FakeAstrContext() - - class FakeGUI: - async def screenshot(self, path: str): - Path(path).write_bytes(b"fake-png") - return { - "success": True, - "path": path, - "mime_type": "image/png", - "base64": base64.b64encode(b"fake-png").decode(), - } - - class FakeBooter: - gui = FakeGUI() - - async def fake_get_booter(context, session_id): - return FakeBooter() - - monkeypatch.setattr(cua_tools, "get_booter", fake_get_booter) - monkeypatch.setattr(cua_tools, "get_astrbot_temp_path", lambda: str(tmp_path)) - - result = await CuaScreenshotTool().call( - FakeWrapper(), send_to_user=False, return_image_to_llm=True - ) - - image_parts = [part for part in result.content if part.type == "image"] - text_parts = [part for part in result.content if part.type == "text"] - payload = json.loads(text_parts[0].text) - assert image_parts[0].data == base64.b64encode(b"fake-png").decode() - assert "base64" not in payload - - -@pytest.mark.asyncio -async def test_screenshot_tool_can_opt_out_of_llm_image_content(monkeypatch, tmp_path): - from astrbot.core.tools.computer_tools import cua as cua_tools - from astrbot.core.tools.computer_tools.cua import CuaScreenshotTool - - class FakeEvent: - unified_msg_origin = "umo" - role = "admin" - - async def send(self, message): - pass - - class FakeAstrContext: - event = FakeEvent() - context = FakeContext( - {"provider_settings": {"computer_use_require_admin": True}} - ) - - class FakeWrapper: - context = FakeAstrContext() - - class FakeGUI: - async def screenshot(self, path: str): - Path(path).write_bytes(b"fake-png") - return { - "success": True, - "path": path, - "mime_type": "image/png", - "base64": base64.b64encode(b"fake-png").decode(), - } - - class FakeBooter: - gui = FakeGUI() - - async def fake_get_booter(context, session_id): - return FakeBooter() - - monkeypatch.setattr(cua_tools, "get_booter", fake_get_booter) - monkeypatch.setattr(cua_tools, "get_astrbot_temp_path", lambda: str(tmp_path)) - - result = await CuaScreenshotTool().call( - FakeWrapper(), send_to_user=False, return_image_to_llm=False - ) - - image_parts = [part for part in result.content if part.type == "image"] - text_parts = [part for part in result.content if part.type == "text"] - payload = json.loads(text_parts[0].text) - assert image_parts == [] - assert "base64" not in payload - - -@pytest.mark.asyncio -async def test_cua_tools_return_permission_error_without_gui_lookup(monkeypatch): - from astrbot.core.tools.computer_tools import cua as cua_tools - from astrbot.core.tools.computer_tools.cua import ( - CuaKeyboardTypeTool, - CuaMouseClickTool, - CuaScreenshotTool, - ) - - sent_messages = [] - - class FakeEvent: - unified_msg_origin = "umo" - role = "member" - - async def send(self, message): - sent_messages.append(message) - - class FakeAstrContext: - event = FakeEvent() - context = FakeContext({"provider_settings": {}}) - - class FakeWrapper: - context = FakeAstrContext() - - async def fail_gui_lookup(context): - raise AssertionError("GUI lookup should not run after permission failure") - - monkeypatch.setattr(cua_tools, "check_admin_permission", lambda *args: "denied") - monkeypatch.setattr(cua_tools, "_get_gui_component", fail_gui_lookup) - - assert await CuaScreenshotTool().call(FakeWrapper()) == "denied" - assert await CuaMouseClickTool().call(FakeWrapper(), x=1, y=2) == "denied" - assert await CuaKeyboardTypeTool().call(FakeWrapper(), text="hello") == "denied" - assert sent_messages == [] - - -@pytest.mark.asyncio -async def test_cua_tools_include_exception_type_for_blank_error(monkeypatch): - from astrbot.core.tools.computer_tools import cua as cua_tools - from astrbot.core.tools.computer_tools.cua import CuaMouseClickTool - - class BlankError(Exception): - def __str__(self): - return "" - - class FakeEvent: - unified_msg_origin = "umo" - role = "admin" - - class FakeAstrContext: - event = FakeEvent() - context = FakeContext( - {"provider_settings": {"computer_use_require_admin": True}} - ) - - class FakeWrapper: - context = FakeAstrContext() - - async def fail_gui_lookup(context): - raise BlankError() - - monkeypatch.setattr(cua_tools, "_get_gui_component", fail_gui_lookup) - - assert await CuaMouseClickTool().call(FakeWrapper(), x=1, y=2) == ( - "Error clicking CUA desktop: BlankError" - ) - - -@pytest.mark.asyncio -async def test_cua_mouse_click_tool_happy_path_forwards_args_and_serializes_json( - monkeypatch, -): - from astrbot.core.tools.computer_tools import cua as cua_tools - from astrbot.core.tools.computer_tools.cua import CuaMouseClickTool - - class FakeEvent: - unified_msg_origin = "umo" - role = "admin" - - class FakeAstrContext: - event = FakeEvent() - context = FakeContext( - {"provider_settings": {"computer_use_require_admin": True}} - ) - - class FakeWrapper: - context = FakeAstrContext() - - class FakeGui: - def __init__(self): - self.clicked_args = None - - async def click(self, x: int, y: int, button: str = "left"): - self.clicked_args = (x, y, button) - return {"status": "ok", "x": x, "y": y, "button": button} - - fake_gui = FakeGui() - get_gui_called = {"value": False} - wrapper = FakeWrapper() - - async def fake_get_gui_component(context): - get_gui_called["value"] = True - assert context is wrapper - return fake_gui - - monkeypatch.setattr(cua_tools, "_get_gui_component", fake_get_gui_component) - - result = await CuaMouseClickTool().call(wrapper, x=10, y=20, button="right") - - assert get_gui_called["value"] is True - assert fake_gui.clicked_args == (10, 20, "right") - assert json.loads(result) == { - "status": "ok", - "x": 10, - "y": 20, - "button": "right", - } - - -@pytest.mark.asyncio -async def test_cua_keyboard_type_tool_happy_path_forwards_args_and_serializes_json( - monkeypatch, -): - from astrbot.core.tools.computer_tools import cua as cua_tools - from astrbot.core.tools.computer_tools.cua import CuaKeyboardTypeTool - - class FakeEvent: - unified_msg_origin = "umo" - role = "admin" - - class FakeAstrContext: - event = FakeEvent() - context = FakeContext( - {"provider_settings": {"computer_use_require_admin": True}} - ) - - class FakeWrapper: - context = FakeAstrContext() - - class FakeGui: - def __init__(self): - self.typed_text_args = None - - async def type_text(self, text: str): - self.typed_text_args = (text,) - return {"status": "ok", "text": text} - - fake_gui = FakeGui() - get_gui_called = {"value": False} - wrapper = FakeWrapper() - - async def fake_get_gui_component(context): - get_gui_called["value"] = True - assert context is wrapper - return fake_gui - - monkeypatch.setattr(cua_tools, "_get_gui_component", fake_get_gui_component) - - result = await CuaKeyboardTypeTool().call(wrapper, text="Hello CUA") - - assert get_gui_called["value"] is True - assert fake_gui.typed_text_args == ("Hello CUA",) - assert json.loads(result) == {"status": "ok", "text": "Hello CUA"} diff --git a/tests/unit/test_cua_extracted_from_core.py b/tests/unit/test_cua_extracted_from_core.py new file mode 100644 index 0000000000..eb8876d624 --- /dev/null +++ b/tests/unit/test_cua_extracted_from_core.py @@ -0,0 +1,60 @@ +import importlib.util + +from astrbot.core.config.default import DEFAULT_CONFIG + + +def test_core_no_longer_ships_concrete_sandbox_runtime_modules(): + assert importlib.util.find_spec("astrbot.core.computer.booters.cua") is None + assert ( + importlib.util.find_spec("astrbot.core.computer.booters.cua_defaults") is None + ) + assert importlib.util.find_spec("astrbot.core.tools.computer_tools.cua") is None + assert importlib.util.find_spec("astrbot.core.computer.booters.shipyard") is None + assert ( + importlib.util.find_spec("astrbot.core.computer.booters.shipyard_neo") is None + ) + assert importlib.util.find_spec("astrbot.core.computer.booters.boxlite") is None + assert importlib.util.find_spec("astrbot.core.computer.booters.bay_manager") is None + assert ( + importlib.util.find_spec("astrbot.core.computer.booters.shell_background") + is None + ) + assert ( + importlib.util.find_spec( + "astrbot.core.computer.booters.shipyard_search_file_util" + ) + is None + ) + assert ( + importlib.util.find_spec("astrbot.core.tools.computer_tools.shipyard_neo") + is None + ) + + +def test_core_default_config_does_not_include_runtime_specific_settings(): + sandbox = DEFAULT_CONFIG["provider_settings"]["sandbox"] + + assert sandbox["booter"] == "" + assert sandbox["sandbox_ttl"] == 3600 + assert sandbox["sandbox_idle_timeout"] == 1800 + assert sandbox["sandbox_lease_timeout"] == 600 + assert "cua_image" not in sandbox + assert "cua_os_type" not in sandbox + assert "cua_ttl" not in sandbox + assert "cua_telemetry_enabled" not in sandbox + assert "cua_local" not in sandbox + assert "cua_api_key" not in sandbox + assert "shipyard_endpoint" not in sandbox + assert "shipyard_neo_endpoint" not in sandbox + assert "shipyard_neo_profile" not in sandbox + + +def test_core_sandbox_config_metadata_is_provider_agnostic(): + from astrbot.core.config.default import CONFIG_METADATA_3 + + items = CONFIG_METADATA_3["ai_group"]["metadata"]["agent_computer_use"]["items"] + booter = items["provider_settings.sandbox.booter"] + + assert booter["options"] == [] + assert booter["labels"] == [] + assert not any("shipyard" in key or "cua" in key for key in items) diff --git a/tests/unit/test_func_tool_manager.py b/tests/unit/test_func_tool_manager.py index d53ed3296f..275e2e1b89 100644 --- a/tests/unit/test_func_tool_manager.py +++ b/tests/unit/test_func_tool_manager.py @@ -1,7 +1,9 @@ import json +from dataclasses import dataclass, field import pytest +from astrbot.api import FunctionTool from astrbot.core import sp from astrbot.core.provider.func_tool_manager import FunctionToolManager from astrbot.core.tools.computer_tools.shell import ExecuteShellTool @@ -49,6 +51,26 @@ def test_computer_tools_are_registered_as_builtin_tools(): assert manager.is_builtin_tool("astrbot_execute_shell") is True +def test_clear_builtin_tool_cache_by_module_prefix_removes_matching_instances(): + manager = FunctionToolManager() + + @dataclass + class ExampleTool(FunctionTool): + name: str = "astrbot_test_cached_tool" + description: str = "Cached tool for eviction testing." + parameters: dict = field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + + tool = ExampleTool() + manager.builtin_func_list[ExampleTool] = tool + + removed = manager.clear_builtin_tool_cache_by_module_prefix(__name__) + + assert removed == ["astrbot_test_cached_tool"] + assert manager.builtin_func_list == {} + + @pytest.mark.asyncio async def test_execute_shell_defaults_to_foreground(monkeypatch): from astrbot.core.tools.computer_tools import shell as shell_tools diff --git a/tests/unit/test_migra_helper.py b/tests/unit/test_migra_helper.py new file mode 100644 index 0000000000..918d543f92 --- /dev/null +++ b/tests/unit/test_migra_helper.py @@ -0,0 +1,64 @@ +import json + +from astrbot.core.config.astrbot_config import AstrBotConfig +from astrbot.core.utils.migra_helper import _prune_invalid_provider_source_models + + +def test_prune_invalid_provider_source_models_removes_unresolvable_entries(tmp_path): + config_path = tmp_path / "cmd_config.json" + config_path.write_text( + json.dumps( + { + "provider_sources": [ + { + "id": "valid_source", + "type": "openai_chat_completion", + "provider_type": "chat_completion", + }, + {"id": "broken_source"}, + ], + "provider": [ + { + "id": "valid-model", + "provider_source_id": "valid_source", + "model": "gpt-test", + "enable": True, + }, + { + "id": "missing-source-model", + "provider_source_id": "missing_source", + "model": "stale", + "enable": True, + }, + { + "id": "broken-source-model", + "provider_source_id": "broken_source", + "model": "stale", + "enable": True, + }, + { + "id": "legacy-direct-provider", + "type": "openai_chat_completion", + "model": "gpt-test", + "enable": True, + }, + ], + }, + ensure_ascii=False, + ), + encoding="utf-8-sig", + ) + + conf = AstrBotConfig( + config_path=str(config_path), + default_config={"provider_sources": [], "provider": []}, + ) + + _prune_invalid_provider_source_models(conf) + + provider_ids = [provider["id"] for provider in conf["provider"]] + assert provider_ids == ["valid-model", "legacy-direct-provider"] + + saved = json.loads(config_path.read_text(encoding="utf-8-sig")) + saved_ids = [provider["id"] for provider in saved["provider"]] + assert saved_ids == provider_ids diff --git a/tests/unit/test_sandbox_computer_client.py b/tests/unit/test_sandbox_computer_client.py new file mode 100644 index 0000000000..9140de4cf0 --- /dev/null +++ b/tests/unit/test_sandbox_computer_client.py @@ -0,0 +1,741 @@ +import asyncio +import time + +import pytest + +from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor + + +class FakeBooter: + def __init__(self): + self.shutdown_calls = 0 + + async def available(self): + return True + + async def shutdown(self): + self.shutdown_calls += 1 + + +class FakeProvider: + provider_id = "generic" + capabilities = {"shell"} + tool_names = {"generic_tool"} + system_prompt = "Use provider-specific sandbox rules." + + def __init__(self): + self.created = [] + + def build_create_config(self, context, session_id): + return {} + + def build_connect_info(self, sandbox_name, config): + return {"name": sandbox_name} + + def update_connect_info(self, record, *, sandbox_name): + return {"name": sandbox_name} + + async def create_booter(self, context, session_id, sandbox_id, config): + self.created.append((session_id, sandbox_id, config)) + return FakeBooter() + + async def destroy_booter(self, booter, record): + await booter.shutdown() + + +class OtherFakeProvider(FakeProvider): + provider_id = "other" + capabilities = {"filesystem", "python"} + + async def create_booter(self, context, session_id, sandbox_id, config): + self.created.append((session_id, sandbox_id, config)) + return OtherFakeBooter() + + +class OtherFakeBooter(FakeBooter): + pass + + +class FakeContext: + def get_config(self, umo=None): + return { + "provider_settings": { + "computer_use_runtime": "sandbox", + "sandbox": {"booter": "generic"}, + } + } + + +@pytest.mark.asyncio +async def test_registered_generic_provider_handles_booter(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeProvider() + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + monkeypatch.setattr(computer_client, "_sync_skills_to_sandbox", lambda booter: None) + + computer_client.register_sandbox_provider(provider) + booter = await computer_client.get_booter(FakeContext(), "session-a") + + assert isinstance(booter, FakeBooter) + assert computer_client.list_sandbox_providers() == [ + { + "provider_id": "generic", + "capabilities": ["shell"], + "tool_names": ["generic_tool"], + "system_prompt": "Use provider-specific sandbox rules.", + } + ] + + +def test_register_sandbox_provider_tags_provider_tools(monkeypatch, tmp_path): + from astrbot.core.agent.tool import FunctionTool + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + tool = FunctionTool( + name="generic_tool", + parameters={"type": "object", "properties": {}}, + description="generic", + ) + + computer_client.register_sandbox_provider(FakeProvider(), tools=[tool]) + + assert tool.sandbox_provider_id == "generic" + assert "Sandbox provider-specific tool: generic" in tool.description + assert "current sandbox uses provider 'generic'" in tool.description + + +def test_register_sandbox_provider_does_not_duplicate_provider_tool_description( + monkeypatch, tmp_path +): + from astrbot.core.agent.tool import FunctionTool + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + tool = FunctionTool( + name="generic_tool", + parameters={"type": "object", "properties": {}}, + description=( + "[Sandbox provider-specific tool: generic] This tool only works when " + "the current sandbox uses provider 'generic'. generic" + ), + ) + + computer_client.register_sandbox_provider(FakeProvider(), tools=[tool]) + + assert tool.description.count("Sandbox provider-specific tool: generic") == 1 + + +@pytest.mark.asyncio +async def test_get_booter_prefers_current_sandbox_over_configured_provider( + monkeypatch, tmp_path +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + generic = FakeProvider() + other = OtherFakeProvider() + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={generic.provider_id: generic, other.provider_id: other}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + current = await manager.create_sandbox(None, "session-a", "other") + + booter = await computer_client.get_booter(FakeContext(), "session-a") + + assert isinstance(booter, OtherFakeBooter) + assert manager.registry.get_current_sandbox_id("session-a") == current["sandbox_id"] + assert len(generic.created) == 0 + + +@pytest.mark.asyncio +async def test_get_booter_renews_current_sandbox_lease(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + generic = FakeProvider() + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={generic.provider_id: generic}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + current = await manager.create_sandbox(None, "session-a", "generic") + manager.registry._payload["sandboxes"][current["sandbox_id"]][ + "lease_expires_at" + ] = time.time() - 1 + + await computer_client.get_booter(FakeContext(), "session-a") + + renewed = manager.registry.get_sandbox(current["sandbox_id"]) + assert renewed["controller_session_id"] == "session-a" + assert renewed["lease_expires_at"] > time.time() + + +def test_computer_client_does_not_expose_legacy_session_cache(): + from astrbot.core.computer import computer_client + + assert not hasattr(computer_client, "session_booter") + + +def test_sandbox_tool_formats_timestamps_for_agent(): + from astrbot.core.tools.computer_tools.sandbox import _format_sandbox_for_agent + + payload = _format_sandbox_for_agent( + { + "lease_expires_at": 1778557598.4646258, + "idle_cleanup_at": None, + "expires_at": 1778559999, + "nested": [{"last_used_at": 1778550000}], + } + ) + + assert payload["lease_expires_at"] != 1778557598.4646258 + assert payload["lease_expires_at"] + assert payload["idle_cleanup_at"] is None + assert payload["expires_at"] + assert payload["nested"][0]["last_used_at"] + + +@pytest.mark.asyncio +async def test_sync_skills_uses_active_manager_booters(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + synced = [] + + async def fake_sync(booter, provider_id=None): + assert provider_id is None or provider_id == "generic" + synced.append(booter) + + manager_booter = FakeBooter() + manager.session_booter["generic-1"] = manager_booter + monkeypatch.setattr(computer_client, "_sync_skills_to_sandbox", fake_sync) + + await computer_client.sync_skills_to_active_sandboxes() + + assert synced == [manager_booter] + + +def test_register_provider_rejects_duplicate_by_default(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + computer_client.register_sandbox_provider(FakeProvider()) + + with pytest.raises(RuntimeError, match="already registered"): + computer_client.register_sandbox_provider(FakeProvider()) + + +def test_computer_client_does_not_load_registry_on_import(monkeypatch): + import importlib + + import astrbot.core.computer.sandbox_registry as sandbox_registry_module + + loads = [] + original_class = sandbox_registry_module.SandboxRegistry + + class TrackingSandboxRegistry(original_class): + def load(self): + loads.append(self.storage_path) + + monkeypatch.setattr( + sandbox_registry_module, "SandboxRegistry", TrackingSandboxRegistry + ) + + import astrbot.core.computer.computer_client as computer_client_module + + importlib.reload(computer_client_module) + + assert loads == [] + + +def test_register_provider_can_replace_when_requested(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + replacement = FakeProvider() + replacement.capabilities = {"keyboard", "mouse"} + + computer_client.register_sandbox_provider(FakeProvider()) + computer_client.register_sandbox_provider(replacement, replace=True) + + assert computer_client.get_sandbox_provider_info("generic") == { + "provider_id": "generic", + "capabilities": ["keyboard", "mouse"], + "tool_names": ["generic_tool"], + "system_prompt": "Use provider-specific sandbox rules.", + } + + +def test_unregister_provider_removes_registered_provider_tools(monkeypatch, tmp_path): + from astrbot.core.agent.tool import FunctionTool + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + FunctionToolExecutor.clear_runtime_computer_tools_cache() + tool = FunctionTool( + name="generic_tool_unregister_once", + parameters={"type": "object", "properties": {}}, + description="generic", + ) + + computer_client.register_sandbox_provider(FakeProvider(), tools=[tool]) + + assert computer_client.llm_tools.get_func("generic_tool_unregister_once") is tool + + computer_client.unregister_sandbox_provider("generic") + + assert computer_client.llm_tools.get_func("generic_tool_unregister_once") is None + + +def test_current_sandbox_provider_ignores_terminal_status(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + monkeypatch.setattr(computer_client, "sandbox_registry", manager.registry) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Terminal", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={}, + status="error", + ) + manager.registry.set_current_sandbox_id("session-a", "generic-1") + + assert computer_client.get_current_sandbox_provider_id("session-a") is None + + +def test_unregister_provider_rejects_active_managed_sandboxes(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + computer_client.register_sandbox_provider(FakeProvider()) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Generic 1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={}, + ) + + with pytest.raises(RuntimeError, match="active managed sandboxes"): + computer_client.unregister_sandbox_provider("generic") + + +def test_unregister_provider_allows_force(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + computer_client.register_sandbox_provider(FakeProvider()) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Generic 1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={}, + ) + + computer_client.unregister_sandbox_provider("generic", force=True) + + assert computer_client.get_sandbox_provider_info("generic") is None + + +def test_unregister_provider_force_preserves_persistent_sandboxes( + monkeypatch, tmp_path +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + computer_client.register_sandbox_provider(FakeProvider()) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Generic 1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={}, + retention_policy="persistent", + status="running", + ) + manager.session_booter["generic-1"] = object() + + computer_client.unregister_sandbox_provider("generic", force=True) + + record = manager.registry.get_sandbox("generic-1") + assert record is not None + assert record["retention_policy"] == "persistent" + assert computer_client.get_sandbox_provider_info("generic") is None + assert "generic-1" not in manager.session_booter + + +@pytest.mark.asyncio +async def test_unregister_provider_force_closes_persistent_booters( + monkeypatch, tmp_path +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + computer_client.register_sandbox_provider(FakeProvider()) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Generic 1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={}, + retention_policy="persistent", + status="running", + ) + persistent_booter = FakeBooter() + manager.session_booter["generic-1"] = persistent_booter + + computer_client.unregister_sandbox_provider("generic", force=True) + await asyncio.sleep(0) + + record = manager.registry.get_sandbox("generic-1") + assert record is not None + assert record["retention_policy"] == "persistent" + assert "generic-1" not in manager.session_booter + assert persistent_booter.shutdown_calls == 1 + + +def test_list_sandbox_providers_is_sorted(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + computer_client.register_sandbox_provider(OtherFakeProvider()) + computer_client.register_sandbox_provider(FakeProvider()) + + assert computer_client.list_sandbox_providers() == [ + { + "provider_id": "generic", + "capabilities": ["shell"], + "tool_names": ["generic_tool"], + "system_prompt": "Use provider-specific sandbox rules.", + }, + { + "provider_id": "other", + "capabilities": ["filesystem", "python"], + "tool_names": ["generic_tool"], + "system_prompt": "Use provider-specific sandbox rules.", + }, + ] + + +@pytest.mark.asyncio +async def test_cleanup_registered_sandbox_manager(monkeypatch, tmp_path): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeProvider() + destroyed = [] + + async def fake_destroy_booter(booter, record): + destroyed.append(record["sandbox_id"]) + await booter.shutdown() + + provider.destroy_booter = fake_destroy_booter + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={provider.provider_id: provider}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + await manager.create_sandbox(None, "session-a", "generic") + await computer_client.cleanup_managed_sandboxes() + + assert manager.list_sandboxes() == [] + + +@pytest.mark.asyncio +async def test_cleanup_sandbox_provider_destroys_temporary_and_preserves_persistent_records( + monkeypatch, tmp_path +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeProvider() + destroyed = [] + + async def fake_destroy_booter(booter, record): + destroyed.append(record["sandbox_id"]) + await booter.shutdown() + + provider.destroy_booter = fake_destroy_booter + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={provider.provider_id: provider}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + temporary = manager.registry.upsert_sandbox( + sandbox_id="generic-temp", + sandbox_name="Temp", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={}, + retention_policy="temporary", + status="running", + ) + persistent = manager.registry.upsert_sandbox( + sandbox_id="generic-persistent", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-b", + owner_session_id="session-b", + connect_info={}, + retention_policy="persistent", + status="running", + ) + temp_booter = FakeBooter() + temp_booter.provider_id = provider.provider_id + persistent_booter = FakeBooter() + persistent_booter.provider_id = provider.provider_id + manager.session_booter[temporary["sandbox_id"]] = temp_booter + manager.session_booter[persistent["sandbox_id"]] = persistent_booter + + await computer_client.cleanup_sandbox_provider("generic") + + assert manager.registry.get_sandbox("generic-temp") is None + assert manager.registry.get_sandbox("generic-persistent") is not None + assert destroyed == ["generic-temp"] + assert persistent_booter.shutdown_calls == 1 + + +@pytest.mark.asyncio +async def test_cleanup_sandbox_provider_preserves_temporary_record_when_destroy_fails( + monkeypatch, tmp_path +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeProvider() + + async def fake_destroy_booter(booter, record): + raise RuntimeError("destroy failed") + + provider.destroy_booter = fake_destroy_booter + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={provider.provider_id: provider}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + manager.registry.upsert_sandbox( + sandbox_id="generic-temp", + sandbox_name="Temp", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={}, + retention_policy="temporary", + status="running", + ) + booter = FakeBooter() + booter.provider_id = provider.provider_id + manager.session_booter["generic-temp"] = booter + + await computer_client.cleanup_sandbox_provider("generic") + + record = manager.registry.get_sandbox("generic-temp") + assert record is not None + assert record["status"] == "error" + assert manager.session_booter["generic-temp"] is booter + + +@pytest.mark.asyncio +async def test_cleanup_sandbox_provider_cleans_live_booter_without_registry_record( + monkeypatch, tmp_path +): + from astrbot.core.computer import computer_client + from astrbot.core.computer.sandbox_manager import SandboxManager + from astrbot.core.computer.sandbox_registry import SandboxRegistry + + provider = FakeProvider() + destroyed = [] + + async def fake_destroy_booter(booter, record): + destroyed.append(record["sandbox_id"]) + await booter.shutdown() + + provider.destroy_booter = fake_destroy_booter + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={provider.provider_id: provider}, + ) + monkeypatch.setattr(computer_client, "sandbox_manager", manager) + + booter = FakeBooter() + booter.provider_id = provider.provider_id + manager.session_booter["generic-orphan"] = booter + + await computer_client.cleanup_sandbox_provider("generic") + + assert destroyed == ["generic-orphan"] + assert manager.session_booter == {} + + +@pytest.mark.asyncio +async def test_core_lifecycle_stop_cleans_up_temporary_managed_sandboxes(monkeypatch): + from types import SimpleNamespace + + from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + + cleaned = [] + + async def fake_cleanup_managed_sandboxes(): + cleaned.append("called") + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.cleanup_managed_sandboxes", + fake_cleanup_managed_sandboxes, + ) + + lifecycle = object.__new__(AstrBotCoreLifecycle) + lifecycle.temp_dir_cleaner = None + lifecycle.curr_tasks = [] + lifecycle.cron_manager = None + lifecycle.provider_manager = SimpleNamespace(terminate=lambda: None) + lifecycle.platform_manager = SimpleNamespace(terminate=lambda: None) + lifecycle.kb_manager = SimpleNamespace(terminate=lambda: None) + lifecycle.dashboard_shutdown_event = SimpleNamespace(set=lambda: None) + lifecycle.plugin_manager = SimpleNamespace( + context=SimpleNamespace(get_all_stars=lambda: []), + _terminate_plugin=lambda plugin: None, + ) + + async def provider_terminate(): + return None + + async def platform_terminate(): + return None + + async def kb_terminate(): + return None + + async def terminate_plugin(plugin): + return None + + lifecycle.provider_manager.terminate = provider_terminate + lifecycle.platform_manager.terminate = platform_terminate + lifecycle.kb_manager.terminate = kb_terminate + lifecycle.plugin_manager._terminate_plugin = terminate_plugin + lifecycle._persistent_restore_task = None + + await AstrBotCoreLifecycle.stop(lifecycle) + + assert cleaned == ["called"] diff --git a/tests/unit/test_sandbox_manager.py b/tests/unit/test_sandbox_manager.py new file mode 100644 index 0000000000..c7c8ca2baf --- /dev/null +++ b/tests/unit/test_sandbox_manager.py @@ -0,0 +1,2688 @@ +import asyncio +import time + +import pytest + +from astrbot.core.computer.sandbox_manager import SandboxManager +from astrbot.core.computer.sandbox_registry import SandboxRegistry + + +class FakeBooter: + def __init__(self): + self.shutdown_calls = 0 + self.available_result = True + + async def available(self): + return self.available_result + + async def shutdown(self): + self.shutdown_calls += 1 + + +class SyncAvailableBooter: + def available(self): + return True + + +class BoolAvailableBooter: + available = True + + +class BaseDefaultAvailableBooter: + pass + + +class NoneAvailableBooter: + async def available(self): + return None + + +class UnavailablePropertyBooter: + available = False + + +class FakeProvider: + provider_id = "generic" + capabilities = {"shell", "python", "filesystem", "screenshot", "mouse", "keyboard"} + tool_names = {"astrbot_generic_screenshot"} + supports_persistent_reconnect = True + + def __init__(self): + self.created = [] + self.destroyed = [] + + def build_create_config(self, context, session_id): + return {"session_id": session_id} + + def build_connect_info(self, sandbox_name, config): + return {"name": sandbox_name, **config} + + def update_connect_info(self, record, *, sandbox_name): + info = dict(record.get("connect_info") or {}) + info["name"] = sandbox_name + return info + + async def create_booter(self, context, session_id, sandbox_id, config): + booter = FakeBooter() + self.created.append((session_id, sandbox_id, booter, config)) + return booter + + async def destroy_booter(self, booter, record): + self.destroyed.append((booter, record["sandbox_id"])) + await booter.shutdown() + + +class FakeContext: + def __init__(self, sandbox_config=None): + self._sandbox_config = sandbox_config or {} + + def get_config(self, umo): + return {"provider_settings": {"sandbox": dict(self._sandbox_config)}} + + +class RecordCapturingProvider(FakeProvider): + def __init__(self): + super().__init__() + self.boot_started = asyncio.Event() + self.allow_boot = asyncio.Event() + self.destroyed_records = [] + + async def create_booter(self, context, session_id, sandbox_id, config): + self.boot_started.set() + await self.allow_boot.wait() + return await super().create_booter(context, session_id, sandbox_id, config) + + async def destroy_booter(self, booter, record): + self.destroyed_records.append(dict(record)) + await super().destroy_booter(booter, record) + + +class SaveFailingProvider(FakeProvider): + def __init__(self): + super().__init__() + self.destroyed = [] + + +class OtherFakeProvider(FakeProvider): + provider_id = "other" + + +class BlockingDestroyProvider(FakeProvider): + def __init__(self): + super().__init__() + self.destroy_started = asyncio.Event() + self.allow_destroy = asyncio.Event() + + async def destroy_booter(self, booter, record): + self.destroy_started.set() + await self.allow_destroy.wait() + return await super().destroy_booter(booter, record) + + +class ImmediateDestroyProvider(FakeProvider): + def __init__(self): + super().__init__() + self.destroy_started = asyncio.Event() + + async def destroy_booter(self, booter, record): + self.destroy_started.set() + await super().destroy_booter(booter, record) + + +class FailsOnceDestroyProvider(FakeProvider): + def __init__(self): + super().__init__() + self.destroy_attempts = 0 + + async def destroy_booter(self, booter, record): + self.destroy_attempts += 1 + if self.destroy_attempts == 1: + raise RuntimeError("transient destroy failure") + await super().destroy_booter(booter, record) + + +class SlowCreatedHookProvider(FakeProvider): + def __init__(self): + super().__init__() + self.hook_started = asyncio.Event() + self.allow_hook = asyncio.Event() + self.hook_calls = 0 + + async def on_sandbox_created(self, record): + self.hook_calls += 1 + self.hook_started.set() + await self.allow_hook.wait() + + +class DeferredBootProvider(FakeProvider): + def __init__(self): + super().__init__() + self.boot_started = asyncio.Event() + self.allow_boot = asyncio.Event() + self.raise_on_boot = False + + async def create_booter(self, context, session_id, sandbox_id, config): + self.boot_started.set() + await self.allow_boot.wait() + if self.raise_on_boot: + raise RuntimeError("boot failed") + return await super().create_booter(context, session_id, sandbox_id, config) + + +class SlowDeferredDestroyProvider(DeferredBootProvider): + def __init__(self): + super().__init__() + self.cancelled_during_boot = asyncio.Event() + self.destroy_started = asyncio.Event() + self.allow_destroy = asyncio.Event() + self.pause_destroy = False + + async def create_booter(self, context, session_id, sandbox_id, config): + self.boot_started.set() + try: + await self.allow_boot.wait() + except asyncio.CancelledError: + self.cancelled_during_boot.set() + await self.allow_boot.wait() + if self.raise_on_boot: + raise RuntimeError("boot failed") + return await FakeProvider.create_booter( + self, context, session_id, sandbox_id, config + ) + + async def destroy_booter(self, booter, record): + self.destroy_started.set() + if self.pause_destroy: + await self.allow_destroy.wait() + return await super().destroy_booter(booter, record) + + +class DeadIdleDestroyProvider(FakeProvider): + def __init__(self): + super().__init__() + self.destroy_started = asyncio.Event() + self.destroy_calls = 0 + + async def destroy_booter(self, booter, record): + self.destroy_calls += 1 + self.destroy_started.set() + booter.available_result = False + raise RuntimeError("half-closed booter") + + +class AlwaysFailingIdleDestroyProvider(FakeProvider): + def __init__(self): + super().__init__() + self.destroy_started = asyncio.Event() + self.destroy_calls = 0 + + async def destroy_booter(self, booter, record): + self.destroy_calls += 1 + self.destroy_started.set() + raise RuntimeError("destroy failed") + + +class AlwaysFailingDestroyProvider(FakeProvider): + async def destroy_booter(self, booter, record): + raise RuntimeError("destroy failed") + + +class FailingReconnectProvider(FakeProvider): + async def create_booter(self, context, session_id, sandbox_id, config): + raise RuntimeError("boot failed") + + +class AlwaysBusyManager(SandboxManager): + def acquire_lease(self, sandbox_id: str, session_id: str, *, ttl=None): + return False + + +class MissingPersistentProvider(FakeProvider): + async def check_persistent_sandbox_exists(self, record): + return False + + +class PruningMissingPersistentProvider(MissingPersistentProvider): + prune_missing_persistent_records = True + + +class ExistingPersistentProvider(FakeProvider): + async def check_persistent_sandbox_exists(self, record): + return True + + +class ContextCapturingProvider(FakeProvider): + def __init__(self): + super().__init__() + self.contexts = [] + + async def create_booter(self, context, session_id, sandbox_id, config): + self.contexts.append(context) + return await super().create_booter(context, session_id, sandbox_id, config) + + +class ConnectInfoAfterBootProvider(FakeProvider): + def update_connect_info_after_boot(self, record, booter): + info = dict(record.get("connect_info") or {}) + info["runtime_id"] = getattr(booter, "runtime_id", "runtime-1") + return info + + +def _manager(tmp_path, provider=None): + provider = provider or FakeProvider() + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={provider.provider_id: provider}, + ) + return manager, provider + + +async def wait_until(predicate, *, timeout: float = 1.0) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if predicate(): + return + await asyncio.sleep(0.01) + raise AssertionError("condition was not met before timeout") + + +def test_manager_list_sandboxes_preserves_persisted_tool_names_without_provider( + tmp_path, +): + manager, _provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Generic", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Generic"}, + capabilities={"shell"}, + tool_names={"persisted_tool"}, + ) + manager.providers.clear() + + sandboxes = manager.list_sandboxes() + + assert sandboxes[0]["tool_names"] == ["persisted_tool"] + + +def test_manager_list_sandboxes_does_not_emit_legacy_booter_type(tmp_path): + manager, _provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Generic", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Generic"}, + ) + + sandboxes = manager.list_sandboxes() + + assert sandboxes[0]["provider"] == "generic" + assert "booter_type" not in sandboxes[0] + + +def test_manager_list_sandboxes_migrates_legacy_booter_type_to_provider(tmp_path): + manager, _provider = _manager(tmp_path) + manager.registry._payload["sandboxes"]["legacy-1"] = { + "sandbox_id": "legacy-1", + "sandbox_name": "Legacy", + "booter_type": "generic", + "managed": True, + "created_by_astrbot": True, + "connect_info": {"name": "Legacy"}, + } + + sandboxes = manager.list_sandboxes() + + assert sandboxes[0]["provider"] == "generic" + assert "booter_type" not in sandboxes[0] + + +@pytest.mark.asyncio +async def test_manager_creates_default_sandbox_and_reuses_available_booter(tmp_path): + manager, provider = _manager(tmp_path) + + first = await manager.get_or_create_booter(None, "session-a", "generic") + second = await manager.get_or_create_booter(None, "session-a", "generic") + + assert first is second + assert len(provider.created) == 1 + sandboxes = manager.list_sandboxes() + assert len(sandboxes) == 1 + assert sandboxes[0]["provider"] == "generic" + assert sandboxes[0]["capabilities"] == sorted(provider.capabilities) + assert sandboxes[0]["tool_names"] == sorted(provider.tool_names) + + +@pytest.mark.asyncio +async def test_manager_get_or_create_booter_stops_after_repeated_lease_failures( + tmp_path, +): + provider = FakeProvider() + manager = AlwaysBusyManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={provider.provider_id: provider}, + ) + + with pytest.raises(RuntimeError, match="Could not acquire sandbox lease"): + await manager.get_or_create_booter(None, "session-a", "generic") + + assert len(manager.registry.list_sandboxes()) <= 4 + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_returns_authoritative_registry_state( + tmp_path, +): + manager, _provider = _manager(tmp_path) + + sandbox = await manager.create_sandbox_uncontrolled(None, "session-a", "generic") + + assert sandbox["status"] == "running" + assert sandbox == manager.registry.get_sandbox(sandbox["sandbox_id"]) + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_cleans_up_on_cancellation(tmp_path): + provider = DeferredBootProvider() + manager, _provider = _manager(tmp_path, provider) + + task = asyncio.create_task( + manager.create_sandbox_uncontrolled(None, "session-a", "generic", "Named") + ) + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + records = manager.registry.list_sandboxes() + assert len(records) == 1 + assert records[0]["status"] == "error" + assert manager.session_booter == {} + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_keeps_error_record_on_boot_failure( + tmp_path, +): + provider = FailingReconnectProvider() + manager, _provider = _manager(tmp_path, provider) + + with pytest.raises(RuntimeError, match="boot failed"): + await manager.create_sandbox_uncontrolled(None, "session-a", "generic", "Named") + + records = manager.registry.list_sandboxes() + assert len(records) == 1 + assert records[0]["sandbox_name"] == "Named" + assert records[0]["status"] == "error" + assert manager.session_booter == {} + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_rejects_duplicate_name(tmp_path): + manager, _provider = _manager(tmp_path) + + await manager.create_sandbox_uncontrolled(None, "session-a", "generic", "Named") + + with pytest.raises(RuntimeError, match="Sandbox name 'Named' already exists"): + await manager.create_sandbox_uncontrolled(None, "session-a", "generic", "Named") + + +@pytest.mark.asyncio +async def test_create_sandbox_respects_global_max_sandboxes(tmp_path): + manager, _provider = _manager(tmp_path) + context = FakeContext({"max_sandboxes": 1}) + + await manager.create_sandbox(context, "session-a", "generic") + + with pytest.raises(RuntimeError, match="Sandbox limit reached"): + await manager.create_sandbox(context, "session-b", "generic") + + +@pytest.mark.asyncio +async def test_get_or_create_booter_respects_global_max_sandboxes(tmp_path): + manager, _provider = _manager(tmp_path) + context = FakeContext({"max_sandboxes": 1}) + + await manager.get_or_create_booter(context, "session-a", "generic") + + with pytest.raises(RuntimeError, match="Sandbox limit reached"): + await manager.get_or_create_booter(context, "session-b", "generic") + + +@pytest.mark.asyncio +async def test_create_sandbox_uses_default_max_sandboxes_when_config_missing(tmp_path): + manager, _provider = _manager(tmp_path) + context = FakeContext({}) + for index in range(10): + await manager.create_sandbox(context, f"session-{index}", "generic") + + with pytest.raises(RuntimeError, match="Maximum managed sandboxes: 10"): + await manager.create_sandbox(context, "session-over-limit", "generic") + + +@pytest.mark.asyncio +async def test_create_sandbox_uses_default_max_sandboxes_when_config_invalid(tmp_path): + manager, _provider = _manager(tmp_path) + context = FakeContext({"max_sandboxes": "invalid"}) + for index in range(10): + await manager.create_sandbox(context, f"session-{index}", "generic") + + with pytest.raises(RuntimeError, match="Maximum managed sandboxes: 10"): + await manager.create_sandbox(context, "session-over-limit", "generic") + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_blank_name_falls_back_to_sandbox_id( + tmp_path, +): + manager, _provider = _manager(tmp_path) + + sandbox = await manager.create_sandbox_uncontrolled( + None, "session-a", "generic", " " + ) + + assert sandbox["sandbox_name"] == sandbox["sandbox_id"] + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_passes_sandbox_id_to_connect_info(tmp_path): + manager, _provider = _manager(tmp_path) + + sandbox = await manager.create_sandbox_uncontrolled( + None, "session-a", "generic", "Display Name" + ) + + assert sandbox["sandbox_name"] == "Display Name" + assert sandbox["connect_info"]["name"] == "Display Name" + assert sandbox["connect_info"]["sandbox_id"] == sandbox["sandbox_id"] + + +@pytest.mark.asyncio +async def test_create_sandbox_updates_connect_info_after_boot(tmp_path): + provider = ConnectInfoAfterBootProvider() + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled( + None, "session-a", "generic", "Display Name" + ) + + assert sandbox["connect_info"]["runtime_id"] == "runtime-1" + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_deferred_passes_sandbox_id_to_connect_info( + tmp_path, +): + provider = DeferredBootProvider() + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Display Name" + ) + + assert sandbox["sandbox_name"] == "Display Name" + assert sandbox["connect_info"]["name"] == "Display Name" + assert sandbox["connect_info"]["sandbox_id"] == sandbox["sandbox_id"] + + provider.allow_boot.set() + await asyncio.wait_for(manager.pending_boot_tasks[sandbox["sandbox_id"]], timeout=1) + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_deferred_returns_creating_then_running( + tmp_path, +): + provider = DeferredBootProvider() + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + assert sandbox["status"] == "creating" + assert manager.session_booter == {} + + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + assert manager.registry.get_sandbox(sandbox["sandbox_id"])["status"] == "creating" + + provider.allow_boot.set() + await wait_until( + lambda: ( + (record := manager.registry.get_sandbox(sandbox["sandbox_id"])) is not None + and record["status"] == "running" + ) + ) + + assert manager.registry.get_sandbox(sandbox["sandbox_id"])["status"] == "running" + assert sandbox["sandbox_id"] in manager.session_booter + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_deferred_tracks_pending_boot_task( + tmp_path, +): + provider = DeferredBootProvider() + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + task = manager.pending_boot_tasks.get(sandbox["sandbox_id"]) + assert task is not None + assert not task.done() + + provider.allow_boot.set() + await asyncio.wait_for(task, timeout=1) + assert sandbox["sandbox_id"] not in manager.pending_boot_tasks + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_deferred_delays_boot_past_next_loop_turn( + tmp_path, +): + provider = DeferredBootProvider() + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + assert sandbox["status"] == "creating" + assert not provider.boot_started.is_set() + await asyncio.sleep(0) + assert not provider.boot_started.is_set() + + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_deferred_rejects_duplicate_name(tmp_path): + provider = DeferredBootProvider() + manager, _provider = _manager(tmp_path, provider) + + await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + with pytest.raises(RuntimeError, match="Sandbox name 'Named' already exists"): + await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_deferred_keeps_error_record_on_boot_failure( + tmp_path, +): + provider = DeferredBootProvider() + provider.raise_on_boot = True + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + provider.allow_boot.set() + await wait_until( + lambda: ( + (record := manager.registry.get_sandbox(sandbox["sandbox_id"])) is not None + and record["status"] == "error" + ) + ) + + record = manager.registry.get_sandbox(sandbox["sandbox_id"]) + assert record is not None + assert record["status"] == "error" + assert sandbox["sandbox_id"] not in manager.session_booter + + +@pytest.mark.asyncio +async def test_create_sandbox_uncontrolled_deferred_uses_fresh_record_for_cleanup( + tmp_path, +): + provider = RecordCapturingProvider() + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + manager.registry.delete_sandbox(sandbox["sandbox_id"]) + provider.allow_boot.set() + + await wait_until(lambda: not manager.pending_boot_tasks) + + assert provider.destroyed_records == [{}] + assert manager.registry.get_sandbox(sandbox["sandbox_id"]) is None + assert sandbox["sandbox_id"] not in manager.session_booter + + +@pytest.mark.asyncio +async def test_manager_waits_for_current_creating_sandbox_instead_of_creating_another_one( + tmp_path, +): + class CountingDeferredBootProvider(DeferredBootProvider): + def __init__(self): + super().__init__() + self.create_calls = 0 + self.second_create_started = asyncio.Event() + + async def create_booter(self, context, session_id, sandbox_id, config): + self.create_calls += 1 + if self.create_calls == 2: + self.second_create_started.set() + return await super().create_booter(context, session_id, sandbox_id, config) + + provider = CountingDeferredBootProvider() + manager, _provider = _manager(tmp_path, provider) + + created = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + manager.registry.set_current_sandbox_id("session-a", created["sandbox_id"]) + + get_booter_task = asyncio.create_task( + manager.get_or_create_booter(None, "session-a", "generic") + ) + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(asyncio.shield(get_booter_task), timeout=0.1) + + assert len(manager.registry.list_sandboxes()) == 1 + assert provider.create_calls == 1 + assert created["sandbox_id"] in manager.pending_boot_tasks + + provider.allow_boot.set() + booter = await get_booter_task + + assert booter is manager.session_booter[created["sandbox_id"]] + assert len(provider.created) == 1 + + +@pytest.mark.asyncio +async def test_create_sandbox_rolls_back_when_lease_acquisition_fails(tmp_path): + provider = FakeProvider() + manager = AlwaysBusyManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={provider.provider_id: provider}, + ) + + with pytest.raises(RuntimeError, match="Sandbox .* is busy"): + await manager.create_sandbox(None, "session-a", "generic", "Named") + + assert manager.registry.list_sandboxes() == [] + assert manager.session_booter == {} + assert provider.destroyed and provider.destroyed[0][1].startswith("generic-") + + +@pytest.mark.asyncio +async def test_destroy_sandbox_cancels_deferred_boot_task(tmp_path): + provider = DeferredBootProvider() + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + + destroyed = await manager.destroy_sandbox("session-a", sandbox["sandbox_id"]) + + assert destroyed["sandbox_id"] == sandbox["sandbox_id"] + assert manager.registry.get_sandbox(sandbox["sandbox_id"]) is None + assert sandbox["sandbox_id"] not in manager.session_booter + assert sandbox["sandbox_id"] not in manager.pending_boot_tasks + + +@pytest.mark.asyncio +async def test_destroy_sandbox_waits_for_deferred_boot_lock_before_cleanup(tmp_path): + provider = SlowDeferredDestroyProvider() + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + destroy_task = asyncio.create_task( + manager.destroy_sandbox("session-a", sandbox["sandbox_id"]) + ) + + await asyncio.wait_for(provider.cancelled_during_boot.wait(), timeout=1) + assert not destroy_task.done() + + provider.allow_boot.set() + destroyed = await asyncio.wait_for(destroy_task, timeout=1) + + await asyncio.sleep(0.05) + + assert destroyed["sandbox_id"] == sandbox["sandbox_id"] + assert manager.registry.get_sandbox(sandbox["sandbox_id"]) is None + assert sandbox["sandbox_id"] not in manager.session_booter + assert sandbox["sandbox_id"] not in manager.pending_boot_tasks + + +@pytest.mark.asyncio +async def test_destroy_sandbox_waits_for_deferred_boot_lock_after_cancel_timeout( + tmp_path, +): + provider = SlowDeferredDestroyProvider() + manager, _provider = _manager(tmp_path, provider) + + sandbox = await manager.create_sandbox_uncontrolled_deferred( + None, "session-a", "generic", "Named" + ) + + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + destroy_task = asyncio.create_task( + manager.destroy_sandbox("session-a", sandbox["sandbox_id"]) + ) + + await asyncio.wait_for(provider.cancelled_during_boot.wait(), timeout=1) + await asyncio.sleep(1.1) + assert not destroy_task.done() + + provider.allow_boot.set() + destroyed = await asyncio.wait_for(destroy_task, timeout=1) + + assert destroyed["sandbox_id"] == sandbox["sandbox_id"] + assert manager.registry.get_sandbox(sandbox["sandbox_id"]) is None + assert sandbox["sandbox_id"] not in manager.session_booter + assert sandbox["sandbox_id"] not in manager.pending_boot_tasks + + +@pytest.mark.asyncio +async def test_destroy_persistent_sandbox_removes_record(tmp_path): + manager, provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + manager.update_sandbox_config( + created["sandbox_id"], + idle_timeout=None, + expires_at=None, + retention_policy="persistent", + ) + + destroyed = await manager.destroy_sandbox("session-a", created["sandbox_id"]) + + assert destroyed["sandbox_id"] == created["sandbox_id"] + assert manager.registry.get_sandbox(created["sandbox_id"]) is None + assert provider.destroyed[0][1] == created["sandbox_id"] + + +@pytest.mark.asyncio +async def test_destroy_sandbox_preserves_record_when_provider_destroy_fails(tmp_path): + provider = AlwaysFailingDestroyProvider() + manager, _provider = _manager(tmp_path, provider) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + + with pytest.raises(RuntimeError, match="destroy failed"): + await manager.destroy_sandbox("session-a", created["sandbox_id"]) + + record = manager.registry.get_sandbox(created["sandbox_id"]) + assert record is not None + assert record["status"] == "error" + assert created["sandbox_id"] in manager.session_booter + + +@pytest.mark.asyncio +async def test_destroy_sandbox_deferred_returns_stopping_before_background_delete( + tmp_path, +): + provider = BlockingDestroyProvider() + manager, _provider = _manager(tmp_path, provider) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + + destroyed = await asyncio.wait_for( + manager.destroy_sandbox_deferred("session-a", created["sandbox_id"]), + timeout=1, + ) + + assert destroyed["status"] == "stopping" + assert destroyed["sandbox_id"] == created["sandbox_id"] + await asyncio.wait_for(provider.destroy_started.wait(), timeout=1) + + record = manager.registry.get_sandbox(created["sandbox_id"]) + assert record is not None + assert record["status"] == "stopping" + + provider.allow_destroy.set() + await wait_until( + lambda: manager.registry.get_sandbox(created["sandbox_id"]) is None + ) + + assert manager.registry.get_sandbox(created["sandbox_id"]) is None + assert provider.destroyed[0][1] == created["sandbox_id"] + + +@pytest.mark.asyncio +async def test_destroy_sandbox_deferred_delays_cleanup_past_next_loop_turn(tmp_path): + provider = ImmediateDestroyProvider() + manager, _provider = _manager(tmp_path, provider) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + + destroyed = await manager.destroy_sandbox_deferred( + "session-a", created["sandbox_id"] + ) + + assert destroyed["status"] == "stopping" + assert not provider.destroy_started.is_set() + await asyncio.sleep(0) + assert not provider.destroy_started.is_set() + + await asyncio.wait_for(provider.destroy_started.wait(), timeout=1) + + +@pytest.mark.asyncio +async def test_destroy_sandbox_deferred_tracks_pending_destroy_task(tmp_path): + provider = BlockingDestroyProvider() + manager, _provider = _manager(tmp_path, provider) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + + destroyed = await manager.destroy_sandbox_deferred( + "session-a", created["sandbox_id"] + ) + + task = manager.pending_destroy_tasks.get(destroyed["sandbox_id"]) + assert task is not None + assert not task.done() + + provider.allow_destroy.set() + await asyncio.wait_for(task, timeout=1) + assert destroyed["sandbox_id"] not in manager.pending_destroy_tasks + + +@pytest.mark.asyncio +async def test_created_hook_second_call_does_not_wait_for_first_hook(tmp_path): + provider = SlowCreatedHookProvider() + manager, _provider = _manager(tmp_path, provider) + created = await manager.create_sandbox_uncontrolled( + None, "session-a", "generic", "Named" + ) + + first = asyncio.create_task( + manager._invoke_sandbox_created_hook(provider, created["sandbox_id"]) + ) + await asyncio.wait_for(provider.hook_started.wait(), timeout=1) + + second = asyncio.create_task( + manager._invoke_sandbox_created_hook(provider, created["sandbox_id"]) + ) + await asyncio.sleep(0) + + assert second.done() + assert not first.done() + assert provider.hook_calls == 1 + + provider.allow_hook.set() + await asyncio.wait_for(first, timeout=1) + + +@pytest.mark.asyncio +async def test_clear_runtime_state_keeps_held_boot_lock(tmp_path): + manager, _provider = _manager(tmp_path) + sandbox_id = "generic-1" + + async with manager._sandbox_boot_lock(sandbox_id): + held_lock = manager.boot_locks[sandbox_id] + manager.clear_runtime_state(sandbox_id) + + assert manager.boot_locks[sandbox_id] is held_lock + + +@pytest.mark.asyncio +async def test_takeover_sandbox_keeps_held_boot_lock_on_health_failure(tmp_path): + manager, _provider = _manager(tmp_path) + sandbox_id = "generic-1" + manager.registry.upsert_sandbox( + sandbox_id=sandbox_id, + sandbox_name="Sandbox", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Sandbox"}, + status="running", + ) + booter = FakeBooter() + booter.available_result = False + manager.session_booter[sandbox_id] = booter + + async with manager._sandbox_boot_lock(sandbox_id): + held_lock = manager.boot_locks[sandbox_id] + + with pytest.raises(RuntimeError, match="booter health check failed"): + await manager.takeover_sandbox("session-b", sandbox_id) + + assert manager.boot_locks[sandbox_id] is held_lock + + +@pytest.mark.asyncio +async def test_create_sandbox_sets_current_sandbox_after_lease(tmp_path): + manager, _provider = _manager(tmp_path) + + sandbox = await manager.create_sandbox(None, "session-a", "generic") + + assert sandbox["status"] == "running" + assert ( + manager.get_current_sandbox("session-a")["current_sandbox_id"] + == sandbox["sandbox_id"] + ) + + +@pytest.mark.asyncio +async def test_get_or_create_booter_rolls_back_on_registry_save_failure(tmp_path): + provider = FakeProvider() + manager, _provider = _manager(tmp_path, provider) + save_calls = 0 + + async def fail_save(): + nonlocal save_calls + save_calls += 1 + if save_calls > 1: + raise RuntimeError("disk full") + + manager.save_registry_async = fail_save + + with pytest.raises(RuntimeError, match="disk full"): + await manager.get_or_create_booter(None, "session-a", "generic") + + assert manager.session_booter == {} + assert provider.destroyed + assert manager.list_sandboxes() == [] + assert manager.get_current_sandbox("session-a")["current_sandbox_id"] is None + + +@pytest.mark.asyncio +async def test_manager_creates_new_sandbox_when_default_busy(tmp_path): + manager, provider = _manager(tmp_path) + + await manager.get_or_create_booter(None, "session-a", "generic") + await manager.get_or_create_booter(None, "session-b", "generic") + + assert len(provider.created) == 2 + assert len(manager.list_sandboxes()) == 2 + + +@pytest.mark.asyncio +async def test_manager_reuses_idle_provider_sandbox_when_default_busy(tmp_path): + manager, provider = _manager(tmp_path) + + default = await manager.create_sandbox(None, "session-a", "generic", "Default") + idle = await manager.create_sandbox(None, "session-b", "generic", "Reusable") + manager.release_current_sandbox("session-b", idle["sandbox_id"]) + + booter = await manager.get_or_create_booter(None, "session-c", "generic") + + current_id = manager.get_current_sandbox("session-c")["current_sandbox_id"] + assert current_id == idle["sandbox_id"] + assert booter is manager.session_booter[idle["sandbox_id"]] + assert len(provider.created) == 2 + assert ( + manager.registry.get_sandbox(default["sandbox_id"])["controller_session_id"] + == "session-a" + ) + + +def test_manager_treats_expired_lease_sandbox_as_idle(tmp_path): + manager, _provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Expired", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Expired"}, + status="running", + controller_user_id="session-a", + controller_session_id="session-a", + lease_expires_at=time.time() - 1, + ) + manager.session_booter["generic-1"] = FakeBooter() + + assert manager._find_idle_provider_sandbox_id("generic") == "generic-1" + + +def test_manager_list_sandboxes_releases_expired_lease(tmp_path): + manager, _provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Expired", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Expired"}, + status="running", + controller_user_id="session-a", + controller_session_id="session-a", + lease_expires_at=time.time() - 1, + ) + manager.registry.set_current_sandbox_id("session-a", "generic-1") + + listed = manager.list_sandboxes()[0] + persisted = manager.registry.get_sandbox("generic-1") + + assert listed["controller_session_id"] is None + assert listed["controller_user_id"] is None + assert listed["lease_expires_at"] is None + assert persisted["controller_session_id"] is None + assert manager.registry.get_current_sandbox_id("session-a") is None + + +def test_manager_get_current_clears_expired_current_binding(tmp_path): + manager, _provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Expired", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Expired"}, + status="running", + controller_user_id="session-a", + controller_session_id="session-a", + lease_expires_at=time.time() - 1, + ) + manager.registry.set_current_sandbox_id("session-a", "generic-1") + + current = manager.get_current_sandbox("session-a") + record = manager.registry.get_sandbox("generic-1") + + assert current == {"current_sandbox_id": None, "sandbox": None} + assert record["controller_session_id"] is None + assert record["controller_user_id"] is None + assert record["lease_expires_at"] is None + assert manager.registry.get_current_sandbox_id("session-a") is None + + +def test_manager_get_current_clears_binding_taken_by_other_session(tmp_path): + manager, _provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Taken", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Taken"}, + status="running", + controller_user_id="session-b", + controller_session_id="session-b", + lease_expires_at=time.time() + 60, + ) + manager.registry.set_current_sandbox_id("session-a", "generic-1") + + current = manager.get_current_sandbox("session-a") + record = manager.registry.get_sandbox("generic-1") + + assert current == {"current_sandbox_id": None, "sandbox": None} + assert record["controller_session_id"] == "session-b" + assert manager.registry.get_current_sandbox_id("session-a") is None + + +@pytest.mark.asyncio +async def test_manager_creates_new_sandbox_when_current_binding_is_busy(tmp_path): + manager, provider = _manager(tmp_path) + + first = await manager.get_or_create_booter(None, "session-a", "generic") + first_sandbox_id = manager.get_current_sandbox("session-a")["current_sandbox_id"] + manager.registry.set_current_sandbox_id("session-b", first_sandbox_id) + + second = await manager.get_or_create_booter(None, "session-b", "generic") + + assert second is not first + assert len(provider.created) == 2 + assert manager.get_current_sandbox("session-b")["current_sandbox_id"] != ( + first_sandbox_id + ) + first_record = manager.registry.get_sandbox(first_sandbox_id) + assert first_record["controller_session_id"] == "session-a" + assert first_record["lease_expires_at"] > time.time() + + +@pytest.mark.asyncio +async def test_get_or_create_booter_does_not_reacquire_expired_current_binding( + tmp_path, +): + manager, provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Expired") + expired_id = created["sandbox_id"] + manager.registry.upsert_sandbox( + sandbox_id=expired_id, + sandbox_name=created["sandbox_name"], + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info=created["connect_info"], + controller_user_id="session-a", + controller_session_id="session-a", + lease_expires_at=time.time() - 1, + ) + + booter = await manager.get_or_create_booter(None, "session-a", "generic") + current_id = manager.get_current_sandbox("session-a")["current_sandbox_id"] + + assert current_id != expired_id + assert booter is manager.session_booter[current_id] + assert manager.registry.get_sandbox(expired_id)["controller_session_id"] is None + assert len(provider.created) == 2 + + +@pytest.mark.asyncio +async def test_get_or_create_booter_does_not_reuse_stopping_current_sandbox(tmp_path): + manager, provider = _manager(tmp_path) + first = await manager.create_sandbox(None, "session-a", "generic", "Stopping") + first_id = first["sandbox_id"] + first_booter = manager.session_booter[first_id] + manager.registry.update_sandbox_status(first_id, "stopping") + + next_booter = await manager.get_or_create_booter(None, "session-a", "generic") + + assert next_booter is not first_booter + assert manager.get_current_sandbox("session-a")["current_sandbox_id"] != first_id + assert len(provider.created) == 2 + + +@pytest.mark.asyncio +async def test_get_or_create_booter_revives_persistent_unknown_default(tmp_path): + manager, provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-persistent", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent", "persistent_name": "persist-1"}, + status="unknown", + retention_policy="persistent", + is_default=True, + ) + manager.registry.set_default_sandbox_id("generic-persistent") + + await manager.get_or_create_booter(object(), "session-a", "generic") + + assert len(provider.created) == 1 + assert provider.created[0][3]["resume"] is True + assert provider.created[0][3]["persistent_name"] == "persist-1" + + +@pytest.mark.asyncio +async def test_get_or_create_booter_passes_persistent_host_port(tmp_path): + manager, provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-persistent", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={ + "name": "Persistent", + "persistent_name": "persist-1", + "host_port": 23456, + }, + status="unknown", + retention_policy="persistent", + is_default=True, + ) + manager.registry.set_default_sandbox_id("generic-persistent") + + await manager.get_or_create_booter(object(), "session-a", "generic") + + assert provider.created[0][3]["host_port"] == 23456 + + +@pytest.mark.asyncio +async def test_get_or_create_booter_defaults_to_temporary_retention(tmp_path): + manager, _provider = _manager(tmp_path) + + await manager.get_or_create_booter(None, "session-a", "generic") + + sandbox_id = manager.get_current_sandbox("session-a")["current_sandbox_id"] + record = manager.registry.get_sandbox(sandbox_id) + assert record["retention_policy"] == "temporary" + + +@pytest.mark.asyncio +async def test_manager_booter_available_accepts_sync_callable_and_property(tmp_path): + manager, _provider = _manager(tmp_path) + + assert await manager.booter_available(SyncAvailableBooter()) is True + assert await manager.booter_available(BoolAvailableBooter()) is True + assert await manager.booter_available(UnavailablePropertyBooter()) is False + + +@pytest.mark.asyncio +async def test_manager_switches_releases_takes_over_and_destroys(tmp_path): + manager, provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + + assert ( + manager.get_current_sandbox("session-a")["current_sandbox_id"] + == created["sandbox_id"] + ) + released = manager.release_current_sandbox("session-a") + assert released["controller_session_id"] is None + + switched = await manager.switch_current_sandbox_checked( + "session-a", created["sandbox_id"] + ) + assert switched["controller_session_id"] == "session-a" + + taken = await manager.takeover_sandbox("session-b", created["sandbox_id"]) + assert taken["controller_session_id"] == "session-b" + assert manager.get_current_sandbox("session-a")["current_sandbox_id"] is None + + destroyed = await manager.destroy_sandbox("session-b", created["sandbox_id"]) + assert destroyed["sandbox_id"] == created["sandbox_id"] + assert provider.destroyed[0][1] == created["sandbox_id"] + assert manager.list_sandboxes() == [] + + +@pytest.mark.asyncio +async def test_manager_switch_releases_previous_sandbox_owned_by_same_session(tmp_path): + manager, _provider = _manager(tmp_path) + first = await manager.create_sandbox(None, "session-a", "generic", "First") + second = await manager.create_sandbox(None, "session-a", "generic", "Second") + + switched = await manager.switch_current_sandbox_checked( + "session-a", second["sandbox_id"] + ) + + first_record = manager.registry.get_sandbox(first["sandbox_id"]) + second_record = manager.registry.get_sandbox(second["sandbox_id"]) + assert switched["sandbox_id"] == second["sandbox_id"] + assert first_record["controller_session_id"] is None + assert first_record["lease_expires_at"] is None + assert second_record["controller_session_id"] == "session-a" + assert second_record["lease_expires_at"] > time.time() + assert ( + manager.get_current_sandbox("session-a")["current_sandbox_id"] + == second["sandbox_id"] + ) + + +@pytest.mark.asyncio +async def test_manager_create_releases_previous_sandbox_owned_by_same_session(tmp_path): + manager, _provider = _manager(tmp_path) + + first = await manager.create_sandbox(None, "session-a", "generic", "First") + second = await manager.create_sandbox(None, "session-a", "generic", "Second") + + first_record = manager.registry.get_sandbox(first["sandbox_id"]) + second_record = manager.registry.get_sandbox(second["sandbox_id"]) + assert first_record["controller_session_id"] is None + assert first_record["lease_expires_at"] is None + assert second_record["controller_session_id"] == "session-a" + assert second_record["lease_expires_at"] > time.time() + assert ( + manager.get_current_sandbox("session-a")["current_sandbox_id"] + == second["sandbox_id"] + ) + + +@pytest.mark.asyncio +async def test_manager_get_or_create_releases_previous_cross_provider_sandbox(tmp_path): + first_provider = FakeProvider() + second_provider = OtherFakeProvider() + manager = SandboxManager( + registry=SandboxRegistry(tmp_path / "sandbox_registry.json"), + providers={ + first_provider.provider_id: first_provider, + second_provider.provider_id: second_provider, + }, + ) + + first = await manager.create_sandbox(None, "session-a", "generic", "First") + await manager.get_or_create_booter(None, "session-a", "other") + + first_record = manager.registry.get_sandbox(first["sandbox_id"]) + current_id = manager.get_current_sandbox("session-a")["current_sandbox_id"] + current_record = manager.registry.get_sandbox(current_id) + assert first_record["controller_session_id"] is None + assert first_record["lease_expires_at"] is None + assert current_record["provider"] == "other" + assert current_record["controller_session_id"] == "session-a" + + +@pytest.mark.asyncio +async def test_manager_takeover_releases_previous_sandbox_owned_by_same_session( + tmp_path, +): + manager, _provider = _manager(tmp_path) + + first = await manager.create_sandbox(None, "session-a", "generic", "First") + second = await manager.create_sandbox(None, "session-b", "generic", "Second") + + taken = await manager.takeover_sandbox("session-a", second["sandbox_id"]) + + first_record = manager.registry.get_sandbox(first["sandbox_id"]) + second_record = manager.registry.get_sandbox(second["sandbox_id"]) + assert taken["sandbox_id"] == second["sandbox_id"] + assert first_record["controller_session_id"] is None + assert first_record["lease_expires_at"] is None + assert second_record["controller_session_id"] == "session-a" + assert second_record["lease_expires_at"] > time.time() + + +@pytest.mark.asyncio +async def test_manager_takeover_uses_configured_lease_timeout(tmp_path): + manager, _provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + + taken = await manager.takeover_sandbox( + "session-b", + created["sandbox_id"], + context=FakeContext({"sandbox_lease_timeout": 12}), + ) + + assert taken["controller_session_id"] == "session-b" + assert taken["lease_expires_at"] > time.time() + 10 + assert taken["lease_expires_at"] < time.time() + 20 + + +@pytest.mark.asyncio +async def test_manager_force_releases_other_session_lease(tmp_path): + manager, _provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + taken = await manager.takeover_sandbox("session-b", created["sandbox_id"]) + manager.registry.set_current_sandbox_id("session-b", created["sandbox_id"]) + + released = manager.force_release_sandbox(created["sandbox_id"]) + + assert taken["controller_session_id"] == "session-b" + assert released["controller_session_id"] is None + assert released["lease_expires_at"] is None + assert manager.get_current_sandbox("session-b")["current_sandbox_id"] is None + + +@pytest.mark.asyncio +async def test_manager_renews_current_sandbox_lease_with_requested_ttl(tmp_path): + manager, _provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + + renewed = await manager.renew_current_sandbox_lease("session-a", ttl_seconds=7200) + + assert renewed["sandbox_id"] == created["sandbox_id"] + assert renewed["controller_session_id"] == "session-a" + assert renewed["lease_expires_at"] > time.time() + 7190 + + +@pytest.mark.asyncio +async def test_manager_same_session_lease_acquire_does_not_shorten_longer_lease( + tmp_path, +): + manager, _provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + renewed = await manager.renew_current_sandbox_lease("session-a", ttl_seconds=7200) + + switched = await manager.switch_current_sandbox_checked( + "session-a", + created["sandbox_id"], + context=FakeContext({"sandbox_lease_timeout": 12}), + ) + + assert switched["lease_expires_at"] == renewed["lease_expires_at"] + assert switched["last_used_at"] >= renewed["last_used_at"] + + +@pytest.mark.asyncio +async def test_manager_renew_current_sandbox_rejects_missing_current(tmp_path): + manager, _provider = _manager(tmp_path) + + with pytest.raises(RuntimeError, match="No current sandbox"): + await manager.renew_current_sandbox_lease("session-a") + + +@pytest.mark.asyncio +async def test_manager_renew_current_sandbox_rejects_expired_current(tmp_path): + manager, _provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + manager.registry.upsert_sandbox( + sandbox_id=created["sandbox_id"], + sandbox_name=created["sandbox_name"], + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info=created["connect_info"], + controller_user_id="session-a", + controller_session_id="session-a", + lease_expires_at=time.time() - 1, + ) + + with pytest.raises(RuntimeError, match="No current sandbox"): + await manager.renew_current_sandbox_lease("session-a") + + assert manager.registry.get_current_sandbox_id("session-a") is None + + +@pytest.mark.asyncio +async def test_manager_renew_current_sandbox_allows_zero_ttl_for_permanent_lease( + tmp_path, +): + manager, _provider = _manager(tmp_path) + await manager.create_sandbox(None, "session-a", "generic", "Named") + + renewed = await manager.renew_current_sandbox_lease("session-a", ttl_seconds=0) + + assert renewed["controller_session_id"] == "session-a" + assert renewed["lease_expires_at"] is None + + +@pytest.mark.asyncio +async def test_manager_renew_current_sandbox_rejects_non_finite_ttl(tmp_path): + manager, _provider = _manager(tmp_path) + await manager.create_sandbox(None, "session-a", "generic", "Named") + + with pytest.raises(RuntimeError, match="ttl_seconds must be finite"): + await manager.renew_current_sandbox_lease("session-a", ttl_seconds=float("inf")) + + +@pytest.mark.asyncio +async def test_manager_uses_configured_lease_timeout_for_new_sandboxes(tmp_path): + manager, _provider = _manager(tmp_path) + + created = await manager.create_sandbox( + FakeContext({"sandbox_lease_timeout": 12}), + "session-a", + "generic", + "Named", + ) + + assert created["lease_expires_at"] > time.time() + 10 + + +@pytest.mark.asyncio +async def test_manager_list_sandboxes_exposes_exact_cleanup_times(tmp_path): + manager, _provider = _manager(tmp_path) + + idle_created = await manager.create_sandbox( + FakeContext({"sandbox_idle_timeout": 30, "sandbox_ttl": 0}), + "session-a", + "generic", + "Idle", + ) + ttl_created = await manager.create_sandbox( + FakeContext({"sandbox_idle_timeout": 0, "sandbox_ttl": 120}), + "session-b", + "generic", + "TTL", + ) + listed = {item["sandbox_id"]: item for item in manager.list_sandboxes()} + + assert idle_created["expires_at"] is None + assert listed[idle_created["sandbox_id"]]["idle_cleanup_at"] is None + manager.release_current_sandbox("session-a", idle_created["sandbox_id"]) + idle_listed = {item["sandbox_id"]: item for item in manager.list_sandboxes()}[ + idle_created["sandbox_id"] + ] + assert idle_listed["idle_cleanup_at"] == pytest.approx( + idle_listed["last_used_at"] + idle_listed["idle_timeout"], abs=0.01 + ) + assert ttl_created["expires_at"] > time.time() + 110 + assert listed[ttl_created["sandbox_id"]]["idle_cleanup_at"] is None + + +@pytest.mark.asyncio +async def test_manager_ttl_cleanup_removes_temporary_sandbox_when_idle_cleanup_disabled( + tmp_path, +): + manager, provider = _manager(tmp_path) + + sandbox = await manager.create_sandbox( + FakeContext({"sandbox_idle_timeout": 0, "sandbox_ttl": 0.01}), + "session-a", + "generic", + "TTL", + ) + + await asyncio.sleep(0.05) + + assert manager.registry.get_sandbox(sandbox["sandbox_id"]) is None + assert provider.destroyed[0][1] == sandbox["sandbox_id"] + + +@pytest.mark.asyncio +async def test_manager_ttl_cleanup_uses_monotonic_delay_when_wall_clock_moves_back( + tmp_path, + monkeypatch, +): + manager, provider = _manager(tmp_path) + original_time = time.time + + sandbox = await manager.create_sandbox( + FakeContext({"sandbox_idle_timeout": 0, "sandbox_ttl": 0.02}), + "session-a", + "generic", + "TTL", + ) + monkeypatch.setattr(time, "time", lambda: original_time() - 3600) + + await asyncio.sleep(0.08) + + assert manager.registry.get_sandbox(sandbox["sandbox_id"]) is None + assert provider.destroyed[0][1] == sandbox["sandbox_id"] + + +@pytest.mark.asyncio +async def test_manager_renew_current_sandbox_rejects_non_running_sandbox(tmp_path): + manager, _provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + manager.session_booter.pop(created["sandbox_id"]) + manager.registry.update_sandbox_status(created["sandbox_id"], "error") + + with pytest.raises(RuntimeError, match="encountered an error"): + await manager.renew_current_sandbox_lease("session-a") + + assert ( + manager.get_current_sandbox("session-a")["current_sandbox_id"] + == created["sandbox_id"] + ) + + +@pytest.mark.asyncio +async def test_manager_renew_current_sandbox_rejects_unavailable_booter(tmp_path): + manager, _provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + booter = manager.session_booter[created["sandbox_id"]] + booter.available_result = False + + with pytest.raises(RuntimeError, match="is not running"): + await manager.renew_current_sandbox_lease("session-a") + + assert created["sandbox_id"] not in manager.session_booter + assert manager.registry.get_sandbox(created["sandbox_id"])["status"] == "unknown" + + +@pytest.mark.asyncio +async def test_manager_blocks_observer_booter_access_from_other_session(tmp_path): + manager, _provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + + with pytest.raises(RuntimeError, match="controlled by another session"): + await manager.get_observer_booter_by_id(created["sandbox_id"], "session-b") + + +@pytest.mark.asyncio +async def test_manager_revives_persistent_sandbox_for_observer_access(tmp_path): + manager, provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + booter = await manager.get_observer_booter_by_id( + "generic-1", "dashboard", require_lease=False, context=object() + ) + + assert isinstance(booter, FakeBooter) + assert len(provider.created) == 1 + assert manager.registry.get_sandbox("generic-1")["status"] == "running" + + +@pytest.mark.asyncio +async def test_manager_revives_persistent_sandbox_for_switch_access(tmp_path): + manager, provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + switched = await manager.switch_current_sandbox_checked( + "session-a", "generic-1", context=object() + ) + + assert switched["sandbox_id"] == "generic-1" + assert len(provider.created) == 1 + assert manager.registry.get_sandbox("generic-1")["status"] == "running" + + +def test_manager_current_sandbox_uses_current_provider_tool_names(tmp_path): + manager, provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + controller_user_id="session-a", + controller_session_id="session-a", + lease_expires_at=time.time() + 60, + tool_names=["stale_provider_screenshot"], + capabilities=["stale"], + ) + manager.registry.set_current_sandbox_id("session-a", "generic-1") + + current = manager.get_current_sandbox("session-a") + + assert current["sandbox"]["tool_names"] == sorted(provider.tool_names) + assert current["sandbox"]["capabilities"] == sorted(provider.capabilities) + + +@pytest.mark.asyncio +async def test_manager_does_not_revive_destroyed_persistent_sandbox(tmp_path): + manager, provider = _manager(tmp_path) + created = await manager.create_sandbox(None, "session-a", "generic", "Named") + manager.update_sandbox_config( + created["sandbox_id"], + idle_timeout=None, + expires_at=None, + retention_policy="persistent", + ) + + await manager.destroy_sandbox("session-a", created["sandbox_id"]) + + with pytest.raises(RuntimeError, match="not found"): + await manager.get_observer_booter_by_id( + created["sandbox_id"], + "dashboard", + require_lease=False, + context=object(), + ) + + assert len(provider.created) == 1 + + +@pytest.mark.asyncio +async def test_manager_does_not_revive_persistent_sandbox_without_provider_support( + tmp_path, +): + manager, provider = _manager(tmp_path) + provider.supports_persistent_reconnect = False + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + with pytest.raises(RuntimeError, match="Sandbox generic-1 is not running"): + await manager.get_observer_booter_by_id( + "generic-1", "dashboard", require_lease=False, context=object() + ) + + assert provider.created == [] + + +@pytest.mark.asyncio +async def test_manager_persistent_reconnect_failure_restores_previous_status(tmp_path): + provider = FailingReconnectProvider() + manager, _provider = _manager(tmp_path, provider) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + with pytest.raises(RuntimeError, match="boot failed"): + await manager.get_observer_booter_by_id( + "generic-1", "dashboard", require_lease=False, context=object() + ) + + assert manager.registry.get_sandbox("generic-1")["status"] == "running" + assert manager.session_booter == {} + + +@pytest.mark.asyncio +async def test_manager_marks_persistent_reconnect_as_restoring(tmp_path): + provider = RecordCapturingProvider() + manager, _provider = _manager(tmp_path, provider) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="unknown", + retention_policy="persistent", + ) + + task = asyncio.create_task( + manager.get_observer_booter_by_id( + "generic-1", "dashboard", require_lease=False, context=object() + ) + ) + await asyncio.wait_for(provider.boot_started.wait(), timeout=1) + + assert manager.registry.get_sandbox("generic-1")["status"] == "restoring" + + provider.allow_boot.set() + await task + assert manager.registry.get_sandbox("generic-1")["status"] == "running" + + +@pytest.mark.asyncio +async def test_manager_treats_unknown_available_state_as_unavailable(tmp_path): + manager, _provider = _manager(tmp_path) + assert await manager.booter_available(BaseDefaultAvailableBooter()) is False + assert await manager.booter_available(NoneAvailableBooter()) is False + + +@pytest.mark.asyncio +async def test_manager_persistent_health_failure_marks_unknown_for_retry(tmp_path): + manager, _provider = _manager(tmp_path) + booter = FakeBooter() + booter.available_result = False + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + manager.session_booter["generic-1"] = booter + + with pytest.raises(RuntimeError, match="booter health check failed"): + await manager.get_observer_booter_by_id( + "generic-1", "dashboard", require_lease=False, context=object() + ) + + assert manager.registry.get_sandbox("generic-1")["status"] == "unknown" + assert "generic-1" not in manager.session_booter + + +@pytest.mark.asyncio +async def test_manager_checked_list_marks_stale_persistent_booter_unknown(tmp_path): + manager, _provider = _manager(tmp_path) + booter = FakeBooter() + booter.available_result = False + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + manager.session_booter["generic-1"] = booter + manager.boot_locks["generic-1"] = asyncio.Lock() + + listed = await manager.list_sandboxes_checked() + + assert listed[0]["status"] == "unknown" + assert manager.registry.get_sandbox("generic-1")["status"] == "unknown" + assert "generic-1" not in manager.session_booter + assert "generic-1" not in manager.boot_locks + + +@pytest.mark.asyncio +async def test_manager_revives_persistent_sandbox_for_tool_access(tmp_path): + manager, provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + booter = await manager.get_observer_booter_by_id( + "generic-1", "session-a", require_lease=False, context=object() + ) + + assert isinstance(booter, FakeBooter) + assert len(provider.created) == 1 + + +@pytest.mark.asyncio +async def test_manager_restores_persistent_sandboxes_on_startup(tmp_path): + provider = FakeProvider() + manager, provider = _manager(tmp_path, provider) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + await manager.restore_persistent_sandboxes(object()) + + assert "generic-1" in manager.session_booter + assert len(provider.created) == 1 + assert manager.registry.get_sandbox("generic-1")["status"] == "running" + assert "generic-1" not in manager.idle_state + + +@pytest.mark.asyncio +async def test_manager_reconcile_on_startup_removes_stale_persistent_records( + tmp_path, +): + provider = PruningMissingPersistentProvider() + manager, provider = _manager(tmp_path, provider) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + manager.registry.save() + await manager.reconcile_on_startup() + + assert manager.registry.get_sandbox("generic-1") is None + assert len(provider.created) == 0 + + +@pytest.mark.asyncio +async def test_manager_reconcile_on_startup_keeps_unconfirmed_persistent_records_by_default( + tmp_path, +): + provider = MissingPersistentProvider() + manager, provider = _manager(tmp_path, provider) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + manager.registry.save() + await manager.reconcile_on_startup() + + record = manager.registry.get_sandbox("generic-1") + assert record is not None + assert record["status"] == "unknown" + assert len(provider.created) == 0 + + +@pytest.mark.asyncio +async def test_manager_reconcile_on_startup_clears_all_runtime_state(tmp_path): + manager, _provider = _manager(tmp_path) + manager.session_booter["stale-1"] = FakeBooter() + manager._sandbox_boot_lock("stale-1") + manager.schedule_idle_cleanup("stale-1", 30) + manager.registry.save() + + await manager.reconcile_on_startup() + + assert manager.session_booter == {} + assert manager.idle_state == {} + assert manager.boot_locks == {} + + +@pytest.mark.asyncio +async def test_manager_reconcile_on_startup_waits_for_pending_destroy_tasks(tmp_path): + manager, provider = _manager(tmp_path, BlockingDestroyProvider()) + sandbox = await manager.create_sandbox(None, "session-a", "generic", "Named") + task = asyncio.create_task( + manager.destroy_sandbox_deferred("session-a", sandbox["sandbox_id"]) + ) + await asyncio.wait_for(provider.destroy_started.wait(), timeout=1) + provider.allow_destroy.set() + + await manager.reconcile_on_startup() + await task + + assert sandbox["sandbox_id"] not in manager.pending_destroy_tasks + assert manager.registry.get_sandbox(sandbox["sandbox_id"]) is None + + +@pytest.mark.asyncio +async def test_manager_reconcile_on_startup_keeps_persistent_records_for_missing_provider( + tmp_path, +): + manager, _provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="missing-1", + sandbox_name="Persistent", + provider="missing", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + manager.registry.save() + await manager.reconcile_on_startup() + + record = manager.registry.get_sandbox("missing-1") + assert record is not None + assert record["status"] == "unknown" + + +@pytest.mark.asyncio +async def test_manager_takeover_rejects_non_running_sandbox(tmp_path): + manager, _provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Broken", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Broken"}, + status="error", + ) + + with pytest.raises(RuntimeError, match="encountered an error"): + await manager.takeover_sandbox("session-b", "generic-1") + + record = manager.registry.get_sandbox("generic-1") + assert record["controller_session_id"] is None + + +@pytest.mark.asyncio +async def test_manager_takeover_revives_persistent_sandbox_with_context(tmp_path): + provider = ContextCapturingProvider() + manager, provider = _manager(tmp_path, provider) + context = FakeContext({}) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="unknown", + retention_policy="persistent", + ) + + await manager.takeover_sandbox("session-b", "generic-1", context=context) + + assert provider.contexts == [context] + assert "generic-1" in manager.session_booter + + +@pytest.mark.asyncio +async def test_manager_reconcile_on_startup_keeps_valid_persistent_records( + tmp_path, +): + provider = ExistingPersistentProvider() + manager, provider = _manager(tmp_path, provider) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + manager.registry.save() + await manager.reconcile_on_startup() + + record = manager.registry.get_sandbox("generic-1") + assert record is not None + assert record["status"] == "unknown" + assert len(provider.created) == 0 + + +@pytest.mark.asyncio +async def test_manager_restore_persistent_sandboxes_times_out_and_keeps_record( + tmp_path, +): + provider = FailingReconnectProvider() + manager, _provider = _manager(tmp_path, provider) + restore_started = asyncio.Event() + + async def slow_create_booter(context, session_id, sandbox_id, config): + restore_started.set() + await asyncio.sleep(1) + return await FakeProvider().create_booter( + context, session_id, sandbox_id, config + ) + + provider.create_booter = slow_create_booter + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + restored, timed_out = await manager.restore_persistent_sandboxes( + object(), per_sandbox_timeout=0.01 + ) + + assert restore_started.is_set() + assert restored == 0 + assert timed_out == 1 + record = manager.registry.get_sandbox("generic-1") + assert record is not None + assert record["status"] == "unknown" + + +@pytest.mark.asyncio +async def test_manager_restore_persistent_sandboxes_cancellation_restores_previous_status( + tmp_path, +): + provider = FailingReconnectProvider() + manager, _provider = _manager(tmp_path, provider) + restore_started = asyncio.Event() + + async def slow_create_booter(context, session_id, sandbox_id, config): + restore_started.set() + await asyncio.sleep(3600) + return await FakeProvider().create_booter( + context, session_id, sandbox_id, config + ) + + provider.create_booter = slow_create_booter + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + ) + + task = asyncio.create_task(manager.restore_persistent_sandboxes(object())) + await asyncio.wait_for(restore_started.wait(), timeout=1) + task.cancel() + + with pytest.raises(asyncio.CancelledError): + await task + + record = manager.registry.get_sandbox("generic-1") + assert record is not None + assert record["status"] == "running" + + +def test_manager_reconcile_on_startup_removes_temporary_records(tmp_path): + manager, _provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Temporary", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Temporary"}, + status="running", + retention_policy="temporary", + ) + + manager.registry.save() + manager.registry.reconcile_startup() + + assert manager.registry.get_sandbox("generic-1") is None + + +@pytest.mark.asyncio +async def test_manager_idle_cleanup_removes_temporary_sandbox(tmp_path): + manager, provider = _manager(tmp_path) + context = FakeContext({"sandbox_idle_timeout": 0.01, "sandbox_ttl": 0}) + + await manager.get_or_create_booter(context, "session-a", "generic") + sandbox_id = manager.list_sandboxes()[0]["sandbox_id"] + manager.release_current_sandbox("session-a", sandbox_id) + + await asyncio.sleep(0.05) + + assert manager.registry.get_sandbox(sandbox_id) is None + assert provider.destroyed[0][1] == sandbox_id + + +@pytest.mark.asyncio +async def test_manager_idle_cleanup_uses_scheduled_monotonic_deadline(tmp_path): + manager, provider = _manager(tmp_path) + context = FakeContext({"sandbox_idle_timeout": 0.01, "sandbox_ttl": 0}) + + sandbox = await manager.create_sandbox(context, "session-a", "generic", "Named") + sandbox_id = sandbox["sandbox_id"] + manager.release_current_sandbox("session-a", sandbox_id) + manager.registry._payload["sandboxes"][sandbox_id]["last_used_at"] = ( + time.time() + 3600 + ) + + await asyncio.sleep(0.05) + + assert manager.registry.get_sandbox(sandbox_id) is None + assert provider.destroyed[0][1] == sandbox_id + + +@pytest.mark.asyncio +async def test_manager_idle_cleanup_ignores_persistent_sandbox(tmp_path): + manager, provider = _manager(tmp_path) + context = FakeContext({"sandbox_idle_timeout": 0.01, "sandbox_ttl": 0}) + + sandbox = await manager.create_sandbox(context, "session-a", "generic", "Named") + sandbox_id = sandbox["sandbox_id"] + manager.update_sandbox_config( + sandbox_id, + idle_timeout=None, + expires_at=None, + retention_policy="persistent", + ) + manager.release_current_sandbox("session-a", sandbox_id) + + await asyncio.sleep(0.05) + + record = manager.registry.get_sandbox(sandbox_id) + assert record is not None + assert record["status"] == "running" + assert provider.destroyed == [] + + +@pytest.mark.asyncio +async def test_manager_stale_idle_cleanup_task_skips_persistent_sandbox(tmp_path): + manager, provider = _manager(tmp_path) + manager.registry.upsert_sandbox( + sandbox_id="persistent-1", + sandbox_name="Persistent", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Persistent"}, + status="running", + retention_policy="persistent", + idle_timeout=0.01, + ) + booter = FakeBooter() + manager.session_booter["persistent-1"] = booter + + manager.schedule_idle_cleanup("persistent-1", 0.01) + await asyncio.sleep(0.05) + + record = manager.registry.get_sandbox("persistent-1") + assert record is not None + assert record["status"] == "running" + assert manager.session_booter["persistent-1"] is booter + assert provider.destroyed == [] + + +@pytest.mark.asyncio +async def test_manager_idle_cleanup_does_not_retry_dead_booter(tmp_path): + provider = DeadIdleDestroyProvider() + manager, provider = _manager(tmp_path, provider) + context = FakeContext({"sandbox_idle_timeout": 0.01, "sandbox_ttl": 0}) + + sandbox = await manager.create_sandbox(context, "session-a", "generic", "Named") + manager.release_current_sandbox("session-a", sandbox["sandbox_id"]) + + await asyncio.wait_for(provider.destroy_started.wait(), timeout=1) + await asyncio.sleep(0.05) + + assert provider.destroy_calls == 1 + assert manager.registry.get_sandbox(sandbox["sandbox_id"]) is None + assert sandbox["sandbox_id"] not in manager.session_booter + assert sandbox["sandbox_id"] not in manager.idle_state + + +@pytest.mark.asyncio +async def test_manager_idle_cleanup_stops_retrying_after_max_attempts(tmp_path): + from astrbot.core.computer import sandbox_manager as sandbox_manager_module + + provider = AlwaysFailingIdleDestroyProvider() + manager, provider = _manager(tmp_path, provider) + context = FakeContext({"sandbox_idle_timeout": 0.01, "sandbox_ttl": 0}) + + sandbox = await manager.create_sandbox(context, "session-a", "generic", "Named") + manager.release_current_sandbox("session-a", sandbox["sandbox_id"]) + + await asyncio.wait_for(provider.destroy_started.wait(), timeout=1) + await wait_until( + lambda: sandbox["sandbox_id"] not in manager.idle_state, + timeout=1, + ) + + assert provider.destroy_calls == sandbox_manager_module.MAX_IDLE_DESTROY_ATTEMPTS + record = manager.registry.get_sandbox(sandbox["sandbox_id"]) + assert record is not None + assert record["status"] == "error" + assert sandbox["sandbox_id"] in manager.session_booter + assert sandbox["sandbox_id"] not in manager.idle_state + + +@pytest.mark.asyncio +async def test_manager_cleanup_waits_for_pending_destroy_tasks(tmp_path): + manager, provider = _manager(tmp_path, BlockingDestroyProvider()) + + sandbox = await manager.create_sandbox(None, "session-a", "generic", "Named") + task = asyncio.create_task( + manager.destroy_sandbox_deferred("session-a", sandbox["sandbox_id"]) + ) + await asyncio.wait_for(provider.destroy_started.wait(), timeout=1) + provider.allow_destroy.set() + + await manager.cleanup_managed_sandboxes() + + await task + assert sandbox["sandbox_id"] not in manager.pending_destroy_tasks + assert manager.registry.get_sandbox(sandbox["sandbox_id"]) is None + assert sandbox["sandbox_id"] not in manager.session_booter + assert provider.destroyed[0][1] == sandbox["sandbox_id"] + + +@pytest.mark.asyncio +async def test_manager_cleanup_destroys_temporary_sandboxes_and_keeps_persistent_records( + tmp_path, +): + manager, provider = _manager(tmp_path) + temporary = await manager.create_sandbox(None, "session-a", "generic") + persistent = await manager.create_sandbox(None, "session-b", "generic") + persistent_booter = manager.session_booter[persistent["sandbox_id"]] + manager.update_sandbox_config( + persistent["sandbox_id"], + idle_timeout=None, + expires_at=None, + retention_policy="persistent", + ) + + await manager.cleanup_managed_sandboxes() + + assert manager.registry.get_sandbox(temporary["sandbox_id"]) is None + assert manager.registry.get_sandbox(persistent["sandbox_id"])["status"] == "running" + assert len(provider.destroyed) == 1 + assert provider.destroyed[0][1] == temporary["sandbox_id"] + assert persistent_booter.shutdown_calls == 1 + + +@pytest.mark.asyncio +async def test_manager_cleanup_clears_persistent_runtime_memory_state(tmp_path): + manager, provider = _manager(tmp_path) + persistent = await manager.create_sandbox(None, "session-a", "generic") + persistent_booter = manager.session_booter[persistent["sandbox_id"]] + manager.update_sandbox_config( + persistent["sandbox_id"], + idle_timeout=None, + expires_at=None, + retention_policy="persistent", + ) + persistent_id = persistent["sandbox_id"] + manager._sandbox_boot_lock(persistent_id) + + await manager.cleanup_managed_sandboxes() + + assert manager.registry.get_sandbox(persistent_id) is not None + assert persistent_id not in manager.session_booter + assert persistent_id not in manager.idle_state + assert persistent_id not in manager.boot_locks + assert provider.destroyed == [] + assert persistent_booter.shutdown_calls == 1 + + +@pytest.mark.asyncio +async def test_ttl_cleanup_retries_transient_destroy_failure(monkeypatch, tmp_path): + from astrbot.core.computer import sandbox_manager as sandbox_manager_module + + manager, provider = _manager(tmp_path, FailsOnceDestroyProvider()) + monkeypatch.setattr( + sandbox_manager_module, "SANDBOX_TTL_DESTROY_RETRY_SECONDS", 0, raising=False + ) + sandbox = await manager.create_sandbox( + FakeContext({"sandbox_idle_timeout": 0, "sandbox_ttl": 60}), + "session-a", + "generic", + ) + sandbox_id = sandbox["sandbox_id"] + expired_at = time.time() - 1 + manager.registry.update_sandbox_config(sandbox_id, expires_at=expired_at) + manager.schedule_ttl_cleanup(sandbox_id, expired_at) + cleanup_task = manager.expiration_state[sandbox_id].task + + await asyncio.wait_for(cleanup_task, timeout=5) + + assert provider.destroy_attempts == 2 + assert manager.registry.get_sandbox(sandbox_id) is None + assert sandbox_id not in manager.session_booter + + +@pytest.mark.asyncio +async def test_destroy_persistent_sandbox_with_missing_provider_removes_stale_record( + tmp_path, +): + manager, _provider = _manager(tmp_path) + sandbox = await manager.create_sandbox(None, "session-a", "generic") + sandbox_id = sandbox["sandbox_id"] + manager.update_sandbox_config( + sandbox_id, + idle_timeout=None, + expires_at=None, + retention_policy="persistent", + ) + manager.providers.pop("generic") + + removed = await manager.destroy_sandbox_deferred("session-a", sandbox_id) + + assert removed["sandbox_id"] == sandbox_id + assert manager.registry.get_sandbox(sandbox_id) is None + assert sandbox_id not in manager.session_booter + + +def test_manager_update_sandbox_config_rejects_duplicate_name(tmp_path): + manager, _provider = _manager(tmp_path) + first = manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="First", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "First"}, + ) + second = manager.registry.upsert_sandbox( + sandbox_id="generic-2", + sandbox_name="Second", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Second"}, + ) + + with pytest.raises(RuntimeError, match="Sandbox name 'First' already exists"): + manager.update_sandbox_config( + second["sandbox_id"], + sandbox_name=first["sandbox_name"], + idle_timeout=None, + expires_at=None, + retention_policy="temporary", + ) + + +@pytest.mark.asyncio +async def test_manager_update_sandbox_config_clears_expires_at_when_idle_enabled( + tmp_path, +): + manager, _provider = _manager(tmp_path) + record = manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Generic", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Generic"}, + idle_timeout=0, + expires_at=time.time() + 3600, + ) + + updated = manager.update_sandbox_config( + record["sandbox_id"], + idle_timeout=30, + expires_at=time.time() + 7200, + retention_policy="temporary", + ) + + assert updated["idle_timeout"] == 30 + assert updated["expires_at"] is None + assert record["sandbox_id"] in manager.idle_state + assert record["sandbox_id"] not in manager.expiration_state + + +def test_manager_update_sandbox_config_rejects_persistent_for_unsupported_provider( + tmp_path, +): + manager, provider = _manager(tmp_path) + provider.supports_persistent_reconnect = False + created = manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="First", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "First"}, + ) + + with pytest.raises( + RuntimeError, match="Provider generic does not support persistent sandboxes" + ): + manager.update_sandbox_config( + created["sandbox_id"], + sandbox_name=created["sandbox_name"], + idle_timeout=None, + expires_at=None, + retention_policy="persistent", + ) + + +def test_manager_update_sandbox_config_rejects_persistent_for_missing_provider( + tmp_path, +): + manager, _provider = _manager(tmp_path) + created = manager.registry.upsert_sandbox( + sandbox_id="missing-1", + sandbox_name="Missing", + provider="missing", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Missing"}, + ) + + with pytest.raises(RuntimeError, match="Provider missing is not available"): + manager.update_sandbox_config( + created["sandbox_id"], + sandbox_name=created["sandbox_name"], + idle_timeout=None, + expires_at=None, + retention_policy="persistent", + ) + + +@pytest.mark.asyncio +async def test_manager_set_sandbox_retention_policy_makes_sandbox_persistent(tmp_path): + manager, _provider = _manager(tmp_path) + record = manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Generic", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Generic"}, + idle_timeout=30, + expires_at=time.time() + 3600, + ) + manager.schedule_idle_cleanup(record["sandbox_id"], 30) + + updated = manager.set_sandbox_retention_policy( + FakeContext({"sandbox_idle_timeout": 30, "sandbox_ttl": 120}), + "session-a", + record["sandbox_id"], + "persistent", + ) + + assert updated["retention_policy"] == "persistent" + assert updated["idle_timeout"] is None + assert updated["expires_at"] is None + assert record["sandbox_id"] not in manager.idle_state + assert record["sandbox_id"] not in manager.expiration_state + + +@pytest.mark.asyncio +async def test_manager_set_sandbox_retention_policy_makes_sandbox_temporary(tmp_path): + manager, _provider = _manager(tmp_path) + record = manager.registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Generic", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="session-a", + owner_session_id="session-a", + connect_info={"name": "Generic"}, + retention_policy="persistent", + ) + + updated = manager.set_sandbox_retention_policy( + FakeContext({"sandbox_idle_timeout": 30, "sandbox_ttl": 120}), + "session-a", + record["sandbox_id"], + "temporary", + ) + + assert updated["retention_policy"] == "temporary" + assert updated["idle_timeout"] == 30 + assert updated["expires_at"] is None + assert record["sandbox_id"] in manager.idle_state + + +def test_manager_save_registry_propagates_write_failures(tmp_path): + manager, _provider = _manager(tmp_path) + + def fail_save(): + raise OSError("disk full") + + manager.registry.save = fail_save + + with pytest.raises(OSError, match="disk full"): + manager.save_registry() + + +@pytest.mark.asyncio +async def test_manager_save_registry_async_propagates_write_failures(tmp_path): + manager, _provider = _manager(tmp_path) + + async def fail_save_async(): + raise OSError("disk full") + + manager.registry.save_async = fail_save_async + + with pytest.raises(OSError, match="disk full"): + await manager.save_registry_async() + + +@pytest.mark.asyncio +async def test_manager_observer_access_does_not_refresh_idle_timer_for_unclaimed_sandbox( + tmp_path, +): + provider = FakeProvider() + manager, _provider = _manager(tmp_path, provider) + created = await manager.create_sandbox( + FakeContext({"sandbox_idle_timeout": 30, "sandbox_ttl": 0}), + "session-a", + "generic", + ) + manager.release_current_sandbox("session-a", created["sandbox_id"]) + first_state = manager.idle_state[created["sandbox_id"]] + first_last_used_at = manager.registry.get_sandbox(created["sandbox_id"])[ + "last_used_at" + ] + + await manager.get_observer_booter_by_id( + created["sandbox_id"], "dashboard", require_lease=False + ) + + second_state = manager.idle_state[created["sandbox_id"]] + second_last_used_at = manager.registry.get_sandbox(created["sandbox_id"])[ + "last_used_at" + ] + assert second_state is first_state + assert second_last_used_at == first_last_used_at diff --git a/tests/unit/test_sandbox_models.py b/tests/unit/test_sandbox_models.py new file mode 100644 index 0000000000..1c99266e2a --- /dev/null +++ b/tests/unit/test_sandbox_models.py @@ -0,0 +1,98 @@ +import time + +import pytest + +from astrbot.core.computer.sandbox_models import SandboxRecord + + +def test_sandbox_record_round_trips_generic_capabilities_sorted(): + record = SandboxRecord.from_dict( + { + "sandbox_id": "sandbox-1", + "sandbox_name": "General Sandbox", + "provider": "generic-provider", + "managed": True, + "created_by_astrbot": True, + "capabilities": ["keyboard", "shell", "filesystem", "shell"], + } + ) + + assert record.capabilities == ["filesystem", "keyboard", "shell", "shell"] + payload = record.to_dict() + assert payload["sandbox_id"] == "sandbox-1" + assert payload["retention_policy"] == "temporary" + assert payload["status"] == "running" + assert payload["capabilities"] == ["filesystem", "keyboard", "shell", "shell"] + assert "booter_type" not in payload + + +def test_sandbox_record_aliases_owner_fields_to_created_by_fields(): + record = SandboxRecord.from_dict( + { + "sandbox_id": "sandbox-1", + "sandbox_name": "General Sandbox", + "provider": "generic-provider", + "managed": True, + "created_by_astrbot": True, + "owner_user_id": "legacy-user", + "owner_session_id": "legacy-session", + } + ) + + payload = record.to_dict() + + assert payload["created_by_user_id"] == "legacy-user" + assert payload["created_by_session_id"] == "legacy-session" + assert payload["owner_user_id"] == "legacy-user" + assert payload["owner_session_id"] == "legacy-session" + + +def test_sandbox_record_validates_required_strings(): + with pytest.raises(ValueError, match="sandbox_id"): + SandboxRecord.from_dict( + { + "sandbox_id": "", + "sandbox_name": "General Sandbox", + "provider": "generic-provider", + "managed": True, + "created_by_astrbot": True, + } + ) + + +def test_sandbox_record_reports_active_control_lease(): + now = time.time() + record = SandboxRecord.from_dict( + { + "sandbox_id": "sandbox-1", + "sandbox_name": "General Sandbox", + "provider": "generic-provider", + "managed": True, + "created_by_astrbot": True, + "controller_session_id": "session-a", + "lease_expires_at": now + 60, + } + ) + + assert record.has_active_lease(now=now) + assert record.is_controlled_by("session-a", now=now) + assert not record.is_controlled_by("session-b", now=now) + assert not record.has_active_lease(now=now + 120) + + +def test_sandbox_record_migrates_legacy_booter_type_to_provider(): + record = SandboxRecord.from_dict( + { + "sandbox_id": "sandbox-1", + "sandbox_name": "General Sandbox", + "booter_type": "legacy-provider", + "managed": True, + "created_by_astrbot": True, + } + ) + + payload = record.to_dict() + + assert record.provider == "legacy-provider" + assert payload["provider"] == "legacy-provider" + assert "booter_type" not in payload diff --git a/tests/unit/test_sandbox_provider.py b/tests/unit/test_sandbox_provider.py new file mode 100644 index 0000000000..c5ae108f1b --- /dev/null +++ b/tests/unit/test_sandbox_provider.py @@ -0,0 +1,24 @@ +from typing import Any, get_type_hints + +from astrbot.core.computer.sandbox_provider import SandboxProvider + + +def test_sandbox_provider_protocol_exposes_generic_runtime_contract(): + protocol_hints = get_type_hints(SandboxProvider) + assert protocol_hints["provider_id"] is str + assert protocol_hints["capabilities"] == set[str] + assert protocol_hints["tool_names"] == set[str] + assert protocol_hints["system_prompt"] is str + assert protocol_hints["plugin_config"] == dict[str, Any] | None + assert protocol_hints["provider_api_version"] is str + assert protocol_hints["auto_sync_skills"] is bool + assert protocol_hints["supports_persistent_reconnect"] is bool + + hints = get_type_hints(SandboxProvider.create_booter) + assert "context" in hints + assert hints["session_id"] is str + assert hints["sandbox_id"] is str + + +def test_sandbox_provider_protocol_has_no_default_existence_probe(): + assert "check_persistent_sandbox_exists" not in SandboxProvider.__dict__ diff --git a/tests/unit/test_sandbox_registry.py b/tests/unit/test_sandbox_registry.py new file mode 100644 index 0000000000..fdd8c3cf88 --- /dev/null +++ b/tests/unit/test_sandbox_registry.py @@ -0,0 +1,366 @@ +import asyncio +import json +import threading +from pathlib import Path + +import pytest + +from astrbot.core.computer.sandbox_registry import SandboxRegistry + + +def _registry(tmp_path): + return SandboxRegistry(tmp_path / "sandbox_registry.json") + + +def _upsert(registry, sandbox_id="generic-1", provider="generic"): + return registry.upsert_sandbox( + sandbox_id=sandbox_id, + sandbox_name=f"Sandbox {sandbox_id}", + provider=provider, + managed=True, + created_by_astrbot=True, + owner_user_id="user-1", + owner_session_id="session-1", + connect_info={"name": sandbox_id}, + capabilities={"shell", "python", "filesystem"}, + tool_names={"generic_tool"}, + ) + + +def test_registry_upserts_lists_and_deletes_sandboxes(tmp_path): + registry = _registry(tmp_path) + + record = _upsert(registry) + + assert record["sandbox_id"] == "generic-1" + assert record["capabilities"] == ["filesystem", "python", "shell"] + assert record["tool_names"] == ["generic_tool"] + assert registry.get_sandbox("generic-1")["sandbox_name"] == "Sandbox generic-1" + assert [item["sandbox_id"] for item in registry.list_sandboxes()] == ["generic-1"] + + registry.delete_sandbox("generic-1") + + assert registry.get_sandbox("generic-1") is None + assert registry.list_sandboxes() == [] + + +def test_registry_tracks_provider_defaults_and_current_session(tmp_path): + registry = _registry(tmp_path) + _upsert(registry, "generic-1", provider="generic") + _upsert(registry, "other-1", provider="other") + + registry.set_default_sandbox_id("generic-1") + registry.set_current_sandbox_id("session-a", "generic-1") + + assert registry.get_default_sandbox_id("generic") == "generic-1" + assert registry.get_default_sandbox_id("other") is None + assert registry.get_current_sandbox_id("session-a") == "generic-1" + + registry.delete_sandbox("generic-1") + + assert registry.get_current_sandbox_id("session-a") is None + + +def test_registry_acquires_releases_and_takes_over_leases(tmp_path): + registry = _registry(tmp_path) + _upsert(registry) + + assert registry.acquire_lease( + sandbox_id="generic-1", session_id="session-a", user_id="user-a", ttl=60, now=10 + ) + assert not registry.acquire_lease( + sandbox_id="generic-1", session_id="session-b", user_id="user-b", ttl=60, now=20 + ) + + released = registry.release_lease("generic-1") + assert released["controller_session_id"] is None + assert registry.acquire_lease( + sandbox_id="generic-1", session_id="session-b", user_id="user-b", ttl=60, now=20 + ) + + taken = registry.takeover_lease( + sandbox_id="generic-1", session_id="session-c", user_id="user-c", ttl=60, now=30 + ) + assert taken["controller_session_id"] == "session-c" + + +def test_registry_normalizes_sandbox_status_enum_values(tmp_path): + from astrbot.core.computer.sandbox_models import SandboxStatus + + registry = _registry(tmp_path) + _upsert(registry) + + updated = registry.update_sandbox_status("generic-1", SandboxStatus.UNKNOWN) + + assert updated["status"] == "unknown" + assert type(updated["status"]) is str + assert registry.get_sandbox("generic-1")["status"] == "unknown" + assert type(registry.get_sandbox("generic-1")["status"]) is str + + +def test_registry_saves_loads_and_reconciles_runtime_state(tmp_path): + registry = _registry(tmp_path) + _upsert(registry) + registry.acquire_lease( + sandbox_id="generic-1", session_id="session-a", user_id="user-a", ttl=60, now=10 + ) + registry.set_current_sandbox_id("session-a", "generic-1") + registry.save() + + loaded = _registry(tmp_path) + loaded.load() + assert loaded.get_sandbox("generic-1")["controller_session_id"] == "session-a" + + loaded.reconcile_startup() + + assert loaded.get_sandbox("generic-1") is None + assert loaded.get_current_sandbox_id("session-a") is None + + payload = json.loads((tmp_path / "sandbox_registry.json").read_text()) + assert payload["schema_version"] == 1 + assert "sandboxes" in payload + + +def test_registry_loads_legacy_payload_without_schema_version(tmp_path): + legacy_payload = { + "default_sandbox_id": None, + "default_sandbox_ids": {}, + "sandboxes": { + "generic-1": { + "sandbox_id": "generic-1", + "sandbox_name": "Legacy", + "provider": "generic", + "managed": True, + "created_by_astrbot": True, + "owner_user_id": "user-1", + "owner_session_id": "session-1", + "connect_info": {"name": "Legacy"}, + } + }, + "session_current": {}, + } + path = tmp_path / "sandbox_registry.json" + path.write_text(json.dumps(legacy_payload), encoding="utf-8") + + registry = _registry(tmp_path) + registry.load() + + assert registry._payload["schema_version"] == 1 + assert registry.get_sandbox("generic-1")["sandbox_name"] == "Legacy" + + +def test_registry_loads_non_object_payload_as_empty_registry(tmp_path): + path = tmp_path / "sandbox_registry.json" + path.write_text("[]", encoding="utf-8") + + registry = _registry(tmp_path) + registry.load() + + assert registry._payload["schema_version"] == 1 + assert registry.list_sandboxes() == [] + + +def test_registry_reconcile_startup_removes_non_persistent_creating_records( + tmp_path, +): + registry = _registry(tmp_path) + registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Sandbox generic-1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="user-1", + owner_session_id="session-1", + connect_info={"name": "generic-1"}, + status="creating", + ) + + registry.reconcile_startup() + + assert registry.get_sandbox("generic-1") is None + + +def test_registry_reconcile_startup_removes_temporary_records(tmp_path): + registry = _registry(tmp_path) + registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Sandbox generic-1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="user-1", + owner_session_id="session-1", + connect_info={"name": "generic-1"}, + status="running", + retention_policy="temporary", + ) + + registry.reconcile_startup() + + assert registry.get_sandbox("generic-1") is None + + +def test_registry_reconcile_startup_removes_non_persistent_restoring_records( + tmp_path, +): + registry = _registry(tmp_path) + registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Sandbox generic-1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="user-1", + owner_session_id="session-1", + connect_info={"name": "generic-1"}, + status="restoring", + ) + + registry.reconcile_startup() + + assert registry.get_sandbox("generic-1") is None + + +def test_registry_reconcile_startup_marks_persistent_running_unknown(tmp_path): + registry = _registry(tmp_path) + registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Sandbox generic-1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="user-1", + owner_session_id="session-1", + connect_info={"name": "generic-1"}, + status="running", + retention_policy="persistent", + ) + + registry.reconcile_startup() + + assert registry.get_sandbox("generic-1")["status"] == "unknown" + + +def test_registry_reconcile_startup_clears_stale_default_references(tmp_path): + registry = _registry(tmp_path) + registry._payload["default_sandbox_id"] = "missing-1" + registry._payload["default_sandbox_ids"] = {"generic": "missing-1"} + registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Sandbox generic-1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="user-1", + owner_session_id="session-1", + connect_info={"name": "generic-1"}, + retention_policy="persistent", + status="running", + ) + + registry.reconcile_startup() + + assert registry.default_sandbox_id is None + assert registry.get_default_sandbox_id("generic") is None + assert registry._payload["default_sandbox_ids"] == {} + + +def test_registry_reconcile_startup_clears_stale_is_default_flags(tmp_path): + registry = _registry(tmp_path) + record = registry.upsert_sandbox( + sandbox_id="generic-1", + sandbox_name="Sandbox generic-1", + provider="generic", + managed=True, + created_by_astrbot=True, + owner_user_id="user-1", + owner_session_id="session-1", + connect_info={"name": "generic-1"}, + retention_policy="persistent", + status="running", + ) + registry._payload["default_sandbox_id"] = "missing-1" + registry._payload["default_sandbox_ids"] = {} + registry._payload["sandboxes"][record["sandbox_id"]]["is_default"] = True + + registry.reconcile_startup() + + assert registry.get_sandbox("generic-1")["is_default"] is False + + +@pytest.mark.asyncio +async def test_registry_save_async_runs_save_in_worker_thread(tmp_path): + registry = _registry(tmp_path) + main_thread_id = threading.get_ident() + save_thread_id = None + + def fake_write_payload(payload): + nonlocal save_thread_id + assert payload == registry._payload + save_thread_id = threading.get_ident() + + registry._write_payload = fake_write_payload + + await registry.save_async() + + assert save_thread_id is not None + assert save_thread_id != main_thread_id + + +@pytest.mark.asyncio +async def test_registry_save_async_serializes_writes(tmp_path): + registry = _registry(tmp_path) + active_writes = 0 + max_active_writes = 0 + release_first_write = None + first_write_started = asyncio.Event() + + def fake_write_payload(payload): + nonlocal active_writes, max_active_writes, release_first_write + active_writes += 1 + max_active_writes = max(max_active_writes, active_writes) + if release_first_write is None: + release_first_write = threading.Event() + first_write_started.set() + assert release_first_write.wait(timeout=1) + active_writes -= 1 + + registry._write_payload = fake_write_payload + + first = asyncio.create_task(registry.save_async()) + await first_write_started.wait() + second = asyncio.create_task(registry.save_async()) + await asyncio.sleep(0.05) + + assert max_active_writes == 1 + release_first_write.set() + await first + await second + + +def test_registry_write_payload_replaces_temp_file_atomically(tmp_path, monkeypatch): + registry = _registry(tmp_path) + payload = { + "default_sandbox_id": None, + "default_sandbox_ids": {}, + "sandboxes": {"generic-1": {"sandbox_id": "generic-1"}}, + "session_current": {}, + } + replace_calls = [] + original_replace = Path.replace + + def track_replace(self, target): + replace_calls.append((Path(self), Path(target))) + return original_replace(self, target) + + monkeypatch.setattr(Path, "replace", track_replace) + + registry._write_payload(payload) + + assert replace_calls + source_path, target_path = replace_calls[0] + assert source_path != target_path + assert target_path == registry.storage_path + assert source_path.parent == target_path.parent + assert json.loads(registry.storage_path.read_text(encoding="utf-8")) == payload diff --git a/tests/unit/test_sandbox_timeouts.py b/tests/unit/test_sandbox_timeouts.py new file mode 100644 index 0000000000..971e99ffb2 --- /dev/null +++ b/tests/unit/test_sandbox_timeouts.py @@ -0,0 +1,74 @@ +from astrbot.core.computer.sandbox_timeouts import ( + DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS, + expires_at_from_timeout, + idle_cleanup_at_from_record, + lease_is_active, + resolve_sandbox_timeout, +) + + +def test_resolve_sandbox_timeout_prefers_generic_key_over_alias(): + config = {"sandbox_ttl": 0, "cua_ttl": 3600} + + resolved = resolve_sandbox_timeout( + config, + "sandbox_ttl", + aliases=("cua_ttl",), + default=3600, + ) + + assert resolved == 0 + + +def test_resolve_sandbox_timeout_uses_legacy_alias_when_generic_missing(): + config = {"shipyard_ttl": "120"} + + resolved = resolve_sandbox_timeout( + config, + "sandbox_ttl", + aliases=("shipyard_ttl",), + default=3600, + ) + + assert resolved == 120 + + +def test_resolve_sandbox_timeout_falls_back_for_invalid_values(): + assert ( + resolve_sandbox_timeout( + {"sandbox_lease_timeout": "forever"}, + "sandbox_lease_timeout", + default=DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS, + ) + == DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS + ) + + +def test_resolve_sandbox_timeout_rejects_boolean_values(): + assert ( + resolve_sandbox_timeout({"sandbox_ttl": True}, "sandbox_ttl", default=3600) + == 3600 + ) + + +def test_zero_lease_timeout_is_an_indefinite_active_lease(): + assert lease_is_active("session-a", None, now=100.0) is True + assert lease_is_active(None, None, now=100.0) is False + assert lease_is_active("session-a", 99.0, now=100.0) is False + assert lease_is_active("session-a", 101.0, now=100.0) is True + + +def test_expires_at_from_timeout_uses_absolute_time_and_hides_zero(): + assert expires_at_from_timeout(0, now=100.0) is None + assert expires_at_from_timeout(300, now=100.0) == 400.0 + + +def test_idle_cleanup_at_from_record_uses_last_used_time(): + assert ( + idle_cleanup_at_from_record(last_used_at=100.0, idle_timeout=0, now=100.0) + is None + ) + assert ( + idle_cleanup_at_from_record(last_used_at=100.0, idle_timeout=30, now=100.0) + == 130.0 + ) diff --git a/tests/unit/test_sandbox_tool_binding.py b/tests/unit/test_sandbox_tool_binding.py new file mode 100644 index 0000000000..a4b97461b5 --- /dev/null +++ b/tests/unit/test_sandbox_tool_binding.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass, field + +from astrbot.core.agent.tool import FunctionTool +from astrbot.core.computer.sandbox_tool_binding import ( + get_sandbox_provider_tool_config_statuses, + sandbox_provider_tool, + tool_available_in_runtime, +) + + +class FakeTool: + def __init__( + self, name: str, active: bool = True, sandbox_provider_id: str | None = None + ): + self.name = name + self.active = active + self.sandbox_provider_id = sandbox_provider_id + + +def test_provider_scoped_tool_is_available_to_any_sandbox_runtime(): + tool = FakeTool("sandbox_tool", sandbox_provider_id="Generic") + + assert tool_available_in_runtime(tool, "sandbox") + assert not tool_available_in_runtime(tool, "local") + + +def test_unscoped_tool_is_available_to_every_runtime(): + tool = FakeTool("regular_tool") + + assert tool_available_in_runtime(tool, "sandbox") + assert tool_available_in_runtime(tool, "local") + + +def test_sandbox_provider_tool_marks_class_and_registers_config_rule(): + @sandbox_provider_tool( + "Example", + config={"provider_settings.sandbox.booter": "example"}, + ) + @dataclass + class ExampleTool(FunctionTool): + name: str = "example_sandbox_tool" + description: str = "Example" + parameters: dict = field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + + tool = ExampleTool() + + assert tool.sandbox_provider_id == "example" + + statuses = get_sandbox_provider_tool_config_statuses( + "example_sandbox_tool", + [ + { + "conf_id": "conf-a", + "conf_name": "Config A", + "config": {"provider_settings": {"sandbox": {"booter": "example"}}}, + } + ], + ) + assert statuses[0]["enabled"] is True diff --git a/tests/unit/test_sandbox_tool_consolidation.py b/tests/unit/test_sandbox_tool_consolidation.py new file mode 100644 index 0000000000..f80a9dabe9 --- /dev/null +++ b/tests/unit/test_sandbox_tool_consolidation.py @@ -0,0 +1,55 @@ +from astrbot.core.tools.computer_tools import sandbox as sandbox_tools + + +def _actions(tool_cls) -> set[str]: + return set(tool_cls().parameters["properties"]["action"]["enum"]) + + +def test_generic_sandbox_tools_are_grouped_by_intent(): + assert sandbox_tools.SandboxQueryTool().name == "astrbot_sandbox_query" + assert sandbox_tools.SandboxLifecycleTool().name == "astrbot_sandbox_lifecycle" + assert sandbox_tools.SandboxOperationTool().name == "astrbot_sandbox_operation" + + assert _actions(sandbox_tools.SandboxQueryTool) == { + "list_sandboxes", + "get_current", + "list_providers", + } + assert _actions(sandbox_tools.SandboxLifecycleTool) == { + "create", + "switch", + "release", + "renew_lease", + "set_retention", + "takeover", + "destroy", + } + assert _actions(sandbox_tools.SandboxOperationTool) == { + "capture_screenshot", + "copy_file", + } + operation_params = sandbox_tools.SandboxOperationTool().parameters["properties"] + assert "return_image_to_llm" in operation_params + assert ( + "copy_file requires source_sandbox_id" + in sandbox_tools.SandboxOperationTool().description + ) + + +def test_legacy_generic_sandbox_tools_are_not_registered(): + legacy_names = { + "ListSandboxesTool", + "ListSandboxProvidersTool", + "GetCurrentSandboxTool", + "CreateSandboxTool", + "SwitchSandboxTool", + "ReleaseSandboxTool", + "SetSandboxRetentionPolicyTool", + "KeepAliveSandboxTool", + "TakeoverSandboxTool", + "DestroySandboxTool", + "ScreenshotSandboxTool", + "CopyFileBetweenSandboxesTool", + } + + assert not any(hasattr(sandbox_tools, name) for name in legacy_names) diff --git a/tests/unit/test_sandbox_tools_permissions.py b/tests/unit/test_sandbox_tools_permissions.py new file mode 100644 index 0000000000..be9004a7f0 --- /dev/null +++ b/tests/unit/test_sandbox_tools_permissions.py @@ -0,0 +1,904 @@ +import json +import time +from pathlib import Path +from types import SimpleNamespace + +import mcp +import pytest + +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.tools.computer_tools.sandbox import ( + SandboxLifecycleTool, + SandboxOperationTool, + SandboxQueryTool, +) + + +class FakeEvent: + unified_msg_origin = "session-a" + role = "member" + + def get_sender_id(self): + return "user-a" + + +def _context(): + plugin_context = SimpleNamespace( + get_config=lambda umo=None: { + "provider_settings": {"computer_use_require_admin": True} + } + ) + return ContextWrapper( + context=SimpleNamespace(event=FakeEvent(), context=plugin_context) + ) + + +def _admin_context_without_admin_requirement(): + event = FakeEvent() + event.role = "admin" + plugin_context = SimpleNamespace( + get_config=lambda umo=None: { + "provider_settings": {"computer_use_require_admin": False} + } + ) + return ContextWrapper(context=SimpleNamespace(event=event, context=plugin_context)) + + +def _member_context_without_admin_requirement(): + plugin_context = SimpleNamespace( + get_config=lambda umo=None: { + "provider_settings": {"computer_use_require_admin": False} + } + ) + return ContextWrapper( + context=SimpleNamespace(event=FakeEvent(), context=plugin_context) + ) + + +def _member_context_with_sandbox_permissions(**permissions): + plugin_context = SimpleNamespace( + get_config=lambda umo=None: { + "provider_settings": { + "computer_use_require_admin": False, + "computer_use_runtime": "sandbox", + "sandbox": { + "booter": "generic", + "member_permissions": dict(permissions), + }, + } + } + ) + return ContextWrapper( + context=SimpleNamespace(event=FakeEvent(), context=plugin_context) + ) + + +def _sandbox_context(default_provider: str = "generic"): + plugin_context = SimpleNamespace( + get_config=lambda umo=None: { + "provider_settings": { + "computer_use_require_admin": False, + "computer_use_runtime": "sandbox", + "sandbox": {"booter": default_provider}, + } + } + ) + return ContextWrapper( + context=SimpleNamespace(event=FakeEvent(), context=plugin_context) + ) + + +@pytest.mark.asyncio +async def test_screenshot_sandbox_tool_requires_admin_permission(): + result = await SandboxOperationTool().call( + _context(), "capture_screenshot", sandbox_id="sandbox-1" + ) + + assert "Permission denied" in str(result) + + +@pytest.mark.asyncio +async def test_copy_file_between_sandboxes_tool_requires_admin_permission(): + result = await SandboxOperationTool().call( + _context(), + "copy_file", + source_sandbox_id="source-1", + source_path="/tmp/source.txt", + target_sandbox_id="target-1", + target_path="/tmp/target.txt", + ) + + assert "Permission denied" in str(result) + + +@pytest.mark.asyncio +async def test_copy_file_between_sandboxes_handles_windows_target_filename( + monkeypatch, tmp_path +): + from astrbot.core.tools.computer_tools import sandbox as sandbox_tools + + copied: dict[str, str] = {} + + class SourceBooter: + async def download_file(self, source_path, local_path): + copied["source_path"] = source_path + copied["local_path"] = local_path + Path(local_path).write_text("payload", encoding="utf-8") + + class TargetBooter: + async def upload_file(self, local_path, target_path): + copied["upload_local_path"] = local_path + copied["target_path"] = target_path + return {"ok": True} + + class Manager: + async def get_observer_booter_by_id(self, sandbox_id, *args, **kwargs): + return SourceBooter() if sandbox_id == "source-1" else TargetBooter() + + monkeypatch.setattr(sandbox_tools, "get_astrbot_temp_path", lambda: str(tmp_path)) + monkeypatch.setattr( + sandbox_tools.computer_client, + "sandbox_manager", + Manager(), + ) + + result = await SandboxOperationTool().call( + _admin_context_without_admin_requirement(), + "copy_file", + source_sandbox_id="source-1", + source_path="/tmp/source.txt", + target_sandbox_id="target-1", + target_path=r"C:\Users\AstrBot\target.txt", + ) + + assert json.loads(result)["upload_result"] == {"ok": True} + assert Path(copied["local_path"]).name.endswith("-target.txt") + assert copied["target_path"] == r"C:\Users\AstrBot\target.txt" + + +@pytest.mark.asyncio +async def test_copy_file_between_sandboxes_includes_lease_metadata( + monkeypatch, tmp_path +): + from astrbot.core.tools.computer_tools import sandbox as sandbox_tools + + expires_at = time.time() + 600 + + class SourceBooter: + async def download_file(self, source_path, local_path): + Path(local_path).write_text("payload", encoding="utf-8") + + class TargetBooter: + async def upload_file(self, local_path, target_path): + return {"ok": True} + + class Manager: + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "controller_session_id": "session-a", + "lease_expires_at": expires_at, + } + ) + + async def get_observer_booter_by_id(self, sandbox_id, *args, **kwargs): + return SourceBooter() if sandbox_id == "source-1" else TargetBooter() + + monkeypatch.setattr(sandbox_tools, "get_astrbot_temp_path", lambda: str(tmp_path)) + monkeypatch.setattr(sandbox_tools.computer_client, "sandbox_manager", Manager()) + + result = await SandboxOperationTool().call( + _sandbox_context(), + "copy_file", + source_sandbox_id="source-1", + source_path="/tmp/source.txt", + target_sandbox_id="target-1", + target_path="/tmp/target.txt", + ) + payload = json.loads(str(result)) + + assert payload["lease"]["sandbox_id"] == "target-1" + assert payload["lease"]["lease_expires_at"] + assert payload["lease"]["lease_expires_in_seconds"] > 0 + assert payload["lease"]["auto_renew_interval_seconds"] == 600 + + +@pytest.mark.asyncio +async def test_sandbox_operation_can_return_screenshot_image_to_llm( + monkeypatch, tmp_path +): + from astrbot.core.tools.computer_tools import sandbox as sandbox_tools + + class Gui: + async def screenshot(self, path): + Path(path).write_bytes(b"image") + return {"base64": "aW1hZ2U=", "mime_type": "image/png"} + + class Booter: + gui = Gui() + + class Manager: + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "controller_session_id": "session-a", + "lease_expires_at": time.time() + 600, + } + ) + + async def get_observer_booter_by_id(self, sandbox_id, *args, **kwargs): + return Booter() + + monkeypatch.setattr(sandbox_tools, "get_astrbot_temp_path", lambda: str(tmp_path)) + monkeypatch.setattr( + sandbox_tools.computer_client, + "sandbox_manager", + Manager(), + ) + + result = await SandboxOperationTool().call( + _admin_context_without_admin_requirement(), + "capture_screenshot", + sandbox_id="sandbox-1", + return_image_to_llm=True, + ) + + assert isinstance(result, mcp.types.CallToolResult) + assert isinstance(result.content[0], mcp.types.TextContent) + assert isinstance(result.content[1], mcp.types.ImageContent) + assert json.loads(result.content[0].text)["lease"]["sandbox_id"] == "sandbox-1" + assert result.content[1].data == "aW1hZ2U=" + + +@pytest.mark.asyncio +async def test_sensitive_sandbox_tools_require_strict_admin_permission(): + context = _member_context_with_sandbox_permissions(set_retention_policy=True) + + assert "Permission denied" in str( + await SandboxLifecycleTool().call(context, "takeover", sandbox_id="sandbox-1") + ) + assert "Permission denied" in str( + await SandboxLifecycleTool().call(context, "destroy", sandbox_id="sandbox-1") + ) + + +@pytest.mark.asyncio +async def test_set_sandbox_retention_policy_tool_respects_admin_requirement(): + result = await SandboxLifecycleTool().call( + _context(), + "set_retention", + retention_policy="persistent", + sandbox_id="sandbox-1", + ) + + assert "Permission denied" in str(result) + + +@pytest.mark.asyncio +async def test_readonly_sandbox_tools_respect_admin_requirement(monkeypatch): + class FakeManager: + def list_sandboxes(self): + raise AssertionError("list must be denied") + + def get_current_sandbox(self, session_id): + raise AssertionError("get current must be denied") + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + monkeypatch.setattr( + "astrbot.core.computer.computer_client.list_sandbox_providers", + lambda: (_ for _ in ()).throw(AssertionError("providers must be denied")), + ) + context = _context() + + assert "Permission denied" in str( + await SandboxQueryTool().call(context, "list_sandboxes") + ) + assert "Permission denied" in str( + await SandboxQueryTool().call(context, "list_providers") + ) + assert "Permission denied" in str( + await SandboxQueryTool().call(context, "get_current") + ) + + +@pytest.mark.asyncio +async def test_member_sandbox_management_permissions_default_to_disabled(monkeypatch): + class FakeManager: + providers = {"generic": object()} + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "controller_session_id": "session-a", + } + ) + + async def create_sandbox(self, *args, **kwargs): + raise AssertionError("create must be denied by default") + + def set_sandbox_retention_policy(self, *args, **kwargs): + raise AssertionError("retention changes must be denied by default") + + async def destroy_sandbox(self, *args, **kwargs): + raise AssertionError("destroy must be denied by default") + + async def takeover_sandbox(self, *args, **kwargs): + raise AssertionError("takeover must be denied by default") + + def get_current_sandbox(self, session_id): + return {"current_sandbox_id": "sandbox-1"} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + context = _member_context_with_sandbox_permissions() + + assert "Permission denied" in str( + await SandboxLifecycleTool().call(context, "create") + ) + assert "Permission denied" in str( + await SandboxLifecycleTool().call( + context, "set_retention", retention_policy="persistent" + ) + ) + assert "Permission denied" in str( + await SandboxLifecycleTool().call(context, "destroy", sandbox_id="sandbox-1") + ) + assert "Permission denied" in str( + await SandboxLifecycleTool().call(context, "takeover", sandbox_id="sandbox-1") + ) + + +@pytest.mark.asyncio +async def test_member_takeover_sandbox_requires_explicit_permission(monkeypatch): + calls = [] + + class FakeManager: + async def takeover_sandbox(self, session_id, sandbox_id, **kwargs): + calls.append((session_id, sandbox_id, kwargs)) + return {"sandbox_id": sandbox_id} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_with_sandbox_permissions(takeover=True), + "takeover", + sandbox_id="sandbox-1", + ) + + assert "sandbox-1" in str(result) + assert calls + + +@pytest.mark.asyncio +async def test_create_sandbox_tool_reports_max_sandbox_limit(monkeypatch): + class FakeManager: + providers = {"generic": object()} + + async def create_sandbox(self, *args, **kwargs): + raise RuntimeError("Sandbox limit reached. Maximum managed sandboxes: 10.") + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_with_sandbox_permissions(create=True), "create" + ) + + assert "Error creating sandbox" in str(result) + assert "Sandbox limit reached" in str(result) + + +@pytest.mark.asyncio +async def test_member_list_sandboxes_includes_all_sandboxes_with_status( + monkeypatch, +): + class FakeManager: + def list_sandboxes(self): + return [ + { + "sandbox_id": "owned", + "owner_session_id": "session-a", + "controller_session_id": None, + }, + { + "sandbox_id": "current", + "owner_session_id": "session-b", + "controller_session_id": "session-a", + }, + { + "sandbox_id": "other-idle", + "sandbox_name": "Other Idle", + "owner_session_id": "session-b", + "owner_user_id": "user-b", + "created_by_session_id": "session-b", + "created_by_user_id": "user-b", + "controller_session_id": None, + "connect_info": {"secret": "idle-secret"}, + "status": "running", + }, + { + "sandbox_id": "other-busy", + "sandbox_name": "Other Busy", + "owner_session_id": "session-c", + "owner_user_id": "user-c", + "created_by_session_id": "session-c", + "created_by_user_id": "user-c", + "controller_session_id": "session-c", + "controller_user_id": "user-c", + "connect_info": {"secret": "busy-secret"}, + "status": "running", + }, + ] + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxQueryTool().call( + _member_context_without_admin_requirement(), "list_sandboxes" + ) + payload = json.loads(str(result)) + by_id = {item["sandbox_id"]: item for item in payload["sandboxes"]} + + assert "owned" in str(result) + assert "current" in str(result) + assert "other-idle" in str(result) + assert "other-busy" in str(result) + assert "idle-secret" not in str(result) + assert "busy-secret" not in str(result) + assert "session-c" not in str(result) + assert "user-c" not in str(result) + assert "session-b" not in str(result) + assert "user-b" not in str(result) + assert by_id["other-idle"]["access"]["status"] == "idle" + assert by_id["other-idle"]["access"]["can_switch"] is True + assert by_id["other-busy"]["access"]["status"] == "occupied" + assert by_id["other-busy"]["access"]["can_switch"] is False + + +@pytest.mark.asyncio +async def test_list_sandboxes_includes_access_status_for_admin(monkeypatch): + class FakeManager: + def list_sandboxes(self): + return [ + { + "sandbox_id": "current", + "controller_session_id": "session-a", + "controller_user_id": "user-a", + "connect_info": {"secret": "current-secret"}, + "status": "running", + }, + { + "sandbox_id": "occupied", + "controller_session_id": "session-b", + "controller_user_id": "user-b", + "lease_expires_at": time.time() + 60, + "connect_info": {"secret": "occupied-secret"}, + "status": "running", + }, + { + "sandbox_id": "idle", + "controller_session_id": None, + "connect_info": {"secret": "idle-secret"}, + "status": "running", + }, + ] + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxQueryTool().call( + _admin_context_without_admin_requirement(), "list_sandboxes" + ) + payload = json.loads(str(result)) + by_id = {item["sandbox_id"]: item for item in payload["sandboxes"]} + + assert by_id["current"]["access"] == { + "status": "current", + "can_switch": True, + "occupied": True, + } + assert by_id["occupied"]["access"] == { + "status": "occupied", + "can_switch": False, + "occupied": True, + } + assert by_id["idle"]["access"] == { + "status": "idle", + "can_switch": True, + "occupied": False, + } + assert by_id["occupied"]["connect_info"]["secret"] == "occupied-secret" + + +@pytest.mark.asyncio +async def test_sandbox_tools_use_current_computer_client_manager(monkeypatch): + from astrbot.core.computer import computer_client + + class FakeManager: + def list_sandboxes(self): + return [{"sandbox_id": "dynamic-manager", "controller_session_id": None}] + + monkeypatch.setattr(computer_client, "sandbox_manager", FakeManager()) + + result = await SandboxQueryTool().call( + _member_context_without_admin_requirement(), "list_sandboxes" + ) + + assert "dynamic-manager" in str(result) + + +@pytest.mark.asyncio +async def test_list_sandbox_providers_tool_exposes_loaded_provider_capabilities( + monkeypatch, +): + monkeypatch.setattr( + "astrbot.core.computer.computer_client.list_sandbox_providers", + lambda: [ + { + "provider_id": "generic", + "capabilities": ["shell"], + "tool_names": ["generic_tool"], + "system_prompt": "", + } + ], + ) + + result = await SandboxQueryTool().call(_sandbox_context(), "list_providers") + payload = json.loads(str(result)) + + assert payload["providers"] == [ + { + "provider_id": "generic", + "capabilities": ["shell"], + "tool_names": ["generic_tool"], + "system_prompt": "", + } + ] + + +@pytest.mark.asyncio +async def test_get_current_sandbox_tool_formats_agent_timestamps(monkeypatch): + class FakeManager: + def get_current_sandbox(self, session_id): + return { + "current_sandbox_id": "sandbox-1", + "sandbox": { + "sandbox_id": "sandbox-1", + "retention_policy": "persistent", + "lease_expires_at": 1778557598.4646258, + }, + } + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxQueryTool().call(_sandbox_context(), "get_current") + payload = json.loads(str(result)) + + assert payload["sandbox"]["retention_policy"] == "persistent" + assert payload["sandbox"]["lease_expires_at"] != 1778557598.4646258 + assert payload["sandbox"]["lease_expires_at"] + + +@pytest.mark.asyncio +async def test_create_sandbox_tool_defaults_to_configured_provider(monkeypatch): + calls = [] + + class FakeManager: + providers = {"generic": object(), "other": object()} + + async def create_sandbox( + self, plugin_context, session_id, provider_id, *, sandbox_name=None + ): + calls.append((plugin_context, session_id, provider_id, sandbox_name)) + return {"sandbox_id": "generic-1", "provider": provider_id} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_with_sandbox_permissions(create=True), + "create", + sandbox_name="Fresh", + ) + payload = json.loads(str(result)) + + assert payload["sandbox"]["provider"] == "generic" + assert calls[0][2:] == ("generic", "Fresh") + + +@pytest.mark.asyncio +async def test_create_sandbox_tool_accepts_explicit_provider_id(monkeypatch): + calls = [] + + class FakeManager: + providers = {"generic": object(), "other": object()} + + async def create_sandbox( + self, plugin_context, session_id, provider_id, *, sandbox_name=None + ): + calls.append((plugin_context, session_id, provider_id, sandbox_name)) + return {"sandbox_id": "other-1", "provider": provider_id} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_with_sandbox_permissions(create=True), + "create", + sandbox_name="Fresh", + provider_id="other", + ) + payload = json.loads(str(result)) + + assert payload["sandbox"]["provider"] == "other" + assert calls[0][2:] == ("other", "Fresh") + + +@pytest.mark.asyncio +async def test_member_switch_sandbox_allows_idle_default_sandbox(monkeypatch): + called = [] + + class FakeManager: + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "owner_session_id": "dashboard", + "controller_session_id": None, + "is_default": True, + } + ) + + async def switch_current_sandbox_checked( + self, session_id, sandbox_id, **kwargs + ): + called.append((session_id, sandbox_id, kwargs)) + return {"sandbox_id": sandbox_id} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_without_admin_requirement(), + "switch", + sandbox_id="default-idle", + ) + + assert "default-idle" in str(result) + assert called + + +@pytest.mark.asyncio +async def test_member_switch_sandbox_allows_idle_dashboard_sandbox(monkeypatch): + called = [] + + class FakeManager: + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "owner_session_id": "dashboard", + "controller_session_id": None, + "is_default": False, + } + ) + + async def switch_current_sandbox_checked( + self, session_id, sandbox_id, **kwargs + ): + called.append((session_id, sandbox_id, kwargs)) + return {"sandbox_id": sandbox_id} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_without_admin_requirement(), + "switch", + sandbox_id="ordinary-idle", + ) + + assert "ordinary-idle" in str(result) + assert called + + +@pytest.mark.asyncio +async def test_member_switch_sandbox_allows_idle_sandbox_from_any_session(monkeypatch): + called = [] + + class FakeManager: + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "owner_session_id": "session-b", + "controller_session_id": None, + "is_default": False, + } + ) + + async def switch_current_sandbox_checked( + self, session_id, sandbox_id, **kwargs + ): + called.append((session_id, sandbox_id, kwargs)) + return {"sandbox_id": sandbox_id} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_without_admin_requirement(), + "switch", + sandbox_id="other-idle", + ) + + assert "other-idle" in str(result) + assert called + + +@pytest.mark.asyncio +async def test_member_switch_sandbox_allows_expired_lease_sandbox(monkeypatch): + called = [] + + class FakeManager: + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "owner_session_id": "session-b", + "controller_session_id": "session-b", + "lease_expires_at": time.time() - 1, + "is_default": False, + } + ) + + async def switch_current_sandbox_checked( + self, session_id, sandbox_id, **kwargs + ): + called.append((session_id, sandbox_id, kwargs)) + return {"sandbox_id": sandbox_id} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_without_admin_requirement(), + "switch", + sandbox_id="expired-id", + ) + + assert "expired-id" in str(result) + assert called + + +@pytest.mark.asyncio +async def test_member_switch_sandbox_rejects_other_session_sandbox(monkeypatch): + class FakeManager: + def registry_get(self): + return None + + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "owner_session_id": "session-b", + "controller_session_id": "session-b", + } + ) + + async def switch_current_sandbox_checked(self, *args, **kwargs): + raise AssertionError("switch must not be called for another user's sandbox") + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_without_admin_requirement(), + "switch", + sandbox_id="other-idle", + ) + + assert "Permission denied" in str(result) + + +@pytest.mark.asyncio +async def test_keep_alive_sandbox_tool_renews_current_sandbox(monkeypatch): + calls = [] + + class FakeManager: + async def renew_current_sandbox_lease( + self, session_id, ttl_seconds=None, context=None + ): + calls.append((session_id, ttl_seconds, context)) + return {"sandbox_id": "sandbox-1", "lease_expires_at": time.time() + 3600} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + context = _member_context_without_admin_requirement() + result = await SandboxLifecycleTool().call(context, "renew_lease", ttl_seconds=3600) + + assert "sandbox-1" in str(result) + assert calls == [("session-a", 3600, context.context.context)] + + payload = json.loads(str(result)) + assert payload["lease"]["sandbox_id"] == "sandbox-1" + assert payload["lease"]["lease_expires_in_seconds"] > 0 + assert payload["lease"]["auto_renew_interval_seconds"] == 600 + + +@pytest.mark.asyncio +async def test_set_sandbox_retention_policy_tool_updates_current_sandbox(monkeypatch): + calls = [] + + class FakeManager: + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "controller_session_id": "session-a", + } + ) + + def set_sandbox_retention_policy( + self, plugin_context, session_id, sandbox_id, retention_policy, **kwargs + ): + calls.append((plugin_context, session_id, sandbox_id, retention_policy)) + return {"sandbox_id": sandbox_id, "retention_policy": retention_policy} + + def get_current_sandbox(self, session_id): + return {"current_sandbox_id": "sandbox-1"} + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + context = _member_context_with_sandbox_permissions(set_retention_policy=True) + result = await SandboxLifecycleTool().call( + context, "set_retention", retention_policy="persistent" + ) + payload = json.loads(str(result)) + + assert payload["sandbox"]["retention_policy"] == "persistent" + assert calls == [(context.context.context, "session-a", "sandbox-1", "persistent")] + + +@pytest.mark.asyncio +async def test_set_sandbox_retention_policy_tool_rejects_other_session_sandbox( + monkeypatch, +): + class FakeManager: + registry = SimpleNamespace( + get_sandbox=lambda sandbox_id: { + "sandbox_id": sandbox_id, + "controller_session_id": "session-b", + "lease_expires_at": time.time() + 60, + } + ) + + def set_sandbox_retention_policy(self, *args, **kwargs): + raise AssertionError("must not update another session's sandbox") + + monkeypatch.setattr( + "astrbot.core.computer.computer_client.sandbox_manager", FakeManager() + ) + + result = await SandboxLifecycleTool().call( + _member_context_with_sandbox_permissions(set_retention_policy=True), + "set_retention", + retention_policy="persistent", + sandbox_id="sandbox-1", + ) + + assert "Permission denied" in str(result) diff --git a/tests/unit/test_tool_conflict_resolution.py b/tests/unit/test_tool_conflict_resolution.py index 8146ad8874..706ec8e2a7 100644 --- a/tests/unit/test_tool_conflict_resolution.py +++ b/tests/unit/test_tool_conflict_resolution.py @@ -4,8 +4,6 @@ with a disabled built-in tool, the MCP tool should not be removed as collateral damage. """ -import pytest - from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.provider.func_tool_manager import FunctionToolManager diff --git a/tests/unit/test_tool_permission.py b/tests/unit/test_tool_permission.py index 1a3a8a376b..d99d54cd0b 100644 --- a/tests/unit/test_tool_permission.py +++ b/tests/unit/test_tool_permission.py @@ -1,20 +1,16 @@ """Tests for per-tool permission management.""" -import asyncio import json -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import MagicMock, patch import pytest from astrbot.core import sp -from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.agent.tool import FunctionTool from astrbot.core.provider.func_tool_manager import ( FunctionToolManager, _PermissionGuardedTool, ) -from astrbot.core.tools.computer_tools.shell import ExecuteShellTool - # ── helpers ────────────────────────────────────────────────────────── @@ -298,9 +294,9 @@ def test_get_full_tool_set_wraps_non_builtin(): plugin_tools = [t for t in tool_set.tools if t.name == "my_plugin_tool"] assert plugin_tools - assert isinstance( - plugin_tools[0], _PermissionGuardedTool - ), "non-builtin tools must be wrapped" + assert isinstance(plugin_tools[0], _PermissionGuardedTool), ( + "non-builtin tools must be wrapped" + ) # ── API: get_tool_list permission fields ────────────────────────────── @@ -358,6 +354,206 @@ async def test_list_no_permission_fields_for_builtin(self): assert "permission_configured" not in target assert target["readonly"] is True + @pytest.mark.asyncio + async def test_list_labels_sandbox_provider_builtin_tools_by_provider(self): + from astrbot.core.computer.sandbox_tool_binding import ( + mark_tool_as_sandbox_provider_tool, + ) + from astrbot.core.tools.computer_tools import SandboxQueryTool + from astrbot.dashboard.routes.tools import ToolsRoute + + route = ToolsRoute.__new__(ToolsRoute) + route.core_lifecycle = MagicMock() + route.core_lifecycle.astrbot_config_mgr = MagicMock() + route.core_lifecycle.astrbot_config_mgr.get_conf_list.return_value = [] + route.core_lifecycle.astrbot_config_mgr.confs = {} + route.tool_mgr = FunctionToolManager() + + tool = route.tool_mgr.get_builtin_tool(SandboxQueryTool) + original_description = tool.description + try: + mark_tool_as_sandbox_provider_tool(tool, "shipyard_neo") + + resp = await route.get_tool_list() + data = json.loads(json.dumps(resp)) + target = next(t for t in data["data"] if t["name"] == tool.name) + + assert target["origin"] == "sandbox" + assert target["origin_name"] == "shipyard_neo" + assert target["readonly"] is True + assert "permission" not in target + assert ( + "builtin_config_tags" not in target + or target["builtin_config_tags"] == [] + ) + finally: + if hasattr(tool, "sandbox_provider_id"): + delattr(tool, "sandbox_provider_id") + tool.description = original_description + + @pytest.mark.asyncio + async def test_list_labels_sandbox_provider_plugin_tools_by_provider(self): + from astrbot.core.computer.sandbox_tool_binding import ( + mark_tool_as_sandbox_provider_tool, + ) + from astrbot.dashboard.routes.tools import ToolsRoute + + route = ToolsRoute.__new__(ToolsRoute) + route.core_lifecycle = MagicMock() + route.core_lifecycle.astrbot_config_mgr = MagicMock() + route.core_lifecycle.astrbot_config_mgr.get_conf_list.return_value = [] + route.core_lifecycle.astrbot_config_mgr.confs = {} + route.tool_mgr = FunctionToolManager() + + route.tool_mgr.func_list.append( + mark_tool_as_sandbox_provider_tool( + _dummy_tool("astrbot_cua_mouse_click"), "cua" + ) + ) + + resp = await route.get_tool_list() + data = json.loads(json.dumps(resp)) + target = next(t for t in data["data"] if t["name"] == "astrbot_cua_mouse_click") + + assert target["origin"] == "sandbox" + assert target["origin_name"] == "cua" + assert target["readonly"] is True + assert "permission" not in target + + @pytest.mark.asyncio + async def test_toggle_rejects_sandbox_provider_tools(self): + from astrbot.core.computer.sandbox_tool_binding import ( + mark_tool_as_sandbox_provider_tool, + ) + from astrbot.dashboard.routes.tools import ToolsRoute + + route = ToolsRoute.__new__(ToolsRoute) + route.tool_mgr = FunctionToolManager() + route.tool_mgr.func_list.append( + mark_tool_as_sandbox_provider_tool( + _dummy_tool("astrbot_cua_mouse_click"), "cua" + ) + ) + + mock_req = MagicMock() + mock_req.json = _make_coro( + {"name": "astrbot_cua_mouse_click", "activate": False} + ) + with patch("astrbot.dashboard.routes.tools.request", mock_req): + resp = await route.toggle_tool() + + data = json.loads(json.dumps(resp)) + assert "Sandbox provider tools are read-only" in data["message"] + + @pytest.mark.asyncio + async def test_update_permission_rejects_sandbox_provider_tools(self): + from astrbot.core.computer.sandbox_tool_binding import ( + mark_tool_as_sandbox_provider_tool, + ) + from astrbot.dashboard.routes.tools import ToolsRoute + + route = ToolsRoute.__new__(ToolsRoute) + route.tool_mgr = FunctionToolManager() + route.tool_mgr.func_list.append( + mark_tool_as_sandbox_provider_tool( + _dummy_tool("astrbot_cua_mouse_click"), "cua" + ) + ) + + mock_req = MagicMock() + mock_req.json = _make_coro( + {"name": "astrbot_cua_mouse_click", "permission": "admin"} + ) + with patch("astrbot.dashboard.routes.tools.request", mock_req): + resp = await route.update_tool_permission() + + data = json.loads(json.dumps(resp)) + assert "do not support per-tool permission configuration" in data["message"] + + @pytest.mark.asyncio + async def test_list_labels_regular_plugin_tools_by_plugin_name(self): + from types import SimpleNamespace + + from astrbot.dashboard.routes import tools as tools_route_module + from astrbot.dashboard.routes.tools import ToolsRoute + + route = ToolsRoute.__new__(ToolsRoute) + route.core_lifecycle = MagicMock() + route.core_lifecycle.astrbot_config_mgr = MagicMock() + route.core_lifecycle.astrbot_config_mgr.get_conf_list.return_value = [] + route.core_lifecycle.astrbot_config_mgr.confs = {} + route.tool_mgr = FunctionToolManager() + + plugin_tool = _dummy_tool("regular_plugin_tool") + plugin_tool.handler_module_path = "plugins.weather" + route.tool_mgr.func_list.append(plugin_tool) + + old_star = tools_route_module.star_map.get("plugins.weather") + tools_route_module.star_map["plugins.weather"] = SimpleNamespace( + name="Weather Plugin" + ) + try: + resp = await route.get_tool_list() + data = json.loads(json.dumps(resp)) + target = next(t for t in data["data"] if t["name"] == "regular_plugin_tool") + + assert target["origin"] == "plugin" + assert target["origin_name"] == "Weather Plugin" + assert target["readonly"] is False + assert target["permission"] == "member" + finally: + if old_star is None: + tools_route_module.star_map.pop("plugins.weather", None) + else: + tools_route_module.star_map["plugins.weather"] = old_star + + @pytest.mark.asyncio + async def test_list_includes_config_tags_for_sandbox_provider_tools(self): + from dataclasses import dataclass, field + + from astrbot.core.computer.sandbox_tool_binding import sandbox_provider_tool + from astrbot.dashboard.routes.tools import ToolsRoute + + @sandbox_provider_tool( + "shipyard_neo", + config={ + "provider_settings.computer_use_runtime": "sandbox", + "provider_settings.sandbox.booter": "shipyard_neo", + }, + ) + @dataclass + class FakeNeoTool(FunctionTool): + name: str = "fake_neo_tool" + description: str = "Fake Neo tool" + parameters: dict = field( + default_factory=lambda: {"type": "object", "properties": {}} + ) + + route = ToolsRoute.__new__(ToolsRoute) + route.core_lifecycle = MagicMock() + route.core_lifecycle.astrbot_config_mgr = MagicMock() + route.core_lifecycle.astrbot_config_mgr.get_conf_list.return_value = [ + {"id": "conf-a", "name": "Config A"} + ] + route.core_lifecycle.astrbot_config_mgr.confs = { + "conf-a": { + "provider_settings": { + "computer_use_runtime": "sandbox", + "sandbox": {"booter": "shipyard_neo"}, + } + } + } + route.tool_mgr = FunctionToolManager() + route.tool_mgr.func_list.append(FakeNeoTool()) + + resp = await route.get_tool_list() + data = json.loads(json.dumps(resp)) + target = next(t for t in data["data"] if t["name"] == "fake_neo_tool") + + assert target["origin"] == "sandbox" + assert target["origin_name"] == "shipyard_neo" + assert target["builtin_config_tags"][0]["conf_name"] == "Config A" + # ── API: update_tool_permission ────────────────────────────────────── @@ -380,9 +576,7 @@ async def test_set_admin_permission(self): data = json.loads(json.dumps(resp)) assert data["status"] == "ok" - stored = sp.get( - "tool_permissions", {}, scope="global", scope_id="global" - ) + stored = sp.get("tool_permissions", {}, scope="global", scope_id="global") assert stored["_default"]["target_tool"] == "admin" @pytest.mark.asyncio @@ -394,7 +588,9 @@ async def test_reject_builtin_tool(self): route.tool_mgr = FunctionToolManager() mock_req = MagicMock() - mock_req.json = _make_coro({"name": "astrbot_execute_shell", "permission": "admin"}) + mock_req.json = _make_coro( + {"name": "astrbot_execute_shell", "permission": "admin"} + ) with patch("astrbot.dashboard.routes.tools.request", mock_req): resp = await route.update_tool_permission() data = json.loads(json.dumps(resp))
{{ JSON.stringify(selectedSandboxRecord.connect_info || {}, null, 2) }}
{{ entry.stdout }}
{{ entry.stderr }}