From 70db7e72e98ce9818ea7b515656d35449450201a Mon Sep 17 00:00:00 2001 From: MackZackA Date: Tue, 30 Jul 2024 20:57:28 -0700 Subject: [PATCH] Open source 0725 patch (#42) --- README.md | 8 +- .../download_open_llama_ckpt.py | 2 +- .../llama_mfu_calculator.py | 2 +- .../run_open_llama_w_vescale.py | 2 +- .../open_llama_4D_benchmark/sharding_plan.py | 2 +- .../open_llama/test_open_llama_dp_reshard.py | 2 +- .../open_llama/test_open_llama_load_save.py | 2 +- .../open_llama/test_open_llama_tp_reshard.py | 2 +- .../api/test_pipe_single_stage_ops.py | 26 +- .../pipeline/api/test_schedule_engine.py | 21 +- test/parallel/pipeline/api/test_simple_api.py | 1 - .../e2e/test_pp_accuracy_alignment.py | 2 +- vescale/engine/pipe.py | 24 +- vescale/model/base_gpt/__init__.py | 5 - vescale/model/base_gpt/attention.py | 531 ------------------ vescale/model/base_gpt/checkpoint.py | 133 ----- vescale/model/base_gpt/enums.py | 27 - vescale/model/base_gpt/fuse_layer_norm.py | 119 ---- vescale/model/base_gpt/fuse_softmax.py | 203 ------- vescale/model/base_gpt/jit_func.py | 40 -- vescale/model/base_gpt/mlp.py | 101 ---- vescale/model/base_gpt/rotary.py | 52 -- vescale/model/base_gpt/transformer_block.py | 135 ----- vescale/model/base_gpt/transformer_layer.py | 194 ------- vescale/model/base_gpt/utils.py | 27 - vescale/pipe/_schedules/instruction_base.py | 19 + vescale/pipe/_schedules/looping_bfs.py | 2 +- vescale/pipe/_schedules/zero_bubble_v.py | 2 +- vescale/pipe/pipe_stage.py | 8 +- 29 files changed, 74 insertions(+), 1620 deletions(-) delete mode 100644 vescale/model/base_gpt/__init__.py delete mode 100644 vescale/model/base_gpt/attention.py delete mode 100644 vescale/model/base_gpt/checkpoint.py delete mode 100644 vescale/model/base_gpt/enums.py delete mode 100644 vescale/model/base_gpt/fuse_layer_norm.py delete mode 100644 vescale/model/base_gpt/fuse_softmax.py delete mode 100644 vescale/model/base_gpt/jit_func.py delete mode 100644 vescale/model/base_gpt/mlp.py delete mode 100644 vescale/model/base_gpt/rotary.py delete mode 100644 vescale/model/base_gpt/transformer_block.py delete mode 100644 vescale/model/base_gpt/transformer_layer.py delete mode 100644 vescale/model/base_gpt/utils.py diff --git a/README.md b/README.md index 32017c4..323dce2 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ _**An Industrial-Level Framework for Easy-of-Use**_ ## Latest News +- [2024-7-25] veScale's [pipeline parallelism](https://github.com/volcengine/veScale/blob/main/vescale/pipe/README.md) open sourced with API, graph parser, stage abstraction, schedules and execution runtime along with [nD distributed timeline](https://github.com/volcengine/veScale/blob/main/vescale/ndtimeline/README.md). + - [2024-5-31] veScale's [fast checkpointing system](https://github.com/volcengine/veScale/blob/main/vescale/checkpoint/README.md) open sourced with automatic checkpoint resharding, caching, load-balancing, fast copying, deduplicating, and asynchronous io. - [2024-5-21] veScale's examples ([Mixtral](https://github.com/volcengine/veScale/tree/main/examples/mixtral_4D_training), [LLama2](https://github.com/volcengine/veScale/tree/main/examples/llama2_4D_finetune), and [nanoGPT](https://github.com/volcengine/veScale/tree/main/examples/nanogpt_4D_finetune)) open sourced with bit-wise correctness of training loss curves. @@ -32,7 +34,11 @@ _**An Industrial-Level Framework for Easy-of-Use**_ _**veScale**_ is still in its early phase. We are refactoring our internal LLM training system components to meet open source standard. The tentative timeline is as follows: -- by end of July, CUDA event monitor, pipeline parallelism and supporting components for large-scale training +- High-level [nD parallel api](https://github.com/volcengine/veScale/issues/39) for extreme ease of use + +- Power-user plan api for easy customization of nD parallel training + +- End-to-end vescale/examples with 5D parallel training (TP, SP, DP, ZeRO, PP) ## Table of Content ([web view](https://volcengine.github.io/veScaleWeb/)) diff --git a/examples/open_llama_4D_benchmark/download_open_llama_ckpt.py b/examples/open_llama_4D_benchmark/download_open_llama_ckpt.py index 5fcf32e..876228a 100644 --- a/examples/open_llama_4D_benchmark/download_open_llama_ckpt.py +++ b/examples/open_llama_4D_benchmark/download_open_llama_ckpt.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/examples/open_llama_4D_benchmark/llama_mfu_calculator.py b/examples/open_llama_4D_benchmark/llama_mfu_calculator.py index b67908d..9bacdd5 100644 --- a/examples/open_llama_4D_benchmark/llama_mfu_calculator.py +++ b/examples/open_llama_4D_benchmark/llama_mfu_calculator.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py b/examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py index 22f7cf8..8117551 100644 --- a/examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py +++ b/examples/open_llama_4D_benchmark/run_open_llama_w_vescale.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/examples/open_llama_4D_benchmark/sharding_plan.py b/examples/open_llama_4D_benchmark/sharding_plan.py index 2aefd81..12bcd65 100644 --- a/examples/open_llama_4D_benchmark/sharding_plan.py +++ b/examples/open_llama_4D_benchmark/sharding_plan.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/test/checkpoint/open_llama/test_open_llama_dp_reshard.py b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py index 370dadd..b1f6cb3 100644 --- a/test/checkpoint/open_llama/test_open_llama_dp_reshard.py +++ b/test/checkpoint/open_llama/test_open_llama_dp_reshard.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/test/checkpoint/open_llama/test_open_llama_load_save.py b/test/checkpoint/open_llama/test_open_llama_load_save.py index c0a8377..0a3a29a 100644 --- a/test/checkpoint/open_llama/test_open_llama_load_save.py +++ b/test/checkpoint/open_llama/test_open_llama_load_save.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/test/checkpoint/open_llama/test_open_llama_tp_reshard.py b/test/checkpoint/open_llama/test_open_llama_tp_reshard.py index 2a85cae..5096062 100644 --- a/test/checkpoint/open_llama/test_open_llama_tp_reshard.py +++ b/test/checkpoint/open_llama/test_open_llama_tp_reshard.py @@ -1,6 +1,6 @@ ################################################################################ # -# Copyright 2024 ByteDance Ltd. and/or its affiliates. All rights reserved. +# Copyright 2023 ByteDance Ltd. and/or its affiliates. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at diff --git a/test/parallel/pipeline/api/test_pipe_single_stage_ops.py b/test/parallel/pipeline/api/test_pipe_single_stage_ops.py index 9d0922b..6ea7d64 100644 --- a/test/parallel/pipeline/api/test_pipe_single_stage_ops.py +++ b/test/parallel/pipeline/api/test_pipe_single_stage_ops.py @@ -21,7 +21,7 @@ from torch.testing._internal.common_utils import run_tests from vescale.devicemesh_api import VESCALE_DEVICE_MESH from vescale.plan import PipelineScheduleType, PipelineParallelPlan, ModeType, PipelineSplitMethodType -from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules +from vescale.pipe.pipe_stage import construct_pipeline_stage from vescale.engine import PipeEngine from common_dtensor import DTensorTestBase, with_comms from torch.optim import SGD @@ -132,9 +132,7 @@ def test_stage_forward(self): def _run_no_pp_model(self): os.environ["model_name"] = "golden" model = EightMLP().to("cuda:0") - optimizer = torch.optim.SGD( - model.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False - ) + optimizer = SGD(model.parameters(), lr=0.01, momentum=0, dampening=0, weight_decay=0, nesterov=False) torch.manual_seed(9999) batch = [torch.ones(microbatch_size, 128, 32, dtype=torch.float32).to("cuda:0") for _ in range(factor)] for mb in batch: @@ -166,13 +164,6 @@ def _run_stage_forward(self): mesh_dim_names=("PP", "DP", "TP"), ) - stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules( - model, - config, - VESCALE_DEVICE_MESH, - update_split_points=True, - ) - optimizer_fn_kwargs = { "lr": 0.01, "momentum": 0, @@ -183,9 +174,16 @@ def _run_stage_forward(self): "foreach": None, "differentiable": False, } - _parameters = list(stage_modules[0].parameters()) + list(stage_modules[1].parameters()) - optimizer = SGD(_parameters, **optimizer_fn_kwargs) - pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config) + + pipe_module = construct_pipeline_stage( + model, + config, + VESCALE_DEVICE_MESH, + lr_scheduler=None, + update_split_points=True, + ) + optimizer = SGD(pipe_module.parameters(), **optimizer_fn_kwargs) + pipe_module.doptimizer = optimizer engine = PipeEngine( pipe_module, diff --git a/test/parallel/pipeline/api/test_schedule_engine.py b/test/parallel/pipeline/api/test_schedule_engine.py index c508511..01246e2 100644 --- a/test/parallel/pipeline/api/test_schedule_engine.py +++ b/test/parallel/pipeline/api/test_schedule_engine.py @@ -20,7 +20,7 @@ from common_dtensor import DTensorTestBase, with_comms from torch.testing._internal.common_utils import run_tests from vescale.devicemesh_api import VESCALE_DEVICE_MESH -from vescale.pipe.pipe_stage import PipeModule, construct_stage_modules +from vescale.pipe.pipe_stage import construct_pipeline_stage from vescale.pipe._schedules.instruction_base import StageDeps from vescale.pipe.pipe_emmiter import ScheduleEngine from vescale.plan.spec import PipelineScheduleType, ModeType, PipelineSplitMethodType @@ -79,13 +79,6 @@ def test_simple_1f1b(self): schedule_type=PipelineScheduleType.SIMPLE_1F1B, ) - stage_modules, stage_dependency, p2p_index_mapping = construct_stage_modules( - model, - config, - VESCALE_DEVICE_MESH, - update_split_points=True, - ) - optimizer_fn_kwargs = { "lr": 0.01, "momentum": 0, @@ -96,9 +89,15 @@ def test_simple_1f1b(self): "foreach": None, "differentiable": False, } - _parameters = list(stage_modules[0].parameters()) - optimizer = SGD(_parameters, **optimizer_fn_kwargs) - pipe_module = PipeModule(stage_modules, optimizer, None, stage_dependency, p2p_index_mapping, config) + pipe_module = construct_pipeline_stage( + model, + config, + VESCALE_DEVICE_MESH, + lr_scheduler=None, + update_split_points=True, + ) + optimizer = SGD(pipe_module.parameters(), **optimizer_fn_kwargs) + pipe_module.doptimizer = optimizer dep = pipe_module.stage_deps device_mesh_list = VESCALE_DEVICE_MESH.get_global_tensor_parallel_meshes() diff --git a/test/parallel/pipeline/api/test_simple_api.py b/test/parallel/pipeline/api/test_simple_api.py index 97b3d3b..2b5086f 100644 --- a/test/parallel/pipeline/api/test_simple_api.py +++ b/test/parallel/pipeline/api/test_simple_api.py @@ -20,7 +20,6 @@ import torch import torch.nn as nn from torch.testing._internal.common_utils import run_tests -from vescale.debug.pdb import ForkedPdb from vescale.optim.base_optimizer import BasicOptimizer from vescale.pipe.pipe_stage import construct_pipeline_stage from vescale.devicemesh_api import VESCALE_DEVICE_MESH diff --git a/test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py b/test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py index 163f453..a1c27b8 100644 --- a/test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py +++ b/test/parallel/pipeline/e2e/test_pp_accuracy_alignment.py @@ -225,7 +225,7 @@ def _run_engine_with_1f1b(self, fixed_size=True): pipe_config, ) - engine.forward_backward(batch) + engine(batch) optimizer = engine.get_optimizer optimizer.step() diff --git a/vescale/engine/pipe.py b/vescale/engine/pipe.py index 4cd66c8..f3a8631 100644 --- a/vescale/engine/pipe.py +++ b/vescale/engine/pipe.py @@ -36,7 +36,7 @@ def __init__( module: PipeModule, global_mesh: VeDeviceMesh, loss_fn: Callable, - config: PipelineParallelPlan, + plan: PipelineParallelPlan, ): """ Training engine for pipeline parallelism and multi-dimensional @@ -46,8 +46,8 @@ def __init__( training, and optimizer synchronization. """ self.module = module - self.virtual_chunks_per_stage = config.virtual_chunks - self.engine_config = config + self.virtual_chunks_per_stage = plan.virtual_chunks + self.engine_plan = plan self.optimizer = self.module.get_optimizer self.lr_scheduler = self.module.get_lr_scheduler self.global_mesh = global_mesh @@ -59,16 +59,16 @@ def __init__( except: # noqa: E722 self.loss_fn = loss_fn self.schedule_engine = None - self.reuse_comm_shape = self.engine_config.reuse_p2p_tensor_shape + self.reuse_comm_shape = self.engine_plan.reuse_p2p_tensor_shape if self.reuse_comm_shape: os.environ["REUSE_COMM_SHAPE"] = "1" if ( - self.engine_config.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B + self.engine_plan.schedule_type == PipelineScheduleType.INTERLEAVED_1F1B and self.virtual_chunks_per_stage == 1 ): print("[warning]: #virtual pipeline chunks is 1. Falling back to simple 1F1B schedule.") - self.engine_config.schedule_type = PipelineScheduleType.SIMPLE_1F1B - self.schedule_type = self.engine_config.schedule_type + self.engine_plan.schedule_type = PipelineScheduleType.SIMPLE_1F1B + self.schedule_type = self.engine_plan.schedule_type def build_schedule(self, minibatches, data_shape=None): """ @@ -105,7 +105,7 @@ def _locate_tp_mesh(_rank): ) num_minibatches = self._align_num_batches(first_stage_rank, len(minibatches)) # TODO: insert shape inference - batch_p2p_comm = self.engine_config.batch_p2p_comm + batch_p2p_comm = self.engine_plan.batch_p2p_comm # if on interleaved 1f1b schedule, set batch_p2p_comm to False to execute p2p communication schedule_type = self.schedule_type if schedule_type in [PipelineScheduleType.INTERLEAVED_1F1B, PipelineScheduleType.ZERO_BUBBLE]: @@ -123,16 +123,16 @@ def _locate_tp_mesh(_rank): data_iterator=data_iterator, stage_id=self.global_mesh.get_pipeline_parallel_rank(), shape=data_shape, - dtype=self.engine_config.p2p_tensor_dtype, + dtype=self.engine_plan.p2p_tensor_dtype, num_chunks=self.virtual_chunks_per_stage, input_shapes=None, input_shapes_unpad=None, # send_dtypes_map=self.module.recv_dtypes_dict, - overlap_p2p_comm=self.engine_config.overlap_p2p_comm, + overlap_p2p_comm=self.engine_plan.overlap_p2p_comm, batch_p2p_comm=batch_p2p_comm, loss_fn=self.loss_fn, global_mesh=self.global_mesh, - forward_only=self.engine_config.forward_only, + forward_only=self.engine_plan.forward_only, ) def forward_backward( @@ -211,7 +211,7 @@ def parameters(self, including_frozen=False): def sync_shared_params(self, group_id: int = 0, share_params=True) -> None: """ Synchronize gradients and weights among groups of specified units, dictated by - "partition_units" in PipelineConfig. Typically, this function is used for + "partition_units" in PipelineParallelPlan. Typically, this function is used for synchronizing gradients and weights of embeddings layers in Transformer-based architecture. Args: diff --git a/vescale/model/base_gpt/__init__.py b/vescale/model/base_gpt/__init__.py deleted file mode 100644 index f3b869e..0000000 --- a/vescale/model/base_gpt/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ diff --git a/vescale/model/base_gpt/attention.py b/vescale/model/base_gpt/attention.py deleted file mode 100644 index 66c615d..0000000 --- a/vescale/model/base_gpt/attention.py +++ /dev/null @@ -1,531 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import math - -import torch -from torch import nn - -from vescale.dtensor.api import from_local -from vescale.model.base_gpt.checkpoint import checkpoint -from vescale.model.base_gpt.enums import AttnMaskType, AttnType -from vescale.model.base_gpt.fuse_softmax import FusedScaleMaskSoftmax -from vescale.model.base_gpt.rotary import apply_rotary_pos_emb -from vescale.model.random import get_cuda_rng_tracker -from vescale.model.utils import attention_mask_func, divide - -try: - from flash_attn.flash_attn_interface import flash_attn_unpadded_func -except ImportError: - try: - from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_unpadded_func - except ImportError: - flash_attn_unpadded_func = None - -try: - from einops import rearrange -except ImportError: - rearrange = None - - -class CoreAttention(nn.Module): - def __init__(self, layer_number, config, attn_mask_type=AttnMaskType.padding): - super().__init__() - self.fp16 = config.fp16 - self.bf16 = config.bf16 - - self.apply_query_key_layer_scaling = config.apply_query_key_layer_scaling - self.attention_softmax_in_fp32 = config.attention_softmax_in_fp32 - if self.apply_query_key_layer_scaling: - self.attention_softmax_in_fp32 = True - self.layer_number = max(1, layer_number) - self.attn_mask_type = attn_mask_type - self.sequence_parallel = config.sequence_parallel - - self.config = config - - # Per attention head and per partition values. - - coeff = None - if self.apply_query_key_layer_scaling: - coeff = self.layer_number - - self.scale_mask_softmax = FusedScaleMaskSoftmax( - self.fp16, - self.bf16, - self.attn_mask_type, - config.masked_softmax_fusion, - attention_mask_func, - self.attention_softmax_in_fp32, - coeff, - ) - - # Dropout. Note that for a single iteration, this layer will generate - # different outputs on different number of parallel partitions but - # on average it should not be partition dependent. - self.attention_dropout = torch.nn.Dropout(config.attention_dropout) - - def forward(self, query_layer, key_layer, value_layer, attention_mask): - # =================================== - # Raw attention scores. [b, np, s, s] - # =================================== - - # [b, np, sq, sk] - output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) - - # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.reshape(output_size[2], output_size[0] * output_size[1], -1) - # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) - - q_t = query_layer.transpose(0, 1) - # preallocting input tensor: [b * np, sq, sk] - matmul_input_buffer = torch.empty( - (output_size[0] * output_size[1] // query_layer._spec.mesh.size(), output_size[2], output_size[3]), - dtype=query_layer.dtype, - device=query_layer.device, - ) - matmul_input_buffer = from_local(matmul_input_buffer, query_layer._spec.mesh, q_t._spec.placements) - - # Raw attention scores. [b * np, sq, sk] - projection_size = self.config.kv_channels * self.config.num_attention_heads - hidden_size_per_attention_head = divide(projection_size, self.config.num_attention_heads) - norm_factor = math.sqrt(hidden_size_per_attention_head) - norm_factor *= self.layer_number - matmul_result = torch.baddbmm( - matmul_input_buffer, - query_layer.transpose(0, 1), # [b * np, sq, hn] - key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] - beta=0.0, - alpha=(1.0 / norm_factor), - ) - - # change view to [b, np, sq, sk] - attention_scores = matmul_result.view(*output_size) - - # =========================== - # Attention probs and dropout - # =========================== - - # attention scores and attention mask [b, np, sq, sk] - attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - if not self.sequence_parallel: - with get_cuda_rng_tracker().fork(): - attention_probs = self.attention_dropout(attention_probs) - else: - attention_probs = self.attention_dropout(attention_probs) - - attention_probs = from_local(attention_probs, attention_scores._spec.mesh, attention_scores._spec.placements) - - # ========================= - # Context layer. [sq, b, hp] - # ========================= - - # value_layer -> context layer. - # [sk, b, np, hn] --> [b, np, sq, hn] - - # context layer shape: [b, np, sq, hn] - output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) - - # change view [sk, b * np, hn] - value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) - - # change view [b * np, sq, sk] - attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) - - # matmul: [b * np, sq, hn] - context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) - - # change view [b, np, sq, hn] - context_layer = context_layer.view(*output_size) - - # [b, np, sq, hn] --> [sq, b, np, hn] - context_layer = context_layer.permute(2, 0, 1, 3).contiguous() - - # [sq, b, np, hn] --> [sq, b, hp] - context_layer = context_layer.view(*context_layer.size()[:-2], -1) - - return context_layer - - -class FlashSelfAttention(nn.Module): - """Implement the scaled dot product attention with softmax. - Arguments - --------- - softmax_scale: The temperature to use for the softmax attention. - (default: 1/sqrt(d_keys) where d_keys is computed at - runtime) - attention_dropout: The dropout rate to apply to the attention - (default: 0.0) - """ - - def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None): - super().__init__() - assert flash_attn_unpadded_func is not None, ( - "Please install FlashAttention first, " "e.g., with pip install flash-attn" - ) - assert rearrange is not None, "Please install einops first, e.g., with pip install einops" - self.causal = causal - self.softmax_scale = softmax_scale - self.dropout_p = attention_dropout - - def forward(self, q, k, v): - """Implements the multihead softmax attention. - Arguments - --------- - q, k, v: The tensor containing the query, key, and value. (B, S, H, D) - """ - - assert all(i.dtype in [torch.float16, torch.bfloat16] for i in (q, k, v)) - assert all(i.is_cuda for i in (q, k, v)) - - batch_size, seqlen_q = q.shape[0], q.shape[1] - seqlen_k = k.shape[1] - - q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]) - cu_seqlens_q = torch.arange(0, (batch_size + 1) * seqlen_q, step=seqlen_q, dtype=torch.int32, device=q.device) - - if self.training: - # during training q,k,v always have same seqlen - assert seqlen_k == seqlen_q - - is_causal = self.causal - cu_seqlens_k = cu_seqlens_q - dropout_p = self.dropout_p - else: - # turn off FA causal mask after first inference autoregressive iteration - # only on first autoregressive step q,k,v have same seqlen - is_causal = seqlen_q == seqlen_k - cu_seqlens_k = torch.arange( - 0, (batch_size + 1) * seqlen_k, step=seqlen_k, dtype=torch.int32, device=q.device - ) - dropout_p = 0 - - output = flash_attn_unpadded_func( - q, - k, - v, - cu_seqlens_q, - cu_seqlens_k, - seqlen_q, - seqlen_k, - dropout_p, - softmax_scale=self.softmax_scale, - causal=is_causal, - ) - - output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) - return output - - -class ParallelAttention(nn.Module): - """Parallel self-attention layer abstract class. - - Self-attention layer takes input with size [s, b, h] - and returns output of the same size. - """ - - def __init__( - self, - config, - layer_number, - attention_type=AttnType.self_attn, - attn_mask_type=AttnMaskType.padding, - ): - super().__init__() - self.layer_number = max(1, layer_number) - self.attention_type = attention_type - self.attn_mask_type = attn_mask_type - self.config = config - - self.group_query_attention = config.group_query_attention - self.num_query_groups = config.num_query_groups - - query_projection_size = config.kv_channels * config.num_attention_heads - if self.group_query_attention: - kv_projection_size = config.kv_channels * config.num_query_groups - else: - kv_projection_size = config.kv_channels * config.num_attention_heads - - self.use_flash_attn = ( - config.use_flash_attn - and attention_type == AttnType.self_attn - and self.attn_mask_type == AttnMaskType.causal - ) - if self.use_flash_attn: - if flash_attn_unpadded_func is None: - raise ImportError("FlashAttention is not installed, please install with " "pip install flash-attn") - assert attention_type == AttnType.self_attn, ( - "FlashAttention code path only supports " "self-attention for now" - ) - assert self.attn_mask_type == AttnMaskType.causal, ( - "FlashAttention code path only " "supports causal mask for now" - ) - if rearrange is None: - raise ImportError("einops is not installed, please install with pip install einops") - - # Strided linear layer. - if attention_type == AttnType.self_attn: - self.query_key_value = nn.Linear( - config.hidden_size, - query_projection_size + 2 * kv_projection_size, - bias=config.add_bias_linear, - ) - config.init_method(self.query_key_value.weight) - else: - assert attention_type == AttnType.cross_attn - - if self.group_query_attention: - raise NotImplementedError("Grouped query attention not implemented for cross-attention.") - assert query_projection_size == kv_projection_size - - self.query = nn.Linear( - config.hidden_size, - query_projection_size, - bias=self.add_bias_linear, - ) - config.init_method(self.query.weight) - self.key_value = nn.Linear(config.hidden_size, 2 * kv_projection_size, bias=config.add_bias_linear) - config.init_method(self.key_value.weight) - - self.core_attention = CoreAttention(self.layer_number, config, self.attn_mask_type) - self.checkpoint_core_attention = config.recompute_granularity == "selective" - - if self.use_flash_attn: - self.core_attention_flash = FlashSelfAttention(causal=True, attention_dropout=config.attention_dropout) - - # Output. - self.dense = nn.Linear( - query_projection_size, - config.hidden_size, - bias=False, - ) - self.dense_bias = torch.empty(config.hidden_size) if config.add_bias_linear else None - config.output_layer_init_method(self.dense.weight) - - def _checkpointed_attention_forward(self, query_layer, key_layer, value_layer, attention_mask, rotary_pos_emb=None): - """Forward method with activation checkpointing.""" - - def custom_forward(*inputs): - query_layer = inputs[0] - key_layer = inputs[1] - value_layer = inputs[2] - attention_mask = inputs[3] - output_ = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - return output_ - - q_pos_emb, k_pos_emb = (None, None) if rotary_pos_emb is None else rotary_pos_emb - - hidden_states = checkpoint( - custom_forward, False, query_layer, key_layer, value_layer, attention_mask, q_pos_emb, k_pos_emb - ) - - return hidden_states - - def _allocate_memory(self, inference_max_sequence_len, batch_size, num_attention_heads): - query_projection_size = self.config.kv_channels * self.config.num_attention_heads - hidden_size_per_attention_head = divide(query_projection_size, self.config.num_attention_heads) - return torch.empty( - inference_max_sequence_len, - batch_size, - num_attention_heads, - hidden_size_per_attention_head, - dtype=self.params_dtype, - device=torch.cuda.current_device(), - ) - - def forward( - self, hidden_states, attention_mask=None, encoder_output=None, inference_params=None, rotary_pos_emb=None - ): - # hidden_states: [sq, b, h] - - # Per attention head and per partition values. - world_size = self.hidden_states._spec.mesh.size() # TP - query_projection_size = self.config.kv_channels * self.config.num_attention_heads - self.hidden_size_per_attention_head = divide(query_projection_size, self.config.num_attention_heads) - self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) - - if self.group_query_attention: - self.num_query_groups_per_partition = divide(self.num_query_groups, world_size) - else: - self.num_query_groups_per_partition = self.num_attention_heads_per_partition - - self.num_attention_heads_per_partition = divide(self.config.num_attention_heads, world_size) - - # ================================================= - # Pre-allocate memory for key-values for inference. - # ================================================= - is_first_step = False - if inference_params: - if self.layer_number not in inference_params.key_value_memory_dict: - inf_max_seq_len = inference_params.max_sequence_length - inf_max_batch_size = inference_params.max_batch_size - inference_key_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, self.num_query_groups_per_partition - ) - inference_value_memory = self._allocate_memory( - inf_max_seq_len, inf_max_batch_size, self.num_query_groups_per_partition - ) - - inference_params.key_value_memory_dict[self.layer_number] = ( - inference_key_memory, - inference_value_memory, - ) - is_first_step = True - else: - inference_key_memory, inference_value_memory = inference_params.key_value_memory_dict[self.layer_number] - - # ===================== - # Query, Key, and Value - # ===================== - if self.attention_type == AttnType.self_attn: - # Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)] - mixed_x_layer, _ = self.query_key_value(hidden_states) - - # [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + ( - self.num_query_groups_per_partition, - ( - (self.num_attention_heads_per_partition // self.num_query_groups_per_partition + 2) - * self.hidden_size_per_attention_head - ), - ) - mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) - - # [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn] - (query_layer, key_layer, value_layer) = torch.split( - mixed_x_layer, - [ - ( - self.num_attention_heads_per_partition - // self.num_query_groups_per_partition - * self.hidden_size_per_attention_head - ), - self.hidden_size_per_attention_head, - self.hidden_size_per_attention_head, - ], - dim=3, - ) - - # [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] - - query_layer = query_layer.view( - query_layer.size(0), query_layer.size(1), -1, self.hidden_size_per_attention_head - ) - else: - # Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)] - mixed_kv_layer, _ = self.key_value(encoder_output) - - # [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn] - new_tensor_shape = mixed_kv_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - 2 * self.hidden_size_per_attention_head, - ) - mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) - - # [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn] - (key_layer, value_layer) = split_tensor_along_last_dim(mixed_kv_layer, 2) - - # Attention head [sq, b, h] --> [sq, b, hp] - query_layer, _ = self.query(hidden_states) - # [sq, b, hp] --> [sq, b, np, hn] - new_tensor_shape = query_layer.size()[:-1] + ( - self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head, - ) - query_layer = query_layer.view(*new_tensor_shape) - - # ================================== - # Adjust key and value for inference - # ================================== - - # duplicate the pos_emb for self attention - if rotary_pos_emb is not None: - if isinstance(rotary_pos_emb, tuple): - rotary_pos_emb = rotary_pos_emb - else: - rotary_pos_emb = (rotary_pos_emb,) * 2 - - if inference_params: - batch_start = inference_params.batch_size_offset - batch_end = batch_start + key_layer.size(1) - assert batch_end <= inference_key_memory.size(1) - sequence_start = inference_params.sequence_len_offset - sequence_end = sequence_start + key_layer.size(0) - assert sequence_end <= inference_key_memory.size(0) - # Copy key and values. - inference_key_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer - inference_value_memory[sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer - key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...] - value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...] - - # adjust the key rotary positional embedding - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - # need to cross check this condition during inference - # if not set_inference_key_value_memory: - if not is_first_step: - # In inference, we compute one token at a time. - # Select the correct positional embedding - # (only the last token in the sequence) - q_pos_emb = q_pos_emb[sequence_end - 1 : sequence_end] - else: - # In the first forward pass of inference, - # we use the entire provided prefix. - # q_pos_emb here has the rope embeddings of the entire - # prefix + to-be-generated output so - # we slice to just the prefix. - q_pos_emb = q_pos_emb[:sequence_end, :, :, :] - k_pos_emb = k_pos_emb[:sequence_end, :, :, :] - rotary_pos_emb = (q_pos_emb, k_pos_emb) - - # ================================== - # core attention computation - # ================================== - - # expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn] - if self.num_attention_heads_per_partition // self.num_query_groups_per_partition > 1: - key_layer = key_layer.repeat_interleave( - self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2 - ) - value_layer = value_layer.repeat_interleave( - self.num_attention_heads_per_partition // self.num_query_groups_per_partition, dim=2 - ) - - # apply relative positional encoding (rotary embedding) - if rotary_pos_emb is not None: - q_pos_emb, k_pos_emb = rotary_pos_emb - query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) - key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) - # TODO, can apply positional embedding to value_layer so it has - # absolute positional embedding. - # otherwise, only relative positional embedding takes effect - # value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb) - - if not self.use_flash_attn: - if self.checkpoint_core_attention: - context_layer = self._checkpointed_attention_forward( - query_layer, key_layer, value_layer, attention_mask - ) - else: - context_layer = self.core_attention(query_layer, key_layer, value_layer, attention_mask) - else: - q, k, v = (rearrange(x, "s b ... -> b s ...").contiguous() for x in (query_layer, key_layer, value_layer)) - if not self.sequence_parallel: - with get_cuda_rng_tracker().fork(): - context_layer = self.core_attention_flash(q, k, v) - else: - context_layer = self.core_attention_flash(q, k, v) - context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous() - - # ================= - # Output. [sq, b, h] - # ================= - - output = self.dense(context_layer) - - return output, self.dense_bias diff --git a/vescale/model/base_gpt/checkpoint.py b/vescale/model/base_gpt/checkpoint.py deleted file mode 100644 index d6f6cd7..0000000 --- a/vescale/model/base_gpt/checkpoint.py +++ /dev/null @@ -1,133 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -from typing import Any, Tuple - -import torch - -from vescale.dtensor.dtensor import DTensor -from vescale.dtensor.placement_types import Replicate -from vescale.model.random import _set_cuda_rng_state, get_cuda_rng_tracker - - -def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]: - if isinstance(inputs, tuple): - out = [] - for inp in inputs: - if not isinstance(inp, torch.Tensor): - out.append(inp) - continue - - x = inp.detach() - x.requires_grad = inp.requires_grad - out.append(x) - return tuple(out) - else: - raise RuntimeError( - "Only tuple of tensors is supported. Got Unsupported input type: ", - type(inputs).__name__, - ) - - -def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): - """Break a tensor into equal 1D chunks across tensor parallel ranks. - - Returns a Tensor or View with this rank's portion of the data. - - Arguments: - tensor: The tensor to split - - Keyword Arguments: - new_buffer (bool): If True, returns a new Tensor. - If False, returns a view into the existing Tensor. - Default is False - - """ - device_mesh = tensor.device_mesh - partition_size = torch.numel(tensor) // device_mesh.size() - start_index = partition_size * device_mesh.get_rank() - end_index = start_index + partition_size - if new_buffer: - data = torch.empty(partition_size, dtype=tensor.dtype, device=torch.cuda.current_device(), requires_grad=False) - data.copy_(tensor.view(-1)[start_index:end_index]) - else: - data = tensor.view(-1)[start_index:end_index] - return data - - -class CheckpointFunction(torch.autograd.Function): - """This function is adapted from torch.utils.checkpoint with - two main changes: - 1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` - 2) the states in the model parallel tracker are also properly - tracked/set/reset. - """ - - @staticmethod - def forward(ctx, run_function, distribute_saved_activations, *args): - ctx.run_function = run_function - ctx.distribute_saved_activations = distribute_saved_activations - - # Copy the rng states. - ctx.fwd_cpu_rng_state = torch.get_rng_state() - ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() - ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - with torch.no_grad(): - outputs = run_function(*args) - - # Divide hidden states across model parallel group and only keep - # the chunk corresponding to the current rank. - if distribute_saved_activations: - ctx.input_0_shape = args[0].data.shape - assert isinstance(args[0].data, DTensor) - args[0].data = split_tensor_into_1d_equal_chunks(args[0].data, new_buffer=True) - - # Store everything. - ctx.save_for_backward(*args) - - return outputs - - @staticmethod - def backward(ctx, *args): - if not torch.autograd._is_checkpoint_valid(): - raise RuntimeError("Checkpointing is not compatible with .grad(), " "please use .backward() if possible") - inputs = ctx.saved_tensors - if ctx.distribute_saved_activations: - assert isinstance(inputs[0].data, DTensor) - inputs[0].data.redistribute(inputs[0].data.device_mesh, [Replicate()]) - - # Store the current states. - bwd_cpu_rng_state = torch.get_rng_state() - bwd_cuda_rng_state = torch.cuda.get_rng_state() - bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() - - # Set the states to what it used to be before the forward pass. - torch.set_rng_state(ctx.fwd_cpu_rng_state) - _set_cuda_rng_state(ctx.fwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) - - # Compute the forward pass. - detached_inputs = detach_variable(inputs) - with torch.enable_grad(): - outputs = ctx.run_function(*detached_inputs) - - # Set the states back to what it was at the start of this function. - torch.set_rng_state(bwd_cpu_rng_state) - _set_cuda_rng_state(bwd_cuda_rng_state) - get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) - - if isinstance(outputs, DTensor): - outputs = (outputs,) - torch.autograd.backward(outputs, args) - grads = tuple(inp.grad if isinstance(inp, DTensor) else inp for inp in detached_inputs) - return (None, None) + grads - - -def checkpoint(function, distribute_saved_activations, *args): - """Checkpoint a model or part of the model. - This has been directly copied from torch.utils.checkpoint.""" - return CheckpointFunction.apply(function, distribute_saved_activations, *args) diff --git a/vescale/model/base_gpt/enums.py b/vescale/model/base_gpt/enums.py deleted file mode 100644 index 841dffd..0000000 --- a/vescale/model/base_gpt/enums.py +++ /dev/null @@ -1,27 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import enum - - -class ModelType(enum.Enum): - encoder_or_decoder = 1 - encoder_and_decoder = 2 - - -class LayerType(enum.Enum): - encoder = 1 - decoder = 2 - - -class AttnType(enum.Enum): - self_attn = 1 - cross_attn = 2 - - -class AttnMaskType(enum.Enum): - padding = 1 - causal = 2 diff --git a/vescale/model/base_gpt/fuse_layer_norm.py b/vescale/model/base_gpt/fuse_layer_norm.py deleted file mode 100644 index e1e5801..0000000 --- a/vescale/model/base_gpt/fuse_layer_norm.py +++ /dev/null @@ -1,119 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -"""This code is copied fron NVIDIA apex: - https://github.com/NVIDIA/apex -with some changes.""" - -import importlib -import numbers - -import torch -from torch import nn -from torch.nn import init -from torch.nn.parameter import Parameter - - -fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") - -try: - from apex.contrib.layer_norm.layer_norm import FastLayerNormFN - - HAVE_PERSIST_LAYER_NORM = True -except ImportError: - HAVE_PERSIST_LAYER_NORM = False - -try: - from apex.normalization.fused_layer_norm import FusedLayerNormAffineFunction - - HAVE_FUSED_LAYER_NORM = True -except ImportError: - HAVE_FUSED_LAYER_NORM = False - - -class MixedFusedLayerNorm(nn.Module): - def __init__( - self, - normalized_shape, - eps=1e-5, - no_persist_layer_norm=True, - param_dtype=torch.float32, - sequence_parallel=False, - apply_layernorm_1p=False, - ): - super().__init__() - - self.apply_layernorm_1p = apply_layernorm_1p - - # List of hiddens sizes supported in the persistent layer norm kernel - # If the hidden size is not supported, fall back to the non-persistent - # kernel. - persist_ln_hidden_sizes = [ - 1024, - 1536, - 2048, - 2304, - 3072, - 3840, - 4096, - 5120, - 6144, - 8192, - 10240, - 12288, - 12800, - 15360, - 16384, - 18432, - 20480, - 24576, - 25600, - 30720, - 32768, - 40960, - 49152, - 65536, - ] - if normalized_shape not in persist_ln_hidden_sizes or not HAVE_PERSIST_LAYER_NORM: - no_persist_layer_norm = True - - if isinstance(normalized_shape, numbers.Integral): - normalized_shape = (normalized_shape,) - self.normalized_shape = torch.Size(normalized_shape) - self.eps = eps - self.weight = Parameter(torch.Tensor(*normalized_shape).to(param_dtype)) - self.bias = Parameter(torch.Tensor(*normalized_shape).to(param_dtype)) - self.reset_parameters() - self.no_persist_layer_norm = no_persist_layer_norm - self.sequence_parallel = sequence_parallel - - def reset_parameters(self): - if self.apply_layernorm_1p: - init.zeros_(self.weight) - init.zeros_(self.bias) - else: - init.ones_(self.weight) - init.zeros_(self.bias) - - def forward(self, input): - weight = self.weight + 1 if self.apply_layernorm_1p else self.weight - - if self.no_persist_layer_norm: - assert ( - FusedLayerNormAffineFunction is not None - ), "FusedLayerNormAffineFunction is not available, please install apex from https://github.com/NVIDIA/apex" - out = FusedLayerNormAffineFunction.apply(input, weight, self.bias, self.normalized_shape, self.eps, False) - return out - else: - output = FastLayerNormFN.apply(input, weight, self.bias, self.eps) - - # Apex's fast layer norm function outputs a 'view' tensor (i.e., has - # a populated '_base' field). This will result in schedule.py's - # deallocate_output_tensor() throwing an error, so a viewless tensor is - # created to prevent this. - # output = make_viewless_tensor( - # inp=output, requires_grad=input.requires_grad, keep_graph=True) - return output diff --git a/vescale/model/base_gpt/fuse_softmax.py b/vescale/model/base_gpt/fuse_softmax.py deleted file mode 100644 index 25f3021..0000000 --- a/vescale/model/base_gpt/fuse_softmax.py +++ /dev/null @@ -1,203 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import torch -from torch import nn - -from vescale.model.base_gpt.enums import AttnMaskType - - -class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply upper triangular mask (typically used in gpt models). - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_upper_triang_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - softmax_results = scaled_upper_triang_masked_softmax_cuda.forward(inputs, scale_t[0]) - - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_upper_triang_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - input_grads = scaled_upper_triang_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) - - return input_grads, None - - -class ScaledMaskedSoftmax(torch.autograd.Function): - """ - Fused operation which performs following three operations in sequence - 1. Scale the tensor. - 2. Apply the mask. - 3. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, mask, scale): - import scaled_masked_softmax_cuda - - scale_t = torch.tensor([scale]) - softmax_results = scaled_masked_softmax_cuda.forward(inputs, mask, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_masked_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_masked_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class ScaledSoftmax(torch.autograd.Function): - """ - Fused operation which performs following two operations in sequence - 1. Scale the tensor. - 2. Perform softmax. - """ - - @staticmethod - def forward(ctx, inputs, scale): - import scaled_softmax_cuda - - scale_t = torch.tensor([scale]) - - softmax_results = scaled_softmax_cuda.forward(inputs, scale_t[0]) - ctx.save_for_backward(softmax_results, scale_t) - return softmax_results - - @staticmethod - def backward(ctx, output_grads): - import scaled_softmax_cuda - - softmax_results, scale_t = ctx.saved_tensors - - input_grads = scaled_softmax_cuda.backward(output_grads, softmax_results, scale_t[0]) - return input_grads, None, None - - -class FusedScaleMaskSoftmax(nn.Module): - """ - fused operation: scaling + mask + softmax - - Arguments: - input_in_fp16: flag to indicate if input in fp16 data format. - input_in_bf16: flag to indicate if input in bf16 data format. - attn_mask_type: attention mask type (pad or causal) - scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion - mask_func: mask function to be applied. - softmax_in_fp32: if true, softmax in performed at fp32 precision. - scale: scaling factor used in input tensor scaling. - """ - - def __init__( - self, - input_in_fp16, - input_in_bf16, - attn_mask_type, - scaled_masked_softmax_fusion, - mask_func, - softmax_in_fp32, - scale, - ): - super().__init__() - self.input_in_fp16 = input_in_fp16 - self.input_in_bf16 = input_in_bf16 - assert not ( - self.input_in_fp16 and self.input_in_bf16 - ), "both fp16 and bf16 flags cannot be active at the same time." - self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 - self.attn_mask_type = attn_mask_type - self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion - self.mask_func = mask_func - self.softmax_in_fp32 = softmax_in_fp32 - self.scale = scale - - assert self.scale is None or softmax_in_fp32, "softmax should be in fp32 when scaled" - - def forward(self, input, mask): - # [b, np, sq, sk] - assert input.dim() == 4 - - if self.is_kernel_available(mask, *input.size()): - return self.forward_fused_softmax(input, mask) - else: - return self.forward_torch_softmax(input, mask) - - def is_kernel_available(self, mask, b, np, sq, sk): - attn_batches = b * np - - if ( - self.scaled_masked_softmax_fusion # user want to fuse - and self.input_in_float16 # input must be fp16 - and 16 < sk <= 4096 # sk must be 16 ~ 2048 - and sq % 4 == 0 # sq must be divisor of 4 - and sk % 4 == 0 # sk must be divisor of 4 - and attn_batches % 4 == 0 # np * b must be divisor of 4 - ): - if 0 <= sk <= 4096: - batch_per_block = self.get_batch_per_block(sq, sk, b, np) - - if self.attn_mask_type == AttnMaskType.causal: - if attn_batches % batch_per_block == 0: - return True - else: - if sq % batch_per_block == 0: - return True - return False - - def forward_fused_softmax(self, input, mask): - b, np, sq, sk = input.size() - scale = self.scale if self.scale is not None else 1.0 - - if self.attn_mask_type == AttnMaskType.causal: - assert sq == sk, "causal mask is only for self attention" - # input is 3D tensor (attn_batches, sq, sk) - input = input.view(-1, sq, sk) - probs = ScaledUpperTriangMaskedSoftmax.apply(input, scale) - return probs.view(b, np, sq, sk) - else: - # input is 4D tensor (b, np, sq, sk) - if mask is not None: - return ScaledMaskedSoftmax.apply(input, mask, scale) - else: - return ScaledSoftmax.apply(input, scale) - - def forward_torch_softmax(self, input, mask): - if self.input_in_float16 and self.softmax_in_fp32: - input = input.float() - - if self.scale is not None: - input = input * self.scale - mask_output = self.mask_func(input, mask) if mask is not None else input - probs = torch.nn.Softmax(dim=-1)(mask_output) - - if self.input_in_float16 and self.softmax_in_fp32: - if self.input_in_fp16: - probs = probs.half() - else: - probs = probs.bfloat16() - - return probs - - @staticmethod - def get_batch_per_block(sq, sk, b, np): - import scaled_masked_softmax_cuda - - return scaled_masked_softmax_cuda.get_batch_per_block(sq, sk, b, np) diff --git a/vescale/model/base_gpt/jit_func.py b/vescale/model/base_gpt/jit_func.py deleted file mode 100644 index c129688..0000000 --- a/vescale/model/base_gpt/jit_func.py +++ /dev/null @@ -1,40 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -from typing import Optional - -import torch - - -def bias_dropout_add(x, bias, residual, prob, training): - # type: (torch.Tensor, Optional[torch.Tensor], torch.Tensor, float, bool) -> torch.Tensor - if bias is not None: - x = x + bias - out = torch.nn.functional.dropout(x, p=prob, training=training) - out = residual + out - return out - - -def get_bias_dropout_add(training): - def _bias_dropout_add(x, bias, residual, prob): - return bias_dropout_add(x, bias, residual, prob, training) - - return _bias_dropout_add - - -@torch.compile -def bias_dropout_add_fused_train( - x: torch.Tensor, bias: Optional[torch.Tensor], residual: torch.Tensor, prob: float -) -> torch.Tensor: - return bias_dropout_add(x, bias, residual, prob, True) - - -# @torch.jit.script -@torch.compile -def bias_dropout_add_fused_inference( - x: torch.Tensor, bias: Optional[torch.Tensor], residual: torch.Tensor, prob: float -) -> torch.Tensor: - return bias_dropout_add(x, bias, residual, prob, False) diff --git a/vescale/model/base_gpt/mlp.py b/vescale/model/base_gpt/mlp.py deleted file mode 100644 index f2c33fc..0000000 --- a/vescale/model/base_gpt/mlp.py +++ /dev/null @@ -1,101 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import torch -from torch import nn - -from vescale.model.utils import bias_gelu_impl, openai_gelu - - -class SwitchMLP(nn.Module): - """ - Routes input to one of N MLP "experts" - """ - - def __init__(self, hidden_size, num_experts): - super().__init__() - self.router = nn.Linear(hidden_size, num_experts) - self.experts = torch.nn.ModuleList() - for _ in range(num_experts): - self.experts.append(ParallelMLP(hidden_size)) - - def forward(self, hidden_states): - # hidden_states: [s, b, h] - s = hidden_states.size(0) - b = hidden_states.size(1) - h = hidden_states.size(2) - route = self.router(hidden_states) - route = torch.nn.functional.softmax(route, dim=2) - max_prob, max_ind = torch.max(route, dim=2) - max_prob = torch.unsqueeze(max_prob, 2) # [s b 1] - - # TODO (rprenger) TODO this could be made easier to read - # Converting [s, b, h] to [s*b, h]. - # Each vector could be routed differently - # [s*b h] - hidden_states = hidden_states.view(-1, hidden_states.size(2)) - max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1] - max_ind = max_ind.view(-1) # [s*b] - - output_total = torch.empty_like(hidden_states) - output_bias_total = torch.empty_like(hidden_states) - # TODO (rprenger) This does each expert in serial, but it could be parallelized - - for expert_num, expert in enumerate(self.experts): - local_indices = (max_ind == expert_num).nonzero() - hidden = hidden_states[local_indices, :] - output, output_bias = expert(hidden) - output_bias = output_bias.expand_as(output) - output_total[local_indices, :] = output - output_bias_total[local_indices, :] = output_bias - - output_total = output_total * max_prob - output_bias_total = output_bias_total * max_prob - output_total = output_total.view(s, b, h) - output_bias_total = output_bias_total.view(s, b, h) - - return output_total, output_bias_total - - -class ParallelMLP(nn.Module): - """MLP. - - MLP will take the input with h hidden state, project it to 4*h - hidden dimension, perform nonlinear transformation, and project the - state back into h hidden dimension. - """ - - def __init__(self, h, param_dtype=torch.float32, bias_gelu_fusion=None): - super().__init__() - - # Project to 4h. - self.dense_h_to_4h = nn.Linear(h, h * 4, bias=False, dtype=param_dtype) - # torch.nn.init.normal_(self.dense_h_to_4h.weight, mean=0.0, std=0.02) - torch.nn.init.xavier_normal_(self.dense_h_to_4h.weight) - self.dense_h_to_4h_bias = nn.Parameter(torch.zeros(4 * h, dtype=param_dtype)) - - self.bias_gelu_fusion = bias_gelu_fusion - self.activation_func = openai_gelu - - # Project back to h. - self.dense_4h_to_h = nn.Linear(4 * h, h, bias=False, dtype=param_dtype) - torch.nn.init.xavier_uniform_(self.dense_4h_to_h.weight) - # torch.nn.init.normal_(self.dense_4h_to_h.weight, mean=0.0, std=0.02) - self.dense_4h_to_h_bias = nn.Parameter(torch.zeros(h, dtype=param_dtype)) - - def forward(self, hidden_states): - intermediate_parallel = self.dense_h_to_4h(hidden_states) - bias_parallel = self.dense_h_to_4h_bias - - if self.bias_gelu_fusion: - intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel) - else: - intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) - - # [s, b, h] - output = self.dense_4h_to_h(intermediate_parallel) - output_bias = self.dense_4h_to_h_bias - return output, output_bias diff --git a/vescale/model/base_gpt/rotary.py b/vescale/model/base_gpt/rotary.py deleted file mode 100644 index eaa8d76..0000000 --- a/vescale/model/base_gpt/rotary.py +++ /dev/null @@ -1,52 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -from typing import Union - -import torch -from torch import Tensor - -from vescale.dtensor.dtensor import DTensor - - -def _rotate_half(x: Union[Tensor, DTensor]) -> Union[Tensor, DTensor]: - """Change sign so the last dimension becomes [-odd, +even] - - Args: - x (Tensor): Input tensor - - Returns: - Tensor: Tensor rotated half - """ - - x1, x2 = torch.chunk(x, 2, dim=-1) - return torch.cat((-x2, x1), dim=-1) - - -def apply_rotary_pos_emb(t: Union[Tensor, DTensor], freqs: Union[Tensor, DTensor]) -> Union[Tensor, DTensor]: - """Apply rotary positional embedding to input tensor T. - - check https://kexue.fm/archives/8265 for detailed formulas - - Args: - t (Tensor): Input tensor T is of shape [seq_length, ... , dim] - freqs (Tensor): Rotary Positional embedding tensor freq is of shape [seq_length, ..., dim] - - Returns: - Tensor: The input tensor after applying RoPE - """ - rot_dim = freqs.shape[-1] - - # ideally t_pass is empty so rotary pos embedding is applied to all tensor t - t, t_pass = t.narrow(-1, 0, rot_dim), t.narrow(-1, rot_dim, max(t.size()[-1] - rot_dim, 0)) - - # first part is cosine component - # second part is sine component, need to change signs with _rotate_half method - cos_ = torch.cos(freqs).to(t.dtype) - sin_ = torch.sin(freqs).to(t.dtype) - - t = (t * cos_) + (_rotate_half(t) * sin_) - return torch.cat((t, t_pass), dim=-1) diff --git a/vescale/model/base_gpt/transformer_block.py b/vescale/model/base_gpt/transformer_block.py deleted file mode 100644 index a2c09be..0000000 --- a/vescale/model/base_gpt/transformer_block.py +++ /dev/null @@ -1,135 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import torch -from torch import nn -from contextlib import nullcontext -from vescale.dtensor.dtensor import DTensor -from vescale.initialize.deferred_init import deferred_init -from vescale.model.base_gpt.transformer_layer import ParallelTransformerLayer -from vescale.model.random import get_cuda_rng_tracker - - -class TransformerBlock(nn.Module): - """Transformer class.""" - - def __init__( - self, - num_layer, - args, - drop_path_rate=0.0, - pre_process=True, - deferred_init=False, - ): - super().__init__() - - self.config = args - self.drop_path_rate = drop_path_rate - self.pre_process = pre_process - self.num_layer = num_layer - self.deferred_init = deferred_init - - # required for pipeline parallel schedules - self.input_tensor = None - self._build_layers() - - def _build_layers(self): - # Transformer layers. - # @jcasper can we improve how we deal with layer_number? - # currently it's only used in CoreAttention? - # if self.apply_query_key_layer_scaling: - # coeff = self.layer_number - # self.norm_factor *= coeff - def build_layer(layer_number): - if self.deferred_init: - layer_config = { - "init_method": self.config.init_method, - "output_layer_init_method": self.config.output_layer_init_method, - "layer_number": layer_number, - "args": self.config, - "drop_path_rate": self.drop_path_rate, - } - layer = deferred_init(ParallelTransformerLayer, **layer_config) - else: - layer = ParallelTransformerLayer( - self.config.init_method, - self.config.output_layer_init_method, - layer_number, - self.config, - self.drop_path_rate, - ) - - return layer - - # offset is implicit in TransformerLayer - self.transformer_layers = torch.nn.ModuleList([build_layer(i + 1) for i in range(self.num_layer)]) - self.layers = torch.nn.Sequential() - for i in range(len(self.transformer_layers)): - self.layers.append(self.transformer_layers[i]) - - def _get_layer(self, layer_number): - return self.transformer_layers[layer_number] - - def set_input_tensor(self, input_tensor): - """Set input tensor to be used instead of forward()'s input. - - When doing pipeline parallelism the input from the previous - stage comes from communication, not from the input, so the - model's forward_step_func won't have it. This function is thus - used by internal code to bypass the input provided by the - forward_step_func""" - self.input_tensor = input_tensor - - def forward( - self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None - ): - # hidden_states (float): [s, b, h] - # attention_mask (bool): [1, 1, s, s] - - if not self.pre_process: - # See set_input_tensor() - hidden_states = self.input_tensor - - # Viewless tensor. - # - We only need to create a viewless tensor in the case of micro batch - # size (mbs) == 1, since in this case, 'hidden_states.transpose()' - # above creates a view tensor, and '.contiguous()' is a pass-through. - # For mbs >= 2, '.contiguous()' creates a new tensor, eliminating - # the need to make it viewless. - # - # However, we don't explicitly check mbs == 1 here because - # make_viewless_tensor() has negligible overhead when its input - # is already viewless. - # - # - For the 'else' case above, calling make_viewless_tensor() here is - # likely redundant, since p2p_communication.py (likely originator) - # already creates viewless tensors. That said, make_viewless_tensor() - # is called here to be future-proof and corner-case-proof. - # hidden_states = make_viewless_tensor( - # inp=hidden_states, - # requires_grad=True, - # keep_graph=True, - # ) - - rng_context = nullcontext() - if isinstance(hidden_states, DTensor): - placements = hidden_states.placements - # check sbh, for s - is_sp = any(placement.is_shard(dim=0) for placement in placements) - if is_sp: - rng_context = get_cuda_rng_tracker().fork() - - with rng_context: - for layer in self.transformer_layers: - hidden_states = layer( - hidden_states=hidden_states, - attention_mask=attention_mask, - encoder_output=encoder_output, - enc_dec_attn_mask=enc_dec_attn_mask, - inference_params=inference_params, - ) - - return hidden_states diff --git a/vescale/model/base_gpt/transformer_layer.py b/vescale/model/base_gpt/transformer_layer.py deleted file mode 100644 index f9931d1..0000000 --- a/vescale/model/base_gpt/transformer_layer.py +++ /dev/null @@ -1,194 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -from contextlib import nullcontext - -import torch -from torch import nn - -from vescale.model.attention.dmodule_parallel_attention import ParallelAttention -from vescale.model.base_gpt.fuse_layer_norm import MixedFusedLayerNorm as LayerNorm -from vescale.model.base_gpt.jit_func import ( - bias_dropout_add_fused_inference, - bias_dropout_add_fused_train, - get_bias_dropout_add, -) -from vescale.model.base_gpt.mlp import ParallelMLP, SwitchMLP - - -class DropPath(nn.Module): - """Drop paths (Stochastic Depth) per sample - (when applied in main path of residual blocks). - """ - - def __init__(self, drop_prob=0.0): - super().__init__() - self.drop_prob = drop_prob - - def forward(self, hidden_state): - if self.drop_prob == 0.0 or not self.training: - return hidden_state - keep_prob = 1 - self.drop_prob - # work with diff dim tensors, not just 2D ConvNets - # hidden_state: [s, b, h] - random_tensor = keep_prob + torch.rand_like(hidden_state) - random_tensor.floor_() # binarize - output = hidden_state.div(keep_prob) * random_tensor - return output - - -class ParallelTransformerLayer(nn.Module): - """A single transformer layer. - - Transformer layer takes input with size [s, b, h] and returns an - output of the same size. - """ - - def __init__( - self, - init_method, - output_layer_init_method, - layer_number, - args, - drop_path_rate=0.0, - ): - super().__init__() - self.layer_number = layer_number - - self.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm - - self.bf16 = args.bf16 - - # Layernorm on the input data. - self.input_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - no_persist_layer_norm=not args.persist_layer_norm, - param_dtype=args.param_dtype, - ) - - # Self attention. - self.self_attention = ParallelAttention( - args.hidden_size, - args.kv_channels, - args.num_attention_heads, - args.world_size, - 1, # n_shared_qhead - args.param_dtype, - ) - self.hidden_dropout = args.hidden_dropout - self.bias_dropout_fusion = args.bias_dropout_fusion - self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None - - # Layernorm on the attention output - self.post_attention_layernorm = LayerNorm( - args.hidden_size, - eps=args.layernorm_epsilon, - no_persist_layer_norm=not args.persist_layer_norm, - param_dtype=args.param_dtype, - ) - - # MLP - if args.num_experts is not None: - self.mlp = SwitchMLP(init_method, output_layer_init_method, args) - else: - self.mlp = ParallelMLP(args.hidden_size, param_dtype=args.param_dtype) - - # Set bias+dropout+add fusion grad_enable execution handler. - TORCH_MAJOR = int(torch.__version__.split(".")[0]) - TORCH_MINOR = int(torch.__version__.split(".")[1]) - use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) - self.bias_dropout_add_exec_handler = nullcontext if use_nvfuser else torch.enable_grad - - def forward( - self, hidden_states, attention_mask, encoder_output=None, enc_dec_attn_mask=None, inference_params=None - ): - # hidden_states: [s, b, h] - - # Layer norm at the beginning of the transformer layer. - layernorm_output = self.input_layernorm(hidden_states) - - # Self attention. - attention_output, attention_bias = self.self_attention( - layernorm_output, attention_mask, inference_params=inference_params - ) - - # assert not torch.isnan(attention_output.to_local() - # ).any(), attention_output - # assert not torch.isnan(attention_bias.to_local()).any(), attention_bias - - # Residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = hidden_states - - if self.drop_path is None: - # jit scripting for a nn.module (with dropout) is not - # trigerring the fusion kernel. For now, we use two - # different nn.functional routines to account for varying - # dropout semantics during training and inference phases. - if self.bias_dropout_fusion: - if self.training: - bias_dropout_add_func = bias_dropout_add_fused_train - else: - bias_dropout_add_func = bias_dropout_add_fused_inference - else: - bias_dropout_add_func = get_bias_dropout_add(self.training) - - with self.bias_dropout_add_exec_handler(): - layernorm_input = bias_dropout_add_func( - attention_output, attention_bias.expand_as(residual), residual, self.hidden_dropout - ) - else: - out = attention_output + attention_bias - out = torch.nn.functional.dropout(out, p=self.hidden_dropout, training=self.training) - layernorm_input = residual + self.drop_path(out) - - # Layer norm post the self attention. - layernorm_output = self.post_attention_layernorm(layernorm_input) - # assert not torch.isnan(layernorm_output).any() - - # MLP. - mlp_output, mlp_bias = self.mlp(layernorm_output) - # assert not torch.isnan(mlp_output.to_local()).any(), mlp_output - # assert not torch.isnan(mlp_bias.to_local()).any(), mlp_bias - - # Second residual connection. - if self.apply_residual_connection_post_layernorm: - residual = layernorm_output - else: - residual = layernorm_input - - if self.drop_path is None: - with self.bias_dropout_add_exec_handler(): - output = bias_dropout_add_func(mlp_output, mlp_bias.expand_as(residual), residual, self.hidden_dropout) - - # Jit compiled function creates 'view' tensor. This tensor - # potentially gets saved in the MPU checkpoint function context, - # which rejects view tensors. While making a viewless tensor here - # won't result in memory savings (like the data loader, or - # p2p_communication), it serves to document the origin of this - # 'view' tensor. - # output = dtensor.utils.make_viewless_tensor(inp=output, requires_grad=output.requires_grad, keep_graph=True) - # - else: - out = mlp_output + mlp_bias - out = torch.nn.functional.dropout(out, p=self.hidden_dropout, training=self.training) - output = residual + self.drop_path(out) - - return output - - def forward_util(self, input_tensor, data): - ret = { - "hidden_states": input_tensor if input_tensor is not None else data["hidden_states"], - "attention_mask": data["attention_mask"], - } - return [ret["hidden_states"], ret["attention_mask"]] - - def output_utils(self, p2p_tensor): - p2p_tensor = torch.permute(p2p_tensor, (0, 2, 1)) - return p2p_tensor diff --git a/vescale/model/base_gpt/utils.py b/vescale/model/base_gpt/utils.py deleted file mode 100644 index 3a67817..0000000 --- a/vescale/model/base_gpt/utils.py +++ /dev/null @@ -1,27 +0,0 @@ -################################################################################ -# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. -################################################################################ -# Modification Copyright 2023 ByteDance Ltd. and/or its affiliates. -################################################################################ - -import functools -from typing import Callable - -from optree import tree_map -from vescale.dtensor.dtensor import DTensor - - -def switch_dtensor(func: Callable): - @functools.wraps(func) - def wrap(*args, **kwargs): - def to_tensor(x): - if isinstance(x, DTensor): - return x.to_local() - return x - - new_args = tree_map(to_tensor, args) - new_kwargs = tree_map(to_tensor, kwargs) - out = func(*new_args, **new_kwargs) - return out - - return wrap diff --git a/vescale/pipe/_schedules/instruction_base.py b/vescale/pipe/_schedules/instruction_base.py index d43474e..5312f72 100644 --- a/vescale/pipe/_schedules/instruction_base.py +++ b/vescale/pipe/_schedules/instruction_base.py @@ -27,7 +27,10 @@ from vescale.pipe.pipe_stage import PipeModule from typing import List, Tuple, Union, Optional, Dict, Any import logging +import functools import numpy as np +from optree import tree_map +from vescale.dtensor.dtensor import DTensor from vescale.plan.spec import PipelineP2PSpec Shape = Union[List[int], torch.Size] @@ -36,6 +39,22 @@ registed_functions = {} +def switch_dtensor(func: Callable): + @functools.wraps(func) + def wrap(*args, **kwargs): + def to_tensor(x): + if isinstance(x, DTensor): + return x.to_local() + return x + + new_args = tree_map(to_tensor, args) + new_kwargs = tree_map(to_tensor, kwargs) + out = func(*new_args, **new_kwargs) + return out + + return wrap + + def register_instruction(name): assert name is not None, "The Instruction must have name" if name in registed_functions: diff --git a/vescale/pipe/_schedules/looping_bfs.py b/vescale/pipe/_schedules/looping_bfs.py index 4d0b6e6..0d68cf1 100644 --- a/vescale/pipe/_schedules/looping_bfs.py +++ b/vescale/pipe/_schedules/looping_bfs.py @@ -26,6 +26,7 @@ VESCALE_INTRUCTION_BUILDER as builder, register_instruction, registed_functions, + switch_dtensor, ) import contextlib from dataclasses import dataclass, field @@ -46,7 +47,6 @@ send_forward_recv_forward, send_backward_recv_backward, ) -from vescale.model.base_gpt.utils import switch_dtensor @dataclass diff --git a/vescale/pipe/_schedules/zero_bubble_v.py b/vescale/pipe/_schedules/zero_bubble_v.py index 294a806..51a633a 100644 --- a/vescale/pipe/_schedules/zero_bubble_v.py +++ b/vescale/pipe/_schedules/zero_bubble_v.py @@ -31,6 +31,7 @@ Shape, registed_functions, VESCALE_INTRUCTION_BUILDER as builder, + switch_dtensor, ) from vescale.pipe.p2p_communication import ( recv_backward, @@ -47,7 +48,6 @@ from vescale.dtensor.placement_types import Placement from vescale.dtensor._utils import compute_global_tensor_info from torch.distributed.distributed_c10d import _get_default_group -from vescale.model.base_gpt.utils import switch_dtensor import logging diff --git a/vescale/pipe/pipe_stage.py b/vescale/pipe/pipe_stage.py index 9e91f56..75c955d 100644 --- a/vescale/pipe/pipe_stage.py +++ b/vescale/pipe/pipe_stage.py @@ -69,7 +69,7 @@ def __init__( lr_scheduler: Callable, stage_deps: np.ndarray, p2p_index_mapping: Dict, - config: PipelineParallelPlan, + plan: PipelineParallelPlan, ): super().__init__() self.stage_modules = {} @@ -84,9 +84,9 @@ def __init__( self.sync_chunk_ids = set() self.shared_path_this_stage = {} self.shared_module_mapping = {} - self.config = config - self.num_stages = self.config.num_stages - self.virtual_chunks = self.config.virtual_chunks + self.plan = plan + self.num_stages = self.plan.num_stages + self.virtual_chunks = self.plan.virtual_chunks self.stage_deps = stage_deps self.p2p_index_mapping = p2p_index_mapping