-
-
Notifications
You must be signed in to change notification settings - Fork 2.4k
增强 OpenAI-compatible Provider 与 FAISS 兼容性 #8689
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
3ae1420
670c99a
83f5ae6
14901b4
d24bdb8
d7c5aa1
98784ad
7bce127
3508bd2
9f72e98
c92a3f4
4b9eddd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,6 +20,21 @@ def __init__(self, dimension: int, path: str | None = None) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| base_index = faiss.IndexFlatL2(dimension) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.index = faiss.IndexIDMap(base_index) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| def _add_with_ids(self, vectors: np.ndarray, ids: np.ndarray) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.index is not None, "FAISS index is not initialized." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vectors = np.ascontiguousarray(vectors, dtype=np.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ids = np.ascontiguousarray(ids, dtype=np.int64) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.index.add_with_ids(vectors, ids) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except TypeError as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "missing" not in str(exc): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.index.add_with_ids( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vectors.shape[0], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| faiss.swig_ptr(vectors), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| faiss.swig_ptr(ids), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def insert(self, vector: np.ndarray, id: int) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """插入向量 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -35,7 +50,7 @@ async def insert(self, vector: np.ndarray, id: int) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"向量维度不匹配, 期望: {self.dimension}, 实际: {vector.shape[0]}", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.index.add_with_ids(vector.reshape(1, -1), np.array([id])) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._add_with_ids(vector.reshape(1, -1), np.array([id], dtype=np.int64)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await self.save_index() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -53,7 +68,7 @@ async def insert_batch(self, vectors: np.ndarray, ids: list[int]) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise ValueError( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| f"向量维度不匹配, 期望: {self.dimension}, 实际: {vectors.shape[1]}", | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.index.add_with_ids(vectors, np.array(ids)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self._add_with_ids(vectors, np.array(ids, dtype=np.int64)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await self.save_index() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def search(self, vector: np.ndarray, k: int) -> tuple: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -67,8 +82,24 @@ async def search(self, vector: np.ndarray, k: int) -> tuple: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.index is not None, "FAISS index is not initialized." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector = np.ascontiguousarray(vector, dtype=np.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if vector.ndim == 1: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector = vector.reshape(1, -1) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| faiss.normalize_L2(vector) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| distances, indices = self.index.search(vector, k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| distances, indices = self.index.search(vector, k) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except TypeError as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "missing" not in str(exc): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| distances = np.empty((vector.shape[0], k), dtype=np.float32) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| indices = np.empty((vector.shape[0], k), dtype=np.int64) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.index.search( | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| vector.shape[0], | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| faiss.swig_ptr(vector), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| k, | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| faiss.swig_ptr(distances), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| faiss.swig_ptr(indices), | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
88
to
+102
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential Segmentation Fault / Crash in SWIG FallbackIn the When the SWIG fallback is triggered, passing To prevent this, we should ensure
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| return distances, indices | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def delete(self, ids: list[int]) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
@@ -80,7 +111,13 @@ async def delete(self, ids: list[int]) -> None: | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| assert self.index is not None, "FAISS index is not initialized." | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| id_array = np.array(ids, dtype=np.int64) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.index.remove_ids(id_array) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| try: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.index.remove_ids(id_array) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| except TypeError as exc: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| if "IDSelector" not in str(exc): | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| raise | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| selector = faiss.IDSelectorBatch(id_array.size, faiss.swig_ptr(id_array)) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| self.index.remove_ids(selector) | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| await self.save_index() | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| async def save_index(self) -> None: | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -452,36 +452,66 @@ def to_openai_tool_calls(self) -> list[dict]: | |
| """Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead.""" | ||
| ret = [] | ||
| for idx, tool_call_arg in enumerate(self.tools_call_args): | ||
| if idx >= len(self.tools_call_name): | ||
| logger.warning( | ||
| "Skipping tool call argument without matching tool name at index %s.", | ||
| idx, | ||
| ) | ||
| break | ||
|
Comment on lines
+455
to
+460
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. suggestion (bug_risk): Silently breaking when tools_call_args is longer than tools_call_name risks dropping tool calls. By breaking when |
||
| tool_name = self.tools_call_name[idx] | ||
| if not isinstance(tool_name, str) or not tool_name.strip(): | ||
| logger.warning("Skipping malformed tool call with empty tool name.") | ||
| continue | ||
| tool_call_id = ( | ||
| self.tools_call_ids[idx] | ||
| if idx < len(self.tools_call_ids) | ||
| and isinstance(self.tools_call_ids[idx], str) | ||
| and self.tools_call_ids[idx].strip() | ||
| else f"call_{uuid.uuid4().hex}" | ||
| ) | ||
| payload = { | ||
| "id": self.tools_call_ids[idx], | ||
| "id": tool_call_id, | ||
| "function": { | ||
| "name": self.tools_call_name[idx], | ||
| "name": tool_name.strip(), | ||
| "arguments": json.dumps(tool_call_arg), | ||
| }, | ||
| "type": "function", | ||
| } | ||
| if self.tools_call_extra_content.get(self.tools_call_ids[idx]): | ||
| payload["extra_content"] = self.tools_call_extra_content[ | ||
| self.tools_call_ids[idx] | ||
| ] | ||
| if self.tools_call_extra_content.get(tool_call_id): | ||
| payload["extra_content"] = self.tools_call_extra_content[tool_call_id] | ||
| ret.append(payload) | ||
| return ret | ||
|
|
||
| def to_openai_to_calls_model(self) -> list[ToolCall]: | ||
| """The same as to_openai_tool_calls but return pydantic model.""" | ||
| ret = [] | ||
| for idx, tool_call_arg in enumerate(self.tools_call_args): | ||
| if idx >= len(self.tools_call_name): | ||
| logger.warning( | ||
| "Skipping tool call argument without matching tool name at index %s.", | ||
| idx, | ||
| ) | ||
| break | ||
| tool_name = self.tools_call_name[idx] | ||
| if not isinstance(tool_name, str) or not tool_name.strip(): | ||
| logger.warning("Skipping malformed tool call with empty tool name.") | ||
| continue | ||
| tool_call_id = ( | ||
| self.tools_call_ids[idx] | ||
| if idx < len(self.tools_call_ids) | ||
| and isinstance(self.tools_call_ids[idx], str) | ||
| and self.tools_call_ids[idx].strip() | ||
| else f"call_{uuid.uuid4().hex}" | ||
| ) | ||
| ret.append( | ||
| ToolCall( | ||
| id=self.tools_call_ids[idx], | ||
| id=tool_call_id, | ||
| function=ToolCall.FunctionBody( | ||
| name=self.tools_call_name[idx], | ||
| name=tool_name.strip(), | ||
| arguments=json.dumps(tool_call_arg), | ||
| ), | ||
| # the extra_content will not serialize if it's None when calling ToolCall.model_dump() | ||
| extra_content=self.tools_call_extra_content.get( | ||
| self.tools_call_ids[idx] | ||
| ), | ||
| extra_content=self.tools_call_extra_content.get(tool_call_id), | ||
| ), | ||
| ) | ||
| return ret | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
suggestion (bug_risk): FAISS search fallback also depends on a specific error message and assumes vector dtype/shape.
The
searchfallback currently relies on the exact substring"missing 3 required positional arguments", which is tied to a specific FAISS version and may break if the error message changes. It also calls the pointer-based overload without first ensuringvectoris contiguous float32, so non-contiguous or non-float32 inputs could causefaiss.swig_ptrmisuse. Consider mirroring_add_with_idsby normalizing vianp.ascontiguousarray(vector, dtype=np.float32)and using a more robust way to detect the legacy signature than matching a hard-coded error string.