diff --git a/mosbot/query.py b/mosbot/query.py index bb7bab7..fb3c9c3 100644 --- a/mosbot/query.py +++ b/mosbot/query.py @@ -96,6 +96,7 @@ async def save_user(*, user_dict: dict, conn=None) -> dict: async def get_or_save_user(*, user_dict: dict, conn=None) -> dict: + """Try to retrieve a given user. If it doesn't exist, try to create it.""" user = await get_user(user_dict=user_dict, conn=conn) if user: return user @@ -373,41 +374,42 @@ async def query_simplified_user_actions(playback_id, *, conn=None) -> List[dict] :param conn: A connection if any open :return: A list of the records """ - sub_query = sa.select([ - db.UserAction.c.user_id, - saf.max(db.UserAction.c.ts).label('ts'), - db.UserAction.c.playback_id, - ]).where( - db.UserAction.c.playback_id == playback_id - ).group_by( - db.UserAction.c.user_id, - db.UserAction.c.playback_id, - sa.case([ - (db.UserAction.c.user_id.is_(None), db.UserAction.c.id), - ], else_=0) - ).alias() - - query = sa.select([ - sa.distinct(db.UserAction.c.id), - db.UserAction.c.action, - db.UserAction.c.playback_id, - db.UserAction.c.ts, - db.UserAction.c.user_id, - ]).select_from( - db.UserAction.join( - sub_query, - sa.and_( - sub_query.c.ts == db.UserAction.c.ts, - db.UserAction.c.playback_id == sub_query.c.playback_id, - sa.case([ - (sa.and_( - db.UserAction.c.user_id.is_(None), - sub_query.c.user_id.is_(None) - ), sa.true()) - ], else_=db.UserAction.c.user_id == sub_query.c.user_id) - ) - ) - ) + query = f""" + select * + from ( + select + "song_title", + "username", + "action_timestamp", + "action", + "rn", + lag("action") over (partition by "username") as "next_action" + from ( + select * + from ( + select + p.id "playback_id", + u.username "username", + u.id "user_id", + t."name" "song_title", + ua.id "action_id", + ua."action" "action", + ua.ts "action_timestamp", + row_number() over (partition by ua.user_id, p.id order by u.username asc, ua.ts desc) rn + from playback p + left join user_action ua on p.id = ua.playback_id + left join track t on p.track_id = t.id + left join "user" u on ua.user_id = u.id + where p.id = {playback_id} + ) sub + ) sub2 + ) sub3 + where + "rn" = 1 or + "next_action" = 'skip' or + "username" = null + order by "action_timestamp" + """ async with ensure_connection(conn) as conn: result = [] async for user_action in await conn.execute(query): diff --git a/tests/test_query.py b/tests/test_query.py index 658d3eb..71d3af5 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -83,22 +83,22 @@ async def test_ensure_connection( @pytest.mark.parametrize('data_dict,expected_result', ( - ( - {'id': 1, 'dtid': '1234', 'username': 'username', 'country': 'ES'}, - {'id': 1, 'dtid': '1234', 'username': 'username', 'country': 'ES'}, - ), - ( - {'id': 1, 'dtid': '1234', 'username': 'username', 'country': None}, - {'id': 1, 'dtid': '1234', 'username': 'username', 'country': None}, - ), - ( - {'id': 1, 'dtid': '1234', 'username': 'username'}, - {'id': 1, 'dtid': '1234', 'username': 'username', 'country': None}, - ), - ( - {'dtid': '1234', 'username': 'username'}, - {'id': 1, 'dtid': '1234', 'username': 'username', 'country': None}, - ), + ( + {'id': 1, 'dtid': '1234', 'username': 'username', 'country': 'ES'}, + {'id': 1, 'dtid': '1234', 'username': 'username', 'country': 'ES'}, + ), + ( + {'id': 1, 'dtid': '1234', 'username': 'username', 'country': None}, + {'id': 1, 'dtid': '1234', 'username': 'username', 'country': None}, + ), + ( + {'id': 1, 'dtid': '1234', 'username': 'username'}, + {'id': 1, 'dtid': '1234', 'username': 'username', 'country': None}, + ), + ( + {'dtid': '1234', 'username': 'username'}, + {'id': 1, 'dtid': '1234', 'username': 'username', 'country': None}, + ), ), ids=[ 'all_values', 'all_values_with_null', @@ -127,10 +127,10 @@ async def test_execute_and_first(db_conn, data_dict, expected_result): @pytest.mark.asyncio @pytest.mark.parametrize('user_dict, raises_exception', ( - ({'id': 1}, False), - ({'username': 'Username 1'}, False), - ({'dtid': '00000001-0001-0001-0001-0000000001'}, False), - ({'country': 'Country 1'}, ValueError), + ({'id': 1}, False), + ({'username': 'Username 1'}, False), + ({'dtid': '00000001-0001-0001-0001-0000000001'}, False), + ({'country': 'Country 1'}, ValueError), ), ids=['by_id', 'by_usernam', 'by_dtid', 'failing_by_country']) async def test_get_user(db_conn, user_generator, user_dict, raises_exception): user = await user_generator() @@ -143,7 +143,7 @@ async def test_get_user(db_conn, user_generator, user_dict, raises_exception): @pytest.mark.parametrize('user_dict, raises_exception', ( - ({}, False), + ({}, False), )) @pytest.mark.asyncio async def test_save_user(db_conn, user_dict, raises_exception): @@ -203,11 +203,11 @@ async def test_get_or_save_user( @pytest.mark.parametrize('track_dict, raises_exception', ( - ({'id': 1}, False), - ({'extid': 'Extid 1'}, False), - ({'origin': 'youtube'}, AssertionError), - ({'length': 120}, AssertionError), - ({'extid': 'Extid 1', 'origin': 'youtube'}, False), + ({'id': 1}, False), + ({'extid': 'Extid 1'}, False), + ({'origin': 'youtube'}, AssertionError), + ({'length': 120}, AssertionError), + ({'extid': 'Extid 1', 'origin': 'youtube'}, False), ), ids=['by_id', 'by_extid', 'by_origin', 'by_length', 'by_extid+origin'], ) @pytest.mark.asyncio async def test_get_track(db_conn, track_generator, track_dict, raises_exception): @@ -282,10 +282,10 @@ async def test_get_or_save_track( @pytest.mark.parametrize('playback_dict, raises_exception', ( - ({'id': 1}, False), - ({'start': datetime.datetime(year=1, month=1, day=1)}, False), - ({'track_id': 1}, ValueError), - ({'user_id': 1}, ValueError), + ({'id': 1}, False), + ({'start': datetime.datetime(year=1, month=1, day=1)}, False), + ({'track_id': 1}, ValueError), + ({'user_id': 1}, ValueError), ), ids=['by_id', 'by_start', 'by_track_id', 'by_user_id'], ) @pytest.mark.asyncio async def test_get_playback( @@ -364,10 +364,10 @@ async def test_get_or_save_playback( @pytest.mark.parametrize('user_action_dict, raises_exception', ( - ({'id': 1}, False), - ({'ts': datetime.datetime(year=1, month=1, day=1)}, ValueError), - ({'track_id': 1}, ValueError), - ({'user_id': 1}, ValueError), + ({'id': 1}, False), + ({'ts': datetime.datetime(year=1, month=1, day=1)}, ValueError), + ({'track_id': 1}, ValueError), + ({'user_id': 1}, ValueError), ), ids=['by_id', 'by_start', 'by_track_id', 'by_user_id'], ) @pytest.mark.asyncio async def test_get_user_action( @@ -431,12 +431,12 @@ async def test_save_user_action( @pytest.mark.parametrize('bot_data, raises_exception', ( - (('id', 1), False), - (('extid', 'Extid 1'), False), - (('origin', {'a': 1, 'b': 2}), False), - (('length', {1, 2, 3, 120}), TypeError), - (('other', [1, 2, 3, 120]), False), - (('extid', None), False), + (('id', 1), False), + (('extid', 'Extid 1'), False), + (('origin', {'a': 1, 'b': 2}), False), + (('length', {1, 2, 3, 120}), TypeError), + (('other', [1, 2, 3, 120]), False), + (('extid', None), False), ), ids=['number', 'string', 'dict', 'set', 'list', 'null'], ) @pytest.mark.asyncio async def test_bot_data(db_conn, bot_data, raises_exception): @@ -518,13 +518,13 @@ async def test_user_dub_user_actions( @pytest.mark.parametrize('input,output', ( - ('upvote', Action.upvote), - ('updub', Action.upvote), - ('updubs', Action.upvote), - ('downvote', Action.downvote), - ('downdub', Action.downvote), - ('downdubs', Action.downvote), - ('none', None), + ('upvote', Action.upvote), + ('updub', Action.upvote), + ('updubs', Action.upvote), + ('downvote', Action.downvote), + ('downdub', Action.downvote), + ('downdubs', Action.downvote), + ('none', None), )) def test_get_dub_action(input, output): if output is None: @@ -535,13 +535,13 @@ def test_get_dub_action(input, output): @pytest.mark.parametrize('input,output', ( - ('upvote', Action.downvote), - ('updub', Action.downvote), - ('updubs', Action.downvote), - ('downvote', Action.upvote), - ('downdub', Action.upvote), - ('downdubs', Action.upvote), - ('none', None), + ('upvote', Action.downvote), + ('updub', Action.downvote), + ('updubs', Action.downvote), + ('downvote', Action.upvote), + ('downdub', Action.upvote), + ('downdubs', Action.upvote), + ('none', None), )) def test_get_opposite_dub_action(input, output): if output is None: @@ -552,11 +552,15 @@ def test_get_opposite_dub_action(input, output): @pytest.mark.parametrize('loops, output', ( - (1, {Action.upvote, }), - (2, {Action.downvote, }), - (3, {Action.upvote, }), - pytest.param(4, {Action.upvote, Action.skip}, marks=pytest.mark.xfail( - reason='The query is not complete enough to only aggregate upvotes and downvotes')), + (1, {(0, Action.upvote.name), }), + (2, {(0, Action.downvote.name), }), + (3, {(0, Action.upvote.name), }), + (4, {(1, Action.upvote.name), (0, Action.upvote.name), }), + (5, {(2, Action.upvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }), + (6, {(2, Action.downvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }), + (7, {(2, Action.upvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }), + (8, {(3, Action.downvote.name), (2, Action.upvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }), + (9, {(4, Action.skip.name), (3, Action.downvote.name), (2, Action.upvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }), )) @pytest.mark.asyncio async def test_query_simplified_user_actions( @@ -568,13 +572,15 @@ async def test_query_simplified_user_actions( loops, output ): - actions = ['upvote', 'downvote', 'upvote', 'skip'] + # User: 1 1 1 2 3 1 1 4 1 + actions = ['upvote', 'downvote', 'upvote', 'upvote', 'upvote', 'downvote', 'upvote', 'downvote', 'skip'] track = await track_generator() user = await user_generator() playback = await playback_generator(user=user, track=track) - for _, action in zip(range(loops), actions): - await user_action_generator(user=user, playback=playback, action=action) + for i, action in zip(range(loops), actions): + generated_user = user if i not in [3, 4, 7] else await user_generator() + await user_action_generator(user=generated_user, playback=playback, action=action) user_actions = await query_simplified_user_actions(playback_id=playback['id'], conn=db_conn) - result = {ua['action'] for ua in user_actions} + result = {(n, ua['action']) for n, ua in enumerate(user_actions)} assert result == output