From 3b13b54a93a95abc3ad1c7322319223450323893 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 9 Apr 2024 16:21:00 -0400 Subject: [PATCH] Fix circular imports --- src/axolotl/core/trainer_builder.py | 5 +---- src/axolotl/utils/__init__.py | 8 ++++++++ src/axolotl/utils/callbacks/__init__.py | 2 +- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 8bdcfba268..35318b836d 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -36,6 +36,7 @@ from axolotl.loraplus import create_loraplus_optimizer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler +from axolotl.utils import is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, GPUStatsCallback, @@ -71,10 +72,6 @@ LOG = logging.getLogger("axolotl.core.trainer_builder") -def is_mlflow_available(): - return importlib.util.find_spec("mlflow") is not None - - def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): if isinstance(tag_names, str): tag_names = [tag_names] diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index e69de29bb2..99dec79f1b 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -0,0 +1,8 @@ +""" +Basic utils for Axolotl +""" +import importlib + + +def is_mlflow_available(): + return importlib.util.find_spec("mlflow") is not None diff --git a/src/axolotl/utils/callbacks/__init__.py b/src/axolotl/utils/callbacks/__init__.py index 890883512e..d907e3f6a3 100644 --- a/src/axolotl/utils/callbacks/__init__.py +++ b/src/axolotl/utils/callbacks/__init__.py @@ -27,7 +27,7 @@ ) from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy -from axolotl.core.trainer_builder import is_mlflow_available +from axolotl.utils import is_mlflow_available from axolotl.utils.bench import log_gpu_memory_usage from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.distributed import (