From f22688c6715eda8a9b67ba44a35568519069aafb Mon Sep 17 00:00:00 2001 From: daxfo Date: Wed, 10 Sep 2025 13:06:00 -0700 Subject: [PATCH 01/11] Preserve placement cache during inserts --- cirq-core/cirq/circuits/circuit.py | 104 ++++++++++++++++++------ cirq-core/cirq/circuits/circuit_test.py | 5 +- 2 files changed, 85 insertions(+), 24 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 15152fc2c6e..5fe7eea8ac0 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -22,6 +22,7 @@ from __future__ import annotations import abc +import copy import enum import html import itertools @@ -1957,8 +1958,13 @@ def copy(self) -> Circuit: """Return a copy of this circuit.""" copied_circuit = Circuit() copied_circuit._moments[:] = self._moments - copied_circuit._placement_cache = None + copied_circuit._placement_cache = copy.deepcopy(self._placement_cache) copied_circuit._tags = self.tags + copied_circuit._all_qubits = self._all_qubits + copied_circuit._frozen = self._frozen + copied_circuit._is_measurement = self._is_measurement + copied_circuit._is_parameterized = self._is_parameterized + copied_circuit._parameter_names = copy.copy(self._parameter_names) return copied_circuit @overload @@ -2002,7 +2008,7 @@ def __radd__(self, other): return NotImplemented # Auto wrap OP_TREE inputs into a circuit. result = self.copy() - result._moments[:0] = Circuit(other)._moments + result._insert_moment(0, *Circuit(other)._moments) return result # Needed for numpy to handle multiplication by np.int64 correctly. @@ -2011,8 +2017,13 @@ def __radd__(self, other): def __imul__(self, repetitions: _INT_TYPE): if not isinstance(repetitions, (int, np.integer)): return NotImplemented + moment_len = len(self._moments) self._moments *= int(repetitions) - self._mutated() + if self._placement_cache: + for _ in range(len(self._moments) - moment_len): + # todo: add "count" + self._placement_cache.insert_moment(0) + self._mutated(preserve_placement_cache=True) return self def __mul__(self, repetitions: _INT_TYPE): @@ -2186,10 +2197,9 @@ def insert( """ # limit index to 0..len(self._moments), also deal with indices smaller 0 k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0) - if strategy != InsertStrategy.EARLIEST or k != len(self._moments): - self._placement_cache = None + appending = strategy == InsertStrategy.EARLIEST and k == len(self._moments) mops = list(ops.flatten_to_ops_or_moments(moment_or_operation_tree)) - if self._placement_cache: + if appending: batches = [mops] # Any grouping would work here; this just happens to be the fastest. elif strategy is InsertStrategy.NEW: batches = [[mop] for mop in mops] # Each op goes into its own moment. @@ -2198,7 +2208,7 @@ def insert( for batch in batches: # Insert a moment if inline/earliest and _any_ op in the batch requires it. if ( - not self._placement_cache + not appending and not isinstance(batch[0], Moment) and strategy in (InsertStrategy.INLINE, InsertStrategy.EARLIEST) and not all( @@ -2207,18 +2217,18 @@ def insert( for op in cast(list[cirq.Operation], batch) ) ): - self._moments.insert(k, Moment()) + self._insert_moment(k) if strategy is InsertStrategy.INLINE: k += 1 max_p = 0 for moment_or_op in batch: # Determine Placement - if self._placement_cache: + if self._placement_cache and appending: p = self._placement_cache.append(moment_or_op) elif isinstance(moment_or_op, Moment): p = k elif strategy in (InsertStrategy.NEW, InsertStrategy.NEW_THEN_INLINE): - self._moments.insert(k, Moment()) + self._insert_moment(k) p = k elif strategy is InsertStrategy.INLINE: p = k - 1 @@ -2226,11 +2236,9 @@ def insert( p = self.earliest_available_moment(moment_or_op, end_moment_index=k) # Place if isinstance(moment_or_op, Moment): - self._moments.insert(p, moment_or_op) - elif p == len(self._moments): - self._moments.append(Moment(moment_or_op)) + self._insert_moment(p, moment_or_op) else: - self._moments[p] = self._moments[p].with_operation(moment_or_op) + self._put_op(p, moment_or_op, skip_cache=appending) # Iterate max_p = max(p, max_p) if strategy is InsertStrategy.NEW_THEN_INLINE: @@ -2240,6 +2248,28 @@ def insert( self._mutated(preserve_placement_cache=True) return k + def _insert_moment(self, k: int, *moments: Moment, count:int = 1): + # todo: *args are slow + if not moments: + moments = [Moment()] + moments *= count + self._moments[k:k] = moments + if self._placement_cache: + for i, m in enumerate(moments): + # todo: overload with moments + self._placement_cache.insert_moment(k + i) + self._placement_cache.put(m, k + i) + + def _put_op(self, k: int, *ops: cirq.Operation, skip_cache: bool=False): + # todo: *args are slow, need OP_TREE too + if k == len(self._moments): + self._moments.append(Moment(ops)) + else: + self._moments[k] = self._moments[k].with_operations(*ops) + if self._placement_cache and not skip_cache: + for op in ops: + self._placement_cache.put(op, k) + def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> int: """Writes operations inline into an area of the circuit. @@ -2272,9 +2302,9 @@ def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> i if i >= end: break - self._moments[i] = self._moments[i].with_operation(op) + self._put_op(i, op) op_index += 1 - self._mutated() + self._mutated(preserve_placement_cache=True) if op_index >= len(flat_ops): return end @@ -2319,8 +2349,8 @@ def _push_frontier( ) if n_new_moments > 0: insert_index = min(late_frontier.values()) - self._moments[insert_index:insert_index] = [Moment()] * n_new_moments - self._mutated() + self._insert_moment(insert_index, count=n_new_moments) + self._mutated(preserve_placement_cache=True) for q in update_qubits: if early_frontier.get(q, 0) > insert_index: early_frontier[q] += n_new_moments @@ -2347,12 +2377,12 @@ def _insert_operations( if len(operations) != len(insertion_indices): raise ValueError('operations and insertion_indices must have the same length.') self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))] - self._mutated() + self._mutated(preserve_placement_cache=True) moment_to_ops: dict[int, list[cirq.Operation]] = defaultdict(list) for op_index, moment_index in enumerate(insertion_indices): moment_to_ops[moment_index].append(operations[op_index]) for moment_index, new_ops in moment_to_ops.items(): - self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops) + self._put_op(moment_index, *new_ops) def insert_at_frontier( self, operations: cirq.OP_TREE, start: int, frontier: dict[cirq.Qid, int] | None = None @@ -2455,9 +2485,11 @@ def batch_insert_into(self, insert_intos: Iterable[tuple[int, cirq.OP_TREE]]) -> """ copy = self.copy() for i, insertions in insert_intos: - copy._moments[i] = copy._moments[i].with_operations(insertions) + copy._put_op(i, insertions) + # todo: add self._copy_from() self._moments = copy._moments - self._mutated() + self._placement_cache = copy._placement_cache + self._mutated(preserve_placement_cache=True) def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None: """Applies a batched insert operation to the circuit. @@ -2492,7 +2524,8 @@ def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None: if next_index > insert_index: shift += next_index - insert_index self._moments = copy._moments - self._mutated() + self._placement_cache = copy._placement_cache + self._mutated(preserve_placement_cache=True) def append( self, @@ -2539,6 +2572,7 @@ def with_tags(self, *new_tags: Hashable) -> cirq.Circuit: return self new_circuit = Circuit(tags=self.tags + new_tags) new_circuit._moments[:] = self._moments + new_circuit._placement_cache = copy.deepcopy(self._placement_cache) # todo: implement return new_circuit def with_noise(self, noise: cirq.NOISE_MODEL_LIKE) -> cirq.Circuit: @@ -3062,3 +3096,27 @@ def append(self, moment_or_operation: _MOMENT_OR_OP) -> int: ) self._length = max(self._length, index + 1) return index + + def put(self, moment_or_operation: _MOMENT_OR_OP, index: int): + self._update_index(self._qubit_indices, moment_or_operation.qubits, index) + self._update_index(self._mkey_indices, protocols.measurement_key_objs(moment_or_operation), index) + self._update_index(self._ckey_indices, protocols.control_keys(moment_or_operation), index) + self._length = max(self._length, index + 1) + + def insert_moment(self, index: int): + self._insert_moment(self._qubit_indices, index) + self._insert_moment(self._mkey_indices, index) + self._insert_moment(self._ckey_indices, index) + self._length += 1 + + @staticmethod + def _update_index(key_indices, mop_keys, mop_index): + for key in mop_keys: + key_indices[key] = max(mop_index, key_indices.get(key, -1)) + + @staticmethod + def _insert_moment(key_indices, index): + for key in key_indices: + key_index = key_indices[key] + if key_index >= index: + key_indices[key] = key_index + 1 diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index cd8f7b00c70..6e88d7fc691 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -4931,11 +4931,14 @@ def test_append_speed() -> None: t = time.perf_counter() # Iterating with the moments in the inner loop highlights the improvement: when filling in the # second qubit, we no longer have to search backwards from moment 10000 for a placement index. + # c.insert(0, xs[0]) + # c.insert(0, xs[0]) for q in range(qs): for _ in range(moments): c.append(xs[q]) duration = time.perf_counter() - t - assert len(c) == moments + print(duration) + #assert len(c) == moments assert duration < 5 From f0a2eaf7571ad3aaea95a2cca43c6a665c93f3af Mon Sep 17 00:00:00 2001 From: daxfo Date: Wed, 10 Sep 2025 16:21:45 -0700 Subject: [PATCH 02/11] Fix bugs --- cirq-core/cirq/circuits/circuit.py | 71 +++++++++---------- cirq-core/cirq/circuits/circuit_test.py | 11 +-- .../cirq/circuits/frozen_circuit_test.py | 9 ++- cirq-core/cirq/circuits/moment.py | 3 + 4 files changed, 47 insertions(+), 47 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 5fe7eea8ac0..d57ad1ed9c0 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1967,6 +1967,16 @@ def copy(self) -> Circuit: copied_circuit._parameter_names = copy.copy(self._parameter_names) return copied_circuit + def _copy_from(self, other: Circuit) -> None: + """Copies the contents of another circuit into this one.""" + self._moments[:] = other._moments + self._placement_cache = copy.deepcopy(other._placement_cache) + self._tags = other.tags + self._all_qubits = other._all_qubits + self._is_measurement = other._is_measurement + self._is_parameterized = other._is_parameterized + self._parameter_names = copy.copy(other._parameter_names) + @overload def __setitem__(self, key: int, value: cirq.Moment): pass @@ -2017,13 +2027,10 @@ def __radd__(self, other): def __imul__(self, repetitions: _INT_TYPE): if not isinstance(repetitions, (int, np.integer)): return NotImplemented - moment_len = len(self._moments) + num_moments_added = len(self._moments) * (repetitions - 1) self._moments *= int(repetitions) if self._placement_cache: - for _ in range(len(self._moments) - moment_len): - # todo: add "count" - self._placement_cache.insert_moment(0) - self._mutated(preserve_placement_cache=True) + self._placement_cache.insert_moment(0, num_moments_added) return self def __mul__(self, repetitions: _INT_TYPE): @@ -2199,7 +2206,7 @@ def insert( k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0) appending = strategy == InsertStrategy.EARLIEST and k == len(self._moments) mops = list(ops.flatten_to_ops_or_moments(moment_or_operation_tree)) - if appending: + if self._placement_cache and appending: batches = [mops] # Any grouping would work here; this just happens to be the fastest. elif strategy is InsertStrategy.NEW: batches = [[mop] for mop in mops] # Each op goes into its own moment. @@ -2236,7 +2243,7 @@ def insert( p = self.earliest_available_moment(moment_or_op, end_moment_index=k) # Place if isinstance(moment_or_op, Moment): - self._insert_moment(p, moment_or_op) + self._insert_moment(p, moment_or_op, skip_cache=appending) else: self._put_op(p, moment_or_op, skip_cache=appending) # Iterate @@ -2248,22 +2255,19 @@ def insert( self._mutated(preserve_placement_cache=True) return k - def _insert_moment(self, k: int, *moments: Moment, count:int = 1): - # todo: *args are slow + def _insert_moment(self, k: int, *moments: Moment, count: int = 1, skip_cache: bool = False): if not moments: - moments = [Moment()] + moments = (Moment(),) moments *= count self._moments[k:k] = moments - if self._placement_cache: + if self._placement_cache and not skip_cache: + self._placement_cache.insert_moment(k, len(moments)) for i, m in enumerate(moments): - # todo: overload with moments - self._placement_cache.insert_moment(k + i) self._placement_cache.put(m, k + i) - def _put_op(self, k: int, *ops: cirq.Operation, skip_cache: bool=False): - # todo: *args are slow, need OP_TREE too + def _put_op(self, k: int, *ops: cirq.Operation, skip_cache: bool = False): if k == len(self._moments): - self._moments.append(Moment(ops)) + self._moments.append(Moment.from_ops(*ops)) else: self._moments[k] = self._moments[k].with_operations(*ops) if self._placement_cache and not skip_cache: @@ -2485,11 +2489,8 @@ def batch_insert_into(self, insert_intos: Iterable[tuple[int, cirq.OP_TREE]]) -> """ copy = self.copy() for i, insertions in insert_intos: - copy._put_op(i, insertions) - # todo: add self._copy_from() - self._moments = copy._moments - self._placement_cache = copy._placement_cache - self._mutated(preserve_placement_cache=True) + copy._put_op(i, *ops.flatten_to_ops(insertions)) + self._copy_from(copy) def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None: """Applies a batched insert operation to the circuit. @@ -2523,9 +2524,7 @@ def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None: next_index = copy.insert(insert_index, reversed(group), InsertStrategy.EARLIEST) if next_index > insert_index: shift += next_index - insert_index - self._moments = copy._moments - self._placement_cache = copy._placement_cache - self._mutated(preserve_placement_cache=True) + self._copy_from(copy) def append( self, @@ -2572,7 +2571,7 @@ def with_tags(self, *new_tags: Hashable) -> cirq.Circuit: return self new_circuit = Circuit(tags=self.tags + new_tags) new_circuit._moments[:] = self._moments - new_circuit._placement_cache = copy.deepcopy(self._placement_cache) # todo: implement + new_circuit._placement_cache = copy.deepcopy(self._placement_cache) # todo: implement return new_circuit def with_noise(self, noise: cirq.NOISE_MODEL_LIKE) -> cirq.Circuit: @@ -3098,25 +3097,25 @@ def append(self, moment_or_operation: _MOMENT_OR_OP) -> int: return index def put(self, moment_or_operation: _MOMENT_OR_OP, index: int): - self._update_index(self._qubit_indices, moment_or_operation.qubits, index) - self._update_index(self._mkey_indices, protocols.measurement_key_objs(moment_or_operation), index) - self._update_index(self._ckey_indices, protocols.control_keys(moment_or_operation), index) + self._put(self._qubit_indices, moment_or_operation.qubits, index) + self._put(self._mkey_indices, protocols.measurement_key_objs(moment_or_operation), index) + self._put(self._ckey_indices, protocols.control_keys(moment_or_operation), index) self._length = max(self._length, index + 1) - def insert_moment(self, index: int): - self._insert_moment(self._qubit_indices, index) - self._insert_moment(self._mkey_indices, index) - self._insert_moment(self._ckey_indices, index) - self._length += 1 + def insert_moment(self, index: int, count: int = 1): + self._insert_moment(self._qubit_indices, index, count) + self._insert_moment(self._mkey_indices, index, count) + self._insert_moment(self._ckey_indices, index, count) + self._length += count @staticmethod - def _update_index(key_indices, mop_keys, mop_index): + def _put[T](key_indices: dict[T, int], mop_keys: Iterable[T], mop_index: int): for key in mop_keys: key_indices[key] = max(mop_index, key_indices.get(key, -1)) @staticmethod - def _insert_moment(key_indices, index): + def _insert_moment[T](key_indices: dict[T, int], index: int, count: int): for key in key_indices: key_index = key_indices[key] if key_index >= index: - key_indices[key] = key_index + 1 + key_indices[key] = key_index + count diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 6e88d7fc691..cd2cab8e72a 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -4641,7 +4641,6 @@ def test_freeze_is_cached() -> None: lambda c: c.__setitem__(0, cirq.Moment(cirq.Y(cirq.q(0)))), ), (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__delitem__(0)), - (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__imul__(2)), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert(1, cirq.Y(cirq.q(0))), @@ -4658,14 +4657,6 @@ def test_freeze_is_cached() -> None: cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.batch_replace([(0, cirq.X(cirq.q(0)), cirq.Y(cirq.q(0)))]), ), - ( - cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0), cirq.q(1))), - lambda c: c.batch_insert_into([(0, cirq.X(cirq.q(1)))]), - ), - ( - cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), - lambda c: c.batch_insert([(1, cirq.Y(cirq.q(0)))]), - ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.clear_operations_touching([cirq.q(0)], [0]), @@ -4938,7 +4929,7 @@ def test_append_speed() -> None: c.append(xs[q]) duration = time.perf_counter() - t print(duration) - #assert len(c) == moments + # assert len(c) == moments assert duration < 5 diff --git a/cirq-core/cirq/circuits/frozen_circuit_test.py b/cirq-core/cirq/circuits/frozen_circuit_test.py index fda9c1fe1be..12528d88d7b 100644 --- a/cirq-core/cirq/circuits/frozen_circuit_test.py +++ b/cirq-core/cirq/circuits/frozen_circuit_test.py @@ -76,9 +76,16 @@ def test_freeze_and_unfreeze() -> None: cc = c.unfreeze() assert cc is not c + # Refreezing without modification returns original FrozenCircuit. fcc = cc.freeze() assert fcc.moments == f.moments - assert fcc is not f + assert fcc is f + + # Modifying and refreezing returns new FrozenCircuit. + cc.append(cirq.X(a)) + fcc2 = cc.freeze() + assert tuple(cc.moments) == fcc2.moments + assert fcc2 is not f def test_immutable() -> None: diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index 414e2fab0f2..051fc1ecce8 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -256,6 +256,9 @@ def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment: Raises: ValueError: If the contents given overlaps a current operation in the moment. """ + if len(contents) == 1 and isinstance(contents[0], ops.Operation): + return self.with_operation(contents[0]) + flattened_contents = tuple(op_tree.flatten_to_ops(contents)) if not flattened_contents: From fa72d38b7b4b40d4a274208817798791f482e062 Mon Sep 17 00:00:00 2001 From: daxfo Date: Wed, 10 Sep 2025 21:49:19 -0700 Subject: [PATCH 03/11] nits --- cirq-core/cirq/circuits/circuit.py | 47 +++++++++++++---------- cirq-core/cirq/circuits/frozen_circuit.py | 4 +- 2 files changed, 28 insertions(+), 23 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index d57ad1ed9c0..f91edc19b01 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -30,7 +30,6 @@ from collections import defaultdict from types import NotImplementedType from typing import ( - AbstractSet, Any, Callable, cast, @@ -1342,10 +1341,10 @@ def _is_parameterized_(self) -> bool: protocols.is_parameterized(tag) for tag in self.tags ) - def _parameter_names_(self) -> AbstractSet[str]: + def _parameter_names_(self) -> frozenset[str]: op_params = {name for op in self.all_operations() for name in protocols.parameter_names(op)} tag_params = {name for tag in self.tags for name in protocols.parameter_names(tag)} - return op_params | tag_params + return frozenset(op_params | tag_params) def _resolve_parameters_(self, resolver: cirq.ParamResolver, recursive: bool) -> Self: changed = False @@ -1840,7 +1839,7 @@ def __init__( self._frozen: cirq.FrozenCircuit | None = None self._is_measurement: bool | None = None self._is_parameterized: bool | None = None - self._parameter_names: AbstractSet[str] | None = None + self._parameter_names: frozenset[str] | None = None if not contents: return flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents)) @@ -1949,7 +1948,7 @@ def _is_parameterized_(self) -> bool: self._is_parameterized = super()._is_parameterized_() return self._is_parameterized - def _parameter_names_(self) -> AbstractSet[str]: + def _parameter_names_(self) -> frozenset[str]: if self._parameter_names is None: self._parameter_names = super()._parameter_names_() return self._parameter_names @@ -1958,24 +1957,25 @@ def copy(self) -> Circuit: """Return a copy of this circuit.""" copied_circuit = Circuit() copied_circuit._moments[:] = self._moments - copied_circuit._placement_cache = copy.deepcopy(self._placement_cache) + copied_circuit._placement_cache = copy.copy(self._placement_cache) copied_circuit._tags = self.tags copied_circuit._all_qubits = self._all_qubits copied_circuit._frozen = self._frozen copied_circuit._is_measurement = self._is_measurement copied_circuit._is_parameterized = self._is_parameterized - copied_circuit._parameter_names = copy.copy(self._parameter_names) + copied_circuit._parameter_names = self._parameter_names return copied_circuit def _copy_from(self, other: Circuit) -> None: """Copies the contents of another circuit into this one.""" self._moments[:] = other._moments - self._placement_cache = copy.deepcopy(other._placement_cache) + self._placement_cache = other._placement_cache self._tags = other.tags self._all_qubits = other._all_qubits + self._frozen = other._frozen self._is_measurement = other._is_measurement self._is_parameterized = other._is_parameterized - self._parameter_names = copy.copy(other._parameter_names) + self._parameter_names = other._parameter_names @overload def __setitem__(self, key: int, value: cirq.Moment): @@ -2252,7 +2252,6 @@ def insert( strategy = InsertStrategy.INLINE k += 1 k = max(k, max_p + 1) - self._mutated(preserve_placement_cache=True) return k def _insert_moment(self, k: int, *moments: Moment, count: int = 1, skip_cache: bool = False): @@ -2264,6 +2263,7 @@ def _insert_moment(self, k: int, *moments: Moment, count: int = 1, skip_cache: b self._placement_cache.insert_moment(k, len(moments)) for i, m in enumerate(moments): self._placement_cache.put(m, k + i) + self._mutated(preserve_placement_cache=True) def _put_op(self, k: int, *ops: cirq.Operation, skip_cache: bool = False): if k == len(self._moments): @@ -2273,6 +2273,7 @@ def _put_op(self, k: int, *ops: cirq.Operation, skip_cache: bool = False): if self._placement_cache and not skip_cache: for op in ops: self._placement_cache.put(op, k) + self._mutated(preserve_placement_cache=True) def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> int: """Writes operations inline into an area of the circuit. @@ -2308,7 +2309,6 @@ def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> i self._put_op(i, op) op_index += 1 - self._mutated(preserve_placement_cache=True) if op_index >= len(flat_ops): return end @@ -2354,7 +2354,6 @@ def _push_frontier( if n_new_moments > 0: insert_index = min(late_frontier.values()) self._insert_moment(insert_index, count=n_new_moments) - self._mutated(preserve_placement_cache=True) for q in update_qubits: if early_frontier.get(q, 0) > insert_index: early_frontier[q] += n_new_moments @@ -2380,8 +2379,7 @@ def _insert_operations( """ if len(operations) != len(insertion_indices): raise ValueError('operations and insertion_indices must have the same length.') - self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))] - self._mutated(preserve_placement_cache=True) + self._insert_moment(len(self), count=1 + max(insertion_indices) - len(self)) moment_to_ops: dict[int, list[cirq.Operation]] = defaultdict(list) for op_index, moment_index in enumerate(insertion_indices): moment_to_ops[moment_index].append(operations[op_index]) @@ -2569,9 +2567,8 @@ def with_tags(self, *new_tags: Hashable) -> cirq.Circuit: """Creates a new tagged `Circuit` with `self.tags` and `new_tags` combined.""" if not new_tags: return self - new_circuit = Circuit(tags=self.tags + new_tags) - new_circuit._moments[:] = self._moments - new_circuit._placement_cache = copy.deepcopy(self._placement_cache) # todo: implement + new_circuit = self.copy() + new_circuit._tags = self.tags + new_tags return new_circuit def with_noise(self, noise: cirq.NOISE_MODEL_LIKE) -> cirq.Circuit: @@ -3096,26 +3093,34 @@ def append(self, moment_or_operation: _MOMENT_OR_OP) -> int: self._length = max(self._length, index + 1) return index - def put(self, moment_or_operation: _MOMENT_OR_OP, index: int): + def put(self, moment_or_operation: _MOMENT_OR_OP, index: int) -> None: self._put(self._qubit_indices, moment_or_operation.qubits, index) self._put(self._mkey_indices, protocols.measurement_key_objs(moment_or_operation), index) self._put(self._ckey_indices, protocols.control_keys(moment_or_operation), index) self._length = max(self._length, index + 1) - def insert_moment(self, index: int, count: int = 1): + def insert_moment(self, index: int, count: int = 1) -> None: self._insert_moment(self._qubit_indices, index, count) self._insert_moment(self._mkey_indices, index, count) self._insert_moment(self._ckey_indices, index, count) self._length += count @staticmethod - def _put[T](key_indices: dict[T, int], mop_keys: Iterable[T], mop_index: int): + def _put[T](key_indices: dict[T, int], mop_keys: Iterable[T], mop_index: int) -> None: for key in mop_keys: key_indices[key] = max(mop_index, key_indices.get(key, -1)) @staticmethod - def _insert_moment[T](key_indices: dict[T, int], index: int, count: int): + def _insert_moment[T](key_indices: dict[T, int], index: int, count: int) -> None: for key in key_indices: key_index = key_indices[key] if key_index >= index: key_indices[key] = key_index + count + + def __copy__(self) -> _PlacementCache: + cache = _PlacementCache() + cache._qubit_indices = self._qubit_indices.copy() + cache._mkey_indices = self._mkey_indices.copy() + cache._ckey_indices = self._ckey_indices.copy() + cache._length = self._length + return cache diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index 2aa8d401bda..b4b03d51462 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -18,7 +18,7 @@ from functools import cached_property from types import NotImplementedType -from typing import AbstractSet, Hashable, Iterable, Iterator, Sequence, TYPE_CHECKING +from typing import Hashable, Iterable, Iterator, Sequence, TYPE_CHECKING from cirq import _compat, protocols from cirq.circuits import AbstractCircuit, Alignment, Circuit @@ -172,7 +172,7 @@ def _is_parameterized_(self) -> bool: return super()._is_parameterized_() @_compat.cached_method - def _parameter_names_(self) -> AbstractSet[str]: + def _parameter_names_(self) -> frozenset[str]: return super()._parameter_names_() def _measurement_key_names_(self) -> frozenset[str]: From b0dabab651bd85e320f138bd7898ff92d9fea917 Mon Sep 17 00:00:00 2001 From: daxfo Date: Wed, 10 Sep 2025 22:47:07 -0700 Subject: [PATCH 04/11] tests --- cirq-core/cirq/circuits/circuit.py | 18 ++-- cirq-core/cirq/circuits/circuit_test.py | 125 ++++++++++++++++++++---- 2 files changed, 114 insertions(+), 29 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index f91edc19b01..5886866c2f9 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1957,25 +1957,25 @@ def copy(self) -> Circuit: """Return a copy of this circuit.""" copied_circuit = Circuit() copied_circuit._moments[:] = self._moments - copied_circuit._placement_cache = copy.copy(self._placement_cache) copied_circuit._tags = self.tags copied_circuit._all_qubits = self._all_qubits copied_circuit._frozen = self._frozen copied_circuit._is_measurement = self._is_measurement copied_circuit._is_parameterized = self._is_parameterized copied_circuit._parameter_names = self._parameter_names + copied_circuit._placement_cache = copy.copy(self._placement_cache) return copied_circuit - def _copy_from(self, other: Circuit) -> None: + def _copy_from_shallow(self, other: Circuit) -> None: """Copies the contents of another circuit into this one.""" self._moments[:] = other._moments - self._placement_cache = other._placement_cache self._tags = other.tags self._all_qubits = other._all_qubits self._frozen = other._frozen self._is_measurement = other._is_measurement self._is_parameterized = other._is_parameterized self._parameter_names = other._parameter_names + self._placement_cache = other._placement_cache @overload def __setitem__(self, key: int, value: cirq.Moment): @@ -2034,14 +2034,10 @@ def __imul__(self, repetitions: _INT_TYPE): return self def __mul__(self, repetitions: _INT_TYPE): - if not isinstance(repetitions, (int, np.integer)): - return NotImplemented - return Circuit(self._moments * int(repetitions), tags=self.tags) + return self.copy().__imul__(repetitions) def __rmul__(self, repetitions: _INT_TYPE): - if not isinstance(repetitions, (int, np.integer)): - return NotImplemented - return self * int(repetitions) + return self.copy().__imul__(repetitions) def __pow__(self, exponent: int) -> cirq.Circuit: """A circuit raised to a power, only valid for exponent -1, the inverse. @@ -2488,7 +2484,7 @@ def batch_insert_into(self, insert_intos: Iterable[tuple[int, cirq.OP_TREE]]) -> copy = self.copy() for i, insertions in insert_intos: copy._put_op(i, *ops.flatten_to_ops(insertions)) - self._copy_from(copy) + self._copy_from_shallow(copy) def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None: """Applies a batched insert operation to the circuit. @@ -2522,7 +2518,7 @@ def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None: next_index = copy.insert(insert_index, reversed(group), InsertStrategy.EARLIEST) if next_index > insert_index: shift += next_index - insert_index - self._copy_from(copy) + self._copy_from_shallow(copy) def append( self, diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index cd2cab8e72a..a96407552a3 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -19,7 +19,7 @@ import time from collections import defaultdict from random import randint, random, randrange, sample -from typing import Iterator, Sequence +from typing import Any, Callable, Iterator, Sequence import numpy as np import pytest @@ -4634,46 +4634,106 @@ def test_freeze_is_cached() -> None: @pytest.mark.parametrize( - "circuit, mutate", + "circuit, mutate, inserts, replaces_or_deletes", [ ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__setitem__(0, cirq.Moment(cirq.Y(cirq.q(0)))), + False, + True, ), - (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__delitem__(0)), + ( + cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), + lambda c: c.__delitem__(0), + False, + True, + ), + # Formally `mul` does insert, but in a way that doesn't affect caches. + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__imul__(2), False, False), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c * 2, False, False), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: 2 * c, False, False), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert(1, cirq.Y(cirq.q(0))), + True, + False, + ), + ( + cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), + lambda c: c.__iadd__([cirq.Y(cirq.q(0))]), + True, + False, + ), + ( + cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), + lambda c: c + [cirq.Y(cirq.q(0))], + True, + False, + ), + ( + cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), + lambda c: [cirq.Y(cirq.q(0))] + c, + True, + False, ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert_into_range([cirq.Y(cirq.q(1)), cirq.M(cirq.q(1))], 0, 2), + True, + False, ), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.insert(1, []), False, False), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert_at_frontier([cirq.Y(cirq.q(0)), cirq.Y(cirq.q(1))], 1), + True, + False, ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.batch_replace([(0, cirq.X(cirq.q(0)), cirq.Y(cirq.q(0)))]), + False, + True, + ), + ( + cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0), cirq.q(1))), + lambda c: c.batch_insert_into([(0, cirq.X(cirq.q(1)))]), + True, + False, + ), + ( + cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), + lambda c: c.batch_insert([(1, cirq.Y(cirq.q(0)))]), + True, + False, ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.clear_operations_touching([cirq.q(0)], [0]), + False, + True, ), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.copy(), False, False), ], ) -def test_mutation_clears_cached_attributes(circuit, mutate) -> None: - cached_attributes = [ +def test_mutation_clears_cached_attributes( + circuit: cirq.Circuit, + mutate: Callable[[cirq.Circuit], Any], + inserts: bool, + replaces_or_deletes: bool, +) -> None: + standard_attributes = [ "_all_qubits", "_frozen", "_is_measurement", "_is_parameterized", "_parameter_names", ] + all_attributes = standard_attributes + ['_placement_cache'] - for attr in cached_attributes: + for attr in standard_attributes: assert getattr(circuit, attr) is None, f"{attr=} is not None" + assert circuit._placement_cache is not None # Check that attributes are cached after getting them. qubits = circuit.all_qubits() @@ -4682,7 +4742,7 @@ def test_mutation_clears_cached_attributes(circuit, mutate) -> None: is_parameterized = cirq.is_parameterized(circuit) parameter_names = cirq.parameter_names(circuit) - for attr in cached_attributes: + for attr in all_attributes: assert getattr(circuit, attr) is not None, f"{attr=} is None" # Check that getting again returns same object. @@ -4693,9 +4753,18 @@ def test_mutation_clears_cached_attributes(circuit, mutate) -> None: assert cirq.parameter_names(circuit) is parameter_names # Check that attributes are cleared after mutation. - mutate(circuit) - for attr in cached_attributes: - assert getattr(circuit, attr) is None, f"{attr=} is not None" + result = mutate(circuit) + if isinstance(result, cirq.Circuit): + # For functional "mutations" + circuit = result + + for attr in all_attributes: + # Standard attributes get cleared on any mutation (except `mul`), but placement cache only + # gets cleared on replacements and deletes. + if replaces_or_deletes or (inserts and attr in standard_attributes): + assert getattr(circuit, attr) is None, f"{attr=} is not None" + else: + assert getattr(circuit, attr) is not None, f"{attr=} is None" def test_factorize_one_factor() -> None: @@ -4908,7 +4977,25 @@ def test_create_speed() -> None: assert duration < 4 -def test_append_speed() -> None: +@pytest.mark.parametrize( + 'mutate', + [ + lambda c: c.insert(0, cirq.X(cirq.q('init'))), + lambda c: c.__imul__(2), + lambda c: c * 2, + lambda c: 2 * c, + lambda c: 2 * (c.copy()), + lambda c: (2 * c).copy(), + lambda c: c.__iadd__([cirq.X(cirq.q('init'))]), + lambda c: c + [cirq.X(cirq.q('init'))], + lambda c: [cirq.X(cirq.q('init'))] + c, + lambda c: c.insert_into_range([cirq.X(cirq.q('init'))], 0, 1), + lambda c: c.insert_at_frontier([cirq.X(cirq.q('init'))], 0), + lambda c: c.batch_insert_into([(1, cirq.X(cirq.q('init')))]), + lambda c: c.batch_insert([(0, cirq.X(cirq.q('init')))]), + ], +) +def test_append_speed(mutate) -> None: # Previously this took ~17s to run. Now it should take ~150ms. However the coverage test can # run this slowly, so allowing 5 sec to account for things like that. Feel free to increase the # buffer time or delete the test entirely if it ends up causing flakes. @@ -4918,18 +5005,20 @@ def test_append_speed() -> None: qs = 2 moments = 10000 xs = [cirq.X(cirq.LineQubit(i)) for i in range(qs)] - c = cirq.Circuit() + c = cirq.Circuit(cirq.X(cirq.q('init'))) + result = mutate(c) + if isinstance(result, cirq.Circuit): + # For functional "mutations" + c = result t = time.perf_counter() - # Iterating with the moments in the inner loop highlights the improvement: when filling in the - # second qubit, we no longer have to search backwards from moment 10000 for a placement index. - # c.insert(0, xs[0]) - # c.insert(0, xs[0]) + # Iterating with the moments in the inner loop highlights the improvement: when filling in + # the second qubit, we no longer have to search backwards from moment 10000 for a placement + # index. for q in range(qs): for _ in range(moments): c.append(xs[q]) duration = time.perf_counter() - t - print(duration) - # assert len(c) == moments + assert len(c) == moments assert duration < 5 From 07d00a5a5d33580bf5d415a6dd161bd300eb68aa Mon Sep 17 00:00:00 2001 From: daxfo Date: Thu, 11 Sep 2025 10:05:57 -0700 Subject: [PATCH 05/11] docstrings --- cirq-core/cirq/circuits/circuit.py | 96 ++++++++++++++++++++---------- 1 file changed, 63 insertions(+), 33 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 5886866c2f9..3715befeddf 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -2018,7 +2018,7 @@ def __radd__(self, other): return NotImplemented # Auto wrap OP_TREE inputs into a circuit. result = self.copy() - result._insert_moment(0, *Circuit(other)._moments) + result._insert_moments(0, *Circuit(other)._moments) return result # Needed for numpy to handle multiplication by np.int64 correctly. @@ -2030,7 +2030,7 @@ def __imul__(self, repetitions: _INT_TYPE): num_moments_added = len(self._moments) * (repetitions - 1) self._moments *= int(repetitions) if self._placement_cache: - self._placement_cache.insert_moment(0, num_moments_added) + self._placement_cache.insert_moments(0, num_moments_added) return self def __mul__(self, repetitions: _INT_TYPE): @@ -2220,18 +2220,21 @@ def insert( for op in cast(list[cirq.Operation], batch) ) ): - self._insert_moment(k) + self._insert_moments(k) if strategy is InsertStrategy.INLINE: k += 1 max_p = 0 for moment_or_op in batch: # Determine Placement if self._placement_cache and appending: + # This updates the cache and returns placement in a single step. It would be + # cleaner to "check" placement here and avoid the special `skip_cache_update` + # args below, but that adds about 15% latency to this perf-critical case. p = self._placement_cache.append(moment_or_op) elif isinstance(moment_or_op, Moment): p = k elif strategy in (InsertStrategy.NEW, InsertStrategy.NEW_THEN_INLINE): - self._insert_moment(k) + self._insert_moments(k) p = k elif strategy is InsertStrategy.INLINE: p = k - 1 @@ -2239,9 +2242,9 @@ def insert( p = self.earliest_available_moment(moment_or_op, end_moment_index=k) # Place if isinstance(moment_or_op, Moment): - self._insert_moment(p, moment_or_op, skip_cache=appending) + self._insert_moments(p, moment_or_op, skip_cache_update=appending) else: - self._put_op(p, moment_or_op, skip_cache=appending) + self._put_ops(p, moment_or_op, skip_cache_update=appending) # Iterate max_p = max(p, max_p) if strategy is InsertStrategy.NEW_THEN_INLINE: @@ -2250,25 +2253,49 @@ def insert( k = max(k, max_p + 1) return k - def _insert_moment(self, k: int, *moments: Moment, count: int = 1, skip_cache: bool = False): + def _insert_moments( + self, index: int, *moments: Moment, count: int = 1, skip_cache_update: bool = False + ): + """Inserts moments directly to a circuit at index k and updates caches. + + Args: + index: The moment index to insert the moment. If greater than the circuit length, the + moments will be appended. + moments: The moments to insert. If none are provided, a single empty moment will be + assumed. + count: The number of moments to insert. If both `moments` and `count` are provided, + the provided moments will be inserted `count` times. + skip_cache_update: Skips updates to the placement cache. Only use if the placement cache + has already been updated. + """ if not moments: moments = (Moment(),) moments *= count - self._moments[k:k] = moments - if self._placement_cache and not skip_cache: - self._placement_cache.insert_moment(k, len(moments)) + self._moments[index:index] = moments + if self._placement_cache and not skip_cache_update: + self._placement_cache.insert_moments(index, len(moments)) for i, m in enumerate(moments): - self._placement_cache.put(m, k + i) + self._placement_cache.put(index + i, m) self._mutated(preserve_placement_cache=True) - def _put_op(self, k: int, *ops: cirq.Operation, skip_cache: bool = False): - if k == len(self._moments): + def _put_ops(self, index: int, *ops: cirq.Operation, skip_cache_update: bool = False): + """Adds operations directly to moment k and updates caches. + + This is intended to be low-level and will fail if the moment does not exist or already has + conflicting operations. + + Args: + index: The moment index to add operations to. + ops: The operations to add. + skip_cache_update: Skips updates to the placement cache. Only use if the placement cache + has already been updated. + """ + if index == len(self._moments): self._moments.append(Moment.from_ops(*ops)) else: - self._moments[k] = self._moments[k].with_operations(*ops) - if self._placement_cache and not skip_cache: - for op in ops: - self._placement_cache.put(op, k) + self._moments[index] = self._moments[index].with_operations(*ops) + if self._placement_cache and not skip_cache_update: + self._placement_cache.put(index, *ops) self._mutated(preserve_placement_cache=True) def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> int: @@ -2303,7 +2330,7 @@ def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> i if i >= end: break - self._put_op(i, op) + self._put_ops(i, op) op_index += 1 if op_index >= len(flat_ops): @@ -2349,7 +2376,7 @@ def _push_frontier( ) if n_new_moments > 0: insert_index = min(late_frontier.values()) - self._insert_moment(insert_index, count=n_new_moments) + self._insert_moments(insert_index, count=n_new_moments) for q in update_qubits: if early_frontier.get(q, 0) > insert_index: early_frontier[q] += n_new_moments @@ -2375,12 +2402,12 @@ def _insert_operations( """ if len(operations) != len(insertion_indices): raise ValueError('operations and insertion_indices must have the same length.') - self._insert_moment(len(self), count=1 + max(insertion_indices) - len(self)) + self._insert_moments(len(self), count=1 + max(insertion_indices) - len(self)) moment_to_ops: dict[int, list[cirq.Operation]] = defaultdict(list) for op_index, moment_index in enumerate(insertion_indices): moment_to_ops[moment_index].append(operations[op_index]) for moment_index, new_ops in moment_to_ops.items(): - self._put_op(moment_index, *new_ops) + self._put_ops(moment_index, *new_ops) def insert_at_frontier( self, operations: cirq.OP_TREE, start: int, frontier: dict[cirq.Qid, int] | None = None @@ -2483,7 +2510,7 @@ def batch_insert_into(self, insert_intos: Iterable[tuple[int, cirq.OP_TREE]]) -> """ copy = self.copy() for i, insertions in insert_intos: - copy._put_op(i, *ops.flatten_to_ops(insertions)) + copy._put_ops(i, *ops.flatten_to_ops(insertions)) self._copy_from_shallow(copy) def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None: @@ -3089,25 +3116,28 @@ def append(self, moment_or_operation: _MOMENT_OR_OP) -> int: self._length = max(self._length, index + 1) return index - def put(self, moment_or_operation: _MOMENT_OR_OP, index: int) -> None: - self._put(self._qubit_indices, moment_or_operation.qubits, index) - self._put(self._mkey_indices, protocols.measurement_key_objs(moment_or_operation), index) - self._put(self._ckey_indices, protocols.control_keys(moment_or_operation), index) - self._length = max(self._length, index + 1) - - def insert_moment(self, index: int, count: int = 1) -> None: - self._insert_moment(self._qubit_indices, index, count) - self._insert_moment(self._mkey_indices, index, count) - self._insert_moment(self._ckey_indices, index, count) + def insert_moments(self, index: int, count: int = 1) -> None: + """Updates cache to account for empty moments inserted at circuit[index].""" + self._insert_moments(self._qubit_indices, index, count) + self._insert_moments(self._mkey_indices, index, count) + self._insert_moments(self._ckey_indices, index, count) self._length += count + def put(self, index: int, *moments_or_operations: _MOMENT_OR_OP) -> None: + """Updates cache to account for ops added to circuit[index].""" + for mop in moments_or_operations: + self._put(self._qubit_indices, mop.qubits, index) + self._put(self._mkey_indices, protocols.measurement_key_objs(mop), index) + self._put(self._ckey_indices, protocols.control_keys(mop), index) + self._length = max(self._length, index + len(moments_or_operations)) + @staticmethod def _put[T](key_indices: dict[T, int], mop_keys: Iterable[T], mop_index: int) -> None: for key in mop_keys: key_indices[key] = max(mop_index, key_indices.get(key, -1)) @staticmethod - def _insert_moment[T](key_indices: dict[T, int], index: int, count: int) -> None: + def _insert_moments[T](key_indices: dict[T, int], index: int, count: int) -> None: for key in key_indices: key_index = key_indices[key] if key_index >= index: From d407a4fa1df8aef5fdf60d831f8f6d3c55bb5eb6 Mon Sep 17 00:00:00 2001 From: daxfo Date: Thu, 11 Sep 2025 15:15:59 -0700 Subject: [PATCH 06/11] from_moments --- cirq-core/cirq/circuits/circuit.py | 9 +++------ cirq-core/cirq/circuits/circuit_test.py | 13 +++++++++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 3715befeddf..7da41d79c72 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1844,8 +1844,9 @@ def __init__( return flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents)) if all(isinstance(c, Moment) for c in flattened_contents): - self._placement_cache = None self._moments[:] = cast(Iterable[Moment], flattened_contents) + for i, moment in enumerate(self._moments): + self._placement_cache.put(i, moment) return with _compat.block_overlapping_deprecation('.*'): if strategy == InsertStrategy.EARLIEST: @@ -1865,11 +1866,7 @@ def _mutated(self, *, preserve_placement_cache=False) -> None: @classmethod def _from_moments(cls, moments: Iterable[cirq.Moment], tags: Sequence[Hashable]) -> Circuit: - new_circuit = Circuit() - new_circuit._moments[:] = moments - new_circuit._placement_cache = None - new_circuit._tags = tuple(tags) - return new_circuit + return Circuit(moments, tags=tags) def _load_contents_with_earliest_strategy(self, contents: cirq.OP_TREE): """Optimized algorithm to load contents quickly. diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index a96407552a3..c20cc524b06 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -4977,6 +4977,13 @@ def test_create_speed() -> None: assert duration < 4 +@pytest.mark.parametrize( + 'contents', + [ + cirq.X(cirq.q('init')), + cirq.Moment(cirq.X(cirq.q('init'))), + ], +) @pytest.mark.parametrize( 'mutate', [ @@ -4986,6 +4993,8 @@ def test_create_speed() -> None: lambda c: 2 * c, lambda c: 2 * (c.copy()), lambda c: (2 * c).copy(), + lambda c: (2 * c).freeze().unfreeze(), + lambda c: (2 * c.freeze()).unfreeze(), lambda c: c.__iadd__([cirq.X(cirq.q('init'))]), lambda c: c + [cirq.X(cirq.q('init'))], lambda c: [cirq.X(cirq.q('init'))] + c, @@ -4995,7 +5004,7 @@ def test_create_speed() -> None: lambda c: c.batch_insert([(0, cirq.X(cirq.q('init')))]), ], ) -def test_append_speed(mutate) -> None: +def test_append_speed(contents, mutate) -> None: # Previously this took ~17s to run. Now it should take ~150ms. However the coverage test can # run this slowly, so allowing 5 sec to account for things like that. Feel free to increase the # buffer time or delete the test entirely if it ends up causing flakes. @@ -5005,7 +5014,7 @@ def test_append_speed(mutate) -> None: qs = 2 moments = 10000 xs = [cirq.X(cirq.LineQubit(i)) for i in range(qs)] - c = cirq.Circuit(cirq.X(cirq.q('init'))) + c = cirq.Circuit(contents) result = mutate(c) if isinstance(result, cirq.Circuit): # For functional "mutations" From 2dcc25222780aa2f1ad89650f8c061b511dfaf9d Mon Sep 17 00:00:00 2001 From: daxfo Date: Thu, 11 Sep 2025 17:18:38 -0700 Subject: [PATCH 07/11] clear _frozen for 'mul' --- cirq-core/cirq/circuits/circuit.py | 2 + cirq-core/cirq/circuits/circuit_test.py | 76 ++++++++++--------------- 2 files changed, 33 insertions(+), 45 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 7da41d79c72..39d1038c0f4 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -2028,6 +2028,8 @@ def __imul__(self, repetitions: _INT_TYPE): self._moments *= int(repetitions) if self._placement_cache: self._placement_cache.insert_moments(0, num_moments_added) + if self._frozen: + self._frozen = None return self def __mul__(self, repetitions: _INT_TYPE): diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index c20cc524b06..719a8eb08ac 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -4634,93 +4634,75 @@ def test_freeze_is_cached() -> None: @pytest.mark.parametrize( - "circuit, mutate, inserts, replaces_or_deletes", + "circuit, mutate, action", [ ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__setitem__(0, cirq.Moment(cirq.Y(cirq.q(0)))), - False, - True, + 'update', ), - ( - cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), - lambda c: c.__delitem__(0), - False, - True, - ), - # Formally `mul` does insert, but in a way that doesn't affect caches. - (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__imul__(2), False, False), - (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c * 2, False, False), - (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: 2 * c, False, False), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__delitem__(0), 'delete'), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__imul__(2), 'mul'), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c * 2, 'mul'), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: 2 * c, 'mul'), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert(1, cirq.Y(cirq.q(0))), - True, - False, + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__iadd__([cirq.Y(cirq.q(0))]), - True, - False, + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c + [cirq.Y(cirq.q(0))], - True, - False, + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: [cirq.Y(cirq.q(0))] + c, - True, - False, + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert_into_range([cirq.Y(cirq.q(1)), cirq.M(cirq.q(1))], 0, 2), - True, - False, + 'insert', ), - (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.insert(1, []), False, False), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.insert(1, []), 'none'), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert_at_frontier([cirq.Y(cirq.q(0)), cirq.Y(cirq.q(1))], 1), - True, - False, + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.batch_replace([(0, cirq.X(cirq.q(0)), cirq.Y(cirq.q(0)))]), - False, - True, + 'update', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0), cirq.q(1))), lambda c: c.batch_insert_into([(0, cirq.X(cirq.q(1)))]), - True, - False, + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.batch_insert([(1, cirq.Y(cirq.q(0)))]), - True, - False, + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.clear_operations_touching([cirq.q(0)], [0]), - False, - True, + 'delete', ), - (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.copy(), False, False), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.copy(), 'none'), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.unfreeze(), 'none'), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.freeze().unfreeze(), 'freeze_cycle'), ], ) def test_mutation_clears_cached_attributes( - circuit: cirq.Circuit, - mutate: Callable[[cirq.Circuit], Any], - inserts: bool, - replaces_or_deletes: bool, + circuit: cirq.Circuit, mutate: Callable[[cirq.Circuit], Any], action: str ) -> None: standard_attributes = [ "_all_qubits", @@ -4761,7 +4743,11 @@ def test_mutation_clears_cached_attributes( for attr in all_attributes: # Standard attributes get cleared on any mutation (except `mul`), but placement cache only # gets cleared on replacements and deletes. - if replaces_or_deletes or (inserts and attr in standard_attributes): + if ( + (action in ['update', 'delete']) + or (action in ['insert', 'freeze_cycle'] and attr in standard_attributes) + or (action == 'mul' and attr == '_frozen') + ): assert getattr(circuit, attr) is None, f"{attr=} is not None" else: assert getattr(circuit, attr) is not None, f"{attr=} is None" @@ -4978,10 +4964,10 @@ def test_create_speed() -> None: @pytest.mark.parametrize( - 'contents', + 'create_circuit', [ - cirq.X(cirq.q('init')), - cirq.Moment(cirq.X(cirq.q('init'))), + lambda: cirq.Circuit(cirq.X(cirq.q('init'))), + lambda: cirq.Circuit.from_moments(cirq.Moment(cirq.X(cirq.q('init')))), ], ) @pytest.mark.parametrize( @@ -5004,7 +4990,7 @@ def test_create_speed() -> None: lambda c: c.batch_insert([(0, cirq.X(cirq.q('init')))]), ], ) -def test_append_speed(contents, mutate) -> None: +def test_append_speed(create_circuit, mutate) -> None: # Previously this took ~17s to run. Now it should take ~150ms. However the coverage test can # run this slowly, so allowing 5 sec to account for things like that. Feel free to increase the # buffer time or delete the test entirely if it ends up causing flakes. @@ -5014,7 +5000,7 @@ def test_append_speed(contents, mutate) -> None: qs = 2 moments = 10000 xs = [cirq.X(cirq.LineQubit(i)) for i in range(qs)] - c = cirq.Circuit(contents) + c = create_circuit() result = mutate(c) if isinstance(result, cirq.Circuit): # For functional "mutations" From eb905605ad6e72264f918006a0fb484cd7771ecc Mon Sep 17 00:00:00 2001 From: daxfo Date: Thu, 11 Sep 2025 19:22:10 -0700 Subject: [PATCH 08/11] docs --- cirq-core/cirq/circuits/circuit.py | 22 ++++++++++++++-------- cirq-core/cirq/circuits/circuit_test.py | 6 +++--- 2 files changed, 17 insertions(+), 11 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index 39d1038c0f4..44f9dbbf2ef 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1964,7 +1964,11 @@ def copy(self) -> Circuit: return copied_circuit def _copy_from_shallow(self, other: Circuit) -> None: - """Copies the contents of another circuit into this one.""" + """Copies the contents of another circuit into this one. + + This performs a shallow copy from another circuit. It is primarily intended for reimporting + data from temporary copies that were created during multistep mutations to allow them to be + performed atomically.""" self._moments[:] = other._moments self._tags = other.tags self._all_qubits = other._all_qubits @@ -2027,9 +2031,9 @@ def __imul__(self, repetitions: _INT_TYPE): num_moments_added = len(self._moments) * (repetitions - 1) self._moments *= int(repetitions) if self._placement_cache: + # Shift everything `num_moments_added` to the right. self._placement_cache.insert_moments(0, num_moments_added) - if self._frozen: - self._frozen = None + self._frozen = None # All other cache values are resilient to mul. return self def __mul__(self, repetitions: _INT_TYPE): @@ -2225,11 +2229,13 @@ def insert( max_p = 0 for moment_or_op in batch: # Determine Placement + cache_updated = False if self._placement_cache and appending: # This updates the cache and returns placement in a single step. It would be # cleaner to "check" placement here and avoid the special `skip_cache_update` # args below, but that adds about 15% latency to this perf-critical case. p = self._placement_cache.append(moment_or_op) + cache_updated = True elif isinstance(moment_or_op, Moment): p = k elif strategy in (InsertStrategy.NEW, InsertStrategy.NEW_THEN_INLINE): @@ -2241,9 +2247,9 @@ def insert( p = self.earliest_available_moment(moment_or_op, end_moment_index=k) # Place if isinstance(moment_or_op, Moment): - self._insert_moments(p, moment_or_op, skip_cache_update=appending) + self._insert_moments(p, moment_or_op, skip_cache_update=cache_updated) else: - self._put_ops(p, moment_or_op, skip_cache_update=appending) + self._put_ops(p, moment_or_op, skip_cache_update=cache_updated) # Iterate max_p = max(p, max_p) if strategy is InsertStrategy.NEW_THEN_INLINE: @@ -2255,7 +2261,7 @@ def insert( def _insert_moments( self, index: int, *moments: Moment, count: int = 1, skip_cache_update: bool = False ): - """Inserts moments directly to a circuit at index k and updates caches. + """Inserts moments directly before circuit[index] and updates caches. Args: index: The moment index to insert the moment. If greater than the circuit length, the @@ -2278,10 +2284,10 @@ def _insert_moments( self._mutated(preserve_placement_cache=True) def _put_ops(self, index: int, *ops: cirq.Operation, skip_cache_update: bool = False): - """Adds operations directly to moment k and updates caches. + """Adds operations directly to circuit[index] and updates caches. This is intended to be low-level and will fail if the moment does not exist or already has - conflicting operations. + conflicting operations. Args: index: The moment index to add operations to. diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 719a8eb08ac..16ce4dea328 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -4967,6 +4967,7 @@ def test_create_speed() -> None: 'create_circuit', [ lambda: cirq.Circuit(cirq.X(cirq.q('init'))), + lambda: cirq.Circuit(cirq.Moment(cirq.X(cirq.q('init')))), lambda: cirq.Circuit.from_moments(cirq.Moment(cirq.X(cirq.q('init')))), ], ) @@ -5006,9 +5007,8 @@ def test_append_speed(create_circuit, mutate) -> None: # For functional "mutations" c = result t = time.perf_counter() - # Iterating with the moments in the inner loop highlights the improvement: when filling in - # the second qubit, we no longer have to search backwards from moment 10000 for a placement - # index. + # Iterating with the moments in the inner loop highlights the improvement: when filling in the + # second qubit, we no longer have to search backwards from moment 10000 for a placement index. for q in range(qs): for _ in range(moments): c.append(xs[q]) From e46f3c5e888e1ee4a67590708e52ce10cbff012f Mon Sep 17 00:00:00 2001 From: daxfo Date: Thu, 11 Sep 2025 23:24:43 -0700 Subject: [PATCH 09/11] fix generic syntax --- cirq-core/cirq/circuits/circuit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index a1668f64fee..d1253e3e413 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -3143,12 +3143,12 @@ def put(self, index: int, *moments_or_operations: _MOMENT_OR_OP) -> None: self._length = max(self._length, index + len(moments_or_operations)) @staticmethod - def _put[T](key_indices: dict[T, int], mop_keys: Iterable[T], mop_index: int) -> None: + def _put(key_indices: dict[_TKey, int], mop_keys: Iterable[_TKey], mop_index: int) -> None: for key in mop_keys: key_indices[key] = max(mop_index, key_indices.get(key, -1)) @staticmethod - def _insert_moments[T](key_indices: dict[T, int], index: int, count: int) -> None: + def _insert_moments(key_indices: dict[_TKey, int], index: int, count: int) -> None: for key in key_indices: key_index = key_indices[key] if key_index >= index: From 59b5e03eedf933b1feee497c2213688930fd655c Mon Sep 17 00:00:00 2001 From: daxfo Date: Thu, 11 Sep 2025 23:42:28 -0700 Subject: [PATCH 10/11] Revert change that builds cache when creating Circuit from Moment lists, update tests accordingly. --- cirq-core/cirq/circuits/circuit.py | 9 ++++++--- cirq-core/cirq/circuits/circuit_test.py | 18 ++++-------------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index d1253e3e413..a3cb63b0e60 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1850,9 +1850,8 @@ def __init__( return flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents)) if all(isinstance(c, Moment) for c in flattened_contents): + self._placement_cache = None self._moments[:] = cast(Iterable[Moment], flattened_contents) - for i, moment in enumerate(self._moments): - self._placement_cache.put(i, moment) return with _compat.block_overlapping_deprecation('.*'): if strategy == InsertStrategy.EARLIEST: @@ -1872,7 +1871,11 @@ def _mutated(self, *, preserve_placement_cache=False) -> None: @classmethod def _from_moments(cls, moments: Iterable[cirq.Moment], tags: Sequence[Hashable]) -> Circuit: - return Circuit(moments, tags=tags) + new_circuit = Circuit() + new_circuit._moments[:] = moments + new_circuit._placement_cache = None + new_circuit._tags = tuple(tags) + return new_circuit def _load_contents_with_earliest_strategy(self, contents: cirq.OP_TREE): """Optimized algorithm to load contents quickly. diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 16ce4dea328..66031c97863 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -4744,8 +4744,8 @@ def test_mutation_clears_cached_attributes( # Standard attributes get cleared on any mutation (except `mul`), but placement cache only # gets cleared on replacements and deletes. if ( - (action in ['update', 'delete']) - or (action in ['insert', 'freeze_cycle'] and attr in standard_attributes) + (action in ['update', 'delete', 'freeze_cycle']) + or (action in ['insert'] and attr in standard_attributes) or (action == 'mul' and attr == '_frozen') ): assert getattr(circuit, attr) is None, f"{attr=} is not None" @@ -4963,14 +4963,6 @@ def test_create_speed() -> None: assert duration < 4 -@pytest.mark.parametrize( - 'create_circuit', - [ - lambda: cirq.Circuit(cirq.X(cirq.q('init'))), - lambda: cirq.Circuit(cirq.Moment(cirq.X(cirq.q('init')))), - lambda: cirq.Circuit.from_moments(cirq.Moment(cirq.X(cirq.q('init')))), - ], -) @pytest.mark.parametrize( 'mutate', [ @@ -4980,8 +4972,6 @@ def test_create_speed() -> None: lambda c: 2 * c, lambda c: 2 * (c.copy()), lambda c: (2 * c).copy(), - lambda c: (2 * c).freeze().unfreeze(), - lambda c: (2 * c.freeze()).unfreeze(), lambda c: c.__iadd__([cirq.X(cirq.q('init'))]), lambda c: c + [cirq.X(cirq.q('init'))], lambda c: [cirq.X(cirq.q('init'))] + c, @@ -4991,7 +4981,7 @@ def test_create_speed() -> None: lambda c: c.batch_insert([(0, cirq.X(cirq.q('init')))]), ], ) -def test_append_speed(create_circuit, mutate) -> None: +def test_append_speed(mutate) -> None: # Previously this took ~17s to run. Now it should take ~150ms. However the coverage test can # run this slowly, so allowing 5 sec to account for things like that. Feel free to increase the # buffer time or delete the test entirely if it ends up causing flakes. @@ -5001,7 +4991,7 @@ def test_append_speed(create_circuit, mutate) -> None: qs = 2 moments = 10000 xs = [cirq.X(cirq.LineQubit(i)) for i in range(qs)] - c = create_circuit() + c = cirq.Circuit(cirq.X(cirq.q('init'))) result = mutate(c) if isinstance(result, cirq.Circuit): # For functional "mutations" From d347a267738f200e6b0145c767701209070b8b8f Mon Sep 17 00:00:00 2001 From: daxfo Date: Sat, 13 Sep 2025 20:50:09 -0700 Subject: [PATCH 11/11] nits --- cirq-core/cirq/circuits/circuit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index a3cb63b0e60..cf12e4e575c 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -1978,7 +1978,7 @@ def _copy_from_shallow(self, other: Circuit) -> None: This performs a shallow copy from another circuit. It is primarily intended for reimporting data from temporary copies that were created during multistep mutations to allow them to be performed atomically.""" - self._moments[:] = other._moments + self._moments = other._moments self._tags = other.tags self._all_qubits = other._all_qubits self._frozen = other._frozen @@ -3143,7 +3143,7 @@ def put(self, index: int, *moments_or_operations: _MOMENT_OR_OP) -> None: self._put(self._qubit_indices, mop.qubits, index) self._put(self._mkey_indices, protocols.measurement_key_objs(mop), index) self._put(self._ckey_indices, protocols.control_keys(mop), index) - self._length = max(self._length, index + len(moments_or_operations)) + self._length = max(self._length, index + 1) @staticmethod def _put(key_indices: dict[_TKey, int], mop_keys: Iterable[_TKey], mop_index: int) -> None: