diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index 90379b86..37ad71e3 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -478,21 +478,20 @@ def fit_transform( if documents.Document.values[0] is None: custom_documents = self._images_to_text(documents, embeddings) - # Extract topics by calculating c-TF-IDF - self._extract_topics(custom_documents, embeddings=embeddings) - self._create_topic_vectors(documents=documents, embeddings=embeddings) - - # Reduce topics + # Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations. + self._extract_topics(custom_documents, embeddings=embeddings, calculate_representation=not self.nr_topics) if self.nr_topics: custom_documents = self._reduce_topics(custom_documents) + self._create_topic_vectors(documents=documents, embeddings=embeddings) # Save the top 3 most representative documents per topic self._save_representative_docs(custom_documents) - else: - # Extract topics by calculating c-TF-IDF - self._extract_topics(documents, embeddings=embeddings, verbose=self.verbose) - # Reduce topics + else: + # Extract topics by calculating c-TF-IDF, reduce topics if needed, and get representations. + self._extract_topics( + documents, embeddings=embeddings, verbose=self.verbose, calculate_representation=not self.nr_topics + ) if self.nr_topics: documents = self._reduce_topics(documents) @@ -3972,6 +3971,7 @@ def _extract_topics( embeddings: np.ndarray = None, mappings=None, verbose: bool = False, + calculate_representation: bool = True, ): """Extract topics from the clusters using a class-based TF-IDF. @@ -3980,18 +3980,29 @@ def _extract_topics( embeddings: The document embeddings mappings: The mappings from topic to word verbose: Whether to log the process of extracting topics + calculate_representation: Whether to extract the topic representations Returns: c_tf_idf: The resulting matrix giving a value (importance score) for each word per topic """ if verbose: - logger.info("Representation - Extracting topics from clusters using representation models.") + action = "Representation" if calculate_representation else "Topics" + logger.info( + f"{action} - Extracting topics from clusters{' using representation models' if calculate_representation else ''}." + ) + documents_per_topic = documents.groupby(["Topic"], as_index=False).agg({"Document": " ".join}) self.c_tf_idf_, words = self._c_tf_idf(documents_per_topic) - self.topic_representations_ = self._extract_words_per_topic(words, documents) + self.topic_representations_ = self._extract_words_per_topic( + words, + documents, + calculate_representation=calculate_representation, + calculate_aspects=calculate_representation, + ) self._create_topic_vectors(documents=documents, embeddings=embeddings, mappings=mappings) + if verbose: - logger.info("Representation - Completed \u2713") + logger.info(f"{action} - Completed \u2713") def _save_representative_docs(self, documents: pd.DataFrame): """Save the 3 most representative docs per topic. @@ -4245,6 +4256,7 @@ def _extract_words_per_topic( words: List[str], documents: pd.DataFrame, c_tf_idf: csr_matrix = None, + calculate_representation: bool = True, calculate_aspects: bool = True, ) -> Mapping[str, List[Tuple[str, float]]]: """Based on tf_idf scores per topic, extract the top n words per topic. @@ -4258,6 +4270,7 @@ def _extract_words_per_topic( words: List of all words (sorted according to tf_idf matrix position) documents: DataFrame with documents and their topic IDs c_tf_idf: A c-TF-IDF matrix from which to calculate the top words + calculate_representation: Whether to calculate the topic representations calculate_aspects: Whether to calculate additional topic aspects Returns: @@ -4288,15 +4301,15 @@ def _extract_words_per_topic( # Fine-tune the topic representations topics = base_topics.copy() - if not self.representation_model: + if not self.representation_model or not calculate_representation: # Default representation: c_tf_idf + top_n_words topics = {label: values[: self.top_n_words] for label, values in topics.items()} - elif isinstance(self.representation_model, list): + elif calculate_representation and isinstance(self.representation_model, list): for tuner in self.representation_model: topics = tuner.extract_topics(self, documents, c_tf_idf, topics) - elif isinstance(self.representation_model, BaseRepresentation): + elif calculate_representation and isinstance(self.representation_model, BaseRepresentation): topics = self.representation_model.extract_topics(self, documents, c_tf_idf, topics) - elif isinstance(self.representation_model, dict): + elif calculate_representation and isinstance(self.representation_model, dict): if self.representation_model.get("Main"): main_model = self.representation_model["Main"] if isinstance(main_model, BaseRepresentation): @@ -4350,6 +4363,13 @@ def _reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) -> p if isinstance(self.nr_topics, int): if self.nr_topics < initial_nr_topics: documents = self._reduce_to_n_topics(documents, use_ctfidf) + else: + logger.info( + f"Topic reduction - Number of topics ({self.nr_topics}) is equal or higher than the clustered topics({len(self.get_topics())})." + ) + documents = self._sort_mappings_by_frequency(documents) + self._extract_topics(documents, verbose=self.verbose) + return documents elif isinstance(self.nr_topics, str): documents = self._auto_reduce_topics(documents, use_ctfidf) else: @@ -4412,7 +4432,7 @@ def _reduce_to_n_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) # Update representations documents = self._sort_mappings_by_frequency(documents) - self._extract_topics(documents, mappings=mappings) + self._extract_topics(documents, mappings=mappings, verbose=self.verbose) self._update_topic_size(documents) return documents @@ -4468,7 +4488,7 @@ def _auto_reduce_topics(self, documents: pd.DataFrame, use_ctfidf: bool = False) # Update documents and topics self.topic_mapper_.add_mappings(mapped_topics, topic_model=self) documents = self._sort_mappings_by_frequency(documents) - self._extract_topics(documents, mappings=mappings) + self._extract_topics(documents, mappings=mappings, verbose=self.verbose) self._update_topic_size(documents) return documents diff --git a/tests/test_representation/test_representations.py b/tests/test_representation/test_representations.py index 7c819964..fa756625 100644 --- a/tests/test_representation/test_representations.py +++ b/tests/test_representation/test_representations.py @@ -153,6 +153,7 @@ def test_topic_reduction_edge_cases(model, documents, request): topics = np.random.randint(-1, nr_topics - 1, len(documents)) old_documents = pd.DataFrame({"Document": documents, "ID": range(len(documents)), "Topic": topics}) topic_model._update_topic_size(old_documents) + old_documents = topic_model._sort_mappings_by_frequency(old_documents) topic_model._extract_topics(old_documents) old_freq = topic_model.get_topic_freq()