From 8ee52c96415acb3e0fb99c243fdccb64f3b56845 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=AD=B1=E6=96=87?= Date: Fri, 6 Dec 2024 18:23:44 +0800 Subject: [PATCH] Fix embedding utils --- .../embeddings/clip/cnclip_embedding.py | 10 +++--- .../embeddings/pai/embedding_utils.py | 32 ++++++++++++++----- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py b/src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py index b69c372d..841e8d62 100644 --- a/src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py +++ b/src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py @@ -26,6 +26,7 @@ def __init__( self, model_name: str = DEFAULT_CNCLIP_MODEL, embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, + model_path: str = DEFAULT_MODEL_DIR, **kwargs: Any, ) -> None: super().__init__( @@ -38,13 +39,10 @@ def __init__( raise ValueError(f"Unknown ChineseClip model: {model_name}.") self._device = "cuda" if torch.cuda.is_available() else "cpu" - pai_rag_model_dir = os.getenv("PAI_RAG_MODEL_DIR", DEFAULT_MODEL_DIR) self._model, self._preprocess = load_from_name( self.model_name, device=self._device, - download_root=os.path.join( - pai_rag_model_dir, "chinese-clip-vit-large-patch14" - ), + download_root=model_path, ) self._model.eval() @@ -97,7 +95,9 @@ def _get_image_embedding(self, img_file_path: ImageType) -> Embedding: if __name__ == "__main__": - clip_embedding = CnClipEmbedding() + clip_embedding = CnClipEmbedding( + os.path.join(DEFAULT_MODEL_DIR, "chinese-clip-vit-large-patch14") + ) image_embedding = clip_embedding.get_image_embedding( "example_data/cn_clip/pokemon.jpeg" diff --git a/src/pai_rag/integrations/embeddings/pai/embedding_utils.py b/src/pai_rag/integrations/embeddings/pai/embedding_utils.py index 2c0f3bd8..ab47dc1d 100644 --- a/src/pai_rag/integrations/embeddings/pai/embedding_utils.py +++ b/src/pai_rag/integrations/embeddings/pai/embedding_utils.py @@ -36,35 +36,51 @@ def create_embedding(embed_config: PaiBaseEmbeddingConfig): f"Initialized DashScope embedding model with {embed_config.embed_batch_size} batch size." ) elif isinstance(embed_config, HuggingFaceEmbeddingConfig): - pai_model_dir = os.getenv("PAI_RAG_MODEL_DIR", "./model_repository") - pai_model_name = os.path.join(pai_model_dir, embed_config.model) - if not os.path.exists(pai_model_name): + pai_rag_model_dir = os.getenv("PAI_RAG_MODEL_DIR", "./model_repository") + pai_model_path = os.path.join(pai_rag_model_dir, embed_config.model) + if not os.path.exists(pai_model_path): logger.info( - f"Embedding model {embed_config.model} not found in {pai_model_dir}, try download it." + f"Embedding model {embed_config.model} not found in {pai_rag_model_dir}, try download it." ) download_models = ModelScopeDownloader( - fetch_config=True, download_directory_path=pai_model_dir + fetch_config=True, download_directory_path=pai_rag_model_dir ) download_models.load_model(model=embed_config.model) logger.info( - f"Embedding model {embed_config.model} downloaded to {pai_model_name}." + f"Embedding model {embed_config.model} downloaded to {pai_model_path}." ) embed_model = HuggingFaceEmbedding( - model_name=pai_model_name, + model_name=pai_model_path, embed_batch_size=embed_config.embed_batch_size, trust_remote_code=True, callback_manager=Settings.callback_manager, ) logger.info( - f"Initialized HuggingFace embedding model {embed_config.model} from model_dir_path {pai_model_dir} with {embed_config.embed_batch_size} batch size." + f"Initialized HuggingFace embedding model {embed_config.model} from model_dir_path {pai_rag_model_dir} with {embed_config.embed_batch_size} batch size." ) elif isinstance(embed_config, CnClipEmbeddingConfig): + pai_rag_model_dir = os.getenv("PAI_RAG_MODEL_DIR", "./model_repository") + pai_model_path = os.path.join( + pai_rag_model_dir, "chinese-clip-vit-large-patch14" + ) + if not os.path.exists(pai_model_path): + logger.info( + f"Embedding model {embed_config.model} not found in {pai_rag_model_dir}, try download it." + ) + download_models = ModelScopeDownloader( + fetch_config=True, download_directory_path=pai_rag_model_dir + ) + download_models.load_model(model="chinese-clip-vit-large-patch14") + logger.info( + f"Embedding model {embed_config.model} downloaded to {pai_model_path}." + ) embed_model = CnClipEmbedding( model_name=embed_config.model, embed_batch_size=embed_config.embed_batch_size, callback_manager=Settings.callback_manager, + model_path=pai_model_path, ) logger.info( f"Initialized CnClip embedding model {embed_config.model} with {embed_config.embed_batch_size} batch size."