diff --git a/src/allin1/training/data/datasets/collate.py b/src/allin1/training/data/datasets/collate.py index 64569cb..01e8d83 100644 --- a/src/allin1/training/data/datasets/collate.py +++ b/src/allin1/training/data/datasets/collate.py @@ -21,7 +21,7 @@ def collate_fn(raw_batch): 'true_beat', 'true_downbeat', 'true_section', 'true_function', 'widen_true_beat', 'widen_true_downbeat', 'widen_true_section', ]: - data[key] = value[:max_T] + data[key] = np.pad(value, (0, max_T - value.shape[0]), 'constant') elif key in ['spec']: T = raw_data[key].shape[1] spec = raw_data[key]