Skip to content

Commit

Permalink
Extract image generation code into new image processor for modularity
Browse files Browse the repository at this point in the history
  • Loading branch information
debanjum committed Sep 13, 2024
1 parent 84051d7 commit 75d3b34
Show file tree
Hide file tree
Showing 3 changed files with 214 additions and 186 deletions.
212 changes: 212 additions & 0 deletions src/khoj/processor/image/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
import base64
import io
import logging
import time
from typing import Any, Callable, Dict, List, Optional

import openai
import requests

from khoj.database.adapters import ConversationAdapters
from khoj.database.models import KhojUser, TextToImageModelConfig
from khoj.routers.helpers import ChatEvent, generate_better_image_prompt
from khoj.routers.storage import upload_image
from khoj.utils import state
from khoj.utils.helpers import ImageIntentType, convert_image_to_webp, timer
from khoj.utils.rawconfig import LocationData

logger = logging.getLogger(__name__)


async def text_to_image(
message: str,
user: KhojUser,
conversation_log: dict,
location_data: LocationData,
references: List[Dict[str, Any]],
online_results: Dict[str, Any],
subscribed: bool = False,
send_status_func: Optional[Callable] = None,
uploaded_image_url: Optional[str] = None,
):
status_code = 200
image = None
image_url = None
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3

text_to_image_config = await ConversationAdapters.aget_user_text_to_image_model(user)
if not text_to_image_config:
# If the user has not configured a text to image model, return an unsupported on server error
status_code = 501
message = "Failed to generate image. Setup image generation on the server."
yield image_url or image, status_code, message, intent_type.value
return

text2image_model = text_to_image_config.model_name
chat_history = ""
for chat in conversation_log.get("chat", [])[-4:]:
if chat["by"] == "khoj" and chat["intent"].get("type") in ["remember", "reminder"]:
chat_history += f"Q: {chat['intent']['query']}\n"
chat_history += f"A: {chat['message']}\n"
elif chat["by"] == "khoj" and "text-to-image" in chat["intent"].get("type"):
chat_history += f"Q: Prompt: {chat['intent']['query']}\n"
chat_history += f"A: Improved Prompt: {chat['intent']['inferred-queries'][0]}\n"

if send_status_func:
async for event in send_status_func("**Enhancing the Painting Prompt**"):
yield {ChatEvent.STATUS: event}

# Generate a better image prompt
# Use the user's message, chat history, and other context
image_prompt = await generate_better_image_prompt(
message,
chat_history,
location_data=location_data,
note_references=references,
online_results=online_results,
model_type=text_to_image_config.model_type,
subscribed=subscribed,
uploaded_image_url=uploaded_image_url,
)

if send_status_func:
async for event in send_status_func(f"**Painting to Imagine**:\n{image_prompt}"):
yield {ChatEvent.STATUS: event}

# Generate image using the configured model and API
with timer(f"Generate image with {text_to_image_config.model_type}", logger):
try:
if text_to_image_config.model_type == TextToImageModelConfig.ModelType.OPENAI:
webp_image_bytes = generate_image_with_openai(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.STABILITYAI:
webp_image_bytes = generate_image_with_stability(image_prompt, text_to_image_config, text2image_model)
elif text_to_image_config.model_type == TextToImageModelConfig.ModelType.REPLICATE:
webp_image_bytes = generate_image_with_replicate(image_prompt, text_to_image_config, text2image_model)
except openai.OpenAIError or openai.BadRequestError or openai.APIConnectionError as e:
if "content_policy_violation" in e.message:
logger.error(f"Image Generation blocked by OpenAI: {e}")
status_code = e.status_code # type: ignore
message = f"Image generation blocked by OpenAI: {e.message}" # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
else:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation failed with OpenAI error: {e.message}" # type: ignore
status_code = e.status_code # type: ignore
yield image_url or image, status_code, message, intent_type.value
return
except requests.RequestException as e:
logger.error(f"Image Generation failed with {e}", exc_info=True)
message = f"Image generation using {text2image_model} via {text_to_image_config.model_type} failed with error: {e}"
status_code = 502
yield image_url or image, status_code, message, intent_type.value
return

# Decide how to store the generated image
with timer("Upload image to S3", logger):
image_url = upload_image(webp_image_bytes, user.uuid)
if image_url:
intent_type = ImageIntentType.TEXT_TO_IMAGE2
else:
intent_type = ImageIntentType.TEXT_TO_IMAGE_V3
image = base64.b64encode(webp_image_bytes).decode("utf-8")

yield image_url or image, status_code, image_prompt, intent_type.value


def generate_image_with_openai(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using OpenAI API"

# Get the API key from the user's configuration
if text_to_image_config.api_key:
api_key = text_to_image_config.api_key
elif text_to_image_config.openai_config:
api_key = text_to_image_config.openai_config.api_key
elif state.openai_client:
api_key = state.openai_client.api_key
auth_header = {"Authorization": f"Bearer {api_key}"} if api_key else {}

# Generate image using OpenAI API
OPENAI_IMAGE_GEN_STYLE = "vivid"
response = state.openai_client.images.generate(
prompt=improved_image_prompt,
model=text2image_model,
style=OPENAI_IMAGE_GEN_STYLE,
response_format="b64_json",
extra_headers=auth_header,
)

# Extract the base64 image from the response
image = response.data[0].b64_json
# Decode base64 png and convert it to webp for faster loading
return convert_image_to_webp(base64.b64decode(image))


def generate_image_with_stability(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using Stability AI"

# Call Stability AI API to generate image
response = requests.post(
f"https://api.stability.ai/v2beta/stable-image/generate/sd3",
headers={"authorization": f"Bearer {text_to_image_config.api_key}", "accept": "image/*"},
files={"none": ""},
data={
"prompt": improved_image_prompt,
"model": text2image_model,
"mode": "text-to-image",
"output_format": "png",
"aspect_ratio": "1:1",
},
)
# Convert png to webp for faster loading
return convert_image_to_webp(response.content)


def generate_image_with_replicate(
improved_image_prompt: str, text_to_image_config: TextToImageModelConfig, text2image_model: str
):
"Generate image using Replicate API"

# Create image generation task on Replicate
replicate_create_prediction_url = f"https://api.replicate.com/v1/models/{text2image_model}/predictions"
headers = {
"Authorization": f"Bearer {text_to_image_config.api_key}",
"Content-Type": "application/json",
}
json = {
"input": {
"prompt": improved_image_prompt,
"num_outputs": 1,
"aspect_ratio": "1:1",
"output_format": "webp",
"output_quality": 100,
}
}
create_prediction = requests.post(replicate_create_prediction_url, headers=headers, json=json).json()

# Get status of image generation task
get_prediction_url = create_prediction["urls"]["get"]
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count = 1

# Poll the image generation task for completion status
while status not in ["succeeded", "failed", "canceled"] and retry_count < 20:
time.sleep(2)
get_prediction = requests.get(get_prediction_url, headers=headers).json()
status = get_prediction["status"]
retry_count += 1

# Raise exception if the image generation task fails
if status != "succeeded":
if retry_count >= 10:
raise requests.RequestException("Image generation timed out")
raise requests.RequestException(f"Image generation failed with status: {status}")

# Get the generated image
image_url = get_prediction["output"][0] if isinstance(get_prediction["output"], list) else get_prediction["output"]
return io.BytesIO(requests.get(image_url).content).getvalue()
2 changes: 1 addition & 1 deletion src/khoj/routers/api_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from khoj.database.models import KhojUser
from khoj.processor.conversation.prompts import help_message, no_entries_found
from khoj.processor.conversation.utils import save_to_conversation_log
from khoj.processor.image.generate import text_to_image
from khoj.processor.speech.text_to_speech import generate_text_to_speech
from khoj.processor.tools.online_search import read_webpages, search_online
from khoj.routers.api import extract_references_and_questions
Expand All @@ -44,7 +45,6 @@
is_query_empty,
is_ready_to_chat,
read_chat_stream,
text_to_image,
update_telemetry_state,
validate_conversation_config,
)
Expand Down
Loading

0 comments on commit 75d3b34

Please sign in to comment.