From 814e21a2ab87bc0e319c4ea85775a0f6b3e4a6a2 Mon Sep 17 00:00:00 2001 From: Manuel Burger Date: Wed, 9 Oct 2024 23:05:48 +0200 Subject: [PATCH] Always store seq length stats when logging --- src/nanotron/data/petagraph_dataset.py | 9 ++ src/nanotron/trainer.py | 148 ++++++++++++------------- 2 files changed, 79 insertions(+), 78 deletions(-) diff --git a/src/nanotron/data/petagraph_dataset.py b/src/nanotron/data/petagraph_dataset.py index 5b8e23a0..4af72fe9 100644 --- a/src/nanotron/data/petagraph_dataset.py +++ b/src/nanotron/data/petagraph_dataset.py @@ -15,6 +15,7 @@ from tqdm import tqdm import numpy as np from typing import Dict, Optional, Tuple +import json # import zstd import zstandard @@ -77,6 +78,7 @@ def __init__(self, self._bos_token_id = self.VOCAB["BOS"] self._unk_token_id = self.VOCAB["UNK"] + self.num_files = len(url_list) self.current_epoch = 0 @@ -85,6 +87,13 @@ def __init__(self, self.num_consumed_sequences = 0 self.consumed_files_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt" + + # Save the vocabulary as json on head node + if self.rank == 0: + self.vocab_path = log_directory / "vocabulary.json" + with open(self.vocab_path, "w") as f: + json.dump(self.VOCAB, f) + # Take list of already consumed lists and remove them from the # url list, to continue training from the last checkpoint properly # - Check if the consumed_files exist diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 58698358..22951197 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -571,91 +571,83 @@ def train_step_logs( # Gather information from the dataloaders and log statistics # Don't do this too often as it contains an allgather - if self.iteration_step < 5 or (self.iteration_step - 1) % self.config.checkpoints.checkpoint_interval == 0: - if dataloaders is not None: - if isinstance(dataloaders, tuple): - current_dataset = dataloaders[0].dataset - elif isinstance(dataloaders, dict): - current_dataset = dataloaders[list(dataloaders.keys())[0]].dataset - else: - current_dataset = dataloaders.dataset - else: - current_dataset = None - - if current_dataset is not None and hasattr(current_dataset, "consumed_seq_len_queue"): - consumed_seq_lens = np.array(list(current_dataset.consumed_seq_len_queue), dtype=np.int64) - mean_seq_len = np.mean(consumed_seq_lens) - else: - mean_seq_len = 0.0 - - if current_dataset is not None and hasattr(current_dataset, "consumed_files"): - num_consumed_files = len(current_dataset.consumed_files) + if dataloaders is not None: + if isinstance(dataloaders, tuple): + current_dataset = dataloaders[0].dataset + elif isinstance(dataloaders, dict): + current_dataset = dataloaders[list(dataloaders.keys())[0]].dataset else: - num_consumed_files = -1 + current_dataset = dataloaders.dataset + else: + current_dataset = None - if current_dataset is not None and hasattr(current_dataset, "current_epoch"): - current_epoch = current_dataset.current_epoch - else: - current_epoch = -1 + if current_dataset is not None and hasattr(current_dataset, "consumed_seq_len_queue"): + consumed_seq_lens = np.array(list(current_dataset.consumed_seq_len_queue), dtype=np.int64) + mean_seq_len = np.mean(consumed_seq_lens) + else: + mean_seq_len = 0.0 - if current_dataset is not None and hasattr(current_dataset, "num_consumed_sequences"): - num_consumed_sequences = current_dataset.num_consumed_sequences - current_dataset.num_consumed_sequences = 0 - else: - num_consumed_sequences = 0 + if current_dataset is not None and hasattr(current_dataset, "consumed_files"): + num_consumed_files = len(current_dataset.consumed_files) + else: + num_consumed_files = -1 - # Gather the values across all ranks - world_size_dp_pg = self.parallel_context.dp_pg.size() + if current_dataset is not None and hasattr(current_dataset, "current_epoch"): + current_epoch = current_dataset.current_epoch + else: + current_epoch = -1 - num_consumed_files_t = torch.tensor(num_consumed_files, device="cuda", dtype=torch.int64) - num_consumed_files_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64) - dist.all_gather_into_tensor( - output_tensor=num_consumed_files_t_all, - input_tensor=num_consumed_files_t, - group=self.parallel_context.dp_pg - ) - num_consumed_files_ranks = num_consumed_files_t_all.cpu().numpy() - num_consumed_files_all = num_consumed_files_ranks.sum() - self.metadata.consumed_num_logan_files = int(num_consumed_files_all) - - current_epoch_t = torch.tensor(current_epoch, device="cuda", dtype=torch.int64) - current_epoch_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64) - dist.all_gather_into_tensor( - output_tensor=current_epoch_t_all, - input_tensor=current_epoch_t, - group=self.parallel_context.dp_pg - ) - current_epoch_ranks = current_epoch_t_all.cpu().numpy() - current_epoch_all = current_epoch_ranks.mean() - - num_consumed_seq_t = torch.tensor(num_consumed_sequences, device="cuda", dtype=torch.int64) - num_consumed_seq_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64) - dist.all_gather_into_tensor( - output_tensor=num_consumed_seq_t_all, - input_tensor=num_consumed_seq_t, - group=self.parallel_context.dp_pg - ) - num_consumed_seq_ranks = num_consumed_seq_t_all.cpu().numpy() - num_consumed_seq_all = num_consumed_seq_ranks.sum() - self.metadata.consumed_num_sequences += int(num_consumed_seq_all) - num_consumed_seq_log = self.metadata.consumed_num_sequences - - mean_consumed_seq_len_t = torch.tensor(mean_seq_len, device="cuda", dtype=torch.float32) - mean_consumed_seq_len_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.float32) - dist.all_gather_into_tensor( - output_tensor=mean_consumed_seq_len_t_all, - input_tensor=mean_consumed_seq_len_t, - group=self.parallel_context.dp_pg - ) - mean_consumed_seq_len_ranks = mean_consumed_seq_len_t_all.cpu().numpy() - mean_consumed_seq_len_all = mean_consumed_seq_len_ranks.mean() + if current_dataset is not None and hasattr(current_dataset, "num_consumed_sequences"): + num_consumed_sequences = current_dataset.num_consumed_sequences + current_dataset.num_consumed_sequences = 0 + else: + num_consumed_sequences = 0 + # Gather the values across all ranks + world_size_dp_pg = self.parallel_context.dp_pg.size() - else: - num_consumed_files_all = None - current_epoch_all = None - num_consumed_seq_log = None - mean_consumed_seq_len_all = None + num_consumed_files_t = torch.tensor(num_consumed_files, device="cuda", dtype=torch.int64) + num_consumed_files_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64) + dist.all_gather_into_tensor( + output_tensor=num_consumed_files_t_all, + input_tensor=num_consumed_files_t, + group=self.parallel_context.dp_pg + ) + num_consumed_files_ranks = num_consumed_files_t_all.cpu().numpy() + num_consumed_files_all = num_consumed_files_ranks.sum() + self.metadata.consumed_num_logan_files = int(num_consumed_files_all) + + current_epoch_t = torch.tensor(current_epoch, device="cuda", dtype=torch.int64) + current_epoch_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64) + dist.all_gather_into_tensor( + output_tensor=current_epoch_t_all, + input_tensor=current_epoch_t, + group=self.parallel_context.dp_pg + ) + current_epoch_ranks = current_epoch_t_all.cpu().numpy() + current_epoch_all = current_epoch_ranks.mean() + + num_consumed_seq_t = torch.tensor(num_consumed_sequences, device="cuda", dtype=torch.int64) + num_consumed_seq_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64) + dist.all_gather_into_tensor( + output_tensor=num_consumed_seq_t_all, + input_tensor=num_consumed_seq_t, + group=self.parallel_context.dp_pg + ) + num_consumed_seq_ranks = num_consumed_seq_t_all.cpu().numpy() + num_consumed_seq_all = num_consumed_seq_ranks.sum() + self.metadata.consumed_num_sequences += int(num_consumed_seq_all) + num_consumed_seq_log = self.metadata.consumed_num_sequences + + mean_consumed_seq_len_t = torch.tensor(mean_seq_len, device="cuda", dtype=torch.float32) + mean_consumed_seq_len_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.float32) + dist.all_gather_into_tensor( + output_tensor=mean_consumed_seq_len_t_all, + input_tensor=mean_consumed_seq_len_t, + group=self.parallel_context.dp_pg + ) + mean_consumed_seq_len_ranks = mean_consumed_seq_len_t_all.cpu().numpy() + mean_consumed_seq_len_all = mean_consumed_seq_len_ranks.mean() # Logging on logger ranks if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks: