diff --git a/pvnet/data/utils.py b/pvnet/data/utils.py index cd77c887..c85ce40c 100644 --- a/pvnet/data/utils.py +++ b/pvnet/data/utils.py @@ -25,7 +25,7 @@ def batch_to_tensor(batch): def split_batches(batch): """Splits a single batch of data.""" - n_samples = batch[BatchKey.sensor].shape[0] + n_samples = batch[BatchKey.gsp].shape[0] keys = list(batch.keys()) examples = [{} for _ in range(n_samples)] for i in range(n_samples):