From a3d2e86a27c413011fc979059eb16788a3f079cf Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Mon, 29 Jan 2024 15:21:09 -0800 Subject: [PATCH 01/11] Support layer parallelism in transformer application --- applications/nlp/transformer/modeling.py | 2 + applications/nlp/transformer/parallelism.py | 73 +++++++++++++++++++++ applications/nlp/transformer/trainer.py | 4 +- python/lbann/contrib/args.py | 12 ++-- 4 files changed, 84 insertions(+), 7 deletions(-) diff --git a/applications/nlp/transformer/modeling.py b/applications/nlp/transformer/modeling.py index fc6ddf48bc8..95d46415144 100644 --- a/applications/nlp/transformer/modeling.py +++ b/applications/nlp/transformer/modeling.py @@ -124,6 +124,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace): ) parallelism.apply_fsdp_allweights(result, args) + parallelism.apply_layer_parallelism(transformer, result, args) return result @@ -227,6 +228,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int, ) parallelism.apply_fsdp_allweights(result, args) + parallelism.apply_layer_parallelism(transformer, result, args) return result diff --git a/applications/nlp/transformer/parallelism.py b/applications/nlp/transformer/parallelism.py index 7c56faf51eb..2f82990867d 100644 --- a/applications/nlp/transformer/parallelism.py +++ b/applications/nlp/transformer/parallelism.py @@ -7,6 +7,8 @@ import itertools import lbann import lbann.models.subgraph.transformer +import math +import re from typing import Any, Dict, Optional, List, Tuple, Union ############################################################################# @@ -195,6 +197,64 @@ def apply_subgraph_parallelism( return sgmodule, extra_model_kwargs +############################################################################# +# Layer parallelism + +lp_grids = None +def apply_layer_parallelism(module: lbann.models.Transformer, + model: lbann.Model, args: argparse.Namespace): + """ + Applies a model-parallel strategy on sequences of contiguous transformer + blocks, sometimes referred to as pipeline parallelism or layer parallelism. + + :param module: Transformer module to take as reference for block counts. + :param model: The model to modify. + :param args: Command-line arguments. + :param layers: If not None, a list of integers representing which blocks + to apply model parallelism to. + """ + if not args.layer_parallel: + return + + lp_count = args.lp_count + if args.lp_count == 0: + lp_count = args.nodes * args.procs_per_node + + blocks = len(module.encoder) + len(module.decoder) + + # Assign blocks to increasing grid tags + blocks_per_grid_tag = math.ceil(blocks / lp_count) + cur_grid_tag = 0 + + # Go over all layers in traversal order, applying grid tags in increasing order + last_block_id = -1 + block_id = -1 + total_block_id = 0 + for layer in model.layers: + if layer.name.startswith('transformer_decoder'): + block_id = int( + re.search(r'transformer_decoder(\d+)_', + layer.name).groups(1)[0]) + elif layer.name.startswith('transformer_encoder'): + block_id = int( + re.search(r'transformer_encoder(\d+)_', + layer.name).groups(1)[0]) + if last_block_id != block_id: + if total_block_id % blocks_per_grid_tag == 0: + cur_grid_tag += 1 + last_block_id = block_id + total_block_id += 1 + + # Apply layer parallelism + layer.grid_tag = { 'value': cur_grid_tag } + + global lp_grids + lp_grids = cur_grid_tag + +def get_layer_parallel_args() -> List[str]: + if lp_grids is not None: + return ['--num-subgrids', str(lp_grids)] + def add_transformer_parallelism_arguments(parser: argparse.Namespace, subgraph: bool = True): @@ -277,3 +337,16 @@ def add_transformer_parallelism_arguments(parser: argparse.Namespace, action='store_true', help='Apply Fully-Sharded Data-Parallelism (FSDP) and shard MLP weights' ) + + ####################################### + # Layer parallelism + parser.add_argument( + '--layer-parallel', + action='store_true', + help='Apply layer parallelism (also referred to as pipelining)') + parser.add_argument( + '--lp-count', + default=0, + type=int, + help='In layer parallelism, the number of portions to divide network to' + ' (Default: divide evenly between all ranks)') diff --git a/applications/nlp/transformer/trainer.py b/applications/nlp/transformer/trainer.py index d34a894ecd1..395512eee91 100644 --- a/applications/nlp/transformer/trainer.py +++ b/applications/nlp/transformer/trainer.py @@ -12,6 +12,7 @@ from lbann.launcher.batch_script import BatchScript import utils.paths +import parallelism def construct_training_task(model: lbann.Model, @@ -238,7 +239,8 @@ def make_batch_script(model: lbann.Model, script.add_parallel_command([ lbann.lbann_exe(), f'--prototext={protobuf_file}', - ] + lbann.contrib.args.get_profile_args(args)) + ] + (lbann.contrib.args.get_profile_args(args) + + parallelism.get_layer_parallel_args())) script.add_command('status=$?') script.add_command('echo "Finished training at $(date)"') script.add_command('exit ${status}') diff --git a/python/lbann/contrib/args.py b/python/lbann/contrib/args.py index d63acb0e8c6..a43de8a6f62 100644 --- a/python/lbann/contrib/args.py +++ b/python/lbann/contrib/args.py @@ -1,6 +1,6 @@ """Helper functions to add common command-line arguments.""" -from typing import Any +from typing import Any, List import argparse import shlex @@ -250,10 +250,10 @@ def add_profiling_arguments(parser: argparse.ArgumentParser) -> None: action='store_true', default=False, help='enable itemized memory usage analysis') - parser.add_argument('--profile-init', + parser.add_argument('--profile-noinit', action='store_true', default=False, - help='enable profiling initialization') + help='disable profiling initialization') parser.add_argument('--caliper', action='store_true', default=False, @@ -285,7 +285,7 @@ def create_profile_callbacks(args: argparse.Namespace) -> Any: """ try: profile = args.profile - profile_init = not args.profile_init + profile_noinit = args.profile_noinit memprofile = args.memory_profile memprof_verbose = args.memory_profile_verbose except AttributeError: @@ -294,7 +294,7 @@ def create_profile_callbacks(args: argparse.Namespace) -> Any: result = [] if profile: - result.append(lbann.CallbackProfiler(skip_init=profile_init)) + result.append(lbann.CallbackProfiler(skip_init=profile_noinit)) if memprofile: result.append(lbann.CallbackMemoryProfiler( detailed_first_step=memprof_verbose)) @@ -302,7 +302,7 @@ def create_profile_callbacks(args: argparse.Namespace) -> Any: return result -def get_profile_args(args: argparse.Namespace) -> list[str]: +def get_profile_args(args: argparse.Namespace) -> List[str]: """Get LBANN command-line arguments for profiling. The parsed arguments must be generated by an From f1c3bf8bc257f3a8b194dca042574ea52a12fb57 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 30 Jan 2024 16:47:48 -0800 Subject: [PATCH 02/11] Fix bug in the original Pile and keep epilogue layers on grid tag 0 --- applications/nlp/transformer/datasets/thepile.py | 11 ++++++++++- applications/nlp/transformer/parallelism.py | 7 ++++++- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/applications/nlp/transformer/datasets/thepile.py b/applications/nlp/transformer/datasets/thepile.py index ee82f01a910..8678e061afd 100644 --- a/applications/nlp/transformer/datasets/thepile.py +++ b/applications/nlp/transformer/datasets/thepile.py @@ -91,7 +91,7 @@ def get_train_sample(index): def get_val_sample(index): """Token indices for a data sample from the validation set.""" - text = dataset_train[index]['text'] + text = dataset_val[index]['text'] tokenized = tokenize(text) # Trim long sequences, left-pad short sequences @@ -120,3 +120,12 @@ def sample_dims(): def vocab_size(): return tokenizer.get_vocab_size() + + +if __name__ == '__main__': + print('Training samples:', num_train_samples()) + print('Validation samples:', num_val_samples()) + print('Training sample 101:') + print(tokenizer.decode(get_train_sample(101))) + print('Validation sample 233:') + print(tokenizer.decode(get_val_sample(233))) diff --git a/applications/nlp/transformer/parallelism.py b/applications/nlp/transformer/parallelism.py index 2f82990867d..bd511e9ccc3 100644 --- a/applications/nlp/transformer/parallelism.py +++ b/applications/nlp/transformer/parallelism.py @@ -230,7 +230,7 @@ def apply_layer_parallelism(module: lbann.models.Transformer, last_block_id = -1 block_id = -1 total_block_id = 0 - for layer in model.layers: + for i, layer in enumerate(model.layers): if layer.name.startswith('transformer_decoder'): block_id = int( re.search(r'transformer_decoder(\d+)_', @@ -248,6 +248,11 @@ def apply_layer_parallelism(module: lbann.models.Transformer, # Apply layer parallelism layer.grid_tag = { 'value': cur_grid_tag } + # ...everywhere but the epilogue layers + if i >= len(model.layers) - 8: + layer.grid_tag = { 'value': 0 } + print(layer.grid_tag['value'], '-', layer.name) + global lp_grids lp_grids = cur_grid_tag From e3a5b863efd17d7df3aa502f391c5b08c640afee Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 2 Feb 2024 16:36:49 -0800 Subject: [PATCH 03/11] Support variable-length pretokenized dataset --- .../varlen/pretokenize-validation.py | 34 ++++++ .../pretokenize/varlen/pretokenize.py | 78 +++++++++++++ .../datasets/thepile_pretokenized.py | 6 +- .../datasets/thepile_pretokenized_varlen.py | 105 ++++++++++++++++++ applications/nlp/transformer/parallelism.py | 3 + 5 files changed, 224 insertions(+), 2 deletions(-) create mode 100644 applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize-validation.py create mode 100644 applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize.py create mode 100644 applications/nlp/transformer/datasets/thepile_pretokenized_varlen.py diff --git a/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize-validation.py b/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize-validation.py new file mode 100644 index 00000000000..dc3dad824a0 --- /dev/null +++ b/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize-validation.py @@ -0,0 +1,34 @@ +from tqdm import trange +from multiprocessing import Pool +import numpy as np +import pickle + + +class Processor: + + def __init__(self, total_threads: int): + self.threads = total_threads + + def __call__(self, tid: int): + import thepile as dataset + num_samples = dataset.num_val_samples() + filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/val.bin' + len_filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/val-seqlen.bin' + + with open(filename, 'ab') as fp: + with open(len_filename, 'ab') as slfp: + for i in trange(num_samples): + text = dataset.dataset_val[i]['text'] + tokenized = dataset.tokenize(text) + sample = np.array(tokenized, dtype=np.uint16) + sample_len = np.array([len(sample)], dtype=np.uint32) + sample.tofile(fp) + sample_len.tofile(slfp) + + print('Done') + + +if __name__ == '__main__': + threads = 1 + with Pool(threads) as pool: + pool.map(Processor(threads), range(threads)) diff --git a/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize.py b/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize.py new file mode 100644 index 00000000000..90a811a8abb --- /dev/null +++ b/applications/nlp/transformer/datasets/pretokenize/varlen/pretokenize.py @@ -0,0 +1,78 @@ +from tqdm import trange +from multiprocessing import Pool +import numpy as np +import os +import argparse +from pathlib import Path + + +class Processor: + + def __init__(self, total_threads: int): + self.threads = total_threads + + def __call__(self, tid: int): + import thepile as dataset + num_samples = dataset.num_train_samples() + np.random.seed(20231023) + indices = np.random.permutation(num_samples) + local_samples = num_samples // self.threads + offset = tid * local_samples + # Add remainder + if tid == self.threads - 1: + local_samples += num_samples % self.threads + section = indices[offset:offset + local_samples] + filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/train-pretokenized-{tid:02d}-of-{self.threads}.bin' + len_filename = f'/p/vast1/data/datasets/the-pile-huggingface/pretokenized-varlen/train-seqlen-{tid:02d}-of-{self.threads}.bin' + + # Create file + if not os.path.isfile(filename): + Path(filename).touch() + if not os.path.isfile(len_filename): + Path(len_filename).touch() + + sz = os.path.getsize(len_filename) + assert sz % 4 == 0 + sequences_processed = sz // 4 + print(tid, ': Size in bytes:', sz, '. Sequences processed:', + sequences_processed) + + with open(filename, 'ab') as fp: + with open(len_filename, 'ab') as slfp: + for i in trange(sequences_processed, + section.shape[0], + desc=f'Thread {tid}'): + text = dataset.dataset_train[int(section[i])]['text'] + sample = dataset.tokenize(text) + sample = np.array(sample, dtype=np.uint16) + sample.tofile(fp) + sample_len = np.array([len(sample)], dtype=np.uint32) + sample_len.tofile(slfp) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + + parser.add_argument('-j', + action='store', + default=0, + type=int, + help='Threads (default 0 = number of cores)') + parser.add_argument('-t', + action='store', + default=0, + type=int, + help='Total Chunks (default 0 = number of threads)') + parser.add_argument('-o', + action='store', + default=0, + type=int, + help='Chunk offset (default 0)') + args = parser.parse_args() + + threads = args.j or os.cpu_count() + total_chunks = args.t or threads + offset = args.o + assert offset + threads <= total_chunks + with Pool(threads) as pool: + pool.map(Processor(total_chunks), range(offset, offset + threads)) diff --git a/applications/nlp/transformer/datasets/thepile_pretokenized.py b/applications/nlp/transformer/datasets/thepile_pretokenized.py index 89c490e2db0..65b7c7e80ad 100644 --- a/applications/nlp/transformer/datasets/thepile_pretokenized.py +++ b/applications/nlp/transformer/datasets/thepile_pretokenized.py @@ -1,5 +1,5 @@ """ -The Pile dataset, stored as pre-tokenized binary files for optimized processing. +The Pile dataset, stored as pre-tokenized, pre-packed binary files for optimized processing. """ import os import os.path @@ -10,7 +10,9 @@ # Options # ---------------------------------------------- -sequence_length = int(os.getenv('THE_PILE_SEQUENCE_LENGTH', default='512')) +# Sequence length is hardcoded to 512 in the pre-packed binary dataset. +# To use other sequence lengths, see ``thepile_pretokenized_varlen.py`` +sequence_length = 512 # ---------------------------------------------- # Setup diff --git a/applications/nlp/transformer/datasets/thepile_pretokenized_varlen.py b/applications/nlp/transformer/datasets/thepile_pretokenized_varlen.py new file mode 100644 index 00000000000..71e26e0c117 --- /dev/null +++ b/applications/nlp/transformer/datasets/thepile_pretokenized_varlen.py @@ -0,0 +1,105 @@ +""" +The Pile dataset, stored as pre-tokenized binary files for optimized processing. +""" +import os +import os.path + +import numpy as np +# ---------------------------------------------- +# Options +# ---------------------------------------------- + +sequence_length = int(os.getenv('THE_PILE_SEQUENCE_LENGTH', default='512')) + +# ---------------------------------------------- +# Setup +# ---------------------------------------------- + +# Load the datasets +data_dir = os.getenv('THE_PILE_DATA_DIR', + '/p/vast1/data/datasets/the-pile-pretokenized') +dataset_train = np.memmap(os.path.join(data_dir, 'train.bin'), + dtype=np.uint16, + mode='r') +sample_lengths_train = np.fromfile(os.path.join(data_dir, 'train-seqlen.bin'), + dtype=np.uint32).astype(np.uint64) +sample_offsets_train = np.zeros_like(sample_lengths_train) +sample_offsets_train[1:] = np.cumsum(sample_lengths_train)[:-1] +dataset_val = np.memmap(os.path.join(data_dir, 'val.bin'), + dtype=np.uint16, + mode='r') +sample_lengths_val = np.fromfile(os.path.join(data_dir, 'val-seqlen.bin'), + dtype=np.uint32).astype(np.uint64) +sample_offsets_val = np.zeros_like(sample_lengths_val) +sample_offsets_val[1:] = np.cumsum(sample_lengths_val)[:-1] + +# Uses the definition from the GPT-NeoX-20B tokenizer +pad_index = 1 # '<|padding|>' +_vocab_size = 50277 + +# ---------------------------------------------- +# Sample access functions +# ---------------------------------------------- + + +def trim_and_pad(sample, random: bool): + # Trim long sequences + if len(sample) > sequence_length: + if random: + pos = np.random.rand() + offset = (len(sample) - sequence_length + 1) * pos + offset = int(np.floor(offset)) + sample = sample[offset:offset + sequence_length] + else: + sample = sample[0:sequence_length] + + # Left-pad short sequences + if len(sample) < sequence_length: + sample_pad = np.full(sequence_length, pad_index, dtype=np.int32) + if len(sample) > 0: + sample_pad[-len(sample):] = sample + return sample_pad + + return sample + + +def get_train_sample(index: int): + sample = np.copy( + dataset_train[sample_offsets_train[index]:sample_offsets_train[index] + + sample_lengths_train[index]]).astype(np.int32) + return trim_and_pad(sample, True) + + +def get_val_sample(index): + sample = np.copy( + dataset_val[sample_offsets_val[index]:sample_offsets_val[index] + + sample_lengths_val[index]]).astype(np.int32) + return trim_and_pad(sample, False) + + +def num_train_samples(): + return sample_lengths_train.shape[0] + + +def num_val_samples(): + return sample_lengths_val.shape[0] + + +def sample_dims(): + return (sequence_length, ) + + +def vocab_size(): + return _vocab_size + + +if __name__ == '__main__': + print('Training samples:', num_train_samples()) + print('Validation samples:', num_val_samples()) + from tokenizers import Tokenizer + tokenizer = Tokenizer.from_file( + os.path.join(data_dir, '20B_tokenizer.json')) + print('Training sample 101:') + print(tokenizer.decode(get_train_sample(101))) + print('Validation sample 233:') + print(tokenizer.decode(get_val_sample(233))) diff --git a/applications/nlp/transformer/parallelism.py b/applications/nlp/transformer/parallelism.py index bd511e9ccc3..d1f22c89801 100644 --- a/applications/nlp/transformer/parallelism.py +++ b/applications/nlp/transformer/parallelism.py @@ -256,9 +256,12 @@ def apply_layer_parallelism(module: lbann.models.Transformer, global lp_grids lp_grids = cur_grid_tag + def get_layer_parallel_args() -> List[str]: if lp_grids is not None: return ['--num-subgrids', str(lp_grids)] + return [] + def add_transformer_parallelism_arguments(parser: argparse.Namespace, subgraph: bool = True): From 86ba8715ccce41e0958f06adba53425a210f17d0 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 2 Feb 2024 16:38:48 -0800 Subject: [PATCH 04/11] Support additional layer arguments in transformer --- applications/nlp/transformer/modeling.py | 6 +- applications/nlp/transformer/parallelism.py | 99 +++++++++----- python/lbann/models/transformer.py | 86 ++++++++----- python/lbann/modules/transformer/attention.py | 121 +++++++++++++----- python/lbann/modules/transformer/encoding.py | 76 +++++++---- 5 files changed, 266 insertions(+), 122 deletions(-) diff --git a/applications/nlp/transformer/modeling.py b/applications/nlp/transformer/modeling.py index 95d46415144..013a8e180b9 100644 --- a/applications/nlp/transformer/modeling.py +++ b/applications/nlp/transformer/modeling.py @@ -87,6 +87,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace): transformer, args) parallelism.apply_ffn_model_parallelism(transformer, args) parallelism.apply_fsdp_mlp(transformer, [embedding_weights], args) + parallelism.apply_layer_parallelism(transformer, args) # Run through transformer result = transformer(encoder_input, decoder_input, sequence_length - 1) @@ -124,7 +125,7 @@ def create_encoder_decoder_transformer(dataset, args: argparse.Namespace): ) parallelism.apply_fsdp_allweights(result, args) - parallelism.apply_layer_parallelism(transformer, result, args) + parallelism.apply_layer_parallelism_postamble(result, args) return result @@ -187,6 +188,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int, transformer, args) parallelism.apply_ffn_model_parallelism(transformer, args) parallelism.apply_fsdp_mlp(transformer, [embedding_weights], args) + parallelism.apply_layer_parallelism(transformer, args) # Run through transformer with the same sequence result = transformer(decoder_input, decoder_input, sequence_length) @@ -228,7 +230,7 @@ def create_causal_lm_decoder_transformer(dataset, embed_dim: int, ) parallelism.apply_fsdp_allweights(result, args) - parallelism.apply_layer_parallelism(transformer, result, args) + parallelism.apply_layer_parallelism_postamble(result, args) return result diff --git a/applications/nlp/transformer/parallelism.py b/applications/nlp/transformer/parallelism.py index d1f22c89801..0e74310c919 100644 --- a/applications/nlp/transformer/parallelism.py +++ b/applications/nlp/transformer/parallelism.py @@ -4,6 +4,7 @@ strategies found in this file. """ import argparse +import collections import itertools import lbann import lbann.models.subgraph.transformer @@ -201,17 +202,16 @@ def apply_subgraph_parallelism( # Layer parallelism lp_grids = None + + def apply_layer_parallelism(module: lbann.models.Transformer, - model: lbann.Model, args: argparse.Namespace): + args: argparse.Namespace): """ Applies a model-parallel strategy on sequences of contiguous transformer blocks, sometimes referred to as pipeline parallelism or layer parallelism. - :param module: Transformer module to take as reference for block counts. - :param model: The model to modify. + :param module: Transformer module to modify. :param args: Command-line arguments. - :param layers: If not None, a list of integers representing which blocks - to apply model parallelism to. """ if not args.layer_parallel: return @@ -226,43 +226,80 @@ def apply_layer_parallelism(module: lbann.models.Transformer, blocks_per_grid_tag = math.ceil(blocks / lp_count) cur_grid_tag = 0 - # Go over all layers in traversal order, applying grid tags in increasing order - last_block_id = -1 - block_id = -1 - total_block_id = 0 - for i, layer in enumerate(model.layers): - if layer.name.startswith('transformer_decoder'): - block_id = int( - re.search(r'transformer_decoder(\d+)_', - layer.name).groups(1)[0]) - elif layer.name.startswith('transformer_encoder'): - block_id = int( - re.search(r'transformer_encoder(\d+)_', - layer.name).groups(1)[0]) - if last_block_id != block_id: - if total_block_id % blocks_per_grid_tag == 0: - cur_grid_tag += 1 - last_block_id = block_id - total_block_id += 1 - - # Apply layer parallelism - layer.grid_tag = { 'value': cur_grid_tag } - - # ...everywhere but the epilogue layers - if i >= len(model.layers) - 8: - layer.grid_tag = { 'value': 0 } - print(layer.grid_tag['value'], '-', layer.name) + # Go over all blocks, applying grid tags in increasing order + for i, block in enumerate(itertools.chain(module.encoder, module.decoder)): + cur_grid_tag = max(cur_grid_tag, (i // blocks_per_grid_tag) + 1) + block.extra_layer_args['grid_tag'] = cur_grid_tag global lp_grids lp_grids = cur_grid_tag +def _get_grid_tag(tag: Union[int, Dict[str, int]]): + if isinstance(tag, dict): + return tag.get('value', 0) + return tag + + +def apply_layer_parallelism_postamble(model: lbann.Model, + args: argparse.Namespace): + """ + Applies post-model creation optimizations of the layer-parallel strategy + (see ``apply_layer_parallelism``). + + :param model: LBANN Model to modify. + :param args: Command-line arguments. + """ + if not args.layer_parallel: + return + + # Loop over all layers that have multiple outgoing cross-grid edges + layers_to_insert = [] + for i, layer in enumerate(model.layers): + if len(layer.children) == 1: + continue + tag = _get_grid_tag(layer.grid_tag) + unique_grids = collections.defaultdict(list) + new_children = [] + for child in layer.children: + ctag = _get_grid_tag(child.grid_tag) + if ctag != tag: + unique_grids[ctag].append(child) + new_children.append(None) + else: + new_children.append(child) + + # Inject interim layers for each grid and reconnect + for dst_grid, children in unique_grids.items(): + interim = lbann.Identity(layer, grid_tag=dst_grid) + layers_to_insert.append((i, interim)) + + # Reconnect parents + for child in children: + pind = child.parents.index(layer) + child.parents[pind] = interim + cind = layer.children.index(child) + new_children[cind] = interim + + # Reconnect children + if unique_grids: + layer.children = new_children + + # Add identity layers to the traversed graph right after the source layer + # was computed + for i, l in reversed(layers_to_insert): + model.layers.insert(i, l) + + def get_layer_parallel_args() -> List[str]: if lp_grids is not None: return ['--num-subgrids', str(lp_grids)] return [] +############################################################################# + + def add_transformer_parallelism_arguments(parser: argparse.Namespace, subgraph: bool = True): diff --git a/python/lbann/models/transformer.py b/python/lbann/models/transformer.py index 95a10ffde35..6c7bf796dce 100644 --- a/python/lbann/models/transformer.py +++ b/python/lbann/models/transformer.py @@ -40,30 +40,33 @@ def __init__(self, normalized_shape, name=None, builtin=True): name=f'{self.name}_bias', ) - def forward(self, x): + def forward(self, x, **extra_kwargs): if self.builtin: return lbann.LayerNorm(x, scale=True, bias=True, start_dim=-1, name=self.name, - weights=[self.weight, self.bias]) + weights=[self.weight, self.bias], + **extra_kwargs) # Normalization - x = lbann.InstanceNorm(x) + x = lbann.InstanceNorm(x, **extra_kwargs) # Affine transform s = lbann.WeightsLayer( weights=self.weight, dims=[1] + list(make_iterable(self.normalized_shape)), + **extra_kwargs, ) - s = lbann.Tessellate(s, hint_layer=x) + s = lbann.Tessellate(s, hint_layer=x, **extra_kwargs) b = lbann.WeightsLayer( weights=self.bias, dims=[1] + list(make_iterable(self.normalized_shape)), + **extra_kwargs, ) - b = lbann.Tessellate(b, hint_layer=x) - x = lbann.Add(lbann.Multiply(s, x), b) + b = lbann.Tessellate(b, hint_layer=x, **extra_kwargs) + x = lbann.Add(lbann.Multiply(s, x, **extra_kwargs), b, **extra_kwargs) return x @@ -124,6 +127,7 @@ def __init__( self.pre_layernorm = pre_layernorm self.activation = activation self.extra_ffn_args = {} + self.extra_layer_args = {} # Module name self.name = name @@ -172,26 +176,27 @@ def forward(self, x, mask=None): name = f'{self.name}_instance{self.instance}' if self.pre_layernorm: - y = self.norm1(x) + y = self.norm1(x, **self.extra_layer_args) else: y = x # Self-attention with residual connection - y = self.attention(y, y, y, mask=mask) + y = self.attention(y, y, y, mask=mask, **self.extra_layer_args) if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop1', + **self.extra_layer_args, ) - z = lbann.Sum(x, y, name=f'{name}_sum1') + z = lbann.Sum(x, y, name=f'{name}_sum1', **self.extra_layer_args) if not self.pre_layernorm: - z = self.norm1(z) + z = self.norm1(z, **self.extra_layer_args) x = z # Feedforward network with residual connection if self.pre_layernorm: - y = self.norm2(z) + y = self.norm2(z, **self.extra_layer_args) else: y = x @@ -200,14 +205,19 @@ def forward(self, x, mask=None): weights=self.fc1_weights, output_channel_dims=[self.feedforward_dim], name=f'{name}_fc1', + **self.extra_layer_args, **self.extra_ffn_args, ) - y = self.activation(y, name=f'{name}_ffn_act', **self.extra_ffn_args) + y = self.activation(y, + name=f'{name}_ffn_act', + **self.extra_layer_args, + **self.extra_ffn_args) if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop2', + **self.extra_layer_args, **self.extra_ffn_args, ) y = lbann.ChannelwiseFullyConnected( @@ -215,6 +225,7 @@ def forward(self, x, mask=None): weights=self.fc2_weights, output_channel_dims=[self.embed_dim], name=f'{name}_fc2', + **self.extra_layer_args, **self.extra_ffn_args, ) if self.dropout_prob > 0: @@ -222,11 +233,12 @@ def forward(self, x, mask=None): y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop3', + **self.extra_layer_args, **self.extra_ffn_args, ) - z = lbann.Sum(x, y, name=f'{name}_sum2') + z = lbann.Sum(x, y, name=f'{name}_sum2', **self.extra_layer_args) if not self.pre_layernorm: - z = self.norm2(z) + z = self.norm2(z, **self.extra_layer_args) return z @@ -288,6 +300,7 @@ def __init__( self.pre_layernorm = pre_layernorm self.activation = activation self.extra_ffn_args = {} + self.extra_layer_args = {} # Module name self.name = name @@ -350,22 +363,23 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None): name = f'{self.name}_instance{self.instance}' if self.pre_layernorm: - y = self.norm1(x) + y = self.norm1(x, **self.extra_layer_args) else: y = x # Self-attention with residual connection - y = self.attention1(y, y, y, mask=tgt_mask) + y = self.attention1(y, y, y, mask=tgt_mask, **self.extra_layer_args) if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop1', + **self.extra_layer_args, ) - z = lbann.Sum(x, y, name=f'{name}_sum1') + z = lbann.Sum(x, y, name=f'{name}_sum1', **self.extra_layer_args) if not self.pre_layernorm: - z = self.norm1(z) + z = self.norm1(z, **self.extra_layer_args) x = z @@ -373,27 +387,30 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None): if memory is not None: # Attention on encoder output with residual connection if self.pre_layernorm: - y = self.norm2(x) + y = self.norm2(x, **self.extra_layer_args) else: y = x - y = self.attention2(y, memory, memory, mask=src_mask) + y = self.attention2(y, + memory, + memory, + mask=src_mask, + **self.extra_layer_args) if self.dropout_prob > 0: - y = lbann.Dropout( - y, - keep_prob=1 - self.dropout_prob, - name=f'{name}_drop2', - ) - z = lbann.Sum(x, y, name=f'{name}_sum2') + y = lbann.Dropout(y, + keep_prob=1 - self.dropout_prob, + name=f'{name}_drop2', + **self.extra_layer_args) + z = lbann.Sum(x, y, name=f'{name}_sum2', **self.extra_layer_args) if not self.pre_layernorm: - z = self.norm2(z) + z = self.norm2(z, **self.extra_layer_args) x = z # Feedforward network with residual connection if self.pre_layernorm: - y = self.norm3(x) + y = self.norm3(x, **self.extra_layer_args) else: y = x @@ -402,14 +419,19 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None): weights=self.fc1_weights, output_channel_dims=[self.feedforward_dim], name=f'{name}_fc1', + **self.extra_layer_args, **self.extra_ffn_args, ) - y = self.activation(y, name=f'{name}_ffn_act', **self.extra_ffn_args) + y = self.activation(y, + name=f'{name}_ffn_act', + **self.extra_layer_args, + **self.extra_ffn_args) if self.dropout_prob > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop3', + **self.extra_layer_args, **self.extra_ffn_args, ) y = lbann.ChannelwiseFullyConnected( @@ -417,6 +439,7 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None): weights=self.fc2_weights, output_channel_dims=[self.embed_dim], name=f'{name}_fc2', + **self.extra_layer_args, **self.extra_ffn_args, ) if self.dropout_prob > 0: @@ -424,12 +447,13 @@ def forward(self, x, memory, src_mask=None, tgt_mask=None): y, keep_prob=1 - self.dropout_prob, name=f'{name}_drop4', + **self.extra_layer_args, **self.extra_ffn_args, ) - z = lbann.Sum(x, y, name=f'{name}_sum3') + z = lbann.Sum(x, y, name=f'{name}_sum3', **self.extra_layer_args) if not self.pre_layernorm: - z = self.norm3(z) + z = self.norm3(z, **self.extra_layer_args) return z diff --git a/python/lbann/modules/transformer/attention.py b/python/lbann/modules/transformer/attention.py index cac70833325..b721d9a2e80 100644 --- a/python/lbann/modules/transformer/attention.py +++ b/python/lbann/modules/transformer/attention.py @@ -113,7 +113,13 @@ def __init__(self, name=f'{self.name}_output_bias'), ] - def forward(self, queries, keys, values, mask=None, seqlen=None): + def forward(self, + queries, + keys, + values, + mask=None, + seqlen=None, + **extra_kwargs): """Apply multi-head attention. The input and output tensors are interpreted as sequences of @@ -147,7 +153,8 @@ def forward(self, queries, keys, values, mask=None, seqlen=None): output_channel_dims=[self.embed_dim * 3], name=f'{name}_qkv_fc', bias=True, - transpose=False) + transpose=False, + **extra_kwargs) # Unstack qkv_slice = lbann.Slice(qkv_fc, @@ -155,10 +162,11 @@ def forward(self, queries, keys, values, mask=None, seqlen=None): slice_points=[ 0, self.embed_dim, 2 * self.embed_dim, 3 * self.embed_dim - ]) - queries_fc = lbann.Identity(qkv_slice) - keys_fc = lbann.Identity(qkv_slice) - values_fc = lbann.Identity(qkv_slice) + ], + **extra_kwargs) + queries_fc = lbann.Identity(qkv_slice, **extra_kwargs) + keys_fc = lbann.Identity(qkv_slice, **extra_kwargs) + values_fc = lbann.Identity(qkv_slice, **extra_kwargs) else: # Otherwise, apply fully-connected layers to input sequences separately queries_fc = lbann.ChannelwiseFullyConnected( @@ -166,47 +174,52 @@ def forward(self, queries, keys, values, mask=None, seqlen=None): weights=self.query_weights, output_channel_dims=[self.embed_dim], name=f'{name}_queries_fc', + **extra_kwargs, ) keys_fc = lbann.ChannelwiseFullyConnected( keys, weights=self.key_weights, output_channel_dims=[self.embed_dim], name=f'{name}_keys_fc', + **extra_kwargs, ) values_fc = lbann.ChannelwiseFullyConnected( values, weights=self.value_weights, output_channel_dims=[self.embed_dim], name=f'{name}_values_fc', + **extra_kwargs, ) if self.positional_encoding is not None: queries_fc, keys_fc, values_fc = self.positional_encoding.apply_layer( - queries_fc, keys_fc, values_fc, seqlen) + queries_fc, keys_fc, values_fc, seqlen, **extra_kwargs) if self.separate_heads: attentions = self.dot_product_attn_separate_heads( - name, queries_fc, keys_fc, values_fc, mask) + name, queries_fc, keys_fc, values_fc, mask, **extra_kwargs) else: attentions = self.dot_product_attn_batched(name, queries_fc, keys_fc, values_fc, - mask) + mask, **extra_kwargs) outputs_fc = lbann.ChannelwiseFullyConnected( attentions, weights=self.output_weights, output_channel_dims=[self.embed_dim], name=f'{name}', + **extra_kwargs, ) return outputs_fc def dot_product_attn_batched(self, name, queries_fc, keys_fc, values_fc, - mask): + mask, **extra_kwargs): head_name = f'{name}_all_heads' queries_fc = lbann.Scale( queries_fc, constant=1 / math.sqrt(self.head_dim), name=f'{head_name}_scale', + **extra_kwargs, ) # Dimension key: @@ -216,15 +229,24 @@ def dot_product_attn_batched(self, name, queries_fc, keys_fc, values_fc, # * P = Head size # SxE -> HxPxS - q_headsfirst = lbann.TensorPermute(queries_fc, axes=(1, 0)) + q_headsfirst = lbann.TensorPermute(queries_fc, + axes=(1, 0), + **extra_kwargs) q_headsfirst = lbann.Reshape(q_headsfirst, - dims=(self.num_heads, self.head_dim, -1)) - k_headsfirst = lbann.TensorPermute(keys_fc, axes=(1, 0)) + dims=(self.num_heads, self.head_dim, -1), + **extra_kwargs) + k_headsfirst = lbann.TensorPermute(keys_fc, + axes=(1, 0), + **extra_kwargs) k_headsfirst = lbann.Reshape(k_headsfirst, - dims=(self.num_heads, self.head_dim, -1)) - v_headsfirst = lbann.TensorPermute(values_fc, axes=(1, 0)) + dims=(self.num_heads, self.head_dim, -1), + **extra_kwargs) + v_headsfirst = lbann.TensorPermute(values_fc, + axes=(1, 0), + **extra_kwargs) v_headsfirst = lbann.Reshape(v_headsfirst, - dims=(self.num_heads, self.head_dim, -1)) + dims=(self.num_heads, self.head_dim, -1), + **extra_kwargs) # HxPxS -> HxSxS y = lbann.MatMul( @@ -233,24 +255,30 @@ def dot_product_attn_batched(self, name, queries_fc, keys_fc, values_fc, transpose_a=True, transpose_b=False, name=f'{head_name}_matmul', + **extra_kwargs, ) if mask: - y = lbann.Add(y, mask, name=f'{head_name}_mask') + y = lbann.Add(y, mask, name=f'{head_name}_mask', **extra_kwargs) if self.bias: - y = lbann.Add(y, self.bias, name=f'{head_name}_attnbias') + y = lbann.Add(y, + self.bias, + name=f'{head_name}_attnbias', + **extra_kwargs) y = lbann.ChannelwiseSoftmax(y, dim=-1, single_dim_mode=True, - name=f'{head_name}_softmax') + name=f'{head_name}_softmax', + **extra_kwargs) if self.dropout > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout, name=f'{head_name}_drop', + **extra_kwargs, ) # Attention output as batched matrix multiplication @@ -258,11 +286,16 @@ def dot_product_attn_batched(self, name, queries_fc, keys_fc, values_fc, attentions = lbann.MatMul(y, v_headsfirst, transpose_b=True, - name=head_name) + name=head_name, + **extra_kwargs) # HxSxP -> SxE - attentions = lbann.TensorPermute(attentions, axes=(1, 0, 2)) - attentions = lbann.Reshape(attentions, dims=(-1, self.embed_dim)) + attentions = lbann.TensorPermute(attentions, + axes=(1, 0, 2), + **extra_kwargs) + attentions = lbann.Reshape(attentions, + dims=(-1, self.embed_dim), + **extra_kwargs) return attentions def _get_subgraph(self, tag_id: int = 0) -> Dict[str, int]: @@ -279,7 +312,7 @@ def _get_subgraph(self, tag_id: int = 0) -> Dict[str, int]: return dict(grid_tag=tag_id) def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc, - values_fc, mask): + values_fc, mask, **extra_kwargs): # Slice embedding vectors for each head slice_points = [self.head_dim * i for i in range(self.num_heads + 1)] queries_slice = lbann.Slice( @@ -288,6 +321,7 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc, slice_points=slice_points, name=f'{name}_queries_slice', parallel_strategy=self._get_subgraph(), + **extra_kwargs, ) keys_slice = lbann.Slice( keys_fc, @@ -295,6 +329,7 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc, slice_points=slice_points, name=f'{name}_keys_slice', parallel_strategy=self._get_subgraph(), + **extra_kwargs, ) values_slice = lbann.Slice( values_fc, @@ -302,6 +337,7 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc, slice_points=slice_points, name=f'{name}_values_slice', parallel_strategy=self._get_subgraph(), + **extra_kwargs, ) if self.subgraph_branches > 0 and mask is not None: @@ -321,11 +357,14 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc, # Attention inputs q = lbann.Identity(queries_slice, - parallel_strategy=self._get_subgraph(tag)) + parallel_strategy=self._get_subgraph(tag), + **extra_kwargs) k = lbann.Identity(keys_slice, - parallel_strategy=self._get_subgraph(tag)) + parallel_strategy=self._get_subgraph(tag), + **extra_kwargs) v = lbann.Identity(values_slice, - parallel_strategy=self._get_subgraph(tag)) + parallel_strategy=self._get_subgraph(tag), + **extra_kwargs) # Multiply queries and keys # Note: num_queries x num_keys @@ -334,36 +373,52 @@ def dot_product_attn_separate_heads(self, name, queries_fc, keys_fc, k, transpose_b=True, name=f'{head_name}_matmul', + **extra_kwargs, ) y = lbann.Scale(y, constant=1 / math.sqrt(self.head_dim), - name=f'{head_name}_scale') + name=f'{head_name}_scale', + **extra_kwargs) if mask: if self.subgraph_branches > 0: - y = lbann.Add(y, mask[tag - 1], name=f'{head_name}_mask') + y = lbann.Add(y, + mask[tag - 1], + name=f'{head_name}_mask', + **extra_kwargs) else: - y = lbann.Add(y, mask, name=f'{head_name}_mask') + y = lbann.Add(y, + mask, + name=f'{head_name}_mask', + **extra_kwargs) if self.bias: - y = lbann.Add(y, self.bias, name=f'{head_name}_attnbias') + y = lbann.Add(y, + self.bias, + name=f'{head_name}_attnbias', + **extra_kwargs) - y = lbann.ChannelwiseSoftmax(y, name=f'{head_name}_softmax') + y = lbann.ChannelwiseSoftmax(y, + name=f'{head_name}_softmax', + **extra_kwargs) if self.dropout > 0: y = lbann.Dropout( y, keep_prob=1 - self.dropout, name=f'{head_name}_drop', + **extra_kwargs, ) # Attention output # Note: num_queries x head_dim - attentions.append(lbann.MatMul(y, v, name=head_name)) + attentions.append( + lbann.MatMul(y, v, name=head_name, **extra_kwargs)) # Concatenate heads and apply fully-connected layer attentions = lbann.Concatenation( attentions, axis=1, name=f'{name}_heads_concat', - parallel_strategy=self._get_subgraph()) + parallel_strategy=self._get_subgraph(), + **extra_kwargs) return attentions diff --git a/python/lbann/modules/transformer/encoding.py b/python/lbann/modules/transformer/encoding.py index cd56915002f..653ce6d6f9d 100644 --- a/python/lbann/modules/transformer/encoding.py +++ b/python/lbann/modules/transformer/encoding.py @@ -15,20 +15,22 @@ class SequenceEncoding: the layer type and index. """ - def apply_input(self, x: lbann.Layer, length: int) -> lbann.Layer: + def apply_input(self, x: lbann.Layer, length: int, + **extra_kwargs) -> lbann.Layer: """ Applies sequence encoding on the input of a transformer, immediately after token embedding. :param x: The output of the embedded sequence minibatch. :param length: Sequence length. + :param extra_kwargs: Additional arguments to pass to each internal Layer. :return: Encoded input. """ return x # Do nothing def apply_layer( - self, q: lbann.Layer, k: lbann.Layer, v: lbann.Layer, - length: int) -> Tuple[lbann.Layer, lbann.Layer, lbann.Layer]: + self, q: lbann.Layer, k: lbann.Layer, v: lbann.Layer, length: int, + **extra_kwargs) -> Tuple[lbann.Layer, lbann.Layer, lbann.Layer]: """ Applies sequence encoding within a transformer encoder/decoder layer. Encoding is performed on each transformer layer's multi-head attention @@ -38,6 +40,7 @@ def apply_layer( :param k: The input keys of the transformer layer. :param v: The input values of the transformer layer. :param length: Sequence length. + :param extra_kwargs: Additional arguments to pass to each internal Layer. :return: Encoded tuple of (q, k, v). """ return q, k, v # Do nothing @@ -121,13 +124,14 @@ def _positional_encoding(self, sequence_length): # Return cached positional encoding return self._positional_encoding_cache[sequence_length] - def apply_input(self, inputs, input_length): + def apply_input(self, inputs, input_length, **extra_kwargs): self.instance += 1 result = lbann.Add( inputs, self._positional_encoding(input_length), name=f'{self.name}_instance{self.instance}_peadd', + **extra_kwargs, ) # Input dropout @@ -136,6 +140,7 @@ def apply_input(self, inputs, input_length): result, keep_prob=1 - self.dropout_prob, name=f'{self.name}_pedrop', + **extra_kwargs, ) return result @@ -182,7 +187,11 @@ def compute_embeddings(self): embedding_dim=self.embed_dim, ) - def apply_input(self, inputs, input_length, learned_encoding=None): + def apply_input(self, + inputs, + input_length, + learned_encoding=None, + **extra_kwargs): self.instance += 1 if learned_encoding is None: @@ -193,12 +202,14 @@ def apply_input(self, inputs, input_length, learned_encoding=None): learned_encoding = lbann.Identity( lbann.Slice(learned_encoding, axis=0, - slice_points=[0, input_length])) + slice_points=[0, input_length], + **extra_kwargs), **extra_kwargs) result = lbann.Add( inputs, learned_encoding, name=f'{self.name}_instance{self.instance}_peadd', + **extra_kwargs, ) # Input dropout @@ -207,6 +218,7 @@ def apply_input(self, inputs, input_length, learned_encoding=None): result, keep_prob=1 - self.dropout_prob, name=f'{self.name}_pedrop', + **extra_kwargs, ) return result @@ -278,35 +290,43 @@ def _precompute_frequencies(self, sequence_length: int): _make_constant_from_array(sin, f'rope_sin_{sequence_length}'), ) - def _rotate_half(self, x: lbann.Layer, length: int): + def _rotate_half(self, x: lbann.Layer, length: int, **extra_kwargs): """ Helper method that rotates half of a tensor x. """ # SxE -> SxHxP - r = lbann.Reshape(x, dims=(length, self.num_heads, self.dim)) - s = lbann.Slice(r, slice_points=[0, self.dim // 2, self.dim], axis=2) - x1 = lbann.Identity(s) - x2 = lbann.Identity(s) - nx2 = lbann.Scale(x2, constant=-1) - cat = lbann.Concatenation([nx2, x1], axis=2) + r = lbann.Reshape(x, + dims=(length, self.num_heads, self.dim), + **extra_kwargs) + s = lbann.Slice(r, + slice_points=[0, self.dim // 2, self.dim], + axis=2, + **extra_kwargs) + x1 = lbann.Identity(s, **extra_kwargs) + x2 = lbann.Identity(s, **extra_kwargs) + nx2 = lbann.Scale(x2, constant=-1, **extra_kwargs) + cat = lbann.Concatenation([nx2, x1], axis=2, **extra_kwargs) # Reshape back to SxE - return lbann.Reshape(cat, dims=(length, self.num_heads * self.dim)) + return lbann.Reshape(cat, + dims=(length, self.num_heads * self.dim), + **extra_kwargs) def _embed(self, x: lbann.Layer, length: int, sliced_cos: lbann.Layer, - sliced_sin: lbann.Layer): + sliced_sin: lbann.Layer, **extra_kwargs): """ Helper method that applies rotary embeddings on a tensor x. """ - rot = self._rotate_half(x, length) + rot = self._rotate_half(x, length, **extra_kwargs) return lbann.Add( - lbann.Multiply(x, sliced_cos), - lbann.Multiply(rot, sliced_sin), + lbann.Multiply(x, sliced_cos, **extra_kwargs), + lbann.Multiply(rot, sliced_sin, **extra_kwargs), + **extra_kwargs, ) def apply_layer( - self, q: lbann.Layer, k: lbann.Layer, v: lbann.Layer, - length: int) -> Tuple[lbann.Layer, lbann.Layer, lbann.Layer]: + self, q: lbann.Layer, k: lbann.Layer, v: lbann.Layer, length: int, + **extra_kwargs) -> Tuple[lbann.Layer, lbann.Layer, lbann.Layer]: # If length is not given, maximum sequence length is assumed if length is None: length = self.max_sequence_length @@ -316,15 +336,21 @@ def apply_layer( sliced_sin = self.sin else: sliced_cos = lbann.Identity( - lbann.Slice(self.cos, slice_points=[0, length], axis=0)) + lbann.Slice(self.cos, + slice_points=[0, length], + axis=0, + **extra_kwargs), **extra_kwargs) sliced_sin = lbann.Identity( - lbann.Slice(self.sin, slice_points=[0, length], axis=0)) + lbann.Slice(self.sin, + slice_points=[0, length], + axis=0, + **extra_kwargs), **extra_kwargs) - eq = self._embed(q, length, sliced_cos, sliced_sin) - ek = self._embed(k, length, sliced_cos, sliced_sin) + eq = self._embed(q, length, sliced_cos, sliced_sin, **extra_kwargs) + ek = self._embed(k, length, sliced_cos, sliced_sin, **extra_kwargs) if self.embed_values: - ev = self._embed(v, length, sliced_cos, sliced_sin) + ev = self._embed(v, length, sliced_cos, sliced_sin, **extra_kwargs) else: ev = v From 8f991041a86e7522f25f12cd8091eb76b312db71 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 2 Feb 2024 16:39:22 -0800 Subject: [PATCH 05/11] General-purpose experiment file loading API --- python/lbann/proto/serialize.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/python/lbann/proto/serialize.py b/python/lbann/proto/serialize.py index 63b3f7f60f5..37d9eca0c7f 100644 --- a/python/lbann/proto/serialize.py +++ b/python/lbann/proto/serialize.py @@ -80,3 +80,21 @@ def bin2text(infile: str, outfile: str): f.write( google.protobuf.text_format.MessageToString( message, use_index_order=True).encode()) + + +def generic_load(filename: str): + """ + Loads a .protobin or .prototext file. + """ + try: # Try binary first + message = lbann_pb2.LbannPB() + + # Read file + with open(filename, 'rb') as f: + message.ParseFromString(f.read()) + except: # Try text + with open(filename, 'rb') as f: + message = google.protobuf.text_format.Parse( + f.read(), lbann_pb2.LbannPB()) + + return message From 2477ac2ccc7bee327ffc9f71491894e27876e026 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Fri, 2 Feb 2024 16:42:01 -0800 Subject: [PATCH 06/11] Visualization: support new models, operator layers, grid tag coloring, and cross-grid edge highlighting --- scripts/viz.py | 70 ++++++++++++++++++++++++++++++++++++++------------ 1 file changed, 54 insertions(+), 16 deletions(-) diff --git a/scripts/viz.py b/scripts/viz.py index 09d7d86dd33..09217d0c405 100755 --- a/scripts/viz.py +++ b/scripts/viz.py @@ -2,10 +2,17 @@ """Visualize an LBANN model's layer graph and save to file.""" import argparse +import random import re import graphviz -import google.protobuf.text_format from lbann import lbann_pb2, layers_pb2 +from lbann.proto import serialize + +# Pastel rainbow (slightly shuffled) from colorkit.co +palette = [ + '#ffffff', '#a0c4ff', '#ffadad', '#fdffb6', '#caffbf', '#9bf6ff', + '#bdb2ff', '#ffc6ff', '#ffd6a5' +] # Parse command-line arguments parser = argparse.ArgumentParser( @@ -17,14 +24,14 @@ parser.add_argument('output', action='store', nargs='?', - default='graph.pdf', + default='graph.dot', type=str, - help='output file (default: graph.pdf)') + help='output file (default: graph.dot)') parser.add_argument('--file-format', action='store', - default='pdf', + default='dot', type=str, - help='output file format (default: pdf)', + help='output file format (default: dot)', metavar='FORMAT') parser.add_argument('--label-format', action='store', @@ -39,6 +46,10 @@ type=str, help='Graphviz visualization scheme (default: dot)', metavar='ENGINE') +parser.add_argument('--color-cross-grid', + action='store_true', + default=False, + help='Highlight cross-grid edges') args = parser.parse_args() # Strip extension from filename @@ -51,9 +62,7 @@ label_format = re.sub(r' |-|_', '', args.label_format.lower()) # Read prototext file -proto = lbann_pb2.LbannPB() -with open(args.input, 'r') as f: - google.protobuf.text_format.Merge(f.read(), proto) +proto = serialize.generic_load(args.input) model = proto.model # Construct graphviz graph @@ -62,29 +71,36 @@ engine=args.graphviz_engine) graph.attr('node', shape='rect') +layer_to_grid_tag = {} + # Construct nodes in layer graph layer_types = (set(layers_pb2.Layer.DESCRIPTOR.fields_by_name.keys()) - set([ 'name', 'parents', 'children', 'datatype', 'data_layout', 'device_allocation', 'weights', 'freeze', 'hint_layer', 'top', 'bottom', - 'type', 'motif_layer' + 'type', 'motif_layer', 'parallel_strategy', 'grid_tag' ])) for l in model.layer: # Determine layer type - type = '' + ltype = '' for _type in layer_types: if l.HasField(_type): - type = getattr(l, _type).DESCRIPTOR.name + ltype = getattr(l, _type).DESCRIPTOR.name break + # If operator layer, use operator type + if ltype == 'OperatorLayer': + url = l.operator_layer.ops[0].parameters.type_url + ltype = url[url.rfind('.') + 1:] + # Construct node label label = '' if label_format == 'nameonly': label = l.name elif label_format == 'typeonly': - label = type + label = ltype elif label_format == 'typeandname': - label = '<{0}
{1}>'.format(type, l.name) + label = '<{0}
{1}>'.format(ltype, l.name) elif label_format == 'full': label = '<' for (index, line) in enumerate(str(l).strip().split('\n')): @@ -94,14 +110,36 @@ label += '>' # Add layer as layer graph node - graph.node(l.name, label=label) + tag = l.grid_tag.value + layer_to_grid_tag[l.name] = tag + attrs = {} + if tag != 0: + attrs = dict(style='filled', fillcolor=palette[tag % len(palette)]) + graph.node(l.name, label=label, **attrs) # Add parent/child relationships as layer graph edges edges = set() +cross_grid_edges = set() for l in model.layer: - edges.update([(p, l.name) for p in l.parents.split()]) - edges.update([(l.name, c) for c in l.children.split()]) + tag = layer_to_grid_tag[l.name] + for p in l.parents: + if tag != layer_to_grid_tag[p]: + cross_grid_edges.add((p, l.name)) + else: + edges.add((p, l.name)) + + for c in l.children: + if tag != layer_to_grid_tag[c]: + cross_grid_edges.add((l.name, c)) + else: + edges.add((l.name, c)) + graph.edges(edges) +if args.color_cross_grid: + for u, v in cross_grid_edges: + graph.edge(u, v, color='red') +else: + graph.edges(cross_grid_edges) # Save to file graph.render(filename=filename, cleanup=True, format=file_format) From 9580149f3c2ed319c932ab45a492830a3392b36a Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 6 Feb 2024 14:45:34 -0800 Subject: [PATCH 07/11] Fix constant weight optimizers --- python/lbann/models/subgraph/transformer.py | 2 +- python/lbann/models/transformer.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/lbann/models/subgraph/transformer.py b/python/lbann/models/subgraph/transformer.py index b0674f8ee87..e7291dc474b 100644 --- a/python/lbann/models/subgraph/transformer.py +++ b/python/lbann/models/subgraph/transformer.py @@ -1285,7 +1285,7 @@ def _subsequent_mask(self, size): vals = np.triu(np.full((size, size), -1e9), k=1) weights = lbann.Weights( initializer=lbann.ValueInitializer(values=vals.flat), - optimizer=None, + optimizer=lbann.NoOptimizer(), name=f"{self.name}_mask{size}_weights", ) self._subsequent_mask_cache[size] = lbann.WeightsLayer( diff --git a/python/lbann/models/transformer.py b/python/lbann/models/transformer.py index 6c7bf796dce..63774869970 100644 --- a/python/lbann/models/transformer.py +++ b/python/lbann/models/transformer.py @@ -594,7 +594,7 @@ def _subsequent_mask(self, size): weights = lbann.Weights( initializer=lbann.ValueInitializer(values=vals.flat), - optimizer=None, + optimizer=lbann.NoOptimizer(), name=f'{self.name}_mask{size}_weights', ) self._subsequent_mask_cache[size] = lbann.WeightsLayer( From 7663da388229773172b2bb560bbeb2bb5d8dc5d1 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 6 Feb 2024 14:45:52 -0800 Subject: [PATCH 08/11] Add signal handler disable as an envvar --- src/utils/options.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/utils/options.cpp b/src/utils/options.cpp index 4466d1aca20..b83509f22af 100644 --- a/src/utils/options.cpp +++ b/src/utils/options.cpp @@ -53,6 +53,7 @@ void construct_std_options() arg_parser.add_flag( LBANN_OPTION_DISABLE_SIGNAL_HANDLER, {"--disable_signal_handler"}, + utils::ENV("LBANN_DISABLE_SIGNAL_HANDLER"), "[STD] Disables signal handling (signal handling on by default)"); arg_parser.add_flag(LBANN_OPTION_EXIT_AFTER_SETUP, {"--exit_after_setup"}, From 6b93e65c84110e3bf42df6fd4537cec1aeb35096 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 6 Feb 2024 17:40:08 -0800 Subject: [PATCH 09/11] Fix split layer grid tag --- src/models/model.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/models/model.cpp b/src/models/model.cpp index 6ed719817dd..719471e6f0f 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -1382,10 +1382,13 @@ void model::add_split_layers(std::unordered_set& layer_names) split->set_name(name); layer_names.insert(name); - // Copy parallel strategy from parent. + // Copy parallel strategy and grid tag from parent. ParallelStrategy& ps = split->get_parallel_strategy(); ParallelStrategy& orig_ps = l.get_parallel_strategy(); ps = orig_ps; + if (l.grid_tag() >= 0) { + split->grid_tag(l.grid_tag()); + } // Setup relationships between split layer and child layers for (int j = 0; j < l.get_num_children(); ++j) { @@ -1674,8 +1677,9 @@ void model::backward_prop(bool compute_weight_grads_only, bool skip_callbacks) // Based on gradient/optimizer requirements if (compute_weight_grads_only && m_needed_for_backprop.size() > 0 && - m_needed_for_backprop.find(&l) == m_needed_for_backprop.end()) + m_needed_for_backprop.find(&l) == m_needed_for_backprop.end()) { enable_layer = false; + } } // Check if all children skip gradient backpropagation From 9de6e2fbdbb3ecc5698c57587b98045a9d1a0604 Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Tue, 6 Feb 2024 17:40:35 -0800 Subject: [PATCH 10/11] Fix layer insertion order in transformer LP --- applications/nlp/transformer/parallelism.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/applications/nlp/transformer/parallelism.py b/applications/nlp/transformer/parallelism.py index 0e74310c919..9c0c84ea50d 100644 --- a/applications/nlp/transformer/parallelism.py +++ b/applications/nlp/transformer/parallelism.py @@ -272,7 +272,7 @@ def apply_layer_parallelism_postamble(model: lbann.Model, # Inject interim layers for each grid and reconnect for dst_grid, children in unique_grids.items(): interim = lbann.Identity(layer, grid_tag=dst_grid) - layers_to_insert.append((i, interim)) + layers_to_insert.append((i+1, interim)) # Reconnect parents for child in children: @@ -281,9 +281,9 @@ def apply_layer_parallelism_postamble(model: lbann.Model, cind = layer.children.index(child) new_children[cind] = interim - # Reconnect children + # Reconnect and condense children if unique_grids: - layer.children = new_children + layer.children = list(set(new_children)) # Add identity layers to the traversed graph right after the source layer # was computed From 1e1605931ef25bdf98c1bf357f475c06ada3547d Mon Sep 17 00:00:00 2001 From: Tal Ben-Nun Date: Thu, 15 Feb 2024 19:29:53 -0800 Subject: [PATCH 11/11] Use copy to broadcast weights layers instead of GEMM --- src/layers/transform/weights.cpp | 32 +++++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/src/layers/transform/weights.cpp b/src/layers/transform/weights.cpp index bac97570a71..e9d40e46504 100644 --- a/src/layers/transform/weights.cpp +++ b/src/layers/transform/weights.cpp @@ -157,15 +157,29 @@ void weights_layer::fp_compute() // Duplicate weights across columns of output matrix const auto& local_weights = this->weights_values(0).LockedMatrix(); - MatType ones; - El::Ones(ones, local_output.Width(), 1); - El::Gemm(El::NORMAL, - El::TRANSPOSE, - El::TypeTraits::One(), - local_weights, - ones, - El::TypeTraits::Zero(), - local_output); + if (local_output.Width() <= 32) { // The number 32 is a heuristic + // Use copies for broadcast + for (int i = 0; i < local_output.Width(); ++i) { + MatType v; + El::View(v, + local_output, + El::IR(0, local_weights.Height()), + El::IR(i, i + 1)); + El::Copy(local_weights, v); + } + } + else { + // Use GEMM with ones for broadcast + MatType ones; + El::Ones(ones, local_output.Width(), 1); + El::Gemm(El::NORMAL, + El::TRANSPOSE, + El::TypeTraits::One(), + local_weights, + ones, + El::TypeTraits::Zero(), + local_output); + } } template