diff --git a/pvnet_summation/models/base_model.py b/pvnet_summation/models/base_model.py index 0d10862..dec091f 100644 --- a/pvnet_summation/models/base_model.py +++ b/pvnet_summation/models/base_model.py @@ -135,7 +135,8 @@ def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times): def training_step(self, batch, batch_idx): """Run training step""" - y_hat = self.forward(batch['pvnet_inputs']) + + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -152,7 +153,8 @@ def training_step(self, batch, batch_idx): def validation_step(self, batch: dict, batch_idx): """Run validation step""" - y_hat = self.forward(batch['pvnet_inputs']) + + y_hat = self.forward(batch) y = batch["national_targets"] times = batch["times"] @@ -201,7 +203,8 @@ def validation_step(self, batch: dict, batch_idx): def test_step(self, batch, batch_idx): """Run test step""" - y_hat = self.forward(batch['pvnet_inputs']) + + y_hat = self.forward(batch) y = batch["national_targets"] losses = self._calculate_common_losses(y, y_hat) diff --git a/pvnet_summation/models/model.py b/pvnet_summation/models/model.py index b978ef4..901621e 100644 --- a/pvnet_summation/models/model.py +++ b/pvnet_summation/models/model.py @@ -66,8 +66,13 @@ def __init__( def forward(self, x): - """Run central model forward""" - pvnet_out = self.predict_pvnet_batch(x) + """Run model forward""" + + if "pvnet_outputs" in x: + pvnet_out = x["pvnet_outputs"] + else: + pvnet_out = self.predict_pvnet_batch(x['pvnet_inputs']) + pvnet_out = torch.flatten(pvnet_out, start_dim=1) out = self.model(pvnet_out)