From 156bd2c0b77d27beddfe3030f4ec2d207b78cc1b Mon Sep 17 00:00:00 2001 From: Stiofain <17852477+sdfordham@users.noreply.github.com> Date: Mon, 20 Nov 2023 19:37:27 +0000 Subject: [PATCH] Add att unit tests on example data --- tests/test_synth_basque.py | 19 +++++++++++++++++++ tests/test_synth_germany.py | 20 ++++++++++++++++++++ tests/test_synth_texas.py | 19 +++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/tests/test_synth_basque.py b/tests/test_synth_basque.py index 6c14027..4a8b74c 100644 --- a/tests/test_synth_basque.py +++ b/tests/test_synth_basque.py @@ -154,6 +154,8 @@ def setUp(self): ) self.treatment_time = 1975 self.pvalue = 0.16666666666666666 + self.att = {'att': -0.6995647842110987, 'se': 0.07078092130438395} + self.att_time_period = range(1975, 1998) def test_weights(self): synth = Synth() @@ -199,3 +201,20 @@ def test_placebo_weights(self): placebo_test.pvalue(treatment_time=self.treatment_time), places=3, ) + + def test_att(self): + synth = Synth() + synth.fit( + dataprep=self.dataprep, + optim_method=self.optim_method, + optim_initial=self.optim_initial, + ) + synth_att = synth.att(time_period=self.att_time_period) + + # Allow a tolerance of 2.5% + att_perc_delta = abs(1.0 - self.att['att'] / synth_att['att']) + self.assertLessEqual(att_perc_delta, 0.025) + + # Allow a tolerance of 2.5% + se_perc_delta = abs(1.0 - self.att['se'] / synth_att['se']) + self.assertLessEqual(se_perc_delta, 0.025) diff --git a/tests/test_synth_germany.py b/tests/test_synth_germany.py index 92f8d83..d6ae5c5 100644 --- a/tests/test_synth_germany.py +++ b/tests/test_synth_germany.py @@ -102,6 +102,8 @@ def setUp(self): "Australia": 0.0, "New Zealand": 0.0, } + self.att = {'att': -1555.1346777620479, 'se': 317.6469306023242} + self.att_time_period = range(1990, 2004) def test_weights(self): synth = Synth() @@ -115,3 +117,21 @@ def test_weights(self): pd.testing.assert_series_equal( weights, synth.weights(round=9), check_exact=False, atol=0.025 ) + + def test_att(self): + synth = Synth() + synth.fit( + dataprep=self.dataprep, + optim_method=self.optim_method, + optim_initial=self.optim_initial, + custom_V=self.custom_V, + ) + synth_att = synth.att(time_period=self.att_time_period) + + # Allow a tolerance of 2.5% + att_perc_delta = abs(1.0 - self.att['att'] / synth_att['att']) + self.assertLessEqual(att_perc_delta, 0.025) + + # Allow a tolerance of 2.5% + se_perc_delta = abs(1.0 - self.att['se'] / synth_att['se']) + self.assertLessEqual(se_perc_delta, 0.025) diff --git a/tests/test_synth_texas.py b/tests/test_synth_texas.py index 58fedf0..7574e93 100644 --- a/tests/test_synth_texas.py +++ b/tests/test_synth_texas.py @@ -137,6 +137,8 @@ def setUp(self): "Wisconsin": 0.0, "Wyoming": 0.0, } + self.att = {'att': 20339.375838131393, 'se': 3190.4946788704715} + self.att_time_period = range(1993, 2001) def test_weights(self): synth = Synth() @@ -150,3 +152,20 @@ def test_weights(self): pd.testing.assert_series_equal( weights, synth.weights(round=9), check_exact=False, atol=0.025 ) + + def test_att(self): + synth = Synth() + synth.fit( + dataprep=self.dataprep, + optim_method=self.optim_method, + optim_initial=self.optim_initial, + ) + synth_att = synth.att(time_period=self.att_time_period) + + # Allow a tolerance of 2.5% + att_perc_delta = abs(1.0 - self.att['att'] / synth_att['att']) + self.assertLessEqual(att_perc_delta, 0.025) + + # Allow a tolerance of 2.5% + se_perc_delta = abs(1.0 - self.att['se'] / synth_att['se']) + self.assertLessEqual(se_perc_delta, 0.025)