From d4e0c559b5f697cb7dec57e739bbec5f05a4fccf Mon Sep 17 00:00:00 2001 From: Sheppsu <49356627+Sheppsu@users.noreply.github.com> Date: Tue, 1 Oct 2024 21:26:39 -0400 Subject: [PATCH] rework serialization logic replaces serializer classes for each model with a SerializableModel class --- otdb/api/views/mappools.py | 34 ++++--------------- otdb/api/views/tournaments.py | 17 +++++----- otdb/api/views/users.py | 7 ++-- otdb/common/models.py | 64 +++++++++++++++++++++++++++++++++++ otdb/database/models.py | 60 +++++++++++++++++++++++++------- otdb/main/models.py | 17 ++++++++-- 6 files changed, 146 insertions(+), 53 deletions(-) diff --git a/otdb/api/views/mappools.py b/otdb/api/views/mappools.py index f3b92f2..98907e0 100644 --- a/otdb/api/views/mappools.py +++ b/otdb/api/views/mappools.py @@ -1,8 +1,8 @@ from django.contrib.postgres.search import SearchVector, SearchQuery -from ..serializers import * from .util import * from .listing import Listing +from database.models import * from common.validation import * import time @@ -12,8 +12,7 @@ "get_full_mappool", "mappools", - "favorite_mappool", - "search_mappools" + "favorite_mappool" ) @@ -65,7 +64,7 @@ async def get_full_mappool(user, mappool_id) -> dict | None: except Mappool.DoesNotExist: return - data = MappoolSerializer(mappool).serialize(include=include+prefetch) + data = mappool.serialize(includes=include+prefetch) if user.is_authenticated: data["is_favorited"] = await mappool.is_favorited(user.id) @@ -90,10 +89,9 @@ async def mappools(req, mappool_id=None): return JsonResponse( { - "data": MappoolSerializer( - mappool_list, - many=True - ).serialize(include=["favorite_count"]), + "data": list(( + mappool.serialize(includes=["favorite_count"]) for mappool in mappool_list + )), "total_pages": total_pages }, safe=False @@ -153,8 +151,7 @@ async def create_mappool(req, data): mappool_id=data.get("id") or 0 ) - serializer = MappoolSerializer(mappool) - return JsonResponse(serializer.serialize(), safe=False) + return JsonResponse(mappool.serialize(), safe=False) @requires_auth @@ -198,20 +195,3 @@ async def favorite_mappool(req, mappool_id, data): await favorite.adelete() return HttpResponse(b"", 200) - - -@require_method("GET") -async def search_mappools(req): - def search(req): - return list(Mappool.objects.annotate( - search=SearchVector( - "name", - "tournament_connections__name_override", - "tournament_connections__tournament__name", - "tournament_connections__tournament__abbreviation", - "tournament_connections__tournament__description" - ) - ).filter(search=SearchQuery(req.GET.get("q", "")))[:20]) - - result = await sync_to_async(search)(req) - return JsonResponse(MappoolSerializer(result, many=True).serialize(), safe=False) diff --git a/otdb/api/views/tournaments.py b/otdb/api/views/tournaments.py index d03980d..8de9911 100644 --- a/otdb/api/views/tournaments.py +++ b/otdb/api/views/tournaments.py @@ -1,9 +1,9 @@ from django.http import Http404 -from ..serializers import * from .util import * from common.validation import * from .listing import Listing +from database.models import * import time @@ -36,9 +36,9 @@ async def get_full_tournament(user, id): except Tournament.DoesNotExist: return - data = TournamentSerializer(tournament).serialize( - include=["involvements__user", "submitted_by", "mappool_connections__mappool__favorite_count"], - exclude=["mappool_connections__tournament_id"] + data = tournament.serialize( + includes=["involvements__user", "submitted_by", "mappool_connections__mappool__favorite_count"], + excludes=["mappool_connections__tournament_id"] ) if user.is_authenticated: @@ -65,10 +65,9 @@ async def tournaments(req, id=None): tournament_list, total_pages = await TournamentListing(req).aget() return JsonResponse({ - "data": TournamentSerializer( - tournament_list, - many=True - ).serialize(include=["favorite_count"]), + "data": list(( + tournament.serialize(includes=["favorite_count"]) for tournament in tournament_list + )), "total_pages": total_pages }, safe=False) @@ -120,7 +119,7 @@ async def create_tournament(req, data): data.get("id") or 0 ) - return JsonResponse(TournamentSerializer(tournament).serialize(), safe=False) + return JsonResponse(tournament.serialize(), safe=False) @requires_auth diff --git a/otdb/api/views/users.py b/otdb/api/views/users.py index c2aa12c..3556023 100644 --- a/otdb/api/views/users.py +++ b/otdb/api/views/users.py @@ -1,5 +1,6 @@ -from ..serializers import * from .util import * +from main.models import * +from database.models import * __all__ = ( @@ -34,8 +35,8 @@ async def users(req, id): except OsuUser.DoesNotExist: return error("Invalid user id", 400) - return JsonResponse(OsuUserSerializer(user).serialize( - include=[ + return JsonResponse(user.serialize( + includes=[ "involvements__tournament__favorite_count", "tournament_favorite_connections__tournament__favorite_count", "mappool_favorite_connections__mappool__favorite_count", diff --git a/otdb/common/models.py b/otdb/common/models.py index a6899d8..41be319 100644 --- a/otdb/common/models.py +++ b/otdb/common/models.py @@ -1,3 +1,8 @@ +from django.db import models +from collections import defaultdict +from datetime import datetime + + def enum_field(enum, field): def decorator(cls): return type( @@ -10,3 +15,62 @@ def decorator(cls): } ) return decorator + + +def _separate_field_args(fields, only_include_last=False): + now = [] + later = defaultdict(list) + for field in fields: + split = field.split("__", 1) + if len(split) > 1: + later[split[0]].append(split[1]) + if not only_include_last or len(split) == 1: + now.append(split[0]) + return now, later + + +class SerializableModel(models.Model): + class Meta: + abstract = True + + class Serialization: + FIELDS: list + EXCLUDES: list + TRANSFORM: dict[str, str] + + def _transform(self, fields: list, excludes: dict, includes: dict): + field_transforms = getattr(self.Serialization, "TRANSFORM", {}) + + data = {} + for field in fields: + value = getattr(self, field) + if isinstance(value, SerializableModel): + value = value.serialize(includes.get(field), excludes.get(field)) + elif isinstance(value, datetime): + value = value.isoformat() + elif value.__class__.__name__ == "RelatedManager" or value.__class__.__name__ == "ManyRelatedManager": + value = [obj.serialize(includes.get(field), excludes.get(field)) for obj in value.all()] + + data[field_transforms.get(field, field)] = value + + return data + + def serialize(self, includes: list | None = None, excludes: list | None = None): + if includes is None: + includes = [] + if excludes is None: + excludes = [] + + exclude_now, exclude_later = _separate_field_args(excludes, only_include_last=True) + include_now, include_later = _separate_field_args(includes) + + fields = list(self.Serialization.FIELDS) + for field in exclude_now: + try: + fields.remove(field) + except ValueError: + pass + for field in include_now: + fields.append(field) + + return self._transform(fields, exclude_later, include_later) diff --git a/otdb/database/models.py b/otdb/database/models.py index 2282c09..d9e2ff4 100644 --- a/otdb/database/models.py +++ b/otdb/database/models.py @@ -2,7 +2,7 @@ from django.contrib.auth import get_user_model from django.conf import settings -from common.models import enum_field +from common.models import enum_field, SerializableModel from common.exceptions import ClientException, ServerException from common.util import sql_s, unzip, find_invalids @@ -94,17 +94,20 @@ class UserRolesField: pass -class BeatmapsetMetadata(models.Model): +class BeatmapsetMetadata(SerializableModel): id = models.PositiveIntegerField(primary_key=True) artist = models.CharField(max_length=256) title = models.CharField(max_length=256) creator = models.CharField(max_length=15) + class Serialization: + FIELDS = ["id", "artist", "title", "creator"] + def __str__(self): return str(self.id) -class BeatmapMetadata(models.Model): +class BeatmapMetadata(SerializableModel): id = models.PositiveIntegerField(primary_key=True) difficulty = models.CharField(max_length=256) ar = models.FloatField() @@ -114,26 +117,35 @@ class BeatmapMetadata(models.Model): length = models.PositiveIntegerField() bpm = models.FloatField() + class Serialization: + FIELDS = ["id", "difficulty", "ar", "od", "cs", "hp", "length", "bpm"] + def __str__(self): return str(self.id) -class BeatmapMod(models.Model): +class BeatmapMod(SerializableModel): acronym = models.CharField(max_length=2) settings = models.JSONField(default=dict) + class Serialization: + FIELDS = ["id", "acronym", "settings"] + class Meta: constraints = [ models.UniqueConstraint(fields=["acronym", "settings"], name="beatmapmod_unique_constraint") ] -class MappoolBeatmap(models.Model): +class MappoolBeatmap(SerializableModel): beatmapset_metadata = models.ForeignKey(BeatmapsetMetadata, models.PROTECT, related_name="mappool_beatmaps") beatmap_metadata = models.ForeignKey(BeatmapMetadata, models.PROTECT, related_name="mappool_beatmaps") mods = models.ManyToManyField(BeatmapMod, "related_beatmaps") star_rating = models.FloatField() + class Serialization: + FIELDS = ["id", "star_rating"] + @staticmethod async def get_rows_data(beatmap: Beatmap, mods: tuple[str | None, ...]): mods_flag = 0 @@ -178,13 +190,16 @@ def __str__(self): return self.slot -class MappoolBeatmapConnection(models.Model): +class MappoolBeatmapConnection(SerializableModel): mappool = models.ForeignKey("Mappool", models.CASCADE, related_name="beatmap_connections") beatmap = models.ForeignKey(MappoolBeatmap, models.CASCADE, related_name="mappool_connections") slot = models.CharField(max_length=8) + class Serialization: + FIELDS = ["slot"] + -class Mappool(models.Model): +class Mappool(SerializableModel): name = models.CharField(max_length=64) description = models.CharField(max_length=512, default="") beatmaps = models.ManyToManyField(MappoolBeatmap, "mappools", through=MappoolBeatmapConnection) @@ -192,6 +207,9 @@ class Mappool(models.Model): favorites = models.ManyToManyField(OsuUser, through="MappoolFavorite", related_name="mappool_favorites") avg_star_rating = models.FloatField() + class Serialization: + FIELDS = ["id", "name", "description", "avg_star_rating"] + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -253,7 +271,7 @@ def __str__(self): return self.name -class Tournament(models.Model): +class Tournament(SerializableModel): name = models.CharField(max_length=128, unique=True) abbreviation = models.CharField(max_length=16, default="") link = models.CharField(max_length=256, default="") @@ -263,6 +281,12 @@ class Tournament(models.Model): submitted_by = models.ForeignKey(OsuUser, models.PROTECT, related_name="submitted_tournaments") favorites = models.ManyToManyField(OsuUser, through="TournamentFavorite", related_name="tournament_favorites") + class Serialization: + FIELDS = ["id", "name", "abbreviation", "link", "description"] + TRANSFORM = { + "involvements": "staff" + } + @staticmethod def _new_tournament( cls, @@ -355,33 +379,45 @@ def __str__(self): return self.name -class TournamentInvolvement(models.Model): +class TournamentInvolvement(SerializableModel): tournament = models.ForeignKey(Tournament, on_delete=models.CASCADE, related_name="involvements") user = models.ForeignKey(OsuUser, on_delete=models.CASCADE, related_name="involvements") roles = UserRolesField(default=0) + class Serialization: + FIELDS = ["roles"] + class Meta: constraints = [ models.UniqueConstraint(fields=["tournament", "user"], name="tournamentinvolvement_unique_constraint") ] -class MappoolConnection(models.Model): +class MappoolConnection(SerializableModel): tournament = models.ForeignKey(Tournament, on_delete=models.CASCADE, related_name="mappool_connections") mappool = models.ForeignKey(Mappool, on_delete=models.CASCADE, related_name="tournament_connections") name_override = models.CharField(max_length=64, null=True) + class Serialization: + FIELDS = ["name_override"] + def __str__(self): return self.name_override if self.name_override is not None else "" -class MappoolFavorite(models.Model): +class MappoolFavorite(SerializableModel): mappool = models.ForeignKey(Mappool, models.CASCADE, related_name="favorite_connections") user = models.ForeignKey(OsuUser, models.CASCADE, related_name="mappool_favorite_connections") timestamp = models.PositiveBigIntegerField() + class Serialization: + FIELDS = ["timestamp"] -class TournamentFavorite(models.Model): + +class TournamentFavorite(SerializableModel): tournament = models.ForeignKey(Tournament, models.CASCADE, related_name="favorite_connections") user = models.ForeignKey(OsuUser, models.CASCADE, related_name="tournament_favorite_connections") timestamp = models.PositiveBigIntegerField() + + class Serialization: + FIELDS = ["timestamp"] diff --git a/otdb/main/models.py b/otdb/main/models.py index 66250c4..1ca6f18 100644 --- a/otdb/main/models.py +++ b/otdb/main/models.py @@ -5,6 +5,8 @@ from datetime import datetime, timezone, time from asgiref.sync import sync_to_async +from common.models import SerializableModel + osu_client: AsynchronousClient = settings.OSU_CLIENT @@ -29,7 +31,7 @@ async def create_user(self, code): return user -class OsuUser(models.Model): +class OsuUser(SerializableModel): is_anonymous = False is_authenticated = True @@ -46,6 +48,14 @@ class OsuUser(models.Model): USERNAME_FIELD = "id" objects = UserManager() + class Serialization: + FIELDS = ["id", "username", "avatar", "cover", "is_admin"] + TRANSFORM = { + "involvements": "staff_roles", + "mappool_favorite_connections": "mappool_favorites", + "tournament_favorite_connections": "tournament_favorites" + } + @classmethod async def from_data(cls, data): try: @@ -66,10 +76,13 @@ def __str__(self): return self.username -class TrafficStatistic(models.Model): +class TrafficStatistic(SerializableModel): timestamp = models.DateTimeField() traffic = models.PositiveBigIntegerField(default=0) + class Serialization: + FIELDS = ["timestamp", "traffic"] + @classmethod def _now(cls): now = datetime.now(tz=timezone.utc)