Skip to content

Commit

Permalink
Improve handling of harmful categorized responses by Gemini
Browse files Browse the repository at this point in the history
Previously Khoj would stop in the middle of response generation when
the safety filters got triggered at default thresholds. This was
confusing as it felt like a service error, not expected behavior.

Going forward Khoj will
- Only block responding to high confidence harmful content detected by
  Gemini's safety filters instead of using the default safety settings
- Show an explanatory, conversational response (w/ harm category)
  when response is terminated due to Gemini's safety filters
  • Loading branch information
debanjum committed Sep 15, 2024
1 parent ec1f87a commit 893ae60
Showing 1 changed file with 106 additions and 7 deletions.
113 changes: 106 additions & 7 deletions src/khoj/processor/conversation/google/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
import logging
import random
from threading import Thread

import google.generativeai as genai
from google.generativeai.types.answer_types import FinishReason
from google.generativeai.types.generation_types import (
GenerateContentResponse,
StopCandidateException,
)
from google.generativeai.types.safety_types import (
HarmBlockThreshold,
HarmCategory,
HarmProbability,
)
from tenacity import (
before_sleep_log,
retry,
Expand Down Expand Up @@ -32,14 +43,35 @@ def gemini_completion_with_backoff(
model_kwargs = model_kwargs or dict()
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = max_tokens
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
model = genai.GenerativeModel(
model_name,
generation_config=model_kwargs,
system_instruction=system_prompt,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
)

formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
# all messages up to the last are considered to be part of the chat history

# Start chat session. All messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
return aggregated_response.text

try:
# Generate the response. The last message is considered to be the current prompt
aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0])
return aggregated_response.text
except StopCandidateException as e:
response_message, _ = handle_gemini_response(e.args)
# Respond with reason for stopping
logger.warning(
f"LLM Response Prevented for {model_name}: {response_message}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
return response_message


@retry(
Expand Down Expand Up @@ -79,15 +111,82 @@ def gemini_llm_thread(
model_kwargs["temperature"] = temperature
model_kwargs["max_output_tokens"] = max_tokens
model_kwargs["stop_sequences"] = ["Notes:\n["]
model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt)
model = genai.GenerativeModel(
model_name,
generation_config=model_kwargs,
system_instruction=system_prompt,
safety_settings={
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH,
},
)

formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages]
# all messages up to the last are considered to be part of the chat history
chat_session = model.start_chat(history=formatted_messages[0:-1])
# the last message is considered to be the current prompt
for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True):
g.send(chunk.text)
message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback)
message = message or chunk.text
g.send(message)
if stopped:
raise StopCandidateException(message)
except StopCandidateException as e:
logger.warning(
f"LLM Response Prevented for {model_name}: {e.args[0]}.\n"
+ f"Last Message by {messages[-1].role}: {messages[-1].content}"
)
except Exception as e:
logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True)
finally:
g.close()


def handle_gemini_response(candidates, prompt_feedback=None):
"""Check if Gemini response was blocked and return an explanatory error message."""
# Check if the response was blocked due to safety concerns with the prompt
if len(candidates) == 0 and prompt_feedback:
message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query."
stopped = True
# Check if the response was blocked due to safety concerns with the generated content
elif candidates[0].finish_reason == FinishReason.SAFETY:
message = generate_safety_response(candidates[0].safety_ratings)
stopped = True
# Check if the response was stopped due to reaching maximum token limit or other reasons
elif candidates[0].finish_reason != FinishReason.STOP:
message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**"
stopped = True
# Otherwise, the response is valid and can be used
else:
message = None
stopped = False
return message, stopped


def generate_safety_response(safety_ratings):
"""Generate a conversational response based on the safety ratings of the response."""
# Get the safety rating with the highest probability
max_safety_rating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0]
# Remove the "HARM_CATEGORY_" prefix and title case the category name
max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title()
# Add a bit of variety to the discomfort level based on the safety rating probability
discomfort_level = {
HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ",
HarmProbability.LOW: "a bit ",
HarmProbability.MEDIUM: "moderately ",
HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]),
}[max_safety_rating.probability]
# Generate a response using a random response template
safety_response_choice = random.choice(
[
"\nUmm, I'd rather not to respond to that. The conversation has some probability of going into **{category}** territory.",
"\nI'd prefer not to talk about **{category}** related topics. It makes me {discomfort_level}uncomfortable.",
"\nI feel {discomfort_level}squeamish talking about **{category}** related stuff! Can we talk about something less controversial?",
"\nThat sounds {discomfort_level}outside the [Overtone Window](https://en.wikipedia.org/wiki/Overton_window) of acceptable conversation. Should we stick to something less {category} related?",
]
)
return safety_response_choice.format(
category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level
)

0 comments on commit 893ae60

Please sign in to comment.