Skip to content

Commit

Permalink
Always store seq length stats when logging
Browse files Browse the repository at this point in the history
  • Loading branch information
manuelburger committed Oct 9, 2024
1 parent 45328b2 commit 814e21a
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 78 deletions.
9 changes: 9 additions & 0 deletions src/nanotron/data/petagraph_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from tqdm import tqdm
import numpy as np
from typing import Dict, Optional, Tuple
import json

# import zstd
import zstandard
Expand Down Expand Up @@ -77,6 +78,7 @@ def __init__(self,
self._bos_token_id = self.VOCAB["BOS"]
self._unk_token_id = self.VOCAB["UNK"]


self.num_files = len(url_list)
self.current_epoch = 0

Expand All @@ -85,6 +87,13 @@ def __init__(self,
self.num_consumed_sequences = 0
self.consumed_files_path = self.log_directory / f"consumed_files/consumed_files_rank_{self.rank}.txt"


# Save the vocabulary as json on head node
if self.rank == 0:
self.vocab_path = log_directory / "vocabulary.json"
with open(self.vocab_path, "w") as f:
json.dump(self.VOCAB, f)

# Take list of already consumed lists and remove them from the
# url list, to continue training from the last checkpoint properly
# - Check if the consumed_files exist
Expand Down
148 changes: 70 additions & 78 deletions src/nanotron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,91 +571,83 @@ def train_step_logs(

# Gather information from the dataloaders and log statistics
# Don't do this too often as it contains an allgather
if self.iteration_step < 5 or (self.iteration_step - 1) % self.config.checkpoints.checkpoint_interval == 0:
if dataloaders is not None:
if isinstance(dataloaders, tuple):
current_dataset = dataloaders[0].dataset
elif isinstance(dataloaders, dict):
current_dataset = dataloaders[list(dataloaders.keys())[0]].dataset
else:
current_dataset = dataloaders.dataset
else:
current_dataset = None

if current_dataset is not None and hasattr(current_dataset, "consumed_seq_len_queue"):
consumed_seq_lens = np.array(list(current_dataset.consumed_seq_len_queue), dtype=np.int64)
mean_seq_len = np.mean(consumed_seq_lens)
else:
mean_seq_len = 0.0

if current_dataset is not None and hasattr(current_dataset, "consumed_files"):
num_consumed_files = len(current_dataset.consumed_files)
if dataloaders is not None:
if isinstance(dataloaders, tuple):
current_dataset = dataloaders[0].dataset
elif isinstance(dataloaders, dict):
current_dataset = dataloaders[list(dataloaders.keys())[0]].dataset
else:
num_consumed_files = -1
current_dataset = dataloaders.dataset
else:
current_dataset = None

if current_dataset is not None and hasattr(current_dataset, "current_epoch"):
current_epoch = current_dataset.current_epoch
else:
current_epoch = -1
if current_dataset is not None and hasattr(current_dataset, "consumed_seq_len_queue"):
consumed_seq_lens = np.array(list(current_dataset.consumed_seq_len_queue), dtype=np.int64)
mean_seq_len = np.mean(consumed_seq_lens)
else:
mean_seq_len = 0.0

if current_dataset is not None and hasattr(current_dataset, "num_consumed_sequences"):
num_consumed_sequences = current_dataset.num_consumed_sequences
current_dataset.num_consumed_sequences = 0
else:
num_consumed_sequences = 0
if current_dataset is not None and hasattr(current_dataset, "consumed_files"):
num_consumed_files = len(current_dataset.consumed_files)
else:
num_consumed_files = -1

# Gather the values across all ranks
world_size_dp_pg = self.parallel_context.dp_pg.size()
if current_dataset is not None and hasattr(current_dataset, "current_epoch"):
current_epoch = current_dataset.current_epoch
else:
current_epoch = -1

num_consumed_files_t = torch.tensor(num_consumed_files, device="cuda", dtype=torch.int64)
num_consumed_files_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64)
dist.all_gather_into_tensor(
output_tensor=num_consumed_files_t_all,
input_tensor=num_consumed_files_t,
group=self.parallel_context.dp_pg
)
num_consumed_files_ranks = num_consumed_files_t_all.cpu().numpy()
num_consumed_files_all = num_consumed_files_ranks.sum()
self.metadata.consumed_num_logan_files = int(num_consumed_files_all)

current_epoch_t = torch.tensor(current_epoch, device="cuda", dtype=torch.int64)
current_epoch_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64)
dist.all_gather_into_tensor(
output_tensor=current_epoch_t_all,
input_tensor=current_epoch_t,
group=self.parallel_context.dp_pg
)
current_epoch_ranks = current_epoch_t_all.cpu().numpy()
current_epoch_all = current_epoch_ranks.mean()

num_consumed_seq_t = torch.tensor(num_consumed_sequences, device="cuda", dtype=torch.int64)
num_consumed_seq_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64)
dist.all_gather_into_tensor(
output_tensor=num_consumed_seq_t_all,
input_tensor=num_consumed_seq_t,
group=self.parallel_context.dp_pg
)
num_consumed_seq_ranks = num_consumed_seq_t_all.cpu().numpy()
num_consumed_seq_all = num_consumed_seq_ranks.sum()
self.metadata.consumed_num_sequences += int(num_consumed_seq_all)
num_consumed_seq_log = self.metadata.consumed_num_sequences

mean_consumed_seq_len_t = torch.tensor(mean_seq_len, device="cuda", dtype=torch.float32)
mean_consumed_seq_len_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.float32)
dist.all_gather_into_tensor(
output_tensor=mean_consumed_seq_len_t_all,
input_tensor=mean_consumed_seq_len_t,
group=self.parallel_context.dp_pg
)
mean_consumed_seq_len_ranks = mean_consumed_seq_len_t_all.cpu().numpy()
mean_consumed_seq_len_all = mean_consumed_seq_len_ranks.mean()
if current_dataset is not None and hasattr(current_dataset, "num_consumed_sequences"):
num_consumed_sequences = current_dataset.num_consumed_sequences
current_dataset.num_consumed_sequences = 0
else:
num_consumed_sequences = 0

# Gather the values across all ranks
world_size_dp_pg = self.parallel_context.dp_pg.size()

else:
num_consumed_files_all = None
current_epoch_all = None
num_consumed_seq_log = None
mean_consumed_seq_len_all = None
num_consumed_files_t = torch.tensor(num_consumed_files, device="cuda", dtype=torch.int64)
num_consumed_files_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64)
dist.all_gather_into_tensor(
output_tensor=num_consumed_files_t_all,
input_tensor=num_consumed_files_t,
group=self.parallel_context.dp_pg
)
num_consumed_files_ranks = num_consumed_files_t_all.cpu().numpy()
num_consumed_files_all = num_consumed_files_ranks.sum()
self.metadata.consumed_num_logan_files = int(num_consumed_files_all)

current_epoch_t = torch.tensor(current_epoch, device="cuda", dtype=torch.int64)
current_epoch_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64)
dist.all_gather_into_tensor(
output_tensor=current_epoch_t_all,
input_tensor=current_epoch_t,
group=self.parallel_context.dp_pg
)
current_epoch_ranks = current_epoch_t_all.cpu().numpy()
current_epoch_all = current_epoch_ranks.mean()

num_consumed_seq_t = torch.tensor(num_consumed_sequences, device="cuda", dtype=torch.int64)
num_consumed_seq_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.int64)
dist.all_gather_into_tensor(
output_tensor=num_consumed_seq_t_all,
input_tensor=num_consumed_seq_t,
group=self.parallel_context.dp_pg
)
num_consumed_seq_ranks = num_consumed_seq_t_all.cpu().numpy()
num_consumed_seq_all = num_consumed_seq_ranks.sum()
self.metadata.consumed_num_sequences += int(num_consumed_seq_all)
num_consumed_seq_log = self.metadata.consumed_num_sequences

mean_consumed_seq_len_t = torch.tensor(mean_seq_len, device="cuda", dtype=torch.float32)
mean_consumed_seq_len_t_all = torch.zeros(world_size_dp_pg, device="cuda", dtype=torch.float32)
dist.all_gather_into_tensor(
output_tensor=mean_consumed_seq_len_t_all,
input_tensor=mean_consumed_seq_len_t,
group=self.parallel_context.dp_pg
)
mean_consumed_seq_len_ranks = mean_consumed_seq_len_t_all.cpu().numpy()
mean_consumed_seq_len_all = mean_consumed_seq_len_ranks.mean()

# Logging on logger ranks
if dist.get_rank(self.parallel_context.world_pg) in self.logger_ranks:
Expand Down

0 comments on commit 814e21a

Please sign in to comment.