diff --git a/cirq-core/cirq/ops/gateset.py b/cirq-core/cirq/ops/gateset.py index a591ac4d8a2..1f299e8fc03 100644 --- a/cirq-core/cirq/ops/gateset.py +++ b/cirq-core/cirq/ops/gateset.py @@ -340,6 +340,7 @@ def __init__( self._unroll_circuit_op = unroll_circuit_op self._instance_gate_families: dict[raw_types.Gate, GateFamily] = {} self._type_gate_families: dict[type[raw_types.Gate], GateFamily] = {} + self._gate_families_with_tags: list[GateFamily] = [] self._gates_repr_str = ", ".join([_gate_str(g, repr) for g in gates]) unique_gate_list: list[GateFamily] = list( dict.fromkeys(g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates) @@ -351,6 +352,12 @@ def __init__( self._instance_gate_families[g.gate] = g else: self._type_gate_families[g.gate] = g + else: + if isinstance(g.gate, raw_types.Gate): + self._gate_families_with_tags.append(g) + else: + # Instance checks are faster, so test them first. + self._gate_families_with_tags.insert(0, g) self._unique_gate_list = unique_gate_list self._gates = frozenset(unique_gate_list) @@ -422,6 +429,7 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool: g = item if isinstance(item, raw_types.Gate) else item.gate assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`' + # Check "type" based GateFamily since isinstance is fast for gate_mro_type in type(g).mro(): if gate_mro_type in self._type_gate_families: assert item in self._type_gate_families[gate_mro_type], ( @@ -430,6 +438,7 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool: ) return True + # Check exact instance equality next if g in self._instance_gate_families: assert item in self._instance_gate_families[g], ( f"{item} instance matches {self._instance_gate_families[g]} but " @@ -437,7 +446,17 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool: ) return True - return any(item in gate_family for gate_family in self._gates) + # Check other GateFamilies next + if any(item in gate_family for gate_family in self._gate_families_with_tags): + return True + + # Lastly, do a final exhaustive check to make sure this is not equivalent + # to another type of gate. This will catch things like: + # cirq.XPowGate(exponent=0) in cirq.GateFamily(cirq.I) + return any( + item in gate_family + for gate_family in self._gates.difference(self._gate_families_with_tags) + ) def validate(self, circuit_or_optree: cirq.AbstractCircuit | op_tree.OP_TREE) -> bool: """Validates gates forming `circuit_or_optree` should be contained in Gateset.