Skip to content

Commit 9cf299a

Browse files
Make regular empty latent node work properly on flux 2 variants. (Comfy-Org#12050)
1 parent e89b229 commit 9cf299a

File tree

5 files changed

+20
-8
lines changed

5 files changed

+20
-8
lines changed

comfy/latent_formats.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class LatentFormat:
88
latent_rgb_factors_bias = None
99
latent_rgb_factors_reshape = None
1010
taesd_decoder_name = None
11+
spacial_downscale_ratio = 8
1112

1213
def process_in(self, latent):
1314
return latent * self.scale_factor
@@ -181,6 +182,7 @@ def process_out(self, latent):
181182

182183
class Flux2(LatentFormat):
183184
latent_channels = 128
185+
spacial_downscale_ratio = 16
184186

185187
def __init__(self):
186188
self.latent_rgb_factors =[
@@ -749,6 +751,7 @@ class ACEAudio(LatentFormat):
749751

750752
class ChromaRadiance(LatentFormat):
751753
latent_channels = 3
754+
spacial_downscale_ratio = 1
752755

753756
def __init__(self):
754757
self.latent_rgb_factors = [

comfy/sample.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,18 @@ def prepare_noise(latent_image, seed, noise_inds=None):
3737

3838
return noises
3939

40-
def fix_empty_latent_channels(model, latent_image):
40+
def fix_empty_latent_channels(model, latent_image, downscale_ratio_spacial=None):
4141
if latent_image.is_nested:
4242
return latent_image
4343
latent_format = model.get_model_object("latent_format") #Resize the empty latent image so it has the right number of channels
44-
if latent_format.latent_channels != latent_image.shape[1] and torch.count_nonzero(latent_image) == 0:
45-
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
44+
if torch.count_nonzero(latent_image) == 0:
45+
if latent_format.latent_channels != latent_image.shape[1]:
46+
latent_image = comfy.utils.repeat_to_batch_size(latent_image, latent_format.latent_channels, dim=1)
47+
if downscale_ratio_spacial is not None:
48+
if downscale_ratio_spacial != latent_format.spacial_downscale_ratio:
49+
ratio = downscale_ratio_spacial / latent_format.spacial_downscale_ratio
50+
latent_image = comfy.utils.common_upscale(latent_image, round(latent_image.shape[-1] * ratio), round(latent_image.shape[-2] * ratio), "nearest-exact", crop="disabled")
51+
4652
if latent_format.latent_dimensions == 3 and latent_image.ndim == 4:
4753
latent_image = latent_image.unsqueeze(2)
4854
return latent_image

comfy_extras/nodes_custom_sampler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler,
741741
latent = latent_image
742742
latent_image = latent["samples"]
743743
latent = latent.copy()
744-
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
744+
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
745745
latent["samples"] = latent_image
746746

747747
if not add_noise:
@@ -760,6 +760,7 @@ def execute(cls, model, add_noise, noise_seed, cfg, positive, negative, sampler,
760760
samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed)
761761

762762
out = latent.copy()
763+
out.pop("downscale_ratio_spacial", None)
763764
out["samples"] = samples
764765
if "x0" in x0_output:
765766
x0_out = model.model.process_latent_out(x0_output["x0"].cpu())
@@ -939,7 +940,7 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
939940
latent = latent_image
940941
latent_image = latent["samples"]
941942
latent = latent.copy()
942-
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image)
943+
latent_image = comfy.sample.fix_empty_latent_channels(guider.model_patcher, latent_image, latent.get("downscale_ratio_spacial", None))
943944
latent["samples"] = latent_image
944945

945946
noise_mask = None
@@ -954,6 +955,7 @@ def execute(cls, noise, guider, sampler, sigmas, latent_image) -> io.NodeOutput:
954955
samples = samples.to(comfy.model_management.intermediate_device())
955956

956957
out = latent.copy()
958+
out.pop("downscale_ratio_spacial", None)
957959
out["samples"] = samples
958960
if "x0" in x0_output:
959961
x0_out = guider.model_patcher.model.process_latent_out(x0_output["x0"].cpu())

comfy_extras/nodes_sd3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def define_schema(cls):
5555
@classmethod
5656
def execute(cls, width, height, batch_size=1) -> io.NodeOutput:
5757
latent = torch.zeros([batch_size, 16, height // 8, width // 8], device=comfy.model_management.intermediate_device())
58-
return io.NodeOutput({"samples":latent})
58+
return io.NodeOutput({"samples": latent, "downscale_ratio_spacial": 8})
5959

6060
generate = execute # TODO: remove
6161

nodes.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,7 +1230,7 @@ def INPUT_TYPES(s):
12301230

12311231
def generate(self, width, height, batch_size=1):
12321232
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
1233-
return ({"samples":latent}, )
1233+
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
12341234

12351235

12361236
class LatentFromBatch:
@@ -1538,7 +1538,7 @@ def set_mask(self, samples, mask):
15381538

15391539
def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, negative, latent, denoise=1.0, disable_noise=False, start_step=None, last_step=None, force_full_denoise=False):
15401540
latent_image = latent["samples"]
1541-
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image)
1541+
latent_image = comfy.sample.fix_empty_latent_channels(model, latent_image, latent.get("downscale_ratio_spacial", None))
15421542

15431543
if disable_noise:
15441544
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
@@ -1556,6 +1556,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive,
15561556
denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step,
15571557
force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
15581558
out = latent.copy()
1559+
out.pop("downscale_ratio_spacial", None)
15591560
out["samples"] = samples
15601561
return (out, )
15611562

0 commit comments

Comments
 (0)