Skip to content

Commit

Permalink
Update is_holiday polars backend
Browse files Browse the repository at this point in the history
is_holiday polars backend
  • Loading branch information
JustinKurland authored Oct 23, 2023
1 parent 1877488 commit 191d9ed
Showing 1 changed file with 85 additions and 14 deletions.
99 changes: 85 additions & 14 deletions src/pytimetk/utils/datetime_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,9 +378,10 @@ def _week_of_month_polars(idx: Union[pd.Series, pd.DatetimeIndex]) -> pd.Series:

@pf.register_series_method
def is_holiday(
idx: Union[str, datetime, List[Union[str, datetime]], pd.DatetimeIndex, pd.Series],
idx: Union[str, datetime, List[Union[str, datetime]], pd.Series],
country_name: str = 'UnitedStates',
country: str = None
country: str = None,
engine: str = 'pandas'
) -> pd.Series:
"""
Check if a given list of dates are holidays for a specified country.
Expand All @@ -389,14 +390,23 @@ def is_holiday(
Parameters
----------
idx : Union[str, datetime, List[Union[str, datetime]], pd.DatetimeIndex, pd.Series]
idx : Union[str, datetime, List[Union[str, datetime]], pd.Series]
The dates to check for holiday status.
country_name (str, optional):
The name of the country for which to check the holiday status. Defaults
to 'UnitedStates' if not specified.
country (str, optional):
An alternative parameter to specify the country for holiday checking,
overriding country_name.
engine : str, optional
The `engine` parameter is used to specify the engine to use for
generating the boolean series. It can be either "pandas" or "polars".
- The default value is "pandas".
- When "polars", the function will internally use the `polars` library
for generating a boolean of holidays or not holidays. This can be
faster than using "pandas" for long series.
Returns:
-------
Expand All @@ -411,7 +421,7 @@ def is_holiday(
Examples:
--------
```{python}
import pandas as pd
import polars as pl
import pytimetk as tk
tk.is_holiday('2023-01-01', country_name='UnitedStates')
Expand All @@ -423,19 +433,24 @@ def is_holiday(
```
```{python}
# DatetimeIndex
tk.is_holiday(pd.date_range("2023-01-01", "2023-01-03"), country_name='UnitedStates')
```
```{python}
# Pandas Series Method
(
pd.Series(pd.date_range("2023-01-01", "2023-01-03"))
.is_holiday(country_name='UnitedStates')
)
# Polars Series
tk.is_holiday(pl.Series(['2023-01-01', '2023-01-02', '2023-01-03']), country_name='UnitedStates')
```
"""

if engine == 'pandas':
return _is_holiday_pandas(idx, country_name , country)
elif engine == 'polars':
return _is_holiday_polars(idx, country_name, country)
else:
raise ValueError("Invalid engine. Use 'pandas' or 'polars'.")

def _is_holiday_pandas(
idx: Union[str, datetime, List[Union[str, datetime]], pd.DatetimeIndex, pd.Series],
country_name: str = 'UnitedStates',
country: str = None
) -> pd.Series:

# This function requires the holidays package to be installed
try:
import holidays
Expand Down Expand Up @@ -463,6 +478,62 @@ def is_holiday(

return ret

def _is_holiday_polars(
idx: Union[str, datetime, List[Union[str, datetime]], pd.Series],
country_name: str = 'UnitedStates',
country: str = None
) -> pd.Series:

# This function requires the holidays package to be installed
try:
import holidays
except ImportError:
raise ImportError("The 'holidays' package is not installed. Please install it by running 'pip install holidays'.")

if country:
country_name = country # Override the default country_name with the provided one

# Convert pl.date objects if they are strings
if isinstance(idx, str):
start_year, start_month, start_day = map(int,idx.split("-"))
start = pl.date(start_year, start_month, start_day)
end = start
end_year, end_month, end_day = map(int,idx.split("-"))

# Convert to pl.date objects DatetimeIndex object
if isinstance(idx, pd.core.indexes.datetimes.DatetimeIndex):
date_range = idx

start_date = date_range[0]
start_year, start_month, start_day = start_date.year, start_date.month, start_date.day
start = pl.date(start_year, start_month, start_day)

end_date = date_range[-1]
end_year, end_month, end_day = end_date.year, end_date.month, end_date.day
end = pl.date(end_year, end_month, end_day)

# Convert list of strings to list of dates
if isinstance(idx, list):
dates = []
for date_str in idx:
date = pd.to_datetime(date_str).date()
dates.append(date)
start_date = dates[0]
start_year, start_month, start_day = start_date.year, start_date.month, start_date.day
start = pl.date(start_year, start_month, start_day)

end_date = dates[-1]
end_year, end_month, end_day = end_date.year, end_date.month, end_date.day
end = pl.date(end_year, end_month, end_day)

holidays_list = list(holidays.country_holidays(country_name, years=[start_year,end_year]))
expr = pl.date_range(start, end)
is_holiday = expr.is_in(holidays_list)

ret = pl.select(is_holiday).to_series().to_pandas()

return ret

def is_datetime_string(x: Union[str, pd.Series, pd.DatetimeIndex]) -> bool:

if isinstance(x, pd.Series):
Expand Down

0 comments on commit 191d9ed

Please sign in to comment.