-
Notifications
You must be signed in to change notification settings - Fork 176
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #8 from eitanturok/eitan-temp
- Loading branch information
Showing
4 changed files
with
118 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
# Copyright 2022 MegaBlocks Composer authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""The MegaBlocks Version.""" | ||
|
||
__version__ = '0.5.1' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,77 @@ | ||
# Copyright 2024 MosaicML MegaBlocks authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
"""MegaBlocks package setup.""" | ||
|
||
import os | ||
import warnings | ||
from typing import Any, Dict, Mapping | ||
|
||
import torch | ||
from setuptools import find_packages, setup | ||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension | ||
|
||
if os.environ.get('TORCH_CUDA_ARCH_LIST'): | ||
# Let PyTorch builder to choose device to target for. | ||
device_capability = '' | ||
else: | ||
device_capability = torch.cuda.get_device_capability() | ||
device_capability = f'{device_capability[0]}{device_capability[1]}' | ||
|
||
nvcc_flags = [ | ||
'--ptxas-options=-v', | ||
'--optimize=2', | ||
] | ||
if device_capability: | ||
nvcc_flags.append(f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}',) | ||
|
||
ext_modules = [ | ||
CUDAExtension( | ||
'megablocks_ops', | ||
['csrc/ops.cu'], | ||
include_dirs=['csrc'], | ||
extra_compile_args={ | ||
'cxx': ['-fopenmp'], | ||
'nvcc': nvcc_flags, | ||
}, | ||
), | ||
# We require torch in setup.py to build cpp extensions "ahead of time" | ||
# More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html | ||
try: | ||
import torch | ||
from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, | ||
CUDAExtension,) | ||
except ModuleNotFoundError as e: | ||
raise ModuleNotFoundError( | ||
"No module named 'torch'. `torch` is required to install `MegaBlocks`." | ||
) from e | ||
|
||
|
||
_PACKAGE_NAME = 'megablocks' | ||
_PACKAGE_DIR = 'megablocks' | ||
_REPO_REAL_PATH = os.path.dirname(os.path.realpath(__file__)) | ||
_PACKAGE_REAL_PATH = os.path.join(_REPO_REAL_PATH, _PACKAGE_DIR) | ||
|
||
# Read the package version | ||
# We can't use `.__version__` from the library since it's not installed yet | ||
version_path = os.path.join(_PACKAGE_REAL_PATH, '_version.py') | ||
with open(version_path, encoding='utf-8') as f: | ||
version_globals: Dict[str, Any] = {} | ||
version_locals: Mapping[str, object] = {} | ||
content = f.read() | ||
exec(content, version_globals, version_locals) | ||
repo_version = version_locals['__version__'] | ||
|
||
|
||
with open('README.md', 'r', encoding='utf-8') as fh: | ||
long_description = fh.read() | ||
|
||
# Hide the content between <!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_BEGIN --> and | ||
# <!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_END --> tags in the README | ||
while True: | ||
start_tag = '<!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_BEGIN -->' | ||
end_tag = '<!-- SETUPTOOLS_LONG_DESCRIPTION_HIDE_END -->' | ||
start = long_description.find(start_tag) | ||
end = long_description.find(end_tag) | ||
if start == -1: | ||
assert end == -1, 'there should be a balanced number of start and ends' | ||
break | ||
else: | ||
assert end != -1, 'there should be a balanced number of start and ends' | ||
long_description = long_description[:start] + \ | ||
long_description[end + len(end_tag):] | ||
|
||
|
||
classifiers = [ | ||
'Programming Language :: Python :: 3', | ||
'Programming Language :: Python :: 3.9', | ||
'Programming Language :: Python :: 3.10', | ||
'Programming Language :: Python :: 3.11', | ||
'License :: OSI Approved :: BSD License', | ||
'Operating System :: Unix', | ||
] | ||
|
||
install_requires = [ | ||
'numpy>=1.21.5,<2.1.0', | ||
'packaging>=21.3.0,<24.2', | ||
'torch>=2.3.0,<2.4', | ||
'triton>=2.1.0', | ||
'stanford-stk @ git+https://[email protected]/stanford-futuredata/stk.git@a1ddf98466730b88a2988860a9d8000fd1833301', | ||
'packaging>=21.3.0,<24.2', | ||
] | ||
|
||
extra_deps = {} | ||
|
@@ -62,23 +95,62 @@ | |
|
||
extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) | ||
|
||
|
||
cmdclass = {} | ||
ext_modules = [] | ||
|
||
# Only install CUDA extensions if available | ||
if 'cu' in torch.__version__ and CUDA_HOME is not None: | ||
|
||
cmdclass = {'build_ext': BuildExtension} | ||
nvcc_flags = ['--ptxas-options=-v', '--optimize=2'] | ||
|
||
if os.environ.get('TORCH_CUDA_ARCH_LIST'): | ||
# Let PyTorch builder to choose device to target for. | ||
device_capability = '' | ||
else: | ||
device_capability_tuple = torch.cuda.get_device_capability() | ||
device_capability = f'{device_capability_tuple[0]}{device_capability_tuple[1]}' | ||
|
||
if device_capability: | ||
nvcc_flags.append( | ||
f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}' | ||
) | ||
|
||
ext_modules = [ | ||
CUDAExtension( | ||
'megablocks_ops', | ||
['csrc/ops.cu'], | ||
include_dirs=['csrc'], | ||
extra_compile_args={ | ||
'cxx': ['-fopenmp'], | ||
'nvcc': nvcc_flags | ||
}, | ||
) | ||
] | ||
elif CUDA_HOME is None: | ||
warnings.warn( | ||
'Attempted to install CUDA extensions, but CUDA_HOME was None. ' + | ||
'Please install CUDA and ensure that the CUDA_HOME environment ' + | ||
'variable points to the installation location.') | ||
else: | ||
warnings.warn('Warning: No CUDA devices; cuda code will not be compiled.') | ||
|
||
|
||
setup( | ||
name='megablocks', | ||
version='0.5.1', | ||
name=_PACKAGE_NAME, | ||
version=repo_version, | ||
author='Trevor Gale', | ||
author_email='[email protected]', | ||
description='MegaBlocks', | ||
long_description=open('README.md').read(), | ||
long_description=long_description, | ||
long_description_content_type='text/markdown', | ||
url='https://github.com/stanford-futuredata/megablocks', | ||
classifiers=[ | ||
'Programming Language :: Python :: 3', | ||
'License :: OSI Approved :: BSD License', | ||
'Operating System :: Unix', | ||
], | ||
packages=find_packages(), | ||
url='https://github.com/databricks/megablocks', | ||
classifiers=classifiers, | ||
packages=find_packages(exclude=['tests*', 'third_party*', 'yamls*', 'exp*', '.github*']), | ||
ext_modules=ext_modules, | ||
cmdclass={'build_ext': BuildExtension}, | ||
cmdclass=cmdclass, | ||
install_requires=install_requires, | ||
extras_require=extra_deps, | ||
python_requires='>=3.9', | ||
) |