Skip to content

Commit

Permalink
Fix embedding utils
Browse files Browse the repository at this point in the history
  • Loading branch information
wwxxzz committed Dec 6, 2024
1 parent b95c8a7 commit 8ee52c9
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 13 deletions.
10 changes: 5 additions & 5 deletions src/pai_rag/integrations/embeddings/clip/cnclip_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
self,
model_name: str = DEFAULT_CNCLIP_MODEL,
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE,
model_path: str = DEFAULT_MODEL_DIR,
**kwargs: Any,
) -> None:
super().__init__(
Expand All @@ -38,13 +39,10 @@ def __init__(
raise ValueError(f"Unknown ChineseClip model: {model_name}.")

self._device = "cuda" if torch.cuda.is_available() else "cpu"
pai_rag_model_dir = os.getenv("PAI_RAG_MODEL_DIR", DEFAULT_MODEL_DIR)
self._model, self._preprocess = load_from_name(
self.model_name,
device=self._device,
download_root=os.path.join(
pai_rag_model_dir, "chinese-clip-vit-large-patch14"
),
download_root=model_path,
)
self._model.eval()

Expand Down Expand Up @@ -97,7 +95,9 @@ def _get_image_embedding(self, img_file_path: ImageType) -> Embedding:


if __name__ == "__main__":
clip_embedding = CnClipEmbedding()
clip_embedding = CnClipEmbedding(
os.path.join(DEFAULT_MODEL_DIR, "chinese-clip-vit-large-patch14")
)

image_embedding = clip_embedding.get_image_embedding(
"example_data/cn_clip/pokemon.jpeg"
Expand Down
32 changes: 24 additions & 8 deletions src/pai_rag/integrations/embeddings/pai/embedding_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,35 +36,51 @@ def create_embedding(embed_config: PaiBaseEmbeddingConfig):
f"Initialized DashScope embedding model with {embed_config.embed_batch_size} batch size."
)
elif isinstance(embed_config, HuggingFaceEmbeddingConfig):
pai_model_dir = os.getenv("PAI_RAG_MODEL_DIR", "./model_repository")
pai_model_name = os.path.join(pai_model_dir, embed_config.model)
if not os.path.exists(pai_model_name):
pai_rag_model_dir = os.getenv("PAI_RAG_MODEL_DIR", "./model_repository")
pai_model_path = os.path.join(pai_rag_model_dir, embed_config.model)
if not os.path.exists(pai_model_path):
logger.info(
f"Embedding model {embed_config.model} not found in {pai_model_dir}, try download it."
f"Embedding model {embed_config.model} not found in {pai_rag_model_dir}, try download it."
)
download_models = ModelScopeDownloader(
fetch_config=True, download_directory_path=pai_model_dir
fetch_config=True, download_directory_path=pai_rag_model_dir
)
download_models.load_model(model=embed_config.model)
logger.info(
f"Embedding model {embed_config.model} downloaded to {pai_model_name}."
f"Embedding model {embed_config.model} downloaded to {pai_model_path}."
)
embed_model = HuggingFaceEmbedding(
model_name=pai_model_name,
model_name=pai_model_path,
embed_batch_size=embed_config.embed_batch_size,
trust_remote_code=True,
callback_manager=Settings.callback_manager,
)

logger.info(
f"Initialized HuggingFace embedding model {embed_config.model} from model_dir_path {pai_model_dir} with {embed_config.embed_batch_size} batch size."
f"Initialized HuggingFace embedding model {embed_config.model} from model_dir_path {pai_rag_model_dir} with {embed_config.embed_batch_size} batch size."
)

elif isinstance(embed_config, CnClipEmbeddingConfig):
pai_rag_model_dir = os.getenv("PAI_RAG_MODEL_DIR", "./model_repository")
pai_model_path = os.path.join(
pai_rag_model_dir, "chinese-clip-vit-large-patch14"
)
if not os.path.exists(pai_model_path):
logger.info(
f"Embedding model {embed_config.model} not found in {pai_rag_model_dir}, try download it."
)
download_models = ModelScopeDownloader(
fetch_config=True, download_directory_path=pai_rag_model_dir
)
download_models.load_model(model="chinese-clip-vit-large-patch14")
logger.info(
f"Embedding model {embed_config.model} downloaded to {pai_model_path}."
)
embed_model = CnClipEmbedding(
model_name=embed_config.model,
embed_batch_size=embed_config.embed_batch_size,
callback_manager=Settings.callback_manager,
model_path=pai_model_path,
)
logger.info(
f"Initialized CnClip embedding model {embed_config.model} with {embed_config.embed_batch_size} batch size."
Expand Down

0 comments on commit 8ee52c9

Please sign in to comment.