Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Flux](https://comfyanonymous.github.io/ComfyUI_examples/flux/)
- [Lumina Image 2.0](https://comfyanonymous.github.io/ComfyUI_examples/lumina2/)
- [HiDream](https://comfyanonymous.github.io/ComfyUI_examples/hidream/)
- [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Qwen Image](https://comfyanonymous.github.io/ComfyUI_examples/qwen_image/)
- Image Editing Models
- [Omnigen 2](https://comfyanonymous.github.io/ComfyUI_examples/omnigen/)
Expand All @@ -77,7 +76,6 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
- [Mochi](https://comfyanonymous.github.io/ComfyUI_examples/mochi/)
- [LTX-Video](https://comfyanonymous.github.io/ComfyUI_examples/ltxv/)
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
- [Nvidia Cosmos](https://comfyanonymous.github.io/ComfyUI_examples/cosmos/) and [Cosmos Predict2](https://comfyanonymous.github.io/ComfyUI_examples/cosmos_predict2/)
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
- Audio Models
Expand Down
34 changes: 34 additions & 0 deletions comfy_extras/nodes_easycache.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
input_change = None
do_easycache = easycache.should_do_easycache(sigmas)
if do_easycache:
easycache.check_metadata(x)
# if first cond marked this step for skipping, skip it and use appropriate cached values
if easycache.skip_current_step:
if easycache.verbose:
Expand Down Expand Up @@ -92,6 +93,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
input_change = None
do_easycache = easycache.should_do_easycache(timestep)
if do_easycache:
easycache.check_metadata(x)
if easycache.has_x_prev_subsampled():
if easycache.has_x_prev_subsampled():
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
Expand Down Expand Up @@ -194,6 +196,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
# how to deal with mismatched dims
self.allow_mismatch = True
self.cut_from_start = True
self.state_metadata = None

def is_past_end_timestep(self, timestep: float) -> bool:
return not (timestep[0] > self.end_t).item()
Expand Down Expand Up @@ -283,6 +286,17 @@ def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[U
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
return self.first_cond_uuid in uuids

def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape[1:])
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False

def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
Expand All @@ -299,6 +313,7 @@ def reset(self):
del self.uuid_cache_diffs
self.uuid_cache_diffs = {}
self.total_steps_skipped = 0
self.state_metadata = None
return self

def clone(self):
Expand Down Expand Up @@ -360,6 +375,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
self.output_change_rates = []
self.approx_output_change_rates = []
self.total_steps_skipped = 0
self.state_metadata = None

def has_cache_diff(self) -> bool:
return self.cache_diff is not None
Expand Down Expand Up @@ -404,6 +420,17 @@ def apply_cache_diff(self, x: torch.Tensor):
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
self.cache_diff = output - x

def check_metadata(self, x: torch.Tensor) -> bool:
metadata = (x.device, x.dtype, x.shape)
if self.state_metadata is None:
self.state_metadata = metadata
return True
if metadata == self.state_metadata:
return True
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
self.reset()
return False

def reset(self):
self.relative_transformation_rate = 0.0
self.cumulative_change_rate = 0.0
Expand All @@ -412,7 +439,14 @@ def reset(self):
self.approx_output_change_rates = []
del self.cache_diff
self.cache_diff = None
del self.x_prev_subsampled
self.x_prev_subsampled = None
del self.output_prev_subsampled
self.output_prev_subsampled = None
del self.output_prev_norm
self.output_prev_norm = None
self.total_steps_skipped = 0
self.state_metadata = None
return self

def clone(self):
Expand Down
Loading