Skip to content

Commit

Permalink
fix: embed and upsert_documents methods
Browse files Browse the repository at this point in the history
  • Loading branch information
KevKibe committed Apr 29, 2024
1 parent 9290a26 commit cf14b20
Showing 1 changed file with 13 additions and 5 deletions.
18 changes: 13 additions & 5 deletions src/docindex/doc_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def tiktoken_len(self, text: str) -> int:
)
return len(tokens)

def embed(self, sample_text: str):
def embed(self, sample_text: str = None):
"""
Generates embeddings for the provided sample text using either Google's Generative AI or OpenAI.
Expand All @@ -136,20 +136,23 @@ def embed(self, sample_text: str):
# Google Generative AI
if self.google_api_key:
genai.configure(api_key=self.google_api_key)
return genai.embed_content(
embed = genai.embed_content(
model='models/embedding-001',
content=sample_text,
task_type="retrieval_document"
)
return embed

# OpenAI Embeddings
elif self.openai_api_key:
return OpenAIEmbeddings(
embed = OpenAIEmbeddings(
openai_api_key=self.openai_api_key
)
return embed
elif self.cohere_api_key:
return CohereEmbeddings(model_name = "embed-english-light-v3.0",
embed = CohereEmbeddings(model_name = "embed-english-light-v3.0",
cohere_api_key=self.cohere_api_key)
return embed
else:
raise ValueError("A valid API key for either Google, Cohere or OpenAI must be provided to generate embeddings.")

Expand Down Expand Up @@ -189,10 +192,15 @@ def upsert_documents(self, documents: List[Page], batch_limit: int, chunk_size:
ids = [str(uuid4()) for _ in range(len(texts))]
embeddings = None
if self.google_api_key:
embeddings = self.embed(texts)['embedding']
embeds = self.embed(texts)
embeddings = embeds['embedding']
elif self.openai_api_key:
embed = self.embed()
embeddings = embed.embed_documents(texts)
elif self.cohere_api_key:
embed = self.embed()
embeddings = embed.embed_query(texts)

if embeddings is not None:
index = self.pc.Index(self.index_name)
index.upsert(vectors=zip(ids, embeddings , metadatas), async_req=True)
Expand Down

0 comments on commit cf14b20

Please sign in to comment.