2323import asyncio
2424import json
2525import os
26+ from datetime import datetime
2627from pathlib import Path
2728from typing import Dict , Any , Optional
2829
6465 from data_loader import DataLoader
6566 from simple_eval import run_simple_evaluation
6667 from hyperparam_utils import get_recommended_lr , get_lr_with_warmup
68+ from logger import StructuredLogger
6769except ImportError :
6870 TrainingConfig = None
6971 DataLoader = None
7072 run_simple_evaluation = None
7173 get_recommended_lr = None
7274 get_lr_with_warmup = None
75+ StructuredLogger = None
7376
7477
7578def prepare_training_data (
@@ -245,6 +248,15 @@ async def async_main(config_path: str) -> None:
245248 if tinker is None and not USE_MOCK :
246249 raise ImportError ("The `tinker` package is not installed. Please install it via `pip install tinker` or run in mock mode with TINKER_MOCK=1." )
247250
251+ run_dir = Path ("runs" ) / datetime .now ().strftime ("%Y%m%d_%H%M%S" )
252+ run_dir .mkdir (parents = True , exist_ok = True )
253+
254+ logger = StructuredLogger (run_dir = run_dir ) if StructuredLogger else None
255+
256+ if logger :
257+ logger .log_config (config .model_dump ())
258+ print (f"Logging to { run_dir } /metrics.jsonl" )
259+
248260 service_client = tinker .ServiceClient ()
249261 base_model = config .base_model
250262 max_rounds = config .max_rounds
@@ -332,6 +344,9 @@ async def async_main(config_path: str) -> None:
332344 weights_uri = training_client .save_weights_for_sampler (name = f"round_{ round_idx } " )
333345 state_uri = weights_uri .result ().path if hasattr (weights_uri , 'result' ) else weights_uri
334346 print (f"Checkpoint saved at { state_uri } " )
347+
348+ if logger :
349+ logger .log_checkpoint (round_idx , state_uri )
335350
336351 print ("Running evaluations..." )
337352 score = await run_evaluations (
@@ -346,6 +361,9 @@ async def async_main(config_path: str) -> None:
346361 round_number = round_idx ,
347362 )
348363 print (f"Evaluation score: { score :.4f} " )
364+
365+ if logger :
366+ logger .log_evaluation (round_idx , score , eval_threshold , score >= eval_threshold )
349367
350368 if score >= eval_threshold :
351369 print (f"Target met: { score :.4f} >= { eval_threshold } . Stopping." )
0 commit comments