From 3d36232a462a2ab9ab5cb5717b7c14e2f7e1be18 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Thu, 2 Oct 2025 16:05:28 -0700 Subject: [PATCH 1/3] Micro-optimization for Gateset containment - Check gate families with tags before doing instance GateFamily checks. - This change will check cirq.GateFamily(cirq.ZPowGate, tags_to_accept=[PhysicalZ()]) before instance gate families. - We generally want to avoid checking instance gate families if at all possible, since they can call protocols.equal_up_to_global_phase which is quite slow. --- cirq-core/cirq/ops/gateset.py | 16 ++++++++++++++++ cirq-core/cirq/ops/gateset_test.py | 3 +++ 2 files changed, 19 insertions(+) diff --git a/cirq-core/cirq/ops/gateset.py b/cirq-core/cirq/ops/gateset.py index a591ac4d8a2..a4b52ab3607 100644 --- a/cirq-core/cirq/ops/gateset.py +++ b/cirq-core/cirq/ops/gateset.py @@ -340,6 +340,7 @@ def __init__( self._unroll_circuit_op = unroll_circuit_op self._instance_gate_families: dict[raw_types.Gate, GateFamily] = {} self._type_gate_families: dict[type[raw_types.Gate], GateFamily] = {} + self._gate_families_with_tags: list[GateFamily] = [] self._gates_repr_str = ", ".join([_gate_str(g, repr) for g in gates]) unique_gate_list: list[GateFamily] = list( dict.fromkeys(g if isinstance(g, GateFamily) else GateFamily(gate=g) for g in gates) @@ -351,6 +352,12 @@ def __init__( self._instance_gate_families[g.gate] = g else: self._type_gate_families[g.gate] = g + else: + if isinstance(g.gate, raw_types.Gate): + self._gate_families_with_tags.append(g) + else: + # Instance checks are faster, so test them first. + self._gate_families_with_tags.insert(0, g) self._unique_gate_list = unique_gate_list self._gates = frozenset(unique_gate_list) @@ -422,6 +429,7 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool: g = item if isinstance(item, raw_types.Gate) else item.gate assert g is not None, f'`item`: {item} must be a gate or have a valid `item.gate`' + # Check "type" based GateFamily since isinstance is fast for gate_mro_type in type(g).mro(): if gate_mro_type in self._type_gate_families: assert item in self._type_gate_families[gate_mro_type], ( @@ -430,6 +438,7 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool: ) return True + # Check exact instance equality next if g in self._instance_gate_families: assert item in self._instance_gate_families[g], ( f"{item} instance matches {self._instance_gate_families[g]} but " @@ -437,6 +446,13 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool: ) return True + # Check other GateFamilies next + if any(item in gate_family for gate_family in self._gate_families_with_tags): + return True + + # Lastly, do a final exhaustive check to make sure this is not equivalent + # to another type of gate. This will catch things like: + # cirq.XPowGate(exponent=0) in cirq.GateFamily(cirq.I) return any(item in gate_family for gate_family in self._gates) def validate(self, circuit_or_optree: cirq.AbstractCircuit | op_tree.OP_TREE) -> bool: diff --git a/cirq-core/cirq/ops/gateset_test.py b/cirq-core/cirq/ops/gateset_test.py index ca5870a0aeb..2930dc7b065 100644 --- a/cirq-core/cirq/ops/gateset_test.py +++ b/cirq-core/cirq/ops/gateset_test.py @@ -329,6 +329,9 @@ def test_gateset_repr_and_str(g) -> None: ], ) def test_gateset_contains(gate, result) -> None: + print(gateset._gate_families_with_tags) + for g in gateset.gates: + print(f"{g} {gate in g}") assert (gate in gateset) is result op = gate(*cirq.LineQubit.range(gate.num_qubits())) assert (op in gateset) is result From 1c8a5bd0789704105e698876f37aee7c1497da57 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Wed, 8 Oct 2025 06:12:17 -0700 Subject: [PATCH 2/3] Remove prints --- cirq-core/cirq/ops/gateset_test.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/cirq-core/cirq/ops/gateset_test.py b/cirq-core/cirq/ops/gateset_test.py index 2930dc7b065..ca5870a0aeb 100644 --- a/cirq-core/cirq/ops/gateset_test.py +++ b/cirq-core/cirq/ops/gateset_test.py @@ -329,9 +329,6 @@ def test_gateset_repr_and_str(g) -> None: ], ) def test_gateset_contains(gate, result) -> None: - print(gateset._gate_families_with_tags) - for g in gateset.gates: - print(f"{g} {gate in g}") assert (gate in gateset) is result op = gate(*cirq.LineQubit.range(gate.num_qubits())) assert (op in gateset) is result From 7d51273d9e4390a5059d99682aed1455e87fce31 Mon Sep 17 00:00:00 2001 From: Doug Strain Date: Wed, 8 Oct 2025 06:14:11 -0700 Subject: [PATCH 3/3] Address comments. --- cirq-core/cirq/ops/gateset.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cirq-core/cirq/ops/gateset.py b/cirq-core/cirq/ops/gateset.py index a4b52ab3607..1f299e8fc03 100644 --- a/cirq-core/cirq/ops/gateset.py +++ b/cirq-core/cirq/ops/gateset.py @@ -453,7 +453,10 @@ def __contains__(self, item: raw_types.Gate | raw_types.Operation) -> bool: # Lastly, do a final exhaustive check to make sure this is not equivalent # to another type of gate. This will catch things like: # cirq.XPowGate(exponent=0) in cirq.GateFamily(cirq.I) - return any(item in gate_family for gate_family in self._gates) + return any( + item in gate_family + for gate_family in self._gates.difference(self._gate_families_with_tags) + ) def validate(self, circuit_or_optree: cirq.AbstractCircuit | op_tree.OP_TREE) -> bool: """Validates gates forming `circuit_or_optree` should be contained in Gateset.