From ce3e067f65281bff1b08283ae258787afffa3fe6 Mon Sep 17 00:00:00 2001 From: Paul Xu Date: Thu, 30 Nov 2023 11:30:37 -0500 Subject: [PATCH] adjust how bounds are passed --- pyproject.toml | 11 ++--------- src/hssm/param.py | 10 ++++------ 2 files changed, 6 insertions(+), 15 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0200e92f..eccafae3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,7 +69,7 @@ profile = "black" [tool.ruff] line-length = 88 -target-version = "py39" +target-version = "py310" unfixable = ["E711"] select = [ @@ -166,14 +166,7 @@ ignore = [ "TID252", ] -exclude = [ - ".github", - "docs", - "notebook", - "tests", - "src/hssm/likelihoods/hddm_wfpt/cdfdif_wrapper.c", - "src/hssm/likelihoods/hddm_wfpt/wfpt.cpp", -] +exclude = [".github", "docs", "notebook", "tests"] [tool.ruff.pydocstyle] convention = "numpy" diff --git a/src/hssm/param.py b/src/hssm/param.py index 2f2309a9..e8e4195d 100644 --- a/src/hssm/param.py +++ b/src/hssm/param.py @@ -165,13 +165,11 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]): for name, term in dm.group.terms.items(): if term.kind == "intercept": if has_common_intercept: - override_priors[name] = get_default_prior( - "group_intercept", self.bounds - ) + override_priors[name] = get_default_prior("group_intercept", None) else: # treat the term as any other group-specific term override_priors[name] = get_default_prior( - "group_specific", bounds=None + "group_specific", bounds=self.bounds ) else: override_priors[name] = get_default_prior("group_specific", bounds=None) @@ -220,12 +218,12 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An if term.kind == "intercept": if has_common_intercept: override_priors[name] = get_hddm_default_prior( - "group_intercept", self.name, bounds=self.bounds + "group_intercept", self.name, bounds=None ) else: # treat the term as any other group-specific term override_priors[name] = get_hddm_default_prior( - "group_intercept", self.name, bounds=None + "group_intercept", self.name, bounds=self.bounds ) else: override_priors[name] = get_hddm_default_prior(