17
17
18
18
import cirq
19
19
20
+ a = np .array ([1 ])
21
+ b = np .array ([1j ])
22
+
20
23
21
24
class NoMethod :
22
25
pass
@@ -32,35 +35,35 @@ def _has_mixture_(self):
32
35
33
36
class ReturnsValidTuple (cirq .SupportsMixture ):
34
37
def _mixture_ (self ):
35
- return ((0.4 , 'a' ), (0.6 , 'b' ))
38
+ return ((0.4 , a ), (0.6 , b ))
36
39
37
40
def _has_mixture_ (self ):
38
41
return True
39
42
40
43
41
44
class ReturnsNonnormalizedTuple :
42
45
def _mixture_ (self ):
43
- return ((0.4 , 'a' ), (0.4 , 'b' ))
46
+ return ((0.4 , a ), (0.4 , b ))
44
47
45
48
46
49
class ReturnsNegativeProbability :
47
50
def _mixture_ (self ):
48
- return ((0.4 , 'a' ), (- 0.4 , 'b' ))
51
+ return ((0.4 , a ), (- 0.4 , b ))
49
52
50
53
51
54
class ReturnsGreaterThanUnityProbability :
52
55
def _mixture_ (self ):
53
- return ((1.2 , 'a' ), (0.4 , 'b' ))
56
+ return ((1.2 , a ), (0.4 , b ))
54
57
55
58
56
59
class ReturnsMixtureButNoHasMixture :
57
60
def _mixture_ (self ):
58
- return ((0.4 , 'a' ), (0.6 , 'b' ))
61
+ return ((0.4 , a ), (0.6 , b ))
59
62
60
63
61
64
class ReturnsUnitary :
62
65
def _unitary_ (self ):
63
- return np .ones (( 2 , 2 ) )
66
+ return np .eye ( 2 )
64
67
65
68
def _has_unitary_ (self ):
66
69
return True
@@ -74,12 +77,18 @@ def _has_unitary_(self):
74
77
return NotImplemented
75
78
76
79
80
+ class ReturnsMixtureOfReturnsUnitary :
81
+ def _mixture_ (self ):
82
+ return ((0.4 , ReturnsUnitary ()), (0.6 , ReturnsUnitary ()))
83
+
84
+
77
85
@pytest .mark .parametrize (
78
86
'val,mixture' ,
79
87
(
80
- (ReturnsValidTuple (), ((0.4 , 'a' ), (0.6 , 'b' ))),
81
- (ReturnsNonnormalizedTuple (), ((0.4 , 'a' ), (0.4 , 'b' ))),
82
- (ReturnsUnitary (), ((1.0 , np .ones ((2 , 2 ))),)),
88
+ (ReturnsValidTuple (), ((0.4 , a ), (0.6 , b ))),
89
+ (ReturnsNonnormalizedTuple (), ((0.4 , a ), (0.4 , b ))),
90
+ (ReturnsUnitary (), ((1.0 , np .eye (2 )),)),
91
+ (ReturnsMixtureOfReturnsUnitary (), ((0.4 , np .eye (2 )), (0.6 , np .eye (2 )))),
83
92
),
84
93
)
85
94
def test_objects_with_mixture (val , mixture ):
@@ -88,7 +97,7 @@ def test_objects_with_mixture(val, mixture):
88
97
np .testing .assert_almost_equal (keys , expected_keys )
89
98
np .testing .assert_equal (values , expected_values )
90
99
91
- keys , values = zip (* cirq .mixture (val , ((0.3 , 'a' ), (0.7 , 'b' ))))
100
+ keys , values = zip (* cirq .mixture (val , ((0.3 , a ), (0.7 , b ))))
92
101
np .testing .assert_almost_equal (keys , expected_keys )
93
102
np .testing .assert_equal (values , expected_values )
94
103
@@ -101,7 +110,7 @@ def test_objects_with_no_mixture(val):
101
110
_ = cirq .mixture (val )
102
111
assert cirq .mixture (val , None ) is None
103
112
assert cirq .mixture (val , NotImplemented ) is NotImplemented
104
- default = ((0.4 , 'a' ), (0.6 , 'b' ))
113
+ default = ((0.4 , a ), (0.6 , b ))
105
114
assert cirq .mixture (val , default ) == default
106
115
107
116
0 commit comments