Skip to content

Commit

Permalink
Clear cache after EMA swapping
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Dec 31, 2024
1 parent 11e642b commit 080ae73
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
11 changes: 8 additions & 3 deletions flux_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,8 @@ def train(args):
args.pretrained_model_name_or_path, weight_dtype, "cpu", args.disable_mmap_load_safetensors
)

ema = EMA(flux, beta = args.ema_beta, update_after_step=args.ema_update_after_step, update_every=args.ema_update_every, update_model_with_ema_every=args.ema_switch_every, allow_different_devices=True) if args.ema else None
if args.ema:
ema = EMA(flux, beta = args.ema_beta, update_after_step=args.ema_update_after_step, update_every=args.ema_update_every, update_model_with_ema_every=args.ema_switch_every, allow_different_devices=True) if args.ema else None

if args.gradient_checkpointing:
flux.enable_gradient_checkpointing(cpu_offload=args.cpu_offload_checkpointing)
Expand Down Expand Up @@ -771,7 +772,7 @@ def grad_hook(parameter: torch.Tensor):
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
ema
ema if not args.no_ema_sampling else None
)
optimizer_train_fn()

Expand Down Expand Up @@ -815,7 +816,7 @@ def grad_hook(parameter: torch.Tensor):
num_train_epochs,
global_step,
accelerator.unwrap_model(flux),
ema
ema if not args.no_ema_sampling else None
)

flux_train_utils.sample_images(
Expand Down Expand Up @@ -929,6 +930,10 @@ def setup_parser() -> argparse.ArgumentParser:
type=int,
default=None
)
parser.add_argument(
"--no_ema_sampling",
action="store_true"
)
parser.add_argument(
"--no_shuffle",
action="store_true",
Expand Down
4 changes: 4 additions & 0 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ def sample_images(
)
ema.to("cpu")
flux.to(device)
with torch.cuda.device(device):
torch.cuda.empty_cache()
else:
# Creating list with N elements, where each element is a list of prompt_dicts, and N is the number of processes available (number of devices available)
# prompt_dicts are assigned to lists based on order of processes, to attempt to time the image creation time to match enum order. Probably only works when steps and sampler are identical.
Expand Down Expand Up @@ -173,6 +175,8 @@ def sample_images(
)
ema.to("cpu")
flux.to(device)
with torch.cuda.device(device):
torch.cuda.empty_cache()

torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
Expand Down

0 comments on commit 080ae73

Please sign in to comment.