Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion cirq-core/cirq/ops/gateset.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def __init__(
self._unroll_circuit_op = unroll_circuit_op
self._instance_gate_families: dict[raw_types.Gate, GateFamily] = {}
self._type_gate_families: dict[type[raw_types.Gate], GateFamily] = {}
self._gate_families_with_tags: list[GateFamily] = []
self._gates_repr_str = ", ".join([_gate_str(g, repr) for g in gates])
unique_gate_list: list[GateFamily] = list(
dict.fromkeys(g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates)
Expand All @@ -351,6 +352,12 @@ def __init__(
self._instance_gate_families[g.gate] = g
else:
self._type_gate_families[g.gate] = g
else:
if isinstance(g.gate, raw_types.Gate):
self._gate_families_with_tags.append(g)
else:
# Instance checks are faster, so test them first.
self._gate_families_with_tags.insert(0, g)
self._unique_gate_list = unique_gate_list
self._gates = frozenset(unique_gate_list)

Expand Down Expand Up @@ -422,6 +429,7 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool:
g = item if isinstance(item, raw_types.Gate) else item.gate
assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`'

# Check "type" based GateFamily since isinstance is fast
for gate_mro_type in type(g).mro():
if gate_mro_type in self._type_gate_families:
assert item in self._type_gate_families[gate_mro_type], (
Expand All @@ -430,14 +438,25 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool:
)
return True

# Check exact instance equality next
if g in self._instance_gate_families:
assert item in self._instance_gate_families[g], (
f"{item} instance matches {self._instance_gate_families[g]} but "
f"is not accepted by it."
)
return True

return any(item in gate_family for gate_family in self._gates)
# Check other GateFamilies next
if any(item in gate_family for gate_family in self._gate_families_with_tags):
return True

# Lastly, do a final exhaustive check to make sure this is not equivalent
# to another type of gate. This will catch things like:
# cirq.XPowGate(exponent=0) in cirq.GateFamily(cirq.I)
return any(
item in gate_family
for gate_family in self._gates.difference(self._gate_families_with_tags)
)

def validate(self, circuit_or_optree: cirq.AbstractCircuit | op_tree.OP_TREE) -> bool:
"""Validates gates forming `circuit_or_optree` should be contained in Gateset.
Expand Down
Loading