Skip to content

Commit

Permalink
move the setting of PYTORCH_CUDA_ALLOC_CONF to the cli rather than tr…
Browse files Browse the repository at this point in the history
…ain module
  • Loading branch information
winglian committed Dec 17, 2024
1 parent 1c14c4a commit 1699178
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/axolotl/cli/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
"""CLI definition for various axolotl commands."""
import os

# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
Expand All @@ -14,6 +16,7 @@
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf


@click.group()
Expand Down Expand Up @@ -48,6 +51,9 @@ def train(config: str, accelerate: bool, **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}

# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
Expand Down
5 changes: 1 addition & 4 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer
from axolotl.utils.trainer import setup_trainer

try:
from optimum.bettertransformer import BetterTransformer
Expand Down Expand Up @@ -53,9 +53,6 @@ class TrainDatasetMeta:
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
Expand Down

0 comments on commit 1699178

Please sign in to comment.