diff --git a/Ovi/vae/edm2_utils.py b/Ovi/vae/edm2_utils.py index a18ffba5..8d4bdf8a 100644 --- a/Ovi/vae/edm2_utils.py +++ b/Ovi/vae/edm2_utils.py @@ -9,12 +9,15 @@ import numpy as np import torch +from collections import OrderedDict #---------------------------------------------------------------------------- # Variant of constant() that inherits dtype and device from the given # reference tensor by default. -_constant_cache = dict() +# Cache for broadcasted constant tensors, bounded by LRU to avoid unbounded growth. +_MAX_CONSTANT_CACHE = 64 +_constant_cache = OrderedDict() def constant(value, shape=None, dtype=None, device=None, memory_format=None): @@ -36,6 +39,9 @@ def constant(value, shape=None, dtype=None, device=None, memory_format=None): tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) tensor = tensor.contiguous(memory_format=memory_format) _constant_cache[key] = tensor + _constant_cache.move_to_end(key) + while len(_constant_cache) > _MAX_CONSTANT_CACHE: + _constant_cache.popitem(last=False) return tensor diff --git a/context_windows/context.py b/context_windows/context.py index 4f64f3f9..c16432b2 100644 --- a/context_windows/context.py +++ b/context_windows/context.py @@ -235,11 +235,12 @@ def create_window_mask(noise_pred_context, c, latent_video_length, context_overl return window_mask class WindowTracker: - def __init__(self, verbose=False): + def __init__(self, verbose=False, max_windows=64): self.window_map = {} # Maps frame sequence to persistent ID self.next_id = 0 self.cache_states = {} # Maps persistent ID to teacache state self.verbose = verbose + self.max_windows = max_windows def get_window_id(self, frames): key = tuple(sorted(frames)) # Order-independent frame sequence @@ -248,6 +249,13 @@ def get_window_id(self, frames): if self.verbose: log.info(f"New window pattern {key} -> ID {self.next_id}") self.next_id += 1 + # Prevent unbounded growth if many unique window patterns are used + if len(self.window_map) > self.max_windows: + oldest_key = next(iter(self.window_map)) + oldest_id = self.window_map.pop(oldest_key) + self.cache_states.pop(oldest_id, None) + if self.verbose: + log.info(f"Evicted oldest window pattern {oldest_key} (ID {oldest_id})") return self.window_map[key] def get_teacache(self, window_id, base_state): diff --git a/custom_linear.py b/custom_linear.py index 9fe47628..0a977fb9 100644 --- a/custom_linear.py +++ b/custom_linear.py @@ -1,7 +1,9 @@ -import torch +import torch, gc import torch.nn as nn from accelerate import init_empty_weights from .gguf.gguf_utils import GGUFParameter, dequantize_gguf_tensor +import logging +_ram_log = logging.getLogger(__name__) @torch.library.custom_op("wanvideo::apply_lora", mutates_args=()) def apply_lora(weight: torch.Tensor, lora_diff_0: torch.Tensor, lora_diff_1: torch.Tensor, lora_diff_2: float, lora_strength: torch.Tensor) -> torch.Tensor: @@ -58,13 +60,20 @@ def _replace_linear(model, compute_dtype, state_dict, prefix="", patches=None, s if isinstance(module, nn.Linear) and "loras" not in module_prefix and "dual_controller" not in module_prefix and name not in modules_to_not_convert: weight_key = module_prefix + "weight" - if weight_key not in state_dict: - continue + if state_dict is not None: + if weight_key not in state_dict: + continue + weight = state_dict[weight_key] + else: + # sd was released to save memory; fall back to the already-loaded parameter + weight = getattr(module, "weight", None) + if weight is None or weight.numel() == 0: + continue - in_features = state_dict[weight_key].shape[1] - out_features = state_dict[weight_key].shape[0] + in_features = weight.shape[1] + out_features = weight.shape[0] - is_gguf = isinstance(state_dict[weight_key], GGUFParameter) + is_gguf = isinstance(weight, GGUFParameter) scale_weight = None if not is_gguf and scale_weights is not None: @@ -86,42 +95,164 @@ def _replace_linear(model, compute_dtype, state_dict, prefix="", patches=None, s return model -def set_lora_params(module, patches, module_prefix="", device=torch.device("cpu")): +def set_lora_params(module, patches, module_prefix="", device=torch.device("cpu"), force_cpu=False, _depth=0, _diag=None): + """Apply LoRA patches to CustomLinear layers using a progressive approach. + + Instead of recursively iterating modules and holding all float32 LoRA tensors + in memory throughout, this function: + 1. Clears any previously-applied LoRA attributes + 2. Builds a key→CustomLinear map once via named_modules() + 3. Iterates patches dict, applies each patch immediately, then DELETES + the entry from patches to free float32 originals progressively + + This avoids the peak memory scenario where ALL float32 originals (~16 GB) + AND ALL bfloat16 copies (~8 GB) coexist simultaneously. + + Returns (lora_param_count, lora_total_bytes, module_count_matched). + When _diag is a dict, fills it with diagnostic counters: + _diag['customlinear_total'] = total CustomLinear modules found + _diag['customlinear_matched'] = number matched with a patch + _diag['customlinear_bytes'] = bytes of bfloat16 LoRA tensors stored + _diag['_key_mismatches'] = first 5 patch keys with no matching module + """ + import psutil, os as _os + _pid = _os.getpid() + def _rss_mb(): + try: + return psutil.Process(_pid).memory_info().rss / (1024 * 1024) + except Exception: + return 0.0 + + if _diag is None: + _diag = {} + + _rss0 = _rss_mb() + _ram_log.info(f"[RAM-diag] set_lora_params start | RSS: {_rss0:.1f} MB | {len(patches)} patch entries") + + # Step 1: Clear any previously applied LoRA attrs from all CustomLinear modules remove_lora_from_module(module) - # Recursively set lora_diffs and lora_strengths for all CustomLinear layers - for name, child in module.named_children(): - params = list(child.parameters()) - if params: - device = params[0].device - else: - device = torch.device("cpu") - child_prefix = (f"{module_prefix}{name}.") - set_lora_params(child, patches, child_prefix, device) - if isinstance(module, CustomLinear): - key = f"diffusion_model.{module_prefix}weight" - patch = patches.get(key, []) - #print(f"Processing LoRA patches for {key}: {len(patch)} patches found") - if len(patch) == 0: - key = key.replace("_orig_mod.", "") - patch = patches.get(key, []) - #print(f"Processing LoRA patches for {key}: {len(patch)} patches found") - if len(patch) != 0: - lora_diffs = [] - for p in patch: - lora_obj = p[1] - if "head" in key: - continue # For now skip LoRA for head layers - elif hasattr(lora_obj, "weights"): - lora_diffs.append(lora_obj.weights) - elif isinstance(lora_obj, tuple) and lora_obj[0] == "diff": - lora_diffs.append(lora_obj[1]) - else: - continue - lora_strengths = [p[0] for p in patch] - module.set_lora_diffs(lora_diffs, device=device) - module.set_lora_strengths(lora_strengths, device=device) - module._step.fill_(0) # Initialize step for LoRA scheduling + _rss1 = _rss_mb() + _ram_log.info(f"[RAM-diag] after remove_lora_from_module | RSS: {_rss1:.1f} MB (delta: {_rss1 - _rss0:+.1f} MB)") + + # Step 2: Build key→CustomLinear map once + # named_modules() returns (name, module) where name is the full dotted path + cl_map = {} + for mod_name, submodule in module.named_modules(): + if isinstance(submodule, CustomLinear): + # Construct key matching patcher.patches convention + key = f"diffusion_model.{mod_name}.weight" + cl_map[key] = submodule + # Also register the _orig_mod stripped variant (from torch.compile wrapping) + stripped = key.replace("_orig_mod.", "") + if stripped != key: + cl_map[stripped] = submodule + + _unique_cl = len(set(id(m) for _, m in cl_map.items())) + _rss2 = _rss_mb() + _ram_log.info(f"[RAM-diag] after building CL map ({len(cl_map)} key entries, {_unique_cl} unique modules) | RSS: {_rss2:.1f} MB") + + # Step 3a: Eject ALL unmatched patches FIRST — frees float32 tensors + # before any bfloat16 conversion starts, eliminating the + # intermediate RSS peak from ~650 leftover float32 patches. + unmatched_keys = [k for k in patches if k not in cl_map] + _diag.setdefault('_key_mismatches', []) + for k in unmatched_keys[:5]: + _diag['_key_mismatches'].append(k) + for k in unmatched_keys: + del patches[k] + unmatched_count = len(unmatched_keys) + if unmatched_count: + _rss_pre = _rss_mb() + gc.collect() + _rss_post = _rss_mb() + _ram_log.info( + f"[RAM-diag] ejected {unmatched_count} unmatched patches | " + f"RSS: {_rss_pre:.1f} → {_rss_post:.1f} MB (delta: {_rss_post - _rss_pre:+.1f} MB)" + ) + + # Step 3b: Process matched patches progressively + lora_param_count = 0 + lora_total_bytes = 0 + module_count_matched = 0 + total_patches = len(patches) + processed = 0 + + for key in list(patches.keys()): + cl_module = cl_map[key] + patch = patches[key] + + # Build LoRA diff list from patch entries. + # PRE-CONVERT to (CPU, compute_dtype) so set_lora_diffs sees + # already-matching tensors and returns self (zero-copy). + # Without explicit device="cpu", float32 CUDA tensors get + # converted to bf16-CUDA, then set_lora_diffs creates a SECOND + # copy via .to(cpu, bf16) — doubling RSS. + target_device = torch.device("cpu") + lora_diffs = [] + for p in patch: + lora_obj = p[1] + if "head" in key: + continue # Skip LoRA for head layers + elif hasattr(lora_obj, "weights"): + weights = lora_obj.weights + new_weights = tuple( + w.to(device=target_device, dtype=cl_module.compute_dtype) + if torch.is_tensor(w) else w + for w in weights + ) + lora_obj.weights = new_weights + lora_diffs.append(new_weights) + elif isinstance(lora_obj, tuple) and lora_obj[0] == "diff": + diffs = lora_obj[1] + new_diffs = tuple( + w.to(device=target_device, dtype=cl_module.compute_dtype) + if torch.is_tensor(w) else w + for w in diffs + ) + lora_diffs.append(new_diffs) + else: + continue + + if not lora_diffs: + del patches[key] + continue + + lora_strengths = [p[0] for p in patch] + diff_bytes = cl_module.set_lora_diffs(lora_diffs, device=device, _diag=_diag) + cl_module.set_lora_strengths(lora_strengths, device=device) + cl_module._step.fill_(0) + + module_count_matched += 1 + lora_total_bytes += diff_bytes + lora_param_count += len(lora_diffs) + + # IMMEDIATELY delete from patches dict to free float32 originals + del patches[key] + processed += 1 + + # Periodic GC to help OS reclaim freed pages + if processed % 50 == 0: + gc.collect() + _rss_now = _rss_mb() + _ram_log.info(f"[RAM-diag] progress {processed}/{total_patches} | RSS: {_rss_now:.1f} MB | accum bytes: {lora_total_bytes / (1024**2):.1f} MB") + # Final cleanup + gc.collect() + + _diag['customlinear_total'] = _unique_cl # unique CustomLinear modules, not key entries + _diag['customlinear_matched'] = module_count_matched + _diag['customlinear_bytes'] = lora_total_bytes + + _rss3 = _rss_mb() + _ram_log.info( + f"[RAM-diag] set_lora_params done | RSS: {_rss3:.1f} MB (delta: {_rss3 - _rss0:+.1f} MB) | " + f"total CL: {_unique_cl} | matched: {module_count_matched} | " + f"accum_bytes: {lora_total_bytes / (1024**2):.1f} MB | " + f"to_copies={_diag.get('to_copies', '?')} to_noops={_diag.get('to_noops', '?')} | " + f"d2_bytes={_diag.get('d2_total_bytes', 0) / (1024**2):.1f} MB (count={_diag.get('d2_count', 0)})" + ) + + return lora_param_count, lora_total_bytes, module_count_matched class CustomLinear(nn.Linear): def __init__( @@ -182,17 +313,54 @@ def _apply_single_lora_custom_op(self, weight, lora_diff, lora_strength): def _linear_forward_custom_op(self, input, weight, bias): return torch.ops.wanvideo.linear_forward(input, weight, bias) - def set_lora_diffs(self, lora_diffs, device=torch.device("cpu")): + def set_lora_diffs(self, lora_diffs, device=torch.device("cpu"), _diag=None): self.lora_diffs = [] + diff_bytes = 0 + if _diag is None: + _diag = {} for i, diff in enumerate(lora_diffs): if len(diff) > 1: - self.register_buffer(f"lora_diff_{i}_0", diff[0].to(device, self.compute_dtype)) - self.register_buffer(f"lora_diff_{i}_1", diff[1].to(device, self.compute_dtype)) - setattr(self, f"lora_diff_{i}_2", diff[2]) + d0_src = diff[0] + d1_src = diff[1] + d2_src = diff[2] + d0 = d0_src.to(device, self.compute_dtype) + d1 = d1_src.to(device, self.compute_dtype) + # Diagnostic: check if .to() created a copy + _diag.setdefault('to_copies', 0) + _diag.setdefault('to_noops', 0) + if d0 is d0_src: + _diag['to_noops'] += 1 + else: + _diag['to_copies'] += 1 + if d1 is d1_src: + _diag['to_noops'] += 1 + else: + _diag['to_copies'] += 1 + # Check diff[2] size + if torch.is_tensor(d2_src): + d2_bytes = d2_src.numel() * d2_src.element_size() + _diag.setdefault('d2_total_bytes', 0) + _diag['d2_total_bytes'] += d2_bytes + _diag.setdefault('d2_count', 0) + _diag['d2_count'] += 1 + setattr(self, f"lora_diff_{i}_0", d0) + setattr(self, f"lora_diff_{i}_1", d1) + setattr(self, f"lora_diff_{i}_2", d2_src) self.lora_diffs.append((f"lora_diff_{i}_0", f"lora_diff_{i}_1", f"lora_diff_{i}_2")) + diff_bytes += d0.numel() * d0.element_size() + d1.numel() * d1.element_size() else: - self.register_buffer(f"lora_diff_{i}_0", diff[0].to(device, self.compute_dtype)) + d0_src = diff[0] + d0 = d0_src.to(device, self.compute_dtype) + _diag.setdefault('to_copies', 0) + _diag.setdefault('to_noops', 0) + if d0 is d0_src: + _diag['to_noops'] += 1 + else: + _diag['to_copies'] += 1 + setattr(self, f"lora_diff_{i}_0", d0) self.lora_diffs.append(f"lora_diff_{i}_0") + diff_bytes += d0.numel() * d0.element_size() + return diff_bytes def set_lora_strengths(self, lora_strengths, device=torch.device("cpu")): self._lora_strength_tensors = [] @@ -201,11 +369,11 @@ def set_lora_strengths(self, lora_strengths, device=torch.device("cpu")): for i, strength in enumerate(lora_strengths): if isinstance(strength, list): tensor = torch.tensor(strength, dtype=self.compute_dtype, device=device) - self.register_buffer(f"_lora_strength_{i}", tensor) + setattr(self, f"_lora_strength_{i}", tensor) self._lora_strength_is_scheduled.append(True) else: tensor = torch.tensor([strength], dtype=self.compute_dtype, device=device) - self.register_buffer(f"_lora_strength_{i}", tensor) + setattr(self, f"_lora_strength_{i}", tensor) self._lora_strength_is_scheduled.append(False) def _get_lora_strength(self, idx): @@ -221,11 +389,17 @@ def _get_weight_with_lora(self, weight): for idx, lora_diff_names in enumerate(self.lora_diffs): lora_strength = self._get_lora_strength(idx) + if lora_strength.device != weight.device: + lora_strength = lora_strength.to(weight.device, weight.dtype) if isinstance(lora_diff_names, tuple): lora_diff_0 = getattr(self, lora_diff_names[0]) lora_diff_1 = getattr(self, lora_diff_names[1]) lora_diff_2 = getattr(self, lora_diff_names[2]) + if lora_diff_0.device != weight.device: + lora_diff_0 = lora_diff_0.to(weight.device, weight.dtype) + if lora_diff_1.device != weight.device: + lora_diff_1 = lora_diff_1.to(weight.device, weight.dtype) weight = self._apply_lora_impl( weight, lora_diff_0, lora_diff_1, @@ -233,6 +407,8 @@ def _get_weight_with_lora(self, weight): ) else: lora_diff = getattr(self, lora_diff_names) + if lora_diff.device != weight.device: + lora_diff = lora_diff.to(weight.device, weight.dtype) weight = self._apply_single_lora_impl(weight, lora_diff, lora_strength) return weight @@ -279,3 +455,10 @@ def remove_lora_from_module(module): delattr(submodule, f"lora_diff_{i}_1") if hasattr(submodule, f"lora_diff_{i}_2"): delattr(submodule, f"lora_diff_{i}_2") + # Clear strength tensors as well + i = 0 + while hasattr(submodule, f"_lora_strength_{i}"): + delattr(submodule, f"_lora_strength_{i}") + i += 1 + submodule.lora_diffs = [] + submodule._lora_strength_is_scheduled = [] diff --git a/gguf/gguf.py b/gguf/gguf.py index 4b168aab..3547982d 100644 --- a/gguf/gguf.py +++ b/gguf/gguf.py @@ -19,7 +19,7 @@ def load_gguf(model_path): def _replace_with_gguf_linear(model, compute_dtype, state_dict, prefix="", modules_to_not_convert=[], patches=None, compile_args=None): return _replace_linear(model, compute_dtype, state_dict, prefix, patches, None, compile_args, modules_to_not_convert) -def set_lora_params_gguf(module, patches, module_prefix="", device=torch.device("cpu")): - return set_lora_params(module, patches, module_prefix, device) +def set_lora_params_gguf(module, patches, module_prefix="", device=torch.device("cpu"), force_cpu=False, _diag=None): + return set_lora_params(module, patches, module_prefix, device, force_cpu, _diag=_diag) GGUFLinear = CustomLinear \ No newline at end of file diff --git a/multitalk/multitalk_loop.py b/multitalk/multitalk_loop.py index c40df11f..7b87fcb3 100644 --- a/multitalk/multitalk_loop.py +++ b/multitalk/multitalk_loop.py @@ -6,7 +6,7 @@ from ..latent_preview import prepare_callback from ..wanvideo.schedulers import get_scheduler from .multitalk import timestep_transform, add_noise -from ..utils import log, print_memory, temporal_score_rescaling, offload_transformer, init_blockswap, match_and_blend_colors +from ..utils import log, print_memory, temporal_score_rescaling, offload_transformer, init_blockswap, match_and_blend_colors, reopen_gguf_readers from comfy.utils import load_torch_file from ..nodes_model_loading import load_weights from ..HuMo.nodes import get_audio_emb_window @@ -22,6 +22,27 @@ device = mm.get_torch_device() offload_device = mm.unet_offload_device() +def _weights_load_signature(model, patcher, transformer, block_swap_args, device, weight_dtype, base_dtype, gguf_reader): + """Return a hashable signature describing the current weight-load configuration. + + Used to skip redundant full-model load_weights() calls when the model has + already been loaded with identical settings in a previous execution. + """ + try: + compile_args = model["compile_args"] + except KeyError: + compile_args = None + return ( + str(device), + str(weight_dtype), + str(base_dtype), + str(block_swap_args) if block_swap_args is not None else None, + str(compile_args) if compile_args is not None else None, + gguf_reader is not None, + getattr(transformer, "patched_linear", False), + len(patcher.patches), + ) + def multitalk_loop(self, **kwargs): # Unpack kwargs into local variables (latent, total_steps, steps, start_step, end_step, shift, cfg, denoise_strength, @@ -113,7 +134,7 @@ def multitalk_loop(self, **kwargs): try: silence_path = os.path.join(script_directory, "encoded_silence.safetensors") encoded_silence = load_torch_file(silence_path)["audio_emb"].to(dtype) - except Exception: + except: log.warning("No encoded silence file found, padding with end of audio embedding instead.") total_frames = len(audio_embedding[0]) @@ -343,14 +364,28 @@ def multitalk_loop(self, **kwargs): del motion_add_noise, add_latent if offloaded: - # Load weights - if transformer.patched_linear and gguf_reader is None: - load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args) - elif gguf_reader is not None: #handle GGUF - load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args) + # Load weights (only if configuration changed or weights were offloaded) + offloaded_sig = _weights_load_signature(model, patcher, transformer, block_swap_args, device, weight_dtype, dtype, gguf_reader) + offloaded_loaded = (getattr(transformer, "_wan_weights_load_signature", None) == offloaded_sig) and not getattr(transformer, "_wan_weights_offloaded", False) + if not offloaded_loaded: + if gguf_reader is not None: #handle GGUF + reopen_gguf_readers(gguf_reader) + load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args) + else: + load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args) + transformer._wan_weights_load_signature = offloaded_sig + transformer._wan_weights_offloaded = False #blockswap init init_blockswap(transformer, block_swap_args, model) + # Release the original state dict if still present (main sampler usually does this earlier, + # but keep it here for entry points that call multitalk_loop directly). + if patcher.model.get("sd") is not None: + log.info("MultiTalk: releasing patcher.model['sd'] to free memory") + patcher.model["sd"] = None + gc.collect() + mm.soft_empty_cache() + # Use the appropriate prompt for this section if len(text_embeds["prompt_embeds"]) > 1: prompt_index = min(iteration_count, len(text_embeds["prompt_embeds"]) - 1) @@ -564,6 +599,6 @@ def multitalk_loop(self, **kwargs): try: print_memory(device) torch.cuda.reset_peak_memory_stats(device) - except Exception: + except: pass return {"video": gen_video_samples.permute(1, 2, 3, 0), "output_path": output_path}, diff --git a/nodes.py b/nodes.py index f0b0e84d..378941bd 100644 --- a/nodes.py +++ b/nodes.py @@ -1,4 +1,5 @@ import os, gc, math +from collections import OrderedDict import torch import torch.nn.functional as F import hashlib @@ -148,8 +149,9 @@ def create_list(self, blocks): -# In-memory cache for prompt extender output -_extender_cache = {} +# In-memory cache for prompt extender output (LRU-bounded) +_MAX_EXTENDER_CACHE = 128 +_extender_cache = OrderedDict() cache_dir = os.path.join(script_directory, 'text_embed_cache') @@ -231,6 +233,7 @@ def process(self, model_name, precision, positive_prompt, negative_prompt, quant extender_key = (orig_prompt, str(extender_args)) if extender_key in _extender_cache: positive_prompt = _extender_cache[extender_key] + _extender_cache.move_to_end(extender_key) log.info(f"Loaded extended prompt from in-memory cache: {positive_prompt}") else: from .qwen.qwen import QwenLoader, WanVideoPromptExtender @@ -250,6 +253,9 @@ def process(self, model_name, precision, positive_prompt, negative_prompt, quant ) log.info(f"Extended positive prompt: {positive_prompt}") _extender_cache[extender_key] = positive_prompt + _extender_cache.move_to_end(extender_key) + while len(_extender_cache) > _MAX_EXTENDER_CACHE: + _extender_cache.popitem(last=False) del qwen pbar.update(1) diff --git a/nodes_model_loading.py b/nodes_model_loading.py index f5b7558e..09e9c0c5 100644 --- a/nodes_model_loading.py +++ b/nodes_model_loading.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn import os, gc, uuid -from .utils import log, apply_lora +from .utils import log, apply_lora, log_memory_peak, close_gguf_readers, log_ram_usage import numpy as np from tqdm import tqdm import re @@ -23,7 +23,7 @@ try: from .gguf.gguf import _replace_with_gguf_linear, GGUFParameter from gguf import GGMLQuantizationType -except Exception: +except: pass script_directory = os.path.dirname(os.path.abspath(__file__)) @@ -33,7 +33,7 @@ try: from server import PromptServer -except Exception: +except: PromptServer = None attention_modes = ["sdpa", "flash_attn_2", "flash_attn_3", "sageattn", "sageattn_3", "radial_sage_attention", "sageattn_compiled", @@ -105,9 +105,6 @@ def filter_state_dict_by_blocks(state_dict, blocks_mapping, layer_filter=[]): else: filtered_dict[key] = state_dict[key] - for key in filtered_dict: - print(key) - #from safetensors.torch import save_file #save_file(filtered_dict, "filtered_state_dict_2.safetensors") @@ -414,7 +411,7 @@ def getlorapath(self, lora, strength, unique_id, blocks={}, prev_lora=None, low_ try: lora_path = folder_paths.get_full_path_or_raise("loras", lora) - except Exception: + except: lora_path = lora # Load metadata from the safetensors file @@ -716,13 +713,16 @@ def load_lora_for_models_mod(model, lora, strength_model): key_map = model_lora_keys_unet(model.model, key_map) loaded = comfy.lora.load_lora(lora, key_map) + log_ram_usage(f"[RAM] load_lora_for_models_mod: After comfy.load_lora, {len(loaded)} loaded keys") new_modelpatcher = model.clone() + log_ram_usage(f"[RAM] load_lora_for_models_mod: After model.clone(), old patches={len(model.patches)}, new patches={len(new_modelpatcher.patches)}") k = add_patches(new_modelpatcher, loaded, strength_model) k = set(k) for x in loaded: if (x not in k): log.warning("NOT LOADED {}".format(x)) + log_ram_usage(f"[RAM] load_lora_for_models_mod: After add_patches, patches={len(new_modelpatcher.patches)}") return (new_modelpatcher) @@ -750,7 +750,9 @@ def setlora(self, model, lora=None): if lora is None: return (model,) + log_ram_usage("[RAM] WanVideoSetLoRAs.setlora START") patcher = model.clone() + log_ram_usage("[RAM] After model.clone()") merge_loras = False for l in lora: @@ -770,7 +772,11 @@ def setlora(self, model, lora=None): if lora_strength == 0: log.warning(f"LoRA {lora_path} has strength 0, skipping...") continue + log_ram_usage(f"[RAM] setlora: Before load_torch_file ({l['name']})") lora_sd = load_torch_file(lora_path, safe_load=True) + lora_bytes = sum(v.numel() * v.element_size() for v in lora_sd.values() if torch.is_tensor(v)) + log.info(f"[RAM] setlora: lora_sd size: {lora_bytes / (1024*1024):.1f} MB ({len(lora_sd)} tensors)") + log_ram_usage(f"[RAM] setlora: After load_torch_file ({l['name']})") if "dwpose_embedding.0.weight" in lora_sd: #unianimate raise NotImplementedError("Unianimate LoRA patching is not implemented in this node.") if "base_model.model.blocks.0.cross_attn.k.lora_A.weight" in lora_sd: # assume rs_lora @@ -787,10 +793,16 @@ def setlora(self, model, lora=None): if "diffusion_model.patch_embedding.lora_A.weight" in lora_sd: raise NotImplementedError("Control LoRA patching is not implemented in this node.") + log_ram_usage(f"[RAM] setlora: Before load_lora_for_models_mod ({l['name']})") patcher = load_lora_for_models_mod(patcher, lora_sd, lora_strength) + log_ram_usage(f"[RAM] setlora: After load_lora_for_models_mod ({l['name']}), patcher.patches={len(patcher.patches)}") + log_ram_usage(f"[RAM] setlora: Before del lora_sd ({l['name']})") del lora_sd + gc.collect() + log_ram_usage(f"[RAM] setlora: After del lora_sd + gc ({l['name']})") + log_ram_usage("[RAM] WanVideoSetLoRAs.setlora END") return (patcher,) def rename_fuser_block(name): @@ -812,56 +824,50 @@ def load_weights(transformer, sd=None, weight_dtype=None, base_dtype=None, pbar = ProgressBar(param_count) block_idx = vace_block_idx = None + # Initialize GGUF on-demand loading structures (used only when gguf=True, + # but defined here so the named_parameters loop can safely reference them). + extra_sd = {} + tensor_map = {} + if gguf: log.info("Using GGUF to load and assign model weights to device...") - # Prepare sd from GGUF readers - - # handle possible non-GGUF weights - extra_sd = {} - for key, value in sd.items(): - if value.device != torch.device("meta"): - extra_sd[key] = value + # Collect non-meta weights from the passed-in sd (if any). + # These are weights that were already loaded (not from GGUF reader). + if sd is not None: + for key, value in sd.items(): + if value.device != torch.device("meta"): + extra_sd[key] = value - sd = {} - all_tensors = [] + # Build name -> reader tensor mapping (NO data copies, just references). + # This avoids building a full sd dict with ALL tensor data copies in RAM + # simultaneously, which was the main cause of RAM exhaustion for large + # GGUF models (e.g. ~14GB for a 14B Q8_0 model). for r in reader: - all_tensors.extend(r.tensors) - for tensor in all_tensors: - name = rename_fuser_block(tensor.name) - if "glob" not in name and "multitalk_audio_proj" not in name and "audio_proj" in name: - name = name.replace("audio_proj", "multitalk_audio_proj") - load_device = device - if "vace_blocks." in name: - try: - vace_block_idx = int(name.split("vace_blocks.")[1].split(".")[0]) - except Exception: - vace_block_idx = None - elif "blocks." in name and "face" not in name: - try: - block_idx = int(name.split("blocks.")[1].split(".")[0]) - except Exception: - block_idx = None - - if block_swap_args is not None: - if block_idx is not None: - if block_idx >= len(transformer.blocks) - block_swap_args.get("blocks_to_swap", 0): - load_device = offload_device - elif vace_block_idx is not None: - if vace_block_idx >= len(transformer.vace_blocks) - block_swap_args.get("vace_blocks_to_swap", 0): - load_device = offload_device - - is_gguf_quant = tensor.tensor_type not in [GGMLQuantizationType.F32, GGMLQuantizationType.F16] - weights = torch.from_numpy(tensor.data.copy()).to(load_device) - sd[name] = GGUFParameter(weights, quant_type=tensor.tensor_type) if is_gguf_quant else weights - sd.update(extra_sd) - del all_tensors, extra_sd + for tensor in r.tensors: + name = rename_fuser_block(tensor.name) + if "glob" not in name and "multitalk_audio_proj" not in name and "audio_proj" in name: + name = name.replace("audio_proj", "multitalk_audio_proj") + tensor_map[name] = tensor + + # Use the original sd (meta tensors from load_gguf) + extra_sd for module + # replacement. _replace_linear only needs shapes and is_gguf flag, NOT + # actual weight data, so meta tensors are sufficient. + replacement_sd = {} + if sd is not None: + replacement_sd.update(sd) + replacement_sd.update(extra_sd) if not getattr(transformer, "gguf_patched", False): transformer = _replace_with_gguf_linear( - transformer, base_dtype, sd, patches=patcher.patches, compile_args=compile_args + transformer, base_dtype, replacement_sd, patches=patcher.patches, compile_args=compile_args ) transformer.gguf_patched = True + + del replacement_sd + # Set sd to None so the named_parameters loop uses on-demand loading + # from tensor_map instead of a pre-built sd dict. + sd = None else: log.info("Loading and assigning model weights to device...") named_params = transformer.named_parameters() @@ -890,7 +896,43 @@ def load_weights(transformer, sd=None, weight_dtype=None, base_dtype=None, continue key = name.replace("_orig_mod.", "") - value=sd[key] + + # Determine load_device early — needed for on-demand GGUF tensor loading + load_device = transformer_load_device + if block_swap_args is not None: + load_device = device + if block_idx is not None: + if block_idx >= len(transformer.blocks) - block_swap_args.get("blocks_to_swap", 0): + load_device = offload_device + elif vace_block_idx is not None: + if vace_block_idx >= len(transformer.vace_blocks) - block_swap_args.get("vace_blocks_to_swap", 0): + load_device = offload_device + + on_demand_value = None + if sd is None: + if gguf and key in tensor_map: + # On-demand load from GGUF reader: copy a single tensor's data + # and wrap as GGUFParameter. This avoids holding ALL tensor + # copies in RAM simultaneously (the previous approach built a + # full sd dict containing every tensor's data at once). + _tensor = tensor_map[key] + _is_gguf_quant = _tensor.tensor_type not in [GGMLQuantizationType.F32, GGMLQuantizationType.F16] + _weights = torch.from_numpy(_tensor.data.copy()).to(load_device) + on_demand_value = GGUFParameter(_weights, quant_type=_tensor.tensor_type) if _is_gguf_quant else _weights + del _weights + value = on_demand_value + elif key in extra_sd: + value = extra_sd[key] + elif param.data.device == torch.device("meta"): + log.warning(f"Parameter {name} is on meta device and sd is None; skipping reload") + continue + else: + value = param.data + else: + if key not in sd: + log.warning(f"Key {key} not found in sd; skipping") + continue + value = sd[key] keep_fp32 = ["patch_embedding", "motion_encoder", "condition_embedding"] if gguf: @@ -898,9 +940,17 @@ def load_weights(transformer, sd=None, weight_dtype=None, base_dtype=None, else: dtype_to_use = base_dtype if any(keyword in name for keyword in params_to_keep) else weight_dtype dtype_to_use = weight_dtype if value.dtype == weight_dtype else dtype_to_use - scale_key = key.replace(".weight", ".scale_weight") - if scale_key in sd: + + # Check scale_weight on the module itself so sd can be released after first load. + # CustomLinear (from _replace_linear) and fp8_optimization both store scale_weight + # as a module attribute, so we no longer need the original state dict. + if "." in name: + module = transformer.get_submodule(name.rsplit(".", 1)[0]) + else: + module = transformer + if getattr(module, "scale_weight", None) is not None: dtype_to_use = value.dtype + if "bias" in name or "img_emb" in name: dtype_to_use = base_dtype if any(k in name for k in keep_fp32): @@ -908,19 +958,29 @@ def load_weights(transformer, sd=None, weight_dtype=None, base_dtype=None, if "modulation" in name or "norm" in name: dtype_to_use = value.dtype if value.dtype == torch.float32 else base_dtype - load_device = transformer_load_device - if block_swap_args is not None: - load_device = device - if block_idx is not None: - if block_idx >= len(transformer.blocks) - block_swap_args.get("blocks_to_swap", 0): - load_device = offload_device - elif vace_block_idx is not None: - if vace_block_idx >= len(transformer.vace_blocks) - block_swap_args.get("vace_blocks_to_swap", 0): - load_device = offload_device # Set tensor to device set_module_tensor_to_device(transformer, name, device=load_device, dtype=dtype_to_use, value=value) pbar.update(1) + # Release on-demand loaded value immediately to keep peak RAM low. + # Only one tensor's data copy exists in RAM at any given time. + if on_demand_value is not None: + del on_demand_value + del value + + # For non-GGUF: release the source tensor from sd immediately after + # assignment. For GPU-bound tensors, the CPU copy in sd is now redundant + # (the data lives on GPU). For CPU-offloaded tensors (block swap), the + # module now holds the only reference, so no duplication occurs. + # This avoids the peak RAM double-copy where the full sd dict (~28GB + # for a 14B FP16 model) coexists with the model's parameters. + # Safe because: (1) each key is looked up exactly once in this loop, + # (2) all callers release sd right after load_weights returns, + # (3) apply_lora with low_mem_load=True is only used when load_weights + # is NOT called (they are mutually exclusive in the merge_loras path). + elif sd is not None and key in sd: + del sd[key] + #[print(name, param.device, param.dtype) for name, param in transformer.named_parameters()] memory_on_device = get_module_memory_mb_per_device(transformer) log.info("-" * 25) @@ -932,6 +992,30 @@ def load_weights(transformer, sd=None, weight_dtype=None, base_dtype=None, pbar._last_sent_value = -1 pbar.update_absolute(0) + # Release loading resources so no tensor data copies linger after the + # weights have been assigned to the model. + if gguf: + tensor_map.clear() + extra_sd.clear() + del tensor_map, extra_sd + # Close GGUF reader mmap to release system RAM. + # After all tensors are loaded, the entire GGUF file sits in the OS + # page cache (~10-15GB for large models). Closing the mmap releases + # those pages, reducing system RAM usage to just the actual model + # weights on CPU (~2GB for block-swapped layers). + # Readers are reopened on demand via reopen_gguf_readers() before + # weight reloads (block swap cycles). + close_gguf_readers(reader) + gc.collect() + mm.soft_empty_cache() + elif sd is not None: + # Non-GGUF: sd should be mostly empty now (entries popped during the + # loop), but clear any remaining keys (e.g. keys not matching any + # named_parameter) and force garbage collection. + sd.clear() + gc.collect() + mm.soft_empty_cache() + def patch_control_lora(transformer, device): log.info("Control-LoRA detected, patching model...") @@ -971,6 +1055,7 @@ def patch_stand_in_lora(transformer, lora_sd, transformer_load_device, base_dtyp param.data.copy_(lora_sd["diffusion_model." + name].to(param.device, dtype=param.dtype)) def add_lora_weights(patcher, lora, base_dtype, merge_loras=False): + log_ram_usage("[RAM] add_lora_weights START") unianimate_sd = None control_lora=False #spacepxl's control LoRA patch @@ -985,7 +1070,12 @@ def add_lora_weights(patcher, lora, base_dtype, merge_loras=False): if lora_strength == 0: log.warning(f"LoRA {lora_path} has strength 0, skipping...") continue + log_ram_usage(f"[RAM] Before load_torch_file: {l['name']}") lora_sd = load_torch_file(lora_path, safe_load=True) + log_ram_usage(f"[RAM] After load_torch_file ({l['name']}, {len(lora_sd)} keys)") + # Estimate lora_sd size + lora_bytes = sum(v.numel() * v.element_size() for v in lora_sd.values() if torch.is_tensor(v)) + log.info(f"[RAM] lora_sd estimated size: {lora_bytes / (1024*1024):.1f} MB ({len(lora_sd)} tensors)") if "dwpose_embedding.0.weight" in lora_sd: #unianimate from .unianimate.nodes import update_transformer log.info("Unianimate LoRA detected, patching model...") @@ -1008,9 +1098,16 @@ def add_lora_weights(patcher, lora, base_dtype, merge_loras=False): patch_stand_in_lora(patcher.model.diffusion_model, lora_sd, device, base_dtype, lora_strength) # normal LoRA patch else: + log_ram_usage(f"[RAM] Before load_lora_for_models ({l['name']})") patcher, _ = load_lora_for_models(patcher, None, lora_sd, lora_strength, 0) + log_ram_usage(f"[RAM] After load_lora_for_models ({l['name']}), patcher.patches now has {len(patcher.patches)} entries") + log_ram_usage(f"[RAM] Before del lora_sd ({l['name']})") del lora_sd + log_ram_usage(f"[RAM] After del lora_sd ({l['name']}) gc") + gc.collect() + log_ram_usage(f"[RAM] After gc.collect ({l['name']})") + log_ram_usage("[RAM] add_lora_weights END") return patcher, control_lora, unianimate_sd class WanVideoSetAttentionModeOverride: @@ -1151,7 +1248,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, try: if hasattr(torch.backends.cuda.matmul, "allow_fp16_accumulation"): torch.backends.cuda.matmul.allow_fp16_accumulation = False - except Exception: + except: pass @@ -1159,7 +1256,10 @@ def loadmodel(self, model, base_precision, load_device, quantization, gguf_reader = None if not gguf: - sd = load_torch_file(model_path, device=transformer_load_device, safe_load=True) + # Load the raw state dict to CPU first, then let load_weights() move + # individual tensors to the target device. This avoids a transient + # double-copy on the GPU (sd + transformer parameters) during load. + sd = load_torch_file(model_path, device=offload_device, safe_load=True) else: gguf_reader=[] from .gguf.gguf import load_gguf @@ -1238,7 +1338,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, else: if _model["path"].endswith(".gguf"): raise ValueError("With GGUF extra model the main model must also be GGUF quantized model") - extra_sd = load_torch_file(_model["path"], device=transformer_load_device, safe_load=True) + extra_sd = load_torch_file(_model["path"], device=offload_device, safe_load=True) if "audio_model.patch_embedding.0.weight" in extra_sd: extra_audio_model = True sd.update(extra_sd) @@ -1512,24 +1612,12 @@ def loadmodel(self, model, base_precision, load_device, quantization, block.cross_attn.ip_adapter_single_stream_v_proj = nn.Linear(context_dim, dim, bias=False) # LongCat Avatar - proj1_key = "multitalk_audio_proj.proj1.weight" if "multitalk_audio_proj.proj1.weight" in sd \ - else "multitalk_audio_proj.proj1.weight_int8" if "multitalk_audio_proj.proj1.weight_int8" in sd \ - else None - if proj1_key is not None and ("blocks.0.audio_cross_attn.q_norm.weight" in sd or "blocks.0.audio_cross_attn.q_norm.weight_int8" in sd): + if "multitalk_audio_proj.proj1.weight" in sd and "blocks.0.audio_cross_attn.q_norm.weight" in sd: log.info("MultiTalk/InfiniteTalk model detected, patching model...") from .multitalk.multitalk import AudioProjModel from .wanvideo.modules.model import WanLayerNorm from .LongCat.layers import SingleStreamAttention - # Detect LongCat-Avatar audio encoder variant from proj1 input dim: - # v1.0 (wav2vec2): seq_len * blocks * channels = 5 * 12 * 768 = 46080 - # v1.5 (whisper): seq_len * blocks * channels = 5 * 5 * 1280 = 32000 - proj1_in = sd[proj1_key].shape[1] - if proj1_in == 32000: - audio_proj_blocks, audio_proj_channels = 5, 1280 - log.info("LongCat-Avatar-1.5 (Whisper) audio proj detected") - else: - audio_proj_blocks, audio_proj_channels = 12, 768 for block in transformer.blocks: with init_empty_weights(): @@ -1546,7 +1634,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, class_interval=4, attention_mode=attention_mode, ) - multitalk_proj_model = AudioProjModel(blocks=audio_proj_blocks, channels=audio_proj_channels) + multitalk_proj_model = AudioProjModel() transformer.multitalk_audio_proj = multitalk_proj_model # SkyreelsV3 elif "blocks.1.audio_cross_attn.kv_linear.weight" in sd and "audio_proj.proj1.weight" in sd: @@ -1589,7 +1677,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, gguf_reader.append(extra_reader) del extra_reader else: - extra_sd_temp = load_torch_file(extra_model_path, device=transformer_load_device, safe_load=True) + extra_sd_temp = load_torch_file(extra_model_path, device=offload_device, safe_load=True) for k, v in extra_sd_temp.items(): extra_sd[k.replace("audio_proj.", "multitalk_audio_proj.")] = v @@ -1707,6 +1795,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, comfy_model.load_device = transformer_load_device patcher = comfy.model_patcher.ModelPatcher(comfy_model, device, offload_device) patcher.model.is_patched = False + log_ram_usage("[RAM] After ModelPatcher creation in loadmodel") scale_weights = {} if "fp8" in quantization: @@ -1729,7 +1818,9 @@ def loadmodel(self, model, base_precision, load_device, quantization, log.warning("Control-LoRA patching is only supported with merge_loras=True") if lora is not None: + log_ram_usage("[RAM] loadmodel: Before add_lora_weights") patcher, control_lora, unianimate_sd = add_lora_weights(patcher, lora, base_dtype, merge_loras=merge_loras) + log_ram_usage(f"[RAM] loadmodel: After add_lora_weights, patcher.patches={len(patcher.patches)}") if unianimate_sd is not None: log.info("Merging UniAnimate weights to the model...") sd.update(unianimate_sd) @@ -1739,6 +1830,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, if lora is not None and merge_loras: if not lora_low_mem_load: load_weights(transformer, sd, weight_dtype, base_dtype, transformer_load_device) + log_memory_peak("After model load_weights (merge_loras path)", device=transformer_load_device, reset_peak=True) if control_lora: patch_control_lora(patcher.model.diffusion_model, device) @@ -1762,6 +1854,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, raise NotImplementedError("fp8_fast is not supported with unmerged LoRAs") from .fp8_optimization import convert_fp8_linear convert_fp8_linear(transformer, base_dtype, params_to_keep, scale_weight_keys=scale_weights) + log_memory_peak("After FP8 conversion", device=device, reset_peak=True) if vram_management_args is not None: if gguf: @@ -1805,16 +1898,13 @@ def loadmodel(self, model, base_precision, load_device, quantization, ), compile_args = compile_args, ) + log_memory_peak("After vram_management", device=device, reset_peak=True) if merge_loras and lora is not None: - # Skip offloading if load_device is main_device (for unified memory systems like AMD Strix Halo) - if load_device != "main_device": - log.info(f"Moving diffusion model from {patcher.model.diffusion_model.device} to {offload_device}") - patcher.model.diffusion_model.to(offload_device) - gc.collect() - mm.soft_empty_cache() - else: - log.info(f"Skipping offload (load_device=main_device, keeping model on {patcher.model.diffusion_model.device})") + log.info(f"Moving diffusion model from {patcher.model.diffusion_model.device} to {offload_device}") + patcher.model.diffusion_model.to(offload_device) + gc.collect() + mm.soft_empty_cache() patcher.model["base_dtype"] = base_dtype patcher.model["weight_dtype"] = weight_dtype @@ -1835,9 +1925,7 @@ def loadmodel(self, model, base_precision, load_device, quantization, patcher.model_options["transformer_options"]["block_swap_args"] = block_swap_args patcher.model_options["transformer_options"]["merge_loras"] = merge_loras - for model in mm.current_loaded_models: - if model._model() == patcher: - mm.current_loaded_models.remove(model) + log_memory_peak("Model loader return", device=device, reset_peak=True) return (patcher,) # class WanVideoSaveModel: diff --git a/nodes_sampler.py b/nodes_sampler.py index 32b51c6a..f0b5c735 100644 --- a/nodes_sampler.py +++ b/nodes_sampler.py @@ -9,7 +9,7 @@ from .gguf.gguf import set_lora_params_gguf from .multitalk.multitalk import add_noise from .utils import(log, print_memory, apply_lora, fourier_filter, optimized_scale, setup_radial_attention, - compile_model, dict_to_device, tangential_projection, get_raag_guidance, temporal_score_rescaling, offload_transformer, init_blockswap) + compile_model, dict_to_device, tangential_projection, get_raag_guidance, temporal_score_rescaling, offload_transformer, init_blockswap, log_memory_peak, reopen_gguf_readers, log_ram_usage) from .multitalk.multitalk_loop import multitalk_loop from .cache_methods.cache_methods import cache_report from .nodes_model_loading import load_weights @@ -32,6 +32,28 @@ PATCH_SIZE = (1, 2, 2) +def _weights_load_signature(model, patcher, transformer, block_swap_args, device, weight_dtype, base_dtype, gguf_reader): + """Return a hashable signature describing the current weight-load configuration. + + Used to skip redundant full-model load_weights() calls when the model has + already been loaded with identical settings in a previous execution. + """ + try: + compile_args = model["compile_args"] + except KeyError: + compile_args = None + return ( + str(device), + str(weight_dtype), + str(base_dtype), + str(block_swap_args) if block_swap_args is not None else None, + str(compile_args) if compile_args is not None else None, + gguf_reader is not None, + getattr(transformer, "patched_linear", False), + len(patcher.patches), + ) + + class WanVideoSampler: @classmethod def INPUT_TYPES(s): @@ -83,6 +105,7 @@ def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, rifle experimental_args=None, sigmas=None, unianimate_poses=None, fantasytalking_embeds=None, uni3c_embeds=None, multitalk_embeds=None, freeinit_args=None, start_step=0, end_step=-1, add_noise_to_samples=False): if flowedit_args is not None: raise Exception("FlowEdit support has been deprecated and removed due to lack of use and code maintainability") + log_ram_usage("[RAM] Sampler process START") patcher = model model = model.model transformer = model.diffusion_model @@ -118,26 +141,85 @@ def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, rifle if hasattr(block, 'audio_block'): block.audio_block = None - if not transformer.patched_linear and patcher.model["sd"] is not None and len(patcher.patches) != 0 and gguf_reader is None: - transformer = _replace_linear(transformer, dtype, patcher.model["sd"], compile_args=model["compile_args"]) + if not transformer.patched_linear and len(patcher.patches) != 0 and gguf_reader is None: + transformer = _replace_linear(transformer, dtype, patcher.model["sd"], patches=patcher.patches, scale_weights=patcher.model.get("scale_weights", None), compile_args=model["compile_args"]) transformer.patched_linear = True - if patcher.model["sd"] is not None and gguf_reader is None: - load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, - block_swap_args=block_swap_args, compile_args=model["compile_args"]) + + load_sig = _weights_load_signature(model, patcher, transformer, block_swap_args, device, weight_dtype, dtype, gguf_reader) + weights_offloaded = getattr(transformer, "_wan_weights_offloaded", False) + weights_already_loaded = (getattr(transformer, "_wan_weights_load_signature", None) == load_sig) and not weights_offloaded + if not weights_already_loaded: + if gguf_reader is not None: #handle GGUF + log_ram_usage("[RAM] Before reopen + load_weights (GGUF)") + reopen_gguf_readers(gguf_reader) # Reopen mmap if closed from prior load + load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, + reader=gguf_reader, block_swap_args=block_swap_args, compile_args=model["compile_args"]) + else: + load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, + block_swap_args=block_swap_args, compile_args=model["compile_args"]) + transformer._wan_weights_load_signature = load_sig + transformer._wan_weights_offloaded = False + log_memory_peak("After load_weights", device=device, reset_peak=True) + else: + log.info("WanVideoSampler: transformer weights already loaded with matching configuration, skipping load_weights") + + # Release the original state dict to free memory now that weights are loaded + # into the transformer. Subsequent reloads can fall back to param.data. + if patcher.model["sd"] is not None: + log_ram_usage("[RAM] Before sd release") + log.info("WanVideoSampler: releasing patcher.model['sd'] to free memory") + patcher.model["sd"] = None + gc.collect() + mm.soft_empty_cache() + log_memory_peak("After sd release", device=device, reset_peak=True) + log_ram_usage("[RAM] After sd release, before LoRA setup") if gguf_reader is not None: #handle GGUF - load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, - reader=gguf_reader, block_swap_args=block_swap_args, compile_args=model["compile_args"]) - set_lora_params_gguf(transformer, patcher.patches) + log.info(f"[RAM] patcher.patches has {len(patcher.patches)} entries before set_lora_params_gguf") + # Show first 5 patcher.patches keys for diagnostic comparison + _sample_keys = list(patcher.patches.keys())[:5] + log.info(f"[RAM-diag] patcher.patches sample keys: {_sample_keys}") + log_ram_usage("[RAM] Before set_lora_params_gguf") + _diag = {} + lora_count, lora_bytes, lora_modules = set_lora_params_gguf( + transformer, patcher.patches, force_cpu=True, _diag=_diag + ) + _mismatches = _diag.get('_key_mismatches', []) + log.info( + f"[RAM] set_lora_params_gguf: {lora_count} params, " + f"{lora_bytes / (1024*1024):.1f} MB, " + f"{lora_modules} CustomLinear matched " + f"(total CL: {_diag.get('customlinear_total', '?')}, " + f"matched: {_diag.get('customlinear_matched', '?')}, " + f"bytes: {_diag.get('customlinear_bytes', 0) / (1024*1024):.1f} MB)" + ) + if _mismatches: + log.info(f"[RAM-diag] First unmatched CL keys (up to 5): {_mismatches}") + log_ram_usage("[RAM] After set_lora_params_gguf") transformer.patched_linear = True elif len(patcher.patches) != 0: #handle patched linear layers (unmerged loras, fp8 scaled) log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model") if not merge_loras and fp8_matmul: raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported") - set_lora_params(transformer, patcher.patches) + log_ram_usage("[RAM] Before set_lora_params (non-GGUF)") + set_lora_params(transformer, patcher.patches, force_cpu=True) + log_ram_usage("[RAM] After set_lora_params (non-GGUF)") else: remove_lora_from_module(transformer) #clear possible unmerged lora weights + # Free the original LoRA tensors in patcher.patches now that + # set_lora_params has copied all diffs into CustomLinear module + # attributes. Without this, patcher.patches holds ~1-2GB of LoRA + # tensor data in CPU RAM that is duplicated by the module copies. + if len(patcher.patches) > 0: + log_ram_usage("[RAM] Before patcher.patches.clear()") + patcher.patches.clear() + gc.collect() + mm.soft_empty_cache() + log_ram_usage("[RAM] After patcher.patches.clear() + gc") + + log_memory_peak("After LoRA setup", device=device, reset_peak=True) + transformer.lora_scheduling_enabled = transformer_options.get("lora_scheduling_enabled", False) #torch.compile @@ -564,7 +646,6 @@ def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, rifle # MultiTalk multitalk_audio_embeds = audio_emb_slice = audio_features_in = None - multitalk_audio_stride = None multitalk_embeds = image_embeds.get("multitalk_embeds", multitalk_embeds) if multitalk_embeds is not None: @@ -585,7 +666,6 @@ def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, rifle audio_scale = multitalk_embeds.get("audio_scale", 1.0) audio_cfg_scale = multitalk_embeds.get("audio_cfg_scale", 1.0) ref_target_masks = multitalk_embeds.get("ref_target_masks", None) - multitalk_audio_stride = multitalk_embeds.get("audio_stride", None) if not isinstance(audio_cfg_scale, list): audio_cfg_scale = [audio_cfg_scale] * (steps + 1) @@ -819,11 +899,7 @@ def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, rifle latent_video_length += insert_len longcat_num_cond_latents = len(clean_latent_indices) log.info(f"LongCat num_cond_latents: {longcat_num_cond_latents} num_ref_latents: {longcat_num_ref_latents}") - # v1.5 (Whisper) embeds set audio_stride=1; v1.0 (wav2vec2) uses 2 for LongCat - if multitalk_audio_stride is not None: - audio_stride = multitalk_audio_stride - else: - audio_stride = 2 if transformer.is_longcat else 1 + audio_stride = 2 if transformer.is_longcat else 1 #controlnet controlnet_latents = controlnet = None @@ -877,9 +953,11 @@ def process(self, model, image_embeds, shift, steps, cfg, seed, scheduler, rifle mm.unload_all_models() mm.soft_empty_cache() gc.collect() + log_memory_peak("Before init_blockswap", device=device, reset_peak=True) #blockswap init init_blockswap(transformer, block_swap_args, model) + log_memory_peak("After init_blockswap", device=device, reset_peak=True) # Initialize Cache if enabled previous_cache_states = None @@ -1185,6 +1263,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i return z*0, None nonlocal patcher + log_memory_peak(f"predict_with_cfg step {idx} start", device=device, reset_peak=True) current_step_percentage = idx / len(timesteps) control_lora_enabled = False image_cond_input = None @@ -1736,7 +1815,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i gc.collect() try: torch.cuda.reset_peak_memory_stats(device) - except Exception: + except: pass # Main sampling loop with FreeInit iterations @@ -1799,6 +1878,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i pbar = ProgressBar(len(timesteps) - ttm_start_step) #region main loop start for idx, t in enumerate(tqdm(timesteps[ttm_start_step:], disable=multitalk_sampling or wananimate_loop)): + log_memory_peak(f"Step {idx}/{len(timesteps)} start", device=device, reset_peak=True) if bidirectional_sampling: latent_flipped = torch.flip(latent, dims=[1]) @@ -2188,7 +2268,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i try: print_memory(device) torch.cuda.reset_peak_memory_stats(device) - except Exception: + except: pass return {"video": gen_video_samples}, # region wananimate loop @@ -2212,11 +2292,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i bg_images = image_embeds.get("bg_images", None) pose_images = image_embeds.get("pose_images", None) - current_ref_images = image_embeds.get("start_ref_image", None) - if current_ref_images is not None: - log.info( - "WanAnimate: Detected manual start reference image, enabling continuous generation across windows.") - face_images = face_images_in = None + current_ref_images = face_images = face_images_in = None if wananim_face_pixels is not None: face_images = tensor_pingpong_pad(wananim_face_pixels, target_len) @@ -2255,10 +2331,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i mm.soft_empty_cache() - if current_ref_images is not None: - mask_reft_len = refert_num - else: - mask_reft_len = 0 if start == 0 else refert_num + mask_reft_len = 0 if start == 0 else refert_num self.cache_state = [None, None] @@ -2367,11 +2440,17 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i latent = noise if offloaded: - # Load weights - if transformer.patched_linear and gguf_reader is None: - load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args) - elif gguf_reader is not None: #handle GGUF + # Load weights (only if configuration changed or weights were offloaded) + offloaded_sig = _weights_load_signature(model, patcher, transformer, block_swap_args, device, weight_dtype, dtype, gguf_reader) + offloaded_loaded = (getattr(transformer, "_wan_weights_load_signature", None) == offloaded_sig) and not getattr(transformer, "_wan_weights_offloaded", False) + if not offloaded_loaded: + if gguf_reader is not None: #handle GGUF + reopen_gguf_readers(gguf_reader) load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args) + else: + load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args) + transformer._wan_weights_load_signature = offloaded_sig + transformer._wan_weights_offloaded = False #blockswap init init_blockswap(transformer, block_swap_args, model) @@ -2437,7 +2516,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i videos = vae.decode(latent[:, 1:].unsqueeze(0).to(device, vae.dtype), device=device, tiled=tiled_vae, pbar=False)[0].cpu() del latent - if start != 0 or current_ref_images is not None: + if start != 0: videos = videos[:, refert_num:] sampling_pbar.close() @@ -2489,7 +2568,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i try: print_memory(device) torch.cuda.reset_peak_memory_stats(device) - except Exception: + except: pass return {"video": gen_video_samples.permute(1, 2, 3, 0), "output_path": output_path}, @@ -2503,6 +2582,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i humo_image_cond=humo_image_cond, humo_image_cond_neg=humo_image_cond_neg, humo_audio=humo_audio, humo_audio_neg=humo_audio_neg, wananim_face_pixels=wananim_face_pixels, wananim_pose_latents=wananim_pose_latents, uni3c_data = uni3c_data, latent_model_input_ovi=latent_model_input_ovi, flashvsr_LQ_latent=flashvsr_LQ_latent, ) + log_memory_peak(f"predict_with_cfg step {idx} return", device=device, reset_peak=False) if bidirectional_sampling: noise_pred_flipped, _,self.cache_state = predict_with_cfg( latent_model_input_flipped, @@ -2596,13 +2676,14 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i callback(idx, callback_latent.permute(1,0,2,3), None, len(timesteps)) else: pbar.update(1) + log_memory_peak(f"Step {idx}/{len(timesteps)} end", device=device, reset_peak=False) except Exception as e: log.error(f"Error during sampling: {e}") - raise - finally: - if force_offload and not model["auto_cpu_offload"]: - offload_transformer(transformer) + if force_offload: + if not model["auto_cpu_offload"]: + offload_transformer(transformer) + raise e if phantom_latents is not None: latent = latent[:,:-phantom_latents.shape[1]] @@ -2626,10 +2707,14 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i "magcache_state": transformer.magcache_state, } + if force_offload: + if not model["auto_cpu_offload"]: + offload_transformer(transformer) + try: print_memory(device) torch.cuda.reset_peak_memory_stats(device) - except Exception: + except: pass return ({ "samples": latent.unsqueeze(0).cpu(), @@ -2773,7 +2858,7 @@ def process(self, scheduler, steps, start_step, end_step, shift, unique_id, sigm import io import base64 import matplotlib.pyplot as plt - except Exception: + except: PromptServer = None if unique_id and PromptServer is not None: try: diff --git a/skyreels/nodes.py b/skyreels/nodes.py index a188edea..4be9fa70 100644 --- a/skyreels/nodes.py +++ b/skyreels/nodes.py @@ -1,7 +1,7 @@ import os import torch import gc -from ..utils import log, print_memory, fourier_filter, optimized_scale, setup_radial_attention, compile_model +from ..utils import log, print_memory, fourier_filter, optimized_scale, setup_radial_attention, compile_model, reopen_gguf_readers import math from tqdm import tqdm @@ -19,7 +19,7 @@ from comfy.utils import ProgressBar from comfy.cli_args import args, LatentPreviewMethod from ..nodes_model_loading import load_weights -from ..nodes_sampler import offload_transformer, init_blockswap +from ..nodes_sampler import offload_transformer, init_blockswap, _weights_load_signature from ..custom_linear import remove_lora_from_module, set_lora_params, _replace_linear device = mm.get_torch_device() @@ -171,24 +171,48 @@ def process(self, model, text_embeds, image_embeds, shift, fps, steps, addnoise_ vae_upscale_factor = 16 if is_5b else 8 # Load weights - if not transformer.patched_linear and patcher.model["sd"] is not None and len(patcher.patches) != 0: - transformer = _replace_linear(transformer, dtype, patcher.model["sd"], compile_args=model["compile_args"]) + if not transformer.patched_linear and len(patcher.patches) != 0: + transformer = _replace_linear(transformer, dtype, patcher.model["sd"], patches=patcher.patches, scale_weights=patcher.model.get("scale_weights", None), compile_args=model["compile_args"]) transformer.patched_linear = True - if patcher.model["sd"] is not None and gguf_reader is None: - load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args) + + load_sig = _weights_load_signature(model, patcher, transformer, block_swap_args, device, weight_dtype, dtype, gguf_reader) + weights_offloaded = getattr(transformer, "_wan_weights_offloaded", False) + weights_already_loaded = (getattr(transformer, "_wan_weights_load_signature", None) == load_sig) and not weights_offloaded + if not weights_already_loaded: + if gguf_reader is not None: #handle GGUF + reopen_gguf_readers(gguf_reader) + load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args) + else: + load_weights(patcher.model.diffusion_model, patcher.model["sd"], weight_dtype, base_dtype=dtype, transformer_load_device=device, block_swap_args=block_swap_args) + transformer._wan_weights_load_signature = load_sig + transformer._wan_weights_offloaded = False + else: + log.info("SkyReels: transformer weights already loaded with matching configuration, skipping load_weights") + + if patcher.model["sd"] is not None: + log.info("SkyReels: releasing patcher.model['sd'] to free memory") + patcher.model["sd"] = None + gc.collect() + mm.soft_empty_cache() if gguf_reader is not None: #handle GGUF - load_weights(transformer, patcher.model["sd"], base_dtype=dtype, transformer_load_device=device, patcher=patcher, gguf=True, reader=gguf_reader, block_swap_args=block_swap_args) - set_lora_params_gguf(transformer, patcher.patches) + set_lora_params_gguf(transformer, patcher.patches, force_cpu=True) transformer.patched_linear = True elif len(patcher.patches) != 0: #handle patched linear layers (unmerged loras, fp8 scaled) log.info(f"Using {len(patcher.patches)} LoRA weight patches for WanVideo model") if not merge_loras and fp8_matmul: raise NotImplementedError("FP8 matmul with unmerged LoRAs is not supported") - set_lora_params(transformer, patcher.patches) + set_lora_params(transformer, patcher.patches, force_cpu=True) else: remove_lora_from_module(transformer) #clear possible unmerged lora weights + # Free the original LoRA tensors in patcher.patches now that + # set_lora_params has copied all diffs into CustomLinear modules. + if len(patcher.patches) > 0: + patcher.patches.clear() + gc.collect() + mm.soft_empty_cache() + transformer.lora_scheduling_enabled = transformer_options.get("lora_scheduling_enabled", False) #torch.compile @@ -548,7 +572,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i gc.collect() try: torch.cuda.reset_peak_memory_stats(device) - except Exception: + except: pass #region main loop start @@ -615,7 +639,7 @@ def predict_with_cfg(z, cfg_scale, positive_embeds, negative_embeds, timestep, i try: print_memory(device) torch.cuda.reset_peak_memory_stats(device) - except Exception: + except: pass return ({ diff --git a/utils.py b/utils.py index e2bf1844..20c077d9 100644 --- a/utils.py +++ b/utils.py @@ -7,18 +7,16 @@ import gc import types, collections from comfy.utils import ProgressBar, copy_to_param, set_attr_param -from comfy.model_patcher import get_key_weight +from comfy.model_patcher import get_key_weight, string_to_seed from comfy.lora import calculate_weight -try: - from comfy.utils import string_to_seed -except Exception: - from comfy.model_patcher import string_to_seed - from comfy.float import stochastic_rounding from .custom_linear import remove_lora_from_module import folder_paths -logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +# Only configure logging if the root logger has no handlers, to avoid overriding +# ComfyUI's or other extensions' logging setup. +if not logging.getLogger().handlers: + logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') log = logging.getLogger(__name__) import comfy.model_management as mm @@ -27,7 +25,7 @@ try: from .gguf.gguf import GGUFParameter -except Exception: +except: pass COLOR_CODES = { @@ -60,6 +58,8 @@ def offload_transformer(transformer, remove_lora=True): transformer.easycache_state.clear_all() if transformer.patched_linear: + # Offload to CPU instead of meta, so weights can be reloaded from param.data + # when patcher.model["sd"] has been released. for name, param in transformer.named_parameters(): if "loras" in name or "controlnet" in name: continue @@ -69,11 +69,12 @@ def offload_transformer(transformer, remove_lora=True): module = getattr(module, subname) attr_name = subnames[-1] if param.data.is_floating_point(): - meta_param = torch.nn.Parameter(torch.empty_like(param.data, device='meta'), requires_grad=False) - setattr(module, attr_name, meta_param) + offloaded_param = torch.nn.Parameter(param.data.to(offload_device), requires_grad=False) + setattr(module, attr_name, offloaded_param) elif isinstance(param.data, GGUFParameter): quant_type = getattr(param, 'quant_type', None) - setattr(module, attr_name, MetaParameter(param.data.dtype, quant_type)) + offloaded_param = GGUFParameter(param.data.to(offload_device), quant_type=quant_type) + setattr(module, attr_name, offloaded_param) else: pass if remove_lora: @@ -86,38 +87,38 @@ def offload_transformer(transformer, remove_lora=True): if transformer.audio_model is not None and hasattr(block, 'audio_block'): block.audio_block = None + transformer._wan_weights_offloaded = True mm.soft_empty_cache() gc.collect() def init_blockswap(transformer, block_swap_args, model): - if not transformer.patched_linear: - if block_swap_args is not None: - for name, param in transformer.named_parameters(): - if "block" not in name or "control_adapter" in name or "face" in name: - param.data = param.data.to(device) - elif block_swap_args["offload_txt_emb"] and "txt_emb" in name: - param.data = param.data.to(offload_device) - elif block_swap_args["offload_img_emb"] and "img_emb" in name: - param.data = param.data.to(offload_device) - - transformer.block_swap( - block_swap_args["blocks_to_swap"] - 1 , - block_swap_args["offload_txt_emb"], - block_swap_args["offload_img_emb"], - vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None), - ) - elif model["auto_cpu_offload"]: - for module in transformer.modules(): - if hasattr(module, "offload"): - module.offload() - if hasattr(module, "onload"): - module.onload() - for block in transformer.blocks: - block.modulation = torch.nn.Parameter(block.modulation.to(device)) - transformer.head.modulation = torch.nn.Parameter(transformer.head.modulation.to(device)) - else: - transformer.to(device) + if block_swap_args is not None: + for name, param in transformer.named_parameters(): + if "block" not in name or "control_adapter" in name or "face" in name: + param.data = param.data.to(device) + elif block_swap_args["offload_txt_emb"] and "txt_emb" in name: + param.data = param.data.to(offload_device) + elif block_swap_args["offload_img_emb"] and "img_emb" in name: + param.data = param.data.to(offload_device) + + transformer.block_swap( + block_swap_args["blocks_to_swap"] - 1 , + block_swap_args["offload_txt_emb"], + block_swap_args["offload_img_emb"], + vace_blocks_to_swap = block_swap_args.get("vace_blocks_to_swap", None), + ) + elif model["auto_cpu_offload"]: + for module in transformer.modules(): + if hasattr(module, "offload"): + module.offload() + if hasattr(module, "onload"): + module.onload() + for block in transformer.blocks: + block.modulation = torch.nn.Parameter(block.modulation.to(device)) + transformer.head.modulation = torch.nn.Parameter(transformer.head.modulation.to(device)) + elif not transformer.patched_linear: + transformer.to(device) def check_device_same(first_device, second_device): if first_device.type != second_device.type: @@ -220,6 +221,35 @@ def print_memory(device, process="Sampling"): #memory_summary = torch.cuda.memory_summary(device=device, abbreviated=False) #log.info(f"Memory Summary:\n{memory_summary}") + +def log_ram_usage(tag="", gc_collect=False): + """Print process RAM (RSS) usage in MB. Optionally triggers gc first.""" + if gc_collect: + gc.collect() + try: + import psutil + proc = psutil.Process() + rss_mb = proc.memory_info().rss / (1024 * 1024) + vms_mb = proc.memory_info().vms / (1024 * 1024) + log.info(f"[RAM] {tag} | RSS: {rss_mb:.1f} MB | VMS: {vms_mb:.1f} MB") + except ImportError: + log.info(f"[RAM] {tag} | psutil not available, skipping RAM log") + + +def log_memory_peak(tag="", device=None, reset_peak=False): + """Print current and peak VRAM usage in GB.""" + if device is None: + device = mm.get_torch_device() + if not torch.cuda.is_available() or "cuda" not in str(device): + log.info(f"[Mem] {tag} | device={device} (non-CUDA)") + return + allocated = torch.cuda.memory_allocated(device) / 1024**3 + reserved = torch.cuda.memory_reserved(device) / 1024**3 + peak = torch.cuda.max_memory_allocated(device) / 1024**3 + log.info(f"[Mem] {tag} | Allocated: {allocated:.3f} GB | Reserved: {reserved:.3f} GB | Peak: {peak:.3f} GB") + if reset_peak: + torch.cuda.reset_peak_memory_stats(device) + def get_module_memory_mb(module): memory = 0 for param in module.parameters(): @@ -229,12 +259,10 @@ def get_module_memory_mb(module): def get_module_memory_mb_per_device(module): memory_per_device = {} - memory = 0 for param in module.parameters(): if param.data is not None: device = str(param.device) - memory += param.nelement() * param.element_size() - memory_per_device[device] = memory_per_device.get(device, 0) + memory + memory_per_device[device] = memory_per_device.get(device, 0) + param.nelement() * param.element_size() memory_per_device = {dev: mem / (1024 * 1024) for dev, mem in memory_per_device.items()} return memory_per_device @@ -309,7 +337,7 @@ def apply_lora(model, device_to, transformer_load_device, params_to_keep=None, d key = f"{name.replace('diffusion_model.', '')}.{param}" try: set_module_tensor_to_device(model.model.diffusion_model, key, device=transformer_load_device, dtype=dtype_to_use, value=state_dict[key]) - except Exception: + except: continue key = f"{name}.{param}" if scale_weights is not None: @@ -323,7 +351,7 @@ def apply_lora(model, device_to, transformer_load_device, params_to_keep=None, d if low_mem_load: try: set_module_tensor_to_device(model.model.diffusion_model, key, device=transformer_load_device, dtype=dtype_to_use, value=model.model.diffusion_model.state_dict()[key]) - except Exception: + except: continue m.comfy_patched_weights = True cnt += 1 @@ -352,7 +380,7 @@ def apply_lora(model, device_to, transformer_load_device, params_to_keep=None, d dtype_to_use = torch.float32 try: set_module_tensor_to_device(model.model.diffusion_model, name, device=transformer_load_device, dtype=dtype_to_use, value=state_dict[name]) - except Exception: + except: continue return model @@ -703,9 +731,8 @@ def check_duplicate_nodes(): # Check all directories in custom_nodes for path in custom_nodes_dir.iterdir(): - if (path.is_dir() and + if (path.is_dir() and path != current_path and - not path.name.endswith('.disabled') and 'wanvideo' in path.name.lower() and 'wrapper' in path.name.lower()): wanvideo_dirs.append(str(path)) @@ -776,3 +803,89 @@ def match_and_blend_colors( # [0,1] -> [-1,1] return (blended_rgb_01 * 2.0 - 1.0)[0].to(dtype=input_dtype) + + +# --------------------------------------------------------------------------- +# GGUF reader mmap lifecycle management +# --------------------------------------------------------------------------- +# The GGUFReader uses np.memmap to access GGUF files. During load_weights(), +# every tensor is read via tensor.data.copy(), which touches every page of +# the mmap'd file. The OS caches the entire file in the process working set +# (system RAM). For a 14B Q8 model, this adds ~14GB of system RAM usage that +# persists as long as the reader's mmap is alive. +# +# These helpers close the mmap after loading to release system RAM, and +# reopen the reader on demand for weights reload (block swap cycles). + +def close_gguf_readers(gguf_reader): + """Close memory-mapped files backing GGUF readers to release system RAM. + + Saves each reader's original file path so it can be reopened later via + reopen_gguf_readers(). After closing, the OS releases mmap pages from + the process working set, dramatically reducing system RAM usage (often + 10-15GB for large GGUF files like Wan2.1 14B Q8). + + Safe to call on None or an empty list. + """ + if gguf_reader is None or not gguf_reader: + return + import numpy as np + for r in gguf_reader: + if r is None: + continue + if hasattr(r, 'data') and hasattr(r.data, '_mmap'): + try: + # MUST save filename BEFORE closing — accessing .filename + # after _mmap.close() segfaults on some numpy versions. + r._gguf_filename = r.data.filename + except Exception: + r._gguf_filename = None + try: + r.data._mmap.close() + r._gguf_mmap_closed = True + except Exception: + pass + gc.collect() + + +def reopen_gguf_readers(gguf_reader): + """Reopen GGUF readers previously closed by close_gguf_readers(). + + Recreates GGUFReader objects from the saved filenames. This re-parses + the file header and re-mmaps the file, which is necessary before + calling load_weights() for a weights reload (e.g. after block swap + offload/reload cycles). + + Safe to call on None or an empty list. Readers that are still open + are left untouched. + """ + if gguf_reader is None or not gguf_reader: + return + from gguf import GGUFReader + reopened = False + for i, r in enumerate(gguf_reader): + if r is not None and getattr(r, '_gguf_mmap_closed', False): + fname = getattr(r, '_gguf_filename', None) + if fname: + # Explicitly clean up the old reader's tensor list and + # data BEFORE replacement. The old reader holds 500+ numpy + # view objects that keep the closed mmap's C-level memory + # allocation alive. Clearing these prevents a phantom RAM + # leak of 10-15GB per reopen cycle. + try: + if hasattr(r, 'tensors'): + r.tensors.clear() + del r.tensors + if hasattr(r, 'data'): + del r.data + except Exception: + pass + try: + new_reader = GGUFReader(fname) + gguf_reader[i] = new_reader + reopened = True + except Exception as e: + log.warning(f"Failed to reopen GGUF reader for {fname}: {e}") + if reopened: + del r # break loop variable reference to last replaced reader + gc.collect() diff --git a/wanvideo/modules/model.py b/wanvideo/modules/model.py index e2566ec8..0126106b 100644 --- a/wanvideo/modules/model.py +++ b/wanvideo/modules/model.py @@ -10,7 +10,7 @@ try: from ..radial_attention.attn_mask import RadialSpargeSageAttn, RadialSpargeSageAttnDense, MaskMap -except Exception: +except: pass from .attention import attention @@ -18,7 +18,7 @@ from tqdm import tqdm import gc -from ...utils import log, get_module_memory_mb +from ...utils import log, get_module_memory_mb, log_memory_peak from ...cache_methods.cache_methods import TeaCacheState, MagCacheState, EasyCacheState, relative_l1_distance from ...multitalk.multitalk import get_attn_map_with_target from ...echoshot.echoshot import rope_apply_z, rope_apply_c, rope_apply_echoshot @@ -647,7 +647,7 @@ def __init__(self, in_features, out_features, num_heads, kv_dim=None, qk_norm=Tr def forward(self, x, context, grid_sizes=None, clip_embed=None, audio_proj=None, audio_scale=1.0, num_latent_frames=21, nag_params={}, nag_context=None, rope_func="comfy", inner_t=None, inner_c=None, cross_freqs=None, - adapter_proj=None, ip_scale=1.0, orig_seq_len=None, lynx_x_ip=None, lynx_ip_scale=1.0, longcat_num_cond_latents=None, **kwargs): + adapter_proj=None, adapter_attn_mask=None, ip_scale=1.0, orig_seq_len=None, lynx_x_ip=None, lynx_ip_scale=1.0, longcat_num_cond_latents=None, **kwargs): b, n, d = x.size(0), self.num_heads, self.head_dim s = x.size(1) # compute query @@ -702,7 +702,7 @@ def forward(self, x, context, grid_sizes=None, clip_embed=None, audio_proj=None, # FantasyPortrait adapter attention if adapter_proj is not None: if len(adapter_proj.shape) == 4: - q_in = q[:, :orig_seq_len] + q_in = q[:, :orig_seq_len] adapter_q = q_in.view(b * num_latent_frames, -1, n, d) ip_key = self.ip_adapter_single_stream_k_proj(adapter_proj).view(b * num_latent_frames, -1, n, d) ip_value = self.ip_adapter_single_stream_v_proj(adapter_proj).view(b * num_latent_frames, -1, n, d) @@ -745,7 +745,7 @@ def __init__(self, in_features, out_features, num_heads, qk_norm=True, eps=1e-6, def forward(self, x, context, grid_sizes=None, clip_embed=None, audio_proj=None, audio_scale=1.0, num_latent_frames=21, nag_params={}, nag_context=None, rope_func="comfy", - adapter_proj=None, ip_scale=1.0, orig_seq_len=None, **kwargs): + adapter_proj=None, adapter_attn_mask=None, ip_scale=1.0, orig_seq_len=None, **kwargs): r""" Args: x(Tensor): Shape [B, L1, C] @@ -758,21 +758,23 @@ def forward(self, x, context, grid_sizes=None, clip_embed=None, audio_proj=None, if nag_context is not None: x_positive, x_negative = self.nag_attention(b, n, d, q, context, nag_context) - x = self.normalized_attention_guidance(x_positive, x_negative, nag_params) + x_text = self.normalized_attention_guidance(x_positive, x_negative, nag_params) del x_positive, x_negative else: # text attention k = self.norm_k(self.k(context).to(self.norm_k.weight.dtype)).view(b, -1, n, d).to(x.dtype) v = self.v(context).view(b, -1, n, d) - x = attention(q, k, v, attention_mode=self.attention_mode, heads=self.num_heads).flatten(2) - del k, v + x_text = attention(q, k, v, attention_mode=self.attention_mode, heads=self.num_heads).flatten(2) #img attention if clip_embed is not None: k_img = self.norm_k_img(self.k_img(clip_embed).to(self.norm_k_img.weight.dtype)).view(b, -1, n, d).to(x.dtype) v_img = self.v_img(clip_embed).view(b, -1, n, d) - x.add_(attention(q, k_img, v_img, attention_mode=self.attention_mode, heads=self.num_heads).flatten(2)) - del k_img, v_img + img_x = attention(q, k_img, v_img, attention_mode=self.attention_mode, heads=self.num_heads).flatten(2) + x_text.add_(img_x) + x = x_text + else: + x = x_text # FantasyTalking audio attention if audio_proj is not None: @@ -805,7 +807,7 @@ def forward(self, x, context, grid_sizes=None, clip_embed=None, audio_proj=None, adapter_x = attention(q, ip_key, ip_value, attention_mode=self.attention_mode, heads=self.num_heads) adapter_x = adapter_x.flatten(2) x = x + adapter_x * ip_scale - del q + return self.o(x) class WanHuMoCrossAttention(WanSelfAttention): @@ -1039,6 +1041,10 @@ def forward( T = num_latent_frames is_longcat = C == 4096 + if self.block_idx == 0: + log.info(f"WanAttentionBlock input: shape={x.shape}, dtype={x.dtype}, seq_len={N}") + log_memory_peak(f"WanAttentionBlock step {current_step} start", device=x.device, reset_peak=True) + zero_timestep = len(e) == 2 if zero_timestep: #s2v zero timestep self.seg_idx = e[1] @@ -2206,7 +2212,7 @@ def wananimate_forward(self, block, x, motion_vec, strength=1.0, motion_masks=No def rope_encode_comfy(self, t, h, w, freq_offset=0, t_start=0, ref_frame_shape=None, pose_frame_shape=None, steps_t=None, steps_h=None, steps_w=None, ntk_alphas=[1,1,1], device=None, dtype=None, - ref_frame_index=10, longcat_num_ref_latents=0, num_memory_frames=3, rope_negative_offset=0): + ref_frame_index=10, longcat_num_ref_latents=0, num_memory_frames=3, rope_negative_offset=5): patch_size = self.patch_size t_len = ((t + (patch_size[0] // 2)) // patch_size[0]) diff --git a/wanvideo/radial_attention/attn_mask.py b/wanvideo/radial_attention/attn_mask.py index dd3ff5f5..37a8c73b 100644 --- a/wanvideo/radial_attention/attn_mask.py +++ b/wanvideo/radial_attention/attn_mask.py @@ -1,18 +1,19 @@ # based on https://github.com/mit-han-lab/radial-attention/blob/main/radial_attn/attn_mask.py import torch +from collections import OrderedDict try: from spas_sage_attn import block_sparse_sage2_attn_cuda sparse_attn_func = block_sparse_sage2_attn_cuda -except Exception: +except: try: from sparse_sageattn import sparse_sageattn sparse_attn_func = sparse_sageattn - except Exception: + except: try: from .sparse_sage.core import sparse_sageattn sparse_attn_func = sparse_sageattn - except Exception: + except: sparse_sageattn = None raise ImportError("sparse_sageattn is not available. Please install the sparse_sageattn package or check your import path.") @@ -138,9 +139,10 @@ def RadialSpargeSageAttnDense(query, key, value, mask_map): @torch.compiler.disable() def RadialSpargeSageAttn(query, key, value, mask_map, decay_factor): - # Simple cache based on function arguments + # Simple LRU-bounded cache based on function arguments + _MAX_RADIAL_CACHE = 8 if not hasattr(RadialSpargeSageAttn, "_cache"): - RadialSpargeSageAttn._cache = {} + RadialSpargeSageAttn._cache = OrderedDict() # print(mask_map.block_size) block_size = mask_map.block_size cache_key = ( @@ -152,6 +154,7 @@ def RadialSpargeSageAttn(query, key, value, mask_map, decay_factor): ) if cache_key in RadialSpargeSageAttn._cache: input_mask = RadialSpargeSageAttn._cache[cache_key] + RadialSpargeSageAttn._cache.move_to_end(cache_key) else: print("Radial Attention: Generating block mask") video_mask = mask_map.queryLogMask(query.shape[0] * query.shape[1], "radial", block_size=block_size, decay_factor=decay_factor) @@ -164,6 +167,9 @@ def RadialSpargeSageAttn(query, key, value, mask_map, decay_factor): mask = torch.max(reshaped_mask, dim=1).values input_mask = mask.unsqueeze(0).unsqueeze(1).expand(1, query.shape[-2], mask.shape[0], mask.shape[1]) RadialSpargeSageAttn._cache[cache_key] = input_mask + RadialSpargeSageAttn._cache.move_to_end(cache_key) + while len(RadialSpargeSageAttn._cache) > _MAX_RADIAL_CACHE: + RadialSpargeSageAttn._cache.popitem(last=False) return sparse_attn_func( query[:, :, :mask_map.video_token_num, :],