Skip to content

Commit

Permalink
Merge pull request #53 from sdfordham/add-w-names-attrib
Browse files Browse the repository at this point in the history
W_names allows for correct naming of weights series
  • Loading branch information
sdfordham authored Apr 30, 2024
2 parents d036beb + 58d4c8e commit 0cabf76
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
5 changes: 3 additions & 2 deletions pysyncon/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from __future__ import annotations
from typing import Optional, Literal
from typing import Optional, Literal, Sequence
from abc import ABCMeta, abstractmethod

import numpy as np
Expand All @@ -16,6 +16,7 @@ class BaseSynth(metaclass=ABCMeta):
def __init__(self) -> None:
self.dataprep: Optional[Dataprep] = None
self.W: Optional[np.ndarray] = None
self.W_names: Optional[Sequence] = None

@abstractmethod
def fit(*args, **kwargs) -> None:
Expand Down Expand Up @@ -207,7 +208,7 @@ def weights(self, round: int = 3, threshold: Optional[float] = None) -> pd.Serie
if self.W is None:
raise ValueError("No weight matrix available; fit data first.")
if self.dataprep is None:
weights_ser = pd.Series(self.W, name="weights")
weights_ser = pd.Series(self.W, index=self.W_names, name="weights")
else:
weights_ser = pd.Series(
self.W, index=list(self.dataprep.controls_identifier), name="weights"
Expand Down
1 change: 1 addition & 0 deletions pysyncon/synth.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def fun(x):
loss_V = self.calc_loss_V(W=W, Z0=Z0_arr, Z1=Z1_arr)

self.W, self.loss_W, self.V, self.loss_V = W, loss_W, V_mat.diagonal(), loss_V
self.W_names = Z0.columns

@staticmethod
def calc_loss_V(W: np.ndarray, Z0: np.ndarray, Z1: np.ndarray) -> float:
Expand Down

0 comments on commit 0cabf76

Please sign in to comment.