diff --git a/comfy_api_nodes/apinode_utils.py b/comfy_api_nodes/apinode_utils.py index 788e2803f5d6..f953f86df0de 100644 --- a/comfy_api_nodes/apinode_utils.py +++ b/comfy_api_nodes/apinode_utils.py @@ -1,4 +1,5 @@ from __future__ import annotations +import aiohttp import io import logging import mimetypes @@ -21,7 +22,6 @@ import numpy as np from PIL import Image -import requests import torch import math import base64 @@ -30,7 +30,7 @@ import av -def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: +async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile: """Downloads a video from a URL and returns a `VIDEO` output. Args: @@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr Returns: A Comfy node `VIDEO` output. """ - video_io = download_url_to_bytesio(video_url, timeout) + video_io = await download_url_to_bytesio(video_url, timeout) if video_io is None: error_msg = f"Failed to download video from {video_url}" logging.error(error_msg) @@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor: return s -def validate_and_cast_response( +async def validate_and_cast_response( response, timeout: int = None, node_id: Union[str, None] = None ) -> torch.Tensor: """Validates and casts a response to a torch.Tensor. @@ -86,35 +86,24 @@ def validate_and_cast_response( image_tensors: list[torch.Tensor] = [] # Process each image in the data array - for image_data in data: - image_url = image_data.url - b64_data = image_data.b64_json - - if not image_url and not b64_data: - raise ValueError("No image was generated in the response") - - if b64_data: - img_data = base64.b64decode(b64_data) - img = Image.open(io.BytesIO(img_data)) - - elif image_url: - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {image_url}", node_id - ) - img_response = requests.get(image_url, timeout=timeout) - if img_response.status_code != 200: - raise ValueError("Failed to download the image") - img = Image.open(io.BytesIO(img_response.content)) - - img = img.convert("RGBA") - - # Convert to numpy array, normalize to float32 between 0 and 1 - img_array = np.array(img).astype(np.float32) / 255.0 - img_tensor = torch.from_numpy(img_array) - - # Add to list of tensors - image_tensors.append(img_tensor) + async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session: + for img_data in data: + img_bytes: bytes + if img_data.b64_json: + img_bytes = base64.b64decode(img_data.b64_json) + elif img_data.url: + if node_id: + PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id) + async with session.get(img_data.url) as resp: + if resp.status != 200: + raise ValueError("Failed to download generated image") + img_bytes = await resp.read() + else: + raise ValueError("Invalid image payload – neither URL nor base64 data present.") + + pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA") + arr = np.asarray(pil_img).astype(np.float32) / 255.0 + image_tensors.append(torch.from_numpy(arr)) return torch.stack(image_tensors, dim=0) @@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str: return mime_type.split("/")[-1].lower() -def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: +async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: """Downloads content from a URL using requests and returns it as BytesIO. Args: @@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO: Returns: BytesIO object containing the downloaded content. """ - response = requests.get(url, stream=True, timeout=timeout) - response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) - return BytesIO(response.content) + timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None + async with aiohttp.ClientSession(timeout=timeout_cfg) as session: + async with session.get(url) as resp: + resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX) + return BytesIO(await resp.read()) def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor: @@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch return torch.from_numpy(image_array).unsqueeze(0) -def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: +async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor: """Downloads an image from a URL and returns a [B, H, W, C] tensor.""" - image_bytesio = download_url_to_bytesio(url, timeout) + image_bytesio = await download_url_to_bytesio(url, timeout) return bytesio_to_image_tensor(image_bytesio) -def process_image_response(response: requests.Response) -> torch.Tensor: +def process_image_response(response_content: bytes | str) -> torch.Tensor: """Uses content from a Response object and converts it to a torch.Tensor""" - return bytesio_to_image_tensor(BytesIO(response.content)) + return bytesio_to_image_tensor(BytesIO(response_content)) def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image: @@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str: return f"data:{mime_type};base64,{base64_string}" -def upload_file_to_comfyapi( +async def upload_file_to_comfyapi( file_bytes_io: BytesIO, filename: str, - upload_mime_type: str, + upload_mime_type: Optional[str], auth_kwargs: Optional[dict[str, str]] = None, ) -> str: """ @@ -354,7 +345,10 @@ def upload_file_to_comfyapi( Returns: The download URL for the uploaded file. """ - request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) + if upload_mime_type is None: + request_object = UploadRequest(file_name=filename) + else: + request_object = UploadRequest(file_name=filename, content_type=upload_mime_type) operation = SynchronousOperation( endpoint=ApiEndpoint( path="/customers/storage", @@ -366,12 +360,8 @@ def upload_file_to_comfyapi( auth_kwargs=auth_kwargs, ) - response: UploadResponse = operation.execute() - upload_response = ApiClient.upload_file( - response.upload_url, file_bytes_io, content_type=upload_mime_type - ) - upload_response.raise_for_status() - + response: UploadResponse = await operation.execute() + await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type) return response.download_url @@ -399,7 +389,7 @@ def video_to_base64_string( return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8") -def upload_video_to_comfyapi( +async def upload_video_to_comfyapi( video: VideoInput, auth_kwargs: Optional[dict[str, str]] = None, container: VideoContainer = VideoContainer.MP4, @@ -439,9 +429,7 @@ def upload_video_to_comfyapi( video.save_to(video_bytes_io, format=container, codec=codec) video_bytes_io.seek(0) - return upload_file_to_comfyapi( - video_bytes_io, filename, upload_mime_type, auth_kwargs - ) + return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs) def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray: @@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio( return audio_bytes_io -def upload_audio_to_comfyapi( +async def upload_audio_to_comfyapi( audio: AudioInput, auth_kwargs: Optional[dict[str, str]] = None, container_format: str = "mp4", @@ -527,7 +515,7 @@ def upload_audio_to_comfyapi( audio_data_np, sample_rate, container_format, codec_name ) - return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) + return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs) def audio_to_base64_string( @@ -544,7 +532,7 @@ def audio_to_base64_string( return base64.b64encode(audio_bytes).decode("utf-8") -def upload_images_to_comfyapi( +async def upload_images_to_comfyapi( image: torch.Tensor, max_images=8, auth_kwargs: Optional[dict[str, str]] = None, @@ -561,55 +549,15 @@ def upload_images_to_comfyapi( mime_type: Optional MIME type for the image. """ # if batch, try to upload each file if max_images is greater than 0 - idx_image = 0 download_urls: list[str] = [] is_batch = len(image.shape) > 3 - batch_length = 1 - if is_batch: - batch_length = image.shape[0] - while True: - curr_image = image - if len(image.shape) > 3: - curr_image = image[idx_image] - # get BytesIO version of image - img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type) - # first, request upload/download urls from comfy API - if not mime_type: - request_object = UploadRequest(file_name=img_binary.name) - else: - request_object = UploadRequest( - file_name=img_binary.name, content_type=mime_type - ) - operation = SynchronousOperation( - endpoint=ApiEndpoint( - path="/customers/storage", - method=HttpMethod.POST, - request_model=UploadRequest, - response_model=UploadResponse, - ), - request=request_object, - auth_kwargs=auth_kwargs, - ) - response = operation.execute() + batch_len = image.shape[0] if is_batch else 1 - upload_response = ApiClient.upload_file( - response.upload_url, img_binary, content_type=mime_type - ) - # verify success - try: - upload_response.raise_for_status() - except requests.exceptions.HTTPError as e: - raise ValueError(f"Could not upload one or more images: {e}") from e - # add download_url to list - download_urls.append(response.download_url) - - idx_image += 1 - # stop uploading additional files if done - if is_batch and max_images > 0: - if idx_image >= max_images: - break - if idx_image >= batch_length: - break + for idx in range(min(batch_len, max_images)): + tensor = image[idx] if is_batch else image + img_io = tensor_to_bytesio(tensor, mime_type=mime_type) + url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs) + download_urls.append(url) return download_urls diff --git a/comfy_api_nodes/apis/client.py b/comfy_api_nodes/apis/client.py index 2a4bac88b246..4ad0b783bc61 100644 --- a/comfy_api_nodes/apis/client.py +++ b/comfy_api_nodes/apis/client.py @@ -43,7 +43,7 @@ endpoint=user_info_endpoint, request=request ) -user_profile = operation.execute(client=api_client) # Returns immediately with the result +user_profile = await operation.execute(client=api_client) # Returns immediately with the result # Example 2: Asynchronous API Operation with Polling @@ -87,18 +87,19 @@ ) # This will make the initial request and then poll until completion -result = operation.execute(client=api_client) # Returns the final ImageGenerationResult when done +result = await operation.execute(client=api_client) # Returns the final ImageGenerationResult when done """ from __future__ import annotations +import aiohttp +import asyncio import logging -import time import io import socket +from aiohttp.client_exceptions import ClientError, ClientResponseError from typing import Dict, Type, Optional, Any, TypeVar, Generic, Callable, Tuple from enum import Enum import json -import requests from urllib.parse import urljoin, urlparse from pydantic import BaseModel, Field import uuid # For generating unique operation IDs @@ -174,6 +175,7 @@ def __init__( retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, retry_status_codes: Optional[Tuple[int, ...]] = None, + session: Optional[aiohttp.ClientSession] = None, ): self.base_url = base_url self.auth_token = auth_token @@ -186,13 +188,16 @@ def __init__( # Default retry status codes: 408 (Request Timeout), 429 (Too Many Requests), # 500, 502, 503, 504 (Server Errors) self.retry_status_codes = retry_status_codes or (408, 429, 500, 502, 503, 504) + self._session: Optional[aiohttp.ClientSession] = session + self._owns_session = session is None # Track if we have to close it - def _generate_operation_id(self, path: str) -> str: + @staticmethod + def _generate_operation_id(path: str) -> str: """Generates a unique operation ID for logging.""" return f"{path.strip('/').replace('/', '_')}_{uuid.uuid4().hex[:8]}" + @staticmethod def _create_json_payload_args( - self, data: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: @@ -203,31 +208,53 @@ def _create_json_payload_args( def _create_form_data_args( self, - data: Dict[str, Any], - files: Dict[str, Any], + data: Dict[str, Any] | None, + files: Dict[str, Any] | None, headers: Optional[Dict[str, str]] = None, - multipart_parser = None, + multipart_parser: Callable | None = None, ) -> Dict[str, Any]: if headers and "Content-Type" in headers: del headers["Content-Type"] - if multipart_parser: + if multipart_parser and data: data = multipart_parser(data) - return { - "data": data, - "files": files, - "headers": headers, - } + form = aiohttp.FormData(default_to_multipart=True) + if data: # regular text fields + for k, v in data.items(): + if v is None: + continue # aiohttp fails to serialize "None" values + # aiohttp expects strings or bytes; convert enums etc. + form.add_field(k, str(v) if not isinstance(v, (bytes, bytearray)) else v) + + if files: + file_iter = files if isinstance(files, list) else files.items() + for field_name, file_obj in file_iter: + if file_obj is None: + continue # aiohttp fails to serialize "None" values + # file_obj can be (filename, bytes/io.BytesIO, content_type) tuple + if isinstance(file_obj, tuple): + filename, file_value, content_type = self._unpack_tuple(file_obj) + else: + file_value = file_obj + filename = getattr(file_obj, "name", field_name) + content_type = "application/octet-stream" + + form.add_field( + name=field_name, + value=file_value, + filename=filename, + content_type=content_type, + ) + return {"data": form, "headers": headers or {}} + @staticmethod def _create_urlencoded_form_data_args( - self, data: Dict[str, Any], headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: headers = headers or {} headers["Content-Type"] = "application/x-www-form-urlencoded" - return { "data": data, "headers": headers, @@ -244,7 +271,7 @@ def get_headers(self) -> Dict[str, str]: return headers - def _check_connectivity(self, target_url: str) -> Dict[str, bool]: + async def _check_connectivity(self, target_url: str) -> Dict[str, bool]: """ Check connectivity to determine if network issues are local or server-related. @@ -258,52 +285,39 @@ def _check_connectivity(self, target_url: str) -> Dict[str, bool]: "internet_accessible": False, "api_accessible": False, "is_local_issue": False, - "is_api_issue": False + "is_api_issue": False, } + timeout = aiohttp.ClientTimeout(total=5.0) + async with aiohttp.ClientSession(timeout=timeout) as session: + try: + async with session.get("https://www.google.com", ssl=self.verify_ssl) as resp: + results["internet_accessible"] = resp.status < 500 + except (ClientError, asyncio.TimeoutError, socket.gaierror): + results["is_local_issue"] = True + return results # cannot reach the internet – early exit + + # Now check API health endpoint + parsed = urlparse(target_url) + health_url = f"{parsed.scheme}://{parsed.netloc}/health" + try: + async with session.get(health_url, ssl=self.verify_ssl) as resp: + results["api_accessible"] = resp.status < 500 + except ClientError: + pass # leave as False - # First check basic internet connectivity using a reliable external site - try: - # Use a reliable external domain for checking basic connectivity - check_response = requests.get("https://www.google.com", - timeout=5.0, - verify=self.verify_ssl) - if check_response.status_code < 500: - results["internet_accessible"] = True - except (requests.RequestException, socket.error): - results["internet_accessible"] = False - results["is_local_issue"] = True - return results - - # Now check API server connectivity - try: - # Extract domain from the target URL to do a simpler health check - parsed_url = urlparse(target_url) - api_base = f"{parsed_url.scheme}://{parsed_url.netloc}" - - # Try to reach the API domain - api_response = requests.get(f"{api_base}/health", timeout=5.0, verify=self.verify_ssl) - if api_response.status_code < 500: - results["api_accessible"] = True - else: - results["api_accessible"] = False - results["is_api_issue"] = True - except requests.RequestException: - results["api_accessible"] = False - # If we can reach the internet but not the API, it's an API issue - results["is_api_issue"] = True - + results["is_api_issue"] = results["internet_accessible"] and not results["api_accessible"] return results - def request( + async def request( self, method: str, path: str, params: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None, - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, headers: Optional[Dict[str, str]] = None, content_type: str = "application/json", - multipart_parser: Callable = None, + multipart_parser: Callable | None = None, retry_count: int = 0, # Used internally for tracking retries ) -> Dict[str, Any]: """ @@ -327,18 +341,19 @@ def request( ApiServerError: If the API server is unreachable but internet is working Exception: For other request failures """ - # Use urljoin but ensure path is relative to avoid absolute path behavior - relative_path = path.lstrip('/') + + # Build full URL and merge headers + relative_path = path.lstrip("/") url = urljoin(self.base_url, relative_path) - self.check_auth(self.auth_token, self.comfy_api_key) - # Combine default headers with any provided headers + self._check_auth(self.auth_token, self.comfy_api_key) + request_headers = self.get_headers() if headers: request_headers.update(headers) - - # Let requests handle the content type when files are present. if files: - del request_headers["Content-Type"] + request_headers.pop("Content-Type", None) + if params: + params = {k: v for k, v in params.items() if v is not None} # aiohttp fails to serialize None values logging.debug(f"[DEBUG] Request Headers: {request_headers}") logging.debug(f"[DEBUG] Files: {files}") @@ -346,11 +361,9 @@ def request( logging.debug(f"[DEBUG] Data: {data}") if content_type == "application/x-www-form-urlencoded": - payload_args = self._create_urlencoded_form_data_args(data, request_headers) + payload_args = self._create_urlencoded_form_data_args(data or {}, request_headers) elif content_type == "multipart/form-data": - payload_args = self._create_form_data_args( - data, files, request_headers, multipart_parser - ) + payload_args = self._create_form_data_args(data, files, request_headers, multipart_parser) else: payload_args = self._create_json_payload_args(data, request_headers) @@ -361,220 +374,67 @@ def request( request_url=url, request_headers=request_headers, request_params=params, - request_data=data if content_type == "application/json" else "[form-data or other]" + request_data=data if content_type == "application/json" else "[form-data or other]", ) + session = await self._get_session() try: - response = requests.request( - method=method, - url=url, + async with session.request( + method, + url, params=params, - timeout=self.timeout, - verify=self.verify_ssl, + ssl=self.verify_ssl, **payload_args, - ) - - # Check if we should retry based on status code - if (response.status_code in self.retry_status_codes and - retry_count < self.max_retries): - - # Calculate delay with exponential backoff - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - - logging.warning( - f"Request failed with status {response.status_code}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - - # Raise exception for error status codes - response.raise_for_status() - - # Log successful response - response_content_to_log = response.content - try: - # Attempt to parse JSON for prettier logging, fallback to raw content - response_content_to_log = response.json() - except json.JSONDecodeError: - pass # Keep as bytes/str if not JSON - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, # Pass request details again for context in log - request_url=url, - response_status_code=response.status_code, - response_headers=dict(response.headers), - response_content=response_content_to_log - ) + ) as resp: + if resp.status >= 400: + try: + error_data = await resp.json() + except (aiohttp.ContentTypeError, json.JSONDecodeError): + error_data = await resp.text() + + return await self._handle_http_error( + ClientResponseError(resp.request_info, resp.history, status=resp.status, message=error_data), + operation_id, + method, + url, + params, + data, + files, + headers, + content_type, + multipart_parser, + retry_count=retry_count, + response_content=error_data, + ) - except requests.ConnectionError as e: - error_message = f"ConnectionError: {str(e)}" - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, - request_url=url, - error_message=error_message - ) - # Only perform connectivity check if we've exhausted all retries - if retry_count >= self.max_retries: - # Check connectivity to determine if it's a local or API issue - connectivity = self._check_connectivity(self.base_url) - - if connectivity["is_local_issue"]: - raise LocalNetworkError( - "Unable to connect to the API server due to local network issues. " - "Please check your internet connection and try again." - ) from e - elif connectivity["is_api_issue"]: - raise ApiServerError( - f"The API server at {self.base_url} is currently unreachable. " - f"The service may be experiencing issues. Please try again later." - ) from e + # Success – parse JSON (safely) and log + try: + payload = await resp.json() + response_content_to_log = payload + except (aiohttp.ContentTypeError, json.JSONDecodeError): + payload = {} + response_content_to_log = await resp.text() - # If we haven't exhausted retries yet, retry the request - if retry_count < self.max_retries: - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"Connection error: {str(e)}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, + request_logger.log_request_response( + operation_id=operation_id, + request_method=method, + request_url=url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content=response_content_to_log, ) + return payload - # If we've exhausted retries and didn't identify the specific issue, - # raise a generic exception - final_error_message = ( - f"Unable to connect to the API server after {self.max_retries} attempts. " - f"Please check your internet connection or try again later." - ) - request_logger.log_request_response( # Log final failure - operation_id=operation_id, - request_method=method, request_url=url, - error_message=final_error_message - ) - raise Exception(final_error_message) from e - - except requests.Timeout as e: - error_message = f"Timeout: {str(e)}" - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, request_url=url, - error_message=error_message - ) - # Retry timeouts if we haven't exhausted retries + except (ClientError, asyncio.TimeoutError, socket.gaierror) as e: + # Treat as *connection* problem – optionally retry, else escalate if retry_count < self.max_retries: delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"Request timed out. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, - params=params, - data=data, - files=files, - headers=headers, - content_type=content_type, - multipart_parser=multipart_parser, - retry_count=retry_count + 1, - ) - final_error_message = ( - f"Request timed out after {self.timeout} seconds and {self.max_retries} retry attempts. " - f"The server might be experiencing high load or the operation is taking longer than expected." - ) - request_logger.log_request_response( # Log final failure - operation_id=operation_id, - request_method=method, request_url=url, - error_message=final_error_message - ) - raise Exception(final_error_message) from e - - except requests.HTTPError as e: - status_code = e.response.status_code if hasattr(e, "response") else None - original_error_message = f"HTTP Error: {str(e)}" - error_content_for_log = None - if hasattr(e, "response") and e.response is not None: - error_content_for_log = e.response.content - try: - error_content_for_log = e.response.json() - except json.JSONDecodeError: - pass - - - # Try to extract detailed error message from JSON response for user display - # but log the full error content. - user_display_error_message = original_error_message - - try: - if hasattr(e, "response") and e.response is not None and e.response.content: - error_json = e.response.json() - if "error" in error_json and "message" in error_json["error"]: - user_display_error_message = f"API Error: {error_json['error']['message']}" - if "type" in error_json["error"]: - user_display_error_message += f" (Type: {error_json['error']['type']})" - elif isinstance(error_json, dict): # Handle cases where error is just a JSON dict - user_display_error_message = f"API Error: {json.dumps(error_json)}" - else: # Non-dict JSON error - user_display_error_message = f"API Error: {str(error_json)}" - except json.JSONDecodeError: - # If not JSON, use the raw content if it's not too long, or a summary - if hasattr(e, "response") and e.response is not None and e.response.content: - raw_content = e.response.content.decode(errors='ignore') - if len(raw_content) < 200: # Arbitrary limit for display - user_display_error_message = f"API Error (raw): {raw_content}" - else: - user_display_error_message = f"API Error (raw, status {status_code})" - - request_logger.log_request_response( - operation_id=operation_id, - request_method=method, request_url=url, - response_status_code=status_code, - response_headers=dict(e.response.headers) if hasattr(e, "response") and e.response is not None else None, - response_content=error_content_for_log, - error_message=original_error_message # Log the original exception string as error - ) - - logging.debug(f"[DEBUG] API Error: {user_display_error_message} (Status: {status_code})") - if hasattr(e, "response") and e.response is not None and e.response.content: - logging.debug(f"[DEBUG] Response content: {e.response.content}") - - # Retry if the status code is in our retry list and we haven't exhausted retries - if (status_code in self.retry_status_codes and - retry_count < self.max_retries): - - delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) - logging.warning( - f"HTTP error {status_code}. " - f"Retrying in {delay:.2f}s ({retry_count + 1}/{self.max_retries})" - ) - time.sleep(delay) - return self.request( - method=method, - path=path, + logging.warning("Connection error. Retrying in %.2fs (%s/%s): %s", delay, retry_count + 1, + self.max_retries, str(e)) + await asyncio.sleep(delay) + return await self.request( + method, + path, params=params, data=data, files=files, @@ -583,40 +443,34 @@ def request( multipart_parser=multipart_parser, retry_count=retry_count + 1, ) + # One final connectivity check for diagnostics + connectivity = await self._check_connectivity(self.base_url) + if connectivity["is_local_issue"]: + raise LocalNetworkError( + "Unable to connect to the API server due to local network issues. " + "Please check your internet connection and try again." + ) from e + raise ApiServerError( + f"The API server at {self.base_url} is currently unreachable. " + f"The service may be experiencing issues. Please try again later." + ) from e - # Specific error messages for common status codes for user display - if status_code == 401: - user_display_error_message = "Unauthorized: Please login first to use this node." - elif status_code == 402: - user_display_error_message = "Payment Required: Please add credits to your account to use this node." - elif status_code == 409: - user_display_error_message = "There is a problem with your account. Please contact support@comfy.org." - elif status_code == 429: - user_display_error_message = "Rate Limit Exceeded: Please try again later." - # else, user_display_error_message remains as parsed from response or original HTTPError string - - raise Exception(user_display_error_message) # Raise with the user-friendly message - - # Parse and return JSON response - if response.content: - return response.json() - return {} - - def check_auth(self, auth_token, comfy_api_key): + @staticmethod + def _check_auth(auth_token, comfy_api_key): """Verify that an auth token is present or comfy_api_key is present""" if auth_token is None and comfy_api_key is None: raise Exception("Unauthorized: Please login first to use this node.") return auth_token or comfy_api_key @staticmethod - def upload_file( + async def upload_file( upload_url: str, file: io.BytesIO | str, content_type: str | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, - ): + ) -> aiohttp.ClientResponse: """Upload a file to the API with retry logic. Args: @@ -627,112 +481,167 @@ def upload_file( retry_delay: Initial delay between retries in seconds retry_backoff_factor: Multiplier for the delay after each retry """ - headers = {} + headers: Dict[str, str] = {} + skip_auto_headers: set[str] = set() if content_type: headers["Content-Type"] = content_type + else: + # tell aiohttp not to add Content-Type that will break the request signature and result in a 403 status. + skip_auto_headers.add("Content-Type") - # Prepare the file data + # Extract file bytes if isinstance(file, io.BytesIO): - file.seek(0) # Ensure we're at the start of the file + file.seek(0) data = file.read() elif isinstance(file, str): with open(file, "rb") as f: data = f.read() else: - raise ValueError("File must be either a BytesIO object or a file path string") - - # Try the upload with retries - last_exception = None - operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" # Simplified ID for uploads + raise ValueError("File must be BytesIO or str path") - # Log initial attempt (without full file data for brevity) + operation_id = f"upload_{upload_url.split('/')[-1]}_{uuid.uuid4().hex[:8]}" request_logger.log_request_response( operation_id=operation_id, request_method="PUT", request_url=upload_url, request_headers=headers, - request_data=f"[File data of type {content_type or 'unknown'}, size {len(data)} bytes]" + request_data=f"[File data {len(data)} bytes]", ) - for retry_attempt in range(max_retries + 1): + delay = retry_delay + for attempt in range(max_retries + 1): try: - response = requests.put(upload_url, data=data, headers=headers) - response.raise_for_status() - request_logger.log_request_response( - operation_id=operation_id, - request_method="PUT", request_url=upload_url, # For context - response_status_code=response.status_code, - response_headers=dict(response.headers), - response_content="File uploaded successfully." # Or response.text if available - ) - return response - - except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e: - last_exception = e - error_message_for_log = f"{type(e).__name__}: {str(e)}" - response_content_for_log = None - status_code_for_log = None - headers_for_log = None - - if hasattr(e, 'response') and e.response is not None: - status_code_for_log = e.response.status_code - headers_for_log = dict(e.response.headers) - try: - response_content_for_log = e.response.json() - except json.JSONDecodeError: - response_content_for_log = e.response.content - - + timeout = aiohttp.ClientTimeout(total=None) # honour server side timeouts + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.put( + upload_url, data=data, headers=headers, skip_auto_headers=skip_auto_headers, + ) as resp: + resp.raise_for_status() + request_logger.log_request_response( + operation_id=operation_id, + request_method="PUT", + request_url=upload_url, + response_status_code=resp.status, + response_headers=dict(resp.headers), + response_content="File uploaded successfully.", + ) + return resp + except (ClientError, asyncio.TimeoutError) as e: request_logger.log_request_response( operation_id=operation_id, - request_method="PUT", request_url=upload_url, - response_status_code=status_code_for_log, - response_headers=headers_for_log, - response_content=response_content_for_log, - error_message=error_message_for_log + request_method="PUT", + request_url=upload_url, + response_status_code=e.status if hasattr(e, "status") else None, + response_headers=dict(e.headers) if getattr(e, "headers") else None, + response_content=None, + error_message=f"{type(e).__name__}: {str(e)}", ) - - if retry_attempt < max_retries: - delay = retry_delay * (retry_backoff_factor ** retry_attempt) + if attempt < max_retries: logging.warning( - f"File upload failed: {str(e)}. " - f"Retrying in {delay:.2f}s ({retry_attempt + 1}/{max_retries})" + "Upload failed (%s/%s). Retrying in %.2fs. %s", attempt + 1, max_retries, delay, str(e) ) - time.sleep(delay) + await asyncio.sleep(delay) + delay *= retry_backoff_factor else: - break # Max retries reached + raise NetworkError(f"Failed to upload file after {max_retries + 1} attempts: {e}") from e - # If we've exhausted all retries, determine the final error type and raise - final_error_message = f"Failed to upload file after {max_retries + 1} attempts. Error: {str(last_exception)}" - try: - # Check basic internet connectivity - check_response = requests.get("https://www.google.com", timeout=5.0, verify=True) # Assuming verify=True is desired - if check_response.status_code >= 500: # Google itself has an issue (rare) - final_error_message = (f"Failed to upload file. Internet connectivity check to Google failed " - f"(status {check_response.status_code}). Original error: {str(last_exception)}") - # Not raising LocalNetworkError here as Google itself might be down. - # If Google is reachable, the issue is likely with the upload server or a more specific local problem - # not caught by a simple Google ping (e.g., DNS for the specific upload URL, firewall). - # The original last_exception is probably most relevant. - - except (requests.RequestException, socket.error) as conn_check_exc: - # Could not reach Google, likely a local network issue - final_error_message = (f"Failed to upload file due to network connectivity issues " - f"(cannot reach Google: {str(conn_check_exc)}). " - f"Original upload error: {str(last_exception)}") - request_logger.log_request_response( # Log final failure reason - operation_id=operation_id, - request_method="PUT", request_url=upload_url, - error_message=final_error_message - ) - raise LocalNetworkError(final_error_message) from last_exception + async def _handle_http_error( + self, + exc: ClientResponseError, + operation_id: str, + *req_meta, + retry_count: int, + response_content: dict | str = "", + ) -> Dict[str, Any]: + status_code = exc.status + if status_code == 401: + user_friendly = "Unauthorized: Please login first to use this node." + elif status_code == 402: + user_friendly = "Payment Required: Please add credits to your account to use this node." + elif status_code == 409: + user_friendly = "There is a problem with your account. Please contact support@comfy.org." + elif status_code == 429: + user_friendly = "Rate Limit Exceeded: Please try again later." + else: + if isinstance(response_content, dict): + if "error" in response_content and "message" in response_content["error"]: + user_friendly = f"API Error: {response_content['error']['message']}" + if "type" in response_content["error"]: + user_friendly += f" (Type: {response_content['error']['type']})" + else: # Handle cases where error is just a JSON dict with unknown format + user_friendly = f"API Error: {json.dumps(response_content)}" + else: + if len(response_content) < 200: # Arbitrary limit for display + user_friendly = f"API Error (raw): {response_content}" + else: + user_friendly = f"API Error (raw, status {response_content})" - request_logger.log_request_response( # Log final failure reason if not LocalNetworkError + request_logger.log_request_response( operation_id=operation_id, - request_method="PUT", request_url=upload_url, - error_message=final_error_message + request_method=req_meta[0], + request_url=req_meta[1], + response_status_code=exc.status, + response_headers=dict(req_meta[5]) if req_meta[5] else None, + response_content=response_content, + error_message=f"HTTP Error {exc.status}", ) - raise Exception(final_error_message) from last_exception + + logging.debug(f"[DEBUG] API Error: {user_friendly} (Status: {status_code})") + if response_content: + logging.debug(f"[DEBUG] Response content: {response_content}") + + # Retry if eligible + if status_code in self.retry_status_codes and retry_count < self.max_retries: + delay = self.retry_delay * (self.retry_backoff_factor ** retry_count) + logging.warning( + "HTTP error %s. Retrying in %.2fs (%s/%s)", + status_code, + delay, + retry_count + 1, + self.max_retries, + ) + await asyncio.sleep(delay) + return await self.request( + req_meta[0], # method + req_meta[1].replace(self.base_url, ""), # path + params=req_meta[2], + data=req_meta[3], + files=req_meta[4], + headers=req_meta[5], + content_type=req_meta[6], + multipart_parser=req_meta[7], + retry_count=retry_count + 1, + ) + + raise Exception(user_friendly) from exc + + @staticmethod + def _unpack_tuple(t): + """Helper to normalise (filename, file, content_type) tuples.""" + if len(t) == 3: + return t + elif len(t) == 2: + return t[0], t[1], "application/octet-stream" + else: + raise ValueError("files tuple must be (filename, file[, content_type])") + + async def _get_session(self) -> aiohttp.ClientSession: + if self._session is None or self._session.closed: + timeout = aiohttp.ClientTimeout(total=self.timeout) + self._session = aiohttp.ClientSession(timeout=timeout) + self._owns_session = True + return self._session + + async def close(self) -> None: + if self._owns_session and self._session and not self._session.closed: + await self._session.close() + + async def __aenter__(self) -> "ApiClient": + """Allow usage as async‑context‑manager – ensures clean teardown""" + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.close() class ApiEndpoint(Generic[T, R]): @@ -763,31 +672,28 @@ def __init__( class SynchronousOperation(Generic[T, R]): - """ - Represents a single synchronous API operation. - """ + """Represents a single synchronous API operation.""" def __init__( self, endpoint: ApiEndpoint[T, R], request: T, - files: Optional[Dict[str, Any]] = None, + files: Optional[Dict[str, Any] | list[tuple[str, Any]]] = None, api_base: str | None = None, auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str,str]] = None, + auth_kwargs: Optional[Dict[str, str]] = None, timeout: float = 604800.0, verify_ssl: bool = True, content_type: str = "application/json", - multipart_parser: Callable = None, + multipart_parser: Callable | None = None, max_retries: int = 3, retry_delay: float = 1.0, retry_backoff_factor: float = 2.0, - ): + ) -> None: self.endpoint = endpoint self.request = request - self.response = None - self.error = None + self.files = files self.api_base: str = api_base or args.comfy_api_base self.auth_token = auth_token self.comfy_api_key = comfy_api_key @@ -796,91 +702,64 @@ def __init__( self.comfy_api_key = auth_kwargs.get("comfy_api_key", self.comfy_api_key) self.timeout = timeout self.verify_ssl = verify_ssl - self.files = files self.content_type = content_type self.multipart_parser = multipart_parser self.max_retries = max_retries self.retry_delay = retry_delay self.retry_backoff_factor = retry_backoff_factor - def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the API operation using the provided client or create one with retry support""" - try: - # Create client if not provided - if client is None: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - timeout=self.timeout, - verify_ssl=self.verify_ssl, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - - # Convert request model to dict, but use None for EmptyRequest - request_dict = ( - None - if isinstance(self.request, EmptyRequest) - else self.request.model_dump(exclude_none=True) + async def execute(self, client: Optional[ApiClient] = None) -> R: + owns_client = client is None + if owns_client: + client = ApiClient( + base_url=self.api_base, + auth_token=self.auth_token, + comfy_api_key=self.comfy_api_key, + timeout=self.timeout, + verify_ssl=self.verify_ssl, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, ) - if request_dict: - for key, value in request_dict.items(): - if isinstance(value, Enum): - request_dict[key] = value.value - # Debug log for request + try: + request_dict: Optional[Dict[str, Any]] + if isinstance(self.request, EmptyRequest): + request_dict = None + else: + request_dict = self.request.model_dump(exclude_none=True) + for k, v in list(request_dict.items()): + if isinstance(v, Enum): + request_dict[k] = v.value + logging.debug( f"[DEBUG] API Request: {self.endpoint.method.value} {self.endpoint.path}" ) logging.debug(f"[DEBUG] Request Data: {json.dumps(request_dict, indent=2)}") logging.debug(f"[DEBUG] Query Params: {self.endpoint.query_params}") - # Make the request with built-in retry - resp = client.request( - method=self.endpoint.method.value, - path=self.endpoint.path, - data=request_dict, + response_json = await client.request( + self.endpoint.method.value, + self.endpoint.path, params=self.endpoint.query_params, + data=request_dict, files=self.files, content_type=self.content_type, - multipart_parser=self.multipart_parser + multipart_parser=self.multipart_parser, ) - # Debug log for response logging.debug("=" * 50) logging.debug("[DEBUG] RESPONSE DETAILS:") logging.debug("[DEBUG] Status Code: 200 (Success)") - logging.debug(f"[DEBUG] Response Body: {json.dumps(resp, indent=2)}") + logging.debug(f"[DEBUG] Response Body: {json.dumps(response_json, indent=2)}") logging.debug("=" * 50) - # Parse and return the response - return self._parse_response(resp) - - except LocalNetworkError as e: - # Propagate specific network error types - logging.error(f"[ERROR] Local network error: {str(e)}") - raise - - except ApiServerError as e: - # Propagate API server errors - logging.error(f"[ERROR] API server error: {str(e)}") - raise - - except Exception as e: - logging.error(f"[ERROR] API Exception: {str(e)}") - raise Exception(str(e)) - - def _parse_response(self, resp): - """Parse response data - can be overridden by subclasses""" - # The response is already the complete object, don't extract just the "data" field - # as that would lose the outer structure (created timestamp, etc.) - - # Parse response using the provided model - self.response = self.endpoint.response_model.model_validate(resp) - logging.debug(f"[DEBUG] Parsed Response: {self.response}") - return self.response + parsed_response = self.endpoint.response_model.model_validate(response_json) + logging.debug(f"[DEBUG] Parsed Response: {parsed_response}") + return parsed_response + finally: + if owns_client: + await client.close() class TaskStatus(str, Enum): @@ -892,23 +771,21 @@ class TaskStatus(str, Enum): class PollingOperation(Generic[T, R]): - """ - Represents an asynchronous API operation that requires polling for completion. - """ + """Represents an asynchronous API operation that requires polling for completion.""" def __init__( self, poll_endpoint: ApiEndpoint[EmptyRequest, R], - completed_statuses: list, - failed_statuses: list, + completed_statuses: list[str], + failed_statuses: list[str], status_extractor: Callable[[R], str], - progress_extractor: Callable[[R], float] = None, - result_url_extractor: Callable[[R], str] = None, + progress_extractor: Callable[[R], float] | None = None, + result_url_extractor: Callable[[R], str] | None = None, request: Optional[T] = None, api_base: str | None = None, auth_token: Optional[str] = None, comfy_api_key: Optional[str] = None, - auth_kwargs: Optional[Dict[str,str]] = None, + auth_kwargs: Optional[Dict[str, str]] = None, poll_interval: float = 5.0, max_poll_attempts: int = 120, # Default max polling attempts (10 minutes with 5s interval) max_retries: int = 3, # Max retries per individual API call @@ -916,7 +793,7 @@ def __init__( retry_backoff_factor: float = 2.0, estimated_duration: Optional[float] = None, node_id: Optional[str] = None, - ): + ) -> None: self.poll_endpoint = poll_endpoint self.request = request self.api_base: str = api_base or args.comfy_api_base @@ -931,100 +808,73 @@ def __init__( self.retry_delay = retry_delay self.retry_backoff_factor = retry_backoff_factor self.estimated_duration = estimated_duration - - # Polling configuration - self.status_extractor = status_extractor or ( - lambda x: getattr(x, "status", None) - ) + self.status_extractor = status_extractor or (lambda x: getattr(x, "status", None)) self.progress_extractor = progress_extractor self.result_url_extractor = result_url_extractor self.node_id = node_id self.completed_statuses = completed_statuses self.failed_statuses = failed_statuses - - # For storing response data - self.final_response = None - self.error = None - - def execute(self, client: Optional[ApiClient] = None) -> R: - """Execute the polling operation using the provided client. If failed, raise an exception.""" + self.final_response: Optional[R] = None + + async def execute(self, client: Optional[ApiClient] = None) -> R: + owns_client = client is None + if owns_client: + client = ApiClient( + base_url=self.api_base, + auth_token=self.auth_token, + comfy_api_key=self.comfy_api_key, + max_retries=self.max_retries, + retry_delay=self.retry_delay, + retry_backoff_factor=self.retry_backoff_factor, + ) try: - if client is None: - client = ApiClient( - base_url=self.api_base, - auth_token=self.auth_token, - comfy_api_key=self.comfy_api_key, - max_retries=self.max_retries, - retry_delay=self.retry_delay, - retry_backoff_factor=self.retry_backoff_factor, - ) - return self._poll_until_complete(client) - except LocalNetworkError as e: - # Provide clear message for local network issues - raise Exception( - f"Polling failed due to local network issues. Please check your internet connection. " - f"Details: {str(e)}" - ) from e - except ApiServerError as e: - # Provide clear message for API server issues - raise Exception( - f"Polling failed due to API server issues. The service may be experiencing problems. " - f"Please try again later. Details: {str(e)}" - ) from e - except Exception as e: - raise Exception(f"Error during polling: {str(e)}") + return await self._poll_until_complete(client) + finally: + if owns_client: + await client.close() def _display_text_on_node(self, text: str): - """Sends text to the client which will be displayed on the node in the UI""" if not self.node_id: return - PromptServer.instance.send_progress_text(text, self.node_id) - def _display_time_progress_on_node(self, time_completed: int): + def _display_time_progress_on_node(self, time_completed: int | float): if not self.node_id: return - if self.estimated_duration is not None: - estimated_time_remaining = max( - 0, int(self.estimated_duration) - int(time_completed) - ) - message = f"Task in progress: {time_completed:.0f}s (~{estimated_time_remaining:.0f}s remaining)" + remaining = max(0, int(self.estimated_duration) - time_completed) + message = f"Task in progress: {time_completed}s (~{remaining}s remaining)" else: - message = f"Task in progress: {time_completed:.0f}s" + message = f"Task in progress: {time_completed}s" self._display_text_on_node(message) def _check_task_status(self, response: R) -> TaskStatus: - """Check task status using the status extractor function""" try: status = self.status_extractor(response) if status in self.completed_statuses: return TaskStatus.COMPLETED - elif status in self.failed_statuses: + if status in self.failed_statuses: return TaskStatus.FAILED return TaskStatus.PENDING except Exception as e: - logging.error(f"Error extracting status: {e}") + logging.error("Error extracting status: %s", e) return TaskStatus.PENDING - def _poll_until_complete(self, client: ApiClient) -> R: + async def _poll_until_complete(self, client: ApiClient) -> R: """Poll until the task is complete""" - poll_count = 0 consecutive_errors = 0 max_consecutive_errors = min(5, self.max_retries * 2) # Limit consecutive errors if self.progress_extractor: progress = utils.ProgressBar(PROGRESS_BAR_MAX) - while poll_count < self.max_poll_attempts: + status = TaskStatus.PENDING + for poll_count in range(1, self.max_poll_attempts + 1): try: - poll_count += 1 logging.debug(f"[DEBUG] Polling attempt #{poll_count}") request_dict = ( - self.request.model_dump(exclude_none=True) - if self.request is not None - else None + None if self.request is None else self.request.model_dump(exclude_none=True) ) if poll_count == 1: @@ -1036,18 +886,14 @@ def _poll_until_complete(self, client: ApiClient) -> R: ) # Query task status - resp = client.request( - method=self.poll_endpoint.method.value, - path=self.poll_endpoint.path, + resp = await client.request( + self.poll_endpoint.method.value, + self.poll_endpoint.path, params=self.poll_endpoint.query_params, data=request_dict, ) - - # Successfully got a response, reset consecutive error count - consecutive_errors = 0 - - # Parse response - response_obj = self.poll_endpoint.response_model.model_validate(resp) + consecutive_errors = 0 # reset on success + response_obj: R = self.poll_endpoint.response_model.model_validate(resp) # Check if task is complete status = self._check_task_status(response_obj) @@ -1065,45 +911,30 @@ def _poll_until_complete(self, client: ApiClient) -> R: result_url = self.result_url_extractor(response_obj) if result_url: message = f"Result URL: {result_url}" - else: - message = "Task completed successfully!" logging.debug(f"[DEBUG] {message}") self._display_text_on_node(message) self.final_response = response_obj if self.progress_extractor: progress.update(100) return self.final_response - elif status == TaskStatus.FAILED: + if status == TaskStatus.FAILED: message = f"Task failed: {json.dumps(resp)}" logging.error(f"[DEBUG] {message}") raise Exception(message) - else: - logging.debug("[DEBUG] Task still pending, continuing to poll...") - - # Wait before polling again - logging.debug( - f"[DEBUG] Waiting {self.poll_interval} seconds before next poll" - ) + logging.debug("[DEBUG] Task still pending, continuing to poll...") + # Task pending – wait for i in range(int(self.poll_interval)): - time_completed = (poll_count * self.poll_interval) + i - self._display_time_progress_on_node(time_completed) - time.sleep(1) + self._display_time_progress_on_node((poll_count - 1) * self.poll_interval + i) + await asyncio.sleep(1) - except (LocalNetworkError, ApiServerError) as e: - # For network-related errors, increment error count and potentially abort + except (LocalNetworkError, ApiServerError, NetworkError) as e: consecutive_errors += 1 if consecutive_errors >= max_consecutive_errors: raise Exception( - f"Polling aborted after {consecutive_errors} consecutive network errors: {str(e)}" + f"Polling aborted after {consecutive_errors} network errors: {str(e)}" ) from e - - # Log the error but continue polling - logging.warning( - f"Network error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " - f"Will retry in {self.poll_interval} seconds." - ) - time.sleep(self.poll_interval) - + logging.warning("Network error (%s/%s): %s", consecutive_errors, max_consecutive_errors, str(e)) + await asyncio.sleep(self.poll_interval) except Exception as e: # For other errors, increment count and potentially abort consecutive_errors += 1 @@ -1117,10 +948,10 @@ def _poll_until_complete(self, client: ApiClient) -> R: f"Error during polling (attempt {poll_count}/{self.max_poll_attempts}): {str(e)}. " f"Will retry in {self.poll_interval} seconds." ) - time.sleep(self.poll_interval) + await asyncio.sleep(self.poll_interval) # If we've exhausted all polling attempts raise Exception( - f"Polling timed out after {poll_count} attempts ({poll_count * self.poll_interval} seconds). " - f"The operation may still be running on the server but is taking longer than expected." + f"Polling timed out after {self.max_poll_attempts} attempts (" f"{self.max_poll_attempts * self.poll_interval} seconds). " + "The operation may still be running on the server but is taking longer than expected." ) diff --git a/comfy_api_nodes/nodes_bfl.py b/comfy_api_nodes/nodes_bfl.py index d93fbd77822b..c09be8d5bf22 100644 --- a/comfy_api_nodes/nodes_bfl.py +++ b/comfy_api_nodes/nodes_bfl.py @@ -1,3 +1,4 @@ +import asyncio import io from inspect import cleandoc from typing import Union, Optional @@ -28,7 +29,7 @@ import numpy as np from PIL import Image -import requests +import aiohttp import torch import base64 import time @@ -44,18 +45,18 @@ def convert_mask_to_image(mask: torch.Tensor): return mask -def handle_bfl_synchronous_operation( +async def handle_bfl_synchronous_operation( operation: SynchronousOperation, timeout_bfl_calls=360, node_id: Union[str, None] = None, ): - response_api: BFLFluxProGenerateResponse = operation.execute() - return _poll_until_generated( + response_api: BFLFluxProGenerateResponse = await operation.execute() + return await _poll_until_generated( response_api.polling_url, timeout=timeout_bfl_calls, node_id=node_id ) -def _poll_until_generated( +async def _poll_until_generated( polling_url: str, timeout=360, node_id: Union[str, None] = None ): # used bfl-comfy-nodes to verify code implementation: @@ -66,55 +67,56 @@ def _poll_until_generated( retry_404_seconds = 2 retry_202_seconds = 2 retry_pending_seconds = 1 - request = requests.Request(method=HttpMethod.GET, url=polling_url) - # NOTE: should True loop be replaced with checking if workflow has been interrupted? - while True: - if node_id: - time_elapsed = time.time() - start_time - PromptServer.instance.send_progress_text( - f"Generating ({time_elapsed:.0f}s)", node_id - ) - response = requests.Session().send(request.prepare()) - if response.status_code == 200: - result = response.json() - if result["status"] == BFLStatus.ready: - img_url = result["result"]["sample"] - if node_id: - PromptServer.instance.send_progress_text( - f"Result URL: {img_url}", node_id - ) - img_response = requests.get(img_url) - return process_image_response(img_response) - elif result["status"] in [ - BFLStatus.request_moderated, - BFLStatus.content_moderated, - ]: - status = result["status"] - raise Exception( - f"BFL API did not return an image due to: {status}." + async with aiohttp.ClientSession() as session: + # NOTE: should True loop be replaced with checking if workflow has been interrupted? + while True: + if node_id: + time_elapsed = time.time() - start_time + PromptServer.instance.send_progress_text( + f"Generating ({time_elapsed:.0f}s)", node_id ) - elif result["status"] == BFLStatus.error: - raise Exception(f"BFL API encountered an error: {result}.") - elif result["status"] == BFLStatus.pending: - time.sleep(retry_pending_seconds) - continue - elif response.status_code == 404: - if retries_404 < max_retries_404: - retries_404 += 1 - time.sleep(retry_404_seconds) - continue - raise Exception( - f"BFL API could not find task after {max_retries_404} tries." - ) - elif response.status_code == 202: - time.sleep(retry_202_seconds) - elif time.time() - start_time > timeout: - raise Exception( - f"BFL API experienced a timeout; could not return request under {timeout} seconds." - ) - else: - raise Exception(f"BFL API encountered an error: {response.json()}") + + async with session.get(polling_url) as response: + if response.status == 200: + result = await response.json() + if result["status"] == BFLStatus.ready: + img_url = result["result"]["sample"] + if node_id: + PromptServer.instance.send_progress_text( + f"Result URL: {img_url}", node_id + ) + async with session.get(img_url) as img_resp: + return process_image_response(await img_resp.content.read()) + elif result["status"] in [ + BFLStatus.request_moderated, + BFLStatus.content_moderated, + ]: + status = result["status"] + raise Exception( + f"BFL API did not return an image due to: {status}." + ) + elif result["status"] == BFLStatus.error: + raise Exception(f"BFL API encountered an error: {result}.") + elif result["status"] == BFLStatus.pending: + await asyncio.sleep(retry_pending_seconds) + continue + elif response.status == 404: + if retries_404 < max_retries_404: + retries_404 += 1 + await asyncio.sleep(retry_404_seconds) + continue + raise Exception( + f"BFL API could not find task after {max_retries_404} tries." + ) + elif response.status == 202: + await asyncio.sleep(retry_202_seconds) + elif time.time() - start_time > timeout: + raise Exception( + f"BFL API experienced a timeout; could not return request under {timeout} seconds." + ) + else: + raise Exception(f"BFL API encountered an error: {response.json()}") def convert_image_to_base64(image: torch.Tensor): scaled_image = downscale_image_tensor(image, total_pixels=2048 * 2048) @@ -222,7 +224,7 @@ def VALIDATE_INPUTS(cls, aspect_ratio: str): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, prompt: str, aspect_ratio: str, @@ -266,7 +268,7 @@ def api_call( ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -354,7 +356,7 @@ def INPUT_TYPES(s): BFL_PATH = "/proxy/bfl/flux-kontext-pro/generate" - def api_call( + async def api_call( self, prompt: str, aspect_ratio: str, @@ -397,7 +399,7 @@ def api_call( ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -489,7 +491,7 @@ def INPUT_TYPES(s): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, prompt: str, prompt_upsampling, @@ -524,7 +526,7 @@ def api_call( ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -632,7 +634,7 @@ def INPUT_TYPES(s): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -670,7 +672,7 @@ def api_call( ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -744,7 +746,7 @@ def INPUT_TYPES(s): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, image: torch.Tensor, mask: torch.Tensor, @@ -780,7 +782,7 @@ def api_call( ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -879,7 +881,7 @@ def INPUT_TYPES(s): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, control_image: torch.Tensor, prompt: str, @@ -929,7 +931,7 @@ def scale_value(value: float, min_val=0, max_val=500): ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) @@ -1008,7 +1010,7 @@ def INPUT_TYPES(s): API_NODE = True CATEGORY = "api node/image/BFL" - def api_call( + async def api_call( self, control_image: torch.Tensor, prompt: str, @@ -1045,7 +1047,7 @@ def api_call( ), auth_kwargs=kwargs, ) - output_image = handle_bfl_synchronous_operation(operation, node_id=unique_id) + output_image = await handle_bfl_synchronous_operation(operation, node_id=unique_id) return (output_image,) diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index af33279d525a..3751fb2a1e2e 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -303,7 +303,7 @@ def create_text_part(self, text: str) -> GeminiPart: """ return GeminiPart(text=text) - def api_call( + async def api_call( self, prompt: str, model: GeminiModel, @@ -332,7 +332,7 @@ def api_call( parts.extend(files) # Create response - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=get_gemini_endpoint(model), request=GeminiGenerateContentRequest( contents=[ diff --git a/comfy_api_nodes/nodes_ideogram.py b/comfy_api_nodes/nodes_ideogram.py index b8487355fadb..db24e6da430c 100644 --- a/comfy_api_nodes/nodes_ideogram.py +++ b/comfy_api_nodes/nodes_ideogram.py @@ -212,7 +212,7 @@ "1536x640" ] -def download_and_process_images(image_urls): +async def download_and_process_images(image_urls): """Helper function to download and process multiple images from URLs""" # Initialize list to store image tensors @@ -220,7 +220,7 @@ def download_and_process_images(image_urls): for image_url in image_urls: # Using functions from apinode_utils.py to handle downloading and processing - image_bytesio = download_url_to_bytesio(image_url) # Download image content to BytesIO + image_bytesio = await download_url_to_bytesio(image_url) # Download image content to BytesIO img_tensor = bytesio_to_image_tensor(image_bytesio, mode="RGB") # Convert to torch.Tensor with RGB mode image_tensors.append(img_tensor) @@ -328,7 +328,7 @@ def INPUT_TYPES(cls) -> InputTypeDict: DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, turbo=False, @@ -367,7 +367,7 @@ def api_call( auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") @@ -378,7 +378,7 @@ def api_call( raise Exception("No image URLs were generated in the response") display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return (await download_and_process_images(image_urls),) class IdeogramV2(ComfyNodeABC): @@ -487,7 +487,7 @@ def INPUT_TYPES(cls) -> InputTypeDict: DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, turbo=False, @@ -543,7 +543,7 @@ def api_call( auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") @@ -554,7 +554,7 @@ def api_call( raise Exception("No image URLs were generated in the response") display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return (await download_and_process_images(image_urls),) class IdeogramV3(ComfyNodeABC): """ @@ -653,7 +653,7 @@ def INPUT_TYPES(cls) -> InputTypeDict: DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, image=None, @@ -774,7 +774,7 @@ def api_call( ) # Execute the operation and process response - response = operation.execute() + response = await operation.execute() if not response.data or len(response.data) == 0: raise Exception("No images were generated in the response") @@ -785,7 +785,7 @@ def api_call( raise Exception("No image URLs were generated in the response") display_image_urls_on_node(image_urls, unique_id) - return (download_and_process_images(image_urls),) + return (await download_and_process_images(image_urls),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 69e9e5cf0657..9d9eb5628a42 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -109,7 +109,7 @@ class KlingApiError(Exception): pass -def poll_until_finished( +async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, R], result_url_extractor: Optional[Callable[[R], str]] = None, @@ -117,7 +117,7 @@ def poll_until_finished( node_id: Optional[str] = None, ) -> R: """Polls the Kling API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( + return await PollingOperation( poll_endpoint=api_endpoint, completed_statuses=[ KlingTaskStatus.succeed.value, @@ -278,18 +278,18 @@ def get_images_urls_from_response(response) -> Optional[str]: return None -def video_result_to_node_output( +async def video_result_to_node_output( video: KlingVideoResult, ) -> tuple[VideoFromFile, str, str]: """Converts a KlingVideoResult to a tuple of (VideoFromFile, str, str) to be used as a ComfyUI node output.""" return ( - download_url_to_video_output(video.url), + await download_url_to_video_output(str(video.url)), str(video.id), str(video.duration), ) -def image_result_to_node_output( +async def image_result_to_node_output( images: list[KlingImageResult], ) -> torch.Tensor: """ @@ -297,9 +297,9 @@ def image_result_to_node_output( If multiple images are returned, they will be stacked along the batch dimension. """ if len(images) == 1: - return download_url_to_image_tensor(images[0].url) + return await download_url_to_image_tensor(str(images[0].url)) else: - return torch.cat([download_url_to_image_tensor(image.url) for image in images]) + return torch.cat([await download_url_to_image_tensor(str(image.url)) for image in images]) class KlingNodeBase(ComfyNodeABC): @@ -467,10 +467,10 @@ def INPUT_TYPES(s): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Text to Video Node" - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingText2VideoResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_TEXT_TO_VIDEO}/{task_id}", @@ -483,7 +483,7 @@ def get_response( node_id=node_id, ) - def api_call( + async def api_call( self, prompt: str, negative_prompt: str, @@ -519,17 +519,17 @@ def api_call( auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingCameraControlT2VNode(KlingTextToVideoNode): @@ -581,7 +581,7 @@ def INPUT_TYPES(s): DESCRIPTION = "Transform text into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original text." - def api_call( + async def api_call( self, prompt: str, negative_prompt: str, @@ -591,7 +591,7 @@ def api_call( unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( model_name=KlingVideoGenModelName.kling_v1, cfg_scale=cfg_scale, mode=KlingVideoGenMode.std, @@ -670,10 +670,10 @@ def INPUT_TYPES(s): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Image to Video Node" - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingImage2VideoResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_IMAGE_TO_VIDEO}/{task_id}", @@ -686,7 +686,7 @@ def get_response( node_id=node_id, ) - def api_call( + async def api_call( self, start_frame: torch.Tensor, prompt: str, @@ -733,17 +733,17 @@ def api_call( auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingCameraControlI2VNode(KlingImage2VideoNode): @@ -798,7 +798,7 @@ def INPUT_TYPES(s): DESCRIPTION = "Transform still images into cinematic videos with professional camera movements that simulate real-world cinematography. Control virtual camera actions including zoom, rotation, pan, tilt, and first-person view, while maintaining focus on your original image." - def api_call( + async def api_call( self, start_frame: torch.Tensor, prompt: str, @@ -809,7 +809,7 @@ def api_call( unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( model_name=KlingVideoGenModelName.kling_v1_5, start_frame=start_frame, cfg_scale=cfg_scale, @@ -897,7 +897,7 @@ def INPUT_TYPES(s): DESCRIPTION = "Generate a video sequence that transitions between your provided start and end images. The node creates all frames in between, producing a smooth transformation from the first frame to the last." - def api_call( + async def api_call( self, start_frame: torch.Tensor, end_frame: torch.Tensor, @@ -912,7 +912,7 @@ def api_call( mode, duration, model_name = KlingStartEndFrameNode.get_mode_string_mapping()[ mode ] - return super().api_call( + return await super().api_call( prompt=prompt, negative_prompt=negative_prompt, model_name=model_name, @@ -964,10 +964,10 @@ def INPUT_TYPES(s): RETURN_NAMES = ("VIDEO", "video_id", "duration") DESCRIPTION = "Kling Video Extend Node. Extend videos made by other Kling nodes. The video_id is created by using other Kling Nodes." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVideoExtendResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_VIDEO_EXTEND}/{task_id}", @@ -980,7 +980,7 @@ def get_response( node_id=node_id, ) - def api_call( + async def api_call( self, prompt: str, negative_prompt: str, @@ -1006,17 +1006,17 @@ def api_call( auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingVideoEffectsBase(KlingNodeBase): @@ -1025,10 +1025,10 @@ class KlingVideoEffectsBase(KlingNodeBase): RETURN_TYPES = ("VIDEO", "STRING", "STRING") RETURN_NAMES = ("VIDEO", "video_id", "duration") - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVideoEffectsResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_VIDEO_EFFECTS}/{task_id}", @@ -1041,7 +1041,7 @@ def get_response( node_id=node_id, ) - def api_call( + async def api_call( self, dual_character: bool, effect_scene: KlingDualCharacterEffectsScene | KlingSingleImageEffectsScene, @@ -1084,17 +1084,17 @@ def api_call( auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingDualCharacterVideoEffectNode(KlingVideoEffectsBase): @@ -1142,7 +1142,7 @@ def INPUT_TYPES(s): RETURN_TYPES = ("VIDEO", "STRING") RETURN_NAMES = ("VIDEO", "duration") - def api_call( + async def api_call( self, image_left: torch.Tensor, image_right: torch.Tensor, @@ -1153,7 +1153,7 @@ def api_call( unique_id: Optional[str] = None, **kwargs, ): - video, _, duration = super().api_call( + video, _, duration = await super().api_call( dual_character=True, effect_scene=effect_scene, model_name=model_name, @@ -1208,7 +1208,7 @@ def INPUT_TYPES(s): DESCRIPTION = "Achieve different special effects when generating a video based on the effect_scene." - def api_call( + async def api_call( self, image: torch.Tensor, effect_scene: KlingSingleImageEffectsScene, @@ -1217,7 +1217,7 @@ def api_call( unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( dual_character=False, effect_scene=effect_scene, model_name=model_name, @@ -1253,11 +1253,11 @@ def validate_text(self, text: str): f"Text is too long. Maximum length is {MAX_PROMPT_LENGTH_LIP_SYNC} characters." ) - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingLipSyncResponse: """Polls the Kling API endpoint until the task reaches a terminal state.""" - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_LIP_SYNC}/{task_id}", @@ -1270,7 +1270,7 @@ def get_response( node_id=node_id, ) - def api_call( + async def api_call( self, video: VideoInput, audio: Optional[AudioInput] = None, @@ -1287,12 +1287,12 @@ def api_call( self.validate_lip_sync_video(video) # Upload video to Comfy API and get download URL - video_url = upload_video_to_comfyapi(video, auth_kwargs=kwargs) + video_url = await upload_video_to_comfyapi(video, auth_kwargs=kwargs) logging.info("Uploaded video to Comfy API. URL: %s", video_url) # Upload the audio file to Comfy API and get download URL if audio: - audio_url = upload_audio_to_comfyapi(audio, auth_kwargs=kwargs) + audio_url = await upload_audio_to_comfyapi(audio, auth_kwargs=kwargs) logging.info("Uploaded audio to Comfy API. URL: %s", audio_url) else: audio_url = None @@ -1319,17 +1319,17 @@ def api_call( auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_video_result_response(final_response) video = get_video_from_response(final_response) - return video_result_to_node_output(video) + return await video_result_to_node_output(video) class KlingLipSyncAudioToVideoNode(KlingLipSyncBase): @@ -1357,7 +1357,7 @@ def INPUT_TYPES(s): DESCRIPTION = "Kling Lip Sync Audio to Video Node. Syncs mouth movements in a video file to the audio content of an audio file. When using, ensure that the audio contains clearly distinguishable vocals and that the video contains a distinct face. The audio file should not be larger than 5MB. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - def api_call( + async def api_call( self, video: VideoInput, audio: AudioInput, @@ -1365,7 +1365,7 @@ def api_call( unique_id: Optional[str] = None, **kwargs, ): - return super().api_call( + return await super().api_call( video=video, audio=audio, voice_language=voice_language, @@ -1469,7 +1469,7 @@ def INPUT_TYPES(s): DESCRIPTION = "Kling Lip Sync Text to Video Node. Syncs mouth movements in a video file to a text prompt. The video file should not be larger than 100MB, should have height/width between 720px and 1920px, and should be between 2s and 10s in length." - def api_call( + async def api_call( self, video: VideoInput, text: str, @@ -1479,7 +1479,7 @@ def api_call( **kwargs, ): voice_id, voice_language = KlingLipSyncTextToVideoNode.get_voice_config()[voice] - return super().api_call( + return await super().api_call( video=video, text=text, voice_language=voice_language, @@ -1533,10 +1533,10 @@ def INPUT_TYPES(s): DESCRIPTION = "Kling Virtual Try On Node. Input a human image and a cloth image to try on the cloth on the human. You can merge multiple clothing item pictures into one image with a white background." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> KlingVirtualTryOnResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_VIRTUAL_TRY_ON}/{task_id}", @@ -1549,7 +1549,7 @@ def get_response( node_id=node_id, ) - def api_call( + async def api_call( self, human_image: torch.Tensor, cloth_image: torch.Tensor, @@ -1572,17 +1572,17 @@ def api_call( auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (image_result_to_node_output(images),) + return (await image_result_to_node_output(images),) class KlingImageGenerationNode(KlingImageGenerationBase): @@ -1655,13 +1655,13 @@ def INPUT_TYPES(s): DESCRIPTION = "Kling Image Generation Node. Generate an image from a text prompt with an optional reference image." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: Optional[dict[str, str]], node_id: Optional[str] = None, ) -> KlingImageGenerationsResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_IMAGE_GENERATIONS}/{task_id}", @@ -1674,7 +1674,7 @@ def get_response( node_id=node_id, ) - def api_call( + async def api_call( self, model_name: KlingImageGenModelName, prompt: str, @@ -1714,17 +1714,17 @@ def api_call( auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.data.task_id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) validate_image_result_response(final_response) images = get_images_from_response(final_response) - return (image_result_to_node_output(images),) + return (await image_result_to_node_output(images),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_luma.py b/comfy_api_nodes/nodes_luma.py index 525dc38e628b..b3c32bed55c5 100644 --- a/comfy_api_nodes/nodes_luma.py +++ b/comfy_api_nodes/nodes_luma.py @@ -38,7 +38,7 @@ ) from server import PromptServer -import requests +import aiohttp import torch from io import BytesIO @@ -217,7 +217,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -234,19 +234,19 @@ def api_call( # handle image_luma_ref api_image_ref = None if image_luma_ref is not None: - api_image_ref = self._convert_luma_refs( + api_image_ref = await self._convert_luma_refs( image_luma_ref, max_refs=4, auth_kwargs=kwargs, ) # handle style_luma_ref api_style_ref = None if style_image is not None: - api_style_ref = self._convert_style_image( + api_style_ref = await self._convert_style_image( style_image, weight=style_image_weight, auth_kwargs=kwargs, ) # handle character_ref images character_ref = None if character_image is not None: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( character_image, max_images=4, auth_kwargs=kwargs, ) character_ref = LumaCharacterRef( @@ -270,7 +270,7 @@ def api_call( ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -286,19 +286,20 @@ def api_call( node_id=unique_id, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - img_response = requests.get(response_poll.assets.image) - img = process_image_response(img_response) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.image) as img_response: + img = process_image_response(await img_response.content.read()) return (img,) - def _convert_luma_refs( + async def _convert_luma_refs( self, luma_ref: LumaReferenceChain, max_refs: int, auth_kwargs: Optional[dict[str,str]] = None ): luma_urls = [] ref_count = 0 for ref in luma_ref.refs: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( ref.image, max_images=1, auth_kwargs=auth_kwargs ) luma_urls.append(download_urls[0]) @@ -307,13 +308,13 @@ def _convert_luma_refs( break return luma_ref.create_api_model(download_urls=luma_urls, max_refs=max_refs) - def _convert_style_image( + async def _convert_style_image( self, style_image: torch.Tensor, weight: float, auth_kwargs: Optional[dict[str,str]] = None ): chain = LumaReferenceChain( first_ref=LumaReference(image=style_image, weight=weight) ) - return self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) + return await self._convert_luma_refs(chain, max_refs=1, auth_kwargs=auth_kwargs) class LumaImageModifyNode(ComfyNodeABC): @@ -370,7 +371,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -381,7 +382,7 @@ def api_call( **kwargs, ): # first, upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( image, max_images=1, auth_kwargs=kwargs, ) image_url = download_urls[0] @@ -402,7 +403,7 @@ def api_call( ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -418,10 +419,11 @@ def api_call( node_id=unique_id, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - img_response = requests.get(response_poll.assets.image) - img = process_image_response(img_response) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.image) as img_response: + img = process_image_response(await img_response.content.read()) return (img,) @@ -494,7 +496,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -529,7 +531,7 @@ def api_call( ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() if unique_id: PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) @@ -549,10 +551,11 @@ def api_call( estimated_duration=LUMA_T2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.assets.video) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.video) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) class LumaImageToVideoGenerationNode(ComfyNodeABC): @@ -626,7 +629,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, model: str, @@ -644,7 +647,7 @@ def api_call( raise Exception( "At least one of first_image and last_image requires an input." ) - keyframes = self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) + keyframes = await self._convert_to_keyframes(first_image, last_image, auth_kwargs=kwargs) duration = duration if model != LumaVideoModel.ray_1_6 else None resolution = resolution if model != LumaVideoModel.ray_1_6 else None @@ -667,7 +670,7 @@ def api_call( ), auth_kwargs=kwargs, ) - response_api: LumaGeneration = operation.execute() + response_api: LumaGeneration = await operation.execute() if unique_id: PromptServer.instance.send_progress_text(f"Luma video generation started: {response_api.id}", unique_id) @@ -687,12 +690,13 @@ def api_call( estimated_duration=LUMA_I2V_AVERAGE_DURATION, auth_kwargs=kwargs, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.assets.video) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.assets.video) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) - def _convert_to_keyframes( + async def _convert_to_keyframes( self, first_image: torch.Tensor = None, last_image: torch.Tensor = None, @@ -703,12 +707,12 @@ def _convert_to_keyframes( frame0 = None frame1 = None if first_image is not None: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( first_image, max_images=1, auth_kwargs=auth_kwargs, ) frame0 = LumaImageReference(type="image", url=download_urls[0]) if last_image is not None: - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( last_image, max_images=1, auth_kwargs=auth_kwargs, ) frame1 = LumaImageReference(type="image", url=download_urls[0]) diff --git a/comfy_api_nodes/nodes_minimax.py b/comfy_api_nodes/nodes_minimax.py index 9b46636dbbd1..58d2ed90c45c 100644 --- a/comfy_api_nodes/nodes_minimax.py +++ b/comfy_api_nodes/nodes_minimax.py @@ -86,7 +86,7 @@ def INPUT_TYPES(s): API_NODE = True OUTPUT_NODE = True - def generate_video( + async def generate_video( self, prompt_text, seed=0, @@ -104,12 +104,12 @@ def generate_video( # upload image, if passed in image_url = None if image is not None: - image_url = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs)[0] + image_url = (await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs))[0] # TODO: figure out how to deal with subject properly, API returns invalid params when using S2V-01 model subject_reference = None if subject is not None: - subject_url = upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs)[0] + subject_url = (await upload_images_to_comfyapi(subject, max_images=1, auth_kwargs=kwargs))[0] subject_reference = [SubjectReferenceItem(image=subject_url)] @@ -130,7 +130,7 @@ def generate_video( ), auth_kwargs=kwargs, ) - response = video_generate_operation.execute() + response = await video_generate_operation.execute() task_id = response.task_id if not task_id: @@ -151,7 +151,7 @@ def generate_video( node_id=unique_id, auth_kwargs=kwargs, ) - task_result = video_generate_operation.execute() + task_result = await video_generate_operation.execute() file_id = task_result.file_id if file_id is None: @@ -167,7 +167,7 @@ def generate_video( request=EmptyRequest(), auth_kwargs=kwargs, ) - file_result = file_retrieve_operation.execute() + file_result = await file_retrieve_operation.execute() file_url = file_result.file.download_url if file_url is None: @@ -182,7 +182,7 @@ def generate_video( message = f"Result URL: {file_url}" PromptServer.instance.send_progress_text(message, unique_id) - video_io = download_url_to_bytesio(file_url) + video_io = await download_url_to_bytesio(file_url) if video_io is None: error_msg = f"Failed to download video from {file_url}" logging.error(error_msg) diff --git a/comfy_api_nodes/nodes_moonvalley.py b/comfy_api_nodes/nodes_moonvalley.py index 789fcef02864..164ca3ea56c7 100644 --- a/comfy_api_nodes/nodes_moonvalley.py +++ b/comfy_api_nodes/nodes_moonvalley.py @@ -95,14 +95,14 @@ def get_video_url_from_response(response) -> Optional[str]: return None -def poll_until_finished( +async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, R], result_url_extractor: Optional[Callable[[R], str]] = None, node_id: Optional[str] = None, ) -> R: """Polls the Moonvalley API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( + return await PollingOperation( poll_endpoint=api_endpoint, completed_statuses=[ "completed", @@ -394,10 +394,10 @@ def parseControlParameter(self, value): else: return control_map["Motion Transfer"] - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> MoonvalleyPromptResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{API_PROMPTS_ENDPOINT}/{task_id}", @@ -507,7 +507,7 @@ def INPUT_TYPES(cls): RETURN_NAMES = ("video",) DESCRIPTION = "Moonvalley Marey Image to Video Node" - def generate( + async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): image = kwargs.get("image", None) @@ -532,9 +532,9 @@ def generate( # Get MIME type from tensor - assuming PNG format for image tensors mime_type = "image/png" - image_url = upload_images_to_comfyapi( + image_url = (await upload_images_to_comfyapi( image, max_images=1, auth_kwargs=kwargs, mime_type=mime_type - )[0] + ))[0] request = MoonvalleyTextToVideoRequest( image_url=image_url, prompt_text=prompt, inference_params=inference_params @@ -549,14 +549,14 @@ def generate( request=request, auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) - video = download_url_to_video_output(final_response.output_url) + video = await download_url_to_video_output(final_response.output_url) return (video,) @@ -609,7 +609,7 @@ def INPUT_TYPES(cls): RETURN_TYPES = ("VIDEO",) RETURN_NAMES = ("video",) - def generate( + async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): video = kwargs.get("video") @@ -620,7 +620,7 @@ def generate( video_url = "" if video: validated_video = validate_video_to_video_input(video) - video_url = upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) + video_url = await upload_video_to_comfyapi(validated_video, auth_kwargs=kwargs) control_type = kwargs.get("control_type") motion_intensity = kwargs.get("motion_intensity") @@ -658,15 +658,15 @@ def generate( request=request, auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) - video = download_url_to_video_output(final_response.output_url) + video = await download_url_to_video_output(final_response.output_url) return (video,) @@ -688,7 +688,7 @@ def INPUT_TYPES(cls): del input_types["optional"][param] return input_types - def generate( + async def generate( self, prompt, negative_prompt, unique_id: Optional[str] = None, **kwargs ): validate_prompts(prompt, negative_prompt, MOONVALLEY_MAREY_MAX_PROMPT_LENGTH) @@ -717,15 +717,15 @@ def generate( request=request, auth_kwargs=kwargs, ) - task_creation_response = initial_operation.execute() + task_creation_response = await initial_operation.execute() validate_task_creation_response(task_creation_response) task_id = task_creation_response.id - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) - video = download_url_to_video_output(final_response.output_url) + video = await download_url_to_video_output(final_response.output_url) return (video,) diff --git a/comfy_api_nodes/nodes_openai.py b/comfy_api_nodes/nodes_openai.py index be1d2de4aeb8..ab3c5363bad7 100644 --- a/comfy_api_nodes/nodes_openai.py +++ b/comfy_api_nodes/nodes_openai.py @@ -163,7 +163,7 @@ def INPUT_TYPES(cls) -> InputTypeDict: DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, seed=0, @@ -233,9 +233,9 @@ def api_call( auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() - img_tensor = validate_and_cast_response(response, node_id=unique_id) + img_tensor = await validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -311,7 +311,7 @@ def INPUT_TYPES(cls) -> InputTypeDict: DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, seed=0, @@ -343,9 +343,9 @@ def api_call( auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() - img_tensor = validate_and_cast_response(response, node_id=unique_id) + img_tensor = await validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -446,7 +446,7 @@ def INPUT_TYPES(cls) -> InputTypeDict: DESCRIPTION = cleandoc(__doc__ or "") API_NODE = True - def api_call( + async def api_call( self, prompt, seed=0, @@ -537,9 +537,9 @@ def api_call( auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() - img_tensor = validate_and_cast_response(response, node_id=unique_id) + img_tensor = await validate_and_cast_response(response, node_id=unique_id) return (img_tensor,) @@ -623,7 +623,7 @@ def INPUT_TYPES(cls) -> InputTypeDict: DESCRIPTION = "Generate text responses from an OpenAI model." - def get_result_response( + async def get_result_response( self, response_id: str, include: Optional[list[Includable]] = None, @@ -639,7 +639,7 @@ def get_result_response( creation above for more information. """ - return PollingOperation( + return await PollingOperation( poll_endpoint=ApiEndpoint( path=f"{RESPONSES_ENDPOINT}/{response_id}", method=HttpMethod.GET, @@ -784,7 +784,7 @@ def delete_history_after_response_id( self.history[session_id] = new_history - def api_call( + async def api_call( self, prompt: str, persist_context: bool, @@ -815,7 +815,7 @@ def api_call( previous_response_id = None # Create response - create_response = SynchronousOperation( + create_response = await SynchronousOperation( endpoint=ApiEndpoint( path=RESPONSES_ENDPOINT, method=HttpMethod.POST, @@ -848,7 +848,7 @@ def api_call( response_id = create_response.id # Get result output - result_response = self.get_result_response(response_id, auth_kwargs=kwargs) + result_response = await self.get_result_response(response_id, auth_kwargs=kwargs) output_text = self.parse_output_text_from_response(result_response) # Update history diff --git a/comfy_api_nodes/nodes_pika.py b/comfy_api_nodes/nodes_pika.py index 1cc7085645eb..a8dc43cb3ba2 100644 --- a/comfy_api_nodes/nodes_pika.py +++ b/comfy_api_nodes/nodes_pika.py @@ -122,7 +122,7 @@ def get_base_inputs_types( FUNCTION = "api_call" RETURN_TYPES = ("VIDEO",) - def poll_for_task_status( + async def poll_for_task_status( self, task_id: str, auth_kwargs: Optional[dict[str, str]] = None, @@ -152,9 +152,9 @@ def poll_for_task_status( node_id=node_id, estimated_duration=60 ) - return polling_operation.execute() + return await polling_operation.execute() - def execute_task( + async def execute_task( self, initial_operation: SynchronousOperation[R, PikaGenerateResponse], auth_kwargs: Optional[dict[str, str]] = None, @@ -169,14 +169,14 @@ def execute_task( Returns: A tuple containing the video file as a VIDEO output. """ - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() if not is_valid_initial_response(initial_response): error_msg = f"Pika initial request failed. Code: {initial_response.code}, Message: {initial_response.message}, Data: {initial_response.data}" logging.error(error_msg) raise PikaApiError(error_msg) task_id = initial_response.video_id - final_response = self.poll_for_task_status(task_id, auth_kwargs) + final_response = await self.poll_for_task_status(task_id, auth_kwargs) if not is_valid_video_response(final_response): error_msg = ( f"Pika task {task_id} succeeded but no video data found in response." @@ -187,7 +187,7 @@ def execute_task( video_url = str(final_response.url) logging.info("Pika task %s succeeded. Video URL: %s", task_id, video_url) - return (download_url_to_video_output(video_url),) + return (await download_url_to_video_output(video_url),) class PikaImageToVideoV2_2(PikaNodeBase): @@ -212,7 +212,7 @@ def INPUT_TYPES(cls): DESCRIPTION = "Sends an image and prompt to the Pika API v2.2 to generate a video." - def api_call( + async def api_call( self, image: torch.Tensor, prompt_text: str, @@ -251,7 +251,7 @@ def api_call( auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaTextToVideoNodeV2_2(PikaNodeBase): @@ -281,7 +281,7 @@ def INPUT_TYPES(cls): DESCRIPTION = "Sends a text prompt to the Pika API v2.2 to generate a video." - def api_call( + async def api_call( self, prompt_text: str, negative_prompt: str, @@ -311,7 +311,7 @@ def api_call( content_type="application/x-www-form-urlencoded", ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaScenesV2_2(PikaNodeBase): @@ -361,7 +361,7 @@ def INPUT_TYPES(cls): DESCRIPTION = "Combine your images to create a video with the objects in them. Upload multiple images as ingredients and generate a high-quality video that incorporates all of them." - def api_call( + async def api_call( self, prompt_text: str, negative_prompt: str, @@ -420,7 +420,7 @@ def api_call( auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikAdditionsNode(PikaNodeBase): @@ -462,7 +462,7 @@ def INPUT_TYPES(cls): DESCRIPTION = "Add any object or image into your video. Upload a video and specify what you'd like to add to create a seamlessly integrated result." - def api_call( + async def api_call( self, video: VideoInput, image: torch.Tensor, @@ -481,10 +481,10 @@ def api_call( image_bytes_io = tensor_to_bytesio(image) image_bytes_io.seek(0) - pika_files = [ - ("video", ("video.mp4", video_bytes_io, "video/mp4")), - ("image", ("image.png", image_bytes_io, "image/png")), - ] + pika_files = { + "video": ("video.mp4", video_bytes_io, "video/mp4"), + "image": ("image.png", image_bytes_io, "image/png"), + } # Prepare non-file data pika_request_data = PikaBodyGeneratePikadditionsGeneratePikadditionsPost( @@ -506,7 +506,7 @@ def api_call( auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaSwapsNode(PikaNodeBase): @@ -558,7 +558,7 @@ def INPUT_TYPES(cls): DESCRIPTION = "Swap out any object or region of your video with a new image or object. Define areas to replace either with a mask or coordinates." RETURN_TYPES = ("VIDEO",) - def api_call( + async def api_call( self, video: VideoInput, image: torch.Tensor, @@ -587,11 +587,11 @@ def api_call( image_bytes_io = tensor_to_bytesio(image) image_bytes_io.seek(0) - pika_files = [ - ("video", ("video.mp4", video_bytes_io, "video/mp4")), - ("image", ("image.png", image_bytes_io, "image/png")), - ("modifyRegionMask", ("mask.png", mask_bytes_io, "image/png")), - ] + pika_files = { + "video": ("video.mp4", video_bytes_io, "video/mp4"), + "image": ("image.png", image_bytes_io, "image/png"), + "modifyRegionMask": ("mask.png", mask_bytes_io, "image/png"), + } # Prepare non-file data pika_request_data = PikaBodyGeneratePikaswapsGeneratePikaswapsPost( @@ -613,7 +613,7 @@ def api_call( auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaffectsNode(PikaNodeBase): @@ -664,7 +664,7 @@ def INPUT_TYPES(cls): DESCRIPTION = "Generate a video with a specific Pikaffect. Supported Pikaffects: Cake-ify, Crumble, Crush, Decapitate, Deflate, Dissolve, Explode, Eye-pop, Inflate, Levitate, Melt, Peel, Poke, Squish, Ta-da, Tear" - def api_call( + async def api_call( self, image: torch.Tensor, pikaffect: str, @@ -693,7 +693,7 @@ def api_call( auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) class PikaStartEndFrameNode2_2(PikaNodeBase): @@ -718,7 +718,7 @@ def INPUT_TYPES(cls): DESCRIPTION = "Generate a video by combining your first and last frame. Upload two images to define the start and end points, and let the AI create a smooth transition between them." - def api_call( + async def api_call( self, image_start: torch.Tensor, image_end: torch.Tensor, @@ -732,10 +732,7 @@ def api_call( ) -> tuple[VideoFromFile]: pika_files = [ - ( - "keyFrames", - ("image_start.png", tensor_to_bytesio(image_start), "image/png"), - ), + ("keyFrames", ("image_start.png", tensor_to_bytesio(image_start), "image/png")), ("keyFrames", ("image_end.png", tensor_to_bytesio(image_end), "image/png")), ] @@ -758,7 +755,7 @@ def api_call( auth_kwargs=kwargs, ) - return self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) + return await self.execute_task(initial_operation, auth_kwargs=kwargs, node_id=unique_id) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_pixverse.py b/comfy_api_nodes/nodes_pixverse.py index ef4a9a8023f2..7c5a52feb42c 100644 --- a/comfy_api_nodes/nodes_pixverse.py +++ b/comfy_api_nodes/nodes_pixverse.py @@ -30,7 +30,7 @@ from comfy_api.input_impl import VideoFromFile import torch -import requests +import aiohttp from io import BytesIO @@ -47,7 +47,7 @@ def get_video_url_from_response( return str(response.Resp.url) -def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): +async def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): # first, upload image to Pixverse and get image id to use in actual generation call files = {"image": tensor_to_bytesio(image)} operation = SynchronousOperation( @@ -62,7 +62,7 @@ def upload_image_to_pixverse(image: torch.Tensor, auth_kwargs=None): content_type="multipart/form-data", auth_kwargs=auth_kwargs, ) - response_upload: PixverseImageUploadResponse = operation.execute() + response_upload: PixverseImageUploadResponse = await operation.execute() if response_upload.Resp is None: raise Exception( @@ -164,7 +164,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, aspect_ratio: str, @@ -205,7 +205,7 @@ def api_call( ), auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") @@ -229,11 +229,11 @@ def api_call( result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.Resp.url) - - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.Resp.url) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) class PixverseImageToVideoNode(ComfyNodeABC): @@ -302,7 +302,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -316,7 +316,7 @@ def api_call( **kwargs, ): validate_string(prompt, strip_whitespace=False) - img_id = upload_image_to_pixverse(image, auth_kwargs=kwargs) + img_id = await upload_image_to_pixverse(image, auth_kwargs=kwargs) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -345,7 +345,7 @@ def api_call( ), auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") @@ -369,10 +369,11 @@ def api_call( result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_I2V, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.Resp.url) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.Resp.url) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) class PixverseTransitionVideoNode(ComfyNodeABC): @@ -436,7 +437,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, first_frame: torch.Tensor, last_frame: torch.Tensor, @@ -450,8 +451,8 @@ def api_call( **kwargs, ): validate_string(prompt, strip_whitespace=False) - first_frame_id = upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) - last_frame_id = upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) + first_frame_id = await upload_image_to_pixverse(first_frame, auth_kwargs=kwargs) + last_frame_id = await upload_image_to_pixverse(last_frame, auth_kwargs=kwargs) # 1080p is limited to 5 seconds duration # only normal motion_mode supported for 1080p or for non-5 second duration @@ -480,7 +481,7 @@ def api_call( ), auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.Resp is None: raise Exception(f"PixVerse request failed: '{response_api.ErrMsg}'") @@ -504,10 +505,11 @@ def api_call( result_url_extractor=get_video_url_from_response, estimated_duration=AVERAGE_DURATION_T2V, ) - response_poll = operation.execute() + response_poll = await operation.execute() - vid_response = requests.get(response_poll.Resp.url) - return (VideoFromFile(BytesIO(vid_response.content)),) + async with aiohttp.ClientSession() as session: + async with session.get(response_poll.Resp.url) as vid_response: + return (VideoFromFile(BytesIO(await vid_response.content.read())),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index e369c4b7e283..c8516b368208 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -37,7 +37,7 @@ from PIL import UnidentifiedImageError -def handle_recraft_file_request( +async def handle_recraft_file_request( image: torch.Tensor, path: str, mask: torch.Tensor=None, @@ -71,13 +71,13 @@ def handle_recraft_file_request( auth_kwargs=auth_kwargs, multipart_parser=recraft_multipart_parser, ) - response: RecraftImageGenerationResponse = operation.execute() + response: RecraftImageGenerationResponse = await operation.execute() all_bytesio = [] if response.image is not None: - all_bytesio.append(download_url_to_bytesio(response.image.url, timeout=timeout)) + all_bytesio.append(await download_url_to_bytesio(response.image.url, timeout=timeout)) else: for data in response.data: - all_bytesio.append(download_url_to_bytesio(data.url, timeout=timeout)) + all_bytesio.append(await download_url_to_bytesio(data.url, timeout=timeout)) return all_bytesio @@ -395,7 +395,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, size: str, @@ -439,7 +439,7 @@ def api_call( ), auth_kwargs=kwargs, ) - response: RecraftImageGenerationResponse = operation.execute() + response: RecraftImageGenerationResponse = await operation.execute() images = [] urls = [] for data in response.data: @@ -451,7 +451,7 @@ def api_call( f"Result URL: {urls_string}", unique_id ) image = bytesio_to_image_tensor( - download_url_to_bytesio(data.url, timeout=1024) + await download_url_to_bytesio(data.url, timeout=1024) ) if len(image.shape) < 4: image = image.unsqueeze(0) @@ -538,7 +538,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -578,7 +578,7 @@ def api_call( total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/imageToImage", request=request, @@ -654,7 +654,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, image: torch.Tensor, mask: torch.Tensor, @@ -690,7 +690,7 @@ def api_call( total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], mask=mask[i:i+1], path="/proxy/recraft/images/inpaint", @@ -779,7 +779,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, substyle: str, @@ -821,7 +821,7 @@ def api_call( ), auth_kwargs=kwargs, ) - response: RecraftImageGenerationResponse = operation.execute() + response: RecraftImageGenerationResponse = await operation.execute() svg_data = [] urls = [] for data in response.data: @@ -831,7 +831,7 @@ def api_call( PromptServer.instance.send_progress_text( f"Result URL: {' '.join(urls)}", unique_id ) - svg_data.append(download_url_to_bytesio(data.url, timeout=1024)) + svg_data.append(await download_url_to_bytesio(data.url, timeout=1024)) return (SVG(svg_data),) @@ -861,7 +861,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, image: torch.Tensor, **kwargs, @@ -870,7 +870,7 @@ def api_call( total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/vectorize", auth_kwargs=kwargs, @@ -942,7 +942,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, image: torch.Tensor, prompt: str, @@ -973,7 +973,7 @@ def api_call( total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/replaceBackground", request=request, @@ -1011,7 +1011,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, image: torch.Tensor, **kwargs, @@ -1020,7 +1020,7 @@ def api_call( total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path="/proxy/recraft/images/removeBackground", auth_kwargs=kwargs, @@ -1062,7 +1062,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, image: torch.Tensor, **kwargs, @@ -1071,7 +1071,7 @@ def api_call( total = image.shape[0] pbar = ProgressBar(total) for i in range(total): - sub_bytes = handle_recraft_file_request( + sub_bytes = await handle_recraft_file_request( image=image[i], path=self.RECRAFT_PATH, auth_kwargs=kwargs, diff --git a/comfy_api_nodes/nodes_rodin.py b/comfy_api_nodes/nodes_rodin.py index 67f90478c1bc..c89d087e5fd2 100644 --- a/comfy_api_nodes/nodes_rodin.py +++ b/comfy_api_nodes/nodes_rodin.py @@ -9,11 +9,10 @@ from inspect import cleandoc from comfy.comfy_types.node_typing import IO import folder_paths as comfy_paths -import requests +import aiohttp import os import datetime -import shutil -import time +import asyncio import io import logging import math @@ -66,7 +65,6 @@ def create_task_error(response: Rodin3DGenerateResponse): return hasattr(response, "error") - class Rodin3DAPI: """ Generate 3D Assets using Rodin API @@ -123,8 +121,8 @@ def check_rodin_status(self, response: Rodin3DCheckStatusResponse) -> str: else: return "Generating" - def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): - if images == None: + async def create_generate_task(self, images=None, seed=1, material="PBR", quality="medium", tier="Regular", mesh_mode="Quad", **kwargs): + if images is None: raise Exception("Rodin 3D generate requires at least 1 image.") if len(images) >= 5: raise Exception("Rodin 3D generate requires up to 5 image.") @@ -155,7 +153,7 @@ def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="mediu auth_kwargs=kwargs, ) - response = operation.execute() + response = await operation.execute() if create_task_error(response): error_message = f"Rodin3D Create 3D generate Task Failed. Message: {response.message}, error: {response.error}" @@ -168,7 +166,7 @@ def CreateGenerateTask(self, images=None, seed=1, material="PBR", quality="mediu logging.info(f"[ Rodin3D API - Submit Jobs ] UUID: {task_uuid}") return task_uuid, subscription_key - def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: + async def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatusResponse: path = "/proxy/rodin/api/v2/status" @@ -191,11 +189,9 @@ def poll_for_task_status(self, subscription_key, **kwargs) -> Rodin3DCheckStatus logging.info("[ Rodin3D API - CheckStatus ] Generate Start!") - return poll_operation.execute() - + return await poll_operation.execute() - - def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse: + async def get_rodin_download_list(self, uuid, **kwargs) -> Rodin3DDownloadResponse: logging.info("[ Rodin3D API - Downloading ] Generate Successfully!") path = "/proxy/rodin/api/v2/download" @@ -212,53 +208,59 @@ def GetRodinDownloadList(self, uuid, **kwargs) -> Rodin3DDownloadResponse: auth_kwargs=kwargs ) - return operation.execute() + return await operation.execute() - def GetQualityAndMode(self, PolyCount): - if PolyCount == "200K-Triangle": + def get_quality_mode(self, poly_count): + if poly_count == "200K-Triangle": mesh_mode = "Raw" quality = "medium" else: mesh_mode = "Quad" - if PolyCount == "4K-Quad": + if poly_count == "4K-Quad": quality = "extra-low" - elif PolyCount == "8K-Quad": + elif poly_count == "8K-Quad": quality = "low" - elif PolyCount == "18K-Quad": + elif poly_count == "18K-Quad": quality = "medium" - elif PolyCount == "50K-Quad": + elif poly_count == "50K-Quad": quality = "high" else: quality = "medium" return mesh_mode, quality - def DownLoadFiles(self, Url_List): - Save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) - os.makedirs(Save_path, exist_ok=True) + async def download_files(self, url_list): + save_path = os.path.join(comfy_paths.get_output_directory(), "Rodin3D", datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")) + os.makedirs(save_path, exist_ok=True) model_file_path = None - for Item in Url_List.list: - url = Item.url - file_name = Item.name - file_path = os.path.join(Save_path, file_name) - if file_path.endswith(".glb"): - model_file_path = file_path - logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") - max_retries = 5 - for attempt in range(max_retries): - try: - with requests.get(url, stream=True) as r: - r.raise_for_status() - with open(file_path, "wb") as f: - shutil.copyfileobj(r.raw, f) - break - except Exception as e: - logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") - if attempt < max_retries - 1: - logging.info("Retrying...") - time.sleep(2) - else: - logging.info(f"[ Rodin3D API - download_files ] Failed to download {file_path} after {max_retries} attempts.") + async with aiohttp.ClientSession() as session: + for i in url_list.list: + url = i.url + file_name = i.name + file_path = os.path.join(save_path, file_name) + if file_path.endswith(".glb"): + model_file_path = file_path + logging.info(f"[ Rodin3D API - download_files ] Downloading file: {file_path}") + max_retries = 5 + for attempt in range(max_retries): + try: + async with session.get(url) as resp: + resp.raise_for_status() + with open(file_path, "wb") as f: + async for chunk in resp.content.iter_chunked(32 * 1024): + f.write(chunk) + break + except Exception as e: + logging.info(f"[ Rodin3D API - download_files ] Error downloading {file_path}:{e}") + if attempt < max_retries - 1: + logging.info("Retrying...") + await asyncio.sleep(2) + else: + logging.info( + "[ Rodin3D API - download_files ] Failed to download %s after %s attempts.", + file_path, + max_retries, + ) return model_file_path @@ -285,7 +287,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -298,14 +300,17 @@ def api_call( m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality=quality, tier=tier, mesh_mode=mesh_mode, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) + class Rodin3D_Detail(Rodin3DAPI): @classmethod def INPUT_TYPES(s): @@ -328,7 +333,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -341,14 +346,17 @@ def api_call( m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality=quality, tier=tier, mesh_mode=mesh_mode, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) + class Rodin3D_Smooth(Rodin3DAPI): @classmethod def INPUT_TYPES(s): @@ -371,7 +379,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -384,14 +392,17 @@ def api_call( m_images = [] for i in range(num_images): m_images.append(Images[i]) - mesh_mode, quality = self.GetQualityAndMode(Polygon_count) - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=Material_Type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + mesh_mode, quality = self.get_quality_mode(Polygon_count) + task_uuid, subscription_key = await self.create_generate_task(images=m_images, seed=Seed, material=Material_Type, + quality=quality, tier=tier, mesh_mode=mesh_mode, + **kwargs) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) + class Rodin3D_Sketch(Rodin3DAPI): @classmethod def INPUT_TYPES(s): @@ -423,7 +434,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, Images, Seed, @@ -437,10 +448,12 @@ def api_call( material_type = "PBR" quality = "medium" mesh_mode = "Quad" - task_uuid, subscription_key = self.CreateGenerateTask(images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs) - self.poll_for_task_status(subscription_key, **kwargs) - Download_List = self.GetRodinDownloadList(task_uuid, **kwargs) - model = self.DownLoadFiles(Download_List) + task_uuid, subscription_key = await self.create_generate_task( + images=m_images, seed=Seed, material=material_type, quality=quality, tier=tier, mesh_mode=mesh_mode, **kwargs + ) + await self.poll_for_task_status(subscription_key, **kwargs) + download_list = await self.get_rodin_download_list(task_uuid, **kwargs) + model = await self.download_files(download_list) return (model,) diff --git a/comfy_api_nodes/nodes_runway.py b/comfy_api_nodes/nodes_runway.py index af4b321f96d0..98024a9fa0b8 100644 --- a/comfy_api_nodes/nodes_runway.py +++ b/comfy_api_nodes/nodes_runway.py @@ -99,14 +99,14 @@ def validate_input_image(image: torch.Tensor) -> bool: return image.shape[2] < 8000 and image.shape[1] < 8000 -def poll_until_finished( +async def poll_until_finished( auth_kwargs: dict[str, str], api_endpoint: ApiEndpoint[Any, TaskStatusResponse], estimated_duration: Optional[int] = None, node_id: Optional[str] = None, ) -> TaskStatusResponse: """Polls the Runway API endpoint until the task reaches a terminal state, then returns the response.""" - return PollingOperation( + return await PollingOperation( poll_endpoint=api_endpoint, completed_statuses=[ TaskStatus.SUCCEEDED.value, @@ -115,7 +115,7 @@ def poll_until_finished( TaskStatus.FAILED.value, TaskStatus.CANCELLED.value, ], - status_extractor=lambda response: (response.status.value), + status_extractor=lambda response: response.status.value, auth_kwargs=auth_kwargs, result_url_extractor=get_video_url_from_task_status, estimated_duration=estimated_duration, @@ -167,11 +167,11 @@ def validate_response(self, response: RunwayImageToVideoResponse) -> bool: ) return True - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> RunwayImageToVideoResponse: """Poll the task status until it is finished then get the response.""" - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_GET_TASK_STATUS}/{task_id}", @@ -183,7 +183,7 @@ def get_response( node_id=node_id, ) - def generate_video( + async def generate_video( self, request: RunwayImageToVideoRequest, auth_kwargs: dict[str, str], @@ -200,15 +200,15 @@ def generate_video( auth_kwargs=auth_kwargs, ) - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() self.validate_task_created(initial_response) task_id = initial_response.id - final_response = self.get_response(task_id, auth_kwargs, node_id) + final_response = await self.get_response(task_id, auth_kwargs, node_id) self.validate_response(final_response) video_url = get_video_url_from_task_status(final_response) - return (download_url_to_video_output(video_url),) + return (await download_url_to_video_output(video_url),) class RunwayImageToVideoNodeGen3a(RunwayVideoGenNode): @@ -250,7 +250,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, start_frame: torch.Tensor, @@ -265,7 +265,7 @@ def api_call( validate_input_image(start_frame) # Upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( start_frame, max_images=1, mime_type="image/png", @@ -274,7 +274,7 @@ def api_call( if len(download_urls) != 1: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( + return await self.generate_video( RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -333,7 +333,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, start_frame: torch.Tensor, @@ -348,7 +348,7 @@ def api_call( validate_input_image(start_frame) # Upload image - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( start_frame, max_images=1, mime_type="image/png", @@ -357,7 +357,7 @@ def api_call( if len(download_urls) != 1: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( + return await self.generate_video( RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -382,10 +382,10 @@ class RunwayFirstLastFrameNode(RunwayVideoGenNode): DESCRIPTION = "Upload first and last keyframes, draft a prompt, and generate a video. More complex transitions, such as cases where the Last frame is completely different from the First frame, may benefit from the longer 10s duration. This would give the generation more time to smoothly transition between the two inputs. Before diving in, review these best practices to ensure that your input selections will set your generation up for success: https://help.runwayml.com/hc/en-us/articles/34170748696595-Creating-with-Keyframes-on-Gen-3." - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> RunwayImageToVideoResponse: - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_GET_TASK_STATUS}/{task_id}", @@ -437,7 +437,7 @@ def INPUT_TYPES(s): }, } - def api_call( + async def api_call( self, prompt: str, start_frame: torch.Tensor, @@ -455,7 +455,7 @@ def api_call( # Upload images stacked_input_images = image_tensor_pair_to_batch(start_frame, end_frame) - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( stacked_input_images, max_images=2, mime_type="image/png", @@ -464,7 +464,7 @@ def api_call( if len(download_urls) != 2: raise RunwayApiError("Failed to upload one or more images to comfy api.") - return self.generate_video( + return await self.generate_video( RunwayImageToVideoRequest( promptText=prompt, seed=seed, @@ -543,11 +543,11 @@ def validate_response(self, response: TaskStatusResponse) -> bool: ) return True - def get_response( + async def get_response( self, task_id: str, auth_kwargs: dict[str, str], node_id: Optional[str] = None ) -> TaskStatusResponse: """Poll the task status until it is finished then get the response.""" - return poll_until_finished( + return await poll_until_finished( auth_kwargs, ApiEndpoint( path=f"{PATH_GET_TASK_STATUS}/{task_id}", @@ -559,7 +559,7 @@ def get_response( node_id=node_id, ) - def api_call( + async def api_call( self, prompt: str, ratio: str, @@ -574,7 +574,7 @@ def api_call( reference_images = None if reference_image is not None: validate_input_image(reference_image) - download_urls = upload_images_to_comfyapi( + download_urls = await upload_images_to_comfyapi( reference_image, max_images=1, mime_type="image/png", @@ -605,19 +605,19 @@ def api_call( auth_kwargs=kwargs, ) - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() self.validate_task_created(initial_response) task_id = initial_response.id # Poll for completion - final_response = self.get_response( + final_response = await self.get_response( task_id, auth_kwargs=kwargs, node_id=unique_id ) self.validate_response(final_response) # Download and return image image_url = get_image_url_from_task_status(final_response) - return (download_url_to_image_tensor(image_url),) + return (await download_url_to_image_tensor(image_url),) NODE_CLASS_MAPPINGS = { diff --git a/comfy_api_nodes/nodes_stability.py b/comfy_api_nodes/nodes_stability.py index 02e4216780b9..31309d831b9d 100644 --- a/comfy_api_nodes/nodes_stability.py +++ b/comfy_api_nodes/nodes_stability.py @@ -124,7 +124,7 @@ def INPUT_TYPES(s): }, } - def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, + async def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, **kwargs): validate_string(prompt, strip_whitespace=False) @@ -163,7 +163,7 @@ def api_call(self, prompt: str, aspect_ratio: str, style_preset: str, seed: int, content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.") @@ -257,7 +257,7 @@ def INPUT_TYPES(s): }, } - def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, + async def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str, seed: int, cfg_scale: float, negative_prompt: str=None, image: torch.Tensor = None, image_denoise: float=None, **kwargs): validate_string(prompt, strip_whitespace=False) @@ -302,7 +302,7 @@ def api_call(self, model: str, prompt: str, aspect_ratio: str, style_preset: str content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.") @@ -374,7 +374,7 @@ def INPUT_TYPES(s): }, } - def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, + async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: int, negative_prompt: str=None, **kwargs): validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -403,7 +403,7 @@ def api_call(self, image: torch.Tensor, prompt: str, creativity: float, seed: in content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.") @@ -480,7 +480,7 @@ def INPUT_TYPES(s): }, } - def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, + async def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_preset: str, seed: int, negative_prompt: str=None, **kwargs): validate_string(prompt, strip_whitespace=False) image_binary = tensor_to_bytesio(image, total_pixels=1024*1024).read() @@ -512,7 +512,7 @@ def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_pr content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() operation = PollingOperation( poll_endpoint=ApiEndpoint( @@ -527,7 +527,7 @@ def api_call(self, image: torch.Tensor, prompt: str, creativity: float, style_pr status_extractor=lambda x: get_async_dummy_status(x), auth_kwargs=kwargs, ) - response_poll: StabilityResultsGetResponse = operation.execute() + response_poll: StabilityResultsGetResponse = await operation.execute() if response_poll.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.") @@ -563,8 +563,7 @@ def INPUT_TYPES(s): }, } - def api_call(self, image: torch.Tensor, - **kwargs): + async def api_call(self, image: torch.Tensor, **kwargs): image_binary = tensor_to_bytesio(image, total_pixels=4096*4096).read() files = { @@ -583,7 +582,7 @@ def api_call(self, image: torch.Tensor, content_type="multipart/form-data", auth_kwargs=kwargs, ) - response_api = operation.execute() + response_api = await operation.execute() if response_api.finish_reason != "SUCCESS": raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.") diff --git a/comfy_api_nodes/nodes_tripo.py b/comfy_api_nodes/nodes_tripo.py index 65f3b21f5cc1..d08cf9007655 100644 --- a/comfy_api_nodes/nodes_tripo.py +++ b/comfy_api_nodes/nodes_tripo.py @@ -37,8 +37,8 @@ ) -def upload_image_to_tripo(image, **kwargs): - urls = upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) +async def upload_image_to_tripo(image, **kwargs): + urls = await upload_images_to_comfyapi(image, max_images=1, auth_kwargs=kwargs) return TripoFileReference(TripoUrlReference(url=urls[0], type="jpeg")) def get_model_url_from_response(response: TripoTaskResponse) -> str: @@ -49,7 +49,7 @@ def get_model_url_from_response(response: TripoTaskResponse) -> str: raise RuntimeError(f"Failed to get model url from response: {response}") -def poll_until_finished( +async def poll_until_finished( kwargs: dict[str, str], response: TripoTaskResponse, ) -> tuple[str, str]: @@ -57,7 +57,7 @@ def poll_until_finished( if response.code != 0: raise RuntimeError(f"Failed to generate mesh: {response.error}") task_id = response.data.task_id - response_poll = PollingOperation( + response_poll = await PollingOperation( poll_endpoint=ApiEndpoint( path=f"/proxy/tripo/v2/openapi/task/{task_id}", method=HttpMethod.GET, @@ -80,7 +80,7 @@ def poll_until_finished( ).execute() if response_poll.data.status == TripoTaskStatus.SUCCESS: url = get_model_url_from_response(response_poll) - bytesio = download_url_to_bytesio(url) + bytesio = await download_url_to_bytesio(url) # Save the downloaded model file model_file = f"tripo_model_{task_id}.glb" with open(os.path.join(get_output_directory(), model_file), "wb") as f: @@ -88,6 +88,7 @@ def poll_until_finished( return model_file, task_id raise RuntimeError(f"Failed to generate mesh: {response_poll}") + class TripoTextToModelNode: """ Generates 3D models synchronously based on a text prompt using Tripo's API. @@ -126,11 +127,11 @@ def INPUT_TYPES(s): API_NODE = True OUTPUT_NODE = True - def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + async def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style=None, texture=None, pbr=None, image_seed=None, model_seed=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): style_enum = None if style == "None" else style if not prompt: raise RuntimeError("Prompt is required") - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -155,7 +156,8 @@ def generate_mesh(self, prompt, negative_prompt=None, model_version=None, style= ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoImageToModelNode: """ @@ -195,12 +197,12 @@ def INPUT_TYPES(s): API_NODE = True OUTPUT_NODE = True - def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): + async def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr=None, model_seed=None, orientation=None, texture_alignment=None, texture_seed=None, texture_quality=None, face_limit=None, quad=None, **kwargs): style_enum = None if style == "None" else style if image is None: raise RuntimeError("Image is required") - tripo_file = upload_image_to_tripo(image, **kwargs) - response = SynchronousOperation( + tripo_file = await upload_image_to_tripo(image, **kwargs) + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -225,7 +227,8 @@ def generate_mesh(self, image, model_version=None, style=None, texture=None, pbr ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoMultiviewToModelNode: """ @@ -267,7 +270,7 @@ def INPUT_TYPES(s): API_NODE = True OUTPUT_NODE = True - def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): + async def generate_mesh(self, image, image_left=None, image_back=None, image_right=None, model_version=None, orientation=None, texture=None, pbr=None, model_seed=None, texture_seed=None, texture_quality=None, texture_alignment=None, face_limit=None, quad=None, **kwargs): if image is None: raise RuntimeError("front image for multiview is required") images = [] @@ -282,11 +285,11 @@ def generate_mesh(self, image, image_left=None, image_back=None, image_right=Non for image_name in ["image", "image_left", "image_back", "image_right"]: image_ = image_dict[image_name] if image_ is not None: - tripo_file = upload_image_to_tripo(image_, **kwargs) + tripo_file = await upload_image_to_tripo(image_, **kwargs) images.append(tripo_file) else: images.append(TripoFileEmptyReference()) - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -309,7 +312,8 @@ def generate_mesh(self, image, image_left=None, image_back=None, image_right=Non ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoTextureNode: @classmethod @@ -340,8 +344,8 @@ def INPUT_TYPES(s): OUTPUT_NODE = True AVERAGE_DURATION = 80 - def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None, texture_quality=None, texture_alignment=None, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -358,7 +362,7 @@ def generate_mesh(self, model_task_id, texture=None, pbr=None, texture_seed=None ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) class TripoRefineNode: @@ -387,8 +391,8 @@ def INPUT_TYPES(s): OUTPUT_NODE = True AVERAGE_DURATION = 240 - def generate_mesh(self, model_task_id, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, model_task_id, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -400,7 +404,7 @@ def generate_mesh(self, model_task_id, **kwargs): ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) class TripoRigNode: @@ -425,8 +429,8 @@ def INPUT_TYPES(s): OUTPUT_NODE = True AVERAGE_DURATION = 180 - def generate_mesh(self, original_model_task_id, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, original_model_task_id, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -440,7 +444,8 @@ def generate_mesh(self, original_model_task_id, **kwargs): ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoRetargetNode: @classmethod @@ -475,8 +480,8 @@ def INPUT_TYPES(s): OUTPUT_NODE = True AVERAGE_DURATION = 30 - def generate_mesh(self, animation, original_model_task_id, **kwargs): - response = SynchronousOperation( + async def generate_mesh(self, animation, original_model_task_id, **kwargs): + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -491,7 +496,8 @@ def generate_mesh(self, animation, original_model_task_id, **kwargs): ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + class TripoConversionNode: @classmethod @@ -529,10 +535,10 @@ def VALIDATE_INPUTS(cls, input_types): OUTPUT_NODE = True AVERAGE_DURATION = 30 - def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): + async def generate_mesh(self, original_model_task_id, format, quad, face_limit, texture_size, texture_format, **kwargs): if not original_model_task_id: raise RuntimeError("original_model_task_id is required") - response = SynchronousOperation( + response = await SynchronousOperation( endpoint=ApiEndpoint( path="/proxy/tripo/v2/openapi/task", method=HttpMethod.POST, @@ -549,7 +555,8 @@ def generate_mesh(self, original_model_task_id, format, quad, face_limit, textur ), auth_kwargs=kwargs, ).execute() - return poll_until_finished(kwargs, response) + return await poll_until_finished(kwargs, response) + NODE_CLASS_MAPPINGS = { "TripoTextToModelNode": TripoTextToModelNode, diff --git a/comfy_api_nodes/nodes_veo2.py b/comfy_api_nodes/nodes_veo2.py index 97bfe20e63e5..e25dab2f5770 100644 --- a/comfy_api_nodes/nodes_veo2.py +++ b/comfy_api_nodes/nodes_veo2.py @@ -1,7 +1,7 @@ import io import logging import base64 -import requests +import aiohttp import torch from typing import Optional @@ -152,7 +152,7 @@ def INPUT_TYPES(s): DESCRIPTION = "Generates videos from text prompts using Google's Veo 2 API" API_NODE = True - def generate_video( + async def generate_video( self, prompt, aspect_ratio="16:9", @@ -217,7 +217,7 @@ def generate_video( auth_kwargs=kwargs, ) - initial_response = initial_operation.execute() + initial_response = await initial_operation.execute() operation_name = initial_response.name logging.info(f"Veo generation started with operation name: {operation_name}") @@ -256,7 +256,7 @@ def progress_extractor(response): ) # Execute the polling operation - poll_response = poll_operation.execute() + poll_response = await poll_operation.execute() # Now check for errors in the final response # Check for error in poll response @@ -281,7 +281,6 @@ def progress_extractor(response): raise Exception(error_message) # Extract video data - video_data = None if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: video = poll_response.response.videos[0] @@ -291,9 +290,9 @@ def progress_extractor(response): video_data = base64.b64decode(video.bytesBase64Encoded) elif hasattr(video, 'gcsUri') and video.gcsUri: # Download from URL - video_url = video.gcsUri - video_response = requests.get(video_url) - video_data = video_response.content + async with aiohttp.ClientSession() as session: + async with session.get(video.gcsUri) as video_response: + video_data = await video_response.content.read() else: raise Exception("Video returned but no data or URL was provided") else: