diff --git a/oslo/torch/distributed/nn/functional.py b/oslo/torch/distributed/nn/functional.py index 568c7c5b..7d768677 100644 --- a/oslo/torch/distributed/nn/functional.py +++ b/oslo/torch/distributed/nn/functional.py @@ -93,6 +93,9 @@ def reduce_scatter( out = tensor work = None else: + assert ( + tensor.size(dim) % world_size == 0 + ), "tensor_size must be divisible by world size for tensor parallelism" temp = list( map(lambda x: x.contiguous(), torch.chunk(tensor, world_size, dim=dim)) ) @@ -203,15 +206,11 @@ def scatter( if world_size == 1: return tensor - tensor_size = tensor.size(dim) assert ( - tensor_size % world_size == 0 + tensor.size(dim) % world_size == 0 ), "tensor_size must be divisible by world size for tensor parallelism" - split_size_or_sections = tensor_size // world_size - tensor_list = torch.split( - tensor, split_size_or_sections=split_size_or_sections, dim=dim - ) + tensor_list = torch.chunk(tensor, world_size, dim=dim) return tensor_list[rank].contiguous() diff --git a/oslo/torch/distributed/parallel_context.py b/oslo/torch/distributed/parallel_context.py index 43c4cb92..e0429cea 100644 --- a/oslo/torch/distributed/parallel_context.py +++ b/oslo/torch/distributed/parallel_context.py @@ -1,5 +1,6 @@ import os import random +import warnings from typing import List, Optional import numpy as np @@ -212,8 +213,8 @@ def from_torch( expert_parallel_size=expert_parallel_size, pipeline_parallel_size=pipeline_parallel_size, tensor_parallel_size=tensor_parallel_size, - tensor_parallel_mode=tensor_parallel_mode, tensor_parallel_depth=tensor_parallel_depth, + tensor_parallel_mode=tensor_parallel_mode, backend=backend, seed=seed, ) @@ -282,8 +283,8 @@ def from_slurm( expert_parallel_size=expert_parallel_size, pipeline_parallel_size=pipeline_parallel_size, tensor_parallel_size=tensor_parallel_size, - tensor_parallel_mode=tensor_parallel_mode, tensor_parallel_depth=tensor_parallel_depth, + tensor_parallel_mode=tensor_parallel_mode, backend=backend, seed=seed, ) @@ -351,8 +352,8 @@ def from_openmpi( expert_parallel_size=expert_parallel_size, pipeline_parallel_size=pipeline_parallel_size, tensor_parallel_size=tensor_parallel_size, - tensor_parallel_mode=tensor_parallel_mode, tensor_parallel_depth=tensor_parallel_depth, + tensor_parallel_mode=tensor_parallel_mode, backend=backend, seed=seed, ) @@ -370,8 +371,8 @@ def __init__( expert_parallel_size: int, pipeline_parallel_size: int, tensor_parallel_size: int, - tensor_parallel_mode: Optional[str], tensor_parallel_depth: Optional[int], + tensor_parallel_mode: Optional[str], backend: str, seed: int, ): diff --git a/oslo/torch/nn/modules/dropout.py b/oslo/torch/nn/modules/dropout.py index 2e8433d2..16440dab 100644 --- a/oslo/torch/nn/modules/dropout.py +++ b/oslo/torch/nn/modules/dropout.py @@ -1,5 +1,8 @@ +from typing import Optional import torch +import torch.nn.functional as F from torch.nn.modules.dropout import _DropoutNd +from oslo.torch.distributed import ParallelContext from oslo.torch.nn.modules.functional import ( fused_bias_dropout, diff --git a/oslo/torch/nn/modules/embedding.py b/oslo/torch/nn/modules/embedding.py index 08c68b1b..de7cfb07 100644 --- a/oslo/torch/nn/modules/embedding.py +++ b/oslo/torch/nn/modules/embedding.py @@ -127,6 +127,7 @@ def __init__( parallel_context: Optional[ParallelContext] = None, ): self.parallel_context = parallel_context + self.memory_priority = False self.world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) assert ( embedding_dim % self.world_size == 0 @@ -141,8 +142,20 @@ def __init__( def forward(self, input: Tensor) -> Tensor: from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import ( - all_gather_tensor_1d, + gather_tensor_1d, + scatter_tensor_1d, ) + from oslo.torch.distributed.nn.functional import ( + all_gather, + ) + + if self.memory_priority: + input = all_gather( + input, + dim=1, + parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_1D, + ) output = F.embedding( input, @@ -154,11 +167,15 @@ def forward(self, input: Tensor) -> Tensor: self.sparse, ) - output = all_gather_tensor_1d( + output = gather_tensor_1d( output, -1, self.parallel_context, ) + if self.memory_priority: + output = scatter_tensor_1d( + output, dim=1, parallel_context=self.parallel_context + ) return output @@ -171,6 +188,7 @@ def __init__( parallel_context: Optional[ParallelContext] = None, ): self.parallel_context = parallel_context + self.memory_priority = False rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_1D) self.world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) assert ( @@ -192,8 +210,20 @@ def __init__( def forward(self, input: Tensor) -> Tensor: from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import ( - all_reduce_tensor_1d, + reduce_tensor_1d, + scatter_tensor_1d, ) + from oslo.torch.distributed.nn.functional import ( + all_gather, + ) + + if self.memory_priority: + input = all_gather( + input, + dim=1, + parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_1D, + ) if self.world_size > 1: input_mask = (input < self.vocab_start_index) | ( @@ -218,7 +248,11 @@ def forward(self, input: Tensor) -> Tensor: output_parallel[input_mask, :] = 0.0 # Reduce across all the model parallel GPUs. - output = all_reduce_tensor_1d(output_parallel, self.parallel_context) + output = reduce_tensor_1d(output_parallel, self.parallel_context) + if self.memory_priority: + output = scatter_tensor_1d( + output, dim=1, parallel_context=self.parallel_context + ) return output diff --git a/oslo/torch/nn/modules/layer_norm.py b/oslo/torch/nn/modules/layer_norm.py index 4a1d4296..0e67e36a 100644 --- a/oslo/torch/nn/modules/layer_norm.py +++ b/oslo/torch/nn/modules/layer_norm.py @@ -81,6 +81,8 @@ def __init__( parallel_context: Optional[ParallelContext] = None, ): self.parallel_context = parallel_context + self.memory_priority = False + super().__init__( normalized_shape=normalized_shape, partitioned_dim=normalized_shape, @@ -89,6 +91,29 @@ def __init__( dtype=dtype, ) + def forward(self, input: Tensor) -> Tensor: + from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import ( + broadcast_tensor_1d, + ) + + weight = ( + broadcast_tensor_1d(self.weight, parallel_context=self.parallel_context) + if self.memory_priority + else self.weight + ) + bias = ( + broadcast_tensor_1d(self.bias, parallel_context=self.parallel_context) + if self.memory_priority and self.bias is not None + else self.bias + ) + normalized_shape = ( + (self.normalized_shape,) + if isinstance(self.normalized_shape, int) + else self.normalized_shape + ) + output = F.layer_norm(input, normalized_shape, weight, bias, self.eps) + return output + class LayerNorm2D(LayerNorm): def __init__( diff --git a/oslo/torch/nn/modules/linear.py b/oslo/torch/nn/modules/linear.py index ba358565..7817164f 100644 --- a/oslo/torch/nn/modules/linear.py +++ b/oslo/torch/nn/modules/linear.py @@ -127,7 +127,9 @@ def __init__( ): self.gather_output = gather_output self.parallel_context = parallel_context + self.memory_priority = False self.reversed = False + self.scatter_output = False self.world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) assert ( @@ -150,12 +152,17 @@ def extra_repr(self) -> str: def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import ( - all_gather_tensor_1d, + gather_tensor_1d, broadcast_tensor_1d, + scatter_tensor_1d, + memory_priority_linear, ) - input = broadcast_tensor_1d(input, self.parallel_context) - outputs = F.linear(input, self.weight) + if self.memory_priority: + outputs = memory_priority_linear(input, self.weight, self.parallel_context) + else: + input = broadcast_tensor_1d(input, self.parallel_context) + outputs = F.linear(input, self.weight) if self.bias is not None: if self.skip_bias_add: @@ -164,13 +171,20 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: outputs = outputs + self.bias if self.gather_output: - outputs = all_gather_tensor_1d( + outputs = gather_tensor_1d( outputs, dim=-1, parallel_context=self.parallel_context, ) if hasattr(self, "orig_num_classes"): outputs = outputs[..., : self.orig_num_classes] + + if self.memory_priority and self.scatter_output: + outputs = scatter_tensor_1d( + outputs, + dim=1, + parallel_context=self.parallel_context, + ) return outputs @@ -187,6 +201,7 @@ def __init__( ): self.parallel_input = parallel_input self.parallel_context = parallel_context + self.memory_priority = False self.reversed = False self.world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) @@ -210,25 +225,37 @@ def extra_repr(self) -> str: def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import ( - all_reduce_tensor_1d, + reduce_tensor_1d, scatter_tensor_1d, + reduce_scatter_tensor_1d, + broadcast_tensor_1d, ) if not self.parallel_input: + assert ( + not self.memory_priority + ), "Input must be parallelized when using memory priority." input = scatter_tensor_1d( input, dim=-1, parallel_context=self.parallel_context, ) - outputs = F.linear(input, self.weight) - outputs = all_reduce_tensor_1d(outputs, self.parallel_context) - + if self.memory_priority: + outputs = reduce_scatter_tensor_1d( + outputs, dim=1, parallel_context=self.parallel_context + ) + else: + outputs = reduce_tensor_1d(outputs, parallel_context=self.parallel_context) if self.bias is not None: if self.skip_bias_add: return outputs, self.bias else: - return outputs + self.bias + if self.memory_priority: + bias = broadcast_tensor_1d(self.bias, self.parallel_context) + else: + bias = self.bias + return outputs + bias return outputs diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py index 9f55fa18..19d9c37c 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py @@ -1,19 +1,27 @@ from typing import Any import torch +import torch.nn.functional as F from torch import Tensor from oslo.torch.distributed import ParallelMode, ParallelContext -from oslo.torch.distributed.nn.functional import all_gather, all_reduce, scatter +from oslo.torch.distributed.nn.functional import ( + all_gather, + all_reduce, + reduce_scatter, + scatter, +) class _BroadcastTensor1D(torch.autograd.Function): + @staticmethod def forward(ctx: Any, inputs: Tensor, parallel_context: ParallelContext): if ctx: ctx.parallel_context = parallel_context return inputs - def backward(ctx, grad): + @staticmethod + def backward(ctx: Any, grad: Tensor): parallel_context = ctx.parallel_context return ( all_reduce( @@ -27,7 +35,8 @@ def backward(ctx, grad): ) -class _AllReduceTensor1D(torch.autograd.Function): +class _ReduceTensor1D(torch.autograd.Function): + @staticmethod def forward(ctx: Any, inputs: Tensor, parallel_context: ParallelContext): return all_reduce( inputs, @@ -37,11 +46,13 @@ def forward(ctx: Any, inputs: Tensor, parallel_context: ParallelContext): parallel_mode=ParallelMode.TENSOR_1D, ) - def backward(ctx, grad): + @staticmethod + def backward(ctx: Any, grad: Tensor): return grad, None -class _AllGatherTensor1D(torch.autograd.Function): +class _GatherTensor1D(torch.autograd.Function): + @staticmethod def forward(ctx: Any, inputs: Tensor, dim: int, parallel_context: ParallelContext): if ctx: ctx.dim = dim @@ -55,6 +66,7 @@ def forward(ctx: Any, inputs: Tensor, dim: int, parallel_context: ParallelContex parallel_mode=ParallelMode.TENSOR_1D, ) + @staticmethod def backward(ctx: Any, grad: Tensor): return ( scatter( @@ -69,21 +81,47 @@ def backward(ctx: Any, grad: Tensor): class _ScatterTensor1D(torch.autograd.Function): + @staticmethod def forward(ctx: Any, inputs: Tensor, dim: int, parallel_context: ParallelContext): if ctx: ctx.dim = dim ctx.parallel_context = parallel_context + return scatter( + inputs, + dim=dim, + parallel_context=parallel_context, + parallel_mode=ParallelMode.TENSOR_1D, + ) + + @staticmethod + def backward(ctx: Any, grad: Tensor): return ( - scatter( - inputs, - dim=dim, - parallel_context=parallel_context, + all_gather( + grad, + dim=ctx.dim, + parallel_context=ctx.parallel_context, parallel_mode=ParallelMode.TENSOR_1D, ), None, + None, ) - def backward(ctx, grad): + +class _ReduceScatterTensor1D(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, inputs: Tensor, dim: int, parallel_context: ParallelContext): + if ctx: + ctx.dim = dim + ctx.parallel_context = parallel_context + return reduce_scatter( + inputs, + dim, + parallel_context=parallel_context, + parallel_mode=ParallelMode.TENSOR_1D, + ) + + @staticmethod + def backward(ctx: Any, grad: Tensor): return ( all_gather( grad, @@ -96,17 +134,83 @@ def backward(ctx, grad): ) +class _MemoryPriorityLinear(torch.autograd.Function): + @staticmethod + def forward( + ctx: Any, inputs: Tensor, weight: Tensor, parallel_context: ParallelContext + ): + if ctx: + ctx.save_for_backward(inputs, weight) + ctx.parallel_context = parallel_context + + total_inputs = all_gather( + inputs, + dim=1, + parallel_context=parallel_context, + parallel_mode=ParallelMode.TENSOR_1D, + ) + outputs = F.linear(total_inputs, weight) + return outputs + + @staticmethod + def backward(ctx: Any, grad_outputs: Tensor): + inputs, weight = ctx.saved_tensors + + total_inputs, handle = all_gather( + inputs, + dim=1, + async_op=True, + parallel_context=ctx.parallel_context, + parallel_mode=ParallelMode.TENSOR_1D, + ) + + grad_inputs = grad_outputs.matmul(weight) + handle.wait() + + grad_outputs = grad_outputs.reshape( + grad_outputs.shape[0] * grad_outputs.shape[1], grad_outputs.shape[2] + ) + total_inputs = total_inputs.reshape( + total_inputs.shape[0] * total_inputs.shape[1], total_inputs.shape[2] + ) + + sub_grad_inputs, handle = reduce_scatter( + grad_inputs, + dim=1, + async_op=True, + parallel_context=ctx.parallel_context, + parallel_mode=ParallelMode.TENSOR_1D, + ) + + grad_weight = grad_outputs.t().matmul(total_inputs) + + handle.wait() + return sub_grad_inputs, grad_weight, None + + def broadcast_tensor_1d(inputs: Tensor, parallel_context: ParallelContext): return _BroadcastTensor1D.apply(inputs, parallel_context) -def all_reduce_tensor_1d(inputs: Tensor, parallel_context: ParallelContext): - return _AllReduceTensor1D.apply(inputs, parallel_context) +def reduce_tensor_1d(inputs: Tensor, parallel_context: ParallelContext): + return _ReduceTensor1D.apply(inputs, parallel_context) -def all_gather_tensor_1d(inputs: Tensor, dim: int, parallel_context: ParallelContext): - return _AllGatherTensor1D.apply(inputs, dim, parallel_context) +def gather_tensor_1d(inputs: Tensor, dim: int, parallel_context: ParallelContext): + return _GatherTensor1D.apply(inputs, dim, parallel_context) def scatter_tensor_1d(inputs: Tensor, dim: int, parallel_context: ParallelContext): return _ScatterTensor1D.apply(inputs, dim, parallel_context) + + +def reduce_scatter_tensor_1d( + inputs: Tensor, dim: int, parallel_context: ParallelContext +): + return _ReduceScatterTensor1D.apply(inputs, dim, parallel_context) + + +def memory_priority_linear( + inputs: Tensor, weight: Tensor, parallel_context: ParallelContext +): + return _MemoryPriorityLinear.apply(inputs, weight, parallel_context) diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py index 29e1901c..c8533e3d 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_wrapper.py @@ -16,6 +16,9 @@ from oslo.torch.nn.modules.layer_norm import ( LayerNorm1D, ) +from oslo.torch.distributed.nn.functional import ( + scatter, +) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, ) @@ -28,6 +31,7 @@ from oslo.transformers.mapping_utils import ( _TensorParallelMappingForHuggingFace, ) +from oslo.transformers.constants import SEQ_DIMENSIONS class _TensorParallel1D(ParallelWrapper): @@ -45,10 +49,12 @@ def __init__( module: nn.Module, parallel_context: ParallelContext, mapping: dict = None, + memory_priority: bool = False, ): super().__init__() self.module = module self.parallel_context = parallel_context + self.memory_priority = memory_priority self.device = torch.cuda.current_device() if mapping is None: @@ -63,6 +69,30 @@ def __init__( self._parallelize() def forward(self, *args, **kwargs): + assert len(args) == 0, ( + "1D tensor parallel model only supports ``**kwargs`` input (keyword arguments). " + "If you wrote code like ``model(input_ids, labels)``, " + "please modify your code like ``model(input_ids=input_ids, labels=labels)``." + ) + if self.memory_priority and not is_oslo_model(self.module): + assert ( + "past_key_values" not in kwargs + ), "``past_key_values`` argument is forbidden with memory priority." + if "position_ids" not in kwargs: + kwargs["position_ids"] = torch.arange( + kwargs["input_ids"].shape[-1], device=kwargs["input_ids"].device + ).unsqueeze(0) + kwargs = { + key: scatter( + value, + dim=SEQ_DIMENSIONS[key], + parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_1D, + ) + if key in SEQ_DIMENSIONS + else value + for key, value in kwargs.items() + } return self.module(*args, **kwargs) @torch.no_grad() @@ -102,39 +132,44 @@ def _parallelize_layernorm(self): ) def _parallelize_linear(self): - for param_name, module in self.module.named_modules(): - if self.tensor_parallel_mapping.is_column_parallel(self.module, param_name): + for module_name, module in self.module.named_modules(): + if self.tensor_parallel_mapping.is_column_parallel( + self.module, module_name + ): self._column_slice_linear( module=module, reversed=self.tensor_parallel_mapping.is_reversed( - self.module, param_name + self.module, module_name ), fusion_degree=self.tensor_parallel_mapping.get_combined_qkv_degree( - self.module, param_name, module + self.module, module_name, module ), gather_output=self.tensor_parallel_mapping.is_gather_output( - self.module, param_name + self.module, module_name + ), + scatter_output=self.tensor_parallel_mapping.is_gather_output( + self.module, module_name ), ) - elif self.tensor_parallel_mapping.is_row_parallel(self.module, param_name): + elif self.tensor_parallel_mapping.is_row_parallel(self.module, module_name): self._row_slice_linear( module=module, reversed=self.tensor_parallel_mapping.is_reversed( - self.module, param_name + self.module, module_name ), fusion_degree=1, ) def _parallelize_head(self): - for param_name, module in self.module.named_modules(): + for module_name, module in self.module.named_modules(): if self.tensor_parallel_mapping.is_head( - self.module, param_name + self.module, module_name ) and isinstance(module, nn.Linear): self._slice_head( module=module, reversed=self.tensor_parallel_mapping.is_reversed( - self.module, param_name + self.module, module_name ), ) @@ -147,8 +182,8 @@ def _deconstruct_combined_qkv(tensor, world_size, fusion_degree, dim): return tensor def _slice_embedding(self, module): - rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_1D) world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) + rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_1D) if module is self.module.get_input_embeddings(): ( @@ -167,6 +202,7 @@ def _slice_embedding(self, module): vocab_start_index=vocab_start_index, vocab_end_index=vocab_end_index, parallel_context=self.parallel_context, + memory_priority=self.memory_priority, world_size=world_size, num_embeddings=module.weight.size()[0], orig_module=copy.deepcopy(module.__class__), @@ -179,6 +215,7 @@ def _slice_embedding(self, module): _update_module_arguments( module=module, parallel_context=self.parallel_context, + memory_priority=self.memory_priority, world_size=world_size, embedding_dim=module.weight.size()[1], orig_module=copy.deepcopy(module.__class__), @@ -198,8 +235,8 @@ def _slice_linear( slice_bias: bool, dim: int, ): - rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_1D) world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) + rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_1D) if reversed: module.weight.data = module.weight.data.t() @@ -246,6 +283,7 @@ def _column_slice_linear( reversed: bool, fusion_degree: int, gather_output: bool, + scatter_output: bool, ): world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) self._slice_linear( @@ -261,18 +299,25 @@ def _column_slice_linear( in_features=module.weight.size()[1], out_features=module.weight.size()[0], parallel_context=self.parallel_context, + memory_priority=self.memory_priority, world_size=world_size, reversed=reversed, fusion_degree=fusion_degree, orig_module=copy.deepcopy(module.__class__), gather_output=gather_output, + scatter_output=scatter_output, skip_bias_add=module.skip_bias_add if hasattr(module, "skip_bias_add") else False, ) module.__class__ = ColLinear1D - def _row_slice_linear(self, module: nn.Module, reversed: bool, fusion_degree: int): + def _row_slice_linear( + self, + module: nn.Module, + reversed: bool, + fusion_degree: int, + ): world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) self._slice_linear( module=module, @@ -286,6 +331,7 @@ def _row_slice_linear(self, module: nn.Module, reversed: bool, fusion_degree: in in_features=module.weight.size()[1], out_features=module.weight.size()[0], parallel_context=self.parallel_context, + memory_priority=self.memory_priority, world_size=world_size, reversed=reversed, fusion_degree=fusion_degree, @@ -304,29 +350,45 @@ def _slice_layernorm(self, module): normalized_shape=module.weight.size()[0], partitioned_dim=module.weight.size()[0], parallel_context=self.parallel_context, + memory_priority=self.memory_priority, world_size=world_size, orig_module=copy.deepcopy(module.__class__), ) module.__class__ = LayerNorm1D def _slice_head(self, module, reversed): - world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) if module.weight is not self.module.get_input_embeddings().weight: self._column_slice_linear( module=module, reversed=reversed, fusion_degree=1, gather_output=not is_oslo_model(self.module), + scatter_output=False, ) else: + world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D) + rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_1D) + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + bias_list = module.bias.data.chunk(world_size, dim=0) + module.bias.data = bias_list[rank].contiguous() + + if hasattr(module.bias, "oslo_parallel"): + module.bias.oslo_parallel[ParallelMode.TENSOR_1D] = rank + else: + module.bias.oslo_parallel = {ParallelMode.TENSOR_1D: rank} + _update_module_arguments( module=module, parallel_context=self.parallel_context, + memory_priority=self.memory_priority, world_size=world_size, reversed=reversed, fusion_degree=1, orig_module=copy.deepcopy(module.__class__), gather_output=not is_oslo_model(self.module), + scatter_output=False, skip_bias_add=module.skip_bias_add if hasattr(module, "skip_bias_add") else False, diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py index c9d31b96..a618c4f4 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_ops.py @@ -94,26 +94,6 @@ def gather_batch_2d( ) -def split_batch_2d( - inputs: Tensor, - dim: int = 0, - parallel_context: Optional[ParallelContext] = None, -) -> Tensor: - dim_size = inputs.size(dim) - world_size = parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) - - if world_size <= 1: - return inputs - - assert ( - dim_size % world_size == 0 - ), f"The batch size ({dim_size}) is not a multiple of 2D size ({world_size})." - - return inputs.chunk(world_size, dim=dim)[ - parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL) - ].contiguous() - - def reduce_tensor_2d( inputs: Tensor, parallel_context: ParallelContext, @@ -802,18 +782,17 @@ def forward( parallel_context: ParallelContext, parallel_mode: ParallelMode, ) -> Tensor: + if ctx: + ctx.dim = dim + ctx.parallel_context = parallel_context + ctx.parallel_mode = parallel_mode + outputs = all_gather( inputs, dim, parallel_context=parallel_context, parallel_mode=parallel_mode, ) - - if ctx: - ctx.dim = dim - ctx.parallel_context = parallel_context - ctx.parallel_mode = parallel_mode - return outputs @staticmethod diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py index f9e6f90d..eef4339d 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2d/_wrapper.py @@ -15,8 +15,8 @@ from oslo.torch.nn.modules.layer_norm import ( LayerNorm2D, ) -from oslo.torch.nn.parallel.tensor_parallel._parallel_2d._ops import ( - split_batch_2d, +from oslo.torch.distributed.nn.functional import ( + scatter, ) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, @@ -73,10 +73,11 @@ def forward(self, *args, **kwargs): ) if not is_oslo_model(self.module): kwargs = { - key: split_batch_2d( + key: scatter( value, dim=BATCH_DIMENSIONS[key], parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_2D_COL, ) if key in BATCH_DIMENSIONS else value @@ -166,9 +167,9 @@ def _deconstruct_combined_qkv(tensor, summa_dim, fusion_degree): return tensor def _slice_embedding(self, module): + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW) col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL) - summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) if module is self.module.get_input_embeddings(): ( @@ -220,10 +221,9 @@ def _slice_embedding(self, module): } def _slice_linear(self, module, reversed, fusion_degree, slice_bias): + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW) col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL) - summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) - data_parallel_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) pipeline_parallel_rank = self.parallel_context.get_local_rank( ParallelMode.PIPELINE @@ -308,10 +308,9 @@ def _slice_linear(self, module, reversed, fusion_degree, slice_bias): module.__class__ = Linear2D def _slice_layernorm(self, module): + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW) col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL) - summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) - data_parallel_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) pipeline_parallel_rank = self.parallel_context.get_local_rank( ParallelMode.PIPELINE @@ -380,10 +379,9 @@ def _slice_head(self, module, reversed): gather_output=not is_oslo_model(self.module), ) else: + summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_ROW) col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2D_COL) - summa_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_2D_COL) - data_parallel_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) pipeline_parallel_rank = self.parallel_context.get_local_rank( ParallelMode.PIPELINE @@ -394,6 +392,21 @@ def _slice_head(self, module, reversed): pipeline_parallel_size = self.parallel_context.get_world_size( ParallelMode.PIPELINE ) + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + bias_list = module.bias.data.chunk(summa_dim, dim=0) + bias_list = [bias.chunk(summa_dim, dim=0) for bias in bias_list] + module.bias.data = bias_list[row_rank][col_rank].contiguous() + + if hasattr(module.bias, "oslo_parallel"): + module.bias.oslo_parallel[ParallelMode.TENSOR_2D_ROW] = row_rank + module.bias.oslo_parallel[ParallelMode.TENSOR_2D_COL] = col_rank + else: + module.bias.oslo_parallel = { + ParallelMode.TENSOR_2D_ROW: row_rank, + ParallelMode.TENSOR_2D_COL: col_rank, + } + _update_module_arguments( module=module, in_features=module.weight.size()[1], diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py index 272d7d62..253a31ec 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py @@ -83,26 +83,6 @@ def gather_batch_2p5d( ) -def split_batch_2p5d( - inputs: Tensor, - dim: int = 0, - parallel_context: Optional[ParallelContext] = None, -) -> Tensor: - dim_size = inputs.size(dim) - world_size = parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL) - - if world_size <= 1: - return inputs - - assert ( - dim_size % world_size == 0 - ), f"The batch size ({dim_size}) is not a multiple of 2.5D size * depth ({world_size})." - - return torch.chunk( - inputs, parallel_context.get_world_size(ParallelMode.TENSOR_2P5D_COL), dim=dim - )[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL)].contiguous() - - def reduce_by_batch_2p5d( inputs, reduce_mean: bool, parallel_context: ParallelContext ) -> Tensor: diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py index e425b519..6315c76f 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_wrapper.py @@ -12,8 +12,9 @@ ) from oslo.torch.nn.modules.linear import Linear2p5D from oslo.torch.nn.modules.layer_norm import LayerNorm2p5D -from oslo.torch.nn.parallel.tensor_parallel._parallel_2p5d._ops import split_batch_2p5d - +from oslo.torch.distributed.nn.functional import ( + scatter, +) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, ) @@ -70,10 +71,11 @@ def forward(self, *args, **kwargs): ) if not is_oslo_model(self.module): kwargs = { - key: split_batch_2p5d( + key: scatter( value, dim=BATCH_DIMENSIONS[key], parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_2P5D_COL, ) if key in BATCH_DIMENSIONS else value @@ -228,13 +230,12 @@ def _slice_embedding(self, module): } def _slice_linear(self, module, reversed, fusion_degree, slice_bias): - row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) - col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) - dep_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP) tesseract_dim = self.parallel_context.get_world_size( ParallelMode.TENSOR_2P5D_COL ) - + row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + dep_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP) data_parallel_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) pipeline_parallel_rank = self.parallel_context.get_local_rank( ParallelMode.PIPELINE @@ -323,13 +324,12 @@ def _slice_linear(self, module, reversed, fusion_degree, slice_bias): return module def _slice_layernorm(self, module): - row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) - col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) - dep_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP) tesseract_dim = self.parallel_context.get_world_size( ParallelMode.TENSOR_2P5D_COL ) - + row_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_ROW) + col_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL) + dep_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_DEP) data_parallel_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) pipeline_parallel_rank = self.parallel_context.get_local_rank( ParallelMode.PIPELINE @@ -402,6 +402,9 @@ def _slice_head(self, module, reversed): gather_output=not is_oslo_model(self.module), ) else: + tesseract_dim = self.parallel_context.get_world_size( + ParallelMode.TENSOR_2P5D_COL + ) row_rank = self.parallel_context.get_local_rank( ParallelMode.TENSOR_2P5D_ROW ) @@ -411,10 +414,6 @@ def _slice_head(self, module, reversed): dep_rank = self.parallel_context.get_local_rank( ParallelMode.TENSOR_2P5D_DEP ) - tesseract_dim = self.parallel_context.get_world_size( - ParallelMode.TENSOR_2P5D_COL - ) - data_parallel_rank = self.parallel_context.get_local_rank(ParallelMode.DATA) pipeline_parallel_rank = self.parallel_context.get_local_rank( ParallelMode.PIPELINE @@ -425,6 +424,30 @@ def _slice_head(self, module, reversed): pipeline_parallel_size = self.parallel_context.get_world_size( ParallelMode.PIPELINE ) + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + bias_list = module.bias.data.chunk(tesseract_dim, dim=0) + + module.bias.data = bias_list[row_rank].contiguous() + + if hasattr(module.bias, "oslo_parallel"): + module.bias.oslo_parallel[ + ParallelMode.TENSOR_2P5D_ROW + ] = row_rank + module.bias.oslo_parallel[ + ParallelMode.TENSOR_2P5D_COL + ] = col_rank + module.weight.oslo_parallel[ + ParallelMode.TENSOR_2P5D_DEP + ] = dep_rank + else: + module.bias.oslo_parallel = { + ParallelMode.TENSOR_2P5D_ROW: row_rank, + ParallelMode.TENSOR_2P5D_COL: col_rank, + ParallelMode.TENSOR_2P5D_DEP: dep_rank, + } + _update_module_arguments( module=module, in_features=module.weight.size()[1], diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py index 209f1e4c..364484b9 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_ops.py @@ -582,43 +582,6 @@ def split_tensor_3d( return output -def split_batch_3d( - inputs: Tensor, - dim: int = 0, - parallel_context: Optional[ParallelContext] = None, -) -> Tensor: - r"""Splits 3D tensor in batch. - - Args: - input_ (:class:`torch.tensor`): Input tensor. - dim (int): Specified dimension in which to split. - - Returns: - :class:`torch.tensor`: The tensor has been split. - - Note: - The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found - in `parallel_mode `_. - """ - dim_size = inputs.size(dim) - weight_world_size = parallel_context.get_world_size(ParallelMode.TENSOR_3D_WEIGHT) - input_world_size = parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) - - assert ( - dim_size % (input_world_size * weight_world_size) == 0 - ), f"The batch size ({dim_size}) is not a multiple of square of 3D cubic dim ({input_world_size*weight_world_size})." - - if inputs.size(dim) <= 1: - return inputs - output = torch.chunk(inputs, weight_world_size, dim=dim)[ - parallel_context.get_local_rank(ParallelMode.TENSOR_3D_WEIGHT) - ].contiguous() - output = torch.chunk(output, input_world_size, dim=dim)[ - parallel_context.get_local_rank(ParallelMode.TENSOR_3D_INPUT) - ].contiguous() - return output - - class _ReduceTensor3D(torch.autograd.Function): @staticmethod def forward( diff --git a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py index bdbe431b..046f550f 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py +++ b/oslo/torch/nn/parallel/tensor_parallel/_parallel_3d/_wrapper.py @@ -15,8 +15,8 @@ from oslo.torch.nn.modules.layer_norm import ( LayerNorm3D, ) -from oslo.torch.nn.parallel.tensor_parallel._parallel_3d._ops import ( - split_batch_3d, +from oslo.torch.distributed.nn.functional import ( + scatter, ) from oslo.torch.nn.parallel.tensor_parallel.mapping import ( TensorParallelMapping, @@ -73,10 +73,16 @@ def forward(self, *args, **kwargs): ) if not is_oslo_model(self.module): kwargs = { - key: split_batch_3d( - value, + key: scatter( + scatter( + value, + dim=BATCH_DIMENSIONS[key], + parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_3D_WEIGHT, + ), dim=BATCH_DIMENSIONS[key], parallel_context=self.parallel_context, + parallel_mode=ParallelMode.TENSOR_3D_INPUT, ) if key in BATCH_DIMENSIONS else value @@ -174,6 +180,7 @@ def _deconstruct_combined_qkv(tensor, cubic_dim, fusion_degree, is_bias=False): return tensor def _slice_embedding(self, module): + cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) input_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_3D_INPUT) output_rank = self.parallel_context.get_local_rank( ParallelMode.TENSOR_3D_OUTPUT @@ -181,7 +188,6 @@ def _slice_embedding(self, module): weight_rank = self.parallel_context.get_local_rank( ParallelMode.TENSOR_3D_WEIGHT ) - cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) if module is self.module.get_input_embeddings(): ( @@ -240,6 +246,7 @@ def _slice_embedding(self, module): } def _slice_linear(self, module, reversed, fusion_degree, slice_bias): + cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) input_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_3D_INPUT) output_rank = self.parallel_context.get_local_rank( ParallelMode.TENSOR_3D_OUTPUT @@ -247,7 +254,6 @@ def _slice_linear(self, module, reversed, fusion_degree, slice_bias): weight_rank = self.parallel_context.get_local_rank( ParallelMode.TENSOR_3D_WEIGHT ) - cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) if reversed: module.weight.data = module.weight.data.t() @@ -331,6 +337,7 @@ def _slice_linear(self, module, reversed, fusion_degree, slice_bias): module.__class__ = Linear3D def _slice_layernorm(self, module): + cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) input_rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_3D_INPUT) output_rank = self.parallel_context.get_local_rank( ParallelMode.TENSOR_3D_OUTPUT @@ -338,7 +345,6 @@ def _slice_layernorm(self, module): weight_rank = self.parallel_context.get_local_rank( ParallelMode.TENSOR_3D_WEIGHT ) - cubic_dim = self.parallel_context.get_world_size(ParallelMode.TENSOR_3D_INPUT) if hasattr(module, "weight") and module.weight is not None: if module.weight.dim() >= 1: @@ -408,6 +414,37 @@ def _slice_head(self, module, reversed): cubic_dim = self.parallel_context.get_world_size( ParallelMode.TENSOR_3D_INPUT ) + input_rank = self.parallel_context.get_local_rank( + ParallelMode.TENSOR_3D_INPUT + ) + output_rank = self.parallel_context.get_local_rank( + ParallelMode.TENSOR_3D_OUTPUT + ) + weight_rank = self.parallel_context.get_local_rank( + ParallelMode.TENSOR_3D_WEIGHT + ) + + if hasattr(module, "bias") and module.bias is not None: + if module.bias.dim() >= 1: + bias_list = module.bias.data.chunk(cubic_dim, dim=0) + module.bias.data = bias_list[input_rank].contiguous() + + if hasattr(module.bias, "oslo_parallel"): + module.bias.oslo_parallel[ + ParallelMode.TENSOR_3D_INPUT + ] = input_rank + module.bias.oslo_parallel[ + ParallelMode.TENSOR_3D_OUTPUT + ] = output_rank + module.bias.oslo_parallel[ + ParallelMode.TENSOR_3D_WEIGHT + ] = weight_rank + else: + module.bias.oslo_parallel = { + ParallelMode.TENSOR_3D_INPUT: input_rank, + ParallelMode.TENSOR_3D_OUTPUT: output_rank, + ParallelMode.TENSOR_3D_WEIGHT: weight_rank, + } _update_module_arguments( module=module, diff --git a/oslo/torch/nn/parallel/tensor_parallel/mapping.py b/oslo/torch/nn/parallel/tensor_parallel/mapping.py index 822abe2e..37c865a4 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/mapping.py +++ b/oslo/torch/nn/parallel/tensor_parallel/mapping.py @@ -6,7 +6,7 @@ class TensorParallelInfo(object): A class to describe tensor parallelization information. Args: - name (Tuple[str]): the name of parameter + name (Tuple[str]): the name of module combined_qkv (bool): combined qkv or not parallel (bool): parallelizable param or not reverse (bool): reversed param or not @@ -130,7 +130,7 @@ def update_attrs(self, model): if mapping is not None: return mapping["Update"] - def search(self, model, param_name): + def search(self, model, module_name): """ Get element by parameter name @@ -142,12 +142,12 @@ def search(self, model, param_name): """ mapping = self.get_mapping(model) count_contain_elem_in_param = 0 - param_split = param_name.split(".") + param_split = module_name.split(".") first_check = [] for elems in mapping.values(): for elem in elems: - if elem.name in param_name: + if elem.name in module_name: first_check.append(elem) for elem in first_check: @@ -160,110 +160,110 @@ def search(self, model, param_name): return None - def is_combined_qkv_param(self, model, param_name): + def is_combined_qkv_param(self, model, module_name): """ - Check whether the param is combined qkv or not + Check whether the module is combined qkv or not Args: model (PreTrainedModel): model obj - param_name (str): name of parameter + module_name (str): name of module Returns: - bool: whether the param is combined qkv or not + bool: whether the module is combined qkv or not """ - elem = self.search(model, param_name) + elem = self.search(model, module_name) if elem is not None: return elem.combined_qkv - def get_combined_qkv_degree(self, model, param_name, module): + def get_combined_qkv_degree(self, model, module_name, module): """ Get combined qkv degree Args: model (PreTrainedModel): model obj - param_name (str): name of parameter + module_name (str): name of module module (nn.Module): module that has `weight` parameter Returns: int: combined qkv degree """ - if self.is_combined_qkv_param(model, param_name) and hasattr(module, "weight"): + if self.is_combined_qkv_param(model, module_name) and hasattr(module, "weight"): bigger = max(module.weight.size(0), module.weight.size(1)) smaller = min(module.weight.size(0), module.weight.size(1)) return bigger // smaller return 1 - def is_reversed(self, model, param_name): + def is_reversed(self, model, module_name): """ - Check whether the parameter is reversed or not + Check whether the moduleeter is reversed or not Args: model (PreTrainedModel): model obj - param_name (str): name of parameter + module_name (str): name of module Returns: - bool: whether the param is reversed or not + bool: whether the module is reversed or not """ - elem = self.search(model, param_name) + elem = self.search(model, module_name) if elem is not None: return elem.reversed - def is_gather_output(self, model, param_name): + def is_gather_output(self, model, module_name): """ - Check whether the param is gather output or not + Check whether the module is gather output or not Args: model (PreTrainedModel): model obj - param_name (str): name of parameter + module_name (str): name of module Returns: - bool: whether the param is combined qkv or not + bool: whether the module is combined qkv or not """ - elem = self.search(model, param_name) + elem = self.search(model, module_name) if elem is not None: return elem.gather_output - def is_column_parallel(self, model, param_name): + def is_column_parallel(self, model, module_name): """ - Check whether the parameter is column parallelizable or not + Check whether the moduleeter is column parallelizable or not Args: model (PreTrainedModel): model obj - param_name (str): name of parameter + module_name (str): name of module Returns: - bool: whether the param is column parallelizable or not + bool: whether the module is column parallelizable or not """ - elem = self.search(model, param_name) + elem = self.search(model, module_name) if elem is not None: return isinstance(elem, Column) - def is_row_parallel(self, model, param_name): + def is_row_parallel(self, model, module_name): """ - Check whether the parameter is row parallelizable or not + Check whether the moduleeter is row parallelizable or not Args: model (PreTrainedModel): model obj - param_name (str): name of parameter + module_name (str): name of module Returns: - bool: whether the param is row parallelizable or not + bool: whether the module is row parallelizable or not """ - elem = self.search(model, param_name) + elem = self.search(model, module_name) if elem is not None: return isinstance(elem, Row) - def is_head(self, model, param_name): + def is_head(self, model, module_name): """ - Check whether the parameter is lm head or not + Check whether the moduleeter is head or not Args: model (PreTrainedModel): model obj - param_name (str): name of parameter + module_name (str): name of module Returns: - bool: whether the param is lm head or not + bool: whether the module is head or not """ - elem = self.search(model, param_name) + elem = self.search(model, module_name) if elem is not None: return isinstance(elem, Head) diff --git a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py index 4e66a0b9..fba41a1f 100644 --- a/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py +++ b/oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py @@ -1,4 +1,5 @@ from typing import Optional +import warnings import torch import torch.nn as nn @@ -77,13 +78,23 @@ def __init__( module: nn.Module, parallel_context: Optional[ParallelContext] = None, mapping: dict = None, + memory_priority: bool = False, ): super().__init__() self.parallel_context = get_parallel_context(module, parallel_context) module = self._resize_vocab_size(module, self.parallel_context) module = self._resize_num_classes(module, self.parallel_context, mapping) + + if parallel_context.tensor_parallel_mode != ParallelMode.TENSOR_1D: + if memory_priority and parallel_context.tensor_parallel_size > 1: + warnings.warn( + "memory_priority is available only with 1D tensor parallel." + ) + if self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_1D: - self.module = _TensorParallel1D(module, self.parallel_context, mapping) + self.module = _TensorParallel1D( + module, self.parallel_context, mapping, memory_priority + ) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D: self.module = _TensorParallel2D(module, self.parallel_context, mapping) elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D: diff --git a/oslo/transformers/constants.py b/oslo/transformers/constants.py index d294d361..8a0ab313 100644 --- a/oslo/transformers/constants.py +++ b/oslo/transformers/constants.py @@ -5,3 +5,10 @@ "position_ids": 0, "inputs_embeds": 0, } + +SEQ_DIMENSIONS = { + "input_ids": -1, + "token_type_ids": -1, + "position_ids": -1, + "input_embeds": -1, +} diff --git a/oslo/transformers/mapping_utils.py b/oslo/transformers/mapping_utils.py index 48093289..84e82c89 100644 --- a/oslo/transformers/mapping_utils.py +++ b/oslo/transformers/mapping_utils.py @@ -83,7 +83,7 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): ], "Bert": [ Column("query", "key", "value", "intermediate.dense"), - Column("pooler.dense", gather_output=True), + Column("pooler.dense", "transform.dense", gather_output=True), Row("output.dense"), Update("num_attention_heads", "all_head_size"), Head("decoder", "seq_relationship", "classifier", "qa_outputs"), @@ -130,8 +130,8 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): Column( "electra.embeddings_project", "classifier.dense", - "discriminator_predictions.dense", "generator_predictions.dense", + "discriminator_predictions.dense", gather_output=True, ), Row("output.dense"), @@ -148,9 +148,9 @@ class _TensorParallelMappingForHuggingFace(_ParallelMappingForHuggingFace): "Roberta": [ Column("query", "key", "value", "intermediate.dense"), Column( - "lm_head.dense", "classifier.dense", "roberta.pooler", + "lm_head.dense", gather_output=True, ), Row("output.dense"), @@ -221,7 +221,7 @@ def __init__(self): HF_TO_OSLO = { transformers.GPT2Model: oslo.transformers.GPT2Model, transformers.GPT2LMHeadModel: oslo.transformers.GPT2LMHeadModel, - transformers.GPT2DoubleHeadsModel: oslo.transformers.GPT2DoubleHeadModel, + transformers.GPT2DoubleHeadsModel: oslo.transformers.GPT2DoubleHeadsModel, transformers.GPT2ForSequenceClassification: oslo.transformers.GPT2ForSequenceClassification, transformers.GPT2ForTokenClassification: oslo.transformers.GPT2ForTokenClassification, } diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/__init__.py b/tests/torch/nn/parallel/tensor_parallel/1d/__init__.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_1d/__init__.py rename to tests/torch/nn/parallel/tensor_parallel/1d/__init__.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/_utils.py b/tests/torch/nn/parallel/tensor_parallel/1d/_utils.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_1d/_utils.py rename to tests/torch/nn/parallel/tensor_parallel/1d/_utils.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_col_linear_1d.py b/tests/torch/nn/parallel/tensor_parallel/1d/test_col_linear_1d.py similarity index 83% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_col_linear_1d.py rename to tests/torch/nn/parallel/tensor_parallel/1d/test_col_linear_1d.py index 5ee46dd5..01aa7620 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_col_linear_1d.py +++ b/tests/torch/nn/parallel/tensor_parallel/1d/test_col_linear_1d.py @@ -1,3 +1,4 @@ +import argparse from copy import deepcopy import torch import torch.distributed as dist @@ -5,6 +6,11 @@ from oslo.torch.nn import ColLinear1D from _utils import split_1d, gather_1d + +parser = argparse.ArgumentParser() +parser.add_argument("--memory_priority", action="store_true", default=False) +args = parser.parse_args() + tp_size = 4 parallel_context = ParallelContext.from_torch( @@ -18,7 +24,7 @@ torch.manual_seed(0) batch_size = 2 -seq_len = 2 +seq_len = 4 input_dim = 4 hidden_dim = 8 world_size = parallel_context.get_world_size(ParallelMode.TENSOR_1D) @@ -33,28 +39,27 @@ out = linear(input_) optimizer = torch.optim.Adam(linear.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = linear(input_) -if parallel_context.get_global_rank() == 0: - print(f"original output: \n{out}\n") - print(f"original next output: \n{out_update}\n") - +if args.memory_priority: + input_ = split_1d(input_, world_size, dim=1, parallel_context=parallel_context) target = split_1d(target, world_size, dim=-1, parallel_context=parallel_context) w = split_1d(w, world_size, dim=0, parallel_context=parallel_context) b = split_1d(b, world_size, dim=0, parallel_context=parallel_context) col_linear = ColLinear1D(input_dim, hidden_dim, parallel_context=parallel_context) +col_linear.memory_priority = args.memory_priority col_linear.weight.data.copy_(w) col_linear.bias.data.copy_(b) pout = col_linear(input_) optimizer = torch.optim.Adam(col_linear.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = col_linear(input_) @@ -65,7 +70,9 @@ ) if parallel_context.get_global_rank() == 0: + print(f"original output: \n{out}\n") print(f"parallel output: \n{pout}\n") + print(f"original next output: \n{out_update}\n") print(f"parallel next output: \n{pout_update}\n") if parallel_context.get_global_rank() == 0: diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_embedding_1d.py b/tests/torch/nn/parallel/tensor_parallel/1d/test_embedding_1d.py similarity index 62% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_embedding_1d.py rename to tests/torch/nn/parallel/tensor_parallel/1d/test_embedding_1d.py index ebad02e6..39eda6dd 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_embedding_1d.py +++ b/tests/torch/nn/parallel/tensor_parallel/1d/test_embedding_1d.py @@ -1,9 +1,15 @@ +import argparse from copy import deepcopy import torch import torch.distributed as dist from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn import Embedding1D -from _utils import split_1d +from _utils import split_1d, gather_1d + + +parser = argparse.ArgumentParser() +parser.add_argument("--memory_priority", action="store_true", default=False) +args = parser.parse_args() tp_size = 4 @@ -16,9 +22,13 @@ torch.set_printoptions(sci_mode=False) torch.manual_seed(0) + +batch_size = 2 +seq_len = 4 +hidden_dim = 8 world_size = parallel_context.get_world_size(ParallelMode.TENSOR_1D) -input_ = torch.LongTensor([[0, 1, 6, 3, 8], [5, 2, 7, 4, 9]]).cuda() -target = torch.randn((2, 5, 8)).cuda() +input_ = torch.LongTensor([[0, 1, 6, 3], [5, 2, 7, 9]]).cuda() +target = torch.randn((batch_size, seq_len, hidden_dim)).cuda() dist.broadcast(input_, src=0) dist.broadcast(target, src=0) @@ -27,8 +37,8 @@ out = embedding(input_) optimizer = torch.optim.Adam(embedding.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = embedding(input_) @@ -38,17 +48,28 @@ print(f"original update output: \n{out_update}\n") w = split_1d(w, world_size, dim=-1, parallel_context=parallel_context) +if args.memory_priority: + input_ = split_1d(input_, world_size, dim=1, parallel_context=parallel_context) + target = split_1d(target, world_size, dim=1, parallel_context=parallel_context) embedding_1d = Embedding1D(16, 8, parallel_context=parallel_context) +embedding_1d.memory_priority = args.memory_priority embedding_1d.weight.data = w pout = embedding_1d(input_) optimizer = torch.optim.Adam(embedding_1d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = embedding_1d(input_) +if args.memory_priority: + pout = gather_1d( + pout.contiguous(), world_size, dim=1, parallel_context=parallel_context + ) + pout_update = gather_1d( + pout_update.contiguous(), world_size, dim=1, parallel_context=parallel_context + ) if parallel_context.get_global_rank() == 0: print(f"parallel output: \n{out}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_layer_norm_1d.py b/tests/torch/nn/parallel/tensor_parallel/1d/test_layer_norm_1d.py similarity index 66% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_layer_norm_1d.py rename to tests/torch/nn/parallel/tensor_parallel/1d/test_layer_norm_1d.py index 539596b4..f8441194 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_layer_norm_1d.py +++ b/tests/torch/nn/parallel/tensor_parallel/1d/test_layer_norm_1d.py @@ -1,8 +1,15 @@ +import argparse from copy import deepcopy import torch import torch.distributed as dist from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn import LayerNorm1D +from _utils import split_1d, gather_1d + + +parser = argparse.ArgumentParser() +parser.add_argument("--memory_priority", action="store_true", default=False) +args = parser.parse_args() tp_size = 4 @@ -17,8 +24,9 @@ torch.manual_seed(0) batch_size = 2 -seq_len = 2 +seq_len = 4 hidden_dim = 8 +world_size = parallel_context.get_world_size(ParallelMode.TENSOR_1D) input_ = torch.randn((batch_size, seq_len, hidden_dim)).cuda() target = torch.randn((batch_size, seq_len, hidden_dim)).cuda() @@ -31,8 +39,8 @@ out = layernorm(input_) optimizer = torch.optim.Adam(layernorm.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = layernorm(input_) @@ -41,19 +49,27 @@ print(f"original output: \n{out}\n") print(f"original update output: \n{out_update}\n") -dist.barrier() +if args.memory_priority: + input_ = split_1d(input_, world_size, dim=1, parallel_context=parallel_context) + target = split_1d(target, world_size, dim=1, parallel_context=parallel_context) layernorm_1d = LayerNorm1D(hidden_dim, parallel_context=parallel_context) +layernorm_1d.memory_priority = args.memory_priority layernorm_1d.weight.data = w layernorm_1d.bias.data = b pout = layernorm_1d(input_) optimizer = torch.optim.Adam(layernorm_1d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = layernorm_1d(input_) +if args.memory_priority: + pout = gather_1d(pout, world_size, dim=1, parallel_context=parallel_context) + pout_update = gather_1d( + pout_update, world_size, dim=1, parallel_context=parallel_context + ) if parallel_context.get_global_rank() == 0: print(f"parallel output: \n{pout}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_row_linear_1d.py b/tests/torch/nn/parallel/tensor_parallel/1d/test_row_linear_1d.py similarity index 73% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_row_linear_1d.py rename to tests/torch/nn/parallel/tensor_parallel/1d/test_row_linear_1d.py index f61d4588..e3b85385 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_row_linear_1d.py +++ b/tests/torch/nn/parallel/tensor_parallel/1d/test_row_linear_1d.py @@ -1,9 +1,15 @@ +import argparse from copy import deepcopy import torch import torch.distributed as dist from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn import RowLinear1D -from _utils import split_1d +from _utils import split_1d, gather_1d + + +parser = argparse.ArgumentParser() +parser.add_argument("--memory_priority", action="store_true", default=False) +args = parser.parse_args() tp_size = 4 @@ -18,7 +24,7 @@ torch.manual_seed(0) batch_size = 2 -seq_len = 2 +seq_len = 4 input_dim = 4 hidden_dim = 8 world_size = parallel_context.get_world_size(ParallelMode.TENSOR_1D) @@ -33,33 +39,39 @@ out = linear(input_) optimizer = torch.optim.Adam(linear.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = linear(input_) -if parallel_context.get_global_rank() == 0: - print(f"original output: \n{out}\n") - print(f"original next output: \n{out_update}\n") - input_ = split_1d(input_, world_size, dim=-1, parallel_context=parallel_context) +if args.memory_priority: + target = split_1d(target, world_size, dim=1, parallel_context=parallel_context) w = split_1d(w, world_size, dim=-1, parallel_context=parallel_context) row_linear = RowLinear1D(input_dim, hidden_dim, parallel_context=parallel_context) +row_linear.memory_priority = args.memory_priority row_linear.weight.data.copy_(w) row_linear.bias.data.copy_(b) pout = row_linear(input_) optimizer = torch.optim.Adam(row_linear.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = row_linear(input_) +if args.memory_priority: + pout = gather_1d(pout, world_size, dim=1, parallel_context=parallel_context) + pout_update = gather_1d( + pout_update, world_size, dim=1, parallel_context=parallel_context + ) if parallel_context.get_global_rank() == 0: + print(f"original output: \n{out}\n") print(f"parallel output: \n{pout}\n") + print(f"original next output: \n{out_update}\n") print(f"parallel next output: \n{pout_update}\n") if parallel_context.get_global_rank() == 0: diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_vocab_embedding_1d.py b/tests/torch/nn/parallel/tensor_parallel/1d/test_vocab_embedding_1d.py similarity index 64% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_vocab_embedding_1d.py rename to tests/torch/nn/parallel/tensor_parallel/1d/test_vocab_embedding_1d.py index 79b50dc8..8d9b617f 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_vocab_embedding_1d.py +++ b/tests/torch/nn/parallel/tensor_parallel/1d/test_vocab_embedding_1d.py @@ -1,9 +1,15 @@ +import argparse from copy import deepcopy import torch import torch.distributed as dist from oslo.torch.distributed import ParallelContext, ParallelMode from oslo.torch.nn import VocabParallelEmbedding1D -from _utils import split_1d +from _utils import split_1d, gather_1d + + +parser = argparse.ArgumentParser() +parser.add_argument("--memory_priority", action="store_true", default=False) +args = parser.parse_args() tp_size = 4 @@ -16,9 +22,13 @@ torch.set_printoptions(sci_mode=False) torch.manual_seed(0) + +batch_size = 2 +seq_len = 4 +hidden_dim = 8 world_size = parallel_context.get_world_size(ParallelMode.TENSOR_1D) -input_ = torch.LongTensor([[0, 1, 6, 3, 8], [5, 2, 7, 4, 9]]).cuda() -target = torch.randn((2, 5, 8)).cuda() +input_ = torch.LongTensor([[0, 1, 6, 3], [5, 2, 7, 9]]).cuda() +target = torch.randn((batch_size, seq_len, hidden_dim)).cuda() dist.broadcast(input_, src=0) dist.broadcast(target, src=0) @@ -27,8 +37,8 @@ out = vocab_embedding(input_) optimizer = torch.optim.Adam(vocab_embedding.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = vocab_embedding(input_) @@ -38,17 +48,26 @@ print(f"original next output: \n{out_update}\n") w = split_1d(w, world_size, dim=0, parallel_context=parallel_context) +if args.memory_priority: + input_ = split_1d(input_, world_size, dim=1, parallel_context=parallel_context) + target = split_1d(target, world_size, dim=1, parallel_context=parallel_context) vocab_embedding_1d = VocabParallelEmbedding1D(16, 8, parallel_context=parallel_context) +vocab_embedding_1d.memory_priority = args.memory_priority vocab_embedding_1d.weight.data = w pout = vocab_embedding_1d(input_) optimizer = torch.optim.Adam(vocab_embedding_1d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = vocab_embedding_1d(input_) +if args.memory_priority: + pout = gather_1d(pout, world_size, dim=1, parallel_context=parallel_context) + pout_update = gather_1d( + pout_update, world_size, dim=1, parallel_context=parallel_context + ) if parallel_context.get_global_rank() == 0: print(f"parallel output: \n{pout}\n") diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py b/tests/torch/nn/parallel/tensor_parallel/1d/test_wrapper_1d.py similarity index 90% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py rename to tests/torch/nn/parallel/tensor_parallel/1d/test_wrapper_1d.py index ed961ee9..4e430b4c 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_1d/test_wrapper_1d.py +++ b/tests/torch/nn/parallel/tensor_parallel/1d/test_wrapper_1d.py @@ -1,3 +1,4 @@ +import argparse import time import wandb import torch @@ -10,8 +11,13 @@ from oslo.torch.nn.parallel.utils import allocate_params from oslo.torch.distributed import ParallelContext, ParallelMode +parser = argparse.ArgumentParser() +parser.add_argument("--memory_priority", action="store_true", default=False) +args = parser.parse_args() + tp_size = 4 batch_size = 16 +seq_length = 128 model_name = "gpt2" # parallel context 생성 @@ -32,7 +38,9 @@ # 모델 생성 및 병렬화 수행 model_no_tp = GPT2LMHeadModel(GPT2Config.from_pretrained(model_name)).cuda() model_tp = GPT2LMHeadModel(GPT2Config.from_pretrained(model_name)) -wrapper_tp = TensorParallel(model_tp, parallel_context) +wrapper_tp = TensorParallel( + model_tp, parallel_context, memory_priority=args.memory_priority +) allocate_params(wrapper_tp, parallel_context) # allocate_params 함수는 추후에 모든 페러렐 래퍼를 관장하는 클래스에서 처리될 예정 # https://github.com/tunib-ai/oslo/blob/307131bbd5ed995ea8dca8ac541bfbce9bfec29b/oslo/pytorch/model_parallelism/model_parallel_engine.py @@ -65,9 +73,9 @@ inputs = tokenizer( data, return_tensors="pt", - padding=True, + padding="max_length", truncation=True, - max_length=512, + max_length=seq_length, ).to("cuda") fw_start = time.time() diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/__init__.py b/tests/torch/nn/parallel/tensor_parallel/2d/__init__.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2d/__init__.py rename to tests/torch/nn/parallel/tensor_parallel/2d/__init__.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/_utils.py b/tests/torch/nn/parallel/tensor_parallel/2d/_utils.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2d/_utils.py rename to tests/torch/nn/parallel/tensor_parallel/2d/_utils.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_embedding_2d.py b/tests/torch/nn/parallel/tensor_parallel/2d/test_embedding_2d.py similarity index 94% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_embedding_2d.py rename to tests/torch/nn/parallel/tensor_parallel/2d/test_embedding_2d.py index 61ef007d..61afaf57 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_embedding_2d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2d/test_embedding_2d.py @@ -27,8 +27,8 @@ out = embedding(input_) optimizer = torch.optim.Adam(embedding.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = embedding(input_) @@ -46,8 +46,8 @@ pout = embedding_2d(input_) optimizer = torch.optim.Adam(embedding_2d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = embedding_2d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_layer_norm_2d.py b/tests/torch/nn/parallel/tensor_parallel/2d/test_layer_norm_2d.py similarity index 95% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_layer_norm_2d.py rename to tests/torch/nn/parallel/tensor_parallel/2d/test_layer_norm_2d.py index 04947010..6949cfa5 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_layer_norm_2d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2d/test_layer_norm_2d.py @@ -33,8 +33,8 @@ out = layernorm(input_) optimizer = torch.optim.Adam(layernorm.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = layernorm(input_) @@ -56,8 +56,8 @@ pout = layernorm_2d(input_) optimizer = torch.optim.Adam(layernorm_2d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = layernorm_2d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_linear_2d.py b/tests/torch/nn/parallel/tensor_parallel/2d/test_linear_2d.py similarity index 94% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_linear_2d.py rename to tests/torch/nn/parallel/tensor_parallel/2d/test_linear_2d.py index 2c39883a..f1c718d9 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_linear_2d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2d/test_linear_2d.py @@ -33,8 +33,8 @@ out = linear(input_) optimizer = torch.optim.Adam(linear.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = linear(input_) @@ -54,8 +54,8 @@ pout = linear_2d(input_) optimizer = torch.optim.Adam(linear_2d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, ptarget) -logits.backward() +loss = torch.nn.MSELoss()(pout, ptarget) +loss.backward() optimizer.step() pout_update = linear_2d(input_) @@ -81,8 +81,8 @@ pout = linear_2d(input_) optimizer = torch.optim.Adam(linear_2d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = linear_2d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_vocab_embedding_2d.py b/tests/torch/nn/parallel/tensor_parallel/2d/test_vocab_embedding_2d.py similarity index 94% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_vocab_embedding_2d.py rename to tests/torch/nn/parallel/tensor_parallel/2d/test_vocab_embedding_2d.py index 082d83d1..181ab2ce 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_vocab_embedding_2d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2d/test_vocab_embedding_2d.py @@ -27,8 +27,8 @@ out = vocab_embedding(input_) optimizer = torch.optim.Adam(vocab_embedding.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = vocab_embedding(input_) @@ -46,8 +46,8 @@ pout = vocab_embedding_2d(input_) optimizer = torch.optim.Adam(vocab_embedding_2d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = vocab_embedding_2d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_wrapper_2d.py b/tests/torch/nn/parallel/tensor_parallel/2d/test_wrapper_2d.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2d/test_wrapper_2d.py rename to tests/torch/nn/parallel/tensor_parallel/2d/test_wrapper_2d.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/__init__.py b/tests/torch/nn/parallel/tensor_parallel/2p5d/__init__.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/__init__.py rename to tests/torch/nn/parallel/tensor_parallel/2p5d/__init__.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_utils.py b/tests/torch/nn/parallel/tensor_parallel/2p5d/_utils.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_utils.py rename to tests/torch/nn/parallel/tensor_parallel/2p5d/_utils.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_embedding_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_embedding_2p5d.py similarity index 95% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_embedding_2p5d.py rename to tests/torch/nn/parallel/tensor_parallel/2p5d/test_embedding_2p5d.py index f76af224..6c0e6975 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_embedding_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_embedding_2p5d.py @@ -34,8 +34,8 @@ out = embedding(input_) optimizer = torch.optim.Adam(embedding.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = embedding(input_) @@ -55,8 +55,8 @@ pout = embedding_2p5d(input_) optimizer = torch.optim.Adam(embedding_2p5d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = embedding_2p5d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_layer_norm_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_layer_norm_2p5d.py similarity index 95% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_layer_norm_2p5d.py rename to tests/torch/nn/parallel/tensor_parallel/2p5d/test_layer_norm_2p5d.py index aefb90af..fc4a64c7 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_layer_norm_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_layer_norm_2p5d.py @@ -35,8 +35,8 @@ out = layernorm(input_) optimizer = torch.optim.Adam(layernorm.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = layernorm(input_) @@ -59,8 +59,8 @@ pout = layernorm_2p5d(input_) optimizer = torch.optim.Adam(layernorm_2p5d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = layernorm_2p5d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_linear_2p5d.py similarity index 94% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py rename to tests/torch/nn/parallel/tensor_parallel/2p5d/test_linear_2p5d.py index f0d7b6a5..59f4df74 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_linear_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_linear_2p5d.py @@ -35,8 +35,8 @@ out = linear(input_) optimizer = torch.optim.Adam(linear.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = linear(input_) @@ -56,8 +56,8 @@ pout = linear_2p5d(input_) optimizer = torch.optim.Adam(linear_2p5d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, ptarget) -logits.backward() +loss = torch.nn.MSELoss()(pout, ptarget) +loss.backward() optimizer.step() pout_update = linear_2p5d(input_) @@ -83,8 +83,8 @@ pout = linear_2p5d(input_) optimizer = torch.optim.Adam(linear_2p5d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = linear_2p5d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_vocab_embedding_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_vocab_embedding_2p5d.py similarity index 95% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_vocab_embedding_2p5d.py rename to tests/torch/nn/parallel/tensor_parallel/2p5d/test_vocab_embedding_2p5d.py index 433af6bd..7726b40c 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_vocab_embedding_2p5d.py +++ b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_vocab_embedding_2p5d.py @@ -34,8 +34,8 @@ out = vocab_embedding(input_) optimizer = torch.optim.Adam(vocab_embedding.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = vocab_embedding(input_) @@ -55,8 +55,8 @@ pout = vocab_embedding_2p5d(input_) optimizer = torch.optim.Adam(vocab_embedding_2p5d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = vocab_embedding_2p5d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py b/tests/torch/nn/parallel/tensor_parallel/2p5d/test_wrapper_2p5d.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_2p5d/test_wrapper_2p5d.py rename to tests/torch/nn/parallel/tensor_parallel/2p5d/test_wrapper_2p5d.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/__init__.py b/tests/torch/nn/parallel/tensor_parallel/3d/__init__.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_3d/__init__.py rename to tests/torch/nn/parallel/tensor_parallel/3d/__init__.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/_utils.py b/tests/torch/nn/parallel/tensor_parallel/3d/_utils.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_3d/_utils.py rename to tests/torch/nn/parallel/tensor_parallel/3d/_utils.py diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_embedding_3d.py b/tests/torch/nn/parallel/tensor_parallel/3d/test_embedding_3d.py similarity index 95% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_embedding_3d.py rename to tests/torch/nn/parallel/tensor_parallel/3d/test_embedding_3d.py index eae2ded1..d26fd444 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_embedding_3d.py +++ b/tests/torch/nn/parallel/tensor_parallel/3d/test_embedding_3d.py @@ -34,8 +34,8 @@ out = embedding(input_) optimizer = torch.optim.Adam(embedding.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = embedding(input_) @@ -55,8 +55,8 @@ pout = embedding_3d(input_) optimizer = torch.optim.Adam(embedding_3d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = embedding_3d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_layer_norm_3d.py b/tests/torch/nn/parallel/tensor_parallel/3d/test_layer_norm_3d.py similarity index 95% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_layer_norm_3d.py rename to tests/torch/nn/parallel/tensor_parallel/3d/test_layer_norm_3d.py index c1e0fd02..7fd68d8c 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_layer_norm_3d.py +++ b/tests/torch/nn/parallel/tensor_parallel/3d/test_layer_norm_3d.py @@ -33,8 +33,8 @@ out = layernorm(input_) optimizer = torch.optim.Adam(layernorm.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = layernorm(input_) @@ -56,8 +56,8 @@ pout = layernorm_3d(input_) optimizer = torch.optim.Adam(layernorm_3d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = layernorm_3d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_linear_3d.py b/tests/torch/nn/parallel/tensor_parallel/3d/test_linear_3d.py similarity index 94% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_linear_3d.py rename to tests/torch/nn/parallel/tensor_parallel/3d/test_linear_3d.py index 810a6ad8..193f55e2 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_linear_3d.py +++ b/tests/torch/nn/parallel/tensor_parallel/3d/test_linear_3d.py @@ -35,8 +35,8 @@ out = linear(input_) optimizer = torch.optim.Adam(linear.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = linear(input_) @@ -57,8 +57,8 @@ pout = linear_3d(input_) optimizer = torch.optim.Adam(linear_3d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, ptarget) -logits.backward() +loss = torch.nn.MSELoss()(pout, ptarget) +loss.backward() optimizer.step() pout_update = linear_3d(input_) @@ -86,8 +86,8 @@ pout = linear_3d(input_) optimizer = torch.optim.Adam(linear_3d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = linear_3d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_vocab_embedding_3d.py b/tests/torch/nn/parallel/tensor_parallel/3d/test_vocab_embedding_3d.py similarity index 95% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_vocab_embedding_3d.py rename to tests/torch/nn/parallel/tensor_parallel/3d/test_vocab_embedding_3d.py index 6b5b6ddf..95d0c3fe 100644 --- a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_vocab_embedding_3d.py +++ b/tests/torch/nn/parallel/tensor_parallel/3d/test_vocab_embedding_3d.py @@ -34,8 +34,8 @@ out = embedding(input_) optimizer = torch.optim.Adam(embedding.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(out, target) -logits.backward() +loss = torch.nn.MSELoss()(out, target) +loss.backward() optimizer.step() out_update = embedding(input_) @@ -55,8 +55,8 @@ pout = vocab_embedding_3d(input_) optimizer = torch.optim.Adam(vocab_embedding_3d.parameters(), lr=1e-3) -logits = torch.nn.MSELoss()(pout, target) -logits.backward() +loss = torch.nn.MSELoss()(pout, target) +loss.backward() optimizer.step() pout_update = vocab_embedding_3d(input_) diff --git a/tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_wrapper_3d.py b/tests/torch/nn/parallel/tensor_parallel/3d/test_wrapper_3d.py similarity index 100% rename from tests/torch/nn/parallel/tensor_parallel/_parallel_3d/test_wrapper_3d.py rename to tests/torch/nn/parallel/tensor_parallel/3d/test_wrapper_3d.py