|
1 | 1 | import torch |
2 | 2 | import comfy.utils |
3 | 3 | from enum import Enum |
| 4 | +from typing_extensions import override |
| 5 | +from comfy_api.latest import ComfyExtension, io |
| 6 | + |
4 | 7 |
|
5 | 8 | def resize_mask(mask, shape): |
6 | 9 | return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1) |
@@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_ |
101 | 104 | return out_image, out_alpha |
102 | 105 |
|
103 | 106 |
|
104 | | -class PorterDuffImageComposite: |
| 107 | +class PorterDuffImageComposite(io.ComfyNode): |
| 108 | + @classmethod |
| 109 | + def define_schema(cls): |
| 110 | + return io.Schema( |
| 111 | + node_id="PorterDuffImageComposite", |
| 112 | + display_name="Porter-Duff Image Composite", |
| 113 | + category="mask/compositing", |
| 114 | + inputs=[ |
| 115 | + io.Image.Input("source"), |
| 116 | + io.Mask.Input("source_alpha"), |
| 117 | + io.Image.Input("destination"), |
| 118 | + io.Mask.Input("destination_alpha"), |
| 119 | + io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name), |
| 120 | + ], |
| 121 | + outputs=[ |
| 122 | + io.Image.Output(), |
| 123 | + io.Mask.Output(), |
| 124 | + ], |
| 125 | + ) |
| 126 | + |
105 | 127 | @classmethod |
106 | | - def INPUT_TYPES(s): |
107 | | - return { |
108 | | - "required": { |
109 | | - "source": ("IMAGE",), |
110 | | - "source_alpha": ("MASK",), |
111 | | - "destination": ("IMAGE",), |
112 | | - "destination_alpha": ("MASK",), |
113 | | - "mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}), |
114 | | - }, |
115 | | - } |
116 | | - |
117 | | - RETURN_TYPES = ("IMAGE", "MASK") |
118 | | - FUNCTION = "composite" |
119 | | - CATEGORY = "mask/compositing" |
120 | | - |
121 | | - def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode): |
| 128 | + def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput: |
122 | 129 | batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha)) |
123 | 130 | out_images = [] |
124 | 131 | out_alphas = [] |
@@ -150,65 +157,67 @@ def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destinatio |
150 | 157 | out_images.append(out_image) |
151 | 158 | out_alphas.append(out_alpha.squeeze(2)) |
152 | 159 |
|
153 | | - result = (torch.stack(out_images), torch.stack(out_alphas)) |
154 | | - return result |
| 160 | + return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas)) |
155 | 161 |
|
156 | 162 |
|
157 | | -class SplitImageWithAlpha: |
| 163 | +class SplitImageWithAlpha(io.ComfyNode): |
158 | 164 | @classmethod |
159 | | - def INPUT_TYPES(s): |
160 | | - return { |
161 | | - "required": { |
162 | | - "image": ("IMAGE",), |
163 | | - } |
164 | | - } |
165 | | - |
166 | | - CATEGORY = "mask/compositing" |
167 | | - RETURN_TYPES = ("IMAGE", "MASK") |
168 | | - FUNCTION = "split_image_with_alpha" |
169 | | - |
170 | | - def split_image_with_alpha(self, image: torch.Tensor): |
| 165 | + def define_schema(cls): |
| 166 | + return io.Schema( |
| 167 | + node_id="SplitImageWithAlpha", |
| 168 | + display_name="Split Image with Alpha", |
| 169 | + category="mask/compositing", |
| 170 | + inputs=[ |
| 171 | + io.Image.Input("image"), |
| 172 | + ], |
| 173 | + outputs=[ |
| 174 | + io.Image.Output(), |
| 175 | + io.Mask.Output(), |
| 176 | + ], |
| 177 | + ) |
| 178 | + |
| 179 | + @classmethod |
| 180 | + def execute(cls, image: torch.Tensor) -> io.NodeOutput: |
171 | 181 | out_images = [i[:,:,:3] for i in image] |
172 | 182 | out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image] |
173 | | - result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas)) |
174 | | - return result |
| 183 | + return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas)) |
| 184 | + |
175 | 185 |
|
| 186 | +class JoinImageWithAlpha(io.ComfyNode): |
| 187 | + @classmethod |
| 188 | + def define_schema(cls): |
| 189 | + return io.Schema( |
| 190 | + node_id="JoinImageWithAlpha", |
| 191 | + display_name="Join Image with Alpha", |
| 192 | + category="mask/compositing", |
| 193 | + inputs=[ |
| 194 | + io.Image.Input("image"), |
| 195 | + io.Mask.Input("alpha"), |
| 196 | + ], |
| 197 | + outputs=[io.Image.Output()], |
| 198 | + ) |
176 | 199 |
|
177 | | -class JoinImageWithAlpha: |
178 | 200 | @classmethod |
179 | | - def INPUT_TYPES(s): |
180 | | - return { |
181 | | - "required": { |
182 | | - "image": ("IMAGE",), |
183 | | - "alpha": ("MASK",), |
184 | | - } |
185 | | - } |
186 | | - |
187 | | - CATEGORY = "mask/compositing" |
188 | | - RETURN_TYPES = ("IMAGE",) |
189 | | - FUNCTION = "join_image_with_alpha" |
190 | | - |
191 | | - def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor): |
| 201 | + def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput: |
192 | 202 | batch_size = min(len(image), len(alpha)) |
193 | 203 | out_images = [] |
194 | 204 |
|
195 | 205 | alpha = 1.0 - resize_mask(alpha, image.shape[1:]) |
196 | 206 | for i in range(batch_size): |
197 | 207 | out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2)) |
198 | 208 |
|
199 | | - result = (torch.stack(out_images),) |
200 | | - return result |
| 209 | + return io.NodeOutput(torch.stack(out_images)) |
201 | 210 |
|
202 | 211 |
|
203 | | -NODE_CLASS_MAPPINGS = { |
204 | | - "PorterDuffImageComposite": PorterDuffImageComposite, |
205 | | - "SplitImageWithAlpha": SplitImageWithAlpha, |
206 | | - "JoinImageWithAlpha": JoinImageWithAlpha, |
207 | | -} |
| 212 | +class CompositingExtension(ComfyExtension): |
| 213 | + @override |
| 214 | + async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| 215 | + return [ |
| 216 | + PorterDuffImageComposite, |
| 217 | + SplitImageWithAlpha, |
| 218 | + JoinImageWithAlpha, |
| 219 | + ] |
208 | 220 |
|
209 | 221 |
|
210 | | -NODE_DISPLAY_NAME_MAPPINGS = { |
211 | | - "PorterDuffImageComposite": "Porter-Duff Image Composite", |
212 | | - "SplitImageWithAlpha": "Split Image with Alpha", |
213 | | - "JoinImageWithAlpha": "Join Image with Alpha", |
214 | | -} |
| 222 | +async def comfy_entrypoint() -> CompositingExtension: |
| 223 | + return CompositingExtension() |
0 commit comments