diff --git a/align_data/embeddings/pinecone/pinecone_db_handler.py b/align_data/embeddings/pinecone/pinecone_db_handler.py index 3cf32112..c5fb20de 100644 --- a/align_data/embeddings/pinecone/pinecone_db_handler.py +++ b/align_data/embeddings/pinecone/pinecone_db_handler.py @@ -2,7 +2,7 @@ import logging from typing import List, Tuple -import pinecone +from pinecone import Pinecone from pinecone.core.client.models import ScoredVector from align_data.embeddings.embedding_utils import get_embedding @@ -35,7 +35,7 @@ def __init__( self.values_dims = values_dims self.metric = metric - pinecone.init( + self.pinecone = Pinecone( api_key=PINECONE_API_KEY, environment=PINECONE_ENVIRONMENT, ) @@ -43,7 +43,7 @@ def __init__( if create_index: self.create_index() - self.index = pinecone.Index(index_name=self.index_name) + self.index = self.pinecone.Index(index_name=self.index_name) if log_index_stats: index_stats_response = self.index.describe_index_stats() @@ -118,7 +118,7 @@ def create_index(self, replace_current_index: bool = True): if replace_current_index: self.delete_index() - pinecone.create_index( + self.pinecone.create_index( name=self.index_name, dimension=self.values_dims, metric=self.metric, @@ -126,9 +126,9 @@ def create_index(self, replace_current_index: bool = True): ) def delete_index(self): - if self.index_name in pinecone.list_indexes(): + if self.index_name in self.pinecone.list_indexes(): logger.info(f"Deleting index '{self.index_name}'.") - pinecone.delete_index(self.index_name) + self.pinecone.delete_index(self.index_name) def get_embeddings_by_ids(self, ids: List[str]) -> List[Tuple[str, List[float] | None]]: """