Skip to content

Commit

Permalink
Make EMA gamma and power configurable
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Jan 4, 2025
1 parent 5bd4f3a commit d507597
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def train(args):
)

if args.ema:
ema = EMA(flux, beta = args.ema_beta, update_after_step=args.ema_update_after_step, update_every=args.ema_update_every, update_model_with_ema_every=args.ema_switch_every, allow_different_devices=True) if args.ema else None
ema = EMA(flux, beta = args.ema_beta, update_after_step=args.ema_update_after_step, update_every=args.ema_update_every, update_model_with_ema_every=args.ema_switch_every, allow_different_devices=True, inv_gamma=args.ema_gamma, power=args.ema_power) if args.ema else None

if args.gradient_checkpointing:
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
Expand Down Expand Up @@ -941,6 +941,16 @@ def setup_parser() -> argparse.ArgumentParser:
"--no_ema_sampling",
action="store_true"
)
parser.add_argument(
"--ema_gamma",
type=float,
default=1.0
)
parser.add_argument(
"--ema_power",
type=float,
default=2/3
)
parser.add_argument(
"--no_shuffle",
action="store_true",
Expand Down

0 comments on commit d507597

Please sign in to comment.