diff --git a/database/repositories/achievements.py b/database/repositories/achievements.py index 5ed46f6..d3dac26 100644 --- a/database/repositories/achievements.py +++ b/database/repositories/achievements.py @@ -1,29 +1,32 @@ +from __future__ import annotations + from app.common.database.objects import DBAchievement from app.common.objects import bAchievement +from .wrapper import session_wrapper +from sqlalchemy.orm import Session from typing import List -import app - +@session_wrapper def create_many( achievements: List[bAchievement], - user_id: int + user_id: int, + session: Session | None = None ) -> None: - with app.session.database.managed_session() as session: - for a in achievements: - session.add( - DBAchievement( - user_id, - a.name, - a.category, - a.filename - ) + for a in achievements: + session.add( + DBAchievement( + user_id, + a.name, + a.category, + a.filename ) - session.commit() + ) + session.commit() -def fetch_many(user_id: int) -> List[DBAchievement]: - with app.session.database.managed_session() as session: - return session.query(DBAchievement) \ - .filter(DBAchievement.user_id == user_id) \ - .all() +@session_wrapper +def fetch_many(user_id: int, session: Session | None = None) -> List[DBAchievement]: + return session.query(DBAchievement) \ + .filter(DBAchievement.user_id == user_id) \ + .all() diff --git a/database/repositories/activities.py b/database/repositories/activities.py index a2df85c..3c294e0 100644 --- a/database/repositories/activities.py +++ b/database/repositories/activities.py @@ -1,41 +1,45 @@ +from __future__ import annotations + from app.common.database.objects import DBActivity from datetime import datetime, timedelta +from sqlalchemy.orm import Session from typing import List -import app +from .wrapper import session_wrapper +@session_wrapper def create( user_id: int, mode: int, text: str, args: str, - links: str + links: str, + session: Session | None = None ) -> DBActivity: - with app.session.database.managed_session() as session: - session.add( - ac := DBActivity( - user_id, - mode, - text, - args, - links - ) + session.add( + ac := DBActivity( + user_id, + mode, + text, + args, + links ) - session.commit() - session.refresh(ac) - + ) + session.commit() + session.refresh(ac) return ac +@session_wrapper def fetch_recent( user_id: int, mode: int, - until: timedelta = timedelta(days=30) + until: timedelta = timedelta(days=30), + session: Session | None = None ) -> List[DBActivity]: - with app.session.database.managed_session() as session: - return session.query(DBActivity) \ - .filter(DBActivity.time > datetime.now() - until) \ - .filter(DBActivity.user_id == user_id) \ - .filter(DBActivity.mode == mode) \ - .order_by(DBActivity.id.desc()) \ - .all() + return session.query(DBActivity) \ + .filter(DBActivity.time > datetime.now() - until) \ + .filter(DBActivity.user_id == user_id) \ + .filter(DBActivity.mode == mode) \ + .order_by(DBActivity.id.desc()) \ + .all() diff --git a/database/repositories/beatmaps.py b/database/repositories/beatmaps.py index b1b7258..acb3ffb 100644 --- a/database/repositories/beatmaps.py +++ b/database/repositories/beatmaps.py @@ -1,13 +1,19 @@ +from __future__ import annotations + from app.common.database.objects import DBBeatmap from sqlalchemy.orm import selectinload from sqlalchemy import func +from .wrapper import session_wrapper + +from sqlalchemy.orm import Session from typing import Optional, List from datetime import datetime import app +@session_wrapper def create( id: int, set_id: int, @@ -25,72 +31,71 @@ def create( ar: float, od: float, hp: float, - diff: float + diff: float, + session: Session | None = None ) -> DBBeatmap: - with app.session.database.managed_session() as session: - session.add( - m := DBBeatmap( - id, - set_id, - mode, - md5, - status, - version, - filename, - created_at, - last_update, - total_length, - max_combo, - bpm, - cs, - ar, - od, - hp, - diff - ) + session.add( + m := DBBeatmap( + id, + set_id, + mode, + md5, + status, + version, + filename, + created_at, + last_update, + total_length, + max_combo, + bpm, + cs, + ar, + od, + hp, + diff ) - session.commit() - session.refresh(m) - + ) + session.commit() + session.refresh(m) return m -def fetch_by_id(id: int) -> Optional[DBBeatmap]: - with app.session.database.managed_session() as session: - return session.query(DBBeatmap) \ - .options(selectinload(DBBeatmap.beatmapset)) \ - .filter(DBBeatmap.id == id) \ - .first() +@session_wrapper +def fetch_by_id(id: int, session: Session | None = None) -> Optional[DBBeatmap]: + return session.query(DBBeatmap) \ + .options(selectinload(DBBeatmap.beatmapset)) \ + .filter(DBBeatmap.id == id) \ + .first() -def fetch_by_file(filename: str) -> Optional[DBBeatmap]: - with app.session.database.managed_session() as session: - return session.query(DBBeatmap) \ - .options(selectinload(DBBeatmap.beatmapset)) \ - .filter(DBBeatmap.filename == filename) \ - .first() +@session_wrapper +def fetch_by_file(filename: str, session: Session | None = None) -> Optional[DBBeatmap]: + return session.query(DBBeatmap) \ + .options(selectinload(DBBeatmap.beatmapset)) \ + .filter(DBBeatmap.filename == filename) \ + .first() -def fetch_by_checksum(checksum: str) -> Optional[DBBeatmap]: - with app.session.database.managed_session() as session: - return session.query(DBBeatmap) \ - .options(selectinload(DBBeatmap.beatmapset)) \ - .filter(DBBeatmap.md5 == checksum) \ - .first() +@session_wrapper +def fetch_by_checksum(checksum: str, session: Session | None = None) -> Optional[DBBeatmap]: + return session.query(DBBeatmap) \ + .options(selectinload(DBBeatmap.beatmapset)) \ + .filter(DBBeatmap.md5 == checksum) \ + .first() -def fetch_by_set(set_id: int) -> List[DBBeatmap]: - with app.session.database.managed_session() as session: - return session.query(DBBeatmap) \ - .filter(DBBeatmap.set_id == set_id) \ - .all() +@session_wrapper +def fetch_by_set(set_id: int, session: Session | None = None) -> List[DBBeatmap]: + return session.query(DBBeatmap) \ + .filter(DBBeatmap.set_id == set_id) \ + .all() -def fetch_count() -> int: - with app.session.database.managed_session() as session: - return session.query(func.count(DBBeatmap.id)) \ - .scalar() +@session_wrapper +def fetch_count(session: Session | None = None) -> int: + return session.query(func.count(DBBeatmap.id)) \ + .scalar() -def update(beatmap_id: int, updates: dict) -> int: - with app.session.database.managed_session() as session: - rows = session.query(DBBeatmap) \ - .filter(DBBeatmap.id == beatmap_id) \ - .update(updates) - session.commit() +@session_wrapper +def update(beatmap_id: int, updates: dict, session: Session | None = None) -> int: + rows = session.query(DBBeatmap) \ + .filter(DBBeatmap.id == beatmap_id) \ + .update(updates) + session.commit() return rows diff --git a/database/repositories/beatmapsets.py b/database/repositories/beatmapsets.py index 0071efe..b854a37 100644 --- a/database/repositories/beatmapsets.py +++ b/database/repositories/beatmapsets.py @@ -1,4 +1,6 @@ +from __future__ import annotations + from app.common.constants import ( BeatmapCategory, BeatmapSortBy, @@ -13,16 +15,14 @@ DBPlay ) -from ...helpers.caching import ttl_cache - -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import selectinload, Session from sqlalchemy import func, or_, and_ +from .wrapper import session_wrapper from typing import Optional, List from datetime import datetime -import app - +@session_wrapper def create( id: int, title: str, @@ -41,123 +41,127 @@ def create( osz_filesize: int = 0, osz_filesize_novideo: int = 0, available: bool = True, - server: int = 0 + server: int = 0, + session: Session | None = None ) -> DBBeatmapset: - with app.session.database.managed_session() as session: - session.add( - s := DBBeatmapset( - id, - title, - artist, - creator, - source, - tags, - status, - has_video, - has_storyboard, - created_at, - approved_at, - last_update, - language_id, - genre_id, - osz_filesize, - osz_filesize_novideo, - available, - server - ) + session.add( + s := DBBeatmapset( + id, + title, + artist, + creator, + source, + tags, + status, + has_video, + has_storyboard, + created_at, + approved_at, + last_update, + language_id, + genre_id, + osz_filesize, + osz_filesize_novideo, + available, + server ) - session.commit() - session.refresh(s) - + ) + session.commit() + session.refresh(s) return s -def fetch_one(id: int) -> Optional[DBBeatmapset]: - with app.session.database.managed_session() as session: - return session.query(DBBeatmapset) \ - .filter(DBBeatmapset.id == id) \ - .first() +@session_wrapper +def fetch_one(id: int, session: Session | None = None) -> Optional[DBBeatmapset]: + return session.query(DBBeatmapset) \ + .filter(DBBeatmapset.id == id) \ + .first() +@session_wrapper def search( query_string: str, user_id: int, - display_mode = DisplayMode.All + display_mode = DisplayMode.All, + session: Session | None = None ) -> List[DBBeatmapset]: - with app.session.database.managed_session() as session: - query = session.query(DBBeatmapset) - - if query_string == 'Newest': - query = query.order_by(DBBeatmapset.created_at.desc()) - - elif query_string == 'Top Rated': - query = query.join(DBRating) \ - .group_by(DBBeatmapset.id) \ - .order_by(func.avg(DBRating.rating).desc()) - - elif query_string == 'Most Played': - query = query.join(DBBeatmap) \ - .group_by(DBBeatmapset.id) \ - .order_by(func.sum(DBBeatmap.playcount).desc()) - - else: - conditions = [] - - keywords = [ - f'{word}%' for word in query_string.strip() \ - .replace(' - ', ' ') \ - .lower() \ - .split() - ] + query = session.query(DBBeatmapset) - searchable_columns = [ - func.to_tsvector('simple', column) - for column in [ - func.lower(DBBeatmapset.title), - func.lower(DBBeatmapset.artist), - func.lower(DBBeatmapset.creator), - func.lower(DBBeatmapset.source), - func.lower(DBBeatmapset.tags), - func.lower(DBBeatmap.version) - ] + if query_string == 'Newest': + query = query.order_by(DBBeatmapset.created_at.desc()) + + elif query_string == 'Top Rated': + query = query.join(DBRating) \ + .group_by(DBBeatmapset.id) \ + .order_by(func.avg(DBRating.rating).desc()) + + elif query_string == 'Most Played': + query = query.join(DBBeatmap) \ + .group_by(DBBeatmapset.id) \ + .order_by(func.sum(DBBeatmap.playcount).desc()) + + else: + conditions = [] + + keywords = [ + f'%{word}%' for word in query_string.strip() \ + .replace(' - ', ' ') \ + .lower() \ + .split() + ] + + searchable_columns = [ + func.to_tsvector('simple', column) + for column in [ + func.lower(DBBeatmapset.title), + func.lower(DBBeatmapset.artist), + func.lower(DBBeatmapset.creator), + func.lower(DBBeatmapset.source), + func.lower(DBBeatmapset.tags), + func.lower(DBBeatmap.version) ] + ] - for word in keywords: - conditions.append(or_( - *[ - col.op('@@')(func.plainto_tsquery('simple', word)) - for col in searchable_columns - ] - )) + for word in keywords: + conditions.append(or_( + *[ + col.op('@@')(func.plainto_tsquery('simple', word)) + for col in searchable_columns + ] + )) - query = query.join(DBBeatmap) \ - .filter(and_(*conditions)) \ - .order_by(DBBeatmap.playcount.desc()) + query = query.join(DBBeatmap) \ + .filter(and_(*conditions)) \ + .order_by(DBBeatmap.playcount.desc()) - if display_mode == DisplayMode.Ranked: - query = query.filter(DBBeatmapset.status > 0) + if display_mode == DisplayMode.Ranked: + query = query.filter(DBBeatmapset.status > 0) - elif display_mode == DisplayMode.Pending: - query = query.filter(DBBeatmapset.status == 0) + elif display_mode == DisplayMode.Pending: + query = query.filter(DBBeatmapset.status == 0) - elif display_mode == DisplayMode.Graveyard: - query = query.filter(DBBeatmapset.status == -1) + elif display_mode == DisplayMode.Graveyard: + query = query.filter(DBBeatmapset.status == -1) - elif display_mode == DisplayMode.Played: - query = query.join(DBPlay) \ - .filter(DBPlay.user_id == user_id) \ - .filter(DBBeatmapset.status > 0) + elif display_mode == DisplayMode.Played: + query = query.join(DBPlay) \ + .filter(DBPlay.user_id == user_id) \ + .filter(DBBeatmapset.status > 0) - return query.limit(100) \ - .options( - selectinload(DBBeatmapset.beatmaps), - selectinload(DBBeatmapset.ratings) - ) \ - .all() + return query.limit(100) \ + .options( + selectinload(DBBeatmapset.beatmaps), + selectinload(DBBeatmapset.ratings) + ).all() -def search_one(query_string: str, offset: int = 0) -> Optional[DBBeatmapset]: +@session_wrapper +def search_one( + query_string: str, + offset: int = 0, + session: Session | None = None +) -> Optional[DBBeatmapset]: conditions = [] keywords = [ - f'{word}%' for word in query_string.strip() \ + f'%{word}%' for word in query_string.strip() \ .replace(' - ', ' ') \ .lower() \ .split() @@ -183,13 +187,14 @@ def search_one(query_string: str, offset: int = 0) -> Optional[DBBeatmapset]: ] )) - with app.session.database.managed_session() as session: - return session.query(DBBeatmapset) \ - .join(DBBeatmap) \ - .filter(and_(*conditions)) \ - .order_by(DBBeatmap.playcount.desc()) \ - .first() + return session.query(DBBeatmapset) \ + .join(DBBeatmap) \ + .filter(and_(*conditions)) \ + .order_by(DBBeatmap.playcount.desc()) \ + .offset(offset) \ + .first() +@session_wrapper def search_extended( query_string: Optional[str], genre: Optional[int], @@ -203,99 +208,99 @@ def search_extended( has_storyboard: bool, has_video: bool, offset: int = 0, - limit: int = 50 + limit: int = 50, + session: Session | None = None ) -> List[DBBeatmapset]: - with app.session.database.managed_session() as session: - query = session.query(DBBeatmapset) \ - .options( - selectinload(DBBeatmapset.beatmaps), - selectinload(DBBeatmapset.ratings), - selectinload(DBBeatmapset.favourites) - ) \ - .group_by(DBBeatmapset.id) \ - .join(DBBeatmap) - - if query_string: - conditions = [] - - keywords = [ - f'%{word}%' for word in query_string.strip() \ - .replace(' - ', ' ') \ - .lower() \ - .split() - ] + query = session.query(DBBeatmapset) \ + .options( + selectinload(DBBeatmapset.beatmaps), + selectinload(DBBeatmapset.ratings), + selectinload(DBBeatmapset.favourites) + ) \ + .group_by(DBBeatmapset.id) \ + .join(DBBeatmap) + + if query_string: + conditions = [] + + keywords = [ + f'%{word}%' for word in query_string.strip() \ + .replace(' - ', ' ') \ + .lower() \ + .split() + ] - searchable_columns = [ - func.to_tsvector('simple', column) - for column in [ - func.lower(DBBeatmapset.title), - func.lower(DBBeatmapset.artist), - func.lower(DBBeatmapset.creator), - func.lower(DBBeatmapset.source), - func.lower(DBBeatmapset.tags), - func.lower(DBBeatmap.version) - ] + searchable_columns = [ + func.to_tsvector('simple', column) + for column in [ + func.lower(DBBeatmapset.title), + func.lower(DBBeatmapset.artist), + func.lower(DBBeatmapset.creator), + func.lower(DBBeatmapset.source), + func.lower(DBBeatmapset.tags), + func.lower(DBBeatmap.version) ] + ] - for word in keywords: - conditions.append(or_( - *[ - col.op('@@')(func.plainto_tsquery('simple', word)) - for col in searchable_columns - ] - )) - - query = query.filter(and_(*conditions)) - - if sort == BeatmapSortBy.Rating: - query = query.join(DBRating) - - order_type = { - BeatmapSortBy.Created: DBBeatmapset.id, - BeatmapSortBy.Title: DBBeatmapset.title, - BeatmapSortBy.Artist: DBBeatmapset.artist, - BeatmapSortBy.Creator: DBBeatmapset.creator, - BeatmapSortBy.Ranked: DBBeatmapset.approved_at, - BeatmapSortBy.Difficulty: func.max(DBBeatmap.diff), - BeatmapSortBy.Rating: func.avg(DBRating.rating), - BeatmapSortBy.Plays: func.sum(DBBeatmap.playcount), - }[sort] - - query = query.order_by( - order_type.asc() if order == BeatmapOrder.Ascending else - order_type.desc() - ) - - if genre is not None: - query = query.filter(DBBeatmapset.genre_id == genre) - - if language is not None: - query = query.filter(DBBeatmapset.language_id == language) - - if mode is not None: - query = query.filter(DBBeatmapset.beatmaps.any(DBBeatmap.mode == mode)) - - if has_storyboard: - query = query.filter(DBBeatmapset.has_storyboard == True) - - if has_video: - query = query.filter(DBBeatmapset.has_video == True) - - if (played is not None and - user_id is not None): - query = query.join(DBPlay) \ - .filter(DBPlay.user_id == user_id) - - if category > BeatmapCategory.Any: - query = query.filter({ - BeatmapCategory.Leaderboard: (DBBeatmapset.status > 0), - BeatmapCategory.Pending: (DBBeatmapset.status == 0), - BeatmapCategory.Ranked: (DBBeatmapset.status == 1), - BeatmapCategory.Approved: (DBBeatmapset.status == 2), - BeatmapCategory.Qualified: (DBBeatmapset.status == 3), - BeatmapCategory.Loved: (DBBeatmapset.status == 4), - }[category]) - - return query.offset(offset) \ - .limit(limit) \ - .all() + for word in keywords: + conditions.append(or_( + *[ + col.op('@@')(func.plainto_tsquery('simple', word)) + for col in searchable_columns + ] + )) + + query = query.filter(and_(*conditions)) + + if sort == BeatmapSortBy.Rating: + query = query.join(DBRating) + + order_type = { + BeatmapSortBy.Created: DBBeatmapset.id, + BeatmapSortBy.Title: DBBeatmapset.title, + BeatmapSortBy.Artist: DBBeatmapset.artist, + BeatmapSortBy.Creator: DBBeatmapset.creator, + BeatmapSortBy.Ranked: DBBeatmapset.approved_at, + BeatmapSortBy.Difficulty: func.max(DBBeatmap.diff), + BeatmapSortBy.Rating: func.avg(DBRating.rating), + BeatmapSortBy.Plays: func.sum(DBBeatmap.playcount), + }[sort] + + query = query.order_by( + order_type.asc() if order == BeatmapOrder.Ascending else + order_type.desc() + ) + + if genre is not None: + query = query.filter(DBBeatmapset.genre_id == genre) + + if language is not None: + query = query.filter(DBBeatmapset.language_id == language) + + if mode is not None: + query = query.filter(DBBeatmapset.beatmaps.any(DBBeatmap.mode == mode)) + + if has_storyboard: + query = query.filter(DBBeatmapset.has_storyboard == True) + + if has_video: + query = query.filter(DBBeatmapset.has_video == True) + + if (played is not None and + user_id is not None): + query = query.join(DBPlay) \ + .filter(DBPlay.user_id == user_id) + + if category > BeatmapCategory.Any: + query = query.filter({ + BeatmapCategory.Leaderboard: (DBBeatmapset.status > 0), + BeatmapCategory.Pending: (DBBeatmapset.status == 0), + BeatmapCategory.Ranked: (DBBeatmapset.status == 1), + BeatmapCategory.Approved: (DBBeatmapset.status == 2), + BeatmapCategory.Qualified: (DBBeatmapset.status == 3), + BeatmapCategory.Loved: (DBBeatmapset.status == 4), + }[category]) + + return query.offset(offset) \ + .limit(limit) \ + .all() diff --git a/database/repositories/channels.py b/database/repositories/channels.py index f9cfc7f..abca11e 100644 --- a/database/repositories/channels.py +++ b/database/repositories/channels.py @@ -1,29 +1,32 @@ +from __future__ import annotations + from app.common.database.objects import DBChannel +from sqlalchemy.orm import Session from typing import List -import app +from .wrapper import session_wrapper +@session_wrapper def create( name: str, topic: str, read_permissions: int, - write_permissions: int + write_permissions: int, + session: Session | None = None ) -> DBChannel: - with app.session.database.managed_session() as session: - session.add( - chan := DBChannel( - name, - topic, - read_permissions, - write_permissions - ) + session.add( + chan := DBChannel( + name, + topic, + read_permissions, + write_permissions ) - session.commit() - + ) + session.commit() return chan -def fetch_all() -> List[DBChannel]: - with app.session.database.managed_session() as session: - return session.query(DBChannel) \ - .all() +@session_wrapper +def fetch_all(session: Session | None = None) -> List[DBChannel]: + return session.query(DBChannel) \ + .all() diff --git a/database/repositories/clients.py b/database/repositories/clients.py index fc0e79b..2b6b9e4 100644 --- a/database/repositories/clients.py +++ b/database/repositories/clients.py @@ -1,103 +1,112 @@ +from __future__ import annotations + from app.common.database.objects import DBClient +from sqlalchemy.orm import Session from typing import List, Optional -from sqlalchemy import or_ -import app +from .wrapper import session_wrapper +@session_wrapper def create( user_id: int, executable: str, adapters: str, unique_id: str, disk_signature: str, - banned: bool = False + banned: bool = False, + session: Session | None = None ) -> DBClient: - with app.session.database.managed_session() as session: - session.add( - client := DBClient( - user_id, - executable, - adapters, - unique_id, - disk_signature, - banned - ) + session.add( + client := DBClient( + user_id, + executable, + adapters, + unique_id, + disk_signature, + banned ) - session.commit() - + ) + session.commit() return client -def update_all(user_id: int, updates: dict) -> int: - with app.session.database.managed_session() as session: - rows = session.query(DBClient) \ - .filter(DBClient.user_id == user_id) \ - .update(updates) - session.commit() - +@session_wrapper +def update_all( + user_id: int, + updates: dict, + session: Session | None = None +) -> int: + rows = session.query(DBClient) \ + .filter(DBClient.user_id == user_id) \ + .update(updates) + session.commit() return rows +@session_wrapper def fetch_one( user_id: int, executable: str, adapters: str, unique_id: str, - disk_signature: str + disk_signature: str, + session: Session | None = None ) -> Optional[DBClient]: """Fetch one client where all hardware attributes need to match""" - with app.session.database.managed_session() as session: - return session.query(DBClient) \ - .filter(DBClient.user_id == user_id) \ - .filter(DBClient.executable == executable) \ - .filter(DBClient.adapters == adapters) \ - .filter(DBClient.unique_id == unique_id) \ - .filter(DBClient.disk_signature == disk_signature) \ - .first() + return session.query(DBClient) \ + .filter(DBClient.user_id == user_id) \ + .filter(DBClient.executable == executable) \ + .filter(DBClient.adapters == adapters) \ + .filter(DBClient.unique_id == unique_id) \ + .filter(DBClient.disk_signature == disk_signature) \ + .first() +@session_wrapper def fetch_without_executable( user_id: int, adapters: str, unique_id: str, - disk_signature: str + disk_signature: str, + session: Session | None = None ) -> Optional[DBClient]: """Fetch one client with matching hardware and user id""" - with app.session.database.managed_session() as session: - return session.query(DBClient) \ - .filter(DBClient.user_id == user_id) \ - .filter(DBClient.adapters == adapters) \ - .filter(DBClient.unique_id == unique_id) \ - .filter(DBClient.disk_signature == disk_signature) \ - .first() + return session.query(DBClient) \ + .filter(DBClient.user_id == user_id) \ + .filter(DBClient.adapters == adapters) \ + .filter(DBClient.unique_id == unique_id) \ + .filter(DBClient.disk_signature == disk_signature) \ + .first() +@session_wrapper def fetch_hardware_only( adapters: str, unique_id: str, - disk_signature: str + disk_signature: str, + session: Session | None = None ) -> List[DBClient]: """Fetch clients only by hardware attributes. Useful for multi-account detection.""" - with app.session.database.managed_session() as session: - return session.query(DBClient) \ - .filter(DBClient.adapters == adapters) \ - .filter(DBClient.unique_id == unique_id) \ - .filter(DBClient.disk_signature == disk_signature) \ - .all() + return session.query(DBClient) \ + .filter(DBClient.adapters == adapters) \ + .filter(DBClient.unique_id == unique_id) \ + .filter(DBClient.disk_signature == disk_signature) \ + .all() +@session_wrapper def fetch_many( user_id: int, limit: int = 50, - offset: int = 0 + offset: int = 0, + session: Session | None = None ) -> List[DBClient]: """Fetch every client from user id""" - with app.session.database.managed_session() as session: - return session.query(DBClient) \ - .filter(DBClient.user_id == user_id) \ - .limit(limit) \ - .offset(offset) \ - .all() + return session.query(DBClient) \ + .filter(DBClient.user_id == user_id) \ + .limit(limit) \ + .offset(offset) \ + .all() -def fetch_all(user_id: int) -> List[DBClient]: +@session_wrapper +def fetch_all(user_id: int, session: Session | None = None) -> List[DBClient]: """Fetch every client from user id""" - with app.session.database.managed_session() as session: - return session.query(DBClient) \ - .filter(DBClient.user_id == user_id) \ - .all() + return session.query(DBClient) \ + .filter(DBClient.user_id == user_id) \ + .all() diff --git a/database/repositories/comments.py b/database/repositories/comments.py index 1d1ae9a..3d37e5a 100644 --- a/database/repositories/comments.py +++ b/database/repositories/comments.py @@ -1,9 +1,13 @@ +from __future__ import annotations + from app.common.database.objects import DBComment +from sqlalchemy.orm import Session from typing import List -import app +from .wrapper import session_wrapper +@session_wrapper def create( target_id: int, target: str, @@ -12,30 +16,33 @@ def create( content: str, comment_format: str, playmode: int, - color: str + color: str, + session: Session | None = None ) -> DBComment: - with app.session.database.managed_session() as session: - session.add( - c := DBComment( - target_id, - target, - user_id, - time, - content, - comment_format, - playmode, - color - ) + session.add( + c := DBComment( + target_id, + target, + user_id, + time, + content, + comment_format, + playmode, + color ) - session.commit() - session.refresh(c) - + ) + session.commit() + session.refresh(c) return c -def fetch_many(target_id: int, type: str) -> List[DBComment]: - with app.session.database.managed_session() as session: - return session.query(DBComment) \ - .filter(DBComment.target_id == target_id) \ - .filter(DBComment.target_type == type) \ - .order_by(DBComment.time.asc()) \ - .all() +@session_wrapper +def fetch_many( + target_id: int, + type: str, + session: Session | None = None +) -> List[DBComment]: + return session.query(DBComment) \ + .filter(DBComment.target_id == target_id) \ + .filter(DBComment.target_type == type) \ + .order_by(DBComment.time.asc()) \ + .all() diff --git a/database/repositories/events.py b/database/repositories/events.py index a69bd5f..9eff44f 100644 --- a/database/repositories/events.py +++ b/database/repositories/events.py @@ -1,52 +1,55 @@ +from __future__ import annotations + from app.common.database.objects import DBMatchEvent from app.common.constants import EventType +from sqlalchemy.orm import Session from typing import List, Optional -import app +from .wrapper import session_wrapper +@session_wrapper def create( match_id: int, type: EventType, data: dict = {}, + session: Session | None = None ) -> DBMatchEvent: - with app.session.database.managed_session() as session: - session.add( - m := DBMatchEvent( - match_id, - type.value, - data - ) + session.add( + m := DBMatchEvent( + match_id, + type.value, + data ) - session.commit() - session.refresh(m) - + ) + session.commit() + session.refresh(m) return m -def fetch_last(match_id: int) -> Optional[DBMatchEvent]: - with app.session.database.managed_session() as session: - return session.query(DBMatchEvent) \ - .filter(DBMatchEvent.match_id == match_id) \ - .order_by(DBMatchEvent.time.desc()) \ - .first() - -def fetch_last_by_type(match_id: int, type: int) -> Optional[DBMatchEvent]: - with app.session.database.managed_session() as session: - return session.query(DBMatchEvent) \ - .filter(DBMatchEvent.match_id == match_id) \ - .filter(DBMatchEvent.type == type) \ - .order_by(DBMatchEvent.time.desc()) \ - .first() - -def fetch_all(match_id: int) -> List[DBMatchEvent]: - with app.session.database.managed_session() as session: - return session.query(DBMatchEvent) \ - .filter(DBMatchEvent.match_id == match_id) \ - .all() - -def delete_all(match_id: int) -> None: - with app.session.database.managed_session() as session: - session.query(DBMatchEvent) \ - .filter(DBMatchEvent.match_id == match_id) \ - .delete() - session.commit() +@session_wrapper +def fetch_last(match_id: int, session: Session | None = None) -> Optional[DBMatchEvent]: + return session.query(DBMatchEvent) \ + .filter(DBMatchEvent.match_id == match_id) \ + .order_by(DBMatchEvent.time.desc()) \ + .first() + +@session_wrapper +def fetch_last_by_type(match_id: int, type: int, session: Session | None = None) -> Optional[DBMatchEvent]: + return session.query(DBMatchEvent) \ + .filter(DBMatchEvent.match_id == match_id) \ + .filter(DBMatchEvent.type == type) \ + .order_by(DBMatchEvent.time.desc()) \ + .first() + +@session_wrapper +def fetch_all(match_id: int, session: Session | None = None) -> List[DBMatchEvent]: + return session.query(DBMatchEvent) \ + .filter(DBMatchEvent.match_id == match_id) \ + .all() + +@session_wrapper +def delete_all(match_id: int, session: Session | None = None) -> None: + session.query(DBMatchEvent) \ + .filter(DBMatchEvent.match_id == match_id) \ + .delete() + session.commit() diff --git a/database/repositories/favourites.py b/database/repositories/favourites.py index 95a3259..2dd0e39 100644 --- a/database/repositories/favourites.py +++ b/database/repositories/favourites.py @@ -1,62 +1,67 @@ +from __future__ import annotations + from app.common.database.objects import DBFavourite +from sqlalchemy.orm import Session from typing import List, Optional -import app +from .wrapper import session_wrapper +@session_wrapper def create( user_id: int, - set_id: int + set_id: int, + session: Session | None = None ) -> Optional[DBFavourite]: - with app.session.database.managed_session() as session: - # Check if favourite was already set - if session.query(DBFavourite.user_id) \ - .filter(DBFavourite.user_id == user_id) \ - .filter(DBFavourite.set_id == set_id) \ - .first(): - return None - - session.add( - fav := DBFavourite( - user_id, - set_id - ) + # Check if favourite was already set + if session.query(DBFavourite.user_id) \ + .filter(DBFavourite.user_id == user_id) \ + .filter(DBFavourite.set_id == set_id) \ + .first(): + return None + + session.add( + fav := DBFavourite( + user_id, + set_id ) - session.commit() + ) + session.commit() return fav +@session_wrapper def fetch_one( user_id: int, - set_id: int + set_id: int, + session: Session | None = None ) -> Optional[DBFavourite]: - with app.session.database.managed_session() as session: - return session.query(DBFavourite) \ - .filter(DBFavourite.user_id == user_id) \ - .filter(DBFavourite.set_id == set_id) \ - .first() - -def fetch_many(user_id: int) -> List[DBFavourite]: - with app.session.database.managed_session() as session: - return session.query(DBFavourite) \ - .filter(DBFavourite.user_id == user_id) \ - .all() - -def fetch_many_by_set(set_id: int, limit: int = 5) -> List[DBFavourite]: - with app.session.database.managed_session() as session: - return session.query(DBFavourite) \ - .filter(DBFavourite.set_id == set_id) \ - .limit(limit) \ - .all() - -def fetch_count(user_id: int) -> int: - with app.session.database.managed_session() as session: - return session.query(DBFavourite) \ - .filter(DBFavourite.user_id == user_id) \ - .count() - -def fetch_count_by_set(set_id: int) -> int: - with app.session.database.managed_session() as session: - return session.query(DBFavourite) \ - .filter(DBFavourite.set_id == set_id) \ - .count() + return session.query(DBFavourite) \ + .filter(DBFavourite.user_id == user_id) \ + .filter(DBFavourite.set_id == set_id) \ + .first() + +@session_wrapper +def fetch_many(user_id: int, session: Session | None = None) -> List[DBFavourite]: + return session.query(DBFavourite) \ + .filter(DBFavourite.user_id == user_id) \ + .all() + +@session_wrapper +def fetch_many_by_set(set_id: int, limit: int = 5, session: Session | None = None) -> List[DBFavourite]: + return session.query(DBFavourite) \ + .filter(DBFavourite.set_id == set_id) \ + .limit(limit) \ + .all() + +@session_wrapper +def fetch_count(user_id: int, session: Session | None = None) -> int: + return session.query(DBFavourite) \ + .filter(DBFavourite.user_id == user_id) \ + .count() + +@session_wrapper +def fetch_count_by_set(set_id: int, session: Session | None = None) -> int: + return session.query(DBFavourite) \ + .filter(DBFavourite.set_id == set_id) \ + .count() diff --git a/database/repositories/histories.py b/database/repositories/histories.py index 51f05a1..4282380 100644 --- a/database/repositories/histories.py +++ b/database/repositories/histories.py @@ -1,4 +1,6 @@ +from __future__ import annotations + from app.common.cache import leaderboards from app.common.database.objects import ( DBReplayHistory, @@ -7,106 +9,113 @@ DBStats ) +from sqlalchemy.orm import Session from datetime import datetime from typing import List -import app +from .wrapper import session_wrapper +@session_wrapper def update_plays( user_id: int, - mode: int + mode: int, + session: Session | None = None ) -> None: time = datetime.now() - with app.session.database.managed_session() as session: - updated = session.query(DBPlayHistory) \ - .filter(DBPlayHistory.user_id == user_id) \ - .filter(DBPlayHistory.mode == mode) \ - .filter(DBPlayHistory.year == time.year) \ - .filter(DBPlayHistory.month == time.month) \ - .update({ - 'plays': DBPlayHistory.plays + 1 - }) + updated = session.query(DBPlayHistory) \ + .filter(DBPlayHistory.user_id == user_id) \ + .filter(DBPlayHistory.mode == mode) \ + .filter(DBPlayHistory.year == time.year) \ + .filter(DBPlayHistory.month == time.month) \ + .update({ + 'plays': DBPlayHistory.plays + 1 + }) - if not updated: - session.add( - DBPlayHistory( - user_id, - mode, - plays=1 - ) + if not updated: + session.add( + DBPlayHistory( + user_id, + mode, + plays=1 ) + ) - session.commit() + session.commit() +@session_wrapper def fetch_plays_history( user_id: int, mode: int, - until: datetime + until: datetime, + session: Session | None = None ) -> List[DBPlayHistory]: - with app.session.database.managed_session() as session: - return session.query(DBPlayHistory) \ - .filter(DBPlayHistory.user_id == user_id) \ - .filter(DBPlayHistory.mode == mode) \ - .filter( - DBPlayHistory.year >= until.year, - DBPlayHistory.month >= until.month, - ) \ - .order_by( - DBPlayHistory.year.desc(), - DBPlayHistory.month.desc() - ) \ - .all() - + return session.query(DBPlayHistory) \ + .filter(DBPlayHistory.user_id == user_id) \ + .filter(DBPlayHistory.mode == mode) \ + .filter( + DBPlayHistory.year >= until.year, + DBPlayHistory.month >= until.month, + ) \ + .order_by( + DBPlayHistory.year.desc(), + DBPlayHistory.month.desc() + ) \ + .all() + +@session_wrapper def update_replay_views( user_id: int, - mode: int + mode: int, + session: Session | None = None ) -> None: time = datetime.now() - with app.session.database.managed_session() as session: - updated = session.query(DBReplayHistory) \ - .filter(DBReplayHistory.user_id == user_id) \ - .filter(DBReplayHistory.mode == mode) \ - .filter(DBReplayHistory.year == time.year) \ - .filter(DBReplayHistory.month == time.month) \ - .update({ - 'replay_views': DBReplayHistory.replay_views + 1 - }) - - if not updated: - session.add( - DBReplayHistory( - user_id, - mode, - replay_views=1 - ) + updated = session.query(DBReplayHistory) \ + .filter(DBReplayHistory.user_id == user_id) \ + .filter(DBReplayHistory.mode == mode) \ + .filter(DBReplayHistory.year == time.year) \ + .filter(DBReplayHistory.month == time.month) \ + .update({ + 'replay_views': DBReplayHistory.replay_views + 1 + }) + + if not updated: + session.add( + DBReplayHistory( + user_id, + mode, + replay_views=1 ) + ) - session.commit() + session.commit() +@session_wrapper def fetch_replay_history( user_id: int, mode: int, - until: datetime + until: datetime, + session: Session | None = None ) -> List[DBReplayHistory]: - with app.session.database.managed_session() as session: - return session.query(DBReplayHistory) \ - .filter(DBReplayHistory.user_id == user_id) \ - .filter(DBReplayHistory.mode == mode) \ - .filter( - DBReplayHistory.year >= until.year, - DBReplayHistory.month >= until.month, - ) \ - .order_by( - DBReplayHistory.year.desc(), - DBReplayHistory.month.desc() - ) \ - .all() - + return session.query(DBReplayHistory) \ + .filter(DBReplayHistory.user_id == user_id) \ + .filter(DBReplayHistory.mode == mode) \ + .filter( + DBReplayHistory.year >= until.year, + DBReplayHistory.month >= until.month, + ) \ + .order_by( + DBReplayHistory.year.desc(), + DBReplayHistory.month.desc() + ) \ + .all() + +@session_wrapper def update_rank( stats: DBStats, - country: str + country: str, + session: Session | None = None ) -> None: country_rank = leaderboards.country_rank(stats.user_id, stats.mode, country) global_rank = leaderboards.global_rank(stats.user_id, stats.mode) @@ -125,31 +134,31 @@ def update_rank( if ppv1_rank <= 0: return - with app.session.database.managed_session() as session: - session.add( - DBRankHistory( - stats.user_id, - stats.mode, - stats.rscore, - stats.pp, - stats.ppv1, - global_rank, - country_rank, - score_rank, - ppv1_rank - ) + session.add( + DBRankHistory( + stats.user_id, + stats.mode, + stats.rscore, + stats.pp, + stats.ppv1, + global_rank, + country_rank, + score_rank, + ppv1_rank ) - session.commit() + ) + session.commit() +@session_wrapper def fetch_rank_history( user_id: int, mode: int, - until: datetime + until: datetime, + session: Session | None = None ) -> List[DBRankHistory]: - with app.session.database.managed_session() as session: - return session.query(DBRankHistory) \ - .filter(DBRankHistory.user_id == user_id) \ - .filter(DBRankHistory.mode == mode) \ - .filter(DBRankHistory.time > until) \ - .order_by(DBRankHistory.time.desc()) \ - .all() + return session.query(DBRankHistory) \ + .filter(DBRankHistory.user_id == user_id) \ + .filter(DBRankHistory.mode == mode) \ + .filter(DBRankHistory.time > until) \ + .order_by(DBRankHistory.time.desc()) \ + .all() diff --git a/database/repositories/infringements.py b/database/repositories/infringements.py index 46a99b3..a4724d3 100644 --- a/database/repositories/infringements.py +++ b/database/repositories/infringements.py @@ -1,77 +1,86 @@ +from __future__ import annotations + from app.common.database.objects import DBInfringement from datetime import datetime, timedelta +from sqlalchemy.orm import Session from typing import Optional, List -import app +from .wrapper import session_wrapper +@session_wrapper def create( user_id: int, action: int, length: datetime, description: Optional[str] = None, - is_permanent: bool = False + is_permanent: bool = False, + session: Session | None = None ) -> DBInfringement: - with app.session.database.managed_session() as session: - session.add( - i := DBInfringement( - user_id, - action, - length, - description, - is_permanent - ) + session.add( + i := DBInfringement( + user_id, + action, + length, + description, + is_permanent ) - session.commit() - session.refresh(i) - + ) + session.commit() + session.refresh(i) return i -def fetch_recent(user_id: int) -> Optional[DBInfringement]: - with app.session.database.managed_session() as session: - return session.query(DBInfringement) \ - .filter(DBInfringement.user_id == user_id) \ - .order_by(DBInfringement.id.desc()) \ - .first() +@session_wrapper +def fetch_recent(user_id: int, session: Session | None = None) -> Optional[DBInfringement]: + return session.query(DBInfringement) \ + .filter(DBInfringement.user_id == user_id) \ + .order_by(DBInfringement.id.desc()) \ + .first() -def fetch_recent_by_action(user_id: int, action: int) -> Optional[DBInfringement]: - with app.session.database.managed_session() as session: - return session.query(DBInfringement) \ - .filter(DBInfringement.user_id == user_id) \ - .filter(DBInfringement.action == action) \ - .order_by(DBInfringement.id.desc()) \ - .first() +@session_wrapper +def fetch_recent_by_action(user_id: int, action: int, session: Session | None = None) -> Optional[DBInfringement]: + return session.query(DBInfringement) \ + .filter(DBInfringement.user_id == user_id) \ + .filter(DBInfringement.action == action) \ + .order_by(DBInfringement.id.desc()) \ + .first() -def fetch_all(user_id: int) -> List[DBInfringement]: - with app.session.database.managed_session() as session: - return session.query(DBInfringement) \ - .filter(DBInfringement.user_id == user_id) \ - .order_by(DBInfringement.id.desc()) \ - .all() +@session_wrapper +def fetch_all(user_id: int, session: Session | None = None) -> List[DBInfringement]: + return session.query(DBInfringement) \ + .filter(DBInfringement.user_id == user_id) \ + .order_by(DBInfringement.id.desc()) \ + .all() -def fetch_all_by_action(user_id: int, action: int) -> List[DBInfringement]: - with app.session.database.managed_session() as session: - return session.query(DBInfringement) \ - .filter(DBInfringement.user_id == user_id) \ - .filter(DBInfringement.action == action) \ - .order_by(DBInfringement.time.desc()) \ - .all() +@session_wrapper +def fetch_all_by_action(user_id: int, action: int, session: Session | None = None) -> List[DBInfringement]: + return session.query(DBInfringement) \ + .filter(DBInfringement.user_id == user_id) \ + .filter(DBInfringement.action == action) \ + .order_by(DBInfringement.time.desc()) \ + .all() -def delete_by_id(id: int) -> None: - with app.session.database.managed_session() as session: - session.query(DBInfringement) \ - .filter(DBInfringement.id == id) \ - .delete() +@session_wrapper +def delete_by_id(id: int, session: Session | None = None) -> None: + session.query(DBInfringement) \ + .filter(DBInfringement.id == id) \ + .delete() -def delete_old(user_id: int, delete_after=timedelta(weeks=5), remove_permanent=False) -> int: +@session_wrapper +def delete_old( + user_id: int, + delete_after=timedelta(weeks=5), + remove_permanent=False, + session: Session | None = None +) -> int: if not remove_permanent: - return app.session.database.session.query(DBInfringement) \ - .filter(DBInfringement.user_id == user_id) \ - .filter(DBInfringement.time < datetime.now() - delete_after) \ - .filter(DBInfringement.is_permanent == False) \ - .delete() + return session.query(DBInfringement) \ + .filter(DBInfringement.user_id == user_id) \ + .filter(DBInfringement.time < datetime.now() - delete_after) \ + .filter(DBInfringement.is_permanent == False) \ + .delete() - return app.session.database.session.query(DBInfringement) \ - .filter(DBInfringement.user_id == user_id) \ - .filter(DBInfringement.time < datetime.now() - delete_after) \ - .delete() + return session.query(DBInfringement) \ + .filter(DBInfringement.user_id == user_id) \ + .filter(DBInfringement.time < datetime.now() - delete_after) \ + .delete() diff --git a/database/repositories/logins.py b/database/repositories/logins.py index 29f3493..9007fc0 100644 --- a/database/repositories/logins.py +++ b/database/repositories/logins.py @@ -1,73 +1,80 @@ +from __future__ import annotations + from app.common.database.objects import DBLogin +from sqlalchemy.orm import Session from datetime import datetime from typing import List -import app +from .wrapper import session_wrapper +@session_wrapper def create( user_id: int, ip: str, - version: str + version: str, + session: Session | None = None ) -> DBLogin: - with app.session.database.managed_session() as session: - session.add( - login := DBLogin( - user_id, - ip, - version - ) + session.add( + login := DBLogin( + user_id, + ip, + version ) - session.commit() - + ) + session.commit() return login +@session_wrapper def fetch_many( user_id: int, limit: int = 50, - offset: int = 0 + offset: int = 0, + session: Session | None = None ) -> List[DBLogin]: - with app.session.database.managed_session() as session: - return session.query(DBLogin) \ - .filter(DBLogin.user_id == user_id) \ - .order_by(DBLogin.time.desc()) \ - .limit(limit) \ - .offset(offset) \ - .all() + return session.query(DBLogin) \ + .filter(DBLogin.user_id == user_id) \ + .order_by(DBLogin.time.desc()) \ + .limit(limit) \ + .offset(offset) \ + .all() +@session_wrapper def fetch_many_until( user_id: int, - until: datetime + until: datetime, + session: Session | None = None ) -> List[DBLogin]: - with app.session.database.managed_session() as session: - return session.query(DBLogin) \ - .filter(DBLogin.user_id == user_id) \ - .filter(DBLogin.time > until) \ - .order_by(DBLogin.time.desc()) \ - .all() + return session.query(DBLogin) \ + .filter(DBLogin.user_id == user_id) \ + .filter(DBLogin.time > until) \ + .order_by(DBLogin.time.desc()) \ + .all() +@session_wrapper def fetch_many_by_ip( ip: str, limit: int = 50, - offset: int = 0 + offset: int = 0, + session: Session | None = None ) -> List[DBLogin]: - with app.session.database.managed_session() as session: - return session.query(DBLogin) \ - .filter(DBLogin.ip == ip) \ - .order_by(DBLogin.time.desc()) \ - .limit(limit) \ - .offset(offset) \ - .all() + return session.query(DBLogin) \ + .filter(DBLogin.ip == ip) \ + .order_by(DBLogin.time.desc()) \ + .limit(limit) \ + .offset(offset) \ + .all() +@session_wrapper def fetch_many_by_version( version: str, limit: int = 50, - offset: int = 0 + offset: int = 0, + session: Session | None = None ) -> List[DBLogin]: - with app.session.database.managed_session() as session: - return session.query(DBLogin) \ - .filter(DBLogin.version == version) \ - .order_by(DBLogin.time.desc()) \ - .limit(limit) \ - .offset(offset) \ - .all() + return session.query(DBLogin) \ + .filter(DBLogin.version == version) \ + .order_by(DBLogin.time.desc()) \ + .limit(limit) \ + .offset(offset) \ + .all() diff --git a/database/repositories/logs.py b/database/repositories/logs.py index 47e039f..911ce71 100644 --- a/database/repositories/logs.py +++ b/database/repositories/logs.py @@ -1,23 +1,26 @@ +from __future__ import annotations + from app.common.database.objects import DBLog +from sqlalchemy.orm import Session -import app +from .wrapper import session_wrapper +@session_wrapper def create( message: str, level: str, - type: str + type: str, + session: Session | None = None ) -> DBLog: - with app.session.database.managed_session() as session: - session.add( - log := DBLog( - message, - level, - type - ) + session.add( + log := DBLog( + message, + level, + type ) - session.commit() - + ) + session.commit() return log # TODO: Create fetch queries diff --git a/database/repositories/matches.py b/database/repositories/matches.py index a0ed4d5..16e2656 100644 --- a/database/repositories/matches.py +++ b/database/repositories/matches.py @@ -1,54 +1,57 @@ +from __future__ import annotations + from app.common.database.repositories import events from app.common.database.objects import DBMatch +from sqlalchemy.orm import Session from typing import Optional -import app +from .wrapper import session_wrapper +@session_wrapper def create( name: str, bancho_id: int, creator_id: int, + session: Session | None = None ) -> DBMatch: - with app.session.database.managed_session() as session: - session.add( - m := DBMatch( - name, - creator_id, - bancho_id - ) + session.add( + m := DBMatch( + name, + creator_id, + bancho_id ) - session.commit() - session.refresh(m) - + ) + session.commit() + session.refresh(m) return m -def fetch_by_id(id: int) -> Optional[DBMatch]: - with app.session.database.managed_session() as session: - return session.query(DBMatch) \ - .filter(DBMatch.id == id) \ - .first() - -def fetch_by_bancho_id(id: int) -> Optional[DBMatch]: - with app.session.database.managed_session() as session: - return session.query(DBMatch) \ - .filter(DBMatch.bancho_id == id) \ - .first() - -def update(id: int, updates: dict) -> None: - with app.session.database.managed_session() as session: - session.query(DBMatch) \ - .filter(DBMatch.id == id) \ - .update(updates) - session.commit() - -def delete(id: int) -> None: +@session_wrapper +def fetch_by_id(id: int, session: Session | None = None) -> Optional[DBMatch]: + return session.query(DBMatch) \ + .filter(DBMatch.id == id) \ + .first() + +@session_wrapper +def fetch_by_bancho_id(id: int, session: Session | None = None) -> Optional[DBMatch]: + return session.query(DBMatch) \ + .filter(DBMatch.bancho_id == id) \ + .first() + +@session_wrapper +def update(id: int, updates: dict, session: Session | None = None) -> None: + session.query(DBMatch) \ + .filter(DBMatch.id == id) \ + .update(updates) + session.commit() + +@session_wrapper +def delete(id: int, session: Session | None = None) -> None: # Delete events first events.delete_all(id) - with app.session.database.managed_session() as session: - session.query(DBMatch) \ - .filter(DBMatch.id == id) \ - .delete() - session.commit() + session.query(DBMatch) \ + .filter(DBMatch.id == id) \ + .delete() + session.commit() diff --git a/database/repositories/messages.py b/database/repositories/messages.py index 9f7a576..196d19a 100644 --- a/database/repositories/messages.py +++ b/database/repositories/messages.py @@ -1,31 +1,38 @@ +from __future__ import annotations + from app.common.database.objects import DBMessage +from sqlalchemy.orm import Session from typing import List -import app +from .wrapper import session_wrapper +@session_wrapper def create( sender: str, target: str, - message: str + message: str, + session: Session | None = None ) -> DBMessage: - with app.session.database.managed_session() as session: - session.add( - msg := DBMessage( - sender, - target, - message - ) + session.add( + msg := DBMessage( + sender, + target, + message ) - session.commit() - session.refresh(msg) - + ) + session.commit() + session.refresh(msg) return msg -def fetch_recent(target: str = '#osu', limit: int = 10) -> List[DBMessage]: - with app.session.database.managed_session() as session: - return session.query(DBMessage) \ - .filter(DBMessage.target == target) \ - .order_by(DBMessage.id.desc()) \ - .limit(limit) \ - .all() +@session_wrapper +def fetch_recent( + target: str = '#osu', + limit: int = 10, + session: Session | None = None +) -> List[DBMessage]: + return session.query(DBMessage) \ + .filter(DBMessage.target == target) \ + .order_by(DBMessage.id.desc()) \ + .limit(limit) \ + .all() diff --git a/database/repositories/names.py b/database/repositories/names.py index 4ae7835..66c08db 100644 --- a/database/repositories/names.py +++ b/database/repositories/names.py @@ -1,24 +1,26 @@ +from __future__ import annotations + from app.common.database.objects import DBName +from sqlalchemy.orm import Session from typing import List -import app - -def create(user_id: int, old_name: str) -> DBName: - with app.session.database.session as session: - session.add(name := DBName(user_id, old_name)) - session.commit() +from .wrapper import session_wrapper +@session_wrapper +def create(user_id: int, old_name: str, session: Session | None = None) -> DBName: + session.add(name := DBName(user_id, old_name)) + session.commit() return name -def fetch_one(id: int): - with app.session.database.managed_session() as session: - return session.query(DBName) \ - .filter(DBName.id == id) \ - .first() +@session_wrapper +def fetch_one(id: int, session: Session | None = None) -> DBName: + return session.query(DBName) \ + .filter(DBName.id == id) \ + .first() -def fetch_all(user_id: int): - with app.session.database.managed_session() as session: - return session.query(DBName) \ - .filter(DBName.user_id == user_id) \ - .all() +@session_wrapper +def fetch_all(user_id: int, session: Session | None = None) -> List[DBName]: + return session.query(DBName) \ + .filter(DBName.user_id == user_id) \ + .all() diff --git a/database/repositories/plays.py b/database/repositories/plays.py index b51adef..2262043 100644 --- a/database/repositories/plays.py +++ b/database/repositories/plays.py @@ -1,87 +1,90 @@ -from app.common.database.objects import DBPlay +from __future__ import annotations -from ...helpers.caching import ttl_cache +from app.common.database.objects import DBPlay +from sqlalchemy.orm import Session from sqlalchemy import func from typing import List -import app +from .wrapper import session_wrapper +@session_wrapper def create( beatmap_file: str, beatmap_id: int, user_id: int, set_id: int, - count: int = 1 + count: int = 1, + session: Session | None = None ) -> DBPlay: - with app.session.database.managed_session() as session: - session.add( - p := DBPlay( - user_id, - beatmap_id, - set_id, - beatmap_file, - count - ) + session.add( + p := DBPlay( + user_id, + beatmap_id, + set_id, + beatmap_file, + count ) - session.commit() - session.refresh(p) - + ) + session.commit() + session.refresh(p) return p +@session_wrapper def update( beatmap_file: str, beatmap_id: int, user_id: int, set_id: int, - count: int = 1 + count: int = 1, + session: Session | None = None ) -> None: - with app.session.database.managed_session() as session: - updated = session.query(DBPlay) \ - .filter(DBPlay.beatmap_id == beatmap_id) \ - .filter(DBPlay.user_id == user_id) \ - .update({ - 'count': DBPlay.count + count - }) + updated = session.query(DBPlay) \ + .filter(DBPlay.beatmap_id == beatmap_id) \ + .filter(DBPlay.user_id == user_id) \ + .update({ + 'count': DBPlay.count + count + }) - if not updated: - create( - beatmap_file, - beatmap_id, - user_id, - set_id, - count - ) + if not updated: + create( + beatmap_file, + beatmap_id, + user_id, + set_id, + count + ) - session.commit() + session.commit() -def fetch_count_for_beatmap(beatmap_id: int) -> int: - with app.session.database.managed_session() as session: - count = session.query( - func.sum(DBPlay.count).label('playcount')) \ - .group_by(DBPlay.beatmap_id) \ - .filter(DBPlay.beatmap_id == beatmap_id) \ - .first() +@session_wrapper +def fetch_count_for_beatmap(beatmap_id: int, session: Session | None = None) -> int: + count = session.query( + func.sum(DBPlay.count).label('playcount')) \ + .group_by(DBPlay.beatmap_id) \ + .filter(DBPlay.beatmap_id == beatmap_id) \ + .first() return count[0] if count else 0 -def fetch_most_played(limit: int = 5) -> List[DBPlay]: - with app.session.database.managed_session() as session: - return session.query(DBPlay) \ - .order_by(DBPlay.count.desc()) \ - .limit(limit) \ - .all() +@session_wrapper +def fetch_most_played(limit: int = 5, session: Session | None = None) -> List[DBPlay]: + return session.query(DBPlay) \ + .order_by(DBPlay.count.desc()) \ + .limit(limit) \ + .all() +@session_wrapper def fetch_most_played_by_user( user_id: int, limit: int = 15, - offset: int = 0 + offset: int = 0, + session: Session | None = None ) -> List[DBPlay]: - with app.session.database.managed_session() as session: - return session.query(DBPlay) \ - .filter(DBPlay.user_id == user_id) \ - .order_by(DBPlay.count.desc()) \ - .limit(limit) \ - .offset(offset) \ - .all() + return session.query(DBPlay) \ + .filter(DBPlay.user_id == user_id) \ + .order_by(DBPlay.count.desc()) \ + .limit(limit) \ + .offset(offset) \ + .all() diff --git a/database/repositories/ratings.py b/database/repositories/ratings.py index 6c76792..44ccf81 100644 --- a/database/repositories/ratings.py +++ b/database/repositories/ratings.py @@ -1,54 +1,57 @@ +from __future__ import annotations + from app.common.database.objects import DBRating +from sqlalchemy.orm import Session from typing import List, Optional from sqlalchemy import func -import app +from .wrapper import session_wrapper +@session_wrapper def create( beatmap_hash: str, user_id: int, set_id: int, - rating: int + rating: int, + session: Session | None = None ) -> DBRating: - with app.session.database.managed_session() as session: - session.add( - rating := DBRating( - user_id, - set_id, - beatmap_hash, - rating - ) + session.add( + rating := DBRating( + user_id, + set_id, + beatmap_hash, + rating ) - session.commit() - session.refresh(rating) - + ) + session.commit() + session.refresh(rating) return rating -def fetch_one(beatmap_hash: str, user_id: int) -> Optional[int]: - with app.session.database.managed_session() as session: - result = session.query(DBRating.rating) \ - .filter(DBRating.map_checksum == beatmap_hash) \ - .filter(DBRating.user_id == user_id) \ - .first() +@session_wrapper +def fetch_one(beatmap_hash: str, user_id: int, session: Session | None = None) -> Optional[int]: + result = session.query(DBRating.rating) \ + .filter(DBRating.map_checksum == beatmap_hash) \ + .filter(DBRating.user_id == user_id) \ + .first() return result[0] if result else None -def fetch_many(beatmap_hash: str) -> List[int]: - with app.session.database.managed_session() as session: - return [ - rating[0] - for rating in session.query(DBRating.rating) \ - .filter(DBRating.map_checksum == beatmap_hash) \ - .all() - ] - -def fetch_average(beatmap_hash: str) -> float: - with app.session.database.managed_session() as session: - result = session.query( - func.avg(DBRating.rating).label('average')) \ +@session_wrapper +def fetch_many(beatmap_hash: str, session: Session | None = None) -> List[int]: + return [ + rating[0] + for rating in session.query(DBRating.rating) \ .filter(DBRating.map_checksum == beatmap_hash) \ - .first()[0] + .all() + ] + +@session_wrapper +def fetch_average(beatmap_hash: str, session: Session | None = None) -> float: + result = session.query( + func.avg(DBRating.rating).label('average')) \ + .filter(DBRating.map_checksum == beatmap_hash) \ + .first()[0] return float(result) if result else 0.0 diff --git a/database/repositories/relationships.py b/database/repositories/relationships.py index a058d36..62dfc51 100644 --- a/database/repositories/relationships.py +++ b/database/repositories/relationships.py @@ -1,73 +1,77 @@ +from __future__ import annotations + from app.common.database.objects import DBRelationship +from sqlalchemy.orm import Session from typing import List -import app +from .wrapper import session_wrapper +@session_wrapper def create( user_id: int, target_id: int, - status: int = 0 + status: int = 0, + session: Session | None = None ) -> DBRelationship: - with app.session.database.managed_session() as session: - session.add( - rel := DBRelationship( - user_id, - target_id, - status - ) + session.add( + rel := DBRelationship( + user_id, + target_id, + status ) - session.commit() - session.refresh(rel) - + ) + session.commit() + session.refresh(rel) return rel +@session_wrapper def delete( user_id: int, target_id: int, - status: int = 0 + status: int = 0, + session: Session | None = None ) -> bool: - with app.session.database.managed_session() as session: - rel = session.query(DBRelationship) \ - .filter(DBRelationship.user_id == user_id) \ - .filter(DBRelationship.target_id == target_id) \ - .filter(DBRelationship.status == status) + rel = session.query(DBRelationship) \ + .filter(DBRelationship.user_id == user_id) \ + .filter(DBRelationship.target_id == target_id) \ + .filter(DBRelationship.status == status) - if rel.first(): - rel.delete() - session.commit() - return True + if rel.first(): + rel.delete() + session.commit() + return True - return False + return False -def fetch_many_by_id(user_id: int) -> List[DBRelationship]: - with app.session.database.managed_session() as session: - return session.query(DBRelationship) \ - .filter(DBRelationship.user_id == user_id) \ - .all() +@session_wrapper +def fetch_many_by_id(user_id: int, session: Session | None = None) -> List[DBRelationship]: + return session.query(DBRelationship) \ + .filter(DBRelationship.user_id == user_id) \ + .all() -def fetch_many_by_target(target_id: int) -> List[DBRelationship]: - with app.session.database.managed_session() as session: +@session_wrapper +def fetch_many_by_target(target_id: int, session: Session | None = None) -> List[DBRelationship]: return session.query(DBRelationship) \ .filter(DBRelationship.target_id == target_id) \ .all() -def fetch_count_by_id(user_id: int) -> int: - with app.session.database.managed_session() as session: - return session.query(DBRelationship) \ - .filter(DBRelationship.user_id == user_id) \ - .count() +@session_wrapper +def fetch_count_by_id(user_id: int, session: Session | None = None) -> int: + return session.query(DBRelationship) \ + .filter(DBRelationship.user_id == user_id) \ + .count() -def fetch_count_by_target(target_id: int) -> int: - with app.session.database.managed_session() as session: - return session.query(DBRelationship) \ - .filter(DBRelationship.target_id == target_id) \ - .count() +@session_wrapper +def fetch_count_by_target(target_id: int, session: Session | None = None) -> int: + return session.query(DBRelationship) \ + .filter(DBRelationship.target_id == target_id) \ + .count() -def fetch_target_ids(user_id: int) -> List[int]: - with app.session.database.managed_session() as session: - result = session.query(DBRelationship.target_id) \ - .filter(DBRelationship.user_id == user_id) \ - .all() +@session_wrapper +def fetch_target_ids(user_id: int, session: Session | None = None) -> List[int]: + result = session.query(DBRelationship.target_id) \ + .filter(DBRelationship.user_id == user_id) \ + .all() return [id[0] for id in result] diff --git a/database/repositories/reports.py b/database/repositories/reports.py index f8bbed7..b6f3146 100644 --- a/database/repositories/reports.py +++ b/database/repositories/reports.py @@ -1,73 +1,77 @@ +from __future__ import annotations + from app.common.database import DBReport +from sqlalchemy.orm import Session from typing import Optional, List -import app +from .wrapper import session_wrapper +@session_wrapper def create( target_id: int, sender_id: int, - reason: Optional[str] = None + reason: Optional[str] = None, + session: Session | None = None ) -> DBReport: - with app.session.database.managed_session() as session: - session.add( - r := DBReport( - target_id, - sender_id, - reason - ) + session.add( + r := DBReport( + target_id, + sender_id, + reason ) - session.commit() - session.refresh(r) - + ) + session.commit() + session.refresh(r) return r -def fetch_by_id(id: int) -> Optional[DBReport]: - with app.session.database.managed_session() as session: - return session.query(DBReport) \ - .filter(DBReport.id == id) \ - .first() +@session_wrapper +def fetch_by_id(id: int, session: Session | None = None) -> Optional[DBReport]: + return session.query(DBReport) \ + .filter(DBReport.id == id) \ + .first() -def fetch_last_by_sender(sender_id: int) -> Optional[DBReport]: - with app.session.database.managed_session() as session: - return session.query(DBReport) \ - .filter(DBReport.sender_id == sender_id) \ - .filter(DBReport.resolved == False) \ - .order_by(DBReport.time.desc()) \ - .first() +@session_wrapper +def fetch_last_by_sender(sender_id: int, session: Session | None = None) -> Optional[DBReport]: + return session.query(DBReport) \ + .filter(DBReport.sender_id == sender_id) \ + .filter(DBReport.resolved == False) \ + .order_by(DBReport.time.desc()) \ + .first() -def fetch_last(target_id: int) -> Optional[DBReport]: - with app.session.database.managed_session() as session: - return session.query(DBReport) \ - .filter(DBReport.target_id == target_id) \ - .filter(DBReport.resolved == False) \ - .order_by(DBReport.time.desc()) \ - .first() +@session_wrapper +def fetch_last(target_id: int, session: Session | None = None) -> Optional[DBReport]: + return session.query(DBReport) \ + .filter(DBReport.target_id == target_id) \ + .filter(DBReport.resolved == False) \ + .order_by(DBReport.time.desc()) \ + .first() -def fetch_all_by_sender(sender_id: int) -> List[DBReport]: - with app.session.database.managed_session() as session: - return session.query(DBReport) \ - .filter(DBReport.sender_id == sender_id) \ - .filter(DBReport.resolved == False) \ - .order_by(DBReport.time.desc()) \ - .all() +@session_wrapper +def fetch_all_by_sender(sender_id: int, session: Session | None = None) -> List[DBReport]: + return session.query(DBReport) \ + .filter(DBReport.sender_id == sender_id) \ + .filter(DBReport.resolved == False) \ + .order_by(DBReport.time.desc()) \ + .all() -def fetch_all(target_id: int) -> List[DBReport]: - with app.session.database.managed_session() as session: - return session.query(DBReport) \ - .filter(DBReport.target_id == target_id) \ - .filter(DBReport.resolved == False) \ - .order_by(DBReport.time.desc()) \ - .all() +@session_wrapper +def fetch_all(target_id: int, session: Session | None = None) -> List[DBReport]: + return session.query(DBReport) \ + .filter(DBReport.target_id == target_id) \ + .filter(DBReport.resolved == False) \ + .order_by(DBReport.time.desc()) \ + .all() +@session_wrapper def fetch_by_sender_to_target( sender_id: int, - target_id: int + target_id: int, + session: Session | None = None ) -> Optional[DBReport]: - with app.session.database.managed_session() as session: - return session.query(DBReport) \ - .filter(DBReport.target_id == target_id) \ - .filter(DBReport.sender_id == sender_id) \ - .filter(DBReport.resolved == False) \ - .order_by(DBReport.time.desc()) \ - .first() + return session.query(DBReport) \ + .filter(DBReport.target_id == target_id) \ + .filter(DBReport.sender_id == sender_id) \ + .filter(DBReport.resolved == False) \ + .order_by(DBReport.time.desc()) \ + .first() diff --git a/database/repositories/scores.py b/database/repositories/scores.py index 436cb07..6f9f1ef 100644 --- a/database/repositories/scores.py +++ b/database/repositories/scores.py @@ -1,536 +1,555 @@ +from __future__ import annotations + from app.common.database.objects import ( DBBeatmap, DBScore, DBUser ) +from sqlalchemy.orm import selectinload, Session from sqlalchemy import or_, and_, func -from sqlalchemy.orm import selectinload from typing import Optional, List, Dict from datetime import datetime -import app +from .wrapper import session_wrapper -def create(score: DBScore) -> DBScore: - with app.session.database.managed_session() as session: - session.add(score) - session.commit() - session.refresh(score) +import app +@session_wrapper +def create(score: DBScore, session: Session | None = None) -> DBScore: + session.add(score) + session.commit() + session.refresh(score) return score -def update(score_id: int, updates: dict) -> int: - with app.session.database.session as session: - rows = session.query(DBScore) \ - .filter(DBScore.id == score_id) \ - .update(updates) - session.commit() - +@session_wrapper +def update(score_id: int, updates: dict, session: Session | None = None) -> int: + rows = session.query(DBScore) \ + .filter(DBScore.id == score_id) \ + .update(updates) + session.commit() return rows -def hide_all(user_id: int) -> int: - with app.session.database.managed_session() as session: - rows = session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .update({ - 'status': -1 - }) - session.commit() - - return rows - -def fetch_by_id(id: int) -> Optional[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .options( - selectinload(DBScore.beatmap), - selectinload(DBScore.user) - ) \ - .filter(DBScore.id == id) \ - .first() - -def fetch_by_replay_checksum(checksum: str) -> Optional[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .filter(DBScore.replay_md5 == checksum) \ - .first() - -def fetch_count(user_id: int, mode: int) -> int: - with app.session.database.managed_session() as session: - return session.query(func.count(DBScore.id)) \ +@session_wrapper +def hide_all(user_id: int, session: Session | None = None) -> int: + rows = session.query(DBScore) \ .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) \ - .scalar() - -def fetch_total_count() -> int: - with app.session.database.managed_session() as session: - return session.query(func.count(DBScore.id)) \ - .filter(DBScore.status != -1) \ - .scalar() + .update({ + 'status': -1 + }) + session.commit() + return rows +@session_wrapper +def fetch_by_id(id: int, session: Session | None = None) -> Optional[DBScore]: + return session.query(DBScore) \ + .options( + selectinload(DBScore.beatmap), + selectinload(DBScore.user) + ) \ + .filter(DBScore.id == id) \ + .first() + +@session_wrapper +def fetch_by_replay_checksum(checksum: str, session: Session | None = None) -> Optional[DBScore]: + return session.query(DBScore) \ + .filter(DBScore.replay_md5 == checksum) \ + .first() + +@session_wrapper +def fetch_count(user_id: int, mode: int, session: Session | None = None) -> int: + return session.query(func.count(DBScore.id)) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) \ + .scalar() + +@session_wrapper +def fetch_total_count(session: Session | None = None) -> int: + return session.query(func.count(DBScore.id)) \ + .filter(DBScore.status != -1) \ + .scalar() + +@session_wrapper def fetch_count_beatmap( beatmap_id: int, mode: int, mods: Optional[int] = None, country: Optional[str] = None, - friends: Optional[List[int]] = None + friends: Optional[List[int]] = None, + session: Session | None = None ) -> int: - with app.session.database.managed_session() as session: - query = session.query(func.count(DBScore.id)) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.mode == mode) + query = session.query(func.count(DBScore.id)) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.mode == mode) - if country is not None: - query = query.filter(DBUser.country == country) \ - .join(DBScore.user) + if country is not None: + query = query.filter(DBUser.country == country) \ + .join(DBScore.user) - if friends is not None: - query = query.filter(DBScore.user_id.in_(friends)) + if friends is not None: + query = query.filter(DBScore.user_id.in_(friends)) - if mods is not None: - query = query.filter(or_(DBScore.status == 3, DBScore.status == 4)) \ - .filter(DBScore.mods == mods) - else: - query = query.filter(DBScore.status == 3) + if mods is not None: + query = query.filter(or_(DBScore.status == 3, DBScore.status == 4)) \ + .filter(DBScore.mods == mods) + else: + query = query.filter(DBScore.status == 3) - return query.scalar() + return query.scalar() +@session_wrapper def fetch_top_scores( user_id: int, mode: int, exclude_approved: bool = False, limit: int = 100, - offset: int = 0 + offset: int = 0, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - query = session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) + query = session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) - if exclude_approved: - query = query.filter(DBBeatmap.status == 1) \ - .join(DBScore.beatmap) + if exclude_approved: + query = query.filter(DBBeatmap.status == 1) \ + .join(DBScore.beatmap) - return query.order_by(DBScore.pp.desc()) \ - .limit(limit) \ - .offset(offset) \ - .all() + return query.order_by(DBScore.pp.desc()) \ + .limit(limit) \ + .offset(offset) \ + .all() +@session_wrapper def fetch_leader_scores( user_id: int, mode: int, limit: int = 50, - offset: int = 0 + offset: int = 0, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - # Find the maximum total score for each beatmap - subquery = session.query( - DBScore.beatmap_id, - DBScore.mode, - func.max(DBScore.total_score).label('max_total_score') - ) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) \ - .group_by(DBScore.beatmap_id, DBScore.mode) \ - .subquery() - - # Get scores where the user has the highest total score - leader_scores = session.query(DBScore) \ - .join(subquery, and_( - DBScore.beatmap_id == subquery.c.beatmap_id, - DBScore.mode == subquery.c.mode, - DBScore.total_score == subquery.c.max_total_score - )) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) \ - .order_by(DBScore.id.desc()) \ - .limit(limit) \ - .offset(offset) \ - .all() + # Find the maximum total score for each beatmap + subquery = session.query( + DBScore.beatmap_id, + DBScore.mode, + func.max(DBScore.total_score).label('max_total_score') + ) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) \ + .group_by(DBScore.beatmap_id, DBScore.mode) \ + .subquery() + + # Get scores where the user has the highest total score + leader_scores = session.query(DBScore) \ + .join(subquery, and_( + DBScore.beatmap_id == subquery.c.beatmap_id, + DBScore.mode == subquery.c.mode, + DBScore.total_score == subquery.c.max_total_score + )) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) \ + .order_by(DBScore.id.desc()) \ + .limit(limit) \ + .offset(offset) \ + .all() return leader_scores +@session_wrapper def fetch_best( user_id: int, mode: int, - exclude_approved: bool = False + exclude_approved: bool = False, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - query = session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) + query = session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) - if exclude_approved: - query = query.filter(DBBeatmap.status == 1) \ - .join(DBScore.beatmap) + if exclude_approved: + query = query.filter(DBBeatmap.status == 1) \ + .join(DBScore.beatmap) - return query.order_by(DBScore.pp.desc()) \ - .all() + return query.order_by(DBScore.pp.desc()) \ + .all() +@session_wrapper def fetch_personal_best( beatmap_id: int, user_id: int, mode: int, - mods: Optional[int] = None + mods: int | None = None, + session: Session | None = None ) -> Optional[DBScore]: - with app.session.database.managed_session() as session: - if mods is None: - return session.query(DBScore) \ - .options(selectinload(DBScore.user)) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) \ - .first() - + if mods is None: return session.query(DBScore) \ .options(selectinload(DBScore.user)) \ .filter(DBScore.beatmap_id == beatmap_id) \ .filter(DBScore.user_id == user_id) \ .filter(DBScore.mode == mode) \ - .filter(or_(DBScore.status == 3, DBScore.status == 4)) \ - .filter(DBScore.mods == mods) \ + .filter(DBScore.status == 3) \ .first() + return session.query(DBScore) \ + .options(selectinload(DBScore.user)) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .filter(or_(DBScore.status == 3, DBScore.status == 4)) \ + .filter(DBScore.mods == mods) \ + .first() + +@session_wrapper def fetch_range_scores( beatmap_id: int, mode: int, offset: int = 0, - limit: int = 5 + limit: int = 5, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .options(selectinload(DBScore.user)) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) \ - .order_by(DBScore.total_score.desc()) \ - .offset(offset) \ - .limit(limit) \ - .all() - + return session.query(DBScore) \ + .options(selectinload(DBScore.user)) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) \ + .order_by(DBScore.total_score.desc()) \ + .offset(offset) \ + .limit(limit) \ + .all() + +@session_wrapper def fetch_range_scores_country( beatmap_id: int, mode: int, country: str, - limit: int = 5 + limit: int = 5, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .options(selectinload(DBScore.user)) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) \ - .filter(DBUser.country == country) \ - .join(DBScore.user) \ - .limit(limit) \ - .all() - + return session.query(DBScore) \ + .options(selectinload(DBScore.user)) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) \ + .filter(DBUser.country == country) \ + .join(DBScore.user) \ + .limit(limit) \ + .all() + +@session_wrapper def fetch_range_scores_friends( beatmap_id: int, mode: int, friends: List[int], - limit: int = 5 + limit: int = 5, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .options(selectinload(DBScore.user)) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) \ - .filter(DBScore.user_id.in_(friends)) \ - .limit(limit) \ - .all() - + return session.query(DBScore) \ + .options(selectinload(DBScore.user)) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) \ + .filter(DBScore.user_id.in_(friends)) \ + .limit(limit) \ + .all() + +@session_wrapper def fetch_range_scores_mods( beatmap_id: int, mode: int, mods: int, - limit: int = 5 + limit: int = 5, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .options(selectinload(DBScore.user)) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.mode == mode) \ - .filter(or_(DBScore.status == 3, DBScore.status == 4)) \ - .filter(DBScore.mods == mods) \ - .order_by(DBScore.total_score.desc()) \ - .limit(limit) \ - .all() - + return session.query(DBScore) \ + .options(selectinload(DBScore.user)) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.mode == mode) \ + .filter(or_(DBScore.status == 3, DBScore.status == 4)) \ + .filter(DBScore.mods == mods) \ + .order_by(DBScore.total_score.desc()) \ + .limit(limit) \ + .all() + +@session_wrapper def fetch_score_index( user_id: int, beatmap_id: int, mode: int, - mods: Optional[int] = None, - friends: Optional[List[int]] = None, - country: Optional[str] = None + mods: int | None = None, + friends: List[int] | None = None, + country: str | None = None, + session: Session | None = None ) -> int: - with app.session.database.managed_session() as session: - query = session.query(DBScore.user_id, DBScore.mods, func.rank() \ - .over( - order_by=DBScore.total_score.desc() - ).label('rank') - ) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.mode == mode) \ - .order_by(DBScore.total_score.desc()) + query = session.query(DBScore.user_id, DBScore.mods, func.rank() \ + .over( + order_by=DBScore.total_score.desc() + ).label('rank') + ) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.mode == mode) \ + .order_by(DBScore.total_score.desc()) - if mods != None: - query = query.filter(DBScore.mods == mods) \ - .filter(or_(DBScore.status == 3, DBScore.status == 4)) - else: - query = query.filter(DBScore.status == 3) + if mods != None: + query = query.filter(DBScore.mods == mods) \ + .filter(or_(DBScore.status == 3, DBScore.status == 4)) + else: + query = query.filter(DBScore.status == 3) - if country != None: - query = query.join(DBScore.user) \ - .filter(DBUser.country == country) + if country != None: + query = query.join(DBScore.user) \ + .filter(DBUser.country == country) - if friends != None: - query = query.filter( - or_( - DBScore.user_id.in_(friends), - DBScore.user_id == user_id - ) - ) + if friends != None: + query = query.filter( + or_( + DBScore.user_id.in_(friends), + DBScore.user_id == user_id + ) + ) - subquery = query.subquery() + subquery = query.subquery() - if not (result := session.query(subquery.c.rank) \ - .filter(subquery.c.user_id == user_id) \ - .first()): - # No score was found...? - return 0 + if not (result := session.query(subquery.c.rank) \ + .filter(subquery.c.user_id == user_id) \ + .first()): + # No score was found...? + return 0 - return result[-1] + return result[-1] +@session_wrapper def fetch_score_index_by_id( score_id: int, beatmap_id: int, mode: int, - mods: Optional[int] = None + mods: int | None = None, + session: Session | None = None ) -> int: - with app.session.database.managed_session() as session: - query = session.query(DBScore.id, DBScore.mods, func.rank() \ - .over( - order_by=DBScore.total_score.desc() - ).label('rank') - ) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.mode == mode) \ - .order_by(DBScore.total_score.desc()) + query = session.query(DBScore.id, DBScore.mods, func.rank() \ + .over( + order_by=DBScore.total_score.desc() + ).label('rank') + ) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.mode == mode) \ + .order_by(DBScore.total_score.desc()) - if mods != None: - query = query.filter(DBScore.mods == mods) \ - .filter(or_(DBScore.status == 3, DBScore.status == 4)) - else: - query = query.filter(DBScore.status == 3) + if mods != None: + query = query.filter(DBScore.mods == mods) \ + .filter(or_(DBScore.status == 3, DBScore.status == 4)) + else: + query = query.filter(DBScore.status == 3) - subquery = query.subquery() + subquery = query.subquery() - if not (result := session.query(subquery.c.rank) \ - .filter(subquery.c.id == score_id) \ - .first()): - return 0 + if not (result := session.query(subquery.c.rank) \ + .filter(subquery.c.id == score_id) \ + .first()): + return 0 - return result[-1] + return result[-1] +@session_wrapper def fetch_score_index_by_tscore( total_score: int, beatmap_id: int, - mode: int + mode: int, + session: Session | None = None ) -> int: - with app.session.database.managed_session() as session: - closest_score = session.query(DBScore) \ - .filter(DBScore.total_score > total_score) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) \ - .order_by(func.abs(DBScore.total_score - total_score)) \ - .first() - - if not closest_score: - return 1 - - # Fetch score rank for closest score - return fetch_score_index_by_id( - closest_score.id, - beatmap_id, - mode - ) + 1 - + closest_score = session.query(DBScore) \ + .filter(DBScore.total_score > total_score) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) \ + .order_by(func.abs(DBScore.total_score - total_score)) \ + .first() + + if not closest_score: + return 1 + + # Fetch score rank for closest score + return fetch_score_index_by_id( + closest_score.id, + beatmap_id, + mode + ) + 1 + +@session_wrapper def fetch_score_above( beatmap_id: int, mode: int, - total_score: int + total_score: int, + session: Session | None = None ) -> Optional[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .options(selectinload(DBScore.user)) \ - .filter(DBScore.beatmap_id == beatmap_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.total_score > total_score) \ - .filter(DBScore.status == 3) \ - .order_by(DBScore.total_score.asc()) \ - .first() - + return session.query(DBScore) \ + .options(selectinload(DBScore.user)) \ + .filter(DBScore.beatmap_id == beatmap_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.total_score > total_score) \ + .filter(DBScore.status == 3) \ + .order_by(DBScore.total_score.asc()) \ + .first() + +@session_wrapper def fetch_recent( user_id: int, mode: int, - limit: int = 3 + limit: int = 3, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .order_by(DBScore.id.desc()) \ - .limit(limit) \ - .all() - + return session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .order_by(DBScore.id.desc()) \ + .limit(limit) \ + .all() + +@session_wrapper def fetch_recent_until( user_id: int, mode: int, until: datetime, - min_status: int = 2 + min_status: int = 2, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .filter(DBScore.submitted_at > until) \ - .filter(DBScore.status >= min_status) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .order_by(DBScore.id.desc()) \ - .all() - + return session.query(DBScore) \ + .filter(DBScore.submitted_at > until) \ + .filter(DBScore.status >= min_status) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .order_by(DBScore.id.desc()) \ + .all() + +@session_wrapper def fetch_recent_all( user_id: int, - limit: int = 3 + limit: int = 3, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .order_by(DBScore.id.desc()) \ - .limit(limit) \ - .all() + return session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .order_by(DBScore.id.desc()) \ + .limit(limit) \ + .all() +@session_wrapper def fetch_recent_top_scores( user_id: int, - limit: int = 3 + limit: int = 3, + session: Session | None = None ) -> List[DBScore]: - with app.session.database.managed_session() as session: - return session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.status == 3) \ - .order_by(DBScore.id.desc()) \ - .limit(limit) \ - .all() - + return session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.status == 3) \ + .order_by(DBScore.id.desc()) \ + .limit(limit) \ + .all() + +@session_wrapper def fetch_pp_record( mode: int, - mods: Optional[int] = None + mods: int | None = None, + session: Session | None = None ) -> DBScore: - with app.session.database.managed_session() as session: - if mods == None: - return session.query(DBScore) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 3) \ - .order_by(DBScore.pp.desc()) \ - .first() - + if mods == None: return session.query(DBScore) \ .filter(DBScore.mode == mode) \ .filter(DBScore.status == 3) \ - .filter(DBScore.mods == mods) \ .order_by(DBScore.pp.desc()) \ .first() -def restore_hidden_scores(user_id: int): + return session.query(DBScore) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 3) \ + .filter(DBScore.mods == mods) \ + .order_by(DBScore.pp.desc()) \ + .first() + +@session_wrapper +def restore_hidden_scores(user_id: int, session: Session | None = None): # This will restore all score status attributes app.session.logger.info(f'Restoring scores for user: {user_id}...') - with app.session.database.managed_session() as session: + session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.failtime != None) \ + .filter(DBScore.status == -1) \ + .update({ + 'status': 1 + }) + session.commit() + + all_scores = session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.failtime == None) \ + .filter(DBScore.status == -1) \ + .all() + + # Sort scores by beatmap + beatmaps: Dict[int, List[DBScore]] = {score.beatmap_id: [] for score in all_scores} + + for score in all_scores: + beatmaps[score.beatmap_id].append(score) + + for beatmap, scores in beatmaps.items(): + # Get best score for each beatmap + scores.sort( + key=lambda score: score.pp, + reverse=True + ) + + best_score = scores[0] + session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.failtime != None) \ - .filter(DBScore.status == -1) \ + .filter(DBScore.id == best_score.id) \ .update({ - 'status': 1 + 'status': 3 }) session.commit() - all_scores = session.query(DBScore) \ + # Set other scores with same mods to 'submitted' + session.query(DBScore) \ + .filter(DBScore.beatmap_id == beatmap) \ .filter(DBScore.user_id == user_id) \ - .filter(DBScore.failtime == None) \ + .filter(DBScore.mods == best_score.mods) \ .filter(DBScore.status == -1) \ - .all() - - # Sort scores by beatmap - beatmaps: Dict[int, List[DBScore]] = {score.beatmap_id: [] for score in all_scores} + .update({ + 'status': 2 + }) + session.commit() - for score in all_scores: - beatmaps[score.beatmap_id].append(score) + all_mods = [score.mods for score in scores if score.mods != best_score.mods] - for beatmap, scores in beatmaps.items(): - # Get best score for each beatmap - scores.sort( - key=lambda score: score.pp, - reverse=True - ) + for mods in all_mods: + # Update best score with mods + best_score = session.query(DBScore) \ + .filter(DBScore.beatmap_id == beatmap) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mods == mods) \ + .filter(DBScore.status == -1) \ + .order_by(DBScore.pp.desc()) \ + .first() - best_score = scores[0] + if not best_score: + continue - session.query(DBScore) \ - .filter(DBScore.id == best_score.id) \ - .update({ - 'status': 3 - }) + best_score.status = 4 session.commit() - # Set other scores with same mods to 'submitted' session.query(DBScore) \ - .filter(DBScore.beatmap_id == beatmap) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mods == best_score.mods) \ - .filter(DBScore.status == -1) \ - .update({ - 'status': 2 - }) + .filter(DBScore.beatmap_id == beatmap) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mods == mods) \ + .filter(DBScore.status == -1) \ + .update({ + 'status': 2 + }) session.commit() - all_mods = [score.mods for score in scores if score.mods != best_score.mods] - - for mods in all_mods: - # Update best score with mods - best_score = session.query(DBScore) \ - .filter(DBScore.beatmap_id == beatmap) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mods == mods) \ - .filter(DBScore.status == -1) \ - .order_by(DBScore.pp.desc()) \ - .first() - - if not best_score: - continue - - best_score.status = 4 - session.commit() - - session.query(DBScore) \ - .filter(DBScore.beatmap_id == beatmap) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mods == mods) \ - .filter(DBScore.status == -1) \ - .update({ - 'status': 2 - }) - session.commit() - app.session.logger.info('Scores have been restored!') diff --git a/database/repositories/screenshots.py b/database/repositories/screenshots.py index 97de655..48c0eb9 100644 --- a/database/repositories/screenshots.py +++ b/database/repositories/screenshots.py @@ -1,24 +1,32 @@ +from __future__ import annotations + from app.common.database.objects import DBScreenshot -from typing import Optional +from sqlalchemy.orm import Session -import app +from .wrapper import session_wrapper -def create(user_id: int, hidden: bool) -> DBScreenshot: - with app.session.database.managed_session() as session: - session.add( - ss := DBScreenshot( - user_id, - hidden - ) +@session_wrapper +def create( + user_id: int, + hidden: bool, + session: Session | None = None +) -> DBScreenshot: + session.add( + ss := DBScreenshot( + user_id, + hidden ) - session.commit() - session.refresh(ss) - + ) + session.commit() + session.refresh(ss) return ss -def fetch_by_id(id: int) -> Optional[DBScreenshot]: - with app.session.database.managed_session() as session: - return session.query(DBScreenshot) \ - .filter(DBScreenshot.id == id) \ - .first() +@session_wrapper +def fetch_by_id( + id: int, + session: Session | None = None +) -> DBScreenshot | None: + return session.query(DBScreenshot) \ + .filter(DBScreenshot.id == id) \ + .first() diff --git a/database/repositories/stats.py b/database/repositories/stats.py index f7a795a..e764461 100644 --- a/database/repositories/stats.py +++ b/database/repositories/stats.py @@ -1,209 +1,225 @@ +from __future__ import annotations + from app.common.database.objects import ( DBBeatmap, DBStats, DBScore ) +from sqlalchemy.orm import Session from typing import Optional, List from sqlalchemy import func +from .wrapper import session_wrapper from . import scores import config -import app - -def create(user_id: int, mode: int) -> DBStats: - with app.session.database.managed_session() as session: - session.add( - stats := DBStats( - user_id, - mode - ) - ) - session.commit() - session.refresh(stats) +@session_wrapper +def create( + user_id: int, + mode: int, + session: Session | None = None +) -> DBStats: + session.add( + stats := DBStats( + user_id, + mode + ) + ) + session.commit() + session.refresh(stats) return stats -def update(user_id: int, mode: int, updates: dict) -> int: - with app.session.database.managed_session() as session: - rows = session.query(DBStats) \ - .filter(DBStats.user_id == user_id) \ - .filter(DBStats.mode == mode) \ - .update(updates) - session.commit() - +@session_wrapper +def update( + user_id: int, + mode: int, + updates: dict, + session: Session | None = None +) -> int: + rows = session.query(DBStats) \ + .filter(DBStats.user_id == user_id) \ + .filter(DBStats.mode == mode) \ + .update(updates) + session.commit() return rows -def update_all(user_id: int, updates: dict) -> int: - with app.session.database.managed_session() as session: - rows = session.query(DBStats) \ - .filter(DBStats.user_id == user_id) \ - .update(updates) - session.commit() - +@session_wrapper +def update_all( + user_id: int, + updates: dict, + session: Session | None = None +) -> int: + rows = session.query(DBStats) \ + .filter(DBStats.user_id == user_id) \ + .update(updates) + session.commit() return rows -def delete_all(user_id: int) -> int: - with app.session.database.managed_session() as session: - rows = session.query(DBStats) \ - .filter(DBStats.user_id == user_id) \ - .delete() - session.commit() - +@session_wrapper +def delete_all(user_id: int, session: Session | None = None) -> int: + rows = session.query(DBStats) \ + .filter(DBStats.user_id == user_id) \ + .delete() + session.commit() return rows -def fetch_by_mode(user_id: int, mode: int) -> Optional[DBStats]: - with app.session.database.managed_session() as session: - return session.query(DBStats) \ - .filter(DBStats.user_id == user_id) \ - .filter(DBStats.mode == mode) \ +@session_wrapper +def fetch_by_mode( + user_id: int, + mode: int, + session: Session | None = None +) -> Optional[DBStats]: + return session.query(DBStats) \ + .filter(DBStats.user_id == user_id) \ + .filter(DBStats.mode == mode) \ + .first() + +@session_wrapper +def fetch_all(user_id: int, session: Session | None = None) -> List[DBStats]: + return session.query(DBStats) \ + .filter(DBStats.user_id == user_id) \ + .all() + +@session_wrapper +def restore(user_id: int, session: Session | None = None) -> None: + all_stats = [DBStats(user_id, mode) for mode in range(4)] + + for mode in range(4): + score_count = session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .count() + + fail_times = session.query( + func.sum(DBScore.failtime) + ) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .filter(DBScore.status == 1) \ + .scalar() + + fail_times = (fail_times / 1000) \ + if fail_times else 0 + + map_times = session.query( + DBScore, + func.sum(DBBeatmap.total_length) + ) \ + .join(DBBeatmap) \ + .group_by(DBScore) \ + .filter(DBScore.user_id == 15) \ + .filter(DBScore.mode == 0) \ + .filter(DBScore.status > 1) \ .first() -def fetch_all(user_id: int) -> List[DBStats]: - with app.session.database.managed_session() as session: - return session.query(DBStats) \ - .filter(DBStats.user_id == user_id) \ - .all() - -def restore(user_id: int) -> None: - with app.session.database.managed_session() as session: - all_stats = [DBStats(user_id, mode) for mode in range(4)] - - for mode in range(4): - score_count = session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .count() - - fail_times = session.query( - func.sum(DBScore.failtime) - ) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .filter(DBScore.status == 1) \ - .scalar() - - fail_times = (fail_times / 1000) \ - if fail_times else 0 - - map_times = session.query( - DBScore, - func.sum(DBBeatmap.total_length) - ) \ - .join(DBBeatmap) \ - .group_by(DBScore) \ - .filter(DBScore.user_id == 15) \ - .filter(DBScore.mode == 0) \ - .filter(DBScore.status > 1) \ - .first() - - if map_times: - map_times = map_times[-1] - else: - map_times = 0 + if map_times: + map_times = map_times[-1] + else: + map_times = 0 + + playtime = map_times + fail_times - playtime = map_times + fail_times + combo_score = session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .order_by(DBScore.max_combo.desc()) \ + .first() - combo_score = session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .order_by(DBScore.max_combo.desc()) \ - .first() + if combo_score: + max_combo = combo_score.max_combo + else: + max_combo = 0 - if combo_score: - max_combo = combo_score.max_combo - else: - max_combo = 0 - - total_score = session.query( - func.sum(DBScore.total_score) - ) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.mode == mode) \ - .scalar() - - if total_score is None: - total_score = 0 - - stats: DBStats = all_stats[mode] - stats.playcount = score_count - stats.max_combo = max_combo - stats.tscore = total_score - stats.playtime = playtime - - top_scores = scores.fetch_top_scores( - user_id, - mode, - exclude_approved=( - not config.APPROVED_MAP_REWARDS - ) + total_score = session.query( + func.sum(DBScore.total_score) + ) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.mode == mode) \ + .scalar() + + if total_score is None: + total_score = 0 + + stats: DBStats = all_stats[mode] + stats.playcount = score_count + stats.max_combo = max_combo + stats.tscore = total_score + stats.playtime = playtime + + top_scores = scores.fetch_top_scores( + user_id, + mode, + exclude_approved=( + not config.APPROVED_MAP_REWARDS ) + ) - # Update acc and pp + # Update acc and pp - score_count_best = scores.fetch_count( - user_id, - mode - ) + score_count_best = scores.fetch_count( + user_id, + mode + ) - if score_count_best > 0: - total_acc = 0 - divide_total = 0 + if score_count_best > 0: + total_acc = 0 + divide_total = 0 - for index, s in enumerate(top_scores): - add = 0.95 ** index - total_acc += s.acc * add - divide_total += add + for index, s in enumerate(top_scores): + add = 0.95 ** index + total_acc += s.acc * add + divide_total += add - if divide_total != 0: - stats.acc = total_acc / divide_total - else: - stats.acc = 0.0 + if divide_total != 0: + stats.acc = total_acc / divide_total + else: + stats.acc = 0.0 - weighted_pp = sum(score.pp * 0.95**index for index, score in enumerate(top_scores)) - bonus_pp = 416.6667 * (1 - 0.9994**score_count_best) + weighted_pp = sum(score.pp * 0.95**index for index, score in enumerate(top_scores)) + bonus_pp = 416.6667 * (1 - 0.9994**score_count_best) - stats.pp = weighted_pp + bonus_pp + stats.pp = weighted_pp + bonus_pp - best_scores = session.query(DBScore) \ - .filter(DBScore.user_id == user_id) \ - .filter(DBScore.status == 3) \ - .all() + best_scores = session.query(DBScore) \ + .filter(DBScore.user_id == user_id) \ + .filter(DBScore.status == 3) \ + .all() - for score in best_scores: - stats: DBStats = all_stats[score.mode] + for score in best_scores: + stats: DBStats = all_stats[score.mode] - grade_count = eval(f'stats.{score.grade.lower()}_count') + grade_count = eval(f'stats.{score.grade.lower()}_count') - if not grade_count: - grade_count = 0 + if not grade_count: + grade_count = 0 - if not stats.rscore: - stats.rscore = 0 + if not stats.rscore: + stats.rscore = 0 - if not stats.total_hits: - stats.total_hits = 0 + if not stats.total_hits: + stats.total_hits = 0 - stats.rscore += score.total_score - grade_count += 1 + stats.rscore += score.total_score + grade_count += 1 - if stats.mode == 2: - # ctb - total_hits = score.n50 + score.n100 + score.n300 + score.nMiss + score.nKatu + if stats.mode == 2: + # ctb + total_hits = score.n50 + score.n100 + score.n300 + score.nMiss + score.nKatu - elif stats.mode == 3: - # mania - total_hits = score.n300 + score.n100 + score.n50 + score.nGeki + score.nKatu + score.nMiss + elif stats.mode == 3: + # mania + total_hits = score.n300 + score.n100 + score.n50 + score.nGeki + score.nKatu + score.nMiss - else: - # standard + taiko - total_hits = score.n50 + score.n100 + score.n300 + score.nMiss + else: + # standard + taiko + total_hits = score.n50 + score.n100 + score.n300 + score.nMiss - stats.total_hits += total_hits + stats.total_hits += total_hits - for stats in all_stats: - session.add(stats) + for stats in all_stats: + session.add(stats) - session.commit() + session.commit() diff --git a/database/repositories/usercount.py b/database/repositories/usercount.py index c75bb6e..cb23a03 100644 --- a/database/repositories/usercount.py +++ b/database/repositories/usercount.py @@ -1,41 +1,49 @@ +from __future__ import annotations + from app.common.database.objects import DBUserCount from datetime import datetime, timedelta +from sqlalchemy.orm import Session from typing import List, Optional from sqlalchemy import desc, and_ -import app - -def create(count: int) -> DBUserCount: - with app.session.database.managed_session() as session: - session.add(uc := DBUserCount(count)) - session.commit() +from .wrapper import session_wrapper +@session_wrapper +def create(count: int, session: Session | None = None) -> DBUserCount: + session.add(uc := DBUserCount(count)) + session.commit() return uc -def fetch_range(_until: datetime, _from: datetime) -> List[DBUserCount]: - with app.session.database.managed_session() as session: - return session.query(DBUserCount) \ - .filter(and_( - DBUserCount.time <= _from, - DBUserCount.time >= _until - )) \ - .order_by(desc(DBUserCount.time)) \ - .all() - -def fetch_last() -> Optional[DBUserCount]: - with app.session.database.managed_session() as session: - return session.query(DBUserCount) \ - .order_by(desc(DBUserCount.time)) \ - .first() - -def delete_old(delta: timedelta = timedelta(weeks=5)) -> int: +@session_wrapper +def fetch_range( + _until: datetime, + _from: datetime, + session: Session | None = None +) -> List[DBUserCount]: + return session.query(DBUserCount) \ + .filter(and_( + DBUserCount.time <= _from, + DBUserCount.time >= _until + )) \ + .order_by(desc(DBUserCount.time)) \ + .all() + +@session_wrapper +def fetch_last(session: Session | None = None) -> Optional[DBUserCount]: + return session.query(DBUserCount) \ + .order_by(desc(DBUserCount.time)) \ + .first() + +@session_wrapper +def delete_old( + delta: timedelta = timedelta(weeks=5), + session: Session | None = None +) -> int: """Delete usercount entries that are older than the given delta (default ~1 month)""" - with app.session.database.managed_session() as session: - rows = session.query(DBUserCount) \ - .filter(DBUserCount.time <= (datetime.now() - delta)) \ - .delete() - session.commit() - + rows = session.query(DBUserCount) \ + .filter(DBUserCount.time <= (datetime.now() - delta)) \ + .delete() + session.commit() return rows diff --git a/database/repositories/users.py b/database/repositories/users.py index c1c490a..5d3bfcb 100644 --- a/database/repositories/users.py +++ b/database/repositories/users.py @@ -1,13 +1,16 @@ +from __future__ import annotations + from app.common.database.objects import DBUser, DBStats from datetime import datetime, timedelta from typing import Optional, List -from sqlalchemy.orm import selectinload +from sqlalchemy.orm import selectinload, Session from sqlalchemy import func, or_ -import app +from .wrapper import session_wrapper +@session_wrapper def create( username: str, safe_name: str, @@ -16,122 +19,125 @@ def create( country: str, activated: bool = False, discord_id: Optional[int] = None, - permissions: int = 1 + permissions: int = 1, + session: Session | None = None ) -> Optional[DBUser]: - with app.session.database.managed_session() as session: - session.add( - user := DBUser( - username, - safe_name, - email, - pw_bcrypt, - country, - activated, - discord_id, - permissions - ) + session.add( + user := DBUser( + username, + safe_name, + email, + pw_bcrypt, + country, + activated, + discord_id, + permissions ) - session.commit() - session.refresh(user) - + ) + session.commit() + session.refresh(user) return user -def update(user_id: int, updates: dict) -> int: - with app.session.database.managed_session() as session: - rows = session.query(DBUser) \ - .filter(DBUser.id == user_id) \ - .update(updates) +@session_wrapper +def update( + user_id: int, + updates: dict, + session: Session | None = None +) -> int: + rows = session.query(DBUser) \ + .filter(DBUser.id == user_id) \ + .update(updates) return rows -def fetch_by_name(username: str) -> Optional[DBUser]: - with app.session.database.managed_session() as session: - return session.query(DBUser) \ - .filter(DBUser.name == username) \ - .first() +@session_wrapper +def fetch_by_name(username: str, session: Session | None = None) -> Optional[DBUser]: + return session.query(DBUser) \ + .filter(DBUser.name == username) \ + .first() -def fetch_by_name_extended(query: str) -> Optional[DBUser]: +@session_wrapper +def fetch_by_name_extended(query: str, session: Session | None = None) -> Optional[DBUser]: """Used for searching users""" - with app.session.database.managed_session() as session: - return session.query(DBUser) \ - .filter(or_( - DBUser.name.ilike(query), - DBUser.name.ilike(f'%{query}%') - )) \ - .order_by(func.length(DBUser.name).asc()) \ - .first() - -def fetch_by_safe_name(username: str) -> Optional[DBUser]: - with app.session.database.managed_session() as session: - return session.query(DBUser) \ - .filter(DBUser.safe_name == username) \ - .first() - -def fetch_by_id(id: int) -> Optional[DBUser]: - with app.session.database.managed_session() as session: - return session.query(DBUser) \ - .filter(DBUser.id == id) \ - .first() - -def fetch_by_email(email: str) -> Optional[DBUser]: - with app.session.database.managed_session() as session: - return session.query(DBUser) \ - .filter(DBUser.email == email) \ - .first() - -def fetch_all(restricted: bool = False) -> List[DBUser]: - with app.session.database.managed_session() as session: - return session.query(DBUser) \ - .filter(DBUser.restricted == restricted) \ - .all() - -def fetch_active(delta: timedelta = timedelta(days=30), *preload) -> List[DBUser]: - with app.session.database.managed_session() as session: - return session.query(DBUser) \ - .join(DBStats) \ - .options(selectinload(*preload)) \ - .filter(DBUser.restricted == False) \ - .filter(DBStats.playcount > 0) \ - .filter( - # Remove inactive users from query, if they are not in the top 100 - or_( - DBUser.latest_activity >= (datetime.now() - delta), - DBStats.rank >= 100 - ) - ) \ - .all() - -def fetch_by_discord_id(id: int) -> Optional[DBUser]: - with app.session.database.managed_session() as session: - return session.query(DBUser) \ - .filter(DBUser.discord_id == id) \ - .first() - -def fetch_count(exclude_restricted=True) -> int: - with app.session.database.managed_session() as session: - query = session.query( - func.count(DBUser.id) - ) - - if exclude_restricted: - query = query.filter(DBUser.restricted == False) - - return query.scalar() - -def fetch_username(user_id: int) -> Optional[str]: - with app.session.database.managed_session() as session: - return session.query(DBUser.name) \ - .filter(DBUser.id == user_id) \ - .scalar() - -def fetch_user_id(username: str) -> Optional[int]: - with app.session.database.managed_session() as session: - return session.query(DBUser.id) \ - .filter(DBUser.name == username) \ - .scalar() - -def fetch_many(user_ids: tuple, *options) -> List[DBUser]: - with app.session.database.managed_session() as session: - return session.query(DBUser) \ - .options(*[selectinload(item) for item in options]) \ - .filter(DBUser.id.in_(user_ids)) \ - .all() + return session.query(DBUser) \ + .filter(or_( + DBUser.name.ilike(query), + DBUser.name.ilike(f'%{query}%') + )) \ + .order_by(func.length(DBUser.name).asc()) \ + .first() + +@session_wrapper +def fetch_by_safe_name(username: str, session: Session | None = None) -> Optional[DBUser]: + return session.query(DBUser) \ + .filter(DBUser.safe_name == username) \ + .first() + +@session_wrapper +def fetch_by_id(id: int, session: Session | None = None) -> Optional[DBUser]: + return session.query(DBUser) \ + .filter(DBUser.id == id) \ + .first() + +@session_wrapper +def fetch_by_email(email: str, session: Session | None = None) -> Optional[DBUser]: + return session.query(DBUser) \ + .filter(DBUser.email == email) \ + .first() + +@session_wrapper +def fetch_all(restricted: bool = False, session: Session | None = None) -> List[DBUser]: + return session.query(DBUser) \ + .filter(DBUser.restricted == restricted) \ + .all() + +@session_wrapper +def fetch_active(delta: timedelta = timedelta(days=30), session: Session | None = None) -> List[DBUser]: + return session.query(DBUser) \ + .join(DBStats) \ + .options(selectinload(DBStats)) \ + .filter(DBUser.restricted == False) \ + .filter(DBStats.playcount > 0) \ + .filter( + # Remove inactive users from query, if they are not in the top 100 + or_( + DBUser.latest_activity >= (datetime.now() - delta), + DBStats.rank >= 100 + ) + ) \ + .all() + +@session_wrapper +def fetch_by_discord_id(id: int, session: Session | None = None) -> Optional[DBUser]: + return session.query(DBUser) \ + .filter(DBUser.discord_id == id) \ + .first() + +@session_wrapper +def fetch_count(exclude_restricted=True, session: Session | None = None) -> int: + query = session.query( + func.count(DBUser.id) + ) + + if exclude_restricted: + query = query.filter(DBUser.restricted == False) + + return query.scalar() + +@session_wrapper +def fetch_username(user_id: int, session: Session | None = None) -> Optional[str]: + return session.query(DBUser.name) \ + .filter(DBUser.id == user_id) \ + .scalar() + +@session_wrapper +def fetch_user_id(username: str, session: Session | None = None) -> Optional[int]: + return session.query(DBUser.id) \ + .filter(DBUser.name == username) \ + .scalar() + +@session_wrapper +def fetch_many(user_ids: tuple, *options, session: Session | None = None) -> List[DBUser]: + return session.query(DBUser) \ + .options(*[selectinload(item) for item in options]) \ + .filter(DBUser.id.in_(user_ids)) \ + .all() diff --git a/database/repositories/verifications.py b/database/repositories/verifications.py index c047ccf..29270d6 100644 --- a/database/repositories/verifications.py +++ b/database/repositories/verifications.py @@ -1,58 +1,69 @@ -from ..objects import DBVerification +from __future__ import annotations + +from app.common.database.objects import DBVerification +from sqlalchemy.orm import Session from typing import Optional, List +from .wrapper import session_wrapper + import random import string -import app - -def create(user_id: int, type: int, token_size: int = 32) -> DBVerification: - with app.session.database.managed_session() as session: - session.add( - v := DBVerification( - ''.join(random.choices( - string.ascii_lowercase + - string.digits, k=token_size - )), - user_id, - type - ) - ) - session.commit() - session.refresh(v) +@session_wrapper +def create( + user_id: int, + type: int, + token_size: int = 32, + session: Session | None = None +) -> DBVerification: + session.add( + v := DBVerification( + ''.join(random.choices( + string.ascii_lowercase + + string.digits, k=token_size + )), + user_id, + type + ) + ) + session.commit() + session.refresh(v) return v -def fetch_by_id(id: int) -> Optional[DBVerification]: - with app.session.database.managed_session() as session: - return session.query(DBVerification) \ - .filter(DBVerification.id == id) \ - .first() +@session_wrapper +def fetch_by_id(id: int, session: Session | None = None) -> Optional[DBVerification]: + return session.query(DBVerification) \ + .filter(DBVerification.id == id) \ + .first() -def fetch_by_token(token: str) -> Optional[DBVerification]: - with app.session.database.managed_session() as session: - return session.query(DBVerification) \ - .filter(DBVerification.token == token) \ - .first() - -def fetch_all(user_id: int) -> List[DBVerification]: - with app.session.database.managed_session() as session: - return session.query(DBVerification) \ - .filter(DBVerification.user_id == user_id) \ - .all() - -def fetch_all_by_type(user_id: int, verification_type: int) -> List[DBVerification]: - with app.session.database.managed_session() as session: - return session.query(DBVerification) \ - .filter(DBVerification.user_id == user_id) \ - .filter(DBVerification.type == verification_type) \ - .all() - -def delete(token: str) -> int: - with app.session.database.managed_session() as session: - rows = session.query(DBVerification) \ - .filter(DBVerification.token == token) \ - .delete() - session.commit() +@session_wrapper +def fetch_by_token(token: str, session: Session | None = None) -> Optional[DBVerification]: + return session.query(DBVerification) \ + .filter(DBVerification.token == token) \ + .first() + +@session_wrapper +def fetch_all(user_id: int, session: Session | None = None) -> List[DBVerification]: + return session.query(DBVerification) \ + .filter(DBVerification.user_id == user_id) \ + .all() +@session_wrapper +def fetch_all_by_type( + user_id: int, + verification_type: int, + session: Session | None = None +) -> List[DBVerification]: + return session.query(DBVerification) \ + .filter(DBVerification.user_id == user_id) \ + .filter(DBVerification.type == verification_type) \ + .all() + +@session_wrapper +def delete(token: str, session: Session | None = None) -> int: + rows = session.query(DBVerification) \ + .filter(DBVerification.token == token) \ + .delete() + session.commit() return rows