diff --git a/config/settings/base.py b/config/settings/base.py index 43ce7b438..82cf36c44 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -7,6 +7,8 @@ import environ from corsheaders.defaults import default_headers as default_cors_headers +from safe_transaction_service import __version__ + from ..gunicorn import ( gunicorn_request_timeout, gunicorn_worker_connections, @@ -99,9 +101,9 @@ "django_extensions", "corsheaders", "rest_framework", - "drf_yasg", "django_s3_storage", "rest_framework.authtoken", + "drf_spectacular", ] LOCAL_APPS = [ "safe_transaction_service.account_abstraction.apps.AccountAbstractionConfig", @@ -322,7 +324,9 @@ "rest_framework.authentication.TokenAuthentication", ), "DEFAULT_VERSIONING_CLASS": "rest_framework.versioning.NamespaceVersioning", + "ALLOWED_VERSIONS": ["v1", "v2"], "EXCEPTION_HANDLER": "safe_transaction_service.history.exceptions.custom_exception_handler", + "DEFAULT_SCHEMA_CLASS": "drf_spectacular.openapi.AutoSchema", } # INDEXER LOG LEVEL @@ -654,13 +658,6 @@ ETHERSCAN_API_KEY = env("ETHERSCAN_API_KEY", default=None) IPFS_GATEWAY = env("IPFS_GATEWAY", default="https://ipfs.io/ipfs/") -SWAGGER_SETTINGS = { - "SECURITY_DEFINITIONS": { - "api_key": {"type": "apiKey", "in": "header", "name": "Authorization"} - }, - "DEFAULT_AUTO_SCHEMA_CLASS": "safe_transaction_service.utils.swagger.CustomSwaggerSchema", -} - # Shell Plus # ------------------------------------------------------------------------------ SHELL_PLUS_PRINT_SQL_TRUNCATE = env.int("SHELL_PLUS_PRINT_SQL_TRUNCATE", default=10_000) @@ -684,3 +681,19 @@ REINDEX_CONTRACTS_METADATA_COUNTDOWN = env.int( "REINDEX_CONTRACTS_METADATA_COUNTDOWN", default=0 ) + +# DRF ESPECTACULAR +SPECTACULAR_SETTINGS = { + "TITLE": "Safe Transaction Service", + "DESCRIPTION": "API to keep track of transactions sent via Safe smart contracts", + "VERSION": __version__, + "SWAGGER_UI_FAVICON_HREF": "static/safe/favicon.png", + "OAS_VERSION": "3.1.0", + "SERVE_INCLUDE_SCHEMA": False, + "SCHEMA_PATH_PREFIX": "/api/v[0-9]", + "DEFAULT_GENERATOR_CLASS": "safe_transaction_service.utils.swagger.IgnoreVersionSchemaGenerator", + "POSTPROCESSING_HOOKS": [ + "drf_spectacular.contrib.djangorestframework_camel_case.camelize_serializer_fields" + ], + "SORT_OPERATION_PARAMETERS": False, +} diff --git a/config/settings/local.py b/config/settings/local.py index 21dcc9902..a5c63f9d5 100644 --- a/config/settings/local.py +++ b/config/settings/local.py @@ -33,7 +33,11 @@ # http://niwinz.github.io/django-redis/latest/#_memcached_exceptions_behavior "IGNORE_EXCEPTIONS": True, }, - } + }, + "local_storage": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "local_mem", + }, } # django-debug-toolbar diff --git a/config/settings/production.py b/config/settings/production.py index 361725b97..0fe1d862c 100644 --- a/config/settings/production.py +++ b/config/settings/production.py @@ -27,7 +27,11 @@ # http://niwinz.github.io/django-redis/latest/#_memcached_exceptions_behavior "IGNORE_EXCEPTIONS": True, }, - } + }, + "local_storage": { + "BACKEND": "django.core.cache.backends.locmem.LocMemCache", + "LOCATION": "local_mem", + }, } # SECURITY diff --git a/config/settings/test.py b/config/settings/test.py index 0c339eb0b..0954ef94a 100644 --- a/config/settings/test.py +++ b/config/settings/test.py @@ -24,6 +24,9 @@ "default": { "BACKEND": "django.core.cache.backends.dummy.DummyCache", }, + "local_storage": { + "BACKEND": "django.core.cache.backends.dummy.DummyCache", + }, } # PASSWORDS diff --git a/config/urls.py b/config/urls.py index 9d80b3a1d..522fe5144 100644 --- a/config/urls.py +++ b/config/urls.py @@ -4,39 +4,35 @@ from django.http import HttpResponse from django.urls import path, re_path from django.views import defaults as default_views +from django.views.decorators.cache import cache_page -from drf_yasg import openapi -from drf_yasg.views import get_schema_view -from rest_framework import permissions - -schema_view = get_schema_view( - openapi.Info( - title="Safe Transaction Service API", - default_version="v1", - description="API to keep track of transactions sent via Safe smart contracts", - license=openapi.License(name="MIT License"), - ), - validators=["flex", "ssv"], - public=True, - permission_classes=[permissions.AllowAny], +from drf_spectacular.views import ( + SpectacularAPIView, + SpectacularRedocView, + SpectacularSwaggerView, ) -schema_cache_timeout = 60 * 5 # 5 minutes - +schema_cache_timeout = 60 * 60 * 24 * 7 # 1 week swagger_urlpatterns = [ path( "", - schema_view.with_ui("swagger", cache_timeout=schema_cache_timeout), + cache_page(schema_cache_timeout, cache="local_storage")( + SpectacularSwaggerView.as_view(url_name="schema-json") + ), name="schema-swagger-ui", ), re_path( - r"^swagger(?P\.json|\.yaml)$", - schema_view.without_ui(cache_timeout=schema_cache_timeout), + r"^schema\/(?:\?format=(?Pjson|yaml))?$", + cache_page(schema_cache_timeout, cache="local_storage")( + SpectacularAPIView().as_view() + ), name="schema-json", ), path( "redoc/", - schema_view.with_ui("redoc", cache_timeout=schema_cache_timeout), + cache_page(schema_cache_timeout, cache="local_storage")( + SpectacularRedocView.as_view(url_name="schema-redoc") + ), name="schema-redoc", ), ] @@ -84,6 +80,7 @@ ), ] + urlpatterns = swagger_urlpatterns + [ path(settings.ADMIN_URL, admin.site.urls), path("api/v1/", include((urlpatterns_v1, "v1"))), diff --git a/requirements.txt b/requirements.txt index 1ad0e9095..1a4e05224 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,7 @@ django-timezone-field==7.0 djangorestframework==3.15.2 djangorestframework-camel-case==1.4.2 docutils==0.21.2 -drf-yasg[validation]==1.21.7 +drf-spectacular==0.27.2 firebase-admin==6.5.0 flower==2.0.1 gunicorn[gevent]==22.0.0 diff --git a/safe_transaction_service/account_abstraction/views.py b/safe_transaction_service/account_abstraction/views.py index aec7ea22f..2599f718f 100644 --- a/safe_transaction_service/account_abstraction/views.py +++ b/safe_transaction_service/account_abstraction/views.py @@ -1,5 +1,5 @@ import django_filters -from drf_yasg.utils import swagger_auto_schema +from drf_spectacular.utils import OpenApiResponse, extend_schema from rest_framework import status from rest_framework.filters import OrderingFilter from rest_framework.generics import ListAPIView, ListCreateAPIView, RetrieveAPIView @@ -10,6 +10,7 @@ from .models import SafeOperation, SafeOperationConfirmation, UserOperation +@extend_schema(tags=["4337"]) class SafeOperationView(RetrieveAPIView): """ Returns a SafeOperation given its Safe operation hash @@ -54,6 +55,7 @@ def get_serializer_class(self): elif self.request.method == "POST": return serializers.SafeOperationSerializer + @extend_schema(tags=["4337"]) def get(self, request, address, *args, **kwargs): """ Returns the list of SafeOperations for a given Safe account @@ -69,8 +71,9 @@ def get(self, request, address, *args, **kwargs): ) return super().get(request, address, *args, **kwargs) - @swagger_auto_schema( - request_body=serializers.SafeOperationSerializer, + @extend_schema( + tags=["4337"], + request=serializers.SafeOperationSerializer, responses={201: "Created"}, ) def post(self, request, address, *args, **kwargs): @@ -112,15 +115,26 @@ def get_serializer_class(self): elif self.request.method == "POST": return serializers.SafeOperationConfirmationSerializer - @swagger_auto_schema(responses={400: "Invalid data"}) + @extend_schema( + tags=["4337"], + responses={ + 200: serializers.SafeOperationConfirmationResponseSerializer, + 400: OpenApiResponse(description="Invalid data"), + }, + ) def get(self, request, *args, **kwargs): """ Get the list of confirmations for a multisig transaction """ return super().get(request, *args, **kwargs) - @swagger_auto_schema( - responses={201: "Created", 400: "Malformed data", 422: "Error processing data"} + @extend_schema( + tags=["4337"], + responses={ + 201: OpenApiResponse(description="Created"), + 400: OpenApiResponse(description="Malformed data"), + 422: OpenApiResponse(description="Error processing data"), + }, ) def post(self, request, *args, **kwargs): """ @@ -130,6 +144,7 @@ def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) +@extend_schema(tags=["4337"]) class UserOperationView(RetrieveAPIView): """ Returns a UserOperation given its user operation hash @@ -171,6 +186,7 @@ def get_serializer_context(self): context["safe_address"] = self.kwargs["address"] return context + @extend_schema(tags=["4337"]) def get(self, request, address, *args, **kwargs): """ Returns the list of UserOperations for a given Safe account diff --git a/safe_transaction_service/history/serializers.py b/safe_transaction_service/history/serializers.py index 26ae64cc3..dd6e3d9a0 100644 --- a/safe_transaction_service/history/serializers.py +++ b/safe_transaction_service/history/serializers.py @@ -7,7 +7,6 @@ from django.http import Http404 from django.utils import timezone -from drf_yasg.utils import swagger_serializer_method from eth_typing import ChecksumAddress from rest_framework import serializers from rest_framework.exceptions import NotFound, ValidationError @@ -678,9 +677,6 @@ def get_block_number(self, obj: MultisigTransaction) -> Optional[int]: if obj.ethereum_tx_id: return obj.ethereum_tx.block_id - @swagger_serializer_method( - serializer_or_field=SafeMultisigConfirmationResponseSerializer - ) def get_confirmations(self, obj: MultisigTransaction) -> Dict[str, Any]: """ Filters confirmations queryset @@ -1245,3 +1241,9 @@ class SafeDeploymentContractSerializer(serializers.Serializer): class SafeDeploymentSerializer(serializers.Serializer): version = serializers.CharField(max_length=10) # Example 1.3.0 contracts = SafeDeploymentContractSerializer(many=True) + + +class CodeErrorResponse(serializers.Serializer): + code = serializers.IntegerField() + message = serializers.CharField() + arguments = serializers.ListField() diff --git a/safe_transaction_service/history/tests/test_views.py b/safe_transaction_service/history/tests/test_views.py index f36dddc40..65526c6e7 100644 --- a/safe_transaction_service/history/tests/test_views.py +++ b/safe_transaction_service/history/tests/test_views.py @@ -91,7 +91,7 @@ def test_about_view(self): self.assertEqual(response.status_code, status.HTTP_200_OK) def test_swagger_json_schema(self): - url = reverse("schema-json", args=(".json",)) + url = reverse("schema-json") + "?format=json" response = self.client.get(url, format="json") self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/safe_transaction_service/history/views.py b/safe_transaction_service/history/views.py index 05308fd08..c4e62f048 100644 --- a/safe_transaction_service/history/views.py +++ b/safe_transaction_service/history/views.py @@ -7,8 +7,13 @@ from django.views.decorators.cache import cache_page import django_filters -from drf_yasg import openapi -from drf_yasg.utils import swagger_auto_schema +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import ( + OpenApiParameter, + OpenApiResponse, + extend_schema, + extend_schema_view, +) from eth_typing import ChecksumAddress, HexStr from rest_framework import status from rest_framework.filters import OrderingFilter @@ -150,6 +155,7 @@ def get(self, request, format=None): return Response(self._get_info(ethereum_client)) +@extend_schema(responses={200: serializers.IndexingStatusSerializer}) class IndexingView(GenericAPIView): serializer_class = serializers.IndexingStatusSerializer pagination_class = None # Don't show limit/offset in swagger @@ -165,6 +171,7 @@ def get(self, request): return Response(status=status.HTTP_200_OK, data=serializer.data) +@extend_schema(responses={200: serializers.MasterCopyResponseSerializer}) class SingletonsView(ListAPIView): """ Returns a list of Master Copies configured in the service @@ -177,6 +184,26 @@ def get_queryset(self): return SafeMasterCopy.objects.relevant() +@extend_schema( + responses={ + 200: serializers.SafeDeploymentSerializer, + 404: OpenApiResponse(description="Provided version does not exist"), + }, + parameters=[ + OpenApiParameter( + "version", + OpenApiTypes.STR, + default=None, + description="Filter by Safe version", + ), + OpenApiParameter( + "contract", + OpenApiTypes.STR, + default=None, + description="Filter by Safe contract name", + ), + ], +) class SafeDeploymentsView(ListAPIView): """ Returns a list of safe deployments by version @@ -185,28 +212,6 @@ class SafeDeploymentsView(ListAPIView): serializer_class = serializers.SafeDeploymentSerializer pagination_class = None # Don't show limit/offset in swagger - _schema_version_param = openapi.Parameter( - "version", - openapi.IN_QUERY, - type=openapi.TYPE_STRING, - default=None, - description="Filter by Safe version", - ) - _schema_contract_param = openapi.Parameter( - "contract", - openapi.IN_QUERY, - type=openapi.TYPE_STRING, - default=None, - description="Filter by Safe contract name", - ) - - @swagger_auto_schema( - responses={404: "Provided version does not exist"}, - manual_parameters=[ - _schema_version_param, - _schema_contract_param, - ], - ) @method_decorator(cache_page(60)) # 60 seconds def get(self, request): filter_version = self.request.query_params.get("version") @@ -247,6 +252,24 @@ def get(self, request): return Response(status=status.HTTP_200_OK, data=serializer.data) +@extend_schema( + tags=["transactions"], + responses={ + 200: OpenApiResponse( + response=serializers.AllTransactionsSchemaSerializer, + description="A list with every element with the structure of one of these transaction" + "types", + ), + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Checksum address validation failed", + ), + 400: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Ordering field is not valid", + ), + }, +) class AllTransactionsListView(ListAPIView): filter_backends = ( django_filters.rest_framework.DjangoFilterBackend, @@ -261,12 +284,6 @@ class AllTransactionsListView(ListAPIView): serializers.AllTransactionsSchemaSerializer ) # Just for docs, not used - _schema_200_response = openapi.Response( - "A list with every element with the structure of one of these transaction" - "types", - serializers.AllTransactionsSchemaSerializer, - ) - def get_ordering_parameter(self) -> Optional[str]: return self.request.query_params.get(OrderingFilter.ordering_param) @@ -350,12 +367,6 @@ def list(self, request, *args, **kwargs): ) return paginated_response - @swagger_auto_schema( - responses={ - 200: _schema_200_response, - 422: "code = 1: Checksum address validation failed", - }, - ) def get(self, request, *args, **kwargs): """ Returns all the *executed* transactions for a given Safe address. @@ -400,17 +411,21 @@ def get(self, request, *args, **kwargs): return response +@extend_schema( + tags=["transactions"], + responses={ + 200: serializers.SafeModuleTransactionResponseSerializer, + 404: OpenApiResponse(description="ModuleTransaction does not exist"), + 400: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Invalid moduleTransactionId", + ), + }, +) class SafeModuleTransactionView(RetrieveAPIView): serializer_class = serializers.SafeModuleTransactionResponseSerializer pagination_class = None # Don't show limit/offset in swagger - @swagger_auto_schema( - responses={ - 200: serializer_class(), - 404: "ModuleTransaction does not exist", - 400: "Invalid moduleTransactionId", - } - ) @method_decorator(cache_page(60 * 60)) # 1 hour def get(self, request, module_transaction_id: str, *args, **kwargs) -> Response: """ @@ -440,6 +455,16 @@ def get(self, request, module_transaction_id: str, *args, **kwargs) -> Response: return Response(status=status.HTTP_404_NOT_FOUND) +@extend_schema( + tags=["transactions"], + responses={ + 200: serializers.SafeModuleTransactionResponseSerializer, + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Checksum address validation failed", + ), + }, +) class SafeModuleTransactionListView(ListAPIView): filter_backends = ( django_filters.rest_framework.DjangoFilterBackend, @@ -451,15 +476,16 @@ class SafeModuleTransactionListView(ListAPIView): serializer_class = serializers.SafeModuleTransactionResponseSerializer def get_queryset(self): + # Just for swagger doc + if getattr(self, "swagger_fake_view", False): + return ModuleTransaction.objects.none() + return ( ModuleTransaction.objects.filter(safe=self.kwargs["address"]) .select_related("internal_tx__ethereum_tx") .order_by("-created") ) - @swagger_auto_schema( - responses={400: "Invalid data", 422: "Invalid ethereum address"} - ) def get(self, request, address, format=None): """ Returns all the transactions executed from modules given a Safe address @@ -496,7 +522,13 @@ def get_serializer_class(self): elif self.request.method == "POST": return serializers.SafeMultisigConfirmationSerializer - @swagger_auto_schema(responses={400: "Invalid data"}) + @extend_schema( + tags=["transactions"], + responses={ + 200: serializers.SafeMultisigConfirmationResponseSerializer, + 400: OpenApiResponse(description="Invalid data"), + }, + ) def get(self, request, *args, **kwargs): """ Returns the list of confirmations for the multi-signature transaction associated with @@ -504,8 +536,13 @@ def get(self, request, *args, **kwargs): """ return super().get(request, *args, **kwargs) - @swagger_auto_schema( - responses={201: "Created", 400: "Malformed data", 422: "Error processing data"} + @extend_schema( + tags=["transactions"], + responses={ + 201: OpenApiResponse(description="Created"), + 400: OpenApiResponse(description="Malformed data"), + 422: OpenApiResponse(description="Error processing data"), + }, ) def post(self, request, *args, **kwargs): """ @@ -516,6 +553,18 @@ def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) +@extend_schema_view( + get=extend_schema(tags=["transactions"]), + delete=extend_schema( + tags=["transactions"], + request=serializers.SafeMultisigTransactionDeleteSerializer, + responses={ + 204: OpenApiResponse(description="Deleted"), + 404: OpenApiResponse(description="Transaction not found"), + 400: OpenApiResponse(description="Error processing data"), + }, + ), +) class SafeMultisigTransactionDetailView(RetrieveAPIView): """ Returns a multi-signature transaction given its Safe transaction hash @@ -532,14 +581,6 @@ def get_queryset(self): .select_related("ethereum_tx__block") ) - @swagger_auto_schema( - request_body=serializers.SafeMultisigTransactionDeleteSerializer(), - responses={ - 204: "Deleted", - 404: "Transaction not found", - 400: "Error processing data", - }, - ) def delete(self, request, safe_tx_hash: HexStr): """ Removes the queued but not executed multi-signature transaction associated with the given Safe tansaction hash. @@ -597,6 +638,9 @@ class SafeMultisigTransactionListView(ListAPIView): pagination_class = pagination.DefaultPagination def get_queryset(self): + if getattr(self, "swagger_fake_view", False): + # Just for openApi doc purposes + return MultisigTransaction.objects.none() return ( MultisigTransaction.objects.filter(safe=self.kwargs["address"]) .with_confirmations_required() @@ -627,8 +671,18 @@ def get_serializer_class(self): elif self.request.method == "POST": return serializers.SafeMultisigTransactionSerializer - @swagger_auto_schema( - responses={400: "Invalid data", 422: "Invalid ethereum address"} + @extend_schema( + tags=["transactions"], + responses={ + 200: OpenApiResponse( + response=serializers.SafeMultisigTransactionResponseSerializer + ), + 400: OpenApiResponse(description="Invalid data"), + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Invalid ethereum address", + ), + }, ) def get(self, request, *args, **kwargs): """ @@ -650,13 +704,21 @@ def get(self, request, *args, **kwargs): response.data["count_unique_nonce"] = self.get_unique_nonce(address) return response - @swagger_auto_schema( + @extend_schema( + tags=["transactions"], + request=serializers.SafeMultisigTransactionSerializer, responses={ - 201: "Created or signature updated", - 400: "Invalid data", - 422: "Invalid ethereum address/User is not an owner/Invalid safeTxHash/" - "Invalid signature/Nonce already executed/Sender is not an owner", - } + 201: OpenApiResponse( + response=serializers.SafeMultisigTransactionSerializer, + description="Created or signature updated", + ), + 400: OpenApiResponse(description="Invalid data"), + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Invalid ethereum address | User is not an owner | Invalid safeTxHash |" + "Invalid signature | Nonce already executed | Sender is not an owner", + ), + }, ) def post(self, request, address, format=None): """ @@ -685,33 +747,29 @@ def post(self, request, address, format=None): return Response(status=status.HTTP_201_CREATED) -def swagger_safe_balance_schema(serializer_class, deprecated: bool = False): - _schema_token_trusted_param = openapi.Parameter( - "trusted", - openapi.IN_QUERY, - type=openapi.TYPE_BOOLEAN, - default=False, - description="If `True` just trusted tokens will be returned", - ) - _schema_token_exclude_spam_param = openapi.Parameter( - "exclude_spam", - openapi.IN_QUERY, - type=openapi.TYPE_BOOLEAN, - default=False, - description="If `True` spam tokens will not be returned", - ) - return swagger_auto_schema( - responses={ - 200: serializer_class(many=True), - 404: "Safe not found", - 422: "Safe address checksum not valid", - }, - manual_parameters=[ - _schema_token_trusted_param, - _schema_token_exclude_spam_param, - ], - deprecated=deprecated, - ) +def swagger_assets_parameters(): + """ + Return the swagger doc of ERC20, ERC721 default filters + Used for documentation purposes + + :return: + """ + return [ + OpenApiParameter( + "trusted", + location="query", + type=OpenApiTypes.BOOL, + default=False, + description="If `True` just trusted tokens will be returned", + ), + OpenApiParameter( + "exclude_spam", + location="query", + type=OpenApiTypes.BOOL, + default=False, + description="If `True` spam tokens will not be returned", + ), + ] class SafeBalanceView(GenericAPIView): @@ -734,7 +792,17 @@ def get_parameters(self) -> Tuple[bool, bool]: def get_result(self, *args, **kwargs): return BalanceServiceProvider().get_balances(*args, **kwargs) - @swagger_safe_balance_schema(serializer_class) + @extend_schema( + parameters=swagger_assets_parameters(), + responses={ + 200: OpenApiResponse( + response=serializers.SafeBalanceResponseSerializer(many=True) + ), + 404: OpenApiResponse(description="Safe not found"), + 422: OpenApiResponse(description="Safe address checksum not valid"), + }, + deprecated=False, + ) def get(self, request, address): """ Get balance for Ether and ERC20 tokens of a given Safe account @@ -820,12 +888,15 @@ def get_queryset(self, transfer_id: str) -> TransferDict: log_index = int(transfer_id[65:]) return self.get_erc20_erc721_transfer(tx_hash, log_index) - @swagger_auto_schema( + @extend_schema( + tags=["transactions"], responses={ - 200: serializers.TransferWithTokenInfoResponseSerializer(), - 404: "Transfer does not exist", - 400: "Invalid transferId", - } + 200: OpenApiResponse( + response=serializers.TransferWithTokenInfoResponseSerializer + ), + 404: OpenApiResponse(description="Transfer does not exist"), + 400: OpenApiResponse(description="Invalid transferId"), + }, ) @method_decorator(cache_page(60 * 60)) # 1 hour def get(self, request, transfer_id: str, *args, **kwargs) -> Response: @@ -870,6 +941,9 @@ def get_transfers(self, address: str): ) def get_queryset(self): + if getattr(self, "swagger_fake_view", False): + # Just for openApi doc purposes + return InternalTx.objects.none() address = self.kwargs["address"] return self.get_transfers(address) @@ -886,11 +960,15 @@ def list(self, request, *args, **kwargs): ) return Response(serializer.data) - @swagger_auto_schema( + @extend_schema( + tags=["transactions"], responses={ 200: serializers.TransferWithTokenInfoResponseSerializer(many=True), - 422: "Safe address checksum not valid", - } + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Safe address checksum not valid", + ), + }, ) def get(self, request, address, format=None): """ @@ -911,11 +989,16 @@ def get(self, request, address, format=None): class SafeIncomingTransferListView(SafeTransferListView): - @swagger_auto_schema( + + @extend_schema( + tags=["transactions"], responses={ 200: serializers.TransferWithTokenInfoResponseSerializer(many=True), - 422: "Safe address checksum not valid", - } + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Safe address checksum not valid", + ), + }, ) def get(self, *args, **kwargs): """ @@ -944,12 +1027,12 @@ class SafeCreationView(GenericAPIView): serializer_class = serializers.SafeCreationInfoResponseSerializer pagination_class = None # Don't show limit/offset in swagger - @swagger_auto_schema( + @extend_schema( responses={ 200: serializer_class(), - 404: "Safe creation not found", - 422: "Owner address checksum not valid", - 503: "Problem connecting to Ethereum network", + 404: OpenApiResponse(description="Safe creation not found"), + 422: OpenApiResponse(description="Owner address checksum not valid"), + 503: OpenApiResponse(description="Problem connecting to Ethereum network"), } ) @method_decorator(cache_page(60 * 60)) # 1 hour @@ -984,11 +1067,13 @@ class SafeInfoView(GenericAPIView): serializer_class = serializers.SafeInfoResponseSerializer pagination_class = None # Don't show limit/offset in swagger - @swagger_auto_schema( + @extend_schema( responses={ 200: serializer_class(), - 404: "Safe not found", - 422: "code = 1: Checksum address validation failed\ncode = 50: Cannot get Safe info", + 404: OpenApiResponse(description="Safe not found"), + 422: OpenApiResponse( + description="code = 1: Checksum address validation failed\ncode = 50: Cannot get Safe info" + ), } ) def get(self, request, address, *args, **kwargs): @@ -1028,10 +1113,13 @@ class ModulesView(GenericAPIView): serializer_class = serializers.ModulesResponseSerializer pagination_class = None # Don't show limit/offset in swagger - @swagger_auto_schema( + @extend_schema( responses={ 200: serializers.ModulesResponseSerializer(), - 422: "Module address checksum not valid", + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Module address checksum not valid", + ), } ) @method_decorator(cache_page(15)) # 15 seconds @@ -1059,10 +1147,13 @@ class OwnersView(GenericAPIView): serializer_class = serializers.OwnerResponseSerializer pagination_class = None # Don't show limit/offset in swagger - @swagger_auto_schema( + @extend_schema( responses={ 200: serializers.OwnerResponseSerializer(), - 422: "Owner address checksum not valid", + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Owner address checksum not valid", + ), } ) @method_decorator(cache_page(15)) # 15 seconds @@ -1089,11 +1180,17 @@ def get(self, request, address, *args, **kwargs): class DataDecoderView(GenericAPIView): serializer_class = serializers.DataDecoderSerializer - @swagger_auto_schema( + @extend_schema( responses={ - 200: "Decoded data", - 404: "Cannot find function selector to decode data", - 422: "Invalid data", + 200: OpenApiResponse( + description="Decoded data", response=serializers.DataDecoderSerializer + ), + 404: OpenApiResponse( + description="Cannot find function selector to decode data" + ), + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, description="Invalid data" + ), } ) def post(self, request, format=None): @@ -1133,13 +1230,16 @@ def get_serializer_context(self): context["safe_address"] = self.kwargs["address"] return context - @swagger_auto_schema( + @extend_schema( + tags=["transactions"], responses={ 200: response_serializer, - 400: "Data not valid", - 404: "Safe not found", - 422: "Tx not valid", - } + 400: OpenApiResponse(description="Data not valid"), + 404: OpenApiResponse(description="Safe not found"), + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, description="Tx not valid" + ), + }, ) def post(self, request, address, *args, **kwargs): """ @@ -1212,15 +1312,21 @@ def get_serializer_class(self): elif self.request.method == "POST": return serializers.DelegateSerializer - @swagger_auto_schema(deprecated=True, responses={400: "Invalid data"}) + @extend_schema( + deprecated=True, responses={400: OpenApiResponse(description="Invalid data")} + ) def get(self, request, **kwargs): """ Returns a list with all the delegates """ return super().get(request, **kwargs) - @swagger_auto_schema( - deprecated=True, responses={202: "Accepted", 400: "Malformed data"} + @extend_schema( + deprecated=True, + responses={ + 202: OpenApiResponse(description="Accepted"), + 400: OpenApiResponse(description="Malformed data"), + }, ) def post(self, request, **kwargs): """ @@ -1247,14 +1353,17 @@ class DelegateDeleteView(GenericAPIView): serializer_class = serializers.DelegateDeleteSerializer - @swagger_auto_schema( + @extend_schema( deprecated=True, - request_body=serializer_class(), + request=serializer_class(), responses={ - 204: "Deleted", - 400: "Malformed data", - 404: "Delegate not found", - 422: "Invalid Ethereum address/Error processing data", + 204: OpenApiResponse(description="Deleted"), + 400: OpenApiResponse(description="Malformed data"), + 404: OpenApiResponse(description="Delegate not found"), + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Invalid Ethereum address/Error processing data", + ), }, ) def delete(self, request, delegate_address, *args, **kwargs): @@ -1302,14 +1411,18 @@ def get_object(self): delegate=self.kwargs["delegate_address"], ) - @swagger_auto_schema( + @extend_schema( + tags=["delegates"], deprecated=True, - request_body=serializer_class(), + request=serializer_class(), responses={ - 204: "Deleted", - 400: "Malformed data", - 404: "Delegate not found", - 422: "Invalid Ethereum address/Error processing data", + 204: OpenApiResponse(description="Deleted"), + 400: OpenApiResponse(description="Malformed data"), + 404: OpenApiResponse(description="Delegate not found"), + 422: OpenApiResponse( + response=serializers.CodeErrorResponse, + description="Invalid Ethereum address | Error processing data", + ), }, ) def delete(self, request, address, delegate_address, *args, **kwargs): diff --git a/safe_transaction_service/history/views_v2.py b/safe_transaction_service/history/views_v2.py index a7c74c1a9..2fabd7f35 100644 --- a/safe_transaction_service/history/views_v2.py +++ b/safe_transaction_service/history/views_v2.py @@ -4,7 +4,8 @@ from django.db.models import Q import django_filters -from drf_yasg.utils import swagger_auto_schema +from drf_spectacular.types import OpenApiTypes +from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema from rest_framework import status from rest_framework.generics import GenericAPIView, ListCreateAPIView from rest_framework.response import Response @@ -17,15 +18,49 @@ from .services import BalanceServiceProvider from .services.balance_service import Balance from .services.collectibles_service import CollectiblesServiceProvider -from .views import swagger_safe_balance_schema +from .views import swagger_assets_parameters logger = logging.getLogger(__name__) +def swagger_pagination_parameters(): + """ + Pagination parameters are ignored with custom pagination + + :return: swagger pagination parameters + """ + return [ + OpenApiParameter( + "limit", + location="query", + type=OpenApiTypes.INT, + description=pagination.ListPagination.limit_query_description, + ), + OpenApiParameter( + "offset", + location="query", + type=OpenApiTypes.INT, + description=pagination.ListPagination.offset_query_description, + ), + ] + + +@extend_schema( + parameters=swagger_assets_parameters() + swagger_pagination_parameters(), + responses={ + 200: OpenApiResponse( + response=serializers.SafeCollectibleResponseSerializer(many=True) + ), + 404: OpenApiResponse(description="Safe not found"), + 422: OpenApiResponse( + description="Safe address checksum not valid", + response=serializers.CodeErrorResponse, + ), + }, +) class SafeCollectiblesView(GenericAPIView): serializer_class = serializers.SafeCollectibleResponseSerializer - @swagger_safe_balance_schema(serializer_class) def get(self, request, address): """ Get paginated collectibles (ERC721 tokens) and information about them of a given Safe account. @@ -79,14 +114,25 @@ def get_serializer_class(self): elif self.request.method == "POST": return serializers.DelegateSerializerV2 - @swagger_auto_schema(responses={400: "Invalid data"}) + @extend_schema( + responses={ + 200: serializers.SafeDelegateResponseSerializer, + 400: OpenApiResponse(description="Invalid data"), + } + ) def get(self, request, **kwargs): """ Returns a list with all the delegates """ return super().get(request, **kwargs) - @swagger_auto_schema(responses={202: "Accepted", 400: "Malformed data"}) + @extend_schema( + request=serializers.DelegateSerializerV2, + responses={ + 202: OpenApiResponse(description="Accepted"), + 400: OpenApiResponse(description="Malformed data"), + }, + ) def post(self, request, **kwargs): """ Adds a new Safe delegate with a custom label. Calls with same delegate but different label or @@ -128,13 +174,15 @@ def post(self, request, **kwargs): class DelegateDeleteView(GenericAPIView): serializer_class = serializers.DelegateDeleteSerializerV2 - @swagger_auto_schema( - request_body=serializer_class(), + @extend_schema( + request=serializer_class(), responses={ - 204: "Deleted", - 400: "Malformed data", - 404: "Delegate not found", - 422: "Invalid Ethereum address/Error processing data", + 204: OpenApiResponse(description="Deleted"), + 400: OpenApiResponse(description="Malformed data"), + 404: OpenApiResponse(description="Delegate not found"), + 422: OpenApiResponse( + description="Invalid Ethereum address/Error processing data" + ), }, ) def delete(self, request, delegate_address, *args, **kwargs): @@ -191,7 +239,19 @@ def get_parameters(self) -> Tuple[bool, bool]: def get_result(self, *args, **kwargs) -> Tuple[List[Balance], int]: return BalanceServiceProvider().get_balances(*args, **kwargs) - @swagger_safe_balance_schema(serializer_class) + @extend_schema( + parameters=swagger_assets_parameters() + swagger_pagination_parameters(), + responses={ + 200: OpenApiResponse( + response=serializers.SafeCollectibleResponseSerializer(many=True) + ), + 404: OpenApiResponse(description="Safe not found"), + 422: OpenApiResponse( + description="Safe address checksum not valid", + response=serializers.CodeErrorResponse, + ), + }, + ) def get(self, request, address): """ Get paginated balances for Ether and ERC20 tokens. diff --git a/safe_transaction_service/notifications/views.py b/safe_transaction_service/notifications/views.py index e2a4b5862..1e453712c 100644 --- a/safe_transaction_service/notifications/views.py +++ b/safe_transaction_service/notifications/views.py @@ -1,6 +1,6 @@ import logging -from drf_yasg.utils import swagger_auto_schema +from drf_spectacular.utils import OpenApiResponse, extend_schema from rest_framework import status from rest_framework.generics import CreateAPIView, DestroyAPIView from rest_framework.response import Response @@ -29,8 +29,11 @@ class FirebaseDeviceCreateView(CreateAPIView): serializers.FirebaseDeviceSerializerWithOwnersResponseSerializer ) - @swagger_auto_schema( - responses={200: response_serializer_class(), 400: "Invalid data"} + @extend_schema( + responses={ + 200: response_serializer_class(), + 400: OpenApiResponse(description="Invalid data"), + } ) def post(self, request, *args, **kwargs): return super().post(request, *args, **kwargs) diff --git a/safe_transaction_service/safe_messages/views.py b/safe_transaction_service/safe_messages/views.py index adafeb370..b147bc36f 100644 --- a/safe_transaction_service/safe_messages/views.py +++ b/safe_transaction_service/safe_messages/views.py @@ -3,7 +3,7 @@ import django_filters from djangorestframework_camel_case.parser import CamelCaseJSONParser from djangorestframework_camel_case.render import CamelCaseJSONRenderer -from drf_yasg.utils import swagger_auto_schema +from drf_spectacular.utils import OpenApiResponse, extend_schema from rest_framework import status from rest_framework.filters import OrderingFilter from rest_framework.generics import CreateAPIView, ListCreateAPIView, RetrieveAPIView @@ -46,8 +46,9 @@ def get_serializer_context(self): ) return context - @swagger_auto_schema( - responses={201: "Created"}, + @extend_schema( + tags=["messages"], + responses={201: OpenApiResponse(description="Created")}, ) def post(self, request, *args, **kwargs): """ @@ -84,6 +85,10 @@ def get_serializer_class(self): elif self.request.method == "POST": return serializers.SafeMessageSerializer + @extend_schema( + tags=["messages"], + responses={200: serializers.SafeMessageResponseSerializer}, + ) def get(self, request, address, *args, **kwargs): """ Returns the list of messages for a given Safe account @@ -99,9 +104,10 @@ def get(self, request, address, *args, **kwargs): ) return super().get(request, address, *args, **kwargs) - @swagger_auto_schema( - request_body=serializers.SafeMessageSerializer, - responses={201: "Created"}, + @extend_schema( + tags=["messages"], + request=serializers.SafeMessageSerializer, + responses={201: OpenApiResponse(description="Created")}, ) def post(self, request, address, *args, **kwargs): """ diff --git a/safe_transaction_service/tokens/views.py b/safe_transaction_service/tokens/views.py index f11929433..ace3b6ac5 100644 --- a/safe_transaction_service/tokens/views.py +++ b/safe_transaction_service/tokens/views.py @@ -2,15 +2,25 @@ from django.views.decorators.cache import cache_page import django_filters.rest_framework +from drf_spectacular.utils import OpenApiResponse, extend_schema from rest_framework import response, status from rest_framework.filters import OrderingFilter, SearchFilter from rest_framework.generics import ListAPIView, RetrieveAPIView from safe_eth.eth.utils import fast_is_checksum_address +from ..history.serializers import CodeErrorResponse from . import filters, serializers from .models import Token +@extend_schema( + responses={ + 200: OpenApiResponse(response=serializers.TokenInfoResponseSerializer), + 422: OpenApiResponse( + response=CodeErrorResponse, description="Invalid ethereum address" + ), + } +) class TokenView(RetrieveAPIView): serializer_class = serializers.TokenInfoResponseSerializer lookup_field = "address" diff --git a/safe_transaction_service/utils/swagger.py b/safe_transaction_service/utils/swagger.py index 469df2839..70270663b 100644 --- a/safe_transaction_service/utils/swagger.py +++ b/safe_transaction_service/utils/swagger.py @@ -1,47 +1,86 @@ +import os import re -from drf_yasg.inspectors import SwaggerAutoSchema - - -class CustomSwaggerSchema(SwaggerAutoSchema): - VERSION_REGULAR_EXPRESSION = re.compile(r"v[\d]+") - CUSTOM_TAGS = { - "messages": ["messages"], - "owners": ["owners"], - "transaction": ["transactions"], - "transfer": ["transactions"], - "multisig-transaction": ["transactions"], - "user-operation": ["4337"], - "safe-operation": ["4337"], - } - - def get_tags(self, operation_keys=None): - """ - The method `get_tags` defined by default just gets the `operation_keys` (generated from the - url) and return the first element, for example in our case being all the tags `v1`, `v2`, etc. - - We are now defining some logic to generate `tags`: - - If they are explicitly defined in the view, we keep that (`self.overrides`). - - If the `operation_id` contains any of the words defined, we override the tag. - - Otherwise, just iterate the `operation_keys` and return - - :param operation_keys: - :return: - """ - operation_keys = operation_keys or self.operation_keys - - if tags := self.overrides.get("tags"): - return tags - - if len(operation_keys) == 1: - return list(operation_keys) - - operation_id = self.get_operation_id() - for key, tags in self.CUSTOM_TAGS.items(): - if key in operation_id: - return tags[:] - - for operation_key in operation_keys: - if not self.VERSION_REGULAR_EXPRESSION.match(operation_key): - return [operation_key] - return [] # This should never happen +from drf_spectacular.drainage import add_trace_message, error, get_override, warn +from drf_spectacular.generators import SchemaGenerator +from drf_spectacular.openapi import AutoSchema +from drf_spectacular.plumbing import camelize_operation +from drf_spectacular.settings import spectacular_settings + + +class IgnoreVersionSchemaGenerator(SchemaGenerator): + + def parse(self, input_request, public): + """Iterate endpoints generating per method path operations.""" + result = {} + self._initialise_endpoints() + endpoints = self._get_paths_and_endpoints() + + if spectacular_settings.SCHEMA_PATH_PREFIX is None: + # estimate common path prefix if none was given. only use it if we encountered more + # than one view to prevent emission of erroneous and unnecessary fallback names. + non_trivial_prefix = ( + len(set([view.__class__ for _, _, _, view in endpoints])) > 1 + ) + if non_trivial_prefix: + path_prefix = os.path.commonpath([path for path, _, _, _ in endpoints]) + path_prefix = re.escape( + path_prefix + ) # guard for RE special chars in path + else: + path_prefix = "/" + else: + path_prefix = spectacular_settings.SCHEMA_PATH_PREFIX + if not path_prefix.startswith("^"): + path_prefix = ( + "^" + path_prefix + ) # make sure regex only matches from the start + + for path, path_regex, method, view in endpoints: + # emit queued up warnings/error that happened prior to generation (decoration) + for w in get_override(view, "warnings", []): + warn(w) + for e in get_override(view, "errors", []): + error(e) + + view.request = spectacular_settings.GET_MOCK_REQUEST( + method, path, view, input_request + ) + + if not (public or self.has_view_permissions(path, method, view)): + continue + + # Remove versioning api + + assert isinstance(view.schema, AutoSchema), ( + f"Incompatible AutoSchema used on View {view.__class__}. Is DRF's " + f'DEFAULT_SCHEMA_CLASS pointing to "drf_spectacular.openapi.AutoSchema" ' + f"or any other drf-spectacular compatible AutoSchema?" + ) + with add_trace_message(getattr(view, "__class__", view)): + operation = view.schema.get_operation( + path, path_regex, path_prefix, method, self.registry + ) + + # operation was manually removed via @extend_schema + if not operation: + continue + + if spectacular_settings.SCHEMA_PATH_PREFIX_TRIM: + path = re.sub( + pattern=path_prefix, repl="", string=path, flags=re.IGNORECASE + ) + + if spectacular_settings.SCHEMA_PATH_PREFIX_INSERT: + path = spectacular_settings.SCHEMA_PATH_PREFIX_INSERT + path + + if not path.startswith("/"): + path = "/" + path + + if spectacular_settings.CAMELIZE_NAMES: + path, operation = camelize_operation(path, operation) + + result.setdefault(path, {}) + result[path][method.lower()] = operation + + return result