From 58d4c8eda1224d3caa717929c195d756eba4254e Mon Sep 17 00:00:00 2001 From: Stiofain <17852477+sdfordham@users.noreply.github.com> Date: Tue, 30 Apr 2024 20:11:09 +0100 Subject: [PATCH] W_names allows for correct naming of weights series --- pysyncon/base.py | 5 +++-- pysyncon/synth.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pysyncon/base.py b/pysyncon/base.py index a833c10..4786020 100644 --- a/pysyncon/base.py +++ b/pysyncon/base.py @@ -1,5 +1,5 @@ from __future__ import annotations -from typing import Optional, Literal +from typing import Optional, Literal, Sequence from abc import ABCMeta, abstractmethod import numpy as np @@ -16,6 +16,7 @@ class BaseSynth(metaclass=ABCMeta): def __init__(self) -> None: self.dataprep: Optional[Dataprep] = None self.W: Optional[np.ndarray] = None + self.W_names: Optional[Sequence] = None @abstractmethod def fit(*args, **kwargs) -> None: @@ -207,7 +208,7 @@ def weights(self, round: int = 3, threshold: Optional[float] = None) -> pd.Serie if self.W is None: raise ValueError("No weight matrix available; fit data first.") if self.dataprep is None: - weights_ser = pd.Series(self.W, name="weights") + weights_ser = pd.Series(self.W, index=self.W_names, name="weights") else: weights_ser = pd.Series( self.W, index=list(self.dataprep.controls_identifier), name="weights" diff --git a/pysyncon/synth.py b/pysyncon/synth.py index 95e05bd..06ed317 100644 --- a/pysyncon/synth.py +++ b/pysyncon/synth.py @@ -178,6 +178,7 @@ def fun(x): loss_V = self.calc_loss_V(W=W, Z0=Z0_arr, Z1=Z1_arr) self.W, self.loss_W, self.V, self.loss_V = W, loss_W, V_mat.diagonal(), loss_V + self.W_names = Z0.columns @staticmethod def calc_loss_V(W: np.ndarray, Z0: np.ndarray, Z1: np.ndarray) -> float: