From b62066aa6a4dd7ea01b78c40f2ca99bddab85f24 Mon Sep 17 00:00:00 2001 From: Manuel Burger Date: Wed, 6 Nov 2024 14:05:51 +0100 Subject: [PATCH] Support reverse probability mode --- petagraph/run_train.py | 3 ++- src/nanotron/config/config.py | 1 + src/nanotron/data/petagraph_dataset.py | 16 +++++++++++++--- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/petagraph/run_train.py b/petagraph/run_train.py index d8e4e970..65c91e46 100644 --- a/petagraph/run_train.py +++ b/petagraph/run_train.py @@ -241,7 +241,8 @@ def get_dataloader_from_data_stage( create_attention_mask=True, log_directory=trainer.config.checkpoints.checkpoints_path, rank=global_rank, - packed=True + packed=True, + reverse_probability=data.reverse_probability, ) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index 9ad604d6..8a58999f 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -130,6 +130,7 @@ class DataArgs: sequence_files_path: Optional[str] = None prefetch_buffer_seq_size: Optional[int] = 1 all_sequences_resources_path: Optional[str] = None + reverse_probability: float = 0.0 def __post_init__(self): if self.seed is None: diff --git a/src/nanotron/data/petagraph_dataset.py b/src/nanotron/data/petagraph_dataset.py index 40f73f5b..07b02710 100644 --- a/src/nanotron/data/petagraph_dataset.py +++ b/src/nanotron/data/petagraph_dataset.py @@ -536,6 +536,8 @@ class PetaGraphStreamDatasetV2(torch.utils.data.IterableDataset): The sequence length at which to switch from sampling to keeping the sequence below the inflection point we only keep the sequence with a probability pr to its length. Above the inflection point we always keep the sequence. + reverse_probability : float + The probability to reverse the sequence. Only active if != 0.0 """ def __init__(self, @@ -550,13 +552,15 @@ def __init__(self, log_directory: Path = None, rank: int = 0, packed: bool = False, - sampling_seq_len_inflection: int = 1024 + sampling_seq_len_inflection: int = 1024, + reverse_probability: float = 0.0 ): self.maxlen = maxlen self.create_attention_mask = create_attention_mask self.debug = debug self.sampling_seq_len_inflection = sampling_seq_len_inflection + self.reverse_probability = reverse_probability self.logger = logger @@ -567,6 +571,8 @@ def __init__(self, self.logging_func(f"[PetaGraphStreamDataset] Num. URLs: {len(url_list)}") self.logging_func(f"[PetaGraphStreamDataset] From Cloud: {from_cloud}") self.logging_func(f"[PetaGraphStreamDataset] Sampling Seq. Len. Inflection: {self.sampling_seq_len_inflection}") + if self.reverse_probability > 0.0: + self.logging_func(f"[PetaGraphStreamDataset] Reverse Probability: {self.reverse_probability}") self.VOCAB = vocabulary self._pad_token_id = self.VOCAB["PAD"] @@ -858,6 +864,10 @@ def tokenize_and_pad(self, input_sequence: str, apply_pad: bool = True): tokenized_sequence.append(self._eos_token_id) # end with EOS token tokenized_sequence = np.array(tokenized_sequence, dtype=np.int32) + if self.reverse_probability > 0.0: + if np.random.rand() < self.reverse_probability: + tokenized_sequence = tokenized_sequence[::-1] + # Pad the sequence if apply_pad and len(tokenized_sequence) < maxlen: # 2 is the PAD token @@ -964,8 +974,8 @@ def generate(self): current_tokens = new_tokens else: # Check the last token of the current sequence - # is an EOS token - assert current_tokens[-1] == self._eos_token_id + # is an EOS token or BOS token (if reverse_probability > 0.0) + assert current_tokens[-1] == self._eos_token_id or (self.reverse_probability > 0.0 and current_tokens[-1] == self._bos_token_id) current_tokens = np.concatenate([current_tokens, new_tokens]) if len(current_tokens) >= self.maxlen: