From 8f189d5de5197ba43f9bdab5d0da80aa9274e995 Mon Sep 17 00:00:00 2001 From: Sumaiyah Date: Fri, 22 Mar 2024 16:28:00 +0000 Subject: [PATCH] [ENH] Add optional kwargs when initialising SentenceTransformerEmbeddingFunction class (#1891) ## Description of changes *Summarize the changes made by this PR.* - Improvements & Bug fixes - Add optional kwargs for `SetenceTransformer` when initialising `SentenceTransformerEmbeddingFunction` class (Issue [#1857](https://github.com/chroma-core/chroma/issues/1857)) ## Test plan *How are these changes tested?* - [x] Tests pass locally with `pytest` for python - installing chroma as an editable package locally and testing with the code ```python import chromadb from chromadb.utils import embedding_functions sentence_transformer_ef = embedding_functions.SentenceTransformerEmbeddingFunction(prompts={"query": "query: ", "passage": "passage: "}) print(sentence_transformer_ef.models['all-MiniLM-L6-v2'].prompts) ``` returned ```bash {'query': 'query: ', 'passage': 'passage: '} ``` ## Documentation Changes *Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs repository](https://github.com/chroma-core/docs)?* I have added the documentation for `SentenceTransformerEmbeddingFunction` initialisation. Co-authored-by: sumaiyah --- chromadb/utils/embedding_functions.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/chromadb/utils/embedding_functions.py b/chromadb/utils/embedding_functions.py index 5e98588538d..da5f1591f1c 100644 --- a/chromadb/utils/embedding_functions.py +++ b/chromadb/utils/embedding_functions.py @@ -61,7 +61,16 @@ def __init__( model_name: str = "all-MiniLM-L6-v2", device: str = "cpu", normalize_embeddings: bool = False, + **kwargs: Any ): + """Initialize SentenceTransformerEmbeddingFunction. + + Args: + model_name (str, optional): Identifier of the SentenceTransformer model, defaults to "all-MiniLM-L6-v2" + device (str, optional): Device used for computation, defaults to "cpu" + normalize_embeddings (bool, optional): Whether to normalize returned vectors, defaults to False + **kwargs: Additional arguments to pass to the SentenceTransformer model. + """ if model_name not in self.models: try: from sentence_transformers import SentenceTransformer @@ -69,7 +78,7 @@ def __init__( raise ValueError( "The sentence_transformers python package is not installed. Please install it with `pip install sentence_transformers`" ) - self.models[model_name] = SentenceTransformer(model_name, device=device) + self.models[model_name] = SentenceTransformer(model_name, device=device, **kwargs) self._model = self.models[model_name] self._normalize_embeddings = normalize_embeddings