diff --git a/src/pytimetk/crossvalidation/time_series_cv.py b/src/pytimetk/crossvalidation/time_series_cv.py
index d9949222..efe30bce 100644
--- a/src/pytimetk/crossvalidation/time_series_cv.py
+++ b/src/pytimetk/crossvalidation/time_series_cv.py
@@ -255,17 +255,17 @@ def plot(
y: pd.Series,
time_series: pd.Series = None,
color_palette: Optional[Union[dict, list, str]] = None,
- bar_height: float = 0.6, # Height of each bar (adjust as needed)
+ bar_height: float = 0.3,
title: str = "Time Series Cross-Validation Plot",
x_lab: str = "",
y_lab: str = "Fold",
- x_axis_date_labels: str = None,
+ x_axis_date_labels: str = None,
base_size: float = 11,
width: Optional[int] = None,
height: Optional[int] = None,
engine: str = "plotly"
):
- """Plots the cross-validation folds on a single plot with folds on the y-axis and dates on the x-axis using rectangle shapes.
+ """Plots the cross-validation folds on a single plot with folds on the y-axis and dates on the x-axis using filled Scatter traces.
Arguments:
y: The Pandas series of target values to plot.
@@ -307,7 +307,7 @@ def plot(
# Calculate the vertical positions for each fold
fold_positions = list(range(1, num_folds + 1))
- # Enumerate through the splits and add rectangle shapes
+ # Enumerate through the splits and add filled Scatter traces
for fold, (train_forecast, split_state) in enumerate(splits, start=1):
train_indices, forecast_indices = train_forecast
@@ -338,36 +338,45 @@ def plot(
y0 = fold - bar_height / 2
y1 = fold + bar_height / 2
- # Add rectangle for the training period
- fig.add_shape(
- type="rect",
- x0=ts_date,
- y0=y0,
- x1=te_date,
- y1=y1,
- line=dict(width=0),
+ # Create coordinates for the training period rectangle
+ x_train = [ts_date, te_date, te_date, ts_date, ts_date]
+ y_train = [y0, y0, y1, y1, y0]
+
+ # Add Scatter trace for the training period
+ fig.add_trace(go.Scatter(
+ x=x_train,
+ y=y_train,
+ mode='lines',
+ fill='toself',
fillcolor=color_palette[0],
- opacity=0.8,
- layer='below',
- )
-
- # Add rectangle for the forecast period
- fig.add_shape(
- type="rect",
- x0=fs_date,
- y0=y0,
- x1=fe_date,
- y1=y1,
line=dict(width=0),
+ hoverinfo='text',
+ hoverlabel=dict(font_size=base_size * 0.8),
+ text=f"Fold {fold}
Train Period
{ts_date.date()} to {te_date.date()}",
+ showlegend=False,
+ ))
+
+ # Create coordinates for the forecast period rectangle
+ x_forecast = [fs_date, fe_date, fe_date, fs_date, fs_date]
+ y_forecast = [y0, y0, y1, y1, y0]
+
+ # Add Scatter trace for the forecast period
+ fig.add_trace(go.Scatter(
+ x=x_forecast,
+ y=y_forecast,
+ mode='lines',
+ fill='toself',
fillcolor=color_palette[1],
- opacity=0.8,
- layer='below',
- )
+ line=dict(width=0),
+ hoverinfo='text',
+ hoverlabel=dict(font_size=base_size * 0.8),
+ text=f"Fold {fold}
Forecast Period
{fs_date.date()} to {fe_date.date()}",
+ showlegend=False,
+ ))
- # Calculate midpoint for annotation
+ # Optionally, add text annotations for each fold
train_midpoint = ts_date + (te_date - ts_date) / 2
- # Optionally, add text annotations for each fold
fig.add_trace(go.Scatter(
x=[train_midpoint],
y=[fold],
@@ -419,6 +428,7 @@ def plot(
+
# class TimeSeriesCV:
# """Generates tuples of train_idx, test_idx pairs
# Assumes the MultiIndex contains levels 'symbol' and 'date'