Skip to content

Commit

Permalink
fix: flux sync pipeline bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
Eigensystem committed Oct 13, 2024
1 parent a354858 commit 8a08550
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 7 deletions.
24 changes: 19 additions & 5 deletions benchmark/single_node_latency_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,19 +44,33 @@ def main():
action="store_true",
help="Do not use resolution binning",
)
parser.add_argument(
"--no_use_cfg_parallel",
action="store_true",
help="Do not use split batch parallelism",
)
parser.add_argument(
"--num_inference_steps",
type=int,
default=20,
help="Number of inference steps",
)
args = parser.parse_args()
MODEL_ID = args.model_id
SIZES = args.sizes
SCRIPT = args.script
N_GPUS = args.n_gpus
NOT_USE_CFG = args.no_use_cfg_parallel
STEPS = args.num_inference_steps
RESOLUTION_BINNING = (
"--no_use_resolution_binning" if args.no_use_resolution_binning else ""
)

visited = set()
dp_degree = 1
cfg_degree_list = [1] if NOT_USE_CFG else [1, 2]
for size in SIZES:
for cfg_degree in [1, 2]:
for cfg_degree in cfg_degree_list:
model_parallel_degree = N_GPUS // cfg_degree
for i in range(int(math.log2(model_parallel_degree)) + 1):
pp_degree = int(math.pow(2, i))
Expand Down Expand Up @@ -104,10 +118,10 @@ def main():
flush=True,
)
cmd = (
f"torchrun --nproc_per_node={N_GPUS} {SCRIPT} --prompt 'A small cat' --output_type 'latent' --model {MODEL_ID} "
f"torchrun --master_port 29501 --nproc_per_node={N_GPUS} {SCRIPT} --prompt 'A small cat' --output_type 'latent' --model {MODEL_ID} "
f"--height {size} --width {size} --warmup_steps {warmup_step} "
f"{RESOLUTION_BINNING} --use_cfg_parallel --ulysses_degree {ulysses_degree} --ring_degree {ring_degree} "
f"--pipefusion_parallel_degree {pp_degree} --num_pipeline_patch {num_pipeline_patches}"
f"--pipefusion_parallel_degree {pp_degree} --num_pipeline_patch {num_pipeline_patches} --num_inference_steps {STEPS}"
)
run_command(cmd)
else:
Expand All @@ -116,10 +130,10 @@ def main():
flush=True,
)
cmd = (
f"torchrun --nproc_per_node={N_GPUS} {SCRIPT} --prompt 'A small cat' --output_type 'latent' --model {MODEL_ID} "
f"torchrun --master_port 29501 --nproc_per_node={N_GPUS} {SCRIPT} --prompt 'A small cat' --output_type 'latent' --model {MODEL_ID} "
f"--height {size} --width {size} --warmup_steps {warmup_step} "
f"{RESOLUTION_BINNING} --ulysses_degree {ulysses_degree} --ring_degree {ring_degree} "
f"--pipefusion_parallel_degree {pp_degree} --num_pipeline_patch {num_pipeline_patches} "
f"--pipefusion_parallel_degree {pp_degree} --num_pipeline_patch {num_pipeline_patches} --num_inference_steps {STEPS}"
)

run_command(cmd)
Expand Down
1 change: 1 addition & 0 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def prepare_run(
use_resolution_binning=input_config.use_resolution_binning,
num_inference_steps=steps,
generator=torch.Generator(device="cuda").manual_seed(42),
output_type=input_config.output_type,
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps

Expand Down
5 changes: 3 additions & 2 deletions xfuser/model_executor/pipelines/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def prepare_run(
num_inference_steps=steps,
max_sequence_length=input_config.max_sequence_length,
generator=torch.Generator(device="cuda").manual_seed(42),
output_type=input_config.output_type,
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps

Expand Down Expand Up @@ -501,8 +502,8 @@ def _sync_pipeline(
sp_latents_list[sp_patch_idx][
:,
get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_start_idx_local[pp_patch_idx + 1],
.pp_patches_token_start_idx_local[pp_patch_idx] : get_runtime_state()
.pp_patches_token_start_idx_local[pp_patch_idx + 1],
:,
]
for sp_patch_idx in range(sp_degree)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def prepare_run(
prompt=prompt,
num_inference_steps=steps,
generator=torch.Generator(device="cuda").manual_seed(42),
output_type=input_config.output_type,
)
get_runtime_state().runtime_config.warmup_steps = warmup_steps

Expand Down

0 comments on commit 8a08550

Please sign in to comment.