Skip to content

Commit

Permalink
Add CLI for train.py (#1337)
Browse files Browse the repository at this point in the history
* train cli

* precommit

* update train to reflect script

* rm extraneoud

* add eval back

* eval

* trim eval

* rm train

* fix repo

* fixing repo

* readme

* precommit

* test

* import typer

* typer

* use cmd line

* use cmd line

* fix llm-foundry train train to llm-foundry train

* test

* test

* test

* test

* test

* test

* test

* test

* test

* test

* test

* precommit

* deprecate old script to use new one

* readd main

* reformatting

* typo

* typo

* typo

* typo

* precommit

* precommit

* precommit

* precommit

* test

* test

* debug

* debug

* debug

* debug

* resolve 2/3 comments

* typo

* help

* rename

* import

* import pathing

* precommit

* precommit

* precommit

* rm type ignore

* rm List for list

* type ignore

* precommit

---------

Co-authored-by: v-chen_data <[email protected]>
  • Loading branch information
KuuCi and v-chen_data authored Jul 10, 2024
1 parent 08a3624 commit 129bb56
Show file tree
Hide file tree
Showing 6 changed files with 620 additions and 578 deletions.
17 changes: 17 additions & 0 deletions llmfoundry/cli/cli.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,29 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0

from typing import Optional

import typer

from llmfoundry.cli import registry_cli
from llmfoundry.train import train_from_yaml

app = typer.Typer(pretty_exceptions_show_locals=False)
app.add_typer(registry_cli.app, name='registry')


@app.command(name='train')
def train(
yaml_path: str = typer.Argument(
...,
help='Path to the YAML configuration file',
), # type: ignore
args_list: Optional[list[str]] = typer.
Argument(None, help='Additional command line arguments'), # type: ignore
):
"""Run the training with optional overrides from CLI."""
train_from_yaml(yaml_path, args_list)


if __name__ == '__main__':
app()
17 changes: 17 additions & 0 deletions llmfoundry/train/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Copyright 2024 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
from llmfoundry.train.train import (
TRAIN_CONFIG_KEYS,
TrainConfig,
train,
train_from_yaml,
validate_config,
)

__all__ = [
'train',
'train_from_yaml',
'TrainConfig',
'TRAIN_CONFIG_KEYS',
'validate_config',
]
Loading

0 comments on commit 129bb56

Please sign in to comment.