Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix for only one DIT worker and a separate VAE worker. #443

Merged
merged 15 commits into from
Jan 24, 2025
1 change: 0 additions & 1 deletion examples/ray/ray_hunyuandit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def main():
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
PipelineClass=xFuserHunyuanDiTPipeline,
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
Expand Down
1 change: 0 additions & 1 deletion examples/ray/ray_pixartsigma_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def main():
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
PipelineClass=xFuserPixArtSigmaPipeline,
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
Expand Down
6 changes: 3 additions & 3 deletions examples/ray/ray_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ mkdir -p ./results
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


N_GPUS=3 # world size
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"
N_GPUS=2 # world size
PARALLEL_ARGS="--pipefusion_parallel_degree 1 --ulysses_degree 1 --ring_degree 1"
VAE_PARALLEL_SIZE=1
DIT_PARALLEL_SIZE=2
DIT_PARALLEL_SIZE=1
# CFG_ARGS="--use_cfg_parallel"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
Expand Down
2 changes: 2 additions & 0 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
engine_config: EngineConfig,
):
self.module: DiffusionPipeline
self.engine_config = engine_config
self._init_runtime_state(pipeline=pipeline, engine_config=engine_config)
self._init_fast_attn_state(pipeline=pipeline, engine_config=engine_config)

Expand Down Expand Up @@ -256,6 +257,7 @@ def use_naive_forward(self):
and get_sequence_parallel_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
and get_fast_attn_enable() == False
and get_runtime_state().parallel_config.vae_parallel_size == 0
)

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions xfuser/model_executor/pipelines/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def _backbone_forward(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

noise_pred, encoder_hidden_states = self.transformer(
ret = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
Expand All @@ -783,7 +783,10 @@ def _backbone_forward(
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

if self.engine_config.parallel_config.dit_parallel_size > 1:
noise_pred, encoder_hidden_states = ret
else:
noise_pred, encoder_hidden_states = ret, None
return noise_pred, encoder_hidden_states

def _scheduler_step(
Expand Down
40 changes: 27 additions & 13 deletions xfuser/model_executor/pipelines/pipeline_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,19 +903,33 @@ def _backbone_forward(
)

# predict the noise residual
noise_pred = self.transformer(
latents,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
skips=skips,
return_dict=False,
)[0]
if skips is not None:
noise_pred = self.transformer(
latents,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
skips=skips,
return_dict=False,
)[0]
else:
noise_pred = self.transformer(
latents,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]

if is_pipeline_last_stage():
noise_pred, _ = noise_pred.chunk(2, dim=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,14 +773,18 @@ def _backbone_forward(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])

noise_pred, encoder_hidden_states = self.transformer(
ret = self.transformer(
hidden_states=latents,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if self.engine_config.parallel_config.dit_parallel_size > 1:
noise_pred, encoder_hidden_states = ret
else:
noise_pred, encoder_hidden_states = ret, None

# classifier free guidance
if is_pipeline_last_stage():
Expand Down
Loading