Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,10 @@ def forward_orig(
pe = None

blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.double_blocks):
transformer_options["block_index"] = i
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
Expand Down Expand Up @@ -215,7 +218,10 @@ def block_wrap(args):
if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig)

transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single"
for i, block in enumerate(self.single_blocks):
transformer_options["block_index"] = i
if ("single_block", i) in blocks_replace:
def block_wrap(args):
out = {}
Expand Down
6 changes: 3 additions & 3 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,7 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
loaded_memory = loaded_model.model_loaded_memory()
current_free_mem = get_free_memory(torch_dev) + loaded_memory

lowvram_model_memory = max(128 * 1024 * 1024, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = max(0, (current_free_mem - minimum_memory_required), min(current_free_mem * MIN_WEIGHT_MEMORY_RATIO, current_free_mem - minimum_inference_memory()))
lowvram_model_memory = lowvram_model_memory - loaded_memory

if lowvram_model_memory == 0:
Expand Down Expand Up @@ -1012,7 +1012,7 @@ def force_channels_last():


STREAMS = {}
NUM_STREAMS = 1
NUM_STREAMS = 0
if args.async_offload:
NUM_STREAMS = 2
logging.info("Using async weight offloading with {} streams".format(NUM_STREAMS))
Expand All @@ -1030,7 +1030,7 @@ def current_stream(device):
stream_counters = {}
def get_offload_stream(device):
stream_counter = stream_counters.get(device, 0)
if NUM_STREAMS <= 1:
if NUM_STREAMS == 0:
return None

if device in STREAMS:
Expand Down
59 changes: 45 additions & 14 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,15 @@ def __call__(self, weight):
else:
return out

#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3

def low_vram_patch_estimate_vram(model, key):
weight, set_func, convert_func = get_key_weight(model, key)
if weight is None:
return 0
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR

def get_key_weight(model, key):
set_func = None
convert_func = None
Expand Down Expand Up @@ -269,6 +278,9 @@ def __init__(self, model, load_device, offload_device, size=0, weight_inplace_up
if not hasattr(self.model, 'current_weight_patches_uuid'):
self.model.current_weight_patches_uuid = None

if not hasattr(self.model, 'model_offload_buffer_memory'):
self.model.model_offload_buffer_memory = 0

def model_size(self):
if self.size > 0:
return self.size
Expand Down Expand Up @@ -662,7 +674,16 @@ def _load_list(self):
skip = True # skip random weights in non leaf modules
break
if not skip and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
loading.append((comfy.model_management.module_size(m), n, m, params))
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
if hasattr(m, "comfy_cast_weights"):
weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)
if weight_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
if bias_key in self.patches:
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
loading.append((module_offload_mem, module_mem, n, m, params))
return loading

def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
Expand All @@ -676,20 +697,22 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False

load_completely = []
offloaded = []
offload_buffer = 0
loading.sort(reverse=True)
for x in loading:
n = x[1]
m = x[2]
params = x[3]
module_mem = x[0]
module_offload_mem, module_mem, n, m, params = x

lowvram_weight = False

potential_offload = max(offload_buffer, module_offload_mem * (comfy.model_management.NUM_STREAMS + 1))
lowvram_fits = mem_counter + module_mem + potential_offload < lowvram_model_memory

weight_key = "{}.weight".format(n)
bias_key = "{}.bias".format(n)

if not full_load and hasattr(m, "comfy_cast_weights"):
if mem_counter + module_mem >= lowvram_model_memory:
if not lowvram_fits:
offload_buffer = potential_offload
lowvram_weight = True
lowvram_counter += 1
lowvram_mem_counter += module_mem
Expand Down Expand Up @@ -723,9 +746,11 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
if hasattr(m, "comfy_cast_weights"):
wipe_lowvram_weight(m)

if full_load or mem_counter + module_mem < lowvram_model_memory:
if full_load or lowvram_fits:
mem_counter += module_mem
load_completely.append((module_mem, n, m, params))
else:
offload_buffer = potential_offload

if cast_weight and hasattr(m, "comfy_cast_weights"):
m.prev_comfy_cast_weights = m.comfy_cast_weights
Expand Down Expand Up @@ -766,7 +791,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
self.pin_weight_to_device("{}.{}".format(n, param))

if lowvram_counter > 0:
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), patch_counter))
logging.info("loaded partially; {:.2f} MB usable, {:.2f} MB loaded, {:.2f} MB offloaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), lowvram_mem_counter / (1024 * 1024), offload_buffer / (1024 * 1024), patch_counter))
self.model.model_lowvram = True
else:
logging.info("loaded completely; {:.2f} MB usable, {:.2f} MB loaded, full load: {}".format(lowvram_model_memory / (1024 * 1024), mem_counter / (1024 * 1024), full_load))
Expand All @@ -778,6 +803,7 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
self.model.lowvram_patch_counter += patch_counter
self.model.device = device_to
self.model.model_loaded_weight_memory = mem_counter
self.model.model_offload_buffer_memory = offload_buffer
self.model.current_weight_patches_uuid = self.patches_uuid

for callback in self.get_all_callbacks(CallbacksMP.ON_LOAD):
Expand Down Expand Up @@ -831,6 +857,7 @@ def unpatch_model(self, device_to=None, unpatch_weights=True):
self.model.to(device_to)
self.model.device = device_to
self.model.model_loaded_weight_memory = 0
self.model.model_offload_buffer_memory = 0

for m in self.model.modules():
if hasattr(m, "comfy_patched_weights"):
Expand All @@ -849,13 +876,14 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
patch_counter = 0
unload_list = self._load_list()
unload_list.sort()
offload_buffer = self.model.model_offload_buffer_memory

for unload in unload_list:
if memory_to_free < memory_freed:
if memory_to_free + offload_buffer - self.model.model_offload_buffer_memory < memory_freed:
break
module_mem = unload[0]
n = unload[1]
m = unload[2]
params = unload[3]
module_offload_mem, module_mem, n, m, params = unload

potential_offload = (comfy.model_management.NUM_STREAMS + 1) * module_offload_mem

lowvram_possible = hasattr(m, "comfy_cast_weights")
if hasattr(m, "comfy_patched_weights") and m.comfy_patched_weights == True:
Expand Down Expand Up @@ -906,15 +934,18 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
m.comfy_cast_weights = True
m.comfy_patched_weights = False
memory_freed += module_mem
offload_buffer = max(offload_buffer, potential_offload)
logging.debug("freed {}".format(n))

for param in params:
self.pin_weight_to_device("{}.{}".format(n, param))


self.model.model_lowvram = True
self.model.lowvram_patch_counter += patch_counter
self.model.model_loaded_weight_memory -= memory_freed
logging.info("loaded partially: {:.2f} MB loaded, lowvram patches: {}".format(self.model.model_loaded_weight_memory / (1024 * 1024), self.model.lowvram_patch_counter))
self.model.model_offload_buffer_memory = offload_buffer
logging.info("Unloaded partially: {:.2f} MB freed, {:.2f} MB remains loaded, {:.2f} MB buffer reserved, lowvram patches: {}".format(memory_freed / (1024 * 1024), self.model.model_loaded_weight_memory / (1024 * 1024), offload_buffer / (1024 * 1024), self.model.lowvram_patch_counter))
return memory_freed

def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
Expand Down
Loading