From 2ad2f0b11bf1730100f169a8feb20c6e6866a4d4 Mon Sep 17 00:00:00 2001 From: daxfo Date: Mon, 3 Mar 2025 16:26:12 -0800 Subject: [PATCH] Add compensating global phase op if needed to MatrixGate decomposition. --- cirq-core/cirq/ops/controlled_gate.py | 7 ----- cirq-core/cirq/ops/controlled_gate_test.py | 2 +- cirq-core/cirq/ops/matrix_gates.py | 31 +++++++++++++++---- .../three_qubit_decomposition.py | 2 +- 4 files changed, 27 insertions(+), 15 deletions(-) diff --git a/cirq-core/cirq/ops/controlled_gate.py b/cirq-core/cirq/ops/controlled_gate.py index 96356b139ad..132227f2c34 100644 --- a/cirq-core/cirq/ops/controlled_gate.py +++ b/cirq-core/cirq/ops/controlled_gate.py @@ -34,7 +34,6 @@ controlled_operation as cop, diagonal_gate as dg, global_phase_op as gp, - matrix_gates, op_tree, raw_types, ) @@ -220,12 +219,6 @@ def _decompose_with_context_( control_qid_shape=self.control_qid_shape, ).on(*control_qubits) return [result, controlled_phase_op] - - if isinstance(self.sub_gate, matrix_gates.MatrixGate): - # Default decompositions of 2/3 qubit `cirq.MatrixGate` ignores global phase, which is - # local phase in the controlled variant and hence cannot be ignored. - return NotImplemented - result = protocols.decompose_once_with_qubits( self.sub_gate, qubits[self.num_controls() :], diff --git a/cirq-core/cirq/ops/controlled_gate_test.py b/cirq-core/cirq/ops/controlled_gate_test.py index c908418f806..ebff6b9c709 100644 --- a/cirq-core/cirq/ops/controlled_gate_test.py +++ b/cirq-core/cirq/ops/controlled_gate_test.py @@ -491,7 +491,7 @@ def _test_controlled_gate_is_consistent( decomposed = cirq.decompose(cgate.on(*qids)) first_op = cirq.IdentityGate(qid_shape=shape).on(*qids) # To ensure same qid order circuit = cirq.Circuit(first_op, *decomposed) - assert cirq.equal_up_to_global_phase(cirq.unitary(cgate), cirq.unitary(circuit)) + np.testing.assert_allclose(cirq.unitary(cgate), cirq.unitary(circuit), atol=1e-13) def test_pow_inverse(): diff --git a/cirq-core/cirq/ops/matrix_gates.py b/cirq-core/cirq/ops/matrix_gates.py index 6d5ee10fc7f..8b5ebc01702 100644 --- a/cirq-core/cirq/ops/matrix_gates.py +++ b/cirq-core/cirq/ops/matrix_gates.py @@ -14,13 +14,13 @@ """Quantum gates defined by a matrix.""" -from typing import Any, Dict, Iterable, Optional, Tuple, TYPE_CHECKING +from typing import Any, Dict, Iterable, List, Optional, Tuple, TYPE_CHECKING import numpy as np from cirq import linalg, protocols, _import from cirq._compat import proper_repr -from cirq.ops import raw_types, phased_x_z_gate +from cirq.ops import raw_types, phased_x_z_gate, global_phase_op as gp, identity if TYPE_CHECKING: import cirq @@ -148,18 +148,37 @@ def _phase_by_(self, phase_turns: float, qubit_index: int) -> 'MatrixGate': return MatrixGate(matrix=result.reshape(self._matrix.shape), qid_shape=self._qid_shape) def _decompose_(self, qubits: Tuple['cirq.Qid', ...]) -> 'cirq.OP_TREE': + decomposed: List['cirq.Operation'] = NotImplemented if self._qid_shape == (2,): - return [ + decomposed = [ g.on(qubits[0]) for g in single_qubit_decompositions.single_qubit_matrix_to_gates(self._matrix) ] if self._qid_shape == (2,) * 2: - return two_qubit_to_cz.two_qubit_matrix_to_cz_operations( + decomposed = two_qubit_to_cz.two_qubit_matrix_to_cz_operations( *qubits, self._matrix, allow_partial_czs=True ) if self._qid_shape == (2,) * 3: - return three_qubit_decomposition.three_qubit_matrix_to_operations(*qubits, self._matrix) - return NotImplemented + decomposed = three_qubit_decomposition.three_qubit_matrix_to_operations( + *qubits, self._matrix + ) + if decomposed is NotImplemented: + return NotImplemented + # The above algorithms ignore phase, but phase is important to maintain if the gate is + # controlled. Here, we add it back in with a global phase op. + from cirq.circuits import Circuit + + ident = identity.IdentityGate(qid_shape=self._qid_shape).on(*qubits) + u = protocols.unitary(Circuit(ident, *decomposed)).reshape(self._matrix.shape) + # All cells will have the same phase difference. Just choose the cell with the largest + # absolute value, to minimize rounding error. + max_index = np.unravel_index(np.abs(self._matrix).argmax(), self._matrix.shape) + phase_delta = self._matrix[max_index] / u[max_index] + # Phase delta is on the complex unit circle, so if real(phase_delta) >= 1, that means + # no phase delta. (>1 is rounding error). + if phase_delta.real < 1: + decomposed.append(gp.global_phase_operation(phase_delta)) + return decomposed def _has_unitary_(self) -> bool: return True diff --git a/cirq-core/cirq/transformers/analytical_decompositions/three_qubit_decomposition.py b/cirq-core/cirq/transformers/analytical_decompositions/three_qubit_decomposition.py index 0990e4dba24..186f1905f56 100644 --- a/cirq-core/cirq/transformers/analytical_decompositions/three_qubit_decomposition.py +++ b/cirq-core/cirq/transformers/analytical_decompositions/three_qubit_decomposition.py @@ -25,7 +25,7 @@ def three_qubit_matrix_to_operations( q0: ops.Qid, q1: ops.Qid, q2: ops.Qid, u: np.ndarray, atol: float = 1e-8 -) -> Sequence[ops.Operation]: +) -> List[ops.Operation]: """Returns operations for a 3 qubit unitary. The algorithm is described in Shende et al.: