Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions libs/core/langchain_core/messages/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
InputTokenDetails,
OutputTokenDetails,
UsageMetadata,
add_ai_message_chunks,
)
from langchain_core.messages.base import (
BaseMessage,
Expand Down Expand Up @@ -112,6 +113,7 @@
"UsageMetadata",
"VideoContentBlock",
"_message_from_dict",
"add_ai_message_chunks",
"convert_to_messages",
"convert_to_openai_data_block",
"convert_to_openai_image_block",
Expand Down Expand Up @@ -184,6 +186,7 @@
"message_chunk_to_message": "utils",
"messages_from_dict": "utils",
"trim_messages": "utils",
"add_ai_message_chunks": "ai",
}


Expand Down
44 changes: 41 additions & 3 deletions libs/core/langchain_core/messages/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Sequence
from typing import Any, Literal, cast, overload

from pydantic import model_validator
from pydantic import Field, model_validator
from typing_extensions import NotRequired, Self, TypedDict, override

from langchain_core.messages import content as types
Expand Down Expand Up @@ -173,25 +173,37 @@ class AIMessage(BaseMessage):
type: Literal["ai"] = "ai"
"""The type of the message (used for deserialization)."""

raw_response: dict[str, Any] | list[dict[str, Any]] | None = Field(
default=None,
exclude=True, # Exclude from serialization by default
)
"""Optional raw model response data, as returned directly by the LLM provider."""

@overload
def __init__(
self,
content: str | list[str | dict],
*,
raw_response: dict[str, Any] | list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> None: ...

@overload
def __init__(
self,
content: str | list[str | dict] | None = None,
*,
content_blocks: list[types.ContentBlock] | None = None,
raw_response: dict[str, Any] | list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> None: ...

def __init__(
self,
content: str | list[str | dict] | None = None,
*,
content_blocks: list[types.ContentBlock] | None = None,
raw_response: dict[str, Any] | list[dict[str, Any]] | None = None,
**kwargs: Any,
) -> None:
"""Initialize an `AIMessage`.
Expand All @@ -200,7 +212,8 @@ def __init__(

Args:
content: The content of the message.
content_blocks: Typed standard content.
content_blocks: Typed standard content blocks.
raw_response: Optional raw model response data from the LLM provider.
**kwargs: Additional arguments to pass to the parent class.
"""
if content_blocks is not None:
Expand All @@ -218,17 +231,24 @@ def __init__(
else:
super().__init__(content=content, **kwargs)

# Store raw_response if provided
if raw_response is not None:
self.raw_response = raw_response

@property
def lc_attributes(self) -> dict:
"""Attributes to be serialized.

Includes all attributes, even if they are derived from other initialization
arguments.
"""
return {
attrs = {
"tool_calls": self.tool_calls,
"invalid_tool_calls": self.invalid_tool_calls,
}
if self.raw_response is not None:
attrs["raw_response"] = self.raw_response
return attrs

@property
def content_blocks(self) -> list[types.ContentBlock]:
Expand Down Expand Up @@ -396,6 +416,12 @@ class AIMessageChunk(AIMessage, BaseMessageChunk):
`tool_call_chunks` in message content will be parsed into `tool_calls`.
"""

raw_response: dict[str, Any] | None = Field(
default=None,
exclude=True, # Exclude from serialization by default
)
"""The raw chunk response from the model, as returned directly by the provider."""

@property
def lc_attributes(self) -> dict:
"""Attributes to be serialized, even if they are derived from other initialization args.""" # noqa: E501
Expand Down Expand Up @@ -671,6 +697,17 @@ def add_ai_message_chunks(
"last" if any(x.chunk_position == "last" for x in [left, *others]) else None
)

# Merge raw_response: collect all non-None raw_response values
raw_responses = [
c.raw_response for c in [left, *others] if c.raw_response is not None
]
# If only one raw_response, use it directly; if multiple, keep as list
merged_raw_response: dict[str, Any] | list[dict[str, Any]] | None = None
if len(raw_responses) == 1:
merged_raw_response = raw_responses[0]
elif len(raw_responses) > 1:
merged_raw_response = raw_responses

return left.__class__(
content=content,
additional_kwargs=additional_kwargs,
Expand All @@ -679,6 +716,7 @@ def add_ai_message_chunks(
usage_metadata=usage_metadata,
id=chunk_id,
chunk_position=chunk_position,
raw_response=merged_raw_response,
)


Expand Down
10 changes: 7 additions & 3 deletions libs/core/langchain_core/messages/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,13 @@ def message_chunk_to_message(chunk: BaseMessage) -> BaseMessage:
ignore_keys = ["type"]
if isinstance(chunk, AIMessageChunk):
ignore_keys.extend(["tool_call_chunks", "chunk_position"])
return chunk.__class__.__mro__[1](
**{k: v for k, v in chunk.__dict__.items() if k not in ignore_keys}
)

data = {k: v for k, v in chunk.__dict__.items() if k not in ignore_keys}
# Preserve raw_response if it exists and is not None
if hasattr(chunk, "raw_response") and chunk.raw_response is not None:
data["raw_response"] = chunk.raw_response

return chunk.__class__.__mro__[1](**data)


MessageLikeRepresentation = (
Expand Down
1 change: 1 addition & 0 deletions libs/core/tests/unit_tests/messages/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"VideoContentBlock",
"ReasoningContentBlock",
"RemoveMessage",
"add_ai_message_chunks",
"convert_to_messages",
"ensure_id",
"get_buffer_string",
Expand Down
60 changes: 60 additions & 0 deletions libs/core/tests/unit_tests/prompts/__snapshots__/test_chat.ambr
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,24 @@
'default': None,
'title': 'Name',
}),
'raw_response': dict({
'anyOf': list([
dict({
'type': 'object',
}),
dict({
'items': dict({
'type': 'object',
}),
'type': 'array',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Raw Response',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
Expand Down Expand Up @@ -182,6 +200,18 @@
'default': None,
'title': 'Name',
}),
'raw_response': dict({
'anyOf': list([
dict({
'type': 'object',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Raw Response',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
Expand Down Expand Up @@ -1485,6 +1515,24 @@
'default': None,
'title': 'Name',
}),
'raw_response': dict({
'anyOf': list([
dict({
'type': 'object',
}),
dict({
'items': dict({
'type': 'object',
}),
'type': 'array',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Raw Response',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
Expand Down Expand Up @@ -1596,6 +1644,18 @@
'default': None,
'title': 'Name',
}),
'raw_response': dict({
'anyOf': list([
dict({
'type': 'object',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Raw Response',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,24 @@
'default': None,
'title': 'Name',
}),
'raw_response': dict({
'anyOf': list([
dict({
'type': 'object',
}),
dict({
'items': dict({
'type': 'object',
}),
'type': 'array',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Raw Response',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
Expand Down Expand Up @@ -606,6 +624,18 @@
'default': None,
'title': 'Name',
}),
'raw_response': dict({
'anyOf': list([
dict({
'type': 'object',
}),
dict({
'type': 'null',
}),
]),
'default': None,
'title': 'Raw Response',
}),
'response_metadata': dict({
'title': 'Response Metadata',
'type': 'object',
Expand Down
76 changes: 76 additions & 0 deletions libs/core/tests/unit_tests/test_ai_message_raw_response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from typing import Any

import pytest

from langchain_core.messages import (
AIMessage,
AIMessageChunk,
add_ai_message_chunks,
message_chunk_to_message,
)


@pytest.fixture
def base_chunk() -> AIMessageChunk:
"""Create a base AIMessageChunk for reuse."""
return AIMessageChunk(content="hello", raw_response={"delta": "hello"})


def test_ai_message_stores_raw_response() -> None:
"""Test that AIMessage correctly stores raw_response."""
msg: AIMessage = AIMessage(content="hi", raw_response={"raw": "ok"})
assert msg.raw_response == {"raw": "ok"}


def test_add_ai_message_chunks_merges_raw_response(base_chunk: AIMessageChunk) -> None:
"""Test merging of AIMessageChunk objects combines raw_response correctly."""
chunk1: AIMessageChunk = base_chunk
chunk2: AIMessageChunk = AIMessageChunk(
content=" world", raw_response={"delta": " world"}
)
merged: AIMessageChunk = add_ai_message_chunks(chunk1, chunk2)
assert merged.content == "hello world"
assert isinstance(merged.raw_response, list)
assert merged.raw_response == [{"delta": "hello"}, {"delta": " world"}]


def test_add_ai_message_chunks_handles_missing_raw_response() -> None:
"""Test merging when some chunks have missing raw_response."""
c1: AIMessageChunk = AIMessageChunk(content="foo", raw_response={"delta": "foo"})
c2: AIMessageChunk = AIMessageChunk(content="bar", raw_response=None)
merged: AIMessageChunk = add_ai_message_chunks(c1, c2)
# Single raw_response should be kept as dict, not list
assert merged.raw_response == {"delta": "foo"}


def test_message_chunk_to_message_transfers_raw_response(
base_chunk: AIMessageChunk,
) -> None:
"""Test that message_chunk_to_message preserves raw_response."""
msg = message_chunk_to_message(base_chunk)
assert isinstance(msg, AIMessage)
# raw_response should be preserved
assert msg.raw_response == {"delta": "hello"}


def test_message_chunk_to_message_ignores_non_chunk_input() -> None:
"""Test that message_chunk_to_message passes through non-chunk inputs."""
raw: AIMessage = AIMessage(content="hi", raw_response={"data": 1})
result = message_chunk_to_message(raw)
# Should simply pass through
assert result is raw


def test_empty_raw_response_not_present_in_serialized_message() -> None:
"""Test that raw_response is omitted when serializing message with None."""
msg: AIMessage = AIMessage(content="test", raw_response=None)
attrs: dict[str, Any] = msg.lc_attributes
assert "raw_response" not in attrs


def test_invalid_input_handling_for_merging_different_types() -> None:
"""Test add_ai_message_chunks handles single or empty merge cases gracefully."""
chunk: AIMessageChunk = AIMessageChunk(content="hello")
# Should be unaffected even when others list is empty
merged: AIMessageChunk = add_ai_message_chunks(chunk)
assert merged.content == "hello"
1 change: 1 addition & 0 deletions libs/langchain/tests/unit_tests/chat_models/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def test_configurable() -> None:
"store": None,
"extra_body": None,
"include_response_headers": False,
"include_raw_response": False,
"stream_usage": True,
"use_previous_response_id": False,
"use_responses_api": None,
Expand Down
Loading
Loading