diff --git a/setup.py b/setup.py index 4b3016a..2284e3a 100644 --- a/setup.py +++ b/setup.py @@ -10,12 +10,13 @@ ) except: raise ModuleNotFoundError("Please install pytorch >= 1.1 before proceeding.") - + import glob from os import path + this_directory = path.abspath(path.dirname(__file__)) -with open(path.join(this_directory, 'README.md'), encoding='utf-8') as f: +with open(path.join(this_directory, "README.md"), encoding="utf-8") as f: long_description = f.read() @@ -27,7 +28,9 @@ def get_ext_modules(): extra_compile_args += ["-DVERSION_GE_1_3"] ext_src_root = "cuda" - ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob("{}/src/*.cu".format(ext_src_root)) + ext_sources = glob.glob("{}/src/*.cpp".format(ext_src_root)) + glob.glob( + "{}/src/*.cu".format(ext_src_root) + ) ext_modules = [] if CUDA_HOME: @@ -36,7 +39,10 @@ def get_ext_modules(): name="torch_points_kernels.points_cuda", sources=ext_sources, include_dirs=["{}/include".format(ext_src_root)], - extra_compile_args={"cxx": extra_compile_args, "nvcc": extra_compile_args,}, + extra_compile_args={ + "cxx": extra_compile_args, + "nvcc": extra_compile_args, + }, ) ) @@ -53,26 +59,28 @@ def get_ext_modules(): ) return ext_modules + def get_cmdclass(): return {"build_ext": BuildExtension} + requirements = ["torch>=1.1.0"] -url = 'https://github.com/nicolas-chaulet/torch-points-kernels' -__version__="0.6.3" +url = "https://github.com/nicolas-chaulet/torch-points-kernels" +__version__ = "0.6.3" setup( name="torch-points-kernels", version=__version__, author="Nicolas Chaulet", packages=find_packages(), - description="PyTorch kernels for spatial operations on point clouds" + description="PyTorch kernels for spatial operations on point clouds", url=url, - download_url='{}/archive/{}.tar.gz'.format(url, __version__), + download_url="{}/archive/{}.tar.gz".format(url, __version__), install_requires=requirements, ext_modules=get_ext_modules(), cmdclass=get_cmdclass(), long_description=long_description, - long_description_content_type='text/markdown', + long_description_content_type="text/markdown", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License",