diff --git a/databases/core.py b/databases/core.py index cf5a7aa0..54e057a2 100644 --- a/databases/core.py +++ b/databases/core.py @@ -36,12 +36,9 @@ logger = logging.getLogger("databases") -_ACTIVE_CONNECTIONS: ContextVar[ - typing.Optional["weakref.WeakKeyDictionary['Database', 'Connection']"] -] = ContextVar("databases:open_connections", default=None) _ACTIVE_TRANSACTIONS: ContextVar[ typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"] -] = ContextVar("databases:open_transactions", default=None) +] = ContextVar("databases:active_transactions", default=None) class Database: @@ -54,6 +51,8 @@ class Database: "sqlite": "databases.backends.sqlite:SQLiteBackend", } + _connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']" + def __init__( self, url: typing.Union[str, "DatabaseURL"], @@ -64,6 +63,7 @@ def __init__( self.url = DatabaseURL(url) self.options = options self.is_connected = False + self._connection_map = weakref.WeakKeyDictionary() self._force_rollback = force_rollback @@ -78,28 +78,28 @@ def __init__( self._global_transaction: typing.Optional[Transaction] = None @property - def _connection(self) -> typing.Optional["Connection"]: - connections = _ACTIVE_CONNECTIONS.get() - if connections is None: - return None + def _current_task(self): + task = asyncio.current_task() + if not task: + raise RuntimeError("No currently active asyncio.Task found") + return task - return connections.get(self, None) + @property + def _connection(self) -> typing.Optional["Connection"]: + return self._connection_map.get(self._current_task) @_connection.setter def _connection( self, connection: typing.Optional["Connection"] ) -> typing.Optional["Connection"]: - connections = _ACTIVE_CONNECTIONS.get() - if connections is None: - connections = weakref.WeakKeyDictionary() - _ACTIVE_CONNECTIONS.set(connections) + task = self._current_task if connection is None: - connections.pop(self, None) + self._connection_map.pop(task, None) else: - connections[self] = connection + self._connection_map[task] = connection - return connections.get(self, None) + return self._connection async def connect(self) -> None: """ @@ -119,7 +119,7 @@ async def connect(self) -> None: assert self._global_connection is None assert self._global_transaction is None - self._global_connection = Connection(self._backend) + self._global_connection = Connection(self, self._backend) self._global_transaction = self._global_connection.transaction( force_rollback=True ) @@ -218,7 +218,7 @@ def connection(self) -> "Connection": return self._global_connection if not self._connection: - self._connection = Connection(self._backend) + self._connection = Connection(self, self._backend) return self._connection @@ -243,7 +243,8 @@ def _get_backend(self) -> str: class Connection: - def __init__(self, backend: DatabaseBackend) -> None: + def __init__(self, database: Database, backend: DatabaseBackend) -> None: + self._database = database self._backend = backend self._connection_lock = asyncio.Lock() @@ -277,6 +278,7 @@ async def __aexit__( self._connection_counter -= 1 if self._connection_counter == 0: await self._connection.release() + self._database._connection = None async def fetch_all( self, @@ -393,13 +395,15 @@ def _transaction( transactions = _ACTIVE_TRANSACTIONS.get() if transactions is None: transactions = weakref.WeakKeyDictionary() - _ACTIVE_TRANSACTIONS.set(transactions) + else: + transactions = transactions.copy() if transaction is None: transactions.pop(self, None) else: transactions[self] = transaction + _ACTIVE_TRANSACTIONS.set(transactions) return transactions.get(self, None) async def __aenter__(self) -> "Transaction": diff --git a/docs/connections_and_transactions.md b/docs/connections_and_transactions.md index e52243e3..11044655 100644 --- a/docs/connections_and_transactions.md +++ b/docs/connections_and_transactions.md @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints. ## Connecting and disconnecting -You can control the database connect/disconnect, by using it as a async context manager. +You can control the database connection pool with an async context manager: ```python async with Database(DATABASE_URL) as database: ... ``` -Or by using explicit connection and disconnection: +Or by using the explicit `.connect()` and `.disconnect()` methods: ```python database = Database(DATABASE_URL) @@ -23,6 +23,8 @@ await database.connect() await database.disconnect() ``` +Connections within this connection pool are acquired for each new `asyncio.Task`. + If you're integrating against a web framework, then you'll probably want to hook into framework startup or shutdown events. For example, with [Starlette][starlette] you would use the following: @@ -96,12 +98,13 @@ async def create_users(request): ... ``` -Transaction state is stored in the context of the currently executing asynchronous task. -This state is _inherited_ by tasks that are started from within an active transaction: +Transaction state is tied to the connection used in the currently executing asynchronous task. +If you would like to influence an active transaction from another task, the connection must be +shared. This state is _inherited_ by tasks that are share the same connection: ```python -async def add_excitement(database: Database, id: int): - await database.execute( +async def add_excitement(connnection: databases.core.Connection, id: int): + await connection.execute( "UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id", {"id": id} ) @@ -113,17 +116,13 @@ async with Database(database_url) as database: await database.execute( "INSERT INTO notes(id, text) values (1, 'databases is cool')" ) - # ...but child tasks inherit transaction state! - await asyncio.create_task(add_excitement(database, id=1)) + # ...but child tasks can use this connection now! + await asyncio.create_task(add_excitement(database.connection(), id=1)) await database.fetch_val("SELECT text FROM notes WHERE id=1") # ^ returns: "databases is cool!!!" ``` -!!! note - In python 3.11, you can opt-out of context propagation by providing a new context to - [`asyncio.create_task`](https://docs.python.org/3.11/library/asyncio-task.html#creating-tasks). - Nested transactions are fully supported, and are implemented using database savepoints: ```python diff --git a/tests/test_databases.py b/tests/test_databases.py index c78ce4f3..4d737261 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -482,11 +482,29 @@ async def test_transaction_commit(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_transaction_context_child_task_interaction(database_url): +async def test_transaction_context_child_task_inheritance(database_url): + """ + Ensure that transactions are inherited by child tasks. + """ + async with Database(database_url) as database: + + async def check_transaction(transaction, active_transaction): + # Should have inherited the same transaction backend from the parent task + assert transaction._transaction is active_transaction + + async with database.transaction() as transaction: + await asyncio.create_task( + check_transaction(transaction, transaction._transaction) + ) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_context_child_task_inheritance_example(database_url): """ Ensure that child tasks may influence inherited transactions. """ - # This is an practical example of the next test. + # This is an practical example of the above test. async with Database(database_url) as database: async with database.transaction(): # Create a note @@ -503,37 +521,19 @@ async def test_transaction_context_child_task_interaction(database_url): result = await database.fetch_one(notes.select().where(notes.c.id == 1)) assert result.text == "prior" - async def run_update_from_child_task(): - # Chage the note from a child task - await database.execute( + async def run_update_from_child_task(connection): + # Change the note from a child task + await connection.execute( notes.update().where(notes.c.id == 1).values(text="test") ) - await asyncio.create_task(run_update_from_child_task()) + await asyncio.create_task(run_update_from_child_task(database.connection())) # Confirm the child's change result = await database.fetch_one(notes.select().where(notes.c.id == 1)) assert result.text == "test" -@pytest.mark.parametrize("database_url", DATABASE_URLS) -@async_adapter -async def test_transaction_context_child_task_inheritance(database_url): - """ - Ensure that transactions are inherited by child tasks. - """ - async with Database(database_url) as database: - - async def check_transaction(transaction, active_transaction): - # Should have inherited the same transaction backend from the parent task - assert transaction._transaction is active_transaction - - async with database.transaction() as transaction: - await asyncio.create_task( - check_transaction(transaction, transaction._transaction) - ) - - @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_transaction_context_sibling_task_isolation(database_url): @@ -568,56 +568,99 @@ async def check_transaction(transaction): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context_cleanup_contextmanager(database_url): +async def test_transaction_context_sibling_task_isolation_example(database_url): + """ + Ensure that transactions are running in sibling tasks are isolated from eachother. + """ + # This is an practical example of the above test. + setup = asyncio.Event() + done = asyncio.Event() + + async def tx1(connection): + async with connection.transaction(): + await db.execute( + notes.insert(), values={"id": 1, "text": "tx1", "completed": False} + ) + setup.set() + await done.wait() + + async def tx2(connection): + async with connection.transaction(): + await setup.wait() + result = await db.fetch_all(notes.select()) + assert result == [], result + done.set() + + async with Database(database_url) as db: + await asyncio.gather(tx1(db), tx2(db)) + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_cleanup_contextmanager(database_url): """ - Ensure that contextvar connections are not persisted unecessarily. + Ensure that task connections are not persisted unecessarily. """ - from databases.core import _ACTIVE_CONNECTIONS - assert _ACTIVE_CONNECTIONS.get() is None + ready = asyncio.Event() + done = asyncio.Event() + + async def check_child_connection(database: Database): + async with database.connection(): + ready.set() + await done.wait() async with Database(database_url) as database: + # Should have a connection in this task # .connect is lazy, it doesn't create a Connection, but .connection does connection = database.connection() + assert isinstance(database._connection_map, MutableMapping) + assert database._connection_map.get(asyncio.current_task()) is connection - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert open_connections.get(database) is connection + # Create a child task and see if it registers a connection + task = asyncio.create_task(check_child_connection(database)) + await ready.wait() + assert database._connection_map.get(task) is not None + assert database._connection_map.get(task) is not connection - # Context manager closes, open_connections is cleaned up - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert open_connections.get(database, None) is None + # Let the child task finish, and see if it cleaned up + done.set() + await task + # This is normal exit logic cleanup, the WeakKeyDictionary + # shouldn't have cleaned up yet since the task is still referenced + assert task not in database._connection_map + + # Context manager closes, all open connections are removed + assert isinstance(database._connection_map, MutableMapping) + assert len(database._connection_map) == 0 @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context_cleanup_garbagecollector(database_url): +async def test_connection_cleanup_garbagecollector(database_url): """ - Ensure that contextvar connections are not persisted unecessarily, even + Ensure that connections for tasks are not persisted unecessarily, even if exit handlers are not called. """ - from databases.core import _ACTIVE_CONNECTIONS - - assert _ACTIVE_CONNECTIONS.get() is None - database = Database(database_url) await database.connect() - connection = database.connection() - # Should be tracking the connection - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert open_connections.get(database) is connection + created = asyncio.Event() + + async def check_child_connection(database: Database): + # neither .disconnect nor .__aexit__ are called before deleting this task + database.connection() + created.set() - # neither .disconnect nor .__aexit__ are called before deleting the reference - del database + task = asyncio.create_task(check_child_connection(database)) + await created.wait() + assert task in database._connection_map + await task + del task gc.collect() - # Should have dropped reference to connection, even without proper cleanup - open_connections = _ACTIVE_CONNECTIONS.get() - assert isinstance(open_connections, MutableMapping) - assert len(open_connections) == 0 + # Should not have a connection for the task anymore + assert len(database._connection_map) == 0 @pytest.mark.parametrize("database_url", DATABASE_URLS) @@ -632,7 +675,6 @@ async def test_transaction_context_cleanup_contextmanager(database_url): async with Database(database_url) as database: async with database.transaction() as transaction: - open_transactions = _ACTIVE_TRANSACTIONS.get() assert isinstance(open_transactions, MutableMapping) assert open_transactions.get(transaction) is transaction._transaction @@ -818,17 +860,44 @@ async def insert_data(raise_exception): with pytest.raises(RuntimeError): await insert_data(raise_exception=True) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 0 await insert_data(raise_exception=False) - query = notes.select() - results = await database.fetch_all(query=query) + results = await database.fetch_all(query=notes.select()) assert len(results) == 1 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_transaction_decorator_concurrent(database_url): + """ + Ensure that @database.transaction() can be called concurrently. + """ + + database = Database(database_url) + + @database.transaction() + async def insert_data(): + await database.execute( + query=notes.insert().values(text="example", completed=True) + ) + + async with database: + await asyncio.gather( + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + insert_data(), + ) + + results = await database.fetch_all(query=notes.select()) + assert len(results) == 6 + + @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_datetime_field(database_url): @@ -1007,7 +1076,7 @@ async def test_connection_context_same_task(database_url): @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter -async def test_connection_context_multiple_tasks(database_url): +async def test_connection_context_multiple_sibling_tasks(database_url): async with Database(database_url) as database: connection_1 = None connection_2 = None @@ -1037,6 +1106,47 @@ async def get_connection_2(): await task_2 +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_connection_context_multiple_tasks(database_url): + async with Database(database_url) as database: + parent_connection = database.connection() + connection_1 = None + connection_2 = None + task_1_ready = asyncio.Event() + task_2_ready = asyncio.Event() + test_complete = asyncio.Event() + + async def get_connection_1(): + nonlocal connection_1 + + async with database.connection() as connection: + connection_1 = connection + task_1_ready.set() + await test_complete.wait() + + async def get_connection_2(): + nonlocal connection_2 + + async with database.connection() as connection: + connection_2 = connection + task_2_ready.set() + await test_complete.wait() + + task_1 = asyncio.create_task(get_connection_1()) + task_2 = asyncio.create_task(get_connection_2()) + await task_1_ready.wait() + await task_2_ready.wait() + + assert connection_1 is not parent_connection + assert connection_2 is not parent_connection + assert connection_1 is not connection_2 + + test_complete.set() + await task_1 + await task_2 + + @pytest.mark.parametrize( "database_url1,database_url2", (