Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 41 additions & 4 deletions astrbot/core/db/vec_db/faiss_impl/embedding_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ def __init__(self, dimension: int, path: str | None = None) -> None:
base_index = faiss.IndexFlatL2(dimension)
self.index = faiss.IndexIDMap(base_index)

def _add_with_ids(self, vectors: np.ndarray, ids: np.ndarray) -> None:
assert self.index is not None, "FAISS index is not initialized."
vectors = np.ascontiguousarray(vectors, dtype=np.float32)
ids = np.ascontiguousarray(ids, dtype=np.int64)
try:
self.index.add_with_ids(vectors, ids)
except TypeError as exc:
if "missing" not in str(exc):
raise
self.index.add_with_ids(
vectors.shape[0],
faiss.swig_ptr(vectors),
faiss.swig_ptr(ids),
)
Comment on lines +27 to +36

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): FAISS search fallback also depends on a specific error message and assumes vector dtype/shape.

The search fallback currently relies on the exact substring "missing 3 required positional arguments", which is tied to a specific FAISS version and may break if the error message changes. It also calls the pointer-based overload without first ensuring vector is contiguous float32, so non-contiguous or non-float32 inputs could cause faiss.swig_ptr misuse. Consider mirroring _add_with_ids by normalizing via np.ascontiguousarray(vector, dtype=np.float32) and using a more robust way to detect the legacy signature than matching a hard-coded error string.


async def insert(self, vector: np.ndarray, id: int) -> None:
"""插入向量

Expand All @@ -35,7 +50,7 @@ async def insert(self, vector: np.ndarray, id: int) -> None:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}",
)
self.index.add_with_ids(vector.reshape(1, -1), np.array([id]))
self._add_with_ids(vector.reshape(1, -1), np.array([id], dtype=np.int64))
await self.save_index()

async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
Expand All @@ -53,7 +68,7 @@ async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None:
raise ValueError(
f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}",
)
self.index.add_with_ids(vectors, np.array(ids))
self._add_with_ids(vectors, np.array(ids, dtype=np.int64))
await self.save_index()

async def search(self, vector: np.ndarray, k: int) -> tuple:
Expand All @@ -67,8 +82,24 @@ async def search(self, vector: np.ndarray, k: int) -> tuple:

"""
assert self.index is not None, "FAISS index is not initialized."
vector = np.ascontiguousarray(vector, dtype=np.float32)
if vector.ndim == 1:
vector = vector.reshape(1, -1)
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
try:
distances, indices = self.index.search(vector, k)
except TypeError as exc:
if "missing" not in str(exc):
raise
distances = np.empty((vector.shape[0], k), dtype=np.float32)
indices = np.empty((vector.shape[0], k), dtype=np.int64)
self.index.search(
vector.shape[0],
faiss.swig_ptr(vector),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(indices),
)
Comment on lines 88 to +102

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Potential Segmentation Fault / Crash in SWIG Fallback

In the search method, if vector is a 1D array (e.g., shape (dimension,)), vector.shape[0] will represent the dimension of the vector (e.g., 1536) rather than the number of query vectors n (which should be 1).

When the SWIG fallback is triggered, passing vector.shape[0] as n to self.index.search will cause FAISS to read out of bounds and crash with a segmentation fault.

To prevent this, we should ensure vector is contiguous, float32, and reshaped to 2D (shape (1, dimension)) if it is 1D.

Suggested change
faiss.normalize_L2(vector)
distances, indices = self.index.search(vector, k)
try:
distances, indices = self.index.search(vector, k)
except TypeError as exc:
if "missing 3 required positional arguments" not in str(exc):
raise
distances = np.empty((vector.shape[0], k), dtype=np.float32)
indices = np.empty((vector.shape[0], k), dtype=np.int64)
self.index.search(
vector.shape[0],
faiss.swig_ptr(vector),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(indices),
)
vector = np.ascontiguousarray(vector, dtype=np.float32)
if vector.ndim == 1:
vector = vector.reshape(1, -1)
faiss.normalize_L2(vector)
try:
distances, indices = self.index.search(vector, k)
except TypeError as exc:
if "missing 3 required positional arguments" not in str(exc):
raise
distances = np.empty((vector.shape[0], k), dtype=np.float32)
indices = np.empty((vector.shape[0], k), dtype=np.int64)
self.index.search(
vector.shape[0],
faiss.swig_ptr(vector),
k,
faiss.swig_ptr(distances),
faiss.swig_ptr(indices),
)

return distances, indices

async def delete(self, ids: list[int]) -> None:
Expand All @@ -80,7 +111,13 @@ async def delete(self, ids: list[int]) -> None:
"""
assert self.index is not None, "FAISS index is not initialized."
id_array = np.array(ids, dtype=np.int64)
self.index.remove_ids(id_array)
try:
self.index.remove_ids(id_array)
except TypeError as exc:
if "IDSelector" not in str(exc):
raise
selector = faiss.IDSelectorBatch(id_array.size, faiss.swig_ptr(id_array))
self.index.remove_ids(selector)
await self.save_index()

async def save_index(self) -> None:
Expand Down
52 changes: 41 additions & 11 deletions astrbot/core/provider/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,36 +452,66 @@ def to_openai_tool_calls(self) -> list[dict]:
"""Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
if idx >= len(self.tools_call_name):
logger.warning(
"Skipping tool call argument without matching tool name at index %s.",
idx,
)
break
Comment on lines +455 to +460

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (bug_risk): Silently breaking when tools_call_args is longer than tools_call_name risks dropping tool calls.

By breaking when idx >= len(self.tools_call_name), any extra tools_call_args entries are silently discarded, which can mask upstream data inconsistencies. Consider emitting a warning (and similarly on len(self.tools_call_ids) mismatches) so inconsistent tool-call lists are detectable rather than silently losing tool invocations.

tool_name = self.tools_call_name[idx]
if not isinstance(tool_name, str) or not tool_name.strip():
logger.warning("Skipping malformed tool call with empty tool name.")
continue
tool_call_id = (
self.tools_call_ids[idx]
if idx < len(self.tools_call_ids)
and isinstance(self.tools_call_ids[idx], str)
and self.tools_call_ids[idx].strip()
else f"call_{uuid.uuid4().hex}"
)
payload = {
"id": self.tools_call_ids[idx],
"id": tool_call_id,
"function": {
"name": self.tools_call_name[idx],
"name": tool_name.strip(),
"arguments": json.dumps(tool_call_arg),
},
"type": "function",
}
if self.tools_call_extra_content.get(self.tools_call_ids[idx]):
payload["extra_content"] = self.tools_call_extra_content[
self.tools_call_ids[idx]
]
if self.tools_call_extra_content.get(tool_call_id):
payload["extra_content"] = self.tools_call_extra_content[tool_call_id]
ret.append(payload)
return ret

def to_openai_to_calls_model(self) -> list[ToolCall]:
"""The same as to_openai_tool_calls but return pydantic model."""
ret = []
for idx, tool_call_arg in enumerate(self.tools_call_args):
if idx >= len(self.tools_call_name):
logger.warning(
"Skipping tool call argument without matching tool name at index %s.",
idx,
)
break
tool_name = self.tools_call_name[idx]
if not isinstance(tool_name, str) or not tool_name.strip():
logger.warning("Skipping malformed tool call with empty tool name.")
continue
tool_call_id = (
self.tools_call_ids[idx]
if idx < len(self.tools_call_ids)
and isinstance(self.tools_call_ids[idx], str)
and self.tools_call_ids[idx].strip()
else f"call_{uuid.uuid4().hex}"
)
ret.append(
ToolCall(
id=self.tools_call_ids[idx],
id=tool_call_id,
function=ToolCall.FunctionBody(
name=self.tools_call_name[idx],
name=tool_name.strip(),
arguments=json.dumps(tool_call_arg),
),
# the extra_content will not serialize if it's None when calling ToolCall.model_dump()
extra_content=self.tools_call_extra_content.get(
self.tools_call_ids[idx]
),
extra_content=self.tools_call_extra_content.get(tool_call_id),
),
)
return ret
Expand Down
87 changes: 77 additions & 10 deletions astrbot/core/provider/sources/openai_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,19 @@ async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
extra_body=extra_body,
)

if isinstance(completion, str):
text = self._normalize_content(completion)
if text:
logger.warning(
"OpenAI-compatible provider returned raw string completion; normalized it."
)
return LLMResponse("assistant", completion_text=text)
raise EmptyModelOutputError(
"OpenAI-compatible provider returned an empty string completion."
)
if isinstance(completion, dict):
completion = ChatCompletion.model_validate(completion)

if not isinstance(completion, ChatCompletion):
raise Exception(
f"API 返回的 completion 类型错误:{type(completion)}: {completion}。",
Expand Down Expand Up @@ -721,7 +734,10 @@ async def _query_stream(
try:
state.handle_chunk(chunk)
except Exception as e:
logger.error("Saving chunk state error: " + str(e))
logger.warning(
f"Saving chunk state skipped for chunk {chunk!r}: {e}",
exc_info=True,
)
# logger.debug(f"chunk delta: {delta}")
# handle the content delta
reasoning = self._extract_reasoning_content(chunk)
Expand Down Expand Up @@ -944,27 +960,78 @@ async def _parse_openai_completion(
# Should be unreachable
raise Exception("工具集未提供")

if tool_call.type == "function":
is_dict_tool_call = isinstance(tool_call, dict)
tool_call_type = (
tool_call.get("type")
if is_dict_tool_call
else getattr(tool_call, "type", None)
)
if tool_call_type == "function":
tool_call_function = (
tool_call.get("function")
if is_dict_tool_call
else getattr(tool_call, "function", None)
)
func_name = (
tool_call_function.get("name")
if isinstance(tool_call_function, dict)
else getattr(tool_call_function, "name", None)
)
if not isinstance(func_name, str) or not func_name.strip():
logger.warning(
"Skipping malformed tool call with empty function name: %s",
tool_call,
)
continue
func_name = func_name.strip()
tool_call_id = (
tool_call.get("id")
if is_dict_tool_call
else getattr(tool_call, "id", None)
)
if not isinstance(tool_call_id, str) or not tool_call_id.strip():
tool_call_id = f"call_{uuid.uuid4().hex}"
logger.warning(
"Generated missing tool_call id for %s: %s",
func_name,
tool_call_id,
)
# workaround for #1454
if isinstance(tool_call.function.arguments, str):
tool_call_arguments = (
tool_call_function.get("arguments")
if isinstance(tool_call_function, dict)
else getattr(tool_call_function, "arguments", None)
)
if isinstance(tool_call_arguments, str):
try:
args = json.loads(tool_call.function.arguments)
args = json.loads(tool_call_arguments)
except json.JSONDecodeError as e:
logger.error(f"解析参数失败: {e}")
logger.warning(f"解析参数失败: {e}")
args = {}
else:
args = tool_call.function.arguments
args = tool_call_arguments
# Some API may return None for tools with no parameters
if args is None:
args = {}
if not isinstance(args, dict):
logger.warning(
"Tool call arguments for %s are not an object: %s",
func_name,
type(args).__name__,
)
args = {}
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
tool_call_ids.append(tool_call.id)
func_name_ls.append(func_name)
tool_call_ids.append(tool_call_id)

# gemini-2.5 / gemini-3 series extra_content handling
extra_content = getattr(tool_call, "extra_content", None)
extra_content = (
tool_call.get("extra_content")
if is_dict_tool_call
else getattr(tool_call, "extra_content", None)
)
if extra_content is not None:
tool_call_extra_content_dict[tool_call.id] = extra_content
tool_call_extra_content_dict[tool_call_id] = extra_content

llm_response.role = "tool"
llm_response.tools_call_args = args_ls
Expand Down
13 changes: 9 additions & 4 deletions astrbot/core/utils/requirements_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,8 @@ def get_requirement_check_paths() -> list[str]:


def _canonical_distribution_identity(distribution) -> tuple[str | None, str | None]:
distribution_name = (
distribution.metadata["Name"] if "Name" in distribution.metadata else None
)
metadata = distribution.metadata
distribution_name = metadata.get("Name") if metadata else None
if not distribution_name:
return None, None
return canonicalize_distribution_name(distribution_name), distribution.version
Expand All @@ -337,7 +336,13 @@ def collect_installed_distribution_versions(paths: list[str]) -> dict[str, str]
installed: dict[str, str] = {}
try:
for distribution in importlib_metadata.distributions(path=paths):
distribution_name, version = _canonical_distribution_identity(distribution)
try:
distribution_name, version = _canonical_distribution_identity(
distribution
)
except Exception as exc:
logger.debug("Skipping unreadable distribution metadata: %s", exc)
continue
if not distribution_name or not version:
continue
installed.setdefault(distribution_name, version)
Expand Down
Loading