Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TFTExplainer: getting explanations when trained on multiple time series #2489

Open
sikhapentyala opened this issue Aug 1, 2024 · 5 comments
Labels
bug Something isn't working good first issue Good for newcomers

Comments

@sikhapentyala
Copy link

I have trained TFT on multiple time series (for example: trained on a retail dataset with 350 time series, each having target, past and future covariates). My understanding was that TFTExplainer would give importance (temporal and variable) based on what it learned from all the time series. To get these, I pass the backgorund_series (and other covariates) that I had used for training to the TFTExplainer. i.e. I pass all 350 series. This gives me the following error

Traceback (most recent call last):
File "/main.py", line 90, in
results = explainer.explain()
File "/lib/python3.9/site-packages/darts/explainability/tft_explainer.py", line 224, in explain
values=np.take(attention_heads[idx], horizon_idx, axis=0).T,
IndexError: index 30 is out of bounds for axis 0 with size 30

I found that in TFTExplainer.explain(), the size of attention_heads is 30 and not 350.

When I pass only one series as background series, it works (size of attention_heads is 1).

How can I get global explanations for the TFT model when it is trained on multiple time series?

Thank you.

@sikhapentyala sikhapentyala changed the title TFTExplainer: TFTExplainer: getting explanations when trained on multiple time series Aug 1, 2024
@madtoinou
Copy link
Collaborator

Hi @sikhapentyala,

I did a bit of investigation and as expected, the size of the attention heads does not depend on the number of series but on the batch size (which is set to 32 by default).

This means (and it's not documented, kind of a bug), that the maximum number of series that can be explained with TFTExplainer is batch_size in the current implementation (which could be improved by iterating over batches). In the meantime, if you want to explain a lot of series series, you will need to make separate calls to TFTEplainer and explain().

import pandas as pd
import numpy as np

from darts import TimeSeries
from darts.models import TFTModel
from darts.explainability.tft_explainer import TFTExplainer
from darts.utils.timeseries_generation import datetime_attribute_timeseries, sine_timeseries

batch_size = 4
model = TFTModel(
    input_chunk_length=10,
    output_chunk_length=2, 
    pl_trainer_kwargs={"accelerator":"cpu"},
    n_epochs=1,
    batch_size=batch_size,
)

possible_starts = [pd.Timestamp(date) for date in ["2000-01-01", "2005-01-01", "2010-01-01"]]
possible_ends = [pd.Timestamp(date) for date in ["2010-01-01", "2015-01-01", "2020-01-01"]]
training_series = [
    sine_timeseries(
        value_frequency=i,
        start=np.random.choice(possible_starts),
        end=np.random.choice(possible_ends),
        freq="M"
    ) for i in range (batch_size+1)
]
future_cov = datetime_attribute_timeseries(pd.date_range(start=pd.Timestamp("1900-01-01"), end=pd.Timestamp("2025-01-01"), freq="M"), "month", cyclic=True)


model.fit(series=training_series, future_covariates=[future_cov]*batch_size+1,)

# works
explainer = TFTExplainer(model,
   background_series=training_series[:batch_size],
background_future_covariates=[future_cov]*batch_size,
)
explanations = explainer.explain()

# does not work
explainer = TFTExplainer(model,
   background_series=training_series[:batch_size+1],
background_future_covariates=[future_cov]*(batch_size+1),
)
explanations = explainer.explain()

# workaround
nb_batches = len(training_series)//batch_size
if len(training_series) % batch_size != 0:
   nb_batches += 1
explanations = []
for batch_idx in range(nb_batches):
   print(batch_size*batch_idx)
   print(batch_size*(batch_idx+1))
   bg_series = training_series[batch_size*batch_idx:batch_size*(batch_idx+1)]
   fut_cov = [future_cov]*len(bg_series)
   explainer = TFTExplainer(
      model,
      background_series=bg_series,
      background_future_covariates=fut_cov,
   )
   explanations.append(explainer.explain())

Just out of curiosity, what do you expect to learn from applying this to all 350 series in you training set?

@madtoinou madtoinou added the bug Something isn't working label Aug 15, 2024
@sikhapentyala
Copy link
Author

Thank you.

My understanding is that TFT interpretations output "...the general relationships it has learned" (Section 7 in the paper). For example, Table 3 in the paper gives the variable importance not for a single series but for all series (A single series is the time series for a given store_item pair i.e. 1 entity). Through interpretations, I wanted to see what the model has learned something similar to global interpretations rather than on a batch of examples.

@madtoinou
Copy link
Collaborator

Nice, thank you for pointing this out.

The interpretability analysis described in Table 3 of the paper is slightly different than what is implemented in TFTExplainer. They look at the weights of the features selection weights, which are stored in self.static_covariates_vsn, self.encoder_vsn and self.decoder_vsn (which are only dependent on the training data, see source) whereas Darts module return the weights of the attention mechanism (which dependent on both the training data & the input passed during prediction).

You should be able to obtain a similar table if you access those attributes and analyze them.

@madtoinou madtoinou added the good first issue Good for newcomers label Aug 28, 2024
@Tanbuliaobei
Copy link

Hi @madtoinou
You are correct that to access the internal attributes such as self.static_covariates_vsn, self.encoder_vsn, and self.decoder_vsn, you need to first create an instance of _TFTModule (the internal model used by TFTModel). This is done via the _create_model method in TFTModel.

The parameters for _create_model (i.e., train_sample) are indeed expected to be a tuple of six tensors:

  1. past_target
  2. past_covariates
  3. historic_future_covariates
  4. future_covariates
  5. static_covariates
  6. future_target

Each of these tensors is expected to have the shape (n_timesteps, n_variables), meaning they are 2D tensors (or arrays). However, my goal is to input multiple sequences for global interpretability analysis, which would typically require 3D tensors with an additional dimension for batch size, i.e., (batch_size, n_timesteps, n_variables).The key challenge here is that the _create_model method appears to only accept 2D tensors, which suggests it is designed for single sequences or individual samples.So, how do i solve this problem

@madtoinou
Copy link
Collaborator

madtoinou commented Nov 7, 2024

Hi @Tanbuliaobei,

First of all, the parameters of _create_model are not relevant for explainability; they are just used to create the model with the proper dimensions. During training, the model is fed with batches of samples taken across all the series in the dataset.

As explained in one of my message above, if you want to use the TFTExplainer and have a look at the attention heads, you can process batches of series. The information won't be "global" but specific to each input series.

If you want to look at the features selection weights (which is probably what you mean by "global interpretability"), you access the following attributes of the models; self.static_covariates_vsn, self.encoder_vsn and self.decoder_vsn. It does not make sense to inspect them prior to training (ie, by calling _create_model() by itself);

from darts.datasets import AirPassengersDataset
from darts.models import TFTModel
from darts.utils.timeseries_generation import datetime_attribute_timeseries

series = AirPassengersDataset().load()
ts1, ts2 = series.split_after(0.6)

fut_cov = datetime_attribute_timeseries(
      time_index=series.time_index,
      attribute="weekday",
      cyclic=True,
      add_length=10
)

# fit the model so that the weights actually have meaning
model = TFTModel(input_chunk_length=3, output_chunk_length=1, n_epochs=5)
model.fit(series=[ts1, ts2], future_covariates=[fut_cov, fut_cov])

# work exactly the same for decoder_vsn
list(model.model.encoder_vsn.prescalers.keys())
>>> ['target_0', 'future_covariate_0', 'future_covariate_1']

list(x.weight.detach().numpy() for x in model.model.encoder_vsn.prescalers.values())
>>> [array([[-0.58796742], ...

Let me know if it solves your problem

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

3 participants