Skip to content

Commit e8e79bb

Browse files
authored
Micro-optimization for Gateset containment (#7692)
- Check gate families with tags before doing instance GateFamily checks. - This change will check cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[PhysicalZ()]) before instance gate families. - We generally want to avoid checking instance gate families if at all possible, since they can call protocols.equal_up_to_global_phase which is quite slow.
1 parent 01e82bf commit e8e79bb

File tree

1 file changed

+20
-1
lines changed

1 file changed

+20
-1
lines changed

cirq-core/cirq/ops/gateset.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ def __init__(
340340
self._unroll_circuit_op = unroll_circuit_op
341341
self._instance_gate_families: dict[raw_types.Gate, GateFamily] = {}
342342
self._type_gate_families: dict[type[raw_types.Gate], GateFamily] = {}
343+
self._gate_families_with_tags: list[GateFamily] = []
343344
self._gates_repr_str = ", ".join([_gate_str(g, repr) for g in gates])
344345
unique_gate_list: list[GateFamily] = list(
345346
dict.fromkeys(g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates)
@@ -351,6 +352,12 @@ def __init__(
351352
self._instance_gate_families[g.gate] = g
352353
else:
353354
self._type_gate_families[g.gate] = g
355+
else:
356+
if isinstance(g.gate, raw_types.Gate):
357+
self._gate_families_with_tags.append(g)
358+
else:
359+
# Instance checks are faster, so test them first.
360+
self._gate_families_with_tags.insert(0, g)
354361
self._unique_gate_list = unique_gate_list
355362
self._gates = frozenset(unique_gate_list)
356363

@@ -422,6 +429,7 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool:
422429
g = item if isinstance(item, raw_types.Gate) else item.gate
423430
assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`'
424431

432+
# Check "type" based GateFamily since isinstance is fast
425433
for gate_mro_type in type(g).mro():
426434
if gate_mro_type in self._type_gate_families:
427435
assert item in self._type_gate_families[gate_mro_type], (
@@ -430,14 +438,25 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool:
430438
)
431439
return True
432440

441+
# Check exact instance equality next
433442
if g in self._instance_gate_families:
434443
assert item in self._instance_gate_families[g], (
435444
f"{item} instance matches {self._instance_gate_families[g]} but "
436445
f"is not accepted by it."
437446
)
438447
return True
439448

440-
return any(item in gate_family for gate_family in self._gates)
449+
# Check other GateFamilies next
450+
if any(item in gate_family for gate_family in self._gate_families_with_tags):
451+
return True
452+
453+
# Lastly, do a final exhaustive check to make sure this is not equivalent
454+
# to another type of gate. This will catch things like:
455+
# cirq.XPowGate(exponent=0) in cirq.GateFamily(cirq.I)
456+
return any(
457+
item in gate_family
458+
for gate_family in self._gates.difference(self._gate_families_with_tags)
459+
)
441460

442461
def validate(self, circuit_or_optree: cirq.AbstractCircuit | op_tree.OP_TREE) -> bool:
443462
"""Validates gates forming `circuit_or_optree` should be contained in Gateset.

0 commit comments

Comments
 (0)