forked from masadcv/FastGeodis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
setup.py
executable file
·148 lines (126 loc) · 4.56 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
from gettext import install
import glob
import os
import re
import sys
import warnings
import pkg_resources
from setuptools import find_packages, setup
FORCE_CUDA = os.getenv("FORCE_CUDA", "0") == "1"
BUILD_CPP = BUILD_CUDA = False
TORCH_VERSION = 0
try:
import torch
print(f"setup.py with torch {torch.__version__}")
from torch.utils.cpp_extension import BuildExtension, CppExtension
BUILD_CPP = True
from torch.utils.cpp_extension import CUDA_HOME, CUDAExtension
BUILD_CUDA = (CUDA_HOME is not None) if torch.cuda.is_available() else FORCE_CUDA
_pt_version = pkg_resources.parse_version(torch.__version__)._version.release
if _pt_version is None or len(_pt_version) < 3:
raise AssertionError("unknown torch version")
TORCH_VERSION = (
int(_pt_version[0]) * 10000 + int(_pt_version[1]) * 100 + int(_pt_version[2])
)
except (ImportError, TypeError, AssertionError, AttributeError) as e:
warnings.warn(f"extension build skipped: {e}")
finally:
print(
f"BUILD_CPP={BUILD_CPP}, BUILD_CUDA={BUILD_CUDA}, TORCH_VERSION={TORCH_VERSION}."
)
def torch_parallel_backend():
try:
match = re.search(
"^ATen parallel backend: (?P<backend>.*)$",
torch._C._parallel_info(),
re.MULTILINE,
)
if match is None:
return None
backend = match.group("backend")
if backend == "OpenMP":
return "AT_PARALLEL_OPENMP"
if backend == "native thread pool":
return "AT_PARALLEL_NATIVE"
if backend == "native thread pool and TBB":
return "AT_PARALLEL_NATIVE_TBB"
except (NameError, AttributeError): # no torch or no binaries
warnings.warn("Could not determine torch parallel_info.")
return None
def omp_flags():
if sys.platform == "win32":
return ["/openmp"]
if sys.platform == "darwin":
# https://stackoverflow.com/questions/37362414/
# return ["-fopenmp=libiomp5"]
return []
return ["-fopenmp"]
def get_extensions():
# this_dir = os.path.dirname(os.path.abspath(__file__))
# ext_dir = os.path.join(this_dir, "src")
ext_dir = "FastGeodis"
include_dirs = [ext_dir]
source_cpu = glob.glob(os.path.join(ext_dir, "**", "*.cpp"), recursive=True)
source_cuda = glob.glob(os.path.join(ext_dir, "**", "*.cu"), recursive=True)
extension = None
define_macros = [(f"{torch_parallel_backend()}", 1)]
extra_compile_args = {}
extra_link_args = []
sources = source_cpu
if BUILD_CPP:
extension = CppExtension
extra_compile_args.setdefault("cxx", [])
if torch_parallel_backend() == "AT_PARALLEL_OPENMP":
extra_compile_args["cxx"] += omp_flags()
extra_link_args = omp_flags()
if BUILD_CUDA:
extension = CUDAExtension
sources += source_cuda
define_macros += [("WITH_CUDA", None)]
extra_compile_args = {"cxx": [], "nvcc": []}
if torch_parallel_backend() == "AT_PARALLEL_OPENMP":
extra_compile_args["cxx"] += omp_flags()
if extension is None or not sources:
return [] # compile nothing
# compile release
extra_compile_args["cxx"] += ["-g0"]
ext_modules = [
extension(
name="FastGeodisCpp",
sources=sources,
include_dirs=include_dirs,
define_macros=define_macros,
extra_compile_args=extra_compile_args,
extra_link_args=extra_link_args,
)
]
return ext_modules
with open("README.md", "r") as fh:
long_description = fh.read()
with open("requirements.txt", "r") as fp:
install_requires = fp.read().splitlines()
# add dependencies folder in include path
dep_dir = os.path.join(".", "dependency")
setup(
name="FastGeodis",
version="1.0.4",
description="Fast Implementation of Generalised Geodesic Distance Transform for CPU (OpenMP) and GPU (CUDA)",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/masadcv/FastGeodis",
author="Muhammad Asad",
author_email="[email protected]",
license="BSD-3-Clause License",
classifiers=[
"License :: OSI Approved :: BSD License",
"Programming Language :: Python :: 3",
],
install_requires=install_requires,
cmdclass={
"build_ext": BuildExtension
}, # .with_options(no_python_abi_suffix=True)},
packages=find_packages(exclude=("data", "docs", "examples", "scripts", "tests")),
zip_safe=False,
ext_modules=get_extensions(),
include_dirs=[dep_dir],
)