Skip to content

Commit 24fdb92

Browse files
authored
feat(api-nodes): add new Gemini model (Comfy-Org#10789)
1 parent d526974 commit 24fdb92

File tree

2 files changed

+246
-32
lines changed

2 files changed

+246
-32
lines changed

comfy_api_nodes/apis/gemini_api.py

Lines changed: 219 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,229 @@
1-
from typing import Optional
1+
from datetime import date
2+
from enum import Enum
3+
from typing import Any
24

3-
from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata
4-
from pydantic import BaseModel
5+
from pydantic import BaseModel, Field
6+
7+
8+
class GeminiSafetyCategory(str, Enum):
9+
HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT"
10+
HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH"
11+
HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT"
12+
HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT"
13+
14+
15+
class GeminiSafetyThreshold(str, Enum):
16+
OFF = "OFF"
17+
BLOCK_NONE = "BLOCK_NONE"
18+
BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE"
19+
BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE"
20+
BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH"
21+
22+
23+
class GeminiSafetySetting(BaseModel):
24+
category: GeminiSafetyCategory
25+
threshold: GeminiSafetyThreshold
26+
27+
28+
class GeminiRole(str, Enum):
29+
user = "user"
30+
model = "model"
31+
32+
33+
class GeminiMimeType(str, Enum):
34+
application_pdf = "application/pdf"
35+
audio_mpeg = "audio/mpeg"
36+
audio_mp3 = "audio/mp3"
37+
audio_wav = "audio/wav"
38+
image_png = "image/png"
39+
image_jpeg = "image/jpeg"
40+
image_webp = "image/webp"
41+
text_plain = "text/plain"
42+
video_mov = "video/mov"
43+
video_mpeg = "video/mpeg"
44+
video_mp4 = "video/mp4"
45+
video_mpg = "video/mpg"
46+
video_avi = "video/avi"
47+
video_wmv = "video/wmv"
48+
video_mpegps = "video/mpegps"
49+
video_flv = "video/flv"
50+
51+
52+
class GeminiInlineData(BaseModel):
53+
data: str | None = Field(
54+
None,
55+
description="The base64 encoding of the image, PDF, or video to include inline in the prompt. "
56+
"When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB",
57+
)
58+
mimeType: GeminiMimeType | None = Field(None)
59+
60+
61+
class GeminiPart(BaseModel):
62+
inlineData: GeminiInlineData | None = Field(None)
63+
text: str | None = Field(None)
64+
65+
66+
class GeminiTextPart(BaseModel):
67+
text: str | None = Field(None)
68+
69+
70+
class GeminiContent(BaseModel):
71+
parts: list[GeminiPart] = Field(...)
72+
role: GeminiRole = Field(..., examples=["user"])
73+
74+
75+
class GeminiSystemInstructionContent(BaseModel):
76+
parts: list[GeminiTextPart] = Field(
77+
...,
78+
description="A list of ordered parts that make up a single message. "
79+
"Different parts may have different IANA MIME types.",
80+
)
81+
role: GeminiRole = Field(
82+
...,
83+
description="The identity of the entity that creates the message. "
84+
"The following values are supported: "
85+
"user: This indicates that the message is sent by a real person, typically a user-generated message. "
86+
"model: This indicates that the message is generated by the model. "
87+
"The model value is used to insert messages from model into the conversation during multi-turn conversations. "
88+
"For non-multi-turn conversations, this field can be left blank or unset.",
89+
)
90+
91+
92+
class GeminiFunctionDeclaration(BaseModel):
93+
description: str | None = Field(None)
94+
name: str = Field(...)
95+
parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters")
96+
97+
98+
class GeminiTool(BaseModel):
99+
functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None)
100+
101+
102+
class GeminiOffset(BaseModel):
103+
nanos: int | None = Field(None, ge=0, le=999999999)
104+
seconds: int | None = Field(None, ge=-315576000000, le=315576000000)
105+
106+
107+
class GeminiVideoMetadata(BaseModel):
108+
endOffset: GeminiOffset | None = Field(None)
109+
startOffset: GeminiOffset | None = Field(None)
110+
111+
112+
class GeminiGenerationConfig(BaseModel):
113+
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
114+
seed: int | None = Field(None)
115+
stopSequences: list[str] | None = Field(None)
116+
temperature: float | None = Field(1, ge=0.0, le=2.0)
117+
topK: int | None = Field(40, ge=1)
118+
topP: float | None = Field(0.95, ge=0.0, le=1.0)
5119

6120

7121
class GeminiImageConfig(BaseModel):
8-
aspectRatio: Optional[str] = None
122+
aspectRatio: str | None = Field(None)
123+
resolution: str | None = Field(None)
9124

10125

11126
class GeminiImageGenerationConfig(GeminiGenerationConfig):
12-
responseModalities: Optional[list[str]] = None
13-
imageConfig: Optional[GeminiImageConfig] = None
127+
responseModalities: list[str] | None = Field(None)
128+
imageConfig: GeminiImageConfig | None = Field(None)
14129

15130

16131
class GeminiImageGenerateContentRequest(BaseModel):
17-
contents: list[GeminiContent]
18-
generationConfig: Optional[GeminiImageGenerationConfig] = None
19-
safetySettings: Optional[list[GeminiSafetySetting]] = None
20-
systemInstruction: Optional[GeminiSystemInstructionContent] = None
21-
tools: Optional[list[GeminiTool]] = None
22-
videoMetadata: Optional[GeminiVideoMetadata] = None
132+
contents: list[GeminiContent] = Field(...)
133+
generationConfig: GeminiImageGenerationConfig | None = Field(None)
134+
safetySettings: list[GeminiSafetySetting] | None = Field(None)
135+
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
136+
tools: list[GeminiTool] | None = Field(None)
137+
videoMetadata: GeminiVideoMetadata | None = Field(None)
138+
139+
140+
class GeminiGenerateContentRequest(BaseModel):
141+
contents: list[GeminiContent] = Field(...)
142+
generationConfig: GeminiGenerationConfig | None = Field(None)
143+
safetySettings: list[GeminiSafetySetting] | None = Field(None)
144+
systemInstruction: GeminiSystemInstructionContent | None = Field(None)
145+
tools: list[GeminiTool] | None = Field(None)
146+
videoMetadata: GeminiVideoMetadata | None = Field(None)
147+
148+
149+
class Modality(str, Enum):
150+
MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED"
151+
TEXT = "TEXT"
152+
IMAGE = "IMAGE"
153+
VIDEO = "VIDEO"
154+
AUDIO = "AUDIO"
155+
DOCUMENT = "DOCUMENT"
156+
157+
158+
class ModalityTokenCount(BaseModel):
159+
modality: Modality | None = None
160+
tokenCount: int | None = Field(None, description="Number of tokens for the given modality.")
161+
162+
163+
class Probability(str, Enum):
164+
NEGLIGIBLE = "NEGLIGIBLE"
165+
LOW = "LOW"
166+
MEDIUM = "MEDIUM"
167+
HIGH = "HIGH"
168+
UNKNOWN = "UNKNOWN"
169+
170+
171+
class GeminiSafetyRating(BaseModel):
172+
category: GeminiSafetyCategory | None = None
173+
probability: Probability | None = Field(
174+
None,
175+
description="The probability that the content violates the specified safety category",
176+
)
177+
178+
179+
class GeminiCitation(BaseModel):
180+
authors: list[str] | None = None
181+
endIndex: int | None = None
182+
license: str | None = None
183+
publicationDate: date | None = None
184+
startIndex: int | None = None
185+
title: str | None = None
186+
uri: str | None = None
187+
188+
189+
class GeminiCitationMetadata(BaseModel):
190+
citations: list[GeminiCitation] | None = None
191+
192+
193+
class GeminiCandidate(BaseModel):
194+
citationMetadata: GeminiCitationMetadata | None = None
195+
content: GeminiContent | None = None
196+
finishReason: str | None = None
197+
safetyRatings: list[GeminiSafetyRating] | None = None
198+
199+
200+
class GeminiPromptFeedback(BaseModel):
201+
blockReason: str | None = None
202+
blockReasonMessage: str | None = None
203+
safetyRatings: list[GeminiSafetyRating] | None = None
204+
205+
206+
class GeminiUsageMetadata(BaseModel):
207+
cachedContentTokenCount: int | None = Field(
208+
None,
209+
description="Output only. Number of tokens in the cached part in the input (the cached content).",
210+
)
211+
candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).")
212+
candidatesTokensDetails: list[ModalityTokenCount] | None = Field(
213+
None, description="Breakdown of candidate tokens by modality."
214+
)
215+
promptTokenCount: int | None = Field(
216+
None,
217+
description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.",
218+
)
219+
promptTokensDetails: list[ModalityTokenCount] | None = Field(
220+
None, description="Breakdown of prompt tokens by modality."
221+
)
222+
thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.")
223+
toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).")
224+
225+
226+
class GeminiGenerateContentResponse(BaseModel):
227+
candidates: list[GeminiCandidate] | None = Field(None)
228+
promptFeedback: GeminiPromptFeedback | None = Field(None)
229+
usageMetadata: GeminiUsageMetadata | None = Field(None)

comfy_api_nodes/nodes_gemini.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,35 +3,32 @@
33
See: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
44
"""
55

6-
from __future__ import annotations
7-
86
import base64
97
import json
108
import os
119
import time
1210
import uuid
1311
from enum import Enum
1412
from io import BytesIO
15-
from typing import Literal, Optional
13+
from typing import Literal
1614

1715
import torch
1816
from typing_extensions import override
1917

2018
import folder_paths
2119
from comfy_api.latest import IO, ComfyExtension, Input
2220
from comfy_api.util import VideoCodec, VideoContainer
23-
from comfy_api_nodes.apis import (
21+
from comfy_api_nodes.apis.gemini_api import (
2422
GeminiContent,
2523
GeminiGenerateContentRequest,
2624
GeminiGenerateContentResponse,
27-
GeminiInlineData,
28-
GeminiMimeType,
29-
GeminiPart,
30-
)
31-
from comfy_api_nodes.apis.gemini_api import (
3225
GeminiImageConfig,
3326
GeminiImageGenerateContentRequest,
3427
GeminiImageGenerationConfig,
28+
GeminiInlineData,
29+
GeminiMimeType,
30+
GeminiPart,
31+
GeminiRole,
3532
)
3633
from comfy_api_nodes.util import (
3734
ApiEndpoint,
@@ -57,6 +54,7 @@ class GeminiModel(str, Enum):
5754
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
5855
gemini_2_5_pro = "gemini-2.5-pro"
5956
gemini_2_5_flash = "gemini-2.5-flash"
57+
gemini_3_0_pro = "gemini-3-pro-preview"
6058

6159

6260
class GeminiImageModel(str, Enum):
@@ -103,6 +101,16 @@ def get_parts_by_type(response: GeminiGenerateContentResponse, part_type: Litera
103101
Returns:
104102
List of response parts matching the requested type.
105103
"""
104+
if response.candidates is None:
105+
if response.promptFeedback.blockReason:
106+
feedback = response.promptFeedback
107+
raise ValueError(
108+
f"Gemini API blocked the request. Reason: {feedback.blockReason} ({feedback.blockReasonMessage})"
109+
)
110+
raise NotImplementedError(
111+
"Gemini returned no response candidates. "
112+
"Please report to ComfyUI repository with the example of workflow to reproduce this."
113+
)
106114
parts = []
107115
for part in response.candidates[0].content.parts:
108116
if part_type == "text" and hasattr(part, "text") and part.text:
@@ -272,10 +280,10 @@ async def execute(
272280
prompt: str,
273281
model: str,
274282
seed: int,
275-
images: Optional[torch.Tensor] = None,
276-
audio: Optional[Input.Audio] = None,
277-
video: Optional[Input.Video] = None,
278-
files: Optional[list[GeminiPart]] = None,
283+
images: torch.Tensor | None = None,
284+
audio: Input.Audio | None = None,
285+
video: Input.Video | None = None,
286+
files: list[GeminiPart] | None = None,
279287
) -> IO.NodeOutput:
280288
validate_string(prompt, strip_whitespace=False)
281289

@@ -300,15 +308,14 @@ async def execute(
300308
data=GeminiGenerateContentRequest(
301309
contents=[
302310
GeminiContent(
303-
role="user",
311+
role=GeminiRole.user,
304312
parts=parts,
305313
)
306314
]
307315
),
308316
response_model=GeminiGenerateContentResponse,
309317
)
310318

311-
# Get result output
312319
output_text = get_text_from_response(response)
313320
if output_text:
314321
# Not a true chat history like the OpenAI Chat node. It is emulated so the frontend can show a copy button.
@@ -406,7 +413,7 @@ def create_file_part(cls, file_path: str) -> GeminiPart:
406413
)
407414

408415
@classmethod
409-
def execute(cls, file: str, GEMINI_INPUT_FILES: Optional[list[GeminiPart]] = None) -> IO.NodeOutput:
416+
def execute(cls, file: str, GEMINI_INPUT_FILES: list[GeminiPart] | None = None) -> IO.NodeOutput:
410417
"""Loads and formats input files for Gemini API."""
411418
if GEMINI_INPUT_FILES is None:
412419
GEMINI_INPUT_FILES = []
@@ -421,7 +428,7 @@ class GeminiImage(IO.ComfyNode):
421428
def define_schema(cls):
422429
return IO.Schema(
423430
node_id="GeminiImageNode",
424-
display_name="Google Gemini Image",
431+
display_name="Nano Banana (Google Gemini Image)",
425432
category="api node/image/Gemini",
426433
description="Edit images synchronously via Google API.",
427434
inputs=[
@@ -488,8 +495,8 @@ async def execute(
488495
prompt: str,
489496
model: str,
490497
seed: int,
491-
images: Optional[torch.Tensor] = None,
492-
files: Optional[list[GeminiPart]] = None,
498+
images: torch.Tensor | None = None,
499+
files: list[GeminiPart] | None = None,
493500
aspect_ratio: str = "auto",
494501
) -> IO.NodeOutput:
495502
validate_string(prompt, strip_whitespace=True, min_length=1)
@@ -510,7 +517,7 @@ async def execute(
510517
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model}", method="POST"),
511518
data=GeminiImageGenerateContentRequest(
512519
contents=[
513-
GeminiContent(role="user", parts=parts),
520+
GeminiContent(role=GeminiRole.user, parts=parts),
514521
],
515522
generationConfig=GeminiImageGenerationConfig(
516523
responseModalities=["TEXT", "IMAGE"],

0 commit comments

Comments
 (0)