From af5ac422d04bb7e9a92fce42d282713812329020 Mon Sep 17 00:00:00 2001 From: Sebastian Hoffmann Date: Tue, 10 Oct 2023 19:20:13 +0200 Subject: [PATCH] feat: more comprehensive dataset logs --- dmlcloud/training/trainer.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/dmlcloud/training/trainer.py b/dmlcloud/training/trainer.py index c6cba11..03c2070 100644 --- a/dmlcloud/training/trainer.py +++ b/dmlcloud/training/trainer.py @@ -233,20 +233,36 @@ def setup_dataset(self): logging.info(f'Dataset creation took {(datetime.now() - ts).total_seconds():.1f}s') if hasattr(self.train_dl, 'dataset') and hasattr(self.train_dl.dataset, '__len__'): - logging.info(f'Train dataset size: {len(self.train_dl.dataset)}') - if hasattr(self.val_dl, 'dataset') and hasattr(self.val_dl.dataset, '__len__'): - logging.info(f' Val dataset size: {len(self.val_dl.dataset)}') - + train_samples = f'{len(self.train_dl.dataset)}' + else: + train_samples = 'N/A' train_sizes = hvd.allgather(torch.tensor([len(self.train_dl)]), name='train_dataset_size') train_sizes = [t.item() for t in train_sizes] + msg = 'Train dataset:' + msg += f'\n\t* Batches: {train_sizes[0]}' + msg += f'\n\t* Batches (total): {sum(train_sizes)}' + msg += f'\n\t* Samples (calculated): {sum(train_sizes) * self.cfg.batch_size}' + msg += f'\n\t* Samples (raw): {train_samples}' + logging.info(msg) if len(set(train_sizes)) > 1 and self.is_root: - logging.warning(f'Uneven train dataset batches: {train_sizes}') + logging.warning(f'!!! Uneven train dataset batches: {train_sizes}') if self.val_dl is not None: + if hasattr(self.val_dl, 'dataset') and hasattr(self.val_dl.dataset, '__len__'): + val_samples = f'{len(self.val_dl.dataset)}' + else: + val_samples = 'N/A' + val_sizes = hvd.allgather(torch.tensor([len(self.val_dl)]), name='val_dataset_size') val_sizes = [t.item() for t in val_sizes] + msg = 'Train dataset:' + msg += f'\n\t* Batches: {val_sizes[0]}' + msg += f'\n\t* Batches (total): {sum(val_sizes)}' + msg += f'\n\t* Samples (calculated): {sum(val_sizes) * self.cfg.batch_size}' + msg += f'\n\t* Samples (raw): {val_samples}' + logging.info(msg) if len(set(val_sizes)) > 1 and self.is_root: - logging.warning(f'Uneven val dataset batches: {val_sizes}') + logging.warning(f'!!! Uneven val dataset batches: {val_sizes}') log_delimiter()