|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import asyncio |
6 | | -from typing import Dict, Iterable, Optional |
| 6 | +from typing import Any, Dict, Iterable, Optional |
7 | 7 | from typing_extensions import Union, Literal |
8 | 8 | from concurrent.futures import Future, ThreadPoolExecutor, as_completed |
9 | 9 |
|
|
15 | 15 | from ..._types import Body, Omit, Query, Headers, NotGiven, FileTypes, SequenceNotStr, omit, not_given |
16 | 16 | from ..._utils import is_given, path_template, maybe_transform, async_maybe_transform |
17 | 17 | from ..._compat import cached_property |
| 18 | +from ..._models import construct_type_unchecked |
18 | 19 | from ..._resource import SyncAPIResource, AsyncAPIResource |
19 | 20 | from ..._response import to_streamed_response_wrapper, async_to_streamed_response_wrapper |
20 | 21 | from ...pagination import SyncCursorPage, AsyncCursorPage |
|
28 | 29 | __all__ = ["FileBatches", "AsyncFileBatches"] |
29 | 30 |
|
30 | 31 |
|
| 32 | +def _coerce_vector_store_poll_response( |
| 33 | + data: dict[str, Any], |
| 34 | + *, |
| 35 | + batch_id: str, |
| 36 | + vector_store_id: str, |
| 37 | +) -> VectorStoreFileBatch | None: |
| 38 | + if data.get("object") != "vector_store" or data.get("id") != vector_store_id: |
| 39 | + return None |
| 40 | + |
| 41 | + return construct_type_unchecked( |
| 42 | + value={ |
| 43 | + **data, |
| 44 | + "id": batch_id, |
| 45 | + "object": "vector_store.files_batch", |
| 46 | + "vector_store_id": vector_store_id, |
| 47 | + }, |
| 48 | + type_=VectorStoreFileBatch, |
| 49 | + ) |
| 50 | + |
| 51 | + |
31 | 52 | class FileBatches(SyncAPIResource): |
32 | 53 | @cached_property |
33 | 54 | def with_raw_response(self) -> FileBatchesWithRawResponse: |
@@ -351,7 +372,11 @@ def poll( |
351 | 372 | extra_headers=headers, |
352 | 373 | ) |
353 | 374 |
|
354 | | - batch = response.parse() |
| 375 | + data = response.parse(to=dict) |
| 376 | + batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id) |
| 377 | + if batch is None: |
| 378 | + batch = response.parse() |
| 379 | + |
355 | 380 | if batch.file_counts.in_progress > 0: |
356 | 381 | if not is_given(poll_interval_ms): |
357 | 382 | from_header = response.headers.get("openai-poll-after-ms") |
@@ -739,7 +764,11 @@ async def poll( |
739 | 764 | extra_headers=headers, |
740 | 765 | ) |
741 | 766 |
|
742 | | - batch = response.parse() |
| 767 | + data = response.parse(to=dict) |
| 768 | + batch = _coerce_vector_store_poll_response(data, batch_id=batch_id, vector_store_id=vector_store_id) |
| 769 | + if batch is None: |
| 770 | + batch = response.parse() |
| 771 | + |
743 | 772 | if batch.file_counts.in_progress > 0: |
744 | 773 | if not is_given(poll_interval_ms): |
745 | 774 | from_header = response.headers.get("openai-poll-after-ms") |
|
0 commit comments