diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py new file mode 100644 index 00000000..613ef311 --- /dev/null +++ b/megablocks/layers/activation_fn.py @@ -0,0 +1,24 @@ +from typing import Callable + +import torch +import stk + + +def act_fn(x: stk.Matrix, function: Callable, return_grad_fn: bool = False, **kwargs): + assert isinstance(x, stk.Matrix) + with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): + if return_grad_fn: + x.data.requires_grad = True + out = function(x.data, **kwargs) + y = stk.Matrix( + x.size(), + out, + x.row_indices, + x.column_indices, + x.offsets, + x.column_indices_t, + x.offsets_t, + x.block_offsets_t) + if return_grad_fn: + return y, out.backward + return y diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 6db5c4ab..9b71170b 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -3,6 +3,7 @@ import megablocks.turbo_util as turbo import megablocks.grouped_gemm_util as grouped_gemm import torch +import torch.nn.functional as F from typing import Callable, Optional, Union # Type annotation for in-place Tensor initialization function. @@ -10,6 +11,8 @@ _ALLOWED_BITWIDTHS = (-1, 4, 8) +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate="tanh") + @dataclasses.dataclass class Arguments: @@ -19,12 +22,13 @@ class Arguments: num_layers : int = 1 bias : bool = True return_bias : bool = True + activation_fn : Optional[Callable] = DEFAULT_ACTIVATION_FN # MoE arguments. moe_num_experts : int = 1 moe_top_k : int = 1 moe_capacity_factor : int = 1 - moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_normalize_expert_weights : Optional[Union[int, float]] = None moe_loss_weight : float = 0.1 moe_jitter_eps : Optional[float] = None moe_lbl_in_fp32 : bool = False diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 0a1fe6b5..b97ebd33 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,12 +1,11 @@ from megablocks.layers import common -from megablocks.layers import gelu +from megablocks.layers.activation_fn import act_fn from megablocks.layers.mlp import SparseMLP, create_dmoe_expert_weights from megablocks.layers import mpu from megablocks.layers.arguments import Arguments, InitFn from megablocks import grouped_gemm_util as gg import stk import torch -import torch.nn.functional as F class SparseGLU(SparseMLP): @@ -38,7 +37,8 @@ def forward(self, x, topo): x1 = stk.ops.sdd(x, w1.t(), topo) x2 = stk.ops.sdd(x, v1.t(), topo) - x1 = stk.ops.mul(gelu.gelu(x1), x2) + activation_fn_out = act_fn(x1, self.args.activation_fn) + x1 = stk.ops.mul(activation_fn_out, x2) return stk.ops.dsd(x1, w2) @@ -56,5 +56,5 @@ def forward(self, x, tokens_per_expert): # Compute the MLP. x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) - x1 = F.gelu(x1, approximate="tanh") * x2 + x1 = self.args.activation_fn(x1) * x2 return gg.ops.gmm(x1, w2, batch_sizes) diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 7894b43c..e8e15056 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -1,8 +1,9 @@ from megablocks.layers import common from megablocks.layers import gelu +from megablocks.layers.activation_fn import act_fn from megablocks.layers import mpu from megablocks.layers import weight_parallel as wp -from megablocks.layers.arguments import Arguments, InitFn +from megablocks.layers.arguments import Arguments, InitFn, DEFAULT_ACTIVATION_FN from megablocks import turbo_util as turbo from megablocks import grouped_gemm_util as gg import stk @@ -121,9 +122,9 @@ def scale_grad(self, w): return scale_gradient(w, self.gradient_scale) def forward(self, x): - return torch.bmm( - F.gelu(torch.bmm(x, self.scale_grad(self.w1)), approximate="tanh"), - self.scale_grad(self.w2)) + x = torch.bmm(x, self.scale_grad(self.w1)) + x = self.args.activation_fn(x) + return torch.bmm(x, self.scale_grad(self.w2)) def create_dmoe_expert_weights(args : Arguments, @@ -155,7 +156,7 @@ class MemoryOptimizedMLP(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd - def forward(ctx, x, w1, w2, topo, num_input_bits, num_remat_bits): + def forward(ctx, x, w1, w2, topo, num_input_bits, num_remat_bits, activation_fn): # x: [m, k], w1: [n, k], w2: [n, k] if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): @@ -177,23 +178,25 @@ def forward(ctx, x, w1, w2, topo, num_input_bits, num_remat_bits): x_q, x_scales = turbo.quantize_signed(x, num_bits=num_input_bits) input_save_args = (x_q, x_scales) - # GeLU. + # Activation function. if num_remat_bits == -1: - gelu_out = gelu.gelu(sdd_out) + activation_fn_out = act_fn(sdd_out, activation_fn) input_save_args += (sdd_out.data,) else: + if activation_fn is not DEFAULT_ACTIVATION_FN: + raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({num_remat_bits=}).') # fused GELU into sdd_out buffer while quantizing input - hidden_q, hidden_scales, gelu_out_data = turbo.quantize_signed( + hidden_q, hidden_scales, activation_fn_out_data = turbo.quantize_signed( sdd_out.data, num_bits=num_remat_bits, op=turbo.ElemwiseOps.GELU_FORWARD, x_forward=sdd_out.data) - gelu_out = sdd_out + activation_fn_out = sdd_out input_save_args += (hidden_q, hidden_scales) # Layer 1: x @ w2. - dsd_out = stk.ops.dsd(gelu_out, w2) + dsd_out = stk.ops.dsd(activation_fn_out, w2) - # NOTE: Save the input to the layer and the gelu input for - # gradient computation. We'll re-compute the gelu forward + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward # pass in the backward pass to avoid materializing another # intermediate. ctx.shape = topo.shape @@ -202,6 +205,7 @@ def forward(ctx, x, w1, w2, topo, num_input_bits, num_remat_bits): ctx.x_shape = x.shape ctx.sdd_out_shape = sdd_out.data.shape ctx.dtype = x.dtype + ctx.activation_fn = activation_fn ctx.save_for_backward(w1, w2, *topo_tensors, *input_save_args) return dsd_out @@ -230,43 +234,51 @@ def backward(ctx, ddsd_out): else: hidden_q, hidden_scales = saved_tensors[-2:] - # rematerialize gelu output + # rematerialize activation function output + activation_fn = ctx.activation_fn + activation_grad_fn = None if ctx.num_remat_bits == -1: sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - gelu_out = gelu.gelu(sdd_out) + activation_fn_out, activation_grad_fn = act_fn(sdd_out, activation_fn, return_grad_fn=True) else: - gelu_out_tensor = turbo.dequantize_signed( + if activation_fn is not DEFAULT_ACTIVATION_FN: + raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({num_remat_bits=}).') + activation_fn_out_tensor = turbo.dequantize_signed( hidden_q, hidden_scales, num_bits=ctx.num_remat_bits, op=turbo.ElemwiseOps.GELU_FORWARD, out_shape=ctx.sdd_out_shape, out_dtype=dtype) - gelu_out = stk.Matrix(ctx.shape, gelu_out_tensor, *topo_tensors) + activation_fn_out = stk.Matrix(ctx.shape, activation_fn_out_tensor, *topo_tensors) - # Compute dw2 with recomputed gelu output. - dw2 = stk.ops.dsd(gelu_out.t(), ddsd_out) + # Compute dw2 with recomputed activation_fn output. + dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) - # Compute dgelu_out. + # Compute dactivation_fn_out. # - # NOTE: We reuse the gelu_out allocation. + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out stk.backend.triton_kernels.sdd( ddsd_out, w2.t(), - gelu_out.shape, - gelu_out.data, - gelu_out.offsets, - gelu_out.row_indices, - gelu_out.column_indices) - dgelu_out = gelu_out + dactivation_fn_out.shape, + dactivation_fn_out.data, + dactivation_fn_out.offsets, + dactivation_fn_out.row_indices, + dactivation_fn_out.column_indices) # Compute dsdd_out. # - # NOTE: This reuses the dgelu_out allocation. + # NOTE: This reuses the dactivation_fn_out allocation. if ctx.num_remat_bits == -1: - dsdd_out = gelu.gelu_backward_(dgelu_out, sdd_out) + if activation_grad_fn is not None: + activation_grad_fn(dactivation_fn_out.data) + dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) + else: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) else: # confusingly, x_out is interpreted as the gradient to overwrite # in-place when the elemwise op is a backwards op ddsd_out_tensor = turbo.dequantize_signed( hidden_q, hidden_scales, num_bits=ctx.num_remat_bits, - op=turbo.ElemwiseOps.GELU_BACKWARD, x_out=dgelu_out.data) + op=turbo.ElemwiseOps.GELU_BACKWARD, x_out=dactivation_fn_out.data) dsdd_out = stk.Matrix(ctx.shape, ddsd_out_tensor, *topo_tensors) # rematerialize MLP input now that we need it @@ -294,7 +306,7 @@ def backward(ctx, ddsd_out): w1, ddsd_out) dx = ddsd_out - return dx, dw1, dw2, None, None, None + return dx, dw1, dw2, None, None, None, None memory_optimized_mlp = MemoryOptimizedMLP.apply @@ -356,12 +368,15 @@ def parallel_forward(self, x, topo): group = self.args.weight_parallel_group w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) if self.args.memory_optimized_mlp: + if self.args.activation_fn is not DEFAULT_ACTIVATION_FN: + raise NotImplementedError(f'memory_optimized_weight_parallel_mlp not implemented for custom {activation_fn=}.') return wp.memory_optimized_weight_parallel_mlp( x, w1, w2, topo, group) # Compute the MLP. x = wp.sdd_nt(x, w1, topo, group) - return wp.dsd_nn(gelu.gelu(x), w2, group) + activation_fn_out = act_fn(x, self.args.activation_fn) + return wp.dsd_nn(activation_fn_out, w2, group) def forward(self, x, topo): w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) @@ -370,11 +385,12 @@ def forward(self, x, topo): elif self.args.memory_optimized_mlp: return memory_optimized_mlp( x, w1, w2, topo, self.args.quantize_inputs_num_bits, - self.args.quantize_rematerialize_num_bits) + self.args.quantize_rematerialize_num_bits, self.args.activation_fn) # Compute the MLP. x = stk.ops.sdd(x, w1.t(), topo) - return stk.ops.dsd(gelu.gelu(x), w2) + activation_fn_out = act_fn(x, self.args.activation_fn) + return stk.ops.dsd(activation_fn_out, w2) class MemoryOptimizedGroupedMLP(torch.autograd.Function): @@ -382,7 +398,7 @@ class MemoryOptimizedGroupedMLP(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd - def forward(ctx, x, w1, w2, batch_sizes, num_input_bits, num_remat_bits): + def forward(ctx, x, w1, w2, batch_sizes, num_input_bits, num_remat_bits, activation_fn): # x: [m, k], w1: [n, k], w2: [n, k] if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): @@ -399,21 +415,23 @@ def forward(ctx, x, w1, w2, batch_sizes, num_input_bits, num_remat_bits): # GeLU. if num_remat_bits == -1: - gelu_out = F.gelu(sdd_out, approximate="tanh") + activation_fn_out = activation_fn(sdd_out) input_save_args += (sdd_out,) else: + if activation_fn is not DEFAULT_ACTIVATION_FN: + raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({num_remat_bits=}).') # Fused GELU into sdd_out buffer while quantizing input - hidden_q, hidden_scales, gelu_out_data = turbo.quantize_signed( + hidden_q, hidden_scales, activation_fn_out_data = turbo.quantize_signed( sdd_out, num_bits=num_remat_bits, op=turbo.ElemwiseOps.GELU_FORWARD, x_forward=sdd_out) - gelu_out = sdd_out + activation_fn_out = sdd_out input_save_args += (hidden_q, hidden_scales) # Layer 1: x @ w2. - dsd_out = gg.backend.gmm(gelu_out, w2, batch_sizes) + dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) - # NOTE: Save the input to the layer and the gelu input for - # gradient computation. We'll re-compute the gelu forward + # NOTE: Save the input to the layer and the activation_fn input for + # gradient computation. We'll re-compute the activation_fn forward # pass in the backward pass to avoid materializing another # intermediate. ctx.num_input_bits = num_input_bits @@ -421,6 +439,7 @@ def forward(ctx, x, w1, w2, batch_sizes, num_input_bits, num_remat_bits): ctx.x_shape = x.shape ctx.sdd_out_shape = sdd_out.shape ctx.dtype = x.dtype + ctx.activation_fn = activation_fn ctx.save_for_backward(w1, w2, batch_sizes, *input_save_args) return dsd_out @@ -451,37 +470,48 @@ def backward(ctx, ddsd_out): else: hidden_q, hidden_scales = saved_tensors[-2:] - # Rematerialize gelu output. + # Rematerialize activation_fn output. + activation_fn = ctx.activation_fn + activation_grad_fn = None if ctx.num_remat_bits == -1: - gelu_out = F.gelu(sdd_out, approximate="tanh") + with torch.set_grad_enabled(True): + sdd_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) + activation_grad_fn = activation_fn_out.backward else: - gelu_out = turbo.dequantize_signed( + if activation_fn is not DEFAULT_ACTIVATION_FN: + raise NotImplementedError(f'`num_remat_bits` != -1 not implemented for custom {activation_fn=} ({num_remat_bits=}).') + activation_fn_out = turbo.dequantize_signed( hidden_q, hidden_scales, num_bits=ctx.num_remat_bits, op=turbo.ElemwiseOps.GELU_FORWARD, out_shape=ctx.sdd_out_shape, out_dtype=dtype) - # Compute dw2 with recomputed gelu output. + # Compute dw2 with recomputed activation_fn output. dw2 = gg.backend.gmm( - gelu_out, ddsd_out, batch_sizes, trans_a=True) + activation_fn_out, ddsd_out, batch_sizes, trans_a=True) - # Compute dgelu_out. + # Compute dactivation_fn_out. # - # NOTE: We reuse the gelu_out allocation. + # NOTE: We reuse the activation_fn_out allocation. + dactivation_fn_out = activation_fn_out gg.backend.gmm( - ddsd_out, w2, batch_sizes, trans_b=True, c=gelu_out) - dgelu_out = gelu_out + ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out) # Compute dsdd_out. # - # NOTE: This reuses the dgelu_out allocation. + # NOTE: This reuses the dactivation_fn_out allocation. if ctx.num_remat_bits == -1: - dsdd_out = gelu.gelu_backward_(dgelu_out, sdd_out) + if activation_grad_fn is not None: + activation_grad_fn(dactivation_fn_out) + dsdd_out = sdd_out.grad + else: + dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) else: # confusingly, x_out is interpreted as the gradient to overwrite # in-place when the elemwise op is a backwards op dsdd_out = turbo.dequantize_signed( hidden_q, hidden_scales, num_bits=ctx.num_remat_bits, - op=turbo.ElemwiseOps.GELU_BACKWARD, x_out=dgelu_out.data) + op=turbo.ElemwiseOps.GELU_BACKWARD, x_out=dactivation_fn_out.data) # rematerialize MLP input now that we need it if ctx.num_input_bits != -1: @@ -497,7 +527,7 @@ def backward(ctx, ddsd_out): # NOTE: This reuses the ddsd_out allocation. gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) dx = ddsd_out - return dx, dw1, dw2, None, None, None + return dx, dw1, dw2, None, None, None, None memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply @@ -514,16 +544,17 @@ def forward(self, x, tokens_per_expert): w2 = w2.view(ne, -1, self.args.hidden_size) if self.args.moe_weight_parallelism: - raise ValueError( + raise NotImplementedError( "Weight parallelism not yet supported with GroupedMLP.") if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( x, w1, w2, batch_sizes, self.args.quantize_inputs_num_bits, - self.args.quantize_rematerialize_num_bits) + self.args.quantize_rematerialize_num_bits, + self.args.activation_fn) # Compute the MLP. x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) - x = F.gelu(x, approximate="tanh") + x = self.args.activation_fn(x) return gg.ops.gmm(x, w2, batch_sizes)