⚡️ Speed up function _mark_int64_fields_for_proto_maps by 17%
#126
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
📄 17% (0.17x) speedup for
_mark_int64_fields_for_proto_mapsinmlflow/utils/proto_json_utils.py⏱️ Runtime :
1.69 milliseconds→1.44 milliseconds(best of52runs)📝 Explanation and details
The optimized code achieves a 16% speedup through several key performance improvements:
1. Set-based lookup optimization: The most critical change is converting
_PROTOBUF_INT64_FIELDSfrom a list to a set, enabling O(1) membership testing instead of O(n) linear search. This dramatically improves thevalue_field_type in int64_typeschecks.2. Branch prediction optimization in
_mark_int64_fields_for_proto_maps: Instead of checking the value type for every map item, the code now branches once outside the loop based on thevalue_field_type, creating three specialized tight loops:_mark_int64_fieldsint()3. Local variable caching: Constants like
FieldDescriptor.TYPE_MESSAGEand frequently accessed attributes are cached in local variables to avoid repeated attribute lookups during iteration.4. Method binding optimization: In
_mark_int64_fields,proto_message.ListFieldsis bound to a local variable to avoid method resolution overhead in the loop.The test results show the optimization is most effective for large-scale workloads:
These optimizations are particularly valuable since protobuf processing often involves large collections of data where the cumulative effect of reducing per-iteration overhead becomes significant.
✅ Correctness verification report:
🌀 Generated Regression Tests and Runtime
from functools import partial
imports
import pytest # used for our unit tests
from mlflow.utils.proto_json_utils import _mark_int64_fields_for_proto_maps
Simulate google.protobuf.descriptor.FieldDescriptor for testing
class FakeFieldDescriptor:
TYPE_INT64 = 3
TYPE_UINT64 = 4
TYPE_FIXED64 = 5
TYPE_SFIXED64 = 6
TYPE_SINT64 = 7
TYPE_MESSAGE = 11
LABEL_REPEATED = 3
class FakeMessageType:
def init(self, has_options=False, map_entry=False, fields_by_name=None):
self.has_options = has_options
self._map_entry = map_entry
self.fields_by_name = fields_by_name or {}
Helper for simulating proto message ListFields()
class FakeProtoMessage:
def init(self, fields):
# fields: list of (FakeFieldDescriptor, value)
self._fields = fields
The actual function under test, copied from above
_PROTOBUF_INT64_FIELDS = [
FakeFieldDescriptor.TYPE_INT64,
FakeFieldDescriptor.TYPE_UINT64,
FakeFieldDescriptor.TYPE_FIXED64,
FakeFieldDescriptor.TYPE_SFIXED64,
FakeFieldDescriptor.TYPE_SINT64,
]
from mlflow.utils.proto_json_utils import _mark_int64_fields_for_proto_maps
----------------------
Unit tests start here
----------------------
1. Basic Test Cases
def test_basic_int64_map():
# Test a proto map with int64 values
proto_map = {"a": 123, "b": -456}
value_field_type = FakeFieldDescriptor.TYPE_INT64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 2.07μs -> 2.12μs (2.36% slower)
def test_basic_uint64_map():
proto_map = {"x": 2**63, "y": 0}
value_field_type = FakeFieldDescriptor.TYPE_UINT64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 2.09μs -> 1.92μs (9.18% faster)
def test_basic_fixed64_map():
proto_map = {"foo": 42, "bar": 99}
value_field_type = FakeFieldDescriptor.TYPE_FIXED64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 1.93μs -> 1.83μs (5.07% faster)
def test_basic_sfixed64_map():
proto_map = {"neg": -100, "pos": 100}
value_field_type = FakeFieldDescriptor.TYPE_SFIXED64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 2.00μs -> 1.92μs (3.95% faster)
def test_basic_sint64_map():
proto_map = {"min": -263, "max": 263-1}
value_field_type = FakeFieldDescriptor.TYPE_SINT64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 1.98μs -> 1.79μs (10.3% faster)
def test_basic_message_map():
# Map with message values
value_field_type = FakeFieldDescriptor.TYPE_MESSAGE
# Simulate a message with an int64 field
int64_field = FakeFieldDescriptor("val", FakeFieldDescriptor.TYPE_INT64)
msg1 = FakeProtoMessage([(int64_field, 10)])
msg2 = FakeProtoMessage([(int64_field, 20)])
proto_map = {"first": msg1, "second": msg2}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 3.86μs -> 3.75μs (2.80% faster)
def test_basic_non_int_key_map():
# If key is int, but value_field_type is not int64, should preserve value
proto_map = {1: "foo", 2: "bar"}
value_field_type = 0 # Not an int64 type
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 2.21μs -> 2.01μs (9.97% faster)
2. Edge Test Cases
def test_empty_map():
# Empty proto map should return empty dict
proto_map = {}
value_field_type = FakeFieldDescriptor.TYPE_INT64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 972ns -> 1.51μs (35.8% slower)
def test_map_with_mixed_types():
# Map with int, float, string values, only int keys preserved if not int64
proto_map = {1: "one", "two": 2, 3: 3.0}
value_field_type = 0 # Not int64 type
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 2.32μs -> 2.11μs (10.0% faster)
def test_map_with_none_values():
# Map with None values, should convert to int if int64 type (raises TypeError)
proto_map = {"a": None}
value_field_type = FakeFieldDescriptor.TYPE_INT64
with pytest.raises(TypeError):
_mark_int64_fields_for_proto_maps(proto_map, value_field_type) # 3.01μs -> 3.05μs (1.31% slower)
def test_message_map_with_no_int64_fields():
# Message values with only non-int64 fields, should return empty dicts
field = FakeFieldDescriptor("val", 1) # Not int64
msg = FakeProtoMessage([(field, "not_int64")])
proto_map = {"k": msg}
value_field_type = FakeFieldDescriptor.TYPE_MESSAGE
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 3.00μs -> 2.98μs (0.637% faster)
def test_message_map_with_repeated_int64_field():
# Message values with repeated int64 field
field = FakeFieldDescriptor("numbers", FakeFieldDescriptor.TYPE_INT64, label=FakeFieldDescriptor.LABEL_REPEATED, is_repeated=True)
msg = FakeProtoMessage([(field, [1, 2, 3])])
proto_map = {"nums": msg}
value_field_type = FakeFieldDescriptor.TYPE_MESSAGE
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 3.70μs -> 3.98μs (7.04% slower)
def test_map_with_negative_and_large_ints():
# Map with negative and large int values
proto_map = {"neg": -999999999999, "big": 999999999999}
value_field_type = FakeFieldDescriptor.TYPE_INT64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 2.01μs -> 1.94μs (3.19% faster)
def test_map_with_float_values_and_int64_type():
# Map with float values and int64 type, should cast to int
proto_map = {"f": 1.23}
value_field_type = FakeFieldDescriptor.TYPE_INT64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 1.81μs -> 1.81μs (0.331% slower)
def test_map_with_bool_values_and_int64_type():
# Map with bool values and int64 type, should cast to int (True->1, False->0)
proto_map = {"t": True, "f": False}
value_field_type = FakeFieldDescriptor.TYPE_INT64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 1.96μs -> 2.06μs (5.00% slower)
def test_map_with_unusual_keys():
# Map with tuple keys, only int keys preserved if not int64
proto_map = {(1, 2): "tuple", 3: "int"}
value_field_type = 0
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 2.18μs -> 2.16μs (1.11% faster)
def test_message_map_with_map_field():
# Message value contains a map field (nested map)
# Simulate nested map field
value_field_type = FakeFieldDescriptor.TYPE_INT64
map_field_desc = FakeFieldDescriptor(
"nested_map", FakeFieldDescriptor.TYPE_MESSAGE,
message_type=FakeMessageType(
has_options=True,
map_entry=True,
fields_by_name={"value": FakeFieldDescriptor("value", FakeFieldDescriptor.TYPE_INT64)}
)
)
nested_map = {"x": 1, "y": 2}
msg = FakeProtoMessage([(map_field_desc, nested_map)])
proto_map = {"outer": msg}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FakeFieldDescriptor.TYPE_MESSAGE); result = codeflash_output # 16.3μs -> 16.0μs (1.72% faster)
def test_message_map_with_map_field_non_int64():
# Message value contains a map field with non-int64 type
map_field_desc = FakeFieldDescriptor(
"nested_map", FakeFieldDescriptor.TYPE_MESSAGE,
message_type=FakeMessageType(
has_options=True,
map_entry=True,
fields_by_name={"value": FakeFieldDescriptor("value", 1)} # Not int64
)
)
nested_map = {"x": "foo", "y": "bar"}
msg = FakeProtoMessage([(map_field_desc, nested_map)])
proto_map = {"outer": msg}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FakeFieldDescriptor.TYPE_MESSAGE); result = codeflash_output # 15.6μs -> 15.9μs (2.22% slower)
3. Large Scale Test Cases
def test_large_map_int64():
# Large map with int64 values
proto_map = {str(i): i for i in range(1000)}
value_field_type = FakeFieldDescriptor.TYPE_INT64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 86.9μs -> 60.6μs (43.5% faster)
def test_large_map_message_values():
# Large map with message values, each with int64 field
int64_field = FakeFieldDescriptor("val", FakeFieldDescriptor.TYPE_INT64)
proto_map = {str(i): FakeProtoMessage([(int64_field, i)]) for i in range(1000)}
value_field_type = FakeFieldDescriptor.TYPE_MESSAGE
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 344μs -> 320μs (7.56% faster)
def test_large_map_repeated_int64():
# Large map with repeated int64 field in message
field = FakeFieldDescriptor("nums", FakeFieldDescriptor.TYPE_INT64, label=FakeFieldDescriptor.LABEL_REPEATED, is_repeated=True)
proto_map = {str(i): FakeProtoMessage([(field, [i, i+1, i+2])]) for i in range(1000)}
value_field_type = FakeFieldDescriptor.TYPE_MESSAGE
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 544μs -> 539μs (0.981% faster)
def test_large_map_non_int_keys():
# Large map with int keys and non-int64 value type
proto_map = {i: f"val{i}" for i in range(1000)}
value_field_type = 0
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 84.9μs -> 44.7μs (89.9% faster)
def test_large_map_with_mixed_types():
# Large map with mixed value types, only int64 values processed
proto_map = {str(i): i if i % 2 == 0 else str(i) for i in range(1000)}
value_field_type = FakeFieldDescriptor.TYPE_INT64
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, value_field_type); result = codeflash_output # 102μs -> 74.9μs (36.4% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
#------------------------------------------------
from functools import partial
imports
import pytest
from mlflow.utils.proto_json_utils import _mark_int64_fields_for_proto_maps
Simulate google.protobuf.descriptor.FieldDescriptor for testing
class FieldDescriptor:
TYPE_INT64 = 3
TYPE_UINT64 = 4
TYPE_FIXED64 = 5
TYPE_SFIXED64 = 6
TYPE_SINT64 = 7
TYPE_MESSAGE = 11
LABEL_REPEATED = 3
_PROTOBUF_INT64_FIELDS = [
FieldDescriptor.TYPE_INT64,
FieldDescriptor.TYPE_UINT64,
FieldDescriptor.TYPE_FIXED64,
FieldDescriptor.TYPE_SFIXED64,
FieldDescriptor.TYPE_SINT64,
]
from mlflow.utils.proto_json_utils import _mark_int64_fields_for_proto_maps
Helper classes to simulate proto fields and messages for tests
class FakeField:
def init(
self,
name,
type_,
label=None,
is_repeated=None,
message_type=None,
):
self.name = name
self.type = type_
self.label = label
self.is_repeated = is_repeated if is_repeated is not None else (label == FieldDescriptor.LABEL_REPEATED)
self.message_type = message_type
class FakeProtoMessage:
def init(self, fields):
# fields: list of (FakeField, value)
self._fields = fields
def ListFields(self):
return self._fields
--------------------------
Basic Test Cases
--------------------------
def test_empty_proto_map_returns_empty_dict():
# Test with empty proto_map
codeflash_output = _mark_int64_fields_for_proto_maps({}, FieldDescriptor.TYPE_INT64); result = codeflash_output # 1.00μs -> 1.44μs (30.4% slower)
def test_int64_map_basic():
# Test with int64 values in proto_map
proto_map = {"a": 123, "b": -456}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FieldDescriptor.TYPE_INT64); result = codeflash_output # 1.90μs -> 1.85μs (2.64% faster)
def test_uint64_map_basic():
# Test with uint64 values
proto_map = {"x": 2**63, "y": 0}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FieldDescriptor.TYPE_UINT64); result = codeflash_output # 1.97μs -> 1.88μs (4.41% faster)
def test_fixed64_map_basic():
# Test with fixed64 values
proto_map = {"foo": 42, "bar": 99}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FieldDescriptor.TYPE_FIXED64); result = codeflash_output # 1.92μs -> 1.81μs (6.18% faster)
def test_message_type_map():
# Test with message type values in proto_map
# Simulate a message with a single int64 field
field = FakeField("val", FieldDescriptor.TYPE_INT64)
msg1 = FakeProtoMessage([(field, 100)])
msg2 = FakeProtoMessage([(field, 200)])
proto_map = {"k1": msg1, "k2": msg2}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FieldDescriptor.TYPE_MESSAGE); result = codeflash_output # 3.76μs -> 3.73μs (1.05% faster)
def test_non_int_key_map():
# Test with non-int value_field_type and int keys
proto_map = {1: "abc", 2: "def"}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, 9999); result = codeflash_output # 2.13μs -> 1.94μs (9.79% faster)
--------------------------
Edge Test Cases
--------------------------
def test_proto_map_with_mixed_keys():
# Test with mixed key types
proto_map = {1: "a", "2": "b", 3.0: "c"}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, 9999); result = codeflash_output # 2.37μs -> 2.27μs (4.22% faster)
def test_proto_map_with_large_int64():
# Test with very large int64 values
max_int64 = 263 - 1
min_int64 = -263
proto_map = {"max": max_int64, "min": min_int64}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FieldDescriptor.TYPE_INT64); result = codeflash_output # 2.04μs -> 2.00μs (2.05% faster)
def test_proto_map_with_negative_and_positive_keys():
# Test with negative and positive int keys
proto_map = {-1: "neg", 0: "zero", 1: "pos"}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, 9999); result = codeflash_output # 2.32μs -> 2.15μs (7.80% faster)
def test_large_proto_map_int64():
# Test with a large proto_map of int64 values
proto_map = {str(i): i for i in range(1000)}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FieldDescriptor.TYPE_INT64); result = codeflash_output # 88.4μs -> 60.5μs (46.1% faster)
def test_large_proto_map_message_type():
# Test with a large proto_map of message types
field = FakeField("val", FieldDescriptor.TYPE_INT64)
proto_map = {str(i): FakeProtoMessage([(field, i)]) for i in range(500)}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FieldDescriptor.TYPE_MESSAGE); result = codeflash_output # 178μs -> 168μs (5.85% faster)
def test_large_proto_map_non_int_keys():
# Test with a large proto_map with non-int keys
proto_map = {f"key_{i}": i for i in range(1000)}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, 9999); result = codeflash_output # 76.7μs -> 34.4μs (123% faster)
def test_large_proto_map_all_int_keys():
# Test with a large proto_map with all int keys and non-int64 type
proto_map = {i: str(i) for i in range(1000)}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, 9999); result = codeflash_output # 87.5μs -> 46.4μs (88.3% faster)
--------------------------
Additional Edge Cases
--------------------------
def test_proto_map_with_float_and_bool_values():
# Test with float and bool values
proto_map = {"float": 1.5, "bool_true": True, "bool_false": False}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, FieldDescriptor.TYPE_INT64); result = codeflash_output # 2.22μs -> 2.25μs (1.29% slower)
def test_proto_map_with_none_key():
# Test with None as key
proto_map = {None: 123, 1: 456}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, 9999); result = codeflash_output # 2.25μs -> 2.21μs (1.94% faster)
def test_proto_map_with_empty_string_key():
# Test with empty string key
proto_map = {"": 789, 2: 101}
codeflash_output = _mark_int64_fields_for_proto_maps(proto_map, 9999); result = codeflash_output # 2.18μs -> 2.00μs (9.06% faster)
codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
To edit these changes
git checkout codeflash/optimize-_mark_int64_fields_for_proto_maps-mhuh3tqgand push.