diff --git a/astrbot/cli/commands/cmd_conf.py b/astrbot/cli/commands/cmd_conf.py index ac626e0d11..30dee61349 100644 --- a/astrbot/cli/commands/cmd_conf.py +++ b/astrbot/cli/commands/cmd_conf.py @@ -7,7 +7,7 @@ from astrbot.core.utils.auth_password import ( hash_dashboard_password, - hash_legacy_dashboard_password, + hash_md5_dashboard_password, validate_dashboard_password, ) @@ -147,7 +147,7 @@ def _set_dashboard_password(config: dict[str, Any], raw_password: str) -> None: _set_nested_item( config, "dashboard.password", - hash_legacy_dashboard_password(raw_password), + hash_md5_dashboard_password(raw_password), ) _set_nested_item(config, "dashboard.password_storage_upgraded", True) _set_nested_item(config, "dashboard.password_change_required", False) diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 4d62becb55..c42088ba9e 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -7,7 +7,7 @@ from astrbot.core.utils.auth_password import ( generate_dashboard_password, hash_dashboard_password, - hash_legacy_dashboard_password, + hash_md5_dashboard_password, validate_dashboard_password, ) @@ -64,11 +64,11 @@ def __init__( conf_str = conf_str[1:] conf = json.loads(conf_str) dashboard_conf = conf.get("dashboard") - legacy_dashboard_password_change_required = bool( + stored_dashboard_password_change_required = bool( isinstance(dashboard_conf, dict) and dashboard_conf.get("password_change_required", False) ) - if legacy_dashboard_password_change_required: + if stored_dashboard_password_change_required: object.__setattr__( self, "_dashboard_password_change_required_from_config", @@ -87,7 +87,7 @@ def __init__( elif ( "dashboard" in conf and isinstance(conf["dashboard"], dict) - and legacy_dashboard_password_change_required + and stored_dashboard_password_change_required and conf["dashboard"].get("pbkdf2_password") ): self._reset_generated_dashboard_password(conf) @@ -103,9 +103,7 @@ def _reset_generated_dashboard_password(self, conf: dict) -> None: conf["dashboard"]["pbkdf2_password"] = hash_dashboard_password( generated_password ) - conf["dashboard"]["password"] = hash_legacy_dashboard_password( - generated_password - ) + conf["dashboard"]["password"] = hash_md5_dashboard_password(generated_password) conf["dashboard"]["password_storage_upgraded"] = True conf["dashboard"]["password_change_required"] = True object.__setattr__( diff --git a/astrbot/core/platform/platform.py b/astrbot/core/platform/platform.py index b32891096e..af2b1a0b5e 100644 --- a/astrbot/core/platform/platform.py +++ b/astrbot/core/platform/platform.py @@ -157,7 +157,7 @@ async def webhook_callback(self, request: Any) -> Any: 当 Dashboard 收到 /api/platform/webhook/{uuid} 请求时,会调用此方法。 Args: - request: Quart 请求对象 + request: webhook 请求对象 Returns: 响应内容,格式取决于具体平台的要求 diff --git a/astrbot/core/platform/sources/lark/server.py b/astrbot/core/platform/sources/lark/server.py index 52177ebb0c..e83ab5a2fc 100644 --- a/astrbot/core/platform/sources/lark/server.py +++ b/astrbot/core/platform/sources/lark/server.py @@ -132,7 +132,7 @@ async def handle_callback(self, request) -> tuple[dict, int] | dict: """处理 webhook 回调,可被统一 webhook 入口复用 Args: - request: Quart 请求对象 + request: webhook 请求对象 Returns: 响应数据 diff --git a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py index 275f01c720..2c9f6cabab 100644 --- a/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py +++ b/astrbot/core/platform/sources/qqofficial_webhook/qo_webhook_server.py @@ -5,12 +5,12 @@ from binascii import Error as BinasciiError from typing import cast -import quart from botpy import BotAPI, BotHttp, BotWebSocket, Client, ConnectionSession, Token from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives.asymmetric import ed25519 from astrbot.api import logger +from astrbot.core.platform.webhook_server import FastAPIWebhookServer # remove logger handler for handler in logging.root.handlers[:]: @@ -90,7 +90,7 @@ def __init__( self.api: BotAPI = BotAPI(http=self.http) self.token = Token(self.appid, self.secret) - self.server = quart.Quart(__name__) + self.server = FastAPIWebhookServer("qq-official-webhook") self.server.add_url_rule( "/astrbot-qo-webhook/callback", view_func=self.callback, @@ -159,15 +159,15 @@ def pop_extra_data(self, message_id: str) -> dict: """Pop and return extra fields cached from the raw webhook payload for a given message ID.""" return self._extra_data_cache.pop(message_id, {}) - async def callback(self): + async def callback(self, request): """内部服务器的回调入口""" - return await self.handle_callback(quart.request) + return await self.handle_callback(request) async def handle_callback(self, request) -> dict: """处理 webhook 回调,可被统一 webhook 入口复用 Args: - request: Quart 请求对象 + request: FastAPI webhook request 对象 Returns: 响应数据 diff --git a/astrbot/core/platform/sources/slack/client.py b/astrbot/core/platform/sources/slack/client.py index efd7a6f3d2..97f5f26ae1 100644 --- a/astrbot/core/platform/sources/slack/client.py +++ b/astrbot/core/platform/sources/slack/client.py @@ -2,11 +2,10 @@ import hashlib import hmac import json -import logging from collections.abc import Callable from typing import cast -from quart import Quart, Response, request +from fastapi.responses import Response from slack_sdk.socket_mode.aiohttp import SocketModeClient from slack_sdk.socket_mode.async_client import AsyncBaseSocketModeClient from slack_sdk.socket_mode.request import SocketModeRequest @@ -14,10 +13,11 @@ from slack_sdk.web.async_client import AsyncWebClient from astrbot.api import logger +from astrbot.core.platform.webhook_server import FastAPIWebhookServer class SlackWebhookClient: - """Slack Webhook 模式客户端,使用 Quart 作为 Web 服务器""" + """Slack Webhook 模式客户端,使用 FastAPI 作为 Web 服务器""" def __init__( self, @@ -35,20 +35,16 @@ def __init__( self.path = path self.event_handler = event_handler - self.app = Quart(__name__) + self.app = FastAPIWebhookServer("slack-webhook") self._setup_routes() - # 禁用 Quart 的默认日志输出 - logging.getLogger("quart.app").setLevel(logging.WARNING) - logging.getLogger("quart.serving").setLevel(logging.WARNING) - self.shutdown_event = asyncio.Event() def _setup_routes(self) -> None: """设置路由""" @self.app.route(self.path, methods=["POST"]) - async def slack_events(): + async def slack_events(request): """内部服务器的 POST 回调入口""" return await self.handle_callback(request) @@ -61,7 +57,7 @@ async def handle_callback(self, req): """处理 Slack 回调请求,可被统一 webhook 入口复用 Args: - req: Quart 请求对象 + req: webhook 请求对象 Returns: Response 对象或字典 @@ -75,7 +71,7 @@ async def handle_callback(self, req): timestamp = req.headers.get("X-Slack-Request-Timestamp") signature = req.headers.get("X-Slack-Signature") if not timestamp or not signature: - return Response("Missing headers", status=400) + return Response("Missing headers", status_code=400) # Calculate the HMAC signature sig_basestring = f"v0:{timestamp}:{body.decode('utf-8')}" my_signature = ( @@ -89,7 +85,7 @@ async def handle_callback(self, req): # Verify the signature if not hmac.compare_digest(my_signature, signature): logger.warning("Slack request signature verification failed") - return Response("Invalid signature", status=400) + return Response("Invalid signature", status_code=400) logger.info(f"Received Slack event: {event_data}") # 处理 URL 验证事件 @@ -99,11 +95,11 @@ async def handle_callback(self, req): if self.event_handler and event_data.get("type") == "event_callback": await self.event_handler(event_data) - return Response("", status=200) + return Response("", status_code=200) except Exception as e: logger.error(f"处理 Slack 事件时出错: {e}") - return Response("Internal Server Error", status=500) + return Response("Internal Server Error", status_code=500) async def start(self) -> None: """启动 Webhook 服务器""" diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index 31436ebf2e..ee25994c50 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -8,7 +8,6 @@ from typing import Any, cast from urllib.parse import unquote -import quart from requests import Response from wechatpy.enterprise import WeChatClient, parse_message from wechatpy.enterprise.crypto import WeChatCrypto @@ -28,6 +27,7 @@ ) from astrbot.core import logger from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.platform.webhook_server import FastAPIWebhookServer from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.media_utils import convert_audio_to_wav from astrbot.core.utils.webhook_utils import log_webhook_info @@ -65,7 +65,7 @@ def _extract_wecom_media_filename(disposition: str | None) -> str | None: class WecomServer: def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: - self.server = quart.Quart(__name__) + self.server = FastAPIWebhookServer("wecom-webhook") self.port = int(cast(str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.server.add_url_rule( @@ -89,15 +89,15 @@ def __init__(self, event_queue: asyncio.Queue, config: dict) -> None: self.callback: Callable[[BaseMessage], Awaitable[None]] | None = None self.shutdown_event = asyncio.Event() - async def verify(self): + async def verify(self, request): """内部服务器的 GET 验证入口""" - return await self.handle_verify(quart.request) + return await self.handle_verify(request) async def handle_verify(self, request) -> str: """处理验证请求,可被统一 webhook 入口复用 Args: - request: Quart 请求对象 + request: FastAPI webhook request 对象 Returns: 验证响应 @@ -117,15 +117,15 @@ async def handle_verify(self, request) -> str: logger.error("验证请求有效性失败,签名异常,请检查配置。") raise - async def callback_command(self): + async def callback_command(self, request): """内部服务器的 POST 回调入口""" - return await self.handle_callback(quart.request) + return await self.handle_callback(request) async def handle_callback(self, request) -> str: """处理回调请求,可被统一 webhook 入口复用 Args: - request: Quart 请求对象 + request: FastAPI webhook request 对象 Returns: 响应内容 diff --git a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py index 80ec5179e3..acf162b123 100644 --- a/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py +++ b/astrbot/core/platform/sources/wecom_ai_bot/wecomai_server.py @@ -6,9 +6,8 @@ from collections.abc import Callable from typing import Any -import quart - from astrbot.api import logger +from astrbot.core.platform.webhook_server import FastAPIWebhookServer from .wecomai_api import WecomAIBotAPIClient from .wecomai_utils import WecomAIBotConstants @@ -38,14 +37,13 @@ def __init__( self.api_client = api_client self.message_handler = message_handler - self.app = quart.Quart(__name__) + self.app = FastAPIWebhookServer("wecom-ai-bot-webhook") self._setup_routes() self.shutdown_event = asyncio.Event() def _setup_routes(self) -> None: - """设置 Quart 路由""" - # 使用 Quart 的 add_url_rule 方法添加路由 + """设置 FastAPI 路由""" self.app.add_url_rule( "/webhook/wecom-ai-bot", view_func=self.verify_url, @@ -58,15 +56,15 @@ def _setup_routes(self) -> None: methods=["POST"], ) - async def verify_url(self): + async def verify_url(self, request): """内部服务器的 GET 验证入口""" - return await self.handle_verify(quart.request) + return await self.handle_verify(request) async def handle_verify(self, request): """处理 URL 验证请求,可被统一 webhook 入口复用 Args: - request: Quart 请求对象 + request: FastAPI webhook request 对象 Returns: 验证响应元组 (content, status_code, headers) @@ -91,15 +89,15 @@ async def handle_verify(self, request): result = self.api_client.verify_url(msg_signature, timestamp, nonce, echostr) return result, 200, {"Content-Type": "text/plain"} - async def handle_message(self): + async def handle_message(self, request): """内部服务器的 POST 消息回调入口""" - return await self.handle_callback(quart.request) + return await self.handle_callback(request) async def handle_callback(self, request): """处理消息回调,可被统一 webhook 入口复用 Args: - request: Quart 请求对象 + request: FastAPI webhook request 对象 Returns: 响应元组 (content, status_code, headers) @@ -186,5 +184,5 @@ async def shutdown(self) -> None: self.shutdown_event.set() def get_app(self): - """获取 Quart 应用实例""" - return self.app + """获取 FastAPI 应用实例""" + return self.app.app diff --git a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py index 8b646e43f3..5d05e75c14 100644 --- a/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py +++ b/astrbot/core/platform/sources/weixin_official_account/weixin_offacc_adapter.py @@ -5,7 +5,6 @@ from collections.abc import Callable, Coroutine from typing import Any, cast -import quart from requests import Response from wechatpy import WeChatClient, create_reply, parse_message from wechatpy.crypto import WeChatCrypto @@ -25,6 +24,7 @@ ) from astrbot.core import logger from astrbot.core.platform.astr_message_event import MessageSesion +from astrbot.core.platform.webhook_server import FastAPIWebhookServer from astrbot.core.utils.astrbot_path import get_astrbot_temp_path from astrbot.core.utils.media_utils import convert_audio_to_wav from astrbot.core.utils.webhook_utils import log_webhook_info @@ -44,7 +44,7 @@ def __init__( config: dict, user_buffer: dict[Any, dict[str, Any]], ) -> None: - self.server = quart.Quart(__name__) + self.server = FastAPIWebhookServer("weixin-official-account-webhook") self.port = int(cast(int | str, config.get("port"))) self.callback_server_host = config.get("callback_server_host", "0.0.0.0") self.token = config.get("token") @@ -73,15 +73,15 @@ def __init__( self.user_buffer: dict[str, dict[str, Any]] = user_buffer # from_user -> state self.active_send_mode = False # 是否启用主动发送模式,启用后 callback 将直接返回回复内容,无需等待微信回调 - async def verify(self): + async def verify(self, request): """内部服务器的 GET 验证入口""" - return await self.handle_verify(quart.request) + return await self.handle_verify(request) async def handle_verify(self, request) -> str: """处理验证请求,可被统一 webhook 入口复用 Args: - request: Quart 请求对象 + request: FastAPI webhook request 对象 Returns: 验证响应 @@ -105,9 +105,9 @@ async def handle_verify(self, request) -> str: logger.error("验证请求有效性失败,签名异常,请检查配置。") return "err" - async def callback_command(self): + async def callback_command(self, request): """内部服务器的 POST 回调入口""" - return await self.handle_callback(quart.request) + return await self.handle_callback(request) def _maybe_encrypt(self, xml: str, nonce: str | None, timestamp: str | None) -> str: if xml and "" not in xml and nonce and timestamp: @@ -129,7 +129,7 @@ async def handle_callback(self, request) -> str: """处理回调请求,可被统一 webhook 入口复用 Args: - request: Quart 请求对象 + request: FastAPI webhook request 对象 Returns: 响应内容 diff --git a/astrbot/core/platform/webhook_server.py b/astrbot/core/platform/webhook_server.py new file mode 100644 index 0000000000..8c9efb7fa1 --- /dev/null +++ b/astrbot/core/platform/webhook_server.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +import inspect +from collections.abc import Callable +from typing import Any + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response +from hypercorn.asyncio import serve +from hypercorn.config import Config as HyperConfig + + +class WebhookRequest: + def __init__(self, request: Request) -> None: + self._request = request + self.args = request.query_params + self.headers = request.headers + self.method = request.method + + @property + def json(self): + return self._request.json() + + async def get_data(self) -> bytes: + return await self._request.body() + + async def get_json(self, *, force: bool = False, silent: bool = False): + try: + return await self._request.json() + except Exception: + if silent: + return None + raise + + +def _response_from_result(result: Any): + if isinstance(result, Response): + return result + + if isinstance(result, tuple): + content = result[0] if result else "" + status_code = ( + result[1] if len(result) > 1 and isinstance(result[1], int) else 200 + ) + headers = result[2] if len(result) > 2 and isinstance(result[2], dict) else None + if isinstance(content, dict | list): + return JSONResponse(content, status_code=status_code, headers=headers) + return Response( + content=content, + status_code=status_code, + headers=headers, + media_type=headers.get("Content-Type") if headers else None, + ) + + if isinstance(result, dict | list): + return JSONResponse(result) + + return result + + +class FastAPIWebhookServer: + def __init__(self, name: str) -> None: + self.app = FastAPI(title=name, docs_url=None, redoc_url=None, openapi_url=None) + + def add_url_rule( + self, + path: str, + view_func: Callable, + methods: list[str] | None = None, + ) -> None: + async def endpoint(request: Request): + if inspect.signature(view_func).parameters: + result = view_func(WebhookRequest(request)) + else: + result = view_func() + if inspect.isawaitable(result): + result = await result + return _response_from_result(result) + + self.app.add_api_route( + path, + endpoint, + methods=methods or ["GET"], + include_in_schema=False, + ) + + def route(self, path: str, methods: list[str] | None = None): + def decorator(view_func: Callable): + self.add_url_rule(path, view_func, methods) + return view_func + + return decorator + + async def run_task( + self, + *, + host: str, + port: int, + shutdown_trigger: Callable | None = None, + **_kwargs, + ) -> None: + config = HyperConfig() + config.bind = [f"{host}:{port}"] + await serve(self.app, config, shutdown_trigger=shutdown_trigger) + + async def shutdown(self) -> None: + return None diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 824c3b653b..aee7a6e3dd 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -60,8 +60,8 @@ logger.warning("未安装 watchfiles,无法实现插件的热重载。") -class PluginVersionIncompatibleError(Exception): - """Raised when plugin astrbot_version is incompatible with current AstrBot.""" +class PluginVersionUnsupportedError(Exception): + """Raised when plugin astrbot_version is not supported by current AstrBot.""" class PluginDependencyInstallError(Exception): @@ -1029,9 +1029,9 @@ async def load( ) ) if not is_valid: - raise PluginVersionIncompatibleError( + raise PluginVersionUnsupportedError( error_message - or "The plugin is not compatible with the current AstrBot version." + or "The plugin does not support the current AstrBot version." ) logger.info(metadata) @@ -1160,9 +1160,9 @@ async def load( ) ) if not is_valid: - raise PluginVersionIncompatibleError( + raise PluginVersionUnsupportedError( error_message - or "The plugin is not compatible with the current AstrBot version." + or "The plugin does not support the current AstrBot version." ) metadata.star_cls = obj diff --git a/astrbot/core/utils/auth_password.py b/astrbot/core/utils/auth_password.py index bbb5ff0462..e39e038e34 100644 --- a/astrbot/core/utils/auth_password.py +++ b/astrbot/core/utils/auth_password.py @@ -10,7 +10,7 @@ _PBKDF2_SALT_BYTES = 16 _PBKDF2_ALGORITHM = "pbkdf2_sha256" _PBKDF2_FORMAT = f"{_PBKDF2_ALGORITHM}$" -_LEGACY_MD5_LENGTH = 32 +_MD5_HASH_LENGTH = 32 _DASHBOARD_PASSWORD_MIN_LENGTH = 8 _GENERATED_DASHBOARD_PASSWORD_LENGTH = 24 DEFAULT_DASHBOARD_PASSWORD = "astrbot" @@ -47,8 +47,8 @@ def hash_dashboard_password(raw_password: str) -> str: return f"{_PBKDF2_FORMAT}{_PBKDF2_ITERATIONS}${salt}${digest}" -def hash_legacy_dashboard_password(raw_password: str) -> str: - """Return legacy MD5 hash for downgrade compatibility only.""" +def hash_md5_dashboard_password(raw_password: str) -> str: + """Return the MD5 dashboard password hash kept for stored config fallback.""" if not isinstance(raw_password, str) or raw_password == "": raise ValueError("Password cannot be empty") return hashlib.md5(raw_password.encode("utf-8")).hexdigest() @@ -71,10 +71,10 @@ def validate_dashboard_password(raw_password: str) -> None: raise ValueError("Password must include at least one digit") -def _is_legacy_md5_hash(stored: str) -> bool: +def _is_md5_hash(stored: str) -> bool: return ( isinstance(stored, str) - and len(stored) == _LEGACY_MD5_LENGTH + and len(stored) == _MD5_HASH_LENGTH and all(c in "0123456789abcdefABCDEF" for c in stored) ) @@ -84,13 +84,13 @@ def _is_pbkdf2_hash(stored: str) -> bool: def verify_dashboard_password(stored_hash: str, candidate_password: str) -> bool: - """Verify password against legacy md5 or new PBKDF2-SHA256 format.""" + """Verify password against MD5 or PBKDF2-SHA256 storage.""" if not isinstance(stored_hash, str) or not isinstance(candidate_password, str): return False - if _is_legacy_md5_hash(stored_hash): - # Keep compatibility with existing MD5-based deployments while requiring - # the real plaintext password, not the stored MD5 value itself. + if _is_md5_hash(stored_hash): + # Support existing MD5-based deployments while requiring the real + # plaintext password, not the stored MD5 value itself. candidate_md5 = hashlib.md5(candidate_password.encode("utf-8")).hexdigest() return hmac.compare_digest(stored_hash.lower(), candidate_md5.lower()) @@ -121,6 +121,6 @@ def is_default_dashboard_password(stored_hash: str) -> bool: return verify_dashboard_password(stored_hash, DEFAULT_DASHBOARD_PASSWORD) -def is_legacy_dashboard_password(stored_hash: str) -> bool: - """Check whether the password is still stored with legacy MD5.""" - return _is_legacy_md5_hash(stored_hash) +def is_md5_dashboard_password(stored_hash: str) -> bool: + """Check whether the password is still stored as MD5.""" + return _is_md5_hash(stored_hash) diff --git a/astrbot/dashboard/__init__.py b/astrbot/dashboard/__init__.py new file mode 100644 index 0000000000..0fa6e01413 --- /dev/null +++ b/astrbot/dashboard/__init__.py @@ -0,0 +1 @@ +"""Dashboard HTTP API and service layer.""" diff --git a/astrbot/dashboard/api/__init__.py b/astrbot/dashboard/api/__init__.py new file mode 100644 index 0000000000..261bb9b65a --- /dev/null +++ b/astrbot/dashboard/api/__init__.py @@ -0,0 +1 @@ +"""Dashboard HTTP API package.""" diff --git a/astrbot/dashboard/api/api_keys.py b/astrbot/dashboard/api/api_keys.py new file mode 100644 index 0000000000..06b8e4bc90 --- /dev/null +++ b/astrbot/dashboard/api/api_keys.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Request + +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import ApiKeyCreateRequest, ApiKeyIdRequest +from astrbot.dashboard.services.api_key_service import ( + ApiKeyService, + ApiKeyServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["API Keys"]) +legacy_router = APIRouter( + prefix="/api/apikey", + tags=["Dashboard API Keys"], + include_in_schema=False, +) + + +async def require_system_scope(request: Request) -> AuthContext: + return await require_scope(request, "system") + + +def get_service(request: Request) -> ApiKeyService: + return request.app.state.services.api_keys + + +def _payload_dict(payload: ApiKeyCreateRequest) -> dict: + return payload.model_dump(exclude_none=True) + + +def _raise_api_key_error(exc: ApiKeyServiceError) -> None: + raise ApiError(str(exc)) from exc + + +async def _list_api_keys(service: ApiKeyService): + try: + return ok(await service.list_api_keys()) + except ApiKeyServiceError as exc: + _raise_api_key_error(exc) + + +async def _create_api_key( + payload: ApiKeyCreateRequest, + *, + created_by: str, + service: ApiKeyService, +): + try: + return ok( + await service.create_api_key( + _payload_dict(payload), + created_by=created_by, + ) + ) + except ApiKeyServiceError as exc: + _raise_api_key_error(exc) + + +async def _revoke_api_key(key_id: str, service: ApiKeyService): + try: + if not await service.revoke_api_key(key_id): + raise ApiKeyServiceError("API key not found") + return ok() + except ApiKeyServiceError as exc: + _raise_api_key_error(exc) + + +async def _delete_api_key(key_id: str, service: ApiKeyService): + try: + if not await service.delete_api_key(key_id): + raise ApiKeyServiceError("API key not found") + return ok() + except ApiKeyServiceError as exc: + _raise_api_key_error(exc) + + +@router.get("/api-keys") +async def list_api_keys( + _auth: AuthContext = Depends(require_system_scope), + service: ApiKeyService = Depends(get_service), +): + return await _list_api_keys(service) + + +@router.post("/api-keys") +async def create_api_key( + payload: ApiKeyCreateRequest, + auth: AuthContext = Depends(require_system_scope), + service: ApiKeyService = Depends(get_service), +): + return await _create_api_key(payload, created_by=auth.username, service=service) + + +@router.post("/api-keys/{key_id}/revoke") +async def revoke_api_key( + key_id: str, + _auth: AuthContext = Depends(require_system_scope), + service: ApiKeyService = Depends(get_service), +): + return await _revoke_api_key(key_id, service) + + +@router.delete("/api-keys/{key_id}") +async def delete_api_key( + key_id: str, + _auth: AuthContext = Depends(require_system_scope), + service: ApiKeyService = Depends(get_service), +): + return await _delete_api_key(key_id, service) + + +@legacy_router.get("/list") +async def list_dashboard_api_keys( + _username: str = Depends(require_dashboard_user), + service: ApiKeyService = Depends(get_service), +): + return await _list_api_keys(service) + + +@legacy_router.post("/create") +async def create_dashboard_api_key( + payload: ApiKeyCreateRequest, + username: str = Depends(require_dashboard_user), + service: ApiKeyService = Depends(get_service), +): + return await _create_api_key(payload, created_by=username, service=service) + + +@legacy_router.post("/revoke") +async def revoke_dashboard_api_key( + payload: ApiKeyIdRequest, + _username: str = Depends(require_dashboard_user), + service: ApiKeyService = Depends(get_service), +): + return await _revoke_api_key(payload.key_id, service) + + +@legacy_router.post("/delete") +async def delete_dashboard_api_key( + payload: ApiKeyIdRequest, + _username: str = Depends(require_dashboard_user), + service: ApiKeyService = Depends(get_service), +): + return await _delete_api_key(payload.key_id, service) diff --git a/astrbot/dashboard/api/app.py b/astrbot/dashboard/api/app.py new file mode 100644 index 0000000000..f0b35819a7 --- /dev/null +++ b/astrbot/dashboard/api/app.py @@ -0,0 +1,195 @@ +from __future__ import annotations + +from types import SimpleNamespace + +from fastapi import FastAPI, HTTPException, Request +from fastapi.responses import JSONResponse + +from astrbot.core import LogBroker +from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.db import BaseDatabase +from astrbot.dashboard.responses import ApiError, error +from astrbot.dashboard.services.api_key_service import ApiKeyService +from astrbot.dashboard.services.auth_service import AuthService +from astrbot.dashboard.services.backup_service import BackupService +from astrbot.dashboard.services.chat_service import ChatService +from astrbot.dashboard.services.chatui_project_service import ChatUIProjectService +from astrbot.dashboard.services.command_service import CommandService +from astrbot.dashboard.services.config_service import ( + BotConfigService, + ConfigDisplayService, + ConfigFileService, + ConfigProfileService, + ConfigRoutingService, + ProviderConfigService, +) +from astrbot.dashboard.services.conversation_service import ConversationService +from astrbot.dashboard.services.cron_service import CronService +from astrbot.dashboard.services.file_service import FileService +from astrbot.dashboard.services.knowledge_base_service import KnowledgeBaseService +from astrbot.dashboard.services.live_chat_service import LiveChatService +from astrbot.dashboard.services.log_service import LogService +from astrbot.dashboard.services.open_api_service import OpenApiService +from astrbot.dashboard.services.persona_service import PersonaService +from astrbot.dashboard.services.platform_service import PlatformService +from astrbot.dashboard.services.plugin_page_service import PluginPageService +from astrbot.dashboard.services.plugin_service import PluginService +from astrbot.dashboard.services.session_management_service import ( + SessionManagementService, +) +from astrbot.dashboard.services.skills_service import SkillsService +from astrbot.dashboard.services.stat_service import StatService +from astrbot.dashboard.services.subagent_service import SubAgentService +from astrbot.dashboard.services.t2i_service import T2iService +from astrbot.dashboard.services.tools_service import ToolsService +from astrbot.dashboard.services.update_service import ( + DEMO_MODE, + UpdateService, + call_check_migration_needed_v4, + call_do_migration_v4, + call_download_dashboard, + call_get_dashboard_version, + call_pip_install, +) + +from .api_keys import legacy_router as legacy_api_keys_router +from .auth import legacy_router as legacy_auth_router +from .backups import legacy_router as legacy_backups_router +from .bots import legacy_router as legacy_bots_router +from .chat import legacy_router as legacy_chat_router +from .chat_projects import legacy_router as legacy_chat_projects_router +from .config_profiles import legacy_router as legacy_config_profiles_router +from .conversations import legacy_router as legacy_conversations_router +from .cron import legacy_router as legacy_cron_router +from .extensions import legacy_router as legacy_extensions_router +from .files import legacy_router as legacy_files_router +from .knowledge_bases import legacy_router as legacy_knowledge_bases_router +from .live_chat import legacy_router as legacy_live_chat_router +from .logs import legacy_router as legacy_logs_router +from .personas import legacy_router as legacy_personas_router +from .platform import legacy_router as legacy_platform_router +from .plugins import legacy_router as legacy_plugins_router +from .providers import legacy_router as legacy_providers_router +from .router import API_V1_PREFIX, build_api_router +from .sessions import legacy_router as legacy_sessions_router +from .skills import legacy_router as legacy_skills_router +from .static_files import router as static_files_router +from .stats import legacy_router as legacy_stats_router +from .subagents import legacy_router as legacy_subagents_router +from .t2i import legacy_router as legacy_t2i_router +from .tools import legacy_router as legacy_tools_router +from .updates import legacy_router as legacy_updates_router + +CLEAR_SITE_DATA_HEADERS = {"Clear-Site-Data": '"cache"'} + + +def create_dashboard_asgi_app( + *, + core_lifecycle: AstrBotCoreLifecycle, + db: BaseDatabase, + jwt_secret: str, + static_folder: str | None = None, +) -> FastAPI: + app = FastAPI( + title="AstrBot OpenAPI", + version="1.0.0", + openapi_url=f"{API_V1_PREFIX}/openapi.json", + docs_url=f"{API_V1_PREFIX}/docs", + redoc_url=f"{API_V1_PREFIX}/redoc", + ) + app.state.core_lifecycle = core_lifecycle + app.state.db = db + app.state.jwt_secret = jwt_secret + app.state.dashboard_static_folder = static_folder + log_broker = getattr(core_lifecycle, "log_broker", None) or LogBroker() + app.state.services = SimpleNamespace( + config_profiles=ConfigProfileService(core_lifecycle, db), + config_display=ConfigDisplayService(core_lifecycle), + config_files=ConfigFileService(core_lifecycle), + config_routes=ConfigRoutingService(core_lifecycle), + api_keys=ApiKeyService(db), + auth=AuthService(db, core_lifecycle.astrbot_config), + backups=BackupService(db, core_lifecycle), + chat=ChatService(db, core_lifecycle), + chat_projects=ChatUIProjectService(db), + commands=CommandService(core_lifecycle.astrbot_config, core_lifecycle), + conversations=ConversationService(db, core_lifecycle), + cron=CronService(core_lifecycle), + files=FileService(), + knowledge_bases=KnowledgeBaseService(core_lifecycle), + live_chat=LiveChatService(db, core_lifecycle), + logs=LogService(log_broker, core_lifecycle.astrbot_config), + bots=BotConfigService(core_lifecycle), + platforms=PlatformService(core_lifecycle), + providers=ProviderConfigService(core_lifecycle), + personas=PersonaService(core_lifecycle), + plugins=PluginService(core_lifecycle, core_lifecycle.plugin_manager), + plugin_pages=PluginPageService( + core_lifecycle.plugin_manager, + core_lifecycle=core_lifecycle, + ), + open_api=OpenApiService(db, core_lifecycle), + sessions=SessionManagementService(core_lifecycle, db), + skills=SkillsService(core_lifecycle), + stats=StatService(db, core_lifecycle, core_lifecycle.astrbot_config), + subagents=SubAgentService(core_lifecycle), + t2i=T2iService(core_lifecycle), + tools=ToolsService(core_lifecycle), + updates=UpdateService( + core_lifecycle.astrbot_updator, + core_lifecycle, + download_dashboard_func=call_download_dashboard, + get_dashboard_version_func=call_get_dashboard_version, + pip_install_func=call_pip_install, + check_migration_needed_func=call_check_migration_needed_v4, + do_migration_func=call_do_migration_v4, + demo_mode=DEMO_MODE, + clear_site_data_headers=CLEAR_SITE_DATA_HEADERS, + ), + ) + + @app.exception_handler(ApiError) + async def api_error_handler(_request: Request, exc: ApiError): + return JSONResponse( + error(exc.message, exc.data), + status_code=exc.status_code, + ) + + @app.exception_handler(ValueError) + async def value_error_handler(_request: Request, exc: ValueError): + return JSONResponse(error(str(exc)), status_code=400) + + @app.exception_handler(HTTPException) + async def http_error_handler(_request: Request, exc: HTTPException): + detail = exc.detail if isinstance(exc.detail, str) else "Request failed" + return JSONResponse(error(detail), status_code=exc.status_code) + + # Legacy dashboard routes keep old /api/* callers working without entering OpenAPI. + app.include_router(legacy_api_keys_router) + app.include_router(legacy_auth_router) + app.include_router(legacy_backups_router) + app.include_router(legacy_config_profiles_router) + app.include_router(legacy_bots_router) + app.include_router(legacy_providers_router) + app.include_router(legacy_chat_router) + app.include_router(legacy_chat_projects_router) + app.include_router(legacy_conversations_router) + app.include_router(legacy_cron_router) + app.include_router(legacy_extensions_router) + app.include_router(legacy_files_router) + app.include_router(legacy_knowledge_bases_router) + app.include_router(legacy_live_chat_router) + app.include_router(legacy_logs_router) + app.include_router(legacy_sessions_router) + app.include_router(legacy_skills_router) + app.include_router(legacy_stats_router) + app.include_router(legacy_subagents_router) + app.include_router(legacy_tools_router) + app.include_router(legacy_platform_router) + app.include_router(legacy_plugins_router) + app.include_router(legacy_t2i_router) + app.include_router(legacy_personas_router) + app.include_router(legacy_updates_router) + app.include_router(build_api_router()) + app.include_router(static_files_router) + return app diff --git a/astrbot/dashboard/api/auth.py b/astrbot/dashboard/api/auth.py new file mode 100644 index 0000000000..d79134e9ee --- /dev/null +++ b/astrbot/dashboard/api/auth.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +from dataclasses import dataclass + +import jwt +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse + +from astrbot.dashboard.responses import ApiError +from astrbot.dashboard.schemas import ( + AccountUpdateRequest, + AuthSetupRequest, + LoginRequest, + TotpSetupRequest, +) +from astrbot.dashboard.services.api_key_service import ApiKeyService +from astrbot.dashboard.services.auth_service import ( + ALL_OPEN_API_SCOPES, + DASHBOARD_JWT_COOKIE_MAX_AGE, + DASHBOARD_JWT_COOKIE_NAME, + TOTP_TRUSTED_DEVICE_COOKIE_NAME, + TOTP_TRUSTED_DEVICE_MAX_AGE, + AuthService, + AuthServiceResult, +) + +router = APIRouter(tags=["Auth"]) +legacy_router = APIRouter( + prefix="/api/auth", + tags=["Dashboard Auth"], + include_in_schema=False, +) + + +@dataclass(frozen=True) +class AuthContext: + username: str + scopes: list[str] + api_key_id: str | None = None + via: str = "jwt" + + +def _extract_raw_api_key(request: Request) -> str | None: + auth_header = request.headers.get("Authorization", "").strip() + if auth_header.startswith("Bearer "): + return None + if auth_header.startswith("ApiKey "): + return auth_header.removeprefix("ApiKey ").strip() + if key := request.query_params.get("api_key"): + return key.strip() + if key := request.query_params.get("key"): + return key.strip() + if key := request.headers.get("X-API-Key"): + return key.strip() + return None + + +def _get_dashboard_state_username(request: Request) -> str | None: + dashboard_g = getattr(request.state, "dashboard_g", None) + if dashboard_g is None: + return None + + username = getattr(dashboard_g, "username", None) + if username is None and hasattr(dashboard_g, "get"): + username = dashboard_g.get("username") + if isinstance(username, str) and username.strip(): + return username + return None + + +def _extract_dashboard_jwt(request: Request) -> str | None: + auth_header = request.headers.get("Authorization", "").strip() + if auth_header.startswith("Bearer "): + token = auth_header.removeprefix("Bearer ").strip() + if token: + return token + + cookie_token = request.cookies.get(DASHBOARD_JWT_COOKIE_NAME, "").strip() + if cookie_token: + return cookie_token + return None + + +async def require_dashboard_user(request: Request) -> str: + if username := _get_dashboard_state_username(request): + return username + + token = _extract_dashboard_jwt(request) + if not token: + raise ApiError("未授权", status_code=401) + + try: + payload = jwt.decode( + token, + request.app.state.jwt_secret, + algorithms=["HS256"], + ) + except jwt.ExpiredSignatureError as exc: + raise ApiError("Token 过期", status_code=401) from exc + except jwt.InvalidTokenError as exc: + raise ApiError("Token 无效", status_code=401) from exc + + username = payload.get("username") + if not isinstance(username, str) or not username.strip(): + raise ApiError("Token 无效", status_code=401) + return username + + +async def _require_api_key_scope( + request: Request, + raw_key: str, + scope: str, +) -> AuthContext: + if scope not in ALL_OPEN_API_SCOPES: + raise ApiError("Insufficient API key scope", status_code=403) + + key_hash = ApiKeyService.hash_key(raw_key) + api_key = await request.app.state.db.get_active_api_key_by_hash(key_hash) + if not api_key: + raise ApiError("Invalid API key", status_code=401) + scopes = ( + [str(scope) for scope in api_key.scopes] + if isinstance(api_key.scopes, list) + else [str(scope) for scope in ALL_OPEN_API_SCOPES] + ) + if "*" not in scopes and scope not in scopes: + raise ApiError("Insufficient API key scope", status_code=403) + await request.app.state.db.touch_api_key(api_key.key_id) + return AuthContext( + username=f"api_key:{api_key.key_id}", + scopes=scopes, + api_key_id=api_key.key_id, + via="api_key", + ) + + +async def require_scope(request: Request, scope: str) -> AuthContext: + raw_key = _extract_raw_api_key(request) + if raw_key: + return await _require_api_key_scope(request, raw_key, scope) + + auth_header = request.headers.get("Authorization", "").strip() + if not auth_header.startswith("Bearer "): + raise ApiError("Missing API key", status_code=401) + token = auth_header.removeprefix("Bearer ").strip() + try: + payload = jwt.decode( + token, + request.app.state.jwt_secret, + algorithms=["HS256"], + ) + except jwt.ExpiredSignatureError as exc: + raise ApiError("Token expired", status_code=401) from exc + except jwt.InvalidTokenError as exc: + try: + return await _require_api_key_scope(request, token, scope) + except ApiError as api_key_exc: + raise api_key_exc from exc + + username = payload.get("username") + if not isinstance(username, str) or not username.strip(): + raise ApiError("Invalid token", status_code=401) + return AuthContext(username=username, scopes=["*"], via="jwt") + + +def get_auth_service(request: Request) -> AuthService: + return request.app.state.services.auth + + +def _payload(payload) -> dict: + return payload.model_dump(exclude_none=True) + + +def _auth_result_payload(result: AuthServiceResult) -> dict: + data = result.data if result.data is not None else {} + payload = { + "status": result.status, + "message": result.message, + "data": data, + } + if result.status == "error" and result.data is None: + payload["data"] = None + return payload + + +def _use_secure_dashboard_jwt_cookie(request: Request) -> bool: + adapter = getattr(request.app.state, "dashboard_app_adapter", None) + adapter_config = getattr(adapter, "config", {}) if adapter is not None else {} + default_secure = not bool(getattr(adapter, "debug", False)) and not bool( + getattr(adapter, "testing", False) + ) + return bool( + adapter_config.get( + "DASHBOARD_JWT_COOKIE_SECURE", + default_secure, + ) + ) + + +def _set_dashboard_jwt_cookie( + request: Request, + response: JSONResponse, + token: str, +) -> None: + response.set_cookie( + DASHBOARD_JWT_COOKIE_NAME, + token, + max_age=DASHBOARD_JWT_COOKIE_MAX_AGE, + httponly=True, + samesite="strict", + secure=_use_secure_dashboard_jwt_cookie(request), + path="/", + ) + + +def _clear_dashboard_jwt_cookie(request: Request, response: JSONResponse) -> None: + response.delete_cookie( + DASHBOARD_JWT_COOKIE_NAME, + httponly=True, + samesite="strict", + secure=_use_secure_dashboard_jwt_cookie(request), + path="/", + ) + + +def _set_trusted_device_cookie( + request: Request, + response: JSONResponse, + token: str, +) -> None: + response.set_cookie( + TOTP_TRUSTED_DEVICE_COOKIE_NAME, + token, + max_age=TOTP_TRUSTED_DEVICE_MAX_AGE, + httponly=True, + samesite="strict", + secure=_use_secure_dashboard_jwt_cookie(request), + path="/api/auth", + ) + + +def _auth_service_response( + request: Request, + result: AuthServiceResult, +) -> JSONResponse: + response = JSONResponse( + _auth_result_payload(result), + status_code=result.status_code, + ) + if result.jwt_token: + _set_dashboard_jwt_cookie(request, response, result.jwt_token) + if result.trusted_device_token: + _set_trusted_device_cookie(request, response, result.trusted_device_token) + return response + + +def _has_auth_credentials(request: Request) -> bool: + auth_header = request.headers.get("Authorization", "") + return bool( + auth_header.startswith(("Bearer ", "ApiKey ")) + or request.query_params.get("api_key") + or request.query_params.get("key") + or request.headers.get("X-API-Key") + ) + + +async def require_system_scope(request: Request) -> AuthContext: + return await require_scope(request, "system") + + +async def optional_system_auth(request: Request) -> AuthContext | None: + if not _has_auth_credentials(request): + return None + return await require_system_scope(request) + + +async def _login( + request: Request, + payload: LoginRequest, + service: AuthService, +): + result = await service.login( + _payload(payload), + trusted_device_cookie_token=request.cookies.get( + TOTP_TRUSTED_DEVICE_COOKIE_NAME, + "", + ).strip(), + ) + return _auth_service_response( + request, + result, + ) + + +async def _setup_status(service: AuthService): + return _auth_service_response_from_result(await service.setup_status()) + + +def _auth_service_response_from_result(result: AuthServiceResult) -> JSONResponse: + return JSONResponse( + _auth_result_payload(result), + status_code=result.status_code, + ) + + +async def _setup( + request: Request, + payload: AuthSetupRequest, + service: AuthService, + auth: AuthContext | None, +): + if auth is None: + result = await service.setup(_payload(payload)) + else: + result = await service.setup_authenticated(_payload(payload), auth.username) + return _auth_service_response( + request, + result, + ) + + +async def _totp_setup( + request: Request, + payload: TotpSetupRequest, + service: AuthService, +): + return _auth_service_response( + request, + await service.totp_setup(_payload(payload)), + ) + + +async def _totp_recovery( + request: Request, + service: AuthService, +): + return _auth_service_response(request, await service.totp_recovery()) + + +async def _update_account( + request: Request, + payload: AccountUpdateRequest, + service: AuthService, +): + return _auth_service_response( + request, + await service.edit_account(_payload(payload)), + ) + + +@router.post("/auth/login") +async def login( + request: Request, + payload: LoginRequest, + service: AuthService = Depends(get_auth_service), +): + return await _login(request, payload, service) + + +@legacy_router.post("/login") +async def dashboard_login( + request: Request, + payload: LoginRequest, + service: AuthService = Depends(get_auth_service), +): + return await _login(request, payload, service) + + +@router.post("/auth/logout") +async def logout(request: Request): + response = JSONResponse( + {"status": "ok", "message": "已退出登录", "data": {}}, + status_code=200, + ) + _clear_dashboard_jwt_cookie(request, response) + return response + + +@legacy_router.post("/logout") +async def dashboard_logout(request: Request): + return await logout(request) + + +@router.get("/auth/setup-status") +async def setup_status( + service: AuthService = Depends(get_auth_service), +): + return _auth_service_response_from_result(await service.setup_status()) + + +@legacy_router.get("/setup-status") +async def dashboard_setup_status( + service: AuthService = Depends(get_auth_service), +): + return _auth_service_response_from_result(await service.setup_status()) + + +@router.post("/auth/setup") +async def setup( + request: Request, + payload: AuthSetupRequest, + auth: AuthContext | None = Depends(optional_system_auth), + service: AuthService = Depends(get_auth_service), +): + return await _setup(request, payload, service, auth) + + +@legacy_router.post("/setup") +async def dashboard_setup( + request: Request, + payload: AuthSetupRequest, + service: AuthService = Depends(get_auth_service), +): + return await _setup(request, payload, service, None) + + +@legacy_router.post("/setup-authenticated") +async def dashboard_setup_authenticated( + request: Request, + payload: AuthSetupRequest, + username: str = Depends(require_dashboard_user), + service: AuthService = Depends(get_auth_service), +): + auth = AuthContext(username=username, scopes=["*"], via="jwt") + return await _setup(request, payload, service, auth) + + +@router.post("/auth/totp/setup") +async def totp_setup( + request: Request, + payload: TotpSetupRequest, + _auth: AuthContext = Depends(require_system_scope), + service: AuthService = Depends(get_auth_service), +): + return await _totp_setup(request, payload, service) + + +@legacy_router.post("/totp/setup") +async def dashboard_totp_setup( + request: Request, + payload: TotpSetupRequest, + _username: str = Depends(require_dashboard_user), + service: AuthService = Depends(get_auth_service), +): + return await _totp_setup(request, payload, service) + + +@router.post("/auth/totp/recovery") +async def totp_recovery( + request: Request, + _auth: AuthContext = Depends(require_system_scope), + service: AuthService = Depends(get_auth_service), +): + return await _totp_recovery(request, service) + + +@legacy_router.post("/totp/recovery") +async def dashboard_totp_recovery( + request: Request, + _username: str = Depends(require_dashboard_user), + service: AuthService = Depends(get_auth_service), +): + return await _totp_recovery(request, service) + + +@router.patch("/auth/account") +async def update_account( + request: Request, + payload: AccountUpdateRequest, + _auth: AuthContext = Depends(require_system_scope), + service: AuthService = Depends(get_auth_service), +): + return await _update_account(request, payload, service) + + +@legacy_router.post("/account/edit") +async def dashboard_update_account( + request: Request, + payload: AccountUpdateRequest, + _username: str = Depends(require_dashboard_user), + service: AuthService = Depends(get_auth_service), +): + return await _update_account(request, payload, service) diff --git a/astrbot/dashboard/api/backups.py b/astrbot/dashboard/api/backups.py new file mode 100644 index 0000000000..0e04a1f92d --- /dev/null +++ b/astrbot/dashboard/api/backups.py @@ -0,0 +1,406 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, File, Form, Query, Request, UploadFile +from fastapi.responses import FileResponse + +from astrbot.core import logger +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.schemas import ( + BackupImportRequest, + BackupRenameRequest, + BackupUploadInitRequest, + BackupUploadSessionRequest, +) +from astrbot.dashboard.services.backup_service import ( + BackupService, + BackupServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Backups"]) +legacy_router = APIRouter( + prefix="/api/backup", + tags=["Dashboard Backups"], + include_in_schema=False, +) + + +def get_service(request: Request) -> BackupService: + return request.app.state.services.backups + + +async def require_system_scope(request: Request) -> AuthContext: + return await require_scope(request, "system") + + +def _model_dict(payload) -> dict: + return payload.model_dump(exclude_none=True) + + +def _ok_result(result): + if isinstance(result, tuple): + data, message = result + return ok(data, message) + return ok(result) + + +async def _json_or_empty(request: Request) -> dict: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +async def _run(operation, *, prefix: str): + try: + result = await run_maybe_async(operation) + return _ok_result(result) + except BackupServiceError as exc: + return error(str(exc)) + except Exception as exc: + logger.error("%s: %s", prefix, exc, exc_info=True) + return error(f"{prefix}: {exc!s}") + + +def _download_response(download) -> FileResponse: + return FileResponse( + download.path, + filename=download.filename, + media_type="application/zip", + ) + + +def _download_backup( + *, + filename: str | None, + token: str | None, + service: BackupService, +): + try: + return _download_response( + service.prepare_download( + filename=filename, + token=token, + jwt_secret=service.config.get("dashboard", {}).get("jwt_secret"), + ) + ) + except BackupServiceError as exc: + return error(str(exc)) + except Exception as exc: + logger.error("下载备份失败: %s", exc, exc_info=True) + return error(f"下载备份失败: {exc!s}") + + +@router.get("/backups") +async def list_backups( + page: int = Query(default=1), + page_size: int = Query(default=20), + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.list_backups(page=page, page_size=page_size), + prefix="获取备份列表失败", + ) + + +@legacy_router.get("/list") +async def list_dashboard_backups( + page: int = Query(default=1), + page_size: int = Query(default=20), + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.list_backups(page=page, page_size=page_size), + prefix="获取备份列表失败", + ) + + +@router.post("/backups") +async def create_backup( + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run(service.export_backup, prefix="创建备份失败") + + +@legacy_router.post("/export") +async def export_dashboard_backup( + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + return await _run(service.export_backup, prefix="创建备份失败") + + +@router.post("/backups/upload") +async def upload_backup( + file: UploadFile = File(...), + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run(lambda: service.upload_backup(file), prefix="上传备份文件失败") + + +@legacy_router.post("/upload") +async def upload_dashboard_backup( + file: UploadFile = File(...), + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + return await _run(lambda: service.upload_backup(file), prefix="上传备份文件失败") + + +@router.post("/backups/upload/init") +async def init_backup_upload( + payload: BackupUploadInitRequest, + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.upload_init(_model_dict(payload)), + prefix="初始化分片上传失败", + ) + + +@legacy_router.post("/upload/init") +async def init_dashboard_backup_upload( + payload: BackupUploadInitRequest, + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.upload_init(_model_dict(payload)), + prefix="初始化分片上传失败", + ) + + +@router.post("/backups/upload/chunk") +async def upload_backup_chunk( + upload_id: str = Form(...), + chunk_index: str = Form(...), + chunk: UploadFile = File(...), + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.upload_chunk( + upload_id=upload_id, + chunk_index_str=chunk_index, + chunk_file=chunk, + ), + prefix="上传分片失败", + ) + + +@legacy_router.post("/upload/chunk") +async def upload_dashboard_backup_chunk( + upload_id: str = Form(...), + chunk_index: str = Form(...), + chunk: UploadFile = File(...), + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.upload_chunk( + upload_id=upload_id, + chunk_index_str=chunk_index, + chunk_file=chunk, + ), + prefix="上传分片失败", + ) + + +@router.post("/backups/upload/complete") +async def complete_backup_upload( + payload: BackupUploadSessionRequest, + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.upload_complete(_model_dict(payload)), + prefix="完成分片上传失败", + ) + + +@legacy_router.post("/upload/complete") +async def complete_dashboard_backup_upload( + payload: BackupUploadSessionRequest, + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.upload_complete(_model_dict(payload)), + prefix="完成分片上传失败", + ) + + +@router.post("/backups/upload/abort") +async def abort_backup_upload( + payload: BackupUploadSessionRequest, + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.upload_abort(_model_dict(payload)), + prefix="取消上传失败", + ) + + +@legacy_router.post("/upload/abort") +async def abort_dashboard_backup_upload( + payload: BackupUploadSessionRequest, + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.upload_abort(_model_dict(payload)), + prefix="取消上传失败", + ) + + +@router.get("/backups/tasks/{task_id}") +async def get_backup_progress( + task_id: str, + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run(lambda: service.get_progress(task_id), prefix="获取任务进度失败") + + +@legacy_router.get("/progress") +async def get_dashboard_backup_progress( + task_id: str | None = Query(default=None), + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.get_progress(task_id), + prefix="获取任务进度失败", + ) + + +@router.get("/backups/{filename:path}") +async def download_backup( + filename: str, + token: str | None = Query(default=None), + service: BackupService = Depends(get_service), +): + return _download_backup(filename=filename, token=token, service=service) + + +@legacy_router.get("/download") +async def download_dashboard_backup( + filename: str | None = Query(default=None), + token: str | None = Query(default=None), + service: BackupService = Depends(get_service), +): + return _download_backup(filename=filename, token=token, service=service) + + +@router.patch("/backups/{filename:path}") +async def rename_backup( + filename: str, + payload: BackupRenameRequest, + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.rename_backup({"filename": filename, **_model_dict(payload)}), + prefix="重命名备份失败", + ) + + +@legacy_router.post("/rename") +async def rename_dashboard_backup( + payload: BackupRenameRequest, + filename: str | None = Query(default=None), + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.rename_backup({"filename": filename, **_model_dict(payload)}), + prefix="重命名备份失败", + ) + + +@router.delete("/backups/{filename:path}") +async def delete_backup( + filename: str, + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.delete_backup({"filename": filename}), + prefix="删除备份失败", + ) + + +@legacy_router.post("/delete") +async def delete_dashboard_backup( + request: Request, + filename: str | None = Query(default=None), + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run( + lambda: service.delete_backup({"filename": filename, **body}), + prefix="删除备份失败", + ) + + +@router.post("/backups/{filename:path}/check") +async def check_backup( + filename: str, + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.check_backup({"filename": filename}), + prefix="预检查备份文件失败", + ) + + +@legacy_router.post("/check") +async def check_dashboard_backup( + request: Request, + filename: str | None = Query(default=None), + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run( + lambda: service.check_backup({"filename": filename, **body}), + prefix="预检查备份文件失败", + ) + + +@router.post("/backups/{filename:path}/import") +async def import_backup( + filename: str, + payload: BackupImportRequest, + _auth: AuthContext = Depends(require_system_scope), + service: BackupService = Depends(get_service), +): + return await _run( + lambda: service.import_backup({"filename": filename, **_model_dict(payload)}), + prefix="导入备份失败", + ) + + +@legacy_router.post("/import") +async def import_dashboard_backup( + request: Request, + filename: str | None = Query(default=None), + _username: str = Depends(require_dashboard_user), + service: BackupService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run( + lambda: service.import_backup({"filename": filename, **body}), + prefix="导入备份失败", + ) diff --git a/astrbot/dashboard/api/bots.py b/astrbot/dashboard/api/bots.py new file mode 100644 index 0000000000..b96c8f3789 --- /dev/null +++ b/astrbot/dashboard/api/bots.py @@ -0,0 +1,253 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request + +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.schemas import BotConfigRequest, EnabledPatch +from astrbot.dashboard.services.config_service import BotConfigService + +from .auth import AuthContext, require_scope + +router = APIRouter(tags=["Bots"]) +legacy_router = APIRouter( + prefix="/api/config/platform", + tags=["Dashboard Bots"], + include_in_schema=False, +) + + +async def require_bot_scope(request: Request) -> AuthContext: + return await require_scope(request, "bot") + + +def get_service(request: Request) -> BotConfigService: + return request.app.state.services.bots + + +async def _json_or_empty(request: Request) -> dict: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _required_text(value: object, name: str) -> str: + text = str(value or "").strip() + if not text: + raise ValueError(f"Missing key: {name}") + return text + + +def _config_from_body(body: dict) -> dict: + config = body.get("config") + if isinstance(config, dict): + return config + return { + key: value + for key, value in body.items() + if key not in {"bot_id", "config", "enabled"} + } + + +def _alias_error(message: str): + return error(message) + + +@router.get("/bot-types") +async def list_bot_types( + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + return ok(service.list_bot_types()) + + +@router.get("/bots") +async def list_bots( + enabled: bool | None = Query(default=None), + type_: str | None = Query(default=None, alias="type"), + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + return ok(service.list_bots(enabled=enabled, type_=type_)) + + +@router.post("/bots") +async def create_bot( + payload: BotConfigRequest, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + await service.create_bot(payload.to_dashboard_config()) + return ok(message="新增平台配置成功~") + + +@router.get("/bots/stats") +async def list_bot_stats( + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + return ok(service.get_bot_stats()) + + +@router.get("/bots/by-id") +async def get_bot_by_id( + bot_id: str = Query(...), + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + return ok(service.get_bot(bot_id)) + + +@router.put("/bots/by-id") +async def update_bot_by_id( + payload: BotConfigRequest, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + bot_id = _required_text(payload.bot_id, "bot_id") + await service.update_bot( + bot_id, + payload.to_dashboard_config(fallback_id=bot_id), + ) + return ok(message="更新平台配置成功~") + + +@router.delete("/bots/by-id") +async def delete_bot_by_id( + bot_id: str = Query(...), + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + await service.delete_bot(bot_id) + return ok(message="删除平台配置成功~") + + +@router.patch("/bots/enabled") +async def set_bot_enabled_by_id( + payload: BotConfigRequest, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + bot_id = _required_text(payload.bot_id, "bot_id") + await service.set_bot_enabled(bot_id, bool(payload.enabled)) + return ok(message="更新平台配置成功~") + + +@router.post("/bots/test") +async def test_bot_by_id( + payload: BotConfigRequest, + _auth: AuthContext = Depends(require_bot_scope), +): + bot_id = _required_text(payload.bot_id, "bot_id") + return ok({"id": bot_id, "status": "unsupported"}) + + +@router.patch("/bots/{bot_id:path}/enabled") +async def set_bot_enabled( + bot_id: str, + payload: EnabledPatch, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + await service.set_bot_enabled(bot_id, payload.enabled) + return ok(message="更新平台配置成功~") + + +@router.post("/bots/{bot_id:path}/test") +async def test_bot( + bot_id: str, + _auth: AuthContext = Depends(require_bot_scope), +): + return ok({"id": bot_id, "status": "unsupported"}) + + +@router.get("/bots/{bot_id:path}") +async def get_bot( + bot_id: str, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + return ok(service.get_bot(bot_id)) + + +@router.put("/bots/{bot_id:path}") +async def update_bot( + bot_id: str, + payload: BotConfigRequest, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + await service.update_bot(bot_id, payload.to_dashboard_config(fallback_id=bot_id)) + return ok(message="更新平台配置成功~") + + +@router.delete("/bots/{bot_id:path}") +async def delete_bot( + bot_id: str, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + await service.delete_bot(bot_id) + return ok(message="删除平台配置成功~") + + +@legacy_router.get("/list") +async def list_dashboard_alias_platforms( + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + return ok({"platforms": service.list_bots()["bots"]}) + + +@legacy_router.post("/new") +async def create_dashboard_alias_platform( + payload: BotConfigRequest, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + try: + await service.create_bot(payload.to_dashboard_config()) + return ok(message="新增平台配置成功~") + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/update") +async def update_dashboard_alias_platform( + request: Request, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + body = await _json_or_empty(request) + bot_id = body.get("id") + config = body.get("config") + if not bot_id or not isinstance(config, dict): + return _alias_error("参数错误") + try: + await service.update_bot( + str(bot_id), + BotConfigRequest(config=config).to_dashboard_config( + fallback_id=str(bot_id) + ), + ) + return ok(message="更新平台配置成功~") + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/delete") +async def delete_dashboard_alias_platform( + request: Request, + _auth: AuthContext = Depends(require_bot_scope), + service: BotConfigService = Depends(get_service), +): + body = await _json_or_empty(request) + bot_id = body.get("id") + if not bot_id: + return _alias_error("缺少参数 id") + try: + await service.delete_bot(str(bot_id)) + return ok(message="删除平台配置成功~") + except ValueError as exc: + return _alias_error(str(exc)) diff --git a/astrbot/dashboard/api/chat.py b/astrbot/dashboard/api/chat.py new file mode 100644 index 0000000000..15e17cece2 --- /dev/null +++ b/astrbot/dashboard/api/chat.py @@ -0,0 +1,526 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import FileResponse, JSONResponse, StreamingResponse + +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.schemas import ( + ChatMessagePatchRequest, + ChatMessageRegenerateRequest, + ChatSessionBatchDeleteRequest, + ChatSessionPatchRequest, + ChatThreadCreateRequest, + ChatThreadMessageRequest, +) +from astrbot.dashboard.services.chat_service import ( + ChatService, + ChatServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope +from .multipart import single_upload + +router = APIRouter(tags=["Chat"]) +legacy_router = APIRouter( + prefix="/api/chat", + tags=["Dashboard Chat"], + include_in_schema=False, +) + + +def get_service(request: Request) -> ChatService: + return request.app.state.services.chat + + +async def require_chat_scope(request: Request) -> AuthContext: + return await require_scope(request, "chat") + + +async def _json_or_empty(request: Request) -> dict[str, Any]: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +async def _json_or_none(request: Request) -> dict[str, Any] | None: + try: + data = await request.json() + except Exception: + return None + return data if isinstance(data, dict) else None + + +async def _json_body(request: Request): + try: + return await request.json() + except Exception: + return None + + +def _model_dict(payload) -> dict[str, Any]: + return payload.model_dump(exclude_unset=True, exclude_none=False) + + +async def _run(operation): + try: + result = await run_maybe_async(operation) + return ok(result) + except ChatServiceError as exc: + return error(str(exc)) + + +def _file_response(file_path: str, mimetype: str | None): + if mimetype: + return FileResponse(file_path, media_type=mimetype) + return FileResponse(file_path) + + +async def _send_chat( + *, + request: Request, + username: str, + service: ChatService, + payload: dict[str, Any] | None = None, +): + post_data = payload if payload is not None else await _json_or_none(request) + if post_data is None: + return JSONResponse(error("Missing JSON body")) + + try: + stream = await service.build_chat_stream(username, post_data) + except ChatServiceError as exc: + return JSONResponse(error(str(exc))) + + return StreamingResponse( + stream, + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Transfer-Encoding": "chunked", + "Connection": "keep-alive", + }, + ) + + +@router.get("/chat/sessions/new") +async def create_chat_session( + request: Request, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.new_session( + auth.username, + request.query_params.get("platform_id") or "webchat", + ) + ) + + +@router.post("/chat/sessions/batch-delete") +async def batch_delete_chat_sessions( + payload: ChatSessionBatchDeleteRequest, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.batch_delete_sessions_from_dashboard_payload( + auth.username, + _model_dict(payload), + ) + ) + + +@router.get("/chat/sessions/{session_id}") +async def get_chat_session( + session_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run(lambda: service.get_session(auth.username, session_id)) + + +@router.patch("/chat/sessions/{session_id}") +async def update_chat_session( + session_id: str, + payload: ChatSessionPatchRequest, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.update_session_display_name( + auth.username, + session_id, + payload.display_name, + ) + ) + + +@router.delete("/chat/sessions/{session_id}") +async def delete_chat_session( + session_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run(lambda: service.delete_webchat_session(auth.username, session_id)) + + +@router.post("/chat/sessions/{session_id}/stop") +async def stop_chat_session( + session_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run(lambda: service.stop_session(auth.username, session_id)) + + +@router.patch("/chat/sessions/{session_id}/messages/{message_id}") +async def update_chat_message( + session_id: str, + message_id: str, + payload: ChatMessagePatchRequest, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.update_message( + auth.username, + { + "session_id": session_id, + "message_id": message_id, + **_model_dict(payload), + }, + ) + ) + + +@router.post("/chat/sessions/{session_id}/messages/{message_id}/regenerate") +async def regenerate_chat_message( + session_id: str, + message_id: str, + request: Request, + payload: ChatMessageRegenerateRequest | None = None, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + body = _model_dict(payload) if payload is not None else {} + try: + chat_payload = await service.prepare_regenerate_message_payload( + auth.username, + {"session_id": session_id, "message_id": message_id, **body}, + ) + except ChatServiceError as exc: + return JSONResponse(error(str(exc))) + return await _send_chat( + request=request, + username=auth.username, + service=service, + payload=chat_payload, + ) + + +@router.get("/chat/configs") +async def chat_configs( + request: Request, + _auth: AuthContext = Depends(require_chat_scope), +): + return ok(request.app.state.services.config_profiles.list_profiles()) + + +@router.post("/chat/threads") +async def create_chat_thread( + payload: ChatThreadCreateRequest, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.create_thread(auth.username, _model_dict(payload)) + ) + + +@router.get("/chat/threads/{thread_id}") +async def get_chat_thread( + thread_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run(lambda: service.get_thread(auth.username, thread_id)) + + +@router.delete("/chat/threads/{thread_id}") +async def delete_chat_thread( + thread_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + return await _run(lambda: service.delete_thread(auth.username, thread_id)) + + +@router.post("/chat/threads/{thread_id}/messages") +async def send_chat_thread_message( + thread_id: str, + request: Request, + payload: ChatThreadMessageRequest, + auth: AuthContext = Depends(require_chat_scope), + service: ChatService = Depends(get_service), +): + try: + chat_payload = await service.prepare_thread_chat_payload( + auth.username, + {"thread_id": thread_id, **_model_dict(payload)}, + ) + except ChatServiceError as exc: + return JSONResponse(error(str(exc))) + return await _send_chat( + request=request, + username=auth.username, + service=service, + payload=chat_payload, + ) + + +@legacy_router.post("/send") +async def dashboard_send_chat( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + return await _send_chat(request=request, username=username, service=service) + + +@legacy_router.get("/new_session") +async def dashboard_new_session( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.new_session( + username, + request.query_params.get("platform_id") or "webchat", + ) + ) + + +@legacy_router.get("/sessions") +async def dashboard_get_sessions( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.get_sessions(username, request.query_params.get("platform_id")) + ) + + +@legacy_router.get("/get_session") +async def dashboard_get_session( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.get_session_from_dashboard_query( + username, + request.query_params.get("session_id"), + ) + ) + + +@legacy_router.post("/stop") +async def dashboard_stop_session( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run( + lambda: service.stop_session_from_dashboard_payload(username, body) + ) + + +@legacy_router.get("/delete_session") +async def dashboard_delete_session( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.delete_webchat_session_from_dashboard_query( + username, + request.query_params.get("session_id"), + ) + ) + + +@legacy_router.post("/batch_delete_sessions") +async def dashboard_batch_delete_sessions( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + body = await _json_body(request) + return await _run( + lambda: service.batch_delete_sessions_from_dashboard_payload(username, body) + ) + + +@legacy_router.post("/update_session_display_name") +async def dashboard_update_session_display_name( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run( + lambda: service.update_session_display_name_from_dashboard_payload( + username, + body, + ) + ) + + +@legacy_router.post("/message/edit") +async def dashboard_update_message( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.update_message(username, body)) + + +@legacy_router.post("/message/regenerate") +async def dashboard_regenerate_message( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + try: + payload = ( + await service.prepare_regenerate_message_payload_from_dashboard_payload( + username, + await _json_or_empty(request), + ) + ) + except ChatServiceError as exc: + return JSONResponse(error(str(exc))) + return await _send_chat( + request=request, + username=username, + service=service, + payload=payload, + ) + + +@legacy_router.post("/thread/create") +async def dashboard_create_thread( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.create_thread(username, body)) + + +@legacy_router.get("/thread/get") +async def dashboard_get_thread( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + return await _run( + lambda: service.get_thread_from_dashboard_query( + username, + request.query_params.get("thread_id"), + ) + ) + + +@legacy_router.post("/thread/send") +async def dashboard_send_thread_message( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + try: + payload = await service.prepare_thread_chat_payload_from_dashboard_payload( + username, + await _json_or_empty(request), + ) + except ChatServiceError as exc: + return JSONResponse(error(str(exc))) + return await _send_chat( + request=request, + username=username, + service=service, + payload=payload, + ) + + +@legacy_router.post("/thread/delete") +async def dashboard_delete_thread( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run( + lambda: service.delete_thread_from_dashboard_payload(username, body) + ) + + +@legacy_router.get("/get_file") +async def dashboard_get_file( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + try: + file_path, mimetype = await service.resolve_webchat_file_from_dashboard_query( + request.query_params.get("filename") + ) + return _file_response(file_path, mimetype) + except ChatServiceError as exc: + return JSONResponse(error(str(exc))) + except (FileNotFoundError, OSError): + return JSONResponse(error("File access error")) + + +@legacy_router.get("/get_attachment") +async def dashboard_get_attachment( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + try: + ( + file_path, + mimetype, + ) = await service.resolve_attachment_file_from_dashboard_query( + request.query_params.get("attachment_id") + ) + return _file_response(file_path, mimetype) + except ChatServiceError as exc: + return JSONResponse(error(str(exc))) + except (FileNotFoundError, OSError): + return JSONResponse(error("File access error")) + + +@legacy_router.post("/post_file") +async def dashboard_post_file( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ChatService = Depends(get_service), +): + try: + upload = await single_upload(request) + if upload is None: + raise ChatServiceError("Missing key: file") + return ok(await service.save_uploaded_file(upload)) + except ChatServiceError as exc: + return error(str(exc)) diff --git a/astrbot/dashboard/api/chat_projects.py b/astrbot/dashboard/api/chat_projects.py new file mode 100644 index 0000000000..a8d4ba4485 --- /dev/null +++ b/astrbot/dashboard/api/chat_projects.py @@ -0,0 +1,215 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request + +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.schemas import ChatProjectRequest +from astrbot.dashboard.services.chatui_project_service import ( + ChatUIProjectService, + ChatUIProjectServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Chat Projects"]) +legacy_router = APIRouter( + prefix="/api/chatui_project", + tags=["Dashboard Chat Projects"], + include_in_schema=False, +) + + +def get_service(request: Request) -> ChatUIProjectService: + return request.app.state.services.chat_projects + + +async def require_chat_scope(request: Request) -> AuthContext: + return await require_scope(request, "chat") + + +async def _json_or_empty(request: Request) -> dict: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _model_dict(payload) -> dict: + return payload.model_dump(exclude_none=True) + + +async def _run(operation): + try: + result = await run_maybe_async(operation) + return ok(result) + except ChatUIProjectServiceError as exc: + return error(str(exc)) + + +@router.get("/chat/projects") +async def list_chat_projects( + auth: AuthContext = Depends(require_chat_scope), + service: ChatUIProjectService = Depends(get_service), +): + return await _run(lambda: service.list_projects(auth.username)) + + +@legacy_router.get("/list") +async def list_dashboard_chat_projects( + username: str = Depends(require_dashboard_user), + service: ChatUIProjectService = Depends(get_service), +): + return await _run(lambda: service.list_projects(username)) + + +@router.post("/chat/projects") +async def create_chat_project( + payload: ChatProjectRequest, + auth: AuthContext = Depends(require_chat_scope), + service: ChatUIProjectService = Depends(get_service), +): + return await _run( + lambda: service.create_project(auth.username, _model_dict(payload)) + ) + + +@legacy_router.post("/create") +async def create_dashboard_chat_project( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatUIProjectService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.create_project(username, body)) + + +@router.get("/chat/projects/{project_id}") +async def get_chat_project( + project_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatUIProjectService = Depends(get_service), +): + return await _run(lambda: service.get_project(auth.username, project_id)) + + +@legacy_router.get("/get") +async def get_dashboard_chat_project( + project_id: str | None = Query(default=None), + username: str = Depends(require_dashboard_user), + service: ChatUIProjectService = Depends(get_service), +): + return await _run(lambda: service.get_project_from_query(username, project_id)) + + +@router.patch("/chat/projects/{project_id}") +async def update_chat_project( + project_id: str, + payload: ChatProjectRequest, + auth: AuthContext = Depends(require_chat_scope), + service: ChatUIProjectService = Depends(get_service), +): + return await _run( + lambda: service.update_project( + auth.username, + {"project_id": project_id, **_model_dict(payload)}, + ) + ) + + +@legacy_router.post("/update") +async def update_dashboard_chat_project( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatUIProjectService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.update_project(username, body)) + + +@router.delete("/chat/projects/{project_id}") +async def delete_chat_project( + project_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatUIProjectService = Depends(get_service), +): + return await _run(lambda: service.delete_project(auth.username, project_id)) + + +@legacy_router.get("/delete") +async def delete_dashboard_chat_project( + project_id: str | None = Query(default=None), + username: str = Depends(require_dashboard_user), + service: ChatUIProjectService = Depends(get_service), +): + return await _run(lambda: service.delete_project_from_query(username, project_id)) + + +@router.get("/chat/projects/{project_id}/sessions") +async def list_chat_project_sessions( + project_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatUIProjectService = Depends(get_service), +): + return await _run(lambda: service.get_project_sessions(auth.username, project_id)) + + +@legacy_router.get("/get_sessions") +async def list_dashboard_chat_project_sessions( + project_id: str | None = Query(default=None), + username: str = Depends(require_dashboard_user), + service: ChatUIProjectService = Depends(get_service), +): + return await _run( + lambda: service.get_project_sessions_from_query(username, project_id) + ) + + +@router.post("/chat/projects/{project_id}/sessions/{session_id}") +async def add_chat_project_session( + project_id: str, + session_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatUIProjectService = Depends(get_service), +): + return await _run( + lambda: service.add_session_to_project( + auth.username, + {"project_id": project_id, "session_id": session_id}, + ) + ) + + +@legacy_router.post("/add_session") +async def add_dashboard_chat_project_session( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatUIProjectService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.add_session_to_project(username, body)) + + +@router.delete("/chat/projects/sessions/{session_id}") +async def remove_chat_project_session( + session_id: str, + auth: AuthContext = Depends(require_chat_scope), + service: ChatUIProjectService = Depends(get_service), +): + return await _run( + lambda: service.remove_session_from_project( + auth.username, + {"session_id": session_id}, + ) + ) + + +@legacy_router.post("/remove_session") +async def remove_dashboard_chat_project_session( + request: Request, + username: str = Depends(require_dashboard_user), + service: ChatUIProjectService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.remove_session_from_project(username, body)) diff --git a/astrbot/dashboard/api/config_profiles.py b/astrbot/dashboard/api/config_profiles.py new file mode 100644 index 0000000000..943cc7d7e8 --- /dev/null +++ b/astrbot/dashboard/api/config_profiles.py @@ -0,0 +1,468 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Query, Request + +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.schemas import ( + ConfigContentRequest, + ConfigProfileCreateRequest, + ConfigRoutesReplaceRequest, + ConfigRouteUpsertRequest, + RenameRequest, +) +from astrbot.dashboard.services.config_service import ( + ConfigDisplayService, + ConfigFileService, + ConfigProfileService, + ConfigRoutingService, +) + +from .auth import AuthContext, require_scope +from .multipart import multipart_parts + +router = APIRouter(tags=["Config Profiles"]) +legacy_router = APIRouter( + prefix="/api/config", + tags=["Dashboard Config"], + include_in_schema=False, +) + + +async def require_config_scope(request: Request) -> AuthContext: + return await require_scope(request, "config") + + +def get_service(request: Request) -> ConfigProfileService: + return request.app.state.services.config_profiles + + +def get_routing_service(request: Request) -> ConfigRoutingService: + return request.app.state.services.config_routes + + +def get_display_service(request: Request) -> ConfigDisplayService: + return request.app.state.services.config_display + + +def get_file_service(request: Request) -> ConfigFileService: + return request.app.state.services.config_files + + +async def _json_or_empty(request: Request) -> dict: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _alias_error(message: str): + return error(message) + + +def _model_dict(payload) -> dict[str, Any]: + return payload.model_dump(exclude_none=True) + + +@router.get("/config-profiles/schema") +async def get_config_profile_schema( + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + return ok(service.get_profile_schema()) + + +@router.get("/config-profiles") +async def list_config_profiles( + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + return ok(service.list_profiles()) + + +@router.post("/config-profiles") +async def create_config_profile( + payload: ConfigProfileCreateRequest, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + return ok(await service.create_profile(payload.name, payload.config), "创建成功") + + +@router.get("/config-profiles/{config_id}") +async def get_config_profile( + config_id: str, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + return ok(service.get_profile(config_id)) + + +@router.put("/config-profiles/{config_id}") +async def update_config_profile( + config_id: str, + payload: ConfigContentRequest, + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + message = await service.update_profile( + config_id, + _model_dict(payload), + two_factor_code=request.headers.get("X-2FA-Code"), + ) + return ok(message=message or "保存成功") + + +@router.patch("/config-profiles/{config_id}") +async def rename_config_profile( + config_id: str, + payload: RenameRequest, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + service.rename_profile(config_id, payload.name) + return ok(message="更新成功") + + +@router.delete("/config-profiles/{config_id}") +async def delete_config_profile( + config_id: str, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + service.delete_profile(config_id) + return ok(message="删除成功") + + +@router.get("/system-config/schema") +async def get_system_config_schema( + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + return ok(service.get_system_schema()) + + +@router.get("/system-config") +async def get_system_config( + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + return ok(service.get_system_config()) + + +@router.get("/system-config/runtime") +async def get_system_config_runtime( + _auth: AuthContext = Depends(require_config_scope), + service: ConfigDisplayService = Depends(get_display_service), +): + return ok(await service.get_configs()) + + +@router.put("/system-config") +async def update_system_config( + payload: ConfigContentRequest, + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + message = await service.update_profile( + "default", + _model_dict(payload), + two_factor_code=request.headers.get("X-2FA-Code"), + ) + return ok(message=message or "保存成功") + + +@router.get("/config-routes") +async def list_config_routes( + _auth: AuthContext = Depends(require_config_scope), + service: ConfigRoutingService = Depends(get_routing_service), +): + return ok(service.list_routes()) + + +@router.put("/config-routes") +async def replace_config_routes( + payload: ConfigRoutesReplaceRequest, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigRoutingService = Depends(get_routing_service), +): + await service.replace_route_mapping(payload.routing) + return ok(message="更新成功") + + +@router.put("/config-routes/{umo}") +async def upsert_config_route( + umo: str, + payload: ConfigRouteUpsertRequest, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigRoutingService = Depends(get_routing_service), +): + await service.set_route(umo, payload.config_id) + return ok(message="更新成功") + + +@router.delete("/config-routes/{umo}") +async def delete_config_route( + umo: str, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigRoutingService = Depends(get_routing_service), +): + await service.delete_route_by_umo(umo) + return ok(message="删除成功") + + +@legacy_router.get("/default") +async def get_dashboard_alias_default_config( + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + return ok(service.get_profile_schema()) + + +@legacy_router.get("/abconfs") +async def list_dashboard_alias_config_profiles( + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + return ok(service.list_profiles()) + + +@legacy_router.post("/abconf/new") +async def create_dashboard_alias_config_profile( + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + body = await _json_or_empty(request) + try: + return ok( + await service.create_profile( + body.get("name"), + body.get("config"), + ), + "创建成功", + ) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.get("/abconf") +async def get_dashboard_alias_config_profile( + id: str | None = Query(default=None), + system_config: str = Query(default="0"), + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + if system_config.lower() == "1": + return ok(service.get_system_schema()) + if not id: + return _alias_error("缺少配置文件 ID") + try: + return ok(service.get_profile(id)) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/abconf/delete") +async def delete_dashboard_alias_config_profile( + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + body = await _json_or_empty(request) + config_id = body.get("id") + if not config_id: + return _alias_error("缺少配置文件 ID") + try: + service.delete_profile(str(config_id)) + return ok(message="删除成功") + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/abconf/update") +async def rename_dashboard_alias_config_profile( + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + body = await _json_or_empty(request) + config_id = body.get("id") + if not config_id: + return _alias_error("缺少配置文件 ID") + try: + service.rename_profile(str(config_id), body.get("name")) + return ok(message="更新成功") + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/astrbot/update") +async def update_dashboard_alias_astrbot_config( + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigProfileService = Depends(get_service), +): + body = await _json_or_empty(request) + config = body.get("config") + config_id = body.get("conf_id") + if not isinstance(config, dict): + return _alias_error("Invalid config payload") + if not config_id: + return _alias_error("Config file None does not exist") + try: + message = await service.update_profile( + str(config_id), + config, + two_factor_code=request.headers.get("X-2FA-Code"), + ) + return ok(message=message or "保存成功~") + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.get("/get") +async def get_dashboard_alias_configs( + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigDisplayService = Depends(get_display_service), +): + try: + return ok(await service.get_configs_from_dashboard_args(request.query_params)) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/plugin/update") +async def update_dashboard_alias_plugin_configs( + request: Request, + plugin_name: str = Query(default="unknown"), + _auth: AuthContext = Depends(require_config_scope), + service: ConfigFileService = Depends(get_file_service), +): + body = await _json_or_empty(request) + try: + message = await service.save_plugin_configs_from_dashboard_payload( + body, + plugin_name=plugin_name, + ) + return ok(message=message) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/file/upload") +async def upload_dashboard_alias_config_file( + request: Request, + scope: str | None = Query(default=None), + name: str | None = Query(default=None), + key: str | None = Query(default=None), + _auth: AuthContext = Depends(require_config_scope), + service: ConfigFileService = Depends(get_file_service), +): + _, files = await multipart_parts(request) + try: + return ok( + await service.upload_config_file( + scope=scope, + name=name, + key_path=key, + files=files, + ) + ) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/file/delete") +async def delete_dashboard_alias_config_file( + request: Request, + scope: str | None = Query(default=None), + name: str | None = Query(default=None), + _auth: AuthContext = Depends(require_config_scope), + service: ConfigFileService = Depends(get_file_service), +): + body = await _json_or_empty(request) + try: + message = service.delete_config_file_from_dashboard_payload( + scope=scope or "plugin", + name=name, + payload=body, + ) + return ok(message=message) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.get("/file/get") +async def list_dashboard_alias_config_files( + scope: str | None = Query(default=None), + name: str | None = Query(default=None), + key: str | None = Query(default=None), + _auth: AuthContext = Depends(require_config_scope), + service: ConfigFileService = Depends(get_file_service), +): + try: + return ok( + service.list_config_files( + scope=scope, + name=name, + key_path=key, + ) + ) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.get("/umo_abconf_routes") +async def get_dashboard_alias_config_routes( + _auth: AuthContext = Depends(require_config_scope), + service: ConfigRoutingService = Depends(get_routing_service), +): + return ok(service.list_routes()) + + +@legacy_router.post("/umo_abconf_route/update_all") +async def update_dashboard_alias_config_routes( + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigRoutingService = Depends(get_routing_service), +): + body = await _json_or_empty(request) + try: + await service.replace_routes(body) + except ValueError: + return _alias_error("缺少或错误的路由表数据") + return ok(message="更新成功") + + +@legacy_router.post("/umo_abconf_route/update") +async def upsert_dashboard_alias_config_route( + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigRoutingService = Depends(get_routing_service), +): + body = await _json_or_empty(request) + try: + await service.upsert_route(body) + except ValueError: + return _alias_error("缺少 UMO 或配置文件 ID") + return ok(message="更新成功") + + +@legacy_router.post("/umo_abconf_route/delete") +async def delete_dashboard_alias_config_route( + request: Request, + _auth: AuthContext = Depends(require_config_scope), + service: ConfigRoutingService = Depends(get_routing_service), +): + body = await _json_or_empty(request) + try: + await service.delete_route(body) + except ValueError: + return _alias_error("缺少 UMO") + return ok(message="删除成功") diff --git a/astrbot/dashboard/api/conversations.py b/astrbot/dashboard/api/conversations.py new file mode 100644 index 0000000000..f2355d835c --- /dev/null +++ b/astrbot/dashboard/api/conversations.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Query, Request +from fastapi.responses import StreamingResponse + +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import ( + ConversationBatchDeleteRequest, + ConversationExportRequest, + ConversationMessagesReplaceRequest, + ConversationPatchRequest, +) +from astrbot.dashboard.services.conversation_service import ( + ConversationExport, + ConversationService, + ConversationServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Conversations"]) +legacy_router = APIRouter( + prefix="/api/conversation", + tags=["Dashboard Conversations"], + include_in_schema=False, +) + + +def get_service(request: Request) -> ConversationService: + return request.app.state.services.conversations + + +async def require_data_scope(request: Request) -> AuthContext: + return await require_scope(request, "data") + + +async def _json_or_empty(request: Request) -> dict[str, Any]: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _model_dict(payload) -> dict[str, Any]: + return payload.model_dump(exclude_none=True) + + +def _raise_conversation_error(exc: ConversationServiceError) -> None: + raise ApiError(str(exc)) from exc + + +async def _run(operation): + try: + result = await run_maybe_async(operation) + return ok(result) + except ConversationServiceError as exc: + _raise_conversation_error(exc) + + +def _export_response(export: ConversationExport) -> StreamingResponse: + export.file_obj.seek(0) + return StreamingResponse( + export.file_obj, + media_type=export.mimetype, + headers={"Content-Disposition": f'attachment; filename="{export.filename}"'}, + ) + + +async def _export_conversations( + payload: dict[str, Any], + service: ConversationService, +): + try: + return _export_response(await service.export_conversations(payload)) + except ConversationServiceError as exc: + _raise_conversation_error(exc) + + +async def _list_conversations( + service: ConversationService, + *, + page: int, + page_size: int, + platforms: str, + message_types: str, + search: str, + exclude_ids: str, + exclude_platforms: str, +): + return await _run( + lambda: service.list_conversations( + page=page, + page_size=page_size, + platforms=platforms, + message_types=message_types, + search_query=search, + exclude_ids=exclude_ids, + exclude_platforms=exclude_platforms, + ) + ) + + +@router.get("/conversations") +async def list_conversations( + page: int = Query(default=1), + page_size: int = Query(default=20), + platforms: str = Query(default=""), + message_types: str = Query(default=""), + search: str = Query(default=""), + exclude_ids: str = Query(default=""), + exclude_platforms: str = Query(default=""), + _auth: AuthContext = Depends(require_data_scope), + service: ConversationService = Depends(get_service), +): + return await _list_conversations( + service, + page=page, + page_size=page_size, + platforms=platforms, + message_types=message_types, + search=search, + exclude_ids=exclude_ids, + exclude_platforms=exclude_platforms, + ) + + +@router.post("/conversations/export") +async def export_conversations( + payload: ConversationExportRequest, + _auth: AuthContext = Depends(require_data_scope), + service: ConversationService = Depends(get_service), +): + return await _export_conversations(_model_dict(payload), service) + + +@router.post("/conversations/batch-delete") +async def batch_delete_conversations( + payload: ConversationBatchDeleteRequest, + _auth: AuthContext = Depends(require_data_scope), + service: ConversationService = Depends(get_service), +): + return await _run(lambda: service.delete_conversation(_model_dict(payload))) + + +@router.put("/conversations/{conversation_id:path}/messages") +async def replace_conversation_messages( + conversation_id: str, + payload: ConversationMessagesReplaceRequest, + user_id: str | None = Query(default=None), + _auth: AuthContext = Depends(require_data_scope), + service: ConversationService = Depends(get_service), +): + body = _model_dict(payload) + body_user_id = body.pop("user_id", None) or user_id + if "messages" in body and "history" not in body: + body["history"] = body.pop("messages") + return await _run( + lambda: service.update_history( + {"user_id": body_user_id, "cid": conversation_id, **body} + ) + ) + + +@router.get("/conversations/{conversation_id:path}") +async def get_conversation( + conversation_id: str, + user_id: str | None = Query(default=None), + _auth: AuthContext = Depends(require_data_scope), + service: ConversationService = Depends(get_service), +): + return await _run( + lambda: service.get_conversation_detail( + {"user_id": user_id, "cid": conversation_id} + ) + ) + + +@router.patch("/conversations/{conversation_id:path}") +async def update_conversation( + conversation_id: str, + payload: ConversationPatchRequest, + user_id: str | None = Query(default=None), + _auth: AuthContext = Depends(require_data_scope), + service: ConversationService = Depends(get_service), +): + body = _model_dict(payload) + body_user_id = body.pop("user_id", None) or user_id + return await _run( + lambda: service.update_conversation( + {"user_id": body_user_id, "cid": conversation_id, **body} + ) + ) + + +@router.delete("/conversations/{conversation_id:path}") +async def delete_conversation( + conversation_id: str, + user_id: str | None = Query(default=None), + _auth: AuthContext = Depends(require_data_scope), + service: ConversationService = Depends(get_service), +): + return await _run( + lambda: service.delete_conversation( + {"user_id": user_id, "cid": conversation_id} + ) + ) + + +@legacy_router.get("/list") +async def list_dashboard_conversations( + page: int = Query(default=1), + page_size: int = Query(default=20), + platforms: str = Query(default=""), + message_types: str = Query(default=""), + search: str = Query(default=""), + exclude_ids: str = Query(default=""), + exclude_platforms: str = Query(default=""), + _username: str = Depends(require_dashboard_user), + service: ConversationService = Depends(get_service), +): + return await _list_conversations( + service, + page=page, + page_size=page_size, + platforms=platforms, + message_types=message_types, + search=search, + exclude_ids=exclude_ids, + exclude_platforms=exclude_platforms, + ) + + +@legacy_router.post("/detail") +async def get_dashboard_conversation_detail( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ConversationService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.get_conversation_detail(body)) + + +@legacy_router.post("/update") +async def update_dashboard_conversation( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ConversationService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.update_conversation(body)) + + +@legacy_router.post("/delete") +async def delete_dashboard_conversation( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ConversationService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.delete_conversation(body)) + + +@legacy_router.post("/update_history") +async def update_dashboard_conversation_history( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ConversationService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.update_history(body)) + + +@legacy_router.post("/export") +async def export_dashboard_conversations( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ConversationService = Depends(get_service), +): + return await _export_conversations(await _json_or_empty(request), service) diff --git a/astrbot/dashboard/api/cron.py b/astrbot/dashboard/api/cron.py new file mode 100644 index 0000000000..1b1f9cc432 --- /dev/null +++ b/astrbot/dashboard/api/cron.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request + +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import CronJobRequest +from astrbot.dashboard.services.cron_service import CronService, CronServiceError + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Cron"]) +legacy_router = APIRouter( + prefix="/api/cron", + tags=["Dashboard Cron"], + include_in_schema=False, +) + + +async def require_system_scope(request: Request) -> AuthContext: + return await require_scope(request, "system") + + +def get_service(request: Request) -> CronService: + return request.app.state.services.cron + + +def _payload_dict(payload: CronJobRequest) -> dict: + return payload.model_dump(exclude_none=True) + + +def _raise_cron_error(exc: CronServiceError) -> None: + raise ApiError(str(exc)) from exc + + +async def _list_jobs(job_type: str | None, service: CronService): + try: + return ok(await service.list_jobs(job_type)) + except CronServiceError as exc: + _raise_cron_error(exc) + + +async def _create_job(payload: CronJobRequest, service: CronService): + try: + return ok(await service.create_job(_payload_dict(payload))) + except CronServiceError as exc: + _raise_cron_error(exc) + + +async def _update_job(job_id: str, payload: CronJobRequest, service: CronService): + try: + return ok(await service.update_job(job_id, _payload_dict(payload))) + except CronServiceError as exc: + _raise_cron_error(exc) + + +async def _delete_job(job_id: str, service: CronService): + try: + await service.delete_job(job_id) + return ok(message="deleted") + except CronServiceError as exc: + _raise_cron_error(exc) + + +async def _run_job(job_id: str, service: CronService): + try: + await service.run_job_now(job_id) + return ok(message="started") + except CronServiceError as exc: + _raise_cron_error(exc) + + +@router.get("/cron/jobs") +async def list_cron_jobs( + job_type: str | None = Query(default=None, alias="type"), + _auth: AuthContext = Depends(require_system_scope), + service: CronService = Depends(get_service), +): + return await _list_jobs(job_type, service) + + +@router.post("/cron/jobs") +async def create_cron_job( + payload: CronJobRequest, + _auth: AuthContext = Depends(require_system_scope), + service: CronService = Depends(get_service), +): + return await _create_job(payload, service) + + +@router.patch("/cron/jobs/{job_id}") +async def update_cron_job( + job_id: str, + payload: CronJobRequest, + _auth: AuthContext = Depends(require_system_scope), + service: CronService = Depends(get_service), +): + return await _update_job(job_id, payload, service) + + +@router.delete("/cron/jobs/{job_id}") +async def delete_cron_job( + job_id: str, + _auth: AuthContext = Depends(require_system_scope), + service: CronService = Depends(get_service), +): + return await _delete_job(job_id, service) + + +@router.post("/cron/jobs/{job_id}/run") +async def run_cron_job( + job_id: str, + _auth: AuthContext = Depends(require_system_scope), + service: CronService = Depends(get_service), +): + return await _run_job(job_id, service) + + +@legacy_router.get("/jobs") +async def list_dashboard_cron_jobs( + job_type: str | None = Query(default=None, alias="type"), + _username: str = Depends(require_dashboard_user), + service: CronService = Depends(get_service), +): + return await _list_jobs(job_type, service) + + +@legacy_router.post("/jobs") +async def create_dashboard_cron_job( + payload: CronJobRequest, + _username: str = Depends(require_dashboard_user), + service: CronService = Depends(get_service), +): + return await _create_job(payload, service) + + +@legacy_router.patch("/jobs/{job_id}") +async def update_dashboard_cron_job( + job_id: str, + payload: CronJobRequest, + _username: str = Depends(require_dashboard_user), + service: CronService = Depends(get_service), +): + return await _update_job(job_id, payload, service) + + +@legacy_router.delete("/jobs/{job_id}") +async def delete_dashboard_cron_job( + job_id: str, + _username: str = Depends(require_dashboard_user), + service: CronService = Depends(get_service), +): + return await _delete_job(job_id, service) + + +@legacy_router.post("/jobs/{job_id}/run") +async def run_dashboard_cron_job( + job_id: str, + _username: str = Depends(require_dashboard_user), + service: CronService = Depends(get_service), +): + return await _run_job(job_id, service) diff --git a/astrbot/dashboard/api/extensions.py b/astrbot/dashboard/api/extensions.py new file mode 100644 index 0000000000..2c88a66ccc --- /dev/null +++ b/astrbot/dashboard/api/extensions.py @@ -0,0 +1,180 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Request + +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import ( + CommandPermissionRequest, + CommandRenameRequest, + CommandToggleRequest, + CommandUpdateRequest, +) +from astrbot.dashboard.services.command_service import ( + CommandService, + CommandServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Extension Components"]) +legacy_router = APIRouter( + prefix="/api", + tags=["Dashboard Extension Components"], + include_in_schema=False, +) + + +def get_command_service(request: Request) -> CommandService: + return request.app.state.services.commands + + +async def require_tool_scope(request: Request) -> AuthContext: + return await require_scope(request, "tool") + + +def _raise_command_error(exc: CommandServiceError) -> None: + raise ApiError(str(exc)) from exc + + +async def _list_commands(config_id: str | None, service: CommandService): + try: + return ok(await service.list_commands(config_id or "")) + except CommandServiceError as exc: + _raise_command_error(exc) + + +async def _list_command_conflicts(service: CommandService): + try: + return ok(await service.list_conflicts()) + except CommandServiceError as exc: + _raise_command_error(exc) + + +async def _toggle_command(payload: CommandToggleRequest, service: CommandService): + try: + return ok( + await service.toggle_command(payload.handler_full_name, payload.enabled) + ) + except CommandServiceError as exc: + _raise_command_error(exc) + + +async def _rename_command(payload: CommandRenameRequest, service: CommandService): + try: + return ok( + await service.rename_command( + payload.handler_full_name, + payload.new_name, + aliases=payload.aliases, + ) + ) + except CommandServiceError as exc: + _raise_command_error(exc) + + +async def _update_command_permission( + payload: CommandPermissionRequest, + service: CommandService, +): + try: + return ok( + await service.update_permission( + payload.handler_full_name, payload.permission + ) + ) + except CommandServiceError as exc: + _raise_command_error(exc) + + +@router.get("/commands") +async def list_commands( + config_id: str | None = None, + _auth: AuthContext = Depends(require_tool_scope), + service: CommandService = Depends(get_command_service), +): + return await _list_commands(config_id, service) + + +@router.get("/commands/conflicts") +async def list_command_conflicts( + _auth: AuthContext = Depends(require_tool_scope), + service: CommandService = Depends(get_command_service), +): + return await _list_command_conflicts(service) + + +@router.patch("/commands/{command_id:path}") +async def update_command( + command_id: str, + payload: CommandUpdateRequest, + _auth: AuthContext = Depends(require_tool_scope), + service: CommandService = Depends(get_command_service), +): + if payload.enabled is not None: + return await _toggle_command( + CommandToggleRequest( + handler_full_name=command_id, + enabled=payload.enabled, + ), + service, + ) + if payload.alias is not None: + return await _rename_command( + CommandRenameRequest( + handler_full_name=command_id, + new_name=payload.alias, + aliases=payload.aliases, + ), + service, + ) + return await _update_command_permission( + CommandPermissionRequest( + handler_full_name=command_id, + permission=payload.permission_group or "", + ), + service, + ) + + +@legacy_router.get("/commands") +async def list_dashboard_commands( + config_id: str | None = None, + _username: str = Depends(require_dashboard_user), + service: CommandService = Depends(get_command_service), +): + return await _list_commands(config_id, service) + + +@legacy_router.get("/commands/conflicts") +async def list_dashboard_command_conflicts( + _username: str = Depends(require_dashboard_user), + service: CommandService = Depends(get_command_service), +): + return await _list_command_conflicts(service) + + +@legacy_router.post("/commands/toggle") +async def toggle_dashboard_command( + payload: CommandToggleRequest, + _username: str = Depends(require_dashboard_user), + service: CommandService = Depends(get_command_service), +): + return await _toggle_command(payload, service) + + +@legacy_router.post("/commands/rename") +async def rename_dashboard_command( + payload: CommandRenameRequest, + _username: str = Depends(require_dashboard_user), + service: CommandService = Depends(get_command_service), +): + return await _rename_command(payload, service) + + +@legacy_router.post("/commands/permission") +async def update_dashboard_command_permission( + payload: CommandPermissionRequest, + _username: str = Depends(require_dashboard_user), + service: CommandService = Depends(get_command_service), +): + return await _update_command_permission(payload, service) diff --git a/astrbot/dashboard/api/files.py b/astrbot/dashboard/api/files.py new file mode 100644 index 0000000000..86ab241c6d --- /dev/null +++ b/astrbot/dashboard/api/files.py @@ -0,0 +1,130 @@ +from __future__ import annotations + +from pathlib import Path + +from fastapi import APIRouter, Depends, File, HTTPException, Query, Request, UploadFile +from fastapi.responses import FileResponse + +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.services.chat_service import ChatService, ChatServiceError +from astrbot.dashboard.services.file_service import FileService, FileServiceError + +from .auth import AuthContext, require_scope + +router = APIRouter(tags=["Files"]) +legacy_router = APIRouter(prefix="/api", include_in_schema=False) + + +def get_service(request: Request) -> FileService: + return request.app.state.services.files + + +def get_chat_service(request: Request) -> ChatService: + return request.app.state.services.chat + + +async def require_file_scope(request: Request) -> AuthContext: + return await require_scope(request, "file") + + +class UploadFileAdapter: + def __init__(self, file: UploadFile) -> None: + self.file = file + self.filename = file.filename + self.content_type = file.content_type + + async def save(self, target_path: str) -> None: + Path(target_path).write_bytes(await self.file.read()) + + +async def _serve_token_file(file_token: str, service: FileService): + try: + return FileResponse(await service.resolve_token_file(file_token)) + except FileServiceError as exc: + raise HTTPException(status_code=404) from exc + + +def _file_response(file_path: str, mimetype: str | None = None) -> FileResponse: + if mimetype: + return FileResponse(file_path, media_type=mimetype) + return FileResponse(file_path) + + +async def _run_file(operation, *, error_message: str = "File access error"): + try: + result = await run_maybe_async(operation) + return result + except ChatServiceError as exc: + return error(str(exc)) + except (FileNotFoundError, OSError): + return error(error_message) + + +async def _upload_file(file: UploadFile, service: ChatService): + result = await _run_file( + lambda: service.save_uploaded_file(UploadFileAdapter(file)) + ) + if isinstance(result, dict) and result.get("status") == "error": + return result + return ok(result) + + +@router.get("/files/tokens/{file_token}") +async def get_token_file( + file_token: str, + service: FileService = Depends(get_service), +): + return await _serve_token_file(file_token, service) + + +@router.post("/files") +async def upload_file( + file: UploadFile = File(...), + _auth: AuthContext = Depends(require_file_scope), + service: ChatService = Depends(get_chat_service), +): + return await _upload_file(file, service) + + +@router.get("/files/content") +async def get_file_by_name( + filename: str | None = Query(default=None), + _auth: AuthContext = Depends(require_file_scope), + service: ChatService = Depends(get_chat_service), +): + result = await _run_file(lambda: service.resolve_webchat_file(filename)) + if isinstance(result, dict) and result.get("status") == "error": + return result + file_path, mimetype = result + return _file_response(file_path, mimetype) + + +@router.get("/files/{attachment_id}") +@router.get("/files/{attachment_id}/content") +async def get_file( + attachment_id: str, + _auth: AuthContext = Depends(require_file_scope), + service: ChatService = Depends(get_chat_service), +): + result = await _run_file(lambda: service.resolve_attachment_file(attachment_id)) + if isinstance(result, dict) and result.get("status") == "error": + return result + file_path, mimetype = result + return _file_response(file_path, mimetype) + + +@router.delete("/files/{attachment_id}") +async def delete_file( + attachment_id: str, + _auth: AuthContext = Depends(require_file_scope), +): + return ok({"attachment_id": attachment_id}) + + +@legacy_router.get("/file/{file_token}") +async def get_dashboard_token_file( + file_token: str, + service: FileService = Depends(get_service), +): + return await _serve_token_file(file_token, service) diff --git a/astrbot/dashboard/api/knowledge_bases.py b/astrbot/dashboard/api/knowledge_bases.py new file mode 100644 index 0000000000..c6f62235dd --- /dev/null +++ b/astrbot/dashboard/api/knowledge_bases.py @@ -0,0 +1,505 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from fastapi import APIRouter, Depends, Request + +from astrbot.core import logger +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.schemas import ( + KnowledgeBaseImportRequest, + KnowledgeBaseRequest, + KnowledgeBaseRetrieveRequest, + KnowledgeBaseUrlImportRequest, +) +from astrbot.dashboard.services.knowledge_base_service import ( + KnowledgeBaseService, + KnowledgeBaseServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope +from .multipart import multipart_parts + +router = APIRouter(tags=["Knowledge Bases"]) +legacy_router = APIRouter( + prefix="/api/kb", + tags=["Dashboard Knowledge Bases"], + include_in_schema=False, +) + + +def get_service(request: Request) -> KnowledgeBaseService: + return request.app.state.services.knowledge_bases + + +async def require_kb_scope(request: Request) -> AuthContext: + return await require_scope(request, "kb") + + +async def _json_or_empty(request: Request) -> dict[str, Any]: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _to_int(value: Any, default: int) -> int: + try: + return int(value) + except (TypeError, ValueError): + return default + + +def _model_dict(payload) -> dict[str, Any]: + if payload is None: + return {} + if hasattr(payload, "model_dump"): + return payload.model_dump(exclude_none=True) + return payload if isinstance(payload, dict) else {} + + +async def _run(operation, *, prefix: str): + try: + result = await run_maybe_async(operation) + if isinstance(result, tuple): + data, message = result + return ok(data, message) + return ok(result) + except (KnowledgeBaseServiceError, ValueError) as exc: + return error(str(exc)) + except Exception as exc: + logger.error("%s: %s", prefix, exc, exc_info=True) + return error(f"{prefix}: {exc!s}") + + +async def _run_json( + request: Request, + operation: Callable[[dict[str, Any]], Any], + *, + prefix: str, +): + body = await _json_or_empty(request) + return await _run(lambda: operation(body), prefix=prefix) + + +@router.get("/knowledge-bases") +async def list_knowledge_bases( + request: Request, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.list_kbs( + page=_to_int(request.query_params.get("page"), 1), + page_size=_to_int(request.query_params.get("page_size"), 20), + ), + prefix="获取知识库列表失败", + ) + + +@router.post("/knowledge-bases") +async def create_knowledge_base( + payload: KnowledgeBaseRequest, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.create_kb(_model_dict(payload)), + prefix="创建知识库失败", + ) + + +@router.get("/knowledge-bases/tasks/{task_id}") +async def get_knowledge_base_task( + task_id: str, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.get_upload_progress(task_id), + prefix="获取上传进度失败", + ) + + +@router.get("/knowledge-bases/{kb_id}") +async def get_knowledge_base( + kb_id: str, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run(lambda: service.get_kb(kb_id), prefix="获取知识库详情失败") + + +@router.put("/knowledge-bases/{kb_id}") +async def update_knowledge_base( + kb_id: str, + payload: KnowledgeBaseRequest, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + body = _model_dict(payload) + return await _run( + lambda: service.update_kb({"kb_id": kb_id, **body}), + prefix="更新知识库失败", + ) + + +@router.delete("/knowledge-bases/{kb_id}") +async def delete_knowledge_base( + kb_id: str, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.delete_kb({"kb_id": kb_id}), prefix="删除知识库失败" + ) + + +@router.get("/knowledge-bases/{kb_id}/stats") +async def get_knowledge_base_stats( + kb_id: str, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.get_kb_stats(kb_id), + prefix="获取知识库统计失败", + ) + + +@router.get("/knowledge-bases/{kb_id}/documents") +async def list_knowledge_base_documents( + kb_id: str, + request: Request, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.list_documents( + kb_id=kb_id, + page=_to_int(request.query_params.get("page"), 1), + page_size=_to_int(request.query_params.get("page_size"), 100), + ), + prefix="获取文档列表失败", + ) + + +@router.post("/knowledge-bases/{kb_id}/documents") +async def upload_knowledge_base_document( + kb_id: str, + request: Request, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + async def _operation(): + form_data, files = await multipart_parts(request, extra_form={"kb_id": kb_id}) + return await service.upload_document( + content_type=request.headers.get("content-type"), + form_data=form_data, + files=files, + ) + + return await _run(_operation, prefix="上传文档失败") + + +@router.post("/knowledge-bases/{kb_id}/documents/import") +async def import_knowledge_base_documents( + kb_id: str, + payload: KnowledgeBaseImportRequest, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + body = _model_dict(payload) + return await _run( + lambda: service.import_documents({"kb_id": kb_id, **body}), + prefix="导入文档失败", + ) + + +@router.post("/knowledge-bases/{kb_id}/documents/import-url") +async def import_knowledge_base_document_url( + kb_id: str, + payload: KnowledgeBaseUrlImportRequest, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + body = _model_dict(payload) + return await _run( + lambda: service.upload_document_from_url({"kb_id": kb_id, **body}), + prefix="从URL上传文档失败", + ) + + +@router.get("/knowledge-bases/{kb_id}/documents/{document_id}") +async def get_knowledge_base_document( + kb_id: str, + document_id: str, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.get_document(kb_id=kb_id, doc_id=document_id), + prefix="获取文档详情失败", + ) + + +@router.delete("/knowledge-bases/{kb_id}/documents/{document_id}") +async def delete_knowledge_base_document( + kb_id: str, + document_id: str, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.delete_document({"kb_id": kb_id, "doc_id": document_id}), + prefix="删除文档失败", + ) + + +@router.get("/knowledge-bases/{kb_id}/chunks") +async def list_knowledge_base_chunks( + kb_id: str, + request: Request, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + document_id = request.query_params.get("document_id") or request.query_params.get( + "doc_id" + ) + return await _run( + lambda: service.list_chunks( + kb_id=kb_id, + doc_id=document_id, + page=_to_int(request.query_params.get("page"), 1), + page_size=_to_int(request.query_params.get("page_size"), 100), + ), + prefix="获取块列表失败", + ) + + +@router.delete("/knowledge-bases/{kb_id}/chunks/{chunk_id}") +async def delete_knowledge_base_chunk( + kb_id: str, + chunk_id: str, + request: Request, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + document_id = request.query_params.get("document_id") or request.query_params.get( + "doc_id" + ) + return await _run( + lambda: service.delete_chunk( + {"kb_id": kb_id, "chunk_id": chunk_id, "doc_id": document_id} + ), + prefix="删除文本块失败", + ) + + +@router.post("/knowledge-bases/{kb_id}/retrieve") +async def retrieve_knowledge_base( + kb_id: str, + payload: KnowledgeBaseRetrieveRequest, + _auth: AuthContext = Depends(require_kb_scope), + service: KnowledgeBaseService = Depends(get_service), +): + body = _model_dict(payload) + return await _run( + lambda: service.retrieve({"kb_id": kb_id, **body}), + prefix="检索失败", + ) + + +@legacy_router.get("/list") +async def dashboard_list_kbs( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.list_kbs( + page=_to_int(request.query_params.get("page"), 1), + page_size=_to_int(request.query_params.get("page_size"), 20), + ), + prefix="获取知识库列表失败", + ) + + +@legacy_router.post("/create") +async def dashboard_create_kb( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run_json(request, service.create_kb, prefix="创建知识库失败") + + +@legacy_router.get("/get") +async def dashboard_get_kb( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.get_kb(request.query_params.get("kb_id")), + prefix="获取知识库详情失败", + ) + + +@legacy_router.post("/update") +async def dashboard_update_kb( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run_json(request, service.update_kb, prefix="更新知识库失败") + + +@legacy_router.post("/delete") +async def dashboard_delete_kb( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run_json(request, service.delete_kb, prefix="删除知识库失败") + + +@legacy_router.get("/stats") +async def dashboard_get_kb_stats( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.get_kb_stats(request.query_params.get("kb_id")), + prefix="获取知识库统计失败", + ) + + +@legacy_router.get("/document/list") +async def dashboard_list_documents( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.list_documents( + kb_id=request.query_params.get("kb_id"), + page=_to_int(request.query_params.get("page"), 1), + page_size=_to_int(request.query_params.get("page_size"), 100), + ), + prefix="获取文档列表失败", + ) + + +@legacy_router.post("/document/upload") +async def dashboard_upload_document( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + async def _operation(): + form_data, files = await multipart_parts(request) + return await service.upload_document( + content_type=request.headers.get("content-type"), + form_data=form_data, + files=files, + ) + + return await _run(_operation, prefix="上传文档失败") + + +@legacy_router.post("/document/import") +async def dashboard_import_documents( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run_json(request, service.import_documents, prefix="导入文档失败") + + +@legacy_router.post("/document/upload/url") +async def dashboard_upload_document_from_url( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run_json( + request, + service.upload_document_from_url, + prefix="从URL上传文档失败", + ) + + +@legacy_router.get("/document/upload/progress") +async def dashboard_get_upload_progress( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.get_upload_progress(request.query_params.get("task_id")), + prefix="获取上传进度失败", + ) + + +@legacy_router.get("/document/get") +async def dashboard_get_document( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.get_document( + kb_id=request.query_params.get("kb_id"), + doc_id=request.query_params.get("doc_id"), + ), + prefix="获取文档详情失败", + ) + + +@legacy_router.post("/document/delete") +async def dashboard_delete_document( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run_json(request, service.delete_document, prefix="删除文档失败") + + +@legacy_router.get("/chunk/list") +async def dashboard_list_chunks( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run( + lambda: service.list_chunks( + kb_id=request.query_params.get("kb_id"), + doc_id=request.query_params.get("doc_id"), + page=_to_int(request.query_params.get("page"), 1), + page_size=_to_int(request.query_params.get("page_size"), 100), + ), + prefix="获取块列表失败", + ) + + +@legacy_router.post("/chunk/delete") +async def dashboard_delete_chunk( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run_json(request, service.delete_chunk, prefix="删除文本块失败") + + +@legacy_router.post("/retrieve") +async def dashboard_retrieve( + request: Request, + _username: str = Depends(require_dashboard_user), + service: KnowledgeBaseService = Depends(get_service), +): + return await _run_json(request, service.retrieve, prefix="检索失败") diff --git a/astrbot/dashboard/api/live_chat.py b/astrbot/dashboard/api/live_chat.py new file mode 100644 index 0000000000..470e6abf56 --- /dev/null +++ b/astrbot/dashboard/api/live_chat.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from fastapi import APIRouter, WebSocket + +from astrbot.dashboard.services.live_chat_service import LiveChatService + +router = APIRouter(tags=["Live Chat"]) +legacy_router = APIRouter( + prefix="/api", + tags=["Dashboard Live Chat"], + include_in_schema=False, +) + + +def get_service(websocket: WebSocket) -> LiveChatService: + return websocket.app.state.services.live_chat + + +async def _run_live_chat_ws( + websocket: WebSocket, + *, + force_ct: str | None, +) -> None: + await websocket.accept() + service = get_service(websocket) + await service.run_websocket_session( + token=websocket.query_params.get("token"), + force_ct=force_ct, + receive_json=websocket.receive_json, + send_json=websocket.send_json, + close=websocket.close, + ) + + +@router.websocket("/live-chat/ws") +async def live_chat_ws(websocket: WebSocket) -> None: + await _run_live_chat_ws(websocket, force_ct="live") + + +@router.websocket("/unified-chat/ws") +async def unified_chat_ws(websocket: WebSocket) -> None: + await _run_live_chat_ws(websocket, force_ct=None) + + +@legacy_router.websocket("/live_chat/ws") +async def dashboard_live_chat_ws(websocket: WebSocket) -> None: + await _run_live_chat_ws(websocket, force_ct="live") + + +@legacy_router.websocket("/unified_chat/ws") +async def dashboard_unified_chat_ws(websocket: WebSocket) -> None: + await _run_live_chat_ws(websocket, force_ct=None) diff --git a/astrbot/dashboard/api/logs.py b/astrbot/dashboard/api/logs.py new file mode 100644 index 0000000000..de5e6b5729 --- /dev/null +++ b/astrbot/dashboard/api/logs.py @@ -0,0 +1,131 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Header, Request +from fastapi.responses import StreamingResponse + +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import TraceSettingsRequest +from astrbot.dashboard.services.log_service import LogService, LogServiceError + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Logs"]) +legacy_router = APIRouter( + prefix="/api", + tags=["Dashboard Logs"], + include_in_schema=False, +) + + +async def require_system_scope(request: Request) -> AuthContext: + return await require_scope(request, "system") + + +def get_service(request: Request) -> LogService: + return request.app.state.services.logs + + +def _raise_log_error(exc: LogServiceError) -> None: + raise ApiError(str(exc)) from exc + + +def _log_stream_response(last_event_id: str | None, service: LogService): + return StreamingResponse( + service.stream_log_events(last_event_id), + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Connection": "keep-alive", + "Transfer-Encoding": "chunked", + }, + ) + + +def _get_log_history(service: LogService): + try: + return ok(service.get_log_history()) + except LogServiceError as exc: + _raise_log_error(exc) + + +def _get_trace_settings(service: LogService): + try: + return ok(service.get_trace_settings()) + except LogServiceError as exc: + _raise_log_error(exc) + + +def _update_trace_settings(payload: TraceSettingsRequest, service: LogService): + try: + message = service.update_trace_settings(payload.model_dump(exclude_none=True)) + return ok(message=message) + except LogServiceError as exc: + _raise_log_error(exc) + + +@router.get("/logs/history") +async def get_log_history( + _auth: AuthContext = Depends(require_system_scope), + service: LogService = Depends(get_service), +): + return _get_log_history(service) + + +@router.get("/logs/live") +async def live_logs( + last_event_id: str | None = Header(default=None, alias="Last-Event-ID"), + _auth: AuthContext = Depends(require_system_scope), + service: LogService = Depends(get_service), +): + return _log_stream_response(last_event_id, service) + + +@router.get("/trace/settings") +async def get_trace_settings( + _auth: AuthContext = Depends(require_system_scope), + service: LogService = Depends(get_service), +): + return _get_trace_settings(service) + + +@router.put("/trace/settings") +async def update_trace_settings( + payload: TraceSettingsRequest, + _auth: AuthContext = Depends(require_system_scope), + service: LogService = Depends(get_service), +): + return _update_trace_settings(payload, service) + + +@legacy_router.get("/log-history") +async def get_dashboard_log_history( + _username: str = Depends(require_dashboard_user), + service: LogService = Depends(get_service), +): + return _get_log_history(service) + + +@legacy_router.get("/live-log") +async def get_dashboard_live_logs( + last_event_id: str | None = Header(default=None, alias="Last-Event-ID"), + _username: str = Depends(require_dashboard_user), + service: LogService = Depends(get_service), +): + return _log_stream_response(last_event_id, service) + + +@legacy_router.get("/trace/settings") +async def get_dashboard_trace_settings( + _username: str = Depends(require_dashboard_user), + service: LogService = Depends(get_service), +): + return _get_trace_settings(service) + + +@legacy_router.post("/trace/settings") +async def update_dashboard_trace_settings( + payload: TraceSettingsRequest, + _username: str = Depends(require_dashboard_user), + service: LogService = Depends(get_service), +): + return _update_trace_settings(payload, service) diff --git a/astrbot/dashboard/api/multipart.py b/astrbot/dashboard/api/multipart.py new file mode 100644 index 0000000000..4979ddfb98 --- /dev/null +++ b/astrbot/dashboard/api/multipart.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from collections.abc import Callable +from pathlib import Path +from typing import Any + +from fastapi import Request +from starlette.datastructures import UploadFile as StarletteUploadFile + + +class UploadFileAdapter: + def __init__(self, upload_file: StarletteUploadFile) -> None: + self._upload_file = upload_file + self.filename = upload_file.filename + self.content_type = upload_file.content_type + self.headers = upload_file.headers + self.content_length = self._resolve_content_length() + + def _resolve_content_length(self) -> int | None: + try: + raw = self.headers.get("content-length") + return int(raw) if raw else None + except (TypeError, ValueError): + return None + + async def save(self, destination: str | Path) -> None: + path = Path(destination) + try: + await self._upload_file.seek(0) + except Exception: + pass + with path.open("wb") as output: + while True: + chunk = await self._upload_file.read(1024 * 1024) + if not chunk: + break + output.write(chunk) + + +class MultiDict: + def __init__(self, pairs: list[tuple[str, Any]]) -> None: + self._pairs = pairs + + def get(self, key: str, default: Any = None, type: Callable | None = None): + for item_key, item_value in reversed(self._pairs): + if item_key != key: + continue + if type is None: + return item_value + try: + return type(item_value) + except (TypeError, ValueError): + return default + return default + + def getlist(self, key: str) -> list[Any]: + return [item_value for item_key, item_value in self._pairs if item_key == key] + + def keys(self): + return dict.fromkeys(item_key for item_key, _ in self._pairs).keys() + + def values(self): + return [self[key] for key in self.keys()] + + def __contains__(self, key: str) -> bool: + return any(item_key == key for item_key, _ in self._pairs) + + def __getitem__(self, key: str): + value = self.get(key) + if value is None and key not in self: + raise KeyError(key) + return value + + def __bool__(self) -> bool: + return bool(self._pairs) + + +async def multipart_parts( + request: Request, + *, + extra_form: dict[str, Any] | None = None, +) -> tuple[MultiDict, MultiDict]: + form = await request.form() + form_pairs: list[tuple[str, Any]] = [] + file_pairs: list[tuple[str, Any]] = [] + for key, value in form.multi_items(): + if isinstance(value, StarletteUploadFile): + file_pairs.append((key, UploadFileAdapter(value))) + else: + form_pairs.append((key, value)) + form_data = MultiDict(form_pairs) + for key, value in (extra_form or {}).items(): + if value is not None and key not in form_data: + form_pairs.append((key, value)) + return MultiDict(form_pairs), MultiDict(file_pairs) + + +async def single_upload( + request: Request, + *, + field_name: str = "file", +) -> UploadFileAdapter | None: + _, files = await multipart_parts(request) + upload = files.get(field_name) + if isinstance(upload, UploadFileAdapter): + return upload + return None diff --git a/astrbot/dashboard/api/open_api.py b/astrbot/dashboard/api/open_api.py new file mode 100644 index 0000000000..cca8dfaf78 --- /dev/null +++ b/astrbot/dashboard/api/open_api.py @@ -0,0 +1,310 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Request, WebSocket +from fastapi.responses import FileResponse, JSONResponse, StreamingResponse + +from astrbot.dashboard.responses import ApiError, error, ok +from astrbot.dashboard.schemas import ImMessageRequest, OpenApiChatRequest +from astrbot.dashboard.services.chat_service import ( + ChatService, + ChatServiceError, + extract_web_search_refs, +) +from astrbot.dashboard.services.open_api_service import ( + OpenApiService, + OpenApiServiceError, + OpenApiWebSocketChatBridge, +) + +from .auth import AuthContext, require_scope +from .multipart import single_upload + +router = APIRouter(tags=["Open API"]) + + +async def require_im_scope(request: Request) -> AuthContext: + return await require_scope(request, "im") + + +async def require_chat_scope(request: Request) -> AuthContext: + return await require_scope(request, "chat") + + +async def require_config_scope(request: Request) -> AuthContext: + return await require_scope(request, "config") + + +async def require_file_scope(request: Request) -> AuthContext: + return await require_scope(request, "file") + + +def get_service(request: Request) -> OpenApiService: + return request.app.state.services.open_api + + +def get_chat_service(request: Request) -> ChatService: + return request.app.state.services.chat + + +def _model_dict(payload) -> dict[str, Any]: + if payload is None: + return {} + if hasattr(payload, "model_dump"): + return payload.model_dump(exclude_unset=True, exclude_none=False) + return payload if isinstance(payload, dict) else {} + + +def _open_api_error(message: str) -> JSONResponse: + return JSONResponse(error(message)) + + +def _get_chat_config_list(service: OpenApiService) -> list[dict]: + return service.get_chat_config_list() + + +async def _build_streaming_chat_response( + chat_service: ChatService, + username: str, + post_data: dict[str, Any], +) -> StreamingResponse | JSONResponse: + try: + stream = await chat_service.build_chat_stream(username, post_data) + except ChatServiceError as exc: + return _open_api_error(str(exc)) + + return StreamingResponse( + stream, + media_type="text/event-stream", + headers={ + "Cache-Control": "no-cache", + "Transfer-Encoding": "chunked", + "Connection": "keep-alive", + }, + ) + + +async def _open_api_chat_response( + post_data: dict[str, Any], + auth: AuthContext, + open_api_service: OpenApiService, + chat_service: ChatService, +) -> StreamingResponse | JSONResponse: + if auth.via != "api_key": + return await _build_streaming_chat_response( + chat_service, + auth.username, + post_data, + ) + + try: + ( + effective_username, + session_id, + config_id, + ) = await open_api_service.prepare_chat_send( + post_data, + _get_chat_config_list(open_api_service), + ) + except OpenApiServiceError as exc: + return _open_api_error(str(exc)) + + config_err = await open_api_service.update_session_config_route( + username=effective_username, + session_id=session_id, + config_id=config_id, + ) + if config_err: + return _open_api_error(config_err) + + return await _build_streaming_chat_response( + chat_service, + effective_username, + post_data, + ) + + +async def _insert_webchat_user_message( + service: OpenApiService, + session_id: str, + effective_username: str, + message_parts: list, +) -> None: + await service.insert_webchat_user_message( + session_id=session_id, + effective_username=effective_username, + message_parts=message_parts, + ) + + +def _build_chat_ws_bridge( + open_api_service: OpenApiService, + chat_service: ChatService, +) -> OpenApiWebSocketChatBridge: + return OpenApiWebSocketChatBridge( + build_user_message_parts=lambda message: chat_service.build_user_message_parts( + message if isinstance(message, str | list) else str(message), + ), + create_attachment_from_file=chat_service.create_attachment_from_file, + extract_web_search_refs=extract_web_search_refs, + insert_user_message=lambda session_id, effective_username, message_parts: ( + _insert_webchat_user_message( + open_api_service, + session_id, + effective_username, + message_parts, + ) + ), + save_bot_message=chat_service.save_bot_message, + ) + + +def _extract_ws_api_key(websocket: WebSocket) -> str | None: + if key := websocket.query_params.get("api_key"): + return key.strip() + if key := websocket.query_params.get("key"): + return key.strip() + if key := websocket.headers.get("X-API-Key"): + return key.strip() + + auth_header = websocket.headers.get("Authorization", "").strip() + if auth_header.startswith("Bearer "): + return auth_header.removeprefix("Bearer ").strip() + if auth_header.startswith("ApiKey "): + return auth_header.removeprefix("ApiKey ").strip() + return None + + +@router.post("/chat") +async def chat( + payload: OpenApiChatRequest, + auth: AuthContext = Depends(require_chat_scope), + service: OpenApiService = Depends(get_service), + chat_service: ChatService = Depends(get_chat_service), +): + return await _open_api_chat_response( + _model_dict(payload), + auth, + service, + chat_service, + ) + + +@router.get("/chat/sessions") +async def chat_sessions( + request: Request, + auth: AuthContext = Depends(require_chat_scope), + service: OpenApiService = Depends(get_service), + chat_service: ChatService = Depends(get_chat_service), +): + if auth.via != "api_key": + try: + return ok( + await chat_service.get_sessions( + auth.username, + request.query_params.get("platform_id"), + ) + ) + except ChatServiceError as exc: + return error(str(exc)) + + try: + return ok( + await service.get_chat_sessions_from_dashboard_query( + username=request.query_params.get("username"), + page=request.query_params.get("page", 1), + page_size=request.query_params.get("page_size", 20), + platform_id=request.query_params.get("platform_id"), + ) + ) + except OpenApiServiceError as exc: + return error(str(exc)) + + +@router.get("/configs", include_in_schema=False) +async def get_chat_configs( + _auth: AuthContext = Depends(require_config_scope), + service: OpenApiService = Depends(get_service), +): + return ok(service.get_chat_configs()) + + +@router.post("/file", include_in_schema=False) +async def upload_open_api_file( + request: Request, + _auth: AuthContext = Depends(require_file_scope), + chat_service: ChatService = Depends(get_chat_service), +): + try: + upload = await single_upload(request) + if upload is None: + raise ChatServiceError("Missing key: file") + return ok(await chat_service.save_uploaded_file(upload)) + except ChatServiceError as exc: + return error(str(exc)) + + +@router.get("/file", include_in_schema=False) +async def get_open_api_file( + request: Request, + _auth: AuthContext = Depends(require_file_scope), + chat_service: ChatService = Depends(get_chat_service), +): + try: + file_path, mimetype = await chat_service.resolve_attachment_file( + request.query_params.get("attachment_id") + ) + return FileResponse(file_path, media_type=mimetype) + except ChatServiceError as exc: + return _open_api_error(str(exc)) + except (FileNotFoundError, OSError): + return _open_api_error("File access error") + + +@router.websocket("/chat/ws") +async def chat_ws(websocket: WebSocket) -> None: + await websocket.accept() + service: OpenApiService = websocket.app.state.services.open_api + chat_service: ChatService = websocket.app.state.services.chat + await service.run_chat_websocket( + raw_api_key=_extract_ws_api_key(websocket), + receive_json=websocket.receive_json, + send_json=websocket.send_json, + close=lambda code, reason: websocket.close(code=code, reason=reason), + conf_list=_get_chat_config_list(service), + chat_bridge=_build_chat_ws_bridge(service, chat_service), + ) + + +@router.post("/im/messages") +async def send_im_message( + payload: ImMessageRequest, + _auth: AuthContext = Depends(require_im_scope), + service: OpenApiService = Depends(get_service), +): + body = _model_dict(payload) + try: + await service.send_message(body) + except OpenApiServiceError as exc: + raise ApiError(str(exc)) from exc + + return ok() + + +@router.post("/im/message", include_in_schema=False) +async def send_im_message_alias( + payload: ImMessageRequest, + auth: AuthContext = Depends(require_im_scope), + service: OpenApiService = Depends(get_service), +): + return await send_im_message(payload, auth, service) + + +@router.get("/im/bots") +async def list_im_bots( + _request: Request, + _auth: AuthContext = Depends(require_im_scope), + service: OpenApiService = Depends(get_service), +): + return ok(service.get_bots()) diff --git a/astrbot/dashboard/api/personas.py b/astrbot/dashboard/api/personas.py new file mode 100644 index 0000000000..ae83e5452b --- /dev/null +++ b/astrbot/dashboard/api/personas.py @@ -0,0 +1,334 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Query, Request + +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import ( + PersonaByIdRequest, + PersonaFolderRequest, + PersonaMoveRequest, + PersonaReorderRequest, + PersonaRequest, +) +from astrbot.dashboard.services.persona_service import ( + PersonaService, + PersonaServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Personas"]) +legacy_router = APIRouter( + prefix="/api/persona", + tags=["Dashboard Personas"], + include_in_schema=False, +) + + +def get_service(request: Request) -> PersonaService: + return request.app.state.services.personas + + +async def require_persona_scope(request: Request) -> AuthContext: + return await require_scope(request, "persona") + + +async def _json_or_empty(request: Request) -> dict[str, Any]: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _model_dict(payload) -> dict[str, Any]: + return payload.model_dump(exclude_none=True) + + +def _raise_persona_error(exc: PersonaServiceError | ValueError) -> None: + raise ApiError(str(exc)) from exc + + +async def _run(operation): + try: + result = await run_maybe_async(operation) + return ok(result) + except (PersonaServiceError, ValueError) as exc: + _raise_persona_error(exc) + + +@router.get("/personas/tree") +async def persona_tree( + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(service.get_folder_tree) + + +@router.get("/personas") +async def list_personas( + request: Request, + folder_id: str | None = Query(default=None), + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run( + lambda: service.list_personas(folder_id, "folder_id" in request.query_params) + ) + + +@router.post("/personas") +async def create_persona( + payload: PersonaRequest, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.create_persona(_model_dict(payload))) + + +@router.get("/personas/by-id") +async def get_persona_by_id( + persona_id: str = Query(...), + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.get_persona_detail({"persona_id": persona_id})) + + +@router.put("/personas/by-id") +async def update_persona_by_id( + payload: PersonaByIdRequest, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.update_persona(_model_dict(payload))) + + +@router.delete("/personas/by-id") +async def delete_persona_by_id( + persona_id: str = Query(...), + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.delete_persona({"persona_id": persona_id})) + + +@router.post("/personas/move") +async def move_persona( + payload: PersonaMoveRequest, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.move_persona(_model_dict(payload))) + + +@router.post("/personas/reorder") +async def reorder_personas( + payload: PersonaReorderRequest, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.reorder_items(_model_dict(payload))) + + +@router.get("/persona-folders") +async def list_persona_folders( + parent_id: str | None = Query(default=None), + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.list_folders(parent_id)) + + +@router.post("/persona-folders") +async def create_persona_folder( + payload: PersonaFolderRequest, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.create_folder(_model_dict(payload))) + + +@router.put("/persona-folders/{folder_id:path}") +async def update_persona_folder( + folder_id: str, + payload: PersonaFolderRequest, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run( + lambda: service.update_folder({"folder_id": folder_id, **_model_dict(payload)}) + ) + + +@router.delete("/persona-folders/{folder_id:path}") +async def delete_persona_folder( + folder_id: str, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.delete_folder({"folder_id": folder_id})) + + +@router.get("/personas/{persona_id:path}") +async def get_persona( + persona_id: str, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.get_persona_detail({"persona_id": persona_id})) + + +@router.put("/personas/{persona_id:path}") +async def update_persona( + persona_id: str, + payload: PersonaRequest, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run( + lambda: service.update_persona( + {"persona_id": persona_id, **_model_dict(payload)} + ) + ) + + +@router.delete("/personas/{persona_id:path}") +async def delete_persona( + persona_id: str, + _auth: AuthContext = Depends(require_persona_scope), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.delete_persona({"persona_id": persona_id})) + + +@legacy_router.get("/list") +async def list_dashboard_personas( + request: Request, + folder_id: str | None = Query(default=None), + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + return await _run( + lambda: service.list_personas(folder_id, "folder_id" in request.query_params) + ) + + +@legacy_router.post("/detail") +async def get_dashboard_persona_detail( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.get_persona_detail(body)) + + +@legacy_router.post("/create") +async def create_dashboard_persona( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.create_persona(body)) + + +@legacy_router.post("/update") +async def update_dashboard_persona( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.update_persona(body)) + + +@legacy_router.post("/delete") +async def delete_dashboard_persona( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.delete_persona(body)) + + +@legacy_router.post("/move") +async def move_dashboard_persona( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.move_persona(body)) + + +@legacy_router.post("/reorder") +async def reorder_dashboard_personas( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.reorder_items(body)) + + +@legacy_router.get("/folder/list") +async def list_dashboard_persona_folders( + parent_id: str | None = Query(default=None), + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + return await _run(lambda: service.list_folders(parent_id)) + + +@legacy_router.get("/folder/tree") +async def get_dashboard_persona_folder_tree( + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + return await _run(service.get_folder_tree) + + +@legacy_router.post("/folder/detail") +async def get_dashboard_persona_folder_detail( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.get_folder_detail(body)) + + +@legacy_router.post("/folder/create") +async def create_dashboard_persona_folder( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.create_folder(body)) + + +@legacy_router.post("/folder/update") +async def update_dashboard_persona_folder( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.update_folder(body)) + + +@legacy_router.post("/folder/delete") +async def delete_dashboard_persona_folder( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PersonaService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.delete_folder(body)) diff --git a/astrbot/dashboard/api/platform.py b/astrbot/dashboard/api/platform.py new file mode 100644 index 0000000000..c6a6f7f551 --- /dev/null +++ b/astrbot/dashboard/api/platform.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Request + +from astrbot.dashboard.asgi_runtime import DashboardRequest +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import BotRegistrationRequest +from astrbot.dashboard.services.platform_service import ( + PlatformService, + PlatformServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Platforms"]) +legacy_router = APIRouter( + prefix="/api/platform", + tags=["Dashboard Platforms"], + include_in_schema=False, +) + + +def get_service(request: Request) -> PlatformService: + return request.app.state.services.platforms + + +async def require_config_scope(request: Request) -> AuthContext: + return await require_scope(request, "config") + + +async def _json_or_empty(request: Request) -> dict[str, Any]: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _raise_platform_error(exc: PlatformServiceError) -> None: + raise ApiError(str(exc), status_code=exc.status_code) from exc + + +def _model_dict(payload) -> dict[str, Any]: + return payload.model_dump(exclude_none=True) + + +async def _run(operation): + try: + result = await run_maybe_async(operation) + return ok(result) + except PlatformServiceError as exc: + _raise_platform_error(exc) + + +@router.post("/bot-types/{bot_type}/registration") +async def register_bot_type( + bot_type: str, + payload: BotRegistrationRequest, + _auth: AuthContext = Depends(require_config_scope), + service: PlatformService = Depends(get_service), +): + return await _run( + lambda: service.handle_platform_registration(bot_type, _model_dict(payload)) + ) + + +@router.get("/webhooks/platforms/{webhook_uuid}") +async def verify_platform_webhook( + webhook_uuid: str, + request: Request, + service: PlatformService = Depends(get_service), +): + return await _run( + lambda: service.handle_webhook_callback(webhook_uuid, DashboardRequest(request)) + ) + + +@router.post("/webhooks/platforms/{webhook_uuid}") +async def receive_platform_webhook( + webhook_uuid: str, + request: Request, + service: PlatformService = Depends(get_service), +): + return await _run( + lambda: service.handle_webhook_callback(webhook_uuid, DashboardRequest(request)) + ) + + +@legacy_router.api_route("/webhook/{webhook_uuid}", methods=["GET", "POST"]) +async def dashboard_platform_webhook( + webhook_uuid: str, + request: Request, + service: PlatformService = Depends(get_service), +): + return await _run( + lambda: service.handle_webhook_callback(webhook_uuid, DashboardRequest(request)) + ) + + +@legacy_router.get("/stats") +async def get_dashboard_platform_stats( + _username: str = Depends(require_dashboard_user), + service: PlatformService = Depends(get_service), +): + return await _run(service.get_platform_stats) + + +@legacy_router.post("/registration/{platform_type}") +async def handle_dashboard_platform_registration( + platform_type: str, + request: Request, + _username: str = Depends(require_dashboard_user), + service: PlatformService = Depends(get_service), +): + payload = await _json_or_empty(request) + return await _run( + lambda: service.handle_platform_registration(platform_type, payload) + ) diff --git a/astrbot/dashboard/api/plugins.py b/astrbot/dashboard/api/plugins.py new file mode 100644 index 0000000000..551309e883 --- /dev/null +++ b/astrbot/dashboard/api/plugins.py @@ -0,0 +1,1416 @@ +from __future__ import annotations + +import re +from collections.abc import Callable +from typing import Any +from urllib.parse import quote + +from fastapi import APIRouter, Body, Depends, Query, Request +from fastapi.responses import PlainTextResponse, Response + +from astrbot.core import logger +from astrbot.dashboard.asgi_runtime import ( + DashboardRequestState, + call_request_view, +) +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import ok +from astrbot.dashboard.schemas import ( + EnabledPatch, + PluginByIdRequest, + PluginConfigFileDeleteRequest, + PluginConfigPayload, + PluginConfigUpdateRequest, + PluginEnabledRequest, + PluginInstallRequest, + PluginSourceRequest, + PluginUninstallRequest, + PluginUpdateRequest, + PluginVersionSupportRequest, +) +from astrbot.dashboard.services.config_service import ( + ConfigDisplayService, + ConfigFileService, +) +from astrbot.dashboard.services.plugin_page_service import ( + PluginPageContentPayload, + PluginPageService, + PluginPageServiceError, +) +from astrbot.dashboard.services.plugin_service import ( + PLUGIN_OPERATION_FAILED_MESSAGE, + PluginService, + PluginServiceError, + PluginServiceWarning, +) + +from .auth import AuthContext, require_dashboard_user, require_scope +from .multipart import multipart_parts + +router = APIRouter(tags=["Plugins"]) +legacy_router = APIRouter(tags=["Dashboard Plugins"], include_in_schema=False) + + +async def require_plugin_scope(request: Request) -> AuthContext: + return await require_scope(request, "plugin") + + +def get_service(request: Request) -> PluginService: + return request.app.state.services.plugins + + +def get_page_service(request: Request) -> PluginPageService: + return request.app.state.services.plugin_pages + + +def get_config_display_service(request: Request) -> ConfigDisplayService: + return request.app.state.services.config_display + + +def get_config_file_service(request: Request) -> ConfigFileService: + return request.app.state.services.config_files + + +async def _json_or_empty(request: Request) -> dict[str, Any]: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _required_text(value: object, name: str) -> str: + text = str(value or "").strip() + if not text: + raise ValueError(f"Missing key: {name}") + return text + + +def _plugin_id_from_body(body: dict[str, Any]) -> str: + return _required_text(body.get("plugin_id"), "plugin_id") + + +def _model_dict(payload) -> dict[str, Any]: + if payload is None: + return {} + if hasattr(payload, "model_dump"): + return payload.model_dump(exclude_none=True) + return payload if isinstance(payload, dict) else {} + + +def _service_ok(result): + if isinstance(result, tuple): + data, message = result + return ok(data, message) + return ok(result) + + +async def _run_service(operation, *, log_label: str | None = None): + try: + result = await run_maybe_async(operation) + return _service_ok(result) + except PluginServiceWarning as exc: + return { + "status": "warning", + "message": exc.public_message, + "data": exc.data, + } + except PluginServiceError as exc: + return {"status": "error", "message": exc.public_message, "data": {}} + except Exception: + if log_label: + logger.error("%s failed", log_label, exc_info=True) + else: + logger.error("Plugin service operation failed", exc_info=True) + return { + "status": "error", + "message": PLUGIN_OPERATION_FAILED_MESSAGE, + "data": {}, + } + + +async def _run_json( + request: Request, + operation: Callable[[dict[str, Any]], Any], + *, + log_label: str | None = None, +): + body = await _json_or_empty(request) + return await _run_service(lambda: operation(body), log_label=log_label) + + +def _normalize_plugin_api_route(route: str) -> str: + route = route.strip() + if not route.startswith("/"): + route = f"/{route}" + return route + + +def _plugin_api_route_pattern(route: str) -> str: + normalized = _normalize_plugin_api_route(route) + chunks = [] + pos = 0 + for match in re.finditer(r"<(?:(path):)?([A-Za-z_][A-Za-z0-9_]*)>", normalized): + chunks.append(re.escape(normalized[pos : match.start()])) + name = match.group(2) + chunks.append(f"(?P<{name}>.*)" if match.group(1) else f"(?P<{name}>[^/]+)") + pos = match.end() + chunks.append(re.escape(normalized[pos:])) + return "".join(chunks) + + +def _match_registered_web_api(registered_web_apis, subpath: str, method: str): + request_path = f"/{subpath.lstrip('/')}" + request_method = method.upper() + + for route, view_handler, methods, _ in registered_web_apis: + allowed_methods = [item.upper() for item in methods] + if request_method not in allowed_methods: + continue + + pattern = _plugin_api_route_pattern(route) + matched = re.fullmatch(pattern, request_path) + if matched: + return view_handler, matched.groupdict() + return None + + +def _plugin_extension_legacy_path(plugin_path: str, request: Request) -> str: + encoded_path = quote(plugin_path.lstrip("/"), safe="/:@!$&'()*+,;=-._~") + path = f"/api/plug/{encoded_path}" + if request.url.query: + return f"{path}?{request.url.query}" + return path + + +async def _call_plugin_extension( + plugin_path: str, + request: Request, + username: str, +): + registered_web_apis = ( + request.app.state.core_lifecycle.star_context.registered_web_apis + ) + matched_api = _match_registered_web_api( + registered_web_apis, + plugin_path, + request.method, + ) + if not matched_api: + return {"status": "error", "message": "未找到该路由", "data": {}} + + view_handler, path_values = matched_api + app_adapter = getattr(request.app.state, "dashboard_app_adapter", None) + if app_adapter is None: + return await run_maybe_async(lambda: view_handler(**path_values)) + + g_obj = DashboardRequestState() + g_obj.username = username + return await call_request_view( + request, + app_adapter, + view_handler, + path_values, + g_obj=g_obj, + quart_compat_path=_plugin_extension_legacy_path(plugin_path, request), + ) + + +def _get_request_locale(request: Request, default: str = "zh-CN") -> str: + raw_locale = request.headers.get("Accept-Language", "").strip() + locale = raw_locale.split(",", 1)[0].split(";", 1)[0].strip() + if not locale or len(locale) > 32: + return default + return locale + + +def _get_request_theme(request: Request) -> str | None: + theme = request.query_params.get("theme", "").strip() + return theme if theme in ("dark", "light") else None + + +def _plugin_page_error_response(status_code: int, message: str): + return PlainTextResponse( + message, + status_code=status_code, + headers={ + "Cache-Control": "no-store", + "Referrer-Policy": "no-referrer", + }, + ) + + +def _plugin_page_payload_response(payload: PluginPageContentPayload): + return Response( + content=payload.content, + media_type=payload.content_type, + headers=PluginPageService.build_security_headers(), + ) + + +async def _serve_plugin_page_content( + *, + request: Request, + page_service: PluginPageService, + username: str | None, + plugin_id: str, + page_name: str, + asset_path: str, +): + try: + payload = await page_service.serve_page_content( + plugin_name=plugin_id, + page_name=page_name, + asset_path=asset_path, + asset_token=request.query_params.get("asset_token", "").strip(), + username=username, + locale=_get_request_locale(request), + theme=_get_request_theme(request), + ) + except PluginPageServiceError as exc: + return _plugin_page_error_response(exc.status_code, exc.public_message) + return _plugin_page_payload_response(payload) + + +async def _serve_plugin_page_bridge_sdk( + *, + request: Request, + page_service: PluginPageService, +): + try: + payload = await page_service.serve_bridge_sdk( + asset_token=request.query_params.get("asset_token", "").strip(), + locale=_get_request_locale(request), + theme=_get_request_theme(request), + ) + except PluginPageServiceError as exc: + return _plugin_page_error_response(exc.status_code, exc.public_message) + return _plugin_page_payload_response(payload) + + +async def _get_plugin_page_entry_config( + *, + request: Request, + page_service: PluginPageService, + username: str | None, + plugin_id: str | None, + page_name: str | None, +): + try: + return ok( + await page_service.get_plugin_page_entry_config( + plugin_name=plugin_id, + page_name=page_name, + username=username, + locale=_get_request_locale(request), + ) + ) + except PluginPageServiceError as exc: + return {"status": "error", "message": exc.public_message, "data": {}} + + +async def _list_plugins( + *, + request: Request, + service: PluginService, + page_service: PluginPageService, +): + return await _run_service( + service.list_plugins_from_dashboard_query( + plugin_name=request.query_params.get("name") + or request.query_params.get("plugin_id"), + logo_token_resolver=service.get_plugin_logo_token, + installed_at_resolver=service.get_plugin_installed_at, + discover_pages=page_service.discover_plugin_pages, + ), + log_label="/api/plugin/get", + ) + + +async def _get_plugin_detail( + *, + plugin_id: str | None, + service: PluginService, + page_service: PluginPageService, +): + return await _run_service( + service.get_plugin_detail( + plugin_name=plugin_id, + logo_token_resolver=service.get_plugin_logo_token, + installed_at_resolver=service.get_plugin_installed_at, + serialize_pages=page_service.serialize_plugin_pages, + ), + log_label="/api/plugin/detail", + ) + + +async def _install_plugin_upload( + request: Request, + service: PluginService, + *, + log_label: str, +): + async def operation(): + form, files = await multipart_parts(request) + upload_file = files.get("file") + if upload_file is None: + raise PluginServiceError("缺少插件文件") + return await service.install_plugin_upload_from_dashboard_form( + upload_file=upload_file, + ignore_version_check=form.get("ignore_version_check", "false"), + ) + + return await _run_service(operation, log_label=log_label) + + +@router.get("/plugins/extensions/{plugin_path:path}") +async def get_plugin_extension_route( + plugin_path: str, + request: Request, + auth: AuthContext = Depends(require_plugin_scope), +): + return await _call_plugin_extension(plugin_path, request, auth.username) + + +@router.post("/plugins/extensions/{plugin_path:path}") +async def post_plugin_extension_route( + plugin_path: str, + request: Request, + auth: AuthContext = Depends(require_plugin_scope), +): + return await _call_plugin_extension(plugin_path, request, auth.username) + + +@router.put("/plugins/extensions/{plugin_path:path}") +async def put_plugin_extension_route( + plugin_path: str, + request: Request, + auth: AuthContext = Depends(require_plugin_scope), +): + return await _call_plugin_extension(plugin_path, request, auth.username) + + +@router.patch("/plugins/extensions/{plugin_path:path}") +async def patch_plugin_extension_route( + plugin_path: str, + request: Request, + auth: AuthContext = Depends(require_plugin_scope), +): + return await _call_plugin_extension(plugin_path, request, auth.username) + + +@router.delete("/plugins/extensions/{plugin_path:path}") +async def delete_plugin_extension_route( + plugin_path: str, + request: Request, + auth: AuthContext = Depends(require_plugin_scope), +): + return await _call_plugin_extension(plugin_path, request, auth.username) + + +@router.get("/plugins/failed") +async def list_failed_plugins( + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _run_service(service.get_failed_plugins) + + +@router.post("/plugins/update") +async def update_plugins( + payload: PluginUpdateRequest, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + body = _model_dict(payload) + if body.get("plugin_id"): + plugin_id = _plugin_id_from_body(body) + return await _run_service( + service.update_plugin( + { + "name": plugin_id, + **{key: value for key, value in body.items() if key != "plugin_id"}, + } + ), + log_label="/api/plugin/update", + ) + return await _run_service( + service.update_all_plugins( + { + **body, + "names": body.get("names") or body.get("plugin_ids") or [], + } + ), + log_label="/api/plugin/update-all", + ) + + +async def _check_plugin_version_support_payload( + payload: dict[str, Any], + service: PluginService, +): + return await _run_service( + lambda: service.check_plugin_version_support(payload), + log_label="/api/plugin/version-support/check", + ) + + +async def _check_plugin_version_support_request( + request: Request, + service: PluginService, +): + return await _check_plugin_version_support_payload( + await _json_or_empty(request), + service, + ) + + +@router.post("/plugins/version-support/check") +async def check_plugin_version_support( + payload: PluginVersionSupportRequest, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _check_plugin_version_support_payload(_model_dict(payload), service) + + +@router.post("/plugins/install/github") +async def install_plugin_from_github( + payload: PluginInstallRequest, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + body = _model_dict(payload) + repository = str(body.get("repository") or body.get("url") or "").strip() + if repository and not repository.startswith(("http://", "https://")): + repository = f"https://github.com/{repository}" + install_payload = { + "url": repository, + "proxy": body.get("proxy"), + "ignore_version_check": body.get("ignore_version_check", False), + } + if body.get("download_url"): + install_payload["download_url"] = body["download_url"] + return await _run_service( + service.install_plugin(install_payload), + log_label="/api/plugin/install", + ) + + +@router.post("/plugins/install/url") +async def install_plugin_from_url( + payload: PluginInstallRequest | None = Body(default=None), + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + body = _model_dict(payload) + url = str(body.get("url") or body.get("repository") or "").strip() + download_url = str(body.get("download_url") or url).strip() + return await _run_service( + service.install_plugin( + { + "url": url or download_url, + "download_url": download_url, + "proxy": body.get("proxy"), + "ignore_version_check": body.get("ignore_version_check", False), + } + ), + log_label="/api/plugin/install", + ) + + +@router.post("/plugins/install/upload") +async def install_plugin_from_upload( + request: Request, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _install_plugin_upload( + request, + service, + log_label="/api/plugin/install-upload", + ) + + +@router.get("/plugins/market") +async def list_plugin_market( + request: Request, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _run_service( + service.get_online_plugins_from_dashboard_query( + custom_registry=request.query_params.get("custom_registry"), + force_refresh=request.query_params.get("force_refresh", "false"), + ), + log_label="/api/plugin/market_list", + ) + + +@router.get("/plugins/market/categories") +async def list_plugin_market_categories( + _auth: AuthContext = Depends(require_plugin_scope), +): + return ok({"categories": []}) + + +@router.get("/plugin-sources") +async def list_plugin_sources( + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return ok({"sources": await service.get_custom_sources()}) + + +@router.post("/plugin-sources") +async def create_plugin_source( + payload: PluginSourceRequest, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return ok( + {"sources": await service.create_custom_source(_model_dict(payload))}, + message="保存成功", + ) + + +@router.put("/plugin-sources") +async def replace_plugin_sources( + payload: PluginSourceRequest, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return ok( + {"sources": await service.replace_custom_sources(_model_dict(payload))}, + message="保存成功", + ) + + +@router.delete("/plugin-sources/by-id") +async def delete_plugin_source_by_id( + source_id: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return ok( + {"sources": await service.delete_custom_source(source_id)}, + message="保存成功", + ) + + +@router.delete("/plugin-sources/{source_id}") +async def delete_plugin_source( + source_id: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return ok( + {"sources": await service.delete_custom_source(source_id)}, + message="保存成功", + ) + + +@router.get("/plugins/page-bridge-sdk.js") +async def get_plugin_page_bridge_sdk( + request: Request, + _auth: AuthContext = Depends(require_plugin_scope), + page_service: PluginPageService = Depends(get_page_service), +): + return await _serve_plugin_page_bridge_sdk( + request=request, + page_service=page_service, + ) + + +@router.get("/plugins") +async def list_plugins( + request: Request, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), + page_service: PluginPageService = Depends(get_page_service), +): + return await _list_plugins( + request=request, + service=service, + page_service=page_service, + ) + + +@router.get("/plugins/by-id") +async def get_plugin_by_id( + plugin_id: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), + page_service: PluginPageService = Depends(get_page_service), +): + return await _get_plugin_detail( + plugin_id=plugin_id, + service=service, + page_service=page_service, + ) + + +@router.delete("/plugins/by-id") +async def uninstall_plugin_by_id( + payload: PluginUninstallRequest | None = None, + plugin_id: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + body = _model_dict(payload) + return await _run_service( + service.uninstall_plugin({"name": plugin_id, **body}), + log_label="/api/plugin/uninstall", + ) + + +@router.get("/plugins/config") +async def get_plugin_config_by_id( + plugin_id: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigDisplayService = Depends(get_config_display_service), +): + return ok({"plugin_name": plugin_id, **await service.get_configs(plugin_id)}) + + +@router.put("/plugins/config") +async def update_plugin_config_by_id( + payload: PluginConfigUpdateRequest, + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigFileService = Depends(get_config_file_service), +): + body = _model_dict(payload) + plugin_id = _plugin_id_from_body(body) + config = body.get("config") + config = config if isinstance(config, dict) else {} + return ok( + message=await service.save_plugin_configs_from_dashboard_payload( + config, + plugin_name=plugin_id, + ) + ) + + +@router.get("/plugins/config/schema") +async def get_plugin_config_schema_by_id( + plugin_id: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigDisplayService = Depends(get_config_display_service), +): + return ok({"plugin_name": plugin_id, **await service.get_configs(plugin_id)}) + + +@router.get("/plugins/config-files") +async def list_plugin_config_files_by_id( + plugin_id: str = Query(...), + config_key: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigFileService = Depends(get_config_file_service), +): + return ok( + service.list_config_files( + scope="plugin", + name=plugin_id, + key_path=config_key, + ) + ) + + +@router.post("/plugins/config-files") +async def upload_plugin_config_files_by_id( + request: Request, + plugin_id: str = Query(...), + config_key: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigFileService = Depends(get_config_file_service), +): + _, files = await multipart_parts(request) + return ok( + await service.upload_config_file( + scope="plugin", + name=plugin_id, + key_path=config_key, + files=files, + ) + ) + + +@router.delete("/plugins/config-files") +async def delete_plugin_config_file_by_id( + payload: PluginConfigFileDeleteRequest | None = None, + plugin_id: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigFileService = Depends(get_config_file_service), +): + return ok( + message=service.delete_config_file_from_dashboard_payload( + scope="plugin", + name=plugin_id, + payload=_model_dict(payload), + ) + ) + + +@router.get("/plugins/readme") +async def get_plugin_readme_by_id( + plugin_id: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _run_service( + lambda: service.get_plugin_readme(plugin_id), + log_label="/api/plugin/readme", + ) + + +@router.get("/plugins/changelog") +async def get_plugin_changelog_by_id( + plugin_id: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _run_service( + lambda: service.get_plugin_changelog(plugin_id), + log_label="/api/plugin/changelog", + ) + + +@router.post("/plugins/reload") +async def reload_plugin_by_id( + payload: PluginByIdRequest, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + plugin_id = _plugin_id_from_body(_model_dict(payload)) + return await _run_service( + service.reload_plugin({"name": plugin_id}), + log_label="/api/plugin/reload", + ) + + +@router.patch("/plugins/enabled") +async def set_plugin_enabled_by_id( + payload: PluginEnabledRequest, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + body = _model_dict(payload) + plugin_id = _plugin_id_from_body(body) + return await _run_service( + service.set_plugin_enabled( + {"name": plugin_id}, enabled=bool(body.get("enabled")) + ), + log_label="/api/plugin/on" if body.get("enabled") else "/api/plugin/off", + ) + + +@router.get("/plugins/pages") +async def list_plugin_pages_by_id( + plugin_id: str = Query(...), + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), + page_service: PluginPageService = Depends(get_page_service), +): + return await _get_plugin_detail( + plugin_id=plugin_id, + service=service, + page_service=page_service, + ) + + +@router.get("/plugins/page") +async def get_plugin_page_by_id( + request: Request, + plugin_id: str = Query(...), + page_name: str = Query(...), + auth: AuthContext = Depends(require_plugin_scope), + page_service: PluginPageService = Depends(get_page_service), +): + return await _get_plugin_page_entry_config( + request=request, + page_service=page_service, + username=auth.username, + plugin_id=plugin_id, + page_name=page_name, + ) + + +@router.get("/plugins/page/assets") +async def get_plugin_page_asset_by_id( + request: Request, + plugin_id: str = Query(...), + page_name: str = Query(...), + asset_path: str = Query(...), + auth: AuthContext = Depends(require_plugin_scope), + page_service: PluginPageService = Depends(get_page_service), +): + return await _serve_plugin_page_content( + request=request, + page_service=page_service, + username=auth.username, + plugin_id=plugin_id, + page_name=page_name, + asset_path=asset_path, + ) + + +@router.get("/plugins/{plugin_id}") +async def get_plugin( + plugin_id: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), + page_service: PluginPageService = Depends(get_page_service), +): + return await _get_plugin_detail( + plugin_id=plugin_id, + service=service, + page_service=page_service, + ) + + +@router.delete("/plugins/{plugin_id}") +async def uninstall_plugin( + plugin_id: str, + payload: PluginUninstallRequest | None = None, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + body = _model_dict(payload) + return await _run_service( + service.uninstall_plugin({"name": plugin_id, **body}), + log_label="/api/plugin/uninstall", + ) + + +@router.delete("/plugins/failed/{plugin_id}") +async def uninstall_failed_plugin( + plugin_id: str, + payload: PluginUninstallRequest | None = None, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + body = _model_dict(payload) + return await _run_service( + service.uninstall_failed_plugin({"dir_name": plugin_id, **body}), + log_label="/api/plugin/uninstall-failed", + ) + + +@router.post("/plugins/failed/{plugin_id}/reload") +async def reload_failed_plugin( + plugin_id: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _run_service( + service.reload_failed_plugin({"dir_name": plugin_id}), + log_label="/api/plugin/reload-failed", + ) + + +@router.get("/plugins/{plugin_id}/config") +async def get_plugin_config( + plugin_id: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigDisplayService = Depends(get_config_display_service), +): + return ok({"plugin_name": plugin_id, **await service.get_configs(plugin_id)}) + + +@router.put("/plugins/{plugin_id}/config") +async def update_plugin_config( + plugin_id: str, + payload: PluginConfigPayload, + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigFileService = Depends(get_config_file_service), +): + body = _model_dict(payload) + config = body.get("config") + config = config if isinstance(config, dict) else body + return ok( + message=await service.save_plugin_configs_from_dashboard_payload( + config, + plugin_name=plugin_id, + ) + ) + + +@router.get("/plugins/{plugin_id}/config/schema") +async def get_plugin_config_schema( + plugin_id: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigDisplayService = Depends(get_config_display_service), +): + return ok({"plugin_name": plugin_id, **await service.get_configs(plugin_id)}) + + +@router.get("/plugins/{plugin_id}/config-files/{config_key:path}") +async def list_plugin_config_files( + plugin_id: str, + config_key: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigFileService = Depends(get_config_file_service), +): + return ok( + service.list_config_files( + scope="plugin", + name=plugin_id, + key_path=config_key, + ) + ) + + +@router.post("/plugins/{plugin_id}/config-files/{config_key:path}") +async def upload_plugin_config_files( + plugin_id: str, + config_key: str, + request: Request, + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigFileService = Depends(get_config_file_service), +): + _, files = await multipart_parts(request) + return ok( + await service.upload_config_file( + scope="plugin", + name=plugin_id, + key_path=config_key, + files=files, + ) + ) + + +@router.delete("/plugins/{plugin_id}/config-files") +async def delete_plugin_config_file( + plugin_id: str, + payload: PluginConfigFileDeleteRequest | None = None, + _auth: AuthContext = Depends(require_plugin_scope), + service: ConfigFileService = Depends(get_config_file_service), +): + return ok( + message=service.delete_config_file_from_dashboard_payload( + scope="plugin", + name=plugin_id, + payload=_model_dict(payload), + ) + ) + + +@router.get("/plugins/{plugin_id}/readme") +async def get_plugin_readme( + plugin_id: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _run_service( + lambda: service.get_plugin_readme(plugin_id), + log_label="/api/plugin/readme", + ) + + +@router.get("/plugins/{plugin_id}/changelog") +async def get_plugin_changelog( + plugin_id: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _run_service( + lambda: service.get_plugin_changelog(plugin_id), + log_label="/api/plugin/changelog", + ) + + +@router.post("/plugins/{plugin_id}/reload") +async def reload_plugin( + plugin_id: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _run_service( + service.reload_plugin({"name": plugin_id}), + log_label="/api/plugin/reload", + ) + + +@router.patch("/plugins/{plugin_id}/enabled") +async def set_plugin_enabled( + plugin_id: str, + payload: EnabledPatch, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + return await _run_service( + service.set_plugin_enabled({"name": plugin_id}, enabled=payload.enabled), + log_label="/api/plugin/on" if payload.enabled else "/api/plugin/off", + ) + + +@router.post("/plugins/{plugin_id}/update") +async def update_plugin( + plugin_id: str, + payload: PluginUpdateRequest | None = None, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), +): + body = _model_dict(payload) + return await _run_service( + service.update_plugin({"name": plugin_id, **body}), + log_label="/api/plugin/update", + ) + + +@router.get("/plugins/{plugin_id}/pages") +async def list_plugin_pages( + plugin_id: str, + _auth: AuthContext = Depends(require_plugin_scope), + service: PluginService = Depends(get_service), + page_service: PluginPageService = Depends(get_page_service), +): + return await _get_plugin_detail( + plugin_id=plugin_id, + service=service, + page_service=page_service, + ) + + +@router.get("/plugins/{plugin_id}/pages/{page_name}") +async def get_plugin_page( + plugin_id: str, + page_name: str, + request: Request, + auth: AuthContext = Depends(require_plugin_scope), + page_service: PluginPageService = Depends(get_page_service), +): + return await _get_plugin_page_entry_config( + request=request, + page_service=page_service, + username=auth.username, + plugin_id=plugin_id, + page_name=page_name, + ) + + +@router.get("/plugins/{plugin_id}/pages/{page_name}/assets/{asset_path:path}") +async def get_plugin_page_asset( + plugin_id: str, + page_name: str, + asset_path: str, + request: Request, + auth: AuthContext = Depends(require_plugin_scope), + page_service: PluginPageService = Depends(get_page_service), +): + return await _serve_plugin_page_content( + request=request, + page_service=page_service, + username=auth.username, + plugin_id=plugin_id, + page_name=page_name, + asset_path=asset_path, + ) + + +@legacy_router.get("/api/plugin/get") +async def dashboard_list_plugins( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), + page_service: PluginPageService = Depends(get_page_service), +): + return await _list_plugins( + request=request, + service=service, + page_service=page_service, + ) + + +@legacy_router.get("/api/plugin/detail") +async def dashboard_get_plugin_detail( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), + page_service: PluginPageService = Depends(get_page_service), +): + return await _get_plugin_detail( + plugin_id=request.query_params.get("name"), + service=service, + page_service=page_service, + ) + + +@legacy_router.post("/api/plugin/check-compat") +async def dashboard_check_plugin_version_support( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _check_plugin_version_support_request(request, service) + + +@legacy_router.get("/api/plugin/page/entry") +async def dashboard_get_plugin_page_entry_config( + request: Request, + username: str = Depends(require_dashboard_user), + page_service: PluginPageService = Depends(get_page_service), +): + return await _get_plugin_page_entry_config( + request=request, + page_service=page_service, + username=username, + plugin_id=request.query_params.get("name"), + page_name=request.query_params.get("page"), + ) + + +@legacy_router.post("/api/plugin/install") +async def dashboard_install_plugin( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + service.install_plugin, + log_label="/api/plugin/install", + ) + + +@legacy_router.post("/api/plugin/install-upload") +async def dashboard_install_plugin_upload( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _install_plugin_upload( + request, + service, + log_label="/api/plugin/install-upload", + ) + + +@legacy_router.post("/api/plugin/update") +async def dashboard_update_plugin( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + service.update_plugin, + log_label="/api/plugin/update", + ) + + +@legacy_router.post("/api/plugin/update-all") +async def dashboard_update_all_plugins( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + service.update_all_plugins, + log_label="/api/plugin/update-all", + ) + + +@legacy_router.post("/api/plugin/uninstall") +async def dashboard_uninstall_plugin( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + service.uninstall_plugin, + log_label="/api/plugin/uninstall", + ) + + +@legacy_router.post("/api/plugin/uninstall-failed") +async def dashboard_uninstall_failed_plugin( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + service.uninstall_failed_plugin, + log_label="/api/plugin/uninstall-failed", + ) + + +@legacy_router.get("/api/plugin/market_list") +async def dashboard_list_plugin_market( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_service( + service.get_online_plugins_from_dashboard_query( + custom_registry=request.query_params.get("custom_registry"), + force_refresh=request.query_params.get("force_refresh", "false"), + ), + log_label="/api/plugin/market_list", + ) + + +@legacy_router.post("/api/plugin/off") +async def dashboard_disable_plugin( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + lambda data: service.set_plugin_enabled(data, enabled=False), + log_label="/api/plugin/off", + ) + + +@legacy_router.post("/api/plugin/on") +async def dashboard_enable_plugin( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + lambda data: service.set_plugin_enabled(data, enabled=True), + log_label="/api/plugin/on", + ) + + +@legacy_router.post("/api/plugin/reload-failed") +async def dashboard_reload_failed_plugin( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + service.reload_failed_plugin, + log_label="/api/plugin/reload-failed", + ) + + +@legacy_router.post("/api/plugin/reload") +async def dashboard_reload_plugin( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + service.reload_plugin, + log_label="/api/plugin/reload", + ) + + +@legacy_router.get("/api/plugin/readme") +async def dashboard_get_plugin_readme( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_service( + lambda: service.get_plugin_readme(request.query_params.get("name")), + log_label="/api/plugin/readme", + ) + + +@legacy_router.get("/api/plugin/changelog") +async def dashboard_get_plugin_changelog( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_service( + lambda: service.get_plugin_changelog(request.query_params.get("name")), + log_label="/api/plugin/changelog", + ) + + +@legacy_router.get("/api/plugin/source/get") +async def dashboard_get_custom_source( + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_service(service.get_custom_sources) + + +@legacy_router.post("/api/plugin/source/save") +async def dashboard_save_custom_source( + request: Request, + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_json( + request, + service.save_custom_sources, + log_label="/api/plugin/source/save", + ) + + +@legacy_router.get("/api/plugin/source/get-failed-plugins") +async def dashboard_get_failed_plugins( + _username: str = Depends(require_dashboard_user), + service: PluginService = Depends(get_service), +): + return await _run_service(service.get_failed_plugins) + + +@legacy_router.get("/api/plugin/page/bridge-sdk.js") +async def dashboard_get_plugin_page_bridge_sdk( + request: Request, + _username: str = Depends(require_dashboard_user), + page_service: PluginPageService = Depends(get_page_service), +): + return await _serve_plugin_page_bridge_sdk( + request=request, + page_service=page_service, + ) + + +@legacy_router.get("/api/plugin/page/content/{plugin_id}/{page_name}/") +async def dashboard_get_plugin_page_entry( + plugin_id: str, + page_name: str, + request: Request, + username: str = Depends(require_dashboard_user), + page_service: PluginPageService = Depends(get_page_service), +): + return await _serve_plugin_page_content( + request=request, + page_service=page_service, + username=username, + plugin_id=plugin_id, + page_name=page_name, + asset_path="", + ) + + +@legacy_router.get("/api/plugin/page/content/{plugin_id}/{page_name}/{asset_path:path}") +async def dashboard_get_plugin_page_asset( + plugin_id: str, + page_name: str, + asset_path: str, + request: Request, + username: str = Depends(require_dashboard_user), + page_service: PluginPageService = Depends(get_page_service), +): + return await _serve_plugin_page_content( + request=request, + page_service=page_service, + username=username, + plugin_id=plugin_id, + page_name=page_name, + asset_path=asset_path, + ) + + +@legacy_router.api_route("/api/plug/{plugin_path:path}", methods=["GET", "POST"]) +async def dashboard_plugin_extension_route( + plugin_path: str, + request: Request, + username: str = Depends(require_dashboard_user), +): + return await _call_plugin_extension(plugin_path, request, username) diff --git a/astrbot/dashboard/api/providers.py b/astrbot/dashboard/api/providers.py new file mode 100644 index 0000000000..c47ab489c9 --- /dev/null +++ b/astrbot/dashboard/api/providers.py @@ -0,0 +1,593 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request + +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.schemas import ( + EnabledPatch, + ProviderConfigRequest, + ProviderSourceRequest, +) +from astrbot.dashboard.services.config_service import ProviderConfigService + +from .auth import AuthContext, require_scope + +router = APIRouter(tags=["Providers"]) +legacy_router = APIRouter( + prefix="/api/config", + tags=["Dashboard Providers"], + include_in_schema=False, +) + + +async def require_provider_scope(request: Request) -> AuthContext: + return await require_scope(request, "provider") + + +def get_service(request: Request) -> ProviderConfigService: + return request.app.state.services.providers + + +async def _json_or_empty(request: Request) -> dict: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _required_text(value: object, name: str) -> str: + text = str(value or "").strip() + if not text: + raise ValueError(f"Missing key: {name}") + return text + + +def _model_dict(payload) -> dict: + if payload is None: + return {} + if hasattr(payload, "model_dump"): + return payload.model_dump(exclude_none=True) + return payload if isinstance(payload, dict) else {} + + +def _config_from_body(body: dict) -> dict: + config = body.get("config") + if isinstance(config, dict): + return config + return { + key: value + for key, value in body.items() + if key + not in { + "provider_id", + "source_id", + "config", + "enabled", + "provider_config", + } + } + + +def _provider_config_for_dimension( + service: ProviderConfigService, + provider_id: str, + body: dict, +) -> dict: + provider = service.get_provider(provider_id, merged=True) + base_config = provider.get("provider") if isinstance(provider, dict) else {} + if not isinstance(base_config, dict): + base_config = {} + provider_config = body.get("provider_config") + if isinstance(provider_config, dict): + return {**base_config, **provider_config} + return base_config + + +def _alias_error(message: str): + return error(message) + + +@router.get("/providers/schema") +async def get_provider_schema( + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(service.get_provider_schema()) + + +@router.get("/provider-sources") +async def list_provider_sources( + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(service.list_provider_sources()) + + +@router.post("/provider-sources") +async def create_provider_source( + payload: ProviderSourceRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + config = payload.to_dashboard_config() + source_id = config.get("id") + if not source_id: + raise ValueError("Provider source config must have an 'id' field") + await service.upsert_provider_source(source_id, config) + return ok(message="更新 provider source 成功") + + +@router.get("/provider-sources/by-id") +async def get_provider_source_by_id( + source_id: str = Query(...), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(service.get_provider_source(source_id)) + + +@router.put("/provider-sources/by-id") +async def upsert_provider_source_by_id( + payload: ProviderSourceRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + source_id = _required_text(payload.source_id, "source_id") + await service.upsert_provider_source( + source_id, + payload.to_dashboard_config(fallback_id=source_id), + ) + return ok(message="更新 provider source 成功") + + +@router.delete("/provider-sources/by-id") +async def delete_provider_source_by_id( + source_id: str = Query(...), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + await service.delete_provider_source(source_id) + return ok(message="删除 provider source 成功") + + +@router.get("/provider-sources/models") +async def list_provider_source_models_by_id( + source_id: str = Query(...), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(await service.list_provider_source_models(source_id)) + + +@router.get("/provider-sources/providers") +async def list_providers_by_source_id( + source_id: str = Query(...), + capability: str | None = Query(default=None), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(service.list_providers(capability=capability, source_id=source_id)) + + +@router.post("/provider-sources/providers") +async def create_provider_in_source_by_id( + payload: ProviderConfigRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + source_id = _required_text(payload.source_id, "source_id") + await service.create_provider( + payload.to_dashboard_config(source_id=source_id), + source_id, + ) + return ok(message="新增服务提供商配置成功") + + +@router.get("/provider-sources/{source_id:path}/models") +async def list_provider_source_models( + source_id: str, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(await service.list_provider_source_models(source_id)) + + +@router.get("/provider-sources/{source_id:path}/providers") +async def list_providers_by_source( + source_id: str, + capability: str | None = Query(default=None), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(service.list_providers(capability=capability, source_id=source_id)) + + +@router.post("/provider-sources/{source_id:path}/providers") +async def create_provider_in_source( + source_id: str, + payload: ProviderConfigRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + await service.create_provider( + payload.to_dashboard_config(source_id=source_id), source_id + ) + return ok(message="新增服务提供商配置成功") + + +@router.get("/provider-sources/{source_id:path}") +async def get_provider_source( + source_id: str, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(service.get_provider_source(source_id)) + + +@router.put("/provider-sources/{source_id:path}") +async def upsert_provider_source( + source_id: str, + payload: ProviderSourceRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + await service.upsert_provider_source( + source_id, + payload.to_dashboard_config(), + ) + return ok(message="更新 provider source 成功") + + +@router.delete("/provider-sources/{source_id:path}") +async def delete_provider_source( + source_id: str, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + await service.delete_provider_source(source_id) + return ok(message="删除 provider source 成功") + + +@router.get("/providers") +async def list_providers( + capability: str | None = Query(default=None), + source_id: str | None = Query(default=None), + enabled: bool | None = Query(default=None), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok( + service.list_providers( + capability=capability, + source_id=source_id, + enabled=enabled, + ) + ) + + +@router.post("/providers") +async def create_provider( + payload: ProviderConfigRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + await service.create_provider(payload.to_dashboard_config()) + return ok(message="新增服务提供商配置成功") + + +@router.get("/providers/by-id") +async def get_provider_by_id( + provider_id: str = Query(...), + merged: bool = Query(default=False), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(service.get_provider(provider_id, merged=merged)) + + +@router.put("/providers/by-id") +async def update_provider_by_id( + payload: ProviderConfigRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + provider_id = _required_text(payload.provider_id, "provider_id") + await service.update_provider( + provider_id, + payload.to_dashboard_config(fallback_id=provider_id), + ) + return ok(message="更新成功,已经实时生效~") + + +@router.delete("/providers/by-id") +async def delete_provider_by_id( + provider_id: str = Query(...), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + await service.delete_provider(provider_id) + return ok(message="删除成功,已经实时生效。") + + +@router.patch("/providers/enabled") +async def set_provider_enabled_by_id( + payload: ProviderConfigRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + provider_id = _required_text(payload.provider_id, "provider_id") + await service.set_provider_enabled(provider_id, bool(payload.enabled)) + return ok(message="更新成功,已经实时生效~") + + +@router.post("/providers/test") +async def test_provider_by_id( + payload: ProviderConfigRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + provider_id = _required_text(payload.provider_id, "provider_id") + return ok(await service.test_provider(provider_id)) + + +@router.post("/providers/embedding-dimension") +async def get_embedding_dimension_by_id( + payload: ProviderConfigRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + body = _model_dict(payload) + provider_id = _required_text(payload.provider_id, "provider_id") + return ok( + await service.get_embedding_dimension( + _provider_config_for_dimension(service, provider_id, body) + ) + ) + + +@router.patch("/providers/{provider_id:path}/enabled") +async def set_provider_enabled( + provider_id: str, + payload: EnabledPatch, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + await service.set_provider_enabled(provider_id, payload.enabled) + return ok(message="更新成功,已经实时生效~") + + +@router.post("/providers/{provider_id:path}/test") +async def test_provider( + provider_id: str, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(await service.test_provider(provider_id)) + + +@router.post("/providers/{provider_id:path}/embedding-dimension") +async def get_embedding_dimension( + provider_id: str, + payload: ProviderConfigRequest | None = None, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + body = _model_dict(payload) + return ok( + await service.get_embedding_dimension( + _provider_config_for_dimension(service, provider_id, body) + ) + ) + + +@router.get("/providers/{provider_id:path}") +async def get_provider( + provider_id: str, + merged: bool = Query(default=False), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(service.get_provider(provider_id, merged=merged)) + + +@router.put("/providers/{provider_id:path}") +async def update_provider( + provider_id: str, + payload: ProviderConfigRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + await service.update_provider( + provider_id, + payload.to_dashboard_config(fallback_id=provider_id), + ) + return ok(message="更新成功,已经实时生效~") + + +@router.delete("/providers/{provider_id:path}") +async def delete_provider( + provider_id: str, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + await service.delete_provider(provider_id) + return ok(message="删除成功,已经实时生效。") + + +@legacy_router.get("/provider/template") +async def get_dashboard_alias_provider_template( + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + return ok(service.get_provider_schema()) + + +@legacy_router.get("/provider/list") +async def list_dashboard_alias_providers( + provider_type: str | None = Query(default=None), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + if not provider_type: + return _alias_error("缺少参数 provider_type") + providers = [] + seen_ids = set() + for item in provider_type.split(","): + for provider in service.list_providers(capability=item)["providers"]: + provider_id = provider.get("id") + if provider_id in seen_ids: + continue + seen_ids.add(provider_id) + providers.append(provider) + return ok(providers) + + +@legacy_router.post("/provider/new") +async def create_dashboard_alias_provider( + payload: ProviderConfigRequest, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + try: + await service.create_provider(payload.to_dashboard_config()) + return ok(message="新增服务提供商配置成功") + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/provider/update") +async def update_dashboard_alias_provider( + request: Request, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + body = await _json_or_empty(request) + provider_id = body.get("id") + config = body.get("config") + if not provider_id or not isinstance(config, dict): + return _alias_error("参数错误") + try: + await service.update_provider( + str(provider_id), + ProviderConfigRequest(config=config).to_dashboard_config( + fallback_id=str(provider_id), + ), + ) + return ok(message="更新成功,已经实时生效~") + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/provider/delete") +async def delete_dashboard_alias_provider( + request: Request, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + body = await _json_or_empty(request) + provider_id = body.get("id") + if not provider_id: + return _alias_error("缺少参数 id") + try: + await service.delete_provider(str(provider_id)) + return ok(message="删除成功,已经实时生效。") + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.get("/provider/check_one") +async def check_dashboard_alias_provider( + id: str | None = Query(default=None), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + if not id: + return _alias_error("Missing provider_id parameter") + try: + return ok(await service.test_provider(id)) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.get("/provider/model_list") +async def list_dashboard_alias_provider_models( + provider_id: str | None = Query(default=None), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + try: + return ok(await service.list_provider_models_for_dashboard(provider_id)) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/provider/get_embedding_dim") +async def get_dashboard_alias_provider_embedding_dimension( + request: Request, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + body = await _json_or_empty(request) + try: + return ok(await service.get_embedding_dimension_from_dashboard_payload(body)) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.get("/provider_sources/models") +async def list_dashboard_alias_provider_source_models( + source_id: str | None = Query(default=None), + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + if not source_id: + return _alias_error("缺少参数 source_id") + try: + data = await service.list_provider_source_models(source_id) + data.pop("provider_source_id", None) + return ok(data) + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/provider_sources/update") +async def update_dashboard_alias_provider_source( + request: Request, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + body = await _json_or_empty(request) + source_id = body.get("original_id") + config = body.get("config") or body + if not source_id: + return _alias_error("缺少 original_id") + if not isinstance(config, dict): + return _alias_error("缺少或错误的配置数据") + try: + await service.upsert_provider_source( + str(source_id), + ProviderSourceRequest(config=config).to_dashboard_config( + fallback_id=str(source_id), + ), + ) + return ok(message="更新 provider source 成功") + except ValueError as exc: + return _alias_error(str(exc)) + + +@legacy_router.post("/provider_sources/delete") +async def delete_dashboard_alias_provider_source( + request: Request, + _auth: AuthContext = Depends(require_provider_scope), + service: ProviderConfigService = Depends(get_service), +): + body = await _json_or_empty(request) + source_id = body.get("id") + if not source_id: + return _alias_error("缺少 provider_source_id") + try: + await service.delete_provider_source(str(source_id)) + return ok(message="删除 provider source 成功") + except ValueError as exc: + return _alias_error(str(exc)) diff --git a/astrbot/dashboard/api/router.py b/astrbot/dashboard/api/router.py new file mode 100644 index 0000000000..b59ea62858 --- /dev/null +++ b/astrbot/dashboard/api/router.py @@ -0,0 +1,63 @@ +"""FastAPI HTTP API surface for the AstrBot dashboard.""" + +from fastapi import APIRouter + +from .api_keys import router as api_keys_router +from .auth import router as auth_router +from .backups import router as backups_router +from .bots import router as bots_router +from .chat import router as chat_router +from .chat_projects import router as chat_projects_router +from .config_profiles import router as config_profiles_router +from .conversations import router as conversations_router +from .cron import router as cron_router +from .extensions import router as extensions_router +from .files import router as files_router +from .knowledge_bases import router as knowledge_bases_router +from .live_chat import router as live_chat_router +from .logs import router as logs_router +from .open_api import router as open_api_router +from .personas import router as personas_router +from .platform import router as platform_router +from .plugins import router as plugins_router +from .providers import router as providers_router +from .sessions import router as sessions_router +from .skills import router as skills_router +from .stats import router as stats_router +from .subagents import router as subagents_router +from .t2i import router as t2i_router +from .tools import router as tools_router +from .updates import router as updates_router + +API_V1_PREFIX = "/api/v1" + + +def build_api_router() -> APIRouter: + router = APIRouter(prefix=API_V1_PREFIX) + router.include_router(auth_router) + router.include_router(backups_router) + router.include_router(config_profiles_router) + router.include_router(api_keys_router) + router.include_router(bots_router) + router.include_router(providers_router) + router.include_router(plugins_router) + router.include_router(chat_router) + router.include_router(chat_projects_router) + router.include_router(conversations_router) + router.include_router(cron_router) + router.include_router(files_router) + router.include_router(knowledge_bases_router) + router.include_router(extensions_router) + router.include_router(skills_router) + router.include_router(sessions_router) + router.include_router(subagents_router) + router.include_router(logs_router) + router.include_router(stats_router) + router.include_router(tools_router) + router.include_router(platform_router) + router.include_router(t2i_router) + router.include_router(personas_router) + router.include_router(updates_router) + router.include_router(open_api_router) + router.include_router(live_chat_router) + return router diff --git a/astrbot/dashboard/api/sessions.py b/astrbot/dashboard/api/sessions.py new file mode 100644 index 0000000000..42ede8ea5e --- /dev/null +++ b/astrbot/dashboard/api/sessions.py @@ -0,0 +1,414 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request + +from astrbot.core import logger +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.schemas import ( + BatchSessionProviderRequest, + BatchSessionServiceRequest, + SessionGroupRequest, + SessionRuleRequest, + UmoListRequest, +) +from astrbot.dashboard.services.session_management_service import ( + SessionManagementService, + SessionManagementServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Sessions"]) +legacy_router = APIRouter( + prefix="/api/session", + tags=["Dashboard Sessions"], + include_in_schema=False, +) + + +def get_service(request: Request) -> SessionManagementService: + return request.app.state.services.sessions + + +async def require_data_scope(request: Request) -> AuthContext: + return await require_scope(request, "data") + + +async def _json_or_empty(request: Request) -> dict: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _service_error(exc: SessionManagementServiceError) -> dict: + return error(str(exc)) + + +def _unexpected_error(prefix: str, exc: Exception) -> dict: + logger.error(f"{prefix}: {exc!s}") + return error(f"{prefix}: {exc!s}") + + +async def _run(operation, *, label: str) -> dict: + try: + result = await run_maybe_async(operation) + return ok(result) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error(label, exc) + + +async def _run_dashboard_json( + request: Request, + operation, + *, + label: str, +) -> dict: + body = await _json_or_empty(request) + return await _run(lambda: operation(body), label=label) + + +@router.get("/sessions") +async def list_sessions( + page: int = Query(1), + page_size: int = Query(20), + search: str = Query(""), + message_type: str = Query("all"), + platform: str = Query(""), + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok( + await service.list_all_umos_with_status( + page=page, + page_size=page_size, + search=search.strip(), + message_type=message_type, + platform=platform, + ) + ) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("获取会话状态列表失败", exc) + + +@router.get("/sessions/active-umos") +async def list_active_umos( + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok(await service.list_active_umos()) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("获取 UMO 列表失败", exc) + + +@router.get("/sessions/rules") +async def list_session_rules( + page: int = Query(1), + page_size: int = Query(10), + search: str = Query(""), + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok( + await service.list_session_rules( + page=page, + page_size=page_size, + search=search.strip(), + ) + ) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("获取规则列表失败", exc) + + +@router.post("/sessions/rules") +async def update_session_rule( + payload: SessionRuleRequest, + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok( + await service.update_session_rule(payload.model_dump(exclude_none=True)) + ) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("更新会话规则失败", exc) + + +@router.post("/sessions/rules/delete") +async def delete_session_rule( + payload: UmoListRequest, + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok( + await service.delete_session_rules(payload.model_dump(exclude_none=True)) + ) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("删除会话规则失败", exc) + + +@router.patch("/sessions/provider") +async def update_session_provider( + payload: BatchSessionProviderRequest, + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok( + await service.batch_update_provider(payload.model_dump(exclude_none=True)) + ) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("批量更新 Provider 失败", exc) + + +@router.patch("/sessions/service") +async def update_session_service( + payload: BatchSessionServiceRequest, + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok( + await service.batch_update_service(payload.model_dump(exclude_none=True)) + ) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("批量更新服务状态失败", exc) + + +@router.get("/session-groups") +async def list_session_groups( + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok(service.list_groups()) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("获取分组列表失败", exc) + + +@router.post("/session-groups") +async def create_session_group( + payload: SessionGroupRequest, + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok(service.create_group(payload.model_dump(exclude_none=True))) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("创建分组失败", exc) + + +@router.put("/session-groups/{group_id}") +async def update_session_group( + group_id: str, + payload: SessionGroupRequest, + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + body = payload.model_dump(exclude_none=True) + return ok(service.update_group({"group_id": group_id, **body})) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("更新分组失败", exc) + + +@router.delete("/session-groups/{group_id}") +async def delete_session_group( + group_id: str, + _auth: AuthContext = Depends(require_data_scope), + service: SessionManagementService = Depends(get_service), +): + try: + return ok(service.delete_group({"group_id": group_id})) + except SessionManagementServiceError as exc: + return _service_error(exc) + except Exception as exc: + return _unexpected_error("删除分组失败", exc) + + +@legacy_router.get("/list-rule") +async def list_dashboard_session_rules( + page: int = Query(1), + page_size: int = Query(10), + search: str = Query(""), + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run( + lambda: service.list_session_rules( + page=page, + page_size=page_size, + search=search.strip(), + ), + label="获取规则列表失败", + ) + + +@legacy_router.post("/update-rule") +async def update_dashboard_session_rule( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run_dashboard_json( + request, + service.update_session_rule, + label="更新会话规则失败", + ) + + +@legacy_router.post("/delete-rule") +async def delete_dashboard_session_rule( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run_dashboard_json( + request, + service.delete_session_rule, + label="删除会话规则失败", + ) + + +@legacy_router.post("/batch-delete-rule") +async def batch_delete_dashboard_session_rule( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run_dashboard_json( + request, + service.batch_delete_session_rule, + label="批量删除会话规则失败", + ) + + +@legacy_router.get("/active-umos") +async def list_dashboard_active_umos( + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run(service.list_active_umos, label="获取 UMO 列表失败") + + +@legacy_router.get("/list-all-with-status") +async def list_dashboard_umos_with_status( + page: int = Query(1), + page_size: int = Query(20), + search: str = Query(""), + message_type: str = Query("all"), + platform: str = Query(""), + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run( + lambda: service.list_all_umos_with_status( + page=page, + page_size=page_size, + search=search.strip(), + message_type=message_type, + platform=platform, + ), + label="获取会话状态列表失败", + ) + + +@legacy_router.post("/batch-update-service") +async def batch_update_dashboard_session_service( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run_dashboard_json( + request, + service.batch_update_service, + label="批量更新服务状态失败", + ) + + +@legacy_router.post("/batch-update-provider") +async def batch_update_dashboard_session_provider( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run_dashboard_json( + request, + service.batch_update_provider, + label="批量更新 Provider 失败", + ) + + +@legacy_router.get("/groups") +async def list_dashboard_session_groups( + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run(service.list_groups, label="获取分组列表失败") + + +@legacy_router.post("/group/create") +async def create_dashboard_session_group( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run_dashboard_json( + request, + service.create_group, + label="创建分组失败", + ) + + +@legacy_router.post("/group/update") +async def update_dashboard_session_group( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run_dashboard_json( + request, + service.update_group, + label="更新分组失败", + ) + + +@legacy_router.post("/group/delete") +async def delete_dashboard_session_group( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SessionManagementService = Depends(get_service), +): + return await _run_dashboard_json( + request, + service.delete_group, + label="删除分组失败", + ) diff --git a/astrbot/dashboard/api/skills.py b/astrbot/dashboard/api/skills.py new file mode 100644 index 0000000000..c8925c8e4a --- /dev/null +++ b/astrbot/dashboard/api/skills.py @@ -0,0 +1,558 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import FileResponse + +from astrbot.core import logger +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import error, ok +from astrbot.dashboard.schemas import ( + SkillByNameUpdateRequest, + SkillFileUpdateRequest, + SkillNeoRequest, + SkillUpdateRequest, +) +from astrbot.dashboard.services.skills_service import ( + SkillArchive, + SkillsOperationResult, + SkillsService, + SkillsServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope +from .multipart import multipart_parts, single_upload + +router = APIRouter(tags=["Skills"]) +legacy_router = APIRouter( + prefix="/api/skills", + tags=["Dashboard Skills"], + include_in_schema=False, +) + + +def get_service(request: Request) -> SkillsService: + return request.app.state.services.skills + + +async def require_skill_scope(request: Request) -> AuthContext: + return await require_scope(request, "skill") + + +async def _json_or_empty(request: Request) -> dict[str, Any]: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _required_text(value: object, name: str) -> str: + text = str(value or "").strip() + if not text: + raise ValueError(f"Missing key: {name}") + return text + + +def _model_dict(payload) -> dict[str, Any]: + if payload is None: + return {} + if hasattr(payload, "model_dump"): + return payload.model_dump(exclude_none=True) + return payload if isinstance(payload, dict) else {} + + +def _serialize_result(result: SkillsOperationResult): + if result.ok: + return ok(result.data, result.message) + return error(result.message or "", result.data) + + +async def _run(operation, *, trace: bool = True): + try: + result = await run_maybe_async(operation) + if isinstance(result, SkillsOperationResult): + return _serialize_result(result) + return ok(result) + except SkillsServiceError as exc: + return error(str(exc)) + except Exception as exc: + logger.error(str(exc), exc_info=trace) + return error(str(exc)) + + +def _archive_response(archive: SkillArchive): + return FileResponse( + archive.path, + filename=archive.filename, + media_type="application/zip", + ) + + +async def _download_skill(service: SkillsService, name: str): + try: + return _archive_response(service.prepare_skill_archive(name)) + except SkillsServiceError as exc: + return error(str(exc)) + except Exception as exc: + logger.error(str(exc), exc_info=True) + return error(str(exc)) + + +@router.get("/skills") +async def list_skills( + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(service.get_skills) + + +@router.post("/skills") +async def upload_skill( + request: Request, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + async def _operation(): + return await service.upload_skill(await single_upload(request)) + + return await _run(_operation) + + +@router.post("/skills/batch") +async def upload_skills_batch( + request: Request, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + async def _operation(): + _, files = await multipart_parts(request) + return await service.batch_upload_skills(files.getlist("files")) + + return await _run(_operation) + + +@router.patch("/skills/by-name") +async def update_skill_by_name( + payload: SkillByNameUpdateRequest, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + skill_name = _required_text(payload.skill_name, "skill_name") + return await _run( + lambda: service.update_skill( + { + "name": skill_name, + "active": payload.active_value(), + } + ) + ) + + +@router.delete("/skills/by-name") +async def delete_skill_by_name( + skill_name: str, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.delete_skill({"name": skill_name})) + + +@router.get("/skills/archive") +async def download_skill_by_name( + skill_name: str, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _download_skill(service, skill_name) + + +@router.get("/skills/files") +async def list_skill_files_by_name( + request: Request, + skill_name: str, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run( + lambda: service.list_skill_files( + skill_name, + request.query_params.get("path", ""), + ) + ) + + +@router.get("/skills/file") +async def get_skill_file_by_name( + skill_name: str, + path: str, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.get_skill_file(skill_name, path)) + + +@router.put("/skills/file") +async def update_skill_file_by_name( + payload: SkillFileUpdateRequest, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + skill_name = _required_text(payload.skill_name, "skill_name") + path = _required_text(payload.path, "path") + return await _run( + lambda: service.update_skill_file( + { + "name": skill_name, + "path": path, + "content": payload.content, + } + ) + ) + + +@router.get("/skills/{skill_name:path}/archive") +async def download_skill( + skill_name: str, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _download_skill(service, skill_name) + + +@router.get("/skills/{skill_name:path}/files") +async def list_skill_files( + skill_name: str, + request: Request, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run( + lambda: service.list_skill_files( + skill_name, + request.query_params.get("path", ""), + ) + ) + + +@router.get("/skills/{skill_name:path}/files/{file_path:path}") +async def get_skill_file( + skill_name: str, + file_path: str, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.get_skill_file(skill_name, file_path)) + + +@router.put("/skills/{skill_name:path}/files/{file_path:path}") +async def update_skill_file( + skill_name: str, + file_path: str, + request: Request, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + content = (await request.body()).decode("utf-8") + return await _run( + lambda: service.update_skill_file( + {"name": skill_name, "path": file_path, "content": content} + ) + ) + + +@router.patch("/skills/{skill_name:path}") +async def update_skill( + skill_name: str, + payload: SkillUpdateRequest, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run( + lambda: service.update_skill( + { + "name": skill_name, + "active": payload.active_value(), + } + ) + ) + + +@router.delete("/skills/{skill_name:path}") +async def delete_skill( + skill_name: str, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.delete_skill({"name": skill_name})) + + +@router.get("/skills/neo/candidates") +async def list_neo_skill_candidates( + request: Request, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run( + service.get_neo_candidates( + dict(request.query_params), + ) + ) + + +@router.get("/skills/neo/releases") +async def list_neo_skill_releases( + request: Request, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run( + service.get_neo_releases( + dict(request.query_params), + ) + ) + + +@router.get("/skills/neo/payload") +async def get_neo_skill_payload( + request: Request, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(service.get_neo_payload(dict(request.query_params))) + + +@router.post("/skills/neo/evaluate") +async def evaluate_neo_skill_candidate( + payload: SkillNeoRequest, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.evaluate_neo_candidate(_model_dict(payload))) + + +@router.post("/skills/neo/promote") +async def promote_neo_skill_candidate( + payload: SkillNeoRequest, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.promote_neo_candidate(_model_dict(payload))) + + +@router.post("/skills/neo/rollback") +async def rollback_neo_skill_release( + payload: SkillNeoRequest, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.rollback_neo_release(_model_dict(payload))) + + +@router.post("/skills/neo/sync") +async def sync_neo_skill_release( + payload: SkillNeoRequest, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.sync_neo_release(_model_dict(payload))) + + +@router.post("/skills/neo/candidates/delete") +async def delete_neo_skill_candidate( + payload: SkillNeoRequest, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.delete_neo_candidate(_model_dict(payload))) + + +@router.post("/skills/neo/releases/delete") +async def delete_neo_skill_release( + payload: SkillNeoRequest, + _auth: AuthContext = Depends(require_skill_scope), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.delete_neo_release(_model_dict(payload))) + + +@legacy_router.get("") +async def list_dashboard_skills( + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + return await _run(service.get_skills) + + +@legacy_router.post("/upload") +async def upload_dashboard_skill( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + async def _operation(): + return await service.upload_skill(await single_upload(request)) + + return await _run(_operation) + + +@legacy_router.post("/batch-upload") +async def batch_upload_dashboard_skills( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + async def _operation(): + _, files = await multipart_parts(request) + return await service.batch_upload_skills(files.getlist("files")) + + return await _run(_operation) + + +@legacy_router.get("/download") +async def download_dashboard_skill( + name: str, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + return await _download_skill(service, name) + + +@legacy_router.get("/files") +async def list_dashboard_skill_files( + request: Request, + name: str, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + return await _run( + lambda: service.list_skill_files(name, request.query_params.get("path", "")) + ) + + +@legacy_router.get("/file") +async def get_dashboard_skill_file( + name: str, + path: str = "SKILL.md", + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + return await _run(lambda: service.get_skill_file(name, path)) + + +@legacy_router.post("/file") +async def update_dashboard_skill_file( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.update_skill_file(body)) + + +@legacy_router.post("/update") +async def update_dashboard_skill( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.update_skill(body)) + + +@legacy_router.post("/delete") +async def delete_dashboard_skill( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.delete_skill(body)) + + +@legacy_router.get("/neo/candidates") +async def list_dashboard_neo_skill_candidates( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + return await _run(service.get_neo_candidates(dict(request.query_params))) + + +@legacy_router.get("/neo/releases") +async def list_dashboard_neo_skill_releases( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + return await _run(service.get_neo_releases(dict(request.query_params))) + + +@legacy_router.get("/neo/payload") +async def get_dashboard_neo_skill_payload( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + return await _run(service.get_neo_payload(dict(request.query_params))) + + +@legacy_router.post("/neo/evaluate") +async def evaluate_dashboard_neo_skill_candidate( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.evaluate_neo_candidate(body)) + + +@legacy_router.post("/neo/promote") +async def promote_dashboard_neo_skill_candidate( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.promote_neo_candidate(body)) + + +@legacy_router.post("/neo/rollback") +async def rollback_dashboard_neo_skill_release( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.rollback_neo_release(body)) + + +@legacy_router.post("/neo/sync") +async def sync_dashboard_neo_skill_release( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.sync_neo_release(body)) + + +@legacy_router.post("/neo/delete-candidate") +async def delete_dashboard_neo_skill_candidate( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.delete_neo_candidate(body)) + + +@legacy_router.post("/neo/delete-release") +async def delete_dashboard_neo_skill_release( + request: Request, + _username: str = Depends(require_dashboard_user), + service: SkillsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.delete_neo_release(body)) diff --git a/astrbot/dashboard/api/static_files.py b/astrbot/dashboard/api/static_files.py new file mode 100644 index 0000000000..02aaa3ca69 --- /dev/null +++ b/astrbot/dashboard/api/static_files.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import FileResponse, PlainTextResponse + +from astrbot.dashboard.services.static_file_service import StaticFileService + +router = APIRouter(include_in_schema=False) +service = StaticFileService() + + +def _static_folder(request: Request) -> str | None: + return getattr(request.app.state, "dashboard_static_folder", None) + + +def _not_found_response() -> PlainTextResponse: + return PlainTextResponse(service.get_not_found_message(), status_code=404) + + +async def serve_index(request: Request): + index_file = service.resolve_index_file(_static_folder(request)) + if index_file is None: + return _not_found_response() + return FileResponse(index_file) + + +async def serve_static_file(request: Request, static_path: str): + if request.url.path.startswith("/api"): + raise HTTPException(status_code=404) + + file_path = service.resolve_static_file(_static_folder(request), static_path) + if file_path is None: + return _not_found_response() + return FileResponse(file_path) + + +for index_route in service.list_index_routes(): + router.add_api_route(index_route, serve_index, methods=["GET"]) + +router.add_api_route("/{static_path:path}", serve_static_file, methods=["GET"]) diff --git a/astrbot/dashboard/api/stats.py b/astrbot/dashboard/api/stats.py new file mode 100644 index 0000000000..6bb87817b0 --- /dev/null +++ b/astrbot/dashboard/api/stats.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request + +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import GhProxyTestRequest, StorageCleanupRequest +from astrbot.dashboard.services.stat_service import StatService, StatServiceError + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["System Stats"]) +legacy_router = APIRouter( + prefix="/api/stat", + tags=["Dashboard System Stats"], + include_in_schema=False, +) + + +async def require_system_scope(request: Request) -> AuthContext: + return await require_scope(request, "system") + + +def get_service(request: Request) -> StatService: + return request.app.state.services.stats + + +def _raise_stat_error(exc: StatServiceError) -> None: + raise ApiError(str(exc)) from exc + + +async def _run(operation): + try: + result = await run_maybe_async(operation) + return ok(result) + except StatServiceError as exc: + _raise_stat_error(exc) + + +def _parse_int(value: object, default: int, name: str) -> int: + if value is None: + return default + if not isinstance(value, int | float | str | bytes | bytearray): + raise ApiError(f"{name} must be an integer") + try: + return int(value) + except (TypeError, ValueError) as exc: + raise ApiError(f"{name} must be an integer") from exc + + +@router.get("/stats") +async def get_stats( + offset_sec: int = Query(default=86400), + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(service.get_stat(offset_sec)) + + +@router.get("/stats/provider-tokens") +async def get_provider_token_stats( + days: int = Query(default=1), + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(service.get_provider_token_stats(days)) + + +@router.get("/stats/version") +async def get_version( + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(service.get_version()) + + +@router.get("/stats/first-notice") +async def get_first_notice( + locale: str | None = None, + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(lambda: service.get_first_notice(locale)) + + +@router.post("/stats/ghproxy/test") +async def test_ghproxy_connection( + payload: GhProxyTestRequest, + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(service.test_ghproxy_connection(payload.proxy_url)) + + +@router.get("/changelogs") +async def list_changelog_versions( + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(service.list_changelog_versions) + + +@router.get("/changelogs/{version}") +async def get_changelog( + version: str, + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(lambda: service.get_changelog(version)) + + +@router.get("/stats/start-time") +async def get_start_time( + service: StatService = Depends(get_service), +): + return await _run(service.get_start_time) + + +@router.get("/stats/storage") +async def get_storage_status( + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(service.get_storage_status()) + + +@router.post("/stats/storage/cleanup") +async def cleanup_storage( + payload: StorageCleanupRequest, + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(service.cleanup_storage(payload.target)) + + +@router.post("/system/restart") +async def restart_system( + _auth: AuthContext = Depends(require_system_scope), + service: StatService = Depends(get_service), +): + return await _run(service.restart_core()) + + +@legacy_router.get("/get") +async def get_dashboard_stats( + offset_sec: int | None = Query(default=86400), + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + return await _run(service.get_stat(_parse_int(offset_sec, 86400, "offset_sec"))) + + +@legacy_router.get("/provider-tokens") +async def get_dashboard_provider_token_stats( + days: int | None = Query(default=1), + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + return await _run(service.get_provider_token_stats(_parse_int(days, 1, "days"))) + + +@legacy_router.get("/version") +async def get_dashboard_version( + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + return await _run(service.get_version()) + + +@legacy_router.get("/start-time") +async def get_dashboard_start_time( + service: StatService = Depends(get_service), +): + return await _run(service.get_start_time) + + +@legacy_router.post("/restart-core") +async def restart_dashboard_core( + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + return await _run(service.restart_core()) + + +@legacy_router.post("/test-ghproxy-connection") +async def test_dashboard_ghproxy_connection( + payload: GhProxyTestRequest, + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + return await _run(service.test_ghproxy_connection(payload.proxy_url)) + + +@legacy_router.get("/changelog") +async def get_dashboard_changelog( + version: str | None = None, + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + return await _run(lambda: service.get_changelog(version)) + + +@legacy_router.get("/changelog/list") +async def list_dashboard_changelog_versions( + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + return await _run(service.list_changelog_versions) + + +@legacy_router.get("/first-notice") +async def get_dashboard_first_notice( + locale: str | None = None, + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + return await _run(lambda: service.get_first_notice(locale)) + + +@legacy_router.get("/storage") +async def get_dashboard_storage_status( + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + return await _run(service.get_storage_status()) + + +@legacy_router.post("/storage/cleanup") +async def cleanup_dashboard_storage( + payload: StorageCleanupRequest | None = None, + _username: str = Depends(require_dashboard_user), + service: StatService = Depends(get_service), +): + target = payload.target if payload is not None else "all" + return await _run(service.cleanup_storage(target)) diff --git a/astrbot/dashboard/api/subagents.py b/astrbot/dashboard/api/subagents.py new file mode 100644 index 0000000000..c26bf4ad57 --- /dev/null +++ b/astrbot/dashboard/api/subagents.py @@ -0,0 +1,107 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Request + +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import SubAgentConfigRequest +from astrbot.dashboard.services.subagent_service import ( + SubAgentService, + SubAgentServiceError, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Subagents"]) +legacy_router = APIRouter( + prefix="/api/subagent", + tags=["Dashboard Subagents"], + include_in_schema=False, +) + + +async def require_config_scope(request: Request) -> AuthContext: + return await require_scope(request, "config") + + +def get_service(request: Request) -> SubAgentService: + return request.app.state.services.subagents + + +def _payload_dict(payload: SubAgentConfigRequest) -> dict: + return payload.model_dump(exclude_none=True) + + +def _raise_subagent_error(exc: SubAgentServiceError) -> None: + raise ApiError(str(exc)) from exc + + +async def _get_config(service: SubAgentService): + try: + return ok(service.get_config()) + except SubAgentServiceError as exc: + _raise_subagent_error(exc) + + +async def _update_config(payload: SubAgentConfigRequest, service: SubAgentService): + try: + await service.update_config(_payload_dict(payload)) + return ok(message="保存成功") + except SubAgentServiceError as exc: + _raise_subagent_error(exc) + + +async def _get_available_tools(service: SubAgentService): + try: + return ok(service.get_available_tools()) + except SubAgentServiceError as exc: + _raise_subagent_error(exc) + + +@router.get("/subagents/config") +async def get_subagent_config( + _auth: AuthContext = Depends(require_config_scope), + service: SubAgentService = Depends(get_service), +): + return await _get_config(service) + + +@router.put("/subagents/config") +async def update_subagent_config( + payload: SubAgentConfigRequest, + _auth: AuthContext = Depends(require_config_scope), + service: SubAgentService = Depends(get_service), +): + return await _update_config(payload, service) + + +@router.get("/subagents/available-tools") +async def get_subagent_tools( + _auth: AuthContext = Depends(require_config_scope), + service: SubAgentService = Depends(get_service), +): + return await _get_available_tools(service) + + +@legacy_router.get("/config") +async def get_dashboard_subagent_config( + _username: str = Depends(require_dashboard_user), + service: SubAgentService = Depends(get_service), +): + return await _get_config(service) + + +@legacy_router.post("/config") +async def update_dashboard_subagent_config( + payload: SubAgentConfigRequest, + _username: str = Depends(require_dashboard_user), + service: SubAgentService = Depends(get_service), +): + return await _update_config(payload, service) + + +@legacy_router.get("/available-tools") +async def get_dashboard_subagent_tools( + _username: str = Depends(require_dashboard_user), + service: SubAgentService = Depends(get_service), +): + return await _get_available_tools(service) diff --git a/astrbot/dashboard/api/t2i.py b/astrbot/dashboard/api/t2i.py new file mode 100644 index 0000000000..3edc66fd73 --- /dev/null +++ b/astrbot/dashboard/api/t2i.py @@ -0,0 +1,235 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Request +from fastapi.responses import JSONResponse + +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import T2iActiveTemplateRequest, T2iTemplateRequest +from astrbot.dashboard.services.t2i_service import T2iService, T2iServiceError + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Text To Image"]) +legacy_router = APIRouter( + prefix="/api/t2i", + tags=["Dashboard Text To Image"], + include_in_schema=False, +) + + +def get_service(request: Request) -> T2iService: + return request.app.state.services.t2i + + +async def require_config_scope(request: Request) -> AuthContext: + return await require_scope(request, "config") + + +async def _json_or_empty(request: Request) -> dict: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _raise_t2i_error(exc: T2iServiceError) -> None: + raise ApiError(str(exc), status_code=exc.status_code) from exc + + +def _response( + data=None, + *, + message: str | None = None, + status_code: int = 200, +): + payload = ok(data, message) + if status_code == 200: + return payload + return JSONResponse(payload, status_code=status_code) + + +async def _run( + operation, + *, + message: str | None = None, + status_code: int = 200, + result_as_message: bool = False, +): + try: + result = await run_maybe_async(operation) + if isinstance(result, tuple): + payload, result_message = result + return _response(payload, message=result_message) + if result_as_message: + return _response(message=str(result), status_code=status_code) + return _response(result, message=message, status_code=status_code) + except T2iServiceError as exc: + _raise_t2i_error(exc) + + +@router.get("/t2i/templates") +async def list_t2i_templates( + _auth: AuthContext = Depends(require_config_scope), + service: T2iService = Depends(get_service), +): + return await _run(service.list_templates) + + +@router.post("/t2i/templates") +async def create_t2i_template( + payload: T2iTemplateRequest, + _auth: AuthContext = Depends(require_config_scope), + service: T2iService = Depends(get_service), +): + return await _run( + lambda: service.create_template(payload.name, payload.content), + message="Template created successfully.", + status_code=201, + ) + + +@router.get("/t2i/templates/active") +async def get_active_t2i_template( + _auth: AuthContext = Depends(require_config_scope), + service: T2iService = Depends(get_service), +): + return await _run(service.get_active_template) + + +@router.put("/t2i/templates/active") +async def set_active_t2i_template( + payload: T2iActiveTemplateRequest, + _auth: AuthContext = Depends(require_config_scope), + service: T2iService = Depends(get_service), +): + return await _run( + lambda: service.set_active_template(payload.name), + result_as_message=True, + ) + + +@router.post("/t2i/templates/default/reset") +async def reset_default_t2i_template( + _auth: AuthContext = Depends(require_config_scope), + service: T2iService = Depends(get_service), +): + return await _run( + service.reset_default_template, + result_as_message=True, + ) + + +@router.get("/t2i/templates/{name:path}") +async def get_t2i_template( + name: str, + _auth: AuthContext = Depends(require_config_scope), + service: T2iService = Depends(get_service), +): + return await _run(lambda: service.get_template(name)) + + +@router.put("/t2i/templates/{name:path}") +async def update_t2i_template( + name: str, + payload: T2iTemplateRequest, + _auth: AuthContext = Depends(require_config_scope), + service: T2iService = Depends(get_service), +): + return await _run(lambda: service.update_template(name, payload.content)) + + +@router.delete("/t2i/templates/{name:path}") +async def delete_t2i_template( + name: str, + _auth: AuthContext = Depends(require_config_scope), + service: T2iService = Depends(get_service), +): + return await _run( + lambda: service.delete_template(name), + message="Template deleted successfully.", + ) + + +@legacy_router.get("/templates") +async def list_dashboard_t2i_templates( + _username: str = Depends(require_dashboard_user), + service: T2iService = Depends(get_service), +): + return await _run(service.list_templates) + + +@legacy_router.get("/templates/active") +async def get_dashboard_active_t2i_template( + _username: str = Depends(require_dashboard_user), + service: T2iService = Depends(get_service), +): + return await _run(service.get_active_template) + + +@legacy_router.post("/templates/create") +async def create_dashboard_t2i_template( + request: Request, + _username: str = Depends(require_dashboard_user), + service: T2iService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run( + lambda: service.create_template(body.get("name"), body.get("content")), + message="Template created successfully.", + status_code=201, + ) + + +@legacy_router.post("/templates/reset_default") +async def reset_dashboard_default_t2i_template( + _username: str = Depends(require_dashboard_user), + service: T2iService = Depends(get_service), +): + return await _run(service.reset_default_template, result_as_message=True) + + +@legacy_router.post("/templates/set_active") +async def set_dashboard_active_t2i_template( + request: Request, + _username: str = Depends(require_dashboard_user), + service: T2iService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run( + lambda: service.set_active_template(body.get("name")), + result_as_message=True, + ) + + +@legacy_router.get("/templates/{name:path}") +async def get_dashboard_t2i_template( + name: str, + _username: str = Depends(require_dashboard_user), + service: T2iService = Depends(get_service), +): + return await _run(lambda: service.get_template(name)) + + +@legacy_router.put("/templates/{name:path}") +async def update_dashboard_t2i_template( + name: str, + request: Request, + _username: str = Depends(require_dashboard_user), + service: T2iService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run(lambda: service.update_template(name, body.get("content"))) + + +@legacy_router.delete("/templates/{name:path}") +async def delete_dashboard_t2i_template( + name: str, + _username: str = Depends(require_dashboard_user), + service: T2iService = Depends(get_service), +): + return await _run( + lambda: service.delete_template(name), + message="Template deleted successfully.", + ) diff --git a/astrbot/dashboard/api/tools.py b/astrbot/dashboard/api/tools.py new file mode 100644 index 0000000000..f71de53f2b --- /dev/null +++ b/astrbot/dashboard/api/tools.py @@ -0,0 +1,429 @@ +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Query, Request + +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.responses import ApiError, ok +from astrbot.dashboard.schemas import ( + McpServerByNameRequest, + McpServerRequest, + ModelScopeSyncRequest, + ToolEnabledRequest, + ToolPermissionRequest, +) +from astrbot.dashboard.services.tools_service import ToolsService, ToolsServiceError + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Extension Components"]) +legacy_router = APIRouter( + prefix="/api", + tags=["Dashboard Extension Components"], + include_in_schema=False, +) + + +def get_service(request: Request) -> ToolsService: + return request.app.state.services.tools + + +async def require_tool_scope(request: Request) -> AuthContext: + return await require_scope(request, "tool") + + +async def _json_or_empty(request: Request) -> dict[str, Any]: + try: + data = await request.json() + except Exception: + return {} + return data if isinstance(data, dict) else {} + + +def _required_text(value: object, name: str) -> str: + text = str(value or "").strip() + if not text: + raise ApiError(f"Missing key: {name}") + return text + + +def _model_dict(payload: McpServerRequest | McpServerByNameRequest) -> dict[str, Any]: + return payload.model_dump(exclude_none=True) + + +def _normalize_server_config(body: dict[str, Any], id_key: str) -> dict[str, Any]: + config = body.get("config") + if isinstance(config, dict): + normalized = dict(config) + else: + normalized = { + key: value + for key, value in body.items() + if key not in {id_key, "config", "enabled", "mcp_server_config"} + } + if "enabled" in body and "active" not in normalized: + normalized["active"] = body["enabled"] + return normalized + + +def _server_name_from_body(body: dict[str, Any]) -> str: + return _required_text(body.get("server_name") or body.get("name"), "server_name") + + +def _test_config_body( + service: ToolsService, + server_name: str, + body: dict[str, Any], +) -> dict[str, Any]: + config = body.get("mcp_server_config") or body.get("config") + if isinstance(config, dict): + return dict(config) + + stored_config = service.get_mcp_server_config(server_name) + if stored_config is not None: + return stored_config + + return {"name": server_name} + + +def _raise_tools_error(exc: ToolsServiceError) -> None: + raise ApiError(str(exc)) from exc + + +async def _run( + operation, *, result_as_message: bool = False, message: str | None = None +): + try: + result = await run_maybe_async(operation) + if result_as_message: + return ok(None, str(result)) + return ok(result, message) + except ToolsServiceError as exc: + _raise_tools_error(exc) + + +async def _toggle_tool( + tool_id: str, + enabled: bool, + service: ToolsService, +): + return await _run( + lambda: service.toggle_tool({"name": tool_id, "activate": enabled}), + result_as_message=True, + ) + + +async def _update_tool_permission( + tool_id: str, + permission: str, + service: ToolsService, +): + return await _run( + lambda: service.update_tool_permission( + {"name": tool_id, "permission": permission} + ), + result_as_message=True, + ) + + +async def _create_mcp_server(body: dict[str, Any], service: ToolsService): + if "enabled" in body and "active" not in body: + body["active"] = body.pop("enabled") + return await _run( + lambda: service.add_mcp_server(body), + result_as_message=True, + ) + + +async def _update_mcp_server( + server_name: str, + body: dict[str, Any], + service: ToolsService, +): + config = _normalize_server_config(body, "server_name") + config.setdefault("name", server_name) + config.setdefault("oldName", server_name) + return await _run( + lambda: service.update_mcp_server(config), + result_as_message=True, + ) + + +async def _delete_mcp_server(server_name: str, service: ToolsService): + return await _run( + lambda: service.delete_mcp_server({"name": server_name}), + result_as_message=True, + ) + + +async def _test_mcp_server( + server_name: str, + body: dict[str, Any], + service: ToolsService, +): + config = _test_config_body(service, server_name, body) + return await _run( + lambda: service.test_mcp_connection( + {"name": server_name, "mcp_server_config": config} + ), + message="🎉 MCP server is available!", + ) + + +async def _sync_modelscope_mcp_servers( + access_token: str, + service: ToolsService, +): + return await _run( + lambda: service.sync_provider( + { + "name": "modelscope", + "access_token": access_token, + } + ), + result_as_message=True, + ) + + +@router.get("/tools") +async def list_tools( + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + return await _run(service.get_tool_list) + + +@router.patch("/tools/{tool_id:path}/enabled") +async def set_tool_enabled( + tool_id: str, + payload: ToolEnabledRequest, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + return await _toggle_tool(tool_id, payload.enabled, service) + + +@router.patch("/tools/{tool_id:path}/permission") +async def set_tool_permission( + tool_id: str, + payload: ToolPermissionRequest, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + return await _update_tool_permission(tool_id, payload.permission, service) + + +@router.get("/mcp/servers") +async def list_mcp_servers( + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + return await _run(service.get_mcp_servers) + + +@router.post("/mcp/servers") +async def create_mcp_server( + payload: McpServerRequest, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + return await _create_mcp_server(_model_dict(payload), service) + + +@router.put("/mcp/servers/by-name") +async def update_mcp_server_by_name( + payload: McpServerByNameRequest, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + body = _model_dict(payload) + return await _update_mcp_server(payload.server_name, body, service) + + +@router.delete("/mcp/servers/by-name") +async def delete_mcp_server_by_name( + server_name: str = Query(...), + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + return await _delete_mcp_server(server_name, service) + + +@router.patch("/mcp/servers/enabled") +async def set_mcp_server_enabled_by_name( + payload: McpServerByNameRequest, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + body = _model_dict(payload) + return await _update_mcp_server(payload.server_name, body, service) + + +@router.post("/mcp/servers/test") +async def test_mcp_server_by_name( + payload: McpServerByNameRequest, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + body = _model_dict(payload) + return await _test_mcp_server(payload.server_name, body, service) + + +@router.patch("/mcp/servers/{server_name:path}/enabled") +async def set_mcp_server_enabled( + server_name: str, + payload: ToolEnabledRequest, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + return await _update_mcp_server( + server_name, + {"server_name": server_name, "enabled": payload.enabled}, + service, + ) + + +@router.post("/mcp/servers/{server_name:path}/test") +async def test_mcp_server( + server_name: str, + payload: McpServerRequest | None = None, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + body = _model_dict(payload) if payload is not None else {} + return await _test_mcp_server(server_name, body, service) + + +@router.put("/mcp/servers/{server_name:path}") +async def update_mcp_server( + server_name: str, + payload: McpServerRequest, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + body = _model_dict(payload) + return await _update_mcp_server(server_name, body, service) + + +@router.delete("/mcp/servers/{server_name:path}") +async def delete_mcp_server( + server_name: str, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + return await _delete_mcp_server(server_name, service) + + +@router.post("/mcp/providers/modelscope/sync") +async def sync_modelscope_mcp_servers( + payload: ModelScopeSyncRequest | None = None, + _auth: AuthContext = Depends(require_tool_scope), + service: ToolsService = Depends(get_service), +): + access_token = payload.access_token if payload is not None else "" + return await _sync_modelscope_mcp_servers(access_token or "", service) + + +@legacy_router.get("/tools/list") +async def list_dashboard_tools( + _username: str = Depends(require_dashboard_user), + service: ToolsService = Depends(get_service), +): + return await _run(service.get_tool_list) + + +@legacy_router.post("/tools/toggle-tool") +async def toggle_dashboard_tool( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ToolsService = Depends(get_service), +): + body = await _json_or_empty(request) + tool_id = _required_text(body.get("name"), "name") + return await _toggle_tool(tool_id, bool(body.get("activate")), service) + + +@legacy_router.post("/tools/permission") +async def update_dashboard_tool_permission( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ToolsService = Depends(get_service), +): + body = await _json_or_empty(request) + tool_id = _required_text(body.get("name"), "name") + return await _update_tool_permission( + tool_id, + str(body.get("permission") or ""), + service, + ) + + +@legacy_router.get("/tools/mcp/servers") +async def list_dashboard_mcp_servers( + _username: str = Depends(require_dashboard_user), + service: ToolsService = Depends(get_service), +): + return await _run(service.get_mcp_servers) + + +@legacy_router.post("/tools/mcp/add") +async def add_dashboard_mcp_server( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ToolsService = Depends(get_service), +): + return await _create_mcp_server(await _json_or_empty(request), service) + + +@legacy_router.post("/tools/mcp/update") +async def update_dashboard_mcp_server( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ToolsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _update_mcp_server(_server_name_from_body(body), body, service) + + +@legacy_router.post("/tools/mcp/delete") +async def delete_dashboard_mcp_server( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ToolsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _delete_mcp_server(_required_text(body.get("name"), "name"), service) + + +@legacy_router.post("/tools/mcp/test") +async def test_dashboard_mcp_connection( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ToolsService = Depends(get_service), +): + body = await _json_or_empty(request) + server_name = str(body.get("name") or "") + config = body.get("mcp_server_config") or body.get("config") or body + return await _run( + lambda: service.test_mcp_connection( + { + "name": server_name, + "mcp_server_config": config, + } + ), + message="🎉 MCP server is available!", + ) + + +@legacy_router.post("/tools/mcp/sync-provider") +async def sync_dashboard_mcp_provider( + request: Request, + _username: str = Depends(require_dashboard_user), + service: ToolsService = Depends(get_service), +): + body = await _json_or_empty(request) + return await _run( + lambda: service.sync_provider(body), + result_as_message=True, + ) diff --git a/astrbot/dashboard/api/updates.py b/astrbot/dashboard/api/updates.py new file mode 100644 index 0000000000..05b8a4d251 --- /dev/null +++ b/astrbot/dashboard/api/updates.py @@ -0,0 +1,194 @@ +from __future__ import annotations + +from fastapi import APIRouter, Depends, Query, Request +from fastapi.responses import JSONResponse + +from astrbot.dashboard.async_utils import run_maybe_async +from astrbot.dashboard.schemas import MigrationRequest, PipInstallRequest, UpdateRequest +from astrbot.dashboard.services.update_service import ( + UpdateService, + UpdateServiceError, + UpdateServiceResult, +) + +from .auth import AuthContext, require_dashboard_user, require_scope + +router = APIRouter(tags=["Updates"]) +legacy_router = APIRouter( + prefix="/api/update", + tags=["Dashboard Updates"], + include_in_schema=False, +) + + +def get_service(request: Request) -> UpdateService: + return request.app.state.services.updates + + +async def require_system_scope(request: Request) -> AuthContext: + return await require_scope(request, "system") + + +def _model_dict(payload) -> dict: + return payload.model_dump(exclude_none=True) + + +def _result_payload(result: UpdateServiceResult) -> dict: + if result.status == "success": + return { + "status": "success", + "message": result.message, + "data": result.data, + } + return { + "status": "ok", + "message": result.message, + "data": {} if result.data is None else result.data, + } + + +def _service_response(result: UpdateServiceResult) -> JSONResponse: + return JSONResponse( + _result_payload(result), + status_code=200, + headers=result.headers or None, + ) + + +def _service_error(exc: UpdateServiceError) -> JSONResponse: + return JSONResponse( + {"status": "error", "message": str(exc), "data": None}, + status_code=200, + ) + + +async def _run(operation) -> JSONResponse: + try: + result = await run_maybe_async(operation) + return _service_response(result) + except UpdateServiceError as exc: + return _service_error(exc) + + +@router.get("/updates/check") +async def check_updates( + update_type: str | None = Query(default=None, alias="type"), + _auth: AuthContext = Depends(require_system_scope), + service: UpdateService = Depends(get_service), +): + return await _run(lambda: service.check_update(update_type)) + + +@legacy_router.get("/check") +async def check_dashboard_updates( + update_type: str | None = Query(default=None, alias="type"), + _username: str = Depends(require_dashboard_user), + service: UpdateService = Depends(get_service), +): + return await _run(lambda: service.check_update(update_type)) + + +@router.get("/updates/releases") +async def update_releases( + _auth: AuthContext = Depends(require_system_scope), + service: UpdateService = Depends(get_service), +): + return await _run(service.get_releases) + + +@legacy_router.get("/releases") +async def dashboard_update_releases( + _username: str = Depends(require_dashboard_user), + service: UpdateService = Depends(get_service), +): + return await _run(service.get_releases) + + +@router.get("/updates/progress/{task_id}") +async def update_progress( + task_id: str, + _auth: AuthContext = Depends(require_system_scope), + service: UpdateService = Depends(get_service), +): + return await _run(lambda: service.get_update_progress(task_id)) + + +@legacy_router.get("/progress") +async def dashboard_update_progress( + progress_id: str | None = Query(default=None, alias="id"), + _username: str = Depends(require_dashboard_user), + service: UpdateService = Depends(get_service), +): + return await _run(lambda: service.get_update_progress(progress_id or "")) + + +@router.post("/updates/core") +async def update_core( + payload: UpdateRequest, + _auth: AuthContext = Depends(require_system_scope), + service: UpdateService = Depends(get_service), +): + return await _run(lambda: service.update_project(_model_dict(payload))) + + +@legacy_router.post("/do") +async def update_dashboard_core( + payload: UpdateRequest, + _username: str = Depends(require_dashboard_user), + service: UpdateService = Depends(get_service), +): + return await _run(lambda: service.update_project(_model_dict(payload))) + + +@router.post("/updates/dashboard") +async def update_dashboard( + _auth: AuthContext = Depends(require_system_scope), + service: UpdateService = Depends(get_service), +): + return await _run(service.update_dashboard) + + +@legacy_router.post("/dashboard") +async def update_dashboard_assets( + _username: str = Depends(require_dashboard_user), + service: UpdateService = Depends(get_service), +): + return await _run(service.update_dashboard) + + +@router.post("/pip/install") +async def install_pip_package( + payload: PipInstallRequest, + _auth: AuthContext = Depends(require_system_scope), + service: UpdateService = Depends(get_service), +): + return await _run(lambda: service.install_pip_package(_model_dict(payload))) + + +@legacy_router.post("/pip-install") +async def install_dashboard_pip_package( + payload: PipInstallRequest, + _username: str = Depends(require_dashboard_user), + service: UpdateService = Depends(get_service), +): + return await _run(lambda: service.install_pip_package(_model_dict(payload))) + + +@router.post("/migrations") +async def run_migration( + payload: MigrationRequest | None = None, + _auth: AuthContext = Depends(require_system_scope), + service: UpdateService = Depends(get_service), +): + body = _model_dict(payload) if payload is not None else {} + return await _run(lambda: service.do_migration_v4(body)) + + +@legacy_router.post("/migration") +async def run_dashboard_migration( + payload: MigrationRequest | None = None, + _username: str = Depends(require_dashboard_user), + service: UpdateService = Depends(get_service), +): + body = _model_dict(payload) if payload is not None else {} + return await _run(lambda: service.do_migration_v4(body)) diff --git a/astrbot/dashboard/asgi_runtime.py b/astrbot/dashboard/asgi_runtime.py new file mode 100644 index 0000000000..d985ec4b82 --- /dev/null +++ b/astrbot/dashboard/asgi_runtime.py @@ -0,0 +1,710 @@ +from __future__ import annotations + +import contextvars +import inspect +import re +from collections.abc import Callable, Iterable +from contextlib import asynccontextmanager, contextmanager +from pathlib import Path +from typing import Any + +import httpx +from fastapi import FastAPI, HTTPException, Request, WebSocket +from fastapi.encoders import jsonable_encoder +from fastapi.responses import FileResponse, JSONResponse, Response +from starlette.datastructures import UploadFile as StarletteUploadFile +from starlette.responses import StreamingResponse + +_request_var: contextvars.ContextVar[DashboardRequest] = contextvars.ContextVar( + "dashboard_request" +) +_websocket_var: contextvars.ContextVar[DashboardWebSocket] = contextvars.ContextVar( + "dashboard_websocket" +) +_g_var: contextvars.ContextVar[DashboardRequestState] = contextvars.ContextVar( + "dashboard_g" +) +_app_var: contextvars.ContextVar[FastAPIAppAdapter] = contextvars.ContextVar( + "dashboard_app" +) + + +class RequestArgs: + def __init__(self, values) -> None: + self._values = values + + def get(self, key: str, default: Any = None, type: Callable | None = None): + value = self._values.get(key, default) + if value is default or type is None: + return value + try: + return type(value) + except (TypeError, ValueError): + return default + + +class RequestMultiDict: + def __init__(self, pairs: list[tuple[str, Any]]) -> None: + self._pairs = pairs + + def get(self, key: str, default: Any = None, type: Callable | None = None): + for item_key, item_value in reversed(self._pairs): + if item_key != key: + continue + if type is None: + return item_value + try: + return type(item_value) + except (TypeError, ValueError): + return default + return default + + def getlist(self, key: str) -> list[Any]: + return [item_value for item_key, item_value in self._pairs if item_key == key] + + def keys(self): + return dict.fromkeys(item_key for item_key, _ in self._pairs).keys() + + def values(self): + return [self[key] for key in self.keys()] + + def items(self): + return [(key, self[key]) for key in self.keys()] + + def __contains__(self, key: str) -> bool: + return any(item_key == key for item_key, _ in self._pairs) + + def __getitem__(self, key: str): + value = self.get(key) + if value is None and key not in self: + raise KeyError(key) + return value + + def __bool__(self) -> bool: + return bool(self._pairs) + + +class RequestUploadFile: + def __init__(self, upload_file: StarletteUploadFile) -> None: + self._upload_file = upload_file + self.filename = upload_file.filename + self.content_type = upload_file.content_type + self.headers = upload_file.headers + self.content_length = self._resolve_content_length() + + def _resolve_content_length(self) -> int | None: + try: + raw = self.headers.get("content-length") + return int(raw) if raw else None + except (TypeError, ValueError): + return None + + async def save(self, destination: str | Path) -> None: + path = Path(destination) + try: + await self._upload_file.seek(0) + except Exception: + pass + with path.open("wb") as output: + while True: + chunk = await self._upload_file.read(1024 * 1024) + if not chunk: + break + output.write(chunk) + + def __getattr__(self, key: str): + return getattr(self._upload_file, key) + + +class DashboardRequestState: + def __init__(self) -> None: + self._values: dict[str, Any] = {} + + def get(self, key: str, default: Any = None): + return self._values.get(key, default) + + def __getattr__(self, key: str): + try: + return self._values[key] + except KeyError as exc: + raise AttributeError(key) from exc + + def __setattr__(self, key: str, value: Any) -> None: + if key == "_values": + super().__setattr__(key, value) + return + self._values[key] = value + + +class DashboardRequest: + def __init__(self, request: Request) -> None: + self._request = request + self.args = RequestArgs(request.query_params) + self.headers = request.headers + self.cookies = request.cookies + self.method = request.method + self.path = request.url.path + self.content_type = request.headers.get("content-type") + self.remote_addr = request.client.host if request.client else None + self._form_cache: RequestMultiDict | None = None + self._files_cache: RequestMultiDict | None = None + + @property + def json(self): + return self.get_json() + + @property + def files(self): + return self._load_files() + + @property + def form(self): + return self._load_form() + + async def get_json(self, silent: bool = False): + try: + return await self._request.json() + except Exception: + if silent: + return None + raise + + async def _load_form_parts(self) -> None: + if self._form_cache is not None and self._files_cache is not None: + return + form = await self._request.form() + form_pairs: list[tuple[str, Any]] = [] + file_pairs: list[tuple[str, Any]] = [] + for key, value in form.multi_items(): + if isinstance(value, StarletteUploadFile): + file_pairs.append((key, RequestUploadFile(value))) + else: + form_pairs.append((key, value)) + self._form_cache = RequestMultiDict(form_pairs) + self._files_cache = RequestMultiDict(file_pairs) + + async def _load_form(self) -> RequestMultiDict: + await self._load_form_parts() + assert self._form_cache is not None + return self._form_cache + + async def _load_files(self) -> RequestMultiDict: + await self._load_form_parts() + assert self._files_cache is not None + return self._files_cache + + +class DashboardWebSocket: + def __init__(self, websocket: WebSocket) -> None: + self._websocket = websocket + self.args = RequestArgs(websocket.query_params) + self.headers = websocket.headers + + async def accept(self) -> None: + await self._websocket.accept() + + async def receive_json(self): + return await self._websocket.receive_json() + + async def send_json(self, payload: Any) -> None: + await self._websocket.send_json(payload) + + async def close(self, code: int = 1000, reason: str | None = None) -> None: + await self._websocket.close(code=code, reason=reason or "") + + +class AdapterTestHeaders: + def __init__(self, headers: httpx.Headers) -> None: + self._headers = headers + + def getlist(self, key: str) -> list[str]: + values = self._headers.get_list(key) + if key.lower() == "set-cookie": + return [value.replace('=""', "=") for value in values] + return values + + def get(self, key: str, default: Any = None): + value = self._headers.get(key, default) + if isinstance(value, str) and key.lower() == "set-cookie": + return value.replace('=""', "=") + return value + + def __getitem__(self, key: str): + return self._headers[key] + + def __contains__(self, key: str) -> bool: + return key in self._headers + + +class AdapterTestResponse: + def __init__(self, response: httpx.Response) -> None: + self._response = response + self.status_code = response.status_code + self.headers = AdapterTestHeaders(response.headers) + self.data = response.content + self.content = response.content + self.text = response.text + + async def get_json(self): + return self._response.json() + + async def get_data(self): + return self._response.content + + +class AdapterTestClient: + def __init__(self, app: FastAPI) -> None: + self._client = httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://testserver", + ) + + @staticmethod + def _is_file_storage(value: Any) -> bool: + return hasattr(value, "stream") and hasattr(value, "filename") + + @classmethod + def _file_tuple(cls, value: Any): + stream = value.stream + if hasattr(stream, "seek"): + stream.seek(0) + content = stream.read() + filename = getattr(value, "filename", "upload.bin") + content_type = getattr(value, "content_type", None) + return filename, content, content_type + + @classmethod + def _normalize_data(cls, data: Any): + if not isinstance(data, dict): + return data, None + + form: dict[str, Any] = {} + files: list[tuple[str, tuple]] = [] + for key, value in data.items(): + if cls._is_file_storage(value): + files.append((key, cls._file_tuple(value))) + continue + if isinstance(value, Iterable) and not isinstance( + value, str | bytes | dict + ): + values = list(value) + if values and all(cls._is_file_storage(item) for item in values): + files.extend((key, cls._file_tuple(item)) for item in values) + continue + form[key] = value + return form, files or None + + @classmethod + def _normalize_files(cls, files: Any): + if isinstance(files, dict): + items = files.items() + elif isinstance(files, Iterable) and not isinstance(files, str | bytes): + items = files + else: + return files + + normalized_files: list[tuple[str, Any]] = [] + for key, value in items: + if cls._is_file_storage(value): + normalized_files.append((key, cls._file_tuple(value))) + continue + if isinstance(value, Iterable) and not isinstance( + value, str | bytes | dict + ): + values = list(value) + if values and all(cls._is_file_storage(item) for item in values): + normalized_files.extend( + (key, cls._file_tuple(item)) for item in values + ) + continue + normalized_files.append((key, value)) + return normalized_files + + async def request(self, method: str, url: str, **kwargs): + data = kwargs.pop("data", None) + if data is not None and "files" not in kwargs: + normalized_data, files = self._normalize_data(data) + kwargs["data"] = normalized_data + if files: + kwargs["files"] = files + elif data is not None: + kwargs["data"] = data + if "files" in kwargs: + kwargs["files"] = self._normalize_files(kwargs["files"]) + response = await self._client.request(method, url, **kwargs) + return AdapterTestResponse(response) + + async def get(self, url: str, **kwargs): + return await self.request("GET", url, **kwargs) + + async def post(self, url: str, **kwargs): + return await self.request("POST", url, **kwargs) + + async def put(self, url: str, **kwargs): + return await self.request("PUT", url, **kwargs) + + async def patch(self, url: str, **kwargs): + return await self.request("PATCH", url, **kwargs) + + async def delete(self, url: str, **kwargs): + return await self.request("DELETE", url, **kwargs) + + +class _ContextProxy: + def __init__(self, var) -> None: + self._var = var + + def __getattr__(self, key: str): + return getattr(self._var.get(), key) + + def __setattr__(self, key: str, value: Any) -> None: + if key == "_var": + super().__setattr__(key, value) + return + setattr(self._var.get(), key, value) + + +request = _ContextProxy(_request_var) +websocket = _ContextProxy(_websocket_var) +g = _ContextProxy(_g_var) +current_app = _ContextProxy(_app_var) + + +@contextmanager +def bind_request_context( + request_: Request, + app: FastAPIAppAdapter, + g_obj: DashboardRequestState | None = None, +): + token_request = _request_var.set(DashboardRequest(request_)) + token_g = _g_var.set( + g_obj or getattr(request_.state, "dashboard_g", DashboardRequestState()) + ) + token_app = _app_var.set(app) + try: + yield _g_var.get() + finally: + _app_var.reset(token_app) + _g_var.reset(token_g) + _request_var.reset(token_request) + + +@contextmanager +def bind_websocket_context( + websocket_: WebSocket, + app: FastAPIAppAdapter, + g_obj: DashboardRequestState | None = None, +): + token_websocket = _websocket_var.set(DashboardWebSocket(websocket_)) + token_g = _g_var.set( + g_obj or getattr(websocket_.state, "dashboard_g", DashboardRequestState()) + ) + token_app = _app_var.set(app) + try: + yield + finally: + _app_var.reset(token_app) + _g_var.reset(token_g) + _websocket_var.reset(token_websocket) + + +def jsonify(payload: Any = None): + return JSONResponse(payload if payload is not None else {}) + + +async def make_response(*args): + if not args: + return Response() + content = args[0] + status_code = args[1] if len(args) > 1 and isinstance(args[1], int) else None + headers = args[1] if len(args) > 1 and isinstance(args[1], dict) else None + if len(args) > 2 and isinstance(args[2], dict): + headers = args[2] + if isinstance(content, Response): + if status_code is not None: + content.status_code = status_code + if headers: + content.headers.update(headers) + return content + if hasattr(content, "__aiter__"): + return StreamingResponse( + content, + status_code=status_code or 200, + headers=headers, + ) + return Response( + content=content, + status_code=status_code or 200, + headers=headers, + ) + + +async def send_file(path: str | Path, mimetype: str | None = None, **kwargs): + filename = kwargs.get("attachment_filename") or kwargs.get("download_name") + as_attachment = bool(kwargs.get("as_attachment")) + return FileResponse( + path, + media_type=mimetype, + filename=filename if as_attachment else None, + ) + + +def abort(status_code: int): + raise HTTPException(status_code=status_code) + + +def _convert_rule(path: str) -> str: + converted = re.sub(r"", r"{\1:path}", path) + converted = re.sub(r"<([A-Za-z_][A-Za-z0-9_]*)>", r"{\1}", converted) + return converted + + +async def _call_view(view_func: Callable, path_params: dict[str, Any]): + result = view_func(**path_params) + if inspect.isawaitable(result): + result = await result + return await _coerce_view_result(result) + + +async def _coerce_view_result(result: Any): + if isinstance(result, Response): + return result + if _is_quart_response(result): + return await _quart_response_to_starlette(result) + + if isinstance(result, tuple): + content = result[0] if result else None + status_code = next((item for item in result[1:] if isinstance(item, int)), 200) + headers = next( + (item for item in result[1:] if isinstance(item, dict)), + None, + ) + if content is not None and isinstance(content, Response): + content.status_code = status_code + if headers: + content.headers.update(headers) + return content + if _is_quart_response(content): + return await _quart_response_to_starlette( + content, + status_code=status_code, + extra_headers=headers, + ) + return _response_from_content(content, status_code=status_code, headers=headers) + + if isinstance(result, dict | list): + return JSONResponse(jsonable_encoder(result)) + return result + + +def _response_from_content( + content: Any, + *, + status_code: int, + headers: dict[str, str] | None = None, +): + if isinstance(content, dict | list): + return JSONResponse( + jsonable_encoder(content), + status_code=status_code, + headers=headers, + ) + return Response( + content=content, + status_code=status_code, + headers=headers, + ) + + +def _is_quart_response(value: Any) -> bool: + return ( + hasattr(value, "get_data") + and inspect.iscoroutinefunction(value.get_data) + and hasattr(value, "headers") + and hasattr(value, "status_code") + ) + + +def _response_header_pairs(headers: Any) -> list[tuple[str, str]]: + if headers is None: + return [] + if hasattr(headers, "to_wsgi_list"): + return [(str(key), str(value)) for key, value in headers.to_wsgi_list()] + if hasattr(headers, "items"): + return [(str(key), str(value)) for key, value in headers.items()] + return [(str(key), str(value)) for key, value in headers] + + +async def _quart_response_to_starlette( + quart_response: Any, + *, + status_code: int | None = None, + extra_headers: dict[str, str] | None = None, +) -> Response: + content = await quart_response.get_data() + response = Response( + content=content, + status_code=status_code or int(quart_response.status_code), + ) + pairs = _response_header_pairs(quart_response.headers) + if extra_headers: + pairs.extend((str(key), str(value)) for key, value in extra_headers.items()) + response.raw_headers = [ + (key.lower().encode("latin-1"), value.encode("latin-1")) for key, value in pairs + ] + return response + + +@asynccontextmanager +async def bind_quart_request_context( + request_: Request, + app: FastAPIAppAdapter, + *, + path: str | None = None, + g_obj: DashboardRequestState | None = None, +): + try: + from quart import g as quart_g + except ImportError: + yield + return + + quart_app = app.get_quart_compat_app() + headers = { + key.decode("latin-1"): value.decode("latin-1") + for key, value in request_.scope.get("headers", []) + } + body = await request_.body() + request_path = path or str(request_.url.path) + if "?" not in request_path and request_.url.query: + request_path = f"{request_path}?{request_.url.query}" + + async with quart_app.test_request_context( + request_path, + method=request_.method, + headers=headers, + data=body, + scheme=request_.url.scheme, + root_path=request_.scope.get("root_path", ""), + scope_base={ + "client": request_.scope.get("client"), + "server": request_.scope.get("server"), + }, + ): + if g_obj is not None: + for key, value in getattr(g_obj, "_values", {}).items(): + setattr(quart_g, key, value) + yield + + +async def call_request_view( + request_: Request, + app: FastAPIAppAdapter, + view_func: Callable, + path_params: dict[str, Any] | None = None, + g_obj: DashboardRequestState | None = None, + quart_compat_path: str | None = None, +): + with bind_request_context(request_, app, g_obj): + async with bind_quart_request_context( + request_, + app, + path=quart_compat_path, + g_obj=g_obj, + ): + return await _call_view(view_func, path_params or {}) + + +async def call_websocket_view( + websocket_: WebSocket, + app: FastAPIAppAdapter, + view_func: Callable, + path_params: dict[str, Any] | None = None, + *, + accept: bool = True, +): + if accept: + await websocket_.accept() + with bind_websocket_context(websocket_, app): + return await _call_view(view_func, path_params or {}) + + +class FastAPIAppAdapter: + def __init__(self, app: FastAPI, static_folder: str | None = None) -> None: + self._app = app + self.static_folder = static_folder + self._dashboard_server: Any | None = None + self.config: dict[str, Any] = {} + self.debug = False + self.testing = False + self.name = "dashboard" + self._quart_compat_app: Any | None = None + + def get_quart_compat_app(self): + if self._quart_compat_app is None: + from quart import Quart + + self._quart_compat_app = Quart("astrbot_dashboard_plugin_compat") + self._quart_compat_app.json.sort_keys = False + return self._quart_compat_app + + def add_url_rule( + self, + path: str, + view_func: Callable, + methods: list[str] | None = None, + endpoint: str | None = None, + ) -> None: + route_path = _convert_rule(path) + methods = methods or ["GET"] + + async def endpoint_func(request_: Request): + with bind_request_context(request_, self): + return await _call_view(view_func, dict(request_.path_params)) + + self._app.add_api_route( + route_path, + endpoint_func, + methods=methods, + name=endpoint, + include_in_schema=False, + ) + + def websocket(self, path: str): + route_path = _convert_rule(path) + + def decorator(view_func: Callable): + async def endpoint_func(websocket_: WebSocket): + return await call_websocket_view( + websocket_, + self, + view_func, + dict(websocket_.path_params), + ) + + self._app.add_api_websocket_route( + route_path, + endpoint_func, + name=getattr(view_func, "__name__", None), + ) + return view_func + + return decorator + + def errorhandler(self, _status_code: int): + def decorator(func: Callable): + return func + + return decorator + + async def send_static_file(self, filename: str): + if not self.static_folder: + raise HTTPException(status_code=404) + return FileResponse(Path(self.static_folder) / filename) + + def test_client(self): + self.testing = True + return AdapterTestClient(self._app) + + +AdapterResponse = Response diff --git a/astrbot/dashboard/async_utils.py b/astrbot/dashboard/async_utils.py new file mode 100644 index 0000000000..fbad60c9c2 --- /dev/null +++ b/astrbot/dashboard/async_utils.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +import inspect +from collections.abc import Awaitable, Callable +from typing import Any, TypeVar, cast + +T = TypeVar("T") + + +async def resolve_maybe_awaitable(value: T | Awaitable[T]) -> T: + while inspect.isawaitable(value): + value = await cast(Awaitable[T], value) + return cast(T, value) + + +async def run_maybe_async( + operation: Callable[[], T | Awaitable[T]] | T | Awaitable[T], +) -> T: + result: Any = operation() if callable(operation) else operation + return await resolve_maybe_awaitable(result) diff --git a/astrbot/dashboard/password_state.py b/astrbot/dashboard/password_state.py index b55c0866a7..471780ef50 100644 --- a/astrbot/dashboard/password_state.py +++ b/astrbot/dashboard/password_state.py @@ -2,8 +2,8 @@ from astrbot.core.db import BaseDatabase from astrbot.core.utils.auth_password import ( hash_dashboard_password, - hash_legacy_dashboard_password, - is_legacy_dashboard_password, + hash_md5_dashboard_password, + is_md5_dashboard_password, ) PASSWORD_STORAGE_UPGRADED_KEY = "password_storage_upgraded" @@ -83,12 +83,12 @@ def get_dashboard_password_hash(config: AstrBotConfig, *, upgraded: bool) -> str if upgraded and _has_usable_pbkdf2_password(config): return config["dashboard"].get("pbkdf2_password", "") - legacy_password = config["dashboard"].get("password", "") - if upgraded and not is_legacy_dashboard_password(legacy_password): + md5_password = config["dashboard"].get("password", "") + if upgraded and not is_md5_dashboard_password(md5_password): return "" - return legacy_password + return md5_password def set_dashboard_password_hashes(config: AstrBotConfig, raw_password: str) -> None: config["dashboard"]["pbkdf2_password"] = hash_dashboard_password(raw_password) - config["dashboard"]["password"] = hash_legacy_dashboard_password(raw_password) + config["dashboard"]["password"] = hash_md5_dashboard_password(raw_password) diff --git a/astrbot/dashboard/plugin_page_auth.py b/astrbot/dashboard/plugin_page_auth.py index f2571b3eef..cb60df1a94 100644 --- a/astrbot/dashboard/plugin_page_auth.py +++ b/astrbot/dashboard/plugin_page_auth.py @@ -1,7 +1,5 @@ from urllib.parse import unquote -from quart import request - PLUGIN_PAGE_CONTENT_PREFIX = "/api/plugin/page/content/" PLUGIN_PAGE_BRIDGE_PATH = "/api/plugin/page/bridge-sdk.js" PLUGIN_PAGE_TOKEN_TYPE = "plugin_page_asset" @@ -19,8 +17,8 @@ def is_asset_token(payload: dict) -> bool: return payload.get("token_type") == PLUGIN_PAGE_TOKEN_TYPE @staticmethod - def extract_asset_token() -> str | None: - query_asset_token = request.args.get("asset_token", "").strip() + def extract_asset_token(query_params) -> str | None: + query_asset_token = query_params.get("asset_token", "").strip() return query_asset_token or None @staticmethod diff --git a/astrbot/dashboard/responses.py b/astrbot/dashboard/responses.py new file mode 100644 index 0000000000..7eea773834 --- /dev/null +++ b/astrbot/dashboard/responses.py @@ -0,0 +1,22 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + + +@dataclass +class ApiError(Exception): + message: str + status_code: int = 400 + data: Any = None + + +def ok(data: Any = None, message: str | None = None) -> dict[str, Any]: + return {"status": "ok", "message": message, "data": {} if data is None else data} + + +def error(message: str, data: Any = None) -> dict[str, Any]: + payload: dict[str, Any] = {"status": "error", "message": message} + if data is not None: + payload["data"] = data + return payload diff --git a/astrbot/dashboard/routes/__init__.py b/astrbot/dashboard/routes/__init__.py deleted file mode 100644 index fbbd0c7a08..0000000000 --- a/astrbot/dashboard/routes/__init__.py +++ /dev/null @@ -1,49 +0,0 @@ -from .api_key import ApiKeyRoute -from .auth import AuthRoute -from .backup import BackupRoute -from .chat import ChatRoute -from .chatui_project import ChatUIProjectRoute -from .command import CommandRoute -from .config import ConfigRoute -from .conversation import ConversationRoute -from .cron import CronRoute -from .file import FileRoute -from .knowledge_base import KnowledgeBaseRoute -from .log import LogRoute -from .open_api import OpenApiRoute -from .persona import PersonaRoute -from .platform import PlatformRoute -from .plugin import PluginRoute -from .session_management import SessionManagementRoute -from .skills import SkillsRoute -from .stat import StatRoute -from .static_file import StaticFileRoute -from .subagent import SubAgentRoute -from .tools import ToolsRoute -from .update import UpdateRoute - -__all__ = [ - "ApiKeyRoute", - "AuthRoute", - "BackupRoute", - "ChatRoute", - "ChatUIProjectRoute", - "CommandRoute", - "ConfigRoute", - "ConversationRoute", - "CronRoute", - "FileRoute", - "KnowledgeBaseRoute", - "LogRoute", - "OpenApiRoute", - "PersonaRoute", - "PlatformRoute", - "PluginRoute", - "SessionManagementRoute", - "StatRoute", - "StaticFileRoute", - "SubAgentRoute", - "ToolsRoute", - "SkillsRoute", - "UpdateRoute", -] diff --git a/astrbot/dashboard/routes/api_key.py b/astrbot/dashboard/routes/api_key.py deleted file mode 100644 index 4b957fe8ea..0000000000 --- a/astrbot/dashboard/routes/api_key.py +++ /dev/null @@ -1,143 +0,0 @@ -import hashlib -import secrets -from datetime import datetime, timedelta, timezone - -from quart import g, request - -from astrbot.core.db import BaseDatabase -from astrbot.core.utils.datetime_utils import normalize_datetime_utc - -from .route import Response, Route, RouteContext - -ALL_OPEN_API_SCOPES = ("chat", "config", "file", "im") - - -class ApiKeyRoute(Route): - def __init__(self, context: RouteContext, db: BaseDatabase) -> None: - super().__init__(context) - self.db = db - self.routes = { - "/apikey/list": ("GET", self.list_api_keys), - "/apikey/create": ("POST", self.create_api_key), - "/apikey/revoke": ("POST", self.revoke_api_key), - "/apikey/delete": ("POST", self.delete_api_key), - } - self.register_routes() - - @staticmethod - def _normalize_utc(dt: datetime | None) -> datetime | None: - return normalize_datetime_utc(dt) - - @classmethod - def _serialize_datetime(cls, dt: datetime | None) -> str | None: - normalized = cls._normalize_utc(dt) - if normalized is None: - return None - return normalized.astimezone().isoformat() - - @staticmethod - def _hash_key(raw_key: str) -> str: - return hashlib.pbkdf2_hmac( - "sha256", - raw_key.encode("utf-8"), - b"astrbot_api_key", - 100_000, - ).hex() - - @staticmethod - def _serialize_api_key(key) -> dict: - expires_at = ApiKeyRoute._normalize_utc(key.expires_at) - return { - "key_id": key.key_id, - "name": key.name, - "key_prefix": key.key_prefix, - "scopes": key.scopes or [], - "created_by": key.created_by, - "created_at": ApiKeyRoute._serialize_datetime(key.created_at), - "updated_at": ApiKeyRoute._serialize_datetime(key.updated_at), - "last_used_at": ApiKeyRoute._serialize_datetime(key.last_used_at), - "expires_at": ApiKeyRoute._serialize_datetime(key.expires_at), - "revoked_at": ApiKeyRoute._serialize_datetime(key.revoked_at), - "is_revoked": key.revoked_at is not None, - "is_expired": bool(expires_at and expires_at < datetime.now(timezone.utc)), - } - - async def list_api_keys(self): - keys = await self.db.list_api_keys() - return ( - Response().ok(data=[self._serialize_api_key(key) for key in keys]).__dict__ - ) - - async def create_api_key(self): - post_data = await request.json or {} - - name = str(post_data.get("name", "")).strip() or "Untitled API Key" - scopes = post_data.get("scopes") - if scopes is None: - normalized_scopes = list(ALL_OPEN_API_SCOPES) - elif isinstance(scopes, list): - normalized_scopes = [ - scope - for scope in scopes - if isinstance(scope, str) and scope in ALL_OPEN_API_SCOPES - ] - normalized_scopes = list(dict.fromkeys(normalized_scopes)) - if not normalized_scopes: - return Response().error("At least one valid scope is required").__dict__ - else: - return Response().error("Invalid scopes").__dict__ - - expires_at = None - expires_in_days = post_data.get("expires_in_days") - if expires_in_days is not None: - try: - expires_in_days_int = int(expires_in_days) - except (TypeError, ValueError): - return Response().error("expires_in_days must be an integer").__dict__ - if expires_in_days_int <= 0: - return ( - Response().error("expires_in_days must be greater than 0").__dict__ - ) - expires_at = datetime.now(timezone.utc) + timedelta( - days=expires_in_days_int - ) - - raw_key = f"abk_{secrets.token_urlsafe(32)}" - key_hash = self._hash_key(raw_key) - key_prefix = raw_key[:12] - created_by = g.get("username", "unknown") - - api_key = await self.db.create_api_key( - name=name, - key_hash=key_hash, - key_prefix=key_prefix, - scopes=normalized_scopes, # type: ignore - created_by=created_by, - expires_at=expires_at, - ) - - payload = self._serialize_api_key(api_key) - payload["api_key"] = raw_key - return Response().ok(data=payload).__dict__ - - async def revoke_api_key(self): - post_data = await request.json or {} - key_id = post_data.get("key_id") - if not key_id: - return Response().error("Missing key: key_id").__dict__ - - success = await self.db.revoke_api_key(key_id) - if not success: - return Response().error("API key not found").__dict__ - return Response().ok().__dict__ - - async def delete_api_key(self): - post_data = await request.json or {} - key_id = post_data.get("key_id") - if not key_id: - return Response().error("Missing key: key_id").__dict__ - - success = await self.db.delete_api_key(key_id) - if not success: - return Response().error("API key not found").__dict__ - return Response().ok().__dict__ diff --git a/astrbot/dashboard/routes/backup.py b/astrbot/dashboard/routes/backup.py deleted file mode 100644 index ecc5dbfc80..0000000000 --- a/astrbot/dashboard/routes/backup.py +++ /dev/null @@ -1,1106 +0,0 @@ -"""备份管理 API 路由""" - -import asyncio -import json -import os -import re -import shutil -import time -import traceback -import uuid -import zipfile -from datetime import datetime -from pathlib import Path - -import jwt -from quart import request, send_file - -from astrbot.core import logger -from astrbot.core.backup.exporter import AstrBotExporter -from astrbot.core.backup.importer import AstrBotImporter -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.db import BaseDatabase -from astrbot.core.utils.astrbot_path import ( - get_astrbot_backups_path, - get_astrbot_data_path, -) - -from .route import Response, Route, RouteContext - -# 分片上传常量 -CHUNK_SIZE = 1024 * 1024 # 1MB -UPLOAD_EXPIRE_SECONDS = 3600 # 上传会话过期时间(1小时) - - -def secure_filename(filename: str) -> str: - """清洗文件名,移除路径遍历字符和危险字符 - - Args: - filename: 原始文件名 - - Returns: - 安全的文件名 - """ - # 跨平台处理:先将反斜杠替换为正斜杠,再取文件名 - filename = filename.replace("\\", "/") - # 仅保留文件名部分,移除路径 - filename = os.path.basename(filename) - - # 替换路径遍历字符 - filename = filename.replace("..", "_") - - # 仅保留字母、数字、下划线、连字符、点 - filename = re.sub(r"[^\w\-.]", "_", filename) - - # 移除前导点(隐藏文件)和尾部点 - filename = filename.strip(".") - - # 如果文件名为空或只包含下划线,生成一个默认名称 - if not filename or filename.replace("_", "") == "": - filename = "backup" - - return filename - - -def generate_unique_filename(original_filename: str) -> str: - """生成唯一的文件名,在原文件名后添加时间戳后缀避免重名 - - Args: - original_filename: 原始文件名(已清洗) - - Returns: - 添加了时间戳后缀的唯一文件名,格式为 {原文件名}_{YYYYMMDD_HHMMSS}.{扩展名} - """ - name, ext = os.path.splitext(original_filename) - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - return f"{name}_{timestamp}{ext}" - - -class BackupRoute(Route): - """备份管理路由 - - 提供备份导出、导入、列表等 API 接口 - """ - - def __init__( - self, - context: RouteContext, - db: BaseDatabase, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.db = db - self.core_lifecycle = core_lifecycle - self.backup_dir = get_astrbot_backups_path() - self.data_dir = get_astrbot_data_path() - self.chunks_dir = os.path.join(self.backup_dir, ".chunks") - - # 任务状态跟踪 - self.backup_tasks: dict[str, dict] = {} - self.backup_progress: dict[str, dict] = {} - - # 分片上传会话跟踪 - # upload_id -> {filename, total_chunks, received_chunks, last_activity, chunk_dir} - self.upload_sessions: dict[str, dict] = {} - - # 后台清理任务句柄 - self._cleanup_task: asyncio.Task | None = None - - # 注册路由 - self.routes = { - "/backup/list": ("GET", self.list_backups), - "/backup/export": ("POST", self.export_backup), - "/backup/upload": ("POST", self.upload_backup), # 上传文件(兼容小文件) - "/backup/upload/init": ("POST", self.upload_init), # 分片上传初始化 - "/backup/upload/chunk": ("POST", self.upload_chunk), # 上传分片 - "/backup/upload/complete": ("POST", self.upload_complete), # 完成分片上传 - "/backup/upload/abort": ("POST", self.upload_abort), # 取消上传 - "/backup/check": ("POST", self.check_backup), # 预检查 - "/backup/import": ("POST", self.import_backup), # 确认导入 - "/backup/progress": ("GET", self.get_progress), - "/backup/download": ("GET", self.download_backup), - "/backup/delete": ("POST", self.delete_backup), - "/backup/rename": ("POST", self.rename_backup), # 重命名备份 - } - self.register_routes() - - def _init_task(self, task_id: str, task_type: str, status: str = "pending") -> None: - """初始化任务状态""" - self.backup_tasks[task_id] = { - "type": task_type, - "status": status, - "result": None, - "error": None, - } - self.backup_progress[task_id] = { - "status": status, - "stage": "waiting", - "current": 0, - "total": 100, - "message": "", - } - - def _set_task_result( - self, - task_id: str, - status: str, - result: dict | None = None, - error: str | None = None, - ) -> None: - """设置任务结果""" - if task_id in self.backup_tasks: - self.backup_tasks[task_id]["status"] = status - self.backup_tasks[task_id]["result"] = result - self.backup_tasks[task_id]["error"] = error - if task_id in self.backup_progress: - self.backup_progress[task_id]["status"] = status - - def _update_progress( - self, - task_id: str, - *, - status: str | None = None, - stage: str | None = None, - current: int | None = None, - total: int | None = None, - message: str | None = None, - ) -> None: - """更新任务进度""" - if task_id not in self.backup_progress: - return - p = self.backup_progress[task_id] - if status is not None: - p["status"] = status - if stage is not None: - p["stage"] = stage - if current is not None: - p["current"] = current - if total is not None: - p["total"] = total - if message is not None: - p["message"] = message - - def _make_progress_callback(self, task_id: str): - """创建进度回调函数""" - - async def _callback( - stage: str, current: int, total: int, message: str = "" - ) -> None: - self._update_progress( - task_id, - status="processing", - stage=stage, - current=current, - total=total, - message=message, - ) - - return _callback - - def _ensure_cleanup_task_started(self) -> None: - """确保后台清理任务已启动(在异步上下文中延迟启动)""" - if self._cleanup_task is None or self._cleanup_task.done(): - try: - self._cleanup_task = asyncio.create_task( - self._cleanup_expired_uploads() - ) - except RuntimeError: - # 如果没有运行中的事件循环,跳过(等待下次异步调用时启动) - pass - - async def _cleanup_expired_uploads(self) -> None: - """定期清理过期的上传会话 - - 基于 last_activity 字段判断过期,避免清理活跃的上传会话。 - """ - while True: - try: - await asyncio.sleep(300) # 每5分钟检查一次 - current_time = time.time() - expired_sessions = [] - - for upload_id, session in self.upload_sessions.items(): - # 使用 last_activity 判断过期,而非 created_at - last_activity = session.get("last_activity", session["created_at"]) - if current_time - last_activity > UPLOAD_EXPIRE_SECONDS: - expired_sessions.append(upload_id) - - for upload_id in expired_sessions: - await self._cleanup_upload_session(upload_id) - logger.info(f"清理过期的上传会话: {upload_id}") - - except asyncio.CancelledError: - # 任务被取消,正常退出 - break - except Exception as e: - logger.error(f"清理过期上传会话失败: {e}") - - async def _cleanup_upload_session(self, upload_id: str) -> None: - """清理上传会话""" - if upload_id in self.upload_sessions: - session = self.upload_sessions[upload_id] - chunk_dir = session.get("chunk_dir") - if chunk_dir and os.path.exists(chunk_dir): - try: - shutil.rmtree(chunk_dir) - except Exception as e: - logger.warning(f"清理分片目录失败: {e}") - del self.upload_sessions[upload_id] - - def _get_backup_manifest(self, zip_path: str) -> dict | None: - """从备份文件读取 manifest.json - - Args: - zip_path: ZIP 文件路径 - - Returns: - dict | None: manifest 内容,如果不是有效备份则返回 None - """ - try: - with zipfile.ZipFile(zip_path, "r") as zf: - if "manifest.json" in zf.namelist(): - manifest_data = zf.read("manifest.json") - return json.loads(manifest_data.decode("utf-8")) - else: - # 没有 manifest.json,不是有效的 AstrBot 备份 - return None - except Exception as e: - logger.debug(f"读取备份 manifest 失败: {e}") - return None # 无法读取,不是有效备份 - - async def list_backups(self): - # 确保后台清理任务已启动 - self._ensure_cleanup_task_started() - - """获取备份列表 - - Query 参数: - - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) - """ - try: - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) - - # 确保备份目录存在 - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) - - # 获取所有备份文件 - backup_files = [] - for filename in os.listdir(self.backup_dir): - # 只处理 .zip 文件,排除隐藏文件和目录 - if not filename.endswith(".zip") or filename.startswith("."): - continue - - file_path = os.path.join(self.backup_dir, filename) - if not os.path.isfile(file_path): - continue - - # 读取 manifest.json 获取备份信息 - # 如果返回 None,说明不是有效的 AstrBot 备份,跳过 - manifest = self._get_backup_manifest(file_path) - if manifest is None: - logger.debug(f"跳过无效备份文件: {filename}") - continue - - stat = os.stat(file_path) - backup_files.append( - { - "filename": filename, - "size": stat.st_size, - "created_at": stat.st_mtime, - "type": manifest.get( - "origin", "exported" - ), # 老版本没有 origin 默认为 exported - "astrbot_version": manifest.get("astrbot_version", "未知"), - "exported_at": manifest.get("exported_at"), - } - ) - - # 按创建时间倒序排序 - backup_files.sort(key=lambda x: x["created_at"], reverse=True) - - # 分页 - start = (page - 1) * page_size - end = start + page_size - items = backup_files[start:end] - - return ( - Response() - .ok( - { - "items": items, - "total": len(backup_files), - "page": page, - "page_size": page_size, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取备份列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取备份列表失败: {e!s}").__dict__ - - async def export_backup(self): - """创建备份 - - 返回: - - task_id: 任务ID,用于查询导出进度 - """ - try: - # 生成任务ID - task_id = str(uuid.uuid4()) - - # 初始化任务状态 - self._init_task(task_id, "export", "pending") - - # 启动后台导出任务 - asyncio.create_task(self._background_export_task(task_id)) - - return ( - Response() - .ok( - { - "task_id": task_id, - "message": "export task created, processing in background", - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"创建备份失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"创建备份失败: {e!s}").__dict__ - - async def _background_export_task(self, task_id: str) -> None: - """后台导出任务""" - try: - self._update_progress(task_id, status="processing", message="正在初始化...") - - # 获取知识库管理器 - kb_manager = getattr(self.core_lifecycle, "kb_manager", None) - - exporter = AstrBotExporter( - main_db=self.db, - kb_manager=kb_manager, - config_path=os.path.join(self.data_dir, "cmd_config.json"), - ) - - # 创建进度回调 - progress_callback = self._make_progress_callback(task_id) - - # 执行导出 - zip_path = await exporter.export_all( - output_dir=self.backup_dir, - progress_callback=progress_callback, - ) - - # 设置成功结果 - self._set_task_result( - task_id, - "completed", - result={ - "filename": os.path.basename(zip_path), - "path": zip_path, - "size": os.path.getsize(zip_path), - }, - ) - except Exception as e: - logger.error(f"后台导出任务 {task_id} 失败: {e}") - logger.error(traceback.format_exc()) - self._set_task_result(task_id, "failed", error=str(e)) - - async def upload_backup(self): - """上传备份文件 - - 将备份文件上传到服务器,返回保存的文件名。 - 上传后应调用 check_backup 进行预检查。 - - Form Data: - - file: 备份文件 (.zip) - - 返回: - - filename: 保存的文件名 - """ - try: - files = await request.files - if "file" not in files: - return Response().error("缺少备份文件").__dict__ - - file = files["file"] - if not file.filename or not file.filename.endswith(".zip"): - return Response().error("请上传 ZIP 格式的备份文件").__dict__ - - # 清洗文件名并生成唯一名称,防止路径遍历和覆盖 - safe_filename = secure_filename(file.filename) - unique_filename = generate_unique_filename(safe_filename) - - # 保存上传的文件 - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) - zip_path = os.path.join(self.backup_dir, unique_filename) - await file.save(zip_path) - - logger.info( - f"上传的备份文件已保存: {unique_filename} (原始名称: {file.filename})" - ) - - return ( - Response() - .ok( - { - "filename": unique_filename, - "original_filename": file.filename, - "size": os.path.getsize(zip_path), - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"上传备份文件失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"上传备份文件失败: {e!s}").__dict__ - - async def upload_init(self): - """初始化分片上传 - - 创建一个上传会话,返回 upload_id 供后续分片上传使用。 - - JSON Body: - - filename: 原始文件名 - - total_size: 文件总大小(字节) - - 返回: - - upload_id: 上传会话 ID - - chunk_size: 分片大小(由后端决定) - - total_chunks: 分片总数(由后端根据 total_size 和 chunk_size 计算) - """ - try: - data = await request.json - filename = data.get("filename") - total_size = data.get("total_size", 0) - - if not filename: - return Response().error("缺少 filename 参数").__dict__ - - if not filename.endswith(".zip"): - return Response().error("请上传 ZIP 格式的备份文件").__dict__ - - if total_size <= 0: - return Response().error("无效的文件大小").__dict__ - - # 由后端计算分片总数,确保前后端一致 - import math - - total_chunks = math.ceil(total_size / CHUNK_SIZE) - - # 生成上传 ID - upload_id = str(uuid.uuid4()) - - # 创建分片存储目录 - chunk_dir = os.path.join(self.chunks_dir, upload_id) - Path(chunk_dir).mkdir(parents=True, exist_ok=True) - - # 清洗文件名 - safe_filename = secure_filename(filename) - unique_filename = generate_unique_filename(safe_filename) - - # 创建上传会话 - current_time = time.time() - self.upload_sessions[upload_id] = { - "filename": unique_filename, - "original_filename": filename, - "total_size": total_size, - "total_chunks": total_chunks, - "received_chunks": set(), - "created_at": current_time, - "last_activity": current_time, # 用于判断会话是否活跃 - "chunk_dir": chunk_dir, - } - - logger.info( - f"初始化分片上传: upload_id={upload_id}, " - f"filename={unique_filename}, total_chunks={total_chunks}" - ) - - return ( - Response() - .ok( - { - "upload_id": upload_id, - "chunk_size": CHUNK_SIZE, - "total_chunks": total_chunks, - "filename": unique_filename, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"初始化分片上传失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"初始化分片上传失败: {e!s}").__dict__ - - async def upload_chunk(self): - """上传分片 - - 上传单个分片数据。 - - Form Data: - - upload_id: 上传会话 ID - - chunk_index: 分片索引(从 0 开始) - - chunk: 分片数据 - - 返回: - - received: 已接收的分片数量 - - total: 分片总数 - """ - try: - form = await request.form - files = await request.files - - upload_id = form.get("upload_id") - chunk_index_str = form.get("chunk_index") - - if not upload_id or chunk_index_str is None: - return Response().error("缺少必要参数").__dict__ - - try: - chunk_index = int(chunk_index_str) - except ValueError: - return Response().error("无效的分片索引").__dict__ - - if "chunk" not in files: - return Response().error("缺少分片数据").__dict__ - - # 验证上传会话 - if upload_id not in self.upload_sessions: - return Response().error("上传会话不存在或已过期").__dict__ - - session = self.upload_sessions[upload_id] - - # 验证分片索引 - if chunk_index < 0 or chunk_index >= session["total_chunks"]: - return Response().error("分片索引超出范围").__dict__ - - # 保存分片 - chunk_file = files["chunk"] - chunk_path = os.path.join(session["chunk_dir"], f"{chunk_index}.part") - await chunk_file.save(chunk_path) - - # 记录已接收的分片,并更新最后活动时间 - session["received_chunks"].add(chunk_index) - session["last_activity"] = time.time() # 刷新活动时间,防止活跃上传被清理 - - received_count = len(session["received_chunks"]) - total_chunks = session["total_chunks"] - - logger.debug( - f"接收分片: upload_id={upload_id}, " - f"chunk={chunk_index + 1}/{total_chunks}" - ) - - return ( - Response() - .ok( - { - "received": received_count, - "total": total_chunks, - "chunk_index": chunk_index, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"上传分片失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"上传分片失败: {e!s}").__dict__ - - def _mark_backup_as_uploaded(self, zip_path: str) -> None: - """修改备份文件的 manifest.json,将 origin 设置为 uploaded - - 使用 zipfile 的 append 模式添加新的 manifest.json, - ZIP 规范中后添加的同名文件会覆盖先前的文件。 - - Args: - zip_path: ZIP 文件路径 - """ - try: - # 读取原有 manifest - manifest = {"origin": "uploaded", "uploaded_at": datetime.now().isoformat()} - with zipfile.ZipFile(zip_path, "r") as zf: - if "manifest.json" in zf.namelist(): - manifest_data = zf.read("manifest.json") - manifest = json.loads(manifest_data.decode("utf-8")) - manifest["origin"] = "uploaded" - manifest["uploaded_at"] = datetime.now().isoformat() - - # 使用 append 模式添加新的 manifest.json - # ZIP 规范中,后添加的同名文件会覆盖先前的 - with zipfile.ZipFile(zip_path, "a") as zf: - new_manifest = json.dumps(manifest, ensure_ascii=False, indent=2) - zf.writestr("manifest.json", new_manifest) - - logger.debug(f"已标记备份为上传来源: {zip_path}") - except Exception as e: - logger.warning(f"标记备份来源失败: {e}") - - async def upload_complete(self): - """完成分片上传 - - 合并所有分片为完整文件。 - - JSON Body: - - upload_id: 上传会话 ID - - 返回: - - filename: 合并后的文件名 - - size: 文件大小 - """ - try: - data = await request.json - upload_id = data.get("upload_id") - - if not upload_id: - return Response().error("缺少 upload_id 参数").__dict__ - - # 验证上传会话 - if upload_id not in self.upload_sessions: - return Response().error("上传会话不存在或已过期").__dict__ - - session = self.upload_sessions[upload_id] - - # 检查是否所有分片都已接收 - received = session["received_chunks"] - total = session["total_chunks"] - - if len(received) != total: - missing = set(range(total)) - received - return ( - Response() - .error(f"分片不完整,缺少: {sorted(missing)[:10]}...") - .__dict__ - ) - - # 合并分片 - chunk_dir = session["chunk_dir"] - filename = session["filename"] - - Path(self.backup_dir).mkdir(parents=True, exist_ok=True) - output_path = os.path.join(self.backup_dir, filename) - - try: - with open(output_path, "wb") as outfile: - for i in range(total): - chunk_path = os.path.join(chunk_dir, f"{i}.part") - with open(chunk_path, "rb") as chunk_file: - # 分块读取,避免内存溢出 - while True: - data_block = chunk_file.read(8192) - if not data_block: - break - outfile.write(data_block) - - file_size = os.path.getsize(output_path) - - # 标记备份为上传来源(修改 manifest.json 中的 origin 字段) - self._mark_backup_as_uploaded(output_path) - - logger.info( - f"分片上传完成: {filename}, size={file_size}, chunks={total}" - ) - - # 清理分片目录 - await self._cleanup_upload_session(upload_id) - - return ( - Response() - .ok( - { - "filename": filename, - "original_filename": session["original_filename"], - "size": file_size, - } - ) - .__dict__ - ) - except Exception as e: - # 如果合并失败,删除不完整的文件 - if os.path.exists(output_path): - os.remove(output_path) - raise e - - except Exception as e: - logger.error(f"完成分片上传失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"完成分片上传失败: {e!s}").__dict__ - - async def upload_abort(self): - """取消分片上传 - - 取消上传并清理已上传的分片。 - - JSON Body: - - upload_id: 上传会话 ID - """ - try: - data = await request.json - upload_id = data.get("upload_id") - - if not upload_id: - return Response().error("缺少 upload_id 参数").__dict__ - - if upload_id not in self.upload_sessions: - # 会话已不存在,可能已过期或已完成 - return Response().ok(message="上传已取消").__dict__ - - # 清理会话 - await self._cleanup_upload_session(upload_id) - - logger.info(f"取消分片上传: {upload_id}") - - return Response().ok(message="上传已取消").__dict__ - except Exception as e: - logger.error(f"取消上传失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"取消上传失败: {e!s}").__dict__ - - async def check_backup(self): - """预检查备份文件 - - 检查备份文件的版本兼容性,返回确认信息。 - 用户确认后调用 import_backup 执行导入。 - - JSON Body: - - filename: 已上传的备份文件名 - - 返回: - - ImportPreCheckResult: 预检查结果 - """ - try: - data = await request.json - filename = data.get("filename") - if not filename: - return Response().error("缺少 filename 参数").__dict__ - - # 安全检查 - 防止路径遍历 - if ".." in filename or "/" in filename or "\\" in filename: - return Response().error("无效的文件名").__dict__ - - zip_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(zip_path): - return Response().error(f"备份文件不存在: {filename}").__dict__ - - # 获取知识库管理器(用于构造 importer) - kb_manager = getattr(self.core_lifecycle, "kb_manager", None) - - importer = AstrBotImporter( - main_db=self.db, - kb_manager=kb_manager, - config_path=os.path.join(self.data_dir, "cmd_config.json"), - ) - - # 执行预检查 - check_result = importer.pre_check(zip_path) - - return Response().ok(check_result.to_dict()).__dict__ - except Exception as e: - logger.error(f"预检查备份文件失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"预检查备份文件失败: {e!s}").__dict__ - - async def import_backup(self): - """执行备份导入 - - 在用户确认后执行实际的导入操作。 - 需要先调用 upload_backup 上传文件,再调用 check_backup 预检查。 - - JSON Body: - - filename: 已上传的备份文件名(必填) - - confirmed: 用户已确认(必填,必须为 true) - - 返回: - - task_id: 任务ID,用于查询导入进度 - """ - try: - data = await request.json - filename = data.get("filename") - confirmed = data.get("confirmed", False) - - if not filename: - return Response().error("缺少 filename 参数").__dict__ - - if not confirmed: - return ( - Response() - .error("请先确认导入。导入将会清空并覆盖现有数据,此操作不可撤销。") - .__dict__ - ) - - # 安全检查 - 防止路径遍历 - if ".." in filename or "/" in filename or "\\" in filename: - return Response().error("无效的文件名").__dict__ - - zip_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(zip_path): - return Response().error(f"备份文件不存在: {filename}").__dict__ - - # 生成任务ID - task_id = str(uuid.uuid4()) - - # 初始化任务状态 - self._init_task(task_id, "import", "pending") - - # 启动后台导入任务 - asyncio.create_task(self._background_import_task(task_id, zip_path)) - - return ( - Response() - .ok( - { - "task_id": task_id, - "message": "import task created, processing in background", - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"导入备份失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"导入备份失败: {e!s}").__dict__ - - async def _background_import_task(self, task_id: str, zip_path: str) -> None: - """后台导入任务""" - try: - self._update_progress(task_id, status="processing", message="正在初始化...") - - # 获取知识库管理器 - kb_manager = getattr(self.core_lifecycle, "kb_manager", None) - - importer = AstrBotImporter( - main_db=self.db, - kb_manager=kb_manager, - config_path=os.path.join(self.data_dir, "cmd_config.json"), - ) - - # 创建进度回调 - progress_callback = self._make_progress_callback(task_id) - - # 执行导入 - result = await importer.import_all( - zip_path=zip_path, - mode="replace", - progress_callback=progress_callback, - ) - - # 设置结果 - if result.success: - self._set_task_result( - task_id, - "completed", - result=result.to_dict(), - ) - else: - self._set_task_result( - task_id, - "failed", - error="; ".join(result.errors), - ) - except Exception as e: - logger.error(f"后台导入任务 {task_id} 失败: {e}") - logger.error(traceback.format_exc()) - self._set_task_result(task_id, "failed", error=str(e)) - - async def get_progress(self): - """获取任务进度 - - Query 参数: - - task_id: 任务 ID (必填) - """ - try: - task_id = request.args.get("task_id") - if not task_id: - return Response().error("缺少参数 task_id").__dict__ - - if task_id not in self.backup_tasks: - return Response().error("找不到该任务").__dict__ - - task_info = self.backup_tasks[task_id] - status = task_info["status"] - - response_data = { - "task_id": task_id, - "type": task_info["type"], - "status": status, - } - - # 如果任务正在处理,返回进度信息 - if status == "processing" and task_id in self.backup_progress: - response_data["progress"] = self.backup_progress[task_id] - - # 如果任务完成,返回结果 - if status == "completed": - response_data["result"] = task_info["result"] - - # 如果任务失败,返回错误信息 - if status == "failed": - response_data["error"] = task_info["error"] - - return Response().ok(response_data).__dict__ - except Exception as e: - logger.error(f"获取任务进度失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取任务进度失败: {e!s}").__dict__ - - async def download_backup(self): - """下载备份文件 - - Query 参数: - - filename: 备份文件名 (必填) - - token: JWT token (必填,用于浏览器原生下载鉴权) - - 注意: 此路由已被添加到 auth_middleware 白名单中, - 使用 URL 参数中的 token 进行鉴权,以支持浏览器原生下载。 - """ - try: - filename = request.args.get("filename") - token = request.args.get("token") - - if not filename: - return Response().error("缺少参数 filename").__dict__ - - if not token: - return Response().error("缺少参数 token").__dict__ - - # 验证 JWT token - try: - jwt_secret = self.config.get("dashboard", {}).get("jwt_secret") - if not jwt_secret: - return Response().error("服务器配置错误").__dict__ - - # Verify JWT token with strict security options - jwt.decode( - token, - jwt_secret, - algorithms=["HS256"], - options={ - "require": ["exp"], # Require expiration claim - "verify_signature": True, # Explicitly verify signature - "verify_exp": True, # Verify expiration - }, - ) - except jwt.ExpiredSignatureError: - return Response().error("Token 已过期,请刷新页面后重试").__dict__ - except jwt.InvalidTokenError: - return Response().error("Token 无效").__dict__ - - # 安全检查 - 防止路径遍历 - if ".." in filename or "/" in filename or "\\" in filename: - return Response().error("无效的文件名").__dict__ - - file_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(file_path): - return Response().error("备份文件不存在").__dict__ - - return await send_file( - file_path, - as_attachment=True, - attachment_filename=filename, - conditional=True, # 启用 Range 请求支持(断点续传) - ) - except Exception as e: - logger.error(f"下载备份失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"下载备份失败: {e!s}").__dict__ - - async def delete_backup(self): - """删除备份文件 - - Body: - - filename: 备份文件名 (必填) - """ - try: - data = await request.json - filename = data.get("filename") - if not filename: - return Response().error("缺少参数 filename").__dict__ - - # 安全检查 - 防止路径遍历 - if ".." in filename or "/" in filename or "\\" in filename: - return Response().error("无效的文件名").__dict__ - - file_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(file_path): - return Response().error("备份文件不存在").__dict__ - - os.remove(file_path) - return Response().ok(message="删除备份成功").__dict__ - except Exception as e: - logger.error(f"删除备份失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除备份失败: {e!s}").__dict__ - - async def rename_backup(self): - """重命名备份文件 - - Body: - - filename: 当前文件名 (必填) - - new_name: 新文件名 (必填,不含扩展名) - """ - try: - data = await request.json - filename = data.get("filename") - new_name = data.get("new_name") - - if not filename: - return Response().error("缺少参数 filename").__dict__ - - if not new_name: - return Response().error("缺少参数 new_name").__dict__ - - # 安全检查 - 防止路径遍历 - if ".." in filename or "/" in filename or "\\" in filename: - return Response().error("无效的文件名").__dict__ - - # 清洗新文件名(移除路径和危险字符) - new_name = secure_filename(new_name) - - # 移除新文件名中的扩展名(如果有的话) - if new_name.endswith(".zip"): - new_name = new_name[:-4] - - # 验证新文件名不为空 - if not new_name or new_name.replace("_", "") == "": - return Response().error("新文件名无效").__dict__ - - # 强制使用 .zip 扩展名 - new_filename = f"{new_name}.zip" - - # 检查原文件是否存在 - old_path = os.path.join(self.backup_dir, filename) - if not os.path.exists(old_path): - return Response().error("备份文件不存在").__dict__ - - # 检查新文件名是否已存在 - new_path = os.path.join(self.backup_dir, new_filename) - if os.path.exists(new_path): - return Response().error(f"文件名 '{new_filename}' 已存在").__dict__ - - # 执行重命名 - os.rename(old_path, new_path) - - logger.info(f"备份文件重命名: {filename} -> {new_filename}") - - return ( - Response() - .ok( - { - "old_filename": filename, - "new_filename": new_filename, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"重命名备份失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"重命名备份失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/chatui_project.py b/astrbot/dashboard/routes/chatui_project.py deleted file mode 100644 index 6ba570f552..0000000000 --- a/astrbot/dashboard/routes/chatui_project.py +++ /dev/null @@ -1,246 +0,0 @@ -from quart import g, request - -from astrbot.core.db import BaseDatabase -from astrbot.core.utils.datetime_utils import to_utc_isoformat - -from .route import Response, Route, RouteContext - - -class ChatUIProjectRoute(Route): - def __init__(self, context: RouteContext, db: BaseDatabase) -> None: - super().__init__(context) - self.routes = { - "/chatui_project/create": ("POST", self.create_project), - "/chatui_project/list": ("GET", self.list_projects), - "/chatui_project/get": ("GET", self.get_project), - "/chatui_project/update": ("POST", self.update_chatui_project), - "/chatui_project/delete": ("GET", self.delete_project), - "/chatui_project/add_session": ("POST", self.add_session_to_project), - "/chatui_project/remove_session": ( - "POST", - self.remove_session_from_project, - ), - "/chatui_project/get_sessions": ("GET", self.get_project_sessions), - } - self.db = db - self.register_routes() - - async def create_project(self): - """Create a new ChatUI project.""" - username = g.get("username", "guest") - post_data = await request.json - - title = post_data.get("title") - emoji = post_data.get("emoji", "📁") - description = post_data.get("description") - - if not title: - return Response().error("Missing key: title").__dict__ - - project = await self.db.create_chatui_project( - creator=username, - title=title, - emoji=emoji, - description=description, - ) - - return ( - Response() - .ok( - data={ - "project_id": project.project_id, - "title": project.title, - "emoji": project.emoji, - "description": project.description, - "created_at": to_utc_isoformat(project.created_at), - "updated_at": to_utc_isoformat(project.updated_at), - } - ) - .__dict__ - ) - - async def list_projects(self): - """Get all ChatUI projects for the current user.""" - username = g.get("username", "guest") - - projects = await self.db.get_chatui_projects_by_creator(creator=username) - - projects_data = [ - { - "project_id": project.project_id, - "title": project.title, - "emoji": project.emoji, - "description": project.description, - "created_at": to_utc_isoformat(project.created_at), - "updated_at": to_utc_isoformat(project.updated_at), - } - for project in projects - ] - - return Response().ok(data=projects_data).__dict__ - - async def get_project(self): - """Get a specific ChatUI project.""" - project_id = request.args.get("project_id") - if not project_id: - return Response().error("Missing key: project_id").__dict__ - - username = g.get("username", "guest") - - project = await self.db.get_chatui_project_by_id(project_id) - if not project: - return Response().error(f"Project {project_id} not found").__dict__ - - # Verify ownership - if project.creator != username: - return Response().error("Permission denied").__dict__ - - return ( - Response() - .ok( - data={ - "project_id": project.project_id, - "title": project.title, - "emoji": project.emoji, - "description": project.description, - "created_at": to_utc_isoformat(project.created_at), - "updated_at": to_utc_isoformat(project.updated_at), - } - ) - .__dict__ - ) - - async def update_chatui_project(self): - """Update a ChatUI project.""" - post_data = await request.json - - project_id = post_data.get("project_id") - title = post_data.get("title") - emoji = post_data.get("emoji") - description = post_data.get("description") - - if not project_id: - return Response().error("Missing key: project_id").__dict__ - - username = g.get("username", "guest") - - # Verify ownership - project = await self.db.get_chatui_project_by_id(project_id) - if not project: - return Response().error(f"Project {project_id} not found").__dict__ - if project.creator != username: - return Response().error("Permission denied").__dict__ - - await self.db.update_chatui_project( - project_id=project_id, - title=title, - emoji=emoji, - description=description, - ) - - return Response().ok().__dict__ - - async def delete_project(self): - """Delete a ChatUI project.""" - project_id = request.args.get("project_id") - if not project_id: - return Response().error("Missing key: project_id").__dict__ - - username = g.get("username", "guest") - - # Verify ownership - project = await self.db.get_chatui_project_by_id(project_id) - if not project: - return Response().error(f"Project {project_id} not found").__dict__ - if project.creator != username: - return Response().error("Permission denied").__dict__ - - await self.db.delete_chatui_project(project_id) - - return Response().ok().__dict__ - - async def add_session_to_project(self): - """Add a session to a project.""" - post_data = await request.json - - session_id = post_data.get("session_id") - project_id = post_data.get("project_id") - - if not session_id: - return Response().error("Missing key: session_id").__dict__ - if not project_id: - return Response().error("Missing key: project_id").__dict__ - - username = g.get("username", "guest") - - # Verify project ownership - project = await self.db.get_chatui_project_by_id(project_id) - if not project: - return Response().error(f"Project {project_id} not found").__dict__ - if project.creator != username: - return Response().error("Permission denied").__dict__ - - # Verify session ownership - session = await self.db.get_platform_session_by_id(session_id) - if not session: - return Response().error(f"Session {session_id} not found").__dict__ - if session.creator != username: - return Response().error("Permission denied").__dict__ - - await self.db.add_session_to_project(session_id, project_id) - - return Response().ok().__dict__ - - async def remove_session_from_project(self): - """Remove a session from its project.""" - post_data = await request.json - - session_id = post_data.get("session_id") - - if not session_id: - return Response().error("Missing key: session_id").__dict__ - - username = g.get("username", "guest") - - # Verify session ownership - session = await self.db.get_platform_session_by_id(session_id) - if not session: - return Response().error(f"Session {session_id} not found").__dict__ - if session.creator != username: - return Response().error("Permission denied").__dict__ - - await self.db.remove_session_from_project(session_id) - - return Response().ok().__dict__ - - async def get_project_sessions(self): - """Get all sessions in a project.""" - project_id = request.args.get("project_id") - if not project_id: - return Response().error("Missing key: project_id").__dict__ - - username = g.get("username", "guest") - - # Verify project ownership - project = await self.db.get_chatui_project_by_id(project_id) - if not project: - return Response().error(f"Project {project_id} not found").__dict__ - if project.creator != username: - return Response().error("Permission denied").__dict__ - - sessions = await self.db.get_project_sessions(project_id) - - sessions_data = [ - { - "session_id": session.session_id, - "platform_id": session.platform_id, - "creator": session.creator, - "display_name": session.display_name, - "is_group": session.is_group, - "created_at": to_utc_isoformat(session.created_at), - "updated_at": to_utc_isoformat(session.updated_at), - } - for session in sessions - ] - - return Response().ok(data=sessions_data).__dict__ diff --git a/astrbot/dashboard/routes/command.py b/astrbot/dashboard/routes/command.py deleted file mode 100644 index 1921fa4a44..0000000000 --- a/astrbot/dashboard/routes/command.py +++ /dev/null @@ -1,117 +0,0 @@ -from quart import request - -from astrbot.core.star.command_management import ( - list_command_conflicts, - list_commands, -) -from astrbot.core.star.command_management import ( - rename_command as rename_command_service, -) -from astrbot.core.star.command_management import ( - toggle_command as toggle_command_service, -) -from astrbot.core.star.command_management import ( - update_command_permission as update_command_permission_service, -) - -from .route import Response, Route, RouteContext - - -class CommandRoute(Route): - def __init__(self, context: RouteContext, core_lifecycle=None) -> None: - super().__init__(context) - self.core_lifecycle = core_lifecycle - self.routes = { - "/commands": ("GET", self.get_commands), - "/commands/conflicts": ("GET", self.get_conflicts), - "/commands/toggle": ("POST", self.toggle_command), - "/commands/rename": ("POST", self.rename_command), - "/commands/permission": ("POST", self.update_permission), - } - self.register_routes() - - async def get_commands(self): - commands = await list_commands() - summary = { - "total": len(commands), - "disabled": len([cmd for cmd in commands if not cmd["enabled"]]), - "conflicts": len([cmd for cmd in commands if cmd.get("has_conflict")]), - } - # 优先从指定 config_id 的配置中读取唤醒词,否则使用默认配置 - config_id = request.args.get("config_id", "").strip() - wake_prefix = self.config.get("wake_prefix", ["/"]) - if config_id and self.core_lifecycle: - acm = getattr(self.core_lifecycle, "astrbot_config_mgr", None) - if acm and config_id in acm.confs: - wake_prefix = acm.confs[config_id].get("wake_prefix", wake_prefix) - return ( - Response() - .ok({"items": commands, "summary": summary, "wake_prefix": wake_prefix}) - .__dict__ - ) - - async def get_conflicts(self): - conflicts = await list_command_conflicts() - return Response().ok(conflicts).__dict__ - - async def toggle_command(self): - data = await request.get_json() - handler_full_name = data.get("handler_full_name") - enabled = data.get("enabled") - - if handler_full_name is None or enabled is None: - return Response().error("handler_full_name 与 enabled 均为必填。").__dict__ - - if isinstance(enabled, str): - enabled = enabled.lower() in ("1", "true", "yes", "on") - - try: - await toggle_command_service(handler_full_name, bool(enabled)) - except ValueError as exc: - return Response().error(str(exc)).__dict__ - - payload = await _get_command_payload(handler_full_name) - return Response().ok(payload).__dict__ - - async def rename_command(self): - data = await request.get_json() - handler_full_name = data.get("handler_full_name") - new_name = data.get("new_name") - aliases = data.get("aliases") - - if not handler_full_name or not new_name: - return Response().error("handler_full_name 与 new_name 均为必填。").__dict__ - - try: - await rename_command_service(handler_full_name, new_name, aliases=aliases) - except ValueError as exc: - return Response().error(str(exc)).__dict__ - - payload = await _get_command_payload(handler_full_name) - return Response().ok(payload).__dict__ - - async def update_permission(self): - data = await request.get_json() - handler_full_name = data.get("handler_full_name") - permission = data.get("permission") - - if not handler_full_name or not permission: - return ( - Response().error("handler_full_name 与 permission 均为必填。").__dict__ - ) - - try: - await update_command_permission_service(handler_full_name, permission) - except ValueError as exc: - return Response().error(str(exc)).__dict__ - - payload = await _get_command_payload(handler_full_name) - return Response().ok(payload).__dict__ - - -async def _get_command_payload(handler_full_name: str): - commands = await list_commands() - for cmd in commands: - if cmd["handler_full_name"] == handler_full_name: - return cmd - return {} diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py deleted file mode 100644 index aebe26047c..0000000000 --- a/astrbot/dashboard/routes/config.py +++ /dev/null @@ -1,1657 +0,0 @@ -import asyncio -import copy -import inspect -import os -import traceback -from pathlib import Path -from typing import Any - -from quart import jsonify, make_response, request - -from astrbot.core import astrbot_config, file_token_service, logger -from astrbot.core.config.astrbot_config import AstrBotConfig -from astrbot.core.config.default import ( - CONFIG_METADATA_2, - CONFIG_METADATA_3, - CONFIG_METADATA_3_SYSTEM, - DEFAULT_CONFIG, - DEFAULT_VALUE_MAP, -) -from astrbot.core.config.i18n_utils import ConfigMetadataI18n -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.platform.register import platform_cls_map, platform_registry -from astrbot.core.provider import Provider -from astrbot.core.provider.register import provider_registry -from astrbot.core.star.star import StarMetadata, star_registry -from astrbot.core.utils.astrbot_path import ( - get_astrbot_plugin_data_path, -) -from astrbot.core.utils.llm_metadata import LLM_METADATAS -from astrbot.core.utils.totp import ( - TwoFactorCodeType, - is_totp_enabled, - revoke_user_trusted_devices, - set_pending_totp_secret, - verify_configured_2fa_code, -) -from astrbot.core.utils.webhook_utils import ensure_platform_webhook_config - -from .route import Response, Route, RouteContext -from .util import ( - config_key_to_folder, - get_schema_item, - normalize_rel_path, - sanitize_filename, -) - -MAX_FILE_BYTES = 500 * 1024 * 1024 -PROTECTED_2FA_CONFIG_PATHS = ( - ("dashboard", "totp", "enable"), - ("dashboard", "totp", "secret"), - ("dashboard", "totp", "recovery_code_hash"), -) -TWO_FACTOR_CODE_HEADER = "X-2FA-Code" - - -def try_cast(value: Any, type_: str): - if type_ == "int": - try: - return int(value) - except (ValueError, TypeError): - return None - elif ( - type_ == "float" - and isinstance(value, str) - and value.replace(".", "", 1).isdigit() - ) or (type_ == "float" and isinstance(value, int)): - return float(value) - elif type_ == "float": - try: - return float(value) - except (ValueError, TypeError): - return None - - -def _expect_type(value, expected_type, path_key, errors, expected_name=None) -> bool: - if not isinstance(value, expected_type): - errors.append( - f"错误的类型 {path_key}: 期望是 {expected_name or expected_type.__name__}, " - f"得到了 {type(value).__name__}" - ) - return False - return True - - -def _validate_template_list(value, meta, path_key, errors, validate_fn) -> None: - if not _expect_type(value, list, path_key, errors, "list"): - return - - templates = meta.get("templates") - if not isinstance(templates, dict): - templates = {} - - for idx, item in enumerate(value): - item_path = f"{path_key}[{idx}]" - if not _expect_type(item, dict, item_path, errors, "dict"): - continue - - template_key = item.get("__template_key") or item.get("template") - if not template_key: - errors.append(f"缺少模板选择 {item_path}: 需要 __template_key") - continue - - template_meta = templates.get(template_key) - if not template_meta: - errors.append(f"未知模板 {item_path}: {template_key}") - continue - - validate_fn( - item, - template_meta.get("items", {}), - path=f"{path_key}.templates.{template_key}.", - ) - - -def validate_config(data, schema: dict, is_core: bool) -> tuple[list[str], dict]: - errors = [] - - def validate(data: dict, metadata: dict = schema, path="") -> None: - for key, value in data.items(): - if key not in metadata: - continue - meta = metadata[key] - if "type" not in meta: - logger.debug(f"配置项 {path}{key} 没有类型定义, 跳过校验") - continue - # null 转换 - if value is None: - data[key] = DEFAULT_VALUE_MAP[meta["type"]] - continue - - if meta["type"] == "template_list": - _validate_template_list(value, meta, f"{path}{key}", errors, validate) - continue - - if meta["type"] == "file": - if not _expect_type(value, list, f"{path}{key}", errors, "list"): - continue - for idx, item in enumerate(value): - if not isinstance(item, str): - errors.append( - f"Invalid type {path}{key}[{idx}]: expected string, got {type(item).__name__}", - ) - continue - normalized = normalize_rel_path(item) - if not normalized or not normalized.startswith("files/"): - errors.append( - f"Invalid file path {path}{key}[{idx}]: {item}", - ) - continue - key_path = f"{path}{key}" - expected_folder = config_key_to_folder(key_path) - expected_prefix = f"files/{expected_folder}/" - if not normalized.startswith(expected_prefix): - errors.append( - f"Invalid file path {path}{key}[{idx}]: {item}", - ) - continue - value[idx] = normalized - continue - - if meta["type"] == "list" and not isinstance(value, list): - errors.append( - f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}", - ) - elif ( - meta["type"] == "list" - and isinstance(value, list) - and value - and "items" in meta - and isinstance(value[0], dict) - ): - # 当前仅针对 list[dict] 的情况进行类型校验,以适配 AstrBot 中 platform、provider 的配置 - for item in value: - validate(item, meta["items"], path=f"{path}{key}.") - elif meta["type"] == "object" and isinstance(value, dict): - validate(value, meta["items"], path=f"{path}{key}.") - - if meta["type"] == "int" and not isinstance(value, int): - casted = try_cast(value, "int") - if casted is None: - errors.append( - f"错误的类型 {path}{key}: 期望是 int, 得到了 {type(value).__name__}", - ) - data[key] = casted - elif meta["type"] == "float" and not isinstance(value, float): - casted = try_cast(value, "float") - if casted is None: - errors.append( - f"错误的类型 {path}{key}: 期望是 float, 得到了 {type(value).__name__}", - ) - data[key] = casted - elif meta["type"] == "bool" and not isinstance(value, bool): - errors.append( - f"错误的类型 {path}{key}: 期望是 bool, 得到了 {type(value).__name__}", - ) - elif meta["type"] in ["string", "text"] and not isinstance(value, str): - errors.append( - f"错误的类型 {path}{key}: 期望是 string, 得到了 {type(value).__name__}", - ) - elif meta["type"] == "list" and not isinstance(value, list): - errors.append( - f"错误的类型 {path}{key}: 期望是 list, 得到了 {type(value).__name__}", - ) - elif meta["type"] == "object" and not isinstance(value, dict): - errors.append( - f"错误的类型 {path}{key}: 期望是 dict, 得到了 {type(value).__name__}", - ) - - if is_core: - meta_all = { - **schema["platform_group"]["metadata"], - **schema["provider_group"]["metadata"], - **schema["misc_config_group"]["metadata"], - } - validate(data, meta_all) - else: - validate(data, schema) - - return errors, data - - -def _log_computer_config_changes(old_config: dict, new_config: dict) -> None: - """Compare and log Computer/sandbox configuration changes.""" - old_ps = old_config.get("provider_settings", {}) - new_ps = new_config.get("provider_settings", {}) - - # Check computer_use_runtime - old_runtime = old_ps.get("computer_use_runtime", "none") - new_runtime = new_ps.get("computer_use_runtime", "none") - if old_runtime != new_runtime: - logger.info( - "[Computer] Config changed: computer_use_runtime %s -> %s", - old_runtime, - new_runtime, - ) - - # Check sandbox sub-keys - old_sandbox = old_ps.get("sandbox", {}) - new_sandbox = new_ps.get("sandbox", {}) - all_keys = set(old_sandbox.keys()) | set(new_sandbox.keys()) - for key in sorted(all_keys): - old_val = old_sandbox.get(key) - new_val = new_sandbox.get(key) - if old_val != new_val: - # Mask tokens/secrets in log output - if "token" in key or "secret" in key: - old_display = "***" if old_val else "(empty)" - new_display = "***" if new_val else "(empty)" - else: - old_display = old_val - new_display = new_val - logger.info( - "[Computer] Config changed: sandbox.%s %s -> %s", - key, - old_display, - new_display, - ) - - -def _get_nested_value(data: dict, path: tuple[str, ...]) -> Any: - current = data - for key in path: - if not isinstance(current, dict) or key not in current: - return None - current = current[key] - return current - - -def _set_nested_value(data: dict, path: tuple[str, ...], value: Any) -> None: - current = data - for key in path[:-1]: - next_value = current.get(key) - if not isinstance(next_value, dict): - next_value = {} - current[key] = next_value - current = next_value - current[path[-1]] = value - - -def _protected_2fa_config_changed(old_config: dict, new_config: dict) -> bool: - return any( - _get_nested_value(old_config, path) != _get_nested_value(new_config, path) - for path in PROTECTED_2FA_CONFIG_PATHS - ) - - -async def _validate_neo_connectivity( - post_config: dict, -) -> str | None: - """Check if Bay is reachable when Shipyard Neo sandbox is configured. - - Returns a warning message string if Bay isn't reachable, or None if - everything looks fine (or Neo isn't configured). - """ - ps = post_config.get("provider_settings", {}) - runtime = ps.get("computer_use_runtime", "none") - sandbox = ps.get("sandbox", {}) - booter = sandbox.get("booter", "") - - # Only check when sandbox mode + shipyard_neo is selected - if runtime != "sandbox" or booter != "shipyard_neo": - return None - - endpoint = sandbox.get("shipyard_neo_endpoint", "").rstrip("/") - if not endpoint: - return "⚠️ Shipyard Neo endpoint 未设置" - - access_token = sandbox.get("shipyard_neo_access_token", "") - if not access_token: - # Try auto-discovery - from astrbot.core.computer.computer_client import _discover_bay_credentials - - access_token = _discover_bay_credentials(endpoint) - - if not access_token: - return ( - "⚠️ 未找到 Bay API Key。请填写访问令牌," - "或确保 Bay 的 credentials.json 可被自动发现。" - ) - - # Connectivity check - import aiohttp - - health_url = f"{endpoint}/health" - try: - async with aiohttp.ClientSession() as session: - async with session.get( - health_url, - timeout=aiohttp.ClientTimeout(total=5), - ) as resp: - if resp.status != 200: - return ( - f"⚠️ Bay 健康检查失败 (HTTP {resp.status})," - f"请确认 Bay 正在运行: {endpoint}" - ) - except Exception: - return f"⚠️ 无法连接 Bay ({endpoint}),请确认 Bay 已启动。" - - return None - - -def save_config( - post_config: dict, config: AstrBotConfig, is_core: bool = False -) -> None: - """验证并保存配置""" - errors = None - - # Snapshot old Computer config for change detection - if is_core: - _log_computer_config_changes(dict(config), post_config) - - try: - if is_core: - errors, post_config = validate_config( - post_config, - CONFIG_METADATA_2, - is_core, - ) - else: - errors, post_config = validate_config( - post_config, getattr(config, "schema", {}), is_core - ) - except BaseException as e: - logger.error(traceback.format_exc()) - logger.warning(f"验证配置时出现异常: {e}") - raise ValueError(f"验证配置时出现异常: {e}") - if errors: - raise ValueError(f"格式校验未通过: {errors}") - - config.save_config(post_config) - - -class ConfigRoute(Route): - def __init__( - self, - context: RouteContext, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.core_lifecycle = core_lifecycle - self.config: AstrBotConfig = core_lifecycle.astrbot_config - self.db = core_lifecycle.db - self._logo_token_cache = {} # 缓存logo token,避免重复注册 - self.acm = core_lifecycle.astrbot_config_mgr - self.ucr = core_lifecycle.umop_config_router - self.routes = { - "/config/abconf/new": ("POST", self.create_abconf), - "/config/abconf": ("GET", self.get_abconf), - "/config/abconfs": ("GET", self.get_abconf_list), - "/config/abconf/delete": ("POST", self.delete_abconf), - "/config/abconf/update": ("POST", self.update_abconf), - "/config/umo_abconf_routes": ("GET", self.get_uc_table), - "/config/umo_abconf_route/update_all": ("POST", self.update_ucr_all), - "/config/umo_abconf_route/update": ("POST", self.update_ucr), - "/config/umo_abconf_route/delete": ("POST", self.delete_ucr), - "/config/get": ("GET", self.get_configs), - "/config/default": ("GET", self.get_default_config), - "/config/astrbot/update": ("POST", self.post_astrbot_configs), - "/config/plugin/update": ("POST", self.post_plugin_configs), - "/config/file/upload": ("POST", self.upload_config_file), - "/config/file/delete": ("POST", self.delete_config_file), - "/config/file/get": ("GET", self.get_config_file_list), - "/config/platform/new": ("POST", self.post_new_platform), - "/config/platform/update": ("POST", self.post_update_platform), - "/config/platform/delete": ("POST", self.post_delete_platform), - "/config/platform/list": ("GET", self.get_platform_list), - "/config/provider/new": ("POST", self.post_new_provider), - "/config/provider/update": ("POST", self.post_update_provider), - "/config/provider/delete": ("POST", self.post_delete_provider), - "/config/provider/template": ("GET", self.get_provider_template), - "/config/provider/check_one": ("GET", self.check_one_provider_status), - "/config/provider/list": ("GET", self.get_provider_config_list), - "/config/provider/model_list": ("GET", self.get_provider_model_list), - "/config/provider/get_embedding_dim": ("POST", self.get_embedding_dim), - "/config/provider_sources/models": ( - "GET", - self.get_provider_source_models, - ), - "/config/provider_sources/update": ( - "POST", - self.update_provider_source, - ), - "/config/provider_sources/delete": ( - "POST", - self.delete_provider_source, - ), - } - self.register_routes() - - async def delete_provider_source(self): - """删除 provider_source,并更新关联的 providers""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - provider_source_id = post_data.get("id") - if not provider_source_id: - return Response().error("缺少 provider_source_id").__dict__ - - provider_sources = self.config.get("provider_sources", []) - target_idx = next( - ( - i - for i, ps in enumerate(provider_sources) - if ps.get("id") == provider_source_id - ), - -1, - ) - - if target_idx == -1: - return Response().error("未找到对应的 provider source").__dict__ - - # 删除 provider_source - del provider_sources[target_idx] - - # 写回配置 - self.config["provider_sources"] = provider_sources - - # 删除引用了该 provider_source 的 providers - await self.core_lifecycle.provider_manager.delete_provider( - provider_source_id=provider_source_id - ) - - try: - save_config(self.config, self.config, is_core=True) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - return Response().ok(message="删除 provider source 成功").__dict__ - - async def update_provider_source(self): - """更新或新增 provider_source,并重载关联的 providers""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - new_source_config = post_data.get("config") or post_data - original_id = post_data.get("original_id") - if not original_id: - return Response().error("缺少 original_id").__dict__ - - if not isinstance(new_source_config, dict): - return Response().error("缺少或错误的配置数据").__dict__ - - # 确保配置中有 id 字段 - if not new_source_config.get("id"): - new_source_config["id"] = original_id - - provider_sources = self.config.get("provider_sources", []) - - for ps in provider_sources: - if ps.get("id") == new_source_config["id"] and ps.get("id") != original_id: - return ( - Response() - .error( - f"Provider source ID '{new_source_config['id']}' exists already, please try another ID.", - ) - .__dict__ - ) - - # 查找旧的 provider_source,若不存在则追加为新配置 - target_idx = next( - (i for i, ps in enumerate(provider_sources) if ps.get("id") == original_id), - -1, - ) - - old_id = original_id - if target_idx == -1: - provider_sources.append(new_source_config) - else: - old_id = provider_sources[target_idx].get("id") - provider_sources[target_idx] = new_source_config - - # 更新引用了该 provider_source 的 providers - affected_providers = [] - for provider in self.config.get("provider", []): - if provider.get("provider_source_id") == old_id: - provider["provider_source_id"] = new_source_config["id"] - affected_providers.append(provider) - - # 写回配置 - self.config["provider_sources"] = provider_sources - - try: - save_config(self.config, self.config, is_core=True) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - # 重载受影响的 providers,使新的 source 配置生效 - reload_errors = [] - prov_mgr = self.core_lifecycle.provider_manager - for provider in affected_providers: - try: - await prov_mgr.reload(provider) - except Exception as e: - logger.error(traceback.format_exc()) - reload_errors.append(f"{provider.get('id')}: {e}") - - if reload_errors: - return ( - Response() - .error("更新成功,但部分提供商重载失败: " + ", ".join(reload_errors)) - .__dict__ - ) - - return Response().ok(message="更新 provider source 成功").__dict__ - - async def get_provider_template(self): - provider_metadata = ConfigMetadataI18n.convert_to_i18n_keys( - { - "provider_group": { - "metadata": { - "provider": CONFIG_METADATA_2["provider_group"]["metadata"][ - "provider" - ] - } - } - } - ) - config_schema = { - "provider": provider_metadata["provider_group"]["metadata"]["provider"] - } - data = { - "config_schema": config_schema, - "providers": astrbot_config["provider"], - "provider_sources": astrbot_config["provider_sources"], - } - return Response().ok(data=data).__dict__ - - async def get_uc_table(self): - """获取 UMOP 配置路由表""" - return Response().ok({"routing": self.ucr.umop_to_conf_id}).__dict__ - - async def update_ucr_all(self): - """更新 UMOP 配置路由表的全部内容""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - new_routing = post_data.get("routing", None) - - if not new_routing or not isinstance(new_routing, dict): - return Response().error("缺少或错误的路由表数据").__dict__ - - try: - await self.ucr.update_routing_data(new_routing) - return Response().ok(message="更新成功").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"更新路由表失败: {e!s}").__dict__ - - async def update_ucr(self): - """更新 UMOP 配置路由表""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - umo = post_data.get("umo", None) - conf_id = post_data.get("conf_id", None) - - if not umo or not conf_id: - return Response().error("缺少 UMO 或配置文件 ID").__dict__ - - try: - await self.ucr.update_route(umo, conf_id) - return Response().ok(message="更新成功").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"更新路由表失败: {e!s}").__dict__ - - async def delete_ucr(self): - """删除 UMOP 配置路由表中的一项""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - umo = post_data.get("umo", None) - - if not umo: - return Response().error("缺少 UMO").__dict__ - - try: - if umo in self.ucr.umop_to_conf_id: - del self.ucr.umop_to_conf_id[umo] - await self.ucr.update_routing_data(self.ucr.umop_to_conf_id) - return Response().ok(message="删除成功").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"删除路由表项失败: {e!s}").__dict__ - - async def get_default_config(self): - """获取默认配置文件""" - metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3) - return Response().ok({"config": DEFAULT_CONFIG, "metadata": metadata}).__dict__ - - async def get_abconf_list(self): - """获取所有 AstrBot 配置文件的列表""" - abconf_list = self.acm.get_conf_list() - return Response().ok({"info_list": abconf_list}).__dict__ - - async def create_abconf(self): - """创建新的 AstrBot 配置文件""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - name = post_data.get("name", None) - config = post_data.get("config", DEFAULT_CONFIG) - - try: - conf_id = self.acm.create_conf(name=name, config=config) - await self.core_lifecycle.reload_pipeline_scheduler(conf_id) - return Response().ok(message="创建成功", data={"conf_id": conf_id}).__dict__ - except ValueError as e: - return Response().error(str(e)).__dict__ - - async def get_abconf(self): - """获取指定 AstrBot 配置文件""" - abconf_id = request.args.get("id") - system_config = request.args.get("system_config", "0").lower() == "1" - if not abconf_id and not system_config: - return Response().error("缺少配置文件 ID").__dict__ - - try: - if system_config: - abconf = self.acm.confs["default"] - metadata = ConfigMetadataI18n.convert_to_i18n_keys( - CONFIG_METADATA_3_SYSTEM - ) - return Response().ok({"config": abconf, "metadata": metadata}).__dict__ - if abconf_id is None: - raise ValueError("abconf_id cannot be None") - abconf = self.acm.confs[abconf_id] - metadata = ConfigMetadataI18n.convert_to_i18n_keys(CONFIG_METADATA_3) - return Response().ok({"config": abconf, "metadata": metadata}).__dict__ - except ValueError as e: - return Response().error(str(e)).__dict__ - - async def delete_abconf(self): - """删除指定 AstrBot 配置文件""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - conf_id = post_data.get("id") - if not conf_id: - return Response().error("缺少配置文件 ID").__dict__ - - try: - success = self.acm.delete_conf(conf_id) - if success: - self.core_lifecycle.pipeline_scheduler_mapping.pop(conf_id, None) - return Response().ok(message="删除成功").__dict__ - return Response().error("删除失败").__dict__ - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"删除配置文件失败: {e!s}").__dict__ - - async def update_abconf(self): - """更新指定 AstrBot 配置文件信息""" - post_data = await request.json - if not post_data: - return Response().error("缺少配置数据").__dict__ - - conf_id = post_data.get("id") - if not conf_id: - return Response().error("缺少配置文件 ID").__dict__ - - name = post_data.get("name") - - try: - success = self.acm.update_conf_info(conf_id, name=name) - if success: - return Response().ok(message="更新成功").__dict__ - return Response().error("更新失败").__dict__ - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"更新配置文件失败: {e!s}").__dict__ - - async def _test_single_provider(self, provider): - """辅助函数:测试单个 provider 的可用性""" - meta = provider.meta() - provider_name = provider.provider_config.get("id", "Unknown Provider") - provider_capability_type = meta.provider_type - - status_info = { - "id": getattr(meta, "id", "Unknown ID"), - "model": getattr(meta, "model", "Unknown Model"), - "type": provider_capability_type.value, - "name": provider_name, - "status": "unavailable", # 默认为不可用 - "error": None, - } - logger.debug( - f"Attempting to check provider: {status_info['name']} (ID: {status_info['id']}, Type: {status_info['type']}, Model: {status_info['model']})", - ) - - try: - await provider.test() - status_info["status"] = "available" - logger.info( - f"Provider {status_info['name']} (ID: {status_info['id']}) is available.", - ) - except Exception as e: - error_message = str(e) - status_info["error"] = error_message - logger.warning( - f"Provider {status_info['name']} (ID: {status_info['id']}) is unavailable. Error: {error_message}", - ) - logger.debug( - f"Traceback for {status_info['name']}:\n{traceback.format_exc()}", - ) - - return status_info - - def _error_response( - self, - message: str, - status_code: int = 500, - log_fn=logger.error, - ): - log_fn(message) - # 记录更详细的traceback信息,但只在是严重错误时 - if status_code == 500: - log_fn(traceback.format_exc()) - return Response().error(message).__dict__ - - async def check_one_provider_status(self): - """API: check a single LLM Provider's status by id""" - provider_id = request.args.get("id") - if not provider_id: - return self._error_response( - "Missing provider_id parameter", - 400, - logger.warning, - ) - - logger.info(f"API call: /config/provider/check_one id={provider_id}") - try: - prov_mgr = self.core_lifecycle.provider_manager - target = prov_mgr.inst_map.get(provider_id) - - if not target: - logger.warning( - f"Provider with id '{provider_id}' not found in provider_manager.", - ) - return ( - Response() - .error(f"Provider with id '{provider_id}' not found") - .__dict__ - ) - - result = await self._test_single_provider(target) - return Response().ok(result).__dict__ - - except Exception as e: - return self._error_response( - f"Critical error checking provider {provider_id}: {e}", - 500, - ) - - async def get_configs(self): - # plugin_name 为空时返回 AstrBot 配置 - # 否则返回指定 plugin_name 的插件配置 - plugin_name = request.args.get("plugin_name", None) - if not plugin_name: - return Response().ok(await self._get_astrbot_config()).__dict__ - return Response().ok(await self._get_plugin_config(plugin_name)).__dict__ - - async def get_provider_config_list(self): - provider_type = request.args.get("provider_type", None) - if not provider_type: - return Response().error("缺少参数 provider_type").__dict__ - provider_type_ls = provider_type.split(",") - provider_list = [] - ps = self.core_lifecycle.provider_manager.providers_config - p_source_pt = { - psrc["id"]: psrc.get("provider_type", "chat_completion") - for psrc in self.core_lifecycle.provider_manager.provider_sources_config - } - for provider in ps: - ps_id = provider.get("provider_source_id", None) - if ( - ps_id - and ps_id in p_source_pt - and p_source_pt[ps_id] in provider_type_ls - ): - # chat - prov = self.core_lifecycle.provider_manager.get_merged_provider_config( - provider - ) - provider_list.append(prov) - elif not ps_id and provider.get("provider_type", "") in provider_type_ls: - # agent runner, embedding, etc - provider_list.append(provider) - return Response().ok(provider_list).__dict__ - - async def get_provider_model_list(self): - """获取指定提供商的模型列表""" - provider_id = request.args.get("provider_id", None) - if not provider_id: - return Response().error("缺少参数 provider_id").__dict__ - - prov_mgr = self.core_lifecycle.provider_manager - provider = prov_mgr.inst_map.get(provider_id, None) - if not provider: - return Response().error(f"未找到 ID 为 {provider_id} 的提供商").__dict__ - if not isinstance(provider, Provider): - return ( - Response() - .error(f"提供商 {provider_id} 类型不支持获取模型列表") - .__dict__ - ) - - try: - models = await provider.get_models() - models = models or [] - - metadata_map = {} - for model_id in models: - meta = LLM_METADATAS.get(model_id) - if meta: - metadata_map[model_id] = meta - - ret = { - "models": models, - "provider_id": provider_id, - "model_metadata": metadata_map, - } - return Response().ok(ret).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def get_embedding_dim(self): - """获取嵌入模型的维度""" - post_data = await request.json - provider_config = post_data.get("provider_config", None) - if not provider_config: - return Response().error("缺少参数 provider_config").__dict__ - - try: - # 动态导入 EmbeddingProvider - from astrbot.core.provider.provider import EmbeddingProvider - from astrbot.core.provider.register import provider_cls_map - - # 获取 provider 类型 - provider_type = provider_config.get("type", None) - if not provider_type: - return Response().error("provider_config 缺少 type 字段").__dict__ - - # 首次添加某类提供商时,provider_cls_map 可能尚未注册该适配器 - if provider_type not in provider_cls_map: - try: - self.core_lifecycle.provider_manager.dynamic_import_provider( - provider_type, - ) - except ImportError: - logger.error(traceback.format_exc()) - return ( - Response() - .error( - "提供商适配器加载失败,请检查提供商类型配置或查看服务端日志" - ) - .__dict__ - ) - - # 获取对应的 provider 类 - if provider_type not in provider_cls_map: - return ( - Response() - .error(f"未找到适用于 {provider_type} 的提供商适配器") - .__dict__ - ) - - provider_metadata = provider_cls_map[provider_type] - cls_type = provider_metadata.cls_type - - if not cls_type: - return Response().error(f"无法找到 {provider_type} 的类").__dict__ - - # 实例化 provider - inst = cls_type(provider_config, {}) - - # 检查是否是 EmbeddingProvider - if not isinstance(inst, EmbeddingProvider): - return Response().error("提供商不是 EmbeddingProvider 类型").__dict__ - - init_fn = getattr(inst, "initialize", None) - if inspect.iscoroutinefunction(init_fn): - await init_fn() - - # 通过实际请求验证当前 embedding_dimensions 是否可用 - vec = await inst.get_embedding("echo") - dim = len(vec) - - logger.info( - f"检测到 {provider_config.get('id', 'unknown')} 的嵌入向量维度为 {dim}", - ) - - return Response().ok({"embedding_dimensions": dim}).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"获取嵌入维度失败: {e!s}").__dict__ - - async def get_provider_source_models(self): - """获取指定 provider_source 支持的模型列表 - - 本质上会临时初始化一个 Provider 实例,调用 get_models() 获取模型列表,然后销毁实例 - """ - provider_source_id = request.args.get("source_id") - if not provider_source_id: - return Response().error("缺少参数 source_id").__dict__ - - try: - from astrbot.core.provider.register import provider_cls_map - - # 从配置中查找对应的 provider_source - provider_sources = self.config.get("provider_sources", []) - provider_source = None - for ps in provider_sources: - if ps.get("id") == provider_source_id: - provider_source = ps - break - - if not provider_source: - return ( - Response() - .error(f"未找到 ID 为 {provider_source_id} 的 provider_source") - .__dict__ - ) - - # 获取 provider 类型 - provider_type = provider_source.get("type", None) - if not provider_type: - return Response().error("provider_source 缺少 type 字段").__dict__ - - try: - self.core_lifecycle.provider_manager.dynamic_import_provider( - provider_type - ) - except ImportError as e: - logger.error(traceback.format_exc()) - return Response().error(f"动态导入提供商适配器失败: {e!s}").__dict__ - - # 获取对应的 provider 类 - if provider_type not in provider_cls_map: - return ( - Response() - .error(f"未找到适用于 {provider_type} 的提供商适配器") - .__dict__ - ) - - provider_metadata = provider_cls_map[provider_type] - cls_type = provider_metadata.cls_type - - if not cls_type: - return Response().error(f"无法找到 {provider_type} 的类").__dict__ - - # 检查是否是 Provider 类型 - if not issubclass(cls_type, Provider): - return ( - Response() - .error(f"提供商 {provider_type} 不支持获取模型列表") - .__dict__ - ) - - # 临时实例化 provider - inst = cls_type(provider_source, {}) - - # 如果有 initialize 方法,调用它 - init_fn = getattr(inst, "initialize", None) - if inspect.iscoroutinefunction(init_fn): - await init_fn() - - # 获取模型列表 - models = await inst.get_models() - models = models or [] - - metadata_map = {} - for model_id in models: - meta = LLM_METADATAS.get(model_id) - if meta: - metadata_map[model_id] = meta - - # 销毁实例(如果有 terminate 方法) - terminate_fn = getattr(inst, "terminate", None) - if inspect.iscoroutinefunction(terminate_fn): - await terminate_fn() - - return ( - Response() - .ok({"models": models, "model_metadata": metadata_map}) - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"获取模型列表失败: {e!s}").__dict__ - - async def get_platform_list(self): - """获取所有平台的列表""" - platform_list = [] - for platform in self.config["platform"]: - platform_list.append(platform) - return Response().ok({"platforms": platform_list}).__dict__ - - async def post_astrbot_configs(self): - data = await request.json - if not isinstance(data, dict): - return Response().error("Invalid request payload").__dict__ - config = data.get("config", None) - conf_id = data.get("conf_id", None) - - try: - if not isinstance(config, dict): - return Response().error("Invalid config payload").__dict__ - - if conf_id not in self.acm.confs: - raise ValueError(f"Config file {conf_id} does not exist") - - # 不更新 provider_sources, provider, platform - # 这些配置有单独的接口进行更新 - if conf_id == "default": - no_update_keys = ["provider_sources", "provider", "platform"] - for key in no_update_keys: - config[key] = self.acm.default_conf[key] - - current_config = self.acm.confs[conf_id] - protected_2fa_changed = _protected_2fa_config_changed( - current_config, config - ) - verified_2fa = None - if await self._requires_config_2fa(current_config, protected_2fa_changed): - verified_2fa = await self._verify_config_2fa(current_config) - if not verified_2fa: - return await self._config_2fa_required_response() - - if not _get_nested_value(config, ("dashboard", "totp", "enable")): - _set_nested_value(config, ("dashboard", "totp", "secret"), "") - _set_nested_value( - config, ("dashboard", "totp", "recovery_code_hash"), "" - ) - - set_pending_totp_secret(None) - await self._save_astrbot_configs(config, conf_id) - if protected_2fa_changed: - await revoke_user_trusted_devices(self.db) - await self.core_lifecycle.reload_pipeline_scheduler(conf_id) - - # Non-blocking Bay connectivity check - warning = await _validate_neo_connectivity(config) - if warning: - return Response().ok(None, f"保存成功。{warning}").__dict__ - return Response().ok(None, "保存成功~").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def _requires_config_2fa( - self, current_config: dict, protected_2fa_changed: bool - ) -> bool: - if not is_totp_enabled(current_config): - return False - if not protected_2fa_changed: - return False - return True - - async def _verify_config_2fa( - self, current_config: dict - ) -> TwoFactorCodeType | None: - code = request.headers.get(TWO_FACTOR_CODE_HEADER, "").strip() - if not code: - return None - - return await verify_configured_2fa_code( - current_config, code, include_pending=True, allow_recovery=False - ) - - async def _config_2fa_required_response(self): - response = await make_response( - jsonify( - { - "status": "error", - "data": { - "totp_required": True, - }, - } - ) - ) - response.status_code = 401 - return response - - async def post_plugin_configs(self): - post_configs = await request.json - plugin_name = request.args.get("plugin_name", "unknown") - try: - await self._save_plugin_configs(post_configs, plugin_name) - await self.core_lifecycle.plugin_manager.reload(plugin_name) - return ( - Response() - .ok(None, f"保存插件 {plugin_name} 成功~ 机器人正在热重载插件。") - .__dict__ - ) - except Exception as e: - return Response().error(str(e)).__dict__ - - def _get_plugin_metadata_by_name(self, plugin_name: str) -> StarMetadata | None: - for plugin_md in star_registry: - if plugin_md.name == plugin_name: - return plugin_md - return None - - def _resolve_config_file_scope( - self, - ) -> tuple[str, str, str, StarMetadata, AstrBotConfig]: - """将请求参数解析为一个明确的配置作用域。 - - 当前支持的 scope: - - scope=plugin:name=,key= - """ - - scope = request.args.get("scope") or "plugin" - name = request.args.get("name") - key_path = request.args.get("key") - - if scope != "plugin": - raise ValueError(f"Unsupported scope: {scope}") - if not name or not key_path: - raise ValueError("Missing name or key parameter") - - md = self._get_plugin_metadata_by_name(name) - if not md or not md.config: - raise ValueError(f"Plugin {name} not found or has no config") - - return scope, name, key_path, md, md.config - - async def upload_config_file(self): - """上传文件到插件数据目录(用于某个 file 类型配置项)。""" - - try: - scope, name, key_path, md, config = self._resolve_config_file_scope() - except ValueError as e: - return Response().error(str(e)).__dict__ - - meta = get_schema_item(getattr(config, "schema", None), key_path) - if not meta or meta.get("type") != "file": - return Response().error("Config item not found or not file type").__dict__ - - file_types = meta.get("file_types") - allowed_exts: list[str] = [] - if isinstance(file_types, list): - allowed_exts = [ - str(ext).lstrip(".").lower() for ext in file_types if str(ext).strip() - ] - - files = await request.files - if not files: - return Response().error("No files uploaded").__dict__ - - storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) - plugin_root_path = (storage_root_path / name).resolve(strict=False) - try: - plugin_root_path.relative_to(storage_root_path) - except ValueError: - return Response().error("Invalid name parameter").__dict__ - plugin_root_path.mkdir(parents=True, exist_ok=True) - - uploaded: list[str] = [] - folder = config_key_to_folder(key_path) - errors: list[str] = [] - for file in files.values(): - filename = sanitize_filename(file.filename or "") - if not filename: - errors.append("Invalid filename") - continue - - file_size = getattr(file, "content_length", None) - if isinstance(file_size, int) and file_size > MAX_FILE_BYTES: - errors.append(f"File too large: {filename}") - continue - - ext = os.path.splitext(filename)[1].lstrip(".").lower() - if allowed_exts and ext not in allowed_exts: - errors.append(f"Unsupported file type: {filename}") - continue - - rel_path = f"files/{folder}/{filename}" - save_path = (plugin_root_path / rel_path).resolve(strict=False) - try: - save_path.relative_to(plugin_root_path) - except ValueError: - errors.append(f"Invalid path: {filename}") - continue - - save_path.parent.mkdir(parents=True, exist_ok=True) - await file.save(str(save_path)) - if save_path.is_file() and save_path.stat().st_size > MAX_FILE_BYTES: - save_path.unlink() - errors.append(f"File too large: {filename}") - continue - uploaded.append(rel_path) - - if not uploaded: - return ( - Response() - .error( - "Upload failed: " + ", ".join(errors) - if errors - else "Upload failed", - ) - .__dict__ - ) - - return Response().ok({"uploaded": uploaded, "errors": errors}).__dict__ - - async def delete_config_file(self): - """删除插件数据目录中的文件。""" - - scope = request.args.get("scope") or "plugin" - name = request.args.get("name") - if not name: - return Response().error("Missing name parameter").__dict__ - if scope != "plugin": - return Response().error(f"Unsupported scope: {scope}").__dict__ - - data = await request.get_json() - rel_path = data.get("path") if isinstance(data, dict) else None - rel_path = normalize_rel_path(rel_path) - if not rel_path or not rel_path.startswith("files/"): - return Response().error("Invalid path parameter").__dict__ - - md = self._get_plugin_metadata_by_name(name) - if not md: - return Response().error(f"Plugin {name} not found").__dict__ - - storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) - plugin_root_path = (storage_root_path / name).resolve(strict=False) - try: - plugin_root_path.relative_to(storage_root_path) - except ValueError: - return Response().error("Invalid name parameter").__dict__ - target_path = (plugin_root_path / rel_path).resolve(strict=False) - try: - target_path.relative_to(plugin_root_path) - except ValueError: - return Response().error("Invalid path parameter").__dict__ - if target_path.is_file(): - target_path.unlink() - - return Response().ok(None, "Deleted").__dict__ - - async def get_config_file_list(self): - """获取配置项对应目录下的文件列表。""" - - try: - _, name, key_path, _, config = self._resolve_config_file_scope() - except ValueError as e: - return Response().error(str(e)).__dict__ - - meta = get_schema_item(getattr(config, "schema", None), key_path) - if not meta or meta.get("type") != "file": - return Response().error("Config item not found or not file type").__dict__ - - storage_root_path = Path(get_astrbot_plugin_data_path()).resolve(strict=False) - plugin_root_path = (storage_root_path / name).resolve(strict=False) - try: - plugin_root_path.relative_to(storage_root_path) - except ValueError: - return Response().error("Invalid name parameter").__dict__ - - folder = config_key_to_folder(key_path) - target_dir = (plugin_root_path / "files" / folder).resolve(strict=False) - try: - target_dir.relative_to(plugin_root_path) - except ValueError: - return Response().error("Invalid path parameter").__dict__ - - if not target_dir.exists() or not target_dir.is_dir(): - return Response().ok({"files": []}).__dict__ - - files: list[str] = [] - for path in target_dir.rglob("*"): - if not path.is_file(): - continue - try: - rel_path = path.relative_to(plugin_root_path).as_posix() - except ValueError: - continue - if rel_path.startswith("files/"): - files.append(rel_path) - - return Response().ok({"files": files}).__dict__ - - async def post_new_platform(self): - new_platform_config = await request.json - - # 如果是支持统一 webhook 模式的平台,生成 webhook_uuid - ensure_platform_webhook_config(new_platform_config) - - self.config["platform"].append(new_platform_config) - try: - save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.platform_manager.load_platform( - new_platform_config, - ) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "新增平台配置成功~").__dict__ - - async def post_new_provider(self): - new_provider_config = await request.json - - try: - await self.core_lifecycle.provider_manager.create_provider( - new_provider_config - ) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "新增服务提供商配置成功").__dict__ - - async def post_update_platform(self): - update_platform_config = await request.json - origin_platform_id = update_platform_config.get("id", None) - new_config = update_platform_config.get("config", None) - if not origin_platform_id or not new_config: - return Response().error("参数错误").__dict__ - - if origin_platform_id != new_config.get("id", None): - return Response().error("机器人名称不允许修改").__dict__ - - # 如果是支持统一 webhook 模式的平台,且启用了统一 webhook 模式,确保有 webhook_uuid - ensure_platform_webhook_config(new_config) - - for i, platform in enumerate(self.config["platform"]): - if platform["id"] == origin_platform_id: - self.config["platform"][i] = new_config - break - else: - return Response().error("未找到对应平台").__dict__ - - try: - save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.platform_manager.reload(new_config) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "更新平台配置成功~").__dict__ - - async def post_update_provider(self): - update_provider_config = await request.json - origin_provider_id = update_provider_config.get("id", None) - new_config = update_provider_config.get("config", None) - if not origin_provider_id or not new_config: - return Response().error("参数错误").__dict__ - - try: - await self.core_lifecycle.provider_manager.update_provider( - origin_provider_id, new_config - ) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "更新成功,已经实时生效~").__dict__ - - async def post_delete_platform(self): - platform_id = await request.json - platform_id = platform_id.get("id") - for i, platform in enumerate(self.config["platform"]): - if platform["id"] == platform_id: - del self.config["platform"][i] - break - else: - return Response().error("未找到对应平台").__dict__ - try: - save_config(self.config, self.config, is_core=True) - await self.core_lifecycle.platform_manager.terminate_platform(platform_id) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "删除平台配置成功~").__dict__ - - async def post_delete_provider(self): - provider_id = await request.json - provider_id = provider_id.get("id", "") - if not provider_id: - return Response().error("缺少参数 id").__dict__ - - try: - await self.core_lifecycle.provider_manager.delete_provider( - provider_id=provider_id - ) - except Exception as e: - return Response().error(str(e)).__dict__ - return Response().ok(None, "删除成功,已经实时生效。").__dict__ - - async def get_llm_tools(self): - """获取函数调用工具。包含了本地加载的以及 MCP 服务的工具""" - tool_mgr = self.core_lifecycle.provider_manager.llm_tools - tools = tool_mgr.get_func_desc_openai_style() - return Response().ok(tools).__dict__ - - async def _register_platform_logo(self, platform, platform_default_tmpl) -> None: - """注册平台logo文件并生成访问令牌""" - if not platform.logo_path: - return - - try: - # 检查缓存 - cache_key = f"{platform.name}:{platform.logo_path}" - if cache_key in self._logo_token_cache: - cached_token = self._logo_token_cache[cache_key] - # 确保platform_default_tmpl[platform.name]存在且为字典 - if platform.name not in platform_default_tmpl or not isinstance( - platform_default_tmpl[platform.name], dict - ): - platform_default_tmpl[platform.name] = {} - platform_default_tmpl[platform.name]["logo_token"] = cached_token - logger.debug(f"Using cached logo token for platform {platform.name}") - return - - # 获取平台适配器类 - platform_cls = platform_cls_map.get(platform.name) - if not platform_cls: - logger.warning(f"Platform class not found for {platform.name}") - return - - # 获取插件目录路径 - module_file = inspect.getfile(platform_cls) - plugin_dir = os.path.dirname(module_file) - - # 解析logo文件路径 - logo_file_path = os.path.join(plugin_dir, platform.logo_path) - - # 检查文件是否存在并注册令牌 - if os.path.exists(logo_file_path): - logo_token = await file_token_service.register_file( - logo_file_path, - timeout=3600, - ) - - # 确保platform_default_tmpl[platform.name]存在且为字典 - if platform.name not in platform_default_tmpl or not isinstance( - platform_default_tmpl[platform.name], dict - ): - platform_default_tmpl[platform.name] = {} - - platform_default_tmpl[platform.name]["logo_token"] = logo_token - - # 缓存token - self._logo_token_cache[cache_key] = logo_token - - logger.debug(f"Logo token registered for platform {platform.name}") - else: - logger.warning( - f"Platform {platform.name} logo file not found: {logo_file_path}", - ) - - except (ImportError, AttributeError) as e: - logger.warning( - f"Failed to import required modules for platform {platform.name}: {e}", - ) - except OSError as e: - logger.warning(f"File system error for platform {platform.name} logo: {e}") - except Exception as e: - logger.warning( - f"Unexpected error registering logo for platform {platform.name}: {e}", - ) - - def _inject_platform_metadata_with_i18n( - self, platform, metadata, platform_i18n_translations: dict - ): - """将配置元数据注入到 metadata 中并处理国际化键转换。""" - metadata["platform_group"]["metadata"]["platform"].setdefault("items", {}) - platform_items_to_inject = copy.deepcopy(platform.config_metadata) - - if platform.i18n_resources: - i18n_prefix = f"platform_group.platform.{platform.name}" - - for lang, lang_data in platform.i18n_resources.items(): - platform_i18n_translations.setdefault(lang, {}).setdefault( - "platform_group", {} - ).setdefault("platform", {})[platform.name] = lang_data - - for field_key, field_value in platform_items_to_inject.items(): - for key in ("description", "hint", "labels"): - if key in field_value: - field_value[key] = f"{i18n_prefix}.{field_key}.{key}" - - metadata["platform_group"]["metadata"]["platform"]["items"].update( - platform_items_to_inject - ) - - async def _get_astrbot_config(self): - config = self.config - metadata = copy.deepcopy(CONFIG_METADATA_2) - platform_i18n = ConfigMetadataI18n.convert_to_i18n_keys( - { - "platform_group": { - "metadata": { - "platform": metadata["platform_group"]["metadata"]["platform"] - } - } - } - ) - metadata["platform_group"]["metadata"]["platform"] = platform_i18n[ - "platform_group" - ]["metadata"]["platform"] - - # 平台适配器的默认配置模板注入 - platform_default_tmpl = metadata["platform_group"]["metadata"]["platform"][ - "config_template" - ] - - # 收集平台的 i18n 翻译数据 - platform_i18n_translations = {} - - # 收集需要注册logo的平台 - logo_registration_tasks = [] - for platform in platform_registry: - if platform.default_config_tmpl: - platform_default_tmpl[platform.name] = copy.deepcopy( - platform.default_config_tmpl - ) - - # 注入配置元数据(在 convert_to_i18n_keys 之后,使用国际化键) - if platform.config_metadata: - self._inject_platform_metadata_with_i18n( - platform, metadata, platform_i18n_translations - ) - - # 收集logo注册任务 - if platform.logo_path: - logo_registration_tasks.append( - self._register_platform_logo(platform, platform_default_tmpl), - ) - - # 并行执行logo注册 - if logo_registration_tasks: - await asyncio.gather(*logo_registration_tasks, return_exceptions=True) - - # 服务提供商的默认配置模板注入 - provider_default_tmpl = metadata["provider_group"]["metadata"]["provider"][ - "config_template" - ] - for provider in provider_registry: - if provider.default_config_tmpl: - provider_default_tmpl[provider.type] = provider.default_config_tmpl - - return { - "metadata": metadata, - "config": config, - "platform_i18n_translations": platform_i18n_translations, - } - - async def _get_plugin_config(self, plugin_name: str): - ret: dict = {"metadata": None, "config": None, "i18n": {}} - - for plugin_md in star_registry: - if plugin_md.name == plugin_name: - if not plugin_md.config: - break - ret["config"] = ( - plugin_md.config - ) # 这是自定义的 Dict 类(AstrBotConfig) - ret["metadata"] = { - plugin_name: { - "description": f"{plugin_name} 配置", - "type": "object", - "items": plugin_md.config.schema, # 初始化时通过 __setattr__ 存入了 schema - }, - } - ret["i18n"] = plugin_md.i18n - break - - return ret - - async def _save_astrbot_configs( - self, post_configs: dict, conf_id: str | None = None - ) -> None: - try: - if conf_id not in self.acm.confs: - raise ValueError(f"配置文件 {conf_id} 不存在") - astrbot_config = self.acm.confs[conf_id] - - # 保留服务端的 t2i_active_template 值 - if "t2i_active_template" in astrbot_config: - post_configs["t2i_active_template"] = astrbot_config[ - "t2i_active_template" - ] - - save_config(post_configs, astrbot_config, is_core=True) - except Exception as e: - raise e - - async def _save_plugin_configs(self, post_configs: dict, plugin_name: str) -> None: - md = None - for plugin_md in star_registry: - if plugin_md.name == plugin_name: - md = plugin_md - - if not md: - raise ValueError(f"插件 {plugin_name} 不存在") - if not md.config: - raise ValueError(f"插件 {plugin_name} 没有注册配置") - assert md.config is not None - - try: - errors, post_configs = validate_config( - post_configs, getattr(md.config, "schema", {}), is_core=False - ) - if errors: - raise ValueError(f"格式校验未通过: {errors}") - md.config.save_config(post_configs) - except Exception as e: - raise e diff --git a/astrbot/dashboard/routes/conversation.py b/astrbot/dashboard/routes/conversation.py deleted file mode 100644 index e15837fd8d..0000000000 --- a/astrbot/dashboard/routes/conversation.py +++ /dev/null @@ -1,402 +0,0 @@ -import json -import traceback -from dataclasses import asdict -from datetime import datetime -from io import BytesIO - -from quart import request, send_file - -from astrbot.core import logger -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.db import BaseDatabase -from astrbot.core.umo_alias import build_umo_alias_map, parse_umo, serialize_umo_alias - -from .route import Response, Route, RouteContext - - -class ConversationRoute(Route): - def __init__( - self, - context: RouteContext, - db_helper: BaseDatabase, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.routes = { - "/conversation/list": ("GET", self.list_conversations), - "/conversation/detail": ( - "POST", - self.get_conv_detail, - ), - "/conversation/update": ("POST", self.upd_conv), - "/conversation/delete": ("POST", self.del_conv), - "/conversation/update_history": ( - "POST", - self.update_history, - ), - "/conversation/export": ("POST", self.export_conversations), - } - self.db_helper = db_helper - self.conv_mgr = core_lifecycle.conversation_manager - self.core_lifecycle = core_lifecycle - self.register_routes() - - def _build_umo_info(self, umo: str | None, alias_map: dict) -> dict: - umo_str = umo or "" - return { - "umo": umo_str, - **parse_umo(umo_str), - **serialize_umo_alias(alias_map.get(umo_str), umo_str), - } - - def _serialize_conversation(self, conversation, alias_map: dict) -> dict: - return { - **asdict(conversation), - "umo_info": self._build_umo_info(conversation.user_id, alias_map), - } - - async def list_conversations(self): - """获取对话列表,支持分页、排序和筛选""" - try: - # 获取分页参数 - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) - - # 获取筛选参数 - platforms = request.args.get("platforms", "") - message_types = request.args.get("message_types", "") - search_query = request.args.get("search", "") - exclude_ids = request.args.get("exclude_ids", "") - exclude_platforms = request.args.get("exclude_platforms", "") - - # 转换为列表 - platform_list = platforms.split(",") if platforms else [] - message_type_list = message_types.split(",") if message_types else [] - exclude_id_list = exclude_ids.split(",") if exclude_ids else [] - exclude_platform_list = ( - exclude_platforms.split(",") if exclude_platforms else [] - ) - - page = max(page, 1) - if page_size < 1: - page_size = 20 - page_size = min(page_size, 100) - - try: - ( - conversations, - total_count, - ) = await self.conv_mgr.get_filtered_conversations( - page=page, - page_size=page_size, - platforms=platform_list, - message_types=message_type_list, - search_query=search_query, - exclude_ids=exclude_id_list, - exclude_platforms=exclude_platform_list, - ) - except Exception as e: - logger.error(f"数据库查询出错: {e!s}\n{traceback.format_exc()}") - return Response().error(f"数据库查询出错: {e!s}").__dict__ - - # 计算总页数 - total_pages = ( - (total_count + page_size - 1) // page_size if total_count > 0 else 1 - ) - umos = sorted({conv.user_id for conv in conversations if conv.user_id}) - alias_map = build_umo_alias_map(await self.db_helper.get_umo_aliases(umos)) - - result = { - "conversations": [ - self._serialize_conversation(conversation, alias_map) - for conversation in conversations - ], - "pagination": { - "page": page, - "page_size": page_size, - "total": total_count, - "total_pages": total_pages, - }, - } - return Response().ok(result).__dict__ - - except Exception as e: - error_msg = f"获取对话列表失败: {e!s}\n{traceback.format_exc()}" - logger.error(error_msg) - return Response().error(f"获取对话列表失败: {e!s}").__dict__ - - async def get_conv_detail(self): - """获取指定对话详情(通过POST请求)""" - try: - data = await request.get_json() - user_id = data.get("user_id") - cid = data.get("cid") - - if not user_id or not cid: - return Response().error("缺少必要参数: user_id 和 cid").__dict__ - - conversation = await self.conv_mgr.get_conversation( - unified_msg_origin=user_id, - conversation_id=cid, - ) - if not conversation: - return Response().error("对话不存在").__dict__ - - alias_map = build_umo_alias_map( - await self.db_helper.get_umo_aliases([user_id]) - ) - return ( - Response() - .ok( - { - "user_id": user_id, - "cid": cid, - "title": conversation.title, - "persona_id": conversation.persona_id, - "history": conversation.history, - "created_at": conversation.created_at, - "updated_at": conversation.updated_at, - "umo_info": self._build_umo_info(user_id, alias_map), - }, - ) - .__dict__ - ) - - except Exception as e: - logger.error(f"获取对话详情失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"获取对话详情失败: {e!s}").__dict__ - - async def upd_conv(self): - """更新对话信息(标题和角色ID)""" - try: - data = await request.get_json() - user_id = data.get("user_id") - cid = data.get("cid") - title = data.get("title") - - if not user_id or not cid: - return Response().error("缺少必要参数: user_id 和 cid").__dict__ - conversation = await self.conv_mgr.get_conversation( - unified_msg_origin=user_id, - conversation_id=cid, - ) - if not conversation: - return Response().error("对话不存在").__dict__ - - persona_id = data.get("persona_id", conversation.persona_id) - - if title is not None or persona_id is not None: - await self.conv_mgr.update_conversation( - unified_msg_origin=user_id, - conversation_id=cid, - title=title, - persona_id=persona_id, - ) - return Response().ok({"message": "对话信息更新成功"}).__dict__ - - except Exception as e: - logger.error(f"更新对话信息失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"更新对话信息失败: {e!s}").__dict__ - - async def del_conv(self): - """删除对话""" - try: - data = await request.get_json() - - # 检查是否是批量删除 - if "conversations" in data: - # 批量删除 - conversations = data.get("conversations", []) - if not conversations: - return ( - Response().error("批量删除时conversations参数不能为空").__dict__ - ) - - deleted_count = 0 - failed_items = [] - - for conv in conversations: - user_id = conv.get("user_id") - cid = conv.get("cid") - - if not user_id or not cid: - failed_items.append( - f"user_id:{user_id}, cid:{cid} - 缺少必要参数", - ) - continue - - try: - await self.core_lifecycle.conversation_manager.delete_conversation( - unified_msg_origin=user_id, - conversation_id=cid, - ) - deleted_count += 1 - except Exception as e: - failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}") - - message = f"成功删除 {deleted_count} 个对话" - if failed_items: - message += f",失败 {len(failed_items)} 个" - - return ( - Response() - .ok( - { - "message": message, - "deleted_count": deleted_count, - "failed_count": len(failed_items), - "failed_items": failed_items, - }, - ) - .__dict__ - ) - # 单个删除 - user_id = data.get("user_id") - cid = data.get("cid") - - if not user_id or not cid: - return Response().error("缺少必要参数: user_id 和 cid").__dict__ - - await self.core_lifecycle.conversation_manager.delete_conversation( - unified_msg_origin=user_id, - conversation_id=cid, - ) - return Response().ok({"message": "对话删除成功"}).__dict__ - - except Exception as e: - logger.error(f"删除对话失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"删除对话失败: {e!s}").__dict__ - - async def update_history(self): - """更新对话历史内容""" - try: - data = await request.get_json() - user_id = data.get("user_id") - cid = data.get("cid") - history = data.get("history") - - if not user_id or not cid: - return Response().error("缺少必要参数: user_id 和 cid").__dict__ - - if history is None: - return Response().error("缺少必要参数: history").__dict__ - - # 历史记录必须是合法的 JSON 字符串 - try: - if isinstance(history, list): - history = json.dumps(history) - else: - # 验证是否为有效的 JSON 字符串 - json.loads(history) - except json.JSONDecodeError: - return ( - Response().error("history 必须是有效的 JSON 字符串或数组").__dict__ - ) - - conversation = await self.conv_mgr.get_conversation( - unified_msg_origin=user_id, - conversation_id=cid, - ) - if not conversation: - return Response().error("对话不存在").__dict__ - - history = json.loads(history) if isinstance(history, str) else history - - await self.conv_mgr.update_conversation( - unified_msg_origin=user_id, - conversation_id=cid, - history=history, - ) - - return Response().ok({"message": "对话历史更新成功"}).__dict__ - - except Exception as e: - logger.error(f"更新对话历史失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"更新对话历史失败: {e!s}").__dict__ - - async def export_conversations(self): - """批量导出对话为 JSONL 格式""" - try: - data = await request.get_json() - conversations_to_export = data.get("conversations", []) - - if not conversations_to_export: - return Response().error("导出列表不能为空").__dict__ - - # 收集所有对话的内容 - jsonl_lines = [] - exported_count = 0 - failed_items = [] - - for conv_info in conversations_to_export: - user_id = conv_info.get("user_id") - cid = conv_info.get("cid") - - if not user_id or not cid: - failed_items.append( - f"user_id:{user_id}, cid:{cid} - 缺少必要参数", - ) - continue - - try: - conversation = await self.conv_mgr.get_conversation( - unified_msg_origin=user_id, - conversation_id=cid, - ) - - if not conversation: - failed_items.append( - f"user_id:{user_id}, cid:{cid} - 对话不存在" - ) - continue - - # 解析对话内容 (history is always a JSON string from _convert_conv_from_v2_to_v1) - content = json.loads(conversation.history) - - # 创建导出记录 - export_record = { - "cid": cid, - "user_id": user_id, - "platform_id": conversation.platform_id, - "title": conversation.title, - "persona_id": conversation.persona_id, - "created_at": conversation.created_at, - "updated_at": conversation.updated_at, - "content": content, - } - - # 将记录转换为 JSON 字符串并添加到 JSONL - jsonl_lines.append(json.dumps(export_record, ensure_ascii=False)) - exported_count += 1 - - except Exception as e: - failed_items.append(f"user_id:{user_id}, cid:{cid} - {e!s}") - logger.error( - f"导出对话失败: user_id={user_id}, cid={cid}, error={e!s}" - ) - - if exported_count == 0: - return Response().error("没有成功导出任何对话").__dict__ - - # 创建 JSONL 内容 - jsonl_content = "\n".join(jsonl_lines) - - # 创建一个内存文件对象 - file_obj = BytesIO(jsonl_content.encode("utf-8")) - file_obj.seek(0) - - # 生成文件名 - timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - filename = f"astrbot_conversations_export_{timestamp}.jsonl" - - # 返回文件流 - return await send_file( - file_obj, - mimetype="application/jsonl", - as_attachment=True, - attachment_filename=filename, - ) - - except Exception as e: - logger.error(f"批量导出对话失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"批量导出对话失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/cron.py b/astrbot/dashboard/routes/cron.py deleted file mode 100644 index 85dbc25095..0000000000 --- a/astrbot/dashboard/routes/cron.py +++ /dev/null @@ -1,301 +0,0 @@ -import asyncio -import traceback -from datetime import datetime, timezone - -from quart import jsonify, request - -from astrbot.core import logger -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle - -from .route import Response, Route, RouteContext - - -class CronRoute(Route): - def __init__( - self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle - ) -> None: - super().__init__(context) - self.core_lifecycle = core_lifecycle - self._background_tasks: set[asyncio.Task] = set() - self.routes = [ - ("/cron/jobs", ("GET", self.list_jobs)), - ("/cron/jobs", ("POST", self.create_job)), - ("/cron/jobs/", ("PATCH", self.update_job)), - ("/cron/jobs/", ("DELETE", self.delete_job)), - ("/cron/jobs//run", ("POST", self.run_job_now)), - ] - self.register_routes() - - def _serialize_job(self, job) -> dict: - data = job.model_dump() if hasattr(job, "model_dump") else job.__dict__ - for k in ["created_at", "updated_at", "last_run_at", "next_run_time"]: - v = data.get(k) - if isinstance(v, datetime): - # Attach UTC - if v.tzinfo is None: - v = v.replace(tzinfo=timezone.utc) - data[k] = v.isoformat() - # expose note explicitly for UI (prefer payload.note then description) - payload = data.get("payload") or {} - data["note"] = payload.get("note") or data.get("description") or "" - data["run_at"] = payload.get("run_at") - data["run_once"] = data.get("run_once", False) - # status is internal; hide to avoid implying one-time completion for recurring jobs - data.pop("status", None) - return data - - async def list_jobs(self): - try: - cron_mgr = self.core_lifecycle.cron_manager - if cron_mgr is None: - return jsonify( - Response().error("Cron manager not initialized").__dict__ - ) - job_type = request.args.get("type") - jobs = await cron_mgr.list_jobs(job_type) - data = [self._serialize_job(j) for j in jobs] - return jsonify(Response().ok(data=data).__dict__) - except Exception as e: # noqa: BLE001 - logger.error(traceback.format_exc()) - return jsonify(Response().error(f"Failed to list jobs: {e!s}").__dict__) - - async def create_job(self): - try: - cron_mgr = self.core_lifecycle.cron_manager - if cron_mgr is None: - return jsonify( - Response().error("Cron manager not initialized").__dict__ - ) - - payload = await request.json - if not isinstance(payload, dict): - return jsonify(Response().error("Invalid payload").__dict__) - - name = payload.get("name") or "active_agent_task" - cron_expression = payload.get("cron_expression") - note = payload.get("note") or payload.get("description") or name - session = str(payload.get("session") or "").strip() - persona_id = payload.get("persona_id") - provider_id = payload.get("provider_id") - timezone = payload.get("timezone") - enabled = bool(payload.get("enabled", True)) - run_once = bool(payload.get("run_once", False)) - run_at = payload.get("run_at") - - if run_once and not run_at: - return jsonify( - Response().error("run_at is required when run_once=true").__dict__ - ) - if (not run_once) and not cron_expression: - return jsonify( - Response() - .error("cron_expression is required when run_once=false") - .__dict__ - ) - if run_once and cron_expression: - cron_expression = None # ignore cron when run_once specified - run_at_dt = None - if run_at: - try: - run_at_dt = datetime.fromisoformat(str(run_at)) - except Exception: - return jsonify( - Response().error("run_at must be ISO datetime").__dict__ - ) - - job_payload = { - "session": session, - "note": note, - "persona_id": persona_id, - "provider_id": provider_id, - "run_at": run_at, - "origin": "api", - } - - job = await cron_mgr.add_active_job( - name=name, - cron_expression=cron_expression, - payload=job_payload, - description=note, - timezone=timezone, - enabled=enabled, - run_once=run_once, - run_at=run_at_dt, - ) - - return jsonify(Response().ok(data=self._serialize_job(job)).__dict__) - except Exception as e: # noqa: BLE001 - logger.error(traceback.format_exc()) - return jsonify(Response().error(f"Failed to create job: {e!s}").__dict__) - - async def update_job(self, job_id: str): - try: - cron_mgr = self.core_lifecycle.cron_manager - if cron_mgr is None: - return jsonify( - Response().error("Cron manager not initialized").__dict__ - ) - - payload = await request.json - if not isinstance(payload, dict): - return jsonify(Response().error("Invalid payload").__dict__) - - job = await cron_mgr.db.get_cron_job(job_id) - if not job: - return jsonify(Response().error("Job not found").__dict__) - - updates = {} - if "name" in payload: - name = str(payload.get("name") or "").strip() - if not name: - return jsonify(Response().error("name cannot be empty").__dict__) - updates["name"] = name - - if "enabled" in payload: - updates["enabled"] = bool(payload.get("enabled")) - - if "timezone" in payload: - timezone = payload.get("timezone") - updates["timezone"] = str(timezone).strip() or None - - next_run_once = ( - bool(payload.get("run_once")) - if "run_once" in payload - else bool(job.run_once) - ) - - if job.job_type == "active_agent": - merged_payload = ( - dict(job.payload) if isinstance(job.payload, dict) else {} - ) - if "payload" in payload and isinstance(payload.get("payload"), dict): - merged_payload.update(payload["payload"]) - - if "session" in payload: - session = str(payload.get("session") or "").strip() - if session: - merged_payload["session"] = session - else: - merged_payload.pop("session", None) - - note_updated = False - if "note" in payload: - note = str(payload.get("note") or "").strip() - if not note: - return jsonify( - Response().error("note cannot be empty").__dict__ - ) - merged_payload["note"] = note - updates["description"] = note - note_updated = True - elif "description" in payload: - description = str(payload.get("description") or "").strip() - if not description: - return jsonify( - Response().error("description cannot be empty").__dict__ - ) - updates["description"] = description - merged_payload["note"] = description - note_updated = True - - if not note_updated and updates.get("description") is None: - existing_note = str( - merged_payload.get("note") or job.description or "" - ).strip() - if existing_note: - merged_payload["note"] = existing_note - - next_cron_expression = ( - payload.get("cron_expression") - if "cron_expression" in payload - else job.cron_expression - ) - if next_cron_expression is not None: - next_cron_expression = str(next_cron_expression).strip() or None - - run_at_raw = ( - payload.get("run_at") - if "run_at" in payload - else merged_payload.get("run_at") - ) - run_at_iso = None - if run_at_raw: - try: - run_at_iso = datetime.fromisoformat(str(run_at_raw)).isoformat() - except Exception: - return jsonify( - Response().error("run_at must be ISO datetime").__dict__ - ) - - if next_run_once: - if not run_at_iso: - return jsonify( - Response() - .error("run_at is required when run_once=true") - .__dict__ - ) - next_cron_expression = None - merged_payload["run_at"] = run_at_iso - else: - if not next_cron_expression: - return jsonify( - Response() - .error("cron_expression is required when run_once=false") - .__dict__ - ) - merged_payload.pop("run_at", None) - - updates["run_once"] = next_run_once - updates["cron_expression"] = next_cron_expression - updates["payload"] = merged_payload - else: - if "cron_expression" in payload: - cron_expression = str(payload.get("cron_expression") or "").strip() - if not cron_expression: - return jsonify( - Response().error("cron_expression cannot be empty").__dict__ - ) - updates["cron_expression"] = cron_expression - - if "description" in payload: - description = str(payload.get("description") or "").strip() - updates["description"] = description or None - - job = await cron_mgr.update_job(job_id, **updates) - if not job: - return jsonify(Response().error("Job not found").__dict__) - return jsonify(Response().ok(data=self._serialize_job(job)).__dict__) - except Exception as e: # noqa: BLE001 - logger.error(traceback.format_exc()) - return jsonify(Response().error(f"Failed to update job: {e!s}").__dict__) - - async def delete_job(self, job_id: str): - try: - cron_mgr = self.core_lifecycle.cron_manager - if cron_mgr is None: - return jsonify( - Response().error("Cron manager not initialized").__dict__ - ) - await cron_mgr.delete_job(job_id) - return jsonify(Response().ok(message="deleted").__dict__) - except Exception as e: # noqa: BLE001 - logger.error(traceback.format_exc()) - return jsonify(Response().error(f"Failed to delete job: {e!s}").__dict__) - - async def run_job_now(self, job_id: str): - try: - cron_mgr = self.core_lifecycle.cron_manager - if cron_mgr is None: - return jsonify( - Response().error("Cron manager not initialized").__dict__ - ) - job = await cron_mgr.db.get_cron_job(job_id) - if not job: - return jsonify(Response().error("Job not found").__dict__) - task = asyncio.create_task(cron_mgr.run_job_now(job_id)) - self._background_tasks.add(task) - task.add_done_callback(self._background_tasks.discard) - return jsonify(Response().ok(message="started").__dict__) - except Exception as e: # noqa: BLE001 - logger.error(traceback.format_exc()) - return jsonify(Response().error(f"Failed to run job: {e!s}").__dict__) diff --git a/astrbot/dashboard/routes/file.py b/astrbot/dashboard/routes/file.py deleted file mode 100644 index 1880150bf0..0000000000 --- a/astrbot/dashboard/routes/file.py +++ /dev/null @@ -1,24 +0,0 @@ -from quart import abort, send_file - -from astrbot.core import file_token_service - -from .route import Route, RouteContext - - -class FileRoute(Route): - def __init__( - self, - context: RouteContext, - ) -> None: - super().__init__(context) - self.routes = { - "/file/": ("GET", self.serve_file), - } - self.register_routes() - - async def serve_file(self, file_token: str): - try: - file_path = await file_token_service.handle_file(file_token) - return await send_file(file_path) - except (FileNotFoundError, KeyError): - return abort(404) diff --git a/astrbot/dashboard/routes/knowledge_base.py b/astrbot/dashboard/routes/knowledge_base.py deleted file mode 100644 index 1b6f7a435d..0000000000 --- a/astrbot/dashboard/routes/knowledge_base.py +++ /dev/null @@ -1,1288 +0,0 @@ -"""知识库管理 API 路由""" - -import asyncio -import os -import traceback -import uuid -from typing import Any - -import aiofiles -from quart import request - -from astrbot.core import logger -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.provider.provider import EmbeddingProvider, RerankProvider -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from ..utils import generate_tsne_visualization -from .route import Response, Route, RouteContext - - -class KnowledgeBaseRoute(Route): - """知识库管理路由 - - 提供知识库、文档、检索、会话配置等 API 接口 - """ - - def __init__( - self, - context: RouteContext, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.core_lifecycle = core_lifecycle - self.kb_manager = None # 延迟初始化 - self.kb_db = None - self.session_config_db = None # 会话配置数据库 - self.retrieval_manager = None - self.upload_progress = {} # 存储上传进度 {task_id: {status, file_index, file_total, stage, current, total}} - self.upload_tasks = {} # 存储后台上传任务 {task_id: {"status", "result", "error"}} - - # 注册路由 - self.routes = { - # 知识库管理 - "/kb/list": ("GET", self.list_kbs), - "/kb/create": ("POST", self.create_kb), - "/kb/get": ("GET", self.get_kb), - "/kb/update": ("POST", self.update_kb), - "/kb/delete": ("POST", self.delete_kb), - "/kb/stats": ("GET", self.get_kb_stats), - # 文档管理 - "/kb/document/list": ("GET", self.list_documents), - "/kb/document/upload": ("POST", self.upload_document), - "/kb/document/import": ("POST", self.import_documents), - "/kb/document/upload/url": ("POST", self.upload_document_from_url), - "/kb/document/upload/progress": ("GET", self.get_upload_progress), - "/kb/document/get": ("GET", self.get_document), - "/kb/document/delete": ("POST", self.delete_document), - # # 块管理 - "/kb/chunk/list": ("GET", self.list_chunks), - "/kb/chunk/delete": ("POST", self.delete_chunk), - # # 多媒体管理 - # "/kb/media/list": ("GET", self.list_media), - # "/kb/media/delete": ("POST", self.delete_media), - # 检索 - "/kb/retrieve": ("POST", self.retrieve), - } - self.register_routes() - - def _get_kb_manager(self): - return self.core_lifecycle.kb_manager - - def _init_task(self, task_id: str, status: str = "pending") -> None: - self.upload_tasks[task_id] = { - "status": status, - "result": None, - "error": None, - } - - def _set_task_result( - self, task_id: str, status: str, result: Any = None, error: str | None = None - ) -> None: - self.upload_tasks[task_id] = { - "status": status, - "result": result, - "error": error, - } - if task_id in self.upload_progress: - self.upload_progress[task_id]["status"] = status - - def _update_progress( - self, - task_id: str, - *, - status: str | None = None, - file_index: int | None = None, - file_name: str | None = None, - stage: str | None = None, - current: int | None = None, - total: int | None = None, - ) -> None: - if task_id not in self.upload_progress: - return - p = self.upload_progress[task_id] - if status is not None: - p["status"] = status - if file_index is not None: - p["file_index"] = file_index - if file_name is not None: - p["file_name"] = file_name - if stage is not None: - p["stage"] = stage - if current is not None: - p["current"] = current - if total is not None: - p["total"] = total - - def _make_progress_callback(self, task_id: str, file_idx: int, file_name: str): - async def _callback(stage: str, current: int, total: int) -> None: - self._update_progress( - task_id, - status="processing", - file_index=file_idx, - file_name=file_name, - stage=stage, - current=current, - total=total, - ) - - return _callback - - @staticmethod - def _format_failed_doc_error(file_name: str, error: Exception) -> str: - message = str(error).strip() or "上传失败:发生未知错误。" - if message.startswith(file_name): - return message - return f"{file_name}: {message}" - - async def _background_upload_task( - self, - task_id: str, - kb_helper, - files_to_upload: list, - chunk_size: int, - chunk_overlap: int, - batch_size: int, - tasks_limit: int, - max_retries: int, - ) -> None: - """后台上传任务""" - try: - # 初始化任务状态 - self._init_task(task_id, status="processing") - self.upload_progress[task_id] = { - "status": "processing", - "file_index": 0, - "file_total": len(files_to_upload), - "stage": "waiting", - "current": 0, - "total": 100, - } - - uploaded_docs = [] - failed_docs = [] - - for file_idx, file_info in enumerate(files_to_upload): - try: - # 更新整体进度 - self._update_progress( - task_id, - status="processing", - file_index=file_idx, - file_name=file_info["file_name"], - stage="parsing", - current=0, - total=100, - ) - - # 创建进度回调函数 - progress_callback = self._make_progress_callback( - task_id, file_idx, file_info["file_name"] - ) - - doc = await kb_helper.upload_document( - file_name=file_info["file_name"], - file_content=file_info["file_content"], - file_type=file_info["file_type"], - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - progress_callback=progress_callback, - ) - - uploaded_docs.append(doc.model_dump()) - except Exception as e: - logger.error(f"上传文档 {file_info['file_name']} 失败: {e}") - failed_docs.append( - { - "file_name": file_info["file_name"], - "error": self._format_failed_doc_error( - file_info["file_name"], e - ), - }, - ) - - # 更新任务完成状态 - result = { - "task_id": task_id, - "uploaded": uploaded_docs, - "failed": failed_docs, - "total": len(files_to_upload), - "success_count": len(uploaded_docs), - "failed_count": len(failed_docs), - } - - self._set_task_result(task_id, "completed", result=result) - - except Exception as e: - logger.error(f"后台上传任务 {task_id} 失败: {e}") - logger.error(traceback.format_exc()) - self._set_task_result(task_id, "failed", error=str(e)) - - async def _background_import_task( - self, - task_id: str, - kb_helper, - documents: list, - batch_size: int, - tasks_limit: int, - max_retries: int, - ) -> None: - """后台导入预切片文档任务""" - try: - # 初始化任务状态 - self._init_task(task_id, status="processing") - self.upload_progress[task_id] = { - "status": "processing", - "file_index": 0, - "file_total": len(documents), - "stage": "waiting", - "current": 0, - "total": 100, - } - - uploaded_docs = [] - failed_docs = [] - - for file_idx, doc_info in enumerate(documents): - file_name = doc_info.get("file_name", f"imported_doc_{file_idx}") - chunks = doc_info.get("chunks", []) - - try: - # 更新整体进度 - self._update_progress( - task_id, - status="processing", - file_index=file_idx, - file_name=file_name, - stage="importing", - current=0, - total=100, - ) - - # 创建进度回调函数 - progress_callback = self._make_progress_callback( - task_id, file_idx, file_name - ) - - # 调用 upload_document,传入 pre_chunked_text - doc = await kb_helper.upload_document( - file_name=file_name, - file_content=None, # 预切片模式下不需要原始内容 - file_type=doc_info.get("file_type") - or ( - file_name.rsplit(".", 1)[-1].lower() - if "." in file_name - else "txt" - ), - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - progress_callback=progress_callback, - pre_chunked_text=chunks, - ) - - uploaded_docs.append(doc.model_dump()) - except Exception as e: - logger.error(f"导入文档 {file_name} 失败: {e}") - failed_docs.append( - { - "file_name": file_name, - "error": self._format_failed_doc_error(file_name, e), - }, - ) - - # 更新任务完成状态 - result = { - "task_id": task_id, - "uploaded": uploaded_docs, - "failed": failed_docs, - "total": len(documents), - "success_count": len(uploaded_docs), - "failed_count": len(failed_docs), - } - - self._set_task_result(task_id, "completed", result=result) - - except Exception as e: - logger.error(f"后台导入任务 {task_id} 失败: {e}") - logger.error(traceback.format_exc()) - self._set_task_result(task_id, "failed", error=str(e)) - - async def list_kbs(self): - """获取知识库列表 - - Query 参数: - - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) - - refresh_stats: 是否刷新统计信息 (默认 false,首次加载时可设为 true) - """ - try: - kb_manager = self._get_kb_manager() - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) - - kbs = await kb_manager.list_kbs() - - # 转换为字典列表 - kb_list = [] - for kb in kbs: - kb_dict = kb.model_dump() - # include init_error from KBHelper if present - kb_helper = await kb_manager.get_kb(kb.kb_id) - if kb_helper and kb_helper.init_error: - kb_dict["init_error"] = kb_helper.init_error - kb_list.append(kb_dict) - - return ( - Response() - .ok({"items": kb_list, "page": page, "page_size": page_size}) - .__dict__ - ) - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取知识库列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取知识库列表失败: {e!s}").__dict__ - - async def create_kb(self): - """创建知识库 - - Body: - - kb_name: 知识库名称 (必填) - - description: 描述 (可选) - - emoji: 图标 (可选) - - embedding_provider_id: 嵌入模型提供商ID (可选) - - rerank_provider_id: 重排序模型提供商ID (可选) - - chunk_size: 分块大小 (可选, 默认512) - - chunk_overlap: 块重叠大小 (可选, 默认50) - - top_k_dense: 密集检索数量 (可选, 默认50) - - top_k_sparse: 稀疏检索数量 (可选, 默认50) - - top_m_final: 最终返回数量 (可选, 默认5) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - kb_name = data.get("kb_name") - if not kb_name: - return Response().error("知识库名称不能为空").__dict__ - - description = data.get("description") - emoji = data.get("emoji") - embedding_provider_id = data.get("embedding_provider_id") - rerank_provider_id = data.get("rerank_provider_id") - chunk_size = data.get("chunk_size") - chunk_overlap = data.get("chunk_overlap") - top_k_dense = data.get("top_k_dense") - top_k_sparse = data.get("top_k_sparse") - top_m_final = data.get("top_m_final") - - # pre-check embedding dim - if not embedding_provider_id: - return Response().error("缺少参数 embedding_provider_id").__dict__ - prv = await kb_manager.provider_manager.get_provider_by_id( - embedding_provider_id, - ) # type: ignore - if not prv or not isinstance(prv, EmbeddingProvider): - return ( - Response().error(f"嵌入模型不存在或类型错误({type(prv)})").__dict__ - ) - try: - vec = await prv.get_embedding("astrbot") - if len(vec) != prv.get_dim(): - raise ValueError( - f"嵌入向量维度不匹配,实际是 {len(vec)},然而配置是 {prv.get_dim()}", - ) - except Exception as e: - return Response().error(f"测试嵌入模型失败: {e!s}").__dict__ - # pre-check rerank - if rerank_provider_id: - rerank_prv: RerankProvider = ( - await kb_manager.provider_manager.get_provider_by_id( - rerank_provider_id, - ) - ) # type: ignore - if not rerank_prv: - return Response().error("重排序模型不存在").__dict__ - # 检查重排序模型可用性 - try: - res = await rerank_prv.rerank( - query="astrbot", - documents=["astrbot knowledge base"], - ) - if not res: - raise ValueError("重排序模型返回结果异常") - except Exception as e: - return ( - Response() - .error(f"测试重排序模型失败: {e!s},请检查平台日志输出。") - .__dict__ - ) - - kb_helper = await kb_manager.create_kb( - kb_name=kb_name, - description=description, - emoji=emoji, - embedding_provider_id=embedding_provider_id, - rerank_provider_id=rerank_provider_id, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - top_k_dense=top_k_dense, - top_k_sparse=top_k_sparse, - top_m_final=top_m_final, - ) - kb = kb_helper.kb - - return Response().ok(kb.model_dump(), "创建知识库成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"创建知识库失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"创建知识库失败: {e!s}").__dict__ - - async def get_kb(self): - """获取知识库详情 - - Query 参数: - - kb_id: 知识库 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - kb = kb_helper.kb - - return Response().ok(kb.model_dump()).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取知识库详情失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取知识库详情失败: {e!s}").__dict__ - - async def update_kb(self): - """更新知识库 - - Body: - - kb_id: 知识库 ID (必填) - - kb_name: 新的知识库名称 (可选) - - description: 新的描述 (可选) - - emoji: 新的图标 (可选) - - embedding_provider_id: 新的嵌入模型提供商ID (可选) - - rerank_provider_id: 新的重排序模型提供商ID (可选) - - chunk_size: 分块大小 (可选) - - chunk_overlap: 块重叠大小 (可选) - - top_k_dense: 密集检索数量 (可选) - - top_k_sparse: 稀疏检索数量 (可选) - - top_m_final: 最终返回数量 (可选) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - kb_name = data.get("kb_name") - description = data.get("description") - emoji = data.get("emoji") - embedding_provider_id = data.get("embedding_provider_id") - rerank_provider_id = data.get("rerank_provider_id") - chunk_size = data.get("chunk_size") - chunk_overlap = data.get("chunk_overlap") - top_k_dense = data.get("top_k_dense") - top_k_sparse = data.get("top_k_sparse") - top_m_final = data.get("top_m_final") - - # 检查是否至少提供了一个更新字段 - if all( - v is None - for v in [ - kb_name, - description, - emoji, - embedding_provider_id, - rerank_provider_id, - chunk_size, - chunk_overlap, - top_k_dense, - top_k_sparse, - top_m_final, - ] - ): - return Response().error("至少需要提供一个更新字段").__dict__ - - kb_helper = await kb_manager.update_kb( - kb_id=kb_id, - kb_name=kb_name, - description=description, - emoji=emoji, - embedding_provider_id=embedding_provider_id, - rerank_provider_id=rerank_provider_id, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - top_k_dense=top_k_dense, - top_k_sparse=top_k_sparse, - top_m_final=top_m_final, - ) - - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - kb = kb_helper.kb - return Response().ok(kb.model_dump(), "更新知识库成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"更新知识库失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"更新知识库失败: {e!s}").__dict__ - - async def delete_kb(self): - """删除知识库 - - Body: - - kb_id: 知识库 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - success = await kb_manager.delete_kb(kb_id) - if not success: - return Response().error("知识库不存在").__dict__ - - return Response().ok(message="删除知识库成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除知识库失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除知识库失败: {e!s}").__dict__ - - async def get_kb_stats(self): - """获取知识库统计信息 - - Query 参数: - - kb_id: 知识库 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - kb = kb_helper.kb - - stats = { - "kb_id": kb.kb_id, - "kb_name": kb.kb_name, - "doc_count": kb.doc_count, - "chunk_count": kb.chunk_count, - "created_at": kb.created_at.isoformat(), - "updated_at": kb.updated_at.isoformat(), - } - - return Response().ok(stats).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取知识库统计失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取知识库统计失败: {e!s}").__dict__ - - # ===== 文档管理 API ===== - - async def list_documents(self): - """获取文档列表 - - Query 参数: - - kb_id: 知识库 ID (必填) - - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 100, type=int) - - offset = (page - 1) * page_size - limit = page_size - - doc_list = await kb_helper.list_documents(offset=offset, limit=limit) - - doc_list = [doc.model_dump() for doc in doc_list] - - return ( - Response() - .ok({"items": doc_list, "page": page, "page_size": page_size}) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取文档列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取文档列表失败: {e!s}").__dict__ - - async def upload_document(self): - """上传文档 - - 支持两种方式: - 1. multipart/form-data 文件上传(支持多文件,最多10个) - 2. JSON 格式 base64 编码上传(支持多文件,最多10个) - - Form Data (multipart/form-data): - - kb_id: 知识库 ID (必填) - - file: 文件对象 (必填,可多个,字段名为 file, file1, file2, ... 或 files[]) - - JSON Body (application/json): - - kb_id: 知识库 ID (必填) - - files: 文件数组 (必填) - - file_name: 文件名 (必填) - - file_content: base64 编码的文件内容 (必填) - - 返回: - - task_id: 任务ID,用于查询上传进度和结果 - """ - try: - kb_manager = self._get_kb_manager() - - # 检查 Content-Type - content_type = request.content_type - kb_id = None - chunk_size = None - chunk_overlap = None - batch_size = 32 - tasks_limit = 3 - max_retries = 3 - files_to_upload = [] # 存储待上传的文件信息列表 - - if content_type and "multipart/form-data" not in content_type: - return ( - Response().error("Content-Type 须为 multipart/form-data").__dict__ - ) - form_data = await request.form - files = await request.files - - kb_id = form_data.get("kb_id") - chunk_size = int(form_data.get("chunk_size", 512)) - chunk_overlap = int(form_data.get("chunk_overlap", 50)) - batch_size = int(form_data.get("batch_size", 32)) - tasks_limit = int(form_data.get("tasks_limit", 3)) - max_retries = int(form_data.get("max_retries", 3)) - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - # 收集所有文件 - file_list = [] - # 支持 file, file1, file2, ... 或 files[] 格式 - for key in files.keys(): - if key == "file" or key.startswith("file") or key == "files[]": - file_items = files.getlist(key) - file_list.extend(file_items) - - if not file_list: - return Response().error("缺少文件").__dict__ - - # 限制文件数量 - if len(file_list) > 10: - return Response().error("最多只能上传10个文件").__dict__ - - # 处理每个文件 - for file in file_list: - file_name = file.filename - - # 保存到临时文件 - temp_file_path = os.path.join( - get_astrbot_temp_path(), - f"kb_upload_{uuid.uuid4()}_{file_name}", - ) - await file.save(temp_file_path) - - try: - # 异步读取文件内容 - async with aiofiles.open(temp_file_path, "rb") as f: - file_content = await f.read() - - # 提取文件类型 - file_type = ( - file_name.rsplit(".", 1)[-1].lower() if "." in file_name else "" - ) - - files_to_upload.append( - { - "file_name": file_name, - "file_content": file_content, - "file_type": file_type, - }, - ) - finally: - # 清理临时文件 - if os.path.exists(temp_file_path): - os.remove(temp_file_path) - - # 获取知识库 - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - # 生成任务ID - task_id = str(uuid.uuid4()) - - # 初始化任务状态 - self._init_task(task_id, status="pending") - - # 启动后台任务 - asyncio.create_task( - self._background_upload_task( - task_id=task_id, - kb_helper=kb_helper, - files_to_upload=files_to_upload, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - ), - ) - - return ( - Response() - .ok( - { - "task_id": task_id, - "file_count": len(files_to_upload), - "message": "task created, processing in background", - }, - ) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"上传文档失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"上传文档失败: {e!s}").__dict__ - - def _validate_import_request(self, data: dict): - kb_id = data.get("kb_id") - if not kb_id: - raise ValueError("缺少参数 kb_id") - - documents = data.get("documents") - if not documents or not isinstance(documents, list): - raise ValueError("缺少参数 documents 或格式错误") - - for doc in documents: - if "file_name" not in doc or "chunks" not in doc: - raise ValueError("文档格式错误,必须包含 file_name 和 chunks") - if not isinstance(doc["chunks"], list): - raise ValueError("chunks 必须是列表") - if not all( - isinstance(chunk, str) and chunk.strip() for chunk in doc["chunks"] - ): - raise ValueError("chunks 必须是非空字符串列表") - - batch_size = data.get("batch_size", 32) - tasks_limit = data.get("tasks_limit", 3) - max_retries = data.get("max_retries", 3) - return kb_id, documents, batch_size, tasks_limit, max_retries - - async def import_documents(self): - """导入预切片文档 - - Body: - - kb_id: 知识库 ID (必填) - - documents: 文档列表 (必填) - - file_name: 文件名 (必填) - - chunks: 切片列表 (必填, list[str]) - - file_type: 文件类型 (可选, 默认从文件名推断或为 txt) - - batch_size: 批处理大小 (可选, 默认32) - - tasks_limit: 并发任务限制 (可选, 默认3) - - max_retries: 最大重试次数 (可选, 默认3) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id, documents, batch_size, tasks_limit, max_retries = ( - self._validate_import_request(data) - ) - - # 获取知识库 - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - # 生成任务ID - task_id = str(uuid.uuid4()) - - # 初始化任务状态 - self._init_task(task_id, status="pending") - - # 启动后台任务 - asyncio.create_task( - self._background_import_task( - task_id=task_id, - kb_helper=kb_helper, - documents=documents, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - ), - ) - - return ( - Response() - .ok( - { - "task_id": task_id, - "doc_count": len(documents), - "message": "import task created, processing in background", - }, - ) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"导入文档失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"导入文档失败: {e!s}").__dict__ - - async def get_upload_progress(self): - """获取上传进度和结果 - - Query 参数: - - task_id: 任务 ID (必填) - - 返回状态: - - pending: 任务待处理 - - processing: 任务处理中 - - completed: 任务完成 - - failed: 任务失败 - """ - try: - task_id = request.args.get("task_id") - if not task_id: - return Response().error("缺少参数 task_id").__dict__ - - # 检查任务是否存在 - if task_id not in self.upload_tasks: - return Response().error("找不到该任务").__dict__ - - task_info = self.upload_tasks[task_id] - status = task_info["status"] - - # 构建返回数据 - response_data = { - "task_id": task_id, - "status": status, - } - - # 如果任务正在处理,返回进度信息 - if status == "processing" and task_id in self.upload_progress: - response_data["progress"] = self.upload_progress[task_id] - - # 如果任务完成,返回结果 - if status == "completed": - response_data["result"] = task_info["result"] - # 清理已完成的任务 - # del self.upload_tasks[task_id] - # if task_id in self.upload_progress: - # del self.upload_progress[task_id] - - # 如果任务失败,返回错误信息 - if status == "failed": - response_data["error"] = task_info["error"] - - return Response().ok(response_data).__dict__ - - except Exception as e: - logger.error(f"获取上传进度失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取上传进度失败: {e!s}").__dict__ - - async def get_document(self): - """获取文档详情 - - Query 参数: - - doc_id: 文档 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - doc_id = request.args.get("doc_id") - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - doc = await kb_helper.get_document(doc_id) - if not doc: - return Response().error("文档不存在").__dict__ - - return Response().ok(doc.model_dump()).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取文档详情失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取文档详情失败: {e!s}").__dict__ - - async def delete_document(self): - """删除文档 - - Body: - - kb_id: 知识库 ID (必填) - - doc_id: 文档 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - doc_id = data.get("doc_id") - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - await kb_helper.delete_document(doc_id) - return Response().ok(message="删除文档成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除文档失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除文档失败: {e!s}").__dict__ - - async def delete_chunk(self): - """删除文本块 - - Body: - - kb_id: 知识库 ID (必填) - - chunk_id: 块 ID (必填) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - chunk_id = data.get("chunk_id") - if not chunk_id: - return Response().error("缺少参数 chunk_id").__dict__ - doc_id = data.get("doc_id") - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - await kb_helper.delete_chunk(chunk_id, doc_id) - return Response().ok(message="删除文本块成功").__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除文本块失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"删除文本块失败: {e!s}").__dict__ - - async def list_chunks(self): - """获取块列表 - - Query 参数: - - kb_id: 知识库 ID (必填) - - page: 页码 (默认 1) - - page_size: 每页数量 (默认 20) - """ - try: - kb_manager = self._get_kb_manager() - kb_id = request.args.get("kb_id") - doc_id = request.args.get("doc_id") - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 100, type=int) - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - if not doc_id: - return Response().error("缺少参数 doc_id").__dict__ - kb_helper = await kb_manager.get_kb(kb_id) - offset = (page - 1) * page_size - limit = page_size - if not kb_helper: - return Response().error("知识库不存在").__dict__ - chunk_list = await kb_helper.get_chunks_by_doc_id( - doc_id=doc_id, - offset=offset, - limit=limit, - ) - return ( - Response() - .ok( - data={ - "items": chunk_list, - "page": page, - "page_size": page_size, - "total": await kb_helper.get_chunk_count_by_doc_id(doc_id), - }, - ) - .__dict__ - ) - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"获取块列表失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"获取块列表失败: {e!s}").__dict__ - - # ===== 检索 API ===== - - async def retrieve(self): - """检索知识库 - - Body: - - query: 查询文本 (必填) - - kb_ids: 知识库 ID 列表 (必填) - - top_k: 返回结果数量 (可选, 默认 5) - - debug: 是否启用调试模式,返回 t-SNE 可视化图片 (可选, 默认 False) - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - query = data.get("query") - kb_names = data.get("kb_names") - debug = data.get("debug", False) - - if not query: - return Response().error("缺少参数 query").__dict__ - if not kb_names or not isinstance(kb_names, list): - return Response().error("缺少参数 kb_names 或格式错误").__dict__ - - top_k = data.get("top_k", 5) - - results = await kb_manager.retrieve( - query=query, - kb_names=kb_names, - top_m_final=top_k, - ) - result_list = [] - if results: - result_list = results["results"] - - response_data = { - "results": result_list, - "total": len(result_list), - "query": query, - } - - # Debug 模式:生成 t-SNE 可视化 - if debug: - try: - img_base64 = await generate_tsne_visualization( - query, - kb_names, - kb_manager, - ) - if img_base64: - response_data["visualization"] = img_base64 - except Exception as e: - logger.error(f"生成 t-SNE 可视化失败: {e}") - logger.error(traceback.format_exc()) - response_data["visualization_error"] = str(e) - - return Response().ok(response_data).__dict__ - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"检索失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"检索失败: {e!s}").__dict__ - - async def upload_document_from_url(self): - """从 URL 上传文档 - - Body: - - kb_id: 知识库 ID (必填) - - url: 要提取内容的网页 URL (必填) - - chunk_size: 分块大小 (可选, 默认512) - - chunk_overlap: 块重叠大小 (可选, 默认50) - - batch_size: 批处理大小 (可选, 默认32) - - tasks_limit: 并发任务限制 (可选, 默认3) - - max_retries: 最大重试次数 (可选, 默认3) - - 返回: - - task_id: 任务ID,用于查询上传进度和结果 - """ - try: - kb_manager = self._get_kb_manager() - data = await request.json - - kb_id = data.get("kb_id") - if not kb_id: - return Response().error("缺少参数 kb_id").__dict__ - - url = data.get("url") - if not url: - return Response().error("缺少参数 url").__dict__ - - chunk_size = data.get("chunk_size", 512) - chunk_overlap = data.get("chunk_overlap", 50) - batch_size = data.get("batch_size", 32) - tasks_limit = data.get("tasks_limit", 3) - max_retries = data.get("max_retries", 3) - enable_cleaning = data.get("enable_cleaning", False) - cleaning_provider_id = data.get("cleaning_provider_id") - - # 获取知识库 - kb_helper = await kb_manager.get_kb(kb_id) - if not kb_helper: - return Response().error("知识库不存在").__dict__ - - # 生成任务ID - task_id = str(uuid.uuid4()) - - # 初始化任务状态 - self._init_task(task_id, status="pending") - - # 启动后台任务 - asyncio.create_task( - self._background_upload_from_url_task( - task_id=task_id, - kb_helper=kb_helper, - url=url, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - enable_cleaning=enable_cleaning, - cleaning_provider_id=cleaning_provider_id, - ), - ) - - return ( - Response() - .ok( - { - "task_id": task_id, - "url": url, - "message": "URL upload task created, processing in background", - }, - ) - .__dict__ - ) - - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"从URL上传文档失败: {e}") - logger.error(traceback.format_exc()) - return Response().error(f"从URL上传文档失败: {e!s}").__dict__ - - async def _background_upload_from_url_task( - self, - task_id: str, - kb_helper, - url: str, - chunk_size: int, - chunk_overlap: int, - batch_size: int, - tasks_limit: int, - max_retries: int, - enable_cleaning: bool, - cleaning_provider_id: str | None, - ) -> None: - """后台上传URL任务""" - try: - # 初始化任务状态 - self._init_task(task_id, status="processing") - self.upload_progress[task_id] = { - "status": "processing", - "file_index": 0, - "file_total": 1, - "file_name": f"URL: {url}", - "stage": "extracting", - "current": 0, - "total": 100, - } - - # 创建进度回调函数 - progress_callback = self._make_progress_callback(task_id, 0, f"URL: {url}") - - # 上传文档 - doc = await kb_helper.upload_from_url( - url=url, - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - batch_size=batch_size, - tasks_limit=tasks_limit, - max_retries=max_retries, - progress_callback=progress_callback, - enable_cleaning=enable_cleaning, - cleaning_provider_id=cleaning_provider_id, - ) - - # 更新任务完成状态 - result = { - "task_id": task_id, - "uploaded": [doc.model_dump()], - "failed": [], - "total": 1, - "success_count": 1, - "failed_count": 0, - } - - self._set_task_result(task_id, "completed", result=result) - - except Exception as e: - logger.error(f"后台上传URL任务 {task_id} 失败: {e}") - logger.error(traceback.format_exc()) - self._set_task_result(task_id, "failed", error=str(e)) diff --git a/astrbot/dashboard/routes/log.py b/astrbot/dashboard/routes/log.py deleted file mode 100644 index e7eebef6e6..0000000000 --- a/astrbot/dashboard/routes/log.py +++ /dev/null @@ -1,144 +0,0 @@ -import asyncio -import json -import time -from collections.abc import AsyncGenerator -from typing import cast - -from quart import Response as QuartResponse -from quart import make_response, request - -from astrbot.core import LogBroker, logger - -from .route import Response, Route, RouteContext - - -def _format_log_sse(log: dict, ts: float) -> str: - """辅助函数:格式化 SSE 消息""" - payload = { - "type": "log", - **log, - } - return f"id: {ts}\ndata: {json.dumps(payload, ensure_ascii=False)}\n\n" - - -class LogRoute(Route): - def __init__(self, context: RouteContext, log_broker: LogBroker) -> None: - super().__init__(context) - self.log_broker = log_broker - self.app.add_url_rule("/api/live-log", view_func=self.log, methods=["GET"]) - self.app.add_url_rule( - "/api/log-history", - view_func=self.log_history, - methods=["GET"], - ) - self.app.add_url_rule( - "/api/trace/settings", - view_func=self.get_trace_settings, - methods=["GET"], - ) - self.app.add_url_rule( - "/api/trace/settings", - view_func=self.update_trace_settings, - methods=["POST"], - ) - - async def _replay_cached_logs( - self, last_event_id: str - ) -> AsyncGenerator[str, None]: - """辅助生成器:重放缓存的日志""" - try: - last_ts = float(last_event_id) - cached_logs = list(self.log_broker.log_cache) - - for log_item in cached_logs: - log_ts = float(log_item.get("time", 0)) - - if log_ts > last_ts: - yield _format_log_sse(log_item, log_ts) - - except ValueError: - pass - except Exception as e: - logger.error(f"Log SSE 补发历史错误: {e}") - - async def log(self) -> QuartResponse: - last_event_id = request.headers.get("Last-Event-ID") - - async def stream(): - queue = None - try: - if last_event_id: - async for event in self._replay_cached_logs(last_event_id): - yield event - - queue = self.log_broker.register() - while True: - message = await queue.get() - current_ts = message.get("time", time.time()) - yield _format_log_sse(message, current_ts) - - except asyncio.CancelledError: - pass - except Exception as e: - logger.error(f"Log SSE 连接错误: {e}") - finally: - if queue: - self.log_broker.unregister(queue) - - response = cast( - QuartResponse, - await make_response( - stream(), - { - "Content-Type": "text/event-stream", - "Cache-Control": "no-cache", - "Connection": "keep-alive", - "Transfer-Encoding": "chunked", - }, - ), - ) - response.timeout = None # type: ignore - return response - - async def log_history(self): - """获取日志历史""" - try: - logs = list(self.log_broker.log_cache) - return ( - Response() - .ok( - data={ - "logs": logs, - }, - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取日志历史失败: {e}") - return Response().error(f"获取日志历史失败: {e}").__dict__ - - async def get_trace_settings(self): - """获取 Trace 设置""" - try: - trace_enable = self.config.get("trace_enable", True) - return Response().ok(data={"trace_enable": trace_enable}).__dict__ - except Exception as e: - logger.error(f"获取 Trace 设置失败: {e}") - return Response().error(f"获取 Trace 设置失败: {e}").__dict__ - - async def update_trace_settings(self): - """更新 Trace 设置""" - try: - data = await request.json - if data is None: - return Response().error("请求数据为空").__dict__ - - trace_enable = data.get("trace_enable") - if trace_enable is not None: - self.config["trace_enable"] = bool(trace_enable) - self.config.save_config() - - return Response().ok(message="Trace 设置已更新").__dict__ - except Exception as e: - logger.error(f"更新 Trace 设置失败: {e}") - return Response().error(f"更新 Trace 设置失败: {e}").__dict__ diff --git a/astrbot/dashboard/routes/persona.py b/astrbot/dashboard/routes/persona.py deleted file mode 100644 index 8a805d4322..0000000000 --- a/astrbot/dashboard/routes/persona.py +++ /dev/null @@ -1,497 +0,0 @@ -import traceback - -from quart import request - -from astrbot.core import logger -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.db import BaseDatabase -from astrbot.core.sentinels import NOT_GIVEN - -from .route import Response, Route, RouteContext - - -class PersonaRoute(Route): - def __init__( - self, - context: RouteContext, - db_helper: BaseDatabase, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.routes = { - "/persona/list": ("GET", self.list_personas), - "/persona/detail": ("POST", self.get_persona_detail), - "/persona/create": ("POST", self.create_persona), - "/persona/update": ("POST", self.update_persona), - "/persona/delete": ("POST", self.delete_persona), - "/persona/move": ("POST", self.move_persona), - "/persona/reorder": ("POST", self.reorder_items), - # Folder routes - "/persona/folder/list": ("GET", self.list_folders), - "/persona/folder/tree": ("GET", self.get_folder_tree), - "/persona/folder/detail": ("POST", self.get_folder_detail), - "/persona/folder/create": ("POST", self.create_folder), - "/persona/folder/update": ("POST", self.update_folder), - "/persona/folder/delete": ("POST", self.delete_folder), - } - self.db_helper = db_helper - self.persona_mgr = core_lifecycle.persona_mgr - self.register_routes() - - async def list_personas(self): - """获取所有人格列表""" - try: - # 支持按文件夹筛选 - folder_id = request.args.get("folder_id") - if folder_id is not None: - personas = await self.persona_mgr.get_personas_by_folder( - folder_id if folder_id else None - ) - else: - personas = await self.persona_mgr.get_all_personas() - return ( - Response() - .ok( - [ - { - "persona_id": persona.persona_id, - "system_prompt": persona.system_prompt, - "begin_dialogs": persona.begin_dialogs or [], - "tools": persona.tools, - "skills": persona.skills, - "custom_error_message": persona.custom_error_message, - "folder_id": persona.folder_id, - "sort_order": persona.sort_order, - "created_at": persona.created_at.isoformat() - if persona.created_at - else None, - "updated_at": persona.updated_at.isoformat() - if persona.updated_at - else None, - } - for persona in personas - ], - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取人格列表失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"获取人格列表失败: {e!s}").__dict__ - - async def get_persona_detail(self): - """获取指定人格的详细信息""" - try: - data = await request.get_json() - persona_id = data.get("persona_id") - - if not persona_id: - return Response().error("缺少必要参数: persona_id").__dict__ - - persona = await self.persona_mgr.get_persona(persona_id) - if not persona: - return Response().error("人格不存在").__dict__ - - return ( - Response() - .ok( - { - "persona_id": persona.persona_id, - "system_prompt": persona.system_prompt, - "begin_dialogs": persona.begin_dialogs or [], - "tools": persona.tools, - "skills": persona.skills, - "custom_error_message": persona.custom_error_message, - "folder_id": persona.folder_id, - "sort_order": persona.sort_order, - "created_at": persona.created_at.isoformat() - if persona.created_at - else None, - "updated_at": persona.updated_at.isoformat() - if persona.updated_at - else None, - }, - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取人格详情失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"获取人格详情失败: {e!s}").__dict__ - - async def create_persona(self): - """创建新人格""" - try: - data = await request.get_json() - persona_id = data.get("persona_id", "").strip() - system_prompt = data.get("system_prompt", "").strip() - begin_dialogs = data.get("begin_dialogs", []) - tools = data.get("tools") - skills = data.get("skills") - custom_error_message = data.get("custom_error_message") - folder_id = data.get("folder_id") # None 表示根目录 - sort_order = data.get("sort_order", 0) - - if not persona_id: - return Response().error("人格ID不能为空").__dict__ - - if not system_prompt: - return Response().error("系统提示词不能为空").__dict__ - - if custom_error_message is not None: - if not isinstance(custom_error_message, str): - return Response().error("自定义报错回复信息必须是字符串").__dict__ - custom_error_message = custom_error_message.strip() or None - - # 验证 begin_dialogs 格式 - if begin_dialogs and len(begin_dialogs) % 2 != 0: - return ( - Response() - .error("预设对话数量必须为偶数(用户和助手轮流对话)") - .__dict__ - ) - - persona = await self.persona_mgr.create_persona( - persona_id=persona_id, - system_prompt=system_prompt, - begin_dialogs=begin_dialogs if begin_dialogs else None, - tools=tools if tools else None, - skills=skills if skills else None, - custom_error_message=custom_error_message, - folder_id=folder_id, - sort_order=sort_order, - ) - - return ( - Response() - .ok( - { - "message": "人格创建成功", - "persona": { - "persona_id": persona.persona_id, - "system_prompt": persona.system_prompt, - "begin_dialogs": persona.begin_dialogs or [], - "tools": persona.tools or [], - "skills": persona.skills or [], - "custom_error_message": persona.custom_error_message, - "folder_id": persona.folder_id, - "sort_order": persona.sort_order, - "created_at": persona.created_at.isoformat() - if persona.created_at - else None, - "updated_at": persona.updated_at.isoformat() - if persona.updated_at - else None, - }, - }, - ) - .__dict__ - ) - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"创建人格失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"创建人格失败: {e!s}").__dict__ - - async def update_persona(self): - """更新人格信息""" - try: - data = await request.get_json() - persona_id = data.get("persona_id") - system_prompt = data.get("system_prompt") - begin_dialogs = data.get("begin_dialogs") - has_tools = "tools" in data - tools = data.get("tools") - has_skills = "skills" in data - skills = data.get("skills") - has_custom_error_message = "custom_error_message" in data - custom_error_message = data.get("custom_error_message") - - if not persona_id: - return Response().error("缺少必要参数: persona_id").__dict__ - - if has_custom_error_message: - if custom_error_message is not None and not isinstance( - custom_error_message, str - ): - return Response().error("自定义报错回复信息必须是字符串").__dict__ - if isinstance(custom_error_message, str): - custom_error_message = custom_error_message.strip() or None - - # 验证 begin_dialogs 格式 - if begin_dialogs is not None and len(begin_dialogs) % 2 != 0: - return ( - Response() - .error("预设对话数量必须为偶数(用户和助手轮流对话)") - .__dict__ - ) - - update_kwargs = { - "persona_id": persona_id, - "system_prompt": system_prompt, - "begin_dialogs": begin_dialogs, - } - if has_tools: - update_kwargs["tools"] = tools - if has_skills: - update_kwargs["skills"] = skills - if has_custom_error_message: - update_kwargs["custom_error_message"] = custom_error_message - - await self.persona_mgr.update_persona(**update_kwargs) - - return Response().ok({"message": "人格更新成功"}).__dict__ - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"更新人格失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"更新人格失败: {e!s}").__dict__ - - async def delete_persona(self): - """删除人格""" - try: - data = await request.get_json() - persona_id = data.get("persona_id") - - if not persona_id: - return Response().error("缺少必要参数: persona_id").__dict__ - - await self.persona_mgr.delete_persona(persona_id) - - return Response().ok({"message": "人格删除成功"}).__dict__ - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"删除人格失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"删除人格失败: {e!s}").__dict__ - - async def move_persona(self): - """移动人格到指定文件夹""" - try: - data = await request.get_json() - persona_id = data.get("persona_id") - folder_id = data.get("folder_id") # None 表示移动到根目录 - - if not persona_id: - return Response().error("缺少必要参数: persona_id").__dict__ - - await self.persona_mgr.move_persona_to_folder(persona_id, folder_id) - - return Response().ok({"message": "人格移动成功"}).__dict__ - except ValueError as e: - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(f"移动人格失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"移动人格失败: {e!s}").__dict__ - - # ==== - # Folder Routes - # ==== - - async def list_folders(self): - """获取文件夹列表""" - try: - parent_id = request.args.get("parent_id") - # 空字符串视为 None(根目录) - if parent_id == "": - parent_id = None - folders = await self.persona_mgr.get_folders(parent_id) - return ( - Response() - .ok( - [ - { - "folder_id": folder.folder_id, - "name": folder.name, - "parent_id": folder.parent_id, - "description": folder.description, - "sort_order": folder.sort_order, - "created_at": folder.created_at.isoformat() - if folder.created_at - else None, - "updated_at": folder.updated_at.isoformat() - if folder.updated_at - else None, - } - for folder in folders - ], - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取文件夹列表失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"获取文件夹列表失败: {e!s}").__dict__ - - async def get_folder_tree(self): - """获取文件夹树形结构""" - try: - tree = await self.persona_mgr.get_folder_tree() - return Response().ok(tree).__dict__ - except Exception as e: - logger.error(f"获取文件夹树失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"获取文件夹树失败: {e!s}").__dict__ - - async def get_folder_detail(self): - """获取指定文件夹的详细信息""" - try: - data = await request.get_json() - folder_id = data.get("folder_id") - - if not folder_id: - return Response().error("缺少必要参数: folder_id").__dict__ - - folder = await self.persona_mgr.get_folder(folder_id) - if not folder: - return Response().error("文件夹不存在").__dict__ - - return ( - Response() - .ok( - { - "folder_id": folder.folder_id, - "name": folder.name, - "parent_id": folder.parent_id, - "description": folder.description, - "sort_order": folder.sort_order, - "created_at": folder.created_at.isoformat() - if folder.created_at - else None, - "updated_at": folder.updated_at.isoformat() - if folder.updated_at - else None, - }, - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取文件夹详情失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"获取文件夹详情失败: {e!s}").__dict__ - - async def create_folder(self): - """创建文件夹""" - try: - data = await request.get_json() - name = data.get("name", "").strip() - parent_id = data.get("parent_id") - description = data.get("description") - sort_order = data.get("sort_order", 0) - - if not name: - return Response().error("文件夹名称不能为空").__dict__ - - folder = await self.persona_mgr.create_folder( - name=name, - parent_id=parent_id, - description=description, - sort_order=sort_order, - ) - - return ( - Response() - .ok( - { - "message": "文件夹创建成功", - "folder": { - "folder_id": folder.folder_id, - "name": folder.name, - "parent_id": folder.parent_id, - "description": folder.description, - "sort_order": folder.sort_order, - "created_at": folder.created_at.isoformat() - if folder.created_at - else None, - "updated_at": folder.updated_at.isoformat() - if folder.updated_at - else None, - }, - }, - ) - .__dict__ - ) - except Exception as e: - logger.error(f"创建文件夹失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"创建文件夹失败: {e!s}").__dict__ - - async def update_folder(self): - """更新文件夹信息""" - try: - data = await request.get_json() - folder_id = data.get("folder_id") - name = data.get("name") - parent_id = data.get("parent_id") if "parent_id" in data else NOT_GIVEN - description = ( - data.get("description") if "description" in data else NOT_GIVEN - ) - sort_order = data.get("sort_order") - - if not folder_id: - return Response().error("缺少必要参数: folder_id").__dict__ - - await self.persona_mgr.update_folder( - folder_id=folder_id, - name=name, - parent_id=parent_id, - description=description, - sort_order=sort_order, - ) - - return Response().ok({"message": "文件夹更新成功"}).__dict__ - except Exception as e: - logger.error(f"更新文件夹失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"更新文件夹失败: {e!s}").__dict__ - - async def delete_folder(self): - """删除文件夹""" - try: - data = await request.get_json() - folder_id = data.get("folder_id") - - if not folder_id: - return Response().error("缺少必要参数: folder_id").__dict__ - - await self.persona_mgr.delete_folder(folder_id) - - return Response().ok({"message": "文件夹删除成功"}).__dict__ - except Exception as e: - logger.error(f"删除文件夹失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"删除文件夹失败: {e!s}").__dict__ - - async def reorder_items(self): - """批量更新排序顺序 - - 请求体格式: - { - "items": [ - {"id": "persona_id_1", "type": "persona", "sort_order": 0}, - {"id": "persona_id_2", "type": "persona", "sort_order": 1}, - {"id": "folder_id_1", "type": "folder", "sort_order": 0}, - ... - ] - } - """ - try: - data = await request.get_json() - items = data.get("items", []) - - if not items: - return Response().error("items 不能为空").__dict__ - - # 验证每个 item 的格式 - for item in items: - if not all(k in item for k in ("id", "type", "sort_order")): - return ( - Response() - .error("每个 item 必须包含 id, type, sort_order 字段") - .__dict__ - ) - if item["type"] not in ("persona", "folder"): - return ( - Response() - .error("type 字段必须是 'persona' 或 'folder'") - .__dict__ - ) - - await self.persona_mgr.batch_update_sort_order(items) - - return Response().ok({"message": "排序更新成功"}).__dict__ - except Exception as e: - logger.error(f"更新排序失败: {e!s}\n{traceback.format_exc()}") - return Response().error(f"更新排序失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/platform.py b/astrbot/dashboard/routes/platform.py deleted file mode 100644 index e302658584..0000000000 --- a/astrbot/dashboard/routes/platform.py +++ /dev/null @@ -1,285 +0,0 @@ -"""统一 Webhook 路由 - -提供统一的 webhook 回调入口,支持多个平台使用同一端口接收回调。 -""" - -import secrets -import string - -from quart import request - -from astrbot.core import logger -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.platform import Platform -from astrbot.core.platform.sources.dingtalk.app_registration import ( - poll_dingtalk_app_registration_once, - request_dingtalk_app_registration, -) -from astrbot.core.platform.sources.lark.app_registration import ( - poll_app_registration_once, - request_app_registration, -) -from astrbot.core.platform.sources.lark.bot_info import request_lark_bot_info -from astrbot.core.platform.sources.weixin_oc.login_registration import ( - poll_weixin_oc_login_once, - request_weixin_oc_login_qr, -) - -from .route import Response, Route, RouteContext - - -def _random_platform_id_suffix() -> str: - return "_" + "".join(secrets.choice(string.ascii_lowercase) for _ in range(4)) - - -class PlatformRoute(Route): - """统一 Webhook 路由""" - - def __init__( - self, - context: RouteContext, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.core_lifecycle = core_lifecycle - self.platform_manager = core_lifecycle.platform_manager - - self._register_webhook_routes() - - def _register_webhook_routes(self) -> None: - """注册 webhook 路由""" - # 统一 webhook 入口,支持 GET 和 POST - self.app.add_url_rule( - "/api/platform/webhook/", - view_func=self.unified_webhook_callback, - methods=["GET", "POST"], - ) - - # 平台统计信息接口 - self.app.add_url_rule( - "/api/platform/stats", - view_func=self.get_platform_stats, - methods=["GET"], - ) - - self.app.add_url_rule( - "/api/platform/registration/", - view_func=self.handle_platform_registration, - methods=["POST"], - ) - - async def unified_webhook_callback(self, webhook_uuid: str): - """统一 webhook 回调入口 - - Args: - webhook_uuid: 平台配置中的 webhook_uuid - - Returns: - 根据平台适配器返回相应的响应 - """ - # 根据 webhook_uuid 查找对应的平台 - platform_adapter = self._find_platform_by_uuid(webhook_uuid) - - if not platform_adapter: - logger.warning(f"未找到 webhook_uuid 为 {webhook_uuid} 的平台") - return Response().error("未找到对应平台").__dict__, 404 - - # 调用平台适配器的 webhook_callback 方法 - try: - result = await platform_adapter.webhook_callback(request) - return result - except NotImplementedError: - logger.error( - f"平台 {platform_adapter.meta().name} 未实现 webhook_callback 方法" - ) - return Response().error("平台未支持统一 Webhook 模式").__dict__, 500 - except Exception as e: - logger.error(f"处理 webhook 回调时发生错误: {e}", exc_info=True) - return Response().error("处理回调失败").__dict__, 500 - - def _find_platform_by_uuid(self, webhook_uuid: str) -> Platform | None: - """根据 webhook_uuid 查找对应的平台适配器 - - Args: - webhook_uuid: webhook UUID - - Returns: - 平台适配器实例,未找到则返回 None - """ - for platform in self.platform_manager.platform_insts: - if platform.config.get("webhook_uuid") == webhook_uuid: - if platform.unified_webhook(): - return platform - return None - - async def get_platform_stats(self): - """获取所有平台的统计信息 - - Returns: - 包含平台统计信息的响应 - """ - try: - stats = self.platform_manager.get_all_stats() - return Response().ok(stats).__dict__ - except Exception as e: - logger.error(f"获取平台统计信息失败: {e}", exc_info=True) - return Response().error(f"获取统计信息失败: {e}").__dict__, 500 - - async def handle_platform_registration(self, platform_type: str): - """Handle dashboard one-click platform registration actions.""" - try: - payload = await request.get_json(silent=True) or {} - action = str(payload.get("action", "")).strip().lower() - if not action: - return Response().error("Missing action").__dict__, 400 - - platform_config = payload.get("platform_config") - if not isinstance(platform_config, dict): - platform_config = {} - - if platform_type == "lark": - return await self._handle_lark_registration( - action, - payload, - platform_config, - ) - if platform_type == "weixin_oc": - return await self._handle_weixin_oc_registration( - action, - payload, - platform_config, - ) - if platform_type == "dingtalk": - return await self._handle_dingtalk_registration(action, payload) - - return Response().error( - f"Unsupported platform registration: {platform_type}" - ).__dict__, 404 - except Exception as e: - logger.error(f"处理平台一键创建请求失败: {e}", exc_info=True) - return Response().error(str(e)).__dict__, 500 - - async def _handle_lark_registration( - self, - action: str, - payload: dict, - platform_config: dict, - ): - domain = str(platform_config.get("domain") or "").strip() - - if action == "start": - registration = await request_app_registration(domain) - return ( - Response() - .ok( - { - "status": "pending", - "device_code": registration.device_code, - "registration_code": registration.device_code, - "user_code": registration.user_code, - "verification_uri": registration.verification_uri, - "verification_uri_complete": registration.verification_uri_complete, - "expires_in": registration.expires_in, - "interval": registration.interval, - } - ) - .__dict__ - ) - - if action == "poll": - device_code = str( - payload.get("device_code") or payload.get("registration_code") or "" - ).strip() - if not device_code: - return Response().error("Missing device_code").__dict__, 400 - result = await poll_app_registration_once( - domain=domain, - device_code=device_code, - ) - if result.get("status") == "created": - try: - bot_info = await request_lark_bot_info( - domain=str(result.get("domain") or domain), - app_id=str(result.get("app_id") or ""), - app_secret=str(result.get("app_secret") or ""), - ) - if bot_info.app_name: - result["bot_name"] = bot_info.app_name - if bot_info.open_id: - result["bot_open_id"] = bot_info.open_id - except Exception as e: - logger.error(f"获取飞书机器人信息失败: {e}", exc_info=True) - return Response().ok(result).__dict__ - - return Response().error(f"Unsupported action: {action}").__dict__, 400 - - async def _handle_dingtalk_registration(self, action: str, payload: dict): - if action == "start": - registration = await request_dingtalk_app_registration() - return ( - Response() - .ok( - { - "status": "pending", - "device_code": registration.device_code, - "registration_code": registration.device_code, - "user_code": registration.user_code, - "verification_uri": registration.verification_uri, - "verification_uri_complete": registration.verification_uri_complete, - "expires_in": registration.expires_in, - "interval": registration.interval, - } - ) - .__dict__ - ) - - if action == "poll": - device_code = str( - payload.get("device_code") or payload.get("registration_code") or "" - ).strip() - if not device_code: - return Response().error("Missing device_code").__dict__, 400 - result = await poll_dingtalk_app_registration_once(device_code) - if result.get("status") == "created": - result["platform_id_suffix"] = _random_platform_id_suffix() - return Response().ok(result).__dict__ - - return Response().error(f"Unsupported action: {action}").__dict__, 400 - - async def _handle_weixin_oc_registration( - self, - action: str, - payload: dict, - platform_config: dict, - ): - if action == "start": - registration = await request_weixin_oc_login_qr(platform_config) - return ( - Response() - .ok( - { - "status": "pending", - "registration_code": registration.qrcode, - "qrcode": registration.qrcode, - "qrcode_img_content": registration.qrcode_img_content, - "interval": registration.interval, - } - ) - .__dict__ - ) - - if action == "poll": - qrcode = str( - payload.get("qrcode") or payload.get("registration_code") or "" - ).strip() - if not qrcode: - return Response().error("Missing qrcode").__dict__, 400 - result = await poll_weixin_oc_login_once( - platform_config=platform_config, - qrcode=qrcode, - ) - if result.get("status") == "created": - result["platform_id_suffix"] = _random_platform_id_suffix() - return Response().ok(result).__dict__ - - return Response().error(f"Unsupported action: {action}").__dict__, 400 diff --git a/astrbot/dashboard/routes/plugin.py b/astrbot/dashboard/routes/plugin.py deleted file mode 100644 index ff785a0379..0000000000 --- a/astrbot/dashboard/routes/plugin.py +++ /dev/null @@ -1,2132 +0,0 @@ -import asyncio -import hashlib -import json -import mimetypes -import os -import posixpath -import re -import ssl -import traceback -from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from pathlib import Path -from typing import cast -from urllib.parse import parse_qsl, quote, urlencode, urlsplit, urlunsplit - -import aiofiles -import aiohttp -import certifi -import jwt -from aiofiles import ospath as aio_ospath -from quart import Response as QuartResponse -from quart import g, make_response, request - -from astrbot.api import sp -from astrbot.core import DEMO_MODE, file_token_service, logger -from astrbot.core.computer.computer_client import sync_skills_to_active_sandboxes -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.skills.skill_manager import SkillManager -from astrbot.core.star.filter.command import CommandFilter -from astrbot.core.star.filter.command_group import CommandGroupFilter -from astrbot.core.star.filter.permission import PermissionTypeFilter -from astrbot.core.star.filter.regex import RegexFilter -from astrbot.core.star.star import StarMetadata -from astrbot.core.star.star_handler import EventType, star_handlers_registry -from astrbot.core.star.star_manager import ( - PluginManager, - PluginVersionIncompatibleError, -) -from astrbot.core.utils.astrbot_path import ( - get_astrbot_data_path, - get_astrbot_temp_path, -) - -from .route import Response, Route, RouteContext - -PLUGIN_UPDATE_CONCURRENCY = ( - 3 # limit concurrent updates to avoid overwhelming plugin sources -) -_PLUGIN_PAGE_BRIDGE_FILE = ( - Path(__file__).resolve().parent.parent / "plugin_page_bridge.js" -) -_HTML_ASSET_ATTR_RE = re.compile( - r"(?Psrc|href)=(?P[\"\'])(?P.*?)(?P=quote)", - re.IGNORECASE, -) -_CSS_URL_RE = re.compile( - r"url\(\s*(?P[\"\']?)(?P.*?)(?P=quote)\s*\)", - re.IGNORECASE, -) -_JS_DYNAMIC_IMPORT_RE = re.compile( - r"(?P\bimport\s*\(\s*)(?P[\"\'])(?P.*?)(?P=quote)(?P\s*\))", - re.IGNORECASE, -) -_JS_MODULE_FROM_RE = re.compile( - r"(?P\b(?:import|export)\s+(?:[^;]*?\s+from\s+))(?P[\"\'])(?P.*?)(?P=quote)", - re.IGNORECASE | re.DOTALL, -) -_JS_SIDE_EFFECT_IMPORT_RE = re.compile( - r"(?P\bimport\s+)(?P[\"\'])(?P[^\"'\r\n]+)(?P=quote)", - re.IGNORECASE, -) -_PLUGIN_PAGE_ASSET_TOKEN_TYPE = "plugin_page_asset" -_PLUGIN_PAGE_ASSET_TOKEN_TTL_SECONDS = 60 -_PLUGIN_PAGE_ROOT_DIR_NAME = "pages" -_PLUGIN_PAGE_ENTRY_FILE_NAME = "index.html" - - -def _normalize_plugin_page_asset_path(asset_path: str) -> str: - return PluginRoute._normalize_plugin_page_path(asset_path, allow_empty=True) - - -PLUGIN_COMPONENT_TYPE_ORDER = { - "page": 0, - "skill": 1, - "command": 2, - "llm_tool": 3, - "listener": 4, - "hook": 5, -} - - -@dataclass -class PluginPage: - name: str - title: str - entry_file: str = _PLUGIN_PAGE_ENTRY_FILE_NAME - - -@dataclass -class RegistrySource: - urls: list[str] - cache_file: str - md5_url: str | None # None means "no remote MD5, always treat cache as stale" - - -class PluginRoute(Route): - def __init__( - self, - context: RouteContext, - core_lifecycle: AstrBotCoreLifecycle, - plugin_manager: PluginManager, - ) -> None: - super().__init__(context) - self.routes = { - "/plugin/get": ("GET", self.get_plugins), - "/plugin/detail": ("GET", self.get_plugin_detail), - "/plugin/check-compat": ("POST", self.check_plugin_compatibility), - "/plugin/page/entry": ("GET", self.get_plugin_page_entry_config), - "/plugin/install": ("POST", self.install_plugin), - "/plugin/install-upload": ("POST", self.install_plugin_upload), - "/plugin/update": ("POST", self.update_plugin), - "/plugin/update-all": ("POST", self.update_all_plugins), - "/plugin/uninstall": ("POST", self.uninstall_plugin), - "/plugin/uninstall-failed": ("POST", self.uninstall_failed_plugin), - "/plugin/market_list": ("GET", self.get_online_plugins), - "/plugin/off": ("POST", self.off_plugin), - "/plugin/on": ("POST", self.on_plugin), - "/plugin/reload-failed": ("POST", self.reload_failed_plugins), - "/plugin/reload": ("POST", self.reload_plugins), - "/plugin/readme": ("GET", self.get_plugin_readme), - "/plugin/changelog": ("GET", self.get_plugin_changelog), - "/plugin/source/get": ("GET", self.get_custom_source), - "/plugin/source/save": ("POST", self.save_custom_source), - "/plugin/source/get-failed-plugins": ("GET", self.get_failed_plugins), - } - self.core_lifecycle = core_lifecycle - self.plugin_manager = plugin_manager - self.register_routes() - self.app.add_url_rule( - "/api/plugin/page/content///", - endpoint="plugin_page_content_entry", - view_func=self.get_plugin_page_entry, - methods=["GET"], - ) - self.app.add_url_rule( - "/api/plugin/page/content///", - endpoint="plugin_page_content_asset", - view_func=self.get_plugin_page_asset, - methods=["GET"], - ) - self.app.add_url_rule( - "/api/plugin/page/bridge-sdk.js", - endpoint="plugin_page_bridge_sdk", - view_func=self.get_plugin_page_bridge_sdk, - methods=["GET"], - ) - - self.translated_event_type = { - EventType.AdapterMessageEvent: "平台消息下发时", - EventType.OnLLMRequestEvent: "LLM 请求时", - EventType.OnLLMResponseEvent: "LLM 响应后", - EventType.OnAgentBeginEvent: "Agent 开始运行时", - EventType.OnAgentDoneEvent: "Agent 运行完成后", - EventType.OnDecoratingResultEvent: "回复消息前", - EventType.OnCallingFuncToolEvent: "函数工具", - EventType.OnAfterMessageSentEvent: "发送消息后", - EventType.OnPluginErrorEvent: "插件报错时", - } - - self._logo_cache = {} - - async def get_plugin_page_entry(self, plugin_name: str, page_name: str): - return await self._serve_plugin_page_content(plugin_name, page_name, "") - - async def get_plugin_page_asset( - self, - plugin_name: str, - page_name: str, - asset_path: str, - ): - return await self._serve_plugin_page_content( - plugin_name, - page_name, - asset_path, - ) - - async def get_plugin_page_bridge_sdk(self): - if not await aio_ospath.isfile(str(_PLUGIN_PAGE_BRIDGE_FILE)): - return await self._plugin_page_error_response( - 404, "Plugin Page bridge SDK not found" - ) - bridge_js = await self._read_plugin_page_text(_PLUGIN_PAGE_BRIDGE_FILE) - initial_context = self._get_plugin_page_initial_context() - if initial_context: - context_json = json.dumps(initial_context, ensure_ascii=False) - bridge_js += ( - f"\n;window.AstrBotPluginPage?.__setInitialContext({context_json});\n" - ) - response = cast( - QuartResponse, - await make_response( - bridge_js, {"Content-Type": "application/javascript; charset=utf-8"} - ), - ) - return self._apply_plugin_page_security_headers(response) - - def _get_plugin_metadata_by_name(self, plugin_name: str) -> StarMetadata | None: - for plugin in self.plugin_manager.context.get_all_stars(): - if plugin.name == plugin_name: - return plugin - return None - - @staticmethod - def _get_by_path(source: dict | None, key: str): - if not isinstance(source, dict) or not key: - return None - current = source - for part in key.split("."): - if not isinstance(current, dict) or part not in current: - return None - current = current[part] - return current - - @staticmethod - def _get_request_locale(default: str = "zh-CN") -> str: - raw_locale = request.headers.get("Accept-Language", "").strip() - locale = raw_locale.split(",", 1)[0].split(";", 1)[0].strip() - if not locale or len(locale) > 32: - return default - return locale - - @staticmethod - def _get_request_theme() -> str | None: - theme = request.args.get("theme", "").strip() - return theme if theme in ("dark", "light") else None - - @staticmethod - def _apply_theme_to_html(html: str, theme: str) -> str: - def _replace_html_tag(m: re.Match) -> str: - attrs = m.group(1) or "" - attrs = re.sub( - r'\s+data-theme\s*=\s*["\'][^"\']*["\']', - "", - attrs, - flags=re.IGNORECASE, - ) - return f'' - - html = re.sub( - r"]*)>", - _replace_html_tag, - html, - count=1, - flags=re.IGNORECASE, - ) - - meta_tag = f'' - - html = re.sub( - r']*name\s*=\s*["\']color-scheme["\'][^>]*>', - "", - html, - flags=re.IGNORECASE, - ) - - head_match = re.search(r"]*>", html, re.IGNORECASE) - if head_match: - html = html.replace( - head_match.group(0), f"{head_match.group(0)}{meta_tag}", 1 - ) - else: - html = re.sub( - r"(]*>)", - rf"\1{meta_tag}", - html, - count=1, - flags=re.IGNORECASE, - ) - return html - - def _get_plugin_page_initial_context(self) -> dict | None: - asset_token = request.args.get("asset_token", "").strip() - if not asset_token: - return None - jwt_secret = self.config.get("dashboard", {}).get("jwt_secret") - if not isinstance(jwt_secret, str) or not jwt_secret.strip(): - return None - - try: - payload = jwt.decode(asset_token, jwt_secret, algorithms=["HS256"]) - except jwt.InvalidTokenError: - return None - if payload.get("token_type") != _PLUGIN_PAGE_ASSET_TOKEN_TYPE: - return None - - plugin_name = payload.get("plugin_name") - page_name = payload.get("page_name") - if not isinstance(plugin_name, str) or not isinstance(page_name, str): - return None - - plugin = self._get_plugin_metadata_by_name(plugin_name) - if not plugin: - return None - - locale = ( - payload.get("locale") - if isinstance(payload.get("locale"), str) - else self._get_request_locale() - ) - plugin_i18n = plugin.i18n or {} - try: - plugin_root = self._get_plugin_root_dir(plugin) - fresh_i18n = PluginManager._load_plugin_i18n(str(plugin_root)) - if fresh_i18n: - plugin_i18n = fresh_i18n - except (OSError, ValueError): - pass - - locale_data = plugin_i18n.get(locale) - display_name = ( - self._get_by_path(locale_data, "metadata.display_name") - or plugin.display_name - or plugin.name - ) - page_title = ( - self._get_by_path(locale_data, f"pages.{page_name}.title") or page_name - ) - - theme = self._get_request_theme() - - return { - "pluginName": plugin.name, - "displayName": display_name, - "pageName": page_name, - "pageTitle": page_title, - "locale": locale, - "i18n": plugin_i18n, - "isDark": theme == "dark", - } - - @staticmethod - def _normalize_plugin_page_path( - raw_path: str, - *, - base_dir: str | None = None, - allow_empty: bool = False, - ) -> str: - path = raw_path.replace("\\", "/").strip() - if base_dir: - path = posixpath.join(base_dir, path) - normalized = posixpath.normpath(path) - if normalized in {"", "."}: - if allow_empty: - return "" - raise ValueError("Invalid plugin Page asset path") - if ( - normalized.startswith("../") - or normalized == ".." - or normalized.startswith("/") - ): - raise ValueError("Invalid plugin Page asset path") - return normalized - - @staticmethod - def _normalize_plugin_page_name(raw_name: str) -> str: - page_name = raw_name.strip() - if not page_name: - raise ValueError("Invalid plugin Page name") - normalized = posixpath.normpath(page_name.replace("\\", "/")) - if ( - normalized != page_name - or normalized in {".", ".."} - or normalized.startswith(".") - or "/" in page_name - or "\\" in page_name - ): - raise ValueError("Invalid plugin Page name") - return page_name - - def _get_plugin_root_dir(self, plugin: StarMetadata) -> Path: - if not plugin.root_dir_name: - raise FileNotFoundError("Plugin directory metadata is missing") - - base_dir = Path( - self.plugin_manager.reserved_plugin_path - if plugin.reserved - else self.plugin_manager.plugin_store_path - ).resolve(strict=False) - plugin_root = (base_dir / plugin.root_dir_name).resolve(strict=False) - plugin_root.relative_to(base_dir) - return plugin_root - - async def _resolve_plugin_pages_root( - self, - plugin: StarMetadata, - ) -> Path: - plugin_root = self._get_plugin_root_dir(plugin) - pages_root = (plugin_root / _PLUGIN_PAGE_ROOT_DIR_NAME).resolve(strict=False) - pages_root.relative_to(plugin_root) - if pages_root == plugin_root: - raise FileNotFoundError("Plugin Pages root directory is invalid") - if not await aio_ospath.isdir(str(pages_root)): - raise FileNotFoundError("Plugin Pages root directory does not exist") - return pages_root - - async def _discover_plugin_pages(self, plugin: StarMetadata) -> list[PluginPage]: - try: - pages_root = await self._resolve_plugin_pages_root(plugin) - except (FileNotFoundError, ValueError): - return [] - - pages: list[PluginPage] = [] - try: - page_dirs = sorted( - (item for item in pages_root.iterdir() if item.is_dir()), - key=lambda item: item.name.lower(), - ) - except OSError: - return [] - - for page_dir in page_dirs: - try: - page_name = self._normalize_plugin_page_name(page_dir.name) - except ValueError: - continue - entry_path = page_dir / _PLUGIN_PAGE_ENTRY_FILE_NAME - if not await aio_ospath.isfile(str(entry_path)): - continue - pages.append( - PluginPage( - name=page_name, - title=page_name, - entry_file=_PLUGIN_PAGE_ENTRY_FILE_NAME, - ) - ) - return pages - - async def _get_plugin_page( - self, - plugin: StarMetadata, - page_name: str, - ) -> PluginPage: - normalized_name = self._normalize_plugin_page_name(page_name) - for page in await self._discover_plugin_pages(plugin): - if page.name == normalized_name: - return page - raise FileNotFoundError("Plugin Page entry not found") - - async def _resolve_plugin_page_root( - self, - plugin: StarMetadata, - page_name: str, - ) -> Path: - normalized_name = self._normalize_plugin_page_name(page_name) - pages_root = await self._resolve_plugin_pages_root(plugin) - page_root = (pages_root / normalized_name).resolve(strict=False) - page_root.relative_to(pages_root) - if not await aio_ospath.isdir(str(page_root)): - raise FileNotFoundError("Plugin Page root directory does not exist") - return page_root - - async def _resolve_plugin_page_file( - self, - plugin: StarMetadata, - page_name: str, - asset_path: str, - ) -> Path: - page = await self._get_plugin_page(plugin, page_name) - page_root = await self._resolve_plugin_page_root(plugin, page.name) - target_name = _normalize_plugin_page_asset_path(asset_path) or page.entry_file - target_path = (page_root / target_name).resolve(strict=False) - target_path.relative_to(page_root) - if not await aio_ospath.isfile(str(target_path)): - raise FileNotFoundError("Plugin Page asset not found") - return target_path - - @staticmethod - def _is_rewritable_asset_url(raw_url: str) -> bool: - value = raw_url.strip() - lower = value.lower() - if not value: - return False - if value.startswith(("#", "/#")): - return False - if lower.startswith( - ( - "http://", - "https://", - "//", - "data:", - "javascript:", - "mailto:", - "tel:", - "blob:", - ) - ): - return False - return True - - @staticmethod - def _resolve_referenced_asset_path( - base_asset_path: str, - referenced_url: str, - ) -> str: - parts = urlsplit(referenced_url) - referenced_path = parts.path.strip() - if not referenced_path: - raise ValueError("Plugin Page referenced asset path is empty") - base_dir = posixpath.dirname(base_asset_path) if base_asset_path else "" - normalized = PluginRoute._normalize_plugin_page_path( - referenced_path, - base_dir=base_dir, - ) - if not normalized: - raise ValueError("Plugin Page referenced asset path is invalid") - return normalized - - def _build_plugin_page_asset_url( - self, - plugin_name: str, - page_name: str, - asset_path: str, - original_query: str = "", - original_fragment: str = "", - extra_query_params: dict[str, str] | None = None, - ) -> str: - path = self._build_plugin_page_content_path(plugin_name, page_name, asset_path) - query_dict = dict(parse_qsl(original_query, keep_blank_values=True)) - if extra_query_params: - for key, value in extra_query_params.items(): - if value: - query_dict[key] = value - query = urlencode(query_dict) - return urlunsplit( - ( - "", - "", - path, - query, - original_fragment, - ) - ) - - @staticmethod - def _build_plugin_page_content_path( - plugin_name: str, - page_name: str, - asset_path: str = "", - ) -> str: - encoded_plugin_name = quote(plugin_name, safe="") - encoded_page_name = quote( - PluginRoute._normalize_plugin_page_name(page_name), - safe="", - ) - if not asset_path: - return ( - f"/api/plugin/page/content/{encoded_plugin_name}/{encoded_page_name}/" - ) - safe_asset_path = _normalize_plugin_page_asset_path(asset_path) - encoded_path = "/".join( - quote(part, safe="") for part in safe_asset_path.split("/") - ) - return ( - f"/api/plugin/page/content/{encoded_plugin_name}/" - f"{encoded_page_name}/{encoded_path}" - ) - - @staticmethod - def _get_plugin_page_bridge_sdk_url( - extra_query_params: dict[str, str] | None = None, - ) -> str: - query = urlencode(extra_query_params or {}) - return urlunsplit( - ( - "", - "", - "/api/plugin/page/bridge-sdk.js", - query, - "", - ) - ) - - @staticmethod - def _is_js_relative_module_specifier(raw_url: str) -> bool: - value = raw_url.strip() - return value.startswith(("./", "../", "/")) - - def _rewrite_relative_asset_url( - self, - raw_url: str, - base_asset_path: str, - plugin_name: str, - page_name: str, - extra_query_params: dict[str, str] | None = None, - ) -> str | None: - candidate = raw_url.strip() - if not self._is_rewritable_asset_url(candidate): - return None - parts = urlsplit(candidate) - asset_path = self._resolve_referenced_asset_path(base_asset_path, candidate) - return self._build_plugin_page_asset_url( - plugin_name, - page_name, - asset_path, - original_query=parts.query, - original_fragment=parts.fragment, - extra_query_params=extra_query_params, - ) - - def _rewrite_plugin_page_html( - self, - html_text: str, - plugin_name: str, - page_name: str, - entry_asset_path: str, - extra_query_params: dict[str, str] | None = None, - ) -> str: - def replace_attr(match: re.Match[str]) -> str: - raw_url = match.group("url") - attr = match.group("attr") - quote_char = match.group("quote") - - if raw_url.strip() == "/api/plugin/page/bridge-sdk.js": - url = self._get_plugin_page_bridge_sdk_url(extra_query_params) - return f"{attr}={quote_char}{url}{quote_char}" - - if not self._is_rewritable_asset_url(raw_url): - return match.group(0) - - try: - rewritten_url = self._rewrite_relative_asset_url( - raw_url, - entry_asset_path, - plugin_name, - page_name, - extra_query_params=extra_query_params, - ) - if not rewritten_url: - return match.group(0) - return f"{attr}={quote_char}{rewritten_url}{quote_char}" - except ValueError: - return match.group(0) - - rewritten_html = _HTML_ASSET_ATTR_RE.sub(replace_attr, html_text) - theme = self._get_request_theme() - if theme: - rewritten_html = self._apply_theme_to_html(rewritten_html, theme) - if "/api/plugin/page/bridge-sdk.js" not in rewritten_html: - bridge_tag = f'' - if "" in rewritten_html: - rewritten_html = rewritten_html.replace( - "", f"{bridge_tag}", 1 - ) - else: - rewritten_html += bridge_tag - return rewritten_html - - def _rewrite_plugin_page_css( - self, - css_text: str, - plugin_name: str, - page_name: str, - css_asset_path: str, - extra_query_params: dict[str, str] | None = None, - ) -> str: - def replace_url(match: re.Match[str]) -> str: - raw_url = match.group("url").strip() - quote_char = match.group("quote") or "" - try: - rewritten_url = self._rewrite_relative_asset_url( - raw_url, - css_asset_path, - plugin_name, - page_name, - extra_query_params=extra_query_params, - ) - if not rewritten_url: - return match.group(0) - return f"url({quote_char}{rewritten_url}{quote_char})" - except ValueError: - return match.group(0) - - return _CSS_URL_RE.sub(replace_url, css_text) - - def _rewrite_plugin_page_js( - self, - js_text: str, - plugin_name: str, - page_name: str, - js_asset_path: str, - extra_query_params: dict[str, str] | None = None, - ) -> str: - def rewrite_specifier(raw_url: str) -> str: - if not self._is_js_relative_module_specifier(raw_url): - return raw_url - if not self._is_rewritable_asset_url(raw_url): - return raw_url - rewritten = self._rewrite_relative_asset_url( - raw_url, - js_asset_path, - plugin_name, - page_name, - extra_query_params=extra_query_params, - ) - return rewritten or raw_url - - def replace_dynamic(match: re.Match[str]) -> str: - raw_url = match.group("url") - try: - rewritten = rewrite_specifier(raw_url) - except ValueError: - return match.group(0) - return ( - f"{match.group('prefix')}{match.group('quote')}{rewritten}" - f"{match.group('quote')}{match.group('suffix')}" - ) - - def replace_from(match: re.Match[str]) -> str: - raw_url = match.group("url") - try: - rewritten = rewrite_specifier(raw_url) - except ValueError: - return match.group(0) - return f"{match.group('prefix')}{match.group('quote')}{rewritten}{match.group('quote')}" - - rewritten_js = _JS_DYNAMIC_IMPORT_RE.sub(replace_dynamic, js_text) - rewritten_js = _JS_MODULE_FROM_RE.sub(replace_from, rewritten_js) - - def replace_side_effect(match: re.Match[str]) -> str: - raw_url = match.group("url") - if raw_url.startswith(("{", "*")): - return match.group(0) - try: - rewritten = rewrite_specifier(raw_url) - except ValueError: - return match.group(0) - return f"{match.group('prefix')}{match.group('quote')}{rewritten}{match.group('quote')}" - - return _JS_SIDE_EFFECT_IMPORT_RE.sub(replace_side_effect, rewritten_js) - - @staticmethod - async def _read_plugin_page_text(file_path: Path) -> str: - async with aiofiles.open(file_path, encoding="utf-8") as file: - return await file.read() - - @staticmethod - async def _read_plugin_page_binary(file_path: Path) -> bytes: - async with aiofiles.open(file_path, mode="rb") as file: - return await file.read() - - @staticmethod - def _guess_plugin_page_mime_type(file_path: Path) -> str: - return mimetypes.guess_type(file_path.name)[0] or "application/octet-stream" - - async def _serialize_plugin_page( - self, - plugin: StarMetadata, - page_name: str, - *, - include_content_path: bool = False, - ) -> dict | None: - plugin_name = plugin.name.strip() if isinstance(plugin.name, str) else "" - if not plugin_name: - return None - try: - page = await self._get_plugin_page(plugin, page_name) - await self._resolve_plugin_page_file(plugin, page.name, "") - except (FileNotFoundError, ValueError): - return None - - page_data = { - "name": page.name, - "title": page.title, - "i18n_key": f"pages.{page.name}", - } - if include_content_path: - asset_token = ( - self._issue_plugin_page_asset_token(plugin_name, page.name) or "" - ) - extra_query_params = {"asset_token": asset_token} if asset_token else None - page_data["content_path"] = self._build_plugin_page_asset_url( - plugin_name, - page.name, - "", - extra_query_params=extra_query_params, - ) - return page_data - - async def _serialize_plugin_pages(self, plugin: StarMetadata) -> list[dict]: - pages = [] - for page in await self._discover_plugin_pages(plugin): - page_data = await self._serialize_plugin_page(plugin, page.name) - if page_data: - pages.append(page_data) - return pages - - def _issue_plugin_page_asset_token( - self, - plugin_name: str, - page_name: str, - ) -> str | None: - jwt_secret = self.config.get("dashboard", {}).get("jwt_secret") - if not isinstance(jwt_secret, str) or not jwt_secret.strip(): - return None - - username = getattr(g, "username", None) - if not isinstance(username, str) or not username.strip(): - return None - - now = datetime.now(timezone.utc) - payload = { - "username": username, - "token_type": _PLUGIN_PAGE_ASSET_TOKEN_TYPE, - "plugin_name": plugin_name, - "page_name": page_name, - "locale": self._get_request_locale(), - "iat": now, - "exp": now + timedelta(seconds=_PLUGIN_PAGE_ASSET_TOKEN_TTL_SECONDS), - } - return cast(str, jwt.encode(payload, jwt_secret, algorithm="HS256")) - - def _prepare_plugin_page_query_params( - self, - plugin_name: str, - page_name: str, - ) -> dict[str, str] | None: - asset_token = request.args.get("asset_token", "").strip() - if not asset_token: - asset_token = ( - self._issue_plugin_page_asset_token(plugin_name, page_name) or "" - ) - theme = self._get_request_theme() - - if not asset_token and not theme: - return None - - params: dict[str, str] = {} - if asset_token: - params["asset_token"] = asset_token - if theme: - params["theme"] = theme - return params - - @staticmethod - async def _plugin_page_error_response(status_code: int, message: str): - response = await make_response(message, status_code) - response.headers["Cache-Control"] = "no-store" - response.headers["Content-Type"] = "text/plain; charset=utf-8" - response.headers["Referrer-Policy"] = "no-referrer" - return response - - @staticmethod - def _apply_plugin_page_security_headers(response: QuartResponse) -> QuartResponse: - response.headers["Cache-Control"] = "no-store" - response.headers["Referrer-Policy"] = "no-referrer" - response.headers["X-Content-Type-Options"] = "nosniff" - response.headers["Cross-Origin-Resource-Policy"] = "cross-origin" - # Sandboxed iframes without allow-same-origin load ES modules with Origin: null. - # CORS read access is allowed here; JWT/asset_token still protects the assets. - response.headers["Access-Control-Allow-Origin"] = "*" - - # When running under the AstrBot Launcher the dashboard is embedded in a - # cross-origin iframe (the Tauri webview). Since frame-ancestors and - # X-Frame-Options inspect the *entire* ancestor chain, enforcing them here - # would block plugin pages from loading inside the nested iframe. - csp = "object-src 'none'; base-uri 'self'" - if os.environ.get("ASTRBOT_LAUNCHER") not in ("1", "true"): - response.headers["X-Frame-Options"] = "SAMEORIGIN" - csp = f"frame-ancestors 'self'; {csp}" - response.headers["Content-Security-Policy"] = csp - - return response - - async def _serve_plugin_page_html_asset( - self, - file_path: Path, - plugin_name: str, - page_name: str, - asset_path: str, - extra_query_params: dict[str, str] | None, - ): - html_text = await self._read_plugin_page_text(file_path) - rewritten_html = self._rewrite_plugin_page_html( - html_text, - plugin_name, - page_name, - asset_path, - extra_query_params=extra_query_params, - ) - response = cast( - QuartResponse, - await make_response( - rewritten_html, {"Content-Type": "text/html; charset=utf-8"} - ), - ) - return self._apply_plugin_page_security_headers(response) - - async def _serve_plugin_page_css_asset( - self, - file_path: Path, - plugin_name: str, - page_name: str, - asset_path: str, - extra_query_params: dict[str, str] | None, - ): - css_text = await self._read_plugin_page_text(file_path) - rewritten_css = self._rewrite_plugin_page_css( - css_text, - plugin_name, - page_name, - asset_path, - extra_query_params=extra_query_params, - ) - response = cast( - QuartResponse, - await make_response( - rewritten_css, {"Content-Type": "text/css; charset=utf-8"} - ), - ) - return self._apply_plugin_page_security_headers(response) - - async def _serve_plugin_page_js_asset( - self, - file_path: Path, - plugin_name: str, - page_name: str, - asset_path: str, - extra_query_params: dict[str, str] | None, - ): - js_text = await self._read_plugin_page_text(file_path) - rewritten_js = self._rewrite_plugin_page_js( - js_text, - plugin_name, - page_name, - asset_path, - extra_query_params=extra_query_params, - ) - response = cast( - QuartResponse, - await make_response( - rewritten_js, - {"Content-Type": "application/javascript; charset=utf-8"}, - ), - ) - return self._apply_plugin_page_security_headers(response) - - async def _serve_plugin_page_static_asset(self, file_path: Path): - raw_bytes = await self._read_plugin_page_binary(file_path) - response = cast( - QuartResponse, - await make_response( - raw_bytes, - {"Content-Type": self._guess_plugin_page_mime_type(file_path)}, - ), - ) - return self._apply_plugin_page_security_headers(response) - - async def _serve_plugin_page_content( - self, - plugin_name: str, - page_name: str, - asset_path: str, - ): - plugin = self._get_plugin_metadata_by_name(plugin_name) - if not plugin: - return await self._plugin_page_error_response(404, "Plugin not found") - if not plugin.activated: - return await self._plugin_page_error_response(403, "Plugin is disabled") - - try: - page = await self._get_plugin_page(plugin, page_name) - file_path = await self._resolve_plugin_page_file( - plugin, - page.name, - asset_path, - ) - except (FileNotFoundError, ValueError): - return await self._plugin_page_error_response( - 404, "Plugin Page asset not found" - ) - - extra_query_params = self._prepare_plugin_page_query_params( - plugin_name, - page.name, - ) - served_asset_path = asset_path or page.entry_file - suffix = file_path.suffix.lower() - handlers = { - ".html": self._serve_plugin_page_html_asset, - ".css": self._serve_plugin_page_css_asset, - ".js": self._serve_plugin_page_js_asset, - ".mjs": self._serve_plugin_page_js_asset, - } - handler = handlers.get(suffix) - if handler: - return await handler( - file_path, - plugin_name, - page.name, - served_asset_path, - extra_query_params, - ) - return await self._serve_plugin_page_static_asset(file_path) - - async def _sync_skills_after_plugin_change(self) -> None: - try: - await sync_skills_to_active_sandboxes() - except Exception: - logger.warning("Failed to sync plugin-provided skills to active sandboxes.") - - async def check_plugin_compatibility(self): - try: - data = await request.get_json() - version_spec = data.get("astrbot_version", "") - is_valid, message = self.plugin_manager._validate_astrbot_version_specifier( - version_spec - ) - return ( - Response() - .ok( - { - "compatible": is_valid, - "message": message, - "astrbot_version": version_spec, - } - ) - .__dict__ - ) - except Exception as e: - return Response().error(str(e)).__dict__ - - async def get_plugin_page_entry_config(self): - plugin_name = request.args.get("name") - if not plugin_name: - return Response().error("缺少插件名").__dict__ - page_name = request.args.get("page") - if not page_name: - return Response().error("缺少 Page 名称").__dict__ - - for plugin in self.plugin_manager.context.get_all_stars(): - if plugin.name != plugin_name: - continue - if not plugin.activated: - return Response().error("插件未启用").__dict__ - - page = await self._serialize_plugin_page( - plugin, - page_name, - include_content_path=True, - ) - if not page: - return Response().error("插件 Page 不存在").__dict__ - return Response().ok(page).__dict__ - - return Response().error("插件不存在").__dict__ - - async def reload_failed_plugins(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - try: - data = await request.get_json() - dir_name = data.get("dir_name") # 这里拿的是目录名,不是插件名 - - if not dir_name: - return Response().error("缺少插件目录名").__dict__ - - # 调用 star_manager.py 中的函数 - # 注意:传入的是目录名 - success, err = await self.plugin_manager.reload_failed_plugin(dir_name) - - if success: - await self._sync_skills_after_plugin_change() - return Response().ok(None, f"插件 {dir_name} 重载成功。").__dict__ - else: - return Response().error(f"重载失败: {err}").__dict__ - - except Exception as e: - logger.error(f"/api/plugin/reload-failed: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ - - async def reload_plugins(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - data = await request.get_json() - plugin_name = data.get("name", None) - try: - success, message = await self.plugin_manager.reload(plugin_name) - if not success: - return Response().error(message or "插件重载失败").__dict__ - await self._sync_skills_after_plugin_change() - return Response().ok(None, "重载成功。").__dict__ - except Exception as e: - logger.error(f"/api/plugin/reload: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ - - async def get_online_plugins(self): - custom = request.args.get("custom_registry") - force_refresh = request.args.get("force_refresh", "false").lower() == "true" - - # 构建注册表源信息 - source = self._build_registry_source(custom) - - # 如果不是强制刷新,先检查缓存是否有效 - cached_data = None - if not force_refresh: - # 先检查MD5是否匹配,如果匹配则使用缓存 - if await self._is_cache_valid(source): - cached_data = self._load_plugin_cache(source.cache_file) - if cached_data: - logger.debug("缓存MD5匹配,使用缓存的插件市场数据") - return Response().ok(cached_data).__dict__ - - # 尝试获取远程数据 - remote_data = None - ssl_context = ssl.create_default_context(cafile=certifi.where()) - connector = aiohttp.TCPConnector(ssl=ssl_context) - - for url in source.urls: - try: - async with ( - aiohttp.ClientSession( - trust_env=True, - connector=connector, - ) as session, - session.get(url) as response, - ): - if response.status == 200: - try: - remote_data = await response.json() - except aiohttp.ContentTypeError: - remote_text = await response.text() - remote_data = json.loads(remote_text) - - # 检查远程数据是否为空 - if not remote_data or ( - isinstance(remote_data, dict) and len(remote_data) == 0 - ): - logger.warning(f"远程插件市场数据为空: {url}") - continue # 继续尝试其他URL或使用缓存 - - logger.info( - f"成功获取远程插件市场数据,包含 {len(remote_data)} 个插件" - ) - # 获取最新的MD5并保存到缓存 - current_md5 = await self._fetch_remote_md5(source.md5_url) - self._save_plugin_cache( - source.cache_file, - remote_data, - current_md5, - ) - return Response().ok(remote_data).__dict__ - logger.error(f"请求 {url} 失败,状态码:{response.status}") - except Exception as e: - logger.error(f"请求 {url} 失败,错误:{e}") - - # 如果远程获取失败,尝试使用缓存数据 - if not cached_data: - cached_data = self._load_plugin_cache(source.cache_file) - - if cached_data: - logger.warning("远程插件市场数据获取失败,使用缓存数据") - return Response().ok(cached_data, "使用缓存数据,可能不是最新版本").__dict__ - - return Response().error("获取插件列表失败,且没有可用的缓存数据").__dict__ - - def _build_registry_source(self, custom_url: str | None) -> RegistrySource: - """构建注册表源信息""" - data_dir = get_astrbot_data_path() - if custom_url: - # 对自定义URL生成一个安全的文件名 - url_hash = hashlib.md5(custom_url.encode()).hexdigest()[:8] - cache_file = os.path.join(data_dir, f"plugins_custom_{url_hash}.json") - - # 更安全的后缀处理方式 - if custom_url.endswith(".json"): - md5_url = custom_url[:-5] + "-md5.json" - else: - md5_url = custom_url + "-md5.json" - - urls = [custom_url] - else: - cache_file = os.path.join(data_dir, "plugins.json") - md5_url = "https://api.soulter.top/astrbot/plugins-md5" - urls = [ - "https://api.soulter.top/astrbot/plugins", - "https://github.com/AstrBotDevs/AstrBot_Plugins_Collection/raw/refs/heads/main/plugin_cache_original.json", - ] - return RegistrySource(urls=urls, cache_file=cache_file, md5_url=md5_url) - - def _load_cached_md5(self, cache_file: str) -> str | None: - """从缓存文件中加载MD5""" - if not os.path.exists(cache_file): - return None - - try: - with open(cache_file, encoding="utf-8") as f: - cache_data = json.load(f) - return cache_data.get("md5") - except Exception as e: - logger.warning(f"Failed to load cached MD5: {e}") - return None - - async def _fetch_remote_md5(self, md5_url: str | None) -> str | None: - """获取远程MD5""" - if not md5_url: - return None - - try: - ssl_context = ssl.create_default_context(cafile=certifi.where()) - connector = aiohttp.TCPConnector(ssl=ssl_context) - - async with ( - aiohttp.ClientSession( - trust_env=True, - connector=connector, - ) as session, - session.get(md5_url) as response, - ): - if response.status == 200: - data = await response.json() - return data.get("md5", "") - except Exception as e: - logger.debug(f"Failed to fetch remote MD5: {e}") - return None - - async def _is_cache_valid(self, source: RegistrySource) -> bool: - """检查缓存是否有效(基于MD5)""" - try: - cached_md5 = self._load_cached_md5(source.cache_file) - if not cached_md5: - logger.debug("MD5 not found in cache, treating cache as invalid") - return False - - remote_md5 = await self._fetch_remote_md5(source.md5_url) - if remote_md5 is None: - logger.warning( - "Cannot fetch remote MD5, using cache without validation" - ) - return True # 如果无法获取远程MD5,认为缓存有效 - - is_valid = cached_md5 == remote_md5 - logger.debug( - f"Plugin cache: local={cached_md5}, remote={remote_md5}, effective={is_valid}", - ) - return is_valid - - except Exception as e: - logger.warning(f"检查缓存有效性失败: {e}") - return False - - def _load_plugin_cache(self, cache_file: str): - """加载本地缓存的插件市场数据""" - try: - if os.path.exists(cache_file): - with open(cache_file, encoding="utf-8") as f: - cache_data = json.load(f) - # 检查缓存是否有效 - if "data" in cache_data and "timestamp" in cache_data: - logger.debug( - f"Loading cached file: {cache_file}, Cache time: {cache_data['timestamp']}", - ) - return cache_data["data"] - except Exception as e: - logger.warning(f"Failed to load plugin market cache: {e}") - return None - - def _save_plugin_cache(self, cache_file: str, data, md5: str | None = None) -> None: - """保存插件市场数据到本地缓存""" - try: - # 确保目录存在 - os.makedirs(os.path.dirname(cache_file), exist_ok=True) - - cache_data = { - "timestamp": datetime.now().isoformat(), - "data": data, - "md5": md5 or "", - } - - with open(cache_file, "w", encoding="utf-8") as f: - json.dump(cache_data, f, ensure_ascii=False, indent=2) - logger.debug(f"Cached plugin market data: {cache_file}, MD5: {md5}") - except Exception as e: - logger.warning(f"Failed to save plugin market cache: {e}") - - async def get_plugin_logo_token(self, logo_path: str): - try: - if token := self._logo_cache.get(logo_path): - if not await file_token_service.check_token_expired(token): - return self._logo_cache[logo_path] - token = await file_token_service.register_file(logo_path, timeout=300) - self._logo_cache[logo_path] = token - return token - except Exception as e: - logger.warning(f"获取插件 Logo 失败: {e}") - return None - - def _resolve_plugin_dir(self, plugin) -> Path | None: - if not plugin.root_dir_name: - return None - - base_dir = Path( - self.plugin_manager.reserved_plugin_path - if plugin.reserved - else self.plugin_manager.plugin_store_path - ) - plugin_dir = base_dir / plugin.root_dir_name - if not plugin_dir.is_dir(): - return None - return plugin_dir - - def _get_plugin_installed_at(self, plugin) -> str | None: - plugin_dir = self._resolve_plugin_dir(plugin) - if plugin_dir is None: - return None - - try: - return datetime.fromtimestamp( - plugin_dir.stat().st_mtime, - timezone.utc, - ).isoformat() - except OSError as exc: - logger.warning(f"获取插件安装时间失败 {plugin.name}: {exc!s}") - return None - - async def get_plugins(self): - _plugin_resp = [] - plugin_name = request.args.get("name") - - plugins = [ - p - for p in self.plugin_manager.context.get_all_stars() - if not (plugin_name and p.name != plugin_name) - ] - - async def process_plugin(plugin): - logo_url = None - if plugin.logo_path: - logo_url = await self.get_plugin_logo_token(plugin.logo_path) - pages = await self._discover_plugin_pages(plugin) - return plugin, logo_url, pages - - results = await asyncio.gather(*(process_plugin(p) for p in plugins)) - - for plugin, logo_url, pages in results: - _t = { - "name": plugin.name, - "marketplace_name": (plugin.name or "").replace("_", "-"), - "repo": "" if plugin.repo is None else str(plugin.repo), - "author": plugin.author, - "desc": plugin.desc, - "version": plugin.version, - "reserved": plugin.reserved, - "activated": plugin.activated, - "online_vesion": "", - "display_name": plugin.display_name, - "logo": f"/api/file/{logo_url}" if logo_url else None, - "support_platforms": plugin.support_platforms, - "astrbot_version": plugin.astrbot_version, - "installed_at": self._get_plugin_installed_at(plugin), - "i18n": plugin.i18n, - "pages": [p.name for p in pages], - } - # 检查是否为全空的幽灵插件 - if not any( - [ - plugin.name, - plugin.author, - plugin.desc, - plugin.version, - plugin.display_name, - ] - ): - continue - _plugin_resp.append(_t) - return ( - Response() - .ok(_plugin_resp, message=self.plugin_manager.failed_plugin_info) - .__dict__ - ) - - async def get_plugin_detail(self): - plugin_name = request.args.get("name") - if not plugin_name: - return Response().error("缺少插件名").__dict__ - - for plugin in self.plugin_manager.context.get_all_stars(): - if plugin.name != plugin_name: - continue - - logo_url = None - if plugin.logo_path: - logo_url = await self.get_plugin_logo_token(plugin.logo_path) - - return ( - Response() - .ok( - { - "name": plugin.name, - "marketplace_name": (plugin.name or "").replace("_", "-"), - "repo": "" if plugin.repo is None else str(plugin.repo), - "author": plugin.author, - "desc": plugin.desc, - "version": plugin.version, - "reserved": plugin.reserved, - "activated": plugin.activated, - "online_vesion": "", - "components": await self.get_plugin_components_info(plugin), - "display_name": plugin.display_name, - "logo": f"/api/file/{logo_url}" if logo_url else None, - "support_platforms": plugin.support_platforms, - "astrbot_version": plugin.astrbot_version, - "installed_at": self._get_plugin_installed_at(plugin), - "i18n": plugin.i18n, - } - ) - .__dict__ - ) - - return Response().error("插件不存在").__dict__ - - async def get_failed_plugins(self): - """专门获取加载失败的插件列表(字典格式)""" - return Response().ok(self.plugin_manager.failed_plugin_dict).__dict__ - - async def get_plugin_components_info(self, plugin): - """Build plugin components for the dashboard.""" - page_components = await self.get_plugin_page_components(plugin) - handler_components = await self.get_plugin_handler_components( - plugin.star_handler_full_names, - ) - components = [ - *page_components, - *self.get_plugin_skill_components(plugin), - *handler_components, - ] - return sorted( - components, - key=lambda item: PLUGIN_COMPONENT_TYPE_ORDER.get(item["type"], 99), - ) - - async def get_plugin_page_components(self, plugin) -> list[dict]: - pages = await self._serialize_plugin_pages(plugin) - return [ - { - "type": "page", - "name": page["title"], - "title": page["title"], - "page_name": page["name"], - "i18n_key": page["i18n_key"], - "description": "Plugin Page entry", - "plugin_name": plugin.name, - "plugin_marketplace_name": (plugin.name or "").replace("_", "-"), - } - for page in pages - ] - - async def get_plugin_handler_components(self, handler_full_names: list[str]): - """Build behavior components from registered handlers.""" - components = [] - - for handler_full_name in handler_full_names: - info = {} - handler = star_handlers_registry.star_handlers_map.get( - handler_full_name, - None, - ) - if handler is None: - continue - info["event_type"] = handler.event_type.name - info["event_type_h"] = self.translated_event_type.get( - handler.event_type, - handler.event_type.name, - ) - info["handler_full_name"] = handler.handler_full_name - info["description"] = handler.desc or "无描述" - info["handler_name"] = handler.handler_name - - component_type = "hook" - component = None - if handler.event_type == EventType.AdapterMessageEvent: - # 处理平台适配器消息事件 - has_admin = False - for event_filter in ( - handler.event_filters - ): # 正常handler就只有 1~2 个 filter,因此这里时间复杂度不会太高 - if isinstance(event_filter, CommandFilter): - component_type = "command" - info["display_type"] = "指令" - info["cmd"] = self._get_command_filter_display_name( - event_filter - ) - component = self._build_command_filter_component( - event_filter, - handler.desc, - ) - elif isinstance(event_filter, CommandGroupFilter): - component_type = "command" - info["display_type"] = "指令组" - info["cmd"] = event_filter.get_complete_command_names()[0] - info["cmd"] = info["cmd"].strip() - component = self._build_command_group_component( - event_filter, - handler.desc, - ) - elif isinstance(event_filter, RegexFilter): - component_type = "command" - info["display_type"] = "正则匹配" - info["cmd"] = event_filter.regex_str - component = { - "type": "command", - "name": event_filter.regex_str, - "description": handler.desc or "无描述", - "match": "regex", - } - elif isinstance(event_filter, PermissionTypeFilter): - has_admin = True - info["has_admin"] = has_admin - if "cmd" not in info: - info["cmd"] = "未知" - if "display_type" not in info: - info["display_type"] = "事件监听器" - component_type = "listener" - else: - info["cmd"] = "自动触发" - info["display_type"] = "无" - if handler.event_type == EventType.OnCallingFuncToolEvent: - component_type = "llm_tool" - - if component is None: - component = { - "type": component_type, - "name": handler.handler_name or handler.event_type.name, - "description": handler.desc or "无描述", - } - else: - component["type"] = component_type - - if component_type == "command": - component["event_type"] = info["event_type"] - component["event_type_h"] = info["event_type_h"] - component["handler_name"] = info["handler_name"] - component["has_admin"] = info.get("has_admin", False) - if "display_type" in info: - component["display_type"] = info["display_type"] - if "cmd" in info: - component["command"] = info["cmd"] - else: - component.update(info) - components.append(component) - - return self._merge_command_components(components) - - def get_plugin_skill_components(self, plugin): - """Build skill components provided by this plugin.""" - plugin_names = { - str(name) - for name in (plugin.root_dir_name, plugin.name) - if str(name or "").strip() - } - if not plugin_names: - return [] - - try: - skills = SkillManager().list_skills( - active_only=False, - runtime="local", - show_sandbox_path=False, - ) - except Exception as exc: - logger.warning(f"获取插件 Skills 失败 {plugin.name}: {exc!s}") - return [] - - components = [] - for skill in skills: - if skill.source_type != "plugin" or skill.plugin_name not in plugin_names: - continue - components.append( - { - "type": "skill", - "name": skill.name, - "description": skill.description or "无描述", - "path": skill.path, - } - ) - return components - - def _get_command_filter_display_name(self, command_filter: CommandFilter) -> str: - return command_filter.get_complete_command_names()[0].strip() - - def _get_command_description( - self, - command_filter: CommandFilter | CommandGroupFilter, - fallback: str = "", - ) -> str: - handler_md = getattr(command_filter, "handler_md", None) - desc = getattr(handler_md, "desc", "") if handler_md else "" - return desc or fallback or "无描述" - - def _build_command_filter_component( - self, - command_filter: CommandFilter, - fallback_desc: str = "", - ) -> dict: - parts = self._get_command_filter_display_name(command_filter).split() - if not parts: - parts = [command_filter.command_name] - component = { - "type": "command", - "name": parts[-1], - "description": self._get_command_description( - command_filter, - fallback_desc, - ), - } - return self._wrap_command_component(parts[:-1], component) - - def _build_command_group_component( - self, - command_group_filter: CommandGroupFilter, - fallback_desc: str = "", - ) -> dict: - parts = command_group_filter.get_complete_command_names()[0].strip().split() - if not parts: - parts = [command_group_filter.group_name] - subcommands = [ - self._build_command_group_child(sub_filter) - for sub_filter in command_group_filter.sub_command_filters - ] - component = { - "type": "command", - "name": parts[-1], - "description": self._get_command_description( - command_group_filter, - fallback_desc, - ), - } - if subcommands: - component["subcommands"] = subcommands - return self._wrap_command_component(parts[:-1], component) - - def _build_command_group_child( - self, - command_filter: CommandFilter | CommandGroupFilter, - ) -> dict: - if isinstance(command_filter, CommandGroupFilter): - component = { - "name": command_filter.group_name, - "description": self._get_command_description(command_filter), - } - subcommands = [ - self._build_command_group_child(sub_filter) - for sub_filter in command_filter.sub_command_filters - ] - if subcommands: - component["subcommands"] = subcommands - return component - - return { - "name": command_filter.command_name, - "description": self._get_command_description(command_filter), - } - - def _wrap_command_component(self, parent_names: list[str], component: dict) -> dict: - for parent_name in reversed(parent_names): - component = { - "type": "command", - "name": parent_name, - "description": "无描述", - "subcommands": [component], - } - return component - - def _merge_command_components(self, components: list[dict]) -> list[dict]: - merged: list[dict] = [] - for component in components: - if component.get("type") != "command": - merged.append(component) - continue - existing = next( - ( - item - for item in merged - if item.get("type") == "command" - and item.get("name") == component.get("name") - and item.get("match") == component.get("match") - ), - None, - ) - if existing is None: - merged.append(component) - continue - self._merge_command_component(existing, component) - return merged - - def _merge_command_component(self, target: dict, source: dict) -> None: - if target.get("description") == "无描述" and source.get("description"): - target["description"] = source["description"] - for key, value in source.items(): - if key in {"subcommands", "description"}: - continue - target.setdefault(key, value) - - source_subcommands = source.get("subcommands") - if not isinstance(source_subcommands, list): - return - target_subcommands = target.setdefault("subcommands", []) - for source_subcommand in source_subcommands: - if not isinstance(source_subcommand, dict): - continue - existing = next( - ( - item - for item in target_subcommands - if isinstance(item, dict) - and item.get("name") == source_subcommand.get("name") - ), - None, - ) - if existing is None: - target_subcommands.append(source_subcommand) - continue - self._merge_command_component(existing, source_subcommand) - - async def install_plugin(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - post_data = await request.get_json() - repo_url = post_data["url"] - download_url = str(post_data.get("download_url") or "").strip() - ignore_version_check = bool(post_data.get("ignore_version_check", False)) - - proxy: str = post_data.get("proxy", None) - if proxy: - proxy = proxy.removesuffix("/") - - try: - logger.info(f"正在安装插件 {repo_url}") - plugin_info = await self.plugin_manager.install_plugin( - repo_url, - proxy, - ignore_version_check=ignore_version_check, - download_url=download_url, - ) - # self.core_lifecycle.restart() - await self._sync_skills_after_plugin_change() - logger.info(f"安装插件 {repo_url} 成功。") - return Response().ok(plugin_info, "安装成功。").__dict__ - except PluginVersionIncompatibleError as e: - return { - "status": "warning", - "message": str(e), - "data": { - "warning_type": "astrbot_version_incompatible", - "can_ignore": True, - }, - } - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def install_plugin_upload(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - try: - file = await request.files - file = file["file"] - form_data = await request.form - ignore_version_check = ( - str(form_data.get("ignore_version_check", "false")).lower() == "true" - ) - logger.info(f"正在安装用户上传的插件 {file.filename}") - file_path = os.path.join( - get_astrbot_temp_path(), - f"plugin_upload_{file.filename}", - ) - await file.save(file_path) - plugin_info = await self.plugin_manager.install_plugin_from_file( - file_path, - ignore_version_check=ignore_version_check, - ) - # self.core_lifecycle.restart() - await self._sync_skills_after_plugin_change() - logger.info(f"安装插件 {file.filename} 成功") - return Response().ok(plugin_info, "安装成功。").__dict__ - except PluginVersionIncompatibleError as e: - return { - "status": "warning", - "message": str(e), - "data": { - "warning_type": "astrbot_version_incompatible", - "can_ignore": True, - }, - } - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def uninstall_plugin(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - post_data = await request.get_json() - plugin_name = post_data["name"] - delete_config = post_data.get("delete_config", False) - delete_data = post_data.get("delete_data", False) - try: - logger.info(f"正在卸载插件 {plugin_name}") - await self.plugin_manager.uninstall_plugin( - plugin_name, - delete_config=delete_config, - delete_data=delete_data, - ) - await self._sync_skills_after_plugin_change() - logger.info(f"卸载插件 {plugin_name} 成功") - return Response().ok(None, "卸载成功").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def uninstall_failed_plugin(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - post_data = await request.get_json() - dir_name = post_data.get("dir_name", "") - delete_config = post_data.get("delete_config", False) - delete_data = post_data.get("delete_data", False) - if not dir_name: - return Response().error("缺少失败插件目录名").__dict__ - - try: - logger.info(f"正在卸载失败插件 {dir_name}") - await self.plugin_manager.uninstall_failed_plugin( - dir_name, - delete_config=delete_config, - delete_data=delete_data, - ) - await self._sync_skills_after_plugin_change() - logger.info(f"卸载失败插件 {dir_name} 成功") - return Response().ok(None, "卸载成功").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def update_plugin(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - post_data = await request.get_json() - plugin_name = post_data["name"] - proxy: str = post_data.get("proxy", None) - download_url = str(post_data.get("download_url") or "").strip() - try: - logger.info(f"正在更新插件 {plugin_name}") - await self.plugin_manager.update_plugin( - plugin_name, proxy, download_url=download_url - ) - # self.core_lifecycle.restart() - await self.plugin_manager.reload(plugin_name) - await self._sync_skills_after_plugin_change() - logger.info(f"更新插件 {plugin_name} 成功。") - return Response().ok(None, "更新成功。").__dict__ - except Exception as e: - logger.error(f"/api/plugin/update: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ - - async def update_all_plugins(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - post_data = await request.get_json() - plugin_names: list[str] = post_data.get("names") or [] - proxy: str = post_data.get("proxy", "") - download_urls: dict[str, str] = post_data.get("download_urls") or {} - - if not isinstance(plugin_names, list) or not plugin_names: - return Response().error("插件列表不能为空").__dict__ - if not isinstance(download_urls, dict): - download_urls = {} - - results = [] - sem = asyncio.Semaphore(PLUGIN_UPDATE_CONCURRENCY) - - async def _update_one(name: str): - async with sem: - try: - logger.info(f"批量更新插件 {name}") - download_url = str(download_urls.get(name) or "").strip() - await self.plugin_manager.update_plugin( - name, proxy, download_url=download_url - ) - return {"name": name, "status": "ok", "message": "更新成功"} - except Exception as e: - logger.error( - f"/api/plugin/update-all: 更新插件 {name} 失败: {traceback.format_exc()}", - ) - return {"name": name, "status": "error", "message": str(e)} - - raw_results = await asyncio.gather( - *(_update_one(name) for name in plugin_names), - return_exceptions=True, - ) - for name, result in zip(plugin_names, raw_results): - if isinstance(result, asyncio.CancelledError): - raise result - if isinstance(result, BaseException): - results.append( - {"name": name, "status": "error", "message": str(result)} - ) - else: - results.append(result) - - failed = [r for r in results if r["status"] == "error"] - if len(failed) < len(results): - await self._sync_skills_after_plugin_change() - message = ( - "批量更新完成,全部成功。" - if not failed - else f"批量更新完成,其中 {len(failed)}/{len(results)} 个插件失败。" - ) - - return Response().ok({"results": results}, message).__dict__ - - async def off_plugin(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - post_data = await request.get_json() - plugin_name = post_data["name"] - try: - await self.plugin_manager.turn_off_plugin(plugin_name) - await self._sync_skills_after_plugin_change() - logger.info(f"停用插件 {plugin_name} 。") - return Response().ok(None, "停用成功。").__dict__ - except Exception as e: - logger.error(f"/api/plugin/off: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ - - async def on_plugin(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - post_data = await request.get_json() - plugin_name = post_data["name"] - try: - await self.plugin_manager.turn_on_plugin(plugin_name) - await self._sync_skills_after_plugin_change() - logger.info(f"启用插件 {plugin_name} 。") - return Response().ok(None, "启用成功。").__dict__ - except Exception as e: - logger.error(f"/api/plugin/on: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ - - async def get_plugin_readme(self): - plugin_name = request.args.get("name") - - if not plugin_name: - logger.warning("插件名称为空") - return Response().error("插件名称不能为空").__dict__ - - plugin_obj = None - for plugin in self.plugin_manager.context.get_all_stars(): - if plugin.name == plugin_name: - plugin_obj = plugin - break - - if not plugin_obj: - logger.warning(f"插件 {plugin_name} 不存在") - return Response().error(f"插件 {plugin_name} 不存在").__dict__ - - if not plugin_obj.root_dir_name: - logger.warning(f"插件 {plugin_name} 目录不存在") - return Response().error(f"插件 {plugin_name} 目录不存在").__dict__ - - if plugin_obj.reserved: - plugin_dir = os.path.join( - self.plugin_manager.reserved_plugin_path, - plugin_obj.root_dir_name, - ) - else: - plugin_dir = os.path.join( - self.plugin_manager.plugin_store_path, - plugin_obj.root_dir_name, - ) - - if not os.path.isdir(plugin_dir): - logger.warning(f"无法找到插件目录: {plugin_dir}") - return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ - - readme_path = os.path.join(plugin_dir, "README.md") - - if not os.path.isfile(readme_path): - logger.warning(f"插件 {plugin_name} 没有README文件") - return Response().error(f"插件 {plugin_name} 没有README文件").__dict__ - - try: - with open(readme_path, encoding="utf-8") as f: - readme_content = f.read() - - return ( - Response() - .ok({"content": readme_content}, "成功获取README内容") - .__dict__ - ) - except Exception as e: - logger.error(f"/api/plugin/readme: {traceback.format_exc()}") - return Response().error(f"读取README文件失败: {e!s}").__dict__ - - async def get_plugin_changelog(self): - """获取插件更新日志 - - 读取插件目录下的 CHANGELOG.md 文件内容。 - """ - plugin_name = request.args.get("name") - logger.debug(f"正在获取插件 {plugin_name} 的更新日志") - - if not plugin_name: - logger.warning("插件名称为空") - return Response().error("插件名称不能为空").__dict__ - - # 查找插件 - plugin_obj = None - for plugin in self.plugin_manager.context.get_all_stars(): - if plugin.name == plugin_name: - plugin_obj = plugin - break - - if not plugin_obj: - logger.warning(f"插件 {plugin_name} 不存在") - return Response().error(f"插件 {plugin_name} 不存在").__dict__ - - if not plugin_obj.root_dir_name: - logger.warning(f"插件 {plugin_name} 目录不存在") - return Response().error(f"插件 {plugin_name} 目录不存在").__dict__ - - if plugin_obj.reserved: - plugin_dir = os.path.join( - self.plugin_manager.reserved_plugin_path, - plugin_obj.root_dir_name, - ) - else: - plugin_dir = os.path.join( - self.plugin_manager.plugin_store_path, - plugin_obj.root_dir_name, - ) - - if not os.path.isdir(plugin_dir): - logger.warning(f"无法找到插件目录: {plugin_dir}") - return Response().error(f"无法找到插件 {plugin_name} 的目录").__dict__ - - # 尝试多种可能的文件名 - changelog_names = ["CHANGELOG.md", "changelog.md", "CHANGELOG", "changelog"] - for name in changelog_names: - changelog_path = os.path.join(plugin_dir, name) - if os.path.isfile(changelog_path): - try: - with open(changelog_path, encoding="utf-8") as f: - changelog_content = f.read() - return ( - Response() - .ok({"content": changelog_content}, "成功获取更新日志") - .__dict__ - ) - except Exception as e: - logger.error(f"/api/plugin/changelog: {traceback.format_exc()}") - return Response().error(f"读取更新日志失败: {e!s}").__dict__ - - # 没有找到 changelog 文件,返回 ok 但 content 为 null - logger.warning(f"插件 {plugin_name} 没有更新日志文件") - return Response().ok({"content": None}, "该插件没有更新日志文件").__dict__ - - async def get_custom_source(self): - """获取自定义插件源""" - sources = await sp.global_get("custom_plugin_sources", []) - return Response().ok(sources).__dict__ - - async def save_custom_source(self): - """保存自定义插件源""" - try: - data = await request.get_json() - sources = data.get("sources", []) - if not isinstance(sources, list): - return Response().error("sources fields must be a list").__dict__ - - await sp.global_put("custom_plugin_sources", sources) - return Response().ok(None, "保存成功").__dict__ - except Exception as e: - logger.error(f"/api/plugin/source/save: {traceback.format_exc()}") - return Response().error(str(e)).__dict__ diff --git a/astrbot/dashboard/routes/route.py b/astrbot/dashboard/routes/route.py deleted file mode 100644 index 53c6234439..0000000000 --- a/astrbot/dashboard/routes/route.py +++ /dev/null @@ -1,59 +0,0 @@ -from dataclasses import dataclass - -from quart import Quart - -from astrbot.core.config.astrbot_config import AstrBotConfig - - -@dataclass -class RouteContext: - config: AstrBotConfig - app: Quart - - -class Route: - routes: list | dict - - def __init__(self, context: RouteContext) -> None: - self.app = context.app - self.config = context.config - - def register_routes(self) -> None: - def _add_rule(path, method, func) -> None: - # 统一添加 /api 前缀 - full_path = f"/api{path}" - self.app.add_url_rule(full_path, view_func=func, methods=[method]) - - # 兼容字典和列表两种格式 - routes_to_register = ( - self.routes.items() if isinstance(self.routes, dict) else self.routes - ) - - for route, definition in routes_to_register: - # 兼容一个路由多个方法 - if isinstance(definition, list): - for method, func in definition: - _add_rule(route, method, func) - else: - method, func = definition - _add_rule(route, method, func) - - -@dataclass -class Response: - status: str | None = None - message: str | None = None - data: dict | list | None = None - - def error(self, message: str): - self.status = "error" - self.message = message - return self - - def ok(self, data: dict | list | None = None, message: str | None = None): - self.status = "ok" - if data is None: - data = {} - self.data = data - self.message = message - return self diff --git a/astrbot/dashboard/routes/session_management.py b/astrbot/dashboard/routes/session_management.py deleted file mode 100644 index 688515f4e3..0000000000 --- a/astrbot/dashboard/routes/session_management.py +++ /dev/null @@ -1,965 +0,0 @@ -from quart import request -from sqlalchemy.ext.asyncio import AsyncSession -from sqlmodel import col, select - -from astrbot.core import logger, sp -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.db import BaseDatabase -from astrbot.core.db.po import ConversationV2, Preference -from astrbot.core.provider.entities import ProviderType -from astrbot.core.umo_alias import build_umo_alias_map, parse_umo, serialize_umo_alias - -from .route import Response, Route, RouteContext - -AVAILABLE_SESSION_RULE_KEYS = [ - "session_service_config", - "session_plugin_config", - "kb_config", - f"provider_perf_{ProviderType.CHAT_COMPLETION.value}", - f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}", - f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}", -] - - -class SessionManagementRoute(Route): - def __init__( - self, - context: RouteContext, - db_helper: BaseDatabase, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.db_helper = db_helper - self.routes = { - "/session/list-rule": ("GET", self.list_session_rule), - "/session/update-rule": ("POST", self.update_session_rule), - "/session/delete-rule": ("POST", self.delete_session_rule), - "/session/batch-delete-rule": ("POST", self.batch_delete_session_rule), - "/session/active-umos": ("GET", self.list_umos), - "/session/list-all-with-status": ("GET", self.list_all_umos_with_status), - "/session/batch-update-service": ("POST", self.batch_update_service), - "/session/batch-update-provider": ("POST", self.batch_update_provider), - # 分组管理 API - "/session/groups": ("GET", self.list_groups), - "/session/group/create": ("POST", self.create_group), - "/session/group/update": ("POST", self.update_group), - "/session/group/delete": ("POST", self.delete_group), - } - self.conv_mgr = core_lifecycle.conversation_manager - self.core_lifecycle = core_lifecycle - self.register_routes() - - @staticmethod - def _is_group_umo(umo: str) -> bool: - umo_lower = umo.lower() - return ":group:" in umo_lower or ":groupmessage:" in umo_lower - - @staticmethod - def _is_private_umo(umo: str) -> bool: - umo_lower = umo.lower() - return ( - ":private:" in umo_lower - or ":friend:" in umo_lower - or ":friendmessage:" in umo_lower - ) - - async def _list_known_umos(self) -> list[str]: - async with self.db_helper.get_db() as session: - session: AsyncSession - result = await session.execute(select(ConversationV2.user_id).distinct()) - umos = {str(row[0]) for row in result.fetchall() if row[0]} - - aliases = await self.db_helper.get_umo_aliases() - umos.update(str(alias.umo) for alias in aliases if alias.umo) - return sorted(umos) - - async def _get_umo_alias_map(self, umos: list[str]) -> dict: - return build_umo_alias_map(await self.db_helper.get_umo_aliases(umos)) - - def _build_umo_info(self, umo: str | None, alias_map: dict) -> dict: - umo_str = umo or "" - return { - "umo": umo_str, - **parse_umo(umo_str), - **serialize_umo_alias(alias_map.get(umo_str), umo_str), - } - - async def _get_umos_by_scope( - self, - scope: str, - group_id: str = "", - ) -> list[str]: - if scope == "custom_group": - if not group_id: - raise ValueError("请指定分组 ID") - groups = self._get_groups() - if group_id not in groups: - raise ValueError(f"分组 '{group_id}' 不存在") - return groups[group_id].get("umos", []) - - all_umos = await self._list_known_umos() - if scope == "group": - return [umo for umo in all_umos if self._is_group_umo(umo)] - if scope == "private": - return [umo for umo in all_umos if self._is_private_umo(umo)] - if scope == "all": - return all_umos - return [] - - async def _get_umo_rules( - self, page: int = 1, page_size: int = 10, search: str = "" - ) -> tuple[dict, int]: - """获取所有带有自定义规则的 umo 及其规则内容(支持分页和搜索)。 - - 如果某个 umo 在 preference 中有以下字段,则表示有自定义规则: - - 1. session_service_config (包含了 是否启用这个umo, 这个umo是否启用 llm, 这个umo是否启用tts, umo自定义名称。) - 2. session_plugin_config (包含了 这个 umo 的 plugin set) - 3. provider_perf_{ProviderType.value} (包含了这个 umo 所选择使用的 provider 信息) - 4. kb_config (包含了这个 umo 的知识库相关配置) - - Args: - page: 页码,从 1 开始 - page_size: 每页数量 - search: 搜索关键词,匹配 umo 或 custom_name - - Returns: - tuple[dict, int]: (umo_rules, total) - 分页后的 umo 规则和总数 - """ - umo_rules = {} - async with self.db_helper.get_db() as session: - session: AsyncSession - result = await session.execute( - select(Preference).where( - col(Preference.scope) == "umo", - col(Preference.key).in_(AVAILABLE_SESSION_RULE_KEYS), - ) - ) - prefs = result.scalars().all() - for pref in prefs: - umo_id = pref.scope_id - if umo_id not in umo_rules: - umo_rules[umo_id] = {} - if pref.key == "session_plugin_config" and umo_id in pref.value["val"]: - umo_rules[umo_id][pref.key] = pref.value["val"][umo_id] - else: - umo_rules[umo_id][pref.key] = pref.value["val"] - - alias_map = await self._get_umo_alias_map(list(umo_rules.keys())) - - # 搜索过滤 - if search: - search_lower = search.lower() - filtered_rules = {} - for umo_id, rules in umo_rules.items(): - # 匹配 umo - if search_lower in umo_id.lower(): - filtered_rules[umo_id] = rules - continue - # 匹配 custom_name - svc_config = rules.get("session_service_config", {}) - custom_name = svc_config.get("custom_name", "") if svc_config else "" - if custom_name and search_lower in custom_name.lower(): - filtered_rules[umo_id] = rules - continue - - alias_info = serialize_umo_alias(alias_map.get(umo_id), umo_id) - if any( - search_lower in alias_info[key].lower() - for key in ("auto_name", "user_alias", "display_name") - if alias_info.get(key) - ): - filtered_rules[umo_id] = rules - umo_rules = filtered_rules - - # 获取总数 - total = len(umo_rules) - - # 分页处理 - all_umo_ids = list(umo_rules.keys()) - start_idx = (page - 1) * page_size - end_idx = start_idx + page_size - paginated_umo_ids = all_umo_ids[start_idx:end_idx] - - # 只返回分页后的数据 - paginated_rules = {umo_id: umo_rules[umo_id] for umo_id in paginated_umo_ids} - - return paginated_rules, total - - async def list_session_rule(self): - """获取所有自定义的规则(支持分页和搜索) - - 返回已配置规则的 umo 列表及其规则内容,以及可用的 personas 和 providers - - Query 参数: - page: 页码,默认为 1 - page_size: 每页数量,默认为 10 - search: 搜索关键词,匹配 umo 或 custom_name - """ - try: - # 获取分页和搜索参数 - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 10, type=int) - search = request.args.get("search", "", type=str).strip() - - # 参数校验 - if page < 1: - page = 1 - if page_size < 1: - page_size = 10 - if page_size > 100: - page_size = 100 - - umo_rules, total = await self._get_umo_rules( - page=page, page_size=page_size, search=search - ) - - # 构建规则列表 - rules_list = [] - alias_map = await self._get_umo_alias_map(list(umo_rules.keys())) - for umo, rules in umo_rules.items(): - rule_info = { - "rules": rules, - **self._build_umo_info(umo, alias_map), - } - rules_list.append(rule_info) - - # 获取可用的 providers 和 personas - provider_manager = self.core_lifecycle.provider_manager - persona_mgr = self.core_lifecycle.persona_mgr - - available_personas = [ - {"name": p["name"], "prompt": p.get("prompt", "")} - for p in persona_mgr.personas_v3 - ] - - available_chat_providers = [ - { - "id": p.meta().id, - "name": p.meta().id, - "model": p.meta().model, - } - for p in provider_manager.provider_insts - ] - - available_stt_providers = [ - { - "id": p.meta().id, - "name": p.meta().id, - "model": p.meta().model, - } - for p in provider_manager.stt_provider_insts - ] - - available_tts_providers = [ - { - "id": p.meta().id, - "name": p.meta().id, - "model": p.meta().model, - } - for p in provider_manager.tts_provider_insts - ] - - # 获取可用的插件列表(排除 reserved 的系统插件) - plugin_manager = self.core_lifecycle.plugin_manager - available_plugins = [ - { - "name": p.name, - "display_name": p.display_name or p.name, - "desc": p.desc, - } - for p in plugin_manager.context.get_all_stars() - if not p.reserved and p.name - ] - - # 获取可用的知识库列表 - available_kbs = [] - kb_manager = self.core_lifecycle.kb_manager - if kb_manager: - try: - kbs = await kb_manager.list_kbs() - available_kbs = [ - { - "kb_id": kb.kb_id, - "kb_name": kb.kb_name, - "emoji": kb.emoji, - } - for kb in kbs - ] - except Exception as e: - logger.warning(f"获取知识库列表失败: {e!s}") - - return ( - Response() - .ok( - { - "rules": rules_list, - "total": total, - "page": page, - "page_size": page_size, - "available_personas": available_personas, - "available_chat_providers": available_chat_providers, - "available_stt_providers": available_stt_providers, - "available_tts_providers": available_tts_providers, - "available_plugins": available_plugins, - "available_kbs": available_kbs, - "available_rule_keys": AVAILABLE_SESSION_RULE_KEYS, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取规则列表失败: {e!s}") - return Response().error(f"获取规则列表失败: {e!s}").__dict__ - - async def update_session_rule(self): - """更新某个 umo 的自定义规则 - - 请求体: - { - "umo": "平台:消息类型:会话ID", - "rule_key": "session_service_config" | "session_plugin_config" | "kb_config" | "provider_perf_xxx", - "rule_value": {...} // 规则值,具体结构根据 rule_key 不同而不同 - } - """ - try: - data = await request.get_json() - umo = data.get("umo") - rule_key = data.get("rule_key") - rule_value = data.get("rule_value") - - if not umo: - return Response().error("缺少必要参数: umo").__dict__ - if not rule_key: - return Response().error("缺少必要参数: rule_key").__dict__ - if rule_key not in AVAILABLE_SESSION_RULE_KEYS: - return Response().error(f"不支持的规则键: {rule_key}").__dict__ - - if rule_key == "session_plugin_config": - rule_value = { - umo: rule_value, - } - - # 使用 shared preferences 更新规则 - await sp.session_put(umo, rule_key, rule_value) - - return ( - Response() - .ok({"message": f"规则 {rule_key} 已更新", "umo": umo}) - .__dict__ - ) - except Exception as e: - logger.error(f"更新会话规则失败: {e!s}") - return Response().error(f"更新会话规则失败: {e!s}").__dict__ - - async def delete_session_rule(self): - """删除某个 umo 的自定义规则 - - 请求体: - { - "umo": "平台:消息类型:会话ID", - "rule_key": "session_service_config" | "session_plugin_config" | ... (可选,不传则删除所有规则) - } - """ - try: - data = await request.get_json() - umo = data.get("umo") - rule_key = data.get("rule_key") - - if not umo: - return Response().error("缺少必要参数: umo").__dict__ - - if rule_key: - # 删除单个规则 - if rule_key not in AVAILABLE_SESSION_RULE_KEYS: - return Response().error(f"不支持的规则键: {rule_key}").__dict__ - await sp.session_remove(umo, rule_key) - return ( - Response() - .ok({"message": f"规则 {rule_key} 已删除", "umo": umo}) - .__dict__ - ) - else: - # 删除该 umo 的所有规则 - await sp.clear_async("umo", umo) - return Response().ok({"message": "所有规则已删除", "umo": umo}).__dict__ - except Exception as e: - logger.error(f"删除会话规则失败: {e!s}") - return Response().error(f"删除会话规则失败: {e!s}").__dict__ - - async def batch_delete_session_rule(self): - """批量删除多个 umo 的自定义规则 - - 请求体: - { - "umos": ["平台:消息类型:会话ID", ...], // 可选 - "scope": "all" | "group" | "private" | "custom_group", // 可选,批量范围 - "group_id": "分组ID", // 当 scope 为 custom_group 时必填 - "rule_key": "session_service_config" | ... (可选,不传则删除所有规则) - } - """ - - try: - data = await request.get_json() - umos = data.get("umos", []) - scope = data.get("scope", "") - group_id = data.get("group_id", "") - rule_key = data.get("rule_key") - - # 如果指定了 scope,获取符合条件的所有 umo - if scope and not umos: - try: - umos = await self._get_umos_by_scope(scope, group_id) - except ValueError as e: - return Response().error(str(e)).__dict__ - - if not umos: - return Response().error("缺少必要参数: umos 或有效的 scope").__dict__ - - if not isinstance(umos, list): - return Response().error("参数 umos 必须是数组").__dict__ - - if rule_key and rule_key not in AVAILABLE_SESSION_RULE_KEYS: - return Response().error(f"不支持的规则键: {rule_key}").__dict__ - - # 批量删除 - success_count = 0 - failed_umos = [] - for umo in umos: - try: - if rule_key: - await sp.session_remove(umo, rule_key) - else: - await sp.clear_async("umo", umo) - success_count += 1 - except Exception as e: - logger.error(f"删除 umo {umo} 的规则失败: {e!s}") - failed_umos.append(umo) - - message = f"已删除 {success_count} 条规则" - if rule_key: - message = f"已删除 {success_count} 条 {rule_key} 规则" - - if failed_umos: - return ( - Response() - .ok( - { - "message": f"{message},{len(failed_umos)} 条删除失败", - "success_count": success_count, - "failed_umos": failed_umos, - } - ) - .__dict__ - ) - else: - return ( - Response() - .ok( - { - "message": message, - "success_count": success_count, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"批量删除会话规则失败: {e!s}") - return Response().error(f"批量删除会话规则失败: {e!s}").__dict__ - - async def list_umos(self): - """List known UMOs from conversations and alias records. - - Returns both the legacy string list and structured display metadata. - """ - try: - umos = await self._list_known_umos() - alias_map = await self._get_umo_alias_map(umos) - umo_infos = [self._build_umo_info(umo, alias_map) for umo in umos] - - return Response().ok({"umos": umos, "umo_infos": umo_infos}).__dict__ - except Exception as e: - logger.error(f"获取 UMO 列表失败: {e!s}") - return Response().error(f"获取 UMO 列表失败: {e!s}").__dict__ - - async def list_all_umos_with_status(self): - """获取所有有对话记录的 UMO 及其服务状态(支持分页、搜索、筛选) - - Query 参数: - page: 页码,默认为 1 - page_size: 每页数量,默认为 20 - search: 搜索关键词 - message_type: 筛选消息类型 (group/private/all) - platform: 筛选平台 - """ - try: - page = request.args.get("page", 1, type=int) - page_size = request.args.get("page_size", 20, type=int) - search = request.args.get("search", "", type=str).strip() - message_type = request.args.get("message_type", "all", type=str) - platform = request.args.get("platform", "", type=str) - - if page < 1: - page = 1 - if page_size < 1: - page_size = 20 - if page_size > 100: - page_size = 100 - - all_umos = await self._list_known_umos() - alias_map = await self._get_umo_alias_map(all_umos) - - # 获取所有 umo 的规则配置 - umo_rules, _ = await self._get_umo_rules(page=1, page_size=99999, search="") - - # 构建带状态的 umo 列表 - umos_with_status = [] - for umo in all_umos: - umo_info = self._build_umo_info(umo, alias_map) - umo_platform = umo_info["platform"] - umo_message_type = umo_info["message_type"] - - # 筛选消息类型 - if message_type != "all": - if message_type == "group" and umo_message_type not in [ - "group", - "GroupMessage", - ]: - continue - if message_type == "private" and umo_message_type not in [ - "private", - "FriendMessage", - "friend", - ]: - continue - - # 筛选平台 - if platform and umo_platform != platform: - continue - - # 获取服务配置 - rules = umo_rules.get(umo, {}) - svc_config = rules.get("session_service_config", {}) - - custom_name = svc_config.get("custom_name", "") if svc_config else "" - session_enabled = ( - svc_config.get("session_enabled", True) if svc_config else True - ) - llm_enabled = ( - svc_config.get("llm_enabled", True) if svc_config else True - ) - tts_enabled = ( - svc_config.get("tts_enabled", True) if svc_config else True - ) - - # 搜索过滤 - if search: - search_lower = search.lower() - search_targets = [ - umo, - custom_name, - umo_info["auto_name"], - umo_info["user_alias"], - umo_info["display_name"], - ] - if not any( - search_lower in target.lower() - for target in search_targets - if target - ): - continue - - # 获取 provider 配置 - chat_provider_key = ( - f"provider_perf_{ProviderType.CHAT_COMPLETION.value}" - ) - tts_provider_key = f"provider_perf_{ProviderType.TEXT_TO_SPEECH.value}" - stt_provider_key = f"provider_perf_{ProviderType.SPEECH_TO_TEXT.value}" - - umos_with_status.append( - { - **umo_info, - "custom_name": custom_name, - "session_enabled": session_enabled, - "llm_enabled": llm_enabled, - "tts_enabled": tts_enabled, - "has_rules": umo in umo_rules, - "chat_provider": rules.get(chat_provider_key), - "tts_provider": rules.get(tts_provider_key), - "stt_provider": rules.get(stt_provider_key), - } - ) - - # 分页 - total = len(umos_with_status) - start_idx = (page - 1) * page_size - end_idx = start_idx + page_size - paginated = umos_with_status[start_idx:end_idx] - - # 获取可用的平台列表 - platforms = list({u["platform"] for u in umos_with_status}) - - # 获取可用的 providers - provider_manager = self.core_lifecycle.provider_manager - available_chat_providers = [ - {"id": p.meta().id, "name": p.meta().id, "model": p.meta().model} - for p in provider_manager.provider_insts - ] - available_tts_providers = [ - {"id": p.meta().id, "name": p.meta().id, "model": p.meta().model} - for p in provider_manager.tts_provider_insts - ] - available_stt_providers = [ - {"id": p.meta().id, "name": p.meta().id, "model": p.meta().model} - for p in provider_manager.stt_provider_insts - ] - - return ( - Response() - .ok( - { - "sessions": paginated, - "total": total, - "page": page, - "page_size": page_size, - "platforms": platforms, - "available_chat_providers": available_chat_providers, - "available_tts_providers": available_tts_providers, - "available_stt_providers": available_stt_providers, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"获取会话状态列表失败: {e!s}") - return Response().error(f"获取会话状态列表失败: {e!s}").__dict__ - - async def batch_update_service(self): - """批量更新多个 UMO 的服务状态 (LLM/TTS/Session) - - 请求体: - { - "umos": ["平台:消息类型:会话ID", ...], // 可选,如果不传则根据 scope 筛选 - "scope": "all" | "group" | "private" | "custom_group", // 可选,批量范围 - "group_id": "分组ID", // 当 scope 为 custom_group 时必填 - "llm_enabled": true/false/null, // 可选,null表示不修改 - "tts_enabled": true/false/null, // 可选 - "session_enabled": true/false/null // 可选 - } - """ - try: - data = await request.get_json() - umos = data.get("umos", []) - scope = data.get("scope", "") - group_id = data.get("group_id", "") - llm_enabled = data.get("llm_enabled") - tts_enabled = data.get("tts_enabled") - session_enabled = data.get("session_enabled") - - # 如果没有任何修改 - if llm_enabled is None and tts_enabled is None and session_enabled is None: - return Response().error("至少需要指定一个要修改的状态").__dict__ - - # 如果指定了 scope,获取符合条件的所有 umo - if scope and not umos: - try: - umos = await self._get_umos_by_scope(scope, group_id) - except ValueError as e: - return Response().error(str(e)).__dict__ - - if not umos: - return Response().error("没有找到符合条件的会话").__dict__ - - # 批量更新 - success_count = 0 - failed_umos = [] - - for umo in umos: - try: - # 获取现有配置 - session_config = ( - sp.get("session_service_config", {}, scope="umo", scope_id=umo) - or {} - ) - - # 更新状态 - if llm_enabled is not None: - session_config["llm_enabled"] = llm_enabled - if tts_enabled is not None: - session_config["tts_enabled"] = tts_enabled - if session_enabled is not None: - session_config["session_enabled"] = session_enabled - - # 保存 - sp.put( - "session_service_config", - session_config, - scope="umo", - scope_id=umo, - ) - success_count += 1 - except Exception as e: - logger.error(f"更新 {umo} 服务状态失败: {e!s}") - failed_umos.append(umo) - - status_changes = [] - if llm_enabled is not None: - status_changes.append(f"LLM={'启用' if llm_enabled else '禁用'}") - if tts_enabled is not None: - status_changes.append(f"TTS={'启用' if tts_enabled else '禁用'}") - if session_enabled is not None: - status_changes.append(f"会话={'启用' if session_enabled else '禁用'}") - - return ( - Response() - .ok( - { - "message": f"已更新 {success_count} 个会话 ({', '.join(status_changes)})", - "success_count": success_count, - "failed_count": len(failed_umos), - "failed_umos": failed_umos, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"批量更新服务状态失败: {e!s}") - return Response().error(f"批量更新服务状态失败: {e!s}").__dict__ - - async def batch_update_provider(self): - """批量更新多个 UMO 的 Provider 配置 - - 请求体: - { - "umos": ["平台:消息类型:会话ID", ...], // 可选 - "scope": "all" | "group" | "private", // 可选 - "provider_type": "chat_completion" | "text_to_speech" | "speech_to_text", - "provider_id": "provider_id" - } - """ - try: - data = await request.get_json() - umos = data.get("umos", []) - scope = data.get("scope", "") - provider_type = data.get("provider_type") - provider_id = data.get("provider_id") - - if not provider_type or not provider_id: - return ( - Response() - .error("缺少必要参数: provider_type, provider_id") - .__dict__ - ) - - # 转换 provider_type - provider_type_map = { - "chat_completion": ProviderType.CHAT_COMPLETION, - "text_to_speech": ProviderType.TEXT_TO_SPEECH, - "speech_to_text": ProviderType.SPEECH_TO_TEXT, - } - if provider_type not in provider_type_map: - return ( - Response() - .error(f"不支持的 provider_type: {provider_type}") - .__dict__ - ) - - provider_type_enum = provider_type_map[provider_type] - - # 如果指定了 scope,获取符合条件的所有 umo - group_id = data.get("group_id", "") - if scope and not umos: - try: - umos = await self._get_umos_by_scope(scope, group_id) - except ValueError as e: - return Response().error(str(e)).__dict__ - - if not umos: - return Response().error("没有找到符合条件的会话").__dict__ - - # 批量更新 - success_count = 0 - failed_umos = [] - provider_manager = self.core_lifecycle.provider_manager - - for umo in umos: - try: - await provider_manager.set_provider( - provider_id=provider_id, - provider_type=provider_type_enum, - umo=umo, - ) - success_count += 1 - except Exception as e: - logger.error(f"更新 {umo} Provider 失败: {e!s}") - failed_umos.append(umo) - - return ( - Response() - .ok( - { - "message": f"已更新 {success_count} 个会话的 {provider_type} 为 {provider_id}", - "success_count": success_count, - "failed_count": len(failed_umos), - "failed_umos": failed_umos, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"批量更新 Provider 失败: {e!s}") - return Response().error(f"批量更新 Provider 失败: {e!s}").__dict__ - - # ==================== 分组管理 API ==================== - - def _get_groups(self) -> dict: - """获取所有分组""" - return sp.get("session_groups", {}) - - def _save_groups(self, groups: dict) -> None: - """保存分组""" - sp.put("session_groups", groups) - - async def list_groups(self): - """获取所有分组列表""" - try: - groups = self._get_groups() - # 转换为列表格式,方便前端使用 - groups_list = [] - for group_id, group_data in groups.items(): - groups_list.append( - { - "id": group_id, - "name": group_data.get("name", ""), - "umos": group_data.get("umos", []), - "umo_count": len(group_data.get("umos", [])), - } - ) - return Response().ok({"groups": groups_list}).__dict__ - except Exception as e: - logger.error(f"获取分组列表失败: {e!s}") - return Response().error(f"获取分组列表失败: {e!s}").__dict__ - - async def create_group(self): - """创建新分组""" - try: - data = await request.json - name = data.get("name", "").strip() - umos = data.get("umos", []) - - if not name: - return Response().error("分组名称不能为空").__dict__ - - groups = self._get_groups() - - # 生成唯一 ID - import uuid - - group_id = str(uuid.uuid4())[:8] - - groups[group_id] = { - "name": name, - "umos": umos, - } - - self._save_groups(groups) - - return ( - Response() - .ok( - { - "message": f"分组 '{name}' 创建成功", - "group": { - "id": group_id, - "name": name, - "umos": umos, - "umo_count": len(umos), - }, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"创建分组失败: {e!s}") - return Response().error(f"创建分组失败: {e!s}").__dict__ - - async def update_group(self): - """更新分组(改名、增删成员)""" - try: - data = await request.json - group_id = data.get("id") - name = data.get("name") - umos = data.get("umos") - add_umos = data.get("add_umos", []) - remove_umos = data.get("remove_umos", []) - - if not group_id: - return Response().error("分组 ID 不能为空").__dict__ - - groups = self._get_groups() - - if group_id not in groups: - return Response().error(f"分组 '{group_id}' 不存在").__dict__ - - group = groups[group_id] - - # 更新名称 - if name is not None: - group["name"] = name.strip() - - # 直接设置 umos 列表 - if umos is not None: - group["umos"] = umos - else: - # 增量更新 - current_umos = set(group.get("umos", [])) - if add_umos: - current_umos.update(add_umos) - if remove_umos: - current_umos.difference_update(remove_umos) - group["umos"] = list(current_umos) - - self._save_groups(groups) - - return ( - Response() - .ok( - { - "message": f"分组 '{group['name']}' 更新成功", - "group": { - "id": group_id, - "name": group["name"], - "umos": group["umos"], - "umo_count": len(group["umos"]), - }, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(f"更新分组失败: {e!s}") - return Response().error(f"更新分组失败: {e!s}").__dict__ - - async def delete_group(self): - """删除分组""" - try: - data = await request.json - group_id = data.get("id") - - if not group_id: - return Response().error("分组 ID 不能为空").__dict__ - - groups = self._get_groups() - - if group_id not in groups: - return Response().error(f"分组 '{group_id}' 不存在").__dict__ - - group_name = groups[group_id].get("name", group_id) - del groups[group_id] - - self._save_groups(groups) - - return Response().ok({"message": f"分组 '{group_name}' 已删除"}).__dict__ - except Exception as e: - logger.error(f"删除分组失败: {e!s}") - return Response().error(f"删除分组失败: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/skills.py b/astrbot/dashboard/routes/skills.py deleted file mode 100644 index c86598212e..0000000000 --- a/astrbot/dashboard/routes/skills.py +++ /dev/null @@ -1,960 +0,0 @@ -import os -import re -import shutil -import traceback -from collections.abc import Awaitable, Callable -from pathlib import Path -from typing import Any - -from quart import request, send_file - -from astrbot.core import DEMO_MODE, logger -from astrbot.core.computer.computer_client import ( - _discover_bay_credentials, - sync_skills_to_active_sandboxes, -) -from astrbot.core.skills.neo_skill_sync import NeoSkillSyncManager -from astrbot.core.skills.skill_manager import SkillManager -from astrbot.core.utils.astrbot_path import get_astrbot_temp_path - -from .route import Response, Route, RouteContext - - -def _to_jsonable(value: Any) -> Any: - if isinstance(value, dict): - return {k: _to_jsonable(v) for k, v in value.items()} - if isinstance(value, list): - return [_to_jsonable(v) for v in value] - if hasattr(value, "model_dump"): - return _to_jsonable(value.model_dump()) - return value - - -def _to_bool(value: Any, default: bool = False) -> bool: - if value is None: - return default - if isinstance(value, bool): - return value - if isinstance(value, str): - return value.strip().lower() in {"1", "true", "yes", "y", "on"} - return bool(value) - - -_SKILL_NAME_RE = re.compile(r"^[A-Za-z0-9._-]+$") -_SKILL_FILE_MAX_BYTES = 512 * 1024 -_EDITABLE_SKILL_FILE_SUFFIXES = { - ".css", - ".html", - ".ini", - ".js", - ".json", - ".md", - ".py", - ".sh", - ".toml", - ".ts", - ".txt", - ".yaml", - ".yml", -} -_EDITABLE_SKILL_FILENAMES = {"Dockerfile", "Makefile"} - - -def _next_available_temp_path(temp_dir: str, filename: str) -> str: - stem = Path(filename).stem - suffix = Path(filename).suffix - candidate = filename - index = 1 - while os.path.exists(os.path.join(temp_dir, candidate)): - candidate = f"{stem}_{index}{suffix}" - index += 1 - return os.path.join(temp_dir, candidate) - - -class SkillsRoute(Route): - def __init__(self, context: RouteContext, core_lifecycle) -> None: - super().__init__(context) - self.core_lifecycle = core_lifecycle - self.routes = { - "/skills": ("GET", self.get_skills), - "/skills/upload": ("POST", self.upload_skill), - "/skills/batch-upload": ("POST", self.batch_upload_skills), - "/skills/download": ("GET", self.download_skill), - "/skills/files": ("GET", self.list_skill_files), - "/skills/file": [ - ("GET", self.get_skill_file), - ("POST", self.update_skill_file), - ], - "/skills/update": ("POST", self.update_skill), - "/skills/delete": ("POST", self.delete_skill), - "/skills/neo/candidates": ("GET", self.get_neo_candidates), - "/skills/neo/releases": ("GET", self.get_neo_releases), - "/skills/neo/payload": ("GET", self.get_neo_payload), - "/skills/neo/evaluate": ("POST", self.evaluate_neo_candidate), - "/skills/neo/promote": ("POST", self.promote_neo_candidate), - "/skills/neo/rollback": ("POST", self.rollback_neo_release), - "/skills/neo/sync": ("POST", self.sync_neo_release), - "/skills/neo/delete-candidate": ("POST", self.delete_neo_candidate), - "/skills/neo/delete-release": ("POST", self.delete_neo_release), - } - self.register_routes() - - def _resolve_local_skill_dir(self, name: str) -> Path: - skill_name = str(name or "").strip() - if not skill_name: - raise ValueError("Missing skill name") - if not _SKILL_NAME_RE.match(skill_name): - raise ValueError("Invalid skill name") - - skill_mgr = SkillManager() - if skill_mgr.is_sandbox_only_skill(skill_name): - raise PermissionError( - "Sandbox preset skill cannot be opened from local skill files." - ) - - plugin_skill_dir = skill_mgr._get_plugin_skill_dir(skill_name) - if plugin_skill_dir is not None: - return plugin_skill_dir.resolve(strict=True) - - skills_root = Path(skill_mgr.skills_root).resolve(strict=True) - skill_dir = (skills_root / skill_name).resolve(strict=True) - if not skill_dir.is_relative_to(skills_root): - raise PermissionError("Invalid skill path") - if not skill_dir.is_dir() or not (skill_dir / "SKILL.md").exists(): - raise FileNotFoundError("Local skill not found") - return skill_dir - - def _resolve_skill_relative_path( - self, - skill_dir: Path, - relative_path: str | None, - *, - expect_file: bool, - ) -> Path: - raw_path = str(relative_path or ".").strip() or "." - normalized = Path(raw_path.replace("\\", "/")) - if normalized.is_absolute() or ".." in normalized.parts: - raise ValueError("Invalid relative path") - - target = (skill_dir / normalized).resolve(strict=True) - if not target.is_relative_to(skill_dir): - raise PermissionError("Path escapes skill directory") - if expect_file and not target.is_file(): - raise FileNotFoundError("Skill file not found") - if not expect_file and not target.is_dir(): - raise FileNotFoundError("Skill directory not found") - return target - - @staticmethod - def _skill_relative_path(skill_dir: Path, target: Path) -> str: - rel = target.relative_to(skill_dir).as_posix() - return "" if rel == "." else rel - - @staticmethod - def _is_editable_skill_file(path: Path) -> bool: - return ( - path.name in _EDITABLE_SKILL_FILENAMES - or path.suffix.lower() in _EDITABLE_SKILL_FILE_SUFFIXES - ) - - def _serialize_skill_file_entry( - self, - skill_dir: Path, - path: Path, - *, - readonly: bool = False, - ) -> dict: - stat = path.stat() - is_dir = path.is_dir() - return { - "name": path.name, - "path": self._skill_relative_path(skill_dir, path), - "type": "directory" if is_dir else "file", - "size": 0 if is_dir else stat.st_size, - "editable": ( - not readonly - and (not is_dir) - and self._is_editable_skill_file(path) - and stat.st_size <= _SKILL_FILE_MAX_BYTES - ), - } - - def _get_neo_client_config(self) -> tuple[str, str]: - provider_settings = self.core_lifecycle.astrbot_config.get( - "provider_settings", - {}, - ) - sandbox = provider_settings.get("sandbox", {}) - endpoint = sandbox.get("shipyard_neo_endpoint", "") - access_token = sandbox.get("shipyard_neo_access_token", "") - - # Auto-discover token from Bay's credentials.json if not configured - if not access_token and endpoint: - access_token = _discover_bay_credentials(endpoint) - - if not endpoint or not access_token: - raise ValueError( - "Shipyard Neo endpoint or access token not configured. " - "Set them in Dashboard or ensure Bay's credentials.json is accessible." - ) - return endpoint, access_token - - async def _delete_neo_release( - self, client: Any, release_id: str, reason: str | None - ): - return await client.skills.delete_release(release_id, reason=reason) - - async def _delete_neo_candidate( - self, client: Any, candidate_id: str, reason: str | None - ): - return await client.skills.delete_candidate(candidate_id, reason=reason) - - async def _with_neo_client( - self, - operation: Callable[[Any], Awaitable[dict]], - ) -> dict: - try: - endpoint, access_token = self._get_neo_client_config() - - from shipyard_neo import BayClient - - async with BayClient( - endpoint_url=endpoint, - access_token=access_token, - ) as client: - return await operation(client) - except ValueError as e: - # Config not ready — expected when Neo isn't set up yet - logger.debug("[Neo] %s", e) - return Response().error(str(e)).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def get_skills(self): - try: - provider_settings = self.core_lifecycle.astrbot_config.get( - "provider_settings", {} - ) - runtime = provider_settings.get("computer_use_runtime", "local") - skill_mgr = SkillManager() - skills = skill_mgr.list_skills( - active_only=False, runtime=runtime, show_sandbox_path=False - ) - return ( - Response() - .ok( - { - "skills": [skill.__dict__ for skill in skills], - "runtime": runtime, - "sandbox_cache": skill_mgr.get_sandbox_skills_cache_status(), - } - ) - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def upload_skill(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - temp_path = None - try: - files = await request.files - file = files.get("file") - if not file: - return Response().error("Missing file").__dict__ - filename = os.path.basename(file.filename or "skill.zip") - if not filename.lower().endswith(".zip"): - return Response().error("Only .zip files are supported").__dict__ - - temp_dir = get_astrbot_temp_path() - os.makedirs(temp_dir, exist_ok=True) - skill_mgr = SkillManager() - temp_path = _next_available_temp_path(temp_dir, filename) - await file.save(temp_path) - - try: - try: - skill_name = skill_mgr.install_skill_from_zip( - temp_path, overwrite=False, skill_name_hint=Path(filename).stem - ) - except TypeError: - # Backward compatibility for callers that do not accept skill_name_hint - skill_name = skill_mgr.install_skill_from_zip( - temp_path, overwrite=False - ) - except Exception: - # Keep behavior consistent with previous implementation - # and bubble up install errors (including duplicates). - raise - - try: - await sync_skills_to_active_sandboxes() - except Exception: - logger.warning("Failed to sync uploaded skills to active sandboxes.") - - return ( - Response() - .ok({"name": skill_name}, "Skill uploaded successfully.") - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - finally: - if temp_path and os.path.exists(temp_path): - try: - os.remove(temp_path) - except Exception: - logger.warning(f"Failed to remove temp skill file: {temp_path}") - - async def batch_upload_skills(self): - """批量上传多个 skill ZIP 文件""" - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - try: - files = await request.files - file_list = files.getlist("files") - - if not file_list: - return Response().error("No files provided").__dict__ - - succeeded = [] - failed = [] - skipped = [] - skill_mgr = SkillManager() - temp_dir = get_astrbot_temp_path() - os.makedirs(temp_dir, exist_ok=True) - - for file in file_list: - filename = os.path.basename(file.filename or "unknown.zip") - temp_path = None - - try: - if not filename.lower().endswith(".zip"): - failed.append( - { - "filename": filename, - "error": "Only .zip files are supported", - } - ) - continue - - temp_path = _next_available_temp_path(temp_dir, filename) - await file.save(temp_path) - - try: - skill_name = skill_mgr.install_skill_from_zip( - temp_path, - overwrite=False, - skill_name_hint=Path(filename).stem, - ) - except TypeError: - # Backward compatibility for monkeypatched implementations in tests - try: - skill_name = skill_mgr.install_skill_from_zip( - temp_path, overwrite=False - ) - except FileExistsError: - skipped.append( - { - "filename": filename, - "name": Path(filename).stem, - "error": "Skill already exists.", - } - ) - skill_name = None - except FileExistsError: - skipped.append( - { - "filename": filename, - "name": Path(filename).stem, - "error": "Skill already exists.", - } - ) - skill_name = None - - if skill_name is None: - continue - succeeded.append({"filename": filename, "name": skill_name}) - - except Exception as e: - failed.append({"filename": filename, "error": str(e)}) - finally: - if temp_path and os.path.exists(temp_path): - try: - os.remove(temp_path) - except Exception: - pass - - if succeeded: - try: - await sync_skills_to_active_sandboxes() - except Exception: - logger.warning( - "Failed to sync uploaded skills to active sandboxes." - ) - - total = len(file_list) - success_count = len(succeeded) - skipped_count = len(skipped) - failed_count = len(failed) - - if failed_count == 0 and success_count == total: - message = f"All {total} skill(s) uploaded successfully." - return ( - Response() - .ok( - { - "total": total, - "succeeded": succeeded, - "failed": failed, - "skipped": skipped, - }, - message, - ) - .__dict__ - ) - if failed_count == 0 and success_count == 0: - message = f"All {total} file(s) were skipped." - return ( - Response() - .ok( - { - "total": total, - "succeeded": succeeded, - "failed": failed, - "skipped": skipped, - }, - message, - ) - .__dict__ - ) - if success_count == 0 and skipped_count == 0: - message = f"Upload failed for all {total} file(s)." - resp = Response().error(message) - resp.data = { - "total": total, - "succeeded": succeeded, - "failed": failed, - "skipped": skipped, - } - return resp.__dict__ - - message = f"Partial success: {success_count}/{total} skill(s) uploaded." - return ( - Response() - .ok( - { - "total": total, - "succeeded": succeeded, - "failed": failed, - "skipped": skipped, - }, - message, - ) - .__dict__ - ) - - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def download_skill(self): - try: - name = str(request.args.get("name") or "").strip() - if not name: - return Response().error("Missing skill name").__dict__ - if not _SKILL_NAME_RE.match(name): - return Response().error("Invalid skill name").__dict__ - - skill_mgr = SkillManager() - if skill_mgr.is_sandbox_only_skill(name): - return ( - Response() - .error( - "Sandbox preset skill cannot be downloaded from local skill files." - ) - .__dict__ - ) - if skill_mgr.is_plugin_skill(name): - return ( - Response() - .error( - "Plugin-provided skill cannot be downloaded from local skill files." - ) - .__dict__ - ) - - skill_dir = Path(skill_mgr.skills_root) / name - skill_md = skill_dir / "SKILL.md" - if not skill_dir.is_dir() or not skill_md.exists(): - return Response().error("Local skill not found").__dict__ - - export_dir = Path(get_astrbot_temp_path()) / "skill_exports" - export_dir.mkdir(parents=True, exist_ok=True) - zip_base = export_dir / name - zip_path = zip_base.with_suffix(".zip") - if zip_path.exists(): - zip_path.unlink() - - shutil.make_archive( - str(zip_base), - "zip", - root_dir=str(skill_mgr.skills_root), - base_dir=name, - ) - - return await send_file( - str(zip_path), - as_attachment=True, - attachment_filename=f"{name}.zip", - conditional=True, - ) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def list_skill_files(self): - try: - name = str(request.args.get("name") or "").strip() - relative_path = request.args.get("path", "") - readonly = SkillManager().is_plugin_skill(name) - skill_dir = self._resolve_local_skill_dir(name) - target_dir = self._resolve_skill_relative_path( - skill_dir, - relative_path, - expect_file=False, - ) - - entries = [] - for entry in sorted( - target_dir.iterdir(), - key=lambda item: (not item.is_dir(), item.name.lower()), - ): - try: - resolved = entry.resolve(strict=True) - except OSError: - continue - if not resolved.is_relative_to(skill_dir): - continue - if not resolved.is_dir() and not resolved.is_file(): - continue - entries.append( - self._serialize_skill_file_entry( - skill_dir, - resolved, - readonly=readonly, - ) - ) - - return ( - Response() - .ok( - { - "name": name, - "path": self._skill_relative_path(skill_dir, target_dir), - "entries": entries, - } - ) - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def get_skill_file(self): - try: - name = str(request.args.get("name") or "").strip() - relative_path = request.args.get("path", "SKILL.md") - skill_dir = self._resolve_local_skill_dir(name) - target_file = self._resolve_skill_relative_path( - skill_dir, - relative_path, - expect_file=True, - ) - if not self._is_editable_skill_file(target_file): - return Response().error("Unsupported file type").__dict__ - - size = target_file.stat().st_size - if size > _SKILL_FILE_MAX_BYTES: - return Response().error("File is too large").__dict__ - - try: - content = target_file.read_text(encoding="utf-8") - except UnicodeDecodeError: - return Response().error("File is not valid UTF-8 text").__dict__ - - return ( - Response() - .ok( - { - "name": name, - "path": self._skill_relative_path(skill_dir, target_file), - "content": content, - "size": size, - "editable": not SkillManager().is_plugin_skill(name), - } - ) - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def update_skill_file(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - try: - data = await request.get_json() - name = str(data.get("name") or "").strip() - relative_path = data.get("path", "SKILL.md") - content = data.get("content") - if not isinstance(content, str): - return Response().error("Missing file content").__dict__ - - encoded = content.encode("utf-8") - if len(encoded) > _SKILL_FILE_MAX_BYTES: - return Response().error("File content is too large").__dict__ - - skill_dir = self._resolve_local_skill_dir(name) - if SkillManager().is_plugin_skill(name): - return Response().error("Plugin-provided skill is read-only.").__dict__ - target_file = self._resolve_skill_relative_path( - skill_dir, - relative_path, - expect_file=True, - ) - if not self._is_editable_skill_file(target_file): - return Response().error("Unsupported file type").__dict__ - - target_file.write_text(content, encoding="utf-8") - - try: - await sync_skills_to_active_sandboxes() - except Exception: - logger.warning("Failed to sync edited skills to active sandboxes.") - - return ( - Response() - .ok( - { - "name": name, - "path": self._skill_relative_path(skill_dir, target_file), - "size": len(encoded), - } - ) - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def update_skill(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - try: - data = await request.get_json() - name = data.get("name") - active = data.get("active", True) - if not name: - return Response().error("Missing skill name").__dict__ - SkillManager().set_skill_active(name, bool(active)) - return Response().ok({"name": name, "active": bool(active)}).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def delete_skill(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - try: - data = await request.get_json() - name = data.get("name") - if not name: - return Response().error("Missing skill name").__dict__ - SkillManager().delete_skill(name) - try: - await sync_skills_to_active_sandboxes() - except Exception: - logger.warning("Failed to sync deleted skills to active sandboxes.") - return Response().ok({"name": name}).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(str(e)).__dict__ - - async def get_neo_candidates(self): - logger.info("[Neo] GET /skills/neo/candidates requested.") - status = request.args.get("status") - skill_key = request.args.get("skill_key") - limit = int(request.args.get("limit", 100)) - offset = int(request.args.get("offset", 0)) - - async def _do(client): - candidates = await client.skills.list_candidates( - status=status, - skill_key=skill_key, - limit=limit, - offset=offset, - ) - result = _to_jsonable(candidates) - total = result.get("total", "?") if isinstance(result, dict) else "?" - logger.info(f"[Neo] Candidates fetched: total={total}") - return Response().ok(result).__dict__ - - return await self._with_neo_client(_do) - - async def get_neo_releases(self): - logger.info("[Neo] GET /skills/neo/releases requested.") - skill_key = request.args.get("skill_key") - stage = request.args.get("stage") - active_only = _to_bool(request.args.get("active_only"), False) - limit = int(request.args.get("limit", 100)) - offset = int(request.args.get("offset", 0)) - - async def _do(client): - releases = await client.skills.list_releases( - skill_key=skill_key, - active_only=active_only, - stage=stage, - limit=limit, - offset=offset, - ) - result = _to_jsonable(releases) - total = result.get("total", "?") if isinstance(result, dict) else "?" - logger.info(f"[Neo] Releases fetched: total={total}") - return Response().ok(result).__dict__ - - return await self._with_neo_client(_do) - - async def get_neo_payload(self): - logger.info("[Neo] GET /skills/neo/payload requested.") - payload_ref = request.args.get("payload_ref", "") - if not payload_ref: - return Response().error("Missing payload_ref").__dict__ - - async def _do(client): - payload = await client.skills.get_payload(payload_ref) - logger.info(f"[Neo] Payload fetched: ref={payload_ref}") - return Response().ok(_to_jsonable(payload)).__dict__ - - return await self._with_neo_client(_do) - - async def evaluate_neo_candidate(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/evaluate requested.") - data = await request.get_json() - candidate_id = data.get("candidate_id") - passed_value = data.get("passed") - if not candidate_id or passed_value is None: - return Response().error("Missing candidate_id or passed").__dict__ - passed = _to_bool(passed_value, False) - - async def _do(client): - result = await client.skills.evaluate_candidate( - candidate_id, - passed=passed, - score=data.get("score"), - benchmark_id=data.get("benchmark_id"), - report=data.get("report"), - ) - logger.info( - f"[Neo] Candidate evaluated: id={candidate_id}, passed={passed}" - ) - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) - - async def promote_neo_candidate(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/promote requested.") - data = await request.get_json() - candidate_id = data.get("candidate_id") - stage = data.get("stage", "canary") - sync_to_local = _to_bool(data.get("sync_to_local"), True) - if not candidate_id: - return Response().error("Missing candidate_id").__dict__ - if stage not in {"canary", "stable"}: - return Response().error("Invalid stage, must be canary/stable").__dict__ - - async def _do(client): - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.promote_with_optional_sync( - client, - candidate_id=candidate_id, - stage=stage, - sync_to_local=sync_to_local, - ) - release_json = result.get("release") - logger.info(f"[Neo] Candidate promoted: id={candidate_id}, stage={stage}") - - sync_json = result.get("sync") - did_sync_to_local = bool(sync_json) - if did_sync_to_local: - logger.info( - f"[Neo] Stable release synced to local: skill={sync_json.get('local_skill_name', '')}" - ) - - if result.get("sync_error"): - resp = Response().error( - "Stable promote synced failed and has been rolled back. " - f"sync_error={result['sync_error']}" - ) - resp.data = { - "release": release_json, - "rollback": result.get("rollback"), - } - return resp.__dict__ - - # Try to push latest local skills to all active sandboxes. - if not did_sync_to_local: - try: - await sync_skills_to_active_sandboxes() - except Exception: - logger.warning("Failed to sync skills to active sandboxes.") - - return Response().ok({"release": release_json, "sync": sync_json}).__dict__ - - return await self._with_neo_client(_do) - - async def rollback_neo_release(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/rollback requested.") - data = await request.get_json() - release_id = data.get("release_id") - if not release_id: - return Response().error("Missing release_id").__dict__ - - async def _do(client): - result = await client.skills.rollback_release(release_id) - logger.info(f"[Neo] Release rolled back: id={release_id}") - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) - - async def sync_neo_release(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/sync requested.") - data = await request.get_json() - release_id = data.get("release_id") - skill_key = data.get("skill_key") - require_stable = _to_bool(data.get("require_stable"), True) - if not release_id and not skill_key: - return Response().error("Missing release_id or skill_key").__dict__ - - async def _do(client): - sync_mgr = NeoSkillSyncManager() - result = await sync_mgr.sync_release( - client, - release_id=release_id, - skill_key=skill_key, - require_stable=require_stable, - ) - logger.info( - f"[Neo] Release synced to local: skill={result.local_skill_name}, " - f"release_id={result.release_id}" - ) - return ( - Response() - .ok( - { - "skill_key": result.skill_key, - "local_skill_name": result.local_skill_name, - "release_id": result.release_id, - "candidate_id": result.candidate_id, - "payload_ref": result.payload_ref, - "map_path": result.map_path, - "synced_at": result.synced_at, - } - ) - .__dict__ - ) - - return await self._with_neo_client(_do) - - async def delete_neo_candidate(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/delete-candidate requested.") - data = await request.get_json() - candidate_id = data.get("candidate_id") - reason = data.get("reason") - if not candidate_id: - return Response().error("Missing candidate_id").__dict__ - - async def _do(client): - result = await self._delete_neo_candidate(client, candidate_id, reason) - logger.info(f"[Neo] Candidate deleted: id={candidate_id}") - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) - - async def delete_neo_release(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - logger.info("[Neo] POST /skills/neo/delete-release requested.") - data = await request.get_json() - release_id = data.get("release_id") - reason = data.get("reason") - if not release_id: - return Response().error("Missing release_id").__dict__ - - async def _do(client): - result = await self._delete_neo_release(client, release_id, reason) - logger.info(f"[Neo] Release deleted: id={release_id}") - return Response().ok(_to_jsonable(result)).__dict__ - - return await self._with_neo_client(_do) diff --git a/astrbot/dashboard/routes/static_file.py b/astrbot/dashboard/routes/static_file.py deleted file mode 100644 index 3a18cf82f2..0000000000 --- a/astrbot/dashboard/routes/static_file.py +++ /dev/null @@ -1,37 +0,0 @@ -from .route import Route, RouteContext - - -class StaticFileRoute(Route): - def __init__(self, context: RouteContext) -> None: - super().__init__(context) - - index_ = [ - "/", - "/auth/login", - "/config", - "/logs", - "/extension", - "/dashboard/default", - "/alkaid", - "/alkaid/knowledge-base", - "/alkaid/long-term-memory", - "/alkaid/other", - "/console", - "/chat", - "/settings", - "/platforms", - "/providers", - "/about", - "/extension-marketplace", - "/conversation", - "/tool-use", - ] - for i in index_: - self.app.add_url_rule(i, view_func=self.index) - - @self.app.errorhandler(404) - async def page_not_found(e) -> str: - return "404 Not found。如果你初次使用打开面板发现 404, 请参考文档: https://docs.astrbot.app/faq.html。如果你正在测试回调地址可达性,显示这段文字说明测试成功了。" - - async def index(self): - return await self.app.send_static_file("index.html") diff --git a/astrbot/dashboard/routes/subagent.py b/astrbot/dashboard/routes/subagent.py deleted file mode 100644 index e3d77f73ad..0000000000 --- a/astrbot/dashboard/routes/subagent.py +++ /dev/null @@ -1,117 +0,0 @@ -import traceback - -from quart import jsonify, request - -from astrbot.core import logger -from astrbot.core.agent.handoff import HandoffTool -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle - -from .route import Response, Route, RouteContext - - -class SubAgentRoute(Route): - def __init__( - self, - context: RouteContext, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.core_lifecycle = core_lifecycle - # NOTE: dict cannot hold duplicate keys; use list form to register multiple - # methods for the same path. - self.routes = [ - ("/subagent/config", ("GET", self.get_config)), - ("/subagent/config", ("POST", self.update_config)), - ("/subagent/available-tools", ("GET", self.get_available_tools)), - ] - self.register_routes() - - async def get_config(self): - try: - cfg = self.core_lifecycle.astrbot_config - data = cfg.get("subagent_orchestrator") - - # First-time access: return a sane default instead of erroring. - if not isinstance(data, dict): - data = { - "main_enable": False, - "remove_main_duplicate_tools": False, - "agents": [], - } - - # Backward compatibility: older config used `enable`. - if ( - isinstance(data, dict) - and "main_enable" not in data - and "enable" in data - ): - data["main_enable"] = bool(data.get("enable", False)) - - # Ensure required keys exist. - data.setdefault("main_enable", False) - data.setdefault("remove_main_duplicate_tools", False) - data.setdefault("agents", []) - - # Backward/forward compatibility: ensure each agent contains provider_id. - # None means follow global/default provider settings. - if isinstance(data.get("agents"), list): - for a in data["agents"]: - if isinstance(a, dict): - a.setdefault("provider_id", None) - a.setdefault("persona_id", None) - return jsonify(Response().ok(data=data).__dict__) - except Exception as e: - logger.error(traceback.format_exc()) - return jsonify(Response().error(f"获取 subagent 配置失败: {e!s}").__dict__) - - async def update_config(self): - try: - data = await request.json - if not isinstance(data, dict): - return jsonify(Response().error("配置必须为 JSON 对象").__dict__) - - cfg = self.core_lifecycle.astrbot_config - cfg["subagent_orchestrator"] = data - - # Persist to cmd_config.json - # AstrBotConfigManager does not expose a `save()` method; persist via AstrBotConfig. - cfg.save_config() - - # Reload dynamic handoff tools if orchestrator exists - orch = getattr(self.core_lifecycle, "subagent_orchestrator", None) - if orch is not None: - await orch.reload_from_config(data) - - return jsonify(Response().ok(message="保存成功").__dict__) - except Exception as e: - logger.error(traceback.format_exc()) - return jsonify(Response().error(f"保存 subagent 配置失败: {e!s}").__dict__) - - async def get_available_tools(self): - """Return all registered tools (name/description/parameters/active/origin). - - UI can use this to build a multi-select list for subagent tool assignment. - """ - try: - tool_mgr = self.core_lifecycle.provider_manager.llm_tools - tools_dict = [] - for tool in tool_mgr.func_list: - # Prevent recursive routing: subagents should not be able to select - # the handoff (transfer_to_*) tools as their own mounted tools. - if isinstance(tool, HandoffTool): - continue - if tool.handler_module_path == "core.subagent_orchestrator": - continue - tools_dict.append( - { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - "active": tool.active, - "handler_module_path": tool.handler_module_path, - } - ) - return jsonify(Response().ok(data=tools_dict).__dict__) - except Exception as e: - logger.error(traceback.format_exc()) - return jsonify(Response().error(f"获取可用工具失败: {e!s}").__dict__) diff --git a/astrbot/dashboard/routes/t2i.py b/astrbot/dashboard/routes/t2i.py deleted file mode 100644 index 634828e955..0000000000 --- a/astrbot/dashboard/routes/t2i.py +++ /dev/null @@ -1,237 +0,0 @@ -# astrbot/dashboard/routes/t2i.py - -from dataclasses import asdict - -from quart import jsonify, request - -from astrbot.core import logger -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.utils.t2i.template_manager import TemplateManager - -from .route import Response, Route, RouteContext - - -class T2iRoute(Route): - def __init__( - self, context: RouteContext, core_lifecycle: AstrBotCoreLifecycle - ) -> None: - super().__init__(context) - self.core_lifecycle = core_lifecycle - self.config = core_lifecycle.astrbot_config - self.manager = TemplateManager() - # 使用列表保证路由注册顺序,避免 / 路由优先匹配 /reset_default - self.routes = [ - ("/t2i/templates", ("GET", self.list_templates)), - ("/t2i/templates/active", ("GET", self.get_active_template)), - ("/t2i/templates/create", ("POST", self.create_template)), - ("/t2i/templates/reset_default", ("POST", self.reset_default_template)), - ("/t2i/templates/set_active", ("POST", self.set_active_template)), - # 动态路由应该在静态路由之后注册 - ( - "/t2i/templates/", - [ - ("GET", self.get_template), - ("PUT", self.update_template), - ("DELETE", self.delete_template), - ], - ), - ] - self.register_routes() - - async def _reload_all_pipeline_schedulers(self) -> None: - """热重载所有配置对应的 pipeline scheduler。""" - for conf_id in self.core_lifecycle.astrbot_config_mgr.confs: - await self.core_lifecycle.reload_pipeline_scheduler(conf_id) - - async def _sync_active_template_to_all_configs(self, name: str) -> None: - """同步当前激活模板到所有配置文件,并热重载对应流水线。""" - for config in self.core_lifecycle.astrbot_config_mgr.confs.values(): - config["t2i_active_template"] = name - config.save_config() - await self._reload_all_pipeline_schedulers() - - async def list_templates(self): - """获取所有T2I模板列表""" - try: - templates = self.manager.list_templates() - return jsonify(asdict(Response().ok(data=templates))) - except Exception as e: - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 500 - return response - - async def get_active_template(self): - """获取当前激活的T2I模板""" - try: - active_template = self.config.get("t2i_active_template", "base") - return jsonify( - asdict(Response().ok(data={"active_template": active_template})), - ) - except Exception as e: - logger.error("Error in get_active_template", exc_info=True) - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 500 - return response - - async def get_template(self, name: str): - """获取指定名称的T2I模板内容""" - try: - content = self.manager.get_template(name) - return jsonify( - asdict(Response().ok(data={"name": name, "content": content})), - ) - except FileNotFoundError: - response = jsonify(asdict(Response().error("Template not found"))) - response.status_code = 404 - return response - except Exception as e: - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 500 - return response - - async def create_template(self): - """创建一个新的T2I模板""" - try: - data = await request.json - name = data.get("name") - content = data.get("content") - if not name or not content: - response = jsonify( - asdict(Response().error("Name and content are required.")), - ) - response.status_code = 400 - return response - name = name.strip() - - self.manager.create_template(name, content) - response = jsonify( - asdict( - Response().ok( - data={"name": name}, - message="Template created successfully.", - ), - ), - ) - response.status_code = 201 - return response - except FileExistsError: - response = jsonify( - asdict(Response().error("Template with this name already exists.")), - ) - response.status_code = 409 - return response - except ValueError as e: - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 400 - return response - except Exception as e: - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 500 - return response - - async def update_template(self, name: str): - """更新一个已存在的T2I模板""" - try: - name = name.strip() - data = await request.json - content = data.get("content") - if content is None: - response = jsonify(asdict(Response().error("Content is required."))) - response.status_code = 400 - return response - - self.manager.update_template(name, content) - - # 检查更新的是否为当前激活的模板,如果是,则热重载 - active_template = self.config.get("t2i_active_template", "base") - if name == active_template: - await self._reload_all_pipeline_schedulers() - message = f"模板 '{name}' 已更新并重新加载。" - else: - message = f"模板 '{name}' 已更新。" - - return jsonify(asdict(Response().ok(data={"name": name}, message=message))) - except ValueError as e: - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 400 - return response - except Exception as e: - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 500 - return response - - async def delete_template(self, name: str): - """删除一个T2I模板""" - try: - name = name.strip() - self.manager.delete_template(name) - return jsonify( - asdict(Response().ok(message="Template deleted successfully.")), - ) - except FileNotFoundError: - response = jsonify(asdict(Response().error("Template not found."))) - response.status_code = 404 - return response - except ValueError as e: - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 400 - return response - except Exception as e: - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 500 - return response - - async def set_active_template(self): - """设置当前活动的T2I模板""" - try: - data = await request.json - name = data.get("name") - if not name: - response = jsonify(asdict(Response().error("模板名称(name)不能为空。"))) - response.status_code = 400 - return response - - # 验证模板文件是否存在 - self.manager.get_template(name) - - # 更新所有配置并热重载以应用更改 - await self._sync_active_template_to_all_configs(name) - - return jsonify(asdict(Response().ok(message=f"模板 '{name}' 已成功应用。"))) - - except FileNotFoundError: - response = jsonify( - asdict(Response().error(f"模板 '{name}' 不存在,无法应用。")), - ) - response.status_code = 404 - return response - except Exception as e: - logger.error("Error in set_active_template", exc_info=True) - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 500 - return response - - async def reset_default_template(self): - """重置默认的'base'模板""" - try: - self.manager.reset_default_template() - - # 更新所有配置,将激活模板也重置为'base' - await self._sync_active_template_to_all_configs("base") - - return jsonify( - asdict( - Response().ok( - message="Default template has been reset and activated.", - ), - ), - ) - except FileNotFoundError as e: - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 404 - return response - except Exception as e: - logger.error("Error in reset_default_template", exc_info=True) - response = jsonify(asdict(Response().error(str(e)))) - response.status_code = 500 - return response diff --git a/astrbot/dashboard/routes/tools.py b/astrbot/dashboard/routes/tools.py deleted file mode 100644 index 2ad80687ef..0000000000 --- a/astrbot/dashboard/routes/tools.py +++ /dev/null @@ -1,644 +0,0 @@ -import traceback - -from quart import request - -from astrbot.core import logger, sp -from astrbot.core.agent.mcp_client import MCPTool, validate_mcp_stdio_config -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.star import star_map -from astrbot.core.tools.registry import get_builtin_tool_config_statuses - -from .route import Response, Route, RouteContext - -DEFAULT_MCP_CONFIG = {"mcpServers": {}} - - -class EmptyMcpServersError(ValueError): - """Raised when mcpServers is empty.""" - - pass - - -def _extract_mcp_server_config(mcp_servers_value: object) -> dict: - """Extract server configuration from user-submitted mcpServers field. - - Raises: - ValueError: Invalid configuration - """ - if not isinstance(mcp_servers_value, dict): - raise ValueError("mcpServers must be a JSON object") - if not mcp_servers_value: - raise EmptyMcpServersError("mcpServers configuration cannot be empty") - key_0 = next(iter(mcp_servers_value)) - extracted = mcp_servers_value[key_0] - if not isinstance(extracted, dict): - raise ValueError( - "Invalid mcpServers format. Ensure each key in mcpServers is a server name, " - "and each value is an object containing fields like command/url." - ) - return extracted - - -class ToolsRoute(Route): - def __init__( - self, - context: RouteContext, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.core_lifecycle = core_lifecycle - self.routes = { - "/tools/mcp/servers": ("GET", self.get_mcp_servers), - "/tools/mcp/add": ("POST", self.add_mcp_server), - "/tools/mcp/update": ("POST", self.update_mcp_server), - "/tools/mcp/delete": ("POST", self.delete_mcp_server), - "/tools/mcp/test": ("POST", self.test_mcp_connection), - "/tools/list": ("GET", self.get_tool_list), - "/tools/toggle-tool": ("POST", self.toggle_tool), - "/tools/permission": ("POST", self.update_tool_permission), - "/tools/mcp/sync-provider": ("POST", self.sync_provider), - } - self.register_routes() - self.tool_mgr = self.core_lifecycle.provider_manager.llm_tools - - @staticmethod - def _get_tool_permission(tool_name: str) -> tuple[str, bool]: - """Return (effective_permission, configured) for a tool. - - ``configured`` is True when the permission was explicitly set via the - dashboard rather than being a fallback default. - """ - perms_store = sp.get("tool_permissions", {}, scope="global", scope_id="global") - defaults = ( - perms_store.get("_default", {}) if isinstance(perms_store, dict) else {} - ) - if tool_name in defaults: - return defaults[tool_name], True - return "member", False - - def _rollback_mcp_server(self, name: str) -> bool: - try: - rollback_config = self.tool_mgr.load_mcp_config() - if name in rollback_config["mcpServers"]: - rollback_config["mcpServers"].pop(name) - return self.tool_mgr.save_mcp_config(rollback_config) - return True - except Exception: - logger.error(traceback.format_exc()) - return False - - async def get_mcp_servers(self): - try: - config = self.tool_mgr.load_mcp_config() - servers = [] - mcp_servers = config.get("mcpServers", {}) - - if not isinstance(mcp_servers, dict): - logger.warning( - f"Invalid MCP server config type: {type(mcp_servers).__name__}. Expected object/dict; skipped all MCP servers." - ) - mcp_servers = {} - - # 获取所有服务器并添加它们的工具列表 - for name, server_config in mcp_servers.items(): - if not isinstance(server_config, dict): - logger.warning( - f"Invalid config for MCP server '{name}' (type: {type(server_config).__name__}); skipped." - ) - continue - - server_info = { - "name": name, - "active": server_config.get("active", True), - } - - # 复制所有配置字段 - for key, value in server_config.items(): - if key != "active": # active 已经处理 - server_info[key] = value - - # 如果MCP客户端已初始化,从客户端获取工具名称 - for name_key, runtime in self.tool_mgr.mcp_server_runtime_view.items(): - if name_key == name: - mcp_client = runtime.client - server_info["tools"] = [tool.name for tool in mcp_client.tools] - server_info["errlogs"] = mcp_client.server_errlogs - break - else: - server_info["tools"] = [] - - servers.append(server_info) - - return Response().ok(servers).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"Failed to get MCP server list: {e!s}").__dict__ - - async def add_mcp_server(self): - try: - server_data = await request.json - - name = server_data.get("name", "") - - # 检查必填字段 - if not name: - return Response().error("Server name cannot be empty").__dict__ - - # 移除特殊字段并检查配置是否有效 - has_valid_config = False - server_config = {"active": server_data.get("active", True)} - - # 复制所有配置字段 - for key, value in server_data.items(): - if key not in ["name", "active", "tools", "errlogs"]: # 排除特殊字段 - if key == "mcpServers": - try: - server_config = _extract_mcp_server_config( - server_data["mcpServers"] - ) - except ValueError as e: - return Response().error(f"{e!s}").__dict__ - else: - server_config[key] = value - has_valid_config = True - - if not has_valid_config: - return ( - Response() - .error("A valid server configuration is required") - .__dict__ - ) - - try: - validate_mcp_stdio_config(server_config) - except ValueError as e: - return Response().error(f"{e!s}").__dict__ - - config = self.tool_mgr.load_mcp_config() - - if name in config["mcpServers"]: - return Response().error(f"Server {name} already exists").__dict__ - - try: - await self.tool_mgr.test_mcp_server_connection(server_config) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"MCP connection test failed: {e!s}").__dict__ - - config["mcpServers"][name] = server_config - - if self.tool_mgr.save_mcp_config(config): - try: - await self.tool_mgr.enable_mcp_server( - name, - server_config, - timeout=30, - ) - except TimeoutError: - rollback_ok = self._rollback_mcp_server(name) - err_msg = f"Timed out while enabling MCP server {name}." - if not rollback_ok: - err_msg += " Configuration rollback failed. Please check the config manually." - return Response().error(err_msg).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - rollback_ok = self._rollback_mcp_server(name) - err_msg = f"Failed to enable MCP server {name}: {e!s}" - if not rollback_ok: - err_msg += " Configuration rollback failed. Please check the config manually." - return Response().error(err_msg).__dict__ - return ( - Response() - .ok(None, f"Successfully added MCP server {name}") - .__dict__ - ) - return Response().error("Failed to save configuration").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"Failed to add MCP server: {e!s}").__dict__ - - async def update_mcp_server(self): - try: - server_data = await request.json - - name = server_data.get("name", "") - old_name = server_data.get("oldName") or name - - if not name: - return Response().error("Server name cannot be empty").__dict__ - - config = self.tool_mgr.load_mcp_config() - - if old_name not in config["mcpServers"]: - return Response().error(f"Server {old_name} does not exist").__dict__ - - is_rename = name != old_name - - if name in config["mcpServers"] and is_rename: - return Response().error(f"Server {name} already exists").__dict__ - - # 获取活动状态 - old_config = config["mcpServers"][old_name] - if isinstance(old_config, dict): - old_active = old_config.get("active", True) - else: - old_active = True - active = server_data.get("active", old_active) - - # 创建新的配置对象 - server_config = {"active": active} - - # 仅更新活动状态的特殊处理 - only_update_active = True - - # 复制所有配置字段 - for key, value in server_data.items(): - if key not in [ - "name", - "active", - "tools", - "errlogs", - "oldName", - ]: # 排除特殊字段 - if key == "mcpServers": - try: - server_config = _extract_mcp_server_config( - server_data["mcpServers"] - ) - except ValueError as e: - return Response().error(f"{e!s}").__dict__ - else: - server_config[key] = value - only_update_active = False - - # 如果只更新活动状态,保留原始配置 - if only_update_active and isinstance(old_config, dict): - for key, value in old_config.items(): - if key != "active": # 除了active之外的所有字段都保留 - server_config[key] = value - - try: - validate_mcp_stdio_config(server_config) - except ValueError as e: - return Response().error(f"{e!s}").__dict__ - - # config["mcpServers"][name] = server_config - if is_rename: - config["mcpServers"].pop(old_name) - config["mcpServers"][name] = server_config - else: - config["mcpServers"][name] = server_config - - if self.tool_mgr.save_mcp_config(config): - # 处理MCP客户端状态变化 - if active: - if ( - old_name in self.tool_mgr.mcp_server_runtime_view - or not only_update_active - or is_rename - ): - try: - await self.tool_mgr.disable_mcp_server(old_name, timeout=10) - except TimeoutError as e: - return ( - Response() - .error( - f"Timed out while disabling MCP server {old_name} before enabling: {e!s}" - ) - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return ( - Response() - .error( - f"Failed to disable MCP server {old_name} before enabling: {e!s}" - ) - .__dict__ - ) - try: - await self.tool_mgr.enable_mcp_server( - name, - config["mcpServers"][name], - timeout=30, - ) - except TimeoutError: - return ( - Response() - .error(f"Timed out while enabling MCP server {name}.") - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return ( - Response() - .error(f"Failed to enable MCP server {name}: {e!s}") - .__dict__ - ) - # 如果要停用服务器 - elif old_name in self.tool_mgr.mcp_server_runtime_view: - try: - await self.tool_mgr.disable_mcp_server(old_name, timeout=10) - except TimeoutError: - return ( - Response() - .error(f"Timed out while disabling MCP server {old_name}.") - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return ( - Response() - .error(f"Failed to disable MCP server {old_name}: {e!s}") - .__dict__ - ) - - return ( - Response() - .ok(None, f"Successfully updated MCP server {name}") - .__dict__ - ) - return Response().error("Failed to save configuration").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"Failed to update MCP server: {e!s}").__dict__ - - async def delete_mcp_server(self): - try: - server_data = await request.json - name = server_data.get("name", "") - - if not name: - return Response().error("Server name cannot be empty").__dict__ - - config = self.tool_mgr.load_mcp_config() - - if name not in config["mcpServers"]: - return Response().error(f"Server {name} does not exist").__dict__ - - del config["mcpServers"][name] - - if self.tool_mgr.save_mcp_config(config): - if name in self.tool_mgr.mcp_server_runtime_view: - try: - await self.tool_mgr.disable_mcp_server(name, timeout=10) - except TimeoutError: - return ( - Response() - .error(f"Timed out while disabling MCP server {name}.") - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return ( - Response() - .error(f"Failed to disable MCP server {name}: {e!s}") - .__dict__ - ) - return ( - Response() - .ok(None, f"Successfully deleted MCP server {name}") - .__dict__ - ) - return Response().error("Failed to save configuration").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"Failed to delete MCP server: {e!s}").__dict__ - - async def test_mcp_connection(self): - """Test MCP server connection.""" - try: - server_data = await request.json - config = server_data.get("mcp_server_config", None) - - if not isinstance(config, dict) or not config: - return Response().error("Invalid MCP server configuration").__dict__ - - if "mcpServers" in config: - mcp_servers = config["mcpServers"] - if isinstance(mcp_servers, dict) and len(mcp_servers) > 1: - return ( - Response() - .error( - "Only one MCP server configuration can be tested at a time" - ) - .__dict__ - ) - try: - config = _extract_mcp_server_config(mcp_servers) - except EmptyMcpServersError: - return ( - Response() - .error("MCP server configuration cannot be empty") - .__dict__ - ) - except ValueError as e: - return Response().error(f"{e!s}").__dict__ - elif not config: - return ( - Response() - .error("MCP server configuration cannot be empty") - .__dict__ - ) - - try: - validate_mcp_stdio_config(config) - except ValueError as e: - return Response().error(f"{e!s}").__dict__ - - tools_name = await self.tool_mgr.test_mcp_server_connection(config) - return ( - Response() - .ok(data=tools_name, message="🎉 MCP server is available!") - .__dict__ - ) - - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"Failed to test MCP connection: {e!s}").__dict__ - - async def get_tool_list(self): - """Get all registered tools.""" - try: - tools = list(self.tool_mgr.func_list) - existing_names = {tool.name for tool in tools} - for tool in self.tool_mgr.iter_builtin_tools(): - if tool.name not in existing_names: - tools.append(tool) - - conf_list = self.core_lifecycle.astrbot_config_mgr.get_conf_list() - conf_name_map = {conf["id"]: conf["name"] for conf in conf_list} - config_entries = [] - for conf_id, conf in self.core_lifecycle.astrbot_config_mgr.confs.items(): - config_entries.append( - { - "conf_id": conf_id, - "conf_name": conf_name_map.get(conf_id, conf_id), - "config": conf, - } - ) - - tools_dict = [] - for tool in tools: - readonly = False - builtin_config_statuses = [] - builtin_config_tags = [] - if self.tool_mgr.is_builtin_tool(tool.name): - origin = "builtin" - origin_name = "AstrBot Core" - readonly = True - builtin_config_statuses = get_builtin_tool_config_statuses( - tool.name, - config_entries, - ) - builtin_config_tags = [ - status - for status in builtin_config_statuses - if status["enabled"] - ] - elif isinstance(tool, MCPTool): - origin = "mcp" - origin_name = tool.mcp_server_name - elif tool.handler_module_path and star_map.get( - tool.handler_module_path - ): - star = star_map[tool.handler_module_path] - origin = "plugin" - origin_name = star.name - else: - origin = "unknown" - origin_name = "unknown" - - tool_info = { - "name": tool.name, - "description": tool.description, - "parameters": tool.parameters, - "active": tool.active, - "origin": origin, - "origin_name": origin_name, - "readonly": readonly, - "builtin_config_statuses": builtin_config_statuses, - "builtin_config_tags": builtin_config_tags, - } - if not readonly: - perm, configured = self._get_tool_permission(tool.name) - tool_info["permission"] = perm - tool_info["permission_configured"] = configured - tools_dict.append(tool_info) - return Response().ok(data=tools_dict).__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"Failed to get tool list: {e!s}").__dict__ - - async def toggle_tool(self): - """Activate or deactivate a specified tool.""" - try: - data = await request.json - tool_name = data.get("name") - action = data.get("activate") # True or False - - if not tool_name or action is None: - return ( - Response() - .error("Missing required parameters: name or activate") - .__dict__ - ) - - if self.tool_mgr.is_builtin_tool(tool_name): - return ( - Response() - .error("Builtin tools are read-only and cannot be toggled.") - .__dict__ - ) - - if action: - try: - ok = self.tool_mgr.activate_llm_tool(tool_name, star_map=star_map) - except ValueError as e: - return Response().error(f"Failed to activate tool: {e!s}").__dict__ - else: - ok = self.tool_mgr.deactivate_llm_tool(tool_name) - - if ok: - return Response().ok(None, "Operation successful.").__dict__ - return ( - Response() - .error(f"Tool {tool_name} does not exist or the operation failed.") - .__dict__ - ) - - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"Failed to operate tool: {e!s}").__dict__ - - async def sync_provider(self): - """Sync MCP provider configuration.""" - try: - data = await request.json - provider_name = data.get("name") # modelscope, or others - match provider_name: - case "modelscope": - access_token = data.get("access_token", "") - await self.tool_mgr.sync_modelscope_mcp_servers(access_token) - case _: - return ( - Response().error(f"Unknown provider: {provider_name}").__dict__ - ) - - return Response().ok(message="Sync completed").__dict__ - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"Sync failed: {e!s}").__dict__ - - async def update_tool_permission(self): - """Set or remove the permission level of a registered tool.""" - try: - data = await request.json - tool_name = data.get("name") - permission = data.get("permission") # "admin" | "member" - - if not tool_name or permission not in ("admin", "member"): - return ( - Response() - .error("name and permission (admin or member) are required") - .__dict__ - ) - - if self.tool_mgr.is_builtin_tool(tool_name): - return ( - Response() - .error( - "Builtin tools do not support per-tool permission configuration." - ) - .__dict__ - ) - - # Verify the tool is known - if not any(t.name == tool_name for t in self.tool_mgr.func_list): - return Response().error(f"Tool '{tool_name}' not found").__dict__ - - perms_store = sp.get( - "tool_permissions", {}, scope="global", scope_id="global" - ) - if not isinstance(perms_store, dict): - perms_store = {} - defaults = perms_store.get("_default", {}) - if not isinstance(defaults, dict): - defaults = {} - defaults[tool_name] = permission - perms_store["_default"] = defaults - sp.put( - "tool_permissions", - perms_store, - scope="global", - scope_id="global", - ) - - return ( - Response() - .ok(None, f"Tool '{tool_name}' permission set to {permission}") - .__dict__ - ) - except Exception as e: - logger.error(traceback.format_exc()) - return Response().error(f"Failed to update tool permission: {e!s}").__dict__ diff --git a/astrbot/dashboard/routes/update.py b/astrbot/dashboard/routes/update.py deleted file mode 100644 index 210eb21005..0000000000 --- a/astrbot/dashboard/routes/update.py +++ /dev/null @@ -1,354 +0,0 @@ -import traceback -import uuid - -from quart import request - -from astrbot.core import DEMO_MODE, logger, pip_installer -from astrbot.core.config.default import VERSION -from astrbot.core.core_lifecycle import AstrBotCoreLifecycle -from astrbot.core.db.migration.helper import check_migration_needed_v4, do_migration_v4 -from astrbot.core.updator import AstrBotUpdator -from astrbot.core.utils.io import download_dashboard, get_dashboard_version - -from .route import Response, Route, RouteContext - -CLEAR_SITE_DATA_HEADERS = {"Clear-Site-Data": '"cache"'} - - -class UpdateRoute(Route): - def __init__( - self, - context: RouteContext, - astrbot_updator: AstrBotUpdator, - core_lifecycle: AstrBotCoreLifecycle, - ) -> None: - super().__init__(context) - self.routes = { - "/update/check": ("GET", self.check_update), - "/update/progress": ("GET", self.get_update_progress), - "/update/releases": ("GET", self.get_releases), - "/update/do": ("POST", self.update_project), - "/update/dashboard": ("POST", self.update_dashboard), - "/update/pip-install": ("POST", self.install_pip_package), - "/update/migration": ("POST", self.do_migration), - } - self.astrbot_updator = astrbot_updator - self.core_lifecycle = core_lifecycle - self.update_progress: dict[str, dict] = {} - self.register_routes() - - def _init_update_progress(self, progress_id: str, version: str) -> None: - self.update_progress[progress_id] = { - "id": progress_id, - "status": "running", - "stage": "preparing", - "version": version or "latest", - "message": "正在准备更新...", - "overall_percent": 0, - "stages": { - "dashboard": self._empty_stage("pending"), - "core": self._empty_stage("pending"), - }, - } - - @staticmethod - def _empty_stage(status: str = "pending") -> dict: - return { - "status": status, - "downloaded": 0, - "total": 0, - "percent": 0, - "speed": 0, - } - - def _set_update_stage( - self, - progress_id: str, - stage: str, - status: str, - message: str, - overall_percent: int | None = None, - ) -> None: - progress = self.update_progress.get(progress_id) - if not progress: - return - progress["stage"] = stage - progress["message"] = message - progress["stages"].setdefault(stage, self._empty_stage()) - progress["stages"][stage]["status"] = status - if overall_percent is not None: - progress["overall_percent"] = overall_percent - - @staticmethod - def _normalize_percent(value) -> int: - try: - percent = float(value or 0) - except (TypeError, ValueError): - return 0 - if percent <= 1: - percent *= 100 - return max(0, min(100, int(percent))) - - def _make_progress_callback( - self, - progress_id: str, - stage: str, - stage_start: int, - stage_weight: int, - ): - def _callback(payload: dict) -> None: - progress = self.update_progress.get(progress_id) - if not progress: - return - stage_percent = self._normalize_percent(payload.get("percent")) - progress["stage"] = stage - progress["stages"][stage] = { - "status": "running" if stage_percent < 100 else "done", - "downloaded": payload.get("downloaded", 0), - "total": payload.get("total", 0), - "percent": stage_percent, - "speed": payload.get("speed", 0), - } - progress["overall_percent"] = min( - 99, - stage_start + int(stage_percent * stage_weight / 100), - ) - - return _callback - - async def get_update_progress(self): - progress_id = request.args.get("id", "") - if not progress_id: - return Response().error("缺少参数 id。").__dict__ - progress = self.update_progress.get(progress_id) - if not progress: - return ( - Response() - .ok( - {"id": progress_id, "status": "idle"}, - "没有正在进行的更新。", - ) - .__dict__ - ) - return Response().ok(progress).__dict__ - - async def do_migration(self): - need_migration = await check_migration_needed_v4(self.core_lifecycle.db) - if not need_migration: - return Response().ok(None, "不需要进行迁移。").__dict__ - try: - data = await request.json - pim = data.get("platform_id_map", {}) - await do_migration_v4( - self.core_lifecycle.db, - pim, - self.core_lifecycle.astrbot_config, - ) - return Response().ok(None, "迁移成功。").__dict__ - except Exception as e: - logger.error(f"迁移失败: {traceback.format_exc()}") - return Response().error(f"迁移失败: {e!s}").__dict__ - - async def check_update(self): - type_ = request.args.get("type", None) - - try: - dv = await get_dashboard_version() - if type_ == "dashboard": - return ( - Response() - .ok({"has_new_version": dv != f"v{VERSION}", "current_version": dv}) - .__dict__ - ) - ret = await self.astrbot_updator.check_update(None, None, False) - return Response( - status="success", - message=str(ret) if ret is not None else "已经是最新版本了。", - data={ - "version": f"v{VERSION}", - "has_new_version": ret is not None, - "dashboard_version": dv, - "dashboard_has_new_version": bool(dv and dv != f"v{VERSION}"), - }, - ).__dict__ - except Exception as e: - logger.warning(f"检查更新失败: {e!s} (不影响除项目更新外的正常使用)") - return Response().error(e.__str__()).__dict__ - - async def get_releases(self): - try: - ret = await self.astrbot_updator.get_releases() - return Response().ok(ret).__dict__ - except Exception as e: - logger.error(f"/api/update/releases: {traceback.format_exc()}") - return Response().error(e.__str__()).__dict__ - - async def update_project(self): - data = await request.json - version = data.get("version", "") - reboot = data.get("reboot", True) - progress_id = data.get("progress_id") or uuid.uuid4().hex - if version == "" or version == "latest": - latest = True - version = "" - else: - latest = False - - proxy: str = data.get("proxy", None) - if proxy: - proxy = proxy.removesuffix("/") - - self._init_update_progress(progress_id, version) - try: - self._set_update_stage( - progress_id, - "dashboard", - "running", - "正在下载 WebUI...", - 0, - ) - await download_dashboard( - latest=latest, - version=version, - proxy=proxy, - progress_callback=self._make_progress_callback( - progress_id, - "dashboard", - 0, - 45, - ), - ) - self._set_update_stage( - progress_id, - "dashboard", - "done", - "WebUI 下载完成。", - 45, - ) - - self._set_update_stage( - progress_id, - "core", - "running", - "正在下载 AstrBot 项目代码...", - 45, - ) - await self.astrbot_updator.update( - latest=latest, - version=version, - proxy=proxy, - progress_callback=self._make_progress_callback( - progress_id, - "core", - 45, - 45, - ), - ) - self._set_update_stage( - progress_id, - "core", - "done", - "项目代码下载完成。", - 90, - ) - - # pip 更新依赖 - self._set_update_stage( - progress_id, - "dependencies", - "running", - "正在更新依赖...", - 92, - ) - logger.info("更新依赖中...") - try: - await pip_installer.install(requirements_path="requirements.txt") - except Exception as e: - logger.error(f"更新依赖失败: {e}") - self._set_update_stage( - progress_id, - "dependencies", - "done", - "依赖更新完成。", - 96, - ) - - if reboot: - self._set_update_stage( - progress_id, - "restart", - "running", - "更新成功,正在准备重启...", - 98, - ) - await self.core_lifecycle.restart() - self.update_progress[progress_id].update( - { - "status": "success", - "stage": "done", - "message": "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。", - "overall_percent": 100, - }, - ) - ret = ( - Response() - .ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。") - .__dict__ - ) - return ret, 200, CLEAR_SITE_DATA_HEADERS - self.update_progress[progress_id].update( - { - "status": "success", - "stage": "done", - "message": "更新成功,AstrBot 将在下次启动时应用新的代码。", - "overall_percent": 100, - }, - ) - ret = ( - Response() - .ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。") - .__dict__ - ) - return ret, 200, CLEAR_SITE_DATA_HEADERS - except Exception as e: - self.update_progress[progress_id].update( - { - "status": "error", - "message": e.__str__(), - }, - ) - logger.error(f"/api/update_project: {traceback.format_exc()}") - return Response().error(e.__str__()).__dict__ - - async def update_dashboard(self): - try: - try: - await download_dashboard(version=f"v{VERSION}", latest=False) - except Exception as e: - logger.error(f"下载管理面板文件失败: {e}。") - return Response().error(f"下载管理面板文件失败: {e}").__dict__ - ret = Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__ - return ret, 200, CLEAR_SITE_DATA_HEADERS - except Exception as e: - logger.error(f"/api/update_dashboard: {traceback.format_exc()}") - return Response().error(e.__str__()).__dict__ - - async def install_pip_package(self): - if DEMO_MODE: - return ( - Response() - .error("You are not permitted to do this operation in demo mode") - .__dict__ - ) - - data = await request.json - package = data.get("package", "") - mirror = data.get("mirror", None) - if not package: - return Response().error("缺少参数 package 或不合法。").__dict__ - try: - await pip_installer.install(package, mirror=mirror) - return Response().ok(None, "安装成功。").__dict__ - except Exception as e: - logger.error(f"/api/update_pip: {traceback.format_exc()}") - return Response().error(e.__str__()).__dict__ diff --git a/astrbot/dashboard/routes/util.py b/astrbot/dashboard/routes/util.py deleted file mode 100644 index d08af03eed..0000000000 --- a/astrbot/dashboard/routes/util.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Dashboard 路由工具集。 - -这里放一些 dashboard routes 可复用的小工具函数。 - -目前主要用于「配置文件上传(file 类型配置项)」功能: -- 清洗/规范化用户可控的文件名与相对路径 -- 将配置 key 映射到配置项独立子目录 -""" - -import os - - -def get_schema_item(schema: dict | None, key_path: str) -> dict | None: - """按 dot-path 获取 schema 的节点。 - - 同时支持: - - 扁平 schema(直接 key 命中) - - 嵌套 object schema({type: "object", items: {...}}) - - template_list schema(.templates. \ No newline at end of file + diff --git a/dashboard/src/views/stats/StatsPage.vue b/dashboard/src/views/stats/StatsPage.vue index 251971baf2..71133c0e2d 100644 --- a/dashboard/src/views/stats/StatsPage.vue +++ b/dashboard/src/views/stats/StatsPage.vue @@ -204,9 +204,9 @@ ', + encoding="utf-8", + ) + (static_folder / "favicon.svg").write_text("", encoding="utf-8") + (assets_folder / "index-demo.js").write_text( + "window.__astrbotStaticTest = true;", + encoding="utf-8", + ) + (tmp_path / "secret.txt").write_text("outside static root", encoding="utf-8") + + app = create_dashboard_asgi_app( + core_lifecycle=fake_core_lifecycle, + db=fake_db, + jwt_secret=JWT_SECRET, + static_folder=str(static_folder), + ) + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient( + transport=transport, + base_url="http://testserver", + ) as client: + asset_response = await client.get("/assets/index-demo.js") + favicon_response = await client.get("/favicon.svg") + page_response = await client.get("/config") + missing_response = await client.get("/assets/missing.js") + traversal_response = await client.get("/assets/%2E%2E/%2E%2E/secret.txt") + api_response = await client.get("/api/not-found") + + assert asset_response.status_code == 200 + assert "window.__astrbotStaticTest" in asset_response.text + assert favicon_response.status_code == 200 + assert favicon_response.text == "" + assert page_response.status_code == 200 + assert "/assets/index-demo.js" in page_response.text + assert missing_response.status_code == 404 + assert traversal_response.status_code == 404 + assert api_response.status_code == 404 + + +@pytest.mark.asyncio +async def test_v1_openapi_uses_pydantic_request_bodies( + asgi_client: httpx.AsyncClient, +): + response = await asgi_client.get("/api/v1/openapi.json") + + assert response.status_code == 200 + spec = response.json() + schemas = spec["components"]["schemas"] + assert "BotRegistrationRequest" in schemas + assert "ConfigContentRequest" in schemas + + bot_registration = spec["paths"]["/api/v1/bot-types/{bot_type}/registration"][ + "post" + ] + assert bot_registration["parameters"][0]["name"] == "bot_type" + assert bot_registration["requestBody"]["content"]["application/json"]["schema"][ + "$ref" + ].endswith("/BotRegistrationRequest") + + config_profile_update = spec["paths"]["/api/v1/config-profiles/{config_id}"]["put"] + assert config_profile_update["requestBody"]["content"]["application/json"][ + "schema" + ]["$ref"].endswith("/ConfigContentRequest") + + system_config_update = spec["paths"]["/api/v1/system-config"]["put"] + assert system_config_update["requestBody"]["content"]["application/json"]["schema"][ + "$ref" + ].endswith("/ConfigContentRequest") + + +@pytest.mark.asyncio +async def test_v1_conversation_path_id_allows_slash(asgi_client: httpx.AsyncClient): + response = await asgi_client.get( + "/api/v1/conversations/conversation%2Fwith%2Fslash", + params={"user_id": "webchat:FriendMessage:webchat!user!session-1"}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "ok" + assert payload["data"]["cid"] == "conversation/with/slash" + + +@pytest.mark.asyncio +async def test_dashboard_alias_conversation_detail_uses_fastapi_service( + asgi_client: httpx.AsyncClient, +): + response = await asgi_client.post( + "/api/conversation/detail", + json={ + "user_id": "webchat:FriendMessage:webchat!user!session-1", + "cid": "conversation/with/slash", + }, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + payload = response.json() + assert payload["status"] == "ok" + assert payload["data"]["cid"] == "conversation/with/slash" + + +@pytest.mark.asyncio +async def test_v1_bots_matches_dashboard_platform_alias_list( + asgi_client: httpx.AsyncClient, +): + headers = _jwt_headers() + + dashboard_alias_response = await asgi_client.get( + "/api/config/platform/list", + headers=headers, + ) + v1_response = await asgi_client.get("/api/v1/bots", headers=headers) + + assert dashboard_alias_response.status_code == 200 + assert v1_response.status_code == 200 + dashboard_alias_data = dashboard_alias_response.json() + v1_data = v1_response.json() + assert dashboard_alias_data["status"] == "ok" + assert v1_data["status"] == "ok" + assert v1_data["data"]["bots"] == dashboard_alias_data["data"]["platforms"] + + +@pytest.mark.asyncio +async def test_v1_bot_stats_match_platform_manager(asgi_client: httpx.AsyncClient): + response = await asgi_client.get("/api/v1/bots/stats", headers=_jwt_headers()) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["data"]["platforms"] == [{"id": "webchat-main", "status": "running"}] + + +@pytest.mark.asyncio +async def test_v1_config_routes_can_replace_all_routes( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, +): + routing = { + "webchat-main:private:*": "default", + "webchat-main:group:demo": "group-conf", + } + + response = await asgi_client.put( + "/api/v1/config-routes", + headers=_jwt_headers(), + json={"routing": routing}, + ) + + assert response.status_code == 200 + assert response.json()["status"] == "ok" + assert fake_core_lifecycle.umop_config_router.umop_to_conf_id == routing + + list_response = await asgi_client.get( + "/api/v1/config-routes", + headers=_jwt_headers(), + ) + assert list_response.status_code == 200 + assert list_response.json()["data"]["routing"] == routing + + +@pytest.mark.asyncio +async def test_v1_active_umos_uses_session_service( + asgi_client: httpx.AsyncClient, +): + response = await asgi_client.get( + "/api/v1/sessions/active-umos", + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["data"]["umos"] == ["webchat:FriendMessage:webchat!user!session-1"] + assert data["data"]["umo_infos"][0]["platform"] == "webchat" + + +@pytest.mark.asyncio +async def test_v1_system_config_update_preserves_independent_bot_provider_sections( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, + monkeypatch: pytest.MonkeyPatch, +): + def fake_save_config(post_config: dict, config: FakeAstrBotConfig, is_core=False): + config.save_config(post_config) + + monkeypatch.setattr(config_service, "save_config", fake_save_config) + + original_platform = copy.deepcopy(fake_core_lifecycle.astrbot_config["platform"]) + original_provider_sources = copy.deepcopy( + fake_core_lifecycle.astrbot_config["provider_sources"] + ) + original_providers = copy.deepcopy(fake_core_lifecycle.astrbot_config["provider"]) + payload = copy.deepcopy(fake_core_lifecycle.astrbot_config) + payload["platform"] = [] + payload["provider_sources"] = [] + payload["provider"] = [] + payload["provider_settings"] = {"default_provider_id": "gpt-mini"} + + response = await asgi_client.put( + "/api/v1/system-config", + headers=_jwt_headers(), + json=payload, + ) + + assert response.status_code == 200 + assert response.json()["status"] == "ok" + assert fake_core_lifecycle.astrbot_config["platform"] == original_platform + assert ( + fake_core_lifecycle.astrbot_config["provider_sources"] + == original_provider_sources + ) + assert fake_core_lifecycle.astrbot_config["provider"] == original_providers + assert fake_core_lifecycle.astrbot_config["provider_settings"] == { + "default_provider_id": "gpt-mini" + } + assert fake_core_lifecycle.reloaded_config_ids == ["default"] + + +@pytest.mark.asyncio +async def test_v1_system_config_returns_system_metadata( + asgi_client: httpx.AsyncClient, +): + response = await asgi_client.get( + "/api/v1/system-config", + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert "system_group" in data["data"]["metadata"] + assert "platform_group" not in data["data"]["metadata"] + + +@pytest.mark.asyncio +async def test_v1_providers_matches_dashboard_provider_alias_list( + asgi_client: httpx.AsyncClient, +): + headers = _jwt_headers() + + dashboard_alias_response = await asgi_client.get( + "/api/config/provider/list?provider_type=chat_completion", + headers=headers, + ) + v1_response = await asgi_client.get( + "/api/v1/providers?capability=chat", + headers=headers, + ) + + assert dashboard_alias_response.status_code == 200 + assert v1_response.status_code == 200 + dashboard_alias_data = dashboard_alias_response.json() + v1_data = v1_response.json() + assert dashboard_alias_data["status"] == "ok" + assert v1_data["status"] == "ok" + assert v1_data["data"]["providers"] == dashboard_alias_data["data"] + + +@pytest.mark.asyncio +async def test_v1_provider_source_rename_updates_provider_refs( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr( + "astrbot.dashboard.services.config_service.save_config", + lambda *_args, **_kwargs: None, + ) + + response = await asgi_client.put( + "/api/v1/provider-sources/openai-source", + json={ + "config": { + "id": "openai-renamed", + "type": "openai_chat_completion", + "provider_type": "chat_completion", + "api_base": "https://api.example.test/v1", + "key": ["test-key"], + } + }, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + assert response.json()["status"] == "ok" + config = fake_core_lifecycle.astrbot_config + assert config["provider_sources"][0]["id"] == "openai-renamed" + assert config["provider"][0]["provider_source_id"] == "openai-renamed" + assert ( + fake_core_lifecycle.provider_manager.provider_sources_config[0]["id"] + == "openai-renamed" + ) + assert fake_core_lifecycle.provider_manager.reloaded_providers == [ + config["provider"][0] + ] + + +@pytest.mark.asyncio +async def test_v1_provider_update_keeps_dashboard_id_rename_behavior( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, +): + response = await asgi_client.put( + "/api/v1/providers/gpt-mini", + json={ + "config": { + "id": "gpt-renamed", + "provider_source_id": "openai-source", + "model": "gpt-4o-mini", + "enable": True, + } + }, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + assert response.json()["status"] == "ok" + config = fake_core_lifecycle.astrbot_config + assert config["provider"][0]["id"] == "gpt-renamed" + assert fake_core_lifecycle.provider_manager.reloaded_providers == [ + config["provider"][0] + ] + + +@pytest.mark.asyncio +async def test_v1_create_standalone_provider_matches_dashboard_alias_capability( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, +): + response = await asgi_client.post( + "/api/v1/providers", + json={ + "config": { + "id": "tts-main", + "type": "edge_tts", + "provider_type": "text_to_speech", + "enable": True, + } + }, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + assert response.json()["status"] == "ok" + assert fake_core_lifecycle.astrbot_config["provider"][-1] == { + "id": "tts-main", + "type": "edge_tts", + "provider_type": "text_to_speech", + "enable": True, + } + + +@pytest.mark.asyncio +async def test_v1_safe_provider_routes_accept_slash_ids( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(config_service, "save_config", lambda *_args, **_kwargs: None) + + source_id = "https://example.com/source" + provider_id = "qianxun/kimi-k2-0905-preview" + config = fake_core_lifecycle.astrbot_config + config["provider_sources"].append( + { + "id": source_id, + "type": "openai_chat_completion", + "provider_type": "chat_completion", + "api_base": "https://api.example.test/v1", + "key": ["test-key"], + } + ) + config["provider"].append( + { + "id": provider_id, + "provider_source_id": source_id, + "model": "kimi-k2-0905-preview", + "enable": True, + } + ) + provider_instance = FakeProviderInstance(provider_id) + fake_core_lifecycle.provider_manager.inst_map[provider_id] = provider_instance + + async def fake_list_models(_service, requested_source_id: str): + return {"provider_source_id": requested_source_id, "models": ["model/a"]} + + monkeypatch.setattr( + config_service.ProviderConfigService, + "list_provider_source_models", + fake_list_models, + ) + + headers = _jwt_headers() + get_response = await asgi_client.get( + "/api/v1/providers/by-id", + params={"provider_id": provider_id, "merged": True}, + headers=headers, + ) + schema_response = await asgi_client.get( + "/api/v1/providers/schema", + headers=headers, + ) + path_test_response = await asgi_client.post( + "/api/v1/providers/qianxun%2Fkimi-k2-0905-preview/test", + headers=headers, + ) + safe_test_response = await asgi_client.post( + "/api/v1/providers/test", + json={"provider_id": provider_id}, + headers=headers, + ) + enabled_response = await asgi_client.patch( + "/api/v1/providers/enabled", + json={"provider_id": provider_id, "enabled": False}, + headers=headers, + ) + embedding_response = await asgi_client.post( + "/api/v1/providers/embedding-dimension", + json={"provider_id": provider_id, "provider_config": {"model": "model/a"}}, + headers=headers, + ) + source_models_response = await asgi_client.get( + "/api/v1/provider-sources/models", + params={"source_id": source_id}, + headers=headers, + ) + source_providers_response = await asgi_client.get( + "/api/v1/provider-sources/providers", + params={"source_id": source_id}, + headers=headers, + ) + + assert get_response.status_code == 200 + assert get_response.json()["data"]["provider"]["id"] == provider_id + assert schema_response.status_code == 200 + assert "config_schema" in schema_response.json()["data"] + assert path_test_response.status_code == 200 + assert path_test_response.json()["data"]["status"] == "available" + assert safe_test_response.status_code == 200 + assert safe_test_response.json()["data"]["status"] == "available" + assert provider_instance.tested is True + assert enabled_response.status_code == 200 + assert config["provider"][-1]["enable"] is False + assert embedding_response.status_code == 400 + assert embedding_response.json()["status"] == "error" + assert embedding_response.json()["message"] in { + "提供商适配器加载失败,请检查提供商类型配置或查看服务端日志", + "提供商不是 EmbeddingProvider 类型", + } + assert source_models_response.status_code == 200 + assert source_models_response.json()["data"]["provider_source_id"] == source_id + assert source_providers_response.status_code == 200 + assert source_providers_response.json()["data"]["providers"][0]["id"] == provider_id + + +@pytest.mark.asyncio +async def test_v1_safe_bot_routes_accept_slash_ids( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(config_service, "save_config", lambda *_args, **_kwargs: None) + + bot_id = "group/a" + fake_core_lifecycle.astrbot_config["platform"].append( + {"id": bot_id, "type": "webchat", "enable": True} + ) + headers = _jwt_headers() + + get_response = await asgi_client.get( + "/api/v1/bots/by-id", + params={"bot_id": bot_id}, + headers=headers, + ) + enabled_response = await asgi_client.patch( + "/api/v1/bots/enabled", + json={"bot_id": bot_id, "enabled": False}, + headers=headers, + ) + test_response = await asgi_client.post( + "/api/v1/bots/test", + json={"bot_id": bot_id}, + headers=headers, + ) + delete_response = await asgi_client.delete( + "/api/v1/bots/by-id", + params={"bot_id": bot_id}, + headers=headers, + ) + + assert get_response.status_code == 200 + assert get_response.json()["data"]["bot"]["id"] == bot_id + assert enabled_response.status_code == 200 + assert fake_core_lifecycle.platform_reload_configs[-1]["id"] == bot_id + assert fake_core_lifecycle.platform_reload_configs[-1]["enable"] is False + assert test_response.status_code == 200 + assert test_response.json()["data"] == {"id": bot_id, "status": "unsupported"} + assert delete_response.status_code == 200 + assert fake_core_lifecycle.terminated_platform_ids == [bot_id] + + +@pytest.mark.asyncio +async def test_v1_bot_scope_accepts_api_key( + asgi_client: httpx.AsyncClient, + fake_db: FakeDb, +): + config_key = "abk_fastapi_v1_config" + fake_db.add_api_key(config_key, scopes=["config"]) + + config_response = await asgi_client.get( + "/api/v1/bots", + headers={"X-API-Key": config_key}, + ) + + assert config_response.status_code == 403 + + bot_key = "abk_fastapi_v1_bot" + fake_db.add_api_key(bot_key, scopes=["bot"]) + + response = await asgi_client.get( + "/api/v1/bots", + headers={"X-API-Key": bot_key}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert isinstance(data["data"]["bots"], list) + assert fake_db.touched_key_ids == ["config-key"] + + +@pytest.mark.asyncio +async def test_dashboard_alias_route_still_works_through_asgi_app( + asgi_client: httpx.AsyncClient, +): + response = await asgi_client.get("/api/stat/start-time") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["data"]["start_time"] == 1234567890 + + +@pytest.mark.asyncio +async def test_v1_plugins_accept_api_key( + asgi_client: httpx.AsyncClient, + fake_db: FakeDb, +): + raw_key = "abk_fastapi_v1_plugin" + fake_db.add_api_key(raw_key, scopes=["plugin"]) + + response = await asgi_client.get( + "/api/v1/plugins", + headers={"X-API-Key": raw_key}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert [item["name"] for item in data["data"]] == ["astrbot_plugin_demo"] + + +@pytest.mark.asyncio +async def test_tool_permission_routes_call_service( + asgi_app: FastAPI, + asgi_client: httpx.AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + captured_payloads = [] + tools_service = asgi_app.state.services.tools + + def fake_update_tool_permission(payload): + captured_payloads.append(copy.deepcopy(payload)) + return f"permission set for {payload['name']}" + + monkeypatch.setattr( + tools_service, + "update_tool_permission", + fake_update_tool_permission, + ) + + v1_response = await asgi_client.patch( + "/api/v1/tools/plugin/foo/permission", + json={"permission": "admin"}, + headers=_jwt_headers(), + ) + legacy_response = await asgi_client.post( + "/api/tools/permission", + json={"name": "legacy_tool", "permission": "member"}, + headers=_jwt_headers(), + ) + + assert v1_response.status_code == 200 + assert v1_response.json()["status"] == "ok" + assert legacy_response.status_code == 200 + assert legacy_response.json()["status"] == "ok" + assert captured_payloads == [ + {"name": "plugin/foo", "permission": "admin"}, + {"name": "legacy_tool", "permission": "member"}, + ] + + +@pytest.mark.asyncio +async def test_v1_plugin_enabled_patch_calls_service( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, +): + response = await asgi_client.patch( + "/api/v1/plugins/astrbot_plugin_demo/enabled", + json={"enabled": False}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["message"] == "停用成功。" + plugin = fake_core_lifecycle.plugin_manager.context.get_all_stars()[0] + assert plugin.activated is False + + +@pytest.mark.asyncio +async def test_v1_plugin_version_support_check_uses_service( + asgi_client: httpx.AsyncClient, +): + response = await asgi_client.post( + "/api/v1/plugins/version-support/check", + json={"plugin_ids": ["astrbot_plugin_demo"]}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["data"] == { + "supported": True, + "message": "supported: ", + "astrbot_version": "", + } + + +@pytest.mark.asyncio +async def test_v1_plugin_url_install_accepts_download_url_and_missing_body( + asgi_app: FastAPI, + asgi_client: httpx.AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + captured_payloads = [] + plugin_service = asgi_app.state.services.plugins + + async def fake_install_plugin(payload): + captured_payloads.append(payload) + if not payload.get("url"): + raise RuntimeError("missing url") + return {"name": "astrbot_plugin_demo"}, "安装成功。" + + monkeypatch.setattr(plugin_service, "install_plugin", fake_install_plugin) + + response = await asgi_client.post( + "/api/v1/plugins/install/url", + json={ + "url": "https://github.com/AstrBotDevs/astrbot-plugin-demo", + "download_url": "https://cdn.example/plugin.zip", + "ignore_version_check": True, + }, + headers=_jwt_headers(), + ) + empty_body_response = await asgi_client.post( + "/api/v1/plugins/install/url", + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert captured_payloads[0] == { + "url": "https://github.com/AstrBotDevs/astrbot-plugin-demo", + "download_url": "https://cdn.example/plugin.zip", + "proxy": None, + "ignore_version_check": True, + } + assert empty_body_response.status_code == 200 + empty_body_data = empty_body_response.json() + assert empty_body_data["status"] == "error" + assert empty_body_data["message"] == "插件操作失败,请查看服务端日志。" + assert "missing url" not in str(empty_body_data) + + +@pytest.mark.asyncio +async def test_v1_plugin_update_all_hides_internal_exceptions( + asgi_client: httpx.AsyncClient, +): + response = await asgi_client.post( + "/api/v1/plugins/update", + json={"plugin_ids": ["astrbot_plugin_demo"]}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + result = data["data"]["results"][0] + assert result["status"] == "error" + assert result["message"] == "更新失败,请查看服务端日志。" + assert "AttributeError" not in str(data) + assert "update_plugin" not in str(data) + + +@pytest.mark.asyncio +async def test_v1_plugin_extension_maps_nested_plugin_path( + asgi_client: httpx.AsyncClient, +): + response = await asgi_client.post( + "/api/v1/plugins/extensions/astrbot_plugin_demo/api/action", + json={"value": "demo"}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["data"] == { + "plugin_path": "astrbot_plugin_demo/api/action", + "method": "POST", + "payload": {"value": "demo"}, + "username": "fastapi-v1-test", + } + + +@pytest.mark.asyncio +async def test_v1_plugin_extension_supports_quart_request_context( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, +): + from quart import g as quart_g + from quart import jsonify as quart_jsonify + from quart import request as quart_request + + async def quart_plugin_extension(item_id: str): + return quart_jsonify( + { + "status": "ok", + "data": { + "item_id": item_id, + "path": quart_request.path, + "method": quart_request.method, + "source": quart_request.args.get("source"), + "payload": await quart_request.get_json(), + "username": quart_g.username, + }, + } + ) + + fake_core_lifecycle.star_context.registered_web_apis = [ + ("/quart/", quart_plugin_extension, ["POST"], "quart") + ] + + response = await asgi_client.post( + "/api/v1/plugins/extensions/quart/demo-item?source=v1", + json={"value": "demo"}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + assert response.headers["content-type"].startswith("application/json") + data = response.json() + assert data["status"] == "ok" + assert data["data"] == { + "item_id": "demo-item", + "path": "/api/plug/quart/demo-item", + "method": "POST", + "source": "v1", + "payload": {"value": "demo"}, + "username": "fastapi-v1-test", + } + + +@pytest.mark.asyncio +async def test_v1_plugin_config_file_routes_reach_service_layer( + asgi_client: httpx.AsyncClient, +): + headers = _jwt_headers() + + list_response = await asgi_client.get( + "/api/v1/plugins/astrbot_plugin_demo/config-files/assets", + headers=headers, + ) + upload_response = await asgi_client.post( + "/api/v1/plugins/astrbot_plugin_demo/config-files/assets", + json={"filename": "demo.txt"}, + headers=headers, + ) + delete_response = await asgi_client.request( + "DELETE", + "/api/v1/plugins/astrbot_plugin_demo/config-files", + json={"path": "demo.txt"}, + headers=headers, + ) + + assert list_response.status_code == 400 + assert list_response.json()["status"] == "error" + assert upload_response.status_code == 400 + assert upload_response.json()["status"] == "error" + assert delete_response.status_code == 400 + assert delete_response.json()["status"] == "error" + + +@pytest.mark.asyncio +async def test_v1_safe_plugin_routes_accept_slash_ids( + asgi_app: FastAPI, + asgi_client: httpx.AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + plugin_id = "plugin/foo" + headers = _jwt_headers() + plugin_service = asgi_app.state.services.plugins + config_display_service = asgi_app.state.services.config_display + config_file_service = asgi_app.state.services.config_files + + async def fake_get_plugin_detail(**kwargs): + return {"name": kwargs["plugin_name"]} + + async def fake_set_plugin_enabled(data, *, enabled: bool): + return {"payload": {"name": data["name"], "enabled": enabled}} + + async def fake_update_plugin(data): + return {"payload": data} + + def fake_get_plugin_readme(name: str): + return {"name": name, "content": "readme"}, "ok" + + async def fake_get_configs(name: str): + return {"schema": {"name": name}} + + def fake_list_config_files(*, scope: str, name: str, key_path: str): + return {"scope": scope, "name": name, "key": key_path} + + monkeypatch.setattr(plugin_service, "get_plugin_detail", fake_get_plugin_detail) + monkeypatch.setattr(plugin_service, "set_plugin_enabled", fake_set_plugin_enabled) + monkeypatch.setattr(plugin_service, "update_plugin", fake_update_plugin) + monkeypatch.setattr(plugin_service, "get_plugin_readme", fake_get_plugin_readme) + monkeypatch.setattr(config_display_service, "get_configs", fake_get_configs) + monkeypatch.setattr( + config_file_service, + "list_config_files", + fake_list_config_files, + ) + + detail_response = await asgi_client.get( + "/api/v1/plugins/by-id", + params={"plugin_id": plugin_id}, + headers=headers, + ) + enabled_response = await asgi_client.patch( + "/api/v1/plugins/enabled", + json={"plugin_id": plugin_id, "enabled": False}, + headers=headers, + ) + update_response = await asgi_client.post( + "/api/v1/plugins/update", + json={"plugin_id": plugin_id, "reinstall": True}, + headers=headers, + ) + readme_response = await asgi_client.get( + "/api/v1/plugins/readme", + params={"plugin_id": plugin_id}, + headers=headers, + ) + schema_response = await asgi_client.get( + "/api/v1/plugins/config/schema", + params={"plugin_id": plugin_id}, + headers=headers, + ) + config_files_response = await asgi_client.get( + "/api/v1/plugins/config-files", + params={"plugin_id": plugin_id, "config_key": "assets/path"}, + headers=headers, + ) + + assert detail_response.status_code == 200 + assert detail_response.json()["data"]["name"] == plugin_id + assert enabled_response.status_code == 200 + assert enabled_response.json()["data"]["payload"] == { + "name": plugin_id, + "enabled": False, + } + assert update_response.status_code == 200 + assert update_response.json()["data"]["payload"] == { + "name": plugin_id, + "reinstall": True, + } + assert readme_response.status_code == 200 + assert readme_response.json()["data"]["name"] == plugin_id + assert schema_response.status_code == 200 + assert schema_response.json()["data"]["plugin_name"] == plugin_id + assert config_files_response.status_code == 200 + assert config_files_response.json()["data"] == { + "scope": "plugin", + "name": plugin_id, + "key": "assets/path", + } + + +@pytest.mark.asyncio +async def test_v1_safe_plugin_source_delete_accepts_slash_ids( + asgi_client: httpx.AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + source_id = "https://example.com/source" + sources = [{"id": source_id}, {"id": "keep"}] + + async def fake_global_get(_key, _default=None): + return list(sources) + + async def fake_global_put(_key, value): + sources[:] = value + + monkeypatch.setattr( + "astrbot.dashboard.services.plugin_service.sp.global_get", + fake_global_get, + ) + monkeypatch.setattr( + "astrbot.dashboard.services.plugin_service.sp.global_put", + fake_global_put, + ) + + response = await asgi_client.delete( + "/api/v1/plugin-sources/by-id", + params={"source_id": source_id}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + assert response.json()["data"]["sources"] == [{"id": "keep"}] + + +@pytest.mark.asyncio +async def test_v1_command_patch_updates_service( + asgi_app: FastAPI, + asgi_client: httpx.AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + async def fake_toggle(handler_full_name: str | None, enabled): + return { + "handler_full_name": handler_full_name, + "enabled": enabled, + } + + monkeypatch.setattr( + asgi_app.state.services.commands, + "toggle_command", + fake_toggle, + ) + + response = await asgi_client.patch( + "/api/v1/commands/plugin.handler", + json={"enabled": False}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["data"] == { + "handler_full_name": "plugin.handler", + "enabled": False, + } + + +@pytest.mark.asyncio +async def test_v1_bot_type_registration_uses_platform_service( + asgi_app: FastAPI, + asgi_client: httpx.AsyncClient, + monkeypatch: pytest.MonkeyPatch, +): + async def fake_registration(platform_type: str, payload: dict): + return {"platform_type": platform_type, "payload": payload} + + monkeypatch.setattr( + asgi_app.state.services.platforms, + "handle_platform_registration", + fake_registration, + ) + + response = await asgi_client.post( + "/api/v1/bot-types/webchat/registration", + json={"registration_code": "abc123"}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["data"] == { + "platform_type": "webchat", + "payload": {"registration_code": "abc123"}, + } + + +@pytest.mark.asyncio +async def test_v1_token_file_is_public( + asgi_client: httpx.AsyncClient, + tmp_path: Path, +): + token_file = tmp_path / "token-file.txt" + token_file.write_text("token:demo-token", encoding="utf-8") + file_token = await file_token_service.register_file(str(token_file), timeout=60) + + response = await asgi_client.get(f"/api/v1/files/tokens/{file_token}") + + assert response.status_code == 200 + assert response.text == "token:demo-token" + assert response.headers["content-type"].startswith("text/plain") + + +def test_v1_openapi_alias_websocket_routes_are_mounted(asgi_app): + websocket_paths = { + route.path + for route in asgi_app.router.routes + if "websocket" in route.__class__.__name__.lower() + } + + assert "/api/v1/chat/ws" in websocket_paths + assert "/api/v1/live-chat/ws" in websocket_paths + assert "/api/v1/unified-chat/ws" in websocket_paths + + +def test_dashboard_config_aliases_are_registered_on_fastapi(asgi_app): + http_paths = { + route.path + for route in asgi_app.router.routes + if "route" in route.__class__.__name__.lower() + } + + assert "/api/config/platform/list" in http_paths + assert "/api/config/provider/list" in http_paths + assert "/api/config/provider_sources/update" in http_paths + + +@pytest.mark.asyncio +async def test_v1_mcp_enabled_patch_updates_stored_active_flag( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, +): + response = await asgi_client.patch( + "/api/v1/mcp/servers/demo-server/enabled", + json={"enabled": False}, + headers=_jwt_headers(), + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["message"] == "Successfully updated MCP server demo-server" + mcp_servers = fake_core_lifecycle.provider_manager.llm_tools.config["mcpServers"] + assert mcp_servers["demo-server"]["active"] is False + + +@pytest.mark.asyncio +async def test_v1_safe_mcp_routes_accept_slash_server_names( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, +): + server_name = "modelscope/demo" + headers = _jwt_headers() + fake_tools = fake_core_lifecycle.provider_manager.llm_tools + + enabled_response = await asgi_client.patch( + "/api/v1/mcp/servers/enabled", + json={"server_name": server_name, "enabled": False}, + headers=headers, + ) + assert enabled_response.status_code == 200 + assert fake_tools.config["mcpServers"][server_name]["active"] is False + + test_response = await asgi_client.post( + "/api/v1/mcp/servers/test", + json={"server_name": server_name}, + headers=headers, + ) + assert test_response.status_code == 200 + assert test_response.json()["data"] == ["demo_tool"] + assert fake_tools.tested_configs[-1] == { + "active": False, + "url": "https://example.com/modelscope-demo", + } + + delete_response = await asgi_client.delete( + "/api/v1/mcp/servers/by-name", + params={"server_name": server_name}, + headers=headers, + ) + assert delete_response.status_code == 200 + assert server_name not in fake_tools.config["mcpServers"] + + sync_response = await asgi_client.post( + "/api/v1/mcp/providers/modelscope/sync", + json={"access_token": "token"}, + headers=headers, + ) + assert sync_response.status_code == 200 + assert fake_tools.synced_modelscope_tokens == ["token"] + + +@pytest.mark.asyncio +async def test_v1_skills_reject_developer_api_key_scope( + asgi_app: FastAPI, + asgi_client: httpx.AsyncClient, + fake_db: FakeDb, + monkeypatch: pytest.MonkeyPatch, +): + raw_key = "abk_fastapi_v1_skill" + fake_db.add_api_key(raw_key, scopes=["skill"]) + monkeypatch.setattr( + asgi_app.state.services.skills, + "get_skills", + lambda: {"skills": [{"name": "demo_skill"}]}, + ) + + response = await asgi_client.get( + "/api/v1/skills", + headers={"X-API-Key": raw_key}, + ) + + assert response.status_code == 403 + data = response.json() + assert data["status"] == "error" + assert data["message"] == "Insufficient API key scope" + + +@pytest.mark.asyncio +async def test_v1_safe_skill_routes_accept_slash_names( + asgi_app: FastAPI, + asgi_client: httpx.AsyncClient, + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, +): + skill_name = "skill/foo" + headers = _jwt_headers() + skill_service = asgi_app.state.services.skills + archive_path = tmp_path / "skill.zip" + archive_path.write_bytes(b"zip") + + async def fake_update_skill(data): + return {"payload": data} + + async def fake_delete_skill(data): + return {"payload": data} + + def fake_prepare_skill_archive(name: str): + assert name == skill_name + return SkillArchive(path=archive_path, filename="skill.zip") + + def fake_list_skill_files(name: str, path: str): + return {"name": name, "path": path} + + def fake_get_skill_file(name: str, path: str): + return {"name": name, "path": path} + + async def fake_update_skill_file(data): + return {"payload": data} + + monkeypatch.setattr(skill_service, "update_skill", fake_update_skill) + monkeypatch.setattr(skill_service, "delete_skill", fake_delete_skill) + monkeypatch.setattr( + skill_service, + "prepare_skill_archive", + fake_prepare_skill_archive, + ) + monkeypatch.setattr(skill_service, "list_skill_files", fake_list_skill_files) + monkeypatch.setattr(skill_service, "get_skill_file", fake_get_skill_file) + monkeypatch.setattr(skill_service, "update_skill_file", fake_update_skill_file) + + enabled_response = await asgi_client.patch( + "/api/v1/skills/by-name", + json={"skill_name": skill_name, "enabled": False}, + headers=headers, + ) + archive_response = await asgi_client.get( + "/api/v1/skills/archive", + params={"skill_name": skill_name}, + headers=headers, + ) + files_response = await asgi_client.get( + "/api/v1/skills/files", + params={"skill_name": skill_name, "path": "src"}, + headers=headers, + ) + file_response = await asgi_client.get( + "/api/v1/skills/file", + params={"skill_name": skill_name, "path": "src/main.py"}, + headers=headers, + ) + update_file_response = await asgi_client.put( + "/api/v1/skills/file", + json={"skill_name": skill_name, "path": "src/main.py", "content": "print(1)"}, + headers=headers, + ) + delete_response = await asgi_client.delete( + "/api/v1/skills/by-name", + params={"skill_name": skill_name}, + headers=headers, + ) + + assert enabled_response.status_code == 200 + assert enabled_response.json()["data"]["payload"] == { + "name": skill_name, + "active": False, + } + assert archive_response.status_code == 200 + assert archive_response.content == b"zip" + assert files_response.status_code == 200 + assert files_response.json()["data"] == {"name": skill_name, "path": "src"} + assert file_response.status_code == 200 + assert file_response.json()["data"] == { + "name": skill_name, + "path": "src/main.py", + } + assert update_file_response.status_code == 200 + assert update_file_response.json()["data"]["payload"] == { + "name": skill_name, + "path": "src/main.py", + "content": "print(1)", + } + assert delete_response.status_code == 200 + assert delete_response.json()["data"]["payload"] == {"name": skill_name} + + +@pytest.mark.asyncio +async def test_v1_safe_persona_routes_accept_slash_ids( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, +): + persona_id = "persona/foo" + headers = _jwt_headers() + persona_mgr = fake_core_lifecycle.persona_mgr + + detail_response = await asgi_client.get( + "/api/v1/personas/by-id", + params={"persona_id": persona_id}, + headers=headers, + ) + update_response = await asgi_client.put( + "/api/v1/personas/by-id", + json={"persona_id": persona_id, "name": "Demo Persona"}, + headers=headers, + ) + delete_response = await asgi_client.delete( + "/api/v1/personas/by-id", + params={"persona_id": persona_id}, + headers=headers, + ) + + assert detail_response.status_code == 200 + assert detail_response.json()["data"]["persona_id"] == persona_id + assert detail_response.json()["data"]["system_prompt"] == "Demo persona" + assert update_response.status_code == 200 + assert update_response.json()["data"] == {"message": "人格更新成功"} + assert delete_response.status_code == 200 + assert delete_response.json()["data"] == {"message": "人格删除成功"} + assert persona_id not in persona_mgr.personas + + +@pytest.mark.asyncio +async def test_v1_im_routes_use_im_scope_and_running_platform( + asgi_client: httpx.AsyncClient, + fake_core_lifecycle, + fake_db: FakeDb, +): + raw_key = "abk_fastapi_v1_im" + fake_db.add_api_key(raw_key, scopes=["im"]) + + bots_response = await asgi_client.get( + "/api/v1/im/bots", + headers={"X-API-Key": raw_key}, + ) + send_response = await asgi_client.post( + "/api/v1/im/messages", + json={ + "umo": "webchat-main:FriendMessage:test-session", + "message": "hello", + }, + headers={"X-API-Key": raw_key}, + ) + + assert bots_response.status_code == 200 + assert send_response.status_code == 200 + assert bots_response.json()["data"]["bot_ids"] == ["webchat-main"] + sent_messages = fake_core_lifecycle.platform_manager.fake_platform.sent_messages + assert len(sent_messages) == 1 + session, message_chain = sent_messages[0] + assert str(session) == "webchat-main:FriendMessage:test-session" + assert message_chain.chain[0].text == "hello" + + +@pytest.mark.asyncio +async def test_v1_platform_webhook_is_public_route( + asgi_client: httpx.AsyncClient, +): + response = await asgi_client.post( + "/api/v1/webhooks/platforms/demo-hook", + json={"challenge": "ping"}, + ) + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["data"] == { + "webhook_uuid": "demo-hook", + "method": "POST", + "payload": {"challenge": "ping"}, + } diff --git a/tests/test_kb_import.py b/tests/test_kb_import.py index 8795b06da1..702c5a7286 100644 --- a/tests/test_kb_import.py +++ b/tests/test_kb_import.py @@ -3,7 +3,6 @@ import pytest import pytest_asyncio -from quart import Quart from astrbot.core import LogBroker from astrbot.core.core_lifecycle import AstrBotCoreLifecycle @@ -13,10 +12,11 @@ from astrbot.core.knowledge_base.models import KBDocument from astrbot.core.utils.auth_password import ( hash_dashboard_password, - hash_legacy_dashboard_password, + hash_md5_dashboard_password, ) -from astrbot.dashboard.routes.knowledge_base import KnowledgeBaseRoute +from astrbot.dashboard.asgi_runtime import FastAPIAppAdapter from astrbot.dashboard.server import AstrBotDashboard +from astrbot.dashboard.services.knowledge_base_service import KnowledgeBaseService _TEST_DASHBOARD_PASSWORD = "AstrbotTest123" @@ -63,7 +63,7 @@ async def core_lifecycle_td(tmp_path_factory): hash_dashboard_password(dashboard_password) ) core_lifecycle.astrbot_config["dashboard"]["password"] = ( - hash_legacy_dashboard_password(dashboard_password) + hash_md5_dashboard_password(dashboard_password) ) object.__setattr__( core_lifecycle, @@ -84,7 +84,7 @@ async def core_lifecycle_td(tmp_path_factory): @pytest.fixture(scope="module") def app(core_lifecycle_td: AstrBotCoreLifecycle): - """Creates a Quart app instance for testing.""" + """Creates a FastAPIAppAdapter app instance for testing.""" shutdown_event = asyncio.Event() server = AstrBotDashboard(core_lifecycle_td, core_lifecycle_td.db, shutdown_event) return server.app @@ -101,7 +101,9 @@ def _resolve_dashboard_password(core_lifecycle_td: AstrBotCoreLifecycle) -> str: @pytest_asyncio.fixture(scope="module") -async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecycle): +async def authenticated_header( + app: FastAPIAppAdapter, core_lifecycle_td: AstrBotCoreLifecycle +): """Handles login and returns an authenticated header.""" test_client = app.test_client() response = await test_client.post( @@ -119,7 +121,9 @@ async def authenticated_header(app: Quart, core_lifecycle_td: AstrBotCoreLifecyc @pytest.mark.asyncio async def test_import_documents( - app: Quart, authenticated_header: dict, core_lifecycle_td: AstrBotCoreLifecycle + app: FastAPIAppAdapter, + authenticated_header: dict, + core_lifecycle_td: AstrBotCoreLifecycle, ): """Tests the import documents functionality.""" test_client = app.test_client() @@ -198,12 +202,12 @@ async def test_import_documents_returns_friendly_failure_message( details={"expected_contents": 2, "actual_vectors": 1}, ) - route = KnowledgeBaseRoute.__new__(KnowledgeBaseRoute) - route.upload_progress = {} - route.upload_tasks = {} + service = KnowledgeBaseService.__new__(KnowledgeBaseService) + service.upload_progress = {} + service.upload_tasks = {} - await KnowledgeBaseRoute._background_import_task( - route, + await KnowledgeBaseService.background_import_task( + service, task_id="task-1", kb_helper=kb_helper, documents=[{"file_name": "broken.txt", "chunks": ["chunk1", "chunk2"]}], @@ -212,8 +216,8 @@ async def test_import_documents_returns_friendly_failure_message( max_retries=3, ) - assert route.upload_tasks["task-1"]["status"] == "completed" - result = route.upload_tasks["task-1"]["result"] + assert service.upload_tasks["task-1"]["status"] == "completed" + result = service.upload_tasks["task-1"]["result"] assert result["success_count"] == 0 assert result["failed_count"] == 1 assert result["failed"][0]["file_name"] == "broken.txt" @@ -227,7 +231,9 @@ async def test_import_documents_returns_friendly_failure_message( @pytest.mark.asyncio -async def test_import_documents_invalid_input(app: Quart, authenticated_header: dict): +async def test_import_documents_invalid_input( + app: FastAPIAppAdapter, authenticated_header: dict +): """Tests import documents with invalid input.""" test_client = app.test_client() diff --git a/tests/unit/test_dashboard_util.py b/tests/unit/test_dashboard_util.py index 48a10020af..8458c86f05 100644 --- a/tests/unit/test_dashboard_util.py +++ b/tests/unit/test_dashboard_util.py @@ -1,7 +1,6 @@ """Tests for dashboard route utility helpers.""" -from astrbot.dashboard.routes.config import validate_config -from astrbot.dashboard.routes.util import get_schema_item +from astrbot.dashboard.services.config_service import get_schema_item, validate_config def test_get_schema_item_template_list_file_item(): diff --git a/tests/unit/test_live_chat_service.py b/tests/unit/test_live_chat_service.py new file mode 100644 index 0000000000..031c30458f --- /dev/null +++ b/tests/unit/test_live_chat_service.py @@ -0,0 +1,142 @@ +from types import SimpleNamespace + +import pytest +from starlette.websockets import WebSocketDisconnect + +from astrbot.dashboard.services.live_chat_service import LiveChatService + + +def _service() -> LiveChatService: + core_lifecycle = SimpleNamespace( + astrbot_config={"dashboard": {"jwt_secret": "test-secret"}}, + plugin_manager=SimpleNamespace(), + platform_message_history_manager=SimpleNamespace(), + ) + return LiveChatService(SimpleNamespace(), core_lifecycle) + + +@pytest.mark.asyncio +async def test_run_websocket_session_closes_when_token_is_missing(): + service = _service() + closed: list[tuple[int, str]] = [] + + async def close(code: int, reason: str) -> None: + closed.append((code, reason)) + + async def receive_json() -> dict: + raise AssertionError("receive_json should not be called") + + async def send_json(payload: dict) -> None: + raise AssertionError(f"send_json should not be called: {payload}") + + await service.run_websocket_session( + token=None, + force_ct=None, + receive_json=receive_json, + send_json=send_json, + close=close, + ) + + assert closed == [(1008, "Missing authentication token")] + assert service.sessions == {} + + +@pytest.mark.asyncio +async def test_run_websocket_session_routes_messages_and_cleans_session(monkeypatch): + service = _service() + messages = iter( + [ + {"ct": "chat", "t": "bind", "session_id": "chat-session"}, + {"t": "start_speaking", "stamp": "s1"}, + ] + ) + routed: list[tuple[str, str, dict]] = [] + + monkeypatch.setattr(service, "authenticate_token", lambda _token: "alice") + + async def handle_chat_message(session, message, _send_json) -> None: + routed.append(("chat", session.username, message)) + + async def handle_live_message(session, message, _send_json) -> None: + routed.append(("live", session.username, message)) + + monkeypatch.setattr(service, "handle_chat_message", handle_chat_message) + monkeypatch.setattr(service, "handle_live_message", handle_live_message) + + async def receive_json() -> dict: + try: + return next(messages) + except StopIteration as exc: + raise RuntimeError("disconnect") from exc + + async def send_json(_payload: dict) -> None: + pass + + async def close(_code: int, _reason: str) -> None: + raise AssertionError("close should not be called") + + await service.run_websocket_session( + token="valid", + force_ct=None, + receive_json=receive_json, + send_json=send_json, + close=close, + ) + + assert [(kind, username) for kind, username, _ in routed] == [ + ("chat", "alice"), + ("live", "alice"), + ] + assert service.sessions == {} + + +@pytest.mark.asyncio +async def test_run_websocket_session_handles_disconnect_without_error_log( + monkeypatch, +): + service = _service() + messages = iter([{"ct": "chat", "t": "bind", "session_id": "chat-session"}]) + routed: list[dict] = [] + + monkeypatch.setattr(service, "authenticate_token", lambda _token: "alice") + + async def handle_chat_message(session, message, _send_json) -> None: + routed.append({"username": session.username, "message": message}) + + monkeypatch.setattr(service, "handle_chat_message", handle_chat_message) + + async def receive_json() -> dict: + try: + return next(messages) + except StopIteration as exc: + raise WebSocketDisconnect(1006) from exc + + async def send_json(_payload: dict) -> None: + pass + + async def close(_code: int, _reason: str) -> None: + raise AssertionError("close should not be called") + + def fail_error_log(*_args, **_kwargs) -> None: + raise AssertionError("disconnect should not be logged as an error") + + monkeypatch.setattr( + "astrbot.dashboard.services.live_chat_service.logger.error", + fail_error_log, + ) + + await service.run_websocket_session( + token="valid", + force_ct=None, + receive_json=receive_json, + send_json=send_json, + close=close, + ) + + assert routed == [ + { + "username": "alice", + "message": {"ct": "chat", "t": "bind", "session_id": "chat-session"}, + } + ] + assert service.sessions == {} diff --git a/tests/unit/test_open_api_service_ws.py b/tests/unit/test_open_api_service_ws.py new file mode 100644 index 0000000000..8c44e74905 --- /dev/null +++ b/tests/unit/test_open_api_service_ws.py @@ -0,0 +1,133 @@ +from types import SimpleNamespace + +import pytest + +from astrbot.dashboard.services.open_api_service import ( + OpenApiService, + OpenApiWebSocketChatBridge, +) + + +def _service() -> OpenApiService: + core_lifecycle = SimpleNamespace( + platform_manager=SimpleNamespace(platform_insts=[]), + platform_message_history_manager=None, + ) + return OpenApiService(SimpleNamespace(), core_lifecycle) + + +def _bridge() -> OpenApiWebSocketChatBridge: + async def build_user_message_parts(_message): + return [] + + async def create_attachment_from_file(_filename, _attach_type): + return None + + async def insert_user_message(_session_id, _effective_username, _message_parts): + pass + + async def save_bot_message(_session_id, _message_parts, _agent_stats, _refs): + return None + + return OpenApiWebSocketChatBridge( + build_user_message_parts=build_user_message_parts, + create_attachment_from_file=create_attachment_from_file, + extract_web_search_refs=lambda _text, _parts: {}, + insert_user_message=insert_user_message, + save_bot_message=save_bot_message, + ) + + +@pytest.mark.asyncio +async def test_run_chat_websocket_closes_when_api_key_is_invalid(monkeypatch): + service = _service() + sent: list[dict] = [] + closed: list[tuple[int, str]] = [] + + async def authenticate_api_key(_raw_key): + return False, "Invalid API key" + + monkeypatch.setattr(service, "authenticate_api_key", authenticate_api_key) + + async def receive_json(): + raise AssertionError("receive_json should not be called") + + async def send_json(payload: dict) -> None: + sent.append(payload) + + async def close(code: int, reason: str) -> None: + closed.append((code, reason)) + + await service.run_chat_websocket( + raw_api_key="bad", + receive_json=receive_json, + send_json=send_json, + close=close, + conf_list=[], + chat_bridge=_bridge(), + ) + + assert sent == [ + {"type": "error", "code": "UNAUTHORIZED", "data": "Invalid API key"} + ] + assert closed == [(1008, "Invalid API key")] + + +@pytest.mark.asyncio +async def test_run_chat_websocket_handles_control_messages(monkeypatch): + service = _service() + messages = iter( + [ + ["not", "an", "object"], + {"t": "ping"}, + {"t": "unknown"}, + {"t": "send", "message": "hello"}, + ] + ) + sent: list[dict] = [] + handled: list[dict] = [] + + async def authenticate_api_key(_raw_key): + return True, None + + async def handle_chat_ws_send(**kwargs): + handled.append(kwargs["post_data"]) + + monkeypatch.setattr(service, "authenticate_api_key", authenticate_api_key) + monkeypatch.setattr(service, "handle_chat_ws_send", handle_chat_ws_send) + + async def receive_json(): + try: + return next(messages) + except StopIteration as exc: + raise RuntimeError("disconnect") from exc + + async def send_json(payload: dict) -> None: + sent.append(payload) + + async def close(_code: int, _reason: str) -> None: + raise AssertionError("close should not be called") + + await service.run_chat_websocket( + raw_api_key="good", + receive_json=receive_json, + send_json=send_json, + close=close, + conf_list=[], + chat_bridge=_bridge(), + ) + + assert sent == [ + { + "type": "error", + "code": "INVALID_MESSAGE", + "data": "message must be an object", + }, + {"type": "pong"}, + { + "type": "error", + "code": "INVALID_MESSAGE", + "data": "Unsupported message type: unknown", + }, + ] + assert handled == [{"t": "send", "message": "hello"}] diff --git a/tests/unit/test_tool_permission.py b/tests/unit/test_tool_permission.py index 1a3a8a376b..5d8581023e 100644 --- a/tests/unit/test_tool_permission.py +++ b/tests/unit/test_tool_permission.py @@ -1,36 +1,20 @@ """Tests for per-tool permission management.""" -import asyncio -import json -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch +from types import SimpleNamespace import pytest from astrbot.core import sp -from astrbot.core.agent.tool import FunctionTool, ToolSet +from astrbot.core.agent.tool import FunctionTool from astrbot.core.provider.func_tool_manager import ( FunctionToolManager, _PermissionGuardedTool, ) -from astrbot.core.tools.computer_tools.shell import ExecuteShellTool - +from astrbot.dashboard.services.tools_service import ToolsService, ToolsServiceError # ── helpers ────────────────────────────────────────────────────────── -def _make_coro(value: object): - """Return a fresh coroutine object that resolves to ``value``. - - Used to mock Quart's ``await request.json`` where ``json`` is an - async property — assigning a coroutine lets ``await`` work directly.""" - - async def _inner(): - return value - - return _inner() - - def _make_context(role: str = "member", sender_id: str = "user_123"): """Return a mock context object suitable for tool permission checks.""" @@ -70,6 +54,16 @@ def _clear_tool_permissions() -> None: sp.put("tool_permissions", {}, scope="global", scope_id="global") +def _make_tools_service() -> ToolsService: + tool_mgr = FunctionToolManager() + config_mgr = SimpleNamespace(get_conf_list=lambda: [], confs={}) + core_lifecycle = SimpleNamespace( + provider_manager=SimpleNamespace(llm_tools=tool_mgr), + astrbot_config_mgr=config_mgr, + ) + return ToolsService(core_lifecycle) + + # ── _default_permission ────────────────────────────────────────────── @@ -298,26 +292,17 @@ def test_get_full_tool_set_wraps_non_builtin(): plugin_tools = [t for t in tool_set.tools if t.name == "my_plugin_tool"] assert plugin_tools - assert isinstance( - plugin_tools[0], _PermissionGuardedTool - ), "non-builtin tools must be wrapped" + assert isinstance(plugin_tools[0], _PermissionGuardedTool), ( + "non-builtin tools must be wrapped" + ) # ── API: get_tool_list permission fields ────────────────────────────── class TestGetToolListPermission: - @pytest.mark.asyncio - async def test_list_includes_permission_fields_for_non_builtin(self): - from astrbot.dashboard.routes.tools import ToolsRoute - - # Minimal stubs to avoid full core lifecycle init - route = ToolsRoute.__new__(ToolsRoute) - route.core_lifecycle = MagicMock() - route.core_lifecycle.astrbot_config_mgr = MagicMock() - route.core_lifecycle.astrbot_config_mgr.get_conf_list.return_value = [] - route.core_lifecycle.astrbot_config_mgr.confs = {} - route.tool_mgr = FunctionToolManager() + def test_list_includes_permission_fields_for_non_builtin(self): + service = _make_tools_service() sp.put( "tool_permissions", @@ -326,10 +311,8 @@ async def test_list_includes_permission_fields_for_non_builtin(self): scope_id="global", ) try: - route.tool_mgr.func_list.append(_dummy_tool("my_plugin_tool")) - resp = await route.get_tool_list() - data = json.loads(json.dumps(resp)) # simulate json serialisation - tools = data["data"] + service.tool_mgr.func_list.append(_dummy_tool("my_plugin_tool")) + tools = service.get_tool_list() target = next(t for t in tools if t["name"] == "my_plugin_tool") assert target["permission"] == "admin" @@ -338,20 +321,9 @@ async def test_list_includes_permission_fields_for_non_builtin(self): finally: _clear_tool_permissions() - @pytest.mark.asyncio - async def test_list_no_permission_fields_for_builtin(self): - from astrbot.dashboard.routes.tools import ToolsRoute - - route = ToolsRoute.__new__(ToolsRoute) - route.core_lifecycle = MagicMock() - route.core_lifecycle.astrbot_config_mgr = MagicMock() - route.core_lifecycle.astrbot_config_mgr.get_conf_list.return_value = [] - route.core_lifecycle.astrbot_config_mgr.confs = {} - route.tool_mgr = FunctionToolManager() - - resp = await route.get_tool_list() - data = json.loads(json.dumps(resp)) - tools = data["data"] + def test_list_no_permission_fields_for_builtin(self): + service = _make_tools_service() + tools = service.get_tool_list() target = next(t for t in tools if t["name"] == "astrbot_execute_shell") assert "permission" not in target @@ -363,72 +335,40 @@ async def test_list_no_permission_fields_for_builtin(self): class TestUpdateToolPermission: - @pytest.mark.asyncio - async def test_set_admin_permission(self): - from astrbot.dashboard.routes.tools import ToolsRoute - - route = ToolsRoute.__new__(ToolsRoute) - route.core_lifecycle = MagicMock() - route.tool_mgr = FunctionToolManager() - route.tool_mgr.func_list.append(_dummy_tool("target_tool")) + def test_set_admin_permission(self): + service = _make_tools_service() + service.tool_mgr.func_list.append(_dummy_tool("target_tool")) _clear_tool_permissions() - mock_req = MagicMock() - mock_req.json = _make_coro({"name": "target_tool", "permission": "admin"}) - with patch("astrbot.dashboard.routes.tools.request", mock_req): - resp = await route.update_tool_permission() - data = json.loads(json.dumps(resp)) - assert data["status"] == "ok" - - stored = sp.get( - "tool_permissions", {}, scope="global", scope_id="global" + message = service.update_tool_permission( + {"name": "target_tool", "permission": "admin"} ) + assert "target_tool" in message + + stored = sp.get("tool_permissions", {}, scope="global", scope_id="global") assert stored["_default"]["target_tool"] == "admin" - @pytest.mark.asyncio - async def test_reject_builtin_tool(self): - from astrbot.dashboard.routes.tools import ToolsRoute - - route = ToolsRoute.__new__(ToolsRoute) - route.core_lifecycle = MagicMock() - route.tool_mgr = FunctionToolManager() - - mock_req = MagicMock() - mock_req.json = _make_coro({"name": "astrbot_execute_shell", "permission": "admin"}) - with patch("astrbot.dashboard.routes.tools.request", mock_req): - resp = await route.update_tool_permission() - data = json.loads(json.dumps(resp)) - assert data["status"] == "error" - assert "builtin" in str(data["message"]).lower() - - @pytest.mark.asyncio - async def test_reject_unknown_tool(self): - from astrbot.dashboard.routes.tools import ToolsRoute - - route = ToolsRoute.__new__(ToolsRoute) - route.core_lifecycle = MagicMock() - route.tool_mgr = FunctionToolManager() - - mock_req = MagicMock() - mock_req.json = _make_coro({"name": "ghost_tool", "permission": "admin"}) - with patch("astrbot.dashboard.routes.tools.request", mock_req): - resp = await route.update_tool_permission() - data = json.loads(json.dumps(resp)) - assert data["status"] == "error" - assert "not found" in str(data["message"]).lower() - - @pytest.mark.asyncio - async def test_reject_invalid_permission_value(self): - from astrbot.dashboard.routes.tools import ToolsRoute - - route = ToolsRoute.__new__(ToolsRoute) - route.core_lifecycle = MagicMock() - route.tool_mgr = FunctionToolManager() - route.tool_mgr.func_list.append(_dummy_tool("target_tool")) - - mock_req = MagicMock() - mock_req.json = _make_coro({"name": "target_tool", "permission": "everyone"}) - with patch("astrbot.dashboard.routes.tools.request", mock_req): - resp = await route.update_tool_permission() - data = json.loads(json.dumps(resp)) - assert data["status"] == "error" + def test_reject_builtin_tool(self): + service = _make_tools_service() + + with pytest.raises(ToolsServiceError, match="Builtin tools"): + service.update_tool_permission( + {"name": "astrbot_execute_shell", "permission": "admin"} + ) + + def test_reject_unknown_tool(self): + service = _make_tools_service() + + with pytest.raises(ToolsServiceError, match="not found"): + service.update_tool_permission( + {"name": "ghost_tool", "permission": "admin"} + ) + + def test_reject_invalid_permission_value(self): + service = _make_tools_service() + service.tool_mgr.func_list.append(_dummy_tool("target_tool")) + + with pytest.raises(ToolsServiceError, match="permission"): + service.update_tool_permission( + {"name": "target_tool", "permission": "everyone"} + ) diff --git a/tests/unit/test_upload_filename_sanitization.py b/tests/unit/test_upload_filename_sanitization.py index 88374669ec..ddb031a1eb 100644 --- a/tests/unit/test_upload_filename_sanitization.py +++ b/tests/unit/test_upload_filename_sanitization.py @@ -1,22 +1,22 @@ """Tests for upload filename sanitization.""" -from astrbot.dashboard.routes.chat import _sanitize_upload_filename +from astrbot.dashboard.services.chat_service import sanitize_upload_filename def test_sanitize_upload_filename_strips_posix_traversal(): - assert _sanitize_upload_filename("../../outside.txt") == "outside.txt" + assert sanitize_upload_filename("../../outside.txt") == "outside.txt" def test_sanitize_upload_filename_strips_windows_traversal(): - assert _sanitize_upload_filename(r"..\\..\\outside.txt") == "outside.txt" + assert sanitize_upload_filename(r"..\\..\\outside.txt") == "outside.txt" def test_sanitize_upload_filename_strips_fakepath(): - assert _sanitize_upload_filename(r"C:\\fakepath\\photo.png") == "photo.png" + assert sanitize_upload_filename(r"C:\\fakepath\\photo.png") == "photo.png" def test_sanitize_upload_filename_falls_back_for_empty_values(): - generated = _sanitize_upload_filename("") + generated = sanitize_upload_filename("") assert generated assert generated not in {".", ".."} @@ -25,8 +25,7 @@ def test_sanitize_upload_filename_falls_back_for_empty_values(): def test_sanitize_upload_filename_removes_embedded_null_bytes(): - assert _sanitize_upload_filename("evil\x00.txt") == "evil.txt" - assert _sanitize_upload_filename("\x00leading.txt") == "leading.txt" - assert _sanitize_upload_filename("trailing\x00.txt\x00") == "trailing.txt" - assert _sanitize_upload_filename("mid\x00dle.txt") == "middle.txt" - + assert sanitize_upload_filename("evil\x00.txt") == "evil.txt" + assert sanitize_upload_filename("\x00leading.txt") == "leading.txt" + assert sanitize_upload_filename("trailing\x00.txt\x00") == "trailing.txt" + assert sanitize_upload_filename("mid\x00dle.txt") == "middle.txt"