diff --git a/notebooks/unslothai_test.ipynb b/notebooks/unslothai_test.ipynb new file mode 100644 index 0000000..d09517b --- /dev/null +++ b/notebooks/unslothai_test.ipynb @@ -0,0 +1,346 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.\n" + ] + } + ], + "source": [ + "from unsloth import FastLanguageModel" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "==((====))== Unsloth 2024.10.7: Fast Llama patching. Transformers = 4.44.2.\n", + " \\\\ /| GPU: NVIDIA GeForce RTX 4070 SUPER. Max memory: 11.72 GB. Platform = Linux.\n", + "O^O/ \\_/ \\ Pytorch: 2.5.0+cu124. CUDA = 8.9. CUDA Toolkit = 12.4.\n", + "\\ / Bfloat16 = TRUE. FA [Xformers = 0.0.28.post2. FA2 = False]\n", + " \"-____-\" Free Apache license: http://github.com/unslothai/unsloth\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ee1f70c1a44d47ac9514512c44c84bc8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "model.safetensors: 0%| | 0.00/5.70G [00:00:209: SyntaxWarning: invalid escape sequence '\\ '\n", + ":210: SyntaxWarning: invalid escape sequence '\\_'\n", + ":211: SyntaxWarning: invalid escape sequence '\\ '\n", + ":209: SyntaxWarning: invalid escape sequence '\\ '\n", + ":210: SyntaxWarning: invalid escape sequence '\\_'\n", + ":211: SyntaxWarning: invalid escape sequence '\\ '\n", + "Unsloth: We fixed a gradient accumulation bug, but it seems like you don't have the latest transformers version!\n", + "Please update transformers, TRL and unsloth via:\n", + "`pip install --upgrade --no-cache-dir unsloth git+https://github.com/huggingface/transformers.git git+https://github.com/huggingface/trl.git`\n" + ] + } + ], + "source": [ + "model, tokenizer = FastLanguageModel.from_pretrained(\n", + " model_name=\"unsloth/Meta-Llama-3.1-8B-Instruct\", max_seq_length=8192, dtype=None, \n", + " load_in_4bit= True)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LlamaForCausalLM(\n", + " (model): LlamaModel(\n", + " (embed_tokens): Embedding(128256, 4096)\n", + " (layers): ModuleList(\n", + " (0-31): 32 x LlamaDecoderLayer(\n", + " (self_attn): LlamaAttention(\n", + " (q_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)\n", + " (k_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)\n", + " (v_proj): Linear4bit(in_features=4096, out_features=1024, bias=False)\n", + " (o_proj): Linear4bit(in_features=4096, out_features=4096, bias=False)\n", + " (rotary_emb): LlamaExtendedRotaryEmbedding()\n", + " )\n", + " (mlp): LlamaMLP(\n", + " (gate_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)\n", + " (up_proj): Linear4bit(in_features=4096, out_features=14336, bias=False)\n", + " (down_proj): Linear4bit(in_features=14336, out_features=4096, bias=False)\n", + " (act_fn): SiLU()\n", + " )\n", + " (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " )\n", + " )\n", + " (norm): LlamaRMSNorm((4096,), eps=1e-05)\n", + " (rotary_emb): LlamaRotaryEmbedding()\n", + " )\n", + " (lm_head): Linear(in_features=4096, out_features=128256, bias=False)\n", + ")" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "FastLanguageModel.for_inference(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from transformers import TextStreamer\n", + "from unsloth.chat_templates import get_chat_template\n", + "tokenizer = get_chat_template(tokenizer, chat_template=\"llama-3.1\", \n", + " mapping = {\"role\" : \"from\", \n", + " \"content\" : \"value\", \n", + " \"user\" : \"human\", \n", + " \"assistant\" : \"gpt\"},)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "msgs = [\n", + " {\"from\": \"human\", \"value\": \"What would a raspberry say to a pear if it could talk?\"}\n", + "]\n", + "template = tokenizer.apply_chat_template(msgs, tokenize=True, add_generation_prompt=True, \n", + " return_tensors=\"pt\")\n", + "inputs = template.to(\"cuda\")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696,\n", + " 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25,\n", + " 220, 1627, 5887, 220, 2366, 19, 271, 128009, 128006,\n", + " 26380, 128007, 271, 3923, 1053, 264, 94802, 2019, 311,\n", + " 264, 38790, 422, 433, 1436, 3137, 30, 128009, 128006,\n", + " 78191, 128007, 271]])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "template" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "text_streamer = TextStreamer(tokenizer)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n", + "\n", + "Cutting Knowledge Date: December 2023\n", + "Today Date: 26 July 2024\n", + "\n", + "<|eot_id|><|start_header_id|>human<|end_header_id|>\n", + "\n", + "What would a raspberry say to a pear if it could talk?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n", + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "That's a fun one. Here's a possible conversation:\n", + "\n", + "Raspberry: \"Hey there, Pear! You're looking lovely today. I must say, your smooth skin is quite the envy of us berries.\"\n", + "\n", + "Pear: \"Ah, thank you, Raspberry! You're looking quite tart and juicy yourself. I've always admired your vibrant color and the way you add a burst of flavor to desserts.\"\n", + "\n", + "Raspberry: \"Ha! Yes, we berries have a certain charm to us, don't we? But I must say, your sweetness is quite alluring. I've always wondered, what's it like being a pear?\"\n", + "\n", + "Pear: \"Well, it's a bit more laid-back than being a raspberry, I suppose. I get to ripen slowly and enjoy the sunshine, whereas you berries are often picked quickly and turned into jams and preserves.\"\n", + "\n", + "Raspberry: \"I see what you mean. It's a trade-off, isn't it? But I think our roles are complementary. Without us berries, your sweetness might get lost in the mix. And without your sweetness, our tartness might be too overwhelming.\"\n", + "\n", + "Pear: \"Exactly! We're both essential in our own ways. And who knows, maybe one day we'll be paired together in a delicious pie or tart.\"\n", + "\n", + "Raspberry: \"Now that's a thought I can get behind! Who knows what culinary wonders we might create together?\"\n", + "\n", + "And so the conversation continues, a delightful exchange between two fruits that might seem worlds apart but share a common bond in the world of fruit.<|eot_id|>\n" + ] + } + ], + "source": [ + "response = model.generate(input_ids = inputs, streamer = text_streamer, max_new_tokens = 1024)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[128000, 128006, 9125, 128007, 271, 38766, 1303, 33025, 2696,\n", + " 25, 6790, 220, 2366, 18, 198, 15724, 2696, 25,\n", + " 220, 1627, 5887, 220, 2366, 19, 271, 128009, 128006,\n", + " 26380, 128007, 271, 3923, 1053, 264, 94802, 2019, 311,\n", + " 264, 38790, 422, 433, 1436, 3137, 30, 128009, 128006,\n", + " 78191, 128007, 271, 3923, 264, 50189, 323, 77361, 3488,\n", + " 2268, 2746, 264, 94802, 1436, 3137, 11, 433, 2643,\n", + " 2019, 311, 264, 38790, 1473, 55005, 11, 23910, 11,\n", + " 499, 2351, 779, 11113, 323, 27877, 13, 358, 2846,\n", + " 264, 2766, 45915, 323, 89800, 398, 11, 719, 358,\n", + " 1093, 311, 1781, 358, 923, 264, 21165, 315, 28361,\n", + " 311, 904, 6671, 13, 1226, 2351, 1093, 379, 258,\n", + " 323, 10587, 11, 7784, 956, 584, 30, 1472, 2351,\n", + " 279, 19858, 11, 22443, 832, 11, 1418, 358, 2846,\n", + " 279, 49277, 11, 68188, 832, 13, 2030, 8994, 1057,\n", + " 12062, 11, 358, 1781, 584, 23606, 1855, 1023, 14268,\n", + " 13, 1472, 4546, 704, 279, 64550, 304, 757, 11,\n", + " 323, 358, 4546, 704, 279, 42786, 304, 499, 13,\n", + " 6914, 596, 387, 4885, 323, 1893, 264, 18406, 26348,\n", + " 3871, 9135, 128009]], device='cuda:0')" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "response" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 3dda280..8be2bbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "ragas == 0.1.10", "nltk == 3.9.1", "nbformat == 4.2.0", + "unsloth == 2024.10.7", ] [project.optional-dependencies]