diff --git a/cirq-core/cirq/linalg/decompositions.py b/cirq-core/cirq/linalg/decompositions.py index 0a6c52062d1..f176507ded6 100644 --- a/cirq-core/cirq/linalg/decompositions.py +++ b/cirq-core/cirq/linalg/decompositions.py @@ -24,7 +24,6 @@ Iterable, List, Optional, - Set, Tuple, TYPE_CHECKING, TypeVar, @@ -106,29 +105,6 @@ def deconstruct_single_qubit_matrix_into_angles(mat: np.ndarray) -> Tuple[float, return right_phase + diagonal_phase, rotation * 2, bottom_phase -def _group_similar(items: List[T], comparer: Callable[[T, T], bool]) -> List[List[T]]: - """Combines similar items into groups. - - Args: - items: The list of items to group. - comparer: Determines if two items are similar. - - Returns: - A list of groups of items. - """ - groups: List[List[T]] = [] - used: Set[int] = set() - for i in range(len(items)): - if i not in used: - group = [items[i]] - for j in range(i + 1, len(items)): - if j not in used and comparer(items[i], items[j]): - used.add(j) - group.append(items[j]) - groups.append(group) - return groups - - def unitary_eig( matrix: np.ndarray, check_preconditions: bool = True, atol: float = 1e-8 ) -> Tuple[np.ndarray, np.ndarray]: @@ -175,7 +151,6 @@ def map_eigenvalues( Args: matrix: The matrix to modify with the function. func: The function to apply to the eigenvalues of the matrix. - rtol: Relative threshold used when separating eigenspaces. atol: Absolute threshold used when separating eigenspaces. Returns: @@ -191,15 +166,18 @@ def map_eigenvalues( return total -def kron_factor_4x4_to_2x2s(matrix: np.ndarray) -> Tuple[complex, np.ndarray, np.ndarray]: +def kron_factor_4x4_to_2x2s( + matrix: np.ndarray, rtol=1e-5, atol=1e-8 +) -> Tuple[complex, np.ndarray, np.ndarray]: """Splits a 4x4 matrix U = kron(A, B) into A, B, and a global factor. Requires the matrix to be the kronecker product of two 2x2 unitaries. Requires the matrix to have a non-zero determinant. - Giving an incorrect matrix will cause garbage output. Args: matrix: The 4x4 unitary matrix to factor. + rtol: Per-matrix-entry relative tolerance on equality. + atol: Per-matrix-entry absolute tolerance on equality. Returns: A scalar factor and a pair of 2x2 unit-determinant matrices. The @@ -232,6 +210,9 @@ def kron_factor_4x4_to_2x2s(matrix: np.ndarray) -> Tuple[complex, np.ndarray, np f1 *= -1 g = -g + if not np.allclose(matrix, g * np.kron(f1, f2), rtol=rtol, atol=atol): + raise ValueError("Invalid 4x4 kronecker product.") + return g, f1, f2 @@ -266,7 +247,7 @@ def so4_to_magic_su2s( raise ValueError('mat must be 4x4 special orthogonal.') ab = combinators.dot(MAGIC, mat, MAGIC_CONJ_T) - _, a, b = kron_factor_4x4_to_2x2s(ab) + _, a, b = kron_factor_4x4_to_2x2s(ab, rtol, atol) return a, b @@ -987,7 +968,7 @@ def _canonicalize_kak_vector(k_vec: np.ndarray, atol: float) -> np.ndarray: unitaries required to bring the KAK vector into canonical form. Args: - k_vec: THe KAK vector to be canonicalized. This input may be vectorized, + k_vec: The KAK vector to be canonicalized. This input may be vectorized, with shape (...,3), where the final axis denotes the k_vector and all other axes are broadcast. atol: How close x2 must be to π/4 to guarantee z2 >= 0. diff --git a/cirq-core/cirq/linalg/decompositions_test.py b/cirq-core/cirq/linalg/decompositions_test.py index 3bfc2717792..b99a5618146 100644 --- a/cirq-core/cirq/linalg/decompositions_test.py +++ b/cirq-core/cirq/linalg/decompositions_test.py @@ -20,6 +20,7 @@ import cirq from cirq import value from cirq import unitary_eig +from cirq.linalg.decompositions import MAGIC, MAGIC_CONJ_T X = np.array([[0, 1], [1, 0]]) Y = np.array([[0, -1j], [1j, 0]]) @@ -45,9 +46,7 @@ def assert_kronecker_factorization_not_within_tolerance(matrix, g, f1, f2): def assert_magic_su2_within_tolerance(mat, a, b): - M = cirq.linalg.decompositions.MAGIC - MT = cirq.linalg.decompositions.MAGIC_CONJ_T - recon = cirq.linalg.combinators.dot(MT, cirq.linalg.combinators.kron(a, b), M) + recon = cirq.linalg.combinators.dot(MAGIC_CONJ_T, cirq.linalg.combinators.kron(a, b), MAGIC) assert np.allclose(recon, mat), "Failed to decompose within tolerance." @@ -149,14 +148,15 @@ def test_kron_factor_special_unitaries(f1, f2): assert_kronecker_factorization_within_tolerance(p, g, g1, g2) -def test_kron_factor_fail(): - mat = cirq.kron_with_controls(cirq.CONTROL_TAG, X) - g, f1, f2 = cirq.kron_factor_4x4_to_2x2s(mat) - with pytest.raises(ValueError): - assert_kronecker_factorization_not_within_tolerance(mat, g, f1, f2) - mat = cirq.kron_factor_4x4_to_2x2s(np.diag([1, 1, 1, 1j])) - with pytest.raises(ValueError): - assert_kronecker_factorization_not_within_tolerance(mat, g, f1, f2) +def test_kron_factor_invalid_input(): + mats = [ + cirq.kron_with_controls(cirq.CONTROL_TAG, X), + np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 2, 3, 4]]), + np.diag([1, 1, 1, 1j]), + ] + for mat in mats: + with pytest.raises(ValueError, match="Invalid 4x4 kronecker product"): + cirq.kron_factor_4x4_to_2x2s(mat) def recompose_so4(a: np.ndarray, b: np.ndarray) -> np.ndarray: @@ -165,8 +165,7 @@ def recompose_so4(a: np.ndarray, b: np.ndarray) -> np.ndarray: assert cirq.is_special_unitary(a) assert cirq.is_special_unitary(b) - magic = np.array([[1, 0, 0, 1j], [0, 1j, 1, 0], [0, 1j, -1, 0], [1, 0, 0, -1j]]) * np.sqrt(0.5) - result = np.real(cirq.dot(np.conj(magic.T), cirq.kron(a, b), magic)) + result = np.real(cirq.dot(MAGIC_CONJ_T, cirq.kron(a, b), MAGIC)) assert cirq.is_orthogonal(result) return result @@ -656,7 +655,7 @@ def test_kak_vector_matches_vectorized(): np.testing.assert_almost_equal(actual, expected) -def test_KAK_vector_local_invariants_random_input(): +def test_kak_vector_local_invariants_random_input(): actual = _local_invariants_from_kak(cirq.kak_vector(_random_unitaries)) expected = _local_invariants_from_kak(_kak_vecs) @@ -697,7 +696,7 @@ def test_kak_vector_on_weyl_chamber_face(): (np.kron(X, X), (0, 0, 0)), ), ) -def test_KAK_vector_weyl_chamber_vertices(unitary, expected): +def test_kak_vector_weyl_chamber_vertices(unitary, expected): actual = cirq.kak_vector(unitary) np.testing.assert_almost_equal(actual, expected)