Skip to content

Commit a2a86e9

Browse files
authored
Merge pull request #33 from longtermrisk/fix_inspect_ai_job
Fixing inference with OpenAI API
2 parents bc89022 + b292d94 commit a2a86e9

File tree

3 files changed

+189
-2
lines changed

3 files changed

+189
-2
lines changed

.env.example

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
OPENWEIGHTS_API_KEY=<your_openweights_api_key>
22
HF_ORG=longtermrisk
33
OW_DEFAULT_API_KEY=<optional key that will be used by vllm API deployments>
4+
OPENAI_API_KEY=<optional key required for examples using the OpenAI API>
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
"""Create a inference job with openai model and poll its results"""
2+
3+
import json
4+
import logging
5+
import os
6+
import random
7+
import time
8+
from typing import Dict
9+
10+
from dotenv import load_dotenv
11+
12+
from openweights import OpenWeights
13+
import openweights.jobs.inference
14+
15+
16+
def run_inference_job_and_get_outputs(
17+
filepath_conversations: str,
18+
model_to_evaluate: str,
19+
wait_for_completion: bool = False,
20+
display_log_file: bool = False,
21+
n_examples_to_log: int = 0,
22+
inference_hyperparameters: Dict = None,
23+
):
24+
load_dotenv()
25+
client = OpenWeights()
26+
27+
# Upload inference file
28+
with open(filepath_conversations, "rb") as file:
29+
file = client.files.create(file, purpose="conversations")
30+
file_id = file["id"]
31+
32+
keys_to_rm = [
33+
"learning_rate",
34+
"per_device_train_batch_size",
35+
"gradient_accumulation_steps",
36+
"max_seq_length",
37+
"load_in_4bit",
38+
"split",
39+
]
40+
for key in keys_to_rm:
41+
if key in inference_hyperparameters:
42+
del inference_hyperparameters[key]
43+
44+
# Create an inference job
45+
logging.info(
46+
f"Running inference for {model_to_evaluate} with parameters: {json.dumps(inference_hyperparameters, indent=4)}"
47+
)
48+
job = client.inference.create(
49+
model=model_to_evaluate,
50+
input_file_id=file_id,
51+
**inference_hyperparameters,
52+
)
53+
54+
if isinstance(job, dict):
55+
if "results" in job: # Completed OpenAI jobs
56+
output = job["results"]
57+
logging.info(f"Returning loaded outputs with length {len(output)}")
58+
if n_examples_to_log > 0:
59+
logging.info(f"Logging {n_examples_to_log} random outputs:")
60+
random_state = random.getstate()
61+
for i in random.sample(
62+
range(len(output)), min(n_examples_to_log, len(output))
63+
):
64+
logging.info(json.dumps(output[i], indent=4))
65+
random.setstate(random_state)
66+
elif "batch_job_info" in job: # Failed or running OpenAI batch jobs
67+
logging.info(f"Got batch job: {json.dumps(job, indent=4)}")
68+
logging.info(f"Retry when the OpenAI batch job is complete...")
69+
return None
70+
else:
71+
raise ValueError(f"Unknown job type: {type(job)}")
72+
else: # Regular OpenWeigths Jobs
73+
logging.info(job)
74+
75+
# Poll job status
76+
current_status = job["status"]
77+
while True:
78+
job = client.jobs.retrieve(job["id"])
79+
if job["status"] != current_status:
80+
# logging.info(job)
81+
current_status = job["status"]
82+
if job["status"] in ["completed", "failed", "canceled"]:
83+
break
84+
if not wait_for_completion:
85+
break
86+
time.sleep(5)
87+
88+
if not wait_for_completion and job["status"] != "completed":
89+
logging.info(
90+
f"Job {job['id']} did not complete, current status: {job['status']}"
91+
)
92+
return None
93+
94+
# Get log file:
95+
if display_log_file:
96+
runs = client.runs.list(job_id=job["id"])
97+
for run in runs:
98+
print(run)
99+
if run["log_file"]:
100+
log = client.files.content(run["log_file"]).decode("utf-8")
101+
print(log)
102+
print("---")
103+
104+
# Get output
105+
job = client.jobs.retrieve(job["id"])
106+
output_file_id = job["outputs"]["file"]
107+
output = client.files.content(output_file_id).decode("utf-8")
108+
output = [json.loads(line) for line in output.splitlines() if line.strip()]
109+
110+
return output
111+
112+
113+
if __name__ == "__main__":
114+
logging.basicConfig(level=logging.INFO)
115+
116+
output = run_inference_job_and_get_outputs(
117+
filepath_conversations=os.path.join(
118+
os.path.dirname(__file__), "../tests/inference_dataset_with_prefill.jsonl"
119+
),
120+
model_to_evaluate="openai/gpt-4.1-mini",
121+
inference_hyperparameters={
122+
"max_tokens": 1000,
123+
"temperature": 0.8,
124+
"max_model_len": 2048,
125+
"n_completions_per_prompt": 1,
126+
"use_batch": False,
127+
},
128+
n_examples_to_log=1,
129+
)
130+
print("parallel output:", output)
131+
132+
output = run_inference_job_and_get_outputs(
133+
filepath_conversations=os.path.join(
134+
os.path.dirname(__file__), "../tests/inference_dataset_with_prefill.jsonl"
135+
),
136+
model_to_evaluate="openai/gpt-4.1-mini",
137+
inference_hyperparameters={
138+
"max_tokens": 1000,
139+
"temperature": 0.8,
140+
"max_model_len": 2048,
141+
"n_completions_per_prompt": 1,
142+
"use_batch": True,
143+
},
144+
n_examples_to_log=1,
145+
)
146+
print("batch output:", output)

openweights/jobs/inference/openai_support.py

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@ def create_openai_inference_batch_request(
2222
import logging
2323
import time
2424

25+
logging.warning(
26+
"OpenAI batch API support through OpenWeigths is not tested.\nIssues include:\n-Files sent twice to OpenAI produce different file IDs. This should now be solved with the permanent caching on the function sending the file."
27+
)
28+
2529
# Initialize OpenAI client
2630
client = self._init_openai_client()
2731

@@ -60,11 +64,23 @@ def create_openai_inference_batch_request(
6064
)
6165

6266
# Check for existing batch jobs using this batch file
67+
found_batch = False
6368
try:
6469
logging.info(f"Checking for existing batch jobs for file {batch_file.id}")
6570
existing_batches = client.batches.list()
71+
# First check for completed batch jobs
72+
for batch in existing_batches.data:
73+
if batch.input_file_id == batch_file.id and batch.status == "completed":
74+
found_batch = True
75+
logging.info(
76+
f"Found existing batch job {batch.id} for batch file {batch_file.id}"
77+
)
78+
batch_job = client.batches.retrieve(batch.id)
79+
return self.get_batch_job_data(client, batch_job)
80+
# Then check for running batch jobs
6681
for batch in existing_batches.data:
6782
if batch.input_file_id == batch_file.id:
83+
found_batch = True
6884
logging.info(
6985
f"Found existing batch job {batch.id} for batch file {batch_file.id}"
7086
)
@@ -73,6 +89,12 @@ def create_openai_inference_batch_request(
7389
except Exception as e:
7490
logging.error(f"Error checking existing batch jobs: {str(e)}")
7591

92+
if found_batch:
93+
return {
94+
"status": "completed",
95+
"results": "Failed to retrieve batch job data",
96+
}
97+
7698
# If no existing batch found, create new batch job
7799
batch_job = client.batches.create(
78100
input_file_id=batch_file.id,
@@ -298,10 +320,28 @@ def get_batch_job_data(self, openai_client, batch_job):
298320
logging.info(f"Batch job status: {batch_data.status}")
299321
if batch_data.status == "completed":
300322
logging.info(f"Retrieving results for file {batch_data.output_file_id}")
301-
file_data = openai_client.files.retrieve(batch_data.output_file_id)
323+
file_content = openai_client.files.content(batch_data.output_file_id)
324+
325+
result_file_name = os.path.join(
326+
os.path.dirname(os.path.dirname(__file__)),
327+
"tmp.jsonl",
328+
)
329+
with open(result_file_name, "wb") as file:
330+
file.write(file_content.content)
331+
332+
# Loading data from saved file
333+
results = []
334+
with open(result_file_name, "r") as file:
335+
for line in file:
336+
# Parsing the JSON string into a dict and appending to the list of results
337+
json_object = json.loads(line.strip())
338+
results.append(json_object)
339+
340+
os.remove(result_file_name)
341+
302342
return {
303343
"status": "completed",
304-
"results": json.loads(file_data.content),
344+
"results": results,
305345
"batch_job_info": json.loads(json.dumps(batch_data.model_dump())),
306346
}
307347
else:

0 commit comments

Comments
 (0)