From 82ae24b47844b200bed128d9bcde2440271533b9 Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 18 Aug 2023 18:05:18 +0800 Subject: [PATCH 01/23] fix some minor issues --- maple2jax/impl/python_template.jinja | 2 +- maple2jax/impl/utils.py | 19 ++++++------------- maple2jax/python_template.jinja | 6 +++--- 3 files changed, 10 insertions(+), 17 deletions(-) diff --git a/maple2jax/impl/python_template.jinja b/maple2jax/impl/python_template.jinja index a169db6..3568d2f 100644 --- a/maple2jax/impl/python_template.jinja +++ b/maple2jax/impl/python_template.jinja @@ -10,7 +10,7 @@ from typing import Callable, Optional from .utils import * -def pol(p, r, s=(None, None, None), l=(None, None), tau=(None, None)): +def pol(p, r, s=None, l=None, tau=None): params = p.params (r0, r1), (s0, s1, s2), (l0, l1), (tau0, tau1) = r, s, l, tau {{ pol_code | indent(2) }} diff --git a/maple2jax/impl/utils.py b/maple2jax/impl/utils.py index c053a78..27fd412 100644 --- a/maple2jax/impl/utils.py +++ b/maple2jax/impl/utils.py @@ -67,8 +67,7 @@ def rho_to_arguments( if p.type == "mgga": if mo is None: raise ValueError( - "Molecular orbital function are required for mgga functionals, " - f"got {p.type}" + "Molecular orbital function are required for mgga functionals" ) # compute density @@ -83,14 +82,12 @@ def rho_to_arguments( if (p.nspin == 1 and polarized) or (p.nspin == 2 and not polarized): raise ValueError( - f"The functional is initialized as polarized={p.nspin==2}, " - f"while for the density function polarized={polarized}" + f"The functional is initialized with nspin={p.nspin}, " + f"while the density function returns array of shape {density.shape}." ) - dens = (density[0], density[1]) if polarized else density - if p.type == "lda": - return (dens,) + return (density,) # compute s jac, hvp = jax.linearize(jax.jacrev(rho), r) @@ -103,14 +100,12 @@ def rho_to_arguments( s = jnp.dot(jac, jac) if p.type == "gga": - return (dens, s) + return (density, s) # compute l # normally, r is a 3d vector for a coordinate in real space. eye = jnp.eye(r.shape[-1]) ll = sum([hvp(eye[i])[..., i] for i in range(r.shape[-1])]) - if polarized: - ll = (ll[0], ll[1]) # compute tau mo_jac = jax.jacobian(mo)(r) @@ -129,6 +124,4 @@ def rho_to_arguments( tau = jnp.sum(mo_jac**2, axis=[-1, -2]) / 2 if deorbitalize is not None: tau = density * deorbitalize - if polarized: - tau = (tau[0], tau[1]) - return (dens, s, ll, tau) + return (density, s, ll, tau) diff --git a/maple2jax/python_template.jinja b/maple2jax/python_template.jinja index 27a822a..97d6095 100644 --- a/maple2jax/python_template.jinja +++ b/maple2jax/python_template.jinja @@ -6,7 +6,7 @@ from typing import Callable, Optional from . import impl from .utils import get_p -def invoke(p, rho, r, mo=None, deo=None): +def invoke(p, rho, r, mo=None): if p.maple_name == "DEORBITALIZE": p0, p1 = (p.func_aux[0], p.func_aux[1]) deo = invoke(p1, rho, r, mo) @@ -14,12 +14,12 @@ def invoke(p, rho, r, mo=None, deo=None): elif p.maple_name == "": return sum( [ - coeff * invoke(fn_p, rho, r, mo, deo) + coeff * invoke(fn_p, rho, r, mo) for fn_p, coeff in zip(p.func_aux, p.mix_coef) ] ) else: - return getattr(impl, p.maple_name).invoke(p, rho, r, mo, deo) + return getattr(impl, p.maple_name).invoke(p, rho, r, mo) {% for p, ext_params, ext_params_descriptions, info in functionals %} From 2bdcfa0309b0b638d8d6256b7bfffc97d32ee379 Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 18 Aug 2023 19:06:21 +0800 Subject: [PATCH 02/23] There are still some errors --- README.rst | 47 ++++++++++++ maple2jax/BUILD | 1 + maple2jax/__init__.py | 1 + maple2jax/build.jinja | 23 +++++- maple2jax/experimental.jinja | 72 ++++++++++++++++++ maple2jax/impl/utils.py | 137 +++++++++++++++++++++++++++++++++++ maple2jax/maple2jax.bzl | 1 + maple2jax/wheel.BUILD | 1 + 8 files changed, 282 insertions(+), 1 deletion(-) create mode 100644 maple2jax/experimental.jinja diff --git a/README.rst b/README.rst index e0c453e..604c649 100644 --- a/README.rst +++ b/README.rst @@ -185,6 +185,53 @@ The meaning for each attribute is the same as libxc: - nlc_b: non-local correlation, b parameter - nlc_C: non-local correlation, C parameter +Experimental +------------ + +We support automatic functional derivative! + +.. code:: python + + import jax + import jax_xc + import autofd + from autofd.general_array import general_shape + import jax.numpy as jnp + from jaxtyping import Array, Float32 + + def rho(r: Float32[Array, "3"]) -> Float32[Array, ""]: + """Electron number density. We take gaussian as an example. + + A function that takes a real coordinate, and returns a scalar + indicating the number density of electron at coordinate r. + + Args: + r: a 3D coordinate. + Returns: + rho: If it is unpolarized, it is a scalar. + If it is polarized, it is a array of shape (2,). + """ + return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1)) + + # create a density functional + gga_xc_pbe = jax_xc.experimental.gga_x_pbe(polarized=False) + + # a grid point in 3D + r = jnp.array([0.1, 0.2, 0.3]) + + # pass rho and r to the functional to compute epsilon_xc (energy density) at r. + # corresponding to the 'zk' in libxc + epsilon_xc = gga_xc_pbe(rho) + print(f"The function signature of epsilon_xc is {general_shape(epsilon_xc)}") + + energy_density = epsilon_xc(r) + print(f"epsilon_xc(r) = {energy_density}") + + vxc = jax.grad(lambda rho: autofd.operators.integrate(gga_xc_pbe(rho)))(rho) + print(f"The function signature of vxc is {general_shape(vxc)}") + print(vxc(r)) + + Support Functionals ------------------- diff --git a/maple2jax/BUILD b/maple2jax/BUILD index b652a2a..910cd8a 100644 --- a/maple2jax/BUILD +++ b/maple2jax/BUILD @@ -8,6 +8,7 @@ exports_files([ "gen_py.py", "build.jinja", "python_template.jinja", + "experimental.jinja", "utils.py", "wheel.BUILD", ]) diff --git a/maple2jax/__init__.py b/maple2jax/__init__.py index 199e571..ec725a3 100644 --- a/maple2jax/__init__.py +++ b/maple2jax/__init__.py @@ -5,5 +5,6 @@ # You can obtain one at https://mozilla.org/MPL/2.0/. from .functionals import * # noqa +from . import experimental __version__ = "0.0.7" diff --git a/maple2jax/build.jinja b/maple2jax/build.jinja index 1e1993c..9ab0b49 100644 --- a/maple2jax/build.jinja +++ b/maple2jax/build.jinja @@ -24,6 +24,16 @@ genrule( ], ) +genrule( + name = "gen_experimental", + outs = ["experimental.py"], + cmd = "$(execpath :gen_py) --output $@ --template $(execpath :experimental.jinja)", + tools = [ + ":gen_py", + ":experimental.jinja", + ], +) + py_library( name = "functionals", srcs = [":gen_functionals"], @@ -35,9 +45,20 @@ py_library( visibility = ["//visibility:public"], ) +py_library( + name = "experimental", + srcs = [":gen_experimental"], + deps = [ + ":utils", + "@maple2jax//jax_xc/libxc", + "@maple2jax//jax_xc/impl", + ], + visibility = ["//visibility:public"], +) + py_library( name = "jax_xc", srcs = ["__init__.py"], - deps = [":functionals"], + deps = [":functionals", ":experimental"], visibility = ["//visibility:public"], ) diff --git a/maple2jax/experimental.jinja b/maple2jax/experimental.jinja new file mode 100644 index 0000000..b81733c --- /dev/null +++ b/maple2jax/experimental.jinja @@ -0,0 +1,72 @@ +import jax +import jax.numpy as jnp +import ctypes +from collections import namedtuple +from typing import Callable, Optional +from . import impl +from .impl.utils import energy_functional +from .utils import get_p + +def get_functional(p): + if p.nspin == 1: + code = getattr(impl, p.maple_name).unpol + elif p.nspin == 2: + code = getattr(impl, p.maple_name).pol + if p.maple_name == "DEORBITALIZE": + p0, p1 = (p.func_aux[0], p.func_aux[1]) + epsilon_xc_p1 = energy_functional(p1, code) + epsilon_xc_p0 = energy_functional(p0, code, epsilon_xc_p1) + fnal = p0 + # elif p.maple_name == "": + # def epsilon_xc(rho, mo): + # funals = [energy_functional(fn_p, code)(rho, mo) for fn_p, coeff in zip(p.func_aux, p.mix_coef)] + else: + fnal = energy_functional(p, code) + if p.maple_name == "": + fnal.cam_alpha = p.cam_alpha + fnal.cam_beta = p.cam_beta + fnal.cam_omega = p.cam_omega + fnal.nlc_b = p.nlc_b + fnal.nlc_C = p.nlc_C + return fnal + +{% for p, ext_params, ext_params_descriptions, info in functionals %} +def {{ p.name }}( + polarized: bool = True, +{% for param_name in ext_params.keys() %} + {{ param_name }}: Optional[float] = None, +{% endfor %} +) -> Callable: + r""" + {% for url, doi, ref in info %} + {{ ref }} + {% if url != "" %} + `{{ doi }} <{{ url }}>`_ + {% else %} + {{ doi }} + {% endif %} + + {% endfor %} + + {% if p.maple_name == "" %} + Mixing of the following functionals: + {% for fn_p, coeff in zip(p.func_aux, p.mix_coef) %} + {{ fn_p.name }} (coefficient: {{ coeff }}) + {% endfor %} + {% endif %} + Parameters + ---------- + polarized : bool + Whether the calculation is polarized. +{% for (param_name, param_val), param_descrip in zip(ext_params.items(), ext_params_descriptions) %} + {{ param_name }} : Optional[float], default: {{ param_val }} + {{ param_descrip }} +{% endfor %} + """ +{% for param_name, value in ext_params.items() %} + {{ param_name }} = ({{ param_name }} or {{ value }}) +{% endfor %} + p = get_p("{{ p.name }}", polarized, {{ ext_params.keys()|join(', ') }}) + return get_functional(p) + +{% endfor %} diff --git a/maple2jax/impl/utils.py b/maple2jax/impl/utils.py index 27fd412..f4a0bcd 100644 --- a/maple2jax/impl/utils.py +++ b/maple2jax/impl/utils.py @@ -17,6 +17,8 @@ import jax.numpy as jnp import tensorflow_probability as tfp from typing import Callable, Optional, NamedTuple +from jaxtyping import Array +from typing import Tuple def Heaviside(x): @@ -47,6 +49,141 @@ def lax_cond(a, b, c): return lax.cond(a, lambda _: b, lambda _: c, None) +def energy_functional(p, impl, deorbitalize=None): + import autofd.operators as o + from autofd.general_array import ( + SpecTree, + return_annotation, + _dtype_to_jaxtyping, + ) + + # filter 0 density + def _impl(r, s=None, l=None, tau=None): + dens = r if p.nspin == 1 else r.sum() + ret = lax.cond( + (dens < p.dens_threshold), lambda *_: 0., + lambda *_: impl(p, r, s, l, tau), None + ) + return ret + + # define the energy functional, that takes a rho function + # and an optional mo function. + def epsilon_xc(rho: Callable, mo: Optional[Callable] = None): + if p.type == "mgga": + if mo is None: + raise ValueError( + "Molecular orbital function are required for mgga functionals." + ) + + o_spec = SpecTree.from_ret(rho) + i_spec = SpecTree.from_args(rho)[0] + if o_spec.shape not in ((), (2,)): + raise RuntimeError( + "The density function rho passed to the functional " + "needs to return either a scalar (unpolarized), " + f"or a shape (2,) array (polarized). Got shape {o_spec.shape}." + ) + polarized = (o_spec.shape == (2,)) + if (p.nspin == 1 and polarized) or (p.nspin == 2 and not polarized): + raise ValueError( + f"The functional is initialized with nspin={p.nspin}, " + f"while the density function returns array of shape {o_spec.shape}." + ) + + # compute arguments relating to density + T = _dtype_to_jaxtyping[o_spec.dtype.name] + + # 0th order + r_fn = rho + + if p.type == "lda": + + def _energy(r: return_annotation(rho)) -> return_annotation(rho): + return _impl(r) + + return o.compose(_energy, rho) + + # 1st order + nabla_rho = o.nabla(rho) + + def compute_s( + jac: return_annotation(nabla_rho), + ) -> T[Array, ("3" if polarized else "")]: + if polarized: + return jnp.stack( + [ + jnp.dot(jac[0], jac[0]), + jnp.dot(jac[0], jac[1]), + jnp.dot(jac[1], jac[1]), + ] + ) + else: + return jnp.dot(jac, jac) + + s_fn = o.compose(compute_s, nabla_rho) + + # compute the functional + if p.type == "gga": + + def _energy(r: return_annotation(rho), + s: return_annotation(s_fn)) -> return_annotation(rho): + return _impl(r, s) + + return o.compose(_energy, rho, s_fn, share_inputs=True) + + # 2nd order + hess_rho = o.nabla(nabla_rho) + + def compute_l( + hess: return_annotation(hess_rho), + ) -> T[Array, ("2" if polarized else "")]: + return jnp.diagonal(hess, axis1=-2, axis2=-1).sum(axis=-1) + + l_fn = o.compose(compute_l, hess_rho) + + # Now deal with mo + mo_o_spec = SpecTree.from_ret(mo) + mo_i_spec = SpecTree.from_args(mo)[0] + if mo_i_spec != i_spec: + raise ValueError("mo must take the same argument as rho.") + if mo_o_spec.shape != (*(2,) * polarized, mo_o_spec.shape[1]): + raise ValueError( + "mo must return (2, N) if polarized, or (N,) if not. " + f"Got {mo_o_spec.shape} while polarized={polarized}." + ) + nabla_mo = o.nabla(mo) + + def compute_tau( + mo_jac: return_annotation(nabla_mo), + ) -> return_annotation(rho): + tau = jnp.sum(mo_jac**2, axis=[-1, -2]) / 2 + return tau + + def compute_tau_deorbitalize( + density: return_annotation(rho), + deo: return_annotation(rho), + ) -> return_annotation(rho): + return density * deo + + if deorbitalize is None: + tau_fn = o.compose(compute_tau, nabla_mo) + else: + tau_fn = o.compose( + compute_tau_deorbitalize, rho, deorbitalize(rho, mo), share_inputs=True + ) + + # compute the functional + def _energy( + r: return_annotation(rho), s: return_annotation(s_fn), + l: return_annotation(l_fn), tau: return_annotation(tau_fn) + ) -> return_annotation(rho): + return _impl(r, s, l, tau) + + return o.compose(_energy, rho, s_fn, l_fn, tau_fn, share_inputs=True) + + return epsilon_xc + + def rho_to_arguments( p: NamedTuple, rho: Callable, diff --git a/maple2jax/maple2jax.bzl b/maple2jax/maple2jax.bzl index 2bab3e0..5a0be0b 100644 --- a/maple2jax/maple2jax.bzl +++ b/maple2jax/maple2jax.bzl @@ -78,6 +78,7 @@ def _impl(rctx): rctx.symlink(Label("//maple2jax:gen_py.py"), "jax_xc/gen_py.py") rctx.symlink(Label("//maple2jax:utils.py"), "jax_xc/utils.py") rctx.symlink(Label("//maple2jax:python_template.jinja"), "jax_xc/python_template.jinja") + rctx.symlink(Label("//maple2jax:experimental.jinja"), "jax_xc/experimental.jinja") rctx.symlink(Label("//maple2jax:wheel.BUILD"), "BUILD") maple2jax_repo = repository_rule(_impl, environ = ["GITHUB_ACTIONS"]) diff --git a/maple2jax/wheel.BUILD b/maple2jax/wheel.BUILD index 7a2c925..ee9f6f1 100644 --- a/maple2jax/wheel.BUILD +++ b/maple2jax/wheel.BUILD @@ -23,6 +23,7 @@ py_wheel( version = "0.0.7", deps = [ "@maple2jax//jax_xc", + "@maple2jax//jax_xc:experimental", "@maple2jax//jax_xc:functionals", "@maple2jax//jax_xc:utils", "@maple2jax//jax_xc/impl", From ad3c2daf7c1b0129d63bba92745eabaa769539d9 Mon Sep 17 00:00:00 2001 From: linmin Date: Mon, 21 Aug 2023 16:35:56 +0800 Subject: [PATCH 03/23] update benchmark scrip --- maple2jax/impl/python_template.jinja | 2 +- maple2jax/libxc/build.jinja | 1 + scripts/speed_benchmark.py | 14 +++++++------- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/maple2jax/impl/python_template.jinja b/maple2jax/impl/python_template.jinja index 3568d2f..a169db6 100644 --- a/maple2jax/impl/python_template.jinja +++ b/maple2jax/impl/python_template.jinja @@ -10,7 +10,7 @@ from typing import Callable, Optional from .utils import * -def pol(p, r, s=None, l=None, tau=None): +def pol(p, r, s=(None, None, None), l=(None, None), tau=(None, None)): params = p.params (r0, r1), (s0, s1, s2), (l0, l1), (tau0, tau1) = r, s, l, tau {{ pol_code | indent(2) }} diff --git a/maple2jax/libxc/build.jinja b/maple2jax/libxc/build.jinja index 8711968..f76c943 100644 --- a/maple2jax/libxc/build.jinja +++ b/maple2jax/libxc/build.jinja @@ -59,6 +59,7 @@ cc_binary( "src", ], local_defines = [ + "XC_DONT_COMPILE_VXC", "XC_DONT_COMPILE_FXC", "XC_DONT_COMPILE_KXC", "XC_DONT_COMPILE_LXC", diff --git a/scripts/speed_benchmark.py b/scripts/speed_benchmark.py index acfc7c7..5ffdb43 100644 --- a/scripts/speed_benchmark.py +++ b/scripts/speed_benchmark.py @@ -56,7 +56,7 @@ def get_impl_fn_and_inputs(inputs, impl, fn_type, p, polarized): rho0, rho1, sigma0, sigma1, sigma2, lapl0, lapl1, tau0, tau1 = inputs - impl_fn = jax.vmap(Partial(impl, params=p.params, p=p)) + impl_fn = jax.vmap(Partial(impl, p)) if polarized: libxc_input_args = { @@ -66,21 +66,21 @@ def get_impl_fn_and_inputs(inputs, impl, fn_type, p, polarized): "tau": jnp.stack([tau0, tau1], -1), } if fn_type == "lda": - fn_input_args = (rho0, rho1) + fn_input_args = ((rho0, rho1),) elif fn_type == "gga": - fn_input_args = (rho0, rho1, sigma0, sigma1, sigma2) + fn_input_args = ((rho0, rho1), (sigma0, sigma1, sigma2)) elif fn_type == "mgga": fn_input_args = ( - rho0, rho1, sigma0, sigma1, sigma2, lapl0, lapl1, tau0, tau1 + (rho0, rho1), (sigma0, sigma1, sigma2), (lapl0, lapl1), (tau0, tau1) ) else: rho = rho0 + rho1 sigma = sigma0 + sigma2 + 2 * sigma1 lapl = lapl0 + lapl1 tau = tau0 + tau1 - libxc_input_args = dict(rho=rho, sigma=sigma, lapl=lapl, tau=tau) + libxc_input_args = {"rho": rho, "sigma": sigma, "lapl": lapl, "tau": tau} if fn_type == "lda": fn_input_args = (rho,) elif fn_type == "gga": @@ -110,7 +110,7 @@ def test_speed(batch): if not hasattr(impl, p.name): logging.debug(f"Skipping {p.name} due to no maple code implementation") continue - fn_type = utils.functional_name_to_type(name) + fn_type = p.type seed = 0 key = jax.random.PRNGKey(seed) @@ -150,7 +150,7 @@ def test_speed(batch): jaxxc_time.append(end_time - start_time) start_time = time.time() - res1 = func.compute(libxc_input_args) # noqa: F841 + res1 = func.compute(libxc_input_args, do_vxc=False) # noqa: F841 end_time = time.time() logging.debug(f"pylibxc {name} took {end_time - start_time} seconds") libxc_time.append(end_time - start_time) From 0752deb52c3b7646a078d08e6a7dff4c6834b4a5 Mon Sep 17 00:00:00 2001 From: linmin Date: Mon, 21 Aug 2023 16:36:09 +0800 Subject: [PATCH 04/23] suppress flake8 on annotations --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index 21a527c..7375ef9 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,7 @@ max-line-length = 80 extend-ignore = E731 E124 + F722 [pycodestyle] ignore = E731 From e3cc1ad2829c58259d59955168a840d8e14dbfcb Mon Sep 17 00:00:00 2001 From: linmin Date: Mon, 21 Aug 2023 16:40:21 +0800 Subject: [PATCH 05/23] fix error --- maple2jax/__init__.py | 2 +- maple2jax/impl/utils.py | 3 --- setup.cfg | 1 + 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/maple2jax/__init__.py b/maple2jax/__init__.py index ec725a3..114dc3c 100644 --- a/maple2jax/__init__.py +++ b/maple2jax/__init__.py @@ -5,6 +5,6 @@ # You can obtain one at https://mozilla.org/MPL/2.0/. from .functionals import * # noqa -from . import experimental +from . import experimental # noqa __version__ = "0.0.7" diff --git a/maple2jax/impl/utils.py b/maple2jax/impl/utils.py index f4a0bcd..b14f4c4 100644 --- a/maple2jax/impl/utils.py +++ b/maple2jax/impl/utils.py @@ -18,7 +18,6 @@ import tensorflow_probability as tfp from typing import Callable, Optional, NamedTuple from jaxtyping import Array -from typing import Tuple def Heaviside(x): @@ -94,8 +93,6 @@ def epsilon_xc(rho: Callable, mo: Optional[Callable] = None): T = _dtype_to_jaxtyping[o_spec.dtype.name] # 0th order - r_fn = rho - if p.type == "lda": def _energy(r: return_annotation(rho)) -> return_annotation(rho): diff --git a/setup.cfg b/setup.cfg index 7375ef9..4efcfac 100644 --- a/setup.cfg +++ b/setup.cfg @@ -14,6 +14,7 @@ extend-ignore = E731 E124 F722 + E741 [pycodestyle] ignore = E731 From de9c6bdfd52d45b7856fb78db1f82689ee9fdcfe Mon Sep 17 00:00:00 2001 From: linmin Date: Mon, 21 Aug 2023 16:41:25 +0800 Subject: [PATCH 06/23] add requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index dc22afc..9598fb0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ absl-py numpy pyscf regex +jaxtyping From 6a8a6897ab5eeea7534b33fa22bca977f12541a7 Mon Sep 17 00:00:00 2001 From: linmin Date: Tue, 22 Aug 2023 15:46:41 +0800 Subject: [PATCH 07/23] remove vxc --- tests/test_impl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_impl.py b/tests/test_impl.py index 9aeffe5..fe74c73 100644 --- a/tests/test_impl.py +++ b/tests/test_impl.py @@ -124,7 +124,7 @@ def _test_impl(self, name, polarized): module = getattr(impl, p.maple_name) fn = module.pol if polarized else module.unpol res2_zk = jax.jit(jax.vmap(lambda *args: fn(p, *args)))(r, s, l, t) - res1 = func.compute(libxc_input_args) + res1 = func.compute(libxc_input_args, do_vxc=False) res1_zk = res1["zk"].squeeze() # absolute(res2_zk - res1_zk) <= (atol + rtol * absolute(res1_zk) np.testing.assert_allclose(res2_zk, res1_zk, rtol=THRESHOLD, atol=THRESHOLD) From 45d38f541710f3a2e18a211a1d2d8f964bd6035f Mon Sep 17 00:00:00 2001 From: linmin Date: Thu, 24 Aug 2023 16:34:18 +0800 Subject: [PATCH 08/23] update test --- maple2jax/experimental.jinja | 56 ++++-- maple2jax/impl/utils.py | 114 ++++++----- maple2jax/python_template.jinja | 6 +- tests/BUILD | 4 +- tests/{test_impl.py => test_against_libxc.py} | 182 +++++++++++------- 5 files changed, 220 insertions(+), 142 deletions(-) rename tests/{test_impl.py => test_against_libxc.py} (50%) diff --git a/maple2jax/experimental.jinja b/maple2jax/experimental.jinja index b81733c..c131253 100644 --- a/maple2jax/experimental.jinja +++ b/maple2jax/experimental.jinja @@ -6,29 +6,47 @@ from typing import Callable, Optional from . import impl from .impl.utils import energy_functional from .utils import get_p +from autofd.general_array import SpecTree -def get_functional(p): - if p.nspin == 1: - code = getattr(impl, p.maple_name).unpol - elif p.nspin == 2: - code = getattr(impl, p.maple_name).pol +def get_functional(p, deo_functional=None): if p.maple_name == "DEORBITALIZE": p0, p1 = (p.func_aux[0], p.func_aux[1]) - epsilon_xc_p1 = energy_functional(p1, code) - epsilon_xc_p0 = energy_functional(p0, code, epsilon_xc_p1) - fnal = p0 - # elif p.maple_name == "": - # def epsilon_xc(rho, mo): - # funals = [energy_functional(fn_p, code)(rho, mo) for fn_p, coeff in zip(p.func_aux, p.mix_coef)] + epsilon_xc_p1 = get_functional(p1) + epsilon_xc_p0 = get_functional(p0, epsilon_xc_p1) + return epsilon_xc_p0 + elif p.maple_name == "": + def epsilon_xc(rho, mo=None): + energy_density_fns = [ + get_functional(fn_p)(rho, mo) + for fn_p in p.func_aux + ] + parameters = ( + SpecTree.to_parameter(SpecTree.from_ret(f), name=f"arg{i}") + for i, f in enumerate(energy_density_fns) + ) + @with_signature( + inspect.Signature( + parameters, + return_annotation=return_annotation(rho), + ) + ) + def linear_combine(*args): + return sum(c * d for c, d in zip(p.mix_coef, args)) + return compose(linear_combine, *energy_density_fns, share_inputs=True) + epsilon_xc.cam_alpha = p.cam_alpha + epsilon_xc.cam_beta = p.cam_beta + epsilon_xc.cam_omega = p.cam_omega + epsilon_xc.nlc_b = p.nlc_b + epsilon_xc.nlc_C = p.nlc_C + return epsilon_xc else: - fnal = energy_functional(p, code) - if p.maple_name == "": - fnal.cam_alpha = p.cam_alpha - fnal.cam_beta = p.cam_beta - fnal.cam_omega = p.cam_omega - fnal.nlc_b = p.nlc_b - fnal.nlc_C = p.nlc_C - return fnal + if p.nspin == 1: + code = getattr(impl, p.maple_name).unpol + elif p.nspin == 2: + code = getattr(impl, p.maple_name).pol + return energy_functional(p, code, deo_functional) + + {% for p, ext_params, ext_params_descriptions, info in functionals %} def {{ p.name }}( diff --git a/maple2jax/impl/utils.py b/maple2jax/impl/utils.py index b14f4c4..e5c8b53 100644 --- a/maple2jax/impl/utils.py +++ b/maple2jax/impl/utils.py @@ -49,33 +49,58 @@ def lax_cond(a, b, c): def energy_functional(p, impl, deorbitalize=None): - import autofd.operators as o + from autofd.operators import compose, nabla from autofd.general_array import ( SpecTree, return_annotation, - _dtype_to_jaxtyping, + dtype_to_jaxtyping, ) # filter 0 density - def _impl(r, s=None, l=None, tau=None): + def _impl(r, *args): dens = r if p.nspin == 1 else r.sum() ret = lax.cond( - (dens < p.dens_threshold), lambda *_: 0., - lambda *_: impl(p, r, s, l, tau), None + (dens < p.dens_threshold), lambda *_: 0., lambda *_: impl(p, r, *args), + None ) return ret # define the energy functional, that takes a rho function # and an optional mo function. def epsilon_xc(rho: Callable, mo: Optional[Callable] = None): - if p.type == "mgga": - if mo is None: - raise ValueError( - "Molecular orbital function are required for mgga functionals." - ) + r"""epsilon_xc is the xc energy density functional. + The exchange correlation energy is defined as + .. raw:: latex + + E_{xc} = \int \rho(r) \epsilon_{xc}[\rho](r) dr + + Therefore the way to use this functional is to + + .. code-block:: python + + energy_density = epsilon_xc(rho) + def Exc(rho): + return integrate(compose(mul, energy_density, rho)) + + Vxc = jax.grad(Exc)(rho) + + `compose` and `integrate` are operators imported from autofd. + + Args: + rho: the density function `f[3] -> f[2]` if polarized, + `f[3] -> f[]` otherwise. + mo: the molecular orbital function `f[3] -> f[2, nmo]` if polarized, + `f[3] -> f[nmo]` otherwise. + + Returns: + The energy density function, `f[3] -> f[2]` if polarized, + `f[3] -> f[]` otherwise. + """ o_spec = SpecTree.from_ret(rho) i_spec = SpecTree.from_args(rho)[0] + T = dtype_to_jaxtyping[o_spec.dtype.name] + # Check for any errors if o_spec.shape not in ((), (2,)): raise RuntimeError( "The density function rho passed to the functional " @@ -89,19 +114,14 @@ def epsilon_xc(rho: Callable, mo: Optional[Callable] = None): f"while the density function returns array of shape {o_spec.shape}." ) - # compute arguments relating to density - T = _dtype_to_jaxtyping[o_spec.dtype.name] + def lda_energy(r: return_annotation(rho)) -> return_annotation(rho): + return _impl(r) - # 0th order if p.type == "lda": + return compose(lda_energy, rho) - def _energy(r: return_annotation(rho)) -> return_annotation(rho): - return _impl(r) - - return o.compose(_energy, rho) - - # 1st order - nabla_rho = o.nabla(rho) + # 1st order derivative + nabla_rho = nabla(rho) def compute_s( jac: return_annotation(nabla_rho), @@ -117,43 +137,44 @@ def compute_s( else: return jnp.dot(jac, jac) - s_fn = o.compose(compute_s, nabla_rho) + def gga_energy(r: return_annotation(rho), + s: return_annotation(compute_s)) -> return_annotation(rho): + return _impl(r, s) # compute the functional if p.type == "gga": + return compose( + gga_energy, rho, compose(compute_s, nabla_rho), share_inputs=True + ) - def _energy(r: return_annotation(rho), - s: return_annotation(s_fn)) -> return_annotation(rho): - return _impl(r, s) - - return o.compose(_energy, rho, s_fn, share_inputs=True) - - # 2nd order - hess_rho = o.nabla(nabla_rho) + # 2nd order derivative + hess_rho = nabla(nabla_rho) def compute_l( hess: return_annotation(hess_rho), ) -> T[Array, ("2" if polarized else "")]: return jnp.diagonal(hess, axis1=-2, axis2=-1).sum(axis=-1) - l_fn = o.compose(compute_l, hess_rho) - - # Now deal with mo + # Now deal with the terms related to mo + if mo is None: + raise ValueError( + "Molecular orbital function are required for mgga functionals." + ) mo_o_spec = SpecTree.from_ret(mo) mo_i_spec = SpecTree.from_args(mo)[0] if mo_i_spec != i_spec: raise ValueError("mo must take the same argument as rho.") - if mo_o_spec.shape != (*(2,) * polarized, mo_o_spec.shape[1]): + if mo_o_spec.shape != (*(2,) * polarized, mo_o_spec.shape[-1]): raise ValueError( "mo must return (2, N) if polarized, or (N,) if not. " f"Got {mo_o_spec.shape} while polarized={polarized}." ) - nabla_mo = o.nabla(mo) + nabla_mo = nabla(mo) def compute_tau( mo_jac: return_annotation(nabla_mo), ) -> return_annotation(rho): - tau = jnp.sum(mo_jac**2, axis=[-1, -2]) / 2 + tau = jnp.sum(jnp.real(jnp.conj(mo_jac) * mo_jac), axis=[-1, -2]) / 2 return tau def compute_tau_deorbitalize( @@ -163,20 +184,27 @@ def compute_tau_deorbitalize( return density * deo if deorbitalize is None: - tau_fn = o.compose(compute_tau, nabla_mo) + tau_fn = compose(compute_tau, nabla_mo) else: - tau_fn = o.compose( + tau_fn = compose( compute_tau_deorbitalize, rho, deorbitalize(rho, mo), share_inputs=True ) # compute the functional - def _energy( - r: return_annotation(rho), s: return_annotation(s_fn), - l: return_annotation(l_fn), tau: return_annotation(tau_fn) + def mgga_energy( + r: return_annotation(rho), s: return_annotation(compute_s), + l: return_annotation(compute_l), tau: return_annotation(rho) ) -> return_annotation(rho): return _impl(r, s, l, tau) - return o.compose(_energy, rho, s_fn, l_fn, tau_fn, share_inputs=True) + return compose( + mgga_energy, + rho, + compose(compute_s, nabla_rho), + compose(compute_l, hess_rho), + tau_fn, + share_inputs=True, + ) return epsilon_xc @@ -242,7 +270,7 @@ def rho_to_arguments( ll = sum([hvp(eye[i])[..., i] for i in range(r.shape[-1])]) # compute tau - mo_jac = jax.jacobian(mo)(r) + mo_jac = jax.jacfwd(mo)(r) if polarized and mo_jac.shape != (2, mo_jac.shape[1], r.shape[-1]): raise ValueError( "Since this functional is initialized to be polarized." @@ -255,7 +283,7 @@ def rho_to_arguments( "mo must return an array of shape (N,), where N stands for the number of " "molecular orbitals." ) - tau = jnp.sum(mo_jac**2, axis=[-1, -2]) / 2 + tau = jnp.sum(jnp.real(jnp.conj(mo_jac) * mo_jac), axis=[-1, -2]) / 2 if deorbitalize is not None: tau = density * deorbitalize return (density, s, ll, tau) diff --git a/maple2jax/python_template.jinja b/maple2jax/python_template.jinja index 97d6095..e939465 100644 --- a/maple2jax/python_template.jinja +++ b/maple2jax/python_template.jinja @@ -59,7 +59,7 @@ def {{ p.name }}( {{ param_name }} = ({{ param_name }} or {{ value }}) {% endfor %} p = get_p("{{ p.name }}", polarized, {{ ext_params.keys()|join(', ') }}) - def _{{ p.name }}_epsilon_xc(rho, r{% if p.type == "mgga" %}, mo=None{% endif %}): + def _{{ p.name }}_epsilon_xc(rho, r, mo=None): r""" The exchange-correlation energy density of {{ p.name }}. @@ -72,11 +72,9 @@ def {{ p.name }}( The 3D coordinate of the point to evaluate the functional. Note this function is designed to accept single input, use vmap for batch. -{% if p.type == "mgga" %} mo: Callable Molecular orbital function :math:`R^3 \rightarrow R^{2 \times N}`. - :math:`N` is the number of orbitals. -{% endif %} + :math:`N` is the number of orbitals. Requried only for mgga functionals. Returns: The energy density evaluated at r. """ diff --git a/tests/BUILD b/tests/BUILD index 2fa143e..de4b3fb 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -15,8 +15,8 @@ py_test( ) py_test( - name = "test_impl", - srcs = ["test_impl.py"], + name = "test_against_libxc", + srcs = ["test_against_libxc.py"], deps = [ "@maple2jax//jax_xc", ], diff --git a/tests/test_impl.py b/tests/test_against_libxc.py similarity index 50% rename from tests/test_impl.py rename to tests/test_against_libxc.py index fe74c73..8375c32 100644 --- a/tests/test_impl.py +++ b/tests/test_against_libxc.py @@ -12,9 +12,12 @@ from absl import logging import numpy as np +import jax_xc +from jax_xc.utils import get_p from jax_xc import libxc as pylibxc -from jax_xc.libxc import libxc from jax_xc import utils, impl, functionals +from functools import partial +from jaxtyping import Array, Float64 config.update("jax_enable_x64", True) config.update("jax_debug_nans", True) @@ -51,102 +54,133 @@ ] names = pylibxc.util.xc_available_functional_names() -impl_names = [] -hybrid_names = [] +lda = [] +gga = [] +mgga = [] for n in names: - func = pylibxc.LibXCFunctional(n, 1) - x = ctypes.cast(func.xc_func, ctypes.c_void_p) - p = libxc.get_p(x.value) - p = utils.dict_to_namedtuple(p, "P") + p = get_p(n, 1) + assert n == p.name if ( p.maple_name not in SKIP_LIST and p.maple_name != "" and p.maple_name != "DEORBITALIZE" ): - impl_names.append(p.name) - if p.maple_name == "": - hybrid_names.append(p.name) - assert n == p.name + if p.name.startswith("mgga") or p.name.startswith("hyb_mgga"): + mgga.append(p.name) + elif p.name.startswith("gga") or p.name.startswith("hyb_gga"): + gga.append(p.name) + elif p.name.startswith("lda") or p.name.startswith("hyb_lda"): + lda.append(p.name) + + +def sigma(rho, r): + jac = jax.jacfwd(rho)(r) + if jac.ndim == 2: + return jnp.stack( + [ + jnp.dot(jac[0], jac[0]), + jnp.dot(jac[0], jac[1]), + jnp.dot(jac[1], jac[1]), + ] + ) + else: + return jnp.dot(jac, jac) + + +def lapl(rho, r): + hess = jax.hessian(rho)(r) + return jnp.diagonal(hess, axis1=-2, axis2=-1).sum(axis=-1) + + +def tau(rho, mo, r, deorbitalize=None): + mo_jac = jax.jacfwd(mo)(r) + if deorbitalize is None: + tau = jnp.sum(jnp.real(jnp.conj(mo_jac) * mo_jac), axis=[-1, -2]) / 2 + else: + tau = rho(r) * deorbitalize + return tau + + +def rho1(r: Float64[Array, "3"]) -> Float64[Array, ""]: + return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1)) + + +def rho2(r: Float64[Array, "3"]) -> Float64[Array, ""]: + return jnp.prod(jax.scipy.stats.cauchy.pdf(r, loc=0, scale=1)) + + +def rho3(r: Float64[Array, "3"]) -> Float64[Array, "2"]: + return jnp.stack([rho1(r), rho2(r)], axis=0) -class _TestImpl(parameterized.TestCase): +def mo1(r: Float64[Array, "3"]) -> Float64[Array, "8"]: + # create 8 orbitals + r = r[None, :] * jnp.arange(8)[:, None] + return jnp.sum(jnp.sin(r), axis=-1) + jnp.sum(jnp.cos(r), axis=-1) * 1.j - @parameterized.parameters(*impl_names) + +def mo2(r: Float64[Array, "3"]) -> Float64[Array, "8"]: + r = r[None, :] * jnp.arange(8)[:, None] * 2 + return jnp.sum(jnp.sin(r), axis=-1) + jnp.sum(jnp.cos(r), axis=-1) * 1.j + + +def mo3(r: Float64[Array, "3"]) -> Float64[Array, "2 8"]: + return jnp.stack([mo1(r), mo2(r)], axis=0) + + +class _TestAgainstLibxc(parameterized.TestCase): + + @parameterized.parameters(*lda, *gga, *mgga) def test_unpol(self, name): - self._test_impl(name, False) + self._test_impl(name, 0, rho1, mo1) - @parameterized.parameters(*impl_names) + @parameterized.parameters(*lda, *gga, *mgga) def test_pol(self, name): - self._test_impl(name, True) + self._test_impl(name, 1, rho3, mo3) - def _test_impl(self, name, polarized): + def _test_impl(self, name, polarized, rho, mo): batch = 100 - # r0, r1, s0, s1, s2, l0, l1, t0, t1 - inputs = jax.random.uniform( - jax.random.PRNGKey(10), - (9, batch), + r = jax.random.uniform( + jax.random.PRNGKey(42), + (batch, 3), dtype=jnp.float64, - minval=1e-5, - maxval=1e2, - ) - inputs = inputs.at[2:5, :].set( - jnp.where(inputs[2] + inputs[4] - 2 * inputs[3] < 0, 1, inputs[2:5]) + minval=-3, + maxval=3, ) - rho0, rho1, sigma0, sigma1, sigma2, lapl0, lapl1, tau0, tau1 = inputs + rho_r = jax.vmap(rho)(r) + sigma_r = jax.vmap(partial(sigma, rho))(r) + lapl_r = jax.vmap(partial(lapl, rho))(r) + tau_r = jax.vmap(partial(tau, rho, mo))(r) + + # libxc func = pylibxc.LibXCFunctional(name, int(polarized) + 1) - x = ctypes.cast(func.xc_func, ctypes.c_void_p) - p = libxc.get_p(x.value) - p = utils.dict_to_namedtuple(p, "P") logging.info( "Testing %s, implemented by maple file %s", p.name, p.maple_name ) - if polarized: - r, s, l, t = ( - (rho0, rho1), (sigma0, sigma1, sigma2), (lapl0, lapl1), (tau0, tau1) - ) - libxc_input_args = { - "rho": jnp.stack([rho0, rho1], -1), - "sigma": jnp.stack([sigma0, sigma1, sigma2], -1), - "lapl": jnp.stack([lapl0, lapl1], -1), - "tau": jnp.stack([tau0, tau1], -1), - } - else: - r, s, l, t = ( - rho0 + rho1, sigma0 + sigma2 + 2 * sigma1, lapl0 + lapl1, tau0 + tau1 - ) - libxc_input_args = { - "rho": r, - "sigma": s, - "lapl": l, - "tau": t, - } - - module = getattr(impl, p.maple_name) - fn = module.pol if polarized else module.unpol - res2_zk = jax.jit(jax.vmap(lambda *args: fn(p, *args)))(r, s, l, t) - res1 = func.compute(libxc_input_args, do_vxc=False) + res1 = func.compute( + { + "rho": rho_r, + "sigma": sigma_r, + "lapl": lapl_r, + "tau": tau_r, + }, + do_vxc=False + ) res1_zk = res1["zk"].squeeze() + + # jax_xc experimental + epsilon_xc = getattr(jax_xc.experimental, name)(polarized) + energy_density = epsilon_xc(rho, mo) + res2_zk = jax.vmap(energy_density)(r) + + # jax_xc + epsilon_xc = getattr(jax_xc, name)(polarized) + energy_density = lambda r: epsilon_xc(rho, r, mo) + res3_zk = jax.vmap(energy_density)(r) + # absolute(res2_zk - res1_zk) <= (atol + rtol * absolute(res1_zk) np.testing.assert_allclose(res2_zk, res1_zk, rtol=THRESHOLD, atol=THRESHOLD) - - @parameterized.parameters(*hybrid_names) - def test_get_hyb_params(self, name): - func = pylibxc.LibXCFunctional(name, 1) - x = ctypes.cast(func.xc_func, ctypes.c_void_p) - p = libxc.get_p(x.value) - p = utils.dict_to_namedtuple(p, "P") - impl_fn = getattr(functionals, name) - functional = impl_fn(False) - alpha = functional.cam_alpha - beta = functional.cam_beta - omega = functional.cam_omega - nlc_b = functional.nlc_b - nlc_C = functional.nlc_C - self.assertTrue(alpha is not None) - self.assertTrue(beta is not None) - self.assertTrue(omega is not None) - self.assertTrue(nlc_b is not None) - self.assertTrue(nlc_C is not None) + np.testing.assert_allclose(res3_zk, res1_zk, rtol=THRESHOLD, atol=THRESHOLD) if __name__ == "__main__": From 5a61a650eeecc5cf3e245dfb381981c551058b0f Mon Sep 17 00:00:00 2001 From: linmin Date: Thu, 24 Aug 2023 16:36:24 +0800 Subject: [PATCH 09/23] fix flake8 --- tests/test_against_libxc.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/test_against_libxc.py b/tests/test_against_libxc.py index 8375c32..a3c9597 100644 --- a/tests/test_against_libxc.py +++ b/tests/test_against_libxc.py @@ -5,7 +5,6 @@ # You can obtain one at https://mozilla.org/MPL/2.0/. import jax -import ctypes import jax.numpy as jnp from jax.config import config from absl.testing import absltest, parameterized @@ -15,7 +14,6 @@ import jax_xc from jax_xc.utils import get_p from jax_xc import libxc as pylibxc -from jax_xc import utils, impl, functionals from functools import partial from jaxtyping import Array, Float64 From 09559f61198b249ce92f7cf722eab09902822492 Mon Sep 17 00:00:00 2001 From: linmin Date: Thu, 24 Aug 2023 16:45:28 +0800 Subject: [PATCH 10/23] fix error caused by latest buildifier --- maple2jax/wheel.BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/maple2jax/wheel.BUILD b/maple2jax/wheel.BUILD index ee9f6f1..cf5d63c 100644 --- a/maple2jax/wheel.BUILD +++ b/maple2jax/wheel.BUILD @@ -1,5 +1,5 @@ -load("@rules_python//python:packaging.bzl", "py_wheel") load("@python_abi//:abi.bzl", "abi_tag", "python_tag") +load("@rules_python//python:packaging.bzl", "py_wheel") py_wheel( name = "jax_xc_wheel", From 6918528ddcef9860ba25e6b6bec8f246f7782a48 Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 8 Sep 2023 12:34:54 +0800 Subject: [PATCH 11/23] hide autofd within experimental --- maple2jax/experimental.jinja | 17 +-------- maple2jax/impl/utils.py | 64 +++++++++++++++----------------- tests/test_against_libxc.py | 72 +++++++++++++++++++++++++----------- 3 files changed, 83 insertions(+), 70 deletions(-) diff --git a/maple2jax/experimental.jinja b/maple2jax/experimental.jinja index c131253..36cab56 100644 --- a/maple2jax/experimental.jinja +++ b/maple2jax/experimental.jinja @@ -6,7 +6,6 @@ from typing import Callable, Optional from . import impl from .impl.utils import energy_functional from .utils import get_p -from autofd.general_array import SpecTree def get_functional(p, deo_functional=None): if p.maple_name == "DEORBITALIZE": @@ -16,23 +15,11 @@ def get_functional(p, deo_functional=None): return epsilon_xc_p0 elif p.maple_name == "": def epsilon_xc(rho, mo=None): - energy_density_fns = [ + fnals = [ get_functional(fn_p)(rho, mo) for fn_p in p.func_aux ] - parameters = ( - SpecTree.to_parameter(SpecTree.from_ret(f), name=f"arg{i}") - for i, f in enumerate(energy_density_fns) - ) - @with_signature( - inspect.Signature( - parameters, - return_annotation=return_annotation(rho), - ) - ) - def linear_combine(*args): - return sum(c * d for c, d in zip(p.mix_coef, args)) - return compose(linear_combine, *energy_density_fns, share_inputs=True) + epsilon_xc = sum(coef * f for coef, f in zip(p.mix_coef, fnals)) epsilon_xc.cam_alpha = p.cam_alpha epsilon_xc.cam_beta = p.cam_beta epsilon_xc.cam_omega = p.cam_omega diff --git a/maple2jax/impl/utils.py b/maple2jax/impl/utils.py index e5c8b53..ce7215b 100644 --- a/maple2jax/impl/utils.py +++ b/maple2jax/impl/utils.py @@ -17,7 +17,6 @@ import jax.numpy as jnp import tensorflow_probability as tfp from typing import Callable, Optional, NamedTuple -from jaxtyping import Array def Heaviside(x): @@ -51,9 +50,9 @@ def lax_cond(a, b, c): def energy_functional(p, impl, deorbitalize=None): from autofd.operators import compose, nabla from autofd.general_array import ( + with_spec, + Spec, SpecTree, - return_annotation, - dtype_to_jaxtyping, ) # filter 0 density @@ -98,8 +97,8 @@ def Exc(rho): `f[3] -> f[]` otherwise. """ o_spec = SpecTree.from_ret(rho) - i_spec = SpecTree.from_args(rho)[0] - T = dtype_to_jaxtyping[o_spec.dtype.name] + i_spec = SpecTree.from_args(rho) + dtype = o_spec.dtype # Check for any errors if o_spec.shape not in ((), (2,)): raise RuntimeError( @@ -114,18 +113,21 @@ def Exc(rho): f"while the density function returns array of shape {o_spec.shape}." ) - def lda_energy(r: return_annotation(rho)) -> return_annotation(rho): + @with_spec((o_spec,), o_spec) + def lda_energy(r): return _impl(r) if p.type == "lda": return compose(lda_energy, rho) # 1st order derivative - nabla_rho = nabla(rho) + nabla_rho = nabla(rho, method=jax.jacrev) - def compute_s( - jac: return_annotation(nabla_rho), - ) -> T[Array, ("3" if polarized else "")]: + @with_spec( + (nabla_rho.ret_spec,), + Spec((3,) if polarized else (), dtype), + ) + def compute_s(jac): if polarized: return jnp.stack( [ @@ -137,8 +139,8 @@ def compute_s( else: return jnp.dot(jac, jac) - def gga_energy(r: return_annotation(rho), - s: return_annotation(compute_s)) -> return_annotation(rho): + @with_spec((o_spec, compute_s.ret_spec), o_spec) + def gga_energy(r, s): return _impl(r, s) # compute the functional @@ -148,11 +150,13 @@ def gga_energy(r: return_annotation(rho), ) # 2nd order derivative - hess_rho = nabla(nabla_rho) + hess_rho = nabla(nabla_rho, method=jax.jacfwd) - def compute_l( - hess: return_annotation(hess_rho), - ) -> T[Array, ("2" if polarized else "")]: + @with_spec( + (hess_rho.ret_spec,), + Spec((2,) if polarized else (), dtype), + ) + def compute_l(hess): return jnp.diagonal(hess, axis1=-2, axis2=-1).sum(axis=-1) # Now deal with the terms related to mo @@ -161,7 +165,7 @@ def compute_l( "Molecular orbital function are required for mgga functionals." ) mo_o_spec = SpecTree.from_ret(mo) - mo_i_spec = SpecTree.from_args(mo)[0] + mo_i_spec = SpecTree.from_args(mo) if mo_i_spec != i_spec: raise ValueError("mo must take the same argument as rho.") if mo_o_spec.shape != (*(2,) * polarized, mo_o_spec.shape[-1]): @@ -169,32 +173,24 @@ def compute_l( "mo must return (2, N) if polarized, or (N,) if not. " f"Got {mo_o_spec.shape} while polarized={polarized}." ) - nabla_mo = nabla(mo) + nabla_mo = nabla(mo, method=jax.jacfwd) - def compute_tau( - mo_jac: return_annotation(nabla_mo), - ) -> return_annotation(rho): + @with_spec((nabla_mo.ret_spec,), o_spec) + def compute_tau(mo_jac): tau = jnp.sum(jnp.real(jnp.conj(mo_jac) * mo_jac), axis=[-1, -2]) / 2 return tau - def compute_tau_deorbitalize( - density: return_annotation(rho), - deo: return_annotation(rho), - ) -> return_annotation(rho): - return density * deo - if deorbitalize is None: tau_fn = compose(compute_tau, nabla_mo) else: - tau_fn = compose( - compute_tau_deorbitalize, rho, deorbitalize(rho, mo), share_inputs=True - ) + tau_fn = rho * deorbitalize(rho, mo) # compute the functional - def mgga_energy( - r: return_annotation(rho), s: return_annotation(compute_s), - l: return_annotation(compute_l), tau: return_annotation(rho) - ) -> return_annotation(rho): + @with_spec( + (o_spec, compute_s.ret_spec, compute_l.ret_spec, o_spec), + o_spec, + ) + def mgga_energy(r, s, l, tau): return _impl(r, s, l, tau) return compose( diff --git a/tests/test_against_libxc.py b/tests/test_against_libxc.py index a3c9597..f6ce6a8 100644 --- a/tests/test_against_libxc.py +++ b/tests/test_against_libxc.py @@ -20,7 +20,23 @@ config.update("jax_enable_x64", True) config.update("jax_debug_nans", True) -THRESHOLD = 2e-10 +THRESHOLD = { + "mgga_x_br89_explicit": 1e-9, + "gga_c_op_pw91": 1e-14, + "lda_x_rel": 1e-13, + "mgga_x_pjs18": 1e-11, + "mgga_x_m08": 1e-12, + "mgga_c_m08": 1e-12, + "mgga_x_edmgga": 1e-12, + "mgga_x_ft98": 1e-12, + "mgga_x_m061": 1e-13, + "hyb_mgga_x_pjs18": 1e-11, + "gga_k_meyer": 1e-12, + "mgga_x_sa_tpss": 1e-13, + "gga_x_beefvdw": 1e-10, + "gga_x_pbepow": 1e-10, + "mgga_c_bc95": 1e-12, +} # NOT-IMPLEMENTED due to jax's lack of support SKIP_LIST = [ @@ -55,6 +71,7 @@ lda = [] gga = [] mgga = [] +sensitive = [] for n in names: p = get_p(n, 1) @@ -63,12 +80,14 @@ p.maple_name not in SKIP_LIST and p.maple_name != "" and p.maple_name != "DEORBITALIZE" ): - if p.name.startswith("mgga") or p.name.startswith("hyb_mgga"): - mgga.append(p.name) + if p.maple_name in THRESHOLD: + sensitive.append((p.name, p.maple_name)) + elif p.name.startswith("mgga") or p.name.startswith("hyb_mgga"): + mgga.append((p.name, p.maple_name)) elif p.name.startswith("gga") or p.name.startswith("hyb_gga"): - gga.append(p.name) + gga.append((p.name, p.maple_name)) elif p.name.startswith("lda") or p.name.startswith("hyb_lda"): - lda.append(p.name) + lda.append((p.name, p.maple_name)) def sigma(rho, r): @@ -129,14 +148,20 @@ def mo3(r: Float64[Array, "3"]) -> Float64[Array, "2 8"]: class _TestAgainstLibxc(parameterized.TestCase): @parameterized.parameters(*lda, *gga, *mgga) - def test_unpol(self, name): - self._test_impl(name, 0, rho1, mo1) + def test_unpol(self, name, maple_name): + self._test_impl(name, maple_name, 0, rho1, mo1) @parameterized.parameters(*lda, *gga, *mgga) - def test_pol(self, name): - self._test_impl(name, 1, rho3, mo3) + def test_pol(self, name, maple_name): + self._test_impl(name, maple_name, 1, rho3, mo3) - def _test_impl(self, name, polarized, rho, mo): + @parameterized.parameters(*sensitive) + def test_sensitive(self, name, maple_name): + self._test_impl(name, maple_name, 0, rho1, mo1) + self._test_impl(name, maple_name, 1, rho3, mo3) + + def _test_impl(self, name, maple_name, polarized, rho, mo): + threshold = THRESHOLD.get(maple_name, 1e-14) batch = 100 r = jax.random.uniform( jax.random.PRNGKey(42), @@ -152,9 +177,7 @@ def _test_impl(self, name, polarized, rho, mo): # libxc func = pylibxc.LibXCFunctional(name, int(polarized) + 1) - logging.info( - "Testing %s, implemented by maple file %s", p.name, p.maple_name - ) + logging.info("Testing %s, implemented by maple file %s", name, maple_name) res1 = func.compute( { "rho": rho_r, @@ -166,19 +189,26 @@ def _test_impl(self, name, polarized, rho, mo): ) res1_zk = res1["zk"].squeeze() - # jax_xc experimental - epsilon_xc = getattr(jax_xc.experimental, name)(polarized) - energy_density = epsilon_xc(rho, mo) - res2_zk = jax.vmap(energy_density)(r) - # jax_xc epsilon_xc = getattr(jax_xc, name)(polarized) energy_density = lambda r: epsilon_xc(rho, r, mo) - res3_zk = jax.vmap(energy_density)(r) + res2_zk = jax.jit(jax.vmap(energy_density))(r) # absolute(res2_zk - res1_zk) <= (atol + rtol * absolute(res1_zk) - np.testing.assert_allclose(res2_zk, res1_zk, rtol=THRESHOLD, atol=THRESHOLD) - np.testing.assert_allclose(res3_zk, res1_zk, rtol=THRESHOLD, atol=THRESHOLD) + np.testing.assert_allclose(res2_zk, res1_zk, rtol=threshold, atol=threshold) + + # jax_xc experimental + try: + from autofd.general_array import function + rho = function(rho) + epsilon_xc = getattr(jax_xc.experimental, name)(polarized) + energy_density = epsilon_xc(rho, mo) + res3_zk = jax.jit(jax.vmap(energy_density))(r) + np.testing.assert_allclose( + res3_zk, res1_zk, rtol=threshold, atol=threshold + ) + except ImportError: + logging.info("Skipping experimental test because autofd is not found") if __name__ == "__main__": From 912a55b3cb1256c3450c5b25b7072aed5e675eff Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 8 Sep 2023 12:39:02 +0800 Subject: [PATCH 12/23] bump version --- maple2jax/__init__.py | 2 +- maple2jax/wheel.BUILD | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/maple2jax/__init__.py b/maple2jax/__init__.py index 114dc3c..284ad40 100644 --- a/maple2jax/__init__.py +++ b/maple2jax/__init__.py @@ -7,4 +7,4 @@ from .functionals import * # noqa from . import experimental # noqa -__version__ = "0.0.7" +__version__ = "0.0.8" diff --git a/maple2jax/wheel.BUILD b/maple2jax/wheel.BUILD index cf5d63c..4ce46d9 100644 --- a/maple2jax/wheel.BUILD +++ b/maple2jax/wheel.BUILD @@ -20,7 +20,7 @@ py_wheel( "numpy", "tensorflow-probability", ], - version = "0.0.7", + version = "0.0.8", deps = [ "@maple2jax//jax_xc", "@maple2jax//jax_xc:experimental", From 96fba451b4ac50572771f04f61899db3b2ade273 Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 8 Sep 2023 12:52:22 +0800 Subject: [PATCH 13/23] fix timeout error --- tests/BUILD | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/BUILD b/tests/BUILD index de4b3fb..b04706e 100644 --- a/tests/BUILD +++ b/tests/BUILD @@ -16,6 +16,7 @@ py_test( py_test( name = "test_against_libxc", + timeout = "eternal", srcs = ["test_against_libxc.py"], deps = [ "@maple2jax//jax_xc", From 676eb079348e1f6cc2b0da8a845cb955720ba2da Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 8 Sep 2023 13:37:40 +0800 Subject: [PATCH 14/23] mgga_x_m06l is also sensitive --- tests/test_against_libxc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_against_libxc.py b/tests/test_against_libxc.py index f6ce6a8..df2c6de 100644 --- a/tests/test_against_libxc.py +++ b/tests/test_against_libxc.py @@ -36,6 +36,7 @@ "gga_x_beefvdw": 1e-10, "gga_x_pbepow": 1e-10, "mgga_c_bc95": 1e-12, + "mgga_x_m06l": 1e-12, } # NOT-IMPLEMENTED due to jax's lack of support From 9e8cb77ba1e35fdebe89b0d77221a716b04566d6 Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 8 Sep 2023 15:43:33 +0800 Subject: [PATCH 15/23] energy actually returns a scalar --- maple2jax/impl/utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/maple2jax/impl/utils.py b/maple2jax/impl/utils.py index ce7215b..67ced42 100644 --- a/maple2jax/impl/utils.py +++ b/maple2jax/impl/utils.py @@ -99,6 +99,7 @@ def Exc(rho): o_spec = SpecTree.from_ret(rho) i_spec = SpecTree.from_args(rho) dtype = o_spec.dtype + scalar = Spec((), dtype) # Check for any errors if o_spec.shape not in ((), (2,)): raise RuntimeError( @@ -113,7 +114,7 @@ def Exc(rho): f"while the density function returns array of shape {o_spec.shape}." ) - @with_spec((o_spec,), o_spec) + @with_spec((o_spec,), scalar) def lda_energy(r): return _impl(r) @@ -139,7 +140,7 @@ def compute_s(jac): else: return jnp.dot(jac, jac) - @with_spec((o_spec, compute_s.ret_spec), o_spec) + @with_spec((o_spec, compute_s.ret_spec), scalar) def gga_energy(r, s): return _impl(r, s) @@ -186,10 +187,7 @@ def compute_tau(mo_jac): tau_fn = rho * deorbitalize(rho, mo) # compute the functional - @with_spec( - (o_spec, compute_s.ret_spec, compute_l.ret_spec, o_spec), - o_spec, - ) + @with_spec((o_spec, compute_s.ret_spec, compute_l.ret_spec, o_spec), scalar) def mgga_energy(r, s, l, tau): return _impl(r, s, l, tau) From d1bddd4c14081f30ea63f27b2173ebdc91c18221 Mon Sep 17 00:00:00 2001 From: linmin Date: Fri, 1 Dec 2023 15:46:46 +0800 Subject: [PATCH 16/23] update code to use latest autofd --- .bazelrc | 1 + maple2jax/__init__.py | 2 +- maple2jax/experimental.jinja | 142 ++++++++++++++++++++++++++----- maple2jax/impl/utils.py | 156 ----------------------------------- maple2jax/wheel.BUILD | 5 +- requirements.txt | 1 + tests/test_against_libxc.py | 47 +++++++---- 7 files changed, 159 insertions(+), 195 deletions(-) diff --git a/.bazelrc b/.bazelrc index 8f6734f..99d8320 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,3 +1,4 @@ build --copt=-g0 --copt=-O3 --copt=-DNDEBUG build --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm build --action_env=BAZEL_LINKOPTS=-static-libgcc +build --incompatible_strict_action_env diff --git a/maple2jax/__init__.py b/maple2jax/__init__.py index 284ad40..61ef665 100644 --- a/maple2jax/__init__.py +++ b/maple2jax/__init__.py @@ -7,4 +7,4 @@ from .functionals import * # noqa from . import experimental # noqa -__version__ = "0.0.8" +__version__ = "0.0.9" diff --git a/maple2jax/experimental.jinja b/maple2jax/experimental.jinja index 36cab56..0cc1806 100644 --- a/maple2jax/experimental.jinja +++ b/maple2jax/experimental.jinja @@ -1,43 +1,140 @@ import jax +from jax import lax import jax.numpy as jnp import ctypes from collections import namedtuple from typing import Callable, Optional +from functools import partial +import autofd.operators as o + from . import impl -from .impl.utils import energy_functional from .utils import get_p -def get_functional(p, deo_functional=None): + +def _filter_small_density(p, code, r, *args): + dens = r if p.nspin == 1 else r.sum() + ret = lax.cond( + (dens < p.dens_threshold), lambda *_: 0., lambda *_: code(p, r, *args), + None + ) + return ret + + +def make_epsilon_xc( + p, rho: Callable, mo: Optional[Callable] = None, deorbitalize=None +): + # if they are deorbitalize or hybrid functionals if p.maple_name == "DEORBITALIZE": p0, p1 = (p.func_aux[0], p.func_aux[1]) - epsilon_xc_p1 = get_functional(p1) - epsilon_xc_p0 = get_functional(p0, epsilon_xc_p1) - return epsilon_xc_p0 + deorbitalize = partial(make_epsilon_xc, p1) + return make_epsilon_xc(p0, rho, mo, deorbitalize=deorbitalize) elif p.maple_name == "": - def epsilon_xc(rho, mo=None): - fnals = [ - get_functional(fn_p)(rho, mo) - for fn_p in p.func_aux - ] - epsilon_xc = sum(coef * f for coef, f in zip(p.mix_coef, fnals)) + + def mix(*args): + return sum(coef * a for a, coef in zip(args, p.mix_coef)) + + epsilon_xc = o.compose(mix, *[make_epsilon_xc(fn_p, rho, mo) + for fn_p in p.func_aux], share_inputs=True) epsilon_xc.cam_alpha = p.cam_alpha epsilon_xc.cam_beta = p.cam_beta epsilon_xc.cam_omega = p.cam_omega epsilon_xc.nlc_b = p.nlc_b epsilon_xc.nlc_C = p.nlc_C return epsilon_xc - else: - if p.nspin == 1: - code = getattr(impl, p.maple_name).unpol - elif p.nspin == 2: - code = getattr(impl, p.maple_name).pol - return energy_functional(p, code, deo_functional) + # otherwise, it is a single functional + if p.nspin == 1: + code = getattr(impl, p.maple_name).unpol + elif p.nspin == 2: + code = getattr(impl, p.maple_name).pol + + code = partial(_filter_small_density, p, code) + + # construct first order derivative of rho for gga + nabla_rho = o.nabla(rho, method=jax.jacrev) + + def compute_s(jac): + if jac.shape == (2, 3): + return jnp.stack([jac[0] @ jac[0], jac[0] @ jac[1], jac[1] @ jac[1]]) + elif jac.shape == (3,): + return jac @ jac + + # construct second order derivative of rho for mgga + hess_rho = o.nabla(nabla_rho, method=jax.jacfwd) + + def compute_l(hess_rho): + return jnp.trace(hess_rho, axis1=-2, axis2=-1) + + # create the epsilon_xc function + if p.type == "lda": + return o.compose(code, rho) + elif p.type == "gga": + return o.compose( + code, rho, o.compose(compute_s, nabla_rho), share_inputs=True + ) + elif p.type == "mgga": + nabla_mo = o.nabla(mo, method=jax.jacfwd) + + def compute_tau(mo_jac): + tau = jnp.sum(jnp.real(jnp.conj(mo_jac) * mo_jac), axis=[-1, -2]) / 2 + return tau + + if deorbitalize is None: + tau_fn = o.compose(compute_tau, nabla_mo) + else: + tau_fn = rho * deorbitalize(rho, mo) + return o.compose( + code, + rho, + o.compose(compute_s, nabla_rho), + o.compose(compute_l, hess_rho), + tau_fn, + share_inputs=True + ) + + +def is_polarized(rho): + try: + out = jax.eval_shape(rho, jax.ShapeDtypeStruct((3,), jnp.float32)) + except: + out = jax.eval_shape(rho, jax.ShapeDtypeStruct((3,), jnp.float64)) + if out.shape != (2,) and out.shape != (): + raise ValueError( + f"rho must return an array of shape (2,) or (), got {out.shape}" + ) + return (out.shape == (2,)) + + +def check_mo_shape(mo, polarized): + try: + out = jax.eval_shape(mo, jax.ShapeDtypeStruct((3,), jnp.float32)) + except: + out = jax.eval_shape(mo, jax.ShapeDtypeStruct((3,), jnp.float64)) + if polarized: + if len(out.shape) != 2 or out.shape[0] != 2: + raise ValueError( + "Return value of rho has shape (2,), which means it is polarized. " + "Therefore mo must return an array of shape (2, number_of_orbital), " + f"got {out.shape}" + ) + else: + if len(out.shape) != 1: + raise ValueError( + "Return value of rho has shape (), which means it is unpolarized. " + "Therefore mo must return an array of shape (number_of_orbital,), " + f"got {out.shape}" + ) {% for p, ext_params, ext_params_descriptions, info in functionals %} def {{ p.name }}( - polarized: bool = True, + rho: Callable, +{% if p.type == "mgga" %} + mo: Callable, +{% endif %} +{% if ext_params|length > 0 %} + *, +{% endif %} {% for param_name in ext_params.keys() %} {{ param_name }}: Optional[float] = None, {% endfor %} @@ -61,17 +158,20 @@ def {{ p.name }}( {% endif %} Parameters ---------- - polarized : bool - Whether the calculation is polarized. + rho: the density function {% for (param_name, param_val), param_descrip in zip(ext_params.items(), ext_params_descriptions) %} {{ param_name }} : Optional[float], default: {{ param_val }} {{ param_descrip }} {% endfor %} """ + polarized = is_polarized(rho) +{% if p.type == "mgga" %} + check_mo_shape(mo, polarized) +{% endif %} {% for param_name, value in ext_params.items() %} {{ param_name }} = ({{ param_name }} or {{ value }}) {% endfor %} p = get_p("{{ p.name }}", polarized, {{ ext_params.keys()|join(', ') }}) - return get_functional(p) + return make_epsilon_xc(p, rho{% if p.type == "mgga" %}, mo{% endif %}) {% endfor %} diff --git a/maple2jax/impl/utils.py b/maple2jax/impl/utils.py index 67ced42..5f22ddd 100644 --- a/maple2jax/impl/utils.py +++ b/maple2jax/impl/utils.py @@ -47,162 +47,6 @@ def lax_cond(a, b, c): return lax.cond(a, lambda _: b, lambda _: c, None) -def energy_functional(p, impl, deorbitalize=None): - from autofd.operators import compose, nabla - from autofd.general_array import ( - with_spec, - Spec, - SpecTree, - ) - - # filter 0 density - def _impl(r, *args): - dens = r if p.nspin == 1 else r.sum() - ret = lax.cond( - (dens < p.dens_threshold), lambda *_: 0., lambda *_: impl(p, r, *args), - None - ) - return ret - - # define the energy functional, that takes a rho function - # and an optional mo function. - def epsilon_xc(rho: Callable, mo: Optional[Callable] = None): - r"""epsilon_xc is the xc energy density functional. - The exchange correlation energy is defined as - .. raw:: latex - - E_{xc} = \int \rho(r) \epsilon_{xc}[\rho](r) dr - - Therefore the way to use this functional is to - - .. code-block:: python - - energy_density = epsilon_xc(rho) - - def Exc(rho): - return integrate(compose(mul, energy_density, rho)) - - Vxc = jax.grad(Exc)(rho) - - `compose` and `integrate` are operators imported from autofd. - - Args: - rho: the density function `f[3] -> f[2]` if polarized, - `f[3] -> f[]` otherwise. - mo: the molecular orbital function `f[3] -> f[2, nmo]` if polarized, - `f[3] -> f[nmo]` otherwise. - - Returns: - The energy density function, `f[3] -> f[2]` if polarized, - `f[3] -> f[]` otherwise. - """ - o_spec = SpecTree.from_ret(rho) - i_spec = SpecTree.from_args(rho) - dtype = o_spec.dtype - scalar = Spec((), dtype) - # Check for any errors - if o_spec.shape not in ((), (2,)): - raise RuntimeError( - "The density function rho passed to the functional " - "needs to return either a scalar (unpolarized), " - f"or a shape (2,) array (polarized). Got shape {o_spec.shape}." - ) - polarized = (o_spec.shape == (2,)) - if (p.nspin == 1 and polarized) or (p.nspin == 2 and not polarized): - raise ValueError( - f"The functional is initialized with nspin={p.nspin}, " - f"while the density function returns array of shape {o_spec.shape}." - ) - - @with_spec((o_spec,), scalar) - def lda_energy(r): - return _impl(r) - - if p.type == "lda": - return compose(lda_energy, rho) - - # 1st order derivative - nabla_rho = nabla(rho, method=jax.jacrev) - - @with_spec( - (nabla_rho.ret_spec,), - Spec((3,) if polarized else (), dtype), - ) - def compute_s(jac): - if polarized: - return jnp.stack( - [ - jnp.dot(jac[0], jac[0]), - jnp.dot(jac[0], jac[1]), - jnp.dot(jac[1], jac[1]), - ] - ) - else: - return jnp.dot(jac, jac) - - @with_spec((o_spec, compute_s.ret_spec), scalar) - def gga_energy(r, s): - return _impl(r, s) - - # compute the functional - if p.type == "gga": - return compose( - gga_energy, rho, compose(compute_s, nabla_rho), share_inputs=True - ) - - # 2nd order derivative - hess_rho = nabla(nabla_rho, method=jax.jacfwd) - - @with_spec( - (hess_rho.ret_spec,), - Spec((2,) if polarized else (), dtype), - ) - def compute_l(hess): - return jnp.diagonal(hess, axis1=-2, axis2=-1).sum(axis=-1) - - # Now deal with the terms related to mo - if mo is None: - raise ValueError( - "Molecular orbital function are required for mgga functionals." - ) - mo_o_spec = SpecTree.from_ret(mo) - mo_i_spec = SpecTree.from_args(mo) - if mo_i_spec != i_spec: - raise ValueError("mo must take the same argument as rho.") - if mo_o_spec.shape != (*(2,) * polarized, mo_o_spec.shape[-1]): - raise ValueError( - "mo must return (2, N) if polarized, or (N,) if not. " - f"Got {mo_o_spec.shape} while polarized={polarized}." - ) - nabla_mo = nabla(mo, method=jax.jacfwd) - - @with_spec((nabla_mo.ret_spec,), o_spec) - def compute_tau(mo_jac): - tau = jnp.sum(jnp.real(jnp.conj(mo_jac) * mo_jac), axis=[-1, -2]) / 2 - return tau - - if deorbitalize is None: - tau_fn = compose(compute_tau, nabla_mo) - else: - tau_fn = rho * deorbitalize(rho, mo) - - # compute the functional - @with_spec((o_spec, compute_s.ret_spec, compute_l.ret_spec, o_spec), scalar) - def mgga_energy(r, s, l, tau): - return _impl(r, s, l, tau) - - return compose( - mgga_energy, - rho, - compose(compute_s, nabla_rho), - compose(compute_l, hess_rho), - tau_fn, - share_inputs=True, - ) - - return epsilon_xc - - def rho_to_arguments( p: NamedTuple, rho: Callable, diff --git a/maple2jax/wheel.BUILD b/maple2jax/wheel.BUILD index 4ce46d9..3e7ffdf 100644 --- a/maple2jax/wheel.BUILD +++ b/maple2jax/wheel.BUILD @@ -13,14 +13,15 @@ py_wheel( description_file = "@jax_xc//:README.rst", distribution = "jax_xc", platform = "manylinux_2_17_x86_64", - python_requires = ">=3.7", + python_requires = ">=3.9", python_tag = python_tag(), requires = [ "jax", "numpy", "tensorflow-probability", + "autofd", ], - version = "0.0.8", + version = "0.0.9", deps = [ "@maple2jax//jax_xc", "@maple2jax//jax_xc:experimental", diff --git a/requirements.txt b/requirements.txt index 9598fb0..9b9165f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ numpy pyscf regex jaxtyping +autofd diff --git a/tests/test_against_libxc.py b/tests/test_against_libxc.py index df2c6de..1c79a04 100644 --- a/tests/test_against_libxc.py +++ b/tests/test_against_libxc.py @@ -15,7 +15,8 @@ from jax_xc.utils import get_p from jax_xc import libxc as pylibxc from functools import partial -from jaxtyping import Array, Float64 +from jaxtyping import Array, Float64, Complex128 +from autofd import function config.update("jax_enable_x64", True) config.update("jax_debug_nans", True) @@ -119,41 +120,54 @@ def tau(rho, mo, r, deorbitalize=None): return tau +@function def rho1(r: Float64[Array, "3"]) -> Float64[Array, ""]: return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1)) +@function def rho2(r: Float64[Array, "3"]) -> Float64[Array, ""]: return jnp.prod(jax.scipy.stats.cauchy.pdf(r, loc=0, scale=1)) +@function def rho3(r: Float64[Array, "3"]) -> Float64[Array, "2"]: return jnp.stack([rho1(r), rho2(r)], axis=0) -def mo1(r: Float64[Array, "3"]) -> Float64[Array, "8"]: +@function +def mo1(r: Float64[Array, "3"]) -> Complex128[Array, "8"]: # create 8 orbitals r = r[None, :] * jnp.arange(8)[:, None] return jnp.sum(jnp.sin(r), axis=-1) + jnp.sum(jnp.cos(r), axis=-1) * 1.j -def mo2(r: Float64[Array, "3"]) -> Float64[Array, "8"]: +@function +def mo2(r: Float64[Array, "3"]) -> Complex128[Array, "8"]: r = r[None, :] * jnp.arange(8)[:, None] * 2 return jnp.sum(jnp.sin(r), axis=-1) + jnp.sum(jnp.cos(r), axis=-1) * 1.j -def mo3(r: Float64[Array, "3"]) -> Float64[Array, "2 8"]: +@function +def mo3(r: Float64[Array, "3"]) -> Complex128[Array, "2 8"]: return jnp.stack([mo1(r), mo2(r)], axis=0) class _TestAgainstLibxc(parameterized.TestCase): - @parameterized.parameters(*lda, *gga, *mgga) - def test_unpol(self, name, maple_name): - self._test_impl(name, maple_name, 0, rho1, mo1) + @parameterized.parameters(*lda) + def test_lda(self, name, maple_name): + self._test_impl(name, maple_name, 0, rho1) + self._test_impl(name, maple_name, 1, rho3) + + @parameterized.parameters(*gga) + def test_gga(self, name, maple_name): + self._test_impl(name, maple_name, 0, rho1) + self._test_impl(name, maple_name, 1, rho3) - @parameterized.parameters(*lda, *gga, *mgga) - def test_pol(self, name, maple_name): + @parameterized.parameters(*mgga) + def test_mgga(self, name, maple_name): + self._test_impl(name, maple_name, 0, rho1, mo1) self._test_impl(name, maple_name, 1, rho3, mo3) @parameterized.parameters(*sensitive) @@ -161,7 +175,7 @@ def test_sensitive(self, name, maple_name): self._test_impl(name, maple_name, 0, rho1, mo1) self._test_impl(name, maple_name, 1, rho3, mo3) - def _test_impl(self, name, maple_name, polarized, rho, mo): + def _test_impl(self, name, maple_name, polarized, rho, mo=None): threshold = THRESHOLD.get(maple_name, 1e-14) batch = 100 r = jax.random.uniform( @@ -174,7 +188,10 @@ def _test_impl(self, name, maple_name, polarized, rho, mo): rho_r = jax.vmap(rho)(r) sigma_r = jax.vmap(partial(sigma, rho))(r) lapl_r = jax.vmap(partial(lapl, rho))(r) - tau_r = jax.vmap(partial(tau, rho, mo))(r) + if mo is not None: + tau_r = jax.vmap(partial(tau, rho, mo))(r) + else: + tau_r = None # libxc func = pylibxc.LibXCFunctional(name, int(polarized) + 1) @@ -200,11 +217,11 @@ def _test_impl(self, name, maple_name, polarized, rho, mo): # jax_xc experimental try: - from autofd.general_array import function + from autofd import function rho = function(rho) - epsilon_xc = getattr(jax_xc.experimental, name)(polarized) - energy_density = epsilon_xc(rho, mo) - res3_zk = jax.jit(jax.vmap(energy_density))(r) + args = (rho, mo) if mo is not None else (rho,) + epsilon_xc = getattr(jax_xc.experimental, name)(*args) + res3_zk = jax.jit(jax.vmap(epsilon_xc))(r) np.testing.assert_allclose( res3_zk, res1_zk, rtol=threshold, atol=threshold ) From 9fe3a70c512b8a225157bb88ad6cf6f85d21be5a Mon Sep 17 00:00:00 2001 From: mavenlin Date: Fri, 1 Dec 2023 21:12:56 +0800 Subject: [PATCH 17/23] Deprecate old python --- .github/workflows/release.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 04f41f9..65c8ddb 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -11,7 +11,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.9', '3.10', '3.11'] container: image: ghcr.io/sail-sg/jax-xc-image:latest steps: From 33952aa5eaecfe3e1539364025540ef7ae0930a5 Mon Sep 17 00:00:00 2001 From: mavenlin Date: Fri, 1 Dec 2023 21:13:35 +0800 Subject: [PATCH 18/23] Deprecate old python --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b0075ce..2b49414 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,6 +15,6 @@ jobs: - uses: actions/checkout@v2 - name: Test run: | - eval "$(pyenv init -)" && pyenv global 3.8-dev + eval "$(pyenv init -)" && pyenv global 3.9-dev pip install -r requirements.txt bazel test --test_output=all --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 //tests/... From 6ace64b29d0190552fcddcc2e68bf30b6d0ee86b Mon Sep 17 00:00:00 2001 From: mavenlin Date: Fri, 1 Dec 2023 23:22:10 +0800 Subject: [PATCH 19/23] Update test.yml --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 2b49414..bd968a1 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -15,6 +15,6 @@ jobs: - uses: actions/checkout@v2 - name: Test run: | - eval "$(pyenv init -)" && pyenv global 3.9-dev + eval "$(pyenv init -)" && pyenv global 3.11-dev pip install -r requirements.txt bazel test --test_output=all --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 //tests/... From ffc7d9a0af387dfc81d5b2d2deef4670436b6d65 Mon Sep 17 00:00:00 2001 From: mavenlin Date: Fri, 1 Dec 2023 23:26:07 +0800 Subject: [PATCH 20/23] Update test.yml --- .github/workflows/test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index bd968a1..0da960f 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -16,5 +16,5 @@ jobs: - name: Test run: | eval "$(pyenv init -)" && pyenv global 3.11-dev - pip install -r requirements.txt + pip install --upgrade -r requirements.txt bazel test --test_output=all --remote_cache=http://${{ secrets.BAZEL_CACHE }}:8080 //tests/... From fab7e4d2c3452c4bfba280ee27732b5ce9cd9105 Mon Sep 17 00:00:00 2001 From: mavenlin Date: Fri, 1 Dec 2023 23:30:59 +0800 Subject: [PATCH 21/23] Update .bazelrc --- .bazelrc | 1 - 1 file changed, 1 deletion(-) diff --git a/.bazelrc b/.bazelrc index 99d8320..8f6734f 100644 --- a/.bazelrc +++ b/.bazelrc @@ -1,4 +1,3 @@ build --copt=-g0 --copt=-O3 --copt=-DNDEBUG build --action_env=BAZEL_LINKLIBS=-l%:libstdc++.a:-lm build --action_env=BAZEL_LINKOPTS=-static-libgcc -build --incompatible_strict_action_env From 30a82f4bbd8d89102fe0c742e44b079dfb71bb9f Mon Sep 17 00:00:00 2001 From: mavenlin Date: Sat, 2 Dec 2023 10:36:19 +0800 Subject: [PATCH 22/23] Update test_against_libxc.py --- tests/test_against_libxc.py | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/tests/test_against_libxc.py b/tests/test_against_libxc.py index 1c79a04..aebd27b 100644 --- a/tests/test_against_libxc.py +++ b/tests/test_against_libxc.py @@ -84,7 +84,7 @@ ): if p.maple_name in THRESHOLD: sensitive.append((p.name, p.maple_name)) - elif p.name.startswith("mgga") or p.name.startswith("hyb_mgga"): + if p.name.startswith("mgga") or p.name.startswith("hyb_mgga"): mgga.append((p.name, p.maple_name)) elif p.name.startswith("gga") or p.name.startswith("hyb_gga"): gga.append((p.name, p.maple_name)) @@ -170,11 +170,6 @@ def test_mgga(self, name, maple_name): self._test_impl(name, maple_name, 0, rho1, mo1) self._test_impl(name, maple_name, 1, rho3, mo3) - @parameterized.parameters(*sensitive) - def test_sensitive(self, name, maple_name): - self._test_impl(name, maple_name, 0, rho1, mo1) - self._test_impl(name, maple_name, 1, rho3, mo3) - def _test_impl(self, name, maple_name, polarized, rho, mo=None): threshold = THRESHOLD.get(maple_name, 1e-14) batch = 100 From a3aaf5dda497fafd856a639867c91d9113358231 Mon Sep 17 00:00:00 2001 From: mavenlin Date: Sat, 2 Dec 2023 11:03:11 +0800 Subject: [PATCH 23/23] Update README.rst --- README.rst | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/README.rst b/README.rst index 604c649..f7a52dc 100644 --- a/README.rst +++ b/README.rst @@ -194,11 +194,12 @@ We support automatic functional derivative! import jax import jax_xc - import autofd - from autofd.general_array import general_shape + import autofd.operators as o + from autofd import function import jax.numpy as jnp from jaxtyping import Array, Float32 + @function def rho(r: Float32[Array, "3"]) -> Float32[Array, ""]: """Electron number density. We take gaussian as an example. @@ -214,21 +215,20 @@ We support automatic functional derivative! return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1)) # create a density functional - gga_xc_pbe = jax_xc.experimental.gga_x_pbe(polarized=False) + epsilon_xc = jax_xc.experimental.gga_x_pbe(rho) # a grid point in 3D r = jnp.array([0.1, 0.2, 0.3]) # pass rho and r to the functional to compute epsilon_xc (energy density) at r. # corresponding to the 'zk' in libxc - epsilon_xc = gga_xc_pbe(rho) - print(f"The function signature of epsilon_xc is {general_shape(epsilon_xc)}") + print(f"The function signature of epsilon_xc is {epsilon_xc}") energy_density = epsilon_xc(r) print(f"epsilon_xc(r) = {energy_density}") - vxc = jax.grad(lambda rho: autofd.operators.integrate(gga_xc_pbe(rho)))(rho) - print(f"The function signature of vxc is {general_shape(vxc)}") + vxc = jax.grad(lambda rho: o.integrate(rho * gga_xc_pbe(rho)))(rho) + print(f"The function signature of vxc is {vxc}") print(vxc(r))