Skip to content

Commit

Permalink
Python: Improve hashing of cmc and scmc items. Add tests. (#10332)
Browse files Browse the repository at this point in the history
### Motivation and Context

We have some content types that can't be hashed as-is. This PR makes
sure that if there are unhashable types, that we turn them into hashable
types. Add unit tests to exercise the scenarios.

<!-- Thank you for your contribution to the semantic-kernel repo!
Please help reviewers and future users, providing the following
information:
  1. Why is this change required?
  2. What problem does it solve?
  3. What scenario does it contribute to?
  4. If it fixes an open issue, please link to the issue here.
-->

### Description

<!-- Describe your changes, the overall approach, the underlying design.
These notes will help understanding how your code works. Thanks! -->

### Contribution Checklist

<!-- Before submitting this PR, please make sure: -->

- [X] The code builds clean without any errors or warnings
- [X] The PR follows the [SK Contribution
Guidelines](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md)
and the [pre-submission formatting
script](https://github.com/microsoft/semantic-kernel/blob/main/CONTRIBUTING.md#development-scripts)
raises no violations
- [X] All unit tests pass, and I have added new tests where possible
- [X] I didn't break anyone 😄
  • Loading branch information
moonbox3 authored Jan 29, 2025
1 parent ec9b980 commit 44034eb
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 2 deletions.
4 changes: 3 additions & 1 deletion python/semantic_kernel/contents/chat_message_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from semantic_kernel.contents.text_content import TextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.contents.utils.hashing import make_hashable
from semantic_kernel.exceptions.content_exceptions import ContentInitializationError

TAG_CONTENT_MAP = {
Expand Down Expand Up @@ -315,4 +316,5 @@ def _parse_items(self) -> str | list[dict[str, Any]]:

def __hash__(self) -> int:
"""Return the hash of the chat message content."""
return hash((self.tag, self.role, self.content, self.encoding, self.finish_reason, *self.items))
hashable_items = [make_hashable(item) for item in self.items] if self.items else []
return hash((self.tag, self.role, self.content, self.encoding, self.finish_reason, *hashable_items))
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
from semantic_kernel.contents.utils.author_role import AuthorRole
from semantic_kernel.contents.utils.finish_reason import FinishReason
from semantic_kernel.contents.utils.hashing import make_hashable
from semantic_kernel.exceptions import ContentAdditionException

ITEM_TYPES = Union[
Expand Down Expand Up @@ -222,6 +223,7 @@ def to_element(self) -> "Element":

def __hash__(self) -> int:
"""Return the hash of the streaming chat message content."""
hashable_items = [make_hashable(item) for item in self.items] if self.items else []
return hash((
self.tag,
self.role,
Expand All @@ -230,5 +232,5 @@ def __hash__(self) -> int:
self.finish_reason,
self.choice_index,
self.function_invoke_attempt,
*self.items,
*hashable_items,
))
35 changes: 35 additions & 0 deletions python/tests/unit/contents/test_chat_message_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from defusedxml.ElementTree import XML

from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.file_reference_content import FileReferenceContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.image_content import ImageContent
Expand Down Expand Up @@ -380,3 +381,37 @@ def test_cmc_to_dict_keys():
def test_cmc_to_dict_items(input_args, expected_dict):
message = ChatMessageContent(**input_args)
assert message.to_dict() == expected_dict


def test_cmc_with_unhashable_types_can_hash():
user_messages = [
ChatMessageContent(
role=AuthorRole.USER,
items=[
TextContent(text="Describe this image."),
ImageContent(
uri="https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/New_york_times_square-terabass.jpg/1200px-New_york_times_square-terabass.jpg"
),
],
),
ChatMessageContent(
role=AuthorRole.USER,
items=[
TextContent(text="What is the main color in this image?"),
ImageContent(uri="https://upload.wikimedia.org/wikipedia/commons/5/56/White_shark.jpg"),
],
),
ChatMessageContent(
role=AuthorRole.USER,
items=[
TextContent(text="Is there an animal in this image?"),
FileReferenceContent(file_id="test_file_id"),
],
),
ChatMessageContent(
role=AuthorRole.USER,
),
]

for message in user_messages:
assert hash(message) is not None
40 changes: 40 additions & 0 deletions python/tests/unit/contents/test_streaming_chat_message_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from defusedxml.ElementTree import XML

from semantic_kernel.contents.chat_message_content import ChatMessageContent
from semantic_kernel.contents.file_reference_content import FileReferenceContent
from semantic_kernel.contents.function_call_content import FunctionCallContent
from semantic_kernel.contents.function_result_content import FunctionResultContent
from semantic_kernel.contents.image_content import ImageContent
from semantic_kernel.contents.streaming_chat_message_content import StreamingChatMessageContent
from semantic_kernel.contents.streaming_text_content import StreamingTextContent
from semantic_kernel.contents.text_content import TextContent
Expand Down Expand Up @@ -410,3 +412,41 @@ def test_scmc_bytes():
message = StreamingChatMessageContent(choice_index=0, role=AuthorRole.USER, content="Hello, world!")
assert bytes(message) == b"Hello, world!"
assert bytes(message.items[0]) == b"Hello, world!"


def test_scmc_with_unhashable_types_can_hash():
user_messages = [
StreamingChatMessageContent(
role=AuthorRole.USER,
items=[
StreamingTextContent(text="Describe this image.", choice_index=0),
ImageContent(
uri="https://upload.wikimedia.org/wikipedia/commons/thumb/4/47/New_york_times_square-terabass.jpg/1200px-New_york_times_square-terabass.jpg"
),
],
choice_index=0,
),
StreamingChatMessageContent(
role=AuthorRole.USER,
items=[
StreamingTextContent(text="What is the main color in this image?", choice_index=0),
ImageContent(uri="https://upload.wikimedia.org/wikipedia/commons/5/56/White_shark.jpg"),
],
choice_index=0,
),
StreamingChatMessageContent(
role=AuthorRole.USER,
items=[
StreamingTextContent(text="Is there an animal in this image?", choice_index=0),
FileReferenceContent(file_id="test_file_id"),
],
choice_index=0,
),
StreamingChatMessageContent(
role=AuthorRole.USER,
choice_index=0,
),
]

for message in user_messages:
assert hash(message) is not None

0 comments on commit 44034eb

Please sign in to comment.