From 3c608b39f6b1aea0f9bb94ab62b7dfd27fc825b9 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Thu, 14 Mar 2024 15:33:18 -0700 Subject: [PATCH 1/3] Make encode() in wrapped model compatible with super encode() * Adding missing params * Don't return unexpected tuple (with token count) unless asked * Adding check to not use our params if given an unwrapped model * Fixing some param position things Signed-off-by: Mark Sturdevant --- .../modules/text_embedding/embedding.py | 79 ++++++++++++++----- 1 file changed, 61 insertions(+), 18 deletions(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 14132f96..fc25c235 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -66,13 +66,11 @@ sentence_transformers = importlib.import_module("sentence_transformers") # Third Party from sentence_transformers import SentenceTransformer + from sentence_transformers.util import batch_to_device, cos_sim, dot_score from sentence_transformers.util import ( - batch_to_device, - cos_sim, - dot_score, - normalize_embeddings, - semantic_search, + normalize_embeddings as normalize, # avoid parameter shadowing ) + from sentence_transformers.util import semantic_search except ModuleNotFoundError: # When it is not available, create a dummy that raises an error on attempted init() class SentenceTransformerNotAvailable: @@ -271,7 +269,9 @@ def _with_retry(self, fn: Callable[..., RT], *args, **kwargs) -> RT: exception=first_exception, ) - def _encode_with_retry(self, *args, **kwargs) -> EmbeddingResultTuple: + def _encode_with_retry( + self, *args, **kwargs + ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: """All encode calls should use this for consistent param adding and retry loop""" # Add the batch_size kwarg if not passed in and given a usable BATCH_SIZE @@ -281,6 +281,16 @@ def _encode_with_retry(self, *args, **kwargs) -> EmbeddingResultTuple: if "batch_size" not in kwargs: kwargs["batch_size"] = BATCH_SIZE + if isinstance(self.model, SentenceTransformerWithTruncate): + return self._with_retry(self.model.encode, *args, **kwargs) + + # Else... + # It's possible to init with a model that doesn't have the added kwargs. + # E.g. a SentenceTransformer or other transormer model. Remove those kwargs! + # This is not the normal use case but at least don't pass invalid kwargs, to encode() + # and don't return the unexpected tuple (adding token count). + del kwargs["truncate_input_tokens"] + del kwargs["return_token_count"] return self._with_retry(self.model.encode, *args, **kwargs) @EmbeddingTask.taskmethod() @@ -306,7 +316,9 @@ def run_embedding( error.type_check("", str, text=text) embeddings, input_token_count = self._encode_with_retry( - text, truncate_input_tokens=truncate_input_tokens + text, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, ) return EmbeddingResult( result=Vector1D.from_vector(embeddings), @@ -341,7 +353,9 @@ def run_embeddings( texts = [texts] embeddings, input_token_count = self._encode_with_retry( - texts, truncate_input_tokens=truncate_input_tokens + texts, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, ) vectors = [Vector1D.from_vector(e) for e in embeddings] @@ -375,10 +389,14 @@ def run_sentence_similarity( """ source_embedding, source_token_count = self._encode_with_retry( - source_sentence, truncate_input_tokens=truncate_input_tokens + source_sentence, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, ) embeddings, sentences_token_count = self._encode_with_retry( - sentences, truncate_input_tokens=truncate_input_tokens + sentences, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, ) input_token_count = source_token_count + sentences_token_count @@ -415,10 +433,14 @@ def run_sentence_similarities( """ source_embedding, source_token_count = self._encode_with_retry( - source_sentences, truncate_input_tokens=truncate_input_tokens + source_sentences, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, ) embeddings, sentences_token_count = self._encode_with_retry( - sentences, truncate_input_tokens=truncate_input_tokens + sentences, + truncate_input_tokens=truncate_input_tokens, + return_token_count=True, ) input_token_count = source_token_count + sentences_token_count @@ -582,16 +604,18 @@ def get_text(doc): doc_embeddings, doc_token_count = self._encode_with_retry( doc_texts, truncate_input_tokens=truncate_input_tokens, + return_token_count=True, convert_to_tensor=True, ) - doc_embeddings = normalize_embeddings(doc_embeddings.to(self.model.device)) + doc_embeddings = normalize(doc_embeddings.to(self.model.device)) query_embeddings, query_token_count = self._encode_with_retry( queries, truncate_input_tokens=truncate_input_tokens, + return_token_count=True, convert_to_tensor=True, ) - query_embeddings = normalize_embeddings(query_embeddings.to(self.model.device)) + query_embeddings = normalize(query_embeddings.to(self.model.device)) res = semantic_search( query_embeddings, doc_embeddings, top_k=top_n, score_function=dot_score @@ -837,32 +861,47 @@ def encode( self, sentences: Union[str, List[str]], batch_size: int = 32, - device: Optional[str] = None, + show_progress_bar: bool = None, + output_value: str = "sentence_embedding", convert_to_numpy: bool = True, convert_to_tensor: bool = False, + device: str = None, + normalize_embeddings: bool = False, truncate_input_tokens: int = 0, - ) -> EmbeddingResultTuple: + return_token_count: bool = False, + ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: """ Computes sentence embeddings :param sentences: the sentences to embed :param batch_size: the batch size used for the computation - :param device: Which torch.device to use for the computation + :param show_progress_bar: Ignored here. Added for compatibility with super API. + :param output_value: Ignored here. Added for compatibility with super API. :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any setting from convert_to_numpy + :param device: Which torch.device to use for the computation + :param normalize_embeddings: Ignored here. Added for compatibility with super API. :param truncate_input_tokens: Truncation length for input tokens. Truncation length for input tokens. If less than zero, this truncation is left up to the tokenizer default (model max). If zero or greater than the model's maximum, then this is used as a test to see if truncation is needed. If needed is needed, an exception is thrown. Otherwise, we take this usable truncation limit to truncate the input tokens. + :param return_token_count: If true, a tuple is returned to add the input token count. :return: A tuple of the embedding, as a numpy matrix, and the input_token_count int. """ + # These args are for API compatability, but are currently ignored in our version of encode() + _ = ( + show_progress_bar, + output_value, + normalize_embeddings, + ) + self.eval() if convert_to_tensor: @@ -931,4 +970,8 @@ def encode( if input_was_string: all_embeddings = all_embeddings[0] - return EmbeddingResultTuple(all_embeddings, input_token_count) + return ( + EmbeddingResultTuple(all_embeddings, input_token_count) + if return_token_count + else all_embeddings + ) From c6ed5b26b89e12adf8cffc32f32413ff1b3ba0f9 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Sat, 16 Mar 2024 16:03:51 -0700 Subject: [PATCH 2/3] Catch KeyErrors before deleting keys. Add configurable default truncation behavior. * First draft could KeyError when deleting kwargs that don't exist. Tests added. * Adding a config option so the desired default behavior can be either: - Throw an error if truncation is happening implicitly, or - Nah. Just let it go. First was requested, so trunction does not happen quietly. This is common behavoir for some models. The second is more aligned with SentenceTransformers and is probably necessary for standard tests to run without errors. Signed-off-by: Mark Sturdevant --- caikit_nlp/config/config.yml | 2 + .../modules/text_embedding/embedding.py | 38 +++++++-- runtime_config.yaml | 2 + .../modules/text_embedding/test_embedding.py | 78 ++++++++++++++++++- 4 files changed, 109 insertions(+), 11 deletions(-) diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index fd25a18c..f6e7f026 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -42,6 +42,8 @@ embedding: retries: 0 # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used batch_size: 0 + # Should implicit truncation (with truncate_input_tokens=0) throw error for truncation (default) or disable this + implicit_truncation_errors: true # Attempt to optimize with PyTorch compile() pt2_compile: false # Use IPEX optimize. Works best when used with autocast (bfloat16) below. diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index fc25c235..912b8835 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -87,6 +87,9 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument PT2_COMPILE = env_val_to_bool(val=embedding_cfg.get("pt2_compile")) RETRIES = env_val_to_int(val=embedding_cfg.get("retries"), default=0) BATCH_SIZE = env_val_to_int(val=embedding_cfg.get("batch_size"), default=0) +NO_IMPLICIT_TRUNCATION = env_val_to_bool( + val=embedding_cfg.get("implicit_truncation_errors", True) +) DEVICE = embedding_cfg.get("device", "") RT = TypeVar("RT") # return type @@ -282,15 +285,22 @@ def _encode_with_retry( kwargs["batch_size"] = BATCH_SIZE if isinstance(self.model, SentenceTransformerWithTruncate): + kwargs[ + "implicit_truncation_errors" + ] = NO_IMPLICIT_TRUNCATION # config/env overrides default return self._with_retry(self.model.encode, *args, **kwargs) # Else... # It's possible to init with a model that doesn't have the added kwargs. - # E.g. a SentenceTransformer or other transormer model. Remove those kwargs! + # E.g. a SentenceTransformer or other transformer model. Remove those kwargs! # This is not the normal use case but at least don't pass invalid kwargs, to encode() # and don't return the unexpected tuple (adding token count). - del kwargs["truncate_input_tokens"] - del kwargs["return_token_count"] + if "truncate_input_tokens" in kwargs: + del kwargs["truncate_input_tokens"] + if "return_token_count" in kwargs: + del kwargs["return_token_count"] + if "implicit_truncation_errors" in kwargs: + del kwargs["implicit_truncation_errors"] return self._with_retry(self.model.encode, *args, **kwargs) @EmbeddingTask.taskmethod() @@ -778,7 +788,10 @@ def sum_token_count( class SentenceTransformerWithTruncate(SentenceTransformer): def _truncate_input_tokens( - self, truncate_input_tokens: int, texts: List[str] + self, + truncate_input_tokens: int, + texts: List[str], + implicit_truncation_errors: bool = True, ) -> TruncatedTokensTuple: """Truncate input tokens Args: @@ -790,6 +803,8 @@ def _truncate_input_tokens( Otherwise, we take this usable truncation limit to truncate the input tokens. texts: List[str] Input texts to be checked and optionally truncated. + implicit_truncation_errors: bool + Configuration indicates whether implicit truncation should be rejected. Returns: Tuple containing a dictionary of lists/arrays/tensors returned by the tokenizer, with proper truncation ('input_ids', 'attention_mask', etc.), and the input_token_count int. @@ -805,7 +820,7 @@ def _truncate_input_tokens( okay_to_truncate = True max_length = truncate_input_tokens else: - okay_to_truncate = False + okay_to_truncate = not implicit_truncation_errors max_length = max_tokens assert len(texts) > 0, "Cannot truncate nothing" @@ -869,6 +884,7 @@ def encode( normalize_embeddings: bool = False, truncate_input_tokens: int = 0, return_token_count: bool = False, + implicit_truncation_errors: bool = True, ) -> Union[EmbeddingResultTuple, List[torch.Tensor], np.ndarray, torch.Tensor]: """ Computes sentence embeddings @@ -887,12 +903,16 @@ def encode( Truncation length for input tokens. If less than zero, this truncation is left up to the tokenizer default (model max). If zero or greater than the model's maximum, then this is used as a test - to see if truncation is needed. If needed is needed, an exception is thrown. + to see if truncation is needed. If truncation is needed, an exception is thrown, + unless implicit_truncation_errors=False (see below). Otherwise, we take this usable truncation limit to truncate the input tokens. :param return_token_count: If true, a tuple is returned to add the input token count. + :param implicit_truncation_errors: If true (default) implicit truncation throws an error. + If false, the model default behavior or used. :return: - A tuple of the embedding, as a numpy matrix, and the input_token_count int. + If return_token_count is False, the embedding is returned as a numpy matrix. + If return_token_count is True, a tuple is returned with both the embedding and the input token count. """ # These args are for API compatability, but are currently ignored in our version of encode() @@ -938,7 +958,9 @@ def encode( for start_index in range(0, len(list_of_sentences), batch_size): sentences_batch = sentences_sorted[start_index : start_index + batch_size] features, token_count = self._truncate_input_tokens( - truncate_input_tokens, sentences_batch + truncate_input_tokens, + sentences_batch, + implicit_truncation_errors=implicit_truncation_errors, ) input_token_count += token_count diff --git a/runtime_config.yaml b/runtime_config.yaml index fcea5ed5..cbd27421 100644 --- a/runtime_config.yaml +++ b/runtime_config.yaml @@ -48,6 +48,8 @@ embedding: retries: 0 # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used batch_size: 0 + # Should implicit truncation (with truncate_input_tokens=0) throw error for truncation (default) or disable this + implicit_truncation_errors: true # Attempt to optimize with PyTorch compile() pt2_compile: false # Use IPEX optimize. Works best when used with autocast (bfloat16) below. diff --git a/tests/modules/text_embedding/test_embedding.py b/tests/modules/text_embedding/test_embedding.py index 5fdf2b1b..99c6bebe 100644 --- a/tests/modules/text_embedding/test_embedding.py +++ b/tests/modules/text_embedding/test_embedding.py @@ -1,7 +1,7 @@ """Tests for text embedding module""" # Standard -from typing import List +from typing import List, Tuple import os import tempfile @@ -170,8 +170,15 @@ def _assert_valid_scores(scores, type_tests={}): return type_tests -def test_bootstrap_reuse(): - assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap reuse error" +def test_bootstrap_model(loaded_model): + assert isinstance(BOOTSTRAPPED_MODEL, EmbeddingModule), "bootstrap model type" + assert ( + BOOTSTRAPPED_MODEL.model.__class__.__name__ == "SentenceTransformer" + ), "bootstrap model class name" + # worth noting that bootstrap does not wrap, but load does + assert ( + loaded_model.model.__class__.__name__ == "SentenceTransformerWithTruncate" + ), "loaded model class name" def test_save_load_and_run(): @@ -488,6 +495,49 @@ def test__truncate_input_tokens_raises(truncate_input_tokens, loaded_model): loaded_model.model.encode( sentences=[too_long], truncate_input_tokens=truncate_input_tokens ) + # Same behavior when implicit_truncation_errors is True (the default) + with pytest.raises(ValueError, match=f"({over} > {model_max})"): + loaded_model.model.encode( + sentences=[too_long], + truncate_input_tokens=truncate_input_tokens, + implicit_truncation_errors=True, + ) + # Different behavior when implicit_truncation_errors is False -- no error raised! + loaded_model.model.encode( + sentences=[too_long], + truncate_input_tokens=truncate_input_tokens, + implicit_truncation_errors=False, + ) + + +def test__implicit_truncation(loaded_model): + """Test that implicit truncation happens (when allowed)""" + model_max = loaded_model.model.max_seq_length + + too_long = "x " * (model_max - 1) # This will go over a little + extra_long = ( + too_long + + "more clever words that surely change the meaning of this text" + * (model_max - 1) + ) + + # Allowed truncation using default tokens (0) and config to disable the error. + res = loaded_model.model.encode( + sentences=[too_long], truncate_input_tokens=0, implicit_truncation_errors=False + ) + # Allowed truncation using model max + res_extra_max = loaded_model.model.encode( + sentences=[extra_long], truncate_input_tokens=loaded_model.model.max_seq_length + ) + # Allowed truncation using -1 to just let the model do its thing + res_extra_neg = loaded_model.model.encode( + sentences=[extra_long], truncate_input_tokens=-1 + ) + + # Demonstrating that when implicit truncation is allowed, sentence-transformers is quietly truncating at model max + # The simple too_long string of x's, is equivalent to the string with significantly different extra text (truncated) + assert np.allclose(res, res_extra_max) + assert np.allclose(res, res_extra_neg) def test_not_too_many_tokens(loaded_model): @@ -930,3 +980,25 @@ def test_get_sample_start_indexes(mapping, expected): "overflow_to_sample_mapping": torch.Tensor(mapping).type(torch.int8) } assert get_sample_start_indexes(mock_tokenized) == expected + + +def test_encode_extensions(loaded_model): + # loaded model can return_token_count + ret = loaded_model._encode_with_retry("text here", return_token_count=True) + assert isinstance(ret, Tuple) + assert isinstance(ret[0], np.ndarray) + assert isinstance(ret[1], int) + ret = loaded_model._encode_with_retry("text here", return_token_count=False) + assert isinstance(ret, np.ndarray) + + # Make sure use with un-wrapped SentenceTransformer model is unaffected by extended params or return tokens + ret = BOOTSTRAPPED_MODEL._encode_with_retry( + "text here", + return_token_count=True, + truncate_input_tokens=123, + implicit_truncation_errors=False, + ) + assert isinstance(ret, np.ndarray) + BOOTSTRAPPED_MODEL._encode_with_retry( + "text here" + ) # and no KeyError trying to remove non-existing keys From fc7d81e4927803f107d5926b0cd4138ee0fddca3 Mon Sep 17 00:00:00 2001 From: Mark Sturdevant Date: Sat, 16 Mar 2024 16:10:11 -0700 Subject: [PATCH 3/3] fmt fix Signed-off-by: Mark Sturdevant --- caikit_nlp/modules/text_embedding/embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index 912b8835..1c7c45aa 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -912,7 +912,8 @@ def encode( :return: If return_token_count is False, the embedding is returned as a numpy matrix. - If return_token_count is True, a tuple is returned with both the embedding and the input token count. + If return_token_count is True, a tuple is returned with both the embedding and + the input token count. """ # These args are for API compatability, but are currently ignored in our version of encode()