diff --git a/README.rst b/README.rst index 604c649..f7a52dc 100644 --- a/README.rst +++ b/README.rst @@ -194,11 +194,12 @@ We support automatic functional derivative! import jax import jax_xc - import autofd - from autofd.general_array import general_shape + import autofd.operators as o + from autofd import function import jax.numpy as jnp from jaxtyping import Array, Float32 + @function def rho(r: Float32[Array, "3"]) -> Float32[Array, ""]: """Electron number density. We take gaussian as an example. @@ -214,21 +215,20 @@ We support automatic functional derivative! return jnp.prod(jax.scipy.stats.norm.pdf(r, loc=0, scale=1)) # create a density functional - gga_xc_pbe = jax_xc.experimental.gga_x_pbe(polarized=False) + epsilon_xc = jax_xc.experimental.gga_x_pbe(rho) # a grid point in 3D r = jnp.array([0.1, 0.2, 0.3]) # pass rho and r to the functional to compute epsilon_xc (energy density) at r. # corresponding to the 'zk' in libxc - epsilon_xc = gga_xc_pbe(rho) - print(f"The function signature of epsilon_xc is {general_shape(epsilon_xc)}") + print(f"The function signature of epsilon_xc is {epsilon_xc}") energy_density = epsilon_xc(r) print(f"epsilon_xc(r) = {energy_density}") - vxc = jax.grad(lambda rho: autofd.operators.integrate(gga_xc_pbe(rho)))(rho) - print(f"The function signature of vxc is {general_shape(vxc)}") + vxc = jax.grad(lambda rho: o.integrate(rho * gga_xc_pbe(rho)))(rho) + print(f"The function signature of vxc is {vxc}") print(vxc(r))