Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transform method handle a single document with provided embedidngs #2043

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions bertopic/_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,10 +497,14 @@ def transform(self,
```
"""
check_is_fitted(self)
check_embeddings_shape(embeddings, documents)


if isinstance(documents, str) or documents is None:
documents = [documents]

if len(documents) == 1 and isinstance(embeddings, np.ndarray) and embeddings.ndim == 1:
embeddings = embeddings.reshape(1,-1)
Comment on lines +504 to +505
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The embeddings that you can pass to .transform is typed as np.ndarray which is actually quite misleading (sorry!) as it can technically take the form of any iterable (but mostly just np.ndarray-like structures, such as a scipy sparse matrix). As such, doing something like embeddings.ndim == 1 might break here.


check_embeddings_shape(embeddings, documents)

if embeddings is None:
embeddings = self._extract_embeddings(documents,
Expand Down
24 changes: 24 additions & 0 deletions tests/test_bertopic.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,27 @@ def test_full_model(model, documents, request):

assert len(merged_model.get_topic_info()) > len(topic_model1.get_topic_info())
assert len(merged_model.get_topic_info()) > len(topic_model.get_topic_info())

def test_transform_flexibility(documents, document_embeddings, request):

topic_model = copy.deepcopy(request.getfixturevalue('base_topic_model'))
print(document_embeddings[0].shape)
try:
topic_model.transform(documents[0], document_embeddings[0])
except ValueError:
pytest.fail('Error thrown for transform with single document and embeddings')

try:
topic_model.transform(documents[0:2], document_embeddings[0:2])
except ValueError:
pytest.fail('Error thrown for transform with multiple documents and embeddings')

with pytest.raises(ValueError):
topic_model.transform(documents[0], document_embeddings[0:2])

with pytest.raises(ValueError):
topic_model.transform(documents[0:2], document_embeddings[0])

with pytest.raises(ValueError):
topic_model.transform(documents[0], [1, 2, 3])