1717
1818import cirq
1919
20+ a = np .array ([1 ])
21+ b = np .array ([1j ])
22+
2023
2124class NoMethod :
2225 pass
@@ -32,35 +35,35 @@ def _has_mixture_(self):
3235
3336class ReturnsValidTuple (cirq .SupportsMixture ):
3437 def _mixture_ (self ):
35- return ((0.4 , 'a' ), (0.6 , 'b' ))
38+ return ((0.4 , a ), (0.6 , b ))
3639
3740 def _has_mixture_ (self ):
3841 return True
3942
4043
4144class ReturnsNonnormalizedTuple :
4245 def _mixture_ (self ):
43- return ((0.4 , 'a' ), (0.4 , 'b' ))
46+ return ((0.4 , a ), (0.4 , b ))
4447
4548
4649class ReturnsNegativeProbability :
4750 def _mixture_ (self ):
48- return ((0.4 , 'a' ), (- 0.4 , 'b' ))
51+ return ((0.4 , a ), (- 0.4 , b ))
4952
5053
5154class ReturnsGreaterThanUnityProbability :
5255 def _mixture_ (self ):
53- return ((1.2 , 'a' ), (0.4 , 'b' ))
56+ return ((1.2 , a ), (0.4 , b ))
5457
5558
5659class ReturnsMixtureButNoHasMixture :
5760 def _mixture_ (self ):
58- return ((0.4 , 'a' ), (0.6 , 'b' ))
61+ return ((0.4 , a ), (0.6 , b ))
5962
6063
6164class ReturnsUnitary :
6265 def _unitary_ (self ):
63- return np .ones (( 2 , 2 ) )
66+ return np .eye ( 2 )
6467
6568 def _has_unitary_ (self ):
6669 return True
@@ -74,12 +77,18 @@ def _has_unitary_(self):
7477 return NotImplemented
7578
7679
80+ class ReturnsMixtureOfReturnsUnitary :
81+ def _mixture_ (self ):
82+ return ((0.4 , ReturnsUnitary ()), (0.6 , ReturnsUnitary ()))
83+
84+
7785@pytest .mark .parametrize (
7886 'val,mixture' ,
7987 (
80- (ReturnsValidTuple (), ((0.4 , 'a' ), (0.6 , 'b' ))),
81- (ReturnsNonnormalizedTuple (), ((0.4 , 'a' ), (0.4 , 'b' ))),
82- (ReturnsUnitary (), ((1.0 , np .ones ((2 , 2 ))),)),
88+ (ReturnsValidTuple (), ((0.4 , a ), (0.6 , b ))),
89+ (ReturnsNonnormalizedTuple (), ((0.4 , a ), (0.4 , b ))),
90+ (ReturnsUnitary (), ((1.0 , np .eye (2 )),)),
91+ (ReturnsMixtureOfReturnsUnitary (), ((0.4 , np .eye (2 )), (0.6 , np .eye (2 )))),
8392 ),
8493)
8594def test_objects_with_mixture (val , mixture ):
@@ -88,7 +97,7 @@ def test_objects_with_mixture(val, mixture):
8897 np .testing .assert_almost_equal (keys , expected_keys )
8998 np .testing .assert_equal (values , expected_values )
9099
91- keys , values = zip (* cirq .mixture (val , ((0.3 , 'a' ), (0.7 , 'b' ))))
100+ keys , values = zip (* cirq .mixture (val , ((0.3 , a ), (0.7 , b ))))
92101 np .testing .assert_almost_equal (keys , expected_keys )
93102 np .testing .assert_equal (values , expected_values )
94103
@@ -101,7 +110,7 @@ def test_objects_with_no_mixture(val):
101110 _ = cirq .mixture (val )
102111 assert cirq .mixture (val , None ) is None
103112 assert cirq .mixture (val , NotImplemented ) is NotImplemented
104- default = ((0.4 , 'a' ), (0.6 , 'b' ))
113+ default = ((0.4 , a ), (0.6 , b ))
105114 assert cirq .mixture (val , default ) == default
106115
107116
0 commit comments