Skip to content

Commit

Permalink
run black
Browse files Browse the repository at this point in the history
  • Loading branch information
krishnbera committed Dec 14, 2024
1 parent a276091 commit ba9628f
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
10 changes: 9 additions & 1 deletion ssms/basic_simulators/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,7 +526,15 @@ def check_if_z_gt_a(z: np.ndarray, a: np.ndarray) -> None:
if np.any(z >= a):
raise ValueError("Starting point z >= a for at least one trial")

if model in ["lba_3_v1", "lba_angle_3_v1", "lba_angle_3_v2", "lba_angle_3_v3", "rlwm_lba_race_v1", "rlwm_lba_race_v2", "rlwm_lba_pw_v1"]:
if model in [
"lba_3_v1",
"lba_angle_3_v1",
"lba_angle_3_v2",
"lba_angle_3_v3",
"rlwm_lba_race_v1",
"rlwm_lba_race_v2",
"rlwm_lba_pw_v1",
]:
if model in ["lba_3_v1", "lba_angle_3_v1"]:
check_lba_drifts_sum(theta["v"])
check_if_z_gt_a(theta["z"], theta["a"])
Expand Down
2 changes: 1 addition & 1 deletion ssms/dataset_generators/lan_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def parameter_transform_for_data_gen(self, theta: dict):
Dictionary containing the transformed parameters.
"""

#print(theta)
# print(theta)

if self.model_config["name"] in ["lba_angle_3_v2", "lba_angle_3_v3"]:
# ensure that a is always greater than z
Expand Down

0 comments on commit ba9628f

Please sign in to comment.