Skip to content

Commit

Permalink
Fixed issue #2144
Browse files Browse the repository at this point in the history
  • Loading branch information
PipaFlores committed Oct 22, 2024
1 parent 0bafe4f commit 1e2ec70
Showing 1 changed file with 26 additions and 19 deletions.
45 changes: 26 additions & 19 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,23 +477,21 @@ def fit_transform(
# Create documents from images if we have images only
if documents.Document.values[0] is None:
custom_documents = self._images_to_text(documents, embeddings)

# Reduce topics if needed, extract topics by calculating c-TF-IDF, and get representations.

# Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations.
self._extract_topics(custom_documents, embeddings=embeddings, fine_tune_representation=not self.nr_topics)
if self.nr_topics:
custom_documents = self._reduce_topics(custom_documents)
else:
self._extract_topics(custom_documents, embeddings=embeddings, verbose=self.verbose)
self._create_topic_vectors(documents=documents, embeddings=embeddings)

# Save the top 3 most representative documents per topic
self._save_representative_docs(custom_documents)

else:
# Reduce topics if needed, extract topics by calculating c-TF-IDF, and get representations.
# Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations.
self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose, fine_tune_representation=not self.nr_topics)
if self.nr_topics:
documents = self._reduce_topics(documents)
else:
self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose)

# Save the top 3 most representative documents per topic
self._save_representative_docs(documents)
Expand Down Expand Up @@ -3971,6 +3969,7 @@ def _extract_topics(
embeddings: np.ndarray = None,
mappings=None,
verbose: bool = False,
fine_tune_representation: bool = True,
):
"""Extract topics from the clusters using a class-based TF-IDF.
Expand All @@ -3979,16 +3978,25 @@ def _extract_topics(
embeddings: The document embeddings
mappings: The mappings from topic to word
verbose: Whether to log the process of extracting topics
fine_tune_representation: If True, the topic representation will be fine-tuned using representation models.
If False, the topic representation will remain as the base c-TF-IDF representation.
Returns:
c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic
"""
if verbose:
logger.info("Representation - Extracting topics from clusters using representation models.")
action = "Fine-tuning" if fine_tune_representation else "Extracting"
method = "representation models" if fine_tune_representation else "c-TF-IDF"
logger.info(f"Representation - {action} topics using {method}.")

documents_per_topic = documents.groupby(["Topic"], as_index=False).agg({"Document": " ".join})
self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic)
self.topic_representations_ = self._extract_words_per_topic(words, documents)
self.topic_representations_ = self._extract_words_per_topic(
words, documents,
fine_tune_representation=fine_tune_representation,
calculate_aspects=fine_tune_representation)
self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings)

if verbose:
logger.info("Representation - Completed \u2713")

Expand Down Expand Up @@ -4244,6 +4252,7 @@ def _extract_words_per_topic(
words: List[str],
documents: pd.DataFrame,
c_tf_idf: csr_matrix = None,
fine_tune_representation: bool = True,
calculate_aspects: bool = True,
) -> Mapping[str, List[Tuple[str, float]]]:
"""Based on tf_idf scores per topic, extract the top n words per topic.
Expand All @@ -4257,6 +4266,8 @@ def _extract_words_per_topic(
words: List of all words (sorted according to tf_idf matrix position)
documents: DataFrame with documents and their topic IDs
c_tf_idf: A c-TF-IDF matrix from which to calculate the top words
fine_tune_representation: If True, the topic representation will be fine-tuned using representation models.
If False, the topic representation will remain as the base c-TF-IDF representation.
calculate_aspects: Whether to calculate additional topic aspects
Returns:
Expand Down Expand Up @@ -4287,15 +4298,15 @@ def _extract_words_per_topic(

# Fine-tune the topic representations
topics = base_topics.copy()
if not self.representation_model:
if not self.representation_model or not fine_tune_representation:
# Default representation: c_tf_idf + top_n_words
topics = {label: values[: self.top_n_words] for label, values in topics.items()}
elif isinstance(self.representation_model, list):
elif fine_tune_representation and isinstance(self.representation_model, list):
for tuner in self.representation_model:
topics = tuner.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, BaseRepresentation):
elif fine_tune_representation and isinstance(self.representation_model, BaseRepresentation):
topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics)
elif isinstance(self.representation_model, dict):
elif fine_tune_representation and isinstance(self.representation_model, dict):
if self.representation_model.get("Main"):
main_model = self.representation_model["Main"]
if isinstance(main_model, BaseRepresentation):
Expand Down Expand Up @@ -4339,21 +4350,17 @@ def _reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) -> p
documents: Dataframe with documents and their corresponding IDs and Topics
use_ctfidf: Whether to calculate distances between topics based on c-TF-IDF embeddings. If False, semantic
embeddings are used.
Returns:
documents: Updated dataframe with documents and the reduced number of Topics
"""
logger.info("Topic reduction - Reducing number of topics")
initial_nr_topics = len(documents["Topic"].unique())
initial_nr_topics = len(self.get_topics())

if isinstance(self.nr_topics, int):
if self.nr_topics < initial_nr_topics:
documents = self._reduce_to_n_topics(documents, use_ctfidf)
else:
logger.info(
f"Topic reduction - Number of topics ({self.nr_topics}) is equal or higher than the clustered topics({len(documents['Topic'].unique())})."
)
documents = self._sort_mappings_by_frequency(documents)
logger.info(f"Topic reduction - Number of topics ({self.nr_topics}) is equal or higher than the clustered topics({len(self.get_topics())}).")
self._extract_topics(documents, verbose=self.verbose)
return documents
elif isinstance(self.nr_topics, str):
Expand Down

0 comments on commit 1e2ec70

Please sign in to comment.