Skip to content

Commit

Permalink
bumped ssms version so tests pass
Browse files Browse the repository at this point in the history
  • Loading branch information
digicosmos86 committed Oct 11, 2023
1 parent d0fa9ca commit 5583825
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ numpy = ">=1.23.4,<1.26"
onnx = "^1.12.0"
jax = "^0.4.0"
jaxlib = "^0.4.0"
ssm-simulators = "^0.4.1"
ssm-simulators = "0.5.1"
huggingface-hub = "^0.15.1"
onnxruntime = "^1.15.0"
bambi = "^0.12.0"
Expand Down
10 changes: 8 additions & 2 deletions src/hssm/hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,12 +183,18 @@ def __init__(
loglik_kind: LoglikKind | None = None,
p_outlier: float | dict | bmb.Prior | None = 0.05,
lapse: dict | bmb.Prior | None = bmb.Prior("Uniform", lower=0.0, upper=10.0),
hierarchical: bool = True,
hierarchical: bool = False,
**kwargs,
):
self.data = data
self._inference_obj = None
self.hierarchical = hierarchical and "participant_id" in data.columns
self.hierarchical = hierarchical

if self.hierarchical and "participant_id" not in self.data.columns:
raise ValueError(
"You have specified a hierarchical model, but there is no "
+ "`participant_id` field in the DataFrame that you have passed."
)

# Construct a model_config from defaults
self.model_config = Config.from_defaults(model, loglik_kind)
Expand Down
13 changes: 10 additions & 3 deletions tests/test_hssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,21 +224,27 @@ def test_hierarchical(data_ddm):
data_ddm = data_ddm.iloc[:10, :].copy()
data_ddm["participant_id"] = np.arange(10)

model = HSSM(data=data_ddm)
model = HSSM(data=data_ddm, hierarchical=True)
assert all(
param.is_regression
for name, param in model.params.items()
if name != "p_outlier"
)

model = HSSM(data=data_ddm, v=bmb.Prior("Uniform", lower=-10.0, upper=10.0))
model = HSSM(
data=data_ddm,
v=bmb.Prior("Uniform", lower=-10.0, upper=10.0),
hierarchical=True,
)
assert all(
param.is_regression
for name, param in model.params.items()
if name not in ["v", "p_outlier"]
)

model = HSSM(data=data_ddm, a=bmb.Prior("Uniform", lower=0.0, upper=10.0))
model = HSSM(
data=data_ddm, a=bmb.Prior("Uniform", lower=0.0, upper=10.0), hierarchical=True
)
assert all(
param.is_regression
for name, param in model.params.items()
Expand All @@ -249,6 +255,7 @@ def test_hierarchical(data_ddm):
data=data_ddm,
v=bmb.Prior("Uniform", lower=-10.0, upper=10.0),
a=bmb.Prior("Uniform", lower=0.0, upper=10.0),
hierarchical=True,
)
assert all(
param.is_regression
Expand Down

0 comments on commit 5583825

Please sign in to comment.