Skip to content

Commit b8127e9

Browse files
committed
Implement async futures with proper batching for performance
1 parent c2e4ddf commit b8127e9

File tree

1 file changed

+23
-8
lines changed

1 file changed

+23
-8
lines changed

trainer_with_eval.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,20 +287,35 @@ async def async_main(config_path: str) -> None:
287287
for round_idx in range(1, max_rounds + 1):
288288
print(f"\n=== Training round {round_idx}/{max_rounds} ===")
289289

290-
if get_lr_with_warmup and config.warmup_steps > 0:
291-
current_lr = get_lr_with_warmup(
292-
step=global_step,
290+
if hasattr(training_client, 'forward_backward_async'):
291+
steps_executed = await run_training_round_async(
292+
training_client=training_client,
293+
datums=datums,
294+
batch_size=config.batch_size,
295+
steps_per_round=config.steps_per_round,
293296
base_lr=learning_rate,
297+
step_offset=global_step,
294298
warmup_steps=config.warmup_steps,
295299
max_steps=config.max_steps,
296300
min_lr=config.min_lr,
297301
)
298-
print(f" Step {global_step}: LR = {current_lr:.2e}")
302+
print(f" Completed {steps_executed} training steps")
303+
global_step += steps_executed
299304
else:
300-
current_lr = learning_rate
301-
302-
run_training_round(training_client, datums, current_lr)
303-
global_step += config.steps_per_round
305+
if get_lr_with_warmup and config.warmup_steps > 0:
306+
current_lr = get_lr_with_warmup(
307+
step=global_step,
308+
base_lr=learning_rate,
309+
warmup_steps=config.warmup_steps,
310+
max_steps=config.max_steps,
311+
min_lr=config.min_lr,
312+
)
313+
print(f" Step {global_step}: LR = {current_lr:.2e}")
314+
else:
315+
current_lr = learning_rate
316+
317+
run_training_round(training_client, datums, current_lr)
318+
global_step += config.steps_per_round
304319

305320
print("Saving model checkpoint...")
306321
weights_uri = training_client.save_weights_for_sampler(name=f"round_{round_idx}")

0 commit comments

Comments
 (0)