diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py index 770306fd..5f877a17 100644 --- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py +++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py @@ -3,6 +3,7 @@ import torch import torch.distributed +import inspect from diffusers import CogVideoXPipeline from diffusers.pipelines.cogvideo.pipeline_cogvideox import ( CogVideoXPipelineOutput, @@ -404,6 +405,24 @@ def _init_sync_pipeline( ) return latents, prompt_embeds, image_rotary_emb + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.module.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.module.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + @property def interrupt(self): return self._interrupt