Skip to content

Commit

Permalink
Support Google's Gemini model series (#902)
Browse files Browse the repository at this point in the history
* Add functions to chat with Google's gemini model series
  * Gracefully close thread when there's an exception in the gemini llm thread
* Use enums for verifying the chat model option type
* Add a migration to add the gemini chat model type to the db model
* Fix chat model selection verification and math prompt tuning
* Fix extract questions method with gemini. Enforce json response in extract questions.
* Add standard stop sequence for Gemini chat response generation

---------

Co-authored-by: sabaimran <[email protected]>
Co-authored-by: Debanjum Singh Solanky <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2024
1 parent 42b727e commit 9570933
Show file tree
Hide file tree
Showing 11 changed files with 438 additions and 16 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ dependencies = [
"cron-descriptor == 1.4.3",
"django_apscheduler == 0.6.2",
"anthropic == 0.26.1",
"docx2txt == 0.8"
"docx2txt == 0.8",
"google-generativeai == 0.7.2"
]
dynamic = ["version"]

Expand Down
9 changes: 7 additions & 2 deletions src/khoj/database/adapters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,7 @@ def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
if conversation_config is None:
conversation_config = ConversationAdapters.get_default_conversation_config()

if conversation_config.model_type == "offline":
if conversation_config.model_type == ChatModelOptions.ModelType.OFFLINE:
if state.offline_chat_processor_config is None or state.offline_chat_processor_config.loaded_model is None:
chat_model = conversation_config.chat_model
max_tokens = conversation_config.max_prompt_size
Expand All @@ -982,7 +982,12 @@ def get_valid_conversation_config(user: KhojUser, conversation: Conversation):
return conversation_config

if (
conversation_config.model_type == "openai" or conversation_config.model_type == "anthropic"
conversation_config.model_type
in [
ChatModelOptions.ModelType.ANTHROPIC,
ChatModelOptions.ModelType.OPENAI,
ChatModelOptions.ModelType.GOOGLE,
]
) and conversation_config.openai_config:
return conversation_config

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Generated by Django 5.0.8 on 2024-09-12 20:06

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0060_merge_20240905_1828"),
]

operations = [
migrations.AlterField(
model_name="chatmodeloptions",
name="model_type",
field=models.CharField(
choices=[
("openai", "Openai"),
("offline", "Offline"),
("anthropic", "Anthropic"),
("google", "Google"),
],
default="offline",
max_length=200,
),
),
]
1 change: 1 addition & 0 deletions src/khoj/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class ModelType(models.TextChoices):
OPENAI = "openai"
OFFLINE = "offline"
ANTHROPIC = "anthropic"
GOOGLE = "google"

max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
subscribed_max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
Expand Down
Empty file.
221 changes: 221 additions & 0 deletions src/khoj/processor/conversation/google/gemini_chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
import json
import logging
import re
from datetime import datetime, timedelta
from typing import Dict, Optional

from langchain.schema import ChatMessage

from khoj.database.models import Agent, KhojUser
from khoj.processor.conversation import prompts
from khoj.processor.conversation.google.utils import (
gemini_chat_completion_with_backoff,
gemini_completion_with_backoff,
)
from khoj.processor.conversation.utils import generate_chatml_messages_with_context
from khoj.utils.helpers import ConversationCommand, is_none_or_empty
from khoj.utils.rawconfig import LocationData

logger = logging.getLogger(__name__)


def extract_questions_gemini(
text,
model: Optional[str] = "gemini-1.5-flash",
conversation_log={},
api_key=None,
temperature=0,
max_tokens=None,
location_data: LocationData = None,
user: KhojUser = None,
):
"""
Infer search queries to retrieve relevant notes to answer user query
"""
# Extract Past User Message and Inferred Questions from Conversation Log
location = f"{location_data.city}, {location_data.region}, {location_data.country}" if location_data else "Unknown"
username = prompts.user_name.format(name=user.get_full_name()) if user and user.get_full_name() else ""

# Extract Past User Message and Inferred Questions from Conversation Log
chat_history = "".join(
[
f'User: {chat["intent"]["query"]}\nAssistant: {{"queries": {chat["intent"].get("inferred-queries") or list([chat["intent"]["query"]])}}}\nA: {chat["message"]}\n\n'
for chat in conversation_log.get("chat", [])[-4:]
if chat["by"] == "khoj" and "text-to-image" not in chat["intent"].get("type")
]
)

# Get dates relative to today for prompt creation
today = datetime.today()
current_new_year = today.replace(month=1, day=1)
last_new_year = current_new_year.replace(year=today.year - 1)

system_prompt = prompts.extract_questions_anthropic_system_prompt.format(
current_date=today.strftime("%Y-%m-%d"),
day_of_week=today.strftime("%A"),
current_month=today.strftime("%Y-%m"),
last_new_year=last_new_year.strftime("%Y"),
last_new_year_date=last_new_year.strftime("%Y-%m-%d"),
current_new_year_date=current_new_year.strftime("%Y-%m-%d"),
yesterday_date=(today - timedelta(days=1)).strftime("%Y-%m-%d"),
location=location,
username=username,
)

prompt = prompts.extract_questions_anthropic_user_message.format(
chat_history=chat_history,
text=text,
)

messages = [ChatMessage(content=prompt, role="user")]

model_kwargs = {"response_mime_type": "application/json"}

response = gemini_completion_with_backoff(
messages=messages,
system_prompt=system_prompt,
model_name=model,
temperature=temperature,
api_key=api_key,
max_tokens=max_tokens,
model_kwargs=model_kwargs,
)

# Extract, Clean Message from Gemini's Response
try:
response = response.strip()
match = re.search(r"\{.*?\}", response)
if match:
response = match.group()
response = json.loads(response)
response = [q.strip() for q in response["queries"] if q.strip()]
if not isinstance(response, list) or not response:
logger.error(f"Invalid response for constructing subqueries: {response}")
return [text]
return response
except:
logger.warning(f"Gemini returned invalid JSON. Falling back to using user message as search query.\n{response}")
questions = [text]
logger.debug(f"Extracted Questions by Gemini: {questions}")
return questions


def gemini_send_message_to_model(messages, api_key, model, response_type="text"):
"""
Send message to model
"""
system_prompt = None
if len(messages) == 1:
messages[0].role = "user"
else:
system_prompt = ""
for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)

model_kwargs = {}
if response_type == "json_object":
model_kwargs["response_mime_type"] = "application/json"

# Get Response from Gemini
return gemini_completion_with_backoff(
messages=messages, system_prompt=system_prompt, model_name=model, api_key=api_key, model_kwargs=model_kwargs
)


def converse_gemini(
references,
user_query,
online_results: Optional[Dict[str, Dict]] = None,
conversation_log={},
model: Optional[str] = "gemini-1.5-flash",
api_key: Optional[str] = None,
temperature: float = 0.2,
completion_func=None,
conversation_commands=[ConversationCommand.Default],
max_prompt_size=None,
tokenizer_name=None,
location_data: LocationData = None,
user_name: str = None,
agent: Agent = None,
):
"""
Converse with user using Google's Gemini
"""
# Initialize Variables
current_date = datetime.now()
compiled_references = "\n\n".join({f"# {item}" for item in references})

conversation_primer = prompts.query_prompt.format(query=user_query)

if agent and agent.personality:
system_prompt = prompts.custom_personality.format(
name=agent.name,
bio=agent.personality,
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)
else:
system_prompt = prompts.personality.format(
current_date=current_date.strftime("%Y-%m-%d"),
day_of_week=current_date.strftime("%A"),
)

if location_data:
location = f"{location_data.city}, {location_data.region}, {location_data.country}"
location_prompt = prompts.user_location.format(location=location)
system_prompt = f"{system_prompt}\n{location_prompt}"

if user_name:
user_name_prompt = prompts.user_name.format(name=user_name)
system_prompt = f"{system_prompt}\n{user_name_prompt}"

# Get Conversation Primer appropriate to Conversation Type
if conversation_commands == [ConversationCommand.Notes] and is_none_or_empty(compiled_references):
completion_func(chat_response=prompts.no_notes_found.format())
return iter([prompts.no_notes_found.format()])
elif conversation_commands == [ConversationCommand.Online] and is_none_or_empty(online_results):
completion_func(chat_response=prompts.no_online_results_found.format())
return iter([prompts.no_online_results_found.format()])

if ConversationCommand.Online in conversation_commands or ConversationCommand.Webpage in conversation_commands:
conversation_primer = (
f"{prompts.online_search_conversation.format(online_results=str(online_results))}\n{conversation_primer}"
)
if not is_none_or_empty(compiled_references):
conversation_primer = f"{prompts.notes_conversation.format(query=user_query, references=compiled_references)}\n\n{conversation_primer}"

# Setup Prompt with Primer or Conversation History
messages = generate_chatml_messages_with_context(
conversation_primer,
conversation_log=conversation_log,
model_name=model,
max_prompt_size=max_prompt_size,
tokenizer_name=tokenizer_name,
)

for message in messages:
if message.role == "assistant":
message.role = "model"

for message in messages.copy():
if message.role == "system":
system_prompt += message.content
messages.remove(message)

truncated_messages = "\n".join({f"{message.content[:40]}..." for message in messages})
logger.debug(f"Conversation Context for Gemini: {truncated_messages}")

# Get Response from Google AI
return gemini_chat_completion_with_backoff(
messages=messages,
compiled_references=references,
online_results=online_results,
model_name=model,
temperature=temperature,
api_key=api_key,
system_prompt=system_prompt,
completion_func=completion_func,
max_prompt_size=max_prompt_size,
)
93 changes: 93 additions & 0 deletions src/khoj/processor/conversation/google/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import logging
from threading import Thread

import google.generativeai as genai
from tenacity import (
before_sleep_log,
retry,
stop_after_attempt,
wait_exponential,
wait_random_exponential,
)

from khoj.processor.conversation.utils import ThreadedGenerator

logger = logging.getLogger(__name__)


DEFAULT_MAX_TOKENS_GEMINI = 8192


@retry(
wait=wait_random_exponential(min=1, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def gemini_completion_with_backoff(
messages, system_prompt, model_name, temperature=0, api_key=None, model_kwargs=None, max_tokens=None
) -> str:
genai.configure(api_key=api_key)
max_tokens = max_tokens or DEFAULT_MAX_TOKENS_GEMINI
model_kwargs = model_kwargs or dict()
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = max_tokens
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)

formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
return aggregated_response.text


@retry(
wait=wait_exponential(multiplier=1, min=4, max=10),
stop=stop_after_attempt(2),
before_sleep=before_sleep_log(logger, logging.DEBUG),
reraise=True,
)
def gemini_chat_completion_with_backoff(
messages,
compiled_references,
online_results,
model_name,
temperature,
api_key,
system_prompt,
max_prompt_size=None,
completion_func=None,
model_kwargs=None,
):
g = ThreadedGenerator(compiled_references, online_results, completion_func=completion_func)
t = Thread(
target=gemini_llm_thread,
args=(g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size, model_kwargs),
)
t.start()
return g


def gemini_llm_thread(
g, messages, system_prompt, model_name, temperature, api_key, max_prompt_size=None, model_kwargs=None
):
try:
genai.configure(api_key=api_key)
max_tokens = max_prompt_size or DEFAULT_MAX_TOKENS_GEMINI
model_kwargs = model_kwargs or dict()
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = max_tokens
model_kwargs["stop_sequences"] = ["Notes:\n["]
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)

formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
g.send(chunk.text)
except Exception as e:
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
finally:
g.close()
4 changes: 2 additions & 2 deletions src/khoj/processor/conversation/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
- You *CAN* generate images, look-up real-time information from the internet, set reminders and answer questions based on the user's notes.
- Say "I don't know" or "I don't understand" if you don't know what to say or if you don't know the answer to a question.
- Make sure to use the specific LaTeX math mode delimiters for your response. LaTex math mode specific delimiters as following
- inline math mode : `\\(` and `\\)`
- display math mode: insert linebreak after opening `$$`, `\\[` and before closing `$$`, `\\]`
- inline math mode : \\( and \\)
- display math mode: insert linebreak after opening $$, \\[ and before closing $$, \\]
- Ask crisp follow-up questions to get additional context, when the answer cannot be inferred from the provided notes or past conversations.
- Sometimes the user will share personal information that needs to be remembered, like an account ID or a residential address. These can be acknowledged with a simple "Got it" or "Okay".
- Provide inline references to quotes from the user's notes or any web pages you refer to in your responses in markdown format. For example, "The farmer had ten sheep. [1](https://example.com)". *ALWAYS CITE YOUR SOURCES AND PROVIDE REFERENCES*. Add them inline to directly support your claim.
Expand Down
Loading

0 comments on commit 9570933

Please sign in to comment.