Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update pass_operations_over with fixed conjugated_by #7123

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
45 changes: 19 additions & 26 deletions cirq-core/cirq/ops/pauli_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import numpy as np
import sympy

import cirq
from cirq import value, protocols, linalg, qis, _compat
from cirq._doc import document
from cirq._import import LazyLoader
Expand Down Expand Up @@ -496,7 +495,7 @@ def matrix(self, qubits: Optional[Iterable[TKey]] = None) -> np.ndarray:
"""
qubits = self.qubits if qubits is None else qubits
factors = [self.get(q, default=identity.I) for q in qubits]
if cirq.is_parameterized(self):
if protocols.is_parameterized(self):
raise NotImplementedError('Cannot express as matrix when parameterized')
assert isinstance(self.coefficient, complex)
return linalg.kron(self.coefficient, *[protocols.unitary(f) for f in factors])
Expand Down Expand Up @@ -981,31 +980,28 @@ def conjugated_by(self, clifford: 'cirq.OP_TREE') -> 'PauliString':
ps = PauliString(qubit_pauli_map=self._qubit_pauli_map, coefficient=self.coefficient)
all_ops = list(op_tree.flatten_to_ops(clifford))
all_qubits = set.union(set(self.qubits), [q for op in all_ops for q in op.qubits])

# Iteratively calculate the conjugation in reverse order of ops.
for op in all_ops[::-1]:
# To calcuate the conjugation of P (`ps`) with respect to C (`op`)
# Decompose P = Pc⊗R, where Pc acts on the same qubits as C, R acts on the remaining.
# Then the conjugation = (C^{-1}⊗I·Pc⊗R·C⊗I) = (C^{-1}·Pc·C)⊗R.

# Isolate R
remain: 'cirq.PauliString' = PauliString()
for q in all_qubits:
pauli = ps.get(q)
if pauli is not None and not q in op.qubits:
remain *= pauli(q)
remain: 'cirq.PauliString' = PauliString(
*(pauli(q) for q in all_qubits - set(op.qubits) if (pauli := ps.get(q)) is not None)
)

# Initialize the conjugation of Pc.
conjugated: 'cirq.DensePauliString' = (
dense_pauli_string.DensePauliString(pauli_mask=[identity.I for _ in op.qubits])
* self.coefficient
* ps.coefficient
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bug of the previous pr I submitted. We should roll against ps in the iteration, coeff should be from ps not self, found this by enforcing unitary checks in tests for all conjugated_by tests, previously if expected is given, unitary isn't checked. Fix it here.

)

# Calculate the conjugation via CliffordGate's clifford_tableau.
# Note the clifford_tableau in CliffordGate represents C·P·C^-1 instead of C^-1·P·C.
# So we take the inverse of the tableau to match the definition of the conjugation here.
gate_in_clifford: 'cirq.CliffordGate'
if isinstance(op.gate, cirq.CliffordGate):
if isinstance(op.gate, clifford_gate.CliffordGate):
gate_in_clifford = op.gate
else:
# Convert the clifford gate to CliffordGate type.
Expand All @@ -1020,7 +1016,7 @@ def conjugated_by(self, clifford: 'cirq.OP_TREE') -> 'PauliString':
# Puali X_k's conjugation is from the destabilzer table;
# Puali Z_k's conjugation is from the stabilzer table;
# Puali Y_k's conjugation is calcluated according to Y = iXZ. E.g., for the kth qubit,
# C^{-1}·Y_k⊗I·C = C^{-1}·(iX_k⊗I·Z_k⊗I)·C = i (C^{-1}·X_k⊗I·C)·(C^{-1}·Z_k⊗I·C)
# C^{-1}·Y_k⊗I·C = C^{-1}·(iX_k⊗I·Z_k⊗I)·C = i (C^{-1}·X_k⊗I·C)·(C^{-1}·Z_k⊗I·C).
for qid, qubit in enumerate(op.qubits):
pauli = ps.get(qubit)
match pauli:
Expand Down Expand Up @@ -1100,20 +1096,17 @@ def pass_operations_over(
pauli string, instead of before (and so are moving in the
opposite direction).
"""
pauli_map = dict(self._qubit_pauli_map)
should_negate = False
for op in ops:
if pauli_map.keys().isdisjoint(set(op.qubits)):
continue
decomposed = _decompose_into_cliffords(op)
if not after_to_before:
decomposed = decomposed[::-1]
for clifford_op in decomposed:
if pauli_map.keys().isdisjoint(set(clifford_op.qubits)):
continue
should_negate ^= _pass_operation_over(pauli_map, clifford_op, after_to_before)
coef = -self._coefficient if should_negate else self.coefficient
return PauliString(qubit_pauli_map=pauli_map, coefficient=coef)
# TODO(#6946): deprecate this method.
# Note: This method is supposed to be replaced by conjugated_by()
# (see #2351 for details).
if after_to_before:
return self.after(ops)

if isinstance(ops, gate_operation.GateOperation):
return self.before(ops)

all_ops = list(op_tree.flatten_to_ops(ops))
return self.before(all_ops[::-1])

def _is_parameterized_(self) -> bool:
return protocols.is_parameterized(self.coefficient)
Expand Down Expand Up @@ -1179,7 +1172,7 @@ def _try_interpret_as_pauli_string(op: Any):
if (pauli := gates.get(type(op.gate), None)) is not None:
exponent = op.gate.exponent # type: ignore
if exponent % 2 == 0:
return cirq.PauliString()
return PauliString()
if exponent % 2 == 1:
return pauli.on(op.qubits[0])
return None
Expand Down
6 changes: 4 additions & 2 deletions cirq-core/cirq/ops/pauli_string_phasor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,10 @@ def test_pass_operations_over():
ps_after = cirq.PauliString({q0: cirq.Z, q1: cirq.Y}, -1)
before = cirq.PauliStringPhasor(ps_before, exponent_neg=0.1)
after = cirq.PauliStringPhasor(ps_after, exponent_neg=0.1)
assert before.pass_operations_over([op]) == after
assert after.pass_operations_over([op], after_to_before=True) == before
assert before.pass_operations_over([op]).pauli_string == after.pauli_string
assert (
after.pass_operations_over([op], after_to_before=True).pauli_string == before.pauli_string
)


def test_extrapolate_effect():
Expand Down
86 changes: 67 additions & 19 deletions cirq-core/cirq/ops/pauli_string_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,53 @@ def _small_sample_qubit_pauli_maps():


def assert_conjugation(
input_ps: cirq.PauliString, ops: cirq.OP_TREE, expected: cirq.PauliString | None
input_ps: cirq.PauliString,
op: cirq.Operation,
expected: cirq.PauliString | None = None,
force_checking_unitary=True,
):
"""Verifies that conjugating `input_ps` by `op` results in `expected`.

Also ensures that the unitary representation of the Pauli string is
preserved under the conjugation.
"""

def _ps_on_qubits(ps: cirq.PauliString, qubits: tuple[cirq.Qid, ...]):
"""Extracts a sub-PauliString from a given PauliString, restricted to
a specified subset of qubits.
"""
pauli_map = {}
for q, pauli in ps.items():
if q in qubits:
pauli_map[q] = pauli
return cirq.PauliString(qubit_pauli_map=pauli_map, coefficient=ps.coefficient)

conjugation = input_ps.conjugated_by(op)
if expected is None or force_checking_unitary:
# Compares the unitary of the conjugation result and the expected unitary.
clifford = cirq.CliffordGate.from_op_list([op], op.qubits)
actual_unitary = cirq.unitary(_ps_on_qubits(conjugation, op.qubits).dense(op.qubits))
c = cirq.unitary(clifford)
expected_unitary = (
np.conj(c.T) @ cirq.unitary(_ps_on_qubits(input_ps, op.qubits).dense(op.qubits)) @ c
)
assert np.allclose(actual_unitary, expected_unitary, atol=1e-8)
if expected is not None:
assert conjugation == expected


def assert_conjugation_multi_ops(
input_ps: cirq.PauliString, ops: list[cirq.Operation], expected: cirq.PauliString | None = None
):
conjugation = input_ps.conjugated_by(ops)
if expected is not None:
assert conjugation == expected
else: # Compares the unitary of the conjugation result and the expected unitary.
op_list = list(cirq.flatten_to_ops(ops))
qubits_of_clifford = [q for op in op_list for q in op.qubits]
clifford = cirq.CliffordGate.from_op_list(op_list, qubits_of_clifford)
actual_unitary = cirq.unitary(conjugation.dense(qubits_of_clifford))
c = cirq.unitary(clifford)
expected_unitary = np.conj(c.T) @ cirq.unitary(input_ps.dense(qubits_of_clifford)) @ c
assert np.allclose(actual_unitary, expected_unitary, atol=1e-8)
# conj_by(op_{n-1}).conj_by(op_{n-1}).....conj_by(op_0)
conj_in_order = input_ps
for op in ops[::-1]:
assert_conjugation(conj_in_order, op)
conj_in_order = conj_in_order.conjugated_by(op)
assert conjugation == conj_in_order


def test_eq_ne_hash():
Expand Down Expand Up @@ -741,26 +775,31 @@ def test_pass_operations_over_double(shift: int, t_or_f1: bool, t_or_f2: bool, n
op0 = cirq.PauliInteractionGate(Z, t_or_f1, X, t_or_f2)(q0, q1)
ps_before = cirq.PauliString(qubit_pauli_map={q0: Z, q2: Y}, coefficient=sign)
ps_after = cirq.PauliString(qubit_pauli_map={q0: Z, q2: Y}, coefficient=sign)
assert_conjugation(ps_before, op0, ps_after, True)
_assert_pass_over([op0], ps_before, ps_after)

op0 = cirq.PauliInteractionGate(Y, t_or_f1, X, t_or_f2)(q0, q1)
ps_before = cirq.PauliString({q0: Z, q2: Y}, sign)
ps_after = cirq.PauliString({q0: Z, q2: Y, q1: X}, sign)
ps_after = cirq.PauliString({q0: Z, q2: Y, q1: X}, -sign if t_or_f2 else sign)
Copy link
Collaborator Author

@babacry babacry Mar 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous test case doesn't preserve unitaries of expected $P_{conj}$ and $C^\dagger P C$ for paulistring $P$ and clifford gate $C$ (which is a PauliInteractionGate here).

It's either a deeper issue from the e.g., decompose func of PauliInteractionGate or the implementation error of the old implementation of pass_operations_over.

assert_conjugation(ps_before, op0, ps_after, True)
_assert_pass_over([op0], ps_before, ps_after)

op0 = cirq.PauliInteractionGate(Z, t_or_f1, X, t_or_f2)(q0, q1)
ps_before = cirq.PauliString({q0: Z, q1: Y}, sign)
ps_after = cirq.PauliString({q1: Y}, sign)
ps_after = cirq.PauliString({q1: Y}, -sign if t_or_f1 else sign)
assert_conjugation(ps_before, op0, ps_after, True)
_assert_pass_over([op0], ps_before, ps_after)

op0 = cirq.PauliInteractionGate(Y, t_or_f1, X, t_or_f2)(q0, q1)
ps_before = cirq.PauliString({q0: Z, q1: Y}, sign)
ps_after = cirq.PauliString({q0: X, q1: Z}, -1 if neg ^ t_or_f1 ^ t_or_f2 else +1)
assert_conjugation(ps_before, op0, ps_after, True)
_assert_pass_over([op0], ps_before, ps_after)

op0 = cirq.PauliInteractionGate(X, t_or_f1, X, t_or_f2)(q0, q1)
ps_before = cirq.PauliString({q0: Z, q1: Y}, sign)
ps_after = cirq.PauliString({q0: Y, q1: Z}, +1 if neg ^ t_or_f1 ^ t_or_f2 else -1)
assert_conjugation(ps_before, op0, ps_after, True)
_assert_pass_over([op0], ps_before, ps_after)


Expand All @@ -774,7 +813,12 @@ def test_pass_operations_over_cz():

def test_pass_operations_over_no_common_qubits():
class ExampleGate(cirq.testing.SingleQubitGate):
pass

def num_qubits(self):
return 1

def _decompose_(self, qubits):
return cirq.X(qubits[0])

q0, q1 = _make_qubits(2)
op0 = ExampleGate()(q1)
Expand All @@ -786,7 +830,11 @@ class ExampleGate(cirq.testing.SingleQubitGate):
def test_pass_unsupported_operations_over():
(q0,) = _make_qubits(1)
pauli_string = cirq.PauliString({q0: cirq.X})
with pytest.raises(TypeError, match='not a known Clifford'):
with pytest.raises(
ValueError,
match='Clifford Gate can only be constructed from the operations'
' that has stabilizer effect.',
):
pauli_string.pass_operations_over([cirq.T(q0)])


Expand Down Expand Up @@ -1523,8 +1571,8 @@ def _decompose_(self, qubits):
def test_conjugated_by_move_into_uninvolved():
a, b, c, d = cirq.LineQubit.range(4)
ps = cirq.X(a) * cirq.Z(b)
assert_conjugation(ps, [cirq.SWAP(c, d), cirq.SWAP(b, c)], cirq.X(a) * cirq.Z(d))
assert_conjugation(ps, [cirq.SWAP(b, c), cirq.SWAP(c, d)], cirq.X(a) * cirq.Z(c))
assert_conjugation_multi_ops(ps, [cirq.SWAP(c, d), cirq.SWAP(b, c)], cirq.X(a) * cirq.Z(d))
assert_conjugation_multi_ops(ps, [cirq.SWAP(b, c), cirq.SWAP(c, d)], cirq.X(a) * cirq.Z(c))


def test_conjugated_by_common_single_qubit_gates():
Expand All @@ -1549,7 +1597,7 @@ def test_conjugated_by_common_single_qubit_gates():
# pauli gate on a, clifford on b: pauli gate preserves.
assert_conjugation(p(a), g(b), p(a))
# pauli gate on a, clifford on a: check conjugation in matrices.
assert_conjugation(p(a), g(a), None)
assert_conjugation(p(a), g(a))


def test_conjugated_by_common_two_qubit_gates():
Expand Down Expand Up @@ -1580,7 +1628,7 @@ def test_conjugated_by_common_two_qubit_gates():
assert_conjugation(p, g(c, d), p)
# pauli_string on (a,b), clifford on (a,b): compare unitaries of
# the conjugated_by and actual matrix conjugation.
assert_conjugation(p, g.on(a, b), None)
assert_conjugation(p, g.on(a, b))


def test_conjugated_by_ordering():
Expand All @@ -1602,7 +1650,7 @@ def _decompose_(self, qubits):

a, b = cirq.LineQubit.range(2)
inp = cirq.Z(b)
out1 = inp.pass_operations_over([OrderSensitiveGate().on(a, b)])
out1 = inp.pass_operations_over(OrderSensitiveGate().on(a, b))
out2 = inp.pass_operations_over([cirq.CNOT(a, b), cirq.Y(a) ** -0.5])
out3 = inp.pass_operations_over([cirq.CNOT(a, b)]).pass_operations_over([cirq.Y(a) ** -0.5])
assert out1 == out2 == out3 == cirq.X(a) * cirq.Z(b)
Expand All @@ -1618,7 +1666,7 @@ def _decompose_(self, qubits):

a, b = cirq.LineQubit.range(2)
inp = cirq.X(a) * cirq.Z(b)
out1 = inp.pass_operations_over([OrderSensitiveGate().on(a, b)], after_to_before=True)
out1 = inp.pass_operations_over(OrderSensitiveGate().on(a, b), after_to_before=True)
out2 = inp.pass_operations_over([cirq.Y(a) ** -0.5, cirq.CNOT(a, b)], after_to_before=True)
out3 = inp.pass_operations_over([cirq.Y(a) ** -0.5], after_to_before=True).pass_operations_over(
[cirq.CNOT(a, b)], after_to_before=True
Expand Down
Loading