Skip to content

[Bugfix] Sanitize malformed tool call recipients in Harmony parser#31677

Open
eous wants to merge 1 commit intovllm-project:mainfrom
eous:fix/harmony-malformed-recipient-sanitization
Open

[Bugfix] Sanitize malformed tool call recipients in Harmony parser#31677
eous wants to merge 1 commit intovllm-project:mainfrom
eous:fix/harmony-malformed-recipient-sanitization

Conversation

@eous
Copy link
Copy Markdown

@eous eous commented Jan 4, 2026

Some GPT-OSS base models occasionally generate malformed Harmony format sequences like to=functions.bash<|channel|>commentary instead of the correct to=functions.bash <|constrain|>json. This causes the function name to be parsed incorrectly as bash<|channel|>commentary instead of bash.

This fix sanitizes the recipient string by stripping <|channel|> and everything after it before extracting the function name. The fix is applied in three locations to cover all API endpoints:

  • harmony_utils.py: /v1/responses (non-streaming)
  • openai_tool_parser.py: /v1/chat/completions (non-streaming)
  • serving_chat_stream_harmony.py: /v1/chat/completions (streaming)

The /v1/responses streaming endpoint already worked correctly because it captures the recipient before malformed tokens can corrupt it.

  • Before: ~35% failure rate
  • After: 0% failure rate

Purpose

Improve tool usage success rate with gpt-oss-20b

Test Plan

Test Result


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Some GPT-OSS base models occasionally generate malformed Harmony format
sequences like `to=functions.bash<|channel|>commentary` instead of the
correct `to=functions.bash <|constrain|>json`. This causes the function
name to be parsed incorrectly as `bash<|channel|>commentary` instead of
`bash`.

This fix sanitizes the recipient string by stripping `<|channel|>` and
everything after it before extracting the function name. The fix is
applied in three locations to cover all API endpoints:

- `harmony_utils.py`: /v1/responses (non-streaming)
- `openai_tool_parser.py`: /v1/chat/completions (non-streaming)
- `serving_chat_stream_harmony.py`: /v1/chat/completions (streaming)

The /v1/responses streaming endpoint already worked correctly because
it captures the recipient before malformed tokens can corrupt it.

- Before: ~35% failure rate
- After: 0% failure rate
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug where malformed tool call recipients from certain models were causing parsing failures. The fix involves sanitizing the recipient string in three different locations to handle these malformed sequences. The changes are correct and a test case has been added to verify the fix. My main feedback is to refactor the duplicated sanitization logic into a single shared utility function to improve code maintainability. I've left comments in the relevant files with suggestions on how to achieve this.

Comment on lines +538 to +542
# Sanitize recipient: the model sometimes outputs malformed sequences
# like "to=functions.bash<|channel|>commentary" instead of the correct
# "to=functions.bash <|constrain|>json". Strip the malformed part.
if "<|channel|>" in recipient:
recipient = recipient.split("<|channel|>")[0].strip()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This sanitization logic is duplicated in vllm/tool_parsers/openai_tool_parser.py and vllm/entrypoints/openai/serving_chat_stream_harmony.py. To improve maintainability and avoid code duplication, consider extracting this logic into a shared utility function. You could add a sanitize_recipient function in this file and then use it in all three places.

For example, you could add:

def sanitize_recipient(recipient: str) -> str:
    """Sanitizes a malformed tool call recipient by stripping `<|channel|>` and anything after it."""
    if "<|channel|>" in recipient:
        return recipient.split("<|channel|>")[0].strip()
    return recipient

And then replace this block with recipient = sanitize_recipient(recipient).

Comment on lines +51 to +56
# Sanitize recipient: the model sometimes outputs malformed sequences
# like "functions.bash<|channel|>commentary" instead of "functions.bash".
# Strip the malformed part.
sanitized_recipient = cur_recipient
if "<|channel|>" in sanitized_recipient:
sanitized_recipient = sanitized_recipient.split("<|channel|>")[0].strip()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This sanitization logic is duplicated in other files. To improve maintainability, it should be centralized into a single utility function. Please see my comment in vllm/entrypoints/openai/parser/harmony_utils.py for a detailed suggestion.

Comment on lines +68 to +73
# Sanitize recipient: the model sometimes outputs malformed
# sequences like "functions.bash<|channel|>commentary"
# instead of "functions.bash". Strip the malformed part.
recipient = msg.recipient
if "<|channel|>" in recipient:
recipient = recipient.split("<|channel|>")[0].strip()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This sanitization logic is duplicated in other files. To improve maintainability, it should be centralized into a single utility function. Please see my comment in vllm/entrypoints/openai/parser/harmony_utils.py for a detailed suggestion.

@github-actions
Copy link
Copy Markdown

github-actions bot commented Jan 4, 2026

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors.

You ask your reviewers to trigger select CI tests on top of fastcheck CI.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

If you have any questions, please reach out to us on Slack at https://slack.vllm.ai.

🚀

@mergify
Copy link
Copy Markdown

mergify bot commented Jan 4, 2026

Hi @eous, the pre-commit checks have failed. Please run:

uv pip install pre-commit
pre-commit install
pre-commit run --all-files

Then, commit the changes and push to your branch.

For future commits, pre-commit will run automatically on changed files before each commit.

Tip

Is mypy or markdownlint failing?
mypy and markdownlint are run differently in CI. If the failure is related to either of these checks, please use the following commands to run them locally:
# For mypy (substitute "3.10" with the failing version if needed)
pre-commit run --hook-stage manual mypy-3.10
# For markdownlint
pre-commit run --hook-stage manual markdownlint

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot encountered an error and was unable to review this pull request. You can try again by re-requesting a review.

@bbrowning
Copy link
Copy Markdown
Contributor

I have seen this in the wild specifically with gpt-oss-20b in multi-turn tool calling scenarios. I don't regularly see it with gpt-oss-120b, but that doesn't mean it cannot happen there.

I agree with the Gemini review bot suggestion to centralize the logic to sanitize the recipient instead of doing it in three separate places. And, longer-term, this may be something we want to try to push down to the https://github.com/openai/harmony library instead of doing it directly in vLLM, as ultimately it's that library that is parsing wrong here and ended up with the channel tokens as part of the recipient.

@mergify mergify bot added the bug Something isn't working label Jan 14, 2026
@mergify
Copy link
Copy Markdown

mergify bot commented Jan 14, 2026

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @eous.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Jan 14, 2026
@thomasgeulen
Copy link
Copy Markdown

Is there any possibility of merging this into the main branch? I’m experiencing the same issues and would appreciate being able to remove my workarounds.

@eous
Copy link
Copy Markdown
Author

eous commented Feb 5, 2026

Is there any possibility of merging this into the main branch? I’m experiencing the same issues and would appreciate being able to remove my workarounds.

Sorry I gave up on this route and just ended up doing a full sft of gpt oss 20b to fix the issue and improve tool use in general (https://huggingface.co/eousphoros/persona_theta_20b_131k). If there is interest in this still I can clean it up and rebase but really OpenAI should just fix their base model to adhere to their own specification for harmony.

@bbrowning
Copy link
Copy Markdown
Contributor

I opened openai/harmony#97 to attempt to fix this in openai/harmony library itself. That would take a new release of that library and a bump of it to get this fix in vLLM. So, it could be reasonable to do something in the interim directly in vLLM.

I regularly see this type of failure with gpt-oss-20b models in many deep multi-turn tool calling scenario. It often doesn't show up until at least 10 turns in with various tool calls. As far as I can tell so far, the model is being prompted properly and it's just struggling to not emit additional channel markers when it's constructing recipients.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working frontend gpt-oss Related to GPT-OSS models needs-rebase tool-calling

Projects

Status: No status
Status: To Triage

Development

Successfully merging this pull request may close these issues.

4 participants