Skip to content

Commit

Permalink
Add additional causal analysis flags
Browse files Browse the repository at this point in the history
  • Loading branch information
kbattocchi committed Jul 10, 2021
1 parent 0cec5ab commit c3d5bd6
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 29 deletions.
151 changes: 122 additions & 29 deletions econml/solutions/causal_analysis/_causal_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import joblib
import lightgbm as lgb
import numpy as np
from numpy.lib.function_base import iterable
import pandas as pd
from sklearn.base import TransformerMixin
from sklearn.compose import ColumnTransformer
Expand Down Expand Up @@ -47,6 +48,7 @@ class _CausalInsightsConstants:
ConfoundingIntervalKey = 'confounding_interval'
ViewKey = 'view'
InitArgsKey = 'init_args'
RowData = 'row_data' # NOTE: RowData is mutually exclusive with the other data columns

ALL = [RawFeatureNameKey,
EngineeredNameKey,
Expand All @@ -62,7 +64,8 @@ class _CausalInsightsConstants:
CausalComputationTypeKey,
ConfoundingIntervalKey,
ViewKey,
InitArgsKey]
InitArgsKey,
RowData]


def _get_default_shared_insights_output():
Expand Down Expand Up @@ -105,6 +108,22 @@ def _get_metadata_causal_insights_keys():
_CausalInsightsConstants.ViewKey]


def _get_column_causal_insights_keys():
return [_CausalInsightsConstants.RawFeatureNameKey,
_CausalInsightsConstants.EngineeredNameKey,
_CausalInsightsConstants.CategoricalColumnKey,
_CausalInsightsConstants.TypeKey]


def _get_data_causal_insights_keys():
return [_CausalInsightsConstants.PointEstimateKey,
_CausalInsightsConstants.StandardErrorKey,
_CausalInsightsConstants.ZStatKey,
_CausalInsightsConstants.ConfidenceIntervalLowerKey,
_CausalInsightsConstants.ConfidenceIntervalUpperKey,
_CausalInsightsConstants.PValueKey]


def _first_stage_reg(X, y, *, automl=True, random_state=None, verbose=0):
if automl:
model = GridSearchCVList([make_pipeline(StandardScaler(), LassoCV(random_state=random_state)),
Expand Down Expand Up @@ -551,6 +570,15 @@ def fit(self, X, y, warm_start=False):

assert np.ndim(X) == 2, f"X must be a 2-dimensional array, but here had shape {np.shape(X)}"

assert iterable(self.feature_inds), f"feature_inds should be array-like, but got {self.feature_inds}"
assert iterable(self.categorical), f"categorical should be array-like, but got {self.categorical}"
assert self.heterogeneity_inds is None or iterable(self.heterogeneity_inds), (
f"heterogeneity_inds should be None or array-like, but got {self.heterogeneity_inds}")
assert self.feature_names is None or iterable(self.feature_names), (
f"feature_names should be None or array-like, but got {self.feature_names}")
assert self.categories == 'auto' or iterable(self.categories), (
f"categories should be 'auto' or array-like, but got {self.categories}")

# TODO: check compatibility of X and Y lengths

if warm_start:
Expand Down Expand Up @@ -860,8 +888,8 @@ def coalesce(attr):

return summary([(key, coalesce(val)) for key, val in props])

def _pandas_summary(self, get_inference, props, n,
expand_arr=False):
def _pandas_summary(self, get_inference, *, props, n,
expand_arr=False, keep_all_levels=False):
"""
Summarizes results into a dataframe.
Expand All @@ -874,7 +902,10 @@ def _pandas_summary(self, get_inference, props, n,
n : int
The number of samples in the dataset
expand_arr : boolean, default False
Whether to add an initial sample dimension to the result arrays
Whether to add a synthetic sample dimension to the result arrays when performing internal computations
keep_all_levels : boolean, default False
Whether to keep all levels, even when they don't take on more than one value;
Note that regardless of this argument the "sample" level will only be present if expand_arr is False
"""
def make_dataframe(props):

Expand All @@ -890,13 +921,19 @@ def make_dataframe(props):
for res in self._results
for lvl in res.feature_levels],
names=["sample", "outcome", "feature", "feature_value"])
for lvl in index.levels:
if len(lvl) == 1:
if not isinstance(index, pd.MultiIndex):
# can't drop only level
index = pd.Index([self._results[0].feature_name], name="feature")
else:
index = index.droplevel(lvl.name)

if expand_arr:
# There is no actual sample level in this data
index = index.droplevel("sample")

if not keep_all_levels:
for lvl in index.levels:
if len(lvl) == 1:
if not isinstance(index, pd.MultiIndex):
# can't drop only level
index = pd.Index([self._results[0].feature_name], name="feature")
else:
index = index.droplevel(lvl.name)
return pd.DataFrame(to_include, index=index)

return self._summarize(summary=make_dataframe,
Expand All @@ -905,14 +942,16 @@ def make_dataframe(props):
expand_arr=expand_arr,
drop_sample=False) # dropping the sample dimension is handled above instead

def _dict_summary(self, get_inference, *, props, kind, drop_sample=False, expand_arr=False):
def _dict_summary(self, get_inference, *, n, props, kind, drop_sample=False, expand_arr=False, row_wise=False):
"""
Summarizes results into a dictionary.
Parameters
----------
get_inference : lambda
Method to get the relevant inference results from each result object
n : int
The number of samples in the dataset
props : list of (string, string or lambda)
Set of column names and ways to get the corresponding values from the inference object
kind : string
Expand All @@ -921,27 +960,59 @@ def _dict_summary(self, get_inference, *, props, kind, drop_sample=False, expand
Whether to drop the sample dimension from each array
expand_arr : boolean, default False
Whether to add an initial sample dimension to the result arrays
row_wise : boolean, default False
Whether to return a list of dictionaries (one dictionary per row) instead of
a dictionary of lists (one list per column)
"""
def make_dict(props):
# should be serialization-ready and contain no numpy arrays
res = _get_default_specific_insights(kind)
res.update([(key, value.tolist()) for key, value in props])
return {**self._shared, **res}
shared = self._shared

if row_wise:
row_data = {}
# remove entries belonging to row data, since we're including them in the list of nested dictionaries
for k in _get_data_causal_insights_keys():
del res[k]

shared = shared.copy() # copy so that we can modify without affecting shared state
# TODO: Note that there's no column metadata for the sample number - should there be?
for k in _get_column_causal_insights_keys():
# need to replicate the column info for each sample, then remove from the shared data
row_data[k] = shared[k] * n
del shared[k]

# NOTE: the flattened order has the ouptut dimension before the feature dimension
# which may need to be revisited once we support multiclass
row_data.update([(key, value.flatten()) for key, value in props])

# get the length of the list corresponding to the first dictionary key
# `list(row_data)` gets the keys as a list, since `row_data.keys()` can't be indexed into
n_rows = len(row_data[list(row_data)[0]])
res[_CausalInsightsConstants.RowData] = [{key: row_data[key][i]
for key in row_data} for i in range(n_rows)]
else:
res.update([(key, value.tolist()) for key, value in props])

return {**shared, **res}

return self._summarize(summary=make_dict,
get_inference=get_inference,
props=props,
expand_arr=expand_arr,
drop_sample=drop_sample)

def global_causal_effect(self, alpha=0.1):
def global_causal_effect(self, *, alpha=0.1, keep_all_levels=False):
"""
Get the global causal effect for each feature as a pandas DataFrame.
Parameters
----------
alpha : float, default 0.1
The confidence level of the confidence interval
keep_all_levels : bool, default False
Whether to keep all levels of the output dataframe ('outcome', 'feature', and 'feature_level')
even if there was only a single value for that level; by default single-valued levels are dropped.
Returns
-------
Expand All @@ -959,9 +1030,9 @@ def global_causal_effect(self, alpha=0.1):
"""
# a global inference indicates the effect of that one feature on the outcome
return self._pandas_summary(lambda res: res.global_inference, props=self._point_props(alpha),
n=1, expand_arr=True)
n=1, expand_arr=True, keep_all_levels=keep_all_levels)

def _global_causal_effect_dict(self, alpha=0.1):
def _global_causal_effect_dict(self, *, alpha=0.1, row_wise=False):
"""
Gets the global causal effect for each feature as dictionary.
Expand All @@ -970,7 +1041,7 @@ def _global_causal_effect_dict(self, alpha=0.1):
Only for serialization purposes to upload to AzureML
"""
return self._dict_summary(lambda res: res.global_inference, props=self._point_props(alpha),
kind='global', drop_sample=True, expand_arr=True)
kind='global', n=1, row_wise=row_wise, drop_sample=True, expand_arr=True)

def _cohort_effect_inference(self, Xtest):
assert np.ndim(Xtest) == 2 and np.shape(Xtest)[1] == self._d_x, (
Expand All @@ -986,7 +1057,7 @@ def inference_from_result(result):
return est.const_marginal_ate_inference(X=X)
return inference_from_result

def cohort_causal_effect(self, Xtest, alpha=0.1):
def cohort_causal_effect(self, Xtest, *, alpha=0.1, keep_all_levels=False):
"""
Gets the average causal effects for a particular cohort defined by a population of X's.
Expand All @@ -996,6 +1067,9 @@ def cohort_causal_effect(self, Xtest, alpha=0.1):
The cohort samples for which to return the average causal effects within cohort
alpha : float, default 0.1
The confidence level of the confidence interval
keep_all_levels : bool, default False
Whether to keep all levels of the output dataframe ('outcome', 'feature', and 'feature_level')
even if there was only a single value for that level; by default single-valued levels are dropped.
Returns
-------
Expand All @@ -1013,9 +1087,9 @@ def cohort_causal_effect(self, Xtest, alpha=0.1):
"""
return self._pandas_summary(self._cohort_effect_inference(Xtest),
props=self._summary_props(alpha), n=1,
expand_arr=True)
expand_arr=True, keep_all_levels=keep_all_levels)

def _cohort_causal_effect_dict(self, Xtest, alpha=0.1):
def _cohort_causal_effect_dict(self, Xtest, *, alpha=0.1, row_wise=False):
"""
Gets the cohort causal effects for each feature as dictionary.
Expand All @@ -1024,7 +1098,7 @@ def _cohort_causal_effect_dict(self, Xtest, alpha=0.1):
Only for serialization purposes to upload to AzureML
"""
return self._dict_summary(self._cohort_effect_inference(Xtest), props=self._summary_props(alpha),
kind='cohort', expand_arr=True, drop_sample=True)
kind='cohort', n=1, row_wise=row_wise, expand_arr=True, drop_sample=True)

def _local_effect_inference(self, Xtest):
assert np.ndim(Xtest) == 2 and np.shape(Xtest)[1] == self._d_x, (
Expand All @@ -1044,7 +1118,7 @@ def inference_from_result(result):
return eff
return inference_from_result

def local_causal_effect(self, Xtest, alpha=0.1):
def local_causal_effect(self, Xtest, *, alpha=0.1, keep_all_levels=False):
"""
Gets the local causal effect for each feature as a pandas DataFrame.
Expand All @@ -1054,7 +1128,9 @@ def local_causal_effect(self, Xtest, alpha=0.1):
The samples for which to return the causal effects
alpha : float, default 0.1
The confidence level of the confidence interval
keep_all_levels : bool, default False
Whether to keep all levels of the output dataframe ('sample', 'outcome', 'feature', and 'feature_level')
even if there was only a single value for that level; by default single-valued levels are dropped.
Returns
-------
global_effect : pandas Dataframe
Expand All @@ -1073,9 +1149,9 @@ def local_causal_effect(self, Xtest, alpha=0.1):
in the serialized dict.
"""
return self._pandas_summary(self._local_effect_inference(Xtest),
props=self._point_props(alpha), n=Xtest.shape[0])
props=self._point_props(alpha), n=Xtest.shape[0], keep_all_levels=keep_all_levels)

def _local_causal_effect_dict(self, Xtest, alpha=0.1):
def _local_causal_effect_dict(self, Xtest, *, alpha=0.1, row_wise=False):
"""
Gets the local feature importance as dictionary
Expand All @@ -1084,7 +1160,7 @@ def _local_causal_effect_dict(self, Xtest, alpha=0.1):
Only for serialization purposes to upload to AzureML
"""
return self._dict_summary(self._local_effect_inference(Xtest), props=self._point_props(alpha),
kind='local')
kind='local', n=Xtest.shape[0], row_wise=row_wise)

def _safe_result_index(self, X, feature_index):
assert hasattr(self, "_results"), "This instance has not yet been fitted"
Expand Down Expand Up @@ -1153,7 +1229,7 @@ def whatif(self, X, Xnew, feature_index, y, *, alpha=0.1):
"""
return self._whatif_inference(X, Xnew, feature_index, y).summary_frame(alpha=alpha)

def _whatif_dict(self, X, Xnew, feature_index, y, alpha=0.1):
def _whatif_dict(self, X, Xnew, feature_index, y, *, alpha=0.1, row_wise=False):
"""
Get counterfactual predictions when feature_index is changed to Xnew from its observational counterpart.
Expand All @@ -1173,6 +1249,9 @@ def _whatif_dict(self, X, Xnew, feature_index, y, alpha=0.1):
alpha : float in [0, 1], default 0.1
Confidence level of the confidence intervals displayed in the leaf nodes.
A (1-alpha)*100% confidence interval is displayed.
row_wise : boolean, default False
Whether to return a list of dictionaries (one dictionary per row) instead of
a dictionary of lists (one list per column)
Returns
-------
dict : dict
Expand All @@ -1182,7 +1261,21 @@ def _whatif_dict(self, X, Xnew, feature_index, y, alpha=0.1):
inf = self._whatif_inference(X, Xnew, feature_index, y)
props = self._point_props(alpha=alpha)
res = _get_default_specific_insights('whatif')
res.update([(key, self._make_accessor(attr)(inf).tolist()) for key, attr in props])
if row_wise:
row_data = {}
# remove entries belonging to row data, since we're including them in the list of nested dictionaries
for k in _get_data_causal_insights_keys():
del res[k]

row_data.update([(key, self._make_accessor(attr)(inf).flatten()) for key, attr in props])

# get the length of the list corresponding to the first dictionary key
# `list(row_data)` gets the keys as a list, since `row_data.keys()` can't be indexed into
n_rows = len(row_data[list(row_data)[0]])
res[_CausalInsightsConstants.RowData] = [{key: row_data[key][i]
for key in row_data} for i in range(n_rows)]
else:
res.update([(key, self._make_accessor(attr)(inf).tolist()) for key, attr in props])
return res

def _tree(self, is_policy, Xtest, feature_index, *, treatment_costs=0,
Expand Down
Loading

0 comments on commit c3d5bd6

Please sign in to comment.