Skip to content

Commit

Permalink
add double factorized t2 tol and max_vecs test
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung committed Jun 2, 2024
1 parent e24927c commit af595fe
Showing 1 changed file with 49 additions and 0 deletions.
49 changes: 49 additions & 0 deletions tests/python/linalg/double_factorized_decomposition_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,3 +396,52 @@ def test_double_factorized_t2_alpha_beta_random():
orbital_rotations[:, 0, 0], orbital_rotations[:, 3, 0].conj(), atol=1e-8
)
# TODO add the rest of the relations


def test_double_factorized_t2_alpha_beta_tol_max_vecs():
"""Test double-factorized decomposition alpha-beta error threshold and max vecs."""
mol = gto.Mole()
mol.build(
atom=[["H", (0, 0, 0)], ["O", (0, 0, 1.1)]],
basis="6-31g",
spin=1,
symmetry="Coov",
)
hartree_fock = scf.ROHF(mol).run()

ccsd = cc.CCSD(hartree_fock).run()
_, t2ab, _ = ccsd.t2
nocc_a, nocc_b, nvrt_a, _ = t2ab.shape
norb = nocc_a + nvrt_a

# test max_vecs
max_vecs = 25
diag_coulomb_mats, orbital_rotations = ffsim.linalg.double_factorized_t2_alpha_beta(
t2ab, max_vecs=max_vecs
)
reconstructed = reconstruct_t2_alpha_beta(
diag_coulomb_mats, orbital_rotations, norb=norb, nocc_a=nocc_a, nocc_b=nocc_b
)
assert len(diag_coulomb_mats) == max_vecs
np.testing.assert_allclose(reconstructed, t2ab, atol=1e-4)

# test error threshold
tol = 1e-3
diag_coulomb_mats, orbital_rotations = ffsim.linalg.double_factorized_t2_alpha_beta(
t2ab, tol=tol
)
reconstructed = reconstruct_t2_alpha_beta(
diag_coulomb_mats, orbital_rotations, norb=norb, nocc_a=nocc_a, nocc_b=nocc_b
)
assert len(diag_coulomb_mats) <= 23
np.testing.assert_allclose(reconstructed, t2ab, atol=tol)

# test error threshold and max vecs
diag_coulomb_mats, orbital_rotations = ffsim.linalg.double_factorized_t2_alpha_beta(
t2ab, tol=tol, max_vecs=max_vecs
)
reconstructed = reconstruct_t2_alpha_beta(
diag_coulomb_mats, orbital_rotations, norb=norb, nocc_a=nocc_a, nocc_b=nocc_b
)
assert len(orbital_rotations) <= 23
np.testing.assert_allclose(reconstructed, t2ab, atol=tol)

0 comments on commit af595fe

Please sign in to comment.