From a2ffa8adabcbff7a92592b7bbf4f3cbbf5a4fea9 Mon Sep 17 00:00:00 2001 From: Theodore Ni <3806110+tjni@users.noreply.github.com> Date: Fri, 24 Jan 2025 17:11:44 -0800 Subject: [PATCH] Add support for the asyncmy library. --- README.md | 4 +- langgraph-tests/tests/conftest.py | 150 +++++- langgraph/checkpoint/mysql/_ainternal.py | 19 +- langgraph/checkpoint/mysql/aio.py | 439 +--------------- langgraph/checkpoint/mysql/aio_base.py | 442 +++++++++++++++++ langgraph/checkpoint/mysql/asyncmy.py | 97 ++++ langgraph/store/mysql/__init__.py | 3 +- langgraph/store/mysql/aio.py | 214 +------- langgraph/store/mysql/aio_base.py | 219 ++++++++ langgraph/store/mysql/asyncmy.py | 58 +++ langgraph/store/mysql/base.py | 42 +- poetry.lock | 69 ++- pyproject.toml | 5 +- tests/test_async.py | 140 +++++- tests/test_async_store.py | 606 ++++++++++++----------- 15 files changed, 1500 insertions(+), 1007 deletions(-) create mode 100644 langgraph/checkpoint/mysql/aio_base.py create mode 100644 langgraph/checkpoint/mysql/asyncmy.py create mode 100644 langgraph/store/mysql/aio_base.py create mode 100644 langgraph/store/mysql/asyncmy.py diff --git a/README.md b/README.md index 251ecb6..d502093 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,9 @@ Implementation of LangGraph CheckpointSaver that uses MySQL. ## Dependencies -To use synchronous `PyMySQLSaver`, install `langgraph-checkpoint-mysql[pymysql]`. To use asynchronous `AIOMySQLSaver`, install `langgraph-checkpoint-mysql[aiomysql]`. +- To use synchronous `PyMySQLSaver`, install `langgraph-checkpoint-mysql[pymysql]`. +- To use asynchronous `AIOMySQLSaver`, install `langgraph-checkpoint-mysql[aiomysql]`. +- To use asynchronous `AsyncMySaver`, install `langgraph-checkpoint-mysql[asyncmy]`. There is currently no support for other drivers. diff --git a/langgraph-tests/tests/conftest.py b/langgraph-tests/tests/conftest.py index 4c2ceac..33cf2b5 100644 --- a/langgraph-tests/tests/conftest.py +++ b/langgraph-tests/tests/conftest.py @@ -3,6 +3,7 @@ from uuid import UUID, uuid4 import aiomysql # type: ignore +import asyncmy import pymysql import pymysql.constants.ER import pytest @@ -13,9 +14,11 @@ from langgraph.checkpoint.base import BaseCheckpointSaver from langgraph.checkpoint.mysql.aio import AIOMySQLSaver, ShallowAIOMySQLSaver +from langgraph.checkpoint.mysql.asyncmy import AsyncMySaver, ShallowAsyncMySaver from langgraph.checkpoint.mysql.pymysql import PyMySQLSaver, ShallowPyMySQLSaver from langgraph.store.base import BaseStore from langgraph.store.mysql.aio import AIOMySQLStore +from langgraph.store.mysql.asyncmy import AsyncMyStore from langgraph.store.mysql.pymysql import PyMySQLStore DEFAULT_MYSQL_URI = "mysql://mysql:mysql@localhost:5441/" @@ -110,6 +113,89 @@ def checkpointer_pymysql_pool(): cursor.execute(f"DROP DATABASE {database}") +@asynccontextmanager +async def _checkpointer_asyncmy(): + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + # yield checkpointer + async with AsyncMySaver.from_conn_string( + DEFAULT_MYSQL_URI + database + ) as checkpointer: + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + + +@asynccontextmanager +async def _checkpointer_asyncmy_shallow(): + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + # yield checkpointer + async with ShallowAsyncMySaver.from_conn_string( + DEFAULT_MYSQL_URI + database + ) as checkpointer: + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + + +@asynccontextmanager +async def _checkpointer_asyncmy_pool(): + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + # yield checkpointer + async with asyncmy.create_pool( + **AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI + database), + maxsize=10, + autocommit=True, + ) as pool: + checkpointer = AsyncMySaver(pool) + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + @asynccontextmanager async def _checkpointer_aiomysql(): database = f"test_{uuid4().hex[:16]}" @@ -207,6 +293,15 @@ async def awith_checkpointer( elif checkpointer_name == "aiomysql_pool": async with _checkpointer_aiomysql_pool() as checkpointer: yield checkpointer + elif checkpointer_name == "asyncmy": + async with _checkpointer_asyncmy() as checkpointer: + yield checkpointer + elif checkpointer_name == "asyncmy_shallow": + async with _checkpointer_asyncmy_shallow() as checkpointer: + yield checkpointer + elif checkpointer_name == "asyncmy_pool": + async with _checkpointer_asyncmy_pool() as checkpointer: + yield checkpointer else: raise NotImplementedError(f"Unknown checkpointer: {checkpointer_name}") @@ -253,23 +348,50 @@ def store_pymysql_pool(): @asynccontextmanager -async def _store_aiomysql(): +async def _store_asyncmy(): database = f"test_{uuid4().hex[:16]}" - async with await aiomysql.connect( - **AIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), + async with await asyncmy.connect( + **AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True, ) as conn: async with conn.cursor() as cursor: await cursor.execute(f"CREATE DATABASE {database}") try: - async with AIOMySQLStore.from_conn_string( + async with AsyncMyStore.from_conn_string( DEFAULT_MYSQL_URI + database ) as store: await store.setup() yield store finally: - async with await aiomysql.connect( - **AIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), + async with await asyncmy.connect( + **AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + + +@asynccontextmanager +async def _store_asyncmy_pool(): + database = f"test_{uuid4().hex[:16]}" + async with await asyncmy.connect( + **AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + async with asyncmy.create_pool( + **AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI + database), + maxsize=10, + autocommit=True, + ) as pool: + store = AsyncMyStore(pool) + await store.setup() + yield store + finally: + async with await asyncmy.connect( + **AsyncMyStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True ) as conn: async with conn.cursor() as cursor: @@ -277,23 +399,23 @@ async def _store_aiomysql(): @asynccontextmanager -async def _store_aiomysql_shallow(): +async def _store_aiomysql(): database = f"test_{uuid4().hex[:16]}" async with await aiomysql.connect( - **ShallowAIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), + **AIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True, ) as conn: async with conn.cursor() as cursor: await cursor.execute(f"CREATE DATABASE {database}") try: - async with ShallowAIOMySQLStore.from_conn_string( + async with AIOMySQLStore.from_conn_string( DEFAULT_MYSQL_URI + database ) as store: await store.setup() yield store finally: async with await aiomysql.connect( - **ShallowAIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), + **AIOMySQLStore.parse_conn_string(DEFAULT_MYSQL_URI), autocommit=True ) as conn: async with conn.cursor() as cursor: @@ -335,6 +457,12 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]: elif store_name == "aiomysql_pool": async with _store_aiomysql_pool() as store: yield store + elif store_name == "asyncmy": + async with _store_asyncmy() as store: + yield store + elif store_name == "asyncmy_pool": + async with _store_asyncmy_pool() as store: + yield store else: raise NotImplementedError(f"Unknown store {store_name}") @@ -355,4 +483,4 @@ async def awith_store(store_name: Optional[str]) -> AsyncIterator[BaseStore]: *SHALLOW_CHECKPOINTERS_ASYNC, ] ALL_STORES_SYNC = ["pymysql", "pymysql_pool"] -ALL_STORES_ASYNC = ["aiomysql", "aiomysql_pool"] +ALL_STORES_ASYNC = ["aiomysql", "aiomysql_pool", "asyncmy", "asyncmy_pool"] diff --git a/langgraph/checkpoint/mysql/_ainternal.py b/langgraph/checkpoint/mysql/_ainternal.py index 0e8b011..18031ce 100644 --- a/langgraph/checkpoint/mysql/_ainternal.py +++ b/langgraph/checkpoint/mysql/_ainternal.py @@ -41,9 +41,7 @@ def __aiter__(self) -> AsyncIterator[dict[str, Any]]: ... R = TypeVar("R", bound=AsyncDictCursor) # cursor type -class AIOMySQLConnection(AsyncContextManager, Protocol): - """From aiomysql package.""" - +class AsyncConnection(AsyncContextManager, Protocol): async def begin(self) -> None: """Begin transaction.""" ... @@ -61,19 +59,17 @@ async def set_charset(self, charset: str) -> None: ... -C = TypeVar("C", bound=AIOMySQLConnection) # connection type -COut = TypeVar("COut", bound=AIOMySQLConnection, covariant=True) # connection type - +C = TypeVar("C", bound=AsyncConnection) # connection type +COut = TypeVar("COut", bound=AsyncConnection, covariant=True) # connection type -class AIOMySQLPool(Protocol, Generic[COut]): - """From aiomysql package.""" +class AsyncPool(Protocol, Generic[COut]): def acquire(self) -> COut: """Gets a connection from the connection pool.""" ... -Conn = Union[C, AIOMySQLPool[C]] +Conn = Union[C, AsyncPool[C]] @asynccontextmanager @@ -83,10 +79,7 @@ async def get_connection( if hasattr(conn, "cursor"): yield cast(C, conn) elif hasattr(conn, "acquire"): - async with cast(AIOMySQLPool[C], conn).acquire() as _conn: - # This seems necessary until https://github.com/PyMySQL/PyMySQL/pull/1119 - # is merged into aiomysql. - await _conn.set_charset("utf8mb4") + async with cast(AsyncPool[C], conn).acquire() as _conn: yield _conn else: raise TypeError(f"Invalid connection type: {type(conn)}") diff --git a/langgraph/checkpoint/mysql/aio.py b/langgraph/checkpoint/mysql/aio.py index 98c365c..020347b 100644 --- a/langgraph/checkpoint/mysql/aio.py +++ b/langgraph/checkpoint/mysql/aio.py @@ -1,49 +1,20 @@ -import asyncio -import json import urllib.parse -from collections.abc import AsyncIterator, Iterator, Sequence +from collections.abc import AsyncIterator from contextlib import asynccontextmanager from typing import Any, Optional, cast import aiomysql # type: ignore -from langchain_core.runnables import RunnableConfig from typing_extensions import Self, override -from langgraph.checkpoint.base import ( - WRITES_IDX_MAP, - ChannelVersions, - Checkpoint, - CheckpointMetadata, - CheckpointTuple, - get_checkpoint_id, -) from langgraph.checkpoint.mysql import _ainternal -from langgraph.checkpoint.mysql.base import BaseMySQLSaver +from langgraph.checkpoint.mysql.aio_base import BaseAsyncMySQLSaver from langgraph.checkpoint.mysql.shallow import BaseShallowAsyncMySQLSaver -from langgraph.checkpoint.mysql.utils import ( - deserialize_channel_values, - deserialize_pending_sends, - deserialize_pending_writes, -) from langgraph.checkpoint.serde.base import SerializerProtocol Conn = _ainternal.Conn[aiomysql.Connection] # For backward compatibility -class AIOMySQLSaver(BaseMySQLSaver): - lock: asyncio.Lock - - def __init__( - self, - conn: _ainternal.Conn, - serde: Optional[SerializerProtocol] = None, - ) -> None: - super().__init__(serde=serde) - - self.conn = conn - self.lock = asyncio.Lock() - self.loop = asyncio.get_running_loop() - +class AIOMySQLSaver(BaseAsyncMySQLSaver[aiomysql.Connection, aiomysql.DictCursor]): @staticmethod def parse_conn_string(conn_string: str) -> dict[str, Any]: parsed = urllib.parse.urlparse(conn_string) @@ -87,406 +58,10 @@ async def from_conn_string( ) as conn: yield cls(conn=conn, serde=serde) - async def setup(self) -> None: - """Set up the checkpoint database asynchronously. - - This method creates the necessary tables in the MySQL database if they don't - already exist and runs database migrations. It MUST be called directly by the user - the first time checkpointer is used. - """ - async with self._cursor() as cur: - await cur.execute(self.MIGRATIONS[0]) - await cur.execute( - "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" - ) - row = await cur.fetchone() - if row is None: - version = -1 - else: - version = row["v"] - for v, migration in zip( - range(version + 1, len(self.MIGRATIONS)), - self.MIGRATIONS[version + 1 :], - ): - await cur.execute(migration) - await cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") - - async def alist( - self, - config: Optional[RunnableConfig], - *, - filter: Optional[dict[str, Any]] = None, - before: Optional[RunnableConfig] = None, - limit: Optional[int] = None, - ) -> AsyncIterator[CheckpointTuple]: - """List checkpoints from the database asynchronously. - - This method retrieves a list of checkpoint tuples from the MySQL database based - on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). - - Args: - config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. - filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. - before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. - limit (Optional[int]): Maximum number of checkpoints to return. - - Yields: - AsyncIterator[CheckpointTuple]: An asynchronous iterator of matching checkpoint tuples. - """ - where, args = self._search_where(config, filter, before) - query = self._select_sql(where) + " ORDER BY checkpoint_id DESC" - if limit: - query += f" LIMIT {limit}" - # if we change this to use .stream() we need to make sure to close the cursor - async with self._cursor() as cur: - await cur.execute(query, args) - async for value in cur: - yield CheckpointTuple( - { - "configurable": { - "thread_id": value["thread_id"], - "checkpoint_ns": value["checkpoint_ns"], - "checkpoint_id": value["checkpoint_id"], - } - }, - await asyncio.to_thread( - self._load_checkpoint, - json.loads(value["checkpoint"]), - deserialize_channel_values(value["channel_values"]), - deserialize_pending_sends(value["pending_sends"]), - ), - self._load_metadata(value["metadata"]), - ( - { - "configurable": { - "thread_id": value["thread_id"], - "checkpoint_ns": value["checkpoint_ns"], - "checkpoint_id": value["parent_checkpoint_id"], - } - } - if value["parent_checkpoint_id"] - else None - ), - await asyncio.to_thread( - self._load_writes, - deserialize_pending_writes(value["pending_writes"]), - ), - ) - - async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: - """Get a checkpoint tuple from the database asynchronously. - - This method retrieves a checkpoint tuple from the MySQL database based on the - provided config. If the config contains a "checkpoint_id" key, the checkpoint with - the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint - for the given thread ID is retrieved. - - Args: - config (RunnableConfig): The config to use for retrieving the checkpoint. - - Returns: - Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. - """ - thread_id = config["configurable"]["thread_id"] - checkpoint_id = get_checkpoint_id(config) - checkpoint_ns = config["configurable"].get("checkpoint_ns", "") - if checkpoint_id: - args = { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": checkpoint_id, - } - where = "WHERE thread_id = %(thread_id)s AND checkpoint_ns_hash = UNHEX(MD5(%(checkpoint_ns)s)) AND checkpoint_id = %(checkpoint_id)s" - else: - args = { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - } - where = "WHERE thread_id = %(thread_id)s AND checkpoint_ns_hash = UNHEX(MD5(%(checkpoint_ns)s))" - - query = self._select_sql(where) - if not checkpoint_id: - query += " ORDER BY checkpoint_id DESC LIMIT 1" - async with self._cursor() as cur: - await cur.execute( - query, - args, - ) - - async for value in cur: - return CheckpointTuple( - { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": value["checkpoint_id"], - } - }, - await asyncio.to_thread( - self._load_checkpoint, - json.loads(value["checkpoint"]), - deserialize_channel_values(value["channel_values"]), - deserialize_pending_sends(value["pending_sends"]), - ), - self._load_metadata(value["metadata"]), - ( - { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": value["parent_checkpoint_id"], - } - } - if value["parent_checkpoint_id"] - else None - ), - await asyncio.to_thread( - self._load_writes, - deserialize_pending_writes(value["pending_writes"]), - ), - ) - - async def aput( - self, - config: RunnableConfig, - checkpoint: Checkpoint, - metadata: CheckpointMetadata, - new_versions: ChannelVersions, - ) -> RunnableConfig: - """Save a checkpoint to the database asynchronously. - - This method saves a checkpoint to the MySQL database. The checkpoint is associated - with the provided config and its parent config (if any). - - Args: - config (RunnableConfig): The config to associate with the checkpoint. - checkpoint (Checkpoint): The checkpoint to save. - metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. - new_versions (ChannelVersions): New channel versions as of this write. - - Returns: - RunnableConfig: Updated configuration after storing the checkpoint. - """ - configurable = config["configurable"].copy() - thread_id = configurable.pop("thread_id") - checkpoint_ns = configurable.pop("checkpoint_ns") - checkpoint_id = configurable.pop( - "checkpoint_id", configurable.pop("thread_ts", None) - ) - - copy = checkpoint.copy() - next_config = { - "configurable": { - "thread_id": thread_id, - "checkpoint_ns": checkpoint_ns, - "checkpoint_id": checkpoint["id"], - } - } - - async with self._cursor(pipeline=True) as cur: - await cur.executemany( - self.UPSERT_CHECKPOINT_BLOBS_SQL, - await asyncio.to_thread( - self._dump_blobs, - thread_id, - checkpoint_ns, - copy.pop("channel_values"), # type: ignore[misc] - new_versions, - ), - ) - await cur.execute( - self.UPSERT_CHECKPOINTS_SQL, - ( - thread_id, - checkpoint_ns, - checkpoint_ns, - checkpoint["id"], - checkpoint_id, - json.dumps(self._dump_checkpoint(copy)), - self._dump_metadata(metadata), - ), - ) - return next_config - - async def aput_writes( - self, - config: RunnableConfig, - writes: Sequence[tuple[str, Any]], - task_id: str, - task_path: str = "", - ) -> None: - """Store intermediate writes linked to a checkpoint asynchronously. - - This method saves intermediate writes associated with a checkpoint to the database. - - Args: - config (RunnableConfig): Configuration of the related checkpoint. - writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. - task_id (str): Identifier for the task creating the writes. - """ - query = ( - self.UPSERT_CHECKPOINT_WRITES_SQL - if all(w[0] in WRITES_IDX_MAP for w in writes) - else self.INSERT_CHECKPOINT_WRITES_SQL - ) - params = await asyncio.to_thread( - self._dump_writes, - config["configurable"]["thread_id"], - config["configurable"]["checkpoint_ns"], - config["configurable"]["checkpoint_id"], - task_id, - task_path, - writes, - ) - async with self._cursor(pipeline=True) as cur: - await cur.executemany(query, params) - - @asynccontextmanager - async def _cursor( - self, *, pipeline: bool = False - ) -> AsyncIterator[aiomysql.DictCursor]: - """Create a database cursor as a context manager. - - Args: - pipeline (bool): whether to use transaction context manager and handle concurrency - """ - async with _ainternal.get_connection(self.conn) as conn: - if pipeline: - async with self.lock: - await conn.begin() - try: - async with conn.cursor(aiomysql.DictCursor) as cur: - yield cur - await conn.commit() - except: - await conn.rollback() - raise - else: - async with ( - self.lock, - conn.cursor(aiomysql.DictCursor) as cur, - ): - yield cur - - def list( - self, - config: Optional[RunnableConfig], - *, - filter: Optional[dict[str, Any]] = None, - before: Optional[RunnableConfig] = None, - limit: Optional[int] = None, - ) -> Iterator[CheckpointTuple]: - """List checkpoints from the database. - - This method retrieves a list of checkpoint tuples from the MySQL database based - on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). - - Args: - config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. - filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. - before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. - limit (Optional[int]): Maximum number of checkpoints to return. - - Yields: - Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples. - """ - try: - # check if we are in the main thread, only bg threads can block - # we don't check in other methods to avoid the overhead - if asyncio.get_running_loop() is self.loop: - raise asyncio.InvalidStateError( - "Synchronous calls to AsyncSqliteSaver are only allowed from a " - "different thread. From the main thread, use the async interface. " - "For example, use `checkpointer.alist(...)` or `await " - "graph.ainvoke(...)`." - ) - except RuntimeError: - pass - aiter_ = self.alist(config, filter=filter, before=before, limit=limit) - while True: - try: - yield asyncio.run_coroutine_threadsafe( - anext(aiter_), # noqa: F821 - self.loop, - ).result() - except StopAsyncIteration: - break - - def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: - """Get a checkpoint tuple from the database. - - This method retrieves a checkpoint tuple from the MySQL database based on the - provided config. If the config contains a "checkpoint_id" key, the checkpoint with - the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint - for the given thread ID is retrieved. - - Args: - config (RunnableConfig): The config to use for retrieving the checkpoint. - - Returns: - Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. - """ - try: - # check if we are in the main thread, only bg threads can block - # we don't check in other methods to avoid the overhead - if asyncio.get_running_loop() is self.loop: - raise asyncio.InvalidStateError( - "Synchronous calls to AIOMySQLSaver are only allowed from a " - "different thread. From the main thread, use the async interface. " - "For example, use `await checkpointer.aget_tuple(...)` or `await " - "graph.ainvoke(...)`." - ) - except RuntimeError: - pass - return asyncio.run_coroutine_threadsafe( - self.aget_tuple(config), self.loop - ).result() - - def put( - self, - config: RunnableConfig, - checkpoint: Checkpoint, - metadata: CheckpointMetadata, - new_versions: ChannelVersions, - ) -> RunnableConfig: - """Save a checkpoint to the database. - - This method saves a checkpoint to the MySQL database. The checkpoint is associated - with the provided config and its parent config (if any). - - Args: - config (RunnableConfig): The config to associate with the checkpoint. - checkpoint (Checkpoint): The checkpoint to save. - metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. - new_versions (ChannelVersions): New channel versions as of this write. - - Returns: - RunnableConfig: Updated configuration after storing the checkpoint. - """ - return asyncio.run_coroutine_threadsafe( - self.aput(config, checkpoint, metadata, new_versions), self.loop - ).result() - - def put_writes( - self, - config: RunnableConfig, - writes: Sequence[tuple[str, Any]], - task_id: str, - task_path: str = "", - ) -> None: - """Store intermediate writes linked to a checkpoint. - - This method saves intermediate writes associated with a checkpoint to the database. - - Args: - config (RunnableConfig): Configuration of the related checkpoint. - writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. - task_id (str): Identifier for the task creating the writes. - task_path (str): Path of the task creating the writes. - """ - return asyncio.run_coroutine_threadsafe( - self.aput_writes(config, writes, task_id, task_path), self.loop - ).result() + @override + @staticmethod + def _get_cursor_from_connection(conn: aiomysql.Connection) -> aiomysql.DictCursor: + return cast(aiomysql.DictCursor, conn.cursor(aiomysql.DictCursor)) class ShallowAIOMySQLSaver( diff --git a/langgraph/checkpoint/mysql/aio_base.py b/langgraph/checkpoint/mysql/aio_base.py new file mode 100644 index 0000000..3782b31 --- /dev/null +++ b/langgraph/checkpoint/mysql/aio_base.py @@ -0,0 +1,442 @@ +import asyncio +import json +from collections.abc import AsyncIterator, Iterator, Sequence +from contextlib import asynccontextmanager +from typing import Any, Generic, Optional + +from langchain_core.runnables import RunnableConfig + +from langgraph.checkpoint.base import ( + WRITES_IDX_MAP, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + get_checkpoint_id, +) +from langgraph.checkpoint.mysql import _ainternal +from langgraph.checkpoint.mysql.base import BaseMySQLSaver +from langgraph.checkpoint.mysql.utils import ( + deserialize_channel_values, + deserialize_pending_sends, + deserialize_pending_writes, +) +from langgraph.checkpoint.serde.base import SerializerProtocol + + +class BaseAsyncMySQLSaver(BaseMySQLSaver, Generic[_ainternal.C, _ainternal.R]): + lock: asyncio.Lock + + def __init__( + self, + conn: _ainternal.Conn[_ainternal.C], + serde: Optional[SerializerProtocol] = None, + ) -> None: + super().__init__(serde=serde) + + self.conn = conn + self.lock = asyncio.Lock() + self.loop = asyncio.get_running_loop() + + @staticmethod + def _get_cursor_from_connection(conn: _ainternal.C) -> _ainternal.R: + raise NotImplementedError + + async def setup(self) -> None: + """Set up the checkpoint database asynchronously. + + This method creates the necessary tables in the MySQL database if they don't + already exist and runs database migrations. It MUST be called directly by the user + the first time checkpointer is used. + """ + async with self._cursor() as cur: + await cur.execute(self.MIGRATIONS[0]) + await cur.execute( + "SELECT v FROM checkpoint_migrations ORDER BY v DESC LIMIT 1" + ) + row = await cur.fetchone() + if row is None: + version = -1 + else: + version = row["v"] + for v, migration in zip( + range(version + 1, len(self.MIGRATIONS)), + self.MIGRATIONS[version + 1 :], + ): + await cur.execute(migration) + await cur.execute(f"INSERT INTO checkpoint_migrations (v) VALUES ({v})") + + async def alist( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """List checkpoints from the database asynchronously. + + This method retrieves a list of checkpoint tuples from the MySQL database based + on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). + + Args: + config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. + before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. + limit (Optional[int]): Maximum number of checkpoints to return. + + Yields: + AsyncIterator[CheckpointTuple]: An asynchronous iterator of matching checkpoint tuples. + """ + where, args = self._search_where(config, filter, before) + query = self._select_sql(where) + " ORDER BY checkpoint_id DESC" + if limit: + query += f" LIMIT {limit}" + # if we change this to use .stream() we need to make sure to close the cursor + async with self._cursor() as cur: + await cur.execute(query, args) + async for value in cur: + yield CheckpointTuple( + { + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["checkpoint_id"], + } + }, + await asyncio.to_thread( + self._load_checkpoint, + json.loads(value["checkpoint"]), + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), + ), + self._load_metadata(value["metadata"]), + ( + { + "configurable": { + "thread_id": value["thread_id"], + "checkpoint_ns": value["checkpoint_ns"], + "checkpoint_id": value["parent_checkpoint_id"], + } + } + if value["parent_checkpoint_id"] + else None + ), + await asyncio.to_thread( + self._load_writes, + deserialize_pending_writes(value["pending_writes"]), + ), + ) + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database asynchronously. + + This method retrieves a checkpoint tuple from the MySQL database based on the + provided config. If the config contains a "checkpoint_id" key, the checkpoint with + the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint + for the given thread ID is retrieved. + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + thread_id = config["configurable"]["thread_id"] + checkpoint_id = get_checkpoint_id(config) + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + if checkpoint_id: + args = { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + where = "WHERE thread_id = %(thread_id)s AND checkpoint_ns_hash = UNHEX(MD5(%(checkpoint_ns)s)) AND checkpoint_id = %(checkpoint_id)s" + else: + args = { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + } + where = "WHERE thread_id = %(thread_id)s AND checkpoint_ns_hash = UNHEX(MD5(%(checkpoint_ns)s))" + + query = self._select_sql(where) + if not checkpoint_id: + query += " ORDER BY checkpoint_id DESC LIMIT 1" + async with self._cursor() as cur: + await cur.execute( + query, + args, + ) + + async for value in cur: + return CheckpointTuple( + { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": value["checkpoint_id"], + } + }, + await asyncio.to_thread( + self._load_checkpoint, + json.loads(value["checkpoint"]), + deserialize_channel_values(value["channel_values"]), + deserialize_pending_sends(value["pending_sends"]), + ), + self._load_metadata(value["metadata"]), + ( + { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": value["parent_checkpoint_id"], + } + } + if value["parent_checkpoint_id"] + else None + ), + await asyncio.to_thread( + self._load_writes, + deserialize_pending_writes(value["pending_writes"]), + ), + ) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database asynchronously. + + This method saves a checkpoint to the MySQL database. The checkpoint is associated + with the provided config and its parent config (if any). + + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + configurable = config["configurable"].copy() + thread_id = configurable.pop("thread_id") + checkpoint_ns = configurable.pop("checkpoint_ns") + checkpoint_id = configurable.pop( + "checkpoint_id", configurable.pop("thread_ts", None) + ) + + copy = checkpoint.copy() + next_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint["id"], + } + } + + async with self._cursor(pipeline=True) as cur: + await cur.executemany( + self.UPSERT_CHECKPOINT_BLOBS_SQL, + await asyncio.to_thread( + self._dump_blobs, + thread_id, + checkpoint_ns, + copy.pop("channel_values"), # type: ignore[misc] + new_versions, + ), + ) + await cur.execute( + self.UPSERT_CHECKPOINTS_SQL, + ( + thread_id, + checkpoint_ns, + checkpoint_ns, + checkpoint["id"], + checkpoint_id, + json.dumps(self._dump_checkpoint(copy)), + self._dump_metadata(metadata), + ), + ) + return next_config + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Store intermediate writes linked to a checkpoint asynchronously. + + This method saves intermediate writes associated with a checkpoint to the database. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. + task_id (str): Identifier for the task creating the writes. + """ + query = ( + self.UPSERT_CHECKPOINT_WRITES_SQL + if all(w[0] in WRITES_IDX_MAP for w in writes) + else self.INSERT_CHECKPOINT_WRITES_SQL + ) + params = await asyncio.to_thread( + self._dump_writes, + config["configurable"]["thread_id"], + config["configurable"]["checkpoint_ns"], + config["configurable"]["checkpoint_id"], + task_id, + task_path, + writes, + ) + async with self._cursor(pipeline=True) as cur: + await cur.executemany(query, params) + + @asynccontextmanager + async def _cursor(self, *, pipeline: bool = False) -> AsyncIterator[_ainternal.R]: + """Create a database cursor as a context manager. + + Args: + pipeline (bool): whether to use transaction context manager and handle concurrency + """ + async with _ainternal.get_connection(self.conn) as conn: + if pipeline: + async with self.lock: + await conn.begin() + try: + async with self._get_cursor_from_connection(conn) as cur: + yield cur + await conn.commit() + except: + await conn.rollback() + raise + else: + async with ( + self.lock, + self._get_cursor_from_connection(conn) as cur, + ): + yield cur + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """List checkpoints from the database. + + This method retrieves a list of checkpoint tuples from the MySQL database based + on the provided config. The checkpoints are ordered by checkpoint ID in descending order (newest first). + + Args: + config (Optional[RunnableConfig]): Base configuration for filtering checkpoints. + filter (Optional[Dict[str, Any]]): Additional filtering criteria for metadata. + before (Optional[RunnableConfig]): If provided, only checkpoints before the specified checkpoint ID are returned. Defaults to None. + limit (Optional[int]): Maximum number of checkpoints to return. + + Yields: + Iterator[CheckpointTuple]: An iterator of matching checkpoint tuples. + """ + try: + # check if we are in the main thread, only bg threads can block + # we don't check in other methods to avoid the overhead + if asyncio.get_running_loop() is self.loop: + raise asyncio.InvalidStateError( + "Synchronous calls to AsyncSqliteSaver are only allowed from a " + "different thread. From the main thread, use the async interface. " + "For example, use `checkpointer.alist(...)` or `await " + "graph.ainvoke(...)`." + ) + except RuntimeError: + pass + aiter_ = self.alist(config, filter=filter, before=before, limit=limit) + while True: + try: + yield asyncio.run_coroutine_threadsafe( + anext(aiter_), # noqa: F821 + self.loop, + ).result() + except StopAsyncIteration: + break + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Get a checkpoint tuple from the database. + + This method retrieves a checkpoint tuple from the MySQL database based on the + provided config. If the config contains a "checkpoint_id" key, the checkpoint with + the matching thread ID and "checkpoint_id" is retrieved. Otherwise, the latest checkpoint + for the given thread ID is retrieved. + + Args: + config (RunnableConfig): The config to use for retrieving the checkpoint. + + Returns: + Optional[CheckpointTuple]: The retrieved checkpoint tuple, or None if no matching checkpoint was found. + """ + try: + # check if we are in the main thread, only bg threads can block + # we don't check in other methods to avoid the overhead + if asyncio.get_running_loop() is self.loop: + raise asyncio.InvalidStateError( + "Synchronous calls to AIOMySQLSaver are only allowed from a " + "different thread. From the main thread, use the async interface. " + "For example, use `await checkpointer.aget_tuple(...)` or `await " + "graph.ainvoke(...)`." + ) + except RuntimeError: + pass + return asyncio.run_coroutine_threadsafe( + self.aget_tuple(config), self.loop + ).result() + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Save a checkpoint to the database. + + This method saves a checkpoint to the MySQL database. The checkpoint is associated + with the provided config and its parent config (if any). + + Args: + config (RunnableConfig): The config to associate with the checkpoint. + checkpoint (Checkpoint): The checkpoint to save. + metadata (CheckpointMetadata): Additional metadata to save with the checkpoint. + new_versions (ChannelVersions): New channel versions as of this write. + + Returns: + RunnableConfig: Updated configuration after storing the checkpoint. + """ + return asyncio.run_coroutine_threadsafe( + self.aput(config, checkpoint, metadata, new_versions), self.loop + ).result() + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Store intermediate writes linked to a checkpoint. + + This method saves intermediate writes associated with a checkpoint to the database. + + Args: + config (RunnableConfig): Configuration of the related checkpoint. + writes (Sequence[Tuple[str, Any]]): List of writes to store, each as (channel, value) pair. + task_id (str): Identifier for the task creating the writes. + task_path (str): Path of the task creating the writes. + """ + return asyncio.run_coroutine_threadsafe( + self.aput_writes(config, writes, task_id, task_path), self.loop + ).result() diff --git a/langgraph/checkpoint/mysql/asyncmy.py b/langgraph/checkpoint/mysql/asyncmy.py new file mode 100644 index 0000000..3760a3f --- /dev/null +++ b/langgraph/checkpoint/mysql/asyncmy.py @@ -0,0 +1,97 @@ +import urllib.parse +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any, Optional, cast + +from asyncmy import Connection, connect # type: ignore +from asyncmy.cursors import DictCursor # type: ignore +from typing_extensions import Self, override + +from langgraph.checkpoint.mysql.aio_base import BaseAsyncMySQLSaver +from langgraph.checkpoint.mysql.shallow import BaseShallowAsyncMySQLSaver +from langgraph.checkpoint.serde.base import SerializerProtocol + + +class AsyncMySaver(BaseAsyncMySQLSaver[Connection, DictCursor]): + @staticmethod + def parse_conn_string(conn_string: str) -> dict[str, Any]: + parsed = urllib.parse.urlparse(conn_string) + + # In order to provide additional params via the connection string, + # we convert the parsed.query to a dict so we can access the values. + # This is necessary when using a unix socket, for example. + params_as_dict = dict(urllib.parse.parse_qsl(parsed.query)) + + return { + "host": parsed.hostname or "localhost", + "user": parsed.username, + "password": parsed.password or "", + "db": parsed.path[1:] or None, + "port": parsed.port or 3306, + "unix_socket": params_as_dict.get("unix_socket"), + } + + @classmethod + @asynccontextmanager + async def from_conn_string( + cls, + conn_string: str, + *, + serde: Optional[SerializerProtocol] = None, + ) -> AsyncIterator[Self]: + """Create a new AsyncMySaver instance from a connection string. + + Args: + conn_string (str): The MySQL connection info string. + + Returns: + AsyncMySaver: A new AsyncMySaver instance. + + Example: + conn_string=mysql+asyncmy://user:password@localhost/db?unix_socket=/path/to/socket + """ + async with connect( + **cls.parse_conn_string(conn_string), + autocommit=True, + ) as conn: + yield cls(conn=conn, serde=serde) + + @override + @staticmethod + def _get_cursor_from_connection(conn: Connection) -> DictCursor: + return cast(DictCursor, conn.cursor(DictCursor)) + + +class ShallowAsyncMySaver(BaseShallowAsyncMySQLSaver[Connection, DictCursor]): + @classmethod + @asynccontextmanager + async def from_conn_string( + cls, + conn_string: str, + *, + serde: Optional[SerializerProtocol] = None, + ) -> AsyncIterator[Self]: + """Create a new ShallowAsyncMySaver instance from a connection string. + + Args: + conn_string (str): The MySQL connection info string. + + Returns: + ShallowAsyncMySaver: A new ShallowAsyncMySaver instance. + + Example: + conn_string=mysql+asyncmy://user:password@localhost/db?unix_socket=/path/to/socket + """ + async with connect( + **AsyncMySaver.parse_conn_string(conn_string), + autocommit=True, + ) as conn: + yield cls(conn=conn, serde=serde) + + @override + @staticmethod + def _get_cursor_from_connection(conn: Connection) -> DictCursor: + return cast(DictCursor, conn.cursor(DictCursor)) + + +__all__ = ["AsyncMySaver", "ShallowAsyncMySaver"] diff --git a/langgraph/store/mysql/__init__.py b/langgraph/store/mysql/__init__.py index 7b84176..e541819 100644 --- a/langgraph/store/mysql/__init__.py +++ b/langgraph/store/mysql/__init__.py @@ -1,4 +1,5 @@ from langgraph.store.mysql.aio import AIOMySQLStore +from langgraph.store.mysql.asyncmy import AsyncMyStore from langgraph.store.mysql.pymysql import PyMySQLStore -__all__ = ["AIOMySQLStore", "PyMySQLStore"] +__all__ = ["AIOMySQLStore", "AsyncMyStore", "PyMySQLStore"] diff --git a/langgraph/store/mysql/aio.py b/langgraph/store/mysql/aio.py index 425a11b..d5b5114 100644 --- a/langgraph/store/mysql/aio.py +++ b/langgraph/store/mysql/aio.py @@ -1,61 +1,18 @@ -import asyncio import logging import urllib.parse -from collections.abc import AsyncIterator, Iterable, Sequence +from collections.abc import AsyncIterator from contextlib import asynccontextmanager -from typing import Any, Callable, Optional, Union, cast +from typing import Any, cast import aiomysql # type: ignore -import orjson +from typing_extensions import Self, override -from langgraph.checkpoint.mysql import _ainternal -from langgraph.store.base import ( - GetOp, - ListNamespacesOp, - Op, - PutOp, - Result, - SearchOp, -) -from langgraph.store.base.batch import AsyncBatchedBaseStore -from langgraph.store.mysql.base import ( - BaseMySQLStore, - Row, - _decode_ns_bytes, - _group_ops, - _row_to_item, - _row_to_search_item, -) +from langgraph.store.mysql.aio_base import BaseAsyncMySQLStore logger = logging.getLogger(__name__) -class AIOMySQLStore(AsyncBatchedBaseStore, BaseMySQLStore[_ainternal.Conn]): - __slots__ = ("_deserializer", "lock") - - def __init__( - self, - conn: _ainternal.Conn, - *, - deserializer: Optional[ - Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]] - ] = None, - ) -> None: - super().__init__() - self._deserializer = deserializer - self.conn = conn - self.lock = asyncio.Lock() - self.loop = asyncio.get_running_loop() - - async def abatch(self, ops: Iterable[Op]) -> list[Result]: - grouped_ops, num_ops = _group_ops(ops) - results: list[Result] = [None] * num_ops - - async with _ainternal.get_connection(self.conn) as conn: - await self._execute_batch(grouped_ops, results, conn) - - return results - +class AIOMySQLStore(BaseAsyncMySQLStore[aiomysql.Connection, aiomysql.DictCursor]): @staticmethod def parse_conn_string(conn_string: str) -> dict[str, Any]: parsed = urllib.parse.urlparse(conn_string) @@ -79,7 +36,7 @@ def parse_conn_string(conn_string: str) -> dict[str, Any]: async def from_conn_string( cls, conn_string: str, - ) -> AsyncIterator["AIOMySQLStore"]: + ) -> AsyncIterator[Self]: """Create a new AIOMySQLStore instance from a connection string. Args: @@ -94,158 +51,7 @@ async def from_conn_string( ) as conn: yield cls(conn=conn) - async def setup(self) -> None: - """Set up the store database asynchronously. - - This method creates the necessary tables in the Postgres database if they don't - already exist and runs database migrations. It MUST be called directly by the user - the first time the store is used. - """ - - async def _get_version(cur: aiomysql.DictCursor, table: str) -> int: - await cur.execute( - f""" - CREATE TABLE IF NOT EXISTS {table} ( - v INTEGER PRIMARY KEY - ) - """ - ) - await cur.execute(f"SELECT v FROM {table} ORDER BY v DESC LIMIT 1") - row = await cur.fetchone() - if row is None: - version = -1 - else: - version = row["v"] - return version - - async with _ainternal.get_connection(self.conn) as conn: - async with self._cursor(conn) as cur: - version = await _get_version(cur, table="store_migrations") - for v, sql in enumerate( - self.MIGRATIONS[version + 1 :], start=version + 1 - ): - await cur.execute(sql) - await cur.execute( - "INSERT INTO store_migrations (v) VALUES (%s)", (v,) - ) - - async def _execute_batch( - self, - grouped_ops: dict, - results: list[Result], - conn: aiomysql.Connection, - ) -> None: - async with self._cursor(conn, pipeline=True) as cur: - if GetOp in grouped_ops: - await self._batch_get_ops( - cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), - results, - cur, - ) - - if SearchOp in grouped_ops: - await self._batch_search_ops( - cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]), - results, - cur, - ) - - if ListNamespacesOp in grouped_ops: - await self._batch_list_namespaces_ops( - cast( - Sequence[tuple[int, ListNamespacesOp]], - grouped_ops[ListNamespacesOp], - ), - results, - cur, - ) - - if PutOp in grouped_ops: - await self._batch_put_ops( - cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]), - cur, - ) - - async def _batch_get_ops( - self, - get_ops: Sequence[tuple[int, GetOp]], - results: list[Result], - cur: aiomysql.DictCursor, - ) -> None: - for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): - await cur.execute(query, params) - rows = cast(list[Row], await cur.fetchall()) - key_to_row = {row["key"]: row for row in rows} - for idx, key in items: - row = key_to_row.get(key) - if row: - results[idx] = _row_to_item( - namespace, row, loader=self._deserializer - ) - else: - results[idx] = None - - async def _batch_put_ops( - self, - put_ops: Sequence[tuple[int, PutOp]], - cur: aiomysql.DictCursor, - ) -> None: - queries = self._prepare_batch_PUT_queries(put_ops) - for query, params in queries: - await cur.execute(query, params) - - async def _batch_search_ops( - self, - search_ops: Sequence[tuple[int, SearchOp]], - results: list[Result], - cur: aiomysql.DictCursor, - ) -> None: - queries = self._prepare_batch_search_queries(search_ops) - for (idx, _), (query, params) in zip(search_ops, queries): - await cur.execute(query, params) - rows = cast(list[Row], await cur.fetchall()) - items = [ - _row_to_search_item( - _decode_ns_bytes(row["prefix"]), row, loader=self._deserializer - ) - for row in rows - ] - results[idx] = items - - async def _batch_list_namespaces_ops( - self, - list_ops: Sequence[tuple[int, ListNamespacesOp]], - results: list[Result], - cur: aiomysql.DictCursor, - ) -> None: - queries = self._get_batch_list_namespaces_queries(list_ops) - for (query, params), (idx, _) in zip(queries, list_ops): - await cur.execute(query, params) - rows = cast(list[dict], await cur.fetchall()) - namespaces = [_decode_ns_bytes(row["truncated_prefix"]) for row in rows] - results[idx] = namespaces - - @asynccontextmanager - async def _cursor( - self, conn: aiomysql.Connection, *, pipeline: bool = False - ) -> AsyncIterator[aiomysql.DictCursor]: - """Create a database cursor as a context manager. - Args: - conn: The database connection to use - pipeline: whether to use transaction context manager and handle concurrency - """ - if pipeline: - # a connection can only be used by one - # thread/coroutine at a time, so we acquire a lock - async with self.lock: - await conn.begin() - try: - async with conn.cursor(aiomysql.DictCursor) as cur: - yield cur - await conn.commit() - except: - await conn.rollback() - raise - else: - async with self.lock, conn.cursor(aiomysql.DictCursor) as cur: - yield cur + @override + @staticmethod + def _get_cursor_from_connection(conn: aiomysql.Connection) -> aiomysql.DictCursor: + return cast(aiomysql.DictCursor, conn.cursor(aiomysql.DictCursor)) diff --git a/langgraph/store/mysql/aio_base.py b/langgraph/store/mysql/aio_base.py new file mode 100644 index 0000000..8078d93 --- /dev/null +++ b/langgraph/store/mysql/aio_base.py @@ -0,0 +1,219 @@ +import asyncio +import logging +from collections.abc import AsyncIterator, Iterable, Sequence +from contextlib import asynccontextmanager +from typing import Any, Callable, Generic, Optional, Union, cast + +import orjson + +from langgraph.checkpoint.mysql import _ainternal +from langgraph.store.base import ( + GetOp, + ListNamespacesOp, + Op, + PutOp, + Result, + SearchOp, +) +from langgraph.store.base.batch import AsyncBatchedBaseStore +from langgraph.store.mysql.base import ( + BaseMySQLStore, + Row, + _decode_ns_bytes, + _group_ops, + _row_to_item, + _row_to_search_item, +) + +logger = logging.getLogger(__name__) + + +class BaseAsyncMySQLStore( + AsyncBatchedBaseStore, + BaseMySQLStore[_ainternal.Conn[_ainternal.C]], + Generic[_ainternal.C, _ainternal.R], +): + __slots__ = ("_deserializer", "lock") + + def __init__( + self, + conn: _ainternal.Conn[_ainternal.C], + *, + deserializer: Optional[ + Callable[[Union[bytes, orjson.Fragment]], dict[str, Any]] + ] = None, + ) -> None: + super().__init__() + self._deserializer = deserializer + self.conn = conn + self.lock = asyncio.Lock() + self.loop = asyncio.get_running_loop() + + @staticmethod + def _get_cursor_from_connection(conn: _ainternal.C) -> _ainternal.R: + raise NotImplementedError + + async def abatch(self, ops: Iterable[Op]) -> list[Result]: + grouped_ops, num_ops = _group_ops(ops) + results: list[Result] = [None] * num_ops + + async with _ainternal.get_connection(self.conn) as conn: + await self._execute_batch(grouped_ops, results, conn) + + return results + + async def setup(self) -> None: + """Set up the store database asynchronously. + + This method creates the necessary tables in the Postgres database if they don't + already exist and runs database migrations. It MUST be called directly by the user + the first time the store is used. + """ + + async def _get_version(cur: _ainternal.R, table: str) -> int: + await cur.execute( + f""" + CREATE TABLE IF NOT EXISTS {table} ( + v INTEGER PRIMARY KEY + ) + """ + ) + await cur.execute(f"SELECT v FROM {table} ORDER BY v DESC LIMIT 1") + row = await cur.fetchone() + if row is None: + version = -1 + else: + version = row["v"] + return version + + async with _ainternal.get_connection(self.conn) as conn: + async with self._cursor(conn) as cur: + version = await _get_version(cur, table="store_migrations") + for v, sql in enumerate( + self.MIGRATIONS[version + 1 :], start=version + 1 + ): + await cur.execute(sql) + await cur.execute( + "INSERT INTO store_migrations (v) VALUES (%s)", (v,) + ) + + async def _execute_batch( + self, + grouped_ops: dict, + results: list[Result], + conn: _ainternal.C, + ) -> None: + async with self._cursor(conn, pipeline=True) as cur: + if GetOp in grouped_ops: + await self._batch_get_ops( + cast(Sequence[tuple[int, GetOp]], grouped_ops[GetOp]), + results, + cur, + ) + + if SearchOp in grouped_ops: + await self._batch_search_ops( + cast(Sequence[tuple[int, SearchOp]], grouped_ops[SearchOp]), + results, + cur, + ) + + if ListNamespacesOp in grouped_ops: + await self._batch_list_namespaces_ops( + cast( + Sequence[tuple[int, ListNamespacesOp]], + grouped_ops[ListNamespacesOp], + ), + results, + cur, + ) + + if PutOp in grouped_ops: + await self._batch_put_ops( + cast(Sequence[tuple[int, PutOp]], grouped_ops[PutOp]), + cur, + ) + + async def _batch_get_ops( + self, + get_ops: Sequence[tuple[int, GetOp]], + results: list[Result], + cur: _ainternal.R, + ) -> None: + for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): + await cur.execute(query, params) + rows = cast(list[Row], await cur.fetchall()) + key_to_row = {row["key"]: row for row in rows} + for idx, key in items: + row = key_to_row.get(key) + if row: + results[idx] = _row_to_item( + namespace, row, loader=self._deserializer + ) + else: + results[idx] = None + + async def _batch_put_ops( + self, + put_ops: Sequence[tuple[int, PutOp]], + cur: _ainternal.R, + ) -> None: + queries = self._prepare_batch_PUT_queries(put_ops) + for query, params in queries: + await cur.execute(query, params) + + async def _batch_search_ops( + self, + search_ops: Sequence[tuple[int, SearchOp]], + results: list[Result], + cur: _ainternal.R, + ) -> None: + queries = self._prepare_batch_search_queries(search_ops) + for (idx, _), (query, params) in zip(search_ops, queries): + await cur.execute(query, params) + rows = cast(list[Row], await cur.fetchall()) + items = [ + _row_to_search_item( + _decode_ns_bytes(row["prefix"]), row, loader=self._deserializer + ) + for row in rows + ] + results[idx] = items + + async def _batch_list_namespaces_ops( + self, + list_ops: Sequence[tuple[int, ListNamespacesOp]], + results: list[Result], + cur: _ainternal.R, + ) -> None: + queries = self._get_batch_list_namespaces_queries(list_ops) + for (query, params), (idx, _) in zip(queries, list_ops): + await cur.execute(query, params) + rows = cast(list[dict], await cur.fetchall()) + namespaces = [_decode_ns_bytes(row["truncated_prefix"]) for row in rows] + results[idx] = namespaces + + @asynccontextmanager + async def _cursor( + self, conn: _ainternal.C, *, pipeline: bool = False + ) -> AsyncIterator[_ainternal.R]: + """Create a database cursor as a context manager. + Args: + conn: The database connection to use + pipeline: whether to use transaction context manager and handle concurrency + """ + if pipeline: + # a connection can only be used by one + # thread/coroutine at a time, so we acquire a lock + async with self.lock: + await conn.begin() + try: + async with self._get_cursor_from_connection(conn) as cur: + yield cur + await conn.commit() + except: + await conn.rollback() + raise + else: + async with self.lock, self._get_cursor_from_connection(conn) as cur: + yield cur diff --git a/langgraph/store/mysql/asyncmy.py b/langgraph/store/mysql/asyncmy.py new file mode 100644 index 0000000..4415f8f --- /dev/null +++ b/langgraph/store/mysql/asyncmy.py @@ -0,0 +1,58 @@ +import logging +import urllib.parse +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any, cast + +from asyncmy import Connection, connect # type: ignore +from asyncmy.cursors import DictCursor # type: ignore +from typing_extensions import Self, override + +from langgraph.store.mysql.aio_base import BaseAsyncMySQLStore + +logger = logging.getLogger(__name__) + + +class AsyncMyStore(BaseAsyncMySQLStore[Connection, DictCursor]): + @staticmethod + def parse_conn_string(conn_string: str) -> dict[str, Any]: + parsed = urllib.parse.urlparse(conn_string) + + # In order to provide additional params via the connection string, + # we convert the parsed.query to a dict so we can access the values. + # This is necessary when using a unix socket, for example. + params_as_dict = dict(urllib.parse.parse_qsl(parsed.query)) + + return { + "host": parsed.hostname or "localhost", + "user": parsed.username, + "password": parsed.password or "", + "db": parsed.path[1:] or None, + "port": parsed.port or 3306, + "unix_socket": params_as_dict.get("unix_socket"), + } + + @classmethod + @asynccontextmanager + async def from_conn_string( + cls, + conn_string: str, + ) -> AsyncIterator[Self]: + """Create a new AsyncMyStore instance from a connection string. + + Args: + conn_string (str): The MySQL connection info string. + + Returns: + AsyncMyStore: A new AsyncMyStore instance. + """ + async with connect( + **cls.parse_conn_string(conn_string), + autocommit=True, + ) as conn: + yield cls(conn=conn) + + @override + @staticmethod + def _get_cursor_from_connection(conn: Connection) -> DictCursor: + return cast(DictCursor, conn.cursor(DictCursor)) diff --git a/langgraph/store/mysql/base.py b/langgraph/store/mysql/base.py index 99996bf..7efb462 100644 --- a/langgraph/store/mysql/base.py +++ b/langgraph/store/mysql/base.py @@ -9,11 +9,8 @@ from typing import ( Any, Callable, - ContextManager, Generic, - Mapping, Optional, - Protocol, TypeVar, Union, cast, @@ -59,28 +56,7 @@ ] -class DictCursor(ContextManager, Protocol): - """ - Protocol that a cursor should implement. - - Modeled after DBAPICursor from Typeshed. - """ - - def execute( - self, - operation: str, - parameters: Union[Sequence[Any], Mapping[str, Any]] = ..., - /, - ) -> object: ... - def executemany( - self, operation: str, seq_of_parameters: Sequence[Sequence[Any]], / - ) -> object: ... - def fetchone(self) -> Optional[dict[str, Any]]: ... - def fetchall(self) -> Sequence[dict[str, Any]]: ... - - C = TypeVar("C", bound=Union[_internal.Conn, _ainternal.Conn]) # connection type -R = TypeVar("R", bound=DictCursor) # cursor type class BaseMySQLStore(Generic[C]): @@ -303,7 +279,9 @@ def _get_filter_condition(self, key: str, op: str, value: Any) -> tuple[str, lis class BaseSyncMySQLStore( - BaseStore, BaseMySQLStore[_internal.Conn[_internal.C]], Generic[_internal.C, R] + BaseStore, + BaseMySQLStore[_internal.Conn[_internal.C]], + Generic[_internal.C, _internal.R], ): __slots__ = ("_deserializer", "lock") @@ -321,11 +299,11 @@ def __init__( self.lock = threading.Lock() @staticmethod - def _get_cursor_from_connection(conn: _internal.C) -> R: + def _get_cursor_from_connection(conn: _internal.C) -> _internal.R: raise NotImplementedError @contextmanager - def _cursor(self, *, pipeline: bool = False) -> Iterator[R]: + def _cursor(self, *, pipeline: bool = False) -> Iterator[_internal.R]: """Create a database cursor as a context manager. Args: @@ -385,7 +363,7 @@ def _batch_get_ops( self, get_ops: Sequence[tuple[int, GetOp]], results: list[Result], - cur: R, + cur: _internal.R, ) -> None: for query, params, namespace, items in self._get_batch_GET_ops_queries(get_ops): cur.execute(query, params) @@ -403,7 +381,7 @@ def _batch_get_ops( def _batch_put_ops( self, put_ops: Sequence[tuple[int, PutOp]], - cur: R, + cur: _internal.R, ) -> None: queries = self._prepare_batch_PUT_queries(put_ops) for query, params in queries: @@ -413,7 +391,7 @@ def _batch_search_ops( self, search_ops: Sequence[tuple[int, SearchOp]], results: list[Result], - cur: R, + cur: _internal.R, ) -> None: for (query, params), (idx, _) in zip( self._prepare_batch_search_queries(search_ops), search_ops @@ -431,7 +409,7 @@ def _batch_list_namespaces_ops( self, list_ops: Sequence[tuple[int, ListNamespacesOp]], results: list[Result], - cur: R, + cur: _internal.R, ) -> None: for (query, params), (idx, _) in zip( self._get_batch_list_namespaces_queries(list_ops), list_ops @@ -451,7 +429,7 @@ def setup(self) -> None: the first time the store is used. """ - def _get_version(cur: R, table: str) -> int: + def _get_version(cur: _internal.R, table: str) -> int: cur.execute( f""" CREATE TABLE IF NOT EXISTS {table} ( diff --git a/poetry.lock b/poetry.lock index 9bf65bd..f2f9e14 100644 --- a/poetry.lock +++ b/poetry.lock @@ -54,6 +54,72 @@ doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"] trio = ["trio (>=0.26.1)"] +[[package]] +name = "asyncmy" +version = "0.2.10" +description = "A fast asyncio MySQL driver" +optional = false +python-versions = "<4.0,>=3.8" +groups = ["main", "dev"] +files = [ + {file = "asyncmy-0.2.10-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:c2237c8756b8f374099bd320c53b16f7ec0cee8258f00d72eed5a2cd3d251066"}, + {file = "asyncmy-0.2.10-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:6e98d4fbf7ea0d99dfecb24968c9c350b019397ba1af9f181d51bb0f6f81919b"}, + {file = "asyncmy-0.2.10-cp310-cp310-manylinux_2_17_i686.manylinux_2_5_i686.manylinux1_i686.manylinux2014_i686.whl", hash = "sha256:b1b1ee03556c7eda6422afc3aca132982a84706f8abf30f880d642f50670c7ed"}, + {file = "asyncmy-0.2.10-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3e2b97672ea3f0b335c0ffd3da1a5727b530f82f5032cd87e86c3aa3ac6df7f3"}, + {file = "asyncmy-0.2.10-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:c6471ce1f9ae1e6f0d55adfb57c49d0bcf5753a253cccbd33799ddb402fe7da2"}, + {file = "asyncmy-0.2.10-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:10e2a10fe44a2b216a1ae58fbdafa3fed661a625ec3c030c560c26f6ab618522"}, + {file = "asyncmy-0.2.10-cp310-cp310-win32.whl", hash = "sha256:a791ab117787eb075bc37ed02caa7f3e30cca10f1b09ec7eeb51d733df1d49fc"}, + {file = "asyncmy-0.2.10-cp310-cp310-win_amd64.whl", hash = "sha256:bd16fdc0964a4a1a19aec9797ca631c3ff2530013fdcd27225fc2e48af592804"}, + {file = "asyncmy-0.2.10-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:7af0f1f31f800a8789620c195e92f36cce4def68ee70d625534544d43044ed2a"}, + {file = "asyncmy-0.2.10-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:800116ab85dc53b24f484fb644fefffac56db7367a31e7d62f4097d495105a2c"}, + {file = "asyncmy-0.2.10-cp311-cp311-manylinux_2_17_i686.manylinux_2_5_i686.manylinux1_i686.manylinux2014_i686.whl", hash = "sha256:39525e9d7e557b83db268ed14b149a13530e0d09a536943dba561a8a1c94cc07"}, + {file = "asyncmy-0.2.10-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76e199d6b57918999efc702d2dbb182cb7ba8c604cdfc912517955219b16eaea"}, + {file = "asyncmy-0.2.10-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:9ca8fdd7dbbf2d9b4c2d3a5fac42b058707d6a483b71fded29051b8ae198a250"}, + {file = "asyncmy-0.2.10-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0df23db54e38602c803dacf1bbc1dcc4237a87223e659681f00d1a319a4f3826"}, + {file = "asyncmy-0.2.10-cp311-cp311-win32.whl", hash = "sha256:a16633032be020b931acfd7cd1862c7dad42a96ea0b9b28786f2ec48e0a86757"}, + {file = "asyncmy-0.2.10-cp311-cp311-win_amd64.whl", hash = "sha256:cca06212575922216b89218abd86a75f8f7375fc9c28159ea469f860785cdbc7"}, + {file = "asyncmy-0.2.10-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:42295530c5f36784031f7fa42235ef8dd93a75d9b66904de087e68ff704b4f03"}, + {file = "asyncmy-0.2.10-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:641a853ffcec762905cbeceeb623839c9149b854d5c3716eb9a22c2b505802af"}, + {file = "asyncmy-0.2.10-cp312-cp312-manylinux_2_17_i686.manylinux_2_5_i686.manylinux1_i686.manylinux2014_i686.whl", hash = "sha256:c554874223dd36b1cfc15e2cd0090792ea3832798e8fe9e9d167557e9cf31b4d"}, + {file = "asyncmy-0.2.10-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dd16e84391dde8edb40c57d7db634706cbbafb75e6a01dc8b68a63f8dd9e44ca"}, + {file = "asyncmy-0.2.10-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:9f6b44c4bf4bb69a2a1d9d26dee302473099105ba95283b479458c448943ed3c"}, + {file = "asyncmy-0.2.10-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:16d398b1aad0550c6fe1655b6758455e3554125af8aaf1f5abdc1546078c7257"}, + {file = "asyncmy-0.2.10-cp312-cp312-win32.whl", hash = "sha256:59d2639dcc23939ae82b93b40a683c15a091460a3f77fa6aef1854c0a0af99cc"}, + {file = "asyncmy-0.2.10-cp312-cp312-win_amd64.whl", hash = "sha256:4c6674073be97ffb7ac7f909e803008b23e50281131fef4e30b7b2162141a574"}, + {file = "asyncmy-0.2.10-cp38-cp38-macosx_13_0_x86_64.whl", hash = "sha256:85bc4522d8b632cd3327001a00cb24416883fc3905857737b99aa00bc0703fe1"}, + {file = "asyncmy-0.2.10-cp38-cp38-macosx_14_0_x86_64.whl", hash = "sha256:c93768dde803c7c118e6ac1893f98252e48fecad7c20bb7e27d4bdf3d130a044"}, + {file = "asyncmy-0.2.10-cp38-cp38-manylinux_2_17_i686.manylinux_2_5_i686.manylinux1_i686.manylinux2014_i686.whl", hash = "sha256:93b6d7db19a093abdeceb454826ff752ce1917288635d5d63519068ef5b2f446"}, + {file = "asyncmy-0.2.10-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:acecd4bbb513a67a94097fd499dac854546e07d2ff63c7fb5f4d2c077e4bdf91"}, + {file = "asyncmy-0.2.10-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1b4b346c02fca1d160005d4921753bb00ed03422f0c6ec90936c43aad96b7d52"}, + {file = "asyncmy-0.2.10-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8d393570e1c96ca200075797cc4f80849fc0ea960a45c6035855b1d392f33768"}, + {file = "asyncmy-0.2.10-cp38-cp38-win32.whl", hash = "sha256:c8ee5282af5f38b4dc3ae94a3485688bd6c0d3509ba37226dbaa187f1708e32c"}, + {file = "asyncmy-0.2.10-cp38-cp38-win_amd64.whl", hash = "sha256:10b3dfb119d7a9cb3aaae355c0981e60934f57297ea560bfdb280c5d85f77a9d"}, + {file = "asyncmy-0.2.10-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:244289bd1bea84384866bde50b09fe5b24856640e30a04073eacb71987b7b6ad"}, + {file = "asyncmy-0.2.10-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:6c9d024b160b9f869a21e62c4ef34a7b7a4b5a886ae03019d4182621ea804d2c"}, + {file = "asyncmy-0.2.10-cp39-cp39-manylinux_2_17_i686.manylinux_2_5_i686.manylinux1_i686.manylinux2014_i686.whl", hash = "sha256:b57594eea942224626203503f24fa88a47eaab3f13c9f24435091ea910f4b966"}, + {file = "asyncmy-0.2.10-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:346192941470ac2d315f97afa14c0131ff846c911da14861baf8a1f8ed541664"}, + {file = "asyncmy-0.2.10-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:957c2b48c5228e5f91fdf389daf38261a7b8989ad0eb0d1ba4e5680ef2a4a078"}, + {file = "asyncmy-0.2.10-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:472989d7bfa405c108a7f3c408bbed52306504fb3aa28963d833cb7eeaafece0"}, + {file = "asyncmy-0.2.10-cp39-cp39-win32.whl", hash = "sha256:714b0fdadd72031e972de2bbbd14e35a19d5a7e001594f0c8a69f92f0d05acc9"}, + {file = "asyncmy-0.2.10-cp39-cp39-win_amd64.whl", hash = "sha256:9fb58645d3da0b91db384f8519b16edc7dc421c966ada8647756318915d63696"}, + {file = "asyncmy-0.2.10-pp310-pypy310_pp73-macosx_13_0_x86_64.whl", hash = "sha256:f10c977c60a95bd6ec6b8654e20c8f53bad566911562a7ad7117ca94618f05d3"}, + {file = "asyncmy-0.2.10-pp310-pypy310_pp73-macosx_14_0_arm64.whl", hash = "sha256:aab07fbdb9466beaffef136ffabe388f0d295d8d2adb8f62c272f1d4076515b9"}, + {file = "asyncmy-0.2.10-pp310-pypy310_pp73-manylinux_2_17_i686.manylinux_2_5_i686.manylinux1_i686.manylinux2014_i686.whl", hash = "sha256:63144322ade68262201baae73ad0c8a06b98a3c6ae39d1f3f21c41cc5287066a"}, + {file = "asyncmy-0.2.10-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux_2_5_x86_64.manylinux1_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9659d95c6f2a611aec15bdd928950df937bf68bc4bbb68b809ee8924b6756067"}, + {file = "asyncmy-0.2.10-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:8ced4bd938e95ede0fb9fa54755773df47bdb9f29f142512501e613dd95cf4a4"}, + {file = "asyncmy-0.2.10-pp38-pypy38_pp73-macosx_13_0_x86_64.whl", hash = "sha256:f76080d5d360635f0c67411fb3fb890d7a5a9e31135b4bb07c6a4e588287b671"}, + {file = "asyncmy-0.2.10-pp38-pypy38_pp73-macosx_14_0_arm64.whl", hash = "sha256:fde04da1a3e656ec7d7656b2d02ade87df9baf88cc1ebeff5d2288f856c086a4"}, + {file = "asyncmy-0.2.10-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux_2_5_i686.manylinux1_i686.manylinux2014_i686.whl", hash = "sha256:a83383cc6951bcde11c9cdda216a0849d29be2002a8fb6405ea6d9e5ced4ec69"}, + {file = "asyncmy-0.2.10-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux_2_5_x86_64.manylinux1_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58c3d8c12030c23df93929c8371da818211fa02c7b50cd178960c0a88e538adf"}, + {file = "asyncmy-0.2.10-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e0c8706ff7fc003775f3fc63804ea45be61e9ac9df9fd968977f781189d625ed"}, + {file = "asyncmy-0.2.10-pp39-pypy39_pp73-macosx_13_0_x86_64.whl", hash = "sha256:4651caaee6f4d7a8eb478a0dc460f8e91ab09a2d8d32444bc2b235544c791947"}, + {file = "asyncmy-0.2.10-pp39-pypy39_pp73-macosx_14_0_arm64.whl", hash = "sha256:ac091b327f01c38d91c697c810ba49e5f836890d48f6879ba0738040bb244290"}, + {file = "asyncmy-0.2.10-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux_2_5_i686.manylinux1_i686.manylinux2014_i686.whl", hash = "sha256:e1d2d9387cd3971297486c21098e035c620149c9033369491f58fe4fc08825b6"}, + {file = "asyncmy-0.2.10-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux_2_5_x86_64.manylinux1_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a760cb486ddb2c936711325236e6b9213564a9bb5deb2f6949dbd16c8e4d739e"}, + {file = "asyncmy-0.2.10-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:1586f26633c05b16bcfc46d86e9875f4941280e12afa79a741cdf77ae4ccfb4d"}, + {file = "asyncmy-0.2.10.tar.gz", hash = "sha256:f4b67edadf7caa56bdaf1c2e6cf451150c0a86f5353744deabe4426fe27aff4e"}, +] + [[package]] name = "certifi" version = "2024.8.30" @@ -1495,9 +1561,10 @@ watchmedo = ["PyYAML (>=3.10)"] [extras] aiomysql = ["aiomysql"] +asyncmy = ["asyncmy"] pymysql = ["pymysql"] [metadata] lock-version = "2.1" python-versions = "^3.9.0,<4.0" -content-hash = "e4d951f31876080d29f3175637c74c659639dbf0aee64444130a2e792109543e" +content-hash = "7e07a752e3cc602f0b370a270c262d4bd44ded23dfe8b30c474904c1d5e912c3" diff --git a/pyproject.toml b/pyproject.toml index 5f6c43c..b276ecc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "langgraph-checkpoint-mysql" -version = "2.0.12" +version = "2.0.13" description = "Library with a MySQL implementation of LangGraph checkpoint saver." authors = ["Theodore Ni "] license = "MIT" @@ -14,11 +14,13 @@ langgraph-checkpoint = "^2.0.10" orjson = ">=3.10.1" pymysql = { version = "^1.1.1", optional = true } aiomysql = { version = "^0.2.0", optional = true } +asyncmy = { version = "^0.2.10", optional = true } typing-extensions = "^4.12.2" [tool.poetry.extras] pymysql = ["pymysql"] aiomysql = ["aiomysql"] +asyncmy = ["asyncmy"] [tool.poetry.group.dev.dependencies] ruff = "^0.6.2" @@ -31,6 +33,7 @@ pytest-watch = "^4.2.0" mypy = "^1.10.0" pymysql = "^1.1.1" aiomysql = "^0.2.0" +asyncmy = "^0.2.10" types-PyMySQL = "^1.1.0" langgraph = "0.2.67" syrupy = "^4.0.2" diff --git a/tests/test_async.py b/tests/test_async.py index 5f12ebf..5d72bb9 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -6,6 +6,7 @@ from uuid import uuid4 import aiomysql # type: ignore +import asyncmy # type: ignore import pytest from langchain_core.runnables import RunnableConfig @@ -17,13 +18,25 @@ empty_checkpoint, ) from langgraph.checkpoint.mysql.aio import AIOMySQLSaver, ShallowAIOMySQLSaver +from langgraph.checkpoint.mysql.aio_base import BaseAsyncMySQLSaver +from langgraph.checkpoint.mysql.asyncmy import AsyncMySaver, ShallowAsyncMySaver +from langgraph.checkpoint.mysql.shallow import BaseShallowAsyncMySQLSaver from langgraph.checkpoint.serde.types import TASKS from langgraph.graph import END, START, MessagesState, StateGraph from tests.conftest import DEFAULT_BASE_URI +SAVERS = [ + "aiomysql", + "aiomysql_pool", + "aiomysql_shallow", + "asyncmy", + "asyncmy_pool", + "asyncmy_shallow", +] + @asynccontextmanager -async def _pool_saver() -> AsyncIterator[AIOMySQLSaver]: +async def _aiomysql_pool_saver() -> AsyncIterator[AIOMySQLSaver]: """Fixture for pool mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db @@ -53,7 +66,7 @@ async def _pool_saver() -> AsyncIterator[AIOMySQLSaver]: @asynccontextmanager -async def _base_saver() -> AsyncIterator[AIOMySQLSaver]: +async def _aiomysql_saver() -> AsyncIterator[AIOMySQLSaver]: """Fixture for regular connection mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db @@ -79,7 +92,7 @@ async def _base_saver() -> AsyncIterator[AIOMySQLSaver]: @asynccontextmanager -async def _shallow_saver() -> AsyncIterator[ShallowAIOMySQLSaver]: +async def _aiomysql_shallow_saver() -> AsyncIterator[ShallowAIOMySQLSaver]: """Fixture for shallow connection mode testing.""" database = f"test_{uuid4().hex[:16]}" # create unique db @@ -104,18 +117,109 @@ async def _shallow_saver() -> AsyncIterator[ShallowAIOMySQLSaver]: await cursor.execute(f"DROP DATABASE {database}") +@asynccontextmanager +async def _asyncmy_pool_saver() -> AsyncIterator[AsyncMySaver]: + """Fixture for pool mode testing.""" + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_BASE_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + # yield checkpointer + async with asyncmy.create_pool( + **AsyncMySaver.parse_conn_string(DEFAULT_BASE_URI + database), + maxsize=10, + autocommit=True, + ) as pool: + checkpointer = AsyncMySaver(pool) + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_BASE_URI), autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + + +@asynccontextmanager +async def _asyncmy_saver() -> AsyncIterator[AsyncMySaver]: + """Fixture for regular connection mode testing.""" + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_BASE_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + async with AsyncMySaver.from_conn_string( + DEFAULT_BASE_URI + database + ) as checkpointer: + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_BASE_URI), autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + + +@asynccontextmanager +async def _asyncmy_shallow_saver() -> AsyncIterator[ShallowAsyncMySaver]: + """Fixture for shallow connection mode testing.""" + database = f"test_{uuid4().hex[:16]}" + # create unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_BASE_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + try: + async with ShallowAsyncMySaver.from_conn_string( + DEFAULT_BASE_URI + database + ) as checkpointer: + await checkpointer.setup() + yield checkpointer + finally: + # drop unique db + async with await asyncmy.connect( + **AsyncMySaver.parse_conn_string(DEFAULT_BASE_URI), autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + + @asynccontextmanager async def _saver( name: str, -) -> AsyncIterator[Union[AIOMySQLSaver, ShallowAIOMySQLSaver]]: - if name == "base": - async with _base_saver() as saver: +) -> AsyncIterator[Union[BaseAsyncMySQLSaver, BaseShallowAsyncMySQLSaver]]: + if name == "aiomysql": + async with _aiomysql_saver() as saver: + yield saver + elif name == "aiomysql_shallow": + async with _aiomysql_shallow_saver() as saver: + yield saver + elif name == "aiomysql_pool": + async with _aiomysql_pool_saver() as saver: + yield saver + elif name == "asyncmy": + async with _asyncmy_saver() as saver: yield saver - elif name == "shallow": - async with _shallow_saver() as saver: + elif name == "asyncmy_shallow": + async with _asyncmy_shallow_saver() as saver: yield saver - elif name == "pool": - async with _pool_saver() as saver: + elif name == "asyncmy_pool": + async with _asyncmy_pool_saver() as saver: yield saver @@ -170,7 +274,7 @@ def test_data() -> dict[str, Any]: } -@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) +@pytest.mark.parametrize("saver_name", SAVERS) async def test_asearch(saver_name: str, test_data: dict[str, Any]) -> None: async with _saver(saver_name) as saver: configs = test_data["configs"] @@ -215,7 +319,7 @@ async def test_asearch(saver_name: str, test_data: dict[str, Any]) -> None: } == {"", "inner"} -@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) +@pytest.mark.parametrize("saver_name", SAVERS) async def test_null_chars(saver_name: str, test_data: dict[str, Any]) -> None: async with _saver(saver_name) as saver: config = await saver.aput( @@ -230,7 +334,7 @@ async def test_null_chars(saver_name: str, test_data: dict[str, Any]) -> None: ].metadata["my_key"] == "abc" -@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) +@pytest.mark.parametrize("saver_name", SAVERS) async def test_write_and_read_pending_writes_and_sends( saver_name: str, test_data: dict[str, Any] ) -> None: @@ -259,7 +363,7 @@ async def test_write_and_read_pending_writes_and_sends( assert result.checkpoint["pending_sends"] == ["w3v"] -@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) +@pytest.mark.parametrize("saver_name", SAVERS) @pytest.mark.parametrize( "channel_values", [ @@ -294,7 +398,7 @@ async def test_write_and_read_channel_values( assert result.checkpoint["channel_values"] == channel_values -@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) +@pytest.mark.parametrize("saver_name", SAVERS) async def test_write_and_read_pending_writes(saver_name: str) -> None: async with _saver(saver_name) as saver: config: RunnableConfig = { @@ -325,7 +429,7 @@ async def test_write_and_read_pending_writes(saver_name: str) -> None: ] -@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) +@pytest.mark.parametrize("saver_name", SAVERS) async def test_write_with_different_checkpoint_ns_inserts( saver_name: str, ) -> None: @@ -350,7 +454,7 @@ async def test_write_with_different_checkpoint_ns_inserts( assert len(results) == 2 -@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) +@pytest.mark.parametrize("saver_name", SAVERS) async def test_write_with_same_checkpoint_ns_updates( saver_name: str, ) -> None: @@ -373,7 +477,7 @@ async def test_write_with_same_checkpoint_ns_updates( assert len(results) == 1 -@pytest.mark.parametrize("saver_name", ["base", "pool", "shallow"]) +@pytest.mark.parametrize("saver_name", SAVERS) async def test_graph_sync_get_state_history_raises(saver_name: str) -> None: """Regression test for https://github.com/langchain-ai/langgraph/issues/2992""" diff --git a/tests/test_async_store.py b/tests/test_async_store.py index a96f95b..334afc9 100644 --- a/tests/test_async_store.py +++ b/tests/test_async_store.py @@ -6,25 +6,38 @@ from concurrent.futures import ThreadPoolExecutor import aiomysql # type: ignore +import asyncmy import pytest from langgraph.store.base import GetOp, Item, ListNamespacesOp, PutOp, SearchOp -from langgraph.store.mysql import AIOMySQLStore -from tests.conftest import DEFAULT_BASE_URI, DEFAULT_URI +from langgraph.store.mysql.aio import AIOMySQLStore +from langgraph.store.mysql.aio_base import BaseAsyncMySQLStore +from langgraph.store.mysql.asyncmy import AsyncMyStore +from tests.conftest import DEFAULT_BASE_URI -@pytest.fixture(scope="function", params=["default", "pool"]) -async def store(request) -> AsyncIterator[AIOMySQLStore]: +@pytest.fixture( + scope="function", params=["aiomysql", "aiomysql_pool", "asyncmy", "asyncmy_pool"] +) +async def store(request) -> AsyncIterator[BaseAsyncMySQLStore]: database = f"test_{uuid.uuid4().hex[:16]}" - async with await aiomysql.connect( - **AIOMySQLStore.parse_conn_string(DEFAULT_BASE_URI), - autocommit=True, - ) as conn: - async with conn.cursor() as cursor: - await cursor.execute(f"CREATE DATABASE {database}") + if request.param.startswith("aiomysql"): + async with await aiomysql.connect( + **AIOMySQLStore.parse_conn_string(DEFAULT_BASE_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") + else: + async with await asyncmy.connect( + **AsyncMyStore.parse_conn_string(DEFAULT_BASE_URI), + autocommit=True, + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"CREATE DATABASE {database}") try: - if request.param == "pool": + if request.param == "aiomysql_pool": async with aiomysql.create_pool( **AIOMySQLStore.parse_conn_string(DEFAULT_BASE_URI + database), maxsize=10, @@ -33,21 +46,43 @@ async def store(request) -> AsyncIterator[AIOMySQLStore]: store = AIOMySQLStore(pool) await store.setup() yield store - else: + elif request.param == "aiomysql": async with AIOMySQLStore.from_conn_string( DEFAULT_BASE_URI + database ) as store: await store.setup() yield store + elif request.param == "asyncmy_pool": + async with asyncmy.create_pool( + **AsyncMyStore.parse_conn_string(DEFAULT_BASE_URI + database), + maxsize=10, + autocommit=True, + ) as pool: + store = AsyncMyStore(pool) + await store.setup() + yield store + elif request.param == "asyncmy": + async with AsyncMyStore.from_conn_string( + DEFAULT_BASE_URI + database + ) as store: + await store.setup() + yield store finally: - async with await aiomysql.connect( - **AIOMySQLStore.parse_conn_string(DEFAULT_BASE_URI), autocommit=True - ) as conn: - async with conn.cursor() as cursor: - await cursor.execute(f"DROP DATABASE {database}") + if request.param.startswith("aiomysql"): + async with await aiomysql.connect( + **AIOMySQLStore.parse_conn_string(DEFAULT_BASE_URI), autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") + else: + async with await asyncmy.connect( + **AsyncMyStore.parse_conn_string(DEFAULT_BASE_URI), autocommit=True + ) as conn: + async with conn.cursor() as cursor: + await cursor.execute(f"DROP DATABASE {database}") -async def test_no_running_loop(store: AIOMySQLStore) -> None: +async def test_no_running_loop(store: BaseAsyncMySQLStore) -> None: with pytest.raises(asyncio.InvalidStateError): store.put(("foo", "bar"), "baz", {"val": "baz"}) with pytest.raises(asyncio.InvalidStateError): @@ -72,7 +107,7 @@ async def test_no_running_loop(store: AIOMySQLStore) -> None: ) -async def test_large_batches(store: AIOMySQLStore) -> None: +async def test_large_batches(store: BaseAsyncMySQLStore) -> None: N = 100 # less important that we are performant here M = 10 @@ -121,7 +156,7 @@ async def test_large_batches(store: AIOMySQLStore) -> None: assert len(results) == M * N * 6 -async def test_large_batches_async(store: AIOMySQLStore) -> None: +async def test_large_batches_async(store: BaseAsyncMySQLStore) -> None: N = 1000 M = 10 coros = [] @@ -169,7 +204,7 @@ async def test_large_batches_async(store: AIOMySQLStore) -> None: assert len(results) == M * N * 6 -async def test_abatch_order(store: AIOMySQLStore) -> None: +async def test_abatch_order(store: BaseAsyncMySQLStore) -> None: # Setup test data await store.aput(("test", "foo"), "key1", {"data": "value1"}) await store.aput(("test", "bar"), "key2", {"data": "value2"}) @@ -222,7 +257,7 @@ async def test_abatch_order(store: AIOMySQLStore) -> None: assert results_reordered[4].key == "key1" -async def test_batch_get_ops(store: AIOMySQLStore) -> None: +async def test_batch_get_ops(store: BaseAsyncMySQLStore) -> None: # Setup test data await store.aput(("test",), "key1", {"data": "value1"}) await store.aput(("test",), "key2", {"data": "value2"}) @@ -243,7 +278,7 @@ async def test_batch_get_ops(store: AIOMySQLStore) -> None: assert results[1].key == "key2" -async def test_batch_put_ops(store: AIOMySQLStore) -> None: +async def test_batch_put_ops(store: BaseAsyncMySQLStore) -> None: ops = [ PutOp(namespace=("test",), key="key1", value={"data": "value1"}), PutOp(namespace=("test",), key="key2", value={"data": "value2"}), @@ -260,7 +295,7 @@ async def test_batch_put_ops(store: AIOMySQLStore) -> None: assert len(items) == 2 # key3 had None value so wasn't stored -async def test_batch_search_ops(store: AIOMySQLStore) -> None: +async def test_batch_search_ops(store: BaseAsyncMySQLStore) -> None: # Setup test data await store.aput(("test", "foo"), "key1", {"data": "value1"}) await store.aput(("test", "bar"), "key2", {"data": "value2"}) @@ -279,7 +314,7 @@ async def test_batch_search_ops(store: AIOMySQLStore) -> None: assert len(results[1]) == 2 # All results -async def test_batch_list_namespaces_ops(store: AIOMySQLStore) -> None: +async def test_batch_list_namespaces_ops(store: BaseAsyncMySQLStore) -> None: # Setup test data await store.aput(("test", "namespace1"), "key1", {"data": "value1"}) await store.aput(("test", "namespace2"), "key2", {"data": "value2"}) @@ -294,273 +329,258 @@ async def test_batch_list_namespaces_ops(store: AIOMySQLStore) -> None: assert ("test", "namespace2") in results[0] -class TestAIOMySQLStore: - @pytest.fixture(autouse=True) - async def setup(self) -> None: - async with AIOMySQLStore.from_conn_string(DEFAULT_URI) as store: - await store.setup() - - async def test_basic_store_ops(self) -> None: - async with AIOMySQLStore.from_conn_string(DEFAULT_URI) as store: - namespace = ("test", "documents") - item_id = "doc1" - item_value = {"title": "Test Document", "content": "Hello, World!"} - - await store.aput(namespace, item_id, item_value) - item = await store.aget(namespace, item_id) - - assert item - assert item.namespace == namespace - assert item.key == item_id - assert item.value == item_value - - updated_value = { - "title": "Updated Test Document", - "content": "Hello, LangGraph!", - } - time.sleep(1) # ensures new updated time is greater - await store.aput(namespace, item_id, updated_value) - updated_item = await store.aget(namespace, item_id) - - assert updated_item.value == updated_value - assert updated_item.updated_at > item.updated_at - different_namespace = ("test", "other_documents") - item_in_different_namespace = await store.aget(different_namespace, item_id) - assert item_in_different_namespace is None - - new_item_id = "doc2" - new_item_value = {"title": "Another Document", "content": "Greetings!"} - await store.aput(namespace, new_item_id, new_item_value) - - search_results = await store.asearch(["test"], limit=10) - items = search_results - assert len(items) == 2 - assert any(item.key == item_id for item in items) - assert any(item.key == new_item_id for item in items) - - namespaces = await store.alist_namespaces(prefix=["test"]) - assert ("test", "documents") in namespaces - - await store.adelete(namespace, item_id) - await store.adelete(namespace, new_item_id) - deleted_item = await store.aget(namespace, item_id) - assert deleted_item is None - - deleted_item = await store.aget(namespace, new_item_id) - assert deleted_item is None - - empty_search_results = await store.asearch(["test"], limit=10) - assert len(empty_search_results) == 0 - - async def test_list_namespaces(self) -> None: - async with AIOMySQLStore.from_conn_string(DEFAULT_URI) as store: - test_pref = str(uuid.uuid4()) - test_namespaces = [ - (test_pref, "test", "documents", "public", test_pref), - (test_pref, "test", "documents", "private", test_pref), - (test_pref, "test", "images", "public", test_pref), - (test_pref, "test", "images", "private", test_pref), - (test_pref, "prod", "documents", "public", test_pref), - ( - test_pref, - "prod", - "documents", - "some", - "nesting", - "public", - test_pref, - ), - (test_pref, "prod", "documents", "private", test_pref), - ] - - for namespace in test_namespaces: - await store.aput(namespace, "dummy", {"content": "dummy"}) - - prefix_result = await store.alist_namespaces(prefix=[test_pref, "test"]) - assert len(prefix_result) == 4 - assert all([ns[1] == "test" for ns in prefix_result]) - - specific_prefix_result = await store.alist_namespaces( - prefix=[test_pref, "test", "documents"] - ) - assert len(specific_prefix_result) == 2 - assert all( - [ns[1:3] == ("test", "documents") for ns in specific_prefix_result] - ) - - suffix_result = await store.alist_namespaces(suffix=["public", test_pref]) - assert len(suffix_result) == 4 - assert all(ns[-2] == "public" for ns in suffix_result) - - prefix_suffix_result = await store.alist_namespaces( - prefix=[test_pref, "test"], suffix=["public", test_pref] - ) - assert len(prefix_suffix_result) == 2 - assert all( - ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result - ) - - wildcard_prefix_result = await store.alist_namespaces( - prefix=[test_pref, "*", "documents"] - ) - assert len(wildcard_prefix_result) == 5 - assert all(ns[2] == "documents" for ns in wildcard_prefix_result) - - wildcard_suffix_result = await store.alist_namespaces( - suffix=["*", "public", test_pref] - ) - assert len(wildcard_suffix_result) == 4 - assert all(ns[-2] == "public" for ns in wildcard_suffix_result) - wildcard_single = await store.alist_namespaces( - suffix=["some", "*", "public", test_pref] - ) - assert len(wildcard_single) == 1 - assert wildcard_single[0] == ( - test_pref, - "prod", - "documents", - "some", - "nesting", - "public", - test_pref, - ) - - max_depth_result = await store.alist_namespaces(max_depth=3) - assert all([len(ns) <= 3 for ns in max_depth_result]) - max_depth_result = await store.alist_namespaces( - max_depth=4, prefix=[test_pref, "*", "documents"] - ) - assert ( - len(set(tuple(res) for res in max_depth_result)) - == len(max_depth_result) - == 5 - ) - - limit_result = await store.alist_namespaces(prefix=[test_pref], limit=3) - assert len(limit_result) == 3 - - offset_result = await store.alist_namespaces(prefix=[test_pref], offset=3) - assert len(offset_result) == len(test_namespaces) - 3 - - empty_prefix_result = await store.alist_namespaces(prefix=[test_pref]) - assert len(empty_prefix_result) == len(test_namespaces) - assert set(tuple(ns) for ns in empty_prefix_result) == set( - tuple(ns) for ns in test_namespaces - ) - - for namespace in test_namespaces: - await store.adelete(namespace, "dummy") - - async def test_search(self): - async with AIOMySQLStore.from_conn_string(DEFAULT_URI) as store: - test_namespaces = [ - ("test_search", "documents", "user1"), - ("test_search", "documents", "user2"), - ("test_search", "reports", "department1"), - ("test_search", "reports", "department2"), - ] - test_items = [ - {"title": "Doc 1", "author": "John Doe", "tags": ["important"]}, - {"title": "Doc 2", "author": "Jane Smith", "tags": ["draft"]}, - {"title": "Report A", "author": "John Doe", "tags": ["final"]}, - {"title": "Report B", "author": "Alice Johnson", "tags": ["draft"]}, - ] - empty = await store.asearch( - ( - "scoped", - "assistant_id", - "shared", - "6c5356f6-63ab-4158-868d-cd9fd14c736e", - ), - limit=10, - offset=0, - ) - assert len(empty) == 0 - - for namespace, item in zip(test_namespaces, test_items): - await store.aput(namespace, f"item_{namespace[-1]}", item) - - docs_result = await store.asearch(["test_search", "documents"]) - assert len(docs_result) == 2 - assert all([item.namespace[1] == "documents" for item in docs_result]), [ - item.namespace for item in docs_result - ] - - reports_result = await store.asearch(["test_search", "reports"]) - assert len(reports_result) == 2 - assert all(item.namespace[1] == "reports" for item in reports_result) - - limited_result = await store.asearch(["test_search"], limit=2) - assert len(limited_result) == 2 - offset_result = await store.asearch(["test_search"]) - assert len(offset_result) == 4 - - offset_result = await store.asearch(["test_search"], offset=2) - assert len(offset_result) == 2 - assert all(item not in limited_result for item in offset_result) +async def test_basic_store_ops(store: BaseAsyncMySQLStore) -> None: + namespace = ("test", "documents") + item_id = "doc1" + item_value = {"title": "Test Document", "content": "Hello, World!"} + + await store.aput(namespace, item_id, item_value) + item = await store.aget(namespace, item_id) + + assert item + assert item.namespace == namespace + assert item.key == item_id + assert item.value == item_value + + updated_value = { + "title": "Updated Test Document", + "content": "Hello, LangGraph!", + } + time.sleep(1) # ensures new updated time is greater + await store.aput(namespace, item_id, updated_value) + updated_item = await store.aget(namespace, item_id) + + assert updated_item.value == updated_value + assert updated_item.updated_at > item.updated_at + different_namespace = ("test", "other_documents") + item_in_different_namespace = await store.aget(different_namespace, item_id) + assert item_in_different_namespace is None + + new_item_id = "doc2" + new_item_value = {"title": "Another Document", "content": "Greetings!"} + await store.aput(namespace, new_item_id, new_item_value) + + search_results = await store.asearch(["test"], limit=10) + items = search_results + assert len(items) == 2 + assert any(item.key == item_id for item in items) + assert any(item.key == new_item_id for item in items) + + namespaces = await store.alist_namespaces(prefix=["test"]) + assert ("test", "documents") in namespaces + + await store.adelete(namespace, item_id) + await store.adelete(namespace, new_item_id) + deleted_item = await store.aget(namespace, item_id) + assert deleted_item is None + + deleted_item = await store.aget(namespace, new_item_id) + assert deleted_item is None + + empty_search_results = await store.asearch(["test"], limit=10) + assert len(empty_search_results) == 0 + + +async def test_list_namespaces(store: BaseAsyncMySQLStore) -> None: + test_pref = str(uuid.uuid4()) + test_namespaces = [ + (test_pref, "test", "documents", "public", test_pref), + (test_pref, "test", "documents", "private", test_pref), + (test_pref, "test", "images", "public", test_pref), + (test_pref, "test", "images", "private", test_pref), + (test_pref, "prod", "documents", "public", test_pref), + ( + test_pref, + "prod", + "documents", + "some", + "nesting", + "public", + test_pref, + ), + (test_pref, "prod", "documents", "private", test_pref), + ] - john_doe_result = await store.asearch( - ["test_search"], filter={"author": "John Doe"} - ) - assert len(john_doe_result) == 2 - assert all(item.value["author"] == "John Doe" for item in john_doe_result) + for namespace in test_namespaces: + await store.aput(namespace, "dummy", {"content": "dummy"}) + + prefix_result = await store.alist_namespaces(prefix=[test_pref, "test"]) + assert len(prefix_result) == 4 + assert all([ns[1] == "test" for ns in prefix_result]) + + specific_prefix_result = await store.alist_namespaces( + prefix=[test_pref, "test", "documents"] + ) + assert len(specific_prefix_result) == 2 + assert all([ns[1:3] == ("test", "documents") for ns in specific_prefix_result]) + + suffix_result = await store.alist_namespaces(suffix=["public", test_pref]) + assert len(suffix_result) == 4 + assert all(ns[-2] == "public" for ns in suffix_result) + + prefix_suffix_result = await store.alist_namespaces( + prefix=[test_pref, "test"], suffix=["public", test_pref] + ) + assert len(prefix_suffix_result) == 2 + assert all(ns[1] == "test" and ns[-2] == "public" for ns in prefix_suffix_result) + + wildcard_prefix_result = await store.alist_namespaces( + prefix=[test_pref, "*", "documents"] + ) + assert len(wildcard_prefix_result) == 5 + assert all(ns[2] == "documents" for ns in wildcard_prefix_result) + + wildcard_suffix_result = await store.alist_namespaces( + suffix=["*", "public", test_pref] + ) + assert len(wildcard_suffix_result) == 4 + assert all(ns[-2] == "public" for ns in wildcard_suffix_result) + wildcard_single = await store.alist_namespaces( + suffix=["some", "*", "public", test_pref] + ) + assert len(wildcard_single) == 1 + assert wildcard_single[0] == ( + test_pref, + "prod", + "documents", + "some", + "nesting", + "public", + test_pref, + ) + + max_depth_result = await store.alist_namespaces(max_depth=3) + assert all([len(ns) <= 3 for ns in max_depth_result]) + max_depth_result = await store.alist_namespaces( + max_depth=4, prefix=[test_pref, "*", "documents"] + ) + assert ( + len(set(tuple(res) for res in max_depth_result)) == len(max_depth_result) == 5 + ) + + limit_result = await store.alist_namespaces(prefix=[test_pref], limit=3) + assert len(limit_result) == 3 + + offset_result = await store.alist_namespaces(prefix=[test_pref], offset=3) + assert len(offset_result) == len(test_namespaces) - 3 + + empty_prefix_result = await store.alist_namespaces(prefix=[test_pref]) + assert len(empty_prefix_result) == len(test_namespaces) + assert set(tuple(ns) for ns in empty_prefix_result) == set( + tuple(ns) for ns in test_namespaces + ) + + for namespace in test_namespaces: + await store.adelete(namespace, "dummy") + + +async def test_search(store: BaseAsyncMySQLStore): + test_namespaces = [ + ("test_search", "documents", "user1"), + ("test_search", "documents", "user2"), + ("test_search", "reports", "department1"), + ("test_search", "reports", "department2"), + ] + test_items = [ + {"title": "Doc 1", "author": "John Doe", "tags": ["important"]}, + {"title": "Doc 2", "author": "Jane Smith", "tags": ["draft"]}, + {"title": "Report A", "author": "John Doe", "tags": ["final"]}, + {"title": "Report B", "author": "Alice Johnson", "tags": ["draft"]}, + ] + empty = await store.asearch( + ( + "scoped", + "assistant_id", + "shared", + "6c5356f6-63ab-4158-868d-cd9fd14c736e", + ), + limit=10, + offset=0, + ) + assert len(empty) == 0 + + for namespace, item in zip(test_namespaces, test_items): + await store.aput(namespace, f"item_{namespace[-1]}", item) + + docs_result = await store.asearch(["test_search", "documents"]) + assert len(docs_result) == 2 + assert all([item.namespace[1] == "documents" for item in docs_result]), [ + item.namespace for item in docs_result + ] - draft_result = await store.asearch( - ["test_search"], filter={"tags": ["draft"]} - ) - assert len(draft_result) == 2 - assert all("draft" in item.value["tags"] for item in draft_result) - - page1 = await store.asearch(["test_search"], limit=2, offset=0) - page2 = await store.asearch(["test_search"], limit=2, offset=2) - all_items = page1 + page2 - assert len(all_items) == 4 - assert len(set(item.key for item in all_items)) == 4 - empty = await store.asearch( - ( - "scoped", - "assistant_id", - "shared", - "again", - "maybe", - "some-long", - "6be5cb0e-2eb4-42e6-bb6b-fba3c269db25", - ), - limit=10, - offset=0, - ) - assert len(empty) == 0 - - # Test with a namespace beginning with a number (like a UUID) - uuid_namespace = (str(uuid.uuid4()), "documents") - uuid_item_id = "uuid_doc" - uuid_item_value = { - "title": "UUID Document", - "content": "This document has a UUID namespace.", - } - - # Insert the item with the UUID namespace - await store.aput(uuid_namespace, uuid_item_id, uuid_item_value) - - # Retrieve the item to verify it was stored correctly - retrieved_item = await store.aget(uuid_namespace, uuid_item_id) - assert retrieved_item is not None - assert retrieved_item.namespace == uuid_namespace - assert retrieved_item.key == uuid_item_id - assert retrieved_item.value == uuid_item_value - - # Search for the item using the UUID namespace - search_result = await store.asearch([uuid_namespace[0]]) - assert len(search_result) == 1 - assert search_result[0].key == uuid_item_id - assert search_result[0].value == uuid_item_value - - # Clean up: delete the item with the UUID namespace - await store.adelete(uuid_namespace, uuid_item_id) - - # Verify the item was deleted - deleted_item = await store.aget(uuid_namespace, uuid_item_id) - assert deleted_item is None - - for namespace in test_namespaces: - await store.adelete(namespace, f"item_{namespace[-1]}") + reports_result = await store.asearch(["test_search", "reports"]) + assert len(reports_result) == 2 + assert all(item.namespace[1] == "reports" for item in reports_result) + + limited_result = await store.asearch(["test_search"], limit=2) + assert len(limited_result) == 2 + offset_result = await store.asearch(["test_search"]) + assert len(offset_result) == 4 + + offset_result = await store.asearch(["test_search"], offset=2) + assert len(offset_result) == 2 + assert all(item not in limited_result for item in offset_result) + + john_doe_result = await store.asearch( + ["test_search"], filter={"author": "John Doe"} + ) + assert len(john_doe_result) == 2 + assert all(item.value["author"] == "John Doe" for item in john_doe_result) + + draft_result = await store.asearch(["test_search"], filter={"tags": ["draft"]}) + assert len(draft_result) == 2 + assert all("draft" in item.value["tags"] for item in draft_result) + + page1 = await store.asearch(["test_search"], limit=2, offset=0) + page2 = await store.asearch(["test_search"], limit=2, offset=2) + all_items = page1 + page2 + assert len(all_items) == 4 + assert len(set(item.key for item in all_items)) == 4 + empty = await store.asearch( + ( + "scoped", + "assistant_id", + "shared", + "again", + "maybe", + "some-long", + "6be5cb0e-2eb4-42e6-bb6b-fba3c269db25", + ), + limit=10, + offset=0, + ) + assert len(empty) == 0 + + # Test with a namespace beginning with a number (like a UUID) + uuid_namespace = (str(uuid.uuid4()), "documents") + uuid_item_id = "uuid_doc" + uuid_item_value = { + "title": "UUID Document", + "content": "This document has a UUID namespace.", + } + + # Insert the item with the UUID namespace + await store.aput(uuid_namespace, uuid_item_id, uuid_item_value) + + # Retrieve the item to verify it was stored correctly + retrieved_item = await store.aget(uuid_namespace, uuid_item_id) + assert retrieved_item is not None + assert retrieved_item.namespace == uuid_namespace + assert retrieved_item.key == uuid_item_id + assert retrieved_item.value == uuid_item_value + + # Search for the item using the UUID namespace + search_result = await store.asearch([uuid_namespace[0]]) + assert len(search_result) == 1 + assert search_result[0].key == uuid_item_id + assert search_result[0].value == uuid_item_value + + # Clean up: delete the item with the UUID namespace + await store.adelete(uuid_namespace, uuid_item_id) + + # Verify the item was deleted + deleted_item = await store.aget(uuid_namespace, uuid_item_id) + assert deleted_item is None + + for namespace in test_namespaces: + await store.adelete(namespace, f"item_{namespace[-1]}")