diff --git a/geniml/search/query2vec/text2vec.py b/geniml/search/query2vec/text2vec.py index 3b9a854..d867670 100644 --- a/geniml/search/query2vec/text2vec.py +++ b/geniml/search/query2vec/text2vec.py @@ -2,7 +2,9 @@ from typing import Union import numpy as np -from langchain_huggingface.embeddings import HuggingFaceEmbeddings + +# from langchain_huggingface.embeddings import HuggingFaceEmbeddings +from fastembed import TextEmbedding from ...const import PKG_NAME from ...text2bednn import Vec2VecFNN @@ -20,7 +22,7 @@ def __init__(self, hf_repo: str, v2v: Union[str, Vec2VecFNN, None]): :param v2v: a Vec2VecFNN (see geniml/text2bednn/text2bednn.py) or a model repository on Hugging Face """ # Set model that embed natural language - self.text_embedder = HuggingFaceEmbeddings(model_name=hf_repo) + self.text_embedder = TextEmbedding(model_name=hf_repo) # Set model that maps natural language embeddings into the embedding space of region sets if isinstance(v2v, Vec2VecFNN): self.v2v = v2v @@ -39,7 +41,7 @@ def forward(self, query: str) -> np.ndarray: :return: the embedding vector of query """ # embed query string - query_embedding = np.array(self.text_embedder.embed_query(query)) + query_embedding = list(self.text_embedder.embed(query))[0] if self.v2v is None: return query_embedding else: diff --git a/tests/test_search.py b/tests/test_search.py index c056514..612f40e 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -9,6 +9,7 @@ from geniml.search import BED2BEDSearchInterface, BED2Vec, Text2BEDSearchInterface, Text2Vec from geniml.search.backends import BiVectorBackend, HNSWBackend, QdrantBackend from geniml.search.backends.filebackend import DEP_HNSWLIB +from geniml.search.interfaces.mlfree import BiVectorSearchInterface DATA_FOLDER_PATH = os.path.join( os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "data" @@ -274,7 +275,7 @@ def query_bed(): def cosine_similarity(vec1: np.array, vec2: np.array) -> float: # Ensure the vectors have shape (100,) - assert vec1.shape == (100,) and vec2.shape == (100,), "Both vectors must have shape (100,)" + assert vec1.shape == (100,) and vec2.shape == (100,) # Compute the dot product of the two vectors dot_product = np.dot(vec1, vec2) @@ -474,6 +475,10 @@ def test_HNSWBackend_save(filenames, bed_hnswb, bed_embeddings, temp_bed_idx_pat "not config.getoption('--qdrant')", reason="Only run when --qdrant is given", ) +@pytest.mark.skipif( + "not config.getoption('--huggingface')", + reason="Only run when --huggingface is given", +) def test_BiVectorBackend( bed_hnswb, metadata_hnswb, @@ -483,16 +488,9 @@ def test_BiVectorBackend( metadata_collection, text_embeddings, metadata_payloads, + nl_embed_repo, ): - def bivec_test(bivec_backend, dist: bool = False, rank: bool = False): - query_vec = np.random.random( - 384, - ) - search_results = bivec_backend.search( - query_vec, 2, with_payload=True, with_vectors=True, distance=dist, rank=rank - ) - assert isinstance(search_results, list) - assert len(search_results) == 2 + def search_result_test(search_results: Dict, rank: bool): min_score = 100.0 max_rank = -1 for result in search_results: @@ -517,6 +515,18 @@ def bivec_test(bivec_backend, dist: bool = False, rank: bool = False): assert isinstance(result["payload"]["name"], str) assert isinstance(result["payload"]["metadata"], dict) + def bivec_test(bivec_backend, dist: bool = False, rank: bool = False): + query_vec = np.random.random( + 384, + ) + search_results = bivec_backend.search( + query_vec, 2, with_payload=True, with_vectors=True, distance=dist, rank=rank + ) + assert isinstance(search_results, list) + assert len(search_results) == 2 + + search_result_test(search_results, rank) + # test QdrantBackend bed_backend = QdrantBackend(collection=bed_collection) # load data @@ -528,6 +538,20 @@ def bivec_test(bivec_backend, dist: bool = False, rank: bool = False): bivec_qd_backend = BiVectorBackend(text_backend, bed_backend) bivec_test(bivec_qd_backend, rank=True) bivec_test(bivec_qd_backend, rank=False) + # test QdrantBackend + Interface + bivec_qd_interface = BiVectorSearchInterface(bivec_qd_backend, nl_embed_repo) + interface_result = bivec_qd_interface.query_search( + "lung cancer cell line", 2, with_payload=True, with_vectors=True, rank=True + ) + + search_result_test(interface_result, rank=True) + + interface_result = bivec_qd_interface.query_search( + "lung cancer cell line", 2, with_payload=True, with_vectors=True, rank=False + ) + + search_result_test(interface_result, rank=False) + bivec_qd_backend.metadata_backend.qd_client.delete_collection(text_backend.collection) bivec_qd_backend.bed_backend.qd_client.delete_collection(bed_backend.collection) @@ -536,6 +560,19 @@ def bivec_test(bivec_backend, dist: bool = False, rank: bool = False): bivec_hnsw_backend = BiVectorBackend(metadata_hnswb, bed_hnswb) bivec_test(bivec_hnsw_backend, dist=True, rank=True) bivec_test(bivec_hnsw_backend, dist=True, rank=False) + # test HNSWBackend + Interface + bivec_hnsw_interface = BiVectorSearchInterface(bivec_hnsw_backend, nl_embed_repo) + interface_result = bivec_hnsw_interface.query_search( + "lung cancer cell line", 2, with_payload=True, with_vectors=True, distance=True, rank=True + ) + + search_result_test(interface_result, rank=True) + + interface_result = bivec_hnsw_interface.query_search( + "lung cancer cell line", 2, with_payload=True, with_vectors=True, distance=True, rank=False + ) + + search_result_test(interface_result, rank=False) @pytest.mark.skipif(