diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index fae720a..61e569b 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -43,7 +43,7 @@ availableSecrets: env: "DB_PASSWORD" substitutions: - _INSTANCE_ID: mysql-vector + _INSTANCE_ID: test-instance _REGION: us-central1 _DB_NAME: test _VERSION: "3.8" diff --git a/src/langchain_google_cloud_sql_mysql/__init__.py b/src/langchain_google_cloud_sql_mysql/__init__.py index 29d2540..b59cf2d 100644 --- a/src/langchain_google_cloud_sql_mysql/__init__.py +++ b/src/langchain_google_cloud_sql_mysql/__init__.py @@ -14,16 +14,22 @@ from .chat_message_history import MySQLChatMessageHistory from .engine import Column, MySQLEngine +from .indexes import DistanceMeasure, IndexType, QueryOptions, SearchType, VectorIndex from .loader import MySQLDocumentSaver, MySQLLoader from .vectorstore import MySQLVectorStore from .version import __version__ __all__ = [ "Column", + "DistanceMeasure", + "IndexType", "MySQLChatMessageHistory", "MySQLDocumentSaver", "MySQLEngine", "MySQLLoader", "MySQLVectorStore", + "QueryOptions", + "SearchType", + "VectorIndex", "__version__", ] diff --git a/src/langchain_google_cloud_sql_mysql/engine.py b/src/langchain_google_cloud_sql_mysql/engine.py index a410748..9dd7375 100644 --- a/src/langchain_google_cloud_sql_mysql/engine.py +++ b/src/langchain_google_cloud_sql_mysql/engine.py @@ -222,11 +222,17 @@ def connect(self) -> sqlalchemy.engine.Connection: return self.engine.connect() def _execute(self, query: str, params: Optional[dict] = None) -> None: - """Execute a SQL query.""" + """Executes a SQL query within a transaction.""" with self.engine.connect() as conn: conn.execute(sqlalchemy.text(query), params) conn.commit() + def _execute_outside_tx(self, query: str, params: Optional[dict] = None) -> None: + """Executes a SQL query with autocommit (outside of transaction).""" + with self.engine.connect() as conn: + conn = conn.execution_options(isolation_level="AUTOCOMMIT") + conn.execute(sqlalchemy.text(query), params) + def _fetch(self, query: str, params: Optional[dict] = None): """Fetch results from a SQL query.""" with self.engine.connect() as conn: diff --git a/src/langchain_google_cloud_sql_mysql/indexes.py b/src/langchain_google_cloud_sql_mysql/indexes.py index d038abb..e7a6fdc 100644 --- a/src/langchain_google_cloud_sql_mysql/indexes.py +++ b/src/langchain_google_cloud_sql_mysql/indexes.py @@ -12,11 +12,91 @@ # See the License for the specific language governing permissions and # limitations under the License. -from abc import ABC from dataclasses import dataclass +from enum import Enum +from typing import Optional + + +class SearchType(Enum): + """Defines the types of search algorithms that can be used. + + Attributes: + KNN: K-Nearest Neighbors search. + ANN: Approximate Nearest Neighbors search. + """ + + KNN = "KNN" + ANN = "ANN" @dataclass -class QueryOptions(ABC): - def to_string(self) -> str: - raise NotImplementedError("to_string method must be implemented by subclass") +class QueryOptions: + """Holds configuration options for executing a search query. + + Attributes: + num_partitions (Optional[int]): The number of partitions to divide the search space into. None means default partitioning. + num_neighbors (Optional[int]): The number of nearest neighbors to retrieve. None means use the default. + search_type (SearchType): The type of search algorithm to use. Defaults to KNN. + """ + + num_partitions: Optional[int] = None + num_neighbors: Optional[int] = None + search_type: SearchType = SearchType.KNN + + +DEFAULT_QUERY_OPTIONS = QueryOptions() + + +class IndexType(Enum): + """Defines the types of indexes that can be used for vector storage. + + Attributes: + BRUTE_FORCE_SCAN: A simple brute force scan approach. + TREE_AH: A tree-based index, specifically Annoy (Approximate Nearest Neighbors Oh Yeah). + TREE_SQ: A tree-based index, specifically ScaNN (Scalable Nearest Neighbors). + """ + + BRUTE_FORCE_SCAN = "BRUTE_FORCE" + TREE_AH = "TREE_AH" + TREE_SQ = "TREE_SQ" + + +class DistanceMeasure(Enum): + """Enumerates the types of distance measures that can be used in searches. + + Attributes: + COSINE: Cosine similarity measure. + SQUARED_L2: Squared L2 norm (Euclidean) distance. + DOT_PRODUCT: Dot product similarity. + """ + + COSINE = "cosine" + SQUARED_L2 = "squared_l2" + DOT_PRODUCT = "dot_product" + + +class VectorIndex: + """Represents a vector index for storing and querying vectors. + + Attributes: + name (Optional[str]): The name of the index. + index_type (Optional[IndexType]): The type of index. + distance_measure (Optional[DistanceMeasure]): The distance measure to use for the index. + num_partitions (Optional[int]): The number of partitions for the index. None for default. + num_neighbors (Optional[int]): The default number of neighbors to return for queries. + """ + + def __init__( + self, + name: Optional[str] = None, + index_type: Optional[IndexType] = None, + distance_measure: Optional[DistanceMeasure] = None, + num_partitions: Optional[int] = None, + num_neighbors: Optional[int] = None, + ): + """Initializes a new instance of the VectorIndex class.""" + self.name = name + self.index_type = index_type + self.distance_measure = distance_measure + self.num_partitions = num_partitions + self.num_neighbors = num_neighbors diff --git a/src/langchain_google_cloud_sql_mysql/vectorstore.py b/src/langchain_google_cloud_sql_mysql/vectorstore.py index 31d9af5..c5f363e 100644 --- a/src/langchain_google_cloud_sql_mysql/vectorstore.py +++ b/src/langchain_google_cloud_sql_mysql/vectorstore.py @@ -23,7 +23,9 @@ from langchain_core.vectorstores import VectorStore from .engine import MySQLEngine -from .indexes import QueryOptions +from .indexes import DEFAULT_QUERY_OPTIONS, QueryOptions, SearchType, VectorIndex + +DEFAULT_INDEX_NAME_SUFFIX = "langchainvectorindex" class MySQLVectorStore(VectorStore): @@ -38,7 +40,7 @@ def __init__( ignore_metadata_columns: Optional[List[str]] = None, id_column: str = "langchain_id", metadata_json_column: Optional[str] = "langchain_metadata", - query_options: Optional[QueryOptions] = None, + query_options: QueryOptions = DEFAULT_QUERY_OPTIONS, ): """Constructor for MySQLVectorStore. Args: @@ -118,11 +120,16 @@ def __init__( self.id_column = id_column self.metadata_json_column = metadata_json_column self.query_options = query_options + self.db_name = self.__get_db_name() @property def embeddings(self) -> Embeddings: return self.embedding_service + def __get_db_name(self) -> str: + result = self.engine._fetch("SELECT DATABASE();") + return result[0]["DATABASE()"] + def _add_embeddings( self, texts: Iterable[str], @@ -210,6 +217,64 @@ def delete( self.engine._execute(query) return True + def apply_vector_index(self, vector_index: VectorIndex): + # Construct the default index name + if not vector_index.name: + vector_index.name = f"{self.table_name}_{DEFAULT_INDEX_NAME_SUFFIX}" + query_template = f"CALL mysql.create_vector_index('{vector_index.name}', '{self.db_name}.{self.table_name}', '{self.embedding_column}', '{{}}');" + self.__exec_apply_vector_index(query_template, vector_index) + # After applying an index to the table, set the query option search type to be ANN + self.query_options.search_type = SearchType.ANN + + def alter_vector_index(self, vector_index: VectorIndex): + existing_index_name = self._get_vector_index_name() + if not existing_index_name: + raise ValueError("No existing vector index found.") + if not vector_index.name: + vector_index.name = existing_index_name.split(".")[1] + if existing_index_name.split(".")[1] != vector_index.name: + raise ValueError( + f"Existing index name {existing_index_name} does not match the new index name {vector_index.name}." + ) + query_template = ( + f"CALL mysql.alter_vector_index('{existing_index_name}', '{{}}');" + ) + self.__exec_apply_vector_index(query_template, vector_index) + + def __exec_apply_vector_index(self, query_template: str, vector_index: VectorIndex): + index_options = [] + if vector_index.index_type: + index_options.append(f"index_type={vector_index.index_type.value}") + if vector_index.distance_measure: + index_options.append( + f"distance_measure={vector_index.distance_measure.value}" + ) + if vector_index.num_partitions: + index_options.append(f"num_partitions={vector_index.num_partitions}") + if vector_index.num_neighbors: + index_options.append(f"num_neighbors={vector_index.num_neighbors}") + index_options_query = ",".join(index_options) + + stmt = query_template.format(index_options_query) + self.engine._execute_outside_tx(stmt) + + def _get_vector_index_name(self): + query = f"SELECT index_name FROM mysql.vector_indexes WHERE table_name='{self.db_name}.{self.table_name}';" + result = self.engine._fetch(query) + if result: + return result[0]["index_name"] + else: + return None + + def drop_vector_index(self): + existing_index_name = self._get_vector_index_name() + if existing_index_name: + self.engine._execute_outside_tx( + f"CALL mysql.drop_vector_index('{existing_index_name}');" + ) + self.query_options.search_type = SearchType.KNN + return existing_index_name + @classmethod def from_texts( # type: ignore[override] cls: Type[MySQLVectorStore], @@ -225,6 +290,7 @@ def from_texts( # type: ignore[override] ignore_metadata_columns: Optional[List[str]] = None, id_column: str = "langchain_id", metadata_json_column: str = "langchain_metadata", + query_options: QueryOptions = DEFAULT_QUERY_OPTIONS, **kwargs: Any, ): vs = cls( @@ -237,6 +303,7 @@ def from_texts( # type: ignore[override] ignore_metadata_columns=ignore_metadata_columns, id_column=id_column, metadata_json_column=metadata_json_column, + query_options=query_options, ) vs.add_texts(texts, metadatas=metadatas, ids=ids, **kwargs) return vs @@ -255,6 +322,7 @@ def from_documents( # type: ignore[override] ignore_metadata_columns: Optional[List[str]] = None, id_column: str = "langchain_id", metadata_json_column: str = "langchain_metadata", + query_options: QueryOptions = DEFAULT_QUERY_OPTIONS, **kwargs: Any, ) -> MySQLVectorStore: vs = cls( @@ -267,6 +335,7 @@ def from_documents( # type: ignore[override] ignore_metadata_columns=ignore_metadata_columns, id_column=id_column, metadata_json_column=metadata_json_column, + query_options=query_options, ) texts = [doc.page_content for doc in documents] metadatas = [doc.metadata for doc in documents] diff --git a/tests/integration/test_mysql_vectorstore_index.py b/tests/integration/test_mysql_vectorstore_index.py new file mode 100644 index 0000000..f2eb302 --- /dev/null +++ b/tests/integration/test_mysql_vectorstore_index.py @@ -0,0 +1,180 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid + +import pytest +from langchain_community.embeddings import DeterministicFakeEmbedding +from langchain_core.documents import Document + +from langchain_google_cloud_sql_mysql import ( + DistanceMeasure, + IndexType, + MySQLEngine, + MySQLVectorStore, + SearchType, + VectorIndex, +) + +DEFAULT_TABLE = "test_table_" + str(uuid.uuid4()).split("-")[0] +TABLE_1000_ROWS = "test_table_1000_rows" +VECTOR_SIZE = 8 + +embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE) + + +def get_env_var(key: str, desc: str) -> str: + v = os.environ.get(key) + if v is None: + raise ValueError(f"Must set env var {key} to: {desc}") + return v + + +class TestVectorStoreFromMethods: + @pytest.fixture(scope="module") + def db_project(self) -> str: + return get_env_var("PROJECT_ID", "project id for google cloud") + + @pytest.fixture(scope="module") + def db_region(self) -> str: + return get_env_var("REGION", "region for cloud sql instance") + + @pytest.fixture(scope="module") + def db_instance(self) -> str: + return get_env_var("INSTANCE_ID", "instance for cloud sql") + + @pytest.fixture(scope="module") + def db_name(self) -> str: + return get_env_var("DB_NAME", "database name on cloud sql instance") + + @pytest.fixture(scope="class") + def engine(self, db_project, db_region, db_instance, db_name): + engine = MySQLEngine.from_instance( + project_id=db_project, + instance=db_instance, + region=db_region, + database=db_name, + ) + yield engine + + @pytest.fixture(scope="class") + def vs(self, engine): + engine.init_vectorstore_table( + DEFAULT_TABLE, + VECTOR_SIZE, + overwrite_existing=True, + ) + vs = MySQLVectorStore( + engine, + embedding_service=embeddings_service, + table_name=DEFAULT_TABLE, + ) + yield vs + vs.drop_vector_index() + engine._execute(f"DROP TABLE IF EXISTS `{DEFAULT_TABLE}`") + + @pytest.fixture(scope="class") + def vs_1000(self, engine): + result = engine._fetch("SHOW TABLES") + tables = [list(r.values())[0] for r in result] + if TABLE_1000_ROWS not in tables: + engine.init_vectorstore_table( + TABLE_1000_ROWS, + VECTOR_SIZE, + ) + vs_1000 = MySQLVectorStore( + engine, + embedding_service=embeddings_service, + table_name=TABLE_1000_ROWS, + ) + row_count = vs_1000.engine._fetch(f"SELECT count(*) FROM `{TABLE_1000_ROWS}`")[ + 0 + ]["count(*)"] + # Add 1000 rows of data if the number of rows is less than 1000 + if row_count < 1000: + texts_1000 = [ + f"{text}_{i}" + for text in ["apple", "dog", "basketball", "coffee"] + for i in range(1, 251) + ] + ids = [str(uuid.uuid4()) for _ in range(len(texts_1000))] + vs_1000.add_texts(texts_1000, ids=ids) + vs_1000.drop_vector_index() + yield vs_1000 + vs_1000.drop_vector_index() + + def test_create_and_drop_index(self, vs): + vs.apply_vector_index(VectorIndex()) + assert ( + vs._get_vector_index_name() + == f"{vs.db_name}.{vs.table_name}_langchainvectorindex" + ) + assert vs.query_options.search_type == SearchType.ANN + vs.drop_vector_index() + assert vs._get_vector_index_name() is None + assert vs.query_options.search_type == SearchType.KNN + + def test_update_index(self, vs): + vs.apply_vector_index(VectorIndex()) + assert ( + vs._get_vector_index_name() + == f"{vs.db_name}.{vs.table_name}_langchainvectorindex" + ) + assert vs.query_options.search_type == SearchType.ANN + vs.alter_vector_index( + VectorIndex( + index_type=IndexType.BRUTE_FORCE_SCAN, + distance_measure=DistanceMeasure.SQUARED_L2, + num_neighbors=10, + ) + ) + assert ( + vs._get_vector_index_name() + == f"{vs.db_name}.{vs.table_name}_langchainvectorindex" + ) + vs.drop_vector_index() + assert vs.query_options.search_type == SearchType.KNN + + def test_create_and_drop_index_tree_sq(self, vs_1000): + vs_1000.apply_vector_index( + VectorIndex( + name="tree_sq", + index_type=IndexType.TREE_SQ, + distance_measure=DistanceMeasure.SQUARED_L2, + num_partitions=1, + num_neighbors=5, + ) + ) + assert vs_1000._get_vector_index_name() == f"{vs_1000.db_name}.tree_sq" + assert vs_1000.query_options.search_type == SearchType.ANN + vs_1000.drop_vector_index() + assert vs_1000._get_vector_index_name() is None + assert vs_1000.query_options.search_type == SearchType.KNN + + def test_create_and_drop_index_tree_ah(self, vs_1000): + vs_1000.apply_vector_index( + VectorIndex( + name="tree_ah", + index_type=IndexType.TREE_AH, + distance_measure=DistanceMeasure.COSINE, + num_partitions=2, + num_neighbors=10, + ) + ) + assert vs_1000._get_vector_index_name() == f"{vs_1000.db_name}.tree_ah" + assert vs_1000.query_options.search_type == SearchType.ANN + vs_1000.drop_vector_index() + assert vs_1000._get_vector_index_name() is None + assert vs_1000.query_options.search_type == SearchType.KNN