Skip to content

Commit

Permalink
add support for different inputs - either raw batches of the output o…
Browse files Browse the repository at this point in the history
…f PVNet
  • Loading branch information
dfulu committed Jul 20, 2023
1 parent 3b801ee commit 6011359
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 5 deletions.
9 changes: 6 additions & 3 deletions pvnet_summation/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def _training_accumulate_log(self, batch_idx, losses, y_hat, y, times):

def training_step(self, batch, batch_idx):
"""Run training step"""
y_hat = self.forward(batch['pvnet_inputs'])

y_hat = self.forward(batch)
y = batch["national_targets"]
times = batch["times"]

Expand All @@ -152,7 +153,8 @@ def training_step(self, batch, batch_idx):

def validation_step(self, batch: dict, batch_idx):
"""Run validation step"""
y_hat = self.forward(batch['pvnet_inputs'])

y_hat = self.forward(batch)
y = batch["national_targets"]
times = batch["times"]

Expand Down Expand Up @@ -201,7 +203,8 @@ def validation_step(self, batch: dict, batch_idx):

def test_step(self, batch, batch_idx):
"""Run test step"""
y_hat = self.forward(batch['pvnet_inputs'])

y_hat = self.forward(batch)
y = batch["national_targets"]

losses = self._calculate_common_losses(y, y_hat)
Expand Down
9 changes: 7 additions & 2 deletions pvnet_summation/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,13 @@ def __init__(


def forward(self, x):
"""Run central model forward"""
pvnet_out = self.predict_pvnet_batch(x)
"""Run model forward"""

if "pvnet_outputs" in x:
pvnet_out = x["pvnet_outputs"]
else:
pvnet_out = self.predict_pvnet_batch(x['pvnet_inputs'])

pvnet_out = torch.flatten(pvnet_out, start_dim=1)
out = self.model(pvnet_out)

Expand Down

0 comments on commit 6011359

Please sign in to comment.