Skip to content

Commit

Permalink
Feature add activation to BlockRNN (#2492) (#2504)
Browse files Browse the repository at this point in the history
* Feature add activation to BlockRNN (#2492)

* Added support for specifying PyTorch activation functions (`ReLU`, `Sigmoid`, `Tanh`, or `None`) in the `BlockRNNModel`.
* Ensured that activation functions are applied between fully connected layers, but not as the final layer.
* Implemented a check to raise an error if an activation function is set but the model only contains one linear layer.
* Updated documentation to reflect the new activation parameter and usage examples.
* Added test cases to verify the correct application of activation functions and to handle edge cases.

* Update darts/models/forecasting/block_rnn_model.py

Co-authored-by: madtoinou <[email protected]>

* Update darts/models/forecasting/block_rnn_model.py

Co-authored-by: madtoinou <[email protected]>

* Feature add activation to BlockRNN (#2492)

* Add a check that raise an error when activation is None and hidden_fc_sizes is greater than 0

* Update darts/models/forecasting/block_rnn_model.py

Co-authored-by: Dennis Bader <[email protected]>

* Feature add activation to BlockRNN (#2492)

* _check_ckpt_parameters
* Remove redundant raise_if

* Update darts/models/forecasting/block_rnn_model.py

Co-authored-by: Dennis Bader <[email protected]>

* Feature add activation to BlockRNN (#2492)

* Revert docstring _BlockRNNModule

---------

Co-authored-by: madtoinou <[email protected]>
Co-authored-by: Dennis Bader <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2024
1 parent 26c5f39 commit 38c066b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 10 deletions.
34 changes: 31 additions & 3 deletions darts/models/forecasting/block_rnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
nr_params: int,
num_layers_out_fc: Optional[List] = None,
dropout: float = 0.0,
activation: str = "ReLU",
**kwargs,
):
"""This class allows to create custom block RNN modules that can later be used with Darts'
Expand Down Expand Up @@ -63,6 +64,8 @@ def __init__(
This network connects the last hidden layer of the PyTorch RNN module to the output.
dropout
The fraction of neurons that are dropped in all-but-last RNN layers.
activation
The name of the activation function to be applied between the layers of the fully connected network.
**kwargs
all parameters required for :class:`darts.models.forecasting.pl_forecasting_module.PLForecastingModule`
base class.
Expand All @@ -77,6 +80,7 @@ def __init__(
self.nr_params = nr_params
self.num_layers_out_fc = [] if num_layers_out_fc is None else num_layers_out_fc
self.dropout = dropout
self.activation = activation
self.out_len = self.output_chunk_length

@io_processor
Expand Down Expand Up @@ -105,6 +109,7 @@ class _BlockRNNModule(CustomBlockRNNModule):
def __init__(
self,
name: str,
activation: Optional[str] = None,
**kwargs,
):
"""PyTorch module implementing a block RNN to be used in `BlockRNNModel`.
Expand All @@ -116,6 +121,7 @@ def __init__(
This module uses an RNN to encode the input sequence, and subsequently uses a fully connected
network as the decoder which takes as input the last hidden state of the encoder RNN.
Optionally, a non-linear activation function can be applied between the layers of the fully connected network.
The final output of the decoder is a sequence of length `output_chunk_length`. In this sense,
the `_BlockRNNModule` produces 'blocks' of forecasts at a time (which is different
from `_RNNModule` used by the `RNNModel`).
Expand All @@ -124,6 +130,9 @@ def __init__(
----------
name
The name of the specific PyTorch RNN module ("RNN", "GRU" or "LSTM").
activation
The name of the activation function to be applied between the layers of the fully connected network.
Options include "ReLU", "Sigmoid", "Tanh", or None for no activation. Default: None.
**kwargs
all parameters required for the :class:`darts.models.forecasting.CustomBlockRNNModule` base class.
Expand Down Expand Up @@ -155,10 +164,15 @@ def __init__(
# to the output of desired length
last = self.hidden_dim
feats = []
for feature in self.num_layers_out_fc + [
self.out_len * self.target_size * self.nr_params
]:
for index, feature in enumerate(
self.num_layers_out_fc + [self.out_len * self.target_size * self.nr_params]
):
feats.append(nn.Linear(last, feature))

# Add activation only between layers, but not on the final layer
if activation and index < len(self.num_layers_out_fc):
activation_function = getattr(nn, activation)()
feats.append(activation_function)
last = feature
self.fc = nn.Sequential(*feats)

Expand Down Expand Up @@ -195,6 +209,7 @@ def __init__(
n_rnn_layers: int = 1,
hidden_fc_sizes: Optional[List] = None,
dropout: float = 0.0,
activation: str = "ReLU",
**kwargs,
):
"""Block Recurrent Neural Network Model (RNNs).
Expand Down Expand Up @@ -243,6 +258,9 @@ def __init__(
Sizes of hidden layers connecting the last hidden layer of the RNN module to the output, if any.
dropout
Fraction of neurons afected by Dropout.
activation
The name of a torch.nn activation function to be applied between the layers of the fully connected network.
Default: "ReLU".
**kwargs
Optional arguments to initialize the pytorch_lightning.Module, pytorch_lightning.Trainer, and
Darts' :class:`TorchForecastingModel`.
Expand Down Expand Up @@ -435,6 +453,7 @@ def encode_year(idx):
self.hidden_dim = hidden_dim
self.n_rnn_layers = n_rnn_layers
self.dropout = dropout
self.activation = activation

@property
def supports_multivariate(self) -> bool:
Expand Down Expand Up @@ -464,6 +483,15 @@ def _create_model(self, train_sample: Tuple[torch.Tensor]) -> torch.nn.Module:
num_layers=self.n_rnn_layers,
num_layers_out_fc=hidden_fc_sizes,
dropout=self.dropout,
activation=self.activation,
**self.pl_module_params,
**kwargs,
)

def _check_ckpt_parameters(self, tfm_save):
# new parameters were added that will break loading weights
new_params = ["activation"]
for param in new_params:
if param not in tfm_save.model_params:
tfm_save.model_params[param] = "ReLU"
super()._check_ckpt_parameters(tfm_save)
29 changes: 22 additions & 7 deletions darts/tests/models/forecasting/test_block_RNN.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,31 +85,46 @@ def test_creation(self):
model1.fit(self.series)
preds1 = model1.predict(n=3)

# can create from a custom class itself
# can create from valid module name with ReLU activation
model2 = BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model=ModuleValid1,
model="RNN",
activation="ReLU",
hidden_fc_sizes=[10],
n_epochs=1,
random_state=42,
**tfm_kwargs,
)
model2.fit(self.series)
preds2 = model2.predict(n=3)
np.testing.assert_array_equal(preds1.all_values(), preds2.all_values())
assert preds1.values().shape == preds2.values().shape

# can create from a custom class itself
model3 = BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model=ModuleValid2,
model=ModuleValid1,
n_epochs=1,
random_state=42,
**tfm_kwargs,
)
model3.fit(self.series)
preds3 = model2.predict(n=3)
assert preds3.all_values().shape == preds2.all_values().shape
assert preds3.time_index.equals(preds2.time_index)
preds3 = model3.predict(n=3)
np.testing.assert_array_equal(preds1.all_values(), preds3.all_values())

model4 = BlockRNNModel(
input_chunk_length=1,
output_chunk_length=1,
model=ModuleValid2,
n_epochs=1,
random_state=42,
**tfm_kwargs,
)
model4.fit(self.series)
preds4 = model4.predict(n=3)
assert preds4.all_values().shape == preds3.all_values().shape
assert preds4.time_index.equals(preds3.time_index)

def test_fit(self, tmpdir_module):
# Test basic fit()
Expand Down

0 comments on commit 38c066b

Please sign in to comment.