Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assume PhasedXPowGate is different from XPowGate and YPowGate #7070

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 4 additions & 13 deletions cirq-core/cirq/ops/phased_x_gate.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import cirq
from cirq import value, protocols
from cirq._compat import proper_repr
from cirq.ops import common_gates, raw_types
from cirq.ops import raw_types


@value.value_equality(manual_cls=True, approximate=True)
Expand Down Expand Up @@ -243,21 +243,12 @@ def _canonical_exponent(self):
return self._exponent % period

def _value_equality_values_cls_(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be removed completely? I assume by the description in #4585, removing it entirely would have the same effect. But I could be wrong.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 - the _value_equality_values_cls_ can be indeed removed provided manual_cls in the decorator is changed to the False default. The _value_equality_approximate_values_ defaults to _value_equality_values_ so it can be removed as well.

diff --git a/cirq-core/cirq/ops/phased_x_gate.py b/cirq-core/cirq/ops/phased_x_gate.py
index d769b863..51bc5f39 100644
--- a/cirq-core/cirq/ops/phased_x_gate.py
+++ b/cirq-core/cirq/ops/phased_x_gate.py
@@ -30,3 +30,3 @@ from cirq.ops import raw_types
 
-@value.value_equality(manual_cls=True, approximate=True)
+@value.value_equality(approximate=True)
 class PhasedXPowGate(raw_types.Gate):
@@ -244,5 +244,2 @@ class PhasedXPowGate(raw_types.Gate):
 
-    def _value_equality_values_cls_(self):
-        return PhasedXPowGate
-
     def _value_equality_values_(self):
@@ -250,5 +247,2 @@ class PhasedXPowGate(raw_types.Gate):
 
-    def _value_equality_approximate_values_(self):
-        return self.phase_exponent, self._canonical_exponent, self._global_shift
-
     def _json_dict_(self) -> Dict[str, Any]:

if self.phase_exponent == 0:
return common_gates.XPowGate
if self.phase_exponent == 0.5:
return common_gates.YPowGate
return PhasedXPowGate

def _value_equality_values_(self):
if self.phase_exponent == 0:
return common_gates.XPowGate(
exponent=self._exponent, global_shift=self._global_shift
)._value_equality_values_()
if self.phase_exponent == 0.5:
return common_gates.YPowGate(
exponent=self._exponent, global_shift=self._global_shift
)._value_equality_values_()
return self.phase_exponent, self._canonical_exponent, self._global_shift

def _value_equality_approximate_values_(self):
return self.phase_exponent, self._canonical_exponent, self._global_shift

def _json_dict_(self) -> Dict[str, Any]:
Expand Down
19 changes: 16 additions & 3 deletions cirq-core/cirq/ops/phased_x_gate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,14 @@ def test_eq():
cirq.PhasedXPowGate(exponent=1, phase_exponent=0),
cirq.PhasedXPowGate(exponent=1, phase_exponent=2),
cirq.PhasedXPowGate(exponent=1, phase_exponent=-2),
cirq.X,
)
eq.add_equality_group(cirq.PhasedXPowGate(exponent=1, phase_exponent=2, global_shift=0.1))

eq.add_equality_group(
cirq.PhasedXPowGate(phase_exponent=0.5, exponent=1),
cirq.PhasedXPowGate(phase_exponent=2.5, exponent=3),
cirq.Y,
)
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.25), cirq.Y**0.25)
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.5, exponent=0.25))

eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=0.25, exponent=0.25, global_shift=0.1))
eq.add_equality_group(cirq.PhasedXPowGate(phase_exponent=2.25, exponent=0.25, global_shift=0.2))
Expand Down Expand Up @@ -266,3 +264,18 @@ def test_exponent_consistency(exponent, phase_exponent):
u = cirq.protocols.unitary(g)
u2 = cirq.protocols.unitary(g2)
assert np.all(u == u2)


def test_approx_eq_for_close_phase_exponents():
gate1 = cirq.PhasedXPowGate(phase_exponent=0)
gate2 = cirq.PhasedXPowGate(phase_exponent=1e-12)
gate3 = cirq.PhasedXPowGate(phase_exponent=2e-12)
gate4 = cirq.PhasedXPowGate(phase_exponent=0.345)

assert cirq.approx_eq(gate2, gate3)
assert cirq.approx_eq(gate2, gate1)
assert not cirq.approx_eq(gate2, gate4)

assert cirq.equal_up_to_global_phase(gate2, gate3)
assert cirq.equal_up_to_global_phase(gate2, gate1)
assert not cirq.equal_up_to_global_phase(gate2, gate4)
Loading