Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

factor out orbital rotation from t1 amplitudes function #343

Merged
merged 3 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions python/ffsim/variational/ucj_spin_balanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from typing import cast

import numpy as np
import scipy.linalg

from ffsim import gates, linalg
from ffsim.variational.util import (
orbital_rotation_from_parameters,
orbital_rotation_from_t1_amplitudes,
orbital_rotation_to_parameters,
validate_interaction_pairs,
)
Expand Down Expand Up @@ -451,10 +451,7 @@ def from_t_amplitudes(

final_orbital_rotation = None
if t1 is not None:
final_orbital_rotation_generator = np.zeros((norb, norb), dtype=complex)
final_orbital_rotation_generator[:nocc, nocc:] = t1
final_orbital_rotation_generator[nocc:, :nocc] = -t1.T
final_orbital_rotation = scipy.linalg.expm(final_orbital_rotation_generator)
final_orbital_rotation = orbital_rotation_from_t1_amplitudes(t1)

# Zero out diagonal coulomb matrix entries if requested
if pairs_aa is not None:
Expand Down
18 changes: 3 additions & 15 deletions python/ffsim/variational/ucj_spin_unbalanced.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from typing import cast

import numpy as np
import scipy.linalg

from ffsim import gates, linalg
from ffsim.variational.util import (
orbital_rotation_from_parameters,
orbital_rotation_from_t1_amplitudes,
orbital_rotation_to_parameters,
validate_interaction_pairs,
)
Expand Down Expand Up @@ -579,20 +579,8 @@ def from_t_amplitudes(
final_orbital_rotation = None
if t1 is not None:
t1a, t1b = t1

final_orbital_rotation_generator_a = np.zeros((norb, norb), dtype=complex)
final_orbital_rotation_generator_a[:nocc_a, nocc_a:] = t1a
final_orbital_rotation_generator_a[nocc_a:, :nocc_a] = -t1a.T
final_orbital_rotation_a = scipy.linalg.expm(
final_orbital_rotation_generator_a
)

final_orbital_rotation_generator_b = np.zeros((norb, norb), dtype=complex)
final_orbital_rotation_generator_b[:nocc_b, nocc_b:] = t1b
final_orbital_rotation_generator_b[nocc_b:, :nocc_b] = -t1b.T
final_orbital_rotation_b = scipy.linalg.expm(
final_orbital_rotation_generator_b
)
final_orbital_rotation_a = orbital_rotation_from_t1_amplitudes(t1a)
final_orbital_rotation_b = orbital_rotation_from_t1_amplitudes(t1b)
final_orbital_rotation = np.stack(
[final_orbital_rotation_a, final_orbital_rotation_b]
)
Expand Down
7 changes: 2 additions & 5 deletions python/ffsim/variational/ucj_spinless.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
from typing import cast

import numpy as np
import scipy.linalg

from ffsim import gates, linalg
from ffsim.variational.util import (
orbital_rotation_from_parameters,
orbital_rotation_from_t1_amplitudes,
orbital_rotation_to_parameters,
validate_interaction_pairs,
)
Expand Down Expand Up @@ -385,10 +385,7 @@ def from_t_amplitudes(

final_orbital_rotation = None
if t1 is not None:
final_orbital_rotation_generator = np.zeros((norb, norb), dtype=complex)
final_orbital_rotation_generator[:nocc, nocc:] = t1
final_orbital_rotation_generator[nocc:, :nocc] = -t1.T
final_orbital_rotation = scipy.linalg.expm(final_orbital_rotation_generator)
final_orbital_rotation = orbital_rotation_from_t1_amplitudes(t1)

# Zero out diagonal coulomb matrix entries if requested
if interaction_pairs is not None:
Expand Down
19 changes: 19 additions & 0 deletions python/ffsim/variational/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,22 @@ def orbital_rotation_from_parameters(
generator[rows, cols] += vals
generator[cols, rows] -= vals
return scipy.linalg.expm(generator)


def orbital_rotation_from_t1_amplitudes(t1: np.ndarray) -> np.ndarray:
"""Construct an orbital rotation from t1 amplitudes.

The orbital rotation is constructed as exp(t1 - t1†).

Args:
t1: The t1 amplitudes.

Returns:
The orbital rotation.
"""
nocc, nvrt = t1.shape
norb = nocc + nvrt
generator = np.zeros((norb, norb), dtype=t1.dtype)
generator[:nocc, nocc:] = t1
generator[nocc:, :nocc] = -t1.T.conj()
return scipy.linalg.expm(generator)
Loading