diff --git a/benchmarks/rust.py b/benchmarks/rust.py index b8b26f1c1..d7c85a16e 100644 --- a/benchmarks/rust.py +++ b/benchmarks/rust.py @@ -71,8 +71,7 @@ def time_apply_givens_rotation_in_place_python(self, *_): apply_givens_rotation_in_place_slow( self.vec_as_mat, c=0.5, - s=0.5, - phase=(1j) ** 0.5, + s=(1j) ** 0.5 * np.sqrt(0.75), slice1=self.slice1, slice2=self.slice2, ) @@ -81,8 +80,7 @@ def time_apply_givens_rotation_in_place_rust(self, *_): apply_givens_rotation_in_place( self.vec_as_mat, c=0.5, - s=0.5, - phase=(1j) ** 0.5, + s=(1j) ** 0.5 * np.sqrt(0.75), slice1=self.slice1, slice2=self.slice2, ) diff --git a/python/ffsim/gates/orbital_rotation.py b/python/ffsim/gates/orbital_rotation.py index 15b6cc591..19cceea6f 100644 --- a/python/ffsim/gates/orbital_rotation.py +++ b/python/ffsim/gates/orbital_rotation.py @@ -311,7 +311,8 @@ def _apply_orbital_rotation_adjacent_spin_in_place( slice1 = indices[: len(indices) // 2] slice2 = indices[len(indices) // 2 :] c, s = mat[0] - apply_givens_rotation_in_place(vec, c.real, abs(s), s / abs(s), slice1, slice2) + c = c.real + apply_givens_rotation_in_place(vec, c, s, slice1, slice2) @lru_cache(maxsize=None) diff --git a/python/ffsim/slow/gates/orbital_rotation.py b/python/ffsim/slow/gates/orbital_rotation.py index bf9fd222b..941a4cc56 100644 --- a/python/ffsim/slow/gates/orbital_rotation.py +++ b/python/ffsim/slow/gates/orbital_rotation.py @@ -11,7 +11,7 @@ from __future__ import annotations import numpy as np -from scipy.linalg.blas import zdrot, zscal +from scipy.linalg.lapack import zrot def gen_orbital_rotation_index_in_place_slow( @@ -66,14 +66,10 @@ def apply_single_column_transformation_in_place_slow( def apply_givens_rotation_in_place_slow( vec: np.ndarray, c: float, - s: float, - phase: complex, + s: complex, slice1: np.ndarray, slice2: np.ndarray, ) -> None: """Apply a Givens rotation to slices of a state vector.""" - phase_conj = phase.conjugate() for i, j in zip(slice1, slice2): - zscal(phase_conj, vec[i]) - vec[i], vec[j] = zdrot(vec[i], vec[j], c, s) - zscal(phase, vec[i]) + vec[i], vec[j] = zrot(vec[i], vec[j], c, s) diff --git a/src/gates/orbital_rotation.rs b/src/gates/orbital_rotation.rs index 2c7d3bbd5..cb0399d31 100644 --- a/src/gates/orbital_rotation.rs +++ b/src/gates/orbital_rotation.rs @@ -150,8 +150,7 @@ pub fn apply_single_column_transformation_in_place( pub fn apply_givens_rotation_in_place( mut vec: PyReadwriteArray2, c: f64, - s: f64, - phase: Complex64, + s: Complex64, slice1: PyReadonlyArray1, slice2: PyReadonlyArray1, ) { @@ -160,6 +159,8 @@ pub fn apply_givens_rotation_in_place( let slice2 = slice2.as_array(); let shape = vec.shape(); let dim_b = shape[1] as i32; + let s_abs = s.norm(); + let phase = s / s_abs; let phase_conj = phase.conj(); // TODO parallelize this @@ -169,7 +170,9 @@ pub fn apply_givens_rotation_in_place( Some(row_i) => match row_j.as_slice_mut() { Some(row_j) => unsafe { zscal(dim_b, phase_conj, row_i, 1); - zdrot(dim_b, row_i, 1, row_j, 1, c, s); + // TODO use zrot from lapack once it's available + // See https://github.com/blas-lapack-rs/lapack/issues/30 + zdrot(dim_b, row_i, 1, row_j, 1, c, s_abs); zscal(dim_b, phase, row_i, 1); }, None => panic!( diff --git a/tests/slow/gates/orbital_rotation_test.py b/tests/slow/gates/orbital_rotation_test.py index 0d26caf99..f47b92f31 100644 --- a/tests/slow/gates/orbital_rotation_test.py +++ b/tests/slow/gates/orbital_rotation_test.py @@ -44,14 +44,13 @@ def test_apply_givens_rotation_in_place_slow(): (dim_a, dim_b) ) vec_fast = vec_slow.copy() - c = rng.uniform() - s = 1 - c**2 - phase = (1j) ** rng.uniform(0, 4) + c = rng.uniform(0, 1) + s = (1j) ** rng.uniform(0, 4) * np.sqrt(1 - c**2) indices = _zero_one_subspace_indices(norb, n_alpha, (1, 3)) slice1 = indices[: len(indices) // 2] slice2 = indices[len(indices) // 2 :] - apply_givens_rotation_in_place_slow(vec_slow, c, s, phase, slice1, slice2) - apply_givens_rotation_in_place(vec_fast, c, s, phase, slice1, slice2) + apply_givens_rotation_in_place_slow(vec_slow, c, s, slice1, slice2) + apply_givens_rotation_in_place(vec_fast, c, s, slice1, slice2) np.testing.assert_allclose(vec_slow, vec_fast, atol=1e-8)