Skip to content

Commit

Permalink
[Bugfix] Fix for only one DIT worker and a separate VAE worker. (#443)
Browse files Browse the repository at this point in the history
  • Loading branch information
lihuahua123 authored Jan 24, 2025
1 parent dc75a3a commit a4757c9
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 21 deletions.
1 change: 0 additions & 1 deletion examples/ray/ray_hunyuandit_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def main():
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
PipelineClass=xFuserHunyuanDiTPipeline,
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
Expand Down
1 change: 0 additions & 1 deletion examples/ray/ray_pixartsigma_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ def main():
torch.cuda.reset_peak_memory_stats()
start_time = time.time()
output = pipe(
PipelineClass=xFuserPixArtSigmaPipeline,
height=input_config.height,
width=input_config.width,
prompt=input_config.prompt,
Expand Down
6 changes: 3 additions & 3 deletions examples/ray/ray_run.sh
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ mkdir -p ./results
TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning"


N_GPUS=3 # world size
PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1"
N_GPUS=2 # world size
PARALLEL_ARGS="--pipefusion_parallel_degree 1 --ulysses_degree 1 --ring_degree 1"
VAE_PARALLEL_SIZE=1
DIT_PARALLEL_SIZE=2
DIT_PARALLEL_SIZE=1
# CFG_ARGS="--use_cfg_parallel"

# By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance.
Expand Down
2 changes: 2 additions & 0 deletions xfuser/model_executor/pipelines/base_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def __init__(
engine_config: EngineConfig,
):
self.module: DiffusionPipeline
self.engine_config = engine_config
self._init_runtime_state(pipeline=pipeline, engine_config=engine_config)
self._init_fast_attn_state(pipeline=pipeline, engine_config=engine_config)

Expand Down Expand Up @@ -256,6 +257,7 @@ def use_naive_forward(self):
and get_sequence_parallel_world_size() == 1
and get_tensor_model_parallel_world_size() == 1
and get_fast_attn_enable() == False
and get_runtime_state().parallel_config.vae_parallel_size == 0
)

@staticmethod
Expand Down
7 changes: 5 additions & 2 deletions xfuser/model_executor/pipelines/pipeline_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,7 +772,7 @@ def _backbone_forward(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0]).to(latents.dtype)

noise_pred, encoder_hidden_states = self.transformer(
ret = self.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
Expand All @@ -783,7 +783,10 @@ def _backbone_forward(
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

if self.engine_config.parallel_config.dit_parallel_size > 1:
noise_pred, encoder_hidden_states = ret
else:
noise_pred, encoder_hidden_states = ret, None
return noise_pred, encoder_hidden_states

def _scheduler_step(
Expand Down
40 changes: 27 additions & 13 deletions xfuser/model_executor/pipelines/pipeline_hunyuandit.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,19 +903,33 @@ def _backbone_forward(
)

# predict the noise residual
noise_pred = self.transformer(
latents,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
skips=skips,
return_dict=False,
)[0]
if skips is not None:
noise_pred = self.transformer(
latents,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
skips=skips,
return_dict=False,
)[0]
else:
noise_pred = self.transformer(
latents,
t_expand,
encoder_hidden_states=prompt_embeds,
text_embedding_mask=prompt_attention_mask,
encoder_hidden_states_t5=prompt_embeds_2,
text_embedding_mask_t5=prompt_attention_mask_2,
image_meta_size=add_time_ids,
style=style,
image_rotary_emb=image_rotary_emb,
return_dict=False,
)[0]

if is_pipeline_last_stage():
noise_pred, _ = noise_pred.chunk(2, dim=1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -773,14 +773,18 @@ def _backbone_forward(
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latents.shape[0])

noise_pred, encoder_hidden_states = self.transformer(
ret = self.transformer(
hidden_states=latents,
timestep=timestep,
encoder_hidden_states=encoder_hidden_states,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]
if self.engine_config.parallel_config.dit_parallel_size > 1:
noise_pred, encoder_hidden_states = ret
else:
noise_pred, encoder_hidden_states = ret, None

# classifier free guidance
if is_pipeline_last_stage():
Expand Down

0 comments on commit a4757c9

Please sign in to comment.