diff --git a/vng_api_common/inspectors/geojson.py b/vng_api_common/inspectors/geojson.py index 7a55fb35..f4c678dd 100644 --- a/vng_api_common/inspectors/geojson.py +++ b/vng_api_common/inspectors/geojson.py @@ -1,34 +1,11 @@ from drf_spectacular.extensions import OpenApiSerializerFieldExtension from drf_spectacular.plumbing import ResolvedComponent -from rest_framework import serializers -from rest_framework_gis.fields import GeometryField from vng_api_common.oas import TYPE_ARRAY, TYPE_NUMBER, TYPE_OBJECT, TYPE_STRING -def has_geo_fields(serializer) -> bool: - """ - Check if any of the serializer fields are a GeometryField. - If the serializer has nested serializers, a depth-first search is done - to check if the nested serializers has `GeometryField`\ s. - """ - for field in serializer.fields.values(): - if isinstance(field, serializers.Serializer): - has_nested_geo_fields = has_geo_fields(field) - if has_nested_geo_fields: - return True - - elif isinstance(field, (serializers.ListSerializer, serializers.ListField)): - field = field.child - - if isinstance(field, GeometryField): - return True - - return False - - class GeometryFieldExtension(OpenApiSerializerFieldExtension): - target_class = GeometryField + target_class = "rest_framework_gis.fields.GeometryField" match_subclasses = True def map_serializer_field(self, auto_schema, direction): diff --git a/vng_api_common/inspectors/utils.py b/vng_api_common/inspectors/utils.py index a6fd0982..e3af6cf9 100644 --- a/vng_api_common/inspectors/utils.py +++ b/vng_api_common/inspectors/utils.py @@ -2,6 +2,7 @@ from django.db import models +from rest_framework import serializers from rest_framework.utils.model_meta import get_field_info @@ -38,3 +39,29 @@ def get_target_field(model: Type[models.Model], field: str) -> Optional[models.F return get_target_field(relation_info.related_model, "__".join(remaining)) return None + + +def has_geo_fields(serializer) -> bool: + """ + Check if any of the serializer fields are a GeometryField. + If the serializer has nested serializers, a depth-first search is done + to check if the nested serializers has `GeometryField`\ s. + """ + try: + from rest_framework_gis.fields import GeometryField + except ImportError: + return False + + for field in serializer.fields.values(): + if isinstance(field, serializers.Serializer): + has_nested_geo_fields = has_geo_fields(field) + if has_nested_geo_fields: + return True + + elif isinstance(field, (serializers.ListSerializer, serializers.ListField)): + field = field.child + + if isinstance(field, GeometryField): + return True + + return False diff --git a/vng_api_common/inspectors/view.py b/vng_api_common/inspectors/view.py index ae6080f4..2dda6817 100644 --- a/vng_api_common/inspectors/view.py +++ b/vng_api_common/inspectors/view.py @@ -10,14 +10,13 @@ from drf_spectacular.utils import OpenApiParameter, OpenApiResponse from rest_framework import exceptions, status, viewsets -from vng_api_common.inspectors.geojson import has_geo_fields - from ..constants import HEADER_AUDIT, HEADER_LOGRECORD_ID, VERSION_HEADER from ..exceptions import Conflict, Gone, PreconditionFailed from ..geo import DEFAULT_CRS, HEADER_ACCEPT, HEADER_CONTENT, GeoMixin from ..permissions import BaseAuthRequired, get_required_scopes from ..serializers import FoutSerializer, ValidatieFoutSerializer from .cache import CACHE_REQUEST_HEADERS, get_cache_headers, has_cache_header +from .utils import has_geo_fields logger = logging.getLogger(__name__)