Skip to content
This repository has been archived by the owner on Aug 26, 2022. It is now read-only.

Commit

Permalink
precommit run
Browse files Browse the repository at this point in the history
  • Loading branch information
jason9693 committed Aug 15, 2022
1 parent 7183465 commit 20204ea
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,9 @@
LayerNorm2D,
)
from oslo.torch.nn.parallel.tensor_parallel._parallel_2d._ops import (
split_batch_2d,
gather_2d,
gather_1d,
gather_1d_twice,
)
from oslo.torch.distributed.nn.functional import (
scatter,
)
Expand Down
10 changes: 1 addition & 9 deletions oslo/torch/nn/parallel/tensor_parallel/_parallel_2p5d/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,15 +83,6 @@ def gather_batch_2p5d(
)


def all_gather_tensor_2p5d(
inputs: Tensor,
dim: int,
parallel_context: ParallelContext,
col_parallel_mode: ParallelMode,
) -> Tensor:
return _AllGatherTensor2p5D.apply(inputs, dim, parallel_context, col_parallel_mode)


def reduce_by_batch_2p5d(
inputs, reduce_mean: bool, parallel_context: ParallelContext
) -> Tensor:
Expand Down Expand Up @@ -137,6 +128,7 @@ def split_batch_2p5d(
)[parallel_context.get_local_rank(ParallelMode.TENSOR_2P5D_COL)].contiguous()
return col_chunked


def get_current_device():
r"""
Get current device.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@
LayerNorm3D,
)
from oslo.torch.nn.parallel.tensor_parallel._parallel_3d._ops import (
split_batch_3d,
gather_3d,
gather_2d,
gather_1d,
)
from oslo.torch.distributed.nn.functional import (
scatter,
)
Expand Down
6 changes: 1 addition & 5 deletions oslo/torch/nn/parallel/tensor_parallel/tensor_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import os
import json
from operator import xor
from typing import Optional
import warnings

import torch
Expand Down Expand Up @@ -96,7 +95,6 @@ def __init__(
module = self._resize_vocab_size(module, self.parallel_context)
module = self._resize_num_classes(module, self.parallel_context, mapping)


if parallel_context.tensor_parallel_mode != ParallelMode.TENSOR_1D:
if memory_priority and parallel_context.tensor_parallel_size > 1:
warnings.warn(
Expand All @@ -108,9 +106,7 @@ def __init__(
module, self.parallel_context, mapping, memory_priority
)
elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2D:
self.module = _TensorParallel2D(
module, self.parallel_context, mapping
)
self.module = _TensorParallel2D(module, self.parallel_context, mapping)
elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_2P5D:
self.module = _TensorParallel2p5D(module, self.parallel_context, mapping)
elif self.parallel_context.tensor_parallel_mode == ParallelMode.TENSOR_3D:
Expand Down

0 comments on commit 20204ea

Please sign in to comment.