-
-
Notifications
You must be signed in to change notification settings - Fork 705
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Extract image generation code into new image processor for modularity
- Loading branch information
Showing
3 changed files
with
214 additions
and
186 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,212 @@ | ||
import base64 | ||
import io | ||
import logging | ||
import time | ||
from typing import Any, Callable, Dict, List, Optional | ||
|
||
import openai | ||
import requests | ||
|
||
from khoj.database.adapters import ConversationAdapters | ||
from khoj.database.models import KhojUser, TextToImageModelConfig | ||
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt | ||
from khoj.routers.storage import upload_image | ||
from khoj.utils import state | ||
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer | ||
from khoj.utils.rawconfig import LocationData | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
async def text_to_image( | ||
message: str, | ||
user: KhojUser, | ||
conversation_log: dict, | ||
location_data: LocationData, | ||
references: List[Dict[str, Any]], | ||
online_results: Dict[str, Any], | ||
subscribed: bool = False, | ||
send_status_func: Optional[Callable] = None, | ||
uploaded_image_url: Optional[str] = None, | ||
): | ||
status_code = 200 | ||
image = None | ||
image_url = None | ||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 | ||
|
||
text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user) | ||
if not text_to_image_config: | ||
# If the user has not configured a text to image model, return an unsupported on server error | ||
status_code = 501 | ||
message = "Failed to generate image. Setup image generation on the server." | ||
yield image_url or image, status_code, message, intent_type.value | ||
return | ||
|
||
text2image_model = text_to_image_config.model_name | ||
chat_history = "" | ||
for chat in conversation_log.get("chat", [])[-4:]: | ||
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: | ||
chat_history += f"Q: {chat['intent']['query']}\n" | ||
chat_history += f"A: {chat['message']}\n" | ||
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): | ||
chat_history += f"Q: Prompt: {chat['intent']['query']}\n" | ||
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" | ||
|
||
if send_status_func: | ||
async for event in send_status_func("**Enhancing the Painting Prompt**"): | ||
yield {ChatEvent.STATUS: event} | ||
|
||
# Generate a better image prompt | ||
# Use the user's message, chat history, and other context | ||
image_prompt = await generate_better_image_prompt( | ||
message, | ||
chat_history, | ||
location_data=location_data, | ||
note_references=references, | ||
online_results=online_results, | ||
model_type=text_to_image_config.model_type, | ||
subscribed=subscribed, | ||
uploaded_image_url=uploaded_image_url, | ||
) | ||
|
||
if send_status_func: | ||
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"): | ||
yield {ChatEvent.STATUS: event} | ||
|
||
# Generate image using the configured model and API | ||
with timer(f"Generate image with {text_to_image_config.model_type}", logger): | ||
try: | ||
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: | ||
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model) | ||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: | ||
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model) | ||
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE: | ||
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model) | ||
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: | ||
if "content_policy_violation" in e.message: | ||
logger.error(f"Image Generation blocked by OpenAI: {e}") | ||
status_code = e.status_code # type: ignore | ||
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore | ||
yield image_url or image, status_code, message, intent_type.value | ||
return | ||
else: | ||
logger.error(f"Image Generation failed with {e}", exc_info=True) | ||
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore | ||
status_code = e.status_code # type: ignore | ||
yield image_url or image, status_code, message, intent_type.value | ||
return | ||
except requests.RequestException as e: | ||
logger.error(f"Image Generation failed with {e}", exc_info=True) | ||
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}" | ||
status_code = 502 | ||
yield image_url or image, status_code, message, intent_type.value | ||
return | ||
|
||
# Decide how to store the generated image | ||
with timer("Upload image to S3", logger): | ||
image_url = upload_image(webp_image_bytes, user.uuid) | ||
if image_url: | ||
intent_type = ImageIntentType.TEXT_TO_IMAGE2 | ||
else: | ||
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 | ||
image = base64.b64encode(webp_image_bytes).decode("utf-8") | ||
|
||
yield image_url or image, status_code, image_prompt, intent_type.value | ||
|
||
|
||
def generate_image_with_openai( | ||
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str | ||
): | ||
"Generate image using OpenAI API" | ||
|
||
# Get the API key from the user's configuration | ||
if text_to_image_config.api_key: | ||
api_key = text_to_image_config.api_key | ||
elif text_to_image_config.openai_config: | ||
api_key = text_to_image_config.openai_config.api_key | ||
elif state.openai_client: | ||
api_key = state.openai_client.api_key | ||
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {} | ||
|
||
# Generate image using OpenAI API | ||
OPENAI_IMAGE_GEN_STYLE = "vivid" | ||
response = state.openai_client.images.generate( | ||
prompt=improved_image_prompt, | ||
model=text2image_model, | ||
style=OPENAI_IMAGE_GEN_STYLE, | ||
response_format="b64_json", | ||
extra_headers=auth_header, | ||
) | ||
|
||
# Extract the base64 image from the response | ||
image = response.data[0].b64_json | ||
# Decode base64 png and convert it to webp for faster loading | ||
return convert_image_to_webp(base64.b64decode(image)) | ||
|
||
|
||
def generate_image_with_stability( | ||
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str | ||
): | ||
"Generate image using Stability AI" | ||
|
||
# Call Stability AI API to generate image | ||
response = requests.post( | ||
f"https://api.stability.ai/v2beta/stable-image/generate/sd3", | ||
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, | ||
files={"none": ""}, | ||
data={ | ||
"prompt": improved_image_prompt, | ||
"model": text2image_model, | ||
"mode": "text-to-image", | ||
"output_format": "png", | ||
"aspect_ratio": "1:1", | ||
}, | ||
) | ||
# Convert png to webp for faster loading | ||
return convert_image_to_webp(response.content) | ||
|
||
|
||
def generate_image_with_replicate( | ||
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str | ||
): | ||
"Generate image using Replicate API" | ||
|
||
# Create image generation task on Replicate | ||
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions" | ||
headers = { | ||
"Authorization": f"Bearer {text_to_image_config.api_key}", | ||
"Content-Type": "application/json", | ||
} | ||
json = { | ||
"input": { | ||
"prompt": improved_image_prompt, | ||
"num_outputs": 1, | ||
"aspect_ratio": "1:1", | ||
"output_format": "webp", | ||
"output_quality": 100, | ||
} | ||
} | ||
create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json() | ||
|
||
# Get status of image generation task | ||
get_prediction_url = create_prediction["urls"]["get"] | ||
get_prediction = requests.get(get_prediction_url, headers=headers).json() | ||
status = get_prediction["status"] | ||
retry_count = 1 | ||
|
||
# Poll the image generation task for completion status | ||
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20: | ||
time.sleep(2) | ||
get_prediction = requests.get(get_prediction_url, headers=headers).json() | ||
status = get_prediction["status"] | ||
retry_count += 1 | ||
|
||
# Raise exception if the image generation task fails | ||
if status != "succeeded": | ||
if retry_count >= 10: | ||
raise requests.RequestException("Image generation timed out") | ||
raise requests.RequestException(f"Image generation failed with status: {status}") | ||
|
||
# Get the generated image | ||
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"] | ||
return io.BytesIO(requests.get(image_url).content).getvalue() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.