diff --git a/train_dynamics.py b/train_dynamics.py index c64013f..8f191ae 100644 --- a/train_dynamics.py +++ b/train_dynamics.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field import os +import time import einops from flax.training.train_state import TrainState @@ -70,6 +71,7 @@ class Args: log_checkpoint_interval: int = 25000 log_checkpoint_keep_period: int = 20000 log_gradients: bool = False + time_measurement_interval: int = 50 args = tyro.cli(Args) @@ -242,6 +244,7 @@ def train_step(state, inputs): prefetch_buffer_size=1, seed=args.seed, ) + step = 0 initial_state = grain_dataloader._create_initial_state() grain_iterator = grain.DataLoaderIterator(grain_dataloader, initial_state) @@ -268,6 +271,7 @@ def train_step(state, inputs): # --- TRAIN LOOP --- dataloader = (jax.make_array_from_process_local_data(videos_sharding, elem) for elem in grain_iterator) # type: ignore + start_time = time.time() while step < args.num_steps: for videos in dataloader: # --- Train step --- @@ -280,7 +284,16 @@ def train_step(state, inputs): mask_rng=_rng_mask, ) train_state, loss, recon, metrics = train_step(train_state, inputs) - print(f"Step {step}, loss: {loss}") + if step % args.time_measurement_interval == 0: + jax.block_until_ready(train_state) + elapsed_time = (time.time() - start_time) + avg_step_time = elapsed_time / args.time_measurement_interval + print(f"Step {step}, loss: {loss}, avg step time: {avg_step_time:.2f}s") + if args.log and jax.process_index() == 0: + wandb.log({"avg_step_time_s": avg_step_time, "step": step}) + start_time = time.time() + else: + print(f"Step {step}, loss: {loss}") step += 1 # --- Logging --- diff --git a/train_lam.py b/train_lam.py index d016b24..97c2aff 100644 --- a/train_lam.py +++ b/train_lam.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field import os +import time import einops from flax.training import orbax_utils @@ -59,6 +60,7 @@ class Args: log_image_interval: int = 250 ckpt_dir: str = "" log_checkpoint_interval: int = 10000 + time_measurement_interval: int = 50 log_checkpoint_keep_period: int = 20000 @@ -257,6 +259,7 @@ def train_step(state, inputs, action_last_active): # --- TRAIN LOOP --- dataloader = (jax.make_array_from_process_local_data(videos_sharding, elem) for elem in grain_iterator) # type: ignore + start_time = time.time() print(f"Starting training from step {step}...") while step < args.num_steps: for videos in dataloader: @@ -267,7 +270,16 @@ def train_step(state, inputs, action_last_active): train_state, loss, recon, action_last_active, metrics = train_step( train_state, inputs, action_last_active ) - print(f"Step {step}, loss: {loss}") + if step % args.time_measurement_interval == 0: + jax.block_until_ready(train_state) + elapsed_time = (time.time() - start_time) + avg_step_time = elapsed_time / args.time_measurement_interval + print(f"Step {step}, loss: {loss}, avg step time: {avg_step_time:.2f}s") + if args.log and jax.process_index() == 0: + wandb.log({"avg_step_time_s": avg_step_time, "step": step}) + start_time = time.time() + else: + print(f"Step {step}, loss: {loss}") step += 1 # --- Logging --- diff --git a/train_tokenizer.py b/train_tokenizer.py index c36dde4..668697c 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field import os +import time import einops from flax.training import orbax_utils @@ -60,6 +61,7 @@ class Args: log_checkpoint_interval: int = 10000 log_checkpoint_keep_period: int = 20000 log_gradients: bool = False + time_measurement_interval: int = 50 args = tyro.cli(Args) @@ -249,6 +251,7 @@ def train_step(state, inputs): print(f"Restored dataloader and model state from step {step}") # --- TRAIN LOOP --- + start_time = time.time() dataloader = (jax.make_array_from_process_local_data(videos_sharding, elem) for elem in grain_iterator) # type: ignore print(f"Starting training from step {step}...") while step < args.num_steps: @@ -258,7 +261,16 @@ def train_step(state, inputs): inputs = dict(videos=videos, rng=_rng, dropout_rng=_rng_dropout) train_state, loss, recon, metrics = train_step(train_state, inputs) - print(f"Step {step}, loss: {loss}") + if step % args.time_measurement_interval == 0: + jax.block_until_ready(train_state) + elapsed_time = (time.time() - start_time) + avg_step_time = elapsed_time / args.time_measurement_interval + print(f"Step {step}, loss: {loss}, avg step time: {avg_step_time:.2f}s") + if args.log and jax.process_index() == 0: + wandb.log({"avg_step_time_s": avg_step_time, "step": step}) + start_time = time.time() + else: + print(f"Step {step}, loss: {loss}") step += 1 # --- Logging ---