|
1 | | -from typing import Optional |
| 1 | +from datetime import date |
| 2 | +from enum import Enum |
| 3 | +from typing import Any |
2 | 4 |
|
3 | | -from comfy_api_nodes.apis import GeminiGenerationConfig, GeminiContent, GeminiSafetySetting, GeminiSystemInstructionContent, GeminiTool, GeminiVideoMetadata |
4 | | -from pydantic import BaseModel |
| 5 | +from pydantic import BaseModel, Field |
| 6 | + |
| 7 | + |
| 8 | +class GeminiSafetyCategory(str, Enum): |
| 9 | + HARM_CATEGORY_SEXUALLY_EXPLICIT = "HARM_CATEGORY_SEXUALLY_EXPLICIT" |
| 10 | + HARM_CATEGORY_HATE_SPEECH = "HARM_CATEGORY_HATE_SPEECH" |
| 11 | + HARM_CATEGORY_HARASSMENT = "HARM_CATEGORY_HARASSMENT" |
| 12 | + HARM_CATEGORY_DANGEROUS_CONTENT = "HARM_CATEGORY_DANGEROUS_CONTENT" |
| 13 | + |
| 14 | + |
| 15 | +class GeminiSafetyThreshold(str, Enum): |
| 16 | + OFF = "OFF" |
| 17 | + BLOCK_NONE = "BLOCK_NONE" |
| 18 | + BLOCK_LOW_AND_ABOVE = "BLOCK_LOW_AND_ABOVE" |
| 19 | + BLOCK_MEDIUM_AND_ABOVE = "BLOCK_MEDIUM_AND_ABOVE" |
| 20 | + BLOCK_ONLY_HIGH = "BLOCK_ONLY_HIGH" |
| 21 | + |
| 22 | + |
| 23 | +class GeminiSafetySetting(BaseModel): |
| 24 | + category: GeminiSafetyCategory |
| 25 | + threshold: GeminiSafetyThreshold |
| 26 | + |
| 27 | + |
| 28 | +class GeminiRole(str, Enum): |
| 29 | + user = "user" |
| 30 | + model = "model" |
| 31 | + |
| 32 | + |
| 33 | +class GeminiMimeType(str, Enum): |
| 34 | + application_pdf = "application/pdf" |
| 35 | + audio_mpeg = "audio/mpeg" |
| 36 | + audio_mp3 = "audio/mp3" |
| 37 | + audio_wav = "audio/wav" |
| 38 | + image_png = "image/png" |
| 39 | + image_jpeg = "image/jpeg" |
| 40 | + image_webp = "image/webp" |
| 41 | + text_plain = "text/plain" |
| 42 | + video_mov = "video/mov" |
| 43 | + video_mpeg = "video/mpeg" |
| 44 | + video_mp4 = "video/mp4" |
| 45 | + video_mpg = "video/mpg" |
| 46 | + video_avi = "video/avi" |
| 47 | + video_wmv = "video/wmv" |
| 48 | + video_mpegps = "video/mpegps" |
| 49 | + video_flv = "video/flv" |
| 50 | + |
| 51 | + |
| 52 | +class GeminiInlineData(BaseModel): |
| 53 | + data: str | None = Field( |
| 54 | + None, |
| 55 | + description="The base64 encoding of the image, PDF, or video to include inline in the prompt. " |
| 56 | + "When including media inline, you must also specify the media type (mimeType) of the data. Size limit: 20MB", |
| 57 | + ) |
| 58 | + mimeType: GeminiMimeType | None = Field(None) |
| 59 | + |
| 60 | + |
| 61 | +class GeminiPart(BaseModel): |
| 62 | + inlineData: GeminiInlineData | None = Field(None) |
| 63 | + text: str | None = Field(None) |
| 64 | + |
| 65 | + |
| 66 | +class GeminiTextPart(BaseModel): |
| 67 | + text: str | None = Field(None) |
| 68 | + |
| 69 | + |
| 70 | +class GeminiContent(BaseModel): |
| 71 | + parts: list[GeminiPart] = Field(...) |
| 72 | + role: GeminiRole = Field(..., examples=["user"]) |
| 73 | + |
| 74 | + |
| 75 | +class GeminiSystemInstructionContent(BaseModel): |
| 76 | + parts: list[GeminiTextPart] = Field( |
| 77 | + ..., |
| 78 | + description="A list of ordered parts that make up a single message. " |
| 79 | + "Different parts may have different IANA MIME types.", |
| 80 | + ) |
| 81 | + role: GeminiRole = Field( |
| 82 | + ..., |
| 83 | + description="The identity of the entity that creates the message. " |
| 84 | + "The following values are supported: " |
| 85 | + "user: This indicates that the message is sent by a real person, typically a user-generated message. " |
| 86 | + "model: This indicates that the message is generated by the model. " |
| 87 | + "The model value is used to insert messages from model into the conversation during multi-turn conversations. " |
| 88 | + "For non-multi-turn conversations, this field can be left blank or unset.", |
| 89 | + ) |
| 90 | + |
| 91 | + |
| 92 | +class GeminiFunctionDeclaration(BaseModel): |
| 93 | + description: str | None = Field(None) |
| 94 | + name: str = Field(...) |
| 95 | + parameters: dict[str, Any] = Field(..., description="JSON schema for the function parameters") |
| 96 | + |
| 97 | + |
| 98 | +class GeminiTool(BaseModel): |
| 99 | + functionDeclarations: list[GeminiFunctionDeclaration] | None = Field(None) |
| 100 | + |
| 101 | + |
| 102 | +class GeminiOffset(BaseModel): |
| 103 | + nanos: int | None = Field(None, ge=0, le=999999999) |
| 104 | + seconds: int | None = Field(None, ge=-315576000000, le=315576000000) |
| 105 | + |
| 106 | + |
| 107 | +class GeminiVideoMetadata(BaseModel): |
| 108 | + endOffset: GeminiOffset | None = Field(None) |
| 109 | + startOffset: GeminiOffset | None = Field(None) |
| 110 | + |
| 111 | + |
| 112 | +class GeminiGenerationConfig(BaseModel): |
| 113 | + maxOutputTokens: int | None = Field(None, ge=16, le=8192) |
| 114 | + seed: int | None = Field(None) |
| 115 | + stopSequences: list[str] | None = Field(None) |
| 116 | + temperature: float | None = Field(1, ge=0.0, le=2.0) |
| 117 | + topK: int | None = Field(40, ge=1) |
| 118 | + topP: float | None = Field(0.95, ge=0.0, le=1.0) |
5 | 119 |
|
6 | 120 |
|
7 | 121 | class GeminiImageConfig(BaseModel): |
8 | | - aspectRatio: Optional[str] = None |
| 122 | + aspectRatio: str | None = Field(None) |
| 123 | + resolution: str | None = Field(None) |
9 | 124 |
|
10 | 125 |
|
11 | 126 | class GeminiImageGenerationConfig(GeminiGenerationConfig): |
12 | | - responseModalities: Optional[list[str]] = None |
13 | | - imageConfig: Optional[GeminiImageConfig] = None |
| 127 | + responseModalities: list[str] | None = Field(None) |
| 128 | + imageConfig: GeminiImageConfig | None = Field(None) |
14 | 129 |
|
15 | 130 |
|
16 | 131 | class GeminiImageGenerateContentRequest(BaseModel): |
17 | | - contents: list[GeminiContent] |
18 | | - generationConfig: Optional[GeminiImageGenerationConfig] = None |
19 | | - safetySettings: Optional[list[GeminiSafetySetting]] = None |
20 | | - systemInstruction: Optional[GeminiSystemInstructionContent] = None |
21 | | - tools: Optional[list[GeminiTool]] = None |
22 | | - videoMetadata: Optional[GeminiVideoMetadata] = None |
| 132 | + contents: list[GeminiContent] = Field(...) |
| 133 | + generationConfig: GeminiImageGenerationConfig | None = Field(None) |
| 134 | + safetySettings: list[GeminiSafetySetting] | None = Field(None) |
| 135 | + systemInstruction: GeminiSystemInstructionContent | None = Field(None) |
| 136 | + tools: list[GeminiTool] | None = Field(None) |
| 137 | + videoMetadata: GeminiVideoMetadata | None = Field(None) |
| 138 | + |
| 139 | + |
| 140 | +class GeminiGenerateContentRequest(BaseModel): |
| 141 | + contents: list[GeminiContent] = Field(...) |
| 142 | + generationConfig: GeminiGenerationConfig | None = Field(None) |
| 143 | + safetySettings: list[GeminiSafetySetting] | None = Field(None) |
| 144 | + systemInstruction: GeminiSystemInstructionContent | None = Field(None) |
| 145 | + tools: list[GeminiTool] | None = Field(None) |
| 146 | + videoMetadata: GeminiVideoMetadata | None = Field(None) |
| 147 | + |
| 148 | + |
| 149 | +class Modality(str, Enum): |
| 150 | + MODALITY_UNSPECIFIED = "MODALITY_UNSPECIFIED" |
| 151 | + TEXT = "TEXT" |
| 152 | + IMAGE = "IMAGE" |
| 153 | + VIDEO = "VIDEO" |
| 154 | + AUDIO = "AUDIO" |
| 155 | + DOCUMENT = "DOCUMENT" |
| 156 | + |
| 157 | + |
| 158 | +class ModalityTokenCount(BaseModel): |
| 159 | + modality: Modality | None = None |
| 160 | + tokenCount: int | None = Field(None, description="Number of tokens for the given modality.") |
| 161 | + |
| 162 | + |
| 163 | +class Probability(str, Enum): |
| 164 | + NEGLIGIBLE = "NEGLIGIBLE" |
| 165 | + LOW = "LOW" |
| 166 | + MEDIUM = "MEDIUM" |
| 167 | + HIGH = "HIGH" |
| 168 | + UNKNOWN = "UNKNOWN" |
| 169 | + |
| 170 | + |
| 171 | +class GeminiSafetyRating(BaseModel): |
| 172 | + category: GeminiSafetyCategory | None = None |
| 173 | + probability: Probability | None = Field( |
| 174 | + None, |
| 175 | + description="The probability that the content violates the specified safety category", |
| 176 | + ) |
| 177 | + |
| 178 | + |
| 179 | +class GeminiCitation(BaseModel): |
| 180 | + authors: list[str] | None = None |
| 181 | + endIndex: int | None = None |
| 182 | + license: str | None = None |
| 183 | + publicationDate: date | None = None |
| 184 | + startIndex: int | None = None |
| 185 | + title: str | None = None |
| 186 | + uri: str | None = None |
| 187 | + |
| 188 | + |
| 189 | +class GeminiCitationMetadata(BaseModel): |
| 190 | + citations: list[GeminiCitation] | None = None |
| 191 | + |
| 192 | + |
| 193 | +class GeminiCandidate(BaseModel): |
| 194 | + citationMetadata: GeminiCitationMetadata | None = None |
| 195 | + content: GeminiContent | None = None |
| 196 | + finishReason: str | None = None |
| 197 | + safetyRatings: list[GeminiSafetyRating] | None = None |
| 198 | + |
| 199 | + |
| 200 | +class GeminiPromptFeedback(BaseModel): |
| 201 | + blockReason: str | None = None |
| 202 | + blockReasonMessage: str | None = None |
| 203 | + safetyRatings: list[GeminiSafetyRating] | None = None |
| 204 | + |
| 205 | + |
| 206 | +class GeminiUsageMetadata(BaseModel): |
| 207 | + cachedContentTokenCount: int | None = Field( |
| 208 | + None, |
| 209 | + description="Output only. Number of tokens in the cached part in the input (the cached content).", |
| 210 | + ) |
| 211 | + candidatesTokenCount: int | None = Field(None, description="Number of tokens in the response(s).") |
| 212 | + candidatesTokensDetails: list[ModalityTokenCount] | None = Field( |
| 213 | + None, description="Breakdown of candidate tokens by modality." |
| 214 | + ) |
| 215 | + promptTokenCount: int | None = Field( |
| 216 | + None, |
| 217 | + description="Number of tokens in the request. When cachedContent is set, this is still the total effective prompt size meaning this includes the number of tokens in the cached content.", |
| 218 | + ) |
| 219 | + promptTokensDetails: list[ModalityTokenCount] | None = Field( |
| 220 | + None, description="Breakdown of prompt tokens by modality." |
| 221 | + ) |
| 222 | + thoughtsTokenCount: int | None = Field(None, description="Number of tokens present in thoughts output.") |
| 223 | + toolUsePromptTokenCount: int | None = Field(None, description="Number of tokens present in tool-use prompt(s).") |
| 224 | + |
| 225 | + |
| 226 | +class GeminiGenerateContentResponse(BaseModel): |
| 227 | + candidates: list[GeminiCandidate] | None = Field(None) |
| 228 | + promptFeedback: GeminiPromptFeedback | None = Field(None) |
| 229 | + usageMetadata: GeminiUsageMetadata | None = Field(None) |
0 commit comments