diff --git a/llmfoundry/cli/cli.py b/llmfoundry/cli/cli.py index 8e86e76467..f70a3ffa4e 100644 --- a/llmfoundry/cli/cli.py +++ b/llmfoundry/cli/cli.py @@ -6,7 +6,7 @@ import typer from llmfoundry.cli import registry_cli -from llmfoundry.train import train_from_yaml +from llmfoundry.command_utils import train_from_yaml app = typer.Typer(pretty_exceptions_show_locals=False) app.add_typer(registry_cli.app, name='registry') diff --git a/llmfoundry/train/__init__.py b/llmfoundry/command_utils/__init__.py similarity index 86% rename from llmfoundry/train/__init__.py rename to llmfoundry/command_utils/__init__.py index 8a4c2749db..cd3d699f47 100644 --- a/llmfoundry/train/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -1,6 +1,6 @@ # Copyright 2024 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 -from llmfoundry.train.train import ( +from llmfoundry.command_utils.train import ( TRAIN_CONFIG_KEYS, TrainConfig, train, diff --git a/llmfoundry/train/train.py b/llmfoundry/command_utils/train.py similarity index 100% rename from llmfoundry/train/train.py rename to llmfoundry/command_utils/train.py diff --git a/scripts/train/train.py b/scripts/train/train.py index 3c8973048b..728010d13a 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import sys -from llmfoundry.train import train_from_yaml +from llmfoundry.command_utils import train_from_yaml if __name__ == '__main__': yaml_path, args_list = sys.argv[1], sys.argv[2:] diff --git a/tests/a_scripts/train/test_train.py b/tests/a_scripts/train/test_train.py index a49f1ac07a..1f724a6070 100644 --- a/tests/a_scripts/train/test_train.py +++ b/tests/a_scripts/train/test_train.py @@ -11,8 +11,8 @@ from omegaconf import DictConfig, ListConfig from omegaconf import OmegaConf as om -from llmfoundry.train import TrainConfig # noqa: E402 -from llmfoundry.train import TRAIN_CONFIG_KEYS, train, validate_config +from llmfoundry.command_utils import TrainConfig # noqa: E402 +from llmfoundry.command_utils import TRAIN_CONFIG_KEYS, train, validate_config from llmfoundry.utils.config_utils import ( make_dataclass_and_log_config, update_batch_size_info, diff --git a/tests/a_scripts/train/test_train_inputs.py b/tests/a_scripts/train/test_train_inputs.py index 328a06a69e..73540afe2f 100644 --- a/tests/a_scripts/train/test_train_inputs.py +++ b/tests/a_scripts/train/test_train_inputs.py @@ -9,7 +9,7 @@ from omegaconf import DictConfig from omegaconf import OmegaConf as om -from llmfoundry.train import train # noqa: E402 +from llmfoundry.command_utils import train def make_fake_index_file(path: str) -> None: