diff --git a/bootcamp/text_embeddings_tutorial.ipynb b/bootcamp/text_embeddings_tutorial.ipynb new file mode 100644 index 000000000..8b5dcbd92 --- /dev/null +++ b/bootcamp/text_embeddings_tutorial.ipynb @@ -0,0 +1,396 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "597347c3-f5a0-4fd5-a840-e23957536b3e", + "metadata": {}, + "source": [ + "# Text Embeddings with `sentence-transformers`" + ] + }, + { + "cell_type": "markdown", + "id": "cb9eff2b-f5df-4c9e-85a6-e39846749e75", + "metadata": {}, + "source": [ + "#### We'll start off by installing some dependencies: `sentence-transformers` for the models and `milvus` for the vector database. Milvus is known for its scalability and wide adoption among organiziations, but we have an \"embedded\" version too!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "e8973196-c967-40bf-9adf-e9ef2df55f59", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: sentence_transformers in /Users/zilliz/.pyenv/lib/python3.9/site-packages (2.2.2)\n", + "Requirement already satisfied: torchvision in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (0.16.2)\n", + "Requirement already satisfied: numpy in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (1.26.2)\n", + "Requirement already satisfied: sentencepiece in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (0.1.99)\n", + "Requirement already satisfied: scikit-learn in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (1.3.2)\n", + "Requirement already satisfied: torch>=1.6.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (2.1.2)\n", + "Requirement already satisfied: huggingface-hub>=0.4.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (0.19.4)\n", + "Requirement already satisfied: nltk in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (3.8.1)\n", + "Requirement already satisfied: tqdm in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (4.66.1)\n", + "Requirement already satisfied: transformers<5.0.0,>=4.6.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (4.36.1)\n", + "Requirement already satisfied: scipy in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sentence_transformers) (1.12.0)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2023.10.0)\n", + "Requirement already satisfied: requests in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (2.31.0)\n", + "Requirement already satisfied: packaging>=20.9 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (23.2)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (4.8.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (6.0.1)\n", + "Requirement already satisfied: filelock in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from huggingface-hub>=0.4.0->sentence_transformers) (3.13.1)\n", + "Requirement already satisfied: sympy in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from torch>=1.6.0->sentence_transformers) (1.12)\n", + "Requirement already satisfied: networkx in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from torch>=1.6.0->sentence_transformers) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from torch>=1.6.0->sentence_transformers) (3.1.2)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.4.1)\n", + "Requirement already satisfied: regex!=2019.12.17 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (2023.10.3)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from transformers<5.0.0,>=4.6.0->sentence_transformers) (0.15.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from jinja2->torch>=1.6.0->sentence_transformers) (2.1.3)\n", + "Requirement already satisfied: joblib in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from nltk->sentence_transformers) (1.3.2)\n", + "Requirement already satisfied: click in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from nltk->sentence_transformers) (8.1.7)\n", + "Requirement already satisfied: idna<4,>=2.5 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2.1.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (3.3.2)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from requests->huggingface-hub>=0.4.0->sentence_transformers) (2023.11.17)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from scikit-learn->sentence_transformers) (3.2.0)\n", + "Requirement already satisfied: mpmath>=0.19 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from sympy->torch>=1.6.0->sentence_transformers) (1.3.0)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /Users/zilliz/.pyenv/lib/python3.9/site-packages (from torchvision->sentence_transformers) (10.1.0)\n", + "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.3.2 is available.\n", + "You should consider upgrading via the '/Users/zilliz/.pyenv/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n", + "Requirement already satisfied: milvus in /Users/zilliz/.pyenv/lib/python3.9/site-packages (2.3.5)\n", + "\u001b[33mWARNING: You are using pip version 21.2.4; however, version 23.3.2 is available.\n", + "You should consider upgrading via the '/Users/zilliz/.pyenv/bin/python3 -m pip install --upgrade pip' command.\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install -U sentence_transformers\n", + "!pip install -U milvus" + ] + }, + { + "cell_type": "markdown", + "id": "0ef63d9e-898c-4167-a3a1-4268540ade60", + "metadata": {}, + "source": [ + "##### We'll go over the basics first: specifying a model and computing its embeddings." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "3162309c-c311-47e6-994f-b08817dbcb57", + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "from sentence_transformers import SentenceTransformer\n", + "model = SentenceTransformer(\"llmrails/ember-v1\")\n", + "model.max_seq_length = 1024" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "51182ae0-2834-41ff-b257-8178c83fff7d", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(1024,)" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "embedding_0 = model.encode(\"Zilliz is an awesome vector database.\")\n", + "embedding_0\n", + "embedding_0.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7550b3a6-ff07-44cb-b757-fbe13a3266f4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tensor([[0.9391, 0.5762, 0.5002, 0.3822, 0.3003]])\n" + ] + } + ], + "source": [ + "from sentence_transformers.util import cos_sim\n", + "sentences = [\"Zilliz is a vector data store that is amazing.\",\n", + " \"Unstructured data can be semantically represented with embeddings.\",\n", + " \"Singular value decomposition factorizes the input matrix into three other matrices.\",\n", + " \"My favorite chess opening is the King's Gambit.\",\n", + " \"It doesn't matter if a cat is black or white, so long as it catches mice.\"]\n", + "embeddings = model.encode(sentences)\n", + "print(cos_sim(embedding_0, embeddings))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "72e43c27-7c1b-4065-ae2c-de8911aac593", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[1.0000]])" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cos_sim(model.encode(\"I like green eggs and ham.\"), model.encode(\"I like green eggs and ham.\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "2868a7f8-4fcd-4479-8475-a21fe94ce5d1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[0.8867]])" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cos_sim(model.encode(\"Let's eat, Chris.\"), model.encode(\"Let's eat Chris!\"))" + ] + }, + { + "cell_type": "markdown", + "id": "080e9943-9cda-4019-84cf-bc1386915ee2", + "metadata": {}, + "source": [ + "#### Now let's check out how to fine-tune our model." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "5c8c019a-3dfe-4528-aeea-f17ecdcef745", + "metadata": {}, + "outputs": [], + "source": [ + "from sentence_transformers import InputExample\n", + "train_examples = [\n", + " InputExample(texts=[\"Give me a quote on pragmatism.\", \"Whether the cat is black or white doesn't matter, so long as it catches mice.\"], label=1.0),\n", + " InputExample(texts=[\"Y Combinator's 7th birthday was March 11.\", \"As usual we were so busy we didn't notice till a few days after.\"], label=1.0)\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "0b2376ea-1a5a-4db4-a5ac-5f9d9f8754ea", + "metadata": {}, + "outputs": [], + "source": [ + "from sentence_transformers import losses\n", + "from torch.utils.data import DataLoader\n", + "train_dataloader = DataLoader(train_examples, shuffle=True, batch_size=1)\n", + "train_loss = losses.CosineSimilarityLoss(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "67bdadea-f523-48b0-968c-ae2c1bbf6ea9", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b1076953670043509a1699051cfd043b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Epoch: 0%| | 0/1 [00:00