diff --git a/.github/dependabot.yml b/.github/dependabot.yml index adee0ed1..8ac6b8c4 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,4 +3,4 @@ updates: - package-ecosystem: "github-actions" directory: "/" schedule: - interval: "monthly" \ No newline at end of file + interval: "monthly" diff --git a/.readthedocs.yaml b/.readthedocs.yaml index 8a7df556..d8643793 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -21,4 +21,4 @@ python: - method: pip path: . extra_requirements: - - docs \ No newline at end of file + - docs diff --git a/docs/_templates/breadcrumbs.html b/docs/_templates/breadcrumbs.html index 339f008b..4ecb013f 100644 --- a/docs/_templates/breadcrumbs.html +++ b/docs/_templates/breadcrumbs.html @@ -1,4 +1,4 @@ {%- extends "sphinx_rtd_theme/breadcrumbs.html" %} {% block breadcrumbs_aside %} -{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/docs/conf.py b/docs/conf.py index f7c7ea4a..21eb0162 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -29,7 +29,7 @@ 'sphinx.ext.mathjax', 'sphinx.ext.napoleon', 'sphinx.ext.viewcode', - 'numpydoc.numpydoc' + 'numpydoc.numpydoc', ] templates_path = ['_templates'] diff --git a/noxfile.py b/noxfile.py index 7a34d8f7..8e7d1584 100644 --- a/noxfile.py +++ b/noxfile.py @@ -1,9 +1,9 @@ import nox -PYTHON_VERSIONS = ["3.9", "3.10", "3.11"] +PYTHON_VERSIONS = ['3.9', '3.10', '3.11'] @nox.session(python=PYTHON_VERSIONS) def test(session: nox.Session) -> None: - session.install(".[test]") - session.run("pytest", "--cov-report=xml", "--cov=elisa", *session.posargs) + session.install('.[test]') + session.run('pytest', '--cov-report=xml', '--cov=elisa', *session.posargs) diff --git a/src/elisa/infer/analysis.py b/src/elisa/infer/analysis.py index 60e42f66..59507686 100644 --- a/src/elisa/infer/analysis.py +++ b/src/elisa/infer/analysis.py @@ -1,7 +1,8 @@ """Subsequent analysis of likelihood or Bayesian fit.""" from __future__ import annotations -from typing import Literal, NamedTuple, Optional, Sequence +from collections.abc import Sequence +from typing import Literal, NamedTuple, Optional import arviz as az import jax @@ -43,11 +44,7 @@ class BootstrapResult(NamedTuple): class MLEResult: """MLE result obtained from likelihood fit.""" - def __init__( - self, - minuit: Minuit, - fit: _fit.LikelihoodFit - ): + def __init__(self, minuit: Minuit, fit: _fit.LikelihoodFit): self._minuit = minuit self._helper = helper = fit._helper self._free_names = free_names = fit._free_names @@ -86,8 +83,8 @@ def __init__( 'deviance': { 'total': stat_total, 'group': stat_group, - 'point': stat_info['point'] - } + 'point': stat_info['point'], + }, } k = len(free_names) @@ -101,7 +98,7 @@ def __init__( def __repr__(self): tab = make_pretty_table( ['Parameter', 'Value', 'Error'], - [(k, f'{v[0]:.4g}', f'{v[1]:.4g}') for k, v in self._mle.items()] + [(k, f'{v[0]:.4g}', f'{v[1]:.4g}') for k, v in self._mle.items()], ) s = 'MLE:\n' + tab.get_string() + '\n' @@ -244,10 +241,7 @@ def ci( cl_ = 1.0 - 2.0 * norm.sf(cl) if cl >= 1.0 else cl - mle = { - k: v for k, v in self._result['params'].items() - if k in params - } + mle = {k: v for k, v in self._result['params'].items() if k in params} helper = self._helper @@ -257,8 +251,7 @@ def ci( mle0 = self._minuit.values.to_dict() others = { # set other unconstrained free parameter to mle - i: mle0[i] - for i in (set(mle0.keys()) - set(free_params)) + i: mle0[i] for i in (set(mle0.keys()) - set(free_params)) } ci = self._minuit.merrors @@ -284,12 +277,13 @@ def ci( # confidence interval of function of parameters, # see, e.g. https://doi.org/10.1007/s11222-021-10012-y for p in composite_params: + def loss(x): """The loss when calculating CI of composite parameter.""" unconstr = {k: v for k, v in zip(self._free_names, x[1:])} p0 = helper.to_params_dict(unconstr)[p] diff = (p0 / x[0] - 1) / 1e-3 - return helper.deviance_unconstr(x[1:]) + diff*diff + return helper.deviance_unconstr(x[1:]) + diff * diff mle_p = mle[p] @@ -317,8 +311,10 @@ def loss(x): else: boot_result = self.boot(n=n) interval = jax.tree_map( - lambda x: tuple(np.quantile(x, q=(0.5 - cl_/2, 0.5 + cl_/2))), - {k: v for k, v in boot_result.params.items() if k in params} + lambda x: tuple( + np.quantile(x, q=(0.5 - cl_ / 2, 0.5 + cl_ / 2)) + ), + {k: v for k, v in boot_result.params.items() if k in params}, ) error = { k: (interval[k][0] - mle[k], interval[k][1] - mle[k]) @@ -327,7 +323,7 @@ def loss(x): status = { 'n': boot_result.n, 'n_valid': boot_result.n_valid, - 'seed': boot_result.seed + 'seed': boot_result.seed, } else: @@ -344,14 +340,11 @@ def format_result(v): error=format_result(error), cl=cl_, method=method, - status=status + status=status, ) def boot( - self, - n: int = 10000, - parallel: bool = True, - seed: Optional[int] = None + self, n: int = 10000, parallel: bool = True, seed: Optional[int] = None ) -> BootstrapResult: """Parametric bootstrap. @@ -389,7 +382,7 @@ def boot( n, self._seed, parallel, - run_str='Bootstrap' + run_str='Bootstrap', ) boot_result = BootstrapResult( @@ -401,7 +394,7 @@ def boot( n=n, n_valid=result['valid'].sum(), seed=self._seed, - results=boot_result + results=boot_result, ) self._boot = boot_result @@ -418,6 +411,7 @@ def plot_corner(self): class CredibleInterval(NamedTuple): """Credible interval result.""" + mle: dict[str, float] median: dict[str, float] interval: dict[str, tuple[float, float]] @@ -428,6 +422,7 @@ class CredibleInterval(NamedTuple): class PPCResult(NamedTuple): """Posterior predictive check result.""" + ... @@ -441,7 +436,7 @@ def __init__( reff: float, lnZ: tuple[float, float], sampler, - fit: _fit.BayesianFit + fit: _fit.BayesianFit, ): self._idata = idata self._ess = ess @@ -621,12 +616,8 @@ def ci( ... - def ppc( - self, - n: int = 10000, - parallel: bool = True, - seed: Optional[int] = None + self, n: int = 10000, parallel: bool = True, seed: Optional[int] = None ) -> PPCResult: """Perform posterior predictive check. diff --git a/src/elisa/infer/fit.py b/src/elisa/infer/fit.py index 8966b9f4..1d82c3d9 100644 --- a/src/elisa/infer/fit.py +++ b/src/elisa/infer/fit.py @@ -3,8 +3,9 @@ import time from abc import ABC, abstractmethod +from collections.abc import Sequence from functools import partial, reduce -from typing import Callable, Literal, NamedTuple, Optional, Sequence, TypeVar +from typing import Callable, Literal, NamedTuple, Optional, TypeVar import arviz as az import jax @@ -19,16 +20,19 @@ from numpyro import handlers from numpyro.infer import MCMC, NUTS from numpyro.infer.initialization import init_to_value -from numpyro.infer.util import constrain_fn, unconstrain_fn, log_likelihood +from numpyro.infer.util import constrain_fn, log_likelihood, unconstrain_fn from prettytable import PrettyTable from ..data.ogip import Data from ..model.base import Model from .analysis import MLEResult, PosteriorResult -from .likelihood import chi2, cstat, pstat, pgstat, wstat +from .likelihood import chi2, cstat, pgstat, pstat, wstat from .nested_sampling import NestedSampler from .util import ( - make_pretty_table, order_composite, progress_bar_factory, replace_string + make_pretty_table, + order_composite, + progress_bar_factory, + replace_string, ) __all__ = ['LikelihoodFit', 'BayesianFit'] @@ -58,6 +62,7 @@ class BaseFit(ABC): Random number generator seed. The default is 42. """ + # TODO: introduce background model that directly fit the background data _stat_options: set[str] = { @@ -65,17 +70,14 @@ class BaseFit(ABC): 'cstat', 'pstat', 'pgstat', - 'wstat' + 'wstat', # It should be noted that 'lstat' does not have long run coverage # property for source estimation, which is probably due to the choice # of conjugate prior of Poisson background data. # 'lstat' will be included here with a proper prior at some point. } - _stat_with_back: set[str] = { - 'pgstat', - 'wstat' - } + _stat_with_back: set[str] = {'pgstat', 'wstat'} def __init__( self, @@ -106,9 +108,18 @@ def __init__( info = { i: reduce(lambda x, y: x | y, (j[i] for j in info_list)) for i in [ - 'sample', 'composite', - 'default', 'min', 'max', 'dist_expr', - 'params', 'mname', 'mfmt', 'mpfmt', 'pname', 'pfmt' + 'sample', + 'composite', + 'default', + 'min', + 'max', + 'dist_expr', + 'params', + 'mname', + 'mfmt', + 'mpfmt', + 'pname', + 'pfmt', ] } @@ -133,8 +144,7 @@ def __init__( fmt_suffix[comp] = r'^\mathrm{' + joined_name + '}' info['mname'] = { - k: v + name_suffix[k] - for k, v in info['mname'].items() + k: v + name_suffix[k] for k, v in info['mname'].items() } info['mfmt'] = { # k: r'$\big[' + v + fmt_suffix[k] + r'\big]$' @@ -166,8 +176,7 @@ def __init__( ] _name = [info['pname'][i] for i in aux_params] _name_suffix = [ - '' if (n := _name[:i+1].count(name)) == 1 - else str(n) + '' if (n := _name[: i + 1].count(name)) == 1 else str(n) for i, name in enumerate(_name) ] aux_params_name = { @@ -175,8 +184,7 @@ def __init__( } _fmt = [info['pfmt'][i] for i in aux_params] _fmt_suffix = [ - '' if (n := _fmt[:i+1].count(fmt)) == 1 - else '_{%s}' % n + '' if (n := _fmt[: i + 1].count(fmt)) == 1 else '_{%s}' % n for i, fmt in enumerate(_fmt) ] aux_params_fmt = { @@ -196,7 +204,7 @@ def __init__( # model information will be displayed self._model_info = replace_string( id_mapping, - {k: (v._node.name, self._stat[k]) for k, v in self._model.items()} + {k: (v._node.name, self._stat[k]) for k, v in self._model.items()}, ) # model information will be displayed @@ -222,7 +230,12 @@ def __init__( pidx = '' params_info[f'{comp}_{i}'] = [ - pidx, comp, i, default_dic[j], bound, dist_dic[j] + pidx, + comp, + i, + default_dic[j], + bound, + dist_dic[j], ] else: # composite parameter @@ -232,9 +245,7 @@ def __init__( if v != f'{comp}_{i}' } expr = replace_string(mapping, params_dic[comp][i]) - params_info[f'{comp}_{i}'] = [ - '', comp, i, expr, '', '' - ] + params_info[f'{comp}_{i}'] = ['', comp, i, expr, '', ''] for k, v in self._info['sample'].items(): if k in params_info: @@ -248,15 +259,12 @@ def __init__( bound = '' pidx = '' - params_info[k] = [ - pidx, '', k, default_dic[k], bound, dist_dic[k] - ] + params_info[k] = [pidx, '', k, default_dic[k], bound, dist_dic[k]] self._params_info = params_info # parameters Tex format self._params_fmt = replace_string( - id_mapping, - params_fmt | aux_params_fmt + id_mapping, params_fmt | aux_params_fmt ) # parameters of spectral model function @@ -268,13 +276,15 @@ def __init__( # free parameters self._free = { - k: v for k, v in self._info['sample'].items() + k: v + for k, v in self._info['sample'].items() if not isinstance(v, float) } # fixed parameters self._fixed = { - k: v for k, v in self._info['sample'].items() + k: v + for k, v in self._info['sample'].items() if isinstance(v, float) } @@ -291,7 +301,8 @@ def __init__( # ordered parameters of interest, # which are directly input to model and not fixed self._interest_names = tuple( - k for k, v in self._params_info.items() + k + for k, v in self._params_info.items() if v[1] and k not in self._fixed ) @@ -345,9 +356,10 @@ def _sanity_check( self, data: Data | Sequence[Data], model: Model | Sequence[Model], - stat: Statistic | Sequence[Statistic] + stat: Statistic | Sequence[Statistic], ): """Check if data, model, and stat are correct and return lists.""" + def get_list(inputs, name, itype, tname): """Check the model/data/stat, and return a list.""" if isinstance(inputs, itype): @@ -466,11 +478,11 @@ def _repr_html_(self) -> str: def _make_info_table(self) -> None: self._tab1 = make_pretty_table( ['Data', 'Model', 'Statistic'], - list((k, *v) for k, v in self._model_info.items()) + list((k, *v) for k, v in self._model_info.items()), ) self._tab2 = make_pretty_table( ['No.', 'Component', 'Parameter', 'Value', 'Bound'], - list(i[:-1] for i in self._params_info.values()) + list(i[:-1] for i in self._params_info.values()), ) def mle( @@ -478,7 +490,7 @@ def mle( init: Optional[dict[str, float]] = None, lopt: Literal['minuit', 'lm'] = 'minuit', strategy: Literal[0, 1, 2] = 1, - gopt: Optional[str] = None + gopt: Optional[str] = None, ) -> MLEResult: """Find the Maximum Likelihood Estimation (MLE) for the model. @@ -526,7 +538,7 @@ def mle( ), termination_kwargs=dict( live_evidence_frac=1e-5, - ) + ), ) t0 = time.time() @@ -549,10 +561,11 @@ def mle( init_unconstr = self._init_unconstr if lopt == 'lm': - res = jax.jit(jaxopt.LevenbergMarquardt( - self._helper.residual, - stop_criterion='grad-l2-norm' - ).run)(jnp.array(self._init_unconstr)) + res = jax.jit( + jaxopt.LevenbergMarquardt( + self._helper.residual, stop_criterion='grad-l2-norm' + ).run + )(jnp.array(self._init_unconstr)) init_unconstr = res.params elif lopt != 'minuit': raise ValueError(f'invalid local optimization method {lopt}') @@ -561,7 +574,7 @@ def mle( self._helper.deviance_unconstr, np.array(init_unconstr), grad=self._helper.deviance_unconstr_grad, - name=self._free_names + name=self._free_names, ) # TODO: use simplex to "polish" the initial guess? @@ -601,11 +614,11 @@ def _repr_html_(self) -> str: def _make_info_table(self) -> None: self._tab1 = make_pretty_table( ['Data', 'Model', 'Statistic'], - list((k, *v) for k, v in self._model_info.items()) + list((k, *v) for k, v in self._model_info.items()), ) self._tab2 = make_pretty_table( ['No.', 'Component', 'Parameter', 'Value', 'Prior'], - list(i[:4] + i[-1:] for i in self._params_info.values()) + list(i[:4] + i[-1:] for i in self._params_info.values()), ) def nuts( @@ -615,7 +628,7 @@ def nuts( chains: Optional[int] = None, init: Optional[dict[str, float]] = None, progress: bool = True, - nuts_kwargs: Optional[dict] = None + nuts_kwargs: Optional[dict] = None, ) -> PosteriorResult: """Run the No-U-Turn Sampler (NUTS) of :mod:`numpyro`. @@ -663,7 +676,7 @@ def nuts( dense_mass=dense_mass, max_tree_depth=max_tree_depth, init_strategy=init_to_value(values=init), - **nuts_kwargs + **nuts_kwargs, ), num_warmup=warmup, num_samples=samples, @@ -685,7 +698,7 @@ def ns( num_parallel_workers: Optional[int] = None, difficult_model: bool = False, parameter_estimation: bool = False, - term_cond: dict = None + term_cond: dict = None, ) -> PosteriorResult: """Run the Nested Sampler of :mod:`jaxns`. @@ -760,25 +773,18 @@ def _generate_result(self, sampler) -> PosteriorResult: if not isinstance(sampler, (MCMC, NestedSampler)): raise ValueError(f'unknown sampler type {type(sampler)}') - coords = { - f'{k}_channel': v.channel - for k, v in self._data.items() - } + coords = {f'{k}_channel': v.channel for k, v in self._data.items()} - dims = { - f'{k}_Non': [f'{k}_channel'] - for k in self._data.keys() - } | { - f'{k}_Noff': [f'{k}_channel'] - for k in self._data.keys() - if self._stat[k] in self._stat_with_back + dims = {f'{k}_Non': [f'{k}_channel'] for k in self._data.keys()} | { + f'{k}_Noff': [f'{k}_channel'] + for k in self._data.keys() + if self._stat[k] in self._stat_with_back } if isinstance(sampler, MCMC): # numpyro sampler idata = az.from_numpyro(sampler, coords=coords, dims=dims) ess = { - k: int(v.values) - for k, v in az.ess(idata).data_vars.items() + k: int(v.values) for k, v in az.ess(idata).data_vars.items() } # the calculation of reff is according to arviz: @@ -811,8 +817,7 @@ def _generate_result(self, sampler) -> PosteriorResult: # get observation data observed_data = { - f'{k}_Non': v.spec_counts - for k, v in self._data.items() + f'{k}_Non': v.spec_counts for k, v in self._data.items() } | { f'{k}_Noff': v.back_counts for k, v in self._data.items() @@ -824,7 +829,7 @@ def _generate_result(self, sampler) -> PosteriorResult: log_likelihood=log_like, observed_data=observed_data, coords=coords, - dims=dims + dims=dims, ) ess = {'total': int(result.ESS)} reff = float(result.ESS / result.total_num_samples) @@ -838,8 +843,9 @@ def _generate_result(self, sampler) -> PosteriorResult: for k, v in self._data.items(): # channel-wise log likelihood of data group if f'{k}_Noff' in ln_likelihood: - ln_likelihood[k] = ln_likelihood[f'{k}_Non'] \ - + ln_likelihood[f'{k}_Noff'] + ln_likelihood[k] = ( + ln_likelihood[f'{k}_Non'] + ln_likelihood[f'{k}_Noff'] + ) else: ln_likelihood[k] = ln_likelihood[f'{k}_Non'] @@ -849,13 +855,13 @@ def _generate_result(self, sampler) -> PosteriorResult: # channel-wise log likelihood ln_likelihood['channels'] = ( ('chain', 'draw', 'channel'), - np.concatenate([ln_likelihood[i] for i in self._data], axis=-1) + np.concatenate([ln_likelihood[i] for i in self._data], axis=-1), ) # channel-wise net counts observation['channels'] = ( ('channel',), - np.concatenate([observation[i] for i in self._data], axis=-1) + np.concatenate([observation[i] for i in self._data], axis=-1), ) # total log likelihood @@ -872,6 +878,7 @@ def _generate_result(self, sampler) -> PosteriorResult: class HelperFn(NamedTuple): """A collection of helper functions.""" + numpyro_model: Callable to_dict: Callable to_constr_dict: Callable @@ -924,7 +931,11 @@ def model_counts(constr_dict: dict) -> dict: p = params_by_group(constr_dict) return jax.tree_map( lambda mi, pi, ei, ri, ti: mi(ei, pi) @ ri * ti, - spec_model, p, egrid, resp, expo + spec_model, + p, + egrid, + resp, + expo, ) # ========================= create numpyro model ========================== @@ -943,11 +954,9 @@ def model_counts(constr_dict: dict) -> dict: def numpyro_model(predictive=False): """The numpyro model.""" params = { - name: numpyro.sample(name, dist) - for name, dist in free.items() + name: numpyro.sample(name, dist) for name, dist in free.items() } | { - k: numpyro.deterministic(k, jnp.array(v)) - for k, v in fixed.items() + k: numpyro.deterministic(k, jnp.array(v)) for k, v in fixed.items() } for name, (arg_names, fn) in composite.items(): args = (params[arg_name] for arg_name in arg_names) @@ -959,7 +968,8 @@ def numpyro_model(predictive=False): jax.tree_map( lambda f, m: f(m, predictive=predictive), - stat_fn, model_counts(params) + stat_fn, + model_counts(params), ) # ============================ other functions ============================ @@ -1002,7 +1012,7 @@ def to_constr_dict(unconstr_array: Sequence) -> dict: model=numpyro_model, model_args=(), model_kwargs={}, - params=to_dict(unconstr_array) + params=to_dict(unconstr_array), ) @jax.jit @@ -1012,7 +1022,7 @@ def to_unconstr_dict(constr_array: Sequence) -> dict: model=numpyro_model, model_args=(), model_kwargs={}, - params=to_dict(constr_array) + params=to_dict(constr_array), ) @jax.jit @@ -1032,10 +1042,7 @@ def deviance_unconstr(unconstr_array: Sequence) -> float: p = to_constr_dict(unconstr_array) return -2.0 * jax.tree_util.tree_reduce( lambda x, y: x + y, - jax.tree_map( - lambda x: x.sum(), - log_likelihood(numpyro_model, p) - ) + jax.tree_map(lambda x: x.sum(), log_likelihood(numpyro_model, p)), ) # deviance_unconstr_info will be used in simulation, @@ -1048,8 +1055,7 @@ def deviance_unconstr_info(unconstr_array: Sequence) -> dict: deviance = jax.tree_map(lambda x: -2.0 * x, log_like) group = { - k: sum(deviance[i].sum() for i in v) - for k, v in group_name.items() + k: sum(deviance[i].sum() for i in v) for k, v in group_name.items() } point = { @@ -1067,7 +1073,7 @@ def to_params_dict(unconstr_dict: dict) -> dict: model_args=(), model_kwargs={}, params=unconstr_dict, - return_deterministic=True + return_deterministic=True, ) @jax.jit @@ -1085,8 +1091,7 @@ def unconstr_covar(unconstr_array: Sequence) -> jnp.ndarray: @jax.jit def params_covar( - unconstr_array: Sequence, - cov_unconstr: Sequence + unconstr_array: Sequence, cov_unconstr: Sequence ) -> jnp.ndarray: """Covariance matrix in constrained space.""" jac = jax.jacobian(to_params_array)(unconstr_array) @@ -1112,16 +1117,16 @@ def sim_result_container(n: int): 'stat_rep': { 'total': jnp.empty(n), 'group': {k: jnp.empty(n) for k in data_names}, - 'point': {k: jnp.empty((n, v)) for k, v in ndata.items()} + 'point': {k: jnp.empty((n, v)) for k, v in ndata.items()}, }, 'params_fit': {k: jnp.empty(n) for k in params_names}, 'model_fit': {k: jnp.empty((n, v)) for k, v in ndata.items()}, 'stat_fit': { 'total': jnp.empty(n), 'group': {k: jnp.empty(n) for k in data_names}, - 'point': {k: jnp.empty((n, v)) for k, v in ndata.items()} + 'point': {k: jnp.empty((n, v)) for k, v in ndata.items()}, }, - 'valid': jnp.full(n, True, bool) + 'valid': jnp.full(n, True, bool), } @jax.jit @@ -1130,20 +1135,17 @@ def sim_fit_one(i, args): sim_data, result, init = args new_data = jax.tree_map(lambda x: x[i], sim_data) - new_residual = handlers.substitute( - fn=residual, - data=new_data - ) + new_residual = handlers.substitute(fn=residual, data=new_data) new_deviance_info = handlers.substitute( - fn=deviance_unconstr_info, - data=new_data + fn=deviance_unconstr_info, data=new_data ) # update best fit params to result params = to_params_dict(to_dict(init[i])) for k in result['params_rep']: - result['params_rep'][k] = \ + result['params_rep'][k] = ( result['params_rep'][k].at[i].set(params[k]) + ) # update unfit model to result model = model_counts(to_constr_dict(init[i])) @@ -1164,16 +1166,16 @@ def sim_fit_one(i, args): # fit simulation data res = jaxopt.LevenbergMarquardt( - residual_fun=new_residual, - stop_criterion='grad-l2-norm' + residual_fun=new_residual, stop_criterion='grad-l2-norm' ).run(init[i]) state = res.state # update best fit params to result params = to_params_dict(to_dict(res.params)) for k in result['params_fit']: - result['params_fit'][k] = \ + result['params_fit'][k] = ( result['params_fit'][k].at[i].set(params[k]) + ) # update best fit model to result model = model_counts(to_constr_dict(res.params)) @@ -1219,7 +1221,7 @@ def sim_parallel_fit(sim_data, result, init, run_str): fn = progress_bar_factory(neval, ncores, run_str=run_str)(sim_fit_one) fit_results = jax.pmap( - lambda *args: lax.fori_loop(0, neval//ncores, fn, args)[1] + lambda *args: lax.fori_loop(0, neval // ncores, fn, args)[1] )(sim_data_, result_, init_) return jax.tree_map(lambda x: jnp.hstack(x), fit_results) @@ -1247,7 +1249,7 @@ def sim_parallel_fit(sim_data, result, init, run_str): sim_result_container=sim_result_container, sim_fit_one=sim_fit_one, sim_sequence_fit=sim_sequence_fit, - sim_parallel_fit=sim_parallel_fit + sim_parallel_fit=sim_parallel_fit, ) @@ -1272,7 +1274,11 @@ def _likelihood_fn(data: Data, stat: str) -> Callable: ratio = data.spec_effexpo / data.back_effexpo return partial( pgstat, - name=name, spec=spec, back=back, back_error=back_error, ratio=ratio + name=name, + spec=spec, + back=back, + back_error=back_error, + ratio=ratio, ) elif stat == 'wstat': spec = data.spec_counts diff --git a/src/elisa/infer/likelihood.py b/src/elisa/infer/likelihood.py index 673b60c5..8d3469d9 100644 --- a/src/elisa/infer/likelihood.py +++ b/src/elisa/infer/likelihood.py @@ -4,7 +4,6 @@ import jax import jax.numpy as jnp import numpyro - from jax import lax from jax.scipy.special import xlogy from numpyro.distributions import Normal, Poisson @@ -14,11 +13,7 @@ def pgstat_background( - s: NDArray, - n: NDArray, - b_est: NDArray, - b_err: NDArray, - a: float | NDArray + s: NDArray, n: NDArray, b_est: NDArray, b_err: NDArray, a: float | NDArray ) -> jax.Array: """Optimized background for PG-statistics given estimate of source counts. @@ -51,16 +46,9 @@ def pgstat_background( c = a * e - s d = jnp.sqrt(c * c + 4.0 * a * f) b = jnp.where( - jnp.bitwise_or( - jnp.greater_equal(e, 0.0), - jnp.greater_equal(f, 0.0) - ), - jnp.where( - jnp.greater(n, 0.0), - (c + d) / (2 * a), - e - ), - 0.0 + jnp.bitwise_or(jnp.greater_equal(e, 0.0), jnp.greater_equal(f, 0.0)), + jnp.where(jnp.greater(n, 0.0), (c + d) / (2 * a), e), + 0.0, ) return b @@ -105,10 +93,10 @@ def wstat_background( jnp.where( jnp.less_equal(s, a / (a + 1) * n_on), n_on / (1 + a) - s / a, - 0.0 + 0.0, ), - (c + d) / (2 * a * (a + 1)) - ) + (c + d) / (2 * a * (a + 1)), + ), ) return b @@ -145,7 +133,7 @@ def log_prob(self, value): .at[nonzero] .add(tmp - gof) .reshape(shape), - a_max=0.0 + a_max=0.0, ) else: @@ -159,7 +147,7 @@ def chi2( name: str, spec: jnp.ndarray, error: jnp.ndarray, - predictive: bool + predictive: bool, ): """Chi-squared statistic, i.e. Gaussian likelihood.""" spec_data = numpyro.primitives.mutable(f'{name}_Non_data', spec) @@ -170,16 +158,11 @@ def chi2( numpyro.sample( name=f'{name}_Non', fn=NormalWithGoodness(spec_model, error), - obs=None if predictive else spec_data + obs=None if predictive else spec_data, ) -def cstat( - model: jnp.ndarray, - name: str, - spec: jnp.ndarray, - predictive: bool -): +def cstat(model: jnp.ndarray, name: str, spec: jnp.ndarray, predictive: bool): """C-statistic, i.e. Poisson likelihood.""" spec_data = numpyro.primitives.mutable(f'{name}_Non_data', spec) @@ -189,7 +172,7 @@ def cstat( numpyro.sample( name=f'{name}_Non', fn=PoissonWithGoodness(spec_model), - obs=None if predictive else spec_data + obs=None if predictive else spec_data, ) @@ -199,7 +182,7 @@ def pstat( spec: jnp.ndarray, back: jnp.ndarray, ratio: jnp.ndarray | float, - predictive: bool + predictive: bool, ): """P-statistic, i.e. Poisson likelihood for data with known background.""" spec_data = numpyro.primitives.mutable(f'{name}_Non_data', spec) @@ -211,7 +194,7 @@ def pstat( numpyro.sample( name=f'{name}_Non', fn=PoissonWithGoodness(spec_model), - obs=None if predictive else spec_data + obs=None if predictive else spec_data, ) @@ -222,7 +205,7 @@ def pgstat( back: NDArray, back_error: NDArray, ratio: NDArray | float, - predictive: bool + predictive: bool, ): """PG-statistic, i.e. Poisson likelihood for data and profile Gaussian likelihood for background. @@ -239,13 +222,13 @@ def pgstat( numpyro.sample( name=f'{name}_Non', fn=PoissonWithGoodness(spec_model), - obs=None if predictive else spec_data + obs=None if predictive else spec_data, ) numpyro.sample( name=f'{name}_Noff', fn=NormalWithGoodness(back_model, back_error), - obs=None if predictive else back_data + obs=None if predictive else back_data, ) @@ -255,7 +238,7 @@ def wstat( spec: NDArray, back: NDArray, ratio: NDArray | float, - predictive: bool + predictive: bool, ): """W-statistic, i.e. Poisson likelihood for data and profile Poisson likelihood for background. @@ -272,11 +255,11 @@ def wstat( numpyro.sample( name=f'{name}_Non', fn=PoissonWithGoodness(spec_model), - obs=None if predictive else spec_data + obs=None if predictive else spec_data, ) numpyro.sample( name=f'{name}_Noff', fn=PoissonWithGoodness(back_model), - obs=None if predictive else back_data + obs=None if predictive else back_data, ) diff --git a/src/elisa/infer/nested_sampling.py b/src/elisa/infer/nested_sampling.py index 2640bf69..fd6120aa 100644 --- a/src/elisa/infer/nested_sampling.py +++ b/src/elisa/infer/nested_sampling.py @@ -14,14 +14,17 @@ import numpyro import numpyro.distributions as dist import tensorflow_probability.substrates.jax as tfp - from jax import random from numpyro.handlers import reparam, seed, trace from numpyro.infer import Predictive from numpyro.infer.reparam import Reparam -from numpyro.infer.util import _guess_max_plate_nesting, _validate_model, log_density +from numpyro.infer.util import ( + _guess_max_plate_nesting, + _validate_model, + log_density, +) -__all__ = ["NestedSampler"] +__all__ = ['NestedSampler'] tfpd = tfp.distributions @@ -34,10 +37,13 @@ def uniform_reparam_transform(d): """ if isinstance(d, dist.TransformedDistribution): outer_transform = dist.transforms.ComposeTransform(d.transforms) - return lambda q: outer_transform(uniform_reparam_transform(d.base_dist)(q)) + return lambda q: outer_transform( + uniform_reparam_transform(d.base_dist)(q) + ) if isinstance( - d, (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution) + d, + (dist.Independent, dist.ExpandedDistribution, dist.MaskedDistribution), ): return lambda q: uniform_reparam_transform(d.base_dist)(q) @@ -63,7 +69,9 @@ def transform(q): @uniform_reparam_transform.register(dist.CategoricalLogits) @uniform_reparam_transform.register(dist.CategoricalProbs) def _(d): - return lambda q: jnp.sum(jnp.cumsum(d.probs, axis=-1) < q[..., None], axis=-1) + return lambda q: jnp.sum( + jnp.cumsum(d.probs, axis=-1) < q[..., None], axis=-1 + ) @uniform_reparam_transform.register(dist.Dirichlet) @@ -90,15 +98,20 @@ class UniformReparam(Reparam): """ def __call__(self, name, fn, obs): - assert obs is None, "TransformReparam does not support observe statements" + assert ( + obs is None + ), 'TransformReparam does not support observe statements' shape = fn.shape() fn, expand_shape, event_dim = self._unwrap(fn) transform = uniform_reparam_transform(fn) tiny = jnp.finfo(jnp.result_type(float)).tiny x = numpyro.sample( - "{}_base".format(name), - dist.Uniform(tiny, 1).expand(shape).to_event(event_dim).mask(False), + f'{name}_base', + dist.Uniform(tiny, 1) + .expand(shape) + .to_event(event_dim) + .mask(False), ) # Simulate a numpyro.deterministic() site. return None, transform(x) @@ -178,6 +191,7 @@ def __init__( # jaxns is import here because it runs jax program when importing logging.disable(logging.INFO) # temporarily disable jaxns logging import jaxns + logging.disable(logging.NOTSET) self._jaxns = jaxns @@ -193,18 +207,20 @@ def run(self, rng_key, *args, **kwargs): rng_sampling, rng_predictive = random.split(rng_key) # reparam the model so that latent sites have Uniform(0, 1) priors - prototype_trace = trace(seed(self.model, rng_key)).get_trace(*args, **kwargs) + prototype_trace = trace(seed(self.model, rng_key)).get_trace( + *args, **kwargs + ) param_names = [ - site["name"] + site['name'] for site in prototype_trace.values() - if site["type"] == "sample" - and not site["is_observed"] - and site["infer"].get("enumerate", "") != "parallel" + if site['type'] == 'sample' + and not site['is_observed'] + and site['infer'].get('enumerate', '') != 'parallel' ] deterministics = [ - site["name"] + site['name'] for site in prototype_trace.values() - if site["type"] == "deterministic" + if site['type'] == 'deterministic' ] reparam_model = reparam( self.model, config={k: UniformReparam() for k in param_names} @@ -212,12 +228,15 @@ def run(self, rng_key, *args, **kwargs): # enable enumerate if needed has_enum = any( - site["type"] == "sample" - and site["infer"].get("enumerate", "") == "parallel" + site['type'] == 'sample' + and site['infer'].get('enumerate', '') == 'parallel' for site in prototype_trace.values() ) if has_enum: - from numpyro.contrib.funsor import enum, log_density as log_density_ + from numpyro.contrib.funsor import ( + enum, + log_density as log_density_, + ) max_plate_nesting = _guess_max_plate_nesting(prototype_trace) _validate_model(prototype_trace) @@ -231,20 +250,20 @@ def run(self, rng_key, *args, **kwargs): \tparams = dict({})\n \treturn log_density_(reparam_model, args, kwargs, params)[0] """.format( - ", ".join([f"{name}_base" for name in param_names]), - ", ".join([f"{name}_base={name}_base" for name in param_names]), + ', '.join([f'{name}_base' for name in param_names]), + ', '.join([f'{name}_base={name}_base' for name in param_names]), ) exec(loglik_fn_def, locals(), local_dict) - loglik_fn = local_dict["loglik_fn"] + loglik_fn = local_dict['loglik_fn'] # use NestedSampler with identity prior chain def prior_model(): params = [] for name in param_names: - shape = prototype_trace[name]["fn"].shape() + shape = prototype_trace[name]['fn'].shape() param = yield jaxns.Prior( tfpd.Uniform(low=jnp.zeros(shape), high=jnp.ones(shape)), - name=name + "_base", + name=name + '_base', ) params.append(param) return tuple(params) @@ -281,7 +300,9 @@ def prior_model(): rng_sampling, term_cond=jaxns.TerminationCondition(**self.termination_kwargs), ) - results = ns.to_results(termination_reason=termination_reason, state=state) + results = ns.to_results( + termination_reason=termination_reason, state=state + ) # transform base samples back to original domains # Here we only transform the first valid num_samples samples @@ -308,7 +329,7 @@ def get_samples(self, rng_key, num_samples=None): if self._results is None: raise RuntimeError( - "NestedSampler.run(...) method should be called first to obtain results." + 'NestedSampler.run(...) method should be called first to obtain results.' ) weighted_samples, sample_weights = self.get_weighted_samples() @@ -319,7 +340,11 @@ def get_samples(self, rng_key, num_samples=None): num_samples = int(num_samples) return jaxns.resample( - rng_key, weighted_samples, sample_weights, S=num_samples, replace=True + rng_key, + weighted_samples, + sample_weights, + S=num_samples, + replace=True, ) def get_weighted_samples(self): @@ -328,7 +353,7 @@ def get_weighted_samples(self): """ if self._results is None: raise RuntimeError( - "NestedSampler.run(...) method should be called first to obtain results." + 'NestedSampler.run(...) method should be called first to obtain results.' ) return self._results.samples, self._results.log_dp_mean @@ -341,7 +366,7 @@ def print_summary(self): if self._results is None: raise RuntimeError( - "NestedSampler.run(...) method should be called first to obtain results." + 'NestedSampler.run(...) method should be called first to obtain results.' ) jaxns.summary(self._results) @@ -354,7 +379,7 @@ def diagnostics(self): if self._results is None: raise RuntimeError( - "NestedSampler.run(...) method should be called first to obtain results." + 'NestedSampler.run(...) method should be called first to obtain results.' ) jaxns.plot_diagnostics(self._results) jaxns.plot_cornerplot(self._results) @@ -365,16 +390,16 @@ def reparam_loglike(model, rng_key, *args, **kwargs): # reparam the model so that latent sites have Uniform(0, 1) priors prototype_trace = trace(seed(model, rng_key)).get_trace(*args, **kwargs) param_names = [ - site["name"] + site['name'] for site in prototype_trace.values() - if site["type"] == "sample" - and not site["is_observed"] - and site["infer"].get("enumerate", "") != "parallel" + if site['type'] == 'sample' + and not site['is_observed'] + and site['infer'].get('enumerate', '') != 'parallel' ] deterministics = [ - site["name"] + site['name'] for site in prototype_trace.values() - if site["type"] == "deterministic" + if site['type'] == 'deterministic' ] reparam_model = reparam( model, config={k: UniformReparam() for k in param_names} @@ -382,8 +407,8 @@ def reparam_loglike(model, rng_key, *args, **kwargs): # enable enumerate if needed has_enum = any( - site["type"] == "sample" - and site["infer"].get("enumerate", "") == "parallel" + site['type'] == 'sample' + and site['infer'].get('enumerate', '') == 'parallel' for site in prototype_trace.values() ) if has_enum: @@ -401,8 +426,8 @@ def reparam_loglike(model, rng_key, *args, **kwargs): \tparams = dict({})\n \treturn log_density_(reparam_model, args, kwargs, params)[0] """.format( - ", ".join([f"{name}_base" for name in param_names]), - ", ".join([f"{name}_base={name}_base" for name in param_names]), + ', '.join([f'{name}_base' for name in param_names]), + ', '.join([f'{name}_base={name}_base' for name in param_names]), ) exec(loglik_fn_def, locals(), local_dict) @@ -411,4 +436,9 @@ def transform_back(samples): reparam_model, samples, return_sites=param_names + deterministics ) return predictive(rng_predictive, *args, **kwargs) - return local_dict["loglik_fn"], transform_back, [f"{name}_base" for name in param_names] + + return ( + local_dict['loglik_fn'], + transform_back, + [f'{name}_base' for name in param_names], + ) diff --git a/src/elisa/infer/simulation.py b/src/elisa/infer/simulation.py index a51277f7..b1c54bfa 100644 --- a/src/elisa/infer/simulation.py +++ b/src/elisa/infer/simulation.py @@ -5,7 +5,6 @@ import jax import numpy as np - from jax.experimental.mesh_utils import create_device_mesh from jax.sharding import PositionalSharding from numpyro import handlers @@ -55,10 +54,7 @@ def __init__(self, model: Callable, seed: int): self._sharding = PositionalSharding(device) def sample_from_one_set( - self, - params: dict[str, float], - n: int, - seed: Optional[int] = None + self, params: dict[str, float], n: int, seed: Optional[int] = None ) -> dict[str, np.ndarray]: """Sample from one set of parameters. @@ -86,13 +82,11 @@ def sample_from_one_set( dist = self._get_dist(params) poisson_sample = jax.tree_map( - lambda v: _random_poisson(seed, v, n), - dist['poisson'] + lambda v: _random_poisson(seed, v, n), dist['poisson'] ) normal_sample = jax.tree_map( - lambda v: _random_normal(seed, v[0], v[1], n), - dist['normal'] + lambda v: _random_normal(seed, v[0], v[1], n), dist['normal'] ) sample = poisson_sample | normal_sample @@ -102,7 +96,7 @@ def sample_from_one_set( def sample_from_multi_sets( self, params: dict[str, np.ndarray | jax.Array], - seed: Optional[int] = None + seed: Optional[int] = None, ) -> dict[str, np.ndarray]: """Sample from multiple sets of parameters. @@ -130,13 +124,11 @@ def sample_from_multi_sets( dist = jax.vmap(self._get_dist)(sharded_params) poisson_sample = jax.tree_map( - lambda v: _random_poisson(seed, v), - dist['poisson'] + lambda v: _random_poisson(seed, v), dist['poisson'] ) normal_sample = jax.tree_map( - lambda v: _random_normal(seed, v[0], v[1]), - dist['normal'] + lambda v: _random_normal(seed, v[0], v[1]), dist['normal'] ) sample = poisson_sample | normal_sample @@ -203,7 +195,7 @@ def run_one_set( n: int, seed: Optional[int] = None, parallel: bool = True, - run_str: Optional[str] = None + run_str: Optional[str] = None, ): """Simulate from one set of parameters and fit the simulation. @@ -258,7 +250,7 @@ def run_multi_sets( params: dict[str, jax.Array], seed: Optional[int] = None, parallel: bool = True, - run_str: Optional[str] = None + run_str: Optional[str] = None, ): """Simulate from multiple sets of parameters and fit the simulation. diff --git a/src/elisa/infer/util.py b/src/elisa/infer/util.py index 1436ad16..fec7fd00 100644 --- a/src/elisa/infer/util.py +++ b/src/elisa/infer/util.py @@ -1,10 +1,10 @@ """Helper functions for inference module.""" from __future__ import annotations -from typing import Callable, Optional, Sequence, TypeVar - import re +from collections.abc import Sequence from functools import reduce +from typing import Callable, Optional, TypeVar from jax import lax from jax.experimental import host_callback @@ -63,7 +63,7 @@ def make_pretty_table(fields: Sequence[str], rows: Sequence) -> PrettyTable: top_right_junction_char='┐', top_left_junction_char='┌', bottom_right_junction_char='┘', - bottom_left_junction_char='└' + bottom_left_junction_char='└', ) table.add_rows(rows) return table @@ -119,7 +119,7 @@ def progress_bar_factory( neval: int, ncores: int, init_str: Optional[str] = None, - run_str: Optional[str] = None + run_str: Optional[str] = None, ) -> Callable: """Add a progress bar to fori_loop kernel. Adapt from: https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/ @@ -141,7 +141,7 @@ def progress_bar_factory( else: run_str = str(run_str) - process_re = re.compile(r"\d+$") + process_re = re.compile(r'\d+$') if neval > 20: print_rate = int(neval_single / 20) @@ -169,9 +169,7 @@ def _close_tqdm(arg, transform, device): def _update_progress_bar(iter_num): _ = lax.cond( iter_num == 1, - lambda _: host_callback.id_tap( - _update_tqdm, 0, result=iter_num - ), + lambda _: host_callback.id_tap(_update_tqdm, 0, result=iter_num), lambda _: iter_num, operand=None, ) diff --git a/src/elisa/model/__init__.py b/src/elisa/model/__init__.py index 1a313e46..565d28bc 100644 --- a/src/elisa/model/__init__.py +++ b/src/elisa/model/__init__.py @@ -1,7 +1,7 @@ -from .base import * from .add import * +from .base import * +from .con import * from .mul import * from .ncon import * -from .con import * __all__ = base.__all__ + add.__all__ + mul.__all__ + con.__all__ + ncon.__all__ diff --git a/src/elisa/model/add.py b/src/elisa/model/add.py index ab1f6be3..f8b4d280 100644 --- a/src/elisa/model/add.py +++ b/src/elisa/model/add.py @@ -9,11 +9,13 @@ from .base import Component, ParamConfig from .integral import integral, list_methods - __all__ = [ - 'Band', 'BandEp', - 'Bbody', 'Bbodyrad', - 'Compt', 'Cutoffpl', + 'Band', + 'BandEp', + 'Bbody', + 'Bbodyrad', + 'Compt', + 'Cutoffpl', 'OTTB', 'Powerlaw', ] @@ -104,8 +106,8 @@ def _continnum(egrid, kT, K): jnp.where( jnp.greater_equal(x, 50.0), 0.0, # avoid exponential overflow - tmp * x / jnp.expm1(x_) - ) + tmp * x / jnp.expm1(x_), + ), ) # return 8.0525 * K * e*e / (kT*kT*kT*kT * jnp.expm1(energy / kT)) @@ -134,8 +136,8 @@ def _continnum(egrid, kT, K): jnp.where( jnp.greater_equal(x, 50.0), 0.0, # avoid exponential overflow - tmp * e / jnp.expm1(x_) - ) + tmp * e / jnp.expm1(x_), + ), ) # return 1.0344e-3 * K * e*e / jnp.expm1(e / kT) @@ -157,12 +159,14 @@ def _continnum(egrid, alpha, beta, Ec, K): amb_ = alpha - beta inv_Ec = 1.0 / Ec amb = jnp.where(jnp.less(amb_, inv_Ec), inv_Ec, amb_) - Ebreak = Ec*amb + Ebreak = Ec * amb log_func = jnp.where( jnp.less(egrid, Ebreak), alpha * jnp.log(egrid / Epiv) - egrid / Ec, - amb * jnp.log(amb * Ec / Epiv) - amb + beta * jnp.log(egrid / Epiv) + amb * jnp.log(amb * Ec / Epiv) + - amb + + beta * jnp.log(egrid / Epiv), ) return K * jnp.exp(log_func) @@ -192,7 +196,7 @@ def _continnum(egrid, alpha, beta, Ep, K): log_func = jnp.where( jnp.less(e, Ebreak), alpha * jnp.log(e / Epiv) - e / Ec, - amb * jnp.log(amb * Ec / Epiv) - amb + beta * jnp.log(e / Epiv) + amb * jnp.log(amb * Ec / Epiv) - amb + beta * jnp.log(e / Epiv), ) return K * jnp.exp(log_func) diff --git a/src/elisa/model/base.py b/src/elisa/model/base.py index 17ce6233..83e607d2 100644 --- a/src/elisa/model/base.py +++ b/src/elisa/model/base.py @@ -37,7 +37,7 @@ class Parameter: def __init__(self, node: ParameterNodeType): if not isinstance(node, (ParameterNode, ParameterOperationNode)): raise ValueError( - "node must be ParameterNode or ParameterOperationNode" + 'node must be ParameterNode or ParameterOperationNode' ) self._node = node @@ -156,14 +156,14 @@ def __init__( min: float, max: float, frozen: bool = False, - log: bool = False + log: bool = False, ): self._config = { 'default': float(default), 'min': float(min), 'max': float(max), 'frozen': bool(frozen), - 'log': bool(log) + 'log': bool(log), } self._check_and_set_values() @@ -175,7 +175,7 @@ def __init__( distribution=self._get_distribution(), min=self._config['min'], max=self._config['max'], - dist_expr=self.get_expression() + dist_expr=self.get_expression(), ) super().__init__(node) @@ -286,14 +286,10 @@ def _check_and_set_values(self, default=None, min=None, max=None) -> None: _max = float(max) if _min <= 0.0 and config['log']: - raise ValueError( - f'min ({_min}) must be positive for LogUniform' - ) + raise ValueError(f'min ({_min}) must be positive for LogUniform') if _min > _max: - raise ValueError( - f'min ({_min}) must not larger than max ({_max})' - ) + raise ValueError(f'min ({_min}) must not larger than max ({_max})') if _default < _min: raise ValueError( @@ -334,7 +330,7 @@ def _get_distribution(self) -> Distribution | float: def get_expression(self) -> str: """Get expression of distribution.""" - default = self._config["default"] + default = self._config['default'] if self._config['frozen']: expr = str(default) @@ -391,12 +387,10 @@ def __init__( self, node: ModelNodeType, params: Optional[dict[str, Parameter]] = None, - params_fmt: Optional[dict[str, str]] = None + params_fmt: Optional[dict[str, str]] = None, ): if not isinstance(node, (ModelNode, ModelOperationNode)): - raise ValueError( - 'node must be ModelNode or ModelOperationNode' - ) + raise ValueError('node must be ModelNode or ModelOperationNode') if isinstance(node, ModelNode) and not ( isinstance(params, dict) @@ -427,9 +421,7 @@ def __init__( else: params_fmt = {i: r'\mathrm{%s}' % i for i in params.keys()} - self._params_fmt = { - node.name: ModelParameterFormat(params_fmt) - } + self._params_fmt = {node.name: ModelParameterFormat(params_fmt)} def __repr__(self): return self._label.name @@ -441,9 +433,11 @@ def __mul__(self, other: Model) -> SuperModel: return SuperModel(self, other, '*') def __setattr__(self, key, value): - if hasattr(self, '_params_name') \ - and self._params_name is not None \ - and key in self._params_name: + if ( + hasattr(self, '_params_name') + and self._params_name is not None + and key in self._params_name + ): self._set_param(key, value) super().__setattr__(key, value) @@ -495,7 +489,7 @@ def eval( shapes = jax.tree_util.tree_flatten( tree=jax.tree_map(jnp.shape, params), - is_leaf=lambda i: isinstance(i, tuple) + is_leaf=lambda i: isinstance(i, tuple), )[0] if not shapes: @@ -510,14 +504,11 @@ def eval( if shape == (): eval_fn = lambda f: f(egrid, params) elif len(shape) == 1: - eval_fn = lambda f: \ - jax.vmap(f, in_axes=(None, 0))(egrid, params) + eval_fn = lambda f: jax.vmap(f, in_axes=(None, 0))(egrid, params) elif len(shape) == 2: - eval_fn = lambda f: \ - jax.vmap( - jax.vmap(f, in_axes=(None, 0)), - in_axes=(None, 0) - )(egrid, params) + eval_fn = lambda f: jax.vmap( + jax.vmap(f, in_axes=(None, 0)), in_axes=(None, 0) + )(egrid, params) else: raise ValueError(f'params ndim should <= 2, got {len(shape)}') @@ -527,7 +518,7 @@ def ne( self, egrid: jax.Array, params: dict[str, dict[str, float | jax.Array]], - comps: bool = False + comps: bool = False, ) -> jax.Array | dict[str, jax.Array]: """Calculate :math:`N_E` over `egrid`. @@ -553,7 +544,7 @@ def ne( shapes = jax.tree_util.tree_flatten( tree=jax.tree_map(jnp.shape, params), - is_leaf=lambda i: isinstance(i, tuple) + is_leaf=lambda i: isinstance(i, tuple), )[0] if not shapes: @@ -570,14 +561,16 @@ def ne( if shape == (): eval_fn = lambda f: f(egrid, params) / de elif len(shape) == 1: - eval_fn = lambda f: \ - jax.vmap(f, in_axes=(None, 0))(egrid, params) / de + eval_fn = ( + lambda f: jax.vmap(f, in_axes=(None, 0))(egrid, params) / de + ) elif len(shape) == 2: - eval_fn = lambda f: \ - jax.vmap( - jax.vmap(f, in_axes=(None, 0)), - in_axes=(None, 0) - )(egrid, params) / de + eval_fn = ( + lambda f: jax.vmap( + jax.vmap(f, in_axes=(None, 0)), in_axes=(None, 0) + )(egrid, params) + / de + ) else: raise ValueError(f'params ndim should <= 2, got {len(shape)}') @@ -590,7 +583,7 @@ def ene( self, egrid: jax.Array, params: dict[str, dict[str, float | jax.Array]], - comps: bool = False + comps: bool = False, ) -> jax.Array | dict[str, jax.Array]: r"""Calculate :math:`E N_E` (:math:`F_\nu`) over `egrid`. @@ -626,7 +619,7 @@ def eene( self, egrid: jax.Array, params: dict[str, dict[str, float | jax.Array]], - comps: bool = False + comps: bool = False, ): r"""Calculate :math:`E^2 N_E` (:math:`\nu F_\nu`) over `egrid`. @@ -664,7 +657,7 @@ def folded( params: dict[str, dict[str, float | jax.Array]], resp_matrix: jax.Array, ch_width: jax.Array, - comps: bool = False + comps: bool = False, ) -> jax.Array | dict[str, jax.Array]: """Calculate the folded spectral model (:math:`C_E`). @@ -708,7 +701,7 @@ def flux( energy: bool = True, comps: bool = False, ngrid: int = 1000, - elog: bool = True + elog: bool = True, ) -> jax.Array | dict[str, jax.Array]: """Calculate flux of the model between `emin` and `emax`. @@ -849,8 +842,9 @@ def _wrapped_comp_fn(self) -> dict[str, Callable]: if self._comp_fn is None: mapping = self._label.mapping self._comp_fn = { - m._label._label('name', mapping): - jax.jit(m._fn_wrapper(mapping['name'])) + m._label._label('name', mapping): jax.jit( + m._fn_wrapper(mapping['name']) + ) for m in self._get_comp() } @@ -873,7 +867,7 @@ def _model_info(self) -> dict: params=self._node.params, mname=mapping['name'], mfmt=mapping['fmt'], - mpfmt=self._params_fmt + mpfmt=self._params_fmt, ) return info @@ -934,9 +928,7 @@ def __setattr__(self, key, value): def __getitem__(self, name: str) -> Model: if name not in self._comps_name: - raise ValueError( - f'{self} has no "{name}" component' - ) + raise ValueError(f'{self} has no "{name}" component') return getattr(self, name) @@ -997,7 +989,7 @@ def __init__(cls, *args, **kwargs): par_def = str(init_def) init_def += 'fmt=None' - init_body += f'fmt=fmt' + init_body += 'fmt=fmt' if hasattr(cls, '_extra_kw') and isinstance(cls._extra_kw, tuple): pos_args = [] @@ -1014,7 +1006,7 @@ def __init__(cls, *args, **kwargs): init_def = s[:6] + ', '.join(pos_args) + ', ' + s[6:] func_code = f'def __init__({init_def}):\n ' - func_code += f'super(type(self), type(self))' + func_code += 'super(type(self), type(self))' func_code += f'.__init__(self, {init_body})\n' func_code += f'def {name}({par_def}):\n ' func_code += f'return {par_body}' @@ -1108,7 +1100,7 @@ def __init__(self, fmt=None, **params): mtype=mtype, params={k: v._node for k, v in params_dict.items()}, func=self._func, - is_ncon=is_ncon + is_ncon=is_ncon, ) params_fmt = {cfg[0]: cfg[1] for cfg in self._config} @@ -1151,7 +1143,7 @@ def generate_parameter( distribution: Distribution, min: Optional[float] = None, max: Optional[float] = None, - dist_expr: Optional[str] = None + dist_expr: Optional[str] = None, ) -> Parameter: """Create :class:`Parameter` instance. @@ -1190,7 +1182,7 @@ def generate_model( mtype: str, params: dict[str, Parameter], func: Callable, - is_ncon: bool + is_ncon: bool, ) -> Model: """Create :class:`Model` instance. diff --git a/src/elisa/model/con.py b/src/elisa/model/con.py index 2ce53bdc..9169c8d7 100644 --- a/src/elisa/model/con.py +++ b/src/elisa/model/con.py @@ -4,8 +4,6 @@ from abc import ABC, abstractmethod from typing import Callable -import jax.numpy as jnp - from .base import Component __all__ = [] @@ -29,5 +27,3 @@ def _func(self) -> Callable: def _convolve(*args): """Return photon flux which has been convolved.""" pass - - diff --git a/src/elisa/model/integral.py b/src/elisa/model/integral.py index fc0dc642..d6099f0e 100644 --- a/src/elisa/model/integral.py +++ b/src/elisa/model/integral.py @@ -5,6 +5,7 @@ """ from __future__ import annotations + from functools import wraps from inspect import signature from typing import Callable @@ -34,7 +35,7 @@ def {name}(egrid, {def_str}): _template: dict = { 'default': _trapezoid, 'trapezoid': _trapezoid, - 'simpson': _simpson + 'simpson': _simpson, } diff --git a/src/elisa/model/mul.py b/src/elisa/model/mul.py index 634a64aa..8ba403d8 100644 --- a/src/elisa/model/mul.py +++ b/src/elisa/model/mul.py @@ -32,9 +32,7 @@ def _continnum(*args): class Constant(MultiplicativeComponent): - _config = ( - ParamConfig('factor', 'f', 1.0, 1e-5, 1e5, False, False), - ) + _config = (ParamConfig('factor', 'f', 1.0, 1e-5, 1e5, False, False),) @staticmethod def _continnum(egrid, factor): diff --git a/src/elisa/model/ncon.py b/src/elisa/model/ncon.py index fabc3202..c2ceff3c 100644 --- a/src/elisa/model/ncon.py +++ b/src/elisa/model/ncon.py @@ -31,12 +31,7 @@ class NormalizationConvolution(Component, ABC): """ - _extra_kw = ( - ('emin',), - ('emax',), - ('ngrid', 1000), - ('elog', True) - ) + _extra_kw = (('emin',), ('emax',), ('ngrid', 1000), ('elog', True)) def __init__( self, @@ -44,7 +39,7 @@ def __init__( emax: float | int, ngrid: int = 1000, elog: bool = True, - **kwargs + **kwargs, ): if emin >= emax: raise ValueError('emin must be less than emax') diff --git a/src/elisa/model/node.py b/src/elisa/model/node.py index b3d2c1bc..9003f31d 100644 --- a/src/elisa/model/node.py +++ b/src/elisa/model/node.py @@ -45,7 +45,7 @@ def __init__( fmt: str, is_operation: bool = False, predecessor: list[Node] | None = None, - attrs: dict[str, Any] | None = None + attrs: dict[str, Any] | None = None, ): if attrs is None: attrs = dict() @@ -72,7 +72,7 @@ def __init__( fmt=fmt, type=self.type, is_operation=is_operation, - id=node_id + id=node_id, ) self._attrs.update(attrs) @@ -156,10 +156,7 @@ def __init__(self, lh: Node, rh: Node, op: str): fmt = r'\times' if op == '*' else op super().__init__( - name=op, - fmt=fmt, - is_operation=True, - predecessor=[lh, rh] + name=op, fmt=fmt, is_operation=True, predecessor=[lh, rh] ) def _label_with_id(self, label: str) -> str: @@ -262,7 +259,7 @@ def __init__( distribution: Distribution | float | int, min: float | None = None, max: float | None = None, - dist_expr: str | None = None + dist_expr: str | None = None, ): self._validate_input(distribution) @@ -273,7 +270,7 @@ def __init__( distribution=distribution, min='' if min is None else f'{min:.4g}', max='' if max is None else f'{max:.4g}', - dist_expr='' if dist_expr is None else str(dist_expr) + dist_expr='' if dist_expr is None else str(dist_expr), ) super().__init__(name=name, fmt=fmt, attrs=attrs) @@ -314,7 +311,7 @@ def site(self) -> dict: 'default': {name: self.default}, 'min': {name: self.attrs['min']}, 'max': {name: self.attrs['max']}, - 'dist_expr': {name: self.attrs['dist_expr']} + 'dist_expr': {name: self.attrs['dist_expr']}, } return info @@ -406,7 +403,7 @@ def get_field(key): 'default': get_field('default'), 'min': get_field('min'), 'max': get_field('max'), - 'dist_expr': get_field('dist_expr') + 'dist_expr': get_field('dist_expr'), } composite = info['composite'] @@ -460,7 +457,7 @@ def __init__( mtype: str, params: dict[str, ParameterNodeType], func: Callable, - is_ncon: bool + is_ncon: bool, ): if mtype not in {'add', 'mul', 'con'}: raise TypeError(f'unrecognized model type "{mtype}"') @@ -476,17 +473,10 @@ def __init__( self._params_name = tuple(params.keys()) predecessor = list(params.values()) - attrs = dict( - mtype=mtype, - func=func, - is_ncon=is_ncon - ) + attrs = dict(mtype=mtype, func=func, is_ncon=is_ncon) super().__init__( - name=name, - fmt=fmt, - predecessor=predecessor, - attrs=attrs + name=name, fmt=fmt, predecessor=predecessor, attrs=attrs ) def __add__(self, other: ModelNodeType) -> ModelOperationNode: @@ -533,7 +523,7 @@ def get_site_field(key): 'default': get_site_field('default'), 'min': get_site_field('min'), 'max': get_site_field('max'), - 'dist_expr': get_site_field('dist_expr') + 'dist_expr': get_site_field('dist_expr'), } def generate_func(self, mapping: dict[str, str]) -> Callable: @@ -552,6 +542,7 @@ def generate_func(self, mapping: dict[str, str]) -> Callable: # notation: p=params, e=egrid, f=flux_input, ff=flux_func # params structure should be {model_id: {param1: ..., param2: ...}} if mtype == 'add': + def wrapper_add(p, e, *_): """Evaluate add model.""" return func(e, **p[model_name]) @@ -559,6 +550,7 @@ def wrapper_add(p, e, *_): return wrapper_add elif mtype == 'mul': + def wrapper_mul(p, e, *_): """Evaluate mul model.""" return func(e, **p[model_name]) @@ -567,18 +559,18 @@ def wrapper_mul(p, e, *_): else: # mtype == 'con' if self.attrs['is_ncon']: + def wrapper_ncon(p, _=None, f=None, ff=None): """Evaluate ncon model, f and ff must be provided.""" other_kwargs = dict( - flux_input=f, - flux_func=ff, - func_params=p + flux_input=f, flux_func=ff, func_params=p ) return func(**p[model_name], **other_kwargs) return wrapper_ncon else: + def wrapper_con(p, e, f, *_): """Evaluate con model.""" return func(e, f, **p[model_name]) @@ -657,8 +649,10 @@ def __init__(self, lh: ModelNodeType, rh: ModelNodeType, op: str): self.attrs['is_ncon'] = is_ncon # for a convolution model, fmt is * - if not isinstance(lh, ModelOperationNode) and \ - lh.attrs.get('mtype', '') == 'con': + if ( + not isinstance(lh, ModelOperationNode) + and lh.attrs.get('mtype', '') == 'con' + ): self.attrs['fmt'] = '*' def __add__(self, other: ModelNodeType) -> ModelOperationNode: @@ -676,7 +670,7 @@ def type(self) -> str: def params(self) -> dict[str, ParameterNodeType]: """Parameter dict.""" lh, rh = self.predecessor - return lh.params |rh.params + return lh.params | rh.params @property def comps(self) -> dict[str, ModelNode]: @@ -703,7 +697,7 @@ def get_field(key): 'default': get_field('default'), 'min': get_field('min'), 'max': get_field('max'), - 'dist_expr': get_field('dist_expr') + 'dist_expr': get_field('dist_expr'), } def generate_func(self, mapping: dict[str, str]) -> Callable: @@ -717,6 +711,7 @@ def generate_func(self, mapping: dict[str, str]) -> Callable: # notation: p=params, e=egrid, f=flux, ff=flux_func if op == '+': + def wrapper_add_add(p, e, *_): """add + add""" return m1(p, e) + m2(p, e) @@ -725,6 +720,7 @@ def wrapper_add_add(p, e, *_): if type1 != 'con': # type1 is add or mul if type2 != 'con': # type2 is add or mul + def wrapper_op(p, e, *_): # add * add not allowed """add * mul, mul * add, mul * mul""" return m1(p, e) * m2(p, e) @@ -733,6 +729,7 @@ def wrapper_op(p, e, *_): # add * add not allowed else: # type2 is con if rh.attrs['is_ncon']: # type2 is ncon + def wrapper_mul_ncon(p, e, f, ff): """mul * ncon""" return m1(p, e) * m2(p, e, f, ff) @@ -740,6 +737,7 @@ def wrapper_mul_ncon(p, e, f, ff): return wrapper_mul_ncon else: # type2 is con + def wrapper_mul_con(p, e, f, *_): """mul * con""" return m1(p, e) * m2(p, e, f) @@ -749,6 +747,7 @@ def wrapper_mul_con(p, e, f, *_): else: # type1 is con if lh.attrs['is_ncon']: # type1 is ncon if type2 == 'add': + def wrapper_ncon_add(p, e, *_): """ncon * add""" return m1(p, e, m2(p, e), m2) @@ -756,8 +755,10 @@ def wrapper_ncon_add(p, e, *_): return wrapper_ncon_add elif type2 == 'mul': + def wrapper_ncon_mul(p, e, f, ff): """ncon * mul""" + def m2_ff(p_, e_, *_): """mul * add, this will be * by ncon""" return m2(p_, e_) * ff(p_, e_) @@ -767,8 +768,10 @@ def m2_ff(p_, e_, *_): return wrapper_ncon_mul else: # type2 == 'con' + def wrapper_ncon_con(p, e, f, ff): """ncon * con""" + def m2_ff(p_, e_, *_): """con * add, this will be * by ncon""" return m2(p_, e_, ff(p_, e_)) @@ -779,6 +782,7 @@ def m2_ff(p_, e_, *_): else: # type1 is con if type2 == 'add': + def wrapper_con_add(p, e, *_): """con * add""" return m1(p, e, m2(p, e)) @@ -786,6 +790,7 @@ def wrapper_con_add(p, e, *_): return wrapper_con_add elif type2 == 'mul': + def wrapper_con_mul(p, e, f, *_): """con * mul""" return m1(p, e, m2(p, e) * f) @@ -794,6 +799,7 @@ def wrapper_con_mul(p, e, f, *_): else: if rh.attrs['is_ncon']: + def wrapper_con_ncon(p, e, f, ff): """con * ncon""" return m1(p, e, m2(p, e, f, ff)) @@ -801,6 +807,7 @@ def wrapper_con_ncon(p, e, f, ff): return wrapper_con_ncon else: + def wrapper_con_con(p, e, f, *_): """con * con""" return m1(p, e, m2(p, e, f)) @@ -819,17 +826,16 @@ class LabelSpace: """ def __init__(self, node: Node): - self.node = node self._label_space = { 'name': self._get_sub_nodes_label('name'), - 'fmt': self._get_sub_nodes_label('fmt') + 'fmt': self._get_sub_nodes_label('fmt'), } self._label_map = { 'name': self._get_suffix_mapping('name'), - 'fmt': self._get_suffix_mapping('fmt') + 'fmt': self._get_suffix_mapping('fmt'), } @property diff --git a/src/elisa/plot/corner.py b/src/elisa/plot/corner.py index 49bd1d97..acdf7eab 100644 --- a/src/elisa/plot/corner.py +++ b/src/elisa/plot/corner.py @@ -6,15 +6,17 @@ from corner import corner -def plot_corner(data, axes_scale='linear', labels=None, color=None, weights=None): +def plot_corner( + data, axes_scale='linear', labels=None, color=None, weights=None +): """log_scale : bool, whether to plot vars in log which is log uniform""" plt.rcParams['font.family'] = 'serif' plt.rcParams['text.usetex'] = True levels = [ [0.683, 0.954, 0.997], # 1/2/3-sigma of 1d normal [0.393, 0.865, 0.989], # 1/2/3-sigma of 2d normal - [0.683, 0.9], # 1-sigma and 90% of 2d normal - [0.393, 0.683, 0.9] # 1-sigma, 68.3% and 90% of 2d normal + [0.683, 0.9], # 1-sigma and 90% of 2d normal + [0.393, 0.683, 0.9], # 1-sigma, 68.3% and 90% of 2d normal ][-1] # def to_hex(c): @@ -50,7 +52,7 @@ def plot_corner(data, axes_scale='linear', labels=None, color=None, weights=None no_fill_contours=True, contour_kwargs={'colors': colors1}, contourf_kwargs={'colors': ['white'] + colors2, 'alpha': 0.75}, - data_kwargs={'color': colors2[0], 'alpha': 0.75} + data_kwargs={'color': colors2[0], 'alpha': 0.75}, ) @@ -75,10 +77,7 @@ def clip(num): def _gradient_colors( - color: str, - n: int, - factor_min: float = 0.9, - factor_max: float = 1.5 + color: str, n: int, factor_min: float = 0.9, factor_max: float = 1.5 ) -> list[str]: color = str(color) n = int(n) @@ -106,7 +105,7 @@ def _contour_colors( n: int, factor_min: float = 0.9, factor_max: float = 1.5, - factor_f: float = 0.72 + factor_f: float = 0.72, ) -> tuple: color = str(color) n = int(n) @@ -115,6 +114,6 @@ def _contour_colors( f = float(factor_f) contourf_colors = _gradient_colors(color, n, factor_min, factor_max) - contour_colors = _gradient_colors(color, n, f*factor_min, f*factor_max) + contour_colors = _gradient_colors(color, n, f * factor_min, f * factor_max) return contour_colors, contourf_colors diff --git a/tests/model/test_name.py b/tests/model/test_name.py index 219edd4a..d5d2d393 100644 --- a/tests/model/test_name.py +++ b/tests/model/test_name.py @@ -3,4 +3,4 @@ def test_model_name(): model = Powerlaw() + Powerlaw() - assert repr(model) == "powerlaw + powerlaw2" + assert repr(model) == 'powerlaw + powerlaw2' diff --git a/tests/parameter/test_name.py b/tests/parameter/test_name.py index 4f7c76ea..7f032821 100644 --- a/tests/parameter/test_name.py +++ b/tests/parameter/test_name.py @@ -2,8 +2,8 @@ def test_param_name(): - a = UniformParameter("a", "a", 1.0, 0.0, 2.0) - a2 = UniformParameter("a2", "a2", 1.0, 0.0, 2.0) + a = UniformParameter('a', 'a', 1.0, 0.0, 2.0) + a2 = UniformParameter('a2', 'a2', 1.0, 0.0, 2.0) b = a + a2 - assert repr(b) == "a + a2" + assert repr(b) == 'a + a2' assert b.default == 2.0