Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump torch to <2.4.1 #145

Merged
merged 8 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@ jobs:
fail-fast: false
matrix:
include:
- name: "python3.11-pytorch2.3.1-gpus1"
- name: "python3.11-pytorch2.4.0-gpus1"
gpu_num: 1
python_version: 3.11
container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04
- name: "python3.11-pytorch2.3.1-gpus2"
container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04
- name: "python3.11-pytorch2.4.0-gpus2"
gpu_num: 2
python_version: 3.11
container: mosaicml/pytorch:2.3.1_cu121-python3.11-ubuntu20.04
container: mosaicml/pytorch:2.4.0_cu124-python3.11-ubuntu20.04
steps:
- name: Run PR GPU tests
uses: mosaicml/ci-testing/.github/actions/[email protected]
Expand Down
144 changes: 43 additions & 101 deletions megablocks/backend/kernels.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional

import torch
import triton
import triton.language as tl


def assert_is_tensor(x: torch.Tensor, ndim: int):
def assert_is_tensor(x, ndim):
if x.ndim != ndim:
raise ValueError(f'Expected {ndim}-tensor but got {x.ndim}-tensor')


def assert_is_matrix(x: torch.Tensor):
def assert_is_matrix(x):
assert_is_tensor(x, 2)


def assert_is_vector(x: torch.Tensor):
def assert_is_vector(x):
if x.ndim != 1:
raise ValueError(f'Expected 1-tensor but got {x.ndim}-tensor')


def assert_equal(a: Any, b: Any):
def assert_equal(a, b):
if a != b:
raise ValueError(f'Expected dimensions to be equal but got {a} and {b}.',)

Expand All @@ -44,13 +43,13 @@ def assert_equal(a: Any, b: Any):
)
@triton.jit
def _padded_copy(
a: torch.Tensor,
b: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Any,
bins: torch.Tensor,
padded_bins: torch.Tensor,
a,
b,
indices,
bin_ids,
weights,
bins,
padded_bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -105,15 +104,7 @@ def _padded_copy(
offsets += BLOCK_X


def padded_gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
def padded_gather(x, indices, bin_ids, weights, bins, padded_bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -129,7 +120,7 @@ def padded_gather(

# NOTE: Because of the padding, the output size is dynamic.
# We load the final padded bin bound to get the output rows.
output_rows = int(padded_bins[-1].cpu().item())
output_rows = padded_bins[-1].cpu().item()
out = torch.zeros((output_rows, x.shape[1]), dtype=x.dtype, device=x.device)
_padded_copy[(indices.shape[0],)](
x,
Expand All @@ -147,14 +138,7 @@ def padded_gather(
return out


def gather(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
def gather(x, indices, bin_ids, weights, bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -186,15 +170,7 @@ def gather(
return out


def padded_scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
def padded_scatter(x, indices, bin_ids, weights, bins, padded_bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand Down Expand Up @@ -227,14 +203,7 @@ def padded_scatter(
return out.sum(dim=1) if top_k > 1 else out.view(tokens, x.shape[1])


def scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
) -> torch.Tensor:
def scatter(x, indices, bin_ids, weights, bins, top_k):
return padded_scatter(x, indices, bin_ids, weights, bins, bins, top_k)


Expand All @@ -257,13 +226,13 @@ def scatter(
)
@triton.jit
def _padded_copy_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
x,
grad,
wgrad,
indices,
bin_ids,
bins,
padded_bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -307,15 +276,7 @@ def _padded_copy_wgrad(
tl.store(wgrad, out)


def padded_scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
def padded_scatter_wgrad(x, grad, indices, bin_ids, bins, padded_bins, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_matrix(grad)
Expand All @@ -342,14 +303,7 @@ def padded_scatter_wgrad(
return out


def scatter_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
bins: torch.Tensor,
top_k: int,
):
def scatter_wgrad(x, grad, indices, bin_ids, bins, top_k):
return padded_scatter_wgrad(x, grad, indices, bin_ids, bins, bins, top_k)


Expand All @@ -370,13 +324,13 @@ def scatter_wgrad(
)
@triton.jit
def _binned_copy(
a: torch.Tensor,
b: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
weights, #: Optional[torch.Tensor],
bins: torch.Tensor,
a,
b,
num_experts,
expert_capacity,
indices,
weights,
bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -435,14 +389,7 @@ def _binned_copy(
offsets += BLOCK_X


def binned_gather(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
expert_capacity: int,
top_k: int,
):
def binned_gather(x, indices, weights, bins, expert_capacity, top_k):
# Validate the input shapes.
assert_is_matrix(x)
assert_is_vector(indices)
Expand All @@ -454,6 +401,7 @@ def binned_gather(

num_experts = bins.shape[0]
out = torch.zeros((num_experts, expert_capacity, x.shape[1]), dtype=x.dtype, device=x.device)

_binned_copy[(num_experts, expert_capacity)](
x,
out,
Expand All @@ -470,13 +418,7 @@ def binned_gather(
return out


def binned_scatter(
x: torch.Tensor,
indices: torch.Tensor,
weights: Optional[torch.Tensor],
bins: torch.Tensor,
top_k: int,
):
def binned_scatter(x, indices, weights, bins, top_k):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_vector(indices)
Expand Down Expand Up @@ -524,13 +466,13 @@ def binned_scatter(
)
@triton.jit
def _binned_copy_wgrad(
x: torch.Tensor,
grad: torch.Tensor,
wgrad: torch.Tensor,
num_experts: int,
expert_capacity: int,
indices: torch.Tensor,
bins: torch.Tensor,
x,
grad,
wgrad,
num_experts,
expert_capacity,
indices,
bins,
NUM_COLUMNS: tl.constexpr,
TOP_K: tl.constexpr,
BLOCK_X: tl.constexpr,
Expand Down Expand Up @@ -576,7 +518,7 @@ def _binned_copy_wgrad(
tl.store(wgrad, out)


def binned_scatter_wgrad(x: torch.Tensor, grad: torch.Tensor, indices: torch.Tensor, bins: torch.Tensor, top_k: int):
def binned_scatter_wgrad(x, grad, indices, bins, top_k):
# Validate the input shapes.
assert_is_tensor(x, 3)
assert_is_matrix(grad)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# build requirements
[build-system]
requires = ["setuptools < 70.0.0", "torch >= 2.3.0, < 2.4"]
requires = ["setuptools < 70.0.0", "torch >= 2.4.0, < 2.4.1"]
build-backend = "setuptools.build_meta"

# Pytest
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@
install_requires = [
'numpy>=1.21.5,<2.1.0',
'packaging>=21.3.0,<24.2',
'torch>=2.3.0,<2.4',
'torch>=2.4.0,<2.4.1',
'triton>=2.1.0',
'stanford-stk @ git+https://[email protected]/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301',
'stanford-stk @ git+https://[email protected]/stanford-futuredata/stk.git@v0.7.1',
]

extra_deps = {}

extra_deps['gg'] = [
'grouped_gemm @ git+https://[email protected]/tgale96/grouped_gemm.git@66c7195e35e8c4f22fa6a014037ef511bfa397cb',
'grouped_gemm @ git+https://[email protected]/tgale96/grouped_gemm.git@v0.1.6',
]

extra_deps['dev'] = [
Expand All @@ -83,7 +83,7 @@
]

extra_deps['testing'] = [
'mosaicml>=0.22.0',
'mosaicml>=0.24.1',
]

extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}})
Expand Down
Loading