Skip to content

Commit 782d4d0

Browse files
committed
Complete evaluation function rewrite with Inspect AI
1 parent 01c6033 commit 782d4d0

File tree

1 file changed

+24
-12
lines changed

1 file changed

+24
-12
lines changed

trainer_with_eval.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -196,26 +196,38 @@ async def run_evaluations(
196196
If EvalOps integration is enabled, results are submitted automatically.
197197
198198
Args:
199-
model_path: The path to the model checkpoint. For Tinker models, use
200-
the `tinker://...` syntax as described in the docs.
201-
model_name: The name of the base model used (e.g., "meta-llama/Llama-3.1-8B").
202-
tasks: A list of evaluation task identifiers (e.g., "inspect_evals/ifeval").
203-
renderer_name: The name of the renderer to use for message formatting.
204-
threshold: A target score; used to decide whether training should continue.
205-
evalops_client: Optional EvalOps client for submitting results.
206-
test_suite_id: Optional test suite ID in EvalOps.
207-
round_number: Optional training round number.
199+
model_path: Model checkpoint path (tinker://... or mock://...).
200+
model_name: Base model name (e.g., "meta-llama/Llama-3.1-8B").
201+
tasks: Evaluation task identifiers.
202+
renderer_name: Renderer for message formatting.
203+
threshold: Target score for training continuation.
204+
service_client: Tinker service client for sampling.
205+
training_client: Training client for simple eval fallback.
206+
evalops_client: Optional EvalOps client.
207+
test_suite_id: Optional EvalOps test suite ID.
208+
round_number: Current training round.
208209
209210
Returns:
210-
A float representing the aggregated evaluation score. Higher is better.
211+
Aggregate evaluation score (0.0-1.0).
211212
"""
212-
if run_simple_evaluation is not None:
213+
if INSPECT_AVAILABLE and run_inspect_evaluation and service_client and not model_path.startswith("mock://"):
214+
try:
215+
score = await run_inspect_evaluation(
216+
service_client, model_path, model_name, renderer_name, tasks
217+
)
218+
print(f" Inspect AI evaluation: {score:.4f}")
219+
except Exception as e:
220+
print(f" Inspect AI failed ({e}), using simple evaluator")
221+
score = run_simple_evaluation(
222+
training_client, model_path, tasks, round_number=round_number or 1
223+
) if run_simple_evaluation else np.random.rand()
224+
elif run_simple_evaluation is not None:
213225
score = run_simple_evaluation(
214226
training_client, model_path, tasks, round_number=round_number or 1
215227
)
216228
else:
217229
score = np.random.rand()
218-
print(f" Using simulated score: {score:.4f} (implement real evaluation for production)")
230+
print(f" Using simulated score: {score:.4f}")
219231

220232
if evalops_client and test_suite_id:
221233
try:

0 commit comments

Comments
 (0)