Skip to content

Commit

Permalink
simplify t2 amplitudes tensor reshaping
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung committed Nov 7, 2024
1 parent 55501f9 commit bc7f63c
Showing 1 changed file with 4 additions and 15 deletions.
19 changes: 4 additions & 15 deletions python/ffsim/linalg/double_factorized_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,11 +532,7 @@ def double_factorized_t2(
nocc, _, nvrt, _ = t2_amplitudes.shape
norb = nocc + nvrt

occ, vrt = np.meshgrid(range(nocc), range(nvrt), indexing="ij")
occ = occ.reshape(-1)
vrt = vrt.reshape(-1)
t2_mat = t2_amplitudes[occ[:, None], occ[None, :], vrt[:, None], vrt[None, :]]

t2_mat = t2_amplitudes.transpose(0, 2, 1, 3).reshape(nocc * nvrt, nocc * nvrt)
outer_eigs, outer_vecs = _truncated_eigh(t2_mat, tol=tol, max_vecs=max_vecs)
n_vecs = len(outer_eigs)

Expand Down Expand Up @@ -643,16 +639,9 @@ def reconstruct_t2_alpha_beta(
nocc_a, nocc_b, nvrt_a, nvrt_b = t2_amplitudes.shape
norb = nocc_a + nvrt_a

occ_a, vrt_a = np.meshgrid(range(nocc_a), range(nvrt_a), indexing="ij")
occ_b, vrt_b = np.meshgrid(range(nocc_b), range(nvrt_b), indexing="ij")
occ_a = occ_a.reshape(-1)
vrt_a = vrt_a.reshape(-1)
occ_b = occ_b.reshape(-1)
vrt_b = vrt_b.reshape(-1)
t2_mat = t2_amplitudes[
occ_a[:, None], occ_b[None, :], vrt_a[:, None], vrt_b[None, :]
]

t2_mat = t2_amplitudes.transpose(0, 2, 1, 3).reshape(
nocc_a * nvrt_a, nocc_b * nvrt_b
)
left_vecs, singular_vals, right_vecs = _truncated_svd(
t2_mat, tol=tol, max_vecs=max_vecs
)
Expand Down

0 comments on commit bc7f63c

Please sign in to comment.