Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework serialization logic #29

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 7 additions & 27 deletions otdb/api/views/mappools.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from django.contrib.postgres.search import SearchVector, SearchQuery

from ..serializers import *
from .util import *
from .listing import Listing
from database.models import *
from common.validation import *

import time
Expand All @@ -12,8 +12,7 @@
"get_full_mappool",

"mappools",
"favorite_mappool",
"search_mappools"
"favorite_mappool"
)


Expand Down Expand Up @@ -65,7 +64,7 @@ async def get_full_mappool(user, mappool_id) -> dict | None:
except Mappool.DoesNotExist:
return

data = MappoolSerializer(mappool).serialize(include=include+prefetch)
data = mappool.serialize(includes=include+prefetch)

if user.is_authenticated:
data["is_favorited"] = await mappool.is_favorited(user.id)
Expand All @@ -90,10 +89,9 @@ async def mappools(req, mappool_id=None):

return JsonResponse(
{
"data": MappoolSerializer(
mappool_list,
many=True
).serialize(include=["favorite_count"]),
"data": list((
mappool.serialize(includes=["favorite_count"]) for mappool in mappool_list
)),
"total_pages": total_pages
},
safe=False
Expand Down Expand Up @@ -153,8 +151,7 @@ async def create_mappool(req, data):
mappool_id=data.get("id") or 0
)

serializer = MappoolSerializer(mappool)
return JsonResponse(serializer.serialize(), safe=False)
return JsonResponse(mappool.serialize(), safe=False)


@requires_auth
Expand Down Expand Up @@ -198,20 +195,3 @@ async def favorite_mappool(req, mappool_id, data):
await favorite.adelete()

return HttpResponse(b"", 200)


@require_method("GET")
async def search_mappools(req):
def search(req):
return list(Mappool.objects.annotate(
search=SearchVector(
"name",
"tournament_connections__name_override",
"tournament_connections__tournament__name",
"tournament_connections__tournament__abbreviation",
"tournament_connections__tournament__description"
)
).filter(search=SearchQuery(req.GET.get("q", "")))[:20])

result = await sync_to_async(search)(req)
return JsonResponse(MappoolSerializer(result, many=True).serialize(), safe=False)
17 changes: 8 additions & 9 deletions otdb/api/views/tournaments.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from django.http import Http404

from ..serializers import *
from .util import *
from common.validation import *
from .listing import Listing
from database.models import *

import time

Expand Down Expand Up @@ -36,9 +36,9 @@ async def get_full_tournament(user, id):
except Tournament.DoesNotExist:
return

data = TournamentSerializer(tournament).serialize(
include=["involvements__user", "submitted_by", "mappool_connections__mappool__favorite_count"],
exclude=["mappool_connections__tournament_id"]
data = tournament.serialize(
includes=["involvements__user", "submitted_by", "mappool_connections__mappool__favorite_count"],
excludes=["mappool_connections__tournament_id"]
)

if user.is_authenticated:
Expand All @@ -65,10 +65,9 @@ async def tournaments(req, id=None):
tournament_list, total_pages = await TournamentListing(req).aget()

return JsonResponse({
"data": TournamentSerializer(
tournament_list,
many=True
).serialize(include=["favorite_count"]),
"data": list((
tournament.serialize(includes=["favorite_count"]) for tournament in tournament_list
)),
"total_pages": total_pages
}, safe=False)

Expand Down Expand Up @@ -120,7 +119,7 @@ async def create_tournament(req, data):
data.get("id") or 0
)

return JsonResponse(TournamentSerializer(tournament).serialize(), safe=False)
return JsonResponse(tournament.serialize(), safe=False)


@requires_auth
Expand Down
7 changes: 4 additions & 3 deletions otdb/api/views/users.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from ..serializers import *
from .util import *
from main.models import *
from database.models import *


__all__ = (
Expand Down Expand Up @@ -34,8 +35,8 @@ async def users(req, id):
except OsuUser.DoesNotExist:
return error("Invalid user id", 400)

return JsonResponse(OsuUserSerializer(user).serialize(
include=[
return JsonResponse(user.serialize(
includes=[
"involvements__tournament__favorite_count",
"tournament_favorite_connections__tournament__favorite_count",
"mappool_favorite_connections__mappool__favorite_count",
Expand Down
64 changes: 64 additions & 0 deletions otdb/common/models.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
from django.db import models
from collections import defaultdict
from datetime import datetime


def enum_field(enum, field):
def decorator(cls):
return type(
Expand All @@ -10,3 +15,62 @@ def decorator(cls):
}
)
return decorator


def _separate_field_args(fields, only_include_last=False):
now = []
later = defaultdict(list)
for field in fields:
split = field.split("__", 1)
if len(split) > 1:
later[split[0]].append(split[1])
if not only_include_last or len(split) == 1:
now.append(split[0])
return now, later


class SerializableModel(models.Model):
class Meta:
abstract = True

class Serialization:
FIELDS: list
EXCLUDES: list
TRANSFORM: dict[str, str]

def _transform(self, fields: list, excludes: dict, includes: dict):
field_transforms = getattr(self.Serialization, "TRANSFORM", {})

data = {}
for field in fields:
value = getattr(self, field)
if isinstance(value, SerializableModel):
value = value.serialize(includes.get(field), excludes.get(field))
elif isinstance(value, datetime):
value = value.isoformat()
elif value.__class__.__name__ == "RelatedManager" or value.__class__.__name__ == "ManyRelatedManager":
value = [obj.serialize(includes.get(field), excludes.get(field)) for obj in value.all()]

data[field_transforms.get(field, field)] = value

return data

def serialize(self, includes: list | None = None, excludes: list | None = None):
if includes is None:
includes = []
if excludes is None:
excludes = []

exclude_now, exclude_later = _separate_field_args(excludes, only_include_last=True)
include_now, include_later = _separate_field_args(includes)

fields = list(self.Serialization.FIELDS)
for field in exclude_now:
try:
fields.remove(field)
except ValueError:
pass
for field in include_now:
fields.append(field)

return self._transform(fields, exclude_later, include_later)
60 changes: 48 additions & 12 deletions otdb/database/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from django.contrib.auth import get_user_model
from django.conf import settings

from common.models import enum_field
from common.models import enum_field, SerializableModel
from common.exceptions import ClientException, ServerException
from common.util import sql_s, unzip, find_invalids

Expand Down Expand Up @@ -94,17 +94,20 @@ class UserRolesField:
pass


class BeatmapsetMetadata(models.Model):
class BeatmapsetMetadata(SerializableModel):
id = models.PositiveIntegerField(primary_key=True)
artist = models.CharField(max_length=256)
title = models.CharField(max_length=256)
creator = models.CharField(max_length=15)

class Serialization:
FIELDS = ["id", "artist", "title", "creator"]

def __str__(self):
return str(self.id)


class BeatmapMetadata(models.Model):
class BeatmapMetadata(SerializableModel):
id = models.PositiveIntegerField(primary_key=True)
difficulty = models.CharField(max_length=256)
ar = models.FloatField()
Expand All @@ -114,26 +117,35 @@ class BeatmapMetadata(models.Model):
length = models.PositiveIntegerField()
bpm = models.FloatField()

class Serialization:
FIELDS = ["id", "difficulty", "ar", "od", "cs", "hp", "length", "bpm"]

def __str__(self):
return str(self.id)


class BeatmapMod(models.Model):
class BeatmapMod(SerializableModel):
acronym = models.CharField(max_length=2)
settings = models.JSONField(default=dict)

class Serialization:
FIELDS = ["id", "acronym", "settings"]

class Meta:
constraints = [
models.UniqueConstraint(fields=["acronym", "settings"], name="beatmapmod_unique_constraint")
]


class MappoolBeatmap(models.Model):
class MappoolBeatmap(SerializableModel):
beatmapset_metadata = models.ForeignKey(BeatmapsetMetadata, models.PROTECT, related_name="mappool_beatmaps")
beatmap_metadata = models.ForeignKey(BeatmapMetadata, models.PROTECT, related_name="mappool_beatmaps")
mods = models.ManyToManyField(BeatmapMod, "related_beatmaps")
star_rating = models.FloatField()

class Serialization:
FIELDS = ["id", "star_rating"]

@staticmethod
async def get_rows_data(beatmap: Beatmap, mods: tuple[str | None, ...]):
mods_flag = 0
Expand Down Expand Up @@ -178,20 +190,26 @@ def __str__(self):
return self.slot


class MappoolBeatmapConnection(models.Model):
class MappoolBeatmapConnection(SerializableModel):
mappool = models.ForeignKey("Mappool", models.CASCADE, related_name="beatmap_connections")
beatmap = models.ForeignKey(MappoolBeatmap, models.CASCADE, related_name="mappool_connections")
slot = models.CharField(max_length=8)

class Serialization:
FIELDS = ["slot"]


class Mappool(models.Model):
class Mappool(SerializableModel):
name = models.CharField(max_length=64)
description = models.CharField(max_length=512, default="")
beatmaps = models.ManyToManyField(MappoolBeatmap, "mappools", through=MappoolBeatmapConnection)
submitted_by = models.ForeignKey(OsuUser, models.PROTECT, related_name="submitted_mappools")
favorites = models.ManyToManyField(OsuUser, through="MappoolFavorite", related_name="mappool_favorites")
avg_star_rating = models.FloatField()

class Serialization:
FIELDS = ["id", "name", "description", "avg_star_rating"]

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

Expand Down Expand Up @@ -253,7 +271,7 @@ def __str__(self):
return self.name


class Tournament(models.Model):
class Tournament(SerializableModel):
name = models.CharField(max_length=128, unique=True)
abbreviation = models.CharField(max_length=16, default="")
link = models.CharField(max_length=256, default="")
Expand All @@ -263,6 +281,12 @@ class Tournament(models.Model):
submitted_by = models.ForeignKey(OsuUser, models.PROTECT, related_name="submitted_tournaments")
favorites = models.ManyToManyField(OsuUser, through="TournamentFavorite", related_name="tournament_favorites")

class Serialization:
FIELDS = ["id", "name", "abbreviation", "link", "description"]
TRANSFORM = {
"involvements": "staff"
}

@staticmethod
def _new_tournament(
cls,
Expand Down Expand Up @@ -355,33 +379,45 @@ def __str__(self):
return self.name


class TournamentInvolvement(models.Model):
class TournamentInvolvement(SerializableModel):
tournament = models.ForeignKey(Tournament, on_delete=models.CASCADE, related_name="involvements")
user = models.ForeignKey(OsuUser, on_delete=models.CASCADE, related_name="involvements")
roles = UserRolesField(default=0)

class Serialization:
FIELDS = ["roles"]

class Meta:
constraints = [
models.UniqueConstraint(fields=["tournament", "user"], name="tournamentinvolvement_unique_constraint")
]


class MappoolConnection(models.Model):
class MappoolConnection(SerializableModel):
tournament = models.ForeignKey(Tournament, on_delete=models.CASCADE, related_name="mappool_connections")
mappool = models.ForeignKey(Mappool, on_delete=models.CASCADE, related_name="tournament_connections")
name_override = models.CharField(max_length=64, null=True)

class Serialization:
FIELDS = ["name_override"]

def __str__(self):
return self.name_override if self.name_override is not None else ""


class MappoolFavorite(models.Model):
class MappoolFavorite(SerializableModel):
mappool = models.ForeignKey(Mappool, models.CASCADE, related_name="favorite_connections")
user = models.ForeignKey(OsuUser, models.CASCADE, related_name="mappool_favorite_connections")
timestamp = models.PositiveBigIntegerField()

class Serialization:
FIELDS = ["timestamp"]

class TournamentFavorite(models.Model):

class TournamentFavorite(SerializableModel):
tournament = models.ForeignKey(Tournament, models.CASCADE, related_name="favorite_connections")
user = models.ForeignKey(OsuUser, models.CASCADE, related_name="tournament_favorite_connections")
timestamp = models.PositiveBigIntegerField()

class Serialization:
FIELDS = ["timestamp"]
Loading
Loading