From 7471363443a35d69ab3ac23e17329229061b6e22 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Wed, 13 Nov 2024 13:29:25 +0100 Subject: [PATCH 1/5] Add LiteLLM as a representation model --- bertopic/cluster/_base.py | 2 +- bertopic/representation/__init__.py | 8 + bertopic/representation/_litellm.py | 176 ++++++++++++++++++ docs/algorithm/algorithm.md | 7 +- docs/api/backends.md | 3 + docs/api/backends/base.md | 3 - docs/api/backends/cohere.md | 3 - docs/api/backends/openai.md | 3 - docs/api/backends/word_doc.md | 3 - docs/api/{cluster/base.md => cluster copy.md} | 0 docs/api/cluster.md | 3 + .../base.md => dimensionality.md} | 0 docs/api/onlinecv.md | 3 - docs/api/plotting.md | 3 + docs/api/representation/base.md | 3 - docs/api/representation/cohere.md | 3 - docs/api/representation/generation.md | 3 - docs/api/representation/keybert.md | 3 - docs/api/representation/langchain.md | 3 - docs/api/representation/mmr.md | 3 - docs/api/representation/openai.md | 3 - docs/api/representation/pos.md | 3 - docs/api/representation/zeroshot.md | 3 - docs/api/representations.md | 3 + docs/api/vectorizers.md | 3 + docs/getting_started/representation/llm.md | 29 +++ mkdocs.yml | 29 +-- 27 files changed, 241 insertions(+), 67 deletions(-) create mode 100644 bertopic/representation/_litellm.py create mode 100644 docs/api/backends.md delete mode 100644 docs/api/backends/base.md delete mode 100644 docs/api/backends/cohere.md delete mode 100644 docs/api/backends/openai.md delete mode 100644 docs/api/backends/word_doc.md rename docs/api/{cluster/base.md => cluster copy.md} (100%) create mode 100644 docs/api/cluster.md rename docs/api/{dimensionality/base.md => dimensionality.md} (100%) delete mode 100644 docs/api/onlinecv.md create mode 100644 docs/api/plotting.md delete mode 100644 docs/api/representation/base.md delete mode 100644 docs/api/representation/cohere.md delete mode 100644 docs/api/representation/generation.md delete mode 100644 docs/api/representation/keybert.md delete mode 100644 docs/api/representation/langchain.md delete mode 100644 docs/api/representation/mmr.md delete mode 100644 docs/api/representation/openai.md delete mode 100644 docs/api/representation/pos.md delete mode 100644 docs/api/representation/zeroshot.md create mode 100644 docs/api/representations.md create mode 100644 docs/api/vectorizers.md diff --git a/bertopic/cluster/_base.py b/bertopic/cluster/_base.py index a096d99e..7212b9bd 100644 --- a/bertopic/cluster/_base.py +++ b/bertopic/cluster/_base.py @@ -14,7 +14,7 @@ class BaseCluster: ```python from bertopic import BERTopic - from bertopic.dimensionality import BaseCluster + from bertopic.cluster import BaseCluster empty_cluster_model = BaseCluster() diff --git a/bertopic/representation/__init__.py b/bertopic/representation/__init__.py index da0c6365..3a2a22c5 100644 --- a/bertopic/representation/__init__.py +++ b/bertopic/representation/__init__.py @@ -33,6 +33,13 @@ msg = "`pip install openai` \n\n" OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg) +# LiteLLM Generator +try: + from bertopic.representation._litellm import LiteLLM +except ModuleNotFoundError: + msg = "`pip install litellm` \n\n" + OpenAI = NotInstalled("LiteLLM", "litellm", custom_msg=msg) + # LangChain Generator try: from bertopic.representation._langchain import LangChain @@ -63,6 +70,7 @@ "Cohere", "OpenAI", "LangChain", + "LiteLLM", "LlamaCPP", "VisualRepresentation", ] diff --git a/bertopic/representation/_litellm.py b/bertopic/representation/_litellm.py new file mode 100644 index 00000000..5a7a891a --- /dev/null +++ b/bertopic/representation/_litellm.py @@ -0,0 +1,176 @@ +import time +from litellm import completion +import pandas as pd +from scipy.sparse import csr_matrix +from typing import Mapping, List, Tuple, Any +from bertopic.representation._base import BaseRepresentation +from bertopic.representation._utils import retry_with_exponential_backoff + + +DEFAULT_PROMPT = """ +I have a topic that contains the following documents: +[DOCUMENTS] +The topic is described by the following keywords: [KEYWORDS] +Based on the information above, extract a short topic label in the following format: +topic: +""" + + +class LiteLLM(BaseRepresentation): + """Using the LiteLLM API to generate topic labels. + + For an overview of models see: + https://docs.litellm.ai/docs/providers + + Arguments: + model: Model to use. Defaults to OpenAI's "gpt-3.5-turbo". + generator_kwargs: Kwargs passed to `litellm.completion`. + prompt: The prompt to be used in the model. If no prompt is given, + `self.default_prompt_` is used instead. + NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt + to decide where the keywords and documents need to be + inserted. + delay_in_seconds: The delay in seconds between consecutive prompts + in order to prevent RateLimitErrors. + exponential_backoff: Retry requests with a random exponential backoff. + A short sleep is used when a rate limit error is hit, + then the requests is retried. Increase the sleep length + if errors are hit until 10 unsuccesfull requests. + If True, overrides `delay_in_seconds`. + nr_docs: The number of documents to pass to LiteLLM if a prompt + with the `["DOCUMENTS"]` tag is used. + diversity: The diversity of documents to pass to LiteLLM. + Accepts values between 0 and 1. A higher + values results in passing more diverse documents + whereas lower values passes more similar documents. + + Usage: + + To use this, you will need to install the openai package first: + + `pip install litellm` + + Then, get yourself an API key of any provider (for instance OpenAI) and use it as follows: + + ```python + import os + from bertopic.representation import LiteLLM + from bertopic import BERTopic + + # set ENV variables + os.environ["OPENAI_API_KEY"] = "your-openai-key" + + # Create your representation model + representation_model = LiteLLM(model="gpt-3.5-turbo") + + # Use the representation model in BERTopic on top of the default pipeline + topic_model = BERTopic(representation_model=representation_model) + ``` + + You can also use a custom prompt: + + ```python + prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '" + representation_model = LiteLLM(model="gpt", prompt=prompt) + ``` + """ + + def __init__( + self, + model: str = "gpt-3.5-turbo", + prompt: str = None, + generator_kwargs: Mapping[str, Any] = {}, + delay_in_seconds: float = None, + exponential_backoff: bool = False, + nr_docs: int = 4, + diversity: float = None, + ): + self.model = model + self.prompt = prompt if prompt else DEFAULT_PROMPT + self.default_prompt_ = DEFAULT_PROMPT + self.delay_in_seconds = delay_in_seconds + self.exponential_backoff = exponential_backoff + self.nr_docs = nr_docs + self.diversity = diversity + + self.generator_kwargs = generator_kwargs + if self.generator_kwargs.get("model"): + self.model = generator_kwargs.get("model") + if self.generator_kwargs.get("prompt"): + del self.generator_kwargs["prompt"] + + def extract_topics( + self, topic_model, documents: pd.DataFrame, c_tf_idf: csr_matrix, topics: Mapping[str, List[Tuple[str, float]]] + ) -> Mapping[str, List[Tuple[str, float]]]: + """Extract topics. + + Arguments: + topic_model: A BERTopic model + documents: All input documents + c_tf_idf: The topic c-TF-IDF representation + topics: The candidate topics as calculated with c-TF-IDF + + Returns: + updated_topics: Updated topic representations + """ + # Extract the top n representative documents per topic + repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs( + c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity + ) + + # Generate using a (Large) Language Model + updated_topics = {} + for topic, docs in repr_docs_mappings.items(): + prompt = self._create_prompt(docs, topic, topics) + + # Delay + if self.delay_in_seconds: + time.sleep(self.delay_in_seconds) + + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": prompt}, + ] + kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs} + if self.exponential_backoff: + response = chat_completions_with_backoff(**kwargs) + else: + response = completion(**kwargs) + label = response["choices"][0]["message"]["content"].strip().replace("topic: ", "") + + updated_topics[topic] = [(label, 1)] + + return updated_topics + + def _create_prompt(self, docs, topic, topics): + keywords = list(zip(*topics[topic]))[0] + + # Use the Default Chat Prompt + if self.prompt == DEFAULT_PROMPT: + prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords)) + prompt = self._replace_documents(prompt, docs) + + # Use a custom prompt that leverages keywords, documents or both using + # custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively + else: + prompt = self.prompt + if "[KEYWORDS]" in prompt: + prompt = prompt.replace("[KEYWORDS]", " ".join(keywords)) + if "[DOCUMENTS]" in prompt: + prompt = self._replace_documents(prompt, docs) + + return prompt + + @staticmethod + def _replace_documents(prompt, docs): + to_replace = "" + for doc in docs: + to_replace += f"- {doc[:255]}\n" + prompt = prompt.replace("[DOCUMENTS]", to_replace) + return prompt + + +def chat_completions_with_backoff(**kwargs): + return retry_with_exponential_backoff( + completion, + )(**kwargs) diff --git a/docs/algorithm/algorithm.md b/docs/algorithm/algorithm.md index 11057b08..af8acd86 100644 --- a/docs/algorithm/algorithm.md +++ b/docs/algorithm/algorithm.md @@ -164,7 +164,12 @@ The following models are implemented in `bertopic.representation`: * `PartOfSpeech` * `KeyBERTInspired` * `ZeroShotClassification` -* `TextGeneration` +* `TextGeneration` (HuggingFace) * `Cohere` * `OpenAI` * `LangChain` +* `LiteLLM` +* `LlamaCPP` + +!!! tip Models + There are roughly two sets of models. **First** are the non-generative set of models that you can find [here](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html). These include models that focus on enhancing the keywords in the topic representations. **Second** are the generative models that attempt to label or summarize the topics instead. You can find an overview of [implemented LLMs here](https://maartengr.github.io/BERTopic/getting_started/representation/llm). \ No newline at end of file diff --git a/docs/api/backends.md b/docs/api/backends.md new file mode 100644 index 00000000..6c6a85bc --- /dev/null +++ b/docs/api/backends.md @@ -0,0 +1,3 @@ +# `Backends` + +::: bertopic.backend \ No newline at end of file diff --git a/docs/api/backends/base.md b/docs/api/backends/base.md deleted file mode 100644 index 1b54a45e..00000000 --- a/docs/api/backends/base.md +++ /dev/null @@ -1,3 +0,0 @@ -# `BaseEmbedder` - -::: bertopic.backend._base.BaseEmbedder diff --git a/docs/api/backends/cohere.md b/docs/api/backends/cohere.md deleted file mode 100644 index 8c7f4df2..00000000 --- a/docs/api/backends/cohere.md +++ /dev/null @@ -1,3 +0,0 @@ -# `CohereBackend` - -::: bertopic.backend._cohere.CohereBackend diff --git a/docs/api/backends/openai.md b/docs/api/backends/openai.md deleted file mode 100644 index 89316457..00000000 --- a/docs/api/backends/openai.md +++ /dev/null @@ -1,3 +0,0 @@ -# `OpenAIBackend` - -::: bertopic.backend._openai.OpenAIBackend diff --git a/docs/api/backends/word_doc.md b/docs/api/backends/word_doc.md deleted file mode 100644 index 89136388..00000000 --- a/docs/api/backends/word_doc.md +++ /dev/null @@ -1,3 +0,0 @@ -# `WordDocEmbedder` - -::: bertopic.backend._word_doc.WordDocEmbedder diff --git a/docs/api/cluster/base.md b/docs/api/cluster copy.md similarity index 100% rename from docs/api/cluster/base.md rename to docs/api/cluster copy.md diff --git a/docs/api/cluster.md b/docs/api/cluster.md new file mode 100644 index 00000000..466bcc32 --- /dev/null +++ b/docs/api/cluster.md @@ -0,0 +1,3 @@ +# `BaseCluster` + +::: bertopic.cluster._base.BaseCluster diff --git a/docs/api/dimensionality/base.md b/docs/api/dimensionality.md similarity index 100% rename from docs/api/dimensionality/base.md rename to docs/api/dimensionality.md diff --git a/docs/api/onlinecv.md b/docs/api/onlinecv.md deleted file mode 100644 index bb986370..00000000 --- a/docs/api/onlinecv.md +++ /dev/null @@ -1,3 +0,0 @@ -# `OnlineCountVectorizer` - -::: bertopic.vectorizers.OnlineCountVectorizer diff --git a/docs/api/plotting.md b/docs/api/plotting.md new file mode 100644 index 00000000..03376402 --- /dev/null +++ b/docs/api/plotting.md @@ -0,0 +1,3 @@ +# `Plotting` + +::: bertopic.plotting \ No newline at end of file diff --git a/docs/api/representation/base.md b/docs/api/representation/base.md deleted file mode 100644 index 42384c29..00000000 --- a/docs/api/representation/base.md +++ /dev/null @@ -1,3 +0,0 @@ -# `BaseRepresentation` - -::: bertopic.representation._base.BaseRepresentation diff --git a/docs/api/representation/cohere.md b/docs/api/representation/cohere.md deleted file mode 100644 index 2301eea4..00000000 --- a/docs/api/representation/cohere.md +++ /dev/null @@ -1,3 +0,0 @@ -# `Cohere` - -::: bertopic.representation._cohere.Cohere diff --git a/docs/api/representation/generation.md b/docs/api/representation/generation.md deleted file mode 100644 index 39ba1739..00000000 --- a/docs/api/representation/generation.md +++ /dev/null @@ -1,3 +0,0 @@ -# `TextGeneration` - -::: bertopic.representation._textgeneration.TextGeneration diff --git a/docs/api/representation/keybert.md b/docs/api/representation/keybert.md deleted file mode 100644 index 8a10e08f..00000000 --- a/docs/api/representation/keybert.md +++ /dev/null @@ -1,3 +0,0 @@ -# `KeyBERTInspired` - -::: bertopic.representation._keybert.KeyBERTInspired diff --git a/docs/api/representation/langchain.md b/docs/api/representation/langchain.md deleted file mode 100644 index 272b517b..00000000 --- a/docs/api/representation/langchain.md +++ /dev/null @@ -1,3 +0,0 @@ -# `LangChain` - -::: bertopic.representation._langchain.LangChain diff --git a/docs/api/representation/mmr.md b/docs/api/representation/mmr.md deleted file mode 100644 index afff1a00..00000000 --- a/docs/api/representation/mmr.md +++ /dev/null @@ -1,3 +0,0 @@ -# `MaximalMarginalRelevance` - -::: bertopic.representation._mmr.MaximalMarginalRelevance diff --git a/docs/api/representation/openai.md b/docs/api/representation/openai.md deleted file mode 100644 index 623dde9a..00000000 --- a/docs/api/representation/openai.md +++ /dev/null @@ -1,3 +0,0 @@ -# `OpenAI` - -::: bertopic.representation.OpenAI diff --git a/docs/api/representation/pos.md b/docs/api/representation/pos.md deleted file mode 100644 index 4ec1eb17..00000000 --- a/docs/api/representation/pos.md +++ /dev/null @@ -1,3 +0,0 @@ -# `PartOfSpeech` - -::: bertopic.representation._pos.PartOfSpeech diff --git a/docs/api/representation/zeroshot.md b/docs/api/representation/zeroshot.md deleted file mode 100644 index cd602591..00000000 --- a/docs/api/representation/zeroshot.md +++ /dev/null @@ -1,3 +0,0 @@ -# `ZeroShotClassification` - -::: bertopic.representation._zeroshot.ZeroShotClassification diff --git a/docs/api/representations.md b/docs/api/representations.md new file mode 100644 index 00000000..b06f85a7 --- /dev/null +++ b/docs/api/representations.md @@ -0,0 +1,3 @@ +# `Representations` + +::: bertopic.representation \ No newline at end of file diff --git a/docs/api/vectorizers.md b/docs/api/vectorizers.md new file mode 100644 index 00000000..44a8e4ac --- /dev/null +++ b/docs/api/vectorizers.md @@ -0,0 +1,3 @@ +# `Vectorizers` + +::: bertopic.vectorizers._online_cv.OnlineCountVectorizer diff --git a/docs/getting_started/representation/llm.md b/docs/getting_started/representation/llm.md index 0385482f..df72b3f5 100644 --- a/docs/getting_started/representation/llm.md +++ b/docs/getting_started/representation/llm.md @@ -377,6 +377,7 @@ topic_model = BERTopic(representation_model=representation_model, verbose=True) """ ``` + ## **OpenAI** Instead of using a language model from 🤗 transformers, we can use external APIs instead that @@ -469,6 +470,34 @@ The above is not constrained to just creating a short description or summary of If you want to have multiple representations of a single topic, it might be worthwhile to also check out [**multi-aspect**](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) topic modeling with BERTopic. +## **LiteLLM** + +An amazing framework to simplify connecting to external LLMs, is [LiteLLM](https://docs.litellm.ai). This package allows you to connect to OpenAI, Cohere, Anthropic, etc. all within one package. This makes iteration and testing out different models a breeze! + +o start with, we first need to install `litellm`: + +```bash +pip install litellm +``` + +After installation, usage is straightforward and you can select any model found in their [docs](https://docs.litellm.ai/docs/providers). +Let's show an example with OpenAI: + +```python +import os +from bertopic import BERTopic +from bertopic.representation import LiteLLM + +# set ENV variables +os.environ["OPENAI_API_KEY"] = "MY_KEY" + +# Create your representation model +representation_model = LiteLLM(model="gpt-4o-mini") + +# Create our BERTopic model +topic_model = BERTopic(representation_model=representation_model, verbose=True) +``` + ## **LangChain** [Langchain](https://github.com/hwchase17/langchain) is a package that helps users with chaining large language models. diff --git a/mkdocs.yml b/mkdocs.yml index 8dc5f4f0..92a1ce78 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -57,29 +57,12 @@ nav: - API: - BERTopic: api/bertopic.md - Sub-models: - - Backends: - - Base: api/backends/base.md - - Word Doc: api/backends/word_doc.md - - OpenAI: api/backends/openai.md - - Cohere: api/backends/cohere.md - - Dimensionality Reduction: - - Base: api/dimensionality/base.md - - Clustering: - - Base: api/cluster/base.md - - Vectorizers: - - cTFIDF: api/ctfidf.md - - OnlineCountVectorizer: api/onlinecv.md - - Topic Representation: - - Base: api/representation/base.md - - MaximalMarginalRelevance: api/representation/mmr.md - - KeyBERT: api/representation/keybert.md - - PartOfSpeech: api/representation/pos.md - - Text Generation: - - 🤗 Transformers: api/representation/generation.md - - LangChain: api/representation/langchain.md - - Cohere: api/representation/cohere.md - - OpenAI: api/representation/openai.md - - Zero-shot Classification: api/representation/zeroshot.md + - 1. Backends: api/backends.md + - 2. Dimensionality Reduction: api/dimensionality.md + - 3. Clustering: api/cluster.md + - 4. Vectorizers: api/vectorizers.md + - 5. c-TF-IDF: api/ctfidf.md + - 6. Fine-Tune Topic Representation: api/representations.md - Plotting: - Barchart: api/plotting/barchart.md - Documents: api/plotting/documents.md From 5e719e301504b8619048578679824d98f5531eb1 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Wed, 13 Nov 2024 13:35:07 +0100 Subject: [PATCH 2/5] Incorrect variable naming --- bertopic/representation/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bertopic/representation/__init__.py b/bertopic/representation/__init__.py index 3a2a22c5..60cdecc6 100644 --- a/bertopic/representation/__init__.py +++ b/bertopic/representation/__init__.py @@ -38,7 +38,7 @@ from bertopic.representation._litellm import LiteLLM except ModuleNotFoundError: msg = "`pip install litellm` \n\n" - OpenAI = NotInstalled("LiteLLM", "litellm", custom_msg=msg) + LiteLLM = NotInstalled("LiteLLM", "litellm", custom_msg=msg) # LangChain Generator try: From b5cc41589a41341c5da0a103cc17039b27855e89 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Wed, 13 Nov 2024 13:42:05 +0100 Subject: [PATCH 3/5] Ruff --- bertopic/representation/_litellm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bertopic/representation/_litellm.py b/bertopic/representation/_litellm.py index 5a7a891a..0aa1bebc 100644 --- a/bertopic/representation/_litellm.py +++ b/bertopic/representation/_litellm.py @@ -73,7 +73,7 @@ class LiteLLM(BaseRepresentation): prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '" representation_model = LiteLLM(model="gpt", prompt=prompt) ``` - """ + """ # noqa: D301 def __init__( self, From dad416bd1e2004bb6cdf690ad79140c3809ed884 Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Fri, 15 Nov 2024 11:50:10 +0100 Subject: [PATCH 4/5] Fix docstrings --- bertopic/representation/_litellm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bertopic/representation/_litellm.py b/bertopic/representation/_litellm.py index 0aa1bebc..c872e381 100644 --- a/bertopic/representation/_litellm.py +++ b/bertopic/representation/_litellm.py @@ -46,7 +46,7 @@ class LiteLLM(BaseRepresentation): Usage: - To use this, you will need to install the openai package first: + To use this, you will need to install the litellm package first: `pip install litellm` From 55515ac31cdc11a01482156300159f53d45ae82d Mon Sep 17 00:00:00 2001 From: MaartenGr Date: Fri, 29 Nov 2024 09:05:01 +0100 Subject: [PATCH 5/5] Add ollama instructions --- docs/getting_started/representation/llm.md | 53 ++++++++++++++++++---- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/docs/getting_started/representation/llm.md b/docs/getting_started/representation/llm.md index df72b3f5..0a9e59fa 100644 --- a/docs/getting_started/representation/llm.md +++ b/docs/getting_started/representation/llm.md @@ -140,7 +140,7 @@ As can be seen from the example above, if you would like to use a `text2text-gen pass a `transformers.pipeline` with the `"text2text-generation"` parameter. Moreover, you can use a custom prompt and decide where the keywords should be inserted by using the `[KEYWORDS]` or documents with the `[DOCUMENTS]` tag. -### **Zephyr** (Mistral 7B) +### **Mistral (GGUF)** We can go a step further with open-source Large Language Models (LLMs) that have shown to match the performance of closed-source LLMs like ChatGPT. @@ -206,13 +206,17 @@ representation_model = {"Zephyr": zephyr} topic_model = BERTopic(representation_model=representation_model, verbose=True) ``` -### **Llama 2** +### **Llama (Manual Quantization)** -Full Llama 2 Tutorial: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QCERSMUjqGetGGujdrvv_6_EeoIcd_9M?usp=sharing) +Full Llama Tutorial: [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1QCERSMUjqGetGGujdrvv_6_EeoIcd_9M?usp=sharing) Open-source LLMs are starting to become more and more popular. Here, we will go through a minimal example of using [Llama 2](https://huggingface.co/meta-llama/Llama-2-13b-chat-hf) together with BERTopic. -First, we need to load in our Llama 2 model: +!!! Note + Although this is an example of the older Llama 2 model, you can use the code below for any Llama variant. + + +First, we need to load in our Llama model: ```python from torch import bfloat16 @@ -227,10 +231,10 @@ bnb_config = transformers.BitsAndBytesConfig( bnb_4bit_compute_dtype=bfloat16 # Computation type ) -# Llama 2 Tokenizer +# Llama Tokenizer tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) -# Llama 2 Model +# Llama Model model = transformers.AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, @@ -249,7 +253,7 @@ generator = transformers.pipeline( ) ``` -After doing so, we will need to define a prompt that works with both Llama 2 as well as BERTopic: +After doing so, we will need to define a prompt that works with both Llama as well as BERTopic: ```python @@ -292,7 +296,7 @@ prompt = system_prompt + example_prompt + main_prompt Three pieces of the prompt were created: * `system_prompt` helps us guide the model during a conversation. For example, we can say that it is a helpful assistant that is specialized in labeling topics. -* `example_prompt` gives an example of a correctly labeled topic to guide Llama 2 +* `example_prompt` gives an example of a correctly labeled topic to guide Llama * `main_prompt` contains the main question we are going to ask it, namely to label a topic. Note that it uses the `[DOCUMENTS]` and `[KEYWORDS]` to provide the most relevant documents and keywords as additional context After having generated our prompt template, we can start running our topic model: @@ -301,7 +305,7 @@ After having generated our prompt template, we can start running our topic model from bertopic.representation import TextGeneration from bertopic import BERTopic -# Text generation with Llama 2 +# Text generation with Llama llama2 = TextGeneration(generator, prompt=prompt) representation_model = { "Llama2": llama2, @@ -469,6 +473,37 @@ representation_model = OpenAI(client, model="gpt-3.5-turbo", chat=True, prompt=s The above is not constrained to just creating a short description or summary of the topic, we can extract labels, keywords, poems, example documents, extensitive descriptions, and more using this method! If you want to have multiple representations of a single topic, it might be worthwhile to also check out [**multi-aspect**](https://maartengr.github.io/BERTopic/getting_started/multiaspect/multiaspect.html) topic modeling with BERTopic. +## **Ollama** + +To use [Ollama](https://github.com/ollama/ollama) within BERTopic, it is advised to use the `openai` package as it allows to pass through a model using the url on which the model is running. + +You will first need to install `openai`: + +```bash +pip install openai +``` + +After installation, usage is straightforward and you can select any model that you have prepared in your `ollama` model list. You can see all models by running `ollama list`. + +Select one from the list and you can use it in BERTopic as follows: + +```python +import openai +from bertopic.representation import OpenAI +from bertopic import BERTopic + +client = openai.OpenAI( + base_url = 'http://localhost:11434/v1', #wherever ollama is running + api_key='ollama', # required, but unused +) + + +# Create your representation model +representation_model = OpenAI(client, model='phi3:14b-medium-128k-instruct-q4_K_M') + +# Create your BERTopic model +topic_model = BERTopic(representation_model=representation_model, verbose=True) +``` ## **LiteLLM**