From 7135072c9b63dc9b7ea0ecc464bf4eccdb02e222 Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Thu, 14 May 2026 23:13:05 +0800 Subject: [PATCH 1/2] fix: preserve file batch id when polling completes --- .../resources/vector_stores/file_batches.py | 35 +++++++++++- .../vector_stores/test_file_batches.py | 56 +++++++++++++++++++ 2 files changed, 88 insertions(+), 3 deletions(-) diff --git a/src/openai/resources/vector_stores/file_batches.py b/src/openai/resources/vector_stores/file_batches.py index 4bde1a4aa6..95b7700bbc 100644 --- a/src/openai/resources/vector_stores/file_batches.py +++ b/src/openai/resources/vector_stores/file_batches.py @@ -3,7 +3,7 @@ from __future__ import annotations import asyncio -from typing import Dict, Iterable, Optional +from typing import Any, Dict, Iterable, Optional from typing_extensions import Union, Literal from concurrent.futures import Future, ThreadPoolExecutor, as_completed @@ -15,6 +15,7 @@ from ..._types import Body, Omit, Query, Headers, NotGiven, FileTypes, SequenceNotStr, omit, not_given from ..._utils import is_given, path_template, maybe_transform, async_maybe_transform from ..._compat import cached_property +from ..._models import construct_type_unchecked from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper from ...pagination import SyncCursorPage, AsyncCursorPage @@ -28,6 +29,26 @@ __all__ = ["FileBatches", "AsyncFileBatches"] +def _coerce_vector_store_poll_response( + data: dict[str, Any], + *, + batch_id: str, + vector_store_id: str, +) -> VectorStoreFileBatch | None: + if data.get("object") != "vector_store" or data.get("id") != vector_store_id: + return None + + return construct_type_unchecked( + value={ + **data, + "id": batch_id, + "object": "vector_store.files_batch", + "vector_store_id": vector_store_id, + }, + type_=VectorStoreFileBatch, + ) + + class FileBatches(SyncAPIResource): @cached_property def with_raw_response(self) -> FileBatchesWithRawResponse: @@ -351,7 +372,11 @@ def poll( extra_headers=headers, ) - batch = response.parse() + data = response.parse(to=dict) + batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id) + if batch is None: + batch = response.parse() + if batch.file_counts.in_progress > 0: if not is_given(poll_interval_ms): from_header = response.headers.get("openai-poll-after-ms") @@ -739,7 +764,11 @@ async def poll( extra_headers=headers, ) - batch = response.parse() + data = response.parse(to=dict) + batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id) + if batch is None: + batch = response.parse() + if batch.file_counts.in_progress > 0: if not is_given(poll_interval_ms): from_header = response.headers.get("openai-poll-after-ms") diff --git a/tests/api_resources/vector_stores/test_file_batches.py b/tests/api_resources/vector_stores/test_file_batches.py index c1fba534a6..ebcd689567 100644 --- a/tests/api_resources/vector_stores/test_file_batches.py +++ b/tests/api_resources/vector_stores/test_file_batches.py @@ -5,6 +5,7 @@ import os from typing import Any, cast +import httpx import pytest from openai import OpenAI, AsyncOpenAI @@ -462,3 +463,58 @@ def test_create_and_poll_method_in_sync(sync: bool, client: OpenAI, async_client checking_client.vector_stores.file_batches.create, checking_client.vector_stores.file_batches.create_and_poll, ) + + +def _completed_vector_store_response() -> dict[str, object]: + return { + "id": "vs_abc123", + "created_at": 1761991501, + "file_counts": { + "cancelled": 0, + "completed": 1, + "failed": 0, + "in_progress": 0, + "total": 1, + }, + "object": "vector_store", + "status": "completed", + "vector_store_id": None, + } + + +def test_poll_coerces_completed_vector_store_response() -> None: + def handler(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123" + return httpx.Response(200, json=_completed_vector_store_response()) + + with OpenAI( + api_key="My API Key", + base_url=base_url, + http_client=httpx.Client(transport=httpx.MockTransport(handler)), + _strict_response_validation=True, + ) as client: + file_batch = client.vector_stores.file_batches.poll(batch_id="vsfb_abc123", vector_store_id="vs_abc123") + + assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"]) + assert file_batch.id == "vsfb_abc123" + assert file_batch.vector_store_id == "vs_abc123" + + +async def test_async_poll_coerces_completed_vector_store_response() -> None: + async def handler(request: httpx.Request) -> httpx.Response: + assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123" + return httpx.Response(200, json=_completed_vector_store_response()) + + async with AsyncOpenAI( + api_key="My API Key", + base_url=base_url, + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)), + _strict_response_validation=True, + ) as async_client: + file_batch = await async_client.vector_stores.file_batches.poll( + batch_id="vsfb_abc123", vector_store_id="vs_abc123" + ) + + assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"]) + assert file_batch.id == "vsfb_abc123" + assert file_batch.vector_store_id == "vs_abc123" From 0e79e2002af997c102d524ca5dc863aee8a78456 Mon Sep 17 00:00:00 2001 From: nightcityblade Date: Fri, 15 May 2026 11:09:28 +0800 Subject: [PATCH 2/2] fix: preserve strict validation in file batch polling --- .../resources/vector_stores/file_batches.py | 42 ++++++++++++------- .../vector_stores/test_file_batches.py | 36 +++++++++++++++- 2 files changed, 62 insertions(+), 16 deletions(-) diff --git a/src/openai/resources/vector_stores/file_batches.py b/src/openai/resources/vector_stores/file_batches.py index 95b7700bbc..7dc403b2d4 100644 --- a/src/openai/resources/vector_stores/file_batches.py +++ b/src/openai/resources/vector_stores/file_batches.py @@ -15,7 +15,6 @@ from ..._types import Body, Omit, Query, Headers, NotGiven, FileTypes, SequenceNotStr, omit, not_given from ..._utils import is_given, path_template, maybe_transform, async_maybe_transform from ..._compat import cached_property -from ..._models import construct_type_unchecked from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper from ...pagination import SyncCursorPage, AsyncCursorPage @@ -34,19 +33,16 @@ def _coerce_vector_store_poll_response( *, batch_id: str, vector_store_id: str, -) -> VectorStoreFileBatch | None: +) -> dict[str, Any] | None: if data.get("object") != "vector_store" or data.get("id") != vector_store_id: return None - return construct_type_unchecked( - value={ - **data, - "id": batch_id, - "object": "vector_store.files_batch", - "vector_store_id": vector_store_id, - }, - type_=VectorStoreFileBatch, - ) + return { + **data, + "id": batch_id, + "object": "vector_store.files_batch", + "vector_store_id": vector_store_id, + } class FileBatches(SyncAPIResource): @@ -373,9 +369,17 @@ def poll( ) data = response.parse(to=dict) - batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id) - if batch is None: + coerced_data = _coerce_vector_store_poll_response( + data, batch_id=batch_id, vector_store_id=vector_store_id + ) + if coerced_data is None: batch = response.parse() + else: + batch = response._client._process_response_data( + data=coerced_data, + cast_to=VectorStoreFileBatch, + response=response.http_response, + ) if batch.file_counts.in_progress > 0: if not is_given(poll_interval_ms): @@ -765,9 +769,17 @@ async def poll( ) data = response.parse(to=dict) - batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id) - if batch is None: + coerced_data = _coerce_vector_store_poll_response( + data, batch_id=batch_id, vector_store_id=vector_store_id + ) + if coerced_data is None: batch = response.parse() + else: + batch = response._client._process_response_data( + data=coerced_data, + cast_to=VectorStoreFileBatch, + response=response.http_response, + ) if batch.file_counts.in_progress > 0: if not is_given(poll_interval_ms): diff --git a/tests/api_resources/vector_stores/test_file_batches.py b/tests/api_resources/vector_stores/test_file_batches.py index ebcd689567..5953f32355 100644 --- a/tests/api_resources/vector_stores/test_file_batches.py +++ b/tests/api_resources/vector_stores/test_file_batches.py @@ -8,7 +8,7 @@ import httpx import pytest -from openai import OpenAI, AsyncOpenAI +from openai import OpenAI, AsyncOpenAI, APIResponseValidationError from tests.utils import assert_matches_type from openai._utils import assert_signatures_in_sync from openai.pagination import SyncCursorPage, AsyncCursorPage @@ -518,3 +518,37 @@ async def handler(request: httpx.Request) -> httpx.Response: assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"]) assert file_batch.id == "vsfb_abc123" assert file_batch.vector_store_id == "vs_abc123" + + +def test_poll_preserves_strict_validation_for_coerced_vector_store_response() -> None: + def handler(request: httpx.Request) -> httpx.Response: + data = _completed_vector_store_response() + data["created_at"] = "invalid" + return httpx.Response(200, json=data) + + with OpenAI( + api_key="My API Key", + base_url=base_url, + http_client=httpx.Client(transport=httpx.MockTransport(handler)), + _strict_response_validation=True, + ) as client: + with pytest.raises(APIResponseValidationError): + client.vector_stores.file_batches.poll(batch_id="vsfb_abc123", vector_store_id="vs_abc123") + + +async def test_async_poll_preserves_strict_validation_for_coerced_vector_store_response() -> None: + async def handler(request: httpx.Request) -> httpx.Response: + data = _completed_vector_store_response() + data["created_at"] = "invalid" + return httpx.Response(200, json=data) + + async with AsyncOpenAI( + api_key="My API Key", + base_url=base_url, + http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)), + _strict_response_validation=True, + ) as async_client: + with pytest.raises(APIResponseValidationError): + await async_client.vector_stores.file_batches.poll( + batch_id="vsfb_abc123", vector_store_id="vs_abc123" + )