Skip to content

Commit

Permalink
refactor: use jax.tree module
Browse files Browse the repository at this point in the history
  • Loading branch information
wcxve committed Nov 18, 2024
1 parent caab33c commit 2e01322
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/elisa/infer/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def _optimize_ns(self, max_steps=131072, verbose=False) -> JAXArray:
samples[f'{i}_loglike'].sum(axis=-1) for i in self._data.keys()
]
mle_idx = np.sum(loglike, axis=0).argmax()
mle = jax.tree_map(lambda s: s[mle_idx], samples)
mle = jax.tree.map(lambda s: s[mle_idx], samples)
mle = {i: mle[i] for i in self._helper.params_names['free']}
return self._helper.constr_dic_to_unconstr_arr(mle)

Expand Down
28 changes: 14 additions & 14 deletions src/elisa/infer/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,7 @@ def numpyro_model(predictive: bool = False) -> None:
numpyro.deterministic(pid_to_pname[pid], fn(params_id_values))

# the likelihood between observation and model for each dataset
jax.tree_map(
jax.tree.map(
lambda f: f(params_name_values, predictive=predictive),
likelihood,
)
Expand Down Expand Up @@ -581,9 +581,9 @@ def deviance(unconstr_arr: JAXArray) -> dict:
"""
loglike_dic = loglike(unconstr_arr)
neg_double = jax.jit(lambda x: -2.0 * x)
point = jax.tree_map(neg_double, loglike_dic['point'])
group = jax.tree_map(neg_double, loglike_dic['group'])
total = jax.tree_map(neg_double, loglike_dic['total'])
point = jax.tree.map(neg_double, loglike_dic['point'])
group = jax.tree.map(neg_double, loglike_dic['group'])
total = jax.tree.map(neg_double, loglike_dic['total'])
return {'total': total, 'group': group, 'point': point}

def deviance_total(unconstr_arr: JAXArray) -> JAXFloat:
Expand Down Expand Up @@ -633,15 +633,15 @@ def fit_once(i: int, args: tuple) -> tuple:

# update best fit params to result
params = sites['params']
result['params'] = jax.tree_map(
result['params'] = jax.tree.map(
lambda x, y: x.at[i].set(y),
result['params'],
params,
)

# update the best fit model to result
models = sites['models']
result['models'] = jax.tree_map(
result['models'] = jax.tree.map(
lambda x, y: x.at[i].set(y),
result['models'],
{k: models[k] for k in result['models']},
Expand All @@ -650,12 +650,12 @@ def fit_once(i: int, args: tuple) -> tuple:
# update the deviance information to result
dev = new_deviance(fitted_params)
res_dev = result['deviance']
res_dev['group'] = jax.tree_map(
res_dev['group'] = jax.tree.map(
lambda x, y: x.at[i].set(y),
res_dev['group'],
dev['group'],
)
res_dev['point'] = jax.tree_map(
res_dev['point'] = jax.tree.map(
lambda x, y: x.at[i].set(y),
res_dev['point'],
dev['point'],
Expand Down Expand Up @@ -719,12 +719,12 @@ def sim_parallel_fit(
fit_pmap = jax.pmap(lambda *args: lax.fori_loop(0, batch, fn, args)[1])
reshape = lambda x: x.reshape((n_parallel, -1) + x.shape[1:])
result = fit_pmap(
jax.tree_map(reshape, sim_data),
jax.tree_map(reshape, result),
jax.tree_map(reshape, init),
jax.tree.map(reshape, sim_data),
jax.tree.map(reshape, result),
jax.tree.map(reshape, init),
)

return jax.tree_map(jnp.concatenate, result)
return jax.tree.map(jnp.concatenate, result)

def simulate_and_fit(
seed: int,
Expand Down Expand Up @@ -768,7 +768,7 @@ def simulate_and_fit(
The simulation and fitting result.
"""
seed = int(seed)
free_params = jax.tree_map(jnp.array, free_params)
free_params = jax.tree.map(jnp.array, free_params)
model_values = {
f'{k}_model': model_values[f'{k}_model'] for k in simulators
}
Expand All @@ -779,7 +779,7 @@ def simulate_and_fit(
assert n > 0

# check if all params shapes are the same
shapes = list(jax.tree_map(jnp.shape, free_params).values())
shapes = list(jax.tree.map(jnp.shape, free_params).values())
assert all(i == shapes[0] for i in shapes)

# TODO: support posterior prediction with n > 1
Expand Down
22 changes: 11 additions & 11 deletions src/elisa/infer/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def _flux(
for name, fn in fns.items():
f = fn(egrid, params, comps)
if comps:
flux[name] = jax.tree_map(
flux[name] = jax.tree.map(
lambda v: jnp.sum(v * de, axis=-1), f
)
else:
Expand Down Expand Up @@ -450,15 +450,15 @@ def boot(
'Bootstrap',
)
valid = result.pop('valid')
result = jax.tree_map(lambda x: x[valid], result)
result = jax.tree.map(lambda x: x[valid], result)

self._boot = BootstrapResult(
mle={k: v[0] for k, v in self.mle.items()},
data=result['data'],
models=result['models'],
params=result['params'],
deviance=result['deviance'],
p_value=jax.tree_map(
p_value=jax.tree.map(
lambda obs, sim: np.sum(sim >= obs, axis=0) / len(sim),
self._deviance,
result['deviance'],
Expand Down Expand Up @@ -2027,7 +2027,7 @@ def ppc(
'PPC',
)
valid = result.pop('valid')
result = jax.tree_map(lambda x: x[valid], result)
result = jax.tree.map(lambda x: x[valid], result)

self._ppc = PPCResult(
params_rep=params,
Expand All @@ -2036,7 +2036,7 @@ def ppc(
params_fit=result['params'],
models_fit=result['models'],
deviance=result['deviance'],
p_value=jax.tree_map(
p_value=jax.tree.map(
lambda obs, sim: np.sum(sim >= obs, axis=0) / len(sim),
self._mle['deviance'],
result['deviance'],
Expand Down Expand Up @@ -2084,7 +2084,7 @@ def _mle(self):
# drop unnecessary terms
loglike.pop('data')
loglike.pop('channels')
mle_result['deviance'] = jax.tree_map(lambda x: -2.0 * x, loglike)
mle_result['deviance'] = jax.tree.map(lambda x: -2.0 * x, loglike)

# model values at MLE
mle_result['models'] = sites['models']
Expand Down Expand Up @@ -2173,7 +2173,7 @@ def _init_from_jaxns(self, sampler: NestedSampler):
# get posterior samples
total = result.total_num_samples
rng_key = jax.random.PRNGKey(helper.seed['mcmc'])
samples = jax.tree_map(
samples = jax.tree.map(
lambda x: x[None, ...],
sampler.get_samples(rng_key, total),
)
Expand Down Expand Up @@ -2202,7 +2202,7 @@ def _init_from_ultranest(self, sampler: ReactiveNestedSampler):
ndrop = nsamples % ncores

# get posterior samples
samples = jax.tree_map(lambda x: x[None, : nsamples - ndrop], result)
samples = jax.tree.map(lambda x: x[None, : nsamples - ndrop], result)

# attrs for each group of arviz.InferenceData
attrs = {
Expand Down Expand Up @@ -2230,7 +2230,7 @@ def _init_from_nautilus(self, sampler: Sampler):
ncores = jax.local_device_count()

# get posterior samples
samples = jax.tree_map(
samples = jax.tree.map(
lambda x: x[None, : len(x) - len(x) % ncores],
result,
)
Expand All @@ -2254,7 +2254,7 @@ def _init_from_nautilus(self, sampler: Sampler):
self._lnZ = (float(sampler.log_z), None)

def _generate_idata(self, samples, attrs, sample_stats=None):
samples = jax.tree_map(jax.device_get, samples)
samples = jax.tree.map(jax.device_get, samples)
helper = self._helper

params = helper.get_params(samples)
Expand Down Expand Up @@ -2857,5 +2857,5 @@ class PPCResult(NamedTuple):

def _format_result(result: dict, order: Sequence[str]) -> dict:
"""Sort the result and use float type."""
formatted = jax.tree_map(float, result)
formatted = jax.tree.map(float, result)
return {k: formatted[k] for k in order}
22 changes: 11 additions & 11 deletions src/elisa/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def _compile_model_fn(self, model_info: ModelInfo) -> ModelCompiledFn:
@jax.jit
def fn(egrid: JAXArray, params: ParamIDValMapping) -> JAXArray:
"""The model evaluation function"""
comps_params = jax.tree_map(lambda f: f(params), pid_to_value)
comps_params = jax.tree.map(lambda f: f(params), pid_to_value)
return eval_fn(egrid, comps_params)

for integrate in model_info.integrate.values():
Expand Down Expand Up @@ -374,7 +374,7 @@ def _prepare_eval(self, params: ArrayLike | Sequence | Mapping | None):
# missing = set(self.params_name) - set(params)
# raise ValueError(f'missing parameters: {", ".join(missing)}')

params = jax.tree_map(jnp.asarray, params)
params = jax.tree.map(jnp.asarray, params)
params = self._value_mapping_to_params(params)

elif params is None:
Expand All @@ -385,8 +385,8 @@ def _prepare_eval(self, params: ArrayLike | Sequence | Mapping | None):

fn = self._fn
add_fn = self._additive_fn
shapes = jax.tree_util.tree_flatten(
tree=jax.tree_map(jnp.shape, params),
shapes = jax.tree.flatten(
tree=jax.tree.map(jnp.shape, params),
is_leaf=lambda i: isinstance(i, tuple),
)[0]

Expand Down Expand Up @@ -467,7 +467,7 @@ def ne(
if comps:
_, additive_fn, params = self._prepare_eval(params)
comps_value = additive_fn(egrid, params)
ne = jax.tree_map(lambda v: v / de, comps_value)
ne = jax.tree.map(lambda v: v / de, comps_value)
else:
ne = self.eval(egrid, params) / de

Expand Down Expand Up @@ -510,7 +510,7 @@ def ene(
ne = self.ne(egrid, params, comps)

if comps:
ene = jax.tree_map(lambda v: factor * v, ne)
ene = jax.tree.map(lambda v: factor * v, ne)
else:
ene = factor * ne

Expand Down Expand Up @@ -553,7 +553,7 @@ def eene(
ne = self.ne(egrid, params, comps)

if comps:
eene = jax.tree_map(lambda v: factor * v, ne)
eene = jax.tree.map(lambda v: factor * v, ne)
else:
eene = factor * ne

Expand Down Expand Up @@ -603,7 +603,7 @@ def ce(
fn = jax.jit(lambda v: (v * de) @ resp_matrix / channel_width)

if comps:
return jax.tree_map(fn, ne)
return jax.tree.map(fn, ne)
else:
return fn(ne)

Expand Down Expand Up @@ -668,7 +668,7 @@ def flux(
fn = jax.jit(lambda v: jnp.sum(v * de, axis=-1))

if comps:
return jax.tree_map(fn, f)
return jax.tree.map(fn, f)
else:
return fn(f)

Expand Down Expand Up @@ -736,7 +736,7 @@ def lumin(
to_lumin = lambda x: (x * flux_unit * factor).to('erg s^-1')

if comps:
return jax.tree_map(to_lumin, flux)
return jax.tree.map(to_lumin, flux)
else:
return to_lumin(flux)

Expand Down Expand Up @@ -801,7 +801,7 @@ def eiso(
to_eiso = lambda x: (x * factor).to('erg')

if comps:
return jax.tree_map(to_eiso, lumin)
return jax.tree.map(to_eiso, lumin)
else:
return to_eiso(lumin)

Expand Down
4 changes: 2 additions & 2 deletions src/elisa/util/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import jax
import jax.numpy as jnp
from astropy.units import Unit
from jax import lax, tree_util
from jax import lax
from jax.custom_derivatives import SymbolicZero
from jax.experimental import io_callback
from jax.flatten_util import ravel_pytree
Expand Down Expand Up @@ -193,7 +193,7 @@ def fdjvp(primals, tangents):

primals_out = fn(egrid, params)

tvals, _ = tree_util.tree_flatten(params_tangent)
tvals, _ = jax.tree.flatten(params_tangent)
if any(jnp.shape(v) != () for v in tvals):
raise NotImplementedError(
'JVP for non-scalar parameter is not implemented'
Expand Down

0 comments on commit 2e01322

Please sign in to comment.