diff --git a/pvnet/data/datamodule.py b/pvnet/data/datamodule.py index 8155b41c..abddbe8f 100644 --- a/pvnet/data/datamodule.py +++ b/pvnet/data/datamodule.py @@ -24,7 +24,7 @@ def __len__(self): def __getitem__(self, idx): return torch.load(self.sample_paths[idx]) - + def collate_fn(samples: list[NumpyBatch]): """Convert a list of NumpyBatch samples to a tensor batch"""