Skip to content

Commit

Permalink
allow sqlaclhemy async engine
Browse files Browse the repository at this point in the history
  • Loading branch information
boetro committed Nov 17, 2023
1 parent 37c4df3 commit f82f0fa
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
14 changes: 12 additions & 2 deletions buildflow/dependencies/sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from google.auth.credentials import Credentials
from google.cloud.sql.connector import Connector, IPTypes
from sqlalchemy import create_engine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.orm import sessionmaker

from buildflow.dependencies import Scope, dependency
Expand All @@ -14,6 +15,7 @@ def engine(
db: CloudSQLDatabase,
db_user: str,
db_password: str,
async_engine: bool = False,
credentials: Optional[Credentials] = None,
**kwargs,
):
Expand Down Expand Up @@ -41,11 +43,17 @@ def getconn() -> Connector:
)
return conn

if async_engine:
return create_async_engine("postgresql+pg8000://", creator=getconn, **kwargs)
return create_engine("postgresql+pg8000://", creator=getconn, **kwargs)


def SessionDep(
db_primitive: CloudSQLDatabase, db_user: str, db_password: str, **kwargs
db_primitive: CloudSQLDatabase,
db_user: str,
db_password: str,
use_async: bool = False,
**kwargs,
):
"""
Args:
Expand All @@ -65,7 +73,9 @@ def __init__(self, db: DBDependency, flow_credentials: FlowCredentials) -> None:
self.SessionLocal = sessionmaker(
autocommit=False,
autoflush=False,
bind=engine(db, db_user, db_password, creds, **kwargs),
bind=engine(
db, db_user, db_password, creds, async_engine=True, **kwargs
),
)

@dependency(scope=Scope.PROCESS)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dependencies = [
# https://github.com/yaml/pyyaml/issues/724
"pyyaml<5.4.0,<6.0.0",
"s3fs",
"sqlalchemy",
"sqlalchemy[asyncio]",
"snowflake-ingest",
"ray[default]>=2.4.0",
"ray[serve]>=2.4.0",
Expand Down

0 comments on commit f82f0fa

Please sign in to comment.