Skip to content

Commit 6732014

Browse files
authored
convert nodes_compositing.py to V3 schema (Comfy-Org#10174)
1 parent 989f715 commit 6732014

File tree

1 file changed

+69
-60
lines changed

1 file changed

+69
-60
lines changed

comfy_extras/nodes_compositing.py

Lines changed: 69 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
import torch
22
import comfy.utils
33
from enum import Enum
4+
from typing_extensions import override
5+
from comfy_api.latest import ComfyExtension, io
6+
47

58
def resize_mask(mask, shape):
69
return torch.nn.functional.interpolate(mask.reshape((-1, 1, mask.shape[-2], mask.shape[-1])), size=(shape[0], shape[1]), mode="bilinear").squeeze(1)
@@ -101,24 +104,28 @@ def porter_duff_composite(src_image: torch.Tensor, src_alpha: torch.Tensor, dst_
101104
return out_image, out_alpha
102105

103106

104-
class PorterDuffImageComposite:
107+
class PorterDuffImageComposite(io.ComfyNode):
108+
@classmethod
109+
def define_schema(cls):
110+
return io.Schema(
111+
node_id="PorterDuffImageComposite",
112+
display_name="Porter-Duff Image Composite",
113+
category="mask/compositing",
114+
inputs=[
115+
io.Image.Input("source"),
116+
io.Mask.Input("source_alpha"),
117+
io.Image.Input("destination"),
118+
io.Mask.Input("destination_alpha"),
119+
io.Combo.Input("mode", options=[mode.name for mode in PorterDuffMode], default=PorterDuffMode.DST.name),
120+
],
121+
outputs=[
122+
io.Image.Output(),
123+
io.Mask.Output(),
124+
],
125+
)
126+
105127
@classmethod
106-
def INPUT_TYPES(s):
107-
return {
108-
"required": {
109-
"source": ("IMAGE",),
110-
"source_alpha": ("MASK",),
111-
"destination": ("IMAGE",),
112-
"destination_alpha": ("MASK",),
113-
"mode": ([mode.name for mode in PorterDuffMode], {"default": PorterDuffMode.DST.name}),
114-
},
115-
}
116-
117-
RETURN_TYPES = ("IMAGE", "MASK")
118-
FUNCTION = "composite"
119-
CATEGORY = "mask/compositing"
120-
121-
def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode):
128+
def execute(cls, source: torch.Tensor, source_alpha: torch.Tensor, destination: torch.Tensor, destination_alpha: torch.Tensor, mode) -> io.NodeOutput:
122129
batch_size = min(len(source), len(source_alpha), len(destination), len(destination_alpha))
123130
out_images = []
124131
out_alphas = []
@@ -150,65 +157,67 @@ def composite(self, source: torch.Tensor, source_alpha: torch.Tensor, destinatio
150157
out_images.append(out_image)
151158
out_alphas.append(out_alpha.squeeze(2))
152159

153-
result = (torch.stack(out_images), torch.stack(out_alphas))
154-
return result
160+
return io.NodeOutput(torch.stack(out_images), torch.stack(out_alphas))
155161

156162

157-
class SplitImageWithAlpha:
163+
class SplitImageWithAlpha(io.ComfyNode):
158164
@classmethod
159-
def INPUT_TYPES(s):
160-
return {
161-
"required": {
162-
"image": ("IMAGE",),
163-
}
164-
}
165-
166-
CATEGORY = "mask/compositing"
167-
RETURN_TYPES = ("IMAGE", "MASK")
168-
FUNCTION = "split_image_with_alpha"
169-
170-
def split_image_with_alpha(self, image: torch.Tensor):
165+
def define_schema(cls):
166+
return io.Schema(
167+
node_id="SplitImageWithAlpha",
168+
display_name="Split Image with Alpha",
169+
category="mask/compositing",
170+
inputs=[
171+
io.Image.Input("image"),
172+
],
173+
outputs=[
174+
io.Image.Output(),
175+
io.Mask.Output(),
176+
],
177+
)
178+
179+
@classmethod
180+
def execute(cls, image: torch.Tensor) -> io.NodeOutput:
171181
out_images = [i[:,:,:3] for i in image]
172182
out_alphas = [i[:,:,3] if i.shape[2] > 3 else torch.ones_like(i[:,:,0]) for i in image]
173-
result = (torch.stack(out_images), 1.0 - torch.stack(out_alphas))
174-
return result
183+
return io.NodeOutput(torch.stack(out_images), 1.0 - torch.stack(out_alphas))
184+
175185

186+
class JoinImageWithAlpha(io.ComfyNode):
187+
@classmethod
188+
def define_schema(cls):
189+
return io.Schema(
190+
node_id="JoinImageWithAlpha",
191+
display_name="Join Image with Alpha",
192+
category="mask/compositing",
193+
inputs=[
194+
io.Image.Input("image"),
195+
io.Mask.Input("alpha"),
196+
],
197+
outputs=[io.Image.Output()],
198+
)
176199

177-
class JoinImageWithAlpha:
178200
@classmethod
179-
def INPUT_TYPES(s):
180-
return {
181-
"required": {
182-
"image": ("IMAGE",),
183-
"alpha": ("MASK",),
184-
}
185-
}
186-
187-
CATEGORY = "mask/compositing"
188-
RETURN_TYPES = ("IMAGE",)
189-
FUNCTION = "join_image_with_alpha"
190-
191-
def join_image_with_alpha(self, image: torch.Tensor, alpha: torch.Tensor):
201+
def execute(cls, image: torch.Tensor, alpha: torch.Tensor) -> io.NodeOutput:
192202
batch_size = min(len(image), len(alpha))
193203
out_images = []
194204

195205
alpha = 1.0 - resize_mask(alpha, image.shape[1:])
196206
for i in range(batch_size):
197207
out_images.append(torch.cat((image[i][:,:,:3], alpha[i].unsqueeze(2)), dim=2))
198208

199-
result = (torch.stack(out_images),)
200-
return result
209+
return io.NodeOutput(torch.stack(out_images))
201210

202211

203-
NODE_CLASS_MAPPINGS = {
204-
"PorterDuffImageComposite": PorterDuffImageComposite,
205-
"SplitImageWithAlpha": SplitImageWithAlpha,
206-
"JoinImageWithAlpha": JoinImageWithAlpha,
207-
}
212+
class CompositingExtension(ComfyExtension):
213+
@override
214+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
215+
return [
216+
PorterDuffImageComposite,
217+
SplitImageWithAlpha,
218+
JoinImageWithAlpha,
219+
]
208220

209221

210-
NODE_DISPLAY_NAME_MAPPINGS = {
211-
"PorterDuffImageComposite": "Porter-Duff Image Composite",
212-
"SplitImageWithAlpha": "Split Image with Alpha",
213-
"JoinImageWithAlpha": "Join Image with Alpha",
214-
}
222+
async def comfy_entrypoint() -> CompositingExtension:
223+
return CompositingExtension()

0 commit comments

Comments
 (0)