diff --git a/.gitignore b/.gitignore index 5eb9616c8c..e63b3eeca5 100644 --- a/.gitignore +++ b/.gitignore @@ -64,3 +64,4 @@ GenieData/ .worktrees/ dashboard/bun.lock +docs/plans/ diff --git a/README_zh.md b/README_zh.md index 425719faba..9b265b48e7 100644 --- a/README_zh.md +++ b/README_zh.md @@ -47,9 +47,9 @@ AstrBot 是一个开源的一站式 Agentic 个人和群聊助手,可在 QQ、 3. 🤖 支持接入 Dify、阿里云百炼、Coze 等智能体平台。 4. 🌐 多平台,支持 QQ、企业微信、飞书、钉钉、微信公众号、Telegram、Slack 以及[更多](#支持的消息平台)。 5. 📦 插件扩展,已有 1000+ 个插件可一键安装。 -6. 🛡️ [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) 隔离化环境,安全地执行任何代码、调用 Shell、会话级资源复用。 +6. 🛡️ [Agent Sandbox](https://docs.astrbot.app/use/astrbot-agent-sandbox.html) 提供隔离沙盒环境,支持安全执行代码、调用 Shell,并在会话内复用资源。 7. 💻 WebUI 支持。 -8. 🌈 Web ChatUI 支持,ChatUI 内置代理沙盒、网页搜索等。 +8. 🌈 Web ChatUI 支持,内置 Agent Sandbox、网页搜索等能力。 9. 🌐 国际化(i18n)支持。
diff --git a/astrbot/core/agent/runners/tool_loop_agent_runner.py b/astrbot/core/agent/runners/tool_loop_agent_runner.py index 3f74f0ec9b..4764c08480 100644 --- a/astrbot/core/agent/runners/tool_loop_agent_runner.py +++ b/astrbot/core/agent/runners/tool_loop_agent_runner.py @@ -146,6 +146,7 @@ class ToolLoopAgentRunner(BaseAgentRunner[TContext]): REPEATED_TOOL_NOTICE_L1_THRESHOLD = 3 REPEATED_TOOL_NOTICE_L2_THRESHOLD = 4 REPEATED_TOOL_NOTICE_L3_THRESHOLD = 5 + REPEATED_TOOL_NOTICE_EXEMPT_TOOL_NAMES = frozenset({"astrbot_execute_shell"}) REPEATED_TOOL_NOTICE_L1_TEMPLATE = ( "\n\n[SYSTEM NOTICE] By the way, you have executed the same tool " "`{tool_name}` {streak} times consecutively. Double-check whether another " @@ -667,6 +668,9 @@ def _track_tool_call_streak(self, tool_name: str) -> int: return self._same_tool_streak def _build_repeated_tool_call_guidance(self, tool_name: str, streak: int) -> str: + if tool_name in self.REPEATED_TOOL_NOTICE_EXEMPT_TOOL_NAMES: + return "" + if streak < self.REPEATED_TOOL_NOTICE_L1_THRESHOLD: return "" diff --git a/astrbot/core/astr_agent_tool_exec.py b/astrbot/core/astr_agent_tool_exec.py index de5caad554..6d512926e1 100644 --- a/astrbot/core/astr_agent_tool_exec.py +++ b/astrbot/core/astr_agent_tool_exec.py @@ -20,6 +20,7 @@ from astrbot.core.astr_main_agent_resources import ( BACKGROUND_TASK_RESULT_WOKE_SYSTEM_PROMPT, ) +from astrbot.core.computer.sandbox_tool_binding import tool_available_in_runtime from astrbot.core.cron.events import CronMessageEvent from astrbot.core.message.components import Image from astrbot.core.message.message_event_result import ( @@ -31,9 +32,6 @@ from astrbot.core.provider.entites import ProviderRequest from astrbot.core.provider.register import llm_tools from astrbot.core.tools.computer_tools import ( - CuaKeyboardTypeTool, - CuaMouseClickTool, - CuaScreenshotTool, ExecuteShellTool, FileDownloadTool, FileEditTool, @@ -43,8 +41,12 @@ GrepTool, LocalPythonTool, PythonTool, + SandboxLifecycleTool, + SandboxOperationTool, + SandboxQueryTool, ) from astrbot.core.tools.message_tools import SendMessageToUserTool +from astrbot.core.tools.registry import get_builtin_tool_config_rule from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.history_saver import persist_agent_history from astrbot.core.utils.image_ref_utils import is_supported_image_ref @@ -52,6 +54,14 @@ class FunctionToolExecutor(BaseFunctionToolExecutor[AstrAgentContext]): + _runtime_computer_tools_cache: dict[ + tuple[int, str, str], dict[str, FunctionTool] + ] = {} + + @classmethod + def clear_runtime_computer_tools_cache(cls, provider_id: str | None = None) -> None: + cls._runtime_computer_tools_cache.clear() + @classmethod def _collect_image_urls_from_args(cls, image_urls_raw: T.Any) -> list[str]: if image_urls_raw is None: @@ -192,8 +202,14 @@ def _get_runtime_computer_tools( booter: str | None = None, ) -> dict[str, FunctionTool]: booter = "" if booter is None else str(booter).lower() + cache_key = (id(tool_mgr), runtime, booter) + if cache_key in cls._runtime_computer_tools_cache: + return cls._runtime_computer_tools_cache[cache_key] if runtime == "sandbox": shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool) + sandbox_query_tool = tool_mgr.get_builtin_tool(SandboxQueryTool) + sandbox_lifecycle_tool = tool_mgr.get_builtin_tool(SandboxLifecycleTool) + sandbox_operation_tool = tool_mgr.get_builtin_tool(SandboxOperationTool) python_tool = tool_mgr.get_builtin_tool(PythonTool) upload_tool = tool_mgr.get_builtin_tool(FileUploadTool) download_tool = tool_mgr.get_builtin_tool(FileDownloadTool) @@ -203,6 +219,9 @@ def _get_runtime_computer_tools( grep_tool = tool_mgr.get_builtin_tool(GrepTool) tools = { shell_tool.name: shell_tool, + sandbox_query_tool.name: sandbox_query_tool, + sandbox_lifecycle_tool.name: sandbox_lifecycle_tool, + sandbox_operation_tool.name: sandbox_operation_tool, python_tool.name: python_tool, upload_tool.name: upload_tool, download_tool.name: download_tool, @@ -211,17 +230,7 @@ def _get_runtime_computer_tools( edit_tool.name: edit_tool, grep_tool.name: grep_tool, } - if booter == "cua": - screenshot_tool = tool_mgr.get_builtin_tool(CuaScreenshotTool) - mouse_click_tool = tool_mgr.get_builtin_tool(CuaMouseClickTool) - keyboard_type_tool = tool_mgr.get_builtin_tool(CuaKeyboardTypeTool) - tools.update( - { - screenshot_tool.name: screenshot_tool, - mouse_click_tool.name: mouse_click_tool, - keyboard_type_tool.name: keyboard_type_tool, - } - ) + cls._runtime_computer_tools_cache[cache_key] = tools return tools if runtime == "local": shell_tool = tool_mgr.get_builtin_tool(ExecuteShellTool) @@ -230,7 +239,7 @@ def _get_runtime_computer_tools( write_tool = tool_mgr.get_builtin_tool(FileWriteTool) edit_tool = tool_mgr.get_builtin_tool(FileEditTool) grep_tool = tool_mgr.get_builtin_tool(GrepTool) - return { + tools = { shell_tool.name: shell_tool, python_tool.name: python_tool, read_tool.name: read_tool, @@ -238,8 +247,29 @@ def _get_runtime_computer_tools( edit_tool.name: edit_tool, grep_tool.name: grep_tool, } + cls._runtime_computer_tools_cache[cache_key] = tools + return tools return {} + @staticmethod + def _tool_available_for_runtime_config(tool: FunctionTool, runtime: str) -> bool: + if not tool_available_in_runtime(tool, runtime): + return False + rule = get_builtin_tool_config_rule(tool.name) + if rule is None: + return True + conditions = rule.evaluate( + {"provider_settings": {"computer_use_runtime": runtime}} + ) + runtime_conditions = [ + condition + for condition in conditions + if str(condition.get("key")) == "provider_settings.computer_use_runtime" + ] + if not runtime_conditions: + return True + return all(bool(condition.get("matched")) for condition in runtime_conditions) + @classmethod def _build_handoff_toolset( cls, @@ -259,17 +289,18 @@ def _build_handoff_toolset( runtime_computer_tools = cls._get_runtime_computer_tools( runtime, tool_mgr, - provider_settings.get("sandbox", {}).get("booter"), ) # Keep persona semantics aligned with the main agent: tools=None means # "all tools", including runtime computer-use tools. if tools is None: toolset = ToolSet() - for registered_tool in llm_tools.func_list: + for registered_tool in getattr(tool_mgr, "func_list", llm_tools.func_list): if isinstance(registered_tool, HandoffTool): continue - if registered_tool.active: + if registered_tool.active and cls._tool_available_for_runtime_config( + registered_tool, runtime + ): toolset.add_tool(registered_tool) for runtime_tool in runtime_computer_tools.values(): toolset.add_tool(runtime_tool) @@ -281,14 +312,20 @@ def _build_handoff_toolset( toolset = ToolSet() for tool_name_or_obj in tools: if isinstance(tool_name_or_obj, str): - registered_tool = llm_tools.get_func(tool_name_or_obj) - if registered_tool and registered_tool.active: + registered_tool = tool_mgr.get_func(tool_name_or_obj) + if ( + registered_tool + and registered_tool.active + and cls._tool_available_for_runtime_config(registered_tool, runtime) + ): toolset.add_tool(registered_tool) continue runtime_tool = runtime_computer_tools.get(tool_name_or_obj) if runtime_tool: toolset.add_tool(runtime_tool) - elif isinstance(tool_name_or_obj, FunctionTool): + elif isinstance( + tool_name_or_obj, FunctionTool + ) and cls._tool_available_for_runtime_config(tool_name_or_obj, runtime): toolset.add_tool(tool_name_or_obj) return None if toolset.empty() else toolset @@ -538,11 +575,16 @@ async def _wake_main_agent_for_background_result( message_type=session.message_type, ) cron_event.role = event.role + session_config = ctx.get_config(umo=event.unified_msg_origin) + provider_settings = session_config.get("provider_settings", {}) config = MainAgentBuildConfig( tool_call_timeout=run_context.tool_call_timeout, - streaming_response=ctx.get_config() - .get("provider_settings", {}) - .get("stream", False), + streaming_response=provider_settings.get("stream", False), + computer_use_runtime=str( + provider_settings.get("computer_use_runtime", "local") + ), + sandbox_cfg=provider_settings.get("sandbox", {}), + provider_settings=provider_settings, ) req = ProviderRequest() diff --git a/astrbot/core/astr_main_agent.py b/astrbot/core/astr_main_agent.py index cee6e9e27d..0c5ea2770f 100644 --- a/astrbot/core/astr_main_agent.py +++ b/astrbot/core/astr_main_agent.py @@ -15,7 +15,7 @@ from astrbot.core.agent.handoff import HandoffTool from astrbot.core.agent.mcp_client import MCPTool from astrbot.core.agent.message import TextPart -from astrbot.core.agent.tool import ToolSet +from astrbot.core.agent.tool import FunctionTool, ToolSet from astrbot.core.astr_agent_context import AgentContextWrapper, AstrAgentContext from astrbot.core.astr_agent_hooks import MAIN_AGENT_HOOKS from astrbot.core.astr_agent_run_util import AgentRunner @@ -24,10 +24,13 @@ CHATUI_SPECIAL_DEFAULT_PERSONA_PROMPT, LIVE_MODE_SYSTEM_PROMPT, LLM_SAFETY_MODE_SYSTEM_PROMPT, + SANDBOX_GUI_PROMPT, SANDBOX_MODE_PROMPT, TOOL_CALL_PROMPT, TOOL_CALL_PROMPT_SKILLS_LIKE_MODE, ) +from astrbot.core.computer import computer_client +from astrbot.core.computer.sandbox_tool_binding import tool_available_in_runtime from astrbot.core.conversation_mgr import Conversation from astrbot.core.message.components import File, Image, Record, Reply, Video from astrbot.core.persona_error_reply import ( @@ -47,32 +50,18 @@ from astrbot.core.star.star import star_registry from astrbot.core.star.star_handler import star_map from astrbot.core.tools.computer_tools import ( - AnnotateExecutionTool, - BrowserBatchExecTool, - BrowserExecTool, - CreateSkillCandidateTool, - CreateSkillPayloadTool, - CuaKeyboardTypeTool, - CuaMouseClickTool, - CuaScreenshotTool, - EvaluateSkillCandidateTool, ExecuteShellTool, FileDownloadTool, FileEditTool, FileReadTool, FileUploadTool, FileWriteTool, - GetExecutionHistoryTool, - GetSkillPayloadTool, GrepTool, - ListSkillCandidatesTool, - ListSkillReleasesTool, LocalPythonTool, - PromoteSkillCandidateTool, PythonTool, - RollbackSkillReleaseTool, - RunBrowserSkillTool, - SyncSkillReleaseTool, + SandboxLifecycleTool, + SandboxOperationTool, + SandboxQueryTool, normalize_umo_for_workspace, ) from astrbot.core.tools.cron_tools import FutureTaskTool @@ -81,6 +70,7 @@ retrieve_knowledge_base, ) from astrbot.core.tools.message_tools import SendMessageToUserTool +from astrbot.core.tools.registry import get_builtin_tool_config_rule from astrbot.core.tools.web_search_tools import ( BaiduWebSearchTool, BochaWebSearchTool, @@ -451,6 +441,34 @@ def _filter_skills_for_current_config( return filtered +def _tool_available_for_current_runtime(tool: FunctionTool, cfg: dict) -> bool: + runtime = str(cfg.get("computer_use_runtime", "local")) + if not tool_available_in_runtime(tool, runtime): + return False + rule = get_builtin_tool_config_rule(tool.name) + if rule is None: + return True + conditions = rule.evaluate({"provider_settings": cfg}) + runtime_conditions = [ + condition + for condition in conditions + if str(condition.get("key")) == "provider_settings.computer_use_runtime" + ] + if not runtime_conditions: + return True + return all(bool(condition.get("matched")) for condition in runtime_conditions) + + +def _filter_tools_for_current_config( + toolset: ToolSet, cfg: dict, session_id: str +) -> ToolSet: + filtered = ToolSet() + for tool in toolset: + if _tool_available_for_current_runtime(tool, cfg): + filtered.add_tool(tool) + return filtered + + async def _ensure_persona_and_skills( req: ProviderRequest, cfg: dict, @@ -479,6 +497,7 @@ async def _ensure_persona_and_skills( if req.system_prompt is None: req.system_prompt = "" + session_id = event.unified_msg_origin if persona: # Inject persona system prompt @@ -492,7 +511,12 @@ async def _ensure_persona_and_skills( # Inject skills prompt runtime = cfg.get("computer_use_runtime", "local") skill_manager = SkillManager() - skills = skill_manager.list_skills(active_only=True, runtime=runtime) + current_provider = computer_client.get_current_sandbox_provider_id(session_id) + skills = skill_manager.list_skills( + active_only=True, + runtime=runtime, + provider_id=current_provider, + ) skills = _filter_skills_for_current_config(skills, cfg) if skills: @@ -511,10 +535,20 @@ async def _ensure_persona_and_skills( "If you need to use these capabilities, ask the user to enable Computer Use in the AstrBot WebUI -> Config." ) tmgr = plugin_context.get_llm_tool_manager() + persona_tools_configured = bool(persona and persona.get("tools") is not None) + req._persona_tools_configured = persona_tools_configured + req._persona_allowed_tool_names = ( + {str(tool_name) for tool_name in persona.get("tools", [])} + if persona_tools_configured + else None + ) # inject toolset in the persona if (persona and persona.get("tools") is None) or not persona: persona_toolset = tmgr.get_full_tool_set() + persona_toolset = _filter_tools_for_current_config( + persona_toolset, cfg, session_id + ) for tool in list(persona_toolset): if not tool.active: persona_toolset.remove_tool(tool.name) @@ -523,7 +557,11 @@ async def _ensure_persona_and_skills( if persona["tools"]: for tool_name in persona["tools"]: tool = tmgr.get_func(tool_name) - if tool and tool.active: + if ( + tool + and tool.active + and _tool_available_for_current_runtime(tool, cfg) + ): persona_toolset.add_tool(tool) if not req.func_tool: req.func_tool = persona_toolset @@ -545,13 +583,15 @@ async def _ensure_persona_and_skills( if a.get("enabled", True) is False: continue persona_tools = None + persona_tools_configured = False pid = a.get("persona_id") if pid: persona = plugin_context.persona_manager.get_persona_v3_by_id(pid) if persona is not None: persona_tools = persona.get("tools") + persona_tools_configured = "tools" in persona tools = a.get("tools", []) - if persona_tools is not None: + if persona_tools_configured: tools = persona_tools if tools is None: assigned_tools.update( @@ -559,6 +599,7 @@ async def _ensure_persona_and_skills( tool.name for tool in tmgr.func_list if not isinstance(tool, HandoffTool) + and _tool_available_for_current_runtime(tool, cfg) ] ) continue @@ -951,7 +992,9 @@ async def _decorate_llm_request( _apply_workspace_extra_prompt(event, req) -def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: +def _plugin_tool_fix( + event: AstrMessageEvent, req: ProviderRequest, cfg: dict | None = None +) -> None: """根据事件中的插件设置,过滤请求中的工具列表。 注意:没有 handler_module_path 的工具(如 MCP 工具)会被保留, @@ -979,6 +1022,10 @@ def _plugin_tool_fix(event: AstrMessageEvent, req: ProviderRequest) -> None: new_tool_set.add_tool(tool) req.func_tool = new_tool_set + if req.func_tool is not None and cfg is not None: + session_id = req.session_id or event.unified_msg_origin + req.func_tool = _filter_tools_for_current_config(req.func_tool, cfg, session_id) + async def _handle_webchat( event: AstrMessageEvent, req: ProviderRequest, prov: Provider @@ -1037,100 +1084,40 @@ def _apply_llm_safety_mode(config: MainAgentBuildConfig, req: ProviderRequest) - def _apply_sandbox_tools( config: MainAgentBuildConfig, req: ProviderRequest, - session_id: str, ) -> None: if req.func_tool is None: req.func_tool = ToolSet() if req.system_prompt is None: req.system_prompt = "" - booter = config.sandbox_cfg.get("booter", "shipyard_neo") - if booter == "shipyard": - ep = config.sandbox_cfg.get("shipyard_endpoint", "") - at = config.sandbox_cfg.get("shipyard_access_token", "") - if not ep or not at: - logger.error("Shipyard sandbox configuration is incomplete.") - return - os.environ["SHIPYARD_ENDPOINT"] = ep - os.environ["SHIPYARD_ACCESS_TOKEN"] = at + allowed_tool_names = getattr(req, "_persona_allowed_tool_names", None) + persona_tools_configured = bool(getattr(req, "_persona_tools_configured", False)) tool_mgr = llm_tools - req.func_tool.add_tool(tool_mgr.get_builtin_tool(ExecuteShellTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(PythonTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileUploadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileDownloadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileReadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileWriteTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(FileEditTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(GrepTool)) - if booter == "shipyard_neo": - # Neo-specific path rule: filesystem tools operate relative to sandbox - # workspace root. Do not prepend "/workspace". - req.system_prompt += ( - "\n[Shipyard Neo File Path Rule]\n" - "When using sandbox filesystem tools (upload/download/read/write/list/delete), " - "always pass paths relative to the sandbox workspace root. " - "Example: use `baidu_homepage.png` instead of `/workspace/baidu_homepage.png`.\n" - ) - - req.system_prompt += ( - "\n[Neo Skill Lifecycle Workflow]\n" - "When user asks to create/update a reusable skill in Neo mode, use lifecycle tools instead of directly writing local skill folders.\n" - "Preferred sequence:\n" - "1) Use `astrbot_create_skill_payload` to store canonical payload content and get `payload_ref`.\n" - "2) Use `astrbot_create_skill_candidate` with `skill_key` + `source_execution_ids` (and optional `payload_ref`) to create a candidate.\n" - "3) Use `astrbot_promote_skill_candidate` to release: `stage=canary` for trial; `stage=stable` for production.\n" - "For stable release, set `sync_to_local=true` to sync `payload.skill_markdown` into local `SKILL.md`.\n" - "Do not treat ad-hoc generated files as reusable Neo skills unless they are captured via payload/candidate/release.\n" - "To update an existing skill, create a new payload/candidate and promote a new release version; avoid patching old local folders directly.\n" - ) + added_tool = False - # Determine sandbox capabilities from an already-booted session. - # If no session exists yet (first request), capabilities is None - # and we register all tools conservatively. - from astrbot.core.computer.computer_client import session_booter - - sandbox_capabilities: list[str] | None = None - existing_booter = session_booter.get(session_id) - if existing_booter is not None: - sandbox_capabilities = getattr(existing_booter, "capabilities", None) - - # Browser tools: only register if profile supports browser - # (or if capabilities are unknown because sandbox hasn't booted yet) - if sandbox_capabilities is None or "browser" in sandbox_capabilities: - req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserExecTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(BrowserBatchExecTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(RunBrowserSkillTool)) - - # Neo-specific tools (always available for shipyard_neo) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetExecutionHistoryTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(AnnotateExecutionTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillPayloadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(GetSkillPayloadTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CreateSkillCandidateTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillCandidatesTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(EvaluateSkillCandidateTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(PromoteSkillCandidateTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(ListSkillReleasesTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(RollbackSkillReleaseTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(SyncSkillReleaseTool)) - - if booter == "cua": - req.system_prompt += ( - "\n[CUA Desktop Control]\n" - "Use `astrbot_execute_shell` with `background=true` to launch GUI apps. " - 'Use Firefox for browser tasks, for example `firefox "https://example.com"`. ' - "After each visible step, call `astrbot_cua_screenshot` with " - "`send_to_user=true` and `return_image_to_llm=true` so the user can " - "monitor progress. When typing, inspect the screenshot first and confirm " - "the target field is focused and empty or safe to append to. Use " - "`astrbot_cua_mouse_click` for coordinates and `astrbot_cua_keyboard_type` " - "for text input; use text=`\\n` for Enter.\n" + def add_sandbox_tool(tool_cls) -> None: + nonlocal added_tool + tool = tool_mgr.get_builtin_tool(tool_cls) + if persona_tools_configured and tool.name not in allowed_tool_names: + return + req.func_tool.add_tool(tool) + added_tool = True + + add_sandbox_tool(ExecuteShellTool) + add_sandbox_tool(SandboxQueryTool) + add_sandbox_tool(SandboxLifecycleTool) + add_sandbox_tool(SandboxOperationTool) + add_sandbox_tool(PythonTool) + add_sandbox_tool(FileUploadTool) + add_sandbox_tool(FileDownloadTool) + add_sandbox_tool(FileReadTool) + add_sandbox_tool(FileWriteTool) + add_sandbox_tool(FileEditTool) + add_sandbox_tool(GrepTool) + if added_tool: + req.system_prompt = ( + f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}{SANDBOX_GUI_PROMPT}\n" ) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CuaScreenshotTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CuaMouseClickTool)) - req.func_tool.add_tool(tool_mgr.get_builtin_tool(CuaKeyboardTypeTool)) - - req.system_prompt = f"{req.system_prompt or ''}\n{SANDBOX_MODE_PROMPT}\n" def _proactive_cron_job_tools(req: ProviderRequest, plugin_context: Context) -> None: @@ -1484,14 +1471,14 @@ async def build_main_agent( if not req.session_id: req.session_id = event.unified_msg_origin - _plugin_tool_fix(event, req) + _plugin_tool_fix(event, req, config.provider_settings) await _apply_web_search_tools(event, req, plugin_context) if config.llm_safety_mode: _apply_llm_safety_mode(config, req) if config.computer_use_runtime == "sandbox": - _apply_sandbox_tools(config, req, req.session_id) + _apply_sandbox_tools(config, req) elif config.computer_use_runtime == "local": _apply_local_env_tools(req, plugin_context) diff --git a/astrbot/core/astr_main_agent_resources.py b/astrbot/core/astr_main_agent_resources.py index 4efa0e5a6d..9a235088cc 100644 --- a/astrbot/core/astr_main_agent_resources.py +++ b/astrbot/core/astr_main_agent_resources.py @@ -13,6 +13,17 @@ SANDBOX_MODE_PROMPT = ( "You have access to a sandboxed environment and can execute shell commands and Python code securely." + " You can manage sandbox lifecycle, including listing sandbox providers, listing sandboxes, checking the current sandbox, creating a new sandbox, switching sandboxes, releasing sandbox occupancy, taking over a sandbox, destroying a sandbox, and copying files between sandboxes." + " Before creating a new sandbox, always check the current sandbox first." + " If there is no current sandbox, list sandboxes and inspect each sandbox's access field for this session." + " Prefer reusing access.status=current first, then access.status=idle. Never treat status=running alone as reusable." + " If access.status=occupied or access.can_switch=false, another active session controls that sandbox; do not switch to it unless the user explicitly asks to take it over." + " If you need a different provider, use astrbot_sandbox_query with action=list_providers first and pass provider_id explicitly to astrbot_sandbox_lifecycle with action=create." + " You can create a new sandbox only when the user explicitly asks for a fresh or separate environment, or when no existing sandbox can be reused safely." + " Each successful sandbox operation that accesses a sandbox automatically renews this session's lease to now plus the configured sandbox lease timeout." + " Sandbox-bound tool results include lease metadata such as lease_expires_at and lease_expires_in_seconds." + " When this session's lease expires, this session no longer has a current sandbox; use list_sandboxes and then switch, takeover, or create before continuing sandbox work." + " For long-running work, monitor lease metadata; if the remaining time is low before a long idle period or external wait, call astrbot_sandbox_lifecycle with action=renew_lease." # "Your have extended skills library, such as PDF processing, image generation, data analysis, etc. " # "Before handling complex tasks, please retrieve and review the documentation in the in /app/skills/ directory. " # "If the current task matches the description of a specific skill, prioritize following the workflow defined by that skill." @@ -22,6 +33,13 @@ # "Use shell commands such as grep, sed, awk to extract relevant information from the documentation as needed.\n" ) +SANDBOX_GUI_PROMPT = ( + " When working with GUI-capable sandboxes, use astrbot_sandbox_operation with action=capture_screenshot to show progress whenever it is helpful, especially after each meaningful GUI step." + " If the screenshot should be visible to the user, set send_to_user=true in that same capture_screenshot call instead of taking a screenshot first and then calling send_message_to_user separately." + " Set return_image_to_llm=true only when you need to inspect the screenshot yourself before deciding the next step." + " If the task is completed successfully, also send a final result screenshot with send_to_user=true to show the outcome clearly." +) + TOOL_CALL_PROMPT = ( "When using tools: " "never return an empty response; " diff --git a/astrbot/core/computer/booters/base.py b/astrbot/core/computer/booters/base.py index ec1af5cdc8..96c967c463 100644 --- a/astrbot/core/computer/booters/base.py +++ b/astrbot/core/computer/booters/base.py @@ -37,12 +37,10 @@ def gui(self) -> GUIComponent | None: async def boot(self, session_id: str) -> None: ... async def shutdown(self, **kwargs) -> None: - """Shut down the computer sandbox. + """Close the current runtime connection without deleting sandbox resources. - Subclasses may accept extra keyword arguments for - type-specific cleanup (e.g. ``delete_sandbox`` for - ShipyardNeoBooter). The default implementation ignores - them. + Subclasses may accept extra keyword arguments for type-specific cleanup. + The default implementation ignores them. """ ... diff --git a/astrbot/core/computer/booters/bay_manager.py b/astrbot/core/computer/booters/bay_manager.py deleted file mode 100644 index 61ccc1b3a5..0000000000 --- a/astrbot/core/computer/booters/bay_manager.py +++ /dev/null @@ -1,259 +0,0 @@ -"""Manage Bay container lifecycle for zero-config Shipyard Neo integration. - -When no Bay endpoint is configured, AstrBot can automatically start a Bay -container using the Docker socket (like BoxliteBooter does for Ship -containers). -""" - -from __future__ import annotations - -import asyncio -import io -import json -import tarfile -from typing import Any - -import aiodocker -import aiohttp - -from astrbot.api import logger - -# --------------------------------------------------------------------------- -# Constants -# --------------------------------------------------------------------------- - -BAY_IMAGE = "ghcr.io/astrbotdevs/shipyard-neo-bay:latest" -BAY_CONTAINER_NAME = "astrbot-bay" -BAY_LABEL = "astrbot.bay.managed" -BAY_PORT = 8114 -HEALTH_TIMEOUT_S = 60 -HEALTH_POLL_INTERVAL_S = 2 - - -class BayContainerManager: - """Start / reuse / stop a Bay container via Docker Engine API.""" - - def __init__( - self, - image: str = BAY_IMAGE, - host_port: int = BAY_PORT, - ) -> None: - self._image = image - self._host_port = host_port - self._docker: aiodocker.Docker | None = None - self._container: Any = None - - # ------------------------------------------------------------------ - # Public API - # ------------------------------------------------------------------ - - async def ensure_running(self) -> str: - """Make sure a Bay container is running. Returns the endpoint URL. - - If a container labelled ``astrbot.bay.managed`` already exists - and is running, it will be reused. Otherwise a new container is - created from *self._image*. - """ - try: - self._docker = aiodocker.Docker() - except Exception as exc: - raise RuntimeError( - "Failed to connect to Docker daemon. " - "Ensure Docker is installed and running, or configure " - "an explicit Bay endpoint instead of auto-start mode." - ) from exc - - # 1. Look for an existing managed container - existing = await self._find_managed_container() - if existing is not None: - state = existing["State"] - if state.get("Running"): - cid = existing["Id"][:12] - logger.info("[BayManager] Reusing existing Bay container: %s", cid) - self._container = await self._docker.containers.get(existing["Id"]) - return f"http://127.0.0.1:{self._host_port}" - else: - # Container exists but stopped — restart it - logger.info("[BayManager] Restarting stopped Bay container") - container = await self._docker.containers.get(existing["Id"]) - await container.start() - self._container = container - return f"http://127.0.0.1:{self._host_port}" - - # 2. Pull image if needed - await self._pull_image_if_needed() - - # 3. Create and start container - logger.info( - "[BayManager] Starting Bay container: image=%s, port=%d", - self._image, - self._host_port, - ) - config = { - "Image": self._image, - "Labels": {BAY_LABEL: "true"}, - "Env": [ - "BAY_SERVER__HOST=0.0.0.0", - f"BAY_SERVER__PORT={BAY_PORT}", - "BAY_DATA_DIR=/app/data", - # allow_anonymous=false → auto-provisions API key - "BAY_SECURITY__ALLOW_ANONYMOUS=false", - ], - "HostConfig": { - "PortBindings": { - f"{BAY_PORT}/tcp": [{"HostPort": str(self._host_port)}], - }, - "Binds": [ - # Bay needs Docker socket to create sandbox containers - "/var/run/docker.sock:/var/run/docker.sock", - ], - "RestartPolicy": {"Name": "unless-stopped"}, - }, - } - self._container = await self._docker.containers.create_or_replace( - BAY_CONTAINER_NAME, config - ) - await self._container.start() - logger.info("[BayManager] Bay container started: %s", BAY_CONTAINER_NAME) - - return f"http://127.0.0.1:{self._host_port}" - - async def wait_healthy(self, timeout: int = HEALTH_TIMEOUT_S) -> None: - """Block until Bay's ``/health`` endpoint returns 200.""" - url = f"http://127.0.0.1:{self._host_port}/health" - loop = asyncio.get_running_loop() - deadline = loop.time() + timeout - last_error: str = "" - - async with aiohttp.ClientSession() as session: - while loop.time() < deadline: - try: - async with session.get( - url, timeout=aiohttp.ClientTimeout(total=3) - ) as resp: - if resp.status == 200: - logger.info("[BayManager] Bay is healthy") - return - last_error = f"HTTP {resp.status}" - except Exception as exc: - last_error = str(exc) - - await asyncio.sleep(HEALTH_POLL_INTERVAL_S) - - raise TimeoutError( - f"Bay did not become healthy within {timeout}s (last error: {last_error})" - ) - - async def read_credentials(self) -> str: - """Read auto-provisioned API key from Bay container. - - Bay writes ``credentials.json`` to its data directory when - ``allow_anonymous=false`` and no explicit API key is set. - """ - if self._container is None: - return "" - - try: - # Read credentials.json from container filesystem - tar_stream = await self._container.get_archive("/app/data/credentials.json") - # get_archive returns (tar_data, stat) - tar_data = tar_stream - - if isinstance(tar_data, dict): - raw = tar_data.get("data", b"") - elif isinstance(tar_data, tuple): - # (stream, stat_info) - raw = b"" - stream = tar_data[0] - if hasattr(stream, "read"): - raw = await stream.read() - elif isinstance(stream, bytes): - raw = stream - else: - # It might be a chunked response - chunks = [] - async for chunk in stream: - chunks.append(chunk) - raw = b"".join(chunks) - else: - raw = tar_data if isinstance(tar_data, bytes) else b"" - - if not raw: - logger.debug("[BayManager] Empty tar response from container") - return "" - - tario = io.BytesIO(raw) - with tarfile.open(fileobj=tario) as tar: - for member in tar.getmembers(): - f = tar.extractfile(member) - if f: - creds = json.loads(f.read().decode("utf-8")) - api_key = creds.get("api_key", "") - if api_key: - masked = ( - f"{api_key[:8]}..." - if len(api_key) >= 10 - else "redacted" - ) - logger.info( - "[BayManager] Auto-discovered Bay API key: %s", - masked, - ) - return api_key - except Exception as exc: - logger.debug( - "[BayManager] Failed to read credentials from container: %s", exc - ) - - return "" - - async def close_client(self) -> None: - """Close the Docker client without stopping the container. - - The Bay container stays running for reuse by future sessions. - """ - if self._docker is not None: - await self._docker.close() - self._docker = None - - async def stop(self) -> None: - """Stop and remove the managed Bay container.""" - if self._container is not None: - try: - await self._container.stop() - await self._container.delete(force=True) - logger.info("[BayManager] Bay container stopped and removed") - except Exception as exc: - logger.debug("[BayManager] Error stopping Bay container: %s", exc) - finally: - self._container = None - - await self.close_client() - - # ------------------------------------------------------------------ - # Private helpers - # ------------------------------------------------------------------ - - async def _find_managed_container(self) -> dict | None: - """Find an existing container with our management label.""" - assert self._docker is not None - containers = await self._docker.containers.list( - all=True, - filters=json.dumps({"label": [f"{BAY_LABEL}=true"]}), - ) - if containers: - # Inspect first match to get full state - return await containers[0].show() - return None - - async def _pull_image_if_needed(self) -> None: - """Pull the Bay image if it doesn't exist locally.""" - assert self._docker is not None - try: - await self._docker.images.inspect(self._image) - logger.debug("[BayManager] Image %s already exists", self._image) - except aiodocker.exceptions.DockerError: - logger.info("[BayManager] Pulling image %s ...", self._image) - # Pull with progress logging - await self._docker.images.pull(self._image) - logger.info("[BayManager] Image %s pulled successfully", self._image) diff --git a/astrbot/core/computer/booters/boxlite.py b/astrbot/core/computer/booters/boxlite.py deleted file mode 100644 index aa3ca59761..0000000000 --- a/astrbot/core/computer/booters/boxlite.py +++ /dev/null @@ -1,194 +0,0 @@ -import asyncio -import random -from typing import Any - -import aiohttp -import boxlite -from shipyard import FileSystemComponent as ShipyardFileSystemComponent -from shipyard.python import PythonComponent as ShipyardPythonComponent -from shipyard.shell import ShellComponent as ShipyardShellComponent - -from astrbot.api import logger - -from ..olayer import FileSystemComponent, PythonComponent, ShellComponent -from .base import ComputerBooter -from .shipyard import ShipyardFileSystemWrapper - - -class MockShipyardSandboxClient: - def __init__(self, sb_url: str) -> None: - self.sb_url = sb_url.rstrip("/") - - async def _exec_operation( - self, - ship_id: str, - operation_type: str, - payload: dict[str, Any], - session_id: str, - ) -> dict[str, Any]: - async with aiohttp.ClientSession() as session: - headers = {"X-SESSION-ID": session_id} - async with session.post( - f"{self.sb_url}/{operation_type}", - json=payload, - headers=headers, - ) as response: - if response.status == 200: - return await response.json() - else: - error_text = await response.text() - raise Exception( - f"Failed to exec operation: {response.status} {error_text}" - ) - - async def upload_file(self, path: str, remote_path: str) -> dict: - """Upload a file to the sandbox""" - url = f"http://{self.sb_url}/upload" - - try: - # Read file content - with open(path, "rb") as f: - file_content = f.read() - - # Create multipart form data - data = aiohttp.FormData() - data.add_field( - "file", - file_content, - filename=remote_path.split("/")[-1], - content_type="application/octet-stream", - ) - data.add_field("file_path", remote_path) - - timeout = aiohttp.ClientTimeout(total=120) # 2 minutes for file upload - - async with aiohttp.ClientSession(timeout=timeout) as session: - async with session.post(url, data=data) as response: - if response.status == 200: - logger.info( - "[Computer] File uploaded to Boxlite sandbox: %s", - remote_path, - ) - return { - "success": True, - "message": "File uploaded successfully", - "file_path": remote_path, - } - else: - error_text = await response.text() - return { - "success": False, - "error": f"Server returned {response.status}: {error_text}", - "message": "File upload failed", - } - - except aiohttp.ClientError as e: - logger.error(f"Failed to upload file: {e}") - return { - "success": False, - "error": f"Connection error: {str(e)}", - "message": "File upload failed", - } - except asyncio.TimeoutError: - return { - "success": False, - "error": "File upload timeout", - "message": "File upload failed", - } - except FileNotFoundError: - logger.error(f"File not found: {path}") - return { - "success": False, - "error": f"File not found: {path}", - "message": "File upload failed", - } - except Exception as e: - logger.error(f"Unexpected error uploading file: {e}") - return { - "success": False, - "error": f"Internal error: {str(e)}", - "message": "File upload failed", - } - - async def wait_healthy(self, ship_id: str, session_id: str) -> None: - """Mock wait healthy""" - loop = 60 - while loop > 0: - try: - logger.info( - f"Checking health for sandbox {ship_id} on {self.sb_url}..." - ) - url = f"{self.sb_url}/health" - async with aiohttp.ClientSession() as session: - async with session.get(url) as response: - if response.status == 200: - logger.info(f"Sandbox {ship_id} is healthy") - return - except Exception: - await asyncio.sleep(1) - loop -= 1 - - -class BoxliteBooter(ComputerBooter): - async def boot(self, session_id: str) -> None: - logger.info( - f"Booting(Boxlite) for session: {session_id}, this may take a while..." - ) - random_port = random.randint(20000, 30000) - self.box = boxlite.SimpleBox( - image="soulter/shipyard-ship", - memory_mib=512, - cpus=1, - ports=[ - { - "host_port": random_port, - "guest_port": 8123, - } - ], - ) - await self.box.start() - logger.info(f"Boxlite booter started for session: {session_id}") - self.mocked = MockShipyardSandboxClient( - sb_url=f"http://127.0.0.1:{random_port}" - ) - self._python = ShipyardPythonComponent( - client=self.mocked, # type: ignore - ship_id=self.box.id, - session_id=session_id, - ) - self._shell = ShipyardShellComponent( - client=self.mocked, # type: ignore - ship_id=self.box.id, - session_id=session_id, - ) - self._ship_fs = ShipyardFileSystemComponent( - client=self.mocked, # type: ignore - ship_id=self.box.id, - session_id=session_id, - ) - self._fs = ShipyardFileSystemWrapper( - _shipyard_fs=self._ship_fs, _shipyard_shell=self._shell - ) - - await self.mocked.wait_healthy(self.box.id, session_id) - - async def shutdown(self) -> None: - logger.info(f"Shutting down Boxlite booter for ship: {self.box.id}") - self.box.shutdown() - logger.info(f"Boxlite booter for ship: {self.box.id} stopped") - - @property - def fs(self) -> FileSystemComponent: - return self._fs - - @property - def python(self) -> PythonComponent: - return self._python - - @property - def shell(self) -> ShellComponent: - return self._shell - - async def upload_file(self, path: str, file_name: str) -> dict: - """Upload file to sandbox""" - return await self.mocked.upload_file(path, file_name) diff --git a/astrbot/core/computer/booters/cua.py b/astrbot/core/computer/booters/cua.py deleted file mode 100644 index 151b4c0e04..0000000000 --- a/astrbot/core/computer/booters/cua.py +++ /dev/null @@ -1,878 +0,0 @@ -from __future__ import annotations - -import base64 -import inspect -import shlex -from dataclasses import asdict, dataclass, is_dataclass -from pathlib import Path -from typing import Any - -from astrbot.api import logger - -from ..olayer import FileSystemComponent, GUIComponent, PythonComponent, ShellComponent -from .base import ComputerBooter -from .cua_defaults import CUA_CONFIG_KEYS, CUA_DEFAULT_CONFIG -from .shipyard_search_file_util import search_files_via_shell - -_POSIX_OS_TYPES = {"linux", "darwin", "macos"} - -_CUA_BACKGROUND_LAUNCHER = """ -import subprocess, sys, time - -p = subprocess.Popen( - ["sh", "-lc", sys.argv[1]], - stdin=subprocess.DEVNULL, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - start_new_session=True, -) -sys.stdout.write(str(p.pid) + "\\n") -sys.stdout.flush() -time.sleep(0.2) -code = p.poll() -sys.exit(0 if code is None else code) -""".strip() - - -async def _maybe_await(value: Any) -> Any: - if inspect.isawaitable(value): - return await value - return value - - -def build_cua_booter_kwargs(sandbox_cfg: dict[str, Any]) -> dict[str, Any]: - return { - name: sandbox_cfg.get(config_key, CUA_DEFAULT_CONFIG[name]) - for name, config_key in CUA_CONFIG_KEYS.items() - } - - -async def _write_base64_via_shell( - shell: ShellComponent, - path: str, - data: bytes, -) -> dict[str, Any]: - encoded = base64.b64encode(data).decode("ascii") - decoder = ( - "import base64,pathlib,sys; " - "pathlib.Path(sys.argv[1]).write_bytes(base64.b64decode(sys.stdin.read()))" - ) - return await shell.exec( - f"python3 -c {shlex.quote(decoder)} {shlex.quote(path)} <<'EOF'\n{encoded}\nEOF" - ) - - -@dataclass(slots=True) -class ProcessResult: - stdout: str - stderr: str - exit_code: int | None - success: bool - - -def _maybe_model_dump(value: Any) -> dict[str, Any]: - if isinstance(value, dict): - return value - if is_dataclass(value) and not isinstance(value, type): - return asdict(value) - if hasattr(value, "model_dump"): - dumped = value.model_dump() - if isinstance(dumped, dict): - return dumped - if hasattr(value, "dict"): - dumped = value.dict() - if isinstance(dumped, dict): - return dumped - attr_payload = { - key: getattr(value, key) - for key in ( - "stdout", - "stderr", - "output", - "error", - "returncode", - "return_code", - "exit_code", - "success", - ) - if hasattr(value, key) - } - if attr_payload: - return attr_payload - return {} - - -def _slice_content_by_lines( - content: str, - *, - offset: int | None = None, - limit: int | None = None, -) -> str: - lines = content.splitlines(keepends=True) - start = 0 if offset is None else offset - selected = lines[start:] if limit is None else lines[start : start + limit] - return "".join(selected) - - -def _normalize_process_result(raw: Any) -> ProcessResult: - """Best-effort normalization for the process shapes returned by CUA SDKs.""" - payload = _maybe_model_dump(raw) - if not payload and isinstance(raw, str): - payload = {"stdout": raw} - - def first_text(*keys: str) -> str: - for key in keys: - value = payload.get(key) - if value is not None: - return str(value) - return "" - - stdout = first_text("stdout", "output") - stderr = first_text("stderr", "error") - exit_code = payload.get("exit_code") - if exit_code is None: - exit_code = payload.get("returncode") - if exit_code is None: - exit_code = payload.get("return_code") - if exit_code is None: - exit_code = 0 if not stderr else 1 - success = bool(payload.get("success", not stderr and exit_code in (0, None))) - return ProcessResult( - stdout=stdout, - stderr=stderr, - exit_code=exit_code, - success=success, - ) - - -def _is_missing_python3_error(stderr: str) -> bool: - lowered = stderr.lower() - return "python3" in lowered and ( - "not found" in lowered - or "command not found" in lowered - or "no such file" in lowered - ) - - -def _python3_requirement_error(operation: str, stderr: str) -> str: - return f"CUA {operation} requires python3 in the sandbox image: {stderr}" - - -def _normalize_with_python3_requirement(raw: Any, operation: str) -> ProcessResult: - proc = _normalize_process_result(raw) - if proc.stderr and _is_missing_python3_error(proc.stderr): - return ProcessResult( - stdout=proc.stdout, - stderr=_python3_requirement_error(operation, proc.stderr), - exit_code=proc.exit_code, - success=proc.success, - ) - return proc - - -async def _exec_python3_or_error( - shell: ShellComponent, - code: str, - *, - operation: str, - timeout: int | None = 30, -) -> ProcessResult: - result = await shell.exec(f"python3 - <<'PY'\n{code}\nPY", timeout=timeout) - return _normalize_with_python3_requirement(result, operation) - - -def _is_posix_os_type(os_type: str) -> bool: - return os_type.lower() in _POSIX_OS_TYPES - - -def _posix_fs_error_message(os_type: str) -> str: - return ( - "CUA filesystem shell fallback is only supported for POSIX images; " - f"os_type={os_type!r} does not support the required shell commands." - ) - - -def _non_posix_filesystem_result(path: str, os_type: str) -> dict[str, Any]: - error = _posix_fs_error_message(os_type) - return {"success": False, "path": path, "error": error, "message": error} - - -def _raise_non_posix_filesystem_error(os_type: str) -> None: - raise RuntimeError(_posix_fs_error_message(os_type)) - - -def _resolve_component_method( - component: Any, - method_names: str | tuple[str, ...], -) -> Any | None: - if component is None: - return None - names = (method_names,) if isinstance(method_names, str) else method_names - for method_name in names: - method = getattr(component, method_name, None) - if method is not None: - return method - return None - - -def _missing_component_method_error( - component_name: str, - method_names: str | tuple[str, ...], -) -> RuntimeError: - names = (method_names,) if isinstance(method_names, str) else method_names - candidates = ", ".join(f"{component_name}.{name}" for name in names) - return RuntimeError( - f"CUA sandbox does not provide any of: {candidates}. " - "Please check the installed CUA SDK version and sandbox backend." - ) - - -def _has_component_method(root: Any, component_name: str, method_name: str) -> bool: - component = getattr(root, component_name, None) - return getattr(component, method_name, None) is not None - - -def _resolve_files_components(sandbox: Any) -> tuple[Any, ...]: - components: list[Any] = [] - seen_ids: set[int] = set() - for name in ("files", "filesystem"): - component = getattr(sandbox, name, None) - if component is None: - continue - component_id = id(component) - if component_id in seen_ids: - continue - seen_ids.add(component_id) - components.append(component) - return tuple(components) - - -def _resolve_files_method( - components: tuple[Any, ...], - method_names: str | tuple[str, ...], -) -> Any | None: - for component in components: - method = _resolve_component_method(component, method_names) - if method is not None: - return method - return None - - -def _normalize_native_upload_result(raw: Any, file_name: str) -> dict[str, Any]: - payload = _maybe_model_dump(raw) - if not payload: - return {"success": True, "file_path": file_name} - if "file_path" not in payload and "path" not in payload: - payload["file_path"] = file_name - if "success" not in payload: - payload["success"] = not bool(payload.get("error") or payload.get("stderr")) - return payload - - -class CuaShellComponent(ShellComponent): - def __init__(self, sandbox: Any, os_type: str = "linux") -> None: - self._sandbox = sandbox - self._os_type = os_type.lower() - shell = sandbox.shell - self._exec_raw = getattr(shell, "exec", None) or getattr(shell, "run", None) - if self._exec_raw is None: - raise RuntimeError("CUA sandbox shell must provide `.exec` or `.run`.") - - async def exec( - self, - command: str, - cwd: str | None = None, - env: dict[str, str] | None = None, - timeout: int | None = 30, - shell: bool = True, - background: bool = False, - ) -> dict[str, Any]: - if not shell: - return { - "stdout": "", - "stderr": "error: only shell mode is supported in CUA booter.", - "exit_code": 2, - "success": False, - } - - kwargs: dict[str, Any] = {} - if cwd is not None: - kwargs["cwd"] = cwd - if timeout is not None: - kwargs["timeout"] = timeout - if env: - kwargs["env"] = env - if background: - if not _is_posix_os_type(self._os_type): - return { - "stdout": "", - "stderr": "error: background shell execution is only supported for POSIX CUA images.", - "exit_code": 2, - "success": False, - } - command = _build_cua_background_command(command) - - result = await _maybe_await(self._exec_raw(command, **kwargs)) - proc = ( - _normalize_with_python3_requirement(result, "background execution") - if background - else _normalize_process_result(result) - ) - response = { - "stdout": proc.stdout, - "stderr": proc.stderr, - "exit_code": proc.exit_code, - "success": proc.success, - } - if background: - try: - response["pid"] = int(proc.stdout.strip().splitlines()[-1]) - except Exception: - response["pid"] = None - return response - - -def _build_cua_background_command(command: str) -> str: - return f"python3 -c {shlex.quote(_CUA_BACKGROUND_LAUNCHER)} {shlex.quote(command)}" - - -class CuaPythonComponent(PythonComponent): - def __init__(self, sandbox: Any, os_type: str = "linux") -> None: - self._sandbox = sandbox - self._os_type = os_type - python = getattr(sandbox, "python", None) - self._python_exec = None - if python is not None: - self._python_exec = getattr(python, "exec", None) or getattr( - python, "run", None - ) - - async def exec( - self, - code: str, - kernel_id: str | None = None, - timeout: int = 30, - silent: bool = False, - ) -> dict[str, Any]: - _ = kernel_id - if self._python_exec is not None: - result = await _maybe_await(self._python_exec(code, timeout=timeout)) - proc = _normalize_process_result(result) - else: - shell = CuaShellComponent(self._sandbox, os_type=self._os_type) - proc = await _exec_python3_or_error( - shell, - code, - operation="Python execution fallback", - timeout=timeout, - ) - - output_text = "" if silent else proc.stdout - error_text = proc.stderr - return { - "success": proc.success if not silent else not bool(error_text), - "data": { - "output": {"text": output_text, "images": []}, - "error": error_text, - }, - "output": output_text, - "error": error_text, - } - - -def _write_result(path: str, result: dict[str, Any]) -> dict[str, Any]: - stderr = result.get("stderr", "") - if stderr and _is_missing_python3_error(stderr): - result = { - **result, - "stderr": _python3_requirement_error("filesystem write fallback", stderr), - } - if result.get("stderr") or result.get("success") is False: - return {"success": False, "path": path, **result} - return {"success": True, "path": path, **result} - - -class CuaFileSystemComponent(FileSystemComponent): - def __init__( - self, sandbox: Any, os_type: str = CUA_DEFAULT_CONFIG["os_type"] - ) -> None: - self._shell = CuaShellComponent(sandbox, os_type=os_type) - self._fs_components = _resolve_files_components(sandbox) - self._os_type = os_type.lower() - self._fallback = _PosixShellFileSystem(self._shell, self._os_type) - - async def create_file( - self, - path: str, - content: str = "", - mode: int = 0o644, - ) -> dict[str, Any]: - write_result = await self.write_file(path, content) - if not write_result.get("success"): - return {**write_result, "mode": mode, "mode_applied": False} - return {"success": True, "path": path, "mode": mode, "mode_applied": False} - - async def read_file( - self, - path: str, - encoding: str = "utf-8", - offset: int | None = None, - limit: int | None = None, - ) -> dict[str, Any]: - read_file = _resolve_files_method( - self._fs_components, ("read_file", "read_text") - ) - if read_file is None: - return await self._fallback.read_file(path, encoding, offset, limit) - else: - content = await _maybe_await(read_file(path)) - if isinstance(content, bytes): - content = content.decode(encoding, errors="replace") - return { - "success": True, - "path": path, - "content": _slice_content_by_lines( - str(content), offset=offset, limit=limit - ), - } - - async def write_file( - self, - path: str, - content: str, - mode: str = "w", - encoding: str = "utf-8", - ) -> dict[str, Any]: - _ = mode - write_file = _resolve_files_method( - self._fs_components, ("write_file", "write_text") - ) - if write_file is None: - return await self._fallback.write_file(path, content, mode, encoding) - else: - await _maybe_await(write_file(path, content)) - return {"success": True, "path": path} - - async def delete_file(self, path: str) -> dict[str, Any]: - delete = _resolve_files_method( - self._fs_components, ("delete", "delete_file", "remove") - ) - if delete is None: - return await self._fallback.delete_file(path) - else: - await _maybe_await(delete(path)) - return {"success": True, "path": path} - - async def list_dir( - self, - path: str = ".", - show_hidden: bool = False, - ) -> dict[str, Any]: - list_dir = _resolve_files_method(self._fs_components, ("list_dir", "list")) - if list_dir is not None: - entries = await _maybe_await(list_dir(path)) - return {"success": True, "path": path, "entries": entries} - return await self._fallback.list_dir(path, show_hidden) - - async def search_files( - self, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - ) -> dict[str, Any]: - return await self._fallback.search_files( - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - - async def edit_file( - self, - path: str, - old_string: str, - new_string: str, - replace_all: bool = False, - encoding: str = "utf-8", - ) -> dict[str, Any]: - read_result = await self.read_file(path, encoding=encoding) - if not read_result.get("success"): - return read_result - content = read_result.get("content", "") - occurrences = content.count(old_string) - if occurrences == 0: - return { - "success": False, - "error": "old string not found in file", - "replacements": 0, - } - updated = content.replace(old_string, new_string, -1 if replace_all else 1) - write_result = await self.write_file(path, updated, encoding=encoding) - if not write_result.get("success"): - return write_result - return { - "success": True, - "path": path, - "replacements": occurrences if replace_all else 1, - } - - -class _PosixShellFileSystem(FileSystemComponent): - def __init__(self, shell: CuaShellComponent, os_type: str) -> None: - self._shell = shell - self._os_type = os_type.lower() - - def _ensure_posix(self, path: str) -> dict[str, Any] | None: - if _is_posix_os_type(self._os_type): - return None - return _non_posix_filesystem_result(path, self._os_type) - - async def read_file( - self, - path: str, - encoding: str = "utf-8", - offset: int | None = None, - limit: int | None = None, - ) -> dict[str, Any]: - _ = encoding - if error := self._ensure_posix(path): - return error - result = await self._shell.exec(f"cat {shlex.quote(path)}") - if result.get("stderr"): - return {"success": False, "path": path, "error": result["stderr"]} - return { - "success": True, - "path": path, - "content": _slice_content_by_lines( - str(result.get("stdout", "")), offset=offset, limit=limit - ), - } - - async def write_file( - self, - path: str, - content: str, - mode: str = "w", - encoding: str = "utf-8", - ) -> dict[str, Any]: - _ = mode - if error := self._ensure_posix(path): - return error - result = await _write_base64_via_shell( - self._shell, path, content.encode(encoding) - ) - return _write_result(path, result) - - async def delete_file(self, path: str) -> dict[str, Any]: - if error := self._ensure_posix(path): - return error - result = await self._shell.exec(f"rm -rf {shlex.quote(path)}") - if result.get("stderr"): - return {"success": False, "path": path, "error": result["stderr"]} - return {"success": True, "path": path} - - async def list_dir( - self, - path: str = ".", - show_hidden: bool = False, - ) -> dict[str, Any]: - if error := self._ensure_posix(path): - return error - return await _list_dir_via_shell(self._shell, path, show_hidden) - - async def search_files( - self, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - ) -> dict[str, Any]: - search_path = path or "." - if error := self._ensure_posix(search_path): - return error - return await search_files_via_shell( - self._shell, - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - - -async def _list_dir_via_shell( - shell: CuaShellComponent, - path: str, - show_hidden: bool, -) -> dict[str, Any]: - flags = "-1A" if show_hidden else "-1" - result = await shell.exec(f"ls {flags} {shlex.quote(path)}") - stdout = result.get("stdout", "") - return { - "success": not bool(result.get("stderr")), - "path": path, - "entries": [line for line in stdout.splitlines() if line.strip()], - "error": result.get("stderr", ""), - } - - -class CuaGUIComponent(GUIComponent): - def __init__(self, sandbox: Any) -> None: - self._sandbox = sandbox - mouse = getattr(sandbox, "mouse", None) - keyboard = getattr(sandbox, "keyboard", None) - self._click = _resolve_component_method(mouse, "click") - self._type_text = _resolve_component_method(keyboard, "type") - self._press_key = _resolve_component_method( - keyboard, ("press", "key_press", "press_key") - ) - - async def screenshot(self, path: str | None = None) -> dict[str, Any]: - raw = await self._sandbox.screenshot() - data = _screenshot_to_bytes(raw) - if path: - Path(path).parent.mkdir(parents=True, exist_ok=True) - Path(path).write_bytes(data) - return { - "success": True, - "path": path, - "mime_type": "image/png", - "base64": base64.b64encode(data).decode("ascii"), - } - - async def click(self, x: int, y: int, button: str = "left") -> dict[str, Any]: - if self._click is None: - raise _missing_component_method_error("mouse", "click") - result = await _maybe_await(self._click(x, y, button=button)) - payload = _maybe_model_dump(result) - return {"success": bool(payload.get("success", True)), **payload} - - async def type_text(self, text: str) -> dict[str, Any]: - if self._type_text is None: - raise _missing_component_method_error("keyboard", "type") - result = await _maybe_await(self._type_text(text)) - payload = _maybe_model_dump(result) - return {"success": bool(payload.get("success", True)), **payload} - - async def press_key(self, key: str) -> dict[str, Any]: - if self._press_key is None: - raise _missing_component_method_error( - "keyboard", ("press", "key_press", "press_key") - ) - result = await _maybe_await(self._press_key(key)) - payload = _maybe_model_dump(result) - return {"success": bool(payload.get("success", True)), **payload} - - -def _screenshot_to_bytes(raw: Any) -> bytes: - def from_str(value: str) -> bytes: - if value.startswith("data:image"): - value = value.split(",", 1)[1] - try: - return base64.b64decode(value, validate=True) - except Exception: - candidate = Path(value) - if candidate.is_file(): - return candidate.read_bytes() - return value.encode("utf-8") - - if isinstance(raw, (bytes, bytearray)): - return bytes(raw) - if isinstance(raw, str): - return from_str(raw) - if hasattr(raw, "save"): - import io - - output = io.BytesIO() - raw.save(output, format="PNG") - return output.getvalue() - payload = _maybe_model_dump(raw) - for key in ("data", "base64", "image"): - value = payload.get(key) - if value: - return _screenshot_to_bytes(value) - raise TypeError(f"Unsupported CUA screenshot result: {type(raw)!r}") - - -@dataclass(slots=True) -class _CuaRuntime: - sandbox_cm: Any - sandbox: Any - shell: CuaShellComponent - python: CuaPythonComponent - fs: CuaFileSystemComponent - gui: CuaGUIComponent | None - - -class CuaBooter(ComputerBooter): - def __init__( - self, - image: str = CUA_DEFAULT_CONFIG["image"], - os_type: str = CUA_DEFAULT_CONFIG["os_type"], - ttl: int = CUA_DEFAULT_CONFIG["ttl"], - telemetry_enabled: bool = CUA_DEFAULT_CONFIG["telemetry_enabled"], - local: bool = CUA_DEFAULT_CONFIG["local"], - api_key: str = CUA_DEFAULT_CONFIG["api_key"], - ) -> None: - self.image = image - self.os_type = os_type - self.ttl = ttl - self.telemetry_enabled = telemetry_enabled - self.local = local - self.api_key = api_key - self._runtime: _CuaRuntime | None = None - - async def boot(self, session_id: str) -> None: - _ = session_id - try: - from cua import Image, Sandbox - except ImportError as exc: - raise RuntimeError( - "CUA sandbox support requires the optional `cua` package. " - "Install it with `pip install cua` in the AstrBot environment." - ) from exc - - image_obj = self._build_image(Image) - ephemeral_kwargs = self._build_ephemeral_kwargs(Sandbox.ephemeral) - sandbox_cm = Sandbox.ephemeral(image_obj, **ephemeral_kwargs) - sandbox = await sandbox_cm.__aenter__() - try: - self._runtime = _CuaRuntime( - sandbox_cm=sandbox_cm, - sandbox=sandbox, - shell=CuaShellComponent(sandbox, os_type=self.os_type), - python=CuaPythonComponent(sandbox, os_type=self.os_type), - fs=CuaFileSystemComponent(sandbox, os_type=self.os_type), - gui=CuaGUIComponent(sandbox), - ) - except Exception: - await sandbox_cm.__aexit__(None, None, None) - self._runtime = None - raise - logger.info( - "[Computer] CUA sandbox booted: image=%s, os_type=%s", - self.image, - self.os_type, - ) - - def _build_image(self, image_cls: Any) -> Any: - image_name = (self.image or self.os_type or "linux").strip().lower() - factory = getattr(image_cls, image_name, None) - if callable(factory): - return factory() - os_factory = getattr(image_cls, (self.os_type or "linux").strip().lower(), None) - if callable(os_factory): - return os_factory() - return image_name - - def _build_ephemeral_kwargs(self, ephemeral: Any) -> dict[str, Any]: - try: - parameters = inspect.signature(ephemeral).parameters - except (TypeError, ValueError): - return {} - kwargs: dict[str, Any] = {} - if "ttl" in parameters: - kwargs["ttl"] = self.ttl - if "telemetry_enabled" in parameters: - kwargs["telemetry_enabled"] = self.telemetry_enabled - if "local" in parameters: - kwargs["local"] = self.local - if "api_key" in parameters and self.api_key: - kwargs["api_key"] = self.api_key - return kwargs - - async def shutdown(self) -> None: - if self._runtime is not None: - await self._runtime.sandbox_cm.__aexit__(None, None, None) - self._runtime = None - - @property - def capabilities(self) -> tuple[str, ...] | None: - capabilities = ["python", "shell", "filesystem"] - if self._runtime is None: - return tuple(capabilities) - - sandbox = self._runtime.sandbox - has_screenshot = getattr(sandbox, "screenshot", None) is not None - has_mouse = _has_component_method(sandbox, "mouse", "click") - has_keyboard = _has_component_method(sandbox, "keyboard", "type") - if has_screenshot or has_mouse or has_keyboard: - capabilities.append("gui") - if has_screenshot: - capabilities.append("screenshot") - if has_mouse: - capabilities.append("mouse") - if has_keyboard: - capabilities.append("keyboard") - return tuple(capabilities) - - @property - def fs(self) -> FileSystemComponent: - if self._runtime is None: - raise RuntimeError("CuaBooter is not initialized.") - return self._runtime.fs - - @property - def python(self) -> PythonComponent: - if self._runtime is None: - raise RuntimeError("CuaBooter is not initialized.") - return self._runtime.python - - @property - def shell(self) -> ShellComponent: - if self._runtime is None: - raise RuntimeError("CuaBooter is not initialized.") - return self._runtime.shell - - @property - def gui(self) -> GUIComponent | None: - return None if self._runtime is None else self._runtime.gui - - async def upload_file(self, path: str, file_name: str) -> dict: - local_path = Path(path) - if not local_path.is_file(): - return {"success": False, "error": f"File not found: {path}"} - sandbox = None if self._runtime is None else self._runtime.sandbox - if sandbox is not None and hasattr(sandbox, "upload_file"): - return _maybe_model_dump( - await sandbox.upload_file(str(local_path), file_name) - ) - files_components = () if sandbox is None else _resolve_files_components(sandbox) - upload = _resolve_files_method(files_components, "upload") - if upload is not None: - result = await _maybe_await(upload(str(local_path), file_name)) - return _normalize_native_upload_result(result, file_name) - write_bytes = _resolve_files_method(files_components, "write_bytes") - if write_bytes is not None: - result = await _maybe_await(write_bytes(file_name, local_path.read_bytes())) - return _normalize_native_upload_result(result, file_name) - if not _is_posix_os_type(self.os_type): - return _non_posix_filesystem_result(file_name, self.os_type) - result = await _write_base64_via_shell( - self.shell, file_name, local_path.read_bytes() - ) - return { - "success": not bool(result.get("stderr")), - "file_path": file_name, - **result, - } - - async def download_file(self, remote_path: str, local_path: str) -> None: - sandbox = None if self._runtime is None else self._runtime.sandbox - if sandbox is not None and hasattr(sandbox, "download_file"): - await sandbox.download_file(remote_path, local_path) - return - if not _is_posix_os_type(self.os_type): - _raise_non_posix_filesystem_error(self.os_type) - result = await self.shell.exec(f"base64 {shlex.quote(remote_path)}") - if result.get("stderr"): - raise RuntimeError(result["stderr"]) - Path(local_path).parent.mkdir(parents=True, exist_ok=True) - Path(local_path).write_bytes(base64.b64decode(result.get("stdout", ""))) - - async def available(self) -> bool: - return self._runtime is not None diff --git a/astrbot/core/computer/booters/cua_defaults.py b/astrbot/core/computer/booters/cua_defaults.py deleted file mode 100644 index a36c6e6546..0000000000 --- a/astrbot/core/computer/booters/cua_defaults.py +++ /dev/null @@ -1,18 +0,0 @@ -CUA_DEFAULT_CONFIG = { - "image": "linux", - "os_type": "linux", - "ttl": 3600, - "idle_timeout": 0, - "telemetry_enabled": False, - "local": True, - "api_key": "", -} - -CUA_CONFIG_KEYS = { - "image": "cua_image", - "os_type": "cua_os_type", - "ttl": "cua_ttl", - "telemetry_enabled": "cua_telemetry_enabled", - "local": "cua_local", - "api_key": "cua_api_key", -} diff --git a/astrbot/core/computer/booters/local.py b/astrbot/core/computer/booters/local.py index 1fb7b5cf7a..8ac4dd6a4d 100644 --- a/astrbot/core/computer/booters/local.py +++ b/astrbot/core/computer/booters/local.py @@ -20,7 +20,8 @@ from ..olayer import FileSystemComponent, PythonComponent, ShellComponent from .base import ComputerBooter -from .shipyard_search_file_util import _truncate_long_lines + +_MAX_SEARCH_LINE_COLUMNS = 1000 _BLOCKED_COMMAND_PATTERNS = [ " rm -rf ", @@ -83,6 +84,23 @@ def _decode_shell_output(output: bytes | None) -> str: return _decode_bytes_with_fallback(output, preferred_encoding="utf-8") +def _truncate_long_lines(text: str) -> str: + output_lines: list[str] = [] + for line in text.splitlines(keepends=True): + line_ending = "" + line_body = line + if line.endswith("\r\n"): + line_body = line[:-2] + line_ending = "\r\n" + elif line.endswith("\n") or line.endswith("\r"): + line_body = line[:-1] + line_ending = line[-1] + if len(line_body) > _MAX_SEARCH_LINE_COLUMNS: + line_body = line_body[:_MAX_SEARCH_LINE_COLUMNS] + output_lines.append(f"{line_body}{line_ending}") + return "".join(output_lines) + + @dataclass class LocalShellComponent(ShellComponent): async def exec( diff --git a/astrbot/core/computer/booters/shell_background.py b/astrbot/core/computer/booters/shell_background.py deleted file mode 100644 index 6fe94c133a..0000000000 --- a/astrbot/core/computer/booters/shell_background.py +++ /dev/null @@ -1,18 +0,0 @@ -import shlex - -_BACKGROUND_SPAWN_SCRIPT = ( - "import subprocess, sys; " - "p = subprocess.Popen(" - "['bash', '-lc', sys.argv[1]], " - "stdin=subprocess.DEVNULL, " - "stdout=subprocess.DEVNULL, " - "stderr=subprocess.DEVNULL, " - "start_new_session=True, " - "close_fds=True" - "); " - "print(p.pid)" -) - - -def build_detached_shell_command(command: str) -> str: - return f"python3 -c {shlex.quote(_BACKGROUND_SPAWN_SCRIPT)} {shlex.quote(command)}" diff --git a/astrbot/core/computer/booters/shipyard.py b/astrbot/core/computer/booters/shipyard.py deleted file mode 100644 index a8375544da..0000000000 --- a/astrbot/core/computer/booters/shipyard.py +++ /dev/null @@ -1,249 +0,0 @@ -from __future__ import annotations - -import shlex -from typing import Any - -from shipyard import FileSystemComponent as ShipyardFileSystemComponent -from shipyard import ShipyardClient, Spec - -from astrbot.api import logger - -from ..olayer import FileSystemComponent, PythonComponent, ShellComponent -from .base import ComputerBooter -from .shell_background import build_detached_shell_command -from .shipyard_search_file_util import search_files_via_shell - - -def _maybe_model_dump(value: Any) -> dict[str, Any]: - if isinstance(value, dict): - return value - if hasattr(value, "model_dump"): - dumped = value.model_dump() - if isinstance(dumped, dict): - return dumped - return {} - - -class ShipyardShellWrapper: - def __init__(self, _shipyard_shell: ShellComponent): - self._shell = _shipyard_shell - - async def exec( - self, - command: str, - cwd: str | None = None, - env: dict[str, str] | None = None, - timeout: int | None = 300, - shell: bool = True, - background: bool = False, - ) -> dict[str, Any]: - if not shell: - return { - "stdout": "", - "stderr": "error: only shell mode is supported in shipyard booter.", - "exit_code": 2, - "success": False, - } - - run_command = command - if env: - env_prefix = " ".join( - f"{k}={shlex.quote(str(v))}" for k, v in sorted(env.items()) - ) - run_command = f"{env_prefix} {run_command}" - - if background: - run_command = build_detached_shell_command(run_command) - - result = await self._shell.exec( - run_command, - timeout=timeout or 300, - cwd=cwd, - ) - payload = _maybe_model_dump(result) - - stdout = payload.get("output", payload.get("stdout", "")) or "" - stderr = payload.get("error", payload.get("stderr", "")) or "" - exit_code = payload.get("exit_code") - if background: - pid: int | None = None - try: - pid = int(str(stdout).strip().splitlines()[-1]) - except Exception: - pid = None - return { - "pid": pid, - "stdout": ( - f"Command is running in the background. pid={pid}" - if pid is not None - else "Command was submitted in the background." - ), - "stderr": stderr, - "exit_code": exit_code, - "success": bool(payload.get("success", not stderr)), - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "command": payload.get("command"), - } - - return { - "stdout": stdout, - "stderr": stderr, - "exit_code": exit_code, - "success": bool(payload.get("success", not stderr)), - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "command": payload.get("command"), - } - - -class ShipyardFileSystemWrapper: - def __init__( - self, _shipyard_fs: ShipyardFileSystemComponent, _shipyard_shell: ShellComponent - ): - self._fs = _shipyard_fs - self._shell = _shipyard_shell - - async def create_file( - self, path: str, content: str = "", mode: int = 420 - ) -> dict[str, Any]: - return await self._fs.create_file(path=path, content=content, mode=mode) - - async def read_file( - self, - path: str, - encoding: str = "utf-8", - offset: int | None = None, - limit: int | None = None, - ) -> dict[str, Any]: - return await self._fs.read_file( - path=path, encoding=encoding, offset=offset, limit=limit - ) - - async def write_file( - self, path: str, content: str, mode: str = "w", encoding: str = "utf-8" - ) -> dict[str, Any]: - return await self._fs.write_file( - path=path, content=content, mode=mode, encoding=encoding - ) - - async def list_dir( - self, path: str = ".", show_hidden: bool = False - ) -> dict[str, Any]: - return await self._fs.list_dir(path=path, show_hidden=show_hidden) - - async def delete_file(self, path: str) -> dict[str, Any]: - return await self._fs.delete_file(path=path) - - async def search_files( - self, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - ) -> dict[str, Any]: - return await search_files_via_shell( - self._shell, - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - - async def edit_file( - self, - path: str, - old_string: str, - new_string: str, - replace_all: bool = False, - encoding: str = "utf-8", - ) -> dict[str, Any]: - return await self._fs.edit_file( - path=path, - old_string=old_string, - new_string=new_string, - replace_all=replace_all, - encoding=encoding, - ) - - -class ShipyardBooter(ComputerBooter): - def __init__( - self, - endpoint_url: str, - access_token: str, - ttl: int = 3600, - session_num: int = 10, - ) -> None: - self._sandbox_client = ShipyardClient( - endpoint_url=endpoint_url, access_token=access_token - ) - self._ttl = ttl - self._session_num = session_num - - async def boot(self, session_id: str) -> None: - ship = await self._sandbox_client.create_ship( - ttl=self._ttl, - spec=Spec(cpus=1.0, memory="512m"), - max_session_num=self._session_num, - session_id=session_id, - ) - logger.info(f"Got sandbox ship: {ship.id} for session: {session_id}") - self._ship = ship - self._shell = ShipyardShellWrapper(self._ship.shell) - self._fs = ShipyardFileSystemWrapper(self._ship.fs, self._shell) - - async def shutdown(self) -> None: - logger.info("[Computer] Shipyard booter shutdown.") - - @property - def fs(self) -> FileSystemComponent: - return self._fs - - @property - def python(self) -> PythonComponent: - return self._ship.python - - @property - def shell(self) -> ShellComponent: - return self._shell - - async def upload_file(self, path: str, file_name: str) -> dict: - """Upload file to sandbox""" - result = await self._ship.upload_file(path, file_name) - logger.info("[Computer] File uploaded to Shipyard sandbox: %s", file_name) - return result - - async def download_file(self, remote_path: str, local_path: str): - """Download file from sandbox.""" - result = await self._ship.download_file(remote_path, local_path) - logger.info( - "[Computer] File downloaded from Shipyard sandbox: %s -> %s", - remote_path, - local_path, - ) - return result - - async def available(self) -> bool: - """Check if the sandbox is available.""" - try: - ship_id = self._ship.id - data = await self._sandbox_client.get_ship(ship_id) - if not data: - logger.info( - "[Computer] Shipyard sandbox health check: id=%s, healthy=False (no data)", - ship_id, - ) - return False - health = bool(data.get("status", 0) == 1) - logger.info( - "[Computer] Shipyard sandbox health check: id=%s, healthy=%s", - ship_id, - health, - ) - return health - except Exception as e: - logger.error(f"Error checking Shipyard sandbox availability: {e}") - return False diff --git a/astrbot/core/computer/booters/shipyard_neo.py b/astrbot/core/computer/booters/shipyard_neo.py deleted file mode 100644 index dd982960f4..0000000000 --- a/astrbot/core/computer/booters/shipyard_neo.py +++ /dev/null @@ -1,702 +0,0 @@ -from __future__ import annotations - -import asyncio -import os -import shlex -from typing import Any, cast - -from astrbot.api import logger - -from ..olayer import ( - BrowserComponent, - FileSystemComponent, - PythonComponent, - ShellComponent, -) -from .base import ComputerBooter -from .shell_background import build_detached_shell_command -from .shipyard_search_file_util import search_files_via_shell - -try: - from shipyard_neo import BayClient - from shipyard_neo.sandbox import Sandbox -except ImportError: - logger.warning( - "shipyard_neo_sdk is not installed. ShipyardNeoBooter will not work without it." - ) - - -def _maybe_model_dump(value: Any) -> dict[str, Any]: - if isinstance(value, dict): - return value - if hasattr(value, "model_dump"): - dumped = value.model_dump() - if isinstance(dumped, dict): - return dumped - return {} - - -def _slice_content_by_lines( - content: str, - *, - offset: int | None = None, - limit: int | None = None, -) -> str: - lines = content.splitlines(keepends=True) - start = 0 if offset is None else offset - selected = lines[start:] if limit is None else lines[start : start + limit] - return "".join(selected) - - -class NeoPythonComponent(PythonComponent): - def __init__(self, sandbox: Sandbox) -> None: - self._sandbox = sandbox - - async def exec( - self, - code: str, - kernel_id: str | None = None, - timeout: int = 30, - silent: bool = False, - ) -> dict[str, Any]: - _ = kernel_id # Bay runtime does not expose kernel_id in current SDK. - result = await self._sandbox.python.exec(code, timeout=timeout) - payload = _maybe_model_dump(result) - - output_text = payload.get("output", "") or "" - error_text = payload.get("error", "") or "" - data = payload.get("data") if isinstance(payload.get("data"), dict) else {} - rich_output = data.get("output") if isinstance(data.get("output"), dict) else {} - if not isinstance(rich_output.get("images"), list): - rich_output["images"] = [] - if "text" not in rich_output: - rich_output["text"] = output_text - - if silent: - rich_output["text"] = "" - - return { - "success": bool(payload.get("success", error_text == "")), - "data": { - "output": rich_output, - "error": error_text, - }, - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "code": payload.get("code"), - "output": output_text, - "error": error_text, - } - - -class NeoShellComponent(ShellComponent): - def __init__(self, sandbox: Sandbox) -> None: - self._sandbox = sandbox - - async def exec( - self, - command: str, - cwd: str | None = None, - env: dict[str, str] | None = None, - timeout: int | None = 300, - shell: bool = True, - background: bool = False, - ) -> dict[str, Any]: - if not shell: - return { - "stdout": "", - "stderr": "error: only shell mode is supported in shipyard_neo booter.", - "exit_code": 2, - "success": False, - } - - run_command = command - if env: - env_prefix = " ".join( - f"{k}={shlex.quote(str(v))}" for k, v in sorted(env.items()) - ) - run_command = f"{env_prefix} {run_command}" - - if background: - run_command = build_detached_shell_command(run_command) - - result = await self._sandbox.shell.exec( - run_command, - timeout=timeout or 300, - cwd=cwd, - ) - payload = _maybe_model_dump(result) - - stdout = payload.get("output", "") or "" - stderr = payload.get("error", "") or "" - exit_code = payload.get("exit_code") - if background: - pid: int | None = None - try: - pid = int(stdout.strip().splitlines()[-1]) - except Exception: - pid = None - return { - "pid": pid, - "stdout": ( - f"Command is running in the background. pid={pid}" - if pid is not None - else "Command was submitted in the background." - ), - "stderr": stderr, - "exit_code": exit_code, - "success": bool(payload.get("success", not stderr)), - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "command": payload.get("command"), - } - - return { - "stdout": stdout, - "stderr": stderr, - "exit_code": exit_code, - "success": bool(payload.get("success", not stderr)), - "execution_id": payload.get("execution_id"), - "execution_time_ms": payload.get("execution_time_ms"), - "command": payload.get("command"), - } - - -class NeoFileSystemComponent(FileSystemComponent): - def __init__(self, sandbox: Sandbox, shell: ShellComponent) -> None: - self._sandbox = sandbox - self._shell = shell - - async def create_file( - self, - path: str, - content: str = "", - mode: int = 0o644, - ) -> dict[str, Any]: - _ = mode - await self._sandbox.filesystem.write_file(path, content) - return {"success": True, "path": path} - - async def read_file( - self, - path: str, - encoding: str = "utf-8", - offset: int | None = None, - limit: int | None = None, - ) -> dict[str, Any]: - _ = encoding - content = await self._sandbox.filesystem.read_file(path) - return { - "success": True, - "path": path, - "content": _slice_content_by_lines( - content, - offset=offset, - limit=limit, - ), - } - - async def search_files( - self, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - ) -> dict[str, Any]: - return await search_files_via_shell( - self._shell, - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - - async def edit_file( - self, - path: str, - old_string: str, - new_string: str, - replace_all: bool = False, - encoding: str = "utf-8", - ) -> dict[str, Any]: - _ = encoding - content = await self._sandbox.filesystem.read_file(path) - occurrences = content.count(old_string) - if occurrences == 0: - return { - "success": False, - "error": "old string not found in file", - "replacements": 0, - } - if replace_all: - updated = content.replace(old_string, new_string) - replacements = occurrences - else: - updated = content.replace(old_string, new_string, 1) - replacements = 1 - await self._sandbox.filesystem.write_file(path, updated) - return { - "success": True, - "path": path, - "replacements": replacements, - } - - async def write_file( - self, - path: str, - content: str, - mode: str = "w", - encoding: str = "utf-8", - ) -> dict[str, Any]: - _ = mode - _ = encoding - await self._sandbox.filesystem.write_file(path, content) - return {"success": True, "path": path} - - async def delete_file(self, path: str) -> dict[str, Any]: - await self._sandbox.filesystem.delete(path) - return {"success": True, "path": path} - - async def list_dir( - self, - path: str = ".", - show_hidden: bool = False, - ) -> dict[str, Any]: - entries = await self._sandbox.filesystem.list_dir(path) - data = [] - for entry in entries: - item = _maybe_model_dump(entry) - if not show_hidden and str(item.get("name", "")).startswith("."): - continue - data.append(item) - return {"success": True, "path": path, "entries": data} - - -class NeoBrowserComponent(BrowserComponent): - def __init__(self, sandbox: Sandbox) -> None: - self._sandbox = sandbox - - async def exec( - self, - cmd: str, - timeout: int = 30, - description: str | None = None, - tags: str | None = None, - learn: bool = False, - include_trace: bool = False, - ) -> dict[str, Any]: - result = await self._sandbox.browser.exec( - cmd, - timeout=timeout, - description=description, - tags=tags, - learn=learn, - include_trace=include_trace, - ) - return _maybe_model_dump(result) - - async def exec_batch( - self, - commands: list[str], - timeout: int = 60, - stop_on_error: bool = True, - description: str | None = None, - tags: str | None = None, - learn: bool = False, - include_trace: bool = False, - ) -> dict[str, Any]: - result = await self._sandbox.browser.exec_batch( - commands, - timeout=timeout, - stop_on_error=stop_on_error, - description=description, - tags=tags, - learn=learn, - include_trace=include_trace, - ) - return _maybe_model_dump(result) - - async def run_skill( - self, - skill_key: str, - timeout: int = 60, - stop_on_error: bool = True, - include_trace: bool = False, - description: str | None = None, - tags: str | None = None, - ) -> dict[str, Any]: - result = await self._sandbox.browser.run_skill( - skill_key=skill_key, - timeout=timeout, - stop_on_error=stop_on_error, - include_trace=include_trace, - description=description, - tags=tags, - ) - return _maybe_model_dump(result) - - -class ShipyardNeoBooter(ComputerBooter): - """Booter backed by Shipyard Neo (Bay). - - If *endpoint_url* is empty or set to ``"__auto__"``, Bay will be - started automatically as a Docker container (like Boxlite does for - Ship containers). - """ - - AUTO_SENTINEL = "__auto__" - DEFAULT_PROFILE = "python-default" - - def __init__( - self, - endpoint_url: str, - access_token: str, - profile: str = "", - ttl: int = 3600, - ) -> None: - self._endpoint_url = endpoint_url - self._access_token = access_token - self._profile = profile.strip() if profile else "" - self._ttl = ttl - self._client: BayClient | None = None - self._sandbox: Sandbox | None = None - self._bay_manager: Any = None # BayContainerManager when auto-started - self._fs: FileSystemComponent | None = None - self._python: PythonComponent | None = None - self._shell: ShellComponent | None = None - self._browser: BrowserComponent | None = None - - @property - def bay_client(self) -> Any: - return self._client - - @property - def sandbox(self) -> Any: - return self._sandbox - - @property - def capabilities(self) -> tuple[str, ...] | None: - """Sandbox capabilities from the Bay profile. - - Returns an immutable tuple after :meth:`boot`; ``None`` before boot. - """ - if self._sandbox is None: - return None - caps = getattr(self._sandbox, "capabilities", None) - return tuple(caps) if caps is not None else None - - @property - def is_auto_mode(self) -> bool: - """True when Bay should be auto-started.""" - ep = (self._endpoint_url or "").strip() - return not ep or ep == self.AUTO_SENTINEL - - async def boot(self, session_id: str) -> None: - _ = session_id - - # --- Auto-start Bay if needed --- - if self.is_auto_mode: - from .bay_manager import BayContainerManager - - # Clean up previous manager if re-booting - if self._bay_manager is not None: - await self._bay_manager.close_client() - - logger.info("[Computer] Neo auto-start mode: launching Bay container") - self._bay_manager = BayContainerManager() - self._endpoint_url = await self._bay_manager.ensure_running() - await self._bay_manager.wait_healthy() - # Read auto-provisioned credentials - if not self._access_token: - self._access_token = await self._bay_manager.read_credentials() - logger.info("[Computer] Bay auto-started at %s", self._endpoint_url) - - if not self._endpoint_url or not self._access_token: - if self._bay_manager is not None: - raise ValueError( - "Bay container started but credentials could not be read. " - "Ensure Bay generated credentials.json, or set access_token manually." - ) - raise ValueError( - "Shipyard Neo sandbox configuration is incomplete. " - "Set endpoint (default http://127.0.0.1:8114) and access token, " - "or ensure Bay's credentials.json is accessible for auto-discovery." - ) - - self._client = BayClient( - endpoint_url=self._endpoint_url, - access_token=self._access_token, - ) - await self._client.__aenter__() - - # Resolve profile: user-specified > smart selection > default. - # An empty profile means auto-select; any non-empty profile must be - # honoured as an explicit choice, including "python-default". - resolved_profile = await self._resolve_profile(self._client) - - self._sandbox = await self._client.create_sandbox( - profile=resolved_profile, - ttl=self._ttl, - ) - - # --- Readiness gate: wait until sandbox session is READY --- - await self._wait_until_ready(self._sandbox) - - self._shell = NeoShellComponent(self._sandbox) - self._fs = NeoFileSystemComponent(self._sandbox, self._shell) - self._python = NeoPythonComponent(self._sandbox) - - caps = self.capabilities or () - self._browser = ( - NeoBrowserComponent(self._sandbox) if "browser" in caps else None - ) - - logger.info( - "Got Shipyard Neo sandbox: %s (profile=%s, capabilities=%s, auto=%s)", - self._sandbox.id, - resolved_profile, - list(caps), - bool(self._bay_manager), - ) - - async def _wait_until_ready(self, sandbox: Sandbox) -> None: - """Poll sandbox status until READY, or raise on FAILED / timeout. - - Covers both warm-pool hits (near-instant) and cold starts (up to 180s). - On FAILED, EXPIRED, or timeout the sandbox is deleted before raising - so no orphan resources leak on Bay. - """ - READINESS_TIMEOUT = 180 # seconds - POLL_INTERVAL = 2 # seconds - - sandbox_id = sandbox.id - deadline = asyncio.get_running_loop().time() + READINESS_TIMEOUT - - while True: - await sandbox.refresh() - status = getattr(sandbox.status, "value", str(sandbox.status)) - - if status == "ready": - logger.info( - "[Computer] Sandbox %s is ready (profile=%s)", - sandbox_id, - sandbox.profile, - ) - return - - if status in {"failed", "expired"}: - logger.error( - "[Computer] Sandbox %s reached terminal state: %s", - sandbox_id, - status, - ) - try: - await sandbox.delete() - except Exception as del_err: - logger.warning( - "[Computer] Failed to delete failed sandbox %s: %s", - sandbox_id, - del_err, - ) - raise RuntimeError( - f"Sandbox {sandbox_id} is in terminal state: {status}" - ) - - remaining = deadline - asyncio.get_running_loop().time() - if remaining <= 0: - logger.error( - "[Computer] Sandbox %s did not become ready within %ds " - "(last status: %s)", - sandbox_id, - READINESS_TIMEOUT, - status, - ) - try: - await sandbox.delete() - except Exception as del_err: - logger.warning( - "[Computer] Failed to delete timed-out sandbox %s: %s", - sandbox_id, - del_err, - ) - raise TimeoutError( - f"Sandbox {sandbox_id} did not become ready within " - f"{READINESS_TIMEOUT}s (last status: {status})" - ) - - logger.debug( - "[Computer] Sandbox %s status=%s, waiting...", - sandbox_id, - status, - ) - await asyncio.sleep(POLL_INTERVAL) - - async def _resolve_profile(self, client: Any) -> str: - """Pick the best profile for this session. - - Resolution order: - 1. User-specified profile (non-empty) → use as-is. - 2. Query ``GET /v1/profiles`` and pick the profile with the most - capabilities, preferring profiles that include ``"browser"``. - 3. Fall back to :attr:`DEFAULT_PROFILE`. - - Auth errors (401/403) are re-raised immediately — they indicate a - misconfigured token, and silently falling back would just delay the - real failure to ``create_sandbox``. - """ - # User explicitly set a profile → honour it. - if self._profile: - logger.info("[Computer] Using user-specified profile: %s", self._profile) - return self._profile - - # Query Bay for available profiles - from shipyard_neo.errors import ForbiddenError, UnauthorizedError - - try: - profile_list = await client.list_profiles() - profiles = profile_list.items - except (UnauthorizedError, ForbiddenError): - raise # auth errors must not be silenced - except Exception as exc: - logger.warning( - "[Computer] Failed to query Bay profiles, falling back to %s: %s", - self.DEFAULT_PROFILE, - exc, - ) - return self.DEFAULT_PROFILE - - if not profiles: - return self.DEFAULT_PROFILE - - def _score(p: Any) -> tuple[int, int]: - """(has_browser, capability_count) — higher is better.""" - caps = getattr(p, "capabilities", []) or [] - return (1 if "browser" in caps else 0, len(caps)) - - best = max(profiles, key=_score) - chosen = getattr(best, "id", self.DEFAULT_PROFILE) - - if chosen != self.DEFAULT_PROFILE: - caps = getattr(best, "capabilities", []) - logger.info( - "[Computer] Auto-selected profile %s (capabilities=%s)", - chosen, - caps, - ) - - return chosen - - async def shutdown(self, *, delete_sandbox: bool = False) -> None: - if self._client is not None: - sandbox_id = getattr(self._sandbox, "id", "unknown") - - # Delete sandbox on Bay BEFORE closing the HTTP client. - # This is critical for cleanup — calling delete after - # __aexit__ would fail because the httpx session is already - # torn down. - if delete_sandbox and self._sandbox is not None: - try: - logger.info( - "[Computer] Deleting Shipyard Neo sandbox: id=%s", sandbox_id - ) - await self._sandbox.delete() - logger.info( - "[Computer] Shipyard Neo sandbox deleted: id=%s", sandbox_id - ) - except Exception as e: - logger.warning( - "[Computer] Failed to delete sandbox %s (may already be " - "cleaned up by Bay GC): %s", - sandbox_id, - e, - ) - - logger.info( - "[Computer] Shutting down Shipyard Neo sandbox client: id=%s", - sandbox_id, - ) - await self._client.__aexit__(None, None, None) - self._client = None - self._sandbox = None - logger.info( - "[Computer] Shipyard Neo sandbox client shut down: id=%s", sandbox_id - ) - - # NOTE: We intentionally do NOT stop the Bay container here. - # It stays running for reuse by future sessions. The user can - # stop it manually or via ``BayContainerManager.stop()``. - if self._bay_manager is not None: - await self._bay_manager.close_client() - - @property - def fs(self) -> FileSystemComponent: - if self._fs is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - return self._fs - - @property - def python(self) -> PythonComponent: - if self._python is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - return self._python - - @property - def shell(self) -> ShellComponent: - if self._shell is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - return self._shell - - @property - def browser(self) -> BrowserComponent: - if self._browser is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - return self._browser - - async def upload_file(self, path: str, file_name: str) -> dict: - if self._sandbox is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - with open(path, "rb") as f: - content = f.read() - remote_path = file_name.lstrip("/") - await self._sandbox.filesystem.upload(remote_path, content) - logger.info("[Computer] File uploaded to Neo sandbox: %s", remote_path) - return { - "success": True, - "message": "File uploaded successfully", - "file_path": remote_path, - } - - async def download_file(self, remote_path: str, local_path: str) -> None: - if self._sandbox is None: - raise RuntimeError("ShipyardNeoBooter is not initialized.") - content = await self._sandbox.filesystem.download(remote_path.lstrip("/")) - local_dir = os.path.dirname(local_path) - if local_dir: - os.makedirs(local_dir, exist_ok=True) - with open(local_path, "wb") as f: - f.write(cast(bytes, content)) - logger.info( - "[Computer] File downloaded from Neo sandbox: %s -> %s", - remote_path, - local_path, - ) - - async def available(self) -> bool: - if self._sandbox is None: - return False - try: - await self._sandbox.refresh() - status = getattr(self._sandbox.status, "value", str(self._sandbox.status)) - healthy = status not in {"failed", "expired"} - logger.info( - "[Computer] Neo sandbox health check: id=%s, status=%s, healthy=%s", - getattr(self._sandbox, "id", "unknown"), - status, - healthy, - ) - return healthy - except Exception as e: - logger.error(f"Error checking Shipyard Neo sandbox availability: {e}") - return False diff --git a/astrbot/core/computer/booters/shipyard_search_file_util.py b/astrbot/core/computer/booters/shipyard_search_file_util.py deleted file mode 100644 index cdd41de82e..0000000000 --- a/astrbot/core/computer/booters/shipyard_search_file_util.py +++ /dev/null @@ -1,148 +0,0 @@ -from __future__ import annotations - -import shlex -from typing import Any - -from ..olayer import ShellComponent - -_MAX_SEARCH_LINE_COLUMNS = 1000 - - -def _truncate_long_lines(text: str) -> str: - output_lines: list[str] = [] - for line in text.splitlines(keepends=True): - line_ending = "" - line_body = line - if line.endswith("\r\n"): - line_body = line[:-2] - line_ending = "\r\n" - elif line.endswith("\n") or line.endswith("\r"): - line_body = line[:-1] - line_ending = line[-1] - - if len(line_body) > _MAX_SEARCH_LINE_COLUMNS: - line_body = line_body[:_MAX_SEARCH_LINE_COLUMNS] - - output_lines.append(f"{line_body}{line_ending}") - return "".join(output_lines) - - -def _build_rg_command( - *, - pattern: str, - path: str, - glob: str | None, - after_context: int | None, - before_context: int | None, -) -> list[str]: - command = [ - "rg", - "--color=never", - "-n", - "--max-columns", - str(_MAX_SEARCH_LINE_COLUMNS), - "-e", - pattern, - ] - if glob: - command.extend(["-g", glob]) - if after_context is not None: - command.extend(["-A", str(after_context)]) - if before_context is not None: - command.extend(["-B", str(before_context)]) - command.extend(["--", path]) - return command - - -def _build_grep_command( - *, - pattern: str, - path: str, - glob: str | None, - after_context: int | None, - before_context: int | None, -) -> list[str]: - command = ["grep", "-R", "-H", "-n", "-e", pattern] - if glob: - command.append(f"--include={glob}") - if after_context is not None: - command.extend(["-A", str(after_context)]) - if before_context is not None: - command.extend(["-B", str(before_context)]) - command.extend(["--", path]) - return command - - -def _quote_command(command: list[str]) -> str: - return shlex.join(command) - - -def build_search_command( - *, - pattern: str, - path: str, - glob: str | None, - after_context: int | None, - before_context: int | None, -) -> str: - rg_command = _quote_command( - _build_rg_command( - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - ) - grep_command = _quote_command( - _build_grep_command( - pattern=pattern, - path=path, - glob=glob, - after_context=after_context, - before_context=before_context, - ) - ) - return ( - "if command -v rg >/dev/null 2>&1; then " - f"{rg_command}; " - "elif command -v grep >/dev/null 2>&1; then " - f"{grep_command}; " - "else " - "echo 'Neither rg nor grep is available in the sandbox.' >&2; " - "exit 127; " - "fi" - ) - - -async def search_files_via_shell( - shell: ShellComponent, - *, - pattern: str, - path: str | None = None, - glob: str | None = None, - after_context: int | None = None, - before_context: int | None = None, - timeout: int = 30, -) -> dict[str, Any]: - command = build_search_command( - pattern=pattern, - path=path or ".", - glob=glob, - after_context=after_context, - before_context=before_context, - ) - result = await shell.exec(command, timeout=timeout) - stdout = _truncate_long_lines(str(result.get("stdout", "") or "")) - stderr = str(result.get("stderr", "") or "") - exit_code = result.get("exit_code") - if exit_code in (0, None): - return {"success": True, "content": stdout} - if exit_code == 1: - return {"success": True, "content": ""} - return { - "success": False, - "content": "", - "error": stderr or f"command exited with code {exit_code}", - "exit_code": exit_code, - } diff --git a/astrbot/core/computer/computer_client.py b/astrbot/core/computer/computer_client.py index 9be646265e..5984f48c41 100644 --- a/astrbot/core/computer/computer_client.py +++ b/astrbot/core/computer/computer_client.py @@ -1,13 +1,12 @@ import asyncio import json -import os import shutil -import time import uuid -from dataclasses import dataclass from pathlib import Path from astrbot.api import logger +from astrbot.core.agent.tool import FunctionTool +from astrbot.core.provider.register import llm_tools from astrbot.core.skills.skill_manager import SANDBOX_SKILLS_ROOT, SkillManager from astrbot.core.star.context import Context from astrbot.core.utils.astrbot_path import ( @@ -17,74 +16,312 @@ from .booters.base import ComputerBooter from .booters.local import LocalBooter - -session_booter: dict[str, ComputerBooter] = {} -local_booter: ComputerBooter | None = None +from .sandbox_manager import SandboxManager +from .sandbox_models import SandboxStatus +from .sandbox_provider import SandboxProvider +from .sandbox_registry import SandboxRegistry +from .sandbox_tool_binding import mark_tool_as_sandbox_provider_tool + +local_booter: LocalBooter | None = None +sandbox_registry = SandboxRegistry() +sandbox_manager = SandboxManager(registry=sandbox_registry, providers={}) _MANAGED_SKILLS_FILE = ".astrbot_managed_skills.json" +_SANDBOX_SKILLS_SYNC_LOCK = asyncio.Lock() +# Tracks tools registered per provider so core can remove them on unregister. +_provider_tools: dict[str, list[FunctionTool]] = {} -@dataclass(slots=True) -class _CUAIdleState: - expires_at: float - task: asyncio.Task +def _sandbox_provider_info(provider_id: str, provider: SandboxProvider) -> dict: + return { + "provider_id": provider_id, + "capabilities": sorted(getattr(provider, "capabilities", set())), + "tool_names": sorted(getattr(provider, "tool_names", set())), + "system_prompt": str(getattr(provider, "system_prompt", "") or ""), + } -cua_idle_state: dict[str, _CUAIdleState] = {} +def _has_managed_sandboxes_for_provider(provider_id: str) -> bool: + return any( + record.get("managed") and record.get("provider") == provider_id + for record in sandbox_manager.registry.list_sandboxes() + ) -def _get_cua_idle_timeout(config: dict) -> float: - sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) - value = sandbox_cfg.get("cua_idle_timeout", 0) - try: - timeout = float(value) - except (TypeError, ValueError): - return 0.0 - return max(timeout, 0.0) +def register_sandbox_provider( + provider: SandboxProvider, + *, + replace: bool = False, + tools: list[FunctionTool] | None = None, +) -> None: + """Register a plugin-provided sandbox runtime. + + Args: + provider: The sandbox provider instance. + replace: If ``True``, replace an existing provider with the same ID. + tools: Optional list of provider-specific tools to register with the + global LLM tool manager. Core will automatically unregister these + tools when the provider is unregistered. + """ + if not provider.provider_id: + raise ValueError("Sandbox provider_id must be a non-empty string.") + if provider.provider_id in sandbox_manager.providers and not replace: + raise RuntimeError( + f"Sandbox provider {provider.provider_id} is already registered" + ) -def _clear_cua_idle_state(session_id: str) -> None: - state = cua_idle_state.pop(session_id, None) - if state is not None and not state.task.done(): - state.task.cancel() + # Clean up previous tools when replacing. + if replace and provider.provider_id in sandbox_manager.providers: + _unregister_provider_tools(provider.provider_id) + sandbox_manager.providers[provider.provider_id] = provider -def _schedule_cua_idle_cleanup(session_id: str, timeout: float) -> None: - _clear_cua_idle_state(session_id) - if timeout <= 0: - return - expires_at = time.monotonic() + timeout + if tools: + registered: list[FunctionTool] = [] + for tool in tools: + mark_tool_as_sandbox_provider_tool(tool, provider.provider_id) + llm_tools.func_list.append(tool) + registered.append(tool) + _provider_tools[provider.provider_id] = registered + logger.info( + "Sandbox provider %s registered with %d tool(s)", + provider.provider_id, + len(registered), + ) + else: + logger.info("Sandbox provider %s registered", provider.provider_id) - async def _expire_when_idle() -> None: - try: - remaining = expires_at - time.monotonic() - if remaining > 0: - await asyncio.sleep(remaining) - state = cua_idle_state.get(session_id) - if state is None or state.expires_at != expires_at: - return +def unregister_sandbox_provider(provider_id: str, *, force: bool = False) -> None: + if not force and _has_managed_sandboxes_for_provider(provider_id): + raise RuntimeError( + f"Sandbox provider {provider_id} has active managed sandboxes; " + "destroy them or pass force=True before unregistering." + ) + + if force: + # Synchronously clear registry and memory state for this provider's + # sandboxes. Async destroy_booter is best-effort via background task. + _cleanup_provider_sandboxes_sync(provider_id) + + _unregister_provider_tools(provider_id) + sandbox_manager.providers.pop(provider_id, None) - booter = session_booter.get(session_id) + +def _unregister_provider_tools(provider_id: str) -> None: + registered = _provider_tools.pop(provider_id, []) + if registered: + registered_ids = {id(tool) for tool in registered} + llm_tools.func_list = [ + tool for tool in llm_tools.func_list if id(tool) not in registered_ids + ] + from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor + + FunctionToolExecutor.clear_runtime_computer_tools_cache(provider_id) + if registered: + logger.info( + "Unregistered %d tool(s) for sandbox provider %s", + len(registered), + provider_id, + ) + + +def _cleanup_provider_sandboxes_sync(provider_id: str) -> None: + """Synchronous cleanup of a provider's managed sandboxes on unregister. + + Temporary registry records and in-memory state are removed immediately. If + a temporary booter is alive and an event loop is running, an async + destroy_booter task is spawned as a best-effort cleanup. Persistent records + are preserved and their live booters are only shut down to close the current + runtime connection. + """ + import asyncio + + for record in list(sandbox_manager.registry.list_sandboxes()): + if not record.get("managed") or record.get("provider") != provider_id: + continue + sandbox_id = record["sandbox_id"] + if record.get("retention_policy") == "persistent": + booter = sandbox_manager.session_booter.pop(sandbox_id, None) + sandbox_manager.clear_idle_state(sandbox_id) + sandbox_manager.drop_boot_lock(sandbox_id) if booter is not None: try: - await booter.shutdown() - except Exception as shutdown_err: - logger.warning( - "[Computer] Failed to shutdown idle CUA sandbox for session %s: %s", - session_id, - shutdown_err, + loop = asyncio.get_running_loop() + loop.create_task(_safe_shutdown_booter(booter, record)) + except RuntimeError: + pass # no running event loop + continue + booter = sandbox_manager.session_booter.pop(sandbox_id, None) + sandbox_manager.clear_idle_state(sandbox_id) + sandbox_manager.registry.delete_sandbox(sandbox_id) + sandbox_manager.drop_boot_lock(sandbox_id) + if booter is not None: + try: + loop = asyncio.get_running_loop() + provider = sandbox_manager.providers.get(provider_id) + if provider is not None: + loop.create_task(_safe_destroy_booter(provider, booter, record)) + except RuntimeError: + pass # no running event loop + try: + sandbox_manager.registry.save() + except Exception as exc: + logger.warning( + "[Computer] Failed to save registry after force-unregister: %s", + exc, + ) + from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor + + FunctionToolExecutor.clear_runtime_computer_tools_cache(provider_id) + logger.info( + "Force-unregistered sandbox provider %s: sandboxes cleaned up", + provider_id, + ) + + +async def cleanup_sandbox_provider(provider_id: str) -> None: + """Destroy all sandboxes owned by a provider before unregistering it.""" + provider = sandbox_manager.providers.get(provider_id) + removed = 0 + preserved = 0 + handled_sandbox_ids: set[str] = set() + + def _pop_live_booter(sandbox_id: str): + booter = sandbox_manager.session_booter.pop(sandbox_id, None) + sandbox_manager.clear_idle_state(sandbox_id) + sandbox_manager.drop_boot_lock(sandbox_id) + return booter + + for record in list(sandbox_manager.registry.list_sandboxes()): + if not record.get("managed") or record.get("provider") != provider_id: + continue + sandbox_id = record["sandbox_id"] + handled_sandbox_ids.add(sandbox_id) + booter = _pop_live_booter(sandbox_id) + if record.get("retention_policy") == "persistent": + if booter is not None: + await _safe_shutdown_booter(booter, record) + preserved += 1 + continue + if booter is not None and provider is not None: + if not await _safe_destroy_booter(provider, booter, record): + sandbox_manager.session_booter[sandbox_id] = booter + sandbox_manager.registry.update_sandbox_status( + sandbox_id, + "error", + ) + continue + sandbox_manager.registry.delete_sandbox(sandbox_id) + removed += 1 + + for sandbox_id, booter in list(sandbox_manager.session_booter.items()): + booter_provider = getattr(booter, "provider_id", None) + if str(booter_provider or "") != provider_id: + continue + if sandbox_id in handled_sandbox_ids: + continue + record = sandbox_manager.registry.get_sandbox(sandbox_id) or { + "sandbox_id": sandbox_id, + "provider": provider_id, + "managed": True, + "retention_policy": "temporary", + } + sandbox_manager.session_booter.pop(sandbox_id, None) + sandbox_manager.clear_idle_state(sandbox_id) + sandbox_manager.drop_boot_lock(sandbox_id) + if provider is not None: + if not await _safe_destroy_booter(provider, booter, record): + sandbox_manager.session_booter[sandbox_id] = booter + if sandbox_manager.registry.get_sandbox(sandbox_id) is not None: + sandbox_manager.registry.update_sandbox_status( + sandbox_id, + "error", ) - finally: - session_booter.pop(session_id, None) - except asyncio.CancelledError: - raise - finally: - state = cua_idle_state.get(session_id) - if state is not None and state.expires_at == expires_at: - cua_idle_state.pop(session_id, None) + continue + if sandbox_manager.registry.get_sandbox(sandbox_id) is not None: + sandbox_manager.registry.delete_sandbox(sandbox_id) + removed += 1 + try: + await sandbox_manager.save_registry_async() + except Exception as exc: + logger.warning( + "[Computer] Failed to save registry after provider cleanup: %s", + exc, + ) + logger.info( + "Provider sandbox cleanup completed: provider=%s removed_temporary=%d preserved_persistent=%d", + provider_id, + removed, + preserved, + ) + + +def detach_sandbox_provider(provider_id: str) -> None: + """Remove a provider and its registered tools without touching sandboxes.""" + _unregister_provider_tools(provider_id) + sandbox_manager.providers.pop(provider_id, None) + + +async def _safe_destroy_booter( + provider: SandboxProvider, booter: ComputerBooter, record: dict +) -> bool: + try: + await provider.destroy_booter(booter, record) + return True + except Exception as exc: + logger.warning( + "Background destroy_booter failed for sandbox %s: %s", + record.get("sandbox_id"), + exc, + ) + return False + + +async def _safe_shutdown_booter(booter: ComputerBooter, record: dict) -> None: + try: + await booter.shutdown() + except Exception as exc: + logger.warning( + "Background shutdown failed for sandbox %s: %s", + record.get("sandbox_id"), + exc, + ) + + +def get_sandbox_provider_info(provider_id: str) -> dict | None: + provider = sandbox_manager.providers.get(provider_id) + if provider is None: + return None + return _sandbox_provider_info(provider_id, provider) + + +def get_current_sandbox_provider_id(session_id: str) -> str | None: + current_sandbox_id = sandbox_manager.registry.get_current_sandbox_id(session_id) + if not current_sandbox_id: + return None + current_record = sandbox_manager.registry.get_sandbox(current_sandbox_id) + if current_record is None: + return None + if current_record.get("status") in { + SandboxStatus.STOPPING, + SandboxStatus.STOPPED, + SandboxStatus.ERROR, + }: + return None + provider_id = str(current_record.get("provider") or "").strip() + return provider_id or None + - task = asyncio.create_task(_expire_when_idle()) - cua_idle_state[session_id] = _CUAIdleState(expires_at=expires_at, task=task) +def list_sandbox_providers() -> list[dict]: + return [ + _sandbox_provider_info(provider_id, provider) + for provider_id, provider in sorted(sandbox_manager.providers.items()) + ] + + +async def cleanup_managed_sandboxes() -> None: + await sandbox_manager.cleanup_managed_sandboxes() def _list_local_skill_dirs(skills_root: Path) -> list[Path]: @@ -131,65 +368,6 @@ def _normalize_shell_exec_result(result: object) -> dict: return {"exit_code": 0, "stdout": "", "stderr": ""} -def _discover_bay_credentials(endpoint: str) -> str: - """Try to auto-discover Bay API key from credentials.json. - - Search order: - 1. BAY_DATA_DIR env var - 2. Mono-repo relative path: ../pkgs/bay/ (dev layout) - 3. Current working directory - - Returns: - API key string, or empty string if not found. - """ - candidates: list[Path] = [] - - # 1. BAY_DATA_DIR env var - bay_data_dir = os.environ.get("BAY_DATA_DIR") - if bay_data_dir: - candidates.append(Path(bay_data_dir) / "credentials.json") - - # 2. Mono-repo layout: AstrBot/../pkgs/bay/credentials.json - astrbot_root = Path(__file__).resolve().parents[3] # astrbot/core/computer/ → root - candidates.append(astrbot_root.parent / "pkgs" / "bay" / "credentials.json") - - # 3. Current working directory - candidates.append(Path.cwd() / "credentials.json") - - for cred_path in candidates: - if not cred_path.is_file(): - continue - try: - data = json.loads(cred_path.read_text()) - api_key = data.get("api_key", "") - if api_key: - # Optionally verify endpoint matches - cred_endpoint = data.get("endpoint", "") - if ( - cred_endpoint - and endpoint - and cred_endpoint.rstrip("/") != endpoint.rstrip("/") - ): - logger.warning( - "[Computer] credentials.json endpoint mismatch: " - "file=%s, configured=%s — using key anyway", - cred_endpoint, - endpoint, - ) - masked_key = f"{api_key[:4]}..." if len(api_key) >= 6 else "redacted" - logger.info( - "[Computer] Auto-discovered Bay API key from %s (prefix=%s)", - cred_path, - masked_key, - ) - return api_key - except (json.JSONDecodeError, OSError) as exc: - logger.debug("[Computer] Failed to read %s: %s", cred_path, exc) - - logger.debug("[Computer] No Bay credentials.json found in search paths") - return "" - - def _build_python_exec_command(script: str) -> str: return ( "if command -v python3 >/dev/null 2>&1; then PYBIN=python3; " @@ -201,7 +379,7 @@ def _build_python_exec_command(script: str) -> str: ) -def _build_apply_sync_command() -> str: +def _build_apply_sync_command(zip_name: str = "skills.zip") -> str: """Build shell command for sync stage only. This stage mutates sandbox files (managed skill replacement) but does not scan @@ -215,7 +393,7 @@ def _build_apply_sync_command() -> str: from pathlib import Path root = Path({SANDBOX_SKILLS_ROOT!r}) -zip_path = root / "skills.zip" +zip_path = root / {zip_name!r} tmp_extract = Path(f"{{root}}_tmp_extract") managed_file = root / {_MANAGED_SKILLS_FILE!r} @@ -435,16 +613,22 @@ def _decode_sync_payload(stdout: str) -> dict | None: return None -def _update_sandbox_skills_cache(payload: dict | None) -> None: +def _update_sandbox_skills_cache( + payload: dict | None, + provider_id: str | None = None, +) -> None: if not isinstance(payload, dict): return skills = payload.get("skills", []) if not isinstance(skills, list): return - SkillManager().set_sandbox_skills_cache(skills) + SkillManager().set_sandbox_skills_cache(skills, provider_id=provider_id) -async def _apply_skills_to_sandbox(booter: ComputerBooter) -> None: +async def _apply_skills_to_sandbox( + booter: ComputerBooter, + zip_name: str = "skills.zip", +) -> None: """Apply local skill bundle to sandbox filesystem only. This function is intentionally limited to file mutation. Metadata scanning is @@ -452,7 +636,7 @@ async def _apply_skills_to_sandbox(booter: ComputerBooter) -> None: """ logger.info("[Computer] Skill sync phase=apply start") apply_result = _normalize_shell_exec_result( - await booter.shell.exec(_build_apply_sync_command()) + await booter.shell.exec(_build_apply_sync_command(zip_name)) ) if not _shell_exec_succeeded(apply_result): detail = _format_exec_error_detail(apply_result) @@ -480,63 +664,70 @@ async def _scan_sandbox_skills(booter: ComputerBooter) -> dict | None: return payload -async def _sync_skills_to_sandbox(booter: ComputerBooter) -> None: +async def _sync_skills_to_sandbox( + booter: ComputerBooter, + provider_id: str | None = None, +) -> None: """Sync local skills to sandbox and refresh cache. Backward-compatible orchestrator: keep historical behavior while internally splitting into `apply` and `scan` phases. """ - sync_skill_dirs = _collect_sync_skill_dirs() + async with _SANDBOX_SKILLS_SYNC_LOCK: + sync_skill_dirs = _collect_sync_skill_dirs() - temp_dir = Path(get_astrbot_temp_path()) - temp_dir.mkdir(parents=True, exist_ok=True) - zip_base = temp_dir / "skills_bundle" - zip_path = zip_base.with_suffix(".zip") - bundle_root = temp_dir / f"skills_bundle_{uuid.uuid4().hex}" + temp_dir = Path(get_astrbot_temp_path()) + temp_dir.mkdir(parents=True, exist_ok=True) + zip_base = temp_dir / f"skills_bundle_{uuid.uuid4().hex}" + zip_path = zip_base.with_suffix(".zip") + bundle_root = temp_dir / f"{zip_base.name}_contents" + remote_zip_name = f"{zip_base.name}.zip" + remote_zip = Path(SANDBOX_SKILLS_ROOT) / remote_zip_name - try: - if sync_skill_dirs: - if zip_path.exists(): - zip_path.unlink() - if bundle_root.exists(): - shutil.rmtree(bundle_root) - bundle_root.mkdir(parents=True) - for skill_name, skill_dir in sync_skill_dirs: - shutil.copytree(skill_dir, bundle_root / skill_name) - shutil.make_archive(str(zip_base), "zip", str(bundle_root)) - remote_zip = Path(SANDBOX_SKILLS_ROOT) / "skills.zip" - logger.info("Uploading skills bundle to sandbox...") - await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}") - upload_result = await booter.upload_file(str(zip_path), str(remote_zip)) - if not upload_result.get("success", False): - raise RuntimeError("Failed to upload skills bundle to sandbox.") - else: + try: + if sync_skill_dirs: + bundle_root.mkdir(parents=True) + for skill_name, skill_dir in sync_skill_dirs: + shutil.copytree(skill_dir, bundle_root / skill_name) + shutil.make_archive(str(zip_base), "zip", str(bundle_root)) + logger.info("Uploading skills bundle to sandbox...") + await booter.shell.exec(f"mkdir -p {SANDBOX_SKILLS_ROOT}") + upload_result = await booter.upload_file(str(zip_path), str(remote_zip)) + if not upload_result.get("success", False): + raise RuntimeError("Failed to upload skills bundle to sandbox.") + else: + logger.info( + "No local skills found. Keeping sandbox built-ins and refreshing metadata." + ) + await booter.shell.exec( + f"rm -f {SANDBOX_SKILLS_ROOT}/{remote_zip_name}" + ) + + # Keep backward-compatible behavior while splitting lifecycle into two + # observable phases: apply (filesystem mutation) + scan (metadata read). + await _apply_skills_to_sandbox(booter, zip_name=remote_zip_name) + payload = await _scan_sandbox_skills(booter) + _update_sandbox_skills_cache(payload, provider_id=provider_id) + managed = ( + payload.get("managed_skills", []) if isinstance(payload, dict) else [] + ) logger.info( - "No local skills found. Keeping sandbox built-ins and refreshing metadata." + "[Computer] Sandbox skill sync complete: managed=%d", + len(managed), ) - await booter.shell.exec(f"rm -f {SANDBOX_SKILLS_ROOT}/skills.zip") - - # Keep backward-compatible behavior while splitting lifecycle into two - # observable phases: apply (filesystem mutation) + scan (metadata read). - await _apply_skills_to_sandbox(booter) - payload = await _scan_sandbox_skills(booter) - _update_sandbox_skills_cache(payload) - managed = payload.get("managed_skills", []) if isinstance(payload, dict) else [] - logger.info( - "[Computer] Sandbox skill sync complete: managed=%d", - len(managed), - ) - finally: - if bundle_root.exists(): - try: - shutil.rmtree(bundle_root) - except Exception: - logger.warning(f"Failed to remove temp skills bundle: {bundle_root}") - if zip_path.exists(): - try: - zip_path.unlink() - except Exception: - logger.warning(f"Failed to remove temp skills zip: {zip_path}") + finally: + if bundle_root.exists(): + try: + shutil.rmtree(bundle_root) + except Exception: + logger.warning( + f"Failed to remove temp skills bundle: {bundle_root}" + ) + if zip_path.exists(): + try: + zip_path.unlink() + except Exception: + logger.warning(f"Failed to remove temp skills zip: {zip_path}") async def get_booter( @@ -551,126 +742,55 @@ async def get_booter( elif runtime == "none": raise RuntimeError("Sandbox runtime is disabled by configuration.") - sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) - booter_type = sandbox_cfg.get("booter", "shipyard_neo") - cua_idle_timeout = _get_cua_idle_timeout(config) if booter_type == "cua" else 0.0 - - if session_id in session_booter: - booter = session_booter[session_id] - if not await booter.available(): - # Clean up old booter before rebuilding so sandbox resources - # on Bay (containers, volumes, networks) are not leaked. - # Only ShipyardNeoBooter supports delete_sandbox; other booters - # (local, boxlite, cua, etc.) are not backed by a remote sandbox - # manager and don't need it. - try: - if booter_type == "shipyard_neo": - await booter.shutdown(delete_sandbox=True) - else: - await booter.shutdown() - except Exception as shutdown_err: - logger.warning( - "[Computer] Error shutting down stale booter for session %s: %s", - session_id, - shutdown_err, - ) - _clear_cua_idle_state(session_id) - session_booter.pop(session_id, None) - if session_id not in session_booter: - uuid_str = uuid.uuid5(uuid.NAMESPACE_DNS, session_id).hex - logger.info( - f"[Computer] Initializing booter: type={booter_type}, session={session_id}" - ) - if booter_type == "shipyard": - from .booters.shipyard import ShipyardBooter - - ep = sandbox_cfg.get("shipyard_endpoint", "") - token = sandbox_cfg.get("shipyard_access_token", "") - ttl = sandbox_cfg.get("shipyard_ttl", 3600) - max_sessions = sandbox_cfg.get("shipyard_max_sessions", 10) - - client = ShipyardBooter( - endpoint_url=ep, access_token=token, ttl=ttl, session_num=max_sessions - ) - elif booter_type == "shipyard_neo": - from .booters.shipyard_neo import ShipyardNeoBooter - - ep = sandbox_cfg.get("shipyard_neo_endpoint", "") - token = sandbox_cfg.get("shipyard_neo_access_token", "") - ttl = sandbox_cfg.get("shipyard_neo_ttl", 3600) - profile = sandbox_cfg.get("shipyard_neo_profile", "python-default") - - # Auto-discover token from Bay's credentials.json if not configured - if not token: - token = _discover_bay_credentials(ep) - - logger.info( - f"[Computer] Shipyard Neo config: endpoint={ep}, profile={profile}, ttl={ttl}" - ) - client = ShipyardNeoBooter( - endpoint_url=ep, - access_token=token, - profile=profile, - ttl=ttl, - ) - elif booter_type == "cua": - from .booters.cua import CuaBooter, build_cua_booter_kwargs - - cua_kwargs = build_cua_booter_kwargs(sandbox_cfg) - logger.info( - f"[Computer] CUA config: image={cua_kwargs['image']}, " - f"os_type={cua_kwargs['os_type']}, ttl={cua_kwargs['ttl']}" + current_sandbox_id = sandbox_manager.registry.get_current_sandbox_id(session_id) + if current_sandbox_id: + current_record = sandbox_manager.registry.get_sandbox(current_sandbox_id) + if current_record and current_record.get("managed"): + return await sandbox_manager.get_observer_booter_by_id( + current_sandbox_id, + session_id, + require_lease=True, + context=context, ) - client = CuaBooter(**cua_kwargs) - elif booter_type == "boxlite": - from .booters.boxlite import BoxliteBooter - client = BoxliteBooter() - else: - raise ValueError(f"Unknown booter type: {booter_type}") - - try: - await client.boot(uuid_str) - logger.info( - f"[Computer] Sandbox booted successfully: type={booter_type}, session={session_id}" - ) - await _sync_skills_to_sandbox(client) - except Exception as e: - logger.error(f"Error booting sandbox for session {session_id}: {e}") - try: - if booter_type == "shipyard_neo": - await client.shutdown(delete_sandbox=True) - else: - await client.shutdown() - except Exception as shutdown_error: - logger.warning( - "Failed to shutdown sandbox after boot error for session %s: %s", - session_id, - shutdown_error, - ) - _clear_cua_idle_state(session_id) - raise e + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + provider_id = str(sandbox_cfg.get("booter", "")).strip() + if not provider_id: + raise ValueError( + "Sandbox provider is not configured. Install and enable a sandbox provider plugin, then select it in provider_settings.sandbox.booter." + ) - session_booter[session_id] = client - if booter_type == "cua": - _schedule_cua_idle_cleanup(session_id, cua_idle_timeout) - return session_booter[session_id] + logger.info( + f"[Computer] Initializing sandbox provider: provider={provider_id}, session={session_id}" + ) + if provider_id in sandbox_manager.providers: + return await sandbox_manager.get_or_create_booter( + context, + session_id, + provider_id, + ) + raise ValueError( + f"Unknown sandbox provider: {provider_id}. Install and enable a sandbox provider plugin, then select it in provider_settings.sandbox.booter." + ) async def sync_skills_to_active_sandboxes() -> None: """Best-effort skills synchronization for all active sandbox sessions.""" + active_booters = list(sandbox_manager.session_booter.items()) logger.info( - "[Computer] Syncing skills to %d active sandbox(es)", len(session_booter) + "[Computer] Syncing skills to %d active sandbox(es)", len(active_booters) ) - for session_id, booter in list(session_booter.items()): + for sandbox_id, booter in active_booters: + record = sandbox_manager.registry.get_sandbox(sandbox_id) or {} + provider_id = str(record.get("provider") or "").strip() or None try: - if not await booter.available(): + if not await sandbox_manager.booter_available(booter): continue - await _sync_skills_to_sandbox(booter) + await _sync_skills_to_sandbox(booter, provider_id=provider_id) except Exception as e: logger.warning( - "Failed to sync skills to sandbox for session %s: %s", - session_id, + "Failed to sync skills to sandbox for sandbox %s: %s", + sandbox_id, e, ) diff --git a/astrbot/core/computer/sandbox_manager.py b/astrbot/core/computer/sandbox_manager.py new file mode 100644 index 0000000000..3847a0eebc --- /dev/null +++ b/astrbot/core/computer/sandbox_manager.py @@ -0,0 +1,1969 @@ +from __future__ import annotations + +import asyncio +import inspect +import math +import time +import uuid +from dataclasses import dataclass + +from astrbot.api import logger +from astrbot.core.computer.booters.base import ComputerBooter +from astrbot.core.computer.sandbox_models import SandboxRecord, SandboxStatus +from astrbot.core.computer.sandbox_provider import SandboxProvider +from astrbot.core.computer.sandbox_registry import SandboxRegistry +from astrbot.core.computer.sandbox_timeouts import ( + DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS, + expires_at_from_timeout, + get_provider_sandbox_config, + idle_cleanup_at_from_record, + lease_is_active, + resolve_sandbox_timeout, +) +from astrbot.core.star.context import Context + +SANDBOX_LEASE_SECONDS = int(DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS) +MAX_SANDBOX_LEASE_ATTEMPTS = 3 +MAX_IDLE_DESTROY_ATTEMPTS = 3 +SANDBOX_TTL_DESTROY_RETRY_SECONDS = 60 + + +@dataclass(slots=True) +class SandboxIdleState: + expires_at: float + task: asyncio.Task + + +@dataclass(slots=True) +class SandboxExpirationState: + expires_at: float + monotonic_expires_at: float + task: asyncio.Task + + +class SandboxManager: + def __init__( + self, + *, + registry: SandboxRegistry, + providers: dict[str, SandboxProvider], + ) -> None: + self.registry = registry + self.providers = providers + self.session_booter: dict[str, ComputerBooter] = {} + self.idle_state: dict[str, SandboxIdleState] = {} + self.expiration_state: dict[str, SandboxExpirationState] = {} + self.boot_locks: dict[str, asyncio.Lock] = {} + self.created_hook_inflight: set[str] = set() + self.pending_boot_tasks: dict[str, asyncio.Task] = {} + self.pending_destroy_tasks: dict[str, asyncio.Task] = {} + + def _ensure_unique_sandbox_name( + self, sandbox_name: str, *, exclude_sandbox_id: str | None = None + ) -> str: + normalized_name = str(sandbox_name).strip() + for record in self.registry.list_sandboxes(): + if record.get("sandbox_id") == exclude_sandbox_id: + continue + if str(record.get("sandbox_name") or "").strip() == normalized_name: + raise RuntimeError(f"Sandbox name '{normalized_name}' already exists") + return normalized_name + + def _created_sandbox_name(self, sandbox_id: str, sandbox_name: str | None) -> str: + if sandbox_name is None: + return sandbox_id + normalized_name = str(sandbox_name).strip() + if not normalized_name: + return sandbox_id + return self._ensure_unique_sandbox_name(normalized_name) + + def save_registry(self) -> None: + try: + self.registry.save() + except Exception as exc: + logger.warning("[Computer] Failed to save sandbox registry: %s", exc) + raise + + async def save_registry_async(self) -> None: + try: + await self.registry.save_async() + except Exception as exc: + logger.warning("[Computer] Failed to save sandbox registry: %s", exc) + raise + + async def _defer_lifecycle_task_start(self) -> None: + # Let the request that queued this lifecycle work finish before a + # provider boot/destroy path gets a chance to monopolize the event loop. + await asyncio.sleep(0) + + def _sandbox_boot_lock(self, sandbox_id: str) -> asyncio.Lock: + lock = self.boot_locks.get(sandbox_id) + if lock is None: + lock = asyncio.Lock() + self.boot_locks[sandbox_id] = lock + return lock + + def _lease_timeout(self, context: Context | None, session_id: str) -> float: + sandbox_cfg = get_provider_sandbox_config(context, session_id) + return resolve_sandbox_timeout( + sandbox_cfg, + "sandbox_lease_timeout", + aliases=("lease_timeout",), + default=DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS, + ) + + def _idle_timeout(self, context: Context | None, session_id: str) -> float: + sandbox_cfg = get_provider_sandbox_config(context, session_id) + return resolve_sandbox_timeout( + sandbox_cfg, + "sandbox_idle_timeout", + default=0.0, + ) + + def _expires_at( + self, context: Context | None, session_id: str, idle_timeout: float + ) -> float | None: + if idle_timeout > 0: + return None + sandbox_cfg = get_provider_sandbox_config(context, session_id) + ttl = resolve_sandbox_timeout( + sandbox_cfg, + "sandbox_ttl", + default=0.0, + ) + return expires_at_from_timeout(ttl) + + def _sandbox_policy_timeouts( + self, context: Context | None, session_id: str + ) -> tuple[float, float | None]: + idle_timeout = self._idle_timeout(context, session_id) + return idle_timeout, self._expires_at(context, session_id, idle_timeout) + + def _max_sandboxes(self, context: Context | None, session_id: str) -> int: + sandbox_cfg = get_provider_sandbox_config(context, session_id) + try: + max_sandboxes = int(sandbox_cfg.get("max_sandboxes", 10)) + except (TypeError, ValueError): + return 10 + if max_sandboxes < 0: + return 0 + return max_sandboxes + + def _ensure_under_max_sandboxes( + self, context: Context | None, session_id: str + ) -> None: + max_sandboxes = self._max_sandboxes(context, session_id) + if max_sandboxes <= 0: + return + managed_count = sum( + 1 for record in self.registry.list_sandboxes() if record.get("managed") + ) + if managed_count >= max_sandboxes: + raise RuntimeError( + f"Sandbox limit reached. Maximum managed sandboxes: {max_sandboxes}." + ) + + def drop_boot_lock(self, sandbox_id: str) -> None: + self.boot_locks.pop(sandbox_id, None) + + def clear_runtime_state(self, sandbox_id: str) -> None: + self.session_booter.pop(sandbox_id, None) + self.clear_idle_state(sandbox_id) + self.clear_expiration_state(sandbox_id) + self.created_hook_inflight.discard(sandbox_id) + + def clear_runtime_state_and_drop_lock(self, sandbox_id: str) -> None: + self.clear_runtime_state(sandbox_id) + self.drop_boot_lock(sandbox_id) + + def clear_all_runtime_state(self) -> None: + for sandbox_id in list(self.session_booter): + self.clear_runtime_state(sandbox_id) + for sandbox_id in list(self.idle_state): + self.clear_runtime_state_and_drop_lock(sandbox_id) + for sandbox_id in list(self.expiration_state): + self.clear_runtime_state(sandbox_id) + self.boot_locks.clear() + + async def cancel_pending_boot_task(self, sandbox_id: str) -> None: + task = self.pending_boot_tasks.pop(sandbox_id, None) + if task is None: + return + task.cancel() + try: + done, _pending = await asyncio.wait({task}, timeout=1) + if not done: + logger.warning( + "[Computer] Timed out waiting for pending sandbox boot task cancellation: %s", + sandbox_id, + ) + return + await task + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning( + "[Computer] Pending sandbox boot task ended with error for %s: %s", + sandbox_id, + exc, + ) + + async def wait_pending_destroy_task( + self, sandbox_id: str, *, timeout: float | None = 1 + ) -> None: + task = self.pending_destroy_tasks.get(sandbox_id) + if task is None: + return + try: + if timeout is None: + await asyncio.shield(task) + else: + await asyncio.wait_for(asyncio.shield(task), timeout=timeout) + except TimeoutError: + if not task.done(): + logger.warning( + "[Computer] Timed out waiting for pending sandbox destroy task: %s", + sandbox_id, + ) + except asyncio.CancelledError: + pass + except Exception as exc: + logger.warning( + "[Computer] Pending sandbox destroy task ended with error for %s: %s", + sandbox_id, + exc, + ) + finally: + if task.done(): + self.pending_destroy_tasks.pop(sandbox_id, None) + + def get_provider(self, provider_id: str) -> SandboxProvider: + provider = self.providers.get(provider_id) + if provider is None: + raise RuntimeError(f"Provider {provider_id} is not supported") + return provider + + def build_record_payload( + self, + *, + sandbox_id: str, + sandbox_name: str, + session_id: str, + provider_id: str, + idle_timeout: float, + expires_at: float | None, + connect_info: dict, + is_default: bool = False, + status: str = SandboxStatus.RUNNING, + ) -> dict: + return { + "sandbox_id": sandbox_id, + "sandbox_name": sandbox_name, + "provider": provider_id, + "managed": True, + "created_by_astrbot": True, + "owner_user_id": session_id, + "owner_session_id": session_id, + "connect_info": connect_info, + "capabilities": sorted( + getattr(self.get_provider(provider_id), "capabilities", set()) + ), + "tool_names": sorted( + getattr(self.get_provider(provider_id), "tool_names", set()) + ), + "is_default": is_default, + "idle_timeout": idle_timeout, + "expires_at": expires_at, + "status": status, + } + + def new_sandbox_id(self, provider_id: str) -> str: + return f"{provider_id}-{uuid.uuid4().hex[:12]}" + + def get_default_sandbox_id(self, provider_id: str) -> str | None: + default_sandbox_id = self.registry.get_default_sandbox_id(provider_id) + if default_sandbox_id: + record = self.registry.get_sandbox(default_sandbox_id) + if record and record.get("provider") == provider_id: + return default_sandbox_id + for record in self.registry.list_sandboxes(): + if record.get("managed") and record.get("provider") == provider_id: + return record["sandbox_id"] + return None + + async def booter_available(self, booter: ComputerBooter) -> bool: + available = getattr(booter, "available", None) + if available is None: + return False + if getattr(available, "__isabstractmethod__", False): + return False + result = available() if callable(available) else available + if inspect.isawaitable(result): + result = await result + if result is None: + return False + return bool(result) + + def acquire_lease( + self, sandbox_id: str, session_id: str, *, ttl: float | None = None + ) -> bool: + return self.registry.acquire_lease( + sandbox_id=sandbox_id, + session_id=session_id, + user_id=session_id, + ttl=DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS if ttl is None else ttl, + ) + + def sandbox_has_active_lease(self, sandbox_id: str) -> bool: + record = self.registry.get_sandbox(sandbox_id) + if record is None: + return False + return lease_is_active( + record.get("controller_session_id"), + record.get("lease_expires_at"), + ) + + def _release_expired_lease(self, record: dict) -> dict: + sandbox_id = record.get("sandbox_id") + controller_session_id = record.get("controller_session_id") + if not sandbox_id or not controller_session_id: + return record + if lease_is_active(controller_session_id, record.get("lease_expires_at")): + return record + released = self.registry.release_lease(sandbox_id) or record + if self.registry.get_current_sandbox_id(controller_session_id) == sandbox_id: + self.registry.set_current_sandbox_id(controller_session_id, None) + return released + + def sandbox_controlled_by_other_session( + self, sandbox_id: str, session_id: str + ) -> bool: + record = self.registry.get_sandbox(sandbox_id) + if record is None: + return False + controller_session_id = record.get("controller_session_id") + if not controller_session_id or controller_session_id == session_id: + return False + return lease_is_active(controller_session_id, record.get("lease_expires_at")) + + async def _upsert_new_sandbox_record( + self, context: Context, session_id: str, provider_id: str, create_config: dict + ) -> str: + self._ensure_under_max_sandboxes(context, session_id) + provider = self.get_provider(provider_id) + sandbox_id = self.new_sandbox_id(provider_id) + idle_timeout, expires_at = self._sandbox_policy_timeouts(context, session_id) + self.registry.upsert_sandbox( + **self.build_record_payload( + sandbox_id=sandbox_id, + sandbox_name=sandbox_id, + session_id=session_id, + provider_id=provider_id, + idle_timeout=idle_timeout, + expires_at=expires_at, + connect_info=provider.build_connect_info( + sandbox_id, + {**create_config, "sandbox_id": sandbox_id}, + ), + ) + ) + await self.save_registry_async() + return sandbox_id + + def _find_idle_provider_sandbox_id( + self, provider_id: str, *, exclude: set[str] | None = None + ) -> str | None: + excluded = exclude or set() + for record in self.registry.list_sandboxes(): + sandbox_id = record.get("sandbox_id") + if not sandbox_id or sandbox_id in excluded: + continue + if not record.get("managed") or record.get("provider") != provider_id: + continue + if record.get("status") != SandboxStatus.RUNNING: + continue + if self.sandbox_has_active_lease(sandbox_id): + continue + if sandbox_id not in self.session_booter: + continue + return sandbox_id + return None + + @staticmethod + def _sandbox_can_be_bootstrapped(record: dict) -> bool: + status = record.get("status") + if status == SandboxStatus.RUNNING: + return True + return bool( + record.get("retention_policy") == "persistent" + and status == SandboxStatus.UNKNOWN + ) + + async def get_or_create_booter( + self, context: Context, session_id: str, provider_id: str + ) -> ComputerBooter: + provider = self.get_provider(provider_id) + create_config = provider.build_create_config(context, session_id) + idle_timeout, expires_at = self._sandbox_policy_timeouts(context, session_id) + lease_timeout = self._lease_timeout(context, session_id) + + current_sandbox_id = self.registry.get_current_sandbox_id(session_id) + current_record = self.registry.get_sandbox(current_sandbox_id) + excluded_stale_current_ids: set[str] = set() + if current_sandbox_id and ( + current_record is None or current_record.get("provider") != provider_id + ): + if ( + current_record + and current_record.get("controller_session_id") == session_id + ): + self.registry.release_lease(current_sandbox_id) + self.registry.set_current_sandbox_id(session_id, None) + await self.save_registry_async() + current_sandbox_id = None + current_record = None + if current_sandbox_id and current_record: + status = current_record.get("status") + current_controller_session_id = current_record.get("controller_session_id") + if status == SandboxStatus.RUNNING and ( + current_controller_session_id != session_id + or not lease_is_active( + current_controller_session_id, + current_record.get("lease_expires_at"), + ) + ): + self._release_expired_lease(current_record) + self.registry.set_current_sandbox_id(session_id, None) + await self.save_registry_async() + excluded_stale_current_ids.add(current_sandbox_id) + current_sandbox_id = None + current_record = None + status = None + if current_sandbox_id and current_record: + status = current_record.get("status") + if status == SandboxStatus.CREATING: + pending_boot_task = self.pending_boot_tasks.get(current_sandbox_id) + if pending_boot_task is not None: + await asyncio.shield(pending_boot_task) + current_record = self.registry.get_sandbox(current_sandbox_id) + status = current_record.get("status") if current_record else None + if status in { + SandboxStatus.CREATING, + SandboxStatus.RESTORING, + SandboxStatus.STOPPING, + SandboxStatus.ERROR, + }: + if current_record.get("controller_session_id") == session_id: + self.registry.release_lease(current_sandbox_id) + self.registry.set_current_sandbox_id(session_id, None) + await self.save_registry_async() + current_sandbox_id = None + current_record = None + elif ( + current_record.get("retention_policy") == "persistent" + and status == SandboxStatus.UNKNOWN + and current_sandbox_id not in self.session_booter + ): + current_record = await self._revive_persistent_booter_if_needed( + current_record, current_sandbox_id, session_id, context + ) + if ( + current_sandbox_id + and current_record + and current_record.get("provider") == provider_id + and current_sandbox_id in self.session_booter + ): + if not self.acquire_lease( + current_sandbox_id, session_id, ttl=lease_timeout + ): + self.registry.set_current_sandbox_id(session_id, None) + await self.save_registry_async() + else: + booter = self.session_booter[current_sandbox_id] + if await self.booter_available(booter): + self.registry.touch_sandbox(current_sandbox_id) + await self.save_registry_async() + self.schedule_lifecycle_cleanup( + current_sandbox_id, + idle_timeout, + current_record.get("expires_at"), + ) + return booter + self.clear_runtime_state(current_sandbox_id) + self.registry.release_lease(current_sandbox_id) + await self.save_registry_async() + + created_target_record = False + target_sandbox_id = self.get_default_sandbox_id(provider_id) + if target_sandbox_id in excluded_stale_current_ids: + target_sandbox_id = None + target_record = self.registry.get_sandbox(target_sandbox_id) + if ( + target_sandbox_id + and target_record + and target_record.get("provider") == provider_id + and target_record.get("retention_policy") == "persistent" + and target_record.get("status") == SandboxStatus.UNKNOWN + ): + target_record = await self._revive_persistent_booter_if_needed( + target_record, target_sandbox_id, session_id, context + ) + elif target_record and not self._sandbox_can_be_bootstrapped(target_record): + target_sandbox_id = None + + if target_sandbox_id is None: + self._ensure_under_max_sandboxes(context, session_id) + target_sandbox_id = self.new_sandbox_id(provider_id) + created_target_record = True + record = self.registry.upsert_sandbox( + **self.build_record_payload( + sandbox_id=target_sandbox_id, + sandbox_name=target_sandbox_id, + session_id=session_id, + provider_id=provider_id, + idle_timeout=idle_timeout, + expires_at=expires_at, + connect_info=provider.build_connect_info( + target_sandbox_id, + {**create_config, "sandbox_id": target_sandbox_id}, + ), + is_default=True, + ) + ) + self.registry.set_default_sandbox_id(record["sandbox_id"]) + await self.save_registry_async() + + if self.sandbox_controlled_by_other_session(target_sandbox_id, session_id): + reusable_sandbox_id = self._find_idle_provider_sandbox_id( + provider_id, exclude={target_sandbox_id} + ) + if reusable_sandbox_id is not None: + target_sandbox_id = reusable_sandbox_id + created_target_record = False + else: + target_sandbox_id = await self._upsert_new_sandbox_record( + context, session_id, provider_id, create_config + ) + created_target_record = True + + for _attempt in range(MAX_SANDBOX_LEASE_ATTEMPTS): + async with self._sandbox_boot_lock(target_sandbox_id): + target_record = self.registry.get_sandbox(target_sandbox_id) + if target_record and not self._sandbox_can_be_bootstrapped( + target_record + ): + target_sandbox_id = await self._upsert_new_sandbox_record( + context, session_id, provider_id, create_config + ) + created_target_record = True + continue + + if target_sandbox_id in self.session_booter and not self.acquire_lease( + target_sandbox_id, session_id, ttl=lease_timeout + ): + target_sandbox_id = await self._upsert_new_sandbox_record( + context, session_id, provider_id, create_config + ) + created_target_record = True + continue + + if target_sandbox_id in self.session_booter: + booter = self.session_booter[target_sandbox_id] + if await self.booter_available(booter): + break + self.clear_runtime_state(target_sandbox_id) + self.registry.release_lease(target_sandbox_id) + self.registry.update_sandbox_status( + target_sandbox_id, SandboxStatus.UNKNOWN + ) + await self.save_registry_async() + + if not self.acquire_lease( + target_sandbox_id, session_id, ttl=lease_timeout + ): + target_sandbox_id = await self._upsert_new_sandbox_record( + context, session_id, provider_id, create_config + ) + created_target_record = True + continue + + try: + client = await provider.create_booter( + context, session_id, target_sandbox_id, create_config + ) + except Exception: + if created_target_record: + self.registry.delete_sandbox(target_sandbox_id) + else: + self.registry.release_lease(target_sandbox_id) + self.registry.update_sandbox_status( + target_sandbox_id, SandboxStatus.UNKNOWN + ) + self.clear_runtime_state(target_sandbox_id) + await self.save_registry_async() + raise + setattr(client, "sandbox_id", target_sandbox_id) + setattr(client, "provider_id", provider_id) + self.session_booter[target_sandbox_id] = client + break + else: + raise RuntimeError( + "Could not acquire sandbox lease after multiple attempts" + ) + + await self._finalize_created_booter( + provider, + target_sandbox_id, + session_id=session_id, + idle_timeout=idle_timeout, + remove_record_on_failure=created_target_record, + ) + await self._invoke_sandbox_created_hook(provider, target_sandbox_id) + return self.session_booter[target_sandbox_id] + + async def _finalize_created_booter( + self, + provider: SandboxProvider, + sandbox_id: str, + *, + session_id: str | None = None, + idle_timeout: float, + remove_record_on_failure: bool = False, + ) -> None: + """Common post-creation steps: persist, idle cleanup, skill sync, hooks.""" + booter = self.session_booter.get(sandbox_id) + record = self.registry.get_sandbox(sandbox_id) or {} + update_connect_info_after_boot = getattr( + provider, "update_connect_info_after_boot", None + ) + if booter is not None and callable(update_connect_info_after_boot): + connect_info = update_connect_info_after_boot(record, booter) + if connect_info is not None: + self.registry.update_sandbox_config( + sandbox_id, connect_info=connect_info + ) + self.registry.touch_sandbox(sandbox_id) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.RUNNING) + if session_id is not None: + self.registry.set_current_sandbox_id(session_id, sandbox_id) + try: + await self.save_registry_async() + except Exception: + if booter is not None: + try: + await provider.destroy_booter( + booter, self.registry.get_sandbox(sandbox_id) or {} + ) + except Exception as destroy_err: + logger.warning( + "[Computer] Failed to rollback sandbox %s after registry save error: %s", + sandbox_id, + destroy_err, + ) + self.clear_runtime_state(sandbox_id) + if remove_record_on_failure: + self.registry.delete_sandbox(sandbox_id) + else: + self.registry.release_lease(sandbox_id) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + if session_id is not None: + self.registry.set_current_sandbox_id(session_id, None) + raise + record = self.registry.get_sandbox(sandbox_id) or {} + self.schedule_lifecycle_cleanup( + sandbox_id, idle_timeout, record.get("expires_at") + ) + + # Auto-sync skills unless the provider opts out. Best-effort: a sync + # failure is logged but does not destroy the already-created sandbox. + if getattr(provider, "auto_sync_skills", True): + booter = self.session_booter.get(sandbox_id) + if booter is not None and hasattr(booter, "shell"): + try: + await self._sync_skills_to_booter( + booter, + provider_id=getattr(provider, "provider_id", None), + ) + except Exception as sync_err: + logger.warning( + "[Computer] Auto skill sync failed for %s: %s", + sandbox_id, + sync_err, + ) + + async def _invoke_sandbox_created_hook( + self, provider: SandboxProvider, sandbox_id: str + ) -> None: + """Invoke provider's on_sandbox_created hook if present. + + Each sandbox only fires the hook once, guarded by a persistent flag in + the registry record so that dashboard-created sandboxes still receive + the hook when they are first leased via switch/takeover. + + The flag is only set on success so that a transient hook failure can + be retried on the next lease operation. The check-and-set is protected + by the sandbox boot lock to prevent duplicate triggers under concurrent + lease operations. + """ + if not hasattr(provider, "on_sandbox_created"): + async with self._sandbox_boot_lock(sandbox_id): + if not self.registry.has_created_hook_fired(sandbox_id): + self.registry.mark_created_hook_fired(sandbox_id) + await self.save_registry_async() + return + + async with self._sandbox_boot_lock(sandbox_id): + record = self.registry.get_sandbox(sandbox_id) or {} + if ( + record.get("created_hook_fired") + or sandbox_id in self.created_hook_inflight + ): + return + self.created_hook_inflight.add(sandbox_id) + + should_mark_fired = False + try: + await provider.on_sandbox_created(record) + should_mark_fired = True + except Exception as hook_err: + logger.warning( + "[Computer] on_sandbox_created hook failed for %s: %s", + sandbox_id, + hook_err, + ) + return + finally: + async with self._sandbox_boot_lock(sandbox_id): + if should_mark_fired: + if not self.registry.has_created_hook_fired(sandbox_id): + self.registry.mark_created_hook_fired(sandbox_id) + await self.save_registry_async() + self.created_hook_inflight.discard(sandbox_id) + + async def create_sandbox_uncontrolled( + self, + context: Context, + session_id: str, + provider_id: str, + sandbox_name: str | None = None, + ) -> dict: + provider = self.get_provider(provider_id) + sandbox_id = self.new_sandbox_id(provider_id) + sandbox_name = self._created_sandbox_name(sandbox_id, sandbox_name) + self._ensure_under_max_sandboxes(context, session_id) + create_config = provider.build_create_config(context, session_id) + idle_timeout, expires_at = self._sandbox_policy_timeouts(context, session_id) + async with self._sandbox_boot_lock(sandbox_id): + record = self.registry.upsert_sandbox( + **self.build_record_payload( + sandbox_id=sandbox_id, + sandbox_name=sandbox_name, + session_id=session_id, + provider_id=provider_id, + idle_timeout=idle_timeout, + expires_at=expires_at, + connect_info=provider.build_connect_info( + sandbox_name, + {**create_config, "sandbox_id": sandbox_id}, + ), + status=SandboxStatus.CREATING, + ) + ) + try: + client = await provider.create_booter( + context, session_id, sandbox_id, create_config + ) + except asyncio.CancelledError: + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.ERROR) + self.clear_runtime_state(sandbox_id) + await self.save_registry_async() + raise + except Exception: + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.ERROR) + self.clear_runtime_state(sandbox_id) + await self.save_registry_async() + raise + setattr(client, "sandbox_id", sandbox_id) + setattr(client, "provider_id", provider_id) + self.session_booter[sandbox_id] = client + await self._finalize_created_booter( + provider, + sandbox_id, + session_id=None, + idle_timeout=idle_timeout, + remove_record_on_failure=True, + ) + return self.registry.get_sandbox(sandbox_id) or record + + async def create_sandbox_uncontrolled_deferred( + self, + context: Context, + session_id: str, + provider_id: str, + sandbox_name: str | None = None, + ) -> dict: + provider = self.get_provider(provider_id) + sandbox_id = self.new_sandbox_id(provider_id) + sandbox_name = self._created_sandbox_name(sandbox_id, sandbox_name) + self._ensure_under_max_sandboxes(context, session_id) + create_config = provider.build_create_config(context, session_id) + idle_timeout, expires_at = self._sandbox_policy_timeouts(context, session_id) + async with self._sandbox_boot_lock(sandbox_id): + record = self.registry.upsert_sandbox( + **self.build_record_payload( + sandbox_id=sandbox_id, + sandbox_name=sandbox_name, + session_id=session_id, + provider_id=provider_id, + idle_timeout=idle_timeout, + expires_at=expires_at, + connect_info=provider.build_connect_info( + sandbox_name, + {**create_config, "sandbox_id": sandbox_id}, + ), + status=SandboxStatus.CREATING, + ) + ) + await self.save_registry_async() + + task = asyncio.create_task( + self._boot_sandbox_uncontrolled_deferred( + context=context, + session_id=session_id, + provider=provider, + sandbox_id=sandbox_id, + create_config=create_config, + idle_timeout=idle_timeout, + ) + ) + self.pending_boot_tasks[sandbox_id] = task + + return self.registry.get_sandbox(sandbox_id) or record + + async def _boot_sandbox_uncontrolled_deferred( + self, + *, + context: Context, + session_id: str, + provider: SandboxProvider, + sandbox_id: str, + create_config: dict, + idle_timeout: float, + ) -> None: + try: + await self._defer_lifecycle_task_start() + async with self._sandbox_boot_lock(sandbox_id): + current = self.registry.get_sandbox(sandbox_id) + if current is None or current.get("status") != SandboxStatus.CREATING: + return + + try: + client = await provider.create_booter( + context, session_id, sandbox_id, create_config + ) + except asyncio.CancelledError: + raise + except Exception as boot_err: + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.ERROR) + await self.save_registry_async() + logger.warning( + "[Computer] Deferred sandbox boot failed: sandbox_id=%s session_id=%s error=%s", + sandbox_id, + session_id, + boot_err, + ) + return + + current = self.registry.get_sandbox(sandbox_id) + if current is None or current.get("status") != SandboxStatus.CREATING: + try: + cleanup_record = self.registry.get_sandbox(sandbox_id) or {} + await provider.destroy_booter(client, cleanup_record) + except Exception as destroy_err: + logger.warning( + "[Computer] Deferred sandbox cleanup failed after record removal: sandbox_id=%s error=%s", + sandbox_id, + destroy_err, + ) + return + + setattr(client, "sandbox_id", sandbox_id) + setattr(client, "provider_id", provider.provider_id) + self.session_booter[sandbox_id] = client + await self._finalize_created_booter( + provider, + sandbox_id, + session_id=None, + idle_timeout=idle_timeout, + remove_record_on_failure=True, + ) + finally: + self.pending_boot_tasks.pop(sandbox_id, None) + + async def create_sandbox( + self, + context: Context, + session_id: str, + provider_id: str, + sandbox_name: str | None = None, + ) -> dict: + sandbox = await self.create_sandbox_uncontrolled( + context, session_id, provider_id, sandbox_name + ) + sandbox_id = sandbox["sandbox_id"] + lease_timeout = self._lease_timeout(context, session_id) + if not self.acquire_lease(sandbox_id, session_id, ttl=lease_timeout): + provider = self.get_provider(sandbox.get("provider", "")) + await self._destroy_sandbox_cleanup(provider, sandbox_id, sandbox) + raise RuntimeError(f"Sandbox {sandbox_id} is busy") + await self._set_current_sandbox_after_lease(session_id, sandbox_id, sandbox) + provider = self.get_provider(sandbox.get("provider", "")) + # Reset idle cleanup after lease acquisition. The uncontrolled + # creation path already schedules cleanup, but a slow skill-sync or + # short idle_timeout could let the timer expire before the lease is + # acquired. Re-scheduling here guarantees a full idle window. + idle_timeout = sandbox.get("idle_timeout") or 0 + self.schedule_lifecycle_cleanup( + sandbox_id, float(idle_timeout), sandbox.get("expires_at") + ) + await self._invoke_sandbox_created_hook(provider, sandbox_id) + return self.registry.get_sandbox(sandbox_id) or sandbox + + def list_sandboxes(self) -> list[dict]: + records = [] + for record in self.registry.list_sandboxes(): + if not record.get("managed"): + continue + if "booter_type" in record: + record = SandboxRecord.from_dict(record).to_dict() + record = self._release_expired_lease(record) + provider = self.providers.get(record.get("provider")) + updated = dict(record) + updated["capabilities"] = sorted( + getattr(provider, "capabilities", record.get("capabilities", [])) + if provider + else record.get("capabilities", []) + ) + updated["tool_names"] = sorted( + getattr(provider, "tool_names", record.get("tool_names", [])) + if provider + else record.get("tool_names", []) + ) + if self.sandbox_has_active_lease(updated["sandbox_id"]): + updated["idle_cleanup_at"] = None + else: + updated["idle_cleanup_at"] = idle_cleanup_at_from_record( + last_used_at=updated.get("last_used_at"), + idle_timeout=updated.get("idle_timeout"), + ) + records.append(updated) + return records + + async def list_sandboxes_checked(self) -> list[dict]: + changed = False + for record in self.registry.list_sandboxes(): + sandbox_id = record.get("sandbox_id") + if not sandbox_id or not record.get("managed"): + continue + booter = self.session_booter.get(sandbox_id) + if booter is None: + continue + try: + available = await self.booter_available(booter) + except Exception as exc: + logger.warning( + "[Computer] Sandbox health check failed for %s: %s", + sandbox_id, + exc, + ) + available = False + if available: + continue + self.clear_runtime_state_and_drop_lock(sandbox_id) + next_status = ( + SandboxStatus.UNKNOWN + if record.get("retention_policy") == "persistent" + else SandboxStatus.ERROR + ) + self.registry.update_sandbox_status(sandbox_id, next_status) + changed = True + if changed: + await self.save_registry_async() + return self.list_sandboxes() + + def set_default_sandbox(self, sandbox_id: str) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + self.registry.set_default_sandbox_id(sandbox_id) + self.save_registry() + return self.registry.get_sandbox(sandbox_id) or record + + def update_sandbox_config( + self, + sandbox_id: str, + *, + sandbox_name: str | None = None, + idle_timeout: int | float | None, + expires_at: int | float | None, + retention_policy: str, + ) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + provider_id = record.get("provider", "") + provider = self.providers.get(provider_id) + if retention_policy not in {"temporary", "persistent"}: + raise RuntimeError("retention_policy must be temporary or persistent") + if retention_policy == "persistent" and provider is None: + raise RuntimeError(f"Provider {provider_id} is not available") + if ( + retention_policy == "persistent" + and provider is not None + and not getattr(provider, "supports_persistent_reconnect", False) + ): + raise RuntimeError( + f"Provider {record.get('provider')} does not support persistent sandboxes" + ) + if retention_policy == "persistent": + idle_timeout = None + expires_at = None + elif idle_timeout and float(idle_timeout) > 0: + expires_at = None + updates = { + "idle_timeout": idle_timeout, + "expires_at": expires_at, + "retention_policy": retention_policy, + } + if sandbox_name is not None: + normalized_name = str(sandbox_name).strip() + if not normalized_name: + raise ValueError("sandbox_name must be a non-empty string") + normalized_name = self._ensure_unique_sandbox_name( + normalized_name, exclude_sandbox_id=sandbox_id + ) + updates["sandbox_name"] = normalized_name + if provider is not None: + updates["connect_info"] = provider.update_connect_info( + record, + sandbox_name=normalized_name, + ) + updated = self.registry.update_sandbox_config(sandbox_id, **updates) + if retention_policy == "persistent": + self.clear_idle_state(sandbox_id) + self.clear_expiration_state(sandbox_id) + else: + self.schedule_lifecycle_cleanup( + sandbox_id, float(idle_timeout or 0), expires_at + ) + self.save_registry() + return updated or record + + def set_sandbox_retention_policy( + self, + context: Context | None, + session_id: str, + sandbox_id: str, + retention_policy: str, + *, + sandbox_name: str | None = None, + ) -> dict: + idle_timeout: float | None + expires_at: float | None + if retention_policy == "persistent": + idle_timeout = None + expires_at = None + else: + idle_timeout, expires_at = self._sandbox_policy_timeouts( + context, session_id + ) + return self.update_sandbox_config( + sandbox_id, + sandbox_name=sandbox_name, + idle_timeout=idle_timeout, + expires_at=expires_at, + retention_policy=retention_policy, + ) + + async def _revive_persistent_booter_if_needed( + self, + record: dict, + sandbox_id: str, + session_id: str | None, + context: Context | None, + ) -> dict: + if ( + context is None + or record.get("retention_policy") != "persistent" + or record.get("status") + not in {SandboxStatus.RUNNING, SandboxStatus.UNKNOWN} + ): + return record + + provider = self.get_provider(record.get("provider", "")) + if not getattr(provider, "supports_persistent_reconnect", False): + return record + + create_session_id = str( + record.get("owner_session_id") or session_id or "dashboard" + ) + create_config = provider.build_create_config(context, create_session_id) + connect_info = record.get("connect_info") or {} + create_config = { + **create_config, + "persistent_name": str( + connect_info.get("persistent_name") or sandbox_id + ).strip(), + "resume": True, + } + existing_runtime_id = connect_info.get("sandbox_id") + if existing_runtime_id: + create_config["sandbox_id"] = existing_runtime_id + existing_host_port = connect_info.get("host_port") + if existing_host_port: + create_config["host_port"] = existing_host_port + + async with self._sandbox_boot_lock(sandbox_id): + current = self.registry.get_sandbox(sandbox_id) + booter = self.session_booter.get(sandbox_id) + if ( + booter is None + and current is not None + and current.get("status") + in { + SandboxStatus.RUNNING, + SandboxStatus.UNKNOWN, + } + ): + previous_status = current.get("status") or SandboxStatus.UNKNOWN + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.RESTORING) + await self.save_registry_async() + try: + client = await provider.create_booter( + context, + create_session_id, + sandbox_id, + create_config, + ) + except asyncio.CancelledError: + latest = self.registry.get_sandbox(sandbox_id) + if ( + latest is not None + and latest.get("status") == SandboxStatus.RESTORING + ): + self.registry.update_sandbox_status(sandbox_id, previous_status) + await self.save_registry_async() + raise + except Exception: + latest = self.registry.get_sandbox(sandbox_id) + if ( + latest is not None + and latest.get("status") == SandboxStatus.RESTORING + ): + self.registry.update_sandbox_status(sandbox_id, previous_status) + await self.save_registry_async() + raise + setattr(client, "sandbox_id", sandbox_id) + setattr(client, "provider_id", provider.provider_id) + self.session_booter[sandbox_id] = client + await self._finalize_created_booter( + provider, + sandbox_id, + session_id=None, + idle_timeout=( + 0 + if record.get("retention_policy") == "persistent" + else self._idle_timeout(context, create_session_id) + ), + ) + return self.registry.get_sandbox(sandbox_id) or record + + async def switch_current_sandbox_checked( + self, session_id: str, sandbox_id: str, context: Context | None = None + ) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + record = await self._revive_persistent_booter_if_needed( + record, sandbox_id, session_id, context + ) + booter = self.session_booter.get(sandbox_id) + if booter is None: + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + if not await self.booter_available(booter): + self.session_booter.pop(sandbox_id, None) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + await self.save_registry_async() + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + lease_timeout = self._lease_timeout(context, session_id) + if not self.acquire_lease(sandbox_id, session_id, ttl=lease_timeout): + raise RuntimeError(f"Sandbox {sandbox_id} is busy") + result = await self._set_current_sandbox_after_lease( + session_id, sandbox_id, record + ) + provider = self.get_provider(record.get("provider", "")) + await self._invoke_sandbox_created_hook(provider, sandbox_id) + return result + + async def _set_current_sandbox_after_lease( + self, session_id: str, sandbox_id: str, record: dict + ) -> dict: + previous_sandbox_id = self.registry.get_current_sandbox_id(session_id) + if previous_sandbox_id and previous_sandbox_id != sandbox_id: + previous = self.registry.get_sandbox(previous_sandbox_id) + if previous and previous.get("controller_session_id") == session_id: + self.registry.release_lease(previous_sandbox_id) + self.registry.set_current_sandbox_id(session_id, sandbox_id) + self.registry.touch_sandbox(sandbox_id) + await self.save_registry_async() + return self.registry.get_sandbox(sandbox_id) or record + + def get_current_sandbox(self, session_id: str) -> dict: + sandbox_id = self.registry.get_current_sandbox_id(session_id) + sandbox = self.registry.get_sandbox(sandbox_id) if sandbox_id else None + if sandbox: + controller_session_id = sandbox.get("controller_session_id") + if controller_session_id != session_id or not lease_is_active( + controller_session_id, sandbox.get("lease_expires_at") + ): + self._release_expired_lease(sandbox) + self.registry.set_current_sandbox_id(session_id, None) + self.save_registry() + sandbox_id = None + sandbox = None + if sandbox: + provider = self.providers.get(sandbox.get("provider")) + if provider: + sandbox = dict(sandbox) + sandbox["capabilities"] = sorted( + getattr(provider, "capabilities", sandbox.get("capabilities", [])) + ) + sandbox["tool_names"] = sorted( + getattr(provider, "tool_names", sandbox.get("tool_names", [])) + ) + return { + "current_sandbox_id": sandbox_id, + "sandbox": sandbox, + } + + def release_current_sandbox( + self, session_id: str, sandbox_id: str | None = None + ) -> dict: + target_sandbox_id = sandbox_id or self.registry.get_current_sandbox_id( + session_id + ) + if target_sandbox_id is None: + raise RuntimeError("No current sandbox") + record = self.registry.get_sandbox(target_sandbox_id) + if record is None: + raise RuntimeError(f"Sandbox {target_sandbox_id} not found") + controller_session_id = record.get("controller_session_id") + if ( + controller_session_id + and controller_session_id != session_id + and self.sandbox_has_active_lease(target_sandbox_id) + ): + raise RuntimeError( + f"Sandbox {target_sandbox_id} is controlled by another session" + ) + released = self.registry.release_lease(target_sandbox_id) or record + if self.registry.get_current_sandbox_id(session_id) == target_sandbox_id: + self.registry.set_current_sandbox_id(session_id, None) + self.save_registry() + return released + + def force_release_sandbox(self, sandbox_id: str) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None: + raise RuntimeError(f"Sandbox {sandbox_id} not found") + controller_session_id = record.get("controller_session_id") + released = self.registry.release_lease(sandbox_id) or record + if controller_session_id: + if ( + self.registry.get_current_sandbox_id(controller_session_id) + == sandbox_id + ): + self.registry.set_current_sandbox_id(controller_session_id, None) + self.save_registry() + return released + + async def renew_current_sandbox_lease( + self, + session_id: str, + ttl_seconds: int | float | None = None, + context: Context | None = None, + ) -> dict: + current = self.get_current_sandbox(session_id) + sandbox_id = current.get("current_sandbox_id") + if sandbox_id is None: + raise RuntimeError("No current sandbox") + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + status = record.get("status") + if status == SandboxStatus.CREATING: + raise RuntimeError(f"Sandbox {sandbox_id} is still being created") + if status == SandboxStatus.RESTORING: + raise RuntimeError(f"Sandbox {sandbox_id} is being restored") + if status == SandboxStatus.STOPPING: + raise RuntimeError(f"Sandbox {sandbox_id} is being destroyed") + if status == SandboxStatus.STOPPED: + raise RuntimeError(f"Sandbox {sandbox_id} has been destroyed") + if status == SandboxStatus.ERROR: + raise RuntimeError( + f"Sandbox {sandbox_id} encountered an error during creation" + ) + if status != SandboxStatus.RUNNING: + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + booter = self.session_booter.get(sandbox_id) + if booter is None: + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + if not await self.booter_available(booter): + self.session_booter.pop(sandbox_id, None) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + await self.save_registry_async() + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + controller_session_id = record.get("controller_session_id") + if controller_session_id and controller_session_id != session_id: + raise RuntimeError(f"Sandbox {sandbox_id} is controlled by another session") + ttl = ( + self._lease_timeout(context, session_id) + if ttl_seconds is None + else float(ttl_seconds) + ) + if not math.isfinite(ttl): + raise RuntimeError("ttl_seconds must be finite") + if ttl < 0: + raise RuntimeError("ttl_seconds must be non-negative") + if not self.acquire_lease(sandbox_id, session_id, ttl=ttl): + raise RuntimeError(f"Sandbox {sandbox_id} is busy") + self.registry.touch_sandbox(sandbox_id) + self.save_registry() + return self.registry.get_sandbox(sandbox_id) or record + + async def takeover_sandbox( + self, + session_id: str, + sandbox_id: str, + context: Context | None = None, + ) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + record = await self._revive_persistent_booter_if_needed( + record, sandbox_id, session_id, context + ) + booter = self.session_booter.get(sandbox_id) + status = record.get("status") + if booter is None: + if status == SandboxStatus.CREATING: + raise RuntimeError(f"Sandbox {sandbox_id} is still being created") + if status == SandboxStatus.RESTORING: + raise RuntimeError(f"Sandbox {sandbox_id} is being restored") + if status == SandboxStatus.STOPPING: + raise RuntimeError(f"Sandbox {sandbox_id} is being destroyed") + if status == SandboxStatus.STOPPED: + raise RuntimeError(f"Sandbox {sandbox_id} has been destroyed") + if status == SandboxStatus.ERROR: + raise RuntimeError( + f"Sandbox {sandbox_id} encountered an error during creation" + ) + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + if not await self.booter_available(booter): + self.clear_runtime_state(sandbox_id) + next_status = ( + SandboxStatus.UNKNOWN + if record.get("retention_policy") == "persistent" + else SandboxStatus.ERROR + ) + self.registry.update_sandbox_status(sandbox_id, next_status) + await self.save_registry_async() + raise RuntimeError( + f"Sandbox {sandbox_id} is unavailable (booter health check failed)" + ) + previous_controller_session_id = record.get("controller_session_id") + updated = ( + self.registry.takeover_lease( + sandbox_id=sandbox_id, + session_id=session_id, + user_id=session_id, + ttl=self._lease_timeout(context, session_id), + ) + or record + ) + updated = await self._set_current_sandbox_after_lease( + session_id, sandbox_id, updated + ) + if ( + previous_controller_session_id + and previous_controller_session_id != session_id + and self.registry.get_current_sandbox_id(previous_controller_session_id) + == sandbox_id + ): + self.registry.set_current_sandbox_id(previous_controller_session_id, None) + await self.save_registry_async() + provider = self.get_provider(record.get("provider", "")) + await self._invoke_sandbox_created_hook(provider, sandbox_id) + return updated + + async def _destroy_sandbox_cleanup( + self, + provider: SandboxProvider, + sandbox_id: str, + record: dict, + ) -> None: + destroy_err: Exception | None = None + async with self._sandbox_boot_lock(sandbox_id): + current = self.registry.get_sandbox(sandbox_id) or record + booter = self.session_booter.get(sandbox_id) + if booter is not None: + try: + await provider.destroy_booter(booter, current) + except Exception as exc: + destroy_err = exc + logger.warning( + "[Computer] destroy_booter failed for %s: %s", + sandbox_id, + exc, + ) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.ERROR) + await self.save_registry_async() + else: + self.clear_runtime_state(sandbox_id) + if destroy_err is None: + self.registry.delete_sandbox(sandbox_id) + await self.save_registry_async() + + self.drop_boot_lock(sandbox_id) + + if destroy_err is not None: + raise destroy_err + + if hasattr(provider, "on_sandbox_destroyed"): + try: + await provider.on_sandbox_destroyed(record) + except Exception as hook_err: + logger.warning( + "[Computer] on_sandbox_destroyed hook failed for %s: %s", + sandbox_id, + hook_err, + ) + + async def destroy_sandbox(self, session_id: str, sandbox_id: str) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + if record.get("status") == SandboxStatus.STOPPING: + return record + controller_session_id = record.get("controller_session_id") + if ( + controller_session_id + and controller_session_id != session_id + and self.sandbox_has_active_lease(sandbox_id) + ): + raise RuntimeError(f"Sandbox {sandbox_id} is controlled by another session") + try: + provider = self.get_provider(record.get("provider", "")) + except RuntimeError: + if record.get("retention_policy") != "persistent": + raise + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + await self.save_registry_async() + return record + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.STOPPING) + await self.save_registry_async() + await self.cancel_pending_boot_task(sandbox_id) + await self._destroy_sandbox_cleanup(provider, sandbox_id, record) + return record + + async def destroy_sandbox_deferred(self, session_id: str, sandbox_id: str) -> dict: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + if record.get("status") == SandboxStatus.STOPPING: + return record + controller_session_id = record.get("controller_session_id") + if ( + controller_session_id + and controller_session_id != session_id + and self.sandbox_has_active_lease(sandbox_id) + ): + raise RuntimeError(f"Sandbox {sandbox_id} is controlled by another session") + try: + provider = self.get_provider(record.get("provider", "")) + except RuntimeError: + if record.get("retention_policy") != "persistent": + raise + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + await self.save_registry_async() + return record + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.STOPPING) + await self.save_registry_async() + + async def _run_destroy_cleanup() -> None: + try: + await self._defer_lifecycle_task_start() + await self.cancel_pending_boot_task(sandbox_id) + await self._destroy_sandbox_cleanup(provider, sandbox_id, record) + finally: + self.pending_destroy_tasks.pop(sandbox_id, None) + + task = asyncio.create_task(_run_destroy_cleanup()) + self.pending_destroy_tasks[sandbox_id] = task + return self.registry.get_sandbox(sandbox_id) or record + + async def get_observer_booter_by_id( + self, + sandbox_id: str, + session_id: str | None = None, + *, + require_lease: bool = True, + context: Context | None = None, + ) -> ComputerBooter: + record = self.registry.get_sandbox(sandbox_id) + if record is None or not record.get("managed"): + raise RuntimeError(f"Sandbox {sandbox_id} not found") + controlled_by_other = bool( + session_id + and self.sandbox_controlled_by_other_session(sandbox_id, session_id) + ) + if controlled_by_other and require_lease: + raise RuntimeError(f"Sandbox {sandbox_id} is controlled by another session") + booter = self.session_booter.get(sandbox_id) + record = await self._revive_persistent_booter_if_needed( + record, sandbox_id, session_id, context + ) + booter = self.session_booter.get(sandbox_id) + status = record.get("status") + if booter is None: + if status == SandboxStatus.CREATING: + raise RuntimeError(f"Sandbox {sandbox_id} is still being created") + if status == SandboxStatus.RESTORING: + raise RuntimeError(f"Sandbox {sandbox_id} is being restored") + if status == SandboxStatus.STOPPING: + raise RuntimeError(f"Sandbox {sandbox_id} is being destroyed") + if status == SandboxStatus.STOPPED: + raise RuntimeError(f"Sandbox {sandbox_id} has been destroyed") + if status == SandboxStatus.ERROR: + raise RuntimeError( + f"Sandbox {sandbox_id} encountered an error during creation" + ) + raise RuntimeError(f"Sandbox {sandbox_id} is not running") + if not await self.booter_available(booter): + self.session_booter.pop(sandbox_id, None) + next_status = ( + SandboxStatus.UNKNOWN + if record.get("retention_policy") == "persistent" + else SandboxStatus.ERROR + ) + self.registry.update_sandbox_status(sandbox_id, next_status) + await self.save_registry_async() + raise RuntimeError( + f"Sandbox {sandbox_id} is unavailable (booter health check failed)" + ) + if require_lease and session_id: + lease_timeout = self._lease_timeout(context, session_id) + if not self.acquire_lease(sandbox_id, session_id, ttl=lease_timeout): + raise RuntimeError(f"Sandbox {sandbox_id} is busy") + record = self.registry.get_sandbox(sandbox_id) or record + # Only touch lifecycle when the caller actually holds the lease (or + # the sandbox is unclaimed). Pure observer access must not reset + # idle timers for sandboxes controlled by other sessions. + if session_id and record.get("controller_session_id") == session_id: + self.registry.touch_sandbox(sandbox_id) + await self.save_registry_async() + idle_timeout = record.get("idle_timeout") or 0 + self.schedule_lifecycle_cleanup( + sandbox_id, float(idle_timeout), record.get("expires_at") + ) + return booter + + async def reconcile_on_startup(self) -> None: + for sandbox_id in list(self.pending_boot_tasks): + await self.cancel_pending_boot_task(sandbox_id) + for sandbox_id in list(self.pending_destroy_tasks): + await self.wait_pending_destroy_task(sandbox_id, timeout=None) + self.registry.load() + self.registry.reconcile_startup() + self.clear_all_runtime_state() + + # Validate persistent sandbox records against provider reality. + # If a provider reports that its persistent sandbox no longer exists + # externally, remove the stale registry record so the dashboard does + # not show ghost entries. + for record in list(self.registry.list_sandboxes()): + if record.get("retention_policy") != "persistent": + continue + try: + provider = self.get_provider(record.get("provider", "")) + except RuntimeError: + sandbox_id = record["sandbox_id"] + logger.info( + "[Computer] Provider for persistent sandbox %s is unavailable; keeping registry record", + sandbox_id, + ) + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + continue + if not getattr(provider, "supports_persistent_reconnect", False): + continue + check_exists = getattr(provider, "check_persistent_sandbox_exists", None) + if check_exists is None: + continue + try: + exists = await check_exists(record) + except Exception as exc: + logger.warning( + "[Computer] Failed to check persistent sandbox %s existence: %s", + record.get("sandbox_id"), + exc, + ) + continue + if not exists: + sandbox_id = record["sandbox_id"] + if not getattr(provider, "prune_missing_persistent_records", False): + logger.info( + "[Computer] Persistent sandbox %s was not confirmed externally; keeping registry record as unknown", + sandbox_id, + ) + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.UNKNOWN + ) + continue + logger.info( + "[Computer] Persistent sandbox %s no longer exists externally; removing registry record", + sandbox_id, + ) + self.clear_runtime_state_and_drop_lock(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + + await self.save_registry_async() + + async def restore_persistent_sandboxes( + self, + context: Context, + *, + per_sandbox_timeout: float | None = None, + ) -> tuple[int, int]: + restored = 0 + deleted = 0 + for record in self.registry.list_sandboxes(): + sandbox_id = record["sandbox_id"] + if not record.get("managed"): + continue + if record.get("retention_policy") != "persistent": + continue + if record.get("status") not in { + SandboxStatus.RUNNING, + SandboxStatus.UNKNOWN, + }: + continue + try: + restore_coro = self._revive_persistent_booter_if_needed( + record=record, + sandbox_id=sandbox_id, + session_id=str(record.get("owner_session_id") or "dashboard"), + context=context, + ) + if per_sandbox_timeout is None: + await restore_coro + else: + await asyncio.wait_for(restore_coro, timeout=per_sandbox_timeout) + restored += 1 + except asyncio.TimeoutError: + self.session_booter.pop(sandbox_id, None) + self.clear_idle_state(sandbox_id) + self.registry.update_sandbox_status(sandbox_id, SandboxStatus.UNKNOWN) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + deleted += 1 + logger.warning( + "[Computer] Persistent sandbox restore timed out; keeping registry record as unknown: %s", + sandbox_id, + ) + except Exception as exc: + logger.warning( + "[Computer] Failed to restore persistent sandbox %s: %s", + sandbox_id, + exc, + ) + return restored, deleted + + async def cleanup_managed_sandboxes(self) -> None: + for sandbox_id in list(self.pending_boot_tasks): + await self.cancel_pending_boot_task(sandbox_id) + for sandbox_id in list(self.pending_destroy_tasks): + await self.wait_pending_destroy_task(sandbox_id, timeout=None) + managed_records = [ + record + for record in self.list_sandboxes() + if record["sandbox_id"] not in self.pending_destroy_tasks + ] + for record in managed_records: + sandbox_id = record["sandbox_id"] + if record.get("retention_policy") == "persistent": + booter = self.session_booter.get(sandbox_id) + if booter is not None: + try: + await booter.shutdown() + except Exception as shutdown_err: + logger.warning( + "[Computer] Failed to close persistent sandbox runtime %s: %s", + sandbox_id, + shutdown_err, + ) + self.clear_runtime_state_and_drop_lock(sandbox_id) + continue + provider = None + try: + provider = self.get_provider(record.get("provider", "")) + except RuntimeError as provider_error: + logger.warning( + "[Computer] Provider unavailable for sandbox %s: %s", + sandbox_id, + provider_error, + ) + booter = self.session_booter.get(sandbox_id) + if booter is not None: + if provider is not None: + try: + await provider.destroy_booter(booter, record) + except Exception as shutdown_err: + logger.warning( + "[Computer] Failed to shutdown managed sandbox %s: %s", + sandbox_id, + shutdown_err, + ) + # Always pop the booter so memory is freed even when the + # provider has already been unregistered. + self.clear_runtime_state(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + self.clear_runtime_state(sandbox_id) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + + def clear_idle_state(self, sandbox_id: str) -> None: + state = self.idle_state.pop(sandbox_id, None) + if ( + state is not None + and not state.task.done() + and state.task is not asyncio.current_task() + ): + state.task.cancel() + + def clear_expiration_state(self, sandbox_id: str) -> None: + state = self.expiration_state.pop(sandbox_id, None) + if ( + state is not None + and not state.task.done() + and state.task is not asyncio.current_task() + ): + state.task.cancel() + + def schedule_idle_cleanup(self, sandbox_id: str, timeout: float) -> None: + self.clear_idle_state(sandbox_id) + if timeout <= 0: + return + self.registry.touch_sandbox(sandbox_id) + expires_at = time.monotonic() + timeout + task = asyncio.create_task( + self._expire_when_idle(sandbox_id, timeout, expires_at) + ) + self.idle_state[sandbox_id] = SandboxIdleState(expires_at=expires_at, task=task) + + def schedule_ttl_cleanup(self, sandbox_id: str, expires_at: float | None) -> None: + self.clear_expiration_state(sandbox_id) + if expires_at is None: + return + monotonic_expires_at = time.monotonic() + max( + 0.0, float(expires_at) - time.time() + ) + task = asyncio.create_task( + self._expire_at_fixed_time( + sandbox_id, float(expires_at), monotonic_expires_at + ) + ) + self.expiration_state[sandbox_id] = SandboxExpirationState( + expires_at=float(expires_at), + monotonic_expires_at=monotonic_expires_at, + task=task, + ) + + def schedule_lifecycle_cleanup( + self, + sandbox_id: str, + idle_timeout: float, + expires_at: float | None, + ) -> None: + if idle_timeout > 0: + self.clear_expiration_state(sandbox_id) + self.schedule_idle_cleanup(sandbox_id, idle_timeout) + return + self.clear_idle_state(sandbox_id) + self.schedule_ttl_cleanup(sandbox_id, expires_at) + + async def _expire_at_fixed_time( + self, + sandbox_id: str, + expires_at: float, + monotonic_expires_at: float, + ) -> None: + current_task = asyncio.current_task() + destroy_attempts = 0 + try: + while True: + remaining = monotonic_expires_at - time.monotonic() + if remaining > 0: + await asyncio.sleep(remaining) + state = self.expiration_state.get(sandbox_id) + if ( + state is None + or state.task is not current_task + or state.expires_at != float(expires_at) + ): + return + record = self.registry.get_sandbox(sandbox_id) + if record is None: + self.session_booter.pop(sandbox_id, None) + return + if float(record.get("expires_at") or 0) != float(expires_at): + return + booter = self.session_booter.get(sandbox_id) + if booter is not None: + try: + provider = self.get_provider(record.get("provider", "")) + await provider.destroy_booter(booter, record) + except Exception as shutdown_err: + logger.warning( + "[Computer] Failed to shutdown expired sandbox %s: %s", + sandbox_id, + shutdown_err, + ) + destroy_attempts += 1 + if destroy_attempts < MAX_IDLE_DESTROY_ATTEMPTS: + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.UNKNOWN + ) + await self.save_registry_async() + await asyncio.sleep(SANDBOX_TTL_DESTROY_RETRY_SECONDS) + continue + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.ERROR + ) + await self.save_registry_async() + return + self.clear_runtime_state(sandbox_id) + if record.get("retention_policy") == "persistent": + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.STOPPED + ) + else: + self.registry.delete_sandbox(sandbox_id) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + return + finally: + state = self.expiration_state.get(sandbox_id) + if ( + state is not None + and state.task is current_task + and state.expires_at == float(expires_at) + ): + self.expiration_state.pop(sandbox_id, None) + + async def _expire_when_idle( + self, sandbox_id: str, timeout: float, initial_expires_at: float + ) -> None: + current_expires_at = initial_expires_at + destroy_attempts = 0 + try: + while True: + remaining = current_expires_at - time.monotonic() + if remaining > 0: + await asyncio.sleep(remaining) + state = self.idle_state.get(sandbox_id) + if state is None or state.expires_at != current_expires_at: + return + record = self.registry.get_sandbox(sandbox_id) + if record is None: + self.session_booter.pop(sandbox_id, None) + return + if self.sandbox_has_active_lease(sandbox_id): + current_expires_at = time.monotonic() + timeout + self.idle_state[sandbox_id] = SandboxIdleState( + expires_at=current_expires_at, task=state.task + ) + continue + if record.get("retention_policy") == "persistent": + return + booter = self.session_booter.get(sandbox_id) + if booter is not None: + try: + provider = self.get_provider(record.get("provider", "")) + self.session_booter.pop(sandbox_id, None) + await provider.destroy_booter(booter, record) + except Exception as shutdown_err: + logger.warning( + "[Computer] Failed to shutdown idle sandbox %s: %s", + sandbox_id, + shutdown_err, + ) + try: + booter_available = await self.booter_available(booter) + except Exception: + booter_available = False + if booter_available: + destroy_attempts += 1 + if destroy_attempts < MAX_IDLE_DESTROY_ATTEMPTS: + self.session_booter[sandbox_id] = booter + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.UNKNOWN + ) + await self.save_registry_async() + # Retry cleanup after the normal timeout instead of + # leaving the sandbox without any scheduled cleanup. + current_expires_at = time.monotonic() + timeout + self.idle_state[sandbox_id] = SandboxIdleState( + expires_at=current_expires_at, task=state.task + ) + continue + logger.warning( + "[Computer] Giving up on idle sandbox %s after %d destroy attempts", + sandbox_id, + destroy_attempts, + ) + self.session_booter[sandbox_id] = booter + self.registry.update_sandbox_status( + sandbox_id, SandboxStatus.ERROR + ) + self.idle_state.pop(sandbox_id, None) + await self.save_registry_async() + return + self.clear_runtime_state(sandbox_id) + self.registry.delete_sandbox(sandbox_id) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + return + self.registry.delete_sandbox(sandbox_id) + self.drop_boot_lock(sandbox_id) + await self.save_registry_async() + return + except asyncio.CancelledError: + raise + finally: + state = self.idle_state.get(sandbox_id) + if state is not None and state.expires_at == current_expires_at: + self.idle_state.pop(sandbox_id, None) + + @staticmethod + async def _sync_skills_to_booter( + booter: ComputerBooter, + provider_id: str | None = None, + ) -> None: + """Delay-import wrapper to avoid circular imports.""" + from astrbot.core.computer.computer_client import _sync_skills_to_sandbox + + await _sync_skills_to_sandbox(booter, provider_id=provider_id) diff --git a/astrbot/core/computer/sandbox_models.py b/astrbot/core/computer/sandbox_models.py new file mode 100644 index 0000000000..32b6622030 --- /dev/null +++ b/astrbot/core/computer/sandbox_models.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from astrbot.core.computer.sandbox_timeouts import lease_is_active + + +class SandboxRetentionPolicy(str, Enum): + TEMPORARY = "temporary" + PERSISTENT = "persistent" + + +class SandboxStatus(str, Enum): + CREATING = "creating" + RESTORING = "restoring" + RUNNING = "running" + ERROR = "error" + STOPPING = "stopping" + STOPPED = "stopped" + UNKNOWN = "unknown" + + +@dataclass(slots=True) +class SandboxRecord: + sandbox_id: str + sandbox_name: str + provider: str + managed: bool + created_by_astrbot: bool + is_default: bool = False + owner_user_id: str | None = None + owner_session_id: str | None = None + created_by_user_id: str | None = None + created_by_session_id: str | None = None + controller_user_id: str | None = None + controller_session_id: str | None = None + lease_expires_at: float | None = None + last_used_at: float | None = None + idle_timeout: int | float | None = None + expires_at: float | None = None + retention_policy: SandboxRetentionPolicy = SandboxRetentionPolicy.TEMPORARY + status: SandboxStatus = SandboxStatus.RUNNING + connect_info: dict[str, Any] = field(default_factory=dict) + capabilities: list[str] = field(default_factory=list) + tool_names: list[str] = field(default_factory=list) + labels: dict[str, Any] = field(default_factory=dict) + notes: str | None = None + created_hook_fired: bool = False + + @staticmethod + def _required_string(data: dict[str, Any], field_name: str) -> str: + value = data[field_name] + if not isinstance(value, str): + raise ValueError(f"{field_name} must be a non-empty string") + value = value.strip() + if not value: + raise ValueError(f"{field_name} must be a non-empty string") + return value + + @classmethod + def _required_provider(cls, data: dict[str, Any]) -> str: + if "provider" in data: + return cls._required_string(data, "provider") + return cls._required_string(data, "booter_type") + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> SandboxRecord: + return cls( + sandbox_id=cls._required_string(data, "sandbox_id"), + sandbox_name=cls._required_string(data, "sandbox_name"), + provider=cls._required_provider(data), + managed=bool(data["managed"]), + created_by_astrbot=bool(data["created_by_astrbot"]), + is_default=bool(data.get("is_default", False)), + owner_user_id=data.get("owner_user_id"), + owner_session_id=data.get("owner_session_id"), + created_by_user_id=data.get("created_by_user_id") + or data.get("owner_user_id"), + created_by_session_id=data.get("created_by_session_id") + or data.get("owner_session_id"), + controller_user_id=data.get("controller_user_id"), + controller_session_id=data.get("controller_session_id"), + lease_expires_at=data.get("lease_expires_at"), + last_used_at=data.get("last_used_at"), + idle_timeout=data.get("idle_timeout"), + expires_at=data.get("expires_at"), + retention_policy=SandboxRetentionPolicy( + data.get("retention_policy", SandboxRetentionPolicy.TEMPORARY) + ), + status=SandboxStatus(data.get("status", SandboxStatus.RUNNING)), + connect_info=dict(data.get("connect_info") or {}), + capabilities=sorted( + str(item) for item in data.get("capabilities", []) if item + ), + tool_names=sorted(str(item) for item in data.get("tool_names", []) if item), + labels=dict(data.get("labels") or {}), + notes=data.get("notes"), + created_hook_fired=bool(data.get("created_hook_fired", False)), + ) + + def to_dict(self) -> dict[str, Any]: + return { + "sandbox_id": self.sandbox_id, + "sandbox_name": self.sandbox_name, + "provider": self.provider, + "managed": self.managed, + "created_by_astrbot": self.created_by_astrbot, + "is_default": self.is_default, + "owner_user_id": self.owner_user_id, + "owner_session_id": self.owner_session_id, + "created_by_user_id": self.created_by_user_id, + "created_by_session_id": self.created_by_session_id, + "controller_user_id": self.controller_user_id, + "controller_session_id": self.controller_session_id, + "lease_expires_at": self.lease_expires_at, + "last_used_at": self.last_used_at, + "idle_timeout": self.idle_timeout, + "expires_at": self.expires_at, + "retention_policy": self.retention_policy.value, + "status": self.status.value, + "connect_info": dict(self.connect_info), + "capabilities": list(self.capabilities), + "tool_names": list(self.tool_names), + "labels": dict(self.labels), + "notes": self.notes, + "created_hook_fired": self.created_hook_fired, + } + + def has_active_lease(self, *, now: float | None = None) -> bool: + return lease_is_active( + self.controller_session_id, self.lease_expires_at, now=now + ) + + def is_controlled_by(self, session_id: str, *, now: float | None = None) -> bool: + return self.controller_session_id == session_id and self.has_active_lease( + now=now + ) diff --git a/astrbot/core/computer/sandbox_provider.py b/astrbot/core/computer/sandbox_provider.py new file mode 100644 index 0000000000..1930f06911 --- /dev/null +++ b/astrbot/core/computer/sandbox_provider.py @@ -0,0 +1,67 @@ +from __future__ import annotations + +from typing import Any, Protocol + +from astrbot.core.computer.booters.base import ComputerBooter +from astrbot.core.star.context import Context + + +class SandboxProvider(Protocol): + """Protocol for plugin-provided sandbox runtime providers. + + Required attributes: + provider_id: Unique provider identifier (e.g. "browser", "python_sandbox"). + capabilities: Set of capability strings (e.g. {"shell", "python", "gui"}). + tool_names: Set of tool names this provider contributes to the LLM. + + Optional attributes (core uses ``getattr`` with safe fallbacks): + provider_api_version: Provider API compatibility version. Defaults to "1.0". + system_prompt: Runtime-specific instructions exposed in provider metadata. + plugin_config: Plugin-specific configuration dict. Implementations are + encouraged to accept this as an ``__init__`` parameter so the + provider is fully initialized at construction time. + auto_sync_skills: If ``False``, core will skip automatic skill sync after + booting a sandbox for this provider. Defaults to ``True``. + prune_missing_persistent_records: If ``True``, startup reconciliation may + delete persistent registry records when the provider confirms the + external sandbox is missing. Defaults to ``False`` to avoid data loss + from transient reconnect failures. + """ + + provider_id: str + capabilities: set[str] + tool_names: set[str] + system_prompt: str = "" + plugin_config: dict[str, Any] | None = None + provider_api_version: str = "1.0" + auto_sync_skills: bool = True + supports_persistent_reconnect: bool = False + prune_missing_persistent_records: bool = False + + def build_create_config(self, context: Context, session_id: str) -> dict: ... + + def build_connect_info(self, sandbox_name: str, config: dict) -> dict: ... + + def update_connect_info(self, record: dict, *, sandbox_name: str) -> dict: ... + + def update_connect_info_after_boot( + self, record: dict, booter: ComputerBooter + ) -> dict | None: ... + + async def create_booter( + self, + context: Context, + session_id: str, + sandbox_id: str, + config: dict, + ) -> ComputerBooter: ... + + async def destroy_booter(self, booter: ComputerBooter, record: dict) -> None: ... + + # Optional lifecycle hooks -- core checks ``hasattr`` before invoking. + + async def on_sandbox_created(self, record: dict) -> None: + """Called after a sandbox is successfully created and leased.""" + + async def on_sandbox_destroyed(self, record: dict) -> None: + """Called after a sandbox is destroyed and removed from registry.""" diff --git a/astrbot/core/computer/sandbox_registry.py b/astrbot/core/computer/sandbox_registry.py new file mode 100644 index 0000000000..f867495ba0 --- /dev/null +++ b/astrbot/core/computer/sandbox_registry.py @@ -0,0 +1,432 @@ +from __future__ import annotations + +import asyncio +import json +import time +from copy import deepcopy +from pathlib import Path +from typing import Any + +from astrbot.api import logger +from astrbot.core.computer.sandbox_models import SandboxRecord, SandboxStatus +from astrbot.core.computer.sandbox_timeouts import ( + lease_expires_at_from_timeout, + lease_is_active, +) +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + +_UNSET = object() +_SCHEMA_VERSION = 1 + + +def _default_registry_payload() -> dict[str, Any]: + return { + "schema_version": _SCHEMA_VERSION, + "default_sandbox_id": None, + "default_sandbox_ids": {}, + "sandboxes": {}, + "session_current": {}, + } + + +def _coerce_schema_version(value: Any) -> int: + try: + version = int(value) + except (TypeError, ValueError): + return _SCHEMA_VERSION + return version if version > 0 else _SCHEMA_VERSION + + +class SandboxRegistry: + def __init__(self, storage_path: str | Path | None = None): + if storage_path is None: + storage_path = Path(get_astrbot_data_path()) / "sandbox_registry.json" + self.storage_path = Path(storage_path) + self._payload = _default_registry_payload() + self._save_lock = asyncio.Lock() + + @property + def default_sandbox_id(self) -> str | None: + return self._payload["default_sandbox_id"] + + def get_default_sandbox_id(self, provider: str) -> str | None: + sandbox_id = self._payload.get("default_sandbox_ids", {}).get(provider) + if sandbox_id and sandbox_id in self._payload["sandboxes"]: + return sandbox_id + if self._payload["default_sandbox_id"]: + record = self.get_sandbox(self._payload["default_sandbox_id"]) + if record and record.get("provider") == provider: + return self._payload["default_sandbox_id"] + return None + + def get_sandbox(self, sandbox_id: str | None) -> dict[str, Any] | None: + if sandbox_id is None: + return None + record = self._payload["sandboxes"].get(sandbox_id) + return deepcopy(record) if record is not None else None + + def list_sandboxes(self) -> list[dict[str, Any]]: + return [deepcopy(item) for item in self._payload["sandboxes"].values()] + + def set_default_sandbox_id(self, sandbox_id: str | None) -> None: + old_default = self._payload["default_sandbox_id"] + self._payload["default_sandbox_id"] = sandbox_id + if sandbox_id and sandbox_id in self._payload["sandboxes"]: + record = self._payload["sandboxes"][sandbox_id] + provider = record.get("provider") + if provider: + old_provider_default = self._payload.setdefault( + "default_sandbox_ids", {} + ).get(provider) + if ( + old_provider_default + and old_provider_default in self._payload["sandboxes"] + ): + self._payload["sandboxes"][old_provider_default]["is_default"] = ( + False + ) + self._payload["default_sandbox_ids"][provider] = sandbox_id + record["is_default"] = True + elif old_default and old_default in self._payload["sandboxes"]: + self._payload["sandboxes"][old_default]["is_default"] = False + + def get_current_sandbox_id(self, session_id: str) -> str | None: + return self._payload["session_current"].get(session_id) + + def set_current_sandbox_id(self, session_id: str, sandbox_id: str | None) -> None: + if sandbox_id is None: + self._payload["session_current"].pop(session_id, None) + else: + self._payload["session_current"][session_id] = sandbox_id + + def upsert_sandbox( + self, + *, + sandbox_id: str, + sandbox_name: str, + provider: str, + managed: bool, + created_by_astrbot: bool, + owner_user_id: str | None, + owner_session_id: str | None, + connect_info: dict[str, Any], + is_default: bool | object = _UNSET, + status: str | object = _UNSET, + idle_timeout: int | float | None | object = _UNSET, + expires_at: float | None | object = _UNSET, + retention_policy: str | object = _UNSET, + last_used_at: float | None | object = _UNSET, + controller_user_id: str | None | object = _UNSET, + controller_session_id: str | None | object = _UNSET, + lease_expires_at: float | None | object = _UNSET, + labels: dict[str, Any] | None | object = _UNSET, + capabilities: list[str] | set[str] | None | object = _UNSET, + tool_names: list[str] | set[str] | None | object = _UNSET, + notes: str | None | object = _UNSET, + ) -> dict[str, Any]: + record = self._payload["sandboxes"].get(sandbox_id, {}) + record.update( + { + "sandbox_id": sandbox_id, + "sandbox_name": sandbox_name, + "provider": provider, + "managed": managed, + "created_by_astrbot": created_by_astrbot, + "owner_user_id": owner_user_id, + "owner_session_id": owner_session_id, + "created_by_user_id": owner_user_id, + "created_by_session_id": owner_session_id, + "connect_info": deepcopy(connect_info), + } + ) + defaults = { + "controller_user_id": None, + "controller_session_id": None, + "lease_expires_at": None, + "last_used_at": None, + "idle_timeout": None, + "expires_at": None, + "retention_policy": "temporary", + "status": "running", + "is_default": False, + "labels": {}, + "capabilities": [], + "tool_names": [], + "notes": None, + "created_hook_fired": False, + } + updates = { + "controller_user_id": controller_user_id, + "controller_session_id": controller_session_id, + "lease_expires_at": lease_expires_at, + "last_used_at": last_used_at, + "idle_timeout": idle_timeout, + "expires_at": expires_at, + "retention_policy": retention_policy, + "status": status, + "is_default": is_default, + "labels": deepcopy(labels) if labels is not _UNSET else _UNSET, + "capabilities": sorted(capabilities) + if capabilities is not _UNSET + else _UNSET, + "tool_names": sorted(tool_names) if tool_names is not _UNSET else _UNSET, + "notes": notes, + "created_hook_fired": _UNSET, + } + for field_name, default_value in defaults.items(): + value = updates[field_name] + if value is _UNSET: + record.setdefault(field_name, deepcopy(default_value)) + else: + record[field_name] = value + record = SandboxRecord.from_dict(record).to_dict() + self._payload["sandboxes"][sandbox_id] = record + if is_default is True or ( + managed and self._payload["default_sandbox_id"] is None + ): + self.set_default_sandbox_id(sandbox_id) + return deepcopy(record) + + def delete_sandbox(self, sandbox_id: str) -> None: + was_default = self._payload["default_sandbox_id"] == sandbox_id + deleted = self._payload["sandboxes"].pop(sandbox_id, None) + if deleted: + provider = deleted.get("provider") + if ( + provider + and self._payload.get("default_sandbox_ids", {}).get(provider) + == sandbox_id + ): + self._payload["default_sandbox_ids"].pop(provider, None) + for candidate_id, candidate in self._payload["sandboxes"].items(): + if ( + candidate.get("managed") + and candidate.get("provider") == provider + ): + self.set_default_sandbox_id(candidate_id) + break + if was_default: + self._payload["default_sandbox_id"] = None + for candidate_id, candidate in self._payload["sandboxes"].items(): + if candidate.get("managed"): + self.set_default_sandbox_id(candidate_id) + break + stale_sessions = [ + session_id + for session_id, current_id in self._payload["session_current"].items() + if current_id == sandbox_id + ] + for session_id in stale_sessions: + self._payload["session_current"].pop(session_id, None) + + def touch_sandbox( + self, sandbox_id: str, *, ts: float | None = None + ) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + record["last_used_at"] = ts if ts is not None else time.time() + return deepcopy(record) + + def update_sandbox_config( + self, + sandbox_id: str, + *, + sandbox_name: str | object = _UNSET, + connect_info: dict[str, Any] | object = _UNSET, + idle_timeout: int | float | None | object = _UNSET, + expires_at: int | float | None | object = _UNSET, + retention_policy: str | object = _UNSET, + ) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + if sandbox_name is not _UNSET: + name = str(sandbox_name).strip() + if not name: + raise ValueError("sandbox_name must be a non-empty string") + record["sandbox_name"] = name + if connect_info is not _UNSET: + record["connect_info"] = deepcopy(connect_info) + if idle_timeout is not _UNSET: + record["idle_timeout"] = idle_timeout + if expires_at is not _UNSET: + record["expires_at"] = expires_at + if retention_policy is not _UNSET: + record["retention_policy"] = retention_policy + return deepcopy(record) + + def update_sandbox_status( + self, sandbox_id: str, status: str + ) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + record["status"] = getattr(status, "value", status) + return deepcopy(record) + + def has_created_hook_fired(self, sandbox_id: str) -> bool: + record = self._payload["sandboxes"].get(sandbox_id) + return bool(record and record.get("created_hook_fired")) + + def mark_created_hook_fired(self, sandbox_id: str) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + record["created_hook_fired"] = True + return deepcopy(record) + + def acquire_lease( + self, + *, + sandbox_id: str, + session_id: str, + user_id: str | None, + ttl: int | float, + now: float | None = None, + ) -> bool: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return False + current_time = time.time() if now is None else now + controller_session_id = record.get("controller_session_id") + lease_expires_at = record.get("lease_expires_at") + if lease_is_active( + controller_session_id, lease_expires_at, now=current_time + ) and (controller_session_id != session_id): + return False + record["controller_session_id"] = session_id + record["controller_user_id"] = user_id + next_lease_expires_at = lease_expires_at_from_timeout(ttl, now=current_time) + if controller_session_id == session_id and lease_is_active( + controller_session_id, lease_expires_at, now=current_time + ): + if lease_expires_at is None: + next_lease_expires_at = None + elif next_lease_expires_at is not None: + next_lease_expires_at = max( + float(lease_expires_at), next_lease_expires_at + ) + record["lease_expires_at"] = next_lease_expires_at + return True + + def release_lease(self, sandbox_id: str) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + record["controller_session_id"] = None + record["controller_user_id"] = None + record["lease_expires_at"] = None + return deepcopy(record) + + def takeover_lease( + self, + *, + sandbox_id: str, + session_id: str, + user_id: str | None, + ttl: int | float, + now: float | None = None, + ) -> dict[str, Any] | None: + record = self._payload["sandboxes"].get(sandbox_id) + if record is None: + return None + current_time = time.time() if now is None else now + record["controller_session_id"] = session_id + record["controller_user_id"] = user_id + record["lease_expires_at"] = lease_expires_at_from_timeout( + ttl, now=current_time + ) + return deepcopy(record) + + def reconcile_startup(self) -> None: + self._payload["session_current"] = {} + for sandbox_id, record in list(self._payload["sandboxes"].items()): + if record.get("retention_policy") != "persistent": + self._payload["sandboxes"].pop(sandbox_id, None) + continue + record["controller_session_id"] = None + record["controller_user_id"] = None + record["lease_expires_at"] = None + if record.get("status") == SandboxStatus.RUNNING: + record["status"] = SandboxStatus.UNKNOWN.value + elif record.get("status") in { + SandboxStatus.CREATING, + SandboxStatus.RESTORING, + }: + record["status"] = SandboxStatus.ERROR.value + self._prune_default_references() + + def _prune_default_references(self) -> None: + sandboxes = self._payload["sandboxes"] + default_sandbox_id = self._payload.get("default_sandbox_id") + if default_sandbox_id not in sandboxes: + self._payload["default_sandbox_id"] = None + default_sandbox_ids = self._payload.get("default_sandbox_ids") or {} + valid_default_sandbox_ids = { + provider: sandbox_id + for provider, sandbox_id in default_sandbox_ids.items() + if sandbox_id in sandboxes + and sandboxes[sandbox_id].get("provider") == provider + } + self._payload["default_sandbox_ids"] = valid_default_sandbox_ids + for record in sandboxes.values(): + record["is_default"] = False + if self._payload["default_sandbox_id"]: + sandboxes[self._payload["default_sandbox_id"]]["is_default"] = True + for sandbox_id in valid_default_sandbox_ids.values(): + if sandbox_id in sandboxes: + sandboxes[sandbox_id]["is_default"] = True + + def load(self) -> None: + if not self.storage_path.exists(): + self._payload = _default_registry_payload() + return + try: + payload = json.loads(self.storage_path.read_text(encoding="utf-8")) + except Exception as exc: + logger.warning("Failed to load sandbox registry: %s", exc) + self._payload = _default_registry_payload() + return + if not isinstance(payload, dict): + logger.warning("Failed to load sandbox registry: payload is not an object") + self._payload = _default_registry_payload() + return + self._payload = _default_registry_payload() + self._payload["schema_version"] = _coerce_schema_version( + payload.get("schema_version") + ) + self._payload.update({key: payload.get(key) for key in self._payload}) + self._payload["schema_version"] = _coerce_schema_version( + self._payload.get("schema_version") + ) + self._payload["default_sandbox_ids"] = dict( + self._payload.get("default_sandbox_ids") or {} + ) + self._payload["sandboxes"] = dict(self._payload.get("sandboxes") or {}) + self._payload["session_current"] = dict( + self._payload.get("session_current") or {} + ) + + def _write_payload(self, payload: dict[str, Any]) -> None: + self.storage_path.parent.mkdir(parents=True, exist_ok=True) + temp_path = self.storage_path.with_name( + f"{self.storage_path.name}.{time.time_ns()}.tmp" + ) + try: + temp_path.write_text( + json.dumps(payload, ensure_ascii=False, indent=2, sort_keys=True), + encoding="utf-8", + ) + temp_path.replace(self.storage_path) + finally: + if temp_path.exists(): + temp_path.unlink() + + def save(self) -> None: + self._write_payload(deepcopy(self._payload)) + + async def save_async(self) -> None: + async with self._save_lock: + payload = deepcopy(self._payload) + await asyncio.to_thread(self._write_payload, payload) diff --git a/astrbot/core/computer/sandbox_timeouts.py b/astrbot/core/computer/sandbox_timeouts.py new file mode 100644 index 0000000000..d1641b51b3 --- /dev/null +++ b/astrbot/core/computer/sandbox_timeouts.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import math +import time +from collections.abc import Mapping +from typing import Any + +DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS = 600.0 + + +def _coerce_timeout(value: Any, default: float) -> float: + if isinstance(value, bool): + return default + try: + timeout = float(value) + except (TypeError, ValueError): + return default + if not math.isfinite(timeout) or timeout < 0: + return default + return timeout + + +def resolve_sandbox_timeout( + config: Mapping[str, Any], + key: str, + *, + aliases: tuple[str, ...] = (), + default: float, +) -> float: + for candidate in (key, *aliases): + if candidate in config: + return _coerce_timeout(config.get(candidate), default) + return default + + +def lease_is_active( + controller_session_id: str | None, + lease_expires_at: float | None, + *, + now: float | None = None, +) -> bool: + if not controller_session_id: + return False + if lease_expires_at is None: + return True + current_time = time.time() if now is None else now + return float(lease_expires_at) > current_time + + +def lease_expires_at_from_timeout( + timeout: float | int | None, + *, + now: float | None = None, +) -> float | None: + if timeout is None: + return None + current_time = time.time() if now is None else now + normalized = _coerce_timeout(timeout, DEFAULT_SANDBOX_LEASE_TIMEOUT_SECONDS) + if normalized <= 0: + return None + return current_time + normalized + + +def expires_at_from_timeout( + timeout: float | int | None, + *, + now: float | None = None, +) -> float | None: + return lease_expires_at_from_timeout(timeout, now=now) + + +def idle_cleanup_at_from_record( + *, + last_used_at: float | None, + idle_timeout: float | int | None, + now: float | None = None, +) -> float | None: + if last_used_at is None: + return None + current_timeout = _coerce_timeout(idle_timeout, 0.0) + if current_timeout <= 0: + return None + candidate = float(last_used_at) + current_timeout + return candidate + + +def get_provider_sandbox_config(context: Any, session_id: str) -> dict[str, Any]: + if context is None: + return {} + get_config = getattr(context, "get_config", None) + if not callable(get_config): + return {} + config = get_config(umo=session_id) + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + return sandbox_cfg if isinstance(sandbox_cfg, dict) else {} diff --git a/astrbot/core/computer/sandbox_tool_binding.py b/astrbot/core/computer/sandbox_tool_binding.py new file mode 100644 index 0000000000..21f49733c0 --- /dev/null +++ b/astrbot/core/computer/sandbox_tool_binding.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from astrbot.core.agent.tool import FunctionTool +from astrbot.core.tools.registry import ( + BuiltinToolConfigRule, + build_builtin_tool_config_rule, +) + +TFunctionTool = type[FunctionTool] + +_sandbox_provider_tool_config_rules: dict[str, BuiltinToolConfigRule] = {} + + +def tool_available_in_runtime(tool: Any, runtime: str) -> bool: + """Return whether a tool should be exposed for the computer-use runtime. + + Provider-specific sandbox tools are registered once when their provider is + enabled. They are visible to all sandbox sessions and hidden from local/none + runtimes. + """ + tool_provider = getattr(tool, "sandbox_provider_id", None) + if not tool_provider: + return True + return runtime == "sandbox" + + +def mark_tool_as_sandbox_provider_tool(tool: Any, provider_id: str) -> Any: + provider_id = _normalize_provider_id(provider_id) + tool.sandbox_provider_id = provider_id + marker = f"[Sandbox provider-specific tool: {provider_id}]" + description = str(getattr(tool, "description", "") or "") + if marker not in description: + tool.description = ( + f"{marker} This tool only works when the current sandbox uses provider " + f"'{provider_id}'. If the current sandbox uses another provider, switch or " + f"create a '{provider_id}' sandbox first. {description}" + ).strip() + return tool + + +def sandbox_provider_tool( + provider_id: str, + *, + config: dict[str, Any] | None = None, +) -> Callable[[TFunctionTool], TFunctionTool]: + """Mark a FunctionTool class as belonging to a sandbox provider. + + Sandbox provider tools are plugin-owned tools with sandbox runtime semantics. + They are not AstrBot Core builtin tools, but may still expose config tags in + the dashboard through the same condition format as builtin tools. + """ + + normalized_provider_id = _normalize_provider_id(provider_id) + + def _register(cls: TFunctionTool) -> TFunctionTool: + tool_name = _resolve_tool_name(cls) + cls.sandbox_provider_id = normalized_provider_id + if config is not None: + _sandbox_provider_tool_config_rules[tool_name] = ( + build_builtin_tool_config_rule(config) + ) + return cls + + return _register + + +def get_sandbox_provider_tool_config_statuses( + tool_name: str, + config_entries: list[dict[str, Any]], +) -> list[dict[str, Any]]: + rule = _sandbox_provider_tool_config_rules.get(tool_name) + if rule is None: + return [] + + statuses: list[dict[str, Any]] = [] + for entry in config_entries: + config = entry.get("config") + if not isinstance(config, dict): + continue + + conditions = rule.evaluate(config) + enabled = bool(conditions) and all( + bool(condition.get("matched")) for condition in conditions + ) + statuses.append( + { + "conf_id": entry.get("conf_id"), + "conf_name": entry.get("conf_name"), + "enabled": enabled, + "matched_conditions": [ + condition for condition in conditions if condition.get("matched") + ], + "failed_conditions": [ + condition + for condition in conditions + if not condition.get("matched") + ], + } + ) + return statuses + + +def _resolve_tool_name(tool_cls: type[FunctionTool]) -> str: + tool_name = getattr(tool_cls, "name", None) + if isinstance(tool_name, str) and tool_name: + return tool_name + + dataclass_fields = getattr(tool_cls, "__dataclass_fields__", {}) + name_field = dataclass_fields.get("name") + if name_field is not None and isinstance(name_field.default, str): + return name_field.default + + raise ValueError( + f"Sandbox provider tool class {tool_cls.__module__}.{tool_cls.__name__} does not define a valid name.", + ) + + +def _normalize_provider_id(provider_id: str | None) -> str: + return "" if provider_id is None else str(provider_id).strip().lower() diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 4d62becb55..56ffd7943d 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -17,6 +17,8 @@ DASHBOARD_INITIAL_PASSWORD_ENV = "ASTRBOT_DASHBOARD_INITIAL_PASSWORD" logger = logging.getLogger("astrbot") +CORE_COMPUTER_RUNTIME_IDS = {"local", "sandbox", "none"} + class RateLimitStrategy(enum.Enum): STALL = "stall" @@ -76,6 +78,7 @@ def __init__( ) # 检查配置完整性,并插入 has_new = self.check_config_integrity(default_config, conf) + has_new |= self._migrate_legacy_sandbox_runtime(conf) if ( "dashboard" in conf and isinstance(conf["dashboard"], dict) @@ -154,6 +157,33 @@ def _parse_schema(schema: dict, conf: dict) -> None: return conf + def _migrate_legacy_sandbox_runtime(self, conf: dict) -> bool: + provider_settings = conf.get("provider_settings") + if not isinstance(provider_settings, dict): + return False + + runtime = provider_settings.get("computer_use_runtime") + if runtime in CORE_COMPUTER_RUNTIME_IDS or not runtime: + return False + + # Older configs stored sandbox provider IDs directly as the runtime. + # Preserve that value as the selected sandbox booter without teaching + # core about concrete provider names. + sandbox_cfg = provider_settings.get("sandbox") + if not isinstance(sandbox_cfg, dict): + sandbox_cfg = {} + provider_settings["sandbox"] = sandbox_cfg + + if not sandbox_cfg.get("booter"): + sandbox_cfg["booter"] = runtime + logger.info( + "Config key migrated: provider_settings.computer_use_runtime %s -> sandbox", + runtime, + ) + + provider_settings["computer_use_runtime"] = "sandbox" + return True + def check_config_integrity(self, refer_conf: dict, conf: dict, path=""): """检查配置完整性,如果有新的配置项或顺序不一致则返回 True""" has_new = False diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index d060ce1c3d..eeb11b9496 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2,7 +2,6 @@ import os -from astrbot.core.computer.booters.cua_defaults import CUA_DEFAULT_CONFIG from astrbot.core.utils.astrbot_path import get_astrbot_data_path VERSION = "4.25.5" @@ -169,21 +168,17 @@ "computer_use_runtime": "none", "computer_use_require_admin": True, "sandbox": { - "booter": "shipyard_neo", - "shipyard_endpoint": "", - "shipyard_access_token": "", - "shipyard_ttl": 3600, - "shipyard_max_sessions": 10, - "shipyard_neo_endpoint": "", - "shipyard_neo_access_token": "", - "shipyard_neo_profile": "python-default", - "shipyard_neo_ttl": 3600, - "cua_image": CUA_DEFAULT_CONFIG["image"], - "cua_os_type": CUA_DEFAULT_CONFIG["os_type"], - "cua_idle_timeout": CUA_DEFAULT_CONFIG["idle_timeout"], - "cua_telemetry_enabled": CUA_DEFAULT_CONFIG["telemetry_enabled"], - "cua_local": CUA_DEFAULT_CONFIG["local"], - "cua_api_key": CUA_DEFAULT_CONFIG["api_key"], + "booter": "", + "sandbox_lease_timeout": 600, + "sandbox_idle_timeout": 1800, + "sandbox_ttl": 3600, + "max_sandboxes": 10, + "member_permissions": { + "create": False, + "set_retention_policy": False, + "takeover": False, + "destroy": False, + }, }, "image_compress_enabled": True, "image_compress_options": { @@ -3369,143 +3364,80 @@ "hint": "开启后,需要 AstrBot 管理员权限才能调用使用电脑能力。在平台配置->管理员中可添加管理员。使用 /sid 指令查看管理员 ID。", }, "provider_settings.sandbox.booter": { - "description": "沙箱环境驱动器", - "type": "string", - "options": ["shipyard_neo", "shipyard", "cua"], - "labels": ["Shipyard Neo", "Shipyard", "CUA"], - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - }, - }, - "provider_settings.sandbox.shipyard_neo_endpoint": { - "description": "Shipyard Neo API Endpoint", - "type": "string", - "hint": "Shipyard Neo(Bay) 服务的 API 地址,默认 http://127.0.0.1:8114。", - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", - }, - }, - "provider_settings.sandbox.shipyard_neo_access_token": { - "description": "Shipyard Neo Access Token", - "type": "string", - "hint": "Bay 的 API Key(sk-bay-...)。留空时自动从 credentials.json 发现。", - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", - }, - }, - "provider_settings.sandbox.shipyard_neo_profile": { - "description": "Shipyard Neo Profile", + "description": "沙箱驱动", "type": "string", - "hint": "Shipyard Neo 沙箱 profile,如 python-default。留空时自动选择能力更完整的 profile。", + "options": [], + "labels": [], "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", }, }, - "provider_settings.sandbox.shipyard_neo_ttl": { - "description": "Shipyard Neo Sandbox TTL", + "provider_settings.sandbox.sandbox_lease_timeout": { + "description": "沙箱占用超时", "type": "int", - "hint": "Shipyard Neo 沙箱生存时间(秒)。", + "hint": "单位为秒。每次 Agent 成功访问沙盒时,都会自动将本会话的沙盒租约续到当前时间 + 此时长。默认 600 秒;到期后当前会话不再绑定该沙盒,其他会话可接管。`0` 表示租约不会自动过期,需手动释放。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", }, }, - "provider_settings.sandbox.cua_image": { - "description": "CUA Image", - "type": "string", - "hint": "CUA 沙箱镜像/系统类型,默认 linux。可填写 linux、macos、windows、android,具体取决于 CUA SDK 支持。", + "provider_settings.sandbox.sandbox_idle_timeout": { + "description": "沙箱空闲回收时间", + "type": "int", + "hint": "单位为秒。`0` 表示不启用空闲回收,此时才会启用沙箱存活时间。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", }, }, - "provider_settings.sandbox.cua_os_type": { - "description": "CUA OS Type", - "type": "string", - "options": ["linux", "macos", "windows", "android"], - "labels": ["Linux", "macOS", "Windows", "Android"], - "hint": "CUA 沙箱操作系统类型,默认 linux。", + "provider_settings.sandbox.sandbox_ttl": { + "description": "沙箱存活时间", + "type": "int", + "hint": "单位为秒。仅在空闲回收时间为 `0` 时生效;`0` 表示不自动销毁。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", }, }, - "provider_settings.sandbox.cua_idle_timeout": { - "description": "CUA Idle Timeout", + "provider_settings.sandbox.max_sandboxes": { + "description": "最大沙箱数量", "type": "int", - "hint": "Idle timeout for CUA sandbox sessions in seconds. When greater than 0, AstrBot proactively shuts down an idle CUA sandbox after that amount of inactivity; 0 disables it.", + "hint": "全局托管沙箱数量上限,默认 10。`0` 表示不限制。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", }, }, - "provider_settings.sandbox.cua_telemetry_enabled": { - "description": "CUA Telemetry", + "provider_settings.sandbox.member_permissions.create": { + "description": "允许普通用户创建沙箱", "type": "bool", - "hint": "是否允许 CUA SDK 发送遥测数据。默认关闭。", + "hint": "允许普通用户创建新的托管沙箱。普通用户的创建请求仍会受到最大沙箱数量限制。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", + "provider_settings.computer_use_require_admin": False, }, }, - "provider_settings.sandbox.cua_local": { - "description": "CUA Local Sandbox", + "provider_settings.sandbox.member_permissions.set_retention_policy": { + "description": "允许普通用户修改沙箱保留策略", "type": "bool", - "hint": "是否优先使用 CUA 本地沙箱。默认开启,避免云端沙箱要求 CUA_API_KEY。关闭后可使用 CUA 云端沙箱。", + "hint": "允许普通用户在临时沙箱和持久沙箱策略之间切换。持久沙箱会保留环境以便后续复用。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", + "provider_settings.computer_use_require_admin": False, }, }, - "provider_settings.sandbox.cua_api_key": { - "description": "CUA API Key", - "type": "string", - "hint": "CUA 云端沙箱 API Key。仅在关闭本地沙箱时需要。也可以通过 CUA_API_KEY 环境变量提供。", - "obvious_hint": True, - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", - "provider_settings.sandbox.cua_local": False, - }, - }, - "provider_settings.sandbox.shipyard_endpoint": { - "description": "Shipyard API Endpoint", - "type": "string", - "hint": "Shipyard 服务的 API 访问地址。", - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard", - }, - "_special": "check_shipyard_connection", - }, - "provider_settings.sandbox.shipyard_access_token": { - "description": "Shipyard Access Token", - "type": "string", - "hint": "用于访问 Shipyard 服务的访问令牌。", - "condition": { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard", - }, - }, - "provider_settings.sandbox.shipyard_ttl": { - "description": "Shipyard Session TTL", - "type": "int", - "hint": "Shipyard 会话的生存时间(秒)。", + "provider_settings.sandbox.member_permissions.takeover": { + "description": "允许普通用户强占沙箱", + "type": "bool", + "hint": "允许普通用户强制接管被其他会话占用的沙箱。此操作会转移沙箱控制权,建议谨慎开启。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard", + "provider_settings.computer_use_require_admin": False, }, }, - "provider_settings.sandbox.shipyard_max_sessions": { - "description": "Shipyard Max Sessions", - "type": "int", - "hint": "Shipyard 最大会话数量。", + "provider_settings.sandbox.member_permissions.destroy": { + "description": "允许普通用户删除沙箱", + "type": "bool", + "hint": "允许普通用户删除自己可访问的托管沙箱。删除后沙箱环境和对应记录都会被移除。", "condition": { "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard", + "provider_settings.computer_use_require_admin": False, }, }, }, diff --git a/astrbot/core/core_lifecycle.py b/astrbot/core/core_lifecycle.py index c325a2ea38..1173342970 100644 --- a/astrbot/core/core_lifecycle.py +++ b/astrbot/core/core_lifecycle.py @@ -19,6 +19,7 @@ from astrbot.api import logger, sp from astrbot.core import LogBroker, LogManager from astrbot.core.astrbot_config_mgr import AstrBotConfigManager +from astrbot.core.computer import computer_client from astrbot.core.config.default import VERSION from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.cron import CronJobManager @@ -60,6 +61,7 @@ def __init__(self, log_broker: LogBroker, db: BaseDatabase) -> None: self.cron_manager: CronJobManager | None = None self.temp_dir_cleaner: TempDirCleaner | None = None self._default_chat_provider_warning_emitted = False + self._persistent_restore_task: asyncio.Task | None = None # 设置代理 proxy_config = self.astrbot_config.get("http_proxy", "") @@ -242,6 +244,17 @@ async def initialize(self) -> None: # 扫描、注册插件、实例化插件类 await self.plugin_manager.reload() + # Reconcile sandbox registry on startup to clear stale state and + # remove persistent records whose underlying resources no longer exist. + try: + await computer_client.sandbox_manager.reconcile_on_startup() + except Exception as e: + logger.warning( + "Sandbox startup reconciliation failed: %s", + e, + exc_info=True, + ) + # 根据配置实例化各个 Provider self._default_chat_provider_warning_emitted = False await self.provider_manager.initialize() @@ -276,6 +289,41 @@ async def initialize(self) -> None: asyncio.create_task(update_llm_metadata()) + async def _restore_persistent_sandboxes_background(self) -> None: + try: + # Do not let persistent sandbox recovery compete with the main + # startup path. Recovery is best-effort and should never delay the + # process becoming ready. + await asyncio.sleep(0) + ( + restored, + deleted, + ) = await computer_client.sandbox_manager.restore_persistent_sandboxes( + self.star_context, + per_sandbox_timeout=30.0, + ) + logger.info( + "Persistent sandbox restore finished: restored=%d deleted=%d", + restored, + deleted, + ) + except asyncio.CancelledError: + raise + except Exception as e: + logger.warning( + "Persistent sandbox restore failed: %s", + e, + exc_info=True, + ) + + def _schedule_persistent_sandbox_restore(self) -> None: + if self._persistent_restore_task is not None: + return + self._persistent_restore_task = asyncio.create_task( + self._restore_persistent_sandboxes_background(), + name="persistent-sandbox-restore", + ) + def _load(self) -> None: """加载事件总线和任务并初始化.""" # 创建一个异步任务来执行事件总线的 dispatch() 方法 @@ -339,6 +387,7 @@ async def start(self) -> None: """ self._load() logger.info("AstrBot started.") + self._schedule_persistent_sandbox_restore() # 执行启动完成事件钩子 handlers = star_handlers_registry.get_handlers_by_event_type( @@ -368,6 +417,24 @@ async def stop(self) -> None: if self.cron_manager: await self.cron_manager.shutdown() + persistent_restore_task = getattr(self, "_persistent_restore_task", None) + if persistent_restore_task is not None: + persistent_restore_task.cancel() + try: + await persistent_restore_task + except asyncio.CancelledError: + pass + self._persistent_restore_task = None + + try: + await computer_client.cleanup_managed_sandboxes() + except Exception as e: + logger.warning( + "Managed sandbox cleanup during shutdown failed: %s", + e, + exc_info=True, + ) + for plugin in self.plugin_manager.context.get_all_stars(): try: await self.plugin_manager._terminate_plugin(plugin) @@ -399,6 +466,30 @@ async def stop(self) -> None: async def restart(self) -> None: """重启 AstrBot 核心生命周期管理类, 终止各个管理器并重新加载平台实例""" + for task in getattr(self, "curr_tasks", []): + task.cancel() + + if self.cron_manager: + await self.cron_manager.shutdown() + + persistent_restore_task = getattr(self, "_persistent_restore_task", None) + if persistent_restore_task is not None: + persistent_restore_task.cancel() + try: + await persistent_restore_task + except asyncio.CancelledError: + pass + self._persistent_restore_task = None + + try: + await computer_client.cleanup_managed_sandboxes() + except Exception as e: + logger.warning( + "Managed sandbox cleanup during restart failed: %s", + e, + exc_info=True, + ) + await self.provider_manager.terminate() await self.platform_manager.terminate() await self.kb_manager.terminate() diff --git a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py index 4b642d8ce5..b8706d6d6c 100644 --- a/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py +++ b/astrbot/core/platform/sources/aiocqhttp/aiocqhttp_message_event.py @@ -1,4 +1,5 @@ import asyncio +import os import re from collections.abc import AsyncGenerator @@ -31,6 +32,13 @@ def __init__( super().__init__(message_str, message_obj, platform_meta, session_id) self.bot = bot + @staticmethod + def _parse_session_id_int(session_id: str | None) -> int | None: + try: + return int(str(session_id).strip()) + except (TypeError, ValueError): + return None + @staticmethod async def _from_segment_to_dict(segment: BaseMessageComponent) -> dict: """修复部分字段""" @@ -86,6 +94,37 @@ async def _parse_onebot_json(message_chain: MessageChain): ret.append(d) return ret + @staticmethod + async def _upload_file_segment( + bot: CQHttp, + segment: File, + is_group: bool, + session_id: str | None, + ) -> None: + session_id_int = AiocqhttpMessageEvent._parse_session_id_int(session_id) + if not isinstance(session_id_int, int): + raise ValueError( + f"无法发送文件:缺少有效的数字 session_id({session_id})", + ) + + file_path = await segment.get_file(allow_return_url=True) + if not file_path: + raise ValueError("无法发送文件:文件路径为空") + name = segment.name or os.path.basename(str(file_path)) or "file" + + if is_group: + await bot.upload_group_file( + group_id=session_id_int, + file=file_path, + name=name, + ) + else: + await bot.upload_private_file( + user_id=session_id_int, + file=file_path, + name=name, + ) + @classmethod async def _dispatch_send( cls, @@ -95,10 +134,8 @@ async def _dispatch_send( session_id: str | None, messages: list[dict], ) -> None: - # session_id 必须是纯数字字符串 - session_id_int = ( - int(session_id) if session_id and session_id.isdigit() else None - ) + # session_id 必须是数字字符串 + session_id_int = cls._parse_session_id_int(session_id) if is_group and isinstance(session_id_int, int): await bot.send_group_msg(group_id=session_id_int, message=messages) @@ -156,8 +193,11 @@ async def send_message( payload["user_id"] = session_id await bot.call_action("send_private_forward_msg", **payload) elif isinstance(seg, File): - d = await cls._from_segment_to_dict(seg) - await cls._dispatch_send(bot, event, is_group, session_id, [d]) + try: + await cls._upload_file_segment(bot, seg, is_group, session_id) + except Exception: + messages = await cls._parse_onebot_json(MessageChain([seg])) + await cls._dispatch_send(bot, event, is_group, session_id, messages) else: messages = await cls._parse_onebot_json(MessageChain([seg])) if not messages: diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index f3b3cb77c2..8fc32e9f1f 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -423,6 +423,18 @@ def get_builtin_tool(self, tool: str | type[FuncTool]) -> FuncTool: self.builtin_func_list[tool_cls] = builtin_tool return builtin_tool + def clear_builtin_tool_cache_by_module_prefix( + self, module_prefix: str + ) -> list[str]: + removed: list[str] = [] + for tool_cls in tuple(self.builtin_func_list): + if not getattr(tool_cls, "__module__", "").startswith(module_prefix): + continue + tool = self.builtin_func_list.pop(tool_cls, None) + tool_name = get_builtin_tool_name(tool_cls) or getattr(tool, "name", None) + removed.append(tool_name or tool_cls.__name__) + return removed + def iter_builtin_tools(self) -> list[FuncTool]: ensure_builtin_tools_loaded() return [ diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index ae4001fcd6..8780e48176 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -586,6 +586,14 @@ async def load_provider(self, provider_config: dict) -> None: return if provider_config.get("provider_type", "") == "agent_runner": return + if "type" not in provider_config: + logger.warning( + "Provider %s has no adapter type after merging provider source %s, skipping. " + "This is likely a stale provider model config; remove it in the dashboard.", + provider_config.get("id"), + provider_config.get("provider_source_id"), + ) + return logger.info( "Loading model %s(%s) ...", diff --git a/astrbot/core/skills/skill_manager.py b/astrbot/core/skills/skill_manager.py index 838301c044..c13c079582 100644 --- a/astrbot/core/skills/skill_manager.py +++ b/astrbot/core/skills/skill_manager.py @@ -6,6 +6,7 @@ import shlex import shutil import tempfile +import threading import uuid import zipfile from dataclasses import dataclass @@ -29,6 +30,11 @@ _SANDBOX_SKILLS_CACHE_VERSION = 1 _SKILL_NAME_RE = re.compile(r"^[\w.-]+$") +_SANDBOX_SKILLS_CACHE_LOCK = threading.RLock() + + +def _normalize_cache_provider_id(provider_id: str | None) -> str: + return str(provider_id or "").strip().lower() def _normalize_skill_name(name: str | None) -> str: @@ -353,22 +359,38 @@ def _save_config(self, config: dict) -> None: def _load_sandbox_skills_cache(self) -> dict: if not os.path.exists(self.sandbox_skills_cache_path): - return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []} + return { + "version": _SANDBOX_SKILLS_CACHE_VERSION, + "skills": [], + "providers": {}, + } try: with open(self.sandbox_skills_cache_path, encoding="utf-8") as f: data = json.load(f) if not isinstance(data, dict): - return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []} + return { + "version": _SANDBOX_SKILLS_CACHE_VERSION, + "skills": [], + "providers": {}, + } skills = data.get("skills", []) if not isinstance(skills, list): skills = [] + providers = data.get("providers", {}) + if not isinstance(providers, dict): + providers = {} return { "version": int(data.get("version", _SANDBOX_SKILLS_CACHE_VERSION)), "skills": skills, + "providers": providers, "updated_at": data.get("updated_at"), } except Exception: - return {"version": _SANDBOX_SKILLS_CACHE_VERSION, "skills": []} + return { + "version": _SANDBOX_SKILLS_CACHE_VERSION, + "skills": [], + "providers": {}, + } def _save_sandbox_skills_cache(self, cache: dict) -> None: cache["version"] = _SANDBOX_SKILLS_CACHE_VERSION @@ -376,7 +398,11 @@ def _save_sandbox_skills_cache(self, cache: dict) -> None: with open(self.sandbox_skills_cache_path, "w", encoding="utf-8") as f: json.dump(cache, f, ensure_ascii=False, indent=2) - def set_sandbox_skills_cache(self, skills: list[dict]) -> None: + def set_sandbox_skills_cache( + self, + skills: list[dict], + provider_id: str | None = None, + ) -> None: """Persist sandbox skill metadata discovered from runtime side.""" deduped: dict[str, dict[str, str]] = {} for item in skills: @@ -394,16 +420,72 @@ def set_sandbox_skills_cache(self, skills: list[dict]) -> None: "description": description, "path": path, } - cache = { - "version": _SANDBOX_SKILLS_CACHE_VERSION, - "skills": [deduped[name] for name in sorted(deduped)], - } - self._save_sandbox_skills_cache(cache) + provider_key = _normalize_cache_provider_id(provider_id) + skills_payload = [deduped[name] for name in sorted(deduped)] + with _SANDBOX_SKILLS_CACHE_LOCK: + cache = self._load_sandbox_skills_cache() + providers = cache.get("providers", {}) + if not isinstance(providers, dict): + providers = {} + if provider_key: + providers[provider_key] = { + "skills": skills_payload, + } + else: + cache["skills"] = skills_payload + providers["default"] = { + "skills": skills_payload, + } + cache = { + "version": _SANDBOX_SKILLS_CACHE_VERSION, + "skills": cache.get("skills", []), + "providers": providers, + } + self._save_sandbox_skills_cache(cache) + + def _sandbox_cache_skills_for_provider( + self, cache: dict, provider_id: str | None + ) -> list[dict]: + provider_key = _normalize_cache_provider_id(provider_id) + providers = cache.get("providers", {}) + if provider_key and isinstance(providers, dict): + provider_cache = providers.get(provider_key) + if isinstance(provider_cache, dict): + skills = provider_cache.get("skills", []) + return skills if isinstance(skills, list) else [] + return [] + + skills = cache.get("skills", []) + if isinstance(skills, list): + return skills + return [] def get_sandbox_skills_cache_status(self) -> dict[str, object]: cache = self._load_sandbox_skills_cache() - skills = cache.get("skills", []) - count = len(skills) if isinstance(skills, list) else 0 + count = 0 + seen: set[str] = set() + for item in cache.get("skills", []): + if not isinstance(item, dict): + continue + name = str(item.get("name", "")).strip() + if name and name not in seen: + seen.add(name) + count += 1 + providers = cache.get("providers", {}) + if isinstance(providers, dict): + for provider_cache in providers.values(): + if not isinstance(provider_cache, dict): + continue + skills = provider_cache.get("skills", []) + if not isinstance(skills, list): + continue + for item in skills: + if not isinstance(item, dict): + continue + name = str(item.get("name", "")).strip() + if name and name not in seen: + seen.add(name) + count += 1 return { "exists": os.path.exists(self.sandbox_skills_cache_path), "ready": count > 0, @@ -416,6 +498,7 @@ def list_skills( *, active_only: bool = False, runtime: str = "local", + provider_id: str | None = None, show_sandbox_path: bool = True, ) -> list[SkillInfo]: """List all skills. @@ -432,7 +515,9 @@ def list_skills( sandbox_cached_paths: dict[str, str] = {} sandbox_cached_descriptions: dict[str, str] = {} cache_for_paths = self._load_sandbox_skills_cache() - for item in cache_for_paths.get("skills", []): + for item in self._sandbox_cache_skills_for_provider( + cache_for_paths, provider_id + ): if not isinstance(item, dict): continue name = str(item.get("name", "") or "").strip() @@ -527,8 +612,10 @@ def list_skills( ) if runtime == "sandbox": - cache = self._load_sandbox_skills_cache() - for item in cache.get("skills", []): + cache = self._sandbox_cache_skills_for_provider( + self._load_sandbox_skills_cache(), provider_id + ) + for item in cache: if not isinstance(item, dict): continue skill_name = str(item.get("name", "")).strip() @@ -574,14 +661,24 @@ def is_sandbox_only_skill(self, name: str) -> bool: if skill_md_exists: return False cache = self._load_sandbox_skills_cache() - skills = cache.get("skills", []) - if not isinstance(skills, list): - return False - for item in skills: + for item in cache.get("skills", []): if not isinstance(item, dict): continue if str(item.get("name", "")).strip() == name: return True + providers = cache.get("providers", {}) + if isinstance(providers, dict): + for provider_cache in providers.values(): + if not isinstance(provider_cache, dict): + continue + skills = provider_cache.get("skills", []) + if not isinstance(skills, list): + continue + for item in skills: + if not isinstance(item, dict): + continue + if str(item.get("name", "")).strip() == name: + return True return False def is_plugin_skill(self, name: str) -> bool: @@ -598,22 +695,47 @@ def set_skill_active(self, name: str, active: bool) -> None: self._save_config(config) def _remove_skill_from_sandbox_cache(self, name: str) -> None: - cache = self._load_sandbox_skills_cache() - skills = cache.get("skills", []) - if not isinstance(skills, list): - return - - filtered = [ - item - for item in skills - if not ( - isinstance(item, dict) and str(item.get("name", "")).strip() == name - ) - ] - - if len(filtered) != len(skills): - cache["skills"] = filtered - self._save_sandbox_skills_cache(cache) + with _SANDBOX_SKILLS_CACHE_LOCK: + cache = self._load_sandbox_skills_cache() + changed = False + skills = cache.get("skills", []) + if isinstance(skills, list): + filtered = [ + item + for item in skills + if not ( + isinstance(item, dict) + and str(item.get("name", "")).strip() == name + ) + ] + if len(filtered) != len(skills): + cache["skills"] = filtered + changed = True + + providers = cache.get("providers", {}) + if isinstance(providers, dict): + for provider_key, provider_cache in list(providers.items()): + if not isinstance(provider_cache, dict): + continue + provider_skills = provider_cache.get("skills", []) + if not isinstance(provider_skills, list): + continue + filtered = [ + item + for item in provider_skills + if not ( + isinstance(item, dict) + and str(item.get("name", "")).strip() == name + ) + ] + if len(filtered) != len(provider_skills): + provider_cache["skills"] = filtered + providers[provider_key] = provider_cache + changed = True + + if changed: + cache["providers"] = providers + self._save_sandbox_skills_cache(cache) def delete_skill(self, name: str) -> None: if self.is_sandbox_only_skill(name): diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 824c3b653b..738f6b557e 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -31,6 +31,7 @@ from astrbot.core.config.default import VERSION from astrbot.core.platform.register import unregister_platform_adapters_by_module from astrbot.core.provider.register import llm_tools +from astrbot.core.tools.registry import unregister_builtin_tools_by_module_prefix from astrbot.core.utils.astrbot_path import ( get_astrbot_config_path, get_astrbot_path, @@ -745,6 +746,14 @@ def _cleanup_plugin_state(self, dir_name: str, is_reserved: bool = False) -> Non llm_tools.func_list.remove(tool) logger.info(f"清理工具: {tool.name}") + for tool_name in llm_tools.clear_builtin_tool_cache_by_module_prefix( + module_prefix + ): + logger.info(f"清理内置工具缓存: {tool_name}") + + for tool_name in unregister_builtin_tools_by_module_prefix(module_prefix): + logger.info(f"清理内置工具注册: {tool_name}") + for adapter_name in unregister_platform_adapters_by_module(module_prefix): logger.info(f"清理平台适配器: {adapter_name}") @@ -1667,6 +1676,18 @@ async def _unbind_plugin(self, plugin_name: str, plugin_module_path: str) -> Non # module_path is like "data.plugins.my_plugin.main", extract prefix like "data.plugins.my_plugin" module_prefix = ".".join(plugin_module_path.split(".")[:-1]) if module_prefix: + for tool_name in llm_tools.clear_builtin_tool_cache_by_module_prefix( + module_prefix + ): + logger.info( + f"移除了插件 {plugin_name} 的内置工具缓存 {tool_name}", + ) + + for tool_name in unregister_builtin_tools_by_module_prefix(module_prefix): + logger.info( + f"移除了插件 {plugin_name} 的内置工具注册 {tool_name}", + ) + unregistered_adapters = unregister_platform_adapters_by_module( module_prefix ) diff --git a/astrbot/core/tools/computer_tools/__init__.py b/astrbot/core/tools/computer_tools/__init__.py index f90c2e1de8..16610e9e74 100644 --- a/astrbot/core/tools/computer_tools/__init__.py +++ b/astrbot/core/tools/computer_tools/__init__.py @@ -1,8 +1,3 @@ -from .cua import ( - CuaKeyboardTypeTool, - CuaMouseClickTool, - CuaScreenshotTool, -) from .fs import ( FileDownloadTool, FileEditTool, @@ -12,52 +7,23 @@ GrepTool, ) from .python import LocalPythonTool, PythonTool +from .sandbox import SandboxLifecycleTool, SandboxOperationTool, SandboxQueryTool from .shell import ExecuteShellTool -from .shipyard_neo import ( - AnnotateExecutionTool, - BrowserBatchExecTool, - BrowserExecTool, - CreateSkillCandidateTool, - CreateSkillPayloadTool, - EvaluateSkillCandidateTool, - GetExecutionHistoryTool, - GetSkillPayloadTool, - ListSkillCandidatesTool, - ListSkillReleasesTool, - PromoteSkillCandidateTool, - RollbackSkillReleaseTool, - RunBrowserSkillTool, - SyncSkillReleaseTool, -) from .util import check_admin_permission, normalize_umo_for_workspace __all__ = [ - "AnnotateExecutionTool", - "BrowserBatchExecTool", - "BrowserExecTool", - "CreateSkillCandidateTool", - "CreateSkillPayloadTool", - "CuaKeyboardTypeTool", - "CuaMouseClickTool", - "CuaScreenshotTool", - "EvaluateSkillCandidateTool", "ExecuteShellTool", "FileDownloadTool", "FileEditTool", "FileReadTool", "FileUploadTool", "FileWriteTool", - "GetExecutionHistoryTool", - "GetSkillPayloadTool", "GrepTool", - "ListSkillCandidatesTool", - "ListSkillReleasesTool", "LocalPythonTool", - "PromoteSkillCandidateTool", "PythonTool", - "RollbackSkillReleaseTool", - "RunBrowserSkillTool", - "SyncSkillReleaseTool", + "SandboxQueryTool", + "SandboxLifecycleTool", + "SandboxOperationTool", "normalize_umo_for_workspace", "check_admin_permission", ] diff --git a/astrbot/core/tools/computer_tools/cua.py b/astrbot/core/tools/computer_tools/cua.py deleted file mode 100644 index 7b37a55086..0000000000 --- a/astrbot/core/tools/computer_tools/cua.py +++ /dev/null @@ -1,177 +0,0 @@ -from __future__ import annotations - -import json -import uuid -from dataclasses import dataclass, field -from pathlib import Path -from typing import Any - -import mcp - -from astrbot.api import FunctionTool -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import ToolExecResult -from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.computer.computer_client import get_booter -from astrbot.core.message.message_event_result import MessageChain -from astrbot.core.tools.computer_tools.util import check_admin_permission -from astrbot.core.tools.registry import builtin_tool -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -_CUA_TOOL_CONFIG = { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "cua", -} - - -def _to_json(data: Any) -> str: - return json.dumps(data, ensure_ascii=False, default=str) - - -def _exception_detail(error: Exception) -> str: - return str(error) or type(error).__name__ - - -async def _get_gui_component(context: ContextWrapper[AstrAgentContext]) -> Any: - booter = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) - gui = getattr(booter, "gui", None) - if gui is None: - raise RuntimeError( - "Current sandbox booter does not support CUA GUI capability. " - "Please switch sandbox booter to cua." - ) - return gui - - -@builtin_tool(config=_CUA_TOOL_CONFIG) -@dataclass -class CuaScreenshotTool(FunctionTool): - name: str = "astrbot_cua_screenshot" - description: str = ( - "Capture a screenshot from the CUA sandbox and optionally send it to the user." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "send_to_user": { - "type": "boolean", - "description": "Whether to send the screenshot image to the current conversation.", - "default": True, - }, - "return_image_to_llm": { - "type": "boolean", - "description": "Whether to include the screenshot image content in the tool result for model inspection.", - "default": True, - }, - }, - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - send_to_user: bool = True, - return_image_to_llm: bool = True, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Taking CUA screenshots"): - return err - try: - gui = await _get_gui_component(context) - path = _new_screenshot_path(context.context.event.unified_msg_origin) - result = await gui.screenshot(path) - payload = {"success": True, **result, "path": path} - if send_to_user: - await context.context.event.send(MessageChain().file_image(path)) - payload["sent_to_user"] = True - image_data = payload.pop("base64", "") - content: list[mcp.types.TextContent | mcp.types.ImageContent] = [ - mcp.types.TextContent(type="text", text=_to_json(payload)) - ] - if return_image_to_llm: - content.append( - mcp.types.ImageContent( - type="image", - data=str(image_data), - mimeType=str(payload.get("mime_type", "image/png")), - ) - ) - return mcp.types.CallToolResult(content=content) - except Exception as e: - return f"Error taking CUA screenshot: {_exception_detail(e)}" - - -@builtin_tool(config=_CUA_TOOL_CONFIG) -@dataclass -class CuaMouseClickTool(FunctionTool): - name: str = "astrbot_cua_mouse_click" - description: str = "Click a coordinate in the CUA sandbox desktop." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "x": {"type": "integer", "description": "X coordinate."}, - "y": {"type": "integer", "description": "Y coordinate."}, - "button": { - "type": "string", - "description": "Mouse button, usually left, right, or middle.", - "default": "left", - }, - }, - "required": ["x", "y"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - x: int, - y: int, - button: str = "left", - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using CUA mouse"): - return err - try: - gui = await _get_gui_component(context) - return _to_json(await gui.click(x, y, button=button)) - except Exception as e: - return f"Error clicking CUA desktop: {_exception_detail(e)}" - - -@builtin_tool(config=_CUA_TOOL_CONFIG) -@dataclass -class CuaKeyboardTypeTool(FunctionTool): - name: str = "astrbot_cua_keyboard_type" - description: str = "Type text into the CUA sandbox desktop." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "text": {"type": "string", "description": "Text to type."}, - }, - "required": ["text"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - text: str, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using CUA keyboard"): - return err - try: - gui = await _get_gui_component(context) - return _to_json(await gui.type_text(text)) - except Exception as e: - return f"Error typing in CUA desktop: {_exception_detail(e)}" - - -def _new_screenshot_path(umo: str) -> str: - safe_prefix = uuid.uuid5(uuid.NAMESPACE_DNS, umo).hex[:12] - screenshot_dir = Path(get_astrbot_temp_path()) / "cua_screenshots" - screenshot_dir.mkdir(parents=True, exist_ok=True) - return str(screenshot_dir / f"{safe_prefix}-{uuid.uuid4().hex}.png") diff --git a/astrbot/core/tools/computer_tools/fs.py b/astrbot/core/tools/computer_tools/fs.py index 5660022fd0..a44dc1f432 100644 --- a/astrbot/core/tools/computer_tools/fs.py +++ b/astrbot/core/tools/computer_tools/fs.py @@ -70,6 +70,10 @@ _IMAGE_FILE_SUFFIXES = {".bmp", ".gif", ".jpeg", ".jpg", ".png", ".webp"} +def _remote_basename(path: str) -> str: + return path.replace("\\", "/").rstrip("/").split("/")[-1] + + def _restricted_env_path_labels(umo: str, *, include_plugin_skills: bool) -> list[str]: """Labels for the allowed directories in a local(not sandbox) and restricted(not admin) environment""" normalized_umo = normalize_umo_for_workspace(umo) @@ -772,7 +776,7 @@ async def call( context.context.event.unified_msg_origin, ) try: - name = os.path.basename(remote_path) + name = _remote_basename(remote_path) or os.path.basename(remote_path) local_path = os.path.join( get_astrbot_temp_path(), f"sandbox_{uuid.uuid4().hex[:4]}_{name}" @@ -784,7 +788,9 @@ async def call( if also_send_to_user: try: - name = os.path.basename(local_path) + # Keep the user-facing filename stable; the local temp path + # still carries a random prefix to avoid collisions. + name = _remote_basename(remote_path) or os.path.basename(local_path) if Path(local_path).suffix.lower() in _IMAGE_FILE_SUFFIXES: message_component = Image.fromFileSystem(local_path) sent_as = "image" diff --git a/astrbot/core/tools/computer_tools/sandbox.py b/astrbot/core/tools/computer_tools/sandbox.py new file mode 100644 index 0000000000..29421dcabd --- /dev/null +++ b/astrbot/core/tools/computer_tools/sandbox.py @@ -0,0 +1,779 @@ +import json +import time +import uuid +from dataclasses import dataclass, field +from datetime import datetime +from pathlib import Path + +import mcp + +from astrbot.api import FunctionTool +from astrbot.core.agent.run_context import ContextWrapper +from astrbot.core.agent.tool import ToolExecResult +from astrbot.core.astr_agent_context import AstrAgentContext +from astrbot.core.computer import computer_client +from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.utils.astrbot_path import get_astrbot_temp_path + +from ..registry import builtin_tool +from .util import check_admin_permission + +_SANDBOX_RUNTIME_TOOL_CONFIG = { + "provider_settings.computer_use_runtime": "sandbox", +} + + +def _dump(data) -> str: + return json.dumps(data, ensure_ascii=False, default=str) + + +def _remote_basename(path: str) -> str: + return path.replace("\\", "/").rstrip("/").split("/")[-1] + + +def _format_agent_time(value: int | float | None) -> str | None: + if value is None: + return None + if isinstance(value, bool): + return str(value) + if not isinstance(value, (int, float)): + return str(value) + return ( + datetime.fromtimestamp(float(value)) + .astimezone() + .strftime("%Y-%m-%d %H:%M:%S %Z") + ) + + +def _format_sandbox_for_agent(value): + if isinstance(value, list): + return [_format_sandbox_for_agent(item) for item in value] + if not isinstance(value, dict): + return value + formatted = {} + for key, item in value.items(): + if key.endswith("_at"): + formatted[key] = _format_agent_time(item) + else: + formatted[key] = _format_sandbox_for_agent(item) + return formatted + + +def _lease_metadata_for_agent( + context: ContextWrapper[AstrAgentContext], + sandbox_id: str, + record: dict | None = None, +) -> dict | None: + if not sandbox_id: + return None + manager = _sandbox_manager() + if record is None: + registry = getattr(manager, "registry", None) + get_sandbox = getattr(registry, "get_sandbox", None) + record = get_sandbox(sandbox_id) if callable(get_sandbox) else None + if not record: + return None + lease_expires_at = record.get("lease_expires_at") + lease_expires_in_seconds = None + if isinstance(lease_expires_at, (int, float)) and not isinstance( + lease_expires_at, bool + ): + lease_expires_in_seconds = max(0, int(float(lease_expires_at) - time.time())) + return { + "sandbox_id": sandbox_id, + "lease_expires_at": _format_agent_time(lease_expires_at), + "lease_expires_in_seconds": lease_expires_in_seconds, + "auto_renew_interval_seconds": _lease_timeout_for_agent(context), + } + + +def _lease_timeout_for_agent(context: ContextWrapper[AstrAgentContext]) -> float: + manager = _sandbox_manager() + lease_timeout = getattr(manager, "_lease_timeout", None) + if callable(lease_timeout): + return lease_timeout( + context.context.context, + context.context.event.unified_msg_origin, + ) + return float(_sandbox_config(context).get("sandbox_lease_timeout", 600)) + + +def _attach_lease_metadata( + payload: dict, + context: ContextWrapper[AstrAgentContext], + sandbox_id: str | None, + record: dict | None = None, +) -> dict: + if not sandbox_id: + return payload + lease = _lease_metadata_for_agent(context, sandbox_id, record) + if lease is not None: + payload["lease"] = lease + return payload + + +def _sandbox_manager(): + return computer_client.sandbox_manager + + +def _current_provider_id(context: ContextWrapper[AstrAgentContext]) -> str: + plugin_context = context.context.context + session_id = context.context.event.unified_msg_origin + config = plugin_context.get_config(umo=session_id) + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + return str(sandbox_cfg.get("booter", "")).strip() + + +def _is_admin(context: ContextWrapper[AstrAgentContext]) -> bool: + return context.context.event.role == "admin" + + +def _sandbox_config(context: ContextWrapper[AstrAgentContext]) -> dict: + config = context.context.context.get_config( + umo=context.context.event.unified_msg_origin + ) + sandbox_cfg = config.get("provider_settings", {}).get("sandbox", {}) + return sandbox_cfg if isinstance(sandbox_cfg, dict) else {} + + +def _member_sandbox_permission_enabled( + context: ContextWrapper[AstrAgentContext], permission: str +) -> bool: + permissions = _sandbox_config(context).get("member_permissions", {}) + if not isinstance(permissions, dict): + return False + return bool(permissions.get(permission, False)) + + +def _check_basic_sandbox_permission( + context: ContextWrapper[AstrAgentContext], operation_name: str +) -> str | None: + return check_admin_permission(context, operation_name) + + +def _check_member_sandbox_permission( + context: ContextWrapper[AstrAgentContext], operation_name: str, permission: str +) -> str | None: + if permission_error := check_admin_permission(context, operation_name): + return permission_error + if _is_admin(context) or _member_sandbox_permission_enabled(context, permission): + return None + return ( + f"error: Permission denied. {operation_name} is disabled for non-admin users " + "by sandbox member permission settings." + ) + + +def _visible_to_session(record: dict, session_id: str) -> bool: + return record.get("controller_session_id") == session_id or _is_idle_sandbox(record) + + +def _is_idle_sandbox(record: dict) -> bool: + controller_session_id = record.get("controller_session_id") + if not controller_session_id: + return True + lease_expires_at = record.get("lease_expires_at") + return bool(lease_expires_at and lease_expires_at <= time.time()) + + +def _sandbox_status_for_session(record: dict, session_id: str) -> str: + controller_session_id = record.get("controller_session_id") + if controller_session_id == session_id: + return "current" + if controller_session_id and not _is_idle_sandbox(record): + return "occupied" + return "idle" + + +def _redact_sandbox_for_session(record: dict, session_id: str, *, admin: bool) -> dict: + visible = dict(record) + visible["access"] = { + "status": _sandbox_status_for_session(record, session_id), + "can_switch": _visible_to_session(record, session_id), + "occupied": not _is_idle_sandbox(record), + } + if admin: + return visible + visible.pop("connect_info", None) + visible["owner_session_id"] = None + visible["owner_user_id"] = None + visible["created_by_session_id"] = None + visible["created_by_user_id"] = None + if record.get("controller_session_id") != session_id: + visible["controller_session_id"] = None + visible["controller_user_id"] = None + return visible + + +def _sandbox_access_denied( + context: ContextWrapper[AstrAgentContext], record: dict | None +) -> str | None: + if record is None or _is_admin(context): + return None + session_id = context.context.event.unified_msg_origin + if _visible_to_session(record, session_id): + return None + return "error: Permission denied. This sandbox belongs to another session." + + +async def _query_list_sandboxes( + context: ContextWrapper[AstrAgentContext], +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Listing sandboxes" + ): + return permission_error + session_id = context.context.event.unified_msg_origin + manager = _sandbox_manager() + list_checked = getattr(manager, "list_sandboxes_checked", None) + if callable(list_checked): + sandboxes = await list_checked() + else: + sandboxes = manager.list_sandboxes() + sandboxes = [ + _redact_sandbox_for_session(record, session_id, admin=_is_admin(context)) + for record in sandboxes + ] + return _dump({"sandboxes": _format_sandbox_for_agent(sandboxes)}) + + +async def _query_list_providers( + context: ContextWrapper[AstrAgentContext], +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Listing sandbox providers" + ): + return permission_error + return _dump({"providers": computer_client.list_sandbox_providers()}) + + +async def _query_get_current( + context: ContextWrapper[AstrAgentContext], +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Getting current sandbox" + ): + return permission_error + session_id = context.context.event.unified_msg_origin + current = _sandbox_manager().get_current_sandbox(session_id) + payload = _format_sandbox_for_agent(current) + return _dump( + _attach_lease_metadata( + payload, + context, + current.get("current_sandbox_id"), + current.get("sandbox"), + ) + ) + + +async def _lifecycle_create( + context: ContextWrapper[AstrAgentContext], + sandbox_name: str = "", + provider_id: str = "", +) -> ToolExecResult: + if permission_error := _check_member_sandbox_permission( + context, "Creating sandbox", "create" + ): + return permission_error + + plugin_context = context.context.context + session_id = context.context.event.unified_msg_origin + requested_provider_id = str(provider_id).strip().lower() + if requested_provider_id: + provider_id = requested_provider_id + else: + provider_id = _current_provider_id(context) + if not provider_id: + return "Error creating sandbox: sandbox booter is not configured." + manager = _sandbox_manager() + if provider_id not in manager.providers: + providers = computer_client.list_sandbox_providers() + available = ", ".join(p["provider_id"] for p in providers) or "none" + return ( + f"Error creating sandbox: sandbox provider '{provider_id}' is not " + f"available. Available providers: {available}." + ) + + try: + sandbox = await manager.create_sandbox( + plugin_context, + session_id, + provider_id, + sandbox_name=sandbox_name.strip() or None, + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error creating sandbox: {detail}" + + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_switch( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Switching sandbox" + ): + return permission_error + if not sandbox_id: + return "Error switching sandbox: sandbox_id is required." + session_id = context.context.event.unified_msg_origin + manager = _sandbox_manager() + record = manager.registry.get_sandbox(sandbox_id) + if permission_error := _sandbox_access_denied(context, record): + return permission_error + try: + sandbox = await manager.switch_current_sandbox_checked( + session_id, sandbox_id, context=context.context.context + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error switching sandbox: {detail}" + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_release( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Releasing sandbox" + ): + return permission_error + session_id = context.context.event.unified_msg_origin + try: + sandbox = _sandbox_manager().release_current_sandbox( + session_id, sandbox_id.strip() or None + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error releasing sandbox: {detail}" + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_set_retention( + context: ContextWrapper[AstrAgentContext], + retention_policy: str = "", + sandbox_id: str = "", + sandbox_name: str = "", +) -> ToolExecResult: + if permission_error := _check_member_sandbox_permission( + context, "Changing sandbox retention policy", "set_retention_policy" + ): + return permission_error + if not retention_policy: + return "Error changing sandbox retention policy: retention_policy is required." + manager = _sandbox_manager() + session_id = context.context.event.unified_msg_origin + target_sandbox_id = sandbox_id.strip() + if not target_sandbox_id: + current = manager.get_current_sandbox(session_id) + target_sandbox_id = current.get("current_sandbox_id") or "" + if not target_sandbox_id: + return "Error changing sandbox retention policy: No current sandbox" + record = manager.registry.get_sandbox(target_sandbox_id) + if permission_error := _sandbox_access_denied(context, record): + return permission_error + try: + sandbox = manager.set_sandbox_retention_policy( + context.context.context, + session_id, + target_sandbox_id, + retention_policy.strip().lower(), + sandbox_name=sandbox_name.strip() or None, + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error changing sandbox retention policy: {detail}" + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_renew_lease( + context: ContextWrapper[AstrAgentContext], + ttl_seconds: int | float | None = None, +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Renewing sandbox lease" + ): + return permission_error + session_id = context.context.event.unified_msg_origin + try: + sandbox = await _sandbox_manager().renew_current_sandbox_lease( + session_id, ttl_seconds=ttl_seconds, context=context.context.context + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error renewing sandbox lease: {detail}" + payload = {"sandbox": _format_sandbox_for_agent(sandbox)} + return _dump( + _attach_lease_metadata(payload, context, sandbox.get("sandbox_id"), sandbox) + ) + + +async def _lifecycle_takeover( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> ToolExecResult: + if permission_error := _check_member_sandbox_permission( + context, "Taking over sandbox", "takeover" + ): + return permission_error + if not sandbox_id: + return "Error taking over sandbox: sandbox_id is required." + session_id = context.context.event.unified_msg_origin + try: + sandbox = await _sandbox_manager().takeover_sandbox( + session_id, sandbox_id, context=context.context.context + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error taking over sandbox: {detail}" + return _dump({"sandbox": _format_sandbox_for_agent(sandbox)}) + + +async def _lifecycle_destroy( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> ToolExecResult: + if permission_error := _check_member_sandbox_permission( + context, "Destroying sandbox", "destroy" + ): + return permission_error + if not sandbox_id: + return "Error destroying sandbox: sandbox_id is required." + session_id = context.context.event.unified_msg_origin + try: + sandbox = await _sandbox_manager().destroy_sandbox(session_id, sandbox_id) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error destroying sandbox: {detail}" + return _dump({"sandbox": _format_sandbox_for_agent(sandbox)}) + + +def _current_sandbox_id_for_operation( + context: ContextWrapper[AstrAgentContext], sandbox_id: str = "" +) -> str: + target_sandbox_id = sandbox_id.strip() + if target_sandbox_id: + return target_sandbox_id + current = _sandbox_manager().get_current_sandbox( + context.context.event.unified_msg_origin + ) + return str(current.get("current_sandbox_id") or "").strip() + + +async def _operation_capture_screenshot( + context: ContextWrapper[AstrAgentContext], + sandbox_id: str = "", + send_to_user: bool = False, + return_image_to_llm: bool = False, +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Sandbox screenshot capture" + ): + return permission_error + target_sandbox_id = _current_sandbox_id_for_operation(context, sandbox_id) + if not target_sandbox_id: + return "Error taking sandbox screenshot: No current sandbox" + try: + booter = await _sandbox_manager().get_observer_booter_by_id( + target_sandbox_id, + context.context.event.unified_msg_origin, + context=context.context.context, + ) + gui = getattr(booter, "gui", None) + if gui is None: + return f"Error taking sandbox screenshot: sandbox {target_sandbox_id} does not support screenshots." + screenshot_dir = Path(get_astrbot_temp_path()) / "sandbox_screenshots" + screenshot_dir.mkdir(parents=True, exist_ok=True) + path = str(screenshot_dir / f"{uuid.uuid4().hex}.png") + result = await gui.screenshot(path) + payload = {"sandbox_id": target_sandbox_id, "path": path, **result} + if send_to_user: + await context.context.event.send(MessageChain().file_image(path)) + payload["sent_to_user"] = True + image_data = payload.pop("base64", "") + payload = _attach_lease_metadata(payload, context, target_sandbox_id) + if return_image_to_llm: + content: list[mcp.types.TextContent | mcp.types.ImageContent] = [ + mcp.types.TextContent(type="text", text=_dump(payload)) + ] + if image_data: + content.append( + mcp.types.ImageContent( + type="image", + data=str(image_data), + mimeType=str(payload.get("mime_type", "image/png")), + ) + ) + return mcp.types.CallToolResult(content=content) + if image_data: + payload["base64"] = image_data + return _dump(payload) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error taking sandbox screenshot: {detail}" + + +async def _operation_copy_file( + context: ContextWrapper[AstrAgentContext], + source_sandbox_id: str = "", + source_path: str = "", + target_sandbox_id: str = "", + target_path: str = "", +) -> ToolExecResult: + if permission_error := _check_basic_sandbox_permission( + context, "Copying files between sandboxes" + ): + return permission_error + if not all([source_sandbox_id, source_path, target_sandbox_id, target_path]): + return "Error copying file between sandboxes: source_sandbox_id, source_path, target_sandbox_id, and target_path are required." + try: + manager = _sandbox_manager() + session_id = context.context.event.unified_msg_origin + source = await manager.get_observer_booter_by_id( + source_sandbox_id, session_id, context=context.context.context + ) + target = await manager.get_observer_booter_by_id( + target_sandbox_id, session_id, context=context.context.context + ) + temp_dir = Path(get_astrbot_temp_path()) / "sandbox_copy" + temp_dir.mkdir(parents=True, exist_ok=True) + local_path = temp_dir / f"{uuid.uuid4().hex}-{_remote_basename(target_path)}" + try: + await source.download_file(source_path, str(local_path)) + upload_result = await target.upload_file(str(local_path), target_path) + finally: + try: + local_path.unlink(missing_ok=True) + except OSError: + pass + return _dump( + _attach_lease_metadata( + { + "source_sandbox_id": source_sandbox_id, + "source_path": source_path, + "target_sandbox_id": target_sandbox_id, + "target_path": target_path, + "upload_result": upload_result, + }, + context, + target_sandbox_id, + ) + ) + except Exception as e: + detail = str(e) or type(e).__name__ + return f"Error copying file between sandboxes: {detail}" + + +@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG) +@dataclass +class SandboxQueryTool(FunctionTool): + name: str = "astrbot_sandbox_query" + description: str = ( + "Query managed sandboxes, the current sandbox, or loaded sandbox providers. " + "Actions: list_sandboxes has no extra parameters; get_current has no extra parameters; " + "list_providers has no extra parameters. Use list_sandboxes before creating a new sandbox " + "when you need to find a reusable one." + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["list_sandboxes", "get_current", "list_providers"], + "description": "Query action to perform.", + } + }, + "required": ["action"], + } + ) + + async def call( + self, context: ContextWrapper[AstrAgentContext], action: str + ) -> ToolExecResult: + match action: + case "list_sandboxes": + return await _query_list_sandboxes(context) + case "get_current": + return await _query_get_current(context) + case "list_providers": + return await _query_list_providers(context) + return f"Error querying sandbox: unsupported action '{action}'." + + +@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG) +@dataclass +class SandboxLifecycleTool(FunctionTool): + name: str = "astrbot_sandbox_lifecycle" + description: str = ( + "Manage sandbox lifecycle and session occupancy: create, switch, release, " + "renew lease, set retention, takeover, or destroy a sandbox. " + "Actions: create accepts sandbox_name and provider_id; switch requires sandbox_id; " + "release accepts optional sandbox_id; renew_lease accepts ttl_seconds; " + "set_retention requires retention_policy and accepts sandbox_id/sandbox_name; " + "takeover requires sandbox_id; destroy requires sandbox_id." + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": [ + "create", + "switch", + "release", + "renew_lease", + "set_retention", + "takeover", + "destroy", + ], + "description": "Lifecycle action to perform.", + }, + "sandbox_id": { + "type": "string", + "description": "Target sandbox ID for switch, release, retention, takeover, or destroy.", + }, + "sandbox_name": { + "type": "string", + "description": "Optional sandbox name for create or set_retention.", + }, + "provider_id": { + "type": "string", + "description": "Optional provider ID for create. Defaults to configured sandbox booter.", + }, + "retention_policy": { + "type": "string", + "enum": ["persistent", "temporary"], + "description": "Target retention policy for set_retention.", + }, + "ttl_seconds": { + "type": "number", + "description": "Optional lease duration for renew_lease.", + }, + }, + "required": ["action"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + action: str, + sandbox_id: str = "", + sandbox_name: str = "", + provider_id: str = "", + retention_policy: str = "", + ttl_seconds: int | float | None = None, + ) -> ToolExecResult: + match action: + case "create": + return await _lifecycle_create(context, sandbox_name, provider_id) + case "switch": + return await _lifecycle_switch(context, sandbox_id) + case "release": + return await _lifecycle_release(context, sandbox_id) + case "renew_lease": + return await _lifecycle_renew_lease(context, ttl_seconds) + case "set_retention": + return await _lifecycle_set_retention( + context, retention_policy, sandbox_id, sandbox_name + ) + case "takeover": + return await _lifecycle_takeover(context, sandbox_id) + case "destroy": + return await _lifecycle_destroy(context, sandbox_id) + return f"Error managing sandbox lifecycle: unsupported action '{action}'." + + +@builtin_tool(config=_SANDBOX_RUNTIME_TOOL_CONFIG) +@dataclass +class SandboxOperationTool(FunctionTool): + name: str = "astrbot_sandbox_operation" + description: str = ( + "Run standard sandbox operations. Actions: capture_screenshot accepts sandbox_id, " + "send_to_user, and return_image_to_llm; copy_file requires source_sandbox_id, " + "source_path, target_sandbox_id, and target_path." + ) + parameters: dict = field( + default_factory=lambda: { + "type": "object", + "properties": { + "action": { + "type": "string", + "enum": ["capture_screenshot", "copy_file"], + "description": "Sandbox operation to perform.", + }, + "sandbox_id": { + "type": "string", + "description": "Target sandbox ID for capture_screenshot. Defaults to current sandbox.", + }, + "send_to_user": { + "type": "boolean", + "description": "Whether capture_screenshot should send the image to the current conversation.", + "default": False, + }, + "return_image_to_llm": { + "type": "boolean", + "description": "Whether capture_screenshot should include image content in the tool result for model inspection.", + "default": False, + }, + "source_sandbox_id": { + "type": "string", + "description": "Source sandbox ID for copy_file.", + }, + "source_path": { + "type": "string", + "description": "Source path for copy_file.", + }, + "target_sandbox_id": { + "type": "string", + "description": "Target sandbox ID for copy_file.", + }, + "target_path": { + "type": "string", + "description": "Target path for copy_file.", + }, + }, + "required": ["action"], + } + ) + + async def call( + self, + context: ContextWrapper[AstrAgentContext], + action: str, + sandbox_id: str = "", + send_to_user: bool = False, + return_image_to_llm: bool = False, + source_sandbox_id: str = "", + source_path: str = "", + target_sandbox_id: str = "", + target_path: str = "", + ) -> ToolExecResult: + match action: + case "capture_screenshot": + return await _operation_capture_screenshot( + context, sandbox_id, send_to_user, return_image_to_llm + ) + case "copy_file": + return await _operation_copy_file( + context, + source_sandbox_id, + source_path, + target_sandbox_id, + target_path, + ) + return f"Error running sandbox operation: unsupported action '{action}'." diff --git a/astrbot/core/tools/computer_tools/shipyard_neo/__init__.py b/astrbot/core/tools/computer_tools/shipyard_neo/__init__.py deleted file mode 100644 index 9228c86354..0000000000 --- a/astrbot/core/tools/computer_tools/shipyard_neo/__init__.py +++ /dev/null @@ -1,31 +0,0 @@ -from .browser import BrowserBatchExecTool, BrowserExecTool, RunBrowserSkillTool -from .neo_skills import ( - AnnotateExecutionTool, - CreateSkillCandidateTool, - CreateSkillPayloadTool, - EvaluateSkillCandidateTool, - GetExecutionHistoryTool, - GetSkillPayloadTool, - ListSkillCandidatesTool, - ListSkillReleasesTool, - PromoteSkillCandidateTool, - RollbackSkillReleaseTool, - SyncSkillReleaseTool, -) - -__all__ = [ - "AnnotateExecutionTool", - "BrowserBatchExecTool", - "BrowserExecTool", - "CreateSkillCandidateTool", - "CreateSkillPayloadTool", - "EvaluateSkillCandidateTool", - "GetExecutionHistoryTool", - "GetSkillPayloadTool", - "ListSkillCandidatesTool", - "ListSkillReleasesTool", - "PromoteSkillCandidateTool", - "RollbackSkillReleaseTool", - "RunBrowserSkillTool", - "SyncSkillReleaseTool", -] diff --git a/astrbot/core/tools/computer_tools/shipyard_neo/browser.py b/astrbot/core/tools/computer_tools/shipyard_neo/browser.py deleted file mode 100644 index b4b7f4fd06..0000000000 --- a/astrbot/core/tools/computer_tools/shipyard_neo/browser.py +++ /dev/null @@ -1,204 +0,0 @@ -import json -from dataclasses import dataclass, field -from typing import Any - -from astrbot.api import FunctionTool -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import ToolExecResult -from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.computer.computer_client import get_booter -from astrbot.core.tools.computer_tools.util import check_admin_permission -from astrbot.core.tools.registry import builtin_tool - -_SHIPYARD_NEO_TOOL_CONFIG = { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", -} - - -def _to_json(data: Any) -> str: - return json.dumps(data, ensure_ascii=False, default=str) - - -async def _get_browser_component(context: ContextWrapper[AstrAgentContext]) -> Any: - booter = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) - browser = getattr(booter, "browser", None) - if browser is None: - raise RuntimeError( - "Current sandbox booter does not support browser capability. " - "Please switch to shipyard_neo." - ) - return browser - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class BrowserExecTool(FunctionTool): - name: str = "astrbot_execute_browser" - description: str = "Execute one browser automation command in the sandbox." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "cmd": {"type": "string", "description": "Browser command to execute."}, - "timeout": {"type": "integer", "default": 30}, - "description": { - "type": "string", - "description": "Optional execution description.", - }, - "tags": {"type": "string", "description": "Optional tags."}, - "learn": { - "type": "boolean", - "description": "Whether to mark execution as learn evidence.", - "default": False, - }, - "include_trace": { - "type": "boolean", - "description": "Whether to include trace_ref in response.", - "default": False, - }, - }, - "required": ["cmd"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - cmd: str, - timeout: int = 30, - description: str | None = None, - tags: str | None = None, - learn: bool = False, - include_trace: bool = False, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using browser tools"): - return err - try: - browser = await _get_browser_component(context) - result = await browser.exec( - cmd=cmd, - timeout=timeout, - description=description, - tags=tags, - learn=learn, - include_trace=include_trace, - ) - return _to_json(result) - except Exception as e: - return f"Error executing browser command: {str(e)}" - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class BrowserBatchExecTool(FunctionTool): - name: str = "astrbot_execute_browser_batch" - description: str = "Execute a browser command batch in the sandbox." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "commands": { - "type": "array", - "items": {"type": "string"}, - "description": "Ordered browser commands.", - }, - "timeout": {"type": "integer", "default": 60}, - "stop_on_error": {"type": "boolean", "default": True}, - "description": { - "type": "string", - "description": "Optional execution description.", - }, - "tags": {"type": "string", "description": "Optional tags."}, - "learn": { - "type": "boolean", - "description": "Whether to mark execution as learn evidence.", - "default": False, - }, - "include_trace": { - "type": "boolean", - "description": "Whether to include trace_ref in response.", - "default": False, - }, - }, - "required": ["commands"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - commands: list[str], - timeout: int = 60, - stop_on_error: bool = True, - description: str | None = None, - tags: str | None = None, - learn: bool = False, - include_trace: bool = False, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using browser tools"): - return err - try: - browser = await _get_browser_component(context) - result = await browser.exec_batch( - commands=commands, - timeout=timeout, - stop_on_error=stop_on_error, - description=description, - tags=tags, - learn=learn, - include_trace=include_trace, - ) - return _to_json(result) - except Exception as e: - return f"Error executing browser batch command: {str(e)}" - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class RunBrowserSkillTool(FunctionTool): - name: str = "astrbot_run_browser_skill" - description: str = "Run a released browser skill in the sandbox by skill_key." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "skill_key": {"type": "string"}, - "timeout": {"type": "integer", "default": 60}, - "stop_on_error": {"type": "boolean", "default": True}, - "include_trace": {"type": "boolean", "default": False}, - "description": {"type": "string"}, - "tags": {"type": "string"}, - }, - "required": ["skill_key"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - skill_key: str, - timeout: int = 60, - stop_on_error: bool = True, - include_trace: bool = False, - description: str | None = None, - tags: str | None = None, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using browser tools"): - return err - try: - browser = await _get_browser_component(context) - result = await browser.run_skill( - skill_key=skill_key, - timeout=timeout, - stop_on_error=stop_on_error, - include_trace=include_trace, - description=description, - tags=tags, - ) - return _to_json(result) - except Exception as e: - return f"Error running browser skill: {str(e)}" diff --git a/astrbot/core/tools/computer_tools/shipyard_neo/neo_skills.py b/astrbot/core/tools/computer_tools/shipyard_neo/neo_skills.py deleted file mode 100644 index e2c4f59093..0000000000 --- a/astrbot/core/tools/computer_tools/shipyard_neo/neo_skills.py +++ /dev/null @@ -1,556 +0,0 @@ -import json -from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field -from typing import Any - -from astrbot.api import FunctionTool -from astrbot.core.agent.run_context import ContextWrapper -from astrbot.core.agent.tool import ToolExecResult -from astrbot.core.astr_agent_context import AstrAgentContext -from astrbot.core.computer.computer_client import get_booter -from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager -from astrbot.core.tools.computer_tools.util import check_admin_permission -from astrbot.core.tools.registry import builtin_tool - -_SHIPYARD_NEO_TOOL_CONFIG = { - "provider_settings.computer_use_runtime": "sandbox", - "provider_settings.sandbox.booter": "shipyard_neo", -} - - -def _to_jsonable(model_like: Any) -> Any: - if isinstance(model_like, dict): - return model_like - if isinstance(model_like, list): - return [_to_jsonable(i) for i in model_like] - if hasattr(model_like, "model_dump"): - return _to_jsonable(model_like.model_dump()) - return model_like - - -def _to_json_text(data: Any) -> str: - return json.dumps(_to_jsonable(data), ensure_ascii=False, default=str) - - -async def _get_neo_context( - context: ContextWrapper[AstrAgentContext], -) -> tuple[Any, Any]: - booter = await get_booter( - context.context.context, - context.context.event.unified_msg_origin, - ) - client = getattr(booter, "bay_client", None) - sandbox = getattr(booter, "sandbox", None) - if client is None or sandbox is None: - raise RuntimeError( - "Current sandbox booter does not support Neo skill lifecycle APIs. " - "Please switch to shipyard_neo." - ) - return client, sandbox - - -@dataclass -class NeoSkillToolBase(FunctionTool): - error_prefix: str = "Error" - - async def _run( - self, - context: ContextWrapper[AstrAgentContext], - neo_call: Callable[[Any, Any], Awaitable[Any]], - error_action: str, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using skill lifecycle tools"): - return err - try: - client, sandbox = await _get_neo_context(context) - result = await neo_call(client, sandbox) - return _to_json_text(result) - except Exception as e: - return f"{self.error_prefix} {error_action}: {str(e)}" - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class GetExecutionHistoryTool(NeoSkillToolBase): - name: str = "astrbot_get_execution_history" - description: str = "Get execution history from current sandbox." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "exec_type": {"type": "string"}, - "success_only": {"type": "boolean", "default": False}, - "limit": {"type": "integer", "default": 100}, - "offset": {"type": "integer", "default": 0}, - "tags": {"type": "string"}, - "has_notes": {"type": "boolean", "default": False}, - "has_description": {"type": "boolean", "default": False}, - }, - "required": [], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - exec_type: str | None = None, - success_only: bool = False, - limit: int = 100, - offset: int = 0, - tags: str | None = None, - has_notes: bool = False, - has_description: bool = False, - ) -> ToolExecResult: - return await self._run( - context, - lambda _client, sandbox: sandbox.get_execution_history( - exec_type=exec_type, - success_only=success_only, - limit=limit, - offset=offset, - tags=tags, - has_notes=has_notes, - has_description=has_description, - ), - error_action="getting execution history", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class AnnotateExecutionTool(NeoSkillToolBase): - name: str = "astrbot_annotate_execution" - description: str = "Annotate one execution history record." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "execution_id": {"type": "string"}, - "description": {"type": "string"}, - "tags": {"type": "string"}, - "notes": {"type": "string"}, - }, - "required": ["execution_id"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - execution_id: str, - description: str | None = None, - tags: str | None = None, - notes: str | None = None, - ) -> ToolExecResult: - return await self._run( - context, - lambda _client, sandbox: sandbox.annotate_execution( - execution_id=execution_id, - description=description, - tags=tags, - notes=notes, - ), - error_action="annotating execution", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class CreateSkillPayloadTool(NeoSkillToolBase): - name: str = "astrbot_create_skill_payload" - description: str = ( - "Step 1/3 for Neo skill authoring: create immutable payload content and return payload_ref. " - "Use this to store skill_markdown and structured metadata; do NOT write local skill folders directly." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "payload": { - "anyOf": [ - {"type": "object"}, - {"type": "array", "items": {"type": "object"}}, - ], - "description": ( - "Skill payload JSON. Typical schema: {skill_markdown, inputs, outputs, meta}. " - "This only stores content and returns payload_ref; it does not create a candidate or release." - ), - }, - "kind": { - "type": "string", - "description": "Payload kind.", - "default": "astrbot_skill_v1", - }, - }, - "required": ["payload"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - payload: dict[str, Any] | list[Any], - kind: str = "astrbot_skill_v1", - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.create_payload( - payload=payload, - kind=kind, - ), - error_action="creating skill payload", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class GetSkillPayloadTool(NeoSkillToolBase): - name: str = "astrbot_get_skill_payload" - description: str = "Get one skill payload by payload_ref." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "payload_ref": {"type": "string"}, - }, - "required": ["payload_ref"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - payload_ref: str, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.get_payload(payload_ref), - error_action="getting skill payload", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class CreateSkillCandidateTool(NeoSkillToolBase): - name: str = "astrbot_create_skill_candidate" - description: str = ( - "Step 2/3 for Neo skill authoring: create a candidate by binding execution evidence " - "(source_execution_ids) with skill identity (skill_key) and optional payload_ref." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "skill_key": { - "type": "string", - "description": "Stable logical identifier, e.g. image-collage-9grid.", - }, - "source_execution_ids": { - "type": "array", - "items": {"type": "string"}, - "description": "Execution evidence IDs captured from sandbox history.", - }, - "scenario_key": { - "type": "string", - "description": "Optional scenario namespace for grouping candidates.", - }, - "payload_ref": { - "type": "string", - "description": "Optional payload reference created by astrbot_create_skill_payload.", - }, - }, - "required": ["skill_key", "source_execution_ids"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - skill_key: str, - source_execution_ids: list[str], - scenario_key: str | None = None, - payload_ref: str | None = None, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.create_candidate( - skill_key=skill_key, - source_execution_ids=source_execution_ids, - scenario_key=scenario_key, - payload_ref=payload_ref, - ), - error_action="creating skill candidate", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class ListSkillCandidatesTool(NeoSkillToolBase): - name: str = "astrbot_list_skill_candidates" - description: str = "List skill candidates." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "status": {"type": "string"}, - "skill_key": {"type": "string"}, - "limit": {"type": "integer", "default": 100}, - "offset": {"type": "integer", "default": 0}, - }, - "required": [], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - status: str | None = None, - skill_key: str | None = None, - limit: int = 100, - offset: int = 0, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.list_candidates( - status=status, - skill_key=skill_key, - limit=limit, - offset=offset, - ), - error_action="listing skill candidates", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class EvaluateSkillCandidateTool(NeoSkillToolBase): - name: str = "astrbot_evaluate_skill_candidate" - description: str = "Evaluate a skill candidate." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "candidate_id": {"type": "string"}, - "passed": {"type": "boolean"}, - "score": {"type": "number"}, - "benchmark_id": {"type": "string"}, - "report": {"type": "string"}, - }, - "required": ["candidate_id", "passed"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - candidate_id: str, - passed: bool, - score: float | None = None, - benchmark_id: str | None = None, - report: str | None = None, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.evaluate_candidate( - candidate_id, - passed=passed, - score=score, - benchmark_id=benchmark_id, - report=report, - ), - error_action="evaluating skill candidate", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class PromoteSkillCandidateTool(NeoSkillToolBase): - name: str = "astrbot_promote_skill_candidate" - description: str = ( - "Step 3/3 for Neo skill authoring: promote candidate to canary/stable release. " - "If stage=stable and sync_to_local=true, payload.skill_markdown is synced to local SKILL.md automatically." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "candidate_id": {"type": "string"}, - "stage": { - "type": "string", - "description": "Release stage: canary/stable", - "default": "canary", - }, - "sync_to_local": { - "type": "boolean", - "description": ( - "Only used with stage=stable. true means sync payload.skill_markdown to local SKILL.md; " - "false means release remains Neo-side only." - ), - "default": True, - }, - }, - "required": ["candidate_id"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - candidate_id: str, - stage: str = "canary", - sync_to_local: bool = True, - ) -> ToolExecResult: - if err := check_admin_permission(context, "Using skill lifecycle tools"): - return err - if stage not in {"canary", "stable"}: - return "Error promoting skill candidate: stage must be canary or stable." - - try: - client, _sandbox = await _get_neo_context(context) - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.promote_with_optional_sync( - client, - candidate_id=candidate_id, - stage=stage, - sync_to_local=sync_to_local, - ) - if result.get("sync_error"): - rollback_json = result.get("rollback") - if rollback_json: - return ( - "Error promoting skill candidate: stable release synced failed; " - f"auto rollback succeeded. sync_error={result['sync_error']}; " - f"rollback={_to_json_text(rollback_json)}" - ) - return _to_json_text( - { - "release": result.get("release"), - "sync": result.get("sync"), - "rollback": result.get("rollback"), - } - ) - except Exception as e: - return f"Error promoting skill candidate: {str(e)}" - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class ListSkillReleasesTool(NeoSkillToolBase): - name: str = "astrbot_list_skill_releases" - description: str = "List skill releases." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "skill_key": {"type": "string"}, - "active_only": {"type": "boolean", "default": False}, - "stage": {"type": "string"}, - "limit": {"type": "integer", "default": 100}, - "offset": {"type": "integer", "default": 0}, - }, - "required": [], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - skill_key: str | None = None, - active_only: bool = False, - stage: str | None = None, - limit: int = 100, - offset: int = 0, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.list_releases( - skill_key=skill_key, - active_only=active_only, - stage=stage, - limit=limit, - offset=offset, - ), - error_action="listing skill releases", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class RollbackSkillReleaseTool(NeoSkillToolBase): - name: str = "astrbot_rollback_skill_release" - description: str = "Rollback one skill release." - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "release_id": {"type": "string"}, - }, - "required": ["release_id"], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - release_id: str, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: client.skills.rollback_release(release_id), - error_action="rolling back skill release", - ) - - -@builtin_tool(config=_SHIPYARD_NEO_TOOL_CONFIG) -@dataclass -class SyncSkillReleaseTool(NeoSkillToolBase): - name: str = "astrbot_sync_skill_release" - description: str = ( - "Sync stable Neo release payload to local SKILL.md and update mapping metadata." - ) - parameters: dict = field( - default_factory=lambda: { - "type": "object", - "properties": { - "release_id": {"type": "string"}, - "skill_key": {"type": "string"}, - "require_stable": {"type": "boolean", "default": True}, - }, - "required": [], - } - ) - - async def call( - self, - context: ContextWrapper[AstrAgentContext], - release_id: str | None = None, - skill_key: str | None = None, - require_stable: bool = True, - ) -> ToolExecResult: - return await self._run( - context, - lambda client, _sandbox: _sync_release_to_dict( - client, - release_id=release_id, - skill_key=skill_key, - require_stable=require_stable, - ), - error_action="syncing skill release", - ) - - -async def _sync_release_to_dict( - client: Any, - *, - release_id: str | None, - skill_key: str | None, - require_stable: bool, -) -> dict[str, str]: - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.sync_release( - client, - release_id=release_id, - skill_key=skill_key, - require_stable=require_stable, - ) - return sync_mgr.sync_result_to_dict(result) diff --git a/astrbot/core/tools/computer_tools/util.py b/astrbot/core/tools/computer_tools/util.py index a3930b4c6a..07b4630c3c 100644 --- a/astrbot/core/tools/computer_tools/util.py +++ b/astrbot/core/tools/computer_tools/util.py @@ -41,3 +41,15 @@ def check_admin_permission( f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command." ) return None + + +def check_strict_admin_permission( + context: ContextWrapper[AstrAgentContext], operation_name: str +) -> str | None: + if context.context.event.role != "admin": + return ( + f"error: Permission denied. {operation_name} is only allowed for admin users. " + "Tell user to set admins in `AstrBot WebUI -> Config -> General Config` by adding their user ID to the admins list if they need this feature. " + f"User's ID is: {context.context.event.get_sender_id()}. User's ID can be found by using /sid command." + ) + return None diff --git a/astrbot/core/tools/registry.py b/astrbot/core/tools/registry.py index c3b10d2295..196a8fe00f 100644 --- a/astrbot/core/tools/registry.py +++ b/astrbot/core/tools/registry.py @@ -19,6 +19,7 @@ _builtin_tool_classes_by_name: dict[str, type[FunctionTool]] = {} _builtin_tool_names_by_class: dict[type[FunctionTool], str] = {} +_builtin_tool_names_by_module_prefix: dict[str, tuple[str, ...]] = {} _builtin_tools_loaded = False _MISSING = object() @@ -118,6 +119,10 @@ def _build_rule_from_config_map( return BuiltinToolConfigRule(conditions=tuple(conditions)) +def build_builtin_tool_config_rule(config_map: dict[str, Any]) -> BuiltinToolConfigRule: + return _build_rule_from_config_map(config_map) + + def _evaluate_send_message_tool(config: dict[str, Any]) -> list[dict[str, Any]]: platform_configs = config.get("platform", []) if not isinstance(platform_configs, list): @@ -238,6 +243,51 @@ def _register(cls: TFunctionTool) -> TFunctionTool: return _register(tool_cls) +def unregister_builtin_tool_class(tool_cls: type[FunctionTool]) -> str | None: + tool_name = _builtin_tool_names_by_class.pop(tool_cls, None) + if tool_name is None: + return None + existing = _builtin_tool_classes_by_name.get(tool_name) + if existing is tool_cls: + _builtin_tool_classes_by_name.pop(tool_name, None) + _BUILTIN_TOOL_CONFIG_RULES.pop(tool_name, None) + return tool_name + + +def _iter_builtin_tool_names_by_module_prefix(module_prefix: str) -> tuple[str, ...]: + return tuple( + tool_name + for tool_cls, tool_name in _builtin_tool_names_by_class.items() + if getattr(tool_cls, "__module__", "").startswith(module_prefix) + ) + + +def register_builtin_tools_by_module_prefix(module_prefix: str) -> list[str]: + ensure_builtin_tools_loaded() + tool_names = _iter_builtin_tool_names_by_module_prefix(module_prefix) + _builtin_tool_names_by_module_prefix[module_prefix] = tool_names + return list(tool_names) + + +def unregister_builtin_tools_by_module_prefix(module_prefix: str) -> list[str]: + recorded_tool_names = _builtin_tool_names_by_module_prefix.pop(module_prefix, ()) + tool_names = recorded_tool_names or _iter_builtin_tool_names_by_module_prefix( + module_prefix + ) + + removed: list[str] = [] + for tool_name in tool_names: + tool_cls = _builtin_tool_classes_by_name.get(tool_name) + if tool_cls is None: + continue + if not getattr(tool_cls, "__module__", "").startswith(module_prefix): + continue + removed_tool_name = unregister_builtin_tool_class(tool_cls) + if removed_tool_name is not None: + removed.append(removed_tool_name) + return removed + + def ensure_builtin_tools_loaded() -> None: global _builtin_tools_loaded if _builtin_tools_loaded: @@ -318,6 +368,7 @@ def get_builtin_tool_config_tags( __all__ = [ "builtin_tool", + "build_builtin_tool_config_rule", "ensure_builtin_tools_loaded", "get_builtin_tool_config_rule", "get_builtin_tool_config_statuses", @@ -325,4 +376,7 @@ def get_builtin_tool_config_tags( "get_builtin_tool_class", "get_builtin_tool_name", "iter_builtin_tool_classes", + "register_builtin_tools_by_module_prefix", + "unregister_builtin_tool_class", + "unregister_builtin_tools_by_module_prefix", ] diff --git a/astrbot/core/utils/migra_helper.py b/astrbot/core/utils/migra_helper.py index 40b899620d..7846ff215e 100644 --- a/astrbot/core/utils/migra_helper.py +++ b/astrbot/core/utils/migra_helper.py @@ -128,6 +128,52 @@ def _migra_provider_to_source_structure(conf: AstrBotConfig) -> None: logger.info("Provider-source structure migration completed") +def _prune_invalid_provider_source_models(conf: AstrBotConfig) -> None: + """Remove stale provider model entries that cannot resolve a source type. + + New-style provider model entries intentionally keep only model-level fields and + inherit adapter fields such as `type` from provider_sources. Older configs can + contain entries that reference a removed source id; those can never be loaded + and would otherwise produce repeated startup errors. + """ + providers = conf.get("provider", []) + if not isinstance(providers, list): + return + + provider_sources = conf.get("provider_sources", []) + source_types = { + source.get("id"): source.get("type") + for source in provider_sources + if isinstance(source, dict) + } + + kept = [] + removed_ids = [] + for provider in providers: + if not isinstance(provider, dict): + kept.append(provider) + continue + if provider.get("type"): + kept.append(provider) + continue + source_id = provider.get("provider_source_id") + if source_id and source_types.get(source_id): + kept.append(provider) + continue + removed_ids.append(str(provider.get("id") or source_id or "")) + + if not removed_ids: + return + + conf["provider"] = kept + conf.save_config() + logger.info( + "Pruned %d invalid provider model config(s) with missing adapter type: %s", + len(removed_ids), + ", ".join(removed_ids), + ) + + async def migra( db, astrbot_config_mgr, umop_config_router, acm: AstrBotConfigManager ) -> None: @@ -181,3 +227,9 @@ async def migra( except Exception as e: logger.error(f"Migration for provider-source structure failed: {e!s}") logger.error(traceback.format_exc()) + + try: + _prune_invalid_provider_source_models(astrbot_config) + except Exception as e: + logger.error(f"Migration for invalid provider-source models failed: {e!s}") + logger.error(traceback.format_exc()) diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py index fbbd0c7a08..512e2cf1a3 100644 --- a/astrbot/dashboard/routes/__init__.py +++ b/astrbot/dashboard/routes/__init__.py @@ -14,6 +14,7 @@ from .persona import PersonaRoute from .platform import PlatformRoute from .plugin import PluginRoute +from .sandbox import SandboxRoute from .session_management import SessionManagementRoute from .skills import SkillsRoute from .stat import StatRoute @@ -40,6 +41,7 @@ "PlatformRoute", "PluginRoute", "SessionManagementRoute", + "SandboxRoute", "StatRoute", "StaticFileRoute", "SubAgentRoute", diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index aebe26047c..bfa01daa9e 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -9,6 +9,7 @@ from quart import jsonify, make_response, request from astrbot.core import astrbot_config, file_token_service, logger +from astrbot.core.computer import computer_client from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.config.default import ( CONFIG_METADATA_2, @@ -284,58 +285,27 @@ def _protected_2fa_config_changed(old_config: dict, new_config: dict) -> bool: ) +def _normalize_unavailable_sandbox_booter(config: dict) -> dict: + sandbox = config.get("provider_settings", {}).get("sandbox", {}) + if not isinstance(sandbox, dict): + return config + booter = str(sandbox.get("booter") or "").strip() + if not booter: + return config + provider_ids = { + str(provider.get("provider_id") or "") + for provider in computer_client.list_sandbox_providers() + } + if booter not in provider_ids: + sandbox["booter"] = "" + return config + + async def _validate_neo_connectivity( post_config: dict, ) -> str | None: - """Check if Bay is reachable when Shipyard Neo sandbox is configured. - - Returns a warning message string if Bay isn't reachable, or None if - everything looks fine (or Neo isn't configured). - """ - ps = post_config.get("provider_settings", {}) - runtime = ps.get("computer_use_runtime", "none") - sandbox = ps.get("sandbox", {}) - booter = sandbox.get("booter", "") - - # Only check when sandbox mode + shipyard_neo is selected - if runtime != "sandbox" or booter != "shipyard_neo": - return None - - endpoint = sandbox.get("shipyard_neo_endpoint", "").rstrip("/") - if not endpoint: - return "⚠️ Shipyard Neo endpoint 未设置" - - access_token = sandbox.get("shipyard_neo_access_token", "") - if not access_token: - # Try auto-discovery - from astrbot.core.computer.computer_client import _discover_bay_credentials - - access_token = _discover_bay_credentials(endpoint) - - if not access_token: - return ( - "⚠️ 未找到 Bay API Key。请填写访问令牌," - "或确保 Bay 的 credentials.json 可被自动发现。" - ) - - # Connectivity check - import aiohttp - - health_url = f"{endpoint}/health" - try: - async with aiohttp.ClientSession() as session: - async with session.get( - health_url, - timeout=aiohttp.ClientTimeout(total=5), - ) as resp: - if resp.status != 200: - return ( - f"⚠️ Bay 健康检查失败 (HTTP {resp.status})," - f"请确认 Bay 正在运行: {endpoint}" - ) - except Exception: - return f"⚠️ 无法连接 Bay ({endpoint}),请确认 Bay 已启动。" - + """Concrete sandbox providers own their connectivity checks.""" + del post_config return None @@ -632,8 +602,11 @@ async def delete_ucr(self): async def get_default_config(self): """获取默认配置文件""" - metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3) - return Response().ok({"config": DEFAULT_CONFIG, "metadata": metadata}).__dict__ + metadata = ConfigMetadataI18n.convert_to_i18n_keys( + self._inject_sandbox_provider_options(copy.deepcopy(CONFIG_METADATA_3)) + ) + config = _normalize_unavailable_sandbox_booter(copy.deepcopy(DEFAULT_CONFIG)) + return Response().ok({"config": config, "metadata": metadata}).__dict__ async def get_abconf_list(self): """获取所有 AstrBot 配置文件的列表""" @@ -664,15 +637,23 @@ async def get_abconf(self): try: if system_config: - abconf = self.acm.confs["default"] + abconf = _normalize_unavailable_sandbox_booter( + copy.deepcopy(dict(self.acm.confs["default"])) + ) metadata = ConfigMetadataI18n.convert_to_i18n_keys( - CONFIG_METADATA_3_SYSTEM + self._inject_sandbox_provider_options( + copy.deepcopy(CONFIG_METADATA_3_SYSTEM) + ) ) return Response().ok({"config": abconf, "metadata": metadata}).__dict__ if abconf_id is None: raise ValueError("abconf_id cannot be None") - abconf = self.acm.confs[abconf_id] - metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3) + abconf = _normalize_unavailable_sandbox_booter( + copy.deepcopy(dict(self.acm.confs[abconf_id])) + ) + metadata = ConfigMetadataI18n.convert_to_i18n_keys( + self._inject_sandbox_provider_options(copy.deepcopy(CONFIG_METADATA_3)) + ) return Response().ok({"config": abconf, "metadata": metadata}).__dict__ except ValueError as e: return Response().error(str(e)).__dict__ @@ -1588,12 +1569,29 @@ async def _get_astrbot_config(self): if provider.default_config_tmpl: provider_default_tmpl[provider.type] = provider.default_config_tmpl + self._inject_sandbox_provider_options(metadata) + return { "metadata": metadata, "config": config, "platform_i18n_translations": platform_i18n_translations, } + def _inject_sandbox_provider_options(self, metadata: dict) -> dict: + try: + items = metadata["ai_group"]["metadata"]["agent_computer_use"]["items"] + booter = items.get("provider_settings.sandbox.booter") + except KeyError: + return metadata + if not isinstance(booter, dict): + return metadata + + providers = computer_client.list_sandbox_providers() + options = [provider["provider_id"] for provider in providers] + booter["options"] = options + booter["labels"] = options.copy() + return metadata + async def _get_plugin_config(self, plugin_name: str): ret: dict = {"metadata": None, "config": None, "i18n": {}} diff --git a/astrbot/dashboard/routes/sandbox.py b/astrbot/dashboard/routes/sandbox.py new file mode 100644 index 0000000000..786ef113f2 --- /dev/null +++ b/astrbot/dashboard/routes/sandbox.py @@ -0,0 +1,311 @@ +import traceback + +from quart import jsonify, request + +from astrbot.core import logger +from astrbot.core.computer import computer_client +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle + +from .route import Response, Route, RouteContext +from .sandbox_helpers import ( + demo_mode_response, + is_demo_mode, + is_sandbox_limit_error, + is_sandbox_name_conflict, + is_sandbox_user_error, + sanitize_shell_timeout, +) + + +class SandboxRoute(Route): + def __init__( + self, + context: RouteContext, + core_lifecycle: AstrBotCoreLifecycle, + ) -> None: + super().__init__(context) + self.core_lifecycle = core_lifecycle + self.routes = [ + ("/sandbox/providers", ("GET", self.list_providers)), + ("/sandbox", ("GET", self.list_sandboxes)), + ("/sandbox/current", ("GET", self.get_current_sandbox)), + ("/sandbox/current", ("DELETE", self.release_current_sandbox)), + ("/sandbox", ("POST", self.create_sandbox)), + ("/sandbox//switch", ("POST", self.switch_sandbox)), + ("/sandbox//takeover", ("POST", self.takeover_sandbox)), + ("/sandbox//default", ("POST", self.set_default_sandbox)), + ("/sandbox//shell", ("POST", self.run_shell)), + ("/sandbox//screenshot", ("POST", self.capture_screenshot)), + ("/sandbox/", ("PATCH", self.update_sandbox)), + ("/sandbox/", ("DELETE", self.destroy_sandbox)), + ] + self.register_routes() + + def _session_id(self) -> str: + return request.args.get("session_id") or "dashboard" + + async def list_providers(self): + try: + config = self.core_lifecycle.star_context.get_config(umo=self._session_id()) + sandbox_config = config.get("provider_settings", {}).get("sandbox", {}) + default_provider_id = "" + if isinstance(sandbox_config, dict): + configured_provider_id = str(sandbox_config.get("booter") or "").strip() + if computer_client.get_sandbox_provider_info(configured_provider_id): + default_provider_id = configured_provider_id + return jsonify( + Response() + .ok( + data={ + "providers": computer_client.list_sandbox_providers(), + "default_provider_id": default_provider_id, + } + ) + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to list sandbox providers: {e!s}").__dict__ + ) + + async def list_sandboxes(self): + try: + return jsonify( + Response() + .ok( + data={ + "sandboxes": await computer_client.sandbox_manager.list_sandboxes_checked() + } + ) + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to list sandboxes: {e!s}").__dict__ + ) + + async def get_current_sandbox(self): + try: + return jsonify( + Response() + .ok( + data=computer_client.sandbox_manager.get_current_sandbox( + self._session_id() + ) + ) + .__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to get current sandbox: {e!s}").__dict__ + ) + + async def create_sandbox(self): + if is_demo_mode(): + return demo_mode_response() + try: + data = await request.get_json(silent=True) or {} + provider_id = str(data.get("provider_id") or "").strip() + if not provider_id: + return jsonify(Response().error("provider_id is required").__dict__) + sandbox = await computer_client.sandbox_manager.create_sandbox_uncontrolled_deferred( + self.core_lifecycle.star_context, + self._session_id(), + provider_id, + sandbox_name=data.get("sandbox_name"), + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except RuntimeError as e: + if is_sandbox_name_conflict(e) or is_sandbox_limit_error(e): + logger.warning(str(e)) + return jsonify(Response().error(str(e)).__dict__) + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to create sandbox: {e!s}").__dict__ + ) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to create sandbox: {e!s}").__dict__ + ) + + async def switch_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox = ( + await computer_client.sandbox_manager.switch_current_sandbox_checked( + self._session_id(), + sandbox_id, + context=self.core_lifecycle.star_context, + ) + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to switch sandbox: {e!s}").__dict__ + ) + + async def release_current_sandbox(self): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox_id = request.args.get("sandbox_id") + if sandbox_id: + sandbox = computer_client.sandbox_manager.force_release_sandbox( + sandbox_id + ) + else: + sandbox = computer_client.sandbox_manager.release_current_sandbox( + self._session_id() + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to release sandbox: {e!s}").__dict__ + ) + + async def takeover_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox = await computer_client.sandbox_manager.takeover_sandbox( + self._session_id(), sandbox_id, context=self.core_lifecycle.star_context + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to takeover sandbox: {e!s}").__dict__ + ) + + async def set_default_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox = computer_client.sandbox_manager.set_default_sandbox(sandbox_id) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to set default sandbox: {e!s}").__dict__ + ) + + async def run_shell(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + data = await request.get_json(silent=True) or {} + command = str(data.get("command") or "").strip() + if not command: + return jsonify(Response().error("command is required").__dict__) + # Dashboard shell access is an administrative operation; it does + # not need a lease so admins can operate any sandbox at any time. + booter = await computer_client.sandbox_manager.get_observer_booter_by_id( + sandbox_id, + self._session_id(), + require_lease=False, + context=self.core_lifecycle.star_context, + ) + shell = getattr(booter, "shell", None) + if shell is None: + return jsonify( + Response().error("Sandbox does not support shell.").__dict__ + ) + result = await shell.exec( + command, + cwd=data.get("cwd"), + env=data.get("env"), + timeout=sanitize_shell_timeout(data.get("timeout", 300)), + shell=data.get("shell", True), + ) + return jsonify(Response().ok(data={"result": result}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to run sandbox shell: {e!s}").__dict__ + ) + + async def capture_screenshot(self, sandbox_id: str): + try: + data = await request.get_json(silent=True) or {} + # Dashboard screenshot is a read-only observer operation; it does + # not need a lease and must not reset the sandbox idle timer. + booter = await computer_client.sandbox_manager.get_observer_booter_by_id( + sandbox_id, + self._session_id(), + require_lease=False, + context=self.core_lifecycle.star_context, + ) + gui = getattr(booter, "gui", None) + if gui is None: + return jsonify( + Response().error("Sandbox does not support screenshots.").__dict__ + ) + screenshot = await gui.screenshot(path=data.get("path")) + return jsonify(Response().ok(data={"screenshot": screenshot}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response() + .error(f"Failed to capture sandbox screenshot: {e!s}") + .__dict__ + ) + + async def update_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + data = await request.get_json(silent=True) or {} + current_sandbox = computer_client.sandbox_manager.registry.get_sandbox( + sandbox_id + ) + retention_policy = data.get( + "retention_policy", + current_sandbox.get("retention_policy", "temporary") + if current_sandbox + else "temporary", + ) + idle_timeout = data.get( + "idle_timeout", + current_sandbox.get("idle_timeout") if current_sandbox else None, + ) + expires_at = data.get( + "expires_at", + current_sandbox.get("expires_at") if current_sandbox else None, + ) + sandbox = computer_client.sandbox_manager.update_sandbox_config( + sandbox_id, + sandbox_name=data.get("sandbox_name"), + idle_timeout=idle_timeout, + expires_at=expires_at, + retention_policy=retention_policy, + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + if is_sandbox_user_error(e): + logger.info("Failed to update sandbox: %s", e) + else: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to update sandbox: {e!s}").__dict__ + ) + + async def destroy_sandbox(self, sandbox_id: str): + if is_demo_mode(): + return demo_mode_response() + try: + sandbox = await computer_client.sandbox_manager.destroy_sandbox_deferred( + self._session_id(), sandbox_id + ) + return jsonify(Response().ok(data={"sandbox": sandbox}).__dict__) + except Exception as e: + logger.error(traceback.format_exc()) + return jsonify( + Response().error(f"Failed to destroy sandbox: {e!s}").__dict__ + ) diff --git a/astrbot/dashboard/routes/sandbox_helpers.py b/astrbot/dashboard/routes/sandbox_helpers.py new file mode 100644 index 0000000000..45a61777da --- /dev/null +++ b/astrbot/dashboard/routes/sandbox_helpers.py @@ -0,0 +1,54 @@ +import math + +from quart import jsonify + +from astrbot.core import DEMO_MODE + +from .route import Response + + +def is_demo_mode() -> bool: + return DEMO_MODE + + +def demo_mode_response(): + return jsonify( + Response() + .error("You are not permitted to do this operation in demo mode") + .__dict__ + ) + + +def is_sandbox_name_conflict(error: Exception) -> bool: + return isinstance(error, RuntimeError) and str(error).startswith("Sandbox name ") + + +def is_sandbox_limit_error(error: Exception) -> bool: + return isinstance(error, RuntimeError) and str(error).startswith( + "Sandbox limit reached" + ) + + +def is_sandbox_user_error(error: Exception) -> bool: + if not isinstance(error, (RuntimeError, ValueError)): + return False + message = str(error) + return ( + is_sandbox_name_conflict(error) + or is_sandbox_limit_error(error) + or "does not support persistent sandboxes" in message + or "retention_policy must be" in message + or "sandbox_name must be" in message + ) + + +def sanitize_shell_timeout(value, default: float = 300) -> float: + if isinstance(value, bool): + return default + try: + timeout = float(value) + except (TypeError, ValueError): + return default + if not math.isfinite(timeout) or timeout <= 0: + return default + return timeout diff --git a/astrbot/dashboard/routes/skills.py b/astrbot/dashboard/routes/skills.py index c86598212e..9a395a2255 100644 --- a/astrbot/dashboard/routes/skills.py +++ b/astrbot/dashboard/routes/skills.py @@ -2,44 +2,17 @@ import re import shutil import traceback -from collections.abc import Awaitable, Callable from pathlib import Path -from typing import Any from quart import request, send_file from astrbot.core import DEMO_MODE, logger -from astrbot.core.computer.computer_client import ( - _discover_bay_credentials, - sync_skills_to_active_sandboxes, -) -from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager +from astrbot.core.computer.computer_client import sync_skills_to_active_sandboxes from astrbot.core.skills.skill_manager import SkillManager from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from .route import Response, Route, RouteContext - -def _to_jsonable(value: Any) -> Any: - if isinstance(value, dict): - return {k: _to_jsonable(v) for k, v in value.items()} - if isinstance(value, list): - return [_to_jsonable(v) for v in value] - if hasattr(value, "model_dump"): - return _to_jsonable(value.model_dump()) - return value - - -def _to_bool(value: Any, default: bool = False) -> bool: - if value is None: - return default - if isinstance(value, bool): - return value - if isinstance(value, str): - return value.strip().lower() in {"1", "true", "yes", "y", "on"} - return bool(value) - - _SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$") _SKILL_FILE_MAX_BYTES = 512 * 1024 _EDITABLE_SKILL_FILE_SUFFIXES = { @@ -87,15 +60,6 @@ def __init__(self, context: RouteContext, core_lifecycle) -> None: ], "/skills/update": ("POST", self.update_skill), "/skills/delete": ("POST", self.delete_skill), - "/skills/neo/candidates": ("GET", self.get_neo_candidates), - "/skills/neo/releases": ("GET", self.get_neo_releases), - "/skills/neo/payload": ("GET", self.get_neo_payload), - "/skills/neo/evaluate": ("POST", self.evaluate_neo_candidate), - "/skills/neo/promote": ("POST", self.promote_neo_candidate), - "/skills/neo/rollback": ("POST", self.rollback_neo_release), - "/skills/neo/sync": ("POST", self.sync_neo_release), - "/skills/neo/delete-candidate": ("POST", self.delete_neo_candidate), - "/skills/neo/delete-release": ("POST", self.delete_neo_release), } self.register_routes() @@ -179,58 +143,6 @@ def _serialize_skill_file_entry( ), } - def _get_neo_client_config(self) -> tuple[str, str]: - provider_settings = self.core_lifecycle.astrbot_config.get( - "provider_settings", - {}, - ) - sandbox = provider_settings.get("sandbox", {}) - endpoint = sandbox.get("shipyard_neo_endpoint", "") - access_token = sandbox.get("shipyard_neo_access_token", "") - - # Auto-discover token from Bay's credentials.json if not configured - if not access_token and endpoint: - access_token = _discover_bay_credentials(endpoint) - - if not endpoint or not access_token: - raise ValueError( - "Shipyard Neo endpoint or access token not configured. " - "Set them in Dashboard or ensure Bay's credentials.json is accessible." - ) - return endpoint, access_token - - async def _delete_neo_release( - self, client: Any, release_id: str, reason: str | None - ): - return await client.skills.delete_release(release_id, reason=reason) - - async def _delete_neo_candidate( - self, client: Any, candidate_id: str, reason: str | None - ): - return await client.skills.delete_candidate(candidate_id, reason=reason) - - async def _with_neo_client( - self, - operation: Callable[[Any], Awaitable[dict]], - ) -> dict: - try: - endpoint, access_token = self._get_neo_client_config() - - from shipyard_neo import BayClient - - async with BayClient( - endpoint_url=endpoint, - access_token=access_token, - ) as client: - return await operation(client) - except ValueError as e: - # Config not ready — expected when Neo isn't set up yet - logger.debug("[Neo] %s", e) - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - async def get_skills(self): try: provider_settings = self.core_lifecycle.astrbot_config.get( @@ -707,254 +619,3 @@ async def delete_skill(self): except Exception as e: logger.error(traceback.format_exc()) return Response().error(str(e)).__dict__ - - async def get_neo_candidates(self): - logger.info("[Neo] GET /skills/neo/candidates requested.") - status = request.args.get("status") - skill_key = request.args.get("skill_key") - limit = int(request.args.get("limit", 100)) - offset = int(request.args.get("offset", 0)) - - async def _do(client): - candidates = await client.skills.list_candidates( - status=status, - skill_key=skill_key, - limit=limit, - offset=offset, - ) - result = _to_jsonable(candidates) - total = result.get("total", "?") if isinstance(result, dict) else "?" - logger.info(f"[Neo] Candidates fetched: total={total}") - return Response().ok(result).__dict__ - - return await self._with_neo_client(_do) - - async def get_neo_releases(self): - logger.info("[Neo] GET /skills/neo/releases requested.") - skill_key = request.args.get("skill_key") - stage = request.args.get("stage") - active_only = _to_bool(request.args.get("active_only"), False) - limit = int(request.args.get("limit", 100)) - offset = int(request.args.get("offset", 0)) - - async def _do(client): - releases = await client.skills.list_releases( - skill_key=skill_key, - active_only=active_only, - stage=stage, - limit=limit, - offset=offset, - ) - result = _to_jsonable(releases) - total = result.get("total", "?") if isinstance(result, dict) else "?" - logger.info(f"[Neo] Releases fetched: total={total}") - return Response().ok(result).__dict__ - - return await self._with_neo_client(_do) - - async def get_neo_payload(self): - logger.info("[Neo] GET /skills/neo/payload requested.") - payload_ref = request.args.get("payload_ref", "") - if not payload_ref: - return Response().error("Missing payload_ref").__dict__ - - async def _do(client): - payload = await client.skills.get_payload(payload_ref) - logger.info(f"[Neo] Payload fetched: ref={payload_ref}") - return Response().ok(_to_jsonable(payload)).__dict__ - - return await self._with_neo_client(_do) - - async def evaluate_neo_candidate(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/evaluate requested.") - data = await request.get_json() - candidate_id = data.get("candidate_id") - passed_value = data.get("passed") - if not candidate_id or passed_value is None: - return Response().error("Missing candidate_id or passed").__dict__ - passed = _to_bool(passed_value, False) - - async def _do(client): - result = await client.skills.evaluate_candidate( - candidate_id, - passed=passed, - score=data.get("score"), - benchmark_id=data.get("benchmark_id"), - report=data.get("report"), - ) - logger.info( - f"[Neo] Candidate evaluated: id={candidate_id}, passed={passed}" - ) - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) - - async def promote_neo_candidate(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/promote requested.") - data = await request.get_json() - candidate_id = data.get("candidate_id") - stage = data.get("stage", "canary") - sync_to_local = _to_bool(data.get("sync_to_local"), True) - if not candidate_id: - return Response().error("Missing candidate_id").__dict__ - if stage not in {"canary", "stable"}: - return Response().error("Invalid stage, must be canary/stable").__dict__ - - async def _do(client): - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.promote_with_optional_sync( - client, - candidate_id=candidate_id, - stage=stage, - sync_to_local=sync_to_local, - ) - release_json = result.get("release") - logger.info(f"[Neo] Candidate promoted: id={candidate_id}, stage={stage}") - - sync_json = result.get("sync") - did_sync_to_local = bool(sync_json) - if did_sync_to_local: - logger.info( - f"[Neo] Stable release synced to local: skill={sync_json.get('local_skill_name', '')}" - ) - - if result.get("sync_error"): - resp = Response().error( - "Stable promote synced failed and has been rolled back. " - f"sync_error={result['sync_error']}" - ) - resp.data = { - "release": release_json, - "rollback": result.get("rollback"), - } - return resp.__dict__ - - # Try to push latest local skills to all active sandboxes. - if not did_sync_to_local: - try: - await sync_skills_to_active_sandboxes() - except Exception: - logger.warning("Failed to sync skills to active sandboxes.") - - return Response().ok({"release": release_json, "sync": sync_json}).__dict__ - - return await self._with_neo_client(_do) - - async def rollback_neo_release(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/rollback requested.") - data = await request.get_json() - release_id = data.get("release_id") - if not release_id: - return Response().error("Missing release_id").__dict__ - - async def _do(client): - result = await client.skills.rollback_release(release_id) - logger.info(f"[Neo] Release rolled back: id={release_id}") - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) - - async def sync_neo_release(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/sync requested.") - data = await request.get_json() - release_id = data.get("release_id") - skill_key = data.get("skill_key") - require_stable = _to_bool(data.get("require_stable"), True) - if not release_id and not skill_key: - return Response().error("Missing release_id or skill_key").__dict__ - - async def _do(client): - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.sync_release( - client, - release_id=release_id, - skill_key=skill_key, - require_stable=require_stable, - ) - logger.info( - f"[Neo] Release synced to local: skill={result.local_skill_name}, " - f"release_id={result.release_id}" - ) - return ( - Response() - .ok( - { - "skill_key": result.skill_key, - "local_skill_name": result.local_skill_name, - "release_id": result.release_id, - "candidate_id": result.candidate_id, - "payload_ref": result.payload_ref, - "map_path": result.map_path, - "synced_at": result.synced_at, - } - ) - .__dict__ - ) - - return await self._with_neo_client(_do) - - async def delete_neo_candidate(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/delete-candidate requested.") - data = await request.get_json() - candidate_id = data.get("candidate_id") - reason = data.get("reason") - if not candidate_id: - return Response().error("Missing candidate_id").__dict__ - - async def _do(client): - result = await self._delete_neo_candidate(client, candidate_id, reason) - logger.info(f"[Neo] Candidate deleted: id={candidate_id}") - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) - - async def delete_neo_release(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/delete-release requested.") - data = await request.get_json() - release_id = data.get("release_id") - reason = data.get("reason") - if not release_id: - return Response().error("Missing release_id").__dict__ - - async def _do(client): - result = await self._delete_neo_release(client, release_id, reason) - logger.info(f"[Neo] Release deleted: id={release_id}") - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py index 2ad80687ef..63864c65a6 100644 --- a/astrbot/dashboard/routes/tools.py +++ b/astrbot/dashboard/routes/tools.py @@ -4,6 +4,9 @@ from astrbot.core import logger, sp from astrbot.core.agent.mcp_client import MCPTool, validate_mcp_stdio_config +from astrbot.core.computer.sandbox_tool_binding import ( + get_sandbox_provider_tool_config_statuses, +) from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.star import star_map from astrbot.core.tools.registry import get_builtin_tool_config_statuses @@ -483,7 +486,21 @@ async def get_tool_list(self): readonly = False builtin_config_statuses = [] builtin_config_tags = [] - if self.tool_mgr.is_builtin_tool(tool.name): + sandbox_provider_id = getattr(tool, "sandbox_provider_id", None) + if sandbox_provider_id: + origin = "sandbox" + origin_name = str(sandbox_provider_id) + readonly = True + builtin_config_statuses = get_sandbox_provider_tool_config_statuses( + tool.name, + config_entries, + ) + builtin_config_tags = [ + status + for status in builtin_config_statuses + if status["enabled"] + ] + elif self.tool_mgr.is_builtin_tool(tool.name): origin = "builtin" origin_name = "AstrBot Core" readonly = True @@ -544,6 +561,18 @@ async def toggle_tool(self): .__dict__ ) + tool = next( + (t for t in self.tool_mgr.func_list if t.name == tool_name), None + ) + if getattr(tool, "sandbox_provider_id", None): + return ( + Response() + .error( + "Sandbox provider tools are read-only and cannot be toggled." + ) + .__dict__ + ) + if self.tool_mgr.is_builtin_tool(tool_name): return ( Response() @@ -604,6 +633,18 @@ async def update_tool_permission(self): .__dict__ ) + tool = next( + (t for t in self.tool_mgr.func_list if t.name == tool_name), None + ) + if getattr(tool, "sandbox_provider_id", None): + return ( + Response() + .error( + "Sandbox provider tools do not support per-tool permission configuration." + ) + .__dict__ + ) + if self.tool_mgr.is_builtin_tool(tool_name): return ( Response() diff --git a/astrbot/dashboard/server.py b/astrbot/dashboard/server.py index 9888da8f5f..0960b5f401 100644 --- a/astrbot/dashboard/server.py +++ b/astrbot/dashboard/server.py @@ -8,6 +8,7 @@ from datetime import datetime from pathlib import Path from typing import Protocol, cast +from urllib.parse import urlsplit import jwt import psutil @@ -41,6 +42,7 @@ from .routes.live_chat import LiveChatRoute from .routes.platform import PlatformRoute from .routes.route import Response, RouteContext +from .routes.sandbox import SandboxRoute from .routes.session_management import SessionManagementRoute from .routes.subagent import SubAgentRoute from .routes.t2i import T2iRoute @@ -300,6 +302,7 @@ def __init__( db, core_lifecycle, ) + self.sandbox_route = SandboxRoute(self.context, core_lifecycle) self.persona_route = PersonaRoute(self.context, db, core_lifecycle) self.cron_route = CronRoute(self.context, core_lifecycle) self.t2i_route = T2iRoute(self.context, core_lifecycle) @@ -435,6 +438,9 @@ async def auth_middleware(self): if not isinstance(username, str) or not username.strip(): raise jwt.InvalidTokenError("missing username in token payload") g.username = username + sandbox_origin_error = self._sandbox_cookie_origin_error() + if sandbox_origin_error is not None: + return sandbox_origin_error except jwt.ExpiredSignatureError: r = jsonify(Response().error("Token 过期").__dict__) r.status_code = 401 @@ -444,6 +450,38 @@ async def auth_middleware(self): r.status_code = 401 return r + def _sandbox_cookie_origin_error(self): + if request.method not in {"POST", "PATCH", "DELETE"}: + return None + if not request.path.startswith("/api/sandbox"): + return None + auth_header = request.headers.get("Authorization", "").strip() + if auth_header.startswith("Bearer "): + return None + + origin = request.headers.get("Origin", "").strip() + referer = request.headers.get("Referer", "").strip() + candidate = origin or referer + if not candidate: + return None + if self._is_same_dashboard_origin(candidate): + return None + + r = jsonify(Response().error("Origin is not allowed").__dict__) + r.status_code = 403 + return r + + @staticmethod + def _is_same_dashboard_origin(candidate: str) -> bool: + parsed_candidate = urlsplit(candidate) + if not parsed_candidate.scheme or not parsed_candidate.netloc: + return False + parsed_request = urlsplit(str(request.url_root)) + return ( + parsed_candidate.scheme == parsed_request.scheme + and parsed_candidate.netloc == parsed_request.netloc + ) + def _get_request_client_ip(self, current_request) -> str: if bool(self.config.get("dashboard", {}).get("trust_proxy_headers", False)): forwarded_for = str( diff --git a/dashboard/pnpm-workspace.yaml b/dashboard/pnpm-workspace.yaml index a23eae0866..1bdd797c8c 100644 --- a/dashboard/pnpm-workspace.yaml +++ b/dashboard/pnpm-workspace.yaml @@ -1,3 +1,6 @@ allowBuilds: esbuild: true vue-demi: true +onlyBuiltDependencies: + - esbuild + - vue-demi diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css b/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css index be565ba238..c143cf7979 100644 --- a/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css +++ b/dashboard/src/assets/mdi-subset/materialdesignicons-subset.css @@ -1,4 +1,4 @@ -/* Auto-generated MDI subset – 272 icons */ +/* Auto-generated MDI subset – 273 icons */ /* Do not edit manually. Run: pnpm run subset-icons */ @font-face { @@ -316,6 +316,10 @@ content: "\F1C2B"; } +.mdi-cube-outline::before { + content: "\F01A7"; +} + .mdi-cursor-default-click::before { content: "\F0CFD"; } diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff index 8e57a70b4f..42491d6927 100644 Binary files a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff and b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff differ diff --git a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 index 107a267095..a388b33bff 100644 Binary files a/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 and b/dashboard/src/assets/mdi-subset/materialdesignicons-webfont-subset.woff2 differ diff --git a/dashboard/src/components/extension/SkillsSection.vue b/dashboard/src/components/extension/SkillsSection.vue index 2ad1bddffe..56cb3f2db5 100644 --- a/dashboard/src/components/extension/SkillsSection.vue +++ b/dashboard/src/components/extension/SkillsSection.vue @@ -1521,19 +1521,9 @@ export default { }; const loadNeoAvailability = async () => { - try { - const res = await axios.get("/api/config/get"); - const config = res?.data?.data?.config || {}; - const providerSettings = config?.provider_settings || {}; - const currentRuntime = - providerSettings?.computer_use_runtime || "local"; - const booter = providerSettings?.sandbox?.booter || ""; - neoEnabled.value = - currentRuntime === "sandbox" && booter === "shipyard_neo"; - } catch (_err) { - neoEnabled.value = false; - } - + // Core no longer ships /api/skills/neo/* routes; Neo skill management is + // provided by external sandbox provider plugins instead of this page. + neoEnabled.value = false; neoUnavailableMessage.value = tm("skills.neoRuntimeRequired"); if (!neoEnabled.value && mode.value === "neo") { mode.value = "local"; diff --git a/dashboard/src/components/extension/componentPanel/components/ToolTable.vue b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue index 75e59f30a9..937fe4cbe2 100644 --- a/dashboard/src/components/extension/componentPanel/components/ToolTable.vue +++ b/dashboard/src/components/extension/componentPanel/components/ToolTable.vue @@ -68,7 +68,7 @@ const formatCondition = (condition: ToolConfigCondition) => { }; const enabledConfigTags = (tool: ToolItem): BuiltinToolConfigTag[] => { - if (tool.origin !== 'builtin') return []; + if (tool.origin !== 'builtin' && tool.origin !== 'sandbox') return []; return (tool.builtin_config_tags || []).filter(tag => tag.enabled); }; @@ -89,6 +89,23 @@ const getPermissionLabel = (permission?: string): string => { return tmTool('functionTools.table.permissionEveryone'); } }; + +const getOriginLabel = (origin?: string): string => { + switch (origin) { + case 'builtin': + return tmTool('functionTools.table.originBuiltin'); + case 'mcp': + return tmTool('functionTools.table.originMcp'); + case 'plugin': + return tmTool('functionTools.table.originPlugin'); + case 'sandbox': + return tmTool('functionTools.table.originSandbox'); + case 'unknown': + return tmTool('functionTools.table.originUnknown'); + default: + return origin || '-'; + } +};