diff --git a/jasmine/train_dynamics.py b/jasmine/train_dynamics.py index 06cd966..62d84cd 100644 --- a/jasmine/train_dynamics.py +++ b/jasmine/train_dynamics.py @@ -82,7 +82,7 @@ class Args: mask_limit: float = 0.5 z_loss_weight: float = 0.0 param_dtype = jnp.float32 - dtype = jnp.bfloat16 + dtype = jnp.float32 use_flash_attention: bool = True use_gt_actions: bool = False # Logging