From d61c53d2c3961c24a87c564859a007aa816c5ead Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Mon, 11 Mar 2024 13:57:21 -0700 Subject: [PATCH] Add code import to train/eval scripts (#1002) --- llmfoundry/registry.py | 37 +++++++++++++++++++++++++++++++++++ scripts/eval/eval.py | 11 +++++++++++ scripts/train/train.py | 11 +++++++++++ tests/test_registry.py | 44 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 103 insertions(+) create mode 100644 llmfoundry/registry.py create mode 100644 tests/test_registry.py diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py new file mode 100644 index 0000000000..6b14ff0650 --- /dev/null +++ b/llmfoundry/registry.py @@ -0,0 +1,37 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import importlib.util +import os +from pathlib import Path +from types import ModuleType +from typing import Union + +__all__ = ['import_file'] + + +def import_file(loc: Union[str, Path]) -> ModuleType: + """Import module from a file. Used to run arbitrary python code. + + Args: + name (str): Name of module to load. + loc (str / Path): Path to the file. + + Returns: + ModuleType: The module object. + """ + if not os.path.exists(loc): + raise FileNotFoundError(f'File {loc} does not exist.') + + spec = importlib.util.spec_from_file_location('python_code', str(loc)) + + assert spec is not None + assert spec.loader is not None + + module = importlib.util.module_from_spec(spec) + + try: + spec.loader.exec_module(module) + except Exception as e: + raise RuntimeError(f'Error executing {loc}') from e + return module diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index e36e08575b..7997eef46c 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -23,6 +23,7 @@ install() from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY +from llmfoundry.registry import import_file from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_evaluators, build_logger, build_tokenizer) @@ -188,6 +189,16 @@ def evaluate_model( def main(cfg: DictConfig) -> Tuple[List[Trainer], pd.DataFrame]: + # Run user provided code if specified + code_paths = pop_config(cfg, + 'code_paths', + must_exist=False, + default_value=[], + convert=True) + # Import any user provided code + for code_path in code_paths: + import_file(code_path) + om.resolve(cfg) # Create copy of config for logging diff --git a/scripts/train/train.py b/scripts/train/train.py index aa09157bc2..92c6afa128 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -30,6 +30,7 @@ from llmfoundry import COMPOSER_MODEL_REGISTRY from llmfoundry.callbacks import AsyncEval from llmfoundry.data.dataloader import build_dataloader +from llmfoundry.registry import import_file from llmfoundry.utils.builders import (add_metrics_to_eval_loaders, build_algorithm, build_callback, build_evaluators, build_logger, @@ -158,6 +159,16 @@ def main(cfg: DictConfig) -> Trainer: 'torch.distributed.*_base is a private function and will be deprecated.*' ) + # Run user provided code if specified + code_paths = pop_config(cfg, + 'code_paths', + must_exist=False, + default_value=[], + convert=True) + # Import any user provided code + for code_path in code_paths: + import_file(code_path) + # Check for incompatibilities between the model and data loaders validate_config(cfg) diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000000..1612bf13c7 --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,44 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import os +import pathlib + +import pytest + +from llmfoundry.registry import import_file + + +def test_registry_init_code(tmp_path: pathlib.Path): + register_code = """ +import os +os.environ['TEST_ENVIRON_REGISTRY_KEY'] = 'test' +""" + + with open(tmp_path / 'init_code.py', 'w') as _f: + _f.write(register_code) + + import_file(tmp_path / 'init_code.py') + + assert os.environ['TEST_ENVIRON_REGISTRY_KEY'] == 'test' + + del os.environ['TEST_ENVIRON_REGISTRY_KEY'] + + +def test_registry_init_code_fails(tmp_path: pathlib.Path): + register_code = """ +import os +os.environ['TEST_ENVIRON_REGISTRY_KEY'] = 'test' +asdf +""" + + with open(tmp_path / 'init_code.py', 'w') as _f: + _f.write(register_code) + + with pytest.raises(RuntimeError, match='Error executing .*init_code.py'): + import_file(tmp_path / 'init_code.py') + + +def test_registry_init_code_dne(tmp_path: pathlib.Path): + with pytest.raises(FileNotFoundError, match='File .* does not exist'): + import_file(tmp_path / 'init_code.py')