Skip to content

Commit 02941bf

Browse files
authored
Allow '_mixture_' to return gates instead of just raw unitary matrices (#7048)
* Allow gates to be returned from _mixture_ * Add test * lint * coverage * mypy * Make ReturnsUnitary return unitary
1 parent d87d87e commit 02941bf

File tree

6 files changed

+54
-26
lines changed

6 files changed

+54
-26
lines changed

cirq-core/cirq/ops/common_channels.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1089,15 +1089,12 @@ def __init__(self, p: float) -> None:
10891089
ValueError: if p is not a valid probability.
10901090
"""
10911091
self._p = value.validate_probability(p, 'p')
1092-
self._delegate = AsymmetricDepolarizingChannel(p, 0.0, 0.0)
10931092

10941093
def _num_qubits_(self) -> int:
10951094
return 1
10961095

1097-
def _mixture_(self) -> Sequence[Tuple[float, np.ndarray]]:
1098-
mixture = self._delegate._mixture_()
1099-
# just return identity and x term
1100-
return (mixture[0], mixture[1])
1096+
def _mixture_(self) -> Sequence[Tuple[float, Any]]:
1097+
return ((1 - self._p, identity.I), (self._p, pauli_gates.X))
11011098

11021099
def _has_mixture_(self) -> bool:
11031100
return True

cirq-core/cirq/protocols/apply_mixture_protocol.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from cirq._doc import doc_private
2525
from cirq.protocols import qid_shape_protocol
2626
from cirq.protocols.apply_unitary_protocol import apply_unitary, ApplyUnitaryArgs
27-
from cirq.protocols.mixture_protocol import mixture
2827

2928
# This is a special indicator value used by the apply_mixture method
3029
# to determine whether or not the caller provided a 'default' argument. It must
@@ -260,9 +259,9 @@ def err_str(buf_num_str):
260259
return result
261260

262261
# Fallback to using the object's `_mixture_` matrices. (STEP C)
263-
prob_mix = mixture(val, None)
264-
if prob_mix is not None:
265-
return _mixture_strat(prob_mix, args, is_density_matrix)
262+
result = _apply_mixture_from_mixture_strat(val, args, is_density_matrix)
263+
if result is not None:
264+
return result
266265

267266
# Don't know how to apply mixture. Fallback to specified default behavior.
268267
# (STEP D)
@@ -359,11 +358,19 @@ def _apply_unitary_from_matrix_strat(
359358
return args.target_tensor
360359

361360

362-
def _mixture_strat(val: Any, args: 'ApplyMixtureArgs', is_density_matrix: bool) -> np.ndarray:
361+
def _apply_mixture_from_mixture_strat(
362+
val: Any, args: 'ApplyMixtureArgs', is_density_matrix: bool
363+
) -> Optional[np.ndarray]:
363364
"""Attempt to use unitary matrices in _mixture_ and return the result."""
365+
method = getattr(val, '_mixture_', None)
366+
if method is None:
367+
return None
368+
prob_mix = method()
369+
if prob_mix is NotImplemented or prob_mix is None:
370+
return None
364371
args.out_buffer[:] = 0
365372
np.copyto(dst=args.auxiliary_buffer1, src=args.target_tensor)
366-
for prob, op in val:
373+
for prob, op in prob_mix:
367374
np.copyto(dst=args.target_tensor, src=args.auxiliary_buffer1)
368375
right_result = _apply_unitary_strat(op, args, is_density_matrix)
369376
if right_result is None:

cirq-core/cirq/protocols/apply_mixture_protocol_test.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,17 @@ class NoProtocols:
237237
assert_apply_mixture_returns(NoProtocols(), rho, left_axes=[1], right_axes=[1])
238238

239239

240+
def test_apply_mixture_mixture_returns_not_implemented():
241+
class NoMixture:
242+
def _mixture_(self):
243+
return NotImplemented
244+
245+
rho = np.ones((2, 2, 2, 2), dtype=np.complex128)
246+
247+
with pytest.raises(TypeError, match='has no'):
248+
assert_apply_mixture_returns(NoMixture(), rho, left_axes=[1], right_axes=[1])
249+
250+
240251
def test_apply_mixture_no_protocols_implemented_default():
241252
class NoProtocols:
242253
pass

cirq-core/cirq/protocols/kraus_protocol.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from cirq._doc import doc_private
2525
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
2626
from cirq.protocols.mixture_protocol import has_mixture
27+
from cirq.protocols.unitary_protocol import unitary
2728

2829
# This is a special indicator value used by the channel method to determine
2930
# whether or not the caller provided a 'default' argument. It must be of type
@@ -145,7 +146,9 @@ def kraus(
145146
mixture_getter = getattr(val, '_mixture_', None)
146147
mixture_result = NotImplemented if mixture_getter is None else mixture_getter()
147148
if mixture_result is not NotImplemented and mixture_result is not None:
148-
return tuple(np.sqrt(p) * u for p, u in mixture_result)
149+
return tuple(
150+
np.sqrt(p) * (u if isinstance(u, np.ndarray) else unitary(u)) for p, u in mixture_result
151+
)
149152

150153
unitary_getter = getattr(val, '_unitary_', None)
151154
unitary_result = NotImplemented if unitary_getter is None else unitary_getter()

cirq-core/cirq/protocols/mixture_protocol.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from cirq._doc import doc_private
2424
from cirq.protocols.decompose_protocol import _try_decompose_into_operations_and_qubits
2525
from cirq.protocols.has_unitary_protocol import has_unitary
26+
from cirq.protocols.unitary_protocol import unitary
2627

2728
# This is a special indicator value used by the inverse method to determine
2829
# whether or not the caller provided a 'default' argument.
@@ -84,14 +85,14 @@ def mixture(
8485
with that probability in the mixture. The probabilities will sum to 1.0.
8586
8687
Raises:
87-
TypeError: If `val` has no `_mixture_` or `_unitary_` mehod, or if it
88+
TypeError: If `val` has no `_mixture_` or `_unitary_` method, or if it
8889
does and this method returned `NotImplemented`.
8990
"""
9091

9192
mixture_getter = getattr(val, '_mixture_', None)
9293
result = NotImplemented if mixture_getter is None else mixture_getter()
93-
if result is not NotImplemented:
94-
return result
94+
if result is not NotImplemented and result is not None:
95+
return tuple((p, u if isinstance(u, np.ndarray) else unitary(u)) for p, u in result)
9596

9697
unitary_getter = getattr(val, '_unitary_', None)
9798
result = NotImplemented if unitary_getter is None else unitary_getter()

cirq-core/cirq/protocols/mixture_protocol_test.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@
1717

1818
import cirq
1919

20+
a = np.array([1])
21+
b = np.array([1j])
22+
2023

2124
class NoMethod:
2225
pass
@@ -32,35 +35,35 @@ def _has_mixture_(self):
3235

3336
class ReturnsValidTuple(cirq.SupportsMixture):
3437
def _mixture_(self):
35-
return ((0.4, 'a'), (0.6, 'b'))
38+
return ((0.4, a), (0.6, b))
3639

3740
def _has_mixture_(self):
3841
return True
3942

4043

4144
class ReturnsNonnormalizedTuple:
4245
def _mixture_(self):
43-
return ((0.4, 'a'), (0.4, 'b'))
46+
return ((0.4, a), (0.4, b))
4447

4548

4649
class ReturnsNegativeProbability:
4750
def _mixture_(self):
48-
return ((0.4, 'a'), (-0.4, 'b'))
51+
return ((0.4, a), (-0.4, b))
4952

5053

5154
class ReturnsGreaterThanUnityProbability:
5255
def _mixture_(self):
53-
return ((1.2, 'a'), (0.4, 'b'))
56+
return ((1.2, a), (0.4, b))
5457

5558

5659
class ReturnsMixtureButNoHasMixture:
5760
def _mixture_(self):
58-
return ((0.4, 'a'), (0.6, 'b'))
61+
return ((0.4, a), (0.6, b))
5962

6063

6164
class ReturnsUnitary:
6265
def _unitary_(self):
63-
return np.ones((2, 2))
66+
return np.eye(2)
6467

6568
def _has_unitary_(self):
6669
return True
@@ -74,12 +77,18 @@ def _has_unitary_(self):
7477
return NotImplemented
7578

7679

80+
class ReturnsMixtureOfReturnsUnitary:
81+
def _mixture_(self):
82+
return ((0.4, ReturnsUnitary()), (0.6, ReturnsUnitary()))
83+
84+
7785
@pytest.mark.parametrize(
7886
'val,mixture',
7987
(
80-
(ReturnsValidTuple(), ((0.4, 'a'), (0.6, 'b'))),
81-
(ReturnsNonnormalizedTuple(), ((0.4, 'a'), (0.4, 'b'))),
82-
(ReturnsUnitary(), ((1.0, np.ones((2, 2))),)),
88+
(ReturnsValidTuple(), ((0.4, a), (0.6, b))),
89+
(ReturnsNonnormalizedTuple(), ((0.4, a), (0.4, b))),
90+
(ReturnsUnitary(), ((1.0, np.eye(2)),)),
91+
(ReturnsMixtureOfReturnsUnitary(), ((0.4, np.eye(2)), (0.6, np.eye(2)))),
8392
),
8493
)
8594
def test_objects_with_mixture(val, mixture):
@@ -88,7 +97,7 @@ def test_objects_with_mixture(val, mixture):
8897
np.testing.assert_almost_equal(keys, expected_keys)
8998
np.testing.assert_equal(values, expected_values)
9099

91-
keys, values = zip(*cirq.mixture(val, ((0.3, 'a'), (0.7, 'b'))))
100+
keys, values = zip(*cirq.mixture(val, ((0.3, a), (0.7, b))))
92101
np.testing.assert_almost_equal(keys, expected_keys)
93102
np.testing.assert_equal(values, expected_values)
94103

@@ -101,7 +110,7 @@ def test_objects_with_no_mixture(val):
101110
_ = cirq.mixture(val)
102111
assert cirq.mixture(val, None) is None
103112
assert cirq.mixture(val, NotImplemented) is NotImplemented
104-
default = ((0.4, 'a'), (0.6, 'b'))
113+
default = ((0.4, a), (0.6, b))
105114
assert cirq.mixture(val, default) == default
106115

107116

0 commit comments

Comments
 (0)