Skip to content

Commit 0610100

Browse files
committed
Add Inspect AI evaluation integration and checkpoint resume
1 parent cf1c83e commit 0610100

File tree

2 files changed

+162
-7
lines changed

2 files changed

+162
-7
lines changed

inspect_eval.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
"""
2+
Inspect AI evaluation integration with Tinker sampling.
3+
"""
4+
5+
from typing import Any, Dict, List, Optional
6+
7+
try:
8+
from inspect_ai import Task, task, eval_async
9+
from inspect_ai.dataset import MemoryDataset, Sample
10+
from inspect_ai.model import GenerateConfig, Model
11+
from inspect_ai.scorer import match, includes
12+
from inspect_ai.solver import generate
13+
INSPECT_AVAILABLE = True
14+
except ImportError:
15+
INSPECT_AVAILABLE = False
16+
17+
try:
18+
from tinker_cookbook.eval.inspect_utils import InspectAPIFromTinkerSampling
19+
TINKER_INSPECT_AVAILABLE = True
20+
except ImportError:
21+
TINKER_INSPECT_AVAILABLE = False
22+
23+
24+
QA_SAMPLES = [
25+
Sample(input="What is 2 + 2?", target="4"),
26+
Sample(input="What is the capital of France?", target="Paris"),
27+
Sample(input="What color is grass?", target="green"),
28+
Sample(input="How many days in a week?", target="7"),
29+
Sample(input="What is 10 x 5?", target="50"),
30+
]
31+
32+
33+
@task
34+
def simple_qa_task() -> Task:
35+
"""
36+
Simple QA evaluation task for demo purposes.
37+
38+
Tests basic factual knowledge with exact match scoring.
39+
"""
40+
if not INSPECT_AVAILABLE:
41+
raise ImportError("inspect_ai required for this task")
42+
43+
return Task(
44+
name="simple_qa",
45+
dataset=MemoryDataset(name="simple_qa", samples=QA_SAMPLES),
46+
solver=generate(),
47+
scorer=includes(),
48+
)
49+
50+
51+
async def run_inspect_evaluation(
52+
service_client: Any,
53+
model_path: str,
54+
model_name: str,
55+
renderer_name: str,
56+
tasks: List[str],
57+
) -> float:
58+
"""
59+
Run Inspect AI evaluation using Tinker sampling.
60+
61+
Args:
62+
service_client: Tinker service client.
63+
model_path: Path to model checkpoint (tinker:// or mock://).
64+
model_name: Base model name.
65+
renderer_name: Renderer name for message formatting.
66+
tasks: List of task names to evaluate.
67+
68+
Returns:
69+
Aggregate accuracy score.
70+
"""
71+
if not INSPECT_AVAILABLE:
72+
print(" Warning: inspect_ai not available, using fallback")
73+
return 0.5
74+
75+
if model_path.startswith("mock://"):
76+
print(" Warning: Mock mode - using simulated eval")
77+
return 0.5
78+
79+
try:
80+
if not TINKER_INSPECT_AVAILABLE:
81+
print(" Warning: tinker_cookbook.eval not available")
82+
results = await eval_async(
83+
tasks=[simple_qa_task()],
84+
model=f"openai/{model_name}",
85+
)
86+
else:
87+
sampling_client = service_client.create_sampling_client(model_path=model_path)
88+
89+
api = InspectAPIFromTinkerSampling(
90+
renderer_name=renderer_name,
91+
model_name=model_name,
92+
sampling_client=sampling_client,
93+
verbose=False,
94+
)
95+
96+
model = Model(api=api, config=GenerateConfig(max_tokens=100, temperature=0.0))
97+
98+
eval_tasks = [simple_qa_task()] if "simple_qa" in tasks else []
99+
100+
results = await eval_async(tasks=eval_tasks, model=model)
101+
102+
if results and len(results) > 0:
103+
scores = [r.scores[0].value for r in results if r.scores]
104+
return sum(scores) / len(scores) if scores else 0.0
105+
106+
return 0.0
107+
108+
except Exception as e:
109+
print(f" Warning: Inspect AI evaluation failed: {e}")
110+
return 0.5

trainer_with_eval.py

Lines changed: 52 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,13 +66,16 @@
6666
from simple_eval import run_simple_evaluation
6767
from hyperparam_utils import get_recommended_lr, get_lr_with_warmup
6868
from logger import StructuredLogger
69+
from checkpoint_manager import CheckpointManager, find_latest_run
6970
except ImportError:
7071
TrainingConfig = None
7172
DataLoader = None
7273
run_simple_evaluation = None
7374
get_recommended_lr = None
7475
get_lr_with_warmup = None
7576
StructuredLogger = None
77+
CheckpointManager = None
78+
find_latest_run = None
7679

7780

7881
def prepare_training_data(
@@ -238,7 +241,7 @@ async def run_evaluations(
238241
return score
239242

240243

241-
async def async_main(config_path: str) -> None:
244+
async def async_main(config_path: str, resume: bool = False) -> None:
242245
"""Main training loop with async EvalOps integration."""
243246
if TrainingConfig is None:
244247
raise ImportError("config_schema module required. Please ensure all dependencies are installed.")
@@ -248,12 +251,38 @@ async def async_main(config_path: str) -> None:
248251
if tinker is None and not USE_MOCK:
249252
raise ImportError("The `tinker` package is not installed. Please install it via `pip install tinker` or run in mock mode with TINKER_MOCK=1.")
250253

251-
run_dir = Path("runs") / datetime.now().strftime("%Y%m%d_%H%M%S")
254+
start_round = 1
255+
start_step = 0
256+
start_lr = None
257+
checkpoint_to_load = None
258+
259+
if resume and find_latest_run:
260+
latest_run = find_latest_run()
261+
if latest_run and CheckpointManager:
262+
mgr = CheckpointManager(latest_run)
263+
saved_state = mgr.load_run_state()
264+
if saved_state:
265+
start_round = saved_state["round_idx"] + 1
266+
start_step = saved_state["global_step"]
267+
start_lr = saved_state["learning_rate"]
268+
checkpoint_to_load = saved_state.get("checkpoint_uri")
269+
run_dir = latest_run
270+
print(f"Resuming from round {start_round}, step {start_step}")
271+
print(f"Loading checkpoint: {checkpoint_to_load}")
272+
else:
273+
print("No saved state found, starting fresh")
274+
run_dir = Path("runs") / datetime.now().strftime("%Y%m%d_%H%M%S")
275+
else:
276+
run_dir = Path("runs") / datetime.now().strftime("%Y%m%d_%H%M%S")
277+
else:
278+
run_dir = Path("runs") / datetime.now().strftime("%Y%m%d_%H%M%S")
279+
252280
run_dir.mkdir(parents=True, exist_ok=True)
253281

254282
logger = StructuredLogger(run_dir=run_dir) if StructuredLogger else None
283+
checkpoint_mgr = CheckpointManager(run_dir) if CheckpointManager else None
255284

256-
if logger:
285+
if logger and not resume:
257286
logger.log_config(config.model_dump())
258287
print(f"Logging to {run_dir}/metrics.jsonl")
259288

@@ -264,13 +293,15 @@ async def async_main(config_path: str) -> None:
264293
tasks = config.eval_tasks
265294
renderer_name = config.renderer_name
266295

267-
if config.use_recommended_lr and get_recommended_lr:
296+
if start_lr:
297+
learning_rate = start_lr
298+
elif config.use_recommended_lr and get_recommended_lr:
268299
learning_rate = get_recommended_lr(base_model)
269300
print(f"Using recommended LR for {base_model}: {learning_rate:.2e}")
270301
else:
271302
learning_rate = config.learning_rate
272303

273-
global_step = 0
304+
global_step = start_step
274305

275306
evalops_enabled = config.evalops_enabled
276307
test_suite_id = config.evalops_test_suite_id
@@ -295,6 +326,10 @@ async def async_main(config_path: str) -> None:
295326
)
296327

297328
tokenizer = training_client.get_tokenizer()
329+
330+
if checkpoint_to_load:
331+
print("Loading training state from checkpoint...")
332+
training_client.load_state(checkpoint_to_load)
298333

299334
datums = prepare_training_data(
300335
train_file=config.train_file,
@@ -307,7 +342,7 @@ async def async_main(config_path: str) -> None:
307342
print("Warning: no training data loaded. Check that your training file has valid examples.")
308343

309344
try:
310-
for round_idx in range(1, max_rounds + 1):
345+
for round_idx in range(start_round, max_rounds + 1):
311346
print(f"\n=== Training round {round_idx}/{max_rounds} ===")
312347

313348
if hasattr(training_client, 'forward_backward_async'):
@@ -347,6 +382,15 @@ async def async_main(config_path: str) -> None:
347382

348383
if logger:
349384
logger.log_checkpoint(round_idx, state_uri)
385+
386+
if checkpoint_mgr:
387+
checkpoint_mgr.save_run_state(
388+
round_idx=round_idx,
389+
global_step=global_step,
390+
learning_rate=learning_rate,
391+
checkpoint_uri=state_uri,
392+
config=config.model_dump(),
393+
)
350394

351395
print("Running evaluations...")
352396
score = await run_evaluations(
@@ -381,9 +425,10 @@ async def async_main(config_path: str) -> None:
381425
def main() -> None:
382426
parser = argparse.ArgumentParser(description="Evaluation‑driven fine‑tuning loop")
383427
parser.add_argument("--config", type=str, required=True, help="Path to configuration JSON file")
428+
parser.add_argument("--resume", action="store_true", help="Resume from latest checkpoint")
384429
args = parser.parse_args()
385430

386-
asyncio.run(async_main(args.config))
431+
asyncio.run(async_main(args.config, resume=args.resume))
387432

388433

389434
if __name__ == "__main__":

0 commit comments

Comments
 (0)