Skip to content

Commit

Permalink
new models
Browse files Browse the repository at this point in the history
  • Loading branch information
krishnbera committed Dec 10, 2024
1 parent 60bca9c commit e8f74a9
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 5 deletions.
6 changes: 4 additions & 2 deletions ssms/basic_simulators/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,8 @@ def check_if_z_gt_a(z, a):
check_lba_drifts_sum(theta["v_RL"])
check_lba_drifts_sum(theta["v_WM"])
check_if_z_gt_a(theta["z"], theta["a"])
elif model in ["rlwm_lba_race_v2"]:
check_if_z_gt_a(theta["z"], theta["a"])
elif model in ["lba_angle_3_v2", "rlwm_lba_pw_v1"]:
check_if_z_gt_a(theta["z"], theta["a"])

Expand Down Expand Up @@ -541,7 +543,7 @@ def simulator(
theta["z"] = np.expand_dims(theta["z"], axis=1)
theta["theta"] = np.expand_dims(theta["theta"], axis=1)

if model == "rlwm_lba_race_v1":
if model in ["rlwm_lba_race_v1", "rlwm_lba_race_v2"]:
sim_param_dict["sd"] = noise_dict["lba_based_models"]
theta["v_RL"] = np.column_stack(
[theta["v_RL_0"], theta["v_RL_1"], theta["v_RL_2"]]
Expand All @@ -551,7 +553,7 @@ def simulator(
)
theta["a"] = np.expand_dims(theta["a"], axis=1)
theta["z"] = np.expand_dims(theta["z"], axis=1)

if model == "rlwm_lba_pw_v1":
sim_param_dict["sd"] = noise_dict["lba_based_models"]
theta["v_RL"] = np.column_stack(
Expand Down
4 changes: 2 additions & 2 deletions ssms/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@
"simulator": cssm.lba_angle,
},
"rlwm_lba_pw_v1": {
"name": "rlwm_lba_pw_v1",
"name": "rlwm_lba_pw_v1",
"params": [
"v_RL_0",
"v_RL_1",
Expand All @@ -320,7 +320,7 @@
"v_WM_2",
"a",
"z",
"t_WM"
"t_WM",
],
"param_bounds": [
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.05, 0.0, 0.01],
Expand Down
2 changes: 1 addition & 1 deletion ssms/dataset_generators/lan_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def parameter_transform_for_data_gen(self, theta):
tmp = theta[3]
theta[3] = theta[4]
theta[4] = tmp

if self.model_config["name"] == "rlwm_lba_pw_v1":
# ensure that a is always greater than z
if theta[6] <= theta[7]:
Expand Down

0 comments on commit e8f74a9

Please sign in to comment.