Skip to content

Commit

Permalink
fix(serializer): fix RecursionError without changing signature
Browse files Browse the repository at this point in the history
The `inspect` library is used instead to get the `caller` information.
This information gives us a clue, if it is recursion call or not.

Signed-off-by: Sergei Shishov <[email protected]>
  • Loading branch information
sshishov committed Aug 4, 2024
1 parent 9ce4680 commit 89de758
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 22 deletions.
23 changes: 9 additions & 14 deletions drf_rw_serializers/generics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-

from inspect import currentframe

from rest_framework import generics, mixins

from .mixins import (
Expand All @@ -16,6 +18,7 @@ def get_serializer_class(self):
Return the class to use for the serializer.
Defaults to using `self.serializer_class`.
If the previous call was already done then to break the recursion we fallback to parent `get_serializer_class`
If the request method is GET, it tries to use `self.read_serializer_class`.
If the request method is not GET, it tries to use `self.write_serializer_class`.
If the specific serializer class for the request method is not set, it falls back to
Expand All @@ -26,6 +29,8 @@ def get_serializer_class(self):
(Eg. admins get full serialization, others get basic serialization)
"""
if currentframe().f_back.f_code.co_name in {'get_serializer_class', 'get_read_serializer_class', 'get_write_serializer_class'}:
return super().get_serializer_class()
if hasattr(self, "request"):
if self.request.method in ["GET", "HEAD", "OPTIONS", "TRACE"]:
assert (
Expand All @@ -36,8 +41,7 @@ def get_serializer_class(self):
"attribute, or override the `get_read_serializer_class()` or "
"`get_serializer_class()` method." % self.__class__.__name__
)
# `default_to_serializer_class` is used to prevent a `RecursionError`
return self.get_read_serializer_class(default_to_serializer_class=True)
return self.get_read_serializer_class()

if self.request.method in ["POST", "PUT", "PATCH", "DELETE"]:
assert (
Expand All @@ -48,8 +52,7 @@ def get_serializer_class(self):
"attribute, or override the `get_write_serializer_class()` or "
"`get_serializer_class()` method." % self.__class__.__name__
)
# `default_to_serializer_class` is used to prevent a `RecursionError`
return self.get_write_serializer_class(default_to_serializer_class=True)
return self.get_write_serializer_class()

assert (
self.serializer_class is not None
Expand All @@ -70,7 +73,7 @@ def get_read_serializer(self, *args, **kwargs):
kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)

def get_read_serializer_class(self, default_to_serializer_class: bool = False):
def get_read_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.read_serializer_class`.
Expand All @@ -81,11 +84,7 @@ def get_read_serializer_class(self, default_to_serializer_class: bool = False):
(Eg. admins get full serialization, others get basic serialization)
"""
if getattr(self, "read_serializer_class", None) is None:
if default_to_serializer_class:
return self.serializer_class

return self.get_serializer_class()

return self.read_serializer_class

def get_write_serializer(self, *args, **kwargs):
Expand All @@ -97,7 +96,7 @@ def get_write_serializer(self, *args, **kwargs):
kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)

def get_write_serializer_class(self, default_to_serializer_class: bool = False):
def get_write_serializer_class(self):
"""
Return the class to use for the serializer.
Defaults to using `self.write_serializer_class`.
Expand All @@ -108,11 +107,7 @@ def get_write_serializer_class(self, default_to_serializer_class: bool = False):
(Eg. admins can send extra fields, others cannot)
"""
if getattr(self, "write_serializer_class", None) is None:
if default_to_serializer_class:
return self.serializer_class

return self.get_serializer_class()

return self.write_serializer_class


Expand Down
16 changes: 8 additions & 8 deletions tests/test_generics.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,10 @@ class GetSerializerClassView(generics.GenericAPIView):
def get_serializer_class(self):
return OrderedMealDetailsSerializer

def get_read_serializer_class(self, default_to_serializer_class: bool = False):
def get_read_serializer_class(self):
return OrderListSerializer

def get_write_serializer_class(self, default_to_serializer_class: bool = False):
def get_write_serializer_class(self):
return OrderCreateSerializer

self.assertEqual(
Expand Down Expand Up @@ -169,20 +169,20 @@ class SerializerClassView(generics.GenericAPIView):
serializer_class = OrderedMealDetailsSerializer

self.assertEqual(
SerializerClassView().get_read_serializer_class(default_to_serializer_class=True),
SerializerClassView().get_read_serializer_class(),
OrderedMealDetailsSerializer,
)

with mock.patch.object(
SerializerClassView, "get_serializer_class"
) as mock_get_serializer_class:
SerializerClassView().get_read_serializer_class(default_to_serializer_class=False)
SerializerClassView().get_read_serializer_class()

mock_get_serializer_class.assert_called_once()

def test_get_read_serializer_class_override_provided(self):
class GetReadSerializerClassView(generics.GenericAPIView):
def get_read_serializer_class(self, default_to_serializer_class: bool = False):
def get_read_serializer_class(self):
return OrderListSerializer

self.assertEqual(
Expand Down Expand Up @@ -216,20 +216,20 @@ class SerializerClassView(generics.GenericAPIView):
serializer_class = OrderedMealDetailsSerializer

self.assertEqual(
SerializerClassView().get_write_serializer_class(default_to_serializer_class=True),
SerializerClassView().get_write_serializer_class(),
OrderedMealDetailsSerializer,
)

with mock.patch.object(
SerializerClassView, "get_serializer_class"
) as mock_get_serializer_class:
SerializerClassView().get_write_serializer_class(default_to_serializer_class=False)
SerializerClassView().get_write_serializer_class()

mock_get_serializer_class.assert_called_once()

def test_get_write_serializer_class_override_provided(self):
class GetWriteSerializerClassView(generics.GenericAPIView):
def get_write_serializer_class(self, default_to_serializer_class: bool = False):
def get_write_serializer_class(self):
return OrderCreateSerializer

self.assertEqual(
Expand Down

0 comments on commit 89de758

Please sign in to comment.