From 67dcab9fabd58d04e46f37cec2afb893939721b0 Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Sun, 24 Mar 2024 08:25:51 -0700 Subject: [PATCH] LLM Foundry CLI (just registry) (#1043) --- REGISTRY.md | 8 +++- llmfoundry/cli/__init__.py | 2 + llmfoundry/cli/cli.py | 12 ++++++ llmfoundry/cli/registry_cli.py | 72 ++++++++++++++++++++++++++++++++++ llmfoundry/registry.py | 54 ++++++++++++++++--------- setup.py | 4 ++ tests/cli/test_registry_cli.py | 22 +++++++++++ 7 files changed, 154 insertions(+), 20 deletions(-) create mode 100644 llmfoundry/cli/__init__.py create mode 100644 llmfoundry/cli/cli.py create mode 100644 llmfoundry/cli/registry_cli.py create mode 100644 tests/cli/test_registry_cli.py diff --git a/REGISTRY.md b/REGISTRY.md index b0ac6f9d81..ebb70d41de 100644 --- a/REGISTRY.md +++ b/REGISTRY.md @@ -81,4 +81,10 @@ code_paths: ## Discovering registrable components -Coming soon +To help find and understand registrable components, you can use the `llmfoundry registry` cli command. + +We provide two commands: +- `llmfoundry registry get [--group]`: List all registries, and their components, optionally specifying a specific registry. Example usage: `llmfoundry registry get --group loggers` or `llmfoundry registry get` +- `llmfoundry registry find `: Get information about a specific registered component. Example usage: `llmfoundry registry find loggers wandb` + +Use `--help` on any of these commands for more information. diff --git a/llmfoundry/cli/__init__.py b/llmfoundry/cli/__init__.py new file mode 100644 index 0000000000..80950cb7b4 --- /dev/null +++ b/llmfoundry/cli/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 diff --git a/llmfoundry/cli/cli.py b/llmfoundry/cli/cli.py new file mode 100644 index 0000000000..25c1a6d230 --- /dev/null +++ b/llmfoundry/cli/cli.py @@ -0,0 +1,12 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import typer + +from llmfoundry.cli import registry_cli + +app = typer.Typer(pretty_exceptions_show_locals=False) +app.add_typer(registry_cli.app, name='registry') + +if __name__ == '__main__': + app() diff --git a/llmfoundry/cli/registry_cli.py b/llmfoundry/cli/registry_cli.py new file mode 100644 index 0000000000..03046c2f07 --- /dev/null +++ b/llmfoundry/cli/registry_cli.py @@ -0,0 +1,72 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Optional + +import typer +from rich.console import Console +from rich.table import Table + +from llmfoundry import registry +from llmfoundry.utils.registry_utils import TypedRegistry + +console = Console() +app = typer.Typer(pretty_exceptions_show_locals=False) + + +def _get_registries(group: Optional[str] = None) -> list[TypedRegistry]: + registry_attr_names = dir(registry) + registry_attrs = [getattr(registry, name) for name in registry_attr_names] + available_registries = [ + r for r in registry_attrs if isinstance(r, TypedRegistry) + ] + + if group is not None and group not in registry_attr_names: + console.print( + f'Group {group} not found in registry. Run `llmfoundry registry get` to see available groups.' + ) + return [] + + if group is not None: + available_registries = [getattr(registry, group)] + + return available_registries + + +@app.command() +def get(group: Optional[str] = None): + """Get the available registries. + + Args: + group (Optional[str], optional): The group to get. If not provided, all groups will be shown. Defaults to None. + """ + available_registries = _get_registries(group) + + table = Table('Registry', 'Description', 'Options', show_lines=True) + for r in available_registries: + table.add_row('.'.join(r.namespace), r.description, + ', '.join(r.get_all())) + + console.print(table) + + +@app.command() +def find(group: str, name: str): + """Find a registry entry by name. + + Args: + group (str): The group to search. + name (str): The name of the entry to search for. + """ + available_registries = _get_registries(group) + if not available_registries: + return + + r = available_registries[0] + find_output = r.find(name) + + table = Table('Module', 'File', 'Line number', 'Docstring') + table.add_row(find_output['module'], find_output['file'], + str(find_output['line_no']), find_output['docstring']) + + console.print(table) diff --git a/llmfoundry/registry.py b/llmfoundry/registry.py index 86dc3513b6..6e664ca9c1 100644 --- a/llmfoundry/registry.py +++ b/llmfoundry/registry.py @@ -11,54 +11,70 @@ from llmfoundry.interfaces import CallbackWithConfig from llmfoundry.utils.registry_utils import create_registry -_loggers_description = """The loggers registry is used to register classes that implement the LoggerDestination interface. -These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers -will be constructed by directly passing along the specified kwargs to the constructor.""" +_loggers_description = ( + 'The loggers registry is used to register classes that implement the LoggerDestination interface. ' + + + 'These classes are used to log data from the training loop, and will be passed to the loggers arg of the Trainer. The loggers ' + + + 'will be constructed by directly passing along the specified kwargs to the constructor.' +) loggers = create_registry('llmfoundry', 'loggers', generic_type=Type[LoggerDestination], entry_points=True, description=_loggers_description) -_callbacks_description = """The callbacks registry is used to register classes that implement the Callback interface. -These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer. -The callbacks will be constructed by directly passing along the specified kwargs to the constructor.""" +_callbacks_description = ( + 'The callbacks registry is used to register classes that implement the Callback interface. ' + + + 'These classes are used to interact with the Composer event system, and will be passed to the callbacks arg of the Trainer. ' + + + 'The callbacks will be constructed by directly passing along the specified kwargs to the constructor.' +) callbacks = create_registry('llmfoundry', 'callbacks', generic_type=Type[Callback], entry_points=True, description=_callbacks_description) -_callbacks_with_config_description = """The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. -These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor.""" +_callbacks_with_config_description = ( + 'The callbacks_with_config registry is used to register classes that implement the CallbackWithConfig interface. ' + + + 'These are the same as the callbacks registry, except that they additionally take the full training config as an argument to their constructor.' +) callbacks_with_config = create_registry( - 'llm_foundry', - 'callbacks_with_config', + 'llm_foundry.callbacks_with_config', generic_type=Type[CallbackWithConfig], entry_points=True, description=_callbacks_with_config_description) -_optimizers_description = """The optimizers registry is used to register classes that implement the Optimizer interface. -The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the -specified kwargs to the constructor, along with the model parameters.""" +_optimizers_description = ( + 'The optimizers registry is used to register classes that implement the Optimizer interface. ' + + + 'The optimizer will be passed to the optimizers arg of the Trainer. The optimizer will be constructed by directly passing along the ' + + 'specified kwargs to the constructor, along with the model parameters.') optimizers = create_registry('llmfoundry', 'optimizers', generic_type=Type[Optimizer], entry_points=True, description=_optimizers_description) -_algorithms_description = """The algorithms registry is used to register classes that implement the Algorithm interface. -The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the -specified kwargs to the constructor.""" +_algorithms_description = ( + 'The algorithms registry is used to register classes that implement the Algorithm interface. ' + + + 'The algorithm will be passed to the algorithms arg of the Trainer. The algorithm will be constructed by directly passing along the ' + + 'specified kwargs to the constructor.') algorithms = create_registry('llmfoundry', 'algorithms', generic_type=Type[Algorithm], entry_points=True, description=_algorithms_description) -_schedulers_description = """The schedulers registry is used to register classes that implement the ComposerScheduler interface. -The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the -specified kwargs to the constructor.""" +_schedulers_description = ( + 'The schedulers registry is used to register classes that implement the ComposerScheduler interface. ' + + + 'The scheduler will be passed to the schedulers arg of the Trainer. The scheduler will be constructed by directly passing along the ' + + 'specified kwargs to the constructor.') schedulers = create_registry('llmfoundry', 'schedulers', generic_type=Type[ComposerScheduler], diff --git a/setup.py b/setup.py index 37ab272f9d..fd89cc30a5 100644 --- a/setup.py +++ b/setup.py @@ -74,6 +74,7 @@ 'beautifulsoup4>=4.12.2,<5', # required for model download utils 'tenacity>=8.2.3,<9', 'catalogue>=2,<3', + 'typer[all]<1', ] extra_deps = {} @@ -145,4 +146,7 @@ install_requires=install_requires, extras_require=extra_deps, python_requires='>=3.9', + entry_points={ + 'console_scripts': ['llmfoundry = llmfoundry.cli.cli:app'], + }, ) diff --git a/tests/cli/test_registry_cli.py b/tests/cli/test_registry_cli.py new file mode 100644 index 0000000000..2c61118baf --- /dev/null +++ b/tests/cli/test_registry_cli.py @@ -0,0 +1,22 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry import registry +from llmfoundry.cli.registry_cli import _get_registries +from llmfoundry.utils.registry_utils import TypedRegistry + + +def test_get_registries(): + available_registries = _get_registries() + expected_registries = [ + getattr(registry, r) + for r in dir(registry) + if isinstance(getattr(registry, r), TypedRegistry) + ] + assert available_registries == expected_registries + + +def test_get_registries_group(): + available_registries = _get_registries('loggers') + assert len(available_registries) == 1 + assert available_registries[0].namespace == ('llmfoundry', 'loggers')