Skip to content

Commit

Permalink
feat: add index types for vector search (#55)
Browse files Browse the repository at this point in the history
* feat: adding index operations and tests
  • Loading branch information
totoleon authored Mar 28, 2024
1 parent 9cc52c1 commit 2e30b48
Show file tree
Hide file tree
Showing 6 changed files with 349 additions and 8 deletions.
2 changes: 1 addition & 1 deletion integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ availableSecrets:
env: "DB_PASSWORD"

substitutions:
_INSTANCE_ID: mysql-vector
_INSTANCE_ID: test-instance
_REGION: us-central1
_DB_NAME: test
_VERSION: "3.8"
Expand Down
6 changes: 6 additions & 0 deletions src/langchain_google_cloud_sql_mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,22 @@

from .chat_message_history import MySQLChatMessageHistory
from .engine import Column, MySQLEngine
from .indexes import DistanceMeasure, IndexType, QueryOptions, SearchType, VectorIndex
from .loader import MySQLDocumentSaver, MySQLLoader
from .vectorstore import MySQLVectorStore
from .version import __version__

__all__ = [
"Column",
"DistanceMeasure",
"IndexType",
"MySQLChatMessageHistory",
"MySQLDocumentSaver",
"MySQLEngine",
"MySQLLoader",
"MySQLVectorStore",
"QueryOptions",
"SearchType",
"VectorIndex",
"__version__",
]
8 changes: 7 additions & 1 deletion src/langchain_google_cloud_sql_mysql/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,17 @@ def connect(self) -> sqlalchemy.engine.Connection:
return self.engine.connect()

def _execute(self, query: str, params: Optional[dict] = None) -> None:
"""Execute a SQL query."""
"""Executes a SQL query within a transaction."""
with self.engine.connect() as conn:
conn.execute(sqlalchemy.text(query), params)
conn.commit()

def _execute_outside_tx(self, query: str, params: Optional[dict] = None) -> None:
"""Executes a SQL query with autocommit (outside of transaction)."""
with self.engine.connect() as conn:
conn = conn.execution_options(isolation_level="AUTOCOMMIT")
conn.execute(sqlalchemy.text(query), params)

def _fetch(self, query: str, params: Optional[dict] = None):
"""Fetch results from a SQL query."""
with self.engine.connect() as conn:
Expand Down
88 changes: 84 additions & 4 deletions src/langchain_google_cloud_sql_mysql/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,91 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC
from dataclasses import dataclass
from enum import Enum
from typing import Optional


class SearchType(Enum):
"""Defines the types of search algorithms that can be used.
Attributes:
KNN: K-Nearest Neighbors search.
ANN: Approximate Nearest Neighbors search.
"""

KNN = "KNN"
ANN = "ANN"


@dataclass
class QueryOptions(ABC):
def to_string(self) -> str:
raise NotImplementedError("to_string method must be implemented by subclass")
class QueryOptions:
"""Holds configuration options for executing a search query.
Attributes:
num_partitions (Optional[int]): The number of partitions to divide the search space into. None means default partitioning.
num_neighbors (Optional[int]): The number of nearest neighbors to retrieve. None means use the default.
search_type (SearchType): The type of search algorithm to use. Defaults to KNN.
"""

num_partitions: Optional[int] = None
num_neighbors: Optional[int] = None
search_type: SearchType = SearchType.KNN


DEFAULT_QUERY_OPTIONS = QueryOptions()


class IndexType(Enum):
"""Defines the types of indexes that can be used for vector storage.
Attributes:
BRUTE_FORCE_SCAN: A simple brute force scan approach.
TREE_AH: A tree-based index, specifically Annoy (Approximate Nearest Neighbors Oh Yeah).
TREE_SQ: A tree-based index, specifically ScaNN (Scalable Nearest Neighbors).
"""

BRUTE_FORCE_SCAN = "BRUTE_FORCE"
TREE_AH = "TREE_AH"
TREE_SQ = "TREE_SQ"


class DistanceMeasure(Enum):
"""Enumerates the types of distance measures that can be used in searches.
Attributes:
COSINE: Cosine similarity measure.
SQUARED_L2: Squared L2 norm (Euclidean) distance.
DOT_PRODUCT: Dot product similarity.
"""

COSINE = "cosine"
SQUARED_L2 = "squared_l2"
DOT_PRODUCT = "dot_product"


class VectorIndex:
"""Represents a vector index for storing and querying vectors.
Attributes:
name (Optional[str]): The name of the index.
index_type (Optional[IndexType]): The type of index.
distance_measure (Optional[DistanceMeasure]): The distance measure to use for the index.
num_partitions (Optional[int]): The number of partitions for the index. None for default.
num_neighbors (Optional[int]): The default number of neighbors to return for queries.
"""

def __init__(
self,
name: Optional[str] = None,
index_type: Optional[IndexType] = None,
distance_measure: Optional[DistanceMeasure] = None,
num_partitions: Optional[int] = None,
num_neighbors: Optional[int] = None,
):
"""Initializes a new instance of the VectorIndex class."""
self.name = name
self.index_type = index_type
self.distance_measure = distance_measure
self.num_partitions = num_partitions
self.num_neighbors = num_neighbors
73 changes: 71 additions & 2 deletions src/langchain_google_cloud_sql_mysql/vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
from langchain_core.vectorstores import VectorStore

from .engine import MySQLEngine
from .indexes import QueryOptions
from .indexes import DEFAULT_QUERY_OPTIONS, QueryOptions, SearchType, VectorIndex

DEFAULT_INDEX_NAME_SUFFIX = "langchainvectorindex"


class MySQLVectorStore(VectorStore):
Expand All @@ -38,7 +40,7 @@ def __init__(
ignore_metadata_columns: Optional[List[str]] = None,
id_column: str = "langchain_id",
metadata_json_column: Optional[str] = "langchain_metadata",
query_options: Optional[QueryOptions] = None,
query_options: QueryOptions = DEFAULT_QUERY_OPTIONS,
):
"""Constructor for MySQLVectorStore.
Args:
Expand Down Expand Up @@ -118,11 +120,16 @@ def __init__(
self.id_column = id_column
self.metadata_json_column = metadata_json_column
self.query_options = query_options
self.db_name = self.__get_db_name()

@property
def embeddings(self) -> Embeddings:
return self.embedding_service

def __get_db_name(self) -> str:
result = self.engine._fetch("SELECT DATABASE();")
return result[0]["DATABASE()"]

def _add_embeddings(
self,
texts: Iterable[str],
Expand Down Expand Up @@ -210,6 +217,64 @@ def delete(
self.engine._execute(query)
return True

def apply_vector_index(self, vector_index: VectorIndex):
# Construct the default index name
if not vector_index.name:
vector_index.name = f"{self.table_name}_{DEFAULT_INDEX_NAME_SUFFIX}"
query_template = f"CALL mysql.create_vector_index('{vector_index.name}', '{self.db_name}.{self.table_name}', '{self.embedding_column}', '{{}}');"
self.__exec_apply_vector_index(query_template, vector_index)
# After applying an index to the table, set the query option search type to be ANN
self.query_options.search_type = SearchType.ANN

def alter_vector_index(self, vector_index: VectorIndex):
existing_index_name = self._get_vector_index_name()
if not existing_index_name:
raise ValueError("No existing vector index found.")
if not vector_index.name:
vector_index.name = existing_index_name.split(".")[1]
if existing_index_name.split(".")[1] != vector_index.name:
raise ValueError(
f"Existing index name {existing_index_name} does not match the new index name {vector_index.name}."
)
query_template = (
f"CALL mysql.alter_vector_index('{existing_index_name}', '{{}}');"
)
self.__exec_apply_vector_index(query_template, vector_index)

def __exec_apply_vector_index(self, query_template: str, vector_index: VectorIndex):
index_options = []
if vector_index.index_type:
index_options.append(f"index_type={vector_index.index_type.value}")
if vector_index.distance_measure:
index_options.append(
f"distance_measure={vector_index.distance_measure.value}"
)
if vector_index.num_partitions:
index_options.append(f"num_partitions={vector_index.num_partitions}")
if vector_index.num_neighbors:
index_options.append(f"num_neighbors={vector_index.num_neighbors}")
index_options_query = ",".join(index_options)

stmt = query_template.format(index_options_query)
self.engine._execute_outside_tx(stmt)

def _get_vector_index_name(self):
query = f"SELECT index_name FROM mysql.vector_indexes WHERE table_name='{self.db_name}.{self.table_name}';"
result = self.engine._fetch(query)
if result:
return result[0]["index_name"]
else:
return None

def drop_vector_index(self):
existing_index_name = self._get_vector_index_name()
if existing_index_name:
self.engine._execute_outside_tx(
f"CALL mysql.drop_vector_index('{existing_index_name}');"
)
self.query_options.search_type = SearchType.KNN
return existing_index_name

@classmethod
def from_texts( # type: ignore[override]
cls: Type[MySQLVectorStore],
Expand All @@ -225,6 +290,7 @@ def from_texts( # type: ignore[override]
ignore_metadata_columns: Optional[List[str]] = None,
id_column: str = "langchain_id",
metadata_json_column: str = "langchain_metadata",
query_options: QueryOptions = DEFAULT_QUERY_OPTIONS,
**kwargs: Any,
):
vs = cls(
Expand All @@ -237,6 +303,7 @@ def from_texts( # type: ignore[override]
ignore_metadata_columns=ignore_metadata_columns,
id_column=id_column,
metadata_json_column=metadata_json_column,
query_options=query_options,
)
vs.add_texts(texts, metadatas=metadatas, ids=ids, **kwargs)
return vs
Expand All @@ -255,6 +322,7 @@ def from_documents( # type: ignore[override]
ignore_metadata_columns: Optional[List[str]] = None,
id_column: str = "langchain_id",
metadata_json_column: str = "langchain_metadata",
query_options: QueryOptions = DEFAULT_QUERY_OPTIONS,
**kwargs: Any,
) -> MySQLVectorStore:
vs = cls(
Expand All @@ -267,6 +335,7 @@ def from_documents( # type: ignore[override]
ignore_metadata_columns=ignore_metadata_columns,
id_column=id_column,
metadata_json_column=metadata_json_column,
query_options=query_options,
)
texts = [doc.page_content for doc in documents]
metadatas = [doc.metadata for doc in documents]
Expand Down
Loading

0 comments on commit 2e30b48

Please sign in to comment.