Skip to content

Commit

Permalink
rework serialization logic
Browse files Browse the repository at this point in the history
replaces serializer classes for each model with a SerializableModel class
  • Loading branch information
Sheppsu committed Oct 2, 2024
1 parent e463c3f commit d4e0c55
Show file tree
Hide file tree
Showing 6 changed files with 146 additions and 53 deletions.
34 changes: 7 additions & 27 deletions otdb/api/views/mappools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from django.contrib.postgres.search import SearchVector, SearchQuery

from ..serializers import *
from .util import *
from .listing import Listing
from database.models import *
from common.validation import *

import time
Expand All @@ -12,8 +12,7 @@
"get_full_mappool",

"mappools",
"favorite_mappool",
"search_mappools"
"favorite_mappool"
)


Expand Down Expand Up @@ -65,7 +64,7 @@ async def get_full_mappool(user, mappool_id) -> dict | None:
except Mappool.DoesNotExist:
return

data = MappoolSerializer(mappool).serialize(include=include+prefetch)
data = mappool.serialize(includes=include+prefetch)

if user.is_authenticated:
data["is_favorited"] = await mappool.is_favorited(user.id)
Expand All @@ -90,10 +89,9 @@ async def mappools(req, mappool_id=None):

return JsonResponse(
{
"data": MappoolSerializer(
mappool_list,
many=True
).serialize(include=["favorite_count"]),
"data": list((
mappool.serialize(includes=["favorite_count"]) for mappool in mappool_list
)),
"total_pages": total_pages
},
safe=False
Expand Down Expand Up @@ -153,8 +151,7 @@ async def create_mappool(req, data):
mappool_id=data.get("id") or 0
)

serializer = MappoolSerializer(mappool)
return JsonResponse(serializer.serialize(), safe=False)
return JsonResponse(mappool.serialize(), safe=False)


@requires_auth
Expand Down Expand Up @@ -198,20 +195,3 @@ async def favorite_mappool(req, mappool_id, data):
await favorite.adelete()

return HttpResponse(b"", 200)


@require_method("GET")
async def search_mappools(req):
def search(req):
return list(Mappool.objects.annotate(
search=SearchVector(
"name",
"tournament_connections__name_override",
"tournament_connections__tournament__name",
"tournament_connections__tournament__abbreviation",
"tournament_connections__tournament__description"
)
).filter(search=SearchQuery(req.GET.get("q", "")))[:20])

result = await sync_to_async(search)(req)
return JsonResponse(MappoolSerializer(result, many=True).serialize(), safe=False)
17 changes: 8 additions & 9 deletions otdb/api/views/tournaments.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from django.http import Http404

from ..serializers import *
from .util import *
from common.validation import *
from .listing import Listing
from database.models import *

import time

Expand Down Expand Up @@ -36,9 +36,9 @@ async def get_full_tournament(user, id):
except Tournament.DoesNotExist:
return

data = TournamentSerializer(tournament).serialize(
include=["involvements__user", "submitted_by", "mappool_connections__mappool__favorite_count"],
exclude=["mappool_connections__tournament_id"]
data = tournament.serialize(
includes=["involvements__user", "submitted_by", "mappool_connections__mappool__favorite_count"],
excludes=["mappool_connections__tournament_id"]
)

if user.is_authenticated:
Expand All @@ -65,10 +65,9 @@ async def tournaments(req, id=None):
tournament_list, total_pages = await TournamentListing(req).aget()

return JsonResponse({
"data": TournamentSerializer(
tournament_list,
many=True
).serialize(include=["favorite_count"]),
"data": list((
tournament.serialize(includes=["favorite_count"]) for tournament in tournament_list
)),
"total_pages": total_pages
}, safe=False)

Expand Down Expand Up @@ -120,7 +119,7 @@ async def create_tournament(req, data):
data.get("id") or 0
)

return JsonResponse(TournamentSerializer(tournament).serialize(), safe=False)
return JsonResponse(tournament.serialize(), safe=False)


@requires_auth
Expand Down
7 changes: 4 additions & 3 deletions otdb/api/views/users.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..serializers import *
from .util import *
from main.models import *
from database.models import *


__all__ = (
Expand Down Expand Up @@ -34,8 +35,8 @@ async def users(req, id):
except OsuUser.DoesNotExist:
return error("Invalid user id", 400)

return JsonResponse(OsuUserSerializer(user).serialize(
include=[
return JsonResponse(user.serialize(
includes=[
"involvements__tournament__favorite_count",
"tournament_favorite_connections__tournament__favorite_count",
"mappool_favorite_connections__mappool__favorite_count",
Expand Down
64 changes: 64 additions & 0 deletions otdb/common/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from django.db import models
from collections import defaultdict
from datetime import datetime


def enum_field(enum, field):
def decorator(cls):
return type(
Expand All @@ -10,3 +15,62 @@ def decorator(cls):
}
)
return decorator


def _separate_field_args(fields, only_include_last=False):
now = []
later = defaultdict(list)
for field in fields:
split = field.split("__", 1)
if len(split) > 1:
later[split[0]].append(split[1])
if not only_include_last or len(split) == 1:
now.append(split[0])
return now, later


class SerializableModel(models.Model):
class Meta:
abstract = True

class Serialization:
FIELDS: list
EXCLUDES: list
TRANSFORM: dict[str, str]

def _transform(self, fields: list, excludes: dict, includes: dict):
field_transforms = getattr(self.Serialization, "TRANSFORM", {})

data = {}
for field in fields:
value = getattr(self, field)
if isinstance(value, SerializableModel):
value = value.serialize(includes.get(field), excludes.get(field))
elif isinstance(value, datetime):
value = value.isoformat()
elif value.__class__.__name__ == "RelatedManager" or value.__class__.__name__ == "ManyRelatedManager":
value = [obj.serialize(includes.get(field), excludes.get(field)) for obj in value.all()]

data[field_transforms.get(field, field)] = value

return data

def serialize(self, includes: list | None = None, excludes: list | None = None):
if includes is None:
includes = []
if excludes is None:
excludes = []

exclude_now, exclude_later = _separate_field_args(excludes, only_include_last=True)
include_now, include_later = _separate_field_args(includes)

fields = list(self.Serialization.FIELDS)
for field in exclude_now:
try:
fields.remove(field)
except ValueError:
pass
for field in include_now:
fields.append(field)

return self._transform(fields, exclude_later, include_later)
60 changes: 48 additions & 12 deletions otdb/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from django.contrib.auth import get_user_model
from django.conf import settings

from common.models import enum_field
from common.models import enum_field, SerializableModel
from common.exceptions import ClientException, ServerException
from common.util import sql_s, unzip, find_invalids

Expand Down Expand Up @@ -94,17 +94,20 @@ class UserRolesField:
pass


class BeatmapsetMetadata(models.Model):
class BeatmapsetMetadata(SerializableModel):
id = models.PositiveIntegerField(primary_key=True)
artist = models.CharField(max_length=256)
title = models.CharField(max_length=256)
creator = models.CharField(max_length=15)

class Serialization:
FIELDS = ["id", "artist", "title", "creator"]

def __str__(self):
return str(self.id)


class BeatmapMetadata(models.Model):
class BeatmapMetadata(SerializableModel):
id = models.PositiveIntegerField(primary_key=True)
difficulty = models.CharField(max_length=256)
ar = models.FloatField()
Expand All @@ -114,26 +117,35 @@ class BeatmapMetadata(models.Model):
length = models.PositiveIntegerField()
bpm = models.FloatField()

class Serialization:
FIELDS = ["id", "difficulty", "ar", "od", "cs", "hp", "length", "bpm"]

def __str__(self):
return str(self.id)


class BeatmapMod(models.Model):
class BeatmapMod(SerializableModel):
acronym = models.CharField(max_length=2)
settings = models.JSONField(default=dict)

class Serialization:
FIELDS = ["id", "acronym", "settings"]

class Meta:
constraints = [
models.UniqueConstraint(fields=["acronym", "settings"], name="beatmapmod_unique_constraint")
]


class MappoolBeatmap(models.Model):
class MappoolBeatmap(SerializableModel):
beatmapset_metadata = models.ForeignKey(BeatmapsetMetadata, models.PROTECT, related_name="mappool_beatmaps")
beatmap_metadata = models.ForeignKey(BeatmapMetadata, models.PROTECT, related_name="mappool_beatmaps")
mods = models.ManyToManyField(BeatmapMod, "related_beatmaps")
star_rating = models.FloatField()

class Serialization:
FIELDS = ["id", "star_rating"]

@staticmethod
async def get_rows_data(beatmap: Beatmap, mods: tuple[str | None, ...]):
mods_flag = 0
Expand Down Expand Up @@ -178,20 +190,26 @@ def __str__(self):
return self.slot


class MappoolBeatmapConnection(models.Model):
class MappoolBeatmapConnection(SerializableModel):
mappool = models.ForeignKey("Mappool", models.CASCADE, related_name="beatmap_connections")
beatmap = models.ForeignKey(MappoolBeatmap, models.CASCADE, related_name="mappool_connections")
slot = models.CharField(max_length=8)

class Serialization:
FIELDS = ["slot"]


class Mappool(models.Model):
class Mappool(SerializableModel):
name = models.CharField(max_length=64)
description = models.CharField(max_length=512, default="")
beatmaps = models.ManyToManyField(MappoolBeatmap, "mappools", through=MappoolBeatmapConnection)
submitted_by = models.ForeignKey(OsuUser, models.PROTECT, related_name="submitted_mappools")
favorites = models.ManyToManyField(OsuUser, through="MappoolFavorite", related_name="mappool_favorites")
avg_star_rating = models.FloatField()

class Serialization:
FIELDS = ["id", "name", "description", "avg_star_rating"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -253,7 +271,7 @@ def __str__(self):
return self.name


class Tournament(models.Model):
class Tournament(SerializableModel):
name = models.CharField(max_length=128, unique=True)
abbreviation = models.CharField(max_length=16, default="")
link = models.CharField(max_length=256, default="")
Expand All @@ -263,6 +281,12 @@ class Tournament(models.Model):
submitted_by = models.ForeignKey(OsuUser, models.PROTECT, related_name="submitted_tournaments")
favorites = models.ManyToManyField(OsuUser, through="TournamentFavorite", related_name="tournament_favorites")

class Serialization:
FIELDS = ["id", "name", "abbreviation", "link", "description"]
TRANSFORM = {
"involvements": "staff"
}

@staticmethod
def _new_tournament(
cls,
Expand Down Expand Up @@ -355,33 +379,45 @@ def __str__(self):
return self.name


class TournamentInvolvement(models.Model):
class TournamentInvolvement(SerializableModel):
tournament = models.ForeignKey(Tournament, on_delete=models.CASCADE, related_name="involvements")
user = models.ForeignKey(OsuUser, on_delete=models.CASCADE, related_name="involvements")
roles = UserRolesField(default=0)

class Serialization:
FIELDS = ["roles"]

class Meta:
constraints = [
models.UniqueConstraint(fields=["tournament", "user"], name="tournamentinvolvement_unique_constraint")
]


class MappoolConnection(models.Model):
class MappoolConnection(SerializableModel):
tournament = models.ForeignKey(Tournament, on_delete=models.CASCADE, related_name="mappool_connections")
mappool = models.ForeignKey(Mappool, on_delete=models.CASCADE, related_name="tournament_connections")
name_override = models.CharField(max_length=64, null=True)

class Serialization:
FIELDS = ["name_override"]

def __str__(self):
return self.name_override if self.name_override is not None else ""


class MappoolFavorite(models.Model):
class MappoolFavorite(SerializableModel):
mappool = models.ForeignKey(Mappool, models.CASCADE, related_name="favorite_connections")
user = models.ForeignKey(OsuUser, models.CASCADE, related_name="mappool_favorite_connections")
timestamp = models.PositiveBigIntegerField()

class Serialization:
FIELDS = ["timestamp"]

class TournamentFavorite(models.Model):

class TournamentFavorite(SerializableModel):
tournament = models.ForeignKey(Tournament, models.CASCADE, related_name="favorite_connections")
user = models.ForeignKey(OsuUser, models.CASCADE, related_name="tournament_favorite_connections")
timestamp = models.PositiveBigIntegerField()

class Serialization:
FIELDS = ["timestamp"]
Loading

0 comments on commit d4e0c55

Please sign in to comment.