Skip to content

Commit 7135072

Browse files
author
nightcityblade
committed
fix: preserve file batch id when polling completes
1 parent 38d75d7 commit 7135072

2 files changed

Lines changed: 88 additions & 3 deletions

File tree

src/openai/resources/vector_stores/file_batches.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
import asyncio
6-
from typing import Dict, Iterable, Optional
6+
from typing import Any, Dict, Iterable, Optional
77
from typing_extensions import Union, Literal
88
from concurrent.futures import Future, ThreadPoolExecutor, as_completed
99

@@ -15,6 +15,7 @@
1515
from ..._types import Body, Omit, Query, Headers, NotGiven, FileTypes, SequenceNotStr, omit, not_given
1616
from ..._utils import is_given, path_template, maybe_transform, async_maybe_transform
1717
from ..._compat import cached_property
18+
from ..._models import construct_type_unchecked
1819
from ..._resource import SyncAPIResource, AsyncAPIResource
1920
from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper
2021
from ...pagination import SyncCursorPage, AsyncCursorPage
@@ -28,6 +29,26 @@
2829
__all__ = ["FileBatches", "AsyncFileBatches"]
2930

3031

32+
def _coerce_vector_store_poll_response(
33+
data: dict[str, Any],
34+
*,
35+
batch_id: str,
36+
vector_store_id: str,
37+
) -> VectorStoreFileBatch | None:
38+
if data.get("object") != "vector_store" or data.get("id") != vector_store_id:
39+
return None
40+
41+
return construct_type_unchecked(
42+
value={
43+
**data,
44+
"id": batch_id,
45+
"object": "vector_store.files_batch",
46+
"vector_store_id": vector_store_id,
47+
},
48+
type_=VectorStoreFileBatch,
49+
)
50+
51+
3152
class FileBatches(SyncAPIResource):
3253
@cached_property
3354
def with_raw_response(self) -> FileBatchesWithRawResponse:
@@ -351,7 +372,11 @@ def poll(
351372
extra_headers=headers,
352373
)
353374

354-
batch = response.parse()
375+
data = response.parse(to=dict)
376+
batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id)
377+
if batch is None:
378+
batch = response.parse()
379+
355380
if batch.file_counts.in_progress > 0:
356381
if not is_given(poll_interval_ms):
357382
from_header = response.headers.get("openai-poll-after-ms")
@@ -739,7 +764,11 @@ async def poll(
739764
extra_headers=headers,
740765
)
741766

742-
batch = response.parse()
767+
data = response.parse(to=dict)
768+
batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id)
769+
if batch is None:
770+
batch = response.parse()
771+
743772
if batch.file_counts.in_progress > 0:
744773
if not is_given(poll_interval_ms):
745774
from_header = response.headers.get("openai-poll-after-ms")

tests/api_resources/vector_stores/test_file_batches.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import os
66
from typing import Any, cast
77

8+
import httpx
89
import pytest
910

1011
from openai import OpenAI, AsyncOpenAI
@@ -462,3 +463,58 @@ def test_create_and_poll_method_in_sync(sync: bool, client: OpenAI, async_client
462463
checking_client.vector_stores.file_batches.create,
463464
checking_client.vector_stores.file_batches.create_and_poll,
464465
)
466+
467+
468+
def _completed_vector_store_response() -> dict[str, object]:
469+
return {
470+
"id": "vs_abc123",
471+
"created_at": 1761991501,
472+
"file_counts": {
473+
"cancelled": 0,
474+
"completed": 1,
475+
"failed": 0,
476+
"in_progress": 0,
477+
"total": 1,
478+
},
479+
"object": "vector_store",
480+
"status": "completed",
481+
"vector_store_id": None,
482+
}
483+
484+
485+
def test_poll_coerces_completed_vector_store_response() -> None:
486+
def handler(request: httpx.Request) -> httpx.Response:
487+
assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123"
488+
return httpx.Response(200, json=_completed_vector_store_response())
489+
490+
with OpenAI(
491+
api_key="My API Key",
492+
base_url=base_url,
493+
http_client=httpx.Client(transport=httpx.MockTransport(handler)),
494+
_strict_response_validation=True,
495+
) as client:
496+
file_batch = client.vector_stores.file_batches.poll(batch_id="vsfb_abc123", vector_store_id="vs_abc123")
497+
498+
assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"])
499+
assert file_batch.id == "vsfb_abc123"
500+
assert file_batch.vector_store_id == "vs_abc123"
501+
502+
503+
async def test_async_poll_coerces_completed_vector_store_response() -> None:
504+
async def handler(request: httpx.Request) -> httpx.Response:
505+
assert request.url.path == "/vector_stores/vs_abc123/file_batches/vsfb_abc123"
506+
return httpx.Response(200, json=_completed_vector_store_response())
507+
508+
async with AsyncOpenAI(
509+
api_key="My API Key",
510+
base_url=base_url,
511+
http_client=httpx.AsyncClient(transport=httpx.MockTransport(handler)),
512+
_strict_response_validation=True,
513+
) as async_client:
514+
file_batch = await async_client.vector_stores.file_batches.poll(
515+
batch_id="vsfb_abc123", vector_store_id="vs_abc123"
516+
)
517+
518+
assert_matches_type(VectorStoreFileBatch, file_batch, path=["response"])
519+
assert file_batch.id == "vsfb_abc123"
520+
assert file_batch.vector_store_id == "vs_abc123"

0 commit comments

Comments
 (0)