From 272eae5d66824cc243bd177e44273b4906c02ec4 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 12 Sep 2024 15:31:11 -0700 Subject: [PATCH 1/3] Add support for the newly released OpenAI O1 model series for preview The O1 series doesn't seem to support streaming, stop words or temperature, response_format currently. --- .../processor/conversation/openai/utils.py | 45 ++++++++++++++----- src/khoj/processor/conversation/utils.py | 12 +++-- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 1a42113ee..878dbb9c8 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -45,15 +45,28 @@ def completion_with_backoff( openai_clients[client_key] = client formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + stream = True + + # Update request parameters for compatability with o1 model series + # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations + if model.startswith("o1"): + stream = False + temperature = 1 + model_kwargs.pop("stop", None) + model_kwargs.pop("response_format", None) chat = client.chat.completions.create( - stream=True, + stream=stream, messages=formatted_messages, # type: ignore model=model, # type: ignore temperature=temperature, timeout=20, **(model_kwargs or dict()), ) + + if not stream: + return chat.choices[0].message.content + aggregated_response = "" for chunk in chat: if len(chunk.choices) == 0: @@ -112,9 +125,18 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba client: openai.OpenAI = openai_clients[client_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + stream = True + + # Update request parameters for compatability with o1 model series + # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations + if model_name.startswith("o1"): + stream = False + temperature = 1 + model_kwargs.pop("stop", None) + model_kwargs.pop("response_format", None) chat = client.chat.completions.create( - stream=True, + stream=stream, messages=formatted_messages, model=model_name, # type: ignore temperature=temperature, @@ -122,14 +144,17 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba **(model_kwargs or dict()), ) - for chunk in chat: - if len(chunk.choices) == 0: - continue - delta_chunk = chunk.choices[0].delta - if isinstance(delta_chunk, str): - g.send(delta_chunk) - elif delta_chunk.content: - g.send(delta_chunk.content) + if not stream: + g.send(chat.choices[0].message.content) + else: + for chunk in chat: + if len(chunk.choices) == 0: + continue + delta_chunk = chunk.choices[0].delta + if isinstance(delta_chunk, str): + g.send(delta_chunk) + elif delta_chunk.content: + g.send(delta_chunk.content) except Exception as e: logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 999473e28..6444b14df 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,4 +1,3 @@ -import json import logging import math import queue @@ -24,6 +23,8 @@ "gpt-4-0125-preview": 20000, "gpt-4-turbo-preview": 20000, "gpt-4o-mini": 20000, + "o1-preview": 20000, + "o1-mini": 20000, "TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000, @@ -220,8 +221,9 @@ def truncate_messages( try: if loaded_model: encoder = loaded_model.tokenizer() - elif model_name.startswith("gpt-"): - encoder = tiktoken.encoding_for_model(model_name) + elif model_name.startswith("gpt-") or model_name.startswith("o1"): + # as tiktoken doesn't recognize o1 model series yet + encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) elif tokenizer_name: if tokenizer_name in state.pretrained_tokenizers: encoder = state.pretrained_tokenizers[tokenizer_name] @@ -278,7 +280,9 @@ def truncate_messages( ) if system_message: - system_message.role = "user" if "gemma-2" in model_name else "system" + # Default system message role is system. + # Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series. + system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system" return messages + [system_message] if system_message else messages From 6e660d11c981d2341b4ab24672a96ef7b1511171 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 12 Sep 2024 16:27:55 -0700 Subject: [PATCH 2/3] Override block display styling of links by Katex in chat messages This happens sometimes when LLM respons contains [\[1\]] kind of links as reference. Both markdown-it and katex apply styling. Katex's span uses display: block which makes the rendering of these references take up a whole line by themselves. Override block styling of spans within an `a' element to prevent such chat message styling issues --- .../web/app/components/chatMessage/chatMessage.module.css | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/interface/web/app/components/chatMessage/chatMessage.module.css b/src/interface/web/app/components/chatMessage/chatMessage.module.css index e0dd31ace..7890fc2b0 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.module.css +++ b/src/interface/web/app/components/chatMessage/chatMessage.module.css @@ -18,6 +18,11 @@ div.chatMessageWrapper p:not(:last-child) { margin-bottom: 16px; } +/* Override some link styling by Katex to improve rendering */ +div.chatMessageWrapper a span { + display: revert !important; +} + div.khojfullHistory { border-width: 1px; padding-left: 4px; From 0685a79748d0b6572b870c66bf7d12863e2cdff3 Mon Sep 17 00:00:00 2001 From: Debanjum Singh Solanky Date: Thu, 12 Sep 2024 16:41:40 -0700 Subject: [PATCH 3/3] Remove any markdown json codeblock in chat actors expecting json responses Strip any json md codeblock wrapper if exists before processing response by output mode, extract questions chat actor. This is similar to what is already being done by other chat actors Useful for succesfully interpreting json output in chat actors when using non (json) schema enforceable models like o1 and gemma-2 Use conversation helper function to centralize the json md codeblock removal code --- src/khoj/processor/conversation/openai/gpt.py | 2 ++ src/khoj/processor/conversation/utils.py | 7 +++++++ src/khoj/routers/helpers.py | 13 ++++++------- 3 files changed, 15 insertions(+), 7 deletions(-) diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index c6f744fa8..90cd4df99 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -14,6 +14,7 @@ from khoj.processor.conversation.utils import ( construct_structured_message, generate_chatml_messages_with_context, + remove_json_codeblock, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -85,6 +86,7 @@ def extract_questions( # Extract, Clean Message from GPT's Response try: response = response.strip() + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 6444b14df..03bd17a39 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -289,3 +289,10 @@ def truncate_messages( def reciprocal_conversation_to_chatml(message_pair): """Convert a single back and forth between user and assistant to chatml format""" return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] + + +def remove_json_codeblock(response): + """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" + if response.startswith("```json") and response.endswith("```"): + response = response[7:-3] + return response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5687937a9..f1b8ddd66 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -88,6 +88,7 @@ from khoj.processor.conversation.utils import ( ThreadedGenerator, generate_chatml_messages_with_context, + remove_json_codeblock, save_to_conversation_log, ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled @@ -298,9 +299,7 @@ async def aget_relevant_information_sources( try: response = response.strip() - # Remove any markdown json codeblock formatting if present (useful for gemma-2) - if response.startswith("```json"): - response = response[7:-3] + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["source"] if q.strip()] if not isinstance(response, list) or not response or len(response) == 0: @@ -353,7 +352,9 @@ async def aget_relevant_output_modes( response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object") try: - response = json.loads(response.strip()) + response = response.strip() + response = remove_json_codeblock(response) + response = json.loads(response) if is_none_or_empty(response): return ConversationCommand.Text @@ -433,9 +434,7 @@ async def generate_online_subqueries( # Validate that the response is a non-empty, JSON-serializable list try: response = response.strip() - # Remove any markdown json codeblock formatting if present (useful for gemma-2) - if response.startswith("```json") and response.endswith("```"): - response = response[7:-3] + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response or len(response) == 0: