From 9eb3a9fdee11b73595e65d3a731778a9ec4a18e3 Mon Sep 17 00:00:00 2001 From: Lekuru Date: Mon, 30 Oct 2023 16:57:52 +0100 Subject: [PATCH] Add `search_extended` query --- database/repositories/beatmapsets.py | 93 +++++++++++++++++++++++++++- 1 file changed, 92 insertions(+), 1 deletion(-) diff --git a/database/repositories/beatmapsets.py b/database/repositories/beatmapsets.py index 42c02a7..880cf80 100644 --- a/database/repositories/beatmapsets.py +++ b/database/repositories/beatmapsets.py @@ -1,5 +1,5 @@ -from app.common.constants import DisplayMode +from app.common.constants import DisplayMode, BeatmapSortBy from app.common.database.objects import ( DBBeatmapset, DBBeatmap, @@ -183,3 +183,94 @@ def search_one(query_string: str, offset: int = 0) -> Optional[DBBeatmapset]: .filter(and_(*conditions)) \ .order_by(DBBeatmap.playcount.desc()) \ .first() + +def search_extended( + query_string: Optional[str], + genre: Optional[int], + language: Optional[int], + played: Optional[bool], + user_id: Optional[int], + mode: Optional[int], + status: Optional[int], + sort: BeatmapSortBy, + has_storyboard: bool, + has_video: bool, + offset: int = 0, + limit: int = 50 +) -> List[DBBeatmapset]: + conditions = [] + + if query_string: + stop_words = ['the', 'and', 'of', 'in', 'to', 'for'] + conditions = [] + + keywords = [ + f'%{word}%' for word in query_string.strip() \ + .replace(' - ', ' ') \ + .lower() \ + .split() + if word not in stop_words + ] + + searchable_columns = [ + func.to_tsvector('english', column) + for column in [ + func.lower(DBBeatmapset.title), + func.lower(DBBeatmapset.artist), + func.lower(DBBeatmapset.creator), + func.lower(DBBeatmapset.source), + func.lower(DBBeatmapset.tags), + func.lower(DBBeatmap.version) + ] + ] + + for word in keywords: + conditions.append(or_( + *[ + col.op('@@')(func.plainto_tsquery('english', word)) + for col in searchable_columns + ] + )) + + query = app.session.database.session.query(DBBeatmapset) \ + .join(DBBeatmap) \ + .join(DBRating) \ + .filter(and_(*conditions)) \ + .group_by(DBBeatmapset.id) \ + .order_by({ + BeatmapSortBy.Title: DBBeatmapset.title.asc(), + BeatmapSortBy.Artist: DBBeatmapset.artist.asc(), + BeatmapSortBy.Creator: DBBeatmapset.creator.asc(), + BeatmapSortBy.RankedAsc: DBBeatmapset.approved_at.asc(), + BeatmapSortBy.RankedDesc: DBBeatmapset.approved_at.desc(), + BeatmapSortBy.Difficulty: func.max(DBBeatmap.diff).desc(), + BeatmapSortBy.Rating: func.avg(DBRating.rating).desc(), + BeatmapSortBy.Plays: func.sum(DBBeatmap.playcount).desc(), + }[sort]) + + if genre is not None: + query = query.filter(DBBeatmapset.genre_id == genre) + + if language is not None: + query = query.filter(DBBeatmapset.language_id == language) + + if mode is not None: + query = query.filter(DBBeatmap.mode == mode) + + if status is not None: + query = query.filter(DBBeatmapset.status == status) + + if has_storyboard: + query = query.filter(DBBeatmapset.has_storyboard == True) + + if has_video: + query = query.filter(DBBeatmapset.has_video == True) + + if (played is not None and + user_id is not None): + query = query.join(DBPlay) \ + .filter(DBPlay.user_id == user_id) + + return query.offset(offset) \ + .limit(limit) \ + .all()