Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 51 additions & 103 deletions comfy_api_nodes/apinode_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
import aiohttp
import io
import logging
import mimetypes
Expand All @@ -21,7 +22,6 @@

import numpy as np
from PIL import Image
import requests
import torch
import math
import base64
Expand All @@ -30,7 +30,7 @@
import av


def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
async def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFromFile:
"""Downloads a video from a URL and returns a `VIDEO` output.

Args:
Expand All @@ -39,7 +39,7 @@ def download_url_to_video_output(video_url: str, timeout: int = None) -> VideoFr
Returns:
A Comfy node `VIDEO` output.
"""
video_io = download_url_to_bytesio(video_url, timeout)
video_io = await download_url_to_bytesio(video_url, timeout)
if video_io is None:
error_msg = f"Failed to download video from {video_url}"
logging.error(error_msg)
Expand All @@ -62,7 +62,7 @@ def downscale_image_tensor(image, total_pixels=1536 * 1024) -> torch.Tensor:
return s


def validate_and_cast_response(
async def validate_and_cast_response(
response, timeout: int = None, node_id: Union[str, None] = None
) -> torch.Tensor:
"""Validates and casts a response to a torch.Tensor.
Expand All @@ -86,35 +86,24 @@ def validate_and_cast_response(
image_tensors: list[torch.Tensor] = []

# Process each image in the data array
for image_data in data:
image_url = image_data.url
b64_data = image_data.b64_json

if not image_url and not b64_data:
raise ValueError("No image was generated in the response")

if b64_data:
img_data = base64.b64decode(b64_data)
img = Image.open(io.BytesIO(img_data))

elif image_url:
if node_id:
PromptServer.instance.send_progress_text(
f"Result URL: {image_url}", node_id
)
img_response = requests.get(image_url, timeout=timeout)
if img_response.status_code != 200:
raise ValueError("Failed to download the image")
img = Image.open(io.BytesIO(img_response.content))

img = img.convert("RGBA")

# Convert to numpy array, normalize to float32 between 0 and 1
img_array = np.array(img).astype(np.float32) / 255.0
img_tensor = torch.from_numpy(img_array)

# Add to list of tensors
image_tensors.append(img_tensor)
async with aiohttp.ClientSession(timeout=aiohttp.ClientTimeout(total=timeout)) as session:
for img_data in data:
img_bytes: bytes
if img_data.b64_json:
img_bytes = base64.b64decode(img_data.b64_json)
elif img_data.url:
if node_id:
PromptServer.instance.send_progress_text(f"Result URL: {img_data.url}", node_id)
async with session.get(img_data.url) as resp:
if resp.status != 200:
raise ValueError("Failed to download generated image")
img_bytes = await resp.read()
else:
raise ValueError("Invalid image payload – neither URL nor base64 data present.")

pil_img = Image.open(BytesIO(img_bytes)).convert("RGBA")
arr = np.asarray(pil_img).astype(np.float32) / 255.0
image_tensors.append(torch.from_numpy(arr))

return torch.stack(image_tensors, dim=0)

Expand Down Expand Up @@ -175,7 +164,7 @@ def mimetype_to_extension(mime_type: str) -> str:
return mime_type.split("/")[-1].lower()


def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
async def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
"""Downloads content from a URL using requests and returns it as BytesIO.

Args:
Expand All @@ -185,9 +174,11 @@ def download_url_to_bytesio(url: str, timeout: int = None) -> BytesIO:
Returns:
BytesIO object containing the downloaded content.
"""
response = requests.get(url, stream=True, timeout=timeout)
response.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(response.content)
timeout_cfg = aiohttp.ClientTimeout(total=timeout) if timeout else None
async with aiohttp.ClientSession(timeout=timeout_cfg) as session:
async with session.get(url) as resp:
resp.raise_for_status() # Raises HTTPError for bad responses (4XX or 5XX)
return BytesIO(await resp.read())


def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch.Tensor:
Expand All @@ -210,15 +201,15 @@ def bytesio_to_image_tensor(image_bytesio: BytesIO, mode: str = "RGBA") -> torch
return torch.from_numpy(image_array).unsqueeze(0)


def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
async def download_url_to_image_tensor(url: str, timeout: int = None) -> torch.Tensor:
"""Downloads an image from a URL and returns a [B, H, W, C] tensor."""
image_bytesio = download_url_to_bytesio(url, timeout)
image_bytesio = await download_url_to_bytesio(url, timeout)
return bytesio_to_image_tensor(image_bytesio)


def process_image_response(response: requests.Response) -> torch.Tensor:
def process_image_response(response_content: bytes | str) -> torch.Tensor:
"""Uses content from a Response object and converts it to a torch.Tensor"""
return bytesio_to_image_tensor(BytesIO(response.content))
return bytesio_to_image_tensor(BytesIO(response_content))


def _tensor_to_pil(image: torch.Tensor, total_pixels: int = 2048 * 2048) -> Image.Image:
Expand Down Expand Up @@ -336,10 +327,10 @@ def text_filepath_to_data_uri(filepath: str) -> str:
return f"data:{mime_type};base64,{base64_string}"


def upload_file_to_comfyapi(
async def upload_file_to_comfyapi(
file_bytes_io: BytesIO,
filename: str,
upload_mime_type: str,
upload_mime_type: Optional[str],
auth_kwargs: Optional[dict[str, str]] = None,
) -> str:
"""
Expand All @@ -354,7 +345,10 @@ def upload_file_to_comfyapi(
Returns:
The download URL for the uploaded file.
"""
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
if upload_mime_type is None:
request_object = UploadRequest(file_name=filename)
else:
request_object = UploadRequest(file_name=filename, content_type=upload_mime_type)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
Expand All @@ -366,12 +360,8 @@ def upload_file_to_comfyapi(
auth_kwargs=auth_kwargs,
)

response: UploadResponse = operation.execute()
upload_response = ApiClient.upload_file(
response.upload_url, file_bytes_io, content_type=upload_mime_type
)
upload_response.raise_for_status()

response: UploadResponse = await operation.execute()
await ApiClient.upload_file(response.upload_url, file_bytes_io, content_type=upload_mime_type)
return response.download_url


Expand Down Expand Up @@ -399,7 +389,7 @@ def video_to_base64_string(
return base64.b64encode(video_bytes_io.getvalue()).decode("utf-8")


def upload_video_to_comfyapi(
async def upload_video_to_comfyapi(
video: VideoInput,
auth_kwargs: Optional[dict[str, str]] = None,
container: VideoContainer = VideoContainer.MP4,
Expand Down Expand Up @@ -439,9 +429,7 @@ def upload_video_to_comfyapi(
video.save_to(video_bytes_io, format=container, codec=codec)
video_bytes_io.seek(0)

return upload_file_to_comfyapi(
video_bytes_io, filename, upload_mime_type, auth_kwargs
)
return await upload_file_to_comfyapi(video_bytes_io, filename, upload_mime_type, auth_kwargs)


def audio_tensor_to_contiguous_ndarray(waveform: torch.Tensor) -> np.ndarray:
Expand Down Expand Up @@ -501,7 +489,7 @@ def audio_ndarray_to_bytesio(
return audio_bytes_io


def upload_audio_to_comfyapi(
async def upload_audio_to_comfyapi(
audio: AudioInput,
auth_kwargs: Optional[dict[str, str]] = None,
container_format: str = "mp4",
Expand All @@ -527,7 +515,7 @@ def upload_audio_to_comfyapi(
audio_data_np, sample_rate, container_format, codec_name
)

return upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)
return await upload_file_to_comfyapi(audio_bytes_io, filename, mime_type, auth_kwargs)


def audio_to_base64_string(
Expand All @@ -544,7 +532,7 @@ def audio_to_base64_string(
return base64.b64encode(audio_bytes).decode("utf-8")


def upload_images_to_comfyapi(
async def upload_images_to_comfyapi(
image: torch.Tensor,
max_images=8,
auth_kwargs: Optional[dict[str, str]] = None,
Expand All @@ -561,55 +549,15 @@ def upload_images_to_comfyapi(
mime_type: Optional MIME type for the image.
"""
# if batch, try to upload each file if max_images is greater than 0
idx_image = 0
download_urls: list[str] = []
is_batch = len(image.shape) > 3
batch_length = 1
if is_batch:
batch_length = image.shape[0]
while True:
curr_image = image
if len(image.shape) > 3:
curr_image = image[idx_image]
# get BytesIO version of image
img_binary = tensor_to_bytesio(curr_image, mime_type=mime_type)
# first, request upload/download urls from comfy API
if not mime_type:
request_object = UploadRequest(file_name=img_binary.name)
else:
request_object = UploadRequest(
file_name=img_binary.name, content_type=mime_type
)
operation = SynchronousOperation(
endpoint=ApiEndpoint(
path="/customers/storage",
method=HttpMethod.POST,
request_model=UploadRequest,
response_model=UploadResponse,
),
request=request_object,
auth_kwargs=auth_kwargs,
)
response = operation.execute()
batch_len = image.shape[0] if is_batch else 1

upload_response = ApiClient.upload_file(
response.upload_url, img_binary, content_type=mime_type
)
# verify success
try:
upload_response.raise_for_status()
except requests.exceptions.HTTPError as e:
raise ValueError(f"Could not upload one or more images: {e}") from e
# add download_url to list
download_urls.append(response.download_url)

idx_image += 1
# stop uploading additional files if done
if is_batch and max_images > 0:
if idx_image >= max_images:
break
if idx_image >= batch_length:
break
for idx in range(min(batch_len, max_images)):
tensor = image[idx] if is_batch else image
img_io = tensor_to_bytesio(tensor, mime_type=mime_type)
url = await upload_file_to_comfyapi(img_io, img_io.name, mime_type, auth_kwargs)
download_urls.append(url)
return download_urls


Expand Down
Loading
Loading