-
Notifications
You must be signed in to change notification settings - Fork 538
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
LLM Foundry CLI (just registry) (#1043)
- Loading branch information
Showing
7 changed files
with
154 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Copyright 2024 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Copyright 2024 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import typer | ||
|
||
from llmfoundry.cli import registry_cli | ||
|
||
app = typer.Typer(pretty_exceptions_show_locals=False) | ||
app.add_typer(registry_cli.app, name='registry') | ||
|
||
if __name__ == '__main__': | ||
app() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
# Copyright 2024 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from typing import Optional | ||
|
||
import typer | ||
from rich.console import Console | ||
from rich.table import Table | ||
|
||
from llmfoundry import registry | ||
from llmfoundry.utils.registry_utils import TypedRegistry | ||
|
||
console = Console() | ||
app = typer.Typer(pretty_exceptions_show_locals=False) | ||
|
||
|
||
def _get_registries(group: Optional[str] = None) -> list[TypedRegistry]: | ||
registry_attr_names = dir(registry) | ||
registry_attrs = [getattr(registry, name) for name in registry_attr_names] | ||
available_registries = [ | ||
r for r in registry_attrs if isinstance(r, TypedRegistry) | ||
] | ||
|
||
if group is not None and group not in registry_attr_names: | ||
console.print( | ||
f'Group {group} not found in registry. Run `llmfoundry registry get` to see available groups.' | ||
) | ||
return [] | ||
|
||
if group is not None: | ||
available_registries = [getattr(registry, group)] | ||
|
||
return available_registries | ||
|
||
|
||
@app.command() | ||
def get(group: Optional[str] = None): | ||
"""Get the available registries. | ||
Args: | ||
group (Optional[str], optional): The group to get. If not provided, all groups will be shown. Defaults to None. | ||
""" | ||
available_registries = _get_registries(group) | ||
|
||
table = Table('Registry', 'Description', 'Options', show_lines=True) | ||
for r in available_registries: | ||
table.add_row('.'.join(r.namespace), r.description, | ||
', '.join(r.get_all())) | ||
|
||
console.print(table) | ||
|
||
|
||
@app.command() | ||
def find(group: str, name: str): | ||
"""Find a registry entry by name. | ||
Args: | ||
group (str): The group to search. | ||
name (str): The name of the entry to search for. | ||
""" | ||
available_registries = _get_registries(group) | ||
if not available_registries: | ||
return | ||
|
||
r = available_registries[0] | ||
find_output = r.find(name) | ||
|
||
table = Table('Module', 'File', 'Line number', 'Docstring') | ||
table.add_row(find_output['module'], find_output['file'], | ||
str(find_output['line_no']), find_output['docstring']) | ||
|
||
console.print(table) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
# Copyright 2024 MosaicML LLM Foundry authors | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from llmfoundry import registry | ||
from llmfoundry.cli.registry_cli import _get_registries | ||
from llmfoundry.utils.registry_utils import TypedRegistry | ||
|
||
|
||
def test_get_registries(): | ||
available_registries = _get_registries() | ||
expected_registries = [ | ||
getattr(registry, r) | ||
for r in dir(registry) | ||
if isinstance(getattr(registry, r), TypedRegistry) | ||
] | ||
assert available_registries == expected_registries | ||
|
||
|
||
def test_get_registries_group(): | ||
available_registries = _get_registries('loggers') | ||
assert len(available_registries) == 1 | ||
assert available_registries[0].namespace == ('llmfoundry', 'loggers') |