Skip to content

Commit

Permalink
Unit test on fit args to conf intervals
Browse files Browse the repository at this point in the history
  • Loading branch information
sdfordham committed May 10, 2024
1 parent 309720a commit 7946b09
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 41 deletions.
2 changes: 1 addition & 1 deletion pysyncon/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def confidence_intervals(
raise ValueError("`step_sz_div` must be greater than 0.0")
if scm.W is None:
raise ValueError("No weight matrix available; fit data first.")

gaps = scm._gaps(Z0=Z0, Z1=Z1)
if step_sz is None:
if len(post_periods) > 1:
Expand Down
18 changes: 4 additions & 14 deletions tests/test_conformal_interence.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,7 @@ def setUp(self):
index=range(1, 5),
columns=range(1, 11),
)
self.X1 = pd.Series(
data=self.rng.random(size=(4,)), index=range(1, 5), name=0
)
self.X1 = pd.Series(data=self.rng.random(size=(4,)), index=range(1, 5), name=0)
self.pre_periods = list(range(1, 21))
self.post_periods = list(range(21, 31))
self.max_iter = 20
Expand Down Expand Up @@ -218,11 +216,7 @@ def test_no_weights(self):
}

conformal_inf = ConformalInference()
self.assertRaises(
ValueError,
conformal_inf.confidence_intervals,
**kwargs
)
self.assertRaises(ValueError, conformal_inf.confidence_intervals, **kwargs)

def test_step_sz_options(self):
self.scm.fit(X0=self.X0, X1=self.X1, Z0=self.Z0, Z1=self.Z1)
Expand All @@ -240,13 +234,9 @@ def test_step_sz_options(self):
}

conformal_inf = ConformalInference()
conformal_inf.confidence_intervals(post_periods=self.post_periods, **kwargs)
conformal_inf.confidence_intervals(
post_periods=self.post_periods,
**kwargs
)
conformal_inf.confidence_intervals(
post_periods=[self.post_periods[0]],
**kwargs
post_periods=[self.post_periods[0]], **kwargs
)

def test_root_search(self):
Expand Down
60 changes: 34 additions & 26 deletions tests/test_synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,59 +295,46 @@ def test_confidence_intervals(self):
dataprep = pysyncon.Dataprep(**kwargs)
synth = pysyncon.Synth()
synth.fit(dataprep=dataprep)

# Bad option
self.assertRaises(
ValueError,
synth.confidence_interval,
alpha=0.5,
time_periods=[4],
method="foo"
method="foo",
)

# With dataprep supplied
try:
synth.confidence_interval(
alpha=0.5,
time_periods=[4],
dataprep=dataprep
)
synth.confidence_interval(alpha=0.5, time_periods=[4], dataprep=dataprep)
except Exception as e:
self.fail(f"Confidence interval failed: {e}.")

# Too few time periods for alpha value
self.assertRaises(
ValueError,
synth.confidence_interval,
alpha=0.05,
time_periods=[4],
dataprep=dataprep
dataprep=dataprep,
)

# Without dataprep supplied
try:
synth.confidence_interval(
alpha=0.5,
time_periods=[4]
)
synth.confidence_interval(alpha=0.5, time_periods=[4])
except Exception as e:
self.fail(f"Confidence interval failed: {e}.")

# Too few time periods for alpha value
self.assertRaises(
ValueError,
synth.confidence_interval,
alpha=0.05,
time_periods=[4]
ValueError, synth.confidence_interval, alpha=0.05, time_periods=[4]
)

# Without dataprep supplied or matrices
synth.dataprep = None
self.assertRaises(
ValueError,
synth.confidence_interval,
alpha=0.5,
time_periods=[4]
ValueError, synth.confidence_interval, alpha=0.5, time_periods=[4]
)

# No pre-periods supplied
Expand All @@ -362,7 +349,7 @@ def test_confidence_intervals(self):
X0=X0,
X1=X1,
Z0=Z0,
Z1=Z1
Z1=Z1,
)

# Bad alpha value
Expand All @@ -375,9 +362,32 @@ def test_confidence_intervals(self):
X0=X0,
X1=X1,
Z0=Z0,
Z1=Z1
Z1=Z1,
)

# Add fit options
_, n_c = X0.shape
custom_V = np.full(n_c, 1 / n_c)
optim_method = "BFGS"
optim_initial = "ols"
optim_options = {"max_iter": 1000}
try:
synth.confidence_interval(
alpha=0.5,
time_periods=[4],
pre_periods=[1, 2, 3],
X0=X0,
X1=X1,
Z0=Z0,
Z1=Z1,
custom_V=custom_V,
optim_method=optim_method,
optim_initial=optim_initial,
optim_options=optim_options,
)
except Exception as e:
self.fail(f"Confidence interval failed: {e}.")

# Dataframes supplied instead of series
X1 = X1.to_frame()
Z1 = Z1.to_frame()
Expand All @@ -390,7 +400,5 @@ def test_confidence_intervals(self):
X0=X0,
X1=X1,
Z0=Z0,
Z1=Z1
Z1=Z1,
)


0 comments on commit 7946b09

Please sign in to comment.