Skip to content

Commit

Permalink
Small fixes of linalg.decomposition.
Browse files Browse the repository at this point in the history
  • Loading branch information
babacry committed Mar 7, 2025
1 parent 78d30e7 commit fb02d0a
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 44 deletions.
37 changes: 8 additions & 29 deletions cirq-core/cirq/linalg/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
Iterable,
List,
Optional,
Set,
Tuple,
TYPE_CHECKING,
TypeVar,
Expand Down Expand Up @@ -106,29 +105,6 @@ def deconstruct_single_qubit_matrix_into_angles(mat: np.ndarray) -> Tuple[float,
return right_phase + diagonal_phase, rotation * 2, bottom_phase


def _group_similar(items: List[T], comparer: Callable[[T, T], bool]) -> List[List[T]]:
"""Combines similar items into groups.
Args:
items: The list of items to group.
comparer: Determines if two items are similar.
Returns:
A list of groups of items.
"""
groups: List[List[T]] = []
used: Set[int] = set()
for i in range(len(items)):
if i not in used:
group = [items[i]]
for j in range(i + 1, len(items)):
if j not in used and comparer(items[i], items[j]):
used.add(j)
group.append(items[j])
groups.append(group)
return groups


def unitary_eig(
matrix: np.ndarray, check_preconditions: bool = True, atol: float = 1e-8
) -> Tuple[np.ndarray, np.ndarray]:
Expand Down Expand Up @@ -175,7 +151,6 @@ def map_eigenvalues(
Args:
matrix: The matrix to modify with the function.
func: The function to apply to the eigenvalues of the matrix.
rtol: Relative threshold used when separating eigenspaces.
atol: Absolute threshold used when separating eigenspaces.
Returns:
Expand All @@ -191,12 +166,13 @@ def map_eigenvalues(
return total


def kron_factor_4x4_to_2x2s(matrix: np.ndarray) -> Tuple[complex, np.ndarray, np.ndarray]:
def kron_factor_4x4_to_2x2s(
matrix: np.ndarray, rtol=1e-5, atol=1e-8
) -> Tuple[complex, np.ndarray, np.ndarray]:
"""Splits a 4x4 matrix U = kron(A, B) into A, B, and a global factor.
Requires the matrix to be the kronecker product of two 2x2 unitaries.
Requires the matrix to have a non-zero determinant.
Giving an incorrect matrix will cause garbage output.
Args:
matrix: The 4x4 unitary matrix to factor.
Expand Down Expand Up @@ -232,6 +208,9 @@ def kron_factor_4x4_to_2x2s(matrix: np.ndarray) -> Tuple[complex, np.ndarray, np
f1 *= -1
g = -g

if not np.allclose(matrix, g * np.kron(f1, f2), rtol=rtol, atol=atol):
raise ValueError("Invalid 4x4 kronecker product.")

return g, f1, f2


Expand Down Expand Up @@ -266,7 +245,7 @@ def so4_to_magic_su2s(
raise ValueError('mat must be 4x4 special orthogonal.')

ab = combinators.dot(MAGIC, mat, MAGIC_CONJ_T)
_, a, b = kron_factor_4x4_to_2x2s(ab)
_, a, b = kron_factor_4x4_to_2x2s(ab, rtol, atol)

return a, b

Expand Down Expand Up @@ -987,7 +966,7 @@ def _canonicalize_kak_vector(k_vec: np.ndarray, atol: float) -> np.ndarray:
unitaries required to bring the KAK vector into canonical form.
Args:
k_vec: THe KAK vector to be canonicalized. This input may be vectorized,
k_vec: The KAK vector to be canonicalized. This input may be vectorized,
with shape (...,3), where the final axis denotes the k_vector and
all other axes are broadcast.
atol: How close x2 must be to π/4 to guarantee z2 >= 0.
Expand Down
29 changes: 14 additions & 15 deletions cirq-core/cirq/linalg/decompositions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import cirq
from cirq import value
from cirq import unitary_eig
from cirq.linalg.decompositions import MAGIC, MAGIC_CONJ_T

X = np.array([[0, 1], [1, 0]])
Y = np.array([[0, -1j], [1j, 0]])
Expand All @@ -45,9 +46,7 @@ def assert_kronecker_factorization_not_within_tolerance(matrix, g, f1, f2):


def assert_magic_su2_within_tolerance(mat, a, b):
M = cirq.linalg.decompositions.MAGIC
MT = cirq.linalg.decompositions.MAGIC_CONJ_T
recon = cirq.linalg.combinators.dot(MT, cirq.linalg.combinators.kron(a, b), M)
recon = cirq.linalg.combinators.dot(MAGIC_CONJ_T, cirq.linalg.combinators.kron(a, b), MAGIC)
assert np.allclose(recon, mat), "Failed to decompose within tolerance."


Expand Down Expand Up @@ -149,14 +148,15 @@ def test_kron_factor_special_unitaries(f1, f2):
assert_kronecker_factorization_within_tolerance(p, g, g1, g2)


def test_kron_factor_fail():
mat = cirq.kron_with_controls(cirq.CONTROL_TAG, X)
g, f1, f2 = cirq.kron_factor_4x4_to_2x2s(mat)
with pytest.raises(ValueError):
assert_kronecker_factorization_not_within_tolerance(mat, g, f1, f2)
mat = cirq.kron_factor_4x4_to_2x2s(np.diag([1, 1, 1, 1j]))
with pytest.raises(ValueError):
assert_kronecker_factorization_not_within_tolerance(mat, g, f1, f2)
def test_kron_factor_invalid_input():
mats = [
cirq.kron_with_controls(cirq.CONTROL_TAG, X),
np.array([[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 2, 3, 4]]),
np.diag([1, 1, 1, 1j]),
]
for mat in mats:
with pytest.raises(ValueError, match="Invalid 4x4 kronecker product"):
cirq.kron_factor_4x4_to_2x2s(mat)


def recompose_so4(a: np.ndarray, b: np.ndarray) -> np.ndarray:
Expand All @@ -165,8 +165,7 @@ def recompose_so4(a: np.ndarray, b: np.ndarray) -> np.ndarray:
assert cirq.is_special_unitary(a)
assert cirq.is_special_unitary(b)

magic = np.array([[1, 0, 0, 1j], [0, 1j, 1, 0], [0, 1j, -1, 0], [1, 0, 0, -1j]]) * np.sqrt(0.5)
result = np.real(cirq.dot(np.conj(magic.T), cirq.kron(a, b), magic))
result = np.real(cirq.dot(MAGIC_CONJ_T, cirq.kron(a, b), MAGIC))
assert cirq.is_orthogonal(result)
return result

Expand Down Expand Up @@ -656,7 +655,7 @@ def test_kak_vector_matches_vectorized():
np.testing.assert_almost_equal(actual, expected)


def test_KAK_vector_local_invariants_random_input():
def test_kak_vector_local_invariants_random_input():
actual = _local_invariants_from_kak(cirq.kak_vector(_random_unitaries))
expected = _local_invariants_from_kak(_kak_vecs)

Expand Down Expand Up @@ -697,7 +696,7 @@ def test_kak_vector_on_weyl_chamber_face():
(np.kron(X, X), (0, 0, 0)),
),
)
def test_KAK_vector_weyl_chamber_vertices(unitary, expected):
def test_kak_vector_weyl_chamber_vertices(unitary, expected):
actual = cirq.kak_vector(unitary)
np.testing.assert_almost_equal(actual, expected)

Expand Down

0 comments on commit fb02d0a

Please sign in to comment.