Skip to content

Commit

Permalink
Support UL2 for decoder-only models
Browse files Browse the repository at this point in the history
  • Loading branch information
janEbert committed Dec 13, 2022
1 parent deed87f commit 728e076
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 69 deletions.
18 changes: 17 additions & 1 deletion megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
import deepspeed

from megatron.enums import PositionEmbeddingType
from megatron.enums import PositionEmbeddingType, UL2ModelType
import megatron
from megatron.logging import log_levels

Expand Down Expand Up @@ -311,6 +311,17 @@ def parse_args(extra_args_provider=None, defaults={},
)
args.skip_train_iteration_range = skip_train_iteration_range

args.ul2_model_type = UL2ModelType(args.ul2_model_type)
if (
args.ul2_model_type is not UL2ModelType.ENCODER_DECODER
and args.decoder_seq_length is not None
):
print(
f'WARNING: `--decoder_seq_length` is ignored when '
f'`--ul2-model-type` is not '
f'"{UL2ModelType.ENCODER_DECODER.value}"!'
)

if args.use_bnb_optimizer:
try:
import bitsandbytes as bnb
Expand Down Expand Up @@ -1028,6 +1039,11 @@ def _add_vit_args(parser):
def _add_ul2_args(parser):
group = parser.add_argument_group(title="UL2")

group.add_argument('--ul2-model-type', type=str, default='ED',
choices=['ED', 'ND', 'CD'],
help='What type of model to use for UL2 pretraining. '
'ED = encoder-decoder; ND = non-causal decoder-only; '
'CD = causal decoder-only')
group.add_argument('--ul2-denoiser-ratios', nargs='+', type=float,
default=None,
help='Probability of each denoising objective to be '
Expand Down
1 change: 1 addition & 0 deletions megatron/data/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def build_dataset(index, name):
args = get_args()
dataset = UL2Dataset(
indexed_dataset=indexed_dataset,
model_type=args.ul2_model_type,
denoiser_ratios=args.ul2_denoiser_ratios,
denoisers=args.ul2_denoisers,
mean_span_lengths=args.ul2_mean_span_lengths,
Expand Down
161 changes: 117 additions & 44 deletions megatron/data/ul2_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

"""UL2-style dataset."""

import math

import numpy as np

from megatron import get_tokenizer
Expand All @@ -23,16 +25,34 @@
get_samples_mapping,
SamplingStyle
)
from megatron.data.t5_dataset import pad_and_convert_to_numpy, T5Dataset
from megatron.data.t5_dataset import (
make_history_mask,
merge_subsequent_masks,
pad_and_convert_to_numpy,
T5Dataset,
)
from megatron.enums import UL2ModelType


def is_decoder_only(ul2_model_type):
"""Return whether we use a decoder-only model."""
assert isinstance(ul2_model_type, UL2ModelType)
return ul2_model_type is not UL2ModelType.ENCODER_DECODER


def is_prefix_lm(ul2_model_type):
"""Return whether we use a non-causal decoder-only model."""
assert isinstance(ul2_model_type, UL2ModelType)
return ul2_model_type is UL2ModelType.NON_CAUSAL_DECODER


class UL2Dataset(T5Dataset):

def __init__(self, name, indexed_dataset, data_prefix,
num_epochs, max_num_samples, denoiser_ratios,
denoisers, mean_span_lengths, mask_ratios,
denoiser_tokens, max_seq_length, max_seq_length_dec,
short_seq_prob, seed):
num_epochs, max_num_samples, model_type,
denoiser_ratios, denoisers, mean_span_lengths,
mask_ratios, denoiser_tokens, max_seq_length,
max_seq_length_dec, short_seq_prob, seed):

if denoiser_ratios is None:
# Uniform distribution by default.
Expand All @@ -52,6 +72,7 @@ def __init__(self, name, indexed_dataset, data_prefix,
short_seq_prob, seed)

# Params to store.
self.model_type = model_type
self.denoiser_ratios = [
denoiser_ratio / sum(denoiser_ratios)
for denoiser_ratio in denoiser_ratios
Expand Down Expand Up @@ -97,21 +118,21 @@ def __getitem__(self, idx):
self.vocab_id_to_token_dict,
self.cls_ids, self.sep_id,
self.mask_id, self.pad_id,
self.denoiser_ratios, self.denoisers,
self.mean_span_lengths, self.mask_ratios,
np_rng,
self.bos_id, self.eos_id,
self.sentinel_tokens)
self.model_type, self.denoiser_ratios,
self.denoisers, self.mean_span_lengths,
self.mask_ratios, np_rng, self.bos_id,
self.eos_id, self.sentinel_tokens)


def build_training_sample(sample, target_seq_length,
max_seq_length, max_seq_length_dec,
vocab_id_list, vocab_id_to_token_dict,
cls_ids, sep_id, mask_id, pad_id,
denoiser_ratios, denoisers,
mean_span_lengths, mask_ratios,
np_rng, bos_id=None,
eos_id=None, sentinel_tokens=None):
model_type, denoiser_ratios,
denoisers, mean_span_lengths,
mask_ratios, np_rng,
bos_id=None, eos_id=None,
sentinel_tokens=None):
"""Build training sample.
Arguments:
Expand All @@ -125,6 +146,7 @@ def build_training_sample(sample, target_seq_length,
sep_id: Separator id.
mask_id: Mask token id.
pad_id: Padding token id.
model_type: What type of model is used.
denoiser_ratios: Probability of each denoising objective to be selected.
denoisers: What type of UL2 denoising objective the other UL2
configurations refer to.
Expand All @@ -139,24 +161,28 @@ def build_training_sample(sample, target_seq_length,
sentinel_tokens: unique value to be substituted for every replaced span
"""

# Denoiser selection
denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios)
denoiser = denoisers[denoiser_index]
masked_lm_prob = mask_ratios[denoiser_index]

assert target_seq_length <= max_seq_length

# flatten sentences into one list
tokens = [token for sentence in sample for token in sentence]

# Truncate to `target_sequence_length`.
max_num_tokens = target_seq_length
truncated = len(tokens) > max_num_tokens
tokens = tokens[:max_num_tokens]

# Denoiser selection
denoiser_index = np_rng.choice(np.arange(len(denoisers)), p=denoiser_ratios)
denoiser = denoisers[denoiser_index]
masked_lm_prob = mask_ratios[denoiser_index]
mean_ngrams = mean_span_lengths[denoiser_index]
if mean_ngrams < 1:
mean_ngrams = round(len(tokens) * mean_ngrams)
max_ngrams = mean_ngrams * 2 - 1
if is_decoder_only(model_type):
# Keep space for repeated `extra_id` tokens; not the most data
# efficient since we calculate this based on the maximum number
# of possible `extra_id` tokens.
safe_max_seq_len = math.floor(max_num_tokens / (1 + masked_lm_prob))
truncated = len(tokens) > safe_max_seq_len
tokens = tokens[:safe_max_seq_len]
else:
# Truncate to `target_sequence_length`.
truncated = len(tokens) > max_num_tokens
tokens = tokens[:max_num_tokens]

# Prepend objective token.
cls_id = cls_ids.get(denoiser)
Expand All @@ -166,6 +192,11 @@ def build_training_sample(sample, target_seq_length,

# Masking.
max_predictions_per_seq = masked_lm_prob * len(tokens)
mean_ngrams = mean_span_lengths[denoiser_index]
if mean_ngrams < 1:
mean_ngrams = round(len(tokens) * mean_ngrams)
max_ngrams = mean_ngrams * 2 - 1

if denoiser == 'R' or denoiser == 'X':
sampling_style = SamplingStyle.NORMAL
prefix_lm = False
Expand All @@ -183,22 +214,64 @@ def build_training_sample(sample, target_seq_length,
sampling_style=sampling_style, prefix_lm=prefix_lm,
)

# Padding.
tokens_enc, tokens_dec_in, labels, enc_mask, \
dec_mask, enc_dec_mask, loss_mask \
= pad_and_convert_to_numpy(tokens, masked_positions,
masked_labels, pad_id, max_seq_length,
max_seq_length_dec, masked_spans,
bos_id, eos_id, sentinel_tokens)

train_sample = {
'text_enc': tokens_enc,
'text_dec': tokens_dec_in,
'labels': labels,
'loss_mask': loss_mask,
'truncated': int(truncated),
'enc_mask': enc_mask,
'dec_mask': dec_mask,
'enc_dec_mask': enc_dec_mask,
}
if is_decoder_only(model_type):
# Concatenate to one sequence.
tokens_enc, tokens_dec_in, labels = merge_subsequent_masks(
tokens, masked_spans, bos_id, eos_id, sentinel_tokens)

# Move EOS tokens to end of sequence.
while tokens_enc[-1] == eos_id:
del tokens_enc[-1]
tokens_dec_in.append(eos_id)
labels.append(eos_id)

num_labels = len(labels)

# Move BOS token to start of sequence.
tokens_dec_in = tokens_dec_in[1:]
tokens = np.concatenate([
np.array([bos_id], dtype=np.int64),
tokens_enc,
np.array([sep_id], dtype=np.int64),
tokens_dec_in,
])
labels = np.concatenate([
tokens_enc,
np.array([sep_id], dtype=np.int64),
labels,
])

loss_mask = np.zeros(len(tokens), dtype=np.int64)
loss_mask[-num_labels:] = 1

dec_mask = make_history_mask(tokens)
if is_prefix_lm(model_type):
dec_mask[:-num_labels, :-num_labels] = 1

train_sample = {
'text': tokens,
'labels': labels,
'loss_mask': loss_mask,
'truncated': int(truncated),
'dec_mask': dec_mask,
}
else:
# Padding.
tokens_enc, tokens_dec_in, labels, enc_mask, \
dec_mask, enc_dec_mask, loss_mask \
= pad_and_convert_to_numpy(tokens, masked_positions,
masked_labels, pad_id, max_seq_length,
max_seq_length_dec, masked_spans,
bos_id, eos_id, sentinel_tokens)

train_sample = {
'text_enc': tokens_enc,
'text_dec': tokens_dec_in,
'labels': labels,
'loss_mask': loss_mask,
'truncated': int(truncated),
'enc_mask': enc_mask,
'dec_mask': dec_mask,
'enc_dec_mask': enc_dec_mask,
}
return train_sample
5 changes: 5 additions & 0 deletions megatron/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,8 @@ class PositionEmbeddingType(enum.Enum):
rotary = 1
absolute = 2
alibi = 3

class UL2ModelType(enum.Enum):
ENCODER_DECODER = 'ED'
NON_CAUSAL_DECODER = 'ND'
CAUSAL_DECODER = 'CD'
Loading

0 comments on commit 728e076

Please sign in to comment.