Skip to content

Commit

Permalink
fix: query_simplified_user_actions and add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
dtgoitia committed Mar 13, 2019
1 parent 19147d5 commit 18b4c2c
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 53 deletions.
72 changes: 36 additions & 36 deletions mosbot/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import List, Optional

import sqlalchemy as sa
import sqlalchemy.sql.functions as saf
from asyncio_extras import async_contextmanager
from sqlalchemy.dialects import postgresql as psa

Expand Down Expand Up @@ -387,41 +386,42 @@ async def query_simplified_user_actions(playback_id, *, conn=None) -> List[dict]
:param conn: A connection if any open
:return: A list of the records
"""
sub_query = sa.select([
db.UserAction.c.user_id,
saf.max(db.UserAction.c.ts).label('ts'),
db.UserAction.c.playback_id,
]).where(
db.UserAction.c.playback_id == playback_id
).group_by(
db.UserAction.c.user_id,
db.UserAction.c.playback_id,
sa.case([
(db.UserAction.c.user_id.is_(None), db.UserAction.c.id),
], else_=0)
).alias()

query = sa.select([
sa.distinct(db.UserAction.c.id),
db.UserAction.c.action,
db.UserAction.c.playback_id,
db.UserAction.c.ts,
db.UserAction.c.user_id,
]).select_from(
db.UserAction.join(
sub_query,
sa.and_(
sub_query.c.ts == db.UserAction.c.ts,
db.UserAction.c.playback_id == sub_query.c.playback_id,
sa.case([
(sa.and_(
db.UserAction.c.user_id.is_(None),
sub_query.c.user_id.is_(None)
), sa.true())
], else_=db.UserAction.c.user_id == sub_query.c.user_id)
)
)
)
query = f"""
select *
from (
select
"song_title",
"username",
"action_timestamp",
"action",
"rn",
lag("action") over (partition by "username") as "next_action"
from (
select *
from (
select
p.id "playback_id",
u.username "username",
u.id "user_id",
t."name" "song_title",
ua.id "action_id",
ua."action" "action",
ua.ts "action_timestamp",
row_number() over (partition by ua.user_id, p.id order by u.username asc, ua.ts desc) rn
from playback p
left join user_action ua on p.id = ua.playback_id
left join track t on p.track_id = t.id
left join "user" u on ua.user_id = u.id
where p.id = {playback_id}
) sub
) sub2
) sub3
where
"rn" = 1 or
"next_action" = 'skip' or
"username" = null
order by "action_timestamp"
"""
async with ensure_connection(conn) as conn:
result = []
async for user_action in await conn.execute(query):
Expand Down
85 changes: 68 additions & 17 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,12 +551,61 @@ def test_get_opposite_dub_action(input, output):
assert output == get_opposite_dub_action(input)


@pytest.mark.parametrize('loops, output', (
(1, {Action.upvote, }),
(2, {Action.downvote, }),
(3, {Action.upvote, }),
pytest.param(4, {Action.upvote, Action.skip}, marks=pytest.mark.xfail(
reason='The query is not complete enough to only aggregate upvotes and downvotes')),
@pytest.mark.parametrize('user_action, expected_result', (
([
{'user_id': 1, 'action': 'upvote'},
], {
0: Action.upvote.name,
}),
([
{'user_id': 1, 'action': 'upvote'},
{'user_id': 1, 'action': 'downvote'},
], {
0: Action.downvote.name,
}),
([
{'user_id': 1, 'action': 'upvote'},
{'user_id': 1, 'action': 'skip'},
], {
0: Action.upvote.name,
1: Action.skip.name,
}),
([
{'user_id': 1, 'action': 'upvote'},
{'user_id': 1, 'action': 'downvote'},
{'user_id': 1, 'action': 'skip'},
], {
0: Action.downvote.name,
1: Action.skip.name,
}),
([
{'user_id': 1, 'action': 'upvote'},
{'user_id': 2, 'action': 'upvote'},
], {
0: Action.upvote.name,
1: Action.upvote.name,
}),
([
{'user_id': 1, 'action': 'upvote'},
{'user_id': 2, 'action': 'upvote'},
{'user_id': 1, 'action': 'downvote'},
{'user_id': 1, 'action': 'skip'},
], {
0: Action.upvote.name,
1: Action.downvote.name,
2: Action.skip.name,
}),
([
{'user_id': 2, 'action': 'upvote'},
{'user_id': 3, 'action': 'upvote'},
{'user_id': 4, 'action': 'downvote'},
{'user_id': 1, 'action': 'skip'},
], {
0: Action.upvote.name,
1: Action.upvote.name,
2: Action.downvote.name,
3: Action.skip.name,
}),
))
@pytest.mark.asyncio
async def test_query_simplified_user_actions(
Expand All @@ -565,16 +614,18 @@ async def test_query_simplified_user_actions(
user_generator,
playback_generator,
user_action_generator,
loops,
output
):
actions = ['upvote', 'downvote', 'upvote', 'skip']
user_action,
expected_result):
track = await track_generator()
user = await user_generator()
playback = await playback_generator(user=user, track=track)
for _, action in zip(range(loops), actions):
await user_action_generator(user=user, playback=playback, action=action)
registered_user = await user_generator()
await user_generator()
await user_generator()
await user_generator()
playback = await playback_generator(user=registered_user, track=track)

user_actions = await query_simplified_user_actions(playback_id=playback['id'], conn=db_conn)
result = {ua['action'] for ua in user_actions}
assert result == output
for action in user_action:
await user_action_generator(user={'id': action['user_id']}, playback=playback, action=action['action'])

simplified_user_action = await query_simplified_user_actions(playback_id=playback['id'], conn=db_conn)
actual_result = {i: user_action['action'] for i, user_action in enumerate(simplified_user_action)}
assert expected_result == actual_result

0 comments on commit 18b4c2c

Please sign in to comment.