Skip to content

Commit 95ac779

Browse files
authored
Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling (Comfy-Org#9528)
* Fix EasyCache/LazyCache crash when tensor shape/dtype/device changes during sampling * Fix missing LazyCache check_metadata method Ensure LazyCache reset method resets all the tensor state values
1 parent 71ed4a3 commit 95ac779

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

comfy_extras/nodes_easycache.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
2828
input_change = None
2929
do_easycache = easycache.should_do_easycache(sigmas)
3030
if do_easycache:
31+
easycache.check_metadata(x)
3132
# if first cond marked this step for skipping, skip it and use appropriate cached values
3233
if easycache.skip_current_step:
3334
if easycache.verbose:
@@ -92,6 +93,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
9293
input_change = None
9394
do_easycache = easycache.should_do_easycache(timestep)
9495
if do_easycache:
96+
easycache.check_metadata(x)
9597
if easycache.has_x_prev_subsampled():
9698
if easycache.has_x_prev_subsampled():
9799
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean()
@@ -194,6 +196,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
194196
# how to deal with mismatched dims
195197
self.allow_mismatch = True
196198
self.cut_from_start = True
199+
self.state_metadata = None
197200

198201
def is_past_end_timestep(self, timestep: float) -> bool:
199202
return not (timestep[0] > self.end_t).item()
@@ -283,6 +286,17 @@ def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[U
283286
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool:
284287
return self.first_cond_uuid in uuids
285288

289+
def check_metadata(self, x: torch.Tensor) -> bool:
290+
metadata = (x.device, x.dtype, x.shape[1:])
291+
if self.state_metadata is None:
292+
self.state_metadata = metadata
293+
return True
294+
if metadata == self.state_metadata:
295+
return True
296+
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
297+
self.reset()
298+
return False
299+
286300
def reset(self):
287301
self.relative_transformation_rate = 0.0
288302
self.cumulative_change_rate = 0.0
@@ -299,6 +313,7 @@ def reset(self):
299313
del self.uuid_cache_diffs
300314
self.uuid_cache_diffs = {}
301315
self.total_steps_skipped = 0
316+
self.state_metadata = None
302317
return self
303318

304319
def clone(self):
@@ -360,6 +375,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
360375
self.output_change_rates = []
361376
self.approx_output_change_rates = []
362377
self.total_steps_skipped = 0
378+
self.state_metadata = None
363379

364380
def has_cache_diff(self) -> bool:
365381
return self.cache_diff is not None
@@ -404,6 +420,17 @@ def apply_cache_diff(self, x: torch.Tensor):
404420
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor):
405421
self.cache_diff = output - x
406422

423+
def check_metadata(self, x: torch.Tensor) -> bool:
424+
metadata = (x.device, x.dtype, x.shape)
425+
if self.state_metadata is None:
426+
self.state_metadata = metadata
427+
return True
428+
if metadata == self.state_metadata:
429+
return True
430+
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state")
431+
self.reset()
432+
return False
433+
407434
def reset(self):
408435
self.relative_transformation_rate = 0.0
409436
self.cumulative_change_rate = 0.0
@@ -412,7 +439,14 @@ def reset(self):
412439
self.approx_output_change_rates = []
413440
del self.cache_diff
414441
self.cache_diff = None
442+
del self.x_prev_subsampled
443+
self.x_prev_subsampled = None
444+
del self.output_prev_subsampled
445+
self.output_prev_subsampled = None
446+
del self.output_prev_norm
447+
self.output_prev_norm = None
415448
self.total_steps_skipped = 0
449+
self.state_metadata = None
416450
return self
417451

418452
def clone(self):

0 commit comments

Comments
 (0)