diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py index 0896c46cc850f..8e25d841dcef2 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-postgres/llama_index/vector_stores/postgres/base.py @@ -130,6 +130,7 @@ class PGVectorStore(BasePydanticVectorStore): user="postgres", table_name="paul_graham_essay", embed_dim=1536 # openai embedding dimension + vector_search_method="cosine_distance" # Optional specify vector search method. Default is cosine_distance. ) ``` """ @@ -272,6 +273,7 @@ def from_params( use_jsonb: bool = False, hnsw_kwargs: Optional[Dict[str, Any]] = None, create_engine_kwargs: Optional[Dict[str, Any]] = None, + vector_search_method: Optional[str] = "cosine_distance", ) -> "PGVectorStore": """Construct from params. @@ -296,6 +298,7 @@ def from_params( contains "hnsw_ef_construction", "hnsw_ef_search", "hnsw_m", and optionally "hnsw_dist_method". Defaults to None, which turns off HNSW search. create_engine_kwargs (Optional[Dict[str, Any]], optional): Engine parameters to pass to create_engine. Defaults to None. + vector_search_method (Optional[str], optional): Vector search method. Defaults to cosine_distance. Returns: PGVectorStore: Instance of PGVectorStore constructed from params. @@ -307,6 +310,7 @@ def from_params( async_conn_str = async_connection_string or ( f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}" ) + cls.vector_search_method = vector_search_method return cls( connection_string=conn_str, async_connection_string=async_conn_str, @@ -597,13 +601,72 @@ def _build_query( ) -> Any: from sqlalchemy import select, text - stmt = select( # type: ignore - self._table_class.id, - self._table_class.node_id, - self._table_class.text, - self._table_class.metadata_, - self._table_class.embedding.cosine_distance(embedding).label("distance"), - ).order_by(text("distance asc")) + match self.vector_search_method: + case "cosine_distance": + stmt = select( # type: ignore + self._table_class.id, + self._table_class.node_id, + self._table_class.text, + self._table_class.metadata_, + self._table_class.embedding.cosine_distance(embedding).label( + "distance" + ), + ).order_by(text("distance asc")) + + case "max_inner_product": + stmt = select( # type: ignore + self._table_class.id, + self._table_class.node_id, + self._table_class.text, + self._table_class.metadata_, + self._table_class.embedding.max_inner_product(embedding).label( + "distance" + ), + ).order_by(text("distance asc")) + + case "l2_distance": + stmt = select( # type: ignore + self._table_class.id, + self._table_class.node_id, + self._table_class.text, + self._table_class.metadata_, + self._table_class.embedding.l2_distance(embedding).label( + "distance" + ), + ).order_by(text("distance asc")) + + case "l1_distance": + stmt = select( # type: ignore + self._table_class.id, + self._table_class.node_id, + self._table_class.text, + self._table_class.metadata_, + self._table_class.embedding.l1_distance(embedding).label( + "distance" + ), + ).order_by(text("distance asc")) + + case "hamming_distance": + stmt = select( # type: ignore + self._table_class.id, + self._table_class.node_id, + self._table_class.text, + self._table_class.metadata_, + self._table_class.embedding.hamming_distance(embedding).label( + "distance" + ), + ).order_by(text("distance asc")) + + case "jaccard_distance": + stmt = select( # type: ignore + self._table_class.id, + self._table_class.node_id, + self._table_class.text, + self._table_class.metadata_, + self._table_class.embedding.jaccard_distance(embedding).label( + "distance" + ), + ).order_by(text("distance asc")) return self._apply_filters_and_limit(stmt, limit, metadata_filters)