Skip to content

Commit

Permalink
correctly handle signals in nested transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
michamos committed Oct 15, 2018
1 parent 50944e7 commit c34ad3b
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 7 deletions.
44 changes: 38 additions & 6 deletions flask_sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import sys
import time
import warnings
from itertools import chain
from math import ceil
from operator import itemgetter
from threading import Lock
Expand Down Expand Up @@ -166,13 +167,14 @@ class _SessionSignalEvents(object):
@classmethod
def register(cls, session):
if not hasattr(session, '_model_changes'):
session._model_changes = {}
session._model_changes = []

event.listen(session, 'before_flush', cls.record_ops)
event.listen(session, 'before_commit', cls.record_ops)
event.listen(session, 'before_commit', cls.before_commit)
event.listen(session, 'after_commit', cls.after_commit)
event.listen(session, 'after_rollback', cls.after_rollback)
event.listen(session, 'after_transaction_create', cls.after_transaction_create)

@classmethod
def unregister(cls, session):
Expand All @@ -184,6 +186,7 @@ def unregister(cls, session):
event.remove(session, 'before_commit', cls.before_commit)
event.remove(session, 'after_commit', cls.after_commit)
event.remove(session, 'after_rollback', cls.after_rollback)
event.remove(session, 'after_transaction_create', cls.after_transaction_create)

@staticmethod
def record_ops(session, flush_context=None, instances=None):
Expand All @@ -196,28 +199,54 @@ def record_ops(session, flush_context=None, instances=None):
for target in targets:
state = inspect(target)
key = state.identity_key if state.has_identity else id(target)
d[key] = (target, operation)
d[-1][key] = (target, operation)

@staticmethod
def after_transaction_create(session, transaction):
if transaction.parent and not transaction.nested:
return

try:
d = session._model_changes
except AttributeError:
return

d.append({})

@staticmethod
def before_commit(session):
if session.transaction.nested:
return

try:
d = session._model_changes
except AttributeError:
return

if d:
before_models_committed.send(session.app, changes=list(d.values()))
for level in d[1:]:
d[0].update(level)

if d[0]:
before_models_committed.send(session.app, changes=list(d[0].values()))

@staticmethod
def after_commit(session):
if session.transaction.nested:
return

try:
d = session._model_changes
except AttributeError:
return

if d:
models_committed.send(session.app, changes=list(d.values()))
d.clear()
for level in d[1:]:
d[0].update(level)

if d[0]:
models_committed.send(session.app, changes=list(d[0].values()))
del d[:]

@staticmethod
def after_rollback(session):
Expand All @@ -226,7 +255,10 @@ def after_rollback(session):
except AttributeError:
return

d.clear()
try:
del d[-1]
except IndexError:
pass


class _EngineDebuggingSignalEvents(object):
Expand Down
68 changes: 67 additions & 1 deletion tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,38 @@
import pytest

import flask_sqlalchemy as fsa
import sqlalchemy as sa


pytestmark = pytest.mark.skipif(
not flask.signals_available,
reason='Signals require the blinker library.'
)


@pytest.fixture()
def app(app):
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = True
return app


@pytest.fixture()
def db(db):
# required for correct handling of nested transactions, see
# https://docs.sqlalchemy.org/en/rel_1_0/dialects/sqlite.html#serializable-isolation-savepoints-transactional-ddl
@sa.event.listens_for(db.engine, "connect")
def do_connect(dbapi_connection, connection_record):
# disable pysqlite's emitting of the BEGIN statement entirely.
# also stops it from emitting COMMIT before any DDL.
dbapi_connection.isolation_level = None

@sa.event.listens_for(db.engine, "begin")
def do_begin(conn):
# emit our own BEGIN
conn.execute("BEGIN")

return db


def test_before_committed(app, db, Todo):
class Namespace(object):
is_received = False
Expand Down Expand Up @@ -59,3 +77,51 @@ def committed(sender, changes):
assert recorded[0][0] == todo
assert recorded[0][1] == 'delete'
fsa.models_committed.disconnect(committed)


def test_model_signals_nested_transaction(db, Todo):
before_commit_recorded = []
commit_recorded = []

def before_committed(sender, changes):
before_commit_recorded.extend(changes)

def committed(sender, changes):
commit_recorded.extend(changes)

fsa.before_models_committed.connect(before_committed)
fsa.models_committed.connect(committed)
with db.session.begin_nested():
todo = Todo('Awesome', 'the text')
db.session.add(todo)
try:
with db.session.begin_nested():
todo2 = Todo('Bad', 'to rollback')
db.session.add(todo2)
raise Exception('raising to roll back')
except Exception:
pass
assert before_commit_recorded == []
assert commit_recorded == []
db.session.commit()
assert before_commit_recorded == [(todo, 'insert')]
assert commit_recorded == [(todo, 'insert')]
del before_commit_recorded[:]
del commit_recorded[:]
try:
with db.session.begin_nested():
todo = Todo('Great', 'the text')
db.session.add(todo)
with db.session.begin_nested():
todo2 = Todo('Bad', 'to rollback')
db.session.add(todo2)
raise Exception('raising to roll back')
except Exception:
pass
assert before_commit_recorded == []
assert commit_recorded == []
db.session.commit()
assert before_commit_recorded == []
assert commit_recorded == []
fsa.before_models_committed.disconnect(before_committed)
fsa.models_committed.disconnect(committed)

0 comments on commit c34ad3b

Please sign in to comment.