Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for stimcirq gates and operations in cirq_google protos #7101

Merged
merged 14 commits into from
Mar 7, 2025
Merged
4 changes: 3 additions & 1 deletion cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def all_subclasses(cls):
gate_subclasses = {
g
for g in all_subclasses(cirq.Gate)
if "cirq." in g.__module__ and "contrib" not in g.__module__ and "test" not in g.__module__
if g.__module__.startswith("cirq.")
and "contrib" not in g.__module__
and "test" not in g.__module__
}

test_module_spec = cirq.testing.json.spec_for("cirq.protocols")
Expand Down
57 changes: 55 additions & 2 deletions cirq-google/cirq_google/serialization/circuit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Support for serializing and deserializing cirq_google.api.v2 protos."""

from typing import Any, Dict, List, Optional
import functools
import warnings
import numpy as np
import sympy
Expand All @@ -38,6 +39,9 @@
# CircuitSerializer is the dedicated serializer for the v2.5 format.
_SERIALIZER_NAME = 'v2_5'

# Package name for stimcirq
_STIMCIRQ_MODULE = "stimcirq"


class CircuitSerializer(serializer.Serializer):
"""A class for serializing and deserializing programs and operations.
Expand Down Expand Up @@ -193,7 +197,6 @@ def _serialize_gate_op(
ValueError: If the operation cannot be serialized.
"""
gate = op.gate

if isinstance(gate, InternalGate):
arg_func_langs.internal_gate_arg_to_proto(gate, out=msg.internalgate)
elif isinstance(gate, cirq.XPowGate):
Expand Down Expand Up @@ -260,6 +263,30 @@ def _serialize_gate_op(
arg_func_langs.float_arg_to_proto(
gate.q1_detune_mhz, out=msg.couplerpulsegate.q1_detune_mhz
)
elif getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) or getattr(
gate, "__module__", ""
).startswith(_STIMCIRQ_MODULE):
# Special handling for stimcirq objects, which can be both operations and gates.
stimcirq_obj = (
op if getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) else gate
)
if stimcirq_obj is not None and hasattr(stimcirq_obj, '_json_dict_'):
# All stimcirq gates currently have _json_dict_defined
msg.internalgate.name = type(stimcirq_obj).__name__
msg.internalgate.module = _STIMCIRQ_MODULE
if isinstance(stimcirq_obj, cirq.Gate):
msg.internalgate.num_qubits = stimcirq_obj.num_qubits()
else:
msg.internalgate.num_qubits = len(stimcirq_obj.qubits)

# Store json_dict objects in gate_args
for k, v in stimcirq_obj._json_dict_().items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of relying on the stimcirq object to always have _json_dict_ defined, prefer to use getattr instead. Also from your comment on line 266, it's not clear if stimcirq operations will also always have _json_dict_ defined.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand what you are proposing. How do we know which attributes to get from the stimcirq object otherwise? Also, how do we know how to instantiate the appropriate object on deserialization.

I think saying that a stimcirq gate/operation must have a json_dict to be imported/exported by cirq_google protos is probably a reasonable requirement (plus, we control both repos, so we can add it if it's missing).

I did add a value error if this condition is not met though.

arg_func_langs.arg_to_proto(value=v, out=msg.internalgate.gate_args[k])
else:
# New stimcirq op without a json dict has been introduced
raise ValueError(
f'Cannot serialize stimcirq {op!r}:{type(gate)}'
) # pragma: no cover
else:
raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}')

Expand Down Expand Up @@ -670,7 +697,21 @@ def _deserialize_gate_op(
raise ValueError(f"dimensions {dimensions} for ResetChannel must be an integer!")
op = cirq.ResetChannel(dimension=dimensions)(*qubits)
elif which_gate_type == 'internalgate':
op = arg_func_langs.internal_gate_from_proto(operation_proto.internalgate)(*qubits)
msg = operation_proto.internalgate
if msg.module == _STIMCIRQ_MODULE and msg.name in _stimcirq_json_resolvers():
# special handling for stimcirq
# Use JSON resolver to instantiate the object
kwargs = {}
for k, v in msg.gate_args.items():
arg = arg_func_langs.arg_from_proto(v)
if arg is not None:
kwargs[k] = arg
op = _stimcirq_json_resolvers()[msg.name](**kwargs)
if qubits:
op = op(*qubits)
else:
# all other internal gates
op = arg_func_langs.internal_gate_from_proto(msg)(*qubits)
elif which_gate_type == 'couplerpulsegate':
gate = CouplerPulse(
hold_time=cirq.Duration(
Expand Down Expand Up @@ -766,4 +807,16 @@ def _deserialize_tag(self, msg: v2.program_pb2.Tag):
return None


@functools.cache
def _stimcirq_json_resolvers():
"""Retrieves stimcirq JSON resolvers if stimcirq is installed.
Returns an empty dict if not installed."""
try:
import stimcirq

return stimcirq.JSON_RESOLVERS_DICT
except ModuleNotFoundError: # pragma: no cover
return {} # pragma: no cover


CIRCUIT_SERIALIZER = CircuitSerializer()
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,18 @@ def test_reset_gate_with_improper_argument():

with pytest.raises(ValueError, match="must be an integer"):
serializer.deserialize(circuit_proto)


def test_stimcirq_gates():
stimcirq = pytest.importorskip("stimcirq")
serializer = cg.CircuitSerializer()
q = cirq.q(1, 2)
q2 = cirq.q(2, 2)
c = cirq.Circuit(
cirq.Moment(stimcirq.CXSwapGate(inverted=True)(q, q2)),
cirq.Moment(cirq.measure(q, key="m")),
cirq.Moment(stimcirq.DetAnnotation(parity_keys=["m"])),
)
msg = serializer.serialize(c)
deserialized_circuit = serializer.deserialize(msg)
assert deserialized_circuit == c
2 changes: 1 addition & 1 deletion dev_tools/conf/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ follow_imports = silent
ignore_missing_imports = true

# Non-Google
[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*]
[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*,stimcirq.*]
follow_imports = silent
ignore_missing_imports = true

Expand Down
3 changes: 3 additions & 0 deletions dev_tools/requirements/deps/dev-tools.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ asv

# For verifying behavior of qasm output.
qiskit-aer~=0.16.1

# For testing stimcirq compatibility (cirq-google)
stimcirq