diff --git a/src/llama_cpp_agent/messages_formatter.py b/src/llama_cpp_agent/messages_formatter.py index 8e0e69b..4ead17a 100644 --- a/src/llama_cpp_agent/messages_formatter.py +++ b/src/llama_cpp_agent/messages_formatter.py @@ -28,6 +28,9 @@ class MessagesFormatterType(Enum): AUTOCODER = 15 GEMMA_2 = 16 DEEP_SEEK_CODER_2 = 17 + PHI_4 = 18 + DEEPSEEK_R1_DISTILL_QWEN = 19 + MISTRAL_SMALL_3 = 20 @dataclass class PromptMarkers: @@ -137,6 +140,14 @@ def _format_response( Roles.tool: PromptMarkers("", ""), } +mistral_small3_prompt_markers = { + Roles.system: PromptMarkers("""[SYSTEM_PROMPT]""", """[/SYSTEM_PROMPT]"""), + Roles.user: PromptMarkers("""[INST]""", """ [/INST]"""), + Roles.assistant: PromptMarkers("""""", """"""), + Roles.tool: PromptMarkers("", ""), +} + + chatml_prompt_markers = { Roles.system: PromptMarkers("""<|im_start|>system\n""", """<|im_end|>\n"""), Roles.user: PromptMarkers("""<|im_start|>user\n""", """<|im_end|>\n"""), @@ -161,7 +172,7 @@ def _format_response( llama_3_prompt_markers = { Roles.system: PromptMarkers("""<|start_header_id|>system<|end_header_id|>\n""", """<|eot_id|>"""), Roles.user: PromptMarkers("""<|start_header_id|>user<|end_header_id|>\n""", """<|eot_id|>"""), - Roles.assistant: PromptMarkers("""<|start_header_id|>assistant<|end_header_id|>\n\n""", """<|eot_id|>"""), + Roles.assistant: PromptMarkers("""<|start_header_id|>assistant<|end_header_id|>\n""", """<|eot_id|>"""), Roles.tool: PromptMarkers("""<|start_header_id|>function_calling_results<|end_header_id|>\n""", """<|eot_id|>"""), } @@ -181,7 +192,7 @@ def _format_response( gemma_2_prompt_markers = { Roles.system: PromptMarkers("""""", """\n\n"""), Roles.user: PromptMarkers("""user\n""", """\n"""), - Roles.assistant: PromptMarkers("""model\n\n""", """\n"""), + Roles.assistant: PromptMarkers("""model\n""", """\n"""), Roles.tool: PromptMarkers("", ""), } code_ds_prompt_markers = { @@ -243,6 +254,18 @@ def _format_response( Roles.assistant: PromptMarkers("""Assistant: """, """<|end▁of▁sentence|>"""), Roles.tool: PromptMarkers("", ""), } +phi_4_chat_prompt_markers = { + Roles.system: PromptMarkers("""<|im_start|>system<|im_sep|>\n""", """<|im_end|>"""), + Roles.user: PromptMarkers("""<|im_start|>user<|im_sep|>\n""", """<|im_end|>\n"""), + Roles.assistant: PromptMarkers("""<|im_start|>assistant<|im_sep|>""", """<|im_end|>\n"""), + Roles.tool: PromptMarkers("", ""), +} +deepseek_r1_distill_qwen_chat_prompt_markers = { + Roles.system: PromptMarkers("""<|begin▁of▁sentence|>""", ""), + Roles.user: PromptMarkers("""<|User|>""", ""), + Roles.assistant: PromptMarkers("""<|Assistant|>""", ""), + Roles.tool: PromptMarkers("", ""), +} """ ### Instruction: @@ -253,6 +276,15 @@ def _format_response( mixtral_prompt_markers, True, [""], + strip_prompt=False, #added +) + +mistral_small3_formatter = MessagesFormatter( + "", + mistral_small3_prompt_markers, + True, + [""], + strip_prompt=False, #added ) chatml_formatter = MessagesFormatter( @@ -284,7 +316,7 @@ def _format_response( False, ["assistant", "<|eot_id|>"], use_user_role_for_function_call_result=False, - strip_prompt=True, + strip_prompt=False, ) synthia_formatter = MessagesFormatter( @@ -348,6 +380,14 @@ def _format_response( use_user_role_for_function_call_result=True, ) +phi_4_chat_formatter = MessagesFormatter( + "", + phi_4_chat_prompt_markers, + True, + ["<|im_end|>", "<|endoftext|>"], + use_user_role_for_function_call_result=True, +) + open_interpreter_chat_formatter = MessagesFormatter( "You are an AI programming assistant, utilizing the Deepseek Coder model, developed by Deepseek Company, and you only answer questions related to computer science. For politically sensitive questions, security and privacy issues, and other non-computer science questions, you will refuse to answer.\n", open_interpreter_chat_prompt_markers, @@ -369,7 +409,8 @@ def _format_response( "", gemma_2_prompt_markers, True, - ["", ""] + ["", ""], + strip_prompt=False, #added ) deep_seek_coder_2_chat_formatter = MessagesFormatter( @@ -380,6 +421,14 @@ def _format_response( bos_token="<|begin▁of▁sentence|>", eos_token="<|end▁of▁sentence|>", ) +deepseek_r1_distill_qwen_chat_formatter = MessagesFormatter( + "", + deepseek_r1_distill_qwen_chat_prompt_markers, + True, + ["<|end▁of▁sentence|>"], + bos_token="<|begin▁of▁sentence|>", + eos_token="<|end▁of▁sentence|>", +) predefined_formatter = { MessagesFormatterType.MISTRAL: mixtral_formatter, @@ -398,7 +447,10 @@ def _format_response( MessagesFormatterType.OPEN_INTERPRETER: open_interpreter_chat_formatter, MessagesFormatterType.AUTOCODER: autocoder_chat_formatter, MessagesFormatterType.GEMMA_2: gemma_2_chat_formatter, - MessagesFormatterType.DEEP_SEEK_CODER_2: deep_seek_coder_2_chat_formatter + MessagesFormatterType.DEEP_SEEK_CODER_2: deep_seek_coder_2_chat_formatter, + MessagesFormatterType.PHI_4: phi_4_chat_formatter, + MessagesFormatterType.DEEPSEEK_R1_DISTILL_QWEN: deepseek_r1_distill_qwen_chat_formatter, + MessagesFormatterType.MISTRAL_SMALL_3: mistral_small3_formatter, }