Skip to content

Commit

Permalink
entering mypy hell
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderFengler committed Dec 13, 2024
1 parent 650c7aa commit e179ad8
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/hssm/distribution_utils/dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def rng_fn(


def make_distribution(
rv: str | Type[RandomVariable] | Callable,
rv: str | Type[RandomVariable] | RandomVariable | Callable,
loglik: LogLikeFunc | pytensor.graph.Op,
list_params: list[str],
bounds: dict | None = None,
Expand Down Expand Up @@ -607,16 +607,18 @@ def make_distribution(
"""
if isinstance(rv, str):
random_variable = make_ssm_rv(rv, list_params, lapse)
rv_instance = random_variable()
elif isinstance(rv, type) and issubclass(rv, RandomVariable):
random_variable = rv
rv_instance = rv()
elif not isinstance(rv, type) and isinstance(rv, RandomVariable):
random_variable = rv
rv_instance = rv
elif callable(rv):
random_variable = make_custom_rv(

Check warning on line 616 in src/hssm/distribution_utils/dist.py

View check run for this annotation

Codecov / codecov/patch

src/hssm/distribution_utils/dist.py#L611-L616

Added lines #L611 - L616 were not covered by tests
simulator_fun=rv,
list_params=list_params,
lapse=lapse,
)
rv_instance = random_variable()

Check warning on line 621 in src/hssm/distribution_utils/dist.py

View check run for this annotation

Codecov / codecov/patch

src/hssm/distribution_utils/dist.py#L621

Added line #L621 was not covered by tests
else:
raise ValueError(f"rv is {rv}, which is not a valid type.")

Check warning on line 623 in src/hssm/distribution_utils/dist.py

View check run for this annotation

Codecov / codecov/patch

src/hssm/distribution_utils/dist.py#L623

Added line #L623 was not covered by tests

Expand Down Expand Up @@ -645,9 +647,7 @@ class SSMDistribution(pm.Distribution):
# Might be updated in the future once

# NOTE: rv_op is an INSTANCE of RandomVariable
rv_op = (
random_variable() if isinstance(random_variable, type) else random_variable
)
rv_op = rv_instance
_params = list_params

@classmethod
Expand Down

0 comments on commit e179ad8

Please sign in to comment.