From 60113597d511dfcd0c7c3d78e1f766cacf568e0c Mon Sep 17 00:00:00 2001 From: James Fulton Date: Thu, 20 Jul 2023 16:13:15 +0000 Subject: [PATCH] add support for different inputs - either raw batches of the output of PVNet --- pvnet_summation/models/base_model.py | 9 ++++++--- pvnet_summation/models/model.py | 9 +++++++-- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/pvnet_summation/models/base_model.py b/pvnet_summation/models/base_model.py index 0d10862..dec091f 100644 --- a/pvnet_summation/models/base_model.py +++ b/pvnet_summation/models/base_model.py @@ -135,7 +135,8 @@ def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times): def training_step(self, batch, batch_idx): """Run training step""" - y_hat = self.forward(batch['pvnet_inputs']) + + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -152,7 +153,8 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch: dict, batch_idx): """Run validation step""" - y_hat = self.forward(batch['pvnet_inputs']) + + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -201,7 +203,8 @@ def validation_step(self, batch: dict, batch_idx): def test_step(self, batch, batch_idx): """Run test step""" - y_hat = self.forward(batch['pvnet_inputs']) + + y_hat = self.forward(batch) y = batch["national_targets"] losses = self._calculate_common_losses(y, y_hat) diff --git a/pvnet_summation/models/model.py b/pvnet_summation/models/model.py index b978ef4..901621e 100644 --- a/pvnet_summation/models/model.py +++ b/pvnet_summation/models/model.py @@ -66,8 +66,13 @@ def __init__( def forward(self, x): - """Run central model forward""" - pvnet_out = self.predict_pvnet_batch(x) + """Run model forward""" + + if "pvnet_outputs" in x: + pvnet_out = x["pvnet_outputs"] + else: + pvnet_out = self.predict_pvnet_batch(x['pvnet_inputs']) + pvnet_out = torch.flatten(pvnet_out, start_dim=1) out = self.model(pvnet_out)