Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,7 @@ class Circuit(AbstractCircuit):
* batch_remove
* batch_insert_into
* insert_at_frontier
* reverse

Circuits can also be iterated over,

Expand Down Expand Up @@ -2525,6 +2526,16 @@ def clear_operations_touching(
self._moments[k] = self._moments[k].without_operations_touching(qubits)
self._mutated()

def reverse(self) -> None:
"""Reverses the moments in the circuit, and the operations in the moments."""
# Work on a copy in case validation fails halfway through.
copy = self.copy()
backwards = []
for moment in copy[::-1]:
backwards.append(Moment(reversed(moment.operations)))
self._moments = backwards
self._mutated()

@property
def moments(self) -> Sequence[cirq.Moment]:
return self._moments
Expand Down
173 changes: 173 additions & 0 deletions cirq-core/cirq/circuits/circuit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1868,6 +1868,179 @@ def test_clear_operations_touching() -> None:
)


def test_reverse_empty_circuit():
circuit = cirq.Circuit()
circuit.reverse()
assert len(circuit) == 0
assert circuit == cirq.Circuit()


def test_reverse_single_moment_single_operation():
q = cirq.GridQubit(0, 0)
circuit = cirq.Circuit(cirq.X(q))
original_str = str(circuit)

circuit.reverse()

assert str(circuit) == original_str
assert len(circuit) == 1


def test_reverse_single_moment_multiple_operations():
"""Test reversing a circuit with one moment and multiple operations."""
q0, q1, q2 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)
original_ops = [cirq.X(q0), cirq.Y(q1), cirq.Z(q2)]
circuit = cirq.Circuit(cirq.Moment(original_ops))

circuit.reverse()

# Moment order unchanged (only one moment), but operations reversed
assert len(circuit) == 1
reversed_ops = list(circuit[0])
assert reversed_ops == list(reversed(original_ops))


def test_reverse_multiple_moments_single_operations():
"""Test reversing a circuit with multiple moments, each with single operations."""
q = cirq.GridQubit(0, 0)
circuit = cirq.Circuit([cirq.Moment([cirq.X(q)]), cirq.Moment([cirq.Y(q)]), cirq.Moment([cirq.Z(q)])])

original_moments = [str(moment) for moment in circuit]
circuit.reverse()

# Moments should be reversed
assert len(circuit) == 3
reversed_moments = [str(moment) for moment in circuit]
assert reversed_moments == list(reversed(original_moments))

def test_reverse_multiple_moments_multiple_operations():
"""Test reversing a circuit with multiple moments and multiple operations."""
q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)
circuit = cirq.Circuit(
[
cirq.Moment([cirq.X(q0), cirq.Y(q1)]),
cirq.Moment([cirq.Z(q0), cirq.H(q1)]),
cirq.Moment([cirq.S(q0), cirq.T(q1)])
]
)

# Store original structure
original_structure = []
for moment in circuit:
original_structure.append(list(moment.operations))

circuit.reverse()

# Check that moments are reversed and operations within each moment are reversed
assert len(circuit) == 3

# First moment should be the reversed last moment
expected_first = list(reversed(original_structure[2]))
actual_first = list(circuit[0])
assert actual_first == expected_first

# Second moment should be the reversed middle moment
expected_second = list(reversed(original_structure[1]))
actual_second = list(circuit[1])
assert actual_second == expected_second

# Third moment should be the reversed first moment
expected_third = list(reversed(original_structure[0]))
actual_third = list(circuit[2])
assert actual_third == expected_third


def test_reverse_twice_returns_original():
"""Test that reversing twice returns the original circuit."""
q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)
original_circuit = cirq.Circuit([
cirq.Moment([cirq.X(q0), cirq.Y(q1)]),
cirq.Moment([cirq.Z(q0)]),
cirq.Moment([cirq.H(q0), cirq.S(q1)])
]
)

# Make a copy to compare against
expected = original_circuit.copy()

# Reverse twice
original_circuit.reverse()
original_circuit.reverse()

# Should be back to original
assert original_circuit == expected


def test_reverse_with_measurements():
"""Test reversing a circuit with measurement operations."""
q0, q1 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1)
circuit = cirq.Circuit(
[
cirq.Moment([cirq.X(q0), cirq.Y(q1)]),
cirq.Moment([cirq.measure(q0, key='a'), cirq.measure(q1, key='b')])
]
)

original_structure = []
for moment in circuit:
original_structure.append(list(moment.operations))

circuit.reverse()

# Check structure is properly reversed
assert len(circuit) == 2

# First moment should be reversed measurements
actual_first = list(circuit[0])
assert len(actual_first) == 2
assert all(isinstance(op.gate, cirq.MeasurementGate) for op in actual_first)

# Second moment should be reversed X, Y gates
actual_second = list(circuit[1])
assert len(actual_second) == 2


def test_reverse_with_two_qubit_gates():
"""Test reversing a circuit with two-qubit gates."""
q0, q1, q2 = cirq.GridQubit(0, 0), cirq.GridQubit(0, 1), cirq.GridQubit(0, 2)
circuit = cirq.Circuit(
[
cirq.Moment([cirq.CNOT(q0, q1), cirq.X(q2)]),
cirq.Moment([cirq.CZ(q1, q2)]),
cirq.Moment([cirq.SWAP(q0, q2), cirq.Y(q1)])
]
)

original_structure = []
for moment in circuit:
original_structure.append(list(moment.operations))

circuit.reverse()

# Verify the structure is correctly reversed
assert len(circuit) == 3

# Check that two-qubit gates are preserved correctly
for i, moment in enumerate(circuit):
expected_ops = list(reversed(original_structure[2 - i]))
actual_ops = list(moment.operations)
assert actual_ops == expected_ops


def test_reverse_modifies_original_circuit():
"""Test that reverse() modifies the original circuit in-place."""
q = cirq.GridQubit(0, 0)
circuit = cirq.Circuit([cirq.Moment([cirq.X(q)]), cirq.Moment([cirq.Y(q)])])

original_id = id(circuit)
circuit.reverse()

# Should be the same object
assert id(circuit) == original_id

# But content should be different
assert str(circuit[0]) != "X(q(0, 0))" # First moment is now Y

@pytest.mark.parametrize('circuit_cls', [cirq.Circuit, cirq.FrozenCircuit])
def test_all_qubits(circuit_cls) -> None:
a = cirq.NamedQubit('a')
Expand Down
8 changes: 6 additions & 2 deletions cirq-core/cirq/transformers/stratify.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import itertools
from typing import Callable, Iterable, Sequence, TYPE_CHECKING, Union

import copy

from cirq import _import, circuits, ops, protocols
from cirq.transformers import transformer_api

Expand Down Expand Up @@ -69,7 +71,8 @@ def stratified_circuit(
# Try the algorithm with each permutation of the classifiers.
smallest_depth = protocols.num_qubits(circuit) * len(circuit) + 1
shortest_stratified_circuit = circuits.Circuit()
reversed_circuit = circuit[::-1]
reversed_circuit = copy.deepcopy(circuit)
reversed_circuit.reverse()
for ordered_classifiers in itertools.permutations(classifiers):
solution = _stratify_circuit(
circuit,
Expand All @@ -87,7 +90,8 @@ def stratified_circuit(
reversed_circuit,
classifiers=ordered_classifiers,
context=context or transformer_api.TransformerContext(),
)[::-1]
)
solution.reverse()
if len(solution) < smallest_depth:
shortest_stratified_circuit = solution
smallest_depth = len(solution)
Expand Down
Loading