diff --git a/setup.py b/setup.py index 3efbdc21..4a5a45ab 100644 --- a/setup.py +++ b/setup.py @@ -1,23 +1,31 @@ from setuptools import setup, find_packages +import os import torch from torch.utils.cpp_extension import BuildExtension, CUDAExtension +if os.environ.get("TORCH_CUDA_ARCH_LIST"): + # Let PyTorch builder to choose device to target for. + device_capability = "" +else: + device_capability = torch.cuda.get_device_capability() + device_capability = f"{device_capability[0]}{device_capability[1]}" + +nvcc_flags = [ + "--ptxas-options=-v", + "--optimize=2", +] +if device_capability: + nvcc_flags.append( + f"--generate-code=arch=compute_{device_capability},code=sm_{device_capability}" + ) -_dc = torch.cuda.get_device_capability() -_dc = f"{_dc[0]}{_dc[1]}" ext_modules = [ CUDAExtension( "megablocks_ops", ["csrc/ops.cu"], - include_dirs = ["csrc"], - extra_compile_args={ - "cxx": ["-fopenmp"], - "nvcc": [ - "--ptxas-options=-v", - "--optimize=2", - f"--generate-code=arch=compute_{_dc},code=sm_{_dc}" - ] - }) + include_dirs=["csrc"], + extra_compile_args={"cxx": ["-fopenmp"], "nvcc": nvcc_flags}, + ) ] install_requires=[