From 513dfbe820d99106aab8282b201656b3c4df3811 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Sat, 11 Jun 2022 16:23:06 +0900 Subject: [PATCH 01/37] add deparallelize_functions --- .../_parallel_2p5d/_wrapper.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index ac0662e8..538451ed 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -1,5 +1,3 @@ -from typing import Dict - import torch import torch.nn as nn @@ -447,3 +445,22 @@ def _slice_embedding(self, module): if isinstance(module, nn.Embedding): module.__class__ = Embedding2p5D + + @torch.no_grad() + def _deparallelize(self): + self._deparallelize_layernorm() + self._deparallelize_linear() + self._deparallelize_embedding() + self._rollback_mp_arguments() + + def _rollback_mp_arguments(self): + pass + + def _deparallelize_embedding(self): + pass + + def _deparallelize_linear(self): + pass + + def _deparallelize_layernorm(self): + pass From 7e318eae2962c949d8436a2ec494d5ee2fba5f98 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Sat, 11 Jun 2022 20:50:58 +0900 Subject: [PATCH 02/37] implement gather_linear function --- .../_parallel_2p5d/_wrapper.py | 117 +++++++++++++++++- 1 file changed, 114 insertions(+), 3 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 295b022d..135cb319 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -12,7 +12,11 @@ ) from oslo.torch.nn.modules.linear import Linear2p5D from oslo.torch.nn.modules.layer_norm import LayerNorm2p5D -from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import split_batch_2p5d +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( + split_batch_2p5d, + gather_2d, + gather_1d +) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, @@ -452,13 +456,120 @@ def _deparallelize(self): self._rollback_mp_arguments() def _rollback_mp_arguments(self): - pass + for module in self.module.modules(): + for elem in self.tensor_parallel_mapping.update_attrs(self.module): + if hasattr(module, elem.name): + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) + reduced_arg = getattr(module, elem.name) * tesseract_dim + setattr(module, elem.name, reduced_arg) def _deparallelize_embedding(self): pass def _deparallelize_linear(self): + # for param_name, module in self.module.named_modules(): + # if self.tensor_parallel_mapping.is_column_parallel( + # self.module, param_name + # ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): + # self._slice_linear( + # module=module, + # reversed=self.tensor_parallel_mapping.is_reversed_param( + # self.module, param_name + # ), + # fusion_degree=self.tensor_parallel_mapping.get_combined_qkv_degree( + # self.module, param_name, module + # ), + # slice_bias=True, + # ) + # module.__class__ = Linear2p5D + # TODO: gather this logics + for param_name, module in self.module.named_modules(): + if module.__class__ == Linear2p5D: + self._gather_linear(module) + module.__class__ = nn.Linear + + def _deparallelize_layernorm(self, module): pass - def _deparallelize_layernorm(self): + def _gather_embedding(self, module): pass + + def _gather_linear(self, module: Linear2p5D): + # , reversed, fusion_degree, slice_bias + is_reversed = module.reversed + fusion_degree = module.fusion_degree + slice_bias = module.slice_bias + + tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + + if fusion_degree > 1: + w = self._reconstruct_combined_qkv(self.weight.data, tesseract_dim, fusion_degree, False) + else: + w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim=tesseract_dim, col_first=True) + + if is_reversed: + w = w.t() + module.weight.data = w + + if hasattr(module, "bias") and module.bias is not None: + if slice_bias is True and module.bias.dim() >= 1: + if fusion_degree > 1: + b = self._reconstruct_combined_qkv(self.bias.data, tesseract_dim, fusion_degree, True) + else: + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + module.bias.data = b + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + row_rank=None, + col_rank=None, + dep_rank=None, + tesseract_dim=None, + data_parallel_rank=None, + pipeline_parallel_rank=None, + tensor_parallel_size=None, + pipeline_parallel_size=None, + reversed=None, + fusion_degree=None, + orig_module=None, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + gather_output=None, + ) + module.__class__ = nn.Linear + + def _gather_layernorm(self, module): + pass + + @staticmethod + def _reconstruct_combined_qkv(tensor, tessearct_dim, fusion_degree, is_bias=False): + tensor = [ + [ + tensor[i][j * tessearct_dim + k] + for i in range(tessearct_dim) + for k in range(tessearct_dim) + ] + for j in range(fusion_degree) + ] + tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) + tensor = [ + [tensor[i * tessearct_dim + j] for j in range(tessearct_dim)] + for i in range(tessearct_dim) + ] + return tensor + + @staticmethod + def _reconstrunct_combined_qkv_bias(tensor, tessearct_dim, fusion_degree): + tensor = [ + [tensor[j * tessearct_dim + k] for k in range(tessearct_dim)] + for j in range(fusion_degree) + ] + tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) + tensor = [tensor[j] for j in range(tessearct_dim)] + return tensor From 586a9041486d70cc715d22367395e9e6272ad4b8 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Sun, 12 Jun 2022 08:54:39 +0900 Subject: [PATCH 03/37] implement deconstrunct_qkv --- .../_parallel_2p5d/_wrapper.py | 26 ++++---- .../_parallel_2p5d/test_qkv.py | 63 +++++++++++++++++++ 2 files changed, 74 insertions(+), 15 deletions(-) create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_qkv.py diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 135cb319..ce1eabf2 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -548,21 +548,17 @@ def _gather_layernorm(self, module): pass @staticmethod - def _reconstruct_combined_qkv(tensor, tessearct_dim, fusion_degree, is_bias=False): - tensor = [ - [ - tensor[i][j * tessearct_dim + k] - for i in range(tessearct_dim) - for k in range(tessearct_dim) - ] - for j in range(fusion_degree) - ] - tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) - tensor = [ - [tensor[i * tessearct_dim + j] for j in range(tessearct_dim)] - for i in range(tessearct_dim) - ] - return tensor + def _reconstruct_combined_qkv(tensor, tesseract_dim, fusion_degree, is_bias=False): + last_dim = tensor.size()[-1] + if is_bias is False: + reshaped_w = tensor.view(-1, tesseract_dim, last_dim) + recon_w = torch.cat([ + reshaped_w[:fusion_degree], reshaped_w[fusion_degree:]], 1).view(-1, last_dim).contiguous() + else: + reshaped_w = tensor.view(-1, tesseract_dim) + recon_w = torch.cat([ + reshaped_w[:fusion_degree], reshaped_w[fusion_degree:]], 1).view(-1, last_dim).contiguous() + return recon_w @staticmethod def _reconstrunct_combined_qkv_bias(tensor, tessearct_dim, fusion_degree): diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_qkv.py new file mode 100644 index 00000000..2b5df38b --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_qkv.py @@ -0,0 +1,63 @@ +import torch.nn as nn + +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._wrapper import _TensorParallel2p5D +from oslo.torch.nn import Linear2p5D +from oslo.torch.distributed import ParallelContext, ParallelMode +from copy import deepcopy +from _utils import gather_2d, gather_1d + + +tp_size = 8 +tp_depth = 2 + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +row_rank = parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) +col_rank = parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) +tesseract_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) +fusion_degree = 3 + +linear = nn.Linear(4, fusion_degree*4).cuda() +w = deepcopy(linear.weight.data) +b = deepcopy(linear.bias.data) + +weight_list = w.chunk(tesseract_dim, dim=1) +weight_list = [ + weight.chunk(tesseract_dim * fusion_degree, dim=0) for weight in weight_list +] +bias_list = b.chunk(tesseract_dim * fusion_degree, dim=0) + +# [t][f*t] +weight_list = _TensorParallel2p5D._deconstruct_combined_qkv(weight_list, tesseract_dim, fusion_degree, False) +bias_list = _TensorParallel2p5D._deconstruct_combined_qkv(bias_list, tesseract_dim, fusion_degree, True) +chunked_w = weight_list[row_rank][col_rank] +chunked_b = bias_list[row_rank] + +linear_2p5d = Linear2p5D(4, fusion_degree*4, parallel_context=parallel_context, bias=True) +if parallel_context.get_global_rank() == 0: + print(chunked_w.size()) + print(linear_2p5d.weight.data.size()) +linear_2p5d.weight.data = chunked_w +linear_2p5d.bias.data = chunked_b + +recon_chunked_w = gather_2d(parallel_context, linear_2p5d.weight.data, tesseract_dim, True) +recon_chunked_b = gather_1d(parallel_context, linear_2p5d.bias.data, tesseract_dim, 0) +# reshaped_w = recon_chunked_w.view(-1, tesseract_dim, 4) +# recon_w = torch.cat([reshaped_w[:3], reshaped_w[3:]], 1).view(-1, 4).contiguous() + +recon_w = _TensorParallel2p5D._reconstruct_combined_qkv(recon_chunked_w, tesseract_dim, fusion_degree, False) +recon_b = _TensorParallel2p5D._reconstruct_combined_qkv(recon_chunked_b, tesseract_dim, fusion_degree, True) + +if parallel_context.get_global_rank() == 0: + print(f"original w: \n{w}\n") + print(f"reconstruct w: \n{recon_w}\n") + + print(f"original b: \n{b}\n") + print(f"reconstruct b: \n{recon_b}\n") \ No newline at end of file From 126c9865cc3b9f38fe8e43fa28c167aebe54afb4 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Mon, 13 Jun 2022 14:23:44 +0900 Subject: [PATCH 04/37] fixed tp2.5d differ backward --- oslo/torch/nn/modules/linear.py | 12 +- .../tensor_parallel/_parallel_2p5d/_ops.py | 154 ++++++++++++++++-- oslo/torch/nn/parallel/utils.py | 5 + .../_parallel_2p5d/test_linear_2p5d.py | 25 ++- 4 files changed, 172 insertions(+), 24 deletions(-) diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index afb78b73..38767ea5 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -490,12 +490,12 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: ParallelMode.TENSOR_2P5D_COL, ) if self.gather_output: - output = all_gather_tensor_2p5d( - output, - dim=0, - col_parallel_mode=ParallelMode.TENSOR_2P5D_DEP, - parallel_context=self.parallel_context, - ).clone() + # output = all_gather_tensor_2p5d( + # output, + # dim=0, + # col_parallel_mode=ParallelMode.TENSOR_2P5D_DEP, + # parallel_context=self.parallel_context, + # ).clone() output = all_gather_tensor_2p5d( output, dim=0, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py index f018f224..499036f5 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py @@ -160,11 +160,12 @@ def split_batch_2p5d( col_chunked = torch.chunk( inputs, parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL), dim=dim )[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL)].contiguous() - return torch.chunk( - col_chunked, - parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_DEP), - dim=dim, - )[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP)].contiguous() + return col_chunked + # return torch.chunk( + # col_chunked, + # parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_DEP), + # dim=dim, + # )[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP)].contiguous() def get_current_device(): @@ -1209,11 +1210,138 @@ def backward(ctx: Any, output_grad: Tensor): return output_grad, None -# def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor: -# r"""All-reduce the input from the model parallel region. -# Args: -# input_ (:class:`torch.tensor`): input matrix. -# reduce_mean (bool, optional): -# If set to ``True``, it will divide the output by column parallel size, default to False. -# """ -# return _ReduceByBatch2p5D.apply(input_, reduce_mean) +def split_batch_2d(parallel_context, tensor, tesseract_dim): + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + return tensor + + +def split_2p5d(parallel_context, tensor, tesseract_dim, col_first=True): + if col_first: + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + else: + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + tensor = torch.chunk(tensor, tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP) + ] + return tensor + + +def split_2d(parallel_context, tensor, tesseract_dim, col_first=True): + if col_first: + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + else: + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + return tensor + + +def split_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor = tensor.chunk(summa_dim, dim=dim)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + return tensor + + +def gather_2p5d(parallel_context, tensor, tesseract_dim, col_first=True): + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, tensor, parallel_context.get_group(ParallelMode.TENSOR_2P5D_DEP) + ) + tensor = torch.cat(tensor_list, dim=0) + if col_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_2d(parallel_context, tensor, tesseract_dim, col_first=True): + if col_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor diff --git a/oslo/torch/nn/parallel/utils.py b/oslo/torch/nn/parallel/utils.py index c3414a1c..a74ecfe7 100644 --- a/oslo/torch/nn/parallel/utils.py +++ b/oslo/torch/nn/parallel/utils.py @@ -65,6 +65,11 @@ def _update_module_arguments(module: nn.Module, **kwargs): setattr(module, k, v) +def _remove_module_arguments(module: nn.Module, args: list): + for k in args: + delattr(module, k) + + def allocate_params(model: nn.Module, parallel_context: ParallelContext): for name, parameter in model.named_parameters(): if hasattr(parameter, "oslo_parallel"): diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py index eba931e4..b5a866a8 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py @@ -4,7 +4,7 @@ from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn import Linear2p5D -from _utils import split_2p5d, split_2d, gather_2p5d +from _utils import split_2p5d, split_2d, gather_2p5d, gather_2d from copy import deepcopy @@ -47,8 +47,8 @@ # split input_ into # 0:[0, 0, 0], 1:[0, 0, 1], 2:[0, 1, 0], 3:[0, 1, 1], 4:[1, 0, 0], 5:[1, 0, 1], 6:[1, 1, 0], 7:[1, 1, 1] # input shape: (m/dq, n/q) -input_ = split_2p5d(parallel_context, input_, tesseract_dim) -ptarget = split_2p5d(parallel_context, target, tesseract_dim) +input_ = split_2d(parallel_context, input_, tesseract_dim) +ptarget = split_2d(parallel_context, target, tesseract_dim) # split weight into 0,4:[0, 0], 1,5:[1, 0], 2,6:[0, 1], 3,7:[1, 1] # input shape: (n/q, k/q) @@ -71,8 +71,8 @@ pout_update = linear_2p5d(input_) -pout = gather_2p5d(parallel_context, pout, tesseract_dim, False) -pout_update = gather_2p5d(parallel_context, pout_update, tesseract_dim, False) +pout = gather_2d(parallel_context, pout, tesseract_dim, False) +pout_update = gather_2d(parallel_context, pout_update, tesseract_dim, False) # w = gather_2d(parallel_context, linear_2p5d.weight.data, tesseract_dim, True) # b = gather_1d(parallel_context, linear_2p5d.bias.data, tesseract_dim, 0) @@ -84,8 +84,14 @@ if parallel_context.get_global_rank() == 0: sse = torch.sum((out - pout) ** 2).item() sse_update = torch.sum((out_update - pout_update) ** 2).item() + minmax_update = (out_update - pout_update) ** 2 print(f"output sse: \n{sse}\n") print(f"next output sse: \n{sse_update}\n") + print(f"next output max: \n{minmax_update.max()}\n") + print(f"next output min: \n{minmax_update.min()}\n") + + + linear_2p5d = Linear2p5D(4, 4, gather_output=True, parallel_context=parallel_context) linear_2p5d.weight.data.copy_(w) @@ -103,8 +109,17 @@ print(f"parallel output (gather_output=True): \n{pout}\n") print(f"parallel update output (gather_output=True): \n{pout_update}\n") + if parallel_context.get_global_rank() == 0: sse = torch.sum((out - pout) ** 2).item() sse_update = torch.sum((out_update - pout_update) ** 2).item() + minmax_update = (out_update - pout_update) ** 2 print(f"output sse (gather_output=True): \n{sse}\n") print(f"next output sse (gather_output=True): \n{sse_update}\n") + import pprint + # top5 = torch.clamp(minmax_update.flatten(), 1e-8) + top5 = minmax_update.flatten() + top5 = [t.item() for t in top5] + top5 = [top5[i:i+4] for i in range(0, len(top5), 4)] + pprint.pprint(top5) + print(f"next output min: \n{minmax_update.min()}\n") From 20265c44bc2f5ecdceac950e26a8051baf6c997b Mon Sep 17 00:00:00 2001 From: jason960903 Date: Tue, 14 Jun 2022 08:39:13 +0900 Subject: [PATCH 05/37] add tp2.5d interface for save pretrained and deparallelization --- oslo/torch/nn/modules/linear.py | 6 - .../tensor_parallel/_parallel_1d/_wrapper.py | 10 + .../tensor_parallel/_parallel_2d/_wrapper.py | 11 + .../_parallel_2p5d/_wrapper.py | 277 +++++++++++++++--- .../tensor_parallel/tensor_parallel.py | 44 ++- .../_parallel_2p5d/deparallel/__init__.py | 3 + .../{ => deparallel}/test_qkv.py | 3 +- .../_parallel_2p5d/test_wrapper_2p5d.py | 50 +++- 8 files changed, 337 insertions(+), 67 deletions(-) create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/__init__.py rename tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/{ => deparallel}/test_qkv.py (98%) diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index 38767ea5..d174af8e 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -490,12 +490,6 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: ParallelMode.TENSOR_2P5D_COL, ) if self.gather_output: - # output = all_gather_tensor_2p5d( - # output, - # dim=0, - # col_parallel_mode=ParallelMode.TENSOR_2P5D_DEP, - # parallel_context=self.parallel_context, - # ).clone() output = all_gather_tensor_2p5d( output, dim=0, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 45b6a425..5fcc2b88 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -42,6 +42,7 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None ): super().__init__() self.module = module @@ -56,6 +57,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 945886f4..b14ddf59 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -48,6 +48,8 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None + ): super().__init__() self.module = module @@ -62,6 +64,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index ce1eabf2..fd252c65 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -1,8 +1,12 @@ import copy +import os import torch import torch.nn as nn +from typing import Union, Optional, Callable +from logging import getLogger + from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn.modules.embedding import ( @@ -33,7 +37,7 @@ ) from oslo.transformers.constants import BATCH_DIMENSIONS - +from transformers import AutoConfig class _TensorParallel2p5D(ParallelWrapper): """ @@ -49,9 +53,11 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None ): super().__init__() self.module = module + self.config = self.module.config self.parallel_context = parallel_context self.device = torch.cuda.current_device() @@ -63,6 +69,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() @@ -449,7 +464,7 @@ def _slice_head(self, module, reversed): module.__class__ = Linear2p5D @torch.no_grad() - def _deparallelize(self): + def deparallelize(self): self._deparallelize_layernorm() self._deparallelize_linear() self._deparallelize_embedding() @@ -462,90 +477,138 @@ def _rollback_mp_arguments(self): tesseract_dim = self.parallel_context.get_world_size( ParallelMode.TENSOR_2P5D_COL ) - reduced_arg = getattr(module, elem.name) * tesseract_dim - setattr(module, elem.name, reduced_arg) + expanded_arg = getattr(module, elem.name) * tesseract_dim + setattr(module, elem.name, expanded_arg) def _deparallelize_embedding(self): - pass + for param_name, module in self.module.named_modules(): + if module.__class__ == VocabParallelEmbedding2p5D: + self._gather_embedding(module) def _deparallelize_linear(self): - # for param_name, module in self.module.named_modules(): - # if self.tensor_parallel_mapping.is_column_parallel( - # self.module, param_name - # ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): - # self._slice_linear( - # module=module, - # reversed=self.tensor_parallel_mapping.is_reversed_param( - # self.module, param_name - # ), - # fusion_degree=self.tensor_parallel_mapping.get_combined_qkv_degree( - # self.module, param_name, module - # ), - # slice_bias=True, - # ) - # module.__class__ = Linear2p5D - # TODO: gather this logics for param_name, module in self.module.named_modules(): if module.__class__ == Linear2p5D: self._gather_linear(module) - module.__class__ = nn.Linear - def _deparallelize_layernorm(self, module): - pass + def _deparallelize_layernorm(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == LayerNorm2p5D: + self._gather_layernorm(module) def _gather_embedding(self, module): - pass + tesseract_dim = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + if self.module.get_input_embeddings(): + w = gather_2d(self.parallel_context, module.weight, tesseract_dim, col_first=True) + + assert hasattr( + self.module, "orig_vocab_size" + ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + orig_vocab_size = self.module.orig_vocab_size + + module.weight.data = w[:, :orig_vocab_size] + + _update_module_arguments( + module=module, + vocab_start_index=None, + vocab_end_index=None, + parallel_context=None, + num_embeddings=module.weight.size()[0], + embedding_dim=module.weight.size()[1], + orig_module=None + ) + else: + w = gather_1d(self.parallel_context, module.weight, tesseract_dim, 1) + w = gather_1d(self.parallel_context, w, tesseract_dim, 1) + module.weight.data = w + + _update_module_arguments( + module=module, + parallel_context=None, + embedding_dim = module.weight.size()[1] + ) + module.__class__ = nn.Embedding def _gather_linear(self, module: Linear2p5D): - # , reversed, fusion_degree, slice_bias is_reversed = module.reversed fusion_degree = module.fusion_degree slice_bias = module.slice_bias tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim=tesseract_dim, col_first=True) if fusion_degree > 1: w = self._reconstruct_combined_qkv(self.weight.data, tesseract_dim, fusion_degree, False) - else: - w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim=tesseract_dim, col_first=True) + module.weight.data = w if is_reversed: w = w.t() module.weight.data = w if hasattr(module, "bias") and module.bias is not None: - if slice_bias is True and module.bias.dim() >= 1: - if fusion_degree > 1: - b = self._reconstruct_combined_qkv(self.bias.data, tesseract_dim, fusion_degree, True) - else: - b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) - module.bias.data = b + # if slice_bias is True and module.bias.dim() >= 1: + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + if fusion_degree > 1: + b = self._reconstruct_combined_qkv(self.bias.data, tesseract_dim, fusion_degree, True) + module.bias.data = b _update_module_arguments( module=module, in_features=module.weight.size()[1], out_features=module.weight.size()[0], parallel_context=self.parallel_context, - row_rank=None, - col_rank=None, - dep_rank=None, - tesseract_dim=None, - data_parallel_rank=None, - pipeline_parallel_rank=None, - tensor_parallel_size=None, - pipeline_parallel_size=None, - reversed=None, - fusion_degree=None, - orig_module=None, skip_bias_add=module.skip_bias_add if hasattr(module, "skip_bias_add") else False, - gather_output=None, ) + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + module.__class__ = nn.Linear def _gather_layernorm(self, module): - pass + if hasattr(module, "weight") and module.weight is not None: + if module.weight.dim() >= 1: + w = gather_1d(self.parallel_context, module.weight.data, 0) + module.weight.data = w + + if hasattr(module.weight, "oslo_parallel"): + # delete oslo_parallel if it exists + del module.weight.oslo_parallel + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + b = gather_1d(self.parallel_context, module.bias.data, 0) + module.bias.data = b + + if hasattr(module.bias, "oslo_parallel"): + del module.weight.oslo_parallel + + del module.partitioned_dim + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.orig_module + _update_module_arguments( + module, + normalized_shape=module.weight.size()[0], + ) + module.__class__ = nn.LayerNorm @staticmethod def _reconstruct_combined_qkv(tensor, tesseract_dim, fusion_degree, is_bias=False): @@ -569,3 +632,123 @@ def _reconstrunct_combined_qkv_bias(tensor, tessearct_dim, fusion_degree): tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) tensor = [tensor[j] for j in range(tessearct_dim)] return tensor + + @torch.no_grad() + def save_parallelized( + self, + new_module, + save_directory: Union[str, os.PathLike], + save_config: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + merge_checkpoints: bool = False, + mapping: Optional[dict] = None, + **kwargs, + ): + import os + import torch.distributed as dist + from transformers.modeling_utils import get_parameter_dtype, unwrap_model + + logger = getLogger("Tensor2p5D") + PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" + # mapping = kwargs.pop("tp_mapping", None) + + if ( + self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1 + and self.parallel_context.get_world_size(ParallelMode.PIPELINE) == 1 + ): + if dist.get_rank() == 0: + self.save_pretrained( + save_directory=save_directory, + save_config=save_config, + state_dict=state_dict, + save_function=save_function, + **kwargs, + ) + dist.barrier() + return None + + if merge_checkpoints: + model_to_save = self.__class__( + module=new_module, + parallel_context=self.parallel_context, + mapping=mapping, + module_args=self.config + ).eval() + + if state_dict is None: + state_dict = self.state_dict() + + model_to_save.load_state_dict(state_dict) + + if self.parallel_context.get_world_size(ParallelMode.TENSOR) > 1: + model_to_save.deparallelize() + + if dist.get_rank() == 0: + model_to_save.module.save_pretrained( + save_directory=save_directory, + save_config=save_config, + save_function=save_function, + **kwargs, + ) + del model_to_save + + dist.barrier() + return None + + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + + os.makedirs(save_directory, exist_ok=True) + + # Only save the model itself if we are using distributed training + model_to_save = unwrap_model(self) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Save the config + if save_config: + model_to_save.config.save_pretrained(save_directory) + + # Save the model + if state_dict is None: + state_dict = model_to_save.state_dict() + + # Handle the case where some state_dict keys shouldn't be saved + if getattr(self, "_keys_to_ignore_on_save") is not None: + state_dict = { + k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save + } + + # If we save using the predefined names, we can load using `from_pretrained` + weights_name = PARALLELIZED_WEIGHTS_NAME + weights_name = weights_name.replace( + "tp_0", f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}" + ) + weights_name = weights_name.replace( + "pp_0", f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}" + ) + + output_model_file = os.path.join(save_directory, weights_name) + + if self.parallel_context.get_world_size(ParallelMode.DATA) > 1: + if self.parallel_context.get_local_rank(ParallelMode.DATA) == 0: + save_function(state_dict, output_model_file) + else: + save_function(state_dict, output_model_file) + + dist.barrier() + logger.info(f"Model weights saved in {output_model_file}") + + @staticmethod + def from_parallelized(cls): + pass diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index ab7a7c6c..a52918a0 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -1,4 +1,6 @@ -from typing import Optional +from typing import Union, Optional, Callable + +import os import torch import torch.nn as nn @@ -56,17 +58,18 @@ def __init__( module: nn.Module, parallel_context: Optional[ParallelContext] = None, mapping: dict = None, + config: dict = None ): super().__init__() self.parallel_context = get_parallel_context(module, parallel_context) module = self._resize_vocab_size(module, self.parallel_context) module = self._resize_num_classes(module, self.parallel_context, mapping) if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D: - self.module = _TensorParallel1D(module, self.parallel_context, mapping) + self.module = _TensorParallel1D(module, self.parallel_context, mapping, config) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D: - self.module = _TensorParallel2D(module, self.parallel_context, mapping) + self.module = _TensorParallel2D(module, self.parallel_context, mapping, config) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D: - self.module = _TensorParallel2p5D(module, self.parallel_context, mapping) + self.module = _TensorParallel2p5D(module, self.parallel_context, mapping, config) else: raise ValueError( "currently, only 1d, 2d, 2p5d tensor parallelism is supported." @@ -177,3 +180,36 @@ def _resize_num_classes(model, parallel_context, mapping): def _remove_embeddings(self, model, parallel_context): pass + + @torch.no_grad() + def save_parallelized( + self, + save_directory: Union[str, os.PathLike], + save_config: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + merge_checkpoints: bool = False, + mapping: Optional[dict] = None, + **kwargs, + ): + unwrapped_model = unwrap_parallel(self.module.module) + if is_huggingface_model(unwrapped_model): + new_module = unwrapped_model.__class__(self.module.config) + else: + new_module = unwrapped_model.__class__(**self.module.config) + new_module = self._resize_vocab_size(new_module, self.parallel_context) + new_module = self._resize_num_classes(new_module, self.parallel_context, mapping) + return self.module.save_parallelized( + new_module, + save_directory, + save_config, + state_dict, + save_function, + merge_checkpoints, + mapping, + **kwargs, + ) + + @staticmethod + def from_parallelized(cls): + pass \ No newline at end of file diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/__init__.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/__init__.py new file mode 100644 index 00000000..28fcc144 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/__init__.py @@ -0,0 +1,3 @@ +from .. import _utils + +_ALL__ = [_utils] diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py similarity index 98% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_qkv.py rename to tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py index 2b5df38b..177e91f8 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_qkv.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py @@ -4,8 +4,7 @@ from oslo.torch.nn import Linear2p5D from oslo.torch.distributed import ParallelContext, ParallelMode from copy import deepcopy -from _utils import gather_2d, gather_1d - +from .._utils import * tp_size = 8 tp_depth = 2 diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py index a55cb5cf..f21a9434 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py @@ -3,7 +3,7 @@ from datasets import load_dataset from torch.optim import Adam from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, AutoModelForCausalLM from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params @@ -13,9 +13,34 @@ from oslo.torch.nn import Linear2D from _utils import split_2d, gather_2d + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end-start + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + tp_size = 8 tp_depth = 2 +model_name = "gpt2" +mkwargs = { +} +dataset_name = "squad" + # parallel context 생성 parallel_context = ParallelContext.from_torch( data_parallel_size=1, @@ -26,7 +51,7 @@ ) # 토크나이저 생성 -tokenizer = AutoTokenizer.from_pretrained("gpt2") +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) tokenizer.pad_token = tokenizer.eos_token # 모델 생성 및 병렬화 수행 @@ -46,13 +71,13 @@ # 데이터셋 생성 batch_size = 16 -datasets = load_dataset("squad").data["train"]["context"] +datasets = load_dataset(dataset_name).data["train"]["context"] datasets = [str(sample) for sample in datasets[:500]] dataloader = DataLoader(datasets, batch_size=batch_size) # 모니터링 생성 if dist.get_rank() == 0: - wandb.init(project="oslo", name=f"tp2p5d_bs{batch_size}") + wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") cur = time.time() # 모니터링 생성 대기 @@ -71,17 +96,26 @@ max_length=512, ).to("cuda") - loss_no_tp = model_no_tp(**inputs, labels=inputs["input_ids"]).loss - loss_tp = wrapper_tp(**inputs, labels=inputs["input_ids"]).loss + loss_no_tp, notp_fw_time = \ + fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) if dist.get_rank() == 0: print(f"TP:{loss_tp}, NOTP:{loss_no_tp}") wandb.log({"tp": loss_tp, "notp": loss_no_tp}) - loss_no_tp.backward() - loss_tp.backward() + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) optimizer_tp.step() optimizer_no_tp.step() + if dist.get_rank() == 0: + wandb.log({ + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time}) + dist.barrier() From 97beeea4a9614823a22914db35855bc1fa30751a Mon Sep 17 00:00:00 2001 From: jason960903 Date: Tue, 14 Jun 2022 19:35:07 +0900 Subject: [PATCH 06/37] add deparallelize and test code --- .../_parallel_2p5d/_wrapper.py | 34 +++-- .../deparallel/test_deparallelize.py | 139 ++++++++++++++++++ .../_parallel_2p5d/deparallel/test_qkv.py | 2 +- 3 files changed, 158 insertions(+), 17 deletions(-) create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index fd252c65..e3b64017 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -31,6 +31,7 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, + allocate_params ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -465,9 +466,9 @@ def _slice_head(self, module, reversed): @torch.no_grad() def deparallelize(self): - self._deparallelize_layernorm() self._deparallelize_linear() self._deparallelize_embedding() + self._deparallelize_layernorm() self._rollback_mp_arguments() def _rollback_mp_arguments(self): @@ -496,9 +497,9 @@ def _deparallelize_layernorm(self): self._gather_layernorm(module) def _gather_embedding(self, module): - tesseract_dim = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) - if self.module.get_input_embeddings(): - w = gather_2d(self.parallel_context, module.weight, tesseract_dim, col_first=True) + tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim, col_first=True) assert hasattr( self.module, "orig_vocab_size" @@ -531,24 +532,23 @@ def _gather_embedding(self, module): def _gather_linear(self, module: Linear2p5D): is_reversed = module.reversed fusion_degree = module.fusion_degree - slice_bias = module.slice_bias + # slice_bias = module.slice_bias tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim=tesseract_dim, col_first=True) + # print(f"w shape: {w.shape}\nweight shape: {module.weight.data.shape}") if fusion_degree > 1: - w = self._reconstruct_combined_qkv(self.weight.data, tesseract_dim, fusion_degree, False) - module.weight.data = w - + w = self._reconstruct_combined_qkv(w, tesseract_dim, fusion_degree, False) if is_reversed: - w = w.t() + w = module.weight.data.t() module.weight.data = w if hasattr(module, "bias") and module.bias is not None: # if slice_bias is True and module.bias.dim() >= 1: b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) if fusion_degree > 1: - b = self._reconstruct_combined_qkv(self.bias.data, tesseract_dim, fusion_degree, True) + b = self._reconstruct_combined_qkv(b, tesseract_dim, fusion_degree, True) module.bias.data = b _update_module_arguments( @@ -577,22 +577,22 @@ def _gather_linear(self, module: Linear2p5D): module.__class__ = nn.Linear def _gather_layernorm(self, module): + tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) if hasattr(module, "weight") and module.weight is not None: if module.weight.dim() >= 1: - w = gather_1d(self.parallel_context, module.weight.data, 0) + w = gather_1d(self.parallel_context, module.weight.data, tesseract_dim, 0) module.weight.data = w if hasattr(module.weight, "oslo_parallel"): - # delete oslo_parallel if it exists del module.weight.oslo_parallel if hasattr(module, "bias") and module.bias is not None: if module.bias.dim() >= 1: - b = gather_1d(self.parallel_context, module.bias.data, 0) + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) module.bias.data = b if hasattr(module.bias, "oslo_parallel"): - del module.weight.oslo_parallel + del module.bias.oslo_parallel del module.partitioned_dim del module.row_rank @@ -614,11 +614,12 @@ def _gather_layernorm(self, module): def _reconstruct_combined_qkv(tensor, tesseract_dim, fusion_degree, is_bias=False): last_dim = tensor.size()[-1] if is_bias is False: - reshaped_w = tensor.view(-1, tesseract_dim, last_dim) + reshaped_w = tensor.view(tesseract_dim*fusion_degree, -1, last_dim) + # print(f"tensor.size: {tensor.size()}, reshaped_w.size: {reshaped_w.size()}") recon_w = torch.cat([ reshaped_w[:fusion_degree], reshaped_w[fusion_degree:]], 1).view(-1, last_dim).contiguous() else: - reshaped_w = tensor.view(-1, tesseract_dim) + reshaped_w = tensor.view(fusion_degree*tesseract_dim, -1) recon_w = torch.cat([ reshaped_w[:fusion_degree], reshaped_w[fusion_degree:]], 1).view(-1, last_dim).contiguous() return recon_w @@ -680,6 +681,7 @@ def save_parallelized( state_dict = self.state_dict() model_to_save.load_state_dict(state_dict) + allocate_params(model_to_save, self.parallel_context) if self.parallel_context.get_world_size(ParallelMode.TENSOR) > 1: model_to_save.deparallelize() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py new file mode 100644 index 00000000..ec19918c --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py @@ -0,0 +1,139 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end-start + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 8 +tp_depth = 2 + +model_name = "gpt2" +mkwargs = { +} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() +model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized('test/', merge_checkpoints=True) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_gathered = GPT2LMHeadModel(GPT2Config.from_pretrained("test/")).cuda() +optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + optimizer_gathered.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = \ + fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, gathered_bw_time = bw(loss_gathered) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_gathered.step() + + if dist.get_rank() == 0: + wandb.log({ + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time + }) + +dist.barrier() + + + + diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py index 177e91f8..1d4591e9 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py @@ -4,7 +4,7 @@ from oslo.torch.nn import Linear2p5D from oslo.torch.distributed import ParallelContext, ParallelMode from copy import deepcopy -from .._utils import * +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import gather_1d, gather_2d tp_size = 8 tp_depth = 2 From dbe21d2ae029c1ad20a2302a111cbd45066287e4 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Tue, 14 Jun 2022 21:56:01 +0900 Subject: [PATCH 07/37] modified argument name (config -> model_args) --- .../_parallel_2p5d/_wrapper.py | 28 +++++++++++-------- .../tensor_parallel/tensor_parallel.py | 16 ++++++++--- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index e3b64017..0e632563 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -31,14 +31,16 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, - allocate_params + allocate_params, + unwrap_parallel, + get_parameter_dtype ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, ) from oslo.transformers.constants import BATCH_DIMENSIONS -from transformers import AutoConfig + class _TensorParallel2p5D(ParallelWrapper): """ @@ -58,7 +60,6 @@ def __init__( ): super().__init__() self.module = module - self.config = self.module.config self.parallel_context = parallel_context self.device = torch.cuda.current_device() @@ -648,7 +649,6 @@ def save_parallelized( ): import os import torch.distributed as dist - from transformers.modeling_utils import get_parameter_dtype, unwrap_model logger = getLogger("Tensor2p5D") PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" @@ -687,12 +687,16 @@ def save_parallelized( model_to_save.deparallelize() if dist.get_rank() == 0: - model_to_save.module.save_pretrained( - save_directory=save_directory, - save_config=save_config, - save_function=save_function, - **kwargs, - ) + if is_huggingface_model(model_to_save.module): + model_to_save.module.save_pretrained( + save_directory=save_directory, + save_config=save_config, + save_function=save_function, + **kwargs, + ) + else: + # TODO : Non-huggingface model + pass del model_to_save dist.barrier() @@ -707,7 +711,7 @@ def save_parallelized( os.makedirs(save_directory, exist_ok=True) # Only save the model itself if we are using distributed training - model_to_save = unwrap_model(self) + model_to_save = unwrap_parallel(self) # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" # we currently don't use this setting automatically, but may start to use with v5 @@ -726,7 +730,7 @@ def save_parallelized( state_dict = model_to_save.state_dict() # Handle the case where some state_dict keys shouldn't be saved - if getattr(self, "_keys_to_ignore_on_save") is not None: + if getattr(self, "_keys_to_ignore_on_save", None) is not None: state_dict = { k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save } diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index a52918a0..08de4a3f 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -5,6 +5,8 @@ import torch import torch.nn as nn +from transformers import AutoConfig + from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._wrapper import ( _TensorParallel1D, @@ -58,18 +60,24 @@ def __init__( module: nn.Module, parallel_context: Optional[ParallelContext] = None, mapping: dict = None, - config: dict = None + module_args: dict = None ): super().__init__() + if is_huggingface_model(module): + assert module_args is not None, "module_args must not be provided in huggingface module." + else: + assert isinstance(module_args, dict), "module_args must be a dict." + self.parallel_context = get_parallel_context(module, parallel_context) module = self._resize_vocab_size(module, self.parallel_context) module = self._resize_num_classes(module, self.parallel_context, mapping) + if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D: - self.module = _TensorParallel1D(module, self.parallel_context, mapping, config) + self.module = _TensorParallel1D(module, self.parallel_context, mapping, module_args) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D: - self.module = _TensorParallel2D(module, self.parallel_context, mapping, config) + self.module = _TensorParallel2D(module, self.parallel_context, mapping, module_args) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D: - self.module = _TensorParallel2p5D(module, self.parallel_context, mapping, config) + self.module = _TensorParallel2p5D(module, self.parallel_context, mapping, module_args) else: raise ValueError( "currently, only 1d, 2d, 2p5d tensor parallelism is supported." From e06a03a7c84d0661def925a3d9045e4da1a1841a Mon Sep 17 00:00:00 2001 From: jason960903 Date: Wed, 15 Jun 2022 08:54:12 +0900 Subject: [PATCH 08/37] add from_parallelized (huggingface only) --- .../nn/parallel/tensor_parallel/__init__.py | 2 +- .../_parallel_2p5d/_wrapper.py | 72 ++++++++-- .../tensor_parallel/tensor_parallel.py | 22 ++- .../deparallel/test_deparallelize.py | 4 +- .../deparallel/test_load_parallel.py | 135 ++++++++++++++++++ 5 files changed, 213 insertions(+), 22 deletions(-) create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py diff --git a/oslo/torch/nn/parallel/tensor_parallel/__init__.py b/oslo/torch/nn/parallel/tensor_parallel/__init__.py index 296bd0cf..9d487932 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/__init__.py +++ b/oslo/torch/nn/parallel/tensor_parallel/__init__.py @@ -1,6 +1,6 @@ from oslo.torch.nn.parallel.tensor_parallel.mapping import Column, Row, Update, Head from oslo.torch.nn.parallel.tensor_parallel.tensor_parallel import ( - TensorParallel, + TensorParallel ) __ALL__ = [TensorParallel, Column, Row, Update, Head] diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 0e632563..95731c76 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn +import torch.distributed as dist from typing import Union, Optional, Callable from logging import getLogger @@ -52,11 +53,11 @@ class _TensorParallel2p5D(ParallelWrapper): """ def __init__( - self, - module: nn.Module, - parallel_context: ParallelContext, - mapping: dict = None, - module_args: dict = None + self, + module: nn.Module, + parallel_context: ParallelContext, + mapping: dict = None, + module_args: dict = None ): super().__init__() self.module = module @@ -119,7 +120,7 @@ def _update_mp_arguments(self): ParallelMode.TENSOR_2P5D_COL ) assert ( - getattr(module, elem.name) % tesseract_dim == 0 + getattr(module, elem.name) % tesseract_dim == 0 ), f"{elem.name} must be divisible by tesseract_dim." reduced_arg = getattr(module, elem.name) // tesseract_dim setattr(module, elem.name, reduced_arg) @@ -134,7 +135,7 @@ def _parallelize_embedding(self): def _parallalize_linear(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_column_parallel( - self.module, param_name + self.module, param_name ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): self._slice_linear( module=module, @@ -157,7 +158,7 @@ def _parallelize_layernorm(self): def _parallelize_head(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_head( - self.module, param_name + self.module, param_name ) and isinstance(module, nn.Linear): self._slice_head( module=module, @@ -647,12 +648,8 @@ def save_parallelized( mapping: Optional[dict] = None, **kwargs, ): - import os - import torch.distributed as dist - logger = getLogger("Tensor2p5D") PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" - # mapping = kwargs.pop("tp_mapping", None) if ( self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1 @@ -755,6 +752,51 @@ def save_parallelized( dist.barrier() logger.info(f"Model weights saved in {output_model_file}") - @staticmethod - def from_parallelized(cls): - pass + def from_parallelized(self, path): + PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" + parallelized_model_path = path + + file_names = { + os.path.join( + parallelized_model_path, + PARALLELIZED_WEIGHTS_NAME.replace("tp_0", f"tp_{tp}").replace( + "pp_0", f"pp_{pp}" + ), + ) + for tp in range(self.parallel_context.get_world_size(ParallelMode.TENSOR)) + for pp in range(self.parallel_context.get_world_size(ParallelMode.PIPELINE)) + } + + if os.path.isdir(parallelized_model_path): + if all(os.path.isfile(file_name) for file_name in file_names): + state_dict = torch.load( + os.path.join( + parallelized_model_path, + PARALLELIZED_WEIGHTS_NAME.replace( + "tp_0", + f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}", + ).replace( + "pp_0", + f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}", + ), + ) + ) + + if getattr(self, "_keys_to_ignore_on_save", None) is not None: + state_dict = { + k: v + for k, v in state_dict.items() + if k not in self._keys_to_ignore_on_save + } + + self.load_state_dict(state_dict=state_dict, strict=False) + + else: + raise FileNotFoundError( + f"all the {file_names} are necessary. " + f"but some of them do not exist. Please check your checkpoint files." + ) + else: + raise NotADirectoryError( + f"directory named {parallelized_model_path} is not valid. " + ) diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index 08de4a3f..76f52a80 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +import torch.distributed as dist from transformers import AutoConfig @@ -28,6 +29,8 @@ unwrap_parallel, get_parallel_context, is_huggingface_model, + allocate_params, + get_parameter_dtype ) @@ -64,7 +67,7 @@ def __init__( ): super().__init__() if is_huggingface_model(module): - assert module_args is not None, "module_args must not be provided in huggingface module." + assert module_args is None, "module_args must not be provided in huggingface module." else: assert isinstance(module_args, dict), "module_args must be a dict." @@ -218,6 +221,17 @@ def save_parallelized( **kwargs, ) - @staticmethod - def from_parallelized(cls): - pass \ No newline at end of file + @classmethod + def from_parallelized(cls, parallelized_model_path, parallel_context, huggingface_task_class=None): + assert huggingface_task_class is not None, "currently, only huggingface model is supported." + + config = AutoConfig.from_pretrained(parallelized_model_path) + base_model = huggingface_task_class(config) + + config = config.to_dict() + + self = cls(base_model, parallel_context, None, config if huggingface_task_class is None else None) + allocate_params(self, parallel_context) + self.module.from_parallelized(parallelized_model_path) + + return self diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py index ec19918c..c5ac9152 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py @@ -78,7 +78,7 @@ def bw(tensors): cur = time.time() # 저장 -wrapper_tp.save_parallelized('test/', merge_checkpoints=True) +wrapper_tp.save_parallelized('test/', merge_checkpoints=False) # 모니터링 생성 대기 dist.barrier() @@ -108,7 +108,7 @@ def bw(tensors): loss_tp, tp_fw_time = \ fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) loss_gathered, gathered_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + fw(model_gathered, **inputs, labels=inputs["input_ids"]) if dist.get_rank() == 0: print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py new file mode 100644 index 00000000..60bb361e --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py @@ -0,0 +1,135 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end-start + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 8 +tp_depth = 2 + +model_name = "gpt2" +mkwargs = { +} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() +model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized('test/', merge_checkpoints=False) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_reparallel = TensorParallel.from_parallelized('test/', parallel_context, GPT2LMHeadModel) +optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + model_reparallel.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = \ + fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_reparallel, reparallel_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, reparallel:{loss_reparallel}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "reparallel": loss_reparallel}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, reparallel_bw_time = bw(loss_reparallel) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_reparallel.step() + + if dist.get_rank() == 0: + wandb.log({ + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "reparallel.forward.time:": reparallel_fw_time, + "reparallel.backward.time:": reparallel_bw_time + }) + +dist.barrier() \ No newline at end of file From f6f28a4fc962bc1c44a45373d57fc07f4845e5fa Mon Sep 17 00:00:00 2001 From: jason960903 Date: Wed, 15 Jun 2022 20:21:35 +0900 Subject: [PATCH 09/37] modified from_parallelized --- .../_parallel_2p5d/_wrapper.py | 16 +++- .../tensor_parallel/tensor_parallel.py | 74 +++++++++++++++---- 2 files changed, 74 insertions(+), 16 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 95731c76..730c022e 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -1,5 +1,6 @@ import copy import os +import json import torch import torch.nn as nn @@ -692,8 +693,13 @@ def save_parallelized( **kwargs, ) else: - # TODO : Non-huggingface model - pass + if save_config: + with open(os.path.join(save_directory, "config.json"), "w") as f: + json.dump(self.config, f) + save_function( + model_to_save, + os.path.join(save_directory, "pytorch_model.bin"), + ) del model_to_save dist.barrier() @@ -753,6 +759,12 @@ def save_parallelized( logger.info(f"Model weights saved in {output_model_file}") def from_parallelized(self, path): + """ + Example: + >>> model = AnyModel() + >>> model = TensorParallel(model, ...) + >>> model.from_parallelized(path) + """ PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" parallelized_model_path = path diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index 76f52a80..97503819 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -1,6 +1,8 @@ from typing import Union, Optional, Callable import os +import json +from operator import xor import torch import torch.nn as nn @@ -221,17 +223,61 @@ def save_parallelized( **kwargs, ) - @classmethod - def from_parallelized(cls, parallelized_model_path, parallel_context, huggingface_task_class=None): - assert huggingface_task_class is not None, "currently, only huggingface model is supported." - - config = AutoConfig.from_pretrained(parallelized_model_path) - base_model = huggingface_task_class(config) - - config = config.to_dict() - - self = cls(base_model, parallel_context, None, config if huggingface_task_class is None else None) - allocate_params(self, parallel_context) - self.module.from_parallelized(parallelized_model_path) - - return self + @staticmethod + def get_module_args(module): + state_dict = module.state_dict() + return { + key: value.shape for key, value in state_dict.items() + } + + def from_parallelized(self, path): + return self.module.from_parallelized(path) + + # @classmethod + # def from_parallelized(cls, parallelized_model_path, parallel_context, + # huggingface_task_class=None, base_model_class=None): + # """ + # :param parallelized_model_path: path to the parallelized model + # :param parallel_context: parallel context + # :param huggingface_task_class: huggingface task class ex. BertForSequenceClassification + # :param base_model_class: custom model's class field ex. CustomModel + # :return: TensorParallelWrapper + # + # Examples: + # + # >>> # huggingface model + # >>> huggingface_parallelized = TensorParallel.from_parallelized( + # >>> "bert-base-uncased", + # >>> parallel_context, + # >>> huggingface_task_class=BertForSequenceClassification, + # >>> ) + # + # >>> # custom model + # >>> from path.to.custom.model import CustomModel + # >>> custom_parallelized = TensorParallel.from_parallelized( + # >>> "/path/to/custom_model", + # >>> parallel_context, + # >>> base_model_class=CustomModel, + # >>>) + # """ + # assert xor(huggingface_task_class is None, base_model_class is None), \ + # "`huggingface_task_class` and `orig_model` must be input only one of them." + # + # if base_model_class is None: + # config = AutoConfig.from_pretrained(parallelized_model_path) + # base_model = huggingface_task_class(config) + # config = None + # else: + # config_path = os.path.join(parallelized_model_path, "config.json") + # if os.path.exists(config_path): + # with open(config_path, "r") as f: + # config = dict(json.load(f)) + # else: + # raise ValueError(f"`config.json` is not found in {parallelized_model_path}.") + # base_model = base_model_class(**config) + # + # self = cls(base_model, parallel_context, None, config) + # allocate_params(self, parallel_context) + # self.module.from_parallelized(parallelized_model_path) + # + # return self From 0805272a83b682a6ecc65eacf456088ae3c074b1 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Wed, 15 Jun 2022 21:24:35 +0900 Subject: [PATCH 10/37] refactoring model wrapper --- .../nn/parallel/tensor_parallel/__init__.py | 3 +- .../parallel/tensor_parallel/_base_wrapper.py | 217 ++++++++++ .../_parallel_2p5d/_wrapper.py | 369 +++++++++--------- .../tensor_parallel/tensor_parallel.py | 49 --- .../deparallel/test_load_parallel.py | 7 +- 5 files changed, 406 insertions(+), 239 deletions(-) create mode 100644 oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py diff --git a/oslo/torch/nn/parallel/tensor_parallel/__init__.py b/oslo/torch/nn/parallel/tensor_parallel/__init__.py index 9d487932..ba704bc3 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/__init__.py +++ b/oslo/torch/nn/parallel/tensor_parallel/__init__.py @@ -2,5 +2,6 @@ from oslo.torch.nn.parallel.tensor_parallel.tensor_parallel import ( TensorParallel ) +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import BaseTensorParallelWrapper -__ALL__ = [TensorParallel, Column, Row, Update, Head] +__ALL__ = [TensorParallel, Column, Row, Update, Head, BaseTensorParallelWrapper] diff --git a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py new file mode 100644 index 00000000..8c65b9da --- /dev/null +++ b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py @@ -0,0 +1,217 @@ +import copy +import os +import json + +import torch +import torch.nn as nn +import torch.distributed as dist + +from typing import Union, Optional, Callable +from logging import getLogger + +from oslo.torch.distributed import ParallelContext, ParallelMode + +from oslo.torch.nn.parallel.utils import ( + ParallelWrapper, + _update_module_arguments, + is_huggingface_model, + is_oslo_model, + allocate_params, + unwrap_parallel, + get_parameter_dtype +) + + +class BaseTensorParallelWrapper(ParallelWrapper): + """ + PyTorch module for 2.5D tensor parallelism + + Args: + module (nn.Module): model object + parallel_context (ParallelContext): parallel context object + """ + + def __init__( + self, + module: nn.Module, + parallel_context: ParallelContext, + mapping: dict = None, + module_args: dict = None + ): + super().__init__() + + @torch.no_grad() + def save_parallelized( + self, + new_module, + save_directory: Union[str, os.PathLike], + save_config: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + merge_checkpoints: bool = False, + mapping: Optional[dict] = None, + **kwargs, + ): + logger = getLogger("Tensor2p5D") + PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" + + if ( + self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1 + and self.parallel_context.get_world_size(ParallelMode.PIPELINE) == 1 + ): + if dist.get_rank() == 0: + self.save_pretrained( + save_directory=save_directory, + save_config=save_config, + state_dict=state_dict, + save_function=save_function, + **kwargs, + ) + dist.barrier() + return None + + if merge_checkpoints: + model_to_save = self.__class__( + module=new_module, + parallel_context=self.parallel_context, + mapping=mapping, + module_args=self.config + ).eval() + + if state_dict is None: + state_dict = self.state_dict() + + model_to_save.load_state_dict(state_dict) + allocate_params(model_to_save, self.parallel_context) + + if self.parallel_context.get_world_size(ParallelMode.TENSOR) > 1: + model_to_save.deparallelize() + + if dist.get_rank() == 0: + if is_huggingface_model(model_to_save.module): + model_to_save.module.save_pretrained( + save_directory=save_directory, + save_config=save_config, + save_function=save_function, + **kwargs, + ) + else: + if save_config: + with open(os.path.join(save_directory, "config.json"), "w") as f: + json.dump(self.config, f) + save_function( + model_to_save, + os.path.join(save_directory, "pytorch_model.bin"), + ) + del model_to_save + + dist.barrier() + return None + + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + + os.makedirs(save_directory, exist_ok=True) + + # Only save the model itself if we are using distributed training + model_to_save = unwrap_parallel(self) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Save the config + if save_config: + model_to_save.config.save_pretrained(save_directory) + + # Save the model + if state_dict is None: + state_dict = model_to_save.state_dict() + + # Handle the case where some state_dict keys shouldn't be saved + if getattr(self, "_keys_to_ignore_on_save", None) is not None: + state_dict = { + k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save + } + + # If we save using the predefined names, we can load using `from_pretrained` + weights_name = PARALLELIZED_WEIGHTS_NAME + weights_name = weights_name.replace( + "tp_0", f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}" + ) + weights_name = weights_name.replace( + "pp_0", f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}" + ) + + output_model_file = os.path.join(save_directory, weights_name) + + if self.parallel_context.get_world_size(ParallelMode.DATA) > 1: + if self.parallel_context.get_local_rank(ParallelMode.DATA) == 0: + save_function(state_dict, output_model_file) + else: + save_function(state_dict, output_model_file) + + dist.barrier() + logger.info(f"Model weights saved in {output_model_file}") + + def from_parallelized(self, path): + """ + Example: + >>> model = AnyModel() + >>> model = TensorParallel(model, ...) + >>> model.from_parallelized(path) + """ + PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" + parallelized_model_path = path + + file_names = { + os.path.join( + parallelized_model_path, + PARALLELIZED_WEIGHTS_NAME.replace("tp_0", f"tp_{tp}").replace( + "pp_0", f"pp_{pp}" + ), + ) + for tp in range(self.parallel_context.get_world_size(ParallelMode.TENSOR)) + for pp in range(self.parallel_context.get_world_size(ParallelMode.PIPELINE)) + } + + if os.path.isdir(parallelized_model_path): + if all(os.path.isfile(file_name) for file_name in file_names): + state_dict = torch.load( + os.path.join( + parallelized_model_path, + PARALLELIZED_WEIGHTS_NAME.replace( + "tp_0", + f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}", + ).replace( + "pp_0", + f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}", + ), + ) + ) + + if getattr(self, "_keys_to_ignore_on_save", None) is not None: + state_dict = { + k: v + for k, v in state_dict.items() + if k not in self._keys_to_ignore_on_save + } + + self.load_state_dict(state_dict=state_dict, strict=False) + + else: + raise FileNotFoundError( + f"all the {file_names} are necessary. " + f"but some of them do not exist. Please check your checkpoint files." + ) + else: + raise NotADirectoryError( + f"directory named {parallelized_model_path} is not valid. " + ) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 730c022e..b86d305a 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -1,13 +1,7 @@ import copy -import os -import json import torch import torch.nn as nn -import torch.distributed as dist - -from typing import Union, Optional, Callable -from logging import getLogger from oslo.torch.distributed import ParallelContext, ParallelMode @@ -27,15 +21,14 @@ from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, ) +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, +) from oslo.torch.nn.parallel.utils import ( - ParallelWrapper, _update_module_arguments, is_huggingface_model, - is_oslo_model, - allocate_params, - unwrap_parallel, - get_parameter_dtype + is_oslo_model ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -44,7 +37,7 @@ from oslo.transformers.constants import BATCH_DIMENSIONS -class _TensorParallel2p5D(ParallelWrapper): +class _TensorParallel2p5D(BaseTensorParallelWrapper): """ PyTorch module for 2.5D tensor parallelism @@ -60,7 +53,7 @@ def __init__( mapping: dict = None, module_args: dict = None ): - super().__init__() + super().__init__(module, parallel_context, mapping, module_args) self.module = module self.parallel_context = parallel_context self.device = torch.cuda.current_device() @@ -637,178 +630,178 @@ def _reconstrunct_combined_qkv_bias(tensor, tessearct_dim, fusion_degree): tensor = [tensor[j] for j in range(tessearct_dim)] return tensor - @torch.no_grad() - def save_parallelized( - self, - new_module, - save_directory: Union[str, os.PathLike], - save_config: bool = True, - state_dict: Optional[dict] = None, - save_function: Callable = torch.save, - merge_checkpoints: bool = False, - mapping: Optional[dict] = None, - **kwargs, - ): - logger = getLogger("Tensor2p5D") - PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" - - if ( - self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1 - and self.parallel_context.get_world_size(ParallelMode.PIPELINE) == 1 - ): - if dist.get_rank() == 0: - self.save_pretrained( - save_directory=save_directory, - save_config=save_config, - state_dict=state_dict, - save_function=save_function, - **kwargs, - ) - dist.barrier() - return None - - if merge_checkpoints: - model_to_save = self.__class__( - module=new_module, - parallel_context=self.parallel_context, - mapping=mapping, - module_args=self.config - ).eval() - - if state_dict is None: - state_dict = self.state_dict() - - model_to_save.load_state_dict(state_dict) - allocate_params(model_to_save, self.parallel_context) - - if self.parallel_context.get_world_size(ParallelMode.TENSOR) > 1: - model_to_save.deparallelize() - - if dist.get_rank() == 0: - if is_huggingface_model(model_to_save.module): - model_to_save.module.save_pretrained( - save_directory=save_directory, - save_config=save_config, - save_function=save_function, - **kwargs, - ) - else: - if save_config: - with open(os.path.join(save_directory, "config.json"), "w") as f: - json.dump(self.config, f) - save_function( - model_to_save, - os.path.join(save_directory, "pytorch_model.bin"), - ) - del model_to_save - - dist.barrier() - return None - - if os.path.isfile(save_directory): - logger.error( - f"Provided path ({save_directory}) should be a directory, not a file" - ) - return - - os.makedirs(save_directory, exist_ok=True) - - # Only save the model itself if we are using distributed training - model_to_save = unwrap_parallel(self) - - # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" - # we currently don't use this setting automatically, but may start to use with v5 - dtype = get_parameter_dtype(model_to_save) - model_to_save.config.torch_dtype = str(dtype).split(".")[1] - - # Attach architecture to the config - model_to_save.config.architectures = [model_to_save.__class__.__name__] - - # Save the config - if save_config: - model_to_save.config.save_pretrained(save_directory) - - # Save the model - if state_dict is None: - state_dict = model_to_save.state_dict() - - # Handle the case where some state_dict keys shouldn't be saved - if getattr(self, "_keys_to_ignore_on_save", None) is not None: - state_dict = { - k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save - } - - # If we save using the predefined names, we can load using `from_pretrained` - weights_name = PARALLELIZED_WEIGHTS_NAME - weights_name = weights_name.replace( - "tp_0", f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}" - ) - weights_name = weights_name.replace( - "pp_0", f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}" - ) - - output_model_file = os.path.join(save_directory, weights_name) - - if self.parallel_context.get_world_size(ParallelMode.DATA) > 1: - if self.parallel_context.get_local_rank(ParallelMode.DATA) == 0: - save_function(state_dict, output_model_file) - else: - save_function(state_dict, output_model_file) - - dist.barrier() - logger.info(f"Model weights saved in {output_model_file}") - - def from_parallelized(self, path): - """ - Example: - >>> model = AnyModel() - >>> model = TensorParallel(model, ...) - >>> model.from_parallelized(path) - """ - PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" - parallelized_model_path = path - - file_names = { - os.path.join( - parallelized_model_path, - PARALLELIZED_WEIGHTS_NAME.replace("tp_0", f"tp_{tp}").replace( - "pp_0", f"pp_{pp}" - ), - ) - for tp in range(self.parallel_context.get_world_size(ParallelMode.TENSOR)) - for pp in range(self.parallel_context.get_world_size(ParallelMode.PIPELINE)) - } - - if os.path.isdir(parallelized_model_path): - if all(os.path.isfile(file_name) for file_name in file_names): - state_dict = torch.load( - os.path.join( - parallelized_model_path, - PARALLELIZED_WEIGHTS_NAME.replace( - "tp_0", - f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}", - ).replace( - "pp_0", - f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}", - ), - ) - ) - - if getattr(self, "_keys_to_ignore_on_save", None) is not None: - state_dict = { - k: v - for k, v in state_dict.items() - if k not in self._keys_to_ignore_on_save - } - - self.load_state_dict(state_dict=state_dict, strict=False) - - else: - raise FileNotFoundError( - f"all the {file_names} are necessary. " - f"but some of them do not exist. Please check your checkpoint files." - ) - else: - raise NotADirectoryError( - f"directory named {parallelized_model_path} is not valid. " - ) + # @torch.no_grad() + # def save_parallelized( + # self, + # new_module, + # save_directory: Union[str, os.PathLike], + # save_config: bool = True, + # state_dict: Optional[dict] = None, + # save_function: Callable = torch.save, + # merge_checkpoints: bool = False, + # mapping: Optional[dict] = None, + # **kwargs, + # ): + # logger = getLogger("Tensor2p5D") + # PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" + # + # if ( + # self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1 + # and self.parallel_context.get_world_size(ParallelMode.PIPELINE) == 1 + # ): + # if dist.get_rank() == 0: + # self.save_pretrained( + # save_directory=save_directory, + # save_config=save_config, + # state_dict=state_dict, + # save_function=save_function, + # **kwargs, + # ) + # dist.barrier() + # return None + # + # if merge_checkpoints: + # model_to_save = self.__class__( + # module=new_module, + # parallel_context=self.parallel_context, + # mapping=mapping, + # module_args=self.config + # ).eval() + # + # if state_dict is None: + # state_dict = self.state_dict() + # + # model_to_save.load_state_dict(state_dict) + # allocate_params(model_to_save, self.parallel_context) + # + # if self.parallel_context.get_world_size(ParallelMode.TENSOR) > 1: + # model_to_save.deparallelize() + # + # if dist.get_rank() == 0: + # if is_huggingface_model(model_to_save.module): + # model_to_save.module.save_pretrained( + # save_directory=save_directory, + # save_config=save_config, + # save_function=save_function, + # **kwargs, + # ) + # else: + # if save_config: + # with open(os.path.join(save_directory, "config.json"), "w") as f: + # json.dump(self.config, f) + # save_function( + # model_to_save, + # os.path.join(save_directory, "pytorch_model.bin"), + # ) + # del model_to_save + # + # dist.barrier() + # return None + # + # if os.path.isfile(save_directory): + # logger.error( + # f"Provided path ({save_directory}) should be a directory, not a file" + # ) + # return + # + # os.makedirs(save_directory, exist_ok=True) + # + # # Only save the model itself if we are using distributed training + # model_to_save = unwrap_parallel(self) + # + # # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # # we currently don't use this setting automatically, but may start to use with v5 + # dtype = get_parameter_dtype(model_to_save) + # model_to_save.config.torch_dtype = str(dtype).split(".")[1] + # + # # Attach architecture to the config + # model_to_save.config.architectures = [model_to_save.__class__.__name__] + # + # # Save the config + # if save_config: + # model_to_save.config.save_pretrained(save_directory) + # + # # Save the model + # if state_dict is None: + # state_dict = model_to_save.state_dict() + # + # # Handle the case where some state_dict keys shouldn't be saved + # if getattr(self, "_keys_to_ignore_on_save", None) is not None: + # state_dict = { + # k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save + # } + # + # # If we save using the predefined names, we can load using `from_pretrained` + # weights_name = PARALLELIZED_WEIGHTS_NAME + # weights_name = weights_name.replace( + # "tp_0", f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}" + # ) + # weights_name = weights_name.replace( + # "pp_0", f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}" + # ) + # + # output_model_file = os.path.join(save_directory, weights_name) + # + # if self.parallel_context.get_world_size(ParallelMode.DATA) > 1: + # if self.parallel_context.get_local_rank(ParallelMode.DATA) == 0: + # save_function(state_dict, output_model_file) + # else: + # save_function(state_dict, output_model_file) + # + # dist.barrier() + # logger.info(f"Model weights saved in {output_model_file}") + # + # def from_parallelized(self, path): + # """ + # Example: + # >>> model = AnyModel() + # >>> model = TensorParallel(model, ...) + # >>> model.from_parallelized(path) + # """ + # PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" + # parallelized_model_path = path + # + # file_names = { + # os.path.join( + # parallelized_model_path, + # PARALLELIZED_WEIGHTS_NAME.replace("tp_0", f"tp_{tp}").replace( + # "pp_0", f"pp_{pp}" + # ), + # ) + # for tp in range(self.parallel_context.get_world_size(ParallelMode.TENSOR)) + # for pp in range(self.parallel_context.get_world_size(ParallelMode.PIPELINE)) + # } + # + # if os.path.isdir(parallelized_model_path): + # if all(os.path.isfile(file_name) for file_name in file_names): + # state_dict = torch.load( + # os.path.join( + # parallelized_model_path, + # PARALLELIZED_WEIGHTS_NAME.replace( + # "tp_0", + # f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}", + # ).replace( + # "pp_0", + # f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}", + # ), + # ) + # ) + # + # if getattr(self, "_keys_to_ignore_on_save", None) is not None: + # state_dict = { + # k: v + # for k, v in state_dict.items() + # if k not in self._keys_to_ignore_on_save + # } + # + # self.load_state_dict(state_dict=state_dict, strict=False) + # + # else: + # raise FileNotFoundError( + # f"all the {file_names} are necessary. " + # f"but some of them do not exist. Please check your checkpoint files." + # ) + # else: + # raise NotADirectoryError( + # f"directory named {parallelized_model_path} is not valid. " + # ) diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index 97503819..d35dcc89 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -232,52 +232,3 @@ def get_module_args(module): def from_parallelized(self, path): return self.module.from_parallelized(path) - - # @classmethod - # def from_parallelized(cls, parallelized_model_path, parallel_context, - # huggingface_task_class=None, base_model_class=None): - # """ - # :param parallelized_model_path: path to the parallelized model - # :param parallel_context: parallel context - # :param huggingface_task_class: huggingface task class ex. BertForSequenceClassification - # :param base_model_class: custom model's class field ex. CustomModel - # :return: TensorParallelWrapper - # - # Examples: - # - # >>> # huggingface model - # >>> huggingface_parallelized = TensorParallel.from_parallelized( - # >>> "bert-base-uncased", - # >>> parallel_context, - # >>> huggingface_task_class=BertForSequenceClassification, - # >>> ) - # - # >>> # custom model - # >>> from path.to.custom.model import CustomModel - # >>> custom_parallelized = TensorParallel.from_parallelized( - # >>> "/path/to/custom_model", - # >>> parallel_context, - # >>> base_model_class=CustomModel, - # >>>) - # """ - # assert xor(huggingface_task_class is None, base_model_class is None), \ - # "`huggingface_task_class` and `orig_model` must be input only one of them." - # - # if base_model_class is None: - # config = AutoConfig.from_pretrained(parallelized_model_path) - # base_model = huggingface_task_class(config) - # config = None - # else: - # config_path = os.path.join(parallelized_model_path, "config.json") - # if os.path.exists(config_path): - # with open(config_path, "r") as f: - # config = dict(json.load(f)) - # else: - # raise ValueError(f"`config.json` is not found in {parallelized_model_path}.") - # base_model = base_model_class(**config) - # - # self = cls(base_model, parallel_context, None, config) - # allocate_params(self, parallel_context) - # self.module.from_parallelized(parallelized_model_path) - # - # return self diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py index 60bb361e..51f9009c 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py @@ -84,7 +84,12 @@ def bw(tensors): dist.barrier() # 로드 -model_reparallel = TensorParallel.from_parallelized('test/', parallel_context, GPT2LMHeadModel) +model_reparallel = TensorParallel( + GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), + parallel_context +) +allocate_params(model_reparallel, parallel_context) +model_reparallel.from_parallelized('test/') optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) dist.barrier() From 1469b95ca5ab4e400cf07181ed21f0cd0315fec3 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Fri, 17 Jun 2022 10:19:30 +0900 Subject: [PATCH 11/37] add 1d deparallelism --- .../parallel/tensor_parallel/_base_wrapper.py | 5 + .../tensor_parallel/_parallel_1d/_ops.py | 19 +++ .../tensor_parallel/_parallel_1d/_wrapper.py | 155 +++++++++++++++++- .../_parallel_1d/deparallel/__init__.py | 3 + .../deparallel/test_deparallelize.py | 139 ++++++++++++++++ .../deparallel/test_load_parallel.py | 140 ++++++++++++++++ .../_parallel_1d/deparallel/test_qkv.py | 59 +++++++ .../_parallel_1d/deparallel/test_vocab.py | 85 ++++++++++ .../_parallel_1d/test_wrapper_1d.py | 62 ++++++- 9 files changed, 657 insertions(+), 10 deletions(-) create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/__init__.py create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py diff --git a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py index 8c65b9da..fd5d302a 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py @@ -215,3 +215,8 @@ def from_parallelized(self, path): raise NotADirectoryError( f"directory named {parallelized_model_path} is not valid. " ) + + @torch.no_grad() + def deparallelize(self): + return NotImplementedError + diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py index 9f55fa18..3a5c0f3f 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py @@ -1,6 +1,7 @@ from typing import Any import torch +import torch.distributed as dist from torch import Tensor from oslo.torch.distributed import ParallelMode, ParallelContext @@ -110,3 +111,21 @@ def all_gather_tensor_1d(inputs: Tensor, dim: int, parallel_context: ParallelCon def scatter_tensor_1d(inputs: Tensor, dim: int, parallel_context: ParallelContext): return _ScatterTensor1D.apply(inputs, dim, parallel_context) + + +def split_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor = tensor.chunk(summa_dim, dim=dim)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_1D) + ] + return tensor + + +def gather_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_1D), + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 5fcc2b88..0aa61c04 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +import torch.distributed as dist from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn.modules.embedding import ( @@ -17,17 +18,19 @@ TensorParallelMapping, ) from oslo.torch.nn.parallel.utils import ( - ParallelWrapper, _update_module_arguments, is_huggingface_model, is_oslo_model, ) +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, +) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, ) -class _TensorParallel1D(ParallelWrapper): +class _TensorParallel1D(BaseTensorParallelWrapper): """ PyTorch module for 1D tensor parallelism @@ -125,7 +128,7 @@ def _parallelize_linear(self): def _parallelize_head(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_head( - self.module, param_name + self.module, param_name ) and isinstance(module, nn.Linear): self._slice_head( module=module, @@ -308,3 +311,149 @@ def _slice_head(self, module, reversed): else False, ) module.__class__ = ColumnParallelLinear + + @torch.no_grad() + def deparallelize(self): + self._deparallelize_linear() + self._deparallelize_embedding() + self._rollback_mp_arguments() + + def _rollback_mp_arguments(self): + for module in self.module.modules(): + for elem in self.tensor_parallel_mapping.update_attrs(self.module): + if hasattr(module, elem.name): + world_size = self.parallel_context.get_world_size( + ParallelMode.TENSOR_1D + ) + expanded_arg = getattr(module, elem.name) * world_size + setattr(module, elem.name, expanded_arg) + + def _deparallelize_embedding(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == VocabParallelEmbedding1D: + self._gather_embedding(module) + + def _deparallelize_linear(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_column_parallel(self.module, param_name): + self._gather_column_linear(module) + + elif self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): + self._gather_row_linear(module) + + def _gather_embedding(self, module): + world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + # w = gather_2d(self.parallel_context, module.weight.data, world_size, col_first=True) + tensor_list = [torch.zeros_like(module.weight.data) for _ in range(world_size)] + dist.all_gather( + tensor_list, + module.weight.data.contiguous(), + self.parallel_context.get_group(ParallelMode.TENSOR_1D), + ) + w = torch.cat(tensor_list, dim=0) + + assert hasattr( + self.module, "orig_vocab_size" + ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + orig_vocab_size = self.module.orig_vocab_size + + module.weight.data = w[:, :orig_vocab_size] + + _update_module_arguments( + module=module, + vocab_start_index=None, + vocab_end_index=None, + parallel_context=None, + num_embeddings=module.weight.size()[0], + embedding_dim=module.weight.size()[1], + orig_module=None + ) + else: + tensor_list = [torch.zeros_like(module.weight.data) for _ in range(world_size)] + dist.all_gather( + tensor_list, + module.weight.data.contiguous(), + self.parallel_context.get_group(ParallelMode.TENSOR_1D), + ) + w = torch.cat(tensor_list, dim=1) + module.weight.data = w + + _update_module_arguments( + module=module, + parallel_context=None, + embedding_dim = module.weight.size()[1] + ) + module.__class__ = nn.Embedding + + def _gather_linear(self, module, dim=1): + is_reversed = module.reversed + fusion_degree = module.fusion_degree + + world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) + + # w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim=tesseract_dim, col_first=True) + tensor_list = [torch.zeros_like(module.weight.data) for _ in range(world_size)] + dist.all_gather( + tensor_list, + module.weight.data.contiguous(), + self.parallel_context.get_group(ParallelMode.TENSOR_1D), + ) + w = torch.cat(tensor_list, dim=dim) + + if fusion_degree > 1: + w = self._reconstruct_combined_qkv(w, world_size, fusion_degree, False) + if is_reversed: + w = module.weight.data.t() + module.weight.data = w + + if hasattr(module, "bias") and module.bias is not None: + # if slice_bias is True and module.bias.dim() >= 1: + # b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, dim) + tensor_list = [torch.zeros_like(module.bias.data) for _ in range(world_size)] + dist.all_gather( + tensor_list, + module.bias.data.contiguous(), + self.parallel_context.get_group(ParallelMode.TENSOR_1D), + ) + b = torch.cat(tensor_list, dim=dim) + if fusion_degree > 1: + b = self._reconstruct_combined_qkv(b, world_size, fusion_degree, dim) + module.bias.data = b + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + module.__class__ = nn.Linear + + def _gather_column_linear(self, module): + self._gather_linear(module, dim=0) + + def _gather_row_linear(self, module): + self._gather_linear(module, dim=1) + + # TODO: fix + @staticmethod + def _reconstruct_combined_qkv(tensor, world_size, fusion_degree, dim): + last_dim = tensor.size()[dim-1] + reshaped_w = tensor.view(fusion_degree * world_size, -1) + # print(f"tensor.size: {tensor.size()}, reshaped_w.size: {reshaped_w.size()}") + recon_w = torch.cat([ + reshaped_w[i * fusion_degree: (i+1) * fusion_degree] + for i in range(world_size)], 1).view(-1, last_dim).contiguous() + return recon_w \ No newline at end of file diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/__init__.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/__init__.py new file mode 100644 index 00000000..28fcc144 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/__init__.py @@ -0,0 +1,3 @@ +from .. import _utils + +_ALL__ = [_utils] diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py new file mode 100644 index 00000000..a95aacd1 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py @@ -0,0 +1,139 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end-start + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 4 +tp_depth = 1 + +model_name = "gpt2" +mkwargs = { +} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_1D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() +model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized('test/', merge_checkpoints=False) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_gathered = GPT2LMHeadModel(GPT2Config.from_pretrained("test/")).cuda() +optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + optimizer_gathered.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = \ + fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = \ + fw(model_gathered, **inputs, labels=inputs["input_ids"]) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, gathered_bw_time = bw(loss_gathered) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_gathered.step() + + if dist.get_rank() == 0: + wandb.log({ + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time + }) + +dist.barrier() + + + + diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py new file mode 100644 index 00000000..51f9009c --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py @@ -0,0 +1,140 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end-start + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 8 +tp_depth = 2 + +model_name = "gpt2" +mkwargs = { +} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() +model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized('test/', merge_checkpoints=False) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_reparallel = TensorParallel( + GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), + parallel_context +) +allocate_params(model_reparallel, parallel_context) +model_reparallel.from_parallelized('test/') +optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + model_reparallel.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = \ + fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_reparallel, reparallel_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, reparallel:{loss_reparallel}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "reparallel": loss_reparallel}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, reparallel_bw_time = bw(loss_reparallel) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_reparallel.step() + + if dist.get_rank() == 0: + wandb.log({ + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "reparallel.forward.time:": reparallel_fw_time, + "reparallel.backward.time:": reparallel_bw_time + }) + +dist.barrier() \ No newline at end of file diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py new file mode 100644 index 00000000..18f3cb04 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py @@ -0,0 +1,59 @@ +import torch.nn as nn + +from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._wrapper import _TensorParallel1D +from oslo.torch.nn import ColumnParallelLinear, RowParallelLinear +from oslo.torch.distributed import ParallelContext, ParallelMode +from copy import deepcopy +from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import gather_1d + +tp_size = 4 +tp_depth = 2 + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_1D, + tensor_parallel_depth=tp_depth, +) + +world_size = parallel_context.get_world_size(ParallelMode.TENSOR_1D) +rank = parallel_context.get_local_rank(ParallelMode.TENSOR_1D) +fusion_degree = 3 + +linear = nn.Linear(4, fusion_degree*4).cuda() +w = deepcopy(linear.weight.data) +b = deepcopy(linear.bias.data) + +dim= 1 + +weight_list = w.t().chunk(fusion_degree * world_size, dim=dim) +bias_list = b.chunk(fusion_degree * world_size, dim=0) + +# [t][f*t] +weight_list = _TensorParallel1D._deconstruct_combined_qkv(weight_list, world_size, fusion_degree, dim) +chunked_w = weight_list[rank].contiguous() +bias_list = _TensorParallel1D._deconstruct_combined_qkv(bias_list, world_size, fusion_degree, 0) +chunked_b = bias_list[rank].contiguous() + +linear_1d = RowParallelLinear(4, fusion_degree * 4, parallel_context=parallel_context, bias=True) +if parallel_context.get_global_rank() == 0: + print(chunked_w.size()) + print(linear_1d.weight.data.size()) +linear_1d.weight.data = chunked_w + +recon_chunked_w = gather_1d(parallel_context, linear_1d.weight.data, world_size, dim) +recon_chunked_b = gather_1d(parallel_context, linear_1d.bias.data, world_size, 0) +# reshaped_w = recon_chunked_w.view(-1, tesseract_dim, 4) +# recon_w = torch.cat([reshaped_w[:3], reshaped_w[3:]], 1).view(-1, 4).contiguous() +print(recon_chunked_w.shape) +recon_w = _TensorParallel1D._reconstruct_combined_qkv(recon_chunked_w, world_size, fusion_degree, dim) +recon_b = _TensorParallel1D._reconstruct_combined_qkv(recon_chunked_b, world_size, fusion_degree, 0) + +if parallel_context.get_global_rank() == 0: + print(f"original w: \n{w}\n") + print(f"reconstruct w: \n{recon_w}\n") + + print(f"original b: \n{b}\n") + print(f"reconstruct b: \n{recon_b}\n") \ No newline at end of file diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py new file mode 100644 index 00000000..3edfde0c --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py @@ -0,0 +1,85 @@ +import torch +import torch.distributed as dist + +from oslo.torch.distributed import ParallelContext, ParallelMode +from oslo.torch.nn import VocabParallelEmbedding2p5D +from oslo.torch.nn.parallel import utils + +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import split_batch_2d, split_2d, gather_2d + +from copy import deepcopy + + +tp_size = 8 +tp_depth = 2 + +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +torch.set_printoptions(sci_mode=False) +torch.manual_seed(0) +tesseract_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) +input_ = torch.LongTensor([[1, 2, 3, 4], [5, 6, 7, 8]]).cuda() +target = torch.randn((2, 4, 16)).cuda() +dist.broadcast(input_, src=0) +dist.broadcast(target, src=0) + +vocab_embedding = torch.nn.Embedding(10, 16).cuda() +w = deepcopy(vocab_embedding.weight.data) + +out = vocab_embedding(input_) +optimizer = torch.optim.Adam(vocab_embedding.parameters(), lr=1e-3) +logits = torch.nn.MSELoss()(out, target) +logits.backward() +optimizer.step() + +out_update = vocab_embedding(input_) + +if parallel_context.get_global_rank() == 0: + print(f"original output: \n{out}\n") + print(f"original update output: \n{out_update}\n") + # print(f"vocab start: {vocab_embedding.start_index}, vocab end: {vocab_embedding.end_index}") + +input_ = split_batch_2d(parallel_context, input_, tesseract_dim) +# split target into 0:[0, 0], 1:[0, 1], 2:[1, 0], 3:[1, 1] +target = split_2d(parallel_context, target, tesseract_dim, col_first=True) +# split weight into 0:[0, 0], 1:[1, 0], 2:[0, 1], 3:[1, 1] +w = split_2d(parallel_context, w, tesseract_dim, col_first=False) + +vocab_embedding_2p5d = VocabParallelEmbedding2p5D( + 10, 16, parallel_context=parallel_context +) +vocab_embedding_2p5d.weight.data.copy_(w) + +pout = vocab_embedding_2p5d(input_) +optimizer = torch.optim.Adam(vocab_embedding_2p5d.parameters(), lr=1e-3) +logits = torch.nn.MSELoss()(pout, target) +logits.backward() +optimizer.step() + +if parallel_context.get_global_rank() == 0: + unwrapped_model = utils.unwrap_parallel(vocab_embedding_2p5d) + print(f"original vocab size: {unwrapped_model.orig_vocab_size}") + + +# +# +# pout_update = vocab_embedding_2p5d(input_) +# +# pout = gather_2d(parallel_context, pout, tesseract_dim, col_first=False) +# pout_update = gather_2d(parallel_context, pout_update, tesseract_dim, col_first=False) +# +# if parallel_context.get_global_rank() == 0: +# print(f"parallel output: \n{pout}\n") +# print(f"parallel update output: \n{pout_update}\n") +# +# if parallel_context.get_global_rank() == 0: +# sse = torch.sum((out - pout) ** 2).item() +# sse_update = torch.sum((out_update - pout_update) ** 2).item() +# print(f"output sse: \n{sse}\n") +# print(f"next output sse: \n{sse_update}\n") \ No newline at end of file diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py index 4d1a7299..3057110e 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py @@ -8,6 +8,26 @@ from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params +import time + + +def time_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end-start + return wrapper + + +@time_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@time_trace +def bw(tensors): + return tensors.backward() # parallel context 생성 parallel_context = ParallelContext.from_torch( @@ -17,8 +37,13 @@ tensor_parallel_mode=ParallelMode.TENSOR_1D, ) +model_name = "gpt2" +mkwargs = { +} +dataset_name = "squad" + # 토크나이저 생성 -tokenizer = AutoTokenizer.from_pretrained("gpt2") +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) tokenizer.pad_token = tokenizer.eos_token # 모델 생성 및 병렬화 수행 @@ -38,13 +63,14 @@ # 데이터셋 생성 batch_size = 16 -datasets = load_dataset("squad").data["train"]["context"] +datasets = load_dataset(dataset_name).data["train"]["context"] datasets = [str(sample) for sample in datasets[:500]] dataloader = DataLoader(datasets, batch_size=batch_size) # 모니터링 생성 if dist.get_rank() == 0: - wandb.init(project="oslo", name=f"tp1d_bs{batch_size}") + wandb.init(project="oslo", name=f"{model_name}_tp1d_bs{batch_size}") + cur = time.time() # 모니터링 생성 대기 dist.barrier() @@ -62,17 +88,39 @@ max_length=512, ).to("cuda") - loss_tp = wrapper_tp(**inputs, labels=inputs["input_ids"]).loss - loss_no_tp = model_no_tp(**inputs, labels=inputs["input_ids"]).loss + loss_no_tp, notp_fw_time = \ + fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) if dist.get_rank() == 0: print(f"TP:{loss_tp}, NOTP:{loss_no_tp}") wandb.log({"tp": loss_tp, "notp": loss_no_tp}) - loss_tp.backward() - loss_no_tp.backward() + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) optimizer_tp.step() optimizer_no_tp.step() + if dist.get_rank() == 0: + wandb.log({ + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time}) + # + # loss_tp = wrapper_tp(**inputs, labels=inputs["input_ids"]).loss + # loss_no_tp = model_no_tp(**inputs, labels=inputs["input_ids"]).loss + # + # if dist.get_rank() == 0: + # print(f"TP:{loss_tp}, NOTP:{loss_no_tp}") + # wandb.log({"tp": loss_tp, "notp": loss_no_tp}) + # + # loss_tp.backward() + # loss_no_tp.backward() + # + # optimizer_tp.step() + # optimizer_no_tp.step() + dist.barrier() From 4e345bed93effa26e23fd0b9b79d717990e6d65c Mon Sep 17 00:00:00 2001 From: jason960903 Date: Fri, 17 Jun 2022 10:19:53 +0900 Subject: [PATCH 12/37] fixed 2p5p5d deparallel (deconstruct_qkv) bug --- .../nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index b86d305a..62a81efa 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -613,11 +613,13 @@ def _reconstruct_combined_qkv(tensor, tesseract_dim, fusion_degree, is_bias=Fals reshaped_w = tensor.view(tesseract_dim*fusion_degree, -1, last_dim) # print(f"tensor.size: {tensor.size()}, reshaped_w.size: {reshaped_w.size()}") recon_w = torch.cat([ - reshaped_w[:fusion_degree], reshaped_w[fusion_degree:]], 1).view(-1, last_dim).contiguous() + reshaped_w[i * fusion_degree: (i+1) * fusion_degree] + for i in range(tesseract_dim)], 1).view(-1, last_dim).contiguous() else: reshaped_w = tensor.view(fusion_degree*tesseract_dim, -1) recon_w = torch.cat([ - reshaped_w[:fusion_degree], reshaped_w[fusion_degree:]], 1).view(-1, last_dim).contiguous() + reshaped_w[i * fusion_degree: (i+1) * fusion_degree] + for i in range(tesseract_dim)], 1).view(-1, last_dim).contiguous() return recon_w @staticmethod From d4562f7a73e0935f1be3eca9eb4171b51dd58d51 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Thu, 7 Jul 2022 22:27:54 +0900 Subject: [PATCH 13/37] implement tensor 1d deprarllelize --- .../tensor_parallel/_parallel_1d/_wrapper.py | 21 ++++++++++++------- .../_parallel_1d/deparallel/test_qkv.py | 1 + 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 0aa61c04..4d06fb38 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -47,7 +47,7 @@ def __init__( mapping: dict = None, module_args: dict = None ): - super().__init__() + super().__init__(module, parallel_context) self.module = module self.parallel_context = parallel_context self.device = torch.cuda.current_device() @@ -447,13 +447,18 @@ def _gather_column_linear(self, module): def _gather_row_linear(self, module): self._gather_linear(module, dim=1) - # TODO: fix @staticmethod def _reconstruct_combined_qkv(tensor, world_size, fusion_degree, dim): - last_dim = tensor.size()[dim-1] - reshaped_w = tensor.view(fusion_degree * world_size, -1) - # print(f"tensor.size: {tensor.size()}, reshaped_w.size: {reshaped_w.size()}") + if dim == 0: + reshaped_w = tensor + else: + reshaped_w = tensor.permute( + dim, *range(0, dim), *range(dim+1, tensor.dim())) + reshaped_w = reshaped_w.view(world_size, fusion_degree, -1) recon_w = torch.cat([ - reshaped_w[i * fusion_degree: (i+1) * fusion_degree] - for i in range(world_size)], 1).view(-1, last_dim).contiguous() - return recon_w \ No newline at end of file + reshaped_w[i] + for i in range(world_size)], 1) + recon_w = recon_w.view(recon_w.size()[0] * world_size, recon_w.size()[1]//world_size).contiguous() + if dim == 0: + recon_w = recon_w.permute(1, 0) + return recon_w diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py index 18f3cb04..003d3d5e 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py @@ -42,6 +42,7 @@ print(chunked_w.size()) print(linear_1d.weight.data.size()) linear_1d.weight.data = chunked_w +linear_1d.bias.data = chunked_b recon_chunked_w = gather_1d(parallel_context, linear_1d.weight.data, world_size, dim) recon_chunked_b = gather_1d(parallel_context, linear_1d.bias.data, world_size, 0) From 72ca5db3dc37529be6f1f8658f11cc219e289276 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Tue, 12 Jul 2022 09:24:22 +0900 Subject: [PATCH 14/37] fix 2p5d deparallel shape error --- .../_parallel_2p5d/_wrapper.py | 197 ++---------------- .../deparallel/test_deparallelize.py | 4 +- 2 files changed, 15 insertions(+), 186 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 62a81efa..f3866fb0 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -462,9 +462,10 @@ def _slice_head(self, module, reversed): @torch.no_grad() def deparallelize(self): + # must deparallelize linear first than embedding self._deparallelize_linear() - self._deparallelize_embedding() self._deparallelize_layernorm() + self._deparallelize_embedding() self._rollback_mp_arguments() def _rollback_mp_arguments(self): @@ -481,6 +482,8 @@ def _deparallelize_embedding(self): for param_name, module in self.module.named_modules(): if module.__class__ == VocabParallelEmbedding2p5D: self._gather_embedding(module) + if module.__class__ == Embedding2p5D: + self._gather_embedding(module) def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): @@ -495,14 +498,18 @@ def _deparallelize_layernorm(self): def _gather_embedding(self, module): tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): - w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim, col_first=True) + w = module.weight.data + + # if module is shared with linear, then skip this loop + if module.embedding_dim == module.weight.size()[0]: + w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim, col_first=True) assert hasattr( self.module, "orig_vocab_size" ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." orig_vocab_size = self.module.orig_vocab_size - module.weight.data = w[:, :orig_vocab_size] + module.weight.data = w[:orig_vocab_size, :] _update_module_arguments( module=module, @@ -521,7 +528,7 @@ def _gather_embedding(self, module): _update_module_arguments( module=module, parallel_context=None, - embedding_dim = module.weight.size()[1] + embedding_dim=module.weight.size()[1] ) module.__class__ = nn.Embedding @@ -533,18 +540,17 @@ def _gather_linear(self, module: Linear2p5D): tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim=tesseract_dim, col_first=True) - # print(f"w shape: {w.shape}\nweight shape: {module.weight.data.shape}") if fusion_degree > 1: w = self._reconstruct_combined_qkv(w, tesseract_dim, fusion_degree, False) if is_reversed: - w = module.weight.data.t() + w = w.t() module.weight.data = w if hasattr(module, "bias") and module.bias is not None: - # if slice_bias is True and module.bias.dim() >= 1: b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) if fusion_degree > 1: b = self._reconstruct_combined_qkv(b, tesseract_dim, fusion_degree, True) + b = b.view(b.size()[1:]) module.bias.data = b _update_module_arguments( @@ -611,7 +617,6 @@ def _reconstruct_combined_qkv(tensor, tesseract_dim, fusion_degree, is_bias=Fals last_dim = tensor.size()[-1] if is_bias is False: reshaped_w = tensor.view(tesseract_dim*fusion_degree, -1, last_dim) - # print(f"tensor.size: {tensor.size()}, reshaped_w.size: {reshaped_w.size()}") recon_w = torch.cat([ reshaped_w[i * fusion_degree: (i+1) * fusion_degree] for i in range(tesseract_dim)], 1).view(-1, last_dim).contiguous() @@ -631,179 +636,3 @@ def _reconstrunct_combined_qkv_bias(tensor, tessearct_dim, fusion_degree): tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) tensor = [tensor[j] for j in range(tessearct_dim)] return tensor - - # @torch.no_grad() - # def save_parallelized( - # self, - # new_module, - # save_directory: Union[str, os.PathLike], - # save_config: bool = True, - # state_dict: Optional[dict] = None, - # save_function: Callable = torch.save, - # merge_checkpoints: bool = False, - # mapping: Optional[dict] = None, - # **kwargs, - # ): - # logger = getLogger("Tensor2p5D") - # PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" - # - # if ( - # self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1 - # and self.parallel_context.get_world_size(ParallelMode.PIPELINE) == 1 - # ): - # if dist.get_rank() == 0: - # self.save_pretrained( - # save_directory=save_directory, - # save_config=save_config, - # state_dict=state_dict, - # save_function=save_function, - # **kwargs, - # ) - # dist.barrier() - # return None - # - # if merge_checkpoints: - # model_to_save = self.__class__( - # module=new_module, - # parallel_context=self.parallel_context, - # mapping=mapping, - # module_args=self.config - # ).eval() - # - # if state_dict is None: - # state_dict = self.state_dict() - # - # model_to_save.load_state_dict(state_dict) - # allocate_params(model_to_save, self.parallel_context) - # - # if self.parallel_context.get_world_size(ParallelMode.TENSOR) > 1: - # model_to_save.deparallelize() - # - # if dist.get_rank() == 0: - # if is_huggingface_model(model_to_save.module): - # model_to_save.module.save_pretrained( - # save_directory=save_directory, - # save_config=save_config, - # save_function=save_function, - # **kwargs, - # ) - # else: - # if save_config: - # with open(os.path.join(save_directory, "config.json"), "w") as f: - # json.dump(self.config, f) - # save_function( - # model_to_save, - # os.path.join(save_directory, "pytorch_model.bin"), - # ) - # del model_to_save - # - # dist.barrier() - # return None - # - # if os.path.isfile(save_directory): - # logger.error( - # f"Provided path ({save_directory}) should be a directory, not a file" - # ) - # return - # - # os.makedirs(save_directory, exist_ok=True) - # - # # Only save the model itself if we are using distributed training - # model_to_save = unwrap_parallel(self) - # - # # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" - # # we currently don't use this setting automatically, but may start to use with v5 - # dtype = get_parameter_dtype(model_to_save) - # model_to_save.config.torch_dtype = str(dtype).split(".")[1] - # - # # Attach architecture to the config - # model_to_save.config.architectures = [model_to_save.__class__.__name__] - # - # # Save the config - # if save_config: - # model_to_save.config.save_pretrained(save_directory) - # - # # Save the model - # if state_dict is None: - # state_dict = model_to_save.state_dict() - # - # # Handle the case where some state_dict keys shouldn't be saved - # if getattr(self, "_keys_to_ignore_on_save", None) is not None: - # state_dict = { - # k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save - # } - # - # # If we save using the predefined names, we can load using `from_pretrained` - # weights_name = PARALLELIZED_WEIGHTS_NAME - # weights_name = weights_name.replace( - # "tp_0", f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}" - # ) - # weights_name = weights_name.replace( - # "pp_0", f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}" - # ) - # - # output_model_file = os.path.join(save_directory, weights_name) - # - # if self.parallel_context.get_world_size(ParallelMode.DATA) > 1: - # if self.parallel_context.get_local_rank(ParallelMode.DATA) == 0: - # save_function(state_dict, output_model_file) - # else: - # save_function(state_dict, output_model_file) - # - # dist.barrier() - # logger.info(f"Model weights saved in {output_model_file}") - # - # def from_parallelized(self, path): - # """ - # Example: - # >>> model = AnyModel() - # >>> model = TensorParallel(model, ...) - # >>> model.from_parallelized(path) - # """ - # PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" - # parallelized_model_path = path - # - # file_names = { - # os.path.join( - # parallelized_model_path, - # PARALLELIZED_WEIGHTS_NAME.replace("tp_0", f"tp_{tp}").replace( - # "pp_0", f"pp_{pp}" - # ), - # ) - # for tp in range(self.parallel_context.get_world_size(ParallelMode.TENSOR)) - # for pp in range(self.parallel_context.get_world_size(ParallelMode.PIPELINE)) - # } - # - # if os.path.isdir(parallelized_model_path): - # if all(os.path.isfile(file_name) for file_name in file_names): - # state_dict = torch.load( - # os.path.join( - # parallelized_model_path, - # PARALLELIZED_WEIGHTS_NAME.replace( - # "tp_0", - # f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}", - # ).replace( - # "pp_0", - # f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}", - # ), - # ) - # ) - # - # if getattr(self, "_keys_to_ignore_on_save", None) is not None: - # state_dict = { - # k: v - # for k, v in state_dict.items() - # if k not in self._keys_to_ignore_on_save - # } - # - # self.load_state_dict(state_dict=state_dict, strict=False) - # - # else: - # raise FileNotFoundError( - # f"all the {file_names} are necessary. " - # f"but some of them do not exist. Please check your checkpoint files." - # ) - # else: - # raise NotADirectoryError( - # f"directory named {parallelized_model_path} is not valid. " - # ) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py index c5ac9152..46aaa839 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py @@ -78,13 +78,13 @@ def bw(tensors): cur = time.time() # 저장 -wrapper_tp.save_parallelized('test/', merge_checkpoints=False) +wrapper_tp.save_parallelized('test/', merge_checkpoints=True) # 모니터링 생성 대기 dist.barrier() # 로드 -model_gathered = GPT2LMHeadModel(GPT2Config.from_pretrained("test/")).cuda() +model_gathered = GPT2LMHeadModel.from_pretrained("test/").cuda() optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) dist.barrier() From 731650cffc5444b55038e477fdc99ec98295c290 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Tue, 19 Jul 2022 08:19:24 +0900 Subject: [PATCH 15/37] fix test code --- .../parallel/tensor_parallel/_base_wrapper.py | 4 +- .../tensor_parallel/_parallel_1d/_wrapper.py | 63 ++---- .../tensor_parallel/_parallel_2d/_ops.py | 61 ++++++ .../tensor_parallel/_parallel_2d/_wrapper.py | 185 +++++++++++++++++- .../deparallel/test_deparallelize.py | 6 +- .../deparallel/test_load_parallel.py | 8 +- .../_parallel_2d/deparallel/__init__.py | 3 + .../deparallel/test_deparallelize.py | 139 +++++++++++++ .../deparallel/test_load_parallel.py | 140 +++++++++++++ .../_parallel_2d/deparallel/test_qkv.py | 62 ++++++ .../_parallel_2p5d/deparallel/test_vocab.py | 85 ++++++++ 11 files changed, 699 insertions(+), 57 deletions(-) create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/__init__.py create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py diff --git a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py index fd5d302a..f22438cd 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py @@ -24,7 +24,7 @@ class BaseTensorParallelWrapper(ParallelWrapper): """ - PyTorch module for 2.5D tensor parallelism + PyTorch module for xD tensor parallelism Args: module (nn.Module): model object @@ -52,7 +52,7 @@ def save_parallelized( mapping: Optional[dict] = None, **kwargs, ): - logger = getLogger("Tensor2p5D") + logger = getLogger("TensorParallel") PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" if ( diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 4d06fb38..4613d335 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -332,6 +332,8 @@ def _deparallelize_embedding(self): for param_name, module in self.module.named_modules(): if module.__class__ == VocabParallelEmbedding1D: self._gather_embedding(module) + if module.__class__ == Embedding1D: + self._gather_embedding(module) def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): @@ -358,7 +360,7 @@ def _gather_embedding(self, module): ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." orig_vocab_size = self.module.orig_vocab_size - module.weight.data = w[:, :orig_vocab_size] + module.weight.data = w[:orig_vocab_size, :] _update_module_arguments( module=module, @@ -390,35 +392,15 @@ def _gather_linear(self, module, dim=1): is_reversed = module.reversed fusion_degree = module.fusion_degree - world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) - - # w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim=tesseract_dim, col_first=True) - tensor_list = [torch.zeros_like(module.weight.data) for _ in range(world_size)] - dist.all_gather( - tensor_list, - module.weight.data.contiguous(), - self.parallel_context.get_group(ParallelMode.TENSOR_1D), - ) - w = torch.cat(tensor_list, dim=dim) + world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR) - if fusion_degree > 1: - w = self._reconstruct_combined_qkv(w, world_size, fusion_degree, False) + w = self._reconstruct_combined_qkv(module.weight, world_size, fusion_degree, dim) if is_reversed: - w = module.weight.data.t() + w = w.t() module.weight.data = w - if hasattr(module, "bias") and module.bias is not None: - # if slice_bias is True and module.bias.dim() >= 1: - # b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, dim) - tensor_list = [torch.zeros_like(module.bias.data) for _ in range(world_size)] - dist.all_gather( - tensor_list, - module.bias.data.contiguous(), - self.parallel_context.get_group(ParallelMode.TENSOR_1D), - ) - b = torch.cat(tensor_list, dim=dim) - if fusion_degree > 1: - b = self._reconstruct_combined_qkv(b, world_size, fusion_degree, dim) + if hasattr(module, "bias") and module.bias is not None and dim != 1: + b = self._reconstruct_combined_qkv(module.bias, world_size, fusion_degree, dim) module.bias.data = b _update_module_arguments( @@ -430,14 +412,10 @@ def _gather_linear(self, module, dim=1): if hasattr(module, "skip_bias_add") else False, ) - del module.data_parallel_rank - del module.pipeline_parallel_rank - del module.tensor_parallel_size - del module.pipeline_parallel_size + del module.reversed del module.fusion_degree del module.orig_module - del module.gather_output del module.parallel_context module.__class__ = nn.Linear @@ -447,18 +425,11 @@ def _gather_column_linear(self, module): def _gather_row_linear(self, module): self._gather_linear(module, dim=1) - @staticmethod - def _reconstruct_combined_qkv(tensor, world_size, fusion_degree, dim): - if dim == 0: - reshaped_w = tensor - else: - reshaped_w = tensor.permute( - dim, *range(0, dim), *range(dim+1, tensor.dim())) - reshaped_w = reshaped_w.view(world_size, fusion_degree, -1) - recon_w = torch.cat([ - reshaped_w[i] - for i in range(world_size)], 1) - recon_w = recon_w.view(recon_w.size()[0] * world_size, recon_w.size()[1]//world_size).contiguous() - if dim == 0: - recon_w = recon_w.permute(1, 0) - return recon_w + def _reconstruct_combined_qkv(self, tensor, world_size, fusion_degree, dim: int): + tensor_list = tensor.chunk(fusion_degree, dim=dim) + result_list = [] + for w in tensor_list: + w_list = [torch.zeros_like(w) for _ in range(world_size)] + dist.all_gather(w_list, w, self.parallel_context.get_group(ParallelMode.TENSOR_1D)) + result_list.append(torch.cat(w_list, dim=dim)) + return torch.cat(result_list, dim=dim) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py index 61226a6e..56d1c02c 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py @@ -144,6 +144,67 @@ def gather_batch_2d( ) +def gather_2d(parallel_context, tensor, summa_dim, col_first=True): + if col_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2D_ROW), + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor + + +def gather_1d_twice(parallel_context, tensor, summa_dim, dim=-1): + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2D_COL), + ) + tensor = torch.cat(tensor_list, dim=dim) + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, tensor, parallel_context.get_group(ParallelMode.TENSOR_2D_ROW) + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor + + def split_batch_2d( inputs: Tensor, dim: int = 0, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index b14ddf59..31818e44 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -17,6 +17,9 @@ ) from oslo.torch.nn.parallel.tensor_parallel._parallel_2d._ops import ( split_batch_2d, + gather_2d, + gather_1d, + gather_1d_twice ) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, @@ -31,10 +34,14 @@ _TensorParallelMappingForHuggingFace, ) +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, +) + from oslo.transformers.constants import BATCH_DIMENSIONS -class _TensorParallel2D(ParallelWrapper): +class _TensorParallel2D(BaseTensorParallelWrapper): """ PyTorch module for 2D tensor parallelism @@ -51,7 +58,7 @@ def __init__( module_args: dict = None ): - super().__init__() + super().__init__(module, parallel_context, mapping, module_args) self.module = module self.parallel_context = parallel_context self.device = torch.cuda.current_device() @@ -422,3 +429,177 @@ def _slice_head(self, module, reversed): orig_module=copy.deepcopy(module.__class__), ) module.__class__ = Linear2D + + @torch.no_grad() + def deparallelize(self): + self._deparallelize_linear() + self._deparallelize_layernorm() + self._deparallelize_embedding() + self._rollback_mp_arguments() + + def _rollback_mp_arguments(self): + for module in self.module.modules(): + for elem in self.tensor_parallel_mapping.update_attrs(self.module): + if hasattr(module, elem.name): + summa_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2D_COL + ) + expanded_arg = getattr(module, elem.name) * summa_dim + setattr(module, elem.name, expanded_arg) + + def _deparallelize_embedding(self): + for param_name, module in self.module.named_modules(): + if module.__class__ in [VocabParallelEmbedding2D, Embedding2D]: + self._gather_embedding(module) + + def _deparallelize_linear(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == Linear2D: + self._gather_linear(module) + + def _deparallelize_layernorm(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == LayerNorm2D: + self._gather_layernorm(module) + + def _gather_embedding(self, module): + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + w = module.weight.data + + if module.embedding_dim == module.weight.size()[0]: + w = gather_2d(self.parallel_context, module.weight.data, summa_dim, col_first=True) + + assert hasattr( + self.module, "orig_vocab_size" + ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + orig_vocab_size = self.module.orig_vocab_size + + module.weight.data = w[:orig_vocab_size, :] + + _update_module_arguments( + module=module, + vocab_start_index=None, + vocab_end_index=None, + parallel_context=None, + num_embeddings=module.weight.size()[0], + embedding_dim=module.weight.size()[1], + orig_module=None + ) + else: + w = gather_1d_twice(self.parallel_context, module.weight.data, summa_dim, 1) + module.weight.data = w + + _update_module_arguments( + module=module, + parallel_context=None, + embedding_dim=module.weight.size()[1] + ) + module.__class__ = nn.Embedding + + def _gather_linear(self, module: Linear2D): + is_reversed = module.reversed + fusion_degree = module.fusion_degree + # slice_bias = module.slice_bias + + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) + + w = gather_2d(self.parallel_context, module.weight.data, summa_dim=summa_dim, col_first=True) + # print(f"w shape: {w.shape}\nweight shape: {module.weight.data.shape}") + if fusion_degree > 1: + w = self._reconstruct_combined_qkv(w, summa_dim, fusion_degree, False) + if is_reversed: + w = w.t() + module.weight.data = w + + if hasattr(module, "bias") and module.bias is not None: + # if slice_bias is True and module.bias.dim() >= 1: + b = gather_1d_twice(self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0) + if fusion_degree > 1: + b = self._reconstruct_combined_qkv(b, summa_dim, fusion_degree, True) + b = b.view(b.size()[1:]) + module.bias.data = b + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.row_rank + del module.col_rank + del module.summa_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + + module.__class__ = nn.Linear + + def _gather_layernorm(self, module): + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) + if hasattr(module, "weight") and module.weight is not None: + if module.weight.dim() >= 1: + w = gather_1d_twice(self.parallel_context, module.weight.data, summa_dim=summa_dim, dim=0) + module.weight.data = w + + if hasattr(module.weight, "oslo_parallel"): + del module.weight.oslo_parallel + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + b = gather_1d_twice(self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0) + module.bias.data = b + + if hasattr(module.bias, "oslo_parallel"): + del module.bias.oslo_parallel + + del module.partitioned_dim + del module.row_rank + del module.col_rank + del module.summa_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.orig_module + _update_module_arguments( + module, + normalized_shape=module.weight.size()[0], + ) + module.__class__ = nn.LayerNorm + + @staticmethod + def _reconstruct_combined_qkv(tensor, summa_dim, fusion_degree, is_bias=False): + last_dim = tensor.size()[-1] + if is_bias is False: + reshaped_w = tensor.view(summa_dim*fusion_degree, -1, last_dim) + # print(f"tensor.size: {tensor.size()}, reshaped_w.size: {reshaped_w.size()}") + recon_w = torch.cat([ + reshaped_w[i * fusion_degree: (i+1) * fusion_degree] + for i in range(summa_dim)], 1).view(-1, last_dim).contiguous() + else: + reshaped_w = tensor.view(fusion_degree*summa_dim, -1) + recon_w = torch.cat([ + reshaped_w[i * fusion_degree: (i+1) * fusion_degree] + for i in range(summa_dim)], 1).view(-1, last_dim).contiguous() + return recon_w + + @staticmethod + def _reconstrunct_combined_qkv_bias(tensor, summa_dim, fusion_degree): + tensor = [ + [tensor[j * summa_dim + k] for k in range(summa_dim)] + for j in range(fusion_degree) + ] + tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) + tensor = [tensor[j] for j in range(summa_dim)] + return tensor + diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py index a95aacd1..986dc365 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py @@ -74,17 +74,17 @@ def bw(tensors): # 모니터링 생성 if dist.get_rank() == 0: - wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") + wandb.init(project="oslo", name=f"{model_name}_tp1d_bs{batch_size}") cur = time.time() # 저장 -wrapper_tp.save_parallelized('test/', merge_checkpoints=False) +wrapper_tp.save_parallelized('test/', merge_checkpoints=True) # 모니터링 생성 대기 dist.barrier() # 로드 -model_gathered = GPT2LMHeadModel(GPT2Config.from_pretrained("test/")).cuda() +model_gathered = GPT2LMHeadModel.from_pretrained("test/").cuda() optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py index 51f9009c..a0f5b3b7 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py @@ -30,8 +30,8 @@ def bw(tensors): return tensors.backward() -tp_size = 8 -tp_depth = 2 +tp_size = 4 +tp_depth = 1 model_name = "gpt2" mkwargs = { @@ -43,7 +43,7 @@ def bw(tensors): data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=tp_size, - tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_mode=ParallelMode.TENSOR_1D, tensor_parallel_depth=tp_depth, ) @@ -74,7 +74,7 @@ def bw(tensors): # 모니터링 생성 if dist.get_rank() == 0: - wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") + wandb.init(project="oslo", name=f"{model_name}_tp1d_bs{batch_size}") cur = time.time() # 저장 diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/__init__.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/__init__.py new file mode 100644 index 00000000..28fcc144 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/__init__.py @@ -0,0 +1,3 @@ +from .. import _utils + +_ALL__ = [_utils] diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py new file mode 100644 index 00000000..6d790952 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py @@ -0,0 +1,139 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end-start + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 4 +tp_depth = 1 + +model_name = "gpt2" +mkwargs = { +} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() +model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized('test/', merge_checkpoints=True) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_gathered = GPT2LMHeadModel.from_pretrained("test/").cuda() +optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + optimizer_gathered.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = \ + fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = \ + fw(model_gathered, **inputs, labels=inputs["input_ids"]) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, gathered_bw_time = bw(loss_gathered) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_gathered.step() + + if dist.get_rank() == 0: + wandb.log({ + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time + }) + +dist.barrier() + + + + diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py new file mode 100644 index 00000000..d0a12ab3 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py @@ -0,0 +1,140 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end-start + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 4 +tp_depth = 1 + +model_name = "gpt2" +mkwargs = { +} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() +model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized('test/', merge_checkpoints=False) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_reparallel = TensorParallel( + GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), + parallel_context +) +allocate_params(model_reparallel, parallel_context) +model_reparallel.from_parallelized('test/') +optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + model_reparallel.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = \ + fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_reparallel, reparallel_fw_time = \ + fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, reparallel:{loss_reparallel}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "reparallel": loss_reparallel}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, reparallel_bw_time = bw(loss_reparallel) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_reparallel.step() + + if dist.get_rank() == 0: + wandb.log({ + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "reparallel.forward.time:": reparallel_fw_time, + "reparallel.backward.time:": reparallel_bw_time + }) + +dist.barrier() \ No newline at end of file diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py new file mode 100644 index 00000000..6cb09fd8 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py @@ -0,0 +1,62 @@ +import torch.nn as nn + +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._wrapper import _TensorParallel2p5D +from oslo.torch.nn import Linear2p5D +from oslo.torch.distributed import ParallelContext, ParallelMode +from copy import deepcopy +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import gather_1d, gather_2d + +tp_size = 4 +tp_depth = 1 + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2D, + tensor_parallel_depth=tp_depth, +) + +row_rank = parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW) +col_rank = parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL) +summa_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) +fusion_degree = 3 + +linear = nn.Linear(4, fusion_degree*4).cuda() +w = deepcopy(linear.weight.data) +b = deepcopy(linear.bias.data) + +weight_list = w.chunk(summa_dim, dim=1) +weight_list = [ + weight.chunk(summa_dim * fusion_degree, dim=0) for weight in weight_list +] +bias_list = b.chunk(summa_dim * fusion_degree, dim=0) + +# [t][f*t] +weight_list = _TensorParallel2p5D._deconstruct_combined_qkv(weight_list, tesseract_dim, fusion_degree, False) +bias_list = _TensorParallel2p5D._deconstruct_combined_qkv(bias_list, tesseract_dim, fusion_degree, True) +chunked_w = weight_list[row_rank][col_rank] +chunked_b = bias_list[row_rank] + +linear_2d = Linear2p5D(4, fusion_degree*4, parallel_context=parallel_context, bias=True) +if parallel_context.get_global_rank() == 0: + print(chunked_w.size()) + print(linear_2d.weight.data.size()) +linear_2d.weight.data = chunked_w +linear_2d.bias.data = chunked_b + +recon_chunked_w = gather_2d(parallel_context, linear_2d.weight.data, summa_dim, True) +recon_chunked_b = gather_1d(parallel_context, linear_2d.bias.data, summa_dim, 0) +# reshaped_w = recon_chunked_w.view(-1, tesseract_dim, 4) +# recon_w = torch.cat([reshaped_w[:3], reshaped_w[3:]], 1).view(-1, 4).contiguous() + +recon_w = _TensorParallel2p5D._reconstruct_combined_qkv(recon_chunked_w, summa_dim, fusion_degree, False) +recon_b = _TensorParallel2p5D._reconstruct_combined_qkv(recon_chunked_b, summa_dim, fusion_degree, True) + +if parallel_context.get_global_rank() == 0: + print(f"original w: \n{w}\n") + print(f"reconstruct w: \n{recon_w}\n") + + print(f"original b: \n{b}\n") + print(f"reconstruct b: \n{recon_b}\n") \ No newline at end of file diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py new file mode 100644 index 00000000..3edfde0c --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py @@ -0,0 +1,85 @@ +import torch +import torch.distributed as dist + +from oslo.torch.distributed import ParallelContext, ParallelMode +from oslo.torch.nn import VocabParallelEmbedding2p5D +from oslo.torch.nn.parallel import utils + +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import split_batch_2d, split_2d, gather_2d + +from copy import deepcopy + + +tp_size = 8 +tp_depth = 2 + +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +torch.set_printoptions(sci_mode=False) +torch.manual_seed(0) +tesseract_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) +input_ = torch.LongTensor([[1, 2, 3, 4], [5, 6, 7, 8]]).cuda() +target = torch.randn((2, 4, 16)).cuda() +dist.broadcast(input_, src=0) +dist.broadcast(target, src=0) + +vocab_embedding = torch.nn.Embedding(10, 16).cuda() +w = deepcopy(vocab_embedding.weight.data) + +out = vocab_embedding(input_) +optimizer = torch.optim.Adam(vocab_embedding.parameters(), lr=1e-3) +logits = torch.nn.MSELoss()(out, target) +logits.backward() +optimizer.step() + +out_update = vocab_embedding(input_) + +if parallel_context.get_global_rank() == 0: + print(f"original output: \n{out}\n") + print(f"original update output: \n{out_update}\n") + # print(f"vocab start: {vocab_embedding.start_index}, vocab end: {vocab_embedding.end_index}") + +input_ = split_batch_2d(parallel_context, input_, tesseract_dim) +# split target into 0:[0, 0], 1:[0, 1], 2:[1, 0], 3:[1, 1] +target = split_2d(parallel_context, target, tesseract_dim, col_first=True) +# split weight into 0:[0, 0], 1:[1, 0], 2:[0, 1], 3:[1, 1] +w = split_2d(parallel_context, w, tesseract_dim, col_first=False) + +vocab_embedding_2p5d = VocabParallelEmbedding2p5D( + 10, 16, parallel_context=parallel_context +) +vocab_embedding_2p5d.weight.data.copy_(w) + +pout = vocab_embedding_2p5d(input_) +optimizer = torch.optim.Adam(vocab_embedding_2p5d.parameters(), lr=1e-3) +logits = torch.nn.MSELoss()(pout, target) +logits.backward() +optimizer.step() + +if parallel_context.get_global_rank() == 0: + unwrapped_model = utils.unwrap_parallel(vocab_embedding_2p5d) + print(f"original vocab size: {unwrapped_model.orig_vocab_size}") + + +# +# +# pout_update = vocab_embedding_2p5d(input_) +# +# pout = gather_2d(parallel_context, pout, tesseract_dim, col_first=False) +# pout_update = gather_2d(parallel_context, pout_update, tesseract_dim, col_first=False) +# +# if parallel_context.get_global_rank() == 0: +# print(f"parallel output: \n{pout}\n") +# print(f"parallel update output: \n{pout_update}\n") +# +# if parallel_context.get_global_rank() == 0: +# sse = torch.sum((out - pout) ** 2).item() +# sse_update = torch.sum((out_update - pout_update) ** 2).item() +# print(f"output sse: \n{sse}\n") +# print(f"next output sse: \n{sse_update}\n") \ No newline at end of file From 6a76699b45cb23ef3eebf47c4eb46121765ad9ef Mon Sep 17 00:00:00 2001 From: bzantium Date: Wed, 20 Jul 2022 14:26:39 +0900 Subject: [PATCH 16/37] Add splitting bias for head --- .../tensor_parallel/_parallel_1d/_wrapper.py | 13 +++++++++++++ .../tensor_parallel/_parallel_2d/_wrapper.py | 17 +++++++++++++++++ .../tensor_parallel/_parallel_2p5d/_wrapper.py | 18 ++++++++++++++++++ 3 files changed, 48 insertions(+) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 4613d335..f45baf9d 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -299,6 +299,19 @@ def _slice_head(self, module, reversed): gather_output=not is_oslo_model(self.module), ) else: + world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) + rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_1D) + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + bias_list = module.bias.data.chunk(world_size, dim=0) + module.bias.data = bias_list[rank].contiguous() + + if hasattr(module.bias, "oslo_parallel"): + module.bias.oslo_parallel[ParallelMode.TENSOR_1D] = rank + else: + module.bias.oslo_parallel = {ParallelMode.TENSOR_1D: rank} + _update_module_arguments( module=module, parallel_context=self.parallel_context, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 31818e44..818a9cb3 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -410,6 +410,23 @@ def _slice_head(self, module, reversed): pipeline_parallel_size = self.parallel_context.get_world_size( ParallelMode.PIPELINE ) + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + bias_list = module.bias.data.chunk(summa_dim, dim=0) + bias_list = [ + bias.chunk(summa_dim, dim=0) for bias in bias_list + ] + module.bias.data = bias_list[row_rank][col_rank].contiguous() + + if hasattr(module.bias, "oslo_parallel"): + module.bias.oslo_parallel[ParallelMode.TENSOR_2D_ROW] = row_rank + module.bias.oslo_parallel[ParallelMode.TENSOR_2D_COL] = col_rank + else: + module.bias.oslo_parallel = { + ParallelMode.TENSOR_2D_ROW: row_rank, + ParallelMode.TENSOR_2D_COL: col_rank, + } + _update_module_arguments( module=module, in_features=module.weight.size()[1], diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index f3866fb0..d649f4d8 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -439,6 +439,24 @@ def _slice_head(self, module, reversed): pipeline_parallel_size = self.parallel_context.get_world_size( ParallelMode.PIPELINE ) + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + bias_list = module.bias.data.chunk(tesseract_dim, dim=0) + + module.bias.data = bias_list[row_rank].contiguous() + + if hasattr(module.bias, "oslo_parallel"): + module.bias.oslo_parallel[ParallelMode.TENSOR_2P5D_ROW] = row_rank + module.bias.oslo_parallel[ParallelMode.TENSOR_2P5D_COL] = col_rank + module.weight.oslo_parallel[ParallelMode.TENSOR_2P5D_DEP] = dep_rank + else: + module.bias.oslo_parallel = { + ParallelMode.TENSOR_2P5D_ROW: row_rank, + ParallelMode.TENSOR_2P5D_COL: col_rank, + ParallelMode.TENSOR_2P5D_DEP: dep_rank, + } + _update_module_arguments( module=module, in_features=module.weight.size()[1], From 176c374f57ccb0448923e325df647197ee6cce45 Mon Sep 17 00:00:00 2001 From: bzantium Date: Wed, 20 Jul 2022 14:29:40 +0900 Subject: [PATCH 17/37] Add space for consistency --- oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py | 1 + 1 file changed, 1 insertion(+) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 818a9cb3..d7146fcd 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -410,6 +410,7 @@ def _slice_head(self, module, reversed): pipeline_parallel_size = self.parallel_context.get_world_size( ParallelMode.PIPELINE ) + if hasattr(module, "bias") and module.bias is not None: if module.bias.dim() >= 1: bias_list = module.bias.data.chunk(summa_dim, dim=0) From 1629363d01d1f81646d0cc836ce320e944141510 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Sat, 6 Aug 2022 11:14:26 +0900 Subject: [PATCH 18/37] fixed roberta gathering --- oslo/torch/nn/modules/linear.py | 2 + .../_parallel_2p5d/_wrapper.py | 10 +++-- .../tensor_parallel/tensor_parallel.py | 2 +- oslo/transformers/mapping_utils.py | 23 ++++++---- .../deparallel/test_deparallelize.py | 45 ++++++++++++++++--- 5 files changed, 63 insertions(+), 19 deletions(-) diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index d174af8e..f127c3fd 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -360,6 +360,7 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: parallel_mode=ParallelMode.TENSOR_2D_ROW, parallel_context=self.parallel_context, ).clone() + outputs = outputs.contiguous() return outputs @@ -502,6 +503,7 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: col_parallel_mode=ParallelMode.TENSOR_2P5D_ROW, parallel_context=self.parallel_context, ).clone() + output = output.contiguous() return output diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index d649f4d8..189bf912 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -159,6 +159,9 @@ def _parallelize_head(self): reversed=self.tensor_parallel_mapping.is_reversed( self.module, param_name ), + gathered=self.tensor_parallel_mapping.is_gather_output( + self.module, param_name + ) ) @staticmethod @@ -403,7 +406,7 @@ def _slice_layernorm(self, module): module.__class__ = LayerNorm2p5D return module - def _slice_head(self, module, reversed): + def _slice_head(self, module, reversed, gathered): if module.weight is not self.module.get_input_embeddings().weight: self._slice_linear( module=module, @@ -413,7 +416,7 @@ def _slice_head(self, module, reversed): ) _update_module_arguments( module=module, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gathered, ) else: row_rank = self.parallel_context.get_local_rank( @@ -457,6 +460,7 @@ def _slice_head(self, module, reversed): ParallelMode.TENSOR_2P5D_DEP: dep_rank, } + gather_output = None _update_module_arguments( module=module, in_features=module.weight.size()[1], @@ -473,7 +477,7 @@ def _slice_head(self, module, reversed): reversed=reversed, fusion_degree=1, skip_bias_add=False, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gathered, orig_module=copy.deepcopy(module.__class__), ) module.__class__ = Linear2p5D diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index d35dcc89..b52ad877 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -122,7 +122,7 @@ def _resize_vocab_size(model, parallel_context): module.weight.data = new_embeddings module.num_embeddings = new_vocab_size - setattr(unwrapped_model, "orig_vocab_size", vocab_size) + setattr(unwrapped_model, "orig_vocab_size", vocab_size) return model @staticmethod diff --git a/oslo/transformers/mapping_utils.py b/oslo/transformers/mapping_utils.py index 48093289..42983722 100644 --- a/oslo/transformers/mapping_utils.py +++ b/oslo/transformers/mapping_utils.py @@ -72,6 +72,7 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): "sop_classifier.classifier", "classifier", "qa_outputs", + gather_output=True ), ], "Bart": [ @@ -79,51 +80,52 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): Column("classification_head.dense", gather_output=True), Row("out_proj", "fc2"), Update("embed_dim", "num_heads"), - Head("lm_head", "classification_head.out_proj", "qa_outputs"), + Head("lm_head", "classification_head.out_proj", "qa_outputs", gather_output=True), ], "Bert": [ Column("query", "key", "value", "intermediate.dense"), Column("pooler.dense", gather_output=True), Row("output.dense"), Update("num_attention_heads", "all_head_size"), - Head("decoder", "seq_relationship", "classifier", "qa_outputs"), + Head("transform.dense", gather_output=False), + Head("decoder", "seq_relationship", "classifier", "qa_outputs", gather_output=True), ], "Blenderbot": [ Column("q_proj", "k_proj", "v_proj", "fc1"), Row("out_proj", "fc2"), Update("embed_dim", "num_heads"), - Head("lm_head"), + Head("lm_head", gather_output=True), ], "BlenderbotSmall": [ Column("q_proj", "k_proj", "v_proj", "fc1"), Row("out_proj", "fc2"), Update("embed_dim", "num_heads"), - Head("lm_head"), + Head("lm_head", gather_output=True), ], "T5": [ Column("q", "k", "v", "DenseReluDense.wi"), Row("o", "DenseReluDense.wo", "relative_attention_bias"), Update("d_model", "n_heads", "inner_dim"), - Head("lm_head"), + Head("lm_head", gather_output=True), ], "GPT2": [ Column("c_attn", reversed=True, combined_qkv=True), Column("c_fc", "q_attn", reversed=True), Row("c_proj", reversed=True), Update("embed_dim", "split_size", "num_heads"), - Head("lm_head", "score", "classifier", "summary"), + Head("lm_head", "score", "classifier", "summary", gather_output=True), ], "GPTNeo": [ Column("q_proj", "k_proj", "v_proj", "c_fc"), Row("out_proj", "c_proj"), Update("embed_dim", "num_heads"), - Head("lm_head", "score", "qa_outputs"), + Head("lm_head", "score", "qa_outputs", gather_output=True), ], "GPTJ": [ Column("q_proj", "k_proj", "v_proj", "fc_in"), Row("out_proj", "fc_out"), Update("embed_dim", "num_attention_heads"), - Head("lm_head", "score"), + Head("lm_head", "score", gather_output=True), ], "Electra": [ Column("query", "key", "value", "intermediate.dense"), @@ -143,6 +145,7 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): "classifier", "qa_outputs", "summary", + gather_output=True ), ], "Roberta": [ @@ -155,7 +158,9 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): ), Row("output.dense"), Update("num_attention_heads", "all_head_size"), - Head("lm_head.decoder", "classifier.out_proj", "classifier", "qa_outputs"), + Head("lm_head.decoder", "classifier.out_proj", "classifier", "qa_outputs", + gather_output=True + ), ], } diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py index 46aaa839..a8b3b7ad 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py @@ -1,15 +1,46 @@ import torch.distributed as dist import wandb from datasets import load_dataset + +import torch from torch.optim import Adam from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, AutoModelForCausalLM, AutoConfig from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params from oslo.torch.distributed import ParallelContext, ParallelMode import time +import numpy as np +import os +import random + + +def seed_all(seed: int = 1930): + + print("Using Seed Number {}".format(seed)) + + os.environ["PYTHONHASHSEED"] = str( + seed) # set PYTHONHASHSEED env var at fixed value + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) + np.random.seed(seed) # for numpy pseudo-random generator + random.seed( + seed) # set fixed value for python built-in pseudo-random generator + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + + +def seed_worker(_worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + +seed_all(seed=1994) + def latency_trace(func): def wrapper(*args, **kwargs): @@ -33,7 +64,7 @@ def bw(tensors): tp_size = 8 tp_depth = 2 -model_name = "gpt2" +model_name = "jason9693/soongsil-bert-base" mkwargs = { } dataset_name = "squad" @@ -52,8 +83,10 @@ def bw(tensors): tokenizer.pad_token = tokenizer.eos_token # 모델 생성 및 병렬화 수행 -model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() -model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +model_no_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name)).cuda() +model_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name)) wrapper_tp = TensorParallel(model_tp, parallel_context) allocate_params(wrapper_tp, parallel_context) # allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 @@ -84,7 +117,7 @@ def bw(tensors): dist.barrier() # 로드 -model_gathered = GPT2LMHeadModel.from_pretrained("test/").cuda() +model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) dist.barrier() @@ -129,7 +162,7 @@ def bw(tensors): "notp.forward.time:": notp_fw_time, "notp.backward.time:": notp_bw_time, "gathered.forward.time:": gathered_fw_time, - "gathered.backward.time:": gathered_bw_time + "gathered.backward.time:": gathered_bw_time, }) dist.barrier() From deaeeb0ffb63a5c36ef76e355601f54fb6f5571c Mon Sep 17 00:00:00 2001 From: kevin-ai Date: Sat, 6 Aug 2022 17:12:11 +0900 Subject: [PATCH 19/37] test finished tp2d/2p5d deparallelize --- .../deparallel/test_deparallelize.py | 45 ++++++++++++++++--- .../deparallel/test_deparallelize.py | 45 ++++++++++++++++--- 2 files changed, 78 insertions(+), 12 deletions(-) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py index 986dc365..275dc3f7 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py @@ -1,14 +1,45 @@ import torch.distributed as dist import wandb from datasets import load_dataset + +import torch from torch.optim import Adam from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, AutoModelForCausalLM, AutoConfig from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params from oslo.torch.distributed import ParallelContext, ParallelMode + +import numpy as np import time +import os +import random + + +def seed_all(seed: int = 1930): + + print("Using Seed Number {}".format(seed)) + + os.environ["PYTHONHASHSEED"] = str( + seed) # set PYTHONHASHSEED env var at fixed value + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) + np.random.seed(seed) # for numpy pseudo-random generator + random.seed( + seed) # set fixed value for python built-in pseudo-random generator + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + + +def seed_worker(_worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + +seed_all(seed=1994) def latency_trace(func): @@ -33,7 +64,7 @@ def bw(tensors): tp_size = 4 tp_depth = 1 -model_name = "gpt2" +model_name = "jason9693/soongsil-bert-base" mkwargs = { } dataset_name = "squad" @@ -52,8 +83,10 @@ def bw(tensors): tokenizer.pad_token = tokenizer.eos_token # 모델 생성 및 병렬화 수행 -model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() -model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +model_no_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name)).cuda() +model_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name)) wrapper_tp = TensorParallel(model_tp, parallel_context) allocate_params(wrapper_tp, parallel_context) # allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 @@ -74,7 +107,7 @@ def bw(tensors): # 모니터링 생성 if dist.get_rank() == 0: - wandb.init(project="oslo", name=f"{model_name}_tp1d_bs{batch_size}") + wandb.init(project="oslo", name=f"{model_name}_tp2d_bs{batch_size}") cur = time.time() # 저장 @@ -84,7 +117,7 @@ def bw(tensors): dist.barrier() # 로드 -model_gathered = GPT2LMHeadModel.from_pretrained("test/").cuda() +model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py index 6d790952..26286378 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py @@ -1,14 +1,45 @@ import torch.distributed as dist import wandb from datasets import load_dataset + +import torch from torch.optim import Adam from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, AutoModelForCausalLM, AutoConfig from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params from oslo.torch.distributed import ParallelContext, ParallelMode + +import numpy as np import time +import os +import random + + +def seed_all(seed: int = 1930): + + print("Using Seed Number {}".format(seed)) + + os.environ["PYTHONHASHSEED"] = str( + seed) # set PYTHONHASHSEED env var at fixed value + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) + np.random.seed(seed) # for numpy pseudo-random generator + random.seed( + seed) # set fixed value for python built-in pseudo-random generator + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + + +def seed_worker(_worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + +seed_all(seed=1994) def latency_trace(func): @@ -33,7 +64,7 @@ def bw(tensors): tp_size = 4 tp_depth = 1 -model_name = "gpt2" +model_name = "jason9693/soongsil-bert-base" mkwargs = { } dataset_name = "squad" @@ -43,7 +74,7 @@ def bw(tensors): data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=tp_size, - tensor_parallel_mode=ParallelMode.TENSOR_2D, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, tensor_parallel_depth=tp_depth, ) @@ -52,8 +83,10 @@ def bw(tensors): tokenizer.pad_token = tokenizer.eos_token # 모델 생성 및 병렬화 수행 -model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() -model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +model_no_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name)).cuda() +model_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name)) wrapper_tp = TensorParallel(model_tp, parallel_context) allocate_params(wrapper_tp, parallel_context) # allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 @@ -84,7 +117,7 @@ def bw(tensors): dist.barrier() # 로드 -model_gathered = GPT2LMHeadModel.from_pretrained("test/").cuda() +model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) dist.barrier() From bc4d2246e49a075f8890dfc566f71faba98d50fa Mon Sep 17 00:00:00 2001 From: kevin-ai Date: Sun, 7 Aug 2022 23:18:40 +0900 Subject: [PATCH 20/37] deparallelize temp commit --- oslo/torch/nn/modules/linear.py | 3 +- .../parallel/tensor_parallel/_base_wrapper.py | 4 ++ .../tensor_parallel/_parallel_1d/_wrapper.py | 40 +++++++++++--- .../_parallel_2p5d/_wrapper.py | 4 +- .../tensor_parallel/tensor_parallel.py | 7 ++- oslo/transformers/mapping_utils.py | 2 +- .../deparallel/test_deparallelize.py | 55 ++++++++++--------- .../deparallel/test_deparallelize.py | 5 +- 8 files changed, 80 insertions(+), 40 deletions(-) diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index f127c3fd..43af941c 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -165,6 +165,7 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: if self.gather_output: outputs = all_gather_tensor_1d(outputs, -1, self.parallel_context).clone() + outputs = outputs.contiguous() return outputs @@ -221,7 +222,7 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: else: return outputs + self.bias - return outputs + return outputs.contiguous() class Linear2D(Linear): diff --git a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py index f22438cd..0af768d4 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py @@ -77,6 +77,10 @@ def save_parallelized( mapping=mapping, module_args=self.config ).eval() + # model_to_save = self.clone() + ## resize vocab & num_class + + if state_dict is None: state_dict = self.state_dict() diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index f45baf9d..6ba54e1d 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -133,8 +133,8 @@ def _parallelize_head(self): self._slice_head( module=module, reversed=self.tensor_parallel_mapping.is_reversed( - self.module, param_name - ), + self.module, param_name, + ) ) @staticmethod @@ -170,6 +170,32 @@ def _slice_embedding(self, module): orig_module=copy.deepcopy(module.__class__), ) module.__class__ = VocabParallelEmbedding1D + + for name, module_head in self.module.named_modules(): + if ( + hasattr(module_head, "weight") + and module_head.weight is module.weight + and not isinstance(module_head, nn.Embedding) + and not self.tensor_parallel_mapping.is_head( + self.module, name + ) + ): + _update_module_arguments( + module=module_head, + parallel_context=self.parallel_context, + reversed=self.tensor_parallel_mapping.is_reversed(self.module, name), + fusion_degree=1, + orig_module=copy.deepcopy(module_head.__class__), + out_features=module.weight.size()[0], + # in_features=module.weight.size()[1], + gather_output=not is_oslo_model(self.module), + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + + if isinstance(module_head, nn.Linear) or isinstance(module_head, nn.Conv1D): + module_head.__class__ = ColumnParallelLinear else: weight_list = module.weight.data.chunk(world_size, dim=1) module.weight.data = weight_list[rank].contiguous() @@ -290,13 +316,13 @@ def _row_slice_linear(self, module: nn.Module, reversed: bool, fusion_degree: in ) module.__class__ = RowParallelLinear - def _slice_head(self, module, reversed): + def _slice_head(self, module, reversed, gather_output=True): if module.weight is not self.module.get_input_embeddings().weight: self._column_slice_linear( module=module, reversed=reversed, fusion_degree=1, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, ) else: world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) @@ -318,7 +344,7 @@ def _slice_head(self, module, reversed): reversed=reversed, fusion_degree=1, orig_module=copy.deepcopy(module.__class__), - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, skip_bias_add=module.skip_bias_add if hasattr(module, "skip_bias_add") else False, @@ -350,10 +376,10 @@ def _deparallelize_embedding(self): def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): - if self.tensor_parallel_mapping.is_column_parallel(self.module, param_name): + if module.__class__ == ColumnParallelLinear: self._gather_column_linear(module) - elif self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): + elif module.__class__ == RowParallelLinear: self._gather_row_linear(module) def _gather_embedding(self, module): diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 189bf912..be8f1dca 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -529,9 +529,9 @@ def _gather_embedding(self, module): assert hasattr( self.module, "orig_vocab_size" ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." - orig_vocab_size = self.module.orig_vocab_size + # orig_vocab_size = self.module.orig_vocab_size - module.weight.data = w[:orig_vocab_size, :] + # module.weight.data = w[:orig_vocab_size, :] _update_module_arguments( module=module, diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index b52ad877..4d9e8982 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -1,3 +1,4 @@ +import imp from typing import Union, Optional, Callable import os @@ -210,9 +211,11 @@ def save_parallelized( new_module = unwrapped_model.__class__(self.module.config) else: new_module = unwrapped_model.__class__(**self.module.config) + new_module = self._resize_vocab_size(new_module, self.parallel_context) new_module = self._resize_num_classes(new_module, self.parallel_context, mapping) - return self.module.save_parallelized( + + new_module = self.module.save_parallelized( new_module, save_directory, save_config, @@ -223,6 +226,8 @@ def save_parallelized( **kwargs, ) + return new_module + @staticmethod def get_module_args(module): state_dict = module.state_dict() diff --git a/oslo/transformers/mapping_utils.py b/oslo/transformers/mapping_utils.py index 42983722..a3cf45b4 100644 --- a/oslo/transformers/mapping_utils.py +++ b/oslo/transformers/mapping_utils.py @@ -151,13 +151,13 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): "Roberta": [ Column("query", "key", "value", "intermediate.dense"), Column( - "lm_head.dense", "classifier.dense", "roberta.pooler", gather_output=True, ), Row("output.dense"), Update("num_attention_heads", "all_head_size"), + Head("lm_head.dense"), Head("lm_head.decoder", "classifier.out_proj", "classifier", "qa_outputs", gather_output=True ), diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py index 275dc3f7..2d8d5d8c 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py @@ -6,6 +6,7 @@ from torch.optim import Adam from torch.utils.data import DataLoader from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, AutoModelForCausalLM, AutoConfig +from transformers import AutoModelForMaskedLM from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params @@ -66,9 +67,11 @@ def bw(tensors): model_name = "jason9693/soongsil-bert-base" mkwargs = { + 'pad_token': '[PAD]' } dataset_name = "squad" + # parallel context 생성 parallel_context = ParallelContext.from_torch( data_parallel_size=1, @@ -80,11 +83,11 @@ def bw(tensors): # 토크나이저 생성 tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) -tokenizer.pad_token = tokenizer.eos_token +# tokenizer.pad_token = tokenizer.eos_token -# 모델 생성 및 병렬화 수행 -model_no_tp = AutoModelForCausalLM.from_config( - AutoConfig.from_pretrained(model_name)).cuda() +# # 모델 생성 및 병렬화 수행 +# model_no_tp = AutoModelForCausalLM.from_config( +# AutoConfig.from_pretrained(model_name)).cuda() model_tp = AutoModelForCausalLM.from_config( AutoConfig.from_pretrained(model_name)) wrapper_tp = TensorParallel(model_tp, parallel_context) @@ -97,7 +100,7 @@ def bw(tensors): # 옵티마이저 생성 optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) -optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) +# optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) # 데이터셋 생성 batch_size = 16 @@ -107,7 +110,7 @@ def bw(tensors): # 모니터링 생성 if dist.get_rank() == 0: - wandb.init(project="oslo", name=f"{model_name}_tp2d_bs{batch_size}") + wandb.init(project="oslo", name=f"{model_name}_tp1d_bs{batch_size}") cur = time.time() # 저장 @@ -116,17 +119,17 @@ def bw(tensors): # 모니터링 생성 대기 dist.barrier() -# 로드 -model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() -optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) +# # 로드 +# model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() +# optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) dist.barrier() # 학습 시작 for data in dataloader: optimizer_tp.zero_grad() - optimizer_no_tp.zero_grad() - optimizer_gathered.zero_grad() + # optimizer_no_tp.zero_grad() + # optimizer_gathered.zero_grad() inputs = tokenizer( data, @@ -136,33 +139,33 @@ def bw(tensors): max_length=512, ).to("cuda") - loss_no_tp, notp_fw_time = \ - fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + # loss_no_tp, notp_fw_time = \ + # fw(model_no_tp, **inputs, labels=inputs["input_ids"]) loss_tp, tp_fw_time = \ fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) - loss_gathered, gathered_fw_time = \ - fw(model_gathered, **inputs, labels=inputs["input_ids"]) + # loss_gathered, gathered_fw_time = \ + # fw(model_gathered, **inputs, labels=inputs["input_ids"]) - if dist.get_rank() == 0: - print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") - wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + # if dist.get_rank() == 0: + # print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + # wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) - _, notp_bw_time = bw(loss_no_tp) + # _, notp_bw_time = bw(loss_no_tp) _, tp_bw_time = bw(loss_tp) - _, gathered_bw_time = bw(loss_gathered) + # _, gathered_bw_time = bw(loss_gathered) optimizer_tp.step() - optimizer_no_tp.step() - optimizer_gathered.step() + # optimizer_no_tp.step() + # optimizer_gathered.step() if dist.get_rank() == 0: wandb.log({ "tp.forward.time:": tp_fw_time, "tp.backward.time:": tp_bw_time, - "notp.forward.time:": notp_fw_time, - "notp.backward.time:": notp_bw_time, - "gathered.forward.time:": gathered_fw_time, - "gathered.backward.time:": gathered_bw_time + # "notp.forward.time:": notp_fw_time, + # "notp.backward.time:": notp_bw_time, + # "gathered.forward.time:": gathered_fw_time, + # "gathered.backward.time:": gathered_bw_time }) dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py index a8b3b7ad..56b7a483 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py @@ -64,8 +64,9 @@ def bw(tensors): tp_size = 8 tp_depth = 2 -model_name = "jason9693/soongsil-bert-base" +model_name = "bert-base-uncased" mkwargs = { + 'pad_token': '[PAD]' } dataset_name = "squad" @@ -80,7 +81,7 @@ def bw(tensors): # 토크나이저 생성 tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) -tokenizer.pad_token = tokenizer.eos_token +# tokenizer.pad_token = tokenizer.eos_token # 모델 생성 및 병렬화 수행 model_no_tp = AutoModelForCausalLM.from_config( From 5ad567542d68fc69600123efcf4e3be977d2a237 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Wed, 10 Aug 2022 00:32:04 +0900 Subject: [PATCH 21/37] fixed 2p5d bert crashed --- .gitignore | 7 + .../parallel/tensor_parallel/_base_wrapper.py | 5 - .../tensor_parallel/_parallel_2p5d/_ops.py | 398 ++++++++++++++++-- .../_parallel_2p5d/_wrapper.py | 101 ++++- .../tensor_parallel/tensor_parallel.py | 59 ++- oslo/transformers/mapping_utils.py | 1 - oslo/transformers/models/__init__.py | 11 + 7 files changed, 513 insertions(+), 69 deletions(-) diff --git a/.gitignore b/.gitignore index 61186f7e..8b1e6911 100644 --- a/.gitignore +++ b/.gitignore @@ -389,3 +389,10 @@ usecases /usecases */usecases wandb/ + +# multi gpu mem log +**/core.* + +# sample huggingface models +**/pytorch_model.bin +**/config.json \ No newline at end of file diff --git a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py index 0af768d4..5d100e9f 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py @@ -77,10 +77,6 @@ def save_parallelized( mapping=mapping, module_args=self.config ).eval() - # model_to_save = self.clone() - ## resize vocab & num_class - - if state_dict is None: state_dict = self.state_dict() @@ -223,4 +219,3 @@ def from_parallelized(self, path): @torch.no_grad() def deparallelize(self): return NotImplementedError - diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py index 272d7d62..ac471f93 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py @@ -9,6 +9,40 @@ from oslo.torch.distributed.nn.functional import all_reduce, reduce_scatter, all_gather +def classifier_2p5d( + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + parallel_context: ParallelContext, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, +) -> Tensor: + return _Classifier2p5D.apply( + A, + B, + bias, + tesseract_dim, + out_shape, + col_rank, + row_rank, + data_parallel_rank, + pipeline_parallel_rank, + pipeline_parallel_size, + tensor_parallel_size, + parallel_context, + row_parallel_mode, + col_parallel_mode, + ) + + def add_bias_2p5d( input: Tensor, bias: Tensor, @@ -56,15 +90,6 @@ def layernorm_2p5d( ) -def all_gather_tensor_2p5d( - inputs: Tensor, - dim: int, - parallel_context: ParallelContext, - col_parallel_mode: ParallelMode, -) -> Tensor: - return _AllGatherTensor2p5D.apply(inputs, dim, parallel_context, col_parallel_mode) - - def gather_batch_2p5d( inputs: Tensor, dim: int = 0, @@ -83,24 +108,13 @@ def gather_batch_2p5d( ) -def split_batch_2p5d( +def all_gather_tensor_2p5d( inputs: Tensor, - dim: int = 0, - parallel_context: Optional[ParallelContext] = None, + dim: int, + parallel_context: ParallelContext, + col_parallel_mode: ParallelMode, ) -> Tensor: - dim_size = inputs.size(dim) - world_size = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) - - if world_size <= 1: - return inputs - - assert ( - dim_size % world_size == 0 - ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." - - return torch.chunk( - inputs, parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL), dim=dim - )[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL)].contiguous() + return _AllGatherTensor2p5D.apply(inputs, dim, parallel_context, col_parallel_mode) def reduce_by_batch_2p5d( @@ -130,6 +144,30 @@ def reduce_scatter_tensor_2p5d( return _ReduceScatterTensor2p5D.apply(inputs, dim, parallel_context, parallel_mode) +def split_batch_2p5d( + inputs: Tensor, dim: int, parallel_context: ParallelContext +) -> Tensor: + dim_size = inputs.size(dim) + world_size = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + + if world_size <= 1: + return inputs + + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." + + col_chunked = torch.chunk( + inputs, parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL), dim=dim + )[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL)].contiguous() + return col_chunked + # return torch.chunk( + # col_chunked, + # parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_DEP), + # dim=dim, + # )[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP)].contiguous() + + def get_current_device(): r""" Get current device. @@ -137,6 +175,116 @@ def get_current_device(): return torch.cuda.current_device() +# TODO: 만약 2D와 비슷할 경우 병합(?) +class _Classifier2p5D(torch.autograd.Function): + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward( + ctx: Any, + A: Tensor, + B: Tensor, + bias, + tesseract_dim: int, + out_shape: Tuple[int, ...], + row_rank: int, + col_rank: int, + data_parallel_rank: int, + pipeline_parallel_rank: int, + pipeline_parallel_size: int, + tensor_parallel_size: int, + parallel_context: ParallelContext, + row_parallel_mode: ParallelMode, + col_parallel_mode: ParallelMode, + ) -> Tensor: + A = A.clone().detach() + A_shape = A.shape + A = A.reshape((-1, A_shape[-1])) + B_shape = B.shape + B = B.reshape((-1, B_shape[-1])) + B_temp = all_gather( + B, -1, parallel_context=parallel_context, parallel_mode=col_parallel_mode + ) + if ctx: + ctx.save_for_backward(A, B_temp) + + C = torch.matmul(A, B_temp.transpose(0, 1)) + + C = all_reduce( + C, + parallel_context=parallel_context, + parallel_mode=row_parallel_mode, + ) + + ctx.use_bias = bias is not None + if bias is not None: + C = C + bias + + out = C.reshape(out_shape) + + if ctx: + ctx.tesseract_dim = tesseract_dim + ctx.col_rank = col_rank + ctx.row_rank = row_rank + ctx.A_shape = A_shape + ctx.B_shape = B_shape + ctx.data_parallel_rank = data_parallel_rank + ctx.pipeline_parallel_rank = pipeline_parallel_rank + ctx.pipeline_parallel_size = pipeline_parallel_size + ctx.tensor_parallel_size = tensor_parallel_size + ctx.parallel_context = parallel_context + ctx.row_parallel_mode = row_parallel_mode + ctx.col_parallel_mode = col_parallel_mode + + return out + + @staticmethod + @custom_bwd + def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: + A, B = ctx.saved_tensors + + with torch.no_grad(): + A_grad = torch.matmul(output_grad, B) + A_grad = A_grad.reshape(ctx.A_shape) + B_grad = torch.matmul( + output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A + ) + B_grad = reduce_scatter( + B_grad, + -1, + parallel_context=ctx.parallel_context, + parallel_mode=ctx.col_parallel_mode, + ) + B_grad = B_grad.reshape(ctx.B_shape) + + if ctx.use_bias: + bias_grad = torch.sum( + output_grad, dim=tuple(range(output_grad.ndim - 1)) + ) + bias_grad = all_reduce( + bias_grad, + parallel_context=ctx.parallel_context, + parallel_mode=ctx.col_parallel_mode, + ) + else: + bias_grad = None + + return ( + A_grad, + B_grad, + bias_grad, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + class Matmul_AB_2p5D(torch.autograd.Function): @staticmethod @custom_fwd(cast_inputs=torch.float16) @@ -157,6 +305,9 @@ def forward( row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, ) -> Tensor: + # A: [b / dq, s, h / q] -> [(b * s) / dq, h / q] + # B: [h / dq, s / q] + # C: [b / dq, s, s / q] -> [(b * s) / dq, s / q] assert A.shape[-1] == B.shape[-2], "Invalid shapes: A={}, B={} for AB.".format( A.shape, B.shape @@ -229,11 +380,11 @@ def forward( if ctx: ctx.tesseract_dim = tesseract_dim - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.row_rank = row_rank ctx.col_rank = col_rank + ctx.row_rank = row_rank ctx.dep_rank = dep_rank + ctx.A_shape = A_shape + ctx.B_shape = B_shape ctx.data_parallel_rank = data_parallel_rank ctx.pipeline_parallel_rank = pipeline_parallel_rank ctx.pipeline_parallel_size = pipeline_parallel_size @@ -399,12 +550,13 @@ def forward( out = C.reshape(out_shape) if ctx: + ctx.parallel_context = parallel_context ctx.tesseract_dim = tesseract_dim - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.row_rank = row_rank ctx.col_rank = col_rank + ctx.row_rank = row_rank ctx.dep_rank = dep_rank + ctx.A_shape = A_shape + ctx.B_shape = B_shape ctx.data_parallel_rank = data_parallel_rank ctx.pipeline_parallel_rank = pipeline_parallel_rank ctx.pipeline_parallel_size = pipeline_parallel_size @@ -467,6 +619,7 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: None, None, None, + None, ) @@ -567,11 +720,11 @@ def forward( if ctx: ctx.tesseract_dim = tesseract_dim - ctx.A_shape = A_shape - ctx.B_shape = B_shape - ctx.row_rank = row_rank ctx.col_rank = col_rank + ctx.row_rank = row_rank ctx.dep_rank = dep_rank + ctx.A_shape = A_shape + ctx.B_shape = B_shape ctx.data_parallel_rank = data_parallel_rank ctx.pipeline_parallel_rank = pipeline_parallel_rank ctx.pipeline_parallel_size = pipeline_parallel_size @@ -634,6 +787,7 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: None, None, None, + None, ) @@ -674,10 +828,10 @@ def forward( ) if ctx: - ctx.tesseract_dim = tesseract_dim - ctx.row_rank = row_rank ctx.col_rank = col_rank + ctx.row_rank = row_rank ctx.dep_rank = dep_rank + ctx.tesseract_dim = tesseract_dim ctx.bias = skip_bias_add ctx.data_parallel_rank = data_parallel_rank ctx.pipeline_parallel_rank = pipeline_parallel_rank @@ -695,10 +849,10 @@ def forward( @staticmethod @custom_bwd def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: - tesseract_dim = ctx.tesseract_dim - row_rank = ctx.row_rank col_rank = ctx.col_rank + row_rank = ctx.row_rank dep_rank = ctx.dep_rank + tesseract_dim = ctx.tesseract_dim data_parallel_rank = ctx.data_parallel_rank pipeline_parallel_rank = ctx.pipeline_parallel_rank pipeline_parallel_size = ctx.pipeline_parallel_size @@ -734,6 +888,8 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: None, None, None, + None, + None, ) else: grad_tmp = torch.zeros_like(output_grad) @@ -752,6 +908,8 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: None, None, None, + None, + None, ) else: reduce_dim = tuple(range(output_grad.ndim - 1)) @@ -783,6 +941,9 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: None, None, None, + None, + None, + None, ) else: reduce_tmp = torch.zeros_like(reduce) @@ -801,6 +962,9 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: None, None, None, + None, + None, + None, ) @@ -851,7 +1015,7 @@ def backward(ctx, output_grad): input_grad -= output_grad_sum input_grad *= Var_x - return input_grad, None, None, None, None, None + return input_grad, None, None, None, None, None, None class _AllGatherTensor2p5D(torch.autograd.Function): @@ -926,7 +1090,7 @@ def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: output_grad.contiguous(), group=ctx.parallel_context.get_group(ctx.col_parallel_mode), ) - return grad, None, None, None + return grad, None, None class _ReduceTensor2p5D(torch.autograd.Function): @@ -948,6 +1112,18 @@ def backward(ctx: Any, output_grad: Tensor): return output_grad, None, None +# def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: +# r"""All-reduce the input. +# Args: +# input_ (:class:`torch.tensor`): Input tensor. +# parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. +# Note: +# The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found +# in `parallel_mode `_ +# """ +# return _ReduceTensor2p5D.apply(input_, parallel_mode) + + class _ReduceScatterTensor2p5D(torch.autograd.Function): @staticmethod def forward( @@ -1029,6 +1205,144 @@ def forward( @custom_bwd def backward(ctx: Any, output_grad: Tensor): if ctx.reduce_mean: - return output_grad / ctx.reduce_size, None, None + return output_grad / ctx.reduce_size, None else: - return output_grad, None, None + return output_grad, None + + +def split_batch_2d(parallel_context, tensor, tesseract_dim): + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + return tensor + + +def split_2p5d(parallel_context, tensor, tesseract_dim, col_first=True): + if col_first: + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + else: + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + tensor = torch.chunk(tensor, tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP) + ] + return tensor + + +def split_2d(parallel_context, tensor, tesseract_dim, col_first=True): + if col_first: + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + else: + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + return tensor + + +def split_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor = tensor.chunk(summa_dim, dim=dim)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + return tensor + + +def gather_2p5d(parallel_context, tensor, tesseract_dim, col_first=True): + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, tensor, parallel_context.get_group(ParallelMode.TENSOR_2P5D_DEP) + ) + tensor = torch.cat(tensor_list, dim=0) + if col_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_2d(parallel_context, tensor, tesseract_dim, col_first=True): + if col_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_1d(parallel_context, tensor, summa_dim, dim=-1): + parallel_modde = ParallelMode.TENSOR_2P5D_ROW + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(parallel_modde), + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor \ No newline at end of file diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 2ba89713..adf9a986 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -10,7 +10,7 @@ VocabUtility, Embedding2p5D, ) -from oslo.torch.nn.modules.linear import Linear2p5D +from oslo.torch.nn.modules.linear import Linear, Linear2p5D from oslo.torch.nn.modules.layer_norm import LayerNorm2p5D from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( split_batch_2p5d, @@ -462,7 +462,6 @@ def _slice_head(self, module, reversed, gathered): ParallelMode.TENSOR_2P5D_DEP: dep_rank, } - gather_output = None _update_module_arguments( module=module, in_features=module.weight.size()[1], @@ -484,12 +483,40 @@ def _slice_head(self, module, reversed, gathered): ) module.__class__ = Linear2p5D + def _zero_rank_log(self, txt): + import torch.distributed as dist + if dist.get_rank() == 0: + print(txt) + # 모니터링 생성 대기 + dist.barrier() + + def _pdb_set_trace(self): + import pdb + import torch.distributed as dist + if dist.get_rank() == 0: + pdb.set_trace() + # 모니터링 생성 대기 + dist.barrier() + @torch.no_grad() def deparallelize(self): - # must deparallelize linear first than embedding + # must deparallelize embedding first than linear + self._zero_rank_log("deparallelize embedding start") + self._deparallelize_embedding() + self._zero_rank_log("deparallelize embedding end") + + self._zero_rank_log("deparallelize linear start") self._deparallelize_linear() + self._zero_rank_log("deparallelize linear end") + + self._zero_rank_log("deparallelize layernorm start") self._deparallelize_layernorm() - self._deparallelize_embedding() + self._zero_rank_log("deparallelize layernorm end") + + self._zero_rank_log("deparallelize head start") + self._deparallelize_head() + self._zero_rank_log("deparallelize head end") + self._rollback_mp_arguments() def _rollback_mp_arguments(self): @@ -511,9 +538,19 @@ def _deparallelize_embedding(self): def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): - if module.__class__ == Linear2p5D: + if self.tensor_parallel_mapping.is_column_parallel( + self.module, param_name + ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): self._gather_linear(module) + def _deparallelize_head(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_head( + self.module, param_name + ) and isinstance(module, Linear2p5D): + self._zero_rank_log(f"deparallelize head {param_name}") + self._gather_head(module) + def _deparallelize_layernorm(self): for param_name, module in self.module.named_modules(): if module.__class__ == LayerNorm2p5D: @@ -521,19 +558,14 @@ def _deparallelize_layernorm(self): def _gather_embedding(self, module): tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) - if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): - w = module.weight.data - - # if module is shared with linear, then skip this loop - if module.embedding_dim == module.weight.size()[0]: - w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim, col_first=True) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim, col_first=True) assert hasattr( self.module, "orig_vocab_size" - ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." - # orig_vocab_size = self.module.orig_vocab_size - - # module.weight.data = w[:orig_vocab_size, :] + ), f"wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + orig_vocab_size = self.module.orig_vocab_size + module.weight.data = w[:orig_vocab_size, :].contiguous() _update_module_arguments( module=module, @@ -556,6 +588,45 @@ def _gather_embedding(self, module): ) module.__class__ = nn.Embedding + def _gather_head(self, module: Linear2p5D): + if module.weight is not self.module.get_input_embeddings().weight: + return self._gather_linear(module) + elif hasattr(module, "bias") and module.bias is not None: + self._zero_rank_log("before gathering bias") + tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + + module.bias.data = b[:module.weight.size()[0]] + self._zero_rank_log("after gathering bias") + + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + + module.__class__ = nn.Linear + + def _gather_linear(self, module: Linear2p5D): is_reversed = module.reversed fusion_degree = module.fusion_degree diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index 27d9e16d..56d356d6 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -98,6 +98,7 @@ def __init__( self.parallel_context = get_parallel_context(module, parallel_context) module = self._resize_vocab_size(module, self.parallel_context) module = self._resize_num_classes(module, self.parallel_context, mapping) + module = self._resize_head_bias_size(module, self.parallel_context, mapping) if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D: self.module = _TensorParallel1D(module, self.parallel_context, mapping, module_args) @@ -146,8 +147,8 @@ def _resize_vocab_size(model, parallel_context): module.weight.data = new_embeddings module.num_embeddings = new_vocab_size - setattr(module, "orig_num_classes", vocab_size) - setattr(unwrapped_model, "orig_vocab_size", vocab_size) + setattr(module, "orig_num_classes", vocab_size) + setattr(unwrapped_model, "orig_vocab_size", vocab_size) return model @staticmethod @@ -173,6 +174,17 @@ def _resize_num_classes(model, parallel_context, mapping): module.out_features = ( unwrapped_model.get_input_embeddings().num_embeddings ) + + assert hasattr(unwrapped_model.get_input_embeddings(), "orig_num_classes"), ( + "call _resize_vocab before _resize_num_classes" + ) + out_features = unwrapped_model.get_input_embeddings().orig_num_classes + setattr(module, "orig_num_classes", out_features) + setattr( + unwrapped_model, + f"orig_{param_name.split('.')[-1]}_num_classes", + out_features, + ) else: out_features, in_features = module.weight.size() new_out_features = out_features @@ -214,11 +226,45 @@ def _resize_num_classes(model, parallel_context, mapping): ) return model - def _restore_vocab_size(self, model, parallel_context): - pass + @staticmethod + def _resize_head_bias_size(model, parallel_context, mapping): + unwrapped_model = unwrap_parallel(model) + divisible_by = get_divisible_by(parallel_context) + + if mapping is None: + if is_huggingface_model(unwrapped_model): + mapping = _TensorParallelMappingForHuggingFace().get_mapping( + unwrapped_model + ) + else: + raise ValueError( + "`mapping` must be input if the model is not huggingface model." + ) + tensor_parallel_mapping = TensorParallelMapping(mapping) + divisible_by = get_divisible_by(parallel_context) - def _restore_num_classes(self, model, parallel_context): - pass + for param_name, module in unwrapped_model.named_modules(): + if tensor_parallel_mapping.is_head(unwrapped_model, param_name + ) and unwrapped_model.get_input_embeddings().weight is module.weight \ + and hasattr(module, "bias") and module.bias is not None: + out_features = module.bias.size()[0] + new_out_features = out_features + + while new_out_features % divisible_by != 0: + new_out_features += 1 + + if new_out_features != out_features: + padding = torch.zeros( + new_out_features - out_features, + dtype=module.bias.dtype, + device=module.bias.device, + ) + new_bias = torch.cat( + tensors=[module.bias.data, padding], + dim=0, + ) + module.bias.data = new_bias + return model @torch.no_grad() def save_parallelized( @@ -239,6 +285,7 @@ def save_parallelized( new_module = self._resize_vocab_size(new_module, self.parallel_context) new_module = self._resize_num_classes(new_module, self.parallel_context, mapping) + new_module = self._resize_head_bias_size(new_module, self.parallel_context, mapping) new_module = self.module.save_parallelized( new_module, diff --git a/oslo/transformers/mapping_utils.py b/oslo/transformers/mapping_utils.py index a3cf45b4..98f0085f 100644 --- a/oslo/transformers/mapping_utils.py +++ b/oslo/transformers/mapping_utils.py @@ -226,7 +226,6 @@ def __init__(self): HF_TO_OSLO = { transformers.GPT2Model: oslo.transformers.GPT2Model, transformers.GPT2LMHeadModel: oslo.transformers.GPT2LMHeadModel, - transformers.GPT2DoubleHeadsModel: oslo.transformers.GPT2DoubleHeadModel, transformers.GPT2ForSequenceClassification: oslo.transformers.GPT2ForSequenceClassification, transformers.GPT2ForTokenClassification: oslo.transformers.GPT2ForTokenClassification, } diff --git a/oslo/transformers/models/__init__.py b/oslo/transformers/models/__init__.py index e69de29b..6dfb8583 100644 --- a/oslo/transformers/models/__init__.py +++ b/oslo/transformers/models/__init__.py @@ -0,0 +1,11 @@ +from oslo.transformers.models.gpt2.modeling_gpt2 import ( + GPT2Model, + GPT2LMHeadModel, + GPT2ForSequenceClassification, + GPT2ForTokenClassification, +) + + +from oslo.transformers.training_args import TrainingArguments + +# from oslo.transformers.trainer import Trainer From dcfcde251cdbd9cb5f1ce0898f325f710ec53aad Mon Sep 17 00:00:00 2001 From: jason9693 Date: Wed, 10 Aug 2022 07:32:08 +0900 Subject: [PATCH 22/37] fixed tp2d bert crashed --- oslo/torch/nn/modules/linear.py | 1 + .../tensor_parallel/_parallel_2d/_wrapper.py | 74 ++++++++++++++++++- .../_parallel_2p5d/_wrapper.py | 1 - 3 files changed, 73 insertions(+), 3 deletions(-) diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index e50aa2cd..aa693e03 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -512,6 +512,7 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] outputs = outputs.contiguous() + return outputs diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index edf6a69f..714366b7 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -10,6 +10,7 @@ VocabUtility, ) from oslo.torch.nn.modules.linear import ( + Linear, Linear2D, ) from oslo.torch.nn.modules.layer_norm import ( @@ -450,13 +451,35 @@ def _slice_head(self, module, reversed): ) module.__class__ = Linear2D + def _zero_rank_log(self, txt): + import torch.distributed as dist + if dist.get_rank() == 0: + print(txt) + # 모니터링 생성 대기 + dist.barrier() + @torch.no_grad() def deparallelize(self): + # must deparallelize embedding first than linear + self._zero_rank_log("deparallelize embedding start") + self._deparallelize_embedding() + self._zero_rank_log("deparallelize embedding end") + + self._zero_rank_log("deparallelize linear start") self._deparallelize_linear() + self._zero_rank_log("deparallelize linear end") + + self._zero_rank_log("deparallelize layernorm start") self._deparallelize_layernorm() - self._deparallelize_embedding() + self._zero_rank_log("deparallelize layernorm end") + + self._zero_rank_log("deparallelize head start") + self._deparallelize_head() + self._zero_rank_log("deparallelize head end") + self._rollback_mp_arguments() + def _rollback_mp_arguments(self): for module in self.module.modules(): for elem in self.tensor_parallel_mapping.update_attrs(self.module): @@ -474,9 +497,19 @@ def _deparallelize_embedding(self): def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): - if module.__class__ == Linear2D: + if self.tensor_parallel_mapping.is_column_parallel( + self.module, param_name + ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): self._gather_linear(module) + def _deparallelize_head(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_head( + self.module, param_name + ) and isinstance(module, Linear2D): + self._zero_rank_log(f"deparallelize head {param_name}") + self._gather_head(module) + def _deparallelize_layernorm(self): for param_name, module in self.module.named_modules(): if module.__class__ == LayerNorm2D: @@ -517,6 +550,43 @@ def _gather_embedding(self, module): ) module.__class__ = nn.Embedding + def _gather_head(self, module: Linear2D): + if module.weight is not self.module.get_input_embeddings().weight: + return self._gather_linear(module) + elif hasattr(module, "bias") and module.bias is not None: + self._zero_rank_log("before gathering bias") + tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + + module.bias.data = b[:module.weight.size()[0]] + self._zero_rank_log("after gathering bias") + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + + module.__class__ = nn.Linear + def _gather_linear(self, module: Linear2D): is_reversed = module.reversed fusion_degree = module.fusion_degree diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index adf9a986..83000946 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -626,7 +626,6 @@ def _gather_head(self, module: Linear2p5D): module.__class__ = nn.Linear - def _gather_linear(self, module: Linear2p5D): is_reversed = module.reversed fusion_degree = module.fusion_degree From 93d1e77a9906f9474a6ce47a3dc2198b2c62385b Mon Sep 17 00:00:00 2001 From: jason9693 Date: Wed, 10 Aug 2022 08:39:34 +0900 Subject: [PATCH 23/37] fixed tp1d bert crashed --- .../tensor_parallel/_parallel_1d/_wrapper.py | 64 ++++++++++++++++++- .../tensor_parallel/_parallel_2d/_wrapper.py | 7 +- .../deparallel/test_deparallelize.py | 50 +++++++-------- 3 files changed, 89 insertions(+), 32 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 37a9bdda..17445edf 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -385,10 +385,28 @@ def _slice_head(self, module, reversed): ) module.__class__ = ColLinear1D + def _zero_rank_log(self, txt): + import torch.distributed as dist + if dist.get_rank() == 0: + print(txt) + # 모니터링 생성 대기 + dist.barrier() + @torch.no_grad() def deparallelize(self): - self._deparallelize_linear() + # must deparallelize embedding first than linear + self._zero_rank_log("deparallelize embedding start") self._deparallelize_embedding() + self._zero_rank_log("deparallelize embedding end") + + self._zero_rank_log("deparallelize linear start") + self._deparallelize_linear() + self._zero_rank_log("deparallelize linear end") + + self._zero_rank_log("deparallelize head start") + self._deparallelize_head() + self._zero_rank_log("deparallelize head end") + self._rollback_mp_arguments() def _rollback_mp_arguments(self): @@ -410,12 +428,24 @@ def _deparallelize_embedding(self): def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): - if module.__class__ == ColLinear1D: + if self.tensor_parallel_mapping.is_column_parallel( + self.module, param_name + ): self._gather_column_linear(module) - elif module.__class__ == RowLinear1D: + elif self.tensor_parallel_mapping.is_row_parallel( + self.module, param_name + ): self._gather_row_linear(module) + def _deparallelize_head(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_head( + self.module, param_name + ) and isinstance(module, ColLinear1D): + self._zero_rank_log(f"deparallelize head {param_name}") + self._gather_head(module) + def _gather_embedding(self, module): world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): @@ -498,6 +528,34 @@ def _gather_column_linear(self, module): def _gather_row_linear(self, module): self._gather_linear(module, dim=1) + def _gather_head(self, module: ColLinear1D): + if module.weight is not self.module.get_input_embeddings().weight: + return self._gather_column_linear(module) + elif hasattr(module, "bias") and module.bias is not None: + self._zero_rank_log("before gathering head bias") + world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) + + b = self._reconstruct_combined_qkv(module.bias, world_size, 1, 0) + + module.bias.data = b[:module.weight.size()[0]] + self._zero_rank_log("after gathering head bias") + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.reversed + del module.fusion_degree + del module.orig_module + del module.parallel_context + + module.__class__ = nn.Linear + def _reconstruct_combined_qkv(self, tensor, world_size, fusion_degree, dim: int): tensor_list = tensor.chunk(fusion_degree, dim=dim) result_list = [] diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 714366b7..2b1dc6e5 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -555,9 +555,9 @@ def _gather_head(self, module: Linear2D): return self._gather_linear(module) elif hasattr(module, "bias") and module.bias is not None: self._zero_rank_log("before gathering bias") - tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D) - b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + b = gather_1d(self.parallel_context, module.bias.data, summa_dim, 0) module.bias.data = b[:module.weight.size()[0]] self._zero_rank_log("after gathering bias") @@ -573,8 +573,7 @@ def _gather_head(self, module: Linear2D): ) del module.row_rank del module.col_rank - del module.dep_rank - del module.tesseract_dim + del module.summa_dim del module.data_parallel_rank del module.pipeline_parallel_rank del module.tensor_parallel_size diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py index 2d8d5d8c..4578e4d4 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py @@ -65,7 +65,7 @@ def bw(tensors): tp_size = 4 tp_depth = 1 -model_name = "jason9693/soongsil-bert-base" +model_name = "bert-base-uncased" mkwargs = { 'pad_token': '[PAD]' } @@ -85,9 +85,9 @@ def bw(tensors): tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) # tokenizer.pad_token = tokenizer.eos_token -# # 모델 생성 및 병렬화 수행 -# model_no_tp = AutoModelForCausalLM.from_config( -# AutoConfig.from_pretrained(model_name)).cuda() +# 모델 생성 및 병렬화 수행 +model_no_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name)).cuda() model_tp = AutoModelForCausalLM.from_config( AutoConfig.from_pretrained(model_name)) wrapper_tp = TensorParallel(model_tp, parallel_context) @@ -100,7 +100,7 @@ def bw(tensors): # 옵티마이저 생성 optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) -# optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) # 데이터셋 생성 batch_size = 16 @@ -119,17 +119,17 @@ def bw(tensors): # 모니터링 생성 대기 dist.barrier() -# # 로드 -# model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() -# optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) +# 로드 +model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() +optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) dist.barrier() # 학습 시작 for data in dataloader: optimizer_tp.zero_grad() - # optimizer_no_tp.zero_grad() - # optimizer_gathered.zero_grad() + optimizer_no_tp.zero_grad() + optimizer_gathered.zero_grad() inputs = tokenizer( data, @@ -139,33 +139,33 @@ def bw(tensors): max_length=512, ).to("cuda") - # loss_no_tp, notp_fw_time = \ - # fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_no_tp, notp_fw_time = \ + fw(model_no_tp, **inputs, labels=inputs["input_ids"]) loss_tp, tp_fw_time = \ fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) - # loss_gathered, gathered_fw_time = \ - # fw(model_gathered, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = \ + fw(model_gathered, **inputs, labels=inputs["input_ids"]) - # if dist.get_rank() == 0: - # print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") - # wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) - # _, notp_bw_time = bw(loss_no_tp) + _, notp_bw_time = bw(loss_no_tp) _, tp_bw_time = bw(loss_tp) - # _, gathered_bw_time = bw(loss_gathered) + _, gathered_bw_time = bw(loss_gathered) optimizer_tp.step() - # optimizer_no_tp.step() - # optimizer_gathered.step() + optimizer_no_tp.step() + optimizer_gathered.step() if dist.get_rank() == 0: wandb.log({ "tp.forward.time:": tp_fw_time, "tp.backward.time:": tp_bw_time, - # "notp.forward.time:": notp_fw_time, - # "notp.backward.time:": notp_bw_time, - # "gathered.forward.time:": gathered_fw_time, - # "gathered.backward.time:": gathered_bw_time + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time }) dist.barrier() From 7de1d6f68a297396aa6ceeb8d4ce7d6b981f0ea8 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Wed, 10 Aug 2022 10:32:51 +0900 Subject: [PATCH 24/37] precommit run --- .../_parallel_2d/deparallel/test_deparallelize.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py index 26286378..2995d39d 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py @@ -64,8 +64,9 @@ def bw(tensors): tp_size = 4 tp_depth = 1 -model_name = "jason9693/soongsil-bert-base" +model_name = "gpt2" mkwargs = { + # 'pad_token': '[PAD]' } dataset_name = "squad" From 59efdf034c2dd7cce5feec71995f3b5424877bb1 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Wed, 10 Aug 2022 11:19:13 +0900 Subject: [PATCH 25/37] precommit run --- .gitignore | 2 +- oslo/torch/nn/modules/linear.py | 4 +- .../nn/parallel/tensor_parallel/__init__.py | 6 +- .../parallel/tensor_parallel/_base_wrapper.py | 44 ++++--- .../tensor_parallel/_parallel_1d/_wrapper.py | 62 ++++++---- .../tensor_parallel/_parallel_2d/_wrapper.py | 84 ++++++++----- .../tensor_parallel/_parallel_2p5d/_ops.py | 2 +- .../_parallel_2p5d/_wrapper.py | 116 ++++++++++++------ .../tensor_parallel/tensor_parallel.py | 61 +++++---- oslo/transformers/mapping_utils.py | 27 +++- .../deparallel/test_deparallelize.py | 65 +++++----- .../deparallel/test_load_parallel.py | 46 +++---- .../_parallel_1d/deparallel/test_qkv.py | 30 +++-- .../_parallel_1d/deparallel/test_vocab.py | 8 +- .../_parallel_1d/test_wrapper_1d.py | 32 ++--- .../deparallel/test_deparallelize.py | 61 ++++----- .../deparallel/test_load_parallel.py | 46 +++---- .../_parallel_2d/deparallel/test_qkv.py | 37 ++++-- .../deparallel/test_deparallelize.py | 65 +++++----- .../deparallel/test_load_parallel.py | 46 +++---- .../_parallel_2p5d/deparallel/test_qkv.py | 37 ++++-- .../_parallel_2p5d/deparallel/test_vocab.py | 8 +- .../_parallel_2p5d/test_linear_2p5d.py | 6 +- .../_parallel_2p5d/test_wrapper_2p5d.py | 36 +++--- 24 files changed, 553 insertions(+), 378 deletions(-) diff --git a/.gitignore b/.gitignore index 8b1e6911..bfa231e7 100644 --- a/.gitignore +++ b/.gitignore @@ -395,4 +395,4 @@ wandb/ # sample huggingface models **/pytorch_model.bin -**/config.json \ No newline at end of file +**/config.json diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index aa693e03..bc21e52b 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -369,7 +369,7 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] outputs = outputs.contiguous() - + return outputs @@ -512,7 +512,7 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] outputs = outputs.contiguous() - + return outputs diff --git a/oslo/torch/nn/parallel/tensor_parallel/__init__.py b/oslo/torch/nn/parallel/tensor_parallel/__init__.py index ba704bc3..43e5116c 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/__init__.py +++ b/oslo/torch/nn/parallel/tensor_parallel/__init__.py @@ -1,7 +1,7 @@ from oslo.torch.nn.parallel.tensor_parallel.mapping import Column, Row, Update, Head -from oslo.torch.nn.parallel.tensor_parallel.tensor_parallel import ( - TensorParallel +from oslo.torch.nn.parallel.tensor_parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, ) -from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import BaseTensorParallelWrapper __ALL__ = [TensorParallel, Column, Row, Update, Head, BaseTensorParallelWrapper] diff --git a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py index 5d100e9f..ec0fad62 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py @@ -18,7 +18,7 @@ is_oslo_model, allocate_params, unwrap_parallel, - get_parameter_dtype + get_parameter_dtype, ) @@ -32,32 +32,32 @@ class BaseTensorParallelWrapper(ParallelWrapper): """ def __init__( - self, - module: nn.Module, - parallel_context: ParallelContext, - mapping: dict = None, - module_args: dict = None + self, + module: nn.Module, + parallel_context: ParallelContext, + mapping: dict = None, + module_args: dict = None, ): super().__init__() @torch.no_grad() def save_parallelized( - self, - new_module, - save_directory: Union[str, os.PathLike], - save_config: bool = True, - state_dict: Optional[dict] = None, - save_function: Callable = torch.save, - merge_checkpoints: bool = False, - mapping: Optional[dict] = None, - **kwargs, + self, + new_module, + save_directory: Union[str, os.PathLike], + save_config: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + merge_checkpoints: bool = False, + mapping: Optional[dict] = None, + **kwargs, ): logger = getLogger("TensorParallel") PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" if ( - self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1 - and self.parallel_context.get_world_size(ParallelMode.PIPELINE) == 1 + self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1 + and self.parallel_context.get_world_size(ParallelMode.PIPELINE) == 1 ): if dist.get_rank() == 0: self.save_pretrained( @@ -75,7 +75,7 @@ def save_parallelized( module=new_module, parallel_context=self.parallel_context, mapping=mapping, - module_args=self.config + module_args=self.config, ).eval() if state_dict is None: @@ -97,7 +97,9 @@ def save_parallelized( ) else: if save_config: - with open(os.path.join(save_directory, "config.json"), "w") as f: + with open( + os.path.join(save_directory, "config.json"), "w" + ) as f: json.dump(self.config, f) save_function( model_to_save, @@ -138,7 +140,9 @@ def save_parallelized( # Handle the case where some state_dict keys shouldn't be saved if getattr(self, "_keys_to_ignore_on_save", None) is not None: state_dict = { - k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save + k: v + for k, v in state_dict.items() + if k not in self._keys_to_ignore_on_save } # If we save using the predefined names, we can load using `from_pretrained` diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 17445edf..9543468e 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -48,7 +48,7 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, - module_args: dict = None + module_args: dict = None, ): super().__init__(module, parallel_context) self.module = module @@ -142,13 +142,14 @@ def _parallelize_linear(self): def _parallelize_head(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_head( - self.module, param_name + self.module, param_name ) and isinstance(module, nn.Linear): self._slice_head( module=module, reversed=self.tensor_parallel_mapping.is_reversed( - self.module, param_name, - ) + self.module, + param_name, + ), ) @staticmethod @@ -191,25 +192,27 @@ def _slice_embedding(self, module): hasattr(module_head, "weight") and module_head.weight is module.weight and not isinstance(module_head, nn.Embedding) - and not self.tensor_parallel_mapping.is_head( - self.module, name - ) + and not self.tensor_parallel_mapping.is_head(self.module, name) ): _update_module_arguments( module=module_head, parallel_context=self.parallel_context, - reversed=self.tensor_parallel_mapping.is_reversed(self.module, name), + reversed=self.tensor_parallel_mapping.is_reversed( + self.module, name + ), fusion_degree=1, orig_module=copy.deepcopy(module_head.__class__), out_features=module.weight.size()[0], # in_features=module.weight.size()[1], gather_output=not is_oslo_model(self.module), skip_bias_add=module.skip_bias_add - if hasattr(module, "skip_bias_add") - else False, + if hasattr(module, "skip_bias_add") + else False, ) - if isinstance(module_head, nn.Linear) or isinstance(module_head, nn.Conv1D): + if isinstance(module_head, nn.Linear) or isinstance( + module_head, nn.Conv1D + ): module_head.__class__ = ColLinear1D else: weight_list = module.weight.data.chunk(world_size, dim=1) @@ -387,6 +390,7 @@ def _slice_head(self, module, reversed): def _zero_rank_log(self, txt): import torch.distributed as dist + if dist.get_rank() == 0: print(txt) # 모니터링 생성 대기 @@ -428,20 +432,16 @@ def _deparallelize_embedding(self): def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): - if self.tensor_parallel_mapping.is_column_parallel( - self.module, param_name - ): + if self.tensor_parallel_mapping.is_column_parallel(self.module, param_name): self._gather_column_linear(module) - elif self.tensor_parallel_mapping.is_row_parallel( - self.module, param_name - ): + elif self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): self._gather_row_linear(module) def _deparallelize_head(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_head( - self.module, param_name + self.module, param_name ) and isinstance(module, ColLinear1D): self._zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) @@ -450,7 +450,9 @@ def _gather_embedding(self, module): world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): # w = gather_2d(self.parallel_context, module.weight.data, world_size, col_first=True) - tensor_list = [torch.zeros_like(module.weight.data) for _ in range(world_size)] + tensor_list = [ + torch.zeros_like(module.weight.data) for _ in range(world_size) + ] dist.all_gather( tensor_list, module.weight.data.contiguous(), @@ -472,10 +474,12 @@ def _gather_embedding(self, module): parallel_context=None, num_embeddings=module.weight.size()[0], embedding_dim=module.weight.size()[1], - orig_module=None + orig_module=None, ) else: - tensor_list = [torch.zeros_like(module.weight.data) for _ in range(world_size)] + tensor_list = [ + torch.zeros_like(module.weight.data) for _ in range(world_size) + ] dist.all_gather( tensor_list, module.weight.data.contiguous(), @@ -487,7 +491,7 @@ def _gather_embedding(self, module): _update_module_arguments( module=module, parallel_context=None, - embedding_dim = module.weight.size()[1] + embedding_dim=module.weight.size()[1], ) module.__class__ = nn.Embedding @@ -497,13 +501,17 @@ def _gather_linear(self, module, dim=1): world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR) - w = self._reconstruct_combined_qkv(module.weight, world_size, fusion_degree, dim) + w = self._reconstruct_combined_qkv( + module.weight, world_size, fusion_degree, dim + ) if is_reversed: w = w.t() module.weight.data = w if hasattr(module, "bias") and module.bias is not None and dim != 1: - b = self._reconstruct_combined_qkv(module.bias, world_size, fusion_degree, dim) + b = self._reconstruct_combined_qkv( + module.bias, world_size, fusion_degree, dim + ) module.bias.data = b _update_module_arguments( @@ -537,7 +545,7 @@ def _gather_head(self, module: ColLinear1D): b = self._reconstruct_combined_qkv(module.bias, world_size, 1, 0) - module.bias.data = b[:module.weight.size()[0]] + module.bias.data = b[: module.weight.size()[0]] self._zero_rank_log("after gathering head bias") _update_module_arguments( @@ -561,6 +569,8 @@ def _reconstruct_combined_qkv(self, tensor, world_size, fusion_degree, dim: int) result_list = [] for w in tensor_list: w_list = [torch.zeros_like(w) for _ in range(world_size)] - dist.all_gather(w_list, w, self.parallel_context.get_group(ParallelMode.TENSOR_1D)) + dist.all_gather( + w_list, w, self.parallel_context.get_group(ParallelMode.TENSOR_1D) + ) result_list.append(torch.cat(w_list, dim=dim)) return torch.cat(result_list, dim=dim) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 2b1dc6e5..d71d44c3 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -20,7 +20,7 @@ split_batch_2d, gather_2d, gather_1d, - gather_1d_twice + gather_1d_twice, ) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, @@ -56,8 +56,7 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, - module_args: dict = None - + module_args: dict = None, ): super().__init__(module, parallel_context, mapping, module_args) self.module = module @@ -413,13 +412,11 @@ def _slice_head(self, module, reversed): pipeline_parallel_size = self.parallel_context.get_world_size( ParallelMode.PIPELINE ) - + if hasattr(module, "bias") and module.bias is not None: if module.bias.dim() >= 1: bias_list = module.bias.data.chunk(summa_dim, dim=0) - bias_list = [ - bias.chunk(summa_dim, dim=0) for bias in bias_list - ] + bias_list = [bias.chunk(summa_dim, dim=0) for bias in bias_list] module.bias.data = bias_list[row_rank][col_rank].contiguous() if hasattr(module.bias, "oslo_parallel"): @@ -430,7 +427,7 @@ def _slice_head(self, module, reversed): ParallelMode.TENSOR_2D_ROW: row_rank, ParallelMode.TENSOR_2D_COL: col_rank, } - + _update_module_arguments( module=module, in_features=module.weight.size()[1], @@ -453,6 +450,7 @@ def _slice_head(self, module, reversed): def _zero_rank_log(self, txt): import torch.distributed as dist + if dist.get_rank() == 0: print(txt) # 모니터링 생성 대기 @@ -479,7 +477,6 @@ def deparallelize(self): self._rollback_mp_arguments() - def _rollback_mp_arguments(self): for module in self.module.modules(): for elem in self.tensor_parallel_mapping.update_attrs(self.module): @@ -498,14 +495,14 @@ def _deparallelize_embedding(self): def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_column_parallel( - self.module, param_name + self.module, param_name ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): self._gather_linear(module) def _deparallelize_head(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_head( - self.module, param_name + self.module, param_name ) and isinstance(module, Linear2D): self._zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) @@ -521,7 +518,9 @@ def _gather_embedding(self, module): w = module.weight.data if module.embedding_dim == module.weight.size()[0]: - w = gather_2d(self.parallel_context, module.weight.data, summa_dim, col_first=True) + w = gather_2d( + self.parallel_context, module.weight.data, summa_dim, col_first=True + ) assert hasattr( self.module, "orig_vocab_size" @@ -537,7 +536,7 @@ def _gather_embedding(self, module): parallel_context=None, num_embeddings=module.weight.size()[0], embedding_dim=module.weight.size()[1], - orig_module=None + orig_module=None, ) else: w = gather_1d_twice(self.parallel_context, module.weight.data, summa_dim, 1) @@ -546,7 +545,7 @@ def _gather_embedding(self, module): _update_module_arguments( module=module, parallel_context=None, - embedding_dim=module.weight.size()[1] + embedding_dim=module.weight.size()[1], ) module.__class__ = nn.Embedding @@ -559,7 +558,7 @@ def _gather_head(self, module: Linear2D): b = gather_1d(self.parallel_context, module.bias.data, summa_dim, 0) - module.bias.data = b[:module.weight.size()[0]] + module.bias.data = b[: module.weight.size()[0]] self._zero_rank_log("after gathering bias") _update_module_arguments( @@ -593,7 +592,12 @@ def _gather_linear(self, module: Linear2D): summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) - w = gather_2d(self.parallel_context, module.weight.data, summa_dim=summa_dim, col_first=True) + w = gather_2d( + self.parallel_context, + module.weight.data, + summa_dim=summa_dim, + col_first=True, + ) # print(f"w shape: {w.shape}\nweight shape: {module.weight.data.shape}") if fusion_degree > 1: w = self._reconstruct_combined_qkv(w, summa_dim, fusion_degree, False) @@ -603,7 +607,9 @@ def _gather_linear(self, module: Linear2D): if hasattr(module, "bias") and module.bias is not None: # if slice_bias is True and module.bias.dim() >= 1: - b = gather_1d_twice(self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0) + b = gather_1d_twice( + self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0 + ) if fusion_degree > 1: b = self._reconstruct_combined_qkv(b, summa_dim, fusion_degree, True) b = b.view(b.size()[1:]) @@ -637,7 +643,12 @@ def _gather_layernorm(self, module): summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) if hasattr(module, "weight") and module.weight is not None: if module.weight.dim() >= 1: - w = gather_1d_twice(self.parallel_context, module.weight.data, summa_dim=summa_dim, dim=0) + w = gather_1d_twice( + self.parallel_context, + module.weight.data, + summa_dim=summa_dim, + dim=0, + ) module.weight.data = w if hasattr(module.weight, "oslo_parallel"): @@ -645,7 +656,9 @@ def _gather_layernorm(self, module): if hasattr(module, "bias") and module.bias is not None: if module.bias.dim() >= 1: - b = gather_1d_twice(self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0) + b = gather_1d_twice( + self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0 + ) module.bias.data = b if hasattr(module.bias, "oslo_parallel"): @@ -670,16 +683,32 @@ def _gather_layernorm(self, module): def _reconstruct_combined_qkv(tensor, summa_dim, fusion_degree, is_bias=False): last_dim = tensor.size()[-1] if is_bias is False: - reshaped_w = tensor.view(summa_dim*fusion_degree, -1, last_dim) + reshaped_w = tensor.view(summa_dim * fusion_degree, -1, last_dim) # print(f"tensor.size: {tensor.size()}, reshaped_w.size: {reshaped_w.size()}") - recon_w = torch.cat([ - reshaped_w[i * fusion_degree: (i+1) * fusion_degree] - for i in range(summa_dim)], 1).view(-1, last_dim).contiguous() + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(summa_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) else: - reshaped_w = tensor.view(fusion_degree*summa_dim, -1) - recon_w = torch.cat([ - reshaped_w[i * fusion_degree: (i+1) * fusion_degree] - for i in range(summa_dim)], 1).view(-1, last_dim).contiguous() + reshaped_w = tensor.view(fusion_degree * summa_dim, -1) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(summa_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) return recon_w @staticmethod @@ -691,4 +720,3 @@ def _reconstrunct_combined_qkv_bias(tensor, summa_dim, fusion_degree): tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) tensor = [tensor[j] for j in range(summa_dim)] return tensor - diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py index ac471f93..e735b8fd 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py @@ -1345,4 +1345,4 @@ def gather_1d(parallel_context, tensor, summa_dim, dim=-1): parallel_context.get_group(parallel_modde), ) tensor = torch.cat(tensor_list, dim=dim) - return tensor \ No newline at end of file + return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 83000946..d8e1185e 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -15,7 +15,7 @@ from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( split_batch_2p5d, gather_2d, - gather_1d + gather_1d, ) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( @@ -28,7 +28,7 @@ from oslo.torch.nn.parallel.utils import ( _update_module_arguments, is_huggingface_model, - is_oslo_model + is_oslo_model, ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -47,11 +47,11 @@ class _TensorParallel2p5D(BaseTensorParallelWrapper): """ def __init__( - self, - module: nn.Module, - parallel_context: ParallelContext, - mapping: dict = None, - module_args: dict = None + self, + module: nn.Module, + parallel_context: ParallelContext, + mapping: dict = None, + module_args: dict = None, ): super().__init__(module, parallel_context, mapping, module_args) self.module = module @@ -129,7 +129,7 @@ def _parallelize_embedding(self): def _parallalize_linear(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_column_parallel( - self.module, param_name + self.module, param_name ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): self._slice_linear( module=module, @@ -152,7 +152,7 @@ def _parallelize_layernorm(self): def _parallelize_head(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_head( - self.module, param_name + self.module, param_name ) and isinstance(module, nn.Linear): self._slice_head( module=module, @@ -161,7 +161,7 @@ def _parallelize_head(self): ), gathered=self.tensor_parallel_mapping.is_gather_output( self.module, param_name - ) + ), ) @staticmethod @@ -452,9 +452,15 @@ def _slice_head(self, module, reversed, gathered): module.bias.data = bias_list[row_rank].contiguous() if hasattr(module.bias, "oslo_parallel"): - module.bias.oslo_parallel[ParallelMode.TENSOR_2P5D_ROW] = row_rank - module.bias.oslo_parallel[ParallelMode.TENSOR_2P5D_COL] = col_rank - module.weight.oslo_parallel[ParallelMode.TENSOR_2P5D_DEP] = dep_rank + module.bias.oslo_parallel[ + ParallelMode.TENSOR_2P5D_ROW + ] = row_rank + module.bias.oslo_parallel[ + ParallelMode.TENSOR_2P5D_COL + ] = col_rank + module.weight.oslo_parallel[ + ParallelMode.TENSOR_2P5D_DEP + ] = dep_rank else: module.bias.oslo_parallel = { ParallelMode.TENSOR_2P5D_ROW: row_rank, @@ -485,6 +491,7 @@ def _slice_head(self, module, reversed, gathered): def _zero_rank_log(self, txt): import torch.distributed as dist + if dist.get_rank() == 0: print(txt) # 모니터링 생성 대기 @@ -493,6 +500,7 @@ def _zero_rank_log(self, txt): def _pdb_set_trace(self): import pdb import torch.distributed as dist + if dist.get_rank() == 0: pdb.set_trace() # 모니터링 생성 대기 @@ -539,14 +547,14 @@ def _deparallelize_embedding(self): def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_column_parallel( - self.module, param_name + self.module, param_name ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): self._gather_linear(module) def _deparallelize_head(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_head( - self.module, param_name + self.module, param_name ) and isinstance(module, Linear2p5D): self._zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) @@ -557,13 +565,17 @@ def _deparallelize_layernorm(self): self._gather_layernorm(module) def _gather_embedding(self, module): - tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) - if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): - w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim, col_first=True) + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + w = gather_2d( + self.parallel_context, module.weight.data, tesseract_dim, col_first=True + ) assert hasattr( self.module, "orig_vocab_size" - ), f"wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." orig_vocab_size = self.module.orig_vocab_size module.weight.data = w[:orig_vocab_size, :].contiguous() @@ -574,7 +586,7 @@ def _gather_embedding(self, module): parallel_context=None, num_embeddings=module.weight.size()[0], embedding_dim=module.weight.size()[1], - orig_module=None + orig_module=None, ) else: w = gather_1d(self.parallel_context, module.weight, tesseract_dim, 1) @@ -584,7 +596,7 @@ def _gather_embedding(self, module): _update_module_arguments( module=module, parallel_context=None, - embedding_dim=module.weight.size()[1] + embedding_dim=module.weight.size()[1], ) module.__class__ = nn.Embedding @@ -593,14 +605,15 @@ def _gather_head(self, module: Linear2p5D): return self._gather_linear(module) elif hasattr(module, "bias") and module.bias is not None: self._zero_rank_log("before gathering bias") - tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) - module.bias.data = b[:module.weight.size()[0]] + module.bias.data = b[: module.weight.size()[0]] self._zero_rank_log("after gathering bias") - _update_module_arguments( module=module, in_features=module.weight.size()[1], @@ -631,9 +644,16 @@ def _gather_linear(self, module: Linear2p5D): fusion_degree = module.fusion_degree # slice_bias = module.slice_bias - tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) - w = gather_2d(self.parallel_context, module.weight.data, tesseract_dim=tesseract_dim, col_first=True) + w = gather_2d( + self.parallel_context, + module.weight.data, + tesseract_dim=tesseract_dim, + col_first=True, + ) if fusion_degree > 1: w = self._reconstruct_combined_qkv(w, tesseract_dim, fusion_degree, False) if is_reversed: @@ -643,7 +663,9 @@ def _gather_linear(self, module: Linear2p5D): if hasattr(module, "bias") and module.bias is not None: b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) if fusion_degree > 1: - b = self._reconstruct_combined_qkv(b, tesseract_dim, fusion_degree, True) + b = self._reconstruct_combined_qkv( + b, tesseract_dim, fusion_degree, True + ) b = b.view(b.size()[1:]) module.bias.data = b @@ -673,10 +695,14 @@ def _gather_linear(self, module: Linear2p5D): module.__class__ = nn.Linear def _gather_layernorm(self, module): - tesseract_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) if hasattr(module, "weight") and module.weight is not None: if module.weight.dim() >= 1: - w = gather_1d(self.parallel_context, module.weight.data, tesseract_dim, 0) + w = gather_1d( + self.parallel_context, module.weight.data, tesseract_dim, 0 + ) module.weight.data = w if hasattr(module.weight, "oslo_parallel"): @@ -710,15 +736,31 @@ def _gather_layernorm(self, module): def _reconstruct_combined_qkv(tensor, tesseract_dim, fusion_degree, is_bias=False): last_dim = tensor.size()[-1] if is_bias is False: - reshaped_w = tensor.view(tesseract_dim*fusion_degree, -1, last_dim) - recon_w = torch.cat([ - reshaped_w[i * fusion_degree: (i+1) * fusion_degree] - for i in range(tesseract_dim)], 1).view(-1, last_dim).contiguous() + reshaped_w = tensor.view(tesseract_dim * fusion_degree, -1, last_dim) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(tesseract_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) else: - reshaped_w = tensor.view(fusion_degree*tesseract_dim, -1) - recon_w = torch.cat([ - reshaped_w[i * fusion_degree: (i+1) * fusion_degree] - for i in range(tesseract_dim)], 1).view(-1, last_dim).contiguous() + reshaped_w = tensor.view(fusion_degree * tesseract_dim, -1) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(tesseract_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) return recon_w @staticmethod diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index 56d356d6..f4cb7e13 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -36,7 +36,7 @@ get_parallel_context, is_huggingface_model, allocate_params, - get_parameter_dtype + get_parameter_dtype, ) @@ -87,11 +87,13 @@ def __init__( module: nn.Module, parallel_context: Optional[ParallelContext] = None, mapping: dict = None, - module_args: dict = None + module_args: dict = None, ): super().__init__() if is_huggingface_model(module): - assert module_args is None, "module_args must not be provided in huggingface module." + assert ( + module_args is None + ), "module_args must not be provided in huggingface module." else: assert isinstance(module_args, dict), "module_args must be a dict." @@ -101,9 +103,13 @@ def __init__( module = self._resize_head_bias_size(module, self.parallel_context, mapping) if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D: - self.module = _TensorParallel1D(module, self.parallel_context, mapping, module_args) + self.module = _TensorParallel1D( + module, self.parallel_context, mapping, module_args + ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D: - self.module = _TensorParallel2D(module, self.parallel_context, mapping, module_args) + self.module = _TensorParallel2D( + module, self.parallel_context, mapping, module_args + ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D: self.module = _TensorParallel2p5D(module, self.parallel_context, mapping) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_3D: @@ -175,10 +181,12 @@ def _resize_num_classes(model, parallel_context, mapping): unwrapped_model.get_input_embeddings().num_embeddings ) - assert hasattr(unwrapped_model.get_input_embeddings(), "orig_num_classes"), ( - "call _resize_vocab before _resize_num_classes" + assert hasattr( + unwrapped_model.get_input_embeddings(), "orig_num_classes" + ), "call _resize_vocab before _resize_num_classes" + out_features = ( + unwrapped_model.get_input_embeddings().orig_num_classes ) - out_features = unwrapped_model.get_input_embeddings().orig_num_classes setattr(module, "orig_num_classes", out_features) setattr( unwrapped_model, @@ -244,9 +252,12 @@ def _resize_head_bias_size(model, parallel_context, mapping): divisible_by = get_divisible_by(parallel_context) for param_name, module in unwrapped_model.named_modules(): - if tensor_parallel_mapping.is_head(unwrapped_model, param_name - ) and unwrapped_model.get_input_embeddings().weight is module.weight \ - and hasattr(module, "bias") and module.bias is not None: + if ( + tensor_parallel_mapping.is_head(unwrapped_model, param_name) + and unwrapped_model.get_input_embeddings().weight is module.weight + and hasattr(module, "bias") + and module.bias is not None + ): out_features = module.bias.size()[0] new_out_features = out_features @@ -268,14 +279,14 @@ def _resize_head_bias_size(model, parallel_context, mapping): @torch.no_grad() def save_parallelized( - self, - save_directory: Union[str, os.PathLike], - save_config: bool = True, - state_dict: Optional[dict] = None, - save_function: Callable = torch.save, - merge_checkpoints: bool = False, - mapping: Optional[dict] = None, - **kwargs, + self, + save_directory: Union[str, os.PathLike], + save_config: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + merge_checkpoints: bool = False, + mapping: Optional[dict] = None, + **kwargs, ): unwrapped_model = unwrap_parallel(self.module.module) if is_huggingface_model(unwrapped_model): @@ -284,8 +295,12 @@ def save_parallelized( new_module = unwrapped_model.__class__(**self.module.config) new_module = self._resize_vocab_size(new_module, self.parallel_context) - new_module = self._resize_num_classes(new_module, self.parallel_context, mapping) - new_module = self._resize_head_bias_size(new_module, self.parallel_context, mapping) + new_module = self._resize_num_classes( + new_module, self.parallel_context, mapping + ) + new_module = self._resize_head_bias_size( + new_module, self.parallel_context, mapping + ) new_module = self.module.save_parallelized( new_module, @@ -303,9 +318,7 @@ def save_parallelized( @staticmethod def get_module_args(module): state_dict = module.state_dict() - return { - key: value.shape for key, value in state_dict.items() - } + return {key: value.shape for key, value in state_dict.items()} def from_parallelized(self, path): return self.module.from_parallelized(path) diff --git a/oslo/transformers/mapping_utils.py b/oslo/transformers/mapping_utils.py index 98f0085f..c8fd3973 100644 --- a/oslo/transformers/mapping_utils.py +++ b/oslo/transformers/mapping_utils.py @@ -72,7 +72,7 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): "sop_classifier.classifier", "classifier", "qa_outputs", - gather_output=True + gather_output=True, ), ], "Bart": [ @@ -80,7 +80,12 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): Column("classification_head.dense", gather_output=True), Row("out_proj", "fc2"), Update("embed_dim", "num_heads"), - Head("lm_head", "classification_head.out_proj", "qa_outputs", gather_output=True), + Head( + "lm_head", + "classification_head.out_proj", + "qa_outputs", + gather_output=True, + ), ], "Bert": [ Column("query", "key", "value", "intermediate.dense"), @@ -88,7 +93,13 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): Row("output.dense"), Update("num_attention_heads", "all_head_size"), Head("transform.dense", gather_output=False), - Head("decoder", "seq_relationship", "classifier", "qa_outputs", gather_output=True), + Head( + "decoder", + "seq_relationship", + "classifier", + "qa_outputs", + gather_output=True, + ), ], "Blenderbot": [ Column("q_proj", "k_proj", "v_proj", "fc1"), @@ -145,7 +156,7 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): "classifier", "qa_outputs", "summary", - gather_output=True + gather_output=True, ), ], "Roberta": [ @@ -158,8 +169,12 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): Row("output.dense"), Update("num_attention_heads", "all_head_size"), Head("lm_head.dense"), - Head("lm_head.decoder", "classifier.out_proj", "classifier", "qa_outputs", - gather_output=True + Head( + "lm_head.decoder", + "classifier.out_proj", + "classifier", + "qa_outputs", + gather_output=True, ), ], } diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py index 4578e4d4..da30d651 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py @@ -5,7 +5,13 @@ import torch from torch.optim import Adam from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, AutoModelForCausalLM, AutoConfig +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, + AutoConfig, +) from transformers import AutoModelForMaskedLM from oslo.torch.nn.parallel.tensor_parallel import TensorParallel @@ -23,13 +29,13 @@ def seed_all(seed: int = 1930): print("Using Seed Number {}".format(seed)) os.environ["PYTHONHASHSEED"] = str( - seed) # set PYTHONHASHSEED env var at fixed value + seed + ) # set PYTHONHASHSEED env var at fixed value torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) np.random.seed(seed) # for numpy pseudo-random generator - random.seed( - seed) # set fixed value for python built-in pseudo-random generator + random.seed(seed) # set fixed value for python built-in pseudo-random generator torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = False @@ -39,7 +45,8 @@ def seed_worker(_worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) - + + seed_all(seed=1994) @@ -48,7 +55,8 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - return result, end-start + return result, end - start + return wrapper @@ -66,9 +74,7 @@ def bw(tensors): tp_depth = 1 model_name = "bert-base-uncased" -mkwargs = { - 'pad_token': '[PAD]' -} +mkwargs = {"pad_token": "[PAD]"} dataset_name = "squad" @@ -87,9 +93,9 @@ def bw(tensors): # 모델 생성 및 병렬화 수행 model_no_tp = AutoModelForCausalLM.from_config( - AutoConfig.from_pretrained(model_name)).cuda() -model_tp = AutoModelForCausalLM.from_config( - AutoConfig.from_pretrained(model_name)) + AutoConfig.from_pretrained(model_name) +).cuda() +model_tp = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)) wrapper_tp = TensorParallel(model_tp, parallel_context) allocate_params(wrapper_tp, parallel_context) # allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 @@ -114,7 +120,7 @@ def bw(tensors): cur = time.time() # 저장 -wrapper_tp.save_parallelized('test/', merge_checkpoints=True) +wrapper_tp.save_parallelized("test/", merge_checkpoints=True) # 모니터링 생성 대기 dist.barrier() @@ -139,12 +145,11 @@ def bw(tensors): max_length=512, ).to("cuda") - loss_no_tp, notp_fw_time = \ - fw(model_no_tp, **inputs, labels=inputs["input_ids"]) - loss_tp, tp_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) - loss_gathered, gathered_fw_time = \ - fw(model_gathered, **inputs, labels=inputs["input_ids"]) + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = fw( + model_gathered, **inputs, labels=inputs["input_ids"] + ) if dist.get_rank() == 0: print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") @@ -159,17 +164,15 @@ def bw(tensors): optimizer_gathered.step() if dist.get_rank() == 0: - wandb.log({ - "tp.forward.time:": tp_fw_time, - "tp.backward.time:": tp_bw_time, - "notp.forward.time:": notp_fw_time, - "notp.backward.time:": notp_bw_time, - "gathered.forward.time:": gathered_fw_time, - "gathered.backward.time:": gathered_bw_time - }) + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time, + } + ) dist.barrier() - - - - diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py index a0f5b3b7..fdb0b32b 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py @@ -16,7 +16,8 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - return result, end-start + return result, end - start + return wrapper @@ -34,8 +35,7 @@ def bw(tensors): tp_depth = 1 model_name = "gpt2" -mkwargs = { -} +mkwargs = {} dataset_name = "squad" # parallel context 생성 @@ -78,18 +78,17 @@ def bw(tensors): cur = time.time() # 저장 -wrapper_tp.save_parallelized('test/', merge_checkpoints=False) +wrapper_tp.save_parallelized("test/", merge_checkpoints=False) # 모니터링 생성 대기 dist.barrier() # 로드 model_reparallel = TensorParallel( - GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), - parallel_context + GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), parallel_context ) allocate_params(model_reparallel, parallel_context) -model_reparallel.from_parallelized('test/') +model_reparallel.from_parallelized("test/") optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) dist.barrier() @@ -108,12 +107,11 @@ def bw(tensors): max_length=512, ).to("cuda") - loss_no_tp, notp_fw_time = \ - fw(model_no_tp, **inputs, labels=inputs["input_ids"]) - loss_tp, tp_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) - loss_reparallel, reparallel_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_reparallel, reparallel_fw_time = fw( + wrapper_tp, **inputs, labels=inputs["input_ids"] + ) if dist.get_rank() == 0: print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, reparallel:{loss_reparallel}") @@ -128,13 +126,15 @@ def bw(tensors): optimizer_reparallel.step() if dist.get_rank() == 0: - wandb.log({ - "tp.forward.time:": tp_fw_time, - "tp.backward.time:": tp_bw_time, - "notp.forward.time:": notp_fw_time, - "notp.backward.time:": notp_bw_time, - "reparallel.forward.time:": reparallel_fw_time, - "reparallel.backward.time:": reparallel_bw_time - }) - -dist.barrier() \ No newline at end of file + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "reparallel.forward.time:": reparallel_fw_time, + "reparallel.backward.time:": reparallel_bw_time, + } + ) + +dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py index 003d3d5e..14033085 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py @@ -1,6 +1,8 @@ import torch.nn as nn -from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._wrapper import _TensorParallel1D +from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._wrapper import ( + _TensorParallel1D, +) from oslo.torch.nn import ColumnParallelLinear, RowParallelLinear from oslo.torch.distributed import ParallelContext, ParallelMode from copy import deepcopy @@ -22,22 +24,28 @@ rank = parallel_context.get_local_rank(ParallelMode.TENSOR_1D) fusion_degree = 3 -linear = nn.Linear(4, fusion_degree*4).cuda() +linear = nn.Linear(4, fusion_degree * 4).cuda() w = deepcopy(linear.weight.data) b = deepcopy(linear.bias.data) -dim= 1 +dim = 1 weight_list = w.t().chunk(fusion_degree * world_size, dim=dim) bias_list = b.chunk(fusion_degree * world_size, dim=0) # [t][f*t] -weight_list = _TensorParallel1D._deconstruct_combined_qkv(weight_list, world_size, fusion_degree, dim) +weight_list = _TensorParallel1D._deconstruct_combined_qkv( + weight_list, world_size, fusion_degree, dim +) chunked_w = weight_list[rank].contiguous() -bias_list = _TensorParallel1D._deconstruct_combined_qkv(bias_list, world_size, fusion_degree, 0) +bias_list = _TensorParallel1D._deconstruct_combined_qkv( + bias_list, world_size, fusion_degree, 0 +) chunked_b = bias_list[rank].contiguous() -linear_1d = RowParallelLinear(4, fusion_degree * 4, parallel_context=parallel_context, bias=True) +linear_1d = RowParallelLinear( + 4, fusion_degree * 4, parallel_context=parallel_context, bias=True +) if parallel_context.get_global_rank() == 0: print(chunked_w.size()) print(linear_1d.weight.data.size()) @@ -49,12 +57,16 @@ # reshaped_w = recon_chunked_w.view(-1, tesseract_dim, 4) # recon_w = torch.cat([reshaped_w[:3], reshaped_w[3:]], 1).view(-1, 4).contiguous() print(recon_chunked_w.shape) -recon_w = _TensorParallel1D._reconstruct_combined_qkv(recon_chunked_w, world_size, fusion_degree, dim) -recon_b = _TensorParallel1D._reconstruct_combined_qkv(recon_chunked_b, world_size, fusion_degree, 0) +recon_w = _TensorParallel1D._reconstruct_combined_qkv( + recon_chunked_w, world_size, fusion_degree, dim +) +recon_b = _TensorParallel1D._reconstruct_combined_qkv( + recon_chunked_b, world_size, fusion_degree, 0 +) if parallel_context.get_global_rank() == 0: print(f"original w: \n{w}\n") print(f"reconstruct w: \n{recon_w}\n") print(f"original b: \n{b}\n") - print(f"reconstruct b: \n{recon_b}\n") \ No newline at end of file + print(f"reconstruct b: \n{recon_b}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py index 3edfde0c..1db376eb 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py @@ -5,7 +5,11 @@ from oslo.torch.nn import VocabParallelEmbedding2p5D from oslo.torch.nn.parallel import utils -from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import split_batch_2d, split_2d, gather_2d +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( + split_batch_2d, + split_2d, + gather_2d, +) from copy import deepcopy @@ -82,4 +86,4 @@ # sse = torch.sum((out - pout) ** 2).item() # sse_update = torch.sum((out_update - pout_update) ** 2).item() # print(f"output sse: \n{sse}\n") -# print(f"next output sse: \n{sse_update}\n") \ No newline at end of file +# print(f"next output sse: \n{sse_update}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py index 0998a199..028fa255 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py @@ -1,4 +1,3 @@ -import time import wandb import torch import torch.distributed as dist @@ -9,7 +8,6 @@ from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params from oslo.torch.distributed import ParallelContext, ParallelMode - import time @@ -18,7 +16,8 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - return result, end-start + return result, end - start + return wrapper @@ -31,6 +30,7 @@ def fw(func, *args, **kwargs): def bw(tensors): return tensors.backward() + tp_size = 4 # parallel context 생성 @@ -42,8 +42,7 @@ def bw(tensors): ) model_name = "gpt2" -mkwargs = { -} +mkwargs = {} dataset_name = "squad" # 토크나이저 생성 @@ -92,22 +91,23 @@ def bw(tensors): max_length=512, ).to("cuda") - loss_no_tp, notp_fw_time = \ - fw(model_no_tp, **inputs, labels=inputs["input_ids"]) - loss_tp, tp_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) _, notp_bw_time = bw(loss_no_tp) _, tp_bw_time = bw(loss_tp) if dist.get_rank() == 0: - wandb.log({ - "tp_loss": loss_tp, - "notp_loss": loss_no_tp, - "tp.forward.time:": tp_fw_time, - "tp.backward.time:": tp_bw_time, - "notp.forward.time:": notp_fw_time, - "notp.backward.time:": notp_bw_time}) + wandb.log( + { + "tp_loss": loss_tp, + "notp_loss": loss_no_tp, + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + } + ) # # loss_tp = wrapper_tp(**inputs, labels=inputs["input_ids"]).loss # loss_no_tp = model_no_tp(**inputs, labels=inputs["input_ids"]).loss diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py index 2995d39d..f62e96d2 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py @@ -5,7 +5,13 @@ import torch from torch.optim import Adam from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, AutoModelForCausalLM, AutoConfig +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, + AutoConfig, +) from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params @@ -22,13 +28,13 @@ def seed_all(seed: int = 1930): print("Using Seed Number {}".format(seed)) os.environ["PYTHONHASHSEED"] = str( - seed) # set PYTHONHASHSEED env var at fixed value + seed + ) # set PYTHONHASHSEED env var at fixed value torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) np.random.seed(seed) # for numpy pseudo-random generator - random.seed( - seed) # set fixed value for python built-in pseudo-random generator + random.seed(seed) # set fixed value for python built-in pseudo-random generator torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = False @@ -38,7 +44,8 @@ def seed_worker(_worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) - + + seed_all(seed=1994) @@ -47,7 +54,8 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - return result, end-start + return result, end - start + return wrapper @@ -85,9 +93,9 @@ def bw(tensors): # 모델 생성 및 병렬화 수행 model_no_tp = AutoModelForCausalLM.from_config( - AutoConfig.from_pretrained(model_name)).cuda() -model_tp = AutoModelForCausalLM.from_config( - AutoConfig.from_pretrained(model_name)) + AutoConfig.from_pretrained(model_name) +).cuda() +model_tp = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)) wrapper_tp = TensorParallel(model_tp, parallel_context) allocate_params(wrapper_tp, parallel_context) # allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 @@ -112,7 +120,7 @@ def bw(tensors): cur = time.time() # 저장 -wrapper_tp.save_parallelized('test/', merge_checkpoints=True) +wrapper_tp.save_parallelized("test/", merge_checkpoints=True) # 모니터링 생성 대기 dist.barrier() @@ -137,12 +145,11 @@ def bw(tensors): max_length=512, ).to("cuda") - loss_no_tp, notp_fw_time = \ - fw(model_no_tp, **inputs, labels=inputs["input_ids"]) - loss_tp, tp_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) - loss_gathered, gathered_fw_time = \ - fw(model_gathered, **inputs, labels=inputs["input_ids"]) + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = fw( + model_gathered, **inputs, labels=inputs["input_ids"] + ) if dist.get_rank() == 0: print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") @@ -157,17 +164,15 @@ def bw(tensors): optimizer_gathered.step() if dist.get_rank() == 0: - wandb.log({ - "tp.forward.time:": tp_fw_time, - "tp.backward.time:": tp_bw_time, - "notp.forward.time:": notp_fw_time, - "notp.backward.time:": notp_bw_time, - "gathered.forward.time:": gathered_fw_time, - "gathered.backward.time:": gathered_bw_time - }) + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time, + } + ) dist.barrier() - - - - diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py index d0a12ab3..4a145cf8 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py @@ -16,7 +16,8 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - return result, end-start + return result, end - start + return wrapper @@ -34,8 +35,7 @@ def bw(tensors): tp_depth = 1 model_name = "gpt2" -mkwargs = { -} +mkwargs = {} dataset_name = "squad" # parallel context 생성 @@ -78,18 +78,17 @@ def bw(tensors): cur = time.time() # 저장 -wrapper_tp.save_parallelized('test/', merge_checkpoints=False) +wrapper_tp.save_parallelized("test/", merge_checkpoints=False) # 모니터링 생성 대기 dist.barrier() # 로드 model_reparallel = TensorParallel( - GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), - parallel_context + GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), parallel_context ) allocate_params(model_reparallel, parallel_context) -model_reparallel.from_parallelized('test/') +model_reparallel.from_parallelized("test/") optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) dist.barrier() @@ -108,12 +107,11 @@ def bw(tensors): max_length=512, ).to("cuda") - loss_no_tp, notp_fw_time = \ - fw(model_no_tp, **inputs, labels=inputs["input_ids"]) - loss_tp, tp_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) - loss_reparallel, reparallel_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_reparallel, reparallel_fw_time = fw( + wrapper_tp, **inputs, labels=inputs["input_ids"] + ) if dist.get_rank() == 0: print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, reparallel:{loss_reparallel}") @@ -128,13 +126,15 @@ def bw(tensors): optimizer_reparallel.step() if dist.get_rank() == 0: - wandb.log({ - "tp.forward.time:": tp_fw_time, - "tp.backward.time:": tp_bw_time, - "notp.forward.time:": notp_fw_time, - "notp.backward.time:": notp_bw_time, - "reparallel.forward.time:": reparallel_fw_time, - "reparallel.backward.time:": reparallel_bw_time - }) - -dist.barrier() \ No newline at end of file + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "reparallel.forward.time:": reparallel_fw_time, + "reparallel.backward.time:": reparallel_bw_time, + } + ) + +dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py index 6cb09fd8..3a7f021a 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py @@ -1,10 +1,15 @@ import torch.nn as nn -from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._wrapper import _TensorParallel2p5D -from oslo.torch.nn import Linear2p5D +from oslo.torch.nn.parallel.tensor_parallel._parallel_2d._wrapper import ( + _TensorParallel2D, +) +from oslo.torch.nn import Linear2D from oslo.torch.distributed import ParallelContext, ParallelMode from copy import deepcopy -from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import gather_1d, gather_2d +from oslo.torch.nn.parallel.tensor_parallel._parallel_2d._ops import ( + gather_1d, + gather_2d, +) tp_size = 4 tp_depth = 1 @@ -23,23 +28,25 @@ summa_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) fusion_degree = 3 -linear = nn.Linear(4, fusion_degree*4).cuda() +linear = nn.Linear(4, fusion_degree * 4).cuda() w = deepcopy(linear.weight.data) b = deepcopy(linear.bias.data) weight_list = w.chunk(summa_dim, dim=1) -weight_list = [ - weight.chunk(summa_dim * fusion_degree, dim=0) for weight in weight_list -] +weight_list = [weight.chunk(summa_dim * fusion_degree, dim=0) for weight in weight_list] bias_list = b.chunk(summa_dim * fusion_degree, dim=0) # [t][f*t] -weight_list = _TensorParallel2p5D._deconstruct_combined_qkv(weight_list, tesseract_dim, fusion_degree, False) -bias_list = _TensorParallel2p5D._deconstruct_combined_qkv(bias_list, tesseract_dim, fusion_degree, True) +weight_list = _TensorParallel2D._deconstruct_combined_qkv( + weight_list, summa_dim, fusion_degree +) +bias_list = _TensorParallel2D._deconstruct_combined_qkv( + bias_list, summa_dim, fusion_degree +) chunked_w = weight_list[row_rank][col_rank] chunked_b = bias_list[row_rank] -linear_2d = Linear2p5D(4, fusion_degree*4, parallel_context=parallel_context, bias=True) +linear_2d = Linear2D(4, fusion_degree * 4, parallel_context=parallel_context, bias=True) if parallel_context.get_global_rank() == 0: print(chunked_w.size()) print(linear_2d.weight.data.size()) @@ -51,12 +58,16 @@ # reshaped_w = recon_chunked_w.view(-1, tesseract_dim, 4) # recon_w = torch.cat([reshaped_w[:3], reshaped_w[3:]], 1).view(-1, 4).contiguous() -recon_w = _TensorParallel2p5D._reconstruct_combined_qkv(recon_chunked_w, summa_dim, fusion_degree, False) -recon_b = _TensorParallel2p5D._reconstruct_combined_qkv(recon_chunked_b, summa_dim, fusion_degree, True) +recon_w = _TensorParallel2D._reconstruct_combined_qkv( + recon_chunked_w, summa_dim, fusion_degree, False +) +recon_b = _TensorParallel2D._reconstruct_combined_qkv( + recon_chunked_b, summa_dim, fusion_degree, True +) if parallel_context.get_global_rank() == 0: print(f"original w: \n{w}\n") print(f"reconstruct w: \n{recon_w}\n") print(f"original b: \n{b}\n") - print(f"reconstruct b: \n{recon_b}\n") \ No newline at end of file + print(f"reconstruct b: \n{recon_b}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py index 56b7a483..bf6fde34 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py @@ -5,7 +5,13 @@ import torch from torch.optim import Adam from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, AutoModelForCausalLM, AutoConfig +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, + AutoConfig, +) from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params @@ -22,13 +28,13 @@ def seed_all(seed: int = 1930): print("Using Seed Number {}".format(seed)) os.environ["PYTHONHASHSEED"] = str( - seed) # set PYTHONHASHSEED env var at fixed value + seed + ) # set PYTHONHASHSEED env var at fixed value torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) np.random.seed(seed) # for numpy pseudo-random generator - random.seed( - seed) # set fixed value for python built-in pseudo-random generator + random.seed(seed) # set fixed value for python built-in pseudo-random generator torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False torch.backends.cudnn.enabled = False @@ -38,7 +44,8 @@ def seed_worker(_worker_id): worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed) - + + seed_all(seed=1994) @@ -47,7 +54,8 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - return result, end-start + return result, end - start + return wrapper @@ -65,9 +73,7 @@ def bw(tensors): tp_depth = 2 model_name = "bert-base-uncased" -mkwargs = { - 'pad_token': '[PAD]' -} +mkwargs = {"pad_token": "[PAD]"} dataset_name = "squad" # parallel context 생성 @@ -85,9 +91,9 @@ def bw(tensors): # 모델 생성 및 병렬화 수행 model_no_tp = AutoModelForCausalLM.from_config( - AutoConfig.from_pretrained(model_name)).cuda() -model_tp = AutoModelForCausalLM.from_config( - AutoConfig.from_pretrained(model_name)) + AutoConfig.from_pretrained(model_name) +).cuda() +model_tp = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)) wrapper_tp = TensorParallel(model_tp, parallel_context) allocate_params(wrapper_tp, parallel_context) # allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 @@ -112,7 +118,7 @@ def bw(tensors): cur = time.time() # 저장 -wrapper_tp.save_parallelized('test/', merge_checkpoints=True) +wrapper_tp.save_parallelized("test/", merge_checkpoints=True) # 모니터링 생성 대기 dist.barrier() @@ -137,12 +143,11 @@ def bw(tensors): max_length=512, ).to("cuda") - loss_no_tp, notp_fw_time = \ - fw(model_no_tp, **inputs, labels=inputs["input_ids"]) - loss_tp, tp_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) - loss_gathered, gathered_fw_time = \ - fw(model_gathered, **inputs, labels=inputs["input_ids"]) + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = fw( + model_gathered, **inputs, labels=inputs["input_ids"] + ) if dist.get_rank() == 0: print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") @@ -157,17 +162,15 @@ def bw(tensors): optimizer_gathered.step() if dist.get_rank() == 0: - wandb.log({ - "tp.forward.time:": tp_fw_time, - "tp.backward.time:": tp_bw_time, - "notp.forward.time:": notp_fw_time, - "notp.backward.time:": notp_bw_time, - "gathered.forward.time:": gathered_fw_time, - "gathered.backward.time:": gathered_bw_time, - }) + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time, + } + ) dist.barrier() - - - - diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py index 51f9009c..432a0d16 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py @@ -16,7 +16,8 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - return result, end-start + return result, end - start + return wrapper @@ -34,8 +35,7 @@ def bw(tensors): tp_depth = 2 model_name = "gpt2" -mkwargs = { -} +mkwargs = {} dataset_name = "squad" # parallel context 생성 @@ -78,18 +78,17 @@ def bw(tensors): cur = time.time() # 저장 -wrapper_tp.save_parallelized('test/', merge_checkpoints=False) +wrapper_tp.save_parallelized("test/", merge_checkpoints=False) # 모니터링 생성 대기 dist.barrier() # 로드 model_reparallel = TensorParallel( - GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), - parallel_context + GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), parallel_context ) allocate_params(model_reparallel, parallel_context) -model_reparallel.from_parallelized('test/') +model_reparallel.from_parallelized("test/") optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) dist.barrier() @@ -108,12 +107,11 @@ def bw(tensors): max_length=512, ).to("cuda") - loss_no_tp, notp_fw_time = \ - fw(model_no_tp, **inputs, labels=inputs["input_ids"]) - loss_tp, tp_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) - loss_reparallel, reparallel_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_reparallel, reparallel_fw_time = fw( + wrapper_tp, **inputs, labels=inputs["input_ids"] + ) if dist.get_rank() == 0: print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, reparallel:{loss_reparallel}") @@ -128,13 +126,15 @@ def bw(tensors): optimizer_reparallel.step() if dist.get_rank() == 0: - wandb.log({ - "tp.forward.time:": tp_fw_time, - "tp.backward.time:": tp_bw_time, - "notp.forward.time:": notp_fw_time, - "notp.backward.time:": notp_bw_time, - "reparallel.forward.time:": reparallel_fw_time, - "reparallel.backward.time:": reparallel_bw_time - }) - -dist.barrier() \ No newline at end of file + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "reparallel.forward.time:": reparallel_fw_time, + "reparallel.backward.time:": reparallel_bw_time, + } + ) + +dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py index 1d4591e9..903fb7c1 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py @@ -1,10 +1,15 @@ import torch.nn as nn -from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._wrapper import _TensorParallel2p5D +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._wrapper import ( + _TensorParallel2p5D, +) from oslo.torch.nn import Linear2p5D from oslo.torch.distributed import ParallelContext, ParallelMode from copy import deepcopy -from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import gather_1d, gather_2d +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( + gather_1d, + gather_2d, +) tp_size = 8 tp_depth = 2 @@ -23,7 +28,7 @@ tesseract_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) fusion_degree = 3 -linear = nn.Linear(4, fusion_degree*4).cuda() +linear = nn.Linear(4, fusion_degree * 4).cuda() w = deepcopy(linear.weight.data) b = deepcopy(linear.bias.data) @@ -34,29 +39,41 @@ bias_list = b.chunk(tesseract_dim * fusion_degree, dim=0) # [t][f*t] -weight_list = _TensorParallel2p5D._deconstruct_combined_qkv(weight_list, tesseract_dim, fusion_degree, False) -bias_list = _TensorParallel2p5D._deconstruct_combined_qkv(bias_list, tesseract_dim, fusion_degree, True) +weight_list = _TensorParallel2p5D._deconstruct_combined_qkv( + weight_list, tesseract_dim, fusion_degree, False +) +bias_list = _TensorParallel2p5D._deconstruct_combined_qkv( + bias_list, tesseract_dim, fusion_degree, True +) chunked_w = weight_list[row_rank][col_rank] chunked_b = bias_list[row_rank] -linear_2p5d = Linear2p5D(4, fusion_degree*4, parallel_context=parallel_context, bias=True) +linear_2p5d = Linear2p5D( + 4, fusion_degree * 4, parallel_context=parallel_context, bias=True +) if parallel_context.get_global_rank() == 0: print(chunked_w.size()) print(linear_2p5d.weight.data.size()) linear_2p5d.weight.data = chunked_w linear_2p5d.bias.data = chunked_b -recon_chunked_w = gather_2d(parallel_context, linear_2p5d.weight.data, tesseract_dim, True) +recon_chunked_w = gather_2d( + parallel_context, linear_2p5d.weight.data, tesseract_dim, True +) recon_chunked_b = gather_1d(parallel_context, linear_2p5d.bias.data, tesseract_dim, 0) # reshaped_w = recon_chunked_w.view(-1, tesseract_dim, 4) # recon_w = torch.cat([reshaped_w[:3], reshaped_w[3:]], 1).view(-1, 4).contiguous() -recon_w = _TensorParallel2p5D._reconstruct_combined_qkv(recon_chunked_w, tesseract_dim, fusion_degree, False) -recon_b = _TensorParallel2p5D._reconstruct_combined_qkv(recon_chunked_b, tesseract_dim, fusion_degree, True) +recon_w = _TensorParallel2p5D._reconstruct_combined_qkv( + recon_chunked_w, tesseract_dim, fusion_degree, False +) +recon_b = _TensorParallel2p5D._reconstruct_combined_qkv( + recon_chunked_b, tesseract_dim, fusion_degree, True +) if parallel_context.get_global_rank() == 0: print(f"original w: \n{w}\n") print(f"reconstruct w: \n{recon_w}\n") print(f"original b: \n{b}\n") - print(f"reconstruct b: \n{recon_b}\n") \ No newline at end of file + print(f"reconstruct b: \n{recon_b}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py index 3edfde0c..1db376eb 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py @@ -5,7 +5,11 @@ from oslo.torch.nn import VocabParallelEmbedding2p5D from oslo.torch.nn.parallel import utils -from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import split_batch_2d, split_2d, gather_2d +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( + split_batch_2d, + split_2d, + gather_2d, +) from copy import deepcopy @@ -82,4 +86,4 @@ # sse = torch.sum((out - pout) ** 2).item() # sse_update = torch.sum((out_update - pout_update) ** 2).item() # print(f"output sse: \n{sse}\n") -# print(f"next output sse: \n{sse_update}\n") \ No newline at end of file +# print(f"next output sse: \n{sse_update}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py index 46ea6c93..f8d1404b 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py @@ -1,4 +1,3 @@ -from copy import deepcopy import torch import torch.distributed as dist from oslo.torch.distributed import ParallelContext, ParallelMode @@ -96,8 +95,6 @@ print(f"next output min: \n{minmax_update.min()}\n") - - linear_2p5d = Linear2p5D( input_dim, hidden_dim, gather_output=True, parallel_context=parallel_context ) @@ -124,9 +121,10 @@ print(f"output sse (gather_output=True): \n{sse}\n") print(f"next output sse (gather_output=True): \n{sse_update}\n") import pprint + # top5 = torch.clamp(minmax_update.flatten(), 1e-8) top5 = minmax_update.flatten() top5 = [t.item() for t in top5] - top5 = [top5[i:i+4] for i in range(0, len(top5), 4)] + top5 = [top5[i : i + 4] for i in range(0, len(top5), 4)] pprint.pprint(top5) print(f"next output min: \n{minmax_update.min()}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py index fe0398fc..9ac65560 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py @@ -6,7 +6,12 @@ from datasets import load_dataset from torch.utils.data import DataLoader -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel, AutoModelForCausalLM +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, +) from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params @@ -18,7 +23,8 @@ def wrapper(*args, **kwargs): start = time.time() result = func(*args, **kwargs) end = time.time() - return result, end-start + return result, end - start + return wrapper @@ -37,8 +43,7 @@ def bw(tensors): model_name = "gpt2" model_name = "gpt2" -mkwargs = { -} +mkwargs = {} dataset_name = "squad" # parallel context 생성 @@ -99,21 +104,22 @@ def bw(tensors): max_length=512, ).to("cuda") - loss_no_tp, notp_fw_time = \ - fw(model_no_tp, **inputs, labels=inputs["input_ids"]) - loss_tp, tp_fw_time = \ - fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) _, notp_bw_time = bw(loss_no_tp) _, tp_bw_time = bw(loss_tp) if dist.get_rank() == 0: - wandb.log({ - "tp_loss": loss_tp, - "notp_loss": loss_no_tp, - "tp.forward.time:": tp_fw_time, - "tp.backward.time:": tp_bw_time, - "notp.forward.time:": notp_fw_time, - "notp.backward.time:": notp_bw_time}) + wandb.log( + { + "tp_loss": loss_tp, + "notp_loss": loss_no_tp, + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + } + ) dist.barrier() From 8715f7c5c433f89d6693e36e50f2a2cf93d75a06 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Thu, 11 Aug 2022 01:29:50 +0900 Subject: [PATCH 26/37] fixed tp2d --- .../tensor_parallel/_parallel_2d/_wrapper.py | 25 ++++--- .../_parallel_2p5d/_wrapper.py | 8 +-- .../tensor_parallel/tensor_parallel.py | 69 ++++++------------- .../deparallel/test_deparallelize.py | 6 +- 4 files changed, 43 insertions(+), 65 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index d71d44c3..9a30b970 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -164,6 +164,9 @@ def _parallelize_head(self): reversed=self.tensor_parallel_mapping.is_reversed( self.module, param_name ), + gather_output=self.tensor_parallel_mapping.is_gather_output( + self.module, param_name + ), ) @staticmethod @@ -385,7 +388,7 @@ def _slice_layernorm(self, module): ) module.__class__ = LayerNorm2D - def _slice_head(self, module, reversed): + def _slice_head(self, module, reversed, gather_output): if module.weight is not self.module.get_input_embeddings().weight: self._slice_linear( module=module, @@ -395,7 +398,7 @@ def _slice_head(self, module, reversed): ) _update_module_arguments( module=module, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, ) else: row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW) @@ -443,7 +446,7 @@ def _slice_head(self, module, reversed): reversed=reversed, fusion_degree=1, skip_bias_add=False, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, orig_module=copy.deepcopy(module.__class__), ) module.__class__ = Linear2D @@ -515,12 +518,9 @@ def _deparallelize_layernorm(self): def _gather_embedding(self, module): summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): - w = module.weight.data - - if module.embedding_dim == module.weight.size()[0]: - w = gather_2d( - self.parallel_context, module.weight.data, summa_dim, col_first=True - ) + w = gather_2d( + self.parallel_context, module.weight.data, summa_dim, col_first=True + ) assert hasattr( self.module, "orig_vocab_size" @@ -554,9 +554,11 @@ def _gather_head(self, module: Linear2D): return self._gather_linear(module) elif hasattr(module, "bias") and module.bias is not None: self._zero_rank_log("before gathering bias") - summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D) + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_ROW) - b = gather_1d(self.parallel_context, module.bias.data, summa_dim, 0) + b = gather_1d_twice( + self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0 + ) module.bias.data = b[: module.weight.size()[0]] self._zero_rank_log("after gathering bias") @@ -570,6 +572,7 @@ def _gather_head(self, module: Linear2D): if hasattr(module, "skip_bias_add") else False, ) + del module.row_rank del module.col_rank del module.summa_dim diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index d8e1185e..68a70efd 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -159,7 +159,7 @@ def _parallelize_head(self): reversed=self.tensor_parallel_mapping.is_reversed( self.module, param_name ), - gathered=self.tensor_parallel_mapping.is_gather_output( + gather_output=self.tensor_parallel_mapping.is_gather_output( self.module, param_name ), ) @@ -408,7 +408,7 @@ def _slice_layernorm(self, module): module.__class__ = LayerNorm2p5D return module - def _slice_head(self, module, reversed, gathered): + def _slice_head(self, module, reversed, gather_output): if module.weight is not self.module.get_input_embeddings().weight: self._slice_linear( module=module, @@ -418,7 +418,7 @@ def _slice_head(self, module, reversed, gathered): ) _update_module_arguments( module=module, - gather_output=not is_oslo_model(self.module) and gathered, + gather_output=not is_oslo_model(self.module) and gather_output, ) else: row_rank = self.parallel_context.get_local_rank( @@ -484,7 +484,7 @@ def _slice_head(self, module, reversed, gathered): reversed=reversed, fusion_degree=1, skip_bias_add=False, - gather_output=not is_oslo_model(self.module) and gathered, + gather_output=not is_oslo_model(self.module) and gather_output, orig_module=copy.deepcopy(module.__class__), ) module.__class__ = Linear2p5D diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index f4cb7e13..6d455179 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -100,7 +100,6 @@ def __init__( self.parallel_context = get_parallel_context(module, parallel_context) module = self._resize_vocab_size(module, self.parallel_context) module = self._resize_num_classes(module, self.parallel_context, mapping) - module = self._resize_head_bias_size(module, self.parallel_context, mapping) if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D: self.module = _TensorParallel1D( @@ -193,6 +192,25 @@ def _resize_num_classes(model, parallel_context, mapping): f"orig_{param_name.split('.')[-1]}_num_classes", out_features, ) + + if hasattr(module, "bias") and module.bias is not None: + out_features = module.bias.size()[0] + new_out_features = out_features + + while new_out_features % divisible_by != 0: + new_out_features += 1 + + if new_out_features != out_features: + padding = torch.zeros( + new_out_features - out_features, + dtype=module.bias.dtype, + device=module.bias.device, + ) + new_bias = torch.cat( + tensors=[module.bias.data, padding], + dim=0, + ) + module.bias.data = new_bias else: out_features, in_features = module.weight.size() new_out_features = out_features @@ -234,49 +252,6 @@ def _resize_num_classes(model, parallel_context, mapping): ) return model - @staticmethod - def _resize_head_bias_size(model, parallel_context, mapping): - unwrapped_model = unwrap_parallel(model) - divisible_by = get_divisible_by(parallel_context) - - if mapping is None: - if is_huggingface_model(unwrapped_model): - mapping = _TensorParallelMappingForHuggingFace().get_mapping( - unwrapped_model - ) - else: - raise ValueError( - "`mapping` must be input if the model is not huggingface model." - ) - tensor_parallel_mapping = TensorParallelMapping(mapping) - divisible_by = get_divisible_by(parallel_context) - - for param_name, module in unwrapped_model.named_modules(): - if ( - tensor_parallel_mapping.is_head(unwrapped_model, param_name) - and unwrapped_model.get_input_embeddings().weight is module.weight - and hasattr(module, "bias") - and module.bias is not None - ): - out_features = module.bias.size()[0] - new_out_features = out_features - - while new_out_features % divisible_by != 0: - new_out_features += 1 - - if new_out_features != out_features: - padding = torch.zeros( - new_out_features - out_features, - dtype=module.bias.dtype, - device=module.bias.device, - ) - new_bias = torch.cat( - tensors=[module.bias.data, padding], - dim=0, - ) - module.bias.data = new_bias - return model - @torch.no_grad() def save_parallelized( self, @@ -298,9 +273,9 @@ def save_parallelized( new_module = self._resize_num_classes( new_module, self.parallel_context, mapping ) - new_module = self._resize_head_bias_size( - new_module, self.parallel_context, mapping - ) + # new_module = self._resize_head_bias_size( + # new_module, self.parallel_context, mapping + # ) new_module = self.module.save_parallelized( new_module, diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py index f62e96d2..a7878003 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py @@ -72,7 +72,7 @@ def bw(tensors): tp_size = 4 tp_depth = 1 -model_name = "gpt2" +model_name = "roberta-base" mkwargs = { # 'pad_token': '[PAD]' } @@ -83,13 +83,13 @@ def bw(tensors): data_parallel_size=1, pipeline_parallel_size=1, tensor_parallel_size=tp_size, - tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_mode=ParallelMode.TENSOR_2D, tensor_parallel_depth=tp_depth, ) # 토크나이저 생성 tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) -tokenizer.pad_token = tokenizer.eos_token +# tokenizer.pad_token = tokenizer.eos_token # 모델 생성 및 병렬화 수행 model_no_tp = AutoModelForCausalLM.from_config( From 85a4ff11816f8319a7344f1e596dd6b3e7592034 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Thu, 11 Aug 2022 02:00:34 +0900 Subject: [PATCH 27/37] precommit run --- oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 9a30b970..ff97ad9c 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -572,7 +572,7 @@ def _gather_head(self, module: Linear2D): if hasattr(module, "skip_bias_add") else False, ) - + del module.row_rank del module.col_rank del module.summa_dim From 40f74c253ceb805cb8d6bb00e332d290a9f97e25 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Thu, 11 Aug 2022 22:16:19 +0900 Subject: [PATCH 28/37] applied code review --- oslo/torch/nn/modules/linear.py | 16 ++++++++++++---- .../tensor_parallel/_parallel_2p5d/_ops.py | 12 ------------ .../parallel/tensor_parallel/tensor_parallel.py | 3 --- 3 files changed, 12 insertions(+), 19 deletions(-) diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index bc21e52b..088a20b8 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -167,7 +167,8 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: outputs = all_gather_tensor_1d(outputs, -1, self.parallel_context).clone() if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] - outputs = outputs.contiguous() + if not outputs.is_contiguous(): + outputs = outputs.contiguous() return outputs @@ -227,7 +228,10 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: else: return outputs + self.bias - return outputs.contiguous() + if not ouptuts.is_contiguous(): + outputs = outputs.contiguous() + + return outputs class Linear2D(Linear): @@ -368,7 +372,9 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: ) if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] - outputs = outputs.contiguous() + + if not outputs.is_contiguous(): + outputs = outputs.contiguous() return outputs @@ -511,7 +517,9 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: ).clone() if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] - outputs = outputs.contiguous() + + if not outputs.is_contiguous(): + outputs = outputs.contiguous() return outputs diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py index e735b8fd..44acd879 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py @@ -1112,18 +1112,6 @@ def backward(ctx: Any, output_grad: Tensor): return output_grad, None, None -# def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor: -# r"""All-reduce the input. -# Args: -# input_ (:class:`torch.tensor`): Input tensor. -# parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode tensor used. -# Note: -# The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found -# in `parallel_mode `_ -# """ -# return _ReduceTensor2p5D.apply(input_, parallel_mode) - - class _ReduceScatterTensor2p5D(torch.autograd.Function): @staticmethod def forward( diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index 6d455179..d253d68e 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -273,9 +273,6 @@ def save_parallelized( new_module = self._resize_num_classes( new_module, self.parallel_context, mapping ) - # new_module = self._resize_head_bias_size( - # new_module, self.parallel_context, mapping - # ) new_module = self.module.save_parallelized( new_module, From 2e752a51080a0514c69e1e12c93250bc9277e4c1 Mon Sep 17 00:00:00 2001 From: jason9693 Date: Thu, 11 Aug 2022 22:16:30 +0900 Subject: [PATCH 29/37] applied code review --- .../parallel/tensor_parallel/_parallel_2p5d/_wrapper.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 68a70efd..8be6bdf8 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -497,15 +497,6 @@ def _zero_rank_log(self, txt): # 모니터링 생성 대기 dist.barrier() - def _pdb_set_trace(self): - import pdb - import torch.distributed as dist - - if dist.get_rank() == 0: - pdb.set_trace() - # 모니터링 생성 대기 - dist.barrier() - @torch.no_grad() def deparallelize(self): # must deparallelize embedding first than linear From a92041b02014223012ed5da7d658556fbe153d1f Mon Sep 17 00:00:00 2001 From: jason9693 Date: Fri, 12 Aug 2022 00:26:10 +0900 Subject: [PATCH 30/37] applied code review --- .../tensor_parallel/_parallel_1d/_wrapper.py | 27 ++++++---------- .../tensor_parallel/_parallel_2d/_wrapper.py | 23 +++++++------- .../_parallel_2p5d/_wrapper.py | 31 +++++++------------ oslo/torch/nn/parallel/utils.py | 9 ++++++ 4 files changed, 43 insertions(+), 47 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 9543468e..e6f215b6 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -24,6 +24,7 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, + zero_rank_log ) from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( BaseTensorParallelWrapper, @@ -388,28 +389,20 @@ def _slice_head(self, module, reversed): ) module.__class__ = ColLinear1D - def _zero_rank_log(self, txt): - import torch.distributed as dist - - if dist.get_rank() == 0: - print(txt) - # 모니터링 생성 대기 - dist.barrier() - @torch.no_grad() def deparallelize(self): # must deparallelize embedding first than linear - self._zero_rank_log("deparallelize embedding start") + zero_rank_log("deparallelize embedding start") self._deparallelize_embedding() - self._zero_rank_log("deparallelize embedding end") + zero_rank_log("deparallelize embedding end") - self._zero_rank_log("deparallelize linear start") + zero_rank_log("deparallelize linear start") self._deparallelize_linear() - self._zero_rank_log("deparallelize linear end") + zero_rank_log("deparallelize linear end") - self._zero_rank_log("deparallelize head start") + zero_rank_log("deparallelize head start") self._deparallelize_head() - self._zero_rank_log("deparallelize head end") + zero_rank_log("deparallelize head end") self._rollback_mp_arguments() @@ -443,7 +436,7 @@ def _deparallelize_head(self): if self.tensor_parallel_mapping.is_head( self.module, param_name ) and isinstance(module, ColLinear1D): - self._zero_rank_log(f"deparallelize head {param_name}") + zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) def _gather_embedding(self, module): @@ -540,13 +533,13 @@ def _gather_head(self, module: ColLinear1D): if module.weight is not self.module.get_input_embeddings().weight: return self._gather_column_linear(module) elif hasattr(module, "bias") and module.bias is not None: - self._zero_rank_log("before gathering head bias") + zero_rank_log("before gathering head bias") world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) b = self._reconstruct_combined_qkv(module.bias, world_size, 1, 0) module.bias.data = b[: module.weight.size()[0]] - self._zero_rank_log("after gathering head bias") + zero_rank_log("after gathering head bias") _update_module_arguments( module=module, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index ff97ad9c..9b4d7d12 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -30,6 +30,7 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, + zero_rank_log ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -462,21 +463,21 @@ def _zero_rank_log(self, txt): @torch.no_grad() def deparallelize(self): # must deparallelize embedding first than linear - self._zero_rank_log("deparallelize embedding start") + zero_rank_log("deparallelize embedding start") self._deparallelize_embedding() - self._zero_rank_log("deparallelize embedding end") + zero_rank_log("deparallelize embedding end") - self._zero_rank_log("deparallelize linear start") + zero_rank_log("deparallelize linear start") self._deparallelize_linear() - self._zero_rank_log("deparallelize linear end") + zero_rank_log("deparallelize linear end") - self._zero_rank_log("deparallelize layernorm start") + zero_rank_log("deparallelize layernorm start") self._deparallelize_layernorm() - self._zero_rank_log("deparallelize layernorm end") + zero_rank_log("deparallelize layernorm end") - self._zero_rank_log("deparallelize head start") + zero_rank_log("deparallelize head start") self._deparallelize_head() - self._zero_rank_log("deparallelize head end") + zero_rank_log("deparallelize head end") self._rollback_mp_arguments() @@ -507,7 +508,7 @@ def _deparallelize_head(self): if self.tensor_parallel_mapping.is_head( self.module, param_name ) and isinstance(module, Linear2D): - self._zero_rank_log(f"deparallelize head {param_name}") + zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) def _deparallelize_layernorm(self): @@ -553,7 +554,7 @@ def _gather_head(self, module: Linear2D): if module.weight is not self.module.get_input_embeddings().weight: return self._gather_linear(module) elif hasattr(module, "bias") and module.bias is not None: - self._zero_rank_log("before gathering bias") + zero_rank_log("before gathering bias") summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_ROW) b = gather_1d_twice( @@ -561,7 +562,7 @@ def _gather_head(self, module: Linear2D): ) module.bias.data = b[: module.weight.size()[0]] - self._zero_rank_log("after gathering bias") + zero_rank_log("after gathering bias") _update_module_arguments( module=module, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 8be6bdf8..d3001785 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -29,6 +29,7 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, + zero_rank_log ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -489,32 +490,24 @@ def _slice_head(self, module, reversed, gather_output): ) module.__class__ = Linear2p5D - def _zero_rank_log(self, txt): - import torch.distributed as dist - - if dist.get_rank() == 0: - print(txt) - # 모니터링 생성 대기 - dist.barrier() - @torch.no_grad() def deparallelize(self): # must deparallelize embedding first than linear - self._zero_rank_log("deparallelize embedding start") + zero_rank_log("deparallelize embedding start") self._deparallelize_embedding() - self._zero_rank_log("deparallelize embedding end") + zero_rank_log("deparallelize embedding end") - self._zero_rank_log("deparallelize linear start") + zero_rank_log("deparallelize linear start") self._deparallelize_linear() - self._zero_rank_log("deparallelize linear end") + zero_rank_log("deparallelize linear end") - self._zero_rank_log("deparallelize layernorm start") + zero_rank_log("deparallelize layernorm start") self._deparallelize_layernorm() - self._zero_rank_log("deparallelize layernorm end") + zero_rank_log("deparallelize layernorm end") - self._zero_rank_log("deparallelize head start") + zero_rank_log("deparallelize head start") self._deparallelize_head() - self._zero_rank_log("deparallelize head end") + zero_rank_log("deparallelize head end") self._rollback_mp_arguments() @@ -547,7 +540,7 @@ def _deparallelize_head(self): if self.tensor_parallel_mapping.is_head( self.module, param_name ) and isinstance(module, Linear2p5D): - self._zero_rank_log(f"deparallelize head {param_name}") + zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) def _deparallelize_layernorm(self): @@ -595,7 +588,7 @@ def _gather_head(self, module: Linear2p5D): if module.weight is not self.module.get_input_embeddings().weight: return self._gather_linear(module) elif hasattr(module, "bias") and module.bias is not None: - self._zero_rank_log("before gathering bias") + zero_rank_log("before gathering bias") tesseract_dim = self.parallel_context.get_world_size( ParallelMode.TENSOR_2P5D_COL ) @@ -603,7 +596,7 @@ def _gather_head(self, module: Linear2p5D): b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) module.bias.data = b[: module.weight.size()[0]] - self._zero_rank_log("after gathering bias") + zero_rank_log("after gathering bias") _update_module_arguments( module=module, diff --git a/oslo/torch/nn/parallel/utils.py b/oslo/torch/nn/parallel/utils.py index a74ecfe7..f9bfa9d2 100644 --- a/oslo/torch/nn/parallel/utils.py +++ b/oslo/torch/nn/parallel/utils.py @@ -105,3 +105,12 @@ def get_parallel_context(module: nn.Module, parallel_context: ParallelContext): ) return parallel_context + + +def zero_rank_log(txt): + import torch.distributed as dist + + if dist.get_rank() == 0: + print(txt) + # 모니터링 생성 대기 + dist.barrier() \ No newline at end of file From 43d41887e2c55135da226639516d4756bcebdc24 Mon Sep 17 00:00:00 2001 From: jason960903 Date: Fri, 12 Aug 2022 00:29:30 +0900 Subject: [PATCH 31/37] precommit run --- oslo/torch/nn/modules/linear.py | 2 +- oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py | 2 +- oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py | 2 +- .../nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py | 2 +- oslo/torch/nn/parallel/utils.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index 088a20b8..c5670f10 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -228,7 +228,7 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: else: return outputs + self.bias - if not ouptuts.is_contiguous(): + if not outputs.is_contiguous(): outputs = outputs.contiguous() return outputs diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index e6f215b6..d4c62811 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -24,7 +24,7 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, - zero_rank_log + zero_rank_log, ) from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( BaseTensorParallelWrapper, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 9b4d7d12..8e4191da 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -30,7 +30,7 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, - zero_rank_log + zero_rank_log, ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index d3001785..5f18a685 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -29,7 +29,7 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, - zero_rank_log + zero_rank_log, ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, diff --git a/oslo/torch/nn/parallel/utils.py b/oslo/torch/nn/parallel/utils.py index f9bfa9d2..f2b2be45 100644 --- a/oslo/torch/nn/parallel/utils.py +++ b/oslo/torch/nn/parallel/utils.py @@ -113,4 +113,4 @@ def zero_rank_log(txt): if dist.get_rank() == 0: print(txt) # 모니터링 생성 대기 - dist.barrier() \ No newline at end of file + dist.barrier() From b34130c70ca210e8e3c84d1e93d13fea7a9a2be1 Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Sat, 13 Aug 2022 15:35:20 +0900 Subject: [PATCH 32/37] Implemented 3d Deparallelization --- .../tensor_parallel/_parallel_3d/_ops.py | 72 +++++ .../tensor_parallel/_parallel_3d/_wrapper.py | 266 ++++++++++++++++++ 2 files changed, 338 insertions(+) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py index 209f1e4c..b03fdb01 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py @@ -1,6 +1,7 @@ from typing import Any, Tuple, Optional import torch +import torch.ditributed as dist from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd @@ -919,3 +920,74 @@ def broadcast_weight_3d_from_diagonal( weight_parallel_mode, output_parallel_mode, ) + + +def gather_3d(parallel_context, tensor, cubic_dim): + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_3D_OUTPUT), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_3D_INPUT), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_3D_WEIGHT), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_2d(parallel_context, tensor, cubic_dim, col_first=True): + if col_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_1d(parallel_context, tensor, cubic_dim, dim=-1): + parallel_modde = ParallelMode.TENSOR_2P5D_ROW + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(parallel_modde), + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py index bdbe431b..524415b1 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py @@ -17,6 +17,9 @@ ) from oslo.torch.nn.parallel.tensor_parallel._parallel_3d._ops import ( split_batch_3d, + gather_3d, + gather_2d, + gather_1d ) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, @@ -26,6 +29,7 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, + zero_rank_log ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -422,3 +426,265 @@ def _slice_head(self, module, reversed): orig_module=copy.deepcopy(module.__class__), ) module.__class__ = Linear3D + + @torch.no_grad() + def deparallelize(self): + # must deparallelize embedding first than linear + zero_rank_log("deparallelize embedding start") + self._deparallelize_embedding() + zero_rank_log("deparallelize embedding end") + + zero_rank_log("deparallelize linear start") + self._deparallelize_linear() + zero_rank_log("deparallelize linear end") + + zero_rank_log("deparallelize layernorm start") + self._deparallelize_layernorm() + zero_rank_log("deparallelize layernorm end") + + zero_rank_log("deparallelize head start") + self._deparallelize_head() + zero_rank_log("deparallelize head end") + + self._rollback_mp_arguments() + + def _rollback_mp_arguments(self): + for module in self.module.modules(): + for elem in self.tensor_parallel_mapping.update_attrs(self.module): + if hasattr(module, elem.name): + cubic_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_3D_INPUT + ) + expanded_arg = getattr(module, elem.name) * cubic_dim + setattr(module, elem.name, expanded_arg) + + def _deparallelize_embedding(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == VocabParallelEmbedding3D: + self._gather_embedding(module) + if module.__class__ == Embedding3D: + self._gather_embedding(module) + + def _deparallelize_linear(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_column_parallel( + self.module, param_name + ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): + self._gather_linear(module) + + def _deparallelize_head(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_head( + self.module, param_name + ) and isinstance(module, Linear3D): + zero_rank_log(f"deparallelize head {param_name}") + self._gather_head(module) + + def _deparallelize_layernorm(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == LayerNorm3D: + self._gather_layernorm(module) + + def _gather_embedding(self, module): + cubic_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_3D_INPUT + ) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + w = gather_3d( + self.parallel_context, module.weight.data, cubic_dim + ) + + assert hasattr( + self.module, "orig_vocab_size" + ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + orig_vocab_size = self.module.orig_vocab_size + module.weight.data = w[:orig_vocab_size, :].contiguous() + + _update_module_arguments( + module=module, + vocab_start_index=None, + vocab_end_index=None, + parallel_context=None, + num_embeddings=module.weight.size()[0], + embedding_dim=module.weight.size()[1], + orig_module=None, + ) + else: + w = gather_1d(self.parallel_context, module.weight, cubic_dim, -1) + module.weight.data = w + + _update_module_arguments( + module=module, + parallel_context=None, + embedding_dim=module.weight.size()[1], + ) + module.__class__ = nn.Embedding + + def _gather_head(self, module: Linear3D): + if module.weight is not self.module.get_input_embeddings().weight: + return self._gather_linear(module) + elif hasattr(module, "bias") and module.bias is not None: + zero_rank_log("before gathering bias") + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) + + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + + module.bias.data = b[: module.weight.size()[0]] + zero_rank_log("after gathering bias") + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + + module.__class__ = nn.Linear + + def _gather_linear(self, module: Linear3D): + is_reversed = module.reversed + fusion_degree = module.fusion_degree + # slice_bias = module.slice_bias + + cubic_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_3D_INPUT + ) + + w = gather_2d( + self.parallel_context, + module.weight.data, + cubic_dim=cubic_dim, + col_first=True, + ) + if fusion_degree > 1: + w = self._reconstruct_combined_qkv(w, cubic_dim, fusion_degree, False) + if is_reversed: + w = w.t() + module.weight.data = w + + if hasattr(module, "bias") and module.bias is not None: + b = gather_1d(self.parallel_context, module.bias.data, cubic_dim, 0) + if fusion_degree > 1: + b = self._reconstruct_combined_qkv( + b, cubic_dim, fusion_degree, True + ) + b = b.view(b.size()[1:]) + module.bias.data = b + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + + module.__class__ = nn.Linear + + def _gather_layernorm(self, module): + cubic_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_3D_INPUT + ) + if hasattr(module, "weight") and module.weight is not None: + if module.weight.dim() >= 1: + w = gather_1d( + self.parallel_context, module.weight.data, cubic_dim, 0 + ) + module.weight.data = w + + if hasattr(module.weight, "oslo_parallel"): + del module.weight.oslo_parallel + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + b = gather_1d(self.parallel_context, module.bias.data, cubic_dim, 0) + module.bias.data = b + + if hasattr(module.bias, "oslo_parallel"): + del module.bias.oslo_parallel + + del module.partitioned_dim + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.orig_module + _update_module_arguments( + module, + normalized_shape=module.weight.size()[0], + ) + module.__class__ = nn.LayerNorm + + @staticmethod + def _reconstruct_combined_qkv(tensor, cubic_dim, fusion_degree, is_bias=False): + last_dim = tensor.size()[-1] + if is_bias is False: + reshaped_w = tensor.view(cubic_dim * fusion_degree, -1, last_dim) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(cubic_dim) + ], + 1, + ).view(-1, last_dim).contiguous() + ) + else: + reshaped_w = tensor.view(fusion_degree * cubic_dim, -1) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(cubic_dim) + ], + 1, + ).view(-1, last_dim).contiguous() + ) + return recon_w + + @staticmethod + def _reconstrunct_combined_qkv_bias(tensor, cubic_dim, fusion_degree): + tensor = [ + [tensor[j * cubic_dim + k] for k in range(cubic_dim)] + for j in range(fusion_degree) + ] + tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) + tensor = [tensor[j] for j in range(cubic_dim)] + return tensor From 8ac13e446b4b94b915f9a195b4b7d6bc60e06fc4 Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Mon, 15 Aug 2022 21:05:02 +0900 Subject: [PATCH 33/37] add 3d deparallez --- oslo/torch/nn/modules/linear.py | 3 + .../tensor_parallel/_parallel_3d/_ops.py | 21 ++- .../tensor_parallel/_parallel_3d/_wrapper.py | 103 +++++----- .../_parallel_3d/deparallel/__init__.py | 3 + .../deparallel/test_deparallelize.py | 176 ++++++++++++++++++ 5 files changed, 247 insertions(+), 59 deletions(-) create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/__init__.py create mode 100644 tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/test_deparallelize.py diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index c5670f10..c7230a52 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -603,4 +603,7 @@ def forward(self, input: Tensor) -> Tensor: ) if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] + + if not outputs.is_contiguous(): + outputs = outputs.contiguous() return outputs diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py index b03fdb01..98f4990a 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py @@ -1,7 +1,7 @@ from typing import Any, Tuple, Optional import torch -import torch.ditributed as dist +import torch.distributed as dist from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd @@ -947,20 +947,20 @@ def gather_3d(parallel_context, tensor, cubic_dim): return tensor -def gather_2d(parallel_context, tensor, cubic_dim, col_first=True): - if col_first: +def gather_2d(parallel_context, tensor, cubic_dim, input_first=True): + if input_first: tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] dist.all_gather( tensor_list, tensor.contiguous(), - parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + parallel_context.get_group(ParallelMode.TENSOR_3D_INPUT), ) tensor = torch.cat(tensor_list, dim=0) tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] dist.all_gather( tensor_list, tensor, - parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + parallel_context.get_group(ParallelMode.TENSOR_3D_OUTPUT), ) tensor = torch.cat(tensor_list, dim=-1) else: @@ -968,26 +968,27 @@ def gather_2d(parallel_context, tensor, cubic_dim, col_first=True): dist.all_gather( tensor_list, tensor, - parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + parallel_context.get_group(ParallelMode.TENSOR_3D_OUTPUT), ) tensor = torch.cat(tensor_list, dim=-1) tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] dist.all_gather( tensor_list, tensor.contiguous(), - parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + parallel_context.get_group(ParallelMode.TENSOR_3D_INPUT), ) tensor = torch.cat(tensor_list, dim=0) return tensor -def gather_1d(parallel_context, tensor, cubic_dim, dim=-1): - parallel_modde = ParallelMode.TENSOR_2P5D_ROW +def gather_1d(parallel_context, tensor, cubic_dim, dim=-1, mode=None): + if mode is None: + mode = ParallelMode.TENSOR_3D_OUTPUT if dim == -1 else ParallelMode.TENSOR_3D_INPUT tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] dist.all_gather( tensor_list, tensor.contiguous(), - parallel_context.get_group(parallel_modde), + parallel_context.get_group(mode), ) tensor = torch.cat(tensor_list, dim=dim) return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py index 524415b1..288217ff 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py @@ -24,8 +24,10 @@ from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, ) +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, +) from oslo.torch.nn.parallel.utils import ( - ParallelWrapper, _update_module_arguments, is_huggingface_model, is_oslo_model, @@ -38,7 +40,7 @@ from oslo.transformers.constants import BATCH_DIMENSIONS -class _TensorParallel3D(ParallelWrapper): +class _TensorParallel3D(BaseTensorParallelWrapper): """ PyTorch module for 3D tensor parallelism @@ -52,8 +54,9 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): - super().__init__() + super().__init__(module, parallel_context, mapping, module_args) self.module = module self.parallel_context = parallel_context self.device = torch.cuda.current_device() @@ -66,6 +69,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() @@ -150,6 +162,9 @@ def _parallelize_head(self): reversed=self.tensor_parallel_mapping.is_reversed( self.module, param_name ), + gather_output=self.tensor_parallel_mapping.is_gather_output( + self.module, param_name + ), ) @staticmethod @@ -396,7 +411,7 @@ def _slice_layernorm(self, module): ) module.__class__ = LayerNorm3D - def _slice_head(self, module, reversed): + def _slice_head(self, module, reversed, gather_output): if module.weight is not self.module.get_input_embeddings().weight: self._slice_linear( module=module, @@ -406,12 +421,41 @@ def _slice_head(self, module, reversed): ) _update_module_arguments( module=module, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, ) else: cubic_dim = self.parallel_context.get_world_size( ParallelMode.TENSOR_3D_INPUT ) + input_rank = self.parallel_context.get_local_rank( + ParallelMode.TENSOR_3D_INPUT + ) + output_rank = self.parallel_context.get_local_rank( + ParallelMode.TENSOR_3D_OUTPUT + ) + weight_rank = self.parallel_context.get_local_rank( + ParallelMode.TENSOR_3D_WEIGHT + ) + # TODO: add bias chunking + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + bias_list = module.bias.chunk(cubic_dim, dim=0) + module.bias.data = bias_list[input_rank].contiguous() + + if hasattr(module.bias, "oslo_parallel"): + module.bias.oslo_parallel[ParallelMode.TENSOR_3D_INPUT] = input_rank + module.bias.oslo_parallel[ + ParallelMode.TENSOR_3D_OUTPUT + ] = output_rank + module.bias.oslo_parallel[ + ParallelMode.TENSOR_3D_WEIGHT + ] = weight_rank + else: + module.bias.oslo_parallel = { + ParallelMode.TENSOR_3D_INPUT: input_rank, + ParallelMode.TENSOR_3D_OUTPUT: output_rank, + ParallelMode.TENSOR_3D_WEIGHT: weight_rank, + } _update_module_arguments( module=module, @@ -422,10 +466,11 @@ def _slice_head(self, module, reversed): reversed=reversed, fusion_degree=1, skip_bias_add=False, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, orig_module=copy.deepcopy(module.__class__), ) module.__class__ = Linear3D + module.__class__ = Linear3D @torch.no_grad() def deparallelize(self): @@ -526,7 +571,7 @@ def _gather_head(self, module: Linear3D): elif hasattr(module, "bias") and module.bias is not None: zero_rank_log("before gathering bias") tesseract_dim = self.parallel_context.get_world_size( - ParallelMode.TENSOR_2P5D_COL + ParallelMode.TENSOR_3D_INPUT ) b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) @@ -543,19 +588,6 @@ def _gather_head(self, module: Linear3D): if hasattr(module, "skip_bias_add") else False, ) - del module.row_rank - del module.col_rank - del module.dep_rank - del module.tesseract_dim - del module.data_parallel_rank - del module.pipeline_parallel_rank - del module.tensor_parallel_size - del module.pipeline_parallel_size - del module.reversed - del module.fusion_degree - del module.orig_module - del module.gather_output - del module.parallel_context module.__class__ = nn.Linear @@ -568,16 +600,12 @@ def _gather_linear(self, module: Linear3D): ParallelMode.TENSOR_3D_INPUT ) - w = gather_2d( - self.parallel_context, - module.weight.data, - cubic_dim=cubic_dim, - col_first=True, - ) + w = gather_1d(self.parallel_context, module.weight.data, cubic_dim, 0, ParallelMode.TENSOR_3D_WEIGHT) if fusion_degree > 1: w = self._reconstruct_combined_qkv(w, cubic_dim, fusion_degree, False) if is_reversed: w = w.t() + w = gather_2d(self.parallel_context, w, cubic_dim=cubic_dim, input_first=False) module.weight.data = w if hasattr(module, "bias") and module.bias is not None: @@ -598,19 +626,6 @@ def _gather_linear(self, module: Linear3D): if hasattr(module, "skip_bias_add") else False, ) - del module.row_rank - del module.col_rank - del module.dep_rank - del module.tesseract_dim - del module.data_parallel_rank - del module.pipeline_parallel_rank - del module.tensor_parallel_size - del module.pipeline_parallel_size - del module.reversed - del module.fusion_degree - del module.orig_module - del module.gather_output - del module.parallel_context module.__class__ = nn.Linear @@ -636,16 +651,6 @@ def _gather_layernorm(self, module): if hasattr(module.bias, "oslo_parallel"): del module.bias.oslo_parallel - del module.partitioned_dim - del module.row_rank - del module.col_rank - del module.dep_rank - del module.tesseract_dim - del module.data_parallel_rank - del module.pipeline_parallel_rank - del module.tensor_parallel_size - del module.pipeline_parallel_size - del module.orig_module _update_module_arguments( module, normalized_shape=module.weight.size()[0], diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/__init__.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/__init__.py new file mode 100644 index 00000000..28fcc144 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/__init__.py @@ -0,0 +1,3 @@ +from .. import _utils + +_ALL__ = [_utils] diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/test_deparallelize.py new file mode 100644 index 00000000..f58401d6 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/test_deparallelize.py @@ -0,0 +1,176 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset + +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, + AutoConfig, +) + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + +import numpy as np +import os +import random + + +def seed_all(seed: int = 1930): + + print("Using Seed Number {}".format(seed)) + + os.environ["PYTHONHASHSEED"] = str( + seed + ) # set PYTHONHASHSEED env var at fixed value + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) + np.random.seed(seed) # for numpy pseudo-random generator + random.seed(seed) # set fixed value for python built-in pseudo-random generator + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + + +def seed_worker(_worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +seed_all(seed=1994) + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end - start + + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 8 +tp_depth = 1 + +model_name = "jason960903/soongsil-bert-base" +mkwargs = {} +dataset_name = "korquadv1" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_3D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name) +).cuda() +model_tp = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp3d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized("test/", merge_checkpoints=True) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() +optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + optimizer_gathered.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = fw( + model_gathered, **inputs, labels=inputs["input_ids"] + ) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, gathered_bw_time = bw(loss_gathered) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_gathered.step() + + if dist.get_rank() == 0: + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time, + } + ) + +dist.barrier() From fda2b0d0d4b802bbd2fb6bdf581a4cff024bb009 Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Tue, 16 Aug 2022 08:07:33 +0900 Subject: [PATCH 34/37] precommit run --- .../tensor_parallel/_parallel_3d/_ops.py | 4 +- .../tensor_parallel/_parallel_3d/_wrapper.py | 52 +++++++++---------- 2 files changed, 29 insertions(+), 27 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py index 98f4990a..1bc27977 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py @@ -983,7 +983,9 @@ def gather_2d(parallel_context, tensor, cubic_dim, input_first=True): def gather_1d(parallel_context, tensor, cubic_dim, dim=-1, mode=None): if mode is None: - mode = ParallelMode.TENSOR_3D_OUTPUT if dim == -1 else ParallelMode.TENSOR_3D_INPUT + mode = ( + ParallelMode.TENSOR_3D_OUTPUT if dim == -1 else ParallelMode.TENSOR_3D_INPUT + ) tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] dist.all_gather( tensor_list, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py index 288217ff..8ae12dad 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py @@ -19,7 +19,7 @@ split_batch_3d, gather_3d, gather_2d, - gather_1d + gather_1d, ) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, @@ -31,7 +31,7 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, - zero_rank_log + zero_rank_log, ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -443,7 +443,9 @@ def _slice_head(self, module, reversed, gather_output): module.bias.data = bias_list[input_rank].contiguous() if hasattr(module.bias, "oslo_parallel"): - module.bias.oslo_parallel[ParallelMode.TENSOR_3D_INPUT] = input_rank + module.bias.oslo_parallel[ + ParallelMode.TENSOR_3D_INPUT + ] = input_rank module.bias.oslo_parallel[ ParallelMode.TENSOR_3D_OUTPUT ] = output_rank @@ -513,14 +515,14 @@ def _deparallelize_embedding(self): def _deparallelize_linear(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_column_parallel( - self.module, param_name + self.module, param_name ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): self._gather_linear(module) def _deparallelize_head(self): for param_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_head( - self.module, param_name + self.module, param_name ) and isinstance(module, Linear3D): zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) @@ -531,13 +533,9 @@ def _deparallelize_layernorm(self): self._gather_layernorm(module) def _gather_embedding(self, module): - cubic_dim = self.parallel_context.get_world_size( - ParallelMode.TENSOR_3D_INPUT - ) + cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): - w = gather_3d( - self.parallel_context, module.weight.data, cubic_dim - ) + w = gather_3d(self.parallel_context, module.weight.data, cubic_dim) assert hasattr( self.module, "orig_vocab_size" @@ -596,11 +594,15 @@ def _gather_linear(self, module: Linear3D): fusion_degree = module.fusion_degree # slice_bias = module.slice_bias - cubic_dim = self.parallel_context.get_world_size( - ParallelMode.TENSOR_3D_INPUT - ) + cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) - w = gather_1d(self.parallel_context, module.weight.data, cubic_dim, 0, ParallelMode.TENSOR_3D_WEIGHT) + w = gather_1d( + self.parallel_context, + module.weight.data, + cubic_dim, + 0, + ParallelMode.TENSOR_3D_WEIGHT, + ) if fusion_degree > 1: w = self._reconstruct_combined_qkv(w, cubic_dim, fusion_degree, False) if is_reversed: @@ -611,9 +613,7 @@ def _gather_linear(self, module: Linear3D): if hasattr(module, "bias") and module.bias is not None: b = gather_1d(self.parallel_context, module.bias.data, cubic_dim, 0) if fusion_degree > 1: - b = self._reconstruct_combined_qkv( - b, cubic_dim, fusion_degree, True - ) + b = self._reconstruct_combined_qkv(b, cubic_dim, fusion_degree, True) b = b.view(b.size()[1:]) module.bias.data = b @@ -630,14 +630,10 @@ def _gather_linear(self, module: Linear3D): module.__class__ = nn.Linear def _gather_layernorm(self, module): - cubic_dim = self.parallel_context.get_world_size( - ParallelMode.TENSOR_3D_INPUT - ) + cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) if hasattr(module, "weight") and module.weight is not None: if module.weight.dim() >= 1: - w = gather_1d( - self.parallel_context, module.weight.data, cubic_dim, 0 - ) + w = gather_1d(self.parallel_context, module.weight.data, cubic_dim, 0) module.weight.data = w if hasattr(module.weight, "oslo_parallel"): @@ -669,7 +665,9 @@ def _reconstruct_combined_qkv(tensor, cubic_dim, fusion_degree, is_bias=False): for i in range(cubic_dim) ], 1, - ).view(-1, last_dim).contiguous() + ) + .view(-1, last_dim) + .contiguous() ) else: reshaped_w = tensor.view(fusion_degree * cubic_dim, -1) @@ -680,7 +678,9 @@ def _reconstruct_combined_qkv(tensor, cubic_dim, fusion_degree, is_bias=False): for i in range(cubic_dim) ], 1, - ).view(-1, last_dim).contiguous() + ) + .view(-1, last_dim) + .contiguous() ) return recon_w From 45cdac05836194857b8722f5051d19844df6110f Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Tue, 16 Aug 2022 08:15:08 +0900 Subject: [PATCH 35/37] remove zero_rank_log --- .../tensor_parallel/_parallel_1d/_wrapper.py | 13 ---------- .../tensor_parallel/_parallel_2d/_wrapper.py | 24 ------------------- .../_parallel_2p5d/_wrapper.py | 16 ------------- .../tensor_parallel/_parallel_3d/_wrapper.py | 18 -------------- 4 files changed, 71 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index d4c62811..105b7eac 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -24,7 +24,6 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, - zero_rank_log, ) from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( BaseTensorParallelWrapper, @@ -392,18 +391,9 @@ def _slice_head(self, module, reversed): @torch.no_grad() def deparallelize(self): # must deparallelize embedding first than linear - zero_rank_log("deparallelize embedding start") self._deparallelize_embedding() - zero_rank_log("deparallelize embedding end") - - zero_rank_log("deparallelize linear start") self._deparallelize_linear() - zero_rank_log("deparallelize linear end") - - zero_rank_log("deparallelize head start") self._deparallelize_head() - zero_rank_log("deparallelize head end") - self._rollback_mp_arguments() def _rollback_mp_arguments(self): @@ -436,7 +426,6 @@ def _deparallelize_head(self): if self.tensor_parallel_mapping.is_head( self.module, param_name ) and isinstance(module, ColLinear1D): - zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) def _gather_embedding(self, module): @@ -533,13 +522,11 @@ def _gather_head(self, module: ColLinear1D): if module.weight is not self.module.get_input_embeddings().weight: return self._gather_column_linear(module) elif hasattr(module, "bias") and module.bias is not None: - zero_rank_log("before gathering head bias") world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) b = self._reconstruct_combined_qkv(module.bias, world_size, 1, 0) module.bias.data = b[: module.weight.size()[0]] - zero_rank_log("after gathering head bias") _update_module_arguments( module=module, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 8e4191da..4ffeeba9 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -30,7 +30,6 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, - zero_rank_log, ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -452,33 +451,13 @@ def _slice_head(self, module, reversed, gather_output): ) module.__class__ = Linear2D - def _zero_rank_log(self, txt): - import torch.distributed as dist - - if dist.get_rank() == 0: - print(txt) - # 모니터링 생성 대기 - dist.barrier() - @torch.no_grad() def deparallelize(self): # must deparallelize embedding first than linear - zero_rank_log("deparallelize embedding start") self._deparallelize_embedding() - zero_rank_log("deparallelize embedding end") - - zero_rank_log("deparallelize linear start") self._deparallelize_linear() - zero_rank_log("deparallelize linear end") - - zero_rank_log("deparallelize layernorm start") self._deparallelize_layernorm() - zero_rank_log("deparallelize layernorm end") - - zero_rank_log("deparallelize head start") self._deparallelize_head() - zero_rank_log("deparallelize head end") - self._rollback_mp_arguments() def _rollback_mp_arguments(self): @@ -508,7 +487,6 @@ def _deparallelize_head(self): if self.tensor_parallel_mapping.is_head( self.module, param_name ) and isinstance(module, Linear2D): - zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) def _deparallelize_layernorm(self): @@ -554,7 +532,6 @@ def _gather_head(self, module: Linear2D): if module.weight is not self.module.get_input_embeddings().weight: return self._gather_linear(module) elif hasattr(module, "bias") and module.bias is not None: - zero_rank_log("before gathering bias") summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_ROW) b = gather_1d_twice( @@ -562,7 +539,6 @@ def _gather_head(self, module: Linear2D): ) module.bias.data = b[: module.weight.size()[0]] - zero_rank_log("after gathering bias") _update_module_arguments( module=module, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 5f18a685..f24c2d74 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -29,7 +29,6 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, - zero_rank_log, ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -493,22 +492,10 @@ def _slice_head(self, module, reversed, gather_output): @torch.no_grad() def deparallelize(self): # must deparallelize embedding first than linear - zero_rank_log("deparallelize embedding start") self._deparallelize_embedding() - zero_rank_log("deparallelize embedding end") - - zero_rank_log("deparallelize linear start") self._deparallelize_linear() - zero_rank_log("deparallelize linear end") - - zero_rank_log("deparallelize layernorm start") self._deparallelize_layernorm() - zero_rank_log("deparallelize layernorm end") - - zero_rank_log("deparallelize head start") self._deparallelize_head() - zero_rank_log("deparallelize head end") - self._rollback_mp_arguments() def _rollback_mp_arguments(self): @@ -540,7 +527,6 @@ def _deparallelize_head(self): if self.tensor_parallel_mapping.is_head( self.module, param_name ) and isinstance(module, Linear2p5D): - zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) def _deparallelize_layernorm(self): @@ -588,7 +574,6 @@ def _gather_head(self, module: Linear2p5D): if module.weight is not self.module.get_input_embeddings().weight: return self._gather_linear(module) elif hasattr(module, "bias") and module.bias is not None: - zero_rank_log("before gathering bias") tesseract_dim = self.parallel_context.get_world_size( ParallelMode.TENSOR_2P5D_COL ) @@ -596,7 +581,6 @@ def _gather_head(self, module: Linear2p5D): b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) module.bias.data = b[: module.weight.size()[0]] - zero_rank_log("after gathering bias") _update_module_arguments( module=module, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py index 8ae12dad..c4c05da3 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py @@ -31,7 +31,6 @@ _update_module_arguments, is_huggingface_model, is_oslo_model, - zero_rank_log, ) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, @@ -477,22 +476,10 @@ def _slice_head(self, module, reversed, gather_output): @torch.no_grad() def deparallelize(self): # must deparallelize embedding first than linear - zero_rank_log("deparallelize embedding start") self._deparallelize_embedding() - zero_rank_log("deparallelize embedding end") - - zero_rank_log("deparallelize linear start") self._deparallelize_linear() - zero_rank_log("deparallelize linear end") - - zero_rank_log("deparallelize layernorm start") self._deparallelize_layernorm() - zero_rank_log("deparallelize layernorm end") - - zero_rank_log("deparallelize head start") self._deparallelize_head() - zero_rank_log("deparallelize head end") - self._rollback_mp_arguments() def _rollback_mp_arguments(self): @@ -524,7 +511,6 @@ def _deparallelize_head(self): if self.tensor_parallel_mapping.is_head( self.module, param_name ) and isinstance(module, Linear3D): - zero_rank_log(f"deparallelize head {param_name}") self._gather_head(module) def _deparallelize_layernorm(self): @@ -567,15 +553,11 @@ def _gather_head(self, module: Linear3D): if module.weight is not self.module.get_input_embeddings().weight: return self._gather_linear(module) elif hasattr(module, "bias") and module.bias is not None: - zero_rank_log("before gathering bias") tesseract_dim = self.parallel_context.get_world_size( ParallelMode.TENSOR_3D_INPUT ) - b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) - module.bias.data = b[: module.weight.size()[0]] - zero_rank_log("after gathering bias") _update_module_arguments( module=module, From 20204eaaee37a3ff47de794eb700a5e6aa389dc6 Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Tue, 16 Aug 2022 08:49:15 +0900 Subject: [PATCH 36/37] precommit run --- .../parallel/tensor_parallel/_parallel_2d/_wrapper.py | 3 +-- .../nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py | 10 +--------- .../parallel/tensor_parallel/_parallel_3d/_wrapper.py | 2 +- .../nn/parallel/tensor_parallel/tensor_parallel.py | 6 +----- 4 files changed, 4 insertions(+), 17 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index b4d116f3..7a6f607c 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -17,10 +17,9 @@ LayerNorm2D, ) from oslo.torch.nn.parallel.tensor_parallel._parallel_2d._ops import ( - split_batch_2d, gather_2d, - gather_1d, gather_1d_twice, +) from oslo.torch.distributed.nn.functional import ( scatter, ) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py index e56e921c..339a0fe2 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py @@ -83,15 +83,6 @@ def gather_batch_2p5d( ) -def all_gather_tensor_2p5d( - inputs: Tensor, - dim: int, - parallel_context: ParallelContext, - col_parallel_mode: ParallelMode, -) -> Tensor: - return _AllGatherTensor2p5D.apply(inputs, dim, parallel_context, col_parallel_mode) - - def reduce_by_batch_2p5d( inputs, reduce_mean: bool, parallel_context: ParallelContext ) -> Tensor: @@ -137,6 +128,7 @@ def split_batch_2p5d( )[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL)].contiguous() return col_chunked + def get_current_device(): r""" Get current device. diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py index 5e011815..e16af883 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py @@ -16,10 +16,10 @@ LayerNorm3D, ) from oslo.torch.nn.parallel.tensor_parallel._parallel_3d._ops import ( - split_batch_3d, gather_3d, gather_2d, gather_1d, +) from oslo.torch.distributed.nn.functional import ( scatter, ) diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index 1ce462a6..d18e2212 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -4,7 +4,6 @@ import os import json from operator import xor -from typing import Optional import warnings import torch @@ -96,7 +95,6 @@ def __init__( module = self._resize_vocab_size(module, self.parallel_context) module = self._resize_num_classes(module, self.parallel_context, mapping) - if parallel_context.tensor_parallel_mode != ParallelMode.TENSOR_1D: if memory_priority and parallel_context.tensor_parallel_size > 1: warnings.warn( @@ -108,9 +106,7 @@ def __init__( module, self.parallel_context, mapping, memory_priority ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D: - self.module = _TensorParallel2D( - module, self.parallel_context, mapping - ) + self.module = _TensorParallel2D(module, self.parallel_context, mapping) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D: self.module = _TensorParallel2p5D(module, self.parallel_context, mapping) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_3D: From 88ac33f69f69f86b0a4b9453c2ece6eff08140de Mon Sep 17 00:00:00 2001 From: yang-kichang Date: Tue, 16 Aug 2022 09:17:33 +0900 Subject: [PATCH 37/37] add module args for save_parallelized --- .../tensor_parallel/_parallel_1d/_wrapper.py | 11 +++++++++++ .../tensor_parallel/_parallel_2d/_wrapper.py | 10 ++++++++++ .../tensor_parallel/_parallel_2p5d/_wrapper.py | 10 ++++++++++ .../tensor_parallel/_parallel_3d/_wrapper.py | 10 ++++++++++ .../parallel/tensor_parallel/tensor_parallel.py | 15 +++++++++++---- 5 files changed, 52 insertions(+), 4 deletions(-) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 5d891fc4..0c9090df 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -53,8 +53,19 @@ def __init__( parallel_context: ParallelContext, mapping: dict = None, memory_priority: bool = False, + module_args: dict = None, ): super().__init__(module, parallel_context) + + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.module = module self.parallel_context = parallel_context self.memory_priority = memory_priority diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index 7a6f607c..cd68334f 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -57,6 +57,7 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): super().__init__(module, parallel_context, mapping) self.module = module @@ -71,6 +72,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 86b36fe8..70e36a92 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -54,6 +54,7 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): super().__init__(module, parallel_context, mapping) self.module = module @@ -68,6 +69,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py index e16af883..4543cee2 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py @@ -55,6 +55,7 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): super().__init__(module, parallel_context, mapping) self.module = module @@ -69,6 +70,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index d18e2212..ce7eb142 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -89,6 +89,7 @@ def __init__( parallel_context: Optional[ParallelContext] = None, mapping: dict = None, memory_priority: bool = False, + module_args: dict = None, ): super().__init__() self.parallel_context = get_parallel_context(module, parallel_context) @@ -103,14 +104,20 @@ def __init__( if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D: self.module = _TensorParallel1D( - module, self.parallel_context, mapping, memory_priority + module, self.parallel_context, mapping, memory_priority, module_args ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D: - self.module = _TensorParallel2D(module, self.parallel_context, mapping) + self.module = _TensorParallel2D( + module, self.parallel_context, mapping, module_args + ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D: - self.module = _TensorParallel2p5D(module, self.parallel_context, mapping) + self.module = _TensorParallel2p5D( + module, self.parallel_context, mapping, module_args + ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_3D: - self.module = _TensorParallel3D(module, self.parallel_context, mapping) + self.module = _TensorParallel3D( + module, self.parallel_context, mapping, module_args + ) else: raise ValueError( "currently, only 1d, 2d, 2p5d tensor parallelism is supported."