diff --git a/drf_rw_serializers/generics.py b/drf_rw_serializers/generics.py index 58f5b9e..c8a5638 100644 --- a/drf_rw_serializers/generics.py +++ b/drf_rw_serializers/generics.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- +from inspect import currentframe + from rest_framework import generics, mixins from .mixins import ( @@ -36,8 +38,7 @@ def get_serializer_class(self): "attribute, or override the `get_read_serializer_class()` or " "`get_serializer_class()` method." % self.__class__.__name__ ) - # `default_to_serializer_class` is used to prevent a `RecursionError` - return self.get_read_serializer_class(default_to_serializer_class=True) + return self.get_read_serializer_class() if self.request.method in ["POST", "PUT", "PATCH", "DELETE"]: assert ( @@ -48,8 +49,7 @@ def get_serializer_class(self): "attribute, or override the `get_write_serializer_class()` or " "`get_serializer_class()` method." % self.__class__.__name__ ) - # `default_to_serializer_class` is used to prevent a `RecursionError` - return self.get_write_serializer_class(default_to_serializer_class=True) + return self.get_write_serializer_class() assert ( self.serializer_class is not None @@ -70,7 +70,7 @@ def get_read_serializer(self, *args, **kwargs): kwargs["context"] = self.get_serializer_context() return serializer_class(*args, **kwargs) - def get_read_serializer_class(self, default_to_serializer_class: bool = False): + def get_read_serializer_class(self): """ Return the class to use for the serializer. Defaults to using `self.read_serializer_class`. @@ -81,11 +81,9 @@ def get_read_serializer_class(self, default_to_serializer_class: bool = False): (Eg. admins get full serialization, others get basic serialization) """ if getattr(self, "read_serializer_class", None) is None: - if default_to_serializer_class: - return self.serializer_class - + if currentframe().f_back.f_code.co_name == 'get_serializer_class': + return super().get_serializer_class() return self.get_serializer_class() - return self.read_serializer_class def get_write_serializer(self, *args, **kwargs): @@ -97,7 +95,7 @@ def get_write_serializer(self, *args, **kwargs): kwargs["context"] = self.get_serializer_context() return serializer_class(*args, **kwargs) - def get_write_serializer_class(self, default_to_serializer_class: bool = False): + def get_write_serializer_class(self): """ Return the class to use for the serializer. Defaults to using `self.write_serializer_class`. @@ -108,11 +106,9 @@ def get_write_serializer_class(self, default_to_serializer_class: bool = False): (Eg. admins can send extra fields, others cannot) """ if getattr(self, "write_serializer_class", None) is None: - if default_to_serializer_class: - return self.serializer_class - + if currentframe().f_back.f_code.co_name == 'get_serializer_class': + return super().get_serializer_class() return self.get_serializer_class() - return self.write_serializer_class diff --git a/tests/test_generics.py b/tests/test_generics.py index e1f2d61..e8a66c6 100644 --- a/tests/test_generics.py +++ b/tests/test_generics.py @@ -128,10 +128,10 @@ class GetSerializerClassView(generics.GenericAPIView): def get_serializer_class(self): return OrderedMealDetailsSerializer - def get_read_serializer_class(self, default_to_serializer_class: bool = False): + def get_read_serializer_class(self): return OrderListSerializer - def get_write_serializer_class(self, default_to_serializer_class: bool = False): + def get_write_serializer_class(self): return OrderCreateSerializer self.assertEqual( @@ -169,20 +169,20 @@ class SerializerClassView(generics.GenericAPIView): serializer_class = OrderedMealDetailsSerializer self.assertEqual( - SerializerClassView().get_read_serializer_class(default_to_serializer_class=True), + SerializerClassView().get_read_serializer_class(), OrderedMealDetailsSerializer, ) with mock.patch.object( SerializerClassView, "get_serializer_class" ) as mock_get_serializer_class: - SerializerClassView().get_read_serializer_class(default_to_serializer_class=False) + SerializerClassView().get_read_serializer_class() mock_get_serializer_class.assert_called_once() def test_get_read_serializer_class_override_provided(self): class GetReadSerializerClassView(generics.GenericAPIView): - def get_read_serializer_class(self, default_to_serializer_class: bool = False): + def get_read_serializer_class(self): return OrderListSerializer self.assertEqual( @@ -216,20 +216,20 @@ class SerializerClassView(generics.GenericAPIView): serializer_class = OrderedMealDetailsSerializer self.assertEqual( - SerializerClassView().get_write_serializer_class(default_to_serializer_class=True), + SerializerClassView().get_write_serializer_class(), OrderedMealDetailsSerializer, ) with mock.patch.object( SerializerClassView, "get_serializer_class" ) as mock_get_serializer_class: - SerializerClassView().get_write_serializer_class(default_to_serializer_class=False) + SerializerClassView().get_write_serializer_class() mock_get_serializer_class.assert_called_once() def test_get_write_serializer_class_override_provided(self): class GetWriteSerializerClassView(generics.GenericAPIView): - def get_write_serializer_class(self, default_to_serializer_class: bool = False): + def get_write_serializer_class(self): return OrderCreateSerializer self.assertEqual(