From ec182a32ce587435cc1217f701ab3b2a5d1cd1ff Mon Sep 17 00:00:00 2001 From: LazyBusyYang Date: Wed, 8 Jan 2025 10:25:31 +0000 Subject: [PATCH] Add method to prepare extra step kwargs for scheduler in xFuserCogVideoXPipeline This update introduces the `prepare_extra_step_kwargs` method, which prepares additional keyword arguments for the scheduler step based on its signature. It checks for the presence of 'eta' and 'generator' parameters to ensure compatibility with different schedulers. This enhancement improves the flexibility and usability of the pipeline. --- .../pipelines/pipeline_cogvideox.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/xfuser/model_executor/pipelines/pipeline_cogvideox.py b/xfuser/model_executor/pipelines/pipeline_cogvideox.py index 770306fd..5f877a17 100644 --- a/xfuser/model_executor/pipelines/pipeline_cogvideox.py +++ b/xfuser/model_executor/pipelines/pipeline_cogvideox.py @@ -3,6 +3,7 @@ import torch import torch.distributed +import inspect from diffusers import CogVideoXPipeline from diffusers.pipelines.cogvideo.pipeline_cogvideox import ( CogVideoXPipelineOutput, @@ -404,6 +405,24 @@ def _init_sync_pipeline( ) return latents, prompt_embeds, image_rotary_emb + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.module.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.module.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + @property def interrupt(self): return self._interrupt