diff --git a/examples/pytorch_optimization.py b/examples/pytorch_optimization.py index f731cc4b..658cd485 100644 --- a/examples/pytorch_optimization.py +++ b/examples/pytorch_optimization.py @@ -16,7 +16,7 @@ import os import torch import matplotlib -matplotlib.use('agg') +matplotlib.use('agg') # Make matplotlib more robust when interface plotting is impossible. import matplotlib.pyplot as plt import argparse diff --git a/setup.py b/setup.py index 6857b9bc..f4e3bffb 100644 --- a/setup.py +++ b/setup.py @@ -22,6 +22,7 @@ import sys import textwrap import traceback +from typing import List from distutils.errors import CompileError, DistutilsError, \ @@ -437,21 +438,26 @@ def is_torch_cuda(build_ext, include_dirs, extra_compile_args): print('INFO: Above error indicates that this PyTorch installation does not support CUDA.') return False -def build_nvcc_extra_objects(cxx11_abi: bool): +def get_nvcc_cmd() -> str: + from shutil import which + nvcc_cmd = which('nvcc') + if nvcc_cmd is None: + raise DistutilsPlatformError('Unable to find NVCC compiler') + return nvcc_cmd + +def build_nvcc_extra_objects(nvcc_cmd: str, cxx11_abi: bool) -> List[str]: # nvcc --compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI=0' -rdc=true -c cuda_kernels.cu # nvcc --compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI=0' -dlink -o cuda_kernels_link.o \ # cuda_kernels.o -lcudart - nvcc = 'nvcc' - common_flags = f"'-fPIC -D_GLIBCXX_USE_CXX11_ABI={int(cxx11_abi)}'" - nvcc_with_flags = f"{nvcc} --compiler-options {common_flags}" + nvcc_flags = f"--compiler-options '-fPIC -D_GLIBCXX_USE_CXX11_ABI={int(cxx11_abi)}'" extra_object_dir = 'bluefog/common/cuda/' source = extra_object_dir+'cuda_kernels.cu' object_file = extra_object_dir+'cuda_kernels.o' object_link = extra_object_dir+'cuda_kernels_link.o' - command_object = f"{nvcc_with_flags} -rdc=true -c {source} -o {object_file}" - command_link = f"{nvcc_with_flags} -dlink {object_file} -lcudart -o {object_link}" + command_object = f"{nvcc_cmd} {nvcc_flags} -rdc=true -c {source} -o {object_file}" + command_link = f"{nvcc_cmd} {nvcc_flags} -dlink {object_file} -lcudart -o {object_link}" subprocess.check_call([command_object], shell=True) subprocess.check_call([command_link], shell=True) @@ -469,7 +475,8 @@ def build_torch_extension(build_ext, global_options, torch_version): if have_cuda: cuda_include_dirs, cuda_lib_dirs = get_cuda_dirs( build_ext, options['COMPILE_FLAGS']) - cuda_extra_objects = build_nvcc_extra_objects(is_cxx11_abi) + nvcc_cmd = get_nvcc_cmd() + cuda_extra_objects = build_nvcc_extra_objects(nvcc_cmd, is_cxx11_abi) options['EXTRA_OBJECTS'] += cuda_extra_objects options['INCLUDES'] += cuda_include_dirs @@ -515,7 +522,6 @@ def build_torch_extension(build_ext, global_options, torch_version): else: # CUDAExtension fails with `ld: library not found for -lcudart` if CUDA is not present from torch.utils.cpp_extension import CppExtension as TorchExtension - bluefog_tensorflow_mpi_lib.extra_objects = options['EXTRA_OBJECTS'] ext = TorchExtension(bluefog_torch_mpi_lib.name, define_macros=updated_macros,