Skip to content

Commit

Permalink
Update setup.py to support multiple device capabilities (#56)
Browse files Browse the repository at this point in the history
  • Loading branch information
simon-mo authored Dec 11, 2023
1 parent 396de2a commit 5897cd6
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions setup.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,31 @@
from setuptools import setup, find_packages
import os
import torch
from torch.utils.cpp_extension import BuildExtension, CUDAExtension

if os.environ.get("TORCH_CUDA_ARCH_LIST"):
# Let PyTorch builder to choose device to target for.
device_capability = ""
else:
device_capability = torch.cuda.get_device_capability()
device_capability = f"{device_capability[0]}{device_capability[1]}"

nvcc_flags = [
"--ptxas-options=-v",
"--optimize=2",
]
if device_capability:
nvcc_flags.append(
f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}"
)

_dc = torch.cuda.get_device_capability()
_dc = f"{_dc[0]}{_dc[1]}"
ext_modules = [
CUDAExtension(
"megablocks_ops",
["csrc/ops.cu"],
include_dirs = ["csrc"],
extra_compile_args={
"cxx": ["-fopenmp"],
"nvcc": [
"--ptxas-options=-v",
"--optimize=2",
f"--generate-code=arch=compute_{_dc},code=sm_{_dc}"
]
})
include_dirs=["csrc"],
extra_compile_args={"cxx": ["-fopenmp"], "nvcc": nvcc_flags},
)
]

install_requires=[
Expand Down

0 comments on commit 5897cd6

Please sign in to comment.