From ddb5f8e97b8d8c08a54c990f1e1f1913fccdb89d Mon Sep 17 00:00:00 2001 From: Henri Lemoine Date: Tue, 22 Aug 2023 01:12:10 -0400 Subject: [PATCH] renamed embed_query to get_embedding --- align_data/common/utils.py | 5 ++--- align_data/pinecone/pinecone_db_handler.py | 4 ++-- align_data/pinecone/pinecone_models.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/align_data/common/utils.py b/align_data/common/utils.py index 76716d94..5fd0afcf 100644 --- a/align_data/common/utils.py +++ b/align_data/common/utils.py @@ -79,7 +79,6 @@ def compute_openai_embeddings(non_flagged_texts, engine, **kwargs): data = openai.Embedding.create(input=non_flagged_texts, engine=engine, **kwargs).data return [d["embedding"] for d in data] -from openai.embeddings_utils import get_embeddings def get_embeddings( texts: List[str], embed_all: bool = False, @@ -126,8 +125,8 @@ def get_embeddings( return final_embeddings -def embed_query(query: str, engine=OPENAI_EMBEDDINGS_MODEL, **kwargs) -> List[float]: - return get_embeddings([query], engine=engine, **kwargs)[0] +def get_embedding(text: str, **kwargs) -> List[float]: + return get_embeddings(texts=[text], **kwargs)[0] def get_recursive_type(obj): diff --git a/align_data/pinecone/pinecone_db_handler.py b/align_data/pinecone/pinecone_db_handler.py index d1bff5bf..98081449 100644 --- a/align_data/pinecone/pinecone_db_handler.py +++ b/align_data/pinecone/pinecone_db_handler.py @@ -4,7 +4,7 @@ import pinecone -from align_data.common.utils import get_embeddings, embed_query +from align_data.common.utils import get_embedding from align_data.pinecone.pinecone_models import PineconeEntry, PineconeMatch, PineconeMetadata from align_data.settings import ( PINECONE_INDEX_NAME, @@ -82,7 +82,7 @@ def query_text( include_values: bool = False, include_metadata: bool = True, **kwargs ) -> List[PineconeMatch]: - query_vector = embed_query(query) + query_vector = get_embedding(query) return self.query_vector( query=query_vector, top_k=top_k, include_values=include_values, include_metadata=include_metadata, **kwargs diff --git a/align_data/pinecone/pinecone_models.py b/align_data/pinecone/pinecone_models.py index 59077dc7..6ac6379a 100644 --- a/align_data/pinecone/pinecone_models.py +++ b/align_data/pinecone/pinecone_models.py @@ -23,7 +23,7 @@ class PineconeEntry(BaseModel): source: str title: str url: str - date_published: int + date_published: float authors: List[str] text_chunks: List[str] embeddings: List[List[float]]