@@ -340,6 +340,7 @@ def __init__(
340340 self ._unroll_circuit_op = unroll_circuit_op
341341 self ._instance_gate_families : dict [raw_types .Gate , GateFamily ] = {}
342342 self ._type_gate_families : dict [type [raw_types .Gate ], GateFamily ] = {}
343+ self ._gate_families_with_tags : list [GateFamily ] = []
343344 self ._gates_repr_str = ", " .join ([_gate_str (g , repr ) for g in gates ])
344345 unique_gate_list : list [GateFamily ] = list (
345346 dict .fromkeys (g if isinstance (g , GateFamily ) else GateFamily (gate = g ) for g in gates )
@@ -351,6 +352,12 @@ def __init__(
351352 self ._instance_gate_families [g .gate ] = g
352353 else :
353354 self ._type_gate_families [g .gate ] = g
355+ else :
356+ if isinstance (g .gate , raw_types .Gate ):
357+ self ._gate_families_with_tags .append (g )
358+ else :
359+ # Instance checks are faster, so test them first.
360+ self ._gate_families_with_tags .insert (0 , g )
354361 self ._unique_gate_list = unique_gate_list
355362 self ._gates = frozenset (unique_gate_list )
356363
@@ -422,6 +429,7 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool:
422429 g = item if isinstance (item , raw_types .Gate ) else item .gate
423430 assert g is not None , f'`item`: { item } must be a gate or have a valid `item.gate`'
424431
432+ # Check "type" based GateFamily since isinstance is fast
425433 for gate_mro_type in type (g ).mro ():
426434 if gate_mro_type in self ._type_gate_families :
427435 assert item in self ._type_gate_families [gate_mro_type ], (
@@ -430,14 +438,25 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool:
430438 )
431439 return True
432440
441+ # Check exact instance equality next
433442 if g in self ._instance_gate_families :
434443 assert item in self ._instance_gate_families [g ], (
435444 f"{ item } instance matches { self ._instance_gate_families [g ]} but "
436445 f"is not accepted by it."
437446 )
438447 return True
439448
440- return any (item in gate_family for gate_family in self ._gates )
449+ # Check other GateFamilies next
450+ if any (item in gate_family for gate_family in self ._gate_families_with_tags ):
451+ return True
452+
453+ # Lastly, do a final exhaustive check to make sure this is not equivalent
454+ # to another type of gate. This will catch things like:
455+ # cirq.XPowGate(exponent=0) in cirq.GateFamily(cirq.I)
456+ return any (
457+ item in gate_family
458+ for gate_family in self ._gates .difference (self ._gate_families_with_tags )
459+ )
441460
442461 def validate (self , circuit_or_optree : cirq .AbstractCircuit | op_tree .OP_TREE ) -> bool :
443462 """Validates gates forming `circuit_or_optree` should be contained in Gateset.
0 commit comments