@@ -29,20 +29,26 @@ def execute(cls, *, width: int, height: int, batch_size: int=1) -> io.NodeOutput
2929
3030
3131class ChromaRadianceStubVAE :
32- @classmethod
33- def encode (cls , pixels : torch .Tensor , * _args , ** _kwargs ) -> torch .Tensor :
34- device = comfy .model_management .intermediate_device ()
35- if pixels .ndim == 3 :
36- pixels = pixels .unsqueeze (0 )
37- elif pixels .ndim != 4 :
38- raise ValueError ("Unexpected input image shape" )
32+ @staticmethod
33+ def vae_encode_crop_pixels (pixels : torch .Tensor ) -> torch .Tensor :
3934 dims = pixels .shape [1 :- 1 ]
4035 for d in range (len (dims )):
4136 d_adj = (dims [d ] // 16 ) * 16
4237 if d_adj == d :
4338 continue
4439 d_offset = (dims [d ] % 16 ) // 2
4540 pixels = pixels .narrow (d + 1 , d_offset , d_adj )
41+ return pixels
42+
43+ @classmethod
44+ def encode (cls , pixels : torch .Tensor , * _args , ** _kwargs ) -> torch .Tensor :
45+ device = comfy .model_management .intermediate_device ()
46+ if pixels .ndim == 3 :
47+ pixels = pixels .unsqueeze (0 )
48+ elif pixels .ndim != 4 :
49+ raise ValueError ("Unexpected input image shape" )
50+ # Ensure the image has spatial dimensions that are multiples of 16.
51+ pixels = cls .vae_encode_crop_pixels (pixels )
4652 h , w , c = pixels .shape [1 :]
4753 if h < 16 or w < 16 :
4854 raise ValueError ("Chroma Radiance image inputs must have height/width of at least 16 pixels." )
@@ -51,6 +57,7 @@ def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
5157 pixels = pixels .expand (- 1 , - 1 , - 1 , 3 )
5258 elif c != 3 :
5359 raise ValueError ("Unexpected number of channels in input image" )
60+ # Rescale to -1..1 and move the channel dimension to position 1.
5461 latent = pixels .to (device = device , dtype = torch .float32 , copy = True )
5562 latent = latent .clamp_ (0 , 1 ).movedim (- 1 , 1 ).contiguous ()
5663 latent -= 0.5
@@ -60,6 +67,7 @@ def encode(cls, pixels: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
6067 @classmethod
6168 def decode (cls , samples : torch .Tensor , * _args , ** _kwargs ) -> torch .Tensor :
6269 device = comfy .model_management .intermediate_device ()
70+ # Rescale to 0..1 and move the channel dimension to the end.
6371 img = samples .to (device = device , dtype = torch .float32 , copy = True )
6472 img = img .clamp_ (- 1 , 1 ).movedim (1 , - 1 ).contiguous ()
6573 img += 1.0
@@ -71,6 +79,7 @@ def decode(cls, samples: torch.Tensor, *_args, **_kwargs) -> torch.Tensor:
7179
7280 @classmethod
7381 def spacial_compression_decode (cls ) -> int :
82+ # This just exists so the tiled VAE nodes don't crash.
7483 return 1
7584
7685 spacial_compression_encode = spacial_compression_decode
@@ -115,7 +124,7 @@ def define_schema(cls) -> io.Schema:
115124 return io .Schema (
116125 node_id = "ChromaRadianceStubVAE" ,
117126 category = "vae/chroma_radiance" ,
118- description = "For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary." ,
127+ description = "For use with Chroma Radiance. Allows converting between latent and image types with nodes that require a VAE input. Note: Chroma Radiance requires inputs with width/height that are multiples of 16 so your image will be cropped if necessary." ,
119128 outputs = [io .Vae .Output ()],
120129 )
121130
@@ -129,37 +138,39 @@ def define_schema(cls) -> io.Schema:
129138 return io .Schema (
130139 node_id = "ChromaRadianceOptions" ,
131140 category = "model_patches/chroma_radiance" ,
132- description = "Allows setting some advanced options for the Chroma Radiance model." ,
141+ description = "Allows setting advanced options for the Chroma Radiance model." ,
133142 inputs = [
134143 io .Model .Input (id = "model" ),
135144 io .Boolean .Input (
136145 id = "preserve_wrapper" ,
137146 default = True ,
138- tooltip = "When enabled preserves an existing model wrapper if it exists. Generally should be left enabled." ,
147+ tooltip = "When enabled, will delegate to an existing model function wrapper if it exists. Generally should be left enabled." ,
139148 ),
140149 io .Float .Input (
141150 id = "start_sigma" ,
142151 default = 1.0 ,
143152 min = 0.0 ,
144153 max = 1.0 ,
154+ tooltip = "First sigma that these options will be in effect." ,
145155 ),
146156 io .Float .Input (
147157 id = "end_sigma" ,
148158 default = 0.0 ,
149159 min = 0.0 ,
150160 max = 1.0 ,
161+ tooltip = "Last sigma that these options will be in effect." ,
151162 ),
152163 io .Int .Input (
153164 id = "nerf_tile_size" ,
154165 default = - 1 ,
155166 min = - 1 ,
156- tooltip = "Allows overriding the default NeRF tile size. -1 means use the default. 0 means use non-tiling mode (may require a lot of VRAM)." ,
167+ tooltip = "Allows overriding the default NeRF tile size. -1 means use the default (32) . 0 means use non-tiling mode (may require a lot of VRAM)." ,
157168 ),
158169 io .Combo .Input (
159170 id = "nerf_embedder_dtype" ,
160171 default = "default" ,
161172 options = ["default" , "model_dtype" , "float32" , "float64" , "float16" , "bfloat16" ],
162- tooltip = "Allows overriding the dtype the NeRF embedder uses." ,
173+ tooltip = "Allows overriding the dtype the NeRF embedder uses. The default is float32. " ,
163174 ),
164175 ],
165176 outputs = [io .Model .Output ()],
0 commit comments