Skip to content

Commit

Permalink
Support for stimcirq gates and operations in cirq_google protos (#7101)
Browse files Browse the repository at this point in the history
* Support for stimcirq gates and operations in cirq_google protos

- Adds special casing for stimcirq gates and operations.
- Note that this only supports gates and operations where the arguments
can be serialized.
- Serializing the stimcirq gates uses the json dictionary in order to
gather arguments from the operations.
- Tests will only be run if stimcirq is installed (manual use only)

* Fix some tests.

* Add requirements for stimcirq

* fix coverage

* format

* Address comments.

* Fix coverage.

* format

* Update cirq-google/cirq_google/serialization/circuit_serializer_test.py

Co-authored-by: Pavol Juhas <[email protected]>

* Move import to cached function

---------

Co-authored-by: Pavol Juhas <[email protected]>
  • Loading branch information
dstrain115 and pavoljuhas authored Mar 7, 2025
1 parent 34b9c81 commit 2167552
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 4 deletions.
4 changes: 3 additions & 1 deletion cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def all_subclasses(cls):
gate_subclasses = {
g
for g in all_subclasses(cirq.Gate)
if "cirq." in g.__module__ and "contrib" not in g.__module__ and "test" not in g.__module__
if g.__module__.startswith("cirq.")
and "contrib" not in g.__module__
and "test" not in g.__module__
}

test_module_spec = cirq.testing.json.spec_for("cirq.protocols")
Expand Down
57 changes: 55 additions & 2 deletions cirq-google/cirq_google/serialization/circuit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Support for serializing and deserializing cirq_google.api.v2 protos."""

from typing import Any, Dict, List, Optional
import functools
import warnings
import numpy as np
import sympy
Expand All @@ -38,6 +39,9 @@
# CircuitSerializer is the dedicated serializer for the v2.5 format.
_SERIALIZER_NAME = 'v2_5'

# Package name for stimcirq
_STIMCIRQ_MODULE = "stimcirq"


class CircuitSerializer(serializer.Serializer):
"""A class for serializing and deserializing programs and operations.
Expand Down Expand Up @@ -193,7 +197,6 @@ def _serialize_gate_op(
ValueError: If the operation cannot be serialized.
"""
gate = op.gate

if isinstance(gate, InternalGate):
arg_func_langs.internal_gate_arg_to_proto(gate, out=msg.internalgate)
elif isinstance(gate, cirq.XPowGate):
Expand Down Expand Up @@ -260,6 +263,30 @@ def _serialize_gate_op(
arg_func_langs.float_arg_to_proto(
gate.q1_detune_mhz, out=msg.couplerpulsegate.q1_detune_mhz
)
elif getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) or getattr(
gate, "__module__", ""
).startswith(_STIMCIRQ_MODULE):
# Special handling for stimcirq objects, which can be both operations and gates.
stimcirq_obj = (
op if getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) else gate
)
if stimcirq_obj is not None and hasattr(stimcirq_obj, '_json_dict_'):
# All stimcirq gates currently have _json_dict_defined
msg.internalgate.name = type(stimcirq_obj).__name__
msg.internalgate.module = _STIMCIRQ_MODULE
if isinstance(stimcirq_obj, cirq.Gate):
msg.internalgate.num_qubits = stimcirq_obj.num_qubits()
else:
msg.internalgate.num_qubits = len(stimcirq_obj.qubits)

# Store json_dict objects in gate_args
for k, v in stimcirq_obj._json_dict_().items():
arg_func_langs.arg_to_proto(value=v, out=msg.internalgate.gate_args[k])
else:
# New stimcirq op without a json dict has been introduced
raise ValueError(
f'Cannot serialize stimcirq {op!r}:{type(gate)}'
) # pragma: no cover
else:
raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}')

Expand Down Expand Up @@ -670,7 +697,21 @@ def _deserialize_gate_op(
raise ValueError(f"dimensions {dimensions} for ResetChannel must be an integer!")
op = cirq.ResetChannel(dimension=dimensions)(*qubits)
elif which_gate_type == 'internalgate':
op = arg_func_langs.internal_gate_from_proto(operation_proto.internalgate)(*qubits)
msg = operation_proto.internalgate
if msg.module == _STIMCIRQ_MODULE and msg.name in _stimcirq_json_resolvers():
# special handling for stimcirq
# Use JSON resolver to instantiate the object
kwargs = {}
for k, v in msg.gate_args.items():
arg = arg_func_langs.arg_from_proto(v)
if arg is not None:
kwargs[k] = arg
op = _stimcirq_json_resolvers()[msg.name](**kwargs)
if qubits:
op = op(*qubits)
else:
# all other internal gates
op = arg_func_langs.internal_gate_from_proto(msg)(*qubits)
elif which_gate_type == 'couplerpulsegate':
gate = CouplerPulse(
hold_time=cirq.Duration(
Expand Down Expand Up @@ -766,4 +807,16 @@ def _deserialize_tag(self, msg: v2.program_pb2.Tag):
return None


@functools.cache
def _stimcirq_json_resolvers():
"""Retrieves stimcirq JSON resolvers if stimcirq is installed.
Returns an empty dict if not installed."""
try:
import stimcirq

return stimcirq.JSON_RESOLVERS_DICT
except ModuleNotFoundError: # pragma: no cover
return {} # pragma: no cover


CIRCUIT_SERIALIZER = CircuitSerializer()
15 changes: 15 additions & 0 deletions cirq-google/cirq_google/serialization/circuit_serializer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,18 @@ def test_reset_gate_with_improper_argument():

with pytest.raises(ValueError, match="must be an integer"):
serializer.deserialize(circuit_proto)


def test_stimcirq_gates():
stimcirq = pytest.importorskip("stimcirq")
serializer = cg.CircuitSerializer()
q = cirq.q(1, 2)
q2 = cirq.q(2, 2)
c = cirq.Circuit(
cirq.Moment(stimcirq.CXSwapGate(inverted=True)(q, q2)),
cirq.Moment(cirq.measure(q, key="m")),
cirq.Moment(stimcirq.DetAnnotation(parity_keys=["m"])),
)
msg = serializer.serialize(c)
deserialized_circuit = serializer.deserialize(msg)
assert deserialized_circuit == c
2 changes: 1 addition & 1 deletion dev_tools/conf/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ follow_imports = silent
ignore_missing_imports = true

# Non-Google
[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*]
[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*,stimcirq.*]
follow_imports = silent
ignore_missing_imports = true

Expand Down
3 changes: 3 additions & 0 deletions dev_tools/requirements/deps/dev-tools.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ asv

# For verifying behavior of qasm output.
qiskit-aer~=0.16.1

# For testing stimcirq compatibility (cirq-google)
stimcirq

0 comments on commit 2167552

Please sign in to comment.