diff --git a/LLaVA_13b_8bit_colab.ipynb b/LLaVA_13b_8bit_colab.ipynb index 7d01b41..7eedae3 100644 --- a/LLaVA_13b_8bit_colab.ipynb +++ b/LLaVA_13b_8bit_colab.ipynb @@ -79,13 +79,26 @@ "metadata": {}, "outputs": [], "source": [ - "from transformers import AutoTokenizer\n", + "from transformers import AutoTokenizer, BitsAndBytesConfig\n", "from llava.model import LlavaLlamaForCausalLM\n", "import torch\n", "\n", - "model_path = \"4bit/llava-v1.5-13b-4GB-8bit\"\n", + "model_path = \"4bit/llava-v1.5-13b-5GB\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)\n", - "model = LlavaLlamaForCausalLM.from_pretrained(model_path)" + "\n", + "# model_path = \"4bit/llava-v1.5-13b-4GB-8bit\"\n", + "# model = LlavaLlamaForCausalLM.from_pretrained(model_path)\n", + "# model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, load_in_8bit=True, device_map=\"auto\")\n", + "\n", + "kwargs = {\"device_map\": \"auto\"}\n", + "kwargs['load_in_4bit'] = True\n", + "kwargs['quantization_config'] = BitsAndBytesConfig(\n", + " load_in_4bit=True,\n", + " bnb_4bit_compute_dtype=torch.float16,\n", + " bnb_4bit_use_double_quant=True,\n", + " bnb_4bit_quant_type='nf4'\n", + ")\n", + "model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)" ] }, {