Skip to content

Commit

Permalink
Add constrained transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
tomicapretto committed Nov 29, 2023
1 parent 2d4b260 commit e83be83
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 4 deletions.
35 changes: 34 additions & 1 deletion bambi/backend/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,34 @@ def build_response_distribution(self, kwargs, pymc_backend):
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)

# Handle constrained responses (through truncated distributions)
elif self.term.is_constrained:
dims = kwargs.pop("dims", None)
data_matrix = kwargs.pop("observed")

# Get values of the response variable
observed = np.squeeze(data_matrix[:, 0])

# Get truncation values
lower = np.squeeze(data_matrix[:, 1])
upper = np.squeeze(data_matrix[:, 2])

# Handle 'None' and scalars appropriately
if np.all(lower == -np.inf):
lower = None
elif np.all(lower == lower[0]):
lower = lower[0]

if np.all(upper == np.inf):
upper = None
elif np.all(upper == upper[0]):
upper = upper[0]

stateless_dist = distribution.dist(**kwargs)
dist_rv = pm.Truncated(
self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims
)

# Handle weighted responses
elif self.term.is_weighted:
dims = kwargs.pop("dims", None)
Expand Down Expand Up @@ -361,7 +389,12 @@ def robustify_dims(self, pymc_backend, kwargs):
if isinstance(self.family, (Multinomial, DirichletMultinomial)):
return kwargs

if self.term.is_censored or self.term.is_truncated or self.term.is_weighted:
if (
self.term.is_censored
or self.term.is_truncated
or self.term.is_weighted
or self.term.is_constrained
):
return kwargs

dims, data = kwargs["dims"], kwargs["observed"]
Expand Down
16 changes: 14 additions & 2 deletions bambi/families/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,9 @@ def posterior_predictive(self, model, posterior, **kwargs):
A data array with the draws from the posterior predictive distribution
"""
response_dist = get_response_dist(model.family)
response_term = model.response_component.response_term
params = model.family.likelihood.params
response_aliased_name = get_aliased_name(model.response_component.response_term)
response_aliased_name = get_aliased_name(response_term)

kwargs.pop("data", None) # Remove the 'data' kwarg
dont_reshape = kwargs.pop("dont_reshape", [])
Expand Down Expand Up @@ -181,7 +182,18 @@ def posterior_predictive(self, model, posterior, **kwargs):
if hasattr(model.family, "transform_kwargs"):
kwargs = model.family.transform_kwargs(kwargs)

output_array = pm.draw(response_dist.dist(**kwargs))
# Handle constrained responses
if response_term.is_constrained:
# Bounds are scalars, we can safely pick them from the first row
lower, upper = response_term.data[0, 1:]
lower = lower if lower != -np.inf else None
upper = upper if upper != np.inf else None
output_array = pm.draw(
pm.Truncated.dist(response_dist.dist(**kwargs), lower=lower, upper=upper)
)
else:
output_array = pm.draw(response_dist.dist(**kwargs))

output_coords_all = xr.merge(output_dataset_list).coords

coord_names = ["chain", "draw", response_aliased_name + "_obs"]
Expand Down
8 changes: 7 additions & 1 deletion bambi/terms/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,20 @@

from bambi.terms.base import BaseTerm

from bambi.terms.utils import is_censored_response, is_truncated_response, is_weighted_response
from bambi.terms.utils import (
is_censored_response,
is_constrained_response,
is_truncated_response,
is_weighted_response,
)


class ResponseTerm(BaseTerm):
def __init__(self, response, family):
self.term = response.term.term
self.family = family
self.is_censored = is_censored_response(self.term)
self.is_constrained = is_constrained_response(self.term)
self.is_truncated = is_truncated_response(self.term)
self.is_weighted = is_weighted_response(self.term)

Expand Down
10 changes: 10 additions & 0 deletions bambi/terms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,16 @@ def is_truncated_response(term):
return is_call_of_kind(component, "truncated")


def is_constrained_response(term):
"""Determines if a formulae term represents a constrained response"""
if not is_single_component(term):
return False
component = term.components[0] # get the first (and single) component
if not is_call_component(component):
return False
return is_call_of_kind(component, "constrained")


def is_weighted_response(term):
"""Determines if a formulae term represents a weighted response"""
if not is_single_component(term):
Expand Down
18 changes: 18 additions & 0 deletions bambi/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,23 @@ def truncated(x, lb=None, ub=None):
truncated.__metadata__ = {"kind": "truncated"}


def constrained(x, lb=None, ub=None):
"""Construct an array for a constrained response
It's exactly like truncated, but it's interpreted by Bambi in a different way as this
one truncates/constrains the bounds of a probability distribution, while `truncated()` is
interpreted as the missing data mechanism.
`lb` and `ub` can only be scalar values.
"""
assert lb is None or isinstance(lb, (int, float))
assert ub is None or isinstance(ub, (int, float))
return truncated(x, lb, ub)


constrained.__metadata__ = {"kind": "constrained"}


def weighted(x, weights):
"""Construct array for a weighted response
Expand Down Expand Up @@ -403,6 +420,7 @@ def get_distance(x):
transformations_namespace = {
"c": c,
"censored": censored,
"constrained": constrained,
"truncated": truncated,
"weighted": weighted,
"log": np.log,
Expand Down

0 comments on commit e83be83

Please sign in to comment.