Skip to content

Commit

Permalink
Moved horseshoe hyperparameter sampling into _numpyro_model routine
Browse files Browse the repository at this point in the history
  • Loading branch information
mikkelbue committed Jan 22, 2024
1 parent 12e5677 commit 7af3848
Showing 1 changed file with 7 additions and 15 deletions.
22 changes: 7 additions & 15 deletions pysindy/optimizers/sbr.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,13 @@ def _numpyro_model(self, x, y):
n_features = x.shape[1]
n_targets = y.shape[1]

# sample the hyperparameters.
tau, c_sq = sample_reg_horseshoe_hyper(
self.sparsity_coef_tau0, self.slab_shape_nu, self.slab_shape_s
# sample the horseshoe hyperparameters.
tau = numpyro.sample("tau", HalfCauchy(self.sparsity_coef_tau0))
c_sq = numpyro.sample(
"c_sq",
InverseGamma(
self.slab_shape_nu / 2, self.slab_shape_nu / 2 * self.slab_shape_s**2
),
)

# sample the parameters compute the predicted outputs.
Expand Down Expand Up @@ -141,18 +145,6 @@ def _run_mcmc(self, x, y, **kwargs):
return mcmc


def sample_reg_horseshoe_hyper(tau0=0.1, nu=4, s=2):
"""
For details on this prior, please refer to:
Hirsh, S. M., Barajas-Solano, D. A., & Kutz, J. N. (2021).
parsifying Priors for Bayesian Uncertainty Quantification in
Model Discovery (arXiv:2107.02107). arXiv. http://arxiv.org/abs/2107.02107
"""
tau = numpyro.sample("tau", HalfCauchy(tau0))
c_sq = numpyro.sample("c_sq", InverseGamma(nu / 2, nu / 2 * s**2))
return tau, c_sq


def _sample_reg_horseshoe(tau, c_sq, shape):
lamb = numpyro.sample("lambda", HalfCauchy(1.0), sample_shape=shape)
lamb_squiggle = jnp.sqrt(c_sq) * lamb / jnp.sqrt(c_sq + tau**2 * lamb**2)
Expand Down

0 comments on commit 7af3848

Please sign in to comment.