Skip to content

Commit

Permalink
Gather wandb samples across processes
Browse files Browse the repository at this point in the history
  • Loading branch information
deepdelirious committed Jan 3, 2025
1 parent 65c8b62 commit 042e743
Showing 1 changed file with 23 additions and 17 deletions.
40 changes: 23 additions & 17 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,12 @@ def sample_images(
except Exception:
pass

sample_results = []
if distributed_state.num_processes <= 1:
# If only one device is available, just use the original prompt list. We don't need to care about the distribution of prompts.
with torch.no_grad(), accelerator.autocast():
for prompt_dict in prompts:
sample_image_inference(
sample_results.append(sample_image_inference(
accelerator,
args,
flux,
Expand All @@ -104,14 +105,14 @@ def sample_images(
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
))
if ema:
model_params = [param.detach().cpu().clone() for _, param in ema.get_params_iter(flux)]
for (_, model_param), (_, ema_param) in zip(ema.get_params_iter(flux), ema.get_params_iter(ema.ema_model)):
ema_param = ema_param.to(model_param.device)
model_param.copy_(ema_param)
for prompt_dict in prompts:
sample_image_inference(
sample_results.append(sample_image_inference(
accelerator,
args,
flux,
Expand All @@ -125,7 +126,7 @@ def sample_images(
prompt_replacement,
controlnet,
file_suffix = "_ema"
)
))
for (_, model_param), original_model_param in zip(ema.get_params_iter(flux), model_params):
original_model_param = original_model_param.to(model_param.device)
model_param.data.copy_(original_model_param.data)
Expand All @@ -140,7 +141,7 @@ def sample_images(
with torch.no_grad():
with distributed_state.split_between_processes(per_process_prompts) as prompt_dict_lists:
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
sample_results.append(sample_image_inference(
accelerator,
args,
flux,
Expand All @@ -153,14 +154,14 @@ def sample_images(
sample_prompts_te_outputs,
prompt_replacement,
controlnet
)
))
if ema:
model_params = [param.detach().cpu().clone() for _, param in ema.get_params_iter(flux)]
for (_, model_param), (_, ema_param) in zip(ema.get_params_iter(flux), ema.get_params_iter(ema.ema_model)):
ema_param = ema_param.to(model_param.device)
model_param.copy_(ema_param)
for prompt_dict in prompt_dict_lists[0]:
sample_image_inference(
sample_results.append(sample_image_inference(
accelerator,
args,
flux,
Expand All @@ -174,11 +175,24 @@ def sample_images(
prompt_replacement,
controlnet,
file_suffix = "_ema"
)
))
for (_, model_param), original_model_param in zip(ema.get_params_iter(flux), model_params):
original_model_param = original_model_param.to(model_param.device)
model_param.data.copy_(original_model_param.data)
model_params = None
accelerator.wait_for_everyone()
accelerator.utils.gather_object(sample_results)

for label, prompt, image in sample_results:
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")

import wandb

# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({label: wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption



torch.set_rng_state(rng_state)
if cuda_rng_state is not None:
Expand Down Expand Up @@ -315,15 +329,7 @@ def sample_image_inference(
img_filename = f"{'' if args.output_name is None else args.output_name + '_'}{num_suffix}_{i:02d}_{ts_str}{seed_suffix}{file_suffix}.png"
image.save(os.path.join(save_dir, img_filename))

# send images to wandb if enabled
if "wandb" in [tracker.name for tracker in accelerator.trackers]:
wandb_tracker = accelerator.get_tracker("wandb")

import wandb

# not to commit images to avoid inconsistency between training and logging steps
wandb_tracker.log({f"sample_{i}{file_suffix}": wandb.Image(image, caption=prompt)}, commit=False) # positive prompt as a caption

return f"sample_{i}{file_suffix}", prompt, image

def time_shift(mu: float, sigma: float, t: torch.Tensor):
return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
Expand Down

0 comments on commit 042e743

Please sign in to comment.