Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add MySQLEngine and Loader load functionality #9

Merged
merged 13 commits into from
Jan 30, 2024
Merged
11 changes: 11 additions & 0 deletions integration.cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,14 @@ steps:
name: python:3.11
entrypoint: python
args: ["-m", "pytest"]
env:
- 'PROJECT_ID=$PROJECT_ID'
- 'INSTANCE_ID=$_INSTANCE_ID'
- 'DB_NAME=$_DB_NAME'
- 'TABLE_NAME=test-$BUILD_ID'
- 'REGION=$_REGION'

substitutions:
_INSTANCE_ID: test-instance
_REGION: us-central1
_DB_NAME: test
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,7 @@ warn_unused_configs = true
exclude = [
"owlbot.py"
]

[[tool.mypy.overrides]]
module="google.auth.*"
ignore_missing_imports = true
5 changes: 5 additions & 0 deletions src/langchain_google_cloud_sql_mysql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine
from langchain_google_cloud_sql_mysql.mysql_loader import MySQLLoader

__all__ = ["MySQLEngine", "MySQLLoader"]
166 changes: 166 additions & 0 deletions src/langchain_google_cloud_sql_mysql/mysql_engine.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# TODO: Remove below import when minimum supported Python version is 3.10
from __future__ import annotations
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved

from typing import TYPE_CHECKING, Dict, Optional

import google.auth
import google.auth.transport.requests
import requests
import sqlalchemy
from google.cloud.sql.connector import Connector

if TYPE_CHECKING:
import google.auth.credentials
import pymysql


def _get_iam_principal_email(
credentials: google.auth.credentials.Credentials,
) -> str:
"""Get email address associated with current authenticated IAM principal.

Email will be used for automatic IAM database authentication to Cloud SQL.

Args:
credentials (google.auth.credentials.Credentials):
The credentials object to use in finding the associated IAM
principal email address.

Returns:
email (str):
The email address associated with the current authenticated IAM
principal.
"""
# refresh credentials if they are not valid
if not credentials.valid:
request = google.auth.transport.requests.Request()
credentials.refresh(request)
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
# if credentials are associated with a service account email, return early
if hasattr(credentials, "_service_account_email"):
return credentials._service_account_email
# call OAuth2 api to get IAM principal email associated with OAuth2 token
url = f"https://oauth2.googleapis.com/tokeninfo?access_token={credentials.token}"
response = requests.get(url)
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
response.raise_for_status()
response_json: Dict = response.json()
email = response_json.get("email")
if email is None:
raise ValueError(
"Failed to automatically obtain authenticated IAM princpal's "
"email address using environment's ADC credentials!"
kurtisvg marked this conversation as resolved.
Show resolved Hide resolved
)
return email


class MySQLEngine:
"""A class for managing connections to a Cloud SQL for MySQL database."""

_connector: Optional[Connector] = None

def __init__(
self,
engine: sqlalchemy.engine.Engine,
) -> None:
self.engine = engine

@classmethod
def from_instance(
cls,
project_id: str,
region: str,
instance: str,
database: str,
) -> MySQLEngine:
"""Create an instance of MySQLEngine from Cloud SQL instance
details.

This method uses the Cloud SQL Python Connector to connect to Cloud SQL
using automatic IAM database authentication with the Google ADC
credentials sourced from the environment.

More details can be found at https://github.com/GoogleCloudPlatform/cloud-sql-python-connector#credentials

Args:
project_id (str): Project ID of the Google Cloud Project where
the Cloud SQL instance is located.
region (str): Region where the Cloud SQL instance is located.
instance (str): The name of the Cloud SQL instance.
database (str): The name of the database to connect to on the
Cloud SQL instance.

Returns:
(MySQLEngine): The engine configured to connect to a
Cloud SQL instance database.
"""
engine = cls._create_connector_engine(
instance_connection_name=f"{project_id}:{region}:{instance}",
database=database,
)
return cls(engine=engine)

@classmethod
def _create_connector_engine(
cls, instance_connection_name: str, database: str
) -> sqlalchemy.engine.Engine:
"""Create a SQLAlchemy engine using the Cloud SQL Python Connector.

Defaults to use "pymysql" driver and to connect using automatic IAM
database authentication with the IAM principal associated with the
environment's Google Application Default Credentials.

Args:
instance_connection_name (str): The instance connection
name of the Cloud SQL instance to establish a connection to.
(ex. "project-id:instance-region:instance-name")
database (str): The name of the database to connect to on the
Cloud SQL instance.
Returns:
(sqlalchemy.engine.Engine): Engine configured using the Cloud SQL
Python Connector.
"""
# get application default credentials
credentials, _ = google.auth.default(
scopes=["https://www.googleapis.com/auth/userinfo.email"]
)
iam_database_user = _get_iam_principal_email(credentials)
if cls._connector is None:
cls._connector = Connector()

# anonymous function to be used for SQLAlchemy 'creator' argument
def getconn() -> pymysql.Connection:
conn = cls._connector.connect( # type: ignore
instance_connection_name,
"pymysql",
user=iam_database_user,
db=database,
enable_iam_auth=True,
)
return conn

return sqlalchemy.create_engine(
"mysql+pymysql://",
creator=getconn,
)

def connect(self) -> sqlalchemy.engine.Connection:
"""Create a connection from SQLAlchemy connection pool.

Returns:
(sqlalchemy.engine.Connection): a single DBAPI connection checked
out from the connection pool.
"""
return self.engine.connect()
107 changes: 107 additions & 0 deletions src/langchain_google_cloud_sql_mysql/mysql_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from collections.abc import Iterable
from typing import Any, Dict, List, Optional, Sequence, cast

import sqlalchemy
from langchain_community.document_loaders.base import BaseLoader
from langchain_core.documents import Document

from langchain_google_cloud_sql_mysql.mysql_engine import MySQLEngine

DEFAULT_METADATA_COL = "langchain_metadata"


def _parse_doc_from_table(
content_columns: Iterable[str],
metadata_columns: Iterable[str],
column_names: Iterable[str],
rows: Sequence[Any],
) -> List[Document]:
docs = []
for row in rows:
page_content = " ".join(
str(getattr(row, column))
for column in content_columns
if column in column_names
)
metadata = {
column: getattr(row, column)
for column in metadata_columns
if column in column_names
}
if DEFAULT_METADATA_COL in metadata:
extra_metadata = json.loads(metadata[DEFAULT_METADATA_COL])
del metadata[DEFAULT_METADATA_COL]
metadata |= extra_metadata
doc = Document(page_content=page_content, metadata=metadata)
docs.append(doc)
return docs


class MySQLLoader(BaseLoader):
"""A class for loading langchain documents from a Cloud SQL MySQL database."""

def __init__(
self,
engine: MySQLEngine,
query: str,
content_columns: Optional[List[str]] = None,
metadata_columns: Optional[List[str]] = None,
):
"""
Args:
engine (MySQLEngine): MySQLEngine object to connect to the MySQL database.
query (str): The query to execute in MySQL format.
content_columns (List[str]): The columns to write into the `page_content`
of the document. Optional.
metadata_columns (List[str]): The columns to write into the `metadata` of the document.
Optional.
"""
self.engine = engine
self.query = query
self.content_columns = content_columns
self.metadata_columns = metadata_columns

def load(self) -> List[Document]:
"""
Load langchain documents from a Cloud SQL MySQL database.

Document page content defaults to the first columns present in the query or table and
metadata defaults to all other columns. Use with content_columns to overwrite the column
used for page content. Use metadata_columns to select specific metadata columns rather
than using all remaining columns.

If multiple content columns are specified, page_content’s string format will default to
space-separated string concatenation.

Returns:
(List[langchain_core.documents.Document]): a list of Documents with metadata from
specific columns.
"""
with self.engine.connect() as connection:
result_proxy = connection.execute(sqlalchemy.text(self.query))
column_names = list(result_proxy.keys())
results = result_proxy.fetchall()
content_columns = self.content_columns or [column_names[0]]
metadata_columns = self.metadata_columns or [
col for col in column_names if col not in content_columns
]
return _parse_doc_from_table(
content_columns,
metadata_columns,
column_names,
results,
)
Loading
Loading