Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comfy/comfy_types/node_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ class ComfyNodeABC(ABC):
DEPRECATED: bool
"""Flags a node as deprecated, indicating to users that they should find alternatives to this node."""
API_NODE: Optional[bool]
"""Flags a node as an API node."""
"""Flags a node as an API node. See: https://docs.comfy.org/tutorials/api-nodes/overview."""

@classmethod
@abstractmethod
Expand Down
11 changes: 9 additions & 2 deletions comfy_api_nodes/apinode_utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations
import io
import logging
from typing import Optional
from typing import Optional, Union
from comfy.utils import common_upscale
from comfy_api.input_impl import VideoFromFile
from comfy_api.util import VideoContainer, VideoCodec
Expand All @@ -15,6 +15,7 @@
UploadRequest,
UploadResponse,
)
from server import PromptServer


import numpy as np
Expand Down Expand Up @@ -60,7 +61,9 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
return s


def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.

Args:
Expand Down Expand Up @@ -94,6 +97,10 @@ def validate_and_cast_response(response, timeout: int = None) -> torch.Tensor:
img = Image.open(io.BytesIO(img_data))

elif image_url:
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {image_url}", node_id
)
img_response = requests.get(image_url, timeout=timeout)
if img_response.status_code != 200:
raise ValueError("Failed to download the image")
Expand Down
44 changes: 41 additions & 3 deletions comfy_api_nodes/apis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
from pydantic import BaseModel, Field
import uuid # For generating unique operation IDs

from server import PromptServer
from comfy.cli_args import args
from comfy import utils
from . import request_logger
Expand Down Expand Up @@ -900,6 +901,7 @@ def __init__(
failed_statuses: list,
status_extractor: Callable[[R], str],
progress_extractor: Callable[[R], float] = None,
result_url_extractor: Callable[[R], str] = None,
request: Optional[T] = None,
api_base: str | None = None,
auth_token: Optional[str] = None,
Expand All @@ -910,6 +912,8 @@ def __init__(
max_retries: int = 3, # Max retries per individual API call
retry_delay: float = 1.0,
retry_backoff_factor: float = 2.0,
estimated_duration: Optional[float] = None,
node_id: Optional[str] = None,
):
self.poll_endpoint = poll_endpoint
self.request = request
Expand All @@ -924,12 +928,15 @@ def __init__(
self.max_retries = max_retries
self.retry_delay = retry_delay
self.retry_backoff_factor = retry_backoff_factor
self.estimated_duration = estimated_duration

# Polling configuration
self.status_extractor = status_extractor or (
lambda x: getattr(x, "status", None)
)
self.progress_extractor = progress_extractor
self.result_url_extractor = result_url_extractor
self.node_id = node_id
self.completed_statuses = completed_statuses
self.failed_statuses = failed_statuses

Expand Down Expand Up @@ -965,6 +972,26 @@ def execute(self, client: Optional[ApiClient] = None) -> R:
except Exception as e:
raise Exception(f"Error during polling: {str(e)}")

def _display_text_on_node(self, text: str):
"""Sends text to the client which will be displayed on the node in the UI"""
if not self.node_id:
return

PromptServer.instance.send_progress_text(text, self.node_id)

def _display_time_progress_on_node(self, time_completed: int):
if not self.node_id:
return

if self.estimated_duration is not None:
estimated_time_remaining = max(
0, int(self.estimated_duration) - int(time_completed)
)
message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)"
else:
message = f"Task in progress: {time_completed:.0f}s"
self._display_text_on_node(message)

def _check_task_status(self, response: R) -> TaskStatus:
"""Check task status using the status extractor function"""
try:
Expand Down Expand Up @@ -1031,7 +1058,15 @@ def _poll_until_complete(self, client: ApiClient) -> R:
progress.update_absolute(new_progress, total=PROGRESS_BAR_MAX)

if status == TaskStatus.COMPLETED:
logging.debug("[DEBUG] Task completed successfully")
message = "Task completed successfully"
if self.result_url_extractor:
result_url = self.result_url_extractor(response_obj)
if result_url:
message = f"Result URL: {result_url}"
else:
message = "Task completed successfully!"
logging.debug(f"[DEBUG] {message}")
self._display_text_on_node(message)
self.final_response = response_obj
if self.progress_extractor:
progress.update(100)
Expand All @@ -1047,7 +1082,10 @@ def _poll_until_complete(self, client: ApiClient) -> R:
logging.debug(
f"[DEBUG] Waiting {self.poll_interval} seconds before next poll"
)
time.sleep(self.poll_interval)
for i in range(int(self.poll_interval)):
time_completed = (poll_count * self.poll_interval) + i
self._display_time_progress_on_node(time_completed)
time.sleep(1)

except (LocalNetworkError, ApiServerError) as e:
# For network-related errors, increment error count and potentially abort
Expand All @@ -1067,7 +1105,7 @@ def _poll_until_complete(self, client: ApiClient) -> R:
except Exception as e:
# For other errors, increment count and potentially abort
consecutive_errors += 1
if consecutive_errors >= max_consecutive_errors:
if consecutive_errors >= max_consecutive_errors or status == TaskStatus.FAILED:
raise Exception(
f"Polling aborted after {consecutive_errors} consecutive errors: {str(e)}"
) from e
Expand Down
63 changes: 44 additions & 19 deletions comfy_api_nodes/nodes_bfl.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import io
from inspect import cleandoc
from typing import Union
from comfy.comfy_types.node_typing import IO, ComfyNodeABC
from comfy_api_nodes.apis.bfl_api import (
BFLStatus,
Expand Down Expand Up @@ -30,6 +31,7 @@
import torch
import base64
import time
from server import PromptServer


def convert_mask_to_image(mask: torch.Tensor):
Expand All @@ -42,14 +44,19 @@ def convert_mask_to_image(mask: torch.Tensor):


def handle_bfl_synchronous_operation(
operation: SynchronousOperation, timeout_bfl_calls=360
operation: SynchronousOperation,
timeout_bfl_calls=360,
node_id: Union[str, None] = None,
):
response_api: BFLFluxProGenerateResponse = operation.execute()
return _poll_until_generated(
response_api.polling_url, timeout=timeout_bfl_calls
response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id
)

def _poll_until_generated(polling_url: str, timeout=360):

def _poll_until_generated(
polling_url: str, timeout=360, node_id: Union[str, None] = None
):
# used bfl-comfy-nodes to verify code implementation:
# https://github.com/black-forest-labs/bfl-comfy-nodes/tree/main
start_time = time.time()
Expand All @@ -61,11 +68,21 @@ def _poll_until_generated(polling_url: str, timeout=360):
request = requests.Request(method=HttpMethod.GET, url=polling_url)
# NOTE: should True loop be replaced with checking if workflow has been interrupted?
while True:
if node_id:
time_elapsed = time.time() - start_time
PromptServer.instance.send_progress_text(
f"Generating ({time_elapsed:.0f}s)", node_id
)

response = requests.Session().send(request.prepare())
if response.status_code == 200:
result = response.json()
if result["status"] == BFLStatus.ready:
img_url = result["result"]["sample"]
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {img_url}", node_id
)
img_response = requests.get(img_url)
return process_image_response(img_response)
elif result["status"] in [
Expand Down Expand Up @@ -180,6 +197,7 @@ def INPUT_TYPES(s):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}

Expand Down Expand Up @@ -212,6 +230,7 @@ def api_call(
seed=0,
image_prompt=None,
image_prompt_strength=0.1,
unique_id: Union[str, None] = None,
**kwargs,
):
if image_prompt is None:
Expand Down Expand Up @@ -246,7 +265,7 @@ def api_call(
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)


Expand Down Expand Up @@ -320,6 +339,7 @@ def INPUT_TYPES(s):
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}

Expand All @@ -338,6 +358,7 @@ def api_call(
seed=0,
image_prompt=None,
# image_prompt_strength=0.1,
unique_id: Union[str, None] = None,
**kwargs,
):
image_prompt = (
Expand All @@ -363,7 +384,7 @@ def api_call(
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)


Expand Down Expand Up @@ -457,11 +478,11 @@ def INPUT_TYPES(s):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}

Expand All @@ -483,6 +504,7 @@ def api_call(
steps: int,
guidance: float,
seed=0,
unique_id: Union[str, None] = None,
**kwargs,
):
image = convert_image_to_base64(image)
Expand All @@ -508,7 +530,7 @@ def api_call(
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)


Expand Down Expand Up @@ -568,11 +590,11 @@ def INPUT_TYPES(s):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}

Expand All @@ -591,13 +613,14 @@ def api_call(
steps: int,
guidance: float,
seed=0,
unique_id: Union[str, None] = None,
**kwargs,
):
# prepare mask
mask = resize_mask_to_image(mask, image)
mask = convert_image_to_base64(convert_mask_to_image(mask))
# make sure image will have alpha channel removed
image = convert_image_to_base64(image[:,:,:,:3])
image = convert_image_to_base64(image[:, :, :, :3])

operation = SynchronousOperation(
endpoint=ApiEndpoint(
Expand All @@ -617,7 +640,7 @@ def api_call(
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)


Expand Down Expand Up @@ -702,11 +725,11 @@ def INPUT_TYPES(s):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}

Expand All @@ -727,9 +750,10 @@ def api_call(
steps: int,
guidance: float,
seed=0,
unique_id: Union[str, None] = None,
**kwargs,
):
control_image = convert_image_to_base64(control_image[:,:,:,:3])
control_image = convert_image_to_base64(control_image[:, :, :, :3])
preprocessed_image = None

# scale canny threshold between 0-500, to match BFL's API
Expand Down Expand Up @@ -765,7 +789,7 @@ def scale_value(value: float, min_val=0, max_val=500):
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)


Expand Down Expand Up @@ -830,11 +854,11 @@ def INPUT_TYPES(s):
},
),
},
"optional": {
},
"optional": {},
"hidden": {
"auth_token": "AUTH_TOKEN_COMFY_ORG",
"comfy_api_key": "API_KEY_COMFY_ORG",
"unique_id": "UNIQUE_ID",
},
}

Expand All @@ -853,6 +877,7 @@ def api_call(
steps: int,
guidance: float,
seed=0,
unique_id: Union[str, None] = None,
**kwargs,
):
control_image = convert_image_to_base64(control_image[:,:,:,:3])
Expand Down Expand Up @@ -880,7 +905,7 @@ def api_call(
),
auth_kwargs=kwargs,
)
output_image = handle_bfl_synchronous_operation(operation)
output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id)
return (output_image,)


Expand Down
Loading
Loading