⚡️ Speed up function _mark_int64_fields by 12%
#127
+35
−23
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 12% (0.12x) speedup for
_mark_int64_fieldsinmlflow/utils/proto_json_utils.py⏱️ Runtime :
812 microseconds→728 microseconds(best of24runs)📝 Explanation and details
The optimized code achieves an 11% speedup through several targeted micro-optimizations that reduce Python overhead in hot loops:
Key Performance Improvements:
Cached attribute lookups: The optimized version stores
FieldDescriptor.TYPE_MESSAGE,_PROTOBUF_INT64_FIELDS, and_mark_int64_fieldsin local variables at function start. This eliminates repeated attribute lookups inside loops, which is particularly beneficial since the line profiler shows these functions are called thousands of times (1049+ hits for_mark_int64_fields).Dictionary comprehensions over explicit loops: In
_mark_int64_fields_for_proto_maps, the optimized version replaces the explicit for-loop with dictionary comprehensions that are branch-optimized for each case (message values, int64 values, integer keys). This reduces per-iteration overhead and leverages Python's optimized comprehension implementation.Eliminated
partialfunction overhead: The original code usedpartial(_mark_int64_fields)which creates a closure object. The optimized version directly references the cached function, avoiding this allocation overhead.Reduced repeated field lookups: The optimization extracts
field.message_type.fields_by_name["value"].typeinto a local variable to avoid the dictionary lookup happening inside the recursive call.Specialized repeated field handling: For repeated int64 fields, the optimized version uses a direct
[int(v) for v in value]comprehension instead of the more generic[ftype(v) for v in value], eliminating function pointer indirection.Performance Impact by Test Case:
test_large_map_field_with_int64_valuessees 39% speedup, andtest_large_message_with_mixed_fieldssees 23% speedupThese optimizations are particularly effective for protobuf processing workloads involving large messages with many int64 fields or large maps, which are common in MLflow's model tracking and experiment logging scenarios.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
from functools import partial
from types import SimpleNamespace
imports
import pytest
from mlflow.utils.proto_json_utils import _mark_int64_fields
class FieldDescriptor:
# Fake protobuf field types
TYPE_INT64 = 1
TYPE_UINT64 = 2
TYPE_FIXED64 = 3
TYPE_SFIXED64 = 4
TYPE_SINT64 = 5
TYPE_MESSAGE = 6
LABEL_REPEATED = 3
class FakeMessageType:
def init(self, has_options=False, map_entry=False, fields_by_name=None):
self.has_options = has_options
self._map_entry = map_entry
self.fields_by_name = fields_by_name or {}
--- End: function to test ---
--- Begin: helper classes for testing ---
class FakeProtoMessage:
"""Fake proto message for testing."""
def init(self, fields):
# fields: list of (FieldDescriptor, value)
self._fields = fields
--- End: helper classes for testing ---
--- Begin: Basic Test Cases ---
def test_single_int64_field():
# Test a message with a single int64 field
fd = FieldDescriptor("foo", FieldDescriptor.TYPE_INT64)
msg = FakeProtoMessage([(fd, 1234567890123456789)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 1.95μs -> 1.75μs (11.9% faster)
def test_multiple_int64_fields():
# Test a message with multiple int64 fields
fd1 = FieldDescriptor("bar", FieldDescriptor.TYPE_INT64)
fd2 = FieldDescriptor("baz", FieldDescriptor.TYPE_UINT64)
msg = FakeProtoMessage([(fd1, 42), (fd2, 99)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.03μs -> 2.01μs (1.14% faster)
def test_non_int64_field_skipped():
# Test that non-int64 fields are skipped
fd1 = FieldDescriptor("foo", FieldDescriptor.TYPE_INT64)
fd2 = FieldDescriptor("not_int64", 999) # Not an int64 type
msg = FakeProtoMessage([(fd1, 100), (fd2, "should be skipped")])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.10μs -> 1.88μs (12.0% faster)
def test_repeated_int64_field():
# Test a repeated int64 field
fd = FieldDescriptor("nums", FieldDescriptor.TYPE_INT64, label=FieldDescriptor.LABEL_REPEATED)
msg = FakeProtoMessage([(fd, [1, 2, 3, 4])])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 1.80μs -> 1.76μs (2.45% faster)
def test_repeated_int64_field_with_is_repeated():
# Test a repeated int64 field using is_repeated property
fd = FieldDescriptor("nums", FieldDescriptor.TYPE_INT64, is_repeated=True)
msg = FakeProtoMessage([(fd, [10, 20, 30])])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 1.72μs -> 1.75μs (1.89% slower)
def test_empty_message():
# Empty message should return empty dict
msg = FakeProtoMessage([])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 1.04μs -> 1.31μs (20.2% slower)
def test_message_with_all_non_int64_fields():
# All fields are non-int64, result should be empty
fd1 = FieldDescriptor("foo", 999)
fd2 = FieldDescriptor("bar", 888)
msg = FakeProtoMessage([(fd1, "a"), (fd2, "b")])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.23μs -> 1.93μs (15.5% faster)
def test_int64_field_with_zero():
# Int64 field with value zero
fd = FieldDescriptor("zero", FieldDescriptor.TYPE_INT64)
msg = FakeProtoMessage([(fd, 0)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 1.73μs -> 1.69μs (2.19% faster)
def test_int64_field_with_negative():
# Int64 field with negative value
fd = FieldDescriptor("neg", FieldDescriptor.TYPE_INT64)
msg = FakeProtoMessage([(fd, -1234567890123456789)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 1.77μs -> 1.55μs (14.5% faster)
def test_repeated_field_empty():
# Repeated int64 field with empty list
fd = FieldDescriptor("nums", FieldDescriptor.TYPE_INT64, label=FieldDescriptor.LABEL_REPEATED)
msg = FakeProtoMessage([(fd, [])])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 1.73μs -> 1.71μs (1.41% faster)
def test_large_repeated_int64_field():
# Large repeated int64 field (1000 elements)
nums = list(range(1000))
fd = FieldDescriptor("many_nums", FieldDescriptor.TYPE_INT64, label=FieldDescriptor.LABEL_REPEATED)
msg = FakeProtoMessage([(fd, nums)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.15μs -> 1.82μs (18.1% faster)
#------------------------------------------------
from functools import partial
imports
import pytest # used for our unit tests
from mlflow.utils.proto_json_utils import _mark_int64_fields
Simulate google.protobuf.descriptor.FieldDescriptor for testing
class FieldDescriptor:
# Field types
TYPE_INT64 = 3
TYPE_UINT64 = 4
TYPE_FIXED64 = 5
TYPE_SFIXED64 = 6
TYPE_SINT64 = 7
TYPE_MESSAGE = 11
# Field labels
LABEL_REPEATED = 10
class MessageType:
def init(self, has_options=False, map_entry=False, fields_by_name=None):
self.has_options = has_options
self._map_entry = map_entry
self.fields_by_name = fields_by_name or {}
Simulate a proto message
class ProtoMessage:
def init(self, fields):
# fields: List of tuples (FieldDescriptor, value)
self._fields = fields
Simulate a proto map (as dict)
class ProtoMap(dict):
pass
--- End function to test ---
--- Begin unit tests ---
Basic Test Cases
def test_single_int64_field():
# Test with a single int64 field
fd = FieldDescriptor("my_int", FieldDescriptor.TYPE_INT64)
msg = ProtoMessage([(fd, 42)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.72μs -> 2.21μs (23.3% faster)
def test_single_uint64_field():
fd = FieldDescriptor("my_uint", FieldDescriptor.TYPE_UINT64)
msg = ProtoMessage([(fd, 12345678901234567890)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.33μs -> 2.17μs (7.13% faster)
def test_multiple_int64_fields():
fd1 = FieldDescriptor("a", FieldDescriptor.TYPE_INT64)
fd2 = FieldDescriptor("b", FieldDescriptor.TYPE_SINT64)
msg = ProtoMessage([(fd1, -1), (fd2, 99)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.56μs -> 2.51μs (2.15% faster)
def test_non_int64_field_skipped():
fd = FieldDescriptor("my_float", 1) # Not an int64 type
msg = ProtoMessage([(fd, 3.14)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 1.93μs -> 1.79μs (8.23% faster)
def test_repeated_int64_field():
fd = FieldDescriptor("ints", FieldDescriptor.TYPE_INT64, is_repeated=True)
msg = ProtoMessage([(fd, [1, 2, 3])])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 3.06μs -> 3.09μs (1.04% slower)
def test_empty_message():
msg = ProtoMessage([])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 1.11μs -> 1.36μs (18.5% slower)
def test_message_with_all_non_int64_fields():
fd1 = FieldDescriptor("f1", 1)
fd2 = FieldDescriptor("f2", 2)
msg = ProtoMessage([(fd1, 1.1), (fd2, "abc")])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.30μs -> 2.08μs (10.7% faster)
def test_message_with_mixed_fields():
fd1 = FieldDescriptor("f1", FieldDescriptor.TYPE_INT64)
fd2 = FieldDescriptor("f2", 1)
fd3 = FieldDescriptor("f3", FieldDescriptor.TYPE_UINT64)
msg = ProtoMessage([(fd1, 100), (fd2, "skip"), (fd3, 200)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.94μs -> 2.79μs (5.19% faster)
def test_message_with_none_value():
fd = FieldDescriptor("my_int", FieldDescriptor.TYPE_INT64)
msg = ProtoMessage([(fd, None)])
# Should raise TypeError when trying to convert None to int
with pytest.raises(TypeError):
_mark_int64_fields(msg) # 3.46μs -> 3.35μs (3.07% faster)
def test_message_with_large_int64():
fd = FieldDescriptor("my_int", FieldDescriptor.TYPE_INT64)
big_val = 2**63 - 1
msg = ProtoMessage([(fd, big_val)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.28μs -> 2.17μs (5.03% faster)
def test_message_with_negative_int64():
fd = FieldDescriptor("my_int", FieldDescriptor.TYPE_INT64)
msg = ProtoMessage([(fd, -2**63)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 2.07μs -> 2.03μs (1.92% faster)
def test_map_field_with_int64_value():
# Simulate a proto map field
value_fd = FieldDescriptor("value", FieldDescriptor.TYPE_INT64)
map_type = MessageType(
has_options=True,
map_entry=True,
fields_by_name={"value": value_fd},
)
fd_map = FieldDescriptor("my_map", FieldDescriptor.TYPE_MESSAGE, message_type=map_type)
proto_map = ProtoMap({"key1": 123, "key2": 456})
msg = ProtoMessage([(fd_map, proto_map)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 18.3μs -> 19.5μs (5.80% slower)
def test_map_field_with_message_value():
# Simulate a proto map field with message values
value_fd = FieldDescriptor("value", FieldDescriptor.TYPE_MESSAGE)
map_type = MessageType(
has_options=True,
map_entry=True,
fields_by_name={"value": value_fd},
)
fd_map = FieldDescriptor("my_map", FieldDescriptor.TYPE_MESSAGE, message_type=map_type)
fd_inner = FieldDescriptor("val", FieldDescriptor.TYPE_INT64)
msg_inner1 = ProtoMessage([(fd_inner, 1)])
msg_inner2 = ProtoMessage([(fd_inner, 2)])
proto_map = ProtoMap({"k1": msg_inner1, "k2": msg_inner2})
msg = ProtoMessage([(fd_map, proto_map)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 17.7μs -> 18.3μs (3.19% slower)
def test_map_field_with_non_int_key():
# Map with non-int key, non-int64 value
value_fd = FieldDescriptor("value", 1)
map_type = MessageType(
has_options=True,
map_entry=True,
fields_by_name={"value": value_fd},
)
fd_map = FieldDescriptor("my_map", FieldDescriptor.TYPE_MESSAGE, message_type=map_type)
proto_map = ProtoMap({"foo": "bar", "baz": "qux"})
msg = ProtoMessage([(fd_map, proto_map)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 16.1μs -> 16.6μs (3.09% slower)
def test_map_field_with_int_key_and_non_int64_value():
value_fd = FieldDescriptor("value", 1)
map_type = MessageType(
has_options=True,
map_entry=True,
fields_by_name={"value": value_fd},
)
fd_map = FieldDescriptor("my_map", FieldDescriptor.TYPE_MESSAGE, message_type=map_type)
proto_map = ProtoMap({1: "bar", 2: "qux"})
msg = ProtoMessage([(fd_map, proto_map)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 15.8μs -> 16.7μs (5.23% slower)
def test_map_field_with_empty_map():
value_fd = FieldDescriptor("value", FieldDescriptor.TYPE_INT64)
map_type = MessageType(
has_options=True,
map_entry=True,
fields_by_name={"value": value_fd},
)
fd_map = FieldDescriptor("my_map", FieldDescriptor.TYPE_MESSAGE, message_type=map_type)
proto_map = ProtoMap({})
msg = ProtoMessage([(fd_map, proto_map)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 15.2μs -> 15.5μs (1.38% slower)
Large Scale Test Cases
def test_large_repeated_int64_field():
fd = FieldDescriptor("ints", FieldDescriptor.TYPE_INT64, is_repeated=True)
vals = list(range(1000))
msg = ProtoMessage([(fd, vals)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 33.5μs -> 39.0μs (14.0% slower)
def test_large_map_field_with_int64_values():
value_fd = FieldDescriptor("value", FieldDescriptor.TYPE_INT64)
map_type = MessageType(
has_options=True,
map_entry=True,
fields_by_name={"value": value_fd},
)
fd_map = FieldDescriptor("big_map", FieldDescriptor.TYPE_MESSAGE, message_type=map_type)
proto_map = ProtoMap({str(i): i for i in range(1000)})
msg = ProtoMessage([(fd_map, proto_map)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 114μs -> 82.1μs (39.1% faster)
def test_large_map_field_with_message_values():
value_fd = FieldDescriptor("value", FieldDescriptor.TYPE_MESSAGE)
map_type = MessageType(
has_options=True,
map_entry=True,
fields_by_name={"value": value_fd},
)
fd_map = FieldDescriptor("big_map", FieldDescriptor.TYPE_MESSAGE, message_type=map_type)
fd_inner = FieldDescriptor("val", FieldDescriptor.TYPE_INT64)
proto_map = ProtoMap({str(i): ProtoMessage([(fd_inner, i)]) for i in range(1000)})
msg = ProtoMessage([(fd_map, proto_map)])
codeflash_output = _mark_int64_fields(msg); result = codeflash_output # 394μs -> 361μs (8.92% faster)
def test_large_message_with_mixed_fields():
# Only int64 fields should be kept
fields = []
for i in range(500):
fd_int = FieldDescriptor(f"int_{i}", FieldDescriptor.TYPE_INT64)
fd_float = FieldDescriptor(f"float_{i}", 1)
fields.append((fd_int, i))
fields.append((fd_float, float(i)))
msg = ProtoMessage(fields)
codeflash_output = mark_int64_fields(msg); result = codeflash_output # 139μs -> 113μs (22.8% faster)
expected = {f"int{i}": i for i in range(500)}
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-_mark_int64_fields-mhuhgokoand push.