Skip to content

Add OpenAI Responses API support to parallel processor #1972

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 96 additions & 14 deletions examples/api_request_parallel_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
- Throttles request and token usage, to stay under rate limits
- Retries failed requests up to {max_attempts} times, to avoid missing data
- Logs errors, to diagnose problems with requests
- Supports OpenAI Responses API with proper error handling and token counting

Example command to call script:
```
Expand Down Expand Up @@ -43,6 +44,7 @@
- if omitted, results will be saved to {requests_filename}_results.jsonl
- request_url : str, optional
- URL of the API endpoint to call
- supports chat completions, embeddings, responses, and other OpenAI API endpoints
- if omitted, will default to "https://api.openai.com/v1/embeddings"
- api_key : str, optional
- API key to use
Expand Down Expand Up @@ -153,13 +155,13 @@ async def process_api_requests_from_file(

# initialize flags
file_not_finished = True # after file is empty, we'll skip reading it
logging.debug(f"Initialization complete.")
logging.debug("Initialization complete.")

# initialize file reading
with open(requests_filepath) as file:
# `requests` will provide requests one at a time
requests = file.__iter__()
logging.debug(f"File opened. Entering main loop")
logging.debug("File opened. Entering main loop")
async with aiohttp.ClientSession() as session: # Initialize ClientSession here
while True:
# get next request (if one is not already waiting for capacity)
Expand Down Expand Up @@ -316,18 +318,51 @@ async def call_api(
url=request_url, headers=request_header, json=self.request_json
) as response:
response = await response.json()
if "error" in response:
logging.warning(
f"Request {self.task_id} failed with error {response['error']}"
)
status_tracker.num_api_errors += 1
error = response
if "rate limit" in response["error"].get("message", "").lower():
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
status_tracker.num_api_errors -= (
1 # rate limit errors are counted separately

# Handle different API response formats
api_endpoint = api_endpoint_from_url(request_url)

# Check for errors based on API type
if api_endpoint == "responses":
# New Responses API format
if response.get("error") is not None:
logging.warning(
f"Request {self.task_id} failed with error {response['error']}"
)
status_tracker.num_api_errors += 1
error = response
# Handle rate limit errors for responses API
error_obj = response.get("error", {})
if isinstance(error_obj, dict):
error_message = error_obj.get("message", "")
if "rate limit" in error_message.lower():
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
status_tracker.num_api_errors -= 1
elif response.get("status") == "failed":
# Handle failed status in responses API
logging.warning(
f"Request {self.task_id} failed with status: failed"
)
status_tracker.num_api_errors += 1
error = response
else:
# Old API format (chat completions, embeddings, etc.)
if "error" in response:
logging.warning(
f"Request {self.task_id} failed with error {response['error']}"
)
status_tracker.num_api_errors += 1
error = response
if (
response["error"]
and "rate limit" in response["error"].get("message", "").lower()
):
status_tracker.time_of_last_rate_limit_error = time.time()
status_tracker.num_rate_limit_errors += 1
status_tracker.num_api_errors -= (
1 # rate limit errors are counted separately
)

except (
Exception
Expand Down Expand Up @@ -389,7 +424,7 @@ def num_tokens_consumed_from_request(
api_endpoint: str,
token_encoding_name: str,
):
"""Count the number of tokens in the request. Only supports completion and embedding requests."""
"""Count the number of tokens in the request. Supports completion, embedding, and responses requests."""
encoding = tiktoken.get_encoding(token_encoding_name)
# if completions request, tokens = prompt + n * max_tokens
if api_endpoint.endswith("completions"):
Expand Down Expand Up @@ -436,6 +471,53 @@ def num_tokens_consumed_from_request(
raise TypeError(
'Expecting either string or list of strings for "inputs" field in embedding request'
)
# if responses request, tokens = input tokens (similar to chat completions but with "input" field)
elif api_endpoint == "responses":
input_data = request_json["input"]
if isinstance(input_data, str): # single input string
num_tokens = len(encoding.encode(input_data))
return num_tokens
elif isinstance(
input_data, list
): # array of message objects (similar to chat completions)
num_tokens = 0
for item in input_data:
if isinstance(item, dict):
# Handle message objects
if "content" in item:
content = item["content"]
if isinstance(content, str):
num_tokens += len(encoding.encode(content))
elif isinstance(content, list):
# Handle content array with different types (text, images, etc.)
for content_item in content:
if (
isinstance(content_item, dict)
and "text" in content_item
):
num_tokens += len(
encoding.encode(content_item["text"])
)
elif (
isinstance(content_item, dict)
and content_item.get("type") == "input_text"
):
num_tokens += len(
encoding.encode(content_item["text"])
)
# Add tokens for role and message structure overhead (similar to chat completions)
num_tokens += 4 # every message follows similar structure
for key, value in item.items():
if key != "content" and isinstance(value, str):
num_tokens += len(encoding.encode(value))
elif isinstance(item, str):
# Handle simple string items in the array
num_tokens += len(encoding.encode(item))
return num_tokens
else:
raise TypeError(
'Expecting either string or list of message objects for "input" field in responses request'
)
# more logic needed to support other API calls (e.g., edits, inserts, DALL-E)
else:
raise NotImplementedError(
Expand Down