Skip to content
This repository has been archived by the owner on Aug 26, 2022. It is now read-only.

Commit

Permalink
Merge pull request #147 from tunib-ai/tensor_sequence_parallel
Browse files Browse the repository at this point in the history
Tensor sequence parallel
  • Loading branch information
hyunwoongko authored Aug 14, 2022
2 parents a0bc6e4 + b208d03 commit abd9906
Show file tree
Hide file tree
Showing 47 changed files with 655 additions and 304 deletions.
11 changes: 5 additions & 6 deletions oslo/torch/distributed/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def reduce_scatter(
out = tensor
work = None
else:
assert (
tensor.size(dim) % world_size == 0
), "tensor_size must be divisible by world size for tensor parallelism"
temp = list(
map(lambda x: x.contiguous(), torch.chunk(tensor, world_size, dim=dim))
)
Expand Down Expand Up @@ -203,15 +206,11 @@ def scatter(
if world_size == 1:
return tensor

tensor_size = tensor.size(dim)
assert (
tensor_size % world_size == 0
tensor.size(dim) % world_size == 0
), "tensor_size must be divisible by world size for tensor parallelism"
split_size_or_sections = tensor_size // world_size

tensor_list = torch.split(
tensor, split_size_or_sections=split_size_or_sections, dim=dim
)
tensor_list = torch.chunk(tensor, world_size, dim=dim)
return tensor_list[rank].contiguous()


Expand Down
9 changes: 5 additions & 4 deletions oslo/torch/distributed/parallel_context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import random
import warnings
from typing import List, Optional

import numpy as np
Expand Down Expand Up @@ -212,8 +213,8 @@ def from_torch(
expert_parallel_size=expert_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_depth=tensor_parallel_depth,
tensor_parallel_mode=tensor_parallel_mode,
backend=backend,
seed=seed,
)
Expand Down Expand Up @@ -282,8 +283,8 @@ def from_slurm(
expert_parallel_size=expert_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_depth=tensor_parallel_depth,
tensor_parallel_mode=tensor_parallel_mode,
backend=backend,
seed=seed,
)
Expand Down Expand Up @@ -351,8 +352,8 @@ def from_openmpi(
expert_parallel_size=expert_parallel_size,
pipeline_parallel_size=pipeline_parallel_size,
tensor_parallel_size=tensor_parallel_size,
tensor_parallel_mode=tensor_parallel_mode,
tensor_parallel_depth=tensor_parallel_depth,
tensor_parallel_mode=tensor_parallel_mode,
backend=backend,
seed=seed,
)
Expand All @@ -370,8 +371,8 @@ def __init__(
expert_parallel_size: int,
pipeline_parallel_size: int,
tensor_parallel_size: int,
tensor_parallel_mode: Optional[str],
tensor_parallel_depth: Optional[int],
tensor_parallel_mode: Optional[str],
backend: str,
seed: int,
):
Expand Down
3 changes: 3 additions & 0 deletions oslo/torch/nn/modules/dropout.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from typing import Optional
import torch
import torch.nn.functional as F
from torch.nn.modules.dropout import _DropoutNd
from oslo.torch.distributed import ParallelContext

from oslo.torch.nn.modules.functional import (
fused_bias_dropout,
Expand Down
42 changes: 38 additions & 4 deletions oslo/torch/nn/modules/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __init__(
parallel_context: Optional[ParallelContext] = None,
):
self.parallel_context = parallel_context
self.memory_priority = False
self.world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D)
assert (
embedding_dim % self.world_size == 0
Expand All @@ -141,8 +142,20 @@ def __init__(

def forward(self, input: Tensor) -> Tensor:
from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import (
all_gather_tensor_1d,
gather_tensor_1d,
scatter_tensor_1d,
)
from oslo.torch.distributed.nn.functional import (
all_gather,
)

if self.memory_priority:
input = all_gather(
input,
dim=1,
parallel_context=self.parallel_context,
parallel_mode=ParallelMode.TENSOR_1D,
)

output = F.embedding(
input,
Expand All @@ -154,11 +167,15 @@ def forward(self, input: Tensor) -> Tensor:
self.sparse,
)

output = all_gather_tensor_1d(
output = gather_tensor_1d(
output,
-1,
self.parallel_context,
)
if self.memory_priority:
output = scatter_tensor_1d(
output, dim=1, parallel_context=self.parallel_context
)
return output


Expand All @@ -171,6 +188,7 @@ def __init__(
parallel_context: Optional[ParallelContext] = None,
):
self.parallel_context = parallel_context
self.memory_priority = False
rank = self.parallel_context.get_local_rank(ParallelMode.TENSOR_1D)
self.world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D)
assert (
Expand All @@ -192,8 +210,20 @@ def __init__(

def forward(self, input: Tensor) -> Tensor:
from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import (
all_reduce_tensor_1d,
reduce_tensor_1d,
scatter_tensor_1d,
)
from oslo.torch.distributed.nn.functional import (
all_gather,
)

if self.memory_priority:
input = all_gather(
input,
dim=1,
parallel_context=self.parallel_context,
parallel_mode=ParallelMode.TENSOR_1D,
)

if self.world_size > 1:
input_mask = (input < self.vocab_start_index) | (
Expand All @@ -218,7 +248,11 @@ def forward(self, input: Tensor) -> Tensor:
output_parallel[input_mask, :] = 0.0

# Reduce across all the model parallel GPUs.
output = all_reduce_tensor_1d(output_parallel, self.parallel_context)
output = reduce_tensor_1d(output_parallel, self.parallel_context)
if self.memory_priority:
output = scatter_tensor_1d(
output, dim=1, parallel_context=self.parallel_context
)
return output


Expand Down
25 changes: 25 additions & 0 deletions oslo/torch/nn/modules/layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def __init__(
parallel_context: Optional[ParallelContext] = None,
):
self.parallel_context = parallel_context
self.memory_priority = False

super().__init__(
normalized_shape=normalized_shape,
partitioned_dim=normalized_shape,
Expand All @@ -89,6 +91,29 @@ def __init__(
dtype=dtype,
)

def forward(self, input: Tensor) -> Tensor:
from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import (
broadcast_tensor_1d,
)

weight = (
broadcast_tensor_1d(self.weight, parallel_context=self.parallel_context)
if self.memory_priority
else self.weight
)
bias = (
broadcast_tensor_1d(self.bias, parallel_context=self.parallel_context)
if self.memory_priority and self.bias is not None
else self.bias
)
normalized_shape = (
(self.normalized_shape,)
if isinstance(self.normalized_shape, int)
else self.normalized_shape
)
output = F.layer_norm(input, normalized_shape, weight, bias, self.eps)
return output


class LayerNorm2D(LayerNorm):
def __init__(
Expand Down
45 changes: 36 additions & 9 deletions oslo/torch/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,9 @@ def __init__(
):
self.gather_output = gather_output
self.parallel_context = parallel_context
self.memory_priority = False
self.reversed = False
self.scatter_output = False

self.world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D)
assert (
Expand All @@ -150,12 +152,17 @@ def extra_repr(self) -> str:

def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import (
all_gather_tensor_1d,
gather_tensor_1d,
broadcast_tensor_1d,
scatter_tensor_1d,
memory_priority_linear,
)

input = broadcast_tensor_1d(input, self.parallel_context)
outputs = F.linear(input, self.weight)
if self.memory_priority:
outputs = memory_priority_linear(input, self.weight, self.parallel_context)
else:
input = broadcast_tensor_1d(input, self.parallel_context)
outputs = F.linear(input, self.weight)

if self.bias is not None:
if self.skip_bias_add:
Expand All @@ -164,13 +171,20 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
outputs = outputs + self.bias

if self.gather_output:
outputs = all_gather_tensor_1d(
outputs = gather_tensor_1d(
outputs,
dim=-1,
parallel_context=self.parallel_context,
)
if hasattr(self, "orig_num_classes"):
outputs = outputs[..., : self.orig_num_classes]

if self.memory_priority and self.scatter_output:
outputs = scatter_tensor_1d(
outputs,
dim=1,
parallel_context=self.parallel_context,
)
return outputs


Expand All @@ -187,6 +201,7 @@ def __init__(
):
self.parallel_input = parallel_input
self.parallel_context = parallel_context
self.memory_priority = False
self.reversed = False

self.world_size = self.parallel_context.get_world_size(ParallelMode.TENSOR_1D)
Expand All @@ -210,25 +225,37 @@ def extra_repr(self) -> str:

def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
from oslo.torch.nn.parallel.tensor_parallel._parallel_1d._ops import (
all_reduce_tensor_1d,
reduce_tensor_1d,
scatter_tensor_1d,
reduce_scatter_tensor_1d,
broadcast_tensor_1d,
)

if not self.parallel_input:
assert (
not self.memory_priority
), "Input must be parallelized when using memory priority."
input = scatter_tensor_1d(
input,
dim=-1,
parallel_context=self.parallel_context,
)

outputs = F.linear(input, self.weight)
outputs = all_reduce_tensor_1d(outputs, self.parallel_context)

if self.memory_priority:
outputs = reduce_scatter_tensor_1d(
outputs, dim=1, parallel_context=self.parallel_context
)
else:
outputs = reduce_tensor_1d(outputs, parallel_context=self.parallel_context)
if self.bias is not None:
if self.skip_bias_add:
return outputs, self.bias
else:
return outputs + self.bias
if self.memory_priority:
bias = broadcast_tensor_1d(self.bias, self.parallel_context)
else:
bias = self.bias
return outputs + bias

return outputs

Expand Down
Loading

0 comments on commit abd9906

Please sign in to comment.