From c97bf026540f6a6d3d6ff3a5fba6974a61f4532a Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Wed, 31 Jul 2024 17:02:15 +0000 Subject: [PATCH 01/43] add GA yaml --- .github/workflows/code-quality.yaml | 42 +++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) create mode 100644 .github/workflows/code-quality.yaml diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml new file mode 100644 index 0000000..062aa41 --- /dev/null +++ b/.github/workflows/code-quality.yaml @@ -0,0 +1,42 @@ +name: Code Quality Checks +on: + push: + branches: + - main + - release/** + pull_request: + branches: + - main + - release/** + workflow_call: + workflow_dispatch: +# Cancel old runs when a new commit is pushed to the same branch if not on main or dev +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} +defaults: + run: + working-directory: . +jobs: + code-quality: + runs-on: linux-ubuntu-latest + timeout-minutes: 30 + strategy: + matrix: + python_version: + - "3.9" + - "3.10" + pip_deps: + - "[dev]" + steps: + - uses: actions/checkout@v3 + - name: Get composite run steps repository + uses: actions/checkout@v3 + with: + repository: mosaicml/ci-testing + ref: v0.0.9 + path: ./ci-testing + - uses: ./ci-testing/.github/actions/code-quality + with: + python_version: ${{ matrix.python_version }} + pip_deps: ${{ matrix.pip_deps }} From ddef812266271bb0b8ca996c6cc1740fbdd06907 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Wed, 31 Jul 2024 17:18:22 +0000 Subject: [PATCH 02/43] apply ruff --- megablocks/__init__.py | 2 -- megablocks/backend/kernels.py | 14 ++++---- megablocks/layers/glu.py | 19 ++++++---- megablocks/layers/memory_test.py | 5 +-- megablocks/layers/mlp.py | 26 +++++++------- megablocks/layers/moe.py | 14 +++----- megablocks/ops/__init__.py | 15 -------- megablocks/ops/all_to_all_benchmark.py | 10 +++--- megablocks/ops/matmul_benchmark.py | 42 ++++++++++++++-------- megablocks/ops/padded_scatter_benchmark.py | 11 +++--- megablocks/ops/permute_benchmark.py | 24 +++++++------ megablocks/ops/repeat.py | 1 - megablocks/ops/sum.py | 1 - tests/layers/glu_test.py | 1 - 14 files changed, 89 insertions(+), 96 deletions(-) diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 90e4511..e69de29 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,2 +0,0 @@ -import megablocks.layers.dmoe -import megablocks.layers.moe diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index f99f93c..09dbfc9 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -62,12 +62,12 @@ def _padded_copy( # Now we know what bin we're assigned to, but we need to know how # many threadblocks were assigned to earlier bins so we can offset # in our bin properly. - offset_in_bin = tl.program_id(0); + offset_in_bin = tl.program_id(0) if bin_idx > 0: offset_in_bin -= tl.load(bins + bin_idx - 1) # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin; + index_b = offset_in_bin if bin_idx > 0: index_b += tl.load(padded_bins + bin_idx - 1) @@ -247,12 +247,12 @@ def _padded_copy_wgrad( # Now we know what bin we're assigned to, but we need to know how # many threadblocks were assigned to earlier bins so we can offset # in our bin properly. - offset_in_bin = tl.program_id(0); + offset_in_bin = tl.program_id(0) if bin_idx > 0: offset_in_bin -= tl.load(bins + bin_idx - 1) # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin; + index_x = offset_in_bin if bin_idx > 0: index_x += tl.load(padded_bins + bin_idx - 1) @@ -264,7 +264,7 @@ def _padded_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + for i in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -380,7 +380,7 @@ def _binned_copy( optr = b if A_TO_B else a iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + for i in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -510,7 +510,7 @@ def _binned_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + for i in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index cc6931a..828f10f 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,12 +1,17 @@ -from megablocks.layers import common -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.mlp import SparseMLP, SharedMLP, create_dmoe_expert_weights, resolve_dtensor -from megablocks.layers import mpu -from megablocks.layers.arguments import Arguments, DEFAULT_ACTIVATION_FN -from megablocks import grouped_gemm_util as gg import stk import torch +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + class SparseGLU(SparseMLP): @@ -91,7 +96,7 @@ def backward(ctx, ddsd_out): raise ValueError("Expected all MLP inputs to need grad.") # Unpack saved tensors - dtype = ctx.dtype + # dtype = ctx.dtype saved_tensors = ctx.saved_tensors w1, v1, w2 = saved_tensors[:3] batch_sizes = saved_tensors[3] diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index e314272..f2e5e49 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -1,9 +1,6 @@ -import functools import gc -from megablocks.layers import dmoe, arguments, mpu -from megablocks import benchmark_util -import numpy as np +from megablocks.layers import dmoe, arguments import torch _TESTS = ( diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 2bb1e3b..f18a824 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -1,16 +1,14 @@ -from packaging import version from typing import Any -from megablocks.layers import common -from megablocks.layers import gelu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers import mpu -from megablocks.layers import weight_parallel as wp -from megablocks.layers.arguments import Arguments, InitFn, DEFAULT_ACTIVATION_FN -from megablocks import grouped_gemm_util as gg import stk import torch -import torch.nn.functional as F +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers import weight_parallel as wp +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): @@ -85,7 +83,7 @@ class MLP(torch.nn.Module): def __init__(self, args : Arguments): super().__init__() self.args = args - expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) experts_per_rank = mpu.experts_per_rank(args) self.w1 = torch.nn.Parameter(torch.empty( @@ -216,7 +214,7 @@ def backward(ctx, ddsd_out): raise ValueError("Expected all MLP inputs to need grad.") # unpack saved tensors - dtype = ctx.dtype + # dtype = ctx.dtype saved_tensors = ctx.saved_tensors w1, w2 = saved_tensors[:2] topo_tensors = saved_tensors[2:8] @@ -335,7 +333,9 @@ def parallel_forward(self, x, topo): w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) if self.args.memory_optimized_mlp: if self.args.activation_fn is not DEFAULT_ACTIVATION_FN: - raise NotImplementedError(f'memory_optimized_weight_parallel_mlp not implemented for custom {activation_fn=}.') + raise NotImplementedError( + f'memory_optimized_weight_parallel_mlp not implemented for custom activation_fn={self.args.activation_fn}.' + ) return wp.memory_optimized_weight_parallel_mlp( x, w1, w2, topo, group) @@ -404,7 +404,7 @@ def backward(ctx, ddsd_out): raise ValueError("Expected all MLP inputs to need grad.") # Unpack saved tensors - dtype = ctx.dtype + # dtype = ctx.dtype saved_tensors = ctx.saved_tensors w1, w2 = saved_tensors[:2] batch_sizes = saved_tensors[2] diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 9d26da2..021c00f 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,14 +1,10 @@ -from megablocks.layers import common -from megablocks.layers import mpu -from megablocks.layers import router -from megablocks.layers import mlp -from megablocks.layers import sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments -import megablocks.ops as ops import numpy as np import torch +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments _LOAD_BALANCING_LOSS = [] @@ -113,7 +109,7 @@ def __init__(self, args : Arguments): # Calculate the number of experts in total and the number of experts # owned by this rank. - world_size = mpu.get_expert_parallel_world_size(args) + # world_size = mpu.get_expert_parallel_world_size(args) self.num_experts = args.moe_num_experts self.top_k = self.args.moe_top_k diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index 44a2909..e69de29 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -1,15 +0,0 @@ -from megablocks.ops.binned_gather import binned_gather -from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum -from megablocks.ops.cumsum import inclusive_cumsum -from megablocks.ops.gather import gather -from megablocks.ops.histogram import histogram -from megablocks.ops.padded_gather import padded_gather -from megablocks.ops.padded_scatter import padded_scatter -from megablocks.ops.repeat import repeat -from megablocks.ops.replicate import replicate -from megablocks.ops.round_up import round_up -from megablocks.ops.scatter import scatter -from megablocks.ops.sort import sort -from megablocks.ops.sum import sum -from megablocks.ops.topology import topology diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index d3fbcf3..ccae8f3 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -1,7 +1,8 @@ -from megablocks.layers.all_to_all import all_to_all -from megablocks import benchmark_util import torch +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + _ALL_TO_ALL_BENCHMARK = ( (8, 1024), (16, 1024), @@ -35,8 +36,9 @@ def benchmark_all_to_all(group, sl, hs): "message_size (B)": send_recv_sizes[0] * hs * 2, # 2B elements. } - fn = lambda: all_to_all(x, send_recv_sizes, send_recv_sizes, group) - time, std = benchmark_util.benchmark_function(fn) + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) + time, std = benchmark_util.benchmark_function(benchmark) if torch.distributed.get_rank(group) == 0: benchmark_util.log_benchmark("All-To-All", details, time, std) diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index 632155c..6016bd5 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -1,10 +1,10 @@ import unittest -from absl.testing import parameterized -from megablocks import benchmark_util -from megablocks import ops import stk import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls @@ -96,7 +96,8 @@ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) w = transpose_view(w) - benchmark = lambda: stk.ops.sdd(x, w, topo) + def benchmark(): + return stk.ops.sdd(x, w, topo) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -113,7 +114,8 @@ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - benchmark = lambda: stk.ops.dsd(topo, w) + def benchmark(): + return stk.ops.dsd(topo, w) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -130,7 +132,8 @@ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) topo = topo.t() - benchmark = lambda: stk.ops.dsd(topo, x) + def benchmark(): + return stk.ops.dsd(topo, x) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -147,7 +150,8 @@ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - benchmark = lambda: stk.ops.dsd(x, w) + def benchmark(): + return stk.ops.dsd(x, w) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -166,7 +170,8 @@ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): out = stk.ops.dsd(x, w) w = transpose_view(w) - benchmark = lambda: stk.ops.sdd(out, w, x) + def benchmark(): + return stk.ops.sdd(out, w, x) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -185,7 +190,8 @@ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): out = stk.ops.dsd(x, w) x = x.t() - benchmark = lambda: stk.ops.dsd(x, out) + def benchmark(): + return stk.ops.dsd(x, out) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -205,7 +211,8 @@ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): w = w.transpose(1, 2).contiguous() w = w.transpose(1, 2) - benchmark = lambda: torch.bmm(x, w) + def benchmark(): + return torch.bmm(x, w) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -224,7 +231,8 @@ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): out = torch.bmm(x, w) w = w.transpose(1, 2).contiguous() - benchmark = lambda: torch.bmm(out, w) + def benchmark(): + return torch.bmm(out, w) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -243,7 +251,8 @@ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): out = torch.bmm(x, w) out = out.transpose(1, 2) - benchmark = lambda: torch.bmm(out, x) + def benchmark(): + return torch.bmm(out, x) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -260,7 +269,8 @@ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): x = torch.randn((ne, sl // ne, fhs)).cuda().half() w = torch.randn((ne, fhs, hs)).cuda().half() - benchmark = lambda: torch.bmm(x, w) + def benchmark(): + return torch.bmm(x, w) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -279,7 +289,8 @@ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): out = torch.bmm(x, w) w = torch.transpose(w, 1, 2) - benchmark = lambda: torch.bmm(out, w) + def benchmark(): + return torch.bmm(out, w) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -298,7 +309,8 @@ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): out = torch.bmm(x, w) x = torch.transpose(x, 1, 2) - benchmark = lambda: torch.bmm(x, out) + def benchmark(): + return torch.bmm(x, out) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, diff --git a/megablocks/ops/padded_scatter_benchmark.py b/megablocks/ops/padded_scatter_benchmark.py index 7a7c337..87e17ed 100644 --- a/megablocks/ops/padded_scatter_benchmark.py +++ b/megablocks/ops/padded_scatter_benchmark.py @@ -1,10 +1,9 @@ import unittest -from absl.testing import parameterized -from megablocks import ops -from megablocks import benchmark_util import torch +from absl.testing import parameterized +from megablocks import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. @@ -35,10 +34,10 @@ def testPaddedScatter(self, sl, hs, ne, top_k): # Gather the data to prepare for backwards. x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - fn = lambda: ops.padded_scatter( - x, indices, bin_ids, weights, bins, padded_bins, top_k) + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k) - time, std = benchmark_util.benchmark_function(fn) + time, std = benchmark_util.benchmark_function(benchmark) benchmark_util.log_benchmark( "Padded Scatter", {"sequence_length": sl, diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index fb5b7f1..f727a03 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -1,12 +1,9 @@ import unittest -from absl.testing import parameterized -from megablocks import benchmark_util -from megablocks import ops -import numpy as np -import stk import torch +from absl.testing import parameterized +from megablocks import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), @@ -40,7 +37,8 @@ def testBinnedGather(self, sl, hs, ne): tokens_per_expert = ops.histogram(indices, ne) bins = ops.inclusive_cumsum(tokens_per_expert, 0) - benchmark = lambda: ops.binned_gather(x, indices, bins, ec) + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -62,7 +60,8 @@ def testBinnedScatter(self, sl, hs, ne): bins = ops.inclusive_cumsum(tokens_per_expert, 0) x = ops.binned_gather(x, indices, bins, ec) - benchmark = lambda: ops.binned_scatter(x, indices, bins) + def benchmark(): + return ops.binned_scatter(x, indices, bins) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -84,7 +83,8 @@ def testPaddedGather(self, sl, hs, ne): padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) bins = ops.inclusive_cumsum(tokens_per_expert, 0) - benchmark = lambda: ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -107,7 +107,8 @@ def testPaddedScatter(self, sl, hs, ne): bins = ops.inclusive_cumsum(tokens_per_expert, 0) x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - benchmark = lambda: ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -119,13 +120,14 @@ def testPaddedScatter(self, sl, hs, ne): @parameterized.parameters(*_PERMUTE_TESTS) def testCopy(self, sl, hs, ne): # NOTE: Capacity factor == 1. - ec = sl // ne + # ec = sl // ne # Create the data and indices. x = torch.randn((sl, hs)).cuda().half() y = x.clone() - benchmark = lambda: y.copy_(x) + def benchmark(): + return y.copy_(x) mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index d02c956..8fdb5fa 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,4 +1,3 @@ -import torch def repeat(x, tiling): diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index 9d550b5..5b511e6 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,4 +1,3 @@ -import torch def sum(x, dim=0): diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index 0487ec8..b09cff3 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -6,7 +6,6 @@ from megablocks.layers import dmlp_registry, testing from megablocks.layers.arguments import Arguments -from megablocks.layers.glu import GroupedGLU, SparseGLU _DENSE_TESTS = ( (16, 1024, 512), From d5502ab5963fc8ee2d75ecd7de701905d81ff7a9 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Wed, 31 Jul 2024 17:31:32 +0000 Subject: [PATCH 03/43] fix init files --- megablocks/__init__.py | 8 ++++++++ megablocks/backend/__init__.py | 1 - megablocks/layers/__init__.py | 6 ++++++ megablocks/ops/__init__.py | 32 ++++++++++++++++++++++++++++++++ 4 files changed, 46 insertions(+), 1 deletion(-) diff --git a/megablocks/__init__.py b/megablocks/__init__.py index e69de29..0c6bd07 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -0,0 +1,8 @@ +from megablocks.layers import dmoe, moe + +"""Key classes are available directly in the ``MegaBlocks`` namespace.""" + +__all__ = [ + 'dmoe', + 'moe', +] diff --git a/megablocks/backend/__init__.py b/megablocks/backend/__init__.py index 8b13789..e69de29 100644 --- a/megablocks/backend/__init__.py +++ b/megablocks/backend/__init__.py @@ -1 +0,0 @@ - diff --git a/megablocks/layers/__init__.py b/megablocks/layers/__init__.py index 8b13789..da9bec7 100644 --- a/megablocks/layers/__init__.py +++ b/megablocks/layers/__init__.py @@ -1 +1,7 @@ +from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE +__all__ = [ + 'MoE', + 'dMoE', +] diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index e69de29..222784e 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -0,0 +1,32 @@ +from megablocks.ops.binned_gather import binned_gather +from megablocks.ops.binned_scatter import binned_scatter +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum +from megablocks.ops.gather import gather +from megablocks.ops.histogram import histogram +from megablocks.ops.padded_gather import padded_gather +from megablocks.ops.padded_scatter import padded_scatter +from megablocks.ops.repeat import repeat +from megablocks.ops.replicate import replicate +from megablocks.ops.round_up import round_up +from megablocks.ops.scatter import scatter +from megablocks.ops.sort import sort +from megablocks.ops.sum import sum +from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] From 243afa9428488c87a6b8b8a448d240213642f4d0 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 01:58:56 +0000 Subject: [PATCH 04/43] add ruff to pre-commit --- .pre-commit-config.yaml | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 .pre-commit-config.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..975cbf9 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,8 @@ +default_language_version: + python: python3 +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.2 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] From 32ef2e6be2831a13fd61fc65d52dcd147da1b26d Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 02:09:36 +0000 Subject: [PATCH 05/43] add ruff settings to pyproject.toml --- megablocks/layers/dmoe.py | 2 +- megablocks/layers/memory_test.py | 5 +++-- megablocks/layers/mlp.py | 2 +- megablocks/layers/moe.py | 12 ++++++------ megablocks/ops/histogram_benchmark.py | 4 ++-- megablocks/ops/matmul_benchmark.py | 24 ++++++++++++------------ megablocks/ops/permute_benchmark.py | 12 ++++++------ megablocks/ops/repeat.py | 2 +- megablocks/ops/sort_benchmark.py | 4 ++-- pyproject.toml | 27 +++++++++++++++++++++++++++ setup.py | 4 ++-- tests/conftest.py | 2 +- tests/ops/sort_test.py | 2 +- 13 files changed, 65 insertions(+), 37 deletions(-) diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 04a538d..c1de9a8 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -95,7 +95,7 @@ def topology(self, x, padded_bins): device='meta') shape = ( padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args) + self.ffn_hidden_size * mpu.experts_per_rank(self.args), ) row_indices = stk.ops.row_indices( shape, data, offsets, column_indices) diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index f2e5e49..9b333b3 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -1,8 +1,9 @@ import gc -from megablocks.layers import dmoe, arguments import torch +from megablocks.layers import arguments, dmoe + _TESTS = ( (8, 2048, 4096, 4096, 32, 4), @@ -10,7 +11,7 @@ def get_tensors(): - ptrs = set([]) + ptrs = set() out = [] for obj in gc.get_objects(): if torch.is_tensor(obj): diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index f18a824..134dd73 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -334,7 +334,7 @@ def parallel_forward(self, x, topo): if self.args.memory_optimized_mlp: if self.args.activation_fn is not DEFAULT_ACTIVATION_FN: raise NotImplementedError( - f'memory_optimized_weight_parallel_mlp not implemented for custom activation_fn={self.args.activation_fn}.' + f'memory_optimized_weight_parallel_mlp not implemented for custom activation_fn={self.args.activation_fn}.', ) return wp.memory_optimized_weight_parallel_mlp( x, w1, w2, topo, group) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 021c00f..6a31bfd 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -54,16 +54,16 @@ def batched_load_balancing_loss(args : Arguments): f" = {args.num_layers_per_virtual_pipeline_stage}") # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all([ + assert all(( x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert - ]) + )) tokens = expert_scores[0].shape[0] - assert all([ + assert all(( (x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores - ]) + )) # Concatenate the contributions of each layer and convert to @@ -345,7 +345,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): torch.arange( self.num_experts * mpu.hidden_sharding_degree(self.args), dtype=torch.int32, - device=indices.device + device=indices.device, ), mpu.experts_per_rank(self.args), ) @@ -406,7 +406,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): shape = ( mpu.hidden_sharding_degree(self.args), -1, - self.args.hidden_size + self.args.hidden_size, ) x = ops.sum(x.view(shape), dim=0) diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index 9e0e930..ff4689a 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -54,7 +54,7 @@ def testHistogram(self, n, dtype, max_val): arguments = { "n": n, "dtype": dtype, - "max_val": max_val + "max_val": max_val, } log_benchmark(arguments, mean_t, std_t) @@ -67,7 +67,7 @@ def testTorchHistogram(self, n, dtype, max_val): arguments = { "n": n, "dtype": dtype, - "max_val": max_val + "max_val": max_val, } log_benchmark(arguments, mean_t, std_t) diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index 6016bd5..6fe4626 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -103,7 +103,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("0::Fwd::SDD::NT", arguments, mean_t, std_t, x.numel() * fhs * 2) @@ -121,7 +121,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("0::GradX::DSD::NN", arguments, mean_t, std_t, x.numel() * fhs * 2) @@ -139,7 +139,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("0::GradW::DSD::TN", arguments, mean_t, std_t, x.numel() * fhs * 2) @@ -157,7 +157,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("1::Fwd::DSD::NN", arguments, mean_t, std_t, x.nnz * hs * 2) @@ -177,7 +177,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("1::GradX::SDD::NT", arguments, mean_t, std_t, x.nnz * hs * 2) @@ -197,7 +197,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("1::GradW::DSD::TN", arguments, mean_t, std_t, x.nnz * hs * 2) @@ -218,7 +218,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("0::Fwd:DDD::NT", arguments, mean_t, std_t, x.numel() * fhs * 2) @@ -238,7 +238,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("0:GradX:DDD::NN", arguments, mean_t, std_t, x.numel() * fhs * 2) @@ -258,7 +258,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("0:GradW:DDD::TN", arguments, mean_t, std_t, x.numel() * fhs * 2) @@ -276,7 +276,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("1::Fwd::DDD::NN", arguments, mean_t, std_t, x.numel() * hs * 2) @@ -296,7 +296,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("1::GradX::DDD::NT", arguments, mean_t, std_t, x.numel() * hs * 2) @@ -316,7 +316,7 @@ def benchmark(): "sequence_length": sl, "hidden_size": hs, "ffn_hidden_size": fhs, - "num_experts": ne + "num_experts": ne, } log_benchmark("1::GradW::DDD::TN", arguments, mean_t, std_t, x.numel() * hs * 2) diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index f727a03..a93dc27 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -19,7 +19,7 @@ (16384 * 8, 768, 16), (16384 * 8, 768, 32), (16384 * 8, 768, 64), - (16384 * 8, 768, 128) + (16384 * 8, 768, 128), ) @@ -43,7 +43,7 @@ def benchmark(): arguments = { "sequence_length": sl, "hidden_size": hs, - "num_experts": ne + "num_experts": ne, } benchmark_util.log_benchmark("BinnedGather", arguments, mean_t, std_t) @@ -66,7 +66,7 @@ def benchmark(): arguments = { "sequence_length": sl, "hidden_size": hs, - "num_experts": ne + "num_experts": ne, } benchmark_util.log_benchmark("BinnedScatter", arguments, mean_t, std_t) @@ -89,7 +89,7 @@ def benchmark(): arguments = { "sequence_length": sl, "hidden_size": hs, - "num_experts": ne + "num_experts": ne, } benchmark_util.log_benchmark("PaddedGather", arguments, mean_t, std_t) @@ -113,7 +113,7 @@ def benchmark(): arguments = { "sequence_length": sl, "hidden_size": hs, - "num_experts": ne + "num_experts": ne, } benchmark_util.log_benchmark("PaddedScatter", arguments, mean_t, std_t) @@ -132,7 +132,7 @@ def benchmark(): arguments = { "sequence_length": sl, "hidden_size": hs, - "num_experts": ne + "num_experts": ne, } benchmark_util.log_benchmark("Copy", arguments, mean_t, std_t) diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index 8fdb5fa..66e4b10 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,6 +1,6 @@ def repeat(x, tiling): - if all([t == 1 for t in tiling]): + if all((t == 1 for t in tiling)): return x return x.repeat(*tiling) diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index 4305767..56e343f 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -21,7 +21,7 @@ def numpy_dtype(dtype): types = { torch.int16: np.int16, torch.int32: np.int32, - torch.int64: np.int64 + torch.int64: np.int64, } return types[dtype] @@ -66,7 +66,7 @@ def testSort(self, n, dtype, max_val): arguments = { "n": n, "dtype": dtype, - "max_val": max_val + "max_val": max_val, } log_benchmark(arguments, mean_t, std_t) diff --git a/pyproject.toml b/pyproject.toml index b4b90ec..2cf40c5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,3 +44,30 @@ concurrency = ["thread"] include = [ "megablocks/*" ] + + +# Ruff global +[tool.ruff] +exclude = [ + "build/**", + "docs/**", + "node_modules/**", +] + +# Ruff linter +[tool.ruff.lint] +select = [ + "C4", # flake8-comprehensions + # TODO port pydocstyle + # "D", # pydocstyle + "LOG", + "PERF", + "PLE", + "COM812", +] + +ignore = [ + "C408", + "PERF2", + "PERF4", +] diff --git a/setup.py b/setup.py index ac1b43f..646a17b 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ ] if device_capability: nvcc_flags.append( - f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}" + f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}", ) ext_modules = [ @@ -26,7 +26,7 @@ ["csrc/ops.cu"], include_dirs=["csrc"], extra_compile_args={"cxx": ["-fopenmp"], "nvcc": nvcc_flags}, - ) + ), ] install_requires = [ diff --git a/tests/conftest.py b/tests/conftest.py index 335140c..328c712 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ # Add the path of any pytest fixture files you want to make global pytest_plugins = [ 'tests.fixtures.autouse', - 'tests.fixtures.fixtures' + 'tests.fixtures.fixtures', ] diff --git a/tests/ops/sort_test.py b/tests/ops/sort_test.py index e07f2e1..243aef1 100644 --- a/tests/ops/sort_test.py +++ b/tests/ops/sort_test.py @@ -36,7 +36,7 @@ def torch_to_numpy_dtype( types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = { torch.int16: np.int16, torch.int32: np.int32, - torch.int64: np.int64 + torch.int64: np.int64, } return types[dtype] From 36dadc33fca562d9504e659d85f93e8532ed6456 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 02:17:50 +0000 Subject: [PATCH 06/43] yapf --- .pre-commit-config.yaml | 12 + megablocks/__init__.py | 1 - megablocks/backend/kernels.py | 162 ++++----- megablocks/benchmark_util.py | 6 +- megablocks/grouped_gemm_util.py | 6 +- megablocks/layers/activation_fn.py | 10 +- megablocks/layers/all_to_all.py | 30 +- megablocks/layers/arguments.py | 62 ++-- megablocks/layers/common.py | 3 +- megablocks/layers/dmlp_registry.py | 17 +- megablocks/layers/dmoe.py | 180 ++++----- megablocks/layers/gelu.py | 12 +- megablocks/layers/glu.py | 106 ++++-- megablocks/layers/memory_test.py | 62 ++-- megablocks/layers/mlp.py | 283 ++++++++++----- megablocks/layers/moe.py | 181 ++++----- megablocks/layers/mpu.py | 53 +-- megablocks/layers/router.py | 22 +- megablocks/layers/sharedexpert_registry.py | 4 +- megablocks/layers/testing.py | 53 ++- megablocks/layers/weight_parallel.py | 131 +++++-- megablocks/ops/all_to_all_benchmark.py | 28 +- megablocks/ops/binned_gather.py | 3 + megablocks/ops/binned_scatter.py | 14 +- megablocks/ops/cumsum.py | 6 + megablocks/ops/gather.py | 8 +- megablocks/ops/histogram.py | 3 + megablocks/ops/histogram_benchmark.py | 12 +- megablocks/ops/matmul_benchmark.py | 152 ++++++-- megablocks/ops/padded_gather.py | 20 +- megablocks/ops/padded_scatter.py | 50 ++- megablocks/ops/padded_scatter_benchmark.py | 23 +- megablocks/ops/permute_benchmark.py | 5 + megablocks/ops/repeat.py | 2 - megablocks/ops/replicate.py | 17 +- megablocks/ops/scatter.py | 23 +- megablocks/ops/sort.py | 4 +- megablocks/ops/sort_benchmark.py | 19 +- megablocks/ops/sum.py | 2 - megablocks/ops/topology.py | 35 +- pyproject.toml | 403 +++++++++++++++++++++ setup.py | 7 +- tests/conftest.py | 18 +- tests/fixtures/autouse.py | 9 +- tests/layers/dmoe_test.py | 145 ++++---- tests/layers/glu_test.py | 22 +- tests/layers/moe_test.py | 40 +- tests/layers/parallelism_test.py | 61 ++-- tests/ops/binned_gather_test.py | 9 +- tests/ops/binned_scatter_test.py | 15 +- tests/ops/histogram_test.py | 5 +- tests/ops/padded_gather_test.py | 11 +- tests/ops/padded_scatter_test.py | 35 +- tests/ops/sort_test.py | 3 +- tests/ops/topology_test.py | 24 +- 55 files changed, 1815 insertions(+), 814 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 975cbf9..d991b58 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -6,3 +6,15 @@ repos: hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] +- repo: https://github.com/google/yapf + rev: v0.32.0 + hooks: + - id: yapf + name: yapf + description: A formatter for Python files. + entry: yapf + args: [-i, -vv, -p] # inplace + language: python + types: [python] + additional_dependencies: + - toml diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 0c6bd07..1e527a4 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,5 +1,4 @@ from megablocks.layers import dmoe, moe - """Key classes are available directly in the ``MegaBlocks`` namespace.""" __all__ = [ diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index 09dbfc9..a1668eb 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -19,7 +19,9 @@ def assert_is_vector(x): def assert_equal(a, b): if a != b: - raise ValueError(f"Expected dimensions to be equal but got {a} and {b}.") + raise ValueError( + f"Expected dimensions to be equal but got {a} and {b}.", + ) # a: (tokens, hidden_size), real. @@ -40,18 +42,19 @@ def assert_equal(a, b): ) @triton.jit def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr, - A_TO_B : tl.constexpr, - SCALE : tl.constexpr): + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): # Our index into array 'a'. index_a = tl.load(indices + tl.program_id(0)) @@ -116,10 +119,7 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): # NOTE: Because of the padding, the output size is dynamic. # We load the final padded bin bound to get the output rows. output_rows = padded_bins[-1].cpu().item() - out = torch.zeros( - (output_rows, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, out, @@ -131,7 +131,8 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) return out @@ -150,10 +151,7 @@ def gather(x, indices, bin_ids, weights, bins, top_k): # NOTE: There is no padding so the output rows equals the # input rows multiplied by top_k. output_rows = x.shape[0] * top_k - out = torch.empty( - (output_rows, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, out, @@ -165,7 +163,8 @@ def gather(x, indices, bin_ids, weights, bins, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) return out @@ -183,10 +182,9 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): assert_equal(indices.shape[0], weights.shape[0]) tokens = indices.shape[0] // top_k - out = torch.empty( - (tokens, top_k, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.empty((tokens, top_k, x.shape[1]), + dtype=x.dtype, + device=x.device) _padded_copy[(indices.shape[0],)]( out, x, @@ -198,7 +196,8 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=False, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) # Reduce along the top-k dimension, if needed. return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) @@ -227,16 +226,17 @@ def scatter(x, indices, bin_ids, weights, bins, top_k): ) @triton.jit def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr): + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): # Our index into 'tokens * top_k'. index_out = tl.load(indices + tl.program_id(0)) @@ -288,10 +288,7 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): assert_equal(bins.size(), padded_bins.size()) tokens = indices.shape[0] // top_k - out = torch.empty( - (tokens * top_k), - dtype=x.dtype, - device=x.device) + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) _padded_copy_wgrad[(indices.shape[0],)]( x, grad, @@ -301,7 +298,8 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): bins, padded_bins, NUM_COLUMNS=x.shape[1], - TOP_K=top_k) + TOP_K=top_k, + ) return out @@ -326,18 +324,19 @@ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): ) @triton.jit def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr, - A_TO_B : tl.constexpr, - SCALE : tl.constexpr): + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): # Load our indices into the output. expert_idx = tl.program_id(0) entry_idx = tl.program_id(1) @@ -349,7 +348,7 @@ def _binned_copy( # the number of tokens assigned to our expert. start = 0 if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) + start = tl.load(bins + expert_idx - 1) end = tl.load(bins + expert_idx) num_tokens = end - start @@ -401,10 +400,9 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): assert_equal(weights.shape[0], x.shape[0] * top_k) num_experts = bins.shape[0] - out = torch.zeros( - (num_experts, expert_capacity, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), + dtype=x.dtype, + device=x.device) _binned_copy[(num_experts, expert_capacity)]( x, @@ -417,7 +415,8 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) return out @@ -433,10 +432,9 @@ def binned_scatter(x, indices, weights, bins, top_k): num_experts, expert_capacity, hidden_size = x.shape tokens = indices.shape[0] // top_k - out = torch.zeros( - (tokens, top_k, hidden_size), - dtype=x.dtype, - device=x.device) + out = torch.zeros((tokens, top_k, hidden_size), + dtype=x.dtype, + device=x.device) _binned_copy[(num_experts, expert_capacity)]( out, x, @@ -448,7 +446,8 @@ def binned_scatter(x, indices, weights, bins, top_k): NUM_COLUMNS=hidden_size, A_TO_B=False, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) # Reduce along the top-k dimension, if needed. return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) @@ -471,16 +470,17 @@ def binned_scatter(x, indices, weights, bins, top_k): ) @triton.jit def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr): + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): # Load our indices into the output. expert_idx = tl.program_id(0) entry_idx = tl.program_id(1) @@ -492,7 +492,7 @@ def _binned_copy_wgrad( # the number of tokens assigned to our expert. start = 0 if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) + start = tl.load(bins + expert_idx - 1) end = tl.load(bins + expert_idx) num_tokens = end - start @@ -532,10 +532,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k): num_experts, expert_capacity, hidden_size = x.shape tokens = indices.shape[0] // top_k - out = torch.zeros( - (tokens * top_k), - dtype=x.dtype, - device=x.device) + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) _binned_copy_wgrad[(num_experts, expert_capacity)]( x, grad, @@ -545,5 +542,6 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k): indices, bins, NUM_COLUMNS=hidden_size, - TOP_K=top_k) + TOP_K=top_k, + ) return out diff --git a/megablocks/benchmark_util.py b/megablocks/benchmark_util.py index abf3521..2984ac0 100644 --- a/megablocks/benchmark_util.py +++ b/megablocks/benchmark_util.py @@ -3,14 +3,14 @@ def log_benchmark(name, arguments, time, std): - print("="*60) + print("=" * 60) print(f"{name} Benchmark") print("Benchmark Parameters:") for (key, value) in arguments.items(): print(f"{key} = {value}") print("Results:") print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std)) - print("="*60) + print("=" * 60) def benchmark_function(fn, iterations=100, warmup=10): @@ -26,7 +26,7 @@ def benchmark_function(fn, iterations=100, warmup=10): start.record() fn() end.record() - + torch.cuda.synchronize() times.append(start.elapsed_time(end)) return np.mean(times), np.std(times) diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index be24c6f..bdd81ee 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -3,13 +3,17 @@ except ImportError: grouped_gemm = None + def grouped_gemm_is_available(): return grouped_gemm is not None + def assert_grouped_gemm_is_available(): assert grouped_gemm_is_available(), ( "Grouped GEMM not available. Please run " - "`pip install git+https://github.com/tgale96/grouped_gemm@main`.") + "`pip install git+https://github.com/tgale96/grouped_gemm@main`." + ) + backend = grouped_gemm.backend if grouped_gemm_is_available() else None ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 613ef31..3038f44 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -4,7 +4,12 @@ import stk -def act_fn(x: stk.Matrix, function: Callable, return_grad_fn: bool = False, **kwargs): +def act_fn( + x: stk.Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +): assert isinstance(x, stk.Matrix) with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): if return_grad_fn: @@ -18,7 +23,8 @@ def act_fn(x: stk.Matrix, function: Callable, return_grad_fn: bool = False, **kw x.offsets, x.column_indices_t, x.offsets_t, - x.block_offsets_t) + x.block_offsets_t, + ) if return_grad_fn: return y, out.backward return y diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 12098eb..b94f662 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -1,23 +1,26 @@ import torch + class AllToAllOp(torch.autograd.Function): @staticmethod def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty( - (sum(output_split_sizes),) + x.shape[1:], - device=x.device, dtype=x.dtype) + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], + device=x.device, + dtype=x.dtype) ctx.input_shape = x.shape ctx.output_split_sizes = output_split_sizes ctx.input_split_sizes = input_split_sizes ctx.group = group handle = torch.distributed.all_to_all_single( - out, x, + out, + x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=group, - async_op=async_op) + async_op=async_op, + ) return out, handle @staticmethod @@ -26,15 +29,24 @@ def backward(ctx, grad, _): out = torch.empty( ctx.input_shape, device=grad.device, - dtype=grad.dtype) + dtype=grad.dtype, + ) torch.distributed.all_to_all_single( - out, grad, + out, + grad, output_split_sizes=ctx.input_split_sizes, input_split_sizes=ctx.output_split_sizes, - group=ctx.group) + group=ctx.group, + ) return out, None, None, None, None return None, None, None, None, None + def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): return AllToAllOp.apply( - x, output_split_sizes, input_split_sizes, group, async_op) + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 9b6c49b..674dd47 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -16,51 +16,55 @@ @dataclasses.dataclass class Arguments: # Model arguments. - hidden_size : int = 1024 - ffn_hidden_size : int = 4096 - num_layers : int = 1 - bias : bool = True - return_bias : bool = True - activation_fn : Optional[Callable] = DEFAULT_ACTIVATION_FN + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN # MoE arguments. - moe_num_experts : int = 1 - moe_top_k : int = 1 - moe_capacity_factor : int = 1 - moe_normalize_expert_weights : Optional[Union[int, float]] = None - moe_loss_weight : float = 0.1 - moe_jitter_eps : Optional[float] = None - moe_lbl_in_fp32 : bool = False + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False # Parallelism arguments. - moe_expert_model_parallelism : bool = False - expert_parallel_group : Optional[torch.distributed.ProcessGroup] = None - moe_weight_parallelism : bool = False - weight_parallel_group : Optional[torch.distributed.ProcessGroup] = None - pipeline_model_parallel_size : int = 1 - num_layers_per_virtual_pipeline_stage : Optional[int] = None + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None + moe_weight_parallelism: bool = False + weight_parallel_group: Optional[torch.distributed.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None # Compute arguments. - memory_optimized_mlp : bool = False - mlp_type : str = 'mlp' - mlp_impl : str = 'sparse' + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' # Initialization arguments. - fp16 : bool = True + fp16: bool = True bf16: bool = False - device : torch.device = torch.cuda.current_device() - init_method : InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method : InitFn = init_method + device: torch.device = torch.cuda.current_device() + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method # Benchmarking arguments. - uniform_expert_assignment : bool = False + uniform_expert_assignment: bool = False # shared expert arguments shared_expert: bool = False # enable using shared expert fc_cls: torch.nn.Module = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) # kwargs for custom fc layers + fc_kwargs: dict[str, Any] = dataclasses.field( + default_factory=dict, + ) # kwargs for custom fc layers remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_hidden_size: Optional[ + int + ] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) def __post_init__(self): diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index fd99aa4..b9b2ab1 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -1,7 +1,8 @@ from megablocks.layers.arguments import Arguments import torch -def dtype(args : Arguments): + +def dtype(args: Arguments): if args.fp16: return torch.float16 elif args.bf16: diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index 666398a..b227a2e 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -6,10 +6,17 @@ MlpType = Union[mlp.SparseMLP, glu.SparseGLU] _REGISTRY = { - 'mlp': {'grouped': mlp.GroupedMLP, 'sparse' : mlp.SparseMLP}, - 'glu': {'grouped': glu.GroupedGLU, 'sparse': glu.SparseGLU}, + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, } + def get(args: Arguments) -> MlpType: """Returns an MLP for use in a dMoE instance. @@ -24,10 +31,12 @@ def get(args: Arguments) -> MlpType: An instantiated MLP constructed using the input args. """ - if args.mlp_type not in _REGISTRY: + if args.mlp_type not in _REGISTRY: raise ValueError(f'Unsupported mlp type: {args.mlp_type}') if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.') + raise ValueError( + f'{args.mlp_type} does not support {args.mlp_impl} backend.', + ) return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index c1de9a8..52ef7a7 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -8,12 +8,14 @@ import stk import torch + def promote_scalar(x): return x.view(1) if not len(x.size()) else x + class ParallelDroplessMLP(moe.ParallelMLP): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super(ParallelDroplessMLP, self).__init__(args) self.hidden_size = args.hidden_size self.ffn_hidden_size = mpu.features_per_rank(args) @@ -22,10 +24,12 @@ def __init__(self, args : Arguments): # Calculate the number of bits needed to represent the column indices # in the intermediate sparse matrix. - max_column_index = ( - (self.ffn_hidden_size * self.num_experts) // self.blocking) + max_column_index = ((self.ffn_hidden_size * self.num_experts) // + self.blocking) self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), 1) + int(np.ceil(np.log2(max_column_index))), + 1, + ) def sparse_transpose(self, size, row_indices, column_indices, offsets): block_columns = size[1] // self.blocking @@ -37,7 +41,9 @@ def sparse_transpose(self, size, row_indices, column_indices, offsets): # To avoid overflow when we have large activation matrices we cast to # 32-bit before sorting. _, gather_indices = ops.sort( - column_indices.int(), self.transpose_sort_end_bit) + column_indices.int(), + self.transpose_sort_end_bit, + ) # There are a constant number of blocks in every row of the sparse matrix. # A blocks offset is: @@ -62,8 +68,11 @@ def topology(self, x, padded_bins): padded_tokens, _ = x.size() assert padded_tokens % self.blocking == 0 if self.ffn_hidden_size % self.blocking != 0: - raise ValueError(f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.') + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + + f'the block size {self.blocking}. Please update your configuration.', + ) # Offsets for the sparse matrix. All rows have the # same number of nonzero blocks dictated by the @@ -75,15 +84,18 @@ def topology(self, x, padded_bins): block_rows * blocks_per_row + 1, blocks_per_row, dtype=torch.int32, - device=x.device) + device=x.device, + ) # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, - self.blocking, - block_rows, - blocks_per_row) + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) # TODO(tgale): This is unused. Remove the need for this in stk. # For now, use meta init to save the device memory. @@ -92,17 +104,29 @@ def topology(self, x, padded_bins): self.blocking, self.blocking, dtype=common.dtype(self.args), - device='meta') + device='meta', + ) shape = ( padded_tokens, self.ffn_hidden_size * mpu.experts_per_rank(self.args), ) - row_indices = stk.ops.row_indices( - shape, data, offsets, column_indices) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, row_indices, column_indices, offsets) - return stk.Matrix(shape, data, row_indices, column_indices, offsets, - column_indices_t, offsets_t, block_offsets_t) + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) def indices_and_padded_bins(self, top_experts): # Sort the expert ids to produce the scatter/gather @@ -118,7 +142,9 @@ def indices_and_padded_bins(self, top_experts): # the matrix muliplications. Caculate the starting # position of each bin. padded_tokens_per_expert = ops.round_up( - tokens_per_expert, self.blocking) + tokens_per_expert, + self.blocking, + ) padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) padded_bins = promote_scalar(padded_bins) @@ -135,7 +161,8 @@ def sparse_forward_once(self, x, expert_weights, top_experts): top_experts = top_experts.flatten() with torch.no_grad(): indices, bin_ids, bins, padded_bins, tokens_per_expert = ( - self.indices_and_padded_bins(top_experts)) + self.indices_and_padded_bins(top_experts) + ) # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) @@ -145,7 +172,8 @@ def sparse_forward_once(self, x, expert_weights, top_experts): bin_ids, bins, padded_bins, - self.top_k) + self.top_k, + ) # Create the sparse matrix topology. with torch.no_grad(): @@ -162,37 +190,35 @@ def sparse_forward_once(self, x, expert_weights, top_experts): expert_weights, bins, padded_bins, - self.top_k) + self.top_k, + ) return x, tokens_per_expert # For use in the base-class parallel_forward_once. def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k): + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): # Round the token counts up to the block size used in the matrix # multiplication. Calculate the starting position of each bin. padded_tokens_per_expert = ops.round_up( - tokens_per_expert, self.blocking) + tokens_per_expert, + self.blocking, + ) padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) padded_bins = promote_scalar(padded_bins) # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - top_k) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) # Create the sparse matrix topology. with torch.no_grad(): @@ -201,7 +227,6 @@ def sparse_permute_and_compute( # Perform the expert computation. x = self.mlp(x, topo) - # Un-route the data for the MoE output. return ops.padded_scatter( x, @@ -210,7 +235,8 @@ def sparse_permute_and_compute( expert_weights, bins, padded_bins, - top_k) + top_k, + ) def grouped_forward_once(self, x, expert_weights, top_experts): # x: [sl, bs, hs] @@ -220,7 +246,8 @@ def grouped_forward_once(self, x, expert_weights, top_experts): top_experts = top_experts.flatten() with torch.no_grad(): indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts)) + self.indices_and_bins(top_experts) + ) out = self.grouped_permute_and_compute( x, @@ -230,60 +257,49 @@ def grouped_forward_once(self, x, expert_weights, top_experts): expert_weights, bins, -1, # unused - self.args.moe_top_k) + self.args.moe_top_k, + ) return out, tokens_per_expert def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k): + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.gather( - x, - indices, - bin_ids, - bins, - top_k) + x = ops.gather(x, indices, bin_ids, bins, top_k) # Perform the expert computation. x = self.mlp(x, tokens_per_expert) # Un-route the data for the MoE output. - return ops.scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - top_k) + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) def forward_once(self, x, expert_weights, top_experts): if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once( - x, expert_weights, top_experts) + return self.sparse_forward_once(x, expert_weights, top_experts) else: - return self.grouped_forward_once( - x, expert_weights, top_experts) - + return self.grouped_forward_once(x, expert_weights, top_experts) def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k): + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): if self.args.mlp_impl == 'sparse': return self.sparse_permute_and_compute( x, @@ -293,7 +309,8 @@ def permute_and_compute( expert_weights, bins, expert_capactiy, - top_k) + top_k, + ) else: return self.grouped_permute_and_compute( x, @@ -303,7 +320,8 @@ def permute_and_compute( expert_weights, bins, expert_capactiy, - top_k) + top_k, + ) class dMoE(moe.MoE): diff --git a/megablocks/layers/gelu.py b/megablocks/layers/gelu.py index 49ac4a8..e0eb8c0 100644 --- a/megablocks/layers/gelu.py +++ b/megablocks/layers/gelu.py @@ -7,10 +7,8 @@ def _gelu_backward_inplace(g, x): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) ff = ( - 0.5 * x * ( - (1 - tanh_out * tanh_out) * - (0.79788456 + 0.1070322243 * x * x) - ) + 0.5 * (1 + tanh_out) + 0.5 * x * ((1 - tanh_out * tanh_out) * + (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) ) return g.mul_(ff) @@ -26,7 +24,8 @@ def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): x.offsets, x.column_indices_t, x.offsets_t, - x.block_offsets_t) + x.block_offsets_t, + ) return _gelu_backward_inplace(grad, x) @@ -40,4 +39,5 @@ def gelu(x: stk.Matrix): x.offsets, x.column_indices_t, x.offsets_t, - x.block_offsets_t) + x.block_offsets_t, + ) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 828f10f..9a41a9b 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -15,30 +15,49 @@ class SparseGLU(SparseMLP): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__(args) - self.v1 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) with torch.no_grad(): - self.v1.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method)) + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) mpu.set_expert_model_parallel_attributes( - self.v1, self._should_set_parallelism_attribute) + self.v1, + self._should_set_parallelism_attribute, + ) if self.args.moe_weight_parallelism: - raise NotImplementedError("Weight parallelism not yet supported with GLU.") + raise NotImplementedError( + "Weight parallelism not yet supported with GLU.", + ) def forward(self, x, topo): if self.args.memory_optimized_mlp: - raise NotImplementedError("Memory optimized implementation not yet supported with GLU with sparse kernels.") + raise NotImplementedError( + "Memory optimized implementation not yet supported with GLU with sparse kernels.", + ) - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1), resolve_dtensor(w2) + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad( + self.v1, + ), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor( + v1, + ), resolve_dtensor(w2) # Compute the GLU. x1 = stk.ops.sdd(x, w1.t(), topo) @@ -49,6 +68,7 @@ def forward(self, x, topo): return stk.ops.dsd(x1, w2) + class MemoryOptimizedGroupedGLU(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @@ -62,8 +82,10 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): v1 = v1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not v1.is_contiguous() or not w2.is_contiguous()): + if ( + not x.is_contiguous() or not w1.is_contiguous() or + not v1.is_contiguous() or not w2.is_contiguous() + ): raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") # Layer 0: x @ w1.t(). @@ -90,9 +112,10 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): + if ( + not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or + not ctx.needs_input_grad[2] + ): raise ValueError("Expected all MLP inputs to need grad.") # Unpack saved tensors @@ -106,21 +129,30 @@ def backward(ctx, ddsd_out): # Rematerialize activation_fn output. activation_fn = ctx.activation_fn with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. dw2 = gg.backend.gmm( - activation_fn_out, ddsd_out, batch_sizes, trans_a=True) + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) # Compute dactivation_fn_out. # # NOTE: We reuse the activation_fn_out allocation. dactivation_fn_out = activation_fn_out gg.backend.gmm( - ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out) + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) # Compute dsdd_out. # @@ -144,14 +176,22 @@ def backward(ctx, ddsd_out): dx += gg.backend.gmm(dv1_out, v1, batch_sizes) return dx, dw1, dv1, dw2, None, None + memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply class GroupedGLU(SparseGLU): + def forward(self, x, tokens_per_expert): batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2)) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1), resolve_dtensor(w2) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor( + v1, + ), resolve_dtensor(w2) # Re-shape the weights for the grouped GEMMs. ne = mpu.experts_per_rank(self.args) @@ -161,8 +201,13 @@ def forward(self, x, tokens_per_expert): if self.args.memory_optimized_mlp: return memory_optimized_grouped_glu( - x, w1, v1, w2, batch_sizes, - self.args.activation_fn) + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) # Compute the MLP. x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) @@ -176,7 +221,8 @@ class SharedGLU(SharedMLP): Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class """ - def __init__(self, args : Arguments): + + def __init__(self, args: Arguments): super().__init__(args) self.gate_proj = args.fc_cls( args.hidden_size, diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index 9b333b3..c1512bd 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -4,10 +4,7 @@ from megablocks.layers import arguments, dmoe -_TESTS = ( - (8, 2048, 4096, 4096, 32, 4), - -) +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) def get_tensors(): @@ -23,13 +20,14 @@ def get_tensors(): def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k): + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): args = arguments.Arguments( hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, @@ -39,13 +37,13 @@ def test_memory( expert_parallel_group=group, fp16=False, bf16=True, - device=torch.cuda.current_device()) + device=torch.cuda.current_device(), + ) layer = dmoe.dMoE(args).cuda() - x = torch.randn( - (batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) torch.cuda.empty_cache() # Run forward + backward. @@ -55,16 +53,18 @@ def test_memory( # Report peak memory. mem = torch.cuda.max_memory_allocated() - print("Max Memory Allocated = {:0.0f}MiB".format( - mem / 1e6)) - print("Max Memory Reserved = {:0.0f}MiB".format( - torch.cuda.max_memory_reserved() / 1e6)) + print("Max Memory Allocated = {:0.0f}MiB".format(mem / 1e6)) + print( + "Max Memory Reserved = {:0.0f}MiB".format( + torch.cuda.max_memory_reserved() / 1e6, + ), + ) # Calculate weight and gradient memory usage. weight_memory = 2 * ( - layer.router.layer.weight.numel() + - layer.experts.mlp.w1.numel() + - layer.experts.mlp.w2.numel()) + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + + layer.experts.mlp.w2.numel() + ) def grad_numel(x): if x.grad is not None: @@ -73,14 +73,15 @@ def grad_numel(x): grad_memory = 2 * ( grad_numel(layer.router.layer.weight) + - grad_numel(layer.experts.mlp.w1) + - grad_numel(layer.experts.mlp.w2)) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) weight_memory += grad_memory - print("Weight Memory Allocated = {:0.0f}MiB".format( - weight_memory / 1e6)) - print("Activation Memory Allocated = {:0.0f}MiB".format( - (mem - weight_memory) / 1e6)) + print("Weight Memory Allocated = {:0.0f}MiB".format(weight_memory / 1e6)) + print( + "Activation Memory Allocated = {:0.0f}MiB".format((mem - weight_memory) + / 1e6,), + ) # Manually calculate GPU memory usage from the garbage # collector. @@ -93,8 +94,7 @@ def grad_numel(x): print(f"{i}: {t.shape}, {t.numel() * 2}") del tensors - print("Total Bytes Found = {:0.0f}MiB".format( - total * 2 / 1e6)) + print("Total Bytes Found = {:0.0f}MiB".format(total * 2 / 1e6)) if __name__ == '__main__': diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 134dd73..8ce5fe3 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -23,6 +23,8 @@ def forward(ctx, x, scale): @torch.cuda.amp.custom_bwd def backward(ctx, grad): return grad * ctx.scale, None + + scale_gradient = ScaleGradient.apply @@ -34,18 +36,23 @@ def resolve_dtensor(weight): return weight -def create_moe_expert_weights(args : Arguments, - num_experts : int, - ffn_hidden_size : int, - hidden_size : int, - init_method : InitFn): +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): # Create the entire weight matrix such that the sampled weights will # not vary between data parallelism and expert model parallelism for # the same random seed. master_weights = torch.empty( - num_experts, ffn_hidden_size, hidden_size, + num_experts, + ffn_hidden_size, + hidden_size, device=args.device, - dtype=common.dtype(args)) + dtype=common.dtype(args), + ) init_method(master_weights) if not args.moe_expert_model_parallelism: @@ -73,35 +80,44 @@ def create_moe_expert_weights(args : Arguments, # Slice the weight matrix to get the chunk for this rank. with torch.no_grad(): - weights = master_weights[ - start_expert:end_expert, start_row:end_row] + weights = master_weights[start_expert:end_expert, start_row:end_row] return weights class MLP(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() self.args = args # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) experts_per_rank = mpu.experts_per_rank(args) - self.w1 = torch.nn.Parameter(torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args))) - self.w2 = torch.nn.Parameter(torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) mpu.set_expert_model_parallel_attributes( - self.w1, args.moe_expert_model_parallelism) + self.w1, + args.moe_expert_model_parallelism, + ) mpu.set_expert_model_parallel_attributes( - self.w2, args.moe_expert_model_parallelism) + self.w2, + args.moe_expert_model_parallelism, + ) # Initialize the parameters for the MLP. # @@ -113,16 +129,28 @@ def __init__(self, args : Arguments): # usage. with torch.no_grad(): w1 = create_moe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method) + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_(create_moe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.output_layer_init_method)) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) self.gradient_scale = None if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args) + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size( + self.args, + ) def scale_grad(self, w): if self.gradient_scale is None: @@ -137,13 +165,20 @@ def forward(self, x): return torch.bmm(x, w2) -def create_dmoe_expert_weights(args : Arguments, - num_experts : int, - rows : int, - columns : int, - init_method : InitFn): +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): weights = create_moe_expert_weights( - args, num_experts, rows, columns, init_method) + args, + num_experts, + rows, + columns, + init_method, + ) weights = weights.view([-1, columns]) rows, columns = weights.shape @@ -173,16 +208,20 @@ def forward(ctx, x, w1, w2, topo, activation_fn): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous()): + if ( + not x.is_contiguous() or not w1.is_contiguous() or + not w2.is_contiguous() + ): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - topo_tensors = (topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) # Layer 0: x @ w1.t(). sdd_out = stk.ops.sdd(x, w1.t(), topo) @@ -208,9 +247,10 @@ def forward(ctx, x, w1, w2, topo, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): + if ( + not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or + not ctx.needs_input_grad[2] + ): raise ValueError("Expected all MLP inputs to need grad.") # unpack saved tensors @@ -224,7 +264,11 @@ def backward(ctx, ddsd_out): # rematerialize activation function output activation_fn = ctx.activation_fn sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn(sdd_out, activation_fn, return_grad_fn=True) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) # Compute dw2 with recomputed activation_fn output. dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) @@ -234,12 +278,14 @@ def backward(ctx, ddsd_out): # NOTE: We reuse the activation_fn_out allocation. dactivation_fn_out = activation_fn_out stk.backend.triton_kernels.sdd( - ddsd_out, w2.t(), + ddsd_out, + w2.t(), dactivation_fn_out.shape, dactivation_fn_out.data, dactivation_fn_out.offsets, dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices) + dactivation_fn_out.column_indices, + ) # Compute dsdd_out. # @@ -268,16 +314,18 @@ def backward(ctx, ddsd_out): dsdd_out.block_offsets_t, False, w1, - ddsd_out) + ddsd_out, + ) dx = ddsd_out return dx, dw1, dw2, None, None + memory_optimized_mlp = MemoryOptimizedMLP.apply class SparseMLP(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() self.args = args self._num_rows_per_rank = ( @@ -285,16 +333,22 @@ def __init__(self, args : Arguments): mpu.get_weight_parallel_world_size(args) ) - self.w1 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) - self.w2 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) # Initialize the parameters for the MLP. # @@ -305,23 +359,42 @@ def __init__(self, args : Arguments): # and the slice which causes large increases in our peak memory # usage. with torch.no_grad(): - self.w1.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method)) - self.w2.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.output_layer_init_method)) + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) self._should_set_parallelism_attribute = ( - args.moe_expert_model_parallelism or args.moe_weight_parallelism) + args.moe_expert_model_parallelism or args.moe_weight_parallelism + ) mpu.set_expert_model_parallel_attributes( - self.w1, self._should_set_parallelism_attribute) + self.w1, + self._should_set_parallelism_attribute, + ) mpu.set_expert_model_parallel_attributes( - self.w2, self._should_set_parallelism_attribute) + self.w2, + self._should_set_parallelism_attribute, + ) self.gradient_scale = None if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args) + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size( + self.args, + ) def scale_grad(self, w): if self.gradient_scale is None: @@ -337,7 +410,12 @@ def parallel_forward(self, x, topo): f'memory_optimized_weight_parallel_mlp not implemented for custom activation_fn={self.args.activation_fn}.', ) return wp.memory_optimized_weight_parallel_mlp( - x, w1, w2, topo, group) + x, + w1, + w2, + topo, + group, + ) # Compute the MLP. x = wp.sdd_nt(x, w1, topo, group) @@ -351,7 +429,12 @@ def forward(self, x, topo): return self.parallel_forward(x, topo) elif self.args.memory_optimized_mlp: return memory_optimized_mlp( - x, w1, w2, topo, self.args.activation_fn) + x, + w1, + w2, + topo, + self.args.activation_fn, + ) # Compute the MLP. x = stk.ops.sdd(x, w1.t(), topo) @@ -371,8 +454,10 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous()): + if ( + not x.is_contiguous() or not w1.is_contiguous() or + not w2.is_contiguous() + ): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). @@ -398,9 +483,10 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): + if ( + not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or + not ctx.needs_input_grad[2] + ): raise ValueError("Expected all MLP inputs to need grad.") # Unpack saved tensors @@ -420,14 +506,23 @@ def backward(ctx, ddsd_out): # Compute dw2 with recomputed activation_fn output. dw2 = gg.backend.gmm( - activation_fn_out, ddsd_out, batch_sizes, trans_a=True) + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) # Compute dactivation_fn_out. # # NOTE: We reuse the activation_fn_out allocation. dactivation_fn_out = activation_fn_out gg.backend.gmm( - ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out) + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) # Compute dsdd_out. # @@ -449,6 +544,7 @@ def backward(ctx, ddsd_out): dx = ddsd_out return dx, dw1, dw2, None, None + memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply @@ -465,12 +561,17 @@ def forward(self, x, tokens_per_expert): if self.args.moe_weight_parallelism: raise NotImplementedError( - "Weight parallelism not yet supported with GroupedMLP.") + "Weight parallelism not yet supported with GroupedMLP.", + ) if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( - x, w1, w2, batch_sizes, - self.args.activation_fn) + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) # Compute the MLP. x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) @@ -483,7 +584,8 @@ class SharedMLP(torch.nn.Module): Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class """ - def __init__(self, args : Arguments): + + def __init__(self, args: Arguments): super().__init__() self.args = args self.fc_kwargs: dict[str, Any] = { @@ -505,7 +607,11 @@ def __init__(self, args : Arguments): ) self.down_proj._is_residual = True # a flag for llm-foundry init - def add_experts_sharedexpert(self, shared_expert_out: torch.Tensor, expert_out: torch.Tensor) -> torch.Tensor: + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: # Helper function to add expert output to shared expert output # with optional weighted sum. if self.args.shared_expert_weighted_sum: @@ -513,7 +619,10 @@ def add_experts_sharedexpert(self, shared_expert_out: torch.Tensor, expert_out: # wieghted by number of experts used t_experts = self.args.moe_top_k + 1 sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add(expert_out, alpha=(self.args.moe_top_k / t_experts)) + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) return shared_expert_out + expert_out diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 6a31bfd..7f659d2 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -24,7 +24,7 @@ def clear_load_balancing_loss(): _LOAD_BALANCING_LOSS.clear() -def batched_load_balancing_loss(args : Arguments): +def batched_load_balancing_loss(args: Arguments): if args.moe_loss_weight == 0: return 0.0 @@ -32,7 +32,8 @@ def batched_load_balancing_loss(args : Arguments): # expert_scores[i].shape = (tokens, num_experts) tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) num_layers_per_pipeline_stage = ( - args.num_layers // args.pipeline_model_parallel_size) + args.num_layers // args.pipeline_model_parallel_size + ) if args.num_layers_per_virtual_pipeline_stage is not None: num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage @@ -43,7 +44,8 @@ def batched_load_balancing_loss(args : Arguments): f"{args.num_layers}\npipeline_model_parallel_size = " f"{args.pipeline_model_parallel_size}\n" "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}") + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) if len(expert_scores) != num_layers_per_pipeline_stage: raise ValueError( f"Expected {num_layers_per_pipeline_stage} expert_scores " @@ -51,7 +53,8 @@ def batched_load_balancing_loss(args : Arguments): f"{args.num_layers}\npipeline_model_parallel_size = " f"{args.pipeline_model_parallel_size}\n" "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}") + f" = {args.num_layers_per_virtual_pipeline_stage}", + ) # Verify the shape of the tokens_per_expert and expert_scores tensors. assert all(( @@ -60,11 +63,10 @@ def batched_load_balancing_loss(args : Arguments): )) tokens = expert_scores[0].shape[0] - assert all(( - (x.ndim == 2 and x.shape[1] == args.moe_num_experts and - x.shape[0] == tokens) for x in expert_scores - )) - + assert all((( + x.ndim == 2 and x.shape[1] == args.moe_num_experts and + x.shape[0] == tokens + ) for x in expert_scores)) # Concatenate the contributions of each layer and convert to # the correct types and formats for the dot product. @@ -84,15 +86,8 @@ def batched_load_balancing_loss(args : Arguments): # Calculate the total scale across all factors. # # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = ( - args.moe_num_experts * - args.moe_loss_weight - ) - scale_denominator = ( - args.num_layers * - tokens * - args.moe_top_k - ) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) scale = scale_numerator / scale_denominator return scale * torch.dot(tokens_per_expert, expert_scores) @@ -103,7 +98,7 @@ def batched_load_balancing_loss(args : Arguments): # parallel all2all. class ParallelMLP(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super(ParallelMLP, self).__init__() self.args = args @@ -123,24 +118,28 @@ def __init__(self, args : Arguments): if self.args.bias: # Note that the output bias is not parallelized with expert # model parallelism. - self.bias = torch.nn.Parameter(torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) torch.nn.init.zeros_(self.bias) else: self.register_parameter('bias', None) # Select the forward function for the operating mode. self.forward_fn = ( - self.parallel_forward_once if - args.moe_expert_model_parallelism else - self.forward_once) + self.parallel_forward_once + if args.moe_expert_model_parallelism else self.forward_once + ) def expert_capacity(self, tokens): world_size = mpu.get_expert_parallel_world_size(self.args) tokens_per_expert = ( - self.top_k * tokens * world_size / self.num_experts) + self.top_k * tokens * world_size / self.num_experts + ) return int(self.args.moe_capacity_factor * tokens_per_expert) def load_balancing_loss(self, tokens_per_expert, expert_scores): @@ -154,7 +153,8 @@ def load_balancing_loss(self, tokens_per_expert, expert_scores): scale = self.num_experts / (tokens * self.top_k) return scale * torch.dot( tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0)) + expert_scores.mean(dim=0), + ) def indices_and_bins(self, top_expert): # Sort the expert ids to produce the scatter/gather @@ -180,27 +180,26 @@ def indices_and_bins(self, top_expert): return indices, bin_ids, bins, tokens_per_expert def permute_and_compute( - self, - x, - tokens_per_expert, # unused - indices, - bin_ids, # unused - expert_weights, - bins, - expert_capacity, - top_k): + self, + x, + tokens_per_expert, # unused + indices, + bin_ids, # unused + expert_weights, + bins, + expert_capacity, + top_k, + ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.binned_gather( - x, indices, bins, expert_capacity, top_k) + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) # Perform the expert computation. Note that we don't # use biases for these linear operations. x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter( - x, indices, expert_weights, bins, top_k) + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) def forward_once(self, x, expert_weights, top_experts): # x: [sl, bs, hs] @@ -210,7 +209,8 @@ def forward_once(self, x, expert_weights, top_experts): top_experts = top_experts.flatten() with torch.no_grad(): indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts)) + self.indices_and_bins(top_experts) + ) # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. @@ -227,7 +227,8 @@ def forward_once(self, x, expert_weights, top_experts): expert_weights, bins, expert_capacity, - self.top_k) + self.top_k, + ) return x, tokens_per_expert def parallel_forward_once(self, x, expert_weights, top_experts): @@ -256,22 +257,28 @@ def parallel_forward_once(self, x, expert_weights, top_experts): top_experts = top_experts.flatten() with torch.no_grad(): indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts)) + self.indices_and_bins(top_experts) + ) # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (mpu.hidden_sharding_degree(self.args),)) + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) # Pass token count information to the device on which the # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) + parallel_tokens_per_expert = torch.empty_like( + repeated_tokens_per_expert, + ) tpe_handle = torch.distributed.all_to_all_single( parallel_tokens_per_expert, repeated_tokens_per_expert, group=self.args.expert_parallel_group, - async_op=True) + async_op=True, + ) # Permute locally and without any padding so that tokens for each # parallel device are stored contiguously. @@ -279,12 +286,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - x = ops.gather( - x, - indices, - bin_ids, - bins, - self.top_k) + x = ops.gather(x, indices, bin_ids, bins, self.top_k) # Compute the number of tokens that will be received from each # device and permute the input data across the devices. @@ -295,9 +297,11 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Reshape to [world_size, num_experts_per_rank]. world_size = mpu.get_expert_parallel_world_size(self.args) repeated_tokens_per_expert = ( - repeated_tokens_per_expert.view(world_size, experts_per_rank)) + repeated_tokens_per_expert.view(world_size, experts_per_rank) + ) parallel_tokens_per_expert = ( - parallel_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert.view(world_size, experts_per_rank) + ) # TODO(tgale): It might be faster to do this on the GPU and # then communicate the results back to the host. @@ -321,9 +325,12 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Start the cross-device permutation asynchronously so we can # overlap communication with computation. parallel_x, parallel_x_handle = all_to_all( - x, recv_counts, send_counts, + x, + recv_counts, + send_counts, self.args.expert_parallel_group, - async_op=True) + async_op=True, + ) with torch.no_grad(): # After we do the cross-device permutation we have the tokens on the @@ -333,11 +340,12 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. replicate_bins = ops.inclusive_cumsum( - parallel_tokens_per_expert.flatten(), 0) + parallel_tokens_per_expert.flatten(), + 0, + ) replicate_bins = ( replicate_bins.view(1) - if not len(replicate_bins.size()) - else replicate_bins + if not len(replicate_bins.size()) else replicate_bins ) # Construct the expert indices for the permuted tokens. @@ -351,21 +359,25 @@ def parallel_forward_once(self, x, expert_weights, top_experts): ) parallel_top_expert = ops.replicate( parallel_top_expert.unsqueeze(dim=0), - replicate_bins, tokens_received).flatten() + replicate_bins, + tokens_received, + ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, self.sort_end_bit) + parallel_top_expert, + self.sort_end_bit, + ) # Calculate the bins boundaries from the token counts. parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int) - parallel_bins = ops.inclusive_cumsum( - parallel_tokens_per_expert, 0) + dim=0, + dtype=torch.int, + ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) parallel_bins = ( parallel_bins.view(1) - if not len(parallel_bins.size()) - else parallel_bins + if not len(parallel_bins.size()) else parallel_bins ) # If expert_capacity is set to zero, set the number of tokens @@ -373,8 +385,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): tokens, hs = x.size() expert_capacity = self.expert_capacity(tokens) if expert_capacity == 0: - expert_capacity = torch.max( - parallel_tokens_per_expert).item() + expert_capacity = torch.max(parallel_tokens_per_expert).item() # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. @@ -383,7 +394,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # moved to CPU for the prior all_to_all, which avoids an extra # device synchronization. parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, dtype=torch.int) + dim=0, + dtype=torch.int, + ) parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, @@ -393,12 +406,16 @@ def parallel_forward_once(self, x, expert_weights, top_experts): None, # expert_weights parallel_bins, expert_capacity, - top_k=1) + top_k=1, + ) # Un-permute the tokens across the devices. x, _ = all_to_all( - parallel_x, send_counts, recv_counts, - self.args.expert_parallel_group) + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) # Reduce along the hidden sharding to get the final outputs. # @@ -411,21 +428,14 @@ def parallel_forward_once(self, x, expert_weights, top_experts): x = ops.sum(x.view(shape), dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - self.top_k) + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x, scores, expert_weights, top_experts): in_shape = x.size() # Compute the experts. - x, tokens_per_expert = self.forward_fn( - x, expert_weights, top_experts) + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) if self.training and self.args.moe_loss_weight > 0: save_load_balancing_loss((tokens_per_expert, scores)) x = x.view(in_shape) @@ -438,7 +448,7 @@ def forward(self, x, scores, expert_weights, top_experts): class MoE(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super(MoE, self).__init__() # Token router. @@ -467,5 +477,8 @@ def forward(self, x): out = self.experts(x, scores, expert_weights, top_experts) if self.shared_expert is not None: shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert(shared_expert_out, out) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) return out diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 49bbcbe..3bed037 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -2,50 +2,58 @@ import torch -def is_moe_param(tensor : torch.Tensor) -> bool: +def is_moe_param(tensor: torch.Tensor) -> bool: return hasattr(tensor, 'expert_model_parallel') -def get_expert_parallel_world_size(args : Arguments) -> int: +def get_expert_parallel_world_size(args: Arguments) -> int: return ( torch.distributed.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1 ) -def get_expert_parallel_rank(args : Arguments) -> int: +def get_expert_parallel_rank(args: Arguments) -> int: return ( torch.distributed.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0 ) -def set_expert_model_parallel_attributes(tensor : torch.Tensor, - is_parallel : bool): +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): assert not hasattr(tensor, 'expert_model_parallel') setattr(tensor, 'expert_model_parallel', is_parallel) -def param_is_expert_model_parallel(param : torch.Tensor) -> bool: - return (hasattr(param, 'expert_model_parallel') and - param.expert_model_parallel) +def param_is_expert_model_parallel(param: torch.Tensor) -> bool: + return ( + hasattr(param, 'expert_model_parallel') and param.expert_model_parallel + ) -def copy_expert_model_parallel_attributes(destination_tensor : torch.Tensor, - source_tensor : torch.Tensor): +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): if hasattr(source_tensor, 'expert_model_parallel'): - setattr(destination_tensor, 'expert_model_parallel', - getattr(source_tensor,'expert_model_parallel')) + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) -def get_weight_parallel_world_size(args : Arguments) -> int: +def get_weight_parallel_world_size(args: Arguments) -> int: return ( torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1 ) -def get_weight_parallel_rank(args : Arguments) -> int: +def get_weight_parallel_rank(args: Arguments) -> int: return ( torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0 @@ -62,35 +70,38 @@ def synchronized_print(group, *x): # Helpers for expert/tensor sharding. -def expert_sharding_degree(args : Arguments) -> int: +def expert_sharding_degree(args: Arguments) -> int: world_size = get_expert_parallel_world_size(args) esd = min(world_size, args.moe_num_experts) if (args.moe_num_experts % esd) != 0: raise ValueError( - f"Cannot shard {args.moe_num_experts} experts {esd} ways.") + f"Cannot shard {args.moe_num_experts} experts {esd} ways.", + ) return esd -def hidden_sharding_degree(args : Arguments) -> int: +def hidden_sharding_degree(args: Arguments) -> int: world_size = get_expert_parallel_world_size(args) esd = expert_sharding_degree(args) hsd = world_size // esd if (args.ffn_hidden_size % hsd) != 0: raise ValueError( - f"Cannot shard {args.ffn_hidden_size} features {hsd} ways.") + f"Cannot shard {args.ffn_hidden_size} features {hsd} ways.", + ) if (esd * hsd) != world_size: raise ValueError( f"Invalid sharding. 'expert_sharding_degree' " f"({esd}) * hidden_sharding_degree " - f"({hsd}) != world_size ({world_size}).") + f"({hsd}) != world_size ({world_size}).", + ) return hsd -def experts_per_rank(args : Arguments) -> int: +def experts_per_rank(args: Arguments) -> int: return args.moe_num_experts // expert_sharding_degree(args) -def features_per_rank(args : Arguments) -> int: +def features_per_rank(args: Arguments) -> int: return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index e1abddf..5039cf4 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -9,18 +9,19 @@ # so that PyTorch still executes the full set of router operation. class _UniformExpertAssignment(torch.autograd.Function): - @staticmethod def forward(ctx, x, num_experts): out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) out = torch.remainder(out, num_experts) return out.view(x.shape) + + _uniform_expert_assignment = _UniformExpertAssignment.apply class LearnedRouter(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() self.args = args @@ -34,7 +35,8 @@ def __init__(self, args : Arguments): args.moe_num_experts, bias=False, dtype=common.dtype(args), - device=args.device) + device=args.device, + ) args.init_method(self.layer.weight) def jitter(self, x): @@ -45,7 +47,7 @@ def jitter(self, x): def _top_k(self, scores): if self.args.moe_top_k == 1: - return scores.max(dim=-1,keepdim=True) + return scores.max(dim=-1, keepdim=True) return torch.topk(scores, self.args.moe_top_k, dim=-1) def forward(self, x): @@ -56,10 +58,16 @@ def forward(self, x): expert_weights, expert_indices = self._top_k(scores) if self.args.moe_normalize_expert_weights: expert_weights = expert_weights / torch.norm( - expert_weights, p=self.args.moe_normalize_expert_weights,dim=-1, keepdim=True) + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) expert_indices = ( - _uniform_expert_assignment(expert_indices, self.args.moe_num_experts) - if self.args.uniform_expert_assignment else expert_indices + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices ) return scores, expert_weights, expert_indices diff --git a/megablocks/layers/sharedexpert_registry.py b/megablocks/layers/sharedexpert_registry.py index 4d323ee..7396de8 100644 --- a/megablocks/layers/sharedexpert_registry.py +++ b/megablocks/layers/sharedexpert_registry.py @@ -3,12 +3,12 @@ from megablocks.layers import glu from megablocks.layers.arguments import Arguments - _REGISTRY = { 'mlp': mlp.SharedMLP, 'glu': glu.SharedGLU, } + def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: """Returns an SharedMLP for use in a dMoE instance. @@ -22,7 +22,7 @@ def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: An instantiated SharedMLP constructed using the input args. """ - if args.mlp_type not in _REGISTRY: + if args.mlp_type not in _REGISTRY: raise ValueError(f'Unsupported mlp type: {args.mlp_type}') return _REGISTRY[args.mlp_type](args) diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index 530026e..5f027dc 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -14,33 +14,46 @@ def allclose(x, y, pct=0.5): class FFN(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() - self.w1 = torch.nn.Parameter(torch.empty( - args.hidden_size, - args.ffn_hidden_size, - device=args.device, - dtype=torch.float16 if args.fp16 else torch.float32)) - self.w2 = torch.nn.Parameter(torch.empty( - args.ffn_hidden_size, - args.hidden_size, - device=args.device, - dtype=torch.float16 if args.fp16 else torch.float32)) + self.w1 = torch.nn.Parameter( + torch.empty( + args.hidden_size, + args.ffn_hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32, + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + args.ffn_hidden_size, + args.hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32, + ), + ) def forward(self, x): - return torch.matmul(F.gelu( - torch.matmul(x, self.w1), approximate="tanh"), self.w2) + return torch.matmul( + F.gelu(torch.matmul(x, self.w1), approximate="tanh"), + self.w2, + ) + class GLU(FFN): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__(args) - self.v1 = torch.nn.Parameter(torch.empty( - args.hidden_size, - args.ffn_hidden_size, - device=args.device, - dtype=torch.float16 if args.fp16 else torch.float32)) + self.v1 = torch.nn.Parameter( + torch.empty( + args.hidden_size, + args.ffn_hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32, + ), + ) def forward(self, x): - x1 = F.gelu(torch.matmul(x, self.w1), approximate="tanh") * torch.matmul(x, self.v1) + x1 = F.gelu(torch.matmul(x, self.w1), + approximate="tanh") * torch.matmul(x, self.v1) return torch.matmul(x1, self.w2) diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py index 46d5674..272d3d2 100644 --- a/megablocks/layers/weight_parallel.py +++ b/megablocks/layers/weight_parallel.py @@ -22,9 +22,17 @@ def _gather_weights(w, group, parallel_w=None, async_op=False): if parallel_w is None: parallel_w = torch.empty( - n * world_size, k, device=w.device, dtype=w.dtype) + n * world_size, + k, + device=w.device, + dtype=w.dtype, + ) handle = torch.distributed.all_gather_into_tensor( - parallel_w, w, group=group, async_op=async_op) + parallel_w, + w, + group=group, + async_op=async_op, + ) return parallel_w, handle @@ -52,11 +60,17 @@ def _scaled_reduce_scatter(parallel_dw, group, dw=None, async_op=False): if dw is None: dw = torch.empty( - n // world_size, k, + n // world_size, + k, device=parallel_dw.device, - dtype=torch.float32) + dtype=torch.float32, + ) handle = torch.distributed.reduce_scatter_tensor( - dw, parallel_dw, group=group, async_op=async_op) + dw, + parallel_dw, + group=group, + async_op=async_op, + ) return dw, handle @@ -76,13 +90,15 @@ def forward(ctx, x, w, topo, group): ctx.group = group ctx.shape = topo.shape ctx.save_for_backward( - x, w, + x, + w, topo.row_indices, topo.column_indices, topo.offsets, topo.column_indices_t, topo.offsets_t, - topo.block_offsets_t) + topo.block_offsets_t, + ) # TODO(tgale): Support prefetching forward weights. parallel_w, _ = _gather_weights(w, group) @@ -104,7 +120,11 @@ def backward(ctx, grad): # Start the weight gradient reduce scatter to overlap with the # data gradient computation. handle.wait() - dw, handle = _scaled_reduce_scatter(parallel_dw, ctx.group, async_op=True) + dw, handle = _scaled_reduce_scatter( + parallel_dw, + ctx.group, + async_op=True, + ) dx = None if ctx.needs_input_grad[0]: dx = stk.ops.dsd(grad, parallel_w) @@ -125,24 +145,27 @@ def sdd_nt(a, b, topo, group): topo.offsets, topo.column_indices_t, topo.offsets_t, - topo.block_offsets_t) + topo.block_offsets_t, + ) class WeightParallelDsdNn(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd - def forward(ctx, - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - w, - group): + def forward( + ctx, + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + w, + group, + ): # [m, k] x [k, n] = [m, n] # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -161,7 +184,8 @@ def forward(ctx, column_indices_t, offsets_t, block_offsets_t, - w) + w, + ) x = stk.Matrix( shape, data, @@ -170,7 +194,8 @@ def forward(ctx, offsets, column_indices_t, offsets_t, - block_offsets_t) + block_offsets_t, + ) # TODO(tgale): Support prefetching forward weights. parallel_w, _ = _gather_weights(w, group) @@ -192,7 +217,11 @@ def backward(ctx, grad): # Start the weight gradient reduce scatter to overlap with the # data gradient computation. handle.wait() - dw, handle = _scaled_reduce_scatter(parallel_dw, ctx.group, async_op=True) + dw, handle = _scaled_reduce_scatter( + parallel_dw, + ctx.group, + async_op=True, + ) dx = None if ctx.needs_input_grad[1]: dx = stk.ops.sdd(grad, parallel_w.t(), x) @@ -215,7 +244,8 @@ def dsd_nn(a, b, group): a.offsets_t, a.block_offsets_t, b, - group) + group, + ) class MemoryOptimizedWeightParallelMLP(torch.autograd.Function): @@ -230,8 +260,10 @@ def forward(ctx, x, w1, w2, topo, group): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous()): + if ( + not x.is_contiguous() or not w1.is_contiguous() or + not w2.is_contiguous() + ): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). @@ -254,13 +286,17 @@ def forward(ctx, x, w1, w2, topo, group): ctx.group = group ctx.shape = topo.shape ctx.save_for_backward( - x, w1, w2, sdd_out.data, + x, + w1, + w2, + sdd_out.data, topo.row_indices, topo.column_indices, topo.offsets, topo.column_indices_t, topo.offsets_t, - topo.block_offsets_t) + topo.block_offsets_t, + ) return dsd_out @staticmethod @@ -269,15 +305,15 @@ def backward(ctx, ddsd_out): x, w1, w2 = ctx.saved_tensors[:3] sdd_out = stk.Matrix(ctx.shape, *ctx.saved_tensors[3:]) - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): + if ( + not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or + not ctx.needs_input_grad[2] + ): raise ValueError("Expected all MLP inputs to need grad.") # Start the weight gather asynchronously to overlap with the # weight gradient computation and gelu recompute. - parallel_w2, handle = _gather_weights( - w2, ctx.group, async_op=True) + parallel_w2, handle = _gather_weights(w2, ctx.group, async_op=True) # Compute dw2 with recomputed gelu output. gelu_out = gelu.gelu(sdd_out) @@ -287,18 +323,23 @@ def backward(ctx, ddsd_out): # data gradient computation. handle.wait() dw2, handle = _scaled_reduce_scatter( - parallel_dw2, ctx.group, async_op=True) + parallel_dw2, + ctx.group, + async_op=True, + ) # Compute dgelu_out. # # NOTE: We reuse the gelu_out allocation. stk.backend.triton_kernels.sdd( - ddsd_out, parallel_w2.t(), + ddsd_out, + parallel_w2.t(), sdd_out.shape, gelu_out.data, sdd_out.offsets, sdd_out.row_indices, - sdd_out.column_indices) + sdd_out.column_indices, + ) dgelu_out = gelu_out # NOTE: Be careful to wait and only cast dw to the output dtype once @@ -311,7 +352,11 @@ def backward(ctx, ddsd_out): # # NOTE: Reuse the buffer from the w2 weight gather. parallel_w1, handle = _gather_weights( - w1, ctx.group, parallel_w2, async_op=True) + w1, + ctx.group, + parallel_w2, + async_op=True, + ) # Compute dsdd_out. # @@ -332,14 +377,18 @@ def backward(ctx, ddsd_out): dsdd_out.block_offsets_t, True, # transpose_a x, - parallel_dw2) + parallel_dw2, + ) parallel_dw1 = parallel_dw2 # Start the weight gradient reduce scatter to overlap with the # data gradient computation. handle.wait() dw1, handle = _scaled_reduce_scatter( - parallel_dw1, ctx.group, async_op=True) + parallel_dw1, + ctx.group, + async_op=True, + ) # Compute dx. # @@ -355,7 +404,8 @@ def backward(ctx, ddsd_out): dsdd_out.block_offsets_t, False, parallel_w1, - ddsd_out) + ddsd_out, + ) dx = ddsd_out # NOTE: Be careful to wait and only cast dw to the output dtype once @@ -364,4 +414,5 @@ def backward(ctx, ddsd_out): dw1 = dw1.to(w1.dtype) return dx, dw1, dw2, None, None + memory_optimized_weight_parallel_mlp = MemoryOptimizedWeightParallelMLP.apply diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index ccae8f3..a26b8fb 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -24,24 +24,26 @@ (1024 * 1024, 1024), ) + def benchmark_all_to_all(group, sl, hs): - world_size = torch.distributed.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size + world_size = torch.distributed.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() - x = torch.randn((sl, hs)).cuda().half() + details = { + "world_size": world_size, + "message_size (B)": send_recv_sizes[0] * hs * 2, # 2B elements. + } - details = { - "world_size": world_size, - "message_size (B)": send_recv_sizes[0] * hs * 2, # 2B elements. - } + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - def benchmark(): - return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - time, std = benchmark_util.benchmark_function(benchmark) + time, std = benchmark_util.benchmark_function(benchmark) - if torch.distributed.get_rank(group) == 0: - benchmark_util.log_benchmark("All-To-All", details, time, std) + if torch.distributed.get_rank(group) == 0: + benchmark_util.log_benchmark("All-To-All", details, time, std) if __name__ == '__main__': diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 0592a55..94b6ea5 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -2,6 +2,7 @@ from megablocks.backend import kernels from stk.backend.autocast import custom_fwd, custom_bwd + # Autograd wrapper for binned_gather kernel. class BinnedGatherOp(torch.autograd.Function): @@ -19,4 +20,6 @@ def backward(ctx, grad): indices, bins = ctx.saved_tensors out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) return out, None, None, None, None + + binned_gather = BinnedGatherOp.apply diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index 453de7d..143ac87 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -2,6 +2,7 @@ from megablocks.backend import kernels from stk.backend.autocast import custom_fwd, custom_bwd + # Autograd wrapper for binned_scatter kernel. class BinnedScatterOp(torch.autograd.Function): @@ -23,7 +24,13 @@ def backward(ctx, grad): grad = grad.contiguous() x, indices, weights, bins = ctx.saved_tensors out = kernels.binned_gather( - grad, indices, weights, bins, ctx.bin_size, ctx.top_k) + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) wgrad = None if ctx.needs_input_grad[2]: @@ -32,6 +39,9 @@ def backward(ctx, grad): grad, indices, bins, - ctx.top_k) + ctx.top_k, + ) return out, None, wgrad, None, None + + binned_scatter = BinnedScatterOp.apply diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 6907f81..87b1298 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -7,6 +7,7 @@ # c++ operations. import megablocks_ops as ops + # Autograd wrappers for cumsum kernels. # # NOTE: Does not support gradients. @@ -22,8 +23,11 @@ def forward(ctx, x, dim): out = torch.empty_like(x) ops.exclusive_cumsum(x, dim, out) return out + + exclusive_cumsum = ExclusiveCumsumOp.apply + class InclusiveCumsumOp(torch.autograd.Function): @staticmethod @@ -36,4 +40,6 @@ def forward(ctx, x, dim): out = torch.empty_like(x) ops.inclusive_cumsum(x, dim, out) return out + + inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index bd8da3a..ec78d2c 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -11,8 +11,7 @@ class GatherOp(torch.autograd.Function): def forward(ctx, x, indices, bin_ids, bins, top_k): ctx.save_for_backward(indices, bin_ids, bins) ctx.top_k = top_k - return kernels.gather( - x, indices, bin_ids, None, bins, top_k) + return kernels.gather(x, indices, bin_ids, None, bins, top_k) @staticmethod @custom_bwd @@ -20,7 +19,8 @@ def backward(ctx, grad): grad = grad.contiguous() indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter( - grad, indices, bin_ids, None, bins, ctx.top_k) + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) return out, None, None, None, None, None + + gather = GatherOp.apply diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index f81862b..f77a8e1 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -7,6 +7,7 @@ # c++ operations. import megablocks_ops as ops + # Autograd wrapper for histogram kernel. # # NOTE: Does not support gradients. @@ -15,4 +16,6 @@ class HistogramOp(torch.autograd.Function): @staticmethod def forward(ctx, x, max_val): return ops.histogram(x, max_val) + + histogram = HistogramOp.apply diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index ff4689a..c6c31e1 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -5,7 +5,6 @@ import numpy as np import torch - _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), (16384, torch.int32, 4), @@ -17,6 +16,7 @@ (16384, torch.int32, 256), ) + def benchmark_function(fn, iterations=10): # Run once to get rid of startup overhead. fn() @@ -34,13 +34,13 @@ def benchmark_function(fn, iterations=10): def log_benchmark(arguments, mean_t, std_t): - print("="*60) + print("=" * 60) print("Benchmark Parameters:") for (key, value) in arguments.items(): print(f"{key} = {value}") print("Results:") print("mean / std = {:.2f}ms / {:.2f}ms".format(mean_t, std_t)) - print("="*60) + print("=" * 60) class HistogramBenchmark(parameterized.TestCase): @@ -50,7 +50,8 @@ def testHistogram(self, n, dtype, max_val): x = torch.randint(0, max_val, (n,)).cuda().to(dtype) mean_t, std_t, max_t, min_t = benchmark_function( - lambda: ops.histogram(x, max_val)) + lambda: ops.histogram(x, max_val), + ) arguments = { "n": n, "dtype": dtype, @@ -63,7 +64,8 @@ def testTorchHistogram(self, n, dtype, max_val): x = torch.randint(0, 128, (n,)).cuda().to(dtype) mean_t, std_t, max_t, min_t = benchmark_function( - lambda: torch.histc(x, max_val, 0, max_val-1)) + lambda: torch.histc(x, max_val, 0, max_val - 1), + ) arguments = { "n": n, "dtype": dtype, diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index 6fe4626..b039cee 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -12,7 +12,10 @@ # this adds. def transpose_view(x): return torch.as_strided( - x, (x.shape[1], x.shape[0]), (x.stride()[1], x.stride()[0])) + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) _MATMUL_TESTS = ( @@ -27,7 +30,7 @@ def log_benchmark(name, arguments, time, std, flops): benchmark_util.log_benchmark(name, arguments, time, std) print("flops = {:.2f}B".format(flops / 1e9)) print("throughput = {:.2f}T".format(flops / 1e9 / time)) - print("="*60) + print("=" * 60) class MatmulBenchmark(parameterized.TestCase): @@ -48,29 +51,28 @@ def build_sparse_matrix(self, x, padded_bins, fhs, ne): block_rows * blocks_per_row + 1, blocks_per_row, dtype=torch.int32, - device=x.device) + device=x.device, + ) # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, - blocking, - block_rows, - blocks_per_row) + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) data = torch.empty( column_indices.numel(), blocking, blocking, dtype=torch.float16, - device=x.device) + device=x.device, + ) shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices( - shape, data, offsets, column_indices) - return stk.Matrix(shape, - data, - row_indices, - column_indices, - offsets) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) def build_input_matrix(self, sl, hs, ne): x = torch.randn((sl, hs)).cuda().half() @@ -98,6 +100,7 @@ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): def benchmark(): return stk.ops.sdd(x, w, topo) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -105,8 +108,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("0::Fwd::SDD::NT", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + "0::Fwd::SDD::NT", + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): @@ -116,6 +124,7 @@ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): def benchmark(): return stk.ops.dsd(topo, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -123,8 +132,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("0::GradX::DSD::NN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + "0::GradX::DSD::NN", + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): @@ -134,6 +148,7 @@ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): def benchmark(): return stk.ops.dsd(topo, x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -141,8 +156,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("0::GradW::DSD::TN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + "0::GradW::DSD::TN", + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): @@ -152,6 +172,7 @@ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): def benchmark(): return stk.ops.dsd(x, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -159,8 +180,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("1::Fwd::DSD::NN", arguments, mean_t, std_t, - x.nnz * hs * 2) + log_benchmark( + "1::Fwd::DSD::NN", + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): @@ -172,6 +198,7 @@ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): def benchmark(): return stk.ops.sdd(out, w, x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -179,8 +206,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("1::GradX::SDD::NT", arguments, mean_t, std_t, - x.nnz * hs * 2) + log_benchmark( + "1::GradX::SDD::NT", + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): @@ -192,6 +224,7 @@ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): def benchmark(): return stk.ops.dsd(x, out) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -199,8 +232,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("1::GradW::DSD::TN", arguments, mean_t, std_t, - x.nnz * hs * 2) + log_benchmark( + "1::GradW::DSD::TN", + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): @@ -213,6 +251,7 @@ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): def benchmark(): return torch.bmm(x, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -220,8 +259,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("0::Fwd:DDD::NT", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + "0::Fwd:DDD::NT", + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): @@ -233,6 +277,7 @@ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): def benchmark(): return torch.bmm(out, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -240,8 +285,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("0:GradX:DDD::NN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + "0:GradX:DDD::NN", + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): @@ -253,6 +303,7 @@ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): def benchmark(): return torch.bmm(out, x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -260,8 +311,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("0:GradW:DDD::TN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + "0:GradW:DDD::TN", + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): @@ -271,6 +327,7 @@ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): def benchmark(): return torch.bmm(x, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -278,8 +335,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("1::Fwd::DDD::NN", arguments, mean_t, std_t, - x.numel() * hs * 2) + log_benchmark( + "1::Fwd::DDD::NN", + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): @@ -291,6 +353,7 @@ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): def benchmark(): return torch.bmm(out, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -298,8 +361,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("1::GradX::DDD::NT", arguments, mean_t, std_t, - x.numel() * hs * 2) + log_benchmark( + "1::GradX::DDD::NT", + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): @@ -311,6 +379,7 @@ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): def benchmark(): return torch.bmm(x, out) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -318,8 +387,13 @@ def benchmark(): "ffn_hidden_size": fhs, "num_experts": ne, } - log_benchmark("1::GradW::DDD::TN", arguments, mean_t, std_t, - x.numel() * hs * 2) + log_benchmark( + "1::GradW::DDD::TN", + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) if __name__ == '__main__': diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 3c2685f..696629b 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -12,7 +12,14 @@ def forward(ctx, x, indices, bin_ids, bins, padded_bins, top_k): ctx.save_for_backward(indices, bin_ids, bins, padded_bins) ctx.top_k = top_k return kernels.padded_gather( - x, indices, bin_ids, None, bins, padded_bins, top_k) + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) @staticmethod @custom_bwd @@ -21,6 +28,15 @@ def backward(ctx, grad): indices, bin_ids, bins, padded_bins = ctx.saved_tensors out = kernels.padded_scatter( - grad, indices, bin_ids, None, bins, padded_bins, ctx.top_k) + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) return out, None, None, None, None, None + + padded_gather = PaddedGatherOp.apply diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 22ae923..8780b33 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -11,11 +11,24 @@ class PaddedScatterOp(torch.autograd.Function): def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k): maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward( - indices, bin_ids, weights, bins, padded_bins, *maybe_x) + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) ctx.top_k = top_k ctx.x_shape = x.shape return kernels.padded_scatter( - x, indices, bin_ids, weights, bins, padded_bins, top_k) + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) @staticmethod @custom_bwd @@ -33,7 +46,8 @@ def backward(ctx, grad): weights, bins, padded_bins, - ctx.top_k) + ctx.top_k, + ) wgrad = None if ctx.needs_input_grad[3]: # need wgrad @@ -45,16 +59,26 @@ def backward(ctx, grad): bin_ids, bins, padded_bins, - ctx.top_k) + ctx.top_k, + ) return dgrad, None, None, wgrad, None, None, None, None -def padded_scatter(x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int): - return PaddedScatterOp.apply(x, indices, bin_ids, weights, bins, - padded_bins, top_k) +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/megablocks/ops/padded_scatter_benchmark.py b/megablocks/ops/padded_scatter_benchmark.py index 87e17ed..8580d38 100644 --- a/megablocks/ops/padded_scatter_benchmark.py +++ b/megablocks/ops/padded_scatter_benchmark.py @@ -35,17 +35,28 @@ def testPaddedScatter(self, sl, hs, ne, top_k): x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) def benchmark(): - return ops.padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k) + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) time, std = benchmark_util.benchmark_function(benchmark) benchmark_util.log_benchmark( "Padded Scatter", - {"sequence_length": sl, - "hidden_size": hs, - "num_experts": ne, - "top_k": top_k}, + { + "sequence_length": sl, + "hidden_size": hs, + "num_experts": ne, + "top_k": top_k, + }, time, - std) + std, + ) if __name__ == '__main__': diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index a93dc27..886781b 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -39,6 +39,7 @@ def testBinnedGather(self, sl, hs, ne): def benchmark(): return ops.binned_gather(x, indices, bins, ec) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -62,6 +63,7 @@ def testBinnedScatter(self, sl, hs, ne): def benchmark(): return ops.binned_scatter(x, indices, bins) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -85,6 +87,7 @@ def testPaddedGather(self, sl, hs, ne): def benchmark(): return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -109,6 +112,7 @@ def testPaddedScatter(self, sl, hs, ne): def benchmark(): return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, @@ -128,6 +132,7 @@ def testCopy(self, sl, hs, ne): def benchmark(): return y.copy_(x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { "sequence_length": sl, diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index 66e4b10..db995ff 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,5 +1,3 @@ - - def repeat(x, tiling): if all((t == 1 for t in tiling)): return x diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index 4d0cf34..f2ef1d0 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -7,26 +7,27 @@ # c++ operations. import megablocks_ops as ops + # Autograd wrapper for replicate kernel. class ReplicateOp(torch.autograd.Function): @staticmethod def forward(ctx, x, bins, num_outputs): ctx.save_for_backward(bins) - out = torch.empty( - (x.shape[0], num_outputs), - dtype=x.dtype, - device=x.device) + out = torch.empty((x.shape[0], num_outputs), + dtype=x.dtype, + device=x.device) ops.replicate_forward(x, bins, out) return out @staticmethod def backward(ctx, grad): bins, = ctx.saved_tensors - out = torch.empty( - (grad.shape[0], bins.shape[0]), - dtype=grad.dtype, - device=grad.device) + out = torch.empty((grad.shape[0], bins.shape[0]), + dtype=grad.dtype, + device=grad.device) ops.replicate_backward(grad, bins, out) return out, None, None + + replicate = ReplicateOp.apply diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 0e91d80..b4a8576 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -13,8 +13,7 @@ def forward(ctx, x, indices, bin_ids, weights, bins, top_k): ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) ctx.top_k = top_k ctx.x_shape = x.shape - return kernels.scatter( - x, indices, bin_ids, weights, bins, top_k) + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) @staticmethod @custom_bwd @@ -31,7 +30,8 @@ def backward(ctx, grad): bin_ids, weights, bins, - ctx.top_k) + ctx.top_k, + ) wgrad = None if ctx.needs_input_grad[3]: # need wgrad @@ -42,14 +42,17 @@ def backward(ctx, grad): indices, bin_ids, bins, - ctx.top_k) + ctx.top_k, + ) return dgrad, None, None, wgrad, None, None, None -def scatter(x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int): +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +): return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index a4bb99f..ce22783 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -7,13 +7,13 @@ # c++ operations. import megablocks_ops as ops - _BITS_FOR_DTYPE = { torch.int16: 16, torch.int32: 32, torch.int64: 64, } + # Autograd wrapper for sort kernel. # # NOTE: Does not support gradients. @@ -27,4 +27,6 @@ def forward(ctx, x, end_bit=None): iota_out = torch.empty_like(x) ops.sort(x, end_bit, x_out, iota_out) return (x_out, iota_out) + + sort = SortOp.apply diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index 56e343f..8def5de 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -5,16 +5,13 @@ import numpy as np import torch - _SORT_TESTS = ( (16384, torch.int32, None), (16384, torch.int32, 2), (16384, torch.int32, 128), ) -_BASELINE_SORT_TESTS = ( - (16384,), -) +_BASELINE_SORT_TESTS = ((16384,),) def numpy_dtype(dtype): @@ -43,13 +40,13 @@ def benchmark_function(fn, iterations=10): def log_benchmark(arguments, mean_t, std_t): - print("="*60) + print("=" * 60) print("Benchmark Parameters:") for (key, value) in arguments.items(): print(f"{key} = {value}") print("Results:") print("mean / std = {:.2f}ms / {:.2f}ms".format(mean_t, std_t)) - print("="*60) + print("=" * 60) class SortBenchmark(parameterized.TestCase): @@ -62,7 +59,8 @@ def testSort(self, n, dtype, max_val): x = torch.randint(0, max_val, (n,)).cuda().to(dtype) mean_t, std_t, max_t, min_t = benchmark_function( - lambda: ops.sort(x, end_bit)) + lambda: ops.sort(x, end_bit), + ) arguments = { "n": n, "dtype": dtype, @@ -74,9 +72,10 @@ def testSort(self, n, dtype, max_val): def testTorchSort(self, n): x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: torch.sort(x)) - arguments = {"n": n,} + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + "n": n, + } log_benchmark(arguments, mean_t, std_t) diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index 5b511e6..76797da 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,5 +1,3 @@ - - def sum(x, dim=0): if x.shape[dim] == 1: return x.squeeze(dim=dim) diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index 7ce31bc..fb43daa 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -7,24 +7,33 @@ # c++ operations. import megablocks_ops as ops + # Autograd wrapper for topology kernel. # # NOTE: Does not support gradients. class TopologyOp(torch.autograd.Function): @staticmethod - def forward(ctx, - padded_bins, - block_size, - output_block_rows, - output_block_columns): - out = torch.empty(output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device) - ops.indices(padded_bins, - block_size, - output_block_rows, - output_block_columns, - out) + def forward( + ctx, + padded_bins, + block_size, + output_block_rows, + output_block_columns, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) return out + + topology = TopologyOp.apply diff --git a/pyproject.toml b/pyproject.toml index 2cf40c5..9d74667 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,3 +71,406 @@ ignore = [ "PERF2", "PERF4", ] + +# Yapf +[tool.yapf] +# Align closing bracket with visual indentation. +align_closing_bracket_with_visual_indent = false + +# Allow dictionary keys to exist on multiple lines. For example: +# +# x = { +# ('this is the first element of a tuple', +# 'this is the second element of a tuple'): +# value, +# } +allow_multiline_dictionary_keys = false + +# Allow lambdas to be formatted on more than one line. +allow_multiline_lambdas = false + +# Allow splitting before a default / named assignment in an argument list. +allow_split_before_default_or_named_assigns = true + +# Allow splits before the dictionary value. +allow_split_before_dict_value = true + +# Let spacing indicate operator precedence. For example: +# +# a = 1 * 2 + 3 / 4 +# b = 1 / 2 - 3 * 4 +# c = (1 + 2) * (3 - 4) +# d = (1 - 2) / (3 + 4) +# e = 1 * 2 - 3 +# f = 1 + 2 + 3 + 4 +# +# will be formatted as follows to indicate precedence: +# +# a = 1*2 + 3/4 +# b = 1/2 - 3*4 +# c = (1+2) * (3-4) +# d = (1-2) / (3+4) +# e = 1*2 - 3 +# f = 1 + 2 + 3 + 4 +# +arithmetic_precedence_indication = false + +# Number of blank lines surrounding top-level function and class +# definitions. +blank_lines_around_top_level_definition = 2 + +# Insert a blank line before a class-level docstring. +blank_line_before_class_docstring = false + +# Insert a blank line before a module docstring. +blank_line_before_module_docstring = true + +# Insert a blank line before a 'def' or 'class' immediately nested +# within another 'def' or 'class'. For example: +# +# class Foo: +# # <------ this blank line +# def method(): +# ... +blank_line_before_nested_class_or_def = true + +# Do not split consecutive brackets. Only relevant when +# dedent_closing_brackets is set. For example: +# +# call_func_that_takes_a_dict( +# { +# 'key1': 'value1', +# 'key2': 'value2', +# } +# ) +# +# would reformat to: +# +# call_func_that_takes_a_dict({ +# 'key1': 'value1', +# 'key2': 'value2', +# }) +coalesce_brackets = true + +# The column limit. +column_limit = 80 + +# The style for continuation alignment. Possible values are: +# +# - SPACE: Use spaces for continuation alignment. This is default behavior. +# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns +# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or +# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment. +# - VALIGN-RIGHT: Vertically align continuation lines to multiple of +# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if +# cannot vertically align continuation lines with indent characters. +continuation_align_style = 'SPACE' + +# Indent width used for line continuations. +continuation_indent_width = 4 + +# Put closing brackets on a separate line, dedented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is dedented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is dedented and on a separate line +dedent_closing_brackets = true + +# Disable the heuristic which places each list element on a separate line +# if the list is comma-terminated. +disable_ending_comma_heuristic = false + +# Place each dictionary entry onto its own line. +each_dict_entry_on_separate_line = true + +# Require multiline dictionary even if it would normally fit on one line. +# For example: +# +# config = { +# 'key1': 'value1' +# } +force_multiline_dict = false + +# The regex for an i18n comment. The presence of this comment stops +# reformatting of that line, because the comments are required to be +# next to the string they translate. +i18n_comment = '#\..*' + +# The i18n function call names. The presence of this function stops +# reformattting on that line, because the string it has cannot be moved +# away from the i18n comment. +i18n_function_call = 'N_, _' + +# Indent blank lines. +indent_blank_lines = false + +# Put closing brackets on a separate line, indented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is indented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is indented and on a separate line +indent_closing_brackets = false + +# Indent the dictionary value if it cannot fit on the same line as the +# dictionary key. For example: +# +# config = { +# 'key1': +# 'value1', +# 'key2': value1 + +# value2, +# } +indent_dictionary_value = true + +# The number of columns to use for indentation. +indent_width = 4 + +# Join short lines into one line. E.g., single line 'if' statements. +join_multiple_lines = false + +# Do not include spaces around selected binary operators. For example: +# +# 1 + 2 * 3 - 4 / 5 +# +# will be formatted as follows when configured with "*,/": +# +# 1 + 2*3 - 4/5 +no_spaces_around_selected_binary_operators = '' + +# Use spaces around default or named assigns. +spaces_around_default_or_named_assign = false + +# Adds a space after the opening '{' and before the ending '}' dict delimiters. +# +# {1: 2} +# +# will be formatted as: +# +# { 1: 2 } +spaces_around_dict_delimiters = false + +# Adds a space after the opening '[' and before the ending ']' list delimiters. +# +# [1, 2] +# +# will be formatted as: +# +# [ 1, 2 ] +spaces_around_list_delimiters = false + +# Use spaces around the power operator. +spaces_around_power_operator = false + +# Use spaces around the subscript / slice operator. For example: +# +# my_list[1 : 10 : 2] +spaces_around_subscript_colon = false + +# Adds a space after the opening '(' and before the ending ')' tuple delimiters. +# +# (1, 2, 3) +# +# will be formatted as: +# +# ( 1, 2, 3 ) +spaces_around_tuple_delimiters = false + +# The number of spaces required before a trailing comment. +# This can be a single value (representing the number of spaces +# before each trailing comment) or list of values (representing +# alignment column values; trailing comments within a block will +# be aligned to the first column value that is greater than the maximum +# line length within the block). For example: +# +# With spaces_before_comment=5: +# +# 1 + 1 # Adding values +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment +# +# With spaces_before_comment = '15, 20:' +# +# 1 + 1 # Adding values +# two + two # More adding +# +# longer_statement # This is a longer statement +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment +# short # This is a shorter statement +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15 +# two + two # More adding +# +# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20 +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length +# short # This is a shorter statement +# +spaces_before_comment = 2 + +# Insert a space between the ending comma and closing bracket of a list, +# etc. +space_between_ending_comma_and_closing_bracket = false + +# Use spaces inside brackets, braces, and parentheses. For example: +# +# method_call( 1 ) +# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ] +# my_set = { 1, 2, 3 } +space_inside_brackets = false + +# Split before arguments +split_all_comma_separated_values = false + +# Split before arguments, but do not split all subexpressions recursively +# (unless needed). +split_all_top_level_comma_separated_values = false + +# Split before arguments if the argument list is terminated by a +# comma. +split_arguments_when_comma_terminated = true + +# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@' +# rather than after. +split_before_arithmetic_operator = false + +# Set to True to prefer splitting before '&', '|' or '^' rather than +# after. +split_before_bitwise_operator = false + +# Split before the closing bracket if a list or dict literal doesn't fit on +# a single line. +split_before_closing_bracket = true + +# Split before a dictionary or set generator (comp_for). For example, note +# the split before the 'for': +# +# foo = { +# variable: 'Hello world, have a nice day!' +# for variable in bar if variable != 42 +# } +split_before_dict_set_generator = false + +# Split before the '.' if we need to split a longer expression: +# +# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d)) +# +# would reformat to something like: +# +# foo = ('This is a really long string: {}, {}, {}, {}' +# .format(a, b, c, d)) +split_before_dot = false + +# Split after the opening paren which surrounds an expression if it doesn't +# fit on a single line. +split_before_expression_after_opening_paren = false + +# If an argument / parameter list is going to be split, then split before +# the first argument. +split_before_first_argument = false + +# Set to True to prefer splitting before 'and' or 'or' rather than +# after. +split_before_logical_operator = false + +# Split named assignments onto individual lines. +split_before_named_assigns = true + +# Set to True to split list comprehensions and generators that have +# non-trivial expressions and multiple clauses before each of these +# clauses. For example: +# +# result = [ +# a_long_var + 100 for a_long_var in xrange(1000) +# if a_long_var % 10] +# +# would reformat to something like: +# +# result = [ +# a_long_var + 100 +# for a_long_var in xrange(1000) +# if a_long_var % 10] +split_complex_comprehension = true + +# The penalty for splitting right after the opening bracket. +split_penalty_after_opening_bracket = 300 + +# The penalty for splitting the line after a unary operator. +split_penalty_after_unary_operator = 10000 + +# The penalty of splitting the line around the '+', '-', '*', '/', '//', +# ``%``, and '@' operators. +split_penalty_arithmetic_operator = 300 + +# The penalty for splitting right before an if expression. +split_penalty_before_if_expr = 0 + +# The penalty of splitting the line around the '&', '|', and '^' +# operators. +split_penalty_bitwise_operator = 300 + +# The penalty for splitting a list comprehension or generator +# expression. +split_penalty_comprehension = 2100 + +# The penalty for characters over the column limit. +split_penalty_excess_character = 7000 + +# The penalty incurred by adding a line split to the unwrapped line. The +# more line splits added the higher the penalty. +split_penalty_for_added_line_split = 20 + +# The penalty of splitting a list of "import as" names. For example: +# +# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, +# long_argument_2, +# long_argument_3) +# +# would reformat to something like: +# +# from a_very_long_or_indented_module_name_yada_yad import ( +# long_argument_1, long_argument_2, long_argument_3) +split_penalty_import_names = 0 + +# The penalty of splitting the line around the 'and' and 'or' +# operators. +split_penalty_logical_operator = 300 + +# Use the Tab character for indentation. +use_tabs = false + +# Ignore directories +[tool.yapfignore] +ignore_patterns = [ + "runs/**/*.py", + "wandb/**/*.py", + "build/**/*.py", +] diff --git a/setup.py b/setup.py index 646a17b..ec6e4d8 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,10 @@ "megablocks_ops", ["csrc/ops.cu"], include_dirs=["csrc"], - extra_compile_args={"cxx": ["-fopenmp"], "nvcc": nvcc_flags}, + extra_compile_args={ + "cxx": ["-fopenmp"], + "nvcc": nvcc_flags, + }, ), ] @@ -44,7 +47,7 @@ ] extra_deps['dev'] = [ - 'absl-py', # todo: delete when finish removing all absl tests + 'absl-py', # todo: delete when finish removing all absl tests 'coverage[toml]==7.4.4', 'pytest_codeblocks>=0.16.1,<0.17', 'pytest-cov>=4,<5', diff --git a/tests/conftest.py b/tests/conftest.py index 328c712..9fff00e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,8 +23,11 @@ def _get_world_size(item: pytest.Item): return item.get_closest_marker('world_size', default=_default).args[0] - -def _get_option(config: pytest.Config, name: str, default: Optional[str] = None) -> str: # type: ignore +def _get_option( + config: pytest.Config, + name: str, + default: Optional[str] = None, +) -> str: # type: ignore val = config.getoption(name) if val is not None: assert isinstance(val, str) @@ -34,13 +37,20 @@ def _get_option(config: pytest.Config, name: str, default: Optional[str] = None) val = None if val is None: if default is None: - pytest.fail(f'Config option {name} is not specified but is required') + pytest.fail( + f'Config option {name} is not specified but is required', + ) val = default assert isinstance(val, str) return val -def _add_option(parser: pytest.Parser, name: str, help: str, choices: Optional[list[str]] = None): +def _add_option( + parser: pytest.Parser, + name: str, + help: str, + choices: Optional[list[str]] = None, +): parser.addoption( f'--{name}', default=None, diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 29fbdeb..3a3fab2 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -50,7 +50,8 @@ def configure_dist(request: pytest.FixtureRequest): device = None for item in request.session.items: - device = DeviceCPU() if item.get_closest_marker('gpu') is None else DeviceGPU() + device = DeviceCPU( + ) if item.get_closest_marker('gpu') is None else DeviceGPU() break assert device is not None @@ -74,7 +75,11 @@ def set_log_levels(): def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch): """Monkeypatch reproducibility get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local seed.""" - monkeypatch.setattr(reproducibility, 'get_random_seed', lambda: rank_zero_seed) + monkeypatch.setattr( + reproducibility, + 'get_random_seed', + lambda: rank_zero_seed, + ) reproducibility.seed_all(rank_zero_seed + dist.get_global_rank()) diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index 3ead862..9ab70fe 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -44,24 +44,28 @@ ) -def construct_moes(hidden_size: int, - ffn_hidden_size: int, - moe_num_experts: int = 1, - moe_capacity_factor: int = 1, - moe_top_k: int = 1, - mlp_impl: str = 'sparse'): +def construct_moes( + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int = 1, + moe_capacity_factor: int = 1, + moe_top_k: int = 1, + mlp_impl: str = 'sparse', +): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) - args = Arguments(hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=moe_num_experts, - moe_capacity_factor=moe_capacity_factor, - moe_top_k=moe_top_k, - init_method=init_method, - memory_optimized_mlp=True, - mlp_type='mlp', - mlp_impl=mlp_impl, - fp16=False, - bf16=True) + args = Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=moe_num_experts, + moe_capacity_factor=moe_capacity_factor, + moe_top_k=moe_top_k, + init_method=init_method, + memory_optimized_mlp=True, + mlp_type='mlp', + mlp_impl=mlp_impl, + fp16=False, + bf16=True, + ) mlp = testing.FFN(args) moe_mlp = moe.MoE(args) @@ -76,8 +80,9 @@ def construct_moes(hidden_size: int, ne, hs, fhs = moe_mlp.experts.mlp.w1.size() w1 = dmoe_mlp.experts.mlp.w1.view([ne, fhs, hs]) moe_mlp.experts.mlp.w1.copy_(torch.transpose(w1, 1, 2).contiguous()) - moe_mlp.experts.mlp.w2.copy_(dmoe_mlp.experts.mlp.w2.view([ne, fhs, - hs])) + moe_mlp.experts.mlp.w2.copy_( + dmoe_mlp.experts.mlp.w2.view([ne, fhs, hs]), + ) moe_mlp.router.layer.weight.copy_(dmoe_mlp.router.layer.weight) if moe_num_experts == 1: mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze()) @@ -88,18 +93,22 @@ def construct_moes(hidden_size: int, @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) -def test_dmoe_forward(bs: int, - sl: int, - hs: int, - num_experts: int, - top_k: int, - mlp_impl: str): +def test_dmoe_forward( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() - _, _, _, layer = construct_moes(hidden_size=hs, - ffn_hidden_size=hs * 2, - moe_num_experts=num_experts, - moe_top_k=top_k, - mlp_impl=mlp_impl) + _, _, _, layer = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=num_experts, + moe_top_k=top_k, + mlp_impl=mlp_impl, + ) out, _ = layer(x) assert out.shape == x.shape @@ -109,20 +118,24 @@ def test_dmoe_forward(bs: int, @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) -def test_dmoe_forward_backward(bs: int, - sl: int, - hs: int, - num_experts: int, - top_k: int, - mlp_impl: str): +def test_dmoe_forward_backward( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() x.requires_grad_(True) - args, _, _, layer = construct_moes(hidden_size=hs, - ffn_hidden_size=hs * 2, - moe_num_experts=num_experts, - moe_top_k=top_k, - mlp_impl=mlp_impl) + args, _, _, layer = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=num_experts, + moe_top_k=top_k, + mlp_impl=mlp_impl, + ) out, _ = layer(x) assert out.shape == x.shape @@ -136,18 +149,22 @@ def test_dmoe_forward_backward(bs: int, @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) -def test_dmoe_forward_vs_baseline(bs: int, - sl: int, - hs: int, - mlp_impl: str = 'sparse'): +def test_dmoe_forward_vs_baseline( + bs: int, + sl: int, + hs: int, + mlp_impl: str = 'sparse', +): x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() - _, mlp, _, dmoe_mlp = construct_moes(hidden_size=hs, - ffn_hidden_size=hs * 2, - moe_num_experts=1, - moe_capacity_factor=1, - moe_top_k=1, - mlp_impl=mlp_impl) + _, mlp, _, dmoe_mlp = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=1, + moe_capacity_factor=1, + moe_top_k=1, + mlp_impl=mlp_impl, + ) expected_out = mlp(x) out, _ = dmoe_mlp(x) @@ -158,21 +175,25 @@ def test_dmoe_forward_vs_baseline(bs: int, @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) -def test_dmoe_forward_vs_moe(bs: int, - sl: int, - hs: int, - num_experts: int, - top_k: int, - mlp_impl: str): +def test_dmoe_forward_vs_moe( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): torch.manual_seed(42) x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() - _, _, moe_mlp, dmoe_mlp = construct_moes(hidden_size=hs, - ffn_hidden_size=hs, - moe_num_experts=num_experts, - moe_capacity_factor=0, - mlp_impl=mlp_impl) + _, _, moe_mlp, dmoe_mlp = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs, + moe_num_experts=num_experts, + moe_capacity_factor=0, + mlp_impl=mlp_impl, + ) expected_out, _ = moe_mlp(x) out, _ = dmoe_mlp(x) diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index b09cff3..93d7c0a 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -14,10 +14,11 @@ def construct_dmoe_glu( - hidden_size: int, - ffn_hidden_size: int, - mlp_impl: str ='sparse', - memory_optimized_mlp: bool =False): + hidden_size: int, + ffn_hidden_size: int, + mlp_impl: str = 'sparse', + memory_optimized_mlp: bool = False, +): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( hidden_size=hidden_size, @@ -29,7 +30,8 @@ def construct_dmoe_glu( mlp_type='glu', mlp_impl=mlp_impl, fp16=False, - bf16=True) + bf16=True, + ) glu = testing.GLU(args) dmoe_glu = dmlp_registry.get(args) @@ -45,7 +47,6 @@ def construct_dmoe_glu( return args, glu, dmoe_glu - @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) def test_glu_forward_grouped_mlp(bs: int, sl: int, hs: int): @@ -54,7 +55,8 @@ def test_glu_forward_grouped_mlp(bs: int, sl: int, hs: int): _, glu, dmoe_glu = construct_dmoe_glu( hidden_size=hs, ffn_hidden_size=hs * 2, - mlp_impl='grouped') + mlp_impl='grouped', + ) expected_out = glu(x) tokens_per_expert = torch.tensor([bs * sl]).cuda() @@ -74,7 +76,8 @@ def test_glu_forward_grouped_mlp_mem_opt(bs: int, sl: int, hs: int): hidden_size=hs, ffn_hidden_size=hs * 2, mlp_impl='grouped', - memory_optimized_mlp=True) + memory_optimized_mlp=True, + ) expected_out = glu(x) tokens_per_expert = torch.tensor([bs * sl]).cuda() @@ -93,7 +96,8 @@ def test_glu_forward_sparse_mlp(bs: int, sl: int, hs: int): _, glu, dmoe_glu = construct_dmoe_glu( hidden_size=hs, ffn_hidden_size=hs * 2, - mlp_impl='sparse') + mlp_impl='sparse', + ) expected_out = glu(x) with torch.no_grad(): diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index 75ea196..8591a7a 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -22,7 +22,6 @@ (16, 1024, 512, 8, 8), ) - _DENSE_TESTS = ( (16, 1024, 512), (8, 2048, 512), @@ -30,11 +29,12 @@ def construct_moe( - hidden_size, - ffn_hidden_size, - moe_num_experts=1, - moe_capacity_factor=1, - moe_top_k=1): + hidden_size, + ffn_hidden_size, + moe_num_experts=1, + moe_capacity_factor=1, + moe_top_k=1, +): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( hidden_size=hidden_size, @@ -42,7 +42,8 @@ def construct_moe( moe_num_experts=moe_num_experts, moe_capacity_factor=moe_capacity_factor, moe_top_k=moe_top_k, - init_method=init_method) + init_method=init_method, + ) mlp = testing.FFN(args) moe_mlp = moe.MoE(args) @@ -68,16 +69,24 @@ def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): hidden_size=hs, ffn_hidden_size=hs * 2, moe_num_experts=num_experts, - moe_top_k=top_k) + moe_top_k=top_k, + ) out, _ = layer(x) assert out.shape == x.shape moe.clear_load_balancing_loss() + @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) -def test_moe_forward_backward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): +def test_moe_forward_backward( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, +): x = torch.randn(sl, bs, hs).half().cuda() x.requires_grad_(True) @@ -85,7 +94,8 @@ def test_moe_forward_backward(bs: int, sl: int, hs: int, num_experts: int, top_k hidden_size=hs, ffn_hidden_size=hs * 2, moe_num_experts=num_experts, - moe_top_k=top_k) + moe_top_k=top_k, + ) out, _ = layer(x) assert out.shape == x.shape @@ -102,9 +112,7 @@ def test_moe_forward_backward(bs: int, sl: int, hs: int, num_experts: int, top_k def test_moe_forward_vs_dense(bs: int, sl: int, hs: int): x = torch.randn(sl, bs, hs).half().cuda() - _, mlp, moe_mlp = construct_moe( - hidden_size=hs, - ffn_hidden_size=hs * 2) + _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2) expected_out = mlp(x) out, _ = moe_mlp(x) @@ -119,9 +127,7 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): x = torch.randn(sl, bs, hs).half().cuda() x.requires_grad_(True) - _, mlp, moe_mlp = construct_moe( - hidden_size=hs, - ffn_hidden_size=hs * 2) + _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2) out, _ = moe_mlp(x) loss = out.sum() @@ -141,7 +147,7 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): x.grad = None # Verify the gradients match. - assert w1_grad.shape == expected_w1_grad.shape + assert w1_grad.shape == expected_w1_grad.shape assert w2_grad.shape == expected_w2_grad.shape assert torch.allclose(w1_grad, expected_w1_grad) assert torch.allclose(w2_grad, expected_w2_grad) diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py index 0aa4269..b7fce22 100644 --- a/tests/layers/parallelism_test.py +++ b/tests/layers/parallelism_test.py @@ -15,6 +15,7 @@ (4, 1, 512, 2048, 4, 1, True), ) + # Todo: Fix this long term @pytest.fixture def group(): @@ -23,17 +24,25 @@ def group(): @pytest.mark.world_size(2) @pytest.mark.gpu -@pytest.mark.parametrize(('batch_size', 'sequence_length', 'hidden_size', 'ffn_hidden_size', 'num_experts', 'top_k', 'memory_optimized'), - _PARALLELISM_TESTS) +@pytest.mark.parametrize(( + 'batch_size', + 'sequence_length', + 'hidden_size', + 'ffn_hidden_size', + 'num_experts', + 'top_k', + 'memory_optimized', +), _PARALLELISM_TESTS) def test_expert_parallel_versus_weight_parallel( - group, - batch_size: int, - sequence_length: int, - hidden_size: int, - ffn_hidden_size: int, - num_experts: int, - top_k: int, - memory_optimized: bool): + group, + batch_size: int, + sequence_length: int, + hidden_size: int, + ffn_hidden_size: int, + num_experts: int, + top_k: int, + memory_optimized: bool, +): init_fn = functools.partial(torch.nn.init.normal_, mean=0.0, std=0.1) ep_args = arguments.Arguments( @@ -47,7 +56,8 @@ def test_expert_parallel_versus_weight_parallel( bf16=False, device=torch.cuda.current_device(), init_method=init_fn, - memory_optimized_mlp=memory_optimized) + memory_optimized_mlp=memory_optimized, + ) wp_args = arguments.Arguments( hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, @@ -59,7 +69,8 @@ def test_expert_parallel_versus_weight_parallel( bf16=False, device=torch.cuda.current_device(), init_method=init_fn, - memory_optimized_mlp=memory_optimized) + memory_optimized_mlp=memory_optimized, + ) # NOTE: Reset the seed so that the models get identical weights. torch.manual_seed(1234) @@ -70,10 +81,9 @@ def test_expert_parallel_versus_weight_parallel( # NOTE: Include the rank in the seed so we get different data per rank. rank = torch.distributed.get_rank(group) torch.manual_seed(1234 * rank) - x = torch.randn( - (batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.float32).requires_grad_(True) + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.float32).requires_grad_(True) # Test forward. out, _ = wp(x) @@ -86,7 +96,9 @@ def test_expert_parallel_versus_weight_parallel( assert np.testing.assert_allclose( out.detach().float().cpu(), expected_out.detach().float().cpu(), - rtol=1e-4, atol=1e-4) is None + rtol=1e-4, + atol=1e-4, + ) is None # Test backward. out.mean().backward() @@ -97,8 +109,7 @@ def test_expert_parallel_versus_weight_parallel( def gather(x): m, n = x.shape world_size = torch.distributed.get_world_size(group) - out = torch.empty( - m * world_size, n, device=x.device, dtype=x.dtype) + out = torch.empty(m * world_size, n, device=x.device, dtype=x.dtype) torch.distributed.all_gather_into_tensor(out, x, group=group) return out @@ -114,7 +125,9 @@ def permute(x): assert np.testing.assert_allclose( wp_w2_grad.float().cpu(), ep_w2_grad.float().cpu(), - rtol=1e-5, atol=1e-5) is None + rtol=1e-5, + atol=1e-5, + ) is None wp_w1_grad = gather(wp.experts.mlp.w1.grad) ep_w1_grad = permute(gather(ep.experts.mlp.w1.grad)) @@ -122,7 +135,9 @@ def permute(x): assert np.testing.assert_allclose( wp_w1_grad.float().cpu(), ep_w1_grad.float().cpu(), - rtol=1e-5, atol=1e-5) is None + rtol=1e-5, + atol=1e-5, + ) is None # Verify the router weight gradient, which is not sharded. for i in range(torch.distributed.get_world_size(group)): @@ -131,4 +146,6 @@ def permute(x): assert np.testing.assert_allclose( wp.router.layer.weight.grad.float().cpu(), ep.router.layer.weight.grad.float().cpu(), - rtol=1e-5, atol=1e-5) is None + rtol=1e-5, + atol=1e-5, + ) is None diff --git a/tests/ops/binned_gather_test.py b/tests/ops/binned_gather_test.py index cc59ae3..d889bce 100644 --- a/tests/ops/binned_gather_test.py +++ b/tests/ops/binned_gather_test.py @@ -46,8 +46,13 @@ def test_binned_gather(sl: int, hs: int, ne: int, top_k: int): _, indices = ops.sort(top_expert) bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0) - def binned_gather(x: torch.Tensor, indices: torch.Tensor, - bins: torch.Tensor, ec: int, top_k: int): + def binned_gather( + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + ec: int, + top_k: int, + ): x = x.cpu().numpy() indices = indices.cpu().numpy() bins = bins.cpu().numpy() diff --git a/tests/ops/binned_scatter_test.py b/tests/ops/binned_scatter_test.py index 2d1c585..d13f89c 100644 --- a/tests/ops/binned_scatter_test.py +++ b/tests/ops/binned_scatter_test.py @@ -48,8 +48,13 @@ def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int): x = ops.binned_gather(x, indices, bins, ec, top_k) - def binned_scatter(x: torch.Tensor, indices: torch.Tensor, - weights: torch.Tensor, bins: torch.Tensor, top_k: int): + def binned_scatter( + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): x = x.cpu().numpy() indices = indices.cpu().numpy() weights = weights.cpu().numpy() @@ -66,10 +71,14 @@ def binned_scatter(x: torch.Tensor, indices: torch.Tensor, out[index, :] += scale * x[i, j, :] start = end return torch.from_numpy(out).cuda().half() + out = ops.binned_scatter(x, indices, weights, bins, top_k) expected_out = binned_scatter(x, indices, weights, bins, top_k) # NOTE: We need to check approximate equality because the # scatter reduce uses atomics. assert np.testing.assert_allclose( - out.cpu(), expected_out.cpu(), rtol=5e-3) is None + out.cpu(), + expected_out.cpu(), + rtol=5e-3, + ) is None diff --git a/tests/ops/histogram_test.py b/tests/ops/histogram_test.py index 25b30cb..5af55e7 100644 --- a/tests/ops/histogram_test.py +++ b/tests/ops/histogram_test.py @@ -78,6 +78,7 @@ def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int): x = torch.randint(0, max_val, (m, n)).cuda().to(dtype) out = ops.histogram(x, max_val) - expected_out = torch.stack( - [torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)]) + expected_out = torch.stack([ + torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1) + ]) assert torch.all(torch.eq(out, expected_out)) diff --git a/tests/ops/padded_gather_test.py b/tests/ops/padded_gather_test.py index e6eb7f7..e7b8a09 100644 --- a/tests/ops/padded_gather_test.py +++ b/tests/ops/padded_gather_test.py @@ -63,9 +63,14 @@ def testPaddedGather(sl: int, hs: int, ne: int, top_k: int): padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) bins = ops.inclusive_cumsum(tokens_per_expert, 0) - def padded_gather(x: torch.Tensor, indices: torch.Tensor, - bin_ids: torch.Tensor, bins: torch.Tensor, - padded_bins: torch.Tensor, top_k: int): + def padded_gather( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): x = x.cpu().numpy() indices = indices.cpu().numpy() bin_ids = bin_ids.cpu().numpy() diff --git a/tests/ops/padded_scatter_test.py b/tests/ops/padded_scatter_test.py index ebd04a8..637b04b 100644 --- a/tests/ops/padded_scatter_test.py +++ b/tests/ops/padded_scatter_test.py @@ -94,10 +94,15 @@ def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int): # Gather the data to prepare for backwards. x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - def padded_scatter(x: torch.Tensor, indices: torch.Tensor, - bin_ids: torch.Tensor, weights: torch.Tensor, - bins: torch.Tensor, padded_bins: torch.Tensor, - top_k: int): + def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): x = x.detach().cpu().numpy() indices: np.ndarray = _to_numpy(indices) bin_ids: np.ndarray = _to_numpy(bin_ids) @@ -120,10 +125,24 @@ def padded_scatter(x: torch.Tensor, indices: torch.Tensor, in_idx += 1 return torch.from_numpy(out).cuda().half() - out = ops.padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, - top_k) - expected_out = padded_scatter(x, indices, bin_ids, weights, bins, - padded_bins, top_k) + out = ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + expected_out = padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) out.backward(torch.randn_like(out)) # sanity check backward pass diff --git a/tests/ops/sort_test.py b/tests/ops/sort_test.py index 243aef1..f65ae16 100644 --- a/tests/ops/sort_test.py +++ b/tests/ops/sort_test.py @@ -32,7 +32,8 @@ def torch_to_numpy_dtype( - dtype: torch.dtype) -> Union[np.int16, np.int32, np.int64]: + dtype: torch.dtype, +) -> Union[np.int16, np.int32, np.int64]: types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = { torch.int16: np.int16, torch.int32: np.int32, diff --git a/tests/ops/topology_test.py b/tests/ops/topology_test.py index a7135be..b7a28b8 100644 --- a/tests/ops/topology_test.py +++ b/tests/ops/topology_test.py @@ -48,8 +48,12 @@ def test_topology(sl: int, hs: int, ne: int): output_block_rows = int(padded_bins[-1]) // blocking output_block_columns = hs // blocking - def topology(padded_bins: torch.Tensor, blocking: torch.Tensor, rows: int, - columns: int): + def topology( + padded_bins: torch.Tensor, + blocking: torch.Tensor, + rows: int, + columns: int, + ): padded_bins = padded_bins.cpu().numpy() out = np.zeros([rows * columns]) @@ -62,8 +66,16 @@ def topology(padded_bins: torch.Tensor, blocking: torch.Tensor, rows: int, start += 1 return torch.from_numpy(out).cuda().short() - out = ops.topology(padded_bins, blocking, output_block_rows, - output_block_columns) - expected_out = topology(padded_bins, blocking, output_block_rows, - output_block_columns) + out = ops.topology( + padded_bins, + blocking, + output_block_rows, + output_block_columns, + ) + expected_out = topology( + padded_bins, + blocking, + output_block_rows, + output_block_columns, + ) assert torch.all(torch.eq(out, expected_out)) From beb5d9bd8823e7f752708bdaa234520205a30e61 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 02:57:37 +0000 Subject: [PATCH 07/43] yapf --- megablocks/backend/kernels.py | 16 ++------ megablocks/grouped_gemm_util.py | 7 ++-- megablocks/layers/all_to_all.py | 4 +- megablocks/layers/arguments.py | 7 +--- megablocks/layers/dmlp_registry.py | 4 +- megablocks/layers/dmoe.py | 14 ++----- megablocks/layers/gelu.py | 5 +-- megablocks/layers/glu.py | 26 +++---------- megablocks/layers/memory_test.py | 17 ++------- megablocks/layers/mlp.py | 42 +++++---------------- megablocks/layers/moe.py | 53 ++++++--------------------- megablocks/layers/mpu.py | 32 ++++------------ megablocks/layers/testing.py | 3 +- megablocks/layers/weight_parallel.py | 10 +---- megablocks/ops/histogram_benchmark.py | 8 +--- megablocks/ops/replicate.py | 8 +--- megablocks/ops/sort_benchmark.py | 4 +- pyproject.toml | 12 ++++-- setup.py | 9 +---- tests/conftest.py | 4 +- tests/fixtures/autouse.py | 3 +- tests/layers/dmoe_test.py | 22 ++++------- tests/layers/moe_test.py | 6 +-- tests/layers/parallelism_test.py | 3 +- tests/ops/histogram_test.py | 4 +- tests/ops/sort_test.py | 4 +- 26 files changed, 87 insertions(+), 240 deletions(-) diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index a1668eb..3366f48 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -19,9 +19,7 @@ def assert_is_vector(x): def assert_equal(a, b): if a != b: - raise ValueError( - f"Expected dimensions to be equal but got {a} and {b}.", - ) + raise ValueError(f"Expected dimensions to be equal but got {a} and {b}.",) # a: (tokens, hidden_size), real. @@ -182,9 +180,7 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): assert_equal(indices.shape[0], weights.shape[0]) tokens = indices.shape[0] // top_k - out = torch.empty((tokens, top_k, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( out, x, @@ -400,9 +396,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): assert_equal(weights.shape[0], x.shape[0] * top_k) num_experts = bins.shape[0] - out = torch.zeros((num_experts, expert_capacity, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) _binned_copy[(num_experts, expert_capacity)]( x, @@ -432,9 +426,7 @@ def binned_scatter(x, indices, weights, bins, top_k): num_experts, expert_capacity, hidden_size = x.shape tokens = indices.shape[0] // top_k - out = torch.zeros((tokens, top_k, hidden_size), - dtype=x.dtype, - device=x.device) + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) _binned_copy[(num_experts, expert_capacity)]( out, x, diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index bdd81ee..36fd7f0 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -9,10 +9,9 @@ def grouped_gemm_is_available(): def assert_grouped_gemm_is_available(): - assert grouped_gemm_is_available(), ( - "Grouped GEMM not available. Please run " - "`pip install git+https://github.com/tgale96/grouped_gemm@main`." - ) + assert grouped_gemm_is_available( + ), ("Grouped GEMM not available. Please run " + "`pip install git+https://github.com/tgale96/grouped_gemm@main`.") backend = grouped_gemm.backend if grouped_gemm_is_available() else None diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index b94f662..07cc584 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -5,9 +5,7 @@ class AllToAllOp(torch.autograd.Function): @staticmethod def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty((sum(output_split_sizes),) + x.shape[1:], - device=x.device, - dtype=x.dtype) + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) ctx.input_shape = x.shape ctx.output_split_sizes = output_split_sizes diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 674dd47..fe67ad6 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -58,13 +58,10 @@ class Arguments: # shared expert arguments shared_expert: bool = False # enable using shared expert fc_cls: torch.nn.Module = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field( - default_factory=dict, - ) # kwargs for custom fc layers + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored shared_expert_hidden_size: Optional[ - int - ] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) def __post_init__(self): diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index b227a2e..a5b86d8 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -35,8 +35,6 @@ def get(args: Arguments) -> MlpType: raise ValueError(f'Unsupported mlp type: {args.mlp_type}') if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError( - f'{args.mlp_type} does not support {args.mlp_impl} backend.', - ) + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 52ef7a7..5ae5ee8 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -24,8 +24,7 @@ def __init__(self, args: Arguments): # Calculate the number of bits needed to represent the column indices # in the intermediate sparse matrix. - max_column_index = ((self.ffn_hidden_size * self.num_experts) // - self.blocking) + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) self.transpose_sort_end_bit = max( int(np.ceil(np.log2(max_column_index))), 1, @@ -69,8 +68,7 @@ def topology(self, x, padded_bins): assert padded_tokens % self.blocking == 0 if self.ffn_hidden_size % self.blocking != 0: raise ValueError( - f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' - + + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + f'the block size {self.blocking}. Please update your configuration.', ) @@ -160,9 +158,7 @@ def sparse_forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = ( - self.indices_and_padded_bins(top_experts) - ) + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) @@ -245,9 +241,7 @@ def grouped_forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts) - ) + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) out = self.grouped_permute_and_compute( x, diff --git a/megablocks/layers/gelu.py b/megablocks/layers/gelu.py index e0eb8c0..50a7c56 100644 --- a/megablocks/layers/gelu.py +++ b/megablocks/layers/gelu.py @@ -6,10 +6,7 @@ @torch.jit.script def _gelu_backward_inplace(g, x): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = ( - 0.5 * x * ((1 - tanh_out * tanh_out) * - (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out) - ) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) return g.mul_(ff) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 9a41a9b..d321aeb 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -42,9 +42,7 @@ def __init__(self, args: Arguments): ) if self.args.moe_weight_parallelism: - raise NotImplementedError( - "Weight parallelism not yet supported with GLU.", - ) + raise NotImplementedError("Weight parallelism not yet supported with GLU.",) def forward(self, x, topo): if self.args.memory_optimized_mlp: @@ -52,12 +50,8 @@ def forward(self, x, topo): "Memory optimized implementation not yet supported with GLU with sparse kernels.", ) - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad( - self.v1, - ), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor( - v1, - ), resolve_dtensor(w2) + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) # Compute the GLU. x1 = stk.ops.sdd(x, w1.t(), topo) @@ -82,10 +76,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): v1 = v1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if ( - not x.is_contiguous() or not w1.is_contiguous() or - not v1.is_contiguous() or not w2.is_contiguous() - ): + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") # Layer 0: x @ w1.t(). @@ -112,10 +103,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): - if ( - not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2] - ): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError("Expected all MLP inputs to need grad.") # Unpack saved tensors @@ -189,9 +177,7 @@ def forward(self, x, tokens_per_expert): self.scale_grad(self.v1), self.scale_grad(self.w2), ) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor( - v1, - ), resolve_dtensor(w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) # Re-shape the weights for the grouped GEMMs. ne = mpu.experts_per_rank(self.args) diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index c1512bd..22bd3dc 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -54,16 +54,11 @@ def test_memory( # Report peak memory. mem = torch.cuda.max_memory_allocated() print("Max Memory Allocated = {:0.0f}MiB".format(mem / 1e6)) - print( - "Max Memory Reserved = {:0.0f}MiB".format( - torch.cuda.max_memory_reserved() / 1e6, - ), - ) + print("Max Memory Reserved = {:0.0f}MiB".format(torch.cuda.max_memory_reserved() / 1e6,),) # Calculate weight and gradient memory usage. weight_memory = 2 * ( - layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + - layer.experts.mlp.w2.numel() + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() ) def grad_numel(x): @@ -72,16 +67,12 @@ def grad_numel(x): return 0 grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + - grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) ) weight_memory += grad_memory print("Weight Memory Allocated = {:0.0f}MiB".format(weight_memory / 1e6)) - print( - "Activation Memory Allocated = {:0.0f}MiB".format((mem - weight_memory) - / 1e6,), - ) + print("Activation Memory Allocated = {:0.0f}MiB".format((mem - weight_memory) / 1e6,),) # Manually calculate GPU memory usage from the garbage # collector. diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 8ce5fe3..024e405 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -148,9 +148,7 @@ def __init__(self, args: Arguments): self.gradient_scale = None if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size( - self.args, - ) + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) def scale_grad(self, w): if self.gradient_scale is None: @@ -208,10 +206,7 @@ def forward(ctx, x, w1, w2, topo, activation_fn): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if ( - not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous() - ): + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") topo_tensors = ( @@ -247,10 +242,7 @@ def forward(ctx, x, w1, w2, topo, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): - if ( - not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2] - ): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError("Expected all MLP inputs to need grad.") # unpack saved tensors @@ -328,10 +320,8 @@ class SparseMLP(torch.nn.Module): def __init__(self, args: Arguments): super().__init__() self.args = args - self._num_rows_per_rank = ( - (mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // - mpu.get_weight_parallel_world_size(args) - ) + self._num_rows_per_rank = ((mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // + mpu.get_weight_parallel_world_size(args)) self.w1 = torch.nn.Parameter( torch.empty( @@ -378,9 +368,7 @@ def __init__(self, args: Arguments): ), ) - self._should_set_parallelism_attribute = ( - args.moe_expert_model_parallelism or args.moe_weight_parallelism - ) + self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism) mpu.set_expert_model_parallel_attributes( self.w1, self._should_set_parallelism_attribute, @@ -392,9 +380,7 @@ def __init__(self, args: Arguments): self.gradient_scale = None if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size( - self.args, - ) + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) def scale_grad(self, w): if self.gradient_scale is None: @@ -454,10 +440,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if ( - not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous() - ): + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). @@ -483,10 +466,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): - if ( - not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2] - ): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError("Expected all MLP inputs to need grad.") # Unpack saved tensors @@ -560,9 +540,7 @@ def forward(self, x, tokens_per_expert): w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) if self.args.moe_weight_parallelism: - raise NotImplementedError( - "Weight parallelism not yet supported with GroupedMLP.", - ) + raise NotImplementedError("Weight parallelism not yet supported with GroupedMLP.",) if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 7f659d2..52ab852 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -31,9 +31,7 @@ def batched_load_balancing_loss(args: Arguments): # tokens_per_expert[i].shape = (num_experts) # expert_scores[i].shape = (tokens, num_experts) tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = ( - args.num_layers // args.pipeline_model_parallel_size - ) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) if args.num_layers_per_virtual_pipeline_stage is not None: num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage @@ -57,16 +55,10 @@ def batched_load_balancing_loss(args: Arguments): ) # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all(( - x.ndim == 1 and x.numel() == args.moe_num_experts - for x in tokens_per_expert - )) + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) tokens = expert_scores[0].shape[0] - assert all((( - x.ndim == 2 and x.shape[1] == args.moe_num_experts and - x.shape[0] == tokens - ) for x in expert_scores)) + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) # Concatenate the contributions of each layer and convert to # the correct types and formats for the dot product. @@ -130,16 +122,11 @@ def __init__(self, args: Arguments): self.register_parameter('bias', None) # Select the forward function for the operating mode. - self.forward_fn = ( - self.parallel_forward_once - if args.moe_expert_model_parallelism else self.forward_once - ) + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) def expert_capacity(self, tokens): world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = ( - self.top_k * tokens * world_size / self.num_experts - ) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) return int(self.args.moe_capacity_factor * tokens_per_expert) def load_balancing_loss(self, tokens_per_expert, expert_scores): @@ -208,9 +195,7 @@ def forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts) - ) + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. @@ -256,9 +241,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts) - ) + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. @@ -270,9 +253,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Pass token count information to the device on which the # target expert resides. - parallel_tokens_per_expert = torch.empty_like( - repeated_tokens_per_expert, - ) + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) tpe_handle = torch.distributed.all_to_all_single( parallel_tokens_per_expert, repeated_tokens_per_expert, @@ -296,12 +277,8 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Reshape to [world_size, num_experts_per_rank]. world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = ( - repeated_tokens_per_expert.view(world_size, experts_per_rank) - ) - parallel_tokens_per_expert = ( - parallel_tokens_per_expert.view(world_size, experts_per_rank) - ) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) # TODO(tgale): It might be faster to do this on the GPU and # then communicate the results back to the host. @@ -343,10 +320,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): parallel_tokens_per_expert.flatten(), 0, ) - replicate_bins = ( - replicate_bins.view(1) - if not len(replicate_bins.size()) else replicate_bins - ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) # Construct the expert indices for the permuted tokens. parallel_top_expert = torch.remainder( @@ -375,10 +349,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): dtype=torch.int, ) parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) - if not len(parallel_bins.size()) else parallel_bins - ) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 3bed037..49b5f73 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -7,17 +7,11 @@ def is_moe_param(tensor: torch.Tensor) -> bool: def get_expert_parallel_world_size(args: Arguments) -> int: - return ( - torch.distributed.get_world_size(args.expert_parallel_group) - if args.moe_expert_model_parallelism else 1 - ) + return (torch.distributed.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) def get_expert_parallel_rank(args: Arguments) -> int: - return ( - torch.distributed.get_rank(args.expert_parallel_group) - if args.moe_expert_model_parallelism else 0 - ) + return (torch.distributed.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) def set_expert_model_parallel_attributes( @@ -29,9 +23,7 @@ def set_expert_model_parallel_attributes( def param_is_expert_model_parallel(param: torch.Tensor) -> bool: - return ( - hasattr(param, 'expert_model_parallel') and param.expert_model_parallel - ) + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) def copy_expert_model_parallel_attributes( @@ -47,17 +39,11 @@ def copy_expert_model_parallel_attributes( def get_weight_parallel_world_size(args: Arguments) -> int: - return ( - torch.distributed.get_world_size(args.weight_parallel_group) - if args.moe_weight_parallelism else 1 - ) + return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1) def get_weight_parallel_rank(args: Arguments) -> int: - return ( - torch.distributed.get_rank(args.weight_parallel_group) - if args.moe_weight_parallelism else 0 - ) + return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0) def synchronized_print(group, *x): @@ -75,9 +61,7 @@ def expert_sharding_degree(args: Arguments) -> int: esd = min(world_size, args.moe_num_experts) if (args.moe_num_experts % esd) != 0: - raise ValueError( - f"Cannot shard {args.moe_num_experts} experts {esd} ways.", - ) + raise ValueError(f"Cannot shard {args.moe_num_experts} experts {esd} ways.",) return esd @@ -87,9 +71,7 @@ def hidden_sharding_degree(args: Arguments) -> int: hsd = world_size // esd if (args.ffn_hidden_size % hsd) != 0: - raise ValueError( - f"Cannot shard {args.ffn_hidden_size} features {hsd} ways.", - ) + raise ValueError(f"Cannot shard {args.ffn_hidden_size} features {hsd} ways.",) if (esd * hsd) != world_size: raise ValueError( f"Invalid sharding. 'expert_sharding_degree' " diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index 5f027dc..b81366b 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -54,6 +54,5 @@ def __init__(self, args: Arguments): ) def forward(self, x): - x1 = F.gelu(torch.matmul(x, self.w1), - approximate="tanh") * torch.matmul(x, self.v1) + x1 = F.gelu(torch.matmul(x, self.w1), approximate="tanh") * torch.matmul(x, self.v1) return torch.matmul(x1, self.w2) diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py index 272d3d2..579832d 100644 --- a/megablocks/layers/weight_parallel.py +++ b/megablocks/layers/weight_parallel.py @@ -260,10 +260,7 @@ def forward(ctx, x, w1, w2, topo, group): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if ( - not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous() - ): + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). @@ -305,10 +302,7 @@ def backward(ctx, ddsd_out): x, w1, w2 = ctx.saved_tensors[:3] sdd_out = stk.Matrix(ctx.shape, *ctx.saved_tensors[3:]) - if ( - not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2] - ): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError("Expected all MLP inputs to need grad.") # Start the weight gather asynchronously to overlap with the diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index c6c31e1..cc977ec 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -49,9 +49,7 @@ class HistogramBenchmark(parameterized.TestCase): def testHistogram(self, n, dtype, max_val): x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: ops.histogram(x, max_val), - ) + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) arguments = { "n": n, "dtype": dtype, @@ -63,9 +61,7 @@ def testHistogram(self, n, dtype, max_val): def testTorchHistogram(self, n, dtype, max_val): x = torch.randint(0, 128, (n,)).cuda().to(dtype) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: torch.histc(x, max_val, 0, max_val - 1), - ) + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) arguments = { "n": n, "dtype": dtype, diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index f2ef1d0..74d9324 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -14,18 +14,14 @@ class ReplicateOp(torch.autograd.Function): @staticmethod def forward(ctx, x, bins, num_outputs): ctx.save_for_backward(bins) - out = torch.empty((x.shape[0], num_outputs), - dtype=x.dtype, - device=x.device) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) ops.replicate_forward(x, bins, out) return out @staticmethod def backward(ctx, grad): bins, = ctx.saved_tensors - out = torch.empty((grad.shape[0], bins.shape[0]), - dtype=grad.dtype, - device=grad.device) + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) ops.replicate_backward(grad, bins, out) return out, None, None diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index 8def5de..47df184 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -58,9 +58,7 @@ def testSort(self, n, dtype, max_val): end_bit = int(np.ceil(np.log2(max_val))) x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: ops.sort(x, end_bit), - ) + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) arguments = { "n": n, "dtype": dtype, diff --git a/pyproject.toml b/pyproject.toml index 9d74667..f617559 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,6 +48,9 @@ include = [ # Ruff global [tool.ruff] + +preview = true # enable preview features, see https://docs.astral.sh/ruff/preview/ + exclude = [ "build/**", "docs/**", @@ -60,10 +63,10 @@ select = [ "C4", # flake8-comprehensions # TODO port pydocstyle # "D", # pydocstyle - "LOG", - "PERF", + "LOG", # flake8-logging + "PERF", # Perflint "PLE", - "COM812", + "COM812", # missing-trailing-comma ] ignore = [ @@ -74,6 +77,7 @@ ignore = [ # Yapf [tool.yapf] + # Align closing bracket with visual indentation. align_closing_bracket_with_visual_indent = false @@ -153,7 +157,7 @@ blank_line_before_nested_class_or_def = true coalesce_brackets = true # The column limit. -column_limit = 80 +column_limit = 120 # The style for continuation alignment. Possible values are: # diff --git a/setup.py b/setup.py index ec6e4d8..d10edc0 100644 --- a/setup.py +++ b/setup.py @@ -16,9 +16,7 @@ "--optimize=2", ] if device_capability: - nvcc_flags.append( - f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}", - ) + nvcc_flags.append(f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}",) ext_modules = [ CUDAExtension( @@ -59,10 +57,7 @@ 'mosaicml>=0.22.0', ] -extra_deps['all'] = list({ - dep for key, deps in extra_deps.items() for dep in deps - if key not in {'testing'} -}) +extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) setup( name="megablocks", diff --git a/tests/conftest.py b/tests/conftest.py index 9fff00e..97bfa92 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -37,9 +37,7 @@ def _get_option( val = None if val is None: if default is None: - pytest.fail( - f'Config option {name} is not specified but is required', - ) + pytest.fail(f'Config option {name} is not specified but is required',) val = default assert isinstance(val, str) return val diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 3a3fab2..68f9406 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -50,8 +50,7 @@ def configure_dist(request: pytest.FixtureRequest): device = None for item in request.session.items: - device = DeviceCPU( - ) if item.get_closest_marker('gpu') is None else DeviceGPU() + device = DeviceCPU() if item.get_closest_marker('gpu') is None else DeviceGPU() break assert device is not None diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index 9ab70fe..01a7e42 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -28,13 +28,10 @@ (16, 1024, 128, 1, 1), ) -_FORWARD_TESTS_GROUPED_MLP = tuple([ - p + ('grouped',) for p in _FORWARD_TESTS_DEFAULT -]) if gg.grouped_gemm_is_available() else () +_FORWARD_TESTS_GROUPED_MLP = tuple([p + ('grouped',) for p in _FORWARD_TESTS_DEFAULT + ],) if gg.grouped_gemm_is_available() else () -_FORWARD_TESTS_SPARSE_MLP = tuple([ - p + ('sparse',) for p in _FORWARD_TESTS_DEFAULT -]) +_FORWARD_TESTS_SPARSE_MLP = tuple([p + ('sparse',) for p in _FORWARD_TESTS_DEFAULT]) _FORWARD_TESTS = (_FORWARD_TESTS_SPARSE_MLP + _FORWARD_TESTS_GROUPED_MLP) @@ -80,9 +77,7 @@ def construct_moes( ne, hs, fhs = moe_mlp.experts.mlp.w1.size() w1 = dmoe_mlp.experts.mlp.w1.view([ne, fhs, hs]) moe_mlp.experts.mlp.w1.copy_(torch.transpose(w1, 1, 2).contiguous()) - moe_mlp.experts.mlp.w2.copy_( - dmoe_mlp.experts.mlp.w2.view([ne, fhs, hs]), - ) + moe_mlp.experts.mlp.w2.copy_(dmoe_mlp.experts.mlp.w2.view([ne, fhs, hs]),) moe_mlp.router.layer.weight.copy_(dmoe_mlp.router.layer.weight) if moe_num_experts == 1: mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze()) @@ -91,8 +86,7 @@ def construct_moes( @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), - _FORWARD_TESTS) +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) def test_dmoe_forward( bs: int, sl: int, @@ -116,8 +110,7 @@ def test_dmoe_forward( @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), - _FORWARD_TESTS) +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) def test_dmoe_forward_backward( bs: int, sl: int, @@ -173,8 +166,7 @@ def test_dmoe_forward_vs_baseline( @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), - _FORWARD_TESTS) +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) def test_dmoe_forward_vs_moe( bs: int, sl: int, diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index 8591a7a..ae5b5bf 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -60,8 +60,7 @@ def construct_moe( @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), - _FORWARD_TESTS) +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): x = torch.randn(sl, bs, hs).half().cuda() @@ -78,8 +77,7 @@ def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), - _FORWARD_TESTS) +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) def test_moe_forward_backward( bs: int, sl: int, diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py index b7fce22..d32746a 100644 --- a/tests/layers/parallelism_test.py +++ b/tests/layers/parallelism_test.py @@ -81,8 +81,7 @@ def test_expert_parallel_versus_weight_parallel( # NOTE: Include the rank in the seed so we get different data per rank. rank = torch.distributed.get_rank(group) torch.manual_seed(1234 * rank) - x = torch.randn((batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), + x = torch.randn((batch_size, sequence_length, hidden_size), device=torch.cuda.current_device(), dtype=torch.float32).requires_grad_(True) # Test forward. diff --git a/tests/ops/histogram_test.py b/tests/ops/histogram_test.py index 5af55e7..2f98fb7 100644 --- a/tests/ops/histogram_test.py +++ b/tests/ops/histogram_test.py @@ -78,7 +78,5 @@ def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int): x = torch.randint(0, max_val, (m, n)).cuda().to(dtype) out = ops.histogram(x, max_val) - expected_out = torch.stack([ - torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1) - ]) + expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)]) assert torch.all(torch.eq(out, expected_out)) diff --git a/tests/ops/sort_test.py b/tests/ops/sort_test.py index f65ae16..3a527de 100644 --- a/tests/ops/sort_test.py +++ b/tests/ops/sort_test.py @@ -31,9 +31,7 @@ ] -def torch_to_numpy_dtype( - dtype: torch.dtype, -) -> Union[np.int16, np.int32, np.int64]: +def torch_to_numpy_dtype(dtype: torch.dtype,) -> Union[np.int16, np.int32, np.int64]: types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = { torch.int16: np.int16, torch.int32: np.int32, From 021052c11f5d666a802c424d4472b1f334b03162 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 03:17:44 +0000 Subject: [PATCH 08/43] isort --- .pre-commit-config.yaml | 4 ++++ pyproject.toml | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index d991b58..97c1da0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -18,3 +18,7 @@ repos: types: [python] additional_dependencies: - toml +- repo: https://github.com/pycqa/isort + hooks: + - id: isort + rev: 5.12.0 diff --git a/pyproject.toml b/pyproject.toml index f617559..e7610d7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,14 @@ ignore = [ "PERF4", ] +# iSort +[tool.isort] +multi_line_output = 0 +line_length = 120 +skip = ["env", "wandb", "runs", "build", "node_modules" ] +include_trailing_comma = true +split_on_trailing_comma = true + # Yapf [tool.yapf] From 37feb12aedd9cd85aea6aa5f510ea709b60c63d7 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 03:19:01 +0000 Subject: [PATCH 09/43] isort --- megablocks/__init__.py | 1 - megablocks/layers/activation_fn.py | 2 +- megablocks/layers/arguments.py | 6 ++++-- megablocks/layers/common.py | 3 ++- megablocks/layers/dmlp_registry.py | 4 ++-- megablocks/layers/dmoe.py | 10 ++++------ megablocks/layers/mpu.py | 3 ++- megablocks/layers/router.py | 3 ++- megablocks/layers/sharedexpert_registry.py | 4 ++-- megablocks/layers/testing.py | 3 ++- megablocks/layers/weight_parallel.py | 3 ++- megablocks/ops/binned_gather.py | 3 ++- megablocks/ops/binned_scatter.py | 3 ++- megablocks/ops/gather.py | 3 ++- megablocks/ops/histogram_benchmark.py | 5 +++-- megablocks/ops/padded_gather.py | 3 ++- megablocks/ops/padded_scatter.py | 3 ++- megablocks/ops/scatter.py | 3 ++- megablocks/ops/sort_benchmark.py | 5 +++-- tests/fixtures/fixtures.py | 1 + 20 files changed, 42 insertions(+), 29 deletions(-) diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 1e527a4..9fb35d3 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,5 +1,4 @@ from megablocks.layers import dmoe, moe -"""Key classes are available directly in the ``MegaBlocks`` namespace.""" __all__ = [ 'dmoe', diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 3038f44..b0d2b53 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,7 +1,7 @@ from typing import Callable -import torch import stk +import torch def act_fn( diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index fe67ad6..5714c34 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -1,9 +1,11 @@ import dataclasses from functools import partial -import megablocks.grouped_gemm_util as grouped_gemm +from typing import Any, Callable, Optional, Union + import torch import torch.nn.functional as F -from typing import Any, Callable, Optional, Union + +import megablocks.grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Callable[[torch.Tensor], None] diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index b9b2ab1..ff7ffc3 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -1,6 +1,7 @@ -from megablocks.layers.arguments import Arguments import torch +from megablocks.layers.arguments import Arguments + def dtype(args: Arguments): if args.fp16: diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index a5b86d8..3d1a06e 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -1,6 +1,6 @@ from typing import Union -from megablocks.layers import mlp -from megablocks.layers import glu + +from megablocks.layers import glu, mlp from megablocks.layers.arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 5ae5ee8..310d5ce 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -1,13 +1,11 @@ -from megablocks.layers import common -from megablocks.layers import moe -from megablocks.layers import dmlp_registry -from megablocks.layers import mpu -from megablocks.layers.arguments import Arguments -import megablocks.ops as ops import numpy as np import stk import torch +import megablocks.ops as ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + def promote_scalar(x): return x.view(1) if not len(x.size()) else x diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 49b5f73..2b3d573 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,6 +1,7 @@ -from megablocks.layers.arguments import Arguments import torch +from megablocks.layers.arguments import Arguments + def is_moe_param(tensor: torch.Tensor) -> bool: return hasattr(tensor, 'expert_model_parallel') diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 5039cf4..f05f607 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,6 +1,7 @@ +import torch + from megablocks.layers import common from megablocks.layers.arguments import Arguments -import torch # NOTE: To enable end-to-end benchmarking without convergence we diff --git a/megablocks/layers/sharedexpert_registry.py b/megablocks/layers/sharedexpert_registry.py index 7396de8..3fd1840 100644 --- a/megablocks/layers/sharedexpert_registry.py +++ b/megablocks/layers/sharedexpert_registry.py @@ -1,6 +1,6 @@ from typing import Union -from megablocks.layers import mlp -from megablocks.layers import glu + +from megablocks.layers import glu, mlp from megablocks.layers.arguments import Arguments _REGISTRY = { diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index b81366b..e4c89ab 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -1,7 +1,8 @@ -from megablocks.layers.arguments import Arguments import torch import torch.nn.functional as F +from megablocks.layers.arguments import Arguments + def allclose(x, y, pct=0.5): mask = torch.isclose(x, y, rtol=5e-2) diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py index 579832d..5724eda 100644 --- a/megablocks/layers/weight_parallel.py +++ b/megablocks/layers/weight_parallel.py @@ -1,7 +1,8 @@ -from megablocks.layers import gelu import stk import torch +from megablocks.layers import gelu + def _gather_weights(w, group, parallel_w=None, async_op=False): """Gather the weights across the process group. diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 94b6ea5..d1315f7 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,6 +1,7 @@ import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for binned_gather kernel. diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index 143ac87..ce3c92e 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,6 +1,7 @@ import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for binned_scatter kernel. diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index ec78d2c..efffed5 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,6 +1,7 @@ import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for gather kernel. diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index cc977ec..1e5a047 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -1,9 +1,10 @@ import unittest -from absl.testing import parameterized -from megablocks import ops import numpy as np import torch +from absl.testing import parameterized + +from megablocks import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 696629b..90bfbc7 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,6 +1,7 @@ import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for padded_gather kernel. diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 8780b33..04766e8 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,6 +1,7 @@ import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for padded_scatter kernel. diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index b4a8576..9f5725f 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,6 +1,7 @@ import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for scatter kernel. diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index 47df184..d906fbb 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -1,9 +1,10 @@ import unittest -from absl.testing import parameterized -from megablocks import ops import numpy as np import torch +from absl.testing import parameterized + +from megablocks import ops _SORT_TESTS = ( (16384, torch.int32, None), diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index 48645a8..2f50624 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -1,4 +1,5 @@ import pytest + from tests.conftest import _get_option From 9eff68424ae50054132c3dec4383f88f9a85e36a Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 03:26:30 +0000 Subject: [PATCH 10/43] pycln --- .pre-commit-config.yaml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 97c1da0..b9bf01c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,3 +22,8 @@ repos: hooks: - id: isort rev: 5.12.0 +- repo: https://github.com/hadialqattan/pycln + rev: v2.1.2 + hooks: + - id: pycln + args: [. --all] From 1e6b92745d480d231d6b4cab2ab72852fa0c79c3 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 03:31:11 +0000 Subject: [PATCH 11/43] pre-commit-hooks --- .pre-commit-config.yaml | 30 +++++ Dockerfile | 2 +- LICENSE | 2 +- MANIFEST.in | 2 +- megablocks/backend/kernels.py | 6 +- megablocks/benchmark_util.py | 14 +-- megablocks/grouped_gemm_util.py | 4 +- megablocks/layers/arguments.py | 2 +- megablocks/layers/dmlp_registry.py | 2 +- megablocks/layers/gelu.py | 2 +- megablocks/layers/glu.py | 6 +- megablocks/layers/memory_test.py | 12 +- megablocks/layers/mlp.py | 6 +- megablocks/layers/moe.py | 24 ++-- megablocks/layers/mpu.py | 10 +- megablocks/layers/testing.py | 6 +- megablocks/layers/weight_parallel.py | 2 +- megablocks/ops/all_to_all_benchmark.py | 6 +- megablocks/ops/histogram_benchmark.py | 24 ++-- megablocks/ops/matmul_benchmark.py | 126 ++++++++++----------- megablocks/ops/padded_scatter_benchmark.py | 10 +- megablocks/ops/permute_benchmark.py | 40 +++---- megablocks/ops/round_up.py | 2 +- megablocks/ops/sort_benchmark.py | 20 ++-- setup.py | 44 +++---- yamls/triton_benchmark.yaml | 2 +- 26 files changed, 218 insertions(+), 188 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b9bf01c..5789e73 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,3 +27,33 @@ repos: hooks: - id: pycln args: [. --all] +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-added-large-files + - id: check-ast + - id: check-builtin-literals + - id: check-case-conflict + - id: check-docstring-first + - id: check-executables-have-shebangs + - id: check-json + - id: check-shebang-scripts-are-executable + - id: pretty-format-json + args: + - --autofix + - --no-sort-keys + - --indent=4 + - --no-ensure-ascii + - id: check-merge-conflict + - id: check-symlinks + - id: check-toml + - id: check-vcs-permalinks + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: destroyed-symlinks + - id: double-quote-string-fixer + - id: end-of-file-fixer + - id: fix-byte-order-marker + - id: mixed-line-ending + - id: trailing-whitespace diff --git a/Dockerfile b/Dockerfile index c71ed0a..e5d9ef8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,4 +6,4 @@ RUN pip install flash-attn ENV PYTHONPATH="/mount/megablocks/third_party/Megatron-LM:${PYTHONPATH}" -WORKDIR /mount/megablocks \ No newline at end of file +WORKDIR /mount/megablocks diff --git a/LICENSE b/LICENSE index 63fd052..be2d25e 100644 --- a/LICENSE +++ b/LICENSE @@ -387,4 +387,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in index 99749aa..b701a75 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ recursive-include csrc *.h -recursive-include csrc *.cu \ No newline at end of file +recursive-include csrc *.cu diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index 3366f48..c193f9d 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -5,7 +5,7 @@ def assert_is_tensor(x, ndim): if x.ndim != ndim: - raise ValueError(f"Expected {ndim}-tensor but got {x.ndim}-tensor") + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') def assert_is_matrix(x): @@ -14,12 +14,12 @@ def assert_is_matrix(x): def assert_is_vector(x): if x.ndim != 1: - raise ValueError(f"Expected 1-tensor but got {x.ndim}-tensor") + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') def assert_equal(a, b): if a != b: - raise ValueError(f"Expected dimensions to be equal but got {a} and {b}.",) + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) # a: (tokens, hidden_size), real. diff --git a/megablocks/benchmark_util.py b/megablocks/benchmark_util.py index 2984ac0..9902a58 100644 --- a/megablocks/benchmark_util.py +++ b/megablocks/benchmark_util.py @@ -3,14 +3,14 @@ def log_benchmark(name, arguments, time, std): - print("=" * 60) - print(f"{name} Benchmark") - print("Benchmark Parameters:") + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') for (key, value) in arguments.items(): - print(f"{key} = {value}") - print("Results:") - print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std)) - print("=" * 60) + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) def benchmark_function(fn, iterations=100, warmup=10): diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index 36fd7f0..3e6772b 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -10,8 +10,8 @@ def grouped_gemm_is_available(): def assert_grouped_gemm_is_available(): assert grouped_gemm_is_available( - ), ("Grouped GEMM not available. Please run " - "`pip install git+https://github.com/tgale96/grouped_gemm@main`.") + ), ('Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.') backend = grouped_gemm.backend if grouped_gemm_is_available() else None diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 5714c34..1597b9f 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -12,7 +12,7 @@ _ALLOWED_BITWIDTHS = (-1, 4, 8) -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate="tanh") +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') @dataclasses.dataclass diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index 3d1a06e..62f5e18 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -21,7 +21,7 @@ def get(args: Arguments) -> MlpType: """Returns an MLP for use in a dMoE instance. Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs + MLP instance. This only contains MLPs for use in dMoEs (ie. only for the dropless versions of MoEs). Args: diff --git a/megablocks/layers/gelu.py b/megablocks/layers/gelu.py index 50a7c56..7dd69c7 100644 --- a/megablocks/layers/gelu.py +++ b/megablocks/layers/gelu.py @@ -30,7 +30,7 @@ def gelu(x: stk.Matrix): assert isinstance(x, stk.Matrix) return stk.Matrix( x.size(), - F.gelu(x.data, approximate="tanh"), + F.gelu(x.data, approximate='tanh'), x.row_indices, x.column_indices, x.offsets, diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index d321aeb..25cb39b 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -42,12 +42,12 @@ def __init__(self, args: Arguments): ) if self.args.moe_weight_parallelism: - raise NotImplementedError("Weight parallelism not yet supported with GLU.",) + raise NotImplementedError('Weight parallelism not yet supported with GLU.',) def forward(self, x, topo): if self.args.memory_optimized_mlp: raise NotImplementedError( - "Memory optimized implementation not yet supported with GLU with sparse kernels.", + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', ) w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) @@ -104,7 +104,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + raise ValueError('Expected all MLP inputs to need grad.') # Unpack saved tensors # dtype = ctx.dtype diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index 22bd3dc..3871fd2 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -53,8 +53,8 @@ def test_memory( # Report peak memory. mem = torch.cuda.max_memory_allocated() - print("Max Memory Allocated = {:0.0f}MiB".format(mem / 1e6)) - print("Max Memory Reserved = {:0.0f}MiB".format(torch.cuda.max_memory_reserved() / 1e6,),) + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) # Calculate weight and gradient memory usage. weight_memory = 2 * ( @@ -71,8 +71,8 @@ def grad_numel(x): ) weight_memory += grad_memory - print("Weight Memory Allocated = {:0.0f}MiB".format(weight_memory / 1e6)) - print("Activation Memory Allocated = {:0.0f}MiB".format((mem - weight_memory) / 1e6,),) + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) # Manually calculate GPU memory usage from the garbage # collector. @@ -82,10 +82,10 @@ def grad_numel(x): tensors = sorted(tensors, key=lambda x: -x.numel()) for i, t in enumerate(tensors): total += t.numel() - print(f"{i}: {t.shape}, {t.numel() * 2}") + print(f'{i}: {t.shape}, {t.numel() * 2}') del tensors - print("Total Bytes Found = {:0.0f}MiB".format(total * 2 / 1e6)) + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) if __name__ == '__main__': diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 024e405..3fe7437 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -243,7 +243,7 @@ def forward(ctx, x, w1, w2, topo, activation_fn): @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + raise ValueError('Expected all MLP inputs to need grad.') # unpack saved tensors # dtype = ctx.dtype @@ -467,7 +467,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + raise ValueError('Expected all MLP inputs to need grad.') # Unpack saved tensors # dtype = ctx.dtype @@ -540,7 +540,7 @@ def forward(self, x, tokens_per_expert): w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) if self.args.moe_weight_parallelism: - raise NotImplementedError("Weight parallelism not yet supported with GroupedMLP.",) + raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',) if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 52ab852..1429cb3 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -37,21 +37,21 @@ def batched_load_balancing_loss(args: Arguments): if len(tokens_per_expert) != num_layers_per_pipeline_stage: raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', ) if len(expert_scores) != num_layers_per_pipeline_stage: raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}", + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', ) # Verify the shape of the tokens_per_expert and expert_scores tensors. diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 2b3d573..58314ba 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -53,7 +53,7 @@ def synchronized_print(group, *x): for i in range(world_size): torch.distributed.barrier(group) if i == rank: - print(f"rank = {rank}", *x) + print(f'rank = {rank}', *x) # Helpers for expert/tensor sharding. @@ -62,7 +62,7 @@ def expert_sharding_degree(args: Arguments) -> int: esd = min(world_size, args.moe_num_experts) if (args.moe_num_experts % esd) != 0: - raise ValueError(f"Cannot shard {args.moe_num_experts} experts {esd} ways.",) + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) return esd @@ -72,12 +72,12 @@ def hidden_sharding_degree(args: Arguments) -> int: hsd = world_size // esd if (args.ffn_hidden_size % hsd) != 0: - raise ValueError(f"Cannot shard {args.ffn_hidden_size} features {hsd} ways.",) + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) if (esd * hsd) != world_size: raise ValueError( f"Invalid sharding. 'expert_sharding_degree' " - f"({esd}) * hidden_sharding_degree " - f"({hsd}) != world_size ({world_size}).", + f'({esd}) * hidden_sharding_degree ' + f'({hsd}) != world_size ({world_size}).', ) return hsd diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index e4c89ab..2fd0782 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -8,7 +8,7 @@ def allclose(x, y, pct=0.5): mask = torch.isclose(x, y, rtol=5e-2) pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) + print('{:.2f}% of values not close.'.format(pct_diff)) return False return True @@ -36,7 +36,7 @@ def __init__(self, args: Arguments): def forward(self, x): return torch.matmul( - F.gelu(torch.matmul(x, self.w1), approximate="tanh"), + F.gelu(torch.matmul(x, self.w1), approximate='tanh'), self.w2, ) @@ -55,5 +55,5 @@ def __init__(self, args: Arguments): ) def forward(self, x): - x1 = F.gelu(torch.matmul(x, self.w1), approximate="tanh") * torch.matmul(x, self.v1) + x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1) return torch.matmul(x1, self.w2) diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py index 5724eda..d0b7b16 100644 --- a/megablocks/layers/weight_parallel.py +++ b/megablocks/layers/weight_parallel.py @@ -304,7 +304,7 @@ def backward(ctx, ddsd_out): sdd_out = stk.Matrix(ctx.shape, *ctx.saved_tensors[3:]) if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + raise ValueError('Expected all MLP inputs to need grad.') # Start the weight gather asynchronously to overlap with the # weight gradient computation and gelu recompute. diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index a26b8fb..5b247e5 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -33,8 +33,8 @@ def benchmark_all_to_all(group, sl, hs): x = torch.randn((sl, hs)).cuda().half() details = { - "world_size": world_size, - "message_size (B)": send_recv_sizes[0] * hs * 2, # 2B elements. + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. } def benchmark(): @@ -43,7 +43,7 @@ def benchmark(): time, std = benchmark_util.benchmark_function(benchmark) if torch.distributed.get_rank(group) == 0: - benchmark_util.log_benchmark("All-To-All", details, time, std) + benchmark_util.log_benchmark('All-To-All', details, time, std) if __name__ == '__main__': diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index 1e5a047..29ab76a 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -35,13 +35,13 @@ def benchmark_function(fn, iterations=10): def log_benchmark(arguments, mean_t, std_t): - print("=" * 60) - print("Benchmark Parameters:") + print('=' * 60) + print('Benchmark Parameters:') for (key, value) in arguments.items(): - print(f"{key} = {value}") - print("Results:") - print("mean / std = {:.2f}ms / {:.2f}ms".format(mean_t, std_t)) - print("=" * 60) + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) class HistogramBenchmark(parameterized.TestCase): @@ -52,9 +52,9 @@ def testHistogram(self, n, dtype, max_val): mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) arguments = { - "n": n, - "dtype": dtype, - "max_val": max_val, + 'n': n, + 'dtype': dtype, + 'max_val': max_val, } log_benchmark(arguments, mean_t, std_t) @@ -64,9 +64,9 @@ def testTorchHistogram(self, n, dtype, max_val): mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) arguments = { - "n": n, - "dtype": dtype, - "max_val": max_val, + 'n': n, + 'dtype': dtype, + 'max_val': max_val, } log_benchmark(arguments, mean_t, std_t) diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index b039cee..45d9ade 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -28,9 +28,9 @@ def transpose_view(x): def log_benchmark(name, arguments, time, std, flops): benchmark_util.log_benchmark(name, arguments, time, std) - print("flops = {:.2f}B".format(flops / 1e9)) - print("throughput = {:.2f}T".format(flops / 1e9 / time)) - print("=" * 60) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) class MatmulBenchmark(parameterized.TestCase): @@ -103,13 +103,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "0::Fwd::SDD::NT", + '0::Fwd::SDD::NT', arguments, mean_t, std_t, @@ -127,13 +127,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "0::GradX::DSD::NN", + '0::GradX::DSD::NN', arguments, mean_t, std_t, @@ -151,13 +151,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "0::GradW::DSD::TN", + '0::GradW::DSD::TN', arguments, mean_t, std_t, @@ -175,13 +175,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "1::Fwd::DSD::NN", + '1::Fwd::DSD::NN', arguments, mean_t, std_t, @@ -201,13 +201,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "1::GradX::SDD::NT", + '1::GradX::SDD::NT', arguments, mean_t, std_t, @@ -227,13 +227,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "1::GradW::DSD::TN", + '1::GradW::DSD::TN', arguments, mean_t, std_t, @@ -254,13 +254,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "0::Fwd:DDD::NT", + '0::Fwd:DDD::NT', arguments, mean_t, std_t, @@ -280,13 +280,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "0:GradX:DDD::NN", + '0:GradX:DDD::NN', arguments, mean_t, std_t, @@ -306,13 +306,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "0:GradW:DDD::TN", + '0:GradW:DDD::TN', arguments, mean_t, std_t, @@ -330,13 +330,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "1::Fwd::DDD::NN", + '1::Fwd::DDD::NN', arguments, mean_t, std_t, @@ -356,13 +356,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "1::GradX::DDD::NT", + '1::GradX::DDD::NT', arguments, mean_t, std_t, @@ -382,13 +382,13 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } log_benchmark( - "1::GradW::DDD::TN", + '1::GradW::DDD::TN', arguments, mean_t, std_t, diff --git a/megablocks/ops/padded_scatter_benchmark.py b/megablocks/ops/padded_scatter_benchmark.py index 8580d38..5f814c8 100644 --- a/megablocks/ops/padded_scatter_benchmark.py +++ b/megablocks/ops/padded_scatter_benchmark.py @@ -47,12 +47,12 @@ def benchmark(): time, std = benchmark_util.benchmark_function(benchmark) benchmark_util.log_benchmark( - "Padded Scatter", + 'Padded Scatter', { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne, - "top_k": top_k, + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, }, time, std, diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index 886781b..b97c82b 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -42,11 +42,11 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("BinnedGather", arguments, mean_t, std_t) + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testBinnedScatter(self, sl, hs, ne): @@ -66,11 +66,11 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("BinnedScatter", arguments, mean_t, std_t) + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testPaddedGather(self, sl, hs, ne): @@ -90,11 +90,11 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("PaddedGather", arguments, mean_t, std_t) + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testPaddedScatter(self, sl, hs, ne): @@ -115,11 +115,11 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("PaddedScatter", arguments, mean_t, std_t) + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testCopy(self, sl, hs, ne): @@ -135,11 +135,11 @@ def benchmark(): mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne, + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("Copy", arguments, mean_t, std_t) + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) if __name__ == '__main__': diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index fc81d61..ed872d3 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -8,4 +8,4 @@ def round_up(x, value): # TODO(tgale): If this becomes and issue # do this in a custom kernel. We only expect # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode="trunc") * value + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index d906fbb..4a767d0 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -41,13 +41,13 @@ def benchmark_function(fn, iterations=10): def log_benchmark(arguments, mean_t, std_t): - print("=" * 60) - print("Benchmark Parameters:") + print('=' * 60) + print('Benchmark Parameters:') for (key, value) in arguments.items(): - print(f"{key} = {value}") - print("Results:") - print("mean / std = {:.2f}ms / {:.2f}ms".format(mean_t, std_t)) - print("=" * 60) + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) class SortBenchmark(parameterized.TestCase): @@ -61,9 +61,9 @@ def testSort(self, n, dtype, max_val): mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) arguments = { - "n": n, - "dtype": dtype, - "max_val": max_val, + 'n': n, + 'dtype': dtype, + 'max_val': max_val, } log_benchmark(arguments, mean_t, std_t) @@ -73,7 +73,7 @@ def testTorchSort(self, n): mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) arguments = { - "n": n, + 'n': n, } log_benchmark(arguments, mean_t, std_t) diff --git a/setup.py b/setup.py index d10edc0..5ef72dd 100644 --- a/setup.py +++ b/setup.py @@ -4,28 +4,28 @@ from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension -if os.environ.get("TORCH_CUDA_ARCH_LIST"): +if os.environ.get('TORCH_CUDA_ARCH_LIST'): # Let PyTorch builder to choose device to target for. - device_capability = "" + device_capability = '' else: device_capability = torch.cuda.get_device_capability() - device_capability = f"{device_capability[0]}{device_capability[1]}" + device_capability = f'{device_capability[0]}{device_capability[1]}' nvcc_flags = [ - "--ptxas-options=-v", - "--optimize=2", + '--ptxas-options=-v', + '--optimize=2', ] if device_capability: - nvcc_flags.append(f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}",) + nvcc_flags.append(f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}',) ext_modules = [ CUDAExtension( - "megablocks_ops", - ["csrc/ops.cu"], - include_dirs=["csrc"], + 'megablocks_ops', + ['csrc/ops.cu'], + include_dirs=['csrc'], extra_compile_args={ - "cxx": ["-fopenmp"], - "nvcc": nvcc_flags, + 'cxx': ['-fopenmp'], + 'nvcc': nvcc_flags, }, ), ] @@ -40,7 +40,7 @@ extra_deps = {} -extra_deps["gg"] = [ +extra_deps['gg'] = [ 'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb', ] @@ -60,22 +60,22 @@ extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) setup( - name="megablocks", - version="0.5.1", - author="Trevor Gale", - author_email="tgale@stanford.edu", - description="MegaBlocks", + name='megablocks', + version='0.5.1', + author='Trevor Gale', + author_email='tgale@stanford.edu', + description='MegaBlocks', long_description=open('README.md').read(), long_description_content_type='text/markdown', - url="https://github.com/stanford-futuredata/megablocks", + url='https://github.com/stanford-futuredata/megablocks', classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", + 'Programming Language :: Python :: 3', + 'License :: OSI Approved :: BSD License', + 'Operating System :: Unix', ], packages=find_packages(), ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension}, + cmdclass={'build_ext': BuildExtension}, install_requires=install_requires, extras_require=extra_deps, ) diff --git a/yamls/triton_benchmark.yaml b/yamls/triton_benchmark.yaml index 70c9626..a31cfb5 100644 --- a/yamls/triton_benchmark.yaml +++ b/yamls/triton_benchmark.yaml @@ -11,7 +11,7 @@ integrations: command: |- export ENABLE_TMA=1 export ENABLE_MMA_V3=1 - + cd triton/python pip install . --no-dependencies From f25369c6818211bb809a5b2d88c169a408744e09 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 03:35:05 +0000 Subject: [PATCH 12/43] add license --- .pre-commit-config.yaml | 11 +++++++++++ .pre-commit/FILE_HEADER | 2 ++ megablocks/__init__.py | 3 +++ megablocks/backend/__init__.py | 2 ++ megablocks/backend/kernels.py | 3 +++ megablocks/benchmark_util.py | 3 +++ megablocks/grouped_gemm_util.py | 3 +++ megablocks/layers/__init__.py | 3 +++ megablocks/layers/activation_fn.py | 3 +++ megablocks/layers/all_to_all.py | 3 +++ megablocks/layers/arguments.py | 3 +++ megablocks/layers/common.py | 3 +++ megablocks/layers/dmlp_registry.py | 3 +++ megablocks/layers/dmoe.py | 3 +++ megablocks/layers/gelu.py | 3 +++ megablocks/layers/glu.py | 3 +++ megablocks/layers/memory_test.py | 3 +++ megablocks/layers/mlp.py | 3 +++ megablocks/layers/moe.py | 3 +++ megablocks/layers/mpu.py | 3 +++ megablocks/layers/router.py | 3 +++ megablocks/layers/sharedexpert_registry.py | 3 +++ megablocks/layers/testing.py | 3 +++ megablocks/layers/weight_parallel.py | 3 +++ megablocks/ops/__init__.py | 3 +++ megablocks/ops/all_to_all_benchmark.py | 3 +++ megablocks/ops/binned_gather.py | 3 +++ megablocks/ops/binned_scatter.py | 3 +++ megablocks/ops/cumsum.py | 3 +++ megablocks/ops/gather.py | 3 +++ megablocks/ops/histogram.py | 3 +++ megablocks/ops/histogram_benchmark.py | 3 +++ megablocks/ops/matmul_benchmark.py | 3 +++ megablocks/ops/padded_gather.py | 3 +++ megablocks/ops/padded_scatter.py | 3 +++ megablocks/ops/padded_scatter_benchmark.py | 3 +++ megablocks/ops/permute_benchmark.py | 3 +++ megablocks/ops/repeat.py | 4 ++++ megablocks/ops/replicate.py | 3 +++ megablocks/ops/round_up.py | 3 +++ megablocks/ops/scatter.py | 3 +++ megablocks/ops/sort.py | 3 +++ megablocks/ops/sort_benchmark.py | 3 +++ megablocks/ops/sum.py | 4 ++++ megablocks/ops/topology.py | 3 +++ setup.py | 3 +++ tests/conftest.py | 3 +++ tests/fixtures/autouse.py | 3 +++ tests/fixtures/fixtures.py | 3 +++ tests/layers/glu_test.py | 3 +++ tests/layers/moe_test.py | 3 +++ tests/layers/parallelism_test.py | 3 +++ tests/ops/binned_scatter_test.py | 3 +++ 53 files changed, 167 insertions(+) create mode 100644 .pre-commit/FILE_HEADER diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 5789e73..22aca8c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -57,3 +57,14 @@ repos: - id: fix-byte-order-marker - id: mixed-line-ending - id: trailing-whitespace +- repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.5.4 + hooks: + - id: insert-license + args: + - --license-filepath + - .pre-commit/FILE_HEADER + - --comment-style + - "#" + - --allow-past-years + types: [python] diff --git a/.pre-commit/FILE_HEADER b/.pre-commit/FILE_HEADER new file mode 100644 index 0000000..69d0cd5 --- /dev/null +++ b/.pre-commit/FILE_HEADER @@ -0,0 +1,2 @@ +Copyright 2024 MosaicML MegaBlocks authors +SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 9fb35d3..3cdf43d 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + from megablocks.layers import dmoe, moe __all__ = [ diff --git a/megablocks/backend/__init__.py b/megablocks/backend/__init__.py index e69de29..1d3c2fd 100644 --- a/megablocks/backend/__init__.py +++ b/megablocks/backend/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index c193f9d..81ea6a0 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch import triton import triton.language as tl diff --git a/megablocks/benchmark_util.py b/megablocks/benchmark_util.py index 9902a58..fe7c998 100644 --- a/megablocks/benchmark_util.py +++ b/megablocks/benchmark_util.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import torch diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index 3e6772b..899bf60 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + try: import grouped_gemm except ImportError: diff --git a/megablocks/layers/__init__.py b/megablocks/layers/__init__.py index da9bec7..de11bcb 100644 --- a/megablocks/layers/__init__.py +++ b/megablocks/layers/__init__.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + from megablocks.layers.dmoe import dMoE from megablocks.layers.moe import MoE diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index b0d2b53..78d6d00 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + from typing import Callable import stk diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 07cc584..51f4b9e 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 1597b9f..f75bcc0 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import dataclasses from functools import partial from typing import Any, Callable, Optional, Union diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index ff7ffc3..27e7473 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from megablocks.layers.arguments import Arguments diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index 62f5e18..fe6c2a1 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + from typing import Union from megablocks.layers import glu, mlp diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 310d5ce..940477c 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import stk import torch diff --git a/megablocks/layers/gelu.py b/megablocks/layers/gelu.py index 7dd69c7..92a0741 100644 --- a/megablocks/layers/gelu.py +++ b/megablocks/layers/gelu.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import stk import torch import torch.nn.functional as F diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 25cb39b..badaf51 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import stk import torch diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index 3871fd2..78e1fa1 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import gc import torch diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 3fe7437..b27ad5d 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + from typing import Any import stk diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 1429cb3..d6687ed 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import torch diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 58314ba..b80ecd8 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from megablocks.layers.arguments import Arguments diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index f05f607..0b5b670 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from megablocks.layers import common diff --git a/megablocks/layers/sharedexpert_registry.py b/megablocks/layers/sharedexpert_registry.py index 3fd1840..24317c4 100644 --- a/megablocks/layers/sharedexpert_registry.py +++ b/megablocks/layers/sharedexpert_registry.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + from typing import Union from megablocks.layers import glu, mlp diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index 2fd0782..12363cb 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch import torch.nn.functional as F diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py index d0b7b16..37935b1 100644 --- a/megablocks/layers/weight_parallel.py +++ b/megablocks/layers/weight_parallel.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import stk import torch diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index 222784e..d582191 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + from megablocks.ops.binned_gather import binned_gather from megablocks.ops.binned_scatter import binned_scatter from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index 5b247e5..7cfa957 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from megablocks import benchmark_util diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index d1315f7..800daed 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index ce3c92e..eef6130 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 87b1298..fd5b3f2 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index efffed5..19033cc 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index f77a8e1..5b28c2e 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index 29ab76a..938079a 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import unittest import numpy as np diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index 45d9ade..8577095 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import unittest import stk diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 90bfbc7..65700ba 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 04766e8..e3ed44a 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/padded_scatter_benchmark.py b/megablocks/ops/padded_scatter_benchmark.py index 5f814c8..bbf82bd 100644 --- a/megablocks/ops/padded_scatter_benchmark.py +++ b/megablocks/ops/padded_scatter_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import unittest import torch diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index b97c82b..a16951d 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import unittest import torch diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index db995ff..bff1e4b 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,3 +1,7 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + + def repeat(x, tiling): if all((t == 1 for t in tiling)): return x diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index 74d9324..1130622 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index ed872d3..706ed07 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 9f5725f..67449cd 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index ce22783..b0a2bf4 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index 4a767d0..ba52917 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import unittest import numpy as np diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index 76797da..88874a0 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,3 +1,7 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + + def sum(x, dim=0): if x.shape[dim] == 1: return x.squeeze(dim=dim) diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index fb43daa..ce2cd90 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/setup.py b/setup.py index 5ef72dd..d1a7133 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import os import torch diff --git a/tests/conftest.py b/tests/conftest.py index 97bfa92..dd0ebc0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import os from typing import List, Optional diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 68f9406..34bf17f 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import gc import logging import os diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index 2f50624..c167fb6 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import pytest from tests.conftest import _get_option diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index 93d7c0a..ab4207c 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + from functools import partial import pytest diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index ae5b5bf..9f36641 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + from functools import partial import pytest diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py index d32746a..d3c5586 100644 --- a/tests/layers/parallelism_test.py +++ b/tests/layers/parallelism_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import functools import numpy as np diff --git a/tests/ops/binned_scatter_test.py b/tests/ops/binned_scatter_test.py index d13f89c..50c309f 100644 --- a/tests/ops/binned_scatter_test.py +++ b/tests/ops/binned_scatter_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import pytest import torch From e89975a293fd45156c75e7bec352564e0bc07647 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 03:36:41 +0000 Subject: [PATCH 13/43] docformatter --- .pre-commit-config.yaml | 5 +++++ megablocks/layers/dmlp_registry.py | 1 - megablocks/layers/glu.py | 2 +- megablocks/layers/mlp.py | 2 +- megablocks/layers/sharedexpert_registry.py | 1 - tests/fixtures/autouse.py | 2 +- 6 files changed, 8 insertions(+), 5 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 22aca8c..0c8ede1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -68,3 +68,8 @@ repos: - "#" - --allow-past-years types: [python] +- repo: https://github.com/PyCQA/docformatter + rev: v1.5.0 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index fe6c2a1..ee51d3e 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -32,7 +32,6 @@ def get(args: Arguments) -> MlpType: Returns: An instantiated MLP constructed using the input args. - """ if args.mlp_type not in _REGISTRY: raise ValueError(f'Unsupported mlp type: {args.mlp_type}') diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index badaf51..fa9f07d 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -206,7 +206,7 @@ def forward(self, x, tokens_per_expert): class SharedGLU(SharedMLP): - """GPU for shared expert + """GPU for shared expert. Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class """ diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index b27ad5d..f98b440 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -561,7 +561,7 @@ def forward(self, x, tokens_per_expert): class SharedMLP(torch.nn.Module): - """MLP for shared expert + """MLP for shared expert. Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class """ diff --git a/megablocks/layers/sharedexpert_registry.py b/megablocks/layers/sharedexpert_registry.py index 24317c4..98c9554 100644 --- a/megablocks/layers/sharedexpert_registry.py +++ b/megablocks/layers/sharedexpert_registry.py @@ -23,7 +23,6 @@ def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: Returns: An instantiated SharedMLP constructed using the input args. - """ if args.mlp_type not in _REGISTRY: raise ValueError(f'Unsupported mlp type: {args.mlp_type}') diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 34bf17f..1e71e15 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -23,7 +23,7 @@ def clear_cuda_cache(request: pytest.FixtureRequest): @pytest.fixture(autouse=True) def reset_mlflow_tracking_dir(): - """Reset MLFlow tracking dir so it doesn't persist across tests""" + """Reset MLFlow tracking dir so it doesn't persist across tests.""" try: import mlflow mlflow.set_tracking_uri(None) # type: ignore From 90c12a93022809a97eb565c3d6e803b9ddee1db8 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 05:39:05 +0000 Subject: [PATCH 14/43] pydocstyle --- .pre-commit-config.yaml | 14 ++++++++++++++ pyproject.toml | 25 +++++++++++++++++++++++++ 2 files changed, 39 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0c8ede1..442415d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + default_language_version: python: python3 repos: @@ -68,6 +71,17 @@ repos: - "#" - --allow-past-years types: [python] +- repo: https://github.com/PyCQA/pydocstyle + rev: 6.1.1 + hooks: + - id: pydocstyle + name: pydocstyle + entry: pydocstyle + language: python + types: [python] + exclude: (.ci|.github) + additional_dependencies: + - toml - repo: https://github.com/PyCQA/docformatter rev: v1.5.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index e7610d7..8682326 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # build requirements [build-system] requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4"] @@ -69,12 +72,28 @@ select = [ "COM812", # missing-trailing-comma ] +extend-select = ["D404"] # pydocstyle + ignore = [ "C408", "PERF2", "PERF4", + "D100", + "D101", + "D102", + "D103", + "D104", + "D105", + "D107", + "D400", + "D401", + "D415", ] +[tool.ruff.lint.pydocstyle] +convention = "google" + + # iSort [tool.isort] multi_line_output = 0 @@ -486,3 +505,9 @@ ignore_patterns = [ "wandb/**/*.py", "build/**/*.py", ] + +# PyDocStyle +[tool.pydocstyle] +convention="google" +add_ignore="D100,D101,D102,D103,D104,D105,D107,D400,D401,D415" +add_select="D404" From 121b7b5047e4bce72373ff743df04819cf048b84 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 05:39:17 +0000 Subject: [PATCH 15/43] more pydocstyle --- pyproject.toml | 22 ---------------------- tests/fixtures/autouse.py | 7 +++++-- 2 files changed, 5 insertions(+), 24 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8682326..13217d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,28 +72,6 @@ select = [ "COM812", # missing-trailing-comma ] -extend-select = ["D404"] # pydocstyle - -ignore = [ - "C408", - "PERF2", - "PERF4", - "D100", - "D101", - "D102", - "D103", - "D104", - "D105", - "D107", - "D400", - "D401", - "D415", -] - -[tool.ruff.lint.pydocstyle] -convention = "google" - - # iSort [tool.isort] multi_line_output = 0 diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 1e71e15..738a8eb 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -75,8 +75,11 @@ def set_log_levels(): @pytest.fixture(autouse=True) def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch): - """Monkeypatch reproducibility get_random_seed to always return the rank zero seed, and set the random seed before - each test to the rank local seed.""" + """Monkeypatch reproducibility. + + Make get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local + seed. + """ monkeypatch.setattr( reproducibility, 'get_random_seed', From f292bb68f956a39e2d724e0b1c726c3abe25bd3c Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 05:50:07 +0000 Subject: [PATCH 16/43] yamllint --- .github/workflows/pr-gpu.yaml | 2 +- .pre-commit-config.yaml | 9 +++++++++ yamls/matmul_benchmark.yaml | 10 +++++----- yamls/triton_benchmark.yaml | 8 ++++---- 4 files changed, 19 insertions(+), 10 deletions(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 0221752..cbdb407 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -38,7 +38,7 @@ jobs: gpu_num: ${{ matrix.gpu_num }} git_repo: databricks/megablocks pip_deps: "[all,testing]" - pytest_command: "coverage run -m pytest tests" # todo: remove tests from pytest tests when we delete all tests outside of MegaBlocks repo + pytest_command: "coverage run -m pytest tests" # todo: remove `tests` when delete tests outside megablocks dir pytest_markers: "gpu" composer_package_name: mosaicml # Required as Composer is built from source mcloud_timeout: 3600 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 442415d..39b3ca7 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -87,3 +87,12 @@ repos: hooks: - id: docformatter args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] +- repo: https://github.com/adrienverge/yamllint.git + rev: v1.28.0 + hooks: + - id: yamllint + name: yamllint + description: This hook runs yamllint. + entry: yamllint + language: python + types: [file, yaml] diff --git a/yamls/matmul_benchmark.yaml b/yamls/matmul_benchmark.yaml index de26f58..46a79d9 100644 --- a/yamls/matmul_benchmark.yaml +++ b/yamls/matmul_benchmark.yaml @@ -4,11 +4,11 @@ cluster: r9z1 gpu_num: 8 gpu_type: h100_80gb integrations: - - integration_type: git_repo - git_repo: stanford-futuredata/megablocks - git_branch: main - pip_install: absl-py 'git+https://github.com/openai/triton.git@main#egg=triton&subdirectory=python' - ssh_clone: false +- integration_type: git_repo + git_repo: stanford-futuredata/megablocks + git_branch: main + pip_install: absl-py 'git+https://github.com/openai/triton.git@main#egg=triton&subdirectory=python' + ssh_clone: false command: |- cd megablocks export ENABLE_TMA=1 diff --git a/yamls/triton_benchmark.yaml b/yamls/triton_benchmark.yaml index a31cfb5..fd30946 100644 --- a/yamls/triton_benchmark.yaml +++ b/yamls/triton_benchmark.yaml @@ -4,10 +4,10 @@ cluster: r9z1 gpu_num: 8 gpu_type: h100_80gb integrations: - - integration_type: git_repo - git_repo: openai/triton - git_branch: main - ssh_clone: false +- integration_type: git_repo + git_repo: openai/triton + git_branch: main + ssh_clone: false command: |- export ENABLE_TMA=1 export ENABLE_MMA_V3=1 From 9d4b5f3c408bc480ba56d6a43e15b369f4d8f43c Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 05:52:11 +0000 Subject: [PATCH 17/43] trufflehog --- .pre-commit-config.yaml | 10 ++++++++++ .yamllint.yaml | 42 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 52 insertions(+) create mode 100644 .yamllint.yaml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 39b3ca7..59fd0ba 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -96,3 +96,13 @@ repos: entry: yamllint language: python types: [file, yaml] +- repo: https://github.com/trufflesecurity/trufflehog.git + rev: v3.40.0 + hooks: + - id: trufflehog + name: secret scan + exclude: tests/horrible_strings.py + entry: trufflehog filesystem ./ + args: + - --only-verified + - --fail diff --git a/.yamllint.yaml b/.yamllint.yaml new file mode 100644 index 0000000..84a08ef --- /dev/null +++ b/.yamllint.yaml @@ -0,0 +1,42 @@ +yaml-files: +- "*.yaml" +- "*.yml" +- .yamllint + +ignore: | + wandb + +rules: + braces: + forbid: non-empty + brackets: + forbid: false + colons: enable + commas: enable + comments: enable + comments-indentation: enable + document-end: + present: false + document-start: + present: false + empty-lines: enable + empty-values: disable + hyphens: enable + indentation: + spaces: 2 + indent-sequences: false + check-multi-line-strings: false + key-duplicates: enable + key-ordering: disable + line-length: + max: 120 + allow-non-breakable-words: true + allow-non-breakable-inline-mappings: true + new-line-at-end-of-file: enable + new-lines: enable + octal-values: enable + quoted-strings: + quote-type: double + required: false + trailing-spaces: enable + truthy: disable From da015b538c63bc6feecc33f827119e85a90cca89 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 06:00:54 +0000 Subject: [PATCH 18/43] pyright init --- .pre-commit-config.yaml | 28 +++++++++++++++++++--------- pyproject.toml | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 59fd0ba..e08799c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,6 +4,16 @@ default_language_version: python: python3 repos: +- repo: local + hooks: + - id: pyright + name: pyright + entry: pyright + language: node + types: [python] + pass_filenames: false + args: [--warnings] + additional_dependencies: ["pyright@1.1.256"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.2.2 hooks: @@ -21,15 +31,15 @@ repos: types: [python] additional_dependencies: - toml -- repo: https://github.com/pycqa/isort - hooks: - - id: isort - rev: 5.12.0 - repo: https://github.com/hadialqattan/pycln rev: v2.1.2 hooks: - id: pycln args: [. --all] +- repo: https://github.com/pycqa/isort + hooks: + - id: isort + rev: 5.12.0 - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 hooks: @@ -71,6 +81,11 @@ repos: - "#" - --allow-past-years types: [python] +- repo: https://github.com/PyCQA/docformatter + rev: v1.5.0 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] - repo: https://github.com/PyCQA/pydocstyle rev: 6.1.1 hooks: @@ -82,11 +97,6 @@ repos: exclude: (.ci|.github) additional_dependencies: - toml -- repo: https://github.com/PyCQA/docformatter - rev: v1.5.0 - hooks: - - id: docformatter - args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] - repo: https://github.com/adrienverge/yamllint.git rev: v1.28.0 hooks: diff --git a/pyproject.toml b/pyproject.toml index 13217d1..c72dbdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -489,3 +489,41 @@ ignore_patterns = [ convention="google" add_ignore="D100,D101,D102,D103,D104,D105,D107,D400,D401,D415" add_select="D404" + + +# Pyright +[tool.pyright] +exclude = ['env-**', 'venv*', '.venv', 'tests/*', '**benchmark'] +stubPath = "" # suppress useless 'stubPath is not a valid directory' errors + +reportUnnecessaryIsInstance = "none" # it is ok to do this for clarity or safety +reportMissingTypeStubs = "none" +reportIncompatibleMethodOverride = "none" +reportIncompatibleVariableOverride = "error" +reportUnusedImport = "error" +reportUnusedClass = "warning" +reportUnusedFunction = "warning" +reportUnusedVariable = "error" +reportDuplicateImport = "error" +reportWildcardImportFromLibrary = "error" +reportUntypedFunctionDecorator = "warning" +reportPrivateImportUsage = "none" +reportUndefinedVariable = "error" +strictParameterNoneValue = true +reportPropertyTypeMismatch = "error" +reportUntypedNamedTuple = "error" +reportUnnecessaryCast = "error" +reportInvalidTypeVarUse = "error" +reportOverlappingOverload = "error" +reportUninitializedInstanceVariable = "error" +reportInvalidStringEscapeSequence = "error" +reportMissingParameterType = "error" +reportCallInDefaultInitializer = "error" +reportUnnecessaryComparison = "error" +reportSelfClsParameterName = "error" +reportImplicitStringConcatenation = "warning" # TODO: make this an error +reportInvalidStubStatement = "error" +reportIncompleteStub = "error" +reportUnsupportedDunderAll = "error" +reportUnusedCoroutine = "error" +reportMissingImports = "none" From de1170030b30303ed312997916d6e2d38635757e Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 19:14:59 +0000 Subject: [PATCH 19/43] change runner --- .github/workflows/code-quality.yaml | 2 +- .github/workflows/pr-gpu.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index 062aa41..23441ee 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -19,7 +19,7 @@ defaults: working-directory: . jobs: code-quality: - runs-on: linux-ubuntu-latest + runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later timeout-minutes: 30 strategy: matrix: diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 2c3f187..2efb63b 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -15,7 +15,7 @@ concurrency: jobs: pytest-gpu: name: ${{ matrix.name }} - runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later + runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later strategy: fail-fast: false matrix: From d1a33959cf49600641b7164ba69ee12f48f6ec22 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 19:26:25 +0000 Subject: [PATCH 20/43] format --- megablocks/_version.py | 2 +- setup.py | 28 +++++++++++----------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/megablocks/_version.py b/megablocks/_version.py index 2bb5d50..a9ac8bc 100644 --- a/megablocks/_version.py +++ b/megablocks/_version.py @@ -1,4 +1,4 @@ -# Copyright 2022 MegaBlocks Composer authors +# Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 """The MegaBlocks Version.""" diff --git a/setup.py b/setup.py index b03a82d..6a8ad77 100644 --- a/setup.py +++ b/setup.py @@ -9,18 +9,17 @@ from setuptools import find_packages, setup - # We require torch in setup.py to build cpp extensions "ahead of time" # More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html try: import torch - from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, - CUDAExtension,) + from torch.utils.cpp_extension import ( + CUDA_HOME, + BuildExtension, + CUDAExtension, + ) except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "No module named 'torch'. `torch` is required to install `MegaBlocks`." - ) from e - + raise ModuleNotFoundError("No module named 'torch'. `torch` is required to install `MegaBlocks`.",) from e _PACKAGE_NAME = 'megablocks' _PACKAGE_DIR = 'megablocks' @@ -37,7 +36,6 @@ exec(content, version_globals, version_locals) repo_version = version_locals['__version__'] - with open('README.md', 'r', encoding='utf-8') as fh: long_description = fh.read() @@ -56,7 +54,6 @@ long_description = long_description[:start] + \ long_description[end + len(end_tag):] - classifiers = [ 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.9', @@ -95,7 +92,6 @@ extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) - cmdclass = {} ext_modules = [] @@ -113,9 +109,7 @@ device_capability = f'{device_capability_tuple[0]}{device_capability_tuple[1]}' if device_capability: - nvcc_flags.append( - f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}' - ) + nvcc_flags.append(f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}',) ext_modules = [ CUDAExtension( @@ -124,19 +118,19 @@ include_dirs=['csrc'], extra_compile_args={ 'cxx': ['-fopenmp'], - 'nvcc': nvcc_flags + 'nvcc': nvcc_flags, }, - ) + ), ] elif CUDA_HOME is None: warnings.warn( 'Attempted to install CUDA extensions, but CUDA_HOME was None. ' + 'Please install CUDA and ensure that the CUDA_HOME environment ' + - 'variable points to the installation location.') + 'variable points to the installation location.', + ) else: warnings.warn('Warning: No CUDA devices; cuda code will not be compiled.') - setup( name=_PACKAGE_NAME, version=repo_version, From 7c7c9f21eb251c7d9454dd527cf82d502b87e69a Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 19:27:04 +0000 Subject: [PATCH 21/43] remove pyright --- .pre-commit-config.yaml | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e08799c..1d315f5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -4,16 +4,6 @@ default_language_version: python: python3 repos: -- repo: local - hooks: - - id: pyright - name: pyright - entry: pyright - language: node - types: [python] - pass_filenames: false - args: [--warnings] - additional_dependencies: ["pyright@1.1.256"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.2.2 hooks: From fb72ccdba77b8f55a5b65cbff0da48784e1aaadb Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 2 Aug 2024 14:47:52 +0000 Subject: [PATCH 22/43] make all executable --- docker.sh | 0 exp/dmoe/dmoe_125m_8gpu.sh | 0 exp/dmoe/dmoe_356m_8gpu.sh | 0 exp/dmoe/dmoe_46m_8gpu.sh | 0 exp/dmoe/dmoe_760m_8gpu.sh | 0 exp/gpt2/gpt2_125m_1gpu.sh | 0 exp/gpt2/gpt2_125m_8gpu.sh | 0 exp/gpt2/gpt2_1315m_1gpu.sh | 0 exp/gpt2/gpt2_1315m_8gpu.sh | 0 exp/gpt2/gpt2_356m_1gpu.sh | 0 exp/gpt2/gpt2_356m_8gpu.sh | 0 exp/gpt2/gpt2_46m_8gpu.sh | 0 exp/gpt2/gpt2_760m_1gpu.sh | 0 exp/gpt2/gpt2_760m_8gpu.sh | 0 exp/moe/moe_125m_8gpu.sh | 0 exp/moe/moe_356m_8gpu.sh | 0 exp/moe/moe_46m_8gpu.sh | 0 megablocks/layers/memory_test.sh | 0 megablocks/ops/all_to_all_benchmark.sh | 0 19 files changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 docker.sh mode change 100644 => 100755 exp/dmoe/dmoe_125m_8gpu.sh mode change 100644 => 100755 exp/dmoe/dmoe_356m_8gpu.sh mode change 100644 => 100755 exp/dmoe/dmoe_46m_8gpu.sh mode change 100644 => 100755 exp/dmoe/dmoe_760m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_125m_1gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_125m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_1315m_1gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_1315m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_356m_1gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_356m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_46m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_760m_1gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_760m_8gpu.sh mode change 100644 => 100755 exp/moe/moe_125m_8gpu.sh mode change 100644 => 100755 exp/moe/moe_356m_8gpu.sh mode change 100644 => 100755 exp/moe/moe_46m_8gpu.sh mode change 100644 => 100755 megablocks/layers/memory_test.sh mode change 100644 => 100755 megablocks/ops/all_to_all_benchmark.sh diff --git a/docker.sh b/docker.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_125m_8gpu.sh b/exp/dmoe/dmoe_125m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_356m_8gpu.sh b/exp/dmoe/dmoe_356m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_46m_8gpu.sh b/exp/dmoe/dmoe_46m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_760m_8gpu.sh b/exp/dmoe/dmoe_760m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_125m_1gpu.sh b/exp/gpt2/gpt2_125m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_125m_8gpu.sh b/exp/gpt2/gpt2_125m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_1315m_1gpu.sh b/exp/gpt2/gpt2_1315m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_1315m_8gpu.sh b/exp/gpt2/gpt2_1315m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_356m_1gpu.sh b/exp/gpt2/gpt2_356m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_356m_8gpu.sh b/exp/gpt2/gpt2_356m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_46m_8gpu.sh b/exp/gpt2/gpt2_46m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_760m_1gpu.sh b/exp/gpt2/gpt2_760m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_760m_8gpu.sh b/exp/gpt2/gpt2_760m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/moe/moe_125m_8gpu.sh b/exp/moe/moe_125m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/moe/moe_356m_8gpu.sh b/exp/moe/moe_356m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/moe/moe_46m_8gpu.sh b/exp/moe/moe_46m_8gpu.sh old mode 100644 new mode 100755 diff --git a/megablocks/layers/memory_test.sh b/megablocks/layers/memory_test.sh old mode 100644 new mode 100755 diff --git a/megablocks/ops/all_to_all_benchmark.sh b/megablocks/ops/all_to_all_benchmark.sh old mode 100644 new mode 100755 From 11b9e834bafa17ce890a206d57b38c84634a67ab Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 2 Aug 2024 14:56:34 +0000 Subject: [PATCH 23/43] chmod --- exp/gpt2/gpt2_46m_1gpu.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 exp/gpt2/gpt2_46m_1gpu.sh diff --git a/exp/gpt2/gpt2_46m_1gpu.sh b/exp/gpt2/gpt2_46m_1gpu.sh old mode 100644 new mode 100755 From 6fed8ab319ff259e55f554eeeb80d6f8d6aaa263 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 2 Aug 2024 15:07:09 +0000 Subject: [PATCH 24/43] fix format --- megablocks/ops/cumsum.py | 11 ++++++----- megablocks/ops/histogram.py | 11 ++++++----- megablocks/ops/replicate.py | 10 ++++++---- megablocks/ops/sort.py | 11 ++++++----- megablocks/ops/topology.py | 11 ++++++----- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index fd5b3f2..86619d3 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -5,14 +5,15 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e # Autograd wrappers for cumsum kernels. -# # NOTE: Does not support gradients. class ExclusiveCumsumOp(torch.autograd.Function): diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index 5b28c2e..8fd4e05 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -5,14 +5,15 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e # Autograd wrapper for histogram kernel. -# # NOTE: Does not support gradients. class HistogramOp(torch.autograd.Function): diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index 1130622..b261e2e 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -5,10 +5,12 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e # Autograd wrapper for replicate kernel. diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index b0a2bf4..7370119 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -5,10 +5,12 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e _BITS_FOR_DTYPE = { torch.int16: 16, @@ -18,7 +20,6 @@ # Autograd wrapper for sort kernel. -# # NOTE: Does not support gradients. class SortOp(torch.autograd.Function): diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index ce2cd90..502f7cd 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -5,14 +5,15 @@ # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e # Autograd wrapper for topology kernel. -# # NOTE: Does not support gradients. class TopologyOp(torch.autograd.Function): From bf6b2bdc3b8e059dda9d8f312a2f010f50f640ac Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 2 Aug 2024 16:25:11 +0000 Subject: [PATCH 25/43] test against python 3.11 --- .github/workflows/code-quality.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index 23441ee..db33036 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -26,6 +26,7 @@ jobs: python_version: - "3.9" - "3.10" + - "3.11" pip_deps: - "[dev]" steps: From ab2793865297d8b9d65cd3471cf4a51fdfda5b55 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 2 Aug 2024 16:28:31 +0000 Subject: [PATCH 26/43] todo -> TODO --- .github/workflows/code-quality.yaml | 2 +- .github/workflows/pr-gpu.yaml | 4 ++-- setup.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index db33036..34dc22c 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -19,7 +19,7 @@ defaults: working-directory: . jobs: code-quality: - runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later + runs-on: ubuntu-latest # TODO: switch to linux-ubuntu-latest later timeout-minutes: 30 strategy: matrix: diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 2efb63b..93e7e6f 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -15,7 +15,7 @@ concurrency: jobs: pytest-gpu: name: ${{ matrix.name }} - runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later + runs-on: ubuntu-latest # TODO: switch to linux-ubuntu-latest later strategy: fail-fast: false matrix: @@ -38,7 +38,7 @@ jobs: gpu_num: ${{ matrix.gpu_num }} git_repo: databricks/megablocks pip_deps: "[all,testing]" - pytest_command: "coverage run -m pytest tests" # todo: remove `tests` when delete tests outside megablocks dir + pytest_command: "coverage run -m pytest tests" # TODO: remove `tests` when delete tests outside megablocks dir pytest_markers: "gpu" composer_package_name: mosaicml # Required as Composer is built from source mcloud_timeout: 3600 diff --git a/setup.py b/setup.py index 6a8ad77..e1ee2e2 100644 --- a/setup.py +++ b/setup.py @@ -78,7 +78,7 @@ ] extra_deps['dev'] = [ - 'absl-py', # todo: delete when finish removing all absl tests + 'absl-py', # TODO: delete when finish removing all absl tests 'coverage[toml]==7.4.4', 'pytest_codeblocks>=0.16.1,<0.17', 'pytest-cov>=4,<5', From 7ba7aaafd6f792da3b5c15984f44b275e217489b Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 2 Aug 2024 16:34:10 +0000 Subject: [PATCH 27/43] use v0.1.0 of ci-testing --- .github/workflows/code-quality.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index 34dc22c..813e270 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -35,7 +35,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.0.9 + ref: v0.1.0 path: ./ci-testing - uses: ./ci-testing/.github/actions/code-quality with: From 40f1a97c6e3b58a95d6d838623b3ff72d4340a74 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Tue, 6 Aug 2024 16:21:02 -0400 Subject: [PATCH 28/43] fix yaml lint --- .github/workflows/pr-gpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index f0fa90a..eca1273 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -16,7 +16,7 @@ jobs: pytest-gpu: name: ${{ matrix.name }} if: github.repository_owner == 'databricks' - runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later + runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later strategy: fail-fast: false matrix: From 63d8b1800b6c39f1c3f2fb7087176db2e3a3df5d Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 7 Aug 2024 12:18:18 -0400 Subject: [PATCH 29/43] update ci-testing version --- .github/workflows/code-quality.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index 813e270..bf95f57 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -35,7 +35,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.1.0 + ref: v0.1.1 path: ./ci-testing - uses: ./ci-testing/.github/actions/code-quality with: From 87daf42bdbb03db7e08a80c230e021bf60434399 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 02:02:36 +0000 Subject: [PATCH 30/43] fix comment --- .github/workflows/pr-gpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index eca1273..4df2147 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -39,7 +39,7 @@ jobs: gpu_num: ${{ matrix.gpu_num }} git_repo: databricks/megablocks pip_deps: "[all,testing]" - pytest_command: "coverage run -m pytest tests" # TODO: remove `tests` when delete tests outside megablocks dir + pytest_command: "coverage run -m pytest tests" # TODO delete `tests` after removing all tests in megablocks dir pytest_markers: "gpu" composer_package_name: mosaicml # Required as Composer is built from source mcloud_timeout: 3600 From cffc463849383aefeba757b3316373a21db19e1f Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 12:29:38 +0000 Subject: [PATCH 31/43] add contributing and style guides --- CONTRIBUTING.md | 104 +++++++++++ STYLE_GUIDE.md | 480 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 584 insertions(+) create mode 100644 CONTRIBUTING.md create mode 100644 STYLE_GUIDE.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..abbe03d --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,104 @@ +# Contributing to MegaBlocks + +Thanks for considering contributing to MegaBlocks! + +Issues tagged with [good first issue](https://github.com/mosaicml/megablocks/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) are great options to start contributing. + +If you have questions, join us on [Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) -- we'll be happy to help you! + +We welcome contributions for bug fixes, new efficient methods you'd like to contribute to the community, or new models and datasets! + +## Prerequisites + +To set up the development environment in your local box, run the commands below. + +1\. Install the dependencies needed for testing and linting the code: + + +```bash +pip install -e '.[all]' +``` + +2\. Configure [pre-commit](https://pre-commit.com/), which automatically formats code before +each commit: + + +```bash +pre-commit install +``` + +## Submitting a Contribution + +To submit a contribution: + +1\. Fork a copy of the [MegaBlocks](https://github.com/databricks/megablocks) library to your own account. + +2\. Clone your fork locally and add the megablocks repo as a remote repository: + + +```bash +git clone git@github.com:/megablocks.git +cd megablocks +git remote add upstream https://github.com/databricks/megablocks.git +``` + +3\. Create a branch and make your proposed changes. + + +```bash +git checkout -b cool-new-feature +``` + +4\. When you are ready, submit a pull request into the megablocks repository! + +## Pull request (PR) guidelines + +We have some rough guidelines that will make your PR easier to review and more likely to get smoothly merged. Please don't let uncertainty or difficulty with any of these things stop you from opening a PR! We are happy to help you through them :) +* Self-contained title and description. Please include a concise title and clear PR description. The title should allow someone to understand what the PR changes or does at a glance. The description should allow someone to understand the contents of the PR _without_ looking at the code. +* If the PR affects output that is displayed to a user of MegaBlocks (e.g. console logging or experiment tracker reporting), please include screenshots showing what the new output looks like. UX is important! +* Include tests. If you are fixing a bug, please add a test that would've caught the bug. If you are adding a new feature, please add unit tests that test the various components of the feature, and also a test that tests the full functionality of the feature. +* Please consider whether your changes affect the example notebooks or large parts of the code base, and run the daily tests locally if so (`pytest -m 'daily and not remote and not gpu and not vision and not doctest'`) +* `pre-commit` should help you handle formatting and type checking, but please do make sure you have it installed as described [above](#prerequisites). + +## Configuring README Code Snippets + +MegaBlocks uses [pytest-codeblocks](https://github.com/nschloe/pytest-codeblocks) to test all example code snippets. The pytest-codeblocks repository explains how to annotate code snippets, which supports most `pytest` configurations. For example, if a test requires model training, the GPU mark (``) should be applied. + +## Running Tests + +To test your changes locally, run: + +* `make test` # run CPU tests +* `make test-gpu` # run GPU tests +* `cd docs && make doctest` # run doctests + +Some of our checks test distributed training as well. To test these, run: + +* `make test-dist WORLD_SIZE=2` # run 2-cpu distributed tests +* `make test-dist-gpu WORLD_SIZE=2` # run 2-gpu distributed tests + +These tests run with the `composer` launcher. We also support `WORLD_SIZE=1`, which would run the tests with the `composer` launcher on a single device. + +See the [Makefile](/Makefile) for more information. + +If you want to run pre-commit hooks manually, which check for code formatting and type annotations, run `pre-commit run --all-files` + +### Docker + +To run the tests in the provided docker containers: + +* `docker pull mosaicml/composer` (or an alternative image like `mosaicml/composer:latest_cpu`) +* `docker run --rm -v ./:/composer --user $(id -u):$(id -g) -it mosaicml/composer` +* from inside the container + * `cd /megablocks` + * `pip install -e .` + * `pytest ` or `make ` to run the desired tests + + +## Code Style & Typing + +See the [MegaBlocks Style Guide](/STYLE_GUIDE.md) for guidelines on how to structure and format your code. + +MegaBlocks aims to annotate all functions with type annotations (introduced in +[PEP 526](https://www.python.org/dev/peps/pep-0526/)). Don't worry if you are not a Python typing expert; +put in the pull request, and we'll help you with getting the code into shape. diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md new file mode 100644 index 0000000..3d5876d --- /dev/null +++ b/STYLE_GUIDE.md @@ -0,0 +1,480 @@ +# 1. Style and Conventions + +## 1.1 Style Guide + +MegaBlocks generally follows Google's +[Python Style Guide](https://google.github.io/styleguide/pyguide.html) for how to format and structure code. + +## 1.2. Pre-Commit Hooks + +MegaBlocks uses [Pre Commit](https://pre-commit.com/) to enforce style checks. To configure, run +``` +pip install '.[dev]' # if not already installed +pre-commit install +``` + +The pre-commit hooks will now be run before each commit. You can also run the hooks manually via: + +``` +pre-commit run # run all hooks on changed files +pre-commit run --all-files # or, run all hooks on all files +``` + + + +## 1.3. Code Formatting + +MegaBlocks uses the [yapf](https://github.com/google/yapf) formatter for general formatting +[isort](https://github.com/PyCQA/isort) to sort imports. These checks run through pre-commit +(see section 2.2). These checks can also be run manually via: + +``` +pre-commit run yapf --all-files # for yapf +pre-commit run isort --all-files # for isort +``` + +The configuration is stored in [pyproject.toml](pyproject.toml). + + +## 1.4. Code Structure + +As a general rule of thumb, + +- Don't: Default to using inheritance for code reuse + + Do: prefer [composition over inheritance](https://en.wikipedia.org/wiki/Composition_over_inheritance) +- Don't: strive to implement all logic using classes + + Do: strive to implement logic as pure functions when possible, and classes when there is good reason +- Don't: Have a function accept falsy values that would then result in a no-op. + + Example of the anti-pattern: + + ```python + from typing import Optional + + def configure_deepspeed(deepspeed_config: Optional[dict]): + if deepspeed_config is None: + # Don't do this check in the callee, which results in a no-op + return + ... + ``` + + Do: Require the caller, instead of the callee, check for and handle falsy values. It's ok to accept falsy values + for individual arguments of a caller function, so long as the entire function would not be a no-op. + + Example: + ```python + from typing import Optional + + def configure_deepspeed(deepspeed_config: dict): + ... + + def trainer(deepspeed_config: Optional[dict]): + if deepspeed_config is not None: + # Do this check in the caller function + configure_deepspeed(deepspeed_config) + ... + ``` + +# 2. Type Annotations and Typechecking + +MegaBlocks aims to annotate all functions with type annotations (introduced in +[PEP 526](https://www.python.org/dev/peps/pep-0526/)). Type annotations help statically catch `TypeError` and +`AttributeError` bugs, in addition to other benefits, as outlined in the PEP. + +For documentation on typing annotations, see: +* [PEP 483](https://peps.python.org/pep-0483/) for a simplified introducation +* [PEP 484](https://peps.python.org/pep-0484/) for the full specification +* [Python docs for `typing`](https://docs.python.org/3/library/typing.html) for the API reference + +MegaBlocks uses [pyright](https://github.com/microsoft/pyright) +to validate type annotations. PyRight is automatically run as part of the pre-commit hooks, but you can also +run PyRight specifically via: + +``` +pre-commit run pyright --all-files +``` + +The pyright configuration is stored in [pyproject.toml](pyproject.toml). + + +## 2.1 Debugging + +Here are some suggestions to deal with pyright errors: + +1. Suppose a variable could be one of multiple types, like the following: + + ```python + from typing import Union + + def foo(x: Union[int, None]): + return x + 5 # type error -- None + 5 is not allowed! + ``` + + PyRight will complain since `None + 5` is not a valid operation. + Instead, add a check to ensure that `x is not None`: + + ```python + from typing import Union + + def foo(x: Union[int, None]): + if x is None: + raise TypeError("x must be an integer, not None!") + return x + 5 # valid + ``` + + Assert statements also work. However, assert statements should not be used for data validation + (see the assert statement section below). + ```python + from typing import Union + + def foo(x: Union[int, None]): + assert x is not None, "x should never be None" + return x + 5 # valid + ``` + +1. For variables where it is impossible for pyright to infer the correct type, use +[cast](https://docs.python.org/3/library/typing.html#typing.cast). +1. As a last resort, add a `# type: ignore` comment to the line where pyright emits an error. +Immediately following this statement, paste in the error emitted by pyright, +so other contributors will know why this error was silenced. + + +# 3. Public APIs +A public API, generally speaking, can be invoked by a user without a leading underscore in any portion of the path. +The following are examples of public APIs in [composer](https://github.com/mosaicml/composer/tree/dev): + +* Standalone functions in public modules (e.g. `composer.utils.dist.get_world_size`) +* Classes in public modules (e.g. `composer.trainer.trainer.Trainer`) +* Public methods in public classes (e.g. `composer.trainer.trainer.Trainer.fit`) +* Public modules (e.g. `composer.trainer.trainer`) + +The following rules apply to public APIs: +1. All public APIs must have a docstring (see the Documentation section below) +1. All parameters must have type annotations. +1. To minimize user imports, parameters should should use native PyTorch or Python types whenever possible. + + It is acceptable to use a union of types, so long as one of the options is a primitive. For example, in the + constructor for `composer.trainer.trainer.Trainer`, the `device` parameter is annotated like the following: + + ```python + from typing import Optional, Union + + from composer.devices import Device + + class Trainer: + def __init__( + self, + device: Union[str, Device], + ): + if isinstance(device, str): + device = Device(device) + ... + ``` + + This signature allows a user to pass a string for a device, + rather than having to import our custom device class. + + Parameters that are for power users (such as `load_object_store`) in the Trainer are exempt from this rule. + These parameters can require custom imports. + +1. Parameters that could take a sequence of elements should also allow `None` or a singleton. + This simplifies the user API by not having to construct a list (or tuple) to hold a single element + (or no element). For example, use `Optional[Union[torch.Tensor, Sequence[torch.Tensor]]`. + + The `composer.utils.ensure_tuple` helper method can convert a singleton, list, or tuple into a tuple. + For example + + ```python + from torch import Tensor + from typing import Optional, Sequence, Union + from composer.utils import ensure_tuple + + def foo(x: Optional[Union[Tensor, Sequence[Tensor]]]) -> tuple[Tensor, ...]: + return ensure_tuple(x) # ensures that the result is always a (potentially empty) tuple of tensors + ``` + + +# 4. Use of `assert` + +`assert` should be used only in test cases and for verifying invariants (likely required for type checking), +not for data validation. As asserts can be disabled in python by using the `-O` flag +(e.g. `python -O path/to/script.py`), they are not guaranteed to run. For data validation, instead use a style like +the following: + + + + +```python +if parameter is None: + raise ValueError("parameter must be specified and cannot be None") +``` + + +# 5. Imports and `__init__.py` + +All imports in MegaBlocks should be absolute -- that is, they do not begin with a period. + +## 5.1 External Dependencies +1. All external dependencies must be specified in both [setup.py](setup.py) for pip and [meta.yaml](meta.yaml) + for Anaconda. + +1. If a dependency is not core to MegaBlocks (e.g. it is for a model, dataset, algorithm, or some callbacks): + 1. It must be specified in a entry of the `extra_deps` dictionary of [setup.py](setup.py). + This dictionary groups dependencies that can be conditionally installed. An entry named `foo` can be installed with `pip install 'megablocks[foo]'`. For example, running `pip install 'megablocks[gg]'` will install everything in `install_requires`, along with `grouped_gemm`. + 1. It must also be specified in the `run_constrained` and the `test.requires` section. + 1. The import must be conditionally imported in the code. For example: + + + + ```python + from composer import Callback + from composer.utils import MissingConditionalImportError + + class SystemMetricsMonitor(Callback) + try: + import pynvml + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group="system_metrics_monitor", + conda_package="pynvml", + conda_channel="conda-forge",) from e + ``` + + This style allows users to perform minimal install of Composer without triggering `ImportError`s if + an optional dependency is missing. + + If the corresponding package is not published on Anaconda, then set the ``conda_package`` to the pip package + name, and set ``conda_channel`` to ``None``. For example, with DeepSpeed: + + + ```python + from composer.utils import MissingConditionalImportError + + try: + import deepspeed + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group="deepspeed", + conda_package="deepspeed>=0.5.5", + conda_channel=None) from e + ``` + + + + 1. If the dependency is core to MegaBlocks, add the dependency to the `install_requires` section of + [setup.py](./setup.py) and the `requirements.run` section of [meta.yaml](./meta.yaml). + +## 5.2 Use of `__all__` + +All public modules must define `__all__` to be the list of members that should be re-exported. +The variable is necessary to 1) limit what `from XXX import *` imports, and 2) ensure that the documentation only +includes exported members, not unrelated re-imports. + +For example, from [composer/callbacks/memory_monitor.py](composer/callbacks/memory_monitor.py) + +```python +"""Log memory usage during training.""" +import logging +from typing import Union + +import torch.cuda + +from composer.core import State +from composer.loggers import Logger +from composer.core.callback import Callback + +log = logging.getLogger(__name__) + +__all__ = ["MemoryMonitor"] # export only the MemoryMonitor, not other imports like `Logger`, `State`, or `Callback` + + +class MemoryMonitor(Callback): + ... +``` + + +## 5.3 `__init__.py` + +All public classes and functions should be added to the module's `__init__.py`. + + +```python +from composer.path.to.module.file import MyClass as MyClass +from composer.path.to.module.file import my_func as my_func +``` + +If a file only contains public functions, then the following is also acceptable: + + +```python +from composer.path.to.module import my_file as my_file +``` + + +# 6. Documentation + +## 6.1 Docstrings + +MegaBlocks uses [Google Style Docstrings](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html). +All public APIs require documentation. + +### 6.1.1 What to include in Docstrings? + +Docstrings, at a minimum, should include a summary of what the function or class does, along with the arguments it takes. See [below](#612-formatting-docstrings) for how to format docstrings. The [Google Style Guide](https://google.github.io/styleguide/pyguide.html) also includes some guidelines on how to write docstrings. + +### 6.1.2 Formatting Docstrings + +The following guidelines apply to documentation. +1. Each function that needs a docstring must have its input arguments, return statement (if not None), and any custom + exceptions annotated. +1. The arguments for the `__init__` signature of classes should be documented under the class-level docstring. There + should not be any `__init__`-level docstring. +1. Each argument annotation should include the type. If the argument has a default value, the type annotation should + specify "optional", and the docstring should say the default value. Some examples: + + ```python + from typing import Optional, Union + + def foo(bar: int): + """Foo. + + Args: + bar (int): Required bar. + """ + ... + + def foo2(bar: int = 42): + """Foo2. + + Args: + bar (int, optional): The first Argument. Default: ``42``. + """ + ... + + def foo3(bar: Optional[int] = None): + """Foo3. + + Args: + bar (int, optional): The first Argument. Default: ``None``. + """ + ... + + def foo4(bar: Union[int, str] = 42): + """Foo4. + + Args: + bar (int | str, optional): The first Argument. Default: ``42``. + """ + ... + + def foo5(bar: int) -> int: + """Foo5. + + Args: + bar (int): Required bar. + + Returns: + int: Description of return statement. + """ + ... + + def foo6(bar: int) -> tuple[int, str]: + """Foo6. + + Args: + bar (int): Required bar. + + Returns: + a (int): Returned value. + b (str): Returned value. + """ + ... + ``` + +### 6.1.3 Building and Viewing Docs Locally + +Assuming you already have a development install of MegaBlocks (see these [instructions](CONTRIBUTING.md#prerequisites)), here’s how to build and previous the docs locally. + +**️️ ⚠ Warning:** Jenkins treats all sphinx warnings as errors, so they must be addressed before a PR can be merged. Building docs locally can help debug any warnings showing up on Jenkins! + +In one terminal, run: + + +```bash +source path/to/megablocks_venv/bin/activate # activate your megablocks virtual env +cd megablocks/docs # cd to the docs folder insde your megablocks clone +make clean +make html +``` + +In a second terminal, run: + + +```bash +cd megablocks/docs +python3 -m http.server --directory _build/html/ +``` + +Then, navigate to [http://localhost:8000](http://localhost:8000) in your browser. + +## 6.2 Doctests + +Most docstrings should also include a `.. doctest` or `.. testcode` example to clearly illustrate how one would interact with the class or function. As part of the CI/CD process, all `.. doctest` blocks are executed to ensure the example in the documentation actually works. + +### 6.2.1 Writing Doctests + +See the [Sphinx Doctest Extension](https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html) for all of the available directives. Do not use `.. code-block::` for Python examples, as they are untested. + +Any test fixtures for doctests should go in [docs/source/doctest_fixtures.py](docs/source/doctest_fixtures.py) or in a `.. testsetup::` block. + +For example: +```python +import torch +from typing import Optional + +def my_function(x: Optional[torch.Tensor]) -> torch.Tensor: + """blah function + + Args: + input (torch.Tensor): Your guess. + + Returns: + torch.Tensor: How good your input is. + + Raises: + ValueError: If your input is negative. + + Example: + .. testsetup:: + + # optional setup section, not shown in docs + import torch + x = torch.randn(42) + + + .. testcode:: + + # shown in docs; runs after testsetup + my_function(x) + """ + ... +``` + +All doctests load the [docs/source/doctest_fixtures.py](docs/source/doctest_fixtures.py) file *before* tests run. If there are any variables that would be helpful have defined for all tests, feel free to add them into this file. However, if a variable is more specific to an individual doctest, then it would be best to include it in a `.. testsetup::` block, as not to pollute the global fixture namespace. (Unlike pytest fixtures, all doctest fixtures are given to every doctest; they cannot be specifically requested) + +### 6.2.2 Running Doctests Locally + +Assuming you already have a development install of MegaBlocks (see these [instructions](CONTRIBUTING.md#prerequisites)), here’s how to run the doctests. + + +```bash +source path/to/megablocks_venv/bin/activate # activate your megablocks virtual env +cd megablocks/docs # cd to the docs folder insde your megablocks clone +make clean +make html # the html build must be completed first to ensure all doctests are identified +make doctest 2>/dev/null # For more verbosity, do not direct stderr to /dev/null +``` From 5277c7056a8904da944f3d4e522eaec3fc302ea0 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 12:29:59 +0000 Subject: [PATCH 32/43] add PULL_REQUEST template --- .github/PULL_REQUEST_TEMPLATE.md | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..34272ee --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,29 @@ +# What does this PR do? + + + +# What issue(s) does this change relate to? + + + +# Before submitting +- [ ] Have you read the [contributor guidelines](https://github.com/databricks/megablocks/blob/dev/CONTRIBUTING.md)? +- [ ] Is this change a documentation change or typo fix? If so, skip the rest of this checklist. +- [ ] Was this change discussed/approved in a GitHub issue first? It is much more likely to be merged if so. +- [ ] Did you update any related docs and document your change? +- [ ] Did you update any related tests and add any new tests related to your change? (see [testing](https://github.com/databricks/megablocks/blob/dev/CONTRIBUTING.md#running-tests)) +- [ ] Did you run the tests locally to make sure they pass? +- [ ] Did you run `pre-commit` on your change? (see the `pre-commit` section of [prerequisites](https://github.com/databricks/megablocks/blob/dev/CONTRIBUTING.md#prerequisites)) + + From e339ac56931f2120320bc01b96018790535e9d25 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 12:31:14 +0000 Subject: [PATCH 33/43] fix imports --- megablocks/__init__.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 3cdf43d..de11bcb 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,9 +1,10 @@ # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -from megablocks.layers import dmoe, moe +from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE __all__ = [ - 'dmoe', - 'moe', + 'MoE', + 'dMoE', ] From 159a6c63258380a17666771102676bcb857dbf95 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:54:10 -0400 Subject: [PATCH 34/43] Update .github/workflows/pr-gpu.yaml Co-authored-by: Mihir Patel --- .github/workflows/pr-gpu.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 4df2147..b7332e2 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -16,7 +16,7 @@ jobs: pytest-gpu: name: ${{ matrix.name }} if: github.repository_owner == 'databricks' - runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later + runs-on: ubuntu-latest # TODO: switch to linux-ubuntu-latest later strategy: fail-fast: false matrix: From 6c303a063a6d71ffed47391ca4b82f94f6ef6ca6 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 19:21:48 +0000 Subject: [PATCH 35/43] update python version --- .github/workflows/pr-gpu.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 4df2147..1392a9b 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -21,9 +21,9 @@ jobs: fail-fast: false matrix: include: - - name: "python3.11-pytorch2.3.1-gpus1" + - name: "python3.9-pytorch2.3.1-gpus1" gpu_num: 1 - python_version: 3.11 + python_version: 3.9 container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 - name: "python3.11-pytorch2.3.1-gpus2" gpu_num: 2 From a309f94a4a1ced01452b021e7e6231725992fc99 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 19:24:22 +0000 Subject: [PATCH 36/43] import more files --- megablocks/__init__.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/megablocks/__init__.py b/megablocks/__init__.py index de11bcb..0eb6061 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,10 +1,20 @@ # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -from megablocks.layers.dmoe import dMoE -from megablocks.layers.moe import MoE +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss __all__ = [ 'MoE', 'dMoE', + 'get_load_balancing_loss', + 'ParallelMLP', + 'ParallelDroplessMLP', + 'SparseMLP', + 'MLP', + 'SparseGLU', + 'Arguments', ] From b5f5869f93dc1a20e08ed477d1c3bc403c772662 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 19:49:34 +0000 Subject: [PATCH 37/43] change license --- .pre-commit/FILE_HEADER | 2 +- megablocks/__init__.py | 3 +++ megablocks/_version.py | 3 +++ megablocks/backend/__init__.py | 3 +++ megablocks/backend/kernels.py | 3 +++ megablocks/benchmark_util.py | 3 +++ megablocks/grouped_gemm_util.py | 3 +++ megablocks/layers/__init__.py | 3 +++ megablocks/layers/activation_fn.py | 3 +++ megablocks/layers/all_to_all.py | 3 +++ megablocks/layers/arguments.py | 3 +++ megablocks/layers/common.py | 3 +++ megablocks/layers/dmlp_registry.py | 3 +++ megablocks/layers/dmoe.py | 3 +++ megablocks/layers/gelu.py | 3 +++ megablocks/layers/glu.py | 3 +++ megablocks/layers/memory_test.py | 3 +++ megablocks/layers/mlp.py | 3 +++ megablocks/layers/moe.py | 3 +++ megablocks/layers/mpu.py | 3 +++ megablocks/layers/router.py | 3 +++ megablocks/layers/sharedexpert_registry.py | 3 +++ megablocks/layers/testing.py | 3 +++ megablocks/layers/weight_parallel.py | 3 +++ megablocks/ops/__init__.py | 3 +++ megablocks/ops/all_to_all_benchmark.py | 3 +++ megablocks/ops/binned_gather.py | 3 +++ megablocks/ops/binned_scatter.py | 3 +++ megablocks/ops/cumsum.py | 3 +++ megablocks/ops/gather.py | 3 +++ megablocks/ops/histogram.py | 3 +++ megablocks/ops/histogram_benchmark.py | 3 +++ megablocks/ops/matmul_benchmark.py | 3 +++ megablocks/ops/padded_gather.py | 3 +++ megablocks/ops/padded_scatter.py | 3 +++ megablocks/ops/padded_scatter_benchmark.py | 3 +++ megablocks/ops/permute_benchmark.py | 3 +++ megablocks/ops/repeat.py | 3 +++ megablocks/ops/replicate.py | 3 +++ megablocks/ops/round_up.py | 3 +++ megablocks/ops/scatter.py | 3 +++ megablocks/ops/sort.py | 3 +++ megablocks/ops/sort_benchmark.py | 3 +++ megablocks/ops/sum.py | 3 +++ megablocks/ops/topology.py | 3 +++ setup.py | 3 +++ tests/conftest.py | 3 +++ tests/fixtures/autouse.py | 3 +++ tests/fixtures/fixtures.py | 3 +++ tests/layers/dmoe_test.py | 3 +++ tests/layers/glu_test.py | 3 +++ tests/layers/moe_test.py | 3 +++ tests/layers/parallelism_test.py | 3 +++ tests/ops/binned_gather_test.py | 3 +++ tests/ops/binned_scatter_test.py | 3 +++ tests/ops/cumsum_test.py | 3 +++ tests/ops/histogram_test.py | 3 +++ tests/ops/padded_gather_test.py | 3 +++ tests/ops/padded_scatter_test.py | 3 +++ tests/ops/replicate_test.py | 3 +++ tests/ops/sort_test.py | 3 +++ tests/ops/topology_test.py | 3 +++ 62 files changed, 184 insertions(+), 1 deletion(-) diff --git a/.pre-commit/FILE_HEADER b/.pre-commit/FILE_HEADER index 69d0cd5..a2432cc 100644 --- a/.pre-commit/FILE_HEADER +++ b/.pre-commit/FILE_HEADER @@ -1,2 +1,2 @@ -Copyright 2024 MosaicML MegaBlocks authors +Copyright 2024 Databricks MegaBlocks authors SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 0eb6061..44898e9 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/_version.py b/megablocks/_version.py index a9ac8bc..9e9578d 100644 --- a/megablocks/_version.py +++ b/megablocks/_version.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/backend/__init__.py b/megablocks/backend/__init__.py index 1d3c2fd..51d38f3 100644 --- a/megablocks/backend/__init__.py +++ b/megablocks/backend/__init__.py @@ -1,2 +1,5 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index 81ea6a0..4215794 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/benchmark_util.py b/megablocks/benchmark_util.py index fe7c998..2a4ecac 100644 --- a/megablocks/benchmark_util.py +++ b/megablocks/benchmark_util.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index 899bf60..fc9a0f9 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/__init__.py b/megablocks/layers/__init__.py index de11bcb..ed970fe 100644 --- a/megablocks/layers/__init__.py +++ b/megablocks/layers/__init__.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 78d6d00..4364198 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 51f4b9e..e82121b 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index f75bcc0..5f09104 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index 27e7473..db55f4a 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index ee51d3e..50d08f3 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 940477c..6d86a5d 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/gelu.py b/megablocks/layers/gelu.py index 92a0741..5593705 100644 --- a/megablocks/layers/gelu.py +++ b/megablocks/layers/gelu.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index fa9f07d..6f3ae05 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index 78e1fa1..489b853 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index f98b440..bf0684d 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index d6687ed..7448389 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index b80ecd8..bdd6a71 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 0b5b670..44ab735 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/sharedexpert_registry.py b/megablocks/layers/sharedexpert_registry.py index 98c9554..19daf2f 100644 --- a/megablocks/layers/sharedexpert_registry.py +++ b/megablocks/layers/sharedexpert_registry.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index 12363cb..dac575a 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py index 37935b1..cc66d79 100644 --- a/megablocks/layers/weight_parallel.py +++ b/megablocks/layers/weight_parallel.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index d582191..0be6781 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index 7cfa957..bfd5b8b 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 800daed..6efa5fa 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index eef6130..fb9df38 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 86619d3..349ec6a 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index 19033cc..24f8fff 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index 8fd4e05..30a84ff 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index 938079a..a44ffa6 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index 8577095..331a686 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 65700ba..6340e2d 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index e3ed44a..f2341e1 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/padded_scatter_benchmark.py b/megablocks/ops/padded_scatter_benchmark.py index bbf82bd..f48d503 100644 --- a/megablocks/ops/padded_scatter_benchmark.py +++ b/megablocks/ops/padded_scatter_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index a16951d..daa16b4 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index bff1e4b..9f02c53 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index b261e2e..60ee4dd 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index 706ed07..345f33d 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 67449cd..d7390a1 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index 7370119..2f916fa 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index ba52917..465f09b 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index 88874a0..4f9a0b4 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index 502f7cd..307440a 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/setup.py b/setup.py index e1ee2e2..166ccfd 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/conftest.py b/tests/conftest.py index dd0ebc0..e810d03 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 738a8eb..f6d5b1a 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index c167fb6..19f694b 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index 01a7e42..d7df87f 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index ab4207c..b41cefe 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index 9f36641..ccf9b96 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py index d3c5586..aba6e69 100644 --- a/tests/layers/parallelism_test.py +++ b/tests/layers/parallelism_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/ops/binned_gather_test.py b/tests/ops/binned_gather_test.py index d889bce..45a23b5 100644 --- a/tests/ops/binned_gather_test.py +++ b/tests/ops/binned_gather_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/ops/binned_scatter_test.py b/tests/ops/binned_scatter_test.py index 50c309f..317265f 100644 --- a/tests/ops/binned_scatter_test.py +++ b/tests/ops/binned_scatter_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/ops/cumsum_test.py b/tests/ops/cumsum_test.py index a1b7160..442a6c2 100644 --- a/tests/ops/cumsum_test.py +++ b/tests/ops/cumsum_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/ops/histogram_test.py b/tests/ops/histogram_test.py index 2f98fb7..53d561d 100644 --- a/tests/ops/histogram_test.py +++ b/tests/ops/histogram_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/ops/padded_gather_test.py b/tests/ops/padded_gather_test.py index e7b8a09..0b37dbb 100644 --- a/tests/ops/padded_gather_test.py +++ b/tests/ops/padded_gather_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/ops/padded_scatter_test.py b/tests/ops/padded_scatter_test.py index 637b04b..19cae51 100644 --- a/tests/ops/padded_scatter_test.py +++ b/tests/ops/padded_scatter_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/ops/replicate_test.py b/tests/ops/replicate_test.py index 94aeb67..5cdbea8 100644 --- a/tests/ops/replicate_test.py +++ b/tests/ops/replicate_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/ops/sort_test.py b/tests/ops/sort_test.py index 3a527de..eb057d9 100644 --- a/tests/ops/sort_test.py +++ b/tests/ops/sort_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 diff --git a/tests/ops/topology_test.py b/tests/ops/topology_test.py index b7a28b8..3e9223e 100644 --- a/tests/ops/topology_test.py +++ b/tests/ops/topology_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 From e9e7576c48a38e85c6e1d81086fb944d0cf5d5dc Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 19:50:17 +0000 Subject: [PATCH 38/43] only support 3.11 --- .github/workflows/code-quality.yaml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index bf95f57..2b1d931 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -24,8 +24,6 @@ jobs: strategy: matrix: python_version: - - "3.9" - - "3.10" - "3.11" pip_deps: - "[dev]" From b69d4f8b03d056991000484bc1c79c666364f4b7 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 19:55:46 +0000 Subject: [PATCH 39/43] comment on new line --- .github/workflows/pr-gpu.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 973e715..f2905ca 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -39,7 +39,8 @@ jobs: gpu_num: ${{ matrix.gpu_num }} git_repo: databricks/megablocks pip_deps: "[all,testing]" - pytest_command: "coverage run -m pytest tests" # TODO delete `tests` after removing all tests in megablocks dir + pytest_command: "coverage run -m pytest tests" + # TODO: remove tests from pytest tests when we delete all tests in the MegaBlocks dir pytest_markers: "gpu" composer_package_name: mosaicml # Required as Composer is built from source mcloud_timeout: 3600 From b1ab8b753f9fe6c39f2a794e49ef21ec5cee0053 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 20:04:03 +0000 Subject: [PATCH 40/43] fix license --- megablocks/__init__.py | 3 --- megablocks/_version.py | 3 --- megablocks/backend/__init__.py | 3 --- megablocks/backend/kernels.py | 3 --- megablocks/benchmark_util.py | 3 --- megablocks/grouped_gemm_util.py | 3 --- megablocks/layers/__init__.py | 3 --- megablocks/layers/activation_fn.py | 3 --- megablocks/layers/all_to_all.py | 3 --- megablocks/layers/arguments.py | 3 --- megablocks/layers/common.py | 3 --- megablocks/layers/dmlp_registry.py | 3 --- megablocks/layers/dmoe.py | 3 --- megablocks/layers/gelu.py | 3 --- megablocks/layers/glu.py | 3 --- megablocks/layers/memory_test.py | 3 --- megablocks/layers/mlp.py | 3 --- megablocks/layers/moe.py | 3 --- megablocks/layers/mpu.py | 3 --- megablocks/layers/router.py | 3 --- megablocks/layers/sharedexpert_registry.py | 3 --- megablocks/layers/testing.py | 3 --- megablocks/layers/weight_parallel.py | 3 --- megablocks/ops/__init__.py | 3 --- megablocks/ops/all_to_all_benchmark.py | 3 --- megablocks/ops/binned_gather.py | 3 --- megablocks/ops/binned_scatter.py | 3 --- megablocks/ops/cumsum.py | 3 --- megablocks/ops/gather.py | 3 --- megablocks/ops/histogram.py | 3 --- megablocks/ops/histogram_benchmark.py | 3 --- megablocks/ops/matmul_benchmark.py | 3 --- megablocks/ops/padded_gather.py | 3 --- megablocks/ops/padded_scatter.py | 3 --- megablocks/ops/padded_scatter_benchmark.py | 3 --- megablocks/ops/permute_benchmark.py | 3 --- megablocks/ops/replicate.py | 3 --- megablocks/ops/round_up.py | 3 --- megablocks/ops/scatter.py | 3 --- megablocks/ops/sort.py | 3 --- megablocks/ops/sort_benchmark.py | 3 --- megablocks/ops/sum.py | 3 --- megablocks/ops/topology.py | 3 --- setup.py | 9 +-------- tests/conftest.py | 3 --- tests/fixtures/autouse.py | 3 --- tests/fixtures/fixtures.py | 3 --- tests/layers/dmoe_test.py | 3 --- tests/layers/glu_test.py | 3 --- tests/layers/moe_test.py | 3 --- tests/layers/parallelism_test.py | 3 --- tests/ops/binned_gather_test.py | 3 --- tests/ops/binned_scatter_test.py | 3 --- tests/ops/cumsum_test.py | 3 --- tests/ops/histogram_test.py | 3 --- tests/ops/padded_gather_test.py | 3 --- tests/ops/padded_scatter_test.py | 3 --- tests/ops/replicate_test.py | 3 --- tests/ops/sort_test.py | 3 --- tests/ops/topology_test.py | 3 --- 60 files changed, 1 insertion(+), 185 deletions(-) diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 44898e9..7f022e9 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from megablocks.layers.arguments import Arguments from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE from megablocks.layers.glu import SparseGLU diff --git a/megablocks/_version.py b/megablocks/_version.py index 9e9578d..ffdc54a 100644 --- a/megablocks/_version.py +++ b/megablocks/_version.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - """The MegaBlocks Version.""" __version__ = '0.5.1' diff --git a/megablocks/backend/__init__.py b/megablocks/backend/__init__.py index 51d38f3..1018343 100644 --- a/megablocks/backend/__init__.py +++ b/megablocks/backend/__init__.py @@ -1,5 +1,2 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 - -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index 4215794..a25e1a4 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch import triton import triton.language as tl diff --git a/megablocks/benchmark_util.py b/megablocks/benchmark_util.py index 2a4ecac..ea6069b 100644 --- a/megablocks/benchmark_util.py +++ b/megablocks/benchmark_util.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import torch diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index fc9a0f9..70b410f 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - try: import grouped_gemm except ImportError: diff --git a/megablocks/layers/__init__.py b/megablocks/layers/__init__.py index ed970fe..c52ab00 100644 --- a/megablocks/layers/__init__.py +++ b/megablocks/layers/__init__.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from megablocks.layers.dmoe import dMoE from megablocks.layers.moe import MoE diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 4364198..102685f 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from typing import Callable import stk diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index e82121b..230f81b 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 5f09104..08cf3cc 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import dataclasses from functools import partial from typing import Any, Callable, Optional, Union diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index db55f4a..890da35 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from megablocks.layers.arguments import Arguments diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index 50d08f3..ac43035 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from typing import Union from megablocks.layers import glu, mlp diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 6d86a5d..83f87b8 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import stk import torch diff --git a/megablocks/layers/gelu.py b/megablocks/layers/gelu.py index 5593705..da067ac 100644 --- a/megablocks/layers/gelu.py +++ b/megablocks/layers/gelu.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import stk import torch import torch.nn.functional as F diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 6f3ae05..ea2102f 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import stk import torch diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index 489b853..afd3a29 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import gc import torch diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index bf0684d..abee132 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from typing import Any import stk diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 7448389..6114fc1 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import torch diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index bdd6a71..8b2cd9a 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from megablocks.layers.arguments import Arguments diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 44ab735..6a356fd 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from megablocks.layers import common diff --git a/megablocks/layers/sharedexpert_registry.py b/megablocks/layers/sharedexpert_registry.py index 19daf2f..07ed662 100644 --- a/megablocks/layers/sharedexpert_registry.py +++ b/megablocks/layers/sharedexpert_registry.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from typing import Union from megablocks.layers import glu, mlp diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index dac575a..d1f12ca 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch import torch.nn.functional as F diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py index cc66d79..bbc2571 100644 --- a/megablocks/layers/weight_parallel.py +++ b/megablocks/layers/weight_parallel.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import stk import torch diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index 0be6781..2951ca1 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from megablocks.ops.binned_gather import binned_gather from megablocks.ops.binned_scatter import binned_scatter from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index bfd5b8b..58bfcb9 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from megablocks import benchmark_util diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 6efa5fa..cf69f90 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index fb9df38..dd15ecc 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 349ec6a..c3804a5 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index 24f8fff..4f52a95 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index 30a84ff..49b3ba9 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index a44ffa6..955d264 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import unittest import numpy as np diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index 331a686..3b766d1 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import unittest import stk diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 6340e2d..1bca6bd 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index f2341e1..8a2c4c7 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/padded_scatter_benchmark.py b/megablocks/ops/padded_scatter_benchmark.py index f48d503..7c11e0a 100644 --- a/megablocks/ops/padded_scatter_benchmark.py +++ b/megablocks/ops/padded_scatter_benchmark.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import unittest import torch diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index daa16b4..11a4201 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import unittest import torch diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index 60ee4dd..3853494 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index 345f33d..0300fe4 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index d7390a1..ccffba3 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import torch from stk.backend.autocast import custom_bwd, custom_fwd diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index 2f916fa..37dd13a 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index 465f09b..4a06580 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import unittest import numpy as np diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index 4f9a0b4..ac55d24 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - def sum(x, dim=0): if x.shape[dim] == 1: diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index 307440a..fb07210 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch diff --git a/setup.py b/setup.py index 166ccfd..9149352 100644 --- a/setup.py +++ b/setup.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - """MegaBlocks package setup.""" import os @@ -16,11 +13,7 @@ # More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html try: import torch - from torch.utils.cpp_extension import ( - CUDA_HOME, - BuildExtension, - CUDAExtension, - ) + from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension except ModuleNotFoundError as e: raise ModuleNotFoundError("No module named 'torch'. `torch` is required to install `MegaBlocks`.",) from e diff --git a/tests/conftest.py b/tests/conftest.py index e810d03..afe3e43 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import os from typing import List, Optional diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index f6d5b1a..986972e 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import gc import logging import os diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index 19f694b..577c077 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import pytest from tests.conftest import _get_option diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index d7df87f..5a1e9d5 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from functools import partial import pytest diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index b41cefe..5ede1db 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from functools import partial import pytest diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index ccf9b96..a4032c2 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from functools import partial import pytest diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py index aba6e69..6fd4d1e 100644 --- a/tests/layers/parallelism_test.py +++ b/tests/layers/parallelism_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import functools import numpy as np diff --git a/tests/ops/binned_gather_test.py b/tests/ops/binned_gather_test.py index 45a23b5..3906d9c 100644 --- a/tests/ops/binned_gather_test.py +++ b/tests/ops/binned_gather_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import pytest import torch diff --git a/tests/ops/binned_scatter_test.py b/tests/ops/binned_scatter_test.py index 317265f..f3c352b 100644 --- a/tests/ops/binned_scatter_test.py +++ b/tests/ops/binned_scatter_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import pytest import torch diff --git a/tests/ops/cumsum_test.py b/tests/ops/cumsum_test.py index 442a6c2..1e79db1 100644 --- a/tests/ops/cumsum_test.py +++ b/tests/ops/cumsum_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import pytest import torch diff --git a/tests/ops/histogram_test.py b/tests/ops/histogram_test.py index 53d561d..7a0fc9a 100644 --- a/tests/ops/histogram_test.py +++ b/tests/ops/histogram_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import pytest import torch diff --git a/tests/ops/padded_gather_test.py b/tests/ops/padded_gather_test.py index 0b37dbb..a9c880b 100644 --- a/tests/ops/padded_gather_test.py +++ b/tests/ops/padded_gather_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import pytest import torch diff --git a/tests/ops/padded_scatter_test.py b/tests/ops/padded_scatter_test.py index 19cae51..d4fbba9 100644 --- a/tests/ops/padded_scatter_test.py +++ b/tests/ops/padded_scatter_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import pytest import torch diff --git a/tests/ops/replicate_test.py b/tests/ops/replicate_test.py index 5cdbea8..4076d9d 100644 --- a/tests/ops/replicate_test.py +++ b/tests/ops/replicate_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import pytest import torch diff --git a/tests/ops/sort_test.py b/tests/ops/sort_test.py index eb057d9..0f86984 100644 --- a/tests/ops/sort_test.py +++ b/tests/ops/sort_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - from typing import Dict, Optional, Union import numpy as np diff --git a/tests/ops/topology_test.py b/tests/ops/topology_test.py index 3e9223e..cbffcbe 100644 --- a/tests/ops/topology_test.py +++ b/tests/ops/topology_test.py @@ -1,9 +1,6 @@ # Copyright 2024 Databricks MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 - import numpy as np import pytest import torch From 3b342ff652352b6982970de89d44759cb94d0d6e Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 20:05:47 +0000 Subject: [PATCH 41/43] fix license --- .pre-commit/FILE_HEADER | 2 +- megablocks/__init__.py | 2 +- megablocks/_version.py | 2 +- megablocks/backend/__init__.py | 2 +- megablocks/backend/kernels.py | 2 +- megablocks/benchmark_util.py | 2 +- megablocks/grouped_gemm_util.py | 2 +- megablocks/layers/__init__.py | 2 +- megablocks/layers/activation_fn.py | 2 +- megablocks/layers/all_to_all.py | 2 +- megablocks/layers/arguments.py | 2 +- megablocks/layers/common.py | 2 +- megablocks/layers/dmlp_registry.py | 2 +- megablocks/layers/dmoe.py | 2 +- megablocks/layers/gelu.py | 2 +- megablocks/layers/glu.py | 2 +- megablocks/layers/memory_test.py | 2 +- megablocks/layers/mlp.py | 2 +- megablocks/layers/moe.py | 2 +- megablocks/layers/mpu.py | 2 +- megablocks/layers/router.py | 2 +- megablocks/layers/sharedexpert_registry.py | 2 +- megablocks/layers/testing.py | 2 +- megablocks/layers/weight_parallel.py | 2 +- megablocks/ops/__init__.py | 2 +- megablocks/ops/all_to_all_benchmark.py | 2 +- megablocks/ops/binned_gather.py | 2 +- megablocks/ops/binned_scatter.py | 2 +- megablocks/ops/cumsum.py | 2 +- megablocks/ops/gather.py | 2 +- megablocks/ops/histogram.py | 2 +- megablocks/ops/histogram_benchmark.py | 2 +- megablocks/ops/matmul_benchmark.py | 2 +- megablocks/ops/padded_gather.py | 2 +- megablocks/ops/padded_scatter.py | 2 +- megablocks/ops/padded_scatter_benchmark.py | 2 +- megablocks/ops/permute_benchmark.py | 2 +- megablocks/ops/repeat.py | 2 +- megablocks/ops/replicate.py | 2 +- megablocks/ops/round_up.py | 2 +- megablocks/ops/scatter.py | 2 +- megablocks/ops/sort.py | 2 +- megablocks/ops/sort_benchmark.py | 2 +- megablocks/ops/sum.py | 2 +- megablocks/ops/topology.py | 2 +- setup.py | 2 +- tests/conftest.py | 2 +- tests/fixtures/autouse.py | 2 +- tests/fixtures/fixtures.py | 2 +- tests/layers/dmoe_test.py | 2 +- tests/layers/glu_test.py | 2 +- tests/layers/moe_test.py | 2 +- tests/layers/parallelism_test.py | 2 +- tests/ops/binned_gather_test.py | 2 +- tests/ops/binned_scatter_test.py | 2 +- tests/ops/cumsum_test.py | 2 +- tests/ops/histogram_test.py | 2 +- tests/ops/padded_gather_test.py | 2 +- tests/ops/padded_scatter_test.py | 2 +- tests/ops/replicate_test.py | 2 +- tests/ops/sort_test.py | 2 +- tests/ops/topology_test.py | 2 +- 62 files changed, 62 insertions(+), 62 deletions(-) diff --git a/.pre-commit/FILE_HEADER b/.pre-commit/FILE_HEADER index a2432cc..5081c93 100644 --- a/.pre-commit/FILE_HEADER +++ b/.pre-commit/FILE_HEADER @@ -1,2 +1,2 @@ -Copyright 2024 Databricks MegaBlocks authors +Copyright 2024 Databricks SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 7f022e9..d8d1848 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from megablocks.layers.arguments import Arguments diff --git a/megablocks/_version.py b/megablocks/_version.py index ffdc54a..44ea780 100644 --- a/megablocks/_version.py +++ b/megablocks/_version.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 """The MegaBlocks Version.""" diff --git a/megablocks/backend/__init__.py b/megablocks/backend/__init__.py index 1018343..9d4e43e 100644 --- a/megablocks/backend/__init__.py +++ b/megablocks/backend/__init__.py @@ -1,2 +1,2 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index a25e1a4..b831826 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/benchmark_util.py b/megablocks/benchmark_util.py index ea6069b..02612d9 100644 --- a/megablocks/benchmark_util.py +++ b/megablocks/benchmark_util.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index 70b410f..07dbc04 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 try: diff --git a/megablocks/layers/__init__.py b/megablocks/layers/__init__.py index c52ab00..f0c42de 100644 --- a/megablocks/layers/__init__.py +++ b/megablocks/layers/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from megablocks.layers.dmoe import dMoE diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 102685f..736d311 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from typing import Callable diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 230f81b..82a6f40 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 08cf3cc..efe131d 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import dataclasses diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index 890da35..ee30e79 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index ac43035..d765bd0 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from typing import Union diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 83f87b8..e683f8a 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/megablocks/layers/gelu.py b/megablocks/layers/gelu.py index da067ac..40b601d 100644 --- a/megablocks/layers/gelu.py +++ b/megablocks/layers/gelu.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import stk diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index ea2102f..fa888a6 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import stk diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index afd3a29..809e317 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import gc diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index abee132..1cae4fb 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from typing import Any diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 6114fc1..e5eaaa8 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 8b2cd9a..6aa0015 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 6a356fd..42cfbe1 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/layers/sharedexpert_registry.py b/megablocks/layers/sharedexpert_registry.py index 07ed662..0f62db3 100644 --- a/megablocks/layers/sharedexpert_registry.py +++ b/megablocks/layers/sharedexpert_registry.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from typing import Union diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index d1f12ca..4cd9500 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py index bbc2571..82effec 100644 --- a/megablocks/layers/weight_parallel.py +++ b/megablocks/layers/weight_parallel.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import stk diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index 2951ca1..b9dc286 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from megablocks.ops.binned_gather import binned_gather diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index 58bfcb9..b3a8537 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index cf69f90..8a22317 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index dd15ecc..f65fbe8 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index c3804a5..09b23ab 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 # NOTE: Torch needs to be imported before the custom diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index 4f52a95..a335273 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index 49b3ba9..7660e82 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 # NOTE: Torch needs to be imported before the custom diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index 955d264..9de8e65 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import unittest diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index 3b766d1..bfa7b7c 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import unittest diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 1bca6bd..b57a518 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 8a2c4c7..1ca1605 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/ops/padded_scatter_benchmark.py b/megablocks/ops/padded_scatter_benchmark.py index 7c11e0a..81dde4e 100644 --- a/megablocks/ops/padded_scatter_benchmark.py +++ b/megablocks/ops/padded_scatter_benchmark.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import unittest diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index 11a4201..837f07e 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import unittest diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index 9f02c53..61bb04b 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 # Copyright 2024 MosaicML MegaBlocks authors diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index 3853494..b7cb9c3 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 # NOTE: Torch needs to be imported before the custom diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index 0300fe4..2c59a78 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index ccffba3..33f051c 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import torch diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index 37dd13a..12ec8f3 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 # NOTE: Torch needs to be imported before the custom diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index 4a06580..f28e3f2 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import unittest diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index ac55d24..aa81334 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index fb07210..ba4ade0 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 # NOTE: Torch needs to be imported before the custom diff --git a/setup.py b/setup.py index 9149352..fa15ee4 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 """MegaBlocks package setup.""" diff --git a/tests/conftest.py b/tests/conftest.py index afe3e43..663bda3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import os diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 986972e..6805f3c 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import gc diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index 577c077..4039db7 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import pytest diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index 5a1e9d5..a737ef4 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from functools import partial diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index 5ede1db..d89af89 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from functools import partial diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index a4032c2..dd40ef9 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from functools import partial diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py index 6fd4d1e..35e40a0 100644 --- a/tests/layers/parallelism_test.py +++ b/tests/layers/parallelism_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import functools diff --git a/tests/ops/binned_gather_test.py b/tests/ops/binned_gather_test.py index 3906d9c..c165086 100644 --- a/tests/ops/binned_gather_test.py +++ b/tests/ops/binned_gather_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/tests/ops/binned_scatter_test.py b/tests/ops/binned_scatter_test.py index f3c352b..b725700 100644 --- a/tests/ops/binned_scatter_test.py +++ b/tests/ops/binned_scatter_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/tests/ops/cumsum_test.py b/tests/ops/cumsum_test.py index 1e79db1..5d8b082 100644 --- a/tests/ops/cumsum_test.py +++ b/tests/ops/cumsum_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import pytest diff --git a/tests/ops/histogram_test.py b/tests/ops/histogram_test.py index 7a0fc9a..d6d3f23 100644 --- a/tests/ops/histogram_test.py +++ b/tests/ops/histogram_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import pytest diff --git a/tests/ops/padded_gather_test.py b/tests/ops/padded_gather_test.py index a9c880b..7198099 100644 --- a/tests/ops/padded_gather_test.py +++ b/tests/ops/padded_gather_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/tests/ops/padded_scatter_test.py b/tests/ops/padded_scatter_test.py index d4fbba9..0e80dbb 100644 --- a/tests/ops/padded_scatter_test.py +++ b/tests/ops/padded_scatter_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/tests/ops/replicate_test.py b/tests/ops/replicate_test.py index 4076d9d..aeb1405 100644 --- a/tests/ops/replicate_test.py +++ b/tests/ops/replicate_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/tests/ops/sort_test.py b/tests/ops/sort_test.py index 0f86984..147426e 100644 --- a/tests/ops/sort_test.py +++ b/tests/ops/sort_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from typing import Dict, Optional, Union diff --git a/tests/ops/topology_test.py b/tests/ops/topology_test.py index cbffcbe..dc3c0ae 100644 --- a/tests/ops/topology_test.py +++ b/tests/ops/topology_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 Databricks MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np From 43f389a3862cf7464db5dc1b79a37bffc936761e Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 8 Aug 2024 20:50:07 +0000 Subject: [PATCH 42/43] modify --- megablocks/layers/glu.py | 3 --- megablocks/layers/mlp.py | 24 +++--------------------- megablocks/layers/mpu.py | 8 -------- 3 files changed, 3 insertions(+), 32 deletions(-) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index fa888a6..4654576 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -44,9 +44,6 @@ def __init__(self, args: Arguments): self._should_set_parallelism_attribute, ) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GLU.',) - def forward(self, x, topo): if self.args.memory_optimized_mlp: raise NotImplementedError( diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 1cae4fb..00bc18b 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -181,20 +181,7 @@ def create_dmoe_expert_weights( init_method, ) weights = weights.view([-1, columns]) - rows, columns = weights.shape - - if not args.moe_weight_parallelism: - return weights - - # Caclculate the number of rows on this weight parallel partition. - # 'rows' must be divisible by weight parallel world size. - weight_parallel_world_size = mpu.get_weight_parallel_world_size(args) - assert (rows % weight_parallel_world_size) == 0 - num_rows_per_rank = rows // weight_parallel_world_size - rank = mpu.get_weight_parallel_rank(args) - start_row = rank * num_rows_per_rank - end_row = (rank + 1) * num_rows_per_rank - return weights[start_row:end_row] + return weights class MemoryOptimizedMLP(torch.autograd.Function): @@ -371,7 +358,7 @@ def __init__(self, args: Arguments): ), ) - self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism) + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism mpu.set_expert_model_parallel_attributes( self.w1, self._should_set_parallelism_attribute, @@ -414,9 +401,7 @@ def parallel_forward(self, x, topo): def forward(self, x, topo): w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.moe_weight_parallelism: - return self.parallel_forward(x, topo) - elif self.args.memory_optimized_mlp: + if self.args.memory_optimized_mlp: return memory_optimized_mlp( x, w1, @@ -542,9 +527,6 @@ def forward(self, x, tokens_per_expert): w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',) - if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( x, diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 6aa0015..239f75f 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -42,14 +42,6 @@ def copy_expert_model_parallel_attributes( ) -def get_weight_parallel_world_size(args: Arguments) -> int: - return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1) - - -def get_weight_parallel_rank(args: Arguments) -> int: - return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0) - - def synchronized_print(group, *x): world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group) From 98033f30d8626e5c7274f954144e457ae0f4620c Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Fri, 9 Aug 2024 16:44:56 +0000 Subject: [PATCH 43/43] remove it! --- megablocks/layers/arguments.py | 2 - megablocks/layers/mlp.py | 5 +- tests/layers/parallelism_test.py | 153 ------------------------------- 3 files changed, 2 insertions(+), 158 deletions(-) delete mode 100644 tests/layers/parallelism_test.py diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index efe131d..ddbe2b7 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -40,8 +40,6 @@ class Arguments: # Parallelism arguments. moe_expert_model_parallelism: bool = False expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None - moe_weight_parallelism: bool = False - weight_parallel_group: Optional[torch.distributed.ProcessGroup] = None pipeline_model_parallel_size: int = 1 num_layers_per_virtual_pipeline_stage: Optional[int] = None diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 00bc18b..361f9c9 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -310,8 +310,7 @@ class SparseMLP(torch.nn.Module): def __init__(self, args: Arguments): super().__init__() self.args = args - self._num_rows_per_rank = ((mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // - mpu.get_weight_parallel_world_size(args)) + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) self.w1 = torch.nn.Parameter( torch.empty( @@ -378,7 +377,7 @@ def scale_grad(self, w): return scale_gradient(w, self.gradient_scale) def parallel_forward(self, x, topo): - group = self.args.weight_parallel_group + group = None w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) if self.args.memory_optimized_mlp: if self.args.activation_fn is not DEFAULT_ACTIVATION_FN: diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py deleted file mode 100644 index 35e40a0..0000000 --- a/tests/layers/parallelism_test.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import functools - -import numpy as np -import pytest -import torch - -from megablocks.layers import arguments, dmoe, mpu - -_PARALLELISM_TESTS = ( - (64, 1024, 512, 2048, 64, 1, False), - (64, 1024, 512, 2048, 64, 1, True), - # Test with fewer experts than ranks to verify tensor - # sharding in tandem with expert sharding. - (4, 1, 512, 2048, 4, 1, False), - (4, 1, 512, 2048, 4, 1, True), -) - - -# Todo: Fix this long term -@pytest.fixture -def group(): - return None - - -@pytest.mark.world_size(2) -@pytest.mark.gpu -@pytest.mark.parametrize(( - 'batch_size', - 'sequence_length', - 'hidden_size', - 'ffn_hidden_size', - 'num_experts', - 'top_k', - 'memory_optimized', -), _PARALLELISM_TESTS) -def test_expert_parallel_versus_weight_parallel( - group, - batch_size: int, - sequence_length: int, - hidden_size: int, - ffn_hidden_size: int, - num_experts: int, - top_k: int, - memory_optimized: bool, -): - - init_fn = functools.partial(torch.nn.init.normal_, mean=0.0, std=0.1) - ep_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized, - ) - wp_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_weight_parallelism=True, - weight_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized, - ) - - # NOTE: Reset the seed so that the models get identical weights. - torch.manual_seed(1234) - ep = dmoe.dMoE(ep_args) - torch.manual_seed(1234) - wp = dmoe.dMoE(wp_args) - - # NOTE: Include the rank in the seed so we get different data per rank. - rank = torch.distributed.get_rank(group) - torch.manual_seed(1234 * rank) - x = torch.randn((batch_size, sequence_length, hidden_size), device=torch.cuda.current_device(), - dtype=torch.float32).requires_grad_(True) - - # Test forward. - out, _ = wp(x) - expected_out, _ = ep(x) - - # Check the forward outputs. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - out.detach().float().cpu(), - expected_out.detach().float().cpu(), - rtol=1e-4, - atol=1e-4, - ) is None - - # Test backward. - out.mean().backward() - expected_out.mean().backward() - - # NOTE: If tensor parallelism is used different weights can be on - # different ranks. Gather the full grads to rank 0 to compare. - def gather(x): - m, n = x.shape - world_size = torch.distributed.get_world_size(group) - out = torch.empty(m * world_size, n, device=x.device, dtype=x.dtype) - torch.distributed.all_gather_into_tensor(out, x, group=group) - return out - - def permute(x): - esd = mpu.expert_sharding_degree(ep_args) - hsd = mpu.hidden_sharding_degree(ep_args) - out = x.view(hsd, esd, -1).transpose(1, 0).contiguous() - return out.view(num_experts * ffn_hidden_size, hidden_size) - - wp_w2_grad = gather(wp.experts.mlp.w2.grad) - ep_w2_grad = permute(gather(ep.experts.mlp.w2.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w2_grad.float().cpu(), - ep_w2_grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None - - wp_w1_grad = gather(wp.experts.mlp.w1.grad) - ep_w1_grad = permute(gather(ep.experts.mlp.w1.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w1_grad.float().cpu(), - ep_w1_grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None - - # Verify the router weight gradient, which is not sharded. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - wp.router.layer.weight.grad.float().cpu(), - ep.router.layer.weight.grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None