Skip to content

Commit

Permalink
Address comments for setup.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Hanbin Hu committed Mar 21, 2021
1 parent 8e42ca2 commit e5f8722
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 9 deletions.
2 changes: 1 addition & 1 deletion examples/pytorch_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
import torch
import matplotlib
matplotlib.use('agg')
matplotlib.use('agg') # Make matplotlib more robust when interface plotting is impossible.
import matplotlib.pyplot as plt
import argparse

Expand Down
22 changes: 14 additions & 8 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import sys
import textwrap
import traceback
from typing import List


from distutils.errors import CompileError, DistutilsError, \
Expand Down Expand Up @@ -437,21 +438,26 @@ def is_torch_cuda(build_ext, include_dirs, extra_compile_args):
print('INFO: Above error indicates that this PyTorch installation does not support CUDA.')
return False

def build_nvcc_extra_objects(cxx11_abi: bool):
def get_nvcc_cmd() -> str:
from shutil import which
nvcc_cmd = which('nvcc')
if nvcc_cmd is None:
raise DistutilsPlatformError('Unable to find NVCC compiler')
return nvcc_cmd

def build_nvcc_extra_objects(nvcc_cmd: str, cxx11_abi: bool) -> List[str]:
# nvcc --compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI=0' -rdc=true -c cuda_kernels.cu
# nvcc --compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI=0' -dlink -o cuda_kernels_link.o \
# cuda_kernels.o -lcudart
nvcc = 'nvcc'
common_flags = f"'-fPIC -D_GLIBCXX_USE_CXX11_ABI={int(cxx11_abi)}'"
nvcc_with_flags = f"{nvcc} --compiler-options {common_flags}"
nvcc_flags = f"--compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI={int(cxx11_abi)}'"

extra_object_dir = 'bluefog/common/cuda/'
source = extra_object_dir+'cuda_kernels.cu'
object_file = extra_object_dir+'cuda_kernels.o'
object_link = extra_object_dir+'cuda_kernels_link.o'

command_object = f"{nvcc_with_flags} -rdc=true -c {source} -o {object_file}"
command_link = f"{nvcc_with_flags} -dlink {object_file} -lcudart -o {object_link}"
command_object = f"{nvcc_cmd} {nvcc_flags} -rdc=true -c {source} -o {object_file}"
command_link = f"{nvcc_cmd} {nvcc_flags} -dlink {object_file} -lcudart -o {object_link}"

subprocess.check_call([command_object], shell=True)
subprocess.check_call([command_link], shell=True)
Expand All @@ -469,7 +475,8 @@ def build_torch_extension(build_ext, global_options, torch_version):
if have_cuda:
cuda_include_dirs, cuda_lib_dirs = get_cuda_dirs(
build_ext, options['COMPILE_FLAGS'])
cuda_extra_objects = build_nvcc_extra_objects(is_cxx11_abi)
nvcc_cmd = get_nvcc_cmd()
cuda_extra_objects = build_nvcc_extra_objects(nvcc_cmd, is_cxx11_abi)
options['EXTRA_OBJECTS'] += cuda_extra_objects

options['INCLUDES'] += cuda_include_dirs
Expand Down Expand Up @@ -515,7 +522,6 @@ def build_torch_extension(build_ext, global_options, torch_version):
else:
# CUDAExtension fails with `ld: library not found for -lcudart` if CUDA is not present
from torch.utils.cpp_extension import CppExtension as TorchExtension
bluefog_tensorflow_mpi_lib.extra_objects = options['EXTRA_OBJECTS']

ext = TorchExtension(bluefog_torch_mpi_lib.name,
define_macros=updated_macros,
Expand Down

0 comments on commit e5f8722

Please sign in to comment.