diff --git a/cirq-core/cirq/transformers/connected_component.py b/cirq-core/cirq/transformers/connected_component.py new file mode 100644 index 00000000000..a905fa0cd49 --- /dev/null +++ b/cirq-core/cirq/transformers/connected_component.py @@ -0,0 +1,240 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines a connected component of operations, to be used in merge transformers.""" + +from __future__ import annotations + +from typing import Callable, cast, Sequence, TYPE_CHECKING + +from scipy.cluster.hierarchy import DisjointSet +from typing_extensions import override + +from cirq import ops, protocols + +if TYPE_CHECKING: + import cirq + + +class Component: + """Internal representation for a connected component of operations.""" + + # Circuit moment containing the component + moment_id: int + # Union of all op qubits in the component + qubits: frozenset[cirq.Qid] + # Union of all measurement keys in the component + mkeys: frozenset[cirq.MeasurementKey] + # Union of all control keys in the component + ckeys: frozenset[cirq.MeasurementKey] + # Initial operation in the component + op: cirq.Operation + + # True if the component can be merged with other components + is_mergeable: bool + + def __init__(self, op: cirq.Operation, moment_id: int, is_mergeable=True): + """Initializes a singleton component.""" + self.op = op + self.is_mergeable = is_mergeable + self.moment_id = moment_id + self.qubits = frozenset(op.qubits) + self.mkeys = protocols.measurement_key_objs(op) + self.ckeys = protocols.control_keys(op) + + +class ComponentWithOps(Component): + """Component that keeps track of operations.""" + + # List of all operations in the component + ops: list[cirq.Operation] + + def __init__(self, op: cirq.Operation, moment_id: int, is_mergeable=True): + super().__init__(op, moment_id, is_mergeable) + self.ops = [op] + + +class ComponentWithCircuitOp(Component): + """Component that keeps track of operations as a CircuitOperation.""" + + # CircuitOperation containing all the operations in the component, + # or a single Operation if the component is a singleton + circuit_op: cirq.Operation + + def __init__(self, op: cirq.Operation, moment_id: int, is_mergeable=True): + super().__init__(op, moment_id, is_mergeable) + self.circuit_op = op + + +class ComponentSet: + """Represents a set of mergeable components of operations.""" + + _comp_type: type[Component] + + _disjoint_set: DisjointSet + + # Callable to decide if a component is mergeable + _is_mergeable: Callable[[cirq.Operation], bool] + + # List of components in creation order + _components: list[Component] + + def __init__(self, is_mergeable: Callable[[cirq.Operation], bool]): + self._is_mergeable = is_mergeable + self._disjoint_set = DisjointSet() + self._components = [] + self._comp_type = Component + + def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: + """Creates a new component and adds it to the set.""" + c = self._comp_type(op, moment_id, self._is_mergeable(op) and is_mergeable) + self._disjoint_set.add(c) + self._components.append(c) + return c + + def components(self) -> list[Component]: + """Returns the initial components in creation order.""" + return self._components + + def find(self, x: Component) -> Component: + """Finds the representative for a merged component.""" + return self._disjoint_set[x] + + def merge(self, x: Component, y: Component, merge_left=True) -> Component | None: + """Attempts to merge two components. + + If merge_left is True, y is merged into x, and the representative will keep + y's moment. If merge_left is False, x is merged into y, and the representative + will keep y's moment. + + Args: + x: First component to merge. + y: Second component to merge. + merge_left: True to keep x's moment for the merged component, False to + keep y's moment for the merged component. + + Returns: + None, if the components can't be merged. + Otherwise the new component representative. + """ + x = self._disjoint_set[x] + y = self._disjoint_set[y] + + if not x.is_mergeable or not y.is_mergeable: + return None + + if not self._disjoint_set.merge(x, y): + return x + + root = self._disjoint_set[x] + root.moment_id = x.moment_id if merge_left else y.moment_id + root.qubits = x.qubits.union(y.qubits) + root.mkeys = x.mkeys.union(y.mkeys) + root.ckeys = x.ckeys.union(y.ckeys) + + return root + + +class ComponentWithOpsSet(ComponentSet): + """Represents a set of mergeable components, where each component tracks operations.""" + + # Callable that returns if two components can be merged based on their operations + _can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] + + def __init__( + self, + is_mergeable: Callable[[cirq.Operation], bool], + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], + ): + super().__init__(is_mergeable) + self._can_merge = can_merge + self._comp_type = ComponentWithOps + + @override + def merge(self, x: Component, y: Component, merge_left=True) -> Component | None: + """Attempts to merge two components. + + Returns: + None if can_merge is False or the merge doesn't succeed, otherwise the + new representative. The representative will have ops = x.ops + y.ops. + """ + x = cast(ComponentWithOps, self._disjoint_set[x]) + y = cast(ComponentWithOps, self._disjoint_set[y]) + + if x == y: + return x + + if not x.is_mergeable or not y.is_mergeable or not self._can_merge(x.ops, y.ops): + return None + + root = cast(ComponentWithOps, super().merge(x, y, merge_left)) + root.ops = x.ops + y.ops + # Clear the ops list in the non-representative component to avoid memory consumption + if x != root: + x.ops = [] + else: + y.ops = [] + return root + + +class ComponentWithCircuitOpSet(ComponentSet): + """Represents a set of mergeable components, with operations as a CircuitOperation.""" + + # Callable that merges CircuitOperations from two components + _merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] + + def __init__( + self, + is_mergeable: Callable[[cirq.Operation], bool], + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], + ): + super().__init__(is_mergeable) + self._merge_func = merge_func + self._comp_type = ComponentWithCircuitOp + + @override + def merge(self, x: Component, y: Component, merge_left=True) -> Component | None: + """Attempts to merge two components. + + Returns: + None if merge_func returns None or the merge doesn't succeed, + otherwise the new representative. + """ + x = cast(ComponentWithCircuitOp, self._disjoint_set[x]) + y = cast(ComponentWithCircuitOp, self._disjoint_set[y]) + + if x == y: + return x + + if not x.is_mergeable or not y.is_mergeable: + return None + + new_op = self._merge_func(x.circuit_op, y.circuit_op) + if not new_op: + return None + + root = cast(ComponentWithCircuitOp, super().merge(x, y, merge_left)) + + root.circuit_op = new_op + # The merge_func can be arbitrary, so we need to recompute the component properties + root.qubits = frozenset(new_op.qubits) + root.mkeys = protocols.measurement_key_objs(new_op) + root.ckeys = protocols.control_keys(new_op) + + # Clear the circuit op in the non-representative component to avoid memory consumption + if x != root: + del x.circuit_op + else: + del y.circuit_op + return root diff --git a/cirq-core/cirq/transformers/connected_component_test.py b/cirq-core/cirq/transformers/connected_component_test.py new file mode 100644 index 00000000000..ad067d60f63 --- /dev/null +++ b/cirq-core/cirq/transformers/connected_component_test.py @@ -0,0 +1,307 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import cirq +from cirq.transformers.connected_component import ( + ComponentSet, + ComponentWithCircuitOpSet, + ComponentWithOpsSet, +) + + +def test_find_returns_itself_for_singleton(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + + q = cirq.NamedQubit('x') + c = cset.new_component(op=cirq.X(q), moment_id=0) + assert cset.find(c) == c + + +def test_merge_components(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + + q = cirq.NamedQubit('x') + c = [cset.new_component(op=cirq.X(q), moment_id=i) for i in range(5)] + cset.merge(c[1], c[0]) + cset.merge(c[2], c[1]) + cset.merge(c[4], c[3]) + cset.merge(c[3], c[0]) + + for i in range(5): + assert cset.find(c[i]) == cset.find(c[0]) + + +def test_merge_same_component(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + + q = cirq.NamedQubit('x') + c = [cset.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] + cset.merge(c[1], c[0]) + cset.merge(c[2], c[1]) + + root = cset.find(c[0]) + + assert cset.merge(c[0], c[2]) == root + + +def test_merge_returns_None_if_one_component_is_not_mergeable(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + + q = cirq.NamedQubit('x') + c0 = cset.new_component(op=cirq.X(q), moment_id=0, is_mergeable=True) + c1 = cset.new_component(op=cirq.X(q), moment_id=1, is_mergeable=False) + assert cset.merge(c0, c1) is None + + +def test_cset_merge_returns_None_if_is_mergeable_is_false(): + q = cirq.NamedQubit('x') + + def is_mergeable(_: cirq.Operation) -> bool: + return False + + cset = ComponentSet(is_mergeable=is_mergeable) + + c0 = cset.new_component(op=cirq.X(q), moment_id=0, is_mergeable=True) + c1 = cset.new_component(op=cirq.X(q), moment_id=1, is_mergeable=True) + assert cset.merge(c0, c1) is None + + +def test_merge_qubits_with_merge_left_true(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = cset.new_component(op=cirq.X(q0), moment_id=0) + c1 = cset.new_component(op=cirq.X(q1), moment_id=0) + c2 = cset.new_component(op=cirq.X(q1), moment_id=1) + cset.merge(c1, c2) + cset.merge(c0, c1, merge_left=True) + assert cset.find(c1).qubits == frozenset([q0, q1]) + + +def test_merge_qubits_with_merge_left_false(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = cset.new_component(op=cirq.X(q0), moment_id=0) + c1 = cset.new_component(op=cirq.X(q0), moment_id=0) + c2 = cset.new_component(op=cirq.X(q1), moment_id=1) + cset.merge(c0, c1) + cset.merge(c1, c2, merge_left=False) + assert cset.find(c0).qubits == frozenset([q0, q1]) + + +def test_merge_moment_with_merge_left_true(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = cset.new_component(op=cirq.X(q0), moment_id=0) + c1 = cset.new_component(op=cirq.X(q1), moment_id=1) + c2 = cset.new_component(op=cirq.X(q1), moment_id=1) + cset.merge(c1, c2) + cset.merge(c0, c1, merge_left=True) + # the set representative kept c0's moment + assert cset.find(c1).moment_id == 0 + + +def test_merge_moment_with_merge_left_false(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = cset.new_component(op=cirq.X(q0), moment_id=0) + c1 = cset.new_component(op=cirq.X(q0), moment_id=0) + c2 = cset.new_component(op=cirq.X(q1), moment_id=1) + cset.merge(c0, c1) + cset.merge(c1, c2, merge_left=False) + # the set representative kept c2's moment + assert cset.find(c0).moment_id == 1 + + +def test_component_with_ops_merge(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + def can_merge(_ops1: list[cirq.Operation], _ops2: list[cirq.Operation]) -> bool: + return True + + cset = ComponentWithOpsSet(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] + + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) + assert cset.find(c[0]).ops == ops + + +def test_component_with_ops_merge_same_component(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + def can_merge(_ops1: list[cirq.Operation], _ops2: list[cirq.Operation]) -> bool: + return True + + cset = ComponentWithOpsSet(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) + assert cset.merge(c[0], c[2]).ops == ops + + +def test_component_with_ops_merge_when_merge_fails(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + def can_merge(_ops1: list[cirq.Operation], _ops2: list[cirq.Operation]) -> bool: + return False + + cset = ComponentWithOpsSet(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] + + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) + # No merge happened + for i in range(3): + assert cset.find(c[i]) == c[i] + + +def test_component_with_ops_merge_when_is_mergeable_is_false(): + def is_mergeable(_: cirq.Operation) -> bool: + return False + + def can_merge(_ops1: list[cirq.Operation], _ops2: list[cirq.Operation]) -> bool: + return True + + cset = ComponentWithOpsSet(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] + + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) + # No merge happened + for i in range(3): + assert cset.find(c[i]) == c[i] + + +def test_component_with_circuit_op_merge(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + def merge_func(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: + return op1 + + cset = ComponentWithCircuitOpSet(is_mergeable, merge_func) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] + + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) + for i in range(3): + assert cset.find(c[i]).circuit_op == ops[0] + + +def test_component_with_circuit_op_merge_same_component(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + def merge_func(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: + return op1 + + cset = ComponentWithCircuitOpSet(is_mergeable, merge_func) + + q = cirq.NamedQubit('x') + c = [cset.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] + cset.merge(c[1], c[0]) + cset.merge(c[2], c[1]) + assert cset.merge(c[0], c[2]) == cset.find(c[1]) + + +def test_component_with_circuit_op_merge_func_is_none(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + def merge_func(_op1: cirq.Operation, _op2: cirq.Operation) -> None: + return None + + cset = ComponentWithCircuitOpSet(is_mergeable, merge_func) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] + + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) + # No merge happened + for i in range(3): + assert cset.find(c[i]) == c[i] + + +def test_component_with_circuit_op_merge_when_is_mergeable_is_false(): + def is_mergeable(_: cirq.Operation) -> bool: + return False + + def merge_func(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: + return op1 + + cset = ComponentWithCircuitOpSet(is_mergeable, merge_func) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] + + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) + # No merge happened + for i in range(3): + assert cset.find(c[i]) == c[i] diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 14396764011..ba30f57f6e2 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -17,12 +17,20 @@ from __future__ import annotations import bisect +import copy import dataclasses from collections import defaultdict from typing import Callable, cast, Hashable, Sequence, TYPE_CHECKING from cirq import circuits, ops, protocols from cirq.circuits.circuit import CIRCUIT_TYPE +from cirq.transformers.connected_component import ( + Component, + ComponentSet, + ComponentWithCircuitOp, + ComponentWithCircuitOpSet, + ComponentWithOpsSet, +) if TYPE_CHECKING: import cirq @@ -282,17 +290,23 @@ def map_operations_and_unroll( @dataclasses.dataclass class _MergedCircuit: - """An optimized internal representation of a circuit, tailored for `cirq.merge_operations` + """An optimized internal representation of a circuit, tailored for merge operations. + + Operations are represented as mergeable components. + Each component has a moment id, a set of qubits, a set of measurement keys, and a set of + control keys. The moment id is the index of the moment that contains the component. Attributes: - qubit_indexes: Mapping from qubits to (sorted) list of moment indexes containing operations - acting on the qubit. - mkey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing - measurement operations with the same key. - ckey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing - classically controlled operations controlled on the same key. - ops_by_index: List of circuit moments containing operations. We use a dictionary instead - of a set to store operations to preserve insertion order. + qubit_indexes: Mapping from qubits to (sorted) list of component moments containing + operations acting on the qubit. + mkey_indexes: Mapping from measurement keys to (sorted) list of component moments + containing measurement operations with the same key. + ckey_indexes: Mapping from measurement keys to (sorted) list of component moments + containing classically controlled operations controlled on the same key. + components_by_index: List of components indexed by moment id. + For a moment id, we use a dictionary to keep track of the + components in the moment. The dictionary instead of a set is used to preserve + insertion order and the dictionary's values are intentionally unused. """ qubit_indexes: dict[cirq.Qid, list[int]] = dataclasses.field( @@ -304,54 +318,218 @@ class _MergedCircuit: ckey_indexes: dict[cirq.MeasurementKey, list[int]] = dataclasses.field( default_factory=lambda: defaultdict(lambda: [-1]) ) - ops_by_index: list[dict[cirq.Operation, int]] = dataclasses.field(default_factory=list) + components_by_index: list[dict[Component, int]] = dataclasses.field(default_factory=list) def append_empty_moment(self) -> None: - self.ops_by_index.append({}) + self.components_by_index.append({}) - def add_op_to_moment(self, moment_index: int, op: cirq.Operation) -> None: - self.ops_by_index[moment_index][op] = 0 - for q in op.qubits: - if moment_index > self.qubit_indexes[q][-1]: - self.qubit_indexes[q].append(moment_index) - else: - bisect.insort(self.qubit_indexes[q], moment_index) - for mkey in protocols.measurement_key_objs(op): - bisect.insort(self.mkey_indexes[mkey], moment_index) - for ckey in protocols.control_keys(op): - bisect.insort(self.ckey_indexes[ckey], moment_index) - - def remove_op_from_moment(self, moment_index: int, op: cirq.Operation) -> None: - self.ops_by_index[moment_index].pop(op) - for q in op.qubits: - if self.qubit_indexes[q][-1] == moment_index: - self.qubit_indexes[q].pop() - else: - self.qubit_indexes[q].remove(moment_index) - for mkey in protocols.measurement_key_objs(op): - self.mkey_indexes[mkey].remove(moment_index) - for ckey in protocols.control_keys(op): - self.ckey_indexes[ckey].remove(moment_index) - - def get_mergeable_ops( - self, op: cirq.Operation, op_qs: set[cirq.Qid] - ) -> tuple[int, list[cirq.Operation]]: - # Find the index of previous moment which can be merged with `op`. - idx = max([self.qubit_indexes[q][-1] for q in op_qs], default=-1) - idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in protocols.control_keys(op)]) - idx = max( - [idx] + [self.ckey_indexes[mkey][-1] for mkey in protocols.measurement_key_objs(op)] - ) - # Return the set of overlapping ops in moment with index `idx`. + def add_moment(self, index: list[int], moment_id: int) -> None: + """Adds a moment to a sorted list of moment indexes. + + Optimized for the majority case when the new moment is higher than any moment in the list. + """ + if index[-1] < moment_id: + index.append(moment_id) + else: + bisect.insort(index, moment_id) + + def remove_moment(self, index: list[int], moment_id: int) -> None: + """Removes a moment from a sorted list of moment indexes. + + Optimized for the majority case when the moment is last in the list. + """ + if index[-1] == moment_id: + index.pop() + else: + index.remove(moment_id) + + def add_component(self, c: Component) -> None: + """Adds a new components to merged circuit.""" + self.components_by_index[c.moment_id][c] = 0 + for q in c.qubits: + self.add_moment(self.qubit_indexes[q], c.moment_id) + for mkey in c.mkeys: + self.add_moment(self.mkey_indexes[mkey], c.moment_id) + for ckey in c.ckeys: + self.add_moment(self.ckey_indexes[ckey], c.moment_id) + + def remove_component(self, c: Component, c_data: Component) -> None: + """Removes a component from the merged circuit. + + Args: + c: Reference to the component to be removed. + c_data: Copy of the data in c before any component merges involving c + (this is necessary as component merges alter the component data). + """ + self.components_by_index[c_data.moment_id].pop(c) + for q in c_data.qubits: + self.remove_moment(self.qubit_indexes[q], c_data.moment_id) + for mkey in c_data.mkeys: + self.remove_moment(self.mkey_indexes[mkey], c_data.moment_id) + for ckey in c_data.ckeys: + self.remove_moment(self.ckey_indexes[ckey], c_data.moment_id) + + def get_mergeable_components(self, c: Component, c_qs: set[cirq.Qid]) -> list[Component]: + """Finds all components that can be merged with c. + + Args: + c: Component to be merged with existing components. + c_qs: Subset of c.qubits used to decide which components are mergeable. + + Returns: + List of mergeable components. + """ + # Find the index of previous moment which can be merged with `c`. + idx = max([self.qubit_indexes[q][-1] for q in c_qs], default=-1) + idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in c.ckeys]) + idx = max([idx] + [self.ckey_indexes[mkey][-1] for mkey in c.mkeys]) + # Return the set of overlapping components in moment with index `idx`. if idx == -1: - return idx, [] + return [] + + return [c for c in self.components_by_index[idx] if not c_qs.isdisjoint(c.qubits)] + + def get_cirq_circuit(self, cset: ComponentSet, merged_circuit_op_tag: str) -> cirq.Circuit: + """Returns the merged circuit. + + Args: + cset: Disjoint set data structure containing the components. + merged_circuit_op_tag: Tag to use for CircuitOperations. + + Returns: + The circuit with merged components as a CircuitOperation. + """ + component_ops: dict[Component, list[cirq.Operation]] = defaultdict(list) + + # Traverse the components in creation order and collect operations + for c in cset.components(): + root = cset.find(c) + component_ops[root].append(c.op) + + moments = [] + for m in self.components_by_index: + ops = [] + for c in m.keys(): + if isinstance(c, ComponentWithCircuitOp): + ops.append(c.circuit_op) + continue + if len(component_ops[c]) == 1: + ops.append(component_ops[c][0]) + else: + ops.append( + circuits.CircuitOperation( + circuits.FrozenCircuit(component_ops[c]) + ).with_tags(merged_circuit_op_tag) + ) + moments.append(circuits.Moment(ops)) + return circuits.Circuit(moments) + + +def _merge_operations_impl( + circuit: CIRCUIT_TYPE, + cset: ComponentSet, + *, + merged_circuit_op_tag: str = "Merged connected component", + tags_to_ignore: Sequence[Hashable] = (), + deep: bool = False, +) -> CIRCUIT_TYPE: + """Merges operations in a circuit. + + Two operations op1 and op2 are merge-able if + - There is no other operations between op1 and op2 in the circuit + - is_subset(op1.qubits, op2.qubits) or is_subset(op2.qubits, op1.qubits) + + The method iterates on the input circuit moment-by-moment from left to right and attempts + to repeatedly merge each operation in the latest moment with all the corresponding merge-able + operations to its left. + + Operations are wrapped in a component and then cset.merge() is called to merge two + components. + + If op1 and op2 are merged, both op1 and op2 are deleted from the circuit and + the merged component is inserted at the index corresponding to the larger + of op1/op2. If both op1 and op2 act on the same number of qubits, the merged component is + inserted in the smaller moment index to minimize circuit depth. + + At the end every component with more than one operation is replaced by a CircuitOperation. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + cset: Disjoint set data structure that is used to create and merge components. + merged_circuit_op_tag: Tag used for CircuitOperations created from merged components. + tags_to_ignore: Sequence of tags which should be ignored during the merge: operations with + these tags will not be merged. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. + + + Returns: + Copy of input circuit with merged operations. + """ + tags_to_ignore_set = set(tags_to_ignore) + + merged_circuit = _MergedCircuit() + for moment_idx, current_moment in enumerate(cast(list['cirq.Moment'], circuit)): + merged_circuit.append_empty_moment() + for op in sorted(current_moment.operations, key=lambda op: op.qubits): + if ( + deep + and isinstance(op.untagged, circuits.CircuitOperation) + and tags_to_ignore_set.isdisjoint(op.tags) + ): + op_untagged = op.untagged + merged_op = op_untagged.replace( + circuit=_merge_operations_impl( + op_untagged.circuit, + cset, + merged_circuit_op_tag=merged_circuit_op_tag, + tags_to_ignore=tags_to_ignore, + deep=True, + ) + ).with_tags(*op.tags) + c = cset.new_component(merged_op, moment_idx, is_mergeable=False) + merged_circuit.add_component(c) + continue - return idx, [ - left_op for left_op in self.ops_by_index[idx] if not op_qs.isdisjoint(left_op.qubits) - ] + c = cset.new_component( + op, moment_idx, is_mergeable=tags_to_ignore_set.isdisjoint(op.tags) + ) + if not c.is_mergeable: + merged_circuit.add_component(c) + continue - def get_cirq_circuit(self) -> cirq.Circuit: - return circuits.Circuit(circuits.Moment(m.keys()) for m in self.ops_by_index) + c_qs = set(c.qubits) + left_comp = merged_circuit.get_mergeable_components(c, c_qs) + if len(left_comp) == 1 and c_qs.issubset(left_comp[0].qubits): + # Make a shallow copy of the left component data before merge + left_c_data = copy.copy(left_comp[0]) + # Case-1: Try to merge c with the larger component on the left. + new_comp = cset.merge(left_comp[0], c, merge_left=True) + if new_comp is not None: + merged_circuit.remove_component(left_comp[0], left_c_data) + merged_circuit.add_component(new_comp) + else: + merged_circuit.add_component(c) + continue + + while left_comp and c_qs: + # Case-2: left_c will merge right into `c` whenever possible. + for left_c in left_comp: + is_merged = False + if c_qs.issuperset(left_c.qubits): + # Make a shallow copy of the left component data before merge + left_c_data = copy.copy(left_c) + # Try to merge left_c into c + new_comp = cset.merge(left_c, c, merge_left=False) + if new_comp is not None: + merged_circuit.remove_component(left_c, left_c_data) + c, is_merged = new_comp, True + if not is_merged: + c_qs -= left_c.qubits + left_comp = merged_circuit.get_mergeable_components(c, c_qs) + merged_circuit.add_component(c) + ret_circuit = merged_circuit.get_cirq_circuit(cset, merged_circuit_op_tag) + return _to_target_circuit_type(ret_circuit, circuit) def merge_operations( @@ -407,12 +585,8 @@ def merge_operations( ValueError if the merged operation acts on new qubits outside the set of qubits corresponding to the original operations to be merged. """ - _circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit" - tags_to_ignore_set = set(tags_to_ignore) | {_circuit_op_tag} def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> ops.Operation | None: - if not all(tags_to_ignore_set.isdisjoint(op.tags) for op in [op1, op2]): - return None new_op = merge_func(op1, op2) qubit_set = frozenset(op1.qubits + op2.qubits) if new_op is not None and not qubit_set.issuperset(new_op.qubits): @@ -422,63 +596,15 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> ops.Operation | ) return new_op - merged_circuit = _MergedCircuit() - for moment_idx, current_moment in enumerate(cast(list['cirq.Moment'], circuit)): - merged_circuit.append_empty_moment() - for op in sorted(current_moment.operations, key=lambda op: op.qubits): - if ( - deep - and isinstance(op.untagged, circuits.CircuitOperation) - and tags_to_ignore_set.isdisjoint(op.tags) - ): - op_untagged = op.untagged - merged_circuit.add_op_to_moment( - moment_idx, - op_untagged.replace( - circuit=merge_operations( - op_untagged.circuit, - merge_func, - tags_to_ignore=tags_to_ignore, - deep=True, - ) - ).with_tags(*op.tags, _circuit_op_tag), - ) - continue - - op_qs = set(op.qubits) - left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs) - if len(left_ops) == 1 and op_qs.issubset(left_ops[0].qubits): - # Case-1: Try to merge op with the larger operation on the left. - new_op = apply_merge_func(left_ops[0], op) - if new_op is not None: - merged_circuit.remove_op_from_moment(left_idx, left_ops[0]) - merged_circuit.add_op_to_moment(left_idx, new_op) - else: - merged_circuit.add_op_to_moment(moment_idx, op) - continue + def is_mergeable(_: cirq.Operation): + return True - while left_ops and op_qs: - # Case-2: left_ops will merge right into `op` whenever possible. - for left_op in left_ops: - is_merged = False - if op_qs.issuperset(left_op.qubits): - # Try to merge left_op into op - new_op = apply_merge_func(left_op, op) - if new_op is not None: - merged_circuit.remove_op_from_moment(left_idx, left_op) - op, is_merged = new_op, True - if not is_merged: - op_qs -= frozenset(left_op.qubits) - left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs) - merged_circuit.add_op_to_moment(moment_idx, op) - ret_circuit = merged_circuit.get_cirq_circuit() - if deep: - ret_circuit = map_operations( - ret_circuit, - lambda o, _: o.untagged.with_tags(*(set(o.tags) - {_circuit_op_tag})), - deep=True, - ) - return _to_target_circuit_type(ret_circuit, circuit) + return _merge_operations_impl( + circuit, + ComponentWithCircuitOpSet(is_mergeable, apply_merge_func), + tags_to_ignore=tags_to_ignore, + deep=deep, + ) def merge_operations_to_circuit_op( @@ -491,10 +617,9 @@ def merge_operations_to_circuit_op( ) -> CIRCUIT_TYPE: """Merges connected components of operations and wraps each component into a circuit operation. - Uses `cirq.merge_operations` to identify connected components of operations. Moment structure - is preserved for operations that do not participate in merging. For merged operations, the - newly created circuit operations are constructed by inserting operations using EARLIEST - strategy. + Moment structure is preserved for operations that do not participate in merging. + For merged operations, the newly created circuit operations are constructed by inserting + operations using EARLIEST strategy. If you need more control on moment structure of newly created circuit operations, consider using `cirq.merge_operations` directly with a custom `merge_func`. @@ -514,24 +639,16 @@ def merge_operations_to_circuit_op( Copy of input circuit with valid connected components wrapped in tagged circuit operations. """ - def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation | None: - def get_ops(op: cirq.Operation): - op_untagged = op.untagged - return ( - [*op_untagged.circuit.all_operations()] - if isinstance(op_untagged, circuits.CircuitOperation) - and merged_circuit_op_tag in op.tags - else [op] - ) - - left_ops, right_ops = get_ops(op1), get_ops(op2) - if not can_merge(left_ops, right_ops): - return None - return circuits.CircuitOperation(circuits.FrozenCircuit(left_ops, right_ops)).with_tags( - merged_circuit_op_tag - ) + def is_mergeable(_: cirq.Operation): + return True - return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore, deep=deep) + return _merge_operations_impl( + circuit, + ComponentWithOpsSet(is_mergeable, can_merge), + merged_circuit_op_tag=merged_circuit_op_tag, + tags_to_ignore=tags_to_ignore, + deep=deep, + ) def merge_k_qubit_unitaries_to_circuit_op( @@ -544,10 +661,9 @@ def merge_k_qubit_unitaries_to_circuit_op( ) -> CIRCUIT_TYPE: """Merges connected components of operations, acting on <= k qubits, into circuit operations. - Uses `cirq.merge_operations_to_circuit_op` to identify and merge connected components of - unitary operations acting on at-most k-qubits. Moment structure is preserved for operations - that do not participate in merging. For merged operations, the newly created circuit operations - are constructed by inserting operations using EARLIEST strategy. + Moment structure is preserved for operations that do not participate in merging. + For merged operations, the newly created circuit operations are constructed by inserting + operations using EARLIEST strategy. Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. @@ -563,18 +679,14 @@ def merge_k_qubit_unitaries_to_circuit_op( Copy of input circuit with valid connected components wrapped in tagged circuit operations. """ - def can_merge(ops1: Sequence[cirq.Operation], ops2: Sequence[cirq.Operation]) -> bool: - return all( - protocols.num_qubits(op) <= k and protocols.has_unitary(op) - for op_list in [ops1, ops2] - for op in op_list - ) + def is_mergeable(op: cirq.Operation): + return protocols.num_qubits(op) <= k and protocols.has_unitary(op) - return merge_operations_to_circuit_op( + return _merge_operations_impl( circuit, - can_merge, - tags_to_ignore=tags_to_ignore, + ComponentSet(is_mergeable), merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.", + tags_to_ignore=tags_to_ignore, deep=deep, ) diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index e1152b60aff..eddc866266c 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -877,3 +877,187 @@ def merge_func(op1, op2): cirq.testing.assert_same_circuits( cirq.align_left(cirq.merge_operations(circuit, merge_func)), expected_circuit ) + + +def test_merge_3q_unitaries_to_circuit_op_3q_gate_absorbs_overlapping_2q_gates(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.Moment( + cirq.H(q[0]).with_tags("ignore"), + cirq.H(q[1]).with_tags("ignore"), + cirq.H(q[2]).with_tags("ignore"), + ), + cirq.Moment(cirq.CNOT(q[0], q[2]), cirq.X(q[1]).with_tags("ignore")), + cirq.CNOT(q[0], q[1]), + cirq.CNOT(q[1], q[2]), + cirq.CCZ(*q), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' + ┌──────────┐ +0: ───H[ignore]────@─────────────@───────@─── + │ │ │ +1: ───H[ignore]────┼X[ignore]────X───@───@─── + │ │ │ +2: ───H[ignore]────X─────────────────X───@─── + └──────────┘ +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@───────@─── ] + [ │ │ │ ] +0: ───H[ignore]───────────────[ 1: ───┼───X───@───@─── ]─────────── + [ │ │ │ ] + [ 2: ───X───────X───@─── ][merged] + │ +1: ───H[ignore]───X[ignore]───#2─────────────────────────────────── + │ +2: ───H[ignore]───────────────#3─────────────────────────────────── +''', + ) + + +def test_merge_3q_unitaries_to_circuit_op_3q_gate_absorbs_disjoint_gates(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.Moment(cirq.CNOT(q[0], q[1]), cirq.X(q[2])), + cirq.CCZ(*q), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@───@─── + │ │ +1: ───X───@─── + │ +2: ───X───@─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@─── ] + [ │ │ ] +0: ───[ 1: ───X───@─── ]─────────── + [ │ ] + [ 2: ───X───@─── ][merged] + │ +1: ───#2─────────────────────────── + │ +2: ───#3─────────────────────────── +''', + ) + + +def test_merge_3q_unitaries_to_circuit_op_3q_gate_doesnt_absorb_unmergeable_gate(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.CCZ(*q), + cirq.Moment(cirq.CNOT(q[0], q[1]), cirq.X(q[2]).with_tags("ignore")), + cirq.CCZ(*q), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@───@───────────@─── + │ │ │ +1: ───@───X───────────@─── + │ │ +2: ───@───X[ignore]───@─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@─── ] + [ │ │ ] +0: ───[ 1: ───@───X─── ]───────────────────────@─── + [ │ ] │ + [ 2: ───@─────── ][merged] │ + │ │ +1: ───#2───────────────────────────────────────@─── + │ │ +2: ───#3───────────────────────────X[ignore]───@─── +''', + ) + + +def test_merge_3q_unitaries_to_circuit_op_prefer_to_merge_into_earlier_op(): + q = cirq.LineQubit.range(6) + c_orig = cirq.Circuit( + cirq.Moment( + cirq.CCZ(*q[0:3]), cirq.X(q[3]), cirq.H(q[4]), cirq.H(q[5]).with_tags("ignore") + ), + cirq.Moment(cirq.CNOT(q[0], q[1]), cirq.X(q[2]).with_tags("ignore"), cirq.CCZ(*q[3:6])), + cirq.Moment( + cirq.X(q[0]), + cirq.X(q[1]), + cirq.X(q[2]), + cirq.X(q[3]).with_tags("ignore"), + cirq.CNOT(*q[4:6]), + ), + cirq.Moment(cirq.CCZ(*q[0:3]), cirq.CCZ(*q[3:6])), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@───────────@───────────X───────────@─── + │ │ │ +1: ───@───────────X───────────X───────────@─── + │ │ +2: ───@───────────X[ignore]───X───────────@─── + +3: ───X───────────@───────────X[ignore]───@─── + │ │ +4: ───H───────────@───────────@───────────@─── + │ │ │ +5: ───H[ignore]───@───────────X───────────@─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@───X─── ] [ 0: ───────@─── ] + [ │ │ ] [ │ ] +0: ───[ 1: ───@───X───X─── ]────────────────────────────────────────────────────────[ 1: ───────@─── ]─────────── + [ │ ] [ │ ] + [ 2: ───@─────────── ][merged] [ 2: ───X───@─── ][merged] + │ │ +1: ───#2────────────────────────────────────────────────────────────────────────────#2─────────────────────────── + │ │ +2: ───#3───────────────────────────────X[ignore]────────────────────────────────────#3─────────────────────────── + + [ 3: ───X───@─────── ] + [ │ ] +3: ────────────────────────────────────[ 4: ───H───@───@─── ]───────────X[ignore]───@──────────────────────────── + [ │ │ ] │ + [ 5: ───────@───X─── ][merged] │ + │ │ +4: ────────────────────────────────────#2───────────────────────────────────────────@──────────────────────────── + │ │ +5: ───H[ignore]────────────────────────#3───────────────────────────────────────────@──────────────────────────── +''', # noqa: E501 + )