@@ -28,6 +28,7 @@ def easycache_forward_wrapper(executor, *args, **kwargs):
2828 input_change = None
2929 do_easycache = easycache .should_do_easycache (sigmas )
3030 if do_easycache :
31+ easycache .check_metadata (x )
3132 # if first cond marked this step for skipping, skip it and use appropriate cached values
3233 if easycache .skip_current_step :
3334 if easycache .verbose :
@@ -92,6 +93,7 @@ def lazycache_predict_noise_wrapper(executor, *args, **kwargs):
9293 input_change = None
9394 do_easycache = easycache .should_do_easycache (timestep )
9495 if do_easycache :
96+ easycache .check_metadata (x )
9597 if easycache .has_x_prev_subsampled ():
9698 if easycache .has_x_prev_subsampled ():
9799 input_change = (easycache .subsample (x , clone = False ) - easycache .x_prev_subsampled ).flatten ().abs ().mean ()
@@ -194,6 +196,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
194196 # how to deal with mismatched dims
195197 self .allow_mismatch = True
196198 self .cut_from_start = True
199+ self .state_metadata = None
197200
198201 def is_past_end_timestep (self , timestep : float ) -> bool :
199202 return not (timestep [0 ] > self .end_t ).item ()
@@ -283,6 +286,17 @@ def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[U
283286 def has_first_cond_uuid (self , uuids : list [UUID ]) -> bool :
284287 return self .first_cond_uuid in uuids
285288
289+ def check_metadata (self , x : torch .Tensor ) -> bool :
290+ metadata = (x .device , x .dtype , x .shape [1 :])
291+ if self .state_metadata is None :
292+ self .state_metadata = metadata
293+ return True
294+ if metadata == self .state_metadata :
295+ return True
296+ logging .warn (f"{ self .name } - Tensor shape, dtype or device changed, resetting state" )
297+ self .reset ()
298+ return False
299+
286300 def reset (self ):
287301 self .relative_transformation_rate = 0.0
288302 self .cumulative_change_rate = 0.0
@@ -299,6 +313,7 @@ def reset(self):
299313 del self .uuid_cache_diffs
300314 self .uuid_cache_diffs = {}
301315 self .total_steps_skipped = 0
316+ self .state_metadata = None
302317 return self
303318
304319 def clone (self ):
@@ -360,6 +375,7 @@ def __init__(self, reuse_threshold: float, start_percent: float, end_percent: fl
360375 self .output_change_rates = []
361376 self .approx_output_change_rates = []
362377 self .total_steps_skipped = 0
378+ self .state_metadata = None
363379
364380 def has_cache_diff (self ) -> bool :
365381 return self .cache_diff is not None
@@ -404,6 +420,17 @@ def apply_cache_diff(self, x: torch.Tensor):
404420 def update_cache_diff (self , output : torch .Tensor , x : torch .Tensor ):
405421 self .cache_diff = output - x
406422
423+ def check_metadata (self , x : torch .Tensor ) -> bool :
424+ metadata = (x .device , x .dtype , x .shape )
425+ if self .state_metadata is None :
426+ self .state_metadata = metadata
427+ return True
428+ if metadata == self .state_metadata :
429+ return True
430+ logging .warn (f"{ self .name } - Tensor shape, dtype or device changed, resetting state" )
431+ self .reset ()
432+ return False
433+
407434 def reset (self ):
408435 self .relative_transformation_rate = 0.0
409436 self .cumulative_change_rate = 0.0
@@ -412,7 +439,14 @@ def reset(self):
412439 self .approx_output_change_rates = []
413440 del self .cache_diff
414441 self .cache_diff = None
442+ del self .x_prev_subsampled
443+ self .x_prev_subsampled = None
444+ del self .output_prev_subsampled
445+ self .output_prev_subsampled = None
446+ del self .output_prev_norm
447+ self .output_prev_norm = None
415448 self .total_steps_skipped = 0
449+ self .state_metadata = None
416450 return self
417451
418452 def clone (self ):
0 commit comments