diff --git a/src/pytimetk/crossvalidation/time_series_cv.py b/src/pytimetk/crossvalidation/time_series_cv.py index d9949222..efe30bce 100644 --- a/src/pytimetk/crossvalidation/time_series_cv.py +++ b/src/pytimetk/crossvalidation/time_series_cv.py @@ -255,17 +255,17 @@ def plot( y: pd.Series, time_series: pd.Series = None, color_palette: Optional[Union[dict, list, str]] = None, - bar_height: float = 0.6, # Height of each bar (adjust as needed) + bar_height: float = 0.3, title: str = "Time Series Cross-Validation Plot", x_lab: str = "", y_lab: str = "Fold", - x_axis_date_labels: str = None, + x_axis_date_labels: str = None, base_size: float = 11, width: Optional[int] = None, height: Optional[int] = None, engine: str = "plotly" ): - """Plots the cross-validation folds on a single plot with folds on the y-axis and dates on the x-axis using rectangle shapes. + """Plots the cross-validation folds on a single plot with folds on the y-axis and dates on the x-axis using filled Scatter traces. Arguments: y: The Pandas series of target values to plot. @@ -307,7 +307,7 @@ def plot( # Calculate the vertical positions for each fold fold_positions = list(range(1, num_folds + 1)) - # Enumerate through the splits and add rectangle shapes + # Enumerate through the splits and add filled Scatter traces for fold, (train_forecast, split_state) in enumerate(splits, start=1): train_indices, forecast_indices = train_forecast @@ -338,36 +338,45 @@ def plot( y0 = fold - bar_height / 2 y1 = fold + bar_height / 2 - # Add rectangle for the training period - fig.add_shape( - type="rect", - x0=ts_date, - y0=y0, - x1=te_date, - y1=y1, - line=dict(width=0), + # Create coordinates for the training period rectangle + x_train = [ts_date, te_date, te_date, ts_date, ts_date] + y_train = [y0, y0, y1, y1, y0] + + # Add Scatter trace for the training period + fig.add_trace(go.Scatter( + x=x_train, + y=y_train, + mode='lines', + fill='toself', fillcolor=color_palette[0], - opacity=0.8, - layer='below', - ) - - # Add rectangle for the forecast period - fig.add_shape( - type="rect", - x0=fs_date, - y0=y0, - x1=fe_date, - y1=y1, line=dict(width=0), + hoverinfo='text', + hoverlabel=dict(font_size=base_size * 0.8), + text=f"Fold {fold}
Train Period
{ts_date.date()} to {te_date.date()}", + showlegend=False, + )) + + # Create coordinates for the forecast period rectangle + x_forecast = [fs_date, fe_date, fe_date, fs_date, fs_date] + y_forecast = [y0, y0, y1, y1, y0] + + # Add Scatter trace for the forecast period + fig.add_trace(go.Scatter( + x=x_forecast, + y=y_forecast, + mode='lines', + fill='toself', fillcolor=color_palette[1], - opacity=0.8, - layer='below', - ) + line=dict(width=0), + hoverinfo='text', + hoverlabel=dict(font_size=base_size * 0.8), + text=f"Fold {fold}
Forecast Period
{fs_date.date()} to {fe_date.date()}", + showlegend=False, + )) - # Calculate midpoint for annotation + # Optionally, add text annotations for each fold train_midpoint = ts_date + (te_date - ts_date) / 2 - # Optionally, add text annotations for each fold fig.add_trace(go.Scatter( x=[train_midpoint], y=[fold], @@ -419,6 +428,7 @@ def plot( + # class TimeSeriesCV: # """Generates tuples of train_idx, test_idx pairs # Assumes the MultiIndex contains levels 'symbol' and 'date'