Skip to content

Commit

Permalink
enable custom activation functions (#65)
Browse files Browse the repository at this point in the history
* enable custom activation functions

* add file; add default; rm branch

* updt set_grad_enabled in activation fn
  • Loading branch information
vchiley authored Dec 19, 2023
1 parent e684e02 commit 44d9743
Show file tree
Hide file tree
Showing 4 changed files with 120 additions and 61 deletions.
24 changes: 24 additions & 0 deletions megablocks/layers/activation_fn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Callable

import torch
import stk


def act_fn(x: stk.Matrix, function: Callable, return_grad_fn: bool = False, **kwargs):
assert isinstance(x, stk.Matrix)
with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn):
if return_grad_fn:
x.data.requires_grad = True
out = function(x.data, **kwargs)
y = stk.Matrix(
x.size(),
out,
x.row_indices,
x.column_indices,
x.offsets,
x.column_indices_t,
x.offsets_t,
x.block_offsets_t)
if return_grad_fn:
return y, out.backward
return y
6 changes: 5 additions & 1 deletion megablocks/layers/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
import megablocks.turbo_util as turbo
import megablocks.grouped_gemm_util as grouped_gemm
import torch
import torch.nn.functional as F
from typing import Callable, Optional, Union

# Type annotation for in-place Tensor initialization function.
InitFn = Callable[[torch.Tensor], None]

_ALLOWED_BITWIDTHS = (-1, 4, 8)

DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate="tanh")


@dataclasses.dataclass
class Arguments:
Expand All @@ -19,12 +22,13 @@ class Arguments:
num_layers : int = 1
bias : bool = True
return_bias : bool = True
activation_fn : Optional[Callable] = DEFAULT_ACTIVATION_FN

# MoE arguments.
moe_num_experts : int = 1
moe_top_k : int = 1
moe_capacity_factor : int = 1
moe_normalize_expert_weights: Optional[Union[int, float]] = None
moe_normalize_expert_weights : Optional[Union[int, float]] = None
moe_loss_weight : float = 0.1
moe_jitter_eps : Optional[float] = None
moe_lbl_in_fp32 : bool = False
Expand Down
8 changes: 4 additions & 4 deletions megablocks/layers/glu.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from megablocks.layers import common
from megablocks.layers import gelu
from megablocks.layers.activation_fn import act_fn
from megablocks.layers.mlp import SparseMLP, create_dmoe_expert_weights
from megablocks.layers import mpu
from megablocks.layers.arguments import Arguments, InitFn
from megablocks import grouped_gemm_util as gg
import stk
import torch
import torch.nn.functional as F


class SparseGLU(SparseMLP):
Expand Down Expand Up @@ -38,7 +37,8 @@ def forward(self, x, topo):
x1 = stk.ops.sdd(x, w1.t(), topo)
x2 = stk.ops.sdd(x, v1.t(), topo)

x1 = stk.ops.mul(gelu.gelu(x1), x2)
activation_fn_out = act_fn(x1, self.args.activation_fn)
x1 = stk.ops.mul(activation_fn_out, x2)

return stk.ops.dsd(x1, w2)

Expand All @@ -56,5 +56,5 @@ def forward(self, x, tokens_per_expert):
# Compute the MLP.
x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True)
x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True)
x1 = F.gelu(x1, approximate="tanh") * x2
x1 = self.args.activation_fn(x1) * x2
return gg.ops.gmm(x1, w2, batch_sizes)
Loading

0 comments on commit 44d9743

Please sign in to comment.