Skip to content

Commit

Permalink
Merge pull request #5 from SCAI-BIO/mpnet-performance-improvement
Browse files Browse the repository at this point in the history
MPNet Performance Improvement
  • Loading branch information
tiadams authored Feb 22, 2024
2 parents 3bce331 + d00cfe3 commit e66c3f0
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
12 changes: 7 additions & 5 deletions index/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,25 +39,27 @@ def get_embeddings(self, messages: [str], model="text-embedding-ada-002"):


class MPNetAdapter(EmbeddingModel):
def __init__(self):
def __init__(self, model="sentence-transformers/all-mpnet-base-v2"):
logging.getLogger().setLevel(logging.INFO)
self.mpnet_model = SentenceTransformer(model)

def get_embedding(self, text: str, model="sentence-transformers/all-mpnet-base-v2"):
mpnet_model = SentenceTransformer(model)
def get_embedding(self, text: str):
logging.info(f"Getting embedding for {text}")
try:
if text is None or text == "" or text is np.nan:
logging.warn(f"Empty text passed to get_embedding")
return None
if isinstance(text, str):
text = text.replace("\n", " ")
return mpnet_model.encode(text)
return self.mpnet_model.encode(text)
except Exception as e:
logging.error(f"Error getting embedding for {text}: {e}")
return None

def get_embeddings(self, messages: [str]) -> [[float]]:
return [self.get_embedding(msg) for msg in messages]
embeddings = self.mpnet_model.encode(messages)
flattened_embeddings = [[float(element) for element in row] for row in embeddings]
return flattened_embeddings


class TextEmbedding:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_embedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
from index.embedding import MPNetAdapter, TextEmbedding
import numpy as np

class TestEmbedding(unittest.TestCase):

def setUp(self):
self.mpnet_adapter = MPNetAdapter(model="sentence-transformers/all-mpnet-base-v2")

def test_mpnet_adapter_get_embedding(self):
text = "This is a test sentence."
embedding = self.mpnet_adapter.get_embedding(text)
self.assertIsInstance(embedding, np.ndarray)
self.assertEqual(len(embedding), 768)

def test_mpnet_adapter_get_embeddings(self):
messages = ["This is message 1.", "This is message 2."]
embeddings = self.mpnet_adapter.get_embeddings(messages)
self.assertIsInstance(embeddings, list)
self.assertEqual(len(embeddings), len(messages))
self.assertEqual(len(embeddings[0]), 768)

def test_text_embedding(self):
text = "This is a test sentence."
embedding = [0.1, 0.2, 0.3, 0.4]
text_embedding = TextEmbedding(text, embedding)
self.assertEqual(text_embedding.text, text)
self.assertEqual(text_embedding.embedding, embedding)

0 comments on commit e66c3f0

Please sign in to comment.