diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..34272ee6 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,29 @@ +# What does this PR do? + + + +# What issue(s) does this change relate to? + + + +# Before submitting +- [ ] Have you read the [contributor guidelines](https://github.com/databricks/megablocks/blob/dev/CONTRIBUTING.md)? +- [ ] Is this change a documentation change or typo fix? If so, skip the rest of this checklist. +- [ ] Was this change discussed/approved in a GitHub issue first? It is much more likely to be merged if so. +- [ ] Did you update any related docs and document your change? +- [ ] Did you update any related tests and add any new tests related to your change? (see [testing](https://github.com/databricks/megablocks/blob/dev/CONTRIBUTING.md#running-tests)) +- [ ] Did you run the tests locally to make sure they pass? +- [ ] Did you run `pre-commit` on your change? (see the `pre-commit` section of [prerequisites](https://github.com/databricks/megablocks/blob/dev/CONTRIBUTING.md#prerequisites)) + + diff --git a/.github/workflows/code-quality.yaml b/.github/workflows/code-quality.yaml new file mode 100644 index 00000000..ff9081d0 --- /dev/null +++ b/.github/workflows/code-quality.yaml @@ -0,0 +1,41 @@ +name: Code Quality Checks +on: + push: + branches: + - main + - release/** + pull_request: + branches: + - main + - release/** + workflow_call: + workflow_dispatch: +# Cancel old runs when a new commit is pushed to the same branch if not on main or dev +concurrency: + group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} +defaults: + run: + working-directory: . +jobs: + code-quality: + runs-on: ubuntu-latest # TODO: switch to linux-ubuntu-latest later + timeout-minutes: 30 + strategy: + matrix: + python_version: + - "3.11" + pip_deps: + - "[dev]" + steps: + - uses: actions/checkout@v3 + - name: Get composite run steps repository + uses: actions/checkout@v3 + with: + repository: mosaicml/ci-testing + ref: v0.1.2 + path: ./ci-testing + - uses: ./ci-testing/.github/actions/code-quality + with: + python_version: ${{ matrix.python_version }} + pip_deps: ${{ matrix.pip_deps }} diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 0447b87a..d94b057b 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -15,22 +15,23 @@ concurrency: jobs: pytest-gpu: name: ${{ matrix.name }} - runs-on: ubuntu-latest # todo: switch to linux-ubuntu-latest later + if: github.repository_owner == 'databricks' + runs-on: ubuntu-latest # TODO: switch to linux-ubuntu-latest later strategy: fail-fast: false matrix: include: - - name: "python3.11-pytorch2.3.1-gpus1" + - name: "python3.11-pytorch2.4.0-gpus1" gpu_num: 1 python_version: 3.11 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 - - name: "python3.11-pytorch2.3.1-gpus2" + container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 + - name: "python3.11-pytorch2.4.0-gpus2" gpu_num: 2 python_version: 3.11 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 + container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 steps: - name: Run PR GPU tests - uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.0 + uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2 with: name: ${{ matrix.name }} container: ${{ matrix.container }} @@ -38,7 +39,8 @@ jobs: gpu_num: ${{ matrix.gpu_num }} git_repo: databricks/megablocks pip_deps: "[all,testing]" - pytest_command: "coverage run -m pytest tests" # todo: remove tests from pytest tests when we delete all tests outside of MegaBlocks repo + pytest_command: "coverage run -m pytest tests" + # TODO: remove tests from pytest tests when we delete all tests in the MegaBlocks dir pytest_markers: "gpu" composer_package_name: mosaicml # Required as Composer is built from source mcloud_timeout: 3600 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..c754b29d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,118 @@ +# Copyright 2024 Databricks authors +# SPDX-License-Identifier: Apache-2.0 + +default_language_version: + python: python3 +repos: +# - repo: local +# hooks: +# - id: pyright +# name: pyright +# entry: pyright +# language: node +# types: [python] +# pass_filenames: false +# args: [--warnings] +# additional_dependencies: ["pyright@1.1.310"] +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.2.2 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix] +- repo: https://github.com/google/yapf + rev: v0.32.0 + hooks: + - id: yapf + name: yapf + description: A formatter for Python files. + entry: yapf + args: [-i, -vv, -p] # inplace + language: python + types: [python] + additional_dependencies: + - toml +- repo: https://github.com/hadialqattan/pycln + rev: v2.1.2 + hooks: + - id: pycln + args: [. --all] +- repo: https://github.com/pycqa/isort + hooks: + - id: isort + rev: 5.12.0 +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.3.0 + hooks: + - id: check-added-large-files + - id: check-ast + - id: check-builtin-literals + - id: check-case-conflict + - id: check-docstring-first + - id: check-executables-have-shebangs + - id: check-json + - id: check-shebang-scripts-are-executable + - id: pretty-format-json + args: + - --autofix + - --no-sort-keys + - --indent=4 + - --no-ensure-ascii + - id: check-merge-conflict + - id: check-symlinks + - id: check-toml + - id: check-vcs-permalinks + - id: check-xml + - id: check-yaml + - id: debug-statements + - id: destroyed-symlinks + - id: double-quote-string-fixer + - id: end-of-file-fixer + - id: fix-byte-order-marker + - id: mixed-line-ending + - id: trailing-whitespace +- repo: https://github.com/Lucas-C/pre-commit-hooks + rev: v1.5.4 + hooks: + - id: insert-license + args: + - --license-filepath + - .pre-commit/FILE_HEADER + - --comment-style + - "#" + - --allow-past-years + types: [python] +- repo: https://github.com/PyCQA/docformatter + rev: v1.5.0 + hooks: + - id: docformatter + args: [--in-place, --wrap-summaries=120, --wrap-descriptions=120] +- repo: https://github.com/PyCQA/pydocstyle + rev: 6.1.1 + hooks: + - id: pydocstyle + name: pydocstyle + entry: pydocstyle + language: python + types: [python] + exclude: (.ci|.github) + additional_dependencies: + - toml +- repo: https://github.com/adrienverge/yamllint.git + rev: v1.28.0 + hooks: + - id: yamllint + name: yamllint + description: This hook runs yamllint. + entry: yamllint + language: python + types: [file, yaml] +- repo: https://github.com/trufflesecurity/trufflehog.git + rev: v3.40.0 + hooks: + - id: trufflehog + name: secret scan + exclude: tests/horrible_strings.py + entry: trufflehog filesystem ./ + args: + - --only-verified + - --fail diff --git a/.pre-commit/FILE_HEADER b/.pre-commit/FILE_HEADER new file mode 100644 index 00000000..5081c939 --- /dev/null +++ b/.pre-commit/FILE_HEADER @@ -0,0 +1,2 @@ +Copyright 2024 Databricks +SPDX-License-Identifier: Apache-2.0 diff --git a/.yamllint.yaml b/.yamllint.yaml new file mode 100644 index 00000000..84a08ef7 --- /dev/null +++ b/.yamllint.yaml @@ -0,0 +1,42 @@ +yaml-files: +- "*.yaml" +- "*.yml" +- .yamllint + +ignore: | + wandb + +rules: + braces: + forbid: non-empty + brackets: + forbid: false + colons: enable + commas: enable + comments: enable + comments-indentation: enable + document-end: + present: false + document-start: + present: false + empty-lines: enable + empty-values: disable + hyphens: enable + indentation: + spaces: 2 + indent-sequences: false + check-multi-line-strings: false + key-duplicates: enable + key-ordering: disable + line-length: + max: 120 + allow-non-breakable-words: true + allow-non-breakable-inline-mappings: true + new-line-at-end-of-file: enable + new-lines: enable + octal-values: enable + quoted-strings: + quote-type: double + required: false + trailing-spaces: enable + truthy: disable diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..abbe03de --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,104 @@ +# Contributing to MegaBlocks + +Thanks for considering contributing to MegaBlocks! + +Issues tagged with [good first issue](https://github.com/mosaicml/megablocks/issues?q=is%3Aissue+is%3Aopen+label%3A%22good+first+issue%22) are great options to start contributing. + +If you have questions, join us on [Slack](https://join.slack.com/t/mosaicml-community/shared_invite/zt-w0tiddn9-WGTlRpfjcO9J5jyrMub1dg) -- we'll be happy to help you! + +We welcome contributions for bug fixes, new efficient methods you'd like to contribute to the community, or new models and datasets! + +## Prerequisites + +To set up the development environment in your local box, run the commands below. + +1\. Install the dependencies needed for testing and linting the code: + + +```bash +pip install -e '.[all]' +``` + +2\. Configure [pre-commit](https://pre-commit.com/), which automatically formats code before +each commit: + + +```bash +pre-commit install +``` + +## Submitting a Contribution + +To submit a contribution: + +1\. Fork a copy of the [MegaBlocks](https://github.com/databricks/megablocks) library to your own account. + +2\. Clone your fork locally and add the megablocks repo as a remote repository: + + +```bash +git clone git@github.com:/megablocks.git +cd megablocks +git remote add upstream https://github.com/databricks/megablocks.git +``` + +3\. Create a branch and make your proposed changes. + + +```bash +git checkout -b cool-new-feature +``` + +4\. When you are ready, submit a pull request into the megablocks repository! + +## Pull request (PR) guidelines + +We have some rough guidelines that will make your PR easier to review and more likely to get smoothly merged. Please don't let uncertainty or difficulty with any of these things stop you from opening a PR! We are happy to help you through them :) +* Self-contained title and description. Please include a concise title and clear PR description. The title should allow someone to understand what the PR changes or does at a glance. The description should allow someone to understand the contents of the PR _without_ looking at the code. +* If the PR affects output that is displayed to a user of MegaBlocks (e.g. console logging or experiment tracker reporting), please include screenshots showing what the new output looks like. UX is important! +* Include tests. If you are fixing a bug, please add a test that would've caught the bug. If you are adding a new feature, please add unit tests that test the various components of the feature, and also a test that tests the full functionality of the feature. +* Please consider whether your changes affect the example notebooks or large parts of the code base, and run the daily tests locally if so (`pytest -m 'daily and not remote and not gpu and not vision and not doctest'`) +* `pre-commit` should help you handle formatting and type checking, but please do make sure you have it installed as described [above](#prerequisites). + +## Configuring README Code Snippets + +MegaBlocks uses [pytest-codeblocks](https://github.com/nschloe/pytest-codeblocks) to test all example code snippets. The pytest-codeblocks repository explains how to annotate code snippets, which supports most `pytest` configurations. For example, if a test requires model training, the GPU mark (``) should be applied. + +## Running Tests + +To test your changes locally, run: + +* `make test` # run CPU tests +* `make test-gpu` # run GPU tests +* `cd docs && make doctest` # run doctests + +Some of our checks test distributed training as well. To test these, run: + +* `make test-dist WORLD_SIZE=2` # run 2-cpu distributed tests +* `make test-dist-gpu WORLD_SIZE=2` # run 2-gpu distributed tests + +These tests run with the `composer` launcher. We also support `WORLD_SIZE=1`, which would run the tests with the `composer` launcher on a single device. + +See the [Makefile](/Makefile) for more information. + +If you want to run pre-commit hooks manually, which check for code formatting and type annotations, run `pre-commit run --all-files` + +### Docker + +To run the tests in the provided docker containers: + +* `docker pull mosaicml/composer` (or an alternative image like `mosaicml/composer:latest_cpu`) +* `docker run --rm -v ./:/composer --user $(id -u):$(id -g) -it mosaicml/composer` +* from inside the container + * `cd /megablocks` + * `pip install -e .` + * `pytest ` or `make ` to run the desired tests + + +## Code Style & Typing + +See the [MegaBlocks Style Guide](/STYLE_GUIDE.md) for guidelines on how to structure and format your code. + +MegaBlocks aims to annotate all functions with type annotations (introduced in +[PEP 526](https://www.python.org/dev/peps/pep-0526/)). Don't worry if you are not a Python typing expert; +put in the pull request, and we'll help you with getting the code into shape. diff --git a/Dockerfile b/Dockerfile index c71ed0ae..e5d9ef8b 100644 --- a/Dockerfile +++ b/Dockerfile @@ -6,4 +6,4 @@ RUN pip install flash-attn ENV PYTHONPATH="/mount/megablocks/third_party/Megatron-LM:${PYTHONPATH}" -WORKDIR /mount/megablocks \ No newline at end of file +WORKDIR /mount/megablocks diff --git a/LICENSE b/LICENSE index 63fd0525..be2d25e0 100644 --- a/LICENSE +++ b/LICENSE @@ -387,4 +387,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. \ No newline at end of file + limitations under the License. diff --git a/MANIFEST.in b/MANIFEST.in index 99749aa3..b701a758 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,2 +1,2 @@ recursive-include csrc *.h -recursive-include csrc *.cu \ No newline at end of file +recursive-include csrc *.cu diff --git a/README.md b/README.md index ee3628f0..a3013d00 100644 --- a/README.md +++ b/README.md @@ -22,9 +22,9 @@ NOTE: This assumes you have `numpy` and `torch` installed. Installing `megablocks[gg]` enables dMoE computation with grouped GEMM. This feature is enabled by setting the `mlp_impl` argument to `grouped`. This is currently our recommended path for Hopper-generation GPUs. -Installing `megablocks[dev]` allows you to contribute to MegaBlocks and test locally. Installing `megablocks[testing]` allows you to test via Github Actions. +Installing `megablocks[dev]` allows you to contribute to MegaBlocks and test locally. Installing `megablocks[testing]` allows you to test via Github Actions. If you've installed megablocks[dev], you can run pre-commit install to configure the pre-commit hook to automatically format the code. -MegaBlocks can be installed with all dependencies via the `megablocks[all]` package. +MegaBlocks can be installed with all dependencies (except for `testing`) via the `megablocks[all]` package. # :steam_locomotive: Usage diff --git a/STYLE_GUIDE.md b/STYLE_GUIDE.md new file mode 100644 index 00000000..3d5876d1 --- /dev/null +++ b/STYLE_GUIDE.md @@ -0,0 +1,480 @@ +# 1. Style and Conventions + +## 1.1 Style Guide + +MegaBlocks generally follows Google's +[Python Style Guide](https://google.github.io/styleguide/pyguide.html) for how to format and structure code. + +## 1.2. Pre-Commit Hooks + +MegaBlocks uses [Pre Commit](https://pre-commit.com/) to enforce style checks. To configure, run +``` +pip install '.[dev]' # if not already installed +pre-commit install +``` + +The pre-commit hooks will now be run before each commit. You can also run the hooks manually via: + +``` +pre-commit run # run all hooks on changed files +pre-commit run --all-files # or, run all hooks on all files +``` + + + +## 1.3. Code Formatting + +MegaBlocks uses the [yapf](https://github.com/google/yapf) formatter for general formatting +[isort](https://github.com/PyCQA/isort) to sort imports. These checks run through pre-commit +(see section 2.2). These checks can also be run manually via: + +``` +pre-commit run yapf --all-files # for yapf +pre-commit run isort --all-files # for isort +``` + +The configuration is stored in [pyproject.toml](pyproject.toml). + + +## 1.4. Code Structure + +As a general rule of thumb, + +- Don't: Default to using inheritance for code reuse + + Do: prefer [composition over inheritance](https://en.wikipedia.org/wiki/Composition_over_inheritance) +- Don't: strive to implement all logic using classes + + Do: strive to implement logic as pure functions when possible, and classes when there is good reason +- Don't: Have a function accept falsy values that would then result in a no-op. + + Example of the anti-pattern: + + ```python + from typing import Optional + + def configure_deepspeed(deepspeed_config: Optional[dict]): + if deepspeed_config is None: + # Don't do this check in the callee, which results in a no-op + return + ... + ``` + + Do: Require the caller, instead of the callee, check for and handle falsy values. It's ok to accept falsy values + for individual arguments of a caller function, so long as the entire function would not be a no-op. + + Example: + ```python + from typing import Optional + + def configure_deepspeed(deepspeed_config: dict): + ... + + def trainer(deepspeed_config: Optional[dict]): + if deepspeed_config is not None: + # Do this check in the caller function + configure_deepspeed(deepspeed_config) + ... + ``` + +# 2. Type Annotations and Typechecking + +MegaBlocks aims to annotate all functions with type annotations (introduced in +[PEP 526](https://www.python.org/dev/peps/pep-0526/)). Type annotations help statically catch `TypeError` and +`AttributeError` bugs, in addition to other benefits, as outlined in the PEP. + +For documentation on typing annotations, see: +* [PEP 483](https://peps.python.org/pep-0483/) for a simplified introducation +* [PEP 484](https://peps.python.org/pep-0484/) for the full specification +* [Python docs for `typing`](https://docs.python.org/3/library/typing.html) for the API reference + +MegaBlocks uses [pyright](https://github.com/microsoft/pyright) +to validate type annotations. PyRight is automatically run as part of the pre-commit hooks, but you can also +run PyRight specifically via: + +``` +pre-commit run pyright --all-files +``` + +The pyright configuration is stored in [pyproject.toml](pyproject.toml). + + +## 2.1 Debugging + +Here are some suggestions to deal with pyright errors: + +1. Suppose a variable could be one of multiple types, like the following: + + ```python + from typing import Union + + def foo(x: Union[int, None]): + return x + 5 # type error -- None + 5 is not allowed! + ``` + + PyRight will complain since `None + 5` is not a valid operation. + Instead, add a check to ensure that `x is not None`: + + ```python + from typing import Union + + def foo(x: Union[int, None]): + if x is None: + raise TypeError("x must be an integer, not None!") + return x + 5 # valid + ``` + + Assert statements also work. However, assert statements should not be used for data validation + (see the assert statement section below). + ```python + from typing import Union + + def foo(x: Union[int, None]): + assert x is not None, "x should never be None" + return x + 5 # valid + ``` + +1. For variables where it is impossible for pyright to infer the correct type, use +[cast](https://docs.python.org/3/library/typing.html#typing.cast). +1. As a last resort, add a `# type: ignore` comment to the line where pyright emits an error. +Immediately following this statement, paste in the error emitted by pyright, +so other contributors will know why this error was silenced. + + +# 3. Public APIs +A public API, generally speaking, can be invoked by a user without a leading underscore in any portion of the path. +The following are examples of public APIs in [composer](https://github.com/mosaicml/composer/tree/dev): + +* Standalone functions in public modules (e.g. `composer.utils.dist.get_world_size`) +* Classes in public modules (e.g. `composer.trainer.trainer.Trainer`) +* Public methods in public classes (e.g. `composer.trainer.trainer.Trainer.fit`) +* Public modules (e.g. `composer.trainer.trainer`) + +The following rules apply to public APIs: +1. All public APIs must have a docstring (see the Documentation section below) +1. All parameters must have type annotations. +1. To minimize user imports, parameters should should use native PyTorch or Python types whenever possible. + + It is acceptable to use a union of types, so long as one of the options is a primitive. For example, in the + constructor for `composer.trainer.trainer.Trainer`, the `device` parameter is annotated like the following: + + ```python + from typing import Optional, Union + + from composer.devices import Device + + class Trainer: + def __init__( + self, + device: Union[str, Device], + ): + if isinstance(device, str): + device = Device(device) + ... + ``` + + This signature allows a user to pass a string for a device, + rather than having to import our custom device class. + + Parameters that are for power users (such as `load_object_store`) in the Trainer are exempt from this rule. + These parameters can require custom imports. + +1. Parameters that could take a sequence of elements should also allow `None` or a singleton. + This simplifies the user API by not having to construct a list (or tuple) to hold a single element + (or no element). For example, use `Optional[Union[torch.Tensor, Sequence[torch.Tensor]]`. + + The `composer.utils.ensure_tuple` helper method can convert a singleton, list, or tuple into a tuple. + For example + + ```python + from torch import Tensor + from typing import Optional, Sequence, Union + from composer.utils import ensure_tuple + + def foo(x: Optional[Union[Tensor, Sequence[Tensor]]]) -> tuple[Tensor, ...]: + return ensure_tuple(x) # ensures that the result is always a (potentially empty) tuple of tensors + ``` + + +# 4. Use of `assert` + +`assert` should be used only in test cases and for verifying invariants (likely required for type checking), +not for data validation. As asserts can be disabled in python by using the `-O` flag +(e.g. `python -O path/to/script.py`), they are not guaranteed to run. For data validation, instead use a style like +the following: + + + + +```python +if parameter is None: + raise ValueError("parameter must be specified and cannot be None") +``` + + +# 5. Imports and `__init__.py` + +All imports in MegaBlocks should be absolute -- that is, they do not begin with a period. + +## 5.1 External Dependencies +1. All external dependencies must be specified in both [setup.py](setup.py) for pip and [meta.yaml](meta.yaml) + for Anaconda. + +1. If a dependency is not core to MegaBlocks (e.g. it is for a model, dataset, algorithm, or some callbacks): + 1. It must be specified in a entry of the `extra_deps` dictionary of [setup.py](setup.py). + This dictionary groups dependencies that can be conditionally installed. An entry named `foo` can be installed with `pip install 'megablocks[foo]'`. For example, running `pip install 'megablocks[gg]'` will install everything in `install_requires`, along with `grouped_gemm`. + 1. It must also be specified in the `run_constrained` and the `test.requires` section. + 1. The import must be conditionally imported in the code. For example: + + + + ```python + from composer import Callback + from composer.utils import MissingConditionalImportError + + class SystemMetricsMonitor(Callback) + try: + import pynvml + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group="system_metrics_monitor", + conda_package="pynvml", + conda_channel="conda-forge",) from e + ``` + + This style allows users to perform minimal install of Composer without triggering `ImportError`s if + an optional dependency is missing. + + If the corresponding package is not published on Anaconda, then set the ``conda_package`` to the pip package + name, and set ``conda_channel`` to ``None``. For example, with DeepSpeed: + + + ```python + from composer.utils import MissingConditionalImportError + + try: + import deepspeed + except ImportError as e: + raise MissingConditionalImportError(extra_deps_group="deepspeed", + conda_package="deepspeed>=0.5.5", + conda_channel=None) from e + ``` + + + + 1. If the dependency is core to MegaBlocks, add the dependency to the `install_requires` section of + [setup.py](./setup.py) and the `requirements.run` section of [meta.yaml](./meta.yaml). + +## 5.2 Use of `__all__` + +All public modules must define `__all__` to be the list of members that should be re-exported. +The variable is necessary to 1) limit what `from XXX import *` imports, and 2) ensure that the documentation only +includes exported members, not unrelated re-imports. + +For example, from [composer/callbacks/memory_monitor.py](composer/callbacks/memory_monitor.py) + +```python +"""Log memory usage during training.""" +import logging +from typing import Union + +import torch.cuda + +from composer.core import State +from composer.loggers import Logger +from composer.core.callback import Callback + +log = logging.getLogger(__name__) + +__all__ = ["MemoryMonitor"] # export only the MemoryMonitor, not other imports like `Logger`, `State`, or `Callback` + + +class MemoryMonitor(Callback): + ... +``` + + +## 5.3 `__init__.py` + +All public classes and functions should be added to the module's `__init__.py`. + + +```python +from composer.path.to.module.file import MyClass as MyClass +from composer.path.to.module.file import my_func as my_func +``` + +If a file only contains public functions, then the following is also acceptable: + + +```python +from composer.path.to.module import my_file as my_file +``` + + +# 6. Documentation + +## 6.1 Docstrings + +MegaBlocks uses [Google Style Docstrings](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html). +All public APIs require documentation. + +### 6.1.1 What to include in Docstrings? + +Docstrings, at a minimum, should include a summary of what the function or class does, along with the arguments it takes. See [below](#612-formatting-docstrings) for how to format docstrings. The [Google Style Guide](https://google.github.io/styleguide/pyguide.html) also includes some guidelines on how to write docstrings. + +### 6.1.2 Formatting Docstrings + +The following guidelines apply to documentation. +1. Each function that needs a docstring must have its input arguments, return statement (if not None), and any custom + exceptions annotated. +1. The arguments for the `__init__` signature of classes should be documented under the class-level docstring. There + should not be any `__init__`-level docstring. +1. Each argument annotation should include the type. If the argument has a default value, the type annotation should + specify "optional", and the docstring should say the default value. Some examples: + + ```python + from typing import Optional, Union + + def foo(bar: int): + """Foo. + + Args: + bar (int): Required bar. + """ + ... + + def foo2(bar: int = 42): + """Foo2. + + Args: + bar (int, optional): The first Argument. Default: ``42``. + """ + ... + + def foo3(bar: Optional[int] = None): + """Foo3. + + Args: + bar (int, optional): The first Argument. Default: ``None``. + """ + ... + + def foo4(bar: Union[int, str] = 42): + """Foo4. + + Args: + bar (int | str, optional): The first Argument. Default: ``42``. + """ + ... + + def foo5(bar: int) -> int: + """Foo5. + + Args: + bar (int): Required bar. + + Returns: + int: Description of return statement. + """ + ... + + def foo6(bar: int) -> tuple[int, str]: + """Foo6. + + Args: + bar (int): Required bar. + + Returns: + a (int): Returned value. + b (str): Returned value. + """ + ... + ``` + +### 6.1.3 Building and Viewing Docs Locally + +Assuming you already have a development install of MegaBlocks (see these [instructions](CONTRIBUTING.md#prerequisites)), here’s how to build and previous the docs locally. + +**️️ ⚠ Warning:** Jenkins treats all sphinx warnings as errors, so they must be addressed before a PR can be merged. Building docs locally can help debug any warnings showing up on Jenkins! + +In one terminal, run: + + +```bash +source path/to/megablocks_venv/bin/activate # activate your megablocks virtual env +cd megablocks/docs # cd to the docs folder insde your megablocks clone +make clean +make html +``` + +In a second terminal, run: + + +```bash +cd megablocks/docs +python3 -m http.server --directory _build/html/ +``` + +Then, navigate to [http://localhost:8000](http://localhost:8000) in your browser. + +## 6.2 Doctests + +Most docstrings should also include a `.. doctest` or `.. testcode` example to clearly illustrate how one would interact with the class or function. As part of the CI/CD process, all `.. doctest` blocks are executed to ensure the example in the documentation actually works. + +### 6.2.1 Writing Doctests + +See the [Sphinx Doctest Extension](https://www.sphinx-doc.org/en/master/usage/extensions/doctest.html) for all of the available directives. Do not use `.. code-block::` for Python examples, as they are untested. + +Any test fixtures for doctests should go in [docs/source/doctest_fixtures.py](docs/source/doctest_fixtures.py) or in a `.. testsetup::` block. + +For example: +```python +import torch +from typing import Optional + +def my_function(x: Optional[torch.Tensor]) -> torch.Tensor: + """blah function + + Args: + input (torch.Tensor): Your guess. + + Returns: + torch.Tensor: How good your input is. + + Raises: + ValueError: If your input is negative. + + Example: + .. testsetup:: + + # optional setup section, not shown in docs + import torch + x = torch.randn(42) + + + .. testcode:: + + # shown in docs; runs after testsetup + my_function(x) + """ + ... +``` + +All doctests load the [docs/source/doctest_fixtures.py](docs/source/doctest_fixtures.py) file *before* tests run. If there are any variables that would be helpful have defined for all tests, feel free to add them into this file. However, if a variable is more specific to an individual doctest, then it would be best to include it in a `.. testsetup::` block, as not to pollute the global fixture namespace. (Unlike pytest fixtures, all doctest fixtures are given to every doctest; they cannot be specifically requested) + +### 6.2.2 Running Doctests Locally + +Assuming you already have a development install of MegaBlocks (see these [instructions](CONTRIBUTING.md#prerequisites)), here’s how to run the doctests. + + +```bash +source path/to/megablocks_venv/bin/activate # activate your megablocks virtual env +cd megablocks/docs # cd to the docs folder insde your megablocks clone +make clean +make html # the html build must be completed first to ensure all doctests are identified +make doctest 2>/dev/null # For more verbosity, do not direct stderr to /dev/null +``` diff --git a/docker.sh b/docker.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_125m_8gpu.sh b/exp/dmoe/dmoe_125m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_356m_8gpu.sh b/exp/dmoe/dmoe_356m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_46m_8gpu.sh b/exp/dmoe/dmoe_46m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/dmoe/dmoe_760m_8gpu.sh b/exp/dmoe/dmoe_760m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_125m_1gpu.sh b/exp/gpt2/gpt2_125m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_125m_8gpu.sh b/exp/gpt2/gpt2_125m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_1315m_1gpu.sh b/exp/gpt2/gpt2_1315m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_1315m_8gpu.sh b/exp/gpt2/gpt2_1315m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_356m_1gpu.sh b/exp/gpt2/gpt2_356m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_356m_8gpu.sh b/exp/gpt2/gpt2_356m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_46m_1gpu.sh b/exp/gpt2/gpt2_46m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_46m_8gpu.sh b/exp/gpt2/gpt2_46m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_760m_1gpu.sh b/exp/gpt2/gpt2_760m_1gpu.sh old mode 100644 new mode 100755 diff --git a/exp/gpt2/gpt2_760m_8gpu.sh b/exp/gpt2/gpt2_760m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/moe/moe_125m_8gpu.sh b/exp/moe/moe_125m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/moe/moe_356m_8gpu.sh b/exp/moe/moe_356m_8gpu.sh old mode 100644 new mode 100755 diff --git a/exp/moe/moe_46m_8gpu.sh b/exp/moe/moe_46m_8gpu.sh old mode 100644 new mode 100755 diff --git a/megablocks/__init__.py b/megablocks/__init__.py index 90e45114..d8d18483 100644 --- a/megablocks/__init__.py +++ b/megablocks/__init__.py @@ -1,2 +1,20 @@ -import megablocks.layers.dmoe -import megablocks.layers.moe +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import ParallelDroplessMLP, dMoE +from megablocks.layers.glu import SparseGLU +from megablocks.layers.mlp import MLP, SparseMLP +from megablocks.layers.moe import MoE, ParallelMLP, get_load_balancing_loss + +__all__ = [ + 'MoE', + 'dMoE', + 'get_load_balancing_loss', + 'ParallelMLP', + 'ParallelDroplessMLP', + 'SparseMLP', + 'MLP', + 'SparseGLU', + 'Arguments', +] diff --git a/megablocks/_version.py b/megablocks/_version.py new file mode 100644 index 00000000..5f259c3a --- /dev/null +++ b/megablocks/_version.py @@ -0,0 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""The MegaBlocks Version.""" + +__version__ = '0.7.0.dev0' diff --git a/megablocks/backend/__init__.py b/megablocks/backend/__init__.py index 8b137891..9d4e43e9 100644 --- a/megablocks/backend/__init__.py +++ b/megablocks/backend/__init__.py @@ -1 +1,2 @@ - +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index f99f93cf..b584ceed 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch import triton import triton.language as tl @@ -5,7 +8,7 @@ def assert_is_tensor(x, ndim): if x.ndim != ndim: - raise ValueError(f"Expected {ndim}-tensor but got {x.ndim}-tensor") + raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') def assert_is_matrix(x): @@ -14,12 +17,12 @@ def assert_is_matrix(x): def assert_is_vector(x): if x.ndim != 1: - raise ValueError(f"Expected 1-tensor but got {x.ndim}-tensor") + raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') def assert_equal(a, b): if a != b: - raise ValueError(f"Expected dimensions to be equal but got {a} and {b}.") + raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) # a: (tokens, hidden_size), real. @@ -40,18 +43,19 @@ def assert_equal(a, b): ) @triton.jit def _padded_copy( - a, - b, - indices, - bin_ids, - weights, - bins, - padded_bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr, - A_TO_B : tl.constexpr, - SCALE : tl.constexpr): + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): # Our index into array 'a'. index_a = tl.load(indices + tl.program_id(0)) @@ -62,12 +66,12 @@ def _padded_copy( # Now we know what bin we're assigned to, but we need to know how # many threadblocks were assigned to earlier bins so we can offset # in our bin properly. - offset_in_bin = tl.program_id(0); + offset_in_bin = tl.program_id(0) if bin_idx > 0: offset_in_bin -= tl.load(bins + bin_idx - 1) # Load the starting index of our bin in array 'b'. - index_b = offset_in_bin; + index_b = offset_in_bin if bin_idx > 0: index_b += tl.load(padded_bins + bin_idx - 1) @@ -89,7 +93,8 @@ def _padded_copy( iptr = a if A_TO_B else b optr = b if A_TO_B else a - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) + for _ in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -116,10 +121,7 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): # NOTE: Because of the padding, the output size is dynamic. # We load the final padded bin bound to get the output rows. output_rows = padded_bins[-1].cpu().item() - out = torch.zeros( - (output_rows, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, out, @@ -131,7 +133,8 @@ def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) return out @@ -150,10 +153,7 @@ def gather(x, indices, bin_ids, weights, bins, top_k): # NOTE: There is no padding so the output rows equals the # input rows multiplied by top_k. output_rows = x.shape[0] * top_k - out = torch.empty( - (output_rows, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.empty((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, out, @@ -165,7 +165,8 @@ def gather(x, indices, bin_ids, weights, bins, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) return out @@ -183,10 +184,7 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): assert_equal(indices.shape[0], weights.shape[0]) tokens = indices.shape[0] // top_k - out = torch.empty( - (tokens, top_k, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.empty((tokens, top_k, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( out, x, @@ -198,7 +196,8 @@ def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=False, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) # Reduce along the top-k dimension, if needed. return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) @@ -227,16 +226,17 @@ def scatter(x, indices, bin_ids, weights, bins, top_k): ) @triton.jit def _padded_copy_wgrad( - x, - grad, - wgrad, - indices, - bin_ids, - bins, - padded_bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr): + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): # Our index into 'tokens * top_k'. index_out = tl.load(indices + tl.program_id(0)) @@ -247,12 +247,12 @@ def _padded_copy_wgrad( # Now we know what bin we're assigned to, but we need to know how # many threadblocks were assigned to earlier bins so we can offset # in our bin properly. - offset_in_bin = tl.program_id(0); + offset_in_bin = tl.program_id(0) if bin_idx > 0: offset_in_bin -= tl.load(bins + bin_idx - 1) # Load the starting index of our bin in array 'x'. - index_x = offset_in_bin; + index_x = offset_in_bin if bin_idx > 0: index_x += tl.load(padded_bins + bin_idx - 1) @@ -264,7 +264,7 @@ def _padded_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + for _ in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -288,10 +288,7 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): assert_equal(bins.size(), padded_bins.size()) tokens = indices.shape[0] // top_k - out = torch.empty( - (tokens * top_k), - dtype=x.dtype, - device=x.device) + out = torch.empty((tokens * top_k), dtype=x.dtype, device=x.device) _padded_copy_wgrad[(indices.shape[0],)]( x, grad, @@ -301,7 +298,8 @@ def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): bins, padded_bins, NUM_COLUMNS=x.shape[1], - TOP_K=top_k) + TOP_K=top_k, + ) return out @@ -326,18 +324,19 @@ def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): ) @triton.jit def _binned_copy( - a, - b, - num_experts, - expert_capacity, - indices, - weights, - bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr, - A_TO_B : tl.constexpr, - SCALE : tl.constexpr): + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, + A_TO_B: tl.constexpr, + SCALE: tl.constexpr, +): # Load our indices into the output. expert_idx = tl.program_id(0) entry_idx = tl.program_id(1) @@ -349,7 +348,7 @@ def _binned_copy( # the number of tokens assigned to our expert. start = 0 if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) + start = tl.load(bins + expert_idx - 1) end = tl.load(bins + expert_idx) num_tokens = end - start @@ -380,7 +379,7 @@ def _binned_copy( optr = b if A_TO_B else a iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + for _ in range(iterations): mask = offsets < NUM_COLUMNS x = tl.load(iptr + offsets, mask=mask) x = x.to(tl.float32) * scale.to(tl.float32) @@ -401,10 +400,7 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): assert_equal(weights.shape[0], x.shape[0] * top_k) num_experts = bins.shape[0] - out = torch.zeros( - (num_experts, expert_capacity, x.shape[1]), - dtype=x.dtype, - device=x.device) + out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) _binned_copy[(num_experts, expert_capacity)]( x, @@ -417,7 +413,8 @@ def binned_gather(x, indices, weights, bins, expert_capacity, top_k): NUM_COLUMNS=x.shape[1], A_TO_B=True, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) return out @@ -433,10 +430,7 @@ def binned_scatter(x, indices, weights, bins, top_k): num_experts, expert_capacity, hidden_size = x.shape tokens = indices.shape[0] // top_k - out = torch.zeros( - (tokens, top_k, hidden_size), - dtype=x.dtype, - device=x.device) + out = torch.zeros((tokens, top_k, hidden_size), dtype=x.dtype, device=x.device) _binned_copy[(num_experts, expert_capacity)]( out, x, @@ -448,7 +442,8 @@ def binned_scatter(x, indices, weights, bins, top_k): NUM_COLUMNS=hidden_size, A_TO_B=False, TOP_K=top_k, - SCALE=weights is not None) + SCALE=weights is not None, + ) # Reduce along the top-k dimension, if needed. return out.sum(dim=1) if top_k > 1 else out.view(tokens, hidden_size) @@ -471,16 +466,17 @@ def binned_scatter(x, indices, weights, bins, top_k): ) @triton.jit def _binned_copy_wgrad( - x, - grad, - wgrad, - num_experts, - expert_capacity, - indices, - bins, - NUM_COLUMNS : tl.constexpr, - TOP_K : tl.constexpr, - BLOCK_X : tl.constexpr): + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, + NUM_COLUMNS: tl.constexpr, + TOP_K: tl.constexpr, + BLOCK_X: tl.constexpr, +): # Load our indices into the output. expert_idx = tl.program_id(0) entry_idx = tl.program_id(1) @@ -492,7 +488,7 @@ def _binned_copy_wgrad( # the number of tokens assigned to our expert. start = 0 if expert_idx > 0: - start = tl.load(bins + expert_idx - 1) + start = tl.load(bins + expert_idx - 1) end = tl.load(bins + expert_idx) num_tokens = end - start @@ -510,7 +506,7 @@ def _binned_copy_wgrad( acc = tl.zeros((BLOCK_X,), dtype=tl.float32) iterations = tl.cdiv(NUM_COLUMNS, BLOCK_X) - for i in range(tl.cdiv(NUM_COLUMNS, BLOCK_X)): + for _ in range(iterations): mask = offsets < NUM_COLUMNS data = tl.load(x + offsets, mask=mask).to(tl.float32) scale = tl.load(grad + offsets, mask=mask).to(tl.float32) @@ -532,10 +528,7 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k): num_experts, expert_capacity, hidden_size = x.shape tokens = indices.shape[0] // top_k - out = torch.zeros( - (tokens * top_k), - dtype=x.dtype, - device=x.device) + out = torch.zeros((tokens * top_k), dtype=x.dtype, device=x.device) _binned_copy_wgrad[(num_experts, expert_capacity)]( x, grad, @@ -545,5 +538,6 @@ def binned_scatter_wgrad(x, grad, indices, bins, top_k): indices, bins, NUM_COLUMNS=hidden_size, - TOP_K=top_k) + TOP_K=top_k, + ) return out diff --git a/megablocks/benchmark_util.py b/megablocks/benchmark_util.py index abf35212..02612d95 100644 --- a/megablocks/benchmark_util.py +++ b/megablocks/benchmark_util.py @@ -1,16 +1,19 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import torch def log_benchmark(name, arguments, time, std): - print("="*60) - print(f"{name} Benchmark") - print("Benchmark Parameters:") + print('=' * 60) + print(f'{name} Benchmark') + print('Benchmark Parameters:') for (key, value) in arguments.items(): - print(f"{key} = {value}") - print("Results:") - print("mean time = {:.3f}ms, std time = {:.3f}ms".format(time, std)) - print("="*60) + print(f'{key} = {value}') + print('Results:') + print('mean time = {:.3f}ms, std time = {:.3f}ms'.format(time, std)) + print('=' * 60) def benchmark_function(fn, iterations=100, warmup=10): @@ -26,7 +29,7 @@ def benchmark_function(fn, iterations=100, warmup=10): start.record() fn() end.record() - + torch.cuda.synchronize() times.append(start.elapsed_time(end)) return np.mean(times), np.std(times) diff --git a/megablocks/grouped_gemm_util.py b/megablocks/grouped_gemm_util.py index be24c6f6..6d3f977f 100644 --- a/megablocks/grouped_gemm_util.py +++ b/megablocks/grouped_gemm_util.py @@ -1,15 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +import warnings + +_grouped_gemm_is_available: bool = False try: import grouped_gemm -except ImportError: - grouped_gemm = None + _grouped_gemm_is_available = True +except ImportError as error: + warnings.warn('Grouped GEMM not available.') + def grouped_gemm_is_available(): - return grouped_gemm is not None + return _grouped_gemm_is_available + def assert_grouped_gemm_is_available(): - assert grouped_gemm_is_available(), ( - "Grouped GEMM not available. Please run " - "`pip install git+https://github.com/tgale96/grouped_gemm@main`.") + msg = ( + 'Grouped GEMM not available. Please run ' + '`pip install git+https://github.com/tgale96/grouped_gemm@main`.', + ) + assert _grouped_gemm_is_available, msg + backend = grouped_gemm.backend if grouped_gemm_is_available() else None ops = grouped_gemm.ops if grouped_gemm_is_available() else None diff --git a/megablocks/layers/__init__.py b/megablocks/layers/__init__.py index 8b137891..f0c42de3 100644 --- a/megablocks/layers/__init__.py +++ b/megablocks/layers/__init__.py @@ -1 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE + +__all__ = [ + 'MoE', + 'dMoE', +] diff --git a/megablocks/layers/activation_fn.py b/megablocks/layers/activation_fn.py index 613ef311..a31770ba 100644 --- a/megablocks/layers/activation_fn.py +++ b/megablocks/layers/activation_fn.py @@ -1,16 +1,24 @@ -from typing import Callable +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Callable, Union import torch -import stk +from stk import Matrix -def act_fn(x: stk.Matrix, function: Callable, return_grad_fn: bool = False, **kwargs): - assert isinstance(x, stk.Matrix) +def act_fn( + x: Matrix, + function: Callable, + return_grad_fn: bool = False, + **kwargs, +) -> Union[tuple[Matrix, Any] | Matrix]: + assert isinstance(x, Matrix) with torch.set_grad_enabled(torch.is_grad_enabled() or return_grad_fn): if return_grad_fn: x.data.requires_grad = True out = function(x.data, **kwargs) - y = stk.Matrix( + y = Matrix( x.size(), out, x.row_indices, @@ -18,7 +26,8 @@ def act_fn(x: stk.Matrix, function: Callable, return_grad_fn: bool = False, **kw x.offsets, x.column_indices_t, x.offsets_t, - x.block_offsets_t) + x.block_offsets_t, + ) if return_grad_fn: return y, out.backward return y diff --git a/megablocks/layers/all_to_all.py b/megablocks/layers/all_to_all.py index 12098ebb..5ac7067b 100644 --- a/megablocks/layers/all_to_all.py +++ b/megablocks/layers/all_to_all.py @@ -1,23 +1,28 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +import torch.distributed as dist + class AllToAllOp(torch.autograd.Function): @staticmethod def forward(ctx, x, output_split_sizes, input_split_sizes, group, async_op): - out = torch.empty( - (sum(output_split_sizes),) + x.shape[1:], - device=x.device, dtype=x.dtype) + out = torch.empty((sum(output_split_sizes),) + x.shape[1:], device=x.device, dtype=x.dtype) ctx.input_shape = x.shape ctx.output_split_sizes = output_split_sizes ctx.input_split_sizes = input_split_sizes ctx.group = group - handle = torch.distributed.all_to_all_single( - out, x, + handle = dist.all_to_all_single( + out, + x, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=group, - async_op=async_op) + async_op=async_op, + ) return out, handle @staticmethod @@ -26,15 +31,24 @@ def backward(ctx, grad, _): out = torch.empty( ctx.input_shape, device=grad.device, - dtype=grad.dtype) - torch.distributed.all_to_all_single( - out, grad, + dtype=grad.dtype, + ) + dist.all_to_all_single( + out, + grad, output_split_sizes=ctx.input_split_sizes, input_split_sizes=ctx.output_split_sizes, - group=ctx.group) + group=ctx.group, + ) return out, None, None, None, None return None, None, None, None, None + def all_to_all(x, output_split_sizes, input_split_sizes, group, async_op=False): return AllToAllOp.apply( - x, output_split_sizes, input_split_sizes, group, async_op) + x, + output_split_sizes, + input_split_sizes, + group, + async_op, + ) diff --git a/megablocks/layers/arguments.py b/megablocks/layers/arguments.py index c14b1721..0d5399b8 100644 --- a/megablocks/layers/arguments.py +++ b/megablocks/layers/arguments.py @@ -1,67 +1,72 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import dataclasses from functools import partial -import megablocks.grouped_gemm_util as grouped_gemm +from typing import Any, Callable, Optional, Union + import torch +import torch.distributed as dist import torch.nn.functional as F -from typing import Any, Callable, Optional, Union + +import megablocks.grouped_gemm_util as grouped_gemm # Type annotation for in-place Tensor initialization function. -InitFn = Callable[[torch.Tensor], None] +InitFn = Union[Callable[[torch.Tensor], None], partial[torch.Tensor]] _ALLOWED_BITWIDTHS = (-1, 4, 8) -DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate="tanh") +DEFAULT_ACTIVATION_FN = partial(F.gelu, approximate='tanh') @dataclasses.dataclass class Arguments: # Model arguments. - hidden_size : int = 1024 - ffn_hidden_size : int = 4096 - num_layers : int = 1 - bias : bool = True - return_bias : bool = True - activation_fn : Optional[Callable] = DEFAULT_ACTIVATION_FN + hidden_size: int = 1024 + ffn_hidden_size: int = 4096 + num_layers: int = 1 + bias: bool = True + return_bias: bool = True + activation_fn: Optional[Callable] = DEFAULT_ACTIVATION_FN # MoE arguments. - moe_num_experts : int = 1 - moe_top_k : int = 1 - moe_capacity_factor : int = 1 - moe_normalize_expert_weights : Optional[Union[int, float]] = None - moe_loss_weight : float = 0.1 - moe_zloss_weight : float = 0.001 - moe_jitter_eps : Optional[float] = None - moe_lbl_in_fp32 : bool = False + moe_num_experts: int = 1 + moe_top_k: int = 1 + moe_capacity_factor: int = 1 + moe_normalize_expert_weights: Optional[Union[int, float]] = None + moe_loss_weight: float = 0.1 + moe_zloss_weight: float = 0.001 + moe_jitter_eps: Optional[float] = None + moe_lbl_in_fp32: bool = False # Parallelism arguments. - moe_expert_model_parallelism : bool = False - expert_parallel_group : Optional[torch.distributed.ProcessGroup] = None - moe_weight_parallelism : bool = False - weight_parallel_group : Optional[torch.distributed.ProcessGroup] = None - pipeline_model_parallel_size : int = 1 - num_layers_per_virtual_pipeline_stage : Optional[int] = None + moe_expert_model_parallelism: bool = False + expert_parallel_group: Optional[dist.ProcessGroup] = None + pipeline_model_parallel_size: int = 1 + num_layers_per_virtual_pipeline_stage: Optional[int] = None # Compute arguments. - memory_optimized_mlp : bool = False - mlp_type : str = 'mlp' - mlp_impl : str = 'sparse' + memory_optimized_mlp: bool = False + mlp_type: str = 'mlp' + mlp_impl: str = 'sparse' # Initialization arguments. - fp16 : bool = True + fp16: bool = True bf16: bool = False - device : torch.device = torch.cuda.current_device() - init_method : InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) - output_layer_init_method : InitFn = init_method + device: Union[int, torch.device] = torch.cuda.current_device() + init_method: InitFn = partial(torch.nn.init.normal_, mean=0.0, std=0.02) + output_layer_init_method: InitFn = init_method # Benchmarking arguments. - uniform_expert_assignment : bool = False + uniform_expert_assignment: bool = False # shared expert arguments shared_expert: bool = False # enable using shared expert - fc_cls: torch.nn.Module = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) - fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict) # kwargs for custom fc layers + fc_cls: Any = torch.nn.Linear # class of the fully connected layer in shared expert (purpose: to allow using custom FC layer eg te.Linear (for FP8)) + fc_kwargs: dict[str, Any] = dataclasses.field(default_factory=dict,) # kwargs for custom fc layers remat_act_fn: bool = True # enable act fn to be rematerialized instead of stored - shared_expert_hidden_size: Optional[int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size + shared_expert_hidden_size: Optional[ + int] = None # hidden size of the shared expert IF we want to set it to something different from hidden_size shared_expert_weighted_sum: bool = False # enable using weighted sum for shared expert output (wieghted by number of experts used) def __post_init__(self): @@ -72,7 +77,7 @@ def __post_init__(self): self.shared_expert_hidden_size = self.ffn_hidden_size -def from_megatron(megatron_args): +def from_megatron(megatron_args: Any): args = Arguments() for field in dataclasses.fields(args): if hasattr(megatron_args, field.name): diff --git a/megablocks/layers/common.py b/megablocks/layers/common.py index fd99aa48..ee30e793 100644 --- a/megablocks/layers/common.py +++ b/megablocks/layers/common.py @@ -1,7 +1,12 @@ -from megablocks.layers.arguments import Arguments +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch -def dtype(args : Arguments): +from megablocks.layers.arguments import Arguments + + +def dtype(args: Arguments): if args.fp16: return torch.float16 elif args.bf16: diff --git a/megablocks/layers/dmlp_registry.py b/megablocks/layers/dmlp_registry.py index 666398ae..d765bd04 100644 --- a/megablocks/layers/dmlp_registry.py +++ b/megablocks/layers/dmlp_registry.py @@ -1,20 +1,30 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from typing import Union -from megablocks.layers import mlp -from megablocks.layers import glu + +from megablocks.layers import glu, mlp from megablocks.layers.arguments import Arguments MlpType = Union[mlp.SparseMLP, glu.SparseGLU] _REGISTRY = { - 'mlp': {'grouped': mlp.GroupedMLP, 'sparse' : mlp.SparseMLP}, - 'glu': {'grouped': glu.GroupedGLU, 'sparse': glu.SparseGLU}, + 'mlp': { + 'grouped': mlp.GroupedMLP, + 'sparse': mlp.SparseMLP, + }, + 'glu': { + 'grouped': glu.GroupedGLU, + 'sparse': glu.SparseGLU, + }, } + def get(args: Arguments) -> MlpType: """Returns an MLP for use in a dMoE instance. Uses the provided arguments to instantiate the appropriate - MLP instance. This only contains MLPs for use in dMoEs + MLP instance. This only contains MLPs for use in dMoEs (ie. only for the dropless versions of MoEs). Args: @@ -22,12 +32,11 @@ def get(args: Arguments) -> MlpType: Returns: An instantiated MLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: + if args.mlp_type not in _REGISTRY: raise ValueError(f'Unsupported mlp type: {args.mlp_type}') if args.mlp_impl not in _REGISTRY[args.mlp_type]: - raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.') + raise ValueError(f'{args.mlp_type} does not support {args.mlp_impl} backend.',) return _REGISTRY[args.mlp_type][args.mlp_impl](args) diff --git a/megablocks/layers/dmoe.py b/megablocks/layers/dmoe.py index 04a538d6..377b77f2 100644 --- a/megablocks/layers/dmoe.py +++ b/megablocks/layers/dmoe.py @@ -1,19 +1,23 @@ -from megablocks.layers import common -from megablocks.layers import moe -from megablocks.layers import dmlp_registry -from megablocks.layers import mpu -from megablocks.layers.arguments import Arguments -import megablocks.ops as ops +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import numpy as np -import stk +import stk.ops import torch +from stk import Matrix + +import megablocks.ops as ops +from megablocks.layers import common, dmlp_registry, moe, mpu +from megablocks.layers.arguments import Arguments + def promote_scalar(x): return x.view(1) if not len(x.size()) else x + class ParallelDroplessMLP(moe.ParallelMLP): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super(ParallelDroplessMLP, self).__init__(args) self.hidden_size = args.hidden_size self.ffn_hidden_size = mpu.features_per_rank(args) @@ -22,10 +26,11 @@ def __init__(self, args : Arguments): # Calculate the number of bits needed to represent the column indices # in the intermediate sparse matrix. - max_column_index = ( - (self.ffn_hidden_size * self.num_experts) // self.blocking) + max_column_index = ((self.ffn_hidden_size * self.num_experts) // self.blocking) self.transpose_sort_end_bit = max( - int(np.ceil(np.log2(max_column_index))), 1) + int(np.ceil(np.log2(max_column_index))), + 1, + ) def sparse_transpose(self, size, row_indices, column_indices, offsets): block_columns = size[1] // self.blocking @@ -37,7 +42,9 @@ def sparse_transpose(self, size, row_indices, column_indices, offsets): # To avoid overflow when we have large activation matrices we cast to # 32-bit before sorting. _, gather_indices = ops.sort( - column_indices.int(), self.transpose_sort_end_bit) + column_indices.int(), + self.transpose_sort_end_bit, + ) # There are a constant number of blocks in every row of the sparse matrix. # A blocks offset is: @@ -62,8 +69,10 @@ def topology(self, x, padded_bins): padded_tokens, _ = x.size() assert padded_tokens % self.blocking == 0 if self.ffn_hidden_size % self.blocking != 0: - raise ValueError(f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + - f'the block size {self.blocking}. Please update your configuration.') + raise ValueError( + f'The ffn_hidden_size {self.ffn_hidden_size} must be divisible by ' + + f'the block size {self.blocking}. Please update your configuration.', + ) # Offsets for the sparse matrix. All rows have the # same number of nonzero blocks dictated by the @@ -75,15 +84,18 @@ def topology(self, x, padded_bins): block_rows * blocks_per_row + 1, blocks_per_row, dtype=torch.int32, - device=x.device) + device=x.device, + ) # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, - self.blocking, - block_rows, - blocks_per_row) + column_indices = ops.topology( + padded_bins, + self.blocking, + block_rows, + blocks_per_row, + ) # TODO(tgale): This is unused. Remove the need for this in stk. # For now, use meta init to save the device memory. @@ -92,17 +104,29 @@ def topology(self, x, padded_bins): self.blocking, self.blocking, dtype=common.dtype(self.args), - device='meta') + device='meta', + ) shape = ( padded_tokens, - self.ffn_hidden_size * mpu.experts_per_rank(self.args) + self.ffn_hidden_size * mpu.experts_per_rank(self.args), ) - row_indices = stk.ops.row_indices( - shape, data, offsets, column_indices) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) column_indices_t, offsets_t, block_offsets_t = self.sparse_transpose( - shape, row_indices, column_indices, offsets) - return stk.Matrix(shape, data, row_indices, column_indices, offsets, - column_indices_t, offsets_t, block_offsets_t) + shape, + row_indices, + column_indices, + offsets, + ) + return stk.Matrix( + shape, + data, + row_indices, + column_indices, + offsets, + column_indices_t, + offsets_t, + block_offsets_t, + ) def indices_and_padded_bins(self, top_experts): # Sort the expert ids to produce the scatter/gather @@ -118,7 +142,9 @@ def indices_and_padded_bins(self, top_experts): # the matrix muliplications. Caculate the starting # position of each bin. padded_tokens_per_expert = ops.round_up( - tokens_per_expert, self.blocking) + tokens_per_expert, + self.blocking, + ) padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) padded_bins = promote_scalar(padded_bins) @@ -134,8 +160,7 @@ def sparse_forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, padded_bins, tokens_per_expert = ( - self.indices_and_padded_bins(top_experts)) + indices, bin_ids, bins, padded_bins, tokens_per_expert = (self.indices_and_padded_bins(top_experts)) # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) @@ -145,7 +170,8 @@ def sparse_forward_once(self, x, expert_weights, top_experts): bin_ids, bins, padded_bins, - self.top_k) + self.top_k, + ) # Create the sparse matrix topology. with torch.no_grad(): @@ -162,37 +188,35 @@ def sparse_forward_once(self, x, expert_weights, top_experts): expert_weights, bins, padded_bins, - self.top_k) + self.top_k, + ) return x, tokens_per_expert # For use in the base-class parallel_forward_once. def sparse_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k): + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): # Round the token counts up to the block size used in the matrix # multiplication. Calculate the starting position of each bin. padded_tokens_per_expert = ops.round_up( - tokens_per_expert, self.blocking) + tokens_per_expert, + self.blocking, + ) padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) padded_bins = promote_scalar(padded_bins) # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.padded_gather( - x, - indices, - bin_ids, - bins, - padded_bins, - top_k) + x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) # Create the sparse matrix topology. with torch.no_grad(): @@ -201,7 +225,6 @@ def sparse_permute_and_compute( # Perform the expert computation. x = self.mlp(x, topo) - # Un-route the data for the MoE output. return ops.padded_scatter( x, @@ -210,7 +233,8 @@ def sparse_permute_and_compute( expert_weights, bins, padded_bins, - top_k) + top_k, + ) def grouped_forward_once(self, x, expert_weights, top_experts): # x: [sl, bs, hs] @@ -219,8 +243,7 @@ def grouped_forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts)) + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) out = self.grouped_permute_and_compute( x, @@ -230,60 +253,49 @@ def grouped_forward_once(self, x, expert_weights, top_experts): expert_weights, bins, -1, # unused - self.args.moe_top_k) + self.args.moe_top_k, + ) return out, tokens_per_expert def grouped_permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, # unused - top_k): + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, # unused + top_k, + ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.gather( - x, - indices, - bin_ids, - bins, - top_k) + x = ops.gather(x, indices, bin_ids, bins, top_k) # Perform the expert computation. x = self.mlp(x, tokens_per_expert) # Un-route the data for the MoE output. - return ops.scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - top_k) + return ops.scatter(x, indices, bin_ids, expert_weights, bins, top_k) def forward_once(self, x, expert_weights, top_experts): if self.args.mlp_impl == 'sparse': - return self.sparse_forward_once( - x, expert_weights, top_experts) + return self.sparse_forward_once(x, expert_weights, top_experts) else: - return self.grouped_forward_once( - x, expert_weights, top_experts) - + return self.grouped_forward_once(x, expert_weights, top_experts) def permute_and_compute( - self, - x, - tokens_per_expert, - indices, - bin_ids, - expert_weights, - bins, - expert_capactiy, - top_k): + self, + x, + tokens_per_expert, + indices, + bin_ids, + expert_weights, + bins, + expert_capactiy, + top_k, + ): if self.args.mlp_impl == 'sparse': return self.sparse_permute_and_compute( x, @@ -293,7 +305,8 @@ def permute_and_compute( expert_weights, bins, expert_capactiy, - top_k) + top_k, + ) else: return self.grouped_permute_and_compute( x, @@ -303,7 +316,8 @@ def permute_and_compute( expert_weights, bins, expert_capactiy, - top_k) + top_k, + ) class dMoE(moe.MoE): diff --git a/megablocks/layers/gelu.py b/megablocks/layers/gelu.py index 49ac4a89..40b601d4 100644 --- a/megablocks/layers/gelu.py +++ b/megablocks/layers/gelu.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import stk import torch import torch.nn.functional as F @@ -6,12 +9,7 @@ @torch.jit.script def _gelu_backward_inplace(g, x): tanh_out = torch.tanh(0.79788456 * x * (1 + 0.044715 * x * x)) - ff = ( - 0.5 * x * ( - (1 - tanh_out * tanh_out) * - (0.79788456 + 0.1070322243 * x * x) - ) + 0.5 * (1 + tanh_out) - ) + ff = (0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * (1 + tanh_out)) return g.mul_(ff) @@ -26,7 +24,8 @@ def gelu_backward_(grad: stk.Matrix, x: stk.Matrix): x.offsets, x.column_indices_t, x.offsets_t, - x.block_offsets_t) + x.block_offsets_t, + ) return _gelu_backward_inplace(grad, x) @@ -34,10 +33,11 @@ def gelu(x: stk.Matrix): assert isinstance(x, stk.Matrix) return stk.Matrix( x.size(), - F.gelu(x.data, approximate="tanh"), + F.gelu(x.data, approximate='tanh'), x.row_indices, x.column_indices, x.offsets, x.column_indices_t, x.offsets_t, - x.block_offsets_t) + x.block_offsets_t, + ) diff --git a/megablocks/layers/glu.py b/megablocks/layers/glu.py index cc6931ab..cbe0c915 100644 --- a/megablocks/layers/glu.py +++ b/megablocks/layers/glu.py @@ -1,39 +1,57 @@ -from megablocks.layers import common -from megablocks.layers.activation_fn import act_fn -from megablocks.layers.mlp import SparseMLP, SharedMLP, create_dmoe_expert_weights, resolve_dtensor -from megablocks.layers import mpu -from megablocks.layers.arguments import Arguments, DEFAULT_ACTIVATION_FN -from megablocks import grouped_gemm_util as gg -import stk +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import stk.ops import torch +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import Arguments +from megablocks.layers.mlp import ( + SharedMLP, + SparseMLP, + create_dmoe_expert_weights, + resolve_dtensor, +) + class SparseGLU(SparseMLP): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__(args) - self.v1 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.v1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) with torch.no_grad(): - self.v1.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method)) + self.v1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) mpu.set_expert_model_parallel_attributes( - self.v1, self._should_set_parallelism_attribute) - - if self.args.moe_weight_parallelism: - raise NotImplementedError("Weight parallelism not yet supported with GLU.") + self.v1, + self._should_set_parallelism_attribute, + ) def forward(self, x, topo): if self.args.memory_optimized_mlp: - raise NotImplementedError("Memory optimized implementation not yet supported with GLU with sparse kernels.") + raise NotImplementedError( + 'Memory optimized implementation not yet supported with GLU with sparse kernels.', + ) - w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1), resolve_dtensor(w2) + w1, v1, w2 = self.scale_grad(self.w1), self.scale_grad(self.v1,), self.scale_grad(self.w2) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) # Compute the GLU. x1 = stk.ops.sdd(x, w1.t(), topo) @@ -44,11 +62,12 @@ def forward(self, x, topo): return stk.ops.dsd(x1, w2) + class MemoryOptimizedGroupedGLU(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -57,11 +76,11 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): v1 = v1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], v1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not v1.is_contiguous() or not w2.is_contiguous()): + if (not x.is_contiguous() or not w1.is_contiguous() or not v1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1', 'v1' and 'w2'.") # Layer 0: x @ w1.t(). + assert gg.backend is not None sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) v1_out = gg.backend.gmm(x, v1, batch_sizes, trans_b=True) @@ -83,15 +102,13 @@ def forward(ctx, x, w1, v1, w2, batch_sizes, activation_fn): return dsd_out @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') # Unpack saved tensors - dtype = ctx.dtype + # dtype = ctx.dtype saved_tensors = ctx.saved_tensors w1, v1, w2 = saved_tensors[:3] batch_sizes = saved_tensors[3] @@ -101,21 +118,31 @@ def backward(ctx, ddsd_out): # Rematerialize activation_fn output. activation_fn = ctx.activation_fn with torch.set_grad_enabled(True): - sdd_out.requires_grad = True - v1_out.requires_grad = True - activation_fn_out = activation_fn(sdd_out) * v1_out - activation_grad_fn = activation_fn_out.backward + sdd_out.requires_grad = True + v1_out.requires_grad = True + activation_fn_out = activation_fn(sdd_out) * v1_out + activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None dw2 = gg.backend.gmm( - activation_fn_out, ddsd_out, batch_sizes, trans_a=True) + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) # Compute dactivation_fn_out. # # NOTE: We reuse the activation_fn_out allocation. dactivation_fn_out = activation_fn_out gg.backend.gmm( - ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out) + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) # Compute dsdd_out. # @@ -139,14 +166,20 @@ def backward(ctx, ddsd_out): dx += gg.backend.gmm(dv1_out, v1, batch_sizes) return dx, dw1, dv1, dw2, None, None + memory_optimized_grouped_glu = MemoryOptimizedGroupedGLU.apply class GroupedGLU(SparseGLU): + def forward(self, x, tokens_per_expert): batch_sizes = tokens_per_expert.cpu().to(torch.long) - w1, v1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.v1), self.scale_grad(self.w2)) - w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1), resolve_dtensor(w2) + w1, v1, w2 = ( + self.scale_grad(self.w1), + self.scale_grad(self.v1), + self.scale_grad(self.w2), + ) + w1, v1, w2 = resolve_dtensor(w1), resolve_dtensor(v1,), resolve_dtensor(w2) # Re-shape the weights for the grouped GEMMs. ne = mpu.experts_per_rank(self.args) @@ -156,10 +189,16 @@ def forward(self, x, tokens_per_expert): if self.args.memory_optimized_mlp: return memory_optimized_grouped_glu( - x, w1, v1, w2, batch_sizes, - self.args.activation_fn) + x, + w1, + v1, + w2, + batch_sizes, + self.args.activation_fn, + ) # Compute the MLP. + assert gg.ops is not None x1 = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x2 = gg.ops.gmm(x, v1, batch_sizes, trans_b=True) x1 = self.args.activation_fn(x1) * x2 @@ -167,11 +206,12 @@ def forward(self, x, tokens_per_expert): class SharedGLU(SharedMLP): - """GPU for shared expert + """GPU for shared expert. Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTGLU class """ - def __init__(self, args : Arguments): + + def __init__(self, args: Arguments): super().__init__(args) self.gate_proj = args.fc_cls( args.hidden_size, diff --git a/megablocks/layers/memory_test.py b/megablocks/layers/memory_test.py index e3142726..4acbd94f 100644 --- a/megablocks/layers/memory_test.py +++ b/megablocks/layers/memory_test.py @@ -1,19 +1,18 @@ -import functools +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import gc -from megablocks.layers import dmoe, arguments, mpu -from megablocks import benchmark_util -import numpy as np import torch +import torch.distributed as dist -_TESTS = ( - (8, 2048, 4096, 4096, 32, 4), +from megablocks.layers import arguments, dmoe -) +_TESTS = ((8, 2048, 4096, 4096, 32, 4),) def get_tensors(): - ptrs = set([]) + ptrs = set() out = [] for obj in gc.get_objects(): if torch.is_tensor(obj): @@ -25,13 +24,14 @@ def get_tensors(): def test_memory( - group, - batch_size, - sequence_length, - hidden_size, - ffn_hidden_size, - num_experts, - top_k): + group, + batch_size, + sequence_length, + hidden_size, + ffn_hidden_size, + num_experts, + top_k, +): args = arguments.Arguments( hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, @@ -41,13 +41,13 @@ def test_memory( expert_parallel_group=group, fp16=False, bf16=True, - device=torch.cuda.current_device()) + device=torch.cuda.current_device(), + ) layer = dmoe.dMoE(args).cuda() - x = torch.randn( - (batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.bfloat16).requires_grad_(True) + x = torch.randn((batch_size, sequence_length, hidden_size), + device=torch.cuda.current_device(), + dtype=torch.bfloat16).requires_grad_(True) torch.cuda.empty_cache() # Run forward + backward. @@ -57,16 +57,13 @@ def test_memory( # Report peak memory. mem = torch.cuda.max_memory_allocated() - print("Max Memory Allocated = {:0.0f}MiB".format( - mem / 1e6)) - print("Max Memory Reserved = {:0.0f}MiB".format( - torch.cuda.max_memory_reserved() / 1e6)) + print('Max Memory Allocated = {:0.0f}MiB'.format(mem / 1e6)) + print('Max Memory Reserved = {:0.0f}MiB'.format(torch.cuda.max_memory_reserved() / 1e6,),) # Calculate weight and gradient memory usage. weight_memory = 2 * ( - layer.router.layer.weight.numel() + - layer.experts.mlp.w1.numel() + - layer.experts.mlp.w2.numel()) + layer.router.layer.weight.numel() + layer.experts.mlp.w1.numel() + layer.experts.mlp.w2.numel() + ) def grad_numel(x): if x.grad is not None: @@ -74,15 +71,12 @@ def grad_numel(x): return 0 grad_memory = 2 * ( - grad_numel(layer.router.layer.weight) + - grad_numel(layer.experts.mlp.w1) + - grad_numel(layer.experts.mlp.w2)) + grad_numel(layer.router.layer.weight) + grad_numel(layer.experts.mlp.w1) + grad_numel(layer.experts.mlp.w2) + ) weight_memory += grad_memory - print("Weight Memory Allocated = {:0.0f}MiB".format( - weight_memory / 1e6)) - print("Activation Memory Allocated = {:0.0f}MiB".format( - (mem - weight_memory) / 1e6)) + print('Weight Memory Allocated = {:0.0f}MiB'.format(weight_memory / 1e6)) + print('Activation Memory Allocated = {:0.0f}MiB'.format((mem - weight_memory) / 1e6,),) # Manually calculate GPU memory usage from the garbage # collector. @@ -92,17 +86,16 @@ def grad_numel(x): tensors = sorted(tensors, key=lambda x: -x.numel()) for i, t in enumerate(tensors): total += t.numel() - print(f"{i}: {t.shape}, {t.numel() * 2}") + print(f'{i}: {t.shape}, {t.numel() * 2}') del tensors - print("Total Bytes Found = {:0.0f}MiB".format( - total * 2 / 1e6)) + print('Total Bytes Found = {:0.0f}MiB'.format(total * 2 / 1e6)) if __name__ == '__main__': - assert torch.distributed.is_available() - group = torch.distributed.init_process_group(backend='nccl') - local_rank = torch.distributed.get_rank(group) + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) torch.cuda.set_device(local_rank) for args in _TESTS: diff --git a/megablocks/layers/memory_test.sh b/megablocks/layers/memory_test.sh old mode 100644 new mode 100755 diff --git a/megablocks/layers/mlp.py b/megablocks/layers/mlp.py index 2bb1e3b5..6e6f4d82 100644 --- a/megablocks/layers/mlp.py +++ b/megablocks/layers/mlp.py @@ -1,34 +1,38 @@ -from packaging import version +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from typing import Any -from megablocks.layers import common -from megablocks.layers import gelu -from megablocks.layers.activation_fn import act_fn -from megablocks.layers import mpu -from megablocks.layers import weight_parallel as wp -from megablocks.layers.arguments import Arguments, InitFn, DEFAULT_ACTIVATION_FN -from megablocks import grouped_gemm_util as gg import stk +import stk.backend.triton_kernels +import stk.ops import torch -import torch.nn.functional as F +from packaging import version + +from megablocks import grouped_gemm_util as gg +from megablocks.layers import common, gelu, mpu +from megablocks.layers.activation_fn import act_fn +from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn class ScaleGradient(torch.autograd.Function): @staticmethod - @torch.cuda.amp.custom_fwd - def forward(ctx, x, scale): + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') + def forward(ctx: Any, x: torch.Tensor, scale: float): ctx.scale = scale return x @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad): + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: torch.Tensor, grad: torch.Tensor): return grad * ctx.scale, None + + scale_gradient = ScaleGradient.apply -def resolve_dtensor(weight): +def resolve_dtensor(weight: torch.Tensor): if version.parse(torch.__version__) >= version.parse('2.0.0'): from torch.distributed._tensor import DTensor if isinstance(weight, DTensor): @@ -36,18 +40,23 @@ def resolve_dtensor(weight): return weight -def create_moe_expert_weights(args : Arguments, - num_experts : int, - ffn_hidden_size : int, - hidden_size : int, - init_method : InitFn): +def create_moe_expert_weights( + args: Arguments, + num_experts: int, + ffn_hidden_size: int, + hidden_size: int, + init_method: InitFn, +): # Create the entire weight matrix such that the sampled weights will # not vary between data parallelism and expert model parallelism for # the same random seed. master_weights = torch.empty( - num_experts, ffn_hidden_size, hidden_size, + num_experts, + ffn_hidden_size, + hidden_size, device=args.device, - dtype=common.dtype(args)) + dtype=common.dtype(args), + ) init_method(master_weights) if not args.moe_expert_model_parallelism: @@ -75,35 +84,44 @@ def create_moe_expert_weights(args : Arguments, # Slice the weight matrix to get the chunk for this rank. with torch.no_grad(): - weights = master_weights[ - start_expert:end_expert, start_row:end_row] + weights = master_weights[start_expert:end_expert, start_row:end_row] return weights class MLP(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() self.args = args - expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) + # expert_parallel_world_size = mpu.get_expert_parallel_world_size(args) experts_per_rank = mpu.experts_per_rank(args) - self.w1 = torch.nn.Parameter(torch.empty( - experts_per_rank, - args.hidden_size, - mpu.features_per_rank(args), - device=args.device, - dtype=common.dtype(args))) - self.w2 = torch.nn.Parameter(torch.empty( - experts_per_rank, - mpu.features_per_rank(args), - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.w1 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + args.hidden_size, + mpu.features_per_rank(args), + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + experts_per_rank, + mpu.features_per_rank(args), + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) mpu.set_expert_model_parallel_attributes( - self.w1, args.moe_expert_model_parallelism) + self.w1, + args.moe_expert_model_parallelism, + ) mpu.set_expert_model_parallel_attributes( - self.w2, args.moe_expert_model_parallelism) + self.w2, + args.moe_expert_model_parallelism, + ) # Initialize the parameters for the MLP. # @@ -115,16 +133,26 @@ def __init__(self, args : Arguments): # usage. with torch.no_grad(): w1 = create_moe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method) + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ) self.w1.copy_(w1.transpose(1, 2).contiguous()) - self.w2.copy_(create_moe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.output_layer_init_method)) + self.w2.copy_( + create_moe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) self.gradient_scale = None if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args) + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) def scale_grad(self, w): if self.gradient_scale is None: @@ -139,35 +167,28 @@ def forward(self, x): return torch.bmm(x, w2) -def create_dmoe_expert_weights(args : Arguments, - num_experts : int, - rows : int, - columns : int, - init_method : InitFn): +def create_dmoe_expert_weights( + args: Arguments, + num_experts: int, + rows: int, + columns: int, + init_method: InitFn, +): weights = create_moe_expert_weights( - args, num_experts, rows, columns, init_method) - weights = weights.view([-1, columns]) - rows, columns = weights.shape - - if not args.moe_weight_parallelism: - return weights - - # Caclculate the number of rows on this weight parallel partition. - # 'rows' must be divisible by weight parallel world size. - weight_parallel_world_size = mpu.get_weight_parallel_world_size(args) - assert (rows % weight_parallel_world_size) == 0 - num_rows_per_rank = rows // weight_parallel_world_size - rank = mpu.get_weight_parallel_rank(args) - start_row = rank * num_rows_per_rank - end_row = (rank + 1) * num_rows_per_rank - return weights[start_row:end_row] + args, + num_experts, + rows, + columns, + init_method, + ) + return weights.view([-1, columns]) class MemoryOptimizedMLP(torch.autograd.Function): """Sparse MLP with manually scheduled memory reuse.""" @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, w2, topo, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -175,16 +196,17 @@ def forward(ctx, x, w1, w2, topo, activation_fn): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous()): + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - topo_tensors = (topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) + topo_tensors = ( + topo.row_indices, + topo.column_indices, + topo.offsets, + topo.column_indices_t, + topo.offsets_t, + topo.block_offsets_t, + ) # Layer 0: x @ w1.t(). sdd_out = stk.ops.sdd(x, w1.t(), topo) @@ -208,15 +230,13 @@ def forward(ctx, x, w1, w2, topo, activation_fn): return dsd_out @staticmethod - @torch.cuda.amp.custom_bwd + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') # unpack saved tensors - dtype = ctx.dtype + # dtype = ctx.dtype saved_tensors = ctx.saved_tensors w1, w2 = saved_tensors[:2] topo_tensors = saved_tensors[2:8] @@ -226,7 +246,11 @@ def backward(ctx, ddsd_out): # rematerialize activation function output activation_fn = ctx.activation_fn sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) - activation_fn_out, activation_grad_fn = act_fn(sdd_out, activation_fn, return_grad_fn=True) + activation_fn_out, activation_grad_fn = act_fn( + sdd_out, + activation_fn, + return_grad_fn=True, + ) # Compute dw2 with recomputed activation_fn output. dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) @@ -236,12 +260,14 @@ def backward(ctx, ddsd_out): # NOTE: We reuse the activation_fn_out allocation. dactivation_fn_out = activation_fn_out stk.backend.triton_kernels.sdd( - ddsd_out, w2.t(), + ddsd_out, + w2.t(), dactivation_fn_out.shape, dactivation_fn_out.data, dactivation_fn_out.offsets, dactivation_fn_out.row_indices, - dactivation_fn_out.column_indices) + dactivation_fn_out.column_indices, + ) # Compute dsdd_out. # @@ -270,33 +296,38 @@ def backward(ctx, ddsd_out): dsdd_out.block_offsets_t, False, w1, - ddsd_out) + ddsd_out, + ) dx = ddsd_out return dx, dw1, dw2, None, None + memory_optimized_mlp = MemoryOptimizedMLP.apply class SparseMLP(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() self.args = args - self._num_rows_per_rank = ( - (mpu.experts_per_rank(args) * mpu.features_per_rank(args)) // - mpu.get_weight_parallel_world_size(args) + self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) + + self.w1 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + self._num_rows_per_rank, + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), ) - - self.w1 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) - self.w2 = torch.nn.Parameter(torch.empty( - self._num_rows_per_rank, - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) # Initialize the parameters for the MLP. # @@ -307,51 +338,55 @@ def __init__(self, args : Arguments): # and the slice which causes large increases in our peak memory # usage. with torch.no_grad(): - self.w1.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.init_method)) - self.w2.copy_(create_dmoe_expert_weights( - args, args.moe_num_experts, args.ffn_hidden_size, - args.hidden_size, args.output_layer_init_method)) - - self._should_set_parallelism_attribute = ( - args.moe_expert_model_parallelism or args.moe_weight_parallelism) + self.w1.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.init_method, + ), + ) + self.w2.copy_( + create_dmoe_expert_weights( + args, + args.moe_num_experts, + args.ffn_hidden_size, + args.hidden_size, + args.output_layer_init_method, + ), + ) + + self._should_set_parallelism_attribute = args.moe_expert_model_parallelism mpu.set_expert_model_parallel_attributes( - self.w1, self._should_set_parallelism_attribute) + self.w1, + self._should_set_parallelism_attribute, + ) mpu.set_expert_model_parallel_attributes( - self.w2, self._should_set_parallelism_attribute) + self.w2, + self._should_set_parallelism_attribute, + ) self.gradient_scale = None if self.args.moe_expert_model_parallelism: - self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args) + self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) def scale_grad(self, w): if self.gradient_scale is None: return w return scale_gradient(w, self.gradient_scale) - def parallel_forward(self, x, topo): - group = self.args.weight_parallel_group - w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) - if self.args.memory_optimized_mlp: - if self.args.activation_fn is not DEFAULT_ACTIVATION_FN: - raise NotImplementedError(f'memory_optimized_weight_parallel_mlp not implemented for custom {activation_fn=}.') - return wp.memory_optimized_weight_parallel_mlp( - x, w1, w2, topo, group) - - # Compute the MLP. - x = wp.sdd_nt(x, w1, topo, group) - activation_fn_out = act_fn(x, self.args.activation_fn) - return wp.dsd_nn(activation_fn_out, w2, group) - def forward(self, x, topo): w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) - if self.args.moe_weight_parallelism: - return self.parallel_forward(x, topo) - elif self.args.memory_optimized_mlp: + if self.args.memory_optimized_mlp: return memory_optimized_mlp( - x, w1, w2, topo, self.args.activation_fn) + x, + w1, + w2, + topo, + self.args.activation_fn, + ) # Compute the MLP. x = stk.ops.sdd(x, w1.t(), topo) @@ -363,7 +398,7 @@ class MemoryOptimizedGroupedMLP(torch.autograd.Function): """GroupedMLP with manually scheduled memory reuse.""" @staticmethod - @torch.cuda.amp.custom_fwd + @torch.amp.autocast_mode.custom_fwd(device_type='cuda') def forward(ctx, x, w1, w2, batch_sizes, activation_fn): # Cast inputs using ctx dtype from AMP if ctx._fwd_used_autocast: @@ -371,11 +406,11 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): w1 = w1.to(ctx._dtype) w2 = w2.to(ctx._dtype) # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous()): + if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") # Layer 0: x @ w1.t(). + assert gg.backend is not None sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) # activation_fn @@ -396,15 +431,13 @@ def forward(ctx, x, w1, w2, batch_sizes, activation_fn): return dsd_out @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, ddsd_out): - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") + @torch.amp.autocast_mode.custom_bwd(device_type='cuda') + def backward(ctx: Any, ddsd_out: torch.Tensor): + if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): + raise ValueError('Expected all MLP inputs to need grad.') # Unpack saved tensors - dtype = ctx.dtype + # dtype = ctx.dtype saved_tensors = ctx.saved_tensors w1, w2 = saved_tensors[:2] batch_sizes = saved_tensors[2] @@ -419,15 +452,25 @@ def backward(ctx, ddsd_out): activation_grad_fn = activation_fn_out.backward # Compute dw2 with recomputed activation_fn output. + assert gg.backend is not None dw2 = gg.backend.gmm( - activation_fn_out, ddsd_out, batch_sizes, trans_a=True) + activation_fn_out, + ddsd_out, + batch_sizes, + trans_a=True, + ) # Compute dactivation_fn_out. # # NOTE: We reuse the activation_fn_out allocation. dactivation_fn_out = activation_fn_out gg.backend.gmm( - ddsd_out, w2, batch_sizes, trans_b=True, c=dactivation_fn_out) + ddsd_out, + w2, + batch_sizes, + trans_b=True, + c=dactivation_fn_out, + ) # Compute dsdd_out. # @@ -449,6 +492,7 @@ def backward(ctx, ddsd_out): dx = ddsd_out return dx, dw1, dw2, None, None + memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply @@ -463,27 +507,29 @@ def forward(self, x, tokens_per_expert): w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) - if self.args.moe_weight_parallelism: - raise NotImplementedError( - "Weight parallelism not yet supported with GroupedMLP.") - if self.args.memory_optimized_mlp: return memory_optimized_grouped_mlp( - x, w1, w2, batch_sizes, - self.args.activation_fn) + x, + w1, + w2, + batch_sizes, + self.args.activation_fn, + ) # Compute the MLP. + assert gg.ops is not None x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) x = self.args.activation_fn(x) return gg.ops.gmm(x, w2, batch_sizes) class SharedMLP(torch.nn.Module): - """MLP for shared expert + """MLP for shared expert. Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class """ - def __init__(self, args : Arguments): + + def __init__(self, args: Arguments): super().__init__() self.args = args self.fc_kwargs: dict[str, Any] = { @@ -505,7 +551,11 @@ def __init__(self, args : Arguments): ) self.down_proj._is_residual = True # a flag for llm-foundry init - def add_experts_sharedexpert(self, shared_expert_out: torch.Tensor, expert_out: torch.Tensor) -> torch.Tensor: + def add_experts_sharedexpert( + self, + shared_expert_out: torch.Tensor, + expert_out: torch.Tensor, + ) -> torch.Tensor: # Helper function to add expert output to shared expert output # with optional weighted sum. if self.args.shared_expert_weighted_sum: @@ -513,7 +563,10 @@ def add_experts_sharedexpert(self, shared_expert_out: torch.Tensor, expert_out: # wieghted by number of experts used t_experts = self.args.moe_top_k + 1 sh_mlp_out = shared_expert_out / t_experts - return sh_mlp_out.add(expert_out, alpha=(self.args.moe_top_k / t_experts)) + return sh_mlp_out.add( + expert_out, + alpha=(self.args.moe_top_k / t_experts), + ) return shared_expert_out + expert_out diff --git a/megablocks/layers/moe.py b/megablocks/layers/moe.py index 22202178..dc5b8456 100644 --- a/megablocks/layers/moe.py +++ b/megablocks/layers/moe.py @@ -1,14 +1,15 @@ -from megablocks.layers import common -from megablocks.layers import mpu -from megablocks.layers import router -from megablocks.layers import mlp -from megablocks.layers import sharedexpert_registry -from megablocks.layers.all_to_all import all_to_all -from megablocks.layers.arguments import Arguments -import megablocks.ops as ops +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Tuple + import numpy as np import torch +import torch.distributed as dist +import megablocks.ops as ops +from megablocks.layers import common, mlp, mpu, router, sharedexpert_registry +from megablocks.layers.all_to_all import all_to_all +from megablocks.layers.arguments import Arguments _LOAD_BALANCING_LOSS = [] @@ -28,7 +29,7 @@ def clear_load_balancing_loss(): _LOAD_BALANCING_LOSS.clear() -def batched_load_balancing_loss(args : Arguments): +def batched_load_balancing_loss(args: Arguments): if args.moe_loss_weight == 0: return 0.0 @@ -36,40 +37,35 @@ def batched_load_balancing_loss(args : Arguments): # expert_scores[i].shape = (tokens, num_experts) # tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) tokens_per_expert, expert_scores, expert_logits = zip(*get_load_balancing_loss()) - num_layers_per_pipeline_stage = ( - args.num_layers // args.pipeline_model_parallel_size) + num_layers_per_pipeline_stage = (args.num_layers // args.pipeline_model_parallel_size) + if args.num_layers_per_virtual_pipeline_stage is not None: num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage if len(tokens_per_expert) != num_layers_per_pipeline_stage: raise ValueError( - f"Expected {num_layers_per_pipeline_stage} token_per_experts " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}") + f'Expected {num_layers_per_pipeline_stage} token_per_experts ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) if len(expert_scores) != num_layers_per_pipeline_stage: raise ValueError( - f"Expected {num_layers_per_pipeline_stage} expert_scores " - f"but found {len(tokens_per_expert)}.\nnum_layers = " - f"{args.num_layers}\npipeline_model_parallel_size = " - f"{args.pipeline_model_parallel_size}\n" - "num_layers_per_virtual_pipeline_stage" - f" = {args.num_layers_per_virtual_pipeline_stage}") + f'Expected {num_layers_per_pipeline_stage} expert_scores ' + f'but found {len(tokens_per_expert)}.\nnum_layers = ' + f'{args.num_layers}\npipeline_model_parallel_size = ' + f'{args.pipeline_model_parallel_size}\n' + 'num_layers_per_virtual_pipeline_stage' + f' = {args.num_layers_per_virtual_pipeline_stage}', + ) # Verify the shape of the tokens_per_expert and expert_scores tensors. - assert all([ - x.ndim == 1 and x.numel() == args.moe_num_experts - for x in tokens_per_expert - ]) + assert all((x.ndim == 1 and x.numel() == args.moe_num_experts for x in tokens_per_expert)) tokens = expert_scores[0].shape[0] - assert all([ - (x.ndim == 2 and x.shape[1] == args.moe_num_experts and - x.shape[0] == tokens) for x in expert_scores - ]) - + assert all(((x.ndim == 2 and x.shape[1] == args.moe_num_experts and x.shape[0] == tokens) for x in expert_scores)) # Concatenate the contributions of each layer and convert to # the correct types and formats for the dot product. @@ -90,15 +86,8 @@ def batched_load_balancing_loss(args : Arguments): # Calculate the total scale across all factors. # # loss_weight * num_experts / (num_layers * tokens * top_k) - scale_numerator = ( - args.moe_num_experts * - args.moe_loss_weight - ) - scale_denominator = ( - args.num_layers * - tokens * - args.moe_top_k - ) + scale_numerator = (args.moe_num_experts * args.moe_loss_weight) + scale_denominator = (args.num_layers * tokens * args.moe_top_k) scale = scale_numerator / scale_denominator zloss = (torch.logsumexp(expert_logits, dim=-1) ** 2).sum() / scale_denominator return scale * torch.dot(tokens_per_expert, expert_scores), args.moe_zloss_weight * zloss @@ -110,13 +99,13 @@ def batched_load_balancing_loss(args : Arguments): # parallel all2all. class ParallelMLP(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super(ParallelMLP, self).__init__() self.args = args # Calculate the number of experts in total and the number of experts # owned by this rank. - world_size = mpu.get_expert_parallel_world_size(args) + # world_size = mpu.get_expert_parallel_world_size(args) self.num_experts = args.moe_num_experts self.top_k = self.args.moe_top_k @@ -127,30 +116,30 @@ def __init__(self, args : Arguments): # Expert MLP. self.mlp = mlp.MLP(args) + self.bias: Optional[torch.Tensor] if self.args.bias: # Note that the output bias is not parallelized with expert # model parallelism. - self.bias = torch.nn.Parameter(torch.empty( - args.hidden_size, - device=args.device, - dtype=common.dtype(args))) + self.bias = torch.nn.Parameter( + torch.empty( + args.hidden_size, + device=args.device, + dtype=common.dtype(args), + ), + ) torch.nn.init.zeros_(self.bias) else: self.register_parameter('bias', None) # Select the forward function for the operating mode. - self.forward_fn = ( - self.parallel_forward_once if - args.moe_expert_model_parallelism else - self.forward_once) + self.forward_fn = (self.parallel_forward_once if args.moe_expert_model_parallelism else self.forward_once) - def expert_capacity(self, tokens): + def expert_capacity(self, tokens: int) -> int: world_size = mpu.get_expert_parallel_world_size(self.args) - tokens_per_expert = ( - self.top_k * tokens * world_size / self.num_experts) + tokens_per_expert = (self.top_k * tokens * world_size / self.num_experts) return int(self.args.moe_capacity_factor * tokens_per_expert) - def load_balancing_loss(self, tokens_per_expert, expert_scores): + def load_balancing_loss(self, tokens_per_expert: torch.Tensor, expert_scores: torch.Tensor): """Calculate the load balancing loss contribution.""" assert len(expert_scores.size()) == 2 tokens, num_experts = expert_scores.size() @@ -161,9 +150,11 @@ def load_balancing_loss(self, tokens_per_expert, expert_scores): scale = self.num_experts / (tokens * self.top_k) return scale * torch.dot( tokens_per_expert.to(expert_scores.dtype), - expert_scores.mean(dim=0)) + expert_scores.mean(dim=0), + ) - def indices_and_bins(self, top_expert): + def indices_and_bins(self, + top_expert: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Sort the expert ids to produce the scatter/gather # indices for the permutation. # @@ -171,7 +162,9 @@ def indices_and_bins(self, top_expert): # prior? Could we place the `torch.max` operation to return # 32-bit expert indices? top_expert = top_expert.int() - bin_ids, indices = ops.sort(top_expert, self.sort_end_bit) + output = ops.sort(top_expert, self.sort_end_bit) + assert output is not None + bin_ids, indices = output # Histogram the expert ids to identify the number of # tokens routed to each expert. @@ -183,45 +176,52 @@ def indices_and_bins(self, top_expert): # Calculate the bin bounds for the sorted tokens. bins = ops.inclusive_cumsum(tokens_per_expert, 0) + assert bins is not None bins = bins.view(1) if not len(bins.size()) else bins + + assert isinstance(indices, torch.Tensor) + assert isinstance(bin_ids, torch.Tensor) + assert isinstance(bins, torch.Tensor) + assert isinstance(tokens_per_expert, torch.Tensor) + return indices, bin_ids, bins, tokens_per_expert def permute_and_compute( - self, - x, - tokens_per_expert, # unused - indices, - bin_ids, # unused - expert_weights, - bins, - expert_capacity, - top_k): + self, + x: torch.Tensor, + tokens_per_expert: int, # unused + indices: torch.Tensor, + bin_ids: torch.Tensor, # unused + expert_weights: torch.Tensor, + bins: torch.Tensor, + expert_capacity: int, + top_k: int, + ): # Route the tokens for MoE computation. x = x.view(-1, x.shape[-1]) - x = ops.binned_gather( - x, indices, bins, expert_capacity, top_k) + output = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + assert output is not None + x = output # Perform the expert computation. Note that we don't # use biases for these linear operations. x = self.mlp(x) # Un-route the data for the MoE output. - return ops.binned_scatter( - x, indices, expert_weights, bins, top_k) + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) - def forward_once(self, x, expert_weights, top_experts): + def forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # x: [sl, bs, hs] # expert_weights: [sl * bs, top-k] # top_experts: [sl * bs, top-k] expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts)) + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. - sl, bs, hs = x.size() + sl, bs, _ = x.size() expert_capacity = self.expert_capacity(sl * bs) if expert_capacity == 0: expert_capacity = torch.max(tokens_per_expert).item() @@ -234,10 +234,11 @@ def forward_once(self, x, expert_weights, top_experts): expert_weights, bins, expert_capacity, - self.top_k) + self.top_k, + ) return x, tokens_per_expert - def parallel_forward_once(self, x, expert_weights, top_experts): + def parallel_forward_once(self, x: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): # NOTE: This function implements the same computation as forward_once # but with expert model parallelism. # @@ -262,23 +263,25 @@ def parallel_forward_once(self, x, expert_weights, top_experts): expert_weights = expert_weights.flatten() top_experts = top_experts.flatten() with torch.no_grad(): - indices, bin_ids, bins, tokens_per_expert = ( - self.indices_and_bins(top_experts)) + indices, bin_ids, bins, tokens_per_expert = (self.indices_and_bins(top_experts)) # If we're sharding the experts along the hidden dimension # multiple devices own parts of the same sets of experts. # Replicate the token counts so every device gets the counts. repeated_tokens_per_expert = ops.repeat( - tokens_per_expert, (mpu.hidden_sharding_degree(self.args),)) + tokens_per_expert, + (mpu.hidden_sharding_degree(self.args),), + ) # Pass token count information to the device on which the # target expert resides. - parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert) - tpe_handle = torch.distributed.all_to_all_single( + parallel_tokens_per_expert = torch.empty_like(repeated_tokens_per_expert,) + tpe_handle = dist.all_to_all_single( parallel_tokens_per_expert, repeated_tokens_per_expert, group=self.args.expert_parallel_group, - async_op=True) + async_op=True, + ) # Permute locally and without any padding so that tokens for each # parallel device are stored contiguously. @@ -286,12 +289,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # This view updates the shape of the tensor from [sl, bs, hs] to # [sl * bs, hs] prior to the permutation. x = x.view(-1, x.shape[-1]) - x = ops.gather( - x, - indices, - bin_ids, - bins, - self.top_k) + output = ops.gather(x, indices, bin_ids, bins, self.top_k) + assert output is not None + x = output # Compute the number of tokens that will be received from each # device and permute the input data across the devices. @@ -301,10 +301,8 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Reshape to [world_size, num_experts_per_rank]. world_size = mpu.get_expert_parallel_world_size(self.args) - repeated_tokens_per_expert = ( - repeated_tokens_per_expert.view(world_size, experts_per_rank)) - parallel_tokens_per_expert = ( - parallel_tokens_per_expert.view(world_size, experts_per_rank)) + repeated_tokens_per_expert = (repeated_tokens_per_expert.view(world_size, experts_per_rank)) + parallel_tokens_per_expert = (parallel_tokens_per_expert.view(world_size, experts_per_rank)) # TODO(tgale): It might be faster to do this on the GPU and # then communicate the results back to the host. @@ -328,9 +326,12 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # Start the cross-device permutation asynchronously so we can # overlap communication with computation. parallel_x, parallel_x_handle = all_to_all( - x, recv_counts, send_counts, + x, + recv_counts, + send_counts, self.args.expert_parallel_group, - async_op=True) + async_op=True, + ) with torch.no_grad(): # After we do the cross-device permutation we have the tokens on the @@ -340,48 +341,46 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # rest of this torch.no_grad() scope sets up the indices and bins # for this permutation. replicate_bins = ops.inclusive_cumsum( - parallel_tokens_per_expert.flatten(), 0) - replicate_bins = ( - replicate_bins.view(1) - if not len(replicate_bins.size()) - else replicate_bins + parallel_tokens_per_expert.flatten(), + 0, ) + replicate_bins = (replicate_bins.view(1) if not len(replicate_bins.size()) else replicate_bins) # Construct the expert indices for the permuted tokens. parallel_top_expert = torch.remainder( torch.arange( self.num_experts * mpu.hidden_sharding_degree(self.args), dtype=torch.int32, - device=indices.device + device=indices.device, ), mpu.experts_per_rank(self.args), ) parallel_top_expert = ops.replicate( parallel_top_expert.unsqueeze(dim=0), - replicate_bins, tokens_received).flatten() + replicate_bins, + tokens_received, + ).flatten() # TODO(tgale): The sort_end_bit here can be reduced. parallel_bin_ids, parallel_indices = ops.sort( - parallel_top_expert, self.sort_end_bit) + parallel_top_expert, + self.sort_end_bit, + ) # Calculate the bins boundaries from the token counts. parallel_tokens_per_expert = parallel_tokens_per_expert.sum( - dim=0, dtype=torch.int) - parallel_bins = ops.inclusive_cumsum( - parallel_tokens_per_expert, 0) - parallel_bins = ( - parallel_bins.view(1) - if not len(parallel_bins.size()) - else parallel_bins + dim=0, + dtype=torch.int, ) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = (parallel_bins.view(1) if not len(parallel_bins.size()) else parallel_bins) # If expert_capacity is set to zero, set the number of tokens # per expert to the maximum we need to avoid dropping tokens. - tokens, hs = x.size() + tokens, _ = x.size() expert_capacity = self.expert_capacity(tokens) if expert_capacity == 0: - expert_capacity = torch.max( - parallel_tokens_per_expert).item() + expert_capacity = torch.max(parallel_tokens_per_expert).item() # Locally permute the tokens and perform the expert computation. # Block to make sure that the cross-device permutation is complete. @@ -390,7 +389,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts): # moved to CPU for the prior all_to_all, which avoids an extra # device synchronization. parallel_tokens_per_expert = parallel_tokens_per_expert_cpu.sum( - dim=0, dtype=torch.int) + dim=0, + dtype=torch.int, + ) parallel_x_handle.wait() parallel_x = self.permute_and_compute( parallel_x, @@ -400,12 +401,16 @@ def parallel_forward_once(self, x, expert_weights, top_experts): None, # expert_weights parallel_bins, expert_capacity, - top_k=1) + top_k=1, + ) # Un-permute the tokens across the devices. x, _ = all_to_all( - parallel_x, send_counts, recv_counts, - self.args.expert_parallel_group) + parallel_x, + send_counts, + recv_counts, + self.args.expert_parallel_group, + ) # Reduce along the hidden sharding to get the final outputs. # @@ -413,26 +418,19 @@ def parallel_forward_once(self, x, expert_weights, top_experts): shape = ( mpu.hidden_sharding_degree(self.args), -1, - self.args.hidden_size + self.args.hidden_size, ) x = ops.sum(x.view(shape), dim=0) # Un-permute locally to setup for the next series of operations. - x = ops.scatter( - x, - indices, - bin_ids, - expert_weights, - bins, - self.top_k) + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k) return x, tokens_per_expert.flatten() - def forward(self, x, scores, logits, expert_weights, top_experts): + def forward(self, x: torch.Tensor, scores: torch.Tensor, logits: torch.Tensor, expert_weights: torch.Tensor, top_experts: torch.Tensor): in_shape = x.size() # Compute the experts. - x, tokens_per_expert = self.forward_fn( - x, expert_weights, top_experts) + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) if self.training and self.args.moe_loss_weight > 0: save_load_balancing_loss((tokens_per_expert, scores, logits)) x = x.view(in_shape) @@ -445,7 +443,7 @@ def forward(self, x, scores, logits, expert_weights, top_experts): class MoE(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super(MoE, self).__init__() # Token router. @@ -462,7 +460,7 @@ def __init__(self, args : Arguments): def _init_experts_mlp(self, args: Arguments): return ParallelMLP(args) - def forward(self, x): + def forward(self, x: torch.Tensor): # NOTE: If we're going to cast the activations to lower precision # do it before we permute the tokens to save bandwidth. x = common.cast_if_autocast_enabled(x) @@ -474,5 +472,8 @@ def forward(self, x): out = self.experts(x, scores, logits, expert_weights, top_experts) if self.shared_expert is not None: shared_expert_out = self.shared_expert(x) - out = self.shared_expert.add_experts_sharedexpert(shared_expert_out, out) + out = self.shared_expert.add_experts_sharedexpert( + shared_expert_out, + out, + ) return out diff --git a/megablocks/layers/mpu.py b/megablocks/layers/mpu.py index 49bbcbe6..b2321390 100644 --- a/megablocks/layers/mpu.py +++ b/megablocks/layers/mpu.py @@ -1,96 +1,93 @@ -from megablocks.layers.arguments import Arguments +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + import torch +import torch.distributed as dist +from megablocks.layers.arguments import Arguments -def is_moe_param(tensor : torch.Tensor) -> bool: - return hasattr(tensor, 'expert_model_parallel') +class MoeParam(torch.Tensor): -def get_expert_parallel_world_size(args : Arguments) -> int: - return ( - torch.distributed.get_world_size(args.expert_parallel_group) - if args.moe_expert_model_parallelism else 1 - ) + def __init__(self): + super().__init__(self) + self.expert_model_parallel: bool -def get_expert_parallel_rank(args : Arguments) -> int: - return ( - torch.distributed.get_rank(args.expert_parallel_group) - if args.moe_expert_model_parallelism else 0 - ) +def is_moe_param(tensor: torch.Tensor) -> bool: + return hasattr(tensor, 'expert_model_parallel') -def set_expert_model_parallel_attributes(tensor : torch.Tensor, - is_parallel : bool): - assert not hasattr(tensor, 'expert_model_parallel') - setattr(tensor, 'expert_model_parallel', is_parallel) +def get_expert_parallel_world_size(args: Arguments) -> int: + return (dist.get_world_size(args.expert_parallel_group) if args.moe_expert_model_parallelism else 1) -def param_is_expert_model_parallel(param : torch.Tensor) -> bool: - return (hasattr(param, 'expert_model_parallel') and - param.expert_model_parallel) +def get_expert_parallel_rank(args: Arguments) -> int: + return (dist.get_rank(args.expert_parallel_group) if args.moe_expert_model_parallelism else 0) -def copy_expert_model_parallel_attributes(destination_tensor : torch.Tensor, - source_tensor : torch.Tensor): - if hasattr(source_tensor, 'expert_model_parallel'): - setattr(destination_tensor, 'expert_model_parallel', - getattr(source_tensor,'expert_model_parallel')) +def set_expert_model_parallel_attributes( + tensor: torch.Tensor, + is_parallel: bool, +): + assert not hasattr(tensor, 'expert_model_parallel') + setattr(tensor, 'expert_model_parallel', is_parallel) -def get_weight_parallel_world_size(args : Arguments) -> int: - return ( - torch.distributed.get_world_size(args.weight_parallel_group) - if args.moe_weight_parallelism else 1 - ) +def param_is_expert_model_parallel(param: MoeParam) -> bool: + return (hasattr(param, 'expert_model_parallel') and param.expert_model_parallel) -def get_weight_parallel_rank(args : Arguments) -> int: - return ( - torch.distributed.get_rank(args.weight_parallel_group) - if args.moe_weight_parallelism else 0 - ) +def copy_expert_model_parallel_attributes( + destination_tensor: torch.Tensor, + source_tensor: torch.Tensor, +): + if hasattr(source_tensor, 'expert_model_parallel'): + setattr( + destination_tensor, + 'expert_model_parallel', + getattr(source_tensor, 'expert_model_parallel'), + ) -def synchronized_print(group, *x): - world_size = torch.distributed.get_world_size(group) - rank = torch.distributed.get_rank(group) +def synchronized_print(group: Optional[dist.ProcessGroup], *x: torch.Tensor): + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) for i in range(world_size): - torch.distributed.barrier(group) + dist.barrier(group) if i == rank: - print(f"rank = {rank}", *x) + print(f'rank = {rank}', *x) # Helpers for expert/tensor sharding. -def expert_sharding_degree(args : Arguments) -> int: +def expert_sharding_degree(args: Arguments) -> int: world_size = get_expert_parallel_world_size(args) esd = min(world_size, args.moe_num_experts) if (args.moe_num_experts % esd) != 0: - raise ValueError( - f"Cannot shard {args.moe_num_experts} experts {esd} ways.") + raise ValueError(f'Cannot shard {args.moe_num_experts} experts {esd} ways.',) return esd -def hidden_sharding_degree(args : Arguments) -> int: +def hidden_sharding_degree(args: Arguments) -> int: world_size = get_expert_parallel_world_size(args) esd = expert_sharding_degree(args) hsd = world_size // esd if (args.ffn_hidden_size % hsd) != 0: - raise ValueError( - f"Cannot shard {args.ffn_hidden_size} features {hsd} ways.") + raise ValueError(f'Cannot shard {args.ffn_hidden_size} features {hsd} ways.',) if (esd * hsd) != world_size: raise ValueError( - f"Invalid sharding. 'expert_sharding_degree' " - f"({esd}) * hidden_sharding_degree " - f"({hsd}) != world_size ({world_size}).") + f"Invalid sharding. 'expert_sharding_degree' ({esd}) * hidden_sharding_degree ({hsd}) != world_size ({world_size}).", + ) return hsd -def experts_per_rank(args : Arguments) -> int: +def experts_per_rank(args: Arguments) -> int: return args.moe_num_experts // expert_sharding_degree(args) -def features_per_rank(args : Arguments) -> int: +def features_per_rank(args: Arguments) -> int: return args.ffn_hidden_size // hidden_sharding_degree(args) diff --git a/megablocks/layers/router.py b/megablocks/layers/router.py index ec598728..a73f4a4e 100644 --- a/megablocks/layers/router.py +++ b/megablocks/layers/router.py @@ -1,6 +1,11 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + +import torch + from megablocks.layers import common from megablocks.layers.arguments import Arguments -import torch # NOTE: To enable end-to-end benchmarking without convergence we @@ -9,18 +14,19 @@ # so that PyTorch still executes the full set of router operation. class _UniformExpertAssignment(torch.autograd.Function): - @staticmethod - def forward(ctx, x, num_experts): + def forward(ctx: Any, x: torch.Tensor, num_experts: int): out = torch.arange(x.numel(), dtype=x.dtype, device=x.device) out = torch.remainder(out, num_experts) return out.view(x.shape) + + _uniform_expert_assignment = _UniformExpertAssignment.apply class LearnedRouter(torch.nn.Module): - def __init__(self, args : Arguments): + def __init__(self, args: Arguments): super().__init__() self.args = args @@ -34,21 +40,23 @@ def __init__(self, args : Arguments): args.moe_num_experts, bias=False, dtype=common.dtype(args), - device=args.device) + device=args.device, + ) args.init_method(self.layer.weight) - def jitter(self, x): + def jitter(self, x: torch.Tensor): + assert isinstance(self.args.moe_jitter_eps, float) low = 1.0 - self.args.moe_jitter_eps high = 1.0 + self.args.moe_jitter_eps noise = torch.rand(x.size(), dtype=x.dtype, device=x.device) return low + noise * (high - low) - def _top_k(self, scores): + def _top_k(self, scores: torch.Tensor): if self.args.moe_top_k == 1: - return scores.max(dim=-1,keepdim=True) + return scores.max(dim=-1, keepdim=True) return torch.topk(scores, self.args.moe_top_k, dim=-1) - def forward(self, x): + def forward(self, x: torch.Tensor): if self.training and self.args.moe_jitter_eps is not None: x = x * self.jitter(x) # scores = self.layer(x.view(-1, x.shape[-1])).softmax(dim=-1) @@ -57,10 +65,16 @@ def forward(self, x): expert_weights, expert_indices = self._top_k(scores) if self.args.moe_normalize_expert_weights: expert_weights = expert_weights / torch.norm( - expert_weights, p=self.args.moe_normalize_expert_weights,dim=-1, keepdim=True) + expert_weights, + p=self.args.moe_normalize_expert_weights, + dim=-1, + keepdim=True, + ) expert_indices = ( - _uniform_expert_assignment(expert_indices, self.args.moe_num_experts) - if self.args.uniform_expert_assignment else expert_indices + _uniform_expert_assignment( + expert_indices, + self.args.moe_num_experts, + ) if self.args.uniform_expert_assignment else expert_indices ) return scores, logits, expert_weights, expert_indices diff --git a/megablocks/layers/sharedexpert_registry.py b/megablocks/layers/sharedexpert_registry.py index 4d323ee5..0f62db39 100644 --- a/megablocks/layers/sharedexpert_registry.py +++ b/megablocks/layers/sharedexpert_registry.py @@ -1,14 +1,17 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from typing import Union -from megablocks.layers import mlp -from megablocks.layers import glu -from megablocks.layers.arguments import Arguments +from megablocks.layers import glu, mlp +from megablocks.layers.arguments import Arguments _REGISTRY = { 'mlp': mlp.SharedMLP, 'glu': glu.SharedGLU, } + def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: """Returns an SharedMLP for use in a dMoE instance. @@ -20,9 +23,8 @@ def get(args: Arguments) -> Union[mlp.SharedMLP, glu.SharedGLU]: Returns: An instantiated SharedMLP constructed using the input args. - """ - if args.mlp_type not in _REGISTRY: + if args.mlp_type not in _REGISTRY: raise ValueError(f'Unsupported mlp type: {args.mlp_type}') return _REGISTRY[args.mlp_type](args) diff --git a/megablocks/layers/testing.py b/megablocks/layers/testing.py deleted file mode 100644 index 530026ec..00000000 --- a/megablocks/layers/testing.py +++ /dev/null @@ -1,46 +0,0 @@ -from megablocks.layers.arguments import Arguments -import torch -import torch.nn.functional as F - - -def allclose(x, y, pct=0.5): - mask = torch.isclose(x, y, rtol=5e-2) - pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 - if pct_diff > pct: - print("{:.2f}% of values not close.".format(pct_diff)) - return False - return True - - -class FFN(torch.nn.Module): - - def __init__(self, args : Arguments): - super().__init__() - self.w1 = torch.nn.Parameter(torch.empty( - args.hidden_size, - args.ffn_hidden_size, - device=args.device, - dtype=torch.float16 if args.fp16 else torch.float32)) - self.w2 = torch.nn.Parameter(torch.empty( - args.ffn_hidden_size, - args.hidden_size, - device=args.device, - dtype=torch.float16 if args.fp16 else torch.float32)) - - def forward(self, x): - return torch.matmul(F.gelu( - torch.matmul(x, self.w1), approximate="tanh"), self.w2) - -class GLU(FFN): - - def __init__(self, args : Arguments): - super().__init__(args) - self.v1 = torch.nn.Parameter(torch.empty( - args.hidden_size, - args.ffn_hidden_size, - device=args.device, - dtype=torch.float16 if args.fp16 else torch.float32)) - - def forward(self, x): - x1 = F.gelu(torch.matmul(x, self.w1), approximate="tanh") * torch.matmul(x, self.v1) - return torch.matmul(x1, self.w2) diff --git a/megablocks/layers/weight_parallel.py b/megablocks/layers/weight_parallel.py deleted file mode 100644 index 46d5674f..00000000 --- a/megablocks/layers/weight_parallel.py +++ /dev/null @@ -1,367 +0,0 @@ -from megablocks.layers import gelu -import stk -import torch - - -def _gather_weights(w, group, parallel_w=None, async_op=False): - """Gather the weights across the process group. - - Args: - w: torch.Tensor, local shard of the weights. - group: ProcessGroup, the group to gather across. - parallel_w: torch.Tensor, option output tensor to use - for the gather. - async_op: Whether to gather asynchronously. - - Returns: - The gathered weights tensor and a handle for asynchronous - communication. - """ - n, k = w.shape - world_size = torch.distributed.get_world_size(group) - - if parallel_w is None: - parallel_w = torch.empty( - n * world_size, k, device=w.device, dtype=w.dtype) - handle = torch.distributed.all_gather_into_tensor( - parallel_w, w, group=group, async_op=async_op) - return parallel_w, handle - - -def _scaled_reduce_scatter(parallel_dw, group, dw=None, async_op=False): - """Scatter reduce the weights across the process group. - - Args: - parallel_dw: torch.Tensor, local shard of the weights. - group: ProcessGroup, the group to scatter-reduce across. - dw: torch.Tensor, option output tensor to use for the op. - async_op: Whether to scatter reduce asynchronously. - - Returns: - The reduced weights tensor, scaled by 1 / world_size, and - a handle for asynchronous communication. - """ - n, k = parallel_dw.shape - world_size = torch.distributed.get_world_size(group) - assert (n % world_size) == 0 - - # Pre-scale the gradients by the world size. - # - # NOTE: Reduce in float32, always. - parallel_dw = parallel_dw.float() / world_size - - if dw is None: - dw = torch.empty( - n // world_size, k, - device=parallel_dw.device, - dtype=torch.float32) - handle = torch.distributed.reduce_scatter_tensor( - dw, parallel_dw, group=group, async_op=async_op) - return dw, handle - - -class WeightParallelSddNt(torch.autograd.Function): - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward(ctx, x, w, topo, group): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w = w.to(ctx._dtype) - # [m, k] x [n, k] = [m, n] - if not x.is_contiguous() or not w.is_contiguous(): - raise ValueError("Expected contiguous 'x' and 'w'.") - - ctx.group = group - ctx.shape = topo.shape - ctx.save_for_backward( - x, w, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) - - # TODO(tgale): Support prefetching forward weights. - parallel_w, _ = _gather_weights(w, group) - return stk.ops.sdd(x, parallel_w.t(), topo).data - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad): - x, w = ctx.saved_tensors[:2] - grad = stk.Matrix(ctx.shape, grad, *ctx.saved_tensors[2:]) - - # Start the weight gather asynchronously to overlap with the - # weight gradient computation. - parallel_w, handle = _gather_weights(w, ctx.group, async_op=True) - parallel_dw = None - if ctx.needs_input_grad[1]: - parallel_dw = stk.ops.dsd(grad.t(), x) - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw, handle = _scaled_reduce_scatter(parallel_dw, ctx.group, async_op=True) - dx = None - if ctx.needs_input_grad[0]: - dx = stk.ops.dsd(grad, parallel_w) - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw = dw.to(w.dtype) - return dx, dw, None, None - - -def sdd_nt(a, b, topo, group): - return stk.Matrix( - topo.size(), - WeightParallelSddNt.apply(a, b, topo, group), - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) - - -class WeightParallelDsdNn(torch.autograd.Function): - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward(ctx, - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - w, - group): - # [m, k] x [k, n] = [m, n] - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - data = data.to(ctx._dtype) - w = w.to(ctx._dtype) - if not data.is_contiguous() or not w.is_contiguous(): - raise ValueError("Expected contiguous 'data' and 'w'.") - - ctx.group = group - ctx.shape = shape - ctx.save_for_backward( - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t, - w) - x = stk.Matrix( - shape, - data, - row_indices, - column_indices, - offsets, - column_indices_t, - offsets_t, - block_offsets_t) - - # TODO(tgale): Support prefetching forward weights. - parallel_w, _ = _gather_weights(w, group) - return stk.ops.dsd(x, parallel_w) - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, grad): - x = stk.Matrix(ctx.shape, *ctx.saved_tensors[:-1]) - w = ctx.saved_tensors[-1] - - # Start the weight gather asynchronously to overlap with the - # weight gradient computation. - parallel_w, handle = _gather_weights(w, ctx.group, async_op=True) - parallel_dw = None - if ctx.needs_input_grad[-2]: - parallel_dw = stk.ops.dsd(x.t(), grad) - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw, handle = _scaled_reduce_scatter(parallel_dw, ctx.group, async_op=True) - dx = None - if ctx.needs_input_grad[1]: - dx = stk.ops.sdd(grad, parallel_w.t(), x) - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw = dw.to(w.dtype) - return None, dx.data, None, None, None, None, None, None, dw, None - - -def dsd_nn(a, b, group): - return WeightParallelDsdNn.apply( - a.size(), - a.data, - a.row_indices, - a.column_indices, - a.offsets, - a.column_indices_t, - a.offsets_t, - a.block_offsets_t, - b, - group) - - -class MemoryOptimizedWeightParallelMLP(torch.autograd.Function): - """Sparse MLP with manually scheduled memory reuse.""" - - @staticmethod - @torch.cuda.amp.custom_fwd - def forward(ctx, x, w1, w2, topo, group): - # Cast inputs using ctx dtype from AMP - if ctx._fwd_used_autocast: - x = x.to(ctx._dtype) - w1 = w1.to(ctx._dtype) - w2 = w2.to(ctx._dtype) - # x: [m, k], w1: [n, k], w2: [n, k] - if (not x.is_contiguous() or not w1.is_contiguous() or - not w2.is_contiguous()): - raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") - - # Layer 0: x @ w1.t(). - parallel_w1, _ = _gather_weights(w1, group) - sdd_out = stk.ops.sdd(x, parallel_w1.t(), topo) - - # GeLU. - gelu_out = gelu.gelu(sdd_out) - - # Layer 1: x @ w2. - # - # NOTE: Reuse the buffer for the w1 weight gather. - parallel_w2, _ = _gather_weights(w2, group, parallel_w1) - dsd_out = stk.ops.dsd(gelu_out, parallel_w2) - - # NOTE: Save the input to the layer and the gelu input for - # gradient computation. We'll re-compute the gelu forward - # pass in the backward pass to avoid materializing another - # intermediate. - ctx.group = group - ctx.shape = topo.shape - ctx.save_for_backward( - x, w1, w2, sdd_out.data, - topo.row_indices, - topo.column_indices, - topo.offsets, - topo.column_indices_t, - topo.offsets_t, - topo.block_offsets_t) - return dsd_out - - @staticmethod - @torch.cuda.amp.custom_bwd - def backward(ctx, ddsd_out): - x, w1, w2 = ctx.saved_tensors[:3] - sdd_out = stk.Matrix(ctx.shape, *ctx.saved_tensors[3:]) - - if (not ctx.needs_input_grad[0] or - not ctx.needs_input_grad[1] or - not ctx.needs_input_grad[2]): - raise ValueError("Expected all MLP inputs to need grad.") - - # Start the weight gather asynchronously to overlap with the - # weight gradient computation and gelu recompute. - parallel_w2, handle = _gather_weights( - w2, ctx.group, async_op=True) - - # Compute dw2 with recomputed gelu output. - gelu_out = gelu.gelu(sdd_out) - parallel_dw2 = stk.ops.dsd(gelu_out.t(), ddsd_out) - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw2, handle = _scaled_reduce_scatter( - parallel_dw2, ctx.group, async_op=True) - - # Compute dgelu_out. - # - # NOTE: We reuse the gelu_out allocation. - stk.backend.triton_kernels.sdd( - ddsd_out, parallel_w2.t(), - sdd_out.shape, - gelu_out.data, - sdd_out.offsets, - sdd_out.row_indices, - sdd_out.column_indices) - dgelu_out = gelu_out - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw2 = dw2.to(w2.dtype) - - # Start the weight gather asynchronously to overlap with the - # weight and gelu gradient computation. - # - # NOTE: Reuse the buffer from the w2 weight gather. - parallel_w1, handle = _gather_weights( - w1, ctx.group, parallel_w2, async_op=True) - - # Compute dsdd_out. - # - # NOTE: This reuses the dgelu_out allocation. - dsdd_out = gelu.gelu_backward_(dgelu_out, sdd_out) - - # Compute dw1. - # - # NOTE: This reuses the parallel_dw2 allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.t().shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - True, # transpose_a - x, - parallel_dw2) - parallel_dw1 = parallel_dw2 - - # Start the weight gradient reduce scatter to overlap with the - # data gradient computation. - handle.wait() - dw1, handle = _scaled_reduce_scatter( - parallel_dw1, ctx.group, async_op=True) - - # Compute dx. - # - # NOTE: This reuses the ddsd_out allocation. - stk.backend.triton_kernels.dsd( - dsdd_out.shape, - dsdd_out.data, - dsdd_out.offsets, - dsdd_out.row_indices, - dsdd_out.column_indices, - dsdd_out.offsets_t, - dsdd_out.column_indices_t, - dsdd_out.block_offsets_t, - False, - parallel_w1, - ddsd_out) - dx = ddsd_out - - # NOTE: Be careful to wait and only cast dw to the output dtype once - # we've blocked on the asynchronous NCCL operation. - handle.wait() - dw1 = dw1.to(w1.dtype) - return dx, dw1, dw2, None, None - -memory_optimized_weight_parallel_mlp = MemoryOptimizedWeightParallelMLP.apply diff --git a/megablocks/ops/__init__.py b/megablocks/ops/__init__.py index 44a2909c..b9dc286a 100644 --- a/megablocks/ops/__init__.py +++ b/megablocks/ops/__init__.py @@ -1,7 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from megablocks.ops.binned_gather import binned_gather from megablocks.ops.binned_scatter import binned_scatter -from megablocks.ops.cumsum import exclusive_cumsum -from megablocks.ops.cumsum import inclusive_cumsum +from megablocks.ops.cumsum import exclusive_cumsum, inclusive_cumsum from megablocks.ops.gather import gather from megablocks.ops.histogram import histogram from megablocks.ops.padded_gather import padded_gather @@ -13,3 +15,21 @@ from megablocks.ops.sort import sort from megablocks.ops.sum import sum from megablocks.ops.topology import topology + +__all__ = [ + 'binned_gather', + 'binned_scatter', + 'exclusive_cumsum', + 'inclusive_cumsum', + 'gather', + 'histogram', + 'padded_gather', + 'padded_scatter', + 'repeat', + 'replicate', + 'round_up', + 'scatter', + 'sort', + 'sum', + 'topology', +] diff --git a/megablocks/ops/all_to_all_benchmark.py b/megablocks/ops/all_to_all_benchmark.py index d3fbcf3e..47b95301 100644 --- a/megablocks/ops/all_to_all_benchmark.py +++ b/megablocks/ops/all_to_all_benchmark.py @@ -1,6 +1,11 @@ -from megablocks.layers.all_to_all import all_to_all -from megablocks import benchmark_util +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch +import torch.distributed as dist + +from megablocks import benchmark_util +from megablocks.layers.all_to_all import all_to_all _ALL_TO_ALL_BENCHMARK = ( (8, 1024), @@ -23,29 +28,32 @@ (1024 * 1024, 1024), ) + def benchmark_all_to_all(group, sl, hs): - world_size = torch.distributed.get_world_size(group) - assert (sl % world_size) == 0 - send_recv_sizes = [sl // world_size] * world_size + world_size = dist.get_world_size(group) + assert (sl % world_size) == 0 + send_recv_sizes = [sl // world_size] * world_size + + x = torch.randn((sl, hs)).cuda().half() - x = torch.randn((sl, hs)).cuda().half() + details = { + 'world_size': world_size, + 'message_size (B)': send_recv_sizes[0] * hs * 2, # 2B elements. + } - details = { - "world_size": world_size, - "message_size (B)": send_recv_sizes[0] * hs * 2, # 2B elements. - } + def benchmark(): + return all_to_all(x, send_recv_sizes, send_recv_sizes, group) - fn = lambda: all_to_all(x, send_recv_sizes, send_recv_sizes, group) - time, std = benchmark_util.benchmark_function(fn) + time, std = benchmark_util.benchmark_function(benchmark) - if torch.distributed.get_rank(group) == 0: - benchmark_util.log_benchmark("All-To-All", details, time, std) + if dist.get_rank(group) == 0: + benchmark_util.log_benchmark('All-To-All', details, time, std) if __name__ == '__main__': - assert torch.distributed.is_available() - group = torch.distributed.init_process_group(backend='nccl') - local_rank = torch.distributed.get_rank(group) + assert dist.is_available() + group = dist.init_process_group(backend='nccl') + local_rank = dist.get_rank(group) torch.cuda.set_device(local_rank) for args in _ALL_TO_ALL_BENCHMARK: diff --git a/megablocks/ops/all_to_all_benchmark.sh b/megablocks/ops/all_to_all_benchmark.sh old mode 100644 new mode 100755 diff --git a/megablocks/ops/binned_gather.py b/megablocks/ops/binned_gather.py index 0592a553..89cce1b6 100644 --- a/megablocks/ops/binned_gather.py +++ b/megablocks/ops/binned_gather.py @@ -1,22 +1,37 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd + # Autograd wrapper for binned_gather kernel. class BinnedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bins, bin_size, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + bin_size: int, + top_k: int, + ): ctx.save_for_backward(indices, bins) ctx.top_k = top_k return kernels.binned_gather(x, indices, None, bins, bin_size, top_k) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bins = ctx.saved_tensors out = kernels.binned_scatter(grad, indices, None, bins, ctx.top_k) return out, None, None, None, None + + binned_gather = BinnedGatherOp.apply diff --git a/megablocks/ops/binned_scatter.py b/megablocks/ops/binned_scatter.py index 453de7d3..f5ce0d6f 100644 --- a/megablocks/ops/binned_scatter.py +++ b/megablocks/ops/binned_scatter.py @@ -1,13 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd + # Autograd wrapper for binned_scatter kernel. class BinnedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, weights, bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): assert len(x.size()) == 3 ctx.bin_size = x.size(1) ctx.top_k = top_k @@ -19,11 +32,17 @@ def forward(ctx, x, indices, weights, bins, top_k): @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() x, indices, weights, bins = ctx.saved_tensors out = kernels.binned_gather( - grad, indices, weights, bins, ctx.bin_size, ctx.top_k) + grad, + indices, + weights, + bins, + ctx.bin_size, + ctx.top_k, + ) wgrad = None if ctx.needs_input_grad[2]: @@ -32,6 +51,9 @@ def backward(ctx, grad): grad, indices, bins, - ctx.top_k) + ctx.top_k, + ) return out, None, wgrad, None, None + + binned_scatter = BinnedScatterOp.apply diff --git a/megablocks/ops/cumsum.py b/megablocks/ops/cumsum.py index 6907f81c..bf0482ac 100644 --- a/megablocks/ops/cumsum.py +++ b/megablocks/ops/cumsum.py @@ -1,19 +1,26 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + # Autograd wrappers for cumsum kernels. -# # NOTE: Does not support gradients. class ExclusiveCumsumOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, dim): + def forward(ctx: Any, x: torch.Tensor, dim: int): if len(x.size()) == 1: x = x.view([1, -1]) out = torch.empty_like(x) @@ -22,12 +29,15 @@ def forward(ctx, x, dim): out = torch.empty_like(x) ops.exclusive_cumsum(x, dim, out) return out + + exclusive_cumsum = ExclusiveCumsumOp.apply + class InclusiveCumsumOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, dim): + def forward(ctx: Any, x: torch.Tensor, dim: int) -> torch.Tensor: if len(x.size()) == 1: x = x.view([1, -1]) out = torch.empty_like(x) @@ -36,4 +46,6 @@ def forward(ctx, x, dim): out = torch.empty_like(x) ops.inclusive_cumsum(x, dim, out) return out + + inclusive_cumsum = InclusiveCumsumOp.apply diff --git a/megablocks/ops/gather.py b/megablocks/ops/gather.py index bd8da3a8..41b09a12 100644 --- a/megablocks/ops/gather.py +++ b/megablocks/ops/gather.py @@ -1,6 +1,11 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for gather kernel. @@ -8,19 +13,26 @@ class GatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): ctx.save_for_backward(indices, bin_ids, bins) ctx.top_k = top_k - return kernels.gather( - x, indices, bin_ids, None, bins, top_k) + return kernels.gather(x, indices, bin_ids, None, bins, top_k) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bin_ids, bins = ctx.saved_tensors - out = kernels.scatter( - grad, indices, bin_ids, None, bins, ctx.top_k) + out = kernels.scatter(grad, indices, bin_ids, None, bins, ctx.top_k) return out, None, None, None, None, None + + gather = GatherOp.apply diff --git a/megablocks/ops/histogram.py b/megablocks/ops/histogram.py index f81862b6..78552338 100644 --- a/megablocks/ops/histogram.py +++ b/megablocks/ops/histogram.py @@ -1,18 +1,27 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + # Autograd wrapper for histogram kernel. -# # NOTE: Does not support gradients. class HistogramOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, max_val): + def forward(ctx: Any, x: torch.Tensor, max_val: float): return ops.histogram(x, max_val) + + histogram = HistogramOp.apply diff --git a/megablocks/ops/histogram_benchmark.py b/megablocks/ops/histogram_benchmark.py index 9e0e9304..9de8e652 100644 --- a/megablocks/ops/histogram_benchmark.py +++ b/megablocks/ops/histogram_benchmark.py @@ -1,10 +1,13 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import ops import numpy as np import torch +from absl.testing import parameterized +from megablocks import ops _HISTOGRAM_TESTS = ( (16384, torch.int32, 2), @@ -17,6 +20,7 @@ (16384, torch.int32, 256), ) + def benchmark_function(fn, iterations=10): # Run once to get rid of startup overhead. fn() @@ -34,13 +38,13 @@ def benchmark_function(fn, iterations=10): def log_benchmark(arguments, mean_t, std_t): - print("="*60) - print("Benchmark Parameters:") + print('=' * 60) + print('Benchmark Parameters:') for (key, value) in arguments.items(): - print(f"{key} = {value}") - print("Results:") - print("mean / std = {:.2f}ms / {:.2f}ms".format(mean_t, std_t)) - print("="*60) + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) class HistogramBenchmark(parameterized.TestCase): @@ -49,12 +53,11 @@ class HistogramBenchmark(parameterized.TestCase): def testHistogram(self, n, dtype, max_val): x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: ops.histogram(x, max_val)) + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.histogram(x, max_val),) arguments = { - "n": n, - "dtype": dtype, - "max_val": max_val + 'n': n, + 'dtype': dtype, + 'max_val': max_val, } log_benchmark(arguments, mean_t, std_t) @@ -62,12 +65,11 @@ def testHistogram(self, n, dtype, max_val): def testTorchHistogram(self, n, dtype, max_val): x = torch.randint(0, 128, (n,)).cuda().to(dtype) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: torch.histc(x, max_val, 0, max_val-1)) + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.histc(x, max_val, 0, max_val - 1),) arguments = { - "n": n, - "dtype": dtype, - "max_val": max_val + 'n': n, + 'dtype': dtype, + 'max_val': max_val, } log_benchmark(arguments, mean_t, std_t) diff --git a/megablocks/ops/matmul_benchmark.py b/megablocks/ops/matmul_benchmark.py index 632155c5..bfa7b7c2 100644 --- a/megablocks/ops/matmul_benchmark.py +++ b/megablocks/ops/matmul_benchmark.py @@ -1,10 +1,13 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import benchmark_util -from megablocks import ops import stk import torch +from absl.testing import parameterized + +from megablocks import benchmark_util, ops # Calling tensor.t() calls tensor.transpose(0, 1) which calls @@ -12,7 +15,10 @@ # this adds. def transpose_view(x): return torch.as_strided( - x, (x.shape[1], x.shape[0]), (x.stride()[1], x.stride()[0])) + x, + (x.shape[1], x.shape[0]), + (x.stride()[1], x.stride()[0]), + ) _MATMUL_TESTS = ( @@ -25,9 +31,9 @@ def transpose_view(x): def log_benchmark(name, arguments, time, std, flops): benchmark_util.log_benchmark(name, arguments, time, std) - print("flops = {:.2f}B".format(flops / 1e9)) - print("throughput = {:.2f}T".format(flops / 1e9 / time)) - print("="*60) + print('flops = {:.2f}B'.format(flops / 1e9)) + print('throughput = {:.2f}T'.format(flops / 1e9 / time)) + print('=' * 60) class MatmulBenchmark(parameterized.TestCase): @@ -48,29 +54,28 @@ def build_sparse_matrix(self, x, padded_bins, fhs, ne): block_rows * blocks_per_row + 1, blocks_per_row, dtype=torch.int32, - device=x.device) + device=x.device, + ) # Indices for the sparse matrix. The indices for # the intermediate matrix are dynamic depending # on the mapping of tokens to experts. - column_indices = ops.topology(padded_bins, - blocking, - block_rows, - blocks_per_row) + column_indices = ops.topology( + padded_bins, + blocking, + block_rows, + blocks_per_row, + ) data = torch.empty( column_indices.numel(), blocking, blocking, dtype=torch.float16, - device=x.device) + device=x.device, + ) shape = (padded_tokens, fhs * ne) - row_indices = stk.ops.row_indices( - shape, data, offsets, column_indices) - return stk.Matrix(shape, - data, - row_indices, - column_indices, - offsets) + row_indices = stk.ops.row_indices(shape, data, offsets, column_indices) + return stk.Matrix(shape, data, row_indices, column_indices, offsets) def build_input_matrix(self, sl, hs, ne): x = torch.randn((sl, hs)).cuda().half() @@ -96,16 +101,23 @@ def testFFN_Linear0_Fwd_SDD_NT(self, sl, hs, fhs, ne): topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) w = transpose_view(w) - benchmark = lambda: stk.ops.sdd(x, w, topo) + def benchmark(): + return stk.ops.sdd(x, w, topo) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0::Fwd::SDD::NT", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0::Fwd::SDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): @@ -113,16 +125,23 @@ def testFFN_Linear0_GradX_DSD_NN(self, sl, hs, fhs, ne): w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) - benchmark = lambda: stk.ops.dsd(topo, w) + def benchmark(): + return stk.ops.dsd(topo, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0::GradX::DSD::NN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0::GradX::DSD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): @@ -130,16 +149,23 @@ def testFFN_Linear0_GradW_DSD_TN(self, sl, hs, fhs, ne): topo = self.build_sparse_matrix(x, padded_bins, fhs, ne) topo = topo.t() - benchmark = lambda: stk.ops.dsd(topo, x) + def benchmark(): + return stk.ops.dsd(topo, x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0::GradW::DSD::TN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): @@ -147,16 +173,23 @@ def testFFN_Linear1_Fwd_DSD_NN(self, sl, hs, fhs, ne): w = self.build_weight_matrix(ne, hs, fhs).t().contiguous() x = self.build_sparse_matrix(x, padded_bins, fhs, ne) - benchmark = lambda: stk.ops.dsd(x, w) + def benchmark(): + return stk.ops.dsd(x, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::Fwd::DSD::NN", arguments, mean_t, std_t, - x.nnz * hs * 2) + log_benchmark( + '1::Fwd::DSD::NN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): @@ -166,16 +199,23 @@ def testFFN_Linear1_GradX_SDD_NT(self, sl, hs, fhs, ne): out = stk.ops.dsd(x, w) w = transpose_view(w) - benchmark = lambda: stk.ops.sdd(out, w, x) + def benchmark(): + return stk.ops.sdd(out, w, x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::GradX::SDD::NT", arguments, mean_t, std_t, - x.nnz * hs * 2) + log_benchmark( + '1::GradX::SDD::NT', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): @@ -185,16 +225,23 @@ def testFFN_Linear1_GradW_DSD_TN(self, sl, hs, fhs, ne): out = stk.ops.dsd(x, w) x = x.t() - benchmark = lambda: stk.ops.dsd(x, out) + def benchmark(): + return stk.ops.dsd(x, out) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::GradW::DSD::TN", arguments, mean_t, std_t, - x.nnz * hs * 2) + log_benchmark( + '1::GradW::DSD::TN', + arguments, + mean_t, + std_t, + x.nnz * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): @@ -205,16 +252,23 @@ def testFFN_Linear0_Fwd_DDD_NT(self, sl, hs, fhs, ne): w = w.transpose(1, 2).contiguous() w = w.transpose(1, 2) - benchmark = lambda: torch.bmm(x, w) + def benchmark(): + return torch.bmm(x, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0::Fwd:DDD::NT", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0::Fwd:DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): @@ -224,16 +278,23 @@ def testFFN_Linear0_GradX_DDD_NN(self, sl, hs, fhs, ne): out = torch.bmm(x, w) w = w.transpose(1, 2).contiguous() - benchmark = lambda: torch.bmm(out, w) + def benchmark(): + return torch.bmm(out, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0:GradX:DDD::NN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0:GradX:DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): @@ -243,16 +304,23 @@ def testFFN_Linear0_GradW_DDD_TN(self, sl, hs, fhs, ne): out = torch.bmm(x, w) out = out.transpose(1, 2) - benchmark = lambda: torch.bmm(out, x) + def benchmark(): + return torch.bmm(out, x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("0:GradW:DDD::TN", arguments, mean_t, std_t, - x.numel() * fhs * 2) + log_benchmark( + '0:GradW:DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * fhs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): @@ -260,16 +328,23 @@ def testFFN_Linear1_Fwd_DDD_NN(self, sl, hs, fhs, ne): x = torch.randn((ne, sl // ne, fhs)).cuda().half() w = torch.randn((ne, fhs, hs)).cuda().half() - benchmark = lambda: torch.bmm(x, w) + def benchmark(): + return torch.bmm(x, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::Fwd::DDD::NN", arguments, mean_t, std_t, - x.numel() * hs * 2) + log_benchmark( + '1::Fwd::DDD::NN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): @@ -279,16 +354,23 @@ def testFFN_Linear1_GradX_DDD_NT(self, sl, hs, fhs, ne): out = torch.bmm(x, w) w = torch.transpose(w, 1, 2) - benchmark = lambda: torch.bmm(out, w) + def benchmark(): + return torch.bmm(out, w) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::GradX::DDD::NT", arguments, mean_t, std_t, - x.numel() * hs * 2) + log_benchmark( + '1::GradX::DDD::NT', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) @parameterized.parameters(*_MATMUL_TESTS) def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): @@ -298,16 +380,23 @@ def testFFN_Linear1_GradW_DDD_TN(self, sl, hs, fhs, ne): out = torch.bmm(x, w) x = torch.transpose(x, 1, 2) - benchmark = lambda: torch.bmm(x, out) + def benchmark(): + return torch.bmm(x, out) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "ffn_hidden_size": fhs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'ffn_hidden_size': fhs, + 'num_experts': ne, } - log_benchmark("1::GradW::DDD::TN", arguments, mean_t, std_t, - x.numel() * hs * 2) + log_benchmark( + '1::GradW::DDD::TN', + arguments, + mean_t, + std_t, + x.numel() * hs * 2, + ) if __name__ == '__main__': diff --git a/megablocks/ops/padded_gather.py b/megablocks/ops/padded_gather.py index 3c2685ff..f272a776 100644 --- a/megablocks/ops/padded_gather.py +++ b/megablocks/ops/padded_gather.py @@ -1,6 +1,11 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for padded_gather kernel. @@ -8,19 +13,43 @@ class PaddedGatherOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, bins, padded_bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): ctx.save_for_backward(indices, bin_ids, bins, padded_bins) ctx.top_k = top_k return kernels.padded_gather( - x, indices, bin_ids, None, bins, padded_bins, top_k) + x, + indices, + bin_ids, + None, + bins, + padded_bins, + top_k, + ) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() indices, bin_ids, bins, padded_bins = ctx.saved_tensors out = kernels.padded_scatter( - grad, indices, bin_ids, None, bins, padded_bins, ctx.top_k) + grad, + indices, + bin_ids, + None, + bins, + padded_bins, + ctx.top_k, + ) return out, None, None, None, None, None + + padded_gather = PaddedGatherOp.apply diff --git a/megablocks/ops/padded_scatter.py b/megablocks/ops/padded_scatter.py index 22ae9237..9ff81dd9 100644 --- a/megablocks/ops/padded_scatter.py +++ b/megablocks/ops/padded_scatter.py @@ -1,6 +1,11 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 +from typing import Any + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for padded_scatter kernel. @@ -8,18 +13,40 @@ class PaddedScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, weights, bins, padded_bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward( - indices, bin_ids, weights, bins, padded_bins, *maybe_x) + indices, + bin_ids, + weights, + bins, + padded_bins, + *maybe_x, + ) ctx.top_k = top_k ctx.x_shape = x.shape return kernels.padded_scatter( - x, indices, bin_ids, weights, bins, padded_bins, top_k) + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() saved_tensors = ctx.saved_tensors @@ -33,7 +60,8 @@ def backward(ctx, grad): weights, bins, padded_bins, - ctx.top_k) + ctx.top_k, + ) wgrad = None if ctx.needs_input_grad[3]: # need wgrad @@ -45,16 +73,26 @@ def backward(ctx, grad): bin_ids, bins, padded_bins, - ctx.top_k) + ctx.top_k, + ) return dgrad, None, None, wgrad, None, None, None, None -def padded_scatter(x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int): - return PaddedScatterOp.apply(x, indices, bin_ids, weights, bins, - padded_bins, top_k) +def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, +): + return PaddedScatterOp.apply( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) diff --git a/megablocks/ops/padded_scatter_benchmark.py b/megablocks/ops/padded_scatter_benchmark.py index 7a7c3378..81dde4e4 100644 --- a/megablocks/ops/padded_scatter_benchmark.py +++ b/megablocks/ops/padded_scatter_benchmark.py @@ -1,10 +1,12 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import ops -from megablocks import benchmark_util import torch +from absl.testing import parameterized +from megablocks import benchmark_util, ops _PADDED_SCATTER_BENCHMARK = ( # dMoE-Medium, 8-way EMP. @@ -35,18 +37,29 @@ def testPaddedScatter(self, sl, hs, ne, top_k): # Gather the data to prepare for backwards. x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - fn = lambda: ops.padded_scatter( - x, indices, bin_ids, weights, bins, padded_bins, top_k) + def benchmark(): + return ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) - time, std = benchmark_util.benchmark_function(fn) + time, std = benchmark_util.benchmark_function(benchmark) benchmark_util.log_benchmark( - "Padded Scatter", - {"sequence_length": sl, - "hidden_size": hs, - "num_experts": ne, - "top_k": top_k}, + 'Padded Scatter', + { + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, + 'top_k': top_k, + }, time, - std) + std, + ) if __name__ == '__main__': diff --git a/megablocks/ops/permute_benchmark.py b/megablocks/ops/permute_benchmark.py index fb5b7f1a..837f07e2 100644 --- a/megablocks/ops/permute_benchmark.py +++ b/megablocks/ops/permute_benchmark.py @@ -1,12 +1,12 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import benchmark_util -from megablocks import ops -import numpy as np -import stk import torch +from absl.testing import parameterized +from megablocks import benchmark_util, ops _PERMUTE_TESTS = ( (16384, 768, 2), @@ -22,7 +22,7 @@ (16384 * 8, 768, 16), (16384 * 8, 768, 32), (16384 * 8, 768, 64), - (16384 * 8, 768, 128) + (16384 * 8, 768, 128), ) @@ -40,14 +40,16 @@ def testBinnedGather(self, sl, hs, ne): tokens_per_expert = ops.histogram(indices, ne) bins = ops.inclusive_cumsum(tokens_per_expert, 0) - benchmark = lambda: ops.binned_gather(x, indices, bins, ec) + def benchmark(): + return ops.binned_gather(x, indices, bins, ec) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("BinnedGather", arguments, mean_t, std_t) + benchmark_util.log_benchmark('BinnedGather', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testBinnedScatter(self, sl, hs, ne): @@ -62,14 +64,16 @@ def testBinnedScatter(self, sl, hs, ne): bins = ops.inclusive_cumsum(tokens_per_expert, 0) x = ops.binned_gather(x, indices, bins, ec) - benchmark = lambda: ops.binned_scatter(x, indices, bins) + def benchmark(): + return ops.binned_scatter(x, indices, bins) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("BinnedScatter", arguments, mean_t, std_t) + benchmark_util.log_benchmark('BinnedScatter', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testPaddedGather(self, sl, hs, ne): @@ -84,14 +88,16 @@ def testPaddedGather(self, sl, hs, ne): padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) bins = ops.inclusive_cumsum(tokens_per_expert, 0) - benchmark = lambda: ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + def benchmark(): + return ops.padded_gather(x, indices, bin_ids, bins, padded_bins) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("PaddedGather", arguments, mean_t, std_t) + benchmark_util.log_benchmark('PaddedGather', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testPaddedScatter(self, sl, hs, ne): @@ -107,32 +113,36 @@ def testPaddedScatter(self, sl, hs, ne): bins = ops.inclusive_cumsum(tokens_per_expert, 0) x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins) - benchmark = lambda: ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + def benchmark(): + return ops.padded_scatter(x, indices, bin_ids, bins, padded_bins) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("PaddedScatter", arguments, mean_t, std_t) + benchmark_util.log_benchmark('PaddedScatter', arguments, mean_t, std_t) @parameterized.parameters(*_PERMUTE_TESTS) def testCopy(self, sl, hs, ne): # NOTE: Capacity factor == 1. - ec = sl // ne + # ec = sl // ne # Create the data and indices. x = torch.randn((sl, hs)).cuda().half() y = x.clone() - benchmark = lambda: y.copy_(x) + def benchmark(): + return y.copy_(x) + mean_t, std_t = benchmark_util.benchmark_function(benchmark) arguments = { - "sequence_length": sl, - "hidden_size": hs, - "num_experts": ne + 'sequence_length': sl, + 'hidden_size': hs, + 'num_experts': ne, } - benchmark_util.log_benchmark("Copy", arguments, mean_t, std_t) + benchmark_util.log_benchmark('Copy', arguments, mean_t, std_t) if __name__ == '__main__': diff --git a/megablocks/ops/repeat.py b/megablocks/ops/repeat.py index d02c9566..7e9e09de 100644 --- a/megablocks/ops/repeat.py +++ b/megablocks/ops/repeat.py @@ -1,7 +1,10 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch -def repeat(x, tiling): - if all([t == 1 for t in tiling]): +def repeat(x: torch.Tensor, tiling: torch.Size): + if all((t == 1 for t in tiling)): return x return x.repeat(*tiling) diff --git a/megablocks/ops/replicate.py b/megablocks/ops/replicate.py index 4d0cf344..2dbec35c 100644 --- a/megablocks/ops/replicate.py +++ b/megablocks/ops/replicate.py @@ -1,32 +1,36 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + # Autograd wrapper for replicate kernel. class ReplicateOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, bins, num_outputs): + def forward(ctx: Any, x: torch.Tensor, bins: torch.Tensor, num_outputs: int): ctx.save_for_backward(bins) - out = torch.empty( - (x.shape[0], num_outputs), - dtype=x.dtype, - device=x.device) + out = torch.empty((x.shape[0], num_outputs), dtype=x.dtype, device=x.device) ops.replicate_forward(x, bins, out) return out @staticmethod - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): bins, = ctx.saved_tensors - out = torch.empty( - (grad.shape[0], bins.shape[0]), - dtype=grad.dtype, - device=grad.device) + out = torch.empty((grad.shape[0], bins.shape[0]), dtype=grad.dtype, device=grad.device) ops.replicate_backward(grad, bins, out) return out, None, None + + replicate = ReplicateOp.apply diff --git a/megablocks/ops/round_up.py b/megablocks/ops/round_up.py index fc81d61f..6cf6bc87 100644 --- a/megablocks/ops/round_up.py +++ b/megablocks/ops/round_up.py @@ -1,11 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import torch -def round_up(x, value): +def round_up(x: torch.Tensor, value: int): assert isinstance(value, int) assert x.dtype == torch.int32 # TODO(tgale): If this becomes and issue # do this in a custom kernel. We only expect # to use this on arrays of less than 1k elements. - return torch.div(x + (value - 1), value, rounding_mode="trunc") * value + return torch.div(x + (value - 1), value, rounding_mode='trunc') * value diff --git a/megablocks/ops/scatter.py b/megablocks/ops/scatter.py index 0e91d80e..a5aaafc4 100644 --- a/megablocks/ops/scatter.py +++ b/megablocks/ops/scatter.py @@ -1,6 +1,12 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional + import torch +from stk.backend.autocast import custom_bwd, custom_fwd + from megablocks.backend import kernels -from stk.backend.autocast import custom_fwd, custom_bwd # Autograd wrapper for scatter kernel. @@ -8,17 +14,24 @@ class ScatterOp(torch.autograd.Function): @staticmethod @custom_fwd - def forward(ctx, x, indices, bin_ids, weights, bins, top_k): + def forward( + ctx: Any, + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ) -> torch.Tensor: maybe_x = [x] if ctx.needs_input_grad[3] else [] ctx.save_for_backward(indices, bin_ids, weights, bins, *maybe_x) ctx.top_k = top_k ctx.x_shape = x.shape - return kernels.scatter( - x, indices, bin_ids, weights, bins, top_k) + return kernels.scatter(x, indices, bin_ids, weights, bins, top_k) @staticmethod @custom_bwd - def backward(ctx, grad): + def backward(ctx: Any, grad: torch.Tensor): grad = grad.contiguous() saved_tensors = ctx.saved_tensors @@ -31,7 +44,8 @@ def backward(ctx, grad): bin_ids, weights, bins, - ctx.top_k) + ctx.top_k, + ) wgrad = None if ctx.needs_input_grad[3]: # need wgrad @@ -42,14 +56,17 @@ def backward(ctx, grad): indices, bin_ids, bins, - ctx.top_k) + ctx.top_k, + ) return dgrad, None, None, wgrad, None, None, None -def scatter(x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: torch.Tensor, - bins: torch.Tensor, - top_k: int): +def scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, +) -> Optional[torch.Tensor]: return ScatterOp.apply(x, indices, bin_ids, weights, bins, top_k) diff --git a/megablocks/ops/sort.py b/megablocks/ops/sort.py index a4bb99f5..4fb0aab4 100644 --- a/megablocks/ops/sort.py +++ b/megablocks/ops/sort.py @@ -1,12 +1,18 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Optional, Tuple + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops - +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e _BITS_FOR_DTYPE = { torch.int16: 16, @@ -14,17 +20,19 @@ torch.int64: 64, } + # Autograd wrapper for sort kernel. -# # NOTE: Does not support gradients. class SortOp(torch.autograd.Function): @staticmethod - def forward(ctx, x, end_bit=None): + def forward(ctx: Any, x: torch.Tensor, end_bit: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: if end_bit is None: end_bit = _BITS_FOR_DTYPE[x.dtype] x_out = torch.empty_like(x) iota_out = torch.empty_like(x) ops.sort(x, end_bit, x_out, iota_out) return (x_out, iota_out) + + sort = SortOp.apply diff --git a/megablocks/ops/sort_benchmark.py b/megablocks/ops/sort_benchmark.py index 4305767d..f28e3f2f 100644 --- a/megablocks/ops/sort_benchmark.py +++ b/megablocks/ops/sort_benchmark.py @@ -1,10 +1,13 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import unittest -from absl.testing import parameterized -from megablocks import ops import numpy as np import torch +from absl.testing import parameterized +from megablocks import ops _SORT_TESTS = ( (16384, torch.int32, None), @@ -12,16 +15,14 @@ (16384, torch.int32, 128), ) -_BASELINE_SORT_TESTS = ( - (16384,), -) +_BASELINE_SORT_TESTS = ((16384,),) def numpy_dtype(dtype): types = { torch.int16: np.int16, torch.int32: np.int32, - torch.int64: np.int64 + torch.int64: np.int64, } return types[dtype] @@ -43,13 +44,13 @@ def benchmark_function(fn, iterations=10): def log_benchmark(arguments, mean_t, std_t): - print("="*60) - print("Benchmark Parameters:") + print('=' * 60) + print('Benchmark Parameters:') for (key, value) in arguments.items(): - print(f"{key} = {value}") - print("Results:") - print("mean / std = {:.2f}ms / {:.2f}ms".format(mean_t, std_t)) - print("="*60) + print(f'{key} = {value}') + print('Results:') + print('mean / std = {:.2f}ms / {:.2f}ms'.format(mean_t, std_t)) + print('=' * 60) class SortBenchmark(parameterized.TestCase): @@ -61,12 +62,11 @@ def testSort(self, n, dtype, max_val): end_bit = int(np.ceil(np.log2(max_val))) x = torch.randint(0, max_val, (n,)).cuda().to(dtype) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: ops.sort(x, end_bit)) + mean_t, std_t, max_t, min_t = benchmark_function(lambda: ops.sort(x, end_bit),) arguments = { - "n": n, - "dtype": dtype, - "max_val": max_val + 'n': n, + 'dtype': dtype, + 'max_val': max_val, } log_benchmark(arguments, mean_t, std_t) @@ -74,9 +74,10 @@ def testSort(self, n, dtype, max_val): def testTorchSort(self, n): x = torch.randint(0, 128, (n,)).cuda().to(torch.int32) - mean_t, std_t, max_t, min_t = benchmark_function( - lambda: torch.sort(x)) - arguments = {"n": n,} + mean_t, std_t, max_t, min_t = benchmark_function(lambda: torch.sort(x)) + arguments = { + 'n': n, + } log_benchmark(arguments, mean_t, std_t) diff --git a/megablocks/ops/sum.py b/megablocks/ops/sum.py index 9d550b55..e00c1aa6 100644 --- a/megablocks/ops/sum.py +++ b/megablocks/ops/sum.py @@ -1,7 +1,9 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 import torch -def sum(x, dim=0): +def sum(x: torch.Tensor, dim: int = 0): if x.shape[dim] == 1: return x.squeeze(dim=dim) return x.sum(dim=dim) diff --git a/megablocks/ops/topology.py b/megablocks/ops/topology.py index 7ce31bc2..b41b5fa5 100644 --- a/megablocks/ops/topology.py +++ b/megablocks/ops/topology.py @@ -1,30 +1,45 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any + # NOTE: Torch needs to be imported before the custom # extensions. Otherwise libc10.so cannot be found. import torch -# TODO(tgale): Wrap this in a try-block with better -# error message and instructions for building the -# c++ operations. -import megablocks_ops as ops +# Wrap this in a try-block with better error message and +# instructions for building the c++ operations. +try: + import megablocks_ops as ops # type: ignore +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e + # Autograd wrapper for topology kernel. -# # NOTE: Does not support gradients. class TopologyOp(torch.autograd.Function): @staticmethod - def forward(ctx, - padded_bins, - block_size, - output_block_rows, - output_block_columns): - out = torch.empty(output_block_rows * output_block_columns, - dtype=torch.int16, - device=padded_bins.device) - ops.indices(padded_bins, - block_size, - output_block_rows, - output_block_columns, - out) + def forward( + ctx: Any, + padded_bins: torch.Tensor, + block_size: int, + output_block_rows: int, + output_block_columns: int, + ): + out = torch.empty( + output_block_rows * output_block_columns, + dtype=torch.int16, + device=padded_bins.device, + ) + ops.indices( + padded_bins, + block_size, + output_block_rows, + output_block_columns, + out, + ) return out + + topology = TopologyOp.apply diff --git a/megablocks/py.typed b/megablocks/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index b4b90ecf..ad4dc1bb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,9 @@ +# Copyright 2024 Databricks authors +# SPDX-License-Identifier: Apache-2.0 + # build requirements [build-system] -requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4"] +requires = ["setuptools < 70.0.0", "torch >= 2.4.0, < 2.4.1"] build-backend = "setuptools.build_meta" # Pytest @@ -44,3 +47,483 @@ concurrency = ["thread"] include = [ "megablocks/*" ] + + +# Ruff global +[tool.ruff] + +preview = true # enable preview features, see https://docs.astral.sh/ruff/preview/ + +exclude = [ + "build/**", + "docs/**", + "node_modules/**", +] + +# Ruff linter +[tool.ruff.lint] +select = [ + "C4", # flake8-comprehensions + # TODO port pydocstyle + # "D", # pydocstyle + "LOG", # flake8-logging + "PERF", # Perflint + "PLE", + "COM812", # missing-trailing-comma +] + +# iSort +[tool.isort] +multi_line_output = 0 +line_length = 120 +skip = ["env", "wandb", "runs", "build", "node_modules" ] +include_trailing_comma = true +split_on_trailing_comma = true + +# Yapf +[tool.yapf] + +# Align closing bracket with visual indentation. +align_closing_bracket_with_visual_indent = false + +# Allow dictionary keys to exist on multiple lines. For example: +# +# x = { +# ('this is the first element of a tuple', +# 'this is the second element of a tuple'): +# value, +# } +allow_multiline_dictionary_keys = false + +# Allow lambdas to be formatted on more than one line. +allow_multiline_lambdas = false + +# Allow splitting before a default / named assignment in an argument list. +allow_split_before_default_or_named_assigns = true + +# Allow splits before the dictionary value. +allow_split_before_dict_value = true + +# Let spacing indicate operator precedence. For example: +# +# a = 1 * 2 + 3 / 4 +# b = 1 / 2 - 3 * 4 +# c = (1 + 2) * (3 - 4) +# d = (1 - 2) / (3 + 4) +# e = 1 * 2 - 3 +# f = 1 + 2 + 3 + 4 +# +# will be formatted as follows to indicate precedence: +# +# a = 1*2 + 3/4 +# b = 1/2 - 3*4 +# c = (1+2) * (3-4) +# d = (1-2) / (3+4) +# e = 1*2 - 3 +# f = 1 + 2 + 3 + 4 +# +arithmetic_precedence_indication = false + +# Number of blank lines surrounding top-level function and class +# definitions. +blank_lines_around_top_level_definition = 2 + +# Insert a blank line before a class-level docstring. +blank_line_before_class_docstring = false + +# Insert a blank line before a module docstring. +blank_line_before_module_docstring = true + +# Insert a blank line before a 'def' or 'class' immediately nested +# within another 'def' or 'class'. For example: +# +# class Foo: +# # <------ this blank line +# def method(): +# ... +blank_line_before_nested_class_or_def = true + +# Do not split consecutive brackets. Only relevant when +# dedent_closing_brackets is set. For example: +# +# call_func_that_takes_a_dict( +# { +# 'key1': 'value1', +# 'key2': 'value2', +# } +# ) +# +# would reformat to: +# +# call_func_that_takes_a_dict({ +# 'key1': 'value1', +# 'key2': 'value2', +# }) +coalesce_brackets = true + +# The column limit. +column_limit = 120 + +# The style for continuation alignment. Possible values are: +# +# - SPACE: Use spaces for continuation alignment. This is default behavior. +# - FIXED: Use fixed number (CONTINUATION_INDENT_WIDTH) of columns +# (ie: CONTINUATION_INDENT_WIDTH/INDENT_WIDTH tabs or +# CONTINUATION_INDENT_WIDTH spaces) for continuation alignment. +# - VALIGN-RIGHT: Vertically align continuation lines to multiple of +# INDENT_WIDTH columns. Slightly right (one tab or a few spaces) if +# cannot vertically align continuation lines with indent characters. +continuation_align_style = 'SPACE' + +# Indent width used for line continuations. +continuation_indent_width = 4 + +# Put closing brackets on a separate line, dedented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is dedented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is dedented and on a separate line +dedent_closing_brackets = true + +# Disable the heuristic which places each list element on a separate line +# if the list is comma-terminated. +disable_ending_comma_heuristic = false + +# Place each dictionary entry onto its own line. +each_dict_entry_on_separate_line = true + +# Require multiline dictionary even if it would normally fit on one line. +# For example: +# +# config = { +# 'key1': 'value1' +# } +force_multiline_dict = false + +# The regex for an i18n comment. The presence of this comment stops +# reformatting of that line, because the comments are required to be +# next to the string they translate. +i18n_comment = '#\..*' + +# The i18n function call names. The presence of this function stops +# reformattting on that line, because the string it has cannot be moved +# away from the i18n comment. +i18n_function_call = 'N_, _' + +# Indent blank lines. +indent_blank_lines = false + +# Put closing brackets on a separate line, indented, if the bracketed +# expression can't fit in a single line. Applies to all kinds of brackets, +# including function definitions and calls. For example: +# +# config = { +# 'key1': 'value1', +# 'key2': 'value2', +# } # <--- this bracket is indented and on a separate line +# +# time_series = self.remote_client.query_entity_counters( +# entity='dev3246.region1', +# key='dns.query_latency_tcp', +# transform=Transformation.AVERAGE(window=timedelta(seconds=60)), +# start_ts=now()-timedelta(days=3), +# end_ts=now(), +# ) # <--- this bracket is indented and on a separate line +indent_closing_brackets = false + +# Indent the dictionary value if it cannot fit on the same line as the +# dictionary key. For example: +# +# config = { +# 'key1': +# 'value1', +# 'key2': value1 + +# value2, +# } +indent_dictionary_value = true + +# The number of columns to use for indentation. +indent_width = 4 + +# Join short lines into one line. E.g., single line 'if' statements. +join_multiple_lines = false + +# Do not include spaces around selected binary operators. For example: +# +# 1 + 2 * 3 - 4 / 5 +# +# will be formatted as follows when configured with "*,/": +# +# 1 + 2*3 - 4/5 +no_spaces_around_selected_binary_operators = '' + +# Use spaces around default or named assigns. +spaces_around_default_or_named_assign = false + +# Adds a space after the opening '{' and before the ending '}' dict delimiters. +# +# {1: 2} +# +# will be formatted as: +# +# { 1: 2 } +spaces_around_dict_delimiters = false + +# Adds a space after the opening '[' and before the ending ']' list delimiters. +# +# [1, 2] +# +# will be formatted as: +# +# [ 1, 2 ] +spaces_around_list_delimiters = false + +# Use spaces around the power operator. +spaces_around_power_operator = false + +# Use spaces around the subscript / slice operator. For example: +# +# my_list[1 : 10 : 2] +spaces_around_subscript_colon = false + +# Adds a space after the opening '(' and before the ending ')' tuple delimiters. +# +# (1, 2, 3) +# +# will be formatted as: +# +# ( 1, 2, 3 ) +spaces_around_tuple_delimiters = false + +# The number of spaces required before a trailing comment. +# This can be a single value (representing the number of spaces +# before each trailing comment) or list of values (representing +# alignment column values; trailing comments within a block will +# be aligned to the first column value that is greater than the maximum +# line length within the block). For example: +# +# With spaces_before_comment=5: +# +# 1 + 1 # Adding values +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- 5 spaces between the end of the statement and comment +# +# With spaces_before_comment = '15, 20:' +# +# 1 + 1 # Adding values +# two + two # More adding +# +# longer_statement # This is a longer statement +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment +# short # This is a shorter statement +# +# will be formatted as: +# +# 1 + 1 # Adding values <-- end of line comments in block aligned to col 15 +# two + two # More adding +# +# longer_statement # This is a longer statement <-- end of line comments in block aligned to col 20 +# short # This is a shorter statement +# +# a_very_long_statement_that_extends_beyond_the_final_column # Comment <-- the end of line comments are aligned based on the line length +# short # This is a shorter statement +# +spaces_before_comment = 2 + +# Insert a space between the ending comma and closing bracket of a list, +# etc. +space_between_ending_comma_and_closing_bracket = false + +# Use spaces inside brackets, braces, and parentheses. For example: +# +# method_call( 1 ) +# my_dict[ 3 ][ 1 ][ get_index( *args, **kwargs ) ] +# my_set = { 1, 2, 3 } +space_inside_brackets = false + +# Split before arguments +split_all_comma_separated_values = false + +# Split before arguments, but do not split all subexpressions recursively +# (unless needed). +split_all_top_level_comma_separated_values = false + +# Split before arguments if the argument list is terminated by a +# comma. +split_arguments_when_comma_terminated = true + +# Set to True to prefer splitting before '+', '-', '*', '/', '//', or '@' +# rather than after. +split_before_arithmetic_operator = false + +# Set to True to prefer splitting before '&', '|' or '^' rather than +# after. +split_before_bitwise_operator = false + +# Split before the closing bracket if a list or dict literal doesn't fit on +# a single line. +split_before_closing_bracket = true + +# Split before a dictionary or set generator (comp_for). For example, note +# the split before the 'for': +# +# foo = { +# variable: 'Hello world, have a nice day!' +# for variable in bar if variable != 42 +# } +split_before_dict_set_generator = false + +# Split before the '.' if we need to split a longer expression: +# +# foo = ('This is a really long string: {}, {}, {}, {}'.format(a, b, c, d)) +# +# would reformat to something like: +# +# foo = ('This is a really long string: {}, {}, {}, {}' +# .format(a, b, c, d)) +split_before_dot = false + +# Split after the opening paren which surrounds an expression if it doesn't +# fit on a single line. +split_before_expression_after_opening_paren = false + +# If an argument / parameter list is going to be split, then split before +# the first argument. +split_before_first_argument = false + +# Set to True to prefer splitting before 'and' or 'or' rather than +# after. +split_before_logical_operator = false + +# Split named assignments onto individual lines. +split_before_named_assigns = true + +# Set to True to split list comprehensions and generators that have +# non-trivial expressions and multiple clauses before each of these +# clauses. For example: +# +# result = [ +# a_long_var + 100 for a_long_var in xrange(1000) +# if a_long_var % 10] +# +# would reformat to something like: +# +# result = [ +# a_long_var + 100 +# for a_long_var in xrange(1000) +# if a_long_var % 10] +split_complex_comprehension = true + +# The penalty for splitting right after the opening bracket. +split_penalty_after_opening_bracket = 300 + +# The penalty for splitting the line after a unary operator. +split_penalty_after_unary_operator = 10000 + +# The penalty of splitting the line around the '+', '-', '*', '/', '//', +# ``%``, and '@' operators. +split_penalty_arithmetic_operator = 300 + +# The penalty for splitting right before an if expression. +split_penalty_before_if_expr = 0 + +# The penalty of splitting the line around the '&', '|', and '^' +# operators. +split_penalty_bitwise_operator = 300 + +# The penalty for splitting a list comprehension or generator +# expression. +split_penalty_comprehension = 2100 + +# The penalty for characters over the column limit. +split_penalty_excess_character = 7000 + +# The penalty incurred by adding a line split to the unwrapped line. The +# more line splits added the higher the penalty. +split_penalty_for_added_line_split = 20 + +# The penalty of splitting a list of "import as" names. For example: +# +# from a_very_long_or_indented_module_name_yada_yad import (long_argument_1, +# long_argument_2, +# long_argument_3) +# +# would reformat to something like: +# +# from a_very_long_or_indented_module_name_yada_yad import ( +# long_argument_1, long_argument_2, long_argument_3) +split_penalty_import_names = 0 + +# The penalty of splitting the line around the 'and' and 'or' +# operators. +split_penalty_logical_operator = 300 + +# Use the Tab character for indentation. +use_tabs = false + +# Ignore directories +[tool.yapfignore] +ignore_patterns = [ + "runs/**/*.py", + "wandb/**/*.py", + "build/**/*.py", +] + +# PyDocStyle +[tool.pydocstyle] +convention="google" +add_ignore="D100,D101,D102,D103,D104,D105,D107,D400,D401,D415" +add_select="D404" + + +# Pyright +[tool.pyright] +exclude = ['env-**', 'venv*', '.venv', 'tests/*', '**benchmark'] +stubPath = "" # suppress useless 'stubPath is not a valid directory' errors + +reportUnnecessaryIsInstance = "none" # it is ok to do this for clarity or safety +reportMissingTypeStubs = "none" +reportIncompatibleMethodOverride = "none" +reportIncompatibleVariableOverride = "error" +reportUnusedImport = "error" +reportUnusedClass = "warning" +reportUnusedFunction = "warning" +reportUnusedVariable = "error" +reportDuplicateImport = "error" +reportWildcardImportFromLibrary = "error" +reportUntypedFunctionDecorator = "warning" +reportPrivateImportUsage = "none" +reportUndefinedVariable = "error" +strictParameterNoneValue = true +reportPropertyTypeMismatch = "error" +reportUntypedNamedTuple = "error" +reportUnnecessaryCast = "error" +reportInvalidTypeVarUse = "error" +reportOverlappingOverload = "error" +reportUninitializedInstanceVariable = "error" +reportInvalidStringEscapeSequence = "error" +reportMissingParameterType = "error" +reportCallInDefaultInitializer = "error" +reportUnnecessaryComparison = "error" +reportSelfClsParameterName = "error" +reportImplicitStringConcatenation = "warning" # TODO: make this an error +reportInvalidStubStatement = "error" +reportIncompleteStub = "error" +reportUnsupportedDunderAll = "error" +reportUnusedCoroutine = "error" +reportMissingImports = "none" diff --git a/setup.py b/setup.py index ac1b43fc..4b7be375 100644 --- a/setup.py +++ b/setup.py @@ -1,50 +1,80 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +"""MegaBlocks package setup.""" + import os +import warnings +from typing import Any, Dict, Mapping -import torch from setuptools import find_packages, setup -from torch.utils.cpp_extension import BuildExtension, CUDAExtension -if os.environ.get("TORCH_CUDA_ARCH_LIST"): - # Let PyTorch builder to choose device to target for. - device_capability = "" -else: - device_capability = torch.cuda.get_device_capability() - device_capability = f"{device_capability[0]}{device_capability[1]}" +# We require torch in setup.py to build cpp extensions "ahead of time" +# More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html +try: + import torch + from torch.utils.cpp_extension import CUDA_HOME, BuildExtension, CUDAExtension +except ModuleNotFoundError as e: + raise ModuleNotFoundError("No module named 'torch'. `torch` is required to install `MegaBlocks`.",) from e -nvcc_flags = [ - "--ptxas-options=-v", - "--optimize=2", -] -if device_capability: - nvcc_flags.append( - f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}" - ) +_PACKAGE_NAME = 'megablocks' +_PACKAGE_DIR = 'megablocks' +_REPO_REAL_PATH = os.path.dirname(os.path.realpath(__file__)) +_PACKAGE_REAL_PATH = os.path.join(_REPO_REAL_PATH, _PACKAGE_DIR) -ext_modules = [ - CUDAExtension( - "megablocks_ops", - ["csrc/ops.cu"], - include_dirs=["csrc"], - extra_compile_args={"cxx": ["-fopenmp"], "nvcc": nvcc_flags}, - ) +# Read the package version +# We can't use `.__version__` from the library since it's not installed yet +version_path = os.path.join(_PACKAGE_REAL_PATH, '_version.py') +with open(version_path, encoding='utf-8') as f: + version_globals: Dict[str, Any] = {} + version_locals: Mapping[str, object] = {} + content = f.read() + exec(content, version_globals, version_locals) + repo_version = version_locals['__version__'] + +with open('README.md', 'r', encoding='utf-8') as fh: + long_description = fh.read() + +# Hide the content between and +# tags in the README +while True: + start_tag = '' + end_tag = '' + start = long_description.find(start_tag) + end = long_description.find(end_tag) + if start == -1: + assert end == -1, 'there should be a balanced number of start and ends' + break + else: + assert end != -1, 'there should be a balanced number of start and ends' + long_description = long_description[:start] + \ + long_description[end + len(end_tag):] + +classifiers = [ + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.9', + 'Programming Language :: Python :: 3.10', + 'Programming Language :: Python :: 3.11', + 'License :: OSI Approved :: BSD License', + 'Operating System :: Unix', ] install_requires = [ 'numpy>=1.21.5,<2.1.0', - 'torch>=2.3.0,<2.4', - 'triton>=2.1.0', - 'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301', 'packaging>=21.3.0,<24.2', + 'torch>=2.4.0,<2.4.1', + 'triton>=2.1.0', + 'stanford-stk==0.7.1', ] extra_deps = {} -extra_deps["gg"] = [ - 'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb', +extra_deps['gg'] = [ + 'grouped_gemm==0.1.6', ] extra_deps['dev'] = [ - 'absl-py', # todo: delete when finish removing all absl tests + 'absl-py', # TODO: delete when finish removing all absl tests 'coverage[toml]==7.4.4', 'pytest_codeblocks>=0.16.1,<0.17', 'pytest-cov>=4,<5', @@ -53,31 +83,65 @@ ] extra_deps['testing'] = [ - 'mosaicml>=0.22.0', + 'mosaicml>=0.24.1', ] -extra_deps['all'] = list({ - dep for key, deps in extra_deps.items() for dep in deps - if key not in {'testing'} -}) +extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) + +cmdclass = {} +ext_modules = [] + +# Only install CUDA extensions if available +if 'cu' in torch.__version__ and CUDA_HOME is not None: + + cmdclass = {'build_ext': BuildExtension} + nvcc_flags = ['--ptxas-options=-v', '--optimize=2'] + + if os.environ.get('TORCH_CUDA_ARCH_LIST'): + # Let PyTorch builder to choose device to target for. + device_capability = '' + else: + device_capability_tuple = torch.cuda.get_device_capability() + device_capability = f'{device_capability_tuple[0]}{device_capability_tuple[1]}' + + if device_capability: + nvcc_flags.append(f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}',) + + ext_modules = [ + CUDAExtension( + 'megablocks_ops', + ['csrc/ops.cu'], + include_dirs=['csrc'], + extra_compile_args={ + 'cxx': ['-fopenmp'], + 'nvcc': nvcc_flags, + }, + ), + ] +elif CUDA_HOME is None: + warnings.warn( + 'Attempted to install CUDA extensions, but CUDA_HOME was None. ' + + 'Please install CUDA and ensure that the CUDA_HOME environment ' + + 'variable points to the installation location.', + ) +else: + warnings.warn('Warning: No CUDA devices; cuda code will not be compiled.') setup( - name="megablocks", - version="0.5.1", - author="Trevor Gale", - author_email="tgale@stanford.edu", - description="MegaBlocks", - long_description=open('README.md').read(), + name=_PACKAGE_NAME, + version=repo_version, + author='Trevor Gale', + author_email='tgale@stanford.edu', + description='MegaBlocks', + long_description=long_description, long_description_content_type='text/markdown', - url="https://github.com/stanford-futuredata/megablocks", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - ], - packages=find_packages(), + url='https://github.com/databricks/megablocks', + classifiers=classifiers, + packages=find_packages(exclude=['tests*', 'third_party*', 'yamls*', 'exp*', '.github*']), ext_modules=ext_modules, - cmdclass={"build_ext": BuildExtension}, + cmdclass=cmdclass, install_requires=install_requires, extras_require=extra_deps, + python_requires='>=3.9', + package_data={_PACKAGE_NAME: ['py.typed']}, ) diff --git a/tests/conftest.py b/tests/conftest.py index 335140ce..663bda39 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import os from typing import List, Optional @@ -13,7 +16,7 @@ # Add the path of any pytest fixture files you want to make global pytest_plugins = [ 'tests.fixtures.autouse', - 'tests.fixtures.fixtures' + 'tests.fixtures.fixtures', ] @@ -23,8 +26,11 @@ def _get_world_size(item: pytest.Item): return item.get_closest_marker('world_size', default=_default).args[0] - -def _get_option(config: pytest.Config, name: str, default: Optional[str] = None) -> str: # type: ignore +def _get_option( + config: pytest.Config, + name: str, + default: Optional[str] = None, +) -> str: # type: ignore val = config.getoption(name) if val is not None: assert isinstance(val, str) @@ -34,13 +40,18 @@ def _get_option(config: pytest.Config, name: str, default: Optional[str] = None) val = None if val is None: if default is None: - pytest.fail(f'Config option {name} is not specified but is required') + pytest.fail(f'Config option {name} is not specified but is required',) val = default assert isinstance(val, str) return val -def _add_option(parser: pytest.Parser, name: str, help: str, choices: Optional[list[str]] = None): +def _add_option( + parser: pytest.Parser, + name: str, + help: str, + choices: Optional[list[str]] = None, +): parser.addoption( f'--{name}', default=None, diff --git a/tests/fixtures/autouse.py b/tests/fixtures/autouse.py index 29fbdeb8..6805f3c1 100644 --- a/tests/fixtures/autouse.py +++ b/tests/fixtures/autouse.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import gc import logging import os @@ -20,7 +23,7 @@ def clear_cuda_cache(request: pytest.FixtureRequest): @pytest.fixture(autouse=True) def reset_mlflow_tracking_dir(): - """Reset MLFlow tracking dir so it doesn't persist across tests""" + """Reset MLFlow tracking dir so it doesn't persist across tests.""" try: import mlflow mlflow.set_tracking_uri(None) # type: ignore @@ -72,9 +75,16 @@ def set_log_levels(): @pytest.fixture(autouse=True) def seed_all(rank_zero_seed: int, monkeypatch: pytest.MonkeyPatch): - """Monkeypatch reproducibility get_random_seed to always return the rank zero seed, and set the random seed before - each test to the rank local seed.""" - monkeypatch.setattr(reproducibility, 'get_random_seed', lambda: rank_zero_seed) + """Monkeypatch reproducibility. + + Make get_random_seed to always return the rank zero seed, and set the random seed before each test to the rank local + seed. + """ + monkeypatch.setattr( + reproducibility, + 'get_random_seed', + lambda: rank_zero_seed, + ) reproducibility.seed_all(rank_zero_seed + dist.get_global_rank()) diff --git a/tests/fixtures/fixtures.py b/tests/fixtures/fixtures.py index 48645a89..4039db70 100644 --- a/tests/fixtures/fixtures.py +++ b/tests/fixtures/fixtures.py @@ -1,4 +1,8 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import pytest + from tests.conftest import _get_option diff --git a/tests/layers/architectures.py b/tests/layers/architectures.py new file mode 100644 index 00000000..da1c5950 --- /dev/null +++ b/tests/layers/architectures.py @@ -0,0 +1,53 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + +import torch +import torch.nn.functional as F + +from megablocks.layers.arguments import Arguments + + +class FFN(torch.nn.Module): + + def __init__(self, args: Arguments): + super().__init__() + self.w1 = torch.nn.Parameter( + torch.empty( + args.hidden_size, + args.ffn_hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32, + ), + ) + self.w2 = torch.nn.Parameter( + torch.empty( + args.ffn_hidden_size, + args.hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32, + ), + ) + + def forward(self, x): + return torch.matmul( + F.gelu(torch.matmul(x, self.w1), approximate='tanh'), + self.w2, + ) + + +class GLU(FFN): + + def __init__(self, args: Arguments): + super().__init__(args) + self.v1 = torch.nn.Parameter( + torch.empty( + args.hidden_size, + args.ffn_hidden_size, + device=args.device, + dtype=torch.float16 if args.fp16 else torch.float32, + ), + ) + + def forward(self, x): + x1 = F.gelu(torch.matmul(x, self.w1), approximate='tanh') * torch.matmul(x, self.v1) + return torch.matmul(x1, self.w2) diff --git a/tests/layers/dmoe_test.py b/tests/layers/dmoe_test.py index 3ead862a..3d6565c8 100644 --- a/tests/layers/dmoe_test.py +++ b/tests/layers/dmoe_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from functools import partial @@ -7,8 +7,10 @@ import torch from megablocks import grouped_gemm_util as gg -from megablocks.layers import dmoe, moe, testing from megablocks.layers.arguments import Arguments +from megablocks.layers.dmoe import dMoE +from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss +from tests.layers.architectures import FFN # min size: (1, 2, 128, 2, 1) _FORWARD_TESTS_DEFAULT = ( @@ -28,13 +30,10 @@ (16, 1024, 128, 1, 1), ) -_FORWARD_TESTS_GROUPED_MLP = tuple([ - p + ('grouped',) for p in _FORWARD_TESTS_DEFAULT -]) if gg.grouped_gemm_is_available() else () +_FORWARD_TESTS_GROUPED_MLP = tuple([p + ('grouped',) for p in _FORWARD_TESTS_DEFAULT + ],) if gg.grouped_gemm_is_available() else () -_FORWARD_TESTS_SPARSE_MLP = tuple([ - p + ('sparse',) for p in _FORWARD_TESTS_DEFAULT -]) +_FORWARD_TESTS_SPARSE_MLP = tuple([p + ('sparse',) for p in _FORWARD_TESTS_DEFAULT]) _FORWARD_TESTS = (_FORWARD_TESTS_SPARSE_MLP + _FORWARD_TESTS_GROUPED_MLP) @@ -44,28 +43,32 @@ ) -def construct_moes(hidden_size: int, - ffn_hidden_size: int, - moe_num_experts: int = 1, - moe_capacity_factor: int = 1, - moe_top_k: int = 1, - mlp_impl: str = 'sparse'): +def construct_moes( + hidden_size: int, + ffn_hidden_size: int, + moe_num_experts: int = 1, + moe_capacity_factor: int = 1, + moe_top_k: int = 1, + mlp_impl: str = 'sparse', +): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) - args = Arguments(hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=moe_num_experts, - moe_capacity_factor=moe_capacity_factor, - moe_top_k=moe_top_k, - init_method=init_method, - memory_optimized_mlp=True, - mlp_type='mlp', - mlp_impl=mlp_impl, - fp16=False, - bf16=True) - - mlp = testing.FFN(args) - moe_mlp = moe.MoE(args) - dmoe_mlp = dmoe.dMoE(args) + args = Arguments( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + moe_num_experts=moe_num_experts, + moe_capacity_factor=moe_capacity_factor, + moe_top_k=moe_top_k, + init_method=init_method, + memory_optimized_mlp=True, + mlp_type='mlp', + mlp_impl=mlp_impl, + fp16=False, + bf16=True, + ) + + mlp = FFN(args) + moe_mlp = MoE(args) + dmoe_mlp = dMoE(args) mlp.cuda(torch.cuda.current_device()).to(torch.bfloat16) moe_mlp.cuda(torch.cuda.current_device()).to(torch.bfloat16) @@ -76,8 +79,7 @@ def construct_moes(hidden_size: int, ne, hs, fhs = moe_mlp.experts.mlp.w1.size() w1 = dmoe_mlp.experts.mlp.w1.view([ne, fhs, hs]) moe_mlp.experts.mlp.w1.copy_(torch.transpose(w1, 1, 2).contiguous()) - moe_mlp.experts.mlp.w2.copy_(dmoe_mlp.experts.mlp.w2.view([ne, fhs, - hs])) + moe_mlp.experts.mlp.w2.copy_(dmoe_mlp.experts.mlp.w2.view([ne, fhs, hs]),) moe_mlp.router.layer.weight.copy_(dmoe_mlp.router.layer.weight) if moe_num_experts == 1: mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze()) @@ -86,68 +88,78 @@ def construct_moes(hidden_size: int, @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), - _FORWARD_TESTS) -def test_dmoe_forward(bs: int, - sl: int, - hs: int, - num_experts: int, - top_k: int, - mlp_impl: str): +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) +def test_dmoe_forward( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() - _, _, _, layer = construct_moes(hidden_size=hs, - ffn_hidden_size=hs * 2, - moe_num_experts=num_experts, - moe_top_k=top_k, - mlp_impl=mlp_impl) + _, _, _, layer = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=num_experts, + moe_top_k=top_k, + mlp_impl=mlp_impl, + ) out, _ = layer(x) assert out.shape == x.shape - moe.clear_load_balancing_loss() + clear_load_balancing_loss() @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), - _FORWARD_TESTS) -def test_dmoe_forward_backward(bs: int, - sl: int, - hs: int, - num_experts: int, - top_k: int, - mlp_impl: str): +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) +def test_dmoe_forward_backward( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() x.requires_grad_(True) - args, _, _, layer = construct_moes(hidden_size=hs, - ffn_hidden_size=hs * 2, - moe_num_experts=num_experts, - moe_top_k=top_k, - mlp_impl=mlp_impl) + args, _, _, layer = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=num_experts, + moe_top_k=top_k, + mlp_impl=mlp_impl, + ) out, _ = layer(x) assert out.shape == x.shape - loss = out.sum() + moe.batched_load_balancing_loss(args) + loss = out.sum() + batched_load_balancing_loss(args) loss.backward() assert x.grad is not None layer.zero_grad(set_to_none=True) x.grad = None - moe.clear_load_balancing_loss() + clear_load_balancing_loss() @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) -def test_dmoe_forward_vs_baseline(bs: int, - sl: int, - hs: int, - mlp_impl: str = 'sparse'): +def test_dmoe_forward_vs_baseline( + bs: int, + sl: int, + hs: int, + mlp_impl: str = 'sparse', +): x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() - _, mlp, _, dmoe_mlp = construct_moes(hidden_size=hs, - ffn_hidden_size=hs * 2, - moe_num_experts=1, - moe_capacity_factor=1, - moe_top_k=1, - mlp_impl=mlp_impl) + _, mlp, _, dmoe_mlp = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs * 2, + moe_num_experts=1, + moe_capacity_factor=1, + moe_top_k=1, + mlp_impl=mlp_impl, + ) expected_out = mlp(x) out, _ = dmoe_mlp(x) @@ -156,23 +168,26 @@ def test_dmoe_forward_vs_baseline(bs: int, @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), - _FORWARD_TESTS) -def test_dmoe_forward_vs_moe(bs: int, - sl: int, - hs: int, - num_experts: int, - top_k: int, - mlp_impl: str): +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k', 'mlp_impl'), _FORWARD_TESTS) +def test_dmoe_forward_vs_moe( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, + mlp_impl: str, +): torch.manual_seed(42) x = torch.randn(sl, bs, hs).to(torch.bfloat16).cuda() - _, _, moe_mlp, dmoe_mlp = construct_moes(hidden_size=hs, - ffn_hidden_size=hs, - moe_num_experts=num_experts, - moe_capacity_factor=0, - mlp_impl=mlp_impl) + _, _, moe_mlp, dmoe_mlp = construct_moes( + hidden_size=hs, + ffn_hidden_size=hs, + moe_num_experts=num_experts, + moe_capacity_factor=0, + mlp_impl=mlp_impl, + ) expected_out, _ = moe_mlp(x) out, _ = dmoe_mlp(x) diff --git a/tests/layers/glu_test.py b/tests/layers/glu_test.py index 0487ec88..1e031ded 100644 --- a/tests/layers/glu_test.py +++ b/tests/layers/glu_test.py @@ -1,12 +1,15 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from functools import partial import pytest import stk import torch -from megablocks.layers import dmlp_registry, testing +from megablocks.layers import dmlp_registry from megablocks.layers.arguments import Arguments -from megablocks.layers.glu import GroupedGLU, SparseGLU +from tests.layers.architectures import GLU _DENSE_TESTS = ( (16, 1024, 512), @@ -15,10 +18,11 @@ def construct_dmoe_glu( - hidden_size: int, - ffn_hidden_size: int, - mlp_impl: str ='sparse', - memory_optimized_mlp: bool =False): + hidden_size: int, + ffn_hidden_size: int, + mlp_impl: str = 'sparse', + memory_optimized_mlp: bool = False, +): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( hidden_size=hidden_size, @@ -30,9 +34,10 @@ def construct_dmoe_glu( mlp_type='glu', mlp_impl=mlp_impl, fp16=False, - bf16=True) + bf16=True, + ) - glu = testing.GLU(args) + glu = GLU(args) dmoe_glu = dmlp_registry.get(args) dmoe_glu.cuda(torch.cuda.current_device()).to(torch.bfloat16) @@ -46,7 +51,6 @@ def construct_dmoe_glu( return args, glu, dmoe_glu - @pytest.mark.gpu @pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) def test_glu_forward_grouped_mlp(bs: int, sl: int, hs: int): @@ -55,7 +59,8 @@ def test_glu_forward_grouped_mlp(bs: int, sl: int, hs: int): _, glu, dmoe_glu = construct_dmoe_glu( hidden_size=hs, ffn_hidden_size=hs * 2, - mlp_impl='grouped') + mlp_impl='grouped', + ) expected_out = glu(x) tokens_per_expert = torch.tensor([bs * sl]).cuda() @@ -75,7 +80,8 @@ def test_glu_forward_grouped_mlp_mem_opt(bs: int, sl: int, hs: int): hidden_size=hs, ffn_hidden_size=hs * 2, mlp_impl='grouped', - memory_optimized_mlp=True) + memory_optimized_mlp=True, + ) expected_out = glu(x) tokens_per_expert = torch.tensor([bs * sl]).cuda() @@ -94,7 +100,8 @@ def test_glu_forward_sparse_mlp(bs: int, sl: int, hs: int): _, glu, dmoe_glu = construct_dmoe_glu( hidden_size=hs, ffn_hidden_size=hs * 2, - mlp_impl='sparse') + mlp_impl='sparse', + ) expected_out = glu(x) with torch.no_grad(): diff --git a/tests/layers/moe_test.py b/tests/layers/moe_test.py index 75ea196f..ffd32cbf 100644 --- a/tests/layers/moe_test.py +++ b/tests/layers/moe_test.py @@ -1,10 +1,14 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + from functools import partial import pytest import torch -from megablocks.layers import moe, testing from megablocks.layers.arguments import Arguments +from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss +from tests.layers.architectures import FFN _FORWARD_TESTS = ( (16, 1024, 512, 1, 1), @@ -22,7 +26,6 @@ (16, 1024, 512, 8, 8), ) - _DENSE_TESTS = ( (16, 1024, 512), (8, 2048, 512), @@ -30,11 +33,12 @@ def construct_moe( - hidden_size, - ffn_hidden_size, - moe_num_experts=1, - moe_capacity_factor=1, - moe_top_k=1): + hidden_size, + ffn_hidden_size, + moe_num_experts=1, + moe_capacity_factor=1, + moe_top_k=1, +): init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) args = Arguments( hidden_size=hidden_size, @@ -42,10 +46,11 @@ def construct_moe( moe_num_experts=moe_num_experts, moe_capacity_factor=moe_capacity_factor, moe_top_k=moe_top_k, - init_method=init_method) + init_method=init_method, + ) - mlp = testing.FFN(args) - moe_mlp = moe.MoE(args) + mlp = FFN(args) + moe_mlp = MoE(args) mlp.cuda(torch.cuda.current_device()).half() moe_mlp.cuda(torch.cuda.current_device()).half() @@ -59,8 +64,7 @@ def construct_moe( @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), - _FORWARD_TESTS) +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): x = torch.randn(sl, bs, hs).half().cuda() @@ -68,16 +72,23 @@ def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): hidden_size=hs, ffn_hidden_size=hs * 2, moe_num_experts=num_experts, - moe_top_k=top_k) + moe_top_k=top_k, + ) out, _ = layer(x) assert out.shape == x.shape - moe.clear_load_balancing_loss() + clear_load_balancing_loss() + @pytest.mark.gpu -@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), - _FORWARD_TESTS) -def test_moe_forward_backward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): +@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) +def test_moe_forward_backward( + bs: int, + sl: int, + hs: int, + num_experts: int, + top_k: int, +): x = torch.randn(sl, bs, hs).half().cuda() x.requires_grad_(True) @@ -85,16 +96,17 @@ def test_moe_forward_backward(bs: int, sl: int, hs: int, num_experts: int, top_k hidden_size=hs, ffn_hidden_size=hs * 2, moe_num_experts=num_experts, - moe_top_k=top_k) + moe_top_k=top_k, + ) out, _ = layer(x) assert out.shape == x.shape - loss = out.sum() + moe.batched_load_balancing_loss(args) + loss = out.sum() + batched_load_balancing_loss(args) loss.backward() layer.zero_grad(set_to_none=True) x.grad = None - moe.clear_load_balancing_loss() + clear_load_balancing_loss() @pytest.mark.gpu @@ -102,15 +114,13 @@ def test_moe_forward_backward(bs: int, sl: int, hs: int, num_experts: int, top_k def test_moe_forward_vs_dense(bs: int, sl: int, hs: int): x = torch.randn(sl, bs, hs).half().cuda() - _, mlp, moe_mlp = construct_moe( - hidden_size=hs, - ffn_hidden_size=hs * 2) + _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2) expected_out = mlp(x) out, _ = moe_mlp(x) assert out.shape == x.shape == expected_out.shape assert torch.allclose(out, expected_out) - moe.clear_load_balancing_loss() + clear_load_balancing_loss() @pytest.mark.gpu @@ -119,9 +129,7 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): x = torch.randn(sl, bs, hs).half().cuda() x.requires_grad_(True) - _, mlp, moe_mlp = construct_moe( - hidden_size=hs, - ffn_hidden_size=hs * 2) + _, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2) out, _ = moe_mlp(x) loss = out.sum() @@ -130,7 +138,7 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze() moe_mlp.zero_grad(set_to_none=True) x.grad = None - moe.clear_load_balancing_loss() + clear_load_balancing_loss() expected_out = mlp(x) expected_loss = expected_out.sum() @@ -141,8 +149,8 @@ def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): x.grad = None # Verify the gradients match. - assert w1_grad.shape == expected_w1_grad.shape + assert w1_grad.shape == expected_w1_grad.shape assert w2_grad.shape == expected_w2_grad.shape assert torch.allclose(w1_grad, expected_w1_grad) assert torch.allclose(w2_grad, expected_w2_grad) - moe.clear_load_balancing_loss() + clear_load_balancing_loss() diff --git a/tests/layers/parallelism_test.py b/tests/layers/parallelism_test.py deleted file mode 100644 index 0aa42694..00000000 --- a/tests/layers/parallelism_test.py +++ /dev/null @@ -1,134 +0,0 @@ -import functools - -import numpy as np -import pytest -import torch - -from megablocks.layers import arguments, dmoe, mpu - -_PARALLELISM_TESTS = ( - (64, 1024, 512, 2048, 64, 1, False), - (64, 1024, 512, 2048, 64, 1, True), - # Test with fewer experts than ranks to verify tensor - # sharding in tandem with expert sharding. - (4, 1, 512, 2048, 4, 1, False), - (4, 1, 512, 2048, 4, 1, True), -) - -# Todo: Fix this long term -@pytest.fixture -def group(): - return None - - -@pytest.mark.world_size(2) -@pytest.mark.gpu -@pytest.mark.parametrize(('batch_size', 'sequence_length', 'hidden_size', 'ffn_hidden_size', 'num_experts', 'top_k', 'memory_optimized'), - _PARALLELISM_TESTS) -def test_expert_parallel_versus_weight_parallel( - group, - batch_size: int, - sequence_length: int, - hidden_size: int, - ffn_hidden_size: int, - num_experts: int, - top_k: int, - memory_optimized: bool): - - init_fn = functools.partial(torch.nn.init.normal_, mean=0.0, std=0.1) - ep_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_expert_model_parallelism=True, - expert_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized) - wp_args = arguments.Arguments( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - moe_num_experts=num_experts, - moe_top_k=top_k, - moe_weight_parallelism=True, - weight_parallel_group=group, - fp16=False, - bf16=False, - device=torch.cuda.current_device(), - init_method=init_fn, - memory_optimized_mlp=memory_optimized) - - # NOTE: Reset the seed so that the models get identical weights. - torch.manual_seed(1234) - ep = dmoe.dMoE(ep_args) - torch.manual_seed(1234) - wp = dmoe.dMoE(wp_args) - - # NOTE: Include the rank in the seed so we get different data per rank. - rank = torch.distributed.get_rank(group) - torch.manual_seed(1234 * rank) - x = torch.randn( - (batch_size, sequence_length, hidden_size), - device=torch.cuda.current_device(), - dtype=torch.float32).requires_grad_(True) - - # Test forward. - out, _ = wp(x) - expected_out, _ = ep(x) - - # Check the forward outputs. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - out.detach().float().cpu(), - expected_out.detach().float().cpu(), - rtol=1e-4, atol=1e-4) is None - - # Test backward. - out.mean().backward() - expected_out.mean().backward() - - # NOTE: If tensor parallelism is used different weights can be on - # different ranks. Gather the full grads to rank 0 to compare. - def gather(x): - m, n = x.shape - world_size = torch.distributed.get_world_size(group) - out = torch.empty( - m * world_size, n, device=x.device, dtype=x.dtype) - torch.distributed.all_gather_into_tensor(out, x, group=group) - return out - - def permute(x): - esd = mpu.expert_sharding_degree(ep_args) - hsd = mpu.hidden_sharding_degree(ep_args) - out = x.view(hsd, esd, -1).transpose(1, 0).contiguous() - return out.view(num_experts * ffn_hidden_size, hidden_size) - - wp_w2_grad = gather(wp.experts.mlp.w2.grad) - ep_w2_grad = permute(gather(ep.experts.mlp.w2.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w2_grad.float().cpu(), - ep_w2_grad.float().cpu(), - rtol=1e-5, atol=1e-5) is None - - wp_w1_grad = gather(wp.experts.mlp.w1.grad) - ep_w1_grad = permute(gather(ep.experts.mlp.w1.grad)) - if rank == 0: - assert np.testing.assert_allclose( - wp_w1_grad.float().cpu(), - ep_w1_grad.float().cpu(), - rtol=1e-5, atol=1e-5) is None - - # Verify the router weight gradient, which is not sharded. - for i in range(torch.distributed.get_world_size(group)): - torch.distributed.barrier(group) - if i == rank: - assert np.testing.assert_allclose( - wp.router.layer.weight.grad.float().cpu(), - ep.router.layer.weight.grad.float().cpu(), - rtol=1e-5, atol=1e-5) is None diff --git a/tests/ops/binned_gather_test.py b/tests/ops/binned_gather_test.py index cc59ae3c..c165086f 100644 --- a/tests/ops/binned_gather_test.py +++ b/tests/ops/binned_gather_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -46,8 +46,13 @@ def test_binned_gather(sl: int, hs: int, ne: int, top_k: int): _, indices = ops.sort(top_expert) bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0) - def binned_gather(x: torch.Tensor, indices: torch.Tensor, - bins: torch.Tensor, ec: int, top_k: int): + def binned_gather( + x: torch.Tensor, + indices: torch.Tensor, + bins: torch.Tensor, + ec: int, + top_k: int, + ): x = x.cpu().numpy() indices = indices.cpu().numpy() bins = bins.cpu().numpy() diff --git a/tests/ops/binned_scatter_test.py b/tests/ops/binned_scatter_test.py index 2d1c5856..b725700b 100644 --- a/tests/ops/binned_scatter_test.py +++ b/tests/ops/binned_scatter_test.py @@ -1,3 +1,6 @@ +# Copyright 2024 Databricks +# SPDX-License-Identifier: Apache-2.0 + import numpy as np import pytest import torch @@ -48,8 +51,13 @@ def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int): x = ops.binned_gather(x, indices, bins, ec, top_k) - def binned_scatter(x: torch.Tensor, indices: torch.Tensor, - weights: torch.Tensor, bins: torch.Tensor, top_k: int): + def binned_scatter( + x: torch.Tensor, + indices: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + top_k: int, + ): x = x.cpu().numpy() indices = indices.cpu().numpy() weights = weights.cpu().numpy() @@ -66,10 +74,14 @@ def binned_scatter(x: torch.Tensor, indices: torch.Tensor, out[index, :] += scale * x[i, j, :] start = end return torch.from_numpy(out).cuda().half() + out = ops.binned_scatter(x, indices, weights, bins, top_k) expected_out = binned_scatter(x, indices, weights, bins, top_k) # NOTE: We need to check approximate equality because the # scatter reduce uses atomics. assert np.testing.assert_allclose( - out.cpu(), expected_out.cpu(), rtol=5e-3) is None + out.cpu(), + expected_out.cpu(), + rtol=5e-3, + ) is None diff --git a/tests/ops/cumsum_test.py b/tests/ops/cumsum_test.py index a1b71605..5d8b0824 100644 --- a/tests/ops/cumsum_test.py +++ b/tests/ops/cumsum_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import pytest diff --git a/tests/ops/histogram_test.py b/tests/ops/histogram_test.py index 25b30cb9..d6d3f238 100644 --- a/tests/ops/histogram_test.py +++ b/tests/ops/histogram_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import pytest @@ -78,6 +78,5 @@ def test_histogram(m: int, n: int, dtype: torch.dtype, max_val: int): x = torch.randint(0, max_val, (m, n)).cuda().to(dtype) out = ops.histogram(x, max_val) - expected_out = torch.stack( - [torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)]) + expected_out = torch.stack([torch.histc(y, max_val, 0, max_val - 1) for y in torch.split(x, 1)]) assert torch.all(torch.eq(out, expected_out)) diff --git a/tests/ops/padded_gather_test.py b/tests/ops/padded_gather_test.py index e6eb7f75..71980998 100644 --- a/tests/ops/padded_gather_test.py +++ b/tests/ops/padded_gather_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -63,9 +63,14 @@ def testPaddedGather(sl: int, hs: int, ne: int, top_k: int): padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) bins = ops.inclusive_cumsum(tokens_per_expert, 0) - def padded_gather(x: torch.Tensor, indices: torch.Tensor, - bin_ids: torch.Tensor, bins: torch.Tensor, - padded_bins: torch.Tensor, top_k: int): + def padded_gather( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): x = x.cpu().numpy() indices = indices.cpu().numpy() bin_ids = bin_ids.cpu().numpy() diff --git a/tests/ops/padded_scatter_test.py b/tests/ops/padded_scatter_test.py index ebd04a86..0e80dbb5 100644 --- a/tests/ops/padded_scatter_test.py +++ b/tests/ops/padded_scatter_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -94,10 +94,15 @@ def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int): # Gather the data to prepare for backwards. x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) - def padded_scatter(x: torch.Tensor, indices: torch.Tensor, - bin_ids: torch.Tensor, weights: torch.Tensor, - bins: torch.Tensor, padded_bins: torch.Tensor, - top_k: int): + def padded_scatter( + x: torch.Tensor, + indices: torch.Tensor, + bin_ids: torch.Tensor, + weights: torch.Tensor, + bins: torch.Tensor, + padded_bins: torch.Tensor, + top_k: int, + ): x = x.detach().cpu().numpy() indices: np.ndarray = _to_numpy(indices) bin_ids: np.ndarray = _to_numpy(bin_ids) @@ -120,10 +125,24 @@ def padded_scatter(x: torch.Tensor, indices: torch.Tensor, in_idx += 1 return torch.from_numpy(out).cuda().half() - out = ops.padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, - top_k) - expected_out = padded_scatter(x, indices, bin_ids, weights, bins, - padded_bins, top_k) + out = ops.padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) + expected_out = padded_scatter( + x, + indices, + bin_ids, + weights, + bins, + padded_bins, + top_k, + ) out.backward(torch.randn_like(out)) # sanity check backward pass diff --git a/tests/ops/replicate_test.py b/tests/ops/replicate_test.py index 94aeb67a..aeb1405e 100644 --- a/tests/ops/replicate_test.py +++ b/tests/ops/replicate_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np diff --git a/tests/ops/sort_test.py b/tests/ops/sort_test.py index e07f2e1a..147426e3 100644 --- a/tests/ops/sort_test.py +++ b/tests/ops/sort_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 from typing import Dict, Optional, Union @@ -31,12 +31,11 @@ ] -def torch_to_numpy_dtype( - dtype: torch.dtype) -> Union[np.int16, np.int32, np.int64]: +def torch_to_numpy_dtype(dtype: torch.dtype,) -> Union[np.int16, np.int32, np.int64]: types: Dict[torch.dtype, Union[np.int16, np.int32, np.int64]] = { torch.int16: np.int16, torch.int32: np.int32, - torch.int64: np.int64 + torch.int64: np.int64, } return types[dtype] diff --git a/tests/ops/topology_test.py b/tests/ops/topology_test.py index a7135be0..dc3c0aee 100644 --- a/tests/ops/topology_test.py +++ b/tests/ops/topology_test.py @@ -1,4 +1,4 @@ -# Copyright 2024 MosaicML MegaBlocks authors +# Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 import numpy as np @@ -48,8 +48,12 @@ def test_topology(sl: int, hs: int, ne: int): output_block_rows = int(padded_bins[-1]) // blocking output_block_columns = hs // blocking - def topology(padded_bins: torch.Tensor, blocking: torch.Tensor, rows: int, - columns: int): + def topology( + padded_bins: torch.Tensor, + blocking: torch.Tensor, + rows: int, + columns: int, + ): padded_bins = padded_bins.cpu().numpy() out = np.zeros([rows * columns]) @@ -62,8 +66,16 @@ def topology(padded_bins: torch.Tensor, blocking: torch.Tensor, rows: int, start += 1 return torch.from_numpy(out).cuda().short() - out = ops.topology(padded_bins, blocking, output_block_rows, - output_block_columns) - expected_out = topology(padded_bins, blocking, output_block_rows, - output_block_columns) + out = ops.topology( + padded_bins, + blocking, + output_block_rows, + output_block_columns, + ) + expected_out = topology( + padded_bins, + blocking, + output_block_rows, + output_block_columns, + ) assert torch.all(torch.eq(out, expected_out)) diff --git a/yamls/matmul_benchmark.yaml b/yamls/matmul_benchmark.yaml index de26f581..46a79d9c 100644 --- a/yamls/matmul_benchmark.yaml +++ b/yamls/matmul_benchmark.yaml @@ -4,11 +4,11 @@ cluster: r9z1 gpu_num: 8 gpu_type: h100_80gb integrations: - - integration_type: git_repo - git_repo: stanford-futuredata/megablocks - git_branch: main - pip_install: absl-py 'git+https://github.com/openai/triton.git@main#egg=triton&subdirectory=python' - ssh_clone: false +- integration_type: git_repo + git_repo: stanford-futuredata/megablocks + git_branch: main + pip_install: absl-py 'git+https://github.com/openai/triton.git@main#egg=triton&subdirectory=python' + ssh_clone: false command: |- cd megablocks export ENABLE_TMA=1 diff --git a/yamls/triton_benchmark.yaml b/yamls/triton_benchmark.yaml index 70c96266..fd309462 100644 --- a/yamls/triton_benchmark.yaml +++ b/yamls/triton_benchmark.yaml @@ -4,14 +4,14 @@ cluster: r9z1 gpu_num: 8 gpu_type: h100_80gb integrations: - - integration_type: git_repo - git_repo: openai/triton - git_branch: main - ssh_clone: false +- integration_type: git_repo + git_repo: openai/triton + git_branch: main + ssh_clone: false command: |- export ENABLE_TMA=1 export ENABLE_MMA_V3=1 - + cd triton/python pip install . --no-dependencies