Skip to content

Commit

Permalink
fix: query_simplified_user_actions and add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
dtgoitia committed Aug 9, 2018
1 parent 36c3987 commit 89114e4
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 55 deletions.
75 changes: 38 additions & 37 deletions mosbot/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import List, Optional

import sqlalchemy as sa
import sqlalchemy.sql.functions as saf
from asyncio_extras import async_contextmanager
from sqlalchemy.dialects import postgresql as psa

Expand Down Expand Up @@ -96,6 +95,7 @@ async def save_user(*, user_dict: dict, conn=None) -> dict:


async def get_or_save_user(*, user_dict: dict, conn=None) -> dict:
"""Try to retrieve a given user. If it doesn't exist, try to create it."""
user = await get_user(user_dict=user_dict, conn=conn)
if user:
return user
Expand Down Expand Up @@ -373,43 +373,44 @@ async def query_simplified_user_actions(playback_id, *, conn=None) -> List[dict]
:param conn: A connection if any open
:return: A list of the records
"""
sub_query = sa.select([
db.UserAction.c.user_id,
saf.max(db.UserAction.c.ts).label('ts'),
db.UserAction.c.playback_id,
]).where(
db.UserAction.c.playback_id == playback_id
).group_by(
db.UserAction.c.user_id,
db.UserAction.c.playback_id,
sa.case([
(db.UserAction.c.user_id.is_(None), db.UserAction.c.id),
], else_=0)
).alias()

query = sa.select([
sa.distinct(db.UserAction.c.id),
db.UserAction.c.action,
db.UserAction.c.playback_id,
db.UserAction.c.ts,
db.UserAction.c.user_id,
]).select_from(
db.UserAction.join(
sub_query,
sa.and_(
sub_query.c.ts == db.UserAction.c.ts,
db.UserAction.c.playback_id == sub_query.c.playback_id,
sa.case([
(sa.and_(
db.UserAction.c.user_id.is_(None),
sub_query.c.user_id.is_(None)
), sa.true())
], else_=db.UserAction.c.user_id == sub_query.c.user_id)
)
)
)
query = f"""
select *
from (
select
"song_title",
"username",
"action_timestamp",
"action",
"rn",
lag("action") over (partition by "username") as "next_action"
from (
select *
from (
select
p.id "playback_id",
u.username "username",
u.id "user_id",
t."name" "song_title",
ua.id "action_id",
ua."action" "action",
ua.ts "action_timestamp",
row_number() over (partition by ua.user_id, p.id order by u.username asc, ua.ts desc) rn
from playback p
left join user_action ua on p.id = ua.playback_id
left join track t on p.track_id = t.id
left join "user" u on ua.user_id = u.id
where p.id = {playback_id}
) sub
) sub2
) sub3
where
"rn" = 1 or
"next_action" = 'skip' or
"username" = null
order by "action_timestamp"
"""
async with ensure_connection(conn) as conn:
result = []
async for user_action in await conn.execute(query):
result.append(dict(user_action))
return result
return result
88 changes: 70 additions & 18 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,12 +551,61 @@ def test_get_opposite_dub_action(input, output):
assert output == get_opposite_dub_action(input)


@pytest.mark.parametrize('loops, output', (
(1, {Action.upvote, }),
(2, {Action.downvote, }),
(3, {Action.upvote, }),
pytest.param(4, {Action.upvote, Action.skip}, marks=pytest.mark.xfail(
reason='The query is not complete enough to only aggregate upvotes and downvotes')),
@pytest.mark.parametrize('user_action, expected_result', (
([
{'user': 1, 'action': 'upvote'},
], {
0: Action.upvote.name,
}),
([
{'user': 1, 'action': 'upvote'},
{'user': 1, 'action': 'downvote'},
], {
0: Action.downvote.name,
}),
([
{'user': 1, 'action': 'upvote'},
{'user': 1, 'action': 'skip'},
], {
0: Action.upvote.name,
1: Action.skip.name,
}),
([
{'user': 1, 'action': 'upvote'},
{'user': 1, 'action': 'downvote'},
{'user': 1, 'action': 'skip'},
], {
0: Action.downvote.name,
1: Action.skip.name,
}),
([
{'user': 1, 'action': 'upvote'},
{'user': None, 'action': 'upvote'},
], {
0: Action.upvote.name,
1: Action.upvote.name,
}),
([
{'user': 1, 'action': 'upvote'},
{'user': None, 'action': 'upvote'},
{'user': 1, 'action': 'downvote'},
{'user': 1, 'action': 'skip'},
], {
0: Action.upvote.name,
1: Action.downvote.name,
2: Action.skip.name,
}),
([
{'user': None, 'action': 'upvote'},
{'user': None, 'action': 'upvote'},
{'user': None, 'action': 'downvote'},
{'user': 1, 'action': 'skip'},
], {
0: Action.upvote.name,
1: Action.upvote.name,
2: Action.downvote.name,
3: Action.skip.name,
}),
))
@pytest.mark.asyncio
async def test_query_simplified_user_actions(
Expand All @@ -565,16 +614,19 @@ async def test_query_simplified_user_actions(
user_generator,
playback_generator,
user_action_generator,
loops,
output
):
actions = ['upvote', 'downvote', 'upvote', 'skip']
user_action,
expected_result):
track = await track_generator()
user = await user_generator()
playback = await playback_generator(user=user, track=track)
for _, action in zip(range(loops), actions):
await user_action_generator(user=user, playback=playback, action=action)

user_actions = await query_simplified_user_actions(playback_id=playback['id'], conn=db_conn)
result = {ua['action'] for ua in user_actions}
assert result == output
registered_user = await user_generator()
playback = await playback_generator(user=registered_user, track=track)

for action in user_action:
if action['user'] is None:
user = await user_generator()
elif action['user'] == 1:
user = registered_user
await user_action_generator(user=user, playback=playback, action=action['action'])

simplified_user_action = await query_simplified_user_actions(playback_id=playback['id'], conn=db_conn)
actual_result = {i: user_action['action'] for i, user_action in enumerate(simplified_user_action)}
assert expected_result == actual_result

0 comments on commit 89114e4

Please sign in to comment.