Skip to content

Assume PhasedXPowGate is different from XPowGate and YPowGate #7070

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 7, 2025
20 changes: 16 additions & 4 deletions cirq-core/cirq/ops/common_gates.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,9 +289,7 @@ def phase_exponent(self):

def _phase_by_(self, phase_turns, qubit_index):
"""See `cirq.SupportsPhase`."""
return cirq.ops.phased_x_gate.PhasedXPowGate(
exponent=self._exponent, phase_exponent=phase_turns * 2
)
return _phased_x_or_pauli_gate(exponent=self._exponent, phase_exponent=phase_turns * 2)

def _has_stabilizer_effect_(self) -> Optional[bool]:
if self._is_parameterized_() or self._dimension != 2:
Expand Down Expand Up @@ -484,7 +482,7 @@ def phase_exponent(self):

def _phase_by_(self, phase_turns, qubit_index):
"""See `cirq.SupportsPhase`."""
return cirq.ops.phased_x_gate.PhasedXPowGate(
return _phased_x_or_pauli_gate(
exponent=self._exponent, phase_exponent=0.5 + phase_turns * 2
)

Expand Down Expand Up @@ -1542,3 +1540,17 @@ def cphase(rads: value.TParamVal) -> CZPowGate:
$$
""",
)


def _phased_x_or_pauli_gate(
exponent: Union[float, sympy.Expr], phase_exponent: Union[float, sympy.Expr]
) -> Union['cirq.PhasedXPowGate', 'cirq.XPowGate', 'cirq.YPowGate']:
"""Return PhasedXPowGate or X or Y gate if equivalent at the given phase_exponent."""
if not isinstance(phase_exponent, sympy.Expr) or phase_exponent.is_constant():
half_turns = value.canonicalize_half_turns(float(phase_exponent))
match half_turns:
case 0.0:
return XPowGate(exponent=exponent)
case 0.5:
return YPowGate(exponent=exponent)
return cirq.ops.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)
19 changes: 2 additions & 17 deletions cirq-core/cirq/ops/phased_x_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
import cirq
from cirq import protocols, value
from cirq._compat import proper_repr
from cirq.ops import common_gates, raw_types
from cirq.ops import raw_types


@value.value_equality(manual_cls=True, approximate=True)
@value.value_equality(approximate=True)
class PhasedXPowGate(raw_types.Gate):
r"""A gate equivalent to $Z^{-p} X^t Z^{p}$ (in time order).

Expand Down Expand Up @@ -241,22 +241,7 @@ def _canonical_exponent(self):

return self._exponent % period

def _value_equality_values_cls_(self):
if self.phase_exponent == 0:
return common_gates.XPowGate
if self.phase_exponent == 0.5:
return common_gates.YPowGate
return PhasedXPowGate

def _value_equality_values_(self):
if self.phase_exponent == 0:
return common_gates.XPowGate(
exponent=self._exponent, global_shift=self._global_shift
)._value_equality_values_()
if self.phase_exponent == 0.5:
return common_gates.YPowGate(
exponent=self._exponent, global_shift=self._global_shift
)._value_equality_values_()
return self.phase_exponent, self._canonical_exponent, self._global_shift

def _json_dict_(self) -> Dict[str, Any]:
Expand Down
22 changes: 19 additions & 3 deletions cirq-core/cirq/ops/phased_x_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,17 @@ def test_eq():
cirq.PhasedXPowGate(exponent=1, phase_exponent=0),
cirq.PhasedXPowGate(exponent=1, phase_exponent=2),
cirq.PhasedXPowGate(exponent=1, phase_exponent=-2),
cirq.X,
)
eq.add_equality_group(cirq.X)
eq.add_equality_group(cirq.PhasedXPowGate(exponent=1, phase_exponent=2, global_shift=0.1))

eq.add_equality_group(
cirq.PhasedXPowGate(phase_exponent=0.5, exponent=1),
cirq.PhasedXPowGate(phase_exponent=2.5, exponent=3),
cirq.Y,
)
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.25), cirq.Y**0.25)
eq.add_equality_group(cirq.Y)
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.25))
eq.add_equality_group(cirq.Y**0.25)

eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.25, exponent=0.25, global_shift=0.1))
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=2.25, exponent=0.25, global_shift=0.2))
Expand Down Expand Up @@ -266,3 +267,18 @@ def test_exponent_consistency(exponent, phase_exponent):
u = cirq.protocols.unitary(g)
u2 = cirq.protocols.unitary(g2)
assert np.all(u == u2)


def test_approx_eq_for_close_phase_exponents():
gate1 = cirq.PhasedXPowGate(phase_exponent=0)
gate2 = cirq.PhasedXPowGate(phase_exponent=1e-12)
gate3 = cirq.PhasedXPowGate(phase_exponent=2e-12)
gate4 = cirq.PhasedXPowGate(phase_exponent=0.345)

assert cirq.approx_eq(gate2, gate3)
assert cirq.approx_eq(gate2, gate1)
assert not cirq.approx_eq(gate2, gate4)

assert cirq.equal_up_to_global_phase(gate2, gate3)
assert cirq.equal_up_to_global_phase(gate2, gate1)
assert not cirq.equal_up_to_global_phase(gate2, gate4)
37 changes: 25 additions & 12 deletions cirq-core/cirq/transformers/eject_phased_paulis.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

"""Transformer pass that pushes 180° rotations around axes in the XY plane later in the circuit."""

from typing import cast, Dict, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING
from typing import cast, Dict, Iterable, Iterator, Optional, Tuple, TYPE_CHECKING, Union

import numpy as np
import sympy
Expand Down Expand Up @@ -63,15 +63,15 @@ def eject_phased_paulis(
def map_func(op: 'cirq.Operation', _: int) -> 'cirq.OP_TREE':
# Dump if `op` marked with a no compile tag.
if set(op.tags) & tags_to_ignore:
return [_dump_held(op.qubits, held_w_phases), op]
return [_dump_held(op.qubits, held_w_phases, atol), op]

# Collect, phase, and merge Ws.
w = _try_get_known_phased_pauli(op, no_symbolic=not eject_parameterized)
if w is not None:
return (
_potential_cross_whole_w(op, atol, held_w_phases)
if single_qubit_decompositions.is_negligible_turn((w[0] - 1) / 2, atol)
else _potential_cross_partial_w(op, held_w_phases)
else _potential_cross_partial_w(op, held_w_phases, atol)
)

affected = [q for q in op.qubits if q in held_w_phases]
Expand All @@ -96,12 +96,12 @@ def map_func(op: 'cirq.Operation', _: int) -> 'cirq.OP_TREE':
)

# Don't know how to handle this situation. Dump the gates.
return [_dump_held(op.qubits, held_w_phases), op]
return [_dump_held(op.qubits, held_w_phases, atol), op]

# Map operations and put anything that's still held at the end of the circuit.
return circuits.Circuit(
transformer_primitives.map_operations_and_unroll(circuit, map_func),
_dump_held(held_w_phases.keys(), held_w_phases),
_dump_held(held_w_phases.keys(), held_w_phases, atol),
)


Expand All @@ -127,14 +127,14 @@ def _absorb_z_into_w(


def _dump_held(
qubits: Iterable[ops.Qid], held_w_phases: Dict[ops.Qid, value.TParamVal]
qubits: Iterable[ops.Qid], held_w_phases: Dict[ops.Qid, value.TParamVal], atol: float
) -> Iterator['cirq.OP_TREE']:
# Note: sorting is to avoid non-determinism in the insertion order.
for q in sorted(qubits):
p = held_w_phases.get(q)
if p is not None:
dump_op = ops.PhasedXPowGate(phase_exponent=p).on(q)
yield dump_op
gate = _phased_x_or_pauli_gate(exponent=1.0, phase_exponent=p, atol=atol)
yield gate.on(q)
held_w_phases.pop(q, None)


Expand Down Expand Up @@ -184,7 +184,7 @@ def _potential_cross_whole_w(


def _potential_cross_partial_w(
op: ops.Operation, held_w_phases: Dict[ops.Qid, value.TParamVal]
op: ops.Operation, held_w_phases: Dict[ops.Qid, value.TParamVal], atol: float
) -> 'cirq.OP_TREE':
"""Cross the held W over a partial W gate.

Expand All @@ -204,10 +204,10 @@ def _potential_cross_partial_w(
exponent, phase_exponent = cast(
Tuple[value.TParamVal, value.TParamVal], _try_get_known_phased_pauli(op)
)
new_op = ops.PhasedXPowGate(exponent=exponent, phase_exponent=2 * a - phase_exponent).on(
op.qubits[0]
gate = _phased_x_or_pauli_gate(
exponent=exponent, phase_exponent=2 * a - phase_exponent, atol=atol
)
return new_op
return gate.on(op.qubits[0])


def _single_cross_over_cz(op: ops.Operation, qubit_with_w: 'cirq.Qid') -> 'cirq.OP_TREE':
Expand Down Expand Up @@ -351,3 +351,16 @@ def _try_get_known_z_half_turns(
if no_symbolic and isinstance(h, sympy.Basic):
return None
return h


def _phased_x_or_pauli_gate(
exponent: Union[float, sympy.Expr], phase_exponent: Union[float, sympy.Expr], atol: float
) -> Union['cirq.PhasedXPowGate', 'cirq.XPowGate', 'cirq.YPowGate']:
"""Return PhasedXPowGate or X or Y gate if equivalent within atol in z-axis turns."""
if not isinstance(phase_exponent, sympy.Expr) or phase_exponent.is_constant():
half_turns = value.canonicalize_half_turns(float(phase_exponent))
if abs(half_turns / 2) <= atol:
return ops.XPowGate(exponent=exponent)
if abs((half_turns - 0.5) / 2) <= atol:
return ops.YPowGate(exponent=exponent)
return ops.PhasedXPowGate(exponent=exponent, phase_exponent=phase_exponent)
9 changes: 2 additions & 7 deletions cirq-core/cirq/transformers/eject_phased_paulis_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,7 @@ def test_crosses_czs():
[cirq.CZ(a, b) ** 0.25],
),
expected=quick_circuit(
[cirq.CZ(a, b) ** 0.25],
[
cirq.PhasedXPowGate(phase_exponent=0.5).on(b),
cirq.PhasedXPowGate(phase_exponent=0.25).on(a),
],
[cirq.CZ(a, b) ** 0.25], [cirq.Y(b), cirq.PhasedXPowGate(phase_exponent=0.25).on(a)]
),
)
assert_optimizes(
Expand Down Expand Up @@ -387,8 +383,7 @@ def test_phases_partial_ws():
assert_optimizes(
before=quick_circuit([cirq.PhasedXPowGate(phase_exponent=0.25).on(q)], [cirq.X(q) ** 0.5]),
expected=quick_circuit(
[cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.5).on(q)],
[cirq.PhasedXPowGate(phase_exponent=0.25).on(q)],
[cirq.Y(q) ** 0.5], [cirq.PhasedXPowGate(phase_exponent=0.25).on(q)]
),
)

Expand Down
4 changes: 2 additions & 2 deletions cirq-core/cirq/transformers/eject_z_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@ def assert_optimizes(
cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(3).with_tags("preserve_tag")),
)
c_expected = cirq.Circuit(
cirq.PhasedXPowGate(phase_exponent=0, exponent=0.25).on_each(*q),
(cirq.X**0.25).on_each(*q),
(cirq.Z**0.5).on_each(*q),
cirq.Moment(cirq.CircuitOperation(before.freeze()).repeat(2).with_tags("ignore")),
cirq.PhasedXPowGate(phase_exponent=0, exponent=0.25).on_each(*q),
(cirq.X**0.25).on_each(*q),
(cirq.Z**0.5).on_each(*q),
cirq.Moment(cirq.CircuitOperation(expected.freeze()).repeat(3).with_tags("preserve_tag")),
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def test_merge_single_qubit_gates_to_phased_x_and_z():
optimized=cirq.merge_single_qubit_gates_to_phased_x_and_z(c),
expected=cirq.Circuit(
cirq.PhasedXPowGate(phase_exponent=1)(a),
cirq.Y(b) ** 0.5,
cirq.PhasedXPowGate(phase_exponent=0.5)(b) ** 0.5,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. The updated expected circuit is actually more consistent with the merge_single_qubit_gates_to_phased_x_and_z docstring here -

def merge_single_qubit_gates_to_phased_x_and_z(
circuit: 'cirq.AbstractCircuit',
*,
context: Optional['cirq.TransformerContext'] = None,
atol: float = 1e-8,
) -> 'cirq.Circuit':
"""Replaces runs of single qubit rotations with `cirq.PhasedXPowGate` and `cirq.ZPowGate`.
Specifically, any run of non-parameterized single-qubit unitaries will be replaced by an
optional PhasedX operation followed by an optional Z operation.

cirq.CZ(a, b),
(cirq.PhasedXPowGate(phase_exponent=-0.5)(a)) ** 0.5,
cirq.measure(b, key="m"),
Expand Down
16 changes: 6 additions & 10 deletions cirq-google/cirq_google/api/v1/programs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,31 +15,27 @@
import pytest
import sympy

import cirq
import cirq.testing
import cirq_google.api.v1.programs as programs
from cirq_google.api.v1 import operations_pb2


def assert_proto_dict_convert(gate: cirq.Gate, proto: operations_pb2.Operation, *qubits: cirq.Qid):
assert programs.gate_to_proto(gate, qubits, delay=0) == proto
assert programs.xmon_op_from_proto(proto) == gate(*qubits)
xmon_op = programs.xmon_op_from_proto(proto)
assert xmon_op.qubits == qubits
assert xmon_op.gate == gate or np.allclose(cirq.unitary(xmon_op.gate), cirq.unitary(gate))


def test_protobuf_round_trip():
qubits = cirq.GridQubit.rect(1, 5)
circuit = cirq.Circuit(
[cirq.X(q) ** 0.5 for q in qubits],
[
cirq.CZ(q, q2)
for q in [cirq.GridQubit(0, 0)]
for q, q2 in zip(qubits, qubits)
if q != q2
],
[cirq.X(q) ** 0.5 for q in qubits], [cirq.CZ(qubits[0], q1) for q1 in qubits[1:]]
)

protos = list(programs.circuit_as_schedule_to_protos(circuit))
s2 = programs.circuit_from_schedule_from_protos(protos)
assert s2 == circuit
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(s2, circuit)


def make_bytes(s: str) -> bytes:
Expand Down
6 changes: 4 additions & 2 deletions cirq-google/cirq_google/engine/engine_program_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from google.protobuf import any_pb2, timestamp_pb2
from google.protobuf.text_format import Merge

import cirq
import cirq.testing
import cirq_google as cg
from cirq_google.api import v1, v2
from cirq_google.cloud import quantum
Expand Down Expand Up @@ -304,7 +304,9 @@ def test_get_circuit_v2(get_program_async):

program = cg.EngineProgram('a', 'b', EngineContext())
get_program_async.return_value = quantum.QuantumProgram(code=_PROGRAM_V2)
assert program.get_circuit() == circuit
cirq.testing.assert_circuits_with_terminal_measurements_are_equivalent(
program.get_circuit(), circuit
)
get_program_async.assert_called_once_with('a', 'b', True)


Expand Down