Skip to content

Commit

Permalink
4bit/llava-v1.5-13b-3GB
Browse files Browse the repository at this point in the history
  • Loading branch information
camenduru authored Oct 14, 2023
1 parent 47a149f commit c86c8d2
Showing 1 changed file with 104 additions and 22 deletions.
126 changes: 104 additions & 22 deletions LLaVA_13b_4bit_colab.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,109 @@
"!pip install -e ."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# from transformers import AutoTokenizer, BitsAndBytesConfig\n",
"# from llava.model import LlavaLlamaForCausalLM\n",
"# import torch\n",
"# model_path = \"4bit/llava-v1.5-13b-3GB\"\n",
"# kwargs = {\"device_map\": \"auto\"}\n",
"# kwargs['load_in_4bit'] = True\n",
"# kwargs['quantization_config'] = BitsAndBytesConfig(\n",
"# load_in_4bit=True,\n",
"# bnb_4bit_compute_dtype=torch.float16,\n",
"# bnb_4bit_use_double_quant=True,\n",
"# bnb_4bit_quant_type='nf4'\n",
"# )\n",
"# model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n",
"# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
"\n",
"# vision_tower = model.get_vision_tower()\n",
"# if not vision_tower.is_loaded:\n",
"# vision_tower.load_model()\n",
"# vision_tower.to(device='cuda')\n",
"# image_processor = vision_tower.image_processor\n",
"\n",
"# import requests\n",
"# from PIL import Image\n",
"# from io import BytesIO\n",
"\n",
"# def load_image(image_file):\n",
"# if image_file.startswith('http') or image_file.startswith('https'):\n",
"# response = requests.get(image_file)\n",
"# image = Image.open(BytesIO(response.content)).convert('RGB')\n",
"# else:\n",
"# image = Image.open(image_file).convert('RGB')\n",
"# return image\n",
"\n",
"# from llava.conversation import conv_templates, SeparatorStyle\n",
"# from llava.utils import disable_torch_init\n",
"\n",
"# disable_torch_init()\n",
"# conv_mode = \"llava_v0\"\n",
"# conv = conv_templates[conv_mode].copy()\n",
"# roles = conv.roles\n",
"\n",
"# image = load_image(\"https://llava-vl.github.io/static/images/view.jpg\")\n",
"# image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values'].half().cuda()\n",
"\n",
"# from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN\n",
"# from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria\n",
"# from transformers import TextStreamer\n",
"# import torch\n",
"\n",
"# while True:\n",
"# try:\n",
"# inp = input(f\"{roles[0]}: \")\n",
"# except EOFError:\n",
"# inp = \"\"\n",
"# if not inp:\n",
"# print(\"exit...\")\n",
"# break\n",
"\n",
"# print(f\"{roles[1]}: \", end=\"\")\n",
"\n",
"# if image is not None:\n",
"# # first message\n",
"# if model.config.mm_use_im_start_end:\n",
"# inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\\n' + inp\n",
"# else:\n",
"# inp = DEFAULT_IMAGE_TOKEN + '\\n' + inp\n",
"# conv.append_message(conv.roles[0], inp)\n",
"# image = None\n",
"# else:\n",
"# # later messages\n",
"# conv.append_message(conv.roles[0], inp)\n",
"# conv.append_message(conv.roles[1], None)\n",
"# prompt = conv.get_prompt()\n",
"\n",
"# input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda()\n",
"# stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2\n",
"# keywords = [stop_str]\n",
"# stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)\n",
"# streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)\n",
"\n",
"# with torch.inference_mode():\n",
"# output_ids = model.generate(\n",
"# input_ids,\n",
"# images=image_tensor,\n",
"# do_sample=True,\n",
"# temperature=0.2,\n",
"# max_new_tokens=1024,\n",
"# streamer=streamer,\n",
"# use_cache=True,\n",
"# stopping_criteria=[stopping_criteria])\n",
"\n",
"# outputs = tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip()\n",
"# conv.messages[-1][-1] = outputs\n",
"\n",
"# print(\"\\n\", {\"prompt\": prompt, \"outputs\": outputs}, \"\\n\")"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -52,28 +155,7 @@
" '--model-path', '4bit/llava-v1.5-13b-3GB',\n",
" '--load-4bit'\n",
"]\n",
"threading.Thread(target=lambda: subprocess.run(command, check=True, shell=False), daemon=True).start()\n",
"\n",
"# from transformers import AutoTokenizer, BitsAndBytesConfig\n",
"# from llava.model import LlavaLlamaForCausalLM\n",
"# import torch\n",
"# model_path = \"4bit/llava-v1.5-13b-3GB\"\n",
"# kwargs = {\"device_map\": \"auto\"}\n",
"# kwargs['load_in_4bit'] = True\n",
"# kwargs['quantization_config'] = BitsAndBytesConfig(\n",
"# load_in_4bit=True,\n",
"# bnb_4bit_compute_dtype=torch.float16,\n",
"# bnb_4bit_use_double_quant=True,\n",
"# bnb_4bit_quant_type='nf4'\n",
"# )\n",
"# model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)\n",
"# tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n",
"\n",
"# vision_tower = model.get_vision_tower()\n",
"# if not vision_tower.is_loaded:\n",
"# vision_tower.load_model()\n",
"# vision_tower.to(device='cuda')\n",
"# image_processor = vision_tower.image_processor"
"threading.Thread(target=lambda: subprocess.run(command, check=True, shell=False), daemon=True).start()"
]
},
{
Expand Down

0 comments on commit c86c8d2

Please sign in to comment.