Skip to content

Commit

Permalink
adjust how bounds are passed
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Nov 30, 2023
1 parent 2a4587c commit ce3e067
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 15 deletions.
11 changes: 2 additions & 9 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ profile = "black"

[tool.ruff]
line-length = 88
target-version = "py39"
target-version = "py310"
unfixable = ["E711"]

select = [
Expand Down Expand Up @@ -166,14 +166,7 @@ ignore = [
"TID252",
]

exclude = [
".github",
"docs",
"notebook",
"tests",
"src/hssm/likelihoods/hddm_wfpt/cdfdif_wrapper.c",
"src/hssm/likelihoods/hddm_wfpt/wfpt.cpp",
]
exclude = [".github", "docs", "notebook", "tests"]

[tool.ruff.pydocstyle]
convention = "numpy"
Expand Down
10 changes: 4 additions & 6 deletions src/hssm/param.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,13 +165,11 @@ def override_default_priors(self, data: pd.DataFrame, eval_env: dict[str, Any]):
for name, term in dm.group.terms.items():
if term.kind == "intercept":
if has_common_intercept:
override_priors[name] = get_default_prior(
"group_intercept", self.bounds
)
override_priors[name] = get_default_prior("group_intercept", None)
else:
# treat the term as any other group-specific term
override_priors[name] = get_default_prior(
"group_specific", bounds=None
"group_specific", bounds=self.bounds
)
else:
override_priors[name] = get_default_prior("group_specific", bounds=None)
Expand Down Expand Up @@ -220,12 +218,12 @@ def override_default_priors_ddm(self, data: pd.DataFrame, eval_env: dict[str, An
if term.kind == "intercept":
if has_common_intercept:
override_priors[name] = get_hddm_default_prior(
"group_intercept", self.name, bounds=self.bounds
"group_intercept", self.name, bounds=None
)
else:
# treat the term as any other group-specific term
override_priors[name] = get_hddm_default_prior(
"group_intercept", self.name, bounds=None
"group_intercept", self.name, bounds=self.bounds
)
else:
override_priors[name] = get_hddm_default_prior(
Expand Down

0 comments on commit ce3e067

Please sign in to comment.