Skip to content

Commit

Permalink
merge setup-deps
Browse files Browse the repository at this point in the history
  • Loading branch information
vchiley committed Jul 27, 2023
2 parents abe9ae4 + 876841e commit 5ac27d0
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 10 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ MegaBlocks dMoEs outperform MoEs trained with [Tutel](https://github.com/microso

# :building_construction: Installation

Note: this assumes you have `numpy` and `torch` installed

**Training models with Megatron-LM:** We recommend using NGC's [`nvcr.io/nvidia/pytorch:23.01-py3`](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch/tags) PyTorch container. The [Dockerfile](Dockerfile) builds on this image with additional dependencies. To build the image, run `docker build . -t megablocks-dev` and then `bash docker.sh` to launch the container. Once inside the container, install MegaBlocks with `pip install .`. See [Usage](#steam_locomotive-usage) for instructions on training MoEs with MegaBlocks + Megatron-LM.

**Using MegaBlocks in other packages:** To install the MegaBlocks package for use in other frameworks, run `pip install megablocks`.
Expand Down
12 changes: 12 additions & 0 deletions megablocks/layers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,15 @@ def dtype(args : Arguments):
elif args.bf16:
dtype = torch.bfloat16
return dtype


def cast_if_autocast_enabled(tensor):
if torch.is_autocast_enabled():
if tensor.device.type == 'cuda':
dtype = torch.get_autocast_gpu_dtype()
elif tensor.device.type == 'cpu':
dtype = torch.get_autocast_cpu_dtype()
else:
raise NotImplementedError()
return tensor.to(dtype=dtype)
return tensor
2 changes: 2 additions & 0 deletions megablocks/layers/dmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def forward_once(self, x, expert_weights, top_experts):

# Perform the expert computation.
x = self.mlp(x, topo)
x = common.cast_if_autocast_enabled(x)

# Un-route the data for the MoE output.
x = ops.padded_scatter(
Expand Down Expand Up @@ -195,6 +196,7 @@ def permute_and_compute(

# Perform the expert computation.
x = self.mlp(x, topo)
x = common.cast_if_autocast_enabled(x)

# Un-route the data for the MoE output.
return ops.padded_scatter(
Expand Down
7 changes: 4 additions & 3 deletions megablocks/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ def batched_load_balancing_loss(args : Arguments):
# the correct types and formats for the dot product.
if args.moe_lbl_in_fp32:
expert_scores = torch.cat(expert_scores, dim=1).float().mean(dim=0)
tokens_per_expert = torch.cat(tokens_per_expert).float()
else:
expert_scores = torch.cat(expert_scores, dim=1).mean(dim=0)
tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)
tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype)

expected_values = num_layers_per_pipeline_stage * args.moe_num_experts
assert tokens_per_expert.numel() == expected_values
Expand Down Expand Up @@ -147,7 +146,7 @@ def load_balancing_loss(self, tokens_per_expert, expert_scores):
assert num_experts == self.num_experts
scale = self.num_experts / (tokens * self.top_k)
return scale * torch.dot(
tokens_per_expert.half(),
tokens_per_expert.to(expert_scores.dtype),
expert_scores.mean(dim=0))

def indices_and_bins(self, top_expert):
Expand Down Expand Up @@ -191,6 +190,7 @@ def permute_and_compute(
# Perform the expert computation. Note that we don't
# use biases for these linear operations.
x = self.mlp(x)
x = common.cast_if_autocast_enabled(x)

# Un-route the data for the MoE output.
return ops.binned_scatter(
Expand Down Expand Up @@ -387,6 +387,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts):
return x, tokens_per_expert.flatten()

def forward(self, x):
x = common.cast_if_autocast_enabled(x)
sl, bs, hs = x.size()

# Compute the expert scores and assignments.
Expand Down
24 changes: 17 additions & 7 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from setuptools import setup, find_packages
from torch import cuda
from torch.utils.cpp_extension import BuildExtension, CUDAExtension


_dc = cuda.get_device_capability()
_dc = f"{_dc[0]}{_dc[1]}"
ext_modules = [
CUDAExtension(
"megablocks_ops",
Expand All @@ -12,11 +15,22 @@
"nvcc": [
"--ptxas-options=-v",
"--optimize=2",
"--generate-code=arch=compute_80,code=sm_80"
f"--generate-code=arch=compute_{_dc},code=sm_{_dc}"
]
})
]

install_requires=[
'stanford-stk @ git+https://github.com/vchiley/stk.git@setup_deps',
]

extra_deps = {}

extra_deps['dev'] = [
'absl-py',
]

extra_deps['all'] = set(dep for deps in extra_deps.values() for dep in deps)

setup(
name="megablocks",
Expand All @@ -35,10 +49,6 @@
packages=find_packages(),
ext_modules=ext_modules,
cmdclass={"build_ext": BuildExtension},
install_requires=[
"absl-py",
"numpy",
"torch",
"stanford-stk @ git+https://github.com/stanford-futuredata/stk.git@main"
],
install_requires=install_requires,
extras_require=extra_deps,
)

0 comments on commit 5ac27d0

Please sign in to comment.