@@ -31,35 +31,43 @@ def downscale_input(image):
3131 s = s .movedim (1 ,- 1 )
3232 return s
3333
34- def validate_and_cast_response (response ):
34+ def validate_and_cast_response (response ):
3535 # validate raw JSON response
3636 data = response .data
3737 if not data or len (data ) == 0 :
3838 raise Exception ("No images returned from API endpoint" )
3939
40- # Get base64 image data
41- image_url = data [0 ].url
42- b64_data = data [0 ].b64_json
43- if not image_url and not b64_data :
44- raise Exception ("No image was generated in the response" )
40+ # Initialize list to store image tensors
41+ image_tensors = []
4542
46- if b64_data :
47- img_data = base64 .b64decode (b64_data )
48- img = Image .open (io .BytesIO (img_data ))
43+ # Process each image in the data array
44+ for image_data in data :
45+ image_url = image_data .url
46+ b64_data = image_data .b64_json
4947
50- elif image_url :
51- img_response = requests .get (image_url )
52- if img_response .status_code != 200 :
53- raise Exception ("Failed to download the image" )
54- img = Image .open (io .BytesIO (img_response .content ))
48+ if not image_url and not b64_data :
49+ raise Exception ("No image was generated in the response" )
5550
56- img = img .convert ("RGBA" )
51+ if b64_data :
52+ img_data = base64 .b64decode (b64_data )
53+ img = Image .open (io .BytesIO (img_data ))
5754
58- # Convert to numpy array, normalize to float32 between 0 and 1
59- img_array = np .array (img ).astype (np .float32 ) / 255.0
55+ elif image_url :
56+ img_response = requests .get (image_url )
57+ if img_response .status_code != 200 :
58+ raise Exception ("Failed to download the image" )
59+ img = Image .open (io .BytesIO (img_response .content ))
6060
61- # Convert to torch tensor and add batch dimension
62- return torch .from_numpy (img_array )[None ,]
61+ img = img .convert ("RGBA" )
62+
63+ # Convert to numpy array, normalize to float32 between 0 and 1
64+ img_array = np .array (img ).astype (np .float32 ) / 255.0
65+ img_tensor = torch .from_numpy (img_array )
66+
67+ # Add to list of tensors
68+ image_tensors .append (img_tensor )
69+
70+ return torch .stack (image_tensors , dim = 0 )
6371
6472class OpenAIDalle2 (ComfyNodeABC ):
6573 """
0 commit comments