diff --git a/examples/ray/ray_hunyuandit_example.py b/examples/ray/ray_hunyuandit_example.py index 51938a44..423b88aa 100644 --- a/examples/ray/ray_hunyuandit_example.py +++ b/examples/ray/ray_hunyuandit_example.py @@ -43,7 +43,6 @@ def main(): torch.cuda.reset_peak_memory_stats() start_time = time.time() output = pipe( - PipelineClass=xFuserHunyuanDiTPipeline, height=input_config.height, width=input_config.width, prompt=input_config.prompt, diff --git a/examples/ray/ray_pixartsigma_example.py b/examples/ray/ray_pixartsigma_example.py index 496d5dea..df735200 100644 --- a/examples/ray/ray_pixartsigma_example.py +++ b/examples/ray/ray_pixartsigma_example.py @@ -41,7 +41,6 @@ def main(): torch.cuda.reset_peak_memory_stats() start_time = time.time() output = pipe( - PipelineClass=xFuserPixArtSigmaPipeline, height=input_config.height, width=input_config.width, prompt=input_config.prompt, diff --git a/examples/ray/ray_run.sh b/examples/ray/ray_run.sh index 6e7a633d..93115784 100644 --- a/examples/ray/ray_run.sh +++ b/examples/ray/ray_run.sh @@ -32,10 +32,10 @@ mkdir -p ./results TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning" -N_GPUS=3 # world size -PARALLEL_ARGS="--pipefusion_parallel_degree 2 --ulysses_degree 1 --ring_degree 1" +N_GPUS=2 # world size +PARALLEL_ARGS="--pipefusion_parallel_degree 1 --ulysses_degree 1 --ring_degree 1" VAE_PARALLEL_SIZE=1 -DIT_PARALLEL_SIZE=2 +DIT_PARALLEL_SIZE=1 # CFG_ARGS="--use_cfg_parallel" # By default, num_pipeline_patch = pipefusion_degree, and you can tune this parameter to achieve optimal performance. diff --git a/xfuser/model_executor/pipelines/base_pipeline.py b/xfuser/model_executor/pipelines/base_pipeline.py index 80564f51..ad94e1dd 100644 --- a/xfuser/model_executor/pipelines/base_pipeline.py +++ b/xfuser/model_executor/pipelines/base_pipeline.py @@ -144,6 +144,7 @@ def __init__( engine_config: EngineConfig, ): self.module: DiffusionPipeline + self.engine_config = engine_config self._init_runtime_state(pipeline=pipeline, engine_config=engine_config) self._init_fast_attn_state(pipeline=pipeline, engine_config=engine_config) @@ -256,6 +257,7 @@ def use_naive_forward(self): and get_sequence_parallel_world_size() == 1 and get_tensor_model_parallel_world_size() == 1 and get_fast_attn_enable() == False + and get_runtime_state().parallel_config.vae_parallel_size == 0 ) @staticmethod diff --git a/xfuser/model_executor/pipelines/pipeline_flux.py b/xfuser/model_executor/pipelines/pipeline_flux.py index 00152262..1550eb05 100644 --- a/xfuser/model_executor/pipelines/pipeline_flux.py +++ b/xfuser/model_executor/pipelines/pipeline_flux.py @@ -772,7 +772,7 @@ def _backbone_forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred, encoder_hidden_states = self.transformer( + ret = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, @@ -783,7 +783,10 @@ def _backbone_forward( joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] - + if self.engine_config.parallel_config.dit_parallel_size > 1: + noise_pred, encoder_hidden_states = ret + else: + noise_pred, encoder_hidden_states = ret, None return noise_pred, encoder_hidden_states def _scheduler_step( diff --git a/xfuser/model_executor/pipelines/pipeline_hunyuandit.py b/xfuser/model_executor/pipelines/pipeline_hunyuandit.py index bd9d4310..b8d11812 100644 --- a/xfuser/model_executor/pipelines/pipeline_hunyuandit.py +++ b/xfuser/model_executor/pipelines/pipeline_hunyuandit.py @@ -903,19 +903,33 @@ def _backbone_forward( ) # predict the noise residual - noise_pred = self.transformer( - latents, - t_expand, - encoder_hidden_states=prompt_embeds, - text_embedding_mask=prompt_attention_mask, - encoder_hidden_states_t5=prompt_embeds_2, - text_embedding_mask_t5=prompt_attention_mask_2, - image_meta_size=add_time_ids, - style=style, - image_rotary_emb=image_rotary_emb, - skips=skips, - return_dict=False, - )[0] + if skips is not None: + noise_pred = self.transformer( + latents, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + skips=skips, + return_dict=False, + )[0] + else: + noise_pred = self.transformer( + latents, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] if is_pipeline_last_stage(): noise_pred, _ = noise_pred.chunk(2, dim=1) diff --git a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py index 48bed7a2..79b009e0 100644 --- a/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py +++ b/xfuser/model_executor/pipelines/pipeline_stable_diffusion_3.py @@ -773,7 +773,7 @@ def _backbone_forward( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) - noise_pred, encoder_hidden_states = self.transformer( + ret = self.transformer( hidden_states=latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states, @@ -781,6 +781,10 @@ def _backbone_forward( joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + if self.engine_config.parallel_config.dit_parallel_size > 1: + noise_pred, encoder_hidden_states = ret + else: + noise_pred, encoder_hidden_states = ret, None # classifier free guidance if is_pipeline_last_stage():