diff --git a/src/khoj/database/migrations/0061_alter_texttoimagemodelconfig_model_type.py b/src/khoj/database/migrations/0061_alter_texttoimagemodelconfig_model_type.py new file mode 100644 index 000000000..4431a9d86 --- /dev/null +++ b/src/khoj/database/migrations/0061_alter_texttoimagemodelconfig_model_type.py @@ -0,0 +1,21 @@ +# Generated by Django 5.0.7 on 2024-09-12 05:43 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0060_merge_20240905_1828"), + ] + + operations = [ + migrations.AlterField( + model_name="texttoimagemodelconfig", + name="model_type", + field=models.CharField( + choices=[("openai", "Openai"), ("stability-ai", "Stabilityai"), ("replicate", "Replicate")], + default="openai", + max_length=200, + ), + ), + ] diff --git a/src/khoj/database/migrations/0062_merge_20240913_0222.py b/src/khoj/database/migrations/0062_merge_20240913_0222.py new file mode 100644 index 000000000..51175c50c --- /dev/null +++ b/src/khoj/database/migrations/0062_merge_20240913_0222.py @@ -0,0 +1,14 @@ +# Generated by Django 5.0.8 on 2024-09-13 02:22 + +from typing import List + +from django.db import migrations + + +class Migration(migrations.Migration): + dependencies = [ + ("database", "0061_alter_chatmodeloptions_model_type"), + ("database", "0061_alter_texttoimagemodelconfig_model_type"), + ] + + operations: List[str] = [] diff --git a/src/khoj/database/models/__init__.py b/src/khoj/database/models/__init__.py index 80769de8c..4029cf3c9 100644 --- a/src/khoj/database/models/__init__.py +++ b/src/khoj/database/models/__init__.py @@ -280,6 +280,7 @@ class TextToImageModelConfig(BaseModel): class ModelType(models.TextChoices): OPENAI = "openai" STABILITYAI = "stability-ai" + REPLICATE = "replicate" model_name = models.CharField(max_length=200, default="dall-e-3") model_type = models.CharField(max_length=200, choices=ModelType.choices, default=ModelType.OPENAI) diff --git a/src/khoj/processor/conversation/prompts.py b/src/khoj/processor/conversation/prompts.py index 24bf8fddc..13e2b72c6 100644 --- a/src/khoj/processor/conversation/prompts.py +++ b/src/khoj/processor/conversation/prompts.py @@ -128,8 +128,8 @@ ## -- image_generation_improve_prompt_base = """ -You are a talented creator with the ability to describe images to compose in vivid, fine detail. -Use the provided context and user prompt to generate a more detailed prompt to create an image: +You are a talented media artist with the ability to describe images to compose in professional, fine detail. +Generate a vivid description of the image to be rendered using the provided context and user prompt below: Today's Date: {current_date} User's Location: {location} @@ -145,10 +145,10 @@ User Prompt: "{query}" -Now generate an improved prompt describing the image to generate in vivid, fine detail. +Now generate an professional description of the image to generate in vivid, fine detail. - Use today's date, user's location, user's notes and online references to weave in any context that will improve the image generation. - Retain any important information and follow any instructions in the conversation log or user prompt. -- Add specific, fine position details to compose the image. +- Add specific, fine position details. Mention painting style, camera parameters to compose the image. - Ensure your improved prompt is in prose format.""" image_generation_improve_prompt_dalle = PromptTemplate.from_template( diff --git a/src/khoj/processor/image/generate.py b/src/khoj/processor/image/generate.py new file mode 100644 index 000000000..200473ebd --- /dev/null +++ b/src/khoj/processor/image/generate.py @@ -0,0 +1,212 @@ +import base64 +import io +import logging +import time +from typing import Any, Callable, Dict, List, Optional + +import openai +import requests + +from khoj.database.adapters import ConversationAdapters +from khoj.database.models import KhojUser, TextToImageModelConfig +from khoj.routers.helpers import ChatEvent, generate_better_image_prompt +from khoj.routers.storage import upload_image +from khoj.utils import state +from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer +from khoj.utils.rawconfig import LocationData + +logger = logging.getLogger(__name__) + + +async def text_to_image( + message: str, + user: KhojUser, + conversation_log: dict, + location_data: LocationData, + references: List[Dict[str, Any]], + online_results: Dict[str, Any], + subscribed: bool = False, + send_status_func: Optional[Callable] = None, + uploaded_image_url: Optional[str] = None, +): + status_code = 200 + image = None + image_url = None + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 + + text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user) + if not text_to_image_config: + # If the user has not configured a text to image model, return an unsupported on server error + status_code = 501 + message = "Failed to generate image. Setup image generation on the server." + yield image_url or image, status_code, message, intent_type.value + return + + text2image_model = text_to_image_config.model_name + chat_history = "" + for chat in conversation_log.get("chat", [])[-4:]: + if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: + chat_history += f"Q: {chat['intent']['query']}\n" + chat_history += f"A: {chat['message']}\n" + elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): + chat_history += f"Q: Prompt: {chat['intent']['query']}\n" + chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" + + if send_status_func: + async for event in send_status_func("**Enhancing the Painting Prompt**"): + yield {ChatEvent.STATUS: event} + + # Generate a better image prompt + # Use the user's message, chat history, and other context + image_prompt = await generate_better_image_prompt( + message, + chat_history, + location_data=location_data, + note_references=references, + online_results=online_results, + model_type=text_to_image_config.model_type, + subscribed=subscribed, + uploaded_image_url=uploaded_image_url, + ) + + if send_status_func: + async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"): + yield {ChatEvent.STATUS: event} + + # Generate image using the configured model and API + with timer(f"Generate image with {text_to_image_config.model_type}", logger): + try: + if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: + webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model) + elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: + webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model) + elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE: + webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model) + except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: + if "content_policy_violation" in e.message: + logger.error(f"Image Generation blocked by OpenAI: {e}") + status_code = e.status_code # type: ignore + message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore + yield image_url or image, status_code, message, intent_type.value + return + else: + logger.error(f"Image Generation failed with {e}", exc_info=True) + message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore + status_code = e.status_code # type: ignore + yield image_url or image, status_code, message, intent_type.value + return + except requests.RequestException as e: + logger.error(f"Image Generation failed with {e}", exc_info=True) + message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}" + status_code = 502 + yield image_url or image, status_code, message, intent_type.value + return + + # Decide how to store the generated image + with timer("Upload image to S3", logger): + image_url = upload_image(webp_image_bytes, user.uuid) + if image_url: + intent_type = ImageIntentType.TEXT_TO_IMAGE2 + else: + intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 + image = base64.b64encode(webp_image_bytes).decode("utf-8") + + yield image_url or image, status_code, image_prompt, intent_type.value + + +def generate_image_with_openai( + improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str +): + "Generate image using OpenAI API" + + # Get the API key from the user's configuration + if text_to_image_config.api_key: + api_key = text_to_image_config.api_key + elif text_to_image_config.openai_config: + api_key = text_to_image_config.openai_config.api_key + elif state.openai_client: + api_key = state.openai_client.api_key + auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {} + + # Generate image using OpenAI API + OPENAI_IMAGE_GEN_STYLE = "vivid" + response = state.openai_client.images.generate( + prompt=improved_image_prompt, + model=text2image_model, + style=OPENAI_IMAGE_GEN_STYLE, + response_format="b64_json", + extra_headers=auth_header, + ) + + # Extract the base64 image from the response + image = response.data[0].b64_json + # Decode base64 png and convert it to webp for faster loading + return convert_image_to_webp(base64.b64decode(image)) + + +def generate_image_with_stability( + improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str +): + "Generate image using Stability AI" + + # Call Stability AI API to generate image + response = requests.post( + f"https://api.stability.ai/v2beta/stable-image/generate/sd3", + headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, + files={"none": ""}, + data={ + "prompt": improved_image_prompt, + "model": text2image_model, + "mode": "text-to-image", + "output_format": "png", + "aspect_ratio": "1:1", + }, + ) + # Convert png to webp for faster loading + return convert_image_to_webp(response.content) + + +def generate_image_with_replicate( + improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str +): + "Generate image using Replicate API" + + # Create image generation task on Replicate + replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions" + headers = { + "Authorization": f"Bearer {text_to_image_config.api_key}", + "Content-Type": "application/json", + } + json = { + "input": { + "prompt": improved_image_prompt, + "num_outputs": 1, + "aspect_ratio": "1:1", + "output_format": "webp", + "output_quality": 100, + } + } + create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json() + + # Get status of image generation task + get_prediction_url = create_prediction["urls"]["get"] + get_prediction = requests.get(get_prediction_url, headers=headers).json() + status = get_prediction["status"] + retry_count = 1 + + # Poll the image generation task for completion status + while status not in ["succeeded", "failed", "canceled"] and retry_count < 20: + time.sleep(2) + get_prediction = requests.get(get_prediction_url, headers=headers).json() + status = get_prediction["status"] + retry_count += 1 + + # Raise exception if the image generation task fails + if status != "succeeded": + if retry_count >= 10: + raise requests.RequestException("Image generation timed out") + raise requests.RequestException(f"Image generation failed with status: {status}") + + # Get the generated image + image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"] + return io.BytesIO(requests.get(image_url).content).getvalue() diff --git a/src/khoj/routers/api_chat.py b/src/khoj/routers/api_chat.py index 73a8816c4..181593e82 100644 --- a/src/khoj/routers/api_chat.py +++ b/src/khoj/routers/api_chat.py @@ -26,6 +26,7 @@ from khoj.database.models import KhojUser from khoj.processor.conversation.prompts import help_message, no_entries_found from khoj.processor.conversation.utils import save_to_conversation_log +from khoj.processor.image.generate import text_to_image from khoj.processor.speech.text_to_speech import generate_text_to_speech from khoj.processor.tools.online_search import read_webpages, search_online from khoj.routers.api import extract_references_and_questions @@ -44,7 +45,6 @@ is_query_empty, is_ready_to_chat, read_chat_stream, - text_to_image, update_telemetry_state, validate_conversation_config, ) diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index f1b8ddd66..0fd40e5a3 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -1,7 +1,5 @@ import asyncio -import base64 import hashlib -import io import json import logging import math @@ -16,7 +14,6 @@ Annotated, Any, AsyncGenerator, - Callable, Dict, Iterator, List, @@ -24,17 +21,15 @@ Tuple, Union, ) -from urllib.parse import parse_qs, urlencode, urljoin, urlparse +from urllib.parse import parse_qs, urljoin, urlparse import cron_descriptor -import openai import pytz import requests from apscheduler.job import Job from apscheduler.triggers.cron import CronTrigger from asgiref.sync import sync_to_async from fastapi import Depends, Header, HTTPException, Request, UploadFile -from PIL import Image from starlette.authentication import has_required_scope from starlette.requests import URL @@ -93,7 +88,6 @@ ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled from khoj.routers.email import is_resend_enabled, send_task_email -from khoj.routers.storage import upload_image from khoj.routers.twilio import is_twilio_enabled from khoj.search_type import text_search from khoj.utils import state @@ -101,8 +95,6 @@ from khoj.utils.helpers import ( LRU, ConversationCommand, - ImageIntentType, - convert_image_to_webp, is_none_or_empty, is_valid_url, log_telemetry, @@ -568,7 +560,7 @@ async def generate_better_image_prompt( references=user_references, online_results=simplified_online_results, ) - elif model_type == TextToImageModelConfig.ModelType.STABILITYAI: + elif model_type in [TextToImageModelConfig.ModelType.STABILITYAI, TextToImageModelConfig.ModelType.REPLICATE]: image_prompt = prompts.image_generation_improve_prompt_sd.format( query=q, chat_history=conversation_history, @@ -921,129 +913,6 @@ def generate_chat_response( return chat_response, metadata -async def text_to_image( - message: str, - user: KhojUser, - conversation_log: dict, - location_data: LocationData, - references: List[Dict[str, Any]], - online_results: Dict[str, Any], - subscribed: bool = False, - send_status_func: Optional[Callable] = None, - uploaded_image_url: Optional[str] = None, -): - status_code = 200 - image = None - response = None - image_url = None - intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 - - text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user) - if not text_to_image_config: - # If the user has not configured a text to image model, return an unsupported on server error - status_code = 501 - message = "Failed to generate image. Setup image generation on the server." - yield image_url or image, status_code, message, intent_type.value - return - - text2image_model = text_to_image_config.model_name - chat_history = "" - for chat in conversation_log.get("chat", [])[-4:]: - if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]: - chat_history += f"Q: {chat['intent']['query']}\n" - chat_history += f"A: {chat['message']}\n" - elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"): - chat_history += f"Q: Prompt: {chat['intent']['query']}\n" - chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n" - - if send_status_func: - async for event in send_status_func("**Enhancing the Painting Prompt**"): - yield {ChatEvent.STATUS: event} - improved_image_prompt = await generate_better_image_prompt( - message, - chat_history, - location_data=location_data, - note_references=references, - online_results=online_results, - model_type=text_to_image_config.model_type, - subscribed=subscribed, - uploaded_image_url=uploaded_image_url, - ) - - if send_status_func: - async for event in send_status_func(f"**Painting to Imagine**:\n{improved_image_prompt}"): - yield {ChatEvent.STATUS: event} - - if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI: - with timer("Generate image with OpenAI", logger): - if text_to_image_config.api_key: - api_key = text_to_image_config.api_key - elif text_to_image_config.openai_config: - api_key = text_to_image_config.openai_config.api_key - elif state.openai_client: - api_key = state.openai_client.api_key - auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {} - try: - response = state.openai_client.images.generate( - prompt=improved_image_prompt, - model=text2image_model, - response_format="b64_json", - extra_headers=auth_header, - ) - image = response.data[0].b64_json - decoded_image = base64.b64decode(image) - except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e: - if "content_policy_violation" in e.message: - logger.error(f"Image Generation blocked by OpenAI: {e}") - status_code = e.status_code # type: ignore - message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore - yield image_url or image, status_code, message, intent_type.value - return - else: - logger.error(f"Image Generation failed with {e}", exc_info=True) - message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore - status_code = e.status_code # type: ignore - yield image_url or image, status_code, message, intent_type.value - return - - elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI: - with timer("Generate image with Stability AI", logger): - try: - response = requests.post( - f"https://api.stability.ai/v2beta/stable-image/generate/sd3", - headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"}, - files={"none": ""}, - data={ - "prompt": improved_image_prompt, - "model": text2image_model, - "mode": "text-to-image", - "output_format": "png", - "aspect_ratio": "1:1", - }, - ) - decoded_image = response.content - except requests.RequestException as e: - logger.error(f"Image Generation failed with {e}", exc_info=True) - message = f"Image generation failed with Stability AI error: {e}" - status_code = e.status_code # type: ignore - yield image_url or image, status_code, message, intent_type.value - return - - with timer("Convert image to webp", logger): - # Convert png to webp for faster loading - webp_image_bytes = convert_image_to_webp(decoded_image) - - with timer("Upload image to S3", logger): - image_url = upload_image(webp_image_bytes, user.uuid) - if image_url: - intent_type = ImageIntentType.TEXT_TO_IMAGE2 - else: - intent_type = ImageIntentType.TEXT_TO_IMAGE_V3 - image = base64.b64encode(webp_image_bytes).decode("utf-8") - - yield image_url or image, status_code, improved_image_prompt, intent_type.value - - class ApiUserRateLimiter: def __init__(self, requests: int, subscribed_requests: int, window: int, slug: str): self.requests = requests