Skip to content

Commit

Permalink
fix batch to tensor and to device for nested NWP
Browse files Browse the repository at this point in the history
  • Loading branch information
dfulu committed Dec 21, 2023
1 parent adeb778 commit 0c135a0
Showing 1 changed file with 15 additions and 9 deletions.
24 changes: 15 additions & 9 deletions pvnet/data/utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,33 @@
"""Utils common between Wind and PV datamodules"""
import numpy as np
import torch
from ocf_datapipes.batch import unstack_np_batch_into_examples
from ocf_datapipes.utils.consts import BatchKey
from ocf_datapipes.batch import BatchKey, unstack_np_batch_into_examples
from torch.utils.data import IterDataPipe, functional_datapipe


def copy_batch_to_device(batch, device):
"""Moves a dict-batch of tensors to new device."""
batch_copy = {}
for k in list(batch.keys()):
if isinstance(batch[k], torch.Tensor):
batch_copy[k] = batch[k].to(device)

for k, v in batch.items():
if isinstance(v, dict):
# Recursion to reach the nested NWP
batch_copy[k] = copy_batch_to_device(v)
elif isinstance(v, torch.Tensor):
batch_copy[k] = v.to(device)
else:
batch_copy[k] = batch[k]
batch_copy[k] = v
return batch_copy


def batch_to_tensor(batch):
"""Moves numpy batch to a tensor"""
for k in list(batch.keys()):
if isinstance(batch[k], np.ndarray) and np.issubdtype(batch[k].dtype, np.number):
batch[k] = torch.as_tensor(batch[k])
for k, v in batch.items():
if isinstance(v, dict):
# Recursion to reach the nested NWP
batch[k] = batch_to_tensor(v)
elif isinstance(v, np.ndarray) and np.issubdtype(v.dtype, np.number):
batch[k] = torch.as_tensor(v)
return batch


Expand Down

0 comments on commit 0c135a0

Please sign in to comment.