Skip to content

Commit

Permalink
factor out apply_term function from FermionOperator linop
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsung committed Apr 2, 2024
1 parent 215be14 commit ebe8418
Showing 1 changed file with 40 additions and 31 deletions.
71 changes: 40 additions & 31 deletions python/ffsim/protocols/linear_operator_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
import pyscf.fci
from scipy.sparse.linalg import LinearOperator

from ffsim import states
from ffsim._lib import FermionOperator
from ffsim.states import dim, dims


class SupportsLinearOperator(Protocol):
Expand Down Expand Up @@ -69,41 +69,50 @@ def _fermion_operator_to_linear_operator(
f"Conserves spin Z: {operator.conserves_spin_z()}"
)

dim_ = dim(norb, nelec)
dims_ = dims(norb, nelec)
dim = states.dim(norb, nelec)

def matvec(vec: np.ndarray):
result = np.zeros(dim, dtype=complex)
for term, coeff in operator.items():
result += coeff * _apply_term(vec, term, norb, nelec)
return result

return LinearOperator(
shape=(dim, dim), matvec=matvec, rmatvec=matvec, dtype=complex
)


def _apply_term(
vec: np.ndarray,
term: tuple[tuple[bool, bool, int], ...],
norb: int,
nelec: tuple[int, int],
) -> np.ndarray:
result = _apply_term_real(vec.real, term, norb, nelec)
result += 1j * _apply_term_real(vec.imag, term, norb, nelec)
return result


def _apply_term_real(
vec: np.ndarray,
term: tuple[tuple[bool, bool, int], ...],
norb: int,
nelec: tuple[int, int],
) -> np.ndarray:
action_funcs = {
# key: (action, spin)
(False, False): pyscf.fci.addons.des_a,
(False, True): pyscf.fci.addons.des_b,
(True, False): pyscf.fci.addons.cre_a,
(True, True): pyscf.fci.addons.cre_b,
}

def matvec(vec: np.ndarray):
result = np.zeros(dim_, dtype=complex)
vec_real = np.real(vec)
vec_imag = np.imag(vec)
for term in operator:
coeff = operator[term]
transformed_real = vec_real.reshape(dims_)
transformed_imag = vec_imag.reshape(dims_)
this_nelec = list(nelec)
zero = False
for action, spin, orb in reversed(term):
action_func = action_funcs[(action, spin)]
transformed_real = action_func(transformed_real, norb, this_nelec, orb)
transformed_imag = action_func(transformed_imag, norb, this_nelec, orb)
this_nelec[spin] += 1 if action else -1
if this_nelec[spin] < 0 or this_nelec[spin] > norb:
zero = True
break
if zero:
continue
result += coeff * transformed_real.reshape(-1)
result += coeff * 1j * transformed_imag.reshape(-1)
return result

return LinearOperator(
shape=(dim_, dim_), matvec=matvec, rmatvec=matvec, dtype=complex
)
(dim_a, dim_b) = states.dims(norb, nelec)
transformed = vec.reshape((dim_a, dim_b))
this_nelec = list(nelec)
for action, spin, orb in reversed(term):
action_func = action_funcs[(action, spin)]
transformed = action_func(transformed, norb, this_nelec, orb)
this_nelec[spin] += 1 if action else -1
if this_nelec[spin] < 0 or this_nelec[spin] > norb:
return np.zeros(dim_a * dim_b, dtype=complex)
return transformed.reshape(-1).astype(complex, copy=False)

0 comments on commit ebe8418

Please sign in to comment.