diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 365ef5d..28a6dc8 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -30,3 +30,7 @@ replace = release = '{new_version}' [bumpversion:file:src/torch_max_mem/version.py] search = VERSION = "{current_version}" replace = VERSION = "{new_version}" + +[bumpversion:file:pyproject.toml] +search = version = "{current_version}" +replace = version = "{new_version}" diff --git a/.flake8 b/.flake8 deleted file mode 100644 index 37e31a5..0000000 --- a/.flake8 +++ /dev/null @@ -1,31 +0,0 @@ -######################### -# Flake8 Configuration # -# (.flake8) # -######################### -[flake8] -ignore = - # the following two ignores are for compatibility with black formatting - # Line break before binary operator (flake8 is wrong) - W503 - # whitespace before ':' - E203 -exclude = - .tox, - .git, - __pycache__, - docs/source/conf.py, - build, - dist, - tests/fixtures/*, - *.pyc, - *.egg-info, - .cache, - .eggs, - data -max-line-length = 120 -max-complexity = 20 -import-order-style = pycharm -application-import-names = - torch_max_mem - tests - diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a4d5691..a3357d9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -11,51 +11,66 @@ jobs: python-version: [ "3.8", "3.11" ] steps: - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} + - name: Install dependencies - run: pip install tox - - name: Check manifest - run: tox -e manifest - - name: Check code quality with flake8 - run: tox -e flake8 - - name: Check package metadata with Pyroma - run: tox -e pyroma + run: pip install tox tox-uv + - name: Check static typing with MyPy run: tox -e mypy - docs: - name: Documentation + + - name: Check code quality + run: tox -e lint + + lint-single-version: + name: Package Meta & Documentation runs-on: ubuntu-latest strategy: matrix: - python-version: [ "3.8", "3.11" ] + python-version: [ "3.11" ] steps: - - uses: actions/checkout@v4 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - name: Install dependencies - run: pip install tox - - name: Check RST conformity with doc8 - run: tox -e doc8 - - name: Check docstring coverage - run: tox -e docstr-coverage - - name: Check documentation build with Sphinx - run: tox -e docs-test + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: pip install tox tox-uv + + - name: Check package metadata with Pyroma + run: tox -e pyroma + + - name: Check docstring coverage + run: tox -e docstr-coverage + + - name: Check manifest + run: tox -e manifest + + - name: Check RST conformity with doc8 + run: tox -e doc8 + + - name: Check documentation build with Sphinx + run: tox -e docs-test + tests: name: Tests runs-on: ${{ matrix.os }} strategy: matrix: - os: [ ubuntu-latest, macos-14 ] + os: [ ubuntu-latest, windows-latest, macos-latest ] python-version: [ "3.8", "3.11" ] exclude: # 3.8 is not available for M1 macOS - os: macos-14 python-version: "3.8" + needs: + - lint steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -63,7 +78,12 @@ jobs: with: python-version: ${{ matrix.python-version }} - name: Install dependencies - run: pip install tox + run: pip install tox tox-uv - name: Test with pytest and generate coverage file run: tox -e py + + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true diff --git a/.gitignore b/.gitignore index a206b71..842fba2 100644 --- a/.gitignore +++ b/.gitignore @@ -318,4 +318,5 @@ $RECYCLE.BIN/ scratch/ -.vscode \ No newline at end of file +.vscode +.idea \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py index 085d52a..de313f1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # @@ -27,17 +26,17 @@ author = "Max Berrendorf" # The full version, including alpha/beta/rc tags. -release = '0.1.4-dev' +release = "0.1.4-dev" # The short X.Y version. parsed_version = re.match( - "(?P\d+)\.(?P\d+)\.(?P\d+)(?:-(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?", + r"(?P\d+)\.(?P\d+)\.(?P\d+)(?:-(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?(?:\+(?P[0-9A-Za-z-]+(?:\.[0-9A-Za-z-]+)*))?", release, ) -version = parsed_version.expand("\g.\g.\g") +version = parsed_version.expand(r"\g.\g.\g") if parsed_version.group("release"): - tags.add("prerelease") + tags.add("prerelease") # noqa:F821 # -- General configuration --------------------------------------------------- @@ -63,11 +62,8 @@ "sphinx.ext.todo", "sphinx.ext.mathjax", "sphinx.ext.viewcode", - "sphinx_autodoc_typehints", - "sphinx_click.ext", "sphinx_automodapi.automodapi", "sphinx_automodapi.smart_resolver", - # 'texext', ] # generate autosummary pages diff --git a/pyproject.toml b/pyproject.toml index 2936354..3438fe9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,20 +1,162 @@ # See https://setuptools.readthedocs.io/en/latest/build_meta.html [build-system] requires = ["setuptools", "wheel"] -build-backend = "setuptools.build_meta:__legacy__" +build-backend = "setuptools.build_meta" + +[project] +name = "torch_max_mem" +version = "0.1.4-dev" +description = "Maximize memory utilization with PyTorch." +# Author information +authors = [{ name = "Max Berrendorf", email = "max.berrendorf@gmail.com" }] +maintainers = [{ name = "Max Berrendorf", email = "max.berrendorf@gmail.com" }] + +# See https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#classifiers +# Search tags using the controlled vocabulary at https://pypi.org/classifiers +classifiers = [ + "Development Status :: 4 - Beta", + "Environment :: Console", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Framework :: Pytest", + "Framework :: tox", + "Framework :: Sphinx", + "Programming Language :: Python", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3 :: Only", + # TODO add your topics from the Trove controlled vocabulary (see https://pypi.org/classifiers) +] +keywords = [ + "snekpack", # please keep this keyword to credit the cookiecutter-snekpack template + "cookiecutter", + "torch", +] + +# License Information. This can be any valid SPDX identifiers that can be resolved +# with URLs like https://spdx.org/licenses/MIT +# See https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#license +license = { file = "License" } + +requires-python = ">=3.8" +dependencies = [ + "torch>=2.0", + "torch<2.4; platform_system=='Windows'", + "typing_extensions", +] + +[project.optional-dependencies] +tests = ["numpy", "numpy<2; platform_system=='Windows'", "pytest", "coverage"] +docs = [ + # Sphinx >= 8.0 not supported by rtd theme, cf. https://github.com/readthedocs/sphinx_rtd_theme/issues/1582 + "sphinx<8", + "sphinx-rtd-theme", + "sphinx_automodapi", + # To include LaTeX comments easily in your docs. + # If you uncomment this, don't forget to do the same in docs/conf.py + # texext +] + +[project.urls] +Homepage = "https://github.com/mberr/torch-max-mem" +Download = "https://github.com/mberr/torch-max-mem/releases" +"Bug Tracker" = "https://github.com/mberr/torch-max-mem/issues" +"Source Code" = "https://github.com/mberr/torch-max-mem" + +[project.readme] +file = "README.md" +content-type = "text/markdown" +# URLs associated with the project + + +[tool.setuptools] +# Where is my code +package-dir = { "" = "src" } + +[tool.setuptools.packages.find] +# this implicitly sets `packages = ":find"` +where = ["src"] # list of folders that contain the packages (["."] by default) + + +# Doc8, see https://doc8.readthedocs.io/en/stable/readme.html#ini-file-usage +[tool.doc8] +max-line-length = 120 + +# Coverage, see https://coverage.readthedocs.io/en/latest/config.html +[tool.coverage.run] +branch = true +source = ["torch_max_mem"] +omit = ["tests/*", "docs/*"] + +[tool.coverage.paths] +source = ["src/torch_max_mem", ".tox/*/lib/python*/site-packages/torch_max_mem"] + +[tool.coverage.report] +show_missing = true +exclude_lines = [ + "pragma: no cover", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "def __str__", + "def __repr__", +] [tool.black] -line-length = 100 -target-version = ["py37", "py38", "py39"] +line-length = 120 +target-version = ["py38", "py39", "py310", "py311", "py312"] [tool.isort] profile = "black" multi_line_output = 3 -line_length = 100 +line_length = 120 include_trailing_comma = true reverse_relative = true +[tool.ruff] +line-length = 120 + +[tool.ruff.lint] +# See https://docs.astral.sh/ruff/rules +extend-select = [ + "F", # pyflakes + "E", # pycodestyle errors + "W", # pycodestyle warnings + "C90", # mccabe + "I", # isort + "N", # pep8 naming + "D", # pydocstyle + "UP", # pyupgrade + "S", # bandit + "B", # bugbear + "T20", # print + "PT", # pytest-style + "RSE", #raise + "SIM", # simplify + "ERA", # eradicate commented out code + "NPY", # numpy checks + "RUF", # ruff rules +] +ignore = [] + +# See https://docs.astral.sh/ruff/settings/#per-file-ignores +[tool.ruff.lint.per-file-ignores] +# asserts in tests +"tests/**/*.py" = ["S101"] +"docs/source/conf.py" = ["D100", "ERA001"] + + +[tool.ruff.lint.pydocstyle] +convention = "pep257" + +[tool.ruff.lint.isort] +known-third-party = [] +known-first-party = ["torch_max_mem", "tests"] +relative-imports-order = "closest-to-furthest" + +# Pytest, see https://docs.pytest.org/en/stable/reference/customize.html#pyproject-toml [tool.pytest.ini_options] -markers = [ - "slow: marks tests as slow (deselect with '-m \"not slow\"')" -] \ No newline at end of file +addopts = "--strict-markers" +markers = ["slow: marks tests as slow (deselect with '-m \"not slow\"')"] diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 1c41724..0000000 --- a/setup.cfg +++ /dev/null @@ -1,124 +0,0 @@ -########################## -# Setup.py Configuration # -########################## -[metadata] -name = torch_max_mem -version = 0.1.4-dev -description = Maximize memory utilization with PyTorch. -long_description = file: README.md -long_description_content_type = text/markdown - -# URLs associated with the project -url = https://github.com/mberr/torch-max-mem -download_url = https://github.com/mberr/torch-max-mem/releases -project_urls = - Bug Tracker = https://github.com/mberr/torch-max-mem/issues - Source Code = https://github.com/mberr/torch-max-mem - -# Author information -author = Max Berrendorf -author_email = max.berrendorf@gmail.com -maintainer = Max Berrendorf -maintainer_email = max.berrendorf@gmail.com - -# License Information -license = MIT -license_file = LICENSE - -# Search tags -classifiers = - Development Status :: 4 - Beta - Environment :: Console - Intended Audience :: Developers - License :: OSI Approved :: MIT License - Operating System :: OS Independent - Framework :: Pytest - Framework :: tox - Framework :: Sphinx - Programming Language :: Python - Programming Language :: Python :: 3.8 - Programming Language :: Python :: 3.9 - Programming Language :: Python :: 3.10 - Programming Language :: Python :: 3.11 - Programming Language :: Python :: 3 :: Only - # TODO add your topics from the Trove controlled vocabulary (see https://pypi.org/classifiers) -keywords = - snekpack - cookiecutter - torch - -[options] -install_requires = - torch>=2.0 - typing_extensions - -# Random options -zip_safe = false -include_package_data = True -python_requires = >=3.8 - -# Where is my code -packages = find: -package_dir = - = src - -[options.packages.find] -where = src - -[options.extras_require] -tests = - numpy - pytest - coverage -formatting = - black - isort -docs = - # ... until RTD issues are fixed - sphinx<7 - sphinx-rtd-theme - sphinx-click - sphinx-autodoc-typehints - sphinx_automodapi - # To include LaTeX comments easily in your docs. - # If you uncomment this, don't forget to do the same in docs/conf.py - # texext - -###################### -# Doc8 Configuration # -# (doc8.ini) # -###################### -[doc8] -max-line-length = 120 - -########################## -# Coverage Configuration # -# (.coveragerc) # -########################## -[coverage:run] -branch = True -source = torch_max_mem -omit = - tests/* - docs/* - -[coverage:paths] -source = - src/torch_max_mem - .tox/*/lib/python*/site-packages/torch_max_mem - -[coverage:report] -show_missing = True -exclude_lines = - pragma: no cover - raise NotImplementedError - if __name__ == .__main__.: - def __str__ - def __repr__ - -########################## -# Darglint Configuration # -########################## -[darglint] -docstring_style = sphinx -strictness = short diff --git a/src/torch_max_mem/__init__.py b/src/torch_max_mem/__init__.py index 9be7873..e26d392 100644 --- a/src/torch_max_mem/__init__.py +++ b/src/torch_max_mem/__init__.py @@ -1,6 +1,6 @@ -# -*- coding: utf-8 -*- """Maximize memory utilization with PyTorch.""" + from .api import MemoryUtilizationMaximizer, maximize_memory_utilization __all__ = [ diff --git a/src/torch_max_mem/api.py b/src/torch_max_mem/api.py index 0962b59..985f0a9 100644 --- a/src/torch_max_mem/api.py +++ b/src/torch_max_mem/api.py @@ -1,5 +1,3 @@ -# -*- coding: utf-8 -*- - """ This module contains the public API. @@ -45,6 +43,7 @@ def knn(x, y, batch_size, k: int = 3): y = torch.rand(200, 100, device="cuda") knn(x, y, batch_size=x.shape[0]) """ + # cf. https://gist.github.com/mberr/c37a8068b38cabc98228db2cbe358043 from __future__ import annotations @@ -59,9 +58,7 @@ def knn(x, y, batch_size, k: int = 3): Iterable, Mapping, MutableMapping, - Optional, Sequence, - Tuple, TypeVar, ) @@ -98,9 +95,7 @@ def upgrade_to_sequence( when the (inferred) length of q and parameter_name do not match """ # normalize parameter name - parameter_names = ( - (parameter_name,) if isinstance(parameter_name, str) else tuple(parameter_name) - ) + parameter_names = (parameter_name,) if isinstance(parameter_name, str) else tuple(parameter_name) q = (q,) if isinstance(q, int) else tuple(q) q = q * len(parameter_names) if len(q) == 1 else q if len(q) != len(parameter_names): @@ -127,7 +122,7 @@ def determine_default_max_value( :raises ValueError: when the function does not have a parameter of the given name """ - if parameter_name not in signature.parameters.keys(): + if parameter_name not in signature.parameters: raise ValueError(f"{func} does not have a parameter {parameter_name}.") _parameter = signature.parameters[parameter_name] if _parameter.annotation != inspect.Parameter.empty and _parameter.annotation not in ( @@ -203,12 +198,13 @@ def determine_max_value( def iter_tensor_devices(*args: Any, **kwargs: Any) -> Iterable[torch.device]: """Iterate over tensors' devices (may contain duplicates).""" for obj in itertools.chain(args, kwargs.values()): - if torch.is_tensor(obj): - assert isinstance(obj, torch.Tensor) + if isinstance(obj, torch.Tensor): yield obj.device -def create_tensor_checker(safe_devices: Collection[str] | None = None) -> Callable[P, None]: +def create_tensor_checker( + safe_devices: Collection[str] | None = None, +) -> Callable[P, None]: """ Create a function that warns when tensors are on any device that is not considered safe. @@ -282,7 +278,7 @@ def maximize_memory_utilization_decorator( parameter_name: str | Sequence[str] = "batch_size", q: int | Sequence[int] = 32, safe_devices: Collection[str] | None = None, -) -> Callable[[Callable[P, R]], Callable[P, Tuple[R, tuple[int, ...]]]]: +) -> Callable[[Callable[P, R]], Callable[P, tuple[R, tuple[int, ...]]]]: """ Create decorators to create methods for memory utilization maximization. @@ -300,8 +296,8 @@ def maximize_memory_utilization_decorator( parameter_names, qs = upgrade_to_sequence(parameter_name, q) def decorator_maximize_memory_utilization( - func: Callable[P, R] - ) -> Callable[P, Tuple[R, tuple[int, ...]]]: + func: Callable[P, R], + ) -> Callable[P, tuple[R, tuple[int, ...]]]: """ Decorate a function to maximize memory utilization. @@ -319,9 +315,7 @@ def decorator_maximize_memory_utilization( } @functools.wraps(func) - def wrapper_maximize_memory_utilization( - *args: P.args, **kwargs: P.kwargs - ) -> Tuple[R, tuple[int, ...]]: + def wrapper_maximize_memory_utilization(*args: P.args, **kwargs: P.kwargs) -> tuple[R, tuple[int, ...]]: """ Wrap a function to maximize memory utilization by successive halving. @@ -359,15 +353,11 @@ def wrapper_maximize_memory_utilization( while i < len(max_values): while max_values[i] > 0: - p_kwargs = { - name: max_value for name, max_value in zip(parameter_names, max_values) - } + p_kwargs = {name: max_value for name, max_value in zip(parameter_names, max_values)} # note: changes to arguments apply to both, .args and .kwargs bound_arguments.arguments.update(p_kwargs) try: - return func(*bound_arguments.args, **bound_arguments.kwargs), tuple( - max_values - ) + return func(*bound_arguments.args, **bound_arguments.kwargs), tuple(max_values) except (torch.cuda.OutOfMemoryError, RuntimeError) as error: # raise errors unrelated to out-of-memory if not is_oom_error(error): @@ -392,12 +382,8 @@ def wrapper_maximize_memory_utilization( i += 1 # log memory summary for each CUDA device before raising memory error for device in {d for d in iter_tensor_devices(*args, **kwargs) if d.type == "cuda"}: - logger.debug( - f"Memory summary for {device=}:\n{torch.cuda.memory_summary(device=device)}" - ) - raise MemoryError( - f"Execution did not even succeed with {parameter_names} all equal to 1." - ) from last_error + logger.debug(f"Memory summary for {device=}:\n{torch.cuda.memory_summary(device=device)}") + raise MemoryError(f"Execution did not even succeed with {parameter_names} all equal to 1.") from last_error return wrapper_maximize_memory_utilization @@ -456,7 +442,7 @@ def __init__( parameter_name: str | Sequence[str] = "batch_size", q: int | Sequence[int] = 32, safe_devices: Collection[str] | None = None, - hasher: Optional[Callable[[Mapping[str, Any]], int]] = None, + hasher: Callable[[Mapping[str, Any]], int] | None = None, keys: Collection[str] | str | None = None, ) -> None: """ diff --git a/src/torch_max_mem/version.py b/src/torch_max_mem/version.py index 8b3958d..08bc4c9 100644 --- a/src/torch_max_mem/version.py +++ b/src/torch_max_mem/version.py @@ -1,12 +1,10 @@ -# -*- coding: utf-8 -*- - """Version information for :mod:`torch_max_mem`. Run with ``python -m torch_max_mem.version`` """ import os -from subprocess import CalledProcessError, check_output # noqa: S404 +from subprocess import CalledProcessError, check_output __all__ = [ "VERSION", @@ -21,8 +19,8 @@ def get_git_hash() -> str: """Get the :mod:`torch_max_mem` git hash.""" with open(os.devnull, "w") as devnull: try: - ret = check_output( # noqa: S603,S607 - ["git", "rev-parse", "HEAD"], + ret = check_output( # noqa: S603 + ["git", "rev-parse", "HEAD"], # noqa: S607 cwd=os.path.dirname(__file__), stderr=devnull, ) diff --git a/tests/__init__.py b/tests/__init__.py index 98adbd8..b242c2c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,2 @@ -# -*- coding: utf-8 -*- """Tests for :mod:`torch_max_mem`.""" diff --git a/tests/test_decorator.py b/tests/test_decorator.py index 6ae3bc5..62602b3 100644 --- a/tests/test_decorator.py +++ b/tests/test_decorator.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """Tests.""" @@ -101,7 +100,7 @@ def test_optimization(): def func(batch_size: int = 8): """Test function.""" if batch_size > 2: - raise torch.cuda.OutOfMemoryError() + raise torch.cuda.OutOfMemoryError return batch_size assert func() == 2 @@ -114,13 +113,13 @@ def test_optimization_multi_level(): def func(batch_size: int = 8, slice_size: int = 16): """Test function.""" if batch_size > 1 or slice_size > 8: - raise torch.cuda.OutOfMemoryError() + raise torch.cuda.OutOfMemoryError return batch_size, slice_size assert func() == (1, 8) -@pytest.mark.parametrize("x,q", [(15, 4), (3, 4)]) +@pytest.mark.parametrize(("x", "q"), [(15, 4), (3, 4)]) def test_floor_to_nearest_multiple_of(x: int, q: int) -> None: """Test floor_to_nearest_multiple_of.""" r = floor_to_nearest_multiple_of(x=x, q=q) @@ -135,7 +134,7 @@ def test_floor_to_nearest_multiple_of(x: int, q: int) -> None: @pytest.mark.parametrize( - "error,exp", + ("error", "exp"), [ # base cases (NameError(), False), @@ -161,7 +160,7 @@ def test_oom_error_detection(error: BaseException, exp: bool) -> None: assert is_oom_error(error) is exp -@pytest.mark.slow +@pytest.mark.slow() def test_large_on_mps(): """Test memory optimization on a large input.""" import torch.backends.mps diff --git a/tox.ini b/tox.ini index efe987b..1bea996 100644 --- a/tox.ini +++ b/tox.ini @@ -13,10 +13,10 @@ envlist = # always keep coverage-clean first # coverage-clean # code linters/stylers - lint + format manifest pyroma - flake8 + lint mypy # documentation linters/checkers doc8 @@ -43,15 +43,14 @@ deps = coverage skip_install = true commands = coverage erase -[testenv:lint] +[testenv:format] deps = - black - isort + ruff skip_install = true commands = - black src/ tests/ - isort src/ tests/ -description = Run linters. + ruff check --fix + ruff format +description = Format the code in a deterministic way using ruff [testenv:manifest] deps = check-manifest @@ -59,23 +58,15 @@ skip_install = true commands = check-manifest description = Check that the MANIFEST.in is written properly and give feedback on how to fix it. -[testenv:flake8] +[testenv:lint] skip_install = true deps = - darglint - flake8 - flake8-black - # flake8-bandit # incompatible with latest flake8 - flake8-bugbear - flake8-colors - flake8-docstrings - flake8-isort - flake8-print - pep8-naming - pydocstyle + ruff + darglint2 commands = - flake8 src/ tests/ -description = Run the flake8 tool with several plugins (bandit, docstrings, import order, pep8 naming). See https://cthoyt.com/2020/04/25/how-to-code-with-me-flake8.html for more information. + ruff check + darglint2 --strictness short --docstring-style sphinx -v 2 src/ tests/ +description = Check code quality using ruff and other tools. See https://github.com/akaihola/darglint2 [testenv:pyroma] deps =