Skip to content

Commit

Permalink
review comments; slightly DRYing up things
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Saunders committed Dec 16, 2024
1 parent 2fb3ed5 commit edfee9e
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 21 deletions.
1 change: 0 additions & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,6 @@ def load_datasets(
tokenizer,
processor=processor,
)
print(train_dataset, eval_dataset, total_num_steps)

if (
cli_args.debug
Expand Down
12 changes: 3 additions & 9 deletions src/axolotl/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from axolotl.train import TrainDatasetMeta
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer

project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
Expand Down Expand Up @@ -79,14 +79,8 @@ def evaluate(
- Dictionary of evaluation metrics
"""
# pylint: disable=duplicate-code
# Set up CUDA allocation config if using PyTorch >= 2.2
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

# Load tokenizer
LOG.debug(
Expand Down
19 changes: 8 additions & 11 deletions src/axolotl/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from axolotl.utils.dict import DictDefault
from axolotl.utils.freeze import freeze_layers_except
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer

try:
from optimum.bettertransformer import BetterTransformer
Expand Down Expand Up @@ -53,25 +53,22 @@ class TrainDatasetMeta:
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# enable expandable segments for cuda allocation to improve VRAM usage
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"

# load the tokenizer first
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()

# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)

# Load processor for multimodal models if needed
processor = None
if cfg.is_multimodal:
processor = load_processor(cfg, tokenizer)

# Get datasets
train_dataset = dataset_meta.train_dataset
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
Expand Down
11 changes: 11 additions & 0 deletions src/axolotl/utils/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg):
os.environ["TOKENIZERS_PARALLELISM"] = "false"


def set_pytorch_cuda_alloc_conf():
"""Set up CUDA allocation config if using PyTorch >= 2.2"""
torch_version = torch.__version__.split(".")
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"


def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
Expand Down

0 comments on commit edfee9e

Please sign in to comment.