Skip to content

Commit

Permalink
TSCV: plot - finalize
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Nov 6, 2024
1 parent 90bdd99 commit 03de697
Showing 1 changed file with 38 additions and 28 deletions.
66 changes: 38 additions & 28 deletions src/pytimetk/crossvalidation/time_series_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,17 +255,17 @@ def plot(
y: pd.Series,
time_series: pd.Series = None,
color_palette: Optional[Union[dict, list, str]] = None,
bar_height: float = 0.6, # Height of each bar (adjust as needed)
bar_height: float = 0.3,
title: str = "Time Series Cross-Validation Plot",
x_lab: str = "",
y_lab: str = "Fold",
x_axis_date_labels: str = None,
x_axis_date_labels: str = None,
base_size: float = 11,
width: Optional[int] = None,
height: Optional[int] = None,
engine: str = "plotly"
):
"""Plots the cross-validation folds on a single plot with folds on the y-axis and dates on the x-axis using rectangle shapes.
"""Plots the cross-validation folds on a single plot with folds on the y-axis and dates on the x-axis using filled Scatter traces.
Arguments:
y: The Pandas series of target values to plot.
Expand Down Expand Up @@ -307,7 +307,7 @@ def plot(
# Calculate the vertical positions for each fold
fold_positions = list(range(1, num_folds + 1))

# Enumerate through the splits and add rectangle shapes
# Enumerate through the splits and add filled Scatter traces
for fold, (train_forecast, split_state) in enumerate(splits, start=1):
train_indices, forecast_indices = train_forecast

Expand Down Expand Up @@ -338,36 +338,45 @@ def plot(
y0 = fold - bar_height / 2
y1 = fold + bar_height / 2

# Add rectangle for the training period
fig.add_shape(
type="rect",
x0=ts_date,
y0=y0,
x1=te_date,
y1=y1,
line=dict(width=0),
# Create coordinates for the training period rectangle
x_train = [ts_date, te_date, te_date, ts_date, ts_date]
y_train = [y0, y0, y1, y1, y0]

# Add Scatter trace for the training period
fig.add_trace(go.Scatter(
x=x_train,
y=y_train,
mode='lines',
fill='toself',
fillcolor=color_palette[0],
opacity=0.8,
layer='below',
)

# Add rectangle for the forecast period
fig.add_shape(
type="rect",
x0=fs_date,
y0=y0,
x1=fe_date,
y1=y1,
line=dict(width=0),
hoverinfo='text',
hoverlabel=dict(font_size=base_size * 0.8),
text=f"Fold {fold}<br>Train Period<br>{ts_date.date()} to {te_date.date()}",
showlegend=False,
))

# Create coordinates for the forecast period rectangle
x_forecast = [fs_date, fe_date, fe_date, fs_date, fs_date]
y_forecast = [y0, y0, y1, y1, y0]

# Add Scatter trace for the forecast period
fig.add_trace(go.Scatter(
x=x_forecast,
y=y_forecast,
mode='lines',
fill='toself',
fillcolor=color_palette[1],
opacity=0.8,
layer='below',
)
line=dict(width=0),
hoverinfo='text',
hoverlabel=dict(font_size=base_size * 0.8),
text=f"Fold {fold}<br>Forecast Period<br>{fs_date.date()} to {fe_date.date()}",
showlegend=False,
))

# Calculate midpoint for annotation
# Optionally, add text annotations for each fold
train_midpoint = ts_date + (te_date - ts_date) / 2

# Optionally, add text annotations for each fold
fig.add_trace(go.Scatter(
x=[train_midpoint],
y=[fold],
Expand Down Expand Up @@ -419,6 +428,7 @@ def plot(




# class TimeSeriesCV:
# """Generates tuples of train_idx, test_idx pairs
# Assumes the MultiIndex contains levels 'symbol' and 'date'
Expand Down

0 comments on commit 03de697

Please sign in to comment.