Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small fixes of linalg.decompositions #7128

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 10 additions & 29 deletions cirq-core/cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Iterable,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
Expand Down Expand Up @@ -106,29 +105,6 @@ def deconstruct_single_qubit_matrix_into_angles(mat: np.ndarray) -> Tuple[float,
return right_phase + diagonal_phase, rotation * 2, bottom_phase


def _group_similar(items: List[T], comparer: Callable[[T, T], bool]) -> List[List[T]]:
"""Combines similar items into groups.

Args:
items: The list of items to group.
comparer: Determines if two items are similar.

Returns:
A list of groups of items.
"""
groups: List[List[T]] = []
used: Set[int] = set()
for i in range(len(items)):
if i not in used:
group = [items[i]]
for j in range(i + 1, len(items)):
if j not in used and comparer(items[i], items[j]):
used.add(j)
group.append(items[j])
groups.append(group)
return groups


def unitary_eig(
matrix: np.ndarray, check_preconditions: bool = True, atol: float = 1e-8
) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -175,7 +151,6 @@ def map_eigenvalues(
Args:
matrix: The matrix to modify with the function.
func: The function to apply to the eigenvalues of the matrix.
rtol: Relative threshold used when separating eigenspaces.
atol: Absolute threshold used when separating eigenspaces.

Returns:
Expand All @@ -191,15 +166,18 @@ def map_eigenvalues(
return total


def kron_factor_4x4_to_2x2s(matrix: np.ndarray) -> Tuple[complex, np.ndarray, np.ndarray]:
def kron_factor_4x4_to_2x2s(
matrix: np.ndarray, rtol=1e-5, atol=1e-8
) -> Tuple[complex, np.ndarray, np.ndarray]:
"""Splits a 4x4 matrix U = kron(A, B) into A, B, and a global factor.

Requires the matrix to be the kronecker product of two 2x2 unitaries.
Requires the matrix to have a non-zero determinant.
Giving an incorrect matrix will cause garbage output.

Args:
matrix: The 4x4 unitary matrix to factor.
rtol: Per-matrix-entry relative tolerance on equality.
atol: Per-matrix-entry absolute tolerance on equality.

Returns:
A scalar factor and a pair of 2x2 unit-determinant matrices. The
Expand Down Expand Up @@ -232,6 +210,9 @@ def kron_factor_4x4_to_2x2s(matrix: np.ndarray) -> Tuple[complex, np.ndarray, np
f1 *= -1
g = -g

if not np.allclose(matrix, g * np.kron(f1, f2), rtol=rtol, atol=atol):
raise ValueError("Invalid 4x4 kronecker product.")

return g, f1, f2


Expand Down Expand Up @@ -266,7 +247,7 @@ def so4_to_magic_su2s(
raise ValueError('mat must be 4x4 special orthogonal.')

ab = combinators.dot(MAGIC, mat, MAGIC_CONJ_T)
_, a, b = kron_factor_4x4_to_2x2s(ab)
_, a, b = kron_factor_4x4_to_2x2s(ab, rtol, atol)

return a, b

Expand Down Expand Up @@ -987,7 +968,7 @@ def _canonicalize_kak_vector(k_vec: np.ndarray, atol: float) -> np.ndarray:
unitaries required to bring the KAK vector into canonical form.

Args:
k_vec: THe KAK vector to be canonicalized. This input may be vectorized,
k_vec: The KAK vector to be canonicalized. This input may be vectorized,
with shape (...,3), where the final axis denotes the k_vector and
all other axes are broadcast.
atol: How close x2 must be to π/4 to guarantee z2 >= 0.
Expand Down
29 changes: 14 additions & 15 deletions cirq-core/cirq/linalg/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import cirq
from cirq import value
from cirq import unitary_eig
from cirq.linalg.decompositions import MAGIC, MAGIC_CONJ_T

X = np.array([[0, 1], [1, 0]])
Y = np.array([[0, -1j], [1j, 0]])
Expand All @@ -45,9 +46,7 @@ def assert_kronecker_factorization_not_within_tolerance(matrix, g, f1, f2):


def assert_magic_su2_within_tolerance(mat, a, b):
M = cirq.linalg.decompositions.MAGIC
MT = cirq.linalg.decompositions.MAGIC_CONJ_T
recon = cirq.linalg.combinators.dot(MT, cirq.linalg.combinators.kron(a, b), M)
recon = cirq.linalg.combinators.dot(MAGIC_CONJ_T, cirq.linalg.combinators.kron(a, b), MAGIC)
assert np.allclose(recon, mat), "Failed to decompose within tolerance."


Expand Down Expand Up @@ -149,14 +148,15 @@ def test_kron_factor_special_unitaries(f1, f2):
assert_kronecker_factorization_within_tolerance(p, g, g1, g2)


def test_kron_factor_fail():
mat = cirq.kron_with_controls(cirq.CONTROL_TAG, X)
g, f1, f2 = cirq.kron_factor_4x4_to_2x2s(mat)
with pytest.raises(ValueError):
assert_kronecker_factorization_not_within_tolerance(mat, g, f1, f2)
mat = cirq.kron_factor_4x4_to_2x2s(np.diag([1, 1, 1, 1j]))
with pytest.raises(ValueError):
assert_kronecker_factorization_not_within_tolerance(mat, g, f1, f2)
def test_kron_factor_invalid_input():
mats = [
cirq.kron_with_controls(cirq.CONTROL_TAG, X),
np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 2, 3, 4]]),
np.diag([1, 1, 1, 1j]),
]
for mat in mats:
with pytest.raises(ValueError, match="Invalid 4x4 kronecker product"):
cirq.kron_factor_4x4_to_2x2s(mat)


def recompose_so4(a: np.ndarray, b: np.ndarray) -> np.ndarray:
Expand All @@ -165,8 +165,7 @@ def recompose_so4(a: np.ndarray, b: np.ndarray) -> np.ndarray:
assert cirq.is_special_unitary(a)
assert cirq.is_special_unitary(b)

magic = np.array([[1, 0, 0, 1j], [0, 1j, 1, 0], [0, 1j, -1, 0], [1, 0, 0, -1j]]) * np.sqrt(0.5)
result = np.real(cirq.dot(np.conj(magic.T), cirq.kron(a, b), magic))
result = np.real(cirq.dot(MAGIC_CONJ_T, cirq.kron(a, b), MAGIC))
assert cirq.is_orthogonal(result)
return result

Expand Down Expand Up @@ -656,7 +655,7 @@ def test_kak_vector_matches_vectorized():
np.testing.assert_almost_equal(actual, expected)


def test_KAK_vector_local_invariants_random_input():
def test_kak_vector_local_invariants_random_input():
actual = _local_invariants_from_kak(cirq.kak_vector(_random_unitaries))
expected = _local_invariants_from_kak(_kak_vecs)

Expand Down Expand Up @@ -697,7 +696,7 @@ def test_kak_vector_on_weyl_chamber_face():
(np.kron(X, X), (0, 0, 0)),
),
)
def test_KAK_vector_weyl_chamber_vertices(unitary, expected):
def test_kak_vector_weyl_chamber_vertices(unitary, expected):
actual = cirq.kak_vector(unitary)
np.testing.assert_almost_equal(actual, expected)

Expand Down