From 3edce3a9e5f4b4bfc6d16025eca8f6c1f0bf33c7 Mon Sep 17 00:00:00 2001 From: Greg DeVosNouri Date: Sun, 22 Sep 2024 22:26:16 -0700 Subject: [PATCH] add TODO --- darts/models/forecasting/times_net_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/darts/models/forecasting/times_net_model.py b/darts/models/forecasting/times_net_model.py index 91bba8ebec..c3f4a0809d 100644 --- a/darts/models/forecasting/times_net_model.py +++ b/darts/models/forecasting/times_net_model.py @@ -116,6 +116,8 @@ def __init__( num_layers: int, num_kernels: int, top_k: int, + embed_type:str="fixed", + freq:str="h", **kwargs, ): super().__init__(**kwargs) @@ -124,7 +126,9 @@ def __init__( self.output_dim = output_dim self.nr_params = nr_params - self.embedding = DataEmbedding(input_dim, hidden_size, "fixed", "h", 0.1) + # embed_type and freq are placeholders and are not used until the futures + # covariate in the forward method are figured out + self.embedding = DataEmbedding(input_dim, hidden_size, embed_type=embed_type, freq=freq, dropout=0.1) self.model = nn.ModuleList([ TimesBlock( @@ -148,7 +152,7 @@ def forward(self, x_in: Tuple) -> torch.Tensor: x, _ = x_in # Embedding - x = self.embedding(x, None) + x = self.embedding(x, None) # TODO: future covariate would go here x = self.predict_linear(x.transpose(1, 2)).transpose(1, 2) # TimesNet