Skip to content

Commit

Permalink
Feat: Added Gradio support (#812)
Browse files Browse the repository at this point in the history
* Added gradio support

* queuing and title

* pre-commit run
  • Loading branch information
Stillerman authored Nov 5, 2023
1 parent cdc71f7 commit 738a057
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 4 deletions.
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
# inference
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out"

# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
```

## Installation
Expand Down Expand Up @@ -919,6 +923,10 @@ Pass the appropriate flag to the train command:
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
--base_model="./completed-model" --prompter=None --load_in_8bit=True
```
-- With gradio hosting
```bash
python -m axolotl.cli.inference examples/your_config.yml --gradio
```

Please use `--sample_packing False` if you have it on and receive the error similar to below:

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@ scikit-learn==1.2.2
pynvml
art
fschat==0.2.29
gradio
89 changes: 88 additions & 1 deletion src/axolotl/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import random
import sys
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union

import gradio as gr
import torch
import yaml

Expand All @@ -16,7 +18,7 @@
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextStreamer
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer

from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
Expand Down Expand Up @@ -153,6 +155,91 @@ def do_inference(
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))


def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}

for token, symbol in default_tokens.items():
# If the token isn't already specified in the config, add it
if not (cfg.special_tokens and token in cfg.special_tokens):
tokenizer.add_special_tokens({token: symbol})

prompter_module = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)

if cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id

set_model_mem_id(model, tokenizer)
model.set_mem_cache_args(
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
)

model = model.to(cfg.device)

def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)

model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}

thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()

all_text = ""

for new_text in streamer:
all_text += new_text
yield all_text

demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(show_api=False, share=True)


def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))

Expand Down
14 changes: 11 additions & 3 deletions src/axolotl/cli/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,16 @@
import fire
import transformers

from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs


def do_cli(config: Path = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
Expand All @@ -21,7 +26,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
)
parsed_cli_args.inference = True

do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)


if __name__ == "__main__":
Expand Down

0 comments on commit 738a057

Please sign in to comment.