Skip to content

Commit

Permalink
Merge pull request #64 from sdfordham/further-arg-validation-dataprep
Browse files Browse the repository at this point in the history
Further arg validation dataprep
  • Loading branch information
sdfordham authored May 22, 2024
2 parents 81c9abe + afad7f3 commit ef72c8a
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 1 deletion.
18 changes: 17 additions & 1 deletion pysyncon/dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,11 @@ def __init__(
raise ValueError(f"time_variable {time_variable} not in foo columns.")
self.time_variable = time_variable

if foo[[unit_variable, time_variable]].duplicated().any():
raise ValueError(
"Multiple rows found in `foo` for same [unit, time] pairs."
)

if isinstance(treatment_identifier, (list, tuple)):
for treated in treatment_identifier:
# This throws FutureWarning (see https://stackoverflow.com/a/46721064/11594901)
Expand Down Expand Up @@ -169,7 +174,14 @@ def __init__(
)
self.controls_identifier = controls_identifier

if self.foo[self.foo[self.time_variable].isin(time_predictors_prior)].empty:
raise ValueError(
f"foo has no rows in the time range `time_predictors_prior`."
)
self.time_predictors_prior = time_predictors_prior

if self.foo[self.foo[self.time_variable].isin(time_optimize_ssr)].empty:
raise ValueError(f"foo has no rows in the time range `time_optimize_ssr`.")
self.time_optimize_ssr = time_optimize_ssr

if special_predictors:
Expand All @@ -178,11 +190,15 @@ def __init__(
raise ValueError(
"Elements of special_predictors should be tuples of length 3."
)
predictor, _, op = el
predictor, time_range, op = el
if predictor not in foo.columns:
raise ValueError(
f"{predictor} in special_predictors not in foo columns."
)
if self.foo[self.foo[self.time_variable].isin(time_range)].empty:
raise ValueError(
f"foo has no rows in the time range {time_range} for `special_predictor` {el}."
)
if op not in AGG_OP:
agg_op_str = ", ".join([f'"{o}"' for o in AGG_OP])
raise ValueError(
Expand Down
76 changes: 76 additions & 0 deletions tests/test_dataprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,27 @@ def test_init_arg_time_variable(self):

self.assertRaises(ValueError, Dataprep, time_variable="badval", **kwargs)

def test_init_multiple_rows(self):
kwargs = {
"predictors": self.predictors,
"predictors_op": self.predictors_op,
"dependent": self.dependent,
"unit_variable": self.unit_variable,
"time_variable": self.time_variable,
"treatment_identifier": self.treatment_identifier,
"controls_identifier": self.controls_identifier,
"time_predictors_prior": self.time_predictors_prior,
"time_optimize_ssr": self.time_optimize_ssr,
"special_predictors": self.special_predictors,
}

self.assertRaises(
ValueError,
Dataprep,
foo=pd.concat([self.foo, self.foo.iloc[0:1]], axis=0),
**kwargs,
)

def test_init_arg_treatment_identifier(self):
kwargs = {
"foo": self.foo,
Expand All @@ -176,6 +197,61 @@ def test_init_arg_treatment_identifier(self):
ValueError, Dataprep, treatment_identifier=["badval"], **kwargs
)

def test_bad_time_periods_time_predictors_prior(self):
kwargs = {
"foo": self.foo,
"predictors": self.predictors,
"predictors_op": self.predictors_op,
"dependent": self.dependent,
"unit_variable": self.unit_variable,
"time_variable": self.time_variable,
"controls_identifier": self.controls_identifier,
"time_optimize_ssr": self.time_optimize_ssr,
"treatment_identifier": self.treatment_identifier,
"special_predictors": self.special_predictors,
}

self.assertRaises(
ValueError, Dataprep, time_predictors_prior=["2", "3"], **kwargs
)

def test_bad_time_periods_time_optimize_ssr(self):
kwargs = {
"foo": self.foo,
"predictors": self.predictors,
"predictors_op": self.predictors_op,
"dependent": self.dependent,
"unit_variable": self.unit_variable,
"time_variable": self.time_variable,
"controls_identifier": self.controls_identifier,
"time_predictors_prior": self.time_predictors_prior,
"treatment_identifier": self.treatment_identifier,
"special_predictors": self.special_predictors,
}

self.assertRaises(ValueError, Dataprep, time_optimize_ssr=["2", "3"], **kwargs)

def test_bad_time_periods_special_predictors(self):
kwargs = {
"foo": self.foo,
"predictors": self.predictors,
"predictors_op": self.predictors_op,
"dependent": self.dependent,
"unit_variable": self.unit_variable,
"time_variable": self.time_variable,
"controls_identifier": self.controls_identifier,
"time_optimize_ssr": self.time_optimize_ssr,
"time_predictors_prior": self.time_predictors_prior,
"treatment_identifier": self.treatment_identifier,
}

self.assertRaises(
ValueError,
Dataprep,
special_predictors=[("predictor1", ["2"], "mean")],
**kwargs,
)

def test_init_arg_controls_identifier(self):
kwargs = {
"foo": self.foo,
Expand Down

0 comments on commit ef72c8a

Please sign in to comment.