diff --git a/.gitignore b/.gitignore index 61186f7e..bfa231e7 100644 --- a/.gitignore +++ b/.gitignore @@ -389,3 +389,10 @@ usecases /usecases */usecases wandb/ + +# multi gpu mem log +**/core.* + +# sample huggingface models +**/pytorch_model.bin +**/config.json diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index 7817164f..268633e0 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -178,6 +178,8 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: ) if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] + if not outputs.is_contiguous(): + outputs = outputs.contiguous() if self.memory_priority and self.scatter_output: outputs = scatter_tensor_1d( @@ -257,6 +259,9 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: bias = self.bias return outputs + bias + if not outputs.is_contiguous(): + outputs = outputs.contiguous() + return outputs @@ -398,6 +403,10 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: ) if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] + + if not outputs.is_contiguous(): + outputs = outputs.contiguous() + return outputs @@ -530,15 +539,19 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: dim=-1, parallel_context=self.parallel_context, col_parallel_mode=ParallelMode.TENSOR_2P5D_ROW, - ) + ).clone() outputs = all_gather_tensor_2p5d( outputs, dim=0, parallel_context=self.parallel_context, col_parallel_mode=ParallelMode.TENSOR_2P5D_COL, - ) + ).clone() if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] + + if not outputs.is_contiguous(): + outputs = outputs.contiguous() + return outputs @@ -621,4 +634,7 @@ def forward(self, input: Tensor) -> Tensor: ) if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] + + if not outputs.is_contiguous(): + outputs = outputs.contiguous() return outputs diff --git a/oslo/torch/nn/parallel/tensor_parallel/__init__.py b/oslo/torch/nn/parallel/tensor_parallel/__init__.py index 296bd0cf..43e5116c 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/__init__.py +++ b/oslo/torch/nn/parallel/tensor_parallel/__init__.py @@ -1,6 +1,7 @@ from oslo.torch.nn.parallel.tensor_parallel.mapping import Column, Row, Update, Head -from oslo.torch.nn.parallel.tensor_parallel.tensor_parallel import ( - TensorParallel, +from oslo.torch.nn.parallel.tensor_parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, ) -__ALL__ = [TensorParallel, Column, Row, Update, Head] +__ALL__ = [TensorParallel, Column, Row, Update, Head, BaseTensorParallelWrapper] diff --git a/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py new file mode 100644 index 00000000..ec0fad62 --- /dev/null +++ b/oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py @@ -0,0 +1,225 @@ +import copy +import os +import json + +import torch +import torch.nn as nn +import torch.distributed as dist + +from typing import Union, Optional, Callable +from logging import getLogger + +from oslo.torch.distributed import ParallelContext, ParallelMode + +from oslo.torch.nn.parallel.utils import ( + ParallelWrapper, + _update_module_arguments, + is_huggingface_model, + is_oslo_model, + allocate_params, + unwrap_parallel, + get_parameter_dtype, +) + + +class BaseTensorParallelWrapper(ParallelWrapper): + """ + PyTorch module for xD tensor parallelism + + Args: + module (nn.Module): model object + parallel_context (ParallelContext): parallel context object + """ + + def __init__( + self, + module: nn.Module, + parallel_context: ParallelContext, + mapping: dict = None, + module_args: dict = None, + ): + super().__init__() + + @torch.no_grad() + def save_parallelized( + self, + new_module, + save_directory: Union[str, os.PathLike], + save_config: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + merge_checkpoints: bool = False, + mapping: Optional[dict] = None, + **kwargs, + ): + logger = getLogger("TensorParallel") + PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" + + if ( + self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1 + and self.parallel_context.get_world_size(ParallelMode.PIPELINE) == 1 + ): + if dist.get_rank() == 0: + self.save_pretrained( + save_directory=save_directory, + save_config=save_config, + state_dict=state_dict, + save_function=save_function, + **kwargs, + ) + dist.barrier() + return None + + if merge_checkpoints: + model_to_save = self.__class__( + module=new_module, + parallel_context=self.parallel_context, + mapping=mapping, + module_args=self.config, + ).eval() + + if state_dict is None: + state_dict = self.state_dict() + + model_to_save.load_state_dict(state_dict) + allocate_params(model_to_save, self.parallel_context) + + if self.parallel_context.get_world_size(ParallelMode.TENSOR) > 1: + model_to_save.deparallelize() + + if dist.get_rank() == 0: + if is_huggingface_model(model_to_save.module): + model_to_save.module.save_pretrained( + save_directory=save_directory, + save_config=save_config, + save_function=save_function, + **kwargs, + ) + else: + if save_config: + with open( + os.path.join(save_directory, "config.json"), "w" + ) as f: + json.dump(self.config, f) + save_function( + model_to_save, + os.path.join(save_directory, "pytorch_model.bin"), + ) + del model_to_save + + dist.barrier() + return None + + if os.path.isfile(save_directory): + logger.error( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + return + + os.makedirs(save_directory, exist_ok=True) + + # Only save the model itself if we are using distributed training + model_to_save = unwrap_parallel(self) + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.torch_dtype = str(dtype).split(".")[1] + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Save the config + if save_config: + model_to_save.config.save_pretrained(save_directory) + + # Save the model + if state_dict is None: + state_dict = model_to_save.state_dict() + + # Handle the case where some state_dict keys shouldn't be saved + if getattr(self, "_keys_to_ignore_on_save", None) is not None: + state_dict = { + k: v + for k, v in state_dict.items() + if k not in self._keys_to_ignore_on_save + } + + # If we save using the predefined names, we can load using `from_pretrained` + weights_name = PARALLELIZED_WEIGHTS_NAME + weights_name = weights_name.replace( + "tp_0", f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}" + ) + weights_name = weights_name.replace( + "pp_0", f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}" + ) + + output_model_file = os.path.join(save_directory, weights_name) + + if self.parallel_context.get_world_size(ParallelMode.DATA) > 1: + if self.parallel_context.get_local_rank(ParallelMode.DATA) == 0: + save_function(state_dict, output_model_file) + else: + save_function(state_dict, output_model_file) + + dist.barrier() + logger.info(f"Model weights saved in {output_model_file}") + + def from_parallelized(self, path): + """ + Example: + >>> model = AnyModel() + >>> model = TensorParallel(model, ...) + >>> model.from_parallelized(path) + """ + PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin" + parallelized_model_path = path + + file_names = { + os.path.join( + parallelized_model_path, + PARALLELIZED_WEIGHTS_NAME.replace("tp_0", f"tp_{tp}").replace( + "pp_0", f"pp_{pp}" + ), + ) + for tp in range(self.parallel_context.get_world_size(ParallelMode.TENSOR)) + for pp in range(self.parallel_context.get_world_size(ParallelMode.PIPELINE)) + } + + if os.path.isdir(parallelized_model_path): + if all(os.path.isfile(file_name) for file_name in file_names): + state_dict = torch.load( + os.path.join( + parallelized_model_path, + PARALLELIZED_WEIGHTS_NAME.replace( + "tp_0", + f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}", + ).replace( + "pp_0", + f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}", + ), + ) + ) + + if getattr(self, "_keys_to_ignore_on_save", None) is not None: + state_dict = { + k: v + for k, v in state_dict.items() + if k not in self._keys_to_ignore_on_save + } + + self.load_state_dict(state_dict=state_dict, strict=False) + + else: + raise FileNotFoundError( + f"all the {file_names} are necessary. " + f"but some of them do not exist. Please check your checkpoint files." + ) + else: + raise NotADirectoryError( + f"directory named {parallelized_model_path} is not valid. " + ) + + @torch.no_grad() + def deparallelize(self): + return NotImplementedError diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py index 19d9c37c..c5104a65 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py @@ -1,6 +1,7 @@ from typing import Any import torch +import torch.distributed as dist import torch.nn.functional as F from torch import Tensor @@ -214,3 +215,21 @@ def memory_priority_linear( inputs: Tensor, weight: Tensor, parallel_context: ParallelContext ): return _MemoryPriorityLinear.apply(inputs, weight, parallel_context) + + +def split_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor = tensor.chunk(summa_dim, dim=dim)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_1D) + ] + return tensor + + +def gather_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_1D), + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index c8533e3d..0c9090df 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +import torch.distributed as dist from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn.modules.embedding import ( @@ -23,18 +24,20 @@ TensorParallelMapping, ) from oslo.torch.nn.parallel.utils import ( - ParallelWrapper, _update_module_arguments, is_huggingface_model, is_oslo_model, ) +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, +) from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, ) from oslo.transformers.constants import SEQ_DIMENSIONS -class _TensorParallel1D(ParallelWrapper): +class _TensorParallel1D(BaseTensorParallelWrapper): """ PyTorch module for 1D tensor parallelism @@ -50,8 +53,19 @@ def __init__( parallel_context: ParallelContext, mapping: dict = None, memory_priority: bool = False, + module_args: dict = None, ): - super().__init__() + super().__init__(module, parallel_context) + + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.module = module self.parallel_context = parallel_context self.memory_priority = memory_priority @@ -208,6 +222,34 @@ def _slice_embedding(self, module): orig_module=copy.deepcopy(module.__class__), ) module.__class__ = VocabParallelEmbedding1D + + for name, module_head in self.module.named_modules(): + if ( + hasattr(module_head, "weight") + and module_head.weight is module.weight + and not isinstance(module_head, nn.Embedding) + and not self.tensor_parallel_mapping.is_head(self.module, name) + ): + _update_module_arguments( + module=module_head, + parallel_context=self.parallel_context, + reversed=self.tensor_parallel_mapping.is_reversed( + self.module, name + ), + fusion_degree=1, + orig_module=copy.deepcopy(module_head.__class__), + out_features=module.weight.size()[0], + # in_features=module.weight.size()[1], + gather_output=not is_oslo_model(self.module), + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + + if isinstance(module_head, nn.Linear) or isinstance( + module_head, nn.Conv1D + ): + module_head.__class__ = ColLinear1D else: weight_list = module.weight.data.chunk(world_size, dim=1) module.weight.data = weight_list[rank].contiguous() @@ -394,3 +436,170 @@ def _slice_head(self, module, reversed): else False, ) module.__class__ = ColLinear1D + + @torch.no_grad() + def deparallelize(self): + # must deparallelize embedding first than linear + self._deparallelize_embedding() + self._deparallelize_linear() + self._deparallelize_head() + self._rollback_mp_arguments() + + def _rollback_mp_arguments(self): + for module in self.module.modules(): + for elem in self.tensor_parallel_mapping.update_attrs(self.module): + if hasattr(module, elem.name): + world_size = self.parallel_context.get_world_size( + ParallelMode.TENSOR_1D + ) + expanded_arg = getattr(module, elem.name) * world_size + setattr(module, elem.name, expanded_arg) + + def _deparallelize_embedding(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == VocabParallelEmbedding1D: + self._gather_embedding(module) + if module.__class__ == Embedding1D: + self._gather_embedding(module) + + def _deparallelize_linear(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_column_parallel(self.module, param_name): + self._gather_column_linear(module) + + elif self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): + self._gather_row_linear(module) + + def _deparallelize_head(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_head( + self.module, param_name + ) and isinstance(module, ColLinear1D): + self._gather_head(module) + + def _gather_embedding(self, module): + world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + # w = gather_2d(self.parallel_context, module.weight.data, world_size, col_first=True) + tensor_list = [ + torch.zeros_like(module.weight.data) for _ in range(world_size) + ] + dist.all_gather( + tensor_list, + module.weight.data.contiguous(), + self.parallel_context.get_group(ParallelMode.TENSOR_1D), + ) + w = torch.cat(tensor_list, dim=0) + + assert hasattr( + self.module, "orig_vocab_size" + ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + orig_vocab_size = self.module.orig_vocab_size + + module.weight.data = w[:orig_vocab_size, :] + + _update_module_arguments( + module=module, + vocab_start_index=None, + vocab_end_index=None, + parallel_context=None, + num_embeddings=module.weight.size()[0], + embedding_dim=module.weight.size()[1], + orig_module=None, + ) + else: + tensor_list = [ + torch.zeros_like(module.weight.data) for _ in range(world_size) + ] + dist.all_gather( + tensor_list, + module.weight.data.contiguous(), + self.parallel_context.get_group(ParallelMode.TENSOR_1D), + ) + w = torch.cat(tensor_list, dim=1) + module.weight.data = w + + _update_module_arguments( + module=module, + parallel_context=None, + embedding_dim=module.weight.size()[1], + ) + module.__class__ = nn.Embedding + + def _gather_linear(self, module, dim=1): + is_reversed = module.reversed + fusion_degree = module.fusion_degree + + world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR) + + w = self._reconstruct_combined_qkv( + module.weight, world_size, fusion_degree, dim + ) + if is_reversed: + w = w.t() + module.weight.data = w + + if hasattr(module, "bias") and module.bias is not None and dim != 1: + b = self._reconstruct_combined_qkv( + module.bias, world_size, fusion_degree, dim + ) + module.bias.data = b + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + + del module.reversed + del module.fusion_degree + del module.orig_module + del module.parallel_context + module.__class__ = nn.Linear + + def _gather_column_linear(self, module): + self._gather_linear(module, dim=0) + + def _gather_row_linear(self, module): + self._gather_linear(module, dim=1) + + def _gather_head(self, module: ColLinear1D): + if module.weight is not self.module.get_input_embeddings().weight: + return self._gather_column_linear(module) + elif hasattr(module, "bias") and module.bias is not None: + world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) + + b = self._reconstruct_combined_qkv(module.bias, world_size, 1, 0) + + module.bias.data = b[: module.weight.size()[0]] + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.reversed + del module.fusion_degree + del module.orig_module + del module.parallel_context + + module.__class__ = nn.Linear + + def _reconstruct_combined_qkv(self, tensor, world_size, fusion_degree, dim: int): + tensor_list = tensor.chunk(fusion_degree, dim=dim) + result_list = [] + for w in tensor_list: + w_list = [torch.zeros_like(w) for _ in range(world_size)] + dist.all_gather( + w_list, w, self.parallel_context.get_group(ParallelMode.TENSOR_1D) + ) + result_list.append(torch.cat(w_list, dim=dim)) + return torch.cat(result_list, dim=dim) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py index a618c4f4..80f91827 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py @@ -94,6 +94,87 @@ def gather_batch_2d( ) +def gather_2d(parallel_context, tensor, summa_dim, col_first=True): + if col_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2D_ROW), + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor + + +def gather_1d_twice(parallel_context, tensor, summa_dim, dim=-1): + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2D_COL), + ) + tensor = torch.cat(tensor_list, dim=dim) + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, tensor, parallel_context.get_group(ParallelMode.TENSOR_2D_ROW) + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor + + +def split_batch_2d( + inputs: Tensor, + dim: int = 0, + parallel_context: Optional[ParallelContext] = None, +) -> Tensor: + dim_size = inputs.size(dim) + world_size = parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) + + if world_size <= 1: + return inputs + + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of 2D size ({world_size})." + + return inputs.chunk(world_size, dim=dim)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL) + ].contiguous() + + def reduce_tensor_2d( inputs: Tensor, parallel_context: ParallelContext, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index eef4339d..cd68334f 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -10,11 +10,16 @@ VocabUtility, ) from oslo.torch.nn.modules.linear import ( + Linear, Linear2D, ) from oslo.torch.nn.modules.layer_norm import ( LayerNorm2D, ) +from oslo.torch.nn.parallel.tensor_parallel._parallel_2d._ops import ( + gather_2d, + gather_1d_twice, +) from oslo.torch.distributed.nn.functional import ( scatter, ) @@ -31,10 +36,14 @@ _TensorParallelMappingForHuggingFace, ) +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, +) + from oslo.transformers.constants import BATCH_DIMENSIONS -class _TensorParallel2D(ParallelWrapper): +class _TensorParallel2D(BaseTensorParallelWrapper): """ PyTorch module for 2D tensor parallelism @@ -48,8 +57,9 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): - super().__init__() + super().__init__(module, parallel_context, mapping) self.module = module self.parallel_context = parallel_context self.device = torch.cuda.current_device() @@ -62,6 +72,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() @@ -147,6 +166,9 @@ def _parallelize_head(self): reversed=self.tensor_parallel_mapping.is_reversed( self.module, param_name ), + gather_output=self.tensor_parallel_mapping.is_gather_output( + self.module, param_name + ), ) @staticmethod @@ -366,7 +388,7 @@ def _slice_layernorm(self, module): ) module.__class__ = LayerNorm2D - def _slice_head(self, module, reversed): + def _slice_head(self, module, reversed, gather_output): if module.weight is not self.module.get_input_embeddings().weight: self._slice_linear( module=module, @@ -376,7 +398,7 @@ def _slice_head(self, module, reversed): ) _update_module_arguments( module=module, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, ) else: summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) @@ -392,6 +414,7 @@ def _slice_head(self, module, reversed): pipeline_parallel_size = self.parallel_context.get_world_size( ParallelMode.PIPELINE ) + if hasattr(module, "bias") and module.bias is not None: if module.bias.dim() >= 1: bias_list = module.bias.data.chunk(summa_dim, dim=0) @@ -422,7 +445,257 @@ def _slice_head(self, module, reversed): reversed=reversed, fusion_degree=1, skip_bias_add=False, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, orig_module=copy.deepcopy(module.__class__), ) module.__class__ = Linear2D + + @torch.no_grad() + def deparallelize(self): + # must deparallelize embedding first than linear + self._deparallelize_embedding() + self._deparallelize_linear() + self._deparallelize_layernorm() + self._deparallelize_head() + self._rollback_mp_arguments() + + def _rollback_mp_arguments(self): + for module in self.module.modules(): + for elem in self.tensor_parallel_mapping.update_attrs(self.module): + if hasattr(module, elem.name): + summa_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2D_COL + ) + expanded_arg = getattr(module, elem.name) * summa_dim + setattr(module, elem.name, expanded_arg) + + def _deparallelize_embedding(self): + for param_name, module in self.module.named_modules(): + if module.__class__ in [VocabParallelEmbedding2D, Embedding2D]: + self._gather_embedding(module) + + def _deparallelize_linear(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_column_parallel( + self.module, param_name + ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): + self._gather_linear(module) + + def _deparallelize_head(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_head( + self.module, param_name + ) and isinstance(module, Linear2D): + self._gather_head(module) + + def _deparallelize_layernorm(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == LayerNorm2D: + self._gather_layernorm(module) + + def _gather_embedding(self, module): + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + w = gather_2d( + self.parallel_context, module.weight.data, summa_dim, col_first=True + ) + + assert hasattr( + self.module, "orig_vocab_size" + ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + orig_vocab_size = self.module.orig_vocab_size + + module.weight.data = w[:orig_vocab_size, :] + + _update_module_arguments( + module=module, + vocab_start_index=None, + vocab_end_index=None, + parallel_context=None, + num_embeddings=module.weight.size()[0], + embedding_dim=module.weight.size()[1], + orig_module=None, + ) + else: + w = gather_1d_twice(self.parallel_context, module.weight.data, summa_dim, 1) + module.weight.data = w + + _update_module_arguments( + module=module, + parallel_context=None, + embedding_dim=module.weight.size()[1], + ) + module.__class__ = nn.Embedding + + def _gather_head(self, module: Linear2D): + if module.weight is not self.module.get_input_embeddings().weight: + return self._gather_linear(module) + elif hasattr(module, "bias") and module.bias is not None: + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_ROW) + + b = gather_1d_twice( + self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0 + ) + + module.bias.data = b[: module.weight.size()[0]] + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + + del module.row_rank + del module.col_rank + del module.summa_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + + module.__class__ = nn.Linear + + def _gather_linear(self, module: Linear2D): + is_reversed = module.reversed + fusion_degree = module.fusion_degree + # slice_bias = module.slice_bias + + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) + + w = gather_2d( + self.parallel_context, + module.weight.data, + summa_dim=summa_dim, + col_first=True, + ) + # print(f"w shape: {w.shape}\nweight shape: {module.weight.data.shape}") + if fusion_degree > 1: + w = self._reconstruct_combined_qkv(w, summa_dim, fusion_degree, False) + if is_reversed: + w = w.t() + module.weight.data = w + + if hasattr(module, "bias") and module.bias is not None: + # if slice_bias is True and module.bias.dim() >= 1: + b = gather_1d_twice( + self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0 + ) + if fusion_degree > 1: + b = self._reconstruct_combined_qkv(b, summa_dim, fusion_degree, True) + b = b.view(b.size()[1:]) + module.bias.data = b + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.row_rank + del module.col_rank + del module.summa_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + + module.__class__ = nn.Linear + + def _gather_layernorm(self, module): + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) + if hasattr(module, "weight") and module.weight is not None: + if module.weight.dim() >= 1: + w = gather_1d_twice( + self.parallel_context, + module.weight.data, + summa_dim=summa_dim, + dim=0, + ) + module.weight.data = w + + if hasattr(module.weight, "oslo_parallel"): + del module.weight.oslo_parallel + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + b = gather_1d_twice( + self.parallel_context, module.bias.data, summa_dim=summa_dim, dim=0 + ) + module.bias.data = b + + if hasattr(module.bias, "oslo_parallel"): + del module.bias.oslo_parallel + + del module.partitioned_dim + del module.row_rank + del module.col_rank + del module.summa_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.orig_module + _update_module_arguments( + module, + normalized_shape=module.weight.size()[0], + ) + module.__class__ = nn.LayerNorm + + @staticmethod + def _reconstruct_combined_qkv(tensor, summa_dim, fusion_degree, is_bias=False): + last_dim = tensor.size()[-1] + if is_bias is False: + reshaped_w = tensor.view(summa_dim * fusion_degree, -1, last_dim) + # print(f"tensor.size: {tensor.size()}, reshaped_w.size: {reshaped_w.size()}") + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(summa_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) + else: + reshaped_w = tensor.view(fusion_degree * summa_dim, -1) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(summa_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) + return recon_w + + @staticmethod + def _reconstrunct_combined_qkv_bias(tensor, summa_dim, fusion_degree): + tensor = [ + [tensor[j * summa_dim + k] for k in range(summa_dim)] + for j in range(fusion_degree) + ] + tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) + tensor = [tensor[j] for j in range(summa_dim)] + return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py index 253a31ec..339a0fe2 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py @@ -110,6 +110,25 @@ def reduce_scatter_tensor_2p5d( return _ReduceScatterTensor2p5D.apply(inputs, dim, parallel_context, parallel_mode) +def split_batch_2p5d( + inputs: Tensor, dim: int, parallel_context: ParallelContext +) -> Tensor: + dim_size = inputs.size(dim) + world_size = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) + + if world_size <= 1: + return inputs + + assert ( + dim_size % world_size == 0 + ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." + + col_chunked = torch.chunk( + inputs, parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL), dim=dim + )[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL)].contiguous() + return col_chunked + + def get_current_device(): r""" Get current device. @@ -1012,3 +1031,141 @@ def backward(ctx: Any, output_grad: Tensor): return output_grad / ctx.reduce_size, None, None else: return output_grad, None, None + + +def split_batch_2d(parallel_context, tensor, tesseract_dim): + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + return tensor + + +def split_2p5d(parallel_context, tensor, tesseract_dim, col_first=True): + if col_first: + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + else: + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + tensor = torch.chunk(tensor, tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP) + ] + return tensor + + +def split_2d(parallel_context, tensor, tesseract_dim, col_first=True): + if col_first: + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + else: + tensor = tensor.chunk(tesseract_dim, dim=-1)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + ] + tensor = tensor.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + return tensor + + +def split_1d(parallel_context, tensor, summa_dim, dim=-1): + tensor = tensor.chunk(summa_dim, dim=dim)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + ] + return tensor + + +def gather_2p5d(parallel_context, tensor, tesseract_dim, col_first=True): + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, tensor, parallel_context.get_group(ParallelMode.TENSOR_2P5D_DEP) + ) + tensor = torch.cat(tensor_list, dim=0) + if col_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_2d(parallel_context, tensor, tesseract_dim, col_first=True): + if col_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_2P5D_ROW), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(tesseract_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_2P5D_COL), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_1d(parallel_context, tensor, summa_dim, dim=-1): + parallel_modde = ParallelMode.TENSOR_2P5D_ROW + tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(parallel_modde), + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index 6315c76f..70e36a92 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -10,17 +10,25 @@ VocabUtility, Embedding2p5D, ) -from oslo.torch.nn.modules.linear import Linear2p5D +from oslo.torch.nn.modules.linear import Linear, Linear2p5D from oslo.torch.nn.modules.layer_norm import LayerNorm2p5D +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( + split_batch_2p5d, + gather_2d, + gather_1d, +) + from oslo.torch.distributed.nn.functional import ( scatter, ) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, ) +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, +) from oslo.torch.nn.parallel.utils import ( - ParallelWrapper, _update_module_arguments, is_huggingface_model, is_oslo_model, @@ -32,7 +40,7 @@ from oslo.transformers.constants import BATCH_DIMENSIONS -class _TensorParallel2p5D(ParallelWrapper): +class _TensorParallel2p5D(BaseTensorParallelWrapper): """ PyTorch module for 2.5D tensor parallelism @@ -46,8 +54,9 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): - super().__init__() + super().__init__(module, parallel_context, mapping) self.module = module self.parallel_context = parallel_context self.device = torch.cuda.current_device() @@ -60,6 +69,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() @@ -145,6 +163,9 @@ def _parallelize_head(self): reversed=self.tensor_parallel_mapping.is_reversed( self.module, param_name ), + gather_output=self.tensor_parallel_mapping.is_gather_output( + self.module, param_name + ), ) @staticmethod @@ -389,7 +410,7 @@ def _slice_layernorm(self, module): module.__class__ = LayerNorm2p5D return module - def _slice_head(self, module, reversed): + def _slice_head(self, module, reversed, gather_output): if module.weight is not self.module.get_input_embeddings().weight: self._slice_linear( module=module, @@ -399,7 +420,7 @@ def _slice_head(self, module, reversed): ) _update_module_arguments( module=module, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, ) else: tesseract_dim = self.parallel_context.get_world_size( @@ -464,7 +485,259 @@ def _slice_head(self, module, reversed): reversed=reversed, fusion_degree=1, skip_bias_add=False, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, orig_module=copy.deepcopy(module.__class__), ) module.__class__ = Linear2p5D + + @torch.no_grad() + def deparallelize(self): + # must deparallelize embedding first than linear + self._deparallelize_embedding() + self._deparallelize_linear() + self._deparallelize_layernorm() + self._deparallelize_head() + self._rollback_mp_arguments() + + def _rollback_mp_arguments(self): + for module in self.module.modules(): + for elem in self.tensor_parallel_mapping.update_attrs(self.module): + if hasattr(module, elem.name): + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) + expanded_arg = getattr(module, elem.name) * tesseract_dim + setattr(module, elem.name, expanded_arg) + + def _deparallelize_embedding(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == VocabParallelEmbedding2p5D: + self._gather_embedding(module) + if module.__class__ == Embedding2p5D: + self._gather_embedding(module) + + def _deparallelize_linear(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_column_parallel( + self.module, param_name + ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): + self._gather_linear(module) + + def _deparallelize_head(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_head( + self.module, param_name + ) and isinstance(module, Linear2p5D): + self._gather_head(module) + + def _deparallelize_layernorm(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == LayerNorm2p5D: + self._gather_layernorm(module) + + def _gather_embedding(self, module): + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + w = gather_2d( + self.parallel_context, module.weight.data, tesseract_dim, col_first=True + ) + + assert hasattr( + self.module, "orig_vocab_size" + ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + orig_vocab_size = self.module.orig_vocab_size + module.weight.data = w[:orig_vocab_size, :].contiguous() + + _update_module_arguments( + module=module, + vocab_start_index=None, + vocab_end_index=None, + parallel_context=None, + num_embeddings=module.weight.size()[0], + embedding_dim=module.weight.size()[1], + orig_module=None, + ) + else: + w = gather_1d(self.parallel_context, module.weight, tesseract_dim, 1) + w = gather_1d(self.parallel_context, w, tesseract_dim, 1) + module.weight.data = w + + _update_module_arguments( + module=module, + parallel_context=None, + embedding_dim=module.weight.size()[1], + ) + module.__class__ = nn.Embedding + + def _gather_head(self, module: Linear2p5D): + if module.weight is not self.module.get_input_embeddings().weight: + return self._gather_linear(module) + elif hasattr(module, "bias") and module.bias is not None: + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) + + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + + module.bias.data = b[: module.weight.size()[0]] + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + + module.__class__ = nn.Linear + + def _gather_linear(self, module: Linear2p5D): + is_reversed = module.reversed + fusion_degree = module.fusion_degree + # slice_bias = module.slice_bias + + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) + + w = gather_2d( + self.parallel_context, + module.weight.data, + tesseract_dim=tesseract_dim, + col_first=True, + ) + if fusion_degree > 1: + w = self._reconstruct_combined_qkv(w, tesseract_dim, fusion_degree, False) + if is_reversed: + w = w.t() + module.weight.data = w + + if hasattr(module, "bias") and module.bias is not None: + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + if fusion_degree > 1: + b = self._reconstruct_combined_qkv( + b, tesseract_dim, fusion_degree, True + ) + b = b.view(b.size()[1:]) + module.bias.data = b + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.reversed + del module.fusion_degree + del module.orig_module + del module.gather_output + del module.parallel_context + + module.__class__ = nn.Linear + + def _gather_layernorm(self, module): + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) + if hasattr(module, "weight") and module.weight is not None: + if module.weight.dim() >= 1: + w = gather_1d( + self.parallel_context, module.weight.data, tesseract_dim, 0 + ) + module.weight.data = w + + if hasattr(module.weight, "oslo_parallel"): + del module.weight.oslo_parallel + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + module.bias.data = b + + if hasattr(module.bias, "oslo_parallel"): + del module.bias.oslo_parallel + + del module.partitioned_dim + del module.row_rank + del module.col_rank + del module.dep_rank + del module.tesseract_dim + del module.data_parallel_rank + del module.pipeline_parallel_rank + del module.tensor_parallel_size + del module.pipeline_parallel_size + del module.orig_module + _update_module_arguments( + module, + normalized_shape=module.weight.size()[0], + ) + module.__class__ = nn.LayerNorm + + @staticmethod + def _reconstruct_combined_qkv(tensor, tesseract_dim, fusion_degree, is_bias=False): + last_dim = tensor.size()[-1] + if is_bias is False: + reshaped_w = tensor.view(tesseract_dim * fusion_degree, -1, last_dim) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(tesseract_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) + else: + reshaped_w = tensor.view(fusion_degree * tesseract_dim, -1) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(tesseract_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) + return recon_w + + @staticmethod + def _reconstrunct_combined_qkv_bias(tensor, tessearct_dim, fusion_degree): + tensor = [ + [tensor[j * tessearct_dim + k] for k in range(tessearct_dim)] + for j in range(fusion_degree) + ] + tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) + tensor = [tensor[j] for j in range(tessearct_dim)] + return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py index 364484b9..fe28b12c 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py @@ -1,6 +1,7 @@ from typing import Any, Tuple, Optional import torch +import torch.distributed as dist from torch import Tensor from torch.cuda.amp import custom_bwd, custom_fwd @@ -882,3 +883,77 @@ def broadcast_weight_3d_from_diagonal( weight_parallel_mode, output_parallel_mode, ) + + +def gather_3d(parallel_context, tensor, cubic_dim): + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_3D_OUTPUT), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_3D_INPUT), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_3D_WEIGHT), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_2d(parallel_context, tensor, cubic_dim, input_first=True): + if input_first: + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_3D_INPUT), + ) + tensor = torch.cat(tensor_list, dim=0) + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_3D_OUTPUT), + ) + tensor = torch.cat(tensor_list, dim=-1) + else: + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor, + parallel_context.get_group(ParallelMode.TENSOR_3D_OUTPUT), + ) + tensor = torch.cat(tensor_list, dim=-1) + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(ParallelMode.TENSOR_3D_INPUT), + ) + tensor = torch.cat(tensor_list, dim=0) + return tensor + + +def gather_1d(parallel_context, tensor, cubic_dim, dim=-1, mode=None): + if mode is None: + mode = ( + ParallelMode.TENSOR_3D_OUTPUT if dim == -1 else ParallelMode.TENSOR_3D_INPUT + ) + tensor_list = [torch.zeros_like(tensor) for _ in range(cubic_dim)] + dist.all_gather( + tensor_list, + tensor.contiguous(), + parallel_context.get_group(mode), + ) + tensor = torch.cat(tensor_list, dim=dim) + return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py index 046f550f..4543cee2 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py @@ -15,14 +15,21 @@ from oslo.torch.nn.modules.layer_norm import ( LayerNorm3D, ) +from oslo.torch.nn.parallel.tensor_parallel._parallel_3d._ops import ( + gather_3d, + gather_2d, + gather_1d, +) from oslo.torch.distributed.nn.functional import ( scatter, ) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, ) +from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import ( + BaseTensorParallelWrapper, +) from oslo.torch.nn.parallel.utils import ( - ParallelWrapper, _update_module_arguments, is_huggingface_model, is_oslo_model, @@ -34,7 +41,7 @@ from oslo.transformers.constants import BATCH_DIMENSIONS -class _TensorParallel3D(ParallelWrapper): +class _TensorParallel3D(BaseTensorParallelWrapper): """ PyTorch module for 3D tensor parallelism @@ -48,8 +55,9 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + module_args: dict = None, ): - super().__init__() + super().__init__(module, parallel_context, mapping) self.module = module self.parallel_context = parallel_context self.device = torch.cuda.current_device() @@ -62,6 +70,15 @@ def __init__( "`mapping` must be input if the model is not huggingface model." ) + if module_args is None: + if is_huggingface_model(module): + module_args = module.config + else: + raise ValueError( + "`config` must be input if the model is not huggingface model." + ) + + self.config = module_args self.tensor_parallel_mapping = TensorParallelMapping(mapping) self._parallelize() @@ -152,6 +169,9 @@ def _parallelize_head(self): reversed=self.tensor_parallel_mapping.is_reversed( self.module, param_name ), + gather_output=self.tensor_parallel_mapping.is_gather_output( + self.module, param_name + ), ) @staticmethod @@ -398,7 +418,7 @@ def _slice_layernorm(self, module): ) module.__class__ = LayerNorm3D - def _slice_head(self, module, reversed): + def _slice_head(self, module, reversed, gather_output): if module.weight is not self.module.get_input_embeddings().weight: self._slice_linear( module=module, @@ -408,7 +428,7 @@ def _slice_head(self, module, reversed): ) _update_module_arguments( module=module, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, ) else: cubic_dim = self.parallel_context.get_world_size( @@ -423,10 +443,9 @@ def _slice_head(self, module, reversed): weight_rank = self.parallel_context.get_local_rank( ParallelMode.TENSOR_3D_WEIGHT ) - if hasattr(module, "bias") and module.bias is not None: if module.bias.dim() >= 1: - bias_list = module.bias.data.chunk(cubic_dim, dim=0) + bias_list = module.bias.chunk(cubic_dim, dim=0) module.bias.data = bias_list[input_rank].contiguous() if hasattr(module.bias, "oslo_parallel"): @@ -455,7 +474,210 @@ def _slice_head(self, module, reversed): reversed=reversed, fusion_degree=1, skip_bias_add=False, - gather_output=not is_oslo_model(self.module), + gather_output=not is_oslo_model(self.module) and gather_output, orig_module=copy.deepcopy(module.__class__), ) module.__class__ = Linear3D + + @torch.no_grad() + def deparallelize(self): + # must deparallelize embedding first than linear + self._deparallelize_embedding() + self._deparallelize_linear() + self._deparallelize_layernorm() + self._deparallelize_head() + self._rollback_mp_arguments() + + def _rollback_mp_arguments(self): + for module in self.module.modules(): + for elem in self.tensor_parallel_mapping.update_attrs(self.module): + if hasattr(module, elem.name): + cubic_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_3D_INPUT + ) + expanded_arg = getattr(module, elem.name) * cubic_dim + setattr(module, elem.name, expanded_arg) + + def _deparallelize_embedding(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == VocabParallelEmbedding3D: + self._gather_embedding(module) + if module.__class__ == Embedding3D: + self._gather_embedding(module) + + def _deparallelize_linear(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_column_parallel( + self.module, param_name + ) or self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): + self._gather_linear(module) + + def _deparallelize_head(self): + for param_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_head( + self.module, param_name + ) and isinstance(module, Linear3D): + self._gather_head(module) + + def _deparallelize_layernorm(self): + for param_name, module in self.module.named_modules(): + if module.__class__ == LayerNorm3D: + self._gather_layernorm(module) + + def _gather_embedding(self, module): + cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) + if hasattr(module, "vocab_start_index") and hasattr(module, "vocab_end_index"): + w = gather_3d(self.parallel_context, module.weight.data, cubic_dim) + + assert hasattr( + self.module, "orig_vocab_size" + ), "wrapper's vocab embedding module must have attribute 'orig_vocab_size'." + orig_vocab_size = self.module.orig_vocab_size + module.weight.data = w[:orig_vocab_size, :].contiguous() + + _update_module_arguments( + module=module, + vocab_start_index=None, + vocab_end_index=None, + parallel_context=None, + num_embeddings=module.weight.size()[0], + embedding_dim=module.weight.size()[1], + orig_module=None, + ) + else: + w = gather_1d(self.parallel_context, module.weight, cubic_dim, -1) + module.weight.data = w + + _update_module_arguments( + module=module, + parallel_context=None, + embedding_dim=module.weight.size()[1], + ) + module.__class__ = nn.Embedding + + def _gather_head(self, module: Linear3D): + if module.weight is not self.module.get_input_embeddings().weight: + return self._gather_linear(module) + elif hasattr(module, "bias") and module.bias is not None: + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_3D_INPUT + ) + b = gather_1d(self.parallel_context, module.bias.data, tesseract_dim, 0) + module.bias.data = b[: module.weight.size()[0]] + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + + module.__class__ = nn.Linear + + def _gather_linear(self, module: Linear3D): + is_reversed = module.reversed + fusion_degree = module.fusion_degree + # slice_bias = module.slice_bias + + cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) + + w = gather_1d( + self.parallel_context, + module.weight.data, + cubic_dim, + 0, + ParallelMode.TENSOR_3D_WEIGHT, + ) + if fusion_degree > 1: + w = self._reconstruct_combined_qkv(w, cubic_dim, fusion_degree, False) + if is_reversed: + w = w.t() + w = gather_2d(self.parallel_context, w, cubic_dim=cubic_dim, input_first=False) + module.weight.data = w + + if hasattr(module, "bias") and module.bias is not None: + b = gather_1d(self.parallel_context, module.bias.data, cubic_dim, 0) + if fusion_degree > 1: + b = self._reconstruct_combined_qkv(b, cubic_dim, fusion_degree, True) + b = b.view(b.size()[1:]) + module.bias.data = b + + _update_module_arguments( + module=module, + in_features=module.weight.size()[1], + out_features=module.weight.size()[0], + parallel_context=self.parallel_context, + skip_bias_add=module.skip_bias_add + if hasattr(module, "skip_bias_add") + else False, + ) + + module.__class__ = nn.Linear + + def _gather_layernorm(self, module): + cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) + if hasattr(module, "weight") and module.weight is not None: + if module.weight.dim() >= 1: + w = gather_1d(self.parallel_context, module.weight.data, cubic_dim, 0) + module.weight.data = w + + if hasattr(module.weight, "oslo_parallel"): + del module.weight.oslo_parallel + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + b = gather_1d(self.parallel_context, module.bias.data, cubic_dim, 0) + module.bias.data = b + + if hasattr(module.bias, "oslo_parallel"): + del module.bias.oslo_parallel + + _update_module_arguments( + module, + normalized_shape=module.weight.size()[0], + ) + module.__class__ = nn.LayerNorm + + @staticmethod + def _reconstruct_combined_qkv(tensor, cubic_dim, fusion_degree, is_bias=False): + last_dim = tensor.size()[-1] + if is_bias is False: + reshaped_w = tensor.view(cubic_dim * fusion_degree, -1, last_dim) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(cubic_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) + else: + reshaped_w = tensor.view(fusion_degree * cubic_dim, -1) + recon_w = ( + torch.cat( + [ + reshaped_w[i * fusion_degree : (i + 1) * fusion_degree] + for i in range(cubic_dim) + ], + 1, + ) + .view(-1, last_dim) + .contiguous() + ) + return recon_w + + @staticmethod + def _reconstrunct_combined_qkv_bias(tensor, cubic_dim, fusion_degree): + tensor = [ + [tensor[j * cubic_dim + k] for k in range(cubic_dim)] + for j in range(fusion_degree) + ] + tensor = list(map(lambda x: torch.cat([*x], dim=0), zip(*tensor))) + tensor = [tensor[j] for j in range(cubic_dim)] + return tensor diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index fba41a1f..ce7eb142 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -1,8 +1,16 @@ -from typing import Optional +import imp +from typing import Union, Optional, Callable + +import os +import json +from operator import xor import warnings import torch import torch.nn as nn +import torch.distributed as dist + +from transformers import AutoConfig from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._wrapper import ( @@ -28,6 +36,8 @@ unwrap_parallel, get_parallel_context, is_huggingface_model, + allocate_params, + get_parameter_dtype, ) @@ -79,6 +89,7 @@ def __init__( parallel_context: Optional[ParallelContext] = None, mapping: dict = None, memory_priority: bool = False, + module_args: dict = None, ): super().__init__() self.parallel_context = get_parallel_context(module, parallel_context) @@ -93,14 +104,20 @@ def __init__( if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D: self.module = _TensorParallel1D( - module, self.parallel_context, mapping, memory_priority + module, self.parallel_context, mapping, memory_priority, module_args ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D: - self.module = _TensorParallel2D(module, self.parallel_context, mapping) + self.module = _TensorParallel2D( + module, self.parallel_context, mapping, module_args + ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D: - self.module = _TensorParallel2p5D(module, self.parallel_context, mapping) + self.module = _TensorParallel2p5D( + module, self.parallel_context, mapping, module_args + ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_3D: - self.module = _TensorParallel3D(module, self.parallel_context, mapping) + self.module = _TensorParallel3D( + module, self.parallel_context, mapping, module_args + ) else: raise ValueError( "currently, only 1d, 2d, 2p5d tensor parallelism is supported." @@ -140,8 +157,8 @@ def _resize_vocab_size(model, parallel_context): module.weight.data = new_embeddings module.num_embeddings = new_vocab_size - setattr(module, "orig_num_classes", vocab_size) - setattr(unwrapped_model, "orig_vocab_size", vocab_size) + setattr(module, "orig_num_classes", vocab_size) + setattr(unwrapped_model, "orig_vocab_size", vocab_size) return model @staticmethod @@ -167,6 +184,38 @@ def _resize_num_classes(model, parallel_context, mapping): module.out_features = ( unwrapped_model.get_input_embeddings().num_embeddings ) + + assert hasattr( + unwrapped_model.get_input_embeddings(), "orig_num_classes" + ), "call _resize_vocab before _resize_num_classes" + out_features = ( + unwrapped_model.get_input_embeddings().orig_num_classes + ) + setattr(module, "orig_num_classes", out_features) + setattr( + unwrapped_model, + f"orig_{param_name.split('.')[-1]}_num_classes", + out_features, + ) + + if hasattr(module, "bias") and module.bias is not None: + out_features = module.bias.size()[0] + new_out_features = out_features + + while new_out_features % divisible_by != 0: + new_out_features += 1 + + if new_out_features != out_features: + padding = torch.zeros( + new_out_features - out_features, + dtype=module.bias.dtype, + device=module.bias.device, + ) + new_bias = torch.cat( + tensors=[module.bias.data, padding], + dim=0, + ) + module.bias.data = new_bias else: out_features, in_features = module.weight.size() new_out_features = out_features @@ -208,8 +257,45 @@ def _resize_num_classes(model, parallel_context, mapping): ) return model - def _restore_vocab_size(self, model, parallel_context): - pass + @torch.no_grad() + def save_parallelized( + self, + save_directory: Union[str, os.PathLike], + save_config: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + merge_checkpoints: bool = False, + mapping: Optional[dict] = None, + **kwargs, + ): + unwrapped_model = unwrap_parallel(self.module.module) + if is_huggingface_model(unwrapped_model): + new_module = unwrapped_model.__class__(self.module.config) + else: + new_module = unwrapped_model.__class__(**self.module.config) + + new_module = self._resize_vocab_size(new_module, self.parallel_context) + new_module = self._resize_num_classes( + new_module, self.parallel_context, mapping + ) + + new_module = self.module.save_parallelized( + new_module, + save_directory, + save_config, + state_dict, + save_function, + merge_checkpoints, + mapping, + **kwargs, + ) + + return new_module + + @staticmethod + def get_module_args(module): + state_dict = module.state_dict() + return {key: value.shape for key, value in state_dict.items()} - def _restore_num_classes(self, model, parallel_context): - pass + def from_parallelized(self, path): + return self.module.from_parallelized(path) diff --git a/oslo/torch/nn/parallel/utils.py b/oslo/torch/nn/parallel/utils.py index c3414a1c..f2b2be45 100644 --- a/oslo/torch/nn/parallel/utils.py +++ b/oslo/torch/nn/parallel/utils.py @@ -65,6 +65,11 @@ def _update_module_arguments(module: nn.Module, **kwargs): setattr(module, k, v) +def _remove_module_arguments(module: nn.Module, args: list): + for k in args: + delattr(module, k) + + def allocate_params(model: nn.Module, parallel_context: ParallelContext): for name, parameter in model.named_parameters(): if hasattr(parameter, "oslo_parallel"): @@ -100,3 +105,12 @@ def get_parallel_context(module: nn.Module, parallel_context: ParallelContext): ) return parallel_context + + +def zero_rank_log(txt): + import torch.distributed as dist + + if dist.get_rank() == 0: + print(txt) + # 모니터링 생성 대기 + dist.barrier() diff --git a/oslo/transformers/mapping_utils.py b/oslo/transformers/mapping_utils.py index 84e82c89..dbca283d 100644 --- a/oslo/transformers/mapping_utils.py +++ b/oslo/transformers/mapping_utils.py @@ -72,6 +72,7 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): "sop_classifier.classifier", "classifier", "qa_outputs", + gather_output=True, ), ], "Bart": [ @@ -79,14 +80,26 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): Column("classification_head.dense", gather_output=True), Row("out_proj", "fc2"), Update("embed_dim", "num_heads"), - Head("lm_head", "classification_head.out_proj", "qa_outputs"), + Head( + "lm_head", + "classification_head.out_proj", + "qa_outputs", + gather_output=True, + ), ], "Bert": [ Column("query", "key", "value", "intermediate.dense"), - Column("pooler.dense", "transform.dense", gather_output=True), + Column("pooler.dense", gather_output=True), Row("output.dense"), Update("num_attention_heads", "all_head_size"), - Head("decoder", "seq_relationship", "classifier", "qa_outputs"), + Head("transform.dense", gather_output=False), + Head( + "decoder", + "seq_relationship", + "classifier", + "qa_outputs", + gather_output=True, + ), ], "Blenderbot": [ Column("q_proj", "k_proj", "v_proj", "fc1"), @@ -111,19 +124,19 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): Column("c_fc", "q_attn", reversed=True), Row("c_proj", reversed=True), Update("embed_dim", "split_size", "num_heads"), - Head("lm_head", "score", "classifier", "summary"), + Head("lm_head", "score", "classifier", "summary", gather_output=True), ], "GPTNeo": [ Column("q_proj", "k_proj", "v_proj", "c_fc"), Row("out_proj", "c_proj"), Update("embed_dim", "num_heads"), - Head("lm_head", "score", "qa_outputs"), + Head("lm_head", "score", "qa_outputs", gather_output=True), ], "GPTJ": [ Column("q_proj", "k_proj", "v_proj", "fc_in"), Row("out_proj", "fc_out"), Update("embed_dim", "num_attention_heads"), - Head("lm_head", "score"), + Head("lm_head", "score", gather_output=True), ], "Electra": [ Column("query", "key", "value", "intermediate.dense"), @@ -143,19 +156,27 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): "classifier", "qa_outputs", "summary", + gather_output=True, ), ], "Roberta": [ Column("query", "key", "value", "intermediate.dense"), Column( + "lm_head.dense", "classifier.dense", "roberta.pooler", - "lm_head.dense", gather_output=True, ), Row("output.dense"), Update("num_attention_heads", "all_head_size"), - Head("lm_head.decoder", "classifier.out_proj", "classifier", "qa_outputs"), + Head("lm_head.dense"), + Head( + "lm_head.decoder", + "classifier.out_proj", + "classifier", + "qa_outputs", + gather_output=True, + ), ], } diff --git a/oslo/transformers/models/__init__.py b/oslo/transformers/models/__init__.py index e69de29b..6dfb8583 100644 --- a/oslo/transformers/models/__init__.py +++ b/oslo/transformers/models/__init__.py @@ -0,0 +1,11 @@ +from oslo.transformers.models.gpt2.modeling_gpt2 import ( + GPT2Model, + GPT2LMHeadModel, + GPT2ForSequenceClassification, + GPT2ForTokenClassification, +) + + +from oslo.transformers.training_args import TrainingArguments + +# from oslo.transformers.trainer import Trainer diff --git a/tests/torch/nn/parallel/tensor_parallel/2p5d/test_linear_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_linear_2p5d.py index 59f4df74..534cc6ca 100644 --- a/tests/torch/nn/parallel/tensor_parallel/2p5d/test_linear_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_linear_2p5d.py @@ -1,9 +1,11 @@ -from copy import deepcopy import torch import torch.distributed as dist from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn import Linear2p5D -from _utils import split_2p5d, split_bias_2p5d, gather_2p5d + +from _utils import split_2p5d, split_2d, gather_2p5d, gather_2d + +from copy import deepcopy tp_size = 8 tp_depth = 2 @@ -43,14 +45,25 @@ if parallel_context.get_global_rank() == 0: print(f"original output: \n{out}\n") - print(f"original update output: \n{out_update}\n") - -input_ = split_2p5d(input_, tesseract_dim, parallel_context=parallel_context) -ptarget = split_2p5d(target, tesseract_dim, parallel_context=parallel_context) -w = split_2p5d(w, tesseract_dim, parallel_context=parallel_context) -b = split_bias_2p5d(b, tesseract_dim, parallel_context=parallel_context) - -linear_2p5d = Linear2p5D(input_dim, hidden_dim, parallel_context=parallel_context) + print(f"original updated weight: \n{linear.weight.data}\n") + print(f"original updated bias: \n{linear.bias.data}\n") + +# split input_ into +# 0:[0, 0, 0], 1:[0, 0, 1], 2:[0, 1, 0], 3:[0, 1, 1], 4:[1, 0, 0], 5:[1, 0, 1], 6:[1, 1, 0], 7:[1, 1, 1] +# input shape: (m/dq, n/q) +input_ = split_2d(parallel_context, input_, tesseract_dim) +ptarget = split_2d(parallel_context, target, tesseract_dim) + +# split weight into 0,4:[0, 0], 1,5:[1, 0], 2,6:[0, 1], 3,7:[1, 1] +# input shape: (n/q, k/q) +w = split_2d(parallel_context, w, tesseract_dim, col_first=False) +# split bias into 0,4:[0], 2,6:[1] +# input shape: (k/q) +b = b.chunk(tesseract_dim, dim=0)[ + parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) +] + +linear_2p5d = Linear2p5D(4, 4, parallel_context=parallel_context) linear_2p5d.weight.data.copy_(w) linear_2p5d.bias.data.copy_(b) @@ -62,8 +75,11 @@ pout_update = linear_2p5d(input_) -pout = gather_2p5d(pout, tesseract_dim, parallel_context=parallel_context) -pout_update = gather_2p5d(pout_update, tesseract_dim, parallel_context=parallel_context) +pout = gather_2d(parallel_context, pout, tesseract_dim, False) +pout_update = gather_2d(parallel_context, pout_update, tesseract_dim, False) + +# w = gather_2d(parallel_context, linear_2p5d.weight.data, tesseract_dim, True) +# b = gather_1d(parallel_context, linear_2p5d.bias.data, tesseract_dim, 0) if parallel_context.get_global_rank() == 0: print(f"parallel output: \n{pout}\n") @@ -72,8 +88,12 @@ if parallel_context.get_global_rank() == 0: sse = torch.sum((out - pout) ** 2).item() sse_update = torch.sum((out_update - pout_update) ** 2).item() + minmax_update = (out_update - pout_update) ** 2 print(f"output sse: \n{sse}\n") print(f"next output sse: \n{sse_update}\n") + print(f"next output max: \n{minmax_update.max()}\n") + print(f"next output min: \n{minmax_update.min()}\n") + linear_2p5d = Linear2p5D( input_dim, hidden_dim, gather_output=True, parallel_context=parallel_context @@ -93,8 +113,18 @@ print(f"parallel output (gather_output=True): \n{pout}\n") print(f"parallel update output (gather_output=True): \n{pout_update}\n") + if parallel_context.get_global_rank() == 0: sse = torch.sum((out - pout) ** 2).item() sse_update = torch.sum((out_update - pout_update) ** 2).item() + minmax_update = (out_update - pout_update) ** 2 print(f"output sse (gather_output=True): \n{sse}\n") print(f"next output sse (gather_output=True): \n{sse_update}\n") + import pprint + + # top5 = torch.clamp(minmax_update.flatten(), 1e-8) + top5 = minmax_update.flatten() + top5 = [t.item() for t in top5] + top5 = [top5[i : i + 4] for i in range(0, len(top5), 4)] + pprint.pprint(top5) + print(f"next output min: \n{minmax_update.min()}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/2p5d/test_wrapper_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_wrapper_2p5d.py index 45f19be3..9ac65560 100644 --- a/tests/torch/nn/parallel/tensor_parallel/2p5d/test_wrapper_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_wrapper_2p5d.py @@ -3,17 +3,49 @@ import torch import torch.distributed as dist from torch.optim import Adam -from torch.utils.data import DataLoader from datasets import load_dataset -from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from torch.utils.data import DataLoader +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, +) + from oslo.torch.nn.parallel.tensor_parallel import TensorParallel from oslo.torch.nn.parallel.utils import allocate_params from oslo.torch.distributed import ParallelContext, ParallelMode + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end - start + + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + tp_size = 8 batch_size = 16 model_name = "gpt2" +model_name = "gpt2" +mkwargs = {} +dataset_name = "squad" + # parallel context 생성 parallel_context = ParallelContext.from_torch( data_parallel_size=1, @@ -27,7 +59,7 @@ torch.manual_seed(0) # 토크나이저 생성 -tokenizer = AutoTokenizer.from_pretrained(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) tokenizer.pad_token = tokenizer.eos_token # 모델 생성 및 병렬화 수행 @@ -46,15 +78,14 @@ optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) # 데이터셋 생성 -datasets = load_dataset("squad").data["train"]["context"] +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] datasets = [str(sample) for sample in datasets[:500]] dataloader = DataLoader(datasets, batch_size=batch_size) # 모니터링 생성 if dist.get_rank() == 0: - wandb.init( - project="oslo", name=f"{model_name}_nprocs{tp_size}_tp2p5d_bs{batch_size}" - ) + wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") cur = time.time() # 모니터링 생성 대기 @@ -73,34 +104,21 @@ max_length=512, ).to("cuda") - fw_start = time.time() - loss_no_tp = model_no_tp(**inputs, labels=inputs["input_ids"]).loss - fw_time = time.time() - fw_start - - fw_start_tp = time.time() - loss_tp = wrapper_tp(**inputs, labels=inputs["input_ids"]).loss - fw_time_tp = time.time() - fw_start_tp - - bw_start = time.time() - loss_no_tp.backward() - optimizer_no_tp.step() - bw_time = time.time() - bw_start + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) - bw_start_tp = time.time() - loss_tp.backward() - optimizer_tp.step() - bw_time_tp = time.time() - bw_start_tp + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) if dist.get_rank() == 0: - print(f"[tp/notp loss]: {loss_tp:.4f}, {loss_no_tp:.4f}") wandb.log( { "tp_loss": loss_tp, "notp_loss": loss_no_tp, - "tp_fw_time": fw_time_tp, - "notp_fw_time": fw_time, - "tp_bw_time": bw_time_tp, - "notp_bw_time": bw_time, + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, } ) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/__init__.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/__init__.py new file mode 100644 index 00000000..28fcc144 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/__init__.py @@ -0,0 +1,3 @@ +from .. import _utils + +_ALL__ = [_utils] diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py new file mode 100644 index 00000000..da30d651 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_deparallelize.py @@ -0,0 +1,178 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset + +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, + AutoConfig, +) +from transformers import AutoModelForMaskedLM + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode + +import numpy as np +import time +import os +import random + + +def seed_all(seed: int = 1930): + + print("Using Seed Number {}".format(seed)) + + os.environ["PYTHONHASHSEED"] = str( + seed + ) # set PYTHONHASHSEED env var at fixed value + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) + np.random.seed(seed) # for numpy pseudo-random generator + random.seed(seed) # set fixed value for python built-in pseudo-random generator + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + + +def seed_worker(_worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +seed_all(seed=1994) + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end - start + + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 4 +tp_depth = 1 + +model_name = "bert-base-uncased" +mkwargs = {"pad_token": "[PAD]"} +dataset_name = "squad" + + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_1D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +# tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name) +).cuda() +model_tp = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp1d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized("test/", merge_checkpoints=True) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() +optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + optimizer_gathered.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = fw( + model_gathered, **inputs, labels=inputs["input_ids"] + ) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, gathered_bw_time = bw(loss_gathered) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_gathered.step() + + if dist.get_rank() == 0: + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time, + } + ) + +dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py new file mode 100644 index 00000000..fdb0b32b --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_load_parallel.py @@ -0,0 +1,140 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end - start + + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 4 +tp_depth = 1 + +model_name = "gpt2" +mkwargs = {} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_1D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() +model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp1d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized("test/", merge_checkpoints=False) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_reparallel = TensorParallel( + GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), parallel_context +) +allocate_params(model_reparallel, parallel_context) +model_reparallel.from_parallelized("test/") +optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + model_reparallel.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_reparallel, reparallel_fw_time = fw( + wrapper_tp, **inputs, labels=inputs["input_ids"] + ) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, reparallel:{loss_reparallel}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "reparallel": loss_reparallel}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, reparallel_bw_time = bw(loss_reparallel) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_reparallel.step() + + if dist.get_rank() == 0: + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "reparallel.forward.time:": reparallel_fw_time, + "reparallel.backward.time:": reparallel_bw_time, + } + ) + +dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py new file mode 100644 index 00000000..14033085 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_qkv.py @@ -0,0 +1,72 @@ +import torch.nn as nn + +from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._wrapper import ( + _TensorParallel1D, +) +from oslo.torch.nn import ColumnParallelLinear, RowParallelLinear +from oslo.torch.distributed import ParallelContext, ParallelMode +from copy import deepcopy +from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import gather_1d + +tp_size = 4 +tp_depth = 2 + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_1D, + tensor_parallel_depth=tp_depth, +) + +world_size = parallel_context.get_world_size(ParallelMode.TENSOR_1D) +rank = parallel_context.get_local_rank(ParallelMode.TENSOR_1D) +fusion_degree = 3 + +linear = nn.Linear(4, fusion_degree * 4).cuda() +w = deepcopy(linear.weight.data) +b = deepcopy(linear.bias.data) + +dim = 1 + +weight_list = w.t().chunk(fusion_degree * world_size, dim=dim) +bias_list = b.chunk(fusion_degree * world_size, dim=0) + +# [t][f*t] +weight_list = _TensorParallel1D._deconstruct_combined_qkv( + weight_list, world_size, fusion_degree, dim +) +chunked_w = weight_list[rank].contiguous() +bias_list = _TensorParallel1D._deconstruct_combined_qkv( + bias_list, world_size, fusion_degree, 0 +) +chunked_b = bias_list[rank].contiguous() + +linear_1d = RowParallelLinear( + 4, fusion_degree * 4, parallel_context=parallel_context, bias=True +) +if parallel_context.get_global_rank() == 0: + print(chunked_w.size()) + print(linear_1d.weight.data.size()) +linear_1d.weight.data = chunked_w +linear_1d.bias.data = chunked_b + +recon_chunked_w = gather_1d(parallel_context, linear_1d.weight.data, world_size, dim) +recon_chunked_b = gather_1d(parallel_context, linear_1d.bias.data, world_size, 0) +# reshaped_w = recon_chunked_w.view(-1, tesseract_dim, 4) +# recon_w = torch.cat([reshaped_w[:3], reshaped_w[3:]], 1).view(-1, 4).contiguous() +print(recon_chunked_w.shape) +recon_w = _TensorParallel1D._reconstruct_combined_qkv( + recon_chunked_w, world_size, fusion_degree, dim +) +recon_b = _TensorParallel1D._reconstruct_combined_qkv( + recon_chunked_b, world_size, fusion_degree, 0 +) + +if parallel_context.get_global_rank() == 0: + print(f"original w: \n{w}\n") + print(f"reconstruct w: \n{recon_w}\n") + + print(f"original b: \n{b}\n") + print(f"reconstruct b: \n{recon_b}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py new file mode 100644 index 00000000..1db376eb --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/deparallel/test_vocab.py @@ -0,0 +1,89 @@ +import torch +import torch.distributed as dist + +from oslo.torch.distributed import ParallelContext, ParallelMode +from oslo.torch.nn import VocabParallelEmbedding2p5D +from oslo.torch.nn.parallel import utils + +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( + split_batch_2d, + split_2d, + gather_2d, +) + +from copy import deepcopy + + +tp_size = 8 +tp_depth = 2 + +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +torch.set_printoptions(sci_mode=False) +torch.manual_seed(0) +tesseract_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) +input_ = torch.LongTensor([[1, 2, 3, 4], [5, 6, 7, 8]]).cuda() +target = torch.randn((2, 4, 16)).cuda() +dist.broadcast(input_, src=0) +dist.broadcast(target, src=0) + +vocab_embedding = torch.nn.Embedding(10, 16).cuda() +w = deepcopy(vocab_embedding.weight.data) + +out = vocab_embedding(input_) +optimizer = torch.optim.Adam(vocab_embedding.parameters(), lr=1e-3) +logits = torch.nn.MSELoss()(out, target) +logits.backward() +optimizer.step() + +out_update = vocab_embedding(input_) + +if parallel_context.get_global_rank() == 0: + print(f"original output: \n{out}\n") + print(f"original update output: \n{out_update}\n") + # print(f"vocab start: {vocab_embedding.start_index}, vocab end: {vocab_embedding.end_index}") + +input_ = split_batch_2d(parallel_context, input_, tesseract_dim) +# split target into 0:[0, 0], 1:[0, 1], 2:[1, 0], 3:[1, 1] +target = split_2d(parallel_context, target, tesseract_dim, col_first=True) +# split weight into 0:[0, 0], 1:[1, 0], 2:[0, 1], 3:[1, 1] +w = split_2d(parallel_context, w, tesseract_dim, col_first=False) + +vocab_embedding_2p5d = VocabParallelEmbedding2p5D( + 10, 16, parallel_context=parallel_context +) +vocab_embedding_2p5d.weight.data.copy_(w) + +pout = vocab_embedding_2p5d(input_) +optimizer = torch.optim.Adam(vocab_embedding_2p5d.parameters(), lr=1e-3) +logits = torch.nn.MSELoss()(pout, target) +logits.backward() +optimizer.step() + +if parallel_context.get_global_rank() == 0: + unwrapped_model = utils.unwrap_parallel(vocab_embedding_2p5d) + print(f"original vocab size: {unwrapped_model.orig_vocab_size}") + + +# +# +# pout_update = vocab_embedding_2p5d(input_) +# +# pout = gather_2d(parallel_context, pout, tesseract_dim, col_first=False) +# pout_update = gather_2d(parallel_context, pout_update, tesseract_dim, col_first=False) +# +# if parallel_context.get_global_rank() == 0: +# print(f"parallel output: \n{pout}\n") +# print(f"parallel update output: \n{pout_update}\n") +# +# if parallel_context.get_global_rank() == 0: +# sse = torch.sum((out - pout) ** 2).item() +# sse_update = torch.sum((out_update - pout_update) ** 2).item() +# print(f"output sse: \n{sse}\n") +# print(f"next output sse: \n{sse_update}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/__init__.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/__init__.py new file mode 100644 index 00000000..28fcc144 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/__init__.py @@ -0,0 +1,3 @@ +from .. import _utils + +_ALL__ = [_utils] diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py new file mode 100644 index 00000000..a7878003 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_deparallelize.py @@ -0,0 +1,178 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset + +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, + AutoConfig, +) + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode + +import numpy as np +import time +import os +import random + + +def seed_all(seed: int = 1930): + + print("Using Seed Number {}".format(seed)) + + os.environ["PYTHONHASHSEED"] = str( + seed + ) # set PYTHONHASHSEED env var at fixed value + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) + np.random.seed(seed) # for numpy pseudo-random generator + random.seed(seed) # set fixed value for python built-in pseudo-random generator + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + + +def seed_worker(_worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +seed_all(seed=1994) + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end - start + + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 4 +tp_depth = 1 + +model_name = "roberta-base" +mkwargs = { + # 'pad_token': '[PAD]' +} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +# tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name) +).cuda() +model_tp = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized("test/", merge_checkpoints=True) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() +optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + optimizer_gathered.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = fw( + model_gathered, **inputs, labels=inputs["input_ids"] + ) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, gathered_bw_time = bw(loss_gathered) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_gathered.step() + + if dist.get_rank() == 0: + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time, + } + ) + +dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py new file mode 100644 index 00000000..4a145cf8 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_load_parallel.py @@ -0,0 +1,140 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end - start + + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 4 +tp_depth = 1 + +model_name = "gpt2" +mkwargs = {} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() +model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized("test/", merge_checkpoints=False) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_reparallel = TensorParallel( + GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), parallel_context +) +allocate_params(model_reparallel, parallel_context) +model_reparallel.from_parallelized("test/") +optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + model_reparallel.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_reparallel, reparallel_fw_time = fw( + wrapper_tp, **inputs, labels=inputs["input_ids"] + ) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, reparallel:{loss_reparallel}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "reparallel": loss_reparallel}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, reparallel_bw_time = bw(loss_reparallel) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_reparallel.step() + + if dist.get_rank() == 0: + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "reparallel.forward.time:": reparallel_fw_time, + "reparallel.backward.time:": reparallel_bw_time, + } + ) + +dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py new file mode 100644 index 00000000..3a7f021a --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/deparallel/test_qkv.py @@ -0,0 +1,73 @@ +import torch.nn as nn + +from oslo.torch.nn.parallel.tensor_parallel._parallel_2d._wrapper import ( + _TensorParallel2D, +) +from oslo.torch.nn import Linear2D +from oslo.torch.distributed import ParallelContext, ParallelMode +from copy import deepcopy +from oslo.torch.nn.parallel.tensor_parallel._parallel_2d._ops import ( + gather_1d, + gather_2d, +) + +tp_size = 4 +tp_depth = 1 + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2D, + tensor_parallel_depth=tp_depth, +) + +row_rank = parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW) +col_rank = parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL) +summa_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) +fusion_degree = 3 + +linear = nn.Linear(4, fusion_degree * 4).cuda() +w = deepcopy(linear.weight.data) +b = deepcopy(linear.bias.data) + +weight_list = w.chunk(summa_dim, dim=1) +weight_list = [weight.chunk(summa_dim * fusion_degree, dim=0) for weight in weight_list] +bias_list = b.chunk(summa_dim * fusion_degree, dim=0) + +# [t][f*t] +weight_list = _TensorParallel2D._deconstruct_combined_qkv( + weight_list, summa_dim, fusion_degree +) +bias_list = _TensorParallel2D._deconstruct_combined_qkv( + bias_list, summa_dim, fusion_degree +) +chunked_w = weight_list[row_rank][col_rank] +chunked_b = bias_list[row_rank] + +linear_2d = Linear2D(4, fusion_degree * 4, parallel_context=parallel_context, bias=True) +if parallel_context.get_global_rank() == 0: + print(chunked_w.size()) + print(linear_2d.weight.data.size()) +linear_2d.weight.data = chunked_w +linear_2d.bias.data = chunked_b + +recon_chunked_w = gather_2d(parallel_context, linear_2d.weight.data, summa_dim, True) +recon_chunked_b = gather_1d(parallel_context, linear_2d.bias.data, summa_dim, 0) +# reshaped_w = recon_chunked_w.view(-1, tesseract_dim, 4) +# recon_w = torch.cat([reshaped_w[:3], reshaped_w[3:]], 1).view(-1, 4).contiguous() + +recon_w = _TensorParallel2D._reconstruct_combined_qkv( + recon_chunked_w, summa_dim, fusion_degree, False +) +recon_b = _TensorParallel2D._reconstruct_combined_qkv( + recon_chunked_b, summa_dim, fusion_degree, True +) + +if parallel_context.get_global_rank() == 0: + print(f"original w: \n{w}\n") + print(f"reconstruct w: \n{recon_w}\n") + + print(f"original b: \n{b}\n") + print(f"reconstruct b: \n{recon_b}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/__init__.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/__init__.py new file mode 100644 index 00000000..28fcc144 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/__init__.py @@ -0,0 +1,3 @@ +from .. import _utils + +_ALL__ = [_utils] diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py new file mode 100644 index 00000000..bf6fde34 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_deparallelize.py @@ -0,0 +1,176 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset + +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, + AutoConfig, +) + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + +import numpy as np +import os +import random + + +def seed_all(seed: int = 1930): + + print("Using Seed Number {}".format(seed)) + + os.environ["PYTHONHASHSEED"] = str( + seed + ) # set PYTHONHASHSEED env var at fixed value + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) + np.random.seed(seed) # for numpy pseudo-random generator + random.seed(seed) # set fixed value for python built-in pseudo-random generator + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + + +def seed_worker(_worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +seed_all(seed=1994) + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end - start + + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 8 +tp_depth = 2 + +model_name = "bert-base-uncased" +mkwargs = {"pad_token": "[PAD]"} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +# tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name) +).cuda() +model_tp = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized("test/", merge_checkpoints=True) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() +optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + optimizer_gathered.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = fw( + model_gathered, **inputs, labels=inputs["input_ids"] + ) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, gathered_bw_time = bw(loss_gathered) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_gathered.step() + + if dist.get_rank() == 0: + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time, + } + ) + +dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py new file mode 100644 index 00000000..432a0d16 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_load_parallel.py @@ -0,0 +1,140 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import AutoTokenizer, GPT2Config, GPT2LMHeadModel + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end - start + + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 8 +tp_depth = 2 + +model_name = "gpt2" +mkwargs = {} +dataset_name = "squad" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")).cuda() +model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp2p5d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized("test/", merge_checkpoints=False) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_reparallel = TensorParallel( + GPT2LMHeadModel(GPT2Config.from_pretrained("gpt2")), parallel_context +) +allocate_params(model_reparallel, parallel_context) +model_reparallel.from_parallelized("test/") +optimizer_reparallel = Adam(model_reparallel.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + model_reparallel.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_reparallel, reparallel_fw_time = fw( + wrapper_tp, **inputs, labels=inputs["input_ids"] + ) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, reparallel:{loss_reparallel}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "reparallel": loss_reparallel}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, reparallel_bw_time = bw(loss_reparallel) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_reparallel.step() + + if dist.get_rank() == 0: + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "reparallel.forward.time:": reparallel_fw_time, + "reparallel.backward.time:": reparallel_bw_time, + } + ) + +dist.barrier() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py new file mode 100644 index 00000000..903fb7c1 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_qkv.py @@ -0,0 +1,79 @@ +import torch.nn as nn + +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._wrapper import ( + _TensorParallel2p5D, +) +from oslo.torch.nn import Linear2p5D +from oslo.torch.distributed import ParallelContext, ParallelMode +from copy import deepcopy +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( + gather_1d, + gather_2d, +) + +tp_size = 8 +tp_depth = 2 + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +row_rank = parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) +col_rank = parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) +tesseract_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) +fusion_degree = 3 + +linear = nn.Linear(4, fusion_degree * 4).cuda() +w = deepcopy(linear.weight.data) +b = deepcopy(linear.bias.data) + +weight_list = w.chunk(tesseract_dim, dim=1) +weight_list = [ + weight.chunk(tesseract_dim * fusion_degree, dim=0) for weight in weight_list +] +bias_list = b.chunk(tesseract_dim * fusion_degree, dim=0) + +# [t][f*t] +weight_list = _TensorParallel2p5D._deconstruct_combined_qkv( + weight_list, tesseract_dim, fusion_degree, False +) +bias_list = _TensorParallel2p5D._deconstruct_combined_qkv( + bias_list, tesseract_dim, fusion_degree, True +) +chunked_w = weight_list[row_rank][col_rank] +chunked_b = bias_list[row_rank] + +linear_2p5d = Linear2p5D( + 4, fusion_degree * 4, parallel_context=parallel_context, bias=True +) +if parallel_context.get_global_rank() == 0: + print(chunked_w.size()) + print(linear_2p5d.weight.data.size()) +linear_2p5d.weight.data = chunked_w +linear_2p5d.bias.data = chunked_b + +recon_chunked_w = gather_2d( + parallel_context, linear_2p5d.weight.data, tesseract_dim, True +) +recon_chunked_b = gather_1d(parallel_context, linear_2p5d.bias.data, tesseract_dim, 0) +# reshaped_w = recon_chunked_w.view(-1, tesseract_dim, 4) +# recon_w = torch.cat([reshaped_w[:3], reshaped_w[3:]], 1).view(-1, 4).contiguous() + +recon_w = _TensorParallel2p5D._reconstruct_combined_qkv( + recon_chunked_w, tesseract_dim, fusion_degree, False +) +recon_b = _TensorParallel2p5D._reconstruct_combined_qkv( + recon_chunked_b, tesseract_dim, fusion_degree, True +) + +if parallel_context.get_global_rank() == 0: + print(f"original w: \n{w}\n") + print(f"reconstruct w: \n{recon_w}\n") + + print(f"original b: \n{b}\n") + print(f"reconstruct b: \n{recon_b}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py new file mode 100644 index 00000000..1db376eb --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/deparallel/test_vocab.py @@ -0,0 +1,89 @@ +import torch +import torch.distributed as dist + +from oslo.torch.distributed import ParallelContext, ParallelMode +from oslo.torch.nn import VocabParallelEmbedding2p5D +from oslo.torch.nn.parallel import utils + +from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import ( + split_batch_2d, + split_2d, + gather_2d, +) + +from copy import deepcopy + + +tp_size = 8 +tp_depth = 2 + +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_2P5D, + tensor_parallel_depth=tp_depth, +) + +torch.set_printoptions(sci_mode=False) +torch.manual_seed(0) +tesseract_dim = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) +input_ = torch.LongTensor([[1, 2, 3, 4], [5, 6, 7, 8]]).cuda() +target = torch.randn((2, 4, 16)).cuda() +dist.broadcast(input_, src=0) +dist.broadcast(target, src=0) + +vocab_embedding = torch.nn.Embedding(10, 16).cuda() +w = deepcopy(vocab_embedding.weight.data) + +out = vocab_embedding(input_) +optimizer = torch.optim.Adam(vocab_embedding.parameters(), lr=1e-3) +logits = torch.nn.MSELoss()(out, target) +logits.backward() +optimizer.step() + +out_update = vocab_embedding(input_) + +if parallel_context.get_global_rank() == 0: + print(f"original output: \n{out}\n") + print(f"original update output: \n{out_update}\n") + # print(f"vocab start: {vocab_embedding.start_index}, vocab end: {vocab_embedding.end_index}") + +input_ = split_batch_2d(parallel_context, input_, tesseract_dim) +# split target into 0:[0, 0], 1:[0, 1], 2:[1, 0], 3:[1, 1] +target = split_2d(parallel_context, target, tesseract_dim, col_first=True) +# split weight into 0:[0, 0], 1:[1, 0], 2:[0, 1], 3:[1, 1] +w = split_2d(parallel_context, w, tesseract_dim, col_first=False) + +vocab_embedding_2p5d = VocabParallelEmbedding2p5D( + 10, 16, parallel_context=parallel_context +) +vocab_embedding_2p5d.weight.data.copy_(w) + +pout = vocab_embedding_2p5d(input_) +optimizer = torch.optim.Adam(vocab_embedding_2p5d.parameters(), lr=1e-3) +logits = torch.nn.MSELoss()(pout, target) +logits.backward() +optimizer.step() + +if parallel_context.get_global_rank() == 0: + unwrapped_model = utils.unwrap_parallel(vocab_embedding_2p5d) + print(f"original vocab size: {unwrapped_model.orig_vocab_size}") + + +# +# +# pout_update = vocab_embedding_2p5d(input_) +# +# pout = gather_2d(parallel_context, pout, tesseract_dim, col_first=False) +# pout_update = gather_2d(parallel_context, pout_update, tesseract_dim, col_first=False) +# +# if parallel_context.get_global_rank() == 0: +# print(f"parallel output: \n{pout}\n") +# print(f"parallel update output: \n{pout_update}\n") +# +# if parallel_context.get_global_rank() == 0: +# sse = torch.sum((out - pout) ** 2).item() +# sse_update = torch.sum((out_update - pout_update) ** 2).item() +# print(f"output sse: \n{sse}\n") +# print(f"next output sse: \n{sse_update}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/__init__.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/__init__.py new file mode 100644 index 00000000..28fcc144 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/__init__.py @@ -0,0 +1,3 @@ +from .. import _utils + +_ALL__ = [_utils] diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/test_deparallelize.py b/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/test_deparallelize.py new file mode 100644 index 00000000..f58401d6 --- /dev/null +++ b/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/deparallel/test_deparallelize.py @@ -0,0 +1,176 @@ +import torch.distributed as dist +import wandb +from datasets import load_dataset + +import torch +from torch.optim import Adam +from torch.utils.data import DataLoader +from transformers import ( + AutoTokenizer, + GPT2Config, + GPT2LMHeadModel, + AutoModelForCausalLM, + AutoConfig, +) + +from oslo.torch.nn.parallel.tensor_parallel import TensorParallel +from oslo.torch.nn.parallel.utils import allocate_params +from oslo.torch.distributed import ParallelContext, ParallelMode +import time + +import numpy as np +import os +import random + + +def seed_all(seed: int = 1930): + + print("Using Seed Number {}".format(seed)) + + os.environ["PYTHONHASHSEED"] = str( + seed + ) # set PYTHONHASHSEED env var at fixed value + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) # pytorch (both CPU and CUDA) + np.random.seed(seed) # for numpy pseudo-random generator + random.seed(seed) # set fixed value for python built-in pseudo-random generator + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.enabled = False + + +def seed_worker(_worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +seed_all(seed=1994) + + +def latency_trace(func): + def wrapper(*args, **kwargs): + start = time.time() + result = func(*args, **kwargs) + end = time.time() + return result, end - start + + return wrapper + + +@latency_trace +def fw(func, *args, **kwargs): + return func(*args, **kwargs).loss + + +@latency_trace +def bw(tensors): + return tensors.backward() + + +tp_size = 8 +tp_depth = 1 + +model_name = "jason960903/soongsil-bert-base" +mkwargs = {} +dataset_name = "korquadv1" + +# parallel context 생성 +parallel_context = ParallelContext.from_torch( + data_parallel_size=1, + pipeline_parallel_size=1, + tensor_parallel_size=tp_size, + tensor_parallel_mode=ParallelMode.TENSOR_3D, + tensor_parallel_depth=tp_depth, +) + +# 토크나이저 생성 +tokenizer = AutoTokenizer.from_pretrained(model_name, **mkwargs) +tokenizer.pad_token = tokenizer.eos_token + +# 모델 생성 및 병렬화 수행 +model_no_tp = AutoModelForCausalLM.from_config( + AutoConfig.from_pretrained(model_name) +).cuda() +model_tp = AutoModelForCausalLM.from_config(AutoConfig.from_pretrained(model_name)) +wrapper_tp = TensorParallel(model_tp, parallel_context) +allocate_params(wrapper_tp, parallel_context) +# allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 +# https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py + +if dist.get_rank() == 0: + print(wrapper_tp) + +# 옵티마이저 생성 +optimizer_tp = Adam(wrapper_tp.parameters(), lr=3e-5) +optimizer_no_tp = Adam(model_no_tp.parameters(), lr=3e-5) + +# 데이터셋 생성 +batch_size = 16 +datasets = load_dataset(dataset_name).data["train"]["context"] +datasets = [str(sample) for sample in datasets[:500]] +dataloader = DataLoader(datasets, batch_size=batch_size) + +# 모니터링 생성 +if dist.get_rank() == 0: + wandb.init(project="oslo", name=f"{model_name}_tp3d_bs{batch_size}") + cur = time.time() + +# 저장 +wrapper_tp.save_parallelized("test/", merge_checkpoints=True) + +# 모니터링 생성 대기 +dist.barrier() + +# 로드 +model_gathered = AutoModelForCausalLM.from_pretrained("test/").cuda() +optimizer_gathered = Adam(model_gathered.parameters(), lr=3e-5) + +dist.barrier() + +# 학습 시작 +for data in dataloader: + optimizer_tp.zero_grad() + optimizer_no_tp.zero_grad() + optimizer_gathered.zero_grad() + + inputs = tokenizer( + data, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512, + ).to("cuda") + + loss_no_tp, notp_fw_time = fw(model_no_tp, **inputs, labels=inputs["input_ids"]) + loss_tp, tp_fw_time = fw(wrapper_tp, **inputs, labels=inputs["input_ids"]) + loss_gathered, gathered_fw_time = fw( + model_gathered, **inputs, labels=inputs["input_ids"] + ) + + if dist.get_rank() == 0: + print(f"TP:{loss_tp}, NOTP:{loss_no_tp}, GATHRED:{loss_gathered}") + wandb.log({"tp": loss_tp, "notp": loss_no_tp, "GATHRED": loss_gathered}) + + _, notp_bw_time = bw(loss_no_tp) + _, tp_bw_time = bw(loss_tp) + _, gathered_bw_time = bw(loss_gathered) + + optimizer_tp.step() + optimizer_no_tp.step() + optimizer_gathered.step() + + if dist.get_rank() == 0: + wandb.log( + { + "tp.forward.time:": tp_fw_time, + "tp.backward.time:": tp_bw_time, + "notp.forward.time:": notp_fw_time, + "notp.backward.time:": notp_bw_time, + "gathered.forward.time:": gathered_fw_time, + "gathered.backward.time:": gathered_bw_time, + } + ) + +dist.barrier()