From 9ff383956710995ebb3266f56b3e319efa854ade Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Wed, 24 Jan 2024 20:15:41 -0500 Subject: [PATCH 01/12] feat: add CloudSQLMySQLEngine class (#7) --- .github/workflows/lint.yml | 2 +- mypy.ini | 5 + .../__init__.py | 4 + .../cloud_sql_mysql_engine.py | 158 ++++++++++++++++++ 4 files changed, 168 insertions(+), 1 deletion(-) create mode 100644 mypy.ini create mode 100644 src/langchain_google_cloud_sql_mysql/cloud_sql_mysql_engine.py diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 521319f..e903f9f 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -51,4 +51,4 @@ jobs: isort --check . - name: Run type-check - run: mypy . + run: mypy --install-types --non-interactive src/ diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..7752070 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,5 @@ +[mypy] +python_version = 3.11 + +[mypy-google.auth.*] +ignore_missing_imports = True diff --git a/src/langchain_google_cloud_sql_mysql/__init__.py b/src/langchain_google_cloud_sql_mysql/__init__.py index 6d5e14b..695dc1b 100644 --- a/src/langchain_google_cloud_sql_mysql/__init__.py +++ b/src/langchain_google_cloud_sql_mysql/__init__.py @@ -11,3 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from langchain_google_cloud_sql_mysql.cloud_sql_mysql_engine import CloudSQLMySQLEngine + +__all__ = ["CloudSQLMySQLEngine"] diff --git a/src/langchain_google_cloud_sql_mysql/cloud_sql_mysql_engine.py b/src/langchain_google_cloud_sql_mysql/cloud_sql_mysql_engine.py new file mode 100644 index 0000000..2cd5d08 --- /dev/null +++ b/src/langchain_google_cloud_sql_mysql/cloud_sql_mysql_engine.py @@ -0,0 +1,158 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional + +import google.auth +import google.auth.transport.requests +import requests +import sqlalchemy +from google.cloud.sql.connector import Connector + +if TYPE_CHECKING: + import google.auth.credentials + import pymysql + + +def _get_iam_principal_email( + credentials: google.auth.credentials.Credentials, +) -> str: + """Get email address associated with current authenticated IAM principal. + + Email will be used for automatic IAM database authentication to Cloud SQL. + + Args: + credentials (google.auth.credentials.Credentials): + The credentials object to use in finding the associated IAM + principal email address. + + Returns: + email (str): + The email address associated with the current authenticated IAM + principal. + """ + # if credentials are associated with a service account email, return early + if hasattr(credentials, "_service_account_email"): + return credentials._service_account_email + # refresh credentials if they are not valid + if not credentials.valid: + request = google.auth.transport.requests.Request() + credentials.refresh(request) + # call OAuth2 api to get IAM principal email associated with OAuth2 token + url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" + response = requests.get(url) + response_json: Dict = response.json() + email = response_json.get("email") + if email is None: + raise ValueError( + "Failed to automatically obtain authenticated IAM princpal's " + "email address using environment's ADC credentials!" + ) + return email + + +class CloudSQLMySQLEngine: + """A class for managing connections to a Cloud SQL for MySQL database.""" + + def __init__( + self, + project_id: Optional[str] = None, + region: Optional[str] = None, + instance: Optional[str] = None, + database: Optional[str] = None, + ) -> None: + self._project_id = project_id + self._region = region + self._instance = instance + self._database = database + self.engine = self._create_connector_engine() + + def close(self) -> None: + """Utility method for closing the Cloud SQL Python Connector + background tasks. + """ + if hasattr(self, "_connector"): + self._connector.close() + + @classmethod + def from_instance( + cls, + project_id: str, + region: str, + instance: str, + database: str, + ) -> CloudSQLMySQLEngine: + """Create an instance of CloudSQLMySQLEngine from Cloud SQL instance + details. + + This method uses the Cloud SQL Python Connector to connect to Cloud SQL + using automatic IAM database authentication with the Google ADC + credentials sourced from the environment. + + More details can be found at https://github.com/GoogleCloudPlatform/cloud-sql-python-connector#credentials + + Args: + project_id (str): Project ID of the Google Cloud Project where + the Cloud SQL instance is located. + region (str): Region where the Cloud SQL instance is located. + instance (str): The name of the Cloud SQL instance. + database (str): The name of the database to connect to on the + Cloud SQL instance. + + Returns: + (CloudSQLMySQLEngine): The engine configured to connect to a + Cloud SQL instance database. + """ + return cls( + project_id=project_id, + region=region, + instance=instance, + database=database, + ) + + def _create_connector_engine(self) -> sqlalchemy.engine.Engine: + """Create a SQLAlchemy engine using the Cloud SQL Python Connector. + + Defaults to use "pymysql" driver and to connect using automatic IAM + database authentication with the IAM principal associated with the + environment's Google Application Default Credentials. + + Returns: + (sqlalchemy.engine.Engine): Engine configured using the Cloud SQL + Python Connector. + """ + # get application default credentials + credentials, _ = google.auth.default( + scopes=["https://www.googleapis.com/auth/userinfo.email"] + ) + iam_database_user = _get_iam_principal_email(credentials) + self._connector = Connector() + + # anonymous function to be used for SQLAlchemy 'creator' argument + def getconn() -> pymysql.Connection: + conn = self._connector.connect( + f"{self._project_id}:{self._region}:{self._instance}", + "pymysql", + user=iam_database_user, + db=self._database, + enable_iam_auth=True, + ) + return conn + + return sqlalchemy.create_engine( + "mysql+pymysql://", + creator=getconn, + ) From eb2cba123042d7895e68445b1e30507252b8a631 Mon Sep 17 00:00:00 2001 From: Liang <154559835+loeng2023@users.noreply.github.com> Date: Thu, 25 Jan 2024 14:51:17 -0800 Subject: [PATCH 02/12] feat: support load document by query (#8) * load documents by query. * make integration test works * rebase to include CloudSQLMySQLEngine * fix in test * minor change * Update class naming * refine typing * fix lint --- integration.cloudbuild.yaml | 11 ++ .../__init__.py | 5 +- .../doc_loader.py | 88 +++++++++++++ ...ud_sql_mysql_engine.py => mysql_engine.py} | 20 +-- tests/integration_tests/test_doc_loader.py | 118 ++++++++++++++++++ 5 files changed, 233 insertions(+), 9 deletions(-) create mode 100644 src/langchain_google_cloud_sql_mysql/doc_loader.py rename src/langchain_google_cloud_sql_mysql/{cloud_sql_mysql_engine.py => mysql_engine.py} (91%) create mode 100644 tests/integration_tests/test_doc_loader.py diff --git a/integration.cloudbuild.yaml b/integration.cloudbuild.yaml index c78148e..3808ade 100644 --- a/integration.cloudbuild.yaml +++ b/integration.cloudbuild.yaml @@ -22,3 +22,14 @@ steps: name: python:3.11 entrypoint: python args: ["-m", "pytest"] + env: + - 'PROJECT_ID=$PROJECT_ID' + - 'INSTANCE_ID=$_INSTANCE_ID' + - 'DB_NAME=$_DB_NAME' + - 'TABLE_NAME=test-$BUILD_ID' + - 'REGION=$_REGION' + +substitutions: + _INSTANCE_ID: test-instance + _REGION: us-central1 + _DB_NAME: test diff --git a/src/langchain_google_cloud_sql_mysql/__init__.py b/src/langchain_google_cloud_sql_mysql/__init__.py index 695dc1b..12bf1cf 100644 --- a/src/langchain_google_cloud_sql_mysql/__init__.py +++ b/src/langchain_google_cloud_sql_mysql/__init__.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from langchain_google_cloud_sql_mysql.cloud_sql_mysql_engine import CloudSQLMySQLEngine +from langchain_google_cloud_sql_mysql.doc_loader import MySQLLoader +from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine -__all__ = ["CloudSQLMySQLEngine"] +__all__ = ["MySQLEngine", "MySQLLoader"] diff --git a/src/langchain_google_cloud_sql_mysql/doc_loader.py b/src/langchain_google_cloud_sql_mysql/doc_loader.py new file mode 100644 index 0000000..6cd3e76 --- /dev/null +++ b/src/langchain_google_cloud_sql_mysql/doc_loader.py @@ -0,0 +1,88 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Iterable +from typing import Any, Dict, List, Optional, Sequence, cast + +import sqlalchemy +from langchain_community.document_loaders.base import BaseLoader +from langchain_core.documents import Document + +from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine + + +def _parse_doc_from_table( + page_content_columns: Iterable[str], + metadata_columns: Iterable[str], + column_names: Iterable[str], + rows: Sequence[Any], +) -> List[Document]: + docs = [] + for row in rows: + page_content = "\n".join( + f"{column}: {getattr(row, column)}" + for column in page_content_columns + if column in column_names + ) + metadata = { + column: getattr(row, column) + for column in metadata_columns + if column in column_names + } + doc = Document(page_content=page_content, metadata=metadata) + docs.append(doc) + return docs + + +class MySQLLoader(BaseLoader): + """A class for loading langchain documents from a Cloud SQL MySQL database.""" + + def __init__( + self, + engine: MySQLEngine, + query: str, + page_content_columns: Optional[List[str]] = None, + metadata_columns: Optional[List[str]] = None, + ): + """ + Args: + engine (MySQLEngine): MySQLEngine object to connect to the MySQL database. + query (str): The query to execute in MySQL format. + page_content_columns (List[str]): The columns to write into the `page_content` + of the document. Optional. + metadata_columns (List[str]): The columns to write into the `metadata` of the document. + Optional. + """ + self.engine = engine + self.query = query + self.page_content_columns = page_content_columns + self.metadata_columns = metadata_columns + + def load(self) -> List[Document]: + """ + Load langchain documents from a Cloud SQL MySQL database. + + Returns: + (List[langchain_core.documents.Document]): a list of Documents with metadata from + specific columns. + """ + with self.engine.connect() as connection: + result_proxy = connection.execute(sqlalchemy.text(self.query)) + column_names = result_proxy.keys() + results = result_proxy.fetchall() + return _parse_doc_from_table( + self.page_content_columns or column_names, + self.metadata_columns or [], + column_names, + results, + ) diff --git a/src/langchain_google_cloud_sql_mysql/cloud_sql_mysql_engine.py b/src/langchain_google_cloud_sql_mysql/mysql_engine.py similarity index 91% rename from src/langchain_google_cloud_sql_mysql/cloud_sql_mysql_engine.py rename to src/langchain_google_cloud_sql_mysql/mysql_engine.py index 2cd5d08..61dff9e 100644 --- a/src/langchain_google_cloud_sql_mysql/cloud_sql_mysql_engine.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_engine.py @@ -44,9 +44,6 @@ def _get_iam_principal_email( The email address associated with the current authenticated IAM principal. """ - # if credentials are associated with a service account email, return early - if hasattr(credentials, "_service_account_email"): - return credentials._service_account_email # refresh credentials if they are not valid if not credentials.valid: request = google.auth.transport.requests.Request() @@ -64,7 +61,7 @@ def _get_iam_principal_email( return email -class CloudSQLMySQLEngine: +class MySQLEngine: """A class for managing connections to a Cloud SQL for MySQL database.""" def __init__( @@ -94,8 +91,8 @@ def from_instance( region: str, instance: str, database: str, - ) -> CloudSQLMySQLEngine: - """Create an instance of CloudSQLMySQLEngine from Cloud SQL instance + ) -> MySQLEngine: + """Create an instance of MySQLEngine from Cloud SQL instance details. This method uses the Cloud SQL Python Connector to connect to Cloud SQL @@ -113,7 +110,7 @@ def from_instance( Cloud SQL instance. Returns: - (CloudSQLMySQLEngine): The engine configured to connect to a + (MySQLEngine): The engine configured to connect to a Cloud SQL instance database. """ return cls( @@ -156,3 +153,12 @@ def getconn() -> pymysql.Connection: "mysql+pymysql://", creator=getconn, ) + + def connect(self) -> sqlalchemy.engine.Connection: + """Create a connection from SQLAlchemy connection pool. + + Returns: + (sqlalchemy.engine.Connection): a single DBAPI connection checked + out from the connection pool. + """ + return self.engine.connect() diff --git a/tests/integration_tests/test_doc_loader.py b/tests/integration_tests/test_doc_loader.py new file mode 100644 index 0000000..f11fadf --- /dev/null +++ b/tests/integration_tests/test_doc_loader.py @@ -0,0 +1,118 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from typing import Generator + +import pymysql +import pytest +import sqlalchemy +from google.cloud.sql.connector import Connector +from langchain_core.documents import Document + +from langchain_google_cloud_sql_mysql import MySQLEngine, MySQLLoader + +project_id = os.environ.get("PROJECT_ID", None) +region = os.environ.get("REGION") +instance_id = os.environ.get("INSTANCE_ID") +table_name = os.environ.get("TABLE_NAME") +db_name = os.environ.get("DB_NAME") + +test_docs = [ + Document( + page_content="fruit_name: Apple\nvariety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 0.99\norganic: 1", + metadata={"fruit_id": 1}, + ), + Document( + page_content="fruit_name: Banana\nvariety: Cavendish\nquantity_in_stock: 200\nprice_per_unit: 0.59\norganic: 0", + metadata={"fruit_id": 2}, + ), + Document( + page_content="fruit_name: Orange\nvariety: Navel\nquantity_in_stock: 80\nprice_per_unit: 1.29\norganic: 1", + metadata={"fruit_id": 3}, + ), + Document( + page_content="fruit_name: Strawberry\nvariety: Camarosa\nquantity_in_stock: 35\nprice_per_unit: 2.49\norganic: 1", + metadata={"fruit_id": 4}, + ), + Document( + page_content="fruit_name: Grape\nvariety: Thompson Seedless\nquantity_in_stock: 120\nprice_per_unit: 1.99\norganic: 0", + metadata={"fruit_id": 5}, + ), +] + + +@pytest.fixture(name="engine") +def setup() -> Generator: + engine = MySQLEngine.from_instance( + project_id=project_id, region=region, instance=instance_id, database=db_name + ) + + with engine.connect() as conn: + conn.execute( + sqlalchemy.text( + f""" + CREATE TABLE IF NOT EXISTS `{table_name}`( + fruit_id INT AUTO_INCREMENT PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit DECIMAL(6,2) NOT NULL, + organic TINYINT(1) NOT NULL + ) + """ + ) + ) + conn.commit() + + yield engine + + with engine.connect() as conn: + conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`")) + conn.commit() + engine.close() + + +def test_load_from_query(engine): + with engine.connect() as conn: + conn.execute( + sqlalchemy.text( + f""" + INSERT INTO `{table_name}` (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES + ('Apple', 'Granny Smith', 150, 0.99, 1), + ('Banana', 'Cavendish', 200, 0.59, 0), + ('Orange', 'Navel', 80, 1.29, 1), + ('Strawberry', 'Camarosa', 35, 2.49, 1), + ('Grape', 'Thompson Seedless', 120, 1.99, 0); + """ + ) + ) + conn.commit() + query = f"SELECT * FROM `{table_name}`;" + loader = MySQLLoader( + engine=engine, + query=query, + page_content_columns=[ + "fruit_name", + "variety", + "quantity_in_stock", + "price_per_unit", + "organic", + ], + metadata_columns=["fruit_id"], + ) + + documents = loader.load() + + assert documents == test_docs From 53bfa4f56fba2273ae61801cf0b47f0b5b21ea19 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 26 Jan 2024 14:41:03 +0000 Subject: [PATCH 03/12] chore: update file and folder naming --- src/langchain_google_cloud_sql_mysql/__init__.py | 2 +- .../{doc_loader.py => mysql_loader.py} | 0 .../test_doc_loader.py => integration/test_mysql_loader.py} | 0 3 files changed, 1 insertion(+), 1 deletion(-) rename src/langchain_google_cloud_sql_mysql/{doc_loader.py => mysql_loader.py} (100%) rename tests/{integration_tests/test_doc_loader.py => integration/test_mysql_loader.py} (100%) diff --git a/src/langchain_google_cloud_sql_mysql/__init__.py b/src/langchain_google_cloud_sql_mysql/__init__.py index 12bf1cf..de30f8e 100644 --- a/src/langchain_google_cloud_sql_mysql/__init__.py +++ b/src/langchain_google_cloud_sql_mysql/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from langchain_google_cloud_sql_mysql.doc_loader import MySQLLoader +from langchain_google_cloud_sql_mysql.mysql_loader import MySQLLoader from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine __all__ = ["MySQLEngine", "MySQLLoader"] diff --git a/src/langchain_google_cloud_sql_mysql/doc_loader.py b/src/langchain_google_cloud_sql_mysql/mysql_loader.py similarity index 100% rename from src/langchain_google_cloud_sql_mysql/doc_loader.py rename to src/langchain_google_cloud_sql_mysql/mysql_loader.py diff --git a/tests/integration_tests/test_doc_loader.py b/tests/integration/test_mysql_loader.py similarity index 100% rename from tests/integration_tests/test_doc_loader.py rename to tests/integration/test_mysql_loader.py From bc10773ac8509a7b1ded49a89ba13277e800bb55 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 26 Jan 2024 14:43:43 +0000 Subject: [PATCH 04/12] chore: lint --- src/langchain_google_cloud_sql_mysql/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/langchain_google_cloud_sql_mysql/__init__.py b/src/langchain_google_cloud_sql_mysql/__init__.py index de30f8e..8114e32 100644 --- a/src/langchain_google_cloud_sql_mysql/__init__.py +++ b/src/langchain_google_cloud_sql_mysql/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from langchain_google_cloud_sql_mysql.mysql_loader import MySQLLoader from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine +from langchain_google_cloud_sql_mysql.mysql_loader import MySQLLoader __all__ = ["MySQLEngine", "MySQLLoader"] From 3bf75062ef6783f53024e9a71ec2222ee9292481 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 26 Jan 2024 14:48:56 +0000 Subject: [PATCH 05/12] chore: run mypy on tests --- .github/workflows/lint.yml | 2 +- tests/integration/test_mysql_loader.py | 12 +++++------- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index e903f9f..3538bc1 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -51,4 +51,4 @@ jobs: isort --check . - name: Run type-check - run: mypy --install-types --non-interactive src/ + run: mypy --install-types --non-interactive src/ tests/ diff --git a/tests/integration/test_mysql_loader.py b/tests/integration/test_mysql_loader.py index f11fadf..39e9d62 100644 --- a/tests/integration/test_mysql_loader.py +++ b/tests/integration/test_mysql_loader.py @@ -14,19 +14,17 @@ import os from typing import Generator -import pymysql import pytest import sqlalchemy -from google.cloud.sql.connector import Connector from langchain_core.documents import Document from langchain_google_cloud_sql_mysql import MySQLEngine, MySQLLoader -project_id = os.environ.get("PROJECT_ID", None) -region = os.environ.get("REGION") -instance_id = os.environ.get("INSTANCE_ID") -table_name = os.environ.get("TABLE_NAME") -db_name = os.environ.get("DB_NAME") +project_id = os.environ["PROJECT_ID"] +region = os.environ["REGION"] +instance_id = os.environ["INSTANCE_ID"] +table_name = os.environ["TABLE_NAME"] +db_name = os.environ["DB_NAME"] test_docs = [ Document( From bd9d471aaff90c54b33b81c5177268b8cbd0f4f9 Mon Sep 17 00:00:00 2001 From: Liang Wang Date: Fri, 26 Jan 2024 19:48:38 +0000 Subject: [PATCH 06/12] fix: load document schema. --- .../mysql_loader.py | 28 ++++--- tests/integration/test_mysql_loader.py | 78 ++++++++++++------- 2 files changed, 68 insertions(+), 38 deletions(-) diff --git a/src/langchain_google_cloud_sql_mysql/mysql_loader.py b/src/langchain_google_cloud_sql_mysql/mysql_loader.py index 6cd3e76..e626e2f 100644 --- a/src/langchain_google_cloud_sql_mysql/mysql_loader.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_loader.py @@ -22,16 +22,16 @@ def _parse_doc_from_table( - page_content_columns: Iterable[str], + content_columns: Iterable[str], metadata_columns: Iterable[str], column_names: Iterable[str], rows: Sequence[Any], ) -> List[Document]: docs = [] for row in rows: - page_content = "\n".join( - f"{column}: {getattr(row, column)}" - for column in page_content_columns + page_content = " ".join( + str(getattr(row, column)) + for column in content_columns if column in column_names ) metadata = { @@ -51,38 +51,46 @@ def __init__( self, engine: MySQLEngine, query: str, - page_content_columns: Optional[List[str]] = None, + content_columns: Optional[List[str]] = None, metadata_columns: Optional[List[str]] = None, ): """ Args: engine (MySQLEngine): MySQLEngine object to connect to the MySQL database. query (str): The query to execute in MySQL format. - page_content_columns (List[str]): The columns to write into the `page_content` + content_columns (List[str]): The columns to write into the `page_content` of the document. Optional. metadata_columns (List[str]): The columns to write into the `metadata` of the document. Optional. """ self.engine = engine self.query = query - self.page_content_columns = page_content_columns + self.content_columns = content_columns self.metadata_columns = metadata_columns def load(self) -> List[Document]: """ Load langchain documents from a Cloud SQL MySQL database. + Document page content defaults to the first columns present in the query or table and + metadata defaults to all other columns. Use with content_columns to overwrite the column + used for page content. Use metadata_columns to select specific metadata columns rather + than using all remaining columns. + + If multiple content columns are specified, page_content’s string format will default to + space-separated string concatenation. + Returns: (List[langchain_core.documents.Document]): a list of Documents with metadata from specific columns. """ with self.engine.connect() as connection: result_proxy = connection.execute(sqlalchemy.text(self.query)) - column_names = result_proxy.keys() + column_names = list(result_proxy.keys()) results = result_proxy.fetchall() return _parse_doc_from_table( - self.page_content_columns or column_names, - self.metadata_columns or [], + self.content_columns or [column_names[0]], + self.metadata_columns or column_names[1:], column_names, results, ) diff --git a/tests/integration/test_mysql_loader.py b/tests/integration/test_mysql_loader.py index 39e9d62..62d1a70 100644 --- a/tests/integration/test_mysql_loader.py +++ b/tests/integration/test_mysql_loader.py @@ -26,29 +26,6 @@ table_name = os.environ["TABLE_NAME"] db_name = os.environ["DB_NAME"] -test_docs = [ - Document( - page_content="fruit_name: Apple\nvariety: Granny Smith\nquantity_in_stock: 150\nprice_per_unit: 0.99\norganic: 1", - metadata={"fruit_id": 1}, - ), - Document( - page_content="fruit_name: Banana\nvariety: Cavendish\nquantity_in_stock: 200\nprice_per_unit: 0.59\norganic: 0", - metadata={"fruit_id": 2}, - ), - Document( - page_content="fruit_name: Orange\nvariety: Navel\nquantity_in_stock: 80\nprice_per_unit: 1.29\norganic: 1", - metadata={"fruit_id": 3}, - ), - Document( - page_content="fruit_name: Strawberry\nvariety: Camarosa\nquantity_in_stock: 35\nprice_per_unit: 2.49\norganic: 1", - metadata={"fruit_id": 4}, - ), - Document( - page_content="fruit_name: Grape\nvariety: Thompson Seedless\nquantity_in_stock: 120\nprice_per_unit: 1.99\norganic: 0", - metadata={"fruit_id": 5}, - ), -] - @pytest.fixture(name="engine") def setup() -> Generator: @@ -90,9 +67,7 @@ def test_load_from_query(engine): VALUES ('Apple', 'Granny Smith', 150, 0.99, 1), ('Banana', 'Cavendish', 200, 0.59, 0), - ('Orange', 'Navel', 80, 1.29, 1), - ('Strawberry', 'Camarosa', 35, 2.49, 1), - ('Grape', 'Thompson Seedless', 120, 1.99, 0); + ('Orange', 'Navel', 80, 1.29, 1); """ ) ) @@ -101,7 +76,7 @@ def test_load_from_query(engine): loader = MySQLLoader( engine=engine, query=query, - page_content_columns=[ + content_columns=[ "fruit_name", "variety", "quantity_in_stock", @@ -113,4 +88,51 @@ def test_load_from_query(engine): documents = loader.load() - assert documents == test_docs + assert documents == [ + Document( + page_content="Apple Granny Smith 150 0.99 1", + metadata={"fruit_id": 1}, + ), + Document( + page_content="Banana Cavendish 200 0.59 0", + metadata={"fruit_id": 2}, + ), + Document( + page_content="Orange Navel 80 1.29 1", + metadata={"fruit_id": 3}, + ), + ] + + +def test_load_from_query_default(engine): + with engine.connect() as conn: + conn.execute( + sqlalchemy.text( + f""" + INSERT INTO `{table_name}` (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES + ('Apple', 'Granny Smith', 150, 1, 1); + """ + ) + ) + conn.commit() + + query = f"SELECT * FROM `{table_name}`;" + loader = MySQLLoader( + engine=engine, + query=query, + ) + + documents = loader.load() + assert documents == [ + Document( + page_content="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + ] From a7a09e5c247931235e969119506a8b2585343d79 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 26 Jan 2024 21:52:08 +0000 Subject: [PATCH 07/12] chore: make connector a class attribute --- .../mysql_engine.py | 46 ++++++++++--------- 1 file changed, 25 insertions(+), 21 deletions(-) diff --git a/src/langchain_google_cloud_sql_mysql/mysql_engine.py b/src/langchain_google_cloud_sql_mysql/mysql_engine.py index 61dff9e..4c02d2e 100644 --- a/src/langchain_google_cloud_sql_mysql/mysql_engine.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_engine.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +# TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations from typing import TYPE_CHECKING, Dict, Optional @@ -64,25 +64,20 @@ def _get_iam_principal_email( class MySQLEngine: """A class for managing connections to a Cloud SQL for MySQL database.""" + _connector: Optional[Connector] = None + def __init__( self, - project_id: Optional[str] = None, - region: Optional[str] = None, - instance: Optional[str] = None, - database: Optional[str] = None, + engine: sqlalchemy.engine.Engine, ) -> None: - self._project_id = project_id - self._region = region - self._instance = instance - self._database = database - self.engine = self._create_connector_engine() + self.engine = engine def close(self) -> None: """Utility method for closing the Cloud SQL Python Connector background tasks. """ - if hasattr(self, "_connector"): - self._connector.close() + if MySQLEngine._connector: + MySQLEngine._connector.close() @classmethod def from_instance( @@ -113,20 +108,28 @@ def from_instance( (MySQLEngine): The engine configured to connect to a Cloud SQL instance database. """ - return cls( - project_id=project_id, - region=region, - instance=instance, + engine = cls._create_connector_engine( + instance_connection_name=f"{project_id}:{region}:{instance}", database=database, ) + return cls(engine=engine) - def _create_connector_engine(self) -> sqlalchemy.engine.Engine: + @classmethod + def _create_connector_engine( + cls, instance_connection_name: str, database: str + ) -> sqlalchemy.engine.Engine: """Create a SQLAlchemy engine using the Cloud SQL Python Connector. Defaults to use "pymysql" driver and to connect using automatic IAM database authentication with the IAM principal associated with the environment's Google Application Default Credentials. + Args: + instance_connection_name (str): The instance connection + name of the Cloud SQL instance to establish a connection to. + (ex. "project-id:instance-region:instance-name") + database (str): The name of the database to connect to on the + Cloud SQL instance. Returns: (sqlalchemy.engine.Engine): Engine configured using the Cloud SQL Python Connector. @@ -136,15 +139,16 @@ def _create_connector_engine(self) -> sqlalchemy.engine.Engine: scopes=["https://www.googleapis.com/auth/userinfo.email"] ) iam_database_user = _get_iam_principal_email(credentials) - self._connector = Connector() + if cls._connector is None: + cls._connector = Connector() # anonymous function to be used for SQLAlchemy 'creator' argument def getconn() -> pymysql.Connection: - conn = self._connector.connect( - f"{self._project_id}:{self._region}:{self._instance}", + conn = cls._connector.connect( # type: ignore + instance_connection_name, "pymysql", user=iam_database_user, - db=self._database, + db=database, enable_iam_auth=True, ) return conn From 51152074086f6716f24ea44409943003eae318fe Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 26 Jan 2024 22:35:26 +0000 Subject: [PATCH 08/12] chore: remove close and raise errors --- src/langchain_google_cloud_sql_mysql/mysql_engine.py | 8 +------- tests/integration/test_mysql_loader.py | 1 - 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/src/langchain_google_cloud_sql_mysql/mysql_engine.py b/src/langchain_google_cloud_sql_mysql/mysql_engine.py index 4c02d2e..0279423 100644 --- a/src/langchain_google_cloud_sql_mysql/mysql_engine.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_engine.py @@ -51,6 +51,7 @@ def _get_iam_principal_email( # call OAuth2 api to get IAM principal email associated with OAuth2 token url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" response = requests.get(url) + response.raise_for_status() response_json: Dict = response.json() email = response_json.get("email") if email is None: @@ -72,13 +73,6 @@ def __init__( ) -> None: self.engine = engine - def close(self) -> None: - """Utility method for closing the Cloud SQL Python Connector - background tasks. - """ - if MySQLEngine._connector: - MySQLEngine._connector.close() - @classmethod def from_instance( cls, diff --git a/tests/integration/test_mysql_loader.py b/tests/integration/test_mysql_loader.py index 62d1a70..8527bac 100644 --- a/tests/integration/test_mysql_loader.py +++ b/tests/integration/test_mysql_loader.py @@ -55,7 +55,6 @@ def setup() -> Generator: with engine.connect() as conn: conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`")) conn.commit() - engine.close() def test_load_from_query(engine): From 8f1926c2c04a8a4204f431a50ff75952df533bc9 Mon Sep 17 00:00:00 2001 From: Jack Wotherspoon Date: Mon, 29 Jan 2024 17:07:29 -0500 Subject: [PATCH 09/12] chore: add whitespace Co-authored-by: Kurtis Van Gent <31518063+kurtisvg@users.noreply.github.com> --- src/langchain_google_cloud_sql_mysql/mysql_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/langchain_google_cloud_sql_mysql/mysql_engine.py b/src/langchain_google_cloud_sql_mysql/mysql_engine.py index 0279423..aac96c4 100644 --- a/src/langchain_google_cloud_sql_mysql/mysql_engine.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_engine.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + # TODO: Remove below import when minimum supported Python version is 3.10 from __future__ import annotations From 31c49e433099e1c17a853c751bd8633c1eee6566 Mon Sep 17 00:00:00 2001 From: Liang Wang Date: Mon, 29 Jan 2024 22:26:56 +0000 Subject: [PATCH 10/12] Add int tests to cover combination of customized content columns and metadata columns in doc laoder. --- .../mysql_loader.py | 8 +- tests/integration/test_mysql_loader.py | 80 +++++++++++++++++-- 2 files changed, 81 insertions(+), 7 deletions(-) diff --git a/src/langchain_google_cloud_sql_mysql/mysql_loader.py b/src/langchain_google_cloud_sql_mysql/mysql_loader.py index e626e2f..c2af45a 100644 --- a/src/langchain_google_cloud_sql_mysql/mysql_loader.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_loader.py @@ -88,9 +88,13 @@ def load(self) -> List[Document]: result_proxy = connection.execute(sqlalchemy.text(self.query)) column_names = list(result_proxy.keys()) results = result_proxy.fetchall() + content_columns = self.content_columns or [column_names[0]] + metadata_columns = self.metadata_columns or [ + col for col in column_names if col not in content_columns + ] return _parse_doc_from_table( - self.content_columns or [column_names[0]], - self.metadata_columns or column_names[1:], + content_columns, + metadata_columns, column_names, results, ) diff --git a/tests/integration/test_mysql_loader.py b/tests/integration/test_mysql_loader.py index 8527bac..864b6ad 100644 --- a/tests/integration/test_mysql_loader.py +++ b/tests/integration/test_mysql_loader.py @@ -57,7 +57,40 @@ def setup() -> Generator: conn.commit() -def test_load_from_query(engine): +def test_load_from_query_default(engine): + with engine.connect() as conn: + conn.execute( + sqlalchemy.text( + f""" + INSERT INTO `{table_name}` (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES + ('Apple', 'Granny Smith', 150, 1, 1); + """ + ) + ) + conn.commit() + query = f"SELECT * FROM `{table_name}`;" + loader = MySQLLoader( + engine=engine, + query=query, + ) + + documents = loader.load() + assert documents == [ + Document( + page_content="1", + metadata={ + "fruit_name": "Apple", + "variety": "Granny Smith", + "quantity_in_stock": 150, + "price_per_unit": 1, + "organic": 1, + }, + ) + ] + + +def test_load_from_query_customized_content_customized_metadata(engine): with engine.connect() as conn: conn.execute( sqlalchemy.text( @@ -103,7 +136,43 @@ def test_load_from_query(engine): ] -def test_load_from_query_default(engine): +def test_load_from_query_customized_content_default_metadata(engine): + with engine.connect() as conn: + conn.execute( + sqlalchemy.text( + f""" + INSERT INTO `{table_name}` (fruit_name, variety, quantity_in_stock, price_per_unit, organic) + VALUES + ('Apple', 'Granny Smith', 150, 0.99, 1); + """ + ) + ) + conn.commit() + query = f"SELECT * FROM `{table_name}`;" + loader = MySQLLoader( + engine=engine, + query=query, + content_columns=[ + "variety", + "quantity_in_stock", + "price_per_unit", + ], + ) + + documents = loader.load() + assert documents == [ + Document( + page_content="Granny Smith 150 0.99", + metadata={ + "fruit_id": 1, + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] + + +def test_load_from_query_default_content_customized_metadata(engine): with engine.connect() as conn: conn.execute( sqlalchemy.text( @@ -120,6 +189,10 @@ def test_load_from_query_default(engine): loader = MySQLLoader( engine=engine, query=query, + metadata_columns=[ + "fruit_name", + "organic", + ], ) documents = loader.load() @@ -128,9 +201,6 @@ def test_load_from_query_default(engine): page_content="1", metadata={ "fruit_name": "Apple", - "variety": "Granny Smith", - "quantity_in_stock": 150, - "price_per_unit": 1, "organic": 1, }, ) From e6b46e5aebda2c339949b03ece184a9e50f6b1a1 Mon Sep 17 00:00:00 2001 From: Liang Wang Date: Mon, 29 Jan 2024 23:00:41 +0000 Subject: [PATCH 11/12] feat: support default metadata langchain_metadata in doc loader. --- .../mysql_loader.py | 7 ++ tests/integration/test_mysql_loader.py | 87 +++++++++++++++---- 2 files changed, 77 insertions(+), 17 deletions(-) diff --git a/src/langchain_google_cloud_sql_mysql/mysql_loader.py b/src/langchain_google_cloud_sql_mysql/mysql_loader.py index c2af45a..29c6dcb 100644 --- a/src/langchain_google_cloud_sql_mysql/mysql_loader.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_loader.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json from collections.abc import Iterable from typing import Any, Dict, List, Optional, Sequence, cast @@ -20,6 +21,8 @@ from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine +DEFAULT_METADATA_COL = "langchain_metadata" + def _parse_doc_from_table( content_columns: Iterable[str], @@ -39,6 +42,10 @@ def _parse_doc_from_table( for column in metadata_columns if column in column_names } + if DEFAULT_METADATA_COL in metadata: + extra_metadata = json.loads(metadata[DEFAULT_METADATA_COL]) + del metadata[DEFAULT_METADATA_COL] + metadata |= extra_metadata doc = Document(page_content=page_content, metadata=metadata) docs.append(doc) return docs diff --git a/tests/integration/test_mysql_loader.py b/tests/integration/test_mysql_loader.py index 864b6ad..a9c7148 100644 --- a/tests/integration/test_mysql_loader.py +++ b/tests/integration/test_mysql_loader.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import json import os from typing import Generator @@ -32,7 +33,15 @@ def setup() -> Generator: engine = MySQLEngine.from_instance( project_id=project_id, region=region, instance=instance_id, database=db_name ) + yield engine + + with engine.connect() as conn: + conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`")) + conn.commit() + +@pytest.fixture +def default_setup(engine): with engine.connect() as conn: conn.execute( sqlalchemy.text( @@ -49,16 +58,11 @@ def setup() -> Generator: ) ) conn.commit() - yield engine - with engine.connect() as conn: - conn.execute(sqlalchemy.text(f"DROP TABLE IF EXISTS `{table_name}`")) - conn.commit() - -def test_load_from_query_default(engine): - with engine.connect() as conn: +def test_load_from_query_default(default_setup): + with default_setup.connect() as conn: conn.execute( sqlalchemy.text( f""" @@ -71,7 +75,7 @@ def test_load_from_query_default(engine): conn.commit() query = f"SELECT * FROM `{table_name}`;" loader = MySQLLoader( - engine=engine, + engine=default_setup, query=query, ) @@ -90,8 +94,8 @@ def test_load_from_query_default(engine): ] -def test_load_from_query_customized_content_customized_metadata(engine): - with engine.connect() as conn: +def test_load_from_query_customized_content_customized_metadata(default_setup): + with default_setup.connect() as conn: conn.execute( sqlalchemy.text( f""" @@ -106,7 +110,7 @@ def test_load_from_query_customized_content_customized_metadata(engine): conn.commit() query = f"SELECT * FROM `{table_name}`;" loader = MySQLLoader( - engine=engine, + engine=default_setup, query=query, content_columns=[ "fruit_name", @@ -136,8 +140,8 @@ def test_load_from_query_customized_content_customized_metadata(engine): ] -def test_load_from_query_customized_content_default_metadata(engine): - with engine.connect() as conn: +def test_load_from_query_customized_content_default_metadata(default_setup): + with default_setup.connect() as conn: conn.execute( sqlalchemy.text( f""" @@ -150,7 +154,7 @@ def test_load_from_query_customized_content_default_metadata(engine): conn.commit() query = f"SELECT * FROM `{table_name}`;" loader = MySQLLoader( - engine=engine, + engine=default_setup, query=query, content_columns=[ "variety", @@ -172,8 +176,8 @@ def test_load_from_query_customized_content_default_metadata(engine): ] -def test_load_from_query_default_content_customized_metadata(engine): - with engine.connect() as conn: +def test_load_from_query_default_content_customized_metadata(default_setup): + with default_setup.connect() as conn: conn.execute( sqlalchemy.text( f""" @@ -187,7 +191,7 @@ def test_load_from_query_default_content_customized_metadata(engine): query = f"SELECT * FROM `{table_name}`;" loader = MySQLLoader( - engine=engine, + engine=default_setup, query=query, metadata_columns=[ "fruit_name", @@ -205,3 +209,52 @@ def test_load_from_query_default_content_customized_metadata(engine): }, ) ] + + +def test_load_from_query_with_langchain_metadata(engine): + with engine.connect() as conn: + conn.execute( + sqlalchemy.text( + f""" + CREATE TABLE IF NOT EXISTS `{table_name}`( + fruit_id INT AUTO_INCREMENT PRIMARY KEY, + fruit_name VARCHAR(100) NOT NULL, + variety VARCHAR(50), + quantity_in_stock INT NOT NULL, + price_per_unit DECIMAL(6,2) NOT NULL, + langchain_metadata JSON NOT NULL + ) + """ + ) + ) + metadata = json.dumps({"organic": 1}) + conn.execute( + sqlalchemy.text( + f""" + INSERT INTO `{table_name}` (fruit_name, variety, quantity_in_stock, price_per_unit, langchain_metadata) + VALUES + ('Apple', 'Granny Smith', 150, 1, '{metadata}'); + """ + ) + ) + conn.commit() + query = f"SELECT * FROM `{table_name}`;" + loader = MySQLLoader( + engine=engine, + query=query, + metadata_columns=[ + "fruit_name", + "langchain_metadata", + ], + ) + + documents = loader.load() + assert documents == [ + Document( + page_content="1", + metadata={ + "fruit_name": "Apple", + "organic": 1, + }, + ) + ] From 6f3bd2b08d0ba2059bc3268cf89518315a898c74 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 30 Jan 2024 14:27:30 +0000 Subject: [PATCH 12/12] chore: re-add service_account_email check --- src/langchain_google_cloud_sql_mysql/mysql_engine.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/langchain_google_cloud_sql_mysql/mysql_engine.py b/src/langchain_google_cloud_sql_mysql/mysql_engine.py index aac96c4..1a6dc7b 100644 --- a/src/langchain_google_cloud_sql_mysql/mysql_engine.py +++ b/src/langchain_google_cloud_sql_mysql/mysql_engine.py @@ -49,6 +49,9 @@ def _get_iam_principal_email( if not credentials.valid: request = google.auth.transport.requests.Request() credentials.refresh(request) + # if credentials are associated with a service account email, return early + if hasattr(credentials, "_service_account_email"): + return credentials._service_account_email # call OAuth2 api to get IAM principal email associated with OAuth2 token url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}" response = requests.get(url)