Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Jun 18, 2024
1 parent 6cf05a7 commit 061a1ed
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 120 deletions.
4 changes: 2 additions & 2 deletions configs.example/model/default.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
_target_: pvnet_summation.models.model.Model
_target_: pvnet_summation.models.flat_model.FlatModel

output_quantiles: null

Expand All @@ -12,11 +12,11 @@ model_version: "898630f3f8cd4e8506525d813dd61c6d8de86144"
output_network:
_target_: pvnet.models.multimodal.linear_networks.networks.ResFCNet2
_partial_: True
output_network_kwargs:
fc_hidden_features: 128
n_res_blocks: 2
res_block_layers: 2
dropout_frac: 0.0

predict_difference_from_sum: False

# ----------------------------------------------
Expand Down
109 changes: 0 additions & 109 deletions pvnet_summation/models/model.py

This file was deleted.

31 changes: 22 additions & 9 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import math
import glob
import tempfile
from pvnet_summation.models.model import Model

import hydra
from pvnet_summation.models.flat_model import FlatModel

from ocf_datapipes.batch import BatchKey
from datetime import timedelta
Expand Down Expand Up @@ -134,22 +134,35 @@ def sample_batch(sample_datamodule):


@pytest.fixture()
def model_kwargs():
# These kwargs define the pvnet model which the summation model uses
def flat_model_kwargs():

kwargs = dict(

# These kwargs define the pvnet model which the summation model uses
model_name="openclimatefix/pvnet_v2",
model_version="4203e12e719efd93da641c43d2e38527648f4915",

# These kwargs define the structure of the summation model
output_network=dict(
_target_="pvnet.models.multimodal.linear_networks.networks.ResFCNet2",
_partial_=True,
fc_hidden_features=128,
n_res_blocks=2,
res_block_layers=2,
dropout_frac=0.0,
),

)
return kwargs
return hydra.utils.instantiate(kwargs)


@pytest.fixture()
def model(model_kwargs):
model = Model(**model_kwargs)
def model(flat_model_kwargs):
model = FlatModel(**flat_model_kwargs)
return model


@pytest.fixture()
def quantile_model(model_kwargs):
model = Model(output_quantiles=[0.1, 0.5, 0.9], **model_kwargs)
def quantile_model(flat_model_kwargs):
model = FlatModel(output_quantiles=[0.1, 0.5, 0.9], **flat_model_kwargs)
return model
File renamed without changes.

0 comments on commit 061a1ed

Please sign in to comment.