diff --git a/mosbot/query.py b/mosbot/query.py index bb7bab7..9dd2631 100644 --- a/mosbot/query.py +++ b/mosbot/query.py @@ -11,7 +11,6 @@ from typing import List, Optional import sqlalchemy as sa -import sqlalchemy.sql.functions as saf from asyncio_extras import async_contextmanager from sqlalchemy.dialects import postgresql as psa @@ -96,6 +95,7 @@ async def save_user(*, user_dict: dict, conn=None) -> dict: async def get_or_save_user(*, user_dict: dict, conn=None) -> dict: + """Try to retrieve a given user. If it doesn't exist, try to create it.""" user = await get_user(user_dict=user_dict, conn=conn) if user: return user @@ -373,41 +373,42 @@ async def query_simplified_user_actions(playback_id, *, conn=None) -> List[dict] :param conn: A connection if any open :return: A list of the records """ - sub_query = sa.select([ - db.UserAction.c.user_id, - saf.max(db.UserAction.c.ts).label('ts'), - db.UserAction.c.playback_id, - ]).where( - db.UserAction.c.playback_id == playback_id - ).group_by( - db.UserAction.c.user_id, - db.UserAction.c.playback_id, - sa.case([ - (db.UserAction.c.user_id.is_(None), db.UserAction.c.id), - ], else_=0) - ).alias() - - query = sa.select([ - sa.distinct(db.UserAction.c.id), - db.UserAction.c.action, - db.UserAction.c.playback_id, - db.UserAction.c.ts, - db.UserAction.c.user_id, - ]).select_from( - db.UserAction.join( - sub_query, - sa.and_( - sub_query.c.ts == db.UserAction.c.ts, - db.UserAction.c.playback_id == sub_query.c.playback_id, - sa.case([ - (sa.and_( - db.UserAction.c.user_id.is_(None), - sub_query.c.user_id.is_(None) - ), sa.true()) - ], else_=db.UserAction.c.user_id == sub_query.c.user_id) - ) - ) - ) + query = f""" + select * + from ( + select + "song_title", + "username", + "action_timestamp", + "action", + "rn", + lag("action") over (partition by "username") as "next_action" + from ( + select * + from ( + select + p.id "playback_id", + u.username "username", + u.id "user_id", + t."name" "song_title", + ua.id "action_id", + ua."action" "action", + ua.ts "action_timestamp", + row_number() over (partition by ua.user_id, p.id order by u.username asc, ua.ts desc) rn + from playback p + left join user_action ua on p.id = ua.playback_id + left join track t on p.track_id = t.id + left join "user" u on ua.user_id = u.id + where p.id = {playback_id} + ) sub + ) sub2 + ) sub3 + where + "rn" = 1 or + "next_action" = 'skip' or + "username" = null + order by "action_timestamp" + """ async with ensure_connection(conn) as conn: result = [] async for user_action in await conn.execute(query): diff --git a/tests/test_query.py b/tests/test_query.py index 658d3eb..83f37c3 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -551,12 +551,61 @@ def test_get_opposite_dub_action(input, output): assert output == get_opposite_dub_action(input) -@pytest.mark.parametrize('loops, output', ( - (1, {Action.upvote, }), - (2, {Action.downvote, }), - (3, {Action.upvote, }), - pytest.param(4, {Action.upvote, Action.skip}, marks=pytest.mark.xfail( - reason='The query is not complete enough to only aggregate upvotes and downvotes')), +@pytest.mark.parametrize('user_action, expected_result', ( + ([ + {'user': 1, 'action': 'upvote'}, + ], { + 0: Action.upvote.name, + }), + ([ + {'user': 1, 'action': 'upvote'}, + {'user': 1, 'action': 'downvote'}, + ], { + 0: Action.downvote.name, + }), + ([ + {'user': 1, 'action': 'upvote'}, + {'user': 1, 'action': 'skip'}, + ], { + 0: Action.upvote.name, + 1: Action.skip.name, + }), + ([ + {'user': 1, 'action': 'upvote'}, + {'user': 1, 'action': 'downvote'}, + {'user': 1, 'action': 'skip'}, + ], { + 0: Action.downvote.name, + 1: Action.skip.name, + }), + ([ + {'user': 1, 'action': 'upvote'}, + {'user': None, 'action': 'upvote'}, + ], { + 0: Action.upvote.name, + 1: Action.upvote.name, + }), + ([ + {'user': 1, 'action': 'upvote'}, + {'user': None, 'action': 'upvote'}, + {'user': 1, 'action': 'downvote'}, + {'user': 1, 'action': 'skip'}, + ], { + 0: Action.upvote.name, + 1: Action.downvote.name, + 2: Action.skip.name, + }), + ([ + {'user': None, 'action': 'upvote'}, + {'user': None, 'action': 'upvote'}, + {'user': None, 'action': 'downvote'}, + {'user': 1, 'action': 'skip'}, + ], { + 0: Action.upvote.name, + 1: Action.upvote.name, + 2: Action.downvote.name, + 3: Action.skip.name, + }), )) @pytest.mark.asyncio async def test_query_simplified_user_actions( @@ -565,16 +614,19 @@ async def test_query_simplified_user_actions( user_generator, playback_generator, user_action_generator, - loops, - output -): - actions = ['upvote', 'downvote', 'upvote', 'skip'] + user_action, + expected_result): track = await track_generator() - user = await user_generator() - playback = await playback_generator(user=user, track=track) - for _, action in zip(range(loops), actions): - await user_action_generator(user=user, playback=playback, action=action) - - user_actions = await query_simplified_user_actions(playback_id=playback['id'], conn=db_conn) - result = {ua['action'] for ua in user_actions} - assert result == output + registered_user = await user_generator() + playback = await playback_generator(user=registered_user, track=track) + + for action in user_action: + if action['user'] is None: + user = await user_generator() + elif action['user'] == 1: + user = registered_user + await user_action_generator(user=user, playback=playback, action=action['action']) + + simplified_user_action = await query_simplified_user_actions(playback_id=playback['id'], conn=db_conn) + actual_result = {i: user_action['action'] for i, user_action in enumerate(simplified_user_action)} + assert expected_result == actual_result