Skip to content

Commit

Permalink
Merge pull request #196 from databio/fastembed
Browse files Browse the repository at this point in the history
switch to fastembed?
  • Loading branch information
khoroshevskyi authored Dec 3, 2024
2 parents f3a9cfe + 7dc3e03 commit d315f9f
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 13 deletions.
8 changes: 5 additions & 3 deletions geniml/search/query2vec/text2vec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
from typing import Union

import numpy as np
from langchain_huggingface.embeddings import HuggingFaceEmbeddings

# from langchain_huggingface.embeddings import HuggingFaceEmbeddings
from fastembed import TextEmbedding

from ...const import PKG_NAME
from ...text2bednn import Vec2VecFNN
Expand All @@ -20,7 +22,7 @@ def __init__(self, hf_repo: str, v2v: Union[str, Vec2VecFNN, None]):
:param v2v: a Vec2VecFNN (see geniml/text2bednn/text2bednn.py) or a model repository on Hugging Face
"""
# Set model that embed natural language
self.text_embedder = HuggingFaceEmbeddings(model_name=hf_repo)
self.text_embedder = TextEmbedding(model_name=hf_repo)
# Set model that maps natural language embeddings into the embedding space of region sets
if isinstance(v2v, Vec2VecFNN):
self.v2v = v2v
Expand All @@ -39,7 +41,7 @@ def forward(self, query: str) -> np.ndarray:
:return: the embedding vector of query
"""
# embed query string
query_embedding = np.array(self.text_embedder.embed_query(query))
query_embedding = list(self.text_embedder.embed(query))[0]
if self.v2v is None:
return query_embedding
else:
Expand Down
57 changes: 47 additions & 10 deletions tests/test_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from geniml.search import BED2BEDSearchInterface, BED2Vec, Text2BEDSearchInterface, Text2Vec
from geniml.search.backends import BiVectorBackend, HNSWBackend, QdrantBackend
from geniml.search.backends.filebackend import DEP_HNSWLIB
from geniml.search.interfaces.mlfree import BiVectorSearchInterface

DATA_FOLDER_PATH = os.path.join(
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "tests", "data"
Expand Down Expand Up @@ -274,7 +275,7 @@ def query_bed():

def cosine_similarity(vec1: np.array, vec2: np.array) -> float:
# Ensure the vectors have shape (100,)
assert vec1.shape == (100,) and vec2.shape == (100,), "Both vectors must have shape (100,)"
assert vec1.shape == (100,) and vec2.shape == (100,)

# Compute the dot product of the two vectors
dot_product = np.dot(vec1, vec2)
Expand Down Expand Up @@ -474,6 +475,10 @@ def test_HNSWBackend_save(filenames, bed_hnswb, bed_embeddings, temp_bed_idx_pat
"not config.getoption('--qdrant')",
reason="Only run when --qdrant is given",
)
@pytest.mark.skipif(
"not config.getoption('--huggingface')",
reason="Only run when --huggingface is given",
)
def test_BiVectorBackend(
bed_hnswb,
metadata_hnswb,
Expand All @@ -483,16 +488,9 @@ def test_BiVectorBackend(
metadata_collection,
text_embeddings,
metadata_payloads,
nl_embed_repo,
):
def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
query_vec = np.random.random(
384,
)
search_results = bivec_backend.search(
query_vec, 2, with_payload=True, with_vectors=True, distance=dist, rank=rank
)
assert isinstance(search_results, list)
assert len(search_results) == 2
def search_result_test(search_results: Dict, rank: bool):
min_score = 100.0
max_rank = -1
for result in search_results:
Expand All @@ -517,6 +515,18 @@ def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
assert isinstance(result["payload"]["name"], str)
assert isinstance(result["payload"]["metadata"], dict)

def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
query_vec = np.random.random(
384,
)
search_results = bivec_backend.search(
query_vec, 2, with_payload=True, with_vectors=True, distance=dist, rank=rank
)
assert isinstance(search_results, list)
assert len(search_results) == 2

search_result_test(search_results, rank)

# test QdrantBackend
bed_backend = QdrantBackend(collection=bed_collection)
# load data
Expand All @@ -528,6 +538,20 @@ def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
bivec_qd_backend = BiVectorBackend(text_backend, bed_backend)
bivec_test(bivec_qd_backend, rank=True)
bivec_test(bivec_qd_backend, rank=False)
# test QdrantBackend + Interface
bivec_qd_interface = BiVectorSearchInterface(bivec_qd_backend, nl_embed_repo)
interface_result = bivec_qd_interface.query_search(
"lung cancer cell line", 2, with_payload=True, with_vectors=True, rank=True
)

search_result_test(interface_result, rank=True)

interface_result = bivec_qd_interface.query_search(
"lung cancer cell line", 2, with_payload=True, with_vectors=True, rank=False
)

search_result_test(interface_result, rank=False)

bivec_qd_backend.metadata_backend.qd_client.delete_collection(text_backend.collection)
bivec_qd_backend.bed_backend.qd_client.delete_collection(bed_backend.collection)

Expand All @@ -536,6 +560,19 @@ def bivec_test(bivec_backend, dist: bool = False, rank: bool = False):
bivec_hnsw_backend = BiVectorBackend(metadata_hnswb, bed_hnswb)
bivec_test(bivec_hnsw_backend, dist=True, rank=True)
bivec_test(bivec_hnsw_backend, dist=True, rank=False)
# test HNSWBackend + Interface
bivec_hnsw_interface = BiVectorSearchInterface(bivec_hnsw_backend, nl_embed_repo)
interface_result = bivec_hnsw_interface.query_search(
"lung cancer cell line", 2, with_payload=True, with_vectors=True, distance=True, rank=True
)

search_result_test(interface_result, rank=True)

interface_result = bivec_hnsw_interface.query_search(
"lung cancer cell line", 2, with_payload=True, with_vectors=True, distance=True, rank=False
)

search_result_test(interface_result, rank=False)


@pytest.mark.skipif(
Expand Down

0 comments on commit d315f9f

Please sign in to comment.