diff --git a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py index dc6977cf8a..aadde10ac5 100644 --- a/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py +++ b/astrbot/core/db/vec_db/faiss_impl/embedding_storage.py @@ -20,6 +20,21 @@ def __init__(self, dimension: int, path: str | None = None) -> None: base_index = faiss.IndexFlatL2(dimension) self.index = faiss.IndexIDMap(base_index) + def _add_with_ids(self, vectors: np.ndarray, ids: np.ndarray) -> None: + assert self.index is not None, "FAISS index is not initialized." + vectors = np.ascontiguousarray(vectors, dtype=np.float32) + ids = np.ascontiguousarray(ids, dtype=np.int64) + try: + self.index.add_with_ids(vectors, ids) + except TypeError as exc: + if "missing" not in str(exc): + raise + self.index.add_with_ids( + vectors.shape[0], + faiss.swig_ptr(vectors), + faiss.swig_ptr(ids), + ) + async def insert(self, vector: np.ndarray, id: int) -> None: """插入向量 @@ -35,7 +50,7 @@ async def insert(self, vector: np.ndarray, id: int) -> None: raise ValueError( f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}", ) - self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) + self._add_with_ids(vector.reshape(1, -1), np.array([id], dtype=np.int64)) await self.save_index() async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: @@ -53,7 +68,7 @@ async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: raise ValueError( f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}", ) - self.index.add_with_ids(vectors, np.array(ids)) + self._add_with_ids(vectors, np.array(ids, dtype=np.int64)) await self.save_index() async def search(self, vector: np.ndarray, k: int) -> tuple: @@ -67,8 +82,24 @@ async def search(self, vector: np.ndarray, k: int) -> tuple: """ assert self.index is not None, "FAISS index is not initialized." + vector = np.ascontiguousarray(vector, dtype=np.float32) + if vector.ndim == 1: + vector = vector.reshape(1, -1) faiss.normalize_L2(vector) - distances, indices = self.index.search(vector, k) + try: + distances, indices = self.index.search(vector, k) + except TypeError as exc: + if "missing" not in str(exc): + raise + distances = np.empty((vector.shape[0], k), dtype=np.float32) + indices = np.empty((vector.shape[0], k), dtype=np.int64) + self.index.search( + vector.shape[0], + faiss.swig_ptr(vector), + k, + faiss.swig_ptr(distances), + faiss.swig_ptr(indices), + ) return distances, indices async def delete(self, ids: list[int]) -> None: @@ -80,7 +111,13 @@ async def delete(self, ids: list[int]) -> None: """ assert self.index is not None, "FAISS index is not initialized." id_array = np.array(ids, dtype=np.int64) - self.index.remove_ids(id_array) + try: + self.index.remove_ids(id_array) + except TypeError as exc: + if "IDSelector" not in str(exc): + raise + selector = faiss.IDSelectorBatch(id_array.size, faiss.swig_ptr(id_array)) + self.index.remove_ids(selector) await self.save_index() async def save_index(self) -> None: diff --git a/astrbot/core/provider/entities.py b/astrbot/core/provider/entities.py index 8e12683ffb..54157377dc 100644 --- a/astrbot/core/provider/entities.py +++ b/astrbot/core/provider/entities.py @@ -452,18 +452,33 @@ def to_openai_tool_calls(self) -> list[dict]: """Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead.""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): + if idx >= len(self.tools_call_name): + logger.warning( + "Skipping tool call argument without matching tool name at index %s.", + idx, + ) + break + tool_name = self.tools_call_name[idx] + if not isinstance(tool_name, str) or not tool_name.strip(): + logger.warning("Skipping malformed tool call with empty tool name.") + continue + tool_call_id = ( + self.tools_call_ids[idx] + if idx < len(self.tools_call_ids) + and isinstance(self.tools_call_ids[idx], str) + and self.tools_call_ids[idx].strip() + else f"call_{uuid.uuid4().hex}" + ) payload = { - "id": self.tools_call_ids[idx], + "id": tool_call_id, "function": { - "name": self.tools_call_name[idx], + "name": tool_name.strip(), "arguments": json.dumps(tool_call_arg), }, "type": "function", } - if self.tools_call_extra_content.get(self.tools_call_ids[idx]): - payload["extra_content"] = self.tools_call_extra_content[ - self.tools_call_ids[idx] - ] + if self.tools_call_extra_content.get(tool_call_id): + payload["extra_content"] = self.tools_call_extra_content[tool_call_id] ret.append(payload) return ret @@ -471,17 +486,32 @@ def to_openai_to_calls_model(self) -> list[ToolCall]: """The same as to_openai_tool_calls but return pydantic model.""" ret = [] for idx, tool_call_arg in enumerate(self.tools_call_args): + if idx >= len(self.tools_call_name): + logger.warning( + "Skipping tool call argument without matching tool name at index %s.", + idx, + ) + break + tool_name = self.tools_call_name[idx] + if not isinstance(tool_name, str) or not tool_name.strip(): + logger.warning("Skipping malformed tool call with empty tool name.") + continue + tool_call_id = ( + self.tools_call_ids[idx] + if idx < len(self.tools_call_ids) + and isinstance(self.tools_call_ids[idx], str) + and self.tools_call_ids[idx].strip() + else f"call_{uuid.uuid4().hex}" + ) ret.append( ToolCall( - id=self.tools_call_ids[idx], + id=tool_call_id, function=ToolCall.FunctionBody( - name=self.tools_call_name[idx], + name=tool_name.strip(), arguments=json.dumps(tool_call_arg), ), # the extra_content will not serialize if it's None when calling ToolCall.model_dump() - extra_content=self.tools_call_extra_content.get( - self.tools_call_ids[idx] - ), + extra_content=self.tools_call_extra_content.get(tool_call_id), ), ) return ret diff --git a/astrbot/core/provider/sources/openai_source.py b/astrbot/core/provider/sources/openai_source.py index 8aa2778f1b..44a748df82 100644 --- a/astrbot/core/provider/sources/openai_source.py +++ b/astrbot/core/provider/sources/openai_source.py @@ -642,6 +642,19 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse: extra_body=extra_body, ) + if isinstance(completion, str): + text = self._normalize_content(completion) + if text: + logger.warning( + "OpenAI-compatible provider returned raw string completion; normalized it." + ) + return LLMResponse("assistant", completion_text=text) + raise EmptyModelOutputError( + "OpenAI-compatible provider returned an empty string completion." + ) + if isinstance(completion, dict): + completion = ChatCompletion.model_validate(completion) + if not isinstance(completion, ChatCompletion): raise Exception( f"API 返回的 completion 类型错误:{type(completion)}: {completion}。", @@ -721,7 +734,10 @@ async def _query_stream( try: state.handle_chunk(chunk) except Exception as e: - logger.error("Saving chunk state error: " + str(e)) + logger.warning( + f"Saving chunk state skipped for chunk {chunk!r}: {e}", + exc_info=True, + ) # logger.debug(f"chunk delta: {delta}") # handle the content delta reasoning = self._extract_reasoning_content(chunk) @@ -944,27 +960,78 @@ async def _parse_openai_completion( # Should be unreachable raise Exception("工具集未提供") - if tool_call.type == "function": + is_dict_tool_call = isinstance(tool_call, dict) + tool_call_type = ( + tool_call.get("type") + if is_dict_tool_call + else getattr(tool_call, "type", None) + ) + if tool_call_type == "function": + tool_call_function = ( + tool_call.get("function") + if is_dict_tool_call + else getattr(tool_call, "function", None) + ) + func_name = ( + tool_call_function.get("name") + if isinstance(tool_call_function, dict) + else getattr(tool_call_function, "name", None) + ) + if not isinstance(func_name, str) or not func_name.strip(): + logger.warning( + "Skipping malformed tool call with empty function name: %s", + tool_call, + ) + continue + func_name = func_name.strip() + tool_call_id = ( + tool_call.get("id") + if is_dict_tool_call + else getattr(tool_call, "id", None) + ) + if not isinstance(tool_call_id, str) or not tool_call_id.strip(): + tool_call_id = f"call_{uuid.uuid4().hex}" + logger.warning( + "Generated missing tool_call id for %s: %s", + func_name, + tool_call_id, + ) # workaround for #1454 - if isinstance(tool_call.function.arguments, str): + tool_call_arguments = ( + tool_call_function.get("arguments") + if isinstance(tool_call_function, dict) + else getattr(tool_call_function, "arguments", None) + ) + if isinstance(tool_call_arguments, str): try: - args = json.loads(tool_call.function.arguments) + args = json.loads(tool_call_arguments) except json.JSONDecodeError as e: - logger.error(f"解析参数失败: {e}") + logger.warning(f"解析参数失败: {e}") args = {} else: - args = tool_call.function.arguments + args = tool_call_arguments # Some API may return None for tools with no parameters if args is None: args = {} + if not isinstance(args, dict): + logger.warning( + "Tool call arguments for %s are not an object: %s", + func_name, + type(args).__name__, + ) + args = {} args_ls.append(args) - func_name_ls.append(tool_call.function.name) - tool_call_ids.append(tool_call.id) + func_name_ls.append(func_name) + tool_call_ids.append(tool_call_id) # gemini-2.5 / gemini-3 series extra_content handling - extra_content = getattr(tool_call, "extra_content", None) + extra_content = ( + tool_call.get("extra_content") + if is_dict_tool_call + else getattr(tool_call, "extra_content", None) + ) if extra_content is not None: - tool_call_extra_content_dict[tool_call.id] = extra_content + tool_call_extra_content_dict[tool_call_id] = extra_content llm_response.role = "tool" llm_response.tools_call_args = args_ls diff --git a/astrbot/core/utils/requirements_utils.py b/astrbot/core/utils/requirements_utils.py index 969976a4fc..d12c3e3348 100644 --- a/astrbot/core/utils/requirements_utils.py +++ b/astrbot/core/utils/requirements_utils.py @@ -325,9 +325,8 @@ def get_requirement_check_paths() -> list[str]: def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]: - distribution_name = ( - distribution.metadata["Name"] if "Name" in distribution.metadata else None - ) + metadata = distribution.metadata + distribution_name = metadata.get("Name") if metadata else None if not distribution_name: return None, None return canonicalize_distribution_name(distribution_name), distribution.version @@ -337,7 +336,13 @@ def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str] installed: dict[str, str] = {} try: for distribution in importlib_metadata.distributions(path=paths): - distribution_name, version = _canonical_distribution_identity(distribution) + try: + distribution_name, version = _canonical_distribution_identity( + distribution + ) + except Exception as exc: + logger.debug("Skipping unreadable distribution metadata: %s", exc) + continue if not distribution_name or not version: continue installed.setdefault(distribution_name, version)