Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
blocks echoed back by Anthropic clients.
"""

import pytest
from pydantic import ValidationError

from vllm.entrypoints.anthropic.protocol import (
AnthropicCountTokensRequest,
AnthropicMessagesRequest,
)
from vllm.entrypoints.anthropic.serving import AnthropicServingMessages
Expand All @@ -30,6 +34,86 @@ def _make_request(
)


# ======================================================================
# Claude Code role normalization
# ======================================================================


class TestClaudeCodeRoleNormalization:
def test_system_role_message_moves_to_system_prompt(self):
request = _make_request(
[
{"role": "system", "content": "You are terse."},
{"role": "user", "content": "Hello"},
]
)

result = _convert(request)

assert result.messages[0] == {
"role": "system",
"content": "You are terse.",
}
assert result.messages[1] == {"role": "user", "content": "Hello"}

def test_system_role_message_merges_after_existing_system(self):
request = _make_request(
[
{"role": "system", "content": "Add this too."},
{"role": "user", "content": "Hello"},
],
system="Keep this. ",
)

result = _convert(request)

assert result.messages[0] == {
"role": "system",
"content": "Keep this. Add this too.",
}

def test_system_role_message_preserves_content_blocks(self):
request = _make_request(
[
{
"role": "system",
"content": [{"type": "text", "text": "Block system."}],
},
{"role": "user", "content": "Hello"},
],
system=[{"type": "text", "text": "Existing system. "}],
)

result = _convert(request)

assert result.messages[0] == {
"role": "system",
"content": "Existing system. Block system.",
}

def test_count_tokens_request_normalizes_claude_code_system_role(self):
request = AnthropicCountTokensRequest(
model="test-model",
system="Keep this. ",
messages=[
{"role": "system", "content": "Add this too."},
{"role": "user", "content": "User request"},
],
)

result = _convert(request)

assert [message.role for message in request.messages] == ["user"]
assert result.messages == [
{"role": "system", "content": "Keep this. Add this too."},
{"role": "user", "content": "User request"},
]

def test_unknown_role_still_fails_validation(self):
with pytest.raises(ValidationError):
_make_request([{"role": "tool", "content": "not accepted"}])


# ======================================================================
# _convert_image_source_to_url
# ======================================================================
Expand Down
79 changes: 79 additions & 0 deletions vllm/entrypoints/anthropic/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,75 @@
from pydantic import BaseModel, Field, field_validator, model_validator


def _content_to_system_blocks(content: Any) -> list[Any]:
if content is None:
return []
if isinstance(content, str):
return [{"type": "text", "text": content}]
if isinstance(content, list):
return content
return [{"type": "text", "text": str(content)}]


def _merge_system_content(
existing_system: Any, system_message_content: list[Any]
) -> str | list[Any] | None:
if not system_message_content:
return existing_system

if existing_system is None and len(system_message_content) == 1:
content = system_message_content[0]
if isinstance(content, str):
return content
if isinstance(content, list):
return content

system_blocks: list[Any] = []
if existing_system is not None:
system_blocks.extend(_content_to_system_blocks(existing_system))
for content in system_message_content:
system_blocks.extend(_content_to_system_blocks(content))
return system_blocks


def _normalize_claude_code_message_roles(data: Any) -> Any:
"""Move Claude Code system messages into the Anthropic system field."""
if not isinstance(data, dict):
return data

messages = data.get("messages")
if not isinstance(messages, list):
return data

normalized_messages: list[Any] = []
system_message_content: list[Any] = []
changed = False

for message in messages:
if not isinstance(message, dict):
normalized_messages.append(message)
continue

role = message.get("role")
if role == "system":
system_message_content.append(message.get("content", ""))
changed = True
continue

normalized_messages.append(message)

if not changed:
return data

normalized_data = dict(data)
normalized_data["messages"] = normalized_messages
if system_message_content:
normalized_data["system"] = _merge_system_content(
data.get("system"), system_message_content
)
return normalized_data


class AnthropicError(BaseModel):
"""Error structure for Anthropic API"""

Expand Down Expand Up @@ -144,6 +213,11 @@ class AnthropicMessagesRequest(BaseModel):
),
)

@model_validator(mode="before")
@classmethod
def normalize_claude_code_message_roles(cls, data: Any) -> Any:
return _normalize_claude_code_message_roles(data)

@field_validator("model")
@classmethod
def validate_model(cls, v):
Expand Down Expand Up @@ -247,6 +321,11 @@ class AnthropicCountTokensRequest(BaseModel):
),
)

@model_validator(mode="before")
@classmethod
def normalize_claude_code_message_roles(cls, data: Any) -> Any:
return _normalize_claude_code_message_roles(data)

@field_validator("model")
@classmethod
def validate_model(cls, v):
Expand Down
Loading