forked from ar4/deepwave
-
Notifications
You must be signed in to change notification settings - Fork 1
/
setup.py
70 lines (60 loc) · 2.68 KB
/
setup.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import os
import setuptools
from torch.utils.cpp_extension import (BuildExtension, CppExtension,
CUDAExtension)
from torch import cuda
with open("README.md", "r") as fh:
long_description = fh.read()
scalar_dir = os.path.join('deepwave', 'scalar')
scalar_cpp_file = os.path.join(scalar_dir, 'scalar.cpp')
scalar_cpu_file = os.path.join(scalar_dir, 'scalar_cpu.cpp')
scalar_gpu_file = os.path.join(scalar_dir, 'scalar_gpu.cu')
scalar_wrapper_file = os.path.join(scalar_dir, 'scalar_wrapper.cpp')
def _make_cpp_extension(dim, dtype):
return CppExtension('scalar{}d_cpu_iso_4_{}'.format(dim, dtype),
[scalar_cpu_file, scalar_cpp_file, scalar_wrapper_file],
define_macros=[('DIM', dim), ('TYPE', dtype)],
include_dirs=[scalar_dir],
extra_compile_args=['-Ofast', '-march=native',
'-fopenmp'],
extra_link_args=['-fopenmp'])
def _make_cuda_extension(dim, dtype):
return CUDAExtension('scalar{}d_gpu_iso_4_{}'.format(dim, dtype),
[scalar_gpu_file, scalar_cpp_file, scalar_wrapper_file],
define_macros=[('DIM', dim), ('TYPE', dtype)],
include_dirs=[scalar_dir],
extra_compile_args={'nvcc': ['--restrict', '-O3',
'--use_fast_math'],
'cxx': ['-Ofast', '-march=native']})
cpp_extensions = [_make_cpp_extension(dim, dtype)
for dim in ['1', '2', '3']
for dtype in ['float', 'double']]
if cuda.is_available():
cuda_extensions = [_make_cuda_extension(dim, dtype)
for dim in ['1', '2', '3']
for dtype in ['float']]
else:
cuda_extensions = []
setuptools.setup(
name="deepwave",
version="0.0.8",
author="Alan Richardson",
author_email="[email protected]",
description="Wave propagation modules for PyTorch.",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/ar4/deepwave",
packages=setuptools.find_packages(),
classifiers=(
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
),
install_requires=["numpy",
"scipy",
"torch>=1.7.1"],
setup_requires=["torch>=1.7.1"],
extras_require={"testing": ["pytest"]},
ext_modules=cpp_extensions + cuda_extensions,
cmdclass={'build_ext': BuildExtension}
)