Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assume PhasedXPowGate is different from XPowGate and YPowGate #7070

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions cirq-core/cirq/ops/common_gates_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,9 +941,13 @@ def test_cx_cz_stabilizer(gate):


def test_phase_by_xy():
assert cirq.phase_by(cirq.X, 0.25, 0) == cirq.Y
assert cirq.phase_by(cirq.X**0.5, 0.25, 0) == cirq.Y**0.5
assert cirq.phase_by(cirq.X**-0.5, 0.25, 0) == cirq.Y**-0.5
assert cirq.phase_by(cirq.X, 0.25, 0) == cirq.PhasedXPowGate(phase_exponent=0.5)
assert cirq.phase_by(cirq.X**0.5, 0.25, 0) == cirq.PhasedXPowGate(
exponent=0.5, phase_exponent=0.5
)
assert cirq.phase_by(cirq.X**-0.5, 0.25, 0) == cirq.PhasedXPowGate(
exponent=-0.5, phase_exponent=0.5
)


def test_ixyz_circuit_diagram():
Expand Down
19 changes: 2 additions & 17 deletions cirq-core/cirq/ops/phased_x_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@
import cirq
from cirq import value, protocols
from cirq._compat import proper_repr
from cirq.ops import common_gates, raw_types
from cirq.ops import raw_types


@value.value_equality(manual_cls=True, approximate=True)
@value.value_equality(approximate=True)
class PhasedXPowGate(raw_types.Gate):
r"""A gate equivalent to $Z^{-p} X^t Z^{p}$ (in time order).

Expand Down Expand Up @@ -242,22 +242,7 @@ def _canonical_exponent(self):

return self._exponent % period

def _value_equality_values_cls_(self):
if self.phase_exponent == 0:
return common_gates.XPowGate
if self.phase_exponent == 0.5:
return common_gates.YPowGate
return PhasedXPowGate

def _value_equality_values_(self):
if self.phase_exponent == 0:
return common_gates.XPowGate(
exponent=self._exponent, global_shift=self._global_shift
)._value_equality_values_()
if self.phase_exponent == 0.5:
return common_gates.YPowGate(
exponent=self._exponent, global_shift=self._global_shift
)._value_equality_values_()
return self.phase_exponent, self._canonical_exponent, self._global_shift

def _json_dict_(self) -> Dict[str, Any]:
Expand Down
19 changes: 16 additions & 3 deletions cirq-core/cirq/ops/phased_x_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,14 @@ def test_eq():
cirq.PhasedXPowGate(exponent=1, phase_exponent=0),
cirq.PhasedXPowGate(exponent=1, phase_exponent=2),
cirq.PhasedXPowGate(exponent=1, phase_exponent=-2),
cirq.X,
)
eq.add_equality_group(cirq.PhasedXPowGate(exponent=1, phase_exponent=2, global_shift=0.1))

eq.add_equality_group(
cirq.PhasedXPowGate(phase_exponent=0.5, exponent=1),
cirq.PhasedXPowGate(phase_exponent=2.5, exponent=3),
cirq.Y,
)
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.25), cirq.Y**0.25)
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.25))

eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.25, exponent=0.25, global_shift=0.1))
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=2.25, exponent=0.25, global_shift=0.2))
Expand Down Expand Up @@ -266,3 +264,18 @@ def test_exponent_consistency(exponent, phase_exponent):
u = cirq.protocols.unitary(g)
u2 = cirq.protocols.unitary(g2)
assert np.all(u == u2)


def test_approx_eq_for_close_phase_exponents():
gate1 = cirq.PhasedXPowGate(phase_exponent=0)
gate2 = cirq.PhasedXPowGate(phase_exponent=1e-12)
gate3 = cirq.PhasedXPowGate(phase_exponent=2e-12)
gate4 = cirq.PhasedXPowGate(phase_exponent=0.345)

assert cirq.approx_eq(gate2, gate3)
assert cirq.approx_eq(gate2, gate1)
assert not cirq.approx_eq(gate2, gate4)

assert cirq.equal_up_to_global_phase(gate2, gate3)
assert cirq.equal_up_to_global_phase(gate2, gate1)
assert not cirq.equal_up_to_global_phase(gate2, gate4)
43 changes: 34 additions & 9 deletions cirq-core/cirq/transformers/eject_phased_paulis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,11 +196,17 @@ def test_crosses_czs():
# Partial CZ.
assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.CZ(a, b) ** 0.25]),
expected=quick_circuit([cirq.Z(b) ** 0.25], [cirq.CZ(a, b) ** -0.25], [cirq.X(a)]),
expected=quick_circuit(
[cirq.Z(b) ** 0.25],
[cirq.CZ(a, b) ** -0.25],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
)
assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.CZ(a, b) ** x]),
expected=quick_circuit([cirq.Z(b) ** x], [cirq.CZ(a, b) ** -x], [cirq.X(a)]),
expected=quick_circuit(
[cirq.Z(b) ** x], [cirq.CZ(a, b) ** -x], [cirq.PhasedXPowGate(phase_exponent=0)(a)]
),
eject_parameterized=True,
)

Expand Down Expand Up @@ -380,7 +386,8 @@ def test_phases_partial_ws():
[cirq.X(q)], [cirq.PhasedXPowGate(phase_exponent=0.25, exponent=0.5).on(q)]
),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=-0.25, exponent=0.5).on(q)], [cirq.X(q)]
[cirq.PhasedXPowGate(phase_exponent=-0.25, exponent=0.5).on(q)],
[cirq.PhasedXPowGate(phase_exponent=0)(q)],
),
)

Expand All @@ -398,7 +405,8 @@ def test_phases_partial_ws():
[cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.75).on(q)],
),
expected=quick_circuit(
[cirq.X(q) ** 0.75], [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)]
[cirq.PhasedXPowGate(phase_exponent=0)(q) ** 0.75],
[cirq.PhasedXPowGate(phase_exponent=0.25).on(q)],
),
)

Expand All @@ -407,7 +415,8 @@ def test_phases_partial_ws():
[cirq.X(q)], [cirq.PhasedXPowGate(exponent=-0.25, phase_exponent=0.5).on(q)]
),
expected=quick_circuit(
[cirq.PhasedXPowGate(exponent=-0.25, phase_exponent=-0.5).on(q)], [cirq.X(q)]
[cirq.PhasedXPowGate(exponent=-0.25, phase_exponent=-0.5).on(q)],
[cirq.PhasedXPowGate(phase_exponent=0)(q)],
),
)

Expand All @@ -431,18 +440,30 @@ def test_blocked_by_unknown_and_symbols(sym):

assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.SWAP(a, b)], [cirq.X(a)]),
expected=quick_circuit([cirq.X(a)], [cirq.SWAP(a, b)], [cirq.X(a)]),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
[cirq.SWAP(a, b)],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
)

assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.Z(a) ** sym], [cirq.X(a)]),
expected=quick_circuit([cirq.X(a)], [cirq.Z(a) ** sym], [cirq.X(a)]),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
[cirq.Z(a) ** sym],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
compare_unitaries=False,
)

assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.CZ(a, b) ** sym], [cirq.X(a)]),
expected=quick_circuit([cirq.X(a)], [cirq.CZ(a, b) ** sym], [cirq.X(a)]),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
[cirq.CZ(a, b) ** sym],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
compare_unitaries=False,
)

Expand All @@ -453,7 +474,11 @@ def test_blocked_by_nocompile_tag():

assert_optimizes(
before=quick_circuit([cirq.X(a)], [cirq.CZ(a, b).with_tags("nocompile")], [cirq.X(a)]),
expected=quick_circuit([cirq.X(a)], [cirq.CZ(a, b).with_tags("nocompile")], [cirq.X(a)]),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
[cirq.CZ(a, b).with_tags("nocompile")],
[cirq.PhasedXPowGate(phase_exponent=0)(a)],
),
with_context=True,
)

Expand Down
6 changes: 5 additions & 1 deletion cirq-core/cirq/transformers/eject_z_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,11 @@ def test_z_pushes_past_xy_and_phases_it():
assert_optimizes(
before=cirq.Circuit([cirq.Moment([cirq.Z(q) ** 0.5]), cirq.Moment([cirq.Y(q) ** 0.25])]),
expected=cirq.Circuit(
[cirq.Moment(), cirq.Moment([cirq.X(q) ** 0.25]), cirq.Moment([cirq.Z(q) ** 0.5])]
[
cirq.Moment(),
cirq.Moment([cirq.PhasedXPowGate(phase_exponent=0)(q) ** 0.25]),
cirq.Moment([cirq.Z(q) ** 0.5]),
]
),
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_merge_single_qubit_gates_to_phased_x_and_z():
optimized=cirq.merge_single_qubit_gates_to_phased_x_and_z(c),
expected=cirq.Circuit(
cirq.PhasedXPowGate(phase_exponent=1)(a),
cirq.Y(b) ** 0.5,
cirq.PhasedXPowGate(phase_exponent=0.5)(b) ** 0.5,
cirq.CZ(a, b),
(cirq.PhasedXPowGate(phase_exponent=-0.5)(a)) ** 0.5,
cirq.measure(b, key="m"),
Expand Down
6 changes: 3 additions & 3 deletions cirq-google/cirq_google/api/v1/programs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def assert_proto_dict_convert(gate: cirq.Gate, proto: operations_pb2.Operation,
def test_protobuf_round_trip():
qubits = cirq.GridQubit.rect(1, 5)
circuit = cirq.Circuit(
[cirq.X(q) ** 0.5 for q in qubits],
[cirq.PhasedXPowGate(phase_exponent=0)(q) ** 0.5 for q in qubits],
[
cirq.CZ(q, q2)
for q in [cirq.GridQubit(0, 0)]
Expand Down Expand Up @@ -245,7 +245,7 @@ def test_w_to_proto():
)
assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))

gate = cirq.X**0.25
gate = cirq.PhasedXPowGate(exponent=0.25, phase_exponent=0)
proto = operations_pb2.Operation(
exp_w=operations_pb2.ExpW(
target=operations_pb2.Qubit(row=2, col=3),
Expand All @@ -255,7 +255,7 @@ def test_w_to_proto():
)
assert_proto_dict_convert(gate, proto, cirq.GridQubit(2, 3))

gate = cirq.Y**0.25
gate = cirq.PhasedXPowGate(exponent=0.25, phase_exponent=0.5)
proto = operations_pb2.Operation(
exp_w=operations_pb2.ExpW(
target=operations_pb2.Qubit(row=2, col=3),
Expand Down
3 changes: 2 additions & 1 deletion cirq-google/cirq_google/engine/engine_program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def test_get_circuit_v1(get_program_async):
@mock.patch('cirq_google.engine.engine_client.EngineClient.get_program_async')
def test_get_circuit_v2(get_program_async):
circuit = cirq.Circuit(
cirq.X(cirq.GridQubit(5, 2)) ** 0.5, cirq.measure(cirq.GridQubit(5, 2), key='result')
cirq.PhasedXPowGate(phase_exponent=0)(cirq.GridQubit(5, 2)) ** 0.5,
cirq.measure(cirq.GridQubit(5, 2), key='result'),
)

program = cg.EngineProgram('a', 'b', EngineContext())
Expand Down
Loading