diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index de57c35a..67955605 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -497,10 +497,14 @@ def transform(self, ``` """ check_is_fitted(self) - check_embeddings_shape(embeddings, documents) - + if isinstance(documents, str) or documents is None: documents = [documents] + + if len(documents) == 1 and isinstance(embeddings, np.ndarray) and embeddings.ndim == 1: + embeddings = embeddings.reshape(1,-1) + + check_embeddings_shape(embeddings, documents) if embeddings is None: embeddings = self._extract_embeddings(documents, diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index 5d4bfac8..6d0bd01f 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -138,3 +138,27 @@ def test_full_model(model, documents, request): assert len(merged_model.get_topic_info()) > len(topic_model1.get_topic_info()) assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info()) + +def test_transform_flexibility(documents, document_embeddings, request): + + topic_model = copy.deepcopy(request.getfixturevalue('base_topic_model')) + print(document_embeddings[0].shape) + try: + topic_model.transform(documents[0], document_embeddings[0]) + except ValueError: + pytest.fail('Error thrown for transform with single document and embeddings') + + try: + topic_model.transform(documents[0:2], document_embeddings[0:2]) + except ValueError: + pytest.fail('Error thrown for transform with multiple documents and embeddings') + + with pytest.raises(ValueError): + topic_model.transform(documents[0], document_embeddings[0:2]) + + with pytest.raises(ValueError): + topic_model.transform(documents[0:2], document_embeddings[0]) + + with pytest.raises(ValueError): + topic_model.transform(documents[0], [1, 2, 3]) +