diff --git a/train_dynamics.py b/train_dynamics.py index fe7fc80..ca0c7dd 100644 --- a/train_dynamics.py +++ b/train_dynamics.py @@ -230,10 +230,8 @@ def train_step(state, inputs): dropout_rng=_rng_dropout, mask_rng=_rng_mask, ) - start_time = time.time() train_state, loss, recon, metrics = train_step(train_state, inputs) - elapsed_time = (time.time() - start_time) * 1000 - print(f"Step {step}, loss: {loss}, step time: {elapsed_time}ms") + print(f"Step {step}, loss: {loss}") step += 1 # --- Logging --- @@ -243,7 +241,6 @@ def train_step(state, inputs): { "loss": loss, "step": step, - "step_time_ms": elapsed_time, **metrics, } ) diff --git a/train_lam.py b/train_lam.py index bc8d50c..e830155 100644 --- a/train_lam.py +++ b/train_lam.py @@ -229,12 +229,10 @@ def train_step(state, inputs, action_last_active): videos = jax.make_array_from_process_local_data(videos_sharding, videos) inputs = dict(videos=videos, rng=_rng) - start_time = time.time() train_state, loss, recon, action_last_active, metrics = train_step( train_state, inputs, action_last_active ) - elapsed_time = (time.time() - start_time) * 1000 - print(f"Step {step}, loss: {loss}, step time: {elapsed_time}ms") + print(f"Step {step}, loss: {loss}") step += 1 # --- Logging --- @@ -244,7 +242,6 @@ def train_step(state, inputs, action_last_active): { "loss": loss, "step": step, - "step_time_ms": elapsed_time, **metrics, } ) diff --git a/train_tokenizer.py b/train_tokenizer.py index a4a9455..35d4075 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -222,10 +222,8 @@ def train_step(state, inputs): videos = jax.make_array_from_process_local_data(videos_sharding, videos) inputs = dict(videos=videos, rng=_rng, dropout_rng=_rng_dropout) - start_time = time.time() train_state, loss, recon, metrics = train_step(train_state, inputs) - elapsed_time = (time.time() - start_time) * 1000 - print(f"Step {step}, loss: {loss}, step time: {elapsed_time}ms") + print(f"Step {step}, loss: {loss}") step += 1 # --- Logging --- @@ -235,7 +233,6 @@ def train_step(state, inputs): { "loss": loss, "step": step, - "step_time_ms": elapsed_time, **metrics, } )