Skip to content

Commit

Permalink
Merge pull request #32 from sdfordham/add-att-method
Browse files Browse the repository at this point in the history
Add att method
  • Loading branch information
sdfordham authored Nov 20, 2023
2 parents 07f92a9 + 0966a76 commit 870efd5
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 22 deletions.
61 changes: 44 additions & 17 deletions examples/basque.ipynb

Large diffs are not rendered by default.

38 changes: 36 additions & 2 deletions examples/germany.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,44 @@
"synth.gaps_plot(time_period=range(1960, 2004), treatment_time=1990)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the average treatment effect on the treated unit (ATT) over the post-treatment time period. This method returns a standard error also."
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'att': -1558.4329540422546, 'se': 317.5609062753852}"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"synth.att(time_period=range(1990, 2004))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The summary function give more information on the predictor values. The first column shows the value of the $V$ matrix for each predictor, the column 'treated' shows the mean value of each predictor for the treated unit over the time period `time_predictors_prior`, the column 'synthetic' shows the mean value of each predictor for the synthetic control over the time period `time_predictors_prior` and finally the column 'sample mean' shows the sample mean of that predictor for all control units over the time period `time_predictors_prior` i.e. this is the same as the synthetic control with all weights equal."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
Expand Down Expand Up @@ -322,7 +356,7 @@
"special.3.invest80 0.155 27.018 27.073 25.895"
]
},
"execution_count": 7,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -348,7 +382,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0"
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
38 changes: 36 additions & 2 deletions examples/texas.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -210,6 +210,40 @@
"synth.gaps_plot(time_period=range(1985, 2001), treatment_time=1993)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Compute the average treatment effect on the treated unit (ATT) over the post-treatment time period. This method returns a standard error also."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'att': 20339.375838131393, 'se': 3190.4946788704715}"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"synth.att(time_period=range(1993, 2001))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The summary function give more information on the predictor values. The first column shows the value of the $V$ matrix for each predictor, the column 'treated' shows the mean value of each predictor for the treated unit over the time period `time_predictors_prior`, the column 'synthetic' shows the mean value of each predictor for the synthetic control over the time period `time_predictors_prior` and finally the column 'sample mean' shows the sample mean of that predictor for all control units over the time period `time_predictors_prior` i.e. this is the same as the synthetic control with all weights equal."
]
},
{
"cell_type": "code",
"execution_count": 6,
Expand Down Expand Up @@ -389,7 +423,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.0 (tags/v3.9.0:9cf6752, Oct 5 2020, 15:34:40) [MSC v.1927 64 bit (AMD64)]"
"version": "3.10.4"
},
"orig_nbformat": 4,
"vscode": {
Expand Down
30 changes: 29 additions & 1 deletion pysyncon/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def path_plot(

def _gaps(self, time_period: Optional[IsinArg_t] = None) -> pd.Series:
"""Calculate the gaps (difference between factual
and estimated conterfactual)
and estimated counterfactual)
Parameters
----------
Expand Down Expand Up @@ -238,6 +238,34 @@ def summary(self, round: int = 3) -> pd.DataFrame:

return pd.concat([treated, synthetic, sample_mean], axis=1).round(round)

def att(self, time_period: IsinArg_t) -> dict[str, float]:
"""Computes the average treatment effect on the treated unit (ATT) and
the standard error to the value over the chosen time-period.
Parameters
----------
time_period : Iterable | pandas.Series | dict, optional
Time period to compute the ATT over.
Returns
-------
dict
A dictionary with the ATT value and the standard error to the ATT.
Raises
------
ValueError
If there is no weight matrix available
"""
if self.W is None:
raise ValueError("No weight matrix available; fit data first.")
gaps = self._gaps(time_period=time_period)

att = np.mean(gaps)
se = np.std(gaps, ddof=1) / np.sqrt(len(time_period))

return {"att": att.item(), "se": se.item()}


class VanillaOptimMixin:
@staticmethod
Expand Down
5 changes: 5 additions & 0 deletions tests/test_synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,8 @@ def test_summary(self):
synth.V = None
# No V matrix available
self.assertRaises(ValueError, synth.summary)

def test_att(self):
synth = pysyncon.Synth()
# No weight matrix set
self.assertRaises(ValueError, synth.att, range(1))
19 changes: 19 additions & 0 deletions tests/test_synth_basque.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ def setUp(self):
)
self.treatment_time = 1975
self.pvalue = 0.16666666666666666
self.att = {"att": -0.6995647842110987, "se": 0.07078092130438395}
self.att_time_period = range(1975, 1998)

def test_weights(self):
synth = Synth()
Expand Down Expand Up @@ -199,3 +201,20 @@ def test_placebo_weights(self):
placebo_test.pvalue(treatment_time=self.treatment_time),
places=3,
)

def test_att(self):
synth = Synth()
synth.fit(
dataprep=self.dataprep,
optim_method=self.optim_method,
optim_initial=self.optim_initial,
)
synth_att = synth.att(time_period=self.att_time_period)

# Allow a tolerance of 2.5%
att_perc_delta = abs(1.0 - self.att["att"] / synth_att["att"])
self.assertLessEqual(att_perc_delta, 0.025)

# Allow a tolerance of 2.5%
se_perc_delta = abs(1.0 - self.att["se"] / synth_att["se"])
self.assertLessEqual(se_perc_delta, 0.025)
20 changes: 20 additions & 0 deletions tests/test_synth_germany.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,8 @@ def setUp(self):
"Australia": 0.0,
"New Zealand": 0.0,
}
self.att = {"att": -1555.1346777620479, "se": 317.6469306023242}
self.att_time_period = range(1990, 2004)

def test_weights(self):
synth = Synth()
Expand All @@ -115,3 +117,21 @@ def test_weights(self):
pd.testing.assert_series_equal(
weights, synth.weights(round=9), check_exact=False, atol=0.025
)

def test_att(self):
synth = Synth()
synth.fit(
dataprep=self.dataprep,
optim_method=self.optim_method,
optim_initial=self.optim_initial,
custom_V=self.custom_V,
)
synth_att = synth.att(time_period=self.att_time_period)

# Allow a tolerance of 2.5%
att_perc_delta = abs(1.0 - self.att["att"] / synth_att["att"])
self.assertLessEqual(att_perc_delta, 0.025)

# Allow a tolerance of 2.5%
se_perc_delta = abs(1.0 - self.att["se"] / synth_att["se"])
self.assertLessEqual(se_perc_delta, 0.025)
19 changes: 19 additions & 0 deletions tests/test_synth_texas.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ def setUp(self):
"Wisconsin": 0.0,
"Wyoming": 0.0,
}
self.att = {"att": 20339.375838131393, "se": 3190.4946788704715}
self.att_time_period = range(1993, 2001)

def test_weights(self):
synth = Synth()
Expand All @@ -150,3 +152,20 @@ def test_weights(self):
pd.testing.assert_series_equal(
weights, synth.weights(round=9), check_exact=False, atol=0.025
)

def test_att(self):
synth = Synth()
synth.fit(
dataprep=self.dataprep,
optim_method=self.optim_method,
optim_initial=self.optim_initial,
)
synth_att = synth.att(time_period=self.att_time_period)

# Allow a tolerance of 2.5%
att_perc_delta = abs(1.0 - self.att["att"] / synth_att["att"])
self.assertLessEqual(att_perc_delta, 0.025)

# Allow a tolerance of 2.5%
se_perc_delta = abs(1.0 - self.att["se"] / synth_att["se"])
self.assertLessEqual(se_perc_delta, 0.025)

0 comments on commit 870efd5

Please sign in to comment.