Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion train_dynamics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, field
import os
from typing import Optional

import einops
from flax.training.train_state import TrainState
Expand Down Expand Up @@ -40,6 +41,7 @@ class Args:
wsd_decay_steps: int = 10000 # NOTE: wsd_decay_steps will only be used when using a wsd-schedule
warmup_steps: int = 5000
lr_schedule : str = "wsd" # supported options: wsd, cos
grad_clip_threshold: Optional[float] = None
# Tokenizer
tokenizer_dim: int = 512
latent_patch_dim: int = 32
Expand Down Expand Up @@ -131,6 +133,10 @@ def train_step(state, inputs):
"""Update state and compute metrics"""
grad_fn = jax.value_and_grad(dynamics_loss_fn, has_aux=True, allow_int=True)
(loss, (recon, metrics)), grads = grad_fn(state.params, state, inputs)
# extract and manually clip grad norm for logging (actual clipping is done in the optax.chain)
raw_grad_norm = optax.global_norm(grads)
metrics["grad_norm"] = jnp.minimum(raw_grad_norm, args.grad_clip_threshold) if args.grad_clip_threshold else raw_grad_norm

state = state.apply_gradients(grads=grads)
if args.log_gradients:
metrics["gradients_std/"] = jax.tree.map(
Expand Down Expand Up @@ -232,7 +238,13 @@ def train_step(state, inputs):
args.num_steps,
args.warmup_steps,
args.wsd_decay_steps)
tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype)
if args.grad_clip_threshold:
tx = optax.chain(
optax.clip_by_global_norm(args.grad_clip_threshold),
optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype)
)
else:
tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype)
train_state = TrainState.create(apply_fn=genie.apply, params=init_params, tx=tx)

device_mesh_arr = create_device_mesh((num_devices,))
Expand Down
15 changes: 13 additions & 2 deletions train_lam.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
import os
from typing import Optional

import einops
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental.mesh_utils import create_device_mesh
Expand Down Expand Up @@ -44,6 +44,7 @@ class Args:
warmup_steps: int = 5000
lr_schedule : str = "wsd" # supported options: wsd, cos
vq_reset_thresh: int = 50
grad_clip_threshold: Optional[float] = None
# LAM
model_dim: int = 512
latent_dim: int = 32
Expand Down Expand Up @@ -111,6 +112,10 @@ def train_step(state, inputs, action_last_active):
rng, inputs["rng"] = jax.random.split(inputs["rng"])
grad_fn = jax.value_and_grad(lam_loss_fn, has_aux=True, allow_int=True)
(loss, (recon, idx_counts, metrics)), grads = grad_fn(state.params, state, inputs)
# extract and manually clip grad norm for logging (actual clipping is done in the optax.chain)
raw_grad_norm = optax.global_norm(grads)
metrics["grad_norm"] = jnp.minimum(raw_grad_norm, args.grad_clip_threshold) if args.grad_clip_threshold else raw_grad_norm

state = state.apply_gradients(grads=grads)

# --- Reset inactive latent actions ---
Expand Down Expand Up @@ -208,7 +213,13 @@ def train_step(state, inputs, action_last_active):
args.num_steps,
args.warmup_steps,
args.wsd_decay_steps)
tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype)
if args.grad_clip_threshold:
tx = optax.chain(
optax.clip_by_global_norm(args.grad_clip_threshold),
optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype)
)
else:
tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype)
train_state = TrainState.create(apply_fn=lam.apply, params=init_params, tx=tx)

# FIXME: switch to create_hybrid_device_mesh for runs spanning multiple nodes
Expand Down
14 changes: 12 additions & 2 deletions train_tokenizer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from dataclasses import dataclass, field
import os
from typing import Optional

import einops
from flax.training import orbax_utils
from flax.training.train_state import TrainState
from jax.sharding import Mesh, PartitionSpec, NamedSharding
from jax.experimental.mesh_utils import create_device_mesh
Expand Down Expand Up @@ -43,6 +43,7 @@ class Args:
wsd_decay_steps: int = 20000 # NOTE: wsd_decay_steps will only be used when using a wsd-schedule
lr_schedule: str = "wsd" # supported options: wsd, cos
warmup_steps: int = 10000
grad_clip_threshold: Optional[float] = None
# Tokenizer
model_dim: int = 512
latent_dim: int = 32
Expand Down Expand Up @@ -116,6 +117,9 @@ def tokenizer_loss_fn(params, state, inputs):
def train_step(state, inputs):
grad_fn = jax.value_and_grad(tokenizer_loss_fn, has_aux=True, allow_int=True)
(loss, (recon, metrics)), grads = grad_fn(state.params, state, inputs)
# extract and manually clip grad norm for logging (actual clipping is done in the optax.chain)
raw_grad_norm = optax.global_norm(grads)
metrics["grad_norm"] = jnp.minimum(raw_grad_norm, args.grad_clip_threshold) if args.grad_clip_threshold else raw_grad_norm
state = state.apply_gradients(grads=grads)
if args.log_gradients:
metrics["encoder_gradients_std/"] = jax.tree.map(
Expand Down Expand Up @@ -205,7 +209,13 @@ def train_step(state, inputs):
args.num_steps,
args.warmup_steps,
args.wsd_decay_steps)
tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype)
if args.grad_clip_threshold:
tx = optax.chain(
optax.clip_by_global_norm(args.grad_clip_threshold),
optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype)
)
else:
tx = optax.adamw(learning_rate=lr_schedule, b1=0.9, b2=0.9, weight_decay=1e-4, mu_dtype=args.dtype)
train_state = TrainState.create(apply_fn=tokenizer.apply, params=init_params, tx=tx)

# FIXME: switch to create_hybrid_device_mesh for runs spanning multiple nodes
Expand Down