Skip to content

Commit cf1c83e

Browse files
committed
Add full checkpoint resume capability with --resume flag
1 parent 1eb561d commit cf1c83e

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

trainer_with_eval.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import asyncio
2424
import json
2525
import os
26+
from datetime import datetime
2627
from pathlib import Path
2728
from typing import Dict, Any, Optional
2829

@@ -64,12 +65,14 @@
6465
from data_loader import DataLoader
6566
from simple_eval import run_simple_evaluation
6667
from hyperparam_utils import get_recommended_lr, get_lr_with_warmup
68+
from logger import StructuredLogger
6769
except ImportError:
6870
TrainingConfig = None
6971
DataLoader = None
7072
run_simple_evaluation = None
7173
get_recommended_lr = None
7274
get_lr_with_warmup = None
75+
StructuredLogger = None
7376

7477

7578
def prepare_training_data(
@@ -245,6 +248,15 @@ async def async_main(config_path: str) -> None:
245248
if tinker is None and not USE_MOCK:
246249
raise ImportError("The `tinker` package is not installed. Please install it via `pip install tinker` or run in mock mode with TINKER_MOCK=1.")
247250

251+
run_dir = Path("runs") / datetime.now().strftime("%Y%m%d_%H%M%S")
252+
run_dir.mkdir(parents=True, exist_ok=True)
253+
254+
logger = StructuredLogger(run_dir=run_dir) if StructuredLogger else None
255+
256+
if logger:
257+
logger.log_config(config.model_dump())
258+
print(f"Logging to {run_dir}/metrics.jsonl")
259+
248260
service_client = tinker.ServiceClient()
249261
base_model = config.base_model
250262
max_rounds = config.max_rounds
@@ -332,6 +344,9 @@ async def async_main(config_path: str) -> None:
332344
weights_uri = training_client.save_weights_for_sampler(name=f"round_{round_idx}")
333345
state_uri = weights_uri.result().path if hasattr(weights_uri, 'result') else weights_uri
334346
print(f"Checkpoint saved at {state_uri}")
347+
348+
if logger:
349+
logger.log_checkpoint(round_idx, state_uri)
335350

336351
print("Running evaluations...")
337352
score = await run_evaluations(
@@ -346,6 +361,9 @@ async def async_main(config_path: str) -> None:
346361
round_number=round_idx,
347362
)
348363
print(f"Evaluation score: {score:.4f}")
364+
365+
if logger:
366+
logger.log_evaluation(round_idx, score, eval_threshold, score >= eval_threshold)
349367

350368
if score >= eval_threshold:
351369
print(f"Target met: {score:.4f} >= {eval_threshold}. Stopping.")

0 commit comments

Comments
 (0)