Skip to content

Commit

Permalink
move set_pytorch_cuda_alloc_conf to a different module to have fewer …
Browse files Browse the repository at this point in the history
…loaded dependencies for the CLI
  • Loading branch information
winglian committed Dec 17, 2024
1 parent 1699178 commit 0e0c6d7
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 16 deletions.
4 changes: 1 addition & 3 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
"""CLI definition for various axolotl commands."""
import os

# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
Expand All @@ -15,8 +13,8 @@
fetch_from_github,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf


@click.group()
Expand Down
3 changes: 2 additions & 1 deletion src/axolotl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
from axolotl.utils.trainer import setup_trainer

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
Expand Down
11 changes: 10 additions & 1 deletion src/axolotl/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""

import importlib.util
import os
import re

import torch
Expand Down Expand Up @@ -33,4 +34,12 @@ def get_pytorch_version() -> tuple[int, int, int]:
return major, minor, patch


# pylint: enable=duplicate-code
def set_pytorch_cuda_alloc_conf():
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"
11 changes: 0 additions & 11 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,17 +512,6 @@ def prepare_opinionated_env(cfg):
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def set_pytorch_cuda_alloc_conf():
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"


def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
Expand Down

0 comments on commit 0e0c6d7

Please sign in to comment.