diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index a4eb0227ee1..cf12e4e575c 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -22,6 +22,7 @@ from __future__ import annotations import abc +import copy import enum import html import itertools @@ -29,7 +30,6 @@ from collections import defaultdict from types import NotImplementedType from typing import ( - AbstractSet, Any, Callable, cast, @@ -1341,10 +1341,10 @@ def _is_parameterized_(self) -> bool: protocols.is_parameterized(tag) for tag in self.tags ) - def _parameter_names_(self) -> AbstractSet[str]: + def _parameter_names_(self) -> frozenset[str]: op_params = {name for op in self.all_operations() for name in protocols.parameter_names(op)} tag_params = {name for tag in self.tags for name in protocols.parameter_names(tag)} - return op_params | tag_params + return frozenset(op_params | tag_params) def _resolve_parameters_(self, resolver: cirq.ParamResolver, recursive: bool) -> Self: changed = False @@ -1845,7 +1845,7 @@ def __init__( self._frozen: cirq.FrozenCircuit | None = None self._is_measurement: bool | None = None self._is_parameterized: bool | None = None - self._parameter_names: AbstractSet[str] | None = None + self._parameter_names: frozenset[str] | None = None if not contents: return flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents)) @@ -1954,7 +1954,7 @@ def _is_parameterized_(self) -> bool: self._is_parameterized = super()._is_parameterized_() return self._is_parameterized - def _parameter_names_(self) -> AbstractSet[str]: + def _parameter_names_(self) -> frozenset[str]: if self._parameter_names is None: self._parameter_names = super()._parameter_names_() return self._parameter_names @@ -1963,10 +1963,30 @@ def copy(self) -> Circuit: """Return a copy of this circuit.""" copied_circuit = Circuit() copied_circuit._moments[:] = self._moments - copied_circuit._placement_cache = None copied_circuit._tags = self.tags + copied_circuit._all_qubits = self._all_qubits + copied_circuit._frozen = self._frozen + copied_circuit._is_measurement = self._is_measurement + copied_circuit._is_parameterized = self._is_parameterized + copied_circuit._parameter_names = self._parameter_names + copied_circuit._placement_cache = copy.copy(self._placement_cache) return copied_circuit + def _copy_from_shallow(self, other: Circuit) -> None: + """Copies the contents of another circuit into this one. + + This performs a shallow copy from another circuit. It is primarily intended for reimporting + data from temporary copies that were created during multistep mutations to allow them to be + performed atomically.""" + self._moments = other._moments + self._tags = other.tags + self._all_qubits = other._all_qubits + self._frozen = other._frozen + self._is_measurement = other._is_measurement + self._is_parameterized = other._is_parameterized + self._parameter_names = other._parameter_names + self._placement_cache = other._placement_cache + @overload def __setitem__(self, key: int, value: cirq.Moment): pass @@ -2008,7 +2028,7 @@ def __radd__(self, other): return NotImplemented # Auto wrap OP_TREE inputs into a circuit. result = self.copy() - result._moments[:0] = Circuit(other)._moments + result._insert_moments(0, *Circuit(other)._moments) return result # Needed for numpy to handle multiplication by np.int64 correctly. @@ -2017,19 +2037,19 @@ def __radd__(self, other): def __imul__(self, repetitions: _INT_TYPE): if not isinstance(repetitions, (int, np.integer)): return NotImplemented + num_moments_added = len(self._moments) * (repetitions - 1) self._moments *= int(repetitions) - self._mutated() + if self._placement_cache: + # Shift everything `num_moments_added` to the right. + self._placement_cache.insert_moments(0, num_moments_added) + self._frozen = None # All other cache values are resilient to mul. return self def __mul__(self, repetitions: _INT_TYPE): - if not isinstance(repetitions, (int, np.integer)): - return NotImplemented - return Circuit(self._moments * int(repetitions), tags=self.tags) + return self.copy().__imul__(repetitions) def __rmul__(self, repetitions: _INT_TYPE): - if not isinstance(repetitions, (int, np.integer)): - return NotImplemented - return self * int(repetitions) + return self.copy().__imul__(repetitions) def __pow__(self, exponent: int) -> cirq.Circuit: """A circuit raised to a power, only valid for exponent -1, the inverse. @@ -2192,10 +2212,9 @@ def insert( """ # limit index to 0..len(self._moments), also deal with indices smaller 0 k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0) - if strategy != InsertStrategy.EARLIEST or k != len(self._moments): - self._placement_cache = None + appending = strategy == InsertStrategy.EARLIEST and k == len(self._moments) mops = list(ops.flatten_to_ops_or_moments(moment_or_operation_tree)) - if self._placement_cache: + if self._placement_cache and appending: batches = [mops] # Any grouping would work here; this just happens to be the fastest. elif strategy is InsertStrategy.NEW: batches = [[mop] for mop in mops] # Each op goes into its own moment. @@ -2204,7 +2223,7 @@ def insert( for batch in batches: # Insert a moment if inline/earliest and _any_ op in the batch requires it. if ( - not self._placement_cache + not appending and not isinstance(batch[0], Moment) and strategy in (InsertStrategy.INLINE, InsertStrategy.EARLIEST) and not all( @@ -2213,18 +2232,23 @@ def insert( for op in cast(list[cirq.Operation], batch) ) ): - self._moments.insert(k, Moment()) + self._insert_moments(k) if strategy is InsertStrategy.INLINE: k += 1 max_p = 0 for moment_or_op in batch: # Determine Placement - if self._placement_cache: + cache_updated = False + if self._placement_cache and appending: + # This updates the cache and returns placement in a single step. It would be + # cleaner to "check" placement here and avoid the special `skip_cache_update` + # args below, but that adds about 15% latency to this perf-critical case. p = self._placement_cache.append(moment_or_op) + cache_updated = True elif isinstance(moment_or_op, Moment): p = k elif strategy in (InsertStrategy.NEW, InsertStrategy.NEW_THEN_INLINE): - self._moments.insert(k, Moment()) + self._insert_moments(k) p = k elif strategy is InsertStrategy.INLINE: p = k - 1 @@ -2232,20 +2256,62 @@ def insert( p = self.earliest_available_moment(moment_or_op, end_moment_index=k) # Place if isinstance(moment_or_op, Moment): - self._moments.insert(p, moment_or_op) - elif p == len(self._moments): - self._moments.append(Moment(moment_or_op)) + self._insert_moments(p, moment_or_op, skip_cache_update=cache_updated) else: - self._moments[p] = self._moments[p].with_operation(moment_or_op) + self._put_ops(p, moment_or_op, skip_cache_update=cache_updated) # Iterate max_p = max(p, max_p) if strategy is InsertStrategy.NEW_THEN_INLINE: strategy = InsertStrategy.INLINE k += 1 k = max(k, max_p + 1) - self._mutated(preserve_placement_cache=True) return k + def _insert_moments( + self, index: int, *moments: Moment, count: int = 1, skip_cache_update: bool = False + ): + """Inserts moments directly before circuit[index] and updates caches. + + Args: + index: The moment index to insert the moment. If greater than the circuit length, the + moments will be appended. + moments: The moments to insert. If none are provided, a single empty moment will be + assumed. + count: The number of moments to insert. If both `moments` and `count` are provided, + the provided moments will be inserted `count` times. + skip_cache_update: Skips updates to the placement cache. Only use if the placement cache + has already been updated. + """ + if not moments: + moments = (Moment(),) + moments *= count + self._moments[index:index] = moments + if self._placement_cache and not skip_cache_update: + self._placement_cache.insert_moments(index, len(moments)) + for i, m in enumerate(moments): + self._placement_cache.put(index + i, m) + self._mutated(preserve_placement_cache=True) + + def _put_ops(self, index: int, *ops: cirq.Operation, skip_cache_update: bool = False): + """Adds operations directly to circuit[index] and updates caches. + + This is intended to be low-level and will fail if the moment does not exist or already has + conflicting operations. + + Args: + index: The moment index to add operations to. + ops: The operations to add. + skip_cache_update: Skips updates to the placement cache. Only use if the placement cache + has already been updated. + """ + if index == len(self._moments): + self._moments.append(Moment.from_ops(*ops)) + else: + self._moments[index] = self._moments[index].with_operations(*ops) + if self._placement_cache and not skip_cache_update: + self._placement_cache.put(index, *ops) + self._mutated(preserve_placement_cache=True) + def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> int: """Writes operations inline into an area of the circuit. @@ -2278,9 +2344,8 @@ def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> i if i >= end: break - self._moments[i] = self._moments[i].with_operation(op) + self._put_ops(i, op) op_index += 1 - self._mutated() if op_index >= len(flat_ops): return end @@ -2325,8 +2390,7 @@ def _push_frontier( ) if n_new_moments > 0: insert_index = min(late_frontier.values()) - self._moments[insert_index:insert_index] = [Moment()] * n_new_moments - self._mutated() + self._insert_moments(insert_index, count=n_new_moments) for q in update_qubits: if early_frontier.get(q, 0) > insert_index: early_frontier[q] += n_new_moments @@ -2352,13 +2416,12 @@ def _insert_operations( """ if len(operations) != len(insertion_indices): raise ValueError('operations and insertion_indices must have the same length.') - self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))] - self._mutated() + self._insert_moments(len(self), count=1 + max(insertion_indices) - len(self)) moment_to_ops: dict[int, list[cirq.Operation]] = defaultdict(list) for op_index, moment_index in enumerate(insertion_indices): moment_to_ops[moment_index].append(operations[op_index]) for moment_index, new_ops in moment_to_ops.items(): - self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops) + self._put_ops(moment_index, *new_ops) def insert_at_frontier( self, operations: cirq.OP_TREE, start: int, frontier: dict[cirq.Qid, int] | None = None @@ -2461,9 +2524,8 @@ def batch_insert_into(self, insert_intos: Iterable[tuple[int, cirq.OP_TREE]]) -> """ copy = self.copy() for i, insertions in insert_intos: - copy._moments[i] = copy._moments[i].with_operations(insertions) - self._moments = copy._moments - self._mutated() + copy._put_ops(i, *ops.flatten_to_ops(insertions)) + self._copy_from_shallow(copy) def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None: """Applies a batched insert operation to the circuit. @@ -2497,8 +2559,7 @@ def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None: next_index = copy.insert(insert_index, reversed(group), InsertStrategy.EARLIEST) if next_index > insert_index: shift += next_index - insert_index - self._moments = copy._moments - self._mutated() + self._copy_from_shallow(copy) def append( self, @@ -2543,8 +2604,8 @@ def with_tags(self, *new_tags: Hashable) -> cirq.Circuit: """Creates a new tagged `Circuit` with `self.tags` and `new_tags` combined.""" if not new_tags: return self - new_circuit = Circuit(tags=self.tags + new_tags) - new_circuit._moments[:] = self._moments + new_circuit = self.copy() + new_circuit._tags = self.tags + new_tags return new_circuit def with_noise(self, noise: cirq.NOISE_MODEL_LIKE) -> cirq.Circuit: @@ -3068,3 +3129,38 @@ def append(self, moment_or_operation: _MOMENT_OR_OP) -> int: ) self._length = max(self._length, index + 1) return index + + def insert_moments(self, index: int, count: int = 1) -> None: + """Updates cache to account for empty moments inserted at circuit[index].""" + self._insert_moments(self._qubit_indices, index, count) + self._insert_moments(self._mkey_indices, index, count) + self._insert_moments(self._ckey_indices, index, count) + self._length += count + + def put(self, index: int, *moments_or_operations: _MOMENT_OR_OP) -> None: + """Updates cache to account for ops added to circuit[index].""" + for mop in moments_or_operations: + self._put(self._qubit_indices, mop.qubits, index) + self._put(self._mkey_indices, protocols.measurement_key_objs(mop), index) + self._put(self._ckey_indices, protocols.control_keys(mop), index) + self._length = max(self._length, index + 1) + + @staticmethod + def _put(key_indices: dict[_TKey, int], mop_keys: Iterable[_TKey], mop_index: int) -> None: + for key in mop_keys: + key_indices[key] = max(mop_index, key_indices.get(key, -1)) + + @staticmethod + def _insert_moments(key_indices: dict[_TKey, int], index: int, count: int) -> None: + for key in key_indices: + key_index = key_indices[key] + if key_index >= index: + key_indices[key] = key_index + count + + def __copy__(self) -> _PlacementCache: + cache = _PlacementCache() + cache._qubit_indices = self._qubit_indices.copy() + cache._mkey_indices = self._mkey_indices.copy() + cache._ckey_indices = self._ckey_indices.copy() + cache._length = self._length + return cache diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index cd8f7b00c70..66031c97863 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -19,7 +19,7 @@ import time from collections import defaultdict from random import randint, random, randrange, sample -from typing import Iterator, Sequence +from typing import Any, Callable, Iterator, Sequence import numpy as np import pytest @@ -4634,55 +4634,88 @@ def test_freeze_is_cached() -> None: @pytest.mark.parametrize( - "circuit, mutate", + "circuit, mutate, action", [ ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__setitem__(0, cirq.Moment(cirq.Y(cirq.q(0)))), + 'update', ), - (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__delitem__(0)), - (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__imul__(2)), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__delitem__(0), 'delete'), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.__imul__(2), 'mul'), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c * 2, 'mul'), + (cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: 2 * c, 'mul'), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert(1, cirq.Y(cirq.q(0))), + 'insert', + ), + ( + cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), + lambda c: c.__iadd__([cirq.Y(cirq.q(0))]), + 'insert', + ), + ( + cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), + lambda c: c + [cirq.Y(cirq.q(0))], + 'insert', + ), + ( + cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), + lambda c: [cirq.Y(cirq.q(0))] + c, + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert_into_range([cirq.Y(cirq.q(1)), cirq.M(cirq.q(1))], 0, 2), + 'insert', ), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.insert(1, []), 'none'), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.insert_at_frontier([cirq.Y(cirq.q(0)), cirq.Y(cirq.q(1))], 1), + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.batch_replace([(0, cirq.X(cirq.q(0)), cirq.Y(cirq.q(0)))]), + 'update', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0), cirq.q(1))), lambda c: c.batch_insert_into([(0, cirq.X(cirq.q(1)))]), + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.batch_insert([(1, cirq.Y(cirq.q(0)))]), + 'insert', ), ( cirq.Circuit(cirq.X(cirq.q(0)), cirq.M(cirq.q(0))), lambda c: c.clear_operations_touching([cirq.q(0)], [0]), + 'delete', ), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.copy(), 'none'), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.unfreeze(), 'none'), + (cirq.Circuit(cirq.X(cirq.q(0))), lambda c: c.freeze().unfreeze(), 'freeze_cycle'), ], ) -def test_mutation_clears_cached_attributes(circuit, mutate) -> None: - cached_attributes = [ +def test_mutation_clears_cached_attributes( + circuit: cirq.Circuit, mutate: Callable[[cirq.Circuit], Any], action: str +) -> None: + standard_attributes = [ "_all_qubits", "_frozen", "_is_measurement", "_is_parameterized", "_parameter_names", ] + all_attributes = standard_attributes + ['_placement_cache'] - for attr in cached_attributes: + for attr in standard_attributes: assert getattr(circuit, attr) is None, f"{attr=} is not None" + assert circuit._placement_cache is not None # Check that attributes are cached after getting them. qubits = circuit.all_qubits() @@ -4691,7 +4724,7 @@ def test_mutation_clears_cached_attributes(circuit, mutate) -> None: is_parameterized = cirq.is_parameterized(circuit) parameter_names = cirq.parameter_names(circuit) - for attr in cached_attributes: + for attr in all_attributes: assert getattr(circuit, attr) is not None, f"{attr=} is None" # Check that getting again returns same object. @@ -4702,9 +4735,22 @@ def test_mutation_clears_cached_attributes(circuit, mutate) -> None: assert cirq.parameter_names(circuit) is parameter_names # Check that attributes are cleared after mutation. - mutate(circuit) - for attr in cached_attributes: - assert getattr(circuit, attr) is None, f"{attr=} is not None" + result = mutate(circuit) + if isinstance(result, cirq.Circuit): + # For functional "mutations" + circuit = result + + for attr in all_attributes: + # Standard attributes get cleared on any mutation (except `mul`), but placement cache only + # gets cleared on replacements and deletes. + if ( + (action in ['update', 'delete', 'freeze_cycle']) + or (action in ['insert'] and attr in standard_attributes) + or (action == 'mul' and attr == '_frozen') + ): + assert getattr(circuit, attr) is None, f"{attr=} is not None" + else: + assert getattr(circuit, attr) is not None, f"{attr=} is None" def test_factorize_one_factor() -> None: @@ -4917,7 +4963,25 @@ def test_create_speed() -> None: assert duration < 4 -def test_append_speed() -> None: +@pytest.mark.parametrize( + 'mutate', + [ + lambda c: c.insert(0, cirq.X(cirq.q('init'))), + lambda c: c.__imul__(2), + lambda c: c * 2, + lambda c: 2 * c, + lambda c: 2 * (c.copy()), + lambda c: (2 * c).copy(), + lambda c: c.__iadd__([cirq.X(cirq.q('init'))]), + lambda c: c + [cirq.X(cirq.q('init'))], + lambda c: [cirq.X(cirq.q('init'))] + c, + lambda c: c.insert_into_range([cirq.X(cirq.q('init'))], 0, 1), + lambda c: c.insert_at_frontier([cirq.X(cirq.q('init'))], 0), + lambda c: c.batch_insert_into([(1, cirq.X(cirq.q('init')))]), + lambda c: c.batch_insert([(0, cirq.X(cirq.q('init')))]), + ], +) +def test_append_speed(mutate) -> None: # Previously this took ~17s to run. Now it should take ~150ms. However the coverage test can # run this slowly, so allowing 5 sec to account for things like that. Feel free to increase the # buffer time or delete the test entirely if it ends up causing flakes. @@ -4927,7 +4991,11 @@ def test_append_speed() -> None: qs = 2 moments = 10000 xs = [cirq.X(cirq.LineQubit(i)) for i in range(qs)] - c = cirq.Circuit() + c = cirq.Circuit(cirq.X(cirq.q('init'))) + result = mutate(c) + if isinstance(result, cirq.Circuit): + # For functional "mutations" + c = result t = time.perf_counter() # Iterating with the moments in the inner loop highlights the improvement: when filling in the # second qubit, we no longer have to search backwards from moment 10000 for a placement index. diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index 2aa8d401bda..b4b03d51462 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -18,7 +18,7 @@ from functools import cached_property from types import NotImplementedType -from typing import AbstractSet, Hashable, Iterable, Iterator, Sequence, TYPE_CHECKING +from typing import Hashable, Iterable, Iterator, Sequence, TYPE_CHECKING from cirq import _compat, protocols from cirq.circuits import AbstractCircuit, Alignment, Circuit @@ -172,7 +172,7 @@ def _is_parameterized_(self) -> bool: return super()._is_parameterized_() @_compat.cached_method - def _parameter_names_(self) -> AbstractSet[str]: + def _parameter_names_(self) -> frozenset[str]: return super()._parameter_names_() def _measurement_key_names_(self) -> frozenset[str]: diff --git a/cirq-core/cirq/circuits/frozen_circuit_test.py b/cirq-core/cirq/circuits/frozen_circuit_test.py index fda9c1fe1be..12528d88d7b 100644 --- a/cirq-core/cirq/circuits/frozen_circuit_test.py +++ b/cirq-core/cirq/circuits/frozen_circuit_test.py @@ -76,9 +76,16 @@ def test_freeze_and_unfreeze() -> None: cc = c.unfreeze() assert cc is not c + # Refreezing without modification returns original FrozenCircuit. fcc = cc.freeze() assert fcc.moments == f.moments - assert fcc is not f + assert fcc is f + + # Modifying and refreezing returns new FrozenCircuit. + cc.append(cirq.X(a)) + fcc2 = cc.freeze() + assert tuple(cc.moments) == fcc2.moments + assert fcc2 is not f def test_immutable() -> None: diff --git a/cirq-core/cirq/circuits/moment.py b/cirq-core/cirq/circuits/moment.py index 414e2fab0f2..051fc1ecce8 100644 --- a/cirq-core/cirq/circuits/moment.py +++ b/cirq-core/cirq/circuits/moment.py @@ -256,6 +256,9 @@ def with_operations(self, *contents: cirq.OP_TREE) -> cirq.Moment: Raises: ValueError: If the contents given overlaps a current operation in the moment. """ + if len(contents) == 1 and isinstance(contents[0], ops.Operation): + return self.with_operation(contents[0]) + flattened_contents = tuple(op_tree.flatten_to_ops(contents)) if not flattened_contents: