From 893ae60a6aff93e470798b95cc684d1dcbde163f Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Sun, 15 Sep 2024 01:10:35 -0700 Subject: [PATCH] Improve handling of harmful categorized responses by Gemini Previously Khoj would stop in the middle of response generation when the safety filters got triggered at default thresholds. This was confusing as it felt like a service error, not expected behavior. Going forward Khoj will - Only block responding to high confidence harmful content detected by Gemini's safety filters instead of using the default safety settings - Show an explanatory, conversational response (w/ harm category) when response is terminated due to Gemini's safety filters --- .../processor/conversation/google/utils.py | 113 ++++++++++++++++-- 1 file changed, 106 insertions(+), 7 deletions(-) diff --git a/src/khoj/processor/conversation/google/utils.py b/src/khoj/processor/conversation/google/utils.py index 4ddf5e2c2..63b8b6100 100644 --- a/src/khoj/processor/conversation/google/utils.py +++ b/src/khoj/processor/conversation/google/utils.py @@ -1,7 +1,18 @@ import logging +import random from threading import Thread import google.generativeai as genai +from google.generativeai.types.answer_types import FinishReason +from google.generativeai.types.generation_types import ( + GenerateContentResponse, + StopCandidateException, +) +from google.generativeai.types.safety_types import ( + HarmBlockThreshold, + HarmCategory, + HarmProbability, +) from tenacity import ( before_sleep_log, retry, @@ -32,14 +43,35 @@ def gemini_completion_with_backoff( model_kwargs = model_kwargs or dict() model_kwargs["temperature"] = temperature model_kwargs["max_output_tokens"] = max_tokens - model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt) + model = genai.GenerativeModel( + model_name, + generation_config=model_kwargs, + system_instruction=system_prompt, + safety_settings={ + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + }, + ) formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] - # all messages up to the last are considered to be part of the chat history + + # Start chat session. All messages up to the last are considered to be part of the chat history chat_session = model.start_chat(history=formatted_messages[0:-1]) - # the last message is considered to be the current prompt - aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0]) - return aggregated_response.text + + try: + # Generate the response. The last message is considered to be the current prompt + aggregated_response = chat_session.send_message(formatted_messages[-1]["parts"][0]) + return aggregated_response.text + except StopCandidateException as e: + response_message, _ = handle_gemini_response(e.args) + # Respond with reason for stopping + logger.warning( + f"LLM Response Prevented for {model_name}: {response_message}.\n" + + f"Last Message by {messages[-1].role}: {messages[-1].content}" + ) + return response_message @retry( @@ -79,15 +111,82 @@ def gemini_llm_thread( model_kwargs["temperature"] = temperature model_kwargs["max_output_tokens"] = max_tokens model_kwargs["stop_sequences"] = ["Notes:\n["] - model = genai.GenerativeModel(model_name, generation_config=model_kwargs, system_instruction=system_prompt) + model = genai.GenerativeModel( + model_name, + generation_config=model_kwargs, + system_instruction=system_prompt, + safety_settings={ + HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_ONLY_HIGH, + HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_ONLY_HIGH, + }, + ) formatted_messages = [{"role": message.role, "parts": [message.content]} for message in messages] # all messages up to the last are considered to be part of the chat history chat_session = model.start_chat(history=formatted_messages[0:-1]) # the last message is considered to be the current prompt for chunk in chat_session.send_message(formatted_messages[-1]["parts"][0], stream=True): - g.send(chunk.text) + message, stopped = handle_gemini_response(chunk.candidates, chunk.prompt_feedback) + message = message or chunk.text + g.send(message) + if stopped: + raise StopCandidateException(message) + except StopCandidateException as e: + logger.warning( + f"LLM Response Prevented for {model_name}: {e.args[0]}.\n" + + f"Last Message by {messages[-1].role}: {messages[-1].content}" + ) except Exception as e: logger.error(f"Error in gemini_llm_thread: {e}", exc_info=True) finally: g.close() + + +def handle_gemini_response(candidates, prompt_feedback=None): + """Check if Gemini response was blocked and return an explanatory error message.""" + # Check if the response was blocked due to safety concerns with the prompt + if len(candidates) == 0 and prompt_feedback: + message = f"\nI'd prefer to not respond to that due to **{prompt_feedback.block_reason.name}** issues with your query." + stopped = True + # Check if the response was blocked due to safety concerns with the generated content + elif candidates[0].finish_reason == FinishReason.SAFETY: + message = generate_safety_response(candidates[0].safety_ratings) + stopped = True + # Check if the response was stopped due to reaching maximum token limit or other reasons + elif candidates[0].finish_reason != FinishReason.STOP: + message = f"\nI can't talk further about that because of **{candidates[0].finish_reason.name} issue.**" + stopped = True + # Otherwise, the response is valid and can be used + else: + message = None + stopped = False + return message, stopped + + +def generate_safety_response(safety_ratings): + """Generate a conversational response based on the safety ratings of the response.""" + # Get the safety rating with the highest probability + max_safety_rating = sorted(safety_ratings, key=lambda x: x.probability, reverse=True)[0] + # Remove the "HARM_CATEGORY_" prefix and title case the category name + max_safety_category = " ".join(max_safety_rating.category.name.split("_")[2:]).title() + # Add a bit of variety to the discomfort level based on the safety rating probability + discomfort_level = { + HarmProbability.HARM_PROBABILITY_UNSPECIFIED: " ", + HarmProbability.LOW: "a bit ", + HarmProbability.MEDIUM: "moderately ", + HarmProbability.HIGH: random.choice(["very ", "quite ", "fairly "]), + }[max_safety_rating.probability] + # Generate a response using a random response template + safety_response_choice = random.choice( + [ + "\nUmm, I'd rather not to respond to that. The conversation has some probability of going into **{category}** territory.", + "\nI'd prefer not to talk about **{category}** related topics. It makes me {discomfort_level}uncomfortable.", + "\nI feel {discomfort_level}squeamish talking about **{category}** related stuff! Can we talk about something less controversial?", + "\nThat sounds {discomfort_level}outside the [Overtone Window](https://en.wikipedia.org/wiki/Overton_window) of acceptable conversation. Should we stick to something less {category} related?", + ] + ) + return safety_response_choice.format( + category=max_safety_category, probability=max_safety_rating.probability.name, discomfort_level=discomfort_level + )