From 51599017818c5e7483547d00ff529f2d868c8ed8 Mon Sep 17 00:00:00 2001 From: 1jamesthompson1 <1jamesthompson1@gmail.com> Date: Wed, 29 May 2024 18:09:51 +1200 Subject: [PATCH] Better handling of single documents and embeddings being given to the transform function. --- bertopic/_bertopic.py | 8 ++++++-- tests/test_bertopic.py | 24 ++++++++++++++++++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/bertopic/_bertopic.py b/bertopic/_bertopic.py index de57c35a..67955605 100644 --- a/bertopic/_bertopic.py +++ b/bertopic/_bertopic.py @@ -497,10 +497,14 @@ def transform(self, ``` """ check_is_fitted(self) - check_embeddings_shape(embeddings, documents) - + if isinstance(documents, str) or documents is None: documents = [documents] + + if len(documents) == 1 and isinstance(embeddings, np.ndarray) and embeddings.ndim == 1: + embeddings = embeddings.reshape(1,-1) + + check_embeddings_shape(embeddings, documents) if embeddings is None: embeddings = self._extract_embeddings(documents, diff --git a/tests/test_bertopic.py b/tests/test_bertopic.py index 5d4bfac8..6d0bd01f 100644 --- a/tests/test_bertopic.py +++ b/tests/test_bertopic.py @@ -138,3 +138,27 @@ def test_full_model(model, documents, request): assert len(merged_model.get_topic_info()) > len(topic_model1.get_topic_info()) assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info()) + +def test_transform_flexibility(documents, document_embeddings, request): + + topic_model = copy.deepcopy(request.getfixturevalue('base_topic_model')) + print(document_embeddings[0].shape) + try: + topic_model.transform(documents[0], document_embeddings[0]) + except ValueError: + pytest.fail('Error thrown for transform with single document and embeddings') + + try: + topic_model.transform(documents[0:2], document_embeddings[0:2]) + except ValueError: + pytest.fail('Error thrown for transform with multiple documents and embeddings') + + with pytest.raises(ValueError): + topic_model.transform(documents[0], document_embeddings[0:2]) + + with pytest.raises(ValueError): + topic_model.transform(documents[0:2], document_embeddings[0]) + + with pytest.raises(ValueError): + topic_model.transform(documents[0], [1, 2, 3]) +