diff --git a/src/interface/web/app/components/chatMessage/chatMessage.module.css b/src/interface/web/app/components/chatMessage/chatMessage.module.css index e0dd31ace..7890fc2b0 100644 --- a/src/interface/web/app/components/chatMessage/chatMessage.module.css +++ b/src/interface/web/app/components/chatMessage/chatMessage.module.css @@ -18,6 +18,11 @@ div.chatMessageWrapper p:not(:last-child) { margin-bottom: 16px; } +/* Override some link styling by Katex to improve rendering */ +div.chatMessageWrapper a span { + display: revert !important; +} + div.khojfullHistory { border-width: 1px; padding-left: 4px; diff --git a/src/khoj/processor/conversation/openai/gpt.py b/src/khoj/processor/conversation/openai/gpt.py index c6f744fa8..90cd4df99 100644 --- a/src/khoj/processor/conversation/openai/gpt.py +++ b/src/khoj/processor/conversation/openai/gpt.py @@ -14,6 +14,7 @@ from khoj.processor.conversation.utils import ( construct_structured_message, generate_chatml_messages_with_context, + remove_json_codeblock, ) from khoj.utils.helpers import ConversationCommand, is_none_or_empty from khoj.utils.rawconfig import LocationData @@ -85,6 +86,7 @@ def extract_questions( # Extract, Clean Message from GPT's Response try: response = response.strip() + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response: diff --git a/src/khoj/processor/conversation/openai/utils.py b/src/khoj/processor/conversation/openai/utils.py index 1a42113ee..878dbb9c8 100644 --- a/src/khoj/processor/conversation/openai/utils.py +++ b/src/khoj/processor/conversation/openai/utils.py @@ -45,15 +45,28 @@ def completion_with_backoff( openai_clients[client_key] = client formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + stream = True + + # Update request parameters for compatability with o1 model series + # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations + if model.startswith("o1"): + stream = False + temperature = 1 + model_kwargs.pop("stop", None) + model_kwargs.pop("response_format", None) chat = client.chat.completions.create( - stream=True, + stream=stream, messages=formatted_messages, # type: ignore model=model, # type: ignore temperature=temperature, timeout=20, **(model_kwargs or dict()), ) + + if not stream: + return chat.choices[0].message.content + aggregated_response = "" for chunk in chat: if len(chunk.choices) == 0: @@ -112,9 +125,18 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba client: openai.OpenAI = openai_clients[client_key] formatted_messages = [{"role": message.role, "content": message.content} for message in messages] + stream = True + + # Update request parameters for compatability with o1 model series + # Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations + if model_name.startswith("o1"): + stream = False + temperature = 1 + model_kwargs.pop("stop", None) + model_kwargs.pop("response_format", None) chat = client.chat.completions.create( - stream=True, + stream=stream, messages=formatted_messages, model=model_name, # type: ignore temperature=temperature, @@ -122,14 +144,17 @@ def llm_thread(g, messages, model_name, temperature, openai_api_key=None, api_ba **(model_kwargs or dict()), ) - for chunk in chat: - if len(chunk.choices) == 0: - continue - delta_chunk = chunk.choices[0].delta - if isinstance(delta_chunk, str): - g.send(delta_chunk) - elif delta_chunk.content: - g.send(delta_chunk.content) + if not stream: + g.send(chat.choices[0].message.content) + else: + for chunk in chat: + if len(chunk.choices) == 0: + continue + delta_chunk = chunk.choices[0].delta + if isinstance(delta_chunk, str): + g.send(delta_chunk) + elif delta_chunk.content: + g.send(delta_chunk.content) except Exception as e: logger.error(f"Error in llm_thread: {e}", exc_info=True) finally: diff --git a/src/khoj/processor/conversation/utils.py b/src/khoj/processor/conversation/utils.py index 999473e28..03bd17a39 100644 --- a/src/khoj/processor/conversation/utils.py +++ b/src/khoj/processor/conversation/utils.py @@ -1,4 +1,3 @@ -import json import logging import math import queue @@ -24,6 +23,8 @@ "gpt-4-0125-preview": 20000, "gpt-4-turbo-preview": 20000, "gpt-4o-mini": 20000, + "o1-preview": 20000, + "o1-mini": 20000, "TheBloke/Mistral-7B-Instruct-v0.2-GGUF": 3500, "NousResearch/Hermes-2-Pro-Mistral-7B-GGUF": 3500, "bartowski/Meta-Llama-3.1-8B-Instruct-GGUF": 20000, @@ -220,8 +221,9 @@ def truncate_messages( try: if loaded_model: encoder = loaded_model.tokenizer() - elif model_name.startswith("gpt-"): - encoder = tiktoken.encoding_for_model(model_name) + elif model_name.startswith("gpt-") or model_name.startswith("o1"): + # as tiktoken doesn't recognize o1 model series yet + encoder = tiktoken.encoding_for_model("gpt-4o" if model_name.startswith("o1") else model_name) elif tokenizer_name: if tokenizer_name in state.pretrained_tokenizers: encoder = state.pretrained_tokenizers[tokenizer_name] @@ -278,10 +280,19 @@ def truncate_messages( ) if system_message: - system_message.role = "user" if "gemma-2" in model_name else "system" + # Default system message role is system. + # Fallback to system message role of user for models that do not support this role like gemma-2 and openai's o1 model series. + system_message.role = "user" if "gemma-2" in model_name or model_name.startswith("o1") else "system" return messages + [system_message] if system_message else messages def reciprocal_conversation_to_chatml(message_pair): """Convert a single back and forth between user and assistant to chatml format""" return [ChatMessage(content=message, role=role) for message, role in zip(message_pair, ["user", "assistant"])] + + +def remove_json_codeblock(response): + """Remove any markdown json codeblock formatting if present. Useful for non schema enforceable models""" + if response.startswith("```json") and response.endswith("```"): + response = response[7:-3] + return response diff --git a/src/khoj/routers/helpers.py b/src/khoj/routers/helpers.py index 5687937a9..f1b8ddd66 100644 --- a/src/khoj/routers/helpers.py +++ b/src/khoj/routers/helpers.py @@ -88,6 +88,7 @@ from khoj.processor.conversation.utils import ( ThreadedGenerator, generate_chatml_messages_with_context, + remove_json_codeblock, save_to_conversation_log, ) from khoj.processor.speech.text_to_speech import is_eleven_labs_enabled @@ -298,9 +299,7 @@ async def aget_relevant_information_sources( try: response = response.strip() - # Remove any markdown json codeblock formatting if present (useful for gemma-2) - if response.startswith("```json"): - response = response[7:-3] + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["source"] if q.strip()] if not isinstance(response, list) or not response or len(response) == 0: @@ -353,7 +352,9 @@ async def aget_relevant_output_modes( response = await send_message_to_model_wrapper(relevant_mode_prompt, response_type="json_object") try: - response = json.loads(response.strip()) + response = response.strip() + response = remove_json_codeblock(response) + response = json.loads(response) if is_none_or_empty(response): return ConversationCommand.Text @@ -433,9 +434,7 @@ async def generate_online_subqueries( # Validate that the response is a non-empty, JSON-serializable list try: response = response.strip() - # Remove any markdown json codeblock formatting if present (useful for gemma-2) - if response.startswith("```json") and response.endswith("```"): - response = response[7:-3] + response = remove_json_codeblock(response) response = json.loads(response) response = [q.strip() for q in response["queries"] if q.strip()] if not isinstance(response, list) or not response or len(response) == 0: