diff --git a/python/ffsim/gates/orbital_rotation.py b/python/ffsim/gates/orbital_rotation.py index 19cceea6f..4d4295d5d 100644 --- a/python/ffsim/gates/orbital_rotation.py +++ b/python/ffsim/gates/orbital_rotation.py @@ -271,9 +271,9 @@ def _apply_orbital_rotation_givens( dim_b = comb(norb, n_beta, exact=True) vec = vec.reshape((dim_a, dim_b)) # transform alpha - for givens_mat, target_orbs in givens_rotations: + for (c, s), target_orbs in givens_rotations: _apply_orbital_rotation_adjacent_spin_in_place( - vec, givens_mat.conj(), target_orbs, norb, n_alpha + vec, c, s.conjugate(), target_orbs, norb, n_alpha ) for i, phase_shift in enumerate(phase_shifts): indices = _one_subspace_indices(norb, n_alpha, (i,)) @@ -282,9 +282,9 @@ def _apply_orbital_rotation_givens( # transform beta # transpose vector to align memory layout vec = vec.T.copy() - for givens_mat, target_orbs in givens_rotations: + for (c, s), target_orbs in givens_rotations: _apply_orbital_rotation_adjacent_spin_in_place( - vec, givens_mat.conj(), target_orbs, norb, n_beta + vec, c, s.conjugate(), target_orbs, norb, n_beta ) for i, phase_shift in enumerate(phase_shifts): indices = _one_subspace_indices(norb, n_beta, (i,)) @@ -294,7 +294,12 @@ def _apply_orbital_rotation_givens( def _apply_orbital_rotation_adjacent_spin_in_place( - vec: np.ndarray, mat: np.ndarray, target_orbs: tuple[int, int], norb: int, nocc: int + vec: np.ndarray, + c: float, + s: complex, + target_orbs: tuple[int, int], + norb: int, + nocc: int, ) -> None: """Apply an orbital rotation to adjacent orbitals. @@ -310,8 +315,6 @@ def _apply_orbital_rotation_adjacent_spin_in_place( indices = _zero_one_subspace_indices(norb, nocc, target_orbs) slice1 = indices[: len(indices) // 2] slice2 = indices[len(indices) // 2 :] - c, s = mat[0] - c = c.real apply_givens_rotation_in_place(vec, c, s, slice1, slice2) diff --git a/python/ffsim/linalg/__init__.py b/python/ffsim/linalg/__init__.py index 525cb708d..ca785bc3b 100644 --- a/python/ffsim/linalg/__init__.py +++ b/python/ffsim/linalg/__init__.py @@ -15,11 +15,7 @@ double_factorized_t2, modified_cholesky, ) -from ffsim.linalg.givens import ( - apply_matrix_to_slices, - givens_decomposition, - givens_matrix, -) +from ffsim.linalg.givens import apply_matrix_to_slices, givens_decomposition from ffsim.linalg.linalg import ( expm_multiply_taylor, lup, diff --git a/python/ffsim/linalg/givens.py b/python/ffsim/linalg/givens.py index 8edae2cda..796575bcd 100644 --- a/python/ffsim/linalg/givens.py +++ b/python/ffsim/linalg/givens.py @@ -13,6 +13,8 @@ from __future__ import annotations import numpy as np +from scipy.linalg.blas import zrotg as zrotg_ +from scipy.linalg.lapack import zrot def apply_matrix_to_slices( @@ -45,63 +47,32 @@ def apply_matrix_to_slices( return out -def givens_matrix(a: complex, b: complex) -> np.ndarray: - r"""Compute the Givens rotation to zero out a row entry. +def zrotg(a: complex, b: complex, tol=1e-12) -> tuple[float, complex]: + r"""Safe version of the zrotg BLAS function. - Returns a :math:`2 \times 2` unitary matrix G that satisfies + The BLAS implementation of zrotg can return NaN values if either a or b is very + close to zero. This function detects if either a or b is close to zero up to the + specified tolerance, in which case it behaves as if it were exactly zero. - .. math:: - - G - \begin{pmatrix} - a \\ - b - \end{pmatrix} - = - \begin{pmatrix} - r \\ - 0 - \end{pmatrix} - - where :math:`r` is a complex number. - - References: - - ``_ - - ``_ - - Args: - a: A complex number representing the first row entry - b: A complex number representing the second row entry - - Returns: - The Givens rotation matrix. + Note that in contrast to `scipy.linalg.blas.zrotg`, this function returns c as a + float rather than a complex. """ - # Handle case that a is zero - if np.isclose(a, 0.0): - cosine = 0.0 - sine = 1.0 - # Handle case that b is zero and a is nonzero - elif np.isclose(b, 0.0): - cosine = 1.0 - sine = 0.0 - # Handle case that a and b are both nonzero - else: - hypotenuse = np.hypot(abs(a), abs(b)) - cosine = abs(a) / hypotenuse - sign_a = a / abs(a) - sine = sign_a * b.conjugate() / hypotenuse - - return np.array([[cosine, sine], [-sine.conjugate(), cosine]]) + if np.isclose(a, 0.0, atol=tol): + return 0.0, 1 + 0j + if np.isclose(b, 0.0, atol=tol): + return 1.0, 0j + c, s = zrotg_(a, b) + return c.real, s def givens_decomposition( mat: np.ndarray, -) -> tuple[list[tuple[np.ndarray, tuple[int, int]]], np.ndarray]: +) -> tuple[list[tuple[tuple[float, complex], tuple[int, int]]], np.ndarray]: """Givens rotation decomposition of a unitary matrix.""" n, _ = mat.shape - current_matrix = mat + current_matrix = mat.astype(complex, copy=False) left_rotations = [] - right_rotations: list[tuple[np.ndarray, tuple[int, int]]] = [] + right_rotations: list[tuple[tuple[float, complex], tuple[int, int]]] = [] # compute left and right Givens rotations for i in range(n - 1): @@ -112,18 +83,22 @@ def givens_decomposition( row = n - j - 1 if not np.isclose(current_matrix[row, target_index], 0.0): # zero out element at target index in given row - givens_mat = givens_matrix( + c, s = zrotg( current_matrix[row, target_index + 1], current_matrix[row, target_index], ) - right_rotations.append( - (givens_mat, (target_index + 1, target_index)) - ) - current_matrix = apply_matrix_to_slices( - current_matrix, - givens_mat, - [(Ellipsis, target_index + 1), (Ellipsis, target_index)], + right_rotations.append(((c, s), (target_index + 1, target_index))) + current_matrix = current_matrix.T.copy() + ( + current_matrix[target_index + 1], + current_matrix[target_index], + ) = zrot( + current_matrix[target_index + 1], + current_matrix[target_index], + c, + s, ) + current_matrix = current_matrix.T else: # rotate rows by left multiplication for j in range(i + 1): @@ -131,24 +106,31 @@ def givens_decomposition( col = j if not np.isclose(current_matrix[target_index, col], 0.0): # zero out element at target index in given column - givens_mat = givens_matrix( + c, s = zrotg( current_matrix[target_index - 1, col], current_matrix[target_index, col], ) - left_rotations.append( - (givens_mat, (target_index - 1, target_index)) - ) - current_matrix = apply_matrix_to_slices( - current_matrix, givens_mat, [target_index - 1, target_index] + left_rotations.append(((c, s), (target_index - 1, target_index))) + ( + current_matrix[target_index - 1], + current_matrix[target_index], + ) = zrot( + current_matrix[target_index - 1], + current_matrix[target_index], + c, + s, ) # convert left rotations to right rotations - for givens_mat, (i, j) in reversed(left_rotations): - givens_mat = givens_mat.T.conj().astype(mat.dtype, copy=False) + for (c, s), (i, j) in reversed(left_rotations): + c, s = zrotg(c * current_matrix[j, j], s.conjugate() * current_matrix[i, i]) + right_rotations.append(((c, -s.conjugate()), (i, j))) + + givens_mat = np.array([[c, -s], [s.conjugate(), c]]) givens_mat[:, 0] *= current_matrix[i, i] givens_mat[:, 1] *= current_matrix[j, j] - new_givens_mat = givens_matrix(givens_mat[1, 1], givens_mat[1, 0]) - right_rotations.append((new_givens_mat.T, (i, j))) + c, s = zrotg(givens_mat[1, 1], givens_mat[1, 0]) + new_givens_mat = np.array([[c, s], [-s.conjugate(), c]]) phase_matrix = givens_mat @ new_givens_mat current_matrix[i, i] = phase_matrix[0, 0] current_matrix[j, j] = phase_matrix[1, 1] diff --git a/tests/linalg/givens_test.py b/tests/linalg/givens_test.py index 7e5456926..fc2c97599 100644 --- a/tests/linalg/givens_test.py +++ b/tests/linalg/givens_test.py @@ -13,23 +13,26 @@ from __future__ import annotations import numpy as np +from scipy.linalg.lapack import zrot -from ffsim.linalg import ( - apply_matrix_to_slices, - givens_decomposition, -) -from ffsim.random import random_unitary +import ffsim +from ffsim.linalg import givens_decomposition def test_givens_decomposition(): dim = 5 - mat = random_unitary(dim) - givens_rotations, phase_shifts = givens_decomposition(mat) - reconstructed = np.eye(dim, dtype=complex) - for i, phase_shift in enumerate(phase_shifts): - reconstructed[i] *= phase_shift - for givens_mat, (i, j) in givens_rotations[::-1]: - reconstructed = apply_matrix_to_slices( - reconstructed, givens_mat.conj(), ((Ellipsis, j), (Ellipsis, i)) - ) - np.testing.assert_allclose(reconstructed, mat, atol=1e-8) + rng = np.random.default_rng() + for _ in range(5): + mat = ffsim.random.random_unitary(dim, seed=rng) + givens_rotations, phase_shifts = givens_decomposition(mat) + reconstructed = np.eye(dim, dtype=complex) + for i, phase_shift in enumerate(phase_shifts): + reconstructed[i] *= phase_shift + for (c, s), (i, j) in givens_rotations[::-1]: + reconstructed = reconstructed.T.copy() + reconstructed[j], reconstructed[i] = zrot( + reconstructed[j], reconstructed[i], c, s.conjugate() + ) + reconstructed = reconstructed.T + + np.testing.assert_allclose(reconstructed, mat, atol=1e-8)