Skip to content

Commit c84f4d9

Browse files
authored
Fix special character escaping in tool parsers (#3693)
Co-authored-by: Copilot <[email protected]> Co-authored-by: Adrian Tobiszewski <[email protected]> This PR fixes special character escaping in tool parsers to properly handle complex arguments containing characters like quotes, newlines, tabs, and backslashes. The changes standardize escape handling across different model output parsers and improve JSON parsing robustness.
1 parent f630543 commit c84f4d9

File tree

12 files changed

+497
-28
lines changed

12 files changed

+497
-28
lines changed

src/llm/apis/openai_completions.cpp

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,34 @@ static absl::Status downloadImage(const char* url, std::string& image, const int
134134
return absl::OkStatus();
135135
}
136136

137+
absl::Status OpenAIChatCompletionsHandler::ensureArgumentsInToolCalls(Value& messageObj, bool& jsonChanged) {
138+
auto& allocator = doc.GetAllocator();
139+
auto toolCallsIt = messageObj.FindMember("tool_calls");
140+
if (toolCallsIt != messageObj.MemberEnd() && toolCallsIt->value.IsArray()) {
141+
const auto& toolCallsArray = toolCallsIt->value.GetArray();
142+
for (rapidjson::SizeType j = 0; j < toolCallsArray.Size(); ++j) {
143+
auto& toolCall = toolCallsArray[j];
144+
if (!toolCall.IsObject()) {
145+
return absl::InvalidArgumentError("Each tool_call must be an object");
146+
}
147+
auto functionIt = toolCall.FindMember("function");
148+
if (functionIt == toolCall.MemberEnd() || !functionIt->value.IsObject()) {
149+
return absl::InvalidArgumentError("Each tool_call must have a 'function' object");
150+
}
151+
const auto& functionObj = functionIt->value.GetObject();
152+
if (functionObj.FindMember("arguments") == functionObj.MemberEnd()) {
153+
// Add "arguments": "{}"
154+
rapidjson::Value argumentsKey("arguments", allocator);
155+
rapidjson::Value argumentsValue;
156+
argumentsValue.SetString("{}", allocator);
157+
functionIt->value.GetObject().AddMember(argumentsKey, argumentsValue, allocator);
158+
jsonChanged = true;
159+
}
160+
}
161+
}
162+
return absl::OkStatus();
163+
}
164+
137165
absl::Status OpenAIChatCompletionsHandler::parseMessages(std::optional<std::string> allowedLocalMediaPath) {
138166
auto it = doc.FindMember("messages");
139167
if (it == doc.MemberEnd())
@@ -264,10 +292,21 @@ absl::Status OpenAIChatCompletionsHandler::parseMessages(std::optional<std::stri
264292
}
265293
}
266294
}
267-
const auto& lastMessage = request.chatHistory.back();
295+
auto& lastMessage = request.chatHistory.back();
268296
if (lastMessage.find("role") == lastMessage.end()) {
269297
return absl::InvalidArgumentError("Every message must have 'role' field");
270298
}
299+
if (lastMessage.find("content") == lastMessage.end()) {
300+
SPDLOG_LOGGER_DEBUG(llm_calculator_logger, "Message does not have content field which might be an issue for some chat templates. Adding empty content.");
301+
lastMessage["content"] = "";
302+
obj.AddMember("content", Value().SetString("", doc.GetAllocator()), doc.GetAllocator());
303+
jsonChanged = true;
304+
}
305+
// If message has tool calls, make sure each tool call has "arguments" field
306+
auto status = ensureArgumentsInToolCalls(obj, jsonChanged);
307+
if (status != absl::OkStatus()) {
308+
return status;
309+
}
271310
}
272311
if (jsonChanged) {
273312
StringBuffer buffer;

src/llm/apis/openai_completions.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class OpenAIChatCompletionsHandler {
7878
absl::Status parseCommonPart(std::optional<uint32_t> maxTokensLimit, uint32_t bestOfLimit, std::optional<uint32_t> maxModelLength);
7979

8080
ParsedOutput parseOutputIfNeeded(const std::vector<int64_t>& generatedIds);
81+
absl::Status ensureArgumentsInToolCalls(Value& messageObj, bool& jsonChanged);
8182

8283
public:
8384
OpenAIChatCompletionsHandler(Document& doc, Endpoint endpoint, std::chrono::time_point<std::chrono::system_clock> creationTime,

src/llm/io_processing/hermes3/tool_parser.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
#include "../../../logging.hpp"
2929
#include "tool_parser.hpp"
3030
#include "../utils.hpp"
31+
#include "src/stringutils.hpp"
3132

3233
namespace ovms {
3334

@@ -293,10 +294,8 @@ std::optional<rapidjson::Document> Hermes3ToolParser::parseChunk(const std::stri
293294
*/
294295

295296
if (lastJson.HasMember("arguments")) {
296-
// Escaping double quotes in the arguments string
297-
for (size_t pos = 0; (pos = modifiedChunk.find("\"", pos)) != std::string::npos; pos += 2) {
298-
modifiedChunk.insert(pos, "\\");
299-
}
297+
// Since inside a string, we need to escape characters like quotes, new lines, tabs, etc.
298+
escapeSpecialCharacters(modifiedChunk);
300299

301300
bool processingFirstArgumentsChunk = argumentsDelayWindow[0].empty();
302301
// Handle the case when we are starting to collect arguments.

src/llm/io_processing/llama3/tool_parser.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "../../../logging.hpp"
3030
#include "tool_parser.hpp"
3131
#include "../utils.hpp"
32+
#include "src/stringutils.hpp"
3233

3334
namespace ovms {
3435
void Llama3ToolParser::parse(ParsedOutput& parsedOutput, const std::vector<int64_t>& generatedTokens) {
@@ -158,10 +159,8 @@ std::optional<rapidjson::Document> Llama3ToolParser::parseChunk(const std::strin
158159
// JSON already contains 'parameters'/'arguments' (they cannot be null at this point). Apply modifications to the input chunk if needed to keep the format valid.
159160
if (jsonHasArgumentsOrParameters(lastJson)) {
160161
std::string modifiedChunk = chunk;
161-
// Escaping all double quotes in the parameters/arguments string
162-
for (size_t pos = 0; (pos = modifiedChunk.find("\"", pos)) != std::string::npos; pos += 2) {
163-
modifiedChunk.insert(pos, "\\");
164-
}
162+
// Since inside a string, we need to escape characters like quotes, new lines, tabs, etc.
163+
escapeSpecialCharacters(modifiedChunk);
165164

166165
// Handle the case when we are starting to collect parameters/arguments.
167166
// Force parameters/arguments string type and fill first element of the delay array.

src/llm/io_processing/partial_json_builder.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,18 @@ Document PartialJsonBuilder::add(const std::string& chunk) {
202202
}
203203
}
204204
} else if (c == '\\') {
205-
finishedWithEscapeCharacter = true;
205+
// Count consecutive backslashes before current position
206+
auto backslashIt = it;
207+
// Start with 1 since we found one backslash already
208+
int backslashCount = 1;
209+
while (backslashIt != buffer.begin() && *(backslashIt - 1) == '\\') {
210+
--backslashIt;
211+
++backslashCount;
212+
}
213+
if (backslashCount % 2 != 0) {
214+
// Odd number of backslashes finishing the buffer: current backslash is escaping the next character
215+
finishedWithEscapeCharacter = true;
216+
}
206217
}
207218
}
208219
}
@@ -264,7 +275,7 @@ Document PartialJsonBuilder::add(const std::string& chunk) {
264275
}
265276
doc.Parse(closedInput.c_str());
266277
if (doc.HasParseError()) {
267-
throw std::runtime_error("Invalid JSON. Content:\n" + closedInput);
278+
throw std::runtime_error("Invalid JSON. Content with closure attempt:\n" + closedInput + "\nOriginal content:\n" + buffer);
268279
}
269280
return doc;
270281
}

src/llm/io_processing/phi4/tool_parser.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "../../../logging.hpp"
3030
#include "tool_parser.hpp"
3131
#include "../utils.hpp"
32+
#include "src/stringutils.hpp"
3233

3334
namespace ovms {
3435

@@ -257,15 +258,10 @@ std::optional<rapidjson::Document> Phi4ToolParser::parseChunk(const std::string&
257258
return std::nullopt;
258259
}
259260
} else { // internalState == PROCESSING_TOOL_CALL
260-
// Remove any newlines to avoid breaking JSON format
261-
modifiedChunk.erase(std::remove(modifiedChunk.begin(), modifiedChunk.end(), '\n'), modifiedChunk.end());
262-
263261
// JSON already contains 'arguments' (they cannot be null at this point). Apply modifications to the input chunk if needed to keep the format valid.
264262
if (processingArguments) {
265-
// Escaping double quotes in the arguments string
266-
for (size_t pos = 0; (pos = modifiedChunk.find("\"", pos)) != std::string::npos; pos += 2) {
267-
modifiedChunk.insert(pos, "\\");
268-
}
263+
// Since inside a string, we need to escape characters like quotes, new lines, tabs, etc.
264+
escapeSpecialCharacters(modifiedChunk);
269265

270266
// Keep track of opened/closed braces to identify the end of the tool call object.
271267
updateOpenBracesCount(modifiedChunk);
@@ -282,6 +278,9 @@ std::optional<rapidjson::Document> Phi4ToolParser::parseChunk(const std::string&
282278
// If we balanced the braces, we are at the end of the tool call object
283279
handleEndOfToolCall(modifiedChunk);
284280
}
281+
} else {
282+
// Remove any newlines to avoid breaking JSON format
283+
modifiedChunk.erase(std::remove(modifiedChunk.begin(), modifiedChunk.end(), '\n'), modifiedChunk.end());
285284
}
286285

287286
// Phase 2: Parse the modified chunk with PartialJsonBuilder and return appropriate delta if possible

src/stringutils.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,4 +255,37 @@ bool stringsOverlap(const std::string& lhs, const std::string& rhs) {
255255
return false;
256256
}
257257

258+
void escapeSpecialCharacters(std::string& text) {
259+
// Escape all double quotes, backslashes, and control characters in the text
260+
std::string escaped;
261+
for (char c : text) {
262+
switch (c) {
263+
case '\"':
264+
escaped += "\\\"";
265+
break;
266+
case '\\':
267+
escaped += "\\\\";
268+
break;
269+
case '\b':
270+
escaped += "\\b";
271+
break;
272+
case '\f':
273+
escaped += "\\f";
274+
break;
275+
case '\n':
276+
escaped += "\\n";
277+
break;
278+
case '\r':
279+
escaped += "\\r";
280+
break;
281+
case '\t':
282+
escaped += "\\t";
283+
break;
284+
default:
285+
escaped += c;
286+
}
287+
}
288+
text = escaped;
289+
}
290+
258291
} // namespace ovms

src/stringutils.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,6 @@ std::string toLower(const std::string& input);
111111

112112
bool stringsOverlap(const std::string& lhs, const std::string& rhs);
113113

114+
void escapeSpecialCharacters(std::string& text);
115+
114116
} // namespace ovms

src/test/llm/output_parsers/llama3_output_parser_test.cpp

Lines changed: 124 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ const std::string tokenizerPath = getWindowsRepoRootPath() + "\\src\\test\\llm_t
3131
const std::string tokenizerPath = "/ovms/src/test/llm_testing/meta-llama/Llama-3.1-8B-Instruct";
3232
#endif
3333

34-
static const ovms::ToolsSchemas_t EMPTY_TOOL_SCHEMA = {}; // not used for llama3
34+
static const ovms::ToolsSchemas_t EMPTY_TOOLS_SCHEMA = {}; // not used for llama3
3535
static std::unique_ptr<ov::genai::Tokenizer> llama3Tokenizer;
3636

3737
// Id of the <|python_tag|> which is a special token used to indicate the start of a tool calls
@@ -57,8 +57,8 @@ class Llama3OutputParserTest : public ::testing::Test {
5757
}
5858

5959
void SetUp() override {
60-
outputParserWithRegularToolParsing = std::make_unique<OutputParser>(*llama3Tokenizer, "llama3", "", EMPTY_TOOL_SCHEMA);
61-
outputParserWithImmediateToolParsing = std::make_unique<OutputParser>(*llama3Tokenizer, "llama3", "", EMPTY_TOOL_SCHEMA);
60+
outputParserWithRegularToolParsing = std::make_unique<OutputParser>(*llama3Tokenizer, "llama3", "", EMPTY_TOOLS_SCHEMA);
61+
outputParserWithImmediateToolParsing = std::make_unique<OutputParser>(*llama3Tokenizer, "llama3", "", EMPTY_TOOLS_SCHEMA);
6262
outputParserWithImmediateToolParsing->enableImmediateToolParsing();
6363
}
6464
};
@@ -223,7 +223,7 @@ TEST_F(Llama3OutputParserTest, HolisticStreaming) {
223223

224224
for (auto lastFinishReason : {ov::genai::GenerationFinishReason::NONE, ov::genai::GenerationFinishReason::STOP, ov::genai::GenerationFinishReason::LENGTH}) {
225225
// Need to have new output parser per case to simulate separate request processing
226-
outputParserWithRegularToolParsing = std::make_unique<OutputParser>(*llama3Tokenizer, "llama3", "", EMPTY_TOOL_SCHEMA);
226+
outputParserWithRegularToolParsing = std::make_unique<OutputParser>(*llama3Tokenizer, "llama3", "", EMPTY_TOOLS_SCHEMA);
227227
auto chunkToDeltaVecCopy = chunkToDeltaVec;
228228
if (lastFinishReason == ov::genai::GenerationFinishReason::NONE) {
229229
chunkToDeltaVecCopy.push_back({"Paris\"}}", ov::genai::GenerationFinishReason::NONE, R"({"delta":{"tool_calls":[{"index":1,"function":{"arguments":" \""}}]}})"});
@@ -284,6 +284,126 @@ TEST_F(Llama3OutputParserTest, HolisticStreaming) {
284284
}
285285
}
286286

287+
// Positive test for streaming tool calls with complex arguments containing special characters
288+
TEST_F(Llama3OutputParserTest, StreamingToolWithComplexArguments) {
289+
std::vector<std::tuple<std::string, std::optional<std::string>>> chunkToDeltaVec{
290+
{"{\"", std::nullopt},
291+
{"name", std::nullopt},
292+
{"\":", std::nullopt},
293+
{" \"", std::nullopt},
294+
{"python_code", std::nullopt},
295+
{"_", std::nullopt},
296+
{"execution_tool", std::nullopt},
297+
{"\",", std::nullopt},
298+
{" \"", std::nullopt},
299+
{"arguments", std::nullopt},
300+
// As we have 'arguments' key present, we can return first delta
301+
{"\":", "{\"delta\":{\"tool_calls\":[{\"id\":\"XXXXXXXXX\",\"type\":\"function\",\"index\":0,\"function\":{\"name\":\"python_code_execution_tool\"}}]}}"},
302+
// Consecutive deltas without 'id' and 'type'. In order to find the end of arguments parser has one chunk delay to handle end of tool.
303+
{" {", std::nullopt},
304+
{"\"", "{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"{\"}}]}}"},
305+
{"function", "{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"\"}}]}}"},
306+
{"\": ", "{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"function\"}}]}}"},
307+
{"\"", "{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\": \"}}]}}"},
308+
/*
309+
Next chunks will simulate sending piece of Python code as argument value.
310+
```python
311+
def example_function(arg1, arg2):
312+
nested_dict = {"nested_arg1": "nested_value1", "nested_arg2": "nested_value2"}
313+
if arg1 == "value1" and arg2 == "arg2":
314+
return nested_dict
315+
else:
316+
return {}
317+
```
318+
*/
319+
{"def example_function(arg1, arg2):\n", "{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"\"}}]}}"},
320+
321+
{"\tnested_dict = {\"nested_arg1\": \"nested_value1\", \"nested_arg2\": \"nested_value2\"}\n",
322+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"def example_function(arg1, arg2):\\n\"}}]}}"},
323+
{"\tif arg1 == \"value1\" and arg2 == \"arg2\":\n",
324+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\tnested_dict = {\\\"nested_arg1\\\": \\\"nested_value1\\\", \\\"nested_arg2\\\": \\\"nested_value2\\\"}\\n\"}}]}}"},
325+
{"\t\treturn nested_dict\n",
326+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\tif arg1 == \\\"value1\\\" and arg2 == \\\"arg2\\\":\\n\"}}]}}"},
327+
{"\telse:\n",
328+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\t\\treturn nested_dict\\n\"}}]}}"},
329+
{"\t\treturn {}\n",
330+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\telse:\\n\"}}]}}"},
331+
{"nested_arg1",
332+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\t\\treturn {}\\n\"}}]}}"},
333+
{"\": ",
334+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"nested_arg1\"}}]}}"},
335+
{"\"",
336+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\": \"}}]}}"},
337+
{"nested_value1",
338+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"\"}}]}}"},
339+
{"\", ",
340+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"nested_value1\"}}]}}"},
341+
{"\"",
342+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\", \"}}]}}"},
343+
{"nested_arg2",
344+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"\"}}]}}"},
345+
{"\": ",
346+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"nested_arg2\"}}]}}"},
347+
{"\"",
348+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\": \"}}]}}"},
349+
{"nested_value2",
350+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"\\\"\"}}]}}"},
351+
{"\"}}}",
352+
"{\"delta\":{\"tool_calls\":[{\"index\":0,\"function\":{\"arguments\":\"nested_value2\"}}]}}"},
353+
};
354+
355+
auto outputParser = std::make_unique<OutputParser>(*llama3Tokenizer, "llama3", "", EMPTY_TOOLS_SCHEMA);
356+
for (const auto& [chunk, expectedDelta] : chunkToDeltaVec) {
357+
std::optional<rapidjson::Document> doc = outputParser->parseChunk(chunk, true, ov::genai::GenerationFinishReason::NONE);
358+
if (!expectedDelta.has_value() && !doc.has_value()) {
359+
continue; // Both are nullopt, OK
360+
}
361+
if (expectedDelta.has_value() && doc.has_value()) {
362+
rapidjson::StringBuffer buffer;
363+
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
364+
doc->Accept(writer);
365+
std::string docStr = buffer.GetString();
366+
// If both strings contain "id":"...", compare id values by length and alphanumeric, else compare whole strings
367+
std::string expected = expectedDelta.value();
368+
std::string idKey = "\"id\":\"";
369+
auto docIdPos = docStr.find(idKey);
370+
auto expectedIdPos = expected.find(idKey);
371+
if (docIdPos != std::string::npos && expectedIdPos != std::string::npos) {
372+
auto docIdStart = docIdPos + idKey.size();
373+
auto docIdEnd = docStr.find("\"", docIdStart);
374+
auto expectedIdStart = expectedIdPos + idKey.size();
375+
auto expectedIdEnd = expected.find("\"", expectedIdStart);
376+
ASSERT_NE(docIdEnd, std::string::npos);
377+
ASSERT_NE(expectedIdEnd, std::string::npos);
378+
std::string docId = docStr.substr(docIdStart, docIdEnd - docIdStart);
379+
std::string expectedId = expected.substr(expectedIdStart, expectedIdEnd - expectedIdStart);
380+
EXPECT_EQ(docId.size(), expectedId.size()) << "ID length mismatch for chunk: " << chunk;
381+
EXPECT_TRUE(std::all_of(docId.begin(), docId.end(), ::isalnum)) << "ID not alphanumeric for chunk: " << chunk;
382+
// Compare everything except the id value
383+
std::string docStrNoId = docStr;
384+
std::string expectedNoId = expected;
385+
docStrNoId.replace(docIdStart, docId.size(), std::string(docId.size(), '*'));
386+
expectedNoId.replace(expectedIdStart, expectedId.size(), std::string(expectedId.size(), '*'));
387+
EXPECT_EQ(docStrNoId, expectedNoId) << "Mismatch for chunk (ignoring id value): " << chunk;
388+
} else {
389+
EXPECT_EQ(docStr, expected) << "Mismatch for chunk: " << chunk << " Received: " << docStr << ", expected: " << expected;
390+
}
391+
} else {
392+
std::string expectedStr = expectedDelta.has_value() ? expectedDelta.value() : "std::nullopt";
393+
std::string docStr = doc.has_value() ? [&]() {
394+
rapidjson::StringBuffer buffer;
395+
rapidjson::Writer<rapidjson::StringBuffer> writer(buffer);
396+
doc->Accept(writer);
397+
return std::string(buffer.GetString());
398+
}()
399+
: "std::nullopt";
400+
FAIL() << "Mismatch between expectedDelta and doc for chunk: " << chunk
401+
<< "\nexpectedDelta: " << expectedStr
402+
<< "\ndoc: " << docStr;
403+
}
404+
}
405+
}
406+
287407
TEST_F(Llama3OutputParserTest, ToolCallsWithoutToolsInTheRequestStreaming) {
288408
std::vector<std::pair<std::string, std::optional<std::string>>> chunkToDeltaVec{
289409
// Tool parser is available, but tools are not in the request so every chunk is just a regular content

0 commit comments

Comments
 (0)