diff --git a/.mypy.ini b/.mypy.ini index 1bbe04d2c7..bb9a21c657 100644 --- a/.mypy.ini +++ b/.mypy.ini @@ -32,6 +32,9 @@ ignore_missing_imports = True [mypy-bitsandbytes] ignore_missing_imports = True +[mypy-requests] +ignore_missing_imports = True + [mypy-datasets] ignore_missing_imports = True diff --git a/README.md b/README.md index 0bbba4fc6e..0d86c304fb 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,10 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ # gradio accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ --lora_model_dir="./lora-out" --gradio + +# remote yaml files - the yaml config can be hosted on a public URL +# Note: the yaml config must directly link to the **raw** yaml +accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml ``` ## Installation @@ -988,6 +992,9 @@ Run accelerate launch -m axolotl.cli.train your_config.yml ``` +> [!TIP] +> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml` + #### Preprocess dataset You can optionally pre-tokenize dataset with the following before finetuning. diff --git a/requirements-dev.txt b/requirements-dev.txt index df7e312cb1..4b5df167b6 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ pre-commit black mypy +types-requests diff --git a/requirements.txt b/requirements.txt index 2e978c16da..4d1073500f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ deepspeed>=0.13.1 addict fire PyYAML>=6.0 +requests datasets>=2.15.0 flash-attn==2.3.3 sentencepiece diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index e900fba955..6b3894cb53 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -1,16 +1,20 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" import importlib +import json import logging import math import os import random import sys +import tempfile from pathlib import Path from threading import Thread from typing import Any, Dict, List, Optional, Union +from urllib.parse import urlparse import gradio as gr +import requests import torch import yaml @@ -59,6 +63,52 @@ def print_axolotl_text_art(suffix=None): print(ascii_art) +def check_remote_config(config: Union[str, Path]): + # Check if the config is a valid HTTPS URL to a .yml or .yaml file + if not (isinstance(config, str) and config.startswith("https://")): + return config # Return the original value if it's not a valid URL + + filename = os.path.basename(urlparse(config).path) + temp_dir = tempfile.mkdtemp() + + try: + response = requests.get(config, timeout=30) + response.raise_for_status() # Check for HTTP errors + + content = response.content + try: + # Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML + json.loads(content) + # Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link + LOG.warning( + f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended." + ) + except json.JSONDecodeError: + # If it's not valid JSON, verify it's valid YAML + try: + yaml.safe_load(content) + except yaml.YAMLError as err: + raise ValueError( + f"Failed to parse the content at {config} as YAML: {err}" + ) from err + + # Write the content to a file if it's valid YAML (or JSON treated as YAML) + output_path = Path(temp_dir) / filename + with open(output_path, "wb") as file: + file.write(content) + LOG.info( + f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n" + ) + return output_path + + except requests.RequestException as err: + # This catches all requests-related exceptions including HTTPError + raise RuntimeError(f"Failed to download {config}: {err}") from err + except Exception as err: + # Catch-all for any other exceptions + raise err + + def get_multi_line_input() -> Optional[str]: print("Give me an instruction (Ctrl + D to submit): ") instruction = "" @@ -270,9 +320,10 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b return not any(el in list2 for el in list1) -def load_cfg(config: Path = Path("examples/"), **kwargs): +def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs): + config = check_remote_config(config) if Path(config).is_dir(): - config = choose_config(config) + config = choose_config(Path(config)) # load the config from the yaml file with open(config, encoding="utf-8") as file: diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index e7bc612b7e..89ab023e5f 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -3,6 +3,7 @@ """ import logging from pathlib import Path +from typing import Union import fire import transformers @@ -23,7 +24,7 @@ LOG = logging.getLogger("axolotl.cli.preprocess") -def do_cli(config: Path = Path("examples/"), **kwargs): +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) diff --git a/src/axolotl/cli/shard.py b/src/axolotl/cli/shard.py index 85901b0f2a..48f22790ac 100644 --- a/src/axolotl/cli/shard.py +++ b/src/axolotl/cli/shard.py @@ -3,6 +3,7 @@ """ import logging from pathlib import Path +from typing import Union import fire import transformers @@ -25,7 +26,7 @@ def shard( model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization) -def do_cli(config: Path = Path("examples/"), **kwargs): +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 6cbe2c9603..05fd63ae80 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -3,7 +3,7 @@ """ import logging from pathlib import Path -from typing import Tuple +from typing import Tuple, Union import fire from transformers.hf_argparser import HfArgumentParser @@ -25,7 +25,7 @@ LOG = logging.getLogger("axolotl.cli.train") -def do_cli(config: Path = Path("examples/"), **kwargs): +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): # pylint: disable=duplicate-code parsed_cfg = load_cfg(config, **kwargs) parser = HfArgumentParser((TrainerCliArgs))