Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simple shrink spot #57

Merged
merged 5 commits into from
Dec 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
464 changes: 464 additions & 0 deletions notebooks/basic_tutorial_12122024.ipynb

Large diffs are not rendered by default.

File renamed without changes.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ requires = ["setuptools", "wheel", "Cython>=0.29.23", "numpy >= 1.20"]

[project]
name= "ssm-simulators"
version= "0.7.8"
version= "0.7.9"
authors= [{name = "Alexander Fenger", email = "[email protected]"}]
description= "SSMS is a package collecting simulators and training data generators for a bunch of generative models of interest in the cognitive science / neuroscience and approximate bayesian computation communities"
readme = "README.md"
Expand Down
2 changes: 1 addition & 1 deletion ssms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,6 @@
from . import config
from . import support_utils

__version__ = "0.7.8" # importlib.metadata.version(__package__ or __name__)
__version__ = "0.7.9" # importlib.metadata.version(__package__ or __name__)

__all__ = ["basic_simulators", "dataset_generators", "config", "support_utils"]
54 changes: 47 additions & 7 deletions ssms/basic_simulators/drift_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ def ds_conflict_drift(

def attend_drift(
t: np.ndarray = np.arange(0, 20, 0.1),
p_target: float = -0.3,
p_outer: float = -0.3,
p_inner: float = 0.3,
ptarget: float = -0.3,
pouter: float = -0.3,
pinner: float = 0.3,
r: float = 0.5,
sda: float = 2,
) -> np.ndarray:
Expand All @@ -160,11 +160,11 @@ def attend_drift(
t: np.ndarray
Timepoints at which to evaluate the drift.
Usually np.arange() of some sort.
p_outer: float
pouter: float
perceptual input for outer flankers
p_inner: float
pinner: float
perceptual input for inner flankers
p_target: float
ptarget: float
perceptual input for target flanker
r: float
rate parameter for sda decrease
Expand All @@ -184,7 +184,47 @@ def attend_drift(
-0.5, loc=0, scale=new_sda
)

v_t = 2 * p_outer * a_outer + 2 * p_inner * a_inner + p_target * a_target
v_t = (2 * pouter * a_outer) + (2 * pinner * a_inner) + (ptarget * a_target)

return v_t


def attend_drift_simple(
t: np.ndarray = np.arange(0, 20, 0.1),
ptarget: float = -0.3,
pouter: float = -0.3,
r: float = 0.5,
sda: float = 2,
) -> np.ndarray:
"""Drift function for shrinking spotlight model, which involves a time varying
function dependent on a linearly decreasing standard deviation of attention.

Arguments
--------
t: np.ndarray
Timepoints at which to evaluate the drift.
Usually np.arange() of some sort.
pouter: float
perceptual input for outer flankers
ptarget: float
perceptual input for target flanker
r: float
rate parameter for sda decrease
sda: float
width of attentional spotlight
Return
------
np.ndarray
Drift evaluated at timepoints t
"""

new_sda = np.maximum(sda - r * t, 0.001)
a_outer = 1.0 - norm.cdf(
0.5, loc=0, scale=new_sda
) # equivalent to norm.sf(0.5, loc=0, scale=new_sda)
a_target = norm.cdf(0.5, loc=0, scale=new_sda) - 0.5

v_t = (2 * pouter * a_outer) + (2 * ptarget * a_target)

return v_t

Expand Down
7 changes: 6 additions & 1 deletion ssms/basic_simulators/theta_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,12 @@ def process_theta(
theta["sv"]
)

if model in ["shrink_spot", "shrink_spot_extended"]:
if model in [
"shrink_spot",
"shrink_spot_extended",
"shrink_spot_extended_angle",
"shrink_spot_simple_extended",
]:
theta["v"] = np.tile(np.array([0], dtype=np.float32), n_trials)

# Multi-particle models
Expand Down
43 changes: 36 additions & 7 deletions ssms/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,11 @@ def boundary_config_to_function_params(boundary_config: dict) -> dict:
},
"attend_drift": {
"fun": df.attend_drift,
"params": ["p_target", "p_outer", "p_inner", "r", "sda"],
"params": ["ptarget", "pouter", "pinner", "r", "sda"],
},
"attend_drift_simple": {
"fun": df.attend_drift_simple,
"params": ["ptarget", "pouter", "r", "sda"],
},
}

Expand Down Expand Up @@ -344,9 +348,9 @@ def boundary_config_to_function_params(boundary_config: dict) -> dict:
"a",
"z",
"t",
"p.target",
"p.outer",
"p.inner",
"ptarget",
"pouter",
"pinner",
"r",
"sda",
],
Expand All @@ -370,9 +374,9 @@ def boundary_config_to_function_params(boundary_config: dict) -> dict:
"a",
"z",
"t",
"p.target",
"p.outer",
"p.inner",
"ptarget",
"pouter",
"pinner",
"r",
"sda",
],
Expand All @@ -390,6 +394,31 @@ def boundary_config_to_function_params(boundary_config: dict) -> dict:
"n_particles": 1,
"simulator": cssm.ddm_flex,
},
"shrink_spot_simple_extended": {
"name": "shrink_spot_simple_extended",
"params": [
"a",
"z",
"t",
"ptarget",
"pouter",
"r",
"sda",
],
"param_bounds": [
[0.3, 0.1, 1e-3, 2.0, -5.5, 0.01, 1],
[3.0, 0.9, 2.0, 5.5, 5.5, 1.0, 3],
],
"boundary_name": "constant",
"boundary": bf.constant,
"drift_name": "attend_drift_simple",
"drift_fun": df.attend_drift_simple,
"n_params": 7,
"default_params": [0.7, 0.5, 0.25, 2.0, -2.0, 0.01, 1],
"nchoices": 2,
"n_particles": 1,
"simulator": cssm.ddm_flex,
},
"gamma_drift_angle": {
"name": "gamma_drift_angle",
"params": ["v", "a", "z", "t", "theta", "shape", "scale", "c"],
Expand Down
Loading