diff --git a/bertopic/backend/__init__.py b/bertopic/backend/__init__.py index df123b8b..14de22b5 100644 --- a/bertopic/backend/__init__.py +++ b/bertopic/backend/__init__.py @@ -24,12 +24,20 @@ msg = "`pip install bertopic[vision]` \n\n" MultiModalBackend = NotInstalled("Vision", "Vision", custom_msg=msg) +# Model2Vec Embeddings +try: + from bertopic.backend._model2vec import Model2VecBackend +except ModuleNotFoundError: + msg = "`pip install model2vec` \n\n" + Model2VecBackend = NotInstalled("Model2Vec", "Model2Vec", custom_msg=msg) + __all__ = [ "BaseEmbedder", "WordDocEmbedder", "OpenAIBackend", "CohereBackend", + "Model2VecBackend", "MultiModalBackend", "languages", ] diff --git a/bertopic/backend/_model2vec.py b/bertopic/backend/_model2vec.py new file mode 100644 index 00000000..3e8a157b --- /dev/null +++ b/bertopic/backend/_model2vec.py @@ -0,0 +1,129 @@ +import numpy as np +from typing import List, Union +from model2vec import StaticModel +from sklearn.feature_extraction.text import CountVectorizer + +from bertopic.backend import BaseEmbedder + + +class Model2VecBackend(BaseEmbedder): + """Model2Vec embedding model. + + Arguments: + embedding_model: Either a model2vec model or a + string pointing to a model2vec model + distill: Indicates whether to distill a sentence-transformers compatible model. + The distillation will happen during fitting of the topic model. + NOTE: Only works if `embedding_model` is a string. + distill_kwargs: Keyword arguments to pass to the distillation process + of `model2vec.distill.distill` + distill_vectorizer: A CountVectorizer used for creating a custom vocabulary + based on the same documents used for topic modeling. + NOTE: If "vocabulary" is in `distill_kwargs`, this will be ignored. + + Examples: + To create a model, you can load in a string pointing to a + model2vec model: + + ```python + from bertopic.backend import Model2VecBackend + + sentence_model = Model2VecBackend("minishlab/potion-base-8M") + ``` + + or you can instantiate a model yourself: + + ```python + from bertopic.backend import Model2VecBackend + from model2vec import StaticModel + + embedding_model = StaticModel.from_pretrained("minishlab/potion-base-8M") + sentence_model = Model2VecBackend(embedding_model) + ``` + + If you want to distill a sentence-transformers model with the vocabulary of the documents, + run the following: + + ```python + from bertopic.backend import Model2VecBackend + + sentence_model = Model2VecBackend("sentence-transformers/all-MiniLM-L6-v2", distill=True) + ``` + """ + + def __init__( + self, + embedding_model: Union[str, StaticModel], + distill: bool = False, + distill_kwargs: dict = {}, + distill_vectorizer: str = None, + ): + super().__init__() + + self.distill = distill + self.distill_kwargs = distill_kwargs + self.distill_vectorizer = distill_vectorizer + self._has_distilled = False + + # When we distill, we need a string pointing to a sentence-transformer model + if self.distill: + self._check_model2vec_installation() + if not self.distill_vectorizer: + self.distill_vectorizer = CountVectorizer() + if isinstance(embedding_model, str): + self.embedding_model = embedding_model + else: + raise ValueError("Please pass a string pointing to a sentence-transformer model when distilling.") + + # If we don't distill, we can pass a model2vec model directly or load from a string + elif isinstance(embedding_model, StaticModel): + self.embedding_model = embedding_model + elif isinstance(embedding_model, str): + self.embedding_model = StaticModel.from_pretrained(embedding_model) + else: + raise ValueError( + "Please select a correct Model2Vec model: \n" + "`from model2vec import StaticModel` \n" + "`model = StaticModel.from_pretrained('minishlab/potion-base-8M')`" + ) + + def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray: + """Embed a list of n documents/words into an n-dimensional + matrix of embeddings. + + Arguments: + documents: A list of documents or words to be embedded + verbose: Controls the verbosity of the process + + Returns: + Document/words embeddings with shape (n, m) with `n` documents/words + that each have an embeddings size of `m` + """ + # Distill the model + if self.distill and not self._has_distilled: + from model2vec.distill import distill + + # Distill with the vocabulary of the documents + if not self.distill_kwargs.get("vocabulary"): + X = self.distill_vectorizer.fit_transform(documents) + word_counts = np.array(X.sum(axis=0)).flatten() + words = self.distill_vectorizer.get_feature_names_out() + vocabulary = [word for word, _ in sorted(zip(words, word_counts), key=lambda x: x[1], reverse=True)] + self.distill_kwargs["vocabulary"] = vocabulary + + # Distill the model + self.embedding_model = distill(self.embedding_model, **self.distill_kwargs) + + # Distillation should happen only once and not for every embed call + # The distillation should only happen the first time on the entire vocabulary + self._has_distilled = True + + # Embed the documents + embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose) + return embeddings + + def _check_model2vec_installation(self): + try: + from model2vec.distill import distill # noqa: F401 + except ImportError: + raise ImportError("To distill a model using model2vec, you need to run `pip install model2vec[distill]`") diff --git a/bertopic/backend/_sentencetransformers.py b/bertopic/backend/_sentencetransformers.py index a54ad0ec..e82751ce 100644 --- a/bertopic/backend/_sentencetransformers.py +++ b/bertopic/backend/_sentencetransformers.py @@ -1,6 +1,7 @@ import numpy as np from typing import List, Union from sentence_transformers import SentenceTransformer +from sentence_transformers.models import StaticEmbedding from bertopic.backend import BaseEmbedder @@ -13,6 +14,9 @@ class SentenceTransformerBackend(BaseEmbedder): Arguments: embedding_model: A sentence-transformers embedding model + model2vec: Indicates whether `embedding_model` is a model2vec model. + NOTE: Only works if `embedding_model` is a string. + Otherwise, you can pass the model2vec model directly to `embedding_model`. Examples: To create a model, you can load in a string pointing to a @@ -25,6 +29,7 @@ class SentenceTransformerBackend(BaseEmbedder): ``` or you can instantiate a model yourself: + ```python from bertopic.backend import SentenceTransformerBackend from sentence_transformers import SentenceTransformer @@ -32,13 +37,27 @@ class SentenceTransformerBackend(BaseEmbedder): embedding_model = SentenceTransformer("all-MiniLM-L6-v2") sentence_model = SentenceTransformerBackend(embedding_model) ``` + + If you want to use a model2vec model without having to install model2vec, + you can pass the model2vec model as a string: + + ```python + from bertopic.backend import SentenceTransformerBackend + from sentence_transformers import SentenceTransformer + + embedding_model = SentenceTransformer("minishlab/potion-base-8M", model2vec=True) + sentence_model = SentenceTransformerBackend(embedding_model) + ``` """ - def __init__(self, embedding_model: Union[str, SentenceTransformer]): + def __init__(self, embedding_model: Union[str, SentenceTransformer], model2vec: bool = False): super().__init__() self._hf_model = None - if isinstance(embedding_model, SentenceTransformer): + if model2vec and isinstance(embedding_model, str): + static_embedding = StaticEmbedding.from_model2vec(embedding_model) + self.embedding_model = SentenceTransformer(modules=[static_embedding]) + elif isinstance(embedding_model, SentenceTransformer): self.embedding_model = embedding_model elif isinstance(embedding_model, str): self.embedding_model = SentenceTransformer(embedding_model) diff --git a/bertopic/backend/_utils.py b/bertopic/backend/_utils.py index 4190bd4e..79633a67 100644 --- a/bertopic/backend/_utils.py +++ b/bertopic/backend/_utils.py @@ -124,6 +124,12 @@ def select_backend(embedding_model, language: str = None, verbose: bool = False) return HFTransformerBackend(embedding_model) + # Model2Vec embeddings + if "model2vec" in str(type(embedding_model)): + from ._model2vec import Model2VecBackend + + return Model2VecBackend(embedding_model) + # Select embedding model based on language if language: try: diff --git a/docs/getting_started/embeddings/embeddings.md b/docs/getting_started/embeddings/embeddings.md index a6fcfb73..7c275af2 100644 --- a/docs/getting_started/embeddings/embeddings.md +++ b/docs/getting_started/embeddings/embeddings.md @@ -14,7 +14,7 @@ This modularity allows us not only to choose any embedding model to convert our When new state-of-the-art pre-trained embedding models are released, BERTopic will be able to use them. As a result, BERTopic grows with any new models being released. Out of the box, BERTopic supports several embedding techniques. In this section, we will go through several of them and how they can be implemented. -### **Sentence Transformers** +## **Sentence Transformers** You can select any model from sentence-transformers [here](https://www.sbert.net/docs/pretrained_models.html) and pass it through BERTopic with `embedding_model`: @@ -47,7 +47,70 @@ topic_model = BERTopic(embedding_model=sentence_model) topic_model = BERTopic(embedding_model=embedding_model) ``` -### 🤗 Hugging Face Transformers +## **Model2Vec** +To use a blazingly fast [Model2Vec](https://github.com/MinishLab/model2vec) model, you first need to install model2vec: + +``` +pip install model2vec +``` + +Then, you can load in any of their models and pass it to BERTopic like so: + +```python +from model2vec import StaticModel +embedding_model = StaticModel.from_pretrained("minishlab/potion-base-8M") + +topic_model = BERTopic(embedding_model=embedding_model) +``` + +### **Distillation** + +These models are extremely versatile and can be distilled from existing embedding model (like those compatible with `sentence-transformers`). +This distillation process doesn't require a vocabulary (as it uses the tokenizer's vocabulary) but can benefit from having one. Fortunately, this allows you to +use the vocabulary from your input documents to distill a model yourself. + +Doing so requires you to install some additional dependencies of model2vec like so: + +``` +pip install model2vec[distill] +``` + +To then distill common embedding models, you need to import the `Model2VecBackend` from BERTopic: + +```python +from bertopic.backend import Model2VecBackend + +# Choose a model to distill (a non-Model2Vec model) +embedding_model = Model2VecBackend( + "sentence-transformers/all-MiniLM-L6-v2", + distill=True +) + +topic_model = BERTopic(embedding_model=embedding_model) +``` + +You can also choose a custom vectorizer for creating the vocabulary and define custom arguments for the distillatio process: + +```python +from bertopic.backend import Model2VecBackend +from sklearn.feature_extraction.text import CountVectorizer + +# Choose a model to distill (a non-Model2Vec model) +embedding_model = Model2VecBackend( + "sentence-transformers/all-MiniLM-L6-v2", + distill=True, + distill_kwargs={"pca_dims": 256, "apply_zipf": True, "use_subword": True}, + distill_vectorizer=CountVectorizer(ngram_range=(1, 3)) +) + +topic_model = BERTopic(embedding_model=embedding_model) +``` + +!!! tip "Tip!" + You can save the resulting model with `topic_model.embedding_model.embedding_model.save_pretrained("m2v_model")`. + + +## **🤗 Hugging Face Transformers** To use a Hugging Face transformers model, load in a pipeline and point to any model found on their model hub (https://huggingface.co/models): @@ -61,7 +124,7 @@ topic_model = BERTopic(embedding_model=embedding_model) !!! tip "Tip!" These transformers also work quite well using `sentence-transformers` which has great optimizations tricks that make using it a bit faster. -### **Flair** +## **Flair** [Flair](https://github.com/flairNLP/flair) allows you to choose almost any embedding model that is publicly available. Flair can be used as follows: @@ -87,7 +150,7 @@ document_glove_embeddings = DocumentPoolEmbeddings([glove_embedding]) topic_model = BERTopic(embedding_model=document_glove_embeddings) ``` -### **Spacy** +## **Spacy** [Spacy](https://github.com/explosion/spaCy) is an amazing framework for processing text. There are many models available across many languages for modeling text. @@ -128,7 +191,7 @@ require_gpu(0) topic_model = BERTopic(embedding_model=nlp) ``` -### **Universal Sentence Encoder (USE)** +## **Universal Sentence Encoder (USE)** The Universal Sentence Encoder encodes text into high-dimensional vectors that are used here for embedding the documents. The model is trained and optimized for greater-than-word length text, such as sentences, phrases, or short paragraphs. @@ -141,7 +204,7 @@ embedding_model = tensorflow_hub.load("https://tfhub.dev/google/universal-senten topic_model = BERTopic(embedding_model=embedding_model) ``` -### **Gensim** +## **Gensim** BERTopic supports the `gensim.downloader` module, which allows it to download any word embedding model supported by Gensim. Typically, these are Glove, Word2Vec, or FastText embeddings: @@ -155,7 +218,7 @@ topic_model = BERTopic(embedding_model=ft) Gensim is primarily used for Word Embedding models. This works typically best for short documents since the word embeddings are pooled. -### **Scikit-Learn Embeddings** +## **Scikit-Learn Embeddings** Scikit-Learn is a framework for more than just machine learning. It offers many preprocessing tools, some of which can be used to create representations for text. Many of these tools are relatively lightweight and do not require a GPU. @@ -187,7 +250,7 @@ topic_model = BERTopic(embedding_model=pipe) it does not support the `bertopic.representation` models. -### OpenAI +## **OpenAI** To use OpenAI's external API, we need to define our key and explicitly call `bertopic.backend.OpenAIBackend` to be used in our topic model: @@ -202,7 +265,7 @@ topic_model = BERTopic(embedding_model=embedding_model) ``` -### Cohere +## **Cohere** To use Cohere's external API, we need to define our key and explicitly call `bertopic.backend.CohereBackend` to be used in our topic model: @@ -216,7 +279,7 @@ embedding_model = CohereBackend(client) topic_model = BERTopic(embedding_model=embedding_model) ``` -### Multimodal +## **Multimodal** To create embeddings for both text and images in the same vector space, we can use the `MultiModalBackend`. This model uses a clip-vit based model that is capable of embedding text, images, or both: @@ -235,7 +298,7 @@ doc_image_embeddings = model.embed(docs, images) ``` -### **Custom Backend** +## **Custom Backend** If your backend or model cannot be found in the ones currently available, you can use the `bertopic.backend.BaseEmbedder` class to create your backend. Below, you will find an example of creating a SentenceTransformer backend for BERTopic: @@ -260,7 +323,7 @@ custom_embedder = CustomEmbedder(embedding_model=embedding_model) topic_model = BERTopic(embedding_model=custom_embedder) ``` -### **Custom Embeddings** +## **Custom Embeddings** The base models in BERTopic are BERT-based models that work well with document similarity tasks. Your documents, however, might be too specific for a general pre-trained model to be used. Fortunately, you can use the embedding model in BERTopic to create document features. @@ -283,7 +346,7 @@ topics, probs = topic_model.fit_transform(docs, embeddings) As you can see above, we used a SentenceTransformer model to create the embedding. You could also have used `🤗 transformers`, `Doc2Vec`, or any other embedding method. -#### **TF-IDF** +### **TF-IDF** As mentioned above, any embedding technique can be used. However, when running UMAP, the typical distance metric is `cosine` which does not work quite well for a TF-IDF matrix. Instead, BERTopic will recognize that a sparse matrix is passed and use `hellinger` instead which works quite well for the similarity between probability distributions.