Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion comfy/ldm/kandinsky5/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,9 @@ def block_wrap(args):
return self.out_layer(visual_embed, time_embed)

def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
original_dims = x.ndim
if original_dims == 4:
x = x.unsqueeze(2)
bs, c, t_len, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)

Expand All @@ -397,7 +400,10 @@ def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_o
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)

return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
if original_dims == 4:
out = out.squeeze(2)
return out

def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
return comfy.patcher_extension.WrapperExecutor.new_class_executor(
Expand Down
14 changes: 14 additions & 0 deletions comfy/model_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -1492,6 +1492,20 @@ def extended_fp16_support():

return True

LORA_COMPUTE_DTYPES = {}
def lora_compute_dtype(device):
dtype = LORA_COMPUTE_DTYPES.get(device, None)
if dtype is not None:
return dtype

if should_use_fp16(device):
dtype = torch.float16
else:
dtype = torch.float32

LORA_COMPUTE_DTYPES[device] = dtype
return dtype

def soft_empty_cache(force=False):
global cpu_state
if cpu_state == CPUState.MPS:
Expand Down
7 changes: 5 additions & 2 deletions comfy/model_patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,10 +614,11 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)

temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
if device_to is not None:
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
else:
temp_weight = weight.to(torch.float32, copy=True)
temp_weight = weight.to(temp_dtype, copy=True)
if convert_func is not None:
temp_weight = convert_func(temp_weight, inplace=True)

Expand Down Expand Up @@ -761,6 +762,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
key = "{}.{}".format(n, param)
self.unpin_weight(key)
self.patch_weight_to_device(key, device_to=device_to)
if comfy.model_management.is_device_cuda(device_to):
torch.cuda.synchronize()

logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
m.comfy_patched_weights = True
Expand Down
12 changes: 6 additions & 6 deletions comfy/text_encoders/kandinsky5.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ def __init__(self, embedding_directory=None, tokenizer_data={}):

class Qwen25_7BVLIModel(sd1_clip.SDClipModel):
def __init__(self, device="cpu", layer="hidden", layer_idx=-1, dtype=None, attention_mask=True, model_options={}):
llama_scaled_fp8 = model_options.get("qwen_scaled_fp8", None)
if llama_scaled_fp8 is not None:
llama_quantization_metadata = model_options.get("llama_quantization_metadata", None)
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["scaled_fp8"] = llama_scaled_fp8
model_options["quantization_metadata"] = llama_quantization_metadata
super().__init__(device=device, layer=layer, layer_idx=layer_idx, textmodel_json_config={}, dtype=dtype, special_tokens={"pad": 151643}, layer_norm_hidden_state=False, model_class=Qwen25_7BVLI, enable_attention_masks=attention_mask, return_attention_masks=attention_mask, model_options=model_options)


Expand Down Expand Up @@ -56,12 +56,12 @@ def load_sd(self, sd):
else:
return super().load_sd(sd)

def te(dtype_llama=None, llama_scaled_fp8=None):
def te(dtype_llama=None, llama_quantization_metadata=None):
class Kandinsky5TEModel_(Kandinsky5TEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_scaled_fp8 is not None and "scaled_fp8" not in model_options:
if llama_quantization_metadata is not None:
model_options = model_options.copy()
model_options["qwen_scaled_fp8"] = llama_scaled_fp8
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(device=device, dtype=dtype, model_options=model_options)
Expand Down
Loading