diff --git a/dm_env_rpc/v1/spec_manager_test.py b/dm_env_rpc/v1/spec_manager_test.py index a8781ff..ea14b7c 100644 --- a/dm_env_rpc/v1/spec_manager_test.py +++ b/dm_env_rpc/v1/spec_manager_test.py @@ -102,13 +102,9 @@ def test_pack_wrong_shape_raises_error(self): self._spec_manager.pack({'foo': [1, 2]}) def test_pack_wrong_dtype_raises_error(self): - with self.assertRaisesRegex(TypeError, 'int32'): + with self.assertRaises(ValueError): self._spec_manager.pack({'foo': 'hello'}) - def test_pack_cast_float_to_int_raises_error(self): - with self.assertRaisesRegex(TypeError, 'int32'): - self._spec_manager.pack({'foo': [0.5, 1.0, 1]}) - def test_pack_cast_int_to_float_is_ok(self): packed = self._spec_manager.pack({'fuzz': [1, 2]}) self.assertEqual([1.0, 2.0], packed[54].floats.array) diff --git a/dm_env_rpc/v1/tensor_utils.py b/dm_env_rpc/v1/tensor_utils.py index fc8bf39..02ee94e 100644 --- a/dm_env_rpc/v1/tensor_utils.py +++ b/dm_env_rpc/v1/tensor_utils.py @@ -322,7 +322,7 @@ def pack_tensor( value = value.astype( dtype=_DM_ENV_RPC_DTYPE_TO_NUMPY_DTYPE.get(dtype, dtype), copy=False, - casting='same_kind' if value.size else 'unsafe') + casting='unsafe') packed.shape[:] = value.shape packer = get_packer(value.dtype.type) diff --git a/dm_env_rpc/v1/tensor_utils_test.py b/dm_env_rpc/v1/tensor_utils_test.py index 63cf8bc..73c5a9f 100644 --- a/dm_env_rpc/v1/tensor_utils_test.py +++ b/dm_env_rpc/v1/tensor_utils_test.py @@ -168,7 +168,7 @@ def test_packed_rowmajor(self): np.testing.assert_array_equal([1, 2, 3, 4, 5, 6], tensor.int32s.array) def test_mixed_scalar_types_raises_exception(self): - with self.assertRaises(TypeError): + with self.assertRaises(ValueError): tensor_utils.pack_tensor(['hello!', 75], dtype=np.float32) def test_jagged_arrays_throw_exceptions(self):