Skip to content

Commit

Permalink
test
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru authored Oct 12, 2023
1 parent ee74ef8 commit aa9510f
Showing 1 changed file with 97 additions and 5 deletions.
102 changes: 97 additions & 5 deletions LLaVA_13b_8bit_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"outputs": [],
"source": [
"%cd /content\n",
"!git clone -b 5GB https://github.com/camenduru/LLaVA\n",
"!git clone dev https://github.com/camenduru/LLaVA\n",
"%cd /content/LLaVA\n",
"\n",
"!pip install ninja\n",
Expand All @@ -29,9 +29,9 @@
"!pip install git+https://github.com/huggingface/transformers\n",
"\n",
"# !python -m llava.serve.cli \\\n",
"# --model-path 4bit/llava-v1.5-13b-4GB-8bit \\\n",
"# --model-path 4bit/llava-v1.5-13b-5GB \\\n",
"# --image-file \"https://llava-vl.github.io/static/images/view.jpg\" \\\n",
"# --load-4bit"
"# --load-8bit"
]
},
{
Expand Down Expand Up @@ -59,8 +59,8 @@
" '--controller', 'http://localhost:10000',\n",
" '--port', '40000',\n",
" '--worker', 'http://localhost:40000',\n",
" '--model-path', '4bit/llava-v1.5-13b-4GB-8bit',\n",
" '--load-4bit'\n",
" '--model-path', '4bit/llava-v1.5-13b-5GB',\n",
" '--load-8bit'\n",
"]\n",
"threading.Thread(target=lambda: subprocess.run(command, check=True, shell=False), daemon=True).start()"
]
Expand All @@ -73,6 +73,98 @@
"source": [
"!python3 -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --share"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from llava.model.builder import load_pretrained_model\n",
"model_path = \"4bit/llava-v1.5-13b-5GB\"\n",
"tokenizer, model, image_processor, context_len = load_pretrained_model(\n",
" model_path=model_path,\n",
" model_base=None,\n",
" model_name=model_path.split(\"/\")[-1],\n",
" load_8bit=True,\n",
" load_4bit=False\n",
")\n",
"\n",
"import requests\n",
"from PIL import Image\n",
"from io import BytesIO\n",
"\n",
"def load_image(image_file):\n",
" if image_file.startswith('http') or image_file.startswith('https'):\n",
" response = requests.get(image_file)\n",
" image = Image.open(BytesIO(response.content)).convert('RGB')\n",
" else:\n",
" image = Image.open(image_file).convert('RGB')\n",
" return image\n",
"\n",
"from llava.conversation import conv_templates, SeparatorStyle\n",
"from llava.utils import disable_torch_init\n",
"\n",
"disable_torch_init()\n",
"conv_mode = \"llava_v0\"\n",
"conv = conv_templates[conv_mode].copy()\n",
"roles = conv.roles\n",
"\n",
"image = load_image(\"https://llava-vl.github.io/static/images/view.jpg\")\n",
"image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n",
"\n",
"from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
"from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n",
"from transformers import TextStreamer\n",
"import torch\n",
"\n",
"while True:\n",
" try:\n",
" inp = input(f\"{roles[0]}: \")\n",
" except EOFError:\n",
" inp = \"\"\n",
" if not inp:\n",
" print(\"exit...\")\n",
" break\n",
"\n",
" print(f\"{roles[1]}: \", end=\"\")\n",
"\n",
" if image is not None:\n",
" # first message\n",
" if model.config.mm_use_im_start_end:\n",
" inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n",
" else:\n",
" inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n",
" conv.append_message(conv.roles[0], inp)\n",
" image = None\n",
" else:\n",
" # later messages\n",
" conv.append_message(conv.roles[0], inp)\n",
" conv.append_message(conv.roles[1], None)\n",
" prompt = conv.get_prompt()\n",
"\n",
" input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
" stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
" keywords = [stop_str]\n",
" stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
" streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
"\n",
" with torch.inference_mode():\n",
" output_ids = model.generate(\n",
" input_ids,\n",
" images=image_tensor,\n",
" do_sample=True,\n",
" temperature=0.2,\n",
" max_new_tokens=1024,\n",
" streamer=streamer,\n",
" use_cache=True,\n",
" stopping_criteria=[stopping_criteria])\n",
"\n",
" outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n",
" conv.messages[-1][-1] = outputs\n",
"\n",
" print(\"\\n\", {\"prompt\": prompt, \"outputs\": outputs}, \"\\n\")"
]
}
],
"metadata": {
Expand Down

0 comments on commit aa9510f

Please sign in to comment.