diff --git a/flask_sqlalchemy/__init__.py b/flask_sqlalchemy/__init__.py index f2c04b8a..12a54622 100644 --- a/flask_sqlalchemy/__init__.py +++ b/flask_sqlalchemy/__init__.py @@ -166,13 +166,14 @@ class _SessionSignalEvents(object): @classmethod def register(cls, session): if not hasattr(session, '_model_changes'): - session._model_changes = {} + session._model_changes = [] event.listen(session, 'before_flush', cls.record_ops) event.listen(session, 'before_commit', cls.record_ops) event.listen(session, 'before_commit', cls.before_commit) event.listen(session, 'after_commit', cls.after_commit) event.listen(session, 'after_rollback', cls.after_rollback) + event.listen(session, 'after_transaction_create', cls.after_transaction_create) @classmethod def unregister(cls, session): @@ -184,6 +185,7 @@ def unregister(cls, session): event.remove(session, 'before_commit', cls.before_commit) event.remove(session, 'after_commit', cls.after_commit) event.remove(session, 'after_rollback', cls.after_rollback) + event.remove(session, 'after_transaction_create', cls.after_transaction_create) @staticmethod def record_ops(session, flush_context=None, instances=None): @@ -196,28 +198,54 @@ def record_ops(session, flush_context=None, instances=None): for target in targets: state = inspect(target) key = state.identity_key if state.has_identity else id(target) - d[key] = (target, operation) + d[-1][key] = (target, operation) + + @staticmethod + def after_transaction_create(session, transaction): + if transaction.parent and not transaction.nested: + return + + try: + d = session._model_changes + except AttributeError: + return + + d.append({}) @staticmethod def before_commit(session): + if session.transaction.nested: + return + try: d = session._model_changes except AttributeError: return if d: - before_models_committed.send(session.app, changes=list(d.values())) + for level in d[1:]: + d[0].update(level) + + if d[0]: + before_models_committed.send(session.app, changes=list(d[0].values())) @staticmethod def after_commit(session): + if session.transaction.nested: + return + try: d = session._model_changes except AttributeError: return if d: - models_committed.send(session.app, changes=list(d.values())) - d.clear() + for level in d[1:]: + d[0].update(level) + + if d[0]: + models_committed.send(session.app, changes=list(d[0].values())) + del d[:] @staticmethod def after_rollback(session): @@ -226,7 +254,10 @@ def after_rollback(session): except AttributeError: return - d.clear() + try: + del d[-1] + except IndexError: + pass class _EngineDebuggingSignalEvents(object): diff --git a/tests/test_signals.py b/tests/test_signals.py index fa6611ed..0f88ce26 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -2,6 +2,7 @@ import pytest import flask_sqlalchemy as fsa +import sqlalchemy as sa pytestmark = pytest.mark.skipif( @@ -16,6 +17,24 @@ def app(app): return app +@pytest.fixture() +def db(db): + # required for correct handling of nested transactions, see + # https://docs.sqlalchemy.org/en/rel_1_0/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl + @sa.event.listens_for(db.engine, "connect") + def do_connect(dbapi_connection, connection_record): + # disable pysqlite's emitting of the BEGIN statement entirely. + # also stops it from emitting COMMIT before any DDL. + dbapi_connection.isolation_level = None + + @sa.event.listens_for(db.engine, "begin") + def do_begin(conn): + # emit our own BEGIN + conn.execute("BEGIN") + + return db + + def test_before_committed(app, db, Todo): class Namespace(object): is_received = False @@ -59,3 +78,51 @@ def committed(sender, changes): assert recorded[0][0] == todo assert recorded[0][1] == 'delete' fsa.models_committed.disconnect(committed) + + +def test_model_signals_nested_transaction(db, Todo): + before_commit_recorded = [] + commit_recorded = [] + + def before_committed(sender, changes): + before_commit_recorded.extend(changes) + + def committed(sender, changes): + commit_recorded.extend(changes) + + fsa.before_models_committed.connect(before_committed) + fsa.models_committed.connect(committed) + with db.session.begin_nested(): + todo = Todo('Awesome', 'the text') + db.session.add(todo) + try: + with db.session.begin_nested(): + todo2 = Todo('Bad', 'to rollback') + db.session.add(todo2) + raise Exception('raising to roll back') + except Exception: + pass + assert before_commit_recorded == [] + assert commit_recorded == [] + db.session.commit() + assert before_commit_recorded == [(todo, 'insert')] + assert commit_recorded == [(todo, 'insert')] + del before_commit_recorded[:] + del commit_recorded[:] + try: + with db.session.begin_nested(): + todo = Todo('Great', 'the text') + db.session.add(todo) + with db.session.begin_nested(): + todo2 = Todo('Bad', 'to rollback') + db.session.add(todo2) + raise Exception('raising to roll back') + except Exception: + pass + assert before_commit_recorded == [] + assert commit_recorded == [] + db.session.commit() + assert before_commit_recorded == [] + assert commit_recorded == [] + fsa.before_models_committed.disconnect(before_committed) + fsa.models_committed.disconnect(committed)