-
Notifications
You must be signed in to change notification settings - Fork 63
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
134 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
--- | ||
title: "Time Series Cross Validation" | ||
jupyter: python3 | ||
toc: true | ||
toc-depth: 3 | ||
number-sections: true | ||
number-depth: 2 | ||
--- | ||
|
||
## Time-Based Cross-Validation Using `TimeSeriesCV` and `TimeSeriesCVSplitter` | ||
|
||
In this tutorial, you'll learn how to use the `TimeSeriesCV` and `TimeSeriesCVSplitter` classes from `pytimetk` for time series cross-validation, using the walmart_sales_df dataset as an example. We'll start with exploring the data and move on to creating and visualizing time-based cross-validation splits. | ||
|
||
### Step 1: Load and Explore the Data | ||
|
||
First, let's load the Walmart sales dataset and explore its structure: | ||
|
||
```{python} | ||
# libraries | ||
import pytimetk as tk | ||
import pandas as pd | ||
import numpy as np | ||
# Import Data | ||
walmart_sales_df = tk.load_dataset('walmart_sales_weekly') | ||
walmart_sales_df['Date'] = pd.to_datetime(walmart_sales_df['Date']) | ||
walmart_sales_df = walmart_sales_df[['id', 'Date', 'Weekly_Sales']] | ||
walmart_sales_df.glimpse() | ||
``` | ||
|
||
### Step 2: Visualize the Time Series Data | ||
|
||
We can visualize the weekly sales data for different store IDs using the `plot_timeseries` method from `pytimetk`: | ||
|
||
```{python} | ||
walmart_sales_df \ | ||
.groupby('id') \ | ||
.plot_timeseries( | ||
"Date", "Weekly_Sales", | ||
plotly_dropdown = True, | ||
) | ||
``` | ||
|
||
This will generate an interactive time series plot, allowing you to explore sales data for different stores using a dropdown. | ||
|
||
### Step 3: Set Up `TimeSeriesCV` for Cross-Validation | ||
|
||
Now, let's set up a time-based cross-validation scheme using `TimeSeriesCV`: | ||
|
||
```{python} | ||
from pytimetk.crossvalidation import TimeSeriesCV | ||
# Define parameters for TimeSeriesCV | ||
tscv = TimeSeriesCV( | ||
frequency="weeks", | ||
train_size=52, # Use 52 weeks for training | ||
forecast_horizon=12, # Forecast 12 weeks ahead | ||
gap=0, # No gap between training and forecast sets | ||
stride=4, # Move forward by 4 weeks after each split | ||
window="rolling", # Use a rolling window | ||
mode="backward" # Generate splits from end to start | ||
) | ||
# Glimpse the cross-validation splits | ||
tscv.glimpse( | ||
walmart_sales_df['Weekly_Sales'], | ||
time_series=walmart_sales_df['Date'] | ||
) | ||
``` | ||
|
||
The `glimpse` method provides a summary of each cross-validation fold, including the start and end dates of the training and forecast periods. | ||
|
||
|
||
### Step 4: Plot the Cross-Validation Splits | ||
|
||
You can visualize how the data is split for training and testing: | ||
|
||
```{python} | ||
# Plot the cross-validation splits | ||
tscv.plot( | ||
walmart_sales_df['Weekly_Sales'], | ||
time_series=walmart_sales_df['Date'] | ||
) | ||
``` | ||
|
||
This plot will show each fold, illustrating which weeks are used for training and which weeks are used for forecasting. | ||
|
||
### Step 5: Using `TimeSeriesCVSplitter` for Model Evaluation | ||
|
||
If you want to use the cross-validation scheme with scikit-learn's model evaluation methods, you can use `TimeSeriesCVSplitter`: | ||
|
||
|
||
```{python} | ||
from pytimetk.crossvalidation import TimeSeriesCVSplitter | ||
from sklearn.linear_model import Ridge | ||
from sklearn.model_selection import cross_val_score | ||
# Set up TimeSeriesCVSplitter | ||
cv_splitter = TimeSeriesCVSplitter( | ||
time_series=walmart_sales_df['Date'], | ||
frequency="weeks", | ||
train_size=52, | ||
forecast_horizon=4, | ||
gap=0, | ||
stride=4, | ||
window="rolling", | ||
mode="backward", | ||
split_limit = 5 | ||
) | ||
# Prepare data for modeling | ||
# Extract time series features from the 'Date' column | ||
X_time_features = tk.get_timeseries_signature(walmart_sales_df['Date']).drop('Date', axis=1) | ||
# Dummy encode the 'id' column | ||
X_id_dummies = pd.get_dummies(walmart_sales_df['id'], prefix='store') | ||
# Combine the time series features and dummy encoded features | ||
X = pd.concat([X_time_features, X_id_dummies], axis=1) | ||
# Target | ||
y = walmart_sales_df['Weekly_Sales'].values | ||
# Fit and evaluate a model using cross-validation | ||
model = Ridge() | ||
scores = cross_val_score(model, X, y, cv=cv_splitter) | ||
# Print cross-validation scores | ||
print("Cross-Validation Scores:", scores) | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters