Skip to content

Commit

Permalink
Adjust _topic_id_to_zeroshot_topic_idx when adding mapping to `Topi…
Browse files Browse the repository at this point in the history
…cMapper` (#7) (#2120)
  • Loading branch information
ianrandman authored Aug 15, 2024
1 parent fed0682 commit 63710da
Showing 1 changed file with 50 additions and 43 deletions.
93 changes: 50 additions & 43 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2143,7 +2143,7 @@ def merge_topics(

# Update topics
documents.Topic = documents.Topic.map(mapping)
self.topic_mapper_.add_mappings(mapping)
self.topic_mapper_.add_mappings(mapping, topic_model=self)
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._update_topic_size(documents)
Expand Down Expand Up @@ -4402,50 +4402,12 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)
# Map topics
documents.Topic = new_topics
self._update_topic_size(documents)
self.topic_mapper_.add_mappings(mapped_topics)
self.topic_mapper_.add_mappings(mapped_topics, topic_model=self)

# Update representations
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)

# When zero-shot topic(s) are present in the topics to merge,
# determine whether to take one of the zero-shot topic labels
# or use a calculated representation.
if self._is_zeroshot():
new_topic_id_to_zeroshot_topic_idx = {}
topics_to_map = {
topic_mapping[0]: topic_mapping[1] for topic_mapping in np.array(self.topic_mapper_.mappings_)[:, -2:]
}

for topic_to, topics_from in basic_mappings.items():
# When extracting topics, the reduced topics were reordered.
# Must get the updated topic_to.
topic_to = topics_to_map[topic_to]

# which of the original topics are zero-shot
zeroshot_topic_ids = [
topic_id for topic_id in topics_from if topic_id in self._topic_id_to_zeroshot_topic_idx
]
if len(zeroshot_topic_ids) == 0:
continue

# If any of the original topics are zero-shot, take the best fitting zero-shot label
# if the cosine similarity with the new topic exceeds the zero-shot threshold
zeroshot_labels = [
self.zeroshot_topic_list[self._topic_id_to_zeroshot_topic_idx[topic_id]]
for topic_id in zeroshot_topic_ids
]
zeroshot_embeddings = self._extract_embeddings(zeroshot_labels)
cosine_similarities = cosine_similarity(
zeroshot_embeddings, [self.topic_embeddings_[topic_to]]
).flatten()
best_zeroshot_topic_idx = np.argmax(cosine_similarities)
best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx]
if best_cosine_similarity >= self.zeroshot_min_similarity:
new_topic_id_to_zeroshot_topic_idx[topic_to] = zeroshot_topic_ids[best_zeroshot_topic_idx]

self._topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx

self._update_topic_size(documents)
return documents

Expand Down Expand Up @@ -4498,7 +4460,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False)
}

# Update documents and topics
self.topic_mapper_.add_mappings(mapped_topics)
self.topic_mapper_.add_mappings(mapped_topics, topic_model=self)
documents = self._sort_mappings_by_frequency(documents)
self._extract_topics(documents, mappings=mappings)
self._update_topic_size(documents)
Expand Down Expand Up @@ -4538,7 +4500,7 @@ def _sort_mappings_by_frequency(self, documents: pd.DataFrame) -> pd.DataFrame:
df = pd.DataFrame(self.topic_sizes_.items(), columns=["Old_Topic", "Size"]).sort_values("Size", ascending=False)
df = df[df.Old_Topic != -1]
sorted_topics = {**{-1: -1}, **dict(zip(df.Old_Topic, range(len(df))))}
self.topic_mapper_.add_mappings(sorted_topics)
self.topic_mapper_.add_mappings(sorted_topics, topic_model=self)

# Map documents
documents.Topic = documents.Topic.map(sorted_topics).fillna(documents.Topic).astype(int)
Expand Down Expand Up @@ -4728,11 +4690,12 @@ def get_mappings(self, original_topics: bool = True) -> Mapping[int, int]:
mappings = dict(zip(mappings[:, 0], mappings[:, 1]))
return mappings

def add_mappings(self, mappings: Mapping[int, int]):
def add_mappings(self, mappings: Mapping[int, int], topic_model: BERTopic):
"""Add new column(s) of topic mappings.

Arguments:
mappings: The mappings to add
topic_model: The topic model this TopicMapper belongs to
"""
for topics in self.mappings_:
topic = topics[-1]
Expand All @@ -4741,6 +4704,50 @@ def add_mappings(self, mappings: Mapping[int, int]):
else:
topics.append(-1)

# When zero-shot topic(s) are present in the topics to merge,
# determine whether to take one of the zero-shot topic labels
# or use a calculated representation.
if topic_model._is_zeroshot() and len(topic_model._topic_id_to_zeroshot_topic_idx) > 0:
new_topic_id_to_zeroshot_topic_idx = {}
topics_to_map = {
topic_mapping[0]: topic_mapping[1]
for topic_mapping in np.array(topic_model.topic_mapper_.mappings_)[:, -2:]
}

# Map topic_to to topics_from
mapping = defaultdict(list)
for key, value in topics_to_map.items():
mapping[value].append(key)

for topic_to, topics_from in mapping.items():
# which of the original topics are zero-shot
zeroshot_topic_ids = [
topic_id for topic_id in topics_from if topic_id in topic_model._topic_id_to_zeroshot_topic_idx
]
if len(zeroshot_topic_ids) == 0:
continue

# If any of the original topics are zero-shot, take the best fitting zero-shot label
# if the cosine similarity with the new topic exceeds the zero-shot threshold
zeroshot_labels = [
topic_model.zeroshot_topic_list[topic_model._topic_id_to_zeroshot_topic_idx[topic_id]]
for topic_id in zeroshot_topic_ids
]
zeroshot_embeddings = topic_model._extract_embeddings(zeroshot_labels)
cosine_similarities = cosine_similarity(
zeroshot_embeddings, [topic_model.topic_embeddings_[topic_to]]
).flatten()
best_zeroshot_topic_idx = np.argmax(cosine_similarities)
best_cosine_similarity = cosine_similarities[best_zeroshot_topic_idx]

if best_cosine_similarity >= topic_model.zeroshot_min_similarity:
# Using the topic ID from before mapping, get the idx into the zeroshot topic list
new_topic_id_to_zeroshot_topic_idx[topic_to] = topic_model._topic_id_to_zeroshot_topic_idx[
zeroshot_topic_ids[best_zeroshot_topic_idx]
]

topic_model._topic_id_to_zeroshot_topic_idx = new_topic_id_to_zeroshot_topic_idx

def add_new_topics(self, mappings: Mapping[int, int]):
"""Add new row(s) of topic mappings.

Expand Down

0 comments on commit 63710da

Please sign in to comment.