From bf5be90d390988b0afc17acec8edf756ee968e1a Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 21 Aug 2024 12:21:51 -0400 Subject: [PATCH 1/7] bump torch to <2.5 (#142) --- .github/workflows/pr-gpu.yaml | 8 ++++---- pyproject.toml | 2 +- setup.py | 6 +++--- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/.github/workflows/pr-gpu.yaml b/.github/workflows/pr-gpu.yaml index 1ca8d5b..d94b057 100644 --- a/.github/workflows/pr-gpu.yaml +++ b/.github/workflows/pr-gpu.yaml @@ -21,14 +21,14 @@ jobs: fail-fast: false matrix: include: - - name: "python3.11-pytorch2.3.1-gpus1" + - name: "python3.11-pytorch2.4.0-gpus1" gpu_num: 1 python_version: 3.11 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 - - name: "python3.11-pytorch2.3.1-gpus2" + container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 + - name: "python3.11-pytorch2.4.0-gpus2" gpu_num: 2 python_version: 3.11 - container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04 + container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04 steps: - name: Run PR GPU tests uses: mosaicml/ci-testing/.github/actions/pytest-gpu@v0.1.2 diff --git a/pyproject.toml b/pyproject.toml index c72dbdf..fc8f3dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ # build requirements [build-system] -requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4"] +requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4.1"] build-backend = "setuptools.build_meta" # Pytest diff --git a/setup.py b/setup.py index fa15ee4..e39dd49 100644 --- a/setup.py +++ b/setup.py @@ -62,15 +62,15 @@ install_requires = [ 'numpy>=1.21.5,<2.1.0', 'packaging>=21.3.0,<24.2', - 'torch>=2.3.0,<2.4', + 'torch>=2.3.0,<2.4.1', 'triton>=2.1.0', - 'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301', + 'stanford-stk @ git+https://git@github.com/eitanturok/stk.git@bump-version', ] extra_deps = {} extra_deps['gg'] = [ - 'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb', + 'grouped_gemm @ git+https://git@github.com/eitanturok/grouped_gemm.git@bump-version', ] extra_deps['dev'] = [ From 31aefba7498e554a56358cd3b5c74f4e7d52d745 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 21 Aug 2024 14:11:17 -0400 Subject: [PATCH 2/7] bump torch to <2.5 (#143) --- pyproject.toml | 2 +- setup.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fc8f3dc..454868f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ # build requirements [build-system] -requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4.1"] +requires = ["setuptools < 70.0.0", "torch >= 2.4.0, < 2.4.1"] build-backend = "setuptools.build_meta" # Pytest diff --git a/setup.py b/setup.py index e39dd49..062a63a 100644 --- a/setup.py +++ b/setup.py @@ -62,7 +62,7 @@ install_requires = [ 'numpy>=1.21.5,<2.1.0', 'packaging>=21.3.0,<24.2', - 'torch>=2.3.0,<2.4.1', + 'torch>=2.4.0,<2.4.1', 'triton>=2.1.0', 'stanford-stk @ git+https://git@github.com/eitanturok/stk.git@bump-version', ] From 6bfbc42e03457eda7f3b46031fe09ff6afdcc791 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 21 Aug 2024 14:27:46 -0400 Subject: [PATCH 3/7] bump torch to <2.4.1 (#144) --- setup.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 062a63a..906adb0 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,8 @@ ] extra_deps['testing'] = [ - 'mosaicml>=0.22.0', + # 'mosaicml>=0.23.6', # uses when released + 'mosaicml @ git+https://github.com/mosaicml/composer.git@7c48cfc00ed5df553c947b336fee72437d2e68a7', ] extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) From c1c5f2222967a063f48a9676e621bf145d191374 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Wed, 21 Aug 2024 15:07:09 -0400 Subject: [PATCH 4/7] bump torch (#146) --- setup.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/setup.py b/setup.py index 906adb0..7596bd8 100644 --- a/setup.py +++ b/setup.py @@ -64,13 +64,13 @@ 'packaging>=21.3.0,<24.2', 'torch>=2.4.0,<2.4.1', 'triton>=2.1.0', - 'stanford-stk @ git+https://git@github.com/eitanturok/stk.git@bump-version', + 'stanford-stk>=0.7.1', ] extra_deps = {} extra_deps['gg'] = [ - 'grouped_gemm @ git+https://git@github.com/eitanturok/grouped_gemm.git@bump-version', + 'grouped_gemm>=0.1.6', ] extra_deps['dev'] = [ @@ -83,8 +83,7 @@ ] extra_deps['testing'] = [ - # 'mosaicml>=0.23.6', # uses when released - 'mosaicml @ git+https://github.com/mosaicml/composer.git@7c48cfc00ed5df553c947b336fee72437d2e68a7', + 'mosaicml>=0.23.6', ] extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) From be461a5ef6ba80ed4a62fb1ee9c866b51a9ca26a Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Mon, 26 Aug 2024 14:10:29 -0400 Subject: [PATCH 5/7] install from git, not pypi --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 7596bd8..714f7ae 100644 --- a/setup.py +++ b/setup.py @@ -64,13 +64,13 @@ 'packaging>=21.3.0,<24.2', 'torch>=2.4.0,<2.4.1', 'triton>=2.1.0', - 'stanford-stk>=0.7.1', + 'stanford-stk @ git+https://git@github.com/stanford-futuredata/stk.git@v0.7.1', ] extra_deps = {} extra_deps['gg'] = [ - 'grouped_gemm>=0.1.6', + 'grouped_gemm @ git+https://git@github.com/tgale96/grouped_gemm.git@v0.1.6', ] extra_deps['dev'] = [ From aeb6b4700086c8e522825f483e0093909c46ad65 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Thu, 29 Aug 2024 15:09:29 -0400 Subject: [PATCH 6/7] Update setup.py Co-authored-by: Saaketh Narayan --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 714f7ae..2ed3fcc 100644 --- a/setup.py +++ b/setup.py @@ -83,7 +83,7 @@ ] extra_deps['testing'] = [ - 'mosaicml>=0.23.6', + 'mosaicml>=0.24.1', ] extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) From 9b77d1626bc1146e05fb37536746442621ef10b0 Mon Sep 17 00:00:00 2001 From: Eitan Turok <150733043+eitanturok@users.noreply.github.com> Date: Fri, 30 Aug 2024 09:01:04 -0400 Subject: [PATCH 7/7] no type checking in `kernel.py` (#147) --- megablocks/backend/kernels.py | 144 ++++++++++------------------------ 1 file changed, 43 insertions(+), 101 deletions(-) diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index ca0120b..b584cee 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -1,27 +1,26 @@ # Copyright 2024 Databricks # SPDX-License-Identifier: Apache-2.0 -from typing import Any, Optional import torch import triton import triton.language as tl -def assert_is_tensor(x: torch.Tensor, ndim: int): +def assert_is_tensor(x, ndim): if x.ndim != ndim: raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor') -def assert_is_matrix(x: torch.Tensor): +def assert_is_matrix(x): assert_is_tensor(x, 2) -def assert_is_vector(x: torch.Tensor): +def assert_is_vector(x): if x.ndim != 1: raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor') -def assert_equal(a: Any, b: Any): +def assert_equal(a, b): if a != b: raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',) @@ -44,13 +43,13 @@ def assert_equal(a: Any, b: Any): ) @triton.jit def _padded_copy( - a: torch.Tensor, - b: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Any, - bins: torch.Tensor, - padded_bins: torch.Tensor, + a, + b, + indices, + bin_ids, + weights, + bins, + padded_bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -105,15 +104,7 @@ def _padded_copy( offsets += BLOCK_X -def padded_gather( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): +def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -129,7 +120,7 @@ def padded_gather( # NOTE: Because of the padding, the output size is dynamic. # We load the final padded bin bound to get the output rows. - output_rows = int(padded_bins[-1].cpu().item()) + output_rows = padded_bins[-1].cpu().item() out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device) _padded_copy[(indices.shape[0],)]( x, @@ -147,14 +138,7 @@ def padded_gather( return out -def gather( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - top_k: int, -): +def gather(x, indices, bin_ids, weights, bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -186,15 +170,7 @@ def gather( return out -def padded_scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -) -> torch.Tensor: +def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -227,14 +203,7 @@ def padded_scatter( return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1]) -def scatter( - x: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - top_k: int, -) -> torch.Tensor: +def scatter(x, indices, bin_ids, weights, bins, top_k): return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k) @@ -257,13 +226,13 @@ def scatter( ) @triton.jit def _padded_copy_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - wgrad: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, + x, + grad, + wgrad, + indices, + bin_ids, + bins, + padded_bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -307,15 +276,7 @@ def _padded_copy_wgrad( tl.store(wgrad, out) -def padded_scatter_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - padded_bins: torch.Tensor, - top_k: int, -): +def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_matrix(grad) @@ -342,14 +303,7 @@ def padded_scatter_wgrad( return out -def scatter_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - indices: torch.Tensor, - bin_ids: torch.Tensor, - bins: torch.Tensor, - top_k: int, -): +def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k): return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k) @@ -370,13 +324,13 @@ def scatter_wgrad( ) @triton.jit def _binned_copy( - a: torch.Tensor, - b: torch.Tensor, - num_experts: int, - expert_capacity: int, - indices: torch.Tensor, - weights, #: Optional[torch.Tensor], - bins: torch.Tensor, + a, + b, + num_experts, + expert_capacity, + indices, + weights, + bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -435,14 +389,7 @@ def _binned_copy( offsets += BLOCK_X -def binned_gather( - x: torch.Tensor, - indices: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - expert_capacity: int, - top_k: int, -): +def binned_gather(x, indices, weights, bins, expert_capacity, top_k): # Validate the input shapes. assert_is_matrix(x) assert_is_vector(indices) @@ -454,6 +401,7 @@ def binned_gather( num_experts = bins.shape[0] out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device) + _binned_copy[(num_experts, expert_capacity)]( x, out, @@ -470,13 +418,7 @@ def binned_gather( return out -def binned_scatter( - x: torch.Tensor, - indices: torch.Tensor, - weights: Optional[torch.Tensor], - bins: torch.Tensor, - top_k: int, -): +def binned_scatter(x, indices, weights, bins, top_k): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_vector(indices) @@ -524,13 +466,13 @@ def binned_scatter( ) @triton.jit def _binned_copy_wgrad( - x: torch.Tensor, - grad: torch.Tensor, - wgrad: torch.Tensor, - num_experts: int, - expert_capacity: int, - indices: torch.Tensor, - bins: torch.Tensor, + x, + grad, + wgrad, + num_experts, + expert_capacity, + indices, + bins, NUM_COLUMNS: tl.constexpr, TOP_K: tl.constexpr, BLOCK_X: tl.constexpr, @@ -576,7 +518,7 @@ def _binned_copy_wgrad( tl.store(wgrad, out) -def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int): +def binned_scatter_wgrad(x, grad, indices, bins, top_k): # Validate the input shapes. assert_is_tensor(x, 3) assert_is_matrix(grad)