Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Jun 20, 2024
1 parent a0f9e00 commit d748335
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 9 deletions.
9 changes: 4 additions & 5 deletions pvnet_summation/models/flat_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import torch
import torch.nn.functional as F
from pvnet.models.multimodal.linear_networks.basic_blocks import AbstractLinearNetwork
from pvnet.models.multimodal.linear_networks.networks import DefaultFCNet
from pvnet.optimizers import AbstractOptimizer
from torch import nn

Expand All @@ -16,7 +15,7 @@

class FlatModel(BaseModel):
"""Neural network which combines GSP predictions from PVNet naively
This model flattens all the features into a 1D vector before feeding them into the sub network
"""

Expand Down Expand Up @@ -53,12 +52,12 @@ def __init__(

self.relative_scale_pvnet_outputs = relative_scale_pvnet_outputs
self.predict_difference_from_sum = predict_difference_from_sum

self.model = output_network(
in_features=np.prod(self.pvnet_output_shape),
out_features=self.num_output_features,
)

# Add linear layer if predicting difference from sum
# This allows difference to be positive or negative
if predict_difference_from_sum:
Expand Down Expand Up @@ -102,4 +101,4 @@ def forward(self, x):

out = F.leaky_relu(gsp_sum + out)

return out
return out
4 changes: 0 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,13 +135,10 @@ def sample_batch(sample_datamodule):

@pytest.fixture()
def flat_model_kwargs():

kwargs = dict(

# These kwargs define the pvnet model which the summation model uses
model_name="openclimatefix/pvnet_v2",
model_version="4203e12e719efd93da641c43d2e38527648f4915",

# These kwargs define the structure of the summation model
output_network=dict(
_target_="pvnet.models.multimodal.linear_networks.networks.ResFCNet2",
Expand All @@ -151,7 +148,6 @@ def flat_model_kwargs():
res_block_layers=2,
dropout_frac=0.0,
),

)
return hydra.utils.instantiate(kwargs)

Expand Down

0 comments on commit d748335

Please sign in to comment.