From 29b36bb80753e7ac54a82e3acbe44501f21a84f2 Mon Sep 17 00:00:00 2001 From: Grigorii Smirnov-Pinchukov Date: Wed, 21 Feb 2024 10:20:20 +0100 Subject: [PATCH] Do not fail if holidays are missing in workalendar --- .../_holiday_transformer.py | 25 +++++++++++++------ 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/hcrystalball/feature_extraction/_holiday_transformer.py b/src/hcrystalball/feature_extraction/_holiday_transformer.py index 5195ddb..7ee3a61 100644 --- a/src/hcrystalball/feature_extraction/_holiday_transformer.py +++ b/src/hcrystalball/feature_extraction/_holiday_transformer.py @@ -1,3 +1,5 @@ +import logging + import pandas as pd from sklearn.base import BaseEstimator from sklearn.base import TransformerMixin @@ -67,7 +69,7 @@ def get_feature_names(self): @unified_country_code.setter def unified_country_code(self, value): if value is not None and value not in list(registry.region_registry.keys()): - raise ValueError("Unknown `country_code`. For list of valid codes please look at workalendar.") + logging.warning("Unknown `country_code`. For list of valid codes please look at workalendar.") self._unified_country_code = value def fit(self, X, y=None): @@ -141,14 +143,21 @@ def transform(self, X, y=None): self.unified_country_code = X[self.country_code_column].unique()[0] years = X.index.year.unique().tolist() + [max(X.index.year)] - cal = registry.region_registry[self.unified_country_code]() - holidays = ( - pd.concat( - [pd.DataFrame(data=cal.holidays(year), columns=["date", self._col_name]) for year in years] + try: + cal = registry.region_registry[self.unified_country_code]() + holidays = ( + pd.concat( + [ + pd.DataFrame(data=cal.holidays(year), columns=["date", self._col_name]) + for year in years + ] + ) + # one day could have multiple public holidays + .drop_duplicates(subset="date").set_index("date") ) - # one day could have multiple public holidays - .drop_duplicates(subset="date").set_index("date") - ) + except KeyError: + logging.warning("HolidayTransformer: No holidays found for %s", self.unified_country_code) + holidays = pd.DataFrame(columns=["date", self._col_name]).set_index("date") df = ( pd.merge(X, holidays, left_index=True, right_index=True, how="left")