diff --git a/train_dynamics.py b/train_dynamics.py index a8e6a2a..b0d0425 100644 --- a/train_dynamics.py +++ b/train_dynamics.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field import os +from typing import Optional import einops from flax.training.train_state import TrainState @@ -40,6 +41,7 @@ class Args: wsd_decay_steps: int = 10000 # NOTE: wsd_decay_steps will only be used when using a wsd-schedule warmup_steps: int = 5000 lr_schedule : str = "wsd" # supported options: wsd, cos + grad_clip_threshold: Optional[float] = None # Tokenizer tokenizer_dim: int = 512 latent_patch_dim: int = 32 @@ -131,6 +133,10 @@ def train_step(state, inputs): """Update state and compute metrics""" grad_fn = jax.value_and_grad(dynamics_loss_fn, has_aux=True, allow_int=True) (loss, (recon, metrics)), grads = grad_fn(state.params, state, inputs) + # extract and manually clip grad norm for logging (actual clipping is done in the optax.chain) + raw_grad_norm = optax.global_norm(grads) + metrics["grad_norm"] = jnp.minimum(raw_grad_norm, args.grad_clip_threshold) if args.grad_clip_threshold else raw_grad_norm + state = state.apply_gradients(grads=grads) if args.log_gradients: metrics["gradients_std/"] = jax.tree.map( @@ -232,7 +238,13 @@ def train_step(state, inputs): args.num_steps, args.warmup_steps, args.wsd_decay_steps) - tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype) + if args.grad_clip_threshold: + tx = optax.chain( + optax.clip_by_global_norm(args.grad_clip_threshold), + optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype) + ) + else: + tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype) train_state = TrainState.create(apply_fn=genie.apply, params=init_params, tx=tx) device_mesh_arr = create_device_mesh((num_devices,)) diff --git a/train_lam.py b/train_lam.py index 52e3ffc..1072a3e 100644 --- a/train_lam.py +++ b/train_lam.py @@ -1,8 +1,8 @@ from dataclasses import dataclass, field import os +from typing import Optional import einops -from flax.training import orbax_utils from flax.training.train_state import TrainState from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.experimental.mesh_utils import create_device_mesh @@ -44,6 +44,7 @@ class Args: warmup_steps: int = 5000 lr_schedule : str = "wsd" # supported options: wsd, cos vq_reset_thresh: int = 50 + grad_clip_threshold: Optional[float] = None # LAM model_dim: int = 512 latent_dim: int = 32 @@ -111,6 +112,10 @@ def train_step(state, inputs, action_last_active): rng, inputs["rng"] = jax.random.split(inputs["rng"]) grad_fn = jax.value_and_grad(lam_loss_fn, has_aux=True, allow_int=True) (loss, (recon, idx_counts, metrics)), grads = grad_fn(state.params, state, inputs) + # extract and manually clip grad norm for logging (actual clipping is done in the optax.chain) + raw_grad_norm = optax.global_norm(grads) + metrics["grad_norm"] = jnp.minimum(raw_grad_norm, args.grad_clip_threshold) if args.grad_clip_threshold else raw_grad_norm + state = state.apply_gradients(grads=grads) # --- Reset inactive latent actions --- @@ -208,7 +213,13 @@ def train_step(state, inputs, action_last_active): args.num_steps, args.warmup_steps, args.wsd_decay_steps) - tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype) + if args.grad_clip_threshold: + tx = optax.chain( + optax.clip_by_global_norm(args.grad_clip_threshold), + optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype) + ) + else: + tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype) train_state = TrainState.create(apply_fn=lam.apply, params=init_params, tx=tx) # FIXME: switch to create_hybrid_device_mesh for runs spanning multiple nodes diff --git a/train_tokenizer.py b/train_tokenizer.py index 7fe160f..ce24c54 100644 --- a/train_tokenizer.py +++ b/train_tokenizer.py @@ -1,8 +1,8 @@ from dataclasses import dataclass, field import os +from typing import Optional import einops -from flax.training import orbax_utils from flax.training.train_state import TrainState from jax.sharding import Mesh, PartitionSpec, NamedSharding from jax.experimental.mesh_utils import create_device_mesh @@ -43,6 +43,7 @@ class Args: wsd_decay_steps: int = 20000 # NOTE: wsd_decay_steps will only be used when using a wsd-schedule lr_schedule: str = "wsd" # supported options: wsd, cos warmup_steps: int = 10000 + grad_clip_threshold: Optional[float] = None # Tokenizer model_dim: int = 512 latent_dim: int = 32 @@ -116,6 +117,9 @@ def tokenizer_loss_fn(params, state, inputs): def train_step(state, inputs): grad_fn = jax.value_and_grad(tokenizer_loss_fn, has_aux=True, allow_int=True) (loss, (recon, metrics)), grads = grad_fn(state.params, state, inputs) + # extract and manually clip grad norm for logging (actual clipping is done in the optax.chain) + raw_grad_norm = optax.global_norm(grads) + metrics["grad_norm"] = jnp.minimum(raw_grad_norm, args.grad_clip_threshold) if args.grad_clip_threshold else raw_grad_norm state = state.apply_gradients(grads=grads) if args.log_gradients: metrics["encoder_gradients_std/"] = jax.tree.map( @@ -205,7 +209,13 @@ def train_step(state, inputs): args.num_steps, args.warmup_steps, args.wsd_decay_steps) - tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype) + if args.grad_clip_threshold: + tx = optax.chain( + optax.clip_by_global_norm(args.grad_clip_threshold), + optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype) + ) + else: + tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype) train_state = TrainState.create(apply_fn=tokenizer.apply, params=init_params, tx=tx) # FIXME: switch to create_hybrid_device_mesh for runs spanning multiple nodes