Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for stimcirq gates and operations in cirq_google protos #7101

Merged
merged 14 commits into from
Mar 7, 2025
Merged
4 changes: 3 additions & 1 deletion cirq-core/cirq/ops/gate_operation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ def all_subclasses(cls):
gate_subclasses = {
g
for g in all_subclasses(cirq.Gate)
if "cirq." in g.__module__ and "contrib" not in g.__module__ and "test" not in g.__module__
if g.__module__.startswith("cirq.")
and "contrib" not in g.__module__
and "test" not in g.__module__
}

test_module_spec = cirq.testing.json.spec_for("cirq.protocols")
Expand Down
54 changes: 52 additions & 2 deletions cirq-google/cirq_google/serialization/circuit_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
# CircuitSerializer is the dedicated serializer for the v2.5 format.
_SERIALIZER_NAME = 'v2_5'

# Package name for stimcirq
_STIMCIRQ_MODULE = "stimcirq"


class CircuitSerializer(serializer.Serializer):
"""A class for serializing and deserializing programs and operations.
Expand Down Expand Up @@ -193,7 +196,6 @@ def _serialize_gate_op(
ValueError: If the operation cannot be serialized.
"""
gate = op.gate

if isinstance(gate, InternalGate):
arg_func_langs.internal_gate_arg_to_proto(gate, out=msg.internalgate)
elif isinstance(gate, cirq.XPowGate):
Expand Down Expand Up @@ -260,6 +262,30 @@ def _serialize_gate_op(
arg_func_langs.float_arg_to_proto(
gate.q1_detune_mhz, out=msg.couplerpulsegate.q1_detune_mhz
)
elif getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) or getattr(
gate, "__module__", ""
).startswith(_STIMCIRQ_MODULE):
# Special handling for stimcirq objects, which can be both operations and gates.
stimcirq_obj = (
op if getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) else gate
)
if stimcirq_obj is not None and hasattr(stimcirq_obj, '_json_dict_'):
# All stimcirq gates currently have _json_dict_defined
msg.internalgate.name = type(stimcirq_obj).__name__
msg.internalgate.module = _STIMCIRQ_MODULE
if isinstance(stimcirq_obj, cirq.Gate):
msg.internalgate.num_qubits = stimcirq_obj.num_qubits()
else:
msg.internalgate.num_qubits = len(stimcirq_obj.qubits)

# Store json_dict objects in gate_args
for k, v in stimcirq_obj._json_dict_().items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of relying on the stimcirq object to always have _json_dict_ defined, prefer to use getattr instead. Also from your comment on line 266, it's not clear if stimcirq operations will also always have _json_dict_ defined.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand what you are proposing. How do we know which attributes to get from the stimcirq object otherwise? Also, how do we know how to instantiate the appropriate object on deserialization.

I think saying that a stimcirq gate/operation must have a json_dict to be imported/exported by cirq_google protos is probably a reasonable requirement (plus, we control both repos, so we can add it if it's missing).

I did add a value error if this condition is not met though.

arg_func_langs.arg_to_proto(value=v, out=msg.internalgate.gate_args[k])
else:
# New stimcirq op without a json dict has been introduced
raise ValueError(
f'Cannot serialize stimcirq {op!r}:{type(gate)}'
) # pragma: no cover
else:
raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}')

Expand Down Expand Up @@ -670,7 +696,31 @@ def _deserialize_gate_op(
raise ValueError(f"dimensions {dimensions} for ResetChannel must be an integer!")
op = cirq.ResetChannel(dimension=dimensions)(*qubits)
elif which_gate_type == 'internalgate':
op = arg_func_langs.internal_gate_from_proto(operation_proto.internalgate)(*qubits)
parsed_as_stimcirq = False
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this code to a local _deserialize_with_stimcirq_if_installed function which would return Optional[cirq.Operation] - and if None it will fallback to internal_gate_from_proto below?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After factoring out the import, this code block is pretty short. No one gates have a local function.
Let me know if you still think this is a good idea.

msg = operation_proto.internalgate
if msg.module == _STIMCIRQ_MODULE:
# special handling for stimcirq
try:
import stimcirq
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If stimcirq is not installed this would search sys.path for every deserialized gate.

Consider moving this to a @functools.cache decorated function which returns stimcirq.JSON_RESOLVERS_DICT if installed or an empty dictionary otherwise.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's only for operations with a module name of stimcirq, but, still, this is a good idea, and done.


# Use JSON resolver to instantiate the object
if msg.name in stimcirq.JSON_RESOLVERS_DICT:
kwargs = {}
for k, v in msg.gate_args.items():
arg = arg_func_langs.arg_from_proto(v)
if arg is not None:
kwargs[k] = arg
op = stimcirq.JSON_RESOLVERS_DICT[msg.name](**kwargs)
if qubits:
op = op(*qubits)
parsed_as_stimcirq = True

except ModuleNotFoundError: # pragma: no cover
# fall back to creating internal gates if stimcirq not installed
pass
if not parsed_as_stimcirq:
# all other internal gates
op = arg_func_langs.internal_gate_from_proto(msg)(*qubits)
elif which_gate_type == 'couplerpulsegate':
gate = CouplerPulse(
hold_time=cirq.Duration(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1043,3 +1043,22 @@ def test_reset_gate_with_improper_argument():

with pytest.raises(ValueError, match="must be an integer"):
serializer.deserialize(circuit_proto)


def test_stimcirq_gates():
try:
import stimcirq
except ModuleNotFoundError: # pragma: no cover
# Stimcirq not found, these are optional tests.
return
serializer = cg.CircuitSerializer()
q = cirq.q(1, 2)
q2 = cirq.q(2, 2)
c = cirq.Circuit(
cirq.Moment(stimcirq.CXSwapGate(inverted=True)(q, q2)),
cirq.Moment(cirq.measure(q, key="m")),
cirq.Moment(stimcirq.DetAnnotation(parity_keys=["m"])),
)
msg = serializer.serialize(c)
deserialized_circuit = serializer.deserialize(msg)
assert deserialized_circuit == c
2 changes: 1 addition & 1 deletion dev_tools/conf/mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ follow_imports = silent
ignore_missing_imports = true

# Non-Google
[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*]
[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*,stimcirq.*]
follow_imports = silent
ignore_missing_imports = true

Expand Down
3 changes: 3 additions & 0 deletions dev_tools/requirements/deps/dev-tools.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,6 @@ asv

# For verifying behavior of qasm output.
qiskit-aer~=0.16.1

# For testing stimcirq compatibility (cirq-google)
stimcirq
Loading