diff --git a/README.md b/README.md index 7f3230423c..d502eec0b5 100644 --- a/README.md +++ b/README.md @@ -124,6 +124,11 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ pip3 install packaging pip3 install -e '.[flash-attn,deepspeed]' ``` + 4. (Optional) Login to Huggingface to use gated models/datasets. + ```bash + huggingface-cli login + ``` + Get the token at huggingface.co/settings/tokens - LambdaLabs
diff --git a/scripts/finetune.py b/scripts/finetune.py index 7b6751e31c..118a97b844 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -7,6 +7,7 @@ from axolotl.cli import ( check_accelerate_default_config, + check_user_token, do_inference, do_merge_lora, load_cfg, @@ -31,6 +32,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs): ) parsed_cfg = load_cfg(config, **kwargs) check_accelerate_default_config() + check_user_token() parser = transformers.HfArgumentParser((TrainerCliArgs)) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 90e1d508b0..c3b580391a 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -14,6 +14,8 @@ # add src to the pythonpath so we don't need to pip install this from accelerate.commands.config import config_args from art import text2art +from huggingface_hub import HfApi +from huggingface_hub.utils import LocalTokenNotFoundError from transformers import GenerationConfig, TextStreamer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer @@ -247,3 +249,16 @@ def check_accelerate_default_config(): LOG.warning( f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors" ) + + +def check_user_token(): + # Verify if token is valid + api = HfApi() + try: + user_info = api.whoami() + return bool(user_info) + except LocalTokenNotFoundError: + LOG.warning( + "Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets." + ) + return False diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 72a9250c8d..c64755872b 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -8,6 +8,7 @@ from axolotl.cli import ( check_accelerate_default_config, + check_user_token, load_cfg, load_datasets, print_axolotl_text_art, @@ -21,6 +22,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs): print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) check_accelerate_default_config() + check_user_token() parser = transformers.HfArgumentParser((TrainerCliArgs)) parsed_cli_args, _ = parser.parse_args_into_dataclasses( return_remaining_strings=True