Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion train_dynamics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field
import os
import time

import einops
from flax.training.train_state import TrainState
Expand Down Expand Up @@ -70,6 +71,7 @@ class Args:
log_checkpoint_interval: int = 25000
log_checkpoint_keep_period: int = 20000
log_gradients: bool = False
time_measurement_interval: int = 50


args = tyro.cli(Args)
Expand Down Expand Up @@ -242,6 +244,7 @@ def train_step(state, inputs):
prefetch_buffer_size=1,
seed=args.seed,
)
step = 0
initial_state = grain_dataloader._create_initial_state()
grain_iterator = grain.DataLoaderIterator(grain_dataloader, initial_state)

Expand All @@ -268,6 +271,7 @@ def train_step(state, inputs):

# --- TRAIN LOOP ---
dataloader = (jax.make_array_from_process_local_data(videos_sharding, elem) for elem in grain_iterator) # type: ignore
start_time = time.time()
while step < args.num_steps:
for videos in dataloader:
# --- Train step ---
Expand All @@ -280,7 +284,16 @@ def train_step(state, inputs):
mask_rng=_rng_mask,
)
train_state, loss, recon, metrics = train_step(train_state, inputs)
print(f"Step {step}, loss: {loss}")
if step % args.time_measurement_interval == 0:
jax.block_until_ready(train_state)
elapsed_time = (time.time() - start_time)
avg_step_time = elapsed_time / args.time_measurement_interval
print(f"Step {step}, loss: {loss}, avg step time: {avg_step_time:.2f}s")
if args.log and jax.process_index() == 0:
wandb.log({"avg_step_time_s": avg_step_time, "step": step})
start_time = time.time()
else:
print(f"Step {step}, loss: {loss}")
step += 1

# --- Logging ---
Expand Down
14 changes: 13 additions & 1 deletion train_lam.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field
import os
import time

import einops
from flax.training import orbax_utils
Expand Down Expand Up @@ -59,6 +60,7 @@ class Args:
log_image_interval: int = 250
ckpt_dir: str = ""
log_checkpoint_interval: int = 10000
time_measurement_interval: int = 50
log_checkpoint_keep_period: int = 20000


Expand Down Expand Up @@ -257,6 +259,7 @@ def train_step(state, inputs, action_last_active):

# --- TRAIN LOOP ---
dataloader = (jax.make_array_from_process_local_data(videos_sharding, elem) for elem in grain_iterator) # type: ignore
start_time = time.time()
print(f"Starting training from step {step}...")
while step < args.num_steps:
for videos in dataloader:
Expand All @@ -267,7 +270,16 @@ def train_step(state, inputs, action_last_active):
train_state, loss, recon, action_last_active, metrics = train_step(
train_state, inputs, action_last_active
)
print(f"Step {step}, loss: {loss}")
if step % args.time_measurement_interval == 0:
jax.block_until_ready(train_state)
elapsed_time = (time.time() - start_time)
avg_step_time = elapsed_time / args.time_measurement_interval
print(f"Step {step}, loss: {loss}, avg step time: {avg_step_time:.2f}s")
if args.log and jax.process_index() == 0:
wandb.log({"avg_step_time_s": avg_step_time, "step": step})
start_time = time.time()
else:
print(f"Step {step}, loss: {loss}")
step += 1

# --- Logging ---
Expand Down
14 changes: 13 additions & 1 deletion train_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field
import os
import time

import einops
from flax.training import orbax_utils
Expand Down Expand Up @@ -60,6 +61,7 @@ class Args:
log_checkpoint_interval: int = 10000
log_checkpoint_keep_period: int = 20000
log_gradients: bool = False
time_measurement_interval: int = 50


args = tyro.cli(Args)
Expand Down Expand Up @@ -249,6 +251,7 @@ def train_step(state, inputs):
print(f"Restored dataloader and model state from step {step}")

# --- TRAIN LOOP ---
start_time = time.time()
dataloader = (jax.make_array_from_process_local_data(videos_sharding, elem) for elem in grain_iterator) # type: ignore
print(f"Starting training from step {step}...")
while step < args.num_steps:
Expand All @@ -258,7 +261,16 @@ def train_step(state, inputs):

inputs = dict(videos=videos, rng=_rng, dropout_rng=_rng_dropout)
train_state, loss, recon, metrics = train_step(train_state, inputs)
print(f"Step {step}, loss: {loss}")
if step % args.time_measurement_interval == 0:
jax.block_until_ready(train_state)
elapsed_time = (time.time() - start_time)
avg_step_time = elapsed_time / args.time_measurement_interval
print(f"Step {step}, loss: {loss}, avg step time: {avg_step_time:.2f}s")
if args.log and jax.process_index() == 0:
wandb.log({"avg_step_time_s": avg_step_time, "step": step})
start_time = time.time()
else:
print(f"Step {step}, loss: {loss}")
step += 1

# --- Logging ---
Expand Down