Skip to content

Commit

Permalink
TSCV: add plotting method
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Nov 5, 2024
1 parent b4c8608 commit 64c9428
Showing 1 changed file with 151 additions and 24 deletions.
175 changes: 151 additions & 24 deletions src/pytimetk/crossvalidation/time_series_cv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import pandas as pd
import numpy as np

import plotly.graph_objects as go
from plotly.subplots import make_subplots

from timebasedcv import TimeBasedSplit
from timebasedcv.splitstate import SplitState
from timebasedcv.utils._types import ModeType
Expand All @@ -20,7 +23,7 @@

class TimeSeriesCV(TimeBasedSplit):
"""`TimeSeriesCV` is a subclass of `TimeBasedSplit` with default mode set to 'backward'
and an optional `slice_limit` to return the first `n` slices of time series cross-validation sets.
and an optional `split_limit` to return the first `n` slices of time series cross-validation sets.
Parameters
----------
Expand All @@ -41,8 +44,8 @@ class TimeSeriesCV(TimeBasedSplit):
The type of window to use, either "rolling" or "expanding".
mode: ModeType, optional
The mode to use for cross-validation. Default is 'backward'.
slice_limit: int, optional
The maximum number of slices to return. If not provided, all slices are returned.
split_limit: int, optional
The maximum number of splits to return. If not provided, all splits are returned.
Raises:
----------
Expand All @@ -59,7 +62,7 @@ class TimeSeriesCV(TimeBasedSplit):
Examples:
---------
```python
``` {python}
import pandas as pd
import numpy as np
from pytimetk import TimeSeriesCV
Expand Down Expand Up @@ -97,29 +100,35 @@ class TimeSeriesCV(TimeBasedSplit):
forecast_horizon=5,
gap=1,
stride=0,
slice_limit=3 # Limiting to 3 slices
split_limit=3 # Limiting to 3 splits
)
X, y = df.loc[:, ["a", "b"]], df["y"]
# If `time_series` is not provided, it will use the index of `X` or `y` if available
for X_train, X_forecast, y_train, y_forecast in tscv.split(X, y):
# Get the start and end dates for the training and forecast periods
train_start_date = min(X_train.index)
train_end_date = max(X_train.index)
forecast_start_date = min(X_forecast.index)
forecast_end_date = max(X_forecast.index)
print(f"Train: {X_train.shape}, Forecast: {X_forecast.shape}")
print(f"Train Period: {train_start_date} to {train_end_date}")
print(f"Forecast Period: {forecast_start_date} to {forecast_end_date}\n")
# Creates a split generator
splits = tscv.split(X, y)
for X_train, X_forecast, y_train, y_forecast in splits:
print(X_train)
print(X_forecast)
```
``` {python}
# Also, you can use `glimpse()` to print summary information about the splits
tscv.glimpse(y)
```
``` {python}
# You can also plot the splits by calling `plot()` on the `TimeSeriesCV` instance with the `y` Pandas series
tscv.plot(y)
```
"""

def __init__(self, *args, mode: ModeType = "backward", slice_limit: int = None, **kwargs):
def __init__(self, *args, mode: ModeType = "backward", split_limit: int = None, **kwargs):
super().__init__(*args, mode=mode, **kwargs)
self.slice_limit = slice_limit
self.split_limit = split_limit

def split(
self,
Expand All @@ -129,7 +138,7 @@ def split(
end_dt: NullableDatetime = None,
return_splitstate: bool = False,
) -> Generator[Union[Tuple[TL, ...], Tuple[Tuple[TL, ...], SplitState]], None, None]:
"""Returns a generator of split arrays with an optional `slice_limit`.
"""Returns a generator of split arrays with an optional `split_limit`.
Arguments:
*arrays:
Expand All @@ -145,8 +154,8 @@ def split(
Whether to return the `SplitState` instance for each split.
Yields:
A generator of tuples of arrays containing the training and forecast data. If `slice_limit` is set,
yields only up to `slice_limit` splits.
A generator of tuples of arrays containing the training and forecast data. If `split_limit` is set,
yields only up to `split_limit` splits.
"""
# If time_series is not provided, attempt to extract it from the index of the first array
if time_series is None:
Expand All @@ -163,14 +172,132 @@ def split(
*arrays, time_series=time_series, start_dt=start_dt, end_dt=end_dt, return_splitstate=return_splitstate
)

if self.slice_limit is not None:
if self.split_limit is not None:
for i, split in enumerate(split_generator):
if i >= self.slice_limit:
if i >= self.split_limit:
break
yield split
else:
yield from split_generator

def glimpse(self, *arrays: TL, time_series: SeriesLike[DateTimeLike] = None):
"""Prints summary information about the splits, focusing on the first two arrays.
Arguments:
*arrays:
The arrays to split. Only the first two will be used for summary information.
time_series:
The time series used for splitting. If not provided, the index of the first array is used. Default is None.
"""

# Use only the first array for splitting and summary
X = arrays[0]

if time_series is None:
if isinstance(X, (pd.DataFrame, pd.Series)):
time_series = X.index
else:
raise ValueError("time_series must be provided if the first array does not have a time-based index.")

# If the time_series is an index, convert it to a Series for easier handling
if isinstance(time_series, pd.Index):
time_series = pd.Series(time_series, index=time_series)

# Iterate through the splits and print summary information
for split_number, (X_train, X_forecast) in enumerate(self.split(X, time_series=time_series), start=1):
# Get the start and end dates for the training and forecast periods
train_start_date = time_series[X_train.index[0]]
train_end_date = time_series[X_train.index[-1]]
forecast_start_date = time_series[X_forecast.index[0]]
forecast_end_date = time_series[X_forecast.index[-1]]

# Print summary information
print(f"Split Number: {split_number}")
print(f"Train Shape: {X_train.shape}, Forecast Shape: {X_forecast.shape}")
print(f"Train Period: {train_start_date} to {train_end_date}")
print(f"Forecast Period: {forecast_start_date} to {forecast_end_date}\n")


def plot(self, y: pd.Series, time_series: pd.Series = None):
"""Plots the cross-validation sets using Plotly with each fold in a separate subplot.
Arguments:
y: Pandas.Series
The Pandas series of target values to plot.
time_series: Optional[pd.Series]
The time series used for the x-axis. If not provided, the index of `y` will be used.
"""
# Use the index of y if time_series is not provided
if time_series is None:
if isinstance(y, pd.Series):
time_series = y.index
else:
raise ValueError("time_series must be provided if y does not have a time-based index.")

# Ensure time_series is a Pandas Index
if not isinstance(time_series, pd.Index):
raise ValueError("time_series must be a Pandas Index or convertible to one.")

# Determine the number of folds
splits = list(self.split(y, time_series=time_series, return_splitstate=True))
num_folds = len(splits)

# Create subplots
fig = make_subplots(
rows=num_folds, cols=1, # One column, multiple rows
shared_xaxes=True, # Share the x-axis across all subplots
subplot_titles=[f"Fold {i+1}" for i in range(num_folds)]
)

# Enumerate through the splits and add traces to each subplot
for fold, (train_forecast, split_state) in enumerate(splits, start=1):
train, forecast = train_forecast

ts = split_state.train_start
te = split_state.train_end
fs = split_state.forecast_start
fe = split_state.forecast_end

# Add train set trace to the current subplot
fig.add_trace(
go.Scatter(
x=time_series[(time_series >= ts) & (time_series < te)],
y=train + fold,
name=f"Train Fold {fold}",
mode="markers",
marker={"color": "rgb(57, 105, 172)"}
),
row=fold, col=1
)

# Add forecast set trace to the current subplot
fig.add_trace(
go.Scatter(
x=time_series[(time_series >= fs) & (time_series < fe)],
y=forecast + fold,
name=f"Forecast Fold {fold}",
mode="markers",
marker={"color": "indianred"}
),
row=fold, col=1
)

# Update layout
fig.update_layout(
title={
"text": "Time-Based Cross Validation",
"y": 0.95, "x": 0.5,
"xanchor": "center",
"yanchor": "top"
},
showlegend=True,
height=300 * num_folds, # Adjust height based on the number of folds
xaxis_title="Time",
yaxis_title="Fold"
)

return fig



# class TimeSeriesCV:
Expand Down

0 comments on commit 64c9428

Please sign in to comment.