diff --git a/pysyncon/dataprep.py b/pysyncon/dataprep.py index b78a566..f35f431 100644 --- a/pysyncon/dataprep.py +++ b/pysyncon/dataprep.py @@ -131,6 +131,11 @@ def __init__( raise ValueError(f"time_variable {time_variable} not in foo columns.") self.time_variable = time_variable + if foo[[unit_variable, time_variable]].duplicated().any(): + raise ValueError( + "Multiple rows found in `foo` for same [unit, time] pairs." + ) + if isinstance(treatment_identifier, (list, tuple)): for treated in treatment_identifier: # This throws FutureWarning (see https://stackoverflow.com/a/46721064/11594901) @@ -169,7 +174,14 @@ def __init__( ) self.controls_identifier = controls_identifier + if self.foo[self.foo[self.time_variable].isin(time_predictors_prior)].empty: + raise ValueError( + f"foo has no rows in the time range `time_predictors_prior`." + ) self.time_predictors_prior = time_predictors_prior + + if self.foo[self.foo[self.time_variable].isin(time_optimize_ssr)].empty: + raise ValueError(f"foo has no rows in the time range `time_optimize_ssr`.") self.time_optimize_ssr = time_optimize_ssr if special_predictors: @@ -178,11 +190,15 @@ def __init__( raise ValueError( "Elements of special_predictors should be tuples of length 3." ) - predictor, _, op = el + predictor, time_range, op = el if predictor not in foo.columns: raise ValueError( f"{predictor} in special_predictors not in foo columns." ) + if self.foo[self.foo[self.time_variable].isin(time_range)].empty: + raise ValueError( + f"foo has no rows in the time range {time_range} for `special_predictor` {el}." + ) if op not in AGG_OP: agg_op_str = ", ".join([f'"{o}"' for o in AGG_OP]) raise ValueError( diff --git a/tests/test_dataprep.py b/tests/test_dataprep.py index 5acc377..f31f46f 100644 --- a/tests/test_dataprep.py +++ b/tests/test_dataprep.py @@ -157,6 +157,27 @@ def test_init_arg_time_variable(self): self.assertRaises(ValueError, Dataprep, time_variable="badval", **kwargs) + def test_init_multiple_rows(self): + kwargs = { + "predictors": self.predictors, + "predictors_op": self.predictors_op, + "dependent": self.dependent, + "unit_variable": self.unit_variable, + "time_variable": self.time_variable, + "treatment_identifier": self.treatment_identifier, + "controls_identifier": self.controls_identifier, + "time_predictors_prior": self.time_predictors_prior, + "time_optimize_ssr": self.time_optimize_ssr, + "special_predictors": self.special_predictors, + } + + self.assertRaises( + ValueError, + Dataprep, + foo=pd.concat([self.foo, self.foo.iloc[0:1]], axis=0), + **kwargs, + ) + def test_init_arg_treatment_identifier(self): kwargs = { "foo": self.foo, @@ -176,6 +197,61 @@ def test_init_arg_treatment_identifier(self): ValueError, Dataprep, treatment_identifier=["badval"], **kwargs ) + def test_bad_time_periods_time_predictors_prior(self): + kwargs = { + "foo": self.foo, + "predictors": self.predictors, + "predictors_op": self.predictors_op, + "dependent": self.dependent, + "unit_variable": self.unit_variable, + "time_variable": self.time_variable, + "controls_identifier": self.controls_identifier, + "time_optimize_ssr": self.time_optimize_ssr, + "treatment_identifier": self.treatment_identifier, + "special_predictors": self.special_predictors, + } + + self.assertRaises( + ValueError, Dataprep, time_predictors_prior=["2", "3"], **kwargs + ) + + def test_bad_time_periods_time_optimize_ssr(self): + kwargs = { + "foo": self.foo, + "predictors": self.predictors, + "predictors_op": self.predictors_op, + "dependent": self.dependent, + "unit_variable": self.unit_variable, + "time_variable": self.time_variable, + "controls_identifier": self.controls_identifier, + "time_predictors_prior": self.time_predictors_prior, + "treatment_identifier": self.treatment_identifier, + "special_predictors": self.special_predictors, + } + + self.assertRaises(ValueError, Dataprep, time_optimize_ssr=["2", "3"], **kwargs) + + def test_bad_time_periods_special_predictors(self): + kwargs = { + "foo": self.foo, + "predictors": self.predictors, + "predictors_op": self.predictors_op, + "dependent": self.dependent, + "unit_variable": self.unit_variable, + "time_variable": self.time_variable, + "controls_identifier": self.controls_identifier, + "time_optimize_ssr": self.time_optimize_ssr, + "time_predictors_prior": self.time_predictors_prior, + "treatment_identifier": self.treatment_identifier, + } + + self.assertRaises( + ValueError, + Dataprep, + special_predictors=[("predictor1", ["2"], "mean")], + **kwargs, + ) + def test_init_arg_controls_identifier(self): kwargs = { "foo": self.foo,