Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Model2Vec as an embedding backend #2245

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions bertopic/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,20 @@
msg = "`pip install bertopic[vision]` \n\n"
MultiModalBackend = NotInstalled("Vision", "Vision", custom_msg=msg)

# Model2Vec Embeddings
try:
from bertopic.backend._model2vec import Model2VecBackend
except ModuleNotFoundError:
msg = "`pip install model2vec` \n\n"
Model2VecBackend = NotInstalled("Model2Vec", "Model2Vec", custom_msg=msg)


__all__ = [
"BaseEmbedder",
"WordDocEmbedder",
"OpenAIBackend",
"CohereBackend",
"Model2VecBackend",
"MultiModalBackend",
"languages",
]
129 changes: 129 additions & 0 deletions bertopic/backend/_model2vec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
import numpy as np
from typing import List, Union
from model2vec import StaticModel
from sklearn.feature_extraction.text import CountVectorizer

from bertopic.backend import BaseEmbedder


class Model2VecBackend(BaseEmbedder):
"""Model2Vec embedding model.

Arguments:
embedding_model: Either a model2vec model or a
string pointing to a model2vec model
distill: Indicates whether to distill a sentence-transformers compatible model.
The distillation will happen during fitting of the topic model.
NOTE: Only works if `embedding_model` is a string.
distill_kwargs: Keyword arguments to pass to the distillation process
of `model2vec.distill.distill`
distill_vectorizer: A CountVectorizer used for creating a custom vocabulary
based on the same documents used for topic modeling.
NOTE: If "vocabulary" is in `distill_kwargs`, this will be ignored.

Examples:
To create a model, you can load in a string pointing to a
model2vec model:

```python
from bertopic.backend import Model2VecBackend

sentence_model = Model2VecBackend("minishlab/potion-base-8M")
```

or you can instantiate a model yourself:

```python
from bertopic.backend import Model2VecBackend
from model2vec import StaticModel

embedding_model = StaticModel.from_pretrained("minishlab/potion-base-8M")
sentence_model = Model2VecBackend(embedding_model)
```

If you want to distill a sentence-transformers model with the vocabulary of the documents,
run the following:

```python
from bertopic.backend import Model2VecBackend

sentence_model = Model2VecBackend("sentence-transformers/all-MiniLM-L6-v2", distill=True)
```
"""

def __init__(
self,
embedding_model: Union[str, StaticModel],
distill: bool = False,
distill_kwargs: dict = {},
distill_vectorizer: str = None,
):
super().__init__()

self.distill = distill
self.distill_kwargs = distill_kwargs
self.distill_vectorizer = distill_vectorizer
self._has_distilled = False

# When we distill, we need a string pointing to a sentence-transformer model
if self.distill:
self._check_model2vec_installation()
if not self.distill_vectorizer:
self.distill_vectorizer = CountVectorizer()
if isinstance(embedding_model, str):
self.embedding_model = embedding_model
else:
raise ValueError("Please pass a string pointing to a sentence-transformer model when distilling.")

# If we don't distill, we can pass a model2vec model directly or load from a string
elif isinstance(embedding_model, StaticModel):
self.embedding_model = embedding_model
elif isinstance(embedding_model, str):
self.embedding_model = StaticModel.from_pretrained(embedding_model)
else:
raise ValueError(
"Please select a correct Model2Vec model: \n"
"`from model2vec import StaticModel` \n"
"`model = StaticModel.from_pretrained('minishlab/potion-base-8M')`"
)

def embed(self, documents: List[str], verbose: bool = False) -> np.ndarray:
"""Embed a list of n documents/words into an n-dimensional
matrix of embeddings.

Arguments:
documents: A list of documents or words to be embedded
verbose: Controls the verbosity of the process

Returns:
Document/words embeddings with shape (n, m) with `n` documents/words
that each have an embeddings size of `m`
"""
# Distill the model
if self.distill and not self._has_distilled:
from model2vec.distill import distill

# Distill with the vocabulary of the documents
if not self.distill_kwargs.get("vocabulary"):
X = self.distill_vectorizer.fit_transform(documents)
word_counts = np.array(X.sum(axis=0)).flatten()
words = self.distill_vectorizer.get_feature_names_out()
vocabulary = [word for word, _ in sorted(zip(words, word_counts), key=lambda x: x[1], reverse=True)]
self.distill_kwargs["vocabulary"] = vocabulary

# Distill the model
self.embedding_model = distill(self.embedding_model, **self.distill_kwargs)

# Distillation should happen only once and not for every embed call
# The distillation should only happen the first time on the entire vocabulary
self._has_distilled = True

# Embed the documents
embeddings = self.embedding_model.encode(documents, show_progress_bar=verbose)
return embeddings

def _check_model2vec_installation(self):
try:
from model2vec.distill import distill # noqa: F401
except ImportError:
raise ImportError("To distill a model using model2vec, you need to run `pip install model2vec[distill]`")
23 changes: 21 additions & 2 deletions bertopic/backend/_sentencetransformers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
from typing import List, Union
from sentence_transformers import SentenceTransformer
from sentence_transformers.models import StaticEmbedding

from bertopic.backend import BaseEmbedder

Expand All @@ -13,6 +14,9 @@ class SentenceTransformerBackend(BaseEmbedder):

Arguments:
embedding_model: A sentence-transformers embedding model
model2vec: Indicates whether `embedding_model` is a model2vec model.
NOTE: Only works if `embedding_model` is a string.
Otherwise, you can pass the model2vec model directly to `embedding_model`.

Examples:
To create a model, you can load in a string pointing to a
Expand All @@ -25,20 +29,35 @@ class SentenceTransformerBackend(BaseEmbedder):
```

or you can instantiate a model yourself:

```python
from bertopic.backend import SentenceTransformerBackend
from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
sentence_model = SentenceTransformerBackend(embedding_model)
```

If you want to use a model2vec model without having to install model2vec,
you can pass the model2vec model as a string:

```python
from bertopic.backend import SentenceTransformerBackend
from sentence_transformers import SentenceTransformer

embedding_model = SentenceTransformer("minishlab/potion-base-8M", model2vec=True)
sentence_model = SentenceTransformerBackend(embedding_model)
```
"""

def __init__(self, embedding_model: Union[str, SentenceTransformer]):
def __init__(self, embedding_model: Union[str, SentenceTransformer], model2vec: bool = False):
super().__init__()

self._hf_model = None
if isinstance(embedding_model, SentenceTransformer):
if model2vec and isinstance(embedding_model, str):
static_embedding = StaticEmbedding.from_model2vec(embedding_model)
self.embedding_model = SentenceTransformer(modules=[static_embedding])
elif isinstance(embedding_model, SentenceTransformer):
self.embedding_model = embedding_model
elif isinstance(embedding_model, str):
self.embedding_model = SentenceTransformer(embedding_model)
Expand Down
6 changes: 6 additions & 0 deletions bertopic/backend/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,12 @@ def select_backend(embedding_model, language: str = None, verbose: bool = False)

return HFTransformerBackend(embedding_model)

# Model2Vec embeddings
if "model2vec" in str(type(embedding_model)):
from ._model2vec import Model2VecBackend

return Model2VecBackend(embedding_model)

# Select embedding model based on language
if language:
try:
Expand Down
89 changes: 76 additions & 13 deletions docs/getting_started/embeddings/embeddings.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ This modularity allows us not only to choose any embedding model to convert our
When new state-of-the-art pre-trained embedding models are released, BERTopic will be able to use them. As a result, BERTopic grows with any new models being released.
Out of the box, BERTopic supports several embedding techniques. In this section, we will go through several of them and how they can be implemented.

### **Sentence Transformers**
## **Sentence Transformers**
You can select any model from sentence-transformers [here](https://www.sbert.net/docs/pretrained_models.html)
and pass it through BERTopic with `embedding_model`:

Expand Down Expand Up @@ -47,7 +47,70 @@ topic_model = BERTopic(embedding_model=sentence_model)
topic_model = BERTopic(embedding_model=embedding_model)
```

### 🤗 Hugging Face Transformers
## **Model2Vec**
To use a blazingly fast [Model2Vec](https://github.com/MinishLab/model2vec) model, you first need to install model2vec:

```
pip install model2vec
```

Then, you can load in any of their models and pass it to BERTopic like so:

```python
from model2vec import StaticModel
embedding_model = StaticModel.from_pretrained("minishlab/potion-base-8M")

topic_model = BERTopic(embedding_model=embedding_model)
```

### **Distillation**

These models are extremely versatile and can be distilled from existing embedding model (like those compatible with `sentence-transformers`).
This distillation process doesn't require a vocabulary (as it uses the tokenizer's vocabulary) but can benefit from having one. Fortunately, this allows you to
use the vocabulary from your input documents to distill a model yourself.

Doing so requires you to install some additional dependencies of model2vec like so:

```
pip install model2vec[distill]
```

To then distill common embedding models, you need to import the `Model2VecBackend` from BERTopic:

```python
from bertopic.backend import Model2VecBackend

# Choose a model to distill (a non-Model2Vec model)
embedding_model = Model2VecBackend(
"sentence-transformers/all-MiniLM-L6-v2",
distill=True
)

topic_model = BERTopic(embedding_model=embedding_model)
```

You can also choose a custom vectorizer for creating the vocabulary and define custom arguments for the distillatio process:

```python
from bertopic.backend import Model2VecBackend
from sklearn.feature_extraction.text import CountVectorizer

# Choose a model to distill (a non-Model2Vec model)
embedding_model = Model2VecBackend(
"sentence-transformers/all-MiniLM-L6-v2",
distill=True,
distill_kwargs={"pca_dims": 256, "apply_zipf": True, "use_subword": True},
distill_vectorizer=CountVectorizer(ngram_range=(1, 3))
)

topic_model = BERTopic(embedding_model=embedding_model)
```

!!! tip "Tip!"
You can save the resulting model with `topic_model.embedding_model.embedding_model.save_pretrained("m2v_model")`.


## **🤗 Hugging Face Transformers**
To use a Hugging Face transformers model, load in a pipeline and point
to any model found on their model hub (https://huggingface.co/models):

Expand All @@ -61,7 +124,7 @@ topic_model = BERTopic(embedding_model=embedding_model)
!!! tip "Tip!"
These transformers also work quite well using `sentence-transformers` which has great optimizations tricks that make using it a bit faster.

### **Flair**
## **Flair**
[Flair](https://github.com/flairNLP/flair) allows you to choose almost any embedding model that
is publicly available. Flair can be used as follows:

Expand All @@ -87,7 +150,7 @@ document_glove_embeddings = DocumentPoolEmbeddings([glove_embedding])
topic_model = BERTopic(embedding_model=document_glove_embeddings)
```

### **Spacy**
## **Spacy**
[Spacy](https://github.com/explosion/spaCy) is an amazing framework for processing text. There are
many models available across many languages for modeling text.

Expand Down Expand Up @@ -128,7 +191,7 @@ require_gpu(0)
topic_model = BERTopic(embedding_model=nlp)
```

### **Universal Sentence Encoder (USE)**
## **Universal Sentence Encoder (USE)**
The Universal Sentence Encoder encodes text into high-dimensional vectors that are used here
for embedding the documents. The model is trained and optimized for greater-than-word length text,
such as sentences, phrases, or short paragraphs.
Expand All @@ -141,7 +204,7 @@ embedding_model = tensorflow_hub.load("https://tfhub.dev/google/universal-senten
topic_model = BERTopic(embedding_model=embedding_model)
```

### **Gensim**
## **Gensim**
BERTopic supports the `gensim.downloader` module, which allows it to download any word embedding model supported by Gensim.
Typically, these are Glove, Word2Vec, or FastText embeddings:

Expand All @@ -155,7 +218,7 @@ topic_model = BERTopic(embedding_model=ft)
Gensim is primarily used for Word Embedding models. This works typically best for short documents since the word embeddings are pooled.


### **Scikit-Learn Embeddings**
## **Scikit-Learn Embeddings**
Scikit-Learn is a framework for more than just machine learning.
It offers many preprocessing tools, some of which can be used to create representations
for text. Many of these tools are relatively lightweight and do not require a GPU.
Expand Down Expand Up @@ -187,7 +250,7 @@ topic_model = BERTopic(embedding_model=pipe)
it does not support the `bertopic.representation` models.


### OpenAI
## **OpenAI**
To use OpenAI's external API, we need to define our key and explicitly call `bertopic.backend.OpenAIBackend`
to be used in our topic model:

Expand All @@ -202,7 +265,7 @@ topic_model = BERTopic(embedding_model=embedding_model)
```


### Cohere
## **Cohere**
To use Cohere's external API, we need to define our key and explicitly call `bertopic.backend.CohereBackend`
to be used in our topic model:

Expand All @@ -216,7 +279,7 @@ embedding_model = CohereBackend(client)
topic_model = BERTopic(embedding_model=embedding_model)
```

### Multimodal
## **Multimodal**
To create embeddings for both text and images in the same vector space, we can use the `MultiModalBackend`.
This model uses a clip-vit based model that is capable of embedding text, images, or both:

Expand All @@ -235,7 +298,7 @@ doc_image_embeddings = model.embed(docs, images)
```


### **Custom Backend**
## **Custom Backend**
If your backend or model cannot be found in the ones currently available, you can use the `bertopic.backend.BaseEmbedder` class to
create your backend. Below, you will find an example of creating a SentenceTransformer backend for BERTopic:

Expand All @@ -260,7 +323,7 @@ custom_embedder = CustomEmbedder(embedding_model=embedding_model)
topic_model = BERTopic(embedding_model=custom_embedder)
```

### **Custom Embeddings**
## **Custom Embeddings**
The base models in BERTopic are BERT-based models that work well with document similarity tasks. Your documents,
however, might be too specific for a general pre-trained model to be used. Fortunately, you can use the embedding
model in BERTopic to create document features.
Expand All @@ -283,7 +346,7 @@ topics, probs = topic_model.fit_transform(docs, embeddings)
As you can see above, we used a SentenceTransformer model to create the embedding. You could also have used
`🤗 transformers`, `Doc2Vec`, or any other embedding method.

#### **TF-IDF**
### **TF-IDF**
As mentioned above, any embedding technique can be used. However, when running UMAP, the typical distance metric is
`cosine` which does not work quite well for a TF-IDF matrix. Instead, BERTopic will recognize that a sparse matrix
is passed and use `hellinger` instead which works quite well for the similarity between probability distributions.
Expand Down
Loading