diff --git a/pvnet/models/multimodal/linear_networks/networks.py b/pvnet/models/multimodal/linear_networks/networks.py index 0009afa0..e29211ae 100644 --- a/pvnet/models/multimodal/linear_networks/networks.py +++ b/pvnet/models/multimodal/linear_networks/networks.py @@ -1,5 +1,5 @@ """Linear networks used for the fusion model""" -from torch import nn +from torch import nn, rand from pvnet.models.multimodal.linear_networks.basic_blocks import ( AbstractLinearNetwork, @@ -316,6 +316,7 @@ def __init__( virtual_batch_size=virtual_batch_size, momentum=momentum, mask_type=mask_type, + group_attention_matrix=rand(4, in_features) ) self.activation = nn.LeakyReLU(negative_slope=0.01) diff --git a/pyproject.toml b/pyproject.toml index c0233394..0a066326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "ipykernel", "h5netcdf", "torch>=2.0.0", - "pytorch_lightning==2.3.0", + "lightning", "torchvision", "pytest", "pytest-cov",