Skip to content

Commit

Permalink
ensures that tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Nov 28, 2023
1 parent 0dc986a commit 08a92e1
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 30 deletions.
14 changes: 3 additions & 11 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,6 @@ class HSSM:
recommended when you are using hierarchical models.
The default value is `None` when `hierarchical` is `False` and `"safe"` when
`hierarchical` is `True`.
center_predictors : optional
If `True`, and if there is an intercept in the common terms, the
data is centered by subtracting the mean. The centering is undone after sampling
to provide the actual intercept in all distributional components that have an
intercept. Note that this changes the interpretation of the prior on the
intercept because it refers to the intercept of the centered data.
extra_namespace : optional
Additional user supplied variables with transformations or data to include in
the environment where the formula is evaluated. Defaults to `None`.
Expand Down Expand Up @@ -224,7 +218,6 @@ def __init__(
hierarchical: bool = False,
link_settings: Literal["log_logit"] | None = None,
prior_settings: Literal["safe"] | None = None,
center_predictors: bool = False,
extra_namespace: dict[str, Any] | None = None,
**kwargs,
):
Expand Down Expand Up @@ -333,7 +326,6 @@ def __init__(
data=data,
family=self.family,
priors=self.priors,
center_predictors=center_predictors,
extra_namespace=extra_namespace,
**other_kwargs,
)
Expand Down Expand Up @@ -938,7 +930,7 @@ def _preprocess_rest(self, processed: dict[str, Param]) -> dict[str, Param]:
bounds = self.model_config.bounds.get(param_str)
param = Param(
param_str,
formula="1 + (1|participant_id)",
formula=f"{param_str} ~ 1 + (1|participant_id)",
bounds=bounds,
)
else:
Expand Down Expand Up @@ -982,8 +974,8 @@ def _find_parent(self) -> tuple[str, Param]:
def _override_defaults(self):
"""Override the default priors or links."""
is_ddm = (
self.model in ["ddm", "ddm_sdv"] and self.loglik_kind == "analytical"
) or (self.model == "ddm_full" and self.loglik_kind == "blackbox")
self.model_name in ["ddm", "ddm_sdv"] and self.loglik_kind == "analytical"
) or (self.model_name == "ddm_full" and self.loglik_kind == "blackbox")
for param in self.list_params:
param_obj = self.params[param]
if self.prior_settings == "safe":
Expand Down
24 changes: 11 additions & 13 deletions src/hssm/param.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
"""The Param utility class."""

import logging
from typing import Any, Union, cast
from copy import deepcopy
from typing import Any, Literal, Union, cast

import bambi as bmb
import numpy as np
import pandas as pd
from deepcopy import deepcopy
from formulae import design_matrices

from .link import Link
Expand Down Expand Up @@ -100,7 +100,7 @@ def override_default_link(self):
This is most likely because both default prior and default bounds are supplied.
"""
self._ensure_not_converted()
self._ensure_not_converted(context="link")

if not self.is_regression or self._link_specified:
return # do nothing
Expand Down Expand Up @@ -144,7 +144,7 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]):
eval_env
The environment used to evaluate the formula.
"""
self._ensure_not_converted()
self._ensure_not_converted(context="prior")

if not self.is_regression:
return
Expand Down Expand Up @@ -195,7 +195,7 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An
eval_env
The environment used to evaluate the formula.
"""
self._ensure_not_converted()
self._ensure_not_converted(context="prior")
assert self.name is not None

if not self.is_regression:
Expand Down Expand Up @@ -238,7 +238,7 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An
prior = cast(dict[str, ParamSpec], self.prior)
self.prior = merge_dicts(override_priors, prior)

def _get_design_matrices(self, data: pd.DataFrame, eval_env: dict[str, Any]):
def _get_design_matrices(self, data: pd.DataFrame, extra_namespace: dict[str, Any]):
"""Get the design matrices for the regression.
Parameters
Expand All @@ -251,19 +251,17 @@ def _get_design_matrices(self, data: pd.DataFrame, eval_env: dict[str, Any]):
formula = cast(str, self.formula)
rhs = formula.split("~")[1]
formula = "rt ~ " + rhs
dm = design_matrices(formula, data=data, eval_env=eval_env)
dm = design_matrices(formula, data=data, extra_namespace=extra_namespace)

return dm

def _ensure_not_converted(self):
def _ensure_not_converted(self, context=Literal["link", "prior"]):
"""Ensure that the object has not been converted."""
if self._is_converted:
context = "link function" if context == "link" else "priors"
raise ValueError(
(
"Cannot override the default priors for parameter %s."
+ " The object has already been processed."
)
% self.name,
f"Cannot override the default {context} for parameter {self.name}."
+ " The object has already been processed."
)

def set_parent(self):
Expand Down
19 changes: 13 additions & 6 deletions src/hssm/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,9 @@ def TruncatedDist(name):


def generate_prior(
dist: str | dict | int | float, bounds: tuple[float, float] | None = None, **kwargs
dist: str | dict | int | float | Prior,
bounds: tuple[float, float] | None = None,
**kwargs,
):
"""Generate a Prior distribution.
Expand Down Expand Up @@ -184,10 +186,12 @@ def generate_prior(
prior: Prior | int | float = Prior(dist, bounds=bounds, **default_settings)
elif isinstance(dist, dict):
prior_settings = dist.copy()
dist = prior_settings.pop("dist")
dist_name: str = prior_settings.pop("dist")
for k, v in prior_settings.items():
prior_settings[k] = generate_prior(v)
prior = generate_prior(dist, bounds=bounds, **prior_settings)
prior = Prior(dist_name, bounds=bounds, **prior_settings)
elif isinstance(dist, Prior):
prior = dist
elif isinstance(dist, (int, float)):
if bounds is not None:
lower, upper = bounds
Expand Down Expand Up @@ -246,9 +250,9 @@ def get_default_prior(term_type: str, bounds: tuple[float, float] | None):
else:
prior = generate_prior("Normal")
elif term_type == "group_intercept":
prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=bounds)
prior = generate_prior("Normal", mu="Normal", sigma="Weibull")
elif term_type == "group_specific":
prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=None)
prior = generate_prior("Normal", mu="Normal", sigma="Weibull")
else:
raise ValueError("Unrecognized term type.")
return prior
Expand All @@ -263,7 +267,7 @@ def get_hddm_default_prior(
elif term_type == "common_intercept":
prior = generate_prior(HDDM_MU[param], bounds=bounds)
elif term_type == "group_intercept":
prior = generate_prior(HDDM_SETTINGS_GROUP[param], bounds=bounds)
prior = generate_prior(HDDM_SETTINGS_GROUP[param], bounds=None)
elif term_type == "group_specific":
prior = generate_prior("Normal", mu="Normal", sigma="Weibull", bounds=None)
else:
Expand All @@ -274,6 +278,9 @@ def get_hddm_default_prior(
HSSM_SETTINGS_DISTRIBUTIONS: dict[Any, Any] = {
"Normal": {"mu": 0.0, "sigma": 0.25},
"Weibull": {"alpha": 1.5, "beta": 0.3},
"HalfNormal": {"sigma": 0.25},
"Beta": {"alpha": 1.0, "beta": 1.0},
"Gamma": {"k": 1.0, "theta": 1.0},
}

HDDM_MU: dict[Any, Any] = {
Expand Down

0 comments on commit 08a92e1

Please sign in to comment.