diff --git a/poc/accelerate.ipynb b/poc/accelerate.ipynb index b33ad3c..b283162 100644 --- a/poc/accelerate.ipynb +++ b/poc/accelerate.ipynb @@ -1556,16 +1556,24 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n", + "`low_cpu_mem_usage` was None, now set to True since model is quantized.\n" + ] + }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting garbage...\n", "The current device is 0\n", - "GPU 0: 44.09 GB free, 44.35 GB total\n", + "GPU 0: 35.48 GB free, 44.35 GB total\n", "GPU 1: 44.09 GB free, 44.35 GB total\n", "GPU 2: 44.09 GB free, 44.35 GB total\n", "GPU 3: 44.09 GB free, 44.35 GB total\n", @@ -1576,9 +1584,7 @@ "name": "stderr", "output_type": "stream", "text": [ - "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n", - "`low_cpu_mem_usage` was None, now set to True since model is quantized.\n", - "Loading checkpoint shards: 100%|██████████| 4/4 [00:50<00:00, 12.56s/it]\n" + "Loading checkpoint shards: 100%|██████████| 4/4 [00:05<00:00, 1.33s/it]\n" ] }, { @@ -1586,7 +1592,7 @@ "output_type": "stream", "text": [ "The current device is 0\n", - "GPU 0: 35.54 GB free, 44.35 GB total\n", + "GPU 0: 27.05 GB free, 44.35 GB total\n", "GPU 1: 44.09 GB free, 44.35 GB total\n", "GPU 2: 44.09 GB free, 44.35 GB total\n", "GPU 3: 44.09 GB free, 44.35 GB total\n", @@ -1597,7 +1603,7 @@ "source": [ "garbage_collect()\n", "print_gpu_memory()\n", - "model = AutoModelForCausalLM.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\", cache_dir=\"/workspace/hf_cache\", load_in_8bit=True)\n", + "model = AutoModelForCausalLM.from_pretrained(\"meta-llama/Meta-Llama-3.1-8B-Instruct\", cache_dir=\"/workspace/hf_cache\", torch_dtype=torch.float16, load_in_8bit=True)\n", "print_gpu_memory()" ] }