Skip to content

Commit

Permalink
add clip model for encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
peng3307165 committed May 15, 2024
1 parent 1d34f35 commit a08a739
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 6 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -139,4 +139,7 @@ dmypy.json

*.ini

**/multicache_serving.py
**/multicache_serving.py
**/modelcache_serving.py

**/model/
12 changes: 12 additions & 0 deletions model/clip_zh/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
"""
Alipay.com Inc.
Copyright (c) 2004-2023 All Rights Reserved.
------------------------------------------------------
File Name : __init__.py.py
Author : fuhui.phe
Create Time : 2024/5/7 14:05
Description : description what the main function of this file
Change Activity:
version0 : 2024/5/7 14:05 by fuhui.phe init
"""
7 changes: 5 additions & 2 deletions modelcache/adapter/adapter_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,20 @@ def adapt_query(cache_data_convert, *args, **kwargs):
report_func=chat_cache.report.embedding,
)(pre_embedding_data)

# print('embedding_data: {}'.format(embedding_data))

if cache_enable:
cache_data_list = time_cal(
chat_cache.data_manager.search,
func_name="milvus_search",
func_name="vector_search",
report_func=chat_cache.report.search,
)(
embedding_data,
extra_param=context.get("search_func", None),
top_k=kwargs.pop("top_k", -1),
model=model
)
print('cache_data_list: {}'.format(cache_data_list))
cache_answers = []
cache_questions = []
cache_ids = []
Expand Down Expand Up @@ -78,8 +81,8 @@ def adapt_query(cache_data_convert, *args, **kwargs):
return

for cache_data in cache_data_list:
print('cache_data: {}'.format(cache_data))
primary_id = cache_data[1]
start_time = time.time()
ret = chat_cache.data_manager.get_scalar_data(
cache_data, extra_param=context.get("get_scalar_data", None)
)
Expand Down
2 changes: 1 addition & 1 deletion modelcache/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from modelcache.processor.post import first
from modelcache.similarity_evaluation import ExactMatchEvaluation
from modelcache.similarity_evaluation import SimilarityEvaluation
from modelcache.embedding.string import to_embeddings as string_embedding
from modelcache.embedding.string_text import to_embeddings as string_embedding
from modelcache.report import Report
from modelcache.config import Config
from modelcache.utils.cache_func import cache_all
Expand Down
90 changes: 90 additions & 0 deletions modelcache/embedding/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -*- coding: utf-8 -*-
import os
import torch
from modelcache.embedding.base import BaseEmbedding
from modelscope.utils.constant import Tasks
from modelscope.pipelines import pipeline
from modelscope.preprocessors.image import load_image


# def mean_pooling(model_output, attention_mask):
# token_embeddings = model_output[0] # First element of model_output contains all token embeddings
# input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
# return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)


class ClipAudio(BaseEmbedding):
def __init__(self, model: str = "sentence-transformers/all-MiniLM-L6-v2"):
# current_dir = os.path.dirname(os.path.abspath(__file__))
# parent_dir = os.path.dirname(current_dir)
# model_dir = os.path.dirname(parent_dir)
# model = os.path.join(model_dir, 'model/text2vec-base-chinese/')

self.clip_pipeline = pipeline(task=Tasks.multi_modal_embedding,
model='damo/multi-modal_clip-vit-base-patch16_zh', model_revision='v1.0.1')

self.__dimension = 1024

def to_embeddings(self, data_dict, **_):
text_list = data_dict['text']
image_data = data_dict['image']

img_data = None
txt_data = None

if image_data:
input_img = load_image(image_data)
# 2D Tensor, [图片数, 特征维度]
img_embedding = self.clip_pipeline.forward({'img': input_img})['img_embedding'].tolist()[0] if input_img else []
print('img_embedding: {}'.format(img_embedding))
else:
raise ValueError('image_data is None, please check!')

if text_list and len(text_list) > 0:
# 2D Tensor, [文本数, 特征维度]
text_embedding = self.clip_pipeline.forward({'text': text_list})['text_embedding'].tolist()[0] if text_list else []
print('text_embedding: {}'.format(text_embedding))
else:
raise ValueError('text_list is None, please check!')

return {'image_embedding': img_embedding, 'text_embeddings': text_embedding}

# return {'image_embedding': img_feats, 'text_embeddings': txt_feats}
# input_texts = ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]
# input_img = load_image(
# 'https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg')

# img_embedding = self.clip_pipeline.forward({'img': input_img})['img_embedding'] # 2D Tensor, [图片数, 特征维度]
# print('img_embedding: {}'.format(img_embedding))
# text_embedding = self.clip_pipeline.forward({'text': input_texts})['text_embedding'] # 2D Tensor, [文本数, 特征维度]


# return embedding_array

def post_proc(self, token_embeddings, inputs):
attention_mask = inputs["attention_mask"]
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
)
sentence_embs = torch.sum(
token_embeddings * input_mask_expanded, 1
) / torch.clamp(input_mask_expanded.sum(1), min=1e-9)
return sentence_embs

@property
def dimension(self):
"""Embedding dimension.
:return: embedding dimension
"""
return self.__dimension


# if __name__ == '__main__':
# clip_vec = ClipAudio()
# text_list = ['hello', '你好']
# text = ['###'.join(text_list)]
# image = 'https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg'
# data_dict = {'text': text, 'image': image}
# resp = clip_vec.to_embeddings(data_dict)
# print('resp: {}'.format(resp))
49 changes: 49 additions & 0 deletions modelcache/embedding/clip_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
"""
Alipay.com Inc.
Copyright (c) 2004-2023 All Rights Reserved.
------------------------------------------------------
File Name : clip_demo.py
Author : fuhui.phe
Create Time : 2024/5/7 11:58
Description : description what the main function of this file
Change Activity:
version0 : 2024/5/7 11:58 by fuhui.phe init
"""
import torch
from modelscope.utils.constant import Tasks
from modelscope.pipelines import pipeline
from modelscope.preprocessors.image import load_image


pipeline = pipeline(task=Tasks.multi_modal_embedding,
model='damo/multi-modal_clip-vit-base-patch16_zh', model_revision='v1.0.1')

# pipeline = pipeline(task=Tasks.multi_modal_embedding,
# model='/Users/penghongen/PycharmProjects/CodeFuse-ModelCache/model/clip_zh', model_revision='v1.0.1')

# pipeline = pipeline(task=Tasks.multi_modal_embedding, model='/Users/penghongen/PycharmProjects/CodeFuse-ModelCache/model/clip_zh')


input_img = load_image('https://clip-cn-beijing.oss-cn-beijing.aliyuncs.com/pokemon.jpeg') # 支持皮卡丘示例图片路径/本地图片 返回PIL.Image


input_texts = ["杰尼龟", "妙蛙种子", "小火龙", "皮卡丘"]

# 支持一张图片(PIL.Image)或多张图片(List[PIL.Image])输入,输出归一化特征向量
img_embedding = pipeline.forward({'img': input_img})['img_embedding'] # 2D Tensor, [图片数, 特征维度]
print('img_embedding: {}'.format(img_embedding))

# 支持一条文本(str)或多条文本(List[str])输入,输出归一化特征向量
text_embedding = pipeline.forward({'text': input_texts})['text_embedding'] # 2D Tensor, [文本数, 特征维度]

# 计算图文相似度
with torch.no_grad():
# 计算内积得到logit,考虑模型temperature
logits_per_image = (img_embedding / pipeline.model.temperature) @ text_embedding.t()
# 根据logit计算概率分布
probs = logits_per_image.softmax(dim=-1).cpu().numpy()

print("图文匹配概率:", probs)


File renamed without changes.
File renamed without changes.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pymilvus==2.3.1
PyMySQL==1.1.0
Requests==2.31.0
torch==2.1.0
transformers==4.34.1
transformers==4.38.2
faiss-cpu==1.7.4
redis==5.0.1

modelscope==1.14.0

0 comments on commit a08a739

Please sign in to comment.