diff --git a/src/openai/resources/vector_stores/file_batches.py b/src/openai/resources/vector_stores/file_batches.py index 4bde1a4aa6..7dc403b2d4 100644 --- a/src/openai/resources/vector_stores/file_batches.py +++ b/src/openai/resources/vector_stores/file_batches.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional from typing_extensions import Union, Literal from concurrent.futures import Future, ThreadPoolExecutor, as_completed @@ -28,6 +28,23 @@ __all__ = ["FileBatches", "AsyncFileBatches"] +def _coerce_vector_store_poll_response( + data: dict[str, Any], + *, + batch_id: str, + vector_store_id: str, +) -> dict[str, Any] | None: + if data.get("object") != "vector_store" or data.get("id") != vector_store_id: + return None + + return { + **data, + "id": batch_id, + "object": "vector_store.files_batch", + "vector_store_id": vector_store_id, + } + + class FileBatches(SyncAPIResource): @cached_property def with_raw_response(self) -> FileBatchesWithRawResponse: @@ -351,7 +368,19 @@ def poll( extra_headers=headers, ) - batch = response.parse() + data = response.parse(to=dict) + coerced_data = _coerce_vector_store_poll_response( + data, batch_id=batch_id, vector_store_id=vector_store_id + ) + if coerced_data is None: + batch = response.parse() + else: + batch = response._client._process_response_data( + data=coerced_data, + cast_to=VectorStoreFileBatch, + response=response.http_response, + ) + if batch.file_counts.in_progress > 0: if not is_given(poll_interval_ms): from_header = response.headers.get("openai-poll-after-ms") @@ -739,7 +768,19 @@ async def poll( extra_headers=headers, ) - batch = response.parse() + data = response.parse(to=dict) + coerced_data = _coerce_vector_store_poll_response( + data, batch_id=batch_id, vector_store_id=vector_store_id + ) + if coerced_data is None: + batch = response.parse() + else: + batch = response._client._process_response_data( + data=coerced_data, + cast_to=VectorStoreFileBatch, + response=response.http_response, + ) + if batch.file_counts.in_progress > 0: if not is_given(poll_interval_ms): from_header = response.headers.get("openai-poll-after-ms") diff --git a/tests/api_resources/vector_stores/test_file_batches.py b/tests/api_resources/vector_stores/test_file_batches.py index c1fba534a6..5953f32355 100644 --- a/tests/api_resources/vector_stores/test_file_batches.py +++ b/tests/api_resources/vector_stores/test_file_batches.py @@ -5,9 +5,10 @@ import os from typing import Any, cast +import httpx import pytest -from openai import OpenAI, AsyncOpenAI +from openai import OpenAI, AsyncOpenAI, APIResponseValidationError from tests.utils import assert_matches_type from openai._utils import assert_signatures_in_sync from openai.pagination import SyncCursorPage, AsyncCursorPage @@ -462,3 +463,92 @@ def test_create_and_poll_method_in_sync(sync: bool, client: OpenAI, async_client checking_client.vector_stores.file_batches.create, checking_client.vector_stores.file_batches.create_and_poll, ) + + +def _completed_vector_store_response() -> dict[str, object]: + return { + "id": "vs_abc123", + "created_at": 1761991501, + "file_counts": { + "cancelled": 0, + "completed": 1, + "failed": 0, + "in_progress": 0, + "total": 1, + }, + "object": "vector_store", + "status": "completed", + "vector_store_id": None, + } + + +def test_poll_coerces_completed_vector_store_response() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123" + return httpx.Response(200, json=_completed_vector_store_response()) + + with OpenAI( + api_key="My API Key", + base_url=base_url, + http_client=httpx.Client(transport=httpx.MockTransport(handler)), + _strict_response_validation=True, + ) as client: + file_batch = client.vector_stores.file_batches.poll(batch_id="vsfb_abc123", vector_store_id="vs_abc123") + + assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"]) + assert file_batch.id == "vsfb_abc123" + assert file_batch.vector_store_id == "vs_abc123" + + +async def test_async_poll_coerces_completed_vector_store_response() -> None: + async def handler(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123" + return httpx.Response(200, json=_completed_vector_store_response()) + + async with AsyncOpenAI( + api_key="My API Key", + base_url=base_url, + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)), + _strict_response_validation=True, + ) as async_client: + file_batch = await async_client.vector_stores.file_batches.poll( + batch_id="vsfb_abc123", vector_store_id="vs_abc123" + ) + + assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"]) + assert file_batch.id == "vsfb_abc123" + assert file_batch.vector_store_id == "vs_abc123" + + +def test_poll_preserves_strict_validation_for_coerced_vector_store_response() -> None: + def handler(request: httpx.Request) -> httpx.Response: + data = _completed_vector_store_response() + data["created_at"] = "invalid" + return httpx.Response(200, json=data) + + with OpenAI( + api_key="My API Key", + base_url=base_url, + http_client=httpx.Client(transport=httpx.MockTransport(handler)), + _strict_response_validation=True, + ) as client: + with pytest.raises(APIResponseValidationError): + client.vector_stores.file_batches.poll(batch_id="vsfb_abc123", vector_store_id="vs_abc123") + + +async def test_async_poll_preserves_strict_validation_for_coerced_vector_store_response() -> None: + async def handler(request: httpx.Request) -> httpx.Response: + data = _completed_vector_store_response() + data["created_at"] = "invalid" + return httpx.Response(200, json=data) + + async with AsyncOpenAI( + api_key="My API Key", + base_url=base_url, + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)), + _strict_response_validation=True, + ) as async_client: + with pytest.raises(APIResponseValidationError): + await async_client.vector_stores.file_batches.poll( + batch_id="vsfb_abc123", vector_store_id="vs_abc123" + )