diff --git a/bambi/backend/terms.py b/bambi/backend/terms.py index 9f9dff549..0437d00c8 100644 --- a/bambi/backend/terms.py +++ b/bambi/backend/terms.py @@ -326,6 +326,34 @@ def build_response_distribution(self, kwargs, pymc_backend): self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims ) + # Handle constrained responses (through truncated distributions) + elif self.term.is_constrained: + dims = kwargs.pop("dims", None) + data_matrix = kwargs.pop("observed") + + # Get values of the response variable + observed = np.squeeze(data_matrix[:, 0]) + + # Get truncation values + lower = np.squeeze(data_matrix[:, 1]) + upper = np.squeeze(data_matrix[:, 2]) + + # Handle 'None' and scalars appropriately + if np.all(lower == -np.inf): + lower = None + elif np.all(lower == lower[0]): + lower = lower[0] + + if np.all(upper == np.inf): + upper = None + elif np.all(upper == upper[0]): + upper = upper[0] + + stateless_dist = distribution.dist(**kwargs) + dist_rv = pm.Truncated( + self.name, stateless_dist, lower=lower, upper=upper, observed=observed, dims=dims + ) + # Handle weighted responses elif self.term.is_weighted: dims = kwargs.pop("dims", None) @@ -361,7 +389,12 @@ def robustify_dims(self, pymc_backend, kwargs): if isinstance(self.family, (Multinomial, DirichletMultinomial)): return kwargs - if self.term.is_censored or self.term.is_truncated or self.term.is_weighted: + if ( + self.term.is_censored + or self.term.is_truncated + or self.term.is_weighted + or self.term.is_constrained + ): return kwargs dims, data = kwargs["dims"], kwargs["observed"] diff --git a/bambi/families/family.py b/bambi/families/family.py index 9f1375ced..e46482252 100644 --- a/bambi/families/family.py +++ b/bambi/families/family.py @@ -132,8 +132,9 @@ def posterior_predictive(self, model, posterior, **kwargs): A data array with the draws from the posterior predictive distribution """ response_dist = get_response_dist(model.family) + response_term = model.response_component.response_term params = model.family.likelihood.params - response_aliased_name = get_aliased_name(model.response_component.response_term) + response_aliased_name = get_aliased_name(response_term) kwargs.pop("data", None) # Remove the 'data' kwarg dont_reshape = kwargs.pop("dont_reshape", []) @@ -181,7 +182,18 @@ def posterior_predictive(self, model, posterior, **kwargs): if hasattr(model.family, "transform_kwargs"): kwargs = model.family.transform_kwargs(kwargs) - output_array = pm.draw(response_dist.dist(**kwargs)) + # Handle constrained responses + if response_term.is_constrained: + # Bounds are scalars, we can safely pick them from the first row + lower, upper = response_term.data[0, 1:] + lower = lower if lower != -np.inf else None + upper = upper if upper != np.inf else None + output_array = pm.draw( + pm.Truncated.dist(response_dist.dist(**kwargs), lower=lower, upper=upper) + ) + else: + output_array = pm.draw(response_dist.dist(**kwargs)) + output_coords_all = xr.merge(output_dataset_list).coords coord_names = ["chain", "draw", response_aliased_name + "_obs"] diff --git a/bambi/terms/response.py b/bambi/terms/response.py index 1751d8964..7249c48a7 100644 --- a/bambi/terms/response.py +++ b/bambi/terms/response.py @@ -2,7 +2,12 @@ from bambi.terms.base import BaseTerm -from bambi.terms.utils import is_censored_response, is_truncated_response, is_weighted_response +from bambi.terms.utils import ( + is_censored_response, + is_constrained_response, + is_truncated_response, + is_weighted_response, +) class ResponseTerm(BaseTerm): @@ -10,6 +15,7 @@ def __init__(self, response, family): self.term = response.term.term self.family = family self.is_censored = is_censored_response(self.term) + self.is_constrained = is_constrained_response(self.term) self.is_truncated = is_truncated_response(self.term) self.is_weighted = is_weighted_response(self.term) diff --git a/bambi/terms/utils.py b/bambi/terms/utils.py index b0cc5bc36..34abb3de1 100644 --- a/bambi/terms/utils.py +++ b/bambi/terms/utils.py @@ -42,6 +42,16 @@ def is_truncated_response(term): return is_call_of_kind(component, "truncated") +def is_constrained_response(term): + """Determines if a formulae term represents a constrained response""" + if not is_single_component(term): + return False + component = term.components[0] # get the first (and single) component + if not is_call_component(component): + return False + return is_call_of_kind(component, "constrained") + + def is_weighted_response(term): """Determines if a formulae term represents a weighted response""" if not is_single_component(term): diff --git a/bambi/transformations.py b/bambi/transformations.py index 03a489a35..bd0616620 100644 --- a/bambi/transformations.py +++ b/bambi/transformations.py @@ -128,6 +128,23 @@ def truncated(x, lb=None, ub=None): truncated.__metadata__ = {"kind": "truncated"} +def constrained(x, lb=None, ub=None): + """Construct an array for a constrained response + + It's exactly like truncated, but it's interpreted by Bambi in a different way as this + one truncates/constrains the bounds of a probability distribution, while `truncated()` is + interpreted as the missing data mechanism. + + `lb` and `ub` can only be scalar values. + """ + assert lb is None or isinstance(lb, (int, float)) + assert ub is None or isinstance(ub, (int, float)) + return truncated(x, lb, ub) + + +constrained.__metadata__ = {"kind": "constrained"} + + def weighted(x, weights): """Construct array for a weighted response @@ -403,6 +420,7 @@ def get_distance(x): transformations_namespace = { "c": c, "censored": censored, + "constrained": constrained, "truncated": truncated, "weighted": weighted, "log": np.log,