Skip to content

Commit

Permalink
fix: query_simplified_user_actions and add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
dtgoitia committed Aug 4, 2018
1 parent 36c3987 commit 153cba7
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 98 deletions.
72 changes: 37 additions & 35 deletions mosbot/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def save_user(*, user_dict: dict, conn=None) -> dict:


async def get_or_save_user(*, user_dict: dict, conn=None) -> dict:
"""Try to retrieve a given user. If it doesn't exist, try to create it."""
user = await get_user(user_dict=user_dict, conn=conn)
if user:
return user
Expand Down Expand Up @@ -373,41 +374,42 @@ async def query_simplified_user_actions(playback_id, *, conn=None) -> List[dict]
:param conn: A connection if any open
:return: A list of the records
"""
sub_query = sa.select([
db.UserAction.c.user_id,
saf.max(db.UserAction.c.ts).label('ts'),
db.UserAction.c.playback_id,
]).where(
db.UserAction.c.playback_id == playback_id
).group_by(
db.UserAction.c.user_id,
db.UserAction.c.playback_id,
sa.case([
(db.UserAction.c.user_id.is_(None), db.UserAction.c.id),
], else_=0)
).alias()

query = sa.select([
sa.distinct(db.UserAction.c.id),
db.UserAction.c.action,
db.UserAction.c.playback_id,
db.UserAction.c.ts,
db.UserAction.c.user_id,
]).select_from(
db.UserAction.join(
sub_query,
sa.and_(
sub_query.c.ts == db.UserAction.c.ts,
db.UserAction.c.playback_id == sub_query.c.playback_id,
sa.case([
(sa.and_(
db.UserAction.c.user_id.is_(None),
sub_query.c.user_id.is_(None)
), sa.true())
], else_=db.UserAction.c.user_id == sub_query.c.user_id)
)
)
)
query = f"""
select *
from (
select
"song_title",
"username",
"action_timestamp",
"action",
"rn",
lag("action") over (partition by "username") as "next_action"
from (
select *
from (
select
p.id "playback_id",
u.username "username",
u.id "user_id",
t."name" "song_title",
ua.id "action_id",
ua."action" "action",
ua.ts "action_timestamp",
row_number() over (partition by ua.user_id, p.id order by u.username asc, ua.ts desc) rn
from playback p
left join user_action ua on p.id = ua.playback_id
left join track t on p.track_id = t.id
left join "user" u on ua.user_id = u.id
where p.id = {playback_id}
) sub
) sub2
) sub3
where
"rn" = 1 or
"next_action" = 'skip' or
"username" = null
order by "action_timestamp"
"""
async with ensure_connection(conn) as conn:
result = []
async for user_action in await conn.execute(query):
Expand Down
132 changes: 69 additions & 63 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,22 +83,22 @@ async def test_ensure_connection(


@pytest.mark.parametrize('data_dict,expected_result', (
(
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': 'ES'},
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': 'ES'},
),
(
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': None},
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': None},
),
(
{'id': 1, 'dtid': '1234', 'username': 'username'},
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': None},
),
(
{'dtid': '1234', 'username': 'username'},
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': None},
),
(
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': 'ES'},
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': 'ES'},
),
(
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': None},
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': None},
),
(
{'id': 1, 'dtid': '1234', 'username': 'username'},
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': None},
),
(
{'dtid': '1234', 'username': 'username'},
{'id': 1, 'dtid': '1234', 'username': 'username', 'country': None},
),
), ids=[
'all_values',
'all_values_with_null',
Expand Down Expand Up @@ -127,10 +127,10 @@ async def test_execute_and_first(db_conn, data_dict, expected_result):

@pytest.mark.asyncio
@pytest.mark.parametrize('user_dict, raises_exception', (
({'id': 1}, False),
({'username': 'Username 1'}, False),
({'dtid': '00000001-0001-0001-0001-0000000001'}, False),
({'country': 'Country 1'}, ValueError),
({'id': 1}, False),
({'username': 'Username 1'}, False),
({'dtid': '00000001-0001-0001-0001-0000000001'}, False),
({'country': 'Country 1'}, ValueError),
), ids=['by_id', 'by_usernam', 'by_dtid', 'failing_by_country'])
async def test_get_user(db_conn, user_generator, user_dict, raises_exception):
user = await user_generator()
Expand All @@ -143,7 +143,7 @@ async def test_get_user(db_conn, user_generator, user_dict, raises_exception):


@pytest.mark.parametrize('user_dict, raises_exception', (
({}, False),
({}, False),
))
@pytest.mark.asyncio
async def test_save_user(db_conn, user_dict, raises_exception):
Expand Down Expand Up @@ -203,11 +203,11 @@ async def test_get_or_save_user(


@pytest.mark.parametrize('track_dict, raises_exception', (
({'id': 1}, False),
({'extid': 'Extid 1'}, False),
({'origin': 'youtube'}, AssertionError),
({'length': 120}, AssertionError),
({'extid': 'Extid 1', 'origin': 'youtube'}, False),
({'id': 1}, False),
({'extid': 'Extid 1'}, False),
({'origin': 'youtube'}, AssertionError),
({'length': 120}, AssertionError),
({'extid': 'Extid 1', 'origin': 'youtube'}, False),
), ids=['by_id', 'by_extid', 'by_origin', 'by_length', 'by_extid+origin'], )
@pytest.mark.asyncio
async def test_get_track(db_conn, track_generator, track_dict, raises_exception):
Expand Down Expand Up @@ -282,10 +282,10 @@ async def test_get_or_save_track(


@pytest.mark.parametrize('playback_dict, raises_exception', (
({'id': 1}, False),
({'start': datetime.datetime(year=1, month=1, day=1)}, False),
({'track_id': 1}, ValueError),
({'user_id': 1}, ValueError),
({'id': 1}, False),
({'start': datetime.datetime(year=1, month=1, day=1)}, False),
({'track_id': 1}, ValueError),
({'user_id': 1}, ValueError),
), ids=['by_id', 'by_start', 'by_track_id', 'by_user_id'], )
@pytest.mark.asyncio
async def test_get_playback(
Expand Down Expand Up @@ -364,10 +364,10 @@ async def test_get_or_save_playback(


@pytest.mark.parametrize('user_action_dict, raises_exception', (
({'id': 1}, False),
({'ts': datetime.datetime(year=1, month=1, day=1)}, ValueError),
({'track_id': 1}, ValueError),
({'user_id': 1}, ValueError),
({'id': 1}, False),
({'ts': datetime.datetime(year=1, month=1, day=1)}, ValueError),
({'track_id': 1}, ValueError),
({'user_id': 1}, ValueError),
), ids=['by_id', 'by_start', 'by_track_id', 'by_user_id'], )
@pytest.mark.asyncio
async def test_get_user_action(
Expand Down Expand Up @@ -431,12 +431,12 @@ async def test_save_user_action(


@pytest.mark.parametrize('bot_data, raises_exception', (
(('id', 1), False),
(('extid', 'Extid 1'), False),
(('origin', {'a': 1, 'b': 2}), False),
(('length', {1, 2, 3, 120}), TypeError),
(('other', [1, 2, 3, 120]), False),
(('extid', None), False),
(('id', 1), False),
(('extid', 'Extid 1'), False),
(('origin', {'a': 1, 'b': 2}), False),
(('length', {1, 2, 3, 120}), TypeError),
(('other', [1, 2, 3, 120]), False),
(('extid', None), False),
), ids=['number', 'string', 'dict', 'set', 'list', 'null'], )
@pytest.mark.asyncio
async def test_bot_data(db_conn, bot_data, raises_exception):
Expand Down Expand Up @@ -518,13 +518,13 @@ async def test_user_dub_user_actions(


@pytest.mark.parametrize('input,output', (
('upvote', Action.upvote),
('updub', Action.upvote),
('updubs', Action.upvote),
('downvote', Action.downvote),
('downdub', Action.downvote),
('downdubs', Action.downvote),
('none', None),
('upvote', Action.upvote),
('updub', Action.upvote),
('updubs', Action.upvote),
('downvote', Action.downvote),
('downdub', Action.downvote),
('downdubs', Action.downvote),
('none', None),
))
def test_get_dub_action(input, output):
if output is None:
Expand All @@ -535,13 +535,13 @@ def test_get_dub_action(input, output):


@pytest.mark.parametrize('input,output', (
('upvote', Action.downvote),
('updub', Action.downvote),
('updubs', Action.downvote),
('downvote', Action.upvote),
('downdub', Action.upvote),
('downdubs', Action.upvote),
('none', None),
('upvote', Action.downvote),
('updub', Action.downvote),
('updubs', Action.downvote),
('downvote', Action.upvote),
('downdub', Action.upvote),
('downdubs', Action.upvote),
('none', None),
))
def test_get_opposite_dub_action(input, output):
if output is None:
Expand All @@ -552,11 +552,15 @@ def test_get_opposite_dub_action(input, output):


@pytest.mark.parametrize('loops, output', (
(1, {Action.upvote, }),
(2, {Action.downvote, }),
(3, {Action.upvote, }),
pytest.param(4, {Action.upvote, Action.skip}, marks=pytest.mark.xfail(
reason='The query is not complete enough to only aggregate upvotes and downvotes')),
(1, {(0, Action.upvote.name), }),
(2, {(0, Action.downvote.name), }),
(3, {(0, Action.upvote.name), }),
(4, {(1, Action.upvote.name), (0, Action.upvote.name), }),
(5, {(2, Action.upvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }),
(6, {(2, Action.downvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }),
(7, {(2, Action.upvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }),
(8, {(3, Action.downvote.name), (2, Action.upvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }),
(9, {(4, Action.skip.name), (3, Action.downvote.name), (2, Action.upvote.name), (1, Action.upvote.name), (0, Action.upvote.name), }),
))
@pytest.mark.asyncio
async def test_query_simplified_user_actions(
Expand All @@ -568,13 +572,15 @@ async def test_query_simplified_user_actions(
loops,
output
):
actions = ['upvote', 'downvote', 'upvote', 'skip']
# User: 1 1 1 2 3 1 1 4 1
actions = ['upvote', 'downvote', 'upvote', 'upvote', 'upvote', 'downvote', 'upvote', 'downvote', 'skip']
track = await track_generator()
user = await user_generator()
playback = await playback_generator(user=user, track=track)
for _, action in zip(range(loops), actions):
await user_action_generator(user=user, playback=playback, action=action)
for i, action in zip(range(loops), actions):
generated_user = user if i not in [3, 4, 7] else await user_generator()
await user_action_generator(user=generated_user, playback=playback, action=action)

user_actions = await query_simplified_user_actions(playback_id=playback['id'], conn=db_conn)
result = {ua['action'] for ua in user_actions}
result = {(n, ua['action']) for n, ua in enumerate(user_actions)}
assert result == output

0 comments on commit 153cba7

Please sign in to comment.