From 87cd325330877a2ce24dbb4d6e95241a7c31cfa3 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Mon, 10 Jul 2023 11:21:15 -0700 Subject: [PATCH 1/6] Update setup.py --- setup.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/setup.py b/setup.py index c2db3d01..8b4c1c5a 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,17 @@ }) ] +install_requires=[ + 'stanford-stk @ git+https://github.com/vchiley/stk.git@setup_deps', +] + +extra_deps = {} + +extra_deps['dev'] = [ + 'absl-py', +] + +extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps) setup( name="megablocks", @@ -35,10 +46,6 @@ packages=find_packages(), ext_modules=ext_modules, cmdclass={"build_ext": BuildExtension}, - install_requires=[ - "absl-py", - "numpy", - "torch", - "stanford-stk @ git+https://github.com/stanford-futuredata/stk.git@main" - ], + install_requires=install_requires, + extras_require=extra_deps, ) From e9738deae82196132864a63eed833b6b637aff03 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Mon, 10 Jul 2023 11:26:04 -0700 Subject: [PATCH 2/6] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 109099fa..e8d53be2 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,8 @@ MegaBlocks dMoEs outperform MoEs trained with [Tutel](https://github.com/microso # :building_construction: Installation +Note: this assumes you have `numpy` and `torch` installed + **Training models with Megatron-LM:** We recommend using NGC's [`nvcr.io/nvidia/pytorch:23.01-py3`](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) PyTorch container. The [Dockerfile](Dockerfile) builds on this image with additional dependencies. To build the image, run `docker build . -t megablocks-dev` and then `bash docker.sh` to launch the container. Once inside the container, install MegaBlocks with `pip install .`. See [Usage](#steam_locomotive-usage) for instructions on training MoEs with MegaBlocks + Megatron-LM. **Using MegaBlocks in other packages:** To install the MegaBlocks package for use in other frameworks, run `pip install megablocks`. From 331b431a7adf707fd3dce1ea09c8f49f1a369797 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Tue, 18 Jul 2023 18:41:20 -0700 Subject: [PATCH 3/6] Update moe.py --- megablocks/layers/moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 94af93a8..a23ada23 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -145,7 +145,7 @@ def load_balancing_loss(self, tokens_per_expert, expert_scores): assert num_experts == self.num_experts scale = self.num_experts / (tokens * self.args.moe_top_k) return scale * torch.dot( - tokens_per_expert.half(), + tokens_per_expert.to(expert_scores.dtype), expert_scores.mean(dim=0)) def indices_and_bins(self, top_expert): From 7c9a26de02e788173ad3fb3fa6c3253af4bc005e Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Wed, 26 Jul 2023 12:04:24 -0700 Subject: [PATCH 4/6] Update setup.py detect device and set `--generate-code` automatically --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 8b4c1c5a..c283674d 100644 --- a/setup.py +++ b/setup.py @@ -1,7 +1,10 @@ from setuptools import setup, find_packages +from torch import cuda from torch.utils.cpp_extension import BuildExtension, CUDAExtension +_dc = cuda.get_device_capability() +_dc = f"{_dc[0]}{_dc[1]}" ext_modules = [ CUDAExtension( "megablocks_ops", @@ -12,7 +15,7 @@ "nvcc": [ "--ptxas-options=-v", "--optimize=2", - "--generate-code=arch=compute_80,code=sm_80" + f"--generate-code=arch=compute_{_dc},code=sm_{_dc}" ] }) ] From bc47c6dafb4bdcba72d3df58efbbc2d4bf79df53 Mon Sep 17 00:00:00 2001 From: Vitaliy Chiley <6439018+vchiley@users.noreply.github.com> Date: Wed, 26 Jul 2023 12:52:28 -0700 Subject: [PATCH 5/6] Update moe.py --- megablocks/layers/moe.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index a23ada23..3939a17a 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -70,10 +70,9 @@ def batched_load_balancing_loss(args : Arguments): # the correct types and formats for the dot product. if args.moe_lbl_in_fp32: expert_scores = torch.cat(expert_scores, dim=1).float().mean(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).float() else: expert_scores = torch.cat(expert_scores, dim=1).mean(dim=0) - tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) expected_values = num_layers_per_pipeline_stage * args.moe_num_experts assert tokens_per_expert.numel() == expected_values From 876841e6154735356c3a943d4848334ba28358e4 Mon Sep 17 00:00:00 2001 From: root Date: Thu, 27 Jul 2023 16:46:56 +0000 Subject: [PATCH 6/6] set all2all dtype using amp precision --- megablocks/layers/common.py | 12 ++++++++++++ megablocks/layers/dmoe.py | 2 ++ megablocks/layers/moe.py | 4 ++++ 3 files changed, 18 insertions(+) diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index eb1cf397..c15bf02a 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -8,3 +8,15 @@ def dtype(args : Arguments): elif args.bf16: dtype = torch.bfloat16 return dtype + + +def cast_if_autocast_enabled(tensor): + if torch.is_autocast_enabled(): + if tensor.device.type == 'cuda': + dtype = torch.get_autocast_gpu_dtype() + elif tensor.device.type == 'cpu': + dtype = torch.get_autocast_cpu_dtype() + else: + raise NotImplementedError() + return tensor.to(dtype=dtype) + return tensor diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index df592f2e..d20d91f6 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -136,6 +136,7 @@ def forward_once(self, x, top_expert): # Perform the expert computation. x = self.mlp(x, topo) + x = common.cast_if_autocast_enabled(x) # Un-route the data for the MoE output. x = ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) @@ -162,6 +163,7 @@ def permute_and_compute( # Perform the expert computation. x = self.mlp(x, topo) + x = common.cast_if_autocast_enabled(x) # Un-route the data for the MoE output. return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 3939a17a..2ecddace 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -185,6 +185,7 @@ def permute_and_compute( # Perform the expert computation. Note that we don't # use biases for these linear operations. x = self.mlp(x) + x = common.cast_if_autocast_enabled(x) # Un-route the data for the MoE output. return ops.binned_scatter(x, indices, bins) @@ -349,6 +350,9 @@ def forward(self, x): # Compute the top-1 expert routing. scores, expert_weights, top_experts = self.router(x) + # guarentee routing is done with amp precision + x = common.cast_if_autocast_enabled(x) + # Simplified code-path for the common case of top_k == 1. if self.args.moe_top_k == 1: x, tokens_per_expert = self.forward_fn(x, top_experts)