Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added simple workarounds for gather_mm and segment_mm #57

Merged
merged 10 commits into from
Sep 27, 2024
4 changes: 3 additions & 1 deletion .github/workflows/black.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- uses: psf/black@stable
with:
jupyter: true
17 changes: 12 additions & 5 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,28 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10"]
torch-version: ["1.13.1", "2.0.1"]
python-version: ["3.8", "3.10", "3.12"]
torch-version: ["1.13.1", "2.4.1"]
exclude:
- python-version: "3.12"
torch-version: "1.13.1"

steps:
- uses: actions/checkout@v3
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v3
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install torch==${{ matrix.torch-version }}
python -m pip install flake8 black
python -m pip install flake8 black[jupyter]
if [ -f requirements-dev.txt ]; then pip install -r requirements-dev.txt; fi
- name: numpy downgrade for pytorch 1.x
if: startsWith(matrix.torch-version, '1.')
run: |
pip install "numpy<2"
- name: Lint check with flake8
run: |
# stop the build if there are Python syntax errors or undefined names
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@ def readme():
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.12",
],
python_requires=">=3.8, <3.11",
python_requires=">=3.8",
keywords="sparse torch utility",
url="https://github.com/cai4cai/torchsparsegradutils",
author="CAI4CAI research group",
Expand Down
10 changes: 9 additions & 1 deletion torchsparsegradutils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,13 @@
from .sparse_matmul import sparse_mm
from .indexed_matmul import gather_mm, segment_mm
from .sparse_solve import sparse_triangular_solve, sparse_generic_solve
from .sparse_lstsq import sparse_generic_lstsq

__all__ = ["sparse_mm", "sparse_triangular_solve", "sparse_generic_solve", "sparse_generic_lstsq"]
__all__ = [
"sparse_mm",
"gather_mm",
"segment_mm",
"sparse_triangular_solve",
"sparse_generic_solve",
"sparse_generic_lstsq",
]
117 changes: 117 additions & 0 deletions torchsparsegradutils/indexed_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch

try:
import dgl.ops as dglops

dgl_installed = True
except ImportError:
dgl_installed = False


def segment_mm(a, b, seglen_a):
"""
Performs matrix multiplication according to segments.
See https://docs.dgl.ai/generated/dgl.ops.segment_mm.html

Suppose ``seglen_a == [10, 5, 0, 3]``, the operator will perform
four matrix multiplications::

a[0:10] @ b[0], a[10:15] @ b[1],
a[15:15] @ b[2], a[15:18] @ b[3]

Args:
a (torch.Tensor): The left operand, 2-D tensor of shape ``(N, D1)``
b (torch.Tensor): The right operand, 3-D tensor of shape ``(R, D1, D2)``
seglen_a (torch.Tensor): An integer tensor of shape ``(R,)``. Each element is the length of segments of input ``a``. The summation of all elements must be equal to ``N``.

Returns:
torch.Tensor: The output dense matrix of shape ``(N, D2)``
"""
if torch.__version__ < (2, 4):
raise NotImplementedError("PyTorch version is too old for nested tesors")

if dgl_installed:
# DGL is probably more computationally efficient
# See https://github.com/pytorch/pytorch/issues/136747
return dglops.segment_mm(a, b, seglen_a)

if not a.dim() == 2 or not b.dim() == 3 or not seglen_a.dim() == 1:
raise ValueError("Input tensors have unexpected dimensions")

N, _ = a.shape
R, D1, D2 = b.shape

# Sanity check sizes
if not a.shape[1] == D1 or not seglen_a.shape[0] == R:
raise ValueError("Incompatible size for inputs")

segidx_a = torch.cumsum(seglen_a[:-1], dim=0)

# Ideally the conversions below to nested tensor would be handled natively
nested_a = torch.nested.as_nested_tensor(torch.tensor_split(a, segidx_a, dim=0))
nested_b = torch.nested.as_nested_tensor(list(map(torch.squeeze, torch.split(b, 1, dim=0))))

# The actual gather matmul computation
nested_ab = torch.matmul(nested_a, nested_b)

# Convert back to tensors, again ideally this would be handled natively
ab = torch.cat(nested_ab.unbind(), dim=0)
return ab


def gather_mm(a, b, idx_b):
"""
Gather data according to the given indices and perform matrix multiplication.
See https://docs.dgl.ai/generated/dgl.ops.gather_mm.html

Let the result tensor be ``c``, the operator conducts the following computation:

c[i] = a[i] @ b[idx_b[i]]
, where len(c) == len(idx_b)

Args:
a (torch.Tensor): A 2-D tensor of shape ``(N, D1)``
b (torch.Tensor): A 3-D tensor of shape ``(R, D1, D2)``
idx_b (torch.Tensor): An 1-D integer tensor of shape ``(N,)``.

Returns:
torch.Tensor: The output dense matrix of shape ``(N, D2)``
"""
if torch.__version__ < (2, 4):
raise NotImplementedError("PyTorch version is too old for nested tesors")

if dgl_installed:
# DGL is more computationally efficient
# See https://github.com/pytorch/pytorch/issues/136747
return dglops.gather_mm(a, b, idx_b)

# Dependency free fallback
if not isinstance(a, torch.Tensor) or not isinstance(b, torch.Tensor) or not isinstance(idx_b, torch.Tensor):
raise ValueError("Inputs should be instances of torch.Tensor")

if not a.dim() == 2 or not b.dim() == 3 or not idx_b.dim() == 1:
raise ValueError("Input tensors have unexpected dimensions")

N = idx_b.shape[0]
R, D1, D2 = b.shape

# Sanity check sizes
if not a.shape[0] == N or not a.shape[1] == D1:
raise ValueError("Incompatible size for inputs")

torchdevice = a.device
src_idx = torch.arange(N, device=torchdevice)

# Ideally the conversions below to nested tensor would be handled without for looops and without copy
nested_a = torch.nested.as_nested_tensor([a[idx_b == i, :] for i in range(R)])
src_idx_reshuffled = torch.cat([src_idx[idx_b == i] for i in range(R)])
nested_b = torch.nested.as_nested_tensor([b[i, :, :].squeeze() for i in range(R)])

# The actual gather matmul computation
nested_ab = torch.matmul(nested_a, nested_b)

# Convert back to tensors, again, ideally this would be handled natively with no copy
ab_segmented = torch.cat(nested_ab.unbind(), dim=0)
ab = torch.empty((N, D2), device=torchdevice)
ab[src_idx_reshuffled] = ab_segmented
return ab
94 changes: 94 additions & 0 deletions torchsparsegradutils/tests/test_indexed_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
import pytest

if torch.__version__ < (2, 4):
pytest.skip(
"Skipping test based on nested tensors since an old version of pytorch is used", allow_module_level=True
)

from torchsparsegradutils import gather_mm, segment_mm

# Identify Testing Parameters
DEVICES = [torch.device("cpu")]
if torch.cuda.is_available():
DEVICES.append(torch.device("cuda"))

TEST_DATA = [
# name N, R, D1, D2
("small", 100, 32, 7, 10),
]

INDEX_DTYPES = [torch.int32, torch.int64]
VALUE_DTYPES = [torch.float32, torch.float64]

ATOL = 1e-6 # relaxed tolerance to allow for float32
RTOL = 1e-4


# Define Test Names:
def data_id(shapes):
return shapes[0]


def device_id(device):
return str(device)


def dtype_id(dtype):
return str(dtype).split(".")[-1]


# Define Fixtures


@pytest.fixture(params=TEST_DATA, ids=[data_id(d) for d in TEST_DATA])
def shapes(request):
return request.param


@pytest.fixture(params=VALUE_DTYPES, ids=[dtype_id(d) for d in VALUE_DTYPES])
def value_dtype(request):
return request.param


@pytest.fixture(params=INDEX_DTYPES, ids=[dtype_id(d) for d in INDEX_DTYPES])
def index_dtype(request):
return request.param


@pytest.fixture(params=DEVICES, ids=[device_id(d) for d in DEVICES])
def device(request):
return request.param


# Define Tests


def test_segment_mm(device, value_dtype, index_dtype, shapes):
_, N, R, D1, D2 = shapes

a = torch.randn((N, D1), device=device)
b = torch.randn((R, D1, D2), device=device)
seglen_a = torch.randint(low=1, high=int(N / R), size=(R,), device=device)
seglen_a[-1] = N - seglen_a[:-1].sum()

ab = segment_mm(a, b, seglen_a)

k = 0
for i in range(R):
for j in range(seglen_a[i]):
assert torch.allclose(ab[k, :].squeeze(), a[k, :].squeeze() @ b[i, :, :].squeeze(), atol=ATOL, rtol=RTOL)
k += 1


def test_gather_mm(device, value_dtype, index_dtype, shapes):
_, N, R, D1, D2 = shapes

a = torch.randn((N, D1), device=device)
b = torch.randn((R, D1, D2), device=device)
idx_b = torch.randint(low=0, high=R, size=(N,), device=device)

ab = gather_mm(a, b, idx_b)

for i in range(N):
assert torch.allclose(ab[i, :].squeeze(), a[i, :].squeeze() @ b[idx_b[i], :, :].squeeze(), atol=ATOL, rtol=RTOL)