diff --git a/ml_metrics/_src/utils/proto_utils.py b/ml_metrics/_src/utils/proto_utils.py index b607c34f..f69537fc 100644 --- a/ml_metrics/_src/utils/proto_utils.py +++ b/ml_metrics/_src/utils/proto_utils.py @@ -12,14 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. """Proto utils.""" + +import collections from collections.abc import Iterable from typing import Any -from absl import logging from ml_metrics._src.tools.telemetry import telemetry -import more_itertools as mit import numpy as np from tensorflow.core.example import example_pb2 + _ExampleOrBytes = bytes | example_pb2.Example @@ -32,61 +33,90 @@ def _maybe_deserialize(ex: _ExampleOrBytes) -> example_pb2.Example: @telemetry.function_monitor(api='ml_metrics', category=telemetry.CATEGORY.UTIL) -def tf_examples_to_dict(examples: Iterable[_ExampleOrBytes] | _ExampleOrBytes): - """Parses a serialized tf.train.Example to a dict.""" +def tf_examples_to_dict( + examples: Iterable[_ExampleOrBytes] | _ExampleOrBytes, +) -> dict[ + str, + list[int | float | bytes] | list[list[int | float | bytes]], +]: + """Parses serialized or unserialized tf.train.Examples to a dict. + + The conversion assumes all examples have the same features. If not, a + ValueError will be raised. + + Args: + examples: A single tf.train.Example, serialized tf.train.Example, or an + iterable of tf.train.Examples and/or serialized tf.train.Examples. + + Returns: + A dict mapping feature names to lists of feature values. + + Raises: + ValueError: If the features are not all present in all examples. + """ + single_example = False if isinstance(examples, (bytes, example_pb2.Example)): single_example = True examples = [examples] - examples = (_maybe_deserialize(ex) for ex in examples) - examples = mit.peekable(examples) - if (head := examples.peek(None)) is None: - return {} - result = {k: [] for k in head.features.feature} + result = collections.defaultdict(list) + for ex in examples: - missing = set(result) - for key, feature in ex.features.feature.items(): - missing.remove(key) - value = getattr(feature, feature.WhichOneof('kind')).value - if value and isinstance(value[0], bytes): - try: - value = [v.decode() for v in value] - except UnicodeDecodeError: - logging.info( - 'chainable: %s', - f'Failed to decode for {key}, forward the raw bytes.', - ) - result[key].extend(value) - if missing: + ex = _maybe_deserialize(ex) + features = dict(ex.features.feature) + + if result and result.keys() != features.keys(): raise ValueError( - f'Missing keys: {missing}, expecting {set(result)}, got {ex=}' + 'All examples must have the same features, got %s and %s' + % (result.keys(), features.keys()) ) - result = {k: v for k, v in result.items()} - # Scalar value in a single example will be returned with the scalar directly. - if single_example and all(len(v) == 1 for v in result.values()): - result = {k: v[0] for k, v in result.items()} + + for name, values in features.items(): + result[name].append(getattr(values, values.WhichOneof('kind')).value) + + if single_example: + return {k: v[0] for k, v in result.items()} return result @telemetry.function_monitor(api='ml_metrics', category=telemetry.CATEGORY.UTIL) def dict_to_tf_example(data: dict[str, Any]) -> example_pb2.Example: """Creates a tf.Example from a dictionary.""" + example = example_pb2.Example() - for key, value in data.items(): - if isinstance(value, (str, bytes, np.floating, float, int, np.integer)): - value = [value] - feature = example.features.feature - if isinstance(value[0], str): - for v in value: - assert isinstance(v, str), f'bad str type: {value}' - feature[key].bytes_list.value.append(v.encode()) - elif isinstance(value[0], bytes): - feature[key].bytes_list.value.extend(value) - elif isinstance(value[0], (int, np.integer)): - feature[key].int64_list.value.extend(value) - elif isinstance(value[0], (float, np.floating)): - feature[key].float_list.value.extend(value) + for key, values in data.items(): + if isinstance(values, (str, bytes, np.floating, float, int, np.integer)): + values = [values] + + if not values: + # Skip empty features. + continue + + if isinstance(values[0], str): + for v in values: + assert isinstance(v, str), f'bad str type: {values}' + example.features.feature[key].bytes_list.value.append(v.encode()) + continue + + if isinstance(values[0], bytes): + feature_kind = 'bytes_list' + elif isinstance(values[0], (float, np.floating)): + feature_kind = 'float_list' + elif isinstance(values[0], (int, np.integer)): + feature_kind = 'int64_list' + for v in values: + if isinstance(v, (float, np.floating)): + # If a float is encountered in the list, we consider the whole feature + # to be a float_list. + feature_kind = 'float_list' + break + elif not isinstance(v, (int, np.integer)): + break else: - raise TypeError(f'Value for "{key}" is not a supported type.') + raise TypeError(f'Values for "{key}" is not a supported type.') + + feature_list = getattr(example.features.feature[key], feature_kind).value + feature_list.extend(values) + return example diff --git a/ml_metrics/_src/utils/proto_utils_test.py b/ml_metrics/_src/utils/proto_utils_test.py index 66b55536..35c31136 100644 --- a/ml_metrics/_src/utils/proto_utils_test.py +++ b/ml_metrics/_src/utils/proto_utils_test.py @@ -11,91 +11,365 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from google.protobuf import text_format from ml_metrics._src.utils import proto_utils from ml_metrics._src.utils import test_utils -import numpy as np +import tensorflow as tf from absl.testing import absltest from absl.testing import parameterized from tensorflow.core.example import example_pb2 -def _get_tf_example(**kwargs): - example = example_pb2.Example() - for k, v in kwargs.items(): - example.features.feature[k].bytes_list.value.append(v) - return example +EXAMPLE_1 = text_format.Parse( + """ +features { + feature { + key: "bytes" + value { + bytes_list { + value: "ab" + } + } + } + feature { + key: "bytes_arr" + value { + bytes_list { + value: "cd" + value: "ef" + } + } + } + feature { + key: "int64" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "int64_arr" + value { + int64_list { + value: 2 + value: 3 + } + } + } + feature { + key: "float" + value { + float_list { + value: 1.5 + } + } + } + feature { + key: "float_arr" + value { + float_list { + value: 2.5 + value: 3.5 + } + } + } +} +""", + example_pb2.Example(), +) -class TFExampleTest(parameterized.TestCase): +EXAMPLE_2 = text_format.Parse( + """ +features { + feature { + key: "bytes" + value { + bytes_list { + value: "mn" + } + } + } + feature { + key: "bytes_arr" + value { + bytes_list { + value: "op" + value: "qr" + } + } + } + feature { + key: "int64" + value { + int64_list { + value: 4 + } + } + } + feature { + key: "int64_arr" + value { + int64_list { + value: 5 + value: 6 + } + } + } + feature { + key: "float" + value { + float_list { + value: 11.5 + } + } + } + feature { + key: "float_arr" + value { + float_list { + value: 12.5 + value: 13.5 + } + } + } +} +""", + example_pb2.Example(), +) - def test_single_example(self): - data = { - 'bytes_key': b'\x80abc', # not utf-8 decodable - 'str_key': 'str_test', - 'init_key': 123, - 'np_int': np.int32(123), - 'float_key': 4.56, - 'np_float': np.float32(123), - } - e = proto_utils.dict_to_tf_example(data).SerializeToString() - actual = proto_utils.tf_examples_to_dict(e) - self.assertDictAlmostEqual(data, actual, places=6) - - def test_batch_example(self): - data = { - 'bytes_key': [b'\x80abc', b'\x80def'], # not utf-8 decodable - 'str_key': ['str_test', 'str_test2'], - 'init_key': [123, 456], - 'np_int': [np.int32(123), np.int32(456)], - 'float_key': [4.56, 7.89], - 'np_float': [np.float32(123), np.float32(456)], + +class TFExampleTest(tf.test.TestCase, parameterized.TestCase): + + @parameterized.named_parameters( + dict( + testcase_name='tf_example', + example=EXAMPLE_1, + ), + dict( + testcase_name='serialized_tf_example', + example=EXAMPLE_1.SerializeToString(), + ), + ) + def test_single_example_to_dict(self, example): + actual = proto_utils.tf_examples_to_dict(example) + expected = { + 'bytes': [b'ab'], + 'bytes_arr': [b'cd', b'ef'], + 'int64': [1], + 'int64_arr': [2, 3], + 'float': [1.5], + 'float_arr': [2.5, 3.5], } - e = proto_utils.dict_to_tf_example(data) - actual = proto_utils.tf_examples_to_dict(e) - test_utils.assert_nested_container_equal(self, data, actual, places=6) + test_utils.assert_nested_container_equal(self, expected, actual, places=6) @parameterized.named_parameters( dict( - testcase_name='with_single_example', - num_elems=1, + testcase_name='batch_tf_examples', + examples=[EXAMPLE_1, EXAMPLE_2], ), dict( - testcase_name='multiple_examples', - num_elems=3, + testcase_name='batch_serialized_tf_examples', + examples=[ + EXAMPLE_1.SerializeToString(), + EXAMPLE_2.SerializeToString(), + ], + ), + dict( + testcase_name='batch_mixed_tf_examples', + examples=[EXAMPLE_1.SerializeToString(), EXAMPLE_2], ), ) - def test_multiple_examples_as_batch(self, num_elems): - data = { - 'bytes_key': b'\x80abc', # not utf-8 decodable - 'str_key': 'str_test', - 'init_key': 123, - 'np_int': np.int32(123), - 'float_key': 4.56, - 'np_float': np.float32(123), - } - e = [proto_utils.dict_to_tf_example(data) for _ in range(num_elems)] - actual = proto_utils.tf_examples_to_dict(e) - expected = {k: [v] * num_elems for k, v in data.items()} + def test_batched_examples_to_dict(self, examples): + actual = proto_utils.tf_examples_to_dict(examples) + expected = { + 'bytes': [[b'ab'], [b'mn']], + 'bytes_arr': [[b'cd', b'ef'], [b'op', b'qr']], + 'int64': [[1], [4]], + 'int64_arr': [[2, 3], [5, 6]], + 'float': [[1.5], [11.5]], + 'float_arr': [[2.5, 3.5], [12.5, 13.5]], + } test_utils.assert_nested_container_equal(self, expected, actual, places=6) - def test_empty_example(self): - self.assertEmpty(proto_utils.tf_examples_to_dict([])) + def test_batched_single_example_to_dict(self): + actual = proto_utils.tf_examples_to_dict([EXAMPLE_1]) + expected = { + 'bytes': [[b'ab']], + 'bytes_arr': [[b'cd', b'ef']], + 'int64': [[1]], + 'int64_arr': [[2, 3]], + 'float': [[1.5]], + 'float_arr': [[2.5, 3.5]], + } + test_utils.assert_nested_container_equal(self, expected, actual, places=6) - def test_unsupported_type(self): - with self.assertRaisesRegex(TypeError, 'Unsupported type'): - proto_utils.tf_examples_to_dict('unsupported_type') + def test_missing_features_example_to_dict(self): + example_missing_features = text_format.Parse( + """ + features { + feature { + key: "bytes" + value { + bytes_list { + value: "xy" + } + } + } + } + """, + example_pb2.Example(), + ) - def test_unsupported_value_type(self): with self.assertRaisesRegex( - TypeError, 'Value for "a" is not a supported type' + ValueError, 'All examples must have the same features' ): - proto_utils.dict_to_tf_example({'a': [example_pb2.Example()]}) + _ = proto_utils.tf_examples_to_dict([example_missing_features, EXAMPLE_1]) + + def test_empty_example_to_dict(self): + self.assertEmpty(proto_utils.tf_examples_to_dict([])) + + def test_dict_to_tf_example(self): + data = { + 'bytes_scalar': b'a', + 'str_scalar': 'b', + 'int64_scalar': 1, + 'flaot_scalar': 2.1, + 'bytes_list': [b'cd', b'ef'], + 'str_list': ['gh', 'ij'], + 'int64_list': [2, 3], + 'float_list': [1, 3.5], + } + expected = text_format.Parse( + """ + features { + feature { + key: "bytes_scalar" + value { + bytes_list { + value: "a" + } + } + } + feature { + key: "str_scalar" + value { + bytes_list { + value: "b" + } + } + } + feature { + key: "int64_scalar" + value { + int64_list { + value: 1 + } + } + } + feature { + key: "flaot_scalar" + value { + float_list { + value: 2.1 + } + } + } + feature { + key: "bytes_list" + value { + bytes_list { + value: "cd" + value: "ef" + } + } + } + feature { + key: "str_list" + value { + bytes_list { + value: "gh" + value: "ij" + } + } + } + feature { + key: "int64_list" + value { + int64_list { + value: 2 + value: 3 + } + } + } + feature { + key: "float_list" + value { + float_list { + value: 1.0 + value: 3.5 + } + } + } + } + """, + example_pb2.Example(), + ) + actual = proto_utils.dict_to_tf_example(data) + self.assertProtoEquals(expected, actual) + + def test_dict_to_tf_example_key_with_empty_list(self): + data = {'int64_list': 1, 'empty_list': []} + expected = text_format.Parse( + """ + features { + feature { + key: "int64_list" + value { + int64_list { + value: 1 + } + } + } + } + """, + example_pb2.Example(), + ) + actual = proto_utils.dict_to_tf_example(data) + self.assertProtoEquals(expected, actual) - def test_multiple_examples_missing_key(self): - data = [{'a': 'a', 'b': 1}, {'b': 2}] - examples = [proto_utils.dict_to_tf_example(d) for d in data] - with self.assertRaisesRegex(ValueError, 'Missing keys'): - _ = proto_utils.tf_examples_to_dict(examples) + def test_dict_to_tf_example_bad_str_type(self): + data = {'str_arr': ['abc', b'def']} + with self.assertRaisesRegex(AssertionError, 'bad str type'): + _ = proto_utils.dict_to_tf_example(data) + + @parameterized.named_parameters( + dict( + testcase_name='bytes', + data={'bad_type': [b'ab', 'cd']}, + ), + dict( + testcase_name='float', + data={'bad_type': [1.0, 'a']}, + ), + dict( + testcase_name='int64', + data={'bad_type': [1, 'b']}, + ), + ) + def test_dict_to_tf_example_inconsistent_types(self, data): + # This test is required as the logic to determine the type of the feature + # list is based on the first value of the list. + with self.assertRaises(Exception): + _ = proto_utils.dict_to_tf_example(data) + + def test_dict_to_tf_example_unsupported_type(self): + data = {'bad_type': [example_pb2.Example()]} + with self.assertRaisesRegex( + TypeError, 'Values for "bad_type" is not a supported type.' + ): + _ = proto_utils.dict_to_tf_example(data) if __name__ == '__main__':