Skip to content

Commit

Permalink
Make tolerance required in fn
Browse files Browse the repository at this point in the history
  • Loading branch information
sdfordham committed May 10, 2024
1 parent 23b7ca7 commit ac8c74d
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 80 deletions.
137 changes: 111 additions & 26 deletions examples/basque.ipynb

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions examples/factor-model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -343,12 +343,12 @@
"_ = synth.confidence_interval(\n",
" alpha=0.05,\n",
" time_periods=[51, 52],\n",
" tol=0.1,\n",
" pre_periods=list(range(1, 51)),\n",
" X0=X0,\n",
" X1=X1,\n",
" Z0=Z0,\n",
" Z1=Z1,\n",
" max_iter=40\n",
" Z1=Z1\n",
")"
]
}
Expand Down
62 changes: 31 additions & 31 deletions examples/germany.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -464,62 +464,62 @@
" <tr>\n",
" <th>1991</th>\n",
" <td>279.096860</td>\n",
" <td>43.148688</td>\n",
" <td>515.045031</td>\n",
" <td>43.070748</td>\n",
" <td>515.122971</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1992</th>\n",
" <td>99.762034</td>\n",
" <td>-136.186137</td>\n",
" <td>335.710206</td>\n",
" <td>-136.264077</td>\n",
" <td>335.788146</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1993</th>\n",
" <td>-631.543723</td>\n",
" <td>-867.491895</td>\n",
" <td>-395.595552</td>\n",
" <td>-867.569835</td>\n",
" <td>-395.517612</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1994</th>\n",
" <td>-1050.267990</td>\n",
" <td>-1286.216162</td>\n",
" <td>-814.319818</td>\n",
" <td>-1286.294102</td>\n",
" <td>-814.241878</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1995</th>\n",
" <td>-1205.254923</td>\n",
" <td>-1441.203094</td>\n",
" <td>-969.306751</td>\n",
" <td>-1441.281034</td>\n",
" <td>-969.228811</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1996</th>\n",
" <td>-1467.249163</td>\n",
" <td>-1703.197334</td>\n",
" <td>-1231.300991</td>\n",
" <td>-1703.275274</td>\n",
" <td>-1231.223051</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1997</th>\n",
" <td>-1954.374169</td>\n",
" <td>-2190.322341</td>\n",
" <td>-1718.425997</td>\n",
" <td>-2190.400281</td>\n",
" <td>-1718.348057</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1998</th>\n",
" <td>-2008.396030</td>\n",
" <td>-2244.344202</td>\n",
" <td>-1772.447858</td>\n",
" <td>-2244.422142</td>\n",
" <td>-1772.369918</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1999</th>\n",
" <td>-2160.627037</td>\n",
" <td>-2396.575208</td>\n",
" <td>-1924.678865</td>\n",
" <td>-2396.653148</td>\n",
" <td>-1924.600925</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2000</th>\n",
" <td>-2620.733091</td>\n",
" <td>-2856.681263</td>\n",
" <td>-2384.784919</td>\n",
" <td>-2856.759203</td>\n",
" <td>-2384.706979</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
Expand All @@ -528,16 +528,16 @@
"text/plain": [
" value lower_ci upper_ci\n",
"time \n",
"1991 279.096860 43.148688 515.045031\n",
"1992 99.762034 -136.186137 335.710206\n",
"1993 -631.543723 -867.491895 -395.595552\n",
"1994 -1050.267990 -1286.216162 -814.319818\n",
"1995 -1205.254923 -1441.203094 -969.306751\n",
"1996 -1467.249163 -1703.197334 -1231.300991\n",
"1997 -1954.374169 -2190.322341 -1718.425997\n",
"1998 -2008.396030 -2244.344202 -1772.447858\n",
"1999 -2160.627037 -2396.575208 -1924.678865\n",
"2000 -2620.733091 -2856.681263 -2384.784919"
"1991 279.096860 43.070748 515.122971\n",
"1992 99.762034 -136.264077 335.788146\n",
"1993 -631.543723 -867.569835 -395.517612\n",
"1994 -1050.267990 -1286.294102 -814.241878\n",
"1995 -1205.254923 -1441.281034 -969.228811\n",
"1996 -1467.249163 -1703.275274 -1231.223051\n",
"1997 -1954.374169 -2190.400281 -1718.348057\n",
"1998 -2008.396030 -2244.422142 -1772.369918\n",
"1999 -2160.627037 -2396.653148 -1924.600925\n",
"2000 -2620.733091 -2856.759203 -2384.706979"
]
},
"execution_count": 10,
Expand All @@ -550,7 +550,7 @@
" alpha=0.05,\n",
" time_periods=[1991, 1992, 1993, 1994, 1995, 1996, 1997, 1998, 1999, 2000],\n",
" custom_V=synth_train.V,\n",
" max_iter=50,\n",
" tol=0.01,\n",
" verbose=False\n",
")"
]
Expand Down
33 changes: 19 additions & 14 deletions pysyncon/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, Callable
from typing import Optional, Callable, Literal

import numpy as np
import pandas as pd
Expand All @@ -8,8 +8,8 @@


class ConformalInference:
"""Implementation of the conformal inference method of
Chernozhukov et al. :cite:`inference2021`
"""Implementation of the conformal inference based confidence intervals
following Chernozhukov et al. :cite:`inference2021`
"""

def __init__(self) -> None:
Expand All @@ -23,15 +23,15 @@ def confidence_intervals(
Z1: pd.Series,
pre_periods: list,
post_periods: list,
max_iter: int = 20,
tol: float = 0.1,
max_iter: int = 50,
step_sz: Optional[float] = None,
step_sz_div: float = 20.0,
verbose: bool = True,
scm_fit_args: dict = {},
) -> pd.DataFrame:
"""Confidence intervals obtained from test-inversion, where
the p-values are obtained by adjusted refits of the data
the p-values are obtained by adjusted re-fits of the data
following Chernozhukov et al. :cite:`inference2021`.
Parameters
Expand All @@ -48,20 +48,19 @@ def confidence_intervals(
Z1 : pd.Series
Column vector giving the outcome variable values over time for the
treated unit.
tol : float
The required tolerance (accuracy) required when calculating the
lower/upper cut-off point of the confidence interval. The search
will try to obtain this tolerance level but will not exceed `max_iter`
iterations trying to achieve that.
pre_periods : list
The time-periods to use for the optimization when refitting the
data with the adjusted outcomes.
post_periods : list
The time-periods to calculate confidence intervals for.
max_iter : int, optional
Maximum number of times to re-fit the data when trying to locate
the lower/upper cut-off point and when binary searching for the
cut-off point, by default 20
tol : float, optional
The required tolerance (accuracy) required when calculating the
lower/upper cut-off point of the confidence interval. The search
will try to obtain this tolerance level but will not exceed `max_iter`
iterations trying to achieve that, by default 0.1.
the lower/upper cut-off point, by default 50
step_sz : Optional[float], optional
Step size to use when searching for an interval that contains the
lower or upper cut-off point of the confidence interval, by default None
Expand Down Expand Up @@ -124,6 +123,8 @@ def confidence_intervals(
raise TypeError("`step_sz` should be a float")
elif step_sz <= 0.0:
raise ValueError("`step_sz` should be greater than 0.0")
elif step_sz <= tol:
raise ValueError("`step_sz` must be greater than `tol`.")
if not isinstance(step_sz_div, float):
raise TypeError("`step_sz_div` must be a float")
elif step_sz_div <= 0.0:
Expand All @@ -133,12 +134,16 @@ def confidence_intervals(

gaps = scm._gaps(Z0=Z0, Z1=Z1)
if step_sz is None:
# Try to guess a step-size
if len(post_periods) > 1:
factor = np.std(gaps.loc[post_periods])
else:
factor = gaps.loc[post_periods].item() / 2.0
step_sz = 2.5 * factor / step_sz_div

if step_sz <= tol:
# Failed to guess a sensible step-size :(
step_sz = 1.1 * tol

conf_interval = dict()
n_periods = len(post_periods)
for idx, post_period in enumerate(post_periods, 1):
Expand Down Expand Up @@ -195,7 +200,7 @@ def _root_search(
self,
fn: Callable,
x0: float,
direction: int,
direction: Literal[+1, -1],
tol: float,
step_sz: float,
max_iter: int,
Expand Down
14 changes: 7 additions & 7 deletions pysyncon/synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def confidence_interval(
self,
alpha: float,
time_periods: list,
tol: float,
pre_periods: Optional[list] = None,
dataprep: Optional[Dataprep] = None,
X0: Optional[pd.DataFrame] = None,
Expand All @@ -271,8 +272,7 @@ def confidence_interval(
optim_initial: Literal["equal", "ols"] = None,
optim_options: dict = None,
method: Literal["conformal"] = "conformal",
max_iter: int = 20,
tol: float = 0.1,
max_iter: int = 50,
step_sz: Optional[float] = None,
step_sz_div: float = 20.0,
verbose: bool = True,
Expand All @@ -288,6 +288,11 @@ def confidence_interval(
yield a confidence level of 100 * (1 - alpha) = 95%.
time_periods : list
The time-periods to calculate confidence intervals for.
tol : float
The required tolerance (accuracy) required when calculating the
lower/upper cut-off point of the confidence interval. The search
will try to obtain this tolerance level but will not exceed `max_iter`
iterations trying to achieve that.
pre_periods : Optional[list], optional
The time-periods to use for the optimization when refitting the
data with the adjusted outcomes, optional.
Expand Down Expand Up @@ -357,11 +362,6 @@ def confidence_interval(
Maximum number of times to re-fit the data when trying to locate
the lower/upper cut-off point and when binary searching for the
cut-off point, by default 20.
tol : float, optional
The required tolerance (accuracy) required when calculating the
lower/upper cut-off point of the confidence interval. The search
will try to obtain this tolerance level but will not exceed `max_iter`
iterations trying to achieve that, by default 0.1.
step_sz : Optional[float], optional
Step size to use when searching for an interval that contains the
lower or upper cut-off point of the confidence interval, by default None.
Expand Down

0 comments on commit ac8c74d

Please sign in to comment.