Skip to content

Commit

Permalink
TimeSeriesCVSplitter: init
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Nov 6, 2024
1 parent 4a01ac3 commit dad84b0
Show file tree
Hide file tree
Showing 5 changed files with 226 additions and 29 deletions.
1 change: 1 addition & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ quartodoc:
package: pytimetk
contents:
- TimeSeriesCV
- TimeSeriesCVSplitter
- title: 💹 Finance Module (Momentum Indicators)
desc: Momentum indicators for financial time series data.
package: pytimetk
Expand Down
55 changes: 30 additions & 25 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ pathos = "^0.3.1"
adjusttext = "^0.8"
xarray = "^2024.6.0" # "<=2023.10.1" https://github.com/pyjanitor-devs/pandas_flavor/issues/33
timebasedcv = "^0.2.1"
scikit-learn = "^1.5.2"


[tool.pytest.ini_options]
Expand Down
2 changes: 1 addition & 1 deletion src/pytimetk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
augment_fourier
)
from .crossvalidation.time_series_cv import (
TimeSeriesCV,
TimeSeriesCV, TimeSeriesCVSplitter
)
from .core.ts_features import (
ts_features
Expand Down
196 changes: 193 additions & 3 deletions src/pytimetk/crossvalidation/time_series_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import numpy as np

import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.model_selection import BaseCrossValidator

from timebasedcv import TimeBasedSplit
from timebasedcv.splitstate import SplitState
Expand Down Expand Up @@ -182,9 +183,9 @@ def split(
time_series: pd.Series
The time series used to create boolean masks for splits. If not provided, the method will try
to use the index of the first array (if it is a DataFrame or Series) as the time series.
start_dt: str
start_dt: pd.Timestamp
The start of the time period. If provided, it is used in place of `time_series.min()`.
end_dt: str
end_dt: pd.Timestamp
The end of the time period. If provided, it is used in place of `time_series.max()`.
return_splitstate: bool
Whether to return the `SplitState` instance for each split.
Expand Down Expand Up @@ -447,6 +448,195 @@ def plot(



class TimeSeriesCVSplitter(BaseCrossValidator):
"""The `TimeSeriesCVSplitter` is a scikit-learn compatible cross-validator using `TimeSeriesCV`.
This cross-validator generates splits based on time values, making it suitable for time series data.
Parameters:
-----------
frequency: str
The frequency of the time series (e.g., "days", "hours").
train_size: int
Minimum number of time units in the training set.
forecast_horizon: int
Number of time units to forecast in each split.
time_series: pd.Series
A pandas Series or Index representing the time values.
gap: int
Number of time units to skip between training and testing sets.
stride: int
Number of time units to move forward after each split.
window: str
Type of window, either "rolling" or "expanding".
mode: str
Order of split generation, "forward" or "backward".
start_dt: pd.Timestamp
Start date for the time period.
end_dt: pd.Timestamp
End date for the time period.
Raises:
-------
ValueError:
If the input arrays are incompatible in length with the time series.
Returns:
--------
A generator of tuples of arrays containing the training and forecast data.
See Also:
--------
TimeSeriesCV
Examples
--------
``` {python}
import pandas as pd
import numpy as np
from pytimetk import TimeSeriesCVSplitter
start_dt = pd.Timestamp(2023, 1, 1)
end_dt = pd.Timestamp(2023, 1, 31)
time_series = pd.Series(pd.date_range(start_dt, end_dt, freq="D"))
size = len(time_series)
df = pd.DataFrame(data=np.random.randn(size, 2), columns=["a", "b"])
X, y = df[["a", "b"]], df[["a", "b"]].sum(axis=1)
cv = TimeSeriesCVSplitter(
time_series=time_series,
frequency="days",
train_size=7,
forecast_horizon=11,
gap=0,
stride=1,
window="rolling",
)
cv
```
``` python
# Using the TimeSeriesCVSplitter in a scikit-learn CV model
from sklearn.linear_model import Ridge
from sklearn.model_selection import RandomizedSearchCV
# Fit and get best estimator
param_grid = {
"alpha": np.linspace(0.1, 2, 10),
"fit_intercept": [True, False],
"positive": [True, False],
}
random_search_cv = RandomizedSearchCV(
estimator=Ridge(),
param_distributions=param_grid,
cv=cv,
n_jobs=-1,
).fit(X, y)
random_search_cv.best_estimator_
```
"""

def __init__(
self,
*,
frequency: str,
train_size: int,
forecast_horizon: int,
time_series: Union[pd.Series, pd.Index],
gap: int = 0,
stride: Union[int, None] = None,
window: str = "rolling",
mode: str = "backward",
start_dt: pd.Timestamp = None,
end_dt: pd.Timestamp = None,
):
self.splitter = TimeSeriesCV(
frequency=frequency,
train_size=train_size,
forecast_horizon=forecast_horizon,
gap=gap,
stride=stride,
window=window,
mode=mode,
)
self.time_series_ = time_series
self.start_dt_ = start_dt
self.end_dt_ = end_dt
self.n_splits = self._compute_n_splits()
self.size_ = len(time_series)

def split(
self,
X: Union[np.ndarray, None] = None,
y: Union[np.ndarray, None] = None,
groups: Union[np.ndarray, None] = None,
) -> Generator[Tuple[np.ndarray, np.ndarray], None, None]:
"""Generates train and test indices for cross-validation.
Parameters:
-----------
X:
Optional input features (ignored, for compatibility with scikit-learn).
y:
Optional target variable (ignored, for compatibility with scikit-learn).
groups:
Optional group labels (ignored, for compatibility with scikit-learn).
Yields:
Tuples of train and test indices.
"""
self._validate_split_args(self.size_, X, y, groups)

index_range = np.arange(self.size_)

for train_mask, test_mask in self.splitter.split(
index_range,
time_series=self.time_series_,
start_dt=self.start_dt_,
end_dt=self.end_dt_,
return_splitstate=False,
):
yield index_range[train_mask], index_range[test_mask]

def get_n_splits(
self,
X: Union[np.ndarray, None] = None,
y: Union[np.ndarray, None] = None,
groups: Union[np.ndarray, None] = None,
) -> int:
"""Returns the number of splits."""
self._validate_split_args(self.size_, X, y, groups)
return self.n_splits

def _compute_n_splits(self) -> int:
"""Computes the number of splits based on the time period."""
time_start = self.start_dt_ or self.time_series_.min()
time_end = self.end_dt_ or self.time_series_.max()
return len(list(self.splitter._splits_from_period(time_start, time_end)))

@staticmethod
def _validate_split_args(
size: int,
X: Union[np.ndarray, None] = None,
y: Union[np.ndarray, None] = None,
groups: Union[np.ndarray, None] = None,
) -> None:
"""Validates that input arrays match the expected size."""
if X is not None and len(X) != size:
raise ValueError(f"Invalid shape: X has {len(X)} elements, expected {size}.")
if y is not None and len(y) != size:
raise ValueError(f"Invalid shape: y has {len(y)} elements, expected {size}.")
if groups is not None and len(groups) != size:
raise ValueError(f"Invalid shape: groups has {len(groups)} elements, expected {size}.")




Expand Down

0 comments on commit dad84b0

Please sign in to comment.