diff --git a/LLaVA_13b_4bit_vanilla_colab.ipynb b/LLaVA_13b_4bit_vanilla_colab.ipynb index b07df00..20c1c85 100644 --- a/LLaVA_13b_4bit_vanilla_colab.ipynb +++ b/LLaVA_13b_4bit_vanilla_colab.ipynb @@ -12,24 +12,14 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "id": "VjYy0F2gZIPR" - }, + "metadata": {}, "outputs": [], "source": [ "%cd /content\n", "!git clone -b v1.0 https://github.com/camenduru/LLaVA\n", "%cd /content/LLaVA\n", + "!pip install -q gradio .\n", "\n", - "!pip install -e ." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ "from transformers import AutoTokenizer, BitsAndBytesConfig\n", "from llava.model import LlavaLlamaForCausalLM\n", "import torch\n", @@ -49,89 +39,57 @@ "if not vision_tower.is_loaded:\n", " vision_tower.load_model()\n", "vision_tower.to(device='cuda')\n", - "image_processor = vision_tower.image_processor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ + "image_processor = vision_tower.image_processor\n", + "\n", + "import os\n", "import requests\n", "from PIL import Image\n", "from io import BytesIO\n", - "\n", - "def load_image(image_file):\n", - " if image_file.startswith('http') or image_file.startswith('https'):\n", - " response = requests.get(image_file)\n", - " image = Image.open(BytesIO(response.content)).convert('RGB')\n", - " else:\n", - " image = Image.open(image_file).convert('RGB')\n", - " return image\n", - "\n", "from llava.conversation import conv_templates, SeparatorStyle\n", "from llava.utils import disable_torch_init\n", - "\n", - "disable_torch_init()\n", - "conv_mode = \"llava_v0\"\n", - "conv = conv_templates[conv_mode].copy()\n", - "roles = conv.roles\n", - "\n", - "image = load_image(\"https://llava-vl.github.io/static/images/view.jpg\")\n", - "image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n", - "\n", "from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n", "from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n", "from transformers import TextStreamer\n", "import torch\n", "\n", - "while True:\n", - " try:\n", - " inp = input(f\"{roles[0]}: \")\n", - " except EOFError:\n", - " inp = \"\"\n", - " if not inp:\n", - " print(\"exit...\")\n", - " break\n", - "\n", - " print(f\"{roles[1]}: \", end=\"\")\n", - "\n", - " if image is not None:\n", - " # first message\n", - " if model.config.mm_use_im_start_end:\n", - " inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n", - " else:\n", - " inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n", - " conv.append_message(conv.roles[0], inp)\n", - " image = None\n", + "def caption_image(image_file, prompt):\n", + " if image_file.startswith('http') or image_file.startswith('https'):\n", + " response = requests.get(image_file)\n", + " image = Image.open(BytesIO(response.content)).convert('RGB')\n", " else:\n", - " # later messages\n", - " conv.append_message(conv.roles[0], inp)\n", + " image = Image.open(image_file).convert('RGB')\n", + " disable_torch_init()\n", + " conv_mode = \"llava_v0\"\n", + " conv = conv_templates[conv_mode].copy()\n", + " roles = conv.roles\n", + " image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n", + " inp = f\"{roles[0]}: {prompt}\"\n", + " inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n", + " conv.append_message(conv.roles[0], inp)\n", " conv.append_message(conv.roles[1], None)\n", - " prompt = conv.get_prompt()\n", - "\n", - " input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n", + " raw_prompt = conv.get_prompt()\n", + " input_ids = tokenizer_image_token(raw_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n", " stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n", " keywords = [stop_str]\n", " stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n", - " streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n", - "\n", " with torch.inference_mode():\n", - " output_ids = model.generate(\n", - " input_ids,\n", - " images=image_tensor,\n", - " do_sample=True,\n", - " temperature=0.2,\n", - " max_new_tokens=1024,\n", - " streamer=streamer,\n", - " use_cache=True,\n", - " stopping_criteria=[stopping_criteria])\n", - "\n", + " output_ids = model.generate(input_ids, images=image_tensor, do_sample=True, temperature=0.2, \n", + " max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])\n", " outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n", " conv.messages[-1][-1] = outputs\n", - "\n", - " print(\"\\n\", {\"prompt\": prompt, \"outputs\": outputs}, \"\\n\")" + " output = outputs.rsplit('', 1)[0]\n", + " return image, output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "image, output = caption_image(f'https://llava-vl.github.io/static/images/view.jpg', 'Describe the image and color details.')\n", + "print(output)\n", + "image" ] } ],