Skip to content

Commit

Permalink
Added dims parameter to vdid
Browse files Browse the repository at this point in the history
  • Loading branch information
Julian Blank committed Aug 14, 2024
1 parent 242bd0c commit 862afcc
Showing 1 changed file with 22 additions and 18 deletions.
40 changes: 22 additions & 18 deletions azcausal/estimators/panel/vdid.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
from typing import Iterator, List, Callable

import numpy as np
Expand Down Expand Up @@ -81,17 +82,21 @@ def group_by_index(dx):
return dx.groupby(list(dx.index.names))


def vdid_avg_by(dx, label, col):
counts = dx.reset_index().groupby(label)[col].nunique()
def vdid_avg_by(dx, label, col, dim=None):
if dim is None:
dim = dx.reset_index().groupby(label)[col].unique()

counts = dim.map(len)
avg = group_by_index(dx.droplevel(col, axis='index')).sum().multiply(1 / counts, axis='index', level=label)
return avg, counts


def vdid_avg(dx, groups, counts=None):
if counts is None:
counts = dict()
def vdid_avg(dx, groups, dims=None):
if dims is None:
dims = dict()
counts = dict()
for label, col in groups:
dx, counts[label] = vdid_avg_by(dx, label, col)
dx, counts[label] = vdid_avg_by(dx, label, col, dims.get(label))
return dx, counts


Expand Down Expand Up @@ -170,13 +175,10 @@ def sample(treatment: pd.Series) -> Iterator[pd.Series]:
return sample, vdid_se


def dot_by_columns(ds, columns, name, weights=None):
if weights is None:
weights = columns.map(lambda x: np.full(len(x), 1 / len(x)))
else:
weights = columns.map(lambda x: weights.loc[x].values).map(lambda x: x / np.sum(x))

return pd.DataFrame({k: ds[v].values @ weights[k] for k, v in columns.items()}, index=ds.index).rename_axis(name, axis=1)
def dot_by_columns(ds, columns, name):
counts = columns.map(len)
columns = columns.map(lambda x: np.array([e for e in x if e in ds]))
return pd.DataFrame({k: np.sum(ds[v].values, axis=1) / counts[k] for k, v in columns.items()}, index=ds.index).rename_axis(name, axis=1)


def vdid_sign(row):
Expand Down Expand Up @@ -207,6 +209,7 @@ def vdid(dx: pd.DataFrame,
ratio=None,
ratio_marginal=None,
fillna=None,
dims=None,
f: Callable = lambda dx: dx,
g: Callable = lambda dx: dx
):
Expand All @@ -216,6 +219,8 @@ def vdid(dx: pd.DataFrame,
ratio_marginal = dict()
if randomize is None:
randomize, _ = diffs[-1]
if dims is None:
dims = defaultdict(None)

labels = {k: v for k, v in diffs}
did = list(labels.keys())
Expand All @@ -233,12 +238,11 @@ def vdid(dx: pd.DataFrame,
dx = pd.melt(dx, id_vars=index, var_name='target', value_name='value').set_index(index + ['target'])['value']

# grouping along the difference list that was provided
davg = dx
counts = dict()

davg, counts = vdid_avg(davg, [(k, v) for (k, v) in diffs if k != randomize], counts=counts)
davg, counts = vdid_avg(dx, [(k, v) for (k, v) in diffs if k != randomize], dims=dims)

units = davg.reset_index().groupby(randomize)[labels[randomize]].unique()
units = dims.get(randomize)
if units is None:
units = davg.reset_index().groupby(randomize)[labels[randomize]].unique()
counts[randomize] = units.map(lambda x: len(x))

matrix = davg.droplevel(axis='index', level=randomize).unstack(labels[randomize]).fillna(0.0)
Expand Down

0 comments on commit 862afcc

Please sign in to comment.