From 738a0576743e40a88ca29075bd2bf4ff8c6e6fc4 Mon Sep 17 00:00:00 2001 From: Jason Stillerman Date: Sat, 4 Nov 2023 23:59:22 -0400 Subject: [PATCH] Feat: Added Gradio support (#812) * Added gradio support * queuing and title * pre-commit run --- README.md | 8 ++++ requirements.txt | 1 + src/axolotl/cli/__init__.py | 89 +++++++++++++++++++++++++++++++++++- src/axolotl/cli/inference.py | 14 ++++-- 4 files changed, 108 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index bafec22a6a..c8862ae430 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,10 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml # inference accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ --lora_model_dir="./lora-out" + +# gradio +accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \ + --lora_model_dir="./lora-out" --gradio ``` ## Installation @@ -919,6 +923,10 @@ Pass the appropriate flag to the train command: cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \ --base_model="./completed-model" --prompter=None --load_in_8bit=True ``` +-- With gradio hosting + ```bash + python -m axolotl.cli.inference examples/your_config.yml --gradio + ``` Please use `--sample_packing False` if you have it on and receive the error similar to below: diff --git a/requirements.txt b/requirements.txt index c69c995fcf..f478481e90 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,3 +31,4 @@ scikit-learn==1.2.2 pynvml art fschat==0.2.29 +gradio diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index 27d5df386c..a055aea101 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -6,8 +6,10 @@ import random import sys from pathlib import Path +from threading import Thread from typing import Any, Dict, List, Optional, Union +import gradio as gr import torch import yaml @@ -16,7 +18,7 @@ from art import text2art from huggingface_hub import HfApi from huggingface_hub.utils import LocalTokenNotFoundError -from transformers import GenerationConfig, TextStreamer +from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer from axolotl.logging_config import configure_logging @@ -153,6 +155,91 @@ def do_inference( print(tokenizer.decode(generated["sequences"].cpu().tolist()[0])) +def do_inference_gradio( + *, + cfg: DictDefault, + cli_args: TrainerCliArgs, +): + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + prompter = cli_args.prompter + default_tokens = {"unk_token": "", "bos_token": "", "eos_token": ""} + + for token, symbol in default_tokens.items(): + # If the token isn't already specified in the config, add it + if not (cfg.special_tokens and token in cfg.special_tokens): + tokenizer.add_special_tokens({token: symbol}) + + prompter_module = None + if prompter: + prompter_module = getattr( + importlib.import_module("axolotl.prompters"), prompter + ) + + if cfg.landmark_attention: + from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id + + set_model_mem_id(model, tokenizer) + model.set_mem_cache_args( + max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None + ) + + model = model.to(cfg.device) + + def generate(instruction): + if not instruction: + return + if prompter_module: + # pylint: disable=stop-iteration-return + prompt: str = next( + prompter_module().build_prompt(instruction=instruction.strip("\n")) + ) + else: + prompt = instruction.strip() + batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) + + model.eval() + with torch.no_grad(): + generation_config = GenerationConfig( + repetition_penalty=1.1, + max_new_tokens=1024, + temperature=0.9, + top_p=0.95, + top_k=40, + bos_token_id=tokenizer.bos_token_id, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.pad_token_id, + do_sample=True, + use_cache=True, + return_dict_in_generate=True, + output_attentions=False, + output_hidden_states=False, + output_scores=False, + ) + streamer = TextIteratorStreamer(tokenizer) + generation_kwargs = { + "inputs": batch["input_ids"].to(cfg.device), + "generation_config": generation_config, + "streamer": streamer, + } + + thread = Thread(target=model.generate, kwargs=generation_kwargs) + thread.start() + + all_text = "" + + for new_text in streamer: + all_text += new_text + yield all_text + + demo = gr.Interface( + fn=generate, + inputs="textbox", + outputs="text", + title=cfg.get("gradio_title", "Axolotl Gradio Interface"), + ) + demo.queue().launch(show_api=False, share=True) + + def choose_config(path: Path): yaml_files = list(path.glob("*.yml")) diff --git a/src/axolotl/cli/inference.py b/src/axolotl/cli/inference.py index 91405d8c66..86ad8409ff 100644 --- a/src/axolotl/cli/inference.py +++ b/src/axolotl/cli/inference.py @@ -6,11 +6,16 @@ import fire import transformers -from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art +from axolotl.cli import ( + do_inference, + do_inference_gradio, + load_cfg, + print_axolotl_text_art, +) from axolotl.common.cli import TrainerCliArgs -def do_cli(config: Path = Path("examples/"), **kwargs): +def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs): # pylint: disable=duplicate-code print_axolotl_text_art() parsed_cfg = load_cfg(config, **kwargs) @@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs): ) parsed_cli_args.inference = True - do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) + if gradio: + do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args) + else: + do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args) if __name__ == "__main__":