diff --git a/include/minja/chat-template.hpp b/include/minja/chat-template.hpp index 142a58e..ab5b521 100644 --- a/include/minja/chat-template.hpp +++ b/include/minja/chat-template.hpp @@ -395,8 +395,8 @@ class chat_template { for (const auto & message_ : adjusted_messages) { auto message = message_; - if (!message.contains("role") || !message.contains("content")) { - throw std::runtime_error("message must have 'role' and 'content' fields: " + message.dump()); + if (!message.contains("role") || (!message.contains("content") && !message.contains("tool_calls"))) { + throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump()); } std::string role = message.at("role"); @@ -417,7 +417,6 @@ class chat_template { } } if (polyfill_tool_calls) { - auto content = message.at("content"); auto tool_calls = json::array(); for (const auto & tool_call : message.at("tool_calls")) { if (tool_call.at("type") != "function") { @@ -436,8 +435,11 @@ class chat_template { auto obj = json { {"tool_calls", tool_calls}, }; - if (!content.is_null() && !content.empty()) { - obj["content"] = content; + if (message.contains("content")) { + auto content = message.at("content"); + if (!content.is_null() && !content.empty()) { + obj["content"] = content; + } } message["content"] = obj.dump(2); message.erase("tool_calls");