Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
Eitan Turok committed Aug 1, 2024
1 parent de11700 commit d1a3395
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
2 changes: 1 addition & 1 deletion megablocks/_version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2022 MegaBlocks Composer authors
# Copyright 2024 MosaicML MegaBlocks authors
# SPDX-License-Identifier: Apache-2.0

"""The MegaBlocks Version."""
Expand Down
28 changes: 11 additions & 17 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,17 @@

from setuptools import find_packages, setup


# We require torch in setup.py to build cpp extensions "ahead of time"
# More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html
try:
import torch
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension,
CUDAExtension,)
from torch.utils.cpp_extension import (
CUDA_HOME,
BuildExtension,
CUDAExtension,
)
except ModuleNotFoundError as e:
raise ModuleNotFoundError(
"No module named 'torch'. `torch` is required to install `MegaBlocks`."
) from e

raise ModuleNotFoundError("No module named 'torch'. `torch` is required to install `MegaBlocks`.",) from e

_PACKAGE_NAME = 'megablocks'
_PACKAGE_DIR = 'megablocks'
Expand All @@ -37,7 +36,6 @@
exec(content, version_globals, version_locals)
repo_version = version_locals['__version__']


with open('README.md', 'r', encoding='utf-8') as fh:
long_description = fh.read()

Expand All @@ -56,7 +54,6 @@
long_description = long_description[:start] + \
long_description[end + len(end_tag):]


classifiers = [
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.9',
Expand Down Expand Up @@ -95,7 +92,6 @@

extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}})


cmdclass = {}
ext_modules = []

Expand All @@ -113,9 +109,7 @@
device_capability = f'{device_capability_tuple[0]}{device_capability_tuple[1]}'

if device_capability:
nvcc_flags.append(
f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}'
)
nvcc_flags.append(f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}',)

ext_modules = [
CUDAExtension(
Expand All @@ -124,19 +118,19 @@
include_dirs=['csrc'],
extra_compile_args={
'cxx': ['-fopenmp'],
'nvcc': nvcc_flags
'nvcc': nvcc_flags,
},
)
),
]
elif CUDA_HOME is None:
warnings.warn(
'Attempted to install CUDA extensions, but CUDA_HOME was None. ' +
'Please install CUDA and ensure that the CUDA_HOME environment ' +
'variable points to the installation location.')
'variable points to the installation location.',
)
else:
warnings.warn('Warning: No CUDA devices; cuda code will not be compiled.')


setup(
name=_PACKAGE_NAME,
version=repo_version,
Expand Down

0 comments on commit d1a3395

Please sign in to comment.