Skip to content

Commit

Permalink
Merge branch 'MaartenGr:master' into short-explanation-difference-bet…
Browse files Browse the repository at this point in the history
…ween-guided-and-zeroshot
  • Loading branch information
janspoerer authored Dec 13, 2024
2 parents 2faf380 + 84dbf36 commit f0359bb
Show file tree
Hide file tree
Showing 33 changed files with 311 additions and 81 deletions.
2 changes: 1 addition & 1 deletion bertopic/cluster/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class BaseCluster:
```python
from bertopic import BERTopic
from bertopic.dimensionality import BaseCluster
from bertopic.cluster import BaseCluster
empty_cluster_model = BaseCluster()
Expand Down
8 changes: 8 additions & 0 deletions bertopic/representation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,13 @@
msg = "`pip install openai` \n\n"
OpenAI = NotInstalled("OpenAI", "openai", custom_msg=msg)

# LiteLLM Generator
try:
from bertopic.representation._litellm import LiteLLM
except ModuleNotFoundError:
msg = "`pip install litellm` \n\n"
LiteLLM = NotInstalled("LiteLLM", "litellm", custom_msg=msg)

# LangChain Generator
try:
from bertopic.representation._langchain import LangChain
Expand Down Expand Up @@ -63,6 +70,7 @@
"Cohere",
"OpenAI",
"LangChain",
"LiteLLM",
"LlamaCPP",
"VisualRepresentation",
]
4 changes: 3 additions & 1 deletion bertopic/representation/_cohere.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters


DEFAULT_PROMPT = """
Expand Down Expand Up @@ -124,6 +124,8 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

self.prompts_ = []

def extract_topics(
Expand Down
3 changes: 2 additions & 1 deletion bertopic/representation/_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Callable, Mapping, List, Tuple, Union

from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters

DEFAULT_PROMPT = "What are these documents about? Please give a single label."

Expand Down Expand Up @@ -148,6 +148,7 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

def extract_topics(
self,
Expand Down
176 changes: 176 additions & 0 deletions bertopic/representation/_litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
import time
from litellm import completion
import pandas as pd
from scipy.sparse import csr_matrix
from typing import Mapping, List, Tuple, Any
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import retry_with_exponential_backoff


DEFAULT_PROMPT = """
I have a topic that contains the following documents:
[DOCUMENTS]
The topic is described by the following keywords: [KEYWORDS]
Based on the information above, extract a short topic label in the following format:
topic: <topic label>
"""


class LiteLLM(BaseRepresentation):
"""Using the LiteLLM API to generate topic labels.
For an overview of models see:
https://docs.litellm.ai/docs/providers
Arguments:
model: Model to use. Defaults to OpenAI's "gpt-3.5-turbo".
generator_kwargs: Kwargs passed to `litellm.completion`.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
NOTE: Use `"[KEYWORDS]"` and `"[DOCUMENTS]"` in the prompt
to decide where the keywords and documents need to be
inserted.
delay_in_seconds: The delay in seconds between consecutive prompts
in order to prevent RateLimitErrors.
exponential_backoff: Retry requests with a random exponential backoff.
A short sleep is used when a rate limit error is hit,
then the requests is retried. Increase the sleep length
if errors are hit until 10 unsuccesfull requests.
If True, overrides `delay_in_seconds`.
nr_docs: The number of documents to pass to LiteLLM if a prompt
with the `["DOCUMENTS"]` tag is used.
diversity: The diversity of documents to pass to LiteLLM.
Accepts values between 0 and 1. A higher
values results in passing more diverse documents
whereas lower values passes more similar documents.
Usage:
To use this, you will need to install the litellm package first:
`pip install litellm`
Then, get yourself an API key of any provider (for instance OpenAI) and use it as follows:
```python
import os
from bertopic.representation import LiteLLM
from bertopic import BERTopic
# set ENV variables
os.environ["OPENAI_API_KEY"] = "your-openai-key"
# Create your representation model
representation_model = LiteLLM(model="gpt-3.5-turbo")
# Use the representation model in BERTopic on top of the default pipeline
topic_model = BERTopic(representation_model=representation_model)
```
You can also use a custom prompt:
```python
prompt = "I have the following documents: [DOCUMENTS] \nThese documents are about the following topic: '"
representation_model = LiteLLM(model="gpt", prompt=prompt)
```
""" # noqa: D301

def __init__(
self,
model: str = "gpt-3.5-turbo",
prompt: str = None,
generator_kwargs: Mapping[str, Any] = {},
delay_in_seconds: float = None,
exponential_backoff: bool = False,
nr_docs: int = 4,
diversity: float = None,
):
self.model = model
self.prompt = prompt if prompt else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.delay_in_seconds = delay_in_seconds
self.exponential_backoff = exponential_backoff
self.nr_docs = nr_docs
self.diversity = diversity

self.generator_kwargs = generator_kwargs
if self.generator_kwargs.get("model"):
self.model = generator_kwargs.get("model")
if self.generator_kwargs.get("prompt"):
del self.generator_kwargs["prompt"]

def extract_topics(
self, topic_model, documents: pd.DataFrame, c_tf_idf: csr_matrix, topics: Mapping[str, List[Tuple[str, float]]]
) -> Mapping[str, List[Tuple[str, float]]]:
"""Extract topics.
Arguments:
topic_model: A BERTopic model
documents: All input documents
c_tf_idf: The topic c-TF-IDF representation
topics: The candidate topics as calculated with c-TF-IDF
Returns:
updated_topics: Updated topic representations
"""
# Extract the top n representative documents per topic
repr_docs_mappings, _, _, _ = topic_model._extract_representative_docs(
c_tf_idf, documents, topics, 500, self.nr_docs, self.diversity
)

# Generate using a (Large) Language Model
updated_topics = {}
for topic, docs in repr_docs_mappings.items():
prompt = self._create_prompt(docs, topic, topics)

# Delay
if self.delay_in_seconds:
time.sleep(self.delay_in_seconds)

messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
kwargs = {"model": self.model, "messages": messages, **self.generator_kwargs}
if self.exponential_backoff:
response = chat_completions_with_backoff(**kwargs)
else:
response = completion(**kwargs)
label = response["choices"][0]["message"]["content"].strip().replace("topic: ", "")

updated_topics[topic] = [(label, 1)]

return updated_topics

def _create_prompt(self, docs, topic, topics):
keywords = list(zip(*topics[topic]))[0]

# Use the Default Chat Prompt
if self.prompt == DEFAULT_PROMPT:
prompt = self.prompt.replace("[KEYWORDS]", " ".join(keywords))
prompt = self._replace_documents(prompt, docs)

# Use a custom prompt that leverages keywords, documents or both using
# custom tags, namely [KEYWORDS] and [DOCUMENTS] respectively
else:
prompt = self.prompt
if "[KEYWORDS]" in prompt:
prompt = prompt.replace("[KEYWORDS]", " ".join(keywords))
if "[DOCUMENTS]" in prompt:
prompt = self._replace_documents(prompt, docs)

return prompt

@staticmethod
def _replace_documents(prompt, docs):
to_replace = ""
for doc in docs:
to_replace += f"- {doc[:255]}\n"
prompt = prompt.replace("[DOCUMENTS]", to_replace)
return prompt


def chat_completions_with_backoff(**kwargs):
return retry_with_exponential_backoff(
completion,
)(**kwargs)
3 changes: 2 additions & 1 deletion bertopic/representation/_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from llama_cpp import Llama
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters


DEFAULT_PROMPT = """
Expand Down Expand Up @@ -116,6 +116,7 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

self.prompts_ = []

Expand Down
3 changes: 3 additions & 0 deletions bertopic/representation/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from bertopic.representation._utils import (
retry_with_exponential_backoff,
truncate_document,
validate_truncate_document_parameters,
)


Expand Down Expand Up @@ -169,6 +170,8 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

self.prompts_ = []

self.generator_kwargs = generator_kwargs
Expand Down
3 changes: 2 additions & 1 deletion bertopic/representation/_textgeneration.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from transformers.pipelines.base import Pipeline
from typing import Mapping, List, Tuple, Any, Union, Callable
from bertopic.representation._base import BaseRepresentation
from bertopic.representation._utils import truncate_document
from bertopic.representation._utils import truncate_document, validate_truncate_document_parameters


DEFAULT_PROMPT = """
Expand Down Expand Up @@ -112,6 +112,7 @@ def __init__(
self.diversity = diversity
self.doc_length = doc_length
self.tokenizer = tokenizer
validate_truncate_document_parameters(self.tokenizer, self.doc_length)

self.prompts_ = []

Expand Down
15 changes: 14 additions & 1 deletion bertopic/representation/_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import random
import time
from typing import Union


def truncate_document(topic_model, doc_length, tokenizer, document: str):
def truncate_document(topic_model, doc_length: Union[int, None], tokenizer: Union[str, callable], document: str) -> str:
"""Truncate a document to a certain length.
If you want to add a custom tokenizer, then it will need to have a `decode` and
Expand Down Expand Up @@ -58,6 +59,18 @@ def decode(self, doc_chunks):
return document


def validate_truncate_document_parameters(tokenizer, doc_length) -> Union[None, ValueError]:
"""Validates parameters that are used in the function `truncate_document`."""
if tokenizer is None and doc_length is not None:
raise ValueError(
"Please select from one of the valid options for the `tokenizer` parameter: \n"
"{'char', 'whitespace', 'vectorizer'} \n"
"If `tokenizer` is of type callable ensure it has methods to encode and decode a document \n"
)
elif tokenizer is not None and doc_length is None:
raise ValueError("If `tokenizer` is provided, `doc_length` of type int must be provided as well.")


def retry_with_exponential_backoff(
func,
initial_delay: float = 1,
Expand Down
7 changes: 6 additions & 1 deletion docs/algorithm/algorithm.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,12 @@ The following models are implemented in `bertopic.representation`:
* `PartOfSpeech`
* `KeyBERTInspired`
* `ZeroShotClassification`
* `TextGeneration`
* `TextGeneration` (HuggingFace)
* `Cohere`
* `OpenAI`
* `LangChain`
* `LiteLLM`
* `LlamaCPP`

!!! tip Models
There are roughly two sets of models. **First** are the non-generative set of models that you can find [here](https://maartengr.github.io/BERTopic/getting_started/representation/representation.html). These include models that focus on enhancing the keywords in the topic representations. **Second** are the generative models that attempt to label or summarize the topics instead. You can find an overview of [implemented LLMs here](https://maartengr.github.io/BERTopic/getting_started/representation/llm).
3 changes: 3 additions & 0 deletions docs/api/backends.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `Backends`

::: bertopic.backend
3 changes: 0 additions & 3 deletions docs/api/backends/base.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/backends/cohere.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/backends/openai.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/backends/word_doc.md

This file was deleted.

File renamed without changes.
3 changes: 3 additions & 0 deletions docs/api/cluster.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `BaseCluster`

::: bertopic.cluster._base.BaseCluster
File renamed without changes.
3 changes: 0 additions & 3 deletions docs/api/onlinecv.md

This file was deleted.

3 changes: 3 additions & 0 deletions docs/api/plotting.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# `Plotting`

::: bertopic.plotting
3 changes: 0 additions & 3 deletions docs/api/representation/base.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/cohere.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/generation.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/keybert.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/langchain.md

This file was deleted.

3 changes: 0 additions & 3 deletions docs/api/representation/mmr.md

This file was deleted.

Loading

0 comments on commit f0359bb

Please sign in to comment.