From 728e076d7a73a9257d0e20f1084920de833615d8 Mon Sep 17 00:00:00 2001 From: janEbert Date: Tue, 13 Dec 2022 12:08:46 +0100 Subject: [PATCH] Support UL2 for decoder-only models --- megatron/arguments.py | 18 +++- megatron/data/dataset_utils.py | 1 + megatron/data/ul2_dataset.py | 161 ++++++++++++++++++++++++--------- megatron/enums.py | 5 + pretrain_ul2.py | 99 +++++++++++++++----- 5 files changed, 215 insertions(+), 69 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 5fd48a7b8..9aae25cda 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -24,7 +24,7 @@ import torch import deepspeed -from megatron.enums import PositionEmbeddingType +from megatron.enums import PositionEmbeddingType, UL2ModelType import megatron from megatron.logging import log_levels @@ -311,6 +311,17 @@ def parse_args(extra_args_provider=None, defaults={}, ) args.skip_train_iteration_range = skip_train_iteration_range + args.ul2_model_type = UL2ModelType(args.ul2_model_type) + if ( + args.ul2_model_type is not UL2ModelType.ENCODER_DECODER + and args.decoder_seq_length is not None + ): + print( + f'WARNING: `--decoder_seq_length` is ignored when ' + f'`--ul2-model-type` is not ' + f'"{UL2ModelType.ENCODER_DECODER.value}"!' + ) + if args.use_bnb_optimizer: try: import bitsandbytes as bnb @@ -1028,6 +1039,11 @@ def _add_vit_args(parser): def _add_ul2_args(parser): group = parser.add_argument_group(title="UL2") + group.add_argument('--ul2-model-type', type=str, default='ED', + choices=['ED', 'ND', 'CD'], + help='What type of model to use for UL2 pretraining. ' + 'ED = encoder-decoder; ND = non-causal decoder-only; ' + 'CD = causal decoder-only') group.add_argument('--ul2-denoiser-ratios', nargs='+', type=float, default=None, help='Probability of each denoising objective to be ' diff --git a/megatron/data/dataset_utils.py b/megatron/data/dataset_utils.py index 92f37b2f8..60d4e0d90 100644 --- a/megatron/data/dataset_utils.py +++ b/megatron/data/dataset_utils.py @@ -597,6 +597,7 @@ def build_dataset(index, name): args = get_args() dataset = UL2Dataset( indexed_dataset=indexed_dataset, + model_type=args.ul2_model_type, denoiser_ratios=args.ul2_denoiser_ratios, denoisers=args.ul2_denoisers, mean_span_lengths=args.ul2_mean_span_lengths, diff --git a/megatron/data/ul2_dataset.py b/megatron/data/ul2_dataset.py index 4f2d333a1..7fc3e6f32 100644 --- a/megatron/data/ul2_dataset.py +++ b/megatron/data/ul2_dataset.py @@ -15,6 +15,8 @@ """UL2-style dataset.""" +import math + import numpy as np from megatron import get_tokenizer @@ -23,16 +25,34 @@ get_samples_mapping, SamplingStyle ) -from megatron.data.t5_dataset import pad_and_convert_to_numpy, T5Dataset +from megatron.data.t5_dataset import ( + make_history_mask, + merge_subsequent_masks, + pad_and_convert_to_numpy, + T5Dataset, +) +from megatron.enums import UL2ModelType + + +def is_decoder_only(ul2_model_type): + """Return whether we use a decoder-only model.""" + assert isinstance(ul2_model_type, UL2ModelType) + return ul2_model_type is not UL2ModelType.ENCODER_DECODER + + +def is_prefix_lm(ul2_model_type): + """Return whether we use a non-causal decoder-only model.""" + assert isinstance(ul2_model_type, UL2ModelType) + return ul2_model_type is UL2ModelType.NON_CAUSAL_DECODER class UL2Dataset(T5Dataset): def __init__(self, name, indexed_dataset, data_prefix, - num_epochs, max_num_samples, denoiser_ratios, - denoisers, mean_span_lengths, mask_ratios, - denoiser_tokens, max_seq_length, max_seq_length_dec, - short_seq_prob, seed): + num_epochs, max_num_samples, model_type, + denoiser_ratios, denoisers, mean_span_lengths, + mask_ratios, denoiser_tokens, max_seq_length, + max_seq_length_dec, short_seq_prob, seed): if denoiser_ratios is None: # Uniform distribution by default. @@ -52,6 +72,7 @@ def __init__(self, name, indexed_dataset, data_prefix, short_seq_prob, seed) # Params to store. + self.model_type = model_type self.denoiser_ratios = [ denoiser_ratio / sum(denoiser_ratios) for denoiser_ratio in denoiser_ratios @@ -97,21 +118,21 @@ def __getitem__(self, idx): self.vocab_id_to_token_dict, self.cls_ids, self.sep_id, self.mask_id, self.pad_id, - self.denoiser_ratios, self.denoisers, - self.mean_span_lengths, self.mask_ratios, - np_rng, - self.bos_id, self.eos_id, - self.sentinel_tokens) + self.model_type, self.denoiser_ratios, + self.denoisers, self.mean_span_lengths, + self.mask_ratios, np_rng, self.bos_id, + self.eos_id, self.sentinel_tokens) def build_training_sample(sample, target_seq_length, max_seq_length, max_seq_length_dec, vocab_id_list, vocab_id_to_token_dict, cls_ids, sep_id, mask_id, pad_id, - denoiser_ratios, denoisers, - mean_span_lengths, mask_ratios, - np_rng, bos_id=None, - eos_id=None, sentinel_tokens=None): + model_type, denoiser_ratios, + denoisers, mean_span_lengths, + mask_ratios, np_rng, + bos_id=None, eos_id=None, + sentinel_tokens=None): """Build training sample. Arguments: @@ -125,6 +146,7 @@ def build_training_sample(sample, target_seq_length, sep_id: Separator id. mask_id: Mask token id. pad_id: Padding token id. + model_type: What type of model is used. denoiser_ratios: Probability of each denoising objective to be selected. denoisers: What type of UL2 denoising objective the other UL2 configurations refer to. @@ -139,24 +161,28 @@ def build_training_sample(sample, target_seq_length, sentinel_tokens: unique value to be substituted for every replaced span """ + # Denoiser selection + denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios) + denoiser = denoisers[denoiser_index] + masked_lm_prob = mask_ratios[denoiser_index] + assert target_seq_length <= max_seq_length # flatten sentences into one list tokens = [token for sentence in sample for token in sentence] - # Truncate to `target_sequence_length`. max_num_tokens = target_seq_length - truncated = len(tokens) > max_num_tokens - tokens = tokens[:max_num_tokens] - - # Denoiser selection - denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios) - denoiser = denoisers[denoiser_index] - masked_lm_prob = mask_ratios[denoiser_index] - mean_ngrams = mean_span_lengths[denoiser_index] - if mean_ngrams < 1: - mean_ngrams = round(len(tokens) * mean_ngrams) - max_ngrams = mean_ngrams * 2 - 1 + if is_decoder_only(model_type): + # Keep space for repeated `extra_id` tokens; not the most data + # efficient since we calculate this based on the maximum number + # of possible `extra_id` tokens. + safe_max_seq_len = math.floor(max_num_tokens / (1 + masked_lm_prob)) + truncated = len(tokens) > safe_max_seq_len + tokens = tokens[:safe_max_seq_len] + else: + # Truncate to `target_sequence_length`. + truncated = len(tokens) > max_num_tokens + tokens = tokens[:max_num_tokens] # Prepend objective token. cls_id = cls_ids.get(denoiser) @@ -166,6 +192,11 @@ def build_training_sample(sample, target_seq_length, # Masking. max_predictions_per_seq = masked_lm_prob * len(tokens) + mean_ngrams = mean_span_lengths[denoiser_index] + if mean_ngrams < 1: + mean_ngrams = round(len(tokens) * mean_ngrams) + max_ngrams = mean_ngrams * 2 - 1 + if denoiser == 'R' or denoiser == 'X': sampling_style = SamplingStyle.NORMAL prefix_lm = False @@ -183,22 +214,64 @@ def build_training_sample(sample, target_seq_length, sampling_style=sampling_style, prefix_lm=prefix_lm, ) - # Padding. - tokens_enc, tokens_dec_in, labels, enc_mask, \ - dec_mask, enc_dec_mask, loss_mask \ - = pad_and_convert_to_numpy(tokens, masked_positions, - masked_labels, pad_id, max_seq_length, - max_seq_length_dec, masked_spans, - bos_id, eos_id, sentinel_tokens) - - train_sample = { - 'text_enc': tokens_enc, - 'text_dec': tokens_dec_in, - 'labels': labels, - 'loss_mask': loss_mask, - 'truncated': int(truncated), - 'enc_mask': enc_mask, - 'dec_mask': dec_mask, - 'enc_dec_mask': enc_dec_mask, - } + if is_decoder_only(model_type): + # Concatenate to one sequence. + tokens_enc, tokens_dec_in, labels = merge_subsequent_masks( + tokens, masked_spans, bos_id, eos_id, sentinel_tokens) + + # Move EOS tokens to end of sequence. + while tokens_enc[-1] == eos_id: + del tokens_enc[-1] + tokens_dec_in.append(eos_id) + labels.append(eos_id) + + num_labels = len(labels) + + # Move BOS token to start of sequence. + tokens_dec_in = tokens_dec_in[1:] + tokens = np.concatenate([ + np.array([bos_id], dtype=np.int64), + tokens_enc, + np.array([sep_id], dtype=np.int64), + tokens_dec_in, + ]) + labels = np.concatenate([ + tokens_enc, + np.array([sep_id], dtype=np.int64), + labels, + ]) + + loss_mask = np.zeros(len(tokens), dtype=np.int64) + loss_mask[-num_labels:] = 1 + + dec_mask = make_history_mask(tokens) + if is_prefix_lm(model_type): + dec_mask[:-num_labels, :-num_labels] = 1 + + train_sample = { + 'text': tokens, + 'labels': labels, + 'loss_mask': loss_mask, + 'truncated': int(truncated), + 'dec_mask': dec_mask, + } + else: + # Padding. + tokens_enc, tokens_dec_in, labels, enc_mask, \ + dec_mask, enc_dec_mask, loss_mask \ + = pad_and_convert_to_numpy(tokens, masked_positions, + masked_labels, pad_id, max_seq_length, + max_seq_length_dec, masked_spans, + bos_id, eos_id, sentinel_tokens) + + train_sample = { + 'text_enc': tokens_enc, + 'text_dec': tokens_dec_in, + 'labels': labels, + 'loss_mask': loss_mask, + 'truncated': int(truncated), + 'enc_mask': enc_mask, + 'dec_mask': dec_mask, + 'enc_dec_mask': enc_dec_mask, + } return train_sample diff --git a/megatron/enums.py b/megatron/enums.py index 90d00a071..2961cbb66 100644 --- a/megatron/enums.py +++ b/megatron/enums.py @@ -33,3 +33,8 @@ class PositionEmbeddingType(enum.Enum): rotary = 1 absolute = 2 alibi = 3 + +class UL2ModelType(enum.Enum): + ENCODER_DECODER = 'ED' + NON_CAUSAL_DECODER = 'ND' + CAUSAL_DECODER = 'CD' diff --git a/pretrain_ul2.py b/pretrain_ul2.py index 04b2b0dc6..cab24ced0 100644 --- a/pretrain_ul2.py +++ b/pretrain_ul2.py @@ -26,26 +26,56 @@ print_rank_0 ) from megatron.data.dataset_utils import build_train_valid_test_datasets -from megatron.model.t5_model import T5Model +from megatron.data.ul2_dataset import ( + is_decoder_only as _is_decoder_only, + is_prefix_lm as _is_prefix_lm, +) +from megatron.model.gpt_model import GPTModel +from megatron.model.t5_model import T5Model, t5_position_ids from megatron.training import pretrain from megatron.utils import average_losses_across_data_parallel_group +def is_decoder_only(): + """Return whether we use a decoder-only model.""" + args = get_args() + return _is_decoder_only(args.ul2_model_type) + + +def is_prefix_lm(): + """Return whether we use a non-causal decoder-only model.""" + args = get_args() + return _is_prefix_lm(args.ul2_model_type) + + def model_provider(pre_process=True, post_process=True): """Build the model.""" assert pre_process and post_process, "UL2 doesn't yet support pipelining" print_rank_0('building UL2 model ...') - model = T5Model(num_tokentypes=0, - parallel_output=True) + if is_decoder_only(): + print_rank_0('Using decoder-only UL2 model.') + model = GPTModel( + num_tokentypes=0, + parallel_output=True, + pre_process=pre_process, + post_process=post_process, + prefix_lm=is_prefix_lm(), + ) + else: + print_rank_0('Using encoder-decoder UL2 model.') + model = T5Model(num_tokentypes=0, parallel_output=True) return model def get_batch(data_iterator): """Build the batch.""" - keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', - 'enc_mask', 'dec_mask', 'enc_dec_mask'] + if is_decoder_only(): + keys = ['text', 'labels', 'loss_mask', 'dec_mask'] + else: + keys = ['text_enc', 'text_dec', 'labels', 'loss_mask', + 'enc_mask', 'dec_mask', 'enc_dec_mask'] datatype = torch.int64 # Broadcast data. @@ -56,21 +86,32 @@ def get_batch(data_iterator): data_b = mpu.broadcast_data(keys, data, datatype) # Unpack. - tokens_enc = data_b['text_enc'].long() - tokens_dec = data_b['text_dec'].long() - labels = data_b['labels'].long() - loss_mask = data_b['loss_mask'].float() + if is_decoder_only(): + tokens = data_b['text'].long() + labels = data_b['labels'].long() + loss_mask = data_b['loss_mask'].float() + + dec_mask = (data_b['dec_mask'] < 0.5) + return tokens, loss_mask, labels, dec_mask + else: + tokens_enc = data_b['text_enc'].long() + tokens_dec = data_b['text_dec'].long() + labels = data_b['labels'].long() + loss_mask = data_b['loss_mask'].float() - enc_mask = (data_b['enc_mask'] < 0.5) - dec_mask = (data_b['dec_mask'] < 0.5) - enc_dec_mask = (data_b['enc_dec_mask'] < 0.5) + enc_mask = (data_b['enc_mask'] < 0.5) + dec_mask = (data_b['dec_mask'] < 0.5) + enc_dec_mask = (data_b['enc_dec_mask'] < 0.5) - return tokens_enc, tokens_dec, loss_mask, labels, \ - enc_mask, dec_mask, enc_dec_mask + return tokens_enc, tokens_dec, loss_mask, labels, \ + enc_mask, dec_mask, enc_dec_mask def loss_func(loss_mask, output_tensor): - lm_loss_, _ = output_tensor + if is_decoder_only(): + lm_loss_ = output_tensor + else: + lm_loss_, _ = output_tensor lm_loss_ = lm_loss_.float() lm_loss = torch.sum( @@ -89,18 +130,28 @@ def forward_step(data_iterator, model): # Get the batch. timers('batch generator').start() - tokens_enc, tokens_dec, loss_mask, lm_labels, enc_mask, dec_mask, enc_dec_mask \ - = get_batch(data_iterator) + if is_decoder_only(): + (tokens, loss_mask, lm_labels, dec_mask) = get_batch(data_iterator) + else: + ( + tokens_enc, tokens_dec, loss_mask, lm_labels, + enc_mask, dec_mask, enc_dec_mask, + ) = get_batch(data_iterator) timers('batch generator').stop() # Forward model lm_labels - output_tensor = model(tokens_enc, - tokens_dec, - enc_mask, - dec_mask, - enc_dec_mask, - tokentype_ids=None, - lm_labels=lm_labels) + if is_decoder_only(): + position_ids = t5_position_ids(tokens) + output_tensor = model(tokens, position_ids, dec_mask, + labels=lm_labels) + else: + output_tensor = model(tokens_enc, + tokens_dec, + enc_mask, + dec_mask, + enc_dec_mask, + tokentype_ids=None, + lm_labels=lm_labels) return output_tensor, partial(loss_func, loss_mask)