Skip to content

Commit

Permalink
Merge pull request #319 from lnccbrown/292-graphing-quantile-probabil…
Browse files Browse the repository at this point in the history
…ity-plot

First attempt at Quantile probability plot
  • Loading branch information
digicosmos86 authored Nov 15, 2023
2 parents 99dc88a + 42f6339 commit b0c124e
Show file tree
Hide file tree
Showing 16 changed files with 1,174 additions and 436 deletions.
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ repos:
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.278
rev: v0.1.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.10.1
hooks:
- id: black-jupyter
args:
Expand All @@ -29,7 +29,7 @@ repos:
build|
dist"""
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1 # Use the sha / tag you want to point at
rev: v1.6.1 # Use the sha / tag you want to point at
hooks:
- id: mypy
args: [--no-strict-optional, --ignore-missing-imports]
488 changes: 243 additions & 245 deletions docs/tutorials/plotting.ipynb

Large diffs are not rendered by default.

9 changes: 6 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@ bambi = "^0.12.0"
numpyro = "^0.12.1"
hddm-wfpt = "^0.1.1"
seaborn = "^0.13.0"
xhistogram = "^0.3.2"

[tool.poetry.group.dev.dependencies]
pytest = "^7.3.1"
black = { extras = ["jupyter"], version = "^23.7.0" }
mypy = "^1.4.1"
black = { extras = ["jupyter"], version = "^23.10.1" }
mypy = "^1.6.1"
pre-commit = "^2.20.0"
jupyterlab = "^4.0.2"
ipykernel = "^6.16.0"
ipywidgets = "^8.0.3"
graphviz = "^0.20.1"
ruff = "^0.0.272"
ruff = "^0.1.3"
mkdocs = "^1.4.3"
mkdocs-material = "^9.1.17"
mkdocstrings-python = "^1.1.2"
Expand Down Expand Up @@ -159,6 +160,8 @@ ignore = [
"PLR2004",
# Consider `elif` instead of `else` then `if` to remove indentation level
"PLR5501",
# Ignore "Use `float` instead of `int | float`."
"PYI041",
# Allow importing from parent modules
"TID252",
]
Expand Down
12 changes: 9 additions & 3 deletions src/hssm/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,20 @@
"""

import os
from collections import namedtuple
from typing import Optional, Union
from typing import NamedTuple, Optional, Union

import pandas as pd

base_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

FileMetadata = namedtuple("FileMetadata", ["filename", "path", "description"])

class FileMetadata(NamedTuple):
"""Typing for dataset metadata."""

filename: str
path: str
description: str


DATASETS = {
"cavanagh_theta": FileMetadata(
Expand Down
8 changes: 7 additions & 1 deletion src/hssm/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,13 @@ class DefaultConfig(TypedDict):
"approx_differentiable": {
"loglik": "ddm_sdv.onnx",
"backend": "jax",
"default_priors": {},
"default_priors": {
"t": {
"name": "HalfNormal",
"sigma": 2.0,
"initval": 0.1,
},
},
"bounds": {
"v": (-3.0, 3.0),
"a": (0.3, 2.5),
Expand Down
19 changes: 12 additions & 7 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import logging
from os import PathLike
from typing import Any, Callable, Iterable, Type
from typing import Any, Callable, Type

import bambi as bmb
import numpy as np
Expand Down Expand Up @@ -279,19 +279,24 @@ def rng_fn(
)
out_shape = sims_out.shape[:-1]
replace = rng.binomial(n=1, p=p_outlier, size=out_shape).astype(bool)
replace = np.stack([replace, replace], axis=-1)
n_draws = np.prod(out_shape)
replace_n = int(np.sum(replace, axis=None))
if replace_n == 0:
return sims_out
replace_shape = (*out_shape[:-1], replace_n)
replace_mask = np.stack([replace, replace], axis=-1)
n_draws = np.prod(replace_shape)
lapse_rt = pm.draw(
get_distribution_from_prior(cls._lapse).dist(**cls._lapse.args),
n_draws,
random_seed=rng,
).reshape(out_shape)
lapse_response = rng.binomial(n=1, p=0.5, size=out_shape)
).reshape(replace_shape)
lapse_response = rng.binomial(n=1, p=0.5, size=replace_shape)
lapse_response = np.where(lapse_response == 1, 1, -1)
lapse_output = np.stack(
[lapse_rt, lapse_response],
axis=-1,
)
np.putmask(sims_out, replace, lapse_output)
np.putmask(sims_out, replace_mask, lapse_output)

return sims_out

Expand Down Expand Up @@ -379,7 +384,7 @@ def dist(cls, **kwargs): # pylint: disable=arguments-renamed

def logp(data, *dist_params): # pylint: disable=E0213
num_params = len(list_params)
extra_fields: Iterable[np.ndarray] = []
extra_fields = []

if num_params < len(dist_params):
extra_fields = dist_params[num_params:]
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/distribution_utils/onnx/onnx2pt.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def pt_interpret_onnx(graph, *args):
"""
vals = dict(
{n.name: a for n, a in zip(graph.input, args)},
**{n.name: _asarray(n) for n in graph.initializer}
**{n.name: _asarray(n) for n in graph.initializer},
)
for node in graph.node:
args = (vals[name] for name in node.input)
Expand Down
2 changes: 1 addition & 1 deletion src/hssm/distribution_utils/onnx/onnx2xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def interpret_onnx(graph, *args):
"""
vals = dict(
{n.name: a for n, a in zip(graph.input, args)},
**{n.name: _asarray(n) for n in graph.initializer}
**{n.name: _asarray(n) for n in graph.initializer},
)
for node in graph.node:
args = (vals[name] for name in node.input)
Expand Down
3 changes: 2 additions & 1 deletion src/hssm/plotting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Plotting functionalities for HSSM."""

from .posterior_predictive import plot_posterior_predictive
from .quantile_probability import plot_quantile_probability

__all__ = ["plot_posterior_predictive"]
__all__ = ["plot_posterior_predictive", "plot_quantile_probability"]
Loading

0 comments on commit b0c124e

Please sign in to comment.