Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Ovi/vae/edm2_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@

import numpy as np
import torch
from collections import OrderedDict

#----------------------------------------------------------------------------
# Variant of constant() that inherits dtype and device from the given
# reference tensor by default.

_constant_cache = dict()
# Cache for broadcasted constant tensors, bounded by LRU to avoid unbounded growth.
_MAX_CONSTANT_CACHE = 64
_constant_cache = OrderedDict()


def constant(value, shape=None, dtype=None, device=None, memory_format=None):
Expand All @@ -36,6 +39,9 @@ def constant(value, shape=None, dtype=None, device=None, memory_format=None):
tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape))
tensor = tensor.contiguous(memory_format=memory_format)
_constant_cache[key] = tensor
_constant_cache.move_to_end(key)
while len(_constant_cache) > _MAX_CONSTANT_CACHE:
_constant_cache.popitem(last=False)
return tensor


Expand Down
10 changes: 9 additions & 1 deletion context_windows/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,11 +235,12 @@ def create_window_mask(noise_pred_context, c, latent_video_length, context_overl
return window_mask

class WindowTracker:
def __init__(self, verbose=False):
def __init__(self, verbose=False, max_windows=64):
self.window_map = {} # Maps frame sequence to persistent ID
self.next_id = 0
self.cache_states = {} # Maps persistent ID to teacache state
self.verbose = verbose
self.max_windows = max_windows

def get_window_id(self, frames):
key = tuple(sorted(frames)) # Order-independent frame sequence
Expand All @@ -248,6 +249,13 @@ def get_window_id(self, frames):
if self.verbose:
log.info(f"New window pattern {key} -> ID {self.next_id}")
self.next_id += 1
# Prevent unbounded growth if many unique window patterns are used
if len(self.window_map) > self.max_windows:
oldest_key = next(iter(self.window_map))
oldest_id = self.window_map.pop(oldest_key)
self.cache_states.pop(oldest_id, None)
if self.verbose:
log.info(f"Evicted oldest window pattern {oldest_key} (ID {oldest_id})")
return self.window_map[key]

def get_teacache(self, window_id, base_state):
Expand Down
277 changes: 230 additions & 47 deletions custom_linear.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions gguf/gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def load_gguf(model_path):
def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[], patches=None, compile_args=None):
return _replace_linear(model, compute_dtype, state_dict, prefix, patches, None, compile_args, modules_to_not_convert)

def set_lora_params_gguf(module, patches, module_prefix="", device=torch.device("cpu")):
return set_lora_params(module, patches, module_prefix, device)
def set_lora_params_gguf(module, patches, module_prefix="", device=torch.device("cpu"), force_cpu=False, _diag=None):
return set_lora_params(module, patches, module_prefix, device, force_cpu, _diag=_diag)

GGUFLinear = CustomLinear
51 changes: 43 additions & 8 deletions multitalk/multitalk_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from ..latent_preview import prepare_callback
from ..wanvideo.schedulers import get_scheduler
from .multitalk import timestep_transform, add_noise
from ..utils import log, print_memory, temporal_score_rescaling, offload_transformer, init_blockswap, match_and_blend_colors
from ..utils import log, print_memory, temporal_score_rescaling, offload_transformer, init_blockswap, match_and_blend_colors, reopen_gguf_readers
from comfy.utils import load_torch_file
from ..nodes_model_loading import load_weights
from ..HuMo.nodes import get_audio_emb_window
Expand All @@ -22,6 +22,27 @@
device = mm.get_torch_device()
offload_device = mm.unet_offload_device()

def _weights_load_signature(model, patcher, transformer, block_swap_args, device, weight_dtype, base_dtype, gguf_reader):
"""Return a hashable signature describing the current weight-load configuration.

Used to skip redundant full-model load_weights() calls when the model has
already been loaded with identical settings in a previous execution.
"""
try:
compile_args = model["compile_args"]
except KeyError:
compile_args = None
return (
str(device),
str(weight_dtype),
str(base_dtype),
str(block_swap_args) if block_swap_args is not None else None,
str(compile_args) if compile_args is not None else None,
gguf_reader is not None,
getattr(transformer, "patched_linear", False),
len(patcher.patches),
)

def multitalk_loop(self, **kwargs):
# Unpack kwargs into local variables
(latent, total_steps, steps, start_step, end_step, shift, cfg, denoise_strength,
Expand Down Expand Up @@ -113,7 +134,7 @@ def multitalk_loop(self, **kwargs):
try:
silence_path = os.path.join(script_directory, "encoded_silence.safetensors")
encoded_silence = load_torch_file(silence_path)["audio_emb"].to(dtype)
except Exception:
except:
log.warning("No encoded silence file found, padding with end of audio embedding instead.")

total_frames = len(audio_embedding[0])
Expand Down Expand Up @@ -343,14 +364,28 @@ def multitalk_loop(self, **kwargs):
del motion_add_noise, add_latent

if offloaded:
# Load weights
if transformer.patched_linear and gguf_reader is None:
load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args)
elif gguf_reader is not None: #handle GGUF
load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args)
# Load weights (only if configuration changed or weights were offloaded)
offloaded_sig = _weights_load_signature(model, patcher, transformer, block_swap_args, device, weight_dtype, dtype, gguf_reader)
offloaded_loaded = (getattr(transformer, "_wan_weights_load_signature", None) == offloaded_sig) and not getattr(transformer, "_wan_weights_offloaded", False)
if not offloaded_loaded:
if gguf_reader is not None: #handle GGUF
reopen_gguf_readers(gguf_reader)
load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args)
else:
load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args)
transformer._wan_weights_load_signature = offloaded_sig
transformer._wan_weights_offloaded = False
#blockswap init
init_blockswap(transformer, block_swap_args, model)

# Release the original state dict if still present (main sampler usually does this earlier,
# but keep it here for entry points that call multitalk_loop directly).
if patcher.model.get("sd") is not None:
log.info("MultiTalk: releasing patcher.model['sd'] to free memory")
patcher.model["sd"] = None
gc.collect()
mm.soft_empty_cache()

# Use the appropriate prompt for this section
if len(text_embeds["prompt_embeds"]) > 1:
prompt_index = min(iteration_count, len(text_embeds["prompt_embeds"]) - 1)
Expand Down Expand Up @@ -564,6 +599,6 @@ def multitalk_loop(self, **kwargs):
try:
print_memory(device)
torch.cuda.reset_peak_memory_stats(device)
except Exception:
except:
pass
return {"video": gen_video_samples.permute(1, 2, 3, 0), "output_path": output_path},
10 changes: 8 additions & 2 deletions nodes.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os, gc, math
from collections import OrderedDict
import torch
import torch.nn.functional as F
import hashlib
Expand Down Expand Up @@ -148,8 +149,9 @@ def create_list(self, blocks):



# In-memory cache for prompt extender output
_extender_cache = {}
# In-memory cache for prompt extender output (LRU-bounded)
_MAX_EXTENDER_CACHE = 128
_extender_cache = OrderedDict()

cache_dir = os.path.join(script_directory, 'text_embed_cache')

Expand Down Expand Up @@ -231,6 +233,7 @@ def process(self, model_name, precision, positive_prompt, negative_prompt, quant
extender_key = (orig_prompt, str(extender_args))
if extender_key in _extender_cache:
positive_prompt = _extender_cache[extender_key]
_extender_cache.move_to_end(extender_key)
log.info(f"Loaded extended prompt from in-memory cache: {positive_prompt}")
else:
from .qwen.qwen import QwenLoader, WanVideoPromptExtender
Expand All @@ -250,6 +253,9 @@ def process(self, model_name, precision, positive_prompt, negative_prompt, quant
)
log.info(f"Extended positive prompt: {positive_prompt}")
_extender_cache[extender_key] = positive_prompt
_extender_cache.move_to_end(extender_key)
while len(_extender_cache) > _MAX_EXTENDER_CACHE:
_extender_cache.popitem(last=False)
del qwen
pbar.update(1)

Expand Down
Loading