Skip to content

Commit

Permalink
incorporate displace t correctly into data_generators (#60)
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFengler authored and krishnbera committed Jan 6, 2025
1 parent 329f1bf commit 91e0cd1
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 28 deletions.
54 changes: 32 additions & 22 deletions docs/basic_tutorial/basic_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -71,7 +71,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -89,7 +89,7 @@
" 'ddm_truncnormt']"
]
},
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -101,19 +101,28 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"metadata": {},
"outputs": [
{
"ename": "NameError",
"evalue": "name 'ssms' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m# Take an example config for a given model\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[43mssms\u001b[49m\u001b[38;5;241m.\u001b[39mconfig\u001b[38;5;241m.\u001b[39mmodel_config[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mddm\u001b[39m\u001b[38;5;124m\"\u001b[39m]\n",
"\u001b[0;31mNameError\u001b[0m: name 'ssms' is not defined"
]
"data": {
"text/plain": [
"{'name': 'ddm',\n",
" 'params': ['v', 'a', 'z', 't'],\n",
" 'param_bounds': [[-3.0, 0.3, 0.1, 0.0], [3.0, 2.5, 0.9, 2.0]],\n",
" 'boundary_name': 'constant',\n",
" 'boundary': <function ssms.basic_simulators.boundary_functions.constant(t: float | numpy.ndarray = 0) -> float | numpy.ndarray>,\n",
" 'boundary_params': [],\n",
" 'n_params': 4,\n",
" 'default_params': [0.0, 1.0, 0.5, 0.001],\n",
" 'nchoices': 2,\n",
" 'n_particles': 1,\n",
" 'simulator': <cyfunction ddm_flexbound at 0x16b3a3c60>}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
Expand All @@ -140,7 +149,7 @@
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -184,7 +193,7 @@
},
{
"cell_type": "code",
"execution_count": 14,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand All @@ -211,10 +220,11 @@
" 'n_subruns': 10,\n",
" 'bin_pointwise': False,\n",
" 'separate_response_channels': False,\n",
" 'smooth_unif': True}"
" 'smooth_unif': True,\n",
" 'kde_displace_t': False}"
]
},
"execution_count": 14,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -233,7 +243,7 @@
},
{
"cell_type": "code",
"execution_count": 15,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -258,14 +268,14 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'name': 'angle', 'params': ['v', 'a', 'z', 't', 'theta'], 'param_bounds': [[-3.0, 0.3, 0.1, 0.001, -0.1], [3.0, 3.0, 0.9, 2.0, 1.3]], 'boundary_name': 'angle', 'boundary': <function angle at 0x1260a27a0>, 'n_params': 5, 'default_params': [0.0, 1.0, 0.5, 0.001, 0.0], 'nchoices': 2, 'n_particles': 1, 'simulator': <cyfunction ddm_flexbound at 0x156032f60>}\n"
"{'name': 'angle', 'params': ['v', 'a', 'z', 't', 'theta'], 'param_bounds': [[-3.0, 0.3, 0.1, 0.001, -0.1], [3.0, 3.0, 0.9, 2.0, 1.3]], 'boundary_name': 'angle', 'boundary': <function angle at 0x126eb30a0>, 'n_params': 5, 'default_params': [0.0, 1.0, 0.5, 0.001, 0.0], 'nchoices': 2, 'n_particles': 1, 'simulator': <cyfunction ddm_flexbound at 0x16b3a3c60>}\n"
]
}
],
Expand All @@ -283,7 +293,7 @@
},
{
"cell_type": "code",
"execution_count": 17,
"execution_count": 9,
"metadata": {},
"outputs": [
{
Expand All @@ -303,7 +313,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 11,
"metadata": {
"tags": []
},
Expand Down
1 change: 1 addition & 0 deletions ssms/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1747,6 +1747,7 @@ def boundary_config_to_function_params(boundary_config: dict) -> dict:
"bin_pointwise": False,
"separate_response_channels": False,
"smooth_unif": True,
"kde_displace_t": False,
},
# AF-TODO: Add opn, gonogo
"ratio_estimator": {
Expand Down
20 changes: 14 additions & 6 deletions ssms/dataset_generators/lan_mlp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ssms.basic_simulators.simulator import simulator # , bin_simulator_output
from ssms.support_utils import kde_class
import numpy as np
import warnings
from copy import deepcopy
import pickle
import uuid
Expand Down Expand Up @@ -113,6 +114,18 @@ def __init__(
self.model_config["name"] += "_deadline"
self.model_config["n_params"] += 1

if "kde_displace_t" not in self.generator_config:
self.generator_config["kde_displace_t"] = False

if (
self.generator_config["kde_displace_t"] is True
and self.model_config["name"].split("_deadline")[0] in KDE_NO_DISPLACE_T
):
warnings.warn(
f"kde_displace_t is True, but model is in {KDE_NO_DISPLACE_T}. Overriding setting to False"
)
self.generator_config["kde_displace_t"] = False

# Define constrained parameter space as dictionary
# and add to internal model config
# AF-COMMENT: This will eventually be replaced so that
Expand Down Expand Up @@ -287,12 +300,7 @@ def _make_kde_data(

tmp_kde = kde_class.LogKDE(
simulations,
displace_t=(
True
if self.model_config["name"].split("_deadline")[0]
not in KDE_NO_DISPLACE_T
else False
),
displace_t=self.generator_config["kde_displace_t"],
)

# Get kde part
Expand Down

0 comments on commit 91e0cd1

Please sign in to comment.