Skip to content

Commit

Permalink
Merge pull request #54 from tjni/asyncmy-support
Browse files Browse the repository at this point in the history
Add support for the asyncmy library.
  • Loading branch information
tjni authored Jan 25, 2025
2 parents 3cb41f6 + a2ffa8a commit c9b479a
Show file tree
Hide file tree
Showing 15 changed files with 1,500 additions and 1,007 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ Implementation of LangGraph CheckpointSaver that uses MySQL.
## Dependencies

To use synchronous `PyMySQLSaver`, install `langgraph-checkpoint-mysql[pymysql]`. To use asynchronous `AIOMySQLSaver`, install `langgraph-checkpoint-mysql[aiomysql]`.
- To use synchronous `PyMySQLSaver`, install `langgraph-checkpoint-mysql[pymysql]`.
- To use asynchronous `AIOMySQLSaver`, install `langgraph-checkpoint-mysql[aiomysql]`.
- To use asynchronous `AsyncMySaver`, install `langgraph-checkpoint-mysql[asyncmy]`.

There is currently no support for other drivers.

Expand Down
150 changes: 139 additions & 11 deletions langgraph-tests/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from uuid import UUID, uuid4

import aiomysql # type: ignore
import asyncmy
import pymysql
import pymysql.constants.ER
import pytest
Expand All @@ -13,9 +14,11 @@

from langgraph.checkpoint.base import BaseCheckpointSaver
from langgraph.checkpoint.mysql.aio import AIOMySQLSaver, ShallowAIOMySQLSaver
from langgraph.checkpoint.mysql.asyncmy import AsyncMySaver, ShallowAsyncMySaver
from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver, ShallowPyMySQLSaver
from langgraph.store.base import BaseStore
from langgraph.store.mysql.aio import AIOMySQLStore
from langgraph.store.mysql.asyncmy import AsyncMyStore
from langgraph.store.mysql.pymysql import PyMySQLStore

DEFAULT_MYSQL_URI = "mysql://mysql:mysql@localhost:5441/"
Expand Down Expand Up @@ -110,6 +113,89 @@ def checkpointer_pymysql_pool():
cursor.execute(f"DROP DATABASE {database}")


@asynccontextmanager
async def _checkpointer_asyncmy():
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await asyncmy.connect(
**AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True,
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
async with AsyncMySaver.from_conn_string(
DEFAULT_MYSQL_URI + database
) as checkpointer:
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await asyncmy.connect(
**AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"DROP DATABASE {database}")


@asynccontextmanager
async def _checkpointer_asyncmy_shallow():
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await asyncmy.connect(
**AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True,
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
async with ShallowAsyncMySaver.from_conn_string(
DEFAULT_MYSQL_URI + database
) as checkpointer:
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await asyncmy.connect(
**AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"DROP DATABASE {database}")


@asynccontextmanager
async def _checkpointer_asyncmy_pool():
database = f"test_{uuid4().hex[:16]}"
# create unique db
async with await asyncmy.connect(
**AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True,
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"CREATE DATABASE {database}")
try:
# yield checkpointer
async with asyncmy.create_pool(
**AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI + database),
maxsize=10,
autocommit=True,
) as pool:
checkpointer = AsyncMySaver(pool)
await checkpointer.setup()
yield checkpointer
finally:
# drop unique db
async with await asyncmy.connect(
**AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"DROP DATABASE {database}")

@asynccontextmanager
async def _checkpointer_aiomysql():
database = f"test_{uuid4().hex[:16]}"
Expand Down Expand Up @@ -207,6 +293,15 @@ async def awith_checkpointer(
elif checkpointer_name == "aiomysql_pool":
async with _checkpointer_aiomysql_pool() as checkpointer:
yield checkpointer
elif checkpointer_name == "asyncmy":
async with _checkpointer_asyncmy() as checkpointer:
yield checkpointer
elif checkpointer_name == "asyncmy_shallow":
async with _checkpointer_asyncmy_shallow() as checkpointer:
yield checkpointer
elif checkpointer_name == "asyncmy_pool":
async with _checkpointer_asyncmy_pool() as checkpointer:
yield checkpointer
else:
raise NotImplementedError(f"Unknown checkpointer: {checkpointer_name}")

Expand Down Expand Up @@ -253,47 +348,74 @@ def store_pymysql_pool():


@asynccontextmanager
async def _store_aiomysql():
async def _store_asyncmy():
database = f"test_{uuid4().hex[:16]}"
async with await aiomysql.connect(
**AIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI),
async with await asyncmy.connect(
**AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True,
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"CREATE DATABASE {database}")
try:
async with AIOMySQLStore.from_conn_string(
async with AsyncMyStore.from_conn_string(
DEFAULT_MYSQL_URI + database
) as store:
await store.setup()
yield store
finally:
async with await aiomysql.connect(
**AIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI),
async with await asyncmy.connect(
**AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"DROP DATABASE {database}")


@asynccontextmanager
async def _store_asyncmy_pool():
database = f"test_{uuid4().hex[:16]}"
async with await asyncmy.connect(
**AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True,
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"CREATE DATABASE {database}")
try:
async with asyncmy.create_pool(
**AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI + database),
maxsize=10,
autocommit=True,
) as pool:
store = AsyncMyStore(pool)
await store.setup()
yield store
finally:
async with await asyncmy.connect(
**AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"DROP DATABASE {database}")


@asynccontextmanager
async def _store_aiomysql_shallow():
async def _store_aiomysql():
database = f"test_{uuid4().hex[:16]}"
async with await aiomysql.connect(
**ShallowAIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI),
**AIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True,
) as conn:
async with conn.cursor() as cursor:
await cursor.execute(f"CREATE DATABASE {database}")
try:
async with ShallowAIOMySQLStore.from_conn_string(
async with AIOMySQLStore.from_conn_string(
DEFAULT_MYSQL_URI + database
) as store:
await store.setup()
yield store
finally:
async with await aiomysql.connect(
**ShallowAIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI),
**AIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI),
autocommit=True
) as conn:
async with conn.cursor() as cursor:
Expand Down Expand Up @@ -335,6 +457,12 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]:
elif store_name == "aiomysql_pool":
async with _store_aiomysql_pool() as store:
yield store
elif store_name == "asyncmy":
async with _store_asyncmy() as store:
yield store
elif store_name == "asyncmy_pool":
async with _store_asyncmy_pool() as store:
yield store
else:
raise NotImplementedError(f"Unknown store {store_name}")

Expand All @@ -355,4 +483,4 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]:
*SHALLOW_CHECKPOINTERS_ASYNC,
]
ALL_STORES_SYNC = ["pymysql", "pymysql_pool"]
ALL_STORES_ASYNC = ["aiomysql", "aiomysql_pool"]
ALL_STORES_ASYNC = ["aiomysql", "aiomysql_pool", "asyncmy", "asyncmy_pool"]
19 changes: 6 additions & 13 deletions langgraph/checkpoint/mysql/_ainternal.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,7 @@ def __aiter__(self) -> AsyncIterator[dict[str, Any]]: ...
R = TypeVar("R", bound=AsyncDictCursor) # cursor type


class AIOMySQLConnection(AsyncContextManager, Protocol):
"""From aiomysql package."""

class AsyncConnection(AsyncContextManager, Protocol):
async def begin(self) -> None:
"""Begin transaction."""
...
Expand All @@ -61,19 +59,17 @@ async def set_charset(self, charset: str) -> None:
...


C = TypeVar("C", bound=AIOMySQLConnection) # connection type
COut = TypeVar("COut", bound=AIOMySQLConnection, covariant=True) # connection type

C = TypeVar("C", bound=AsyncConnection) # connection type
COut = TypeVar("COut", bound=AsyncConnection, covariant=True) # connection type

class AIOMySQLPool(Protocol, Generic[COut]):
"""From aiomysql package."""

class AsyncPool(Protocol, Generic[COut]):
def acquire(self) -> COut:
"""Gets a connection from the connection pool."""
...


Conn = Union[C, AIOMySQLPool[C]]
Conn = Union[C, AsyncPool[C]]


@asynccontextmanager
Expand All @@ -83,10 +79,7 @@ async def get_connection(
if hasattr(conn, "cursor"):
yield cast(C, conn)
elif hasattr(conn, "acquire"):
async with cast(AIOMySQLPool[C], conn).acquire() as _conn:
# This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119
# is merged into aiomysql.
await _conn.set_charset("utf8mb4")
async with cast(AsyncPool[C], conn).acquire() as _conn:
yield _conn
else:
raise TypeError(f"Invalid connection type: {type(conn)}")
Loading

0 comments on commit c9b479a

Please sign in to comment.