Skip to content

Commit 0e79e20

Browse files
author
nightcityblade
committed
fix: preserve strict validation in file batch polling
1 parent 7135072 commit 0e79e20

2 files changed

Lines changed: 62 additions & 16 deletions

File tree

src/openai/resources/vector_stores/file_batches.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from ..._types import Body, Omit, Query, Headers, NotGiven, FileTypes, SequenceNotStr, omit, not_given
1616
from ..._utils import is_given, path_template, maybe_transform, async_maybe_transform
1717
from ..._compat import cached_property
18-
from ..._models import construct_type_unchecked
1918
from ..._resource import SyncAPIResource, AsyncAPIResource
2019
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
2120
from ...pagination import SyncCursorPage, AsyncCursorPage
@@ -34,19 +33,16 @@ def _coerce_vector_store_poll_response(
3433
*,
3534
batch_id: str,
3635
vector_store_id: str,
37-
) -> VectorStoreFileBatch | None:
36+
) -> dict[str, Any] | None:
3837
if data.get("object") != "vector_store" or data.get("id") != vector_store_id:
3938
return None
4039

41-
return construct_type_unchecked(
42-
value={
43-
**data,
44-
"id": batch_id,
45-
"object": "vector_store.files_batch",
46-
"vector_store_id": vector_store_id,
47-
},
48-
type_=VectorStoreFileBatch,
49-
)
40+
return {
41+
**data,
42+
"id": batch_id,
43+
"object": "vector_store.files_batch",
44+
"vector_store_id": vector_store_id,
45+
}
5046

5147

5248
class FileBatches(SyncAPIResource):
@@ -373,9 +369,17 @@ def poll(
373369
)
374370

375371
data = response.parse(to=dict)
376-
batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id)
377-
if batch is None:
372+
coerced_data = _coerce_vector_store_poll_response(
373+
data, batch_id=batch_id, vector_store_id=vector_store_id
374+
)
375+
if coerced_data is None:
378376
batch = response.parse()
377+
else:
378+
batch = response._client._process_response_data(
379+
data=coerced_data,
380+
cast_to=VectorStoreFileBatch,
381+
response=response.http_response,
382+
)
379383

380384
if batch.file_counts.in_progress > 0:
381385
if not is_given(poll_interval_ms):
@@ -765,9 +769,17 @@ async def poll(
765769
)
766770

767771
data = response.parse(to=dict)
768-
batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id)
769-
if batch is None:
772+
coerced_data = _coerce_vector_store_poll_response(
773+
data, batch_id=batch_id, vector_store_id=vector_store_id
774+
)
775+
if coerced_data is None:
770776
batch = response.parse()
777+
else:
778+
batch = response._client._process_response_data(
779+
data=coerced_data,
780+
cast_to=VectorStoreFileBatch,
781+
response=response.http_response,
782+
)
771783

772784
if batch.file_counts.in_progress > 0:
773785
if not is_given(poll_interval_ms):

tests/api_resources/vector_stores/test_file_batches.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import httpx
99
import pytest
1010

11-
from openai import OpenAI, AsyncOpenAI
11+
from openai import OpenAI, AsyncOpenAI, APIResponseValidationError
1212
from tests.utils import assert_matches_type
1313
from openai._utils import assert_signatures_in_sync
1414
from openai.pagination import SyncCursorPage, AsyncCursorPage
@@ -518,3 +518,37 @@ async def handler(request: httpx.Request) -> httpx.Response:
518518
assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"])
519519
assert file_batch.id == "vsfb_abc123"
520520
assert file_batch.vector_store_id == "vs_abc123"
521+
522+
523+
def test_poll_preserves_strict_validation_for_coerced_vector_store_response() -> None:
524+
def handler(request: httpx.Request) -> httpx.Response:
525+
data = _completed_vector_store_response()
526+
data["created_at"] = "invalid"
527+
return httpx.Response(200, json=data)
528+
529+
with OpenAI(
530+
api_key="My API Key",
531+
base_url=base_url,
532+
http_client=httpx.Client(transport=httpx.MockTransport(handler)),
533+
_strict_response_validation=True,
534+
) as client:
535+
with pytest.raises(APIResponseValidationError):
536+
client.vector_stores.file_batches.poll(batch_id="vsfb_abc123", vector_store_id="vs_abc123")
537+
538+
539+
async def test_async_poll_preserves_strict_validation_for_coerced_vector_store_response() -> None:
540+
async def handler(request: httpx.Request) -> httpx.Response:
541+
data = _completed_vector_store_response()
542+
data["created_at"] = "invalid"
543+
return httpx.Response(200, json=data)
544+
545+
async with AsyncOpenAI(
546+
api_key="My API Key",
547+
base_url=base_url,
548+
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
549+
_strict_response_validation=True,
550+
) as async_client:
551+
with pytest.raises(APIResponseValidationError):
552+
await async_client.vector_stores.file_batches.poll(
553+
batch_id="vsfb_abc123", vector_store_id="vs_abc123"
554+
)

0 commit comments

Comments
 (0)