diff --git a/caikit_nlp/config/config.yml b/caikit_nlp/config/config.yml index 6f440a22..8eb5a500 100644 --- a/caikit_nlp/config/config.yml +++ b/caikit_nlp/config/config.yml @@ -38,6 +38,8 @@ training_data_limit: # Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32 embedding: + # Allow models with remote code. + trust_remote_code: false # Number of times to retry on error. Most deployments should use 0 retries. retries: 0 # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used diff --git a/caikit_nlp/modules/text_embedding/embedding.py b/caikit_nlp/modules/text_embedding/embedding.py index efa97728..97fcb7ac 100644 --- a/caikit_nlp/modules/text_embedding/embedding.py +++ b/caikit_nlp/modules/text_embedding/embedding.py @@ -22,6 +22,7 @@ Dict, Iterable, List, + Literal, NamedTuple, Optional, TypeVar, @@ -82,6 +83,8 @@ sentence_transformers = importlib.import_module("sentence_transformers") # Third Party from sentence_transformers import SentenceTransformer + from sentence_transformers.model_card import SentenceTransformerModelCardData + from sentence_transformers.similarity_functions import SimilarityFunction from sentence_transformers.util import batch_to_device, cos_sim, dot_score from sentence_transformers.util import ( normalize_embeddings as normalize, # avoid parameter shadowing @@ -107,6 +110,7 @@ def __init__(self, *args, **kwargs): # pylint: disable=unused-argument val=embedding_cfg.get("implicit_truncation_errors", True) ) DEVICE = embedding_cfg.get("device", "") +TRUST_REMOTE_CODE = embedding_cfg.get("trust_remote_code") RT = TypeVar("RT") # return type @@ -183,7 +187,9 @@ def load(cls, model_path: str, *args, **kwargs) -> "EmbeddingModule": ipex = cls._get_ipex(IPEX) device = cls._select_device(ipex, DEVICE) model = SentenceTransformerWithTruncate( - model_name_or_path=artifacts_path, device=device + model_name_or_path=artifacts_path, + device=device, + trust_remote_code=TRUST_REMOTE_CODE, ) model.eval() # required for IPEX at least if device is not None: @@ -719,7 +725,12 @@ def bootstrap(cls, model_name_or_path: str) -> "EmbeddingModule": model_name_or_path: str Model name (Hugging Face hub) or path to model to load. """ - return cls(model=SentenceTransformer(model_name_or_path=model_name_or_path)) + return cls( + model=SentenceTransformer( + model_name_or_path=model_name_or_path, + trust_remote_code=TRUST_REMOTE_CODE, + ) + ) def save(self, model_path: str, *args, **kwargs): """Save model using config in model_path @@ -875,21 +886,39 @@ def __init__( model_name_or_path: Optional[str] = None, modules: Optional[Iterable[nn.Module]] = None, device: Optional[str] = None, + prompts: Optional[Dict[str, str]] = None, + default_prompt_name: Optional[str] = None, + similarity_fn_name: Optional[Union[str, SimilarityFunction]] = None, cache_folder: Optional[str] = None, trust_remote_code: bool = False, revision: Optional[str] = None, + local_files_only: bool = False, token: Optional[Union[bool, str]] = None, use_auth_token: Optional[Union[bool, str]] = None, + truncate_dim: Optional[int] = None, + model_kwargs: Optional[Dict[str, Any]] = None, + tokenizer_kwargs: Optional[Dict[str, Any]] = None, + config_kwargs: Optional[Dict[str, Any]] = None, + model_card_data: Optional[SentenceTransformerModelCardData] = None, ): super().__init__( model_name_or_path, modules, device, + prompts, + default_prompt_name, + similarity_fn_name, cache_folder, trust_remote_code, revision, + local_files_only, token, use_auth_token, + truncate_dim, + model_kwargs, + tokenizer_kwargs, + config_kwargs, + model_card_data, ) self.tokenizers = {} @@ -1014,9 +1043,12 @@ def _get_tokenized(self, texts): def encode( self, sentences: Union[str, List[str]], + prompt_name: Optional[str] = None, + prompt: Optional[str] = None, batch_size: int = 32, show_progress_bar: bool = None, output_value: str = "sentence_embedding", + precision: Literal["float32", "int8", "uint8", "binary", "ubinary"] = "float32", convert_to_numpy: bool = True, convert_to_tensor: bool = False, device: str = None, @@ -1029,9 +1061,12 @@ def encode( Computes sentence embeddings :param sentences: the sentences to embed + :param prompt_name: Ignored here. Added for compatibility with super API. + :param prompt: Ignored here. Added for compatibility with super API. :param batch_size: the batch size used for the computation :param show_progress_bar: Ignored here. Added for compatibility with super API. :param output_value: Ignored here. Added for compatibility with super API. + :param precision: Ignored here. Added for compatibility with super API. :param convert_to_numpy: If true, the output is a list of numpy vectors. Else, it is a list of pytorch tensors. :param convert_to_tensor: If true, you get one large tensor as return. Overwrites any @@ -1057,8 +1092,11 @@ def encode( # These args are for API compatability, but are currently ignored in our version of encode() _ = ( + prompt_name, + prompt, show_progress_bar, output_value, + precision, normalize_embeddings, ) diff --git a/pyproject.toml b/pyproject.toml index e688a4b5..3a2464f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,11 +27,11 @@ dependencies = [ "pandas>=1.5.0", "scikit-learn>=1.1", "scipy>=1.8.1", - "sentence-transformers>=2.3.1,<2.4.0", + "sentence-transformers>=3.0.0,<3.1.0", "tokenizers>=0.13.3", "torch>=2.3.1,<2.4.0", "tqdm>=4.65.0", - "transformers>=4.32.0", + "transformers>=4.38.0,<4.44.0", "peft==0.6.0", ] diff --git a/runtime_config.yaml b/runtime_config.yaml index cbd27421..b88fb545 100644 --- a/runtime_config.yaml +++ b/runtime_config.yaml @@ -44,6 +44,8 @@ model_management: # Config used only in EmbeddingModule. Set here or use env vars like EMBEDDING_RETRIES=32 embedding: + # Allow models with remote code. + trust_remote_code: false # Number of times to retry on error. Most deployments should use 0 retries. retries: 0 # Batch size for encode() if <= 0 or invalid, the sentence-transformers default is used diff --git a/tox.ini b/tox.ini index c220f022..df4f5d15 100644 --- a/tox.ini +++ b/tox.ini @@ -15,6 +15,7 @@ passenv = LOG_FORMATTER LOG_THREAD_ID LOG_CHANNEL_WIDTH + PYTORCH_ENABLE_MPS_FALLBACK commands = pytest --durations=42 --cov=caikit_nlp --cov-report=term --cov-report=html {posargs:tests} ; Unclear: We probably want to test wheel packaging