From d1a33959cf49600641b7164ba69ee12f48f6ec22 Mon Sep 17 00:00:00 2001 From: Eitan Turok Date: Thu, 1 Aug 2024 19:26:25 +0000 Subject: [PATCH] format --- megablocks/_version.py | 2 +- setup.py | 28 +++++++++++----------------- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/megablocks/_version.py b/megablocks/_version.py index 2bb5d50..a9ac8bc 100644 --- a/megablocks/_version.py +++ b/megablocks/_version.py @@ -1,4 +1,4 @@ -# Copyright 2022 MegaBlocks Composer authors +# Copyright 2024 MosaicML MegaBlocks authors # SPDX-License-Identifier: Apache-2.0 """The MegaBlocks Version.""" diff --git a/setup.py b/setup.py index b03a82d..6a8ad77 100644 --- a/setup.py +++ b/setup.py @@ -9,18 +9,17 @@ from setuptools import find_packages, setup - # We require torch in setup.py to build cpp extensions "ahead of time" # More info here: # https://pytorch.org/tutorials/advanced/cpp_extension.html try: import torch - from torch.utils.cpp_extension import (CUDA_HOME, BuildExtension, - CUDAExtension,) + from torch.utils.cpp_extension import ( + CUDA_HOME, + BuildExtension, + CUDAExtension, + ) except ModuleNotFoundError as e: - raise ModuleNotFoundError( - "No module named 'torch'. `torch` is required to install `MegaBlocks`." - ) from e - + raise ModuleNotFoundError("No module named 'torch'. `torch` is required to install `MegaBlocks`.",) from e _PACKAGE_NAME = 'megablocks' _PACKAGE_DIR = 'megablocks' @@ -37,7 +36,6 @@ exec(content, version_globals, version_locals) repo_version = version_locals['__version__'] - with open('README.md', 'r', encoding='utf-8') as fh: long_description = fh.read() @@ -56,7 +54,6 @@ long_description = long_description[:start] + \ long_description[end + len(end_tag):] - classifiers = [ 'Programming Language :: Python :: 3', 'Programming Language :: Python :: 3.9', @@ -95,7 +92,6 @@ extra_deps['all'] = list({dep for key, deps in extra_deps.items() for dep in deps if key not in {'testing'}}) - cmdclass = {} ext_modules = [] @@ -113,9 +109,7 @@ device_capability = f'{device_capability_tuple[0]}{device_capability_tuple[1]}' if device_capability: - nvcc_flags.append( - f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}' - ) + nvcc_flags.append(f'--generate-code=arch=compute_{device_capability},code=sm_{device_capability}',) ext_modules = [ CUDAExtension( @@ -124,19 +118,19 @@ include_dirs=['csrc'], extra_compile_args={ 'cxx': ['-fopenmp'], - 'nvcc': nvcc_flags + 'nvcc': nvcc_flags, }, - ) + ), ] elif CUDA_HOME is None: warnings.warn( 'Attempted to install CUDA extensions, but CUDA_HOME was None. ' + 'Please install CUDA and ensure that the CUDA_HOME environment ' + - 'variable points to the installation location.') + 'variable points to the installation location.', + ) else: warnings.warn('Warning: No CUDA devices; cuda code will not be compiled.') - setup( name=_PACKAGE_NAME, version=repo_version,