Skip to content

Commit

Permalink
better type handling for orb rot <-> params conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung committed Oct 14, 2024
1 parent df54023 commit f96d48d
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 15 deletions.
4 changes: 3 additions & 1 deletion python/ffsim/variational/uccsd.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,9 @@ def to_parameters(self) -> np.ndarray:
index += 1
# Final orbital rotation
if self.final_orbital_rotation is not None:
params[index:] = orbital_rotation_to_parameters(self.final_orbital_rotation)
params[index:] = orbital_rotation_to_parameters(
self.final_orbital_rotation, real=True
)
return params

def _apply_unitary_(
Expand Down
21 changes: 15 additions & 6 deletions python/ffsim/variational/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
import scipy.linalg


def orbital_rotation_to_parameters(orbital_rotation: np.ndarray) -> np.ndarray:
def orbital_rotation_to_parameters(
orbital_rotation: np.ndarray, real: bool = False
) -> np.ndarray:
"""Convert an orbital rotation to parameters.
Converts an orbital rotation to a real-valued parameter vector. The parameter vector
Expand All @@ -27,20 +29,27 @@ def orbital_rotation_to_parameters(orbital_rotation: np.ndarray) -> np.ndarray:
Args:
orbital_rotation: The orbital rotation.
real: Whether to construct a parameter vector for a real-valued
orbital rotation. If True, the orbital rotation must have a real-valued
data type.
Returns:
The list of real numbers parameterizing the orbital rotation.
"""
if real and np.iscomplexobj(orbital_rotation):
raise TypeError(
"real was set to True, but the orbital rotation has a complex data type. "
"Try passing an orbital rotation with a real-valued data type, or else "
"set real=False."
)
norb, _ = orbital_rotation.shape
triu_indices_no_diag = list(itertools.combinations(range(norb), 2))
mat = scipy.linalg.logm(orbital_rotation)
params = np.zeros(
norb**2 if np.iscomplexobj(orbital_rotation) else norb * (norb - 1) // 2
)
params = np.zeros(norb * (norb - 1) // 2 if real else norb**2)
# real part
params[: len(triu_indices_no_diag)] = mat[tuple(zip(*triu_indices_no_diag))].real
# imaginary part
if np.iscomplexobj(orbital_rotation):
if not real:
triu_indices = list(itertools.combinations_with_replacement(range(norb), 2))
params[len(triu_indices_no_diag) :] = mat[tuple(zip(*triu_indices))].imag
return params
Expand All @@ -59,7 +68,7 @@ def orbital_rotation_from_parameters(
params: The real-valued parameters.
norb: The number of spatial orbitals, which gives the width and height of the
orbital rotation matrix.
real: Whether to construct a real-valued orbital rotation
real: Whether the parameter vector describes a real-valued orbital rotation.
Returns:
The orbital rotation.
Expand Down
12 changes: 4 additions & 8 deletions tests/python/variational/ucj_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,9 @@ def test_n_params():
):
diag_coulomb_mats_alpha_alpha = np.zeros((n_reps, norb, norb))
diag_coulomb_mats_alpha_beta = np.zeros((n_reps, norb, norb))
orbital_rotations = np.stack(
[np.eye(norb, dtype=complex) for _ in range(n_reps)]
)
orbital_rotations = np.stack([np.eye(norb) for _ in range(n_reps)])

final_orbital_rotation = np.eye(norb, dtype=complex)
final_orbital_rotation = np.eye(norb)
operator = ffsim.UCJOperator(
diag_coulomb_mats_alpha_alpha=diag_coulomb_mats_alpha_alpha,
diag_coulomb_mats_alpha_beta=diag_coulomb_mats_alpha_beta,
Expand Down Expand Up @@ -256,11 +254,9 @@ def test_real_ucj_n_params():
):
diag_coulomb_mats_alpha_alpha = np.zeros((n_reps, norb, norb))
diag_coulomb_mats_alpha_beta = np.zeros((n_reps, norb, norb))
orbital_rotations = np.stack(
[np.eye(norb, dtype=complex) for _ in range(n_reps)]
)
orbital_rotations = np.stack([np.eye(norb) for _ in range(n_reps)])

final_orbital_rotation = np.eye(norb, dtype=complex)
final_orbital_rotation = np.eye(norb)
operator = ffsim.RealUCJOperator(
diag_coulomb_mats_alpha_alpha=diag_coulomb_mats_alpha_alpha,
diag_coulomb_mats_alpha_beta=diag_coulomb_mats_alpha_beta,
Expand Down

0 comments on commit f96d48d

Please sign in to comment.