Skip to content

Commit

Permalink
Add att unit tests on example data
Browse files Browse the repository at this point in the history
  • Loading branch information
sdfordham committed Nov 20, 2023
1 parent 8423a6c commit 156bd2c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 0 deletions.
19 changes: 19 additions & 0 deletions tests/test_synth_basque.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def setUp(self):
)
self.treatment_time = 1975
self.pvalue = 0.16666666666666666
self.att = {'att': -0.6995647842110987, 'se': 0.07078092130438395}
self.att_time_period = range(1975, 1998)

def test_weights(self):
synth = Synth()
Expand Down Expand Up @@ -199,3 +201,20 @@ def test_placebo_weights(self):
placebo_test.pvalue(treatment_time=self.treatment_time),
places=3,
)

def test_att(self):
synth = Synth()
synth.fit(
dataprep=self.dataprep,
optim_method=self.optim_method,
optim_initial=self.optim_initial,
)
synth_att = synth.att(time_period=self.att_time_period)

# Allow a tolerance of 2.5%
att_perc_delta = abs(1.0 - self.att['att'] / synth_att['att'])
self.assertLessEqual(att_perc_delta, 0.025)

# Allow a tolerance of 2.5%
se_perc_delta = abs(1.0 - self.att['se'] / synth_att['se'])
self.assertLessEqual(se_perc_delta, 0.025)
20 changes: 20 additions & 0 deletions tests/test_synth_germany.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def setUp(self):
"Australia": 0.0,
"New Zealand": 0.0,
}
self.att = {'att': -1555.1346777620479, 'se': 317.6469306023242}
self.att_time_period = range(1990, 2004)

def test_weights(self):
synth = Synth()
Expand All @@ -115,3 +117,21 @@ def test_weights(self):
pd.testing.assert_series_equal(
weights, synth.weights(round=9), check_exact=False, atol=0.025
)

def test_att(self):
synth = Synth()
synth.fit(
dataprep=self.dataprep,
optim_method=self.optim_method,
optim_initial=self.optim_initial,
custom_V=self.custom_V,
)
synth_att = synth.att(time_period=self.att_time_period)

# Allow a tolerance of 2.5%
att_perc_delta = abs(1.0 - self.att['att'] / synth_att['att'])
self.assertLessEqual(att_perc_delta, 0.025)

# Allow a tolerance of 2.5%
se_perc_delta = abs(1.0 - self.att['se'] / synth_att['se'])
self.assertLessEqual(se_perc_delta, 0.025)
19 changes: 19 additions & 0 deletions tests/test_synth_texas.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def setUp(self):
"Wisconsin": 0.0,
"Wyoming": 0.0,
}
self.att = {'att': 20339.375838131393, 'se': 3190.4946788704715}
self.att_time_period = range(1993, 2001)

def test_weights(self):
synth = Synth()
Expand All @@ -150,3 +152,20 @@ def test_weights(self):
pd.testing.assert_series_equal(
weights, synth.weights(round=9), check_exact=False, atol=0.025
)

def test_att(self):
synth = Synth()
synth.fit(
dataprep=self.dataprep,
optim_method=self.optim_method,
optim_initial=self.optim_initial,
)
synth_att = synth.att(time_period=self.att_time_period)

# Allow a tolerance of 2.5%
att_perc_delta = abs(1.0 - self.att['att'] / synth_att['att'])
self.assertLessEqual(att_perc_delta, 0.025)

# Allow a tolerance of 2.5%
se_perc_delta = abs(1.0 - self.att['se'] / synth_att['se'])
self.assertLessEqual(se_perc_delta, 0.025)

0 comments on commit 156bd2c

Please sign in to comment.