Skip to content

Commit

Permalink
Make self._topic_id_to_zeroshot_topic_idx private, add comments/doc…
Browse files Browse the repository at this point in the history
…strings, lower threshold zeroshot test, fix outliers for probabilities during zeroshot (#2)
  • Loading branch information
ianrandman committed Jun 23, 2024
1 parent 7766277 commit fbc574b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 11 deletions.
59 changes: 49 additions & 10 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ def __init__(
self.topic_mapper_ = None
self.topic_representations_ = None
self.topic_embeddings_ = None
self.topic_id_to_zeroshot_topic_idx = {}
self._topic_id_to_zeroshot_topic_idx = {}
self.custom_labels_ = None
self.c_tf_idf_ = None
self.representative_images_ = None
Expand All @@ -291,13 +291,28 @@ def __init__(

@property
def _outliers(self):
# Some algorithms have outlier labels (-1) that can be tricky to work
# with if you are slicing data based on that labels. Therefore, we
# track if there are outlier labels and act accordingly when slicing.
"""
Some algorithms have outlier labels (-1) that can be tricky to work
with if you are slicing data based on that labels. Therefore, we
track if there are outlier labels and act accordingly when slicing.
Returns:
An integer indicating whether outliers are present in the topic model
"""

Check failure on line 301 in bertopic/_bertopic.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (D202)

bertopic/_bertopic.py:294:9: D202 No blank lines allowed after function docstring (found 1)

Check failure on line 301 in bertopic/_bertopic.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (D212)

bertopic/_bertopic.py:294:9: D212 Multi-line docstring summary should start at the first line

return 1 if -1 in self.topic_sizes_ else 0

@property
def topic_labels_(self):
"""
Map topic IDs to their labels.
A label is the topic ID, along with the first four words of the topic representation, joined using '_'.
Zeroshot topic labels come from self.zeroshot_topic_list rather than the calculated representation.
Returns:
topic_labels: a dict mapping a topic ID (int) to its label (str)
"""

Check failure on line 314 in bertopic/_bertopic.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (D212)

bertopic/_bertopic.py:307:9: D212 Multi-line docstring summary should start at the first line

Check failure on line 314 in bertopic/_bertopic.py

View workflow job for this annotation

GitHub Actions / lint

Ruff (D202)

bertopic/_bertopic.py:307:9: D202 No blank lines allowed after function docstring (found 1)

topic_labels = {
key: f"{key}_" + "_".join([word[0] for word in values[:4]])
for key, values in self.topic_representations_.items()
Expand All @@ -306,7 +321,7 @@ def topic_labels_(self):
# Need to correct labels from zero-shot topics
topic_id_to_zeroshot_label = {
topic_id: self.zeroshot_topic_list[zeroshot_topic_idx]
for topic_id, zeroshot_topic_idx in self.topic_id_to_zeroshot_topic_idx.items()
for topic_id, zeroshot_topic_idx in self._topic_id_to_zeroshot_topic_idx.items()
}
topic_labels.update(topic_id_to_zeroshot_label)
return topic_labels
Expand Down Expand Up @@ -510,7 +525,8 @@ def fit_transform(
# Use `topics_before_reduction` because `self.topics_` may have already been updated from
# reducing topics, and the original probabilities are needed for `self._map_probabilities()`
probabilities = sim_matrix[
np.arange(len(documents)), topics_before_reduction
np.arange(len(documents)),
np.array(topics_before_reduction) + self._outliers,
]

# Resulting output
Expand Down Expand Up @@ -4086,14 +4102,36 @@ def _combine_zeroshot_topics(
assigned_documents: pd.DataFrame,
assigned_embeddings: np.ndarray,
) -> tuple[pd.DataFrame, np.ndarray]:
"""Combine the zero-shot topics with the clustered topics.
The zero-shot topics will be inserted between the outlier topic (that may or may not exist) and the rest of the
topics from clustering. The rest of the topics from clustering will be given new IDs to correspond to topics
after zero-shot topics.
Documents and embeddings used in zero-shot topic modeling and clustering and re-merged.
Arguments:
documents: DataFrame with clustered documents and their corresponding IDs
embeddings: The document embeddings for clustered documents
assigned_documents: DataFrame with documents and their corresponding IDs
that were assigned to a zero-shot topic
assigned_embeddings: The document embeddings for documents that were assigned to a zero-shot topic
Returns:
documents: DataFrame with all the original documents with their topic assignments
embeddings: np.ndarray of embeddings aligned with the documents
"""
logger.info(
"Zeroshot Step 2 - Combining topics from zero-shot topic modeling with topics from clustering..."
)
# Combine Zero-shot topics with topics from clustering
zeroshot_topic_idx_to_topic_id = {
zeroshot_topic_id: new_topic_id
for new_topic_id, zeroshot_topic_id in enumerate(
set(assigned_documents.Topic)
)
}
self.topic_id_to_zeroshot_topic_idx = {
self._topic_id_to_zeroshot_topic_idx = {
new_topic_id: zeroshot_topic_id
for new_topic_id, zeroshot_topic_id in enumerate(
set(assigned_documents.Topic)
Expand Down Expand Up @@ -4122,6 +4160,7 @@ def _combine_zeroshot_topics(
self._update_topic_size(documents)
self.topic_mapper_ = TopicMapper(self.topics_)

logger.info("Zeroshot Step 2 - Completed \u2713")
return documents, embeddings

def _guided_topic_modeling(
Expand Down Expand Up @@ -4720,7 +4759,7 @@ def _reduce_to_n_topics(
zeroshot_topic_ids = [
topic_id
for topic_id in topics_from
if topic_id in self.topic_id_to_zeroshot_topic_idx
if topic_id in self._topic_id_to_zeroshot_topic_idx
]
if len(zeroshot_topic_ids) == 0:
continue
Expand All @@ -4729,7 +4768,7 @@ def _reduce_to_n_topics(
# if the cosine similarity with the new topic exceeds the zero-shot threshold
zeroshot_labels = [
self.zeroshot_topic_list[
self.topic_id_to_zeroshot_topic_idx[topic_id]
self._topic_id_to_zeroshot_topic_idx[topic_id]
]
for topic_id in zeroshot_topic_ids
]
Expand All @@ -4744,7 +4783,7 @@ def _reduce_to_n_topics(
best_zeroshot_topic_idx
]

self.topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx
self._topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx

self._update_topic_size(documents)
return documents
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def zeroshot_topic_model(documents, document_embeddings, embedding_model):
embedding_model=embedding_model,
calculate_probabilities=True,
zeroshot_topic_list=zeroshot_topic_list,
zeroshot_min_similarity=0.5,
zeroshot_min_similarity=0.3,
)
model.umap_model.random_state = 42
model.hdbscan_model.min_cluster_size = 2
Expand Down

0 comments on commit fbc574b

Please sign in to comment.