From aa9510fdce6771b01de2fa4d846b8d1ff31682fe Mon Sep 17 00:00:00 2001 From: camenduru <54370274+camenduru@users.noreply.github.com> Date: Thu, 12 Oct 2023 20:06:35 +0300 Subject: [PATCH] test --- LLaVA_13b_8bit_colab.ipynb | 102 +++++++++++++++++++++++++++++++++++-- 1 file changed, 97 insertions(+), 5 deletions(-) diff --git a/LLaVA_13b_8bit_colab.ipynb b/LLaVA_13b_8bit_colab.ipynb index 07fda1d..0300901 100644 --- a/LLaVA_13b_8bit_colab.ipynb +++ b/LLaVA_13b_8bit_colab.ipynb @@ -18,7 +18,7 @@ "outputs": [], "source": [ "%cd /content\n", - "!git clone -b 5GB https://github.com/camenduru/LLaVA\n", + "!git clone dev https://github.com/camenduru/LLaVA\n", "%cd /content/LLaVA\n", "\n", "!pip install ninja\n", @@ -29,9 +29,9 @@ "!pip install git+https://github.com/huggingface/transformers\n", "\n", "# !python -m llava.serve.cli \\\n", - "# --model-path 4bit/llava-v1.5-13b-4GB-8bit \\\n", + "# --model-path 4bit/llava-v1.5-13b-5GB \\\n", "# --image-file \"https://llava-vl.github.io/static/images/view.jpg\" \\\n", - "# --load-4bit" + "# --load-8bit" ] }, { @@ -59,8 +59,8 @@ " '--controller', 'http://localhost:10000',\n", " '--port', '40000',\n", " '--worker', 'http://localhost:40000',\n", - " '--model-path', '4bit/llava-v1.5-13b-4GB-8bit',\n", - " '--load-4bit'\n", + " '--model-path', '4bit/llava-v1.5-13b-5GB',\n", + " '--load-8bit'\n", "]\n", "threading.Thread(target=lambda: subprocess.run(command, check=True, shell=False), daemon=True).start()" ] @@ -73,6 +73,98 @@ "source": [ "!python3 -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload --share" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llava.model.builder import load_pretrained_model\n", + "model_path = \"4bit/llava-v1.5-13b-5GB\"\n", + "tokenizer, model, image_processor, context_len = load_pretrained_model(\n", + " model_path=model_path,\n", + " model_base=None,\n", + " model_name=model_path.split(\"/\")[-1],\n", + " load_8bit=True,\n", + " load_4bit=False\n", + ")\n", + "\n", + "import requests\n", + "from PIL import Image\n", + "from io import BytesIO\n", + "\n", + "def load_image(image_file):\n", + " if image_file.startswith('http') or image_file.startswith('https'):\n", + " response = requests.get(image_file)\n", + " image = Image.open(BytesIO(response.content)).convert('RGB')\n", + " else:\n", + " image = Image.open(image_file).convert('RGB')\n", + " return image\n", + "\n", + "from llava.conversation import conv_templates, SeparatorStyle\n", + "from llava.utils import disable_torch_init\n", + "\n", + "disable_torch_init()\n", + "conv_mode = \"llava_v0\"\n", + "conv = conv_templates[conv_mode].copy()\n", + "roles = conv.roles\n", + "\n", + "image = load_image(\"https://llava-vl.github.io/static/images/view.jpg\")\n", + "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n", + "\n", + "from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n", + "from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n", + "from transformers import TextStreamer\n", + "import torch\n", + "\n", + "while True:\n", + " try:\n", + " inp = input(f\"{roles[0]}: \")\n", + " except EOFError:\n", + " inp = \"\"\n", + " if not inp:\n", + " print(\"exit...\")\n", + " break\n", + "\n", + " print(f\"{roles[1]}: \", end=\"\")\n", + "\n", + " if image is not None:\n", + " # first message\n", + " if model.config.mm_use_im_start_end:\n", + " inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n", + " else:\n", + " inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n", + " conv.append_message(conv.roles[0], inp)\n", + " image = None\n", + " else:\n", + " # later messages\n", + " conv.append_message(conv.roles[0], inp)\n", + " conv.append_message(conv.roles[1], None)\n", + " prompt = conv.get_prompt()\n", + "\n", + " input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n", + " stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n", + " keywords = [stop_str]\n", + " stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n", + " streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", + "\n", + " with torch.inference_mode():\n", + " output_ids = model.generate(\n", + " input_ids,\n", + " images=image_tensor,\n", + " do_sample=True,\n", + " temperature=0.2,\n", + " max_new_tokens=1024,\n", + " streamer=streamer,\n", + " use_cache=True,\n", + " stopping_criteria=[stopping_criteria])\n", + "\n", + " outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n", + " conv.messages[-1][-1] = outputs\n", + "\n", + " print(\"\\n\", {\"prompt\": prompt, \"outputs\": outputs}, \"\\n\")" + ] } ], "metadata": {