Skip to content

Commit

Permalink
Addressed #67 and started implementation of #68
Browse files Browse the repository at this point in the history
  • Loading branch information
Javier Sanchez committed Oct 30, 2024
1 parent bd2c6e0 commit ff862a8
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions augur/analyze.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self, config, likelihood=None, tools=None, req_params=None):
Output Fisher matrix
"""

_config = parse_config(config) # Load full config
config = parse_config(config) # Load full config

# Load the likelihood if no likelihood is passed along
if likelihood is None:
Expand Down Expand Up @@ -78,7 +78,7 @@ def __init__(self, config, likelihood=None, tools=None, req_params=None):
self.pars_fid = tools.get_ccl_cosmology().__dict__['_params_init_kwargs']

# Load the relevant section of the configuration file
self.config = _config['fisher']
self.config = config['fisher']

# Initialize pivot point
self.x = []
Expand Down Expand Up @@ -187,13 +187,29 @@ def f(self, x, labels, pars_fid, sys_fid):
f_out.append(self.lk.compute_theory_vector(self.tools))
return np.array(f_out)

def get_derivatives(self, force=False, method='5pt_stencil'):
def get_derivatives(self, force=False, method='5pt_stencil', normalize_params=True):
"""
Auxiliary function to compute numerical derivatives of the helper function `f`
Parameters:
-----------
force : bool, If `True` force recalculation of the derivatives
method : str, Method to compute derivatives, currently only `5pt_stencil` or
numdifftools are accepted.
normalize_params : bool, If `True` it normalizes the parameters before computing the
derivatives.
"""
if normalize_params:
x_piv = self.x # Modify
else:
x_piv = self.x

# Compute the derivatives with respect to the parameters in var_pars at x
if (self.derivatives is None) or (force):
if '5pt_stencil' in method:
self.derivatives = five_pt_stencil(lambda y: self.f(y, self.var_pars, self.pars_fid,
self.req_params),
self.x, h=float(self.config['step']))
x_piv, h=float(self.config['step']))
elif 'numdifftools' in method:
import numdifftools as nd
if 'numdifftools_kwargs' in self.config.keys():
Expand All @@ -203,7 +219,7 @@ def get_derivatives(self, force=False, method='5pt_stencil'):
self.derivatives = nd.Gradient(lambda y: self.f(y, self.var_pars, self.pars_fid,
self.req_params),
step=float(self.config['step']),
**ndkwargs)(self.x).T
**ndkwargs)(x_piv).T
else:
raise ValueError(f'Selected method: `{method}` is not available. \
Please select 5pt_stencil or numdifftools.')
Expand Down

0 comments on commit ff862a8

Please sign in to comment.