Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 73 additions & 43 deletions ml_metrics/_src/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Proto utils."""

import collections
from collections.abc import Iterable
from typing import Any
from absl import logging
from ml_metrics._src.tools.telemetry import telemetry
import more_itertools as mit
import numpy as np
from tensorflow.core.example import example_pb2


_ExampleOrBytes = bytes | example_pb2.Example


Expand All @@ -32,61 +33,90 @@ def _maybe_deserialize(ex: _ExampleOrBytes) -> example_pb2.Example:


@telemetry.function_monitor(api='ml_metrics', category=telemetry.CATEGORY.UTIL)
def tf_examples_to_dict(examples: Iterable[_ExampleOrBytes] | _ExampleOrBytes):
"""Parses a serialized tf.train.Example to a dict."""
def tf_examples_to_dict(
examples: Iterable[_ExampleOrBytes] | _ExampleOrBytes,
) -> dict[
str,
list[int | float | bytes] | list[list[int | float | bytes]],
]:
"""Parses serialized or unserialized tf.train.Examples to a dict.

The conversion assumes all examples have the same features. If not, a
ValueError will be raised.

Args:
examples: A single tf.train.Example, serialized tf.train.Example, or an
iterable of tf.train.Examples and/or serialized tf.train.Examples.

Returns:
A dict mapping feature names to lists of feature values.

Raises:
ValueError: If the features are not all present in all examples.
"""

single_example = False
if isinstance(examples, (bytes, example_pb2.Example)):
single_example = True
examples = [examples]
examples = (_maybe_deserialize(ex) for ex in examples)
examples = mit.peekable(examples)
if (head := examples.peek(None)) is None:
return {}

result = {k: [] for k in head.features.feature}
result = collections.defaultdict(list)

for ex in examples:
missing = set(result)
for key, feature in ex.features.feature.items():
missing.remove(key)
value = getattr(feature, feature.WhichOneof('kind')).value
if value and isinstance(value[0], bytes):
try:
value = [v.decode() for v in value]
except UnicodeDecodeError:
logging.info(
'chainable: %s',
f'Failed to decode for {key}, forward the raw bytes.',
)
result[key].extend(value)
if missing:
ex = _maybe_deserialize(ex)
features = dict(ex.features.feature)

if result and result.keys() != features.keys():
raise ValueError(
f'Missing keys: {missing}, expecting {set(result)}, got {ex=}'
'All examples must have the same features, got %s and %s'
% (result.keys(), features.keys())
)
result = {k: v for k, v in result.items()}
# Scalar value in a single example will be returned with the scalar directly.
if single_example and all(len(v) == 1 for v in result.values()):
result = {k: v[0] for k, v in result.items()}

for name, values in features.items():
result[name].append(getattr(values, values.WhichOneof('kind')).value)

if single_example:
return {k: v[0] for k, v in result.items()}
return result


@telemetry.function_monitor(api='ml_metrics', category=telemetry.CATEGORY.UTIL)
def dict_to_tf_example(data: dict[str, Any]) -> example_pb2.Example:
"""Creates a tf.Example from a dictionary."""

example = example_pb2.Example()
for key, value in data.items():
if isinstance(value, (str, bytes, np.floating, float, int, np.integer)):
value = [value]
feature = example.features.feature
if isinstance(value[0], str):
for v in value:
assert isinstance(v, str), f'bad str type: {value}'
feature[key].bytes_list.value.append(v.encode())
elif isinstance(value[0], bytes):
feature[key].bytes_list.value.extend(value)
elif isinstance(value[0], (int, np.integer)):
feature[key].int64_list.value.extend(value)
elif isinstance(value[0], (float, np.floating)):
feature[key].float_list.value.extend(value)
for key, values in data.items():
if isinstance(values, (str, bytes, np.floating, float, int, np.integer)):
values = [values]

if not values:
# Skip empty features.
continue

if isinstance(values[0], str):
for v in values:
assert isinstance(v, str), f'bad str type: {values}'
example.features.feature[key].bytes_list.value.append(v.encode())
continue

if isinstance(values[0], bytes):
feature_kind = 'bytes_list'
elif isinstance(values[0], (float, np.floating)):
feature_kind = 'float_list'
elif isinstance(values[0], (int, np.integer)):
feature_kind = 'int64_list'
for v in values:
if isinstance(v, (float, np.floating)):
# If a float is encountered in the list, we consider the whole feature
# to be a float_list.
feature_kind = 'float_list'
break
elif not isinstance(v, (int, np.integer)):
break
else:
raise TypeError(f'Value for "{key}" is not a supported type.')
raise TypeError(f'Values for "{key}" is not a supported type.')

feature_list = getattr(example.features.feature[key], feature_kind).value
feature_list.extend(values)

return example
Loading
Loading