Skip to content

Commit

Permalink
Merge pull request #257 from CITCOM-project/conditions
Browse files Browse the repository at this point in the history
We now support placing conditions on the data again.
  • Loading branch information
jmafoster1 authored Jan 31, 2024
2 parents e7a3e90 + 41bd188 commit 9abdcbc
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 22 deletions.
5 changes: 2 additions & 3 deletions causal_testing/json_front/json_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,6 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
"""Create the necessary inputs for a single test case
:param causal_test_case: The concrete test case to be executed
:param test: Single JSON test definition stored in a mapping (dict)
:param conditions: A list of conditions which should be applied to the
data. Conditions should be in the query format detailed at
https://pandas.pydata.org/docs/reference/api/pandas.DataFrame.query.html
:returns:
- estimation_model - Estimator instance for the test being run
"""
Expand All @@ -323,11 +320,13 @@ def _setup_test(self, causal_test_case: CausalTestCase, test: Mapping) -> Estima
minimal_adjustment_set = minimal_adjustment_set - {causal_test_case.treatment_variable}
estimator_kwargs["adjustment_set"] = minimal_adjustment_set

estimator_kwargs["query"] = test["query"] if "query" in test else ""
estimator_kwargs["treatment"] = causal_test_case.treatment_variable.name
estimator_kwargs["treatment_value"] = causal_test_case.treatment_value
estimator_kwargs["control_value"] = causal_test_case.control_value
estimator_kwargs["outcome"] = causal_test_case.outcome_variable.name
estimator_kwargs["effect_modifiers"] = causal_test_case.effect_modifier_configuration
estimator_kwargs["df"] = self.data_collector.collect_data()
estimator_kwargs["alpha"] = test["alpha"] if "alpha" in test else 0.05

estimation_model = test["estimator"](**estimator_kwargs)
Expand Down
5 changes: 3 additions & 2 deletions causal_testing/testing/causal_test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,10 @@ def push(s, inc=" "):
f"Treatment value: {self.estimator.treatment_value}\n"
f"Outcome: {self.estimator.outcome}\n"
f"Adjustment set: {self.adjustment_set}\n"
f"Formula: {self.estimator.formula}\n"
f"{self.test_value.type}: {result_str}\n"
)
if hasattr(self.estimator, "formula"):
base_str += f"Formula: {self.estimator.formula}\n"
base_str += f"{self.test_value.type}: {result_str}\n"
confidence_str = ""
if self.confidence_intervals:
ci_str = " " + str(self.confidence_intervals)
Expand Down
63 changes: 51 additions & 12 deletions causal_testing/testing/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import statsmodels.api as sm
import statsmodels.formula.api as smf
from econml.dml import CausalForestDML
from patsy import dmatrix
from patsy import dmatrix # pylint: disable = no-name-in-module

from sklearn.ensemble import GradientBoostingRegressor
from statsmodels.regression.linear_model import RegressionResultsWrapper
Expand Down Expand Up @@ -50,21 +50,25 @@ def __init__(
df: pd.DataFrame = None,
effect_modifiers: dict[str:Any] = None,
alpha: float = 0.05,
query: str = "",
):
self.treatment = treatment
self.treatment_value = treatment_value
self.control_value = control_value
self.adjustment_set = adjustment_set
self.outcome = outcome
self.df = df
self.alpha = alpha
self.df = df.query(query) if query else df

if effect_modifiers is None:
self.effect_modifiers = {}
elif isinstance(effect_modifiers, dict):
self.effect_modifiers = effect_modifiers
else:
raise ValueError(f"Unsupported type for effect_modifiers {effect_modifiers}. Expected iterable")
self.modelling_assumptions = []
if query:
self.modelling_assumptions.append(query)
self.add_modelling_assumptions()
logger.debug("Effect Modifiers: %s", self.effect_modifiers)

Expand Down Expand Up @@ -100,8 +104,18 @@ def __init__(
df: pd.DataFrame = None,
effect_modifiers: dict[str:Any] = None,
formula: str = None,
query: str = "",
):
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers)
super().__init__(
treatment=treatment,
treatment_value=treatment_value,
control_value=control_value,
adjustment_set=adjustment_set,
outcome=outcome,
df=df,
effect_modifiers=effect_modifiers,
query=query,
)

self.model = None

Expand All @@ -116,13 +130,13 @@ def add_modelling_assumptions(self):
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
must hold if the resulting causal inference is to be considered valid.
"""
self.modelling_assumptions += (
self.modelling_assumptions.append(
"The variables in the data must fit a shape which can be expressed as a linear"
"combination of parameters and functions of variables. Note that these functions"
"do not need to be linear."
)
self.modelling_assumptions += "The outcome must be binary."
self.modelling_assumptions += "Independently and identically distributed errors."
self.modelling_assumptions.append("The outcome must be binary.")
self.modelling_assumptions.append("Independently and identically distributed errors.")

def _run_logistic_regression(self, data) -> RegressionResultsWrapper:
"""Run logistic regression of the treatment and adjustment set against the outcome and return the model.
Expand Down Expand Up @@ -291,9 +305,18 @@ def __init__(
effect_modifiers: dict[Variable:Any] = None,
formula: str = None,
alpha: float = 0.05,
query: str = "",
):
super().__init__(
treatment, treatment_value, control_value, adjustment_set, outcome, df, effect_modifiers, alpha=alpha
treatment,
treatment_value,
control_value,
adjustment_set,
outcome,
df,
effect_modifiers,
alpha=alpha,
query=query,
)

self.model = None
Expand All @@ -314,7 +337,7 @@ def add_modelling_assumptions(self):
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
must hold if the resulting causal inference is to be considered valid.
"""
self.modelling_assumptions += (
self.modelling_assumptions.append(
"The variables in the data must fit a shape which can be expressed as a linear"
"combination of parameters and functions of variables. Note that these functions"
"do not need to be linear."
Expand Down Expand Up @@ -509,8 +532,20 @@ def __init__(
df: pd.DataFrame = None,
intercept: int = 1,
effect_modifiers: dict = None, # Not used (yet?). Needed for compatibility
alpha: float = 0.05,
query: str = "",
):
super().__init__(treatment, treatment_value, control_value, adjustment_set, outcome, df, None)
super().__init__(
treatment=treatment,
treatment_value=treatment_value,
control_value=control_value,
adjustment_set=adjustment_set,
outcome=outcome,
df=df,
effect_modifiers=None,
alpha=alpha,
query=query,
)
self.intercept = intercept
self.model = None
self.instrument = instrument
Expand All @@ -520,13 +555,17 @@ def add_modelling_assumptions(self):
Add modelling assumptions to the estimator. This is a list of strings which list the modelling assumptions that
must hold if the resulting causal inference is to be considered valid.
"""
self.modelling_assumptions += """The instrument and the treatment, and the treatment and the outcome must be
self.modelling_assumptions.append(
"""The instrument and the treatment, and the treatment and the outcome must be
related linearly in the form Y = aX + b."""
self.modelling_assumptions += """The three IV conditions must hold
)
self.modelling_assumptions.append(
"""The three IV conditions must hold
(i) Instrument is associated with treatment
(ii) Instrument does not affect outcome except through its potential effect on treatment
(iii) Instrument and outcome do not share causes
"""
)

def estimate_iv_coefficient(self, df):
"""
Expand Down Expand Up @@ -569,7 +608,7 @@ def add_modelling_assumptions(self):
:return self: Update self.modelling_assumptions
"""
self.modelling_assumptions += "Non-parametric estimator: no restrictions imposed on the data."
self.modelling_assumptions.append("Non-parametric estimator: no restrictions imposed on the data.")

def estimate_ate(self) -> float:
"""Estimate the average treatment effect.
Expand Down
15 changes: 10 additions & 5 deletions tests/testing_tests/test_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,14 @@ def test_ate_adjustment(self):
logistic_regression_estimator = LogisticRegressionEstimator(
"length_in", 65, 55, {"large_gauge"}, "completed", df
)
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config = {"large_gauge": 0})
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})
self.assertEqual(round(ate, 4), -0.3388)

def test_ate_invalid_adjustment(self):
df = self.scarf_df.copy()
logistic_regression_estimator = LogisticRegressionEstimator("length_in", 65, 55, {}, "completed", df)
with self.assertRaises(ValueError):
ate, _ = logistic_regression_estimator.estimate_ate(
adjustment_config = {"large_gauge": 0}
)
ate, _ = logistic_regression_estimator.estimate_ate(adjustment_config={"large_gauge": 0})

def test_ate_effect_modifiers(self):
df = self.scarf_df.copy()
Expand Down Expand Up @@ -216,6 +214,13 @@ def setUpClass(cls) -> None:
cls.nhefs_df = load_nhefs_df()
cls.chapter_11_df = load_chapter_11_df()

def test_query(self):
df = self.nhefs_df
linear_regression_estimator = LinearRegressionEstimator(
"treatments", None, None, set(), "outcomes", df, query="sex==1"
)
self.assertTrue(linear_regression_estimator.df.sex.all())

def test_program_11_2(self):
"""Test whether our linear regression implementation produces the same results as program 11.2 (p. 141)."""
df = self.chapter_11_df
Expand Down Expand Up @@ -395,7 +400,7 @@ def test_program_15_no_interaction_ate_calculated(self):
# for term_to_square in terms_to_square:

ate, [ci_low, ci_high] = linear_regression_estimator.estimate_ate_calculated(
adjustment_config = {k: self.nhefs_df.mean()[k] for k in covariates}
adjustment_config={k: self.nhefs_df.mean()[k] for k in covariates}
)
self.assertEqual(round(ate, 1), 3.5)
self.assertEqual([round(ci_low, 1), round(ci_high, 1)], [1.9, 5])
Expand Down

0 comments on commit 9abdcbc

Please sign in to comment.