Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adapt autofd #34

Merged
merged 24 commits into from
Dec 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
python-version: ['3.9', '3.10', '3.11']
container:
image: ghcr.io/sail-sg/jax-xc-image:latest
steps:
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ jobs:
- uses: actions/checkout@v2
- name: Test
run: |
eval "$(pyenv init -)" && pyenv global 3.8-dev
pip install -r requirements.txt
eval "$(pyenv init -)" && pyenv global 3.11-dev
pip install --upgrade -r requirements.txt
bazel test --test_output=all --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 //tests/...
14 changes: 7 additions & 7 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,12 @@ We support automatic functional derivative!

import jax
import jax_xc
import autofd
from autofd.general_array import general_shape
import autofd.operators as o
from autofd import function
import jax.numpy as jnp
from jaxtyping import Array, Float32

@function
def rho(r: Float32[Array, "3"]) -> Float32[Array, ""]:
"""Electron number density. We take gaussian as an example.

Expand All @@ -214,21 +215,20 @@ We support automatic functional derivative!
return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1))

# create a density functional
gga_xc_pbe = jax_xc.experimental.gga_x_pbe(polarized=False)
epsilon_xc = jax_xc.experimental.gga_x_pbe(rho)

# a grid point in 3D
r = jnp.array([0.1, 0.2, 0.3])

# pass rho and r to the functional to compute epsilon_xc (energy density) at r.
# corresponding to the 'zk' in libxc
epsilon_xc = gga_xc_pbe(rho)
print(f"The function signature of epsilon_xc is {general_shape(epsilon_xc)}")
print(f"The function signature of epsilon_xc is {epsilon_xc}")

energy_density = epsilon_xc(r)
print(f"epsilon_xc(r) = {energy_density}")

vxc = jax.grad(lambda rho: autofd.operators.integrate(gga_xc_pbe(rho)))(rho)
print(f"The function signature of vxc is {general_shape(vxc)}")
vxc = jax.grad(lambda rho: o.integrate(rho * gga_xc_pbe(rho)))(rho)
print(f"The function signature of vxc is {vxc}")
print(vxc(r))


Expand Down
2 changes: 1 addition & 1 deletion maple2jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
from .functionals import * # noqa
from . import experimental # noqa

__version__ = "0.0.8"
__version__ = "0.0.9"
142 changes: 121 additions & 21 deletions maple2jax/experimental.jinja
Original file line number Diff line number Diff line change
@@ -1,43 +1,140 @@
import jax
from jax import lax
import jax.numpy as jnp
import ctypes
from collections import namedtuple
from typing import Callable, Optional
from functools import partial
import autofd.operators as o

from . import impl
from .impl.utils import energy_functional
from .utils import get_p

def get_functional(p, deo_functional=None):

def _filter_small_density(p, code, r, *args):
dens = r if p.nspin == 1 else r.sum()
ret = lax.cond(
(dens < p.dens_threshold), lambda *_: 0., lambda *_: code(p, r, *args),
None
)
return ret


def make_epsilon_xc(
p, rho: Callable, mo: Optional[Callable] = None, deorbitalize=None
):
# if they are deorbitalize or hybrid functionals
if p.maple_name == "DEORBITALIZE":
p0, p1 = (p.func_aux[0], p.func_aux[1])
epsilon_xc_p1 = get_functional(p1)
epsilon_xc_p0 = get_functional(p0, epsilon_xc_p1)
return epsilon_xc_p0
deorbitalize = partial(make_epsilon_xc, p1)
return make_epsilon_xc(p0, rho, mo, deorbitalize=deorbitalize)
elif p.maple_name == "":
def epsilon_xc(rho, mo=None):
fnals = [
get_functional(fn_p)(rho, mo)
for fn_p in p.func_aux
]
epsilon_xc = sum(coef * f for coef, f in zip(p.mix_coef, fnals))

def mix(*args):
return sum(coef * a for a, coef in zip(args, p.mix_coef))

epsilon_xc = o.compose(mix, *[make_epsilon_xc(fn_p, rho, mo)
for fn_p in p.func_aux], share_inputs=True)
epsilon_xc.cam_alpha = p.cam_alpha
epsilon_xc.cam_beta = p.cam_beta
epsilon_xc.cam_omega = p.cam_omega
epsilon_xc.nlc_b = p.nlc_b
epsilon_xc.nlc_C = p.nlc_C
return epsilon_xc
else:
if p.nspin == 1:
code = getattr(impl, p.maple_name).unpol
elif p.nspin == 2:
code = getattr(impl, p.maple_name).pol
return energy_functional(p, code, deo_functional)

# otherwise, it is a single functional
if p.nspin == 1:
code = getattr(impl, p.maple_name).unpol
elif p.nspin == 2:
code = getattr(impl, p.maple_name).pol

code = partial(_filter_small_density, p, code)

# construct first order derivative of rho for gga
nabla_rho = o.nabla(rho, method=jax.jacrev)

def compute_s(jac):
if jac.shape == (2, 3):
return jnp.stack([jac[0] @ jac[0], jac[0] @ jac[1], jac[1] @ jac[1]])
elif jac.shape == (3,):
return jac @ jac

# construct second order derivative of rho for mgga
hess_rho = o.nabla(nabla_rho, method=jax.jacfwd)

def compute_l(hess_rho):
return jnp.trace(hess_rho, axis1=-2, axis2=-1)

# create the epsilon_xc function
if p.type == "lda":
return o.compose(code, rho)
elif p.type == "gga":
return o.compose(
code, rho, o.compose(compute_s, nabla_rho), share_inputs=True
)
elif p.type == "mgga":
nabla_mo = o.nabla(mo, method=jax.jacfwd)

def compute_tau(mo_jac):
tau = jnp.sum(jnp.real(jnp.conj(mo_jac) * mo_jac), axis=[-1, -2]) / 2
return tau

if deorbitalize is None:
tau_fn = o.compose(compute_tau, nabla_mo)
else:
tau_fn = rho * deorbitalize(rho, mo)
return o.compose(
code,
rho,
o.compose(compute_s, nabla_rho),
o.compose(compute_l, hess_rho),
tau_fn,
share_inputs=True
)


def is_polarized(rho):
try:
out = jax.eval_shape(rho, jax.ShapeDtypeStruct((3,), jnp.float32))
except:
out = jax.eval_shape(rho, jax.ShapeDtypeStruct((3,), jnp.float64))
if out.shape != (2,) and out.shape != ():
raise ValueError(
f"rho must return an array of shape (2,) or (), got {out.shape}"
)
return (out.shape == (2,))


def check_mo_shape(mo, polarized):
try:
out = jax.eval_shape(mo, jax.ShapeDtypeStruct((3,), jnp.float32))
except:
out = jax.eval_shape(mo, jax.ShapeDtypeStruct((3,), jnp.float64))
if polarized:
if len(out.shape) != 2 or out.shape[0] != 2:
raise ValueError(
"Return value of rho has shape (2,), which means it is polarized. "
"Therefore mo must return an array of shape (2, number_of_orbital), "
f"got {out.shape}"
)
else:
if len(out.shape) != 1:
raise ValueError(
"Return value of rho has shape (), which means it is unpolarized. "
"Therefore mo must return an array of shape (number_of_orbital,), "
f"got {out.shape}"
)


{% for p, ext_params, ext_params_descriptions, info in functionals %}
def {{ p.name }}(
polarized: bool = True,
rho: Callable,
{% if p.type == "mgga" %}
mo: Callable,
{% endif %}
{% if ext_params|length > 0 %}
*,
{% endif %}
{% for param_name in ext_params.keys() %}
{{ param_name }}: Optional[float] = None,
{% endfor %}
Expand All @@ -61,17 +158,20 @@ def {{ p.name }}(
{% endif %}
Parameters
----------
polarized : bool
Whether the calculation is polarized.
rho: the density function
{% for (param_name, param_val), param_descrip in zip(ext_params.items(), ext_params_descriptions) %}
{{ param_name }} : Optional[float], default: {{ param_val }}
{{ param_descrip }}
{% endfor %}
"""
polarized = is_polarized(rho)
{% if p.type == "mgga" %}
check_mo_shape(mo, polarized)
{% endif %}
{% for param_name, value in ext_params.items() %}
{{ param_name }} = ({{ param_name }} or {{ value }})
{% endfor %}
p = get_p("{{ p.name }}", polarized, {{ ext_params.keys()|join(', ') }})
return get_functional(p)
return make_epsilon_xc(p, rho{% if p.type == "mgga" %}, mo{% endif %})

{% endfor %}
Loading
Loading