diff --git a/scripts/train/train.py b/scripts/train/train.py index c07a1898f8..c9e2d67bf4 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -21,25 +21,16 @@ from composer.utils import dist, get_device, reproducibility from omegaconf import DictConfig from omegaconf import OmegaConf as om -from rich.traceback import install +from llmfoundry.callbacks import AsyncEval +from llmfoundry.data.dataloader import build_dataloader from llmfoundry.eval.metrics.nlp import InContextLearningMetric +from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.utils import ( find_mosaicml_logger, log_train_analytics, maybe_create_mosaicml_logger, ) -from llmfoundry.utils.exceptions import ( - BaseContextualError, - EvalDataLoaderLocation, - TrainDataLoaderLocation, -) - -install() - -from llmfoundry.callbacks import AsyncEval -from llmfoundry.data.dataloader import build_dataloader -from llmfoundry.layers_registry import ffns_with_megablocks from llmfoundry.utils.builders import ( add_metrics_to_eval_loaders, build_algorithm, @@ -61,6 +52,11 @@ process_init_device, update_batch_size_info, ) +from llmfoundry.utils.exceptions import ( + BaseContextualError, + EvalDataLoaderLocation, + TrainDataLoaderLocation, +) from llmfoundry.utils.registry_utils import import_file log = logging.getLogger(__name__)