diff --git a/cirq-core/cirq/ops/gate_operation_test.py b/cirq-core/cirq/ops/gate_operation_test.py index d911c25261b..be7abe25932 100644 --- a/cirq-core/cirq/ops/gate_operation_test.py +++ b/cirq-core/cirq/ops/gate_operation_test.py @@ -490,7 +490,9 @@ def all_subclasses(cls): gate_subclasses = { g for g in all_subclasses(cirq.Gate) - if "cirq." in g.__module__ and "contrib" not in g.__module__ and "test" not in g.__module__ + if g.__module__.startswith("cirq.") + and "contrib" not in g.__module__ + and "test" not in g.__module__ } test_module_spec = cirq.testing.json.spec_for("cirq.protocols") diff --git a/cirq-google/cirq_google/serialization/circuit_serializer.py b/cirq-google/cirq_google/serialization/circuit_serializer.py index 085586ce093..10399f3f76c 100644 --- a/cirq-google/cirq_google/serialization/circuit_serializer.py +++ b/cirq-google/cirq_google/serialization/circuit_serializer.py @@ -15,6 +15,7 @@ """Support for serializing and deserializing cirq_google.api.v2 protos.""" from typing import Any, Dict, List, Optional +import functools import warnings import numpy as np import sympy @@ -38,6 +39,9 @@ # CircuitSerializer is the dedicated serializer for the v2.5 format. _SERIALIZER_NAME = 'v2_5' +# Package name for stimcirq +_STIMCIRQ_MODULE = "stimcirq" + class CircuitSerializer(serializer.Serializer): """A class for serializing and deserializing programs and operations. @@ -193,7 +197,6 @@ def _serialize_gate_op( ValueError: If the operation cannot be serialized. """ gate = op.gate - if isinstance(gate, InternalGate): arg_func_langs.internal_gate_arg_to_proto(gate, out=msg.internalgate) elif isinstance(gate, cirq.XPowGate): @@ -260,6 +263,30 @@ def _serialize_gate_op( arg_func_langs.float_arg_to_proto( gate.q1_detune_mhz, out=msg.couplerpulsegate.q1_detune_mhz ) + elif getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) or getattr( + gate, "__module__", "" + ).startswith(_STIMCIRQ_MODULE): + # Special handling for stimcirq objects, which can be both operations and gates. + stimcirq_obj = ( + op if getattr(op, "__module__", "").startswith(_STIMCIRQ_MODULE) else gate + ) + if stimcirq_obj is not None and hasattr(stimcirq_obj, '_json_dict_'): + # All stimcirq gates currently have _json_dict_defined + msg.internalgate.name = type(stimcirq_obj).__name__ + msg.internalgate.module = _STIMCIRQ_MODULE + if isinstance(stimcirq_obj, cirq.Gate): + msg.internalgate.num_qubits = stimcirq_obj.num_qubits() + else: + msg.internalgate.num_qubits = len(stimcirq_obj.qubits) + + # Store json_dict objects in gate_args + for k, v in stimcirq_obj._json_dict_().items(): + arg_func_langs.arg_to_proto(value=v, out=msg.internalgate.gate_args[k]) + else: + # New stimcirq op without a json dict has been introduced + raise ValueError( + f'Cannot serialize stimcirq {op!r}:{type(gate)}' + ) # pragma: no cover else: raise ValueError(f'Cannot serialize op {op!r} of type {type(gate)}') @@ -670,7 +697,21 @@ def _deserialize_gate_op( raise ValueError(f"dimensions {dimensions} for ResetChannel must be an integer!") op = cirq.ResetChannel(dimension=dimensions)(*qubits) elif which_gate_type == 'internalgate': - op = arg_func_langs.internal_gate_from_proto(operation_proto.internalgate)(*qubits) + msg = operation_proto.internalgate + if msg.module == _STIMCIRQ_MODULE and msg.name in _stimcirq_json_resolvers(): + # special handling for stimcirq + # Use JSON resolver to instantiate the object + kwargs = {} + for k, v in msg.gate_args.items(): + arg = arg_func_langs.arg_from_proto(v) + if arg is not None: + kwargs[k] = arg + op = _stimcirq_json_resolvers()[msg.name](**kwargs) + if qubits: + op = op(*qubits) + else: + # all other internal gates + op = arg_func_langs.internal_gate_from_proto(msg)(*qubits) elif which_gate_type == 'couplerpulsegate': gate = CouplerPulse( hold_time=cirq.Duration( @@ -766,4 +807,16 @@ def _deserialize_tag(self, msg: v2.program_pb2.Tag): return None +@functools.cache +def _stimcirq_json_resolvers(): + """Retrieves stimcirq JSON resolvers if stimcirq is installed. + Returns an empty dict if not installed.""" + try: + import stimcirq + + return stimcirq.JSON_RESOLVERS_DICT + except ModuleNotFoundError: # pragma: no cover + return {} # pragma: no cover + + CIRCUIT_SERIALIZER = CircuitSerializer() diff --git a/cirq-google/cirq_google/serialization/circuit_serializer_test.py b/cirq-google/cirq_google/serialization/circuit_serializer_test.py index 311f3f3341a..4502821b4b2 100644 --- a/cirq-google/cirq_google/serialization/circuit_serializer_test.py +++ b/cirq-google/cirq_google/serialization/circuit_serializer_test.py @@ -1043,3 +1043,18 @@ def test_reset_gate_with_improper_argument(): with pytest.raises(ValueError, match="must be an integer"): serializer.deserialize(circuit_proto) + + +def test_stimcirq_gates(): + stimcirq = pytest.importorskip("stimcirq") + serializer = cg.CircuitSerializer() + q = cirq.q(1, 2) + q2 = cirq.q(2, 2) + c = cirq.Circuit( + cirq.Moment(stimcirq.CXSwapGate(inverted=True)(q, q2)), + cirq.Moment(cirq.measure(q, key="m")), + cirq.Moment(stimcirq.DetAnnotation(parity_keys=["m"])), + ) + msg = serializer.serialize(c) + deserialized_circuit = serializer.deserialize(msg) + assert deserialized_circuit == c diff --git a/dev_tools/conf/mypy.ini b/dev_tools/conf/mypy.ini index bed12adf7a0..0f5e589e654 100644 --- a/dev_tools/conf/mypy.ini +++ b/dev_tools/conf/mypy.ini @@ -21,7 +21,7 @@ follow_imports = silent ignore_missing_imports = true # Non-Google -[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*] +[mypy-IPython.*,sympy.*,matplotlib.*,proto.*,pandas.*,scipy.*,freezegun.*,mpl_toolkits.*,networkx.*,ply.*,astroid.*,pytest.*,_pytest.*,pylint.*,setuptools.*,qiskit.*,quimb.*,pylatex.*,filelock.*,sortedcontainers.*,tqdm.*,ruamel.*,absl.*,tensorflow_docs.*,ipywidgets.*,cachetools.*,stimcirq.*] follow_imports = silent ignore_missing_imports = true diff --git a/dev_tools/requirements/deps/dev-tools.txt b/dev_tools/requirements/deps/dev-tools.txt index 62eeb5f9ece..b5e44db5870 100644 --- a/dev_tools/requirements/deps/dev-tools.txt +++ b/dev_tools/requirements/deps/dev-tools.txt @@ -12,3 +12,6 @@ asv # For verifying behavior of qasm output. qiskit-aer~=0.16.1 + +# For testing stimcirq compatibility (cirq-google) +stimcirq