diff --git a/scripts/save_concurrent_batches.py b/scripts/save_concurrent_batches.py index ed32dfba..37833b9e 100644 --- a/scripts/save_concurrent_batches.py +++ b/scripts/save_concurrent_batches.py @@ -32,7 +32,7 @@ import hydra import numpy as np import torch -from ocf_datapipes.batch import BatchKey +from ocf_datapipes.batch import BatchKey, batch_to_tensor from ocf_datapipes.training.pvnet_all_gsp import ( construct_sliced_data_pipeline, construct_time_pipeline, @@ -72,7 +72,7 @@ def _get_datapipe(config_path, start_time, end_time, n_batches): datapipe = construct_sliced_data_pipeline( config_path, t0_datapipe, - ) + ).map(batch_to_tensor) return datapipe