diff --git a/tests/entrypoints/anthropic/test_anthropic_messages_conversion.py b/tests/entrypoints/anthropic/test_anthropic_messages_conversion.py index eb9798980f06..de3dbb50cbe5 100644 --- a/tests/entrypoints/anthropic/test_anthropic_messages_conversion.py +++ b/tests/entrypoints/anthropic/test_anthropic_messages_conversion.py @@ -9,7 +9,11 @@ blocks echoed back by Anthropic clients. """ +import pytest +from pydantic import ValidationError + from vllm.entrypoints.anthropic.protocol import ( + AnthropicCountTokensRequest, AnthropicMessagesRequest, ) from vllm.entrypoints.anthropic.serving import AnthropicServingMessages @@ -30,6 +34,86 @@ def _make_request( ) +# ====================================================================== +# Claude Code role normalization +# ====================================================================== + + +class TestClaudeCodeRoleNormalization: + def test_system_role_message_moves_to_system_prompt(self): + request = _make_request( + [ + {"role": "system", "content": "You are terse."}, + {"role": "user", "content": "Hello"}, + ] + ) + + result = _convert(request) + + assert result.messages[0] == { + "role": "system", + "content": "You are terse.", + } + assert result.messages[1] == {"role": "user", "content": "Hello"} + + def test_system_role_message_merges_after_existing_system(self): + request = _make_request( + [ + {"role": "system", "content": "Add this too."}, + {"role": "user", "content": "Hello"}, + ], + system="Keep this. ", + ) + + result = _convert(request) + + assert result.messages[0] == { + "role": "system", + "content": "Keep this. Add this too.", + } + + def test_system_role_message_preserves_content_blocks(self): + request = _make_request( + [ + { + "role": "system", + "content": [{"type": "text", "text": "Block system."}], + }, + {"role": "user", "content": "Hello"}, + ], + system=[{"type": "text", "text": "Existing system. "}], + ) + + result = _convert(request) + + assert result.messages[0] == { + "role": "system", + "content": "Existing system. Block system.", + } + + def test_count_tokens_request_normalizes_claude_code_system_role(self): + request = AnthropicCountTokensRequest( + model="test-model", + system="Keep this. ", + messages=[ + {"role": "system", "content": "Add this too."}, + {"role": "user", "content": "User request"}, + ], + ) + + result = _convert(request) + + assert [message.role for message in request.messages] == ["user"] + assert result.messages == [ + {"role": "system", "content": "Keep this. Add this too."}, + {"role": "user", "content": "User request"}, + ] + + def test_unknown_role_still_fails_validation(self): + with pytest.raises(ValidationError): + _make_request([{"role": "tool", "content": "not accepted"}]) + + # ====================================================================== # _convert_image_source_to_url # ====================================================================== diff --git a/vllm/entrypoints/anthropic/protocol.py b/vllm/entrypoints/anthropic/protocol.py index 3ebc171173e9..a5d6e0b98586 100644 --- a/vllm/entrypoints/anthropic/protocol.py +++ b/vllm/entrypoints/anthropic/protocol.py @@ -8,6 +8,75 @@ from pydantic import BaseModel, Field, field_validator, model_validator +def _content_to_system_blocks(content: Any) -> list[Any]: + if content is None: + return [] + if isinstance(content, str): + return [{"type": "text", "text": content}] + if isinstance(content, list): + return content + return [{"type": "text", "text": str(content)}] + + +def _merge_system_content( + existing_system: Any, system_message_content: list[Any] +) -> str | list[Any] | None: + if not system_message_content: + return existing_system + + if existing_system is None and len(system_message_content) == 1: + content = system_message_content[0] + if isinstance(content, str): + return content + if isinstance(content, list): + return content + + system_blocks: list[Any] = [] + if existing_system is not None: + system_blocks.extend(_content_to_system_blocks(existing_system)) + for content in system_message_content: + system_blocks.extend(_content_to_system_blocks(content)) + return system_blocks + + +def _normalize_claude_code_message_roles(data: Any) -> Any: + """Move Claude Code system messages into the Anthropic system field.""" + if not isinstance(data, dict): + return data + + messages = data.get("messages") + if not isinstance(messages, list): + return data + + normalized_messages: list[Any] = [] + system_message_content: list[Any] = [] + changed = False + + for message in messages: + if not isinstance(message, dict): + normalized_messages.append(message) + continue + + role = message.get("role") + if role == "system": + system_message_content.append(message.get("content", "")) + changed = True + continue + + normalized_messages.append(message) + + if not changed: + return data + + normalized_data = dict(data) + normalized_data["messages"] = normalized_messages + if system_message_content: + normalized_data["system"] = _merge_system_content( + data.get("system"), system_message_content + ) + return normalized_data + + class AnthropicError(BaseModel): """Error structure for Anthropic API""" @@ -144,6 +213,11 @@ class AnthropicMessagesRequest(BaseModel): ), ) + @model_validator(mode="before") + @classmethod + def normalize_claude_code_message_roles(cls, data: Any) -> Any: + return _normalize_claude_code_message_roles(data) + @field_validator("model") @classmethod def validate_model(cls, v): @@ -247,6 +321,11 @@ class AnthropicCountTokensRequest(BaseModel): ), ) + @model_validator(mode="before") + @classmethod + def normalize_claude_code_message_roles(cls, data: Any) -> Any: + return _normalize_claude_code_message_roles(data) + @field_validator("model") @classmethod def validate_model(cls, v):