Skip to content

Commit

Permalink
Tidy up message utils to not disable type checking.
Browse files Browse the repository at this point in the history
Also removed code path for python 3.7, which is no longer supported.

PiperOrigin-RevId: 643340534
Change-Id: Ife1e7678cc396f1725da8821b052fd7f38a7b895
  • Loading branch information
tomwardio authored and copybara-github committed Jun 14, 2024
1 parent 5e9d027 commit 2cd75a5
Showing 1 changed file with 27 additions and 50 deletions.
77 changes: 27 additions & 50 deletions dm_env_rpc/v1/message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
"""

import typing
from typing import Any, Iterable, NamedTuple, Type, Tuple, Union
from typing import Iterable, NamedTuple, Type, Union

import immutabledict

from google.protobuf import any_pb2
from google.protobuf import message
from dm_env_rpc.v1 import dm_env_rpc_pb2
from dm_env_rpc.v1 import error
from google.protobuf import message


_MESSAGE_TYPE_TO_FIELD = immutabledict.immutabledict({
Expand Down Expand Up @@ -81,8 +81,10 @@ def pack_rpc_request(
Packed extension (any_pb2.Any) - returned as-is.
Everything else - assumed to be an extension; wrapped in any_pb2.Any.
"""
# pytype: disable=bad-return-type
return _pack_message(request, _get_union_args(DmEnvRpcNativeRequest))
if isinstance(request, typing.get_args(DmEnvRpcNativeRequest)):
return typing.cast(DmEnvRpcNativeRequest, request)
else:
return _pack_message(request)


def unpack_rpc_request(
Expand All @@ -108,10 +110,12 @@ def unpack_rpc_request(
ValueError: The message is packed (any_pb2.Any), but not in |extension_type|
or the message type is not a native request or known extension.
"""
return _unpack_message(
request,
_get_union_args(DmEnvRpcNativeRequest),
extension_type=extension_type)
if isinstance(request, typing.get_args(DmEnvRpcNativeRequest)):
return request
else:
return _unpack_message(
request,
extension_type=extension_type)


def pack_rpc_response(
Expand All @@ -127,8 +131,10 @@ def pack_rpc_response(
Packed extension (any_pb2.Any) - returned as-is.
Everything else - assumed to be an extension; wrapped in any_pb2.Any.
"""
# pytype: disable=bad-return-type
return _pack_message(response, _get_union_args(DmEnvRpcNativeResponse))
if isinstance(response, typing.get_args(DmEnvRpcNativeResponse)):
return typing.cast(DmEnvRpcNativeResponse, response)
else:
return _pack_message(response)


def unpack_rpc_response(
Expand All @@ -154,8 +160,10 @@ def unpack_rpc_response(
ValueError: The message is packed (any_pb2.Any), but not in |extension_type|
or the message type is not a native request or known extension.
"""
return _unpack_message(response, _get_union_args(DmEnvRpcNativeResponse),
extension_type=extension_type)
if isinstance(response, typing.get_args(DmEnvRpcNativeResponse)):
return response
else:
return _unpack_message(response, extension_type=extension_type)


class EnvironmentRequestAndFieldName(NamedTuple):
Expand Down Expand Up @@ -213,24 +221,8 @@ def unpack_environment_response(
f'{expected_field_name}, actual: {response_field_name}')


def _pack_message(
msg: message.Message,
union_types: Tuple[Type[Any], ...],
) -> message.Message:
"""Packs a message based on its union types.
Args:
msg: The message to process.
union_types: Determines which types are "native" and "extension".
"native" types are passed through, unchanged.
"extension" types are packed into an Any message and returned.
Returns:
A message within union_types, or an Any proto.
"""
if isinstance(msg, union_types):
return msg

def _pack_message(msg) -> any_pb2.Any:
"""Helper to pack message into an Any proto."""
if isinstance(msg, any_pb2.Any):
return msg

Expand All @@ -242,32 +234,24 @@ def _pack_message(

def _unpack_message(
msg: message.Message,
union_types: Tuple[Type[Any], ...],
*,
extension_type: Union[
Type[message.Message], Iterable[Type[message.Message]]
],
) -> message.Message:
"""Packs a message based on its union types.
):
"""Helper to unpack a message from set of possible extensions.
Args:
msg: The message to process.
union_types: Determines which types are "native" and "extension". "native"
types are passed through, unchanged. "extension" types are packed into an
Any message and returned.
extension_type: Type or type(s) used to match extension messages. The first
matching type is used.
Returns:
An upacked message within |union_types| or an unpacked extension message
with type within |extension_type|.
An upacked extension message with type within |extension_type|.
Raises:
TypeError: Raised if a return type could not be determined.
"""
if isinstance(msg, union_types):
return msg

if isinstance(extension_type, type):
extension_type = (extension_type,)
else:
Expand All @@ -291,12 +275,5 @@ def _unpack_message(
return unpacked

raise ValueError(
'Message type does not appear in union_types and is not an extension: '
f'{type(msg).__name__}.\n'
f'Union Types:\n' + '\n'.join(f'- {t}' for t in union_types))


def _get_union_args(union) -> Tuple[Type[Any], ...]:
"""Python version-safe function for getting the types in a union."""
# __args__ exists < v3.8, typing.get_args() exists >= 3.8.
return getattr(union, '__args__', None) or typing.get_args(union)
f'Cannot unpack extension message with type: {type(msg).__name__}.'
)

0 comments on commit 2cd75a5

Please sign in to comment.