diff --git a/README.md b/README.md index 19e790a9..05f77c0d 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,20 @@ We also have implemented the following parallel strategies for reference: 1. Tensor Parallelism 2. [DistriFusion](https://arxiv.org/abs/2402.19481) +

Cache Acceleration

+ +Cache method is inspired by work from TeaCache(https://github.com/ali-vilab/TeaCache.git) and ParaAttn(https://github.com/chengzeyi/ParaAttention.git); We adapted the TeaCache and First-Block-Cache in xDiT. + +This method is not orthogonal to parallel in xDiT. Only when SP or no parrallelism can activate the cache function. + +To use this functionality, you can activate it by `--use_teacache` or `--use_fbcache`, which activate TeaCache and First-Block-Cache respectively. Right now, this repo only supports FLUX model. + +The Performance shown as below, tested on 4 H20 with SP=4: +| 方法 | 性能 | +|----------------|--------| +| 原始 | 2.02s | +| use_teacache | 1.58s | +| use_fbcache | 0.93s |

Computing Acceleration

diff --git a/examples/flux_example.py b/examples/flux_example.py index ff578bbf..921a6262 100644 --- a/examples/flux_example.py +++ b/examples/flux_example.py @@ -11,6 +11,10 @@ get_data_parallel_world_size, get_runtime_state, is_dp_last_group, + get_pipeline_parallel_world_size, + get_classifier_free_guidance_world_size, + get_tensor_model_parallel_world_size, + get_data_parallel_world_size, ) @@ -45,7 +49,27 @@ def main(): parameter_peak_memory = torch.cuda.max_memory_allocated(device=f"cuda:{local_rank}") pipe.prepare_run(input_config, steps=1) + from xfuser.model_executor.plugins.cache_.diffusers_adapters import apply_cache_on_transformer + use_cache = engine_args.use_teacache or engine_args.use_fbcache + if (use_cache + and get_pipeline_parallel_world_size() == 1 + and get_classifier_free_guidance_world_size() == 1 + and get_tensor_model_parallel_world_size() == 1 + ): + cache_args = { + "rel_l1_thresh": 0.6, + "return_hidden_states_first": False, + "num_steps": input_config.num_inference_steps, + } + + if engine_args.use_fbcache and engine_args.use_teacache: + cache_args["use_cache"] = "Fb" + elif engine_args.use_teacache: + cache_args["use_cache"] = "Tea" + elif engine_args.use_fbcache: + cache_args["use_cache"] = "Fb" + pipe.transformer = apply_cache_on_transformer(pipe.transformer, **cache_args) torch.cuda.reset_peak_memory_stats() start_time = time.time() output = pipe( diff --git a/examples/run.sh b/examples/run.sh index 12897814..12473309 100644 --- a/examples/run.sh +++ b/examples/run.sh @@ -10,7 +10,7 @@ declare -A MODEL_CONFIGS=( ["Pixart-alpha"]="pixartalpha_example.py /cfs/dit/PixArt-XL-2-1024-MS 20" ["Pixart-sigma"]="pixartsigma_example.py /cfs/dit/PixArt-Sigma-XL-2-2K-MS 20" ["Sd3"]="sd3_example.py /cfs/dit/stable-diffusion-3-medium-diffusers 20" - ["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev 28" + ["Flux"]="flux_example.py /cfs/dit/FLUX.1-dev/ 28" ["HunyuanDiT"]="hunyuandit_example.py /cfs/dit/HunyuanDiT-v1.2-Diffusers 50" ) @@ -27,6 +27,9 @@ mkdir -p ./results # task args TASK_ARGS="--height 1024 --width 1024 --no_use_resolution_binning" +# cache args +# CACHE_ARGS="--use_teacache" +# CACHE_ARGS="--use_fbcache" # On 8 gpus, pp=2, ulysses=2, ring=1, cfg_parallel=2 (split batch) N_GPUS=8 @@ -64,3 +67,4 @@ $CFG_ARGS \ $PARALLLEL_VAE \ $COMPILE_FLAG \ $QUANTIZE_FLAG \ +$CACHE_ARGS \ diff --git a/xfuser/config/args.py b/xfuser/config/args.py index b4afabe8..7f5095f3 100644 --- a/xfuser/config/args.py +++ b/xfuser/config/args.py @@ -111,6 +111,8 @@ class xFuserArgs: window_size: int = 64 coco_path: Optional[str] = None use_cache: bool = False + use_teacache: bool = False + use_fbcache: bool = False use_fp8_t5_encoder: bool = False @staticmethod @@ -154,6 +156,16 @@ def add_cli_args(parser: FlexibleArgumentParser): action="store_true", help="Enable onediff to accelerate inference in a single card", ) + runtime_group.add_argument( + "--use_teacache", + action="store_true", + help="Enable teacache to accelerate inference in a single card", + ) + runtime_group.add_argument( + "--use_fbcache", + action="store_true", + help="Enable teacache to accelerate inference in a single card", + ) # Parallel arguments parallel_group = parser.add_argument_group("Parallel Processing Options") diff --git a/xfuser/config/config.py b/xfuser/config/config.py index 2750b020..a87f0cc5 100644 --- a/xfuser/config/config.py +++ b/xfuser/config/config.py @@ -60,6 +60,8 @@ class RuntimeConfig: use_torch_compile: bool = False use_onediff: bool = False use_fp8_t5_encoder: bool = False + use_teacache: bool = False + use_fbcache: bool = False def __post_init__(self): check_packages() diff --git a/xfuser/core/distributed/group_coordinator.py b/xfuser/core/distributed/group_coordinator.py index 0402756f..b3c613a6 100644 --- a/xfuser/core/distributed/group_coordinator.py +++ b/xfuser/core/distributed/group_coordinator.py @@ -196,7 +196,7 @@ def group_skip_rank(self): world_size = self.world_size return (world_size - rank_in_group - 1) % world_size - def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + def all_reduce(self, input_: torch.Tensor, op=torch._C._distributed_c10d.ReduceOp.SUM) -> torch.Tensor: """ NOTE: This operation will be applied in-place or out-of-place. Always assume this function modifies its input, but use the return @@ -206,7 +206,7 @@ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: if self.world_size == 1: return input_ else: - torch.distributed.all_reduce(input_, group=self.device_group) + torch.distributed.all_reduce(input_, op=op, group=self.device_group) return input_ def all_gather( diff --git a/xfuser/core/distributed/parallel_state.py b/xfuser/core/distributed/parallel_state.py index 35590e5c..67b434e1 100644 --- a/xfuser/core/distributed/parallel_state.py +++ b/xfuser/core/distributed/parallel_state.py @@ -217,6 +217,7 @@ def init_distributed_environment( world_size=world_size, rank=rank, ) + torch.cuda.set_device(torch.distributed.get_rank() % torch.cuda.device_count()) # set the local rank # local_rank is not available in torch ProcessGroup, # see https://github.com/pytorch/pytorch/issues/122816 diff --git a/xfuser/model_executor/plugins/__init__.py b/xfuser/model_executor/plugins/__init__.py new file mode 100644 index 00000000..8e81242b --- /dev/null +++ b/xfuser/model_executor/plugins/__init__.py @@ -0,0 +1,4 @@ +""" +adapted from https://github.com/ali-vilab/TeaCache.git +adapted from https://github.com/chengzeyi/ParaAttention.git +""" \ No newline at end of file diff --git a/xfuser/model_executor/plugins/cache_/__init__.py b/xfuser/model_executor/plugins/cache_/__init__.py new file mode 100644 index 00000000..8e81242b --- /dev/null +++ b/xfuser/model_executor/plugins/cache_/__init__.py @@ -0,0 +1,4 @@ +""" +adapted from https://github.com/ali-vilab/TeaCache.git +adapted from https://github.com/chengzeyi/ParaAttention.git +""" \ No newline at end of file diff --git a/xfuser/model_executor/plugins/cache_/diffusers_adapters/__init__.py b/xfuser/model_executor/plugins/cache_/diffusers_adapters/__init__.py new file mode 100644 index 00000000..211afb9a --- /dev/null +++ b/xfuser/model_executor/plugins/cache_/diffusers_adapters/__init__.py @@ -0,0 +1,17 @@ +""" +adapted from https://github.com/ali-vilab/TeaCache.git +adapted from https://github.com/chengzeyi/ParaAttention.git +""" +import importlib +from typing import Type, Dict, TypeVar +from xfuser.model_executor.plugins.cache_.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY + + +def apply_cache_on_transformer(transformer, *args, **kwargs): + adapter_name = TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer)) + if not adapter_name: + raise ValueError(f"Unknown transformer class: {transformer.__class__.__name__}") + + adapter_module = importlib.import_module(f".{adapter_name}", __package__) + apply_cache_on_transformer_fn = getattr(adapter_module, "apply_cache_on_transformer") + return apply_cache_on_transformer_fn(transformer, *args, **kwargs) diff --git a/xfuser/model_executor/plugins/cache_/diffusers_adapters/flux.py b/xfuser/model_executor/plugins/cache_/diffusers_adapters/flux.py new file mode 100644 index 00000000..c3b40ba2 --- /dev/null +++ b/xfuser/model_executor/plugins/cache_/diffusers_adapters/flux.py @@ -0,0 +1,74 @@ +""" +adapted from https://github.com/ali-vilab/TeaCache.git +adapted from https://github.com/chengzeyi/ParaAttention.git +""" +import functools +import unittest + +import torch +from torch import nn +from diffusers import DiffusionPipeline, FluxTransformer2DModel +from xfuser.model_executor.plugins.cache_.diffusers_adapters.registry import TRANSFORMER_ADAPTER_REGISTRY + +from xfuser.model_executor.plugins.cache_ import utils + +def create_cached_transformer_blocks(use_cache, transformer, rel_l1_thresh, return_hidden_states_first, num_steps): + cached_transformer_class = { + "Fb": utils.FBCachedTransformerBlocks, + "Tea": utils.TeaCachedTransformerBlocks, + }.get(use_cache) + + if not cached_transformer_class: + raise ValueError(f"Unsupported use_cache value: {use_cache}") + + return cached_transformer_class( + transformer.transformer_blocks, + transformer.single_transformer_blocks, + transformer=transformer, + rel_l1_thresh=rel_l1_thresh, + return_hidden_states_first=return_hidden_states_first, + num_steps=num_steps, + name=TRANSFORMER_ADAPTER_REGISTRY.get(type(transformer)), + ) + + +def apply_cache_on_transformer( + transformer: FluxTransformer2DModel, + *, + rel_l1_thresh=0.6, + return_hidden_states_first=False, + num_steps=8, + use_cache="Fb", +): + cached_transformer_blocks = nn.ModuleList([ + create_cached_transformer_blocks(use_cache, transformer, rel_l1_thresh, return_hidden_states_first, num_steps) + ]) + + dummy_single_transformer_blocks = torch.nn.ModuleList() + + original_forward = transformer.forward + + @functools.wraps(original_forward) + def new_forward( + self, + *args, + **kwargs, + ): + with unittest.mock.patch.object( + self, + "transformer_blocks", + cached_transformer_blocks, + ), unittest.mock.patch.object( + self, + "single_transformer_blocks", + dummy_single_transformer_blocks, + ): + return original_forward( + *args, + **kwargs, + ) + + transformer.forward = new_forward.__get__(transformer) + + return transformer + diff --git a/xfuser/model_executor/plugins/cache_/diffusers_adapters/registry.py b/xfuser/model_executor/plugins/cache_/diffusers_adapters/registry.py new file mode 100644 index 00000000..2232832a --- /dev/null +++ b/xfuser/model_executor/plugins/cache_/diffusers_adapters/registry.py @@ -0,0 +1,16 @@ +""" +adapted from https://github.com/ali-vilab/TeaCache.git +adapted from https://github.com/chengzeyi/ParaAttention.git +""" +from typing import Type, Dict +from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel +from xfuser.model_executor.models.transformers.transformer_flux import xFuserFluxTransformer2DWrapper + +TRANSFORMER_ADAPTER_REGISTRY: Dict[Type, str] = {} + +def register_transformer_adapter(transformer_class: Type, adapter_name: str): + TRANSFORMER_ADAPTER_REGISTRY[transformer_class] = adapter_name + +register_transformer_adapter(FluxTransformer2DModel, "flux") +register_transformer_adapter(xFuserFluxTransformer2DWrapper, "flux") + diff --git a/xfuser/model_executor/plugins/cache_/utils.py b/xfuser/model_executor/plugins/cache_/utils.py new file mode 100644 index 00000000..9737fc0d --- /dev/null +++ b/xfuser/model_executor/plugins/cache_/utils.py @@ -0,0 +1,345 @@ +""" +adapted from https://github.com/ali-vilab/TeaCache.git +adapted from https://github.com/chengzeyi/ParaAttention.git +""" +import contextlib +import dataclasses +from collections import defaultdict +from typing import DefaultDict, Dict, Optional, List +from xfuser.core.distributed import ( + get_sp_group, + get_sequence_parallel_world_size, +) + +import torch +from abc import ABC, abstractmethod + + +#--------- CacheCallback ---------# +@dataclasses.dataclass +class CacheState: + transformer: None + transformer_blocks: None + single_transformer_blocks: None + cache_context: None + rel_l1_thresh: int = 0.6 + return_hidden_states_first: bool = True + use_cache: bool = False + num_steps: int = 8 + name: str = "default" + + +class CacheCallback: + def on_init_end(self, state: CacheState, **kwargs): + pass + + def on_forward_begin(self, state: CacheState, **kwargs): + pass + + def on_forward_remaining_begin(self, state: CacheState, **kwargs): + pass + + def on_forward_end(self, state: CacheState, **kwargs): + pass + + +class CallbackHandler(CacheCallback): + def __init__(self, callbacks): + self.callbacks = [] + if callbacks is not None: + for cb in callbacks: + self.add_callback(cb) + + def add_callback(self, callback): + cb = callback() if isinstance(callback, type) else callback + self.callbacks.append(cb) + + def pop_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return cb + else: + for cb in self.callbacks: + if cb == callback: + self.callbacks.remove(cb) + return cb + + def remove_callback(self, callback): + if isinstance(callback, type): + for cb in self.callbacks: + if isinstance(cb, callback): + self.callbacks.remove(cb) + return + else: + self.callbacks.remove(callback) + + def on_init_end(self, state: CacheState): + return self.call_event("on_init_end", state) + + def on_forward_begin(self, state: CacheState): + return self.call_event("on_forward_begin", state) + + def on_forward_remaining_begin(self, state: CacheState): + return self.call_event("on_forward_remaining_begin", state) + + def on_forward_end(self, state: CacheState): + return self.call_event("on_forward_end", state) + + def call_event(self, event, state, **kwargs): + for callback in self.callbacks: + getattr(callback, event)( + state, + **kwargs, + ) + + +#--------- CacheContext ---------# +@dataclasses.dataclass +class CacheContext: + buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict) + coefficients: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + self.coefficients["default"] = torch.Tensor([1, 0]).cuda() + self.coefficients["flux"] = torch.Tensor([4.98651651e+02, -2.83781631e+02, 5.58554382e+01, -3.82021401e+00, 2.64230861e-01]).cuda() + + def get_coef(self, name): + return self.coefficients.get(name) + + def get_buffer(self, name): + return self.buffers.get(name) + + def set_buffer(self, name, buffer): + self.buffers[name] = buffer + + def clear_buffer(self): + self.buffers.clear() + + +#--------- torch version of poly1d ---------# +class TorchPoly1D: + def __init__(self, coefficients): + self.coefficients = torch.tensor(coefficients, dtype=torch.float32) + self.degree = len(coefficients) - 1 + + def __call__(self, x): + result = torch.zeros_like(x) + for i, coef in enumerate(self.coefficients): + result += coef * (x ** (self.degree - i)) + return result + + +class CachedTransformerBlocks(torch.nn.Module, ABC): + def __init__( + self, + transformer_blocks, + single_transformer_blocks=None, + *, + transformer=None, + rel_l1_thresh=0.6, + return_hidden_states_first=True, + num_steps=-1, + name="default", + callbacks: Optional[List[CacheCallback]] = None, + ): + super().__init__() + self.state = CacheState( + transformer=transformer, + transformer_blocks=transformer_blocks, + single_transformer_blocks=single_transformer_blocks, + cache_context=CacheContext(), + ) + self.state.rel_l1_thresh = rel_l1_thresh + self.state.return_hidden_states_first = return_hidden_states_first + self.state.use_cache = False + self.state.num_steps=num_steps + self.state.name=name + self.callback_handler = CallbackHandler(callbacks) + self.callback_handler.on_init_end(self.state) + + def is_parallelized(self): + if get_sequence_parallel_world_size() > 1: + return True + return False + + def all_reduce(self, input_, op): + if get_sequence_parallel_world_size() > 1: + return get_sp_group().all_reduce(input_=input_, op=op) + raise NotImplementedError("Cache method not support parrellism other than sp") + + def l1_distance_two_tensor(self, t1, t2): + mean_diff = (t1 - t2).abs().mean() + mean_t1 = t1.abs().mean() + if self.is_parallelized(): + mean_diff = self.all_reduce(mean_diff.unsqueeze(0), op=torch._C._distributed_c10d.ReduceOp.AVG)[0] + mean_t1 = self.all_reduce(mean_t1.unsqueeze(0), op=torch._C._distributed_c10d.ReduceOp.AVG)[0] + diff = mean_diff / mean_t1 + return diff + + @abstractmethod + def are_two_tensor_similar(self, t1, t2, threshold): + pass + + def run_one_block_transformer(self, block, hidden_states, encoder_hidden_states, *args, **kwargs): + hidden_states, encoder_hidden_states = block(hidden_states, encoder_hidden_states, *args, **kwargs) + return ( + (hidden_states, encoder_hidden_states) + if self.state.return_hidden_states_first + else (encoder_hidden_states, hidden_states) + ) + + @abstractmethod + def get_start_idx(self): + pass + + def get_remaining_block_result(self, hidden_states, encoder_hidden_states, *args, **kwargs): + original_hidden_states = self.state.cache_context.get_buffer("original_hidden_states") + original_encoder_hidden_states = self.state.cache_context.get_buffer("original_encoder_hidden_states") + start_idx = self.get_start_idx() + if start_idx == -1: + return (hidden_states, encoder_hidden_states) + for block in self.state.transformer_blocks[start_idx:]: + hidden_states, encoder_hidden_states = \ + self.run_one_block_transformer(block, hidden_states, encoder_hidden_states, *args, **kwargs) + if self.state.single_transformer_blocks is not None: + hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) + for block in self.state.single_transformer_blocks: + hidden_states = block(hidden_states, *args, **kwargs) + encoder_hidden_states, hidden_states = hidden_states.split( + [encoder_hidden_states.shape[1], hidden_states.shape[1] - encoder_hidden_states.shape[1]], dim=1 + ) + self.state.cache_context.set_buffer("hidden_states_residual", hidden_states - original_hidden_states) + self.state.cache_context.set_buffer("encoder_hidden_states_residual", + encoder_hidden_states - original_encoder_hidden_states) + return (hidden_states, encoder_hidden_states) + + @abstractmethod + def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs): + pass + + def forward(self, hidden_states, encoder_hidden_states, *args, **kwargs): + self.callback_handler.on_forward_begin(self.state) + modulated_inputs, prev_modulated_inputs, original_hidden_states, original_encoder_hidden_states = \ + self.get_modulated_inputs(hidden_states, encoder_hidden_states, *args, **kwargs) + + self.state.cache_context.set_buffer("original_hidden_states", original_hidden_states) + self.state.cache_context.set_buffer("original_encoder_hidden_states", original_encoder_hidden_states) + + self.state.use_cache = prev_modulated_inputs is not None and self.are_two_tensor_similar( + t1=prev_modulated_inputs, t2=modulated_inputs, threshold=self.state.rel_l1_thresh) + + self.callback_handler.on_forward_remaining_begin(self.state) + if self.state.use_cache: + hidden_states += self.state.cache_context.get_buffer("hidden_states_residual") + encoder_hidden_states += self.state.cache_context.get_buffer("encoder_hidden_states_residual") + else: + hidden_states, encoder_hidden_states = self.get_remaining_block_result( + original_hidden_states, original_encoder_hidden_states, *args, **kwargs) + + self.callback_handler.on_forward_end(self.state) + return ( + (hidden_states, encoder_hidden_states) + if self.state.return_hidden_states_first + else (encoder_hidden_states, hidden_states) + ) + + +class FBCachedTransformerBlocks(CachedTransformerBlocks): + def __init__( + self, + transformer_blocks, + single_transformer_blocks=None, + *, + transformer=None, + rel_l1_thresh=0.6, + return_hidden_states_first=True, + num_steps=-1, + name="default", + callbacks: Optional[List[CacheCallback]] = None, + ): + super().__init__(transformer_blocks, + single_transformer_blocks=single_transformer_blocks, + transformer=transformer, + rel_l1_thresh=rel_l1_thresh, + num_steps=num_steps, + return_hidden_states_first=return_hidden_states_first, + name=name, + callbacks=callbacks) + + def get_start_idx(self): + return 1 + + def are_two_tensor_similar(self, t1, t2, threshold): + return self.l1_distance_two_tensor(t1, t2) < threshold + + def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs): + original_hidden_states = hidden_states + first_transformer_block = self.state.transformer_blocks[0] + hidden_states, encoder_hidden_states = \ + self.run_one_block_transformer(first_transformer_block, hidden_states, encoder_hidden_states, *args, **kwargs) + first_hidden_states_residual = hidden_states - original_hidden_states + prev_first_hidden_states_residual = self.state.cache_context.get_buffer("modulated_inputs") + self.state.cache_context.set_buffer("modulated_inputs", first_hidden_states_residual) + + return first_hidden_states_residual, prev_first_hidden_states_residual, hidden_states, encoder_hidden_states + + +class TeaCachedTransformerBlocks(CachedTransformerBlocks): + def __init__( + self, + transformer_blocks, + single_transformer_blocks=None, + *, + transformer=None, + rel_l1_thresh=0.6, + return_hidden_states_first=True, + num_steps=-1, + name="default", + callbacks: Optional[List[CacheCallback]] = None, + ): + super().__init__(transformer_blocks, + single_transformer_blocks=single_transformer_blocks, + transformer=transformer, + rel_l1_thresh=rel_l1_thresh, + num_steps=num_steps, + return_hidden_states_first=return_hidden_states_first, + name=name, + callbacks=callbacks) + object.__setattr__(self.state, 'cnt', 0) + object.__setattr__(self.state, 'accumulated_rel_l1_distance', 0) + object.__setattr__(self.state, 'rescale_func', TorchPoly1D(self.state.cache_context.get_coef(self.state.name))) + + def get_start_idx(self): + return 0 + + def are_two_tensor_similar(self, t1, t2, threshold): + if self.state.cnt == 0 or self.state.cnt == self.state.num_steps-1: + self.state.accumulated_rel_l1_distance = 0 + self.state.use_cache = False + else: + diff = self.l1_distance_two_tensor(t1, t2) + self.state.accumulated_rel_l1_distance += self.state.rescale_func(diff) + if self.state.accumulated_rel_l1_distance < threshold: + self.state.use_cache = True + else: + self.state.use_cache = False + self.state.accumulated_rel_l1_distance = 0 + self.state.cnt += 1 + if self.state.cnt == self.state.num_steps: + self.state.cnt = 0 + return self.state.use_cache + + + def get_modulated_inputs(self, hidden_states, encoder_hidden_states, *args, **kwargs): + inp = hidden_states.clone() + temb_ = kwargs.get("temb", None) + if temb_ is not None: + temb_ = temb_.clone() + else: + raise ValueError("'temb' not found in kwargs") + modulated_inp, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.state.transformer_blocks[0].norm1(inp, emb=temb_) + previous_modulated_input = self.state.cache_context.get_buffer("modulated_inputs") + self.state.cache_context.set_buffer("modulated_inputs", modulated_inp) + return modulated_inp, previous_modulated_input, hidden_states, encoder_hidden_states \ No newline at end of file