From a4ff4d420dc0ac3ffeef2574a0fb0f44c8adc0c0 Mon Sep 17 00:00:00 2001 From: achibb <42097962+achibb@users.noreply.github.com> Date: Wed, 3 Apr 2024 14:38:27 +0200 Subject: [PATCH] Update SentenceTransformer.py to use token length --- sentence_transformers/SentenceTransformer.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/sentence_transformers/SentenceTransformer.py b/sentence_transformers/SentenceTransformer.py index 3163bb10a..830617eb9 100644 --- a/sentence_transformers/SentenceTransformer.py +++ b/sentence_transformers/SentenceTransformer.py @@ -344,7 +344,7 @@ def encode( self.to(device) all_embeddings = [] - length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) + length_sorted_idx = np.argsort([-self._token_length(sen) for sen in sentences]) sentences_sorted = [sentences[idx] for idx in length_sorted_idx] for start_index in trange(0, len(sentences), batch_size, desc="Batches", disable=not show_progress_bar): @@ -861,6 +861,22 @@ def _text_length(self, text: Union[List[int], List[List[int]]]): return len(text) else: return sum([len(t) for t in text]) # Sum of length of individual strings + + def _token_length(self, text: Union[List[int], List[List[int]]]): + """ + Help function to get the token length for the input text. Text can be either + a list of ints (which means a single text as input), or a tuple of list of ints + (representing several text inputs to the model). + """ + + if isinstance(text, dict): # {key: value} case + return len(next(iter(self.tokenize(text.values())["input_ids"]))) + elif not hasattr(text, "__len__"): # Object has no len() method + return 1 + elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints + return len(self.tokenize(text)["input_ids"]) + else: + return sum([len(self.tokenize(t)["input_ids"]) for t in text]) # Sum of length of individual strings def fit( self,