@@ -287,20 +287,35 @@ async def async_main(config_path: str) -> None:
287287 for round_idx in range (1 , max_rounds + 1 ):
288288 print (f"\n === Training round { round_idx } /{ max_rounds } ===" )
289289
290- if get_lr_with_warmup and config .warmup_steps > 0 :
291- current_lr = get_lr_with_warmup (
292- step = global_step ,
290+ if hasattr (training_client , 'forward_backward_async' ):
291+ steps_executed = await run_training_round_async (
292+ training_client = training_client ,
293+ datums = datums ,
294+ batch_size = config .batch_size ,
295+ steps_per_round = config .steps_per_round ,
293296 base_lr = learning_rate ,
297+ step_offset = global_step ,
294298 warmup_steps = config .warmup_steps ,
295299 max_steps = config .max_steps ,
296300 min_lr = config .min_lr ,
297301 )
298- print (f" Step { global_step } : LR = { current_lr :.2e} " )
302+ print (f" Completed { steps_executed } training steps" )
303+ global_step += steps_executed
299304 else :
300- current_lr = learning_rate
301-
302- run_training_round (training_client , datums , current_lr )
303- global_step += config .steps_per_round
305+ if get_lr_with_warmup and config .warmup_steps > 0 :
306+ current_lr = get_lr_with_warmup (
307+ step = global_step ,
308+ base_lr = learning_rate ,
309+ warmup_steps = config .warmup_steps ,
310+ max_steps = config .max_steps ,
311+ min_lr = config .min_lr ,
312+ )
313+ print (f" Step { global_step } : LR = { current_lr :.2e} " )
314+ else :
315+ current_lr = learning_rate
316+
317+ run_training_round (training_client , datums , current_lr )
318+ global_step += config .steps_per_round
304319
305320 print ("Saving model checkpoint..." )
306321 weights_uri = training_client .save_weights_for_sampler (name = f"round_{ round_idx } " )
0 commit comments