From 6a51439989d912d013f35fc06c18f36ffc3874b2 Mon Sep 17 00:00:00 2001 From: Codrut Date: Sun, 13 Jul 2025 11:43:07 +0200 Subject: [PATCH 1/7] Use an efficient representation for connected components of operations in the merge methods. Currently merge_operations represents merged components as a CircuitOperation. This means in a merge of n operations, n-1 CircuitOperations are created, with a complexity of O(n^2). We use a disjont-set data structure to reduce the complexity to O(n) for merge_k_qubit_unitaries_to_circuit_op. merge_operations itself can't be improved because it uses a merge_func that requires creation of CircuitOperation at every step. --- .../cirq/transformers/connected_component.py | 272 ++++++++++++ .../transformers/connected_component_test.py | 212 +++++++++ .../transformers/transformer_primitives.py | 404 +++++++++++------- .../transformer_primitives_test.py | 184 ++++++++ 4 files changed, 927 insertions(+), 145 deletions(-) create mode 100644 cirq-core/cirq/transformers/connected_component.py create mode 100644 cirq-core/cirq/transformers/connected_component_test.py diff --git a/cirq-core/cirq/transformers/connected_component.py b/cirq-core/cirq/transformers/connected_component.py new file mode 100644 index 00000000000..ef4089c07f8 --- /dev/null +++ b/cirq-core/cirq/transformers/connected_component.py @@ -0,0 +1,272 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines a connected component of operations, to be used in merge transformers.""" + +from __future__ import annotations + +from typing import Callable, cast, Sequence, TYPE_CHECKING + +from cirq import ops, protocols + +if TYPE_CHECKING: + import cirq + +class Component: + """Internal representation for a connected component of operations. + + It uses the disjoint-set data structure to implement merge efficiently. + Additional merge conditions can be added by deriving from the Component + class and overriding the merge function (see ComponentWithOps and + ComponentWithCircuitOp) below. + """ + + # Properties for the disjoint set data structure + parent: Component|None = None + rank: int = 0 + + # True if the component can be merged + is_mergeable: bool + + # Circuit moment containing the component + moment: int + # Union of all op qubits in the component + qubits: frozenset[cirq.Qid] + # Union of all measurement keys in the component + mkeys: frozenset[cirq.MeasurementKey] + # Union of all control keys in the component + ckeys: frozenset[cirq.MeasurementKey] + # Initial operation in the component + op: cirq.Operation + + def __init__(self, op: cirq.Operation, moment: int, is_mergeable = True): + """Initializes a singleton component.""" + self.is_mergeable = is_mergeable + self.moment = moment + self.qubits = frozenset(op.qubits) + self.mkeys = protocols.measurement_key_objs(op) + self.ckeys = protocols.control_keys(op) + self.op = op + + def find(self) -> Component: + """Finds the component representative.""" + + root = self + while root.parent != None: + root = root.parent + x = self + while x != root: + parent = x.parent + x.parent = root + x = cast(Component, parent) + return root + + def merge(self, c: Component, merge_left = True) -> Component|None: + """Attempts to merge two components. + + We assume the following is true whenever merge is called: + - if merge_left = True then c.qubits are a subset of self.qubits + - if merge_left = False then self.qubits are a subset of c.qubits + + If merge_left is True, c is merged into this component, and the representative + will keep this moment and qubits. If merge_left is False, this component is + merged into c, and the representative will keep c's moment and qubits. + + Args: + c: other component to merge + merge_left: True to keep self's data for the merged component, False to + keep c's data for the merged component. + + Returns: + None, if the components can't be merged. + Otherwise the new component representative. + """ + x = self.find() + y = c.find() + + if not x.is_mergeable or not y.is_mergeable: + return None + + if x == y: + return x + + if x.rank < y.rank: + if merge_left: + # As y will be the new representative, copy moment and qubits from x + y.moment = x.moment + y.qubits = x.qubits + x, y = y, x + elif not merge_left: + # As x will be the new representative, copy moment and qubits from y + x.moment = y.moment + x.qubits = y.qubits + + y.parent = x + if x.rank == y.rank: + x.rank += 1 + + x.mkeys = x.mkeys.union(y.mkeys) + x.ckeys = x.ckeys.union(y.ckeys) + return x + + +class ComponentWithOps(Component): + """Component that keeps track of operations. + + Encapsulates a method can_merge that is used to decide if two components + can be merged. + """ + + # List of all operations in the component + ops: list[cirq.Operation] + + # Method to decide if two components can be merged based on their operations + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] + + def __init__(self, op: cirq.Operation, moment: int, + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], + is_mergeable = True): + super().__init__(op, moment, is_mergeable) + self.ops = [op] + self.can_merge = can_merge + + def merge(self, c: Component, merge_left = True) -> Component|None: + """Attempts to merge two components. + + Returns: + None if can_merge is False, otherwise the new representative. + The representative will have ops = a.ops + b.ops. + """ + x = cast(ComponentWithOps, self.find()) + y = cast(ComponentWithOps, c.find()) + + if x == y: + return x + + if not x.is_mergeable or not y.is_mergeable or not x.can_merge(x.ops, y.ops): + return None + + root = cast(ComponentWithOps, super(ComponentWithOps, x).merge(y, merge_left)) + if not root: + return None + root.ops = x.ops + y.ops + # Clear the ops list in the non-representative set to avoid memory consumption + if x != root: + x.ops = [] + else: + y.ops = [] + return root + + +class ComponentWithCircuitOp(Component): + """Component that keeps track of operations as a CircuitOperation. + + Encapsulates a method merge_func that is used to merge two components. + """ + + # CircuitOperation containing all the operations in the component, + # or a single Operation if the component is a singleton + circuit_op: cirq.Operation + + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] + + def __init__(self, op: cirq.Operation, moment: int, + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], + is_mergeable = True): + super().__init__(op, moment, is_mergeable) + self.circuit_op = op + self.merge_func = merge_func + + def merge(self, c: Component, merge_left = True) -> Component|None: + """Attempts to merge two components. + + Returns: + None if merge_func returns None, otherwise the new representative. + """ + x = cast(ComponentWithCircuitOp, self.find()) + y = cast(ComponentWithCircuitOp, c.find()) + + if x == y: + return x + + if not x.is_mergeable or not y.is_mergeable: + return None + + new_op = x.merge_func(x.circuit_op, y.circuit_op) + if not new_op: + return None + + root = cast(ComponentWithCircuitOp, super(ComponentWithCircuitOp, x).merge(y, merge_left)) + if not root: + return None + + root.circuit_op = new_op + # The merge_func can be arbitrary, so we need to recompute the component properties + root.qubits = frozenset(new_op.qubits) + root.mkeys = protocols.measurement_key_objs(new_op) + root.ckeys = protocols.control_keys(new_op) + + # Clear the circuit op in the non-representative set to avoid memory consumption + if x != root: + del x.circuit_op + else: + del y.circuit_op + return root + + +class ComponentFactory: + """Factory for components.""" + + is_mergeable: Callable[[cirq.Operation], bool] + + def __init__(self, + is_mergeable: Callable[[cirq.Operation], bool]): + self.is_mergeable = is_mergeable + + def new_component(self, op: cirq.Operation, moment: int, is_mergeable = True) -> Component: + return Component(op, moment, self.is_mergeable(op) and is_mergeable) + + +class ComponentWithOpsFactory(ComponentFactory): + """Factory for components with operations.""" + + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] + + def __init__(self, + is_mergeable: Callable[[cirq.Operation], bool], + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool]): + super().__init__(is_mergeable) + self.can_merge = can_merge + + def new_component(self, op: cirq.Operation, moment: int, is_mergeable = True) -> Component: + return ComponentWithOps(op, moment, self.can_merge, self.is_mergeable(op) and is_mergeable) + + +class ComponentWithCircuitOpFactory(ComponentFactory): + """Factory for components with operations as CircuitOperation.""" + + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] + + def __init__(self, + is_mergeable: Callable[[cirq.Operation], bool], + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None]): + super().__init__(is_mergeable) + self.merge_func = merge_func + + def new_component(self, op: cirq.Operation, moment: int, is_mergeable=True) -> Component: + return ComponentWithCircuitOp(op, moment, self.merge_func, self.is_mergeable(op) and is_mergeable) + + + + diff --git a/cirq-core/cirq/transformers/connected_component_test.py b/cirq-core/cirq/transformers/connected_component_test.py new file mode 100644 index 00000000000..aaf80c3b00c --- /dev/null +++ b/cirq-core/cirq/transformers/connected_component_test.py @@ -0,0 +1,212 @@ +# Copyright 2025 The Cirq Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import cirq +from cirq.transformers.connected_component import ( + Component, ComponentWithOps, ComponentWithCircuitOp, ComponentFactory, ComponentWithOpsFactory, ComponentWithCircuitOpFactory +) + + +def test_find_returns_itself_for_singleton(): + q = cirq.NamedQubit('x') + c = Component(op=cirq.X(q), moment=0) + assert c.find() == c + + +def test_merge_components(): + q = cirq.NamedQubit('x') + c = [Component(op=cirq.X(q), moment=i) for i in range(5)] + c[1].merge(c[0]) + c[2].merge(c[1]) + c[4].merge(c[3]) + c[3].merge(c[0]) + # Disjoint set structure: + # c[4] + # / \ + # c[1] c[3] + # / \ + # c[0] c[2] + assert c[0].parent == c[1] + assert c[2].parent == c[1] + assert c[1].parent == c[4] + assert c[3].parent == c[4] + + for i in range(5): + assert c[i].find() == c[4] + # Find() compressed all paths + for i in range(4): + assert c[i].parent == c[4] + + +def test_merge_returns_None_if_one_component_is_not_mergeable(): + q = cirq.NamedQubit('x') + c0 = Component(op=cirq.X(q), moment=0, is_mergeable=True) + c1 = Component(op=cirq.X(q), moment=1, is_mergeable=False) + assert c0.merge(c1) == None + + +def test_factory_merge_returns_None_if_is_mergeable_is_false(): + q = cirq.NamedQubit('x') + + def is_mergeable(op: cirq.Operation) -> bool: + del op + return False + + factory = ComponentFactory(is_mergeable=is_mergeable) + c0 = factory.new_component(op=cirq.X(q), moment=0, is_mergeable=True) + c1 = factory.new_component(op=cirq.X(q), moment=1, is_mergeable=True) + assert c0.merge(c1) == None + + +def test_merge_qubits_with_merge_left_true(): + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = Component(op=cirq.X(q0), moment=0) + c1 = Component(op=cirq.X(q1), moment=0) + c2 = Component(op=cirq.X(q1), moment=1) + c1.merge(c2) + c0.merge(c1, merge_left=True) + assert c0.find() == c1 + # c1 is the set representative but kept c0's qubits + assert c1.qubits == frozenset([q0]) + + +def test_merge_qubits_with_merge_left_false(): + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = Component(op=cirq.X(q0), moment=0) + c1 = Component(op=cirq.X(q0), moment=0) + c2 = Component(op=cirq.X(q1), moment=1) + c0.merge(c1) + c1.merge(c2, merge_left=False) + assert c2.find() == c0 + # c0 is the set representative but kept c2's qubits + assert c0.qubits == frozenset([q1]) + + +def test_merge_moment_with_merge_left_true(): + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = Component(op=cirq.X(q0), moment=0) + c1 = Component(op=cirq.X(q1), moment=1) + c2 = Component(op=cirq.X(q1), moment=1) + c1.merge(c2) + c0.merge(c1, merge_left=True) + assert c0.find() == c1 + # c1 is the set representative but kept c0's moment + assert c1.moment == 0 + + +def test_merge_moment_with_merge_left_false(): + q0 = cirq.NamedQubit('x') + q1 = cirq.NamedQubit('y') + c0 = Component(op=cirq.X(q0), moment=0) + c1 = Component(op=cirq.X(q0), moment=0) + c2 = Component(op=cirq.X(q1), moment=1) + c0.merge(c1) + c1.merge(c2, merge_left=False) + assert c2.find() == c0 + # c0 is the set representative but kept c2's moment + assert c0.moment == 1 + + +def test_component_with_ops_merge(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: + del ops1, ops2 + return True + + factory = ComponentWithOpsFactory(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + assert c[0].find().ops == ops + + +def test_component_with_ops_merge_when_merge_fails(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: + del ops1, ops2 + return False + + factory = ComponentWithOpsFactory(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + # No merge happened + for i in range(3): + assert c[i].find() == c[i] + + +def test_component_with_circuit_op_merge(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op2 + return op1 + + factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + for i in range(3): + assert c[i].find().circuit_op == ops[0] + + +def test_component_with_circuit_op_merge_func_is_none(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> None: + del op1, op2 + return None + + factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + # No merge happened + for i in range(3): + assert c[i].find() == c[i] + + + + diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 14396764011..930438bbb7b 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -17,12 +17,20 @@ from __future__ import annotations import bisect +import copy import dataclasses from collections import defaultdict from typing import Callable, cast, Hashable, Sequence, TYPE_CHECKING from cirq import circuits, ops, protocols from cirq.circuits.circuit import CIRCUIT_TYPE +from cirq.transformers.connected_component import ( + Component, + ComponentFactory, + ComponentWithCircuitOp, + ComponentWithCircuitOpFactory, + ComponentWithOpsFactory, +) if TYPE_CHECKING: import cirq @@ -282,17 +290,17 @@ def map_operations_and_unroll( @dataclasses.dataclass class _MergedCircuit: - """An optimized internal representation of a circuit, tailored for `cirq.merge_operations` + """An optimized internal representation of a circuit, tailored for merge operations Attributes: - qubit_indexes: Mapping from qubits to (sorted) list of moment indexes containing operations - acting on the qubit. - mkey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing - measurement operations with the same key. - ckey_indexes: Mapping from measurement keys to (sorted) list of moment indexes containing - classically controlled operations controlled on the same key. - ops_by_index: List of circuit moments containing operations. We use a dictionary instead - of a set to store operations to preserve insertion order. + qubit_indexes: Mapping from qubits to (sorted) list of component moments containing + operations acting on the qubit. + mkey_indexes: Mapping from measurement keys to (sorted) list of component moments + containing measurement operations with the same key. + ckey_indexes: Mapping from measurement keys to (sorted) list of component moments + containing classically controlled operations controlled on the same key. + components_by_index: List of circuit moments containing components. We use a dictionary + instead of a set to store components to preserve insertion order. """ qubit_indexes: dict[cirq.Qid, list[int]] = dataclasses.field( @@ -304,54 +312,224 @@ class _MergedCircuit: ckey_indexes: dict[cirq.MeasurementKey, list[int]] = dataclasses.field( default_factory=lambda: defaultdict(lambda: [-1]) ) - ops_by_index: list[dict[cirq.Operation, int]] = dataclasses.field(default_factory=list) + components_by_index: list[dict[Component, int]] = dataclasses.field(default_factory=list) def append_empty_moment(self) -> None: - self.ops_by_index.append({}) + self.components_by_index.append({}) - def add_op_to_moment(self, moment_index: int, op: cirq.Operation) -> None: - self.ops_by_index[moment_index][op] = 0 - for q in op.qubits: - if moment_index > self.qubit_indexes[q][-1]: - self.qubit_indexes[q].append(moment_index) - else: - bisect.insort(self.qubit_indexes[q], moment_index) - for mkey in protocols.measurement_key_objs(op): - bisect.insort(self.mkey_indexes[mkey], moment_index) - for ckey in protocols.control_keys(op): - bisect.insort(self.ckey_indexes[ckey], moment_index) - - def remove_op_from_moment(self, moment_index: int, op: cirq.Operation) -> None: - self.ops_by_index[moment_index].pop(op) - for q in op.qubits: - if self.qubit_indexes[q][-1] == moment_index: - self.qubit_indexes[q].pop() - else: - self.qubit_indexes[q].remove(moment_index) - for mkey in protocols.measurement_key_objs(op): - self.mkey_indexes[mkey].remove(moment_index) - for ckey in protocols.control_keys(op): - self.ckey_indexes[ckey].remove(moment_index) - - def get_mergeable_ops( - self, op: cirq.Operation, op_qs: set[cirq.Qid] - ) -> tuple[int, list[cirq.Operation]]: - # Find the index of previous moment which can be merged with `op`. - idx = max([self.qubit_indexes[q][-1] for q in op_qs], default=-1) - idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in protocols.control_keys(op)]) - idx = max( - [idx] + [self.ckey_indexes[mkey][-1] for mkey in protocols.measurement_key_objs(op)] - ) - # Return the set of overlapping ops in moment with index `idx`. + def add_moment(self, index: list[int], moment: int) -> None: + """Adds a moment to a sorted list of moment indexes. + + Optimized for the majority case when the new moment is higher than any moment in the list. + """ + if index[-1] < moment: + index.append(moment) + else: + bisect.insort(index, moment) + + def remove_moment(self, index: list[int], moment: int) -> None: + """Removes a moment from a sorted list of moment indexes. + + Optimized for the majority case when the moment is last in the list. + """ + if index[-1] == moment: + index.pop() + else: + index.remove(moment) + + def add_component(self, c: Component) -> None: + """Adds a new components to merged circuit.""" + self.components_by_index[c.moment][c] = 0 + for q in c.qubits: + self.add_moment(self.qubit_indexes[q], c.moment) + for mkey in c.mkeys: + self.add_moment(self.mkey_indexes[mkey], c.moment) + for ckey in c.ckeys: + self.add_moment(self.ckey_indexes[ckey], c.moment) + + def remove_component(self, c: Component, c_data: Component) -> None: + """Removes a component from the merged circuit. + + Args: + c: reference to the component to be removed + c_data: copy of the data in c before any component merges involving c + (this is necessary as component merges alter the component data) + """ + self.components_by_index[c_data.moment].pop(c) + for q in c_data.qubits: + self.remove_moment(self.qubit_indexes[q], c_data.moment) + for mkey in c_data.mkeys: + self.remove_moment(self.mkey_indexes[mkey], c_data.moment) + for ckey in c_data.ckeys: + self.remove_moment(self.ckey_indexes[ckey], c_data.moment) + + def get_mergeable_components(self, c: Component, c_qs: set[cirq.Qid]) -> list[Component]: + """Finds all components that can be merged with c. + + Args: + c: component to be merged with existing components + c_qs: subset of c.qubits used to decide which components are mergeable + + Returns: + list of mergeable components + """ + # Find the index of previous moment which can be merged with `c`. + idx = max([self.qubit_indexes[q][-1] for q in c_qs], default=-1) + idx = max([idx] + [self.mkey_indexes[ckey][-1] for ckey in c.ckeys]) + idx = max([idx] + [self.ckey_indexes[mkey][-1] for mkey in c.mkeys]) + # Return the set of overlapping components in moment with index `idx`. if idx == -1: - return idx, [] + return [] + + return [c for c in self.components_by_index[idx] if not c_qs.isdisjoint(c.qubits)] + + def get_cirq_circuit( + self, components: list[Component], merged_circuit_op_tag: str + ) -> cirq.Circuit: + """Returns the merged circuit. + + Args: + components: all components in creation order + merged_circuit_op_tag: tag to use for CircuitOperations + + Returns: + the circuit with merged components as a CircuitOperation + """ + component_ops: dict[Component, list[cirq.Operation]] = defaultdict(list) + + # Traverse the components in creation order and collect operations + for c in components: + root = c.find() + component_ops[root].append(c.op) + + moments = [] + for m in self.components_by_index: + ops = [] + for c in m.keys(): + if isinstance(c, ComponentWithCircuitOp): + ops.append(c.circuit_op) + continue + if len(component_ops[c]) == 1: + ops.append(component_ops[c][0]) + else: + ops.append( + circuits.CircuitOperation( + circuits.FrozenCircuit(component_ops[c]) + ).with_tags(merged_circuit_op_tag) + ) + moments.append(circuits.Moment(ops)) + return circuits.Circuit(moments) + + +def _merge_operations_impl( + circuit: CIRCUIT_TYPE, + factory: ComponentFactory, + *, + merged_circuit_op_tag: str = "Merged connected component", + tags_to_ignore: Sequence[Hashable] = (), + deep: bool = False, +) -> CIRCUIT_TYPE: + """Merges operations in a circuit. + + Two operations op1 and op2 are merge-able if + - There is no other operations between op1 and op2 in the circuit + - is_subset(op1.qubits, op2.qubits) or is_subset(op2.qubits, op1.qubits) + + The method iterates on the input circuit moment-by-moment from left to right and attempts + to repeatedly merge each operation in the latest moment with all the corresponding merge-able + operations to its left. - return idx, [ - left_op for left_op in self.ops_by_index[idx] if not op_qs.isdisjoint(left_op.qubits) - ] + Operations are wrapped in a component and then component.merge is called to merge two + components. The factory can provide components with different implementations of the merge + function, allowing for optimizations. - def get_cirq_circuit(self) -> cirq.Circuit: - return circuits.Circuit(circuits.Moment(m.keys()) for m in self.ops_by_index) + If op1 and op2 are merged, both op1 and op2 are deleted from the circuit and + the merged component is inserted at the index corresponding to the larger + of op1/op2. If both op1 and op2 act on the same number of qubits, the merged component is + inserted in the smaller moment index to minimize circuit depth. + + At the end every component with more than one operation is replaced by a CircuitOperation. + + Args: + circuit: Input circuit to apply the transformations on. The input circuit is not mutated. + factory: Factory that creates components from an operation. + merged_circuit_op_tag: tag used for CircuitOperations created from merged components. + tags_to_ignore: Sequence of tags which should be ignored during the merge: operations with + these tags will not be merged. + deep: If true, the transformer primitive will be recursively applied to all circuits + wrapped inside circuit operations. + + + Returns: + Copy of input circuit with merged operations. + """ + components = [] # List of all components in creation order + tags_to_ignore_set = set(tags_to_ignore) + + merged_circuit = _MergedCircuit() + for moment_idx, current_moment in enumerate(cast(list['cirq.Moment'], circuit)): + merged_circuit.append_empty_moment() + for op in sorted(current_moment.operations, key=lambda op: op.qubits): + if ( + deep + and isinstance(op.untagged, circuits.CircuitOperation) + and tags_to_ignore_set.isdisjoint(op.tags) + ): + op_untagged = op.untagged + merged_op = op_untagged.replace( + circuit=_merge_operations_impl( + op_untagged.circuit, + factory, + merged_circuit_op_tag=merged_circuit_op_tag, + tags_to_ignore=tags_to_ignore, + deep=True, + ) + ).with_tags(*op.tags) + c = factory.new_component(merged_op, moment_idx, is_mergeable=False) + components.append(c) + merged_circuit.add_component(c) + continue + + c = factory.new_component( + op, moment_idx, is_mergeable=tags_to_ignore_set.isdisjoint(op.tags) + ) + components.append(c) + if not c.is_mergeable: + merged_circuit.add_component(c) + continue + + c_qs = set(c.qubits) + left_comp = merged_circuit.get_mergeable_components(c, c_qs) + if len(left_comp) == 1 and c_qs.issubset(left_comp[0].qubits): + # Make a shallow copy of the left component data before merge + left_c_data = copy.copy(left_comp[0]) + # Case-1: Try to merge c with the larger component on the left. + new_comp = left_comp[0].merge(c, merge_left=True) + if new_comp is not None: + merged_circuit.remove_component(left_comp[0], left_c_data) + merged_circuit.add_component(new_comp) + else: + merged_circuit.add_component(c) + continue + + while left_comp and c_qs: + # Case-2: left_c will merge right into `c` whenever possible. + for left_c in left_comp: + is_merged = False + if c_qs.issuperset(left_c.qubits): + # Make a shallow copy of the left component data before merge + left_c_data = copy.copy(left_c) + # Try to merge left_c into c + new_comp = left_c.merge(c, merge_left=False) + if new_comp is not None: + merged_circuit.remove_component(left_c, left_c_data) + c, is_merged = new_comp, True + if not is_merged: + c_qs -= left_c.qubits + left_comp = merged_circuit.get_mergeable_components(c, c_qs) + merged_circuit.add_component(c) + ret_circuit = merged_circuit.get_cirq_circuit(components, merged_circuit_op_tag) + return _to_target_circuit_type(ret_circuit, circuit) def merge_operations( @@ -407,12 +585,8 @@ def merge_operations( ValueError if the merged operation acts on new qubits outside the set of qubits corresponding to the original operations to be merged. """ - _circuit_op_tag = "_internal_tag_to_mark_circuit_ops_in_circuit" - tags_to_ignore_set = set(tags_to_ignore) | {_circuit_op_tag} def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> ops.Operation | None: - if not all(tags_to_ignore_set.isdisjoint(op.tags) for op in [op1, op2]): - return None new_op = merge_func(op1, op2) qubit_set = frozenset(op1.qubits + op2.qubits) if new_op is not None and not qubit_set.issuperset(new_op.qubits): @@ -422,63 +596,16 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> ops.Operation | ) return new_op - merged_circuit = _MergedCircuit() - for moment_idx, current_moment in enumerate(cast(list['cirq.Moment'], circuit)): - merged_circuit.append_empty_moment() - for op in sorted(current_moment.operations, key=lambda op: op.qubits): - if ( - deep - and isinstance(op.untagged, circuits.CircuitOperation) - and tags_to_ignore_set.isdisjoint(op.tags) - ): - op_untagged = op.untagged - merged_circuit.add_op_to_moment( - moment_idx, - op_untagged.replace( - circuit=merge_operations( - op_untagged.circuit, - merge_func, - tags_to_ignore=tags_to_ignore, - deep=True, - ) - ).with_tags(*op.tags, _circuit_op_tag), - ) - continue - - op_qs = set(op.qubits) - left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs) - if len(left_ops) == 1 and op_qs.issubset(left_ops[0].qubits): - # Case-1: Try to merge op with the larger operation on the left. - new_op = apply_merge_func(left_ops[0], op) - if new_op is not None: - merged_circuit.remove_op_from_moment(left_idx, left_ops[0]) - merged_circuit.add_op_to_moment(left_idx, new_op) - else: - merged_circuit.add_op_to_moment(moment_idx, op) - continue + def is_mergeable(op: cirq.Operation): + del op + return True - while left_ops and op_qs: - # Case-2: left_ops will merge right into `op` whenever possible. - for left_op in left_ops: - is_merged = False - if op_qs.issuperset(left_op.qubits): - # Try to merge left_op into op - new_op = apply_merge_func(left_op, op) - if new_op is not None: - merged_circuit.remove_op_from_moment(left_idx, left_op) - op, is_merged = new_op, True - if not is_merged: - op_qs -= frozenset(left_op.qubits) - left_idx, left_ops = merged_circuit.get_mergeable_ops(op, op_qs) - merged_circuit.add_op_to_moment(moment_idx, op) - ret_circuit = merged_circuit.get_cirq_circuit() - if deep: - ret_circuit = map_operations( - ret_circuit, - lambda o, _: o.untagged.with_tags(*(set(o.tags) - {_circuit_op_tag})), - deep=True, - ) - return _to_target_circuit_type(ret_circuit, circuit) + return _merge_operations_impl( + circuit, + ComponentWithCircuitOpFactory(is_mergeable, apply_merge_func), + tags_to_ignore=tags_to_ignore, + deep=deep, + ) def merge_operations_to_circuit_op( @@ -491,10 +618,9 @@ def merge_operations_to_circuit_op( ) -> CIRCUIT_TYPE: """Merges connected components of operations and wraps each component into a circuit operation. - Uses `cirq.merge_operations` to identify connected components of operations. Moment structure - is preserved for operations that do not participate in merging. For merged operations, the - newly created circuit operations are constructed by inserting operations using EARLIEST - strategy. + Moment structure is preserved for operations that do not participate in merging. + For merged operations, the newly created circuit operations are constructed by inserting + operations using EARLIEST strategy. If you need more control on moment structure of newly created circuit operations, consider using `cirq.merge_operations` directly with a custom `merge_func`. @@ -514,24 +640,17 @@ def merge_operations_to_circuit_op( Copy of input circuit with valid connected components wrapped in tagged circuit operations. """ - def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation | None: - def get_ops(op: cirq.Operation): - op_untagged = op.untagged - return ( - [*op_untagged.circuit.all_operations()] - if isinstance(op_untagged, circuits.CircuitOperation) - and merged_circuit_op_tag in op.tags - else [op] - ) - - left_ops, right_ops = get_ops(op1), get_ops(op2) - if not can_merge(left_ops, right_ops): - return None - return circuits.CircuitOperation(circuits.FrozenCircuit(left_ops, right_ops)).with_tags( - merged_circuit_op_tag - ) + def is_mergeable(op: cirq.Operation): + del op + return True - return merge_operations(circuit, merge_func, tags_to_ignore=tags_to_ignore, deep=deep) + return _merge_operations_impl( + circuit, + ComponentWithOpsFactory(is_mergeable, can_merge), + merged_circuit_op_tag=merged_circuit_op_tag, + tags_to_ignore=tags_to_ignore, + deep=deep, + ) def merge_k_qubit_unitaries_to_circuit_op( @@ -544,10 +663,9 @@ def merge_k_qubit_unitaries_to_circuit_op( ) -> CIRCUIT_TYPE: """Merges connected components of operations, acting on <= k qubits, into circuit operations. - Uses `cirq.merge_operations_to_circuit_op` to identify and merge connected components of - unitary operations acting on at-most k-qubits. Moment structure is preserved for operations - that do not participate in merging. For merged operations, the newly created circuit operations - are constructed by inserting operations using EARLIEST strategy. + Moment structure is preserved for operations that do not participate in merging. + For merged operations, the newly created circuit operations are constructed by inserting + operations using EARLIEST strategy. Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. @@ -563,18 +681,14 @@ def merge_k_qubit_unitaries_to_circuit_op( Copy of input circuit with valid connected components wrapped in tagged circuit operations. """ - def can_merge(ops1: Sequence[cirq.Operation], ops2: Sequence[cirq.Operation]) -> bool: - return all( - protocols.num_qubits(op) <= k and protocols.has_unitary(op) - for op_list in [ops1, ops2] - for op in op_list - ) + def is_mergeable(op: cirq.Operation): + return protocols.num_qubits(op) <= k and protocols.has_unitary(op) - return merge_operations_to_circuit_op( + return _merge_operations_impl( circuit, - can_merge, - tags_to_ignore=tags_to_ignore, + ComponentFactory(is_mergeable), merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.", + tags_to_ignore=tags_to_ignore, deep=deep, ) diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index e1152b60aff..e1bba5e06af 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -877,3 +877,187 @@ def merge_func(op1, op2): cirq.testing.assert_same_circuits( cirq.align_left(cirq.merge_operations(circuit, merge_func)), expected_circuit ) + + +def test_merge_3q_unitaries_to_circuit_op_3q_gate_absorbs_overlapping_2q_gates(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.Moment( + cirq.H(q[0]).with_tags("ignore"), + cirq.H(q[1]).with_tags("ignore"), + cirq.H(q[2]).with_tags("ignore"), + ), + cirq.Moment(cirq.CNOT(q[0], q[2]), cirq.X(q[1]).with_tags("ignore")), + cirq.CNOT(q[0], q[1]), + cirq.CNOT(q[1], q[2]), + cirq.CCZ(*q), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' + ┌──────────┐ +0: ───H[ignore]────@─────────────@───────@─── + │ │ │ +1: ───H[ignore]────┼X[ignore]────X───@───@─── + │ │ │ +2: ───H[ignore]────X─────────────────X───@─── + └──────────┘ +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@───────@─── ] + [ │ │ │ ] +0: ───H[ignore]───────────────[ 1: ───┼───X───@───@─── ]─────────── + [ │ │ │ ] + [ 2: ───X───────X───@─── ][merged] + │ +1: ───H[ignore]───X[ignore]───#2─────────────────────────────────── + │ +2: ───H[ignore]───────────────#3─────────────────────────────────── +''', + ) + + +def test_merge_3q_unitaries_to_circuit_op_3q_gate_absorbs_disjoint_gates(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.Moment(cirq.CNOT(q[0], q[1]), cirq.X(q[2])), + cirq.CCZ(*q), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@───@─── + │ │ +1: ───X───@─── + │ +2: ───X───@─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@─── ] + [ │ │ ] +0: ───[ 1: ───X───@─── ]─────────── + [ │ ] + [ 2: ───X───@─── ][merged] + │ +1: ───#2─────────────────────────── + │ +2: ───#3─────────────────────────── +''', + ) + + +def test_merge_3q_unitaries_to_circuit_op_3q_gate_doesnt_absorb_unmergeable_gate(): + q = cirq.LineQubit.range(3) + c_orig = cirq.Circuit( + cirq.CCZ(*q), + cirq.Moment(cirq.CNOT(q[0], q[1]), cirq.X(q[2]).with_tags("ignore")), + cirq.CCZ(*q), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@───@───────────@─── + │ │ │ +1: ───@───X───────────@─── + │ │ +2: ───@───X[ignore]───@─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@─── ] + [ │ │ ] +0: ───[ 1: ───@───X─── ]───────────────────────@─── + [ │ ] │ + [ 2: ───@─────── ][merged] │ + │ │ +1: ───#2───────────────────────────────────────@─── + │ │ +2: ───#3───────────────────────────X[ignore]───@─── +''', + ) + + +def test_merge_3q_unitaries_to_circuit_op_prefer_to_merge_into_earlier_op(): + q = cirq.LineQubit.range(6) + c_orig = cirq.Circuit( + cirq.Moment( + cirq.CCZ(*q[0:3]), cirq.X(q[3]), cirq.H(q[4]), cirq.H(q[5]).with_tags("ignore") + ), + cirq.Moment(cirq.CNOT(q[0], q[1]), cirq.X(q[2]).with_tags("ignore"), cirq.CCZ(*q[3:6])), + cirq.Moment( + cirq.X(q[0]), + cirq.X(q[1]), + cirq.X(q[2]), + cirq.X(q[3]).with_tags("ignore"), + cirq.CNOT(*q[4:6]), + ), + cirq.Moment(cirq.CCZ(*q[0:3]), cirq.CCZ(*q[3:6])), + strategy=cirq.InsertStrategy.NEW, + ) + cirq.testing.assert_has_diagram( + c_orig, + ''' +0: ───@───────────@───────────X───────────@─── + │ │ │ +1: ───@───────────X───────────X───────────@─── + │ │ +2: ───@───────────X[ignore]───X───────────@─── + +3: ───X───────────@───────────X[ignore]───@─── + │ │ +4: ───H───────────@───────────@───────────@─── + │ │ │ +5: ───H[ignore]───@───────────X───────────@─── +''', + ) + + c_new = cirq.merge_k_qubit_unitaries_to_circuit_op( + c_orig, k=3, merged_circuit_op_tag="merged", tags_to_ignore=["ignore"] + ) + cirq.testing.assert_has_diagram( + cirq.drop_empty_moments(c_new), + ''' + [ 0: ───@───@───X─── ] [ 0: ───────@─── ] + [ │ │ ] [ │ ] +0: ───[ 1: ───@───X───X─── ]────────────────────────────────────────────────────────[ 1: ───────@─── ]─────────── + [ │ ] [ │ ] + [ 2: ───@─────────── ][merged] [ 2: ───X───@─── ][merged] + │ │ +1: ───#2────────────────────────────────────────────────────────────────────────────#2─────────────────────────── + │ │ +2: ───#3───────────────────────────────X[ignore]────────────────────────────────────#3─────────────────────────── + + [ 3: ───X───@─────── ] + [ │ ] +3: ────────────────────────────────────[ 4: ───H───@───@─── ]───────────X[ignore]───@──────────────────────────── + [ │ │ ] │ + [ 5: ───────@───X─── ][merged] │ + │ │ +4: ────────────────────────────────────#2───────────────────────────────────────────@──────────────────────────── + │ │ +5: ───H[ignore]────────────────────────#3───────────────────────────────────────────@──────────────────────────── +''', + ) From 3828322e03615412a6fee4903132d6b9ec63a989 Mon Sep 17 00:00:00 2001 From: Codrut Date: Sun, 13 Jul 2025 20:26:47 +0200 Subject: [PATCH 2/7] Fix lint, format and coverage issues. --- .../cirq/transformers/connected_component.py | 68 +++++++------ .../transformers/connected_component_test.py | 97 ++++++++++++++++++- .../transformer_primitives_test.py | 2 +- 3 files changed, 132 insertions(+), 35 deletions(-) diff --git a/cirq-core/cirq/transformers/connected_component.py b/cirq-core/cirq/transformers/connected_component.py index ef4089c07f8..dfbfb322130 100644 --- a/cirq-core/cirq/transformers/connected_component.py +++ b/cirq-core/cirq/transformers/connected_component.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: import cirq + class Component: """Internal representation for a connected component of operations. @@ -33,7 +34,7 @@ class and overriding the merge function (see ComponentWithOps and """ # Properties for the disjoint set data structure - parent: Component|None = None + parent: Component | None = None rank: int = 0 # True if the component can be merged @@ -50,7 +51,7 @@ class and overriding the merge function (see ComponentWithOps and # Initial operation in the component op: cirq.Operation - def __init__(self, op: cirq.Operation, moment: int, is_mergeable = True): + def __init__(self, op: cirq.Operation, moment: int, is_mergeable=True): """Initializes a singleton component.""" self.is_mergeable = is_mergeable self.moment = moment @@ -63,7 +64,7 @@ def find(self) -> Component: """Finds the component representative.""" root = self - while root.parent != None: + while root.parent is not None: root = root.parent x = self while x != root: @@ -72,7 +73,7 @@ def find(self) -> Component: x = cast(Component, parent) return root - def merge(self, c: Component, merge_left = True) -> Component|None: + def merge(self, c: Component, merge_left=True) -> Component | None: """Attempts to merge two components. We assume the following is true whenever merge is called: @@ -134,14 +135,18 @@ class ComponentWithOps(Component): # Method to decide if two components can be merged based on their operations can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] - def __init__(self, op: cirq.Operation, moment: int, - can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], - is_mergeable = True): + def __init__( + self, + op: cirq.Operation, + moment: int, + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], + is_mergeable=True, + ): super().__init__(op, moment, is_mergeable) self.ops = [op] self.can_merge = can_merge - def merge(self, c: Component, merge_left = True) -> Component|None: + def merge(self, c: Component, merge_left=True) -> Component | None: """Attempts to merge two components. Returns: @@ -158,8 +163,6 @@ def merge(self, c: Component, merge_left = True) -> Component|None: return None root = cast(ComponentWithOps, super(ComponentWithOps, x).merge(y, merge_left)) - if not root: - return None root.ops = x.ops + y.ops # Clear the ops list in the non-representative set to avoid memory consumption if x != root: @@ -181,14 +184,18 @@ class ComponentWithCircuitOp(Component): merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] - def __init__(self, op: cirq.Operation, moment: int, - merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], - is_mergeable = True): + def __init__( + self, + op: cirq.Operation, + moment: int, + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], + is_mergeable=True, + ): super().__init__(op, moment, is_mergeable) self.circuit_op = op self.merge_func = merge_func - def merge(self, c: Component, merge_left = True) -> Component|None: + def merge(self, c: Component, merge_left=True) -> Component | None: """Attempts to merge two components. Returns: @@ -208,8 +215,6 @@ def merge(self, c: Component, merge_left = True) -> Component|None: return None root = cast(ComponentWithCircuitOp, super(ComponentWithCircuitOp, x).merge(y, merge_left)) - if not root: - return None root.circuit_op = new_op # The merge_func can be arbitrary, so we need to recompute the component properties @@ -230,11 +235,10 @@ class ComponentFactory: is_mergeable: Callable[[cirq.Operation], bool] - def __init__(self, - is_mergeable: Callable[[cirq.Operation], bool]): + def __init__(self, is_mergeable: Callable[[cirq.Operation], bool]): self.is_mergeable = is_mergeable - def new_component(self, op: cirq.Operation, moment: int, is_mergeable = True) -> Component: + def new_component(self, op: cirq.Operation, moment: int, is_mergeable=True) -> Component: return Component(op, moment, self.is_mergeable(op) and is_mergeable) @@ -243,13 +247,15 @@ class ComponentWithOpsFactory(ComponentFactory): can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] - def __init__(self, - is_mergeable: Callable[[cirq.Operation], bool], - can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool]): + def __init__( + self, + is_mergeable: Callable[[cirq.Operation], bool], + can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], + ): super().__init__(is_mergeable) self.can_merge = can_merge - def new_component(self, op: cirq.Operation, moment: int, is_mergeable = True) -> Component: + def new_component(self, op: cirq.Operation, moment: int, is_mergeable=True) -> Component: return ComponentWithOps(op, moment, self.can_merge, self.is_mergeable(op) and is_mergeable) @@ -258,15 +264,15 @@ class ComponentWithCircuitOpFactory(ComponentFactory): merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] - def __init__(self, - is_mergeable: Callable[[cirq.Operation], bool], - merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None]): + def __init__( + self, + is_mergeable: Callable[[cirq.Operation], bool], + merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], + ): super().__init__(is_mergeable) self.merge_func = merge_func def new_component(self, op: cirq.Operation, moment: int, is_mergeable=True) -> Component: - return ComponentWithCircuitOp(op, moment, self.merge_func, self.is_mergeable(op) and is_mergeable) - - - - + return ComponentWithCircuitOp( + op, moment, self.merge_func, self.is_mergeable(op) and is_mergeable + ) diff --git a/cirq-core/cirq/transformers/connected_component_test.py b/cirq-core/cirq/transformers/connected_component_test.py index aaf80c3b00c..40199bedea0 100644 --- a/cirq-core/cirq/transformers/connected_component_test.py +++ b/cirq-core/cirq/transformers/connected_component_test.py @@ -16,7 +16,10 @@ import cirq from cirq.transformers.connected_component import ( - Component, ComponentWithOps, ComponentWithCircuitOp, ComponentFactory, ComponentWithOpsFactory, ComponentWithCircuitOpFactory + Component, + ComponentFactory, + ComponentWithCircuitOpFactory, + ComponentWithOpsFactory, ) @@ -51,11 +54,23 @@ def test_merge_components(): assert c[i].parent == c[4] +def test_merge_same_component(): + q = cirq.NamedQubit('x') + c = [Component(op=cirq.X(q), moment=i) for i in range(3)] + c[1].merge(c[0]) + c[2].merge(c[1]) + # Disjoint set structure: + # c[1] + # / \ + # c[0] c[2] + assert c[0].merge(c[2]) == c[1] + + def test_merge_returns_None_if_one_component_is_not_mergeable(): q = cirq.NamedQubit('x') c0 = Component(op=cirq.X(q), moment=0, is_mergeable=True) c1 = Component(op=cirq.X(q), moment=1, is_mergeable=False) - assert c0.merge(c1) == None + assert c0.merge(c1) is None def test_factory_merge_returns_None_if_is_mergeable_is_false(): @@ -68,7 +83,7 @@ def is_mergeable(op: cirq.Operation) -> bool: factory = ComponentFactory(is_mergeable=is_mergeable) c0 = factory.new_component(op=cirq.X(q), moment=0, is_mergeable=True) c1 = factory.new_component(op=cirq.X(q), moment=1, is_mergeable=True) - assert c0.merge(c1) == None + assert c0.merge(c1) is None def test_merge_qubits_with_merge_left_true(): @@ -143,6 +158,24 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: assert c[0].find().ops == ops +def test_component_with_ops_merge_same_component(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: + del ops1, ops2 + return True + + factory = ComponentWithOpsFactory(is_mergeable, can_merge) + + q = cirq.NamedQubit('x') + c = [factory.new_component(op=cirq.X(q), moment=i) for i in range(3)] + c[1].merge(c[0]) + c[2].merge(c[1]) + assert c[0].merge(c[2]) == c[1] + + def test_component_with_ops_merge_when_merge_fails(): def is_mergeable(op: cirq.Operation) -> bool: del op @@ -165,6 +198,28 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: assert c[i].find() == c[i] +def test_component_with_ops_merge_when_is_mergeable_is_false(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return False + + def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: + del ops1, ops2 + return True + + factory = ComponentWithOpsFactory(is_mergeable, can_merge) + + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + # No merge happened + for i in range(3): + assert c[i].find() == c[i] + + def test_component_with_circuit_op_merge(): def is_mergeable(op: cirq.Operation) -> bool: del op @@ -186,6 +241,24 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: assert c[i].find().circuit_op == ops[0] +def test_component_with_circuit_op_merge_same_component(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return True + + def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op2 + return op1 + + factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + + q = cirq.NamedQubit('x') + c = [factory.new_component(op=cirq.X(q), moment=i) for i in range(3)] + c[1].merge(c[0]) + c[2].merge(c[1]) + assert c[0].merge(c[2]) == c[1] + + def test_component_with_circuit_op_merge_func_is_none(): def is_mergeable(op: cirq.Operation) -> bool: del op @@ -208,5 +281,23 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> None: assert c[i].find() == c[i] +def test_component_with_circuit_op_merge_when_is_mergeable_is_false(): + def is_mergeable(op: cirq.Operation) -> bool: + del op + return False + + def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op2 + return op1 + factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + + c[0].merge(c[1]) + c[1].merge(c[2]) + # No merge happened + for i in range(3): + assert c[i].find() == c[i] diff --git a/cirq-core/cirq/transformers/transformer_primitives_test.py b/cirq-core/cirq/transformers/transformer_primitives_test.py index e1bba5e06af..eddc866266c 100644 --- a/cirq-core/cirq/transformers/transformer_primitives_test.py +++ b/cirq-core/cirq/transformers/transformer_primitives_test.py @@ -1059,5 +1059,5 @@ def test_merge_3q_unitaries_to_circuit_op_prefer_to_merge_into_earlier_op(): 4: ────────────────────────────────────#2───────────────────────────────────────────@──────────────────────────── │ │ 5: ───H[ignore]────────────────────────#3───────────────────────────────────────────@──────────────────────────── -''', +''', # noqa: E501 ) From 0aeafa80f5e4b2e3e55bd2bd16d7e6d49faff8e9 Mon Sep 17 00:00:00 2001 From: Codrut Date: Wed, 30 Jul 2025 21:19:28 +0200 Subject: [PATCH 3/7] Address review comments. --- .../cirq/transformers/connected_component.py | 66 ++++++++------ .../transformers/connected_component_test.py | 85 ++++++++++++------- .../transformers/transformer_primitives.py | 30 +++---- 3 files changed, 105 insertions(+), 76 deletions(-) diff --git a/cirq-core/cirq/transformers/connected_component.py b/cirq-core/cirq/transformers/connected_component.py index dfbfb322130..65365ddd3b5 100644 --- a/cirq-core/cirq/transformers/connected_component.py +++ b/cirq-core/cirq/transformers/connected_component.py @@ -18,6 +18,8 @@ from typing import Callable, cast, Sequence, TYPE_CHECKING +from typing_extensions import override + from cirq import ops, protocols if TYPE_CHECKING: @@ -41,7 +43,7 @@ class and overriding the merge function (see ComponentWithOps and is_mergeable: bool # Circuit moment containing the component - moment: int + moment_id: int # Union of all op qubits in the component qubits: frozenset[cirq.Qid] # Union of all measurement keys in the component @@ -51,10 +53,10 @@ class and overriding the merge function (see ComponentWithOps and # Initial operation in the component op: cirq.Operation - def __init__(self, op: cirq.Operation, moment: int, is_mergeable=True): + def __init__(self, op: cirq.Operation, moment_id: int, is_mergeable=True): """Initializes a singleton component.""" self.is_mergeable = is_mergeable - self.moment = moment + self.moment_id = moment_id self.qubits = frozenset(op.qubits) self.mkeys = protocols.measurement_key_objs(op) self.ckeys = protocols.control_keys(op) @@ -76,18 +78,14 @@ def find(self) -> Component: def merge(self, c: Component, merge_left=True) -> Component | None: """Attempts to merge two components. - We assume the following is true whenever merge is called: - - if merge_left = True then c.qubits are a subset of self.qubits - - if merge_left = False then self.qubits are a subset of c.qubits - If merge_left is True, c is merged into this component, and the representative - will keep this moment and qubits. If merge_left is False, this component is - merged into c, and the representative will keep c's moment and qubits. + will keep this component's moment. If merge_left is False, this component is + merged into c, and the representative will keep c's moment. Args: c: other component to merge - merge_left: True to keep self's data for the merged component, False to - keep c's data for the merged component. + merge_left: True to keep self's moment for the merged component, False to + keep c's moment for the merged component. Returns: None, if the components can't be merged. @@ -104,19 +102,18 @@ def merge(self, c: Component, merge_left=True) -> Component | None: if x.rank < y.rank: if merge_left: - # As y will be the new representative, copy moment and qubits from x - y.moment = x.moment - y.qubits = x.qubits + # As y will be the new representative, copy moment id from x + y.moment_id = x.moment_id x, y = y, x elif not merge_left: - # As x will be the new representative, copy moment and qubits from y - x.moment = y.moment - x.qubits = y.qubits + # As x will be the new representative, copy moment id from y + x.moment_id = y.moment_id y.parent = x if x.rank == y.rank: x.rank += 1 + x.qubits = x.qubits.union(y.qubits) x.mkeys = x.mkeys.union(y.mkeys) x.ckeys = x.ckeys.union(y.ckeys) return x @@ -138,14 +135,15 @@ class ComponentWithOps(Component): def __init__( self, op: cirq.Operation, - moment: int, + moment_id: int, can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], is_mergeable=True, ): - super().__init__(op, moment, is_mergeable) + super().__init__(op, moment_id, is_mergeable) self.ops = [op] self.can_merge = can_merge + @override def merge(self, c: Component, merge_left=True) -> Component | None: """Attempts to merge two components. @@ -187,17 +185,22 @@ class ComponentWithCircuitOp(Component): def __init__( self, op: cirq.Operation, - moment: int, + moment_id: int, merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], is_mergeable=True, ): - super().__init__(op, moment, is_mergeable) + super().__init__(op, moment_id, is_mergeable) self.circuit_op = op self.merge_func = merge_func + @override def merge(self, c: Component, merge_left=True) -> Component | None: """Attempts to merge two components. + If merge_left is True, the merge will use this component representative's + merge_func. If merge_left is False, the merge will use c representative's + merge_func. + Returns: None if merge_func returns None, otherwise the new representative. """ @@ -210,7 +213,10 @@ def merge(self, c: Component, merge_left=True) -> Component | None: if not x.is_mergeable or not y.is_mergeable: return None - new_op = x.merge_func(x.circuit_op, y.circuit_op) + if merge_left: + new_op = x.merge_func(x.circuit_op, y.circuit_op) + else: + new_op = y.merge_func(x.circuit_op, y.circuit_op) if not new_op: return None @@ -238,8 +244,8 @@ class ComponentFactory: def __init__(self, is_mergeable: Callable[[cirq.Operation], bool]): self.is_mergeable = is_mergeable - def new_component(self, op: cirq.Operation, moment: int, is_mergeable=True) -> Component: - return Component(op, moment, self.is_mergeable(op) and is_mergeable) + def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: + return Component(op, moment_id, self.is_mergeable(op) and is_mergeable) class ComponentWithOpsFactory(ComponentFactory): @@ -255,8 +261,11 @@ def __init__( super().__init__(is_mergeable) self.can_merge = can_merge - def new_component(self, op: cirq.Operation, moment: int, is_mergeable=True) -> Component: - return ComponentWithOps(op, moment, self.can_merge, self.is_mergeable(op) and is_mergeable) + @override + def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: + return ComponentWithOps( + op, moment_id, self.can_merge, self.is_mergeable(op) and is_mergeable + ) class ComponentWithCircuitOpFactory(ComponentFactory): @@ -272,7 +281,8 @@ def __init__( super().__init__(is_mergeable) self.merge_func = merge_func - def new_component(self, op: cirq.Operation, moment: int, is_mergeable=True) -> Component: + @override + def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: return ComponentWithCircuitOp( - op, moment, self.merge_func, self.is_mergeable(op) and is_mergeable + op, moment_id, self.merge_func, self.is_mergeable(op) and is_mergeable ) diff --git a/cirq-core/cirq/transformers/connected_component_test.py b/cirq-core/cirq/transformers/connected_component_test.py index 40199bedea0..e8db273c477 100644 --- a/cirq-core/cirq/transformers/connected_component_test.py +++ b/cirq-core/cirq/transformers/connected_component_test.py @@ -14,10 +14,13 @@ from __future__ import annotations +from typing import cast + import cirq from cirq.transformers.connected_component import ( Component, ComponentFactory, + ComponentWithCircuitOp, ComponentWithCircuitOpFactory, ComponentWithOpsFactory, ) @@ -25,13 +28,13 @@ def test_find_returns_itself_for_singleton(): q = cirq.NamedQubit('x') - c = Component(op=cirq.X(q), moment=0) + c = Component(op=cirq.X(q), moment_id=0) assert c.find() == c def test_merge_components(): q = cirq.NamedQubit('x') - c = [Component(op=cirq.X(q), moment=i) for i in range(5)] + c = [Component(op=cirq.X(q), moment_id=i) for i in range(5)] c[1].merge(c[0]) c[2].merge(c[1]) c[4].merge(c[3]) @@ -56,7 +59,7 @@ def test_merge_components(): def test_merge_same_component(): q = cirq.NamedQubit('x') - c = [Component(op=cirq.X(q), moment=i) for i in range(3)] + c = [Component(op=cirq.X(q), moment_id=i) for i in range(3)] c[1].merge(c[0]) c[2].merge(c[1]) # Disjoint set structure: @@ -68,8 +71,8 @@ def test_merge_same_component(): def test_merge_returns_None_if_one_component_is_not_mergeable(): q = cirq.NamedQubit('x') - c0 = Component(op=cirq.X(q), moment=0, is_mergeable=True) - c1 = Component(op=cirq.X(q), moment=1, is_mergeable=False) + c0 = Component(op=cirq.X(q), moment_id=0, is_mergeable=True) + c1 = Component(op=cirq.X(q), moment_id=1, is_mergeable=False) assert c0.merge(c1) is None @@ -81,61 +84,59 @@ def is_mergeable(op: cirq.Operation) -> bool: return False factory = ComponentFactory(is_mergeable=is_mergeable) - c0 = factory.new_component(op=cirq.X(q), moment=0, is_mergeable=True) - c1 = factory.new_component(op=cirq.X(q), moment=1, is_mergeable=True) + c0 = factory.new_component(op=cirq.X(q), moment_id=0, is_mergeable=True) + c1 = factory.new_component(op=cirq.X(q), moment_id=1, is_mergeable=True) assert c0.merge(c1) is None def test_merge_qubits_with_merge_left_true(): q0 = cirq.NamedQubit('x') q1 = cirq.NamedQubit('y') - c0 = Component(op=cirq.X(q0), moment=0) - c1 = Component(op=cirq.X(q1), moment=0) - c2 = Component(op=cirq.X(q1), moment=1) + c0 = Component(op=cirq.X(q0), moment_id=0) + c1 = Component(op=cirq.X(q1), moment_id=0) + c2 = Component(op=cirq.X(q1), moment_id=1) c1.merge(c2) c0.merge(c1, merge_left=True) assert c0.find() == c1 - # c1 is the set representative but kept c0's qubits - assert c1.qubits == frozenset([q0]) + assert c1.qubits == frozenset([q0, q1]) def test_merge_qubits_with_merge_left_false(): q0 = cirq.NamedQubit('x') q1 = cirq.NamedQubit('y') - c0 = Component(op=cirq.X(q0), moment=0) - c1 = Component(op=cirq.X(q0), moment=0) - c2 = Component(op=cirq.X(q1), moment=1) + c0 = Component(op=cirq.X(q0), moment_id=0) + c1 = Component(op=cirq.X(q0), moment_id=0) + c2 = Component(op=cirq.X(q1), moment_id=1) c0.merge(c1) c1.merge(c2, merge_left=False) assert c2.find() == c0 - # c0 is the set representative but kept c2's qubits - assert c0.qubits == frozenset([q1]) + assert c0.qubits == frozenset([q0, q1]) def test_merge_moment_with_merge_left_true(): q0 = cirq.NamedQubit('x') q1 = cirq.NamedQubit('y') - c0 = Component(op=cirq.X(q0), moment=0) - c1 = Component(op=cirq.X(q1), moment=1) - c2 = Component(op=cirq.X(q1), moment=1) + c0 = Component(op=cirq.X(q0), moment_id=0) + c1 = Component(op=cirq.X(q1), moment_id=1) + c2 = Component(op=cirq.X(q1), moment_id=1) c1.merge(c2) c0.merge(c1, merge_left=True) assert c0.find() == c1 # c1 is the set representative but kept c0's moment - assert c1.moment == 0 + assert c1.moment_id == 0 def test_merge_moment_with_merge_left_false(): q0 = cirq.NamedQubit('x') q1 = cirq.NamedQubit('y') - c0 = Component(op=cirq.X(q0), moment=0) - c1 = Component(op=cirq.X(q0), moment=0) - c2 = Component(op=cirq.X(q1), moment=1) + c0 = Component(op=cirq.X(q0), moment_id=0) + c1 = Component(op=cirq.X(q0), moment_id=0) + c2 = Component(op=cirq.X(q1), moment_id=1) c0.merge(c1) c1.merge(c2, merge_left=False) assert c2.find() == c0 # c0 is the set representative but kept c2's moment - assert c0.moment == 1 + assert c0.moment_id == 1 def test_component_with_ops_merge(): @@ -151,7 +152,7 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] c[0].merge(c[1]) c[1].merge(c[2]) @@ -170,7 +171,7 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: factory = ComponentWithOpsFactory(is_mergeable, can_merge) q = cirq.NamedQubit('x') - c = [factory.new_component(op=cirq.X(q), moment=i) for i in range(3)] + c = [factory.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] c[1].merge(c[0]) c[2].merge(c[1]) assert c[0].merge(c[2]) == c[1] @@ -189,7 +190,7 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] c[0].merge(c[1]) c[1].merge(c[2]) @@ -211,7 +212,7 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] c[0].merge(c[1]) c[1].merge(c[2]) @@ -233,7 +234,7 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] c[0].merge(c[1]) c[1].merge(c[2]) @@ -253,7 +254,7 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) q = cirq.NamedQubit('x') - c = [factory.new_component(op=cirq.X(q), moment=i) for i in range(3)] + c = [factory.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] c[1].merge(c[0]) c[2].merge(c[1]) assert c[0].merge(c[2]) == c[1] @@ -272,7 +273,7 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> None: q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] c[0].merge(c[1]) c[1].merge(c[2]) @@ -294,10 +295,28 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment=i) for i in range(3)] + c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] c[0].merge(c[1]) c[1].merge(c[2]) # No merge happened for i in range(3): assert c[i].find() == c[i] + + +def test_component_with_circuit_op_merge_when_merge_left_is_false(): + def merge_func_x(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op2 + return op1 + + def merge_func_y(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: + del op1 + return op2 + + q = cirq.LineQubit.range(2) + x = ComponentWithCircuitOp(cirq.X(q[0]), moment_id=0, merge_func=merge_func_x) + y = ComponentWithCircuitOp(cirq.X(q[1]), moment_id=1, merge_func=merge_func_y) + + root = cast(ComponentWithCircuitOp, x.merge(y, merge_left=False)) + # The merge used merge_func_y because merge_left=False + assert root.circuit_op == cirq.X(q[1]) diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 930438bbb7b..a9c165dfc13 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -317,35 +317,35 @@ class _MergedCircuit: def append_empty_moment(self) -> None: self.components_by_index.append({}) - def add_moment(self, index: list[int], moment: int) -> None: + def add_moment(self, index: list[int], moment_id: int) -> None: """Adds a moment to a sorted list of moment indexes. Optimized for the majority case when the new moment is higher than any moment in the list. """ - if index[-1] < moment: - index.append(moment) + if index[-1] < moment_id: + index.append(moment_id) else: - bisect.insort(index, moment) + bisect.insort(index, moment_id) - def remove_moment(self, index: list[int], moment: int) -> None: + def remove_moment(self, index: list[int], moment_id: int) -> None: """Removes a moment from a sorted list of moment indexes. Optimized for the majority case when the moment is last in the list. """ - if index[-1] == moment: + if index[-1] == moment_id: index.pop() else: - index.remove(moment) + index.remove(moment_id) def add_component(self, c: Component) -> None: """Adds a new components to merged circuit.""" - self.components_by_index[c.moment][c] = 0 + self.components_by_index[c.moment_id][c] = 0 for q in c.qubits: - self.add_moment(self.qubit_indexes[q], c.moment) + self.add_moment(self.qubit_indexes[q], c.moment_id) for mkey in c.mkeys: - self.add_moment(self.mkey_indexes[mkey], c.moment) + self.add_moment(self.mkey_indexes[mkey], c.moment_id) for ckey in c.ckeys: - self.add_moment(self.ckey_indexes[ckey], c.moment) + self.add_moment(self.ckey_indexes[ckey], c.moment_id) def remove_component(self, c: Component, c_data: Component) -> None: """Removes a component from the merged circuit. @@ -355,13 +355,13 @@ def remove_component(self, c: Component, c_data: Component) -> None: c_data: copy of the data in c before any component merges involving c (this is necessary as component merges alter the component data) """ - self.components_by_index[c_data.moment].pop(c) + self.components_by_index[c_data.moment_id].pop(c) for q in c_data.qubits: - self.remove_moment(self.qubit_indexes[q], c_data.moment) + self.remove_moment(self.qubit_indexes[q], c_data.moment_id) for mkey in c_data.mkeys: - self.remove_moment(self.mkey_indexes[mkey], c_data.moment) + self.remove_moment(self.mkey_indexes[mkey], c_data.moment_id) for ckey in c_data.ckeys: - self.remove_moment(self.ckey_indexes[ckey], c_data.moment) + self.remove_moment(self.ckey_indexes[ckey], c_data.moment_id) def get_mergeable_components(self, c: Component, c_qs: set[cirq.Qid]) -> list[Component]: """Finds all components that can be merged with c. From 3d1de8df8a6fd5912bf980d44d76549bb082ae23 Mon Sep 17 00:00:00 2001 From: Codrut Date: Sun, 10 Aug 2025 12:28:23 +0200 Subject: [PATCH 4/7] Address review comments. --- .../transformers/connected_component_test.py | 42 +++++++------------ .../transformers/transformer_primitives.py | 18 ++++---- 2 files changed, 25 insertions(+), 35 deletions(-) diff --git a/cirq-core/cirq/transformers/connected_component_test.py b/cirq-core/cirq/transformers/connected_component_test.py index e8db273c477..4f9b17808ea 100644 --- a/cirq-core/cirq/transformers/connected_component_test.py +++ b/cirq-core/cirq/transformers/connected_component_test.py @@ -79,8 +79,7 @@ def test_merge_returns_None_if_one_component_is_not_mergeable(): def test_factory_merge_returns_None_if_is_mergeable_is_false(): q = cirq.NamedQubit('x') - def is_mergeable(op: cirq.Operation) -> bool: - del op + def is_mergeable(_: cirq.Operation) -> bool: return False factory = ComponentFactory(is_mergeable=is_mergeable) @@ -140,8 +139,7 @@ def test_merge_moment_with_merge_left_false(): def test_component_with_ops_merge(): - def is_mergeable(op: cirq.Operation) -> bool: - del op + def is_mergeable(_: cirq.Operation) -> bool: return True def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: @@ -160,8 +158,7 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: def test_component_with_ops_merge_same_component(): - def is_mergeable(op: cirq.Operation) -> bool: - del op + def is_mergeable(_: cirq.Operation) -> bool: return True def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: @@ -178,8 +175,7 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: def test_component_with_ops_merge_when_merge_fails(): - def is_mergeable(op: cirq.Operation) -> bool: - del op + def is_mergeable(_: cirq.Operation) -> bool: return True def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: @@ -200,8 +196,7 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: def test_component_with_ops_merge_when_is_mergeable_is_false(): - def is_mergeable(op: cirq.Operation) -> bool: - del op + def is_mergeable(_: cirq.Operation) -> bool: return False def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: @@ -222,12 +217,10 @@ def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: def test_component_with_circuit_op_merge(): - def is_mergeable(op: cirq.Operation) -> bool: - del op + def is_mergeable(_: cirq.Operation) -> bool: return True - def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: - del op2 + def merge_func(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: return op1 factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) @@ -243,12 +236,10 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: def test_component_with_circuit_op_merge_same_component(): - def is_mergeable(op: cirq.Operation) -> bool: - del op + def is_mergeable(_: cirq.Operation) -> bool: return True - def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: - del op2 + def merge_func(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: return op1 factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) @@ -261,8 +252,7 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: def test_component_with_circuit_op_merge_func_is_none(): - def is_mergeable(op: cirq.Operation) -> bool: - del op + def is_mergeable(_: cirq.Operation) -> bool: return True def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> None: @@ -283,12 +273,10 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> None: def test_component_with_circuit_op_merge_when_is_mergeable_is_false(): - def is_mergeable(op: cirq.Operation) -> bool: - del op + def is_mergeable(_: cirq.Operation) -> bool: return False - def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: - del op2 + def merge_func(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: return op1 factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) @@ -305,12 +293,10 @@ def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: def test_component_with_circuit_op_merge_when_merge_left_is_false(): - def merge_func_x(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: - del op2 + def merge_func_x(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: return op1 - def merge_func_y(op1: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: - del op1 + def merge_func_y(_: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: return op2 q = cirq.LineQubit.range(2) diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index a9c165dfc13..74e2a4ac948 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -290,7 +290,11 @@ def map_operations_and_unroll( @dataclasses.dataclass class _MergedCircuit: - """An optimized internal representation of a circuit, tailored for merge operations + """An optimized internal representation of a circuit, tailored for merge operations. + + Operations are represented as mergeable components. + Each component has a moment id, a set of qubits, a set of measurement keys, and a set of + control keys. The moment id is the index of the moment that contains the component. Attributes: qubit_indexes: Mapping from qubits to (sorted) list of component moments containing @@ -299,8 +303,10 @@ class _MergedCircuit: containing measurement operations with the same key. ckey_indexes: Mapping from measurement keys to (sorted) list of component moments containing classically controlled operations controlled on the same key. - components_by_index: List of circuit moments containing components. We use a dictionary - instead of a set to store components to preserve insertion order. + components_by_index: List of components indexed by moment id. + For a moment id, we use a dictionary instead of a set to keep track of the + components in the moment. The dictionary is used to preserve insertion order, + and the values have no meaning. """ qubit_indexes: dict[cirq.Qid, list[int]] = dataclasses.field( @@ -596,8 +602,7 @@ def apply_merge_func(op1: ops.Operation, op2: ops.Operation) -> ops.Operation | ) return new_op - def is_mergeable(op: cirq.Operation): - del op + def is_mergeable(_: cirq.Operation): return True return _merge_operations_impl( @@ -640,8 +645,7 @@ def merge_operations_to_circuit_op( Copy of input circuit with valid connected components wrapped in tagged circuit operations. """ - def is_mergeable(op: cirq.Operation): - del op + def is_mergeable(_: cirq.Operation): return True return _merge_operations_impl( From c240b3afe4819335993e69221826cd3e9565479e Mon Sep 17 00:00:00 2001 From: Codrut Date: Fri, 15 Aug 2025 23:58:52 +0200 Subject: [PATCH 5/7] Use scipy.DisjointSet to do union-find. --- .../cirq/transformers/connected_component.py | 274 +++++++---------- .../transformers/connected_component_test.py | 281 +++++++++--------- .../transformers/transformer_primitives.py | 64 ++-- 3 files changed, 283 insertions(+), 336 deletions(-) diff --git a/cirq-core/cirq/transformers/connected_component.py b/cirq-core/cirq/transformers/connected_component.py index 65365ddd3b5..f611e24a747 100644 --- a/cirq-core/cirq/transformers/connected_component.py +++ b/cirq-core/cirq/transformers/connected_component.py @@ -18,6 +18,7 @@ from typing import Callable, cast, Sequence, TYPE_CHECKING +from scipy.cluster.hierarchy import DisjointSet from typing_extensions import override from cirq import ops, protocols @@ -27,20 +28,7 @@ class Component: - """Internal representation for a connected component of operations. - - It uses the disjoint-set data structure to implement merge efficiently. - Additional merge conditions can be added by deriving from the Component - class and overriding the merge function (see ComponentWithOps and - ComponentWithCircuitOp) below. - """ - - # Properties for the disjoint set data structure - parent: Component | None = None - rank: int = 0 - - # True if the component can be merged - is_mergeable: bool + """Internal representation for a connected component of operations.""" # Circuit moment containing the component moment_id: int @@ -53,116 +41,146 @@ class and overriding the merge function (see ComponentWithOps and # Initial operation in the component op: cirq.Operation + # True if the component can be merged with other components + is_mergeable: bool + def __init__(self, op: cirq.Operation, moment_id: int, is_mergeable=True): """Initializes a singleton component.""" + self.op = op self.is_mergeable = is_mergeable self.moment_id = moment_id self.qubits = frozenset(op.qubits) self.mkeys = protocols.measurement_key_objs(op) self.ckeys = protocols.control_keys(op) - self.op = op - def find(self) -> Component: - """Finds the component representative.""" - - root = self - while root.parent is not None: - root = root.parent - x = self - while x != root: - parent = x.parent - x.parent = root - x = cast(Component, parent) - return root - def merge(self, c: Component, merge_left=True) -> Component | None: +class ComponentWithOps(Component): + """Component that keeps track of operations.""" + + # List of all operations in the component + ops: list[cirq.Operation] + + def __init__(self, op: cirq.Operation, moment_id: int, is_mergeable=True): + super().__init__(op, moment_id, is_mergeable) + self.ops = [op] + + +class ComponentWithCircuitOp(Component): + """Component that keeps track of operations as a CircuitOperation.""" + + # CircuitOperation containing all the operations in the component, + # or a single Operation if the component is a singleton + circuit_op: cirq.Operation + + def __init__(self, op: cirq.Operation, moment_id: int, is_mergeable=True): + super().__init__(op, moment_id, is_mergeable) + self.circuit_op = op + + +class ComponentSet: + """Represents a set of mergeable components of operations.""" + + _comp_type: type[Component] + + _disjoint_set: DisjointSet + + # Callable to decide if a component is mergeable + _is_mergeable: Callable[[cirq.Operation], bool] + + # List of components in creation order + _components: list[Component] + + def __init__(self, is_mergeable: Callable[[cirq.Operation], bool]): + self._is_mergeable = is_mergeable + self._disjoint_set = DisjointSet() + self._components = [] + self._comp_type = Component + + def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: + """Creates a new component and adds it to the set.""" + c = self._comp_type(op, moment_id, self._is_mergeable(op) and is_mergeable) + self._disjoint_set.add(c) + self._components.append(c) + return c + + def components(self) -> list[Component]: + """Returns the initial components in creation order.""" + return self._components + + def find(self, x: Component) -> Component: + """Finds the representative for a merged component.""" + return self._disjoint_set[x] + + def merge(self, x: Component, y: Component, merge_left=True) -> Component | None: """Attempts to merge two components. - If merge_left is True, c is merged into this component, and the representative - will keep this component's moment. If merge_left is False, this component is - merged into c, and the representative will keep c's moment. + If merge_left is True, y is merged into x, and the representative will keep + y's moment. If merge_left is False, x is merged into y, and the representative + will keep y's moment. Args: - c: other component to merge - merge_left: True to keep self's moment for the merged component, False to - keep c's moment for the merged component. + x: First component to merge. + y: Second component to merge. + merge_left: True to keep x's moment for the merged component, False to + keep y's moment for the merged component. Returns: None, if the components can't be merged. Otherwise the new component representative. """ - x = self.find() - y = c.find() + x = self._disjoint_set[x] + y = self._disjoint_set[y] if not x.is_mergeable or not y.is_mergeable: return None - if x == y: + if not self._disjoint_set.merge(x, y): return x - if x.rank < y.rank: - if merge_left: - # As y will be the new representative, copy moment id from x - y.moment_id = x.moment_id - x, y = y, x - elif not merge_left: - # As x will be the new representative, copy moment id from y - x.moment_id = y.moment_id - - y.parent = x - if x.rank == y.rank: - x.rank += 1 - - x.qubits = x.qubits.union(y.qubits) - x.mkeys = x.mkeys.union(y.mkeys) - x.ckeys = x.ckeys.union(y.ckeys) - return x - + root = self._disjoint_set[x] + root.moment_id = x.moment_id if merge_left else y.moment_id + root.qubits = x.qubits.union(y.qubits) + root.mkeys = x.mkeys.union(y.mkeys) + root.ckeys = x.ckeys.union(y.ckeys) -class ComponentWithOps(Component): - """Component that keeps track of operations. + return root - Encapsulates a method can_merge that is used to decide if two components - can be merged. - """ - # List of all operations in the component - ops: list[cirq.Operation] +class ComponentWithOpsSet(ComponentSet): + """Represents a set of mergeable components, where each component tracks operations.""" - # Method to decide if two components can be merged based on their operations - can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] + # Callable that returns if two components can be merged based on their operations + _can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] def __init__( self, - op: cirq.Operation, - moment_id: int, + is_mergeable: Callable[[cirq.Operation], bool], can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], - is_mergeable=True, ): - super().__init__(op, moment_id, is_mergeable) - self.ops = [op] - self.can_merge = can_merge + super().__init__(is_mergeable) + self._can_merge = can_merge + self._comp_type = ComponentWithOps @override - def merge(self, c: Component, merge_left=True) -> Component | None: + def merge(self, x: Component, y: Component, merge_left=True) -> Component | None: """Attempts to merge two components. Returns: - None if can_merge is False, otherwise the new representative. - The representative will have ops = a.ops + b.ops. + None if can_merge is False or the merge doesn't succeed, otherwise the + new representative. The representative will have ops = x.ops + y.ops. """ - x = cast(ComponentWithOps, self.find()) - y = cast(ComponentWithOps, c.find()) + x = cast(ComponentWithOps, self._disjoint_set[x]) + y = cast(ComponentWithOps, self._disjoint_set[y]) if x == y: return x - if not x.is_mergeable or not y.is_mergeable or not x.can_merge(x.ops, y.ops): + if not x.is_mergeable or not y.is_mergeable or not self._can_merge(x.ops, y.ops): return None - root = cast(ComponentWithOps, super(ComponentWithOps, x).merge(y, merge_left)) + root = cast(ComponentWithOps, super(ComponentWithOpsSet, self).merge(x, y, merge_left)) root.ops = x.ops + y.ops - # Clear the ops list in the non-representative set to avoid memory consumption + # Clear the ops list in the non-representative component to avoid memory consumption if x != root: x.ops = [] else: @@ -170,42 +188,31 @@ def merge(self, c: Component, merge_left=True) -> Component | None: return root -class ComponentWithCircuitOp(Component): - """Component that keeps track of operations as a CircuitOperation. - - Encapsulates a method merge_func that is used to merge two components. - """ +class ComponentWithCircuitOpSet(ComponentSet): + """Represents a set of mergeable components, with operations as a CircuitOperation.""" - # CircuitOperation containing all the operations in the component, - # or a single Operation if the component is a singleton - circuit_op: cirq.Operation - - merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] + # Callable that merges CircuitOperations from two components + _merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] def __init__( self, - op: cirq.Operation, - moment_id: int, + is_mergeable: Callable[[cirq.Operation], bool], merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], - is_mergeable=True, ): - super().__init__(op, moment_id, is_mergeable) - self.circuit_op = op - self.merge_func = merge_func + super().__init__(is_mergeable) + self._merge_func = merge_func + self._comp_type = ComponentWithCircuitOp @override - def merge(self, c: Component, merge_left=True) -> Component | None: + def merge(self, x: Component, y: Component, merge_left=True) -> Component | None: """Attempts to merge two components. - If merge_left is True, the merge will use this component representative's - merge_func. If merge_left is False, the merge will use c representative's - merge_func. - Returns: - None if merge_func returns None, otherwise the new representative. + None if merge_func returns None or the merge doesn't succeed, + otherwise the new representative. """ - x = cast(ComponentWithCircuitOp, self.find()) - y = cast(ComponentWithCircuitOp, c.find()) + x = cast(ComponentWithCircuitOp, self._disjoint_set[x]) + y = cast(ComponentWithCircuitOp, self._disjoint_set[y]) if x == y: return x @@ -213,14 +220,13 @@ def merge(self, c: Component, merge_left=True) -> Component | None: if not x.is_mergeable or not y.is_mergeable: return None - if merge_left: - new_op = x.merge_func(x.circuit_op, y.circuit_op) - else: - new_op = y.merge_func(x.circuit_op, y.circuit_op) + new_op = self._merge_func(x.circuit_op, y.circuit_op) if not new_op: return None - root = cast(ComponentWithCircuitOp, super(ComponentWithCircuitOp, x).merge(y, merge_left)) + root = cast( + ComponentWithCircuitOp, super(ComponentWithCircuitOpSet, self).merge(x, y, merge_left) + ) root.circuit_op = new_op # The merge_func can be arbitrary, so we need to recompute the component properties @@ -228,61 +234,9 @@ def merge(self, c: Component, merge_left=True) -> Component | None: root.mkeys = protocols.measurement_key_objs(new_op) root.ckeys = protocols.control_keys(new_op) - # Clear the circuit op in the non-representative set to avoid memory consumption + # Clear the circuit op in the non-representative component to avoid memory consumption if x != root: del x.circuit_op else: del y.circuit_op return root - - -class ComponentFactory: - """Factory for components.""" - - is_mergeable: Callable[[cirq.Operation], bool] - - def __init__(self, is_mergeable: Callable[[cirq.Operation], bool]): - self.is_mergeable = is_mergeable - - def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: - return Component(op, moment_id, self.is_mergeable(op) and is_mergeable) - - -class ComponentWithOpsFactory(ComponentFactory): - """Factory for components with operations.""" - - can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool] - - def __init__( - self, - is_mergeable: Callable[[cirq.Operation], bool], - can_merge: Callable[[Sequence[cirq.Operation], Sequence[cirq.Operation]], bool], - ): - super().__init__(is_mergeable) - self.can_merge = can_merge - - @override - def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: - return ComponentWithOps( - op, moment_id, self.can_merge, self.is_mergeable(op) and is_mergeable - ) - - -class ComponentWithCircuitOpFactory(ComponentFactory): - """Factory for components with operations as CircuitOperation.""" - - merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None] - - def __init__( - self, - is_mergeable: Callable[[cirq.Operation], bool], - merge_func: Callable[[ops.Operation, ops.Operation], ops.Operation | None], - ): - super().__init__(is_mergeable) - self.merge_func = merge_func - - @override - def new_component(self, op: cirq.Operation, moment_id: int, is_mergeable=True) -> Component: - return ComponentWithCircuitOp( - op, moment_id, self.merge_func, self.is_mergeable(op) and is_mergeable - ) diff --git a/cirq-core/cirq/transformers/connected_component_test.py b/cirq-core/cirq/transformers/connected_component_test.py index 4f9b17808ea..ad067d60f63 100644 --- a/cirq-core/cirq/transformers/connected_component_test.py +++ b/cirq-core/cirq/transformers/connected_component_test.py @@ -14,206 +14,222 @@ from __future__ import annotations -from typing import cast - import cirq from cirq.transformers.connected_component import ( - Component, - ComponentFactory, - ComponentWithCircuitOp, - ComponentWithCircuitOpFactory, - ComponentWithOpsFactory, + ComponentSet, + ComponentWithCircuitOpSet, + ComponentWithOpsSet, ) def test_find_returns_itself_for_singleton(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + q = cirq.NamedQubit('x') - c = Component(op=cirq.X(q), moment_id=0) - assert c.find() == c + c = cset.new_component(op=cirq.X(q), moment_id=0) + assert cset.find(c) == c def test_merge_components(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + q = cirq.NamedQubit('x') - c = [Component(op=cirq.X(q), moment_id=i) for i in range(5)] - c[1].merge(c[0]) - c[2].merge(c[1]) - c[4].merge(c[3]) - c[3].merge(c[0]) - # Disjoint set structure: - # c[4] - # / \ - # c[1] c[3] - # / \ - # c[0] c[2] - assert c[0].parent == c[1] - assert c[2].parent == c[1] - assert c[1].parent == c[4] - assert c[3].parent == c[4] + c = [cset.new_component(op=cirq.X(q), moment_id=i) for i in range(5)] + cset.merge(c[1], c[0]) + cset.merge(c[2], c[1]) + cset.merge(c[4], c[3]) + cset.merge(c[3], c[0]) for i in range(5): - assert c[i].find() == c[4] - # Find() compressed all paths - for i in range(4): - assert c[i].parent == c[4] + assert cset.find(c[i]) == cset.find(c[0]) def test_merge_same_component(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + q = cirq.NamedQubit('x') - c = [Component(op=cirq.X(q), moment_id=i) for i in range(3)] - c[1].merge(c[0]) - c[2].merge(c[1]) - # Disjoint set structure: - # c[1] - # / \ - # c[0] c[2] - assert c[0].merge(c[2]) == c[1] + c = [cset.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] + cset.merge(c[1], c[0]) + cset.merge(c[2], c[1]) + + root = cset.find(c[0]) + + assert cset.merge(c[0], c[2]) == root def test_merge_returns_None_if_one_component_is_not_mergeable(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + q = cirq.NamedQubit('x') - c0 = Component(op=cirq.X(q), moment_id=0, is_mergeable=True) - c1 = Component(op=cirq.X(q), moment_id=1, is_mergeable=False) - assert c0.merge(c1) is None + c0 = cset.new_component(op=cirq.X(q), moment_id=0, is_mergeable=True) + c1 = cset.new_component(op=cirq.X(q), moment_id=1, is_mergeable=False) + assert cset.merge(c0, c1) is None -def test_factory_merge_returns_None_if_is_mergeable_is_false(): +def test_cset_merge_returns_None_if_is_mergeable_is_false(): q = cirq.NamedQubit('x') def is_mergeable(_: cirq.Operation) -> bool: return False - factory = ComponentFactory(is_mergeable=is_mergeable) - c0 = factory.new_component(op=cirq.X(q), moment_id=0, is_mergeable=True) - c1 = factory.new_component(op=cirq.X(q), moment_id=1, is_mergeable=True) - assert c0.merge(c1) is None + cset = ComponentSet(is_mergeable=is_mergeable) + + c0 = cset.new_component(op=cirq.X(q), moment_id=0, is_mergeable=True) + c1 = cset.new_component(op=cirq.X(q), moment_id=1, is_mergeable=True) + assert cset.merge(c0, c1) is None def test_merge_qubits_with_merge_left_true(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + q0 = cirq.NamedQubit('x') q1 = cirq.NamedQubit('y') - c0 = Component(op=cirq.X(q0), moment_id=0) - c1 = Component(op=cirq.X(q1), moment_id=0) - c2 = Component(op=cirq.X(q1), moment_id=1) - c1.merge(c2) - c0.merge(c1, merge_left=True) - assert c0.find() == c1 - assert c1.qubits == frozenset([q0, q1]) + c0 = cset.new_component(op=cirq.X(q0), moment_id=0) + c1 = cset.new_component(op=cirq.X(q1), moment_id=0) + c2 = cset.new_component(op=cirq.X(q1), moment_id=1) + cset.merge(c1, c2) + cset.merge(c0, c1, merge_left=True) + assert cset.find(c1).qubits == frozenset([q0, q1]) def test_merge_qubits_with_merge_left_false(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + q0 = cirq.NamedQubit('x') q1 = cirq.NamedQubit('y') - c0 = Component(op=cirq.X(q0), moment_id=0) - c1 = Component(op=cirq.X(q0), moment_id=0) - c2 = Component(op=cirq.X(q1), moment_id=1) - c0.merge(c1) - c1.merge(c2, merge_left=False) - assert c2.find() == c0 - assert c0.qubits == frozenset([q0, q1]) + c0 = cset.new_component(op=cirq.X(q0), moment_id=0) + c1 = cset.new_component(op=cirq.X(q0), moment_id=0) + c2 = cset.new_component(op=cirq.X(q1), moment_id=1) + cset.merge(c0, c1) + cset.merge(c1, c2, merge_left=False) + assert cset.find(c0).qubits == frozenset([q0, q1]) def test_merge_moment_with_merge_left_true(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + q0 = cirq.NamedQubit('x') q1 = cirq.NamedQubit('y') - c0 = Component(op=cirq.X(q0), moment_id=0) - c1 = Component(op=cirq.X(q1), moment_id=1) - c2 = Component(op=cirq.X(q1), moment_id=1) - c1.merge(c2) - c0.merge(c1, merge_left=True) - assert c0.find() == c1 - # c1 is the set representative but kept c0's moment - assert c1.moment_id == 0 + c0 = cset.new_component(op=cirq.X(q0), moment_id=0) + c1 = cset.new_component(op=cirq.X(q1), moment_id=1) + c2 = cset.new_component(op=cirq.X(q1), moment_id=1) + cset.merge(c1, c2) + cset.merge(c0, c1, merge_left=True) + # the set representative kept c0's moment + assert cset.find(c1).moment_id == 0 def test_merge_moment_with_merge_left_false(): + def is_mergeable(_: cirq.Operation) -> bool: + return True + + cset = ComponentSet(is_mergeable) + q0 = cirq.NamedQubit('x') q1 = cirq.NamedQubit('y') - c0 = Component(op=cirq.X(q0), moment_id=0) - c1 = Component(op=cirq.X(q0), moment_id=0) - c2 = Component(op=cirq.X(q1), moment_id=1) - c0.merge(c1) - c1.merge(c2, merge_left=False) - assert c2.find() == c0 - # c0 is the set representative but kept c2's moment - assert c0.moment_id == 1 + c0 = cset.new_component(op=cirq.X(q0), moment_id=0) + c1 = cset.new_component(op=cirq.X(q0), moment_id=0) + c2 = cset.new_component(op=cirq.X(q1), moment_id=1) + cset.merge(c0, c1) + cset.merge(c1, c2, merge_left=False) + # the set representative kept c2's moment + assert cset.find(c0).moment_id == 1 def test_component_with_ops_merge(): def is_mergeable(_: cirq.Operation) -> bool: return True - def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: - del ops1, ops2 + def can_merge(_ops1: list[cirq.Operation], _ops2: list[cirq.Operation]) -> bool: return True - factory = ComponentWithOpsFactory(is_mergeable, can_merge) + cset = ComponentWithOpsSet(is_mergeable, can_merge) q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] - c[0].merge(c[1]) - c[1].merge(c[2]) - assert c[0].find().ops == ops + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) + assert cset.find(c[0]).ops == ops def test_component_with_ops_merge_same_component(): def is_mergeable(_: cirq.Operation) -> bool: return True - def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: - del ops1, ops2 + def can_merge(_ops1: list[cirq.Operation], _ops2: list[cirq.Operation]) -> bool: return True - factory = ComponentWithOpsFactory(is_mergeable, can_merge) + cset = ComponentWithOpsSet(is_mergeable, can_merge) - q = cirq.NamedQubit('x') - c = [factory.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] - c[1].merge(c[0]) - c[2].merge(c[1]) - assert c[0].merge(c[2]) == c[1] + q = cirq.LineQubit.range(3) + ops = [cirq.X(q[i]) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) + assert cset.merge(c[0], c[2]).ops == ops def test_component_with_ops_merge_when_merge_fails(): def is_mergeable(_: cirq.Operation) -> bool: return True - def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: - del ops1, ops2 + def can_merge(_ops1: list[cirq.Operation], _ops2: list[cirq.Operation]) -> bool: return False - factory = ComponentWithOpsFactory(is_mergeable, can_merge) + cset = ComponentWithOpsSet(is_mergeable, can_merge) q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] - c[0].merge(c[1]) - c[1].merge(c[2]) + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) # No merge happened for i in range(3): - assert c[i].find() == c[i] + assert cset.find(c[i]) == c[i] def test_component_with_ops_merge_when_is_mergeable_is_false(): def is_mergeable(_: cirq.Operation) -> bool: return False - def can_merge(ops1: list[cirq.Operation], ops2: list[cirq.Operation]) -> bool: - del ops1, ops2 + def can_merge(_ops1: list[cirq.Operation], _ops2: list[cirq.Operation]) -> bool: return True - factory = ComponentWithOpsFactory(is_mergeable, can_merge) + cset = ComponentWithOpsSet(is_mergeable, can_merge) q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] - c[0].merge(c[1]) - c[1].merge(c[2]) + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) # No merge happened for i in range(3): - assert c[i].find() == c[i] + assert cset.find(c[i]) == c[i] def test_component_with_circuit_op_merge(): @@ -223,16 +239,16 @@ def is_mergeable(_: cirq.Operation) -> bool: def merge_func(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: return op1 - factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + cset = ComponentWithCircuitOpSet(is_mergeable, merge_func) q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] - c[0].merge(c[1]) - c[1].merge(c[2]) + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) for i in range(3): - assert c[i].find().circuit_op == ops[0] + assert cset.find(c[i]).circuit_op == ops[0] def test_component_with_circuit_op_merge_same_component(): @@ -242,34 +258,33 @@ def is_mergeable(_: cirq.Operation) -> bool: def merge_func(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: return op1 - factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + cset = ComponentWithCircuitOpSet(is_mergeable, merge_func) q = cirq.NamedQubit('x') - c = [factory.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] - c[1].merge(c[0]) - c[2].merge(c[1]) - assert c[0].merge(c[2]) == c[1] + c = [cset.new_component(op=cirq.X(q), moment_id=i) for i in range(3)] + cset.merge(c[1], c[0]) + cset.merge(c[2], c[1]) + assert cset.merge(c[0], c[2]) == cset.find(c[1]) def test_component_with_circuit_op_merge_func_is_none(): def is_mergeable(_: cirq.Operation) -> bool: return True - def merge_func(op1: cirq.Operation, op2: cirq.Operation) -> None: - del op1, op2 + def merge_func(_op1: cirq.Operation, _op2: cirq.Operation) -> None: return None - factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + cset = ComponentWithCircuitOpSet(is_mergeable, merge_func) q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] - c[0].merge(c[1]) - c[1].merge(c[2]) + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) # No merge happened for i in range(3): - assert c[i].find() == c[i] + assert cset.find(c[i]) == c[i] def test_component_with_circuit_op_merge_when_is_mergeable_is_false(): @@ -279,30 +294,14 @@ def is_mergeable(_: cirq.Operation) -> bool: def merge_func(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: return op1 - factory = ComponentWithCircuitOpFactory(is_mergeable, merge_func) + cset = ComponentWithCircuitOpSet(is_mergeable, merge_func) q = cirq.LineQubit.range(3) ops = [cirq.X(q[i]) for i in range(3)] - c = [factory.new_component(op=ops[i], moment_id=i) for i in range(3)] + c = [cset.new_component(op=ops[i], moment_id=i) for i in range(3)] - c[0].merge(c[1]) - c[1].merge(c[2]) + cset.merge(c[0], c[1]) + cset.merge(c[1], c[2]) # No merge happened for i in range(3): - assert c[i].find() == c[i] - - -def test_component_with_circuit_op_merge_when_merge_left_is_false(): - def merge_func_x(op1: cirq.Operation, _: cirq.Operation) -> cirq.Operation: - return op1 - - def merge_func_y(_: cirq.Operation, op2: cirq.Operation) -> cirq.Operation: - return op2 - - q = cirq.LineQubit.range(2) - x = ComponentWithCircuitOp(cirq.X(q[0]), moment_id=0, merge_func=merge_func_x) - y = ComponentWithCircuitOp(cirq.X(q[1]), moment_id=1, merge_func=merge_func_y) - - root = cast(ComponentWithCircuitOp, x.merge(y, merge_left=False)) - # The merge used merge_func_y because merge_left=False - assert root.circuit_op == cirq.X(q[1]) + assert cset.find(c[i]) == c[i] diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 74e2a4ac948..1bb69276681 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -26,10 +26,10 @@ from cirq.circuits.circuit import CIRCUIT_TYPE from cirq.transformers.connected_component import ( Component, - ComponentFactory, + ComponentSet, ComponentWithCircuitOp, - ComponentWithCircuitOpFactory, - ComponentWithOpsFactory, + ComponentWithCircuitOpSet, + ComponentWithOpsSet, ) if TYPE_CHECKING: @@ -357,9 +357,9 @@ def remove_component(self, c: Component, c_data: Component) -> None: """Removes a component from the merged circuit. Args: - c: reference to the component to be removed - c_data: copy of the data in c before any component merges involving c - (this is necessary as component merges alter the component data) + c: Reference to the component to be removed. + c_data: Copy of the data in c before any component merges involving c + (this is necessary as component merges alter the component data). """ self.components_by_index[c_data.moment_id].pop(c) for q in c_data.qubits: @@ -373,11 +373,11 @@ def get_mergeable_components(self, c: Component, c_qs: set[cirq.Qid]) -> list[Co """Finds all components that can be merged with c. Args: - c: component to be merged with existing components - c_qs: subset of c.qubits used to decide which components are mergeable + c: Component to be merged with existing components. + c_qs: Subset of c.qubits used to decide which components are mergeable. Returns: - list of mergeable components + List of mergeable components. """ # Find the index of previous moment which can be merged with `c`. idx = max([self.qubit_indexes[q][-1] for q in c_qs], default=-1) @@ -389,23 +389,21 @@ def get_mergeable_components(self, c: Component, c_qs: set[cirq.Qid]) -> list[Co return [c for c in self.components_by_index[idx] if not c_qs.isdisjoint(c.qubits)] - def get_cirq_circuit( - self, components: list[Component], merged_circuit_op_tag: str - ) -> cirq.Circuit: + def get_cirq_circuit(self, cset: ComponentSet, merged_circuit_op_tag: str) -> cirq.Circuit: """Returns the merged circuit. Args: - components: all components in creation order - merged_circuit_op_tag: tag to use for CircuitOperations + cset: Disjoint set data structure containing the components. + merged_circuit_op_tag: Tag to use for CircuitOperations. Returns: - the circuit with merged components as a CircuitOperation + The circuit with merged components as a CircuitOperation. """ component_ops: dict[Component, list[cirq.Operation]] = defaultdict(list) # Traverse the components in creation order and collect operations - for c in components: - root = c.find() + for c in cset.components(): + root = cset.find(c) component_ops[root].append(c.op) moments = [] @@ -429,7 +427,7 @@ def get_cirq_circuit( def _merge_operations_impl( circuit: CIRCUIT_TYPE, - factory: ComponentFactory, + cset: ComponentSet, *, merged_circuit_op_tag: str = "Merged connected component", tags_to_ignore: Sequence[Hashable] = (), @@ -445,9 +443,8 @@ def _merge_operations_impl( to repeatedly merge each operation in the latest moment with all the corresponding merge-able operations to its left. - Operations are wrapped in a component and then component.merge is called to merge two - components. The factory can provide components with different implementations of the merge - function, allowing for optimizations. + Operations are wrapped in a component and then cset.merge() is called to merge two + components. If op1 and op2 are merged, both op1 and op2 are deleted from the circuit and the merged component is inserted at the index corresponding to the larger @@ -458,8 +455,8 @@ def _merge_operations_impl( Args: circuit: Input circuit to apply the transformations on. The input circuit is not mutated. - factory: Factory that creates components from an operation. - merged_circuit_op_tag: tag used for CircuitOperations created from merged components. + cset: Disjoint set data structure that is used to create and merge components. + merged_circuit_op_tag: Tag used for CircuitOperations created from merged components. tags_to_ignore: Sequence of tags which should be ignored during the merge: operations with these tags will not be merged. deep: If true, the transformer primitive will be recursively applied to all circuits @@ -469,7 +466,6 @@ def _merge_operations_impl( Returns: Copy of input circuit with merged operations. """ - components = [] # List of all components in creation order tags_to_ignore_set = set(tags_to_ignore) merged_circuit = _MergedCircuit() @@ -485,21 +481,19 @@ def _merge_operations_impl( merged_op = op_untagged.replace( circuit=_merge_operations_impl( op_untagged.circuit, - factory, + cset, merged_circuit_op_tag=merged_circuit_op_tag, tags_to_ignore=tags_to_ignore, deep=True, ) ).with_tags(*op.tags) - c = factory.new_component(merged_op, moment_idx, is_mergeable=False) - components.append(c) + c = cset.new_component(merged_op, moment_idx, is_mergeable=False) merged_circuit.add_component(c) continue - c = factory.new_component( + c = cset.new_component( op, moment_idx, is_mergeable=tags_to_ignore_set.isdisjoint(op.tags) ) - components.append(c) if not c.is_mergeable: merged_circuit.add_component(c) continue @@ -510,7 +504,7 @@ def _merge_operations_impl( # Make a shallow copy of the left component data before merge left_c_data = copy.copy(left_comp[0]) # Case-1: Try to merge c with the larger component on the left. - new_comp = left_comp[0].merge(c, merge_left=True) + new_comp = cset.merge(left_comp[0], c, merge_left=True) if new_comp is not None: merged_circuit.remove_component(left_comp[0], left_c_data) merged_circuit.add_component(new_comp) @@ -526,7 +520,7 @@ def _merge_operations_impl( # Make a shallow copy of the left component data before merge left_c_data = copy.copy(left_c) # Try to merge left_c into c - new_comp = left_c.merge(c, merge_left=False) + new_comp = cset.merge(left_c, c, merge_left=False) if new_comp is not None: merged_circuit.remove_component(left_c, left_c_data) c, is_merged = new_comp, True @@ -534,7 +528,7 @@ def _merge_operations_impl( c_qs -= left_c.qubits left_comp = merged_circuit.get_mergeable_components(c, c_qs) merged_circuit.add_component(c) - ret_circuit = merged_circuit.get_cirq_circuit(components, merged_circuit_op_tag) + ret_circuit = merged_circuit.get_cirq_circuit(cset, merged_circuit_op_tag) return _to_target_circuit_type(ret_circuit, circuit) @@ -607,7 +601,7 @@ def is_mergeable(_: cirq.Operation): return _merge_operations_impl( circuit, - ComponentWithCircuitOpFactory(is_mergeable, apply_merge_func), + ComponentWithCircuitOpSet(is_mergeable, apply_merge_func), tags_to_ignore=tags_to_ignore, deep=deep, ) @@ -650,7 +644,7 @@ def is_mergeable(_: cirq.Operation): return _merge_operations_impl( circuit, - ComponentWithOpsFactory(is_mergeable, can_merge), + ComponentWithOpsSet(is_mergeable, can_merge), merged_circuit_op_tag=merged_circuit_op_tag, tags_to_ignore=tags_to_ignore, deep=deep, @@ -690,7 +684,7 @@ def is_mergeable(op: cirq.Operation): return _merge_operations_impl( circuit, - ComponentFactory(is_mergeable), + ComponentSet(is_mergeable), merged_circuit_op_tag=merged_circuit_op_tag or f"Merged {k}q unitary connected component.", tags_to_ignore=tags_to_ignore, deep=deep, From 3b783c6283779f1f1676ddf3eb79373bb5fc613f Mon Sep 17 00:00:00 2001 From: Codrut Date: Sat, 16 Aug 2025 00:16:57 +0200 Subject: [PATCH 6/7] Update attribute comment. --- cirq-core/cirq/transformers/transformer_primitives.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cirq-core/cirq/transformers/transformer_primitives.py b/cirq-core/cirq/transformers/transformer_primitives.py index 1bb69276681..ba30f57f6e2 100644 --- a/cirq-core/cirq/transformers/transformer_primitives.py +++ b/cirq-core/cirq/transformers/transformer_primitives.py @@ -304,9 +304,9 @@ class _MergedCircuit: ckey_indexes: Mapping from measurement keys to (sorted) list of component moments containing classically controlled operations controlled on the same key. components_by_index: List of components indexed by moment id. - For a moment id, we use a dictionary instead of a set to keep track of the - components in the moment. The dictionary is used to preserve insertion order, - and the values have no meaning. + For a moment id, we use a dictionary to keep track of the + components in the moment. The dictionary instead of a set is used to preserve + insertion order and the dictionary's values are intentionally unused. """ qubit_indexes: dict[cirq.Qid, list[int]] = dataclasses.field( From 867f2c096062d98b8a5e670c6d3e01c38bec264c Mon Sep 17 00:00:00 2001 From: Codrut Date: Sat, 16 Aug 2025 11:30:53 +0200 Subject: [PATCH 7/7] Fix lint issue. --- cirq-core/cirq/transformers/connected_component.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cirq-core/cirq/transformers/connected_component.py b/cirq-core/cirq/transformers/connected_component.py index f611e24a747..a905fa0cd49 100644 --- a/cirq-core/cirq/transformers/connected_component.py +++ b/cirq-core/cirq/transformers/connected_component.py @@ -178,7 +178,7 @@ def merge(self, x: Component, y: Component, merge_left=True) -> Component | None if not x.is_mergeable or not y.is_mergeable or not self._can_merge(x.ops, y.ops): return None - root = cast(ComponentWithOps, super(ComponentWithOpsSet, self).merge(x, y, merge_left)) + root = cast(ComponentWithOps, super().merge(x, y, merge_left)) root.ops = x.ops + y.ops # Clear the ops list in the non-representative component to avoid memory consumption if x != root: @@ -224,9 +224,7 @@ def merge(self, x: Component, y: Component, merge_left=True) -> Component | None if not new_op: return None - root = cast( - ComponentWithCircuitOp, super(ComponentWithCircuitOpSet, self).merge(x, y, merge_left) - ) + root = cast(ComponentWithCircuitOp, super().merge(x, y, merge_left)) root.circuit_op = new_op # The merge_func can be arbitrary, so we need to recompute the component properties