6666 from simple_eval import run_simple_evaluation
6767 from hyperparam_utils import get_recommended_lr , get_lr_with_warmup
6868 from logger import StructuredLogger
69+ from checkpoint_manager import CheckpointManager , find_latest_run
6970except ImportError :
7071 TrainingConfig = None
7172 DataLoader = None
7273 run_simple_evaluation = None
7374 get_recommended_lr = None
7475 get_lr_with_warmup = None
7576 StructuredLogger = None
77+ CheckpointManager = None
78+ find_latest_run = None
7679
7780
7881def prepare_training_data (
@@ -238,7 +241,7 @@ async def run_evaluations(
238241 return score
239242
240243
241- async def async_main (config_path : str ) -> None :
244+ async def async_main (config_path : str , resume : bool = False ) -> None :
242245 """Main training loop with async EvalOps integration."""
243246 if TrainingConfig is None :
244247 raise ImportError ("config_schema module required. Please ensure all dependencies are installed." )
@@ -248,12 +251,38 @@ async def async_main(config_path: str) -> None:
248251 if tinker is None and not USE_MOCK :
249252 raise ImportError ("The `tinker` package is not installed. Please install it via `pip install tinker` or run in mock mode with TINKER_MOCK=1." )
250253
251- run_dir = Path ("runs" ) / datetime .now ().strftime ("%Y%m%d_%H%M%S" )
254+ start_round = 1
255+ start_step = 0
256+ start_lr = None
257+ checkpoint_to_load = None
258+
259+ if resume and find_latest_run :
260+ latest_run = find_latest_run ()
261+ if latest_run and CheckpointManager :
262+ mgr = CheckpointManager (latest_run )
263+ saved_state = mgr .load_run_state ()
264+ if saved_state :
265+ start_round = saved_state ["round_idx" ] + 1
266+ start_step = saved_state ["global_step" ]
267+ start_lr = saved_state ["learning_rate" ]
268+ checkpoint_to_load = saved_state .get ("checkpoint_uri" )
269+ run_dir = latest_run
270+ print (f"Resuming from round { start_round } , step { start_step } " )
271+ print (f"Loading checkpoint: { checkpoint_to_load } " )
272+ else :
273+ print ("No saved state found, starting fresh" )
274+ run_dir = Path ("runs" ) / datetime .now ().strftime ("%Y%m%d_%H%M%S" )
275+ else :
276+ run_dir = Path ("runs" ) / datetime .now ().strftime ("%Y%m%d_%H%M%S" )
277+ else :
278+ run_dir = Path ("runs" ) / datetime .now ().strftime ("%Y%m%d_%H%M%S" )
279+
252280 run_dir .mkdir (parents = True , exist_ok = True )
253281
254282 logger = StructuredLogger (run_dir = run_dir ) if StructuredLogger else None
283+ checkpoint_mgr = CheckpointManager (run_dir ) if CheckpointManager else None
255284
256- if logger :
285+ if logger and not resume :
257286 logger .log_config (config .model_dump ())
258287 print (f"Logging to { run_dir } /metrics.jsonl" )
259288
@@ -264,13 +293,15 @@ async def async_main(config_path: str) -> None:
264293 tasks = config .eval_tasks
265294 renderer_name = config .renderer_name
266295
267- if config .use_recommended_lr and get_recommended_lr :
296+ if start_lr :
297+ learning_rate = start_lr
298+ elif config .use_recommended_lr and get_recommended_lr :
268299 learning_rate = get_recommended_lr (base_model )
269300 print (f"Using recommended LR for { base_model } : { learning_rate :.2e} " )
270301 else :
271302 learning_rate = config .learning_rate
272303
273- global_step = 0
304+ global_step = start_step
274305
275306 evalops_enabled = config .evalops_enabled
276307 test_suite_id = config .evalops_test_suite_id
@@ -295,6 +326,10 @@ async def async_main(config_path: str) -> None:
295326 )
296327
297328 tokenizer = training_client .get_tokenizer ()
329+
330+ if checkpoint_to_load :
331+ print ("Loading training state from checkpoint..." )
332+ training_client .load_state (checkpoint_to_load )
298333
299334 datums = prepare_training_data (
300335 train_file = config .train_file ,
@@ -307,7 +342,7 @@ async def async_main(config_path: str) -> None:
307342 print ("Warning: no training data loaded. Check that your training file has valid examples." )
308343
309344 try :
310- for round_idx in range (1 , max_rounds + 1 ):
345+ for round_idx in range (start_round , max_rounds + 1 ):
311346 print (f"\n === Training round { round_idx } /{ max_rounds } ===" )
312347
313348 if hasattr (training_client , 'forward_backward_async' ):
@@ -347,6 +382,15 @@ async def async_main(config_path: str) -> None:
347382
348383 if logger :
349384 logger .log_checkpoint (round_idx , state_uri )
385+
386+ if checkpoint_mgr :
387+ checkpoint_mgr .save_run_state (
388+ round_idx = round_idx ,
389+ global_step = global_step ,
390+ learning_rate = learning_rate ,
391+ checkpoint_uri = state_uri ,
392+ config = config .model_dump (),
393+ )
350394
351395 print ("Running evaluations..." )
352396 score = await run_evaluations (
@@ -381,9 +425,10 @@ async def async_main(config_path: str) -> None:
381425def main () -> None :
382426 parser = argparse .ArgumentParser (description = "Evaluation‑driven fine‑tuning loop" )
383427 parser .add_argument ("--config" , type = str , required = True , help = "Path to configuration JSON file" )
428+ parser .add_argument ("--resume" , action = "store_true" , help = "Resume from latest checkpoint" )
384429 args = parser .parse_args ()
385430
386- asyncio .run (async_main (args .config ))
431+ asyncio .run (async_main (args .config , resume = args . resume ))
387432
388433
389434if __name__ == "__main__" :
0 commit comments