From d748335f3b995d3f6fd34a0e575dbfd0a16b66fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 20 Jun 2024 10:47:41 +0000 Subject: [PATCH] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pvnet_summation/models/flat_model.py | 9 ++++----- tests/conftest.py | 4 ---- 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pvnet_summation/models/flat_model.py b/pvnet_summation/models/flat_model.py index d627adb..1996f0b 100644 --- a/pvnet_summation/models/flat_model.py +++ b/pvnet_summation/models/flat_model.py @@ -7,7 +7,6 @@ import torch import torch.nn.functional as F from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork -from pvnet.models.multimodal.linear_networks.networks import DefaultFCNet from pvnet.optimizers import AbstractOptimizer from torch import nn @@ -16,7 +15,7 @@ class FlatModel(BaseModel): """Neural network which combines GSP predictions from PVNet naively - + This model flattens all the features into a 1D vector before feeding them into the sub network """ @@ -53,12 +52,12 @@ def __init__( self.relative_scale_pvnet_outputs = relative_scale_pvnet_outputs self.predict_difference_from_sum = predict_difference_from_sum - + self.model = output_network( in_features=np.prod(self.pvnet_output_shape), out_features=self.num_output_features, ) - + # Add linear layer if predicting difference from sum # This allows difference to be positive or negative if predict_difference_from_sum: @@ -102,4 +101,4 @@ def forward(self, x): out = F.leaky_relu(gsp_sum + out) - return out \ No newline at end of file + return out diff --git a/tests/conftest.py b/tests/conftest.py index 68dfd2b..951fd85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -135,13 +135,10 @@ def sample_batch(sample_datamodule): @pytest.fixture() def flat_model_kwargs(): - kwargs = dict( - # These kwargs define the pvnet model which the summation model uses model_name="openclimatefix/pvnet_v2", model_version="4203e12e719efd93da641c43d2e38527648f4915", - # These kwargs define the structure of the summation model output_network=dict( _target_="pvnet.models.multimodal.linear_networks.networks.ResFCNet2", @@ -151,7 +148,6 @@ def flat_model_kwargs(): res_block_layers=2, dropout_frac=0.0, ), - ) return hydra.utils.instantiate(kwargs)