Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 136 additions & 40 deletions cirq-core/cirq/circuits/circuit.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
from __future__ import annotations

import abc
import copy
import enum
import html
import itertools
import math
from collections import defaultdict
from types import NotImplementedType
from typing import (
AbstractSet,
Any,
Callable,
cast,
Expand Down Expand Up @@ -1341,10 +1341,10 @@ def _is_parameterized_(self) -> bool:
protocols.is_parameterized(tag) for tag in self.tags
)

def _parameter_names_(self) -> AbstractSet[str]:
def _parameter_names_(self) -> frozenset[str]:
op_params = {name for op in self.all_operations() for name in protocols.parameter_names(op)}
tag_params = {name for tag in self.tags for name in protocols.parameter_names(tag)}
return op_params | tag_params
return frozenset(op_params | tag_params)

def _resolve_parameters_(self, resolver: cirq.ParamResolver, recursive: bool) -> Self:
changed = False
Expand Down Expand Up @@ -1845,7 +1845,7 @@ def __init__(
self._frozen: cirq.FrozenCircuit | None = None
self._is_measurement: bool | None = None
self._is_parameterized: bool | None = None
self._parameter_names: AbstractSet[str] | None = None
self._parameter_names: frozenset[str] | None = None
if not contents:
return
flattened_contents = tuple(ops.flatten_to_ops_or_moments(contents))
Expand Down Expand Up @@ -1954,7 +1954,7 @@ def _is_parameterized_(self) -> bool:
self._is_parameterized = super()._is_parameterized_()
return self._is_parameterized

def _parameter_names_(self) -> AbstractSet[str]:
def _parameter_names_(self) -> frozenset[str]:
if self._parameter_names is None:
self._parameter_names = super()._parameter_names_()
return self._parameter_names
Expand All @@ -1963,10 +1963,30 @@ def copy(self) -> Circuit:
"""Return a copy of this circuit."""
copied_circuit = Circuit()
copied_circuit._moments[:] = self._moments
copied_circuit._placement_cache = None
copied_circuit._tags = self.tags
copied_circuit._all_qubits = self._all_qubits
copied_circuit._frozen = self._frozen
copied_circuit._is_measurement = self._is_measurement
copied_circuit._is_parameterized = self._is_parameterized
copied_circuit._parameter_names = self._parameter_names
copied_circuit._placement_cache = copy.copy(self._placement_cache)
return copied_circuit

def _copy_from_shallow(self, other: Circuit) -> None:
"""Copies the contents of another circuit into this one.

This performs a shallow copy from another circuit. It is primarily intended for reimporting
data from temporary copies that were created during multistep mutations to allow them to be
performed atomically."""
self._moments = other._moments
self._tags = other.tags
self._all_qubits = other._all_qubits
self._frozen = other._frozen
self._is_measurement = other._is_measurement
self._is_parameterized = other._is_parameterized
self._parameter_names = other._parameter_names
self._placement_cache = other._placement_cache

@overload
def __setitem__(self, key: int, value: cirq.Moment):
pass
Expand Down Expand Up @@ -2008,7 +2028,7 @@ def __radd__(self, other):
return NotImplemented
# Auto wrap OP_TREE inputs into a circuit.
result = self.copy()
result._moments[:0] = Circuit(other)._moments
result._insert_moments(0, *Circuit(other)._moments)
return result

# Needed for numpy to handle multiplication by np.int64 correctly.
Expand All @@ -2017,19 +2037,19 @@ def __radd__(self, other):
def __imul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
num_moments_added = len(self._moments) * (repetitions - 1)
self._moments *= int(repetitions)
self._mutated()
if self._placement_cache:
# Shift everything `num_moments_added` to the right.
self._placement_cache.insert_moments(0, num_moments_added)
self._frozen = None # All other cache values are resilient to mul.
return self

def __mul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
return Circuit(self._moments * int(repetitions), tags=self.tags)
return self.copy().__imul__(repetitions)

def __rmul__(self, repetitions: _INT_TYPE):
if not isinstance(repetitions, (int, np.integer)):
return NotImplemented
return self * int(repetitions)
return self.copy().__imul__(repetitions)

def __pow__(self, exponent: int) -> cirq.Circuit:
"""A circuit raised to a power, only valid for exponent -1, the inverse.
Expand Down Expand Up @@ -2192,10 +2212,9 @@ def insert(
"""
# limit index to 0..len(self._moments), also deal with indices smaller 0
k = max(min(index if index >= 0 else len(self._moments) + index, len(self._moments)), 0)
if strategy != InsertStrategy.EARLIEST or k != len(self._moments):
self._placement_cache = None
appending = strategy == InsertStrategy.EARLIEST and k == len(self._moments)
mops = list(ops.flatten_to_ops_or_moments(moment_or_operation_tree))
if self._placement_cache:
if self._placement_cache and appending:
batches = [mops] # Any grouping would work here; this just happens to be the fastest.
elif strategy is InsertStrategy.NEW:
batches = [[mop] for mop in mops] # Each op goes into its own moment.
Expand All @@ -2204,7 +2223,7 @@ def insert(
for batch in batches:
# Insert a moment if inline/earliest and _any_ op in the batch requires it.
if (
not self._placement_cache
not appending
and not isinstance(batch[0], Moment)
and strategy in (InsertStrategy.INLINE, InsertStrategy.EARLIEST)
and not all(
Expand All @@ -2213,39 +2232,86 @@ def insert(
for op in cast(list[cirq.Operation], batch)
)
):
self._moments.insert(k, Moment())
self._insert_moments(k)
if strategy is InsertStrategy.INLINE:
k += 1
max_p = 0
for moment_or_op in batch:
# Determine Placement
if self._placement_cache:
cache_updated = False
if self._placement_cache and appending:
# This updates the cache and returns placement in a single step. It would be
# cleaner to "check" placement here and avoid the special `skip_cache_update`
# args below, but that adds about 15% latency to this perf-critical case.
p = self._placement_cache.append(moment_or_op)
cache_updated = True
elif isinstance(moment_or_op, Moment):
p = k
elif strategy in (InsertStrategy.NEW, InsertStrategy.NEW_THEN_INLINE):
self._moments.insert(k, Moment())
self._insert_moments(k)
p = k
elif strategy is InsertStrategy.INLINE:
p = k - 1
else: # InsertStrategy.EARLIEST:
p = self.earliest_available_moment(moment_or_op, end_moment_index=k)
# Place
if isinstance(moment_or_op, Moment):
self._moments.insert(p, moment_or_op)
elif p == len(self._moments):
self._moments.append(Moment(moment_or_op))
self._insert_moments(p, moment_or_op, skip_cache_update=cache_updated)
else:
self._moments[p] = self._moments[p].with_operation(moment_or_op)
self._put_ops(p, moment_or_op, skip_cache_update=cache_updated)
# Iterate
max_p = max(p, max_p)
if strategy is InsertStrategy.NEW_THEN_INLINE:
strategy = InsertStrategy.INLINE
k += 1
k = max(k, max_p + 1)
self._mutated(preserve_placement_cache=True)
return k

def _insert_moments(
self, index: int, *moments: Moment, count: int = 1, skip_cache_update: bool = False
):
"""Inserts moments directly before circuit[index] and updates caches.

Args:
index: The moment index to insert the moment. If greater than the circuit length, the
moments will be appended.
moments: The moments to insert. If none are provided, a single empty moment will be
assumed.
count: The number of moments to insert. If both `moments` and `count` are provided,
the provided moments will be inserted `count` times.
skip_cache_update: Skips updates to the placement cache. Only use if the placement cache
has already been updated.
"""
if not moments:
moments = (Moment(),)
moments *= count
self._moments[index:index] = moments
if self._placement_cache and not skip_cache_update:
self._placement_cache.insert_moments(index, len(moments))
for i, m in enumerate(moments):
self._placement_cache.put(index + i, m)
self._mutated(preserve_placement_cache=True)

def _put_ops(self, index: int, *ops: cirq.Operation, skip_cache_update: bool = False):
"""Adds operations directly to circuit[index] and updates caches.

This is intended to be low-level and will fail if the moment does not exist or already has
conflicting operations.

Args:
index: The moment index to add operations to.
ops: The operations to add.
skip_cache_update: Skips updates to the placement cache. Only use if the placement cache
has already been updated.
"""
if index == len(self._moments):
self._moments.append(Moment.from_ops(*ops))
else:
self._moments[index] = self._moments[index].with_operations(*ops)
if self._placement_cache and not skip_cache_update:
self._placement_cache.put(index, *ops)
self._mutated(preserve_placement_cache=True)

def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> int:
"""Writes operations inline into an area of the circuit.

Expand Down Expand Up @@ -2278,9 +2344,8 @@ def insert_into_range(self, operations: cirq.OP_TREE, start: int, end: int) -> i
if i >= end:
break

self._moments[i] = self._moments[i].with_operation(op)
self._put_ops(i, op)
op_index += 1
self._mutated()

if op_index >= len(flat_ops):
return end
Expand Down Expand Up @@ -2325,8 +2390,7 @@ def _push_frontier(
)
if n_new_moments > 0:
insert_index = min(late_frontier.values())
self._moments[insert_index:insert_index] = [Moment()] * n_new_moments
self._mutated()
self._insert_moments(insert_index, count=n_new_moments)
for q in update_qubits:
if early_frontier.get(q, 0) > insert_index:
early_frontier[q] += n_new_moments
Expand All @@ -2352,13 +2416,12 @@ def _insert_operations(
"""
if len(operations) != len(insertion_indices):
raise ValueError('operations and insertion_indices must have the same length.')
self._moments += [Moment() for _ in range(1 + max(insertion_indices) - len(self))]
self._mutated()
self._insert_moments(len(self), count=1 + max(insertion_indices) - len(self))
moment_to_ops: dict[int, list[cirq.Operation]] = defaultdict(list)
for op_index, moment_index in enumerate(insertion_indices):
moment_to_ops[moment_index].append(operations[op_index])
for moment_index, new_ops in moment_to_ops.items():
self._moments[moment_index] = self._moments[moment_index].with_operations(*new_ops)
self._put_ops(moment_index, *new_ops)

def insert_at_frontier(
self, operations: cirq.OP_TREE, start: int, frontier: dict[cirq.Qid, int] | None = None
Expand Down Expand Up @@ -2461,9 +2524,8 @@ def batch_insert_into(self, insert_intos: Iterable[tuple[int, cirq.OP_TREE]]) ->
"""
copy = self.copy()
for i, insertions in insert_intos:
copy._moments[i] = copy._moments[i].with_operations(insertions)
self._moments = copy._moments
self._mutated()
copy._put_ops(i, *ops.flatten_to_ops(insertions))
self._copy_from_shallow(copy)

def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None:
"""Applies a batched insert operation to the circuit.
Expand Down Expand Up @@ -2497,8 +2559,7 @@ def batch_insert(self, insertions: Iterable[tuple[int, cirq.OP_TREE]]) -> None:
next_index = copy.insert(insert_index, reversed(group), InsertStrategy.EARLIEST)
if next_index > insert_index:
shift += next_index - insert_index
self._moments = copy._moments
self._mutated()
self._copy_from_shallow(copy)

def append(
self,
Expand Down Expand Up @@ -2543,8 +2604,8 @@ def with_tags(self, *new_tags: Hashable) -> cirq.Circuit:
"""Creates a new tagged `Circuit` with `self.tags` and `new_tags` combined."""
if not new_tags:
return self
new_circuit = Circuit(tags=self.tags + new_tags)
new_circuit._moments[:] = self._moments
new_circuit = self.copy()
new_circuit._tags = self.tags + new_tags
return new_circuit

def with_noise(self, noise: cirq.NOISE_MODEL_LIKE) -> cirq.Circuit:
Expand Down Expand Up @@ -3068,3 +3129,38 @@ def append(self, moment_or_operation: _MOMENT_OR_OP) -> int:
)
self._length = max(self._length, index + 1)
return index

def insert_moments(self, index: int, count: int = 1) -> None:
"""Updates cache to account for empty moments inserted at circuit[index]."""
self._insert_moments(self._qubit_indices, index, count)
self._insert_moments(self._mkey_indices, index, count)
self._insert_moments(self._ckey_indices, index, count)
self._length += count

def put(self, index: int, *moments_or_operations: _MOMENT_OR_OP) -> None:
"""Updates cache to account for ops added to circuit[index]."""
for mop in moments_or_operations:
self._put(self._qubit_indices, mop.qubits, index)
self._put(self._mkey_indices, protocols.measurement_key_objs(mop), index)
self._put(self._ckey_indices, protocols.control_keys(mop), index)
self._length = max(self._length, index + 1)

@staticmethod
def _put(key_indices: dict[_TKey, int], mop_keys: Iterable[_TKey], mop_index: int) -> None:
for key in mop_keys:
key_indices[key] = max(mop_index, key_indices.get(key, -1))

@staticmethod
def _insert_moments(key_indices: dict[_TKey, int], index: int, count: int) -> None:
for key in key_indices:
key_index = key_indices[key]
if key_index >= index:
key_indices[key] = key_index + count

def __copy__(self) -> _PlacementCache:
cache = _PlacementCache()
cache._qubit_indices = self._qubit_indices.copy()
cache._mkey_indices = self._mkey_indices.copy()
cache._ckey_indices = self._ckey_indices.copy()
cache._length = self._length
return cache
Loading