Skip to content

Commit

Permalink
TSCV: Prep tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
mdancho84 committed Nov 6, 2024
1 parent fe80f10 commit 6c16702
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 13 deletions.
134 changes: 134 additions & 0 deletions docs/guides/07_timeseries_crossvalidation.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
---
title: "Time Series Cross Validation"
jupyter: python3
toc: true
toc-depth: 3
number-sections: true
number-depth: 2
---

## Time-Based Cross-Validation Using `TimeSeriesCV` and `TimeSeriesCVSplitter`

In this tutorial, you'll learn how to use the `TimeSeriesCV` and `TimeSeriesCVSplitter` classes from `pytimetk` for time series cross-validation, using the walmart_sales_df dataset as an example. We'll start with exploring the data and move on to creating and visualizing time-based cross-validation splits.

### Step 1: Load and Explore the Data

First, let's load the Walmart sales dataset and explore its structure:

```{python}
# libraries
import pytimetk as tk
import pandas as pd
import numpy as np
# Import Data
walmart_sales_df = tk.load_dataset('walmart_sales_weekly')
walmart_sales_df['Date'] = pd.to_datetime(walmart_sales_df['Date'])
walmart_sales_df = walmart_sales_df[['id', 'Date', 'Weekly_Sales']]
walmart_sales_df.glimpse()
```

### Step 2: Visualize the Time Series Data

We can visualize the weekly sales data for different store IDs using the `plot_timeseries` method from `pytimetk`:

```{python}
walmart_sales_df \
.groupby('id') \
.plot_timeseries(
"Date", "Weekly_Sales",
plotly_dropdown = True,
)
```

This will generate an interactive time series plot, allowing you to explore sales data for different stores using a dropdown.

### Step 3: Set Up `TimeSeriesCV` for Cross-Validation

Now, let's set up a time-based cross-validation scheme using `TimeSeriesCV`:

```{python}
from pytimetk.crossvalidation import TimeSeriesCV
# Define parameters for TimeSeriesCV
tscv = TimeSeriesCV(
frequency="weeks",
train_size=52, # Use 52 weeks for training
forecast_horizon=12, # Forecast 12 weeks ahead
gap=0, # No gap between training and forecast sets
stride=4, # Move forward by 4 weeks after each split
window="rolling", # Use a rolling window
mode="backward" # Generate splits from end to start
)
# Glimpse the cross-validation splits
tscv.glimpse(
walmart_sales_df['Weekly_Sales'],
time_series=walmart_sales_df['Date']
)
```

The `glimpse` method provides a summary of each cross-validation fold, including the start and end dates of the training and forecast periods.


### Step 4: Plot the Cross-Validation Splits

You can visualize how the data is split for training and testing:

```{python}
# Plot the cross-validation splits
tscv.plot(
walmart_sales_df['Weekly_Sales'],
time_series=walmart_sales_df['Date']
)
```

This plot will show each fold, illustrating which weeks are used for training and which weeks are used for forecasting.

### Step 5: Using `TimeSeriesCVSplitter` for Model Evaluation

If you want to use the cross-validation scheme with scikit-learn's model evaluation methods, you can use `TimeSeriesCVSplitter`:


```{python}
from pytimetk.crossvalidation import TimeSeriesCVSplitter
from sklearn.linear_model import Ridge
from sklearn.model_selection import cross_val_score
# Set up TimeSeriesCVSplitter
cv_splitter = TimeSeriesCVSplitter(
time_series=walmart_sales_df['Date'],
frequency="weeks",
train_size=52,
forecast_horizon=4,
gap=0,
stride=4,
window="rolling",
mode="backward",
split_limit = 5
)
# Prepare data for modeling
# Extract time series features from the 'Date' column
X_time_features = tk.get_timeseries_signature(walmart_sales_df['Date']).drop('Date', axis=1)
# Dummy encode the 'id' column
X_id_dummies = pd.get_dummies(walmart_sales_df['id'], prefix='store')
# Combine the time series features and dummy encoded features
X = pd.concat([X_time_features, X_id_dummies], axis=1)
# Target
y = walmart_sales_df['Weekly_Sales'].values
# Fit and evaluate a model using cross-validation
model = Ridge()
scores = cross_val_score(model, X, y, cv=cv_splitter)
# Print cross-validation scores
print("Cross-Validation Scores:", scores)
```
13 changes: 0 additions & 13 deletions src/pytimetk/crossvalidation/time_series_cv.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,19 +401,6 @@ def plot(
showlegend=False,
))

# Optionally, add text annotations for each fold
train_midpoint = ts_date + (te_date - ts_date) / 2

fig.add_trace(go.Scatter(
x=[train_midpoint],
y=[fold],
text=[f"Fold {fold}"],
mode="text",
showlegend=False,
textposition="middle center",
hoverinfo='skip',
))

# Update layout
fig.update_layout(
title=title,
Expand Down

0 comments on commit 6c16702

Please sign in to comment.