Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung committed Dec 17, 2023
1 parent f3047cd commit 3c92d0e
Showing 1 changed file with 24 additions and 44 deletions.
68 changes: 24 additions & 44 deletions python/ffsim/states/rdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,99 +93,80 @@ def rdm(
link_index_a = gen_linkstr_index(range(norb), n_alpha)
link_index_b = gen_linkstr_index(range(norb), n_beta)
link_index = (link_index_a, link_index_b)
vec_real = np.real(vec)
vec_imag = np.imag(vec)
if rank == 1:
if spin_summed:
return _rdm1_spin_summed(vec_real, vec_imag, norb, nelec, link_index)
return _rdm1_spin_summed(vec, norb, nelec, link_index)
else:
return _rdm1(vec_real, vec_imag, norb, nelec, link_index)
return _rdm1(vec, norb, nelec, link_index)
if rank == 2:
if spin_summed:
return _rdm2_spin_summed(
vec_real,
vec_imag,
norb,
nelec,
reordered,
link_index,
return_lower_ranks,
vec, norb, nelec, reordered, link_index, return_lower_ranks
)
else:
return _rdm2(
vec_real,
vec_imag,
norb,
nelec,
reordered,
link_index,
return_lower_ranks,
)
return _rdm2(vec, norb, nelec, reordered, link_index, return_lower_ranks)
raise NotImplementedError(
f"Computing the rank {rank} reduced density matrix is currently not supported."
)


def _rdm1_spin_summed(
vec_real: np.ndarray,
vec_imag: np.ndarray,
vec: np.ndarray,
norb: int,
nelec: tuple[int, int],
link_index: tuple[np.ndarray, np.ndarray] | None,
) -> np.ndarray:
rdm1_real = make_rdm1(vec_real, norb, nelec, link_index=link_index)
rdm1_imag = make_rdm1(vec_imag, norb, nelec, link_index=link_index)
rdm1_real = make_rdm1(vec.real, norb, nelec, link_index=link_index)
rdm1_imag = make_rdm1(vec.imag, norb, nelec, link_index=link_index)
trans_rdm1_real_imag = trans_rdm1(
vec_real, vec_imag, norb, nelec, link_index=link_index
vec.real, vec.imag, norb, nelec, link_index=link_index
)
trans_rdm1_imag_real = trans_rdm1(
vec_imag, vec_real, norb, nelec, link_index=link_index
vec.imag, vec.real, norb, nelec, link_index=link_index
)
return _assemble_rdm1_spin_summed(
rdm1_real, rdm1_imag, trans_rdm1_real_imag, trans_rdm1_imag_real
)


def _rdm1(
vec_real: np.ndarray,
vec_imag: np.ndarray,
vec: np.ndarray,
norb: int,
nelec: tuple[int, int],
link_index: tuple[np.ndarray, np.ndarray] | None,
) -> np.ndarray:
rdms1_real = make_rdm1s(vec_real, norb, nelec, link_index=link_index)
rdms1_imag = make_rdm1s(vec_imag, norb, nelec, link_index=link_index)
rdms1_real = make_rdm1s(vec.real, norb, nelec, link_index=link_index)
rdms1_imag = make_rdm1s(vec.imag, norb, nelec, link_index=link_index)
trans_rdms1_real_imag = trans_rdm1s(
vec_real, vec_imag, norb, nelec, link_index=link_index
vec.real, vec.imag, norb, nelec, link_index=link_index
)
trans_rdms1_imag_real = trans_rdm1s(
vec_imag, vec_real, norb, nelec, link_index=link_index
vec.imag, vec.real, norb, nelec, link_index=link_index
)
return _assemble_rdm1(
rdms1_real, rdms1_imag, trans_rdms1_real_imag, trans_rdms1_imag_real
)


def _rdm2_spin_summed(
vec_real: np.ndarray,
vec_imag: np.ndarray,
vec: np.ndarray,
norb: int,
nelec: tuple[int, int],
reordered: bool,
link_index: tuple[np.ndarray, np.ndarray] | None,
return_lower_ranks: bool,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
rdm1_real, rdm2_real = make_rdm12(
vec_real, norb, nelec, reorder=reordered, link_index=link_index
vec.real, norb, nelec, reorder=reordered, link_index=link_index
)
rdm1_imag, rdm2_imag = make_rdm12(
vec_imag, norb, nelec, reorder=reordered, link_index=link_index
vec.imag, norb, nelec, reorder=reordered, link_index=link_index
)
trans_rdm1_real_imag, trans_rdm2_real_imag = trans_rdm12(
vec_real, vec_imag, norb, nelec, reorder=reordered, link_index=link_index
vec.real, vec.imag, norb, nelec, reorder=reordered, link_index=link_index
)
trans_rdm1_imag_real, trans_rdm2_imag_real = trans_rdm12(
vec_imag, vec_real, norb, nelec, reorder=reordered, link_index=link_index
vec.imag, vec.real, norb, nelec, reorder=reordered, link_index=link_index
)
rdm2 = _assemble_rdm2_spin_summed(
rdm2_real, rdm2_imag, trans_rdm2_real_imag, trans_rdm2_imag_real
Expand All @@ -199,25 +180,24 @@ def _rdm2_spin_summed(


def _rdm2(
vec_real: np.ndarray,
vec_imag: np.ndarray,
vec: np.ndarray,
norb: int,
nelec: tuple[int, int],
reordered: bool,
link_index: tuple[np.ndarray, np.ndarray] | None,
return_lower_ranks: bool,
) -> np.ndarray | tuple[np.ndarray, np.ndarray]:
rdms1_real, rdms2_real = make_rdm12s(
vec_real, norb, nelec, reorder=reordered, link_index=link_index
vec.real, norb, nelec, reorder=reordered, link_index=link_index
)
rdms1_imag, rdms2_imag = make_rdm12s(
vec_imag, norb, nelec, reorder=reordered, link_index=link_index
vec.imag, norb, nelec, reorder=reordered, link_index=link_index
)
trans_rdms1_real_imag, trans_rdms2_real_imag = trans_rdm12s(
vec_real, vec_imag, norb, nelec, reorder=reordered, link_index=link_index
vec.real, vec.imag, norb, nelec, reorder=reordered, link_index=link_index
)
trans_rdms1_imag_real, trans_rdms2_imag_real = trans_rdm12s(
vec_imag, vec_real, norb, nelec, reorder=reordered, link_index=link_index
vec.imag, vec.real, norb, nelec, reorder=reordered, link_index=link_index
)
rdm2 = _assemble_rdm2(
rdms2_real, rdms2_imag, trans_rdms2_real_imag, trans_rdms2_imag_real
Expand Down

0 comments on commit 3c92d0e

Please sign in to comment.