Skip to content
This repository has been archived by the owner on Aug 26, 2022. It is now read-only.

Commit

Permalink
Merge pull request #149 from tunib-ai/tp-2p5d-deparallelize
Browse files Browse the repository at this point in the history
Tp 2p5d / 2d / 1d deparallelize
  • Loading branch information
hyunwoongko authored Aug 16, 2022
2 parents abd9906 + 88ac33f commit d1a3463
Show file tree
Hide file tree
Showing 34 changed files with 3,367 additions and 87 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,10 @@ usecases
/usecases
*/usecases
wandb/

# multi gpu mem log
**/core.*

# sample huggingface models
**/pytorch_model.bin
**/config.json
20 changes: 18 additions & 2 deletions oslo/torch/nn/modules/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
)
if hasattr(self, "orig_num_classes"):
outputs = outputs[..., : self.orig_num_classes]
if not outputs.is_contiguous():
outputs = outputs.contiguous()

if self.memory_priority and self.scatter_output:
outputs = scatter_tensor_1d(
Expand Down Expand Up @@ -257,6 +259,9 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
bias = self.bias
return outputs + bias

if not outputs.is_contiguous():
outputs = outputs.contiguous()

return outputs


Expand Down Expand Up @@ -398,6 +403,10 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
)
if hasattr(self, "orig_num_classes"):
outputs = outputs[..., : self.orig_num_classes]

if not outputs.is_contiguous():
outputs = outputs.contiguous()

return outputs


Expand Down Expand Up @@ -530,15 +539,19 @@ def forward(self, input: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
dim=-1,
parallel_context=self.parallel_context,
col_parallel_mode=ParallelMode.TENSOR_2P5D_ROW,
)
).clone()
outputs = all_gather_tensor_2p5d(
outputs,
dim=0,
parallel_context=self.parallel_context,
col_parallel_mode=ParallelMode.TENSOR_2P5D_COL,
)
).clone()
if hasattr(self, "orig_num_classes"):
outputs = outputs[..., : self.orig_num_classes]

if not outputs.is_contiguous():
outputs = outputs.contiguous()

return outputs


Expand Down Expand Up @@ -621,4 +634,7 @@ def forward(self, input: Tensor) -> Tensor:
)
if hasattr(self, "orig_num_classes"):
outputs = outputs[..., : self.orig_num_classes]

if not outputs.is_contiguous():
outputs = outputs.contiguous()
return outputs
7 changes: 4 additions & 3 deletions oslo/torch/nn/parallel/tensor_parallel/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from oslo.torch.nn.parallel.tensor_parallel.mapping import Column, Row, Update, Head
from oslo.torch.nn.parallel.tensor_parallel.tensor_parallel import (
TensorParallel,
from oslo.torch.nn.parallel.tensor_parallel.tensor_parallel import TensorParallel
from oslo.torch.nn.parallel.tensor_parallel._base_wrapper import (
BaseTensorParallelWrapper,
)

__ALL__ = [TensorParallel, Column, Row, Update, Head]
__ALL__ = [TensorParallel, Column, Row, Update, Head, BaseTensorParallelWrapper]
225 changes: 225 additions & 0 deletions oslo/torch/nn/parallel/tensor_parallel/_base_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,225 @@
import copy
import os
import json

import torch
import torch.nn as nn
import torch.distributed as dist

from typing import Union, Optional, Callable
from logging import getLogger

from oslo.torch.distributed import ParallelContext, ParallelMode

from oslo.torch.nn.parallel.utils import (
ParallelWrapper,
_update_module_arguments,
is_huggingface_model,
is_oslo_model,
allocate_params,
unwrap_parallel,
get_parameter_dtype,
)


class BaseTensorParallelWrapper(ParallelWrapper):
"""
PyTorch module for xD tensor parallelism
Args:
module (nn.Module): model object
parallel_context (ParallelContext): parallel context object
"""

def __init__(
self,
module: nn.Module,
parallel_context: ParallelContext,
mapping: dict = None,
module_args: dict = None,
):
super().__init__()

@torch.no_grad()
def save_parallelized(
self,
new_module,
save_directory: Union[str, os.PathLike],
save_config: bool = True,
state_dict: Optional[dict] = None,
save_function: Callable = torch.save,
merge_checkpoints: bool = False,
mapping: Optional[dict] = None,
**kwargs,
):
logger = getLogger("TensorParallel")
PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin"

if (
self.parallel_context.get_world_size(ParallelMode.TENSOR) == 1
and self.parallel_context.get_world_size(ParallelMode.PIPELINE) == 1
):
if dist.get_rank() == 0:
self.save_pretrained(
save_directory=save_directory,
save_config=save_config,
state_dict=state_dict,
save_function=save_function,
**kwargs,
)
dist.barrier()
return None

if merge_checkpoints:
model_to_save = self.__class__(
module=new_module,
parallel_context=self.parallel_context,
mapping=mapping,
module_args=self.config,
).eval()

if state_dict is None:
state_dict = self.state_dict()

model_to_save.load_state_dict(state_dict)
allocate_params(model_to_save, self.parallel_context)

if self.parallel_context.get_world_size(ParallelMode.TENSOR) > 1:
model_to_save.deparallelize()

if dist.get_rank() == 0:
if is_huggingface_model(model_to_save.module):
model_to_save.module.save_pretrained(
save_directory=save_directory,
save_config=save_config,
save_function=save_function,
**kwargs,
)
else:
if save_config:
with open(
os.path.join(save_directory, "config.json"), "w"
) as f:
json.dump(self.config, f)
save_function(
model_to_save,
os.path.join(save_directory, "pytorch_model.bin"),
)
del model_to_save

dist.barrier()
return None

if os.path.isfile(save_directory):
logger.error(
f"Provided path ({save_directory}) should be a directory, not a file"
)
return

os.makedirs(save_directory, exist_ok=True)

# Only save the model itself if we are using distributed training
model_to_save = unwrap_parallel(self)

# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
# we currently don't use this setting automatically, but may start to use with v5
dtype = get_parameter_dtype(model_to_save)
model_to_save.config.torch_dtype = str(dtype).split(".")[1]

# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]

# Save the config
if save_config:
model_to_save.config.save_pretrained(save_directory)

# Save the model
if state_dict is None:
state_dict = model_to_save.state_dict()

# Handle the case where some state_dict keys shouldn't be saved
if getattr(self, "_keys_to_ignore_on_save", None) is not None:
state_dict = {
k: v
for k, v in state_dict.items()
if k not in self._keys_to_ignore_on_save
}

# If we save using the predefined names, we can load using `from_pretrained`
weights_name = PARALLELIZED_WEIGHTS_NAME
weights_name = weights_name.replace(
"tp_0", f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}"
)
weights_name = weights_name.replace(
"pp_0", f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}"
)

output_model_file = os.path.join(save_directory, weights_name)

if self.parallel_context.get_world_size(ParallelMode.DATA) > 1:
if self.parallel_context.get_local_rank(ParallelMode.DATA) == 0:
save_function(state_dict, output_model_file)
else:
save_function(state_dict, output_model_file)

dist.barrier()
logger.info(f"Model weights saved in {output_model_file}")

def from_parallelized(self, path):
"""
Example:
>>> model = AnyModel()
>>> model = TensorParallel(model, ...)
>>> model.from_parallelized(path)
"""
PARALLELIZED_WEIGHTS_NAME = "pytorch_model_tp_0_pp_0.bin"
parallelized_model_path = path

file_names = {
os.path.join(
parallelized_model_path,
PARALLELIZED_WEIGHTS_NAME.replace("tp_0", f"tp_{tp}").replace(
"pp_0", f"pp_{pp}"
),
)
for tp in range(self.parallel_context.get_world_size(ParallelMode.TENSOR))
for pp in range(self.parallel_context.get_world_size(ParallelMode.PIPELINE))
}

if os.path.isdir(parallelized_model_path):
if all(os.path.isfile(file_name) for file_name in file_names):
state_dict = torch.load(
os.path.join(
parallelized_model_path,
PARALLELIZED_WEIGHTS_NAME.replace(
"tp_0",
f"tp_{self.parallel_context.get_local_rank(ParallelMode.TENSOR)}",
).replace(
"pp_0",
f"pp_{self.parallel_context.get_local_rank(ParallelMode.PIPELINE)}",
),
)
)

if getattr(self, "_keys_to_ignore_on_save", None) is not None:
state_dict = {
k: v
for k, v in state_dict.items()
if k not in self._keys_to_ignore_on_save
}

self.load_state_dict(state_dict=state_dict, strict=False)

else:
raise FileNotFoundError(
f"all the {file_names} are necessary. "
f"but some of them do not exist. Please check your checkpoint files."
)
else:
raise NotADirectoryError(
f"directory named {parallelized_model_path} is not valid. "
)

@torch.no_grad()
def deparallelize(self):
return NotImplementedError
19 changes: 19 additions & 0 deletions oslo/torch/nn/parallel/tensor_parallel/_parallel_1d/_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch import Tensor

Expand Down Expand Up @@ -214,3 +215,21 @@ def memory_priority_linear(
inputs: Tensor, weight: Tensor, parallel_context: ParallelContext
):
return _MemoryPriorityLinear.apply(inputs, weight, parallel_context)


def split_1d(parallel_context, tensor, summa_dim, dim=-1):
tensor = tensor.chunk(summa_dim, dim=dim)[
parallel_context.get_local_rank(ParallelMode.TENSOR_1D)
]
return tensor


def gather_1d(parallel_context, tensor, summa_dim, dim=-1):
tensor_list = [torch.zeros_like(tensor) for _ in range(summa_dim)]
dist.all_gather(
tensor_list,
tensor.contiguous(),
parallel_context.get_group(ParallelMode.TENSOR_1D),
)
tensor = torch.cat(tensor_list, dim=dim)
return tensor
Loading

0 comments on commit d1a3463

Please sign in to comment.