From 77a08fb32107a4c83771ed03aa8d2b45f8dd329a Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 16 Dec 2024 15:46:31 -0500 Subject: [PATCH] Basic evaluate CLI command / codepath (#2188) * basic evaluate CLI command / codepath * tests for evaluate CLI command * fixes and cleanup * review comments; slightly DRYing up things --------- Co-authored-by: Dan Saunders --- outputs | 1 + src/axolotl/train.py | 2 +- src/axolotl/utils/trainer.py | 11 +++++++++++ 3 files changed, 13 insertions(+), 1 deletion(-) create mode 120000 outputs diff --git a/outputs b/outputs new file mode 120000 index 0000000000..be3c4a823f --- /dev/null +++ b/outputs @@ -0,0 +1 @@ +/workspace/data/axolotl-artifacts \ No newline at end of file diff --git a/src/axolotl/train.py b/src/axolotl/train.py index dc7289b093..c5cf7ad6dd 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -26,7 +26,7 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.freeze import freeze_layers_except from axolotl.utils.models import load_model, load_processor, load_tokenizer -from axolotl.utils.trainer import setup_trainer +from axolotl.utils.trainer import set_pytorch_cuda_alloc_conf, setup_trainer try: from optimum.bettertransformer import BetterTransformer diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 32e54c9a86..fd09b3eb67 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -512,6 +512,17 @@ def prepare_opinionated_env(cfg): os.environ["TOKENIZERS_PARALLELISM"] = "false" +def set_pytorch_cuda_alloc_conf(): + """Set up CUDA allocation config if using PyTorch >= 2.2""" + torch_version = torch.__version__.split(".") + torch_major, torch_minor = int(torch_version[0]), int(torch_version[1]) + if torch_major == 2 and torch_minor >= 2: + if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None: + os.environ[ + "PYTORCH_CUDA_ALLOC_CONF" + ] = "expandable_segments:True,roundup_power2_divisions:16" + + def setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps ):