Skip to content
This repository has been archived by the owner on Aug 26, 2022. It is now read-only.

Commit

Permalink
Merge pull request #143 from tunib-ai/models
Browse files Browse the repository at this point in the history
Update data collators and Add models
  • Loading branch information
hyunwoongko authored Aug 25, 2022
2 parents f2d95ed + a51329c commit a4cad0e
Show file tree
Hide file tree
Showing 35 changed files with 5,051 additions and 520 deletions.
3 changes: 3 additions & 0 deletions oslo/torch/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from oslo.torch.nn.modules.dropout import (
FusedBiasDropout,
)
from oslo.torch.nn.modules.activation import (
FusedBiasGeLU,
)
from oslo.torch.nn.modules.embedding import (
Embedding1D,
Embedding2D,
Expand Down
10 changes: 9 additions & 1 deletion oslo/torch/nn/modules/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter

from oslo.torch.nn.modules.functional import multi_head_attention_forward
from oslo.torch.nn.modules.functional import (
multi_head_attention_forward,
fused_bias_gelu,
)


class MultiheadAttention(nn.Module):
Expand Down Expand Up @@ -268,3 +271,8 @@ def forward(
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights


class FusedBiasGeLU(nn.Module):
def forward(self, input: torch.Tensor, bias: torch.Tensor) -> torch.Tensor:
return fused_bias_gelu(input, bias)
4 changes: 3 additions & 1 deletion oslo/torch/nn/modules/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,11 @@ def _fused_scale_mask_softmax_cuda(input, scale, use_triang_mask, pad_mask):
return output.view(bsz, np, sq, sk)
else:
if pad_mask is not None:
if pad_mask.size(2) == 1:
pad_mask = pad_mask.repeat(1, 1, sq, 1)
return _FusedScaleMaskSoftmaxFunction.apply(
input,
pad_mask.repeat(1, 1, sq, 1).bool(),
pad_mask.bool(),
scale,
)
else:
Expand Down
11 changes: 0 additions & 11 deletions oslo/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,17 +65,6 @@ def _init_weights(self, module: nn.Module):


def create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
if is_deepspeed_zero3_enabled():
import deepspeed

with deepspeed.zero.GatheredParameters(out, modifier_rank=0):
if torch.distributed.get_rank() == 0:
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)
else:
_create_sinusoidal_embeddings(n_pos=n_pos, dim=dim, out=out)


def _create_sinusoidal_embeddings(n_pos: int, dim: int, out: torch.Tensor):
position_enc = np.array(
[
[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)]
Expand Down
Empty file.
Loading

0 comments on commit a4cad0e

Please sign in to comment.