From 52186a6e6037bf1671e3d466737b0d2c3f9eff46 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Thu, 1 Aug 2024 13:25:27 -0400 Subject: [PATCH 01/11] Clean up setup.py (#128) * ignore more things * update setup.py * init pyproject.toml * update README * remove section seperations * remove unused variable * add docstring * read version from megablocks/__init__.py * fix reading repo version * add type hints * add classifiers, better long-description * update url * exclude more packages * add python_requires * update * use Composer's .gitignore * use Composer's pyproject.toml + my changes * remove my stk fork * remove composer specific * better error msg * fix typo * add correct versions of stanford-stk, grouped_gemm; add packaging * test in my GA * use my fork for testing * __init__.py only has __version * change GA to defaults * restore __init__.py * update readme * update readme --- README.md | 4 +- megablocks/_version.py | 6 ++ setup.py | 154 ++++++++++++++++++++++++++++++----------- 3 files changed, 123 insertions(+), 41 deletions(-) create mode 100644 megablocks/_version.py diff --git a/README.md b/README.md index ee3628f..a3013d0 100644 --- a/README.md +++ b/README.md @@ -22,9 +22,9 @@ NOTE: This assumes you have `numpy` and `torch` installed. Installing `megablocks[gg]` enables dMoE computation with grouped GEMM. This feature is enabled by setting the `mlp_impl` argument to `grouped`. This is currently our recommended path for Hopper-generation GPUs. -Installing `megablocks[dev]` allows you to contribute to MegaBlocks and test locally. Installing `megablocks[testing]` allows you to test via Github Actions. +Installing `megablocks[dev]` allows you to contribute to MegaBlocks and test locally. Installing `megablocks[testing]` allows you to test via Github Actions. If you've installed megablocks[dev], you can run pre-commit install to configure the pre-commit hook to automatically format the code. -MegaBlocks can be installed with all dependencies via the `megablocks[all]` package. +MegaBlocks can be installed with all dependencies (except for `testing`) via the `megablocks[all]` package. # :steam_locomotive: Usage diff --git a/megablocks/_version.py b/megablocks/_version.py new file mode 100644 index 0000000..2bb5d50 --- /dev/null +++ b/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2022 MegaBlocks Composer authors +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.5.1' diff --git a/setup.py b/setup.py index ac1b43f..a247b0d 100644 --- a/setup.py +++ b/setup.py @@ -1,40 +1,77 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + +"""MegaBlocks package setup.""" + import os +import warnings +from typing import Any, Dict, Mapping -import torch from setuptools import find_packages, setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension -if os.environ.get("TORCH_CUDA_ARCH_LIST"): - # Let PyTorch builder to choose device to target for. - device_capability = "" -else: - device_capability = torch.cuda.get_device_capability() - device_capability = f"{device_capability[0]}{device_capability[1]}" -nvcc_flags = [ - "--ptxas-options=-v", - "--optimize=2", -] -if device_capability: - nvcc_flags.append( - f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}" - ) - -ext_modules = [ - CUDAExtension( - "megablocks_ops", - ["csrc/ops.cu"], - include_dirs=["csrc"], - extra_compile_args={"cxx": ["-fopenmp"], "nvcc": nvcc_flags}, - ) +# We require torch in setup.py to build cpp extensions "ahead of time" +# More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html +try: + import torch + from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, + CUDAExtension,) +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "No module named 'torch'. `torch` is required to install `MegaBlocks`." + ) from e + + +_PACKAGE_NAME = 'megablocks' +_PACKAGE_DIR = 'megablocks' +_REPO_REAL_PATH = os.path.dirname(os.path.realpath(__file__)) +_PACKAGE_REAL_PATH = os.path.join(_REPO_REAL_PATH, _PACKAGE_DIR) + +# Read the package version +# We can't use `.__version__` from the library since it's not installed yet +version_path = os.path.join(_PACKAGE_REAL_PATH, '_version.py') +with open(version_path, encoding='utf-8') as f: + version_globals: Dict[str, Any] = {} + version_locals: Mapping[str, object] = {} + content = f.read() + exec(content, version_globals, version_locals) + repo_version = version_locals['__version__'] + + +with open('README.md', 'r', encoding='utf-8') as fh: + long_description = fh.read() + +# Hide the content between and +# tags in the README +while True: + start_tag = '' + end_tag = '' + start = long_description.find(start_tag) + end = long_description.find(end_tag) + if start == -1: + assert end == -1, 'there should be a balanced number of start and ends' + break + else: + assert end != -1, 'there should be a balanced number of start and ends' + long_description = long_description[:start] + \ + long_description[end + len(end_tag):] + + +classifiers = [ + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'License :: OSI Approved :: BSD License', + 'Operating System :: Unix', ] install_requires = [ 'numpy>=1.21.5,<2.1.0', + 'packaging>=21.3.0,<24.2', 'torch>=2.3.0,<2.4', 'triton>=2.1.0', 'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301', - 'packaging>=21.3.0,<24.2', ] extra_deps = {} @@ -61,23 +98,62 @@ if key not in {'testing'} }) + +cmdclass = {} +ext_modules = [] + +# Only install CUDA extensions if available +if 'cu' in torch.__version__ and CUDA_HOME is not None: + + cmdclass = {'build_ext': BuildExtension} + nvcc_flags = ['--ptxas-options=-v', '--optimize=2'] + + if os.environ.get('TORCH_CUDA_ARCH_LIST'): + # Let PyTorch builder to choose device to target for. + device_capability = '' + else: + device_capability_tuple = torch.cuda.get_device_capability() + device_capability = f'{device_capability_tuple[0]}{device_capability_tuple[1]}' + + if device_capability: + nvcc_flags.append( + f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}' + ) + + ext_modules = [ + CUDAExtension( + 'megablocks_ops', + ['csrc/ops.cu'], + include_dirs=['csrc'], + extra_compile_args={ + 'cxx': ['-fopenmp'], + 'nvcc': nvcc_flags + }, + ) + ] +elif CUDA_HOME is None: + warnings.warn( + 'Attempted to install CUDA extensions, but CUDA_HOME was None. ' + + 'Please install CUDA and ensure that the CUDA_HOME environment ' + + 'variable points to the installation location.') +else: + warnings.warn('Warning: No CUDA devices; cuda code will not be compiled.') + + setup( - name="megablocks", - version="0.5.1", - author="Trevor Gale", - author_email="tgale@stanford.edu", - description="MegaBlocks", - long_description=open('README.md').read(), + name=_PACKAGE_NAME, + version=repo_version, + author='Trevor Gale', + author_email='tgale@stanford.edu', + description='MegaBlocks', + long_description=long_description, long_description_content_type='text/markdown', - url="https://github.com/stanford-futuredata/megablocks", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - ], - packages=find_packages(), + url='https://github.com/databricks/megablocks', + classifiers=classifiers, + packages=find_packages(exclude=['tests*', 'third_party*', 'yamls*', 'exp*', '.github*']), ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension}, + cmdclass=cmdclass, install_requires=install_requires, extras_require=extra_deps, + python_requires='>=3.9', ) From 2fa6dc873e84adeb7f1e9a57393bd4bd173fca61 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Mon, 5 Aug 2024 19:13:36 -0400 Subject: [PATCH 02/11] only run GA if repo owner is Databricks (#135) --- .github/workflows/pr-gpu.yaml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 0447b87..64d0fb6 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -15,6 +15,7 @@ concurrency: jobs: pytest-gpu: name: ${{ matrix.name }} + if: github.repository_owner == 'databricks' runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later strategy: fail-fast: false From 6adb1fba2c31bb17ed17b9363fbadfeeec10a862 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Fri, 9 Aug 2024 11:59:50 -0400 Subject: [PATCH 03/11] GA to Lint + Format MegaBlocks (#131) Co-authored-by: Mihir Patel --- .github/PULL_REQUEST_TEMPLATE.md | 29 ++ .github/workflows/code-quality.yaml | 41 ++ .github/workflows/pr-gpu.yaml | 5 +- .pre-commit-config.yaml | 108 +++++ .pre-commit/FILE_HEADER | 2 + .yamllint.yaml | 42 ++ CONTRIBUTING.md | 104 +++++ Dockerfile | 2 +- LICENSE | 2 +- MANIFEST.in | 2 +- STYLE_GUIDE.md | 480 ++++++++++++++++++++ docker.sh | 0 exp/dmoe/dmoe_125m_8gpu.sh | 0 exp/dmoe/dmoe_356m_8gpu.sh | 0 exp/dmoe/dmoe_46m_8gpu.sh | 0 exp/dmoe/dmoe_760m_8gpu.sh | 0 exp/gpt2/gpt2_125m_1gpu.sh | 0 exp/gpt2/gpt2_125m_8gpu.sh | 0 exp/gpt2/gpt2_1315m_1gpu.sh | 0 exp/gpt2/gpt2_1315m_8gpu.sh | 0 exp/gpt2/gpt2_356m_1gpu.sh | 0 exp/gpt2/gpt2_356m_8gpu.sh | 0 exp/gpt2/gpt2_46m_1gpu.sh | 0 exp/gpt2/gpt2_46m_8gpu.sh | 0 exp/gpt2/gpt2_760m_1gpu.sh | 0 exp/gpt2/gpt2_760m_8gpu.sh | 0 exp/moe/moe_125m_8gpu.sh | 0 exp/moe/moe_356m_8gpu.sh | 0 exp/moe/moe_46m_8gpu.sh | 0 megablocks/__init__.py | 22 +- megablocks/_version.py | 2 +- megablocks/backend/__init__.py | 3 +- megablocks/backend/kernels.py | 175 ++++---- megablocks/benchmark_util.py | 19 +- megablocks/grouped_gemm_util.py | 12 +- megablocks/layers/__init__.py | 9 + megablocks/layers/activation_fn.py | 15 +- megablocks/layers/all_to_all.py | 31 +- megablocks/layers/arguments.py | 70 +-- megablocks/layers/common.py | 9 +- megablocks/layers/dmlp_registry.py | 25 +- megablocks/layers/dmoe.py | 193 ++++---- megablocks/layers/gelu.py | 18 +- megablocks/layers/glu.py | 118 +++-- megablocks/layers/memory_test.py | 66 ++- megablocks/layers/memory_test.sh | 0 megablocks/layers/mlp.py | 310 ++++++++----- megablocks/layers/moe.py | 223 +++++----- megablocks/layers/mpu.py | 81 ++-- megablocks/layers/router.py | 28 +- megablocks/layers/sharedexpert_registry.py | 12 +- megablocks/layers/testing.py | 60 ++- megablocks/layers/weight_parallel.py | 133 ++++-- megablocks/ops/__init__.py | 24 +- megablocks/ops/all_to_all_benchmark.py | 35 +- megablocks/ops/all_to_all_benchmark.sh | 0 megablocks/ops/binned_gather.py | 9 +- megablocks/ops/binned_scatter.py | 20 +- megablocks/ops/cumsum.py | 20 +- megablocks/ops/gather.py | 14 +- megablocks/ops/histogram.py | 17 +- megablocks/ops/histogram_benchmark.py | 38 +- megablocks/ops/matmul_benchmark.py | 297 ++++++++----- megablocks/ops/padded_gather.py | 26 +- megablocks/ops/padded_scatter.py | 56 ++- megablocks/ops/padded_scatter_benchmark.py | 37 +- megablocks/ops/permute_benchmark.py | 74 ++-- megablocks/ops/repeat.py | 8 +- megablocks/ops/replicate.py | 26 +- megablocks/ops/round_up.py | 5 +- megablocks/ops/scatter.py | 29 +- megablocks/ops/sort.py | 18 +- megablocks/ops/sort_benchmark.py | 41 +- megablocks/ops/sum.py | 3 +- megablocks/ops/topology.py | 49 ++- pyproject.toml | 483 +++++++++++++++++++++ setup.py | 35 +- tests/conftest.py | 21 +- tests/fixtures/autouse.py | 18 +- tests/fixtures/fixtures.py | 4 + tests/layers/dmoe_test.py | 163 +++---- tests/layers/glu_test.py | 26 +- tests/layers/moe_test.py | 49 ++- tests/layers/parallelism_test.py | 63 ++- tests/ops/binned_gather_test.py | 11 +- tests/ops/binned_scatter_test.py | 18 +- tests/ops/cumsum_test.py | 2 +- tests/ops/histogram_test.py | 5 +- tests/ops/padded_gather_test.py | 13 +- tests/ops/padded_scatter_test.py | 37 +- tests/ops/replicate_test.py | 2 +- tests/ops/sort_test.py | 7 +- tests/ops/topology_test.py | 26 +- yamls/matmul_benchmark.yaml | 10 +- yamls/triton_benchmark.yaml | 10 +- 95 files changed, 3116 insertions(+), 1184 deletions(-) create mode 100644 .github/PULL_REQUEST_TEMPLATE.md create mode 100644 .github/workflows/code-quality.yaml create mode 100644 .pre-commit-config.yaml create mode 100644 .pre-commit/FILE_HEADER create mode 100644 .yamllint.yaml create mode 100644 CONTRIBUTING.md create mode 100644 STYLE_GUIDE.md mode change 100644 => 100755 docker.sh mode change 100644 => 100755 exp/dmoe/dmoe_125m_8gpu.sh mode change 100644 => 100755 exp/dmoe/dmoe_356m_8gpu.sh mode change 100644 => 100755 exp/dmoe/dmoe_46m_8gpu.sh mode change 100644 => 100755 exp/dmoe/dmoe_760m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_125m_1gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_125m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_1315m_1gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_1315m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_356m_1gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_356m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_46m_1gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_46m_8gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_760m_1gpu.sh mode change 100644 => 100755 exp/gpt2/gpt2_760m_8gpu.sh mode change 100644 => 100755 exp/moe/moe_125m_8gpu.sh mode change 100644 => 100755 exp/moe/moe_356m_8gpu.sh mode change 100644 => 100755 exp/moe/moe_46m_8gpu.sh mode change 100644 => 100755 megablocks/layers/memory_test.sh mode change 100644 => 100755 megablocks/ops/all_to_all_benchmark.sh diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..34272ee --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,29 @@ +# What does this PR do? + + + +# What issue(s) does this change relate to? + + + +# Before submitting +- [ ] Have you read the [contributor guidelines](https://github.com/databricks/megablocks/blob/dev/CONTRIBUTING.md)? +- [ ] Is this change a documentation change or typo fix? If so, skip the rest of this checklist. +- [ ] Was this change discussed/approved in a GitHub issue first? It is much more likely to be merged if so. +- [ ] Did you update any related docs and document your change? +- [ ] Did you update any related tests and add any new tests related to your change? (see [testing](https://github.com/databricks/megablocks/blob/dev/CONTRIBUTING.md#running-tests)) +- [ ] Did you run the tests locally to make sure they pass? +- [ ] Did you run `pre-commit` on your change? (see the `pre-commit` section of [prerequisites](https://github.com/databricks/megablocks/blob/dev/CONTRIBUTING.md#prerequisites)) + + diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml new file mode 100644 index 0000000..2b1d931 --- /dev/null +++ b/.github/workflows/code-quality.yaml @@ -0,0 +1,41 @@ +name: Code Quality Checks +on: + push: + branches: + - main + - release/** + pull_request: + branches: + - main + - release/** + workflow_call: + workflow_dispatch: +# Cancel old runs when a new commit is pushed to the same branch if not on main or dev +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} +defaults: + run: + working-directory: . +jobs: + code-quality: + runs-on: ubuntu-latest # TODO: switch to linux-ubuntu-latest later + timeout-minutes: 30 + strategy: + matrix: + python_version: + - "3.11" + pip_deps: + - "[dev]" + steps: + - uses: actions/checkout@v3 + - name: Get composite run steps repository + uses: actions/checkout@v3 + with: + repository: mosaicml/ci-testing + ref: v0.1.1 + path: ./ci-testing + - uses: ./ci-testing/.github/actions/code-quality + with: + python_version: ${{ matrix.python_version }} + pip_deps: ${{ matrix.pip_deps }} diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 64d0fb6..e03d37f 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -16,7 +16,7 @@ jobs: pytest-gpu: name: ${{ matrix.name }} if: github.repository_owner == 'databricks' - runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later + runs-on: ubuntu-latest # TODO: switch to linux-ubuntu-latest later strategy: fail-fast: false matrix: @@ -39,7 +39,8 @@ jobs: gpu_num: ${{ matrix.gpu_num }} git_repo: databricks/megablocks pip_deps: "[all,testing]" - pytest_command: "coverage run -m pytest tests" # todo: remove tests from pytest tests when we delete all tests outside of MegaBlocks repo + pytest_command: "coverage run -m pytest tests" + # TODO: remove tests from pytest tests when we delete all tests in the MegaBlocks dir pytest_markers: "gpu" composer_package_name: mosaicml # Required as Composer is built from source mcloud_timeout: 3600 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..1d315f5 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,108 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + +default_language_version: + python: python3 +repos: +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.2 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] +- repo: https://github.com/google/yapf + rev: v0.32.0 + hooks: + - id: yapf + name: yapf + description: A formatter for Python files. + entry: yapf + args: [-i, -vv, -p] # inplace + language: python + types: [python] + additional_dependencies: + - toml +- repo: https://github.com/hadialqattan/pycln + rev: v2.1.2 + hooks: + - id: pycln + args: [. --all] +- repo: https://github.com/pycqa/isort + hooks: + - id: isort + rev: 5.12.0 +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-added-large-files + - id: check-ast + - id: check-builtin-literals + - id: check-case-conflict + - id: check-docstring-first + - id: check-executables-have-shebangs + - id: check-json + - id: check-shebang-scripts-are-executable + - id: pretty-format-json + args: + - --autofix + - --no-sort-keys + - --indent=4 + - --no-ensure-ascii + - id: check-merge-conflict + - id: check-symlinks + - id: check-toml + - id: check-vcs-permalinks + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: destroyed-symlinks + - id: double-quote-string-fixer + - id: end-of-file-fixer + - id: fix-byte-order-marker + - id: mixed-line-ending + - id: trailing-whitespace +- repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.5.4 + hooks: + - id: insert-license + args: + - --license-filepath + - .pre-commit/FILE_HEADER + - --comment-style + - "#" + - --allow-past-years + types: [python] +- repo: https://github.com/PyCQA/docformatter + rev: v1.5.0 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] +- repo: https://github.com/PyCQA/pydocstyle + rev: 6.1.1 + hooks: + - id: pydocstyle + name: pydocstyle + entry: pydocstyle + language: python + types: [python] + exclude: (.ci|.github) + additional_dependencies: + - toml +- repo: https://github.com/adrienverge/yamllint.git + rev: v1.28.0 + hooks: + - id: yamllint + name: yamllint + description: This hook runs yamllint. + entry: yamllint + language: python + types: [file, yaml] +- repo: https://github.com/trufflesecurity/trufflehog.git + rev: v3.40.0 + hooks: + - id: trufflehog + name: secret scan + exclude: tests/horrible_strings.py + entry: trufflehog filesystem ./ + args: + - --only-verified + - --fail diff --git a/.pre-commit/FILE_HEADER b/.pre-commit/FILE_HEADER new file mode 100644 index 0000000..5081c93 --- /dev/null +++ b/.pre-commit/FILE_HEADER @@ -0,0 +1,2 @@ +Copyright 2024 Databricks +SPDX-License-Identifier: Apache-2.0 diff --git a/.yamllint.yaml b/.yamllint.yaml new file mode 100644 index 0000000..84a08ef --- /dev/null +++ b/.yamllint.yaml @@ -0,0 +1,42 @@ +yaml-files: +- "*.yaml" +- "*.yml" +- .yamllint + +ignore: | + wandb + +rules: + braces: + forbid: non-empty + brackets: + forbid: false + colons: enable + commas: enable + comments: enable + comments-indentation: enable + document-end: + present: false + document-start: + present: false + empty-lines: enable + empty-values: disable + hyphens: enable + indentation: + spaces: 2 + indent-sequences: false + check-multi-line-strings: false + key-duplicates: enable + key-ordering: disable + line-length: + max: 120 + allow-non-breakable-words: true + allow-non-breakable-inline-mappings: true + new-line-at-end-of-file: enable + new-lines: enable + octal-values: enable + quoted-strings: + quote-type: double + required: false + trailing-spaces: enable + truthy: disable diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..abbe03d --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,104 @@ +# Contributing to MegaBlocks + +Thanks for considering contributing to MegaBlocks! + +Issues tagged with [good first issue](https://github.com/mosaicml/megablocks/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) are great options to start contributing. + +If you have questions, join us on [Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) -- we'll be happy to help you! + +We welcome contributions for bug fixes, new efficient methods you'd like to contribute to the community, or new models and datasets! + +## Prerequisites + +To set up the development environment in your local box, run the commands below. + +1\. Install the dependencies needed for testing and linting the code: + + +```bash +pip install -e '.[all]' +``` + +2\. Configure [pre-commit](https://pre-commit.com/), which automatically formats code before +each commit: + + +```bash +pre-commit install +``` + +## Submitting a Contribution + +To submit a contribution: + +1\. Fork a copy of the [MegaBlocks](https://github.com/databricks/megablocks) library to your own account. + +2\. Clone your fork locally and add the megablocks repo as a remote repository: + + +```bash +git clone git@github.com:/megablocks.git +cd megablocks +git remote add upstream https://github.com/databricks/megablocks.git +``` + +3\. Create a branch and make your proposed changes. + + +```bash +git checkout -b cool-new-feature +``` + +4\. When you are ready, submit a pull request into the megablocks repository! + +## Pull request (PR) guidelines + +We have some rough guidelines that will make your PR easier to review and more likely to get smoothly merged. Please don't let uncertainty or difficulty with any of these things stop you from opening a PR! We are happy to help you through them :) +* Self-contained title and description. Please include a concise title and clear PR description. The title should allow someone to understand what the PR changes or does at a glance. The description should allow someone to understand the contents of the PR _without_ looking at the code. +* If the PR affects output that is displayed to a user of MegaBlocks (e.g. console logging or experiment tracker reporting), please include screenshots showing what the new output looks like. UX is important! +* Include tests. If you are fixing a bug, please add a test that would've caught the bug. If you are adding a new feature, please add unit tests that test the various components of the feature, and also a test that tests the full functionality of the feature. +* Please consider whether your changes affect the example notebooks or large parts of the code base, and run the daily tests locally if so (`pytest -m 'daily and not remote and not gpu and not vision and not doctest'`) +* `pre-commit` should help you handle formatting and type checking, but please do make sure you have it installed as described [above](#prerequisites). + +## Configuring README Code Snippets + +MegaBlocks uses [pytest-codeblocks](https://github.com/nschloe/pytest-codeblocks) to test all example code snippets. The pytest-codeblocks repository explains how to annotate code snippets, which supports most `pytest` configurations. For example, if a test requires model training, the GPU mark (``) should be applied. + +## Running Tests + +To test your changes locally, run: + +* `make test` # run CPU tests +* `make test-gpu` # run GPU tests +* `cd docs && make doctest` # run doctests + +Some of our checks test distributed training as well. To test these, run: + +* `make test-dist WORLD_SIZE=2` # run 2-cpu distributed tests +* `make test-dist-gpu WORLD_SIZE=2` # run 2-gpu distributed tests + +These tests run with the `composer` launcher. We also support `WORLD_SIZE=1`, which would run the tests with the `composer` launcher on a single device. + +See the [Makefile](/Makefile) for more information. + +If you want to run pre-commit hooks manually, which check for code formatting and type annotations, run `pre-commit run --all-files` + +### Docker + +To run the tests in the provided docker containers: + +* `docker pull mosaicml/composer` (or an alternative image like `mosaicml/composer:latest_cpu`) +* `docker run --rm -v ./:/composer --user $(id -u):$(id -g) -it mosaicml/composer` +* from inside the container + * `cd /megablocks` + * `pip install -e .` + * `pytest ` or `make ` to run the desired tests + + +## Code Style & Typing + +See the [MegaBlocks Style Guide](/STYLE_GUIDE.md) for guidelines on how to structure and format your code. + +MegaBlocks aims to annotate all functions with type annotations (introduced in +[PEP 526](https://www.python.org/dev/peps/pep-0526/)). Don't worry if you are not a Python typing expert; +put in the pull request, and we'll help you with getting the code into shape. diff --git a/Dockerfile b/Dockerfile index c71ed0a..e5d9ef8 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,4 +6,4 @@ RUN pip install flash-attn ENV PYTHONPATH="/mount/megablocks/third_party/Megatron-LM:${PYTHONPATH}" -WORKDIR /mount/megablocks \ No newline at end of file +WORKDIR /mount/megablocks diff --git a/LICENSE b/LICENSE index 63fd052..be2d25e 100644 --- a/LICENSE +++ b/LICENSE @@ -387,4 +387,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in index 99749aa..b701a75 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ recursive-include csrc *.h -recursive-include csrc *.cu \ No newline at end of file +recursive-include csrc *.cu diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md new file mode 100644 index 0000000..3d5876d --- /dev/null +++ b/STYLE_GUIDE.md @@ -0,0 +1,480 @@ +# 1. Style and Conventions + +## 1.1 Style Guide + +MegaBlocks generally follows Google's +[Python Style Guide](https://google.github.io/styleguide/pyguide.html) for how to format and structure code. + +## 1.2. Pre-Commit Hooks + +MegaBlocks uses [Pre Commit](https://pre-commit.com/) to enforce style checks. To configure, run +``` +pip install '.[dev]' # if not already installed +pre-commit install +``` + +The pre-commit hooks will now be run before each commit. You can also run the hooks manually via: + +``` +pre-commit run # run all hooks on changed files +pre-commit run --all-files # or, run all hooks on all files +``` + + + +## 1.3. Code Formatting + +MegaBlocks uses the [yapf](https://github.com/google/yapf) formatter for general formatting +[isort](https://github.com/PyCQA/isort) to sort imports. These checks run through pre-commit +(see section 2.2). These checks can also be run manually via: + +``` +pre-commit run yapf --all-files # for yapf +pre-commit run isort --all-files # for isort +``` + +The configuration is stored in [pyproject.toml](pyproject.toml). + + +## 1.4. Code Structure + +As a general rule of thumb, + +- Don't: Default to using inheritance for code reuse + + Do: prefer [composition over inheritance](https://en.wikipedia.org/wiki/Composition_over_inheritance) +- Don't: strive to implement all logic using classes + + Do: strive to implement logic as pure functions when possible, and classes when there is good reason +- Don't: Have a function accept falsy values that would then result in a no-op. + + Example of the anti-pattern: + + ```python + from typing import Optional + + def configure_deepspeed(deepspeed_config: Optional[dict]): + if deepspeed_config is None: + # Don't do this check in the callee, which results in a no-op + return + ... + ``` + + Do: Require the caller, instead of the callee, check for and handle falsy values. It's ok to accept falsy values + for individual arguments of a caller function, so long as the entire function would not be a no-op. + + Example: + ```python + from typing import Optional + + def configure_deepspeed(deepspeed_config: dict): + ... + + def trainer(deepspeed_config: Optional[dict]): + if deepspeed_config is not None: + # Do this check in the caller function + configure_deepspeed(deepspeed_config) + ... + ``` + +# 2. Type Annotations and Typechecking + +MegaBlocks aims to annotate all functions with type annotations (introduced in +[PEP 526](https://www.python.org/dev/peps/pep-0526/)). Type annotations help statically catch `TypeError` and +`AttributeError` bugs, in addition to other benefits, as outlined in the PEP. + +For documentation on typing annotations, see: +* [PEP 483](https://peps.python.org/pep-0483/) for a simplified introducation +* [PEP 484](https://peps.python.org/pep-0484/) for the full specification +* [Python docs for `typing`](https://docs.python.org/3/library/typing.html) for the API reference + +MegaBlocks uses [pyright](https://github.com/microsoft/pyright) +to validate type annotations. PyRight is automatically run as part of the pre-commit hooks, but you can also +run PyRight specifically via: + +``` +pre-commit run pyright --all-files +``` + +The pyright configuration is stored in [pyproject.toml](pyproject.toml). + + +## 2.1 Debugging + +Here are some suggestions to deal with pyright errors: + +1. Suppose a variable could be one of multiple types, like the following: + + ```python + from typing import Union + + def foo(x: Union[int, None]): + return x + 5 # type error -- None + 5 is not allowed! + ``` + + PyRight will complain since `None + 5` is not a valid operation. + Instead, add a check to ensure that `x is not None`: + + ```python + from typing import Union + + def foo(x: Union[int, None]): + if x is None: + raise TypeError("x must be an integer, not None!") + return x + 5 # valid + ``` + + Assert statements also work. However, assert statements should not be used for data validation + (see the assert statement section below). + ```python + from typing import Union + + def foo(x: Union[int, None]): + assert x is not None, "x should never be None" + return x + 5 # valid + ``` + +1. For variables where it is impossible for pyright to infer the correct type, use +[cast](https://docs.python.org/3/library/typing.html#typing.cast). +1. As a last resort, add a `# type: ignore` comment to the line where pyright emits an error. +Immediately following this statement, paste in the error emitted by pyright, +so other contributors will know why this error was silenced. + + +# 3. Public APIs +A public API, generally speaking, can be invoked by a user without a leading underscore in any portion of the path. +The following are examples of public APIs in [composer](https://github.com/mosaicml/composer/tree/dev): + +* Standalone functions in public modules (e.g. `composer.utils.dist.get_world_size`) +* Classes in public modules (e.g. `composer.trainer.trainer.Trainer`) +* Public methods in public classes (e.g. `composer.trainer.trainer.Trainer.fit`) +* Public modules (e.g. `composer.trainer.trainer`) + +The following rules apply to public APIs: +1. All public APIs must have a docstring (see the Documentation section below) +1. All parameters must have type annotations. +1. To minimize user imports, parameters should should use native PyTorch or Python types whenever possible. + + It is acceptable to use a union of types, so long as one of the options is a primitive. For example, in the + constructor for `composer.trainer.trainer.Trainer`, the `device` parameter is annotated like the following: + + ```python + from typing import Optional, Union + + from composer.devices import Device + + class Trainer: + def __init__( + self, + device: Union[str, Device], + ): + if isinstance(device, str): + device = Device(device) + ... + ``` + + This signature allows a user to pass a string for a device, + rather than having to import our custom device class. + + Parameters that are for power users (such as `load_object_store`) in the Trainer are exempt from this rule. + These parameters can require custom imports. + +1. Parameters that could take a sequence of elements should also allow `None` or a singleton. + This simplifies the user API by not having to construct a list (or tuple) to hold a single element + (or no element). For example, use `Optional[Union[torch.Tensor, Sequence[torch.Tensor]]`. + + The `composer.utils.ensure_tuple` helper method can convert a singleton, list, or tuple into a tuple. + For example + + ```python + from torch import Tensor + from typing import Optional, Sequence, Union + from composer.utils import ensure_tuple + + def foo(x: Optional[Union[Tensor, Sequence[Tensor]]]) -> tuple[Tensor, ...]: + return ensure_tuple(x) # ensures that the result is always a (potentially empty) tuple of tensors + ``` + + +# 4. Use of `assert` + +`assert` should be used only in test cases and for verifying invariants (likely required for type checking), +not for data validation. As asserts can be disabled in python by using the `-O` flag +(e.g. `python -O path/to/script.py`), they are not guaranteed to run. For data validation, instead use a style like +the following: + + + + +```python +if parameter is None: + raise ValueError("parameter must be specified and cannot be None") +``` + + +# 5. Imports and `__init__.py` + +All imports in MegaBlocks should be absolute -- that is, they do not begin with a period. + +## 5.1 External Dependencies +1. All external dependencies must be specified in both [setup.py](setup.py) for pip and [meta.yaml](meta.yaml) + for Anaconda. + +1. If a dependency is not core to MegaBlocks (e.g. it is for a model, dataset, algorithm, or some callbacks): + 1. It must be specified in a entry of the `extra_deps` dictionary of [setup.py](setup.py). + This dictionary groups dependencies that can be conditionally installed. An entry named `foo` can be installed with `pip install 'megablocks[foo]'`. For example, running `pip install 'megablocks[gg]'` will install everything in `install_requires`, along with `grouped_gemm`. + 1. It must also be specified in the `run_constrained` and the `test.requires` section. + 1. The import must be conditionally imported in the code. For example: + + + + ```python + from composer import Callback + from composer.utils import MissingConditionalImportError + + class SystemMetricsMonitor(Callback) + try: + import pynvml + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group="system_metrics_monitor", + conda_package="pynvml", + conda_channel="conda-forge",) from e + ``` + + This style allows users to perform minimal install of Composer without triggering `ImportError`s if + an optional dependency is missing. + + If the corresponding package is not published on Anaconda, then set the ``conda_package`` to the pip package + name, and set ``conda_channel`` to ``None``. For example, with DeepSpeed: + + + ```python + from composer.utils import MissingConditionalImportError + + try: + import deepspeed + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group="deepspeed", + conda_package="deepspeed>=0.5.5", + conda_channel=None) from e + ``` + + + + 1. If the dependency is core to MegaBlocks, add the dependency to the `install_requires` section of + [setup.py](./setup.py) and the `requirements.run` section of [meta.yaml](./meta.yaml). + +## 5.2 Use of `__all__` + +All public modules must define `__all__` to be the list of members that should be re-exported. +The variable is necessary to 1) limit what `from XXX import *` imports, and 2) ensure that the documentation only +includes exported members, not unrelated re-imports. + +For example, from [composer/callbacks/memory_monitor.py](composer/callbacks/memory_monitor.py) + +```python +"""Log memory usage during training.""" +import logging +from typing import Union + +import torch.cuda + +from composer.core import State +from composer.loggers import Logger +from composer.core.callback import Callback + +log = logging.getLogger(__name__) + +__all__ = ["MemoryMonitor"] # export only the MemoryMonitor, not other imports like `Logger`, `State`, or `Callback` + + +class MemoryMonitor(Callback): + ... +``` + + +## 5.3 `__init__.py` + +All public classes and functions should be added to the module's `__init__.py`. + + +```python +from composer.path.to.module.file import MyClass as MyClass +from composer.path.to.module.file import my_func as my_func +``` + +If a file only contains public functions, then the following is also acceptable: + + +```python +from composer.path.to.module import my_file as my_file +``` + + +# 6. Documentation + +## 6.1 Docstrings + +MegaBlocks uses [Google Style Docstrings](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html). +All public APIs require documentation. + +### 6.1.1 What to include in Docstrings? + +Docstrings, at a minimum, should include a summary of what the function or class does, along with the arguments it takes. See [below](#612-formatting-docstrings) for how to format docstrings. The [Google Style Guide](https://google.github.io/styleguide/pyguide.html) also includes some guidelines on how to write docstrings. + +### 6.1.2 Formatting Docstrings + +The following guidelines apply to documentation. +1. Each function that needs a docstring must have its input arguments, return statement (if not None), and any custom + exceptions annotated. +1. The arguments for the `__init__` signature of classes should be documented under the class-level docstring. There + should not be any `__init__`-level docstring. +1. Each argument annotation should include the type. If the argument has a default value, the type annotation should + specify "optional", and the docstring should say the default value. Some examples: + + ```python + from typing import Optional, Union + + def foo(bar: int): + """Foo. + + Args: + bar (int): Required bar. + """ + ... + + def foo2(bar: int = 42): + """Foo2. + + Args: + bar (int, optional): The first Argument. Default: ``42``. + """ + ... + + def foo3(bar: Optional[int] = None): + """Foo3. + + Args: + bar (int, optional): The first Argument. Default: ``None``. + """ + ... + + def foo4(bar: Union[int, str] = 42): + """Foo4. + + Args: + bar (int | str, optional): The first Argument. Default: ``42``. + """ + ... + + def foo5(bar: int) -> int: + """Foo5. + + Args: + bar (int): Required bar. + + Returns: + int: Description of return statement. + """ + ... + + def foo6(bar: int) -> tuple[int, str]: + """Foo6. + + Args: + bar (int): Required bar. + + Returns: + a (int): Returned value. + b (str): Returned value. + """ + ... + ``` + +### 6.1.3 Building and Viewing Docs Locally + +Assuming you already have a development install of MegaBlocks (see these [instructions](CONTRIBUTING.md#prerequisites)), here’s how to build and previous the docs locally. + +**️️ ⚠ Warning:** Jenkins treats all sphinx warnings as errors, so they must be addressed before a PR can be merged. Building docs locally can help debug any warnings showing up on Jenkins! + +In one terminal, run: + + +```bash +source path/to/megablocks_venv/bin/activate # activate your megablocks virtual env +cd megablocks/docs # cd to the docs folder insde your megablocks clone +make clean +make html +``` + +In a second terminal, run: + + +```bash +cd megablocks/docs +python3 -m http.server --directory _build/html/ +``` + +Then, navigate to [http://localhost:8000](http://localhost:8000) in your browser. + +## 6.2 Doctests + +Most docstrings should also include a `.. doctest` or `.. testcode` example to clearly illustrate how one would interact with the class or function. As part of the CI/CD process, all `.. doctest` blocks are executed to ensure the example in the documentation actually works. + +### 6.2.1 Writing Doctests + +See the [Sphinx Doctest Extension](https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html) for all of the available directives. Do not use `.. code-block::` for Python examples, as they are untested. + +Any test fixtures for doctests should go in [docs/source/doctest_fixtures.py](docs/source/doctest_fixtures.py) or in a `.. testsetup::` block. + +For example: +```python +import torch +from typing import Optional + +def my_function(x: Optional[torch.Tensor]) -> torch.Tensor: + """blah function + + Args: + input (torch.Tensor): Your guess. + + Returns: + torch.Tensor: How good your input is. + + Raises: + ValueError: If your input is negative. + + Example: + .. testsetup:: + + # optional setup section, not shown in docs + import torch + x = torch.randn(42) + + + .. testcode:: + + # shown in docs; runs after testsetup + my_function(x) + """ + ... +``` + +All doctests load the [docs/source/doctest_fixtures.py](docs/source/doctest_fixtures.py) file *before* tests run. If there are any variables that would be helpful have defined for all tests, feel free to add them into this file. However, if a variable is more specific to an individual doctest, then it would be best to include it in a `.. testsetup::` block, as not to pollute the global fixture namespace. (Unlike pytest fixtures, all doctest fixtures are given to every doctest; they cannot be specifically requested) + +### 6.2.2 Running Doctests Locally + +Assuming you already have a development install of MegaBlocks (see these [instructions](CONTRIBUTING.md#prerequisites)), here’s how to run the doctests. + + +```bash +source path/to/megablocks_venv/bin/activate # activate your megablocks virtual env +cd megablocks/docs # cd to the docs folder insde your megablocks clone +make clean +make html # the html build must be completed first to ensure all doctests are identified +make doctest 2>/dev/null # For more verbosity, do not direct stderr to /dev/null +``` diff --git a/docker.sh b/docker.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_125m_8gpu.sh b/exp/dmoe/dmoe_125m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_356m_8gpu.sh b/exp/dmoe/dmoe_356m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_46m_8gpu.sh b/exp/dmoe/dmoe_46m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_760m_8gpu.sh b/exp/dmoe/dmoe_760m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_125m_1gpu.sh b/exp/gpt2/gpt2_125m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_125m_8gpu.sh b/exp/gpt2/gpt2_125m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_1315m_1gpu.sh b/exp/gpt2/gpt2_1315m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_1315m_8gpu.sh b/exp/gpt2/gpt2_1315m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_356m_1gpu.sh b/exp/gpt2/gpt2_356m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_356m_8gpu.sh b/exp/gpt2/gpt2_356m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_46m_1gpu.sh b/exp/gpt2/gpt2_46m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_46m_8gpu.sh b/exp/gpt2/gpt2_46m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_760m_1gpu.sh b/exp/gpt2/gpt2_760m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_760m_8gpu.sh b/exp/gpt2/gpt2_760m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/moe/moe_125m_8gpu.sh b/exp/moe/moe_125m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/moe/moe_356m_8gpu.sh b/exp/moe/moe_356m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/moe/moe_46m_8gpu.sh b/exp/moe/moe_46m_8gpu.sh old mode 100644 new mode 100755 diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 90e4511..d8d1848 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,2 +1,20 @@ -import megablocks.layers.dmoe -import megablocks.layers.moe +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +__all__ = [ + 'MoE', + 'dMoE', + 'get_load_balancing_loss', + 'ParallelMLP', + 'ParallelDroplessMLP', + 'SparseMLP', + 'MLP', + 'SparseGLU', + 'Arguments', +] diff --git a/megablocks/_version.py b/megablocks/_version.py index 2bb5d50..44ea780 100644 --- a/megablocks/_version.py +++ b/megablocks/_version.py @@ -1,4 +1,4 @@ -# Copyright 2022 MegaBlocks Composer authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 """The MegaBlocks Version.""" diff --git a/megablocks/backend/__init__.py b/megablocks/backend/__init__.py index 8b13789..9d4e43e 100644 --- a/megablocks/backend/__init__.py +++ b/megablocks/backend/__init__.py @@ -1 +1,2 @@ - +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index f99f93c..b831826 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch import triton import triton.language as tl @@ -5,7 +8,7 @@ def assert_is_tensor(x, ndim): if x.ndim != ndim: - raise ValueError(f"Expected {ndim}-tensor but got {x.ndim}-tensor") + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') def assert_is_matrix(x): @@ -14,12 +17,12 @@ def assert_is_matrix(x): def assert_is_vector(x): if x.ndim != 1: - raise ValueError(f"Expected 1-tensor but got {x.ndim}-tensor") + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') def assert_equal(a, b): if a != b: - raise ValueError(f"Expected dimensions to be equal but got {a} and {b}.") + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) # a: (tokens, hidden_size), real. @@ -40,18 +43,19 @@ def assert_equal(a, b): ) @triton.jit def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr, - A_TO_B : tl.constexpr, - SCALE : tl.constexpr): + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): # Our index into array 'a'. index_a = tl.load(indices + tl.program_id(0)) @@ -62,12 +66,12 @@ def _padded_copy( # Now we know what bin we're assigned to, but we need to know how # many threadblocks were assigned to earlier bins so we can offset # in our bin properly. - offset_in_bin = tl.program_id(0); + offset_in_bin = tl.program_id(0) if bin_idx > 0: offset_in_bin -= tl.load(bins + bin_idx - 1) # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin; + index_b = offset_in_bin if bin_idx > 0: index_b += tl.load(padded_bins + bin_idx - 1) @@ -116,10 +120,7 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): # NOTE: Because of the padding, the output size is dynamic. # We load the final padded bin bound to get the output rows. output_rows = padded_bins[-1].cpu().item() - out = torch.zeros( - (output_rows, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, out, @@ -131,7 +132,8 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) return out @@ -150,10 +152,7 @@ def gather(x, indices, bin_ids, weights, bins, top_k): # NOTE: There is no padding so the output rows equals the # input rows multiplied by top_k. output_rows = x.shape[0] * top_k - out = torch.empty( - (output_rows, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, out, @@ -165,7 +164,8 @@ def gather(x, indices, bin_ids, weights, bins, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) return out @@ -183,10 +183,7 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): assert_equal(indices.shape[0], weights.shape[0]) tokens = indices.shape[0] // top_k - out = torch.empty( - (tokens, top_k, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( out, x, @@ -198,7 +195,8 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=False, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) # Reduce along the top-k dimension, if needed. return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) @@ -227,16 +225,17 @@ def scatter(x, indices, bin_ids, weights, bins, top_k): ) @triton.jit def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr): + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): # Our index into 'tokens * top_k'. index_out = tl.load(indices + tl.program_id(0)) @@ -247,12 +246,12 @@ def _padded_copy_wgrad( # Now we know what bin we're assigned to, but we need to know how # many threadblocks were assigned to earlier bins so we can offset # in our bin properly. - offset_in_bin = tl.program_id(0); + offset_in_bin = tl.program_id(0) if bin_idx > 0: offset_in_bin -= tl.load(bins + bin_idx - 1) # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin; + index_x = offset_in_bin if bin_idx > 0: index_x += tl.load(padded_bins + bin_idx - 1) @@ -264,7 +263,7 @@ def _padded_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + for i in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -288,10 +287,7 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): assert_equal(bins.size(), padded_bins.size()) tokens = indices.shape[0] // top_k - out = torch.empty( - (tokens * top_k), - dtype=x.dtype, - device=x.device) + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) _padded_copy_wgrad[(indices.shape[0],)]( x, grad, @@ -301,7 +297,8 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): bins, padded_bins, NUM_COLUMNS=x.shape[1], - TOP_K=top_k) + TOP_K=top_k, + ) return out @@ -326,18 +323,19 @@ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): ) @triton.jit def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr, - A_TO_B : tl.constexpr, - SCALE : tl.constexpr): + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): # Load our indices into the output. expert_idx = tl.program_id(0) entry_idx = tl.program_id(1) @@ -349,7 +347,7 @@ def _binned_copy( # the number of tokens assigned to our expert. start = 0 if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) + start = tl.load(bins + expert_idx - 1) end = tl.load(bins + expert_idx) num_tokens = end - start @@ -380,7 +378,7 @@ def _binned_copy( optr = b if A_TO_B else a iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + for i in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -401,10 +399,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): assert_equal(weights.shape[0], x.shape[0] * top_k) num_experts = bins.shape[0] - out = torch.zeros( - (num_experts, expert_capacity, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) _binned_copy[(num_experts, expert_capacity)]( x, @@ -417,7 +412,8 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) return out @@ -433,10 +429,7 @@ def binned_scatter(x, indices, weights, bins, top_k): num_experts, expert_capacity, hidden_size = x.shape tokens = indices.shape[0] // top_k - out = torch.zeros( - (tokens, top_k, hidden_size), - dtype=x.dtype, - device=x.device) + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) _binned_copy[(num_experts, expert_capacity)]( out, x, @@ -448,7 +441,8 @@ def binned_scatter(x, indices, weights, bins, top_k): NUM_COLUMNS=hidden_size, A_TO_B=False, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) # Reduce along the top-k dimension, if needed. return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) @@ -471,16 +465,17 @@ def binned_scatter(x, indices, weights, bins, top_k): ) @triton.jit def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr): + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): # Load our indices into the output. expert_idx = tl.program_id(0) entry_idx = tl.program_id(1) @@ -492,7 +487,7 @@ def _binned_copy_wgrad( # the number of tokens assigned to our expert. start = 0 if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) + start = tl.load(bins + expert_idx - 1) end = tl.load(bins + expert_idx) num_tokens = end - start @@ -510,7 +505,7 @@ def _binned_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + for i in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -532,10 +527,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k): num_experts, expert_capacity, hidden_size = x.shape tokens = indices.shape[0] // top_k - out = torch.zeros( - (tokens * top_k), - dtype=x.dtype, - device=x.device) + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) _binned_copy_wgrad[(num_experts, expert_capacity)]( x, grad, @@ -545,5 +537,6 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k): indices, bins, NUM_COLUMNS=hidden_size, - TOP_K=top_k) + TOP_K=top_k, + ) return out diff --git a/megablocks/benchmark_util.py b/megablocks/benchmark_util.py index abf3521..02612d9 100644 --- a/megablocks/benchmark_util.py +++ b/megablocks/benchmark_util.py @@ -1,16 +1,19 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import torch def log_benchmark(name, arguments, time, std): - print("="*60) - print(f"{name} Benchmark") - print("Benchmark Parameters:") + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') for (key, value) in arguments.items(): - print(f"{key} = {value}") - print("Results:") - print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std)) - print("="*60) + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) def benchmark_function(fn, iterations=100, warmup=10): @@ -26,7 +29,7 @@ def benchmark_function(fn, iterations=100, warmup=10): start.record() fn() end.record() - + torch.cuda.synchronize() times.append(start.elapsed_time(end)) return np.mean(times), np.std(times) diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index be24c6f..07dbc04 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -1,15 +1,21 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + try: import grouped_gemm except ImportError: grouped_gemm = None + def grouped_gemm_is_available(): return grouped_gemm is not None + def assert_grouped_gemm_is_available(): - assert grouped_gemm_is_available(), ( - "Grouped GEMM not available. Please run " - "`pip install git+https://github.com/tgale96/grouped_gemm@main`.") + assert grouped_gemm_is_available( + ), ('Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.') + backend = grouped_gemm.backend if grouped_gemm_is_available() else None ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/megablocks/layers/__init__.py b/megablocks/layers/__init__.py index 8b13789..f0c42de 100644 --- a/megablocks/layers/__init__.py +++ b/megablocks/layers/__init__.py @@ -1 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + 'dMoE', +] diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 613ef31..736d311 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,10 +1,18 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from typing import Callable -import torch import stk +import torch -def act_fn(x: stk.Matrix, function: Callable, return_grad_fn: bool = False, **kwargs): +def act_fn( + x: stk.Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +): assert isinstance(x, stk.Matrix) with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): if return_grad_fn: @@ -18,7 +26,8 @@ def act_fn(x: stk.Matrix, function: Callable, return_grad_fn: bool = False, **kw x.offsets, x.column_indices_t, x.offsets_t, - x.block_offsets_t) + x.block_offsets_t, + ) if return_grad_fn: return y, out.backward return y diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 12098eb..82a6f40 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -1,23 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch + class AllToAllOp(torch.autograd.Function): @staticmethod def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty( - (sum(output_split_sizes),) + x.shape[1:], - device=x.device, dtype=x.dtype) + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) ctx.input_shape = x.shape ctx.output_split_sizes = output_split_sizes ctx.input_split_sizes = input_split_sizes ctx.group = group handle = torch.distributed.all_to_all_single( - out, x, + out, + x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=group, - async_op=async_op) + async_op=async_op, + ) return out, handle @staticmethod @@ -26,15 +30,24 @@ def backward(ctx, grad, _): out = torch.empty( ctx.input_shape, device=grad.device, - dtype=grad.dtype) + dtype=grad.dtype, + ) torch.distributed.all_to_all_single( - out, grad, + out, + grad, output_split_sizes=ctx.input_split_sizes, input_split_sizes=ctx.output_split_sizes, - group=ctx.group) + group=ctx.group, + ) return out, None, None, None, None return None, None, None, None, None + def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): return AllToAllOp.apply( - x, output_split_sizes, input_split_sizes, group, async_op) + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index 9b6c49b..efe131d 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -1,66 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import dataclasses from functools import partial -import megablocks.grouped_gemm_util as grouped_gemm +from typing import Any, Callable, Optional, Union + import torch import torch.nn.functional as F -from typing import Any, Callable, Optional, Union + +import megablocks.grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. InitFn = Callable[[torch.Tensor], None] _ALLOWED_BITWIDTHS = (-1, 4, 8) -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate="tanh") +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') @dataclasses.dataclass class Arguments: # Model arguments. - hidden_size : int = 1024 - ffn_hidden_size : int = 4096 - num_layers : int = 1 - bias : bool = True - return_bias : bool = True - activation_fn : Optional[Callable] = DEFAULT_ACTIVATION_FN + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN # MoE arguments. - moe_num_experts : int = 1 - moe_top_k : int = 1 - moe_capacity_factor : int = 1 - moe_normalize_expert_weights : Optional[Union[int, float]] = None - moe_loss_weight : float = 0.1 - moe_jitter_eps : Optional[float] = None - moe_lbl_in_fp32 : bool = False + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False # Parallelism arguments. - moe_expert_model_parallelism : bool = False - expert_parallel_group : Optional[torch.distributed.ProcessGroup] = None - moe_weight_parallelism : bool = False - weight_parallel_group : Optional[torch.distributed.ProcessGroup] = None - pipeline_model_parallel_size : int = 1 - num_layers_per_virtual_pipeline_stage : Optional[int] = None + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None + moe_weight_parallelism: bool = False + weight_parallel_group: Optional[torch.distributed.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None # Compute arguments. - memory_optimized_mlp : bool = False - mlp_type : str = 'mlp' - mlp_impl : str = 'sparse' + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' # Initialization arguments. - fp16 : bool = True + fp16: bool = True bf16: bool = False - device : torch.device = torch.cuda.current_device() - init_method : InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method : InitFn = init_method + device: torch.device = torch.cuda.current_device() + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method # Benchmarking arguments. - uniform_expert_assignment : bool = False + uniform_expert_assignment: bool = False # shared expert arguments shared_expert: bool = False # enable using shared expert fc_cls: torch.nn.Module = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) # kwargs for custom fc layers + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) def __post_init__(self): diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index fd99aa4..ee30e79 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -1,7 +1,12 @@ -from megablocks.layers.arguments import Arguments +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch -def dtype(args : Arguments): +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): if args.fp16: return torch.float16 elif args.bf16: diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index 666398a..d765bd0 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -1,20 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from typing import Union -from megablocks.layers import mlp -from megablocks.layers import glu + +from megablocks.layers import glu, mlp from megablocks.layers.arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] _REGISTRY = { - 'mlp': {'grouped': mlp.GroupedMLP, 'sparse' : mlp.SparseMLP}, - 'glu': {'grouped': glu.GroupedGLU, 'sparse': glu.SparseGLU}, + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, } + def get(args: Arguments) -> MlpType: """Returns an MLP for use in a dMoE instance. Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs + MLP instance. This only contains MLPs for use in dMoEs (ie. only for the dropless versions of MoEs). Args: @@ -22,12 +32,11 @@ def get(args: Arguments) -> MlpType: Returns: An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: + if args.mlp_type not in _REGISTRY: raise ValueError(f'Unsupported mlp type: {args.mlp_type}') if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.') + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 04a538d..e683f8a 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -1,19 +1,22 @@ -from megablocks.layers import common -from megablocks.layers import moe -from megablocks.layers import dmlp_registry -from megablocks.layers import mpu -from megablocks.layers.arguments import Arguments -import megablocks.ops as ops +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import stk import torch +import megablocks.ops as ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + + def promote_scalar(x): return x.view(1) if not len(x.size()) else x + class ParallelDroplessMLP(moe.ParallelMLP): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super(ParallelDroplessMLP, self).__init__(args) self.hidden_size = args.hidden_size self.ffn_hidden_size = mpu.features_per_rank(args) @@ -22,10 +25,11 @@ def __init__(self, args : Arguments): # Calculate the number of bits needed to represent the column indices # in the intermediate sparse matrix. - max_column_index = ( - (self.ffn_hidden_size * self.num_experts) // self.blocking) + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), 1) + int(np.ceil(np.log2(max_column_index))), + 1, + ) def sparse_transpose(self, size, row_indices, column_indices, offsets): block_columns = size[1] // self.blocking @@ -37,7 +41,9 @@ def sparse_transpose(self, size, row_indices, column_indices, offsets): # To avoid overflow when we have large activation matrices we cast to # 32-bit before sorting. _, gather_indices = ops.sort( - column_indices.int(), self.transpose_sort_end_bit) + column_indices.int(), + self.transpose_sort_end_bit, + ) # There are a constant number of blocks in every row of the sparse matrix. # A blocks offset is: @@ -62,8 +68,10 @@ def topology(self, x, padded_bins): padded_tokens, _ = x.size() assert padded_tokens % self.blocking == 0 if self.ffn_hidden_size % self.blocking != 0: - raise ValueError(f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.') + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) # Offsets for the sparse matrix. All rows have the # same number of nonzero blocks dictated by the @@ -75,15 +83,18 @@ def topology(self, x, padded_bins): block_rows * blocks_per_row + 1, blocks_per_row, dtype=torch.int32, - device=x.device) + device=x.device, + ) # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, - self.blocking, - block_rows, - blocks_per_row) + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) # TODO(tgale): This is unused. Remove the need for this in stk. # For now, use meta init to save the device memory. @@ -92,17 +103,29 @@ def topology(self, x, padded_bins): self.blocking, self.blocking, dtype=common.dtype(self.args), - device='meta') + device='meta', + ) shape = ( padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args) + self.ffn_hidden_size * mpu.experts_per_rank(self.args), ) - row_indices = stk.ops.row_indices( - shape, data, offsets, column_indices) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, row_indices, column_indices, offsets) - return stk.Matrix(shape, data, row_indices, column_indices, offsets, - column_indices_t, offsets_t, block_offsets_t) + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) def indices_and_padded_bins(self, top_experts): # Sort the expert ids to produce the scatter/gather @@ -118,7 +141,9 @@ def indices_and_padded_bins(self, top_experts): # the matrix muliplications. Caculate the starting # position of each bin. padded_tokens_per_expert = ops.round_up( - tokens_per_expert, self.blocking) + tokens_per_expert, + self.blocking, + ) padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) padded_bins = promote_scalar(padded_bins) @@ -134,8 +159,7 @@ def sparse_forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = ( - self.indices_and_padded_bins(top_experts)) + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) @@ -145,7 +169,8 @@ def sparse_forward_once(self, x, expert_weights, top_experts): bin_ids, bins, padded_bins, - self.top_k) + self.top_k, + ) # Create the sparse matrix topology. with torch.no_grad(): @@ -162,37 +187,35 @@ def sparse_forward_once(self, x, expert_weights, top_experts): expert_weights, bins, padded_bins, - self.top_k) + self.top_k, + ) return x, tokens_per_expert # For use in the base-class parallel_forward_once. def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k): + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): # Round the token counts up to the block size used in the matrix # multiplication. Calculate the starting position of each bin. padded_tokens_per_expert = ops.round_up( - tokens_per_expert, self.blocking) + tokens_per_expert, + self.blocking, + ) padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) padded_bins = promote_scalar(padded_bins) # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - top_k) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) # Create the sparse matrix topology. with torch.no_grad(): @@ -201,7 +224,6 @@ def sparse_permute_and_compute( # Perform the expert computation. x = self.mlp(x, topo) - # Un-route the data for the MoE output. return ops.padded_scatter( x, @@ -210,7 +232,8 @@ def sparse_permute_and_compute( expert_weights, bins, padded_bins, - top_k) + top_k, + ) def grouped_forward_once(self, x, expert_weights, top_experts): # x: [sl, bs, hs] @@ -219,8 +242,7 @@ def grouped_forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts)) + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) out = self.grouped_permute_and_compute( x, @@ -230,60 +252,49 @@ def grouped_forward_once(self, x, expert_weights, top_experts): expert_weights, bins, -1, # unused - self.args.moe_top_k) + self.args.moe_top_k, + ) return out, tokens_per_expert def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k): + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.gather( - x, - indices, - bin_ids, - bins, - top_k) + x = ops.gather(x, indices, bin_ids, bins, top_k) # Perform the expert computation. x = self.mlp(x, tokens_per_expert) # Un-route the data for the MoE output. - return ops.scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - top_k) + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) def forward_once(self, x, expert_weights, top_experts): if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once( - x, expert_weights, top_experts) + return self.sparse_forward_once(x, expert_weights, top_experts) else: - return self.grouped_forward_once( - x, expert_weights, top_experts) - + return self.grouped_forward_once(x, expert_weights, top_experts) def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k): + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): if self.args.mlp_impl == 'sparse': return self.sparse_permute_and_compute( x, @@ -293,7 +304,8 @@ def permute_and_compute( expert_weights, bins, expert_capactiy, - top_k) + top_k, + ) else: return self.grouped_permute_and_compute( x, @@ -303,7 +315,8 @@ def permute_and_compute( expert_weights, bins, expert_capactiy, - top_k) + top_k, + ) class dMoE(moe.MoE): diff --git a/megablocks/layers/gelu.py b/megablocks/layers/gelu.py index 49ac4a8..40b601d 100644 --- a/megablocks/layers/gelu.py +++ b/megablocks/layers/gelu.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import stk import torch import torch.nn.functional as F @@ -6,12 +9,7 @@ @torch.jit.script def _gelu_backward_inplace(g, x): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = ( - 0.5 * x * ( - (1 - tanh_out * tanh_out) * - (0.79788456 + 0.1070322243 * x * x) - ) + 0.5 * (1 + tanh_out) - ) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) return g.mul_(ff) @@ -26,7 +24,8 @@ def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): x.offsets, x.column_indices_t, x.offsets_t, - x.block_offsets_t) + x.block_offsets_t, + ) return _gelu_backward_inplace(grad, x) @@ -34,10 +33,11 @@ def gelu(x: stk.Matrix): assert isinstance(x, stk.Matrix) return stk.Matrix( x.size(), - F.gelu(x.data, approximate="tanh"), + F.gelu(x.data, approximate='tanh'), x.row_indices, x.column_indices, x.offsets, x.column_indices_t, x.offsets_t, - x.block_offsets_t) + x.block_offsets_t, + ) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index cc6931a..fa888a6 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,39 +1,60 @@ -from megablocks.layers import common -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.mlp import SparseMLP, SharedMLP, create_dmoe_expert_weights, resolve_dtensor -from megablocks.layers import mpu -from megablocks.layers.arguments import Arguments, DEFAULT_ACTIVATION_FN -from megablocks import grouped_gemm_util as gg +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import stk import torch +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + class SparseGLU(SparseMLP): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__(args) - self.v1 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) with torch.no_grad(): - self.v1.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method)) + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) mpu.set_expert_model_parallel_attributes( - self.v1, self._should_set_parallelism_attribute) + self.v1, + self._should_set_parallelism_attribute, + ) if self.args.moe_weight_parallelism: - raise NotImplementedError("Weight parallelism not yet supported with GLU.") + raise NotImplementedError('Weight parallelism not yet supported with GLU.',) def forward(self, x, topo): if self.args.memory_optimized_mlp: - raise NotImplementedError("Memory optimized implementation not yet supported with GLU with sparse kernels.") + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1), resolve_dtensor(w2) + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) # Compute the GLU. x1 = stk.ops.sdd(x, w1.t(), topo) @@ -44,6 +65,7 @@ def forward(self, x, topo): return stk.ops.dsd(x1, w2) + class MemoryOptimizedGroupedGLU(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @@ -57,8 +79,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): v1 = v1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not v1.is_contiguous() or not w2.is_contiguous()): + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") # Layer 0: x @ w1.t(). @@ -85,13 +106,11 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') # Unpack saved tensors - dtype = ctx.dtype + # dtype = ctx.dtype saved_tensors = ctx.saved_tensors w1, v1, w2 = saved_tensors[:3] batch_sizes = saved_tensors[3] @@ -101,21 +120,30 @@ def backward(ctx, ddsd_out): # Rematerialize activation_fn output. activation_fn = ctx.activation_fn with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. dw2 = gg.backend.gmm( - activation_fn_out, ddsd_out, batch_sizes, trans_a=True) + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) # Compute dactivation_fn_out. # # NOTE: We reuse the activation_fn_out allocation. dactivation_fn_out = activation_fn_out gg.backend.gmm( - ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out) + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) # Compute dsdd_out. # @@ -139,14 +167,20 @@ def backward(ctx, ddsd_out): dx += gg.backend.gmm(dv1_out, v1, batch_sizes) return dx, dw1, dv1, dw2, None, None + memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply class GroupedGLU(SparseGLU): + def forward(self, x, tokens_per_expert): batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2)) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1), resolve_dtensor(w2) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) # Re-shape the weights for the grouped GEMMs. ne = mpu.experts_per_rank(self.args) @@ -156,8 +190,13 @@ def forward(self, x, tokens_per_expert): if self.args.memory_optimized_mlp: return memory_optimized_grouped_glu( - x, w1, v1, w2, batch_sizes, - self.args.activation_fn) + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) # Compute the MLP. x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) @@ -167,11 +206,12 @@ def forward(self, x, tokens_per_expert): class SharedGLU(SharedMLP): - """GPU for shared expert + """GPU for shared expert. Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class """ - def __init__(self, args : Arguments): + + def __init__(self, args: Arguments): super().__init__(args) self.gate_proj = args.fc_cls( args.hidden_size, diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index e314272..809e317 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -1,19 +1,17 @@ -import functools +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import gc -from megablocks.layers import dmoe, arguments, mpu -from megablocks import benchmark_util -import numpy as np import torch -_TESTS = ( - (8, 2048, 4096, 4096, 32, 4), +from megablocks.layers import arguments, dmoe -) +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) def get_tensors(): - ptrs = set([]) + ptrs = set() out = [] for obj in gc.get_objects(): if torch.is_tensor(obj): @@ -25,13 +23,14 @@ def get_tensors(): def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k): + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): args = arguments.Arguments( hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, @@ -41,13 +40,13 @@ def test_memory( expert_parallel_group=group, fp16=False, bf16=True, - device=torch.cuda.current_device()) + device=torch.cuda.current_device(), + ) layer = dmoe.dMoE(args).cuda() - x = torch.randn( - (batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) torch.cuda.empty_cache() # Run forward + backward. @@ -57,16 +56,13 @@ def test_memory( # Report peak memory. mem = torch.cuda.max_memory_allocated() - print("Max Memory Allocated = {:0.0f}MiB".format( - mem / 1e6)) - print("Max Memory Reserved = {:0.0f}MiB".format( - torch.cuda.max_memory_reserved() / 1e6)) + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) # Calculate weight and gradient memory usage. weight_memory = 2 * ( - layer.router.layer.weight.numel() + - layer.experts.mlp.w1.numel() + - layer.experts.mlp.w2.numel()) + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) def grad_numel(x): if x.grad is not None: @@ -74,15 +70,12 @@ def grad_numel(x): return 0 grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + - grad_numel(layer.experts.mlp.w1) + - grad_numel(layer.experts.mlp.w2)) + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) weight_memory += grad_memory - print("Weight Memory Allocated = {:0.0f}MiB".format( - weight_memory / 1e6)) - print("Activation Memory Allocated = {:0.0f}MiB".format( - (mem - weight_memory) / 1e6)) + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) # Manually calculate GPU memory usage from the garbage # collector. @@ -92,11 +85,10 @@ def grad_numel(x): tensors = sorted(tensors, key=lambda x: -x.numel()) for i, t in enumerate(tensors): total += t.numel() - print(f"{i}: {t.shape}, {t.numel() * 2}") + print(f'{i}: {t.shape}, {t.numel() * 2}') del tensors - print("Total Bytes Found = {:0.0f}MiB".format( - total * 2 / 1e6)) + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) if __name__ == '__main__': diff --git a/megablocks/layers/memory_test.sh b/megablocks/layers/memory_test.sh old mode 100644 new mode 100755 diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 2bb1e3b..1cae4fb 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -1,16 +1,17 @@ -from packaging import version +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from typing import Any -from megablocks.layers import common -from megablocks.layers import gelu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers import mpu -from megablocks.layers import weight_parallel as wp -from megablocks.layers.arguments import Arguments, InitFn, DEFAULT_ACTIVATION_FN -from megablocks import grouped_gemm_util as gg import stk import torch -import torch.nn.functional as F +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers import weight_parallel as wp +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): @@ -25,6 +26,8 @@ def forward(ctx, x, scale): @torch.cuda.amp.custom_bwd def backward(ctx, grad): return grad * ctx.scale, None + + scale_gradient = ScaleGradient.apply @@ -36,18 +39,23 @@ def resolve_dtensor(weight): return weight -def create_moe_expert_weights(args : Arguments, - num_experts : int, - ffn_hidden_size : int, - hidden_size : int, - init_method : InitFn): +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): # Create the entire weight matrix such that the sampled weights will # not vary between data parallelism and expert model parallelism for # the same random seed. master_weights = torch.empty( - num_experts, ffn_hidden_size, hidden_size, + num_experts, + ffn_hidden_size, + hidden_size, device=args.device, - dtype=common.dtype(args)) + dtype=common.dtype(args), + ) init_method(master_weights) if not args.moe_expert_model_parallelism: @@ -75,35 +83,44 @@ def create_moe_expert_weights(args : Arguments, # Slice the weight matrix to get the chunk for this rank. with torch.no_grad(): - weights = master_weights[ - start_expert:end_expert, start_row:end_row] + weights = master_weights[start_expert:end_expert, start_row:end_row] return weights class MLP(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() self.args = args - expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) experts_per_rank = mpu.experts_per_rank(args) - self.w1 = torch.nn.Parameter(torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args))) - self.w2 = torch.nn.Parameter(torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) mpu.set_expert_model_parallel_attributes( - self.w1, args.moe_expert_model_parallelism) + self.w1, + args.moe_expert_model_parallelism, + ) mpu.set_expert_model_parallel_attributes( - self.w2, args.moe_expert_model_parallelism) + self.w2, + args.moe_expert_model_parallelism, + ) # Initialize the parameters for the MLP. # @@ -115,16 +132,26 @@ def __init__(self, args : Arguments): # usage. with torch.no_grad(): w1 = create_moe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method) + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_(create_moe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.output_layer_init_method)) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) self.gradient_scale = None if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args) + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) def scale_grad(self, w): if self.gradient_scale is None: @@ -139,13 +166,20 @@ def forward(self, x): return torch.bmm(x, w2) -def create_dmoe_expert_weights(args : Arguments, - num_experts : int, - rows : int, - columns : int, - init_method : InitFn): +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): weights = create_moe_expert_weights( - args, num_experts, rows, columns, init_method) + args, + num_experts, + rows, + columns, + init_method, + ) weights = weights.view([-1, columns]) rows, columns = weights.shape @@ -175,16 +209,17 @@ def forward(ctx, x, w1, w2, topo, activation_fn): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous()): + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - topo_tensors = (topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) # Layer 0: x @ w1.t(). sdd_out = stk.ops.sdd(x, w1.t(), topo) @@ -210,13 +245,11 @@ def forward(ctx, x, w1, w2, topo, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') # unpack saved tensors - dtype = ctx.dtype + # dtype = ctx.dtype saved_tensors = ctx.saved_tensors w1, w2 = saved_tensors[:2] topo_tensors = saved_tensors[2:8] @@ -226,7 +259,11 @@ def backward(ctx, ddsd_out): # rematerialize activation function output activation_fn = ctx.activation_fn sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn(sdd_out, activation_fn, return_grad_fn=True) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) # Compute dw2 with recomputed activation_fn output. dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) @@ -236,12 +273,14 @@ def backward(ctx, ddsd_out): # NOTE: We reuse the activation_fn_out allocation. dactivation_fn_out = activation_fn_out stk.backend.triton_kernels.sdd( - ddsd_out, w2.t(), + ddsd_out, + w2.t(), dactivation_fn_out.shape, dactivation_fn_out.data, dactivation_fn_out.offsets, dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices) + dactivation_fn_out.column_indices, + ) # Compute dsdd_out. # @@ -270,33 +309,39 @@ def backward(ctx, ddsd_out): dsdd_out.block_offsets_t, False, w1, - ddsd_out) + ddsd_out, + ) dx = ddsd_out return dx, dw1, dw2, None, None + memory_optimized_mlp = MemoryOptimizedMLP.apply class SparseMLP(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() self.args = args - self._num_rows_per_rank = ( - (mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // - mpu.get_weight_parallel_world_size(args) + self._num_rows_per_rank = ((mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // + mpu.get_weight_parallel_world_size(args)) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), ) - - self.w1 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) - self.w2 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) # Initialize the parameters for the MLP. # @@ -307,23 +352,38 @@ def __init__(self, args : Arguments): # and the slice which causes large increases in our peak memory # usage. with torch.no_grad(): - self.w1.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method)) - self.w2.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.output_layer_init_method)) - - self._should_set_parallelism_attribute = ( - args.moe_expert_model_parallelism or args.moe_weight_parallelism) + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism) mpu.set_expert_model_parallel_attributes( - self.w1, self._should_set_parallelism_attribute) + self.w1, + self._should_set_parallelism_attribute, + ) mpu.set_expert_model_parallel_attributes( - self.w2, self._should_set_parallelism_attribute) + self.w2, + self._should_set_parallelism_attribute, + ) self.gradient_scale = None if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args) + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) def scale_grad(self, w): if self.gradient_scale is None: @@ -335,9 +395,16 @@ def parallel_forward(self, x, topo): w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) if self.args.memory_optimized_mlp: if self.args.activation_fn is not DEFAULT_ACTIVATION_FN: - raise NotImplementedError(f'memory_optimized_weight_parallel_mlp not implemented for custom {activation_fn=}.') + raise NotImplementedError( + f'memory_optimized_weight_parallel_mlp not implemented for custom activation_fn={self.args.activation_fn}.', + ) return wp.memory_optimized_weight_parallel_mlp( - x, w1, w2, topo, group) + x, + w1, + w2, + topo, + group, + ) # Compute the MLP. x = wp.sdd_nt(x, w1, topo, group) @@ -351,7 +418,12 @@ def forward(self, x, topo): return self.parallel_forward(x, topo) elif self.args.memory_optimized_mlp: return memory_optimized_mlp( - x, w1, w2, topo, self.args.activation_fn) + x, + w1, + w2, + topo, + self.args.activation_fn, + ) # Compute the MLP. x = stk.ops.sdd(x, w1.t(), topo) @@ -371,8 +443,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous()): + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). @@ -398,13 +469,11 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') # Unpack saved tensors - dtype = ctx.dtype + # dtype = ctx.dtype saved_tensors = ctx.saved_tensors w1, w2 = saved_tensors[:2] batch_sizes = saved_tensors[2] @@ -420,14 +489,23 @@ def backward(ctx, ddsd_out): # Compute dw2 with recomputed activation_fn output. dw2 = gg.backend.gmm( - activation_fn_out, ddsd_out, batch_sizes, trans_a=True) + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) # Compute dactivation_fn_out. # # NOTE: We reuse the activation_fn_out allocation. dactivation_fn_out = activation_fn_out gg.backend.gmm( - ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out) + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) # Compute dsdd_out. # @@ -449,6 +527,7 @@ def backward(ctx, ddsd_out): dx = ddsd_out return dx, dw1, dw2, None, None + memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply @@ -464,13 +543,16 @@ def forward(self, x, tokens_per_expert): w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) if self.args.moe_weight_parallelism: - raise NotImplementedError( - "Weight parallelism not yet supported with GroupedMLP.") + raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',) if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( - x, w1, w2, batch_sizes, - self.args.activation_fn) + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) # Compute the MLP. x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) @@ -479,11 +561,12 @@ def forward(self, x, tokens_per_expert): class SharedMLP(torch.nn.Module): - """MLP for shared expert + """MLP for shared expert. Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class """ - def __init__(self, args : Arguments): + + def __init__(self, args: Arguments): super().__init__() self.args = args self.fc_kwargs: dict[str, Any] = { @@ -505,7 +588,11 @@ def __init__(self, args : Arguments): ) self.down_proj._is_residual = True # a flag for llm-foundry init - def add_experts_sharedexpert(self, shared_expert_out: torch.Tensor, expert_out: torch.Tensor) -> torch.Tensor: + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: # Helper function to add expert output to shared expert output # with optional weighted sum. if self.args.shared_expert_weighted_sum: @@ -513,7 +600,10 @@ def add_experts_sharedexpert(self, shared_expert_out: torch.Tensor, expert_out: # wieghted by number of experts used t_experts = self.args.moe_top_k + 1 sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add(expert_out, alpha=(self.args.moe_top_k / t_experts)) + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) return shared_expert_out + expert_out diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 9d26da2..e5eaaa8 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,14 +1,13 @@ -from megablocks.layers import common -from megablocks.layers import mpu -from megablocks.layers import router -from megablocks.layers import mlp -from megablocks.layers import sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments -import megablocks.ops as ops +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import torch +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments _LOAD_BALANCING_LOSS = [] @@ -28,47 +27,41 @@ def clear_load_balancing_loss(): _LOAD_BALANCING_LOSS.clear() -def batched_load_balancing_loss(args : Arguments): +def batched_load_balancing_loss(args: Arguments): if args.moe_loss_weight == 0: return 0.0 # tokens_per_expert[i].shape = (num_experts) # expert_scores[i].shape = (tokens, num_experts) tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = ( - args.num_layers // args.pipeline_model_parallel_size) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) if args.num_layers_per_virtual_pipeline_stage is not None: num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage if len(tokens_per_expert) != num_layers_per_pipeline_stage: raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}") + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) if len(expert_scores) != num_layers_per_pipeline_stage: raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}") + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all([ - x.ndim == 1 and x.numel() == args.moe_num_experts - for x in tokens_per_expert - ]) + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) tokens = expert_scores[0].shape[0] - assert all([ - (x.ndim == 2 and x.shape[1] == args.moe_num_experts and - x.shape[0] == tokens) for x in expert_scores - ]) - + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) # Concatenate the contributions of each layer and convert to # the correct types and formats for the dot product. @@ -88,15 +81,8 @@ def batched_load_balancing_loss(args : Arguments): # Calculate the total scale across all factors. # # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = ( - args.moe_num_experts * - args.moe_loss_weight - ) - scale_denominator = ( - args.num_layers * - tokens * - args.moe_top_k - ) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) scale = scale_numerator / scale_denominator return scale * torch.dot(tokens_per_expert, expert_scores) @@ -107,13 +93,13 @@ def batched_load_balancing_loss(args : Arguments): # parallel all2all. class ParallelMLP(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super(ParallelMLP, self).__init__() self.args = args # Calculate the number of experts in total and the number of experts # owned by this rank. - world_size = mpu.get_expert_parallel_world_size(args) + # world_size = mpu.get_expert_parallel_world_size(args) self.num_experts = args.moe_num_experts self.top_k = self.args.moe_top_k @@ -127,24 +113,23 @@ def __init__(self, args : Arguments): if self.args.bias: # Note that the output bias is not parallelized with expert # model parallelism. - self.bias = torch.nn.Parameter(torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) torch.nn.init.zeros_(self.bias) else: self.register_parameter('bias', None) # Select the forward function for the operating mode. - self.forward_fn = ( - self.parallel_forward_once if - args.moe_expert_model_parallelism else - self.forward_once) + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) def expert_capacity(self, tokens): world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = ( - self.top_k * tokens * world_size / self.num_experts) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) return int(self.args.moe_capacity_factor * tokens_per_expert) def load_balancing_loss(self, tokens_per_expert, expert_scores): @@ -158,7 +143,8 @@ def load_balancing_loss(self, tokens_per_expert, expert_scores): scale = self.num_experts / (tokens * self.top_k) return scale * torch.dot( tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0)) + expert_scores.mean(dim=0), + ) def indices_and_bins(self, top_expert): # Sort the expert ids to produce the scatter/gather @@ -184,27 +170,26 @@ def indices_and_bins(self, top_expert): return indices, bin_ids, bins, tokens_per_expert def permute_and_compute( - self, - x, - tokens_per_expert, # unused - indices, - bin_ids, # unused - expert_weights, - bins, - expert_capacity, - top_k): + self, + x, + tokens_per_expert, # unused + indices, + bin_ids, # unused + expert_weights, + bins, + expert_capacity, + top_k, + ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.binned_gather( - x, indices, bins, expert_capacity, top_k) + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) # Perform the expert computation. Note that we don't # use biases for these linear operations. x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter( - x, indices, expert_weights, bins, top_k) + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) def forward_once(self, x, expert_weights, top_experts): # x: [sl, bs, hs] @@ -213,8 +198,7 @@ def forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts)) + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. @@ -231,7 +215,8 @@ def forward_once(self, x, expert_weights, top_experts): expert_weights, bins, expert_capacity, - self.top_k) + self.top_k, + ) return x, tokens_per_expert def parallel_forward_once(self, x, expert_weights, top_experts): @@ -259,23 +244,25 @@ def parallel_forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts)) + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (mpu.hidden_sharding_degree(self.args),)) + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) # Pass token count information to the device on which the # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) tpe_handle = torch.distributed.all_to_all_single( parallel_tokens_per_expert, repeated_tokens_per_expert, group=self.args.expert_parallel_group, - async_op=True) + async_op=True, + ) # Permute locally and without any padding so that tokens for each # parallel device are stored contiguously. @@ -283,12 +270,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - x = ops.gather( - x, - indices, - bin_ids, - bins, - self.top_k) + x = ops.gather(x, indices, bin_ids, bins, self.top_k) # Compute the number of tokens that will be received from each # device and permute the input data across the devices. @@ -298,10 +280,8 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Reshape to [world_size, num_experts_per_rank]. world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = ( - repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = ( - parallel_tokens_per_expert.view(world_size, experts_per_rank)) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) # TODO(tgale): It might be faster to do this on the GPU and # then communicate the results back to the host. @@ -325,9 +305,12 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Start the cross-device permutation asynchronously so we can # overlap communication with computation. parallel_x, parallel_x_handle = all_to_all( - x, recv_counts, send_counts, + x, + recv_counts, + send_counts, self.args.expert_parallel_group, - async_op=True) + async_op=True, + ) with torch.no_grad(): # After we do the cross-device permutation we have the tokens on the @@ -337,48 +320,46 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. replicate_bins = ops.inclusive_cumsum( - parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) - if not len(replicate_bins.size()) - else replicate_bins + parallel_tokens_per_expert.flatten(), + 0, ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) # Construct the expert indices for the permuted tokens. parallel_top_expert = torch.remainder( torch.arange( self.num_experts * mpu.hidden_sharding_degree(self.args), dtype=torch.int32, - device=indices.device + device=indices.device, ), mpu.experts_per_rank(self.args), ) parallel_top_expert = ops.replicate( parallel_top_expert.unsqueeze(dim=0), - replicate_bins, tokens_received).flatten() + replicate_bins, + tokens_received, + ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, self.sort_end_bit) + parallel_top_expert, + self.sort_end_bit, + ) # Calculate the bins boundaries from the token counts. parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int) - parallel_bins = ops.inclusive_cumsum( - parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) - if not len(parallel_bins.size()) - else parallel_bins + dim=0, + dtype=torch.int, ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. tokens, hs = x.size() expert_capacity = self.expert_capacity(tokens) if expert_capacity == 0: - expert_capacity = torch.max( - parallel_tokens_per_expert).item() + expert_capacity = torch.max(parallel_tokens_per_expert).item() # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. @@ -387,7 +368,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # moved to CPU for the prior all_to_all, which avoids an extra # device synchronization. parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, dtype=torch.int) + dim=0, + dtype=torch.int, + ) parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, @@ -397,12 +380,16 @@ def parallel_forward_once(self, x, expert_weights, top_experts): None, # expert_weights parallel_bins, expert_capacity, - top_k=1) + top_k=1, + ) # Un-permute the tokens across the devices. x, _ = all_to_all( - parallel_x, send_counts, recv_counts, - self.args.expert_parallel_group) + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) # Reduce along the hidden sharding to get the final outputs. # @@ -410,26 +397,19 @@ def parallel_forward_once(self, x, expert_weights, top_experts): shape = ( mpu.hidden_sharding_degree(self.args), -1, - self.args.hidden_size + self.args.hidden_size, ) x = ops.sum(x.view(shape), dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - self.top_k) + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() def forward(self, x, scores, expert_weights, top_experts): in_shape = x.size() # Compute the experts. - x, tokens_per_expert = self.forward_fn( - x, expert_weights, top_experts) + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) if self.training and self.args.moe_loss_weight > 0: save_load_balancing_loss((tokens_per_expert, scores)) x = x.view(in_shape) @@ -442,7 +422,7 @@ def forward(self, x, scores, expert_weights, top_experts): class MoE(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super(MoE, self).__init__() # Token router. @@ -471,5 +451,8 @@ def forward(self, x): out = self.experts(x, scores, expert_weights, top_experts) if self.shared_expert is not None: shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert(shared_expert_out, out) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) return out diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 49bbcbe..6aa0015 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,55 +1,53 @@ -from megablocks.layers.arguments import Arguments +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +from megablocks.layers.arguments import Arguments + -def is_moe_param(tensor : torch.Tensor) -> bool: +def is_moe_param(tensor: torch.Tensor) -> bool: return hasattr(tensor, 'expert_model_parallel') -def get_expert_parallel_world_size(args : Arguments) -> int: - return ( - torch.distributed.get_world_size(args.expert_parallel_group) - if args.moe_expert_model_parallelism else 1 - ) +def get_expert_parallel_world_size(args: Arguments) -> int: + return (torch.distributed.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) -def get_expert_parallel_rank(args : Arguments) -> int: - return ( - torch.distributed.get_rank(args.expert_parallel_group) - if args.moe_expert_model_parallelism else 0 - ) +def get_expert_parallel_rank(args: Arguments) -> int: + return (torch.distributed.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) -def set_expert_model_parallel_attributes(tensor : torch.Tensor, - is_parallel : bool): +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): assert not hasattr(tensor, 'expert_model_parallel') setattr(tensor, 'expert_model_parallel', is_parallel) -def param_is_expert_model_parallel(param : torch.Tensor) -> bool: - return (hasattr(param, 'expert_model_parallel') and - param.expert_model_parallel) +def param_is_expert_model_parallel(param: torch.Tensor) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) -def copy_expert_model_parallel_attributes(destination_tensor : torch.Tensor, - source_tensor : torch.Tensor): +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): if hasattr(source_tensor, 'expert_model_parallel'): - setattr(destination_tensor, 'expert_model_parallel', - getattr(source_tensor,'expert_model_parallel')) + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) -def get_weight_parallel_world_size(args : Arguments) -> int: - return ( - torch.distributed.get_world_size(args.weight_parallel_group) - if args.moe_weight_parallelism else 1 - ) +def get_weight_parallel_world_size(args: Arguments) -> int: + return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1) -def get_weight_parallel_rank(args : Arguments) -> int: - return ( - torch.distributed.get_rank(args.weight_parallel_group) - if args.moe_weight_parallelism else 0 - ) +def get_weight_parallel_rank(args: Arguments) -> int: + return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0) def synchronized_print(group, *x): @@ -58,39 +56,38 @@ def synchronized_print(group, *x): for i in range(world_size): torch.distributed.barrier(group) if i == rank: - print(f"rank = {rank}", *x) + print(f'rank = {rank}', *x) # Helpers for expert/tensor sharding. -def expert_sharding_degree(args : Arguments) -> int: +def expert_sharding_degree(args: Arguments) -> int: world_size = get_expert_parallel_world_size(args) esd = min(world_size, args.moe_num_experts) if (args.moe_num_experts % esd) != 0: - raise ValueError( - f"Cannot shard {args.moe_num_experts} experts {esd} ways.") + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) return esd -def hidden_sharding_degree(args : Arguments) -> int: +def hidden_sharding_degree(args: Arguments) -> int: world_size = get_expert_parallel_world_size(args) esd = expert_sharding_degree(args) hsd = world_size // esd if (args.ffn_hidden_size % hsd) != 0: - raise ValueError( - f"Cannot shard {args.ffn_hidden_size} features {hsd} ways.") + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) if (esd * hsd) != world_size: raise ValueError( f"Invalid sharding. 'expert_sharding_degree' " - f"({esd}) * hidden_sharding_degree " - f"({hsd}) != world_size ({world_size}).") + f'({esd}) * hidden_sharding_degree ' + f'({hsd}) != world_size ({world_size}).', + ) return hsd -def experts_per_rank(args : Arguments) -> int: +def experts_per_rank(args: Arguments) -> int: return args.moe_num_experts // expert_sharding_degree(args) -def features_per_rank(args : Arguments) -> int: +def features_per_rank(args: Arguments) -> int: return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index e1abddf..42cfbe1 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,6 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch + from megablocks.layers import common from megablocks.layers.arguments import Arguments -import torch # NOTE: To enable end-to-end benchmarking without convergence we @@ -9,18 +13,19 @@ # so that PyTorch still executes the full set of router operation. class _UniformExpertAssignment(torch.autograd.Function): - @staticmethod def forward(ctx, x, num_experts): out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) out = torch.remainder(out, num_experts) return out.view(x.shape) + + _uniform_expert_assignment = _UniformExpertAssignment.apply class LearnedRouter(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() self.args = args @@ -34,7 +39,8 @@ def __init__(self, args : Arguments): args.moe_num_experts, bias=False, dtype=common.dtype(args), - device=args.device) + device=args.device, + ) args.init_method(self.layer.weight) def jitter(self, x): @@ -45,7 +51,7 @@ def jitter(self, x): def _top_k(self, scores): if self.args.moe_top_k == 1: - return scores.max(dim=-1,keepdim=True) + return scores.max(dim=-1, keepdim=True) return torch.topk(scores, self.args.moe_top_k, dim=-1) def forward(self, x): @@ -56,10 +62,16 @@ def forward(self, x): expert_weights, expert_indices = self._top_k(scores) if self.args.moe_normalize_expert_weights: expert_weights = expert_weights / torch.norm( - expert_weights, p=self.args.moe_normalize_expert_weights,dim=-1, keepdim=True) + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) expert_indices = ( - _uniform_expert_assignment(expert_indices, self.args.moe_num_experts) - if self.args.uniform_expert_assignment else expert_indices + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices ) return scores, expert_weights, expert_indices diff --git a/megablocks/layers/sharedexpert_registry.py b/megablocks/layers/sharedexpert_registry.py index 4d323ee..0f62db3 100644 --- a/megablocks/layers/sharedexpert_registry.py +++ b/megablocks/layers/sharedexpert_registry.py @@ -1,14 +1,17 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from typing import Union -from megablocks.layers import mlp -from megablocks.layers import glu -from megablocks.layers.arguments import Arguments +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, 'glu': glu.SharedGLU, } + def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: """Returns an SharedMLP for use in a dMoE instance. @@ -20,9 +23,8 @@ def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: Returns: An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: + if args.mlp_type not in _REGISTRY: raise ValueError(f'Unsupported mlp type: {args.mlp_type}') return _REGISTRY[args.mlp_type](args) diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py index 530026e..4cd9500 100644 --- a/megablocks/layers/testing.py +++ b/megablocks/layers/testing.py @@ -1,46 +1,62 @@ -from megablocks.layers.arguments import Arguments +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch import torch.nn.functional as F +from megablocks.layers.arguments import Arguments + def allclose(x, y, pct=0.5): mask = torch.isclose(x, y, rtol=5e-2) pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) + print('{:.2f}% of values not close.'.format(pct_diff)) return False return True class FFN(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() - self.w1 = torch.nn.Parameter(torch.empty( - args.hidden_size, - args.ffn_hidden_size, - device=args.device, - dtype=torch.float16 if args.fp16 else torch.float32)) - self.w2 = torch.nn.Parameter(torch.empty( - args.ffn_hidden_size, - args.hidden_size, - device=args.device, - dtype=torch.float16 if args.fp16 else torch.float32)) + self.w1 = torch.nn.Parameter( + torch.empty( + args.hidden_size, + args.ffn_hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32, + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + args.ffn_hidden_size, + args.hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32, + ), + ) def forward(self, x): - return torch.matmul(F.gelu( - torch.matmul(x, self.w1), approximate="tanh"), self.w2) + return torch.matmul( + F.gelu(torch.matmul(x, self.w1), approximate='tanh'), + self.w2, + ) + class GLU(FFN): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__(args) - self.v1 = torch.nn.Parameter(torch.empty( - args.hidden_size, - args.ffn_hidden_size, - device=args.device, - dtype=torch.float16 if args.fp16 else torch.float32)) + self.v1 = torch.nn.Parameter( + torch.empty( + args.hidden_size, + args.ffn_hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32, + ), + ) def forward(self, x): - x1 = F.gelu(torch.matmul(x, self.w1), approximate="tanh") * torch.matmul(x, self.v1) + x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1) return torch.matmul(x1, self.w2) diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py index 46d5674..82effec 100644 --- a/megablocks/layers/weight_parallel.py +++ b/megablocks/layers/weight_parallel.py @@ -1,7 +1,11 @@ -from megablocks.layers import gelu +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import stk import torch +from megablocks.layers import gelu + def _gather_weights(w, group, parallel_w=None, async_op=False): """Gather the weights across the process group. @@ -22,9 +26,17 @@ def _gather_weights(w, group, parallel_w=None, async_op=False): if parallel_w is None: parallel_w = torch.empty( - n * world_size, k, device=w.device, dtype=w.dtype) + n * world_size, + k, + device=w.device, + dtype=w.dtype, + ) handle = torch.distributed.all_gather_into_tensor( - parallel_w, w, group=group, async_op=async_op) + parallel_w, + w, + group=group, + async_op=async_op, + ) return parallel_w, handle @@ -52,11 +64,17 @@ def _scaled_reduce_scatter(parallel_dw, group, dw=None, async_op=False): if dw is None: dw = torch.empty( - n // world_size, k, + n // world_size, + k, device=parallel_dw.device, - dtype=torch.float32) + dtype=torch.float32, + ) handle = torch.distributed.reduce_scatter_tensor( - dw, parallel_dw, group=group, async_op=async_op) + dw, + parallel_dw, + group=group, + async_op=async_op, + ) return dw, handle @@ -76,13 +94,15 @@ def forward(ctx, x, w, topo, group): ctx.group = group ctx.shape = topo.shape ctx.save_for_backward( - x, w, + x, + w, topo.row_indices, topo.column_indices, topo.offsets, topo.column_indices_t, topo.offsets_t, - topo.block_offsets_t) + topo.block_offsets_t, + ) # TODO(tgale): Support prefetching forward weights. parallel_w, _ = _gather_weights(w, group) @@ -104,7 +124,11 @@ def backward(ctx, grad): # Start the weight gradient reduce scatter to overlap with the # data gradient computation. handle.wait() - dw, handle = _scaled_reduce_scatter(parallel_dw, ctx.group, async_op=True) + dw, handle = _scaled_reduce_scatter( + parallel_dw, + ctx.group, + async_op=True, + ) dx = None if ctx.needs_input_grad[0]: dx = stk.ops.dsd(grad, parallel_w) @@ -125,24 +149,27 @@ def sdd_nt(a, b, topo, group): topo.offsets, topo.column_indices_t, topo.offsets_t, - topo.block_offsets_t) + topo.block_offsets_t, + ) class WeightParallelDsdNn(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd - def forward(ctx, - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - w, - group): + def forward( + ctx, + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + w, + group, + ): # [m, k] x [k, n] = [m, n] # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -161,7 +188,8 @@ def forward(ctx, column_indices_t, offsets_t, block_offsets_t, - w) + w, + ) x = stk.Matrix( shape, data, @@ -170,7 +198,8 @@ def forward(ctx, offsets, column_indices_t, offsets_t, - block_offsets_t) + block_offsets_t, + ) # TODO(tgale): Support prefetching forward weights. parallel_w, _ = _gather_weights(w, group) @@ -192,7 +221,11 @@ def backward(ctx, grad): # Start the weight gradient reduce scatter to overlap with the # data gradient computation. handle.wait() - dw, handle = _scaled_reduce_scatter(parallel_dw, ctx.group, async_op=True) + dw, handle = _scaled_reduce_scatter( + parallel_dw, + ctx.group, + async_op=True, + ) dx = None if ctx.needs_input_grad[1]: dx = stk.ops.sdd(grad, parallel_w.t(), x) @@ -215,7 +248,8 @@ def dsd_nn(a, b, group): a.offsets_t, a.block_offsets_t, b, - group) + group, + ) class MemoryOptimizedWeightParallelMLP(torch.autograd.Function): @@ -230,8 +264,7 @@ def forward(ctx, x, w1, w2, topo, group): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous()): + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). @@ -254,13 +287,17 @@ def forward(ctx, x, w1, w2, topo, group): ctx.group = group ctx.shape = topo.shape ctx.save_for_backward( - x, w1, w2, sdd_out.data, + x, + w1, + w2, + sdd_out.data, topo.row_indices, topo.column_indices, topo.offsets, topo.column_indices_t, topo.offsets_t, - topo.block_offsets_t) + topo.block_offsets_t, + ) return dsd_out @staticmethod @@ -269,15 +306,12 @@ def backward(ctx, ddsd_out): x, w1, w2 = ctx.saved_tensors[:3] sdd_out = stk.Matrix(ctx.shape, *ctx.saved_tensors[3:]) - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') # Start the weight gather asynchronously to overlap with the # weight gradient computation and gelu recompute. - parallel_w2, handle = _gather_weights( - w2, ctx.group, async_op=True) + parallel_w2, handle = _gather_weights(w2, ctx.group, async_op=True) # Compute dw2 with recomputed gelu output. gelu_out = gelu.gelu(sdd_out) @@ -287,18 +321,23 @@ def backward(ctx, ddsd_out): # data gradient computation. handle.wait() dw2, handle = _scaled_reduce_scatter( - parallel_dw2, ctx.group, async_op=True) + parallel_dw2, + ctx.group, + async_op=True, + ) # Compute dgelu_out. # # NOTE: We reuse the gelu_out allocation. stk.backend.triton_kernels.sdd( - ddsd_out, parallel_w2.t(), + ddsd_out, + parallel_w2.t(), sdd_out.shape, gelu_out.data, sdd_out.offsets, sdd_out.row_indices, - sdd_out.column_indices) + sdd_out.column_indices, + ) dgelu_out = gelu_out # NOTE: Be careful to wait and only cast dw to the output dtype once @@ -311,7 +350,11 @@ def backward(ctx, ddsd_out): # # NOTE: Reuse the buffer from the w2 weight gather. parallel_w1, handle = _gather_weights( - w1, ctx.group, parallel_w2, async_op=True) + w1, + ctx.group, + parallel_w2, + async_op=True, + ) # Compute dsdd_out. # @@ -332,14 +375,18 @@ def backward(ctx, ddsd_out): dsdd_out.block_offsets_t, True, # transpose_a x, - parallel_dw2) + parallel_dw2, + ) parallel_dw1 = parallel_dw2 # Start the weight gradient reduce scatter to overlap with the # data gradient computation. handle.wait() dw1, handle = _scaled_reduce_scatter( - parallel_dw1, ctx.group, async_op=True) + parallel_dw1, + ctx.group, + async_op=True, + ) # Compute dx. # @@ -355,7 +402,8 @@ def backward(ctx, ddsd_out): dsdd_out.block_offsets_t, False, parallel_w1, - ddsd_out) + ddsd_out, + ) dx = ddsd_out # NOTE: Be careful to wait and only cast dw to the output dtype once @@ -364,4 +412,5 @@ def backward(ctx, ddsd_out): dw1 = dw1.to(w1.dtype) return dx, dw1, dw2, None, None + memory_optimized_weight_parallel_mlp = MemoryOptimizedWeightParallelMLP.apply diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index 44a2909..b9dc286 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -1,7 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from megablocks.ops.binned_gather import binned_gather from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum -from megablocks.ops.cumsum import inclusive_cumsum +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum from megablocks.ops.gather import gather from megablocks.ops.histogram import histogram from megablocks.ops.padded_gather import padded_gather @@ -13,3 +15,21 @@ from megablocks.ops.sort import sort from megablocks.ops.sum import sum from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index d3fbcf3..b3a8537 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -1,7 +1,11 @@ -from megablocks.layers.all_to_all import all_to_all -from megablocks import benchmark_util +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all + _ALL_TO_ALL_BENCHMARK = ( (8, 1024), (16, 1024), @@ -23,23 +27,26 @@ (1024 * 1024, 1024), ) + def benchmark_all_to_all(group, sl, hs): - world_size = torch.distributed.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size + world_size = torch.distributed.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() - x = torch.randn((sl, hs)).cuda().half() + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } - details = { - "world_size": world_size, - "message_size (B)": send_recv_sizes[0] * hs * 2, # 2B elements. - } + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - fn = lambda: all_to_all(x, send_recv_sizes, send_recv_sizes, group) - time, std = benchmark_util.benchmark_function(fn) + time, std = benchmark_util.benchmark_function(benchmark) - if torch.distributed.get_rank(group) == 0: - benchmark_util.log_benchmark("All-To-All", details, time, std) + if torch.distributed.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) if __name__ == '__main__': diff --git a/megablocks/ops/all_to_all_benchmark.sh b/megablocks/ops/all_to_all_benchmark.sh old mode 100644 new mode 100755 diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 0592a55..8a22317 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,6 +1,11 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd + # Autograd wrapper for binned_gather kernel. class BinnedGatherOp(torch.autograd.Function): @@ -19,4 +24,6 @@ def backward(ctx, grad): indices, bins = ctx.saved_tensors out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) return out, None, None, None, None + + binned_gather = BinnedGatherOp.apply diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index 453de7d..f65fbe8 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,6 +1,11 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd + # Autograd wrapper for binned_scatter kernel. class BinnedScatterOp(torch.autograd.Function): @@ -23,7 +28,13 @@ def backward(ctx, grad): grad = grad.contiguous() x, indices, weights, bins = ctx.saved_tensors out = kernels.binned_gather( - grad, indices, weights, bins, ctx.bin_size, ctx.top_k) + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) wgrad = None if ctx.needs_input_grad[2]: @@ -32,6 +43,9 @@ def backward(ctx, grad): grad, indices, bins, - ctx.top_k) + ctx.top_k, + ) return out, None, wgrad, None, None + + binned_scatter = BinnedScatterOp.apply diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 6907f81..09b23ab 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -1,14 +1,19 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + # Autograd wrappers for cumsum kernels. -# # NOTE: Does not support gradients. class ExclusiveCumsumOp(torch.autograd.Function): @@ -22,8 +27,11 @@ def forward(ctx, x, dim): out = torch.empty_like(x) ops.exclusive_cumsum(x, dim, out) return out + + exclusive_cumsum = ExclusiveCumsumOp.apply + class InclusiveCumsumOp(torch.autograd.Function): @staticmethod @@ -36,4 +44,6 @@ def forward(ctx, x, dim): out = torch.empty_like(x) ops.inclusive_cumsum(x, dim, out) return out + + inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index bd8da3a..a335273 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,6 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for gather kernel. @@ -11,8 +15,7 @@ class GatherOp(torch.autograd.Function): def forward(ctx, x, indices, bin_ids, bins, top_k): ctx.save_for_backward(indices, bin_ids, bins) ctx.top_k = top_k - return kernels.gather( - x, indices, bin_ids, None, bins, top_k) + return kernels.gather(x, indices, bin_ids, None, bins, top_k) @staticmethod @custom_bwd @@ -20,7 +23,8 @@ def backward(ctx, grad): grad = grad.contiguous() indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter( - grad, indices, bin_ids, None, bins, ctx.top_k) + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) return out, None, None, None, None, None + + gather = GatherOp.apply diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index f81862b..7660e82 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -1,18 +1,25 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + # Autograd wrapper for histogram kernel. -# # NOTE: Does not support gradients. class HistogramOp(torch.autograd.Function): @staticmethod def forward(ctx, x, max_val): return ops.histogram(x, max_val) + + histogram = HistogramOp.apply diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index 9e0e930..9de8e65 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -1,10 +1,13 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import ops import numpy as np import torch +from absl.testing import parameterized +from megablocks import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), @@ -17,6 +20,7 @@ (16384, torch.int32, 256), ) + def benchmark_function(fn, iterations=10): # Run once to get rid of startup overhead. fn() @@ -34,13 +38,13 @@ def benchmark_function(fn, iterations=10): def log_benchmark(arguments, mean_t, std_t): - print("="*60) - print("Benchmark Parameters:") + print('=' * 60) + print('Benchmark Parameters:') for (key, value) in arguments.items(): - print(f"{key} = {value}") - print("Results:") - print("mean / std = {:.2f}ms / {:.2f}ms".format(mean_t, std_t)) - print("="*60) + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) class HistogramBenchmark(parameterized.TestCase): @@ -49,12 +53,11 @@ class HistogramBenchmark(parameterized.TestCase): def testHistogram(self, n, dtype, max_val): x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: ops.histogram(x, max_val)) + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) arguments = { - "n": n, - "dtype": dtype, - "max_val": max_val + 'n': n, + 'dtype': dtype, + 'max_val': max_val, } log_benchmark(arguments, mean_t, std_t) @@ -62,12 +65,11 @@ def testHistogram(self, n, dtype, max_val): def testTorchHistogram(self, n, dtype, max_val): x = torch.randint(0, 128, (n,)).cuda().to(dtype) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: torch.histc(x, max_val, 0, max_val-1)) + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) arguments = { - "n": n, - "dtype": dtype, - "max_val": max_val + 'n': n, + 'dtype': dtype, + 'max_val': max_val, } log_benchmark(arguments, mean_t, std_t) diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index 632155c..bfa7b7c 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -1,10 +1,13 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import benchmark_util -from megablocks import ops import stk import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls @@ -12,7 +15,10 @@ # this adds. def transpose_view(x): return torch.as_strided( - x, (x.shape[1], x.shape[0]), (x.stride()[1], x.stride()[0])) + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) _MATMUL_TESTS = ( @@ -25,9 +31,9 @@ def transpose_view(x): def log_benchmark(name, arguments, time, std, flops): benchmark_util.log_benchmark(name, arguments, time, std) - print("flops = {:.2f}B".format(flops / 1e9)) - print("throughput = {:.2f}T".format(flops / 1e9 / time)) - print("="*60) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) class MatmulBenchmark(parameterized.TestCase): @@ -48,29 +54,28 @@ def build_sparse_matrix(self, x, padded_bins, fhs, ne): block_rows * blocks_per_row + 1, blocks_per_row, dtype=torch.int32, - device=x.device) + device=x.device, + ) # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, - blocking, - block_rows, - blocks_per_row) + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) data = torch.empty( column_indices.numel(), blocking, blocking, dtype=torch.float16, - device=x.device) + device=x.device, + ) shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices( - shape, data, offsets, column_indices) - return stk.Matrix(shape, - data, - row_indices, - column_indices, - offsets) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) def build_input_matrix(self, sl, hs, ne): x = torch.randn((sl, hs)).cuda().half() @@ -96,16 +101,23 @@ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) w = transpose_view(w) - benchmark = lambda: stk.ops.sdd(x, w, topo) + def benchmark(): + return stk.ops.sdd(x, w, topo) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0::Fwd::SDD::NT", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): @@ -113,16 +125,23 @@ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - benchmark = lambda: stk.ops.dsd(topo, w) + def benchmark(): + return stk.ops.dsd(topo, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0::GradX::DSD::NN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): @@ -130,16 +149,23 @@ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) topo = topo.t() - benchmark = lambda: stk.ops.dsd(topo, x) + def benchmark(): + return stk.ops.dsd(topo, x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0::GradW::DSD::TN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): @@ -147,16 +173,23 @@ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - benchmark = lambda: stk.ops.dsd(x, w) + def benchmark(): + return stk.ops.dsd(x, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::Fwd::DSD::NN", arguments, mean_t, std_t, - x.nnz * hs * 2) + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): @@ -166,16 +199,23 @@ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): out = stk.ops.dsd(x, w) w = transpose_view(w) - benchmark = lambda: stk.ops.sdd(out, w, x) + def benchmark(): + return stk.ops.sdd(out, w, x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::GradX::SDD::NT", arguments, mean_t, std_t, - x.nnz * hs * 2) + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): @@ -185,16 +225,23 @@ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): out = stk.ops.dsd(x, w) x = x.t() - benchmark = lambda: stk.ops.dsd(x, out) + def benchmark(): + return stk.ops.dsd(x, out) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::GradW::DSD::TN", arguments, mean_t, std_t, - x.nnz * hs * 2) + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): @@ -205,16 +252,23 @@ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): w = w.transpose(1, 2).contiguous() w = w.transpose(1, 2) - benchmark = lambda: torch.bmm(x, w) + def benchmark(): + return torch.bmm(x, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0::Fwd:DDD::NT", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): @@ -224,16 +278,23 @@ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): out = torch.bmm(x, w) w = w.transpose(1, 2).contiguous() - benchmark = lambda: torch.bmm(out, w) + def benchmark(): + return torch.bmm(out, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0:GradX:DDD::NN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): @@ -243,16 +304,23 @@ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): out = torch.bmm(x, w) out = out.transpose(1, 2) - benchmark = lambda: torch.bmm(out, x) + def benchmark(): + return torch.bmm(out, x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0:GradW:DDD::TN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): @@ -260,16 +328,23 @@ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): x = torch.randn((ne, sl // ne, fhs)).cuda().half() w = torch.randn((ne, fhs, hs)).cuda().half() - benchmark = lambda: torch.bmm(x, w) + def benchmark(): + return torch.bmm(x, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::Fwd::DDD::NN", arguments, mean_t, std_t, - x.numel() * hs * 2) + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): @@ -279,16 +354,23 @@ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): out = torch.bmm(x, w) w = torch.transpose(w, 1, 2) - benchmark = lambda: torch.bmm(out, w) + def benchmark(): + return torch.bmm(out, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::GradX::DDD::NT", arguments, mean_t, std_t, - x.numel() * hs * 2) + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): @@ -298,16 +380,23 @@ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): out = torch.bmm(x, w) x = torch.transpose(x, 1, 2) - benchmark = lambda: torch.bmm(x, out) + def benchmark(): + return torch.bmm(x, out) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::GradW::DDD::TN", arguments, mean_t, std_t, - x.numel() * hs * 2) + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) if __name__ == '__main__': diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 3c2685f..b57a518 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,6 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for padded_gather kernel. @@ -12,7 +16,14 @@ def forward(ctx, x, indices, bin_ids, bins, padded_bins, top_k): ctx.save_for_backward(indices, bin_ids, bins, padded_bins) ctx.top_k = top_k return kernels.padded_gather( - x, indices, bin_ids, None, bins, padded_bins, top_k) + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) @staticmethod @custom_bwd @@ -21,6 +32,15 @@ def backward(ctx, grad): indices, bin_ids, bins, padded_bins = ctx.saved_tensors out = kernels.padded_scatter( - grad, indices, bin_ids, None, bins, padded_bins, ctx.top_k) + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) return out, None, None, None, None, None + + padded_gather = PaddedGatherOp.apply diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 22ae923..1ca1605 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,6 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for padded_scatter kernel. @@ -11,11 +15,24 @@ class PaddedScatterOp(torch.autograd.Function): def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k): maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward( - indices, bin_ids, weights, bins, padded_bins, *maybe_x) + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) ctx.top_k = top_k ctx.x_shape = x.shape return kernels.padded_scatter( - x, indices, bin_ids, weights, bins, padded_bins, top_k) + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) @staticmethod @custom_bwd @@ -33,7 +50,8 @@ def backward(ctx, grad): weights, bins, padded_bins, - ctx.top_k) + ctx.top_k, + ) wgrad = None if ctx.needs_input_grad[3]: # need wgrad @@ -45,16 +63,26 @@ def backward(ctx, grad): bin_ids, bins, padded_bins, - ctx.top_k) + ctx.top_k, + ) return dgrad, None, None, wgrad, None, None, None, None -def padded_scatter(x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int): - return PaddedScatterOp.apply(x, indices, bin_ids, weights, bins, - padded_bins, top_k) +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/megablocks/ops/padded_scatter_benchmark.py b/megablocks/ops/padded_scatter_benchmark.py index 7a7c337..81dde4e 100644 --- a/megablocks/ops/padded_scatter_benchmark.py +++ b/megablocks/ops/padded_scatter_benchmark.py @@ -1,10 +1,12 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import ops -from megablocks import benchmark_util import torch +from absl.testing import parameterized +from megablocks import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. @@ -35,18 +37,29 @@ def testPaddedScatter(self, sl, hs, ne, top_k): # Gather the data to prepare for backwards. x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - fn = lambda: ops.padded_scatter( - x, indices, bin_ids, weights, bins, padded_bins, top_k) + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) - time, std = benchmark_util.benchmark_function(fn) + time, std = benchmark_util.benchmark_function(benchmark) benchmark_util.log_benchmark( - "Padded Scatter", - {"sequence_length": sl, - "hidden_size": hs, - "num_experts": ne, - "top_k": top_k}, + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, time, - std) + std, + ) if __name__ == '__main__': diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index fb5b7f1..837f07e 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -1,12 +1,12 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import benchmark_util -from megablocks import ops -import numpy as np -import stk import torch +from absl.testing import parameterized +from megablocks import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), @@ -22,7 +22,7 @@ (16384 * 8, 768, 16), (16384 * 8, 768, 32), (16384 * 8, 768, 64), - (16384 * 8, 768, 128) + (16384 * 8, 768, 128), ) @@ -40,14 +40,16 @@ def testBinnedGather(self, sl, hs, ne): tokens_per_expert = ops.histogram(indices, ne) bins = ops.inclusive_cumsum(tokens_per_expert, 0) - benchmark = lambda: ops.binned_gather(x, indices, bins, ec) + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("BinnedGather", arguments, mean_t, std_t) + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testBinnedScatter(self, sl, hs, ne): @@ -62,14 +64,16 @@ def testBinnedScatter(self, sl, hs, ne): bins = ops.inclusive_cumsum(tokens_per_expert, 0) x = ops.binned_gather(x, indices, bins, ec) - benchmark = lambda: ops.binned_scatter(x, indices, bins) + def benchmark(): + return ops.binned_scatter(x, indices, bins) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("BinnedScatter", arguments, mean_t, std_t) + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testPaddedGather(self, sl, hs, ne): @@ -84,14 +88,16 @@ def testPaddedGather(self, sl, hs, ne): padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) bins = ops.inclusive_cumsum(tokens_per_expert, 0) - benchmark = lambda: ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("PaddedGather", arguments, mean_t, std_t) + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testPaddedScatter(self, sl, hs, ne): @@ -107,32 +113,36 @@ def testPaddedScatter(self, sl, hs, ne): bins = ops.inclusive_cumsum(tokens_per_expert, 0) x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - benchmark = lambda: ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("PaddedScatter", arguments, mean_t, std_t) + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testCopy(self, sl, hs, ne): # NOTE: Capacity factor == 1. - ec = sl // ne + # ec = sl // ne # Create the data and indices. x = torch.randn((sl, hs)).cuda().half() y = x.clone() - benchmark = lambda: y.copy_(x) + def benchmark(): + return y.copy_(x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("Copy", arguments, mean_t, std_t) + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) if __name__ == '__main__': diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index d02c956..61bb04b 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,7 +1,11 @@ -import torch +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 def repeat(x, tiling): - if all([t == 1 for t in tiling]): + if all((t == 1 for t in tiling)): return x return x.repeat(*tiling) diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index 4d0cf34..b7cb9c3 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -1,11 +1,17 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + # Autograd wrapper for replicate kernel. class ReplicateOp(torch.autograd.Function): @@ -13,20 +19,16 @@ class ReplicateOp(torch.autograd.Function): @staticmethod def forward(ctx, x, bins, num_outputs): ctx.save_for_backward(bins) - out = torch.empty( - (x.shape[0], num_outputs), - dtype=x.dtype, - device=x.device) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) ops.replicate_forward(x, bins, out) return out @staticmethod def backward(ctx, grad): bins, = ctx.saved_tensors - out = torch.empty( - (grad.shape[0], bins.shape[0]), - dtype=grad.dtype, - device=grad.device) + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) ops.replicate_backward(grad, bins, out) return out, None, None + + replicate = ReplicateOp.apply diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index fc81d61..2c59a78 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch @@ -8,4 +11,4 @@ def round_up(x, value): # TODO(tgale): If this becomes and issue # do this in a custom kernel. We only expect # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode="trunc") * value + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 0e91d80..33f051c 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,6 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for scatter kernel. @@ -13,8 +17,7 @@ def forward(ctx, x, indices, bin_ids, weights, bins, top_k): ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) ctx.top_k = top_k ctx.x_shape = x.shape - return kernels.scatter( - x, indices, bin_ids, weights, bins, top_k) + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) @staticmethod @custom_bwd @@ -31,7 +34,8 @@ def backward(ctx, grad): bin_ids, weights, bins, - ctx.top_k) + ctx.top_k, + ) wgrad = None if ctx.needs_input_grad[3]: # need wgrad @@ -42,14 +46,17 @@ def backward(ctx, grad): indices, bin_ids, bins, - ctx.top_k) + ctx.top_k, + ) return dgrad, None, None, wgrad, None, None, None -def scatter(x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int): +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +): return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index a4bb99f..12ec8f3 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,12 +1,16 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops - +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e _BITS_FOR_DTYPE = { torch.int16: 16, @@ -14,8 +18,8 @@ torch.int64: 64, } + # Autograd wrapper for sort kernel. -# # NOTE: Does not support gradients. class SortOp(torch.autograd.Function): @@ -27,4 +31,6 @@ def forward(ctx, x, end_bit=None): iota_out = torch.empty_like(x) ops.sort(x, end_bit, x_out, iota_out) return (x_out, iota_out) + + sort = SortOp.apply diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index 4305767..f28e3f2 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -1,10 +1,13 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import ops import numpy as np import torch +from absl.testing import parameterized +from megablocks import ops _SORT_TESTS = ( (16384, torch.int32, None), @@ -12,16 +15,14 @@ (16384, torch.int32, 128), ) -_BASELINE_SORT_TESTS = ( - (16384,), -) +_BASELINE_SORT_TESTS = ((16384,),) def numpy_dtype(dtype): types = { torch.int16: np.int16, torch.int32: np.int32, - torch.int64: np.int64 + torch.int64: np.int64, } return types[dtype] @@ -43,13 +44,13 @@ def benchmark_function(fn, iterations=10): def log_benchmark(arguments, mean_t, std_t): - print("="*60) - print("Benchmark Parameters:") + print('=' * 60) + print('Benchmark Parameters:') for (key, value) in arguments.items(): - print(f"{key} = {value}") - print("Results:") - print("mean / std = {:.2f}ms / {:.2f}ms".format(mean_t, std_t)) - print("="*60) + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) class SortBenchmark(parameterized.TestCase): @@ -61,12 +62,11 @@ def testSort(self, n, dtype, max_val): end_bit = int(np.ceil(np.log2(max_val))) x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: ops.sort(x, end_bit)) + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) arguments = { - "n": n, - "dtype": dtype, - "max_val": max_val + 'n': n, + 'dtype': dtype, + 'max_val': max_val, } log_benchmark(arguments, mean_t, std_t) @@ -74,9 +74,10 @@ def testSort(self, n, dtype, max_val): def testTorchSort(self, n): x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: torch.sort(x)) - arguments = {"n": n,} + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } log_benchmark(arguments, mean_t, std_t) diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index 9d550b5..aa81334 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,4 +1,5 @@ -import torch +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 def sum(x, dim=0): diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index 7ce31bc..ba4ade0 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -1,30 +1,43 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + # Autograd wrapper for topology kernel. -# # NOTE: Does not support gradients. class TopologyOp(torch.autograd.Function): @staticmethod - def forward(ctx, - padded_bins, - block_size, - output_block_rows, - output_block_columns): - out = torch.empty(output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device) - ops.indices(padded_bins, - block_size, - output_block_rows, - output_block_columns, - out) + def forward( + ctx, + padded_bins, + block_size, + output_block_rows, + output_block_columns, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) return out + + topology = TopologyOp.apply diff --git a/pyproject.toml b/pyproject.toml index b4b90ec..c72dbdf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,6 @@ +# Copyright 2024 MosaicML MegaBlocks authors +# SPDX-License-Identifier: Apache-2.0 + # build requirements [build-system] requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4"] @@ -44,3 +47,483 @@ concurrency = ["thread"] include = [ "megablocks/*" ] + + +# Ruff global +[tool.ruff] + +preview = true # enable preview features, see https://docs.astral.sh/ruff/preview/ + +exclude = [ + "build/**", + "docs/**", + "node_modules/**", +] + +# Ruff linter +[tool.ruff.lint] +select = [ + "C4", # flake8-comprehensions + # TODO port pydocstyle + # "D", # pydocstyle + "LOG", # flake8-logging + "PERF", # Perflint + "PLE", + "COM812", # missing-trailing-comma +] + +# iSort +[tool.isort] +multi_line_output = 0 +line_length = 120 +skip = ["env", "wandb", "runs", "build", "node_modules" ] +include_trailing_comma = true +split_on_trailing_comma = true + +# Yapf +[tool.yapf] + +# Align closing bracket with visual indentation. +align_closing_bracket_with_visual_indent = false + +# Allow dictionary keys to exist on multiple lines. For example: +# +# x = { +# ('this is the first element of a tuple', +# 'this is the second element of a tuple'): +# value, +# } +allow_multiline_dictionary_keys = false + +# Allow lambdas to be formatted on more than one line. +allow_multiline_lambdas = false + +# Allow splitting before a default / named assignment in an argument list. +allow_split_before_default_or_named_assigns = true + +# Allow splits before the dictionary value. +allow_split_before_dict_value = true + +# Let spacing indicate operator precedence. For example: +# +# a = 1 * 2 + 3 / 4 +# b = 1 / 2 - 3 * 4 +# c = (1 + 2) * (3 - 4) +# d = (1 - 2) / (3 + 4) +# e = 1 * 2 - 3 +# f = 1 + 2 + 3 + 4 +# +# will be formatted as follows to indicate precedence: +# +# a = 1*2 + 3/4 +# b = 1/2 - 3*4 +# c = (1+2) * (3-4) +# d = (1-2) / (3+4) +# e = 1*2 - 3 +# f = 1 + 2 + 3 + 4 +# +arithmetic_precedence_indication = false + +# Number of blank lines surrounding top-level function and class +# definitions. +blank_lines_around_top_level_definition = 2 + +# Insert a blank line before a class-level docstring. +blank_line_before_class_docstring = false + +# Insert a blank line before a module docstring. +blank_line_before_module_docstring = true + +# Insert a blank line before a 'def' or 'class' immediately nested +# within another 'def' or 'class'. For example: +# +# class Foo: +# # <------ this blank line +# def method(): +# ... +blank_line_before_nested_class_or_def = true + +# Do not split consecutive brackets. Only relevant when +# dedent_closing_brackets is set. For example: +# +# call_func_that_takes_a_dict( +# { +# 'key1': 'value1', +# 'key2': 'value2', +# } +# ) +# +# would reformat to: +# +# call_func_that_takes_a_dict({ +# 'key1': 'value1', +# 'key2': 'value2', +# }) +coalesce_brackets = true + +# The column limit. +column_limit = 120 + +# The style for continuation alignment. Possible values are: +# +# - SPACE: Use spaces for continuation alignment. This is default behavior. +# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns +# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or +# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment. +# - VALIGN-RIGHT: Vertically align continuation lines to multiple of +# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if +# cannot vertically align continuation lines with indent characters. +continuation_align_style = 'SPACE' + +# Indent width used for line continuations. +continuation_indent_width = 4 + +# Put closing brackets on a separate line, dedented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is dedented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is dedented and on a separate line +dedent_closing_brackets = true + +# Disable the heuristic which places each list element on a separate line +# if the list is comma-terminated. +disable_ending_comma_heuristic = false + +# Place each dictionary entry onto its own line. +each_dict_entry_on_separate_line = true + +# Require multiline dictionary even if it would normally fit on one line. +# For example: +# +# config = { +# 'key1': 'value1' +# } +force_multiline_dict = false + +# The regex for an i18n comment. The presence of this comment stops +# reformatting of that line, because the comments are required to be +# next to the string they translate. +i18n_comment = '#\..*' + +# The i18n function call names. The presence of this function stops +# reformattting on that line, because the string it has cannot be moved +# away from the i18n comment. +i18n_function_call = 'N_, _' + +# Indent blank lines. +indent_blank_lines = false + +# Put closing brackets on a separate line, indented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is indented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is indented and on a separate line +indent_closing_brackets = false + +# Indent the dictionary value if it cannot fit on the same line as the +# dictionary key. For example: +# +# config = { +# 'key1': +# 'value1', +# 'key2': value1 + +# value2, +# } +indent_dictionary_value = true + +# The number of columns to use for indentation. +indent_width = 4 + +# Join short lines into one line. E.g., single line 'if' statements. +join_multiple_lines = false + +# Do not include spaces around selected binary operators. For example: +# +# 1 + 2 * 3 - 4 / 5 +# +# will be formatted as follows when configured with "*,/": +# +# 1 + 2*3 - 4/5 +no_spaces_around_selected_binary_operators = '' + +# Use spaces around default or named assigns. +spaces_around_default_or_named_assign = false + +# Adds a space after the opening '{' and before the ending '}' dict delimiters. +# +# {1: 2} +# +# will be formatted as: +# +# { 1: 2 } +spaces_around_dict_delimiters = false + +# Adds a space after the opening '[' and before the ending ']' list delimiters. +# +# [1, 2] +# +# will be formatted as: +# +# [ 1, 2 ] +spaces_around_list_delimiters = false + +# Use spaces around the power operator. +spaces_around_power_operator = false + +# Use spaces around the subscript / slice operator. For example: +# +# my_list[1 : 10 : 2] +spaces_around_subscript_colon = false + +# Adds a space after the opening '(' and before the ending ')' tuple delimiters. +# +# (1, 2, 3) +# +# will be formatted as: +# +# ( 1, 2, 3 ) +spaces_around_tuple_delimiters = false + +# The number of spaces required before a trailing comment. +# This can be a single value (representing the number of spaces +# before each trailing comment) or list of values (representing +# alignment column values; trailing comments within a block will +# be aligned to the first column value that is greater than the maximum +# line length within the block). For example: +# +# With spaces_before_comment=5: +# +# 1 + 1 # Adding values +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment +# +# With spaces_before_comment = '15, 20:' +# +# 1 + 1 # Adding values +# two + two # More adding +# +# longer_statement # This is a longer statement +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment +# short # This is a shorter statement +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15 +# two + two # More adding +# +# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20 +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length +# short # This is a shorter statement +# +spaces_before_comment = 2 + +# Insert a space between the ending comma and closing bracket of a list, +# etc. +space_between_ending_comma_and_closing_bracket = false + +# Use spaces inside brackets, braces, and parentheses. For example: +# +# method_call( 1 ) +# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ] +# my_set = { 1, 2, 3 } +space_inside_brackets = false + +# Split before arguments +split_all_comma_separated_values = false + +# Split before arguments, but do not split all subexpressions recursively +# (unless needed). +split_all_top_level_comma_separated_values = false + +# Split before arguments if the argument list is terminated by a +# comma. +split_arguments_when_comma_terminated = true + +# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@' +# rather than after. +split_before_arithmetic_operator = false + +# Set to True to prefer splitting before '&', '|' or '^' rather than +# after. +split_before_bitwise_operator = false + +# Split before the closing bracket if a list or dict literal doesn't fit on +# a single line. +split_before_closing_bracket = true + +# Split before a dictionary or set generator (comp_for). For example, note +# the split before the 'for': +# +# foo = { +# variable: 'Hello world, have a nice day!' +# for variable in bar if variable != 42 +# } +split_before_dict_set_generator = false + +# Split before the '.' if we need to split a longer expression: +# +# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d)) +# +# would reformat to something like: +# +# foo = ('This is a really long string: {}, {}, {}, {}' +# .format(a, b, c, d)) +split_before_dot = false + +# Split after the opening paren which surrounds an expression if it doesn't +# fit on a single line. +split_before_expression_after_opening_paren = false + +# If an argument / parameter list is going to be split, then split before +# the first argument. +split_before_first_argument = false + +# Set to True to prefer splitting before 'and' or 'or' rather than +# after. +split_before_logical_operator = false + +# Split named assignments onto individual lines. +split_before_named_assigns = true + +# Set to True to split list comprehensions and generators that have +# non-trivial expressions and multiple clauses before each of these +# clauses. For example: +# +# result = [ +# a_long_var + 100 for a_long_var in xrange(1000) +# if a_long_var % 10] +# +# would reformat to something like: +# +# result = [ +# a_long_var + 100 +# for a_long_var in xrange(1000) +# if a_long_var % 10] +split_complex_comprehension = true + +# The penalty for splitting right after the opening bracket. +split_penalty_after_opening_bracket = 300 + +# The penalty for splitting the line after a unary operator. +split_penalty_after_unary_operator = 10000 + +# The penalty of splitting the line around the '+', '-', '*', '/', '//', +# ``%``, and '@' operators. +split_penalty_arithmetic_operator = 300 + +# The penalty for splitting right before an if expression. +split_penalty_before_if_expr = 0 + +# The penalty of splitting the line around the '&', '|', and '^' +# operators. +split_penalty_bitwise_operator = 300 + +# The penalty for splitting a list comprehension or generator +# expression. +split_penalty_comprehension = 2100 + +# The penalty for characters over the column limit. +split_penalty_excess_character = 7000 + +# The penalty incurred by adding a line split to the unwrapped line. The +# more line splits added the higher the penalty. +split_penalty_for_added_line_split = 20 + +# The penalty of splitting a list of "import as" names. For example: +# +# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, +# long_argument_2, +# long_argument_3) +# +# would reformat to something like: +# +# from a_very_long_or_indented_module_name_yada_yad import ( +# long_argument_1, long_argument_2, long_argument_3) +split_penalty_import_names = 0 + +# The penalty of splitting the line around the 'and' and 'or' +# operators. +split_penalty_logical_operator = 300 + +# Use the Tab character for indentation. +use_tabs = false + +# Ignore directories +[tool.yapfignore] +ignore_patterns = [ + "runs/**/*.py", + "wandb/**/*.py", + "build/**/*.py", +] + +# PyDocStyle +[tool.pydocstyle] +convention="google" +add_ignore="D100,D101,D102,D103,D104,D105,D107,D400,D401,D415" +add_select="D404" + + +# Pyright +[tool.pyright] +exclude = ['env-**', 'venv*', '.venv', 'tests/*', '**benchmark'] +stubPath = "" # suppress useless 'stubPath is not a valid directory' errors + +reportUnnecessaryIsInstance = "none" # it is ok to do this for clarity or safety +reportMissingTypeStubs = "none" +reportIncompatibleMethodOverride = "none" +reportIncompatibleVariableOverride = "error" +reportUnusedImport = "error" +reportUnusedClass = "warning" +reportUnusedFunction = "warning" +reportUnusedVariable = "error" +reportDuplicateImport = "error" +reportWildcardImportFromLibrary = "error" +reportUntypedFunctionDecorator = "warning" +reportPrivateImportUsage = "none" +reportUndefinedVariable = "error" +strictParameterNoneValue = true +reportPropertyTypeMismatch = "error" +reportUntypedNamedTuple = "error" +reportUnnecessaryCast = "error" +reportInvalidTypeVarUse = "error" +reportOverlappingOverload = "error" +reportUninitializedInstanceVariable = "error" +reportInvalidStringEscapeSequence = "error" +reportMissingParameterType = "error" +reportCallInDefaultInitializer = "error" +reportUnnecessaryComparison = "error" +reportSelfClsParameterName = "error" +reportImplicitStringConcatenation = "warning" # TODO: make this an error +reportInvalidStubStatement = "error" +reportIncompleteStub = "error" +reportUnsupportedDunderAll = "error" +reportUnusedCoroutine = "error" +reportMissingImports = "none" diff --git a/setup.py b/setup.py index a247b0d..fa15ee4 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 """MegaBlocks package setup.""" @@ -9,18 +9,13 @@ from setuptools import find_packages, setup - # We require torch in setup.py to build cpp extensions "ahead of time" # More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html try: import torch - from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, - CUDAExtension,) + from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "No module named 'torch'. `torch` is required to install `MegaBlocks`." - ) from e - + raise ModuleNotFoundError("No module named 'torch'. `torch` is required to install `MegaBlocks`.",) from e _PACKAGE_NAME = 'megablocks' _PACKAGE_DIR = 'megablocks' @@ -37,7 +32,6 @@ exec(content, version_globals, version_locals) repo_version = version_locals['__version__'] - with open('README.md', 'r', encoding='utf-8') as fh: long_description = fh.read() @@ -56,7 +50,6 @@ long_description = long_description[:start] + \ long_description[end + len(end_tag):] - classifiers = [ 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.9', @@ -76,12 +69,12 @@ extra_deps = {} -extra_deps["gg"] = [ +extra_deps['gg'] = [ 'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb', ] extra_deps['dev'] = [ - 'absl-py', # todo: delete when finish removing all absl tests + 'absl-py', # TODO: delete when finish removing all absl tests 'coverage[toml]==7.4.4', 'pytest_codeblocks>=0.16.1,<0.17', 'pytest-cov>=4,<5', @@ -93,11 +86,7 @@ 'mosaicml>=0.22.0', ] -extra_deps['all'] = list({ - dep for key, deps in extra_deps.items() for dep in deps - if key not in {'testing'} -}) - +extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) cmdclass = {} ext_modules = [] @@ -116,9 +105,7 @@ device_capability = f'{device_capability_tuple[0]}{device_capability_tuple[1]}' if device_capability: - nvcc_flags.append( - f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}' - ) + nvcc_flags.append(f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}',) ext_modules = [ CUDAExtension( @@ -127,19 +114,19 @@ include_dirs=['csrc'], extra_compile_args={ 'cxx': ['-fopenmp'], - 'nvcc': nvcc_flags + 'nvcc': nvcc_flags, }, - ) + ), ] elif CUDA_HOME is None: warnings.warn( 'Attempted to install CUDA extensions, but CUDA_HOME was None. ' + 'Please install CUDA and ensure that the CUDA_HOME environment ' + - 'variable points to the installation location.') + 'variable points to the installation location.', + ) else: warnings.warn('Warning: No CUDA devices; cuda code will not be compiled.') - setup( name=_PACKAGE_NAME, version=repo_version, diff --git a/tests/conftest.py b/tests/conftest.py index 335140c..663bda3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import os from typing import List, Optional @@ -13,7 +16,7 @@ # Add the path of any pytest fixture files you want to make global pytest_plugins = [ 'tests.fixtures.autouse', - 'tests.fixtures.fixtures' + 'tests.fixtures.fixtures', ] @@ -23,8 +26,11 @@ def _get_world_size(item: pytest.Item): return item.get_closest_marker('world_size', default=_default).args[0] - -def _get_option(config: pytest.Config, name: str, default: Optional[str] = None) -> str: # type: ignore +def _get_option( + config: pytest.Config, + name: str, + default: Optional[str] = None, +) -> str: # type: ignore val = config.getoption(name) if val is not None: assert isinstance(val, str) @@ -34,13 +40,18 @@ def _get_option(config: pytest.Config, name: str, default: Optional[str] = None) val = None if val is None: if default is None: - pytest.fail(f'Config option {name} is not specified but is required') + pytest.fail(f'Config option {name} is not specified but is required',) val = default assert isinstance(val, str) return val -def _add_option(parser: pytest.Parser, name: str, help: str, choices: Optional[list[str]] = None): +def _add_option( + parser: pytest.Parser, + name: str, + help: str, + choices: Optional[list[str]] = None, +): parser.addoption( f'--{name}', default=None, diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 29fbdeb..6805f3c 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import gc import logging import os @@ -20,7 +23,7 @@ def clear_cuda_cache(request: pytest.FixtureRequest): @pytest.fixture(autouse=True) def reset_mlflow_tracking_dir(): - """Reset MLFlow tracking dir so it doesn't persist across tests""" + """Reset MLFlow tracking dir so it doesn't persist across tests.""" try: import mlflow mlflow.set_tracking_uri(None) # type: ignore @@ -72,9 +75,16 @@ def set_log_levels(): @pytest.fixture(autouse=True) def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch): - """Monkeypatch reproducibility get_random_seed to always return the rank zero seed, and set the random seed before - each test to the rank local seed.""" - monkeypatch.setattr(reproducibility, 'get_random_seed', lambda: rank_zero_seed) + """Monkeypatch reproducibility. + + Make get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local + seed. + """ + monkeypatch.setattr( + reproducibility, + 'get_random_seed', + lambda: rank_zero_seed, + ) reproducibility.seed_all(rank_zero_seed + dist.get_global_rank()) diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index 48645a8..4039db7 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -1,4 +1,8 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import pytest + from tests.conftest import _get_option diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index 3ead862..a737ef4 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from functools import partial @@ -28,13 +28,10 @@ (16, 1024, 128, 1, 1), ) -_FORWARD_TESTS_GROUPED_MLP = tuple([ - p + ('grouped',) for p in _FORWARD_TESTS_DEFAULT -]) if gg.grouped_gemm_is_available() else () +_FORWARD_TESTS_GROUPED_MLP = tuple([p + ('grouped',) for p in _FORWARD_TESTS_DEFAULT + ],) if gg.grouped_gemm_is_available() else () -_FORWARD_TESTS_SPARSE_MLP = tuple([ - p + ('sparse',) for p in _FORWARD_TESTS_DEFAULT -]) +_FORWARD_TESTS_SPARSE_MLP = tuple([p + ('sparse',) for p in _FORWARD_TESTS_DEFAULT]) _FORWARD_TESTS = (_FORWARD_TESTS_SPARSE_MLP + _FORWARD_TESTS_GROUPED_MLP) @@ -44,24 +41,28 @@ ) -def construct_moes(hidden_size: int, - ffn_hidden_size: int, - moe_num_experts: int = 1, - moe_capacity_factor: int = 1, - moe_top_k: int = 1, - mlp_impl: str = 'sparse'): +def construct_moes( + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int = 1, + moe_capacity_factor: int = 1, + moe_top_k: int = 1, + mlp_impl: str = 'sparse', +): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) - args = Arguments(hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=moe_num_experts, - moe_capacity_factor=moe_capacity_factor, - moe_top_k=moe_top_k, - init_method=init_method, - memory_optimized_mlp=True, - mlp_type='mlp', - mlp_impl=mlp_impl, - fp16=False, - bf16=True) + args = Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=moe_num_experts, + moe_capacity_factor=moe_capacity_factor, + moe_top_k=moe_top_k, + init_method=init_method, + memory_optimized_mlp=True, + mlp_type='mlp', + mlp_impl=mlp_impl, + fp16=False, + bf16=True, + ) mlp = testing.FFN(args) moe_mlp = moe.MoE(args) @@ -76,8 +77,7 @@ def construct_moes(hidden_size: int, ne, hs, fhs = moe_mlp.experts.mlp.w1.size() w1 = dmoe_mlp.experts.mlp.w1.view([ne, fhs, hs]) moe_mlp.experts.mlp.w1.copy_(torch.transpose(w1, 1, 2).contiguous()) - moe_mlp.experts.mlp.w2.copy_(dmoe_mlp.experts.mlp.w2.view([ne, fhs, - hs])) + moe_mlp.experts.mlp.w2.copy_(dmoe_mlp.experts.mlp.w2.view([ne, fhs, hs]),) moe_mlp.router.layer.weight.copy_(dmoe_mlp.router.layer.weight) if moe_num_experts == 1: mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze()) @@ -86,20 +86,23 @@ def construct_moes(hidden_size: int, @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), - _FORWARD_TESTS) -def test_dmoe_forward(bs: int, - sl: int, - hs: int, - num_experts: int, - top_k: int, - mlp_impl: str): +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) +def test_dmoe_forward( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() - _, _, _, layer = construct_moes(hidden_size=hs, - ffn_hidden_size=hs * 2, - moe_num_experts=num_experts, - moe_top_k=top_k, - mlp_impl=mlp_impl) + _, _, _, layer = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=num_experts, + moe_top_k=top_k, + mlp_impl=mlp_impl, + ) out, _ = layer(x) assert out.shape == x.shape @@ -107,22 +110,25 @@ def test_dmoe_forward(bs: int, @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), - _FORWARD_TESTS) -def test_dmoe_forward_backward(bs: int, - sl: int, - hs: int, - num_experts: int, - top_k: int, - mlp_impl: str): +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) +def test_dmoe_forward_backward( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() x.requires_grad_(True) - args, _, _, layer = construct_moes(hidden_size=hs, - ffn_hidden_size=hs * 2, - moe_num_experts=num_experts, - moe_top_k=top_k, - mlp_impl=mlp_impl) + args, _, _, layer = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=num_experts, + moe_top_k=top_k, + mlp_impl=mlp_impl, + ) out, _ = layer(x) assert out.shape == x.shape @@ -136,18 +142,22 @@ def test_dmoe_forward_backward(bs: int, @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) -def test_dmoe_forward_vs_baseline(bs: int, - sl: int, - hs: int, - mlp_impl: str = 'sparse'): +def test_dmoe_forward_vs_baseline( + bs: int, + sl: int, + hs: int, + mlp_impl: str = 'sparse', +): x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() - _, mlp, _, dmoe_mlp = construct_moes(hidden_size=hs, - ffn_hidden_size=hs * 2, - moe_num_experts=1, - moe_capacity_factor=1, - moe_top_k=1, - mlp_impl=mlp_impl) + _, mlp, _, dmoe_mlp = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=1, + moe_capacity_factor=1, + moe_top_k=1, + mlp_impl=mlp_impl, + ) expected_out = mlp(x) out, _ = dmoe_mlp(x) @@ -156,23 +166,26 @@ def test_dmoe_forward_vs_baseline(bs: int, @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), - _FORWARD_TESTS) -def test_dmoe_forward_vs_moe(bs: int, - sl: int, - hs: int, - num_experts: int, - top_k: int, - mlp_impl: str): +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) +def test_dmoe_forward_vs_moe( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): torch.manual_seed(42) x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() - _, _, moe_mlp, dmoe_mlp = construct_moes(hidden_size=hs, - ffn_hidden_size=hs, - moe_num_experts=num_experts, - moe_capacity_factor=0, - mlp_impl=mlp_impl) + _, _, moe_mlp, dmoe_mlp = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs, + moe_num_experts=num_experts, + moe_capacity_factor=0, + mlp_impl=mlp_impl, + ) expected_out, _ = moe_mlp(x) out, _ = dmoe_mlp(x) diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index 0487ec8..d89af89 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from functools import partial import pytest @@ -6,7 +9,6 @@ from megablocks.layers import dmlp_registry, testing from megablocks.layers.arguments import Arguments -from megablocks.layers.glu import GroupedGLU, SparseGLU _DENSE_TESTS = ( (16, 1024, 512), @@ -15,10 +17,11 @@ def construct_dmoe_glu( - hidden_size: int, - ffn_hidden_size: int, - mlp_impl: str ='sparse', - memory_optimized_mlp: bool =False): + hidden_size: int, + ffn_hidden_size: int, + mlp_impl: str = 'sparse', + memory_optimized_mlp: bool = False, +): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( hidden_size=hidden_size, @@ -30,7 +33,8 @@ def construct_dmoe_glu( mlp_type='glu', mlp_impl=mlp_impl, fp16=False, - bf16=True) + bf16=True, + ) glu = testing.GLU(args) dmoe_glu = dmlp_registry.get(args) @@ -46,7 +50,6 @@ def construct_dmoe_glu( return args, glu, dmoe_glu - @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) def test_glu_forward_grouped_mlp(bs: int, sl: int, hs: int): @@ -55,7 +58,8 @@ def test_glu_forward_grouped_mlp(bs: int, sl: int, hs: int): _, glu, dmoe_glu = construct_dmoe_glu( hidden_size=hs, ffn_hidden_size=hs * 2, - mlp_impl='grouped') + mlp_impl='grouped', + ) expected_out = glu(x) tokens_per_expert = torch.tensor([bs * sl]).cuda() @@ -75,7 +79,8 @@ def test_glu_forward_grouped_mlp_mem_opt(bs: int, sl: int, hs: int): hidden_size=hs, ffn_hidden_size=hs * 2, mlp_impl='grouped', - memory_optimized_mlp=True) + memory_optimized_mlp=True, + ) expected_out = glu(x) tokens_per_expert = torch.tensor([bs * sl]).cuda() @@ -94,7 +99,8 @@ def test_glu_forward_sparse_mlp(bs: int, sl: int, hs: int): _, glu, dmoe_glu = construct_dmoe_glu( hidden_size=hs, ffn_hidden_size=hs * 2, - mlp_impl='sparse') + mlp_impl='sparse', + ) expected_out = glu(x) with torch.no_grad(): diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index 75ea196..dd40ef9 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from functools import partial import pytest @@ -22,7 +25,6 @@ (16, 1024, 512, 8, 8), ) - _DENSE_TESTS = ( (16, 1024, 512), (8, 2048, 512), @@ -30,11 +32,12 @@ def construct_moe( - hidden_size, - ffn_hidden_size, - moe_num_experts=1, - moe_capacity_factor=1, - moe_top_k=1): + hidden_size, + ffn_hidden_size, + moe_num_experts=1, + moe_capacity_factor=1, + moe_top_k=1, +): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( hidden_size=hidden_size, @@ -42,7 +45,8 @@ def construct_moe( moe_num_experts=moe_num_experts, moe_capacity_factor=moe_capacity_factor, moe_top_k=moe_top_k, - init_method=init_method) + init_method=init_method, + ) mlp = testing.FFN(args) moe_mlp = moe.MoE(args) @@ -59,8 +63,7 @@ def construct_moe( @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), - _FORWARD_TESTS) +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): x = torch.randn(sl, bs, hs).half().cuda() @@ -68,16 +71,23 @@ def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): hidden_size=hs, ffn_hidden_size=hs * 2, moe_num_experts=num_experts, - moe_top_k=top_k) + moe_top_k=top_k, + ) out, _ = layer(x) assert out.shape == x.shape moe.clear_load_balancing_loss() + @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), - _FORWARD_TESTS) -def test_moe_forward_backward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) +def test_moe_forward_backward( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, +): x = torch.randn(sl, bs, hs).half().cuda() x.requires_grad_(True) @@ -85,7 +95,8 @@ def test_moe_forward_backward(bs: int, sl: int, hs: int, num_experts: int, top_k hidden_size=hs, ffn_hidden_size=hs * 2, moe_num_experts=num_experts, - moe_top_k=top_k) + moe_top_k=top_k, + ) out, _ = layer(x) assert out.shape == x.shape @@ -102,9 +113,7 @@ def test_moe_forward_backward(bs: int, sl: int, hs: int, num_experts: int, top_k def test_moe_forward_vs_dense(bs: int, sl: int, hs: int): x = torch.randn(sl, bs, hs).half().cuda() - _, mlp, moe_mlp = construct_moe( - hidden_size=hs, - ffn_hidden_size=hs * 2) + _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2) expected_out = mlp(x) out, _ = moe_mlp(x) @@ -119,9 +128,7 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): x = torch.randn(sl, bs, hs).half().cuda() x.requires_grad_(True) - _, mlp, moe_mlp = construct_moe( - hidden_size=hs, - ffn_hidden_size=hs * 2) + _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2) out, _ = moe_mlp(x) loss = out.sum() @@ -141,7 +148,7 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): x.grad = None # Verify the gradients match. - assert w1_grad.shape == expected_w1_grad.shape + assert w1_grad.shape == expected_w1_grad.shape assert w2_grad.shape == expected_w2_grad.shape assert torch.allclose(w1_grad, expected_w1_grad) assert torch.allclose(w2_grad, expected_w2_grad) diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py index 0aa4269..35e40a0 100644 --- a/tests/layers/parallelism_test.py +++ b/tests/layers/parallelism_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import functools import numpy as np @@ -15,6 +18,7 @@ (4, 1, 512, 2048, 4, 1, True), ) + # Todo: Fix this long term @pytest.fixture def group(): @@ -23,17 +27,25 @@ def group(): @pytest.mark.world_size(2) @pytest.mark.gpu -@pytest.mark.parametrize(('batch_size', 'sequence_length', 'hidden_size', 'ffn_hidden_size', 'num_experts', 'top_k', 'memory_optimized'), - _PARALLELISM_TESTS) +@pytest.mark.parametrize(( + 'batch_size', + 'sequence_length', + 'hidden_size', + 'ffn_hidden_size', + 'num_experts', + 'top_k', + 'memory_optimized', +), _PARALLELISM_TESTS) def test_expert_parallel_versus_weight_parallel( - group, - batch_size: int, - sequence_length: int, - hidden_size: int, - ffn_hidden_size: int, - num_experts: int, - top_k: int, - memory_optimized: bool): + group, + batch_size: int, + sequence_length: int, + hidden_size: int, + ffn_hidden_size: int, + num_experts: int, + top_k: int, + memory_optimized: bool, +): init_fn = functools.partial(torch.nn.init.normal_, mean=0.0, std=0.1) ep_args = arguments.Arguments( @@ -47,7 +59,8 @@ def test_expert_parallel_versus_weight_parallel( bf16=False, device=torch.cuda.current_device(), init_method=init_fn, - memory_optimized_mlp=memory_optimized) + memory_optimized_mlp=memory_optimized, + ) wp_args = arguments.Arguments( hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, @@ -59,7 +72,8 @@ def test_expert_parallel_versus_weight_parallel( bf16=False, device=torch.cuda.current_device(), init_method=init_fn, - memory_optimized_mlp=memory_optimized) + memory_optimized_mlp=memory_optimized, + ) # NOTE: Reset the seed so that the models get identical weights. torch.manual_seed(1234) @@ -70,10 +84,8 @@ def test_expert_parallel_versus_weight_parallel( # NOTE: Include the rank in the seed so we get different data per rank. rank = torch.distributed.get_rank(group) torch.manual_seed(1234 * rank) - x = torch.randn( - (batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.float32).requires_grad_(True) + x = torch.randn((batch_size, sequence_length, hidden_size), device=torch.cuda.current_device(), + dtype=torch.float32).requires_grad_(True) # Test forward. out, _ = wp(x) @@ -86,7 +98,9 @@ def test_expert_parallel_versus_weight_parallel( assert np.testing.assert_allclose( out.detach().float().cpu(), expected_out.detach().float().cpu(), - rtol=1e-4, atol=1e-4) is None + rtol=1e-4, + atol=1e-4, + ) is None # Test backward. out.mean().backward() @@ -97,8 +111,7 @@ def test_expert_parallel_versus_weight_parallel( def gather(x): m, n = x.shape world_size = torch.distributed.get_world_size(group) - out = torch.empty( - m * world_size, n, device=x.device, dtype=x.dtype) + out = torch.empty(m * world_size, n, device=x.device, dtype=x.dtype) torch.distributed.all_gather_into_tensor(out, x, group=group) return out @@ -114,7 +127,9 @@ def permute(x): assert np.testing.assert_allclose( wp_w2_grad.float().cpu(), ep_w2_grad.float().cpu(), - rtol=1e-5, atol=1e-5) is None + rtol=1e-5, + atol=1e-5, + ) is None wp_w1_grad = gather(wp.experts.mlp.w1.grad) ep_w1_grad = permute(gather(ep.experts.mlp.w1.grad)) @@ -122,7 +137,9 @@ def permute(x): assert np.testing.assert_allclose( wp_w1_grad.float().cpu(), ep_w1_grad.float().cpu(), - rtol=1e-5, atol=1e-5) is None + rtol=1e-5, + atol=1e-5, + ) is None # Verify the router weight gradient, which is not sharded. for i in range(torch.distributed.get_world_size(group)): @@ -131,4 +148,6 @@ def permute(x): assert np.testing.assert_allclose( wp.router.layer.weight.grad.float().cpu(), ep.router.layer.weight.grad.float().cpu(), - rtol=1e-5, atol=1e-5) is None + rtol=1e-5, + atol=1e-5, + ) is None diff --git a/tests/ops/binned_gather_test.py b/tests/ops/binned_gather_test.py index cc59ae3..c165086 100644 --- a/tests/ops/binned_gather_test.py +++ b/tests/ops/binned_gather_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -46,8 +46,13 @@ def test_binned_gather(sl: int, hs: int, ne: int, top_k: int): _, indices = ops.sort(top_expert) bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0) - def binned_gather(x: torch.Tensor, indices: torch.Tensor, - bins: torch.Tensor, ec: int, top_k: int): + def binned_gather( + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + ec: int, + top_k: int, + ): x = x.cpu().numpy() indices = indices.cpu().numpy() bins = bins.cpu().numpy() diff --git a/tests/ops/binned_scatter_test.py b/tests/ops/binned_scatter_test.py index 2d1c585..b725700 100644 --- a/tests/ops/binned_scatter_test.py +++ b/tests/ops/binned_scatter_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import pytest import torch @@ -48,8 +51,13 @@ def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int): x = ops.binned_gather(x, indices, bins, ec, top_k) - def binned_scatter(x: torch.Tensor, indices: torch.Tensor, - weights: torch.Tensor, bins: torch.Tensor, top_k: int): + def binned_scatter( + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): x = x.cpu().numpy() indices = indices.cpu().numpy() weights = weights.cpu().numpy() @@ -66,10 +74,14 @@ def binned_scatter(x: torch.Tensor, indices: torch.Tensor, out[index, :] += scale * x[i, j, :] start = end return torch.from_numpy(out).cuda().half() + out = ops.binned_scatter(x, indices, weights, bins, top_k) expected_out = binned_scatter(x, indices, weights, bins, top_k) # NOTE: We need to check approximate equality because the # scatter reduce uses atomics. assert np.testing.assert_allclose( - out.cpu(), expected_out.cpu(), rtol=5e-3) is None + out.cpu(), + expected_out.cpu(), + rtol=5e-3, + ) is None diff --git a/tests/ops/cumsum_test.py b/tests/ops/cumsum_test.py index a1b7160..5d8b082 100644 --- a/tests/ops/cumsum_test.py +++ b/tests/ops/cumsum_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import pytest diff --git a/tests/ops/histogram_test.py b/tests/ops/histogram_test.py index 25b30cb..d6d3f23 100644 --- a/tests/ops/histogram_test.py +++ b/tests/ops/histogram_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import pytest @@ -78,6 +78,5 @@ def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int): x = torch.randint(0, max_val, (m, n)).cuda().to(dtype) out = ops.histogram(x, max_val) - expected_out = torch.stack( - [torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)]) + expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)]) assert torch.all(torch.eq(out, expected_out)) diff --git a/tests/ops/padded_gather_test.py b/tests/ops/padded_gather_test.py index e6eb7f7..7198099 100644 --- a/tests/ops/padded_gather_test.py +++ b/tests/ops/padded_gather_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -63,9 +63,14 @@ def testPaddedGather(sl: int, hs: int, ne: int, top_k: int): padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) bins = ops.inclusive_cumsum(tokens_per_expert, 0) - def padded_gather(x: torch.Tensor, indices: torch.Tensor, - bin_ids: torch.Tensor, bins: torch.Tensor, - padded_bins: torch.Tensor, top_k: int): + def padded_gather( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): x = x.cpu().numpy() indices = indices.cpu().numpy() bin_ids = bin_ids.cpu().numpy() diff --git a/tests/ops/padded_scatter_test.py b/tests/ops/padded_scatter_test.py index ebd04a8..0e80dbb 100644 --- a/tests/ops/padded_scatter_test.py +++ b/tests/ops/padded_scatter_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -94,10 +94,15 @@ def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int): # Gather the data to prepare for backwards. x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - def padded_scatter(x: torch.Tensor, indices: torch.Tensor, - bin_ids: torch.Tensor, weights: torch.Tensor, - bins: torch.Tensor, padded_bins: torch.Tensor, - top_k: int): + def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): x = x.detach().cpu().numpy() indices: np.ndarray = _to_numpy(indices) bin_ids: np.ndarray = _to_numpy(bin_ids) @@ -120,10 +125,24 @@ def padded_scatter(x: torch.Tensor, indices: torch.Tensor, in_idx += 1 return torch.from_numpy(out).cuda().half() - out = ops.padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, - top_k) - expected_out = padded_scatter(x, indices, bin_ids, weights, bins, - padded_bins, top_k) + out = ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + expected_out = padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) out.backward(torch.randn_like(out)) # sanity check backward pass diff --git a/tests/ops/replicate_test.py b/tests/ops/replicate_test.py index 94aeb67..aeb1405 100644 --- a/tests/ops/replicate_test.py +++ b/tests/ops/replicate_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/tests/ops/sort_test.py b/tests/ops/sort_test.py index e07f2e1..147426e 100644 --- a/tests/ops/sort_test.py +++ b/tests/ops/sort_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from typing import Dict, Optional, Union @@ -31,12 +31,11 @@ ] -def torch_to_numpy_dtype( - dtype: torch.dtype) -> Union[np.int16, np.int32, np.int64]: +def torch_to_numpy_dtype(dtype: torch.dtype,) -> Union[np.int16, np.int32, np.int64]: types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = { torch.int16: np.int16, torch.int32: np.int32, - torch.int64: np.int64 + torch.int64: np.int64, } return types[dtype] diff --git a/tests/ops/topology_test.py b/tests/ops/topology_test.py index a7135be..dc3c0ae 100644 --- a/tests/ops/topology_test.py +++ b/tests/ops/topology_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -48,8 +48,12 @@ def test_topology(sl: int, hs: int, ne: int): output_block_rows = int(padded_bins[-1]) // blocking output_block_columns = hs // blocking - def topology(padded_bins: torch.Tensor, blocking: torch.Tensor, rows: int, - columns: int): + def topology( + padded_bins: torch.Tensor, + blocking: torch.Tensor, + rows: int, + columns: int, + ): padded_bins = padded_bins.cpu().numpy() out = np.zeros([rows * columns]) @@ -62,8 +66,16 @@ def topology(padded_bins: torch.Tensor, blocking: torch.Tensor, rows: int, start += 1 return torch.from_numpy(out).cuda().short() - out = ops.topology(padded_bins, blocking, output_block_rows, - output_block_columns) - expected_out = topology(padded_bins, blocking, output_block_rows, - output_block_columns) + out = ops.topology( + padded_bins, + blocking, + output_block_rows, + output_block_columns, + ) + expected_out = topology( + padded_bins, + blocking, + output_block_rows, + output_block_columns, + ) assert torch.all(torch.eq(out, expected_out)) diff --git a/yamls/matmul_benchmark.yaml b/yamls/matmul_benchmark.yaml index de26f58..46a79d9 100644 --- a/yamls/matmul_benchmark.yaml +++ b/yamls/matmul_benchmark.yaml @@ -4,11 +4,11 @@ cluster: r9z1 gpu_num: 8 gpu_type: h100_80gb integrations: - - integration_type: git_repo - git_repo: stanford-futuredata/megablocks - git_branch: main - pip_install: absl-py 'git+https://github.com/openai/triton.git@main#egg=triton&subdirectory=python' - ssh_clone: false +- integration_type: git_repo + git_repo: stanford-futuredata/megablocks + git_branch: main + pip_install: absl-py 'git+https://github.com/openai/triton.git@main#egg=triton&subdirectory=python' + ssh_clone: false command: |- cd megablocks export ENABLE_TMA=1 diff --git a/yamls/triton_benchmark.yaml b/yamls/triton_benchmark.yaml index 70c9626..fd30946 100644 --- a/yamls/triton_benchmark.yaml +++ b/yamls/triton_benchmark.yaml @@ -4,14 +4,14 @@ cluster: r9z1 gpu_num: 8 gpu_type: h100_80gb integrations: - - integration_type: git_repo - git_repo: openai/triton - git_branch: main - ssh_clone: false +- integration_type: git_repo + git_repo: openai/triton + git_branch: main + ssh_clone: false command: |- export ENABLE_TMA=1 export ENABLE_MMA_V3=1 - + cd triton/python pip install . --no-dependencies From f87b26f7ec7791100545e65cdcb368b48c9a23a0 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Fri, 9 Aug 2024 16:50:43 -0400 Subject: [PATCH 04/11] bump to v0.1.2 (#138) --- .github/workflows/code-quality.yaml | 2 +- .github/workflows/pr-gpu.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml index 2b1d931..ff9081d 100644 --- a/.github/workflows/code-quality.yaml +++ b/.github/workflows/code-quality.yaml @@ -33,7 +33,7 @@ jobs: uses: actions/checkout@v3 with: repository: mosaicml/ci-testing - ref: v0.1.1 + ref: v0.1.2 path: ./ci-testing - uses: ./ci-testing/.github/actions/code-quality with: diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index e03d37f..1ca8d5b 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -31,7 +31,7 @@ jobs: container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 steps: - name: Run PR GPU tests - uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2 with: name: ${{ matrix.name }} container: ${{ matrix.container }} From 27d3d2c32319e75caa87b0a7860d64cd556cc26d Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Mon, 12 Aug 2024 17:03:52 -0400 Subject: [PATCH 05/11] remove weight parallelism (#137) * remove weight parallelism * fix linting * remove parallel forward from mlp * remove weight parallel * cleanup --- megablocks/layers/arguments.py | 2 - megablocks/layers/glu.py | 3 - megablocks/layers/mlp.py | 50 +--- megablocks/layers/mpu.py | 8 - megablocks/layers/weight_parallel.py | 416 --------------------------- tests/layers/parallelism_test.py | 153 ---------- 6 files changed, 4 insertions(+), 628 deletions(-) delete mode 100644 megablocks/layers/weight_parallel.py delete mode 100644 tests/layers/parallelism_test.py diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index efe131d..ddbe2b7 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -40,8 +40,6 @@ class Arguments: # Parallelism arguments. moe_expert_model_parallelism: bool = False expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None - moe_weight_parallelism: bool = False - weight_parallel_group: Optional[torch.distributed.ProcessGroup] = None pipeline_model_parallel_size: int = 1 num_layers_per_virtual_pipeline_stage: Optional[int] = None diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index fa888a6..4654576 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -44,9 +44,6 @@ def __init__(self, args: Arguments): self._should_set_parallelism_attribute, ) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GLU.',) - def forward(self, x, topo): if self.args.memory_optimized_mlp: raise NotImplementedError( diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 1cae4fb..f7cb782 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -9,7 +9,6 @@ from megablocks import grouped_gemm_util as gg from megablocks.layers import common, gelu, mpu -from megablocks.layers import weight_parallel as wp from megablocks.layers.activation_fn import act_fn from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn @@ -180,21 +179,7 @@ def create_dmoe_expert_weights( columns, init_method, ) - weights = weights.view([-1, columns]) - rows, columns = weights.shape - - if not args.moe_weight_parallelism: - return weights - - # Caclculate the number of rows on this weight parallel partition. - # 'rows' must be divisible by weight parallel world size. - weight_parallel_world_size = mpu.get_weight_parallel_world_size(args) - assert (rows % weight_parallel_world_size) == 0 - num_rows_per_rank = rows // weight_parallel_world_size - rank = mpu.get_weight_parallel_rank(args) - start_row = rank * num_rows_per_rank - end_row = (rank + 1) * num_rows_per_rank - return weights[start_row:end_row] + return weights.view([-1, columns]) class MemoryOptimizedMLP(torch.autograd.Function): @@ -323,8 +308,7 @@ class SparseMLP(torch.nn.Module): def __init__(self, args: Arguments): super().__init__() self.args = args - self._num_rows_per_rank = ((mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // - mpu.get_weight_parallel_world_size(args)) + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) self.w1 = torch.nn.Parameter( torch.empty( @@ -371,7 +355,7 @@ def __init__(self, args: Arguments): ), ) - self._should_set_parallelism_attribute = (args.moe_expert_model_parallelism or args.moe_weight_parallelism) + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism mpu.set_expert_model_parallel_attributes( self.w1, self._should_set_parallelism_attribute, @@ -390,33 +374,10 @@ def scale_grad(self, w): return w return scale_gradient(w, self.gradient_scale) - def parallel_forward(self, x, topo): - group = self.args.weight_parallel_group - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - if self.args.memory_optimized_mlp: - if self.args.activation_fn is not DEFAULT_ACTIVATION_FN: - raise NotImplementedError( - f'memory_optimized_weight_parallel_mlp not implemented for custom activation_fn={self.args.activation_fn}.', - ) - return wp.memory_optimized_weight_parallel_mlp( - x, - w1, - w2, - topo, - group, - ) - - # Compute the MLP. - x = wp.sdd_nt(x, w1, topo, group) - activation_fn_out = act_fn(x, self.args.activation_fn) - return wp.dsd_nn(activation_fn_out, w2, group) - def forward(self, x, topo): w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.moe_weight_parallelism: - return self.parallel_forward(x, topo) - elif self.args.memory_optimized_mlp: + if self.args.memory_optimized_mlp: return memory_optimized_mlp( x, w1, @@ -542,9 +503,6 @@ def forward(self, x, tokens_per_expert): w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - if self.args.moe_weight_parallelism: - raise NotImplementedError('Weight parallelism not yet supported with GroupedMLP.',) - if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( x, diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 6aa0015..239f75f 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -42,14 +42,6 @@ def copy_expert_model_parallel_attributes( ) -def get_weight_parallel_world_size(args: Arguments) -> int: - return (torch.distributed.get_world_size(args.weight_parallel_group) if args.moe_weight_parallelism else 1) - - -def get_weight_parallel_rank(args: Arguments) -> int: - return (torch.distributed.get_rank(args.weight_parallel_group) if args.moe_weight_parallelism else 0) - - def synchronized_print(group, *x): world_size = torch.distributed.get_world_size(group) rank = torch.distributed.get_rank(group) diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py deleted file mode 100644 index 82effec..0000000 --- a/megablocks/layers/weight_parallel.py +++ /dev/null @@ -1,416 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import stk -import torch - -from megablocks.layers import gelu - - -def _gather_weights(w, group, parallel_w=None, async_op=False): - """Gather the weights across the process group. - - Args: - w: torch.Tensor, local shard of the weights. - group: ProcessGroup, the group to gather across. - parallel_w: torch.Tensor, option output tensor to use - for the gather. - async_op: Whether to gather asynchronously. - - Returns: - The gathered weights tensor and a handle for asynchronous - communication. - """ - n, k = w.shape - world_size = torch.distributed.get_world_size(group) - - if parallel_w is None: - parallel_w = torch.empty( - n * world_size, - k, - device=w.device, - dtype=w.dtype, - ) - handle = torch.distributed.all_gather_into_tensor( - parallel_w, - w, - group=group, - async_op=async_op, - ) - return parallel_w, handle - - -def _scaled_reduce_scatter(parallel_dw, group, dw=None, async_op=False): - """Scatter reduce the weights across the process group. - - Args: - parallel_dw: torch.Tensor, local shard of the weights. - group: ProcessGroup, the group to scatter-reduce across. - dw: torch.Tensor, option output tensor to use for the op. - async_op: Whether to scatter reduce asynchronously. - - Returns: - The reduced weights tensor, scaled by 1 / world_size, and - a handle for asynchronous communication. - """ - n, k = parallel_dw.shape - world_size = torch.distributed.get_world_size(group) - assert (n % world_size) == 0 - - # Pre-scale the gradients by the world size. - # - # NOTE: Reduce in float32, always. - parallel_dw = parallel_dw.float() / world_size - - if dw is None: - dw = torch.empty( - n // world_size, - k, - device=parallel_dw.device, - dtype=torch.float32, - ) - handle = torch.distributed.reduce_scatter_tensor( - dw, - parallel_dw, - group=group, - async_op=async_op, - ) - return dw, handle - - -class WeightParallelSddNt(torch.autograd.Function): - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward(ctx, x, w, topo, group): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w = w.to(ctx._dtype) - # [m, k] x [n, k] = [m, n] - if not x.is_contiguous() or not w.is_contiguous(): - raise ValueError("Expected contiguous 'x' and 'w'.") - - ctx.group = group - ctx.shape = topo.shape - ctx.save_for_backward( - x, - w, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - # TODO(tgale): Support prefetching forward weights. - parallel_w, _ = _gather_weights(w, group) - return stk.ops.sdd(x, parallel_w.t(), topo).data - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad): - x, w = ctx.saved_tensors[:2] - grad = stk.Matrix(ctx.shape, grad, *ctx.saved_tensors[2:]) - - # Start the weight gather asynchronously to overlap with the - # weight gradient computation. - parallel_w, handle = _gather_weights(w, ctx.group, async_op=True) - parallel_dw = None - if ctx.needs_input_grad[1]: - parallel_dw = stk.ops.dsd(grad.t(), x) - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw, handle = _scaled_reduce_scatter( - parallel_dw, - ctx.group, - async_op=True, - ) - dx = None - if ctx.needs_input_grad[0]: - dx = stk.ops.dsd(grad, parallel_w) - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw = dw.to(w.dtype) - return dx, dw, None, None - - -def sdd_nt(a, b, topo, group): - return stk.Matrix( - topo.size(), - WeightParallelSddNt.apply(a, b, topo, group), - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - - -class WeightParallelDsdNn(torch.autograd.Function): - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward( - ctx, - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - w, - group, - ): - # [m, k] x [k, n] = [m, n] - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - data = data.to(ctx._dtype) - w = w.to(ctx._dtype) - if not data.is_contiguous() or not w.is_contiguous(): - raise ValueError("Expected contiguous 'data' and 'w'.") - - ctx.group = group - ctx.shape = shape - ctx.save_for_backward( - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - w, - ) - x = stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - ) - - # TODO(tgale): Support prefetching forward weights. - parallel_w, _ = _gather_weights(w, group) - return stk.ops.dsd(x, parallel_w) - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad): - x = stk.Matrix(ctx.shape, *ctx.saved_tensors[:-1]) - w = ctx.saved_tensors[-1] - - # Start the weight gather asynchronously to overlap with the - # weight gradient computation. - parallel_w, handle = _gather_weights(w, ctx.group, async_op=True) - parallel_dw = None - if ctx.needs_input_grad[-2]: - parallel_dw = stk.ops.dsd(x.t(), grad) - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw, handle = _scaled_reduce_scatter( - parallel_dw, - ctx.group, - async_op=True, - ) - dx = None - if ctx.needs_input_grad[1]: - dx = stk.ops.sdd(grad, parallel_w.t(), x) - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw = dw.to(w.dtype) - return None, dx.data, None, None, None, None, None, None, dw, None - - -def dsd_nn(a, b, group): - return WeightParallelDsdNn.apply( - a.size(), - a.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t, - b, - group, - ) - - -class MemoryOptimizedWeightParallelMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward(ctx, x, w1, w2, topo, group): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - parallel_w1, _ = _gather_weights(w1, group) - sdd_out = stk.ops.sdd(x, parallel_w1.t(), topo) - - # GeLU. - gelu_out = gelu.gelu(sdd_out) - - # Layer 1: x @ w2. - # - # NOTE: Reuse the buffer for the w1 weight gather. - parallel_w2, _ = _gather_weights(w2, group, parallel_w1) - dsd_out = stk.ops.dsd(gelu_out, parallel_w2) - - # NOTE: Save the input to the layer and the gelu input for - # gradient computation. We'll re-compute the gelu forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.group = group - ctx.shape = topo.shape - ctx.save_for_backward( - x, - w1, - w2, - sdd_out.data, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t, - ) - return dsd_out - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, ddsd_out): - x, w1, w2 = ctx.saved_tensors[:3] - sdd_out = stk.Matrix(ctx.shape, *ctx.saved_tensors[3:]) - - if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): - raise ValueError('Expected all MLP inputs to need grad.') - - # Start the weight gather asynchronously to overlap with the - # weight gradient computation and gelu recompute. - parallel_w2, handle = _gather_weights(w2, ctx.group, async_op=True) - - # Compute dw2 with recomputed gelu output. - gelu_out = gelu.gelu(sdd_out) - parallel_dw2 = stk.ops.dsd(gelu_out.t(), ddsd_out) - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw2, handle = _scaled_reduce_scatter( - parallel_dw2, - ctx.group, - async_op=True, - ) - - # Compute dgelu_out. - # - # NOTE: We reuse the gelu_out allocation. - stk.backend.triton_kernels.sdd( - ddsd_out, - parallel_w2.t(), - sdd_out.shape, - gelu_out.data, - sdd_out.offsets, - sdd_out.row_indices, - sdd_out.column_indices, - ) - dgelu_out = gelu_out - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw2 = dw2.to(w2.dtype) - - # Start the weight gather asynchronously to overlap with the - # weight and gelu gradient computation. - # - # NOTE: Reuse the buffer from the w2 weight gather. - parallel_w1, handle = _gather_weights( - w1, - ctx.group, - parallel_w2, - async_op=True, - ) - - # Compute dsdd_out. - # - # NOTE: This reuses the dgelu_out allocation. - dsdd_out = gelu.gelu_backward_(dgelu_out, sdd_out) - - # Compute dw1. - # - # NOTE: This reuses the parallel_dw2 allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.t().shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - True, # transpose_a - x, - parallel_dw2, - ) - parallel_dw1 = parallel_dw2 - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw1, handle = _scaled_reduce_scatter( - parallel_dw1, - ctx.group, - async_op=True, - ) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - parallel_w1, - ddsd_out, - ) - dx = ddsd_out - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw1 = dw1.to(w1.dtype) - return dx, dw1, dw2, None, None - - -memory_optimized_weight_parallel_mlp = MemoryOptimizedWeightParallelMLP.apply diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py deleted file mode 100644 index 35e40a0..0000000 --- a/tests/layers/parallelism_test.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright 2024 Databricks -# SPDX-License-Identifier: Apache-2.0 - -import functools - -import numpy as np -import pytest -import torch - -from megablocks.layers import arguments, dmoe, mpu - -_PARALLELISM_TESTS = ( - (64, 1024, 512, 2048, 64, 1, False), - (64, 1024, 512, 2048, 64, 1, True), - # Test with fewer experts than ranks to verify tensor - # sharding in tandem with expert sharding. - (4, 1, 512, 2048, 4, 1, False), - (4, 1, 512, 2048, 4, 1, True), -) - - -# Todo: Fix this long term -@pytest.fixture -def group(): - return None - - -@pytest.mark.world_size(2) -@pytest.mark.gpu -@pytest.mark.parametrize(( - 'batch_size', - 'sequence_length', - 'hidden_size', - 'ffn_hidden_size', - 'num_experts', - 'top_k', - 'memory_optimized', -), _PARALLELISM_TESTS) -def test_expert_parallel_versus_weight_parallel( - group, - batch_size: int, - sequence_length: int, - hidden_size: int, - ffn_hidden_size: int, - num_experts: int, - top_k: int, - memory_optimized: bool, -): - - init_fn = functools.partial(torch.nn.init.normal_, mean=0.0, std=0.1) - ep_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized, - ) - wp_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_weight_parallelism=True, - weight_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized, - ) - - # NOTE: Reset the seed so that the models get identical weights. - torch.manual_seed(1234) - ep = dmoe.dMoE(ep_args) - torch.manual_seed(1234) - wp = dmoe.dMoE(wp_args) - - # NOTE: Include the rank in the seed so we get different data per rank. - rank = torch.distributed.get_rank(group) - torch.manual_seed(1234 * rank) - x = torch.randn((batch_size, sequence_length, hidden_size), device=torch.cuda.current_device(), - dtype=torch.float32).requires_grad_(True) - - # Test forward. - out, _ = wp(x) - expected_out, _ = ep(x) - - # Check the forward outputs. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - out.detach().float().cpu(), - expected_out.detach().float().cpu(), - rtol=1e-4, - atol=1e-4, - ) is None - - # Test backward. - out.mean().backward() - expected_out.mean().backward() - - # NOTE: If tensor parallelism is used different weights can be on - # different ranks. Gather the full grads to rank 0 to compare. - def gather(x): - m, n = x.shape - world_size = torch.distributed.get_world_size(group) - out = torch.empty(m * world_size, n, device=x.device, dtype=x.dtype) - torch.distributed.all_gather_into_tensor(out, x, group=group) - return out - - def permute(x): - esd = mpu.expert_sharding_degree(ep_args) - hsd = mpu.hidden_sharding_degree(ep_args) - out = x.view(hsd, esd, -1).transpose(1, 0).contiguous() - return out.view(num_experts * ffn_hidden_size, hidden_size) - - wp_w2_grad = gather(wp.experts.mlp.w2.grad) - ep_w2_grad = permute(gather(ep.experts.mlp.w2.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w2_grad.float().cpu(), - ep_w2_grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None - - wp_w1_grad = gather(wp.experts.mlp.w1.grad) - ep_w1_grad = permute(gather(ep.experts.mlp.w1.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w1_grad.float().cpu(), - ep_w1_grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None - - # Verify the router weight gradient, which is not sharded. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - wp.router.layer.weight.grad.float().cpu(), - ep.router.layer.weight.grad.float().cpu(), - rtol=1e-5, - atol=1e-5, - ) is None From bce5d7b2aaf5038bc93b36f76c2baf51c2939bd2 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Mon, 12 Aug 2024 19:32:16 -0400 Subject: [PATCH 06/11] refactor testing (#140) * refactor testing * rename to architectures --- .../layers/architectures.py | 9 --------- tests/layers/dmoe_test.py | 16 +++++++++------- tests/layers/glu_test.py | 5 +++-- tests/layers/moe_test.py | 19 ++++++++++--------- 4 files changed, 22 insertions(+), 27 deletions(-) rename megablocks/layers/testing.py => tests/layers/architectures.py (84%) diff --git a/megablocks/layers/testing.py b/tests/layers/architectures.py similarity index 84% rename from megablocks/layers/testing.py rename to tests/layers/architectures.py index 4cd9500..da1c595 100644 --- a/megablocks/layers/testing.py +++ b/tests/layers/architectures.py @@ -7,15 +7,6 @@ from megablocks.layers.arguments import Arguments -def allclose(x, y, pct=0.5): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print('{:.2f}% of values not close.'.format(pct_diff)) - return False - return True - - class FFN(torch.nn.Module): def __init__(self, args: Arguments): diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index a737ef4..3d6565c 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -7,8 +7,10 @@ import torch from megablocks import grouped_gemm_util as gg -from megablocks.layers import dmoe, moe, testing from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss +from tests.layers.architectures import FFN # min size: (1, 2, 128, 2, 1) _FORWARD_TESTS_DEFAULT = ( @@ -64,9 +66,9 @@ def construct_moes( bf16=True, ) - mlp = testing.FFN(args) - moe_mlp = moe.MoE(args) - dmoe_mlp = dmoe.dMoE(args) + mlp = FFN(args) + moe_mlp = MoE(args) + dmoe_mlp = dMoE(args) mlp.cuda(torch.cuda.current_device()).to(torch.bfloat16) moe_mlp.cuda(torch.cuda.current_device()).to(torch.bfloat16) @@ -106,7 +108,7 @@ def test_dmoe_forward( out, _ = layer(x) assert out.shape == x.shape - moe.clear_load_balancing_loss() + clear_load_balancing_loss() @pytest.mark.gpu @@ -132,12 +134,12 @@ def test_dmoe_forward_backward( out, _ = layer(x) assert out.shape == x.shape - loss = out.sum() + moe.batched_load_balancing_loss(args) + loss = out.sum() + batched_load_balancing_loss(args) loss.backward() assert x.grad is not None layer.zero_grad(set_to_none=True) x.grad = None - moe.clear_load_balancing_loss() + clear_load_balancing_loss() @pytest.mark.gpu diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index d89af89..1e031de 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -7,8 +7,9 @@ import stk import torch -from megablocks.layers import dmlp_registry, testing +from megablocks.layers import dmlp_registry from megablocks.layers.arguments import Arguments +from tests.layers.architectures import GLU _DENSE_TESTS = ( (16, 1024, 512), @@ -36,7 +37,7 @@ def construct_dmoe_glu( bf16=True, ) - glu = testing.GLU(args) + glu = GLU(args) dmoe_glu = dmlp_registry.get(args) dmoe_glu.cuda(torch.cuda.current_device()).to(torch.bfloat16) diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index dd40ef9..ffd32cb 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -6,8 +6,9 @@ import pytest import torch -from megablocks.layers import moe, testing from megablocks.layers.arguments import Arguments +from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss +from tests.layers.architectures import FFN _FORWARD_TESTS = ( (16, 1024, 512, 1, 1), @@ -48,8 +49,8 @@ def construct_moe( init_method=init_method, ) - mlp = testing.FFN(args) - moe_mlp = moe.MoE(args) + mlp = FFN(args) + moe_mlp = MoE(args) mlp.cuda(torch.cuda.current_device()).half() moe_mlp.cuda(torch.cuda.current_device()).half() @@ -76,7 +77,7 @@ def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): out, _ = layer(x) assert out.shape == x.shape - moe.clear_load_balancing_loss() + clear_load_balancing_loss() @pytest.mark.gpu @@ -101,11 +102,11 @@ def test_moe_forward_backward( out, _ = layer(x) assert out.shape == x.shape - loss = out.sum() + moe.batched_load_balancing_loss(args) + loss = out.sum() + batched_load_balancing_loss(args) loss.backward() layer.zero_grad(set_to_none=True) x.grad = None - moe.clear_load_balancing_loss() + clear_load_balancing_loss() @pytest.mark.gpu @@ -119,7 +120,7 @@ def test_moe_forward_vs_dense(bs: int, sl: int, hs: int): out, _ = moe_mlp(x) assert out.shape == x.shape == expected_out.shape assert torch.allclose(out, expected_out) - moe.clear_load_balancing_loss() + clear_load_balancing_loss() @pytest.mark.gpu @@ -137,7 +138,7 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze() moe_mlp.zero_grad(set_to_none=True) x.grad = None - moe.clear_load_balancing_loss() + clear_load_balancing_loss() expected_out = mlp(x) expected_loss = expected_out.sum() @@ -152,4 +153,4 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): assert w2_grad.shape == expected_w2_grad.shape assert torch.allclose(w1_grad, expected_w1_grad) assert torch.allclose(w2_grad, expected_w2_grad) - moe.clear_load_balancing_loss() + clear_load_balancing_loss() From 5b2650a143f24f0d7ded8a20cd6ea41cc4c35567 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 28 Aug 2024 13:45:53 -0400 Subject: [PATCH 07/11] Type Checking (#141) * add type hints * more type checks * tyoe check router * more type checking * restore sum * more tests * more type checking * more updates * add py.typed * git rid of stk type errors * remove icecream package * fix matrix import * add type hints * fix all torch.distibuted type errors * fix more torch.distibuted type errors * fix all gmm type errors * more type checking * comment out type checking * update --- .pre-commit-config.yaml | 12 +- megablocks/backend/kernels.py | 153 +++++++++++++++++-------- megablocks/grouped_gemm_util.py | 17 ++- megablocks/layers/activation_fn.py | 12 +- megablocks/layers/all_to_all.py | 5 +- megablocks/layers/arguments.py | 11 +- megablocks/layers/dmoe.py | 3 +- megablocks/layers/glu.py | 5 +- megablocks/layers/memory_test.py | 7 +- megablocks/layers/mlp.py | 13 ++- megablocks/layers/moe.py | 59 ++++++---- megablocks/layers/mpu.py | 28 +++-- megablocks/layers/router.py | 10 +- megablocks/ops/all_to_all_benchmark.py | 11 +- megablocks/ops/binned_gather.py | 12 +- megablocks/ops/binned_scatter.py | 12 +- megablocks/ops/cumsum.py | 6 +- megablocks/ops/gather.py | 12 +- megablocks/ops/histogram.py | 4 +- megablocks/ops/padded_gather.py | 13 ++- megablocks/ops/padded_scatter.py | 14 ++- megablocks/ops/repeat.py | 5 +- megablocks/ops/replicate.py | 6 +- megablocks/ops/round_up.py | 2 +- megablocks/ops/scatter.py | 16 ++- megablocks/ops/sort.py | 4 +- megablocks/ops/sum.py | 3 +- megablocks/ops/topology.py | 12 +- megablocks/py.typed | 0 pyproject.toml | 2 +- setup.py | 1 + 31 files changed, 324 insertions(+), 146 deletions(-) create mode 100644 megablocks/py.typed diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1d315f5..c754b29 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,9 +1,19 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks authors # SPDX-License-Identifier: Apache-2.0 default_language_version: python: python3 repos: +# - repo: local +# hooks: +# - id: pyright +# name: pyright +# entry: pyright +# language: node +# types: [python] +# pass_filenames: false +# args: [--warnings] +# additional_dependencies: ["pyright@1.1.310"] - repo: https://github.com/astral-sh/ruff-pre-commit rev: v0.2.2 hooks: diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index b831826..ca0120b 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,26 +1,27 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional import torch import triton import triton.language as tl -def assert_is_tensor(x, ndim): +def assert_is_tensor(x: torch.Tensor, ndim: int): if x.ndim != ndim: raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') -def assert_is_matrix(x): +def assert_is_matrix(x: torch.Tensor): assert_is_tensor(x, 2) -def assert_is_vector(x): +def assert_is_vector(x: torch.Tensor): if x.ndim != 1: raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') -def assert_equal(a, b): +def assert_equal(a: Any, b: Any): if a != b: raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) @@ -43,13 +44,13 @@ def assert_equal(a, b): ) @triton.jit def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, + a: torch.Tensor, + b: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Any, + bins: torch.Tensor, + padded_bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -93,7 +94,8 @@ def _padded_copy( iptr = a if A_TO_B else b optr = b if A_TO_B else a - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -103,7 +105,15 @@ def _padded_copy( offsets += BLOCK_X -def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): +def padded_gather( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -119,7 +129,7 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): # NOTE: Because of the padding, the output size is dynamic. # We load the final padded bin bound to get the output rows. - output_rows = padded_bins[-1].cpu().item() + output_rows = int(padded_bins[-1].cpu().item()) out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, @@ -137,7 +147,14 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): return out -def gather(x, indices, bin_ids, weights, bins, top_k): +def gather( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -169,7 +186,15 @@ def gather(x, indices, bin_ids, weights, bins, top_k): return out -def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -202,7 +227,14 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) -def scatter(x, indices, bin_ids, weights, bins, top_k): +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +) -> torch.Tensor: return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) @@ -225,13 +257,13 @@ def scatter(x, indices, bin_ids, weights, bins, top_k): ) @triton.jit def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, + x: torch.Tensor, + grad: torch.Tensor, + wgrad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -263,7 +295,7 @@ def _padded_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(iterations): + for _ in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -275,7 +307,15 @@ def _padded_copy_wgrad( tl.store(wgrad, out) -def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): +def padded_scatter_wgrad( + x: torch.Tensor, + grad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): # Validate the input shapes. assert_is_matrix(x) assert_is_matrix(grad) @@ -302,7 +342,14 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): return out -def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): +def scatter_wgrad( + x: torch.Tensor, + grad: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, +): return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) @@ -323,13 +370,13 @@ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): ) @triton.jit def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, + a: torch.Tensor, + b: torch.Tensor, + num_experts: int, + expert_capacity: int, + indices: torch.Tensor, + weights, #: Optional[torch.Tensor], + bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -378,7 +425,7 @@ def _binned_copy( optr = b if A_TO_B else a iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(iterations): + for _ in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -388,7 +435,14 @@ def _binned_copy( offsets += BLOCK_X -def binned_gather(x, indices, weights, bins, expert_capacity, top_k): +def binned_gather( + x: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + expert_capacity: int, + top_k: int, +): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -400,7 +454,6 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): num_experts = bins.shape[0] out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) - _binned_copy[(num_experts, expert_capacity)]( x, out, @@ -417,7 +470,13 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): return out -def binned_scatter(x, indices, weights, bins, top_k): +def binned_scatter( + x: torch.Tensor, + indices: torch.Tensor, + weights: Optional[torch.Tensor], + bins: torch.Tensor, + top_k: int, +): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_vector(indices) @@ -465,13 +524,13 @@ def binned_scatter(x, indices, weights, bins, top_k): ) @triton.jit def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, + x: torch.Tensor, + grad: torch.Tensor, + wgrad: torch.Tensor, + num_experts: int, + expert_capacity: int, + indices: torch.Tensor, + bins: torch.Tensor, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -505,7 +564,7 @@ def _binned_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(iterations): + for _ in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -517,7 +576,7 @@ def _binned_copy_wgrad( tl.store(wgrad, out) -def binned_scatter_wgrad(x, grad, indices, bins, top_k): +def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_matrix(grad) diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index 07dbc04..6d3f977 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -1,20 +1,25 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +import warnings +_grouped_gemm_is_available: bool = False try: import grouped_gemm -except ImportError: - grouped_gemm = None + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') def grouped_gemm_is_available(): - return grouped_gemm is not None + return _grouped_gemm_is_available def assert_grouped_gemm_is_available(): - assert grouped_gemm_is_available( - ), ('Grouped GEMM not available. Please run ' - '`pip install git+https://github.com/tgale96/grouped_gemm@main`.') + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg backend = grouped_gemm.backend if grouped_gemm_is_available() else None diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 736d311..a31770b 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,24 +1,24 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Callable +from typing import Any, Callable, Union -import stk import torch +from stk import Matrix def act_fn( - x: stk.Matrix, + x: Matrix, function: Callable, return_grad_fn: bool = False, **kwargs, -): - assert isinstance(x, stk.Matrix) +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): if return_grad_fn: x.data.requires_grad = True out = function(x.data, **kwargs) - y = stk.Matrix( + y = Matrix( x.size(), out, x.row_indices, diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 82a6f40..5ac7067 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import torch.distributed as dist class AllToAllOp(torch.autograd.Function): @@ -14,7 +15,7 @@ def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): ctx.output_split_sizes = output_split_sizes ctx.input_split_sizes = input_split_sizes ctx.group = group - handle = torch.distributed.all_to_all_single( + handle = dist.all_to_all_single( out, x, output_split_sizes=output_split_sizes, @@ -32,7 +33,7 @@ def backward(ctx, grad, _): device=grad.device, dtype=grad.dtype, ) - torch.distributed.all_to_all_single( + dist.all_to_all_single( out, grad, output_split_sizes=ctx.input_split_sizes, diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index ddbe2b7..892cb91 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -6,12 +6,13 @@ from typing import Any, Callable, Optional, Union import torch +import torch.distributed as dist import torch.nn.functional as F import megablocks.grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. -InitFn = Callable[[torch.Tensor], None] +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] _ALLOWED_BITWIDTHS = (-1, 4, 8) @@ -39,7 +40,7 @@ class Arguments: # Parallelism arguments. moe_expert_model_parallelism: bool = False - expert_parallel_group: Optional[torch.distributed.ProcessGroup] = None + expert_parallel_group: Optional[dist.ProcessGroup] = None pipeline_model_parallel_size: int = 1 num_layers_per_virtual_pipeline_stage: Optional[int] = None @@ -51,7 +52,7 @@ class Arguments: # Initialization arguments. fp16: bool = True bf16: bool = False - device: torch.device = torch.cuda.current_device() + device: Union[int, torch.device] = torch.cuda.current_device() init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) output_layer_init_method: InitFn = init_method @@ -60,7 +61,7 @@ class Arguments: # shared expert arguments shared_expert: bool = False # enable using shared expert - fc_cls: torch.nn.Module = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored shared_expert_hidden_size: Optional[ @@ -75,7 +76,7 @@ def __post_init__(self): self.shared_expert_hidden_size = self.ffn_hidden_size -def from_megatron(megatron_args): +def from_megatron(megatron_args: Any): args = Arguments() for field in dataclasses.fields(args): if hasattr(megatron_args, field.name): diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index e683f8a..377b77f 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -2,8 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 import numpy as np -import stk +import stk.ops import torch +from stk import Matrix import megablocks.ops as ops from megablocks.layers import common, dmlp_registry, moe, mpu diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index 4654576..e510723 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,7 +1,7 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -import stk +import stk.ops import torch from megablocks import grouped_gemm_util as gg @@ -80,6 +80,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") # Layer 0: x @ w1.t(). + assert gg.backend is not None sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) @@ -123,6 +124,7 @@ def backward(ctx, ddsd_out): activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None dw2 = gg.backend.gmm( activation_fn_out, ddsd_out, @@ -196,6 +198,7 @@ def forward(self, x, tokens_per_expert): ) # Compute the MLP. + assert gg.ops is not None x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) x1 = self.args.activation_fn(x1) * x2 diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index 809e317..4acbd94 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -4,6 +4,7 @@ import gc import torch +import torch.distributed as dist from megablocks.layers import arguments, dmoe @@ -92,9 +93,9 @@ def grad_numel(x): if __name__ == '__main__': - assert torch.distributed.is_available() - group = torch.distributed.init_process_group(backend='nccl') - local_rank = torch.distributed.get_rank(group) + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) torch.cuda.set_device(local_rank) for args in _TESTS: diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index f7cb782..e8f2d7b 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -4,6 +4,8 @@ from typing import Any import stk +import stk.backend.triton_kernels +import stk.ops import torch from packaging import version @@ -17,20 +19,20 @@ class ScaleGradient(torch.autograd.Function): @staticmethod @torch.cuda.amp.custom_fwd - def forward(ctx, x, scale): + def forward(ctx: Any, x: torch.Tensor, scale: float): ctx.scale = scale return x @staticmethod @torch.cuda.amp.custom_bwd - def backward(ctx, grad): + def backward(ctx: torch.Tensor, grad: torch.Tensor): return grad * ctx.scale, None scale_gradient = ScaleGradient.apply -def resolve_dtensor(weight): +def resolve_dtensor(weight: torch.Tensor): if version.parse(torch.__version__) >= version.parse('2.0.0'): from torch.distributed._tensor import DTensor if isinstance(weight, DTensor): @@ -408,6 +410,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). + assert gg.backend is not None sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) # activation_fn @@ -429,7 +432,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): @staticmethod @torch.cuda.amp.custom_bwd - def backward(ctx, ddsd_out): + def backward(ctx: Any, ddsd_out: torch.Tensor): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.') @@ -449,6 +452,7 @@ def backward(ctx, ddsd_out): activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None dw2 = gg.backend.gmm( activation_fn_out, ddsd_out, @@ -513,6 +517,7 @@ def forward(self, x, tokens_per_expert): ) # Compute the MLP. + assert gg.ops is not None x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x = self.args.activation_fn(x) return gg.ops.gmm(x, w2, batch_sizes) diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index e5eaaa8..9ba5edb 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,8 +1,10 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple import numpy as np import torch +import torch.distributed as dist import megablocks.ops as ops from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry @@ -110,6 +112,7 @@ def __init__(self, args: Arguments): # Expert MLP. self.mlp = mlp.MLP(args) + self.bias: Optional[torch.Tensor] if self.args.bias: # Note that the output bias is not parallelized with expert # model parallelism. @@ -127,12 +130,12 @@ def __init__(self, args: Arguments): # Select the forward function for the operating mode. self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - def expert_capacity(self, tokens): + def expert_capacity(self, tokens: int) -> int: world_size = mpu.get_expert_parallel_world_size(self.args) tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) return int(self.args.moe_capacity_factor * tokens_per_expert) - def load_balancing_loss(self, tokens_per_expert, expert_scores): + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): """Calculate the load balancing loss contribution.""" assert len(expert_scores.size()) == 2 tokens, num_experts = expert_scores.size() @@ -146,7 +149,8 @@ def load_balancing_loss(self, tokens_per_expert, expert_scores): expert_scores.mean(dim=0), ) - def indices_and_bins(self, top_expert): + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Sort the expert ids to produce the scatter/gather # indices for the permutation. # @@ -154,7 +158,9 @@ def indices_and_bins(self, top_expert): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - bin_ids, indices = ops.sort(top_expert, self.sort_end_bit) + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output # Histogram the expert ids to identify the number of # tokens routed to each expert. @@ -166,23 +172,32 @@ def indices_and_bins(self, top_expert): # Calculate the bin bounds for the sorted tokens. bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + return indices, bin_ids, bins, tokens_per_expert def permute_and_compute( self, - x, - tokens_per_expert, # unused - indices, - bin_ids, # unused - expert_weights, - bins, - expert_capacity, - top_k, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output # Perform the expert computation. Note that we don't # use biases for these linear operations. @@ -191,7 +206,7 @@ def permute_and_compute( # Un-route the data for the MoE output. return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - def forward_once(self, x, expert_weights, top_experts): + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] # expert_weights: [sl * bs, top-k] # top_experts: [sl * bs, top-k] @@ -202,7 +217,7 @@ def forward_once(self, x, expert_weights, top_experts): # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. - sl, bs, hs = x.size() + sl, bs, _ = x.size() expert_capacity = self.expert_capacity(sl * bs) if expert_capacity == 0: expert_capacity = torch.max(tokens_per_expert).item() @@ -219,7 +234,7 @@ def forward_once(self, x, expert_weights, top_experts): ) return x, tokens_per_expert - def parallel_forward_once(self, x, expert_weights, top_experts): + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # NOTE: This function implements the same computation as forward_once # but with expert model parallelism. # @@ -257,7 +272,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Pass token count information to the device on which the # target expert resides. parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) - tpe_handle = torch.distributed.all_to_all_single( + tpe_handle = dist.all_to_all_single( parallel_tokens_per_expert, repeated_tokens_per_expert, group=self.args.expert_parallel_group, @@ -270,7 +285,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - x = ops.gather(x, indices, bin_ids, bins, self.top_k) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output # Compute the number of tokens that will be received from each # device and permute the input data across the devices. @@ -356,7 +373,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. - tokens, hs = x.size() + tokens, _ = x.size() expert_capacity = self.expert_capacity(tokens) if expert_capacity == 0: expert_capacity = torch.max(parallel_tokens_per_expert).item() @@ -405,7 +422,7 @@ def parallel_forward_once(self, x, expert_weights, top_experts): x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() - def forward(self, x, scores, expert_weights, top_experts): + def forward(self, x: torch.Tensor, scores: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): in_shape = x.size() # Compute the experts. @@ -439,7 +456,7 @@ def __init__(self, args: Arguments): def _init_experts_mlp(self, args: Arguments): return ParallelMLP(args) - def forward(self, x): + def forward(self, x: torch.Tensor): # NOTE: If we're going to cast the activations to lower precision # do it before we permute the tokens to save bandwidth. x = common.cast_if_autocast_enabled(x) diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 239f75f..b232139 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,21 +1,31 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Optional + import torch +import torch.distributed as dist from megablocks.layers.arguments import Arguments +class MoeParam(torch.Tensor): + + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool + + def is_moe_param(tensor: torch.Tensor) -> bool: return hasattr(tensor, 'expert_model_parallel') def get_expert_parallel_world_size(args: Arguments) -> int: - return (torch.distributed.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) def get_expert_parallel_rank(args: Arguments) -> int: - return (torch.distributed.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) def set_expert_model_parallel_attributes( @@ -26,7 +36,7 @@ def set_expert_model_parallel_attributes( setattr(tensor, 'expert_model_parallel', is_parallel) -def param_is_expert_model_parallel(param: torch.Tensor) -> bool: +def param_is_expert_model_parallel(param: MoeParam) -> bool: return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) @@ -42,11 +52,11 @@ def copy_expert_model_parallel_attributes( ) -def synchronized_print(group, *x): - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) for i in range(world_size): - torch.distributed.barrier(group) + dist.barrier(group) if i == rank: print(f'rank = {rank}', *x) @@ -70,9 +80,7 @@ def hidden_sharding_degree(args: Arguments) -> int: raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) if (esd * hsd) != world_size: raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' " - f'({esd}) * hidden_sharding_degree ' - f'({hsd}) != world_size ({world_size}).', + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", ) return hsd diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index 42cfbe1..9499870 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch @@ -14,7 +15,7 @@ class _UniformExpertAssignment(torch.autograd.Function): @staticmethod - def forward(ctx, x, num_experts): + def forward(ctx: Any, x: torch.Tensor, num_experts: int): out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) out = torch.remainder(out, num_experts) return out.view(x.shape) @@ -43,18 +44,19 @@ def __init__(self, args: Arguments): ) args.init_method(self.layer.weight) - def jitter(self, x): + def jitter(self, x: torch.Tensor): + assert isinstance(self.args.moe_jitter_eps, float) low = 1.0 - self.args.moe_jitter_eps high = 1.0 + self.args.moe_jitter_eps noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) return low + noise * (high - low) - def _top_k(self, scores): + def _top_k(self, scores: torch.Tensor): if self.args.moe_top_k == 1: return scores.max(dim=-1, keepdim=True) return torch.topk(scores, self.args.moe_top_k, dim=-1) - def forward(self, x): + def forward(self, x: torch.Tensor): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index b3a8537..47b9530 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import torch +import torch.distributed as dist from megablocks import benchmark_util from megablocks.layers.all_to_all import all_to_all @@ -29,7 +30,7 @@ def benchmark_all_to_all(group, sl, hs): - world_size = torch.distributed.get_world_size(group) + world_size = dist.get_world_size(group) assert (sl % world_size) == 0 send_recv_sizes = [sl // world_size] * world_size @@ -45,14 +46,14 @@ def benchmark(): time, std = benchmark_util.benchmark_function(benchmark) - if torch.distributed.get_rank(group) == 0: + if dist.get_rank(group) == 0: benchmark_util.log_benchmark('All-To-All', details, time, std) if __name__ == '__main__': - assert torch.distributed.is_available() - group = torch.distributed.init_process_group(backend='nccl') - local_rank = torch.distributed.get_rank(group) + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) torch.cuda.set_device(local_rank) for args in _ALL_TO_ALL_BENCHMARK: diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 8a22317..89cce1b 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,14 +13,21 @@ class BinnedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bins, bin_size, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): ctx.save_for_backward(indices, bins) ctx.top_k = top_k return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bins = ctx.saved_tensors out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index f65fbe8..f5ce0d6 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +13,14 @@ class BinnedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, weights, bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): assert len(x.size()) == 3 ctx.bin_size = x.size(1) ctx.top_k = top_k @@ -24,7 +32,7 @@ def forward(ctx, x, indices, weights, bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() x, indices, weights, bins = ctx.saved_tensors out = kernels.binned_gather( diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 09b23ab..bf0482a 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -18,7 +20,7 @@ class ExclusiveCumsumOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, dim): + def forward(ctx: Any, x: torch.Tensor, dim: int): if len(x.size()) == 1: x = x.view([1, -1]) out = torch.empty_like(x) @@ -35,7 +37,7 @@ def forward(ctx, x, dim): class InclusiveCumsumOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, dim): + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: if len(x.size()) == 1: x = x.view([1, -1]) out = torch.empty_like(x) diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index a335273..41b09a1 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,14 +13,21 @@ class GatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): ctx.save_for_backward(indices, bin_ids, bins) ctx.top_k = top_k return kernels.gather(x, indices, bin_ids, None, bins, top_k) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bin_ids, bins = ctx.saved_tensors diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index 7660e82..7855233 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -18,7 +20,7 @@ class HistogramOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, max_val): + def forward(ctx: Any, x: torch.Tensor, max_val: float): return ops.histogram(x, max_val) diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index b57a518..f272a77 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +13,15 @@ class PaddedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, bins, padded_bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): ctx.save_for_backward(indices, bin_ids, bins, padded_bins) ctx.top_k = top_k return kernels.padded_gather( @@ -27,7 +36,7 @@ def forward(ctx, x, indices, bin_ids, bins, padded_bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bin_ids, bins, padded_bins = ctx.saved_tensors diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 1ca1605..9ff81dd 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,5 +1,6 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +13,16 @@ class PaddedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward( indices, @@ -36,7 +46,7 @@ def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() saved_tensors = ctx.saved_tensors diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index 61bb04b..7e9e09d 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,11 +1,10 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 MosaicML MegaBlocks authors -# SPDX-License-Identifier: Apache-2.0 +import torch -def repeat(x, tiling): +def repeat(x: torch.Tensor, tiling: torch.Size): if all((t == 1 for t in tiling)): return x return x.repeat(*tiling) diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index b7cb9c3..2dbec35 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -17,14 +19,14 @@ class ReplicateOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, bins, num_outputs): + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): ctx.save_for_backward(bins) out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) ops.replicate_forward(x, bins, out) return out @staticmethod - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): bins, = ctx.saved_tensors out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) ops.replicate_backward(grad, bins, out) diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index 2c59a78..6cf6bc8 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -4,7 +4,7 @@ import torch -def round_up(x, value): +def round_up(x: torch.Tensor, value: int): assert isinstance(value, int) assert x.dtype == torch.int32 diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 33f051c..a5aaafc 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional + import torch from stk.backend.autocast import custom_bwd, custom_fwd @@ -12,7 +14,15 @@ class ScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, weights, bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) ctx.top_k = top_k @@ -21,7 +31,7 @@ def forward(ctx, x, indices, bin_ids, weights, bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() saved_tensors = ctx.saved_tensors @@ -58,5 +68,5 @@ def scatter( weights: torch.Tensor, bins: torch.Tensor, top_k: int, -): +) -> Optional[torch.Tensor]: return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index 12ec8f3..4fb0aab 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any, Optional, Tuple + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -24,7 +26,7 @@ class SortOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, end_bit=None): + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: if end_bit is None: end_bit = _BITS_FOR_DTYPE[x.dtype] x_out = torch.empty_like(x) diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index aa81334..e00c1aa 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,8 +1,9 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +import torch -def sum(x, dim=0): +def sum(x: torch.Tensor, dim: int = 0): if x.shape[dim] == 1: return x.squeeze(dim=dim) return x.sum(dim=dim) diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index ba4ade0..b41b5fa 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -1,6 +1,8 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch @@ -19,11 +21,11 @@ class TopologyOp(torch.autograd.Function): @staticmethod def forward( - ctx, - padded_bins, - block_size, - output_block_rows, - output_block_columns, + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, ): out = torch.empty( output_block_rows * output_block_columns, diff --git a/megablocks/py.typed b/megablocks/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml index c72dbdf..17e1b0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks authors # SPDX-License-Identifier: Apache-2.0 # build requirements diff --git a/setup.py b/setup.py index fa15ee4..202e3da 100644 --- a/setup.py +++ b/setup.py @@ -143,4 +143,5 @@ install_requires=install_requires, extras_require=extra_deps, python_requires='>=3.9', + package_data={_PACKAGE_NAME: ['py.typed']}, ) From 35abddf845331007803a55d9a168b9acc2ee5a9d Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Fri, 30 Aug 2024 09:21:36 -0400 Subject: [PATCH 08/11] Bump torch to <2.4.1 (#145) * bump torch to <2.5 (#142) * bump torch to <2.5 (#143) * bump torch to <2.4.1 (#144) * bump torch (#146) * install from git, not pypi * Update setup.py Co-authored-by: Saaketh Narayan * no type checking in `kernel.py` (#147) --------- Co-authored-by: Saaketh Narayan --- .github/workflows/pr-gpu.yaml | 8 +- megablocks/backend/kernels.py | 144 ++++++++++------------------------ pyproject.toml | 2 +- setup.py | 8 +- 4 files changed, 52 insertions(+), 110 deletions(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 1ca8d5b..d94b057 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -21,14 +21,14 @@ jobs: fail-fast: false matrix: include: - - name: "python3.11-pytorch2.3.1-gpus1" + - name: "python3.11-pytorch2.4.0-gpus1" gpu_num: 1 python_version: 3.11 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 - - name: "python3.11-pytorch2.3.1-gpus2" + container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 + - name: "python3.11-pytorch2.4.0-gpus2" gpu_num: 2 python_version: 3.11 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 + container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 steps: - name: Run PR GPU tests uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2 diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index ca0120b..b584cee 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,27 +1,26 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional import torch import triton import triton.language as tl -def assert_is_tensor(x: torch.Tensor, ndim: int): +def assert_is_tensor(x, ndim): if x.ndim != ndim: raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') -def assert_is_matrix(x: torch.Tensor): +def assert_is_matrix(x): assert_is_tensor(x, 2) -def assert_is_vector(x: torch.Tensor): +def assert_is_vector(x): if x.ndim != 1: raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') -def assert_equal(a: Any, b: Any): +def assert_equal(a, b): if a != b: raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) @@ -44,13 +43,13 @@ def assert_equal(a: Any, b: Any): ) @triton.jit def _padded_copy( - a: torch.Tensor, - b: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Any, - bins: torch.Tensor, - padded_bins: torch.Tensor, + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -105,15 +104,7 @@ def _padded_copy( offsets += BLOCK_X -def padded_gather( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -129,7 +120,7 @@ def padded_gather( # NOTE: Because of the padding, the output size is dynamic. # We load the final padded bin bound to get the output rows. - output_rows = int(padded_bins[-1].cpu().item()) + output_rows = padded_bins[-1].cpu().item() out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, @@ -147,14 +138,7 @@ def padded_gather( return out -def gather( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - top_k: int, -): +def gather(x, indices, bin_ids, weights, bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -186,15 +170,7 @@ def gather( return out -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -) -> torch.Tensor: +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -227,14 +203,7 @@ def padded_scatter( return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - top_k: int, -) -> torch.Tensor: +def scatter(x, indices, bin_ids, weights, bins, top_k): return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) @@ -257,13 +226,13 @@ def scatter( ) @triton.jit def _padded_copy_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - wgrad: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -307,15 +276,7 @@ def _padded_copy_wgrad( tl.store(wgrad, out) -def padded_scatter_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_matrix(grad) @@ -342,14 +303,7 @@ def padded_scatter_wgrad( return out -def scatter_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, -): +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) @@ -370,13 +324,13 @@ def scatter_wgrad( ) @triton.jit def _binned_copy( - a: torch.Tensor, - b: torch.Tensor, - num_experts: int, - expert_capacity: int, - indices: torch.Tensor, - weights, #: Optional[torch.Tensor], - bins: torch.Tensor, + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -435,14 +389,7 @@ def _binned_copy( offsets += BLOCK_X -def binned_gather( - x: torch.Tensor, - indices: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - expert_capacity: int, - top_k: int, -): +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -454,6 +401,7 @@ def binned_gather( num_experts = bins.shape[0] out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( x, out, @@ -470,13 +418,7 @@ def binned_gather( return out -def binned_scatter( - x: torch.Tensor, - indices: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - top_k: int, -): +def binned_scatter(x, indices, weights, bins, top_k): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_vector(indices) @@ -524,13 +466,13 @@ def binned_scatter( ) @triton.jit def _binned_copy_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - wgrad: torch.Tensor, - num_experts: int, - expert_capacity: int, - indices: torch.Tensor, - bins: torch.Tensor, + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -576,7 +518,7 @@ def _binned_copy_wgrad( tl.store(wgrad, out) -def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int): +def binned_scatter_wgrad(x, grad, indices, bins, top_k): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_matrix(grad) diff --git a/pyproject.toml b/pyproject.toml index 17e1b0f..ad4dc1b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ # build requirements [build-system] -requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4"] +requires = ["setuptools < 70.0.0", "torch >= 2.4.0, < 2.4.1"] build-backend = "setuptools.build_meta" # Pytest diff --git a/setup.py b/setup.py index 202e3da..a7a038e 100644 --- a/setup.py +++ b/setup.py @@ -62,15 +62,15 @@ install_requires = [ 'numpy>=1.21.5,<2.1.0', 'packaging>=21.3.0,<24.2', - 'torch>=2.3.0,<2.4', + 'torch>=2.4.0,<2.4.1', 'triton>=2.1.0', - 'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301', + 'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@v0.7.1', ] extra_deps = {} extra_deps['gg'] = [ - 'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb', + 'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@v0.1.6', ] extra_deps['dev'] = [ @@ -83,7 +83,7 @@ ] extra_deps['testing'] = [ - 'mosaicml>=0.22.0', + 'mosaicml>=0.24.1', ] extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) From 964ca73df2008dc7ed6e0e61ee0da38754ef7a7e Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Fri, 30 Aug 2024 17:00:47 -0700 Subject: [PATCH 09/11] yo (#149) --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index a7a038e..4b7be37 100644 --- a/setup.py +++ b/setup.py @@ -64,13 +64,13 @@ 'packaging>=21.3.0,<24.2', 'torch>=2.4.0,<2.4.1', 'triton>=2.1.0', - 'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@v0.7.1', + 'stanford-stk==0.7.1', ] extra_deps = {} extra_deps['gg'] = [ - 'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@v0.1.6', + 'grouped_gemm==0.1.6', ] extra_deps['dev'] = [ From d51654546d108e01f705022c0b9c7b4e6a8cb158 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Sat, 31 Aug 2024 11:21:59 -0400 Subject: [PATCH 10/11] bump `_version.py` to 0.7.0.dev0 (#148) * bump from 0.5.1 to 0.6.0.dev0 * Update megablocks/_version.py Co-authored-by: Saaketh Narayan * Empty-Commit --------- Co-authored-by: Saaketh Narayan --- megablocks/_version.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megablocks/_version.py b/megablocks/_version.py index 44ea780..5f259c3 100644 --- a/megablocks/_version.py +++ b/megablocks/_version.py @@ -3,4 +3,4 @@ """The MegaBlocks Version.""" -__version__ = '0.5.1' +__version__ = '0.7.0.dev0' From 66d7894c180f3c6b0240f284da8b790d3e90b918 Mon Sep 17 00:00:00 2001 From: Saaketh Narayan Date: Wed, 4 Sep 2024 19:34:33 -0700 Subject: [PATCH 11/11] Remove deprecated torch.cuda.amp custom fwd and bwd (#150) * change * change --- megablocks/layers/glu.py | 4 ++-- megablocks/layers/mlp.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index e510723..cbe0c91 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -67,7 +67,7 @@ class MemoryOptimizedGroupedGLU(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -102,7 +102,7 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): return dsd_out @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx, ddsd_out): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.') diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index e8f2d7b..6e6f4d8 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -18,13 +18,13 @@ class ScaleGradient(torch.autograd.Function): @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx: Any, x: torch.Tensor, scale: float): ctx.scale = scale return x @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx: torch.Tensor, grad: torch.Tensor): return grad * ctx.scale, None @@ -188,7 +188,7 @@ class MemoryOptimizedMLP(torch.autograd.Function): """Sparse MLP with manually scheduled memory reuse.""" @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, w2, topo, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -230,7 +230,7 @@ def forward(ctx, x, w1, w2, topo, activation_fn): return dsd_out @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx, ddsd_out): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.') @@ -398,7 +398,7 @@ class MemoryOptimizedGroupedMLP(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, w2, batch_sizes, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -431,7 +431,7 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): return dsd_out @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx: Any, ddsd_out: torch.Tensor): if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): raise ValueError('Expected all MLP inputs to need grad.')