Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 44 additions & 3 deletions src/openai/resources/vector_stores/file_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from __future__ import annotations

import asyncio
from typing import Dict, Iterable, Optional
from typing import Any, Dict, Iterable, Optional
from typing_extensions import Union, Literal
from concurrent.futures import Future, ThreadPoolExecutor, as_completed

Expand All @@ -28,6 +28,23 @@
__all__ = ["FileBatches", "AsyncFileBatches"]


def _coerce_vector_store_poll_response(
data: dict[str, Any],
*,
batch_id: str,
vector_store_id: str,
) -> dict[str, Any] | None:
if data.get("object") != "vector_store" or data.get("id") != vector_store_id:
return None

return {
**data,
"id": batch_id,
"object": "vector_store.files_batch",
"vector_store_id": vector_store_id,
}


class FileBatches(SyncAPIResource):
@cached_property
def with_raw_response(self) -> FileBatchesWithRawResponse:
Expand Down Expand Up @@ -351,7 +368,19 @@ def poll(
extra_headers=headers,
)

batch = response.parse()
data = response.parse(to=dict)
coerced_data = _coerce_vector_store_poll_response(
data, batch_id=batch_id, vector_store_id=vector_store_id
)
if coerced_data is None:
batch = response.parse()
else:
batch = response._client._process_response_data(
data=coerced_data,
cast_to=VectorStoreFileBatch,
response=response.http_response,
)

if batch.file_counts.in_progress > 0:
if not is_given(poll_interval_ms):
from_header = response.headers.get("openai-poll-after-ms")
Expand Down Expand Up @@ -739,7 +768,19 @@ async def poll(
extra_headers=headers,
)

batch = response.parse()
data = response.parse(to=dict)
coerced_data = _coerce_vector_store_poll_response(
data, batch_id=batch_id, vector_store_id=vector_store_id
)
if coerced_data is None:
batch = response.parse()
else:
batch = response._client._process_response_data(
data=coerced_data,
cast_to=VectorStoreFileBatch,
response=response.http_response,
)

if batch.file_counts.in_progress > 0:
if not is_given(poll_interval_ms):
from_header = response.headers.get("openai-poll-after-ms")
Expand Down
92 changes: 91 additions & 1 deletion tests/api_resources/vector_stores/test_file_batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@
import os
from typing import Any, cast

import httpx
import pytest

from openai import OpenAI, AsyncOpenAI
from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
from tests.utils import assert_matches_type
from openai._utils import assert_signatures_in_sync
from openai.pagination import SyncCursorPage, AsyncCursorPage
Expand Down Expand Up @@ -462,3 +463,92 @@ def test_create_and_poll_method_in_sync(sync: bool, client: OpenAI, async_client
checking_client.vector_stores.file_batches.create,
checking_client.vector_stores.file_batches.create_and_poll,
)


def _completed_vector_store_response() -> dict[str, object]:
return {
"id": "vs_abc123",
"created_at": 1761991501,
"file_counts": {
"cancelled": 0,
"completed": 1,
"failed": 0,
"in_progress": 0,
"total": 1,
},
"object": "vector_store",
"status": "completed",
"vector_store_id": None,
}


def test_poll_coerces_completed_vector_store_response() -> None:
def handler(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123"
return httpx.Response(200, json=_completed_vector_store_response())

with OpenAI(
api_key="My API Key",
base_url=base_url,
http_client=httpx.Client(transport=httpx.MockTransport(handler)),
_strict_response_validation=True,
) as client:
file_batch = client.vector_stores.file_batches.poll(batch_id="vsfb_abc123", vector_store_id="vs_abc123")

assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"])
assert file_batch.id == "vsfb_abc123"
assert file_batch.vector_store_id == "vs_abc123"


async def test_async_poll_coerces_completed_vector_store_response() -> None:
async def handler(request: httpx.Request) -> httpx.Response:
assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123"
return httpx.Response(200, json=_completed_vector_store_response())

async with AsyncOpenAI(
api_key="My API Key",
base_url=base_url,
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
_strict_response_validation=True,
) as async_client:
file_batch = await async_client.vector_stores.file_batches.poll(
batch_id="vsfb_abc123", vector_store_id="vs_abc123"
)

assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"])
assert file_batch.id == "vsfb_abc123"
assert file_batch.vector_store_id == "vs_abc123"


def test_poll_preserves_strict_validation_for_coerced_vector_store_response() -> None:
def handler(request: httpx.Request) -> httpx.Response:
data = _completed_vector_store_response()
data["created_at"] = "invalid"
return httpx.Response(200, json=data)

with OpenAI(
api_key="My API Key",
base_url=base_url,
http_client=httpx.Client(transport=httpx.MockTransport(handler)),
_strict_response_validation=True,
) as client:
with pytest.raises(APIResponseValidationError):
client.vector_stores.file_batches.poll(batch_id="vsfb_abc123", vector_store_id="vs_abc123")


async def test_async_poll_preserves_strict_validation_for_coerced_vector_store_response() -> None:
async def handler(request: httpx.Request) -> httpx.Response:
data = _completed_vector_store_response()
data["created_at"] = "invalid"
return httpx.Response(200, json=data)

async with AsyncOpenAI(
api_key="My API Key",
base_url=base_url,
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
_strict_response_validation=True,
) as async_client:
with pytest.raises(APIResponseValidationError):
await async_client.vector_stores.file_batches.poll(
batch_id="vsfb_abc123", vector_store_id="vs_abc123"
)