From 0053b7b79d8bacc1c1155f34f750c68d3f661426 Mon Sep 17 00:00:00 2001 From: Mihir Mahajan Date: Wed, 16 Jul 2025 12:03:47 +0200 Subject: [PATCH] added grad norm logging and grad clipping --- train_dynamics.py | 14 +++++++++++++- train_lam.py | 15 +++++++++++++-- train_tokenizer.py | 14 ++++++++++++-- 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/train_dynamics.py b/train_dynamics.py index 3274fe4..0d5a62c 100644 --- a/train_dynamics.py +++ b/train_dynamics.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field import os +from typing import Optional import einops from flax.training.train_state import TrainState @@ -36,6 +37,7 @@ class Args: min_lr: float = 0.0 max_lr: float = 3e-5 warmup_steps: int = 5000 + grad_clip_threshold: Optional[float] = None # Tokenizer tokenizer_dim: int = 512 latent_patch_dim: int = 32 @@ -108,6 +110,10 @@ def train_step(state, inputs): """Update state and compute metrics""" grad_fn = jax.value_and_grad(dynamics_loss_fn, has_aux=True, allow_int=True) (loss, (recon, metrics)), grads = grad_fn(state.params, state, inputs) + # extract and manually clip grad norm for logging (actual clipping is done in the optax.chain) + raw_grad_norm = optax.global_norm(grads) + metrics["grad_norm"] = jnp.minimum(raw_grad_norm, args.grad_clip_threshold) if args.grad_clip_threshold else raw_grad_norm + state = state.apply_gradients(grads=grads) if args.log_gradients: metrics["gradients_std/"] = jax.tree.map( @@ -201,7 +207,13 @@ def train_step(state, inputs): lr_schedule = optax.warmup_cosine_decay_schedule( args.min_lr, args.max_lr, args.warmup_steps, args.num_steps ) - tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4) + if args.grad_clip_threshold: + tx = optax.chain( + optax.clip_by_global_norm(args.grad_clip_threshold), + optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4) + ) + else: + tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4) train_state = TrainState.create(apply_fn=genie.apply, params=init_params, tx=tx) device_mesh_arr = create_device_mesh((num_devices,)) diff --git a/train_lam.py b/train_lam.py index b0d0026..8279737 100644 --- a/train_lam.py +++ b/train_lam.py @@ -1,8 +1,8 @@ from dataclasses import dataclass, field import os +from typing import Optional import einops -from flax.training import orbax_utils from flax.training.train_state import TrainState from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.experimental.mesh_utils import create_device_mesh @@ -40,6 +40,7 @@ class Args: max_lr: float = 3e-5 warmup_steps: int = 5000 vq_reset_thresh: int = 50 + grad_clip_threshold: Optional[float] = None # LAM model_dim: int = 512 latent_dim: int = 32 @@ -105,6 +106,10 @@ def train_step(state, inputs, action_last_active): rng, inputs["rng"] = jax.random.split(inputs["rng"]) grad_fn = jax.value_and_grad(lam_loss_fn, has_aux=True, allow_int=True) (loss, (recon, idx_counts, metrics)), grads = grad_fn(state.params, state, inputs) + # extract and manually clip grad norm for logging (actual clipping is done in the optax.chain) + raw_grad_norm = optax.global_norm(grads) + metrics["grad_norm"] = jnp.minimum(raw_grad_norm, args.grad_clip_threshold) if args.grad_clip_threshold else raw_grad_norm + state = state.apply_gradients(grads=grads) # --- Reset inactive latent actions --- @@ -194,7 +199,13 @@ def train_step(state, inputs, action_last_active): lr_schedule = optax.warmup_cosine_decay_schedule( args.min_lr, args.max_lr, args.warmup_steps, args.num_steps ) - tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4) + if args.grad_clip_threshold: + tx = optax.chain( + optax.clip_by_global_norm(args.grad_clip_threshold), + optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4) + ) + else: + tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4) train_state = TrainState.create(apply_fn=lam.apply, params=init_params, tx=tx) # FIXME: switch to create_hybrid_device_mesh for runs spanning multiple nodes diff --git a/train_tokenizer.py b/train_tokenizer.py index bbd172b..532046e 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -1,8 +1,8 @@ from dataclasses import dataclass, field import os +from typing import Optional import einops -from flax.training import orbax_utils from flax.training.train_state import TrainState from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.experimental.mesh_utils import create_device_mesh @@ -39,6 +39,7 @@ class Args: min_lr: float = 0.0 max_lr: float = 3e-4 warmup_steps: int = 10000 + grad_clip_threshold: Optional[float] = None # Tokenizer model_dim: int = 512 latent_dim: int = 32 @@ -107,6 +108,9 @@ def tokenizer_loss_fn(params, state, inputs): def train_step(state, inputs): grad_fn = jax.value_and_grad(tokenizer_loss_fn, has_aux=True, allow_int=True) (loss, (recon, metrics)), grads = grad_fn(state.params, state, inputs) + # extract and manually clip grad norm for logging (actual clipping is done in the optax.chain) + raw_grad_norm = optax.global_norm(grads) + metrics["grad_norm"] = jnp.minimum(raw_grad_norm, args.grad_clip_threshold) if args.grad_clip_threshold else raw_grad_norm state = state.apply_gradients(grads=grads) if args.log_gradients: metrics["encoder_gradients_std/"] = jax.tree.map( @@ -188,7 +192,13 @@ def train_step(state, inputs): lr_schedule = optax.warmup_cosine_decay_schedule( args.min_lr, args.max_lr, args.warmup_steps, args.num_steps ) - tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4) + if args.grad_clip_threshold: + tx = optax.chain( + optax.clip_by_global_norm(args.grad_clip_threshold), + optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4) + ) + else: + tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4) train_state = TrainState.create(apply_fn=tokenizer.apply, params=init_params, tx=tx) # FIXME: switch to create_hybrid_device_mesh for runs spanning multiple nodes