Skip to content

Commit

Permalink
merged
Browse files Browse the repository at this point in the history
  • Loading branch information
milocress committed Mar 26, 2024
2 parents ae663cd + e590acf commit 4fd6269
Show file tree
Hide file tree
Showing 51 changed files with 981 additions and 552 deletions.
5 changes: 0 additions & 5 deletions .github/workflows/pr-gpu.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,6 @@ jobs:
markers: "gpu"
pip_deps: "[all]"
pytest_command: "coverage run -m pytest"
- name: "gpu-2.2.1-flash2"
container: mosaicml/llm-foundry:2.2.1_cu121_flash2-latest
markers: "gpu"
pip_deps: "[all-flash2]"
pytest_command: "coverage run -m pytest"
name: ${{ matrix.name }}
if: github.repository_owner == 'mosaicml'
with:
Expand Down
91 changes: 89 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,96 @@ export HUGGING_FACE_HUB_TOKEN=your-auth-token

and uncomment the line containing `--hf_repo_for_upload ...` in the above call to `inference/convert_composer_to_hf.py`.

# :construction: UNDER CONSTRUCTION: Registry
# Registry

You can use the registry to customize your workflows without forking the library. Some components of LLM Foundry are registrable, such as models, loggers, and callbacks. This means that you can register new options for these components, and then use them in your yaml config.

## Discovering registrable components
To help find and understand registrable components, you can use the `llmfoundry registry` cli command.

We provide two commands currently:
- `llmfoundry registry get [--group]`: List all registries, and their components, optionally specifying a specific registry. Example usage: `llmfoundry registry get --group loggers` or `llmfoundry registry get`
- `llmfoundry registry find <group> <name>`: Get information about a specific registered component. Example usage: `llmfoundry registry find loggers wandb`

Use `--help` on any of these commands for more information.

## How to register

There are a few ways to register a new component:

### Python entrypoints

You can specify registered components via a Python entrypoint if you are building your own package with registered components.

For example, the following would register the `WandBLogger` class, under the key `wandb`, in the `llm_foundry.loggers` registry:

<!--pytest.mark.skip-->
```yaml
[build-system]
requires = ["setuptools>=42", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "foundry_registry"
version = "0.1.0"
dependencies = [
"mosaicml",
"llm-foundry",
]

[project.entry-points."llm_foundry.loggers"]
my_logger = "foundry_registry.loggers:MyLogger"
```

### Direct call to register

You can also register a component directly in your code:

<!--pytest.mark.skip-->
```python
from composer.loggers import LoggerDestination
from llmfoundry.registry import loggers

class MyLogger(LoggerDestination):
pass

loggers.register("my_logger", func=MyLogger)
```

### Decorators

You can also use decorators to register components directly from your code:

<!--pytest.mark.skip-->
```python
from composer.loggers import LoggerDestination
from llmfoundry.registry import loggers

@loggers.register("my_logger")
class MyLogger(LoggerDestination):
pass
```

For both the direct call and decorator approaches, if using the LLM Foundry train/eval scripts, you will need to provide the `code_paths` argument, which is a list of files need to execute in order to register your components. For example, you may have a file called `foundry_imports.py` that contains the following:

<!--pytest.mark.skip-->
```python
from foundry_registry.loggers import MyLogger
from llmfoundry.registry import loggers

loggers.register("my_logger", func=MyLogger)
```

You would then provide `code_paths` to the train/eval scripts in your yaml config:

<!--pytest.mark.skip-->
```yaml
...
code_paths:
- foundry_imports.py
...
```

We are adopting an extensible registry for LLM Foundry to allow various extensions of the library without forking it. See [./REGISTRY.md] for more information as it develops.

# Learn more about LLM Foundry!

Expand Down
84 changes: 0 additions & 84 deletions REGISTRY.md

This file was deleted.

2 changes: 1 addition & 1 deletion llmfoundry/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@

# Source code

LLMFoundry is a Python package for training, finetuning, evaluating, and serving large scale LLM models on distributed compute infrustructure using MosaicML's Composer with PyTorch
LLMFoundry is a Python package for training, finetuning, evaluating, and serving large scale LLM models on distributed compute infrastructure using MosaicML's Composer with PyTorch

At a granular level, LLMFoundry is a library that consists of the following components:

Expand Down
2 changes: 0 additions & 2 deletions llmfoundry/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
flash_attn_fn, scaled_multihead_dot_product_attention, triton_flash_attn_fn)
from llmfoundry.models.layers.blocks import MPTBlock
from llmfoundry.models.layers.ffn import FFN_CLASS_REGISTRY, MPTMLP, build_ffn
from llmfoundry.models.model_registry import COMPOSER_MODEL_REGISTRY
from llmfoundry.models.mpt import (ComposerMPTCausalLM, MPTConfig,
MPTForCausalLM, MPTModel, MPTPreTrainedModel)
from llmfoundry.tokenizers import TiktokenTokenizerWrapper
Expand All @@ -53,7 +52,6 @@
'ComposerHFCausalLM',
'ComposerHFPrefixLM',
'ComposerHFT5',
'COMPOSER_MODEL_REGISTRY',
'scaled_multihead_dot_product_attention',
'flash_attn_fn',
'triton_flash_attn_fn',
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/callbacks/async_eval_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def _get_checkpoints_and_launch_runs(self, state: State):
log.debug('No saved checkpoints found yet on remote. Skipping eval')
return

if state.fsdp_elastic_sharded_enabled:
if state.fsdp_sharded_state_dict_enabled:
checkpoints_to_eval = self._get_ready_sharded_checkpoints(
checkpointer.all_saved_checkpoints_to_timestamp,
remote_checkpoints)
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/callbacks/curriculum_learning_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
from torch.utils.data import DataLoader

from llmfoundry.interfaces import CallbackWithConfig
from llmfoundry.utils.warnings import experimental
from llmfoundry.utils.warnings import experimental_class

log = logging.getLogger(__name__)


@experimental('CurriculumLearning callback')
@experimental_class('CurriculumLearning callback')
class CurriculumLearning(CallbackWithConfig):
"""Starts an epoch with a different dataset when resuming from a checkpoint.
Expand Down
37 changes: 14 additions & 23 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,39 +129,30 @@ def __init__(
mlflow_logging_config = {}
if self.mlflow_registered_model_name is not None:
import numpy as np
from mlflow.models.signature import ModelSignature
from mlflow.types.schema import ColSpec, Schema

# Both the metadata and the task are needed in order for mlflow
# and databricks optimized model serving to work
default_metadata = {'task': 'llm/v1/completions'}
passed_metadata = mlflow_logging_config.get('metadata', {})
mlflow_logging_config['metadata'] = {
**default_metadata,
**passed_metadata
}
mlflow_logging_config.setdefault('task', 'text-generation')

# Define a default input/output that is good for standard text generation LMs
input_schema = Schema([
ColSpec('string', 'prompt'),
ColSpec('double', 'temperature', optional=True),
ColSpec('integer', 'max_tokens', optional=True),
ColSpec('string', 'stop', optional=True),
ColSpec('integer', 'candidate_count', optional=True)
])

output_schema = Schema([ColSpec('string', 'predictions')])

default_signature = ModelSignature(inputs=input_schema,
outputs=output_schema)
mlflow_logging_config['metadata'] = passed_metadata
mlflow_logging_config.setdefault('task', 'llm/v1/completions')

default_input_example = {
'prompt': np.array(['What is Machine Learning?'])
}
is_chat = mlflow_logging_config['task'].endswith(
'chat') or mlflow_logging_config['metadata'].get(
'task', '').endswith('chat')
if is_chat:
default_input_example = {
'messages':
np.array([{
'role': 'user',
'content': 'What is Machine Learning?'
}])
}
mlflow_logging_config.setdefault('example_no_conversion', True)
mlflow_logging_config.setdefault('input_example',
default_input_example)
mlflow_logging_config.setdefault('signature', default_signature)

self.mlflow_logging_config = mlflow_logging_config

Expand Down
2 changes: 2 additions & 0 deletions llmfoundry/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
12 changes: 12 additions & 0 deletions llmfoundry/cli/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

import typer

from llmfoundry.cli import registry_cli

app = typer.Typer(pretty_exceptions_show_locals=False)
app.add_typer(registry_cli.app, name='registry')

if __name__ == '__main__':
app()
72 changes: 72 additions & 0 deletions llmfoundry/cli/registry_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import typer
from rich.console import Console
from rich.table import Table

from llmfoundry import registry
from llmfoundry.utils.registry_utils import TypedRegistry

console = Console()
app = typer.Typer(pretty_exceptions_show_locals=False)


def _get_registries(group: Optional[str] = None) -> list[TypedRegistry]:
registry_attr_names = dir(registry)
registry_attrs = [getattr(registry, name) for name in registry_attr_names]
available_registries = [
r for r in registry_attrs if isinstance(r, TypedRegistry)
]

if group is not None and group not in registry_attr_names:
console.print(
f'Group {group} not found in registry. Run `llmfoundry registry get` to see available groups.'
)
return []

if group is not None:
available_registries = [getattr(registry, group)]

return available_registries


@app.command()
def get(group: Optional[str] = None):
"""Get the available registries.
Args:
group (Optional[str], optional): The group to get. If not provided, all groups will be shown. Defaults to None.
"""
available_registries = _get_registries(group)

table = Table('Registry', 'Description', 'Options', show_lines=True)
for r in available_registries:
table.add_row('.'.join(r.namespace), r.description,
', '.join(r.get_all()))

console.print(table)


@app.command()
def find(group: str, name: str):
"""Find a registry entry by name.
Args:
group (str): The group to search.
name (str): The name of the entry to search for.
"""
available_registries = _get_registries(group)
if not available_registries:
return

r = available_registries[0]
find_output = r.find(name)

table = Table('Module', 'File', 'Line number', 'Docstring')
table.add_row(find_output['module'], find_output['file'],
str(find_output['line_no']), find_output['docstring'])

console.print(table)
5 changes: 5 additions & 0 deletions llmfoundry/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
build_finetuning_dataloader)
from llmfoundry.data.text_data import (StreamingTextDataset,
build_text_dataloader)
from llmfoundry.registry import dataloaders

dataloaders.register('text', func=build_text_dataloader)
dataloaders.register('text_denoising', func=build_text_denoising_dataloader)
dataloaders.register('finetuning', func=build_finetuning_dataloader)

__all__ = [
'MixtureOfDenoisersCollator',
Expand Down
Loading

0 comments on commit 4fd6269

Please sign in to comment.